fast-agent-mcp 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,133 @@
1
+ """
2
+ Module for handling MCP Sampling functionality without causing circular imports.
3
+ This module is carefully designed to avoid circular imports in the agent system.
4
+ """
5
+
6
+ from mcp import ClientSession
7
+ from mcp.types import (
8
+ CreateMessageRequestParams,
9
+ CreateMessageResult,
10
+ TextContent,
11
+ )
12
+
13
+ from mcp_agent.logging.logger import get_logger
14
+ from mcp_agent.mcp.interfaces import AugmentedLLMProtocol
15
+ from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
16
+
17
+ # Protocol is sufficient to describe the interface - no need for TYPE_CHECKING imports
18
+
19
+ logger = get_logger(__name__)
20
+
21
+
22
+ def create_sampling_llm(
23
+ mcp_ctx: ClientSession, model_string: str
24
+ ) -> AugmentedLLMProtocol:
25
+ """
26
+ Create an LLM instance for sampling without tools support.
27
+ This utility function creates a minimal LLM instance based on the model string.
28
+
29
+ Args:
30
+ mcp_ctx: The MCP ClientSession
31
+ model_string: The model to use (e.g. "passthrough", "claude-3-5-sonnet-latest")
32
+
33
+ Returns:
34
+ An initialized LLM instance ready to use
35
+ """
36
+ from mcp_agent.workflows.llm.model_factory import ModelFactory
37
+ from mcp_agent.agents.agent import Agent, AgentConfig
38
+
39
+ # Get application context from global state if available
40
+ # We don't try to extract it from mcp_ctx as they're different contexts
41
+ app_context = None
42
+ try:
43
+ from mcp_agent.context import get_current_context
44
+
45
+ app_context = get_current_context()
46
+ except Exception:
47
+ logger.warning("App context not available for sampling call")
48
+
49
+ # Create a minimal agent configuration
50
+ agent_config = AgentConfig(
51
+ name="sampling_agent",
52
+ instruction="You are a sampling agent.",
53
+ servers=[], # No servers needed
54
+ )
55
+
56
+ # Create agent with our application context (not the MCP context)
57
+ # Set connection_persistence=False to avoid server connections
58
+ agent = Agent(
59
+ config=agent_config,
60
+ context=app_context,
61
+ server_names=[], # Make sure no server connections are attempted
62
+ connection_persistence=False, # Avoid server connection management
63
+ )
64
+
65
+ # Create the LLM using the factory
66
+ factory = ModelFactory.create_factory(model_string)
67
+ llm = factory(agent=agent)
68
+
69
+ # Attach the LLM to the agent
70
+ agent._llm = llm
71
+
72
+ return llm
73
+
74
+
75
+ async def sample(
76
+ mcp_ctx: ClientSession, params: CreateMessageRequestParams
77
+ ) -> CreateMessageResult:
78
+ """
79
+ Handle sampling requests from the MCP protocol.
80
+ This function extracts the model from the server config and
81
+ returns a simple response using the specified model.
82
+ """
83
+ model = None
84
+ try:
85
+ # Extract model from server config
86
+ if (
87
+ hasattr(mcp_ctx, "session")
88
+ and hasattr(mcp_ctx.session, "server_config")
89
+ and mcp_ctx.session.server_config
90
+ and hasattr(mcp_ctx.session.server_config, "sampling")
91
+ and mcp_ctx.session.server_config.sampling.model
92
+ ):
93
+ model = mcp_ctx.session.server_config.sampling.model
94
+
95
+ if model is None:
96
+ raise ValueError("No model configured")
97
+
98
+ # Create an LLM instance using our utility function
99
+ llm = create_sampling_llm(mcp_ctx, model)
100
+
101
+ # Get user message from the request params
102
+ user_message = params.messages[0].content.text
103
+
104
+ # Create a multipart prompt message with the user's input
105
+ prompt = PromptMessageMultipart(
106
+ role="user", content=[TextContent(type="text", text=user_message)]
107
+ )
108
+
109
+ try:
110
+ # Use the LLM to generate a response
111
+ logger.info(f"Processing input: {user_message[:50]}...")
112
+ llm_response = await llm.generate_prompt(prompt, None)
113
+ logger.info(f"Generated response: {llm_response[:50]}...")
114
+ except Exception as e:
115
+ # If there's an error in LLM processing, fall back to echo
116
+ logger.error(f"Error generating response: {str(e)}")
117
+ llm_response = f"Echo response: {user_message}"
118
+
119
+ # Return the LLM-generated response
120
+ return CreateMessageResult(
121
+ role="assistant",
122
+ content=TextContent(type="text", text=llm_response),
123
+ model=model,
124
+ stopReason="endTurn",
125
+ )
126
+ except Exception as e:
127
+ logger.error(f"Error in sampling: {str(e)}")
128
+ return CreateMessageResult(
129
+ role="assistant",
130
+ content=TextContent(type="text", text=f"Error in sampling: {str(e)}"),
131
+ model=model or "unknown",
132
+ stopReason="error",
133
+ )
mcp_agent/mcp/stdio.py CHANGED
@@ -9,10 +9,12 @@ from anyio.streams.text import TextReceiveStream
9
9
  from mcp.client.stdio import StdioServerParameters, get_default_environment
10
10
  import mcp.types as types
11
11
  from mcp_agent.logging.logger import get_logger
12
+ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
12
13
 
13
14
  logger = get_logger(__name__)
14
15
 
15
16
 
17
+ # TODO this will be removed when client library with https://github.com/modelcontextprotocol/python-sdk/pull/343 is released
16
18
  @asynccontextmanager
17
19
  async def stdio_client_with_rich_stderr(server: StdioServerParameters):
18
20
  """
@@ -22,10 +24,16 @@ async def stdio_client_with_rich_stderr(server: StdioServerParameters):
22
24
  Args:
23
25
  server: The server parameters for the stdio connection
24
26
  """
27
+ read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
28
+ read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
29
+
30
+ write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
31
+ write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
32
+
25
33
  read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
26
34
  write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
27
-
28
35
  # Open process with stderr piped for capture
36
+
29
37
  process = await anyio.open_process(
30
38
  [server.command, *server.args],
31
39
  env=server.env if server.env is not None else get_default_environment(),
@@ -67,19 +75,19 @@ async def stdio_client_with_rich_stderr(server: StdioServerParameters):
67
75
  except anyio.ClosedResourceError:
68
76
  await anyio.lowlevel.checkpoint()
69
77
 
70
- async def stderr_reader():
71
- assert process.stderr, "Opened process is missing stderr"
72
- try:
73
- async for chunk in TextReceiveStream(
74
- process.stderr,
75
- encoding=server.encoding,
76
- errors=server.encoding_error_handler,
77
- ):
78
- if chunk.strip():
79
- # Let the logging system handle the formatting consistently
80
- logger.event("info", "mcpserver.stderr", chunk.rstrip(), None, {})
81
- except anyio.ClosedResourceError:
82
- await anyio.lowlevel.checkpoint()
78
+ # async def stderr_reader():
79
+ # assert process.stderr, "Opened process is missing stderr"
80
+ # try:
81
+ # async for chunk in TextReceiveStream(
82
+ # process.stderr,
83
+ # encoding=server.encoding,
84
+ # errors=server.encoding_error_handler,
85
+ # ):
86
+ # if chunk.strip():
87
+ # # Let the logging system handle the formatting consistently
88
+ # logger.event("info", "mcpserver.stderr", chunk.rstrip(), None, {})
89
+ # except anyio.ClosedResourceError:
90
+ # await anyio.lowlevel.checkpoint()
83
91
 
84
92
  async def stdin_writer():
85
93
  assert process.stdin, "Opened process is missing stdin"
@@ -100,5 +108,4 @@ async def stdio_client_with_rich_stderr(server: StdioServerParameters):
100
108
  async with anyio.create_task_group() as tg, process:
101
109
  tg.start_soon(stdout_reader)
102
110
  tg.start_soon(stdin_writer)
103
- tg.start_soon(stderr_reader)
104
111
  yield read_stream, write_stream
@@ -15,9 +15,9 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStre
15
15
  from mcp import ClientSession
16
16
  from mcp.client.stdio import (
17
17
  StdioServerParameters,
18
- stdio_client,
19
18
  get_default_environment,
20
19
  )
20
+ from mcp_agent.mcp.stdio import stdio_client_with_rich_stderr
21
21
  from mcp.client.sse import sse_client
22
22
 
23
23
  from mcp_agent.config import (
@@ -134,7 +134,10 @@ class ServerRegistry:
134
134
  env={**get_default_environment(), **(config.env or {})},
135
135
  )
136
136
 
137
- async with stdio_client(server_params) as (read_stream, write_stream):
137
+ async with stdio_client_with_rich_stderr(server_params) as (
138
+ read_stream,
139
+ write_stream,
140
+ ):
138
141
  session = client_session_factory(
139
142
  read_stream,
140
143
  write_stream,
@@ -6,7 +6,7 @@ fast = FastAgent("FastAgent Example")
6
6
 
7
7
 
8
8
  # Define the agent
9
- @fast.agent(servers=["fetch"])
9
+ @fast.agent(servers=["fetch", "mcp_hfspace"])
10
10
  async def main():
11
11
  # use the --model command line switch or agent arguments to change model
12
12
  async with fast.run() as agent:
@@ -50,3 +50,6 @@ mcp:
50
50
  category:
51
51
  command: "uv"
52
52
  args: ["run", "prompt_category.py"]
53
+ mcp_hfspace:
54
+ command: "npx"
55
+ args: ["@llmindset/mcp-hfspace"]
@@ -50,9 +50,9 @@ fast = FastAgent("Orchestrator-Workers")
50
50
  async def main():
51
51
  async with fast.run() as agent:
52
52
  await agent()
53
- # await agent.author(
54
- # "write a 250 word short story about kittens discovering a castle, and save it to short_story.md"
55
- # )
53
+ await agent.author(
54
+ "write a 250 word short story about kittens discovering a castle, and save it to short_story.md"
55
+ )
56
56
 
57
57
  # The orchestrator can be used just like any other agent
58
58
  task = (
@@ -1,5 +1,5 @@
1
1
  from typing import Any, List, Optional, Type, Union
2
-
2
+ import json # Import at the module level
3
3
  from mcp import GetPromptResult
4
4
  from mcp.types import PromptMessage
5
5
  from pydantic_core import from_json
@@ -45,11 +45,101 @@ class PassthroughLLM(AugmentedLLM):
45
45
  request_params: Optional[RequestParams] = None,
46
46
  ) -> str:
47
47
  """Return the input message as a string."""
48
+ # Check if this is a special command to call a tool
49
+ if isinstance(message, str) and message.startswith("***CALL_TOOL "):
50
+ return await self._call_tool_and_return_result(message)
51
+
48
52
  self.show_user_message(message, model="fastagent-passthrough", chat_turn=0)
49
53
  await self.show_assistant_message(message, title="ASSISTANT/PASSTHROUGH")
50
54
 
55
+ # Handle PromptMessage by concatenating all parts
56
+ if isinstance(message, PromptMessage):
57
+ parts_text = []
58
+ for part in message.content:
59
+ parts_text.append(str(part))
60
+ return "\n".join(parts_text)
61
+
51
62
  return str(message)
52
63
 
64
+ async def _call_tool_and_return_result(self, command: str) -> str:
65
+ """
66
+ Call a tool based on the command and return its result as a string.
67
+
68
+ Args:
69
+ command: The command string, expected format: "***CALL_TOOL <server>-<tool_name> [arguments_json]"
70
+
71
+ Returns:
72
+ Tool result as a string
73
+ """
74
+ try:
75
+ tool_name, arguments = self._parse_tool_command(command)
76
+ result = await self.aggregator.call_tool(tool_name, arguments)
77
+ return self._format_tool_result(tool_name, result)
78
+ except Exception as e:
79
+ self.logger.error(f"Error calling tool: {str(e)}")
80
+ return f"Error calling tool: {str(e)}"
81
+
82
+ def _parse_tool_command(self, command: str) -> tuple[str, Optional[dict]]:
83
+ """
84
+ Parse a tool command string into tool name and arguments.
85
+
86
+ Args:
87
+ command: The command string in format "***CALL_TOOL <tool_name> [arguments_json]"
88
+
89
+ Returns:
90
+ Tuple of (tool_name, arguments_dict)
91
+
92
+ Raises:
93
+ ValueError: If command format is invalid
94
+ """
95
+ parts = command.split(" ", 2)
96
+ if len(parts) < 2:
97
+ raise ValueError(
98
+ "Invalid format. Expected '***CALL_TOOL <tool_name> [arguments_json]'"
99
+ )
100
+
101
+ tool_name = parts[1].strip()
102
+ arguments = None
103
+
104
+ if len(parts) > 2:
105
+ try:
106
+ arguments = json.loads(parts[2])
107
+ except json.JSONDecodeError:
108
+ raise ValueError(f"Invalid JSON arguments: {parts[2]}")
109
+
110
+ self.logger.info(f"Calling tool {tool_name} with arguments {arguments}")
111
+ return tool_name, arguments
112
+
113
+ def _format_tool_result(self, tool_name: str, result) -> str:
114
+ """
115
+ Format tool execution result as a string.
116
+
117
+ Args:
118
+ tool_name: The name of the tool that was called
119
+ result: The result returned from the tool
120
+
121
+ Returns:
122
+ Formatted result as a string
123
+ """
124
+ if result.isError:
125
+ error_text = []
126
+ for content_item in result.content:
127
+ if hasattr(content_item, "text"):
128
+ error_text.append(content_item.text)
129
+ else:
130
+ error_text.append(str(content_item))
131
+ error_message = "\n".join(error_text) if error_text else "Unknown error"
132
+ return f"Error calling tool '{tool_name}': {error_message}"
133
+
134
+ result_text = []
135
+ for content_item in result.content:
136
+ if hasattr(content_item, "text"):
137
+ result_text.append(content_item.text)
138
+ else:
139
+ result_text.append(str(content_item))
140
+
141
+ return "\n".join(result_text)
142
+
53
143
  async def generate_structured(
54
144
  self,
55
145
  message: Union[str, MessageParamT, List[MessageParamT]],
@@ -71,7 +161,25 @@ class PassthroughLLM(AugmentedLLM):
71
161
  async def generate_prompt(
72
162
  self, prompt: "PromptMessageMultipart", request_params: RequestParams | None
73
163
  ) -> str:
74
- return await self.generate_str(prompt.content[0].text, request_params)
164
+ # Check if this prompt contains a tool call command
165
+ if (
166
+ prompt.content
167
+ and prompt.content[0].text
168
+ and prompt.content[0].text.startswith("***CALL_TOOL ")
169
+ ):
170
+ return await self._call_tool_and_return_result(prompt.content[0].text)
171
+
172
+ # Process all parts of the PromptMessageMultipart
173
+ parts_text = []
174
+ for part in prompt.content:
175
+ parts_text.append(str(part))
176
+
177
+ # If no parts found, return empty string
178
+ if not parts_text:
179
+ return ""
180
+
181
+ # Join all parts and process with generate_str
182
+ return await self.generate_str("\n".join(parts_text), request_params)
75
183
 
76
184
  async def apply_prompt_template(
77
185
  self, prompt_result: GetPromptResult, prompt_name: str