fast-agent-mcp 0.1.10__py3-none-any.whl → 0.1.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fast-agent-mcp might be problematic. Click here for more details.
- {fast_agent_mcp-0.1.10.dist-info → fast_agent_mcp-0.1.12.dist-info}/METADATA +36 -38
- {fast_agent_mcp-0.1.10.dist-info → fast_agent_mcp-0.1.12.dist-info}/RECORD +45 -42
- mcp_agent/agents/agent.py +1 -24
- mcp_agent/app.py +0 -5
- mcp_agent/config.py +9 -0
- mcp_agent/context.py +0 -2
- mcp_agent/core/agent_app.py +29 -0
- mcp_agent/core/agent_types.py +29 -2
- mcp_agent/core/decorators.py +1 -2
- mcp_agent/core/error_handling.py +1 -1
- mcp_agent/core/factory.py +2 -3
- mcp_agent/core/mcp_content.py +2 -3
- mcp_agent/core/proxies.py +3 -0
- mcp_agent/core/request_params.py +43 -0
- mcp_agent/core/types.py +4 -2
- mcp_agent/core/validation.py +14 -15
- mcp_agent/logging/transport.py +2 -2
- mcp_agent/mcp/gen_client.py +4 -4
- mcp_agent/mcp/interfaces.py +186 -0
- mcp_agent/mcp/mcp_agent_client_session.py +10 -2
- mcp_agent/mcp/mcp_aggregator.py +12 -3
- mcp_agent/mcp/sampling.py +140 -0
- mcp_agent/mcp/stdio.py +1 -2
- mcp_agent/mcp_server/__init__.py +1 -1
- mcp_agent/resources/examples/internal/agent.py +1 -1
- mcp_agent/resources/examples/internal/fastagent.config.yaml +3 -0
- mcp_agent/resources/examples/prompting/__init__.py +1 -1
- mcp_agent/ui/console_display.py +2 -2
- mcp_agent/workflows/evaluator_optimizer/evaluator_optimizer.py +2 -2
- mcp_agent/workflows/llm/augmented_llm.py +42 -102
- mcp_agent/workflows/llm/augmented_llm_anthropic.py +4 -3
- mcp_agent/workflows/llm/augmented_llm_openai.py +4 -3
- mcp_agent/workflows/llm/augmented_llm_passthrough.py +119 -37
- mcp_agent/workflows/llm/model_factory.py +1 -1
- mcp_agent/workflows/llm/prompt_utils.py +42 -28
- mcp_agent/workflows/llm/providers/multipart_converter_anthropic.py +244 -140
- mcp_agent/workflows/llm/providers/multipart_converter_openai.py +230 -185
- mcp_agent/workflows/llm/providers/sampling_converter_anthropic.py +5 -204
- mcp_agent/workflows/llm/providers/sampling_converter_openai.py +9 -207
- mcp_agent/workflows/llm/sampling_converter.py +124 -0
- mcp_agent/workflows/llm/sampling_format_converter.py +0 -17
- mcp_agent/workflows/router/router_base.py +10 -10
- mcp_agent/workflows/llm/llm_selector.py +0 -345
- {fast_agent_mcp-0.1.10.dist-info → fast_agent_mcp-0.1.12.dist-info}/WHEEL +0 -0
- {fast_agent_mcp-0.1.10.dist-info → fast_agent_mcp-0.1.12.dist-info}/entry_points.txt +0 -0
- {fast_agent_mcp-0.1.10.dist-info → fast_agent_mcp-0.1.12.dist-info}/licenses/LICENSE +0 -0
mcp_agent/core/types.py
CHANGED
@@ -7,7 +7,9 @@ from typing import Dict, Union, TypeAlias, TYPE_CHECKING
|
|
7
7
|
from mcp_agent.agents.agent import Agent
|
8
8
|
from mcp_agent.workflows.orchestrator.orchestrator import Orchestrator
|
9
9
|
from mcp_agent.workflows.parallel.parallel_llm import ParallelLLM
|
10
|
-
from mcp_agent.workflows.evaluator_optimizer.evaluator_optimizer import
|
10
|
+
from mcp_agent.workflows.evaluator_optimizer.evaluator_optimizer import (
|
11
|
+
EvaluatorOptimizerLLM,
|
12
|
+
)
|
11
13
|
from mcp_agent.workflows.router.router_llm import LLMRouter
|
12
14
|
|
13
15
|
# Avoid circular imports
|
@@ -19,4 +21,4 @@ WorkflowType: TypeAlias = Union[
|
|
19
21
|
Orchestrator, ParallelLLM, EvaluatorOptimizerLLM, LLMRouter
|
20
22
|
]
|
21
23
|
AgentOrWorkflow: TypeAlias = Union[Agent, WorkflowType]
|
22
|
-
ProxyDict: TypeAlias = Dict[str, "BaseAgentProxy"] # Forward reference as string
|
24
|
+
ProxyDict: TypeAlias = Dict[str, "BaseAgentProxy"] # Forward reference as string
|
mcp_agent/core/validation.py
CHANGED
@@ -5,14 +5,18 @@ Validation utilities for FastAgent configuration and dependencies.
|
|
5
5
|
from typing import Dict, List, Any
|
6
6
|
from mcp_agent.core.agent_types import AgentType
|
7
7
|
from mcp_agent.workflows.llm.augmented_llm import AugmentedLLM
|
8
|
-
from mcp_agent.core.exceptions import
|
8
|
+
from mcp_agent.core.exceptions import (
|
9
|
+
ServerConfigError,
|
10
|
+
AgentConfigError,
|
11
|
+
CircularDependencyError,
|
12
|
+
)
|
9
13
|
|
10
14
|
|
11
15
|
def validate_server_references(context, agents: Dict[str, Dict[str, Any]]) -> None:
|
12
16
|
"""
|
13
17
|
Validate that all server references in agent configurations exist in config.
|
14
18
|
Raises ServerConfigError if any referenced servers are not defined.
|
15
|
-
|
19
|
+
|
16
20
|
Args:
|
17
21
|
context: Application context
|
18
22
|
agents: Dictionary of agent configurations
|
@@ -39,7 +43,7 @@ def validate_workflow_references(agents: Dict[str, Dict[str, Any]]) -> None:
|
|
39
43
|
Validate that all workflow references point to valid agents/workflows.
|
40
44
|
Also validates that referenced agents have required configuration.
|
41
45
|
Raises AgentConfigError if any validation fails.
|
42
|
-
|
46
|
+
|
43
47
|
Args:
|
44
48
|
agents: Dictionary of agent configurations
|
45
49
|
"""
|
@@ -133,11 +137,11 @@ def validate_workflow_references(agents: Dict[str, Dict[str, Any]]) -> None:
|
|
133
137
|
|
134
138
|
|
135
139
|
def get_dependencies(
|
136
|
-
name: str,
|
140
|
+
name: str,
|
137
141
|
agents: Dict[str, Dict[str, Any]],
|
138
|
-
visited: set,
|
139
|
-
path: set,
|
140
|
-
agent_type: AgentType = None
|
142
|
+
visited: set,
|
143
|
+
path: set,
|
144
|
+
agent_type: AgentType = None,
|
141
145
|
) -> List[str]:
|
142
146
|
"""
|
143
147
|
Get dependencies for an agent in topological order.
|
@@ -184,9 +188,7 @@ def get_dependencies(
|
|
184
188
|
# Get dependencies from sequence agents
|
185
189
|
sequence = config.get("sequence", config.get("agents", []))
|
186
190
|
for agent_name in sequence:
|
187
|
-
deps.extend(
|
188
|
-
get_dependencies(agent_name, agents, visited, path, agent_type)
|
189
|
-
)
|
191
|
+
deps.extend(get_dependencies(agent_name, agents, visited, path, agent_type))
|
190
192
|
|
191
193
|
# Add this agent after its dependencies
|
192
194
|
deps.append(name)
|
@@ -197,10 +199,7 @@ def get_dependencies(
|
|
197
199
|
|
198
200
|
|
199
201
|
def get_parallel_dependencies(
|
200
|
-
name: str,
|
201
|
-
agents: Dict[str, Dict[str, Any]],
|
202
|
-
visited: set,
|
203
|
-
path: set
|
202
|
+
name: str, agents: Dict[str, Dict[str, Any]], visited: set, path: set
|
204
203
|
) -> List[str]:
|
205
204
|
"""
|
206
205
|
Get dependencies for a parallel agent in topological order.
|
@@ -218,4 +217,4 @@ def get_parallel_dependencies(
|
|
218
217
|
Raises:
|
219
218
|
CircularDependencyError: If circular dependency detected
|
220
219
|
"""
|
221
|
-
return get_dependencies(name, agents, visited, path, AgentType.PARALLEL)
|
220
|
+
return get_dependencies(name, agents, visited, path, AgentType.PARALLEL)
|
mcp_agent/logging/transport.py
CHANGED
@@ -290,7 +290,7 @@ class AsyncEventBus:
|
|
290
290
|
# Update transport if provided
|
291
291
|
cls._instance.transport = transport
|
292
292
|
return cls._instance
|
293
|
-
|
293
|
+
|
294
294
|
@classmethod
|
295
295
|
def reset(cls) -> None:
|
296
296
|
"""
|
@@ -302,7 +302,7 @@ class AsyncEventBus:
|
|
302
302
|
# Signal shutdown
|
303
303
|
cls._instance._running = False
|
304
304
|
cls._instance._stop_event.set()
|
305
|
-
|
305
|
+
|
306
306
|
# Clear the singleton instance
|
307
307
|
cls._instance = None
|
308
308
|
|
mcp_agent/mcp/gen_client.py
CHANGED
@@ -6,7 +6,7 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStre
|
|
6
6
|
from mcp import ClientSession
|
7
7
|
|
8
8
|
from mcp_agent.logging.logger import get_logger
|
9
|
-
from mcp_agent.
|
9
|
+
from mcp_agent.mcp.interfaces import ServerRegistryProtocol
|
10
10
|
from mcp_agent.mcp.mcp_agent_client_session import MCPAgentClientSession
|
11
11
|
|
12
12
|
logger = get_logger(__name__)
|
@@ -15,7 +15,7 @@ logger = get_logger(__name__)
|
|
15
15
|
@asynccontextmanager
|
16
16
|
async def gen_client(
|
17
17
|
server_name: str,
|
18
|
-
server_registry:
|
18
|
+
server_registry: ServerRegistryProtocol,
|
19
19
|
client_session_factory: Callable[
|
20
20
|
[MemoryObjectReceiveStream, MemoryObjectSendStream, timedelta | None],
|
21
21
|
ClientSession,
|
@@ -41,7 +41,7 @@ async def gen_client(
|
|
41
41
|
|
42
42
|
async def connect(
|
43
43
|
server_name: str,
|
44
|
-
server_registry:
|
44
|
+
server_registry: ServerRegistryProtocol,
|
45
45
|
client_session_factory: Callable[
|
46
46
|
[MemoryObjectReceiveStream, MemoryObjectSendStream, timedelta | None],
|
47
47
|
ClientSession,
|
@@ -67,7 +67,7 @@ async def connect(
|
|
67
67
|
|
68
68
|
async def disconnect(
|
69
69
|
server_name: str | None,
|
70
|
-
server_registry:
|
70
|
+
server_registry: ServerRegistryProtocol,
|
71
71
|
) -> None:
|
72
72
|
"""
|
73
73
|
Disconnect from the specified server. If server_name is None, disconnect from all servers.
|
@@ -0,0 +1,186 @@
|
|
1
|
+
"""
|
2
|
+
Interface definitions to prevent circular imports.
|
3
|
+
This module defines protocols (interfaces) that can be used to break circular dependencies.
|
4
|
+
"""
|
5
|
+
|
6
|
+
from contextlib import asynccontextmanager
|
7
|
+
from typing import (
|
8
|
+
Any,
|
9
|
+
AsyncGenerator,
|
10
|
+
Callable,
|
11
|
+
Generic,
|
12
|
+
List,
|
13
|
+
Optional,
|
14
|
+
Protocol,
|
15
|
+
Type,
|
16
|
+
TypeVar,
|
17
|
+
)
|
18
|
+
|
19
|
+
from mcp import ClientSession
|
20
|
+
from mcp.types import CreateMessageRequestParams
|
21
|
+
from pydantic import Field
|
22
|
+
|
23
|
+
from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
|
24
|
+
|
25
|
+
|
26
|
+
class ServerRegistryProtocol(Protocol):
|
27
|
+
"""
|
28
|
+
Protocol defining the minimal interface of ServerRegistry needed by gen_client.
|
29
|
+
This allows gen_client to depend on this protocol rather than the full ServerRegistry class.
|
30
|
+
"""
|
31
|
+
|
32
|
+
@asynccontextmanager
|
33
|
+
async def initialize_server(
|
34
|
+
self,
|
35
|
+
server_name: str,
|
36
|
+
client_session_factory=None,
|
37
|
+
init_hook=None,
|
38
|
+
) -> AsyncGenerator[ClientSession, None]:
|
39
|
+
"""Initialize a server and yield a client session."""
|
40
|
+
...
|
41
|
+
|
42
|
+
@property
|
43
|
+
def connection_manager(self) -> "ConnectionManagerProtocol":
|
44
|
+
"""Get the connection manager."""
|
45
|
+
...
|
46
|
+
|
47
|
+
|
48
|
+
class ConnectionManagerProtocol(Protocol):
|
49
|
+
"""
|
50
|
+
Protocol defining the minimal interface of ConnectionManager needed.
|
51
|
+
"""
|
52
|
+
|
53
|
+
async def get_server(
|
54
|
+
self,
|
55
|
+
server_name: str,
|
56
|
+
client_session_factory=None,
|
57
|
+
):
|
58
|
+
"""Get a server connection."""
|
59
|
+
...
|
60
|
+
|
61
|
+
async def disconnect_server(self, server_name: str) -> None:
|
62
|
+
"""Disconnect from a server."""
|
63
|
+
...
|
64
|
+
|
65
|
+
async def disconnect_all_servers(self) -> None:
|
66
|
+
"""Disconnect from all servers."""
|
67
|
+
...
|
68
|
+
|
69
|
+
|
70
|
+
# Type variables for generic protocols
|
71
|
+
MessageParamT = TypeVar("MessageParamT")
|
72
|
+
"""A type representing an input message to an LLM."""
|
73
|
+
|
74
|
+
MessageT = TypeVar("MessageT")
|
75
|
+
"""A type representing an output message from an LLM."""
|
76
|
+
|
77
|
+
ModelT = TypeVar("ModelT")
|
78
|
+
"""A type representing a structured output message from an LLM."""
|
79
|
+
|
80
|
+
|
81
|
+
class RequestParams(CreateMessageRequestParams):
|
82
|
+
"""
|
83
|
+
Parameters to configure the AugmentedLLM 'generate' requests.
|
84
|
+
"""
|
85
|
+
|
86
|
+
messages: None = Field(exclude=True, default=None)
|
87
|
+
"""
|
88
|
+
Ignored. 'messages' are removed from CreateMessageRequestParams
|
89
|
+
to avoid confusion with the 'message' parameter on 'generate' method.
|
90
|
+
"""
|
91
|
+
|
92
|
+
maxTokens: int = 2048
|
93
|
+
"""The maximum number of tokens to sample, as requested by the server."""
|
94
|
+
|
95
|
+
model: str | None = None
|
96
|
+
"""
|
97
|
+
The model to use for the LLM generation.
|
98
|
+
If specified, this overrides the 'modelPreferences' selection criteria.
|
99
|
+
"""
|
100
|
+
|
101
|
+
use_history: bool = True
|
102
|
+
"""
|
103
|
+
Include the message history in the generate request.
|
104
|
+
"""
|
105
|
+
|
106
|
+
max_iterations: int = 10
|
107
|
+
"""
|
108
|
+
The maximum number of iterations to run the LLM for.
|
109
|
+
"""
|
110
|
+
|
111
|
+
parallel_tool_calls: bool = True
|
112
|
+
"""
|
113
|
+
Whether to allow multiple tool calls per iteration.
|
114
|
+
Also known as multi-step tool use.
|
115
|
+
"""
|
116
|
+
|
117
|
+
|
118
|
+
class AugmentedLLMProtocol(Protocol, Generic[MessageParamT, MessageT]):
|
119
|
+
"""Protocol defining the interface for augmented LLMs"""
|
120
|
+
|
121
|
+
async def generate(
|
122
|
+
self,
|
123
|
+
message: str | MessageParamT | List[MessageParamT],
|
124
|
+
request_params: RequestParams | None = None,
|
125
|
+
) -> List[MessageT]:
|
126
|
+
"""Request an LLM generation, which may run multiple iterations, and return the result"""
|
127
|
+
|
128
|
+
async def generate_str(
|
129
|
+
self,
|
130
|
+
message: str | MessageParamT | List[MessageParamT],
|
131
|
+
request_params: RequestParams | None = None,
|
132
|
+
) -> str:
|
133
|
+
"""Request an LLM generation and return the string representation of the result"""
|
134
|
+
|
135
|
+
async def generate_structured(
|
136
|
+
self,
|
137
|
+
message: str | MessageParamT | List[MessageParamT],
|
138
|
+
response_model: Type[ModelT],
|
139
|
+
request_params: RequestParams | None = None,
|
140
|
+
) -> ModelT:
|
141
|
+
"""Request a structured LLM generation and return the result as a Pydantic model."""
|
142
|
+
|
143
|
+
async def generate_prompt(
|
144
|
+
self, prompt: PromptMessageMultipart, request_params: RequestParams | None
|
145
|
+
) -> str:
|
146
|
+
"""Request an LLM generation and return a string representation of the result"""
|
147
|
+
|
148
|
+
async def apply_prompt(
|
149
|
+
self,
|
150
|
+
multipart_messages: List["PromptMessageMultipart"],
|
151
|
+
request_params: RequestParams | None = None,
|
152
|
+
) -> str:
|
153
|
+
"""
|
154
|
+
Apply a list of PromptMessageMultipart messages directly to the LLM.
|
155
|
+
This is a cleaner interface to _apply_prompt_template_provider_specific.
|
156
|
+
|
157
|
+
Args:
|
158
|
+
multipart_messages: List of PromptMessageMultipart objects
|
159
|
+
request_params: Optional parameters to configure the LLM request
|
160
|
+
|
161
|
+
Returns:
|
162
|
+
String representation of the assistant's response
|
163
|
+
"""
|
164
|
+
|
165
|
+
|
166
|
+
class ModelFactoryClassProtocol(Protocol):
|
167
|
+
"""
|
168
|
+
Protocol defining the minimal interface of the ModelFactory class needed by sampling.
|
169
|
+
This allows sampling.py to depend on this protocol rather than the concrete ModelFactory class.
|
170
|
+
"""
|
171
|
+
|
172
|
+
@classmethod
|
173
|
+
def create_factory(
|
174
|
+
cls, model_string: str, request_params: Optional[RequestParams] = None
|
175
|
+
) -> Callable[..., AugmentedLLMProtocol[Any, Any]]:
|
176
|
+
"""
|
177
|
+
Creates a factory function that can be used to construct an LLM instance.
|
178
|
+
|
179
|
+
Args:
|
180
|
+
model_string: The model specification string
|
181
|
+
request_params: Optional parameters to configure LLM behavior
|
182
|
+
|
183
|
+
Returns:
|
184
|
+
A factory function that can create an LLM instance
|
185
|
+
"""
|
186
|
+
...
|
@@ -24,6 +24,7 @@ from pydantic import AnyUrl
|
|
24
24
|
from mcp_agent.config import MCPServerSettings
|
25
25
|
from mcp_agent.context_dependent import ContextDependent
|
26
26
|
from mcp_agent.logging.logger import get_logger
|
27
|
+
from mcp_agent.mcp.sampling import sample
|
27
28
|
|
28
29
|
logger = get_logger(__name__)
|
29
30
|
|
@@ -40,7 +41,12 @@ async def list_roots(ctx: ClientSession) -> ListRootsResult:
|
|
40
41
|
and ctx.session.server_config.roots
|
41
42
|
):
|
42
43
|
roots = [
|
43
|
-
Root(
|
44
|
+
Root(
|
45
|
+
uri=AnyUrl(
|
46
|
+
root.server_uri_alias or root.uri,
|
47
|
+
),
|
48
|
+
name=root.name,
|
49
|
+
)
|
44
50
|
for root in ctx.session.server_config.roots
|
45
51
|
]
|
46
52
|
return ListRootsResult(roots=roots or [])
|
@@ -58,7 +64,9 @@ class MCPAgentClientSession(ClientSession, ContextDependent):
|
|
58
64
|
"""
|
59
65
|
|
60
66
|
def __init__(self, *args, **kwargs):
|
61
|
-
super().__init__(
|
67
|
+
super().__init__(
|
68
|
+
*args, **kwargs, list_roots_callback=list_roots, sampling_callback=sample
|
69
|
+
)
|
62
70
|
self.server_config: Optional[MCPServerSettings] = None
|
63
71
|
|
64
72
|
async def send_request(
|
mcp_agent/mcp/mcp_aggregator.py
CHANGED
@@ -16,6 +16,7 @@ from mcp.server.stdio import stdio_server
|
|
16
16
|
from mcp.types import (
|
17
17
|
CallToolResult,
|
18
18
|
ListToolsResult,
|
19
|
+
TextContent,
|
19
20
|
Tool,
|
20
21
|
Prompt,
|
21
22
|
)
|
@@ -459,7 +460,10 @@ class MCPAggregator(ContextDependent):
|
|
459
460
|
|
460
461
|
if server_name is None or local_tool_name is None:
|
461
462
|
logger.error(f"Error: Tool '{name}' not found")
|
462
|
-
return CallToolResult(
|
463
|
+
return CallToolResult(
|
464
|
+
isError=True,
|
465
|
+
content=[TextContent(type="text", text=f"Tool '{name}' not found")],
|
466
|
+
)
|
463
467
|
|
464
468
|
logger.info(
|
465
469
|
"Requesting tool call",
|
@@ -477,7 +481,9 @@ class MCPAggregator(ContextDependent):
|
|
477
481
|
operation_name=local_tool_name,
|
478
482
|
method_name="call_tool",
|
479
483
|
method_args={"name": local_tool_name, "arguments": arguments},
|
480
|
-
error_factory=lambda msg: CallToolResult(
|
484
|
+
error_factory=lambda msg: CallToolResult(
|
485
|
+
isError=True, content=[TextContent(type="text", text=msg)]
|
486
|
+
),
|
481
487
|
)
|
482
488
|
|
483
489
|
async def get_prompt(
|
@@ -898,7 +904,10 @@ class MCPCompoundServer(Server):
|
|
898
904
|
result = await self.aggregator.call_tool(name=name, arguments=arguments)
|
899
905
|
return result.content
|
900
906
|
except Exception as e:
|
901
|
-
return CallToolResult(
|
907
|
+
return CallToolResult(
|
908
|
+
isError=True,
|
909
|
+
content=[TextContent(type="text", text=f"Error calling tool: {e}")],
|
910
|
+
)
|
902
911
|
|
903
912
|
async def _get_prompt(
|
904
913
|
self, name: str = None, arguments: dict[str, str] = None
|
@@ -0,0 +1,140 @@
|
|
1
|
+
"""
|
2
|
+
This simplified implementation directly converts between MCP types and PromptMessageMultipart.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from mcp import ClientSession
|
6
|
+
from mcp.types import (
|
7
|
+
CreateMessageRequestParams,
|
8
|
+
CreateMessageResult,
|
9
|
+
)
|
10
|
+
|
11
|
+
from mcp_agent.core.agent_types import AgentConfig
|
12
|
+
from mcp_agent.logging.logger import get_logger
|
13
|
+
from mcp_agent.mcp.interfaces import AugmentedLLMProtocol
|
14
|
+
|
15
|
+
from mcp_agent.workflows.llm.sampling_converter import SamplingConverter
|
16
|
+
|
17
|
+
logger = get_logger(__name__)
|
18
|
+
|
19
|
+
|
20
|
+
def create_sampling_llm(
|
21
|
+
params: CreateMessageRequestParams, model_string: str
|
22
|
+
) -> AugmentedLLMProtocol:
|
23
|
+
"""
|
24
|
+
Create an LLM instance for sampling without tools support.
|
25
|
+
This utility function creates a minimal LLM instance based on the model string.
|
26
|
+
|
27
|
+
Args:
|
28
|
+
mcp_ctx: The MCP ClientSession
|
29
|
+
model_string: The model to use (e.g. "passthrough", "claude-3-5-sonnet-latest")
|
30
|
+
|
31
|
+
Returns:
|
32
|
+
An initialized LLM instance ready to use
|
33
|
+
"""
|
34
|
+
from mcp_agent.workflows.llm.model_factory import ModelFactory
|
35
|
+
from mcp_agent.agents.agent import Agent
|
36
|
+
|
37
|
+
app_context = None
|
38
|
+
try:
|
39
|
+
from mcp_agent.context import get_current_context
|
40
|
+
|
41
|
+
app_context = get_current_context()
|
42
|
+
except Exception:
|
43
|
+
logger.warning("App context not available for sampling call")
|
44
|
+
|
45
|
+
agent = Agent(
|
46
|
+
config=sampling_agent_config(params),
|
47
|
+
context=app_context,
|
48
|
+
connection_persistence=False,
|
49
|
+
)
|
50
|
+
|
51
|
+
# Create the LLM using the factory
|
52
|
+
factory = ModelFactory.create_factory(model_string)
|
53
|
+
llm = factory(agent=agent)
|
54
|
+
|
55
|
+
# Attach the LLM to the agent
|
56
|
+
agent._llm = llm
|
57
|
+
|
58
|
+
return llm
|
59
|
+
|
60
|
+
|
61
|
+
async def sample(
|
62
|
+
mcp_ctx: ClientSession, params: CreateMessageRequestParams
|
63
|
+
) -> CreateMessageResult:
|
64
|
+
"""
|
65
|
+
Handle sampling requests from the MCP protocol using SamplingConverter.
|
66
|
+
|
67
|
+
This function:
|
68
|
+
1. Extracts the model from the request
|
69
|
+
2. Uses SamplingConverter to convert types
|
70
|
+
3. Calls the LLM's generate_prompt method
|
71
|
+
4. Returns the result as a CreateMessageResult
|
72
|
+
|
73
|
+
Args:
|
74
|
+
mcp_ctx: The MCP ClientSession
|
75
|
+
params: The sampling request parameters
|
76
|
+
|
77
|
+
Returns:
|
78
|
+
A CreateMessageResult containing the LLM's response
|
79
|
+
"""
|
80
|
+
model = None
|
81
|
+
try:
|
82
|
+
# Extract model from server config
|
83
|
+
if (
|
84
|
+
hasattr(mcp_ctx, "session")
|
85
|
+
and hasattr(mcp_ctx.session, "server_config")
|
86
|
+
and mcp_ctx.session.server_config
|
87
|
+
and hasattr(mcp_ctx.session.server_config, "sampling")
|
88
|
+
and mcp_ctx.session.server_config.sampling.model
|
89
|
+
):
|
90
|
+
model = mcp_ctx.session.server_config.sampling.model
|
91
|
+
|
92
|
+
if model is None:
|
93
|
+
raise ValueError("No model configured")
|
94
|
+
|
95
|
+
# Create an LLM instance
|
96
|
+
llm = create_sampling_llm(params, model)
|
97
|
+
|
98
|
+
# Extract all messages from the request params
|
99
|
+
if not params.messages:
|
100
|
+
raise ValueError("No messages provided")
|
101
|
+
|
102
|
+
# Convert all SamplingMessages to PromptMessageMultipart objects
|
103
|
+
conversation = SamplingConverter.convert_messages(params.messages)
|
104
|
+
|
105
|
+
# Extract request parameters using our converter
|
106
|
+
request_params = SamplingConverter.extract_request_params(params)
|
107
|
+
|
108
|
+
# Use the new public apply_prompt method which is cleaner than calling the protected method
|
109
|
+
llm_response = await llm.apply_prompt(conversation, request_params)
|
110
|
+
logger.info(f"Complete sampling request : {llm_response[:50]}...")
|
111
|
+
|
112
|
+
# Create result using our converter
|
113
|
+
return SamplingConverter.create_message_result(
|
114
|
+
response=llm_response, model=model
|
115
|
+
)
|
116
|
+
except Exception as e:
|
117
|
+
logger.error(f"Error in sampling: {str(e)}")
|
118
|
+
return SamplingConverter.error_result(
|
119
|
+
error_message=f"Error in sampling: {str(e)}", model=model
|
120
|
+
)
|
121
|
+
|
122
|
+
|
123
|
+
def sampling_agent_config(
|
124
|
+
params: CreateMessageRequestParams = None,
|
125
|
+
) -> AgentConfig:
|
126
|
+
"""
|
127
|
+
Build a sampling AgentConfig based on request parameters.
|
128
|
+
|
129
|
+
Args:
|
130
|
+
params: Optional CreateMessageRequestParams that may contain a system prompt
|
131
|
+
|
132
|
+
Returns:
|
133
|
+
An initialized AgentConfig for use in sampling
|
134
|
+
"""
|
135
|
+
# Use systemPrompt from params if available, otherwise use default
|
136
|
+
instruction = "You are a helpful AI Agent."
|
137
|
+
if params and hasattr(params, "systemPrompt") and params.systemPrompt is not None:
|
138
|
+
instruction = params.systemPrompt
|
139
|
+
|
140
|
+
return AgentConfig(name="sampling_agent", instruction=instruction, servers=[])
|
mcp_agent/mcp/stdio.py
CHANGED
@@ -14,7 +14,7 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStre
|
|
14
14
|
logger = get_logger(__name__)
|
15
15
|
|
16
16
|
|
17
|
-
# TODO this will be removed when client library
|
17
|
+
# TODO this will be removed when client library with https://github.com/modelcontextprotocol/python-sdk/pull/343 is released
|
18
18
|
@asynccontextmanager
|
19
19
|
async def stdio_client_with_rich_stderr(server: StdioServerParameters):
|
20
20
|
"""
|
@@ -95,7 +95,6 @@ async def stdio_client_with_rich_stderr(server: StdioServerParameters):
|
|
95
95
|
async with write_stream_reader:
|
96
96
|
async for message in write_stream_reader:
|
97
97
|
json = message.model_dump_json(by_alias=True, exclude_none=True)
|
98
|
-
print(f"**********{id(process.stdin)}")
|
99
98
|
await process.stdin.send(
|
100
99
|
(json + "\n").encode(
|
101
100
|
encoding=server.encoding,
|
mcp_agent/mcp_server/__init__.py
CHANGED
@@ -6,7 +6,7 @@ fast = FastAgent("FastAgent Example")
|
|
6
6
|
|
7
7
|
|
8
8
|
# Define the agent
|
9
|
-
@fast.agent(servers=["fetch"])
|
9
|
+
@fast.agent(servers=["fetch", "mcp_hfspace"])
|
10
10
|
async def main():
|
11
11
|
# use the --model command line switch or agent arguments to change model
|
12
12
|
async with fast.run() as agent:
|
mcp_agent/ui/console_display.py
CHANGED
@@ -250,14 +250,14 @@ class ConsoleDisplay:
|
|
250
250
|
|
251
251
|
if agent_name:
|
252
252
|
content.append(f" for {agent_name}", style="cyan italic")
|
253
|
-
|
253
|
+
|
254
254
|
# Add template arguments if provided
|
255
255
|
if arguments:
|
256
256
|
content.append("\n\nArguments:", style="cyan")
|
257
257
|
for key, value in arguments.items():
|
258
258
|
content.append(f"\n {key}: ", style="cyan bold")
|
259
259
|
content.append(value, style="white")
|
260
|
-
|
260
|
+
|
261
261
|
if description:
|
262
262
|
content.append("\n\n", style="default")
|
263
263
|
content.append(description, style="dim white")
|
@@ -10,7 +10,8 @@ from mcp_agent.workflows.llm.augmented_llm import (
|
|
10
10
|
ModelT,
|
11
11
|
RequestParams,
|
12
12
|
)
|
13
|
-
from mcp_agent.agents.agent import Agent
|
13
|
+
from mcp_agent.agents.agent import Agent
|
14
|
+
from mcp_agent.core.agent_types import AgentConfig
|
14
15
|
from mcp_agent.logging.logger import get_logger
|
15
16
|
from mcp_agent.workflows.llm.augmented_llm_passthrough import PassthroughLLM
|
16
17
|
|
@@ -68,7 +69,6 @@ class EvaluatorOptimizerLLM(AugmentedLLM[MessageParamT, MessageT]):
|
|
68
69
|
def _initialize_default_params(self, kwargs: dict) -> RequestParams:
|
69
70
|
"""Initialize default parameters using the workflow's settings."""
|
70
71
|
return RequestParams(
|
71
|
-
modelPreferences=self.model_preferences,
|
72
72
|
systemPrompt=self.instruction,
|
73
73
|
parallel_tool_calls=True,
|
74
74
|
max_iterations=10,
|