fastmcp 2.9.2__py3-none-any.whl → 2.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. fastmcp/client/auth/oauth.py +5 -82
  2. fastmcp/client/client.py +114 -24
  3. fastmcp/client/elicitation.py +63 -0
  4. fastmcp/client/transports.py +50 -36
  5. fastmcp/contrib/component_manager/README.md +170 -0
  6. fastmcp/contrib/component_manager/__init__.py +4 -0
  7. fastmcp/contrib/component_manager/component_manager.py +186 -0
  8. fastmcp/contrib/component_manager/component_service.py +225 -0
  9. fastmcp/contrib/component_manager/example.py +59 -0
  10. fastmcp/prompts/prompt.py +12 -4
  11. fastmcp/resources/resource.py +8 -3
  12. fastmcp/resources/template.py +5 -0
  13. fastmcp/server/auth/auth.py +15 -0
  14. fastmcp/server/auth/providers/bearer.py +41 -3
  15. fastmcp/server/auth/providers/bearer_env.py +4 -0
  16. fastmcp/server/auth/providers/in_memory.py +15 -0
  17. fastmcp/server/context.py +144 -4
  18. fastmcp/server/elicitation.py +160 -0
  19. fastmcp/server/http.py +1 -9
  20. fastmcp/server/low_level.py +4 -2
  21. fastmcp/server/middleware/__init__.py +14 -1
  22. fastmcp/server/middleware/logging.py +11 -0
  23. fastmcp/server/middleware/middleware.py +10 -6
  24. fastmcp/server/openapi.py +19 -77
  25. fastmcp/server/proxy.py +13 -6
  26. fastmcp/server/server.py +27 -7
  27. fastmcp/settings.py +0 -17
  28. fastmcp/tools/tool.py +209 -57
  29. fastmcp/tools/tool_manager.py +2 -3
  30. fastmcp/tools/tool_transform.py +125 -26
  31. fastmcp/utilities/components.py +5 -1
  32. fastmcp/utilities/json_schema_type.py +648 -0
  33. fastmcp/utilities/openapi.py +69 -0
  34. fastmcp/utilities/types.py +50 -19
  35. {fastmcp-2.9.2.dist-info → fastmcp-2.10.1.dist-info}/METADATA +3 -2
  36. {fastmcp-2.9.2.dist-info → fastmcp-2.10.1.dist-info}/RECORD +39 -31
  37. {fastmcp-2.9.2.dist-info → fastmcp-2.10.1.dist-info}/WHEEL +0 -0
  38. {fastmcp-2.9.2.dist-info → fastmcp-2.10.1.dist-info}/entry_points.txt +0 -0
  39. {fastmcp-2.9.2.dist-info → fastmcp-2.10.1.dist-info}/licenses/LICENSE +0 -0
@@ -9,14 +9,11 @@ from urllib.parse import urljoin, urlparse
9
9
 
10
10
  import anyio
11
11
  import httpx
12
- from mcp.client.auth import OAuthClientProvider as _MCPOAuthClientProvider
13
- from mcp.client.auth import TokenStorage
12
+ from mcp.client.auth import OAuthClientProvider, TokenStorage
14
13
  from mcp.shared.auth import (
15
14
  OAuthClientInformationFull,
16
15
  OAuthClientMetadata,
17
- )
18
- from mcp.shared.auth import (
19
- OAuthMetadata as _MCPServerOAuthMetadata,
16
+ OAuthMetadata,
20
17
  )
21
18
  from mcp.shared.auth import (
22
19
  OAuthToken as OAuthToken,
@@ -39,80 +36,6 @@ def default_cache_dir() -> Path:
39
36
  return fastmcp_global_settings.home / "oauth-mcp-client-cache"
40
37
 
41
38
 
42
- # Flexible OAuth models for real-world compatibility
43
- class ServerOAuthMetadata(_MCPServerOAuthMetadata):
44
- """
45
- More flexible OAuth metadata model that accepts broader ranges of values
46
- than the restrictive MCP standard model.
47
-
48
- This handles real-world OAuth servers like PayPal that may support
49
- additional methods not in the MCP specification.
50
- """
51
-
52
- # Allow any code challenge methods, not just S256
53
- code_challenge_methods_supported: list[str] | None = None
54
-
55
- # Allow any token endpoint auth methods
56
- token_endpoint_auth_methods_supported: list[str] | None = None
57
-
58
- # Allow any grant types
59
- grant_types_supported: list[str] | None = None
60
-
61
- # Allow any response types
62
- response_types_supported: list[str] = ["code"]
63
-
64
- # Allow any response modes
65
- response_modes_supported: list[str] | None = None
66
-
67
-
68
- class OAuthClientProvider(_MCPOAuthClientProvider):
69
- """
70
- OAuth client provider with more flexible OAuth metadata discovery.
71
- """
72
-
73
- async def _discover_oauth_metadata(
74
- self, server_url: str
75
- ) -> ServerOAuthMetadata | None:
76
- """
77
- Discover OAuth metadata with flexible validation.
78
-
79
- This is nearly identical to the parent implementation but uses
80
- ServerOAuthMetadata instead of the restrictive MCP OAuthMetadata.
81
- """
82
- # Extract base URL per MCP spec
83
- auth_base_url = self._get_authorization_base_url(server_url)
84
- url = urljoin(auth_base_url, "/.well-known/oauth-authorization-server")
85
-
86
- from mcp.types import LATEST_PROTOCOL_VERSION
87
-
88
- headers = {"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION}
89
-
90
- async with httpx.AsyncClient() as client:
91
- try:
92
- response = await client.get(url, headers=headers)
93
- if response.status_code == 404:
94
- return None
95
- response.raise_for_status()
96
- metadata_json = response.json()
97
- logger.debug(f"OAuth metadata discovered: {metadata_json}")
98
- return ServerOAuthMetadata.model_validate(metadata_json)
99
- except Exception:
100
- # Retry without MCP header for CORS compatibility
101
- try:
102
- response = await client.get(url)
103
- if response.status_code == 404:
104
- return None
105
- response.raise_for_status()
106
- metadata_json = response.json()
107
- logger.debug(
108
- f"OAuth metadata discovered (no MCP header): {metadata_json}"
109
- )
110
- return ServerOAuthMetadata.model_validate(metadata_json)
111
- except Exception:
112
- logger.exception("Failed to discover OAuth metadata")
113
- return None
114
-
115
-
116
39
  class FileTokenStorage(TokenStorage):
117
40
  """
118
41
  File-based token storage implementation for OAuth credentials and tokens.
@@ -229,7 +152,7 @@ class FileTokenStorage(TokenStorage):
229
152
 
230
153
  async def discover_oauth_metadata(
231
154
  server_base_url: str, httpx_kwargs: dict[str, Any] | None = None
232
- ) -> _MCPServerOAuthMetadata | None:
155
+ ) -> OAuthMetadata | None:
233
156
  """
234
157
  Discover OAuth metadata from the server using RFC 8414 well-known endpoint.
235
158
 
@@ -248,7 +171,7 @@ async def discover_oauth_metadata(
248
171
  response = await client.get(well_known_url, timeout=10.0)
249
172
  if response.status_code == 200:
250
173
  logger.debug("Successfully discovered OAuth metadata")
251
- return _MCPServerOAuthMetadata.model_validate(response.json())
174
+ return OAuthMetadata.model_validate(response.json())
252
175
  elif response.status_code == 404:
253
176
  logger.debug(
254
177
  "OAuth metadata not found (404) - server may not require auth"
@@ -298,7 +221,7 @@ def OAuth(
298
221
  client_name: str = "FastMCP Client",
299
222
  token_storage_cache_dir: Path | None = None,
300
223
  additional_client_metadata: dict[str, Any] | None = None,
301
- ) -> _MCPOAuthClientProvider:
224
+ ) -> OAuthClientProvider:
302
225
  """
303
226
  Create an OAuthClientProvider for an MCP server.
304
227
 
fastmcp/client/client.py CHANGED
@@ -1,6 +1,9 @@
1
+ from __future__ import annotations
2
+
1
3
  import asyncio
2
4
  import datetime
3
5
  from contextlib import AsyncExitStack, asynccontextmanager
6
+ from dataclasses import dataclass
4
7
  from pathlib import Path
5
8
  from typing import Any, Generic, Literal, cast, overload
6
9
 
@@ -13,6 +16,7 @@ from mcp import ClientSession
13
16
  from pydantic import AnyUrl
14
17
 
15
18
  import fastmcp
19
+ from fastmcp.client.elicitation import ElicitationHandler, create_elicitation_callback
16
20
  from fastmcp.client.logging import (
17
21
  LogHandler,
18
22
  create_log_callback,
@@ -29,8 +33,10 @@ from fastmcp.client.sampling import SamplingHandler, create_sampling_callback
29
33
  from fastmcp.exceptions import ToolError
30
34
  from fastmcp.server import FastMCP
31
35
  from fastmcp.utilities.exceptions import get_catch_handlers
36
+ from fastmcp.utilities.json_schema_type import json_schema_to_type
37
+ from fastmcp.utilities.logging import get_logger
32
38
  from fastmcp.utilities.mcp_config import MCPConfig
33
- from fastmcp.utilities.types import MCPContent
39
+ from fastmcp.utilities.types import get_cached_typeadapter
34
40
 
35
41
  from .transports import (
36
42
  ClientTransportT,
@@ -53,9 +59,12 @@ __all__ = [
53
59
  "LogHandler",
54
60
  "MessageHandler",
55
61
  "SamplingHandler",
62
+ "ElicitationHandler",
56
63
  "ProgressHandler",
57
64
  ]
58
65
 
66
+ logger = get_logger(__name__)
67
+
59
68
 
60
69
  class Client(Generic[ClientTransportT]):
61
70
  """
@@ -99,34 +108,39 @@ class Client(Generic[ClientTransportT]):
99
108
  cls,
100
109
  transport: ClientTransportT,
101
110
  **kwargs: Any,
102
- ) -> "Client[ClientTransportT]": ...
111
+ ) -> Client[ClientTransportT]: ...
103
112
 
104
113
  @overload
105
114
  def __new__(
106
115
  cls, transport: AnyUrl, **kwargs
107
- ) -> "Client[SSETransport|StreamableHttpTransport]": ...
116
+ ) -> Client[SSETransport | StreamableHttpTransport]: ...
108
117
 
109
118
  @overload
110
119
  def __new__(
111
120
  cls, transport: FastMCP | FastMCP1Server, **kwargs
112
- ) -> "Client[FastMCPTransport]": ...
121
+ ) -> Client[FastMCPTransport]: ...
113
122
 
114
123
  @overload
115
124
  def __new__(
116
125
  cls, transport: Path, **kwargs
117
- ) -> "Client[PythonStdioTransport|NodeStdioTransport]": ...
126
+ ) -> Client[PythonStdioTransport | NodeStdioTransport]: ...
118
127
 
119
128
  @overload
120
129
  def __new__(
121
130
  cls, transport: MCPConfig | dict[str, Any], **kwargs
122
- ) -> "Client[MCPConfigTransport]": ...
131
+ ) -> Client[MCPConfigTransport]: ...
123
132
 
124
133
  @overload
125
134
  def __new__(
126
135
  cls, transport: str, **kwargs
127
- ) -> "Client[PythonStdioTransport|NodeStdioTransport|SSETransport|StreamableHttpTransport]": ...
128
-
129
- def __new__(cls, transport, **kwargs) -> "Client":
136
+ ) -> Client[
137
+ PythonStdioTransport
138
+ | NodeStdioTransport
139
+ | SSETransport
140
+ | StreamableHttpTransport
141
+ ]: ...
142
+
143
+ def __new__(cls, transport, **kwargs) -> Client:
130
144
  instance = super().__new__(cls)
131
145
  return instance
132
146
 
@@ -142,6 +156,7 @@ class Client(Generic[ClientTransportT]):
142
156
  # Common args
143
157
  roots: RootsList | RootsHandler | None = None,
144
158
  sampling_handler: SamplingHandler | None = None,
159
+ elicitation_handler: ElicitationHandler | None = None,
145
160
  log_handler: LogHandler | None = None,
146
161
  message_handler: MessageHandlerT | MessageHandler | None = None,
147
162
  progress_handler: ProgressHandler | None = None,
@@ -194,6 +209,11 @@ class Client(Generic[ClientTransportT]):
194
209
  sampling_handler
195
210
  )
196
211
 
212
+ if elicitation_handler is not None:
213
+ self._session_kwargs["elicitation_callback"] = create_elicitation_callback(
214
+ elicitation_handler
215
+ )
216
+
197
217
  # session context management
198
218
  self._session: ClientSession | None = None
199
219
  self._exit_stack: AsyncExitStack | None = None
@@ -232,6 +252,14 @@ class Client(Generic[ClientTransportT]):
232
252
  sampling_callback
233
253
  )
234
254
 
255
+ def set_elicitation_callback(
256
+ self, elicitation_callback: ElicitationHandler
257
+ ) -> None:
258
+ """Set the elicitation callback for the client."""
259
+ self._session_kwargs["elicitation_callback"] = create_elicitation_callback(
260
+ elicitation_callback
261
+ )
262
+
235
263
  def is_connected(self) -> bool:
236
264
  """Check if the client is currently connected."""
237
265
  return self._session is not None
@@ -258,6 +286,21 @@ class Client(Generic[ClientTransportT]):
258
286
 
259
287
  async def __aenter__(self):
260
288
  await self._connect()
289
+
290
+ # Check if session task failed and raise error immediately
291
+ if (
292
+ self._session_task is not None
293
+ and self._session_task.done()
294
+ and not self._session_task.cancelled()
295
+ ):
296
+ exception = self._session_task.exception()
297
+ if isinstance(exception, httpx.HTTPStatusError):
298
+ raise exception
299
+ elif exception is not None:
300
+ raise RuntimeError(
301
+ f"Client failed to connect: {exception}"
302
+ ) from exception
303
+
261
304
  return self
262
305
 
263
306
  async def __aexit__(self, exc_type, exc_val, exc_tb):
@@ -308,16 +351,21 @@ class Client(Generic[ClientTransportT]):
308
351
  self._initialize_result = None
309
352
 
310
353
  async def _session_runner(self):
311
- async with AsyncExitStack() as stack:
312
- try:
313
- await stack.enter_async_context(self._context_manager())
314
- # Session/context is now ready
315
- self._ready_event.set()
316
- # Wait until disconnect/stop is requested
317
- await self._stop_event.wait()
318
- finally:
319
- # On exit, ensure ready event is set (idempotent)
320
- self._ready_event.set()
354
+ try:
355
+ async with AsyncExitStack() as stack:
356
+ try:
357
+ await stack.enter_async_context(self._context_manager())
358
+ # Session/context is now ready
359
+ self._ready_event.set()
360
+ # Wait until disconnect/stop is requested
361
+ await self._stop_event.wait()
362
+ finally:
363
+ # On exit, ensure ready event is set (idempotent)
364
+ self._ready_event.set()
365
+ except Exception:
366
+ # Ensure ready event is set even if context manager entry fails
367
+ self._ready_event.set()
368
+ raise
321
369
 
322
370
  async def close(self):
323
371
  await self._disconnect(force=True)
@@ -675,7 +723,8 @@ class Client(Generic[ClientTransportT]):
675
723
  arguments: dict[str, Any] | None = None,
676
724
  timeout: datetime.timedelta | float | int | None = None,
677
725
  progress_handler: ProgressHandler | None = None,
678
- ) -> list[MCPContent]:
726
+ raise_on_error: bool = True,
727
+ ) -> CallToolResult:
679
728
  """Call a tool on the server.
680
729
 
681
730
  Unlike call_tool_mcp, this method raises a ToolError if the tool call results in an error.
@@ -687,8 +736,13 @@ class Client(Generic[ClientTransportT]):
687
736
  progress_handler (ProgressHandler | None, optional): The progress handler to use for the tool call. Defaults to None.
688
737
 
689
738
  Returns:
690
- list[mcp.types.TextContent | mcp.types.ImageContent | mcp.types.AudioContent | mcp.types.EmbeddedResource]:
691
- The content returned by the tool.
739
+ CallToolResult:
740
+ The content returned by the tool. If the tool returns structured
741
+ outputs, they are returned as a dataclass (if an output schema
742
+ is available) or a dictionary; otherwise, a list of content
743
+ blocks is returned. Note: to receive both structured and
744
+ unstructured outputs, use call_tool_mcp instead and access the
745
+ raw result object.
692
746
 
693
747
  Raises:
694
748
  ToolError: If the tool call results in an error.
@@ -700,7 +754,43 @@ class Client(Generic[ClientTransportT]):
700
754
  timeout=timeout,
701
755
  progress_handler=progress_handler,
702
756
  )
703
- if result.isError:
757
+ data = None
758
+ if result.isError and raise_on_error:
704
759
  msg = cast(mcp.types.TextContent, result.content[0]).text
705
760
  raise ToolError(msg)
706
- return result.content
761
+ elif result.structuredContent:
762
+ try:
763
+ if name not in self.session._tool_output_schemas:
764
+ await self.session.list_tools()
765
+ if name in self.session._tool_output_schemas:
766
+ output_schema = self.session._tool_output_schemas.get(name)
767
+ if output_schema:
768
+ if output_schema.get("x-fastmcp-wrap-result"):
769
+ output_schema = output_schema.get("properties", {}).get(
770
+ "result"
771
+ )
772
+ structured_content = result.structuredContent.get("result")
773
+ else:
774
+ structured_content = result.structuredContent
775
+ output_type = json_schema_to_type(output_schema)
776
+ type_adapter = get_cached_typeadapter(output_type)
777
+ data = type_adapter.validate_python(structured_content)
778
+ else:
779
+ data = result.structuredContent
780
+ except Exception as e:
781
+ logger.error(f"Error parsing structured content: {e}")
782
+
783
+ return CallToolResult(
784
+ content=result.content,
785
+ structured_content=result.structuredContent,
786
+ data=data,
787
+ is_error=result.isError,
788
+ )
789
+
790
+
791
+ @dataclass
792
+ class CallToolResult:
793
+ content: list[mcp.types.ContentBlock]
794
+ structured_content: dict[str, Any] | None
795
+ data: Any = None
796
+ is_error: bool = False
@@ -0,0 +1,63 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Awaitable, Callable
4
+ from typing import Any, Generic, TypeAlias, TypeVar
5
+
6
+ import mcp.types
7
+ from mcp import ClientSession
8
+ from mcp.client.session import ElicitationFnT
9
+ from mcp.shared.context import LifespanContextT, RequestContext
10
+ from mcp.types import ElicitRequestParams
11
+ from mcp.types import ElicitResult as MCPElicitResult
12
+ from pydantic_core import to_jsonable_python
13
+
14
+ from fastmcp.utilities.json_schema_type import json_schema_to_type
15
+
16
+ __all__ = ["ElicitRequestParams", "ElicitResult", "ElicitationHandler"]
17
+
18
+ T = TypeVar("T")
19
+
20
+
21
+ class ElicitResult(MCPElicitResult, Generic[T]):
22
+ content: T | None = None
23
+
24
+
25
+ ElicitationHandler: TypeAlias = Callable[
26
+ [
27
+ str, # message
28
+ type[T], # a class for creating a structured response
29
+ ElicitRequestParams,
30
+ RequestContext[ClientSession, LifespanContextT],
31
+ ],
32
+ Awaitable[T | dict[str, Any] | ElicitResult[T | dict[str, Any]]],
33
+ ]
34
+
35
+
36
+ def create_elicitation_callback(
37
+ elicitation_handler: ElicitationHandler,
38
+ ) -> ElicitationFnT:
39
+ async def _elicitation_handler(
40
+ context: RequestContext[ClientSession, LifespanContextT],
41
+ params: ElicitRequestParams,
42
+ ) -> MCPElicitResult | mcp.types.ErrorData:
43
+ try:
44
+ if params.requestedSchema == {"type": "object", "properties": {}}:
45
+ response_type = None
46
+ else:
47
+ response_type = json_schema_to_type(params.requestedSchema)
48
+
49
+ result = await elicitation_handler(
50
+ params.message, response_type, params, context
51
+ )
52
+ # if the user returns data, we assume they've accepted the elicitation
53
+ if not isinstance(result, ElicitResult):
54
+ result = ElicitResult(action="accept", content=result)
55
+ content = to_jsonable_python(result.content)
56
+ return MCPElicitResult(**result.model_dump() | {"content": content})
57
+ except Exception as e:
58
+ return mcp.types.ErrorData(
59
+ code=mcp.types.INTERNAL_ERROR,
60
+ message=str(e),
61
+ )
62
+
63
+ return _elicitation_handler
@@ -8,18 +8,23 @@ import sys
8
8
  import warnings
9
9
  from collections.abc import AsyncIterator, Callable
10
10
  from pathlib import Path
11
- from typing import Any, Literal, TypedDict, TypeVar, cast, overload
12
- from urllib.parse import urlparse, urlunparse
11
+ from typing import Any, Literal, TypeVar, cast, overload
13
12
 
14
13
  import anyio
15
14
  import httpx
16
15
  import mcp.types
17
16
  from mcp import ClientSession, StdioServerParameters
18
- from mcp.client.session import ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT
17
+ from mcp.client.session import (
18
+ ElicitationFnT,
19
+ ListRootsFnT,
20
+ LoggingFnT,
21
+ MessageHandlerFnT,
22
+ SamplingFnT,
23
+ )
19
24
  from mcp.server.fastmcp import FastMCP as FastMCP1Server
20
25
  from mcp.shared.memory import create_client_server_memory_streams
21
26
  from pydantic import AnyUrl
22
- from typing_extensions import Unpack
27
+ from typing_extensions import TypedDict, Unpack
23
28
 
24
29
  import fastmcp
25
30
  from fastmcp.client.auth.bearer import BearerAuth
@@ -56,6 +61,7 @@ class SessionKwargs(TypedDict, total=False):
56
61
  sampling_callback: SamplingFnT | None
57
62
  list_roots_callback: ListRootsFnT | None
58
63
  logging_callback: LoggingFnT | None
64
+ elicitation_callback: ElicitationFnT | None
59
65
  message_handler: MessageHandlerFnT | None
60
66
  client_info: mcp.types.Implementation | None
61
67
 
@@ -161,11 +167,8 @@ class SSETransport(ClientTransport):
161
167
  if not isinstance(url, str) or not url.startswith("http"):
162
168
  raise ValueError("Invalid HTTP/S URL provided for SSE.")
163
169
 
164
- # Ensure the URL path ends with a trailing slash to avoid automatic redirects
165
- parsed = urlparse(url)
166
- if not parsed.path.endswith("/"):
167
- parsed = parsed._replace(path=parsed.path + "/")
168
- url = urlunparse(parsed)
170
+ # Don't modify the URL path - respect the exact URL provided by the user
171
+ # Some servers are strict about trailing slashes (e.g., PayPal MCP)
169
172
 
170
173
  self.url = url
171
174
  self.headers = headers or {}
@@ -236,11 +239,8 @@ class StreamableHttpTransport(ClientTransport):
236
239
  if not isinstance(url, str) or not url.startswith("http"):
237
240
  raise ValueError("Invalid HTTP/S URL provided for Streamable HTTP.")
238
241
 
239
- # Ensure the URL path ends with a trailing slash to avoid automatic redirects
240
- parsed = urlparse(url)
241
- if not parsed.path.endswith("/"):
242
- parsed = parsed._replace(path=parsed.path + "/")
243
- url = urlunparse(parsed)
242
+ # Don't modify the URL path - respect the exact URL provided by the user
243
+ # Some servers are strict about trailing slashes (e.g., PayPal MCP)
244
244
 
245
245
  self.url = url
246
246
  self.headers = headers or {}
@@ -361,34 +361,48 @@ class StdioTransport(ClientTransport):
361
361
  async def _connect_task():
362
362
  from mcp.client.stdio import stdio_client
363
363
 
364
- async with contextlib.AsyncExitStack() as stack:
365
- try:
366
- server_params = StdioServerParameters(
367
- command=self.command, args=self.args, env=self.env, cwd=self.cwd
368
- )
369
- transport = await stack.enter_async_context(
370
- stdio_client(server_params)
371
- )
372
- read_stream, write_stream = transport
373
- self._session = await stack.enter_async_context(
374
- ClientSession(read_stream, write_stream, **session_kwargs)
375
- )
376
-
377
- logger.debug("Stdio transport connected")
378
- self._ready_event.set()
379
-
380
- # Wait until disconnect is requested (stop_event is set)
381
- await self._stop_event.wait()
382
- finally:
383
- # Clean up client on exit
384
- self._session = None
385
- logger.debug("Stdio transport disconnected")
364
+ try:
365
+ async with contextlib.AsyncExitStack() as stack:
366
+ try:
367
+ server_params = StdioServerParameters(
368
+ command=self.command,
369
+ args=self.args,
370
+ env=self.env,
371
+ cwd=self.cwd,
372
+ )
373
+ transport = await stack.enter_async_context(
374
+ stdio_client(server_params)
375
+ )
376
+ read_stream, write_stream = transport
377
+ self._session = await stack.enter_async_context(
378
+ ClientSession(read_stream, write_stream, **session_kwargs)
379
+ )
380
+
381
+ logger.debug("Stdio transport connected")
382
+ self._ready_event.set()
383
+
384
+ # Wait until disconnect is requested (stop_event is set)
385
+ await self._stop_event.wait()
386
+ finally:
387
+ # Clean up client on exit
388
+ self._session = None
389
+ logger.debug("Stdio transport disconnected")
390
+ except Exception:
391
+ # Ensure ready event is set even if connection fails
392
+ self._ready_event.set()
393
+ raise
386
394
 
387
395
  # start the connection task
388
396
  self._connect_task = asyncio.create_task(_connect_task())
389
397
  # wait for the client to be ready before returning
390
398
  await self._ready_event.wait()
391
399
 
400
+ # Check if connect task completed with an exception (early failure)
401
+ if self._connect_task.done():
402
+ exception = self._connect_task.exception()
403
+ if exception is not None:
404
+ raise exception
405
+
392
406
  async def disconnect(self):
393
407
  if self._connect_task is None:
394
408
  return