fastmcp 2.3.3__py3-none-any.whl → 2.3.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
fastmcp/client/client.py CHANGED
@@ -1,21 +1,24 @@
1
1
  import datetime
2
- from contextlib import AbstractAsyncContextManager
2
+ from contextlib import AsyncExitStack, asynccontextmanager
3
3
  from pathlib import Path
4
4
  from typing import Any, cast
5
5
 
6
6
  import mcp.types
7
+ from exceptiongroup import catch
7
8
  from mcp import ClientSession
8
9
  from pydantic import AnyUrl
9
10
 
10
- from fastmcp.client.logging import LogHandler, MessageHandler
11
+ from fastmcp.client.logging import LogHandler, MessageHandler, default_log_handler
12
+ from fastmcp.client.progress import ProgressHandler, default_progress_handler
11
13
  from fastmcp.client.roots import (
12
14
  RootsHandler,
13
15
  RootsList,
14
16
  create_roots_callback,
15
17
  )
16
18
  from fastmcp.client.sampling import SamplingHandler, create_sampling_callback
17
- from fastmcp.exceptions import ClientError
19
+ from fastmcp.exceptions import ToolError
18
20
  from fastmcp.server import FastMCP
21
+ from fastmcp.utilities.exceptions import get_catch_handlers
19
22
 
20
23
  from .transports import ClientTransport, SessionKwargs, infer_transport
21
24
 
@@ -26,6 +29,7 @@ __all__ = [
26
29
  "LogHandler",
27
30
  "MessageHandler",
28
31
  "SamplingHandler",
32
+ "ProgressHandler",
29
33
  ]
30
34
 
31
35
 
@@ -33,8 +37,36 @@ class Client:
33
37
  """
34
38
  MCP client that delegates connection management to a Transport instance.
35
39
 
36
- The Client class is primarily concerned with MCP protocol logic,
37
- while the Transport handles connection establishment and management.
40
+ The Client class is responsible for MCP protocol logic, while the Transport
41
+ handles connection establishment and management. Client provides methods
42
+ for working with resources, prompts, tools and other MCP capabilities.
43
+
44
+ Args:
45
+ transport: Connection source specification, which can be:
46
+ - ClientTransport: Direct transport instance
47
+ - FastMCP: In-process FastMCP server
48
+ - AnyUrl | str: URL to connect to
49
+ - Path: File path for local socket
50
+ - dict: Transport configuration
51
+ roots: Optional RootsList or RootsHandler for filesystem access
52
+ sampling_handler: Optional handler for sampling requests
53
+ log_handler: Optional handler for log messages
54
+ message_handler: Optional handler for protocol messages
55
+ progress_handler: Optional handler for progress notifications
56
+ timeout: Optional timeout for requests (seconds or timedelta)
57
+
58
+ Examples:
59
+ ```python
60
+ # Connect to FastMCP server
61
+ client = Client("http://localhost:8080")
62
+
63
+ async with client:
64
+ # List available resources
65
+ resources = await client.list_resources()
66
+
67
+ # Call a tool
68
+ result = await client.call_tool("my_tool", {"param": "value"})
69
+ ```
38
70
  """
39
71
 
40
72
  def __init__(
@@ -45,36 +77,60 @@ class Client:
45
77
  sampling_handler: SamplingHandler | None = None,
46
78
  log_handler: LogHandler | None = None,
47
79
  message_handler: MessageHandler | None = None,
48
- read_timeout_seconds: datetime.timedelta | None = None,
80
+ progress_handler: ProgressHandler | None = None,
81
+ timeout: datetime.timedelta | float | int | None = None,
49
82
  ):
50
83
  self.transport = infer_transport(transport)
51
84
  self._session: ClientSession | None = None
52
- self._session_cm: AbstractAsyncContextManager[ClientSession] | None = None
85
+ self._exit_stack: AsyncExitStack | None = None
53
86
  self._nesting_counter: int = 0
87
+ self._initialize_result: mcp.types.InitializeResult | None = None
88
+
89
+ if log_handler is None:
90
+ log_handler = default_log_handler
91
+
92
+ if progress_handler is None:
93
+ progress_handler = default_progress_handler
94
+
95
+ self._progress_handler = progress_handler
96
+
97
+ if isinstance(timeout, int | float):
98
+ timeout = datetime.timedelta(seconds=timeout)
54
99
 
55
100
  self._session_kwargs: SessionKwargs = {
56
101
  "sampling_callback": None,
57
102
  "list_roots_callback": None,
58
103
  "logging_callback": log_handler,
59
104
  "message_handler": message_handler,
60
- "read_timeout_seconds": read_timeout_seconds,
105
+ "read_timeout_seconds": timeout,
61
106
  }
62
107
 
63
108
  if roots is not None:
64
109
  self.set_roots(roots)
65
110
 
66
111
  if sampling_handler is not None:
67
- self.set_sampling_callback(sampling_handler)
112
+ self._session_kwargs["sampling_callback"] = create_sampling_callback(
113
+ sampling_handler
114
+ )
68
115
 
69
116
  @property
70
117
  def session(self) -> ClientSession:
71
118
  """Get the current active session. Raises RuntimeError if not connected."""
72
119
  if self._session is None:
73
120
  raise RuntimeError(
74
- "Client is not connected. Use 'async with client:' context manager first."
121
+ "Client is not connected. Use the 'async with client:' context manager first."
75
122
  )
76
123
  return self._session
77
124
 
125
+ @property
126
+ def initialize_result(self) -> mcp.types.InitializeResult:
127
+ """Get the result of the initialization request."""
128
+ if self._initialize_result is None:
129
+ raise RuntimeError(
130
+ "Client is not connected. Use the 'async with client:' context manager first."
131
+ )
132
+ return self._initialize_result
133
+
78
134
  def set_roots(self, roots: RootsList | RootsHandler) -> None:
79
135
  """Set the roots for the client. This does not automatically call `send_roots_list_changed`."""
80
136
  self._session_kwargs["list_roots_callback"] = create_roots_callback(roots)
@@ -89,22 +145,47 @@ class Client:
89
145
  """Check if the client is currently connected."""
90
146
  return self._session is not None
91
147
 
148
+ @asynccontextmanager
149
+ async def _context_manager(self):
150
+ with catch(get_catch_handlers()):
151
+ async with self.transport.connect_session(
152
+ **self._session_kwargs
153
+ ) as session:
154
+ self._session = session
155
+ # Initialize the session
156
+ self._initialize_result = await self._session.initialize()
157
+
158
+ try:
159
+ yield
160
+ finally:
161
+ self._exit_stack = None
162
+ self._session = None
163
+ self._initialize_result = None
164
+
92
165
  async def __aenter__(self):
93
166
  if self._nesting_counter == 0:
94
- # create new session
95
- self._session_cm = self.transport.connect_session(**self._session_kwargs)
96
- self._session = await self._session_cm.__aenter__()
167
+ # Create exit stack to manage both context managers
168
+ stack = AsyncExitStack()
169
+ await stack.__aenter__()
170
+
171
+ await stack.enter_async_context(self._context_manager())
172
+
173
+ self._exit_stack = stack
97
174
 
98
175
  self._nesting_counter += 1
176
+
99
177
  return self
100
178
 
101
179
  async def __aexit__(self, exc_type, exc_val, exc_tb):
102
180
  self._nesting_counter -= 1
103
181
 
104
- if self._nesting_counter == 0 and self._session_cm is not None:
105
- await self._session_cm.__aexit__(exc_type, exc_val, exc_tb)
106
- self._session_cm = None
107
- self._session = None
182
+ if self._nesting_counter == 0:
183
+ # Exit the stack which will handle cleaning up the session
184
+ if self._exit_stack is not None:
185
+ try:
186
+ await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb)
187
+ finally:
188
+ self._exit_stack = None
108
189
 
109
190
  # --- MCP Client Methods ---
110
191
 
@@ -118,9 +199,12 @@ class Client:
118
199
  progress_token: str | int,
119
200
  progress: float,
120
201
  total: float | None = None,
202
+ message: str | None = None,
121
203
  ) -> None:
122
204
  """Send a progress notification."""
123
- await self.session.send_progress_notification(progress_token, progress, total)
205
+ await self.session.send_progress_notification(
206
+ progress_token, progress, total, message
207
+ )
124
208
 
125
209
  async def set_logging_level(self, level: mcp.types.LoggingLevel) -> None:
126
210
  """Send a logging/setLevel request."""
@@ -377,7 +461,11 @@ class Client:
377
461
  # --- Call Tool ---
378
462
 
379
463
  async def call_tool_mcp(
380
- self, name: str, arguments: dict[str, Any]
464
+ self,
465
+ name: str,
466
+ arguments: dict[str, Any],
467
+ progress_handler: ProgressHandler | None = None,
468
+ timeout: datetime.timedelta | float | int | None = None,
381
469
  ) -> mcp.types.CallToolResult:
382
470
  """Send a tools/call request and return the complete MCP protocol result.
383
471
 
@@ -387,6 +475,8 @@ class Client:
387
475
  Args:
388
476
  name (str): The name of the tool to call.
389
477
  arguments (dict[str, Any]): Arguments to pass to the tool.
478
+ timeout (datetime.timedelta | float | int | None, optional): The timeout for the tool call. Defaults to None.
479
+ progress_handler (ProgressHandler | None, optional): The progress handler to use for the tool call. Defaults to None.
390
480
 
391
481
  Returns:
392
482
  mcp.types.CallToolResult: The complete response object from the protocol,
@@ -395,34 +485,51 @@ class Client:
395
485
  Raises:
396
486
  RuntimeError: If called while the client is not connected.
397
487
  """
398
- result = await self.session.call_tool(name=name, arguments=arguments)
488
+
489
+ if isinstance(timeout, int | float):
490
+ timeout = datetime.timedelta(seconds=timeout)
491
+ result = await self.session.call_tool(
492
+ name=name,
493
+ arguments=arguments,
494
+ read_timeout_seconds=timeout,
495
+ progress_callback=progress_handler or self._progress_handler,
496
+ )
399
497
  return result
400
498
 
401
499
  async def call_tool(
402
500
  self,
403
501
  name: str,
404
502
  arguments: dict[str, Any] | None = None,
503
+ timeout: datetime.timedelta | float | int | None = None,
504
+ progress_handler: ProgressHandler | None = None,
405
505
  ) -> list[
406
506
  mcp.types.TextContent | mcp.types.ImageContent | mcp.types.EmbeddedResource
407
507
  ]:
408
508
  """Call a tool on the server.
409
509
 
410
- Unlike call_tool_mcp, this method raises a ClientError if the tool call results in an error.
510
+ Unlike call_tool_mcp, this method raises a ToolError if the tool call results in an error.
411
511
 
412
512
  Args:
413
513
  name (str): The name of the tool to call.
414
514
  arguments (dict[str, Any] | None, optional): Arguments to pass to the tool. Defaults to None.
515
+ timeout (datetime.timedelta | float | int | None, optional): The timeout for the tool call. Defaults to None.
516
+ progress_handler (ProgressHandler | None, optional): The progress handler to use for the tool call. Defaults to None.
415
517
 
416
518
  Returns:
417
519
  list[mcp.types.TextContent | mcp.types.ImageContent | mcp.types.EmbeddedResource]:
418
520
  The content returned by the tool.
419
521
 
420
522
  Raises:
421
- ClientError: If the tool call results in an error.
523
+ ToolError: If the tool call results in an error.
422
524
  RuntimeError: If called while the client is not connected.
423
525
  """
424
- result = await self.call_tool_mcp(name=name, arguments=arguments or {})
526
+ result = await self.call_tool_mcp(
527
+ name=name,
528
+ arguments=arguments or {},
529
+ timeout=timeout,
530
+ progress_handler=progress_handler,
531
+ )
425
532
  if result.isError:
426
533
  msg = cast(mcp.types.TextContent, result.content[0]).text
427
- raise ClientError(msg)
534
+ raise ToolError(msg)
428
535
  return result.content
fastmcp/client/logging.py CHANGED
@@ -6,8 +6,16 @@ from mcp.client.session import (
6
6
  )
7
7
  from mcp.types import LoggingMessageNotificationParams
8
8
 
9
+ from fastmcp.utilities.logging import get_logger
10
+
11
+ logger = get_logger(__name__)
12
+
9
13
  LogMessage: TypeAlias = LoggingMessageNotificationParams
10
14
  LogHandler: TypeAlias = LoggingFnT
11
15
  MessageHandler: TypeAlias = MessageHandlerFnT
12
16
 
13
17
  __all__ = ["LogMessage", "LogHandler", "MessageHandler"]
18
+
19
+
20
+ async def default_log_handler(params: LogMessage) -> None:
21
+ logger.debug(f"Log received: {params}")
@@ -0,0 +1,38 @@
1
+ from typing import TypeAlias
2
+
3
+ from mcp.shared.session import ProgressFnT
4
+
5
+ from fastmcp.utilities.logging import get_logger
6
+
7
+ logger = get_logger(__name__)
8
+
9
+ ProgressHandler: TypeAlias = ProgressFnT
10
+
11
+
12
+ async def default_progress_handler(
13
+ progress: float, total: float | None, message: str | None
14
+ ) -> None:
15
+ """Default handler for progress notifications.
16
+
17
+ Logs progress updates at debug level, properly handling missing total or message values.
18
+
19
+ Args:
20
+ progress: Current progress value
21
+ total: Optional total expected value
22
+ message: Optional status message
23
+ """
24
+ if total is not None:
25
+ # We have both progress and total
26
+ percent = (progress / total) * 100
27
+ progress_str = f"{progress}/{total} ({percent:.1f}%)"
28
+ else:
29
+ # We only have progress
30
+ progress_str = f"{progress}"
31
+
32
+ # Include message if available
33
+ if message:
34
+ log_msg = f"Progress: {progress_str} - {message}"
35
+ else:
36
+ log_msg = f"Progress: {progress_str}"
37
+
38
+ logger.debug(log_msg)
@@ -1,17 +1,15 @@
1
1
  import abc
2
2
  import contextlib
3
3
  import datetime
4
- import inspect
5
4
  import os
6
5
  import shutil
7
6
  import sys
8
- import warnings
9
7
  from collections.abc import AsyncIterator
10
8
  from pathlib import Path
11
- from typing import Any, TypedDict
9
+ from typing import Any, TypedDict, cast
10
+ from urllib.parse import urlparse
12
11
 
13
- from exceptiongroup import BaseExceptionGroup, catch
14
- from mcp import ClientSession, McpError, StdioServerParameters
12
+ from mcp import ClientSession, StdioServerParameters
15
13
  from mcp.client.session import (
16
14
  ListRootsFnT,
17
15
  LoggingFnT,
@@ -26,8 +24,10 @@ from mcp.shared.memory import create_connected_server_and_client_session
26
24
  from pydantic import AnyUrl
27
25
  from typing_extensions import Unpack
28
26
 
29
- from fastmcp.exceptions import ClientError
30
27
  from fastmcp.server import FastMCP as FastMCPServer
28
+ from fastmcp.utilities.logging import get_logger
29
+
30
+ logger = get_logger(__name__)
31
31
 
32
32
 
33
33
  class SessionKwargs(TypedDict, total=False):
@@ -46,6 +46,7 @@ class ClientTransport(abc.ABC):
46
46
 
47
47
  A Transport is responsible for establishing and managing connections
48
48
  to an MCP server, and providing a ClientSession within an async context.
49
+
49
50
  """
50
51
 
51
52
  @abc.abstractmethod
@@ -54,7 +55,9 @@ class ClientTransport(abc.ABC):
54
55
  self, **session_kwargs: Unpack[SessionKwargs]
55
56
  ) -> AsyncIterator[ClientSession]:
56
57
  """
57
- Establishes a connection and yields an active, initialized ClientSession.
58
+ Establishes a connection and yields an active ClientSession.
59
+
60
+ The ClientSession is *not* expected to be initialized in this context manager.
58
61
 
59
62
  The session is guaranteed to be valid only within the scope of the
60
63
  async context manager. Connection setup and teardown are handled
@@ -65,7 +68,7 @@ class ClientTransport(abc.ABC):
65
68
  constructor (e.g., callbacks, timeouts).
66
69
 
67
70
  Yields:
68
- An initialized mcp.ClientSession instance.
71
+ A mcp.ClientSession instance.
69
72
  """
70
73
  raise NotImplementedError
71
74
  yield None # type: ignore
@@ -94,7 +97,6 @@ class WSTransport(ClientTransport):
94
97
  async with ClientSession(
95
98
  read_stream, write_stream, **session_kwargs
96
99
  ) as session:
97
- await session.initialize() # Initialize after session creation
98
100
  yield session
99
101
 
100
102
  def __repr__(self) -> str:
@@ -104,7 +106,12 @@ class WSTransport(ClientTransport):
104
106
  class SSETransport(ClientTransport):
105
107
  """Transport implementation that connects to an MCP server via Server-Sent Events."""
106
108
 
107
- def __init__(self, url: str | AnyUrl, headers: dict[str, str] | None = None):
109
+ def __init__(
110
+ self,
111
+ url: str | AnyUrl,
112
+ headers: dict[str, str] | None = None,
113
+ sse_read_timeout: datetime.timedelta | float | int | None = None,
114
+ ):
108
115
  if isinstance(url, AnyUrl):
109
116
  url = str(url)
110
117
  if not isinstance(url, str) or not url.startswith("http"):
@@ -112,16 +119,32 @@ class SSETransport(ClientTransport):
112
119
  self.url = url
113
120
  self.headers = headers or {}
114
121
 
122
+ if isinstance(sse_read_timeout, int | float):
123
+ sse_read_timeout = datetime.timedelta(seconds=sse_read_timeout)
124
+ self.sse_read_timeout = sse_read_timeout
125
+
115
126
  @contextlib.asynccontextmanager
116
127
  async def connect_session(
117
128
  self, **session_kwargs: Unpack[SessionKwargs]
118
129
  ) -> AsyncIterator[ClientSession]:
119
- async with sse_client(self.url, headers=self.headers) as transport:
130
+ client_kwargs = {}
131
+ # sse_read_timeout has a default value set, so we can't pass None without overriding it
132
+ # instead we simply leave the kwarg out if it's not provided
133
+ if self.sse_read_timeout is not None:
134
+ client_kwargs["sse_read_timeout"] = self.sse_read_timeout.total_seconds()
135
+ if session_kwargs.get("read_timeout_seconds", None) is not None:
136
+ read_timeout_seconds = cast(
137
+ datetime.timedelta, session_kwargs.get("read_timeout_seconds")
138
+ )
139
+ client_kwargs["timeout"] = read_timeout_seconds.total_seconds()
140
+
141
+ async with sse_client(
142
+ self.url, headers=self.headers, **client_kwargs
143
+ ) as transport:
120
144
  read_stream, write_stream = transport
121
145
  async with ClientSession(
122
146
  read_stream, write_stream, **session_kwargs
123
147
  ) as session:
124
- await session.initialize()
125
148
  yield session
126
149
 
127
150
  def __repr__(self) -> str:
@@ -131,7 +154,12 @@ class SSETransport(ClientTransport):
131
154
  class StreamableHttpTransport(ClientTransport):
132
155
  """Transport implementation that connects to an MCP server via Streamable HTTP Requests."""
133
156
 
134
- def __init__(self, url: str | AnyUrl, headers: dict[str, str] | None = None):
157
+ def __init__(
158
+ self,
159
+ url: str | AnyUrl,
160
+ headers: dict[str, str] | None = None,
161
+ sse_read_timeout: datetime.timedelta | float | int | None = None,
162
+ ):
135
163
  if isinstance(url, AnyUrl):
136
164
  url = str(url)
137
165
  if not isinstance(url, str) or not url.startswith("http"):
@@ -139,16 +167,29 @@ class StreamableHttpTransport(ClientTransport):
139
167
  self.url = url
140
168
  self.headers = headers or {}
141
169
 
170
+ if isinstance(sse_read_timeout, int | float):
171
+ sse_read_timeout = datetime.timedelta(seconds=sse_read_timeout)
172
+ self.sse_read_timeout = sse_read_timeout
173
+
142
174
  @contextlib.asynccontextmanager
143
175
  async def connect_session(
144
176
  self, **session_kwargs: Unpack[SessionKwargs]
145
177
  ) -> AsyncIterator[ClientSession]:
146
- async with streamablehttp_client(self.url, headers=self.headers) as transport:
178
+ client_kwargs = {}
179
+ # sse_read_timeout has a default value set, so we can't pass None without overriding it
180
+ # instead we simply leave the kwarg out if it's not provided
181
+ if self.sse_read_timeout is not None:
182
+ client_kwargs["sse_read_timeout"] = self.sse_read_timeout
183
+ if session_kwargs.get("read_timeout_seconds", None) is not None:
184
+ client_kwargs["timeout"] = session_kwargs.get("read_timeout_seconds")
185
+
186
+ async with streamablehttp_client(
187
+ self.url, headers=self.headers, **client_kwargs
188
+ ) as transport:
147
189
  read_stream, write_stream, _ = transport
148
190
  async with ClientSession(
149
191
  read_stream, write_stream, **session_kwargs
150
192
  ) as session:
151
- await session.initialize()
152
193
  yield session
153
194
 
154
195
  def __repr__(self) -> str:
@@ -196,7 +237,6 @@ class StdioTransport(ClientTransport):
196
237
  async with ClientSession(
197
238
  read_stream, write_stream, **session_kwargs
198
239
  ) as session:
199
- await session.initialize()
200
240
  yield session
201
241
 
202
242
  def __repr__(self) -> str:
@@ -418,26 +458,12 @@ class FastMCPTransport(ClientTransport):
418
458
  async def connect_session(
419
459
  self, **session_kwargs: Unpack[SessionKwargs]
420
460
  ) -> AsyncIterator[ClientSession]:
421
- def exception_handler(excgroup: BaseExceptionGroup):
422
- for exc in excgroup.exceptions:
423
- if isinstance(exc, BaseExceptionGroup):
424
- exception_handler(exc)
425
- raise exc
426
-
427
- def mcperror_handler(excgroup: BaseExceptionGroup):
428
- for exc in excgroup.exceptions:
429
- if isinstance(exc, BaseExceptionGroup):
430
- mcperror_handler(exc)
431
- raise ClientError(exc)
432
-
433
- # backport of 3.11's except* syntax
434
- with catch({McpError: mcperror_handler, Exception: exception_handler}):
435
- # create_connected_server_and_client_session manages the session lifecycle itself
436
- async with create_connected_server_and_client_session(
437
- server=self._fastmcp._mcp_server,
438
- **session_kwargs,
439
- ) as session:
440
- yield session
461
+ # create_connected_server_and_client_session manages the session lifecycle itself
462
+ async with create_connected_server_and_client_session(
463
+ server=self._fastmcp._mcp_server,
464
+ **session_kwargs,
465
+ ) as session:
466
+ yield session
441
467
 
442
468
  def __repr__(self) -> str:
443
469
  return f"<FastMCP(server='{self._fastmcp.name}')>"
@@ -461,36 +487,29 @@ def infer_transport(
461
487
 
462
488
  # the transport is a FastMCP server
463
489
  elif isinstance(transport, FastMCPServer):
464
- return FastMCPTransport(mcp=transport)
490
+ inferred_transport = FastMCPTransport(mcp=transport)
465
491
 
466
492
  # the transport is a path to a script
467
493
  elif isinstance(transport, Path | str) and Path(transport).exists():
468
494
  if str(transport).endswith(".py"):
469
- return PythonStdioTransport(script_path=transport)
495
+ inferred_transport = PythonStdioTransport(script_path=transport)
470
496
  elif str(transport).endswith(".js"):
471
- return NodeStdioTransport(script_path=transport)
497
+ inferred_transport = NodeStdioTransport(script_path=transport)
472
498
  else:
473
499
  raise ValueError(f"Unsupported script type: {transport}")
474
500
 
475
501
  # the transport is an http(s) URL
476
502
  elif isinstance(transport, AnyUrl | str) and str(transport).startswith("http"):
477
- if str(transport).rstrip("/").endswith("/sse"):
478
- warnings.warn(
479
- inspect.cleandoc(
480
- """
481
- As of FastMCP 2.3.0, HTTP URLs are inferred to use Streamable HTTP.
482
- The provided URL ends in `/sse`, so you may encounter unexpected behavior.
483
- If you intended to use SSE, please use the `SSETransport` class directly.
484
- """
485
- ),
486
- category=UserWarning,
487
- stacklevel=2,
488
- )
489
- return StreamableHttpTransport(url=transport)
490
-
491
- # the transport is a websocket URL
492
- elif isinstance(transport, AnyUrl | str) and str(transport).startswith("ws"):
493
- return WSTransport(url=transport)
503
+ transport_str = str(transport)
504
+ # Parse out just the path portion to check for /sse
505
+ parsed_url = urlparse(transport_str)
506
+ path = parsed_url.path
507
+
508
+ # Check if path contains /sse/ or ends with /sse
509
+ if "/sse/" in path or path.rstrip("/").endswith("/sse"):
510
+ inferred_transport = SSETransport(url=transport)
511
+ else:
512
+ inferred_transport = StreamableHttpTransport(url=transport)
494
513
 
495
514
  ## if the transport is a config dict
496
515
  elif isinstance(transport, dict):
@@ -505,7 +524,7 @@ def infer_transport(
505
524
  server_name = list(server.keys())[0]
506
525
  # Stdio transport
507
526
  if "command" in server[server_name] and "args" in server[server_name]:
508
- return StdioTransport(
527
+ inferred_transport = StdioTransport(
509
528
  command=server[server_name]["command"],
510
529
  args=server[server_name]["args"],
511
530
  env=server[server_name].get("env", None),
@@ -514,19 +533,16 @@ def infer_transport(
514
533
 
515
534
  # HTTP transport
516
535
  elif "url" in server:
517
- return SSETransport(
536
+ inferred_transport = SSETransport(
518
537
  url=server["url"],
519
538
  headers=server.get("headers", None),
520
539
  )
521
540
 
522
- # WebSocket transport
523
- elif "ws_url" in server:
524
- return WSTransport(
525
- url=server["ws_url"],
526
- )
527
-
528
541
  raise ValueError("Cannot determine transport type from dictionary")
529
542
 
530
543
  # the transport is an unknown type
531
544
  else:
532
545
  raise ValueError(f"Could not infer a valid transport from: {transport}")
546
+
547
+ logger.debug(f"Inferred transport: {inferred_transport}")
548
+ return inferred_transport
fastmcp/exceptions.py CHANGED
@@ -1,5 +1,7 @@
1
1
  """Custom exceptions for FastMCP."""
2
2
 
3
+ from mcp import McpError # noqa: F401
4
+
3
5
 
4
6
  class FastMCPError(Exception):
5
7
  """Base error for FastMCP."""
fastmcp/prompts/prompt.py CHANGED
@@ -13,7 +13,8 @@ from mcp.types import PromptArgument as MCPPromptArgument
13
13
  from pydantic import BaseModel, BeforeValidator, Field, TypeAdapter, validate_call
14
14
 
15
15
  from fastmcp.server.dependencies import get_context
16
- from fastmcp.utilities.json_schema import prune_params
16
+ from fastmcp.utilities.json_schema import compress_schema
17
+ from fastmcp.utilities.logging import get_logger
17
18
  from fastmcp.utilities.types import (
18
19
  _convert_set_defaults,
19
20
  find_kwarg_by_type,
@@ -25,6 +26,8 @@ if TYPE_CHECKING:
25
26
 
26
27
  CONTENT_TYPES = TextContent | ImageContent | EmbeddedResource
27
28
 
29
+ logger = get_logger(__name__)
30
+
28
31
 
29
32
  def Message(
30
33
  content: str | CONTENT_TYPES, role: Role | None = None, **kwargs: Any
@@ -112,7 +115,11 @@ class Prompt(BaseModel):
112
115
 
113
116
  context_kwarg = find_kwarg_by_type(fn, kwarg_type=Context)
114
117
  if context_kwarg:
115
- parameters = prune_params(parameters, params=[context_kwarg])
118
+ prune_params = [context_kwarg]
119
+ else:
120
+ prune_params = None
121
+
122
+ parameters = compress_schema(parameters, prune_params=prune_params)
116
123
 
117
124
  # Convert parameters to PromptArguments
118
125
  arguments: list[PromptArgument] = []
@@ -192,13 +199,12 @@ class Prompt(BaseModel):
192
199
  )
193
200
  )
194
201
  except Exception:
195
- raise ValueError(
196
- f"Could not convert prompt result to message: {msg}"
197
- )
202
+ raise ValueError("Could not convert prompt result to message.")
198
203
 
199
204
  return messages
200
205
  except Exception as e:
201
- raise ValueError(f"Error rendering prompt {self.name}: {e}")
206
+ logger.exception(f"Error rendering prompt {self.name}: {e}")
207
+ raise ValueError(f"Error rendering prompt {self.name}.")
202
208
 
203
209
  def __eq__(self, other: object) -> bool:
204
210
  if not isinstance(other, Prompt):