fast-agent-mcp 0.2.14__py3-none-any.whl → 0.2.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {fast_agent_mcp-0.2.14.dist-info → fast_agent_mcp-0.2.17.dist-info}/METADATA +4 -6
- {fast_agent_mcp-0.2.14.dist-info → fast_agent_mcp-0.2.17.dist-info}/RECORD +46 -46
- mcp_agent/agents/base_agent.py +50 -6
- mcp_agent/agents/workflow/orchestrator_agent.py +6 -7
- mcp_agent/agents/workflow/router_agent.py +70 -136
- mcp_agent/app.py +1 -124
- mcp_agent/cli/commands/setup.py +1 -1
- mcp_agent/config.py +19 -19
- mcp_agent/context.py +4 -22
- mcp_agent/core/agent_types.py +2 -2
- mcp_agent/core/direct_decorators.py +2 -2
- mcp_agent/core/direct_factory.py +2 -1
- mcp_agent/core/enhanced_prompt.py +6 -5
- mcp_agent/core/fastagent.py +1 -1
- mcp_agent/core/interactive_prompt.py +70 -50
- mcp_agent/core/request_params.py +5 -1
- mcp_agent/executor/workflow_signal.py +0 -2
- mcp_agent/llm/augmented_llm.py +183 -57
- mcp_agent/llm/augmented_llm_passthrough.py +1 -1
- mcp_agent/llm/augmented_llm_playback.py +21 -1
- mcp_agent/llm/memory.py +3 -3
- mcp_agent/llm/model_factory.py +3 -1
- mcp_agent/llm/provider_key_manager.py +1 -0
- mcp_agent/llm/provider_types.py +2 -1
- mcp_agent/llm/providers/augmented_llm_anthropic.py +49 -10
- mcp_agent/llm/providers/augmented_llm_deepseek.py +0 -2
- mcp_agent/llm/providers/augmented_llm_generic.py +4 -2
- mcp_agent/llm/providers/augmented_llm_google.py +30 -0
- mcp_agent/llm/providers/augmented_llm_openai.py +95 -158
- mcp_agent/llm/providers/multipart_converter_openai.py +10 -27
- mcp_agent/llm/providers/sampling_converter_openai.py +5 -6
- mcp_agent/mcp/interfaces.py +6 -1
- mcp_agent/mcp/mcp_aggregator.py +2 -8
- mcp_agent/mcp/prompt_message_multipart.py +25 -2
- mcp_agent/resources/examples/data-analysis/analysis-campaign.py +2 -2
- mcp_agent/resources/examples/in_dev/agent_build.py +1 -1
- mcp_agent/resources/examples/internal/job.py +1 -1
- mcp_agent/resources/examples/mcp/state-transfer/fastagent.config.yaml +1 -1
- mcp_agent/resources/examples/prompting/agent.py +0 -2
- mcp_agent/resources/examples/prompting/fastagent.config.yaml +2 -3
- mcp_agent/resources/examples/researcher/fastagent.config.yaml +1 -6
- mcp_agent/resources/examples/workflows/fastagent.config.yaml +0 -1
- mcp_agent/resources/examples/workflows/parallel.py +1 -1
- mcp_agent/executor/decorator_registry.py +0 -112
- {fast_agent_mcp-0.2.14.dist-info → fast_agent_mcp-0.2.17.dist-info}/WHEEL +0 -0
- {fast_agent_mcp-0.2.14.dist-info → fast_agent_mcp-0.2.17.dist-info}/entry_points.txt +0 -0
- {fast_agent_mcp-0.2.14.dist-info → fast_agent_mcp-0.2.17.dist-info}/licenses/LICENSE +0 -0
@@ -14,6 +14,7 @@ from prompt_toolkit.key_binding import KeyBindings
|
|
14
14
|
from prompt_toolkit.styles import Style
|
15
15
|
from rich import print as rich_print
|
16
16
|
|
17
|
+
from mcp_agent.core.agent_types import AgentType
|
17
18
|
from mcp_agent.core.exceptions import PromptExitError
|
18
19
|
|
19
20
|
# Get the application version
|
@@ -86,7 +87,7 @@ class AgentCompleter(Completer):
|
|
86
87
|
for agent in self.agents:
|
87
88
|
if agent.lower().startswith(agent_name.lower()):
|
88
89
|
# Get agent type or default to "Agent"
|
89
|
-
agent_type = self.agent_types.get(agent,
|
90
|
+
agent_type = self.agent_types.get(agent, AgentType.BASIC).value
|
90
91
|
yield Completion(
|
91
92
|
agent,
|
92
93
|
start_position=-len(agent_name),
|
@@ -149,7 +150,7 @@ async def get_enhanced_input(
|
|
149
150
|
show_stop_hint: bool = False,
|
150
151
|
multiline: bool = False,
|
151
152
|
available_agent_names: List[str] = None,
|
152
|
-
agent_types: dict = None,
|
153
|
+
agent_types: dict[str, AgentType] = None,
|
153
154
|
is_human_input: bool = False,
|
154
155
|
toolbar_color: str = "ansiblue",
|
155
156
|
) -> str:
|
@@ -430,18 +431,18 @@ async def get_argument_input(
|
|
430
431
|
async def handle_special_commands(command, agent_app=None):
|
431
432
|
"""
|
432
433
|
Handle special input commands.
|
433
|
-
|
434
|
+
|
434
435
|
Args:
|
435
436
|
command: The command to handle, can be string or dictionary
|
436
437
|
agent_app: Optional agent app reference
|
437
|
-
|
438
|
+
|
438
439
|
Returns:
|
439
440
|
True if command was handled, False if not, or a dict with action info
|
440
441
|
"""
|
441
442
|
# Quick guard for empty or None commands
|
442
443
|
if not command:
|
443
444
|
return False
|
444
|
-
|
445
|
+
|
445
446
|
# If command is already a dictionary, it has been pre-processed
|
446
447
|
# Just return it directly (like when /prompts converts to select_prompt dict)
|
447
448
|
if isinstance(command, dict):
|
mcp_agent/core/fastagent.py
CHANGED
@@ -381,7 +381,7 @@ class FastAgent:
|
|
381
381
|
handle_error(
|
382
382
|
e,
|
383
383
|
"Model Configuration Error",
|
384
|
-
"Common models: gpt-
|
384
|
+
"Common models: gpt-4.1, o3-mini, sonnet, haiku. for o3, set reasoning effort with o3-mini.high",
|
385
385
|
)
|
386
386
|
elif isinstance(e, CircularDependencyError):
|
387
387
|
handle_error(
|
@@ -20,6 +20,7 @@ from rich import print as rich_print
|
|
20
20
|
from rich.console import Console
|
21
21
|
from rich.table import Table
|
22
22
|
|
23
|
+
from mcp_agent.core.agent_types import AgentType
|
23
24
|
from mcp_agent.core.enhanced_prompt import (
|
24
25
|
get_argument_input,
|
25
26
|
get_enhanced_input,
|
@@ -36,14 +37,14 @@ class InteractivePrompt:
|
|
36
37
|
This is extracted from the original AgentApp implementation to support DirectAgentApp.
|
37
38
|
"""
|
38
39
|
|
39
|
-
def __init__(self, agent_types: Optional[Dict[str,
|
40
|
+
def __init__(self, agent_types: Optional[Dict[str, AgentType]] = None) -> None:
|
40
41
|
"""
|
41
42
|
Initialize the interactive prompt.
|
42
43
|
|
43
44
|
Args:
|
44
45
|
agent_types: Dictionary mapping agent names to their types for display
|
45
46
|
"""
|
46
|
-
self.agent_types = agent_types or {}
|
47
|
+
self.agent_types: Dict[str, AgentType] = agent_types or {}
|
47
48
|
|
48
49
|
async def prompt_loop(
|
49
50
|
self,
|
@@ -97,7 +98,7 @@ class InteractivePrompt:
|
|
97
98
|
|
98
99
|
# Handle special commands - pass "True" to enable agent switching
|
99
100
|
command_result = await handle_special_commands(user_input, True)
|
100
|
-
|
101
|
+
|
101
102
|
# Check if we should switch agents
|
102
103
|
if isinstance(command_result, dict):
|
103
104
|
if "switch_agent" in command_result:
|
@@ -113,11 +114,13 @@ class InteractivePrompt:
|
|
113
114
|
# Use the list_prompts_func directly
|
114
115
|
await self._list_prompts(list_prompts_func, agent)
|
115
116
|
continue
|
116
|
-
elif "select_prompt" in command_result and (
|
117
|
+
elif "select_prompt" in command_result and (
|
118
|
+
list_prompts_func and apply_prompt_func
|
119
|
+
):
|
117
120
|
# Handle prompt selection, using both list_prompts and apply_prompt
|
118
121
|
prompt_name = command_result.get("prompt_name")
|
119
122
|
prompt_index = command_result.get("prompt_index")
|
120
|
-
|
123
|
+
|
121
124
|
# If a specific index was provided (from /prompt <number>)
|
122
125
|
if prompt_index is not None:
|
123
126
|
# First get a list of all prompts to look up the index
|
@@ -125,20 +128,29 @@ class InteractivePrompt:
|
|
125
128
|
if not all_prompts:
|
126
129
|
rich_print("[yellow]No prompts available[/yellow]")
|
127
130
|
continue
|
128
|
-
|
131
|
+
|
129
132
|
# Check if the index is valid
|
130
133
|
if 1 <= prompt_index <= len(all_prompts):
|
131
134
|
# Get the prompt at the specified index (1-based to 0-based)
|
132
135
|
selected_prompt = all_prompts[prompt_index - 1]
|
133
136
|
# Use the already created namespaced_name to ensure consistency
|
134
|
-
await self._select_prompt(
|
137
|
+
await self._select_prompt(
|
138
|
+
list_prompts_func,
|
139
|
+
apply_prompt_func,
|
140
|
+
agent,
|
141
|
+
selected_prompt["namespaced_name"],
|
142
|
+
)
|
135
143
|
else:
|
136
|
-
rich_print(
|
144
|
+
rich_print(
|
145
|
+
f"[red]Invalid prompt number: {prompt_index}. Valid range is 1-{len(all_prompts)}[/red]"
|
146
|
+
)
|
137
147
|
# Show the prompt list for convenience
|
138
148
|
await self._list_prompts(list_prompts_func, agent)
|
139
149
|
else:
|
140
150
|
# Use the name-based selection
|
141
|
-
await self._select_prompt(
|
151
|
+
await self._select_prompt(
|
152
|
+
list_prompts_func, apply_prompt_func, agent, prompt_name
|
153
|
+
)
|
142
154
|
continue
|
143
155
|
|
144
156
|
# Skip further processing if command was handled
|
@@ -158,11 +170,11 @@ class InteractivePrompt:
|
|
158
170
|
async def _get_all_prompts(self, list_prompts_func, agent_name):
|
159
171
|
"""
|
160
172
|
Get a list of all available prompts.
|
161
|
-
|
173
|
+
|
162
174
|
Args:
|
163
175
|
list_prompts_func: Function to get available prompts
|
164
176
|
agent_name: Name of the agent
|
165
|
-
|
177
|
+
|
166
178
|
Returns:
|
167
179
|
List of prompt info dictionaries, sorted by server and name
|
168
180
|
"""
|
@@ -171,7 +183,7 @@ class InteractivePrompt:
|
|
171
183
|
# the agent_name parameter should never be used as a server name
|
172
184
|
prompt_servers = await list_prompts_func(None)
|
173
185
|
all_prompts = []
|
174
|
-
|
186
|
+
|
175
187
|
# Process the returned prompt servers
|
176
188
|
if prompt_servers:
|
177
189
|
# First collect all prompts
|
@@ -179,44 +191,51 @@ class InteractivePrompt:
|
|
179
191
|
if prompts_info and hasattr(prompts_info, "prompts") and prompts_info.prompts:
|
180
192
|
for prompt in prompts_info.prompts:
|
181
193
|
# Use the SEP constant for proper namespacing
|
182
|
-
all_prompts.append(
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
194
|
+
all_prompts.append(
|
195
|
+
{
|
196
|
+
"server": server_name,
|
197
|
+
"name": prompt.name,
|
198
|
+
"namespaced_name": f"{server_name}{SEP}{prompt.name}",
|
199
|
+
"description": getattr(prompt, "description", "No description"),
|
200
|
+
"arg_count": len(getattr(prompt, "arguments", [])),
|
201
|
+
"arguments": getattr(prompt, "arguments", []),
|
202
|
+
}
|
203
|
+
)
|
190
204
|
elif isinstance(prompts_info, list) and prompts_info:
|
191
205
|
for prompt in prompts_info:
|
192
206
|
if isinstance(prompt, dict) and "name" in prompt:
|
193
|
-
all_prompts.append(
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
207
|
+
all_prompts.append(
|
208
|
+
{
|
209
|
+
"server": server_name,
|
210
|
+
"name": prompt["name"],
|
211
|
+
"namespaced_name": f"{server_name}{SEP}{prompt['name']}",
|
212
|
+
"description": prompt.get("description", "No description"),
|
213
|
+
"arg_count": len(prompt.get("arguments", [])),
|
214
|
+
"arguments": prompt.get("arguments", []),
|
215
|
+
}
|
216
|
+
)
|
201
217
|
else:
|
202
|
-
all_prompts.append(
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
218
|
+
all_prompts.append(
|
219
|
+
{
|
220
|
+
"server": server_name,
|
221
|
+
"name": str(prompt),
|
222
|
+
"namespaced_name": f"{server_name}{SEP}{str(prompt)}",
|
223
|
+
"description": "No description",
|
224
|
+
"arg_count": 0,
|
225
|
+
"arguments": [],
|
226
|
+
}
|
227
|
+
)
|
228
|
+
|
211
229
|
# Sort prompts by server and name for consistent ordering
|
212
230
|
all_prompts.sort(key=lambda p: (p["server"], p["name"]))
|
213
|
-
|
231
|
+
|
214
232
|
return all_prompts
|
215
|
-
|
233
|
+
|
216
234
|
except Exception as e:
|
217
235
|
import traceback
|
218
236
|
|
219
237
|
from rich import print as rich_print
|
238
|
+
|
220
239
|
rich_print(f"[red]Error getting prompts: {e}[/red]")
|
221
240
|
rich_print(f"[dim]{traceback.format_exc()}[/dim]")
|
222
241
|
return []
|
@@ -238,11 +257,11 @@ class InteractivePrompt:
|
|
238
257
|
try:
|
239
258
|
# Directly call the list_prompts function for this agent
|
240
259
|
rich_print(f"\n[bold]Fetching prompts for agent [cyan]{agent_name}[/cyan]...[/bold]")
|
241
|
-
|
260
|
+
|
242
261
|
# Get all prompts using the helper function - pass None as server name
|
243
262
|
# to get prompts from all available servers
|
244
263
|
all_prompts = await self._get_all_prompts(list_prompts_func, None)
|
245
|
-
|
264
|
+
|
246
265
|
if all_prompts:
|
247
266
|
# Create a table for better display
|
248
267
|
table = Table(title="Available MCP Prompts")
|
@@ -251,7 +270,7 @@ class InteractivePrompt:
|
|
251
270
|
table.add_column("Prompt Name", style="bright_blue")
|
252
271
|
table.add_column("Description")
|
253
272
|
table.add_column("Args", justify="center")
|
254
|
-
|
273
|
+
|
255
274
|
# Add prompts to table
|
256
275
|
for i, prompt in enumerate(all_prompts):
|
257
276
|
table.add_row(
|
@@ -259,11 +278,11 @@ class InteractivePrompt:
|
|
259
278
|
prompt["server"],
|
260
279
|
prompt["name"],
|
261
280
|
prompt["description"],
|
262
|
-
str(prompt["arg_count"])
|
281
|
+
str(prompt["arg_count"]),
|
263
282
|
)
|
264
|
-
|
283
|
+
|
265
284
|
console.print(table)
|
266
|
-
|
285
|
+
|
267
286
|
# Add usage instructions
|
268
287
|
rich_print("\n[bold]Usage:[/bold]")
|
269
288
|
rich_print(" • Use [cyan]/prompt <number>[/cyan] to select a prompt by number")
|
@@ -272,10 +291,13 @@ class InteractivePrompt:
|
|
272
291
|
rich_print("[yellow]No prompts available[/yellow]")
|
273
292
|
except Exception as e:
|
274
293
|
import traceback
|
294
|
+
|
275
295
|
rich_print(f"[red]Error listing prompts: {e}[/red]")
|
276
296
|
rich_print(f"[dim]{traceback.format_exc()}[/dim]")
|
277
297
|
|
278
|
-
async def _select_prompt(
|
298
|
+
async def _select_prompt(
|
299
|
+
self, list_prompts_func, apply_prompt_func, agent_name, requested_name=None
|
300
|
+
) -> None:
|
279
301
|
"""
|
280
302
|
Select and apply a prompt.
|
281
303
|
|
@@ -293,7 +315,7 @@ class InteractivePrompt:
|
|
293
315
|
try:
|
294
316
|
# Get all available prompts directly from the list_prompts function
|
295
317
|
rich_print(f"\n[bold]Fetching prompts for agent [cyan]{agent_name}[/cyan]...[/bold]")
|
296
|
-
# IMPORTANT: list_prompts_func gets MCP server prompts, not agent prompts
|
318
|
+
# IMPORTANT: list_prompts_func gets MCP server prompts, not agent prompts
|
297
319
|
# So we pass None to get prompts from all servers, not using agent_name as server name
|
298
320
|
prompt_servers = await list_prompts_func(None)
|
299
321
|
|
@@ -514,9 +536,7 @@ class InteractivePrompt:
|
|
514
536
|
|
515
537
|
# Apply the prompt
|
516
538
|
namespaced_name = selected_prompt["namespaced_name"]
|
517
|
-
rich_print(
|
518
|
-
f"\n[bold]Applying prompt [cyan]{namespaced_name}[/cyan]...[/bold]"
|
519
|
-
)
|
539
|
+
rich_print(f"\n[bold]Applying prompt [cyan]{namespaced_name}[/cyan]...[/bold]")
|
520
540
|
|
521
541
|
# Call apply_prompt function with the prompt name and arguments
|
522
542
|
await apply_prompt_func(namespaced_name, arg_values, agent_name)
|
mcp_agent/core/request_params.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2
2
|
Request parameters definitions for LLM interactions.
|
3
3
|
"""
|
4
4
|
|
5
|
-
from typing import List
|
5
|
+
from typing import Any, List
|
6
6
|
|
7
7
|
from mcp import SamplingMessage
|
8
8
|
from mcp.types import CreateMessageRequestParams
|
@@ -44,3 +44,7 @@ class RequestParams(CreateMessageRequestParams):
|
|
44
44
|
Whether to allow multiple tool calls per iteration.
|
45
45
|
Also known as multi-step tool use.
|
46
46
|
"""
|
47
|
+
response_format: Any | None = None
|
48
|
+
"""
|
49
|
+
Override response format for structured calls. Prefer sending pydantic model - only use in exceptional circumstances
|
50
|
+
"""
|
@@ -7,8 +7,6 @@ from pydantic import BaseModel, ConfigDict
|
|
7
7
|
|
8
8
|
SignalValueT = TypeVar("SignalValueT")
|
9
9
|
|
10
|
-
# TODO: saqadri - handle signals properly that works with other execution backends like Temporal as well
|
11
|
-
|
12
10
|
|
13
11
|
class Signal(BaseModel, Generic[SignalValueT]):
|
14
12
|
"""Represents a signal that can be sent to a workflow."""
|
mcp_agent/llm/augmented_llm.py
CHANGED
@@ -18,6 +18,8 @@ from mcp.types import (
|
|
18
18
|
PromptMessage,
|
19
19
|
TextContent,
|
20
20
|
)
|
21
|
+
from openai import NotGiven
|
22
|
+
from openai.lib._parsing import type_to_response_format_param as _type_to_response_format
|
21
23
|
from pydantic_core import from_json
|
22
24
|
from rich.text import Text
|
23
25
|
|
@@ -58,6 +60,20 @@ HUMAN_INPUT_TOOL_NAME = "__human_input__"
|
|
58
60
|
|
59
61
|
|
60
62
|
class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT, MessageT]):
|
63
|
+
# Common parameter names used across providers
|
64
|
+
PARAM_MESSAGES = "messages"
|
65
|
+
PARAM_MODEL = "model"
|
66
|
+
PARAM_MAX_TOKENS = "maxTokens"
|
67
|
+
PARAM_SYSTEM_PROMPT = "systemPrompt"
|
68
|
+
PARAM_STOP_SEQUENCES = "stopSequences"
|
69
|
+
PARAM_PARALLEL_TOOL_CALLS = "parallel_tool_calls"
|
70
|
+
PARAM_METADATA = "metadata"
|
71
|
+
PARAM_USE_HISTORY = "use_history"
|
72
|
+
PARAM_MAX_ITERATIONS = "max_iterations"
|
73
|
+
|
74
|
+
# Base set of fields that should always be excluded
|
75
|
+
BASE_EXCLUDE_FIELDS = {PARAM_METADATA}
|
76
|
+
|
61
77
|
"""
|
62
78
|
The basic building block of agentic systems is an LLM enhanced with augmentations
|
63
79
|
such as retrieval, tools, and memory provided from a collection of MCP servers.
|
@@ -141,26 +157,6 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
|
|
141
157
|
use_history=True,
|
142
158
|
)
|
143
159
|
|
144
|
-
async def structured(
|
145
|
-
self,
|
146
|
-
prompt: List[PromptMessageMultipart],
|
147
|
-
model: Type[ModelT],
|
148
|
-
request_params: RequestParams | None = None,
|
149
|
-
) -> Tuple[ModelT | None, PromptMessageMultipart]:
|
150
|
-
"""Apply the prompt and return the result as a Pydantic model, or None if coercion fails"""
|
151
|
-
try:
|
152
|
-
result: PromptMessageMultipart = await self.generate(prompt, request_params)
|
153
|
-
final_generation = get_text(result.content[-1]) or ""
|
154
|
-
await self.show_assistant_message(final_generation)
|
155
|
-
json_data = from_json(final_generation, allow_partial=True)
|
156
|
-
validated_model = model.model_validate(json_data)
|
157
|
-
|
158
|
-
return cast("ModelT", validated_model), Prompt.assistant(json_data)
|
159
|
-
except Exception as e:
|
160
|
-
logger = get_logger(__name__)
|
161
|
-
logger.error(f"Failed to parse structured response: {str(e)}")
|
162
|
-
return None, Prompt.assistant(f"Failed to parse structured response: {str(e)}")
|
163
|
-
|
164
160
|
async def generate(
|
165
161
|
self,
|
166
162
|
multipart_messages: List[PromptMessageMultipart],
|
@@ -169,6 +165,12 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
|
|
169
165
|
"""
|
170
166
|
Create a completion with the LLM using the provided messages.
|
171
167
|
"""
|
168
|
+
# note - check changes here are mirrored in structured(). i've thought hard about
|
169
|
+
# a strategy to reduce duplication etc, but aiming for simple but imperfect for the moment
|
170
|
+
|
171
|
+
# We never expect this for structured() calls - this is for interactive use - developers
|
172
|
+
# can do this programatically
|
173
|
+
# TODO -- create a "fast-agent" control role rather than magic strings
|
172
174
|
if multipart_messages[-1].first_text().startswith("***SAVE_HISTORY"):
|
173
175
|
parts: list[str] = multipart_messages[-1].first_text().split(" ", 1)
|
174
176
|
filename: str = (
|
@@ -180,26 +182,174 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
|
|
180
182
|
)
|
181
183
|
return Prompt.assistant(f"History saved to {filename}")
|
182
184
|
|
183
|
-
self.
|
184
|
-
|
185
|
-
if multipart_messages[-1].role == "user":
|
186
|
-
self.show_user_message(
|
187
|
-
render_multipart_message(multipart_messages[-1]),
|
188
|
-
model=self.default_request_params.model,
|
189
|
-
chat_turn=self.chat_turn(),
|
190
|
-
)
|
185
|
+
self._precall(multipart_messages)
|
191
186
|
|
192
187
|
assistant_response: PromptMessageMultipart = await self._apply_prompt_provider_specific(
|
193
188
|
multipart_messages, request_params
|
194
189
|
)
|
195
190
|
|
191
|
+
# add generic error and termination reason handling/rollback
|
196
192
|
self._message_history.append(assistant_response)
|
197
193
|
return assistant_response
|
198
194
|
|
195
|
+
@abstractmethod
|
196
|
+
async def _apply_prompt_provider_specific(
|
197
|
+
self,
|
198
|
+
multipart_messages: List["PromptMessageMultipart"],
|
199
|
+
request_params: RequestParams | None = None,
|
200
|
+
is_template: bool = False,
|
201
|
+
) -> PromptMessageMultipart:
|
202
|
+
"""
|
203
|
+
Provider-specific implementation of apply_prompt_template.
|
204
|
+
This default implementation handles basic text content for any LLM type.
|
205
|
+
Provider-specific subclasses should override this method to handle
|
206
|
+
multimodal content appropriately.
|
207
|
+
|
208
|
+
Args:
|
209
|
+
multipart_messages: List of PromptMessageMultipart objects parsed from the prompt template
|
210
|
+
|
211
|
+
Returns:
|
212
|
+
String representation of the assistant's response if generated,
|
213
|
+
or the last assistant message in the prompt
|
214
|
+
"""
|
215
|
+
|
216
|
+
async def structured(
|
217
|
+
self,
|
218
|
+
multipart_messages: List[PromptMessageMultipart],
|
219
|
+
model: Type[ModelT],
|
220
|
+
request_params: RequestParams | None = None,
|
221
|
+
) -> Tuple[ModelT | None, PromptMessageMultipart]:
|
222
|
+
"""Return a structured response from the LLM using the provided messages."""
|
223
|
+
self._precall(multipart_messages)
|
224
|
+
result, assistant_response = await self._apply_prompt_provider_specific_structured(
|
225
|
+
multipart_messages, model, request_params
|
226
|
+
)
|
227
|
+
|
228
|
+
self._message_history.append(assistant_response)
|
229
|
+
return result, assistant_response
|
230
|
+
|
231
|
+
@staticmethod
|
232
|
+
def model_to_response_format(
|
233
|
+
model: Type[Any],
|
234
|
+
) -> Any:
|
235
|
+
"""
|
236
|
+
Convert a pydantic model to the appropriate response format schema.
|
237
|
+
This allows for reuse in multiple provider implementations.
|
238
|
+
|
239
|
+
Args:
|
240
|
+
model: The pydantic model class to convert to a schema
|
241
|
+
|
242
|
+
Returns:
|
243
|
+
Provider-agnostic schema representation or NotGiven if conversion fails
|
244
|
+
"""
|
245
|
+
return _type_to_response_format(model)
|
246
|
+
|
247
|
+
@staticmethod
|
248
|
+
def model_to_schema_str(
|
249
|
+
model: Type[Any],
|
250
|
+
) -> str:
|
251
|
+
"""
|
252
|
+
Convert a pydantic model to a schema string representation.
|
253
|
+
This provides a simpler interface for provider implementations
|
254
|
+
that need a string representation.
|
255
|
+
|
256
|
+
Args:
|
257
|
+
model: The pydantic model class to convert to a schema
|
258
|
+
|
259
|
+
Returns:
|
260
|
+
Schema as a string, or empty string if conversion fails
|
261
|
+
"""
|
262
|
+
import json
|
263
|
+
|
264
|
+
try:
|
265
|
+
schema = model.model_json_schema()
|
266
|
+
return json.dumps(schema)
|
267
|
+
except Exception:
|
268
|
+
return ""
|
269
|
+
|
270
|
+
async def _apply_prompt_provider_specific_structured(
|
271
|
+
self,
|
272
|
+
multipart_messages: List[PromptMessageMultipart],
|
273
|
+
model: Type[ModelT],
|
274
|
+
request_params: RequestParams | None = None,
|
275
|
+
) -> Tuple[ModelT | None, PromptMessageMultipart]:
|
276
|
+
"""Base class attempts to parse JSON - subclasses can use provider specific functionality"""
|
277
|
+
|
278
|
+
request_params = self.get_request_params(request_params)
|
279
|
+
|
280
|
+
if not request_params.response_format:
|
281
|
+
schema = self.model_to_response_format(model)
|
282
|
+
if schema is not NotGiven:
|
283
|
+
request_params.response_format = schema
|
284
|
+
|
285
|
+
result: PromptMessageMultipart = await self._apply_prompt_provider_specific(
|
286
|
+
multipart_messages, request_params
|
287
|
+
)
|
288
|
+
return self._structured_from_multipart(result, model)
|
289
|
+
|
290
|
+
def _structured_from_multipart(
|
291
|
+
self, message: PromptMessageMultipart, model: Type[ModelT]
|
292
|
+
) -> Tuple[ModelT | None, PromptMessageMultipart]:
|
293
|
+
"""Parse the content of a PromptMessage and return the structured model and message itself"""
|
294
|
+
try:
|
295
|
+
text = get_text(message.content[-1]) or ""
|
296
|
+
json_data = from_json(text, allow_partial=True)
|
297
|
+
validated_model = model.model_validate(json_data)
|
298
|
+
return cast("ModelT", validated_model), message
|
299
|
+
except ValueError as e:
|
300
|
+
logger = get_logger(__name__)
|
301
|
+
logger.warning(f"Failed to parse structured response: {str(e)}")
|
302
|
+
return None, message
|
303
|
+
|
304
|
+
def _precall(self, multipart_messages: List[PromptMessageMultipart]) -> None:
|
305
|
+
"""Pre-call hook to modify the message before sending it to the provider."""
|
306
|
+
self._message_history.extend(multipart_messages)
|
307
|
+
if multipart_messages[-1].role == "user":
|
308
|
+
self.show_user_message(
|
309
|
+
render_multipart_message(multipart_messages[-1]),
|
310
|
+
model=self.default_request_params.model,
|
311
|
+
chat_turn=self.chat_turn(),
|
312
|
+
)
|
313
|
+
|
199
314
|
def chat_turn(self) -> int:
|
200
315
|
"""Return the current chat turn number"""
|
201
316
|
return 1 + sum(1 for message in self._message_history if message.role == "assistant")
|
202
317
|
|
318
|
+
def prepare_provider_arguments(
|
319
|
+
self,
|
320
|
+
base_args: dict,
|
321
|
+
request_params: RequestParams,
|
322
|
+
exclude_fields: set | None = None,
|
323
|
+
) -> dict:
|
324
|
+
"""
|
325
|
+
Prepare arguments for provider API calls by merging request parameters.
|
326
|
+
|
327
|
+
Args:
|
328
|
+
base_args: Base arguments dictionary with provider-specific required parameters
|
329
|
+
params: The RequestParams object containing all parameters
|
330
|
+
exclude_fields: Set of field names to exclude from params. If None, uses BASE_EXCLUDE_FIELDS.
|
331
|
+
|
332
|
+
Returns:
|
333
|
+
Complete arguments dictionary with all applicable parameters
|
334
|
+
"""
|
335
|
+
# Start with base arguments
|
336
|
+
arguments = base_args.copy()
|
337
|
+
|
338
|
+
# Use provided exclude_fields or fall back to base exclusions
|
339
|
+
exclude_fields = exclude_fields or self.BASE_EXCLUDE_FIELDS.copy()
|
340
|
+
|
341
|
+
# Add all fields from params that aren't explicitly excluded
|
342
|
+
params_dict = request_params.model_dump(exclude=exclude_fields)
|
343
|
+
for key, value in params_dict.items():
|
344
|
+
if value is not None and key not in arguments:
|
345
|
+
arguments[key] = value
|
346
|
+
|
347
|
+
# Finally, add any metadata fields as a last layer of overrides
|
348
|
+
if request_params.metadata:
|
349
|
+
arguments.update(request_params.metadata)
|
350
|
+
|
351
|
+
return arguments
|
352
|
+
|
203
353
|
def _merge_request_params(
|
204
354
|
self, default_params: RequestParams, provided_params: RequestParams
|
205
355
|
) -> RequestParams:
|
@@ -214,7 +364,6 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
|
|
214
364
|
def get_request_params(
|
215
365
|
self,
|
216
366
|
request_params: RequestParams | None = None,
|
217
|
-
default: RequestParams | None = None,
|
218
367
|
) -> RequestParams:
|
219
368
|
"""
|
220
369
|
Get request parameters with merged-in defaults and overrides.
|
@@ -223,17 +372,12 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
|
|
223
372
|
default: The default request parameters to use as the base.
|
224
373
|
If unspecified, self.default_request_params will be used.
|
225
374
|
"""
|
226
|
-
# Start with the defaults
|
227
|
-
default_request_params = default or self.default_request_params
|
228
|
-
|
229
|
-
if not default_request_params:
|
230
|
-
default_request_params = self._initialize_default_params({})
|
231
375
|
|
232
376
|
# If user provides overrides, merge them with defaults
|
233
377
|
if request_params:
|
234
|
-
return self._merge_request_params(default_request_params, request_params)
|
378
|
+
return self._merge_request_params(self.default_request_params, request_params)
|
235
379
|
|
236
|
-
return default_request_params
|
380
|
+
return self.default_request_params.model_copy()
|
237
381
|
|
238
382
|
@classmethod
|
239
383
|
def convert_message_to_message_param(
|
@@ -435,7 +579,9 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
|
|
435
579
|
multipart_messages = PromptMessageMultipart.parse_get_prompt_result(prompt_result)
|
436
580
|
|
437
581
|
# Delegate to the provider-specific implementation
|
438
|
-
result = await self._apply_prompt_provider_specific(
|
582
|
+
result = await self._apply_prompt_provider_specific(
|
583
|
+
multipart_messages, None, is_template=True
|
584
|
+
)
|
439
585
|
return result.first_text()
|
440
586
|
|
441
587
|
async def _save_history(self, filename: str) -> None:
|
@@ -450,26 +596,6 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
|
|
450
596
|
# Save messages using the unified save function that auto-detects format
|
451
597
|
save_messages_to_file(self._message_history, filename)
|
452
598
|
|
453
|
-
@abstractmethod
|
454
|
-
async def _apply_prompt_provider_specific(
|
455
|
-
self,
|
456
|
-
multipart_messages: List["PromptMessageMultipart"],
|
457
|
-
request_params: RequestParams | None = None,
|
458
|
-
) -> PromptMessageMultipart:
|
459
|
-
"""
|
460
|
-
Provider-specific implementation of apply_prompt_template.
|
461
|
-
This default implementation handles basic text content for any LLM type.
|
462
|
-
Provider-specific subclasses should override this method to handle
|
463
|
-
multimodal content appropriately.
|
464
|
-
|
465
|
-
Args:
|
466
|
-
multipart_messages: List of PromptMessageMultipart objects parsed from the prompt template
|
467
|
-
|
468
|
-
Returns:
|
469
|
-
String representation of the assistant's response if generated,
|
470
|
-
or the last assistant message in the prompt
|
471
|
-
"""
|
472
|
-
|
473
599
|
@property
|
474
600
|
def message_history(self) -> List[PromptMessageMultipart]:
|
475
601
|
"""
|
@@ -143,7 +143,6 @@ class PassthroughLLM(AugmentedLLM):
|
|
143
143
|
) -> PromptMessageMultipart:
|
144
144
|
last_message = multipart_messages[-1]
|
145
145
|
|
146
|
-
# TODO -- improve when we support Audio/Multimodal gen
|
147
146
|
if self.is_tool_call(last_message):
|
148
147
|
result = Prompt.assistant(await self.generate_str(last_message.first_text()))
|
149
148
|
await self.show_assistant_message(result.first_text())
|
@@ -158,6 +157,7 @@ class PassthroughLLM(AugmentedLLM):
|
|
158
157
|
await self.show_assistant_message(self._fixed_response)
|
159
158
|
return Prompt.assistant(self._fixed_response)
|
160
159
|
else:
|
160
|
+
# TODO -- improve when we support Audio/Multimodal gen models e.g. gemini . This should really just return the input as "assistant"...
|
161
161
|
concatenated: str = "\n".join(message.all_text() for message in multipart_messages)
|
162
162
|
await self.show_assistant_message(concatenated)
|
163
163
|
return Prompt.assistant(concatenated)
|