fastmcp 2.12.0__py3-none-any.whl → 2.12.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,6 +4,7 @@ import asyncio
4
4
  import json
5
5
  import webbrowser
6
6
  from asyncio import Future
7
+ from collections.abc import AsyncGenerator
7
8
  from datetime import datetime, timedelta, timezone
8
9
  from pathlib import Path
9
10
  from typing import Any, Literal
@@ -34,6 +35,12 @@ __all__ = ["OAuth"]
34
35
  logger = get_logger(__name__)
35
36
 
36
37
 
38
+ class ClientNotFoundError(Exception):
39
+ """Raised when OAuth client credentials are not found on the server."""
40
+
41
+ pass
42
+
43
+
37
44
  class StoredToken(BaseModel):
38
45
  """Token storage format with absolute expiry time."""
39
46
 
@@ -173,7 +180,7 @@ class FileTokenStorage(TokenStorage):
173
180
  for file_type in file_types:
174
181
  path = self._get_file_path(file_type)
175
182
  path.unlink(missing_ok=True)
176
- logger.info(f"Cleared OAuth cache for {self.get_base_url(self.server_url)}")
183
+ logger.debug(f"Cleared OAuth cache for {self.get_base_url(self.server_url)}")
177
184
 
178
185
  @classmethod
179
186
  def clear_all(cls, cache_dir: Path | None = None) -> None:
@@ -300,7 +307,23 @@ class OAuth(OAuthClientProvider):
300
307
  self.context.update_token_expiry(self.context.current_tokens)
301
308
 
302
309
  async def redirect_handler(self, authorization_url: str) -> None:
303
- """Open browser for authorization."""
310
+ """Open browser for authorization, with pre-flight check for invalid client."""
311
+ # Pre-flight check to detect invalid client_id before opening browser
312
+ async with httpx.AsyncClient() as client:
313
+ response = await client.get(authorization_url, follow_redirects=False)
314
+
315
+ # Check for client not found error (400 typically means bad client_id)
316
+ if response.status_code == 400:
317
+ raise ClientNotFoundError(
318
+ "OAuth client not found - cached credentials may be stale"
319
+ )
320
+
321
+ # For any non-redirect response, something is wrong
322
+ if response.status_code not in (302, 303, 307, 308):
323
+ raise RuntimeError(
324
+ f"Unexpected authorization response: {response.status_code}"
325
+ )
326
+
304
327
  logger.info(f"OAuth authorization URL: {authorization_url}")
305
328
  webbrowser.open(authorization_url)
306
329
 
@@ -336,3 +359,56 @@ class OAuth(OAuthClientProvider):
336
359
  tg.cancel_scope.cancel()
337
360
 
338
361
  raise RuntimeError("OAuth callback handler could not be started")
362
+
363
+ async def async_auth_flow(
364
+ self, request: httpx.Request
365
+ ) -> AsyncGenerator[httpx.Request, httpx.Response]:
366
+ """HTTPX auth flow with automatic retry on stale cached credentials.
367
+
368
+ If the OAuth flow fails due to invalid/stale client credentials,
369
+ clears the cache and retries once with fresh registration.
370
+ """
371
+ try:
372
+ # First attempt with potentially cached credentials
373
+ gen = super().async_auth_flow(request)
374
+ response = None
375
+ while True:
376
+ try:
377
+ yielded_request = await gen.asend(response)
378
+ response = yield yielded_request
379
+ except StopAsyncIteration:
380
+ break
381
+
382
+ except ClientNotFoundError:
383
+ logger.debug(
384
+ "OAuth client not found on server, clearing cache and retrying..."
385
+ )
386
+
387
+ # Clear cached state and retry once
388
+ self._initialized = False
389
+
390
+ # Try to clear storage if it supports it
391
+ if hasattr(self.context.storage, "clear"):
392
+ try:
393
+ self.context.storage.clear()
394
+ except Exception as e:
395
+ logger.warning(f"Failed to clear OAuth storage cache: {e}")
396
+ # Can't retry without clearing cache, re-raise original error
397
+ raise ClientNotFoundError(
398
+ "OAuth client not found and cache could not be cleared"
399
+ ) from e
400
+ else:
401
+ logger.warning(
402
+ "Storage does not support clear() - cannot retry with fresh credentials"
403
+ )
404
+ # Can't retry without clearing cache, re-raise original error
405
+ raise
406
+
407
+ gen = super().async_auth_flow(request)
408
+ response = None
409
+ while True:
410
+ try:
411
+ yielded_request = await gen.asend(response)
412
+ response = yield yielded_request
413
+ except StopAsyncIteration:
414
+ break
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from collections.abc import Awaitable, Callable
4
- from typing import Any, Generic, TypeAlias, TypeVar
4
+ from typing import Any, Generic, TypeAlias
5
5
 
6
6
  import mcp.types
7
7
  from mcp import ClientSession
@@ -10,12 +10,13 @@ from mcp.shared.context import LifespanContextT, RequestContext
10
10
  from mcp.types import ElicitRequestParams
11
11
  from mcp.types import ElicitResult as MCPElicitResult
12
12
  from pydantic_core import to_jsonable_python
13
+ from typing_extensions import TypeVar
13
14
 
14
15
  from fastmcp.utilities.json_schema_type import json_schema_to_type
15
16
 
16
17
  __all__ = ["ElicitRequestParams", "ElicitResult", "ElicitationHandler"]
17
18
 
18
- T = TypeVar("T")
19
+ T = TypeVar("T", default=Any)
19
20
 
20
21
 
21
22
  class ElicitResult(MCPElicitResult, Generic[T]):
@@ -1,3 +0,0 @@
1
- from .openai import OpenAISamplingHandler
2
-
3
- __all__ = ["OpenAISamplingHandler"]
@@ -10,15 +10,22 @@ from mcp.types import (
10
10
  SamplingMessage,
11
11
  TextContent,
12
12
  )
13
- from openai import NOT_GIVEN, OpenAI
14
- from openai.types.chat import (
15
- ChatCompletion,
16
- ChatCompletionAssistantMessageParam,
17
- ChatCompletionMessageParam,
18
- ChatCompletionSystemMessageParam,
19
- ChatCompletionUserMessageParam,
20
- )
21
- from openai.types.shared.chat_model import ChatModel
13
+
14
+ try:
15
+ from openai import NOT_GIVEN, OpenAI
16
+ from openai.types.chat import (
17
+ ChatCompletion,
18
+ ChatCompletionAssistantMessageParam,
19
+ ChatCompletionMessageParam,
20
+ ChatCompletionSystemMessageParam,
21
+ ChatCompletionUserMessageParam,
22
+ )
23
+ from openai.types.shared.chat_model import ChatModel
24
+ except ImportError:
25
+ raise ImportError(
26
+ "The `openai` package is not installed. Please install `fastmcp[openai]` or add `openai` to your dependencies manually."
27
+ )
28
+
22
29
  from typing_extensions import override
23
30
 
24
31
  from fastmcp.experimental.sampling.handlers.base import BaseLLMSamplingHandler
@@ -1,7 +1,13 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from typing import Any
4
+ from urllib.parse import urljoin
4
5
 
6
+ from mcp.server.auth.middleware.auth_context import AuthContextMiddleware
7
+ from mcp.server.auth.middleware.bearer_auth import (
8
+ BearerAuthBackend,
9
+ RequireAuthMiddleware,
10
+ )
5
11
  from mcp.server.auth.provider import (
6
12
  AccessToken as _SDKAccessToken,
7
13
  )
@@ -22,6 +28,8 @@ from mcp.server.auth.settings import (
22
28
  RevocationOptions,
23
29
  )
24
30
  from pydantic import AnyHttpUrl
31
+ from starlette.middleware import Middleware
32
+ from starlette.middleware.authentication import AuthenticationMiddleware
25
33
  from starlette.routing import Route
26
34
 
27
35
 
@@ -40,18 +48,23 @@ class AuthProvider(TokenVerifierProtocol):
40
48
  custom authentication routes.
41
49
  """
42
50
 
43
- def __init__(self, resource_server_url: AnyHttpUrl | str | None = None):
51
+ def __init__(
52
+ self,
53
+ base_url: AnyHttpUrl | str | None = None,
54
+ required_scopes: list[str] | None = None,
55
+ ):
44
56
  """
45
57
  Initialize the auth provider.
46
58
 
47
59
  Args:
48
- resource_server_url: The URL of this resource server. This is used
49
- for RFC 8707 resource indicators, including creating the WWW-Authenticate
50
- header.
60
+ base_url: The base URL of this server (e.g., http://localhost:8000).
61
+ This is used for constructing .well-known endpoints and OAuth metadata.
62
+ required_scopes: List of OAuth scopes required for all requests.
51
63
  """
52
- if isinstance(resource_server_url, str):
53
- resource_server_url = AnyHttpUrl(resource_server_url)
54
- self.resource_server_url = resource_server_url
64
+ if isinstance(base_url, str):
65
+ base_url = AnyHttpUrl(base_url)
66
+ self.base_url = base_url
67
+ self.required_scopes = required_scopes or []
55
68
 
56
69
  async def verify_token(self, token: str) -> AccessToken | None:
57
70
  """Verify a bearer token and return access info if valid.
@@ -66,7 +79,11 @@ class AuthProvider(TokenVerifierProtocol):
66
79
  """
67
80
  raise NotImplementedError("Subclasses must implement verify_token")
68
81
 
69
- def get_routes(self) -> list[Route]:
82
+ def get_routes(
83
+ self,
84
+ mcp_path: str | None = None,
85
+ mcp_endpoint: Any | None = None,
86
+ ) -> list[Route]:
70
87
  """Get the routes for this authentication provider.
71
88
 
72
89
  Each provider is responsible for creating whatever routes it needs:
@@ -75,22 +92,63 @@ class AuthProvider(TokenVerifierProtocol):
75
92
  - OAuthProvider: full OAuth authorization server routes
76
93
  - Custom providers: whatever routes they need
77
94
 
95
+ Args:
96
+ mcp_path: The path where the MCP endpoint is mounted (e.g., "/mcp")
97
+ mcp_endpoint: The MCP endpoint handler to protect with auth
98
+
99
+ Returns:
100
+ List of routes for this provider, including protected MCP endpoints if provided
101
+ """
102
+
103
+ routes = []
104
+
105
+ # Add protected MCP endpoint if provided
106
+ if mcp_path and mcp_endpoint:
107
+ resource_metadata_url = self._get_resource_url(
108
+ "/.well-known/oauth-protected-resource"
109
+ )
110
+
111
+ routes.append(
112
+ Route(
113
+ mcp_path,
114
+ endpoint=RequireAuthMiddleware(
115
+ mcp_endpoint, self.required_scopes, resource_metadata_url
116
+ ),
117
+ )
118
+ )
119
+
120
+ return routes
121
+
122
+ def get_middleware(self) -> list:
123
+ """Get HTTP application-level middleware for this auth provider.
124
+
78
125
  Returns:
79
- List of routes for this provider
126
+ List of Starlette Middleware instances to apply to the HTTP app
80
127
  """
81
- return []
128
+ return [
129
+ Middleware(
130
+ AuthenticationMiddleware,
131
+ backend=BearerAuthBackend(self),
132
+ ),
133
+ Middleware(AuthContextMiddleware),
134
+ ]
82
135
 
83
- def get_resource_metadata_url(self) -> AnyHttpUrl | None:
84
- """Get the resource metadata URL for RFC 9728 compliance."""
85
- if self.resource_server_url is None:
136
+ def _get_resource_url(self, path: str | None = None) -> AnyHttpUrl | None:
137
+ """Get the actual resource URL being protected.
138
+
139
+ Args:
140
+ path: The path where the resource endpoint is mounted (e.g., "/mcp")
141
+
142
+ Returns:
143
+ The full URL of the protected resource
144
+ """
145
+ if self.base_url is None:
86
146
  return None
87
147
 
88
- # Add .well-known path for RFC 9728 compliance
89
- resource_metadata_url = AnyHttpUrl(
90
- str(self.resource_server_url).rstrip("/")
91
- + "/.well-known/oauth-protected-resource"
92
- )
93
- return resource_metadata_url
148
+ if path:
149
+ return AnyHttpUrl(urljoin(str(self.base_url), path))
150
+
151
+ return self.base_url
94
152
 
95
153
 
96
154
  class TokenVerifier(AuthProvider):
@@ -102,20 +160,17 @@ class TokenVerifier(AuthProvider):
102
160
 
103
161
  def __init__(
104
162
  self,
105
- resource_server_url: AnyHttpUrl | str | None = None,
163
+ base_url: AnyHttpUrl | str | None = None,
106
164
  required_scopes: list[str] | None = None,
107
165
  ):
108
166
  """
109
167
  Initialize the token verifier.
110
168
 
111
169
  Args:
112
- resource_server_url: The URL of this resource server. This is used
113
- for RFC 8707 resource indicators, including creating the WWW-Authenticate
114
- header.
170
+ base_url: The base URL of this server
115
171
  required_scopes: Scopes that are required for all requests
116
172
  """
117
- super().__init__(resource_server_url=resource_server_url)
118
- self.required_scopes = required_scopes or []
173
+ super().__init__(base_url=base_url, required_scopes=required_scopes)
119
174
 
120
175
  async def verify_token(self, token: str) -> AccessToken | None:
121
176
  """Verify a bearer token and return access info if valid."""
@@ -135,13 +190,13 @@ class RemoteAuthProvider(AuthProvider):
135
190
  the authorization servers that issue valid tokens.
136
191
  """
137
192
 
138
- resource_server_url: AnyHttpUrl
193
+ base_url: AnyHttpUrl
139
194
 
140
195
  def __init__(
141
196
  self,
142
197
  token_verifier: TokenVerifier,
143
198
  authorization_servers: list[AnyHttpUrl],
144
- resource_server_url: AnyHttpUrl | str,
199
+ base_url: AnyHttpUrl | str,
145
200
  resource_name: str | None = None,
146
201
  resource_documentation: AnyHttpUrl | None = None,
147
202
  ):
@@ -150,11 +205,14 @@ class RemoteAuthProvider(AuthProvider):
150
205
  Args:
151
206
  token_verifier: TokenVerifier instance for token validation
152
207
  authorization_servers: List of authorization servers that issue valid tokens
153
- resource_server_url: URL of this resource server. This is used
154
- for RFC 8707 resource indicators, including creating the WWW-Authenticate
155
- header.
208
+ base_url: The base URL of this server
209
+ resource_name: Optional name for the protected resource
210
+ resource_documentation: Optional documentation URL for the protected resource
156
211
  """
157
- super().__init__(resource_server_url=resource_server_url)
212
+ super().__init__(
213
+ base_url=base_url,
214
+ required_scopes=token_verifier.required_scopes,
215
+ )
158
216
  self.token_verifier = token_verifier
159
217
  self.authorization_servers = authorization_servers
160
218
  self.resource_name = resource_name
@@ -164,21 +222,34 @@ class RemoteAuthProvider(AuthProvider):
164
222
  """Verify token using the configured token verifier."""
165
223
  return await self.token_verifier.verify_token(token)
166
224
 
167
- def get_routes(self) -> list[Route]:
225
+ def get_routes(
226
+ self,
227
+ mcp_path: str | None = None,
228
+ mcp_endpoint: Any | None = None,
229
+ ) -> list[Route]:
168
230
  """Get OAuth routes for this provider.
169
231
 
170
- By default, returns only the standardized OAuth 2.0 Protected Resource routes.
171
- Subclasses can override this method to add additional routes by calling
172
- super().get_routes() and extending the returned list.
232
+ Creates protected resource metadata routes and optionally wraps MCP endpoints with auth.
173
233
  """
234
+ # Start with base routes (protected MCP endpoint)
235
+ routes = super().get_routes(mcp_path, mcp_endpoint)
236
+
237
+ # Get the resource URL based on the MCP path
238
+ resource_url = self._get_resource_url(mcp_path)
239
+
240
+ if resource_url:
241
+ # Add protected resource metadata routes
242
+ routes.extend(
243
+ create_protected_resource_routes(
244
+ resource_url=resource_url,
245
+ authorization_servers=self.authorization_servers,
246
+ scopes_supported=self.token_verifier.required_scopes,
247
+ resource_name=self.resource_name,
248
+ resource_documentation=self.resource_documentation,
249
+ )
250
+ )
174
251
 
175
- return create_protected_resource_routes(
176
- resource_url=self.resource_server_url,
177
- authorization_servers=self.authorization_servers,
178
- scopes_supported=self.token_verifier.required_scopes,
179
- resource_name=self.resource_name,
180
- resource_documentation=self.resource_documentation,
181
- )
252
+ return routes
182
253
 
183
254
 
184
255
  class OAuthProvider(
@@ -200,7 +271,6 @@ class OAuthProvider(
200
271
  client_registration_options: ClientRegistrationOptions | None = None,
201
272
  revocation_options: RevocationOptions | None = None,
202
273
  required_scopes: list[str] | None = None,
203
- resource_server_url: AnyHttpUrl | str | None = None,
204
274
  ):
205
275
  """
206
276
  Initialize the OAuth provider.
@@ -212,14 +282,13 @@ class OAuthProvider(
212
282
  client_registration_options: The client registration options.
213
283
  revocation_options: The revocation options.
214
284
  required_scopes: Scopes that are required for all requests.
215
- resource_server_url: The URL of this resource server (for RFC 8707 resource indicators, defaults to base_url)
216
285
  """
217
286
 
218
- super().__init__()
219
-
220
287
  # Convert URLs to proper types
221
288
  if isinstance(base_url, str):
222
289
  base_url = AnyHttpUrl(base_url)
290
+
291
+ super().__init__(base_url=base_url, required_scopes=required_scopes)
223
292
  self.base_url = base_url
224
293
 
225
294
  if issuer_url is None:
@@ -229,15 +298,6 @@ class OAuthProvider(
229
298
  else:
230
299
  self.issuer_url = issuer_url
231
300
 
232
- # Handle our own resource_server_url and required_scopes
233
- if resource_server_url is None:
234
- self.resource_server_url = base_url
235
- elif isinstance(resource_server_url, str):
236
- self.resource_server_url = AnyHttpUrl(resource_server_url)
237
- else:
238
- self.resource_server_url = resource_server_url
239
- self.required_scopes = required_scopes or []
240
-
241
301
  # Initialize OAuth Authorization Server Provider
242
302
  OAuthAuthorizationServerProvider.__init__(self)
243
303
 
@@ -263,12 +323,17 @@ class OAuthProvider(
263
323
  """
264
324
  return await self.load_access_token(token)
265
325
 
266
- def get_routes(self) -> list[Route]:
326
+ def get_routes(
327
+ self,
328
+ mcp_path: str | None = None,
329
+ mcp_endpoint: Any | None = None,
330
+ ) -> list[Route]:
267
331
  """Get OAuth authorization server routes and optional protected resource routes.
268
332
 
269
333
  This method creates the full set of OAuth routes including:
270
334
  - Standard OAuth authorization server routes (/.well-known/oauth-authorization-server, /authorize, /token, etc.)
271
- - Optional protected resource routes if resource_server_url is configured
335
+ - Optional protected resource routes
336
+ - Protected MCP endpoints if provided
272
337
 
273
338
  Returns:
274
339
  List of OAuth routes
@@ -283,13 +348,19 @@ class OAuthProvider(
283
348
  revocation_options=self.revocation_options,
284
349
  )
285
350
 
351
+ # Get the resource URL based on the MCP path
352
+ resource_url = self._get_resource_url(mcp_path)
353
+
286
354
  # Add protected resource routes if this server is also acting as a resource server
287
- if self.resource_server_url:
355
+ if resource_url:
288
356
  protected_routes = create_protected_resource_routes(
289
- resource_url=self.resource_server_url,
357
+ resource_url=resource_url,
290
358
  authorization_servers=[self.issuer_url],
291
359
  scopes_supported=self.required_scopes,
292
360
  )
293
361
  oauth_routes.extend(protected_routes)
294
362
 
363
+ # Add protected MCP endpoint from base class
364
+ oauth_routes.extend(super().get_routes(mcp_path, mcp_endpoint))
365
+
295
366
  return oauth_routes