fast-agent-mcp 0.2.14__py3-none-any.whl → 0.2.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. {fast_agent_mcp-0.2.14.dist-info → fast_agent_mcp-0.2.17.dist-info}/METADATA +4 -6
  2. {fast_agent_mcp-0.2.14.dist-info → fast_agent_mcp-0.2.17.dist-info}/RECORD +46 -46
  3. mcp_agent/agents/base_agent.py +50 -6
  4. mcp_agent/agents/workflow/orchestrator_agent.py +6 -7
  5. mcp_agent/agents/workflow/router_agent.py +70 -136
  6. mcp_agent/app.py +1 -124
  7. mcp_agent/cli/commands/setup.py +1 -1
  8. mcp_agent/config.py +19 -19
  9. mcp_agent/context.py +4 -22
  10. mcp_agent/core/agent_types.py +2 -2
  11. mcp_agent/core/direct_decorators.py +2 -2
  12. mcp_agent/core/direct_factory.py +2 -1
  13. mcp_agent/core/enhanced_prompt.py +6 -5
  14. mcp_agent/core/fastagent.py +1 -1
  15. mcp_agent/core/interactive_prompt.py +70 -50
  16. mcp_agent/core/request_params.py +5 -1
  17. mcp_agent/executor/workflow_signal.py +0 -2
  18. mcp_agent/llm/augmented_llm.py +183 -57
  19. mcp_agent/llm/augmented_llm_passthrough.py +1 -1
  20. mcp_agent/llm/augmented_llm_playback.py +21 -1
  21. mcp_agent/llm/memory.py +3 -3
  22. mcp_agent/llm/model_factory.py +3 -1
  23. mcp_agent/llm/provider_key_manager.py +1 -0
  24. mcp_agent/llm/provider_types.py +2 -1
  25. mcp_agent/llm/providers/augmented_llm_anthropic.py +49 -10
  26. mcp_agent/llm/providers/augmented_llm_deepseek.py +0 -2
  27. mcp_agent/llm/providers/augmented_llm_generic.py +4 -2
  28. mcp_agent/llm/providers/augmented_llm_google.py +30 -0
  29. mcp_agent/llm/providers/augmented_llm_openai.py +95 -158
  30. mcp_agent/llm/providers/multipart_converter_openai.py +10 -27
  31. mcp_agent/llm/providers/sampling_converter_openai.py +5 -6
  32. mcp_agent/mcp/interfaces.py +6 -1
  33. mcp_agent/mcp/mcp_aggregator.py +2 -8
  34. mcp_agent/mcp/prompt_message_multipart.py +25 -2
  35. mcp_agent/resources/examples/data-analysis/analysis-campaign.py +2 -2
  36. mcp_agent/resources/examples/in_dev/agent_build.py +1 -1
  37. mcp_agent/resources/examples/internal/job.py +1 -1
  38. mcp_agent/resources/examples/mcp/state-transfer/fastagent.config.yaml +1 -1
  39. mcp_agent/resources/examples/prompting/agent.py +0 -2
  40. mcp_agent/resources/examples/prompting/fastagent.config.yaml +2 -3
  41. mcp_agent/resources/examples/researcher/fastagent.config.yaml +1 -6
  42. mcp_agent/resources/examples/workflows/fastagent.config.yaml +0 -1
  43. mcp_agent/resources/examples/workflows/parallel.py +1 -1
  44. mcp_agent/executor/decorator_registry.py +0 -112
  45. {fast_agent_mcp-0.2.14.dist-info → fast_agent_mcp-0.2.17.dist-info}/WHEEL +0 -0
  46. {fast_agent_mcp-0.2.14.dist-info → fast_agent_mcp-0.2.17.dist-info}/entry_points.txt +0 -0
  47. {fast_agent_mcp-0.2.14.dist-info → fast_agent_mcp-0.2.17.dist-info}/licenses/LICENSE +0 -0
@@ -14,6 +14,7 @@ from prompt_toolkit.key_binding import KeyBindings
14
14
  from prompt_toolkit.styles import Style
15
15
  from rich import print as rich_print
16
16
 
17
+ from mcp_agent.core.agent_types import AgentType
17
18
  from mcp_agent.core.exceptions import PromptExitError
18
19
 
19
20
  # Get the application version
@@ -86,7 +87,7 @@ class AgentCompleter(Completer):
86
87
  for agent in self.agents:
87
88
  if agent.lower().startswith(agent_name.lower()):
88
89
  # Get agent type or default to "Agent"
89
- agent_type = self.agent_types.get(agent, "Agent")
90
+ agent_type = self.agent_types.get(agent, AgentType.BASIC).value
90
91
  yield Completion(
91
92
  agent,
92
93
  start_position=-len(agent_name),
@@ -149,7 +150,7 @@ async def get_enhanced_input(
149
150
  show_stop_hint: bool = False,
150
151
  multiline: bool = False,
151
152
  available_agent_names: List[str] = None,
152
- agent_types: dict = None,
153
+ agent_types: dict[str, AgentType] = None,
153
154
  is_human_input: bool = False,
154
155
  toolbar_color: str = "ansiblue",
155
156
  ) -> str:
@@ -430,18 +431,18 @@ async def get_argument_input(
430
431
  async def handle_special_commands(command, agent_app=None):
431
432
  """
432
433
  Handle special input commands.
433
-
434
+
434
435
  Args:
435
436
  command: The command to handle, can be string or dictionary
436
437
  agent_app: Optional agent app reference
437
-
438
+
438
439
  Returns:
439
440
  True if command was handled, False if not, or a dict with action info
440
441
  """
441
442
  # Quick guard for empty or None commands
442
443
  if not command:
443
444
  return False
444
-
445
+
445
446
  # If command is already a dictionary, it has been pre-processed
446
447
  # Just return it directly (like when /prompts converts to select_prompt dict)
447
448
  if isinstance(command, dict):
@@ -381,7 +381,7 @@ class FastAgent:
381
381
  handle_error(
382
382
  e,
383
383
  "Model Configuration Error",
384
- "Common models: gpt-4o, o3-mini, sonnet, haiku. for o3, set reasoning effort with o3-mini.high",
384
+ "Common models: gpt-4.1, o3-mini, sonnet, haiku. for o3, set reasoning effort with o3-mini.high",
385
385
  )
386
386
  elif isinstance(e, CircularDependencyError):
387
387
  handle_error(
@@ -20,6 +20,7 @@ from rich import print as rich_print
20
20
  from rich.console import Console
21
21
  from rich.table import Table
22
22
 
23
+ from mcp_agent.core.agent_types import AgentType
23
24
  from mcp_agent.core.enhanced_prompt import (
24
25
  get_argument_input,
25
26
  get_enhanced_input,
@@ -36,14 +37,14 @@ class InteractivePrompt:
36
37
  This is extracted from the original AgentApp implementation to support DirectAgentApp.
37
38
  """
38
39
 
39
- def __init__(self, agent_types: Optional[Dict[str, str]] = None) -> None:
40
+ def __init__(self, agent_types: Optional[Dict[str, AgentType]] = None) -> None:
40
41
  """
41
42
  Initialize the interactive prompt.
42
43
 
43
44
  Args:
44
45
  agent_types: Dictionary mapping agent names to their types for display
45
46
  """
46
- self.agent_types = agent_types or {}
47
+ self.agent_types: Dict[str, AgentType] = agent_types or {}
47
48
 
48
49
  async def prompt_loop(
49
50
  self,
@@ -97,7 +98,7 @@ class InteractivePrompt:
97
98
 
98
99
  # Handle special commands - pass "True" to enable agent switching
99
100
  command_result = await handle_special_commands(user_input, True)
100
-
101
+
101
102
  # Check if we should switch agents
102
103
  if isinstance(command_result, dict):
103
104
  if "switch_agent" in command_result:
@@ -113,11 +114,13 @@ class InteractivePrompt:
113
114
  # Use the list_prompts_func directly
114
115
  await self._list_prompts(list_prompts_func, agent)
115
116
  continue
116
- elif "select_prompt" in command_result and (list_prompts_func and apply_prompt_func):
117
+ elif "select_prompt" in command_result and (
118
+ list_prompts_func and apply_prompt_func
119
+ ):
117
120
  # Handle prompt selection, using both list_prompts and apply_prompt
118
121
  prompt_name = command_result.get("prompt_name")
119
122
  prompt_index = command_result.get("prompt_index")
120
-
123
+
121
124
  # If a specific index was provided (from /prompt <number>)
122
125
  if prompt_index is not None:
123
126
  # First get a list of all prompts to look up the index
@@ -125,20 +128,29 @@ class InteractivePrompt:
125
128
  if not all_prompts:
126
129
  rich_print("[yellow]No prompts available[/yellow]")
127
130
  continue
128
-
131
+
129
132
  # Check if the index is valid
130
133
  if 1 <= prompt_index <= len(all_prompts):
131
134
  # Get the prompt at the specified index (1-based to 0-based)
132
135
  selected_prompt = all_prompts[prompt_index - 1]
133
136
  # Use the already created namespaced_name to ensure consistency
134
- await self._select_prompt(list_prompts_func, apply_prompt_func, agent, selected_prompt["namespaced_name"])
137
+ await self._select_prompt(
138
+ list_prompts_func,
139
+ apply_prompt_func,
140
+ agent,
141
+ selected_prompt["namespaced_name"],
142
+ )
135
143
  else:
136
- rich_print(f"[red]Invalid prompt number: {prompt_index}. Valid range is 1-{len(all_prompts)}[/red]")
144
+ rich_print(
145
+ f"[red]Invalid prompt number: {prompt_index}. Valid range is 1-{len(all_prompts)}[/red]"
146
+ )
137
147
  # Show the prompt list for convenience
138
148
  await self._list_prompts(list_prompts_func, agent)
139
149
  else:
140
150
  # Use the name-based selection
141
- await self._select_prompt(list_prompts_func, apply_prompt_func, agent, prompt_name)
151
+ await self._select_prompt(
152
+ list_prompts_func, apply_prompt_func, agent, prompt_name
153
+ )
142
154
  continue
143
155
 
144
156
  # Skip further processing if command was handled
@@ -158,11 +170,11 @@ class InteractivePrompt:
158
170
  async def _get_all_prompts(self, list_prompts_func, agent_name):
159
171
  """
160
172
  Get a list of all available prompts.
161
-
173
+
162
174
  Args:
163
175
  list_prompts_func: Function to get available prompts
164
176
  agent_name: Name of the agent
165
-
177
+
166
178
  Returns:
167
179
  List of prompt info dictionaries, sorted by server and name
168
180
  """
@@ -171,7 +183,7 @@ class InteractivePrompt:
171
183
  # the agent_name parameter should never be used as a server name
172
184
  prompt_servers = await list_prompts_func(None)
173
185
  all_prompts = []
174
-
186
+
175
187
  # Process the returned prompt servers
176
188
  if prompt_servers:
177
189
  # First collect all prompts
@@ -179,44 +191,51 @@ class InteractivePrompt:
179
191
  if prompts_info and hasattr(prompts_info, "prompts") and prompts_info.prompts:
180
192
  for prompt in prompts_info.prompts:
181
193
  # Use the SEP constant for proper namespacing
182
- all_prompts.append({
183
- "server": server_name,
184
- "name": prompt.name,
185
- "namespaced_name": f"{server_name}{SEP}{prompt.name}",
186
- "description": getattr(prompt, "description", "No description"),
187
- "arg_count": len(getattr(prompt, "arguments", [])),
188
- "arguments": getattr(prompt, "arguments", [])
189
- })
194
+ all_prompts.append(
195
+ {
196
+ "server": server_name,
197
+ "name": prompt.name,
198
+ "namespaced_name": f"{server_name}{SEP}{prompt.name}",
199
+ "description": getattr(prompt, "description", "No description"),
200
+ "arg_count": len(getattr(prompt, "arguments", [])),
201
+ "arguments": getattr(prompt, "arguments", []),
202
+ }
203
+ )
190
204
  elif isinstance(prompts_info, list) and prompts_info:
191
205
  for prompt in prompts_info:
192
206
  if isinstance(prompt, dict) and "name" in prompt:
193
- all_prompts.append({
194
- "server": server_name,
195
- "name": prompt["name"],
196
- "namespaced_name": f"{server_name}{SEP}{prompt['name']}",
197
- "description": prompt.get("description", "No description"),
198
- "arg_count": len(prompt.get("arguments", [])),
199
- "arguments": prompt.get("arguments", [])
200
- })
207
+ all_prompts.append(
208
+ {
209
+ "server": server_name,
210
+ "name": prompt["name"],
211
+ "namespaced_name": f"{server_name}{SEP}{prompt['name']}",
212
+ "description": prompt.get("description", "No description"),
213
+ "arg_count": len(prompt.get("arguments", [])),
214
+ "arguments": prompt.get("arguments", []),
215
+ }
216
+ )
201
217
  else:
202
- all_prompts.append({
203
- "server": server_name,
204
- "name": str(prompt),
205
- "namespaced_name": f"{server_name}{SEP}{str(prompt)}",
206
- "description": "No description",
207
- "arg_count": 0,
208
- "arguments": []
209
- })
210
-
218
+ all_prompts.append(
219
+ {
220
+ "server": server_name,
221
+ "name": str(prompt),
222
+ "namespaced_name": f"{server_name}{SEP}{str(prompt)}",
223
+ "description": "No description",
224
+ "arg_count": 0,
225
+ "arguments": [],
226
+ }
227
+ )
228
+
211
229
  # Sort prompts by server and name for consistent ordering
212
230
  all_prompts.sort(key=lambda p: (p["server"], p["name"]))
213
-
231
+
214
232
  return all_prompts
215
-
233
+
216
234
  except Exception as e:
217
235
  import traceback
218
236
 
219
237
  from rich import print as rich_print
238
+
220
239
  rich_print(f"[red]Error getting prompts: {e}[/red]")
221
240
  rich_print(f"[dim]{traceback.format_exc()}[/dim]")
222
241
  return []
@@ -238,11 +257,11 @@ class InteractivePrompt:
238
257
  try:
239
258
  # Directly call the list_prompts function for this agent
240
259
  rich_print(f"\n[bold]Fetching prompts for agent [cyan]{agent_name}[/cyan]...[/bold]")
241
-
260
+
242
261
  # Get all prompts using the helper function - pass None as server name
243
262
  # to get prompts from all available servers
244
263
  all_prompts = await self._get_all_prompts(list_prompts_func, None)
245
-
264
+
246
265
  if all_prompts:
247
266
  # Create a table for better display
248
267
  table = Table(title="Available MCP Prompts")
@@ -251,7 +270,7 @@ class InteractivePrompt:
251
270
  table.add_column("Prompt Name", style="bright_blue")
252
271
  table.add_column("Description")
253
272
  table.add_column("Args", justify="center")
254
-
273
+
255
274
  # Add prompts to table
256
275
  for i, prompt in enumerate(all_prompts):
257
276
  table.add_row(
@@ -259,11 +278,11 @@ class InteractivePrompt:
259
278
  prompt["server"],
260
279
  prompt["name"],
261
280
  prompt["description"],
262
- str(prompt["arg_count"])
281
+ str(prompt["arg_count"]),
263
282
  )
264
-
283
+
265
284
  console.print(table)
266
-
285
+
267
286
  # Add usage instructions
268
287
  rich_print("\n[bold]Usage:[/bold]")
269
288
  rich_print(" • Use [cyan]/prompt <number>[/cyan] to select a prompt by number")
@@ -272,10 +291,13 @@ class InteractivePrompt:
272
291
  rich_print("[yellow]No prompts available[/yellow]")
273
292
  except Exception as e:
274
293
  import traceback
294
+
275
295
  rich_print(f"[red]Error listing prompts: {e}[/red]")
276
296
  rich_print(f"[dim]{traceback.format_exc()}[/dim]")
277
297
 
278
- async def _select_prompt(self, list_prompts_func, apply_prompt_func, agent_name, requested_name=None) -> None:
298
+ async def _select_prompt(
299
+ self, list_prompts_func, apply_prompt_func, agent_name, requested_name=None
300
+ ) -> None:
279
301
  """
280
302
  Select and apply a prompt.
281
303
 
@@ -293,7 +315,7 @@ class InteractivePrompt:
293
315
  try:
294
316
  # Get all available prompts directly from the list_prompts function
295
317
  rich_print(f"\n[bold]Fetching prompts for agent [cyan]{agent_name}[/cyan]...[/bold]")
296
- # IMPORTANT: list_prompts_func gets MCP server prompts, not agent prompts
318
+ # IMPORTANT: list_prompts_func gets MCP server prompts, not agent prompts
297
319
  # So we pass None to get prompts from all servers, not using agent_name as server name
298
320
  prompt_servers = await list_prompts_func(None)
299
321
 
@@ -514,9 +536,7 @@ class InteractivePrompt:
514
536
 
515
537
  # Apply the prompt
516
538
  namespaced_name = selected_prompt["namespaced_name"]
517
- rich_print(
518
- f"\n[bold]Applying prompt [cyan]{namespaced_name}[/cyan]...[/bold]"
519
- )
539
+ rich_print(f"\n[bold]Applying prompt [cyan]{namespaced_name}[/cyan]...[/bold]")
520
540
 
521
541
  # Call apply_prompt function with the prompt name and arguments
522
542
  await apply_prompt_func(namespaced_name, arg_values, agent_name)
@@ -2,7 +2,7 @@
2
2
  Request parameters definitions for LLM interactions.
3
3
  """
4
4
 
5
- from typing import List
5
+ from typing import Any, List
6
6
 
7
7
  from mcp import SamplingMessage
8
8
  from mcp.types import CreateMessageRequestParams
@@ -44,3 +44,7 @@ class RequestParams(CreateMessageRequestParams):
44
44
  Whether to allow multiple tool calls per iteration.
45
45
  Also known as multi-step tool use.
46
46
  """
47
+ response_format: Any | None = None
48
+ """
49
+ Override response format for structured calls. Prefer sending pydantic model - only use in exceptional circumstances
50
+ """
@@ -7,8 +7,6 @@ from pydantic import BaseModel, ConfigDict
7
7
 
8
8
  SignalValueT = TypeVar("SignalValueT")
9
9
 
10
- # TODO: saqadri - handle signals properly that works with other execution backends like Temporal as well
11
-
12
10
 
13
11
  class Signal(BaseModel, Generic[SignalValueT]):
14
12
  """Represents a signal that can be sent to a workflow."""
@@ -18,6 +18,8 @@ from mcp.types import (
18
18
  PromptMessage,
19
19
  TextContent,
20
20
  )
21
+ from openai import NotGiven
22
+ from openai.lib._parsing import type_to_response_format_param as _type_to_response_format
21
23
  from pydantic_core import from_json
22
24
  from rich.text import Text
23
25
 
@@ -58,6 +60,20 @@ HUMAN_INPUT_TOOL_NAME = "__human_input__"
58
60
 
59
61
 
60
62
  class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT, MessageT]):
63
+ # Common parameter names used across providers
64
+ PARAM_MESSAGES = "messages"
65
+ PARAM_MODEL = "model"
66
+ PARAM_MAX_TOKENS = "maxTokens"
67
+ PARAM_SYSTEM_PROMPT = "systemPrompt"
68
+ PARAM_STOP_SEQUENCES = "stopSequences"
69
+ PARAM_PARALLEL_TOOL_CALLS = "parallel_tool_calls"
70
+ PARAM_METADATA = "metadata"
71
+ PARAM_USE_HISTORY = "use_history"
72
+ PARAM_MAX_ITERATIONS = "max_iterations"
73
+
74
+ # Base set of fields that should always be excluded
75
+ BASE_EXCLUDE_FIELDS = {PARAM_METADATA}
76
+
61
77
  """
62
78
  The basic building block of agentic systems is an LLM enhanced with augmentations
63
79
  such as retrieval, tools, and memory provided from a collection of MCP servers.
@@ -141,26 +157,6 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
141
157
  use_history=True,
142
158
  )
143
159
 
144
- async def structured(
145
- self,
146
- prompt: List[PromptMessageMultipart],
147
- model: Type[ModelT],
148
- request_params: RequestParams | None = None,
149
- ) -> Tuple[ModelT | None, PromptMessageMultipart]:
150
- """Apply the prompt and return the result as a Pydantic model, or None if coercion fails"""
151
- try:
152
- result: PromptMessageMultipart = await self.generate(prompt, request_params)
153
- final_generation = get_text(result.content[-1]) or ""
154
- await self.show_assistant_message(final_generation)
155
- json_data = from_json(final_generation, allow_partial=True)
156
- validated_model = model.model_validate(json_data)
157
-
158
- return cast("ModelT", validated_model), Prompt.assistant(json_data)
159
- except Exception as e:
160
- logger = get_logger(__name__)
161
- logger.error(f"Failed to parse structured response: {str(e)}")
162
- return None, Prompt.assistant(f"Failed to parse structured response: {str(e)}")
163
-
164
160
  async def generate(
165
161
  self,
166
162
  multipart_messages: List[PromptMessageMultipart],
@@ -169,6 +165,12 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
169
165
  """
170
166
  Create a completion with the LLM using the provided messages.
171
167
  """
168
+ # note - check changes here are mirrored in structured(). i've thought hard about
169
+ # a strategy to reduce duplication etc, but aiming for simple but imperfect for the moment
170
+
171
+ # We never expect this for structured() calls - this is for interactive use - developers
172
+ # can do this programatically
173
+ # TODO -- create a "fast-agent" control role rather than magic strings
172
174
  if multipart_messages[-1].first_text().startswith("***SAVE_HISTORY"):
173
175
  parts: list[str] = multipart_messages[-1].first_text().split(" ", 1)
174
176
  filename: str = (
@@ -180,26 +182,174 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
180
182
  )
181
183
  return Prompt.assistant(f"History saved to {filename}")
182
184
 
183
- self._message_history.extend(multipart_messages)
184
-
185
- if multipart_messages[-1].role == "user":
186
- self.show_user_message(
187
- render_multipart_message(multipart_messages[-1]),
188
- model=self.default_request_params.model,
189
- chat_turn=self.chat_turn(),
190
- )
185
+ self._precall(multipart_messages)
191
186
 
192
187
  assistant_response: PromptMessageMultipart = await self._apply_prompt_provider_specific(
193
188
  multipart_messages, request_params
194
189
  )
195
190
 
191
+ # add generic error and termination reason handling/rollback
196
192
  self._message_history.append(assistant_response)
197
193
  return assistant_response
198
194
 
195
+ @abstractmethod
196
+ async def _apply_prompt_provider_specific(
197
+ self,
198
+ multipart_messages: List["PromptMessageMultipart"],
199
+ request_params: RequestParams | None = None,
200
+ is_template: bool = False,
201
+ ) -> PromptMessageMultipart:
202
+ """
203
+ Provider-specific implementation of apply_prompt_template.
204
+ This default implementation handles basic text content for any LLM type.
205
+ Provider-specific subclasses should override this method to handle
206
+ multimodal content appropriately.
207
+
208
+ Args:
209
+ multipart_messages: List of PromptMessageMultipart objects parsed from the prompt template
210
+
211
+ Returns:
212
+ String representation of the assistant's response if generated,
213
+ or the last assistant message in the prompt
214
+ """
215
+
216
+ async def structured(
217
+ self,
218
+ multipart_messages: List[PromptMessageMultipart],
219
+ model: Type[ModelT],
220
+ request_params: RequestParams | None = None,
221
+ ) -> Tuple[ModelT | None, PromptMessageMultipart]:
222
+ """Return a structured response from the LLM using the provided messages."""
223
+ self._precall(multipart_messages)
224
+ result, assistant_response = await self._apply_prompt_provider_specific_structured(
225
+ multipart_messages, model, request_params
226
+ )
227
+
228
+ self._message_history.append(assistant_response)
229
+ return result, assistant_response
230
+
231
+ @staticmethod
232
+ def model_to_response_format(
233
+ model: Type[Any],
234
+ ) -> Any:
235
+ """
236
+ Convert a pydantic model to the appropriate response format schema.
237
+ This allows for reuse in multiple provider implementations.
238
+
239
+ Args:
240
+ model: The pydantic model class to convert to a schema
241
+
242
+ Returns:
243
+ Provider-agnostic schema representation or NotGiven if conversion fails
244
+ """
245
+ return _type_to_response_format(model)
246
+
247
+ @staticmethod
248
+ def model_to_schema_str(
249
+ model: Type[Any],
250
+ ) -> str:
251
+ """
252
+ Convert a pydantic model to a schema string representation.
253
+ This provides a simpler interface for provider implementations
254
+ that need a string representation.
255
+
256
+ Args:
257
+ model: The pydantic model class to convert to a schema
258
+
259
+ Returns:
260
+ Schema as a string, or empty string if conversion fails
261
+ """
262
+ import json
263
+
264
+ try:
265
+ schema = model.model_json_schema()
266
+ return json.dumps(schema)
267
+ except Exception:
268
+ return ""
269
+
270
+ async def _apply_prompt_provider_specific_structured(
271
+ self,
272
+ multipart_messages: List[PromptMessageMultipart],
273
+ model: Type[ModelT],
274
+ request_params: RequestParams | None = None,
275
+ ) -> Tuple[ModelT | None, PromptMessageMultipart]:
276
+ """Base class attempts to parse JSON - subclasses can use provider specific functionality"""
277
+
278
+ request_params = self.get_request_params(request_params)
279
+
280
+ if not request_params.response_format:
281
+ schema = self.model_to_response_format(model)
282
+ if schema is not NotGiven:
283
+ request_params.response_format = schema
284
+
285
+ result: PromptMessageMultipart = await self._apply_prompt_provider_specific(
286
+ multipart_messages, request_params
287
+ )
288
+ return self._structured_from_multipart(result, model)
289
+
290
+ def _structured_from_multipart(
291
+ self, message: PromptMessageMultipart, model: Type[ModelT]
292
+ ) -> Tuple[ModelT | None, PromptMessageMultipart]:
293
+ """Parse the content of a PromptMessage and return the structured model and message itself"""
294
+ try:
295
+ text = get_text(message.content[-1]) or ""
296
+ json_data = from_json(text, allow_partial=True)
297
+ validated_model = model.model_validate(json_data)
298
+ return cast("ModelT", validated_model), message
299
+ except ValueError as e:
300
+ logger = get_logger(__name__)
301
+ logger.warning(f"Failed to parse structured response: {str(e)}")
302
+ return None, message
303
+
304
+ def _precall(self, multipart_messages: List[PromptMessageMultipart]) -> None:
305
+ """Pre-call hook to modify the message before sending it to the provider."""
306
+ self._message_history.extend(multipart_messages)
307
+ if multipart_messages[-1].role == "user":
308
+ self.show_user_message(
309
+ render_multipart_message(multipart_messages[-1]),
310
+ model=self.default_request_params.model,
311
+ chat_turn=self.chat_turn(),
312
+ )
313
+
199
314
  def chat_turn(self) -> int:
200
315
  """Return the current chat turn number"""
201
316
  return 1 + sum(1 for message in self._message_history if message.role == "assistant")
202
317
 
318
+ def prepare_provider_arguments(
319
+ self,
320
+ base_args: dict,
321
+ request_params: RequestParams,
322
+ exclude_fields: set | None = None,
323
+ ) -> dict:
324
+ """
325
+ Prepare arguments for provider API calls by merging request parameters.
326
+
327
+ Args:
328
+ base_args: Base arguments dictionary with provider-specific required parameters
329
+ params: The RequestParams object containing all parameters
330
+ exclude_fields: Set of field names to exclude from params. If None, uses BASE_EXCLUDE_FIELDS.
331
+
332
+ Returns:
333
+ Complete arguments dictionary with all applicable parameters
334
+ """
335
+ # Start with base arguments
336
+ arguments = base_args.copy()
337
+
338
+ # Use provided exclude_fields or fall back to base exclusions
339
+ exclude_fields = exclude_fields or self.BASE_EXCLUDE_FIELDS.copy()
340
+
341
+ # Add all fields from params that aren't explicitly excluded
342
+ params_dict = request_params.model_dump(exclude=exclude_fields)
343
+ for key, value in params_dict.items():
344
+ if value is not None and key not in arguments:
345
+ arguments[key] = value
346
+
347
+ # Finally, add any metadata fields as a last layer of overrides
348
+ if request_params.metadata:
349
+ arguments.update(request_params.metadata)
350
+
351
+ return arguments
352
+
203
353
  def _merge_request_params(
204
354
  self, default_params: RequestParams, provided_params: RequestParams
205
355
  ) -> RequestParams:
@@ -214,7 +364,6 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
214
364
  def get_request_params(
215
365
  self,
216
366
  request_params: RequestParams | None = None,
217
- default: RequestParams | None = None,
218
367
  ) -> RequestParams:
219
368
  """
220
369
  Get request parameters with merged-in defaults and overrides.
@@ -223,17 +372,12 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
223
372
  default: The default request parameters to use as the base.
224
373
  If unspecified, self.default_request_params will be used.
225
374
  """
226
- # Start with the defaults
227
- default_request_params = default or self.default_request_params
228
-
229
- if not default_request_params:
230
- default_request_params = self._initialize_default_params({})
231
375
 
232
376
  # If user provides overrides, merge them with defaults
233
377
  if request_params:
234
- return self._merge_request_params(default_request_params, request_params)
378
+ return self._merge_request_params(self.default_request_params, request_params)
235
379
 
236
- return default_request_params
380
+ return self.default_request_params.model_copy()
237
381
 
238
382
  @classmethod
239
383
  def convert_message_to_message_param(
@@ -435,7 +579,9 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
435
579
  multipart_messages = PromptMessageMultipart.parse_get_prompt_result(prompt_result)
436
580
 
437
581
  # Delegate to the provider-specific implementation
438
- result = await self._apply_prompt_provider_specific(multipart_messages, None)
582
+ result = await self._apply_prompt_provider_specific(
583
+ multipart_messages, None, is_template=True
584
+ )
439
585
  return result.first_text()
440
586
 
441
587
  async def _save_history(self, filename: str) -> None:
@@ -450,26 +596,6 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
450
596
  # Save messages using the unified save function that auto-detects format
451
597
  save_messages_to_file(self._message_history, filename)
452
598
 
453
- @abstractmethod
454
- async def _apply_prompt_provider_specific(
455
- self,
456
- multipart_messages: List["PromptMessageMultipart"],
457
- request_params: RequestParams | None = None,
458
- ) -> PromptMessageMultipart:
459
- """
460
- Provider-specific implementation of apply_prompt_template.
461
- This default implementation handles basic text content for any LLM type.
462
- Provider-specific subclasses should override this method to handle
463
- multimodal content appropriately.
464
-
465
- Args:
466
- multipart_messages: List of PromptMessageMultipart objects parsed from the prompt template
467
-
468
- Returns:
469
- String representation of the assistant's response if generated,
470
- or the last assistant message in the prompt
471
- """
472
-
473
599
  @property
474
600
  def message_history(self) -> List[PromptMessageMultipart]:
475
601
  """
@@ -143,7 +143,6 @@ class PassthroughLLM(AugmentedLLM):
143
143
  ) -> PromptMessageMultipart:
144
144
  last_message = multipart_messages[-1]
145
145
 
146
- # TODO -- improve when we support Audio/Multimodal gen
147
146
  if self.is_tool_call(last_message):
148
147
  result = Prompt.assistant(await self.generate_str(last_message.first_text()))
149
148
  await self.show_assistant_message(result.first_text())
@@ -158,6 +157,7 @@ class PassthroughLLM(AugmentedLLM):
158
157
  await self.show_assistant_message(self._fixed_response)
159
158
  return Prompt.assistant(self._fixed_response)
160
159
  else:
160
+ # TODO -- improve when we support Audio/Multimodal gen models e.g. gemini . This should really just return the input as "assistant"...
161
161
  concatenated: str = "\n".join(message.all_text() for message in multipart_messages)
162
162
  await self.show_assistant_message(concatenated)
163
163
  return Prompt.assistant(concatenated)