fastmcp 2.5.2__py3-none-any.whl → 2.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastmcp/client/__init__.py +3 -0
- fastmcp/client/auth/__init__.py +4 -0
- fastmcp/client/auth/bearer.py +17 -0
- fastmcp/client/auth/oauth.py +391 -0
- fastmcp/client/client.py +74 -26
- fastmcp/client/oauth_callback.py +310 -0
- fastmcp/client/transports.py +76 -14
- fastmcp/server/auth/__init__.py +4 -0
- fastmcp/server/auth/auth.py +45 -0
- fastmcp/server/auth/providers/bearer.py +377 -0
- fastmcp/server/auth/providers/bearer_env.py +62 -0
- fastmcp/server/auth/providers/in_memory.py +325 -0
- fastmcp/server/dependencies.py +10 -0
- fastmcp/server/http.py +38 -66
- fastmcp/server/openapi.py +2 -0
- fastmcp/server/server.py +21 -26
- fastmcp/settings.py +27 -8
- fastmcp/tools/tool.py +22 -3
- fastmcp/tools/tool_manager.py +2 -0
- fastmcp/utilities/http.py +8 -0
- fastmcp/utilities/tests.py +22 -10
- {fastmcp-2.5.2.dist-info → fastmcp-2.6.1.dist-info}/METADATA +28 -16
- {fastmcp-2.5.2.dist-info → fastmcp-2.6.1.dist-info}/RECORD +27 -19
- fastmcp/client/base.py +0 -0
- fastmcp/low_level/README.md +0 -1
- /fastmcp/{low_level → server/auth/providers}/__init__.py +0 -0
- {fastmcp-2.5.2.dist-info → fastmcp-2.6.1.dist-info}/WHEEL +0 -0
- {fastmcp-2.5.2.dist-info → fastmcp-2.6.1.dist-info}/entry_points.txt +0 -0
- {fastmcp-2.5.2.dist-info → fastmcp-2.6.1.dist-info}/licenses/LICENSE +0 -0
fastmcp/client/__init__.py
CHANGED
|
@@ -11,6 +11,7 @@ from .transports import (
|
|
|
11
11
|
FastMCPTransport,
|
|
12
12
|
StreamableHttpTransport,
|
|
13
13
|
)
|
|
14
|
+
from .auth import OAuth, BearerAuth
|
|
14
15
|
|
|
15
16
|
__all__ = [
|
|
16
17
|
"Client",
|
|
@@ -24,4 +25,6 @@ __all__ = [
|
|
|
24
25
|
"NpxStdioTransport",
|
|
25
26
|
"FastMCPTransport",
|
|
26
27
|
"StreamableHttpTransport",
|
|
28
|
+
"OAuth",
|
|
29
|
+
"BearerAuth",
|
|
27
30
|
]
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
import httpx
|
|
2
|
+
from pydantic import SecretStr
|
|
3
|
+
|
|
4
|
+
from fastmcp.utilities.logging import get_logger
|
|
5
|
+
|
|
6
|
+
__all__ = ["BearerAuth"]
|
|
7
|
+
|
|
8
|
+
logger = get_logger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class BearerAuth(httpx.Auth):
|
|
12
|
+
def __init__(self, token: str):
|
|
13
|
+
self.token = SecretStr(token)
|
|
14
|
+
|
|
15
|
+
def auth_flow(self, request):
|
|
16
|
+
request.headers["Authorization"] = f"Bearer {self.token.get_secret_value()}"
|
|
17
|
+
yield request
|
|
@@ -0,0 +1,391 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import json
|
|
5
|
+
import webbrowser
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Literal
|
|
8
|
+
from urllib.parse import urljoin, urlparse
|
|
9
|
+
|
|
10
|
+
import anyio
|
|
11
|
+
import httpx
|
|
12
|
+
from mcp.client.auth import OAuthClientProvider as _MCPOAuthClientProvider
|
|
13
|
+
from mcp.client.auth import TokenStorage
|
|
14
|
+
from mcp.shared.auth import (
|
|
15
|
+
OAuthClientInformationFull,
|
|
16
|
+
OAuthClientMetadata,
|
|
17
|
+
)
|
|
18
|
+
from mcp.shared.auth import (
|
|
19
|
+
OAuthMetadata as _MCPServerOAuthMetadata,
|
|
20
|
+
)
|
|
21
|
+
from mcp.shared.auth import (
|
|
22
|
+
OAuthToken as OAuthToken,
|
|
23
|
+
)
|
|
24
|
+
from pydantic import AnyHttpUrl, ValidationError
|
|
25
|
+
|
|
26
|
+
from fastmcp.client.oauth_callback import (
|
|
27
|
+
create_oauth_callback_server,
|
|
28
|
+
)
|
|
29
|
+
from fastmcp.settings import settings as fastmcp_global_settings
|
|
30
|
+
from fastmcp.utilities.http import find_available_port
|
|
31
|
+
from fastmcp.utilities.logging import get_logger
|
|
32
|
+
|
|
33
|
+
__all__ = ["OAuth"]
|
|
34
|
+
|
|
35
|
+
logger = get_logger(__name__)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def default_cache_dir() -> Path:
|
|
39
|
+
return fastmcp_global_settings.home / "oauth-mcp-client-cache"
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
# Flexible OAuth models for real-world compatibility
|
|
43
|
+
class ServerOAuthMetadata(_MCPServerOAuthMetadata):
|
|
44
|
+
"""
|
|
45
|
+
More flexible OAuth metadata model that accepts broader ranges of values
|
|
46
|
+
than the restrictive MCP standard model.
|
|
47
|
+
|
|
48
|
+
This handles real-world OAuth servers like PayPal that may support
|
|
49
|
+
additional methods not in the MCP specification.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
# Allow any code challenge methods, not just S256
|
|
53
|
+
code_challenge_methods_supported: list[str] | None = None
|
|
54
|
+
|
|
55
|
+
# Allow any token endpoint auth methods
|
|
56
|
+
token_endpoint_auth_methods_supported: list[str] | None = None
|
|
57
|
+
|
|
58
|
+
# Allow any grant types
|
|
59
|
+
grant_types_supported: list[str] | None = None
|
|
60
|
+
|
|
61
|
+
# Allow any response types
|
|
62
|
+
response_types_supported: list[str] = ["code"]
|
|
63
|
+
|
|
64
|
+
# Allow any response modes
|
|
65
|
+
response_modes_supported: list[str] | None = None
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class OAuthClientProvider(_MCPOAuthClientProvider):
|
|
69
|
+
"""
|
|
70
|
+
OAuth client provider with more flexible OAuth metadata discovery.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
async def _discover_oauth_metadata(
|
|
74
|
+
self, server_url: str
|
|
75
|
+
) -> ServerOAuthMetadata | None:
|
|
76
|
+
"""
|
|
77
|
+
Discover OAuth metadata with flexible validation.
|
|
78
|
+
|
|
79
|
+
This is nearly identical to the parent implementation but uses
|
|
80
|
+
ServerOAuthMetadata instead of the restrictive MCP OAuthMetadata.
|
|
81
|
+
"""
|
|
82
|
+
# Extract base URL per MCP spec
|
|
83
|
+
auth_base_url = self._get_authorization_base_url(server_url)
|
|
84
|
+
url = urljoin(auth_base_url, "/.well-known/oauth-authorization-server")
|
|
85
|
+
|
|
86
|
+
from mcp.types import LATEST_PROTOCOL_VERSION
|
|
87
|
+
|
|
88
|
+
headers = {"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION}
|
|
89
|
+
|
|
90
|
+
async with httpx.AsyncClient() as client:
|
|
91
|
+
try:
|
|
92
|
+
response = await client.get(url, headers=headers)
|
|
93
|
+
if response.status_code == 404:
|
|
94
|
+
return None
|
|
95
|
+
response.raise_for_status()
|
|
96
|
+
metadata_json = response.json()
|
|
97
|
+
logger.debug(f"OAuth metadata discovered: {metadata_json}")
|
|
98
|
+
return ServerOAuthMetadata.model_validate(metadata_json)
|
|
99
|
+
except Exception:
|
|
100
|
+
# Retry without MCP header for CORS compatibility
|
|
101
|
+
try:
|
|
102
|
+
response = await client.get(url)
|
|
103
|
+
if response.status_code == 404:
|
|
104
|
+
return None
|
|
105
|
+
response.raise_for_status()
|
|
106
|
+
metadata_json = response.json()
|
|
107
|
+
logger.debug(
|
|
108
|
+
f"OAuth metadata discovered (no MCP header): {metadata_json}"
|
|
109
|
+
)
|
|
110
|
+
return ServerOAuthMetadata.model_validate(metadata_json)
|
|
111
|
+
except Exception:
|
|
112
|
+
logger.exception("Failed to discover OAuth metadata")
|
|
113
|
+
return None
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class FileTokenStorage(TokenStorage):
|
|
117
|
+
"""
|
|
118
|
+
File-based token storage implementation for OAuth credentials and tokens.
|
|
119
|
+
Implements the mcp.client.auth.TokenStorage protocol.
|
|
120
|
+
|
|
121
|
+
Each instance is tied to a specific server URL for proper token isolation.
|
|
122
|
+
"""
|
|
123
|
+
|
|
124
|
+
def __init__(self, server_url: str, cache_dir: Path | None = None):
|
|
125
|
+
"""Initialize storage for a specific server URL."""
|
|
126
|
+
self.server_url = server_url
|
|
127
|
+
self.cache_dir = cache_dir or default_cache_dir()
|
|
128
|
+
self.cache_dir.mkdir(exist_ok=True, parents=True)
|
|
129
|
+
|
|
130
|
+
@staticmethod
|
|
131
|
+
def get_base_url(url: str) -> str:
|
|
132
|
+
"""Extract the base URL (scheme + host) from a URL."""
|
|
133
|
+
parsed = urlparse(url)
|
|
134
|
+
return f"{parsed.scheme}://{parsed.netloc}"
|
|
135
|
+
|
|
136
|
+
def get_cache_key(self) -> str:
|
|
137
|
+
"""Generate a safe filesystem key from the server's base URL."""
|
|
138
|
+
base_url = self.get_base_url(self.server_url)
|
|
139
|
+
return (
|
|
140
|
+
base_url.replace("://", "_")
|
|
141
|
+
.replace(".", "_")
|
|
142
|
+
.replace("/", "_")
|
|
143
|
+
.replace(":", "_")
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
def _get_file_path(self, file_type: Literal["client_info", "tokens"]) -> Path:
|
|
147
|
+
"""Get the file path for the specified cache file type."""
|
|
148
|
+
key = self.get_cache_key()
|
|
149
|
+
return self.cache_dir / f"{key}_{file_type}.json"
|
|
150
|
+
|
|
151
|
+
async def get_tokens(self) -> OAuthToken | None:
|
|
152
|
+
"""Load tokens from file storage."""
|
|
153
|
+
path = self._get_file_path("tokens")
|
|
154
|
+
|
|
155
|
+
try:
|
|
156
|
+
tokens = OAuthToken.model_validate_json(path.read_text())
|
|
157
|
+
# now = datetime.datetime.now(datetime.timezone.utc)
|
|
158
|
+
# if tokens.expires_at is not None and tokens.expires_at <= now:
|
|
159
|
+
# logger.debug(f"Token expired for {self.get_base_url(self.server_url)}")
|
|
160
|
+
# return None
|
|
161
|
+
return tokens
|
|
162
|
+
except (FileNotFoundError, json.JSONDecodeError, ValidationError) as e:
|
|
163
|
+
logger.debug(
|
|
164
|
+
f"Could not load tokens for {self.get_base_url(self.server_url)}: {e}"
|
|
165
|
+
)
|
|
166
|
+
return None
|
|
167
|
+
|
|
168
|
+
async def set_tokens(self, tokens: OAuthToken) -> None:
|
|
169
|
+
"""Save tokens to file storage."""
|
|
170
|
+
path = self._get_file_path("tokens")
|
|
171
|
+
path.write_text(tokens.model_dump_json(indent=2))
|
|
172
|
+
logger.debug(f"Saved tokens for {self.get_base_url(self.server_url)}")
|
|
173
|
+
|
|
174
|
+
async def get_client_info(self) -> OAuthClientInformationFull | None:
|
|
175
|
+
"""Load client information from file storage."""
|
|
176
|
+
path = self._get_file_path("client_info")
|
|
177
|
+
try:
|
|
178
|
+
client_info = OAuthClientInformationFull.model_validate_json(
|
|
179
|
+
path.read_text()
|
|
180
|
+
)
|
|
181
|
+
# Check if we have corresponding valid tokens
|
|
182
|
+
# If no tokens exist, the OAuth flow was incomplete and we should
|
|
183
|
+
# force a fresh client registration
|
|
184
|
+
tokens = await self.get_tokens()
|
|
185
|
+
if tokens is None:
|
|
186
|
+
logger.debug(
|
|
187
|
+
f"No tokens found for client info at {self.get_base_url(self.server_url)}. "
|
|
188
|
+
"OAuth flow may have been incomplete. Clearing client info to force fresh registration."
|
|
189
|
+
)
|
|
190
|
+
# Clear the incomplete client info
|
|
191
|
+
client_info_path = self._get_file_path("client_info")
|
|
192
|
+
client_info_path.unlink(missing_ok=True)
|
|
193
|
+
return None
|
|
194
|
+
|
|
195
|
+
return client_info
|
|
196
|
+
except (FileNotFoundError, json.JSONDecodeError, ValidationError) as e:
|
|
197
|
+
logger.debug(
|
|
198
|
+
f"Could not load client info for {self.get_base_url(self.server_url)}: {e}"
|
|
199
|
+
)
|
|
200
|
+
return None
|
|
201
|
+
|
|
202
|
+
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
|
|
203
|
+
"""Save client information to file storage."""
|
|
204
|
+
path = self._get_file_path("client_info")
|
|
205
|
+
path.write_text(client_info.model_dump_json(indent=2))
|
|
206
|
+
logger.debug(f"Saved client info for {self.get_base_url(self.server_url)}")
|
|
207
|
+
|
|
208
|
+
def clear(self) -> None:
|
|
209
|
+
"""Clear all cached data for this server."""
|
|
210
|
+
file_types: list[Literal["client_info", "tokens"]] = ["client_info", "tokens"]
|
|
211
|
+
for file_type in file_types:
|
|
212
|
+
path = self._get_file_path(file_type)
|
|
213
|
+
path.unlink(missing_ok=True)
|
|
214
|
+
logger.info(f"Cleared OAuth cache for {self.get_base_url(self.server_url)}")
|
|
215
|
+
|
|
216
|
+
@classmethod
|
|
217
|
+
def clear_all(cls, cache_dir: Path | None = None) -> None:
|
|
218
|
+
"""Clear all cached data for all servers."""
|
|
219
|
+
cache_dir = cache_dir or default_cache_dir()
|
|
220
|
+
if not cache_dir.exists():
|
|
221
|
+
return
|
|
222
|
+
|
|
223
|
+
file_types: list[Literal["client_info", "tokens"]] = ["client_info", "tokens"]
|
|
224
|
+
for file_type in file_types:
|
|
225
|
+
for file in cache_dir.glob(f"*_{file_type}.json"):
|
|
226
|
+
file.unlink(missing_ok=True)
|
|
227
|
+
logger.info("Cleared all OAuth client cache data.")
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
async def discover_oauth_metadata(
|
|
231
|
+
server_base_url: str, httpx_kwargs: dict[str, Any] | None = None
|
|
232
|
+
) -> _MCPServerOAuthMetadata | None:
|
|
233
|
+
"""
|
|
234
|
+
Discover OAuth metadata from the server using RFC 8414 well-known endpoint.
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
server_base_url: Base URL of the OAuth server (e.g., "https://example.com")
|
|
238
|
+
httpx_kwargs: Additional kwargs for httpx client
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
OAuth metadata if found, None otherwise
|
|
242
|
+
"""
|
|
243
|
+
well_known_url = urljoin(server_base_url, "/.well-known/oauth-authorization-server")
|
|
244
|
+
logger.debug(f"Discovering OAuth metadata from: {well_known_url}")
|
|
245
|
+
|
|
246
|
+
async with httpx.AsyncClient(**(httpx_kwargs or {})) as client:
|
|
247
|
+
try:
|
|
248
|
+
response = await client.get(well_known_url, timeout=10.0)
|
|
249
|
+
if response.status_code == 200:
|
|
250
|
+
logger.debug("Successfully discovered OAuth metadata")
|
|
251
|
+
return _MCPServerOAuthMetadata.model_validate(response.json())
|
|
252
|
+
elif response.status_code == 404:
|
|
253
|
+
logger.debug(
|
|
254
|
+
"OAuth metadata not found (404) - server may not require auth"
|
|
255
|
+
)
|
|
256
|
+
return None
|
|
257
|
+
else:
|
|
258
|
+
logger.warning(f"OAuth metadata request failed: {response.status_code}")
|
|
259
|
+
return None
|
|
260
|
+
except (httpx.RequestError, json.JSONDecodeError, ValidationError) as e:
|
|
261
|
+
logger.debug(f"OAuth metadata discovery failed: {e}")
|
|
262
|
+
return None
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
async def check_if_auth_required(
|
|
266
|
+
mcp_url: str, httpx_kwargs: dict[str, Any] | None = None
|
|
267
|
+
) -> bool:
|
|
268
|
+
"""
|
|
269
|
+
Check if the MCP endpoint requires authentication by making a test request.
|
|
270
|
+
|
|
271
|
+
Returns:
|
|
272
|
+
True if auth appears to be required, False otherwise
|
|
273
|
+
"""
|
|
274
|
+
async with httpx.AsyncClient(**(httpx_kwargs or {})) as client:
|
|
275
|
+
try:
|
|
276
|
+
# Try a simple request to the endpoint
|
|
277
|
+
response = await client.get(mcp_url, timeout=5.0)
|
|
278
|
+
|
|
279
|
+
# If we get 401/403, auth is likely required
|
|
280
|
+
if response.status_code in (401, 403):
|
|
281
|
+
return True
|
|
282
|
+
|
|
283
|
+
# Check for WWW-Authenticate header
|
|
284
|
+
if "WWW-Authenticate" in response.headers:
|
|
285
|
+
return True
|
|
286
|
+
|
|
287
|
+
# If we get a successful response, auth may not be required
|
|
288
|
+
return False
|
|
289
|
+
|
|
290
|
+
except httpx.RequestError:
|
|
291
|
+
# If we can't connect, assume auth might be required
|
|
292
|
+
return True
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def OAuth(
|
|
296
|
+
mcp_url: str,
|
|
297
|
+
scopes: str | list[str] | None = None,
|
|
298
|
+
client_name: str = "FastMCP Client",
|
|
299
|
+
token_storage_cache_dir: Path | None = None,
|
|
300
|
+
additional_client_metadata: dict[str, Any] | None = None,
|
|
301
|
+
) -> _MCPOAuthClientProvider:
|
|
302
|
+
"""
|
|
303
|
+
Create an OAuthClientProvider for an MCP server.
|
|
304
|
+
|
|
305
|
+
This is intended to be provided to the `auth` parameter of an
|
|
306
|
+
httpx.AsyncClient (or appropriate FastMCP client/transport instance)
|
|
307
|
+
|
|
308
|
+
Args:
|
|
309
|
+
mcp_url: Full URL to the MCP endpoint (e.g.,
|
|
310
|
+
"http://host/mcp/sse")
|
|
311
|
+
scopes: OAuth scopes to request. Can be a
|
|
312
|
+
space-separated string or a list of strings.
|
|
313
|
+
client_name: Name for this client during registration
|
|
314
|
+
token_storage_cache_dir: Directory for FileTokenStorage
|
|
315
|
+
additional_client_metadata: Extra fields for OAuthClientMetadata
|
|
316
|
+
|
|
317
|
+
Returns:
|
|
318
|
+
OAuthClientProvider
|
|
319
|
+
"""
|
|
320
|
+
parsed_url = urlparse(mcp_url)
|
|
321
|
+
server_base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
|
322
|
+
|
|
323
|
+
# Setup OAuth client
|
|
324
|
+
redirect_port = find_available_port()
|
|
325
|
+
redirect_uri = f"http://127.0.0.1:{redirect_port}/callback"
|
|
326
|
+
|
|
327
|
+
if isinstance(scopes, list):
|
|
328
|
+
scopes = " ".join(scopes)
|
|
329
|
+
|
|
330
|
+
client_metadata = OAuthClientMetadata(
|
|
331
|
+
client_name=client_name,
|
|
332
|
+
redirect_uris=[AnyHttpUrl(redirect_uri)],
|
|
333
|
+
grant_types=["authorization_code", "refresh_token"],
|
|
334
|
+
response_types=["code"],
|
|
335
|
+
token_endpoint_auth_method="client_secret_post",
|
|
336
|
+
scope=scopes,
|
|
337
|
+
**(additional_client_metadata or {}),
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
# Create server-specific token storage
|
|
341
|
+
storage = FileTokenStorage(
|
|
342
|
+
server_url=server_base_url, cache_dir=token_storage_cache_dir
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
# Define OAuth handlers
|
|
346
|
+
async def redirect_handler(authorization_url: str) -> None:
|
|
347
|
+
"""Open browser for authorization."""
|
|
348
|
+
logger.info(f"OAuth authorization URL: {authorization_url}")
|
|
349
|
+
webbrowser.open(authorization_url)
|
|
350
|
+
|
|
351
|
+
async def callback_handler() -> tuple[str, str | None]:
|
|
352
|
+
"""Handle OAuth callback and return (auth_code, state)."""
|
|
353
|
+
# Create a future to capture the OAuth response
|
|
354
|
+
response_future = asyncio.get_running_loop().create_future()
|
|
355
|
+
|
|
356
|
+
# Create server with the future
|
|
357
|
+
server = create_oauth_callback_server(
|
|
358
|
+
port=redirect_port,
|
|
359
|
+
server_url=server_base_url,
|
|
360
|
+
response_future=response_future,
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
# Run server until response is received with timeout logic
|
|
364
|
+
async with anyio.create_task_group() as tg:
|
|
365
|
+
tg.start_soon(server.serve)
|
|
366
|
+
logger.info(
|
|
367
|
+
f"🎧 OAuth callback server started on http://127.0.0.1:{redirect_port}"
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
TIMEOUT = 300.0 # 5 minute timeout
|
|
371
|
+
try:
|
|
372
|
+
with anyio.fail_after(TIMEOUT):
|
|
373
|
+
auth_code, state = await response_future
|
|
374
|
+
return auth_code, state
|
|
375
|
+
except TimeoutError:
|
|
376
|
+
raise TimeoutError(f"OAuth callback timed out after {TIMEOUT} seconds")
|
|
377
|
+
finally:
|
|
378
|
+
server.should_exit = True
|
|
379
|
+
await asyncio.sleep(0.1) # Allow server to shutdown gracefully
|
|
380
|
+
tg.cancel_scope.cancel()
|
|
381
|
+
|
|
382
|
+
# Create OAuth provider
|
|
383
|
+
oauth_provider = OAuthClientProvider(
|
|
384
|
+
server_url=server_base_url,
|
|
385
|
+
client_metadata=client_metadata,
|
|
386
|
+
storage=storage,
|
|
387
|
+
redirect_handler=redirect_handler,
|
|
388
|
+
callback_handler=callback_handler,
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
return oauth_provider
|
fastmcp/client/client.py
CHANGED
|
@@ -1,9 +1,11 @@
|
|
|
1
|
+
import asyncio
|
|
1
2
|
import datetime
|
|
2
3
|
from contextlib import AsyncExitStack, asynccontextmanager
|
|
3
4
|
from pathlib import Path
|
|
4
|
-
from typing import Any, Generic, cast, overload
|
|
5
|
+
from typing import Any, Generic, Literal, cast, overload
|
|
5
6
|
|
|
6
7
|
import anyio
|
|
8
|
+
import httpx
|
|
7
9
|
import mcp.types
|
|
8
10
|
from exceptiongroup import catch
|
|
9
11
|
from mcp import ClientSession
|
|
@@ -43,6 +45,7 @@ from .transports import (
|
|
|
43
45
|
|
|
44
46
|
__all__ = [
|
|
45
47
|
"Client",
|
|
48
|
+
"SessionKwargs",
|
|
46
49
|
"RootsHandler",
|
|
47
50
|
"RootsList",
|
|
48
51
|
"LogHandler",
|
|
@@ -142,11 +145,11 @@ class Client(Generic[ClientTransportT]):
|
|
|
142
145
|
progress_handler: ProgressHandler | None = None,
|
|
143
146
|
timeout: datetime.timedelta | float | int | None = None,
|
|
144
147
|
init_timeout: datetime.timedelta | float | int | None = None,
|
|
148
|
+
auth: httpx.Auth | Literal["oauth"] | str | None = None,
|
|
145
149
|
):
|
|
146
150
|
self.transport = cast(ClientTransportT, infer_transport(transport))
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
self._nesting_counter: int = 0
|
|
151
|
+
if auth is not None:
|
|
152
|
+
self.transport._set_auth(auth)
|
|
150
153
|
self._initialize_result: mcp.types.InitializeResult | None = None
|
|
151
154
|
|
|
152
155
|
if log_handler is None:
|
|
@@ -187,6 +190,15 @@ class Client(Generic[ClientTransportT]):
|
|
|
187
190
|
sampling_handler
|
|
188
191
|
)
|
|
189
192
|
|
|
193
|
+
# session context management
|
|
194
|
+
self._session: ClientSession | None = None
|
|
195
|
+
self._exit_stack: AsyncExitStack | None = None
|
|
196
|
+
self._nesting_counter: int = 0
|
|
197
|
+
self._context_lock = anyio.Lock()
|
|
198
|
+
self._session_task: asyncio.Task | None = None
|
|
199
|
+
self._ready_event = anyio.Event()
|
|
200
|
+
self._stop_event = anyio.Event()
|
|
201
|
+
|
|
190
202
|
@property
|
|
191
203
|
def session(self) -> ClientSession:
|
|
192
204
|
"""Get the current active session. Raises RuntimeError if not connected."""
|
|
@@ -237,39 +249,75 @@ class Client(Generic[ClientTransportT]):
|
|
|
237
249
|
except TimeoutError:
|
|
238
250
|
raise RuntimeError("Failed to initialize server session")
|
|
239
251
|
finally:
|
|
240
|
-
self._exit_stack = None
|
|
241
252
|
self._session = None
|
|
242
253
|
self._initialize_result = None
|
|
243
254
|
|
|
244
255
|
async def __aenter__(self):
|
|
245
|
-
|
|
246
|
-
# Create exit stack to manage both context managers
|
|
247
|
-
stack = AsyncExitStack()
|
|
248
|
-
await stack.__aenter__()
|
|
249
|
-
|
|
250
|
-
await stack.enter_async_context(self._context_manager())
|
|
251
|
-
|
|
252
|
-
self._exit_stack = stack
|
|
253
|
-
|
|
254
|
-
self._nesting_counter += 1
|
|
255
|
-
|
|
256
|
+
await self._connect()
|
|
256
257
|
return self
|
|
257
258
|
|
|
258
259
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
259
|
-
self.
|
|
260
|
+
await self._disconnect()
|
|
261
|
+
|
|
262
|
+
async def _connect(self):
|
|
263
|
+
# ensure only one session is running at a time to avoid race conditions
|
|
264
|
+
async with self._context_lock:
|
|
265
|
+
need_to_start = self._session_task is None or self._session_task.done()
|
|
266
|
+
if need_to_start:
|
|
267
|
+
self._stop_event = anyio.Event()
|
|
268
|
+
self._ready_event = anyio.Event()
|
|
269
|
+
self._session_task = asyncio.create_task(self._session_runner())
|
|
270
|
+
await self._ready_event.wait()
|
|
271
|
+
self._nesting_counter += 1
|
|
272
|
+
return self
|
|
260
273
|
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
274
|
+
async def _disconnect(self, force: bool = False):
|
|
275
|
+
# ensure only one session is running at a time to avoid race conditions
|
|
276
|
+
async with self._context_lock:
|
|
277
|
+
# if we are forcing a disconnect, reset the nesting counter
|
|
278
|
+
if force:
|
|
279
|
+
self._nesting_counter = 0
|
|
280
|
+
|
|
281
|
+
# otherwise decrement to check if we are done nesting
|
|
282
|
+
else:
|
|
283
|
+
self._nesting_counter = max(0, self._nesting_counter - 1)
|
|
284
|
+
|
|
285
|
+
# if we are still nested, return
|
|
286
|
+
if self._nesting_counter > 0:
|
|
287
|
+
return
|
|
288
|
+
|
|
289
|
+
# stop the active seesion
|
|
290
|
+
if self._session_task is None:
|
|
291
|
+
return
|
|
292
|
+
self._stop_event.set()
|
|
293
|
+
runner_task = self._session_task
|
|
294
|
+
self._session_task = None
|
|
295
|
+
|
|
296
|
+
# wait for the session to finish
|
|
297
|
+
if runner_task:
|
|
298
|
+
await runner_task
|
|
299
|
+
|
|
300
|
+
# Reset for future reconnects
|
|
301
|
+
self._stop_event = anyio.Event()
|
|
302
|
+
self._ready_event = anyio.Event()
|
|
303
|
+
self._session = None
|
|
304
|
+
self._initialize_result = None
|
|
305
|
+
|
|
306
|
+
async def _session_runner(self):
|
|
307
|
+
async with AsyncExitStack() as stack:
|
|
308
|
+
try:
|
|
309
|
+
await stack.enter_async_context(self._context_manager())
|
|
310
|
+
# Session/context is now ready
|
|
311
|
+
self._ready_event.set()
|
|
312
|
+
# Wait until disconnect/stop is requested
|
|
313
|
+
await self._stop_event.wait()
|
|
314
|
+
finally:
|
|
315
|
+
# On exit, ensure ready event is set (idempotent)
|
|
316
|
+
self._ready_event.set()
|
|
268
317
|
|
|
269
318
|
async def close(self):
|
|
319
|
+
await self._disconnect(force=True)
|
|
270
320
|
await self.transport.close()
|
|
271
|
-
self._session = None
|
|
272
|
-
self._initialize_result = None
|
|
273
321
|
|
|
274
322
|
# --- MCP Client Methods ---
|
|
275
323
|
|