fast-agent-mcp 0.2.16__py3-none-any.whl → 0.2.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {fast_agent_mcp-0.2.16.dist-info → fast_agent_mcp-0.2.18.dist-info}/METADATA +6 -7
- {fast_agent_mcp-0.2.16.dist-info → fast_agent_mcp-0.2.18.dist-info}/RECORD +48 -47
- mcp_agent/agents/base_agent.py +50 -6
- mcp_agent/agents/workflow/orchestrator_agent.py +6 -7
- mcp_agent/agents/workflow/router_agent.py +70 -136
- mcp_agent/app.py +1 -124
- mcp_agent/cli/commands/go.py +133 -0
- mcp_agent/cli/commands/setup.py +2 -2
- mcp_agent/cli/main.py +5 -3
- mcp_agent/config.py +16 -13
- mcp_agent/context.py +4 -22
- mcp_agent/core/agent_types.py +2 -2
- mcp_agent/core/direct_decorators.py +2 -2
- mcp_agent/core/direct_factory.py +2 -1
- mcp_agent/core/enhanced_prompt.py +12 -7
- mcp_agent/core/fastagent.py +39 -5
- mcp_agent/core/interactive_prompt.py +6 -2
- mcp_agent/core/request_params.py +5 -1
- mcp_agent/core/validation.py +12 -1
- mcp_agent/executor/workflow_signal.py +0 -2
- mcp_agent/llm/augmented_llm.py +183 -57
- mcp_agent/llm/augmented_llm_passthrough.py +1 -1
- mcp_agent/llm/augmented_llm_playback.py +21 -1
- mcp_agent/llm/memory.py +3 -3
- mcp_agent/llm/model_factory.py +3 -1
- mcp_agent/llm/provider_key_manager.py +1 -0
- mcp_agent/llm/provider_types.py +2 -1
- mcp_agent/llm/providers/augmented_llm_anthropic.py +50 -10
- mcp_agent/llm/providers/augmented_llm_deepseek.py +1 -5
- mcp_agent/llm/providers/augmented_llm_google.py +30 -0
- mcp_agent/llm/providers/augmented_llm_openai.py +96 -159
- mcp_agent/llm/providers/multipart_converter_openai.py +10 -27
- mcp_agent/llm/providers/sampling_converter_openai.py +5 -6
- mcp_agent/mcp/interfaces.py +6 -1
- mcp_agent/mcp/mcp_aggregator.py +2 -8
- mcp_agent/mcp/prompt_message_multipart.py +25 -2
- mcp_agent/resources/examples/data-analysis/analysis-campaign.py +2 -2
- mcp_agent/resources/examples/in_dev/agent_build.py +1 -1
- mcp_agent/resources/examples/internal/job.py +1 -1
- mcp_agent/resources/examples/mcp/state-transfer/fastagent.config.yaml +1 -1
- mcp_agent/resources/examples/prompting/agent.py +0 -2
- mcp_agent/resources/examples/prompting/fastagent.config.yaml +2 -3
- mcp_agent/resources/examples/researcher/fastagent.config.yaml +1 -6
- mcp_agent/resources/examples/workflows/fastagent.config.yaml +0 -1
- mcp_agent/resources/examples/workflows/parallel.py +1 -1
- mcp_agent/executor/decorator_registry.py +0 -112
- {fast_agent_mcp-0.2.16.dist-info → fast_agent_mcp-0.2.18.dist-info}/WHEEL +0 -0
- {fast_agent_mcp-0.2.16.dist-info → fast_agent_mcp-0.2.18.dist-info}/entry_points.txt +0 -0
- {fast_agent_mcp-0.2.16.dist-info → fast_agent_mcp-0.2.18.dist-info}/licenses/LICENSE +0 -0
mcp_agent/llm/augmented_llm.py
CHANGED
@@ -18,6 +18,8 @@ from mcp.types import (
|
|
18
18
|
PromptMessage,
|
19
19
|
TextContent,
|
20
20
|
)
|
21
|
+
from openai import NotGiven
|
22
|
+
from openai.lib._parsing import type_to_response_format_param as _type_to_response_format
|
21
23
|
from pydantic_core import from_json
|
22
24
|
from rich.text import Text
|
23
25
|
|
@@ -58,6 +60,20 @@ HUMAN_INPUT_TOOL_NAME = "__human_input__"
|
|
58
60
|
|
59
61
|
|
60
62
|
class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT, MessageT]):
|
63
|
+
# Common parameter names used across providers
|
64
|
+
PARAM_MESSAGES = "messages"
|
65
|
+
PARAM_MODEL = "model"
|
66
|
+
PARAM_MAX_TOKENS = "maxTokens"
|
67
|
+
PARAM_SYSTEM_PROMPT = "systemPrompt"
|
68
|
+
PARAM_STOP_SEQUENCES = "stopSequences"
|
69
|
+
PARAM_PARALLEL_TOOL_CALLS = "parallel_tool_calls"
|
70
|
+
PARAM_METADATA = "metadata"
|
71
|
+
PARAM_USE_HISTORY = "use_history"
|
72
|
+
PARAM_MAX_ITERATIONS = "max_iterations"
|
73
|
+
|
74
|
+
# Base set of fields that should always be excluded
|
75
|
+
BASE_EXCLUDE_FIELDS = {PARAM_METADATA}
|
76
|
+
|
61
77
|
"""
|
62
78
|
The basic building block of agentic systems is an LLM enhanced with augmentations
|
63
79
|
such as retrieval, tools, and memory provided from a collection of MCP servers.
|
@@ -141,26 +157,6 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
|
|
141
157
|
use_history=True,
|
142
158
|
)
|
143
159
|
|
144
|
-
async def structured(
|
145
|
-
self,
|
146
|
-
prompt: List[PromptMessageMultipart],
|
147
|
-
model: Type[ModelT],
|
148
|
-
request_params: RequestParams | None = None,
|
149
|
-
) -> Tuple[ModelT | None, PromptMessageMultipart]:
|
150
|
-
"""Apply the prompt and return the result as a Pydantic model, or None if coercion fails"""
|
151
|
-
try:
|
152
|
-
result: PromptMessageMultipart = await self.generate(prompt, request_params)
|
153
|
-
final_generation = get_text(result.content[-1]) or ""
|
154
|
-
await self.show_assistant_message(final_generation)
|
155
|
-
json_data = from_json(final_generation, allow_partial=True)
|
156
|
-
validated_model = model.model_validate(json_data)
|
157
|
-
|
158
|
-
return cast("ModelT", validated_model), Prompt.assistant(json_data)
|
159
|
-
except Exception as e:
|
160
|
-
logger = get_logger(__name__)
|
161
|
-
logger.error(f"Failed to parse structured response: {str(e)}")
|
162
|
-
return None, Prompt.assistant(f"Failed to parse structured response: {str(e)}")
|
163
|
-
|
164
160
|
async def generate(
|
165
161
|
self,
|
166
162
|
multipart_messages: List[PromptMessageMultipart],
|
@@ -169,6 +165,12 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
|
|
169
165
|
"""
|
170
166
|
Create a completion with the LLM using the provided messages.
|
171
167
|
"""
|
168
|
+
# note - check changes here are mirrored in structured(). i've thought hard about
|
169
|
+
# a strategy to reduce duplication etc, but aiming for simple but imperfect for the moment
|
170
|
+
|
171
|
+
# We never expect this for structured() calls - this is for interactive use - developers
|
172
|
+
# can do this programatically
|
173
|
+
# TODO -- create a "fast-agent" control role rather than magic strings
|
172
174
|
if multipart_messages[-1].first_text().startswith("***SAVE_HISTORY"):
|
173
175
|
parts: list[str] = multipart_messages[-1].first_text().split(" ", 1)
|
174
176
|
filename: str = (
|
@@ -180,26 +182,174 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
|
|
180
182
|
)
|
181
183
|
return Prompt.assistant(f"History saved to {filename}")
|
182
184
|
|
183
|
-
self.
|
184
|
-
|
185
|
-
if multipart_messages[-1].role == "user":
|
186
|
-
self.show_user_message(
|
187
|
-
render_multipart_message(multipart_messages[-1]),
|
188
|
-
model=self.default_request_params.model,
|
189
|
-
chat_turn=self.chat_turn(),
|
190
|
-
)
|
185
|
+
self._precall(multipart_messages)
|
191
186
|
|
192
187
|
assistant_response: PromptMessageMultipart = await self._apply_prompt_provider_specific(
|
193
188
|
multipart_messages, request_params
|
194
189
|
)
|
195
190
|
|
191
|
+
# add generic error and termination reason handling/rollback
|
196
192
|
self._message_history.append(assistant_response)
|
197
193
|
return assistant_response
|
198
194
|
|
195
|
+
@abstractmethod
|
196
|
+
async def _apply_prompt_provider_specific(
|
197
|
+
self,
|
198
|
+
multipart_messages: List["PromptMessageMultipart"],
|
199
|
+
request_params: RequestParams | None = None,
|
200
|
+
is_template: bool = False,
|
201
|
+
) -> PromptMessageMultipart:
|
202
|
+
"""
|
203
|
+
Provider-specific implementation of apply_prompt_template.
|
204
|
+
This default implementation handles basic text content for any LLM type.
|
205
|
+
Provider-specific subclasses should override this method to handle
|
206
|
+
multimodal content appropriately.
|
207
|
+
|
208
|
+
Args:
|
209
|
+
multipart_messages: List of PromptMessageMultipart objects parsed from the prompt template
|
210
|
+
|
211
|
+
Returns:
|
212
|
+
String representation of the assistant's response if generated,
|
213
|
+
or the last assistant message in the prompt
|
214
|
+
"""
|
215
|
+
|
216
|
+
async def structured(
|
217
|
+
self,
|
218
|
+
multipart_messages: List[PromptMessageMultipart],
|
219
|
+
model: Type[ModelT],
|
220
|
+
request_params: RequestParams | None = None,
|
221
|
+
) -> Tuple[ModelT | None, PromptMessageMultipart]:
|
222
|
+
"""Return a structured response from the LLM using the provided messages."""
|
223
|
+
self._precall(multipart_messages)
|
224
|
+
result, assistant_response = await self._apply_prompt_provider_specific_structured(
|
225
|
+
multipart_messages, model, request_params
|
226
|
+
)
|
227
|
+
|
228
|
+
self._message_history.append(assistant_response)
|
229
|
+
return result, assistant_response
|
230
|
+
|
231
|
+
@staticmethod
|
232
|
+
def model_to_response_format(
|
233
|
+
model: Type[Any],
|
234
|
+
) -> Any:
|
235
|
+
"""
|
236
|
+
Convert a pydantic model to the appropriate response format schema.
|
237
|
+
This allows for reuse in multiple provider implementations.
|
238
|
+
|
239
|
+
Args:
|
240
|
+
model: The pydantic model class to convert to a schema
|
241
|
+
|
242
|
+
Returns:
|
243
|
+
Provider-agnostic schema representation or NotGiven if conversion fails
|
244
|
+
"""
|
245
|
+
return _type_to_response_format(model)
|
246
|
+
|
247
|
+
@staticmethod
|
248
|
+
def model_to_schema_str(
|
249
|
+
model: Type[Any],
|
250
|
+
) -> str:
|
251
|
+
"""
|
252
|
+
Convert a pydantic model to a schema string representation.
|
253
|
+
This provides a simpler interface for provider implementations
|
254
|
+
that need a string representation.
|
255
|
+
|
256
|
+
Args:
|
257
|
+
model: The pydantic model class to convert to a schema
|
258
|
+
|
259
|
+
Returns:
|
260
|
+
Schema as a string, or empty string if conversion fails
|
261
|
+
"""
|
262
|
+
import json
|
263
|
+
|
264
|
+
try:
|
265
|
+
schema = model.model_json_schema()
|
266
|
+
return json.dumps(schema)
|
267
|
+
except Exception:
|
268
|
+
return ""
|
269
|
+
|
270
|
+
async def _apply_prompt_provider_specific_structured(
|
271
|
+
self,
|
272
|
+
multipart_messages: List[PromptMessageMultipart],
|
273
|
+
model: Type[ModelT],
|
274
|
+
request_params: RequestParams | None = None,
|
275
|
+
) -> Tuple[ModelT | None, PromptMessageMultipart]:
|
276
|
+
"""Base class attempts to parse JSON - subclasses can use provider specific functionality"""
|
277
|
+
|
278
|
+
request_params = self.get_request_params(request_params)
|
279
|
+
|
280
|
+
if not request_params.response_format:
|
281
|
+
schema = self.model_to_response_format(model)
|
282
|
+
if schema is not NotGiven:
|
283
|
+
request_params.response_format = schema
|
284
|
+
|
285
|
+
result: PromptMessageMultipart = await self._apply_prompt_provider_specific(
|
286
|
+
multipart_messages, request_params
|
287
|
+
)
|
288
|
+
return self._structured_from_multipart(result, model)
|
289
|
+
|
290
|
+
def _structured_from_multipart(
|
291
|
+
self, message: PromptMessageMultipart, model: Type[ModelT]
|
292
|
+
) -> Tuple[ModelT | None, PromptMessageMultipart]:
|
293
|
+
"""Parse the content of a PromptMessage and return the structured model and message itself"""
|
294
|
+
try:
|
295
|
+
text = get_text(message.content[-1]) or ""
|
296
|
+
json_data = from_json(text, allow_partial=True)
|
297
|
+
validated_model = model.model_validate(json_data)
|
298
|
+
return cast("ModelT", validated_model), message
|
299
|
+
except ValueError as e:
|
300
|
+
logger = get_logger(__name__)
|
301
|
+
logger.warning(f"Failed to parse structured response: {str(e)}")
|
302
|
+
return None, message
|
303
|
+
|
304
|
+
def _precall(self, multipart_messages: List[PromptMessageMultipart]) -> None:
|
305
|
+
"""Pre-call hook to modify the message before sending it to the provider."""
|
306
|
+
self._message_history.extend(multipart_messages)
|
307
|
+
if multipart_messages[-1].role == "user":
|
308
|
+
self.show_user_message(
|
309
|
+
render_multipart_message(multipart_messages[-1]),
|
310
|
+
model=self.default_request_params.model,
|
311
|
+
chat_turn=self.chat_turn(),
|
312
|
+
)
|
313
|
+
|
199
314
|
def chat_turn(self) -> int:
|
200
315
|
"""Return the current chat turn number"""
|
201
316
|
return 1 + sum(1 for message in self._message_history if message.role == "assistant")
|
202
317
|
|
318
|
+
def prepare_provider_arguments(
|
319
|
+
self,
|
320
|
+
base_args: dict,
|
321
|
+
request_params: RequestParams,
|
322
|
+
exclude_fields: set | None = None,
|
323
|
+
) -> dict:
|
324
|
+
"""
|
325
|
+
Prepare arguments for provider API calls by merging request parameters.
|
326
|
+
|
327
|
+
Args:
|
328
|
+
base_args: Base arguments dictionary with provider-specific required parameters
|
329
|
+
params: The RequestParams object containing all parameters
|
330
|
+
exclude_fields: Set of field names to exclude from params. If None, uses BASE_EXCLUDE_FIELDS.
|
331
|
+
|
332
|
+
Returns:
|
333
|
+
Complete arguments dictionary with all applicable parameters
|
334
|
+
"""
|
335
|
+
# Start with base arguments
|
336
|
+
arguments = base_args.copy()
|
337
|
+
|
338
|
+
# Use provided exclude_fields or fall back to base exclusions
|
339
|
+
exclude_fields = exclude_fields or self.BASE_EXCLUDE_FIELDS.copy()
|
340
|
+
|
341
|
+
# Add all fields from params that aren't explicitly excluded
|
342
|
+
params_dict = request_params.model_dump(exclude=exclude_fields)
|
343
|
+
for key, value in params_dict.items():
|
344
|
+
if value is not None and key not in arguments:
|
345
|
+
arguments[key] = value
|
346
|
+
|
347
|
+
# Finally, add any metadata fields as a last layer of overrides
|
348
|
+
if request_params.metadata:
|
349
|
+
arguments.update(request_params.metadata)
|
350
|
+
|
351
|
+
return arguments
|
352
|
+
|
203
353
|
def _merge_request_params(
|
204
354
|
self, default_params: RequestParams, provided_params: RequestParams
|
205
355
|
) -> RequestParams:
|
@@ -214,7 +364,6 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
|
|
214
364
|
def get_request_params(
|
215
365
|
self,
|
216
366
|
request_params: RequestParams | None = None,
|
217
|
-
default: RequestParams | None = None,
|
218
367
|
) -> RequestParams:
|
219
368
|
"""
|
220
369
|
Get request parameters with merged-in defaults and overrides.
|
@@ -223,17 +372,12 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
|
|
223
372
|
default: The default request parameters to use as the base.
|
224
373
|
If unspecified, self.default_request_params will be used.
|
225
374
|
"""
|
226
|
-
# Start with the defaults
|
227
|
-
default_request_params = default or self.default_request_params
|
228
|
-
|
229
|
-
if not default_request_params:
|
230
|
-
default_request_params = self._initialize_default_params({})
|
231
375
|
|
232
376
|
# If user provides overrides, merge them with defaults
|
233
377
|
if request_params:
|
234
|
-
return self._merge_request_params(default_request_params, request_params)
|
378
|
+
return self._merge_request_params(self.default_request_params, request_params)
|
235
379
|
|
236
|
-
return default_request_params
|
380
|
+
return self.default_request_params.model_copy()
|
237
381
|
|
238
382
|
@classmethod
|
239
383
|
def convert_message_to_message_param(
|
@@ -435,7 +579,9 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
|
|
435
579
|
multipart_messages = PromptMessageMultipart.parse_get_prompt_result(prompt_result)
|
436
580
|
|
437
581
|
# Delegate to the provider-specific implementation
|
438
|
-
result = await self._apply_prompt_provider_specific(
|
582
|
+
result = await self._apply_prompt_provider_specific(
|
583
|
+
multipart_messages, None, is_template=True
|
584
|
+
)
|
439
585
|
return result.first_text()
|
440
586
|
|
441
587
|
async def _save_history(self, filename: str) -> None:
|
@@ -450,26 +596,6 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
|
|
450
596
|
# Save messages using the unified save function that auto-detects format
|
451
597
|
save_messages_to_file(self._message_history, filename)
|
452
598
|
|
453
|
-
@abstractmethod
|
454
|
-
async def _apply_prompt_provider_specific(
|
455
|
-
self,
|
456
|
-
multipart_messages: List["PromptMessageMultipart"],
|
457
|
-
request_params: RequestParams | None = None,
|
458
|
-
) -> PromptMessageMultipart:
|
459
|
-
"""
|
460
|
-
Provider-specific implementation of apply_prompt_template.
|
461
|
-
This default implementation handles basic text content for any LLM type.
|
462
|
-
Provider-specific subclasses should override this method to handle
|
463
|
-
multimodal content appropriately.
|
464
|
-
|
465
|
-
Args:
|
466
|
-
multipart_messages: List of PromptMessageMultipart objects parsed from the prompt template
|
467
|
-
|
468
|
-
Returns:
|
469
|
-
String representation of the assistant's response if generated,
|
470
|
-
or the last assistant message in the prompt
|
471
|
-
"""
|
472
|
-
|
473
599
|
@property
|
474
600
|
def message_history(self) -> List[PromptMessageMultipart]:
|
475
601
|
"""
|
@@ -143,7 +143,6 @@ class PassthroughLLM(AugmentedLLM):
|
|
143
143
|
) -> PromptMessageMultipart:
|
144
144
|
last_message = multipart_messages[-1]
|
145
145
|
|
146
|
-
# TODO -- improve when we support Audio/Multimodal gen
|
147
146
|
if self.is_tool_call(last_message):
|
148
147
|
result = Prompt.assistant(await self.generate_str(last_message.first_text()))
|
149
148
|
await self.show_assistant_message(result.first_text())
|
@@ -158,6 +157,7 @@ class PassthroughLLM(AugmentedLLM):
|
|
158
157
|
await self.show_assistant_message(self._fixed_response)
|
159
158
|
return Prompt.assistant(self._fixed_response)
|
160
159
|
else:
|
160
|
+
# TODO -- improve when we support Audio/Multimodal gen models e.g. gemini . This should really just return the input as "assistant"...
|
161
161
|
concatenated: str = "\n".join(message.all_text() for message in multipart_messages)
|
162
162
|
await self.show_assistant_message(concatenated)
|
163
163
|
return Prompt.assistant(concatenated)
|
@@ -1,9 +1,11 @@
|
|
1
|
-
from typing import Any, List
|
1
|
+
from typing import Any, List, Type
|
2
2
|
|
3
|
+
from mcp_agent.core.exceptions import ModelConfigError
|
3
4
|
from mcp_agent.core.prompt import Prompt
|
4
5
|
from mcp_agent.llm.augmented_llm import RequestParams
|
5
6
|
from mcp_agent.llm.augmented_llm_passthrough import PassthroughLLM
|
6
7
|
from mcp_agent.llm.provider_types import Provider
|
8
|
+
from mcp_agent.mcp.interfaces import ModelT
|
7
9
|
from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
|
8
10
|
from mcp_agent.mcp.prompts.prompt_helpers import MessageContent
|
9
11
|
|
@@ -82,3 +84,21 @@ class PlaybackLLM(PassthroughLLM):
|
|
82
84
|
)
|
83
85
|
|
84
86
|
return response
|
87
|
+
|
88
|
+
async def structured(
|
89
|
+
self,
|
90
|
+
multipart_messages: List[PromptMessageMultipart],
|
91
|
+
model: Type[ModelT],
|
92
|
+
request_params: RequestParams | None = None,
|
93
|
+
) -> tuple[ModelT | None, PromptMessageMultipart]:
|
94
|
+
"""
|
95
|
+
Handle structured requests by returning the next assistant message.
|
96
|
+
"""
|
97
|
+
|
98
|
+
if -1 == self._current_index:
|
99
|
+
raise ModelConfigError("Use generate() to load playback history")
|
100
|
+
|
101
|
+
return self._structured_from_multipart(
|
102
|
+
self._get_next_assistant_message(),
|
103
|
+
model,
|
104
|
+
)
|
mcp_agent/llm/memory.py
CHANGED
@@ -19,7 +19,7 @@ class Memory(Protocol, Generic[MessageParamT]):
|
|
19
19
|
|
20
20
|
def append(self, message: MessageParamT, is_prompt: bool = False) -> None: ...
|
21
21
|
|
22
|
-
def get(self,
|
22
|
+
def get(self, include_completion_history: bool = True) -> List[MessageParamT]: ...
|
23
23
|
|
24
24
|
def clear(self, clear_prompts: bool = False) -> None: ...
|
25
25
|
|
@@ -75,7 +75,7 @@ class SimpleMemory(Memory, Generic[MessageParamT]):
|
|
75
75
|
else:
|
76
76
|
self.history.append(message)
|
77
77
|
|
78
|
-
def get(self,
|
78
|
+
def get(self, include_completion_history: bool = True) -> List[MessageParamT]:
|
79
79
|
"""
|
80
80
|
Get all messages in memory.
|
81
81
|
|
@@ -86,7 +86,7 @@ class SimpleMemory(Memory, Generic[MessageParamT]):
|
|
86
86
|
Returns:
|
87
87
|
Combined list of prompt messages and optionally history messages
|
88
88
|
"""
|
89
|
-
if
|
89
|
+
if include_completion_history:
|
90
90
|
return self.prompt_messages + self.history
|
91
91
|
else:
|
92
92
|
return self.prompt_messages.copy()
|
mcp_agent/llm/model_factory.py
CHANGED
@@ -12,6 +12,7 @@ from mcp_agent.llm.provider_types import Provider
|
|
12
12
|
from mcp_agent.llm.providers.augmented_llm_anthropic import AnthropicAugmentedLLM
|
13
13
|
from mcp_agent.llm.providers.augmented_llm_deepseek import DeepSeekAugmentedLLM
|
14
14
|
from mcp_agent.llm.providers.augmented_llm_generic import GenericAugmentedLLM
|
15
|
+
from mcp_agent.llm.providers.augmented_llm_google import GoogleAugmentedLLM
|
15
16
|
from mcp_agent.llm.providers.augmented_llm_openai import OpenAIAugmentedLLM
|
16
17
|
from mcp_agent.llm.providers.augmented_llm_openrouter import OpenRouterAugmentedLLM
|
17
18
|
from mcp_agent.mcp.interfaces import AugmentedLLMProtocol
|
@@ -107,6 +108,7 @@ class ModelFactory:
|
|
107
108
|
Provider.FAST_AGENT: PassthroughLLM,
|
108
109
|
Provider.DEEPSEEK: DeepSeekAugmentedLLM,
|
109
110
|
Provider.GENERIC: GenericAugmentedLLM,
|
111
|
+
Provider.GOOGLE: GoogleAugmentedLLM, # type: ignore
|
110
112
|
Provider.OPENROUTER: OpenRouterAugmentedLLM,
|
111
113
|
}
|
112
114
|
|
@@ -161,7 +163,7 @@ class ModelFactory:
|
|
161
163
|
Creates a factory function that follows the attach_llm protocol.
|
162
164
|
|
163
165
|
Args:
|
164
|
-
model_string: The model specification string (e.g. "gpt-
|
166
|
+
model_string: The model specification string (e.g. "gpt-4.1")
|
165
167
|
request_params: Optional parameters to configure LLM behavior
|
166
168
|
|
167
169
|
Returns:
|
mcp_agent/llm/provider_types.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import TYPE_CHECKING, List
|
1
|
+
from typing import TYPE_CHECKING, List, Tuple, Type
|
2
2
|
|
3
3
|
from mcp.types import EmbeddedResource, ImageContent, TextContent
|
4
4
|
|
@@ -10,6 +10,7 @@ from mcp_agent.llm.providers.multipart_converter_anthropic import (
|
|
10
10
|
from mcp_agent.llm.providers.sampling_converter_anthropic import (
|
11
11
|
AnthropicSamplingConverter,
|
12
12
|
)
|
13
|
+
from mcp_agent.mcp.interfaces import ModelT
|
13
14
|
from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
|
14
15
|
|
15
16
|
if TYPE_CHECKING:
|
@@ -50,6 +51,19 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
50
51
|
selecting appropriate tools, and determining what information to retain.
|
51
52
|
"""
|
52
53
|
|
54
|
+
# Anthropic-specific parameter exclusions
|
55
|
+
ANTHROPIC_EXCLUDE_FIELDS = {
|
56
|
+
AugmentedLLM.PARAM_MESSAGES,
|
57
|
+
AugmentedLLM.PARAM_MODEL,
|
58
|
+
AugmentedLLM.PARAM_SYSTEM_PROMPT,
|
59
|
+
AugmentedLLM.PARAM_STOP_SEQUENCES,
|
60
|
+
AugmentedLLM.PARAM_MAX_TOKENS,
|
61
|
+
AugmentedLLM.PARAM_METADATA,
|
62
|
+
AugmentedLLM.PARAM_USE_HISTORY,
|
63
|
+
AugmentedLLM.PARAM_MAX_ITERATIONS,
|
64
|
+
AugmentedLLM.PARAM_PARALLEL_TOOL_CALLS,
|
65
|
+
}
|
66
|
+
|
53
67
|
def __init__(self, *args, **kwargs) -> None:
|
54
68
|
# Initialize logger - keep it simple without name reference
|
55
69
|
self.logger = get_logger(__name__)
|
@@ -73,7 +87,7 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
73
87
|
assert self.context.config
|
74
88
|
return self.context.config.anthropic.base_url if self.context.config.anthropic else None
|
75
89
|
|
76
|
-
async def
|
90
|
+
async def _anthropic_completion(
|
77
91
|
self,
|
78
92
|
message_param,
|
79
93
|
request_params: RequestParams | None = None,
|
@@ -100,7 +114,7 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
100
114
|
|
101
115
|
# Always include prompt messages, but only include conversation history
|
102
116
|
# if use_history is True
|
103
|
-
messages.extend(self.history.get(
|
117
|
+
messages.extend(self.history.get(include_completion_history=params.use_history))
|
104
118
|
|
105
119
|
messages.append(message_param)
|
106
120
|
|
@@ -120,7 +134,8 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
120
134
|
|
121
135
|
for i in range(params.max_iterations):
|
122
136
|
self._log_chat_progress(self.chat_turn(), model=model)
|
123
|
-
arguments
|
137
|
+
# Create base arguments dictionary
|
138
|
+
base_args = {
|
124
139
|
"model": model,
|
125
140
|
"messages": messages,
|
126
141
|
"system": self.instruction or params.systemPrompt,
|
@@ -129,10 +144,12 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
129
144
|
}
|
130
145
|
|
131
146
|
if params.maxTokens is not None:
|
132
|
-
|
147
|
+
base_args["max_tokens"] = params.maxTokens
|
133
148
|
|
134
|
-
|
135
|
-
|
149
|
+
# Use the base class method to prepare all arguments with Anthropic-specific exclusions
|
150
|
+
arguments = self.prepare_provider_arguments(
|
151
|
+
base_args, params, self.ANTHROPIC_EXCLUDE_FIELDS
|
152
|
+
)
|
136
153
|
|
137
154
|
self.logger.debug(f"{arguments}")
|
138
155
|
|
@@ -265,7 +282,7 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
265
282
|
# Keep the prompt messages separate
|
266
283
|
if params.use_history:
|
267
284
|
# Get current prompt messages
|
268
|
-
prompt_messages = self.history.get(
|
285
|
+
prompt_messages = self.history.get(include_completion_history=False)
|
269
286
|
|
270
287
|
# Calculate new conversation messages (excluding prompts)
|
271
288
|
new_messages = messages[len(prompt_messages) :]
|
@@ -288,7 +305,7 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
288
305
|
Override this method to use a different LLM.
|
289
306
|
|
290
307
|
"""
|
291
|
-
res = await self.
|
308
|
+
res = await self._anthropic_completion(
|
292
309
|
message_param=message_param,
|
293
310
|
request_params=request_params,
|
294
311
|
)
|
@@ -298,6 +315,7 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
298
315
|
self,
|
299
316
|
multipart_messages: List["PromptMessageMultipart"],
|
300
317
|
request_params: RequestParams | None = None,
|
318
|
+
is_template: bool = False,
|
301
319
|
) -> PromptMessageMultipart:
|
302
320
|
# Check the last message role
|
303
321
|
last_message = multipart_messages[-1]
|
@@ -310,7 +328,7 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
310
328
|
for msg in messages_to_add:
|
311
329
|
converted.append(AnthropicConverter.convert_to_anthropic(msg))
|
312
330
|
|
313
|
-
self.history.extend(converted, is_prompt=
|
331
|
+
self.history.extend(converted, is_prompt=is_template)
|
314
332
|
|
315
333
|
if last_message.role == "user":
|
316
334
|
self.logger.debug("Last message in prompt is from user, generating assistant response")
|
@@ -321,6 +339,28 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
321
339
|
self.logger.debug("Last message in prompt is from assistant, returning it directly")
|
322
340
|
return last_message
|
323
341
|
|
342
|
+
async def _apply_prompt_provider_specific_structured(
|
343
|
+
self,
|
344
|
+
multipart_messages: List[PromptMessageMultipart],
|
345
|
+
model: Type[ModelT],
|
346
|
+
request_params: RequestParams | None = None,
|
347
|
+
) -> Tuple[ModelT | None, PromptMessageMultipart]: # noqa: F821
|
348
|
+
request_params = self.get_request_params(request_params)
|
349
|
+
|
350
|
+
# TODO - convert this to use Tool Calling convention for Anthropic Structured outputs
|
351
|
+
multipart_messages[-1].add_text(
|
352
|
+
"""YOU MUST RESPOND IN THE FOLLOWING FORMAT:
|
353
|
+
{schema}
|
354
|
+
RESPOND ONLY WITH THE JSON, NO PREAMBLE, CODE FENCES OR 'properties' ARE PERMISSABLE """.format(
|
355
|
+
schema=model.model_json_schema()
|
356
|
+
)
|
357
|
+
)
|
358
|
+
|
359
|
+
result: PromptMessageMultipart = await self._apply_prompt_provider_specific(
|
360
|
+
multipart_messages, request_params
|
361
|
+
)
|
362
|
+
return self._structured_from_multipart(result, model)
|
363
|
+
|
324
364
|
@classmethod
|
325
365
|
def convert_message_to_message_param(cls, message: Message, **kwargs) -> MessageParam:
|
326
366
|
"""Convert a response object to an input parameter object to allow LLM calls to be chained."""
|
@@ -1,4 +1,3 @@
|
|
1
|
-
|
2
1
|
from mcp_agent.core.request_params import RequestParams
|
3
2
|
from mcp_agent.llm.provider_types import Provider
|
4
3
|
from mcp_agent.llm.providers.augmented_llm_openai import OpenAIAugmentedLLM
|
@@ -9,10 +8,7 @@ DEFAULT_DEEPSEEK_MODEL = "deepseekchat" # current Deepseek only has two type mo
|
|
9
8
|
|
10
9
|
class DeepSeekAugmentedLLM(OpenAIAugmentedLLM):
|
11
10
|
def __init__(self, *args, **kwargs) -> None:
|
12
|
-
|
13
|
-
super().__init__(
|
14
|
-
*args, provider=Provider.DEEPSEEK, **kwargs
|
15
|
-
) # Properly pass args and kwargs to parent
|
11
|
+
super().__init__(*args, provider=Provider.DEEPSEEK, **kwargs)
|
16
12
|
|
17
13
|
def _initialize_default_params(self, kwargs: dict) -> RequestParams:
|
18
14
|
"""Initialize Deepseek-specific default parameters"""
|
@@ -0,0 +1,30 @@
|
|
1
|
+
from mcp_agent.core.request_params import RequestParams
|
2
|
+
from mcp_agent.llm.provider_types import Provider
|
3
|
+
from mcp_agent.llm.providers.augmented_llm_openai import OpenAIAugmentedLLM
|
4
|
+
|
5
|
+
GOOGLE_BASE_URL = "https://generativelanguage.googleapis.com/v1beta/openai"
|
6
|
+
DEFAULT_GOOGLE_MODEL = "gemini-2.0-flash"
|
7
|
+
|
8
|
+
|
9
|
+
class GoogleAugmentedLLM(OpenAIAugmentedLLM):
|
10
|
+
def __init__(self, *args, **kwargs) -> None:
|
11
|
+
super().__init__(*args, provider=Provider.GOOGLE, **kwargs)
|
12
|
+
|
13
|
+
def _initialize_default_params(self, kwargs: dict) -> RequestParams:
|
14
|
+
"""Initialize Google OpenAI Compatibility default parameters"""
|
15
|
+
chosen_model = kwargs.get("model", DEFAULT_GOOGLE_MODEL)
|
16
|
+
|
17
|
+
return RequestParams(
|
18
|
+
model=chosen_model,
|
19
|
+
systemPrompt=self.instruction,
|
20
|
+
parallel_tool_calls=False,
|
21
|
+
max_iterations=10,
|
22
|
+
use_history=True,
|
23
|
+
)
|
24
|
+
|
25
|
+
def _base_url(self) -> str:
|
26
|
+
base_url = None
|
27
|
+
if self.context.config and self.context.config.google:
|
28
|
+
base_url = self.context.config.google.base_url
|
29
|
+
|
30
|
+
return base_url if base_url else GOOGLE_BASE_URL
|