fast-agent-mcp 0.2.16__py3-none-any.whl → 0.2.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. {fast_agent_mcp-0.2.16.dist-info → fast_agent_mcp-0.2.18.dist-info}/METADATA +6 -7
  2. {fast_agent_mcp-0.2.16.dist-info → fast_agent_mcp-0.2.18.dist-info}/RECORD +48 -47
  3. mcp_agent/agents/base_agent.py +50 -6
  4. mcp_agent/agents/workflow/orchestrator_agent.py +6 -7
  5. mcp_agent/agents/workflow/router_agent.py +70 -136
  6. mcp_agent/app.py +1 -124
  7. mcp_agent/cli/commands/go.py +133 -0
  8. mcp_agent/cli/commands/setup.py +2 -2
  9. mcp_agent/cli/main.py +5 -3
  10. mcp_agent/config.py +16 -13
  11. mcp_agent/context.py +4 -22
  12. mcp_agent/core/agent_types.py +2 -2
  13. mcp_agent/core/direct_decorators.py +2 -2
  14. mcp_agent/core/direct_factory.py +2 -1
  15. mcp_agent/core/enhanced_prompt.py +12 -7
  16. mcp_agent/core/fastagent.py +39 -5
  17. mcp_agent/core/interactive_prompt.py +6 -2
  18. mcp_agent/core/request_params.py +5 -1
  19. mcp_agent/core/validation.py +12 -1
  20. mcp_agent/executor/workflow_signal.py +0 -2
  21. mcp_agent/llm/augmented_llm.py +183 -57
  22. mcp_agent/llm/augmented_llm_passthrough.py +1 -1
  23. mcp_agent/llm/augmented_llm_playback.py +21 -1
  24. mcp_agent/llm/memory.py +3 -3
  25. mcp_agent/llm/model_factory.py +3 -1
  26. mcp_agent/llm/provider_key_manager.py +1 -0
  27. mcp_agent/llm/provider_types.py +2 -1
  28. mcp_agent/llm/providers/augmented_llm_anthropic.py +50 -10
  29. mcp_agent/llm/providers/augmented_llm_deepseek.py +1 -5
  30. mcp_agent/llm/providers/augmented_llm_google.py +30 -0
  31. mcp_agent/llm/providers/augmented_llm_openai.py +96 -159
  32. mcp_agent/llm/providers/multipart_converter_openai.py +10 -27
  33. mcp_agent/llm/providers/sampling_converter_openai.py +5 -6
  34. mcp_agent/mcp/interfaces.py +6 -1
  35. mcp_agent/mcp/mcp_aggregator.py +2 -8
  36. mcp_agent/mcp/prompt_message_multipart.py +25 -2
  37. mcp_agent/resources/examples/data-analysis/analysis-campaign.py +2 -2
  38. mcp_agent/resources/examples/in_dev/agent_build.py +1 -1
  39. mcp_agent/resources/examples/internal/job.py +1 -1
  40. mcp_agent/resources/examples/mcp/state-transfer/fastagent.config.yaml +1 -1
  41. mcp_agent/resources/examples/prompting/agent.py +0 -2
  42. mcp_agent/resources/examples/prompting/fastagent.config.yaml +2 -3
  43. mcp_agent/resources/examples/researcher/fastagent.config.yaml +1 -6
  44. mcp_agent/resources/examples/workflows/fastagent.config.yaml +0 -1
  45. mcp_agent/resources/examples/workflows/parallel.py +1 -1
  46. mcp_agent/executor/decorator_registry.py +0 -112
  47. {fast_agent_mcp-0.2.16.dist-info → fast_agent_mcp-0.2.18.dist-info}/WHEEL +0 -0
  48. {fast_agent_mcp-0.2.16.dist-info → fast_agent_mcp-0.2.18.dist-info}/entry_points.txt +0 -0
  49. {fast_agent_mcp-0.2.16.dist-info → fast_agent_mcp-0.2.18.dist-info}/licenses/LICENSE +0 -0
@@ -18,6 +18,8 @@ from mcp.types import (
18
18
  PromptMessage,
19
19
  TextContent,
20
20
  )
21
+ from openai import NotGiven
22
+ from openai.lib._parsing import type_to_response_format_param as _type_to_response_format
21
23
  from pydantic_core import from_json
22
24
  from rich.text import Text
23
25
 
@@ -58,6 +60,20 @@ HUMAN_INPUT_TOOL_NAME = "__human_input__"
58
60
 
59
61
 
60
62
  class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT, MessageT]):
63
+ # Common parameter names used across providers
64
+ PARAM_MESSAGES = "messages"
65
+ PARAM_MODEL = "model"
66
+ PARAM_MAX_TOKENS = "maxTokens"
67
+ PARAM_SYSTEM_PROMPT = "systemPrompt"
68
+ PARAM_STOP_SEQUENCES = "stopSequences"
69
+ PARAM_PARALLEL_TOOL_CALLS = "parallel_tool_calls"
70
+ PARAM_METADATA = "metadata"
71
+ PARAM_USE_HISTORY = "use_history"
72
+ PARAM_MAX_ITERATIONS = "max_iterations"
73
+
74
+ # Base set of fields that should always be excluded
75
+ BASE_EXCLUDE_FIELDS = {PARAM_METADATA}
76
+
61
77
  """
62
78
  The basic building block of agentic systems is an LLM enhanced with augmentations
63
79
  such as retrieval, tools, and memory provided from a collection of MCP servers.
@@ -141,26 +157,6 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
141
157
  use_history=True,
142
158
  )
143
159
 
144
- async def structured(
145
- self,
146
- prompt: List[PromptMessageMultipart],
147
- model: Type[ModelT],
148
- request_params: RequestParams | None = None,
149
- ) -> Tuple[ModelT | None, PromptMessageMultipart]:
150
- """Apply the prompt and return the result as a Pydantic model, or None if coercion fails"""
151
- try:
152
- result: PromptMessageMultipart = await self.generate(prompt, request_params)
153
- final_generation = get_text(result.content[-1]) or ""
154
- await self.show_assistant_message(final_generation)
155
- json_data = from_json(final_generation, allow_partial=True)
156
- validated_model = model.model_validate(json_data)
157
-
158
- return cast("ModelT", validated_model), Prompt.assistant(json_data)
159
- except Exception as e:
160
- logger = get_logger(__name__)
161
- logger.error(f"Failed to parse structured response: {str(e)}")
162
- return None, Prompt.assistant(f"Failed to parse structured response: {str(e)}")
163
-
164
160
  async def generate(
165
161
  self,
166
162
  multipart_messages: List[PromptMessageMultipart],
@@ -169,6 +165,12 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
169
165
  """
170
166
  Create a completion with the LLM using the provided messages.
171
167
  """
168
+ # note - check changes here are mirrored in structured(). i've thought hard about
169
+ # a strategy to reduce duplication etc, but aiming for simple but imperfect for the moment
170
+
171
+ # We never expect this for structured() calls - this is for interactive use - developers
172
+ # can do this programatically
173
+ # TODO -- create a "fast-agent" control role rather than magic strings
172
174
  if multipart_messages[-1].first_text().startswith("***SAVE_HISTORY"):
173
175
  parts: list[str] = multipart_messages[-1].first_text().split(" ", 1)
174
176
  filename: str = (
@@ -180,26 +182,174 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
180
182
  )
181
183
  return Prompt.assistant(f"History saved to {filename}")
182
184
 
183
- self._message_history.extend(multipart_messages)
184
-
185
- if multipart_messages[-1].role == "user":
186
- self.show_user_message(
187
- render_multipart_message(multipart_messages[-1]),
188
- model=self.default_request_params.model,
189
- chat_turn=self.chat_turn(),
190
- )
185
+ self._precall(multipart_messages)
191
186
 
192
187
  assistant_response: PromptMessageMultipart = await self._apply_prompt_provider_specific(
193
188
  multipart_messages, request_params
194
189
  )
195
190
 
191
+ # add generic error and termination reason handling/rollback
196
192
  self._message_history.append(assistant_response)
197
193
  return assistant_response
198
194
 
195
+ @abstractmethod
196
+ async def _apply_prompt_provider_specific(
197
+ self,
198
+ multipart_messages: List["PromptMessageMultipart"],
199
+ request_params: RequestParams | None = None,
200
+ is_template: bool = False,
201
+ ) -> PromptMessageMultipart:
202
+ """
203
+ Provider-specific implementation of apply_prompt_template.
204
+ This default implementation handles basic text content for any LLM type.
205
+ Provider-specific subclasses should override this method to handle
206
+ multimodal content appropriately.
207
+
208
+ Args:
209
+ multipart_messages: List of PromptMessageMultipart objects parsed from the prompt template
210
+
211
+ Returns:
212
+ String representation of the assistant's response if generated,
213
+ or the last assistant message in the prompt
214
+ """
215
+
216
+ async def structured(
217
+ self,
218
+ multipart_messages: List[PromptMessageMultipart],
219
+ model: Type[ModelT],
220
+ request_params: RequestParams | None = None,
221
+ ) -> Tuple[ModelT | None, PromptMessageMultipart]:
222
+ """Return a structured response from the LLM using the provided messages."""
223
+ self._precall(multipart_messages)
224
+ result, assistant_response = await self._apply_prompt_provider_specific_structured(
225
+ multipart_messages, model, request_params
226
+ )
227
+
228
+ self._message_history.append(assistant_response)
229
+ return result, assistant_response
230
+
231
+ @staticmethod
232
+ def model_to_response_format(
233
+ model: Type[Any],
234
+ ) -> Any:
235
+ """
236
+ Convert a pydantic model to the appropriate response format schema.
237
+ This allows for reuse in multiple provider implementations.
238
+
239
+ Args:
240
+ model: The pydantic model class to convert to a schema
241
+
242
+ Returns:
243
+ Provider-agnostic schema representation or NotGiven if conversion fails
244
+ """
245
+ return _type_to_response_format(model)
246
+
247
+ @staticmethod
248
+ def model_to_schema_str(
249
+ model: Type[Any],
250
+ ) -> str:
251
+ """
252
+ Convert a pydantic model to a schema string representation.
253
+ This provides a simpler interface for provider implementations
254
+ that need a string representation.
255
+
256
+ Args:
257
+ model: The pydantic model class to convert to a schema
258
+
259
+ Returns:
260
+ Schema as a string, or empty string if conversion fails
261
+ """
262
+ import json
263
+
264
+ try:
265
+ schema = model.model_json_schema()
266
+ return json.dumps(schema)
267
+ except Exception:
268
+ return ""
269
+
270
+ async def _apply_prompt_provider_specific_structured(
271
+ self,
272
+ multipart_messages: List[PromptMessageMultipart],
273
+ model: Type[ModelT],
274
+ request_params: RequestParams | None = None,
275
+ ) -> Tuple[ModelT | None, PromptMessageMultipart]:
276
+ """Base class attempts to parse JSON - subclasses can use provider specific functionality"""
277
+
278
+ request_params = self.get_request_params(request_params)
279
+
280
+ if not request_params.response_format:
281
+ schema = self.model_to_response_format(model)
282
+ if schema is not NotGiven:
283
+ request_params.response_format = schema
284
+
285
+ result: PromptMessageMultipart = await self._apply_prompt_provider_specific(
286
+ multipart_messages, request_params
287
+ )
288
+ return self._structured_from_multipart(result, model)
289
+
290
+ def _structured_from_multipart(
291
+ self, message: PromptMessageMultipart, model: Type[ModelT]
292
+ ) -> Tuple[ModelT | None, PromptMessageMultipart]:
293
+ """Parse the content of a PromptMessage and return the structured model and message itself"""
294
+ try:
295
+ text = get_text(message.content[-1]) or ""
296
+ json_data = from_json(text, allow_partial=True)
297
+ validated_model = model.model_validate(json_data)
298
+ return cast("ModelT", validated_model), message
299
+ except ValueError as e:
300
+ logger = get_logger(__name__)
301
+ logger.warning(f"Failed to parse structured response: {str(e)}")
302
+ return None, message
303
+
304
+ def _precall(self, multipart_messages: List[PromptMessageMultipart]) -> None:
305
+ """Pre-call hook to modify the message before sending it to the provider."""
306
+ self._message_history.extend(multipart_messages)
307
+ if multipart_messages[-1].role == "user":
308
+ self.show_user_message(
309
+ render_multipart_message(multipart_messages[-1]),
310
+ model=self.default_request_params.model,
311
+ chat_turn=self.chat_turn(),
312
+ )
313
+
199
314
  def chat_turn(self) -> int:
200
315
  """Return the current chat turn number"""
201
316
  return 1 + sum(1 for message in self._message_history if message.role == "assistant")
202
317
 
318
+ def prepare_provider_arguments(
319
+ self,
320
+ base_args: dict,
321
+ request_params: RequestParams,
322
+ exclude_fields: set | None = None,
323
+ ) -> dict:
324
+ """
325
+ Prepare arguments for provider API calls by merging request parameters.
326
+
327
+ Args:
328
+ base_args: Base arguments dictionary with provider-specific required parameters
329
+ params: The RequestParams object containing all parameters
330
+ exclude_fields: Set of field names to exclude from params. If None, uses BASE_EXCLUDE_FIELDS.
331
+
332
+ Returns:
333
+ Complete arguments dictionary with all applicable parameters
334
+ """
335
+ # Start with base arguments
336
+ arguments = base_args.copy()
337
+
338
+ # Use provided exclude_fields or fall back to base exclusions
339
+ exclude_fields = exclude_fields or self.BASE_EXCLUDE_FIELDS.copy()
340
+
341
+ # Add all fields from params that aren't explicitly excluded
342
+ params_dict = request_params.model_dump(exclude=exclude_fields)
343
+ for key, value in params_dict.items():
344
+ if value is not None and key not in arguments:
345
+ arguments[key] = value
346
+
347
+ # Finally, add any metadata fields as a last layer of overrides
348
+ if request_params.metadata:
349
+ arguments.update(request_params.metadata)
350
+
351
+ return arguments
352
+
203
353
  def _merge_request_params(
204
354
  self, default_params: RequestParams, provided_params: RequestParams
205
355
  ) -> RequestParams:
@@ -214,7 +364,6 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
214
364
  def get_request_params(
215
365
  self,
216
366
  request_params: RequestParams | None = None,
217
- default: RequestParams | None = None,
218
367
  ) -> RequestParams:
219
368
  """
220
369
  Get request parameters with merged-in defaults and overrides.
@@ -223,17 +372,12 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
223
372
  default: The default request parameters to use as the base.
224
373
  If unspecified, self.default_request_params will be used.
225
374
  """
226
- # Start with the defaults
227
- default_request_params = default or self.default_request_params
228
-
229
- if not default_request_params:
230
- default_request_params = self._initialize_default_params({})
231
375
 
232
376
  # If user provides overrides, merge them with defaults
233
377
  if request_params:
234
- return self._merge_request_params(default_request_params, request_params)
378
+ return self._merge_request_params(self.default_request_params, request_params)
235
379
 
236
- return default_request_params
380
+ return self.default_request_params.model_copy()
237
381
 
238
382
  @classmethod
239
383
  def convert_message_to_message_param(
@@ -435,7 +579,9 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
435
579
  multipart_messages = PromptMessageMultipart.parse_get_prompt_result(prompt_result)
436
580
 
437
581
  # Delegate to the provider-specific implementation
438
- result = await self._apply_prompt_provider_specific(multipart_messages, None)
582
+ result = await self._apply_prompt_provider_specific(
583
+ multipart_messages, None, is_template=True
584
+ )
439
585
  return result.first_text()
440
586
 
441
587
  async def _save_history(self, filename: str) -> None:
@@ -450,26 +596,6 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
450
596
  # Save messages using the unified save function that auto-detects format
451
597
  save_messages_to_file(self._message_history, filename)
452
598
 
453
- @abstractmethod
454
- async def _apply_prompt_provider_specific(
455
- self,
456
- multipart_messages: List["PromptMessageMultipart"],
457
- request_params: RequestParams | None = None,
458
- ) -> PromptMessageMultipart:
459
- """
460
- Provider-specific implementation of apply_prompt_template.
461
- This default implementation handles basic text content for any LLM type.
462
- Provider-specific subclasses should override this method to handle
463
- multimodal content appropriately.
464
-
465
- Args:
466
- multipart_messages: List of PromptMessageMultipart objects parsed from the prompt template
467
-
468
- Returns:
469
- String representation of the assistant's response if generated,
470
- or the last assistant message in the prompt
471
- """
472
-
473
599
  @property
474
600
  def message_history(self) -> List[PromptMessageMultipart]:
475
601
  """
@@ -143,7 +143,6 @@ class PassthroughLLM(AugmentedLLM):
143
143
  ) -> PromptMessageMultipart:
144
144
  last_message = multipart_messages[-1]
145
145
 
146
- # TODO -- improve when we support Audio/Multimodal gen
147
146
  if self.is_tool_call(last_message):
148
147
  result = Prompt.assistant(await self.generate_str(last_message.first_text()))
149
148
  await self.show_assistant_message(result.first_text())
@@ -158,6 +157,7 @@ class PassthroughLLM(AugmentedLLM):
158
157
  await self.show_assistant_message(self._fixed_response)
159
158
  return Prompt.assistant(self._fixed_response)
160
159
  else:
160
+ # TODO -- improve when we support Audio/Multimodal gen models e.g. gemini . This should really just return the input as "assistant"...
161
161
  concatenated: str = "\n".join(message.all_text() for message in multipart_messages)
162
162
  await self.show_assistant_message(concatenated)
163
163
  return Prompt.assistant(concatenated)
@@ -1,9 +1,11 @@
1
- from typing import Any, List
1
+ from typing import Any, List, Type
2
2
 
3
+ from mcp_agent.core.exceptions import ModelConfigError
3
4
  from mcp_agent.core.prompt import Prompt
4
5
  from mcp_agent.llm.augmented_llm import RequestParams
5
6
  from mcp_agent.llm.augmented_llm_passthrough import PassthroughLLM
6
7
  from mcp_agent.llm.provider_types import Provider
8
+ from mcp_agent.mcp.interfaces import ModelT
7
9
  from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
8
10
  from mcp_agent.mcp.prompts.prompt_helpers import MessageContent
9
11
 
@@ -82,3 +84,21 @@ class PlaybackLLM(PassthroughLLM):
82
84
  )
83
85
 
84
86
  return response
87
+
88
+ async def structured(
89
+ self,
90
+ multipart_messages: List[PromptMessageMultipart],
91
+ model: Type[ModelT],
92
+ request_params: RequestParams | None = None,
93
+ ) -> tuple[ModelT | None, PromptMessageMultipart]:
94
+ """
95
+ Handle structured requests by returning the next assistant message.
96
+ """
97
+
98
+ if -1 == self._current_index:
99
+ raise ModelConfigError("Use generate() to load playback history")
100
+
101
+ return self._structured_from_multipart(
102
+ self._get_next_assistant_message(),
103
+ model,
104
+ )
mcp_agent/llm/memory.py CHANGED
@@ -19,7 +19,7 @@ class Memory(Protocol, Generic[MessageParamT]):
19
19
 
20
20
  def append(self, message: MessageParamT, is_prompt: bool = False) -> None: ...
21
21
 
22
- def get(self, include_history: bool = True) -> List[MessageParamT]: ...
22
+ def get(self, include_completion_history: bool = True) -> List[MessageParamT]: ...
23
23
 
24
24
  def clear(self, clear_prompts: bool = False) -> None: ...
25
25
 
@@ -75,7 +75,7 @@ class SimpleMemory(Memory, Generic[MessageParamT]):
75
75
  else:
76
76
  self.history.append(message)
77
77
 
78
- def get(self, include_history: bool = True) -> List[MessageParamT]:
78
+ def get(self, include_completion_history: bool = True) -> List[MessageParamT]:
79
79
  """
80
80
  Get all messages in memory.
81
81
 
@@ -86,7 +86,7 @@ class SimpleMemory(Memory, Generic[MessageParamT]):
86
86
  Returns:
87
87
  Combined list of prompt messages and optionally history messages
88
88
  """
89
- if include_history:
89
+ if include_completion_history:
90
90
  return self.prompt_messages + self.history
91
91
  else:
92
92
  return self.prompt_messages.copy()
@@ -12,6 +12,7 @@ from mcp_agent.llm.provider_types import Provider
12
12
  from mcp_agent.llm.providers.augmented_llm_anthropic import AnthropicAugmentedLLM
13
13
  from mcp_agent.llm.providers.augmented_llm_deepseek import DeepSeekAugmentedLLM
14
14
  from mcp_agent.llm.providers.augmented_llm_generic import GenericAugmentedLLM
15
+ from mcp_agent.llm.providers.augmented_llm_google import GoogleAugmentedLLM
15
16
  from mcp_agent.llm.providers.augmented_llm_openai import OpenAIAugmentedLLM
16
17
  from mcp_agent.llm.providers.augmented_llm_openrouter import OpenRouterAugmentedLLM
17
18
  from mcp_agent.mcp.interfaces import AugmentedLLMProtocol
@@ -107,6 +108,7 @@ class ModelFactory:
107
108
  Provider.FAST_AGENT: PassthroughLLM,
108
109
  Provider.DEEPSEEK: DeepSeekAugmentedLLM,
109
110
  Provider.GENERIC: GenericAugmentedLLM,
111
+ Provider.GOOGLE: GoogleAugmentedLLM, # type: ignore
110
112
  Provider.OPENROUTER: OpenRouterAugmentedLLM,
111
113
  }
112
114
 
@@ -161,7 +163,7 @@ class ModelFactory:
161
163
  Creates a factory function that follows the attach_llm protocol.
162
164
 
163
165
  Args:
164
- model_string: The model specification string (e.g. "gpt-4o.high")
166
+ model_string: The model specification string (e.g. "gpt-4.1")
165
167
  request_params: Optional parameters to configure LLM behavior
166
168
 
167
169
  Returns:
@@ -14,6 +14,7 @@ PROVIDER_ENVIRONMENT_MAP: Dict[str, str] = {
14
14
  "anthropic": "ANTHROPIC_API_KEY",
15
15
  "openai": "OPENAI_API_KEY",
16
16
  "deepseek": "DEEPSEEK_API_KEY",
17
+ "google": "GOOGLE_API_KEY",
17
18
  "openrouter": "OPENROUTER_API_KEY",
18
19
  "generic": "GENERIC_API_KEY",
19
20
  }
@@ -11,6 +11,7 @@ class Provider(Enum):
11
11
  ANTHROPIC = "anthropic"
12
12
  OPENAI = "openai"
13
13
  FAST_AGENT = "fast-agent"
14
+ GOOGLE = "google"
14
15
  DEEPSEEK = "deepseek"
15
16
  GENERIC = "generic"
16
- OPENROUTER = "openrouter"
17
+ OPENROUTER = "openrouter"
@@ -1,4 +1,4 @@
1
- from typing import TYPE_CHECKING, List
1
+ from typing import TYPE_CHECKING, List, Tuple, Type
2
2
 
3
3
  from mcp.types import EmbeddedResource, ImageContent, TextContent
4
4
 
@@ -10,6 +10,7 @@ from mcp_agent.llm.providers.multipart_converter_anthropic import (
10
10
  from mcp_agent.llm.providers.sampling_converter_anthropic import (
11
11
  AnthropicSamplingConverter,
12
12
  )
13
+ from mcp_agent.mcp.interfaces import ModelT
13
14
  from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
14
15
 
15
16
  if TYPE_CHECKING:
@@ -50,6 +51,19 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
50
51
  selecting appropriate tools, and determining what information to retain.
51
52
  """
52
53
 
54
+ # Anthropic-specific parameter exclusions
55
+ ANTHROPIC_EXCLUDE_FIELDS = {
56
+ AugmentedLLM.PARAM_MESSAGES,
57
+ AugmentedLLM.PARAM_MODEL,
58
+ AugmentedLLM.PARAM_SYSTEM_PROMPT,
59
+ AugmentedLLM.PARAM_STOP_SEQUENCES,
60
+ AugmentedLLM.PARAM_MAX_TOKENS,
61
+ AugmentedLLM.PARAM_METADATA,
62
+ AugmentedLLM.PARAM_USE_HISTORY,
63
+ AugmentedLLM.PARAM_MAX_ITERATIONS,
64
+ AugmentedLLM.PARAM_PARALLEL_TOOL_CALLS,
65
+ }
66
+
53
67
  def __init__(self, *args, **kwargs) -> None:
54
68
  # Initialize logger - keep it simple without name reference
55
69
  self.logger = get_logger(__name__)
@@ -73,7 +87,7 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
73
87
  assert self.context.config
74
88
  return self.context.config.anthropic.base_url if self.context.config.anthropic else None
75
89
 
76
- async def generate_internal(
90
+ async def _anthropic_completion(
77
91
  self,
78
92
  message_param,
79
93
  request_params: RequestParams | None = None,
@@ -100,7 +114,7 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
100
114
 
101
115
  # Always include prompt messages, but only include conversation history
102
116
  # if use_history is True
103
- messages.extend(self.history.get(include_history=params.use_history))
117
+ messages.extend(self.history.get(include_completion_history=params.use_history))
104
118
 
105
119
  messages.append(message_param)
106
120
 
@@ -120,7 +134,8 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
120
134
 
121
135
  for i in range(params.max_iterations):
122
136
  self._log_chat_progress(self.chat_turn(), model=model)
123
- arguments = {
137
+ # Create base arguments dictionary
138
+ base_args = {
124
139
  "model": model,
125
140
  "messages": messages,
126
141
  "system": self.instruction or params.systemPrompt,
@@ -129,10 +144,12 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
129
144
  }
130
145
 
131
146
  if params.maxTokens is not None:
132
- arguments["max_tokens"] = params.maxTokens
147
+ base_args["max_tokens"] = params.maxTokens
133
148
 
134
- if params.metadata:
135
- arguments = {**arguments, **params.metadata}
149
+ # Use the base class method to prepare all arguments with Anthropic-specific exclusions
150
+ arguments = self.prepare_provider_arguments(
151
+ base_args, params, self.ANTHROPIC_EXCLUDE_FIELDS
152
+ )
136
153
 
137
154
  self.logger.debug(f"{arguments}")
138
155
 
@@ -265,7 +282,7 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
265
282
  # Keep the prompt messages separate
266
283
  if params.use_history:
267
284
  # Get current prompt messages
268
- prompt_messages = self.history.get(include_history=False)
285
+ prompt_messages = self.history.get(include_completion_history=False)
269
286
 
270
287
  # Calculate new conversation messages (excluding prompts)
271
288
  new_messages = messages[len(prompt_messages) :]
@@ -288,7 +305,7 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
288
305
  Override this method to use a different LLM.
289
306
 
290
307
  """
291
- res = await self.generate_internal(
308
+ res = await self._anthropic_completion(
292
309
  message_param=message_param,
293
310
  request_params=request_params,
294
311
  )
@@ -298,6 +315,7 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
298
315
  self,
299
316
  multipart_messages: List["PromptMessageMultipart"],
300
317
  request_params: RequestParams | None = None,
318
+ is_template: bool = False,
301
319
  ) -> PromptMessageMultipart:
302
320
  # Check the last message role
303
321
  last_message = multipart_messages[-1]
@@ -310,7 +328,7 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
310
328
  for msg in messages_to_add:
311
329
  converted.append(AnthropicConverter.convert_to_anthropic(msg))
312
330
 
313
- self.history.extend(converted, is_prompt=True)
331
+ self.history.extend(converted, is_prompt=is_template)
314
332
 
315
333
  if last_message.role == "user":
316
334
  self.logger.debug("Last message in prompt is from user, generating assistant response")
@@ -321,6 +339,28 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
321
339
  self.logger.debug("Last message in prompt is from assistant, returning it directly")
322
340
  return last_message
323
341
 
342
+ async def _apply_prompt_provider_specific_structured(
343
+ self,
344
+ multipart_messages: List[PromptMessageMultipart],
345
+ model: Type[ModelT],
346
+ request_params: RequestParams | None = None,
347
+ ) -> Tuple[ModelT | None, PromptMessageMultipart]: # noqa: F821
348
+ request_params = self.get_request_params(request_params)
349
+
350
+ # TODO - convert this to use Tool Calling convention for Anthropic Structured outputs
351
+ multipart_messages[-1].add_text(
352
+ """YOU MUST RESPOND IN THE FOLLOWING FORMAT:
353
+ {schema}
354
+ RESPOND ONLY WITH THE JSON, NO PREAMBLE, CODE FENCES OR 'properties' ARE PERMISSABLE """.format(
355
+ schema=model.model_json_schema()
356
+ )
357
+ )
358
+
359
+ result: PromptMessageMultipart = await self._apply_prompt_provider_specific(
360
+ multipart_messages, request_params
361
+ )
362
+ return self._structured_from_multipart(result, model)
363
+
324
364
  @classmethod
325
365
  def convert_message_to_message_param(cls, message: Message, **kwargs) -> MessageParam:
326
366
  """Convert a response object to an input parameter object to allow LLM calls to be chained."""
@@ -1,4 +1,3 @@
1
-
2
1
  from mcp_agent.core.request_params import RequestParams
3
2
  from mcp_agent.llm.provider_types import Provider
4
3
  from mcp_agent.llm.providers.augmented_llm_openai import OpenAIAugmentedLLM
@@ -9,10 +8,7 @@ DEFAULT_DEEPSEEK_MODEL = "deepseekchat" # current Deepseek only has two type mo
9
8
 
10
9
  class DeepSeekAugmentedLLM(OpenAIAugmentedLLM):
11
10
  def __init__(self, *args, **kwargs) -> None:
12
- kwargs["provider_name"] = "Deepseek" # Set provider name in kwargs
13
- super().__init__(
14
- *args, provider=Provider.DEEPSEEK, **kwargs
15
- ) # Properly pass args and kwargs to parent
11
+ super().__init__(*args, provider=Provider.DEEPSEEK, **kwargs)
16
12
 
17
13
  def _initialize_default_params(self, kwargs: dict) -> RequestParams:
18
14
  """Initialize Deepseek-specific default parameters"""
@@ -0,0 +1,30 @@
1
+ from mcp_agent.core.request_params import RequestParams
2
+ from mcp_agent.llm.provider_types import Provider
3
+ from mcp_agent.llm.providers.augmented_llm_openai import OpenAIAugmentedLLM
4
+
5
+ GOOGLE_BASE_URL = "https://generativelanguage.googleapis.com/v1beta/openai"
6
+ DEFAULT_GOOGLE_MODEL = "gemini-2.0-flash"
7
+
8
+
9
+ class GoogleAugmentedLLM(OpenAIAugmentedLLM):
10
+ def __init__(self, *args, **kwargs) -> None:
11
+ super().__init__(*args, provider=Provider.GOOGLE, **kwargs)
12
+
13
+ def _initialize_default_params(self, kwargs: dict) -> RequestParams:
14
+ """Initialize Google OpenAI Compatibility default parameters"""
15
+ chosen_model = kwargs.get("model", DEFAULT_GOOGLE_MODEL)
16
+
17
+ return RequestParams(
18
+ model=chosen_model,
19
+ systemPrompt=self.instruction,
20
+ parallel_tool_calls=False,
21
+ max_iterations=10,
22
+ use_history=True,
23
+ )
24
+
25
+ def _base_url(self) -> str:
26
+ base_url = None
27
+ if self.context.config and self.context.config.google:
28
+ base_url = self.context.config.google.base_url
29
+
30
+ return base_url if base_url else GOOGLE_BASE_URL