fastmcp 2.5.2__py3-none-any.whl → 2.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -11,6 +11,7 @@ from .transports import (
11
11
  FastMCPTransport,
12
12
  StreamableHttpTransport,
13
13
  )
14
+ from .auth import OAuth, BearerAuth
14
15
 
15
16
  __all__ = [
16
17
  "Client",
@@ -24,4 +25,6 @@ __all__ = [
24
25
  "NpxStdioTransport",
25
26
  "FastMCPTransport",
26
27
  "StreamableHttpTransport",
28
+ "OAuth",
29
+ "BearerAuth",
27
30
  ]
@@ -0,0 +1,4 @@
1
+ from .bearer import BearerAuth
2
+ from .oauth import OAuth
3
+
4
+ __all__ = ["BearerAuth", "OAuth"]
@@ -0,0 +1,17 @@
1
+ import httpx
2
+ from pydantic import SecretStr
3
+
4
+ from fastmcp.utilities.logging import get_logger
5
+
6
+ __all__ = ["BearerAuth"]
7
+
8
+ logger = get_logger(__name__)
9
+
10
+
11
+ class BearerAuth(httpx.Auth):
12
+ def __init__(self, token: str):
13
+ self.token = SecretStr(token)
14
+
15
+ def auth_flow(self, request):
16
+ request.headers["Authorization"] = f"Bearer {self.token.get_secret_value()}"
17
+ yield request
@@ -0,0 +1,391 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import json
5
+ import webbrowser
6
+ from pathlib import Path
7
+ from typing import Any, Literal
8
+ from urllib.parse import urljoin, urlparse
9
+
10
+ import anyio
11
+ import httpx
12
+ from mcp.client.auth import OAuthClientProvider as _MCPOAuthClientProvider
13
+ from mcp.client.auth import TokenStorage
14
+ from mcp.shared.auth import (
15
+ OAuthClientInformationFull,
16
+ OAuthClientMetadata,
17
+ )
18
+ from mcp.shared.auth import (
19
+ OAuthMetadata as _MCPServerOAuthMetadata,
20
+ )
21
+ from mcp.shared.auth import (
22
+ OAuthToken as OAuthToken,
23
+ )
24
+ from pydantic import AnyHttpUrl, ValidationError
25
+
26
+ from fastmcp.client.oauth_callback import (
27
+ create_oauth_callback_server,
28
+ )
29
+ from fastmcp.settings import settings as fastmcp_global_settings
30
+ from fastmcp.utilities.http import find_available_port
31
+ from fastmcp.utilities.logging import get_logger
32
+
33
+ __all__ = ["OAuth"]
34
+
35
+ logger = get_logger(__name__)
36
+
37
+
38
+ def default_cache_dir() -> Path:
39
+ return fastmcp_global_settings.home / "oauth-mcp-client-cache"
40
+
41
+
42
+ # Flexible OAuth models for real-world compatibility
43
+ class ServerOAuthMetadata(_MCPServerOAuthMetadata):
44
+ """
45
+ More flexible OAuth metadata model that accepts broader ranges of values
46
+ than the restrictive MCP standard model.
47
+
48
+ This handles real-world OAuth servers like PayPal that may support
49
+ additional methods not in the MCP specification.
50
+ """
51
+
52
+ # Allow any code challenge methods, not just S256
53
+ code_challenge_methods_supported: list[str] | None = None
54
+
55
+ # Allow any token endpoint auth methods
56
+ token_endpoint_auth_methods_supported: list[str] | None = None
57
+
58
+ # Allow any grant types
59
+ grant_types_supported: list[str] | None = None
60
+
61
+ # Allow any response types
62
+ response_types_supported: list[str] = ["code"]
63
+
64
+ # Allow any response modes
65
+ response_modes_supported: list[str] | None = None
66
+
67
+
68
+ class OAuthClientProvider(_MCPOAuthClientProvider):
69
+ """
70
+ OAuth client provider with more flexible OAuth metadata discovery.
71
+ """
72
+
73
+ async def _discover_oauth_metadata(
74
+ self, server_url: str
75
+ ) -> ServerOAuthMetadata | None:
76
+ """
77
+ Discover OAuth metadata with flexible validation.
78
+
79
+ This is nearly identical to the parent implementation but uses
80
+ ServerOAuthMetadata instead of the restrictive MCP OAuthMetadata.
81
+ """
82
+ # Extract base URL per MCP spec
83
+ auth_base_url = self._get_authorization_base_url(server_url)
84
+ url = urljoin(auth_base_url, "/.well-known/oauth-authorization-server")
85
+
86
+ from mcp.types import LATEST_PROTOCOL_VERSION
87
+
88
+ headers = {"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION}
89
+
90
+ async with httpx.AsyncClient() as client:
91
+ try:
92
+ response = await client.get(url, headers=headers)
93
+ if response.status_code == 404:
94
+ return None
95
+ response.raise_for_status()
96
+ metadata_json = response.json()
97
+ logger.debug(f"OAuth metadata discovered: {metadata_json}")
98
+ return ServerOAuthMetadata.model_validate(metadata_json)
99
+ except Exception:
100
+ # Retry without MCP header for CORS compatibility
101
+ try:
102
+ response = await client.get(url)
103
+ if response.status_code == 404:
104
+ return None
105
+ response.raise_for_status()
106
+ metadata_json = response.json()
107
+ logger.debug(
108
+ f"OAuth metadata discovered (no MCP header): {metadata_json}"
109
+ )
110
+ return ServerOAuthMetadata.model_validate(metadata_json)
111
+ except Exception:
112
+ logger.exception("Failed to discover OAuth metadata")
113
+ return None
114
+
115
+
116
+ class FileTokenStorage(TokenStorage):
117
+ """
118
+ File-based token storage implementation for OAuth credentials and tokens.
119
+ Implements the mcp.client.auth.TokenStorage protocol.
120
+
121
+ Each instance is tied to a specific server URL for proper token isolation.
122
+ """
123
+
124
+ def __init__(self, server_url: str, cache_dir: Path | None = None):
125
+ """Initialize storage for a specific server URL."""
126
+ self.server_url = server_url
127
+ self.cache_dir = cache_dir or default_cache_dir()
128
+ self.cache_dir.mkdir(exist_ok=True, parents=True)
129
+
130
+ @staticmethod
131
+ def get_base_url(url: str) -> str:
132
+ """Extract the base URL (scheme + host) from a URL."""
133
+ parsed = urlparse(url)
134
+ return f"{parsed.scheme}://{parsed.netloc}"
135
+
136
+ def get_cache_key(self) -> str:
137
+ """Generate a safe filesystem key from the server's base URL."""
138
+ base_url = self.get_base_url(self.server_url)
139
+ return (
140
+ base_url.replace("://", "_")
141
+ .replace(".", "_")
142
+ .replace("/", "_")
143
+ .replace(":", "_")
144
+ )
145
+
146
+ def _get_file_path(self, file_type: Literal["client_info", "tokens"]) -> Path:
147
+ """Get the file path for the specified cache file type."""
148
+ key = self.get_cache_key()
149
+ return self.cache_dir / f"{key}_{file_type}.json"
150
+
151
+ async def get_tokens(self) -> OAuthToken | None:
152
+ """Load tokens from file storage."""
153
+ path = self._get_file_path("tokens")
154
+
155
+ try:
156
+ tokens = OAuthToken.model_validate_json(path.read_text())
157
+ # now = datetime.datetime.now(datetime.timezone.utc)
158
+ # if tokens.expires_at is not None and tokens.expires_at <= now:
159
+ # logger.debug(f"Token expired for {self.get_base_url(self.server_url)}")
160
+ # return None
161
+ return tokens
162
+ except (FileNotFoundError, json.JSONDecodeError, ValidationError) as e:
163
+ logger.debug(
164
+ f"Could not load tokens for {self.get_base_url(self.server_url)}: {e}"
165
+ )
166
+ return None
167
+
168
+ async def set_tokens(self, tokens: OAuthToken) -> None:
169
+ """Save tokens to file storage."""
170
+ path = self._get_file_path("tokens")
171
+ path.write_text(tokens.model_dump_json(indent=2))
172
+ logger.debug(f"Saved tokens for {self.get_base_url(self.server_url)}")
173
+
174
+ async def get_client_info(self) -> OAuthClientInformationFull | None:
175
+ """Load client information from file storage."""
176
+ path = self._get_file_path("client_info")
177
+ try:
178
+ client_info = OAuthClientInformationFull.model_validate_json(
179
+ path.read_text()
180
+ )
181
+ # Check if we have corresponding valid tokens
182
+ # If no tokens exist, the OAuth flow was incomplete and we should
183
+ # force a fresh client registration
184
+ tokens = await self.get_tokens()
185
+ if tokens is None:
186
+ logger.debug(
187
+ f"No tokens found for client info at {self.get_base_url(self.server_url)}. "
188
+ "OAuth flow may have been incomplete. Clearing client info to force fresh registration."
189
+ )
190
+ # Clear the incomplete client info
191
+ client_info_path = self._get_file_path("client_info")
192
+ client_info_path.unlink(missing_ok=True)
193
+ return None
194
+
195
+ return client_info
196
+ except (FileNotFoundError, json.JSONDecodeError, ValidationError) as e:
197
+ logger.debug(
198
+ f"Could not load client info for {self.get_base_url(self.server_url)}: {e}"
199
+ )
200
+ return None
201
+
202
+ async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
203
+ """Save client information to file storage."""
204
+ path = self._get_file_path("client_info")
205
+ path.write_text(client_info.model_dump_json(indent=2))
206
+ logger.debug(f"Saved client info for {self.get_base_url(self.server_url)}")
207
+
208
+ def clear(self) -> None:
209
+ """Clear all cached data for this server."""
210
+ file_types: list[Literal["client_info", "tokens"]] = ["client_info", "tokens"]
211
+ for file_type in file_types:
212
+ path = self._get_file_path(file_type)
213
+ path.unlink(missing_ok=True)
214
+ logger.info(f"Cleared OAuth cache for {self.get_base_url(self.server_url)}")
215
+
216
+ @classmethod
217
+ def clear_all(cls, cache_dir: Path | None = None) -> None:
218
+ """Clear all cached data for all servers."""
219
+ cache_dir = cache_dir or default_cache_dir()
220
+ if not cache_dir.exists():
221
+ return
222
+
223
+ file_types: list[Literal["client_info", "tokens"]] = ["client_info", "tokens"]
224
+ for file_type in file_types:
225
+ for file in cache_dir.glob(f"*_{file_type}.json"):
226
+ file.unlink(missing_ok=True)
227
+ logger.info("Cleared all OAuth client cache data.")
228
+
229
+
230
+ async def discover_oauth_metadata(
231
+ server_base_url: str, httpx_kwargs: dict[str, Any] | None = None
232
+ ) -> _MCPServerOAuthMetadata | None:
233
+ """
234
+ Discover OAuth metadata from the server using RFC 8414 well-known endpoint.
235
+
236
+ Args:
237
+ server_base_url: Base URL of the OAuth server (e.g., "https://example.com")
238
+ httpx_kwargs: Additional kwargs for httpx client
239
+
240
+ Returns:
241
+ OAuth metadata if found, None otherwise
242
+ """
243
+ well_known_url = urljoin(server_base_url, "/.well-known/oauth-authorization-server")
244
+ logger.debug(f"Discovering OAuth metadata from: {well_known_url}")
245
+
246
+ async with httpx.AsyncClient(**(httpx_kwargs or {})) as client:
247
+ try:
248
+ response = await client.get(well_known_url, timeout=10.0)
249
+ if response.status_code == 200:
250
+ logger.debug("Successfully discovered OAuth metadata")
251
+ return _MCPServerOAuthMetadata.model_validate(response.json())
252
+ elif response.status_code == 404:
253
+ logger.debug(
254
+ "OAuth metadata not found (404) - server may not require auth"
255
+ )
256
+ return None
257
+ else:
258
+ logger.warning(f"OAuth metadata request failed: {response.status_code}")
259
+ return None
260
+ except (httpx.RequestError, json.JSONDecodeError, ValidationError) as e:
261
+ logger.debug(f"OAuth metadata discovery failed: {e}")
262
+ return None
263
+
264
+
265
+ async def check_if_auth_required(
266
+ mcp_url: str, httpx_kwargs: dict[str, Any] | None = None
267
+ ) -> bool:
268
+ """
269
+ Check if the MCP endpoint requires authentication by making a test request.
270
+
271
+ Returns:
272
+ True if auth appears to be required, False otherwise
273
+ """
274
+ async with httpx.AsyncClient(**(httpx_kwargs or {})) as client:
275
+ try:
276
+ # Try a simple request to the endpoint
277
+ response = await client.get(mcp_url, timeout=5.0)
278
+
279
+ # If we get 401/403, auth is likely required
280
+ if response.status_code in (401, 403):
281
+ return True
282
+
283
+ # Check for WWW-Authenticate header
284
+ if "WWW-Authenticate" in response.headers:
285
+ return True
286
+
287
+ # If we get a successful response, auth may not be required
288
+ return False
289
+
290
+ except httpx.RequestError:
291
+ # If we can't connect, assume auth might be required
292
+ return True
293
+
294
+
295
+ def OAuth(
296
+ mcp_url: str,
297
+ scopes: str | list[str] | None = None,
298
+ client_name: str = "FastMCP Client",
299
+ token_storage_cache_dir: Path | None = None,
300
+ additional_client_metadata: dict[str, Any] | None = None,
301
+ ) -> _MCPOAuthClientProvider:
302
+ """
303
+ Create an OAuthClientProvider for an MCP server.
304
+
305
+ This is intended to be provided to the `auth` parameter of an
306
+ httpx.AsyncClient (or appropriate FastMCP client/transport instance)
307
+
308
+ Args:
309
+ mcp_url: Full URL to the MCP endpoint (e.g.,
310
+ "http://host/mcp/sse")
311
+ scopes: OAuth scopes to request. Can be a
312
+ space-separated string or a list of strings.
313
+ client_name: Name for this client during registration
314
+ token_storage_cache_dir: Directory for FileTokenStorage
315
+ additional_client_metadata: Extra fields for OAuthClientMetadata
316
+
317
+ Returns:
318
+ OAuthClientProvider
319
+ """
320
+ parsed_url = urlparse(mcp_url)
321
+ server_base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
322
+
323
+ # Setup OAuth client
324
+ redirect_port = find_available_port()
325
+ redirect_uri = f"http://127.0.0.1:{redirect_port}/callback"
326
+
327
+ if isinstance(scopes, list):
328
+ scopes = " ".join(scopes)
329
+
330
+ client_metadata = OAuthClientMetadata(
331
+ client_name=client_name,
332
+ redirect_uris=[AnyHttpUrl(redirect_uri)],
333
+ grant_types=["authorization_code", "refresh_token"],
334
+ response_types=["code"],
335
+ token_endpoint_auth_method="client_secret_post",
336
+ scope=scopes,
337
+ **(additional_client_metadata or {}),
338
+ )
339
+
340
+ # Create server-specific token storage
341
+ storage = FileTokenStorage(
342
+ server_url=server_base_url, cache_dir=token_storage_cache_dir
343
+ )
344
+
345
+ # Define OAuth handlers
346
+ async def redirect_handler(authorization_url: str) -> None:
347
+ """Open browser for authorization."""
348
+ logger.info(f"OAuth authorization URL: {authorization_url}")
349
+ webbrowser.open(authorization_url)
350
+
351
+ async def callback_handler() -> tuple[str, str | None]:
352
+ """Handle OAuth callback and return (auth_code, state)."""
353
+ # Create a future to capture the OAuth response
354
+ response_future = asyncio.get_running_loop().create_future()
355
+
356
+ # Create server with the future
357
+ server = create_oauth_callback_server(
358
+ port=redirect_port,
359
+ server_url=server_base_url,
360
+ response_future=response_future,
361
+ )
362
+
363
+ # Run server until response is received with timeout logic
364
+ async with anyio.create_task_group() as tg:
365
+ tg.start_soon(server.serve)
366
+ logger.info(
367
+ f"🎧 OAuth callback server started on http://127.0.0.1:{redirect_port}"
368
+ )
369
+
370
+ TIMEOUT = 300.0 # 5 minute timeout
371
+ try:
372
+ with anyio.fail_after(TIMEOUT):
373
+ auth_code, state = await response_future
374
+ return auth_code, state
375
+ except TimeoutError:
376
+ raise TimeoutError(f"OAuth callback timed out after {TIMEOUT} seconds")
377
+ finally:
378
+ server.should_exit = True
379
+ await asyncio.sleep(0.1) # Allow server to shutdown gracefully
380
+ tg.cancel_scope.cancel()
381
+
382
+ # Create OAuth provider
383
+ oauth_provider = OAuthClientProvider(
384
+ server_url=server_base_url,
385
+ client_metadata=client_metadata,
386
+ storage=storage,
387
+ redirect_handler=redirect_handler,
388
+ callback_handler=callback_handler,
389
+ )
390
+
391
+ return oauth_provider
fastmcp/client/client.py CHANGED
@@ -1,9 +1,11 @@
1
+ import asyncio
1
2
  import datetime
2
3
  from contextlib import AsyncExitStack, asynccontextmanager
3
4
  from pathlib import Path
4
- from typing import Any, Generic, cast, overload
5
+ from typing import Any, Generic, Literal, cast, overload
5
6
 
6
7
  import anyio
8
+ import httpx
7
9
  import mcp.types
8
10
  from exceptiongroup import catch
9
11
  from mcp import ClientSession
@@ -43,6 +45,7 @@ from .transports import (
43
45
 
44
46
  __all__ = [
45
47
  "Client",
48
+ "SessionKwargs",
46
49
  "RootsHandler",
47
50
  "RootsList",
48
51
  "LogHandler",
@@ -142,11 +145,11 @@ class Client(Generic[ClientTransportT]):
142
145
  progress_handler: ProgressHandler | None = None,
143
146
  timeout: datetime.timedelta | float | int | None = None,
144
147
  init_timeout: datetime.timedelta | float | int | None = None,
148
+ auth: httpx.Auth | Literal["oauth"] | str | None = None,
145
149
  ):
146
150
  self.transport = cast(ClientTransportT, infer_transport(transport))
147
- self._session: ClientSession | None = None
148
- self._exit_stack: AsyncExitStack | None = None
149
- self._nesting_counter: int = 0
151
+ if auth is not None:
152
+ self.transport._set_auth(auth)
150
153
  self._initialize_result: mcp.types.InitializeResult | None = None
151
154
 
152
155
  if log_handler is None:
@@ -187,6 +190,15 @@ class Client(Generic[ClientTransportT]):
187
190
  sampling_handler
188
191
  )
189
192
 
193
+ # session context management
194
+ self._session: ClientSession | None = None
195
+ self._exit_stack: AsyncExitStack | None = None
196
+ self._nesting_counter: int = 0
197
+ self._context_lock = anyio.Lock()
198
+ self._session_task: asyncio.Task | None = None
199
+ self._ready_event = anyio.Event()
200
+ self._stop_event = anyio.Event()
201
+
190
202
  @property
191
203
  def session(self) -> ClientSession:
192
204
  """Get the current active session. Raises RuntimeError if not connected."""
@@ -237,39 +249,75 @@ class Client(Generic[ClientTransportT]):
237
249
  except TimeoutError:
238
250
  raise RuntimeError("Failed to initialize server session")
239
251
  finally:
240
- self._exit_stack = None
241
252
  self._session = None
242
253
  self._initialize_result = None
243
254
 
244
255
  async def __aenter__(self):
245
- if self._nesting_counter == 0:
246
- # Create exit stack to manage both context managers
247
- stack = AsyncExitStack()
248
- await stack.__aenter__()
249
-
250
- await stack.enter_async_context(self._context_manager())
251
-
252
- self._exit_stack = stack
253
-
254
- self._nesting_counter += 1
255
-
256
+ await self._connect()
256
257
  return self
257
258
 
258
259
  async def __aexit__(self, exc_type, exc_val, exc_tb):
259
- self._nesting_counter -= 1
260
+ await self._disconnect()
261
+
262
+ async def _connect(self):
263
+ # ensure only one session is running at a time to avoid race conditions
264
+ async with self._context_lock:
265
+ need_to_start = self._session_task is None or self._session_task.done()
266
+ if need_to_start:
267
+ self._stop_event = anyio.Event()
268
+ self._ready_event = anyio.Event()
269
+ self._session_task = asyncio.create_task(self._session_runner())
270
+ await self._ready_event.wait()
271
+ self._nesting_counter += 1
272
+ return self
260
273
 
261
- if self._nesting_counter == 0:
262
- # Exit the stack which will handle cleaning up the session
263
- if self._exit_stack is not None:
264
- try:
265
- await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb)
266
- finally:
267
- self._exit_stack = None
274
+ async def _disconnect(self, force: bool = False):
275
+ # ensure only one session is running at a time to avoid race conditions
276
+ async with self._context_lock:
277
+ # if we are forcing a disconnect, reset the nesting counter
278
+ if force:
279
+ self._nesting_counter = 0
280
+
281
+ # otherwise decrement to check if we are done nesting
282
+ else:
283
+ self._nesting_counter = max(0, self._nesting_counter - 1)
284
+
285
+ # if we are still nested, return
286
+ if self._nesting_counter > 0:
287
+ return
288
+
289
+ # stop the active seesion
290
+ if self._session_task is None:
291
+ return
292
+ self._stop_event.set()
293
+ runner_task = self._session_task
294
+ self._session_task = None
295
+
296
+ # wait for the session to finish
297
+ if runner_task:
298
+ await runner_task
299
+
300
+ # Reset for future reconnects
301
+ self._stop_event = anyio.Event()
302
+ self._ready_event = anyio.Event()
303
+ self._session = None
304
+ self._initialize_result = None
305
+
306
+ async def _session_runner(self):
307
+ async with AsyncExitStack() as stack:
308
+ try:
309
+ await stack.enter_async_context(self._context_manager())
310
+ # Session/context is now ready
311
+ self._ready_event.set()
312
+ # Wait until disconnect/stop is requested
313
+ await self._stop_event.wait()
314
+ finally:
315
+ # On exit, ensure ready event is set (idempotent)
316
+ self._ready_event.set()
268
317
 
269
318
  async def close(self):
319
+ await self._disconnect(force=True)
270
320
  await self.transport.close()
271
- self._session = None
272
- self._initialize_result = None
273
321
 
274
322
  # --- MCP Client Methods ---
275
323