fast-agent-mcp 0.1.11__py3-none-any.whl → 0.1.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of fast-agent-mcp might be problematic. Click here for more details.
- {fast_agent_mcp-0.1.11.dist-info → fast_agent_mcp-0.1.12.dist-info}/METADATA +1 -1
- {fast_agent_mcp-0.1.11.dist-info → fast_agent_mcp-0.1.12.dist-info}/RECORD +39 -38
- mcp_agent/agents/agent.py +1 -24
- mcp_agent/app.py +0 -5
- mcp_agent/context.py +0 -2
- mcp_agent/core/agent_app.py +1 -1
- mcp_agent/core/agent_types.py +29 -2
- mcp_agent/core/decorators.py +1 -2
- mcp_agent/core/error_handling.py +1 -1
- mcp_agent/core/factory.py +2 -3
- mcp_agent/core/mcp_content.py +2 -3
- mcp_agent/core/request_params.py +43 -0
- mcp_agent/core/types.py +4 -2
- mcp_agent/core/validation.py +14 -15
- mcp_agent/logging/transport.py +2 -2
- mcp_agent/mcp/interfaces.py +37 -3
- mcp_agent/mcp/mcp_agent_client_session.py +1 -1
- mcp_agent/mcp/mcp_aggregator.py +5 -6
- mcp_agent/mcp/sampling.py +60 -53
- mcp_agent/mcp_server/__init__.py +1 -1
- mcp_agent/resources/examples/prompting/__init__.py +1 -1
- mcp_agent/ui/console_display.py +2 -2
- mcp_agent/workflows/evaluator_optimizer/evaluator_optimizer.py +2 -2
- mcp_agent/workflows/llm/augmented_llm.py +42 -102
- mcp_agent/workflows/llm/augmented_llm_anthropic.py +4 -3
- mcp_agent/workflows/llm/augmented_llm_openai.py +4 -3
- mcp_agent/workflows/llm/augmented_llm_passthrough.py +33 -4
- mcp_agent/workflows/llm/model_factory.py +1 -1
- mcp_agent/workflows/llm/prompt_utils.py +42 -28
- mcp_agent/workflows/llm/providers/multipart_converter_anthropic.py +244 -140
- mcp_agent/workflows/llm/providers/multipart_converter_openai.py +230 -185
- mcp_agent/workflows/llm/providers/sampling_converter_anthropic.py +5 -204
- mcp_agent/workflows/llm/providers/sampling_converter_openai.py +9 -207
- mcp_agent/workflows/llm/sampling_converter.py +124 -0
- mcp_agent/workflows/llm/sampling_format_converter.py +0 -17
- mcp_agent/workflows/router/router_base.py +10 -10
- mcp_agent/workflows/llm/llm_selector.py +0 -345
- {fast_agent_mcp-0.1.11.dist-info → fast_agent_mcp-0.1.12.dist-info}/WHEEL +0 -0
- {fast_agent_mcp-0.1.11.dist-info → fast_agent_mcp-0.1.12.dist-info}/entry_points.txt +0 -0
- {fast_agent_mcp-0.1.11.dist-info → fast_agent_mcp-0.1.12.dist-info}/licenses/LICENSE +0 -0
mcp_agent/mcp/mcp_aggregator.py
CHANGED
@@ -461,8 +461,8 @@ class MCPAggregator(ContextDependent):
|
|
461
461
|
if server_name is None or local_tool_name is None:
|
462
462
|
logger.error(f"Error: Tool '{name}' not found")
|
463
463
|
return CallToolResult(
|
464
|
-
isError=True,
|
465
|
-
content=[TextContent(type="text", text=f"Tool '{name}' not found")]
|
464
|
+
isError=True,
|
465
|
+
content=[TextContent(type="text", text=f"Tool '{name}' not found")],
|
466
466
|
)
|
467
467
|
|
468
468
|
logger.info(
|
@@ -482,8 +482,7 @@ class MCPAggregator(ContextDependent):
|
|
482
482
|
method_name="call_tool",
|
483
483
|
method_args={"name": local_tool_name, "arguments": arguments},
|
484
484
|
error_factory=lambda msg: CallToolResult(
|
485
|
-
isError=True,
|
486
|
-
content=[TextContent(type="text", text=msg)]
|
485
|
+
isError=True, content=[TextContent(type="text", text=msg)]
|
487
486
|
),
|
488
487
|
)
|
489
488
|
|
@@ -906,8 +905,8 @@ class MCPCompoundServer(Server):
|
|
906
905
|
return result.content
|
907
906
|
except Exception as e:
|
908
907
|
return CallToolResult(
|
909
|
-
isError=True,
|
910
|
-
content=[TextContent(type="text", text=f"Error calling tool: {e}")]
|
908
|
+
isError=True,
|
909
|
+
content=[TextContent(type="text", text=f"Error calling tool: {e}")],
|
911
910
|
)
|
912
911
|
|
913
912
|
async def _get_prompt(
|
mcp_agent/mcp/sampling.py
CHANGED
@@ -1,26 +1,24 @@
|
|
1
1
|
"""
|
2
|
-
|
3
|
-
This module is carefully designed to avoid circular imports in the agent system.
|
2
|
+
This simplified implementation directly converts between MCP types and PromptMessageMultipart.
|
4
3
|
"""
|
5
4
|
|
6
5
|
from mcp import ClientSession
|
7
6
|
from mcp.types import (
|
8
7
|
CreateMessageRequestParams,
|
9
8
|
CreateMessageResult,
|
10
|
-
TextContent,
|
11
9
|
)
|
12
10
|
|
11
|
+
from mcp_agent.core.agent_types import AgentConfig
|
13
12
|
from mcp_agent.logging.logger import get_logger
|
14
13
|
from mcp_agent.mcp.interfaces import AugmentedLLMProtocol
|
15
|
-
from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
|
16
14
|
|
17
|
-
|
15
|
+
from mcp_agent.workflows.llm.sampling_converter import SamplingConverter
|
18
16
|
|
19
17
|
logger = get_logger(__name__)
|
20
18
|
|
21
19
|
|
22
20
|
def create_sampling_llm(
|
23
|
-
|
21
|
+
params: CreateMessageRequestParams, model_string: str
|
24
22
|
) -> AugmentedLLMProtocol:
|
25
23
|
"""
|
26
24
|
Create an LLM instance for sampling without tools support.
|
@@ -34,10 +32,8 @@ def create_sampling_llm(
|
|
34
32
|
An initialized LLM instance ready to use
|
35
33
|
"""
|
36
34
|
from mcp_agent.workflows.llm.model_factory import ModelFactory
|
37
|
-
from mcp_agent.agents.agent import Agent
|
35
|
+
from mcp_agent.agents.agent import Agent
|
38
36
|
|
39
|
-
# Get application context from global state if available
|
40
|
-
# We don't try to extract it from mcp_ctx as they're different contexts
|
41
37
|
app_context = None
|
42
38
|
try:
|
43
39
|
from mcp_agent.context import get_current_context
|
@@ -46,20 +42,10 @@ def create_sampling_llm(
|
|
46
42
|
except Exception:
|
47
43
|
logger.warning("App context not available for sampling call")
|
48
44
|
|
49
|
-
# Create a minimal agent configuration
|
50
|
-
agent_config = AgentConfig(
|
51
|
-
name="sampling_agent",
|
52
|
-
instruction="You are a sampling agent.",
|
53
|
-
servers=[], # No servers needed
|
54
|
-
)
|
55
|
-
|
56
|
-
# Create agent with our application context (not the MCP context)
|
57
|
-
# Set connection_persistence=False to avoid server connections
|
58
45
|
agent = Agent(
|
59
|
-
config=
|
46
|
+
config=sampling_agent_config(params),
|
60
47
|
context=app_context,
|
61
|
-
|
62
|
-
connection_persistence=False, # Avoid server connection management
|
48
|
+
connection_persistence=False,
|
63
49
|
)
|
64
50
|
|
65
51
|
# Create the LLM using the factory
|
@@ -76,9 +62,20 @@ async def sample(
|
|
76
62
|
mcp_ctx: ClientSession, params: CreateMessageRequestParams
|
77
63
|
) -> CreateMessageResult:
|
78
64
|
"""
|
79
|
-
Handle sampling requests from the MCP protocol.
|
80
|
-
|
81
|
-
|
65
|
+
Handle sampling requests from the MCP protocol using SamplingConverter.
|
66
|
+
|
67
|
+
This function:
|
68
|
+
1. Extracts the model from the request
|
69
|
+
2. Uses SamplingConverter to convert types
|
70
|
+
3. Calls the LLM's generate_prompt method
|
71
|
+
4. Returns the result as a CreateMessageResult
|
72
|
+
|
73
|
+
Args:
|
74
|
+
mcp_ctx: The MCP ClientSession
|
75
|
+
params: The sampling request parameters
|
76
|
+
|
77
|
+
Returns:
|
78
|
+
A CreateMessageResult containing the LLM's response
|
82
79
|
"""
|
83
80
|
model = None
|
84
81
|
try:
|
@@ -95,39 +92,49 @@ async def sample(
|
|
95
92
|
if model is None:
|
96
93
|
raise ValueError("No model configured")
|
97
94
|
|
98
|
-
# Create an LLM instance
|
99
|
-
llm = create_sampling_llm(
|
95
|
+
# Create an LLM instance
|
96
|
+
llm = create_sampling_llm(params, model)
|
100
97
|
|
101
|
-
#
|
102
|
-
|
98
|
+
# Extract all messages from the request params
|
99
|
+
if not params.messages:
|
100
|
+
raise ValueError("No messages provided")
|
103
101
|
|
104
|
-
#
|
105
|
-
|
106
|
-
|
107
|
-
|
102
|
+
# Convert all SamplingMessages to PromptMessageMultipart objects
|
103
|
+
conversation = SamplingConverter.convert_messages(params.messages)
|
104
|
+
|
105
|
+
# Extract request parameters using our converter
|
106
|
+
request_params = SamplingConverter.extract_request_params(params)
|
108
107
|
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
logger.error(f"Error generating response: {str(e)}")
|
117
|
-
llm_response = f"Echo response: {user_message}"
|
118
|
-
|
119
|
-
# Return the LLM-generated response
|
120
|
-
return CreateMessageResult(
|
121
|
-
role="assistant",
|
122
|
-
content=TextContent(type="text", text=llm_response),
|
123
|
-
model=model,
|
124
|
-
stopReason="endTurn",
|
108
|
+
# Use the new public apply_prompt method which is cleaner than calling the protected method
|
109
|
+
llm_response = await llm.apply_prompt(conversation, request_params)
|
110
|
+
logger.info(f"Complete sampling request : {llm_response[:50]}...")
|
111
|
+
|
112
|
+
# Create result using our converter
|
113
|
+
return SamplingConverter.create_message_result(
|
114
|
+
response=llm_response, model=model
|
125
115
|
)
|
126
116
|
except Exception as e:
|
127
117
|
logger.error(f"Error in sampling: {str(e)}")
|
128
|
-
return
|
129
|
-
|
130
|
-
content=TextContent(type="text", text=f"Error in sampling: {str(e)}"),
|
131
|
-
model=model or "unknown",
|
132
|
-
stopReason="error",
|
118
|
+
return SamplingConverter.error_result(
|
119
|
+
error_message=f"Error in sampling: {str(e)}", model=model
|
133
120
|
)
|
121
|
+
|
122
|
+
|
123
|
+
def sampling_agent_config(
|
124
|
+
params: CreateMessageRequestParams = None,
|
125
|
+
) -> AgentConfig:
|
126
|
+
"""
|
127
|
+
Build a sampling AgentConfig based on request parameters.
|
128
|
+
|
129
|
+
Args:
|
130
|
+
params: Optional CreateMessageRequestParams that may contain a system prompt
|
131
|
+
|
132
|
+
Returns:
|
133
|
+
An initialized AgentConfig for use in sampling
|
134
|
+
"""
|
135
|
+
# Use systemPrompt from params if available, otherwise use default
|
136
|
+
instruction = "You are a helpful AI Agent."
|
137
|
+
if params and hasattr(params, "systemPrompt") and params.systemPrompt is not None:
|
138
|
+
instruction = params.systemPrompt
|
139
|
+
|
140
|
+
return AgentConfig(name="sampling_agent", instruction=instruction, servers=[])
|
mcp_agent/mcp_server/__init__.py
CHANGED
mcp_agent/ui/console_display.py
CHANGED
@@ -250,14 +250,14 @@ class ConsoleDisplay:
|
|
250
250
|
|
251
251
|
if agent_name:
|
252
252
|
content.append(f" for {agent_name}", style="cyan italic")
|
253
|
-
|
253
|
+
|
254
254
|
# Add template arguments if provided
|
255
255
|
if arguments:
|
256
256
|
content.append("\n\nArguments:", style="cyan")
|
257
257
|
for key, value in arguments.items():
|
258
258
|
content.append(f"\n {key}: ", style="cyan bold")
|
259
259
|
content.append(value, style="white")
|
260
|
-
|
260
|
+
|
261
261
|
if description:
|
262
262
|
content.append("\n\n", style="default")
|
263
263
|
content.append(description, style="dim white")
|
@@ -10,7 +10,8 @@ from mcp_agent.workflows.llm.augmented_llm import (
|
|
10
10
|
ModelT,
|
11
11
|
RequestParams,
|
12
12
|
)
|
13
|
-
from mcp_agent.agents.agent import Agent
|
13
|
+
from mcp_agent.agents.agent import Agent
|
14
|
+
from mcp_agent.core.agent_types import AgentConfig
|
14
15
|
from mcp_agent.logging.logger import get_logger
|
15
16
|
from mcp_agent.workflows.llm.augmented_llm_passthrough import PassthroughLLM
|
16
17
|
|
@@ -68,7 +69,6 @@ class EvaluatorOptimizerLLM(AugmentedLLM[MessageParamT, MessageT]):
|
|
68
69
|
def _initialize_default_params(self, kwargs: dict) -> RequestParams:
|
69
70
|
"""Initialize default parameters using the workflow's settings."""
|
70
71
|
return RequestParams(
|
71
|
-
modelPreferences=self.model_preferences,
|
72
72
|
systemPrompt=self.instruction,
|
73
73
|
parallel_tool_calls=True,
|
74
74
|
max_iterations=10,
|
@@ -10,7 +10,6 @@ from typing import (
|
|
10
10
|
TYPE_CHECKING,
|
11
11
|
)
|
12
12
|
|
13
|
-
from mcp import CreateMessageResult, SamplingMessage
|
14
13
|
from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
|
15
14
|
from mcp_agent.workflows.llm.sampling_format_converter import (
|
16
15
|
SamplingFormatConverter,
|
@@ -25,23 +24,28 @@ if TYPE_CHECKING:
|
|
25
24
|
from mcp_agent.context import Context
|
26
25
|
|
27
26
|
|
28
|
-
from pydantic import Field
|
29
27
|
|
30
28
|
from mcp.types import (
|
31
29
|
CallToolRequest,
|
32
30
|
CallToolResult,
|
33
|
-
CreateMessageRequestParams,
|
34
|
-
ModelPreferences,
|
35
31
|
PromptMessage,
|
36
32
|
TextContent,
|
37
33
|
GetPromptResult,
|
38
34
|
)
|
39
35
|
|
40
36
|
from mcp_agent.context_dependent import ContextDependent
|
41
|
-
from mcp_agent.core.exceptions import PromptExitError
|
37
|
+
from mcp_agent.core.exceptions import ModelConfigError, PromptExitError
|
38
|
+
from mcp_agent.core.request_params import RequestParams
|
42
39
|
from mcp_agent.event_progress import ProgressAction
|
43
|
-
|
44
|
-
|
40
|
+
|
41
|
+
try:
|
42
|
+
from mcp_agent.mcp.mcp_aggregator import MCPAggregator
|
43
|
+
except ImportError:
|
44
|
+
# For testing purposes
|
45
|
+
class MCPAggregator:
|
46
|
+
pass
|
47
|
+
|
48
|
+
|
45
49
|
from mcp_agent.ui.console_display import ConsoleDisplay
|
46
50
|
from rich.text import Text
|
47
51
|
|
@@ -155,43 +159,6 @@ class SimpleMemory(Memory, Generic[MessageParamT]):
|
|
155
159
|
self.prompt_messages = []
|
156
160
|
|
157
161
|
|
158
|
-
class RequestParams(CreateMessageRequestParams):
|
159
|
-
"""
|
160
|
-
Parameters to configure the AugmentedLLM 'generate' requests.
|
161
|
-
"""
|
162
|
-
|
163
|
-
messages: None = Field(exclude=True, default=None)
|
164
|
-
"""
|
165
|
-
Ignored. 'messages' are removed from CreateMessageRequestParams
|
166
|
-
to avoid confusion with the 'message' parameter on 'generate' method.
|
167
|
-
"""
|
168
|
-
|
169
|
-
maxTokens: int = 2048
|
170
|
-
"""The maximum number of tokens to sample, as requested by the server."""
|
171
|
-
|
172
|
-
model: str | None = None
|
173
|
-
"""
|
174
|
-
The model to use for the LLM generation.
|
175
|
-
If specified, this overrides the 'modelPreferences' selection criteria.
|
176
|
-
"""
|
177
|
-
|
178
|
-
use_history: bool = True
|
179
|
-
"""
|
180
|
-
Include the message history in the generate request.
|
181
|
-
"""
|
182
|
-
|
183
|
-
max_iterations: int = 10
|
184
|
-
"""
|
185
|
-
The maximum number of iterations to run the LLM for.
|
186
|
-
"""
|
187
|
-
|
188
|
-
parallel_tool_calls: bool = True
|
189
|
-
"""
|
190
|
-
Whether to allow multiple tool calls per iteration.
|
191
|
-
Also known as multi-step tool use.
|
192
|
-
"""
|
193
|
-
|
194
|
-
|
195
162
|
class AugmentedLLMProtocol(Protocol, Generic[MessageParamT, MessageT]):
|
196
163
|
"""Protocol defining the interface for augmented LLMs"""
|
197
164
|
|
@@ -269,30 +236,15 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
269
236
|
# Initialize the display component
|
270
237
|
self.display = ConsoleDisplay(config=self.context.config)
|
271
238
|
|
272
|
-
# Set initial model preferences
|
273
|
-
self.model_preferences = ModelPreferences(
|
274
|
-
costPriority=0.3,
|
275
|
-
speedPriority=0.4,
|
276
|
-
intelligencePriority=0.3,
|
277
|
-
)
|
278
|
-
|
279
239
|
# Initialize default parameters
|
280
240
|
self.default_request_params = self._initialize_default_params(kwargs)
|
281
241
|
|
282
|
-
# Update model preferences from default params
|
283
|
-
if self.default_request_params and self.default_request_params.modelPreferences:
|
284
|
-
self.model_preferences = self.default_request_params.modelPreferences
|
285
|
-
|
286
242
|
# Merge with provided params if any
|
287
243
|
if self._init_request_params:
|
288
244
|
self.default_request_params = self._merge_request_params(
|
289
245
|
self.default_request_params, self._init_request_params
|
290
246
|
)
|
291
|
-
# Update model preferences again if they changed in the merge
|
292
|
-
if self.default_request_params.modelPreferences:
|
293
|
-
self.model_preferences = self.default_request_params.modelPreferences
|
294
247
|
|
295
|
-
self.model_selector = self.context.model_selector
|
296
248
|
self.type_converter = type_converter
|
297
249
|
self.verb = kwargs.get("verb")
|
298
250
|
|
@@ -321,39 +273,21 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
321
273
|
) -> ModelT:
|
322
274
|
"""Request a structured LLM generation and return the result as a Pydantic model."""
|
323
275
|
|
324
|
-
# aysnc def generate2_str(self, prompt: PromptMessageMultipart, request_params: RequestParams | None = None) -> List[MessageT]:
|
325
|
-
# """Request an LLM generation, which may run multiple iterations, and return the result"""
|
326
|
-
# return None
|
327
|
-
|
328
276
|
async def select_model(
|
329
277
|
self, request_params: RequestParams | None = None
|
330
278
|
) -> str | None:
|
331
279
|
"""
|
332
|
-
|
333
|
-
If a model is specified in the request, it will override the model selection criteria.
|
280
|
+
Return the configured model (legacy support)
|
334
281
|
"""
|
335
|
-
|
336
|
-
|
337
|
-
model_preferences = request_params.modelPreferences or model_preferences
|
338
|
-
model = request_params.model
|
339
|
-
if model:
|
340
|
-
return model
|
282
|
+
if request_params.model:
|
283
|
+
return request_params.model
|
341
284
|
|
342
|
-
|
343
|
-
if not self.model_selector:
|
344
|
-
self.model_selector = ModelSelector()
|
345
|
-
|
346
|
-
model_info = self.model_selector.select_best_model(
|
347
|
-
model_preferences=model_preferences, provider=self.provider
|
348
|
-
)
|
349
|
-
|
350
|
-
return model_info.name
|
285
|
+
raise ModelConfigError("Internal Error: Model is not configured correctly")
|
351
286
|
|
352
287
|
def _initialize_default_params(self, kwargs: dict) -> RequestParams:
|
353
288
|
"""Initialize default parameters for the LLM.
|
354
289
|
Should be overridden by provider implementations to set provider-specific defaults."""
|
355
290
|
return RequestParams(
|
356
|
-
modelPreferences=self.model_preferences,
|
357
291
|
systemPrompt=self.instruction,
|
358
292
|
parallel_tool_calls=True,
|
359
293
|
max_iterations=10,
|
@@ -395,25 +329,6 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
395
329
|
|
396
330
|
return default_request_params
|
397
331
|
|
398
|
-
def to_mcp_message_result(self, result: MessageT) -> CreateMessageResult:
|
399
|
-
"""Convert an LLM response to an MCP message result type."""
|
400
|
-
return self.type_converter.to_sampling_result(result)
|
401
|
-
|
402
|
-
def from_mcp_message_result(self, result: CreateMessageResult) -> MessageT:
|
403
|
-
"""Convert an MCP message result to an LLM response type."""
|
404
|
-
return self.type_converter.from_sampling_result(result)
|
405
|
-
|
406
|
-
def to_mcp_message_param(self, param: MessageParamT) -> SamplingMessage:
|
407
|
-
"""Convert an LLM input to an MCP message (SamplingMessage) type."""
|
408
|
-
return self.type_converter.to_sampling_message(param)
|
409
|
-
|
410
|
-
def from_mcp_message_param(self, param: SamplingMessage) -> MessageParamT:
|
411
|
-
"""Convert an MCP message (SamplingMessage) to an LLM input type."""
|
412
|
-
return self.type_converter.from_sampling_message(param)
|
413
|
-
|
414
|
-
def from_mcp_prompt_message(self, message: PromptMessage) -> MessageParamT:
|
415
|
-
return self.type_converter.from_prompt_message(message)
|
416
|
-
|
417
332
|
@classmethod
|
418
333
|
def convert_message_to_message_param(
|
419
334
|
cls, message: MessageT, **kwargs
|
@@ -689,10 +604,35 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
689
604
|
)
|
690
605
|
|
691
606
|
# Delegate to the provider-specific implementation
|
692
|
-
return await self._apply_prompt_template_provider_specific(
|
607
|
+
return await self._apply_prompt_template_provider_specific(
|
608
|
+
multipart_messages, None
|
609
|
+
)
|
610
|
+
|
611
|
+
async def apply_prompt(
|
612
|
+
self,
|
613
|
+
multipart_messages: List["PromptMessageMultipart"],
|
614
|
+
request_params: RequestParams | None = None,
|
615
|
+
) -> str:
|
616
|
+
"""
|
617
|
+
Apply a list of PromptMessageMultipart messages directly to the LLM.
|
618
|
+
This is a cleaner interface to _apply_prompt_template_provider_specific.
|
619
|
+
|
620
|
+
Args:
|
621
|
+
multipart_messages: List of PromptMessageMultipart objects
|
622
|
+
request_params: Optional parameters to configure the LLM request
|
623
|
+
|
624
|
+
Returns:
|
625
|
+
String representation of the assistant's response
|
626
|
+
"""
|
627
|
+
# Delegate to the provider-specific implementation
|
628
|
+
return await self._apply_prompt_template_provider_specific(
|
629
|
+
multipart_messages, request_params
|
630
|
+
)
|
693
631
|
|
694
632
|
async def _apply_prompt_template_provider_specific(
|
695
|
-
self,
|
633
|
+
self,
|
634
|
+
multipart_messages: List["PromptMessageMultipart"],
|
635
|
+
request_params: RequestParams | None = None,
|
696
636
|
) -> str:
|
697
637
|
"""
|
698
638
|
Provider-specific implementation of apply_prompt_template.
|
@@ -60,7 +60,6 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
60
60
|
"""Initialize Anthropic-specific default parameters"""
|
61
61
|
return RequestParams(
|
62
62
|
model=kwargs.get("model", DEFAULT_ANTHROPIC_MODEL),
|
63
|
-
modelPreferences=self.model_preferences,
|
64
63
|
maxTokens=4096, # default haiku3
|
65
64
|
systemPrompt=self.instruction,
|
66
65
|
parallel_tool_calls=True,
|
@@ -360,7 +359,9 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
360
359
|
)
|
361
360
|
|
362
361
|
async def _apply_prompt_template_provider_specific(
|
363
|
-
self,
|
362
|
+
self,
|
363
|
+
multipart_messages: List["PromptMessageMultipart"],
|
364
|
+
request_params: RequestParams | None = None,
|
364
365
|
) -> str:
|
365
366
|
"""
|
366
367
|
Anthropic-specific implementation of apply_prompt_template that handles
|
@@ -393,7 +394,7 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
393
394
|
"Last message in prompt is from user, generating assistant response"
|
394
395
|
)
|
395
396
|
message_param = AnthropicConverter.convert_to_anthropic(last_message)
|
396
|
-
return await self.generate_str(message_param)
|
397
|
+
return await self.generate_str(message_param, request_params)
|
397
398
|
else:
|
398
399
|
# For assistant messages: Return the last message content as text
|
399
400
|
self.logger.debug(
|
@@ -92,7 +92,6 @@ class OpenAIAugmentedLLM(
|
|
92
92
|
|
93
93
|
return RequestParams(
|
94
94
|
model=chosen_model,
|
95
|
-
modelPreferences=self.model_preferences,
|
96
95
|
systemPrompt=self.instruction,
|
97
96
|
parallel_tool_calls=True,
|
98
97
|
max_iterations=10,
|
@@ -395,7 +394,9 @@ class OpenAIAugmentedLLM(
|
|
395
394
|
return "\n".join(final_text)
|
396
395
|
|
397
396
|
async def _apply_prompt_template_provider_specific(
|
398
|
-
self,
|
397
|
+
self,
|
398
|
+
multipart_messages: List["PromptMessageMultipart"],
|
399
|
+
request_params: RequestParams | None = None,
|
399
400
|
) -> str:
|
400
401
|
"""
|
401
402
|
OpenAI-specific implementation of apply_prompt_template that handles
|
@@ -431,7 +432,7 @@ class OpenAIAugmentedLLM(
|
|
431
432
|
"Last message in prompt is from user, generating assistant response"
|
432
433
|
)
|
433
434
|
message_param = OpenAIConverter.convert_to_openai(last_message)
|
434
|
-
return await self.generate_str(message_param)
|
435
|
+
return await self.generate_str(message_param, request_params)
|
435
436
|
else:
|
436
437
|
# For assistant messages: Return the last message content as text
|
437
438
|
self.logger.debug(
|
@@ -181,21 +181,42 @@ class PassthroughLLM(AugmentedLLM):
|
|
181
181
|
# Join all parts and process with generate_str
|
182
182
|
return await self.generate_str("\n".join(parts_text), request_params)
|
183
183
|
|
184
|
+
async def apply_prompt(
|
185
|
+
self,
|
186
|
+
multipart_messages: List["PromptMessageMultipart"],
|
187
|
+
request_params: Optional[RequestParams] = None,
|
188
|
+
) -> str:
|
189
|
+
"""
|
190
|
+
Apply a list of PromptMessageMultipart messages directly to the LLM.
|
191
|
+
In PassthroughLLM, this returns a concatenated string of all message content.
|
192
|
+
|
193
|
+
Args:
|
194
|
+
multipart_messages: List of PromptMessageMultipart objects
|
195
|
+
request_params: Optional parameters to configure the LLM request
|
196
|
+
|
197
|
+
Returns:
|
198
|
+
String representation of all message content concatenated together
|
199
|
+
"""
|
200
|
+
# Generate and concatenate result from all messages
|
201
|
+
result = ""
|
202
|
+
for prompt in multipart_messages:
|
203
|
+
result += await self.generate_prompt(prompt, request_params) + "\n"
|
204
|
+
|
205
|
+
return result
|
206
|
+
|
184
207
|
async def apply_prompt_template(
|
185
208
|
self, prompt_result: GetPromptResult, prompt_name: str
|
186
209
|
) -> str:
|
187
210
|
"""
|
188
211
|
Apply a prompt template by adding it to the conversation history.
|
189
|
-
|
190
|
-
generate an assistant response.
|
212
|
+
For PassthroughLLM, this returns all content concatenated together.
|
191
213
|
|
192
214
|
Args:
|
193
215
|
prompt_result: The GetPromptResult containing prompt messages
|
194
216
|
prompt_name: The name of the prompt being applied
|
195
217
|
|
196
218
|
Returns:
|
197
|
-
String representation of
|
198
|
-
or the last assistant message in the prompt
|
219
|
+
String representation of all message content concatenated together
|
199
220
|
"""
|
200
221
|
prompt_messages: List[PromptMessage] = prompt_result.messages
|
201
222
|
|
@@ -210,3 +231,11 @@ class PassthroughLLM(AugmentedLLM):
|
|
210
231
|
arguments=arguments,
|
211
232
|
)
|
212
233
|
self._messages = prompt_messages
|
234
|
+
|
235
|
+
# Convert prompt messages to multipart format
|
236
|
+
multipart_messages = PromptMessageMultipart.from_prompt_messages(
|
237
|
+
prompt_messages
|
238
|
+
)
|
239
|
+
|
240
|
+
# Use apply_prompt to handle the multipart messages
|
241
|
+
return await self.apply_prompt(multipart_messages)
|
@@ -4,9 +4,9 @@ from typing import Optional, Type, Dict, Union, Callable
|
|
4
4
|
|
5
5
|
from mcp_agent.agents.agent import Agent
|
6
6
|
from mcp_agent.core.exceptions import ModelConfigError
|
7
|
+
from mcp_agent.core.request_params import RequestParams
|
7
8
|
from mcp_agent.workflows.llm.augmented_llm_anthropic import AnthropicAugmentedLLM
|
8
9
|
from mcp_agent.workflows.llm.augmented_llm_openai import OpenAIAugmentedLLM
|
9
|
-
from mcp_agent.workflows.llm.augmented_llm import RequestParams
|
10
10
|
from mcp_agent.workflows.llm.augmented_llm_passthrough import PassthroughLLM
|
11
11
|
from mcp_agent.workflows.llm.augmented_llm_playback import PlaybackLLM
|
12
12
|
|