fastmcp 2.5.2__py3-none-any.whl → 2.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,377 @@
1
+ import time
2
+ from dataclasses import dataclass
3
+ from typing import Any, TypedDict
4
+
5
+ import httpx
6
+ from authlib.jose import JsonWebKey, JsonWebToken
7
+ from authlib.jose.errors import JoseError
8
+ from cryptography.hazmat.primitives import serialization
9
+ from cryptography.hazmat.primitives.asymmetric import rsa
10
+ from mcp.server.auth.provider import (
11
+ AccessToken,
12
+ AuthorizationCode,
13
+ AuthorizationParams,
14
+ RefreshToken,
15
+ )
16
+ from mcp.shared.auth import (
17
+ OAuthClientInformationFull,
18
+ OAuthToken,
19
+ )
20
+ from pydantic import SecretStr
21
+
22
+ from fastmcp.server.auth.auth import (
23
+ ClientRegistrationOptions,
24
+ OAuthProvider,
25
+ RevocationOptions,
26
+ )
27
+
28
+
29
+ class JWKData(TypedDict, total=False):
30
+ """JSON Web Key data structure."""
31
+
32
+ kty: str # Key type (e.g., "RSA") - required
33
+ kid: str # Key ID (optional but recommended)
34
+ use: str # Usage (e.g., "sig")
35
+ alg: str # Algorithm (e.g., "RS256")
36
+ n: str # Modulus (for RSA keys)
37
+ e: str # Exponent (for RSA keys)
38
+ x5c: list[str] # X.509 certificate chain (for JWKs)
39
+ x5t: str # X.509 certificate thumbprint (for JWKs)
40
+
41
+
42
+ class JWKSData(TypedDict):
43
+ """JSON Web Key Set data structure."""
44
+
45
+ keys: list[JWKData]
46
+
47
+
48
+ @dataclass(frozen=True, kw_only=True, repr=False)
49
+ class RSAKeyPair:
50
+ private_key: SecretStr
51
+ public_key: str
52
+
53
+ @classmethod
54
+ def generate(cls) -> "RSAKeyPair":
55
+ """
56
+ Generate an RSA key pair for testing.
57
+
58
+ Returns:
59
+ tuple: (private_key_pem, public_key_pem)
60
+ """
61
+ # Generate private key
62
+ private_key = rsa.generate_private_key(
63
+ public_exponent=65537,
64
+ key_size=2048,
65
+ )
66
+
67
+ # Get public key
68
+ public_key = private_key.public_key()
69
+
70
+ # Serialize private key to PEM format
71
+ private_pem = private_key.private_bytes(
72
+ encoding=serialization.Encoding.PEM,
73
+ format=serialization.PrivateFormat.PKCS8,
74
+ encryption_algorithm=serialization.NoEncryption(),
75
+ ).decode("utf-8")
76
+
77
+ # Serialize public key to PEM format
78
+ public_pem = public_key.public_bytes(
79
+ encoding=serialization.Encoding.PEM,
80
+ format=serialization.PublicFormat.SubjectPublicKeyInfo,
81
+ ).decode("utf-8")
82
+
83
+ return cls(
84
+ private_key=SecretStr(private_pem),
85
+ public_key=public_pem,
86
+ )
87
+
88
+ def create_token(
89
+ self,
90
+ subject: str = "fastmcp-user",
91
+ issuer: str = "https://fastmcp.example.com",
92
+ audience: str | None = None,
93
+ scopes: list[str] | None = None,
94
+ expires_in_seconds: int = 3600,
95
+ additional_claims: dict[str, Any] | None = None,
96
+ kid: str | None = None,
97
+ ) -> str:
98
+ """
99
+ Generate a test JWT token for testing purposes.
100
+
101
+ Args:
102
+ private_key_pem: RSA private key in PEM format
103
+ subject: Subject claim (usually user ID)
104
+ issuer: Issuer claim
105
+ audience: Audience claim (optional)
106
+ scopes: List of scopes to include
107
+ expires_in_seconds: Token expiration time in seconds
108
+ additional_claims: Any additional claims to include
109
+ kid: Key ID for JWKS lookup (optional)
110
+
111
+ Returns:
112
+ Signed JWT token string
113
+ """
114
+ jwt = JsonWebToken(["RS256"])
115
+
116
+ now = int(time.time())
117
+
118
+ # Build payload
119
+ payload = {
120
+ "iss": issuer,
121
+ "sub": subject,
122
+ "iat": now,
123
+ "exp": now + expires_in_seconds,
124
+ }
125
+
126
+ if audience:
127
+ payload["aud"] = audience
128
+
129
+ if scopes:
130
+ payload["scope"] = " ".join(scopes)
131
+
132
+ if additional_claims:
133
+ payload.update(additional_claims)
134
+
135
+ # Create header
136
+ header = {"alg": "RS256"}
137
+ if kid:
138
+ header["kid"] = kid
139
+
140
+ # Sign and return token
141
+ token_bytes = jwt.encode(
142
+ header,
143
+ payload,
144
+ key=self.private_key.get_secret_value(),
145
+ )
146
+ return token_bytes.decode("utf-8")
147
+
148
+
149
+ class BearerAuthProvider(OAuthProvider):
150
+ """
151
+ Simple JWT Bearer Token validator for hosted MCP servers.
152
+ Uses RS256 asymmetric encryption. Supports either static public key
153
+ or JWKS URI for key rotation.
154
+
155
+ Note that this provider DOES NOT permit client registration or revocation, or any OAuth flows.
156
+ It is intended to be used with a control plane that manages clients and tokens.
157
+ """
158
+
159
+ def __init__(
160
+ self,
161
+ public_key: str | None = None,
162
+ jwks_uri: str | None = None,
163
+ issuer: str | None = None,
164
+ audience: str | None = None,
165
+ required_scopes: list[str] | None = None,
166
+ ):
167
+ """
168
+ Initialize the provider. Either public_key or jwks_uri must be provided.
169
+
170
+ Args:
171
+ public_key: RSA public key in PEM format (for static key)
172
+ jwks_uri: URI to fetch keys from (for key rotation)
173
+ issuer: Expected issuer claim (optional)
174
+ audience: Expected audience claim (optional)
175
+ required_scopes: List of required scopes for access (optional)
176
+ """
177
+ if not (public_key or jwks_uri):
178
+ raise ValueError("Either public_key or jwks_uri must be provided")
179
+ if public_key and jwks_uri:
180
+ raise ValueError("Provide either public_key or jwks_uri, not both")
181
+
182
+ super().__init__(
183
+ issuer_url=issuer or "https://fastmcp.example.com",
184
+ client_registration_options=ClientRegistrationOptions(enabled=False),
185
+ revocation_options=RevocationOptions(enabled=False),
186
+ required_scopes=required_scopes,
187
+ )
188
+
189
+ self.issuer = issuer
190
+ self.audience = audience
191
+ self.public_key = public_key
192
+ self.jwks_uri = jwks_uri
193
+ self.jwt = JsonWebToken(["RS256"])
194
+
195
+ # Simple JWKS cache
196
+ self._jwks_cache: dict[str, str] = {}
197
+ self._jwks_cache_time: float = 0
198
+ self._cache_ttl = 3600 # 1 hour
199
+
200
+ async def _get_verification_key(self, token: str) -> str:
201
+ """Get the verification key for the token."""
202
+ if self.public_key:
203
+ return self.public_key
204
+
205
+ # Extract kid from token header for JWKS lookup
206
+ try:
207
+ import base64
208
+ import json
209
+
210
+ header_b64 = token.split(".")[0]
211
+ header_b64 += "=" * (4 - len(header_b64) % 4) # Add padding
212
+ header = json.loads(base64.urlsafe_b64decode(header_b64))
213
+ kid = header.get("kid")
214
+
215
+ return await self._get_jwks_key(kid)
216
+
217
+ except Exception as e:
218
+ raise ValueError(f"Failed to extract key ID from token: {e}")
219
+
220
+ async def _get_jwks_key(self, kid: str | None) -> str:
221
+ """Fetch key from JWKS with simple caching."""
222
+ if not self.jwks_uri:
223
+ raise ValueError("JWKS URI not configured")
224
+
225
+ current_time = time.time()
226
+
227
+ # Check cache first
228
+ if current_time - self._jwks_cache_time < self._cache_ttl:
229
+ if kid and kid in self._jwks_cache:
230
+ return self._jwks_cache[kid]
231
+ elif not kid and len(self._jwks_cache) == 1:
232
+ # If no kid but only one key cached, use it
233
+ return next(iter(self._jwks_cache.values()))
234
+
235
+ # Fetch JWKS
236
+ try:
237
+ async with httpx.AsyncClient() as client:
238
+ response = await client.get(self.jwks_uri)
239
+ response.raise_for_status()
240
+ jwks_data = response.json()
241
+
242
+ # Cache all keys
243
+ self._jwks_cache = {}
244
+ for key_data in jwks_data.get("keys", []):
245
+ key_kid = key_data.get("kid")
246
+ jwk = JsonWebKey.import_key(key_data)
247
+ public_key = jwk.get_public_key() # type: ignore
248
+
249
+ if key_kid:
250
+ self._jwks_cache[key_kid] = public_key
251
+ else:
252
+ # Key without kid - use a default identifier
253
+ self._jwks_cache["_default"] = public_key
254
+
255
+ self._jwks_cache_time = current_time
256
+
257
+ # Select the appropriate key
258
+ if kid:
259
+ if kid not in self._jwks_cache:
260
+ raise ValueError(f"Key ID '{kid}' not found in JWKS")
261
+ return self._jwks_cache[kid]
262
+ else:
263
+ # No kid in token - only allow if there's exactly one key
264
+ if len(self._jwks_cache) == 1:
265
+ return next(iter(self._jwks_cache.values()))
266
+ elif len(self._jwks_cache) > 1:
267
+ raise ValueError(
268
+ "Multiple keys in JWKS but no key ID (kid) in token"
269
+ )
270
+ else:
271
+ raise ValueError("No keys found in JWKS")
272
+
273
+ except Exception as e:
274
+ raise ValueError(f"Failed to fetch JWKS: {e}")
275
+
276
+ async def load_access_token(self, token: str) -> AccessToken | None:
277
+ """
278
+ Validates the provided JWT bearer token.
279
+
280
+ Args:
281
+ token: The JWT token string to validate
282
+
283
+ Returns:
284
+ AccessToken object if valid, None if invalid or expired
285
+ """
286
+ try:
287
+ # Get verification key (static or from JWKS)
288
+ verification_key = await self._get_verification_key(token)
289
+
290
+ # Decode and verify the JWT token
291
+ claims = self.jwt.decode(token, verification_key)
292
+
293
+ # Validate expiration
294
+ exp = claims.get("exp")
295
+ if exp and exp < time.time():
296
+ return None
297
+
298
+ # Validate issuer - note we use issuer instead of issuer_url here because
299
+ # issuer is optional, allowing users to make this check optional
300
+ if self.issuer:
301
+ if claims.get("iss") != self.issuer:
302
+ return None
303
+
304
+ # Validate audience if configured
305
+ if self.audience:
306
+ aud = claims.get("aud")
307
+ if isinstance(aud, list):
308
+ if self.audience not in aud:
309
+ return None
310
+ elif aud != self.audience:
311
+ return None
312
+
313
+ # Extract claims - prefer client_id over sub for OAuth application identification
314
+ client_id = claims.get("client_id") or claims.get("sub") or "unknown"
315
+ scopes = self._extract_scopes(claims)
316
+
317
+ return AccessToken(
318
+ token=token,
319
+ client_id=str(client_id),
320
+ scopes=scopes,
321
+ expires_at=int(exp) if exp else None,
322
+ )
323
+
324
+ except JoseError:
325
+ return None
326
+ except Exception:
327
+ return None
328
+
329
+ def _extract_scopes(self, claims: dict[str, Any]) -> list[str]:
330
+ """Extract scopes from JWT claims."""
331
+ scope_claim = claims.get("scope", "")
332
+ if isinstance(scope_claim, str):
333
+ return scope_claim.split()
334
+ elif isinstance(scope_claim, list):
335
+ return scope_claim
336
+ return []
337
+
338
+ # --- Unused OAuth server methods ---
339
+ async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
340
+ raise NotImplementedError("Client management not supported")
341
+
342
+ async def register_client(self, client_info: OAuthClientInformationFull) -> None:
343
+ raise NotImplementedError("Client registration not supported")
344
+
345
+ async def authorize(
346
+ self, client: OAuthClientInformationFull, params: AuthorizationParams
347
+ ) -> str:
348
+ raise NotImplementedError("Authorization flow not supported")
349
+
350
+ async def load_authorization_code(
351
+ self, client: OAuthClientInformationFull, authorization_code: str
352
+ ) -> AuthorizationCode | None:
353
+ raise NotImplementedError("Authorization code flow not supported")
354
+
355
+ async def exchange_authorization_code(
356
+ self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode
357
+ ) -> OAuthToken:
358
+ raise NotImplementedError("Authorization code exchange not supported")
359
+
360
+ async def load_refresh_token(
361
+ self, client: OAuthClientInformationFull, refresh_token: str
362
+ ) -> RefreshToken | None:
363
+ raise NotImplementedError("Refresh token flow not supported")
364
+
365
+ async def exchange_refresh_token(
366
+ self,
367
+ client: OAuthClientInformationFull,
368
+ refresh_token: RefreshToken,
369
+ scopes: list[str],
370
+ ) -> OAuthToken:
371
+ raise NotImplementedError("Refresh token exchange not supported")
372
+
373
+ async def revoke_token(
374
+ self,
375
+ token: AccessToken | RefreshToken,
376
+ ) -> None:
377
+ raise NotImplementedError("Token revocation not supported")
@@ -0,0 +1,62 @@
1
+ from pydantic_settings import BaseSettings, SettingsConfigDict
2
+
3
+ from fastmcp.server.auth.providers.bearer import BearerAuthProvider
4
+
5
+
6
+ # Sentinel object to indicate that a setting is not set
7
+ class _NotSet:
8
+ pass
9
+
10
+
11
+ class EnvBearerAuthProviderSettings(BaseSettings):
12
+ """Settings for the BearerAuthProvider."""
13
+
14
+ model_config = SettingsConfigDict(
15
+ env_prefix="FASTMCP_AUTH_BEARER_",
16
+ env_file=".env",
17
+ extra="ignore",
18
+ )
19
+
20
+ public_key: str | None = None
21
+ jwks_uri: str | None = None
22
+ issuer: str | None = None
23
+ audience: str | None = None
24
+ required_scopes: list[str] | None = None
25
+
26
+
27
+ class EnvBearerAuthProvider(BearerAuthProvider):
28
+ """
29
+ A BearerAuthProvider that loads settings from environment variables. Any
30
+ providing setting will always take precedence over the environment
31
+ variables.
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ public_key: str | None | type[_NotSet] = _NotSet,
37
+ jwks_uri: str | None | type[_NotSet] = _NotSet,
38
+ issuer: str | None | type[_NotSet] = _NotSet,
39
+ audience: str | None | type[_NotSet] = _NotSet,
40
+ required_scopes: list[str] | None | type[_NotSet] = _NotSet,
41
+ ):
42
+ """
43
+ Initialize the provider.
44
+
45
+ Args:
46
+ public_key: RSA public key in PEM format (for static key)
47
+ jwks_uri: URI to fetch keys from (for key rotation)
48
+ issuer: Expected issuer claim (optional)
49
+ audience: Expected audience claim (optional)
50
+ required_scopes: List of required scopes for access (optional)
51
+ """
52
+ kwargs = {
53
+ "public_key": public_key,
54
+ "jwks_uri": jwks_uri,
55
+ "issuer": issuer,
56
+ "audience": audience,
57
+ "required_scopes": required_scopes,
58
+ }
59
+ settings = EnvBearerAuthProviderSettings(
60
+ **{k: v for k, v in kwargs.items() if v is not _NotSet}
61
+ )
62
+ super().__init__(**settings.model_dump())