fast-agent-mcp 0.1.13__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {fast_agent_mcp-0.1.13.dist-info → fast_agent_mcp-0.2.0.dist-info}/METADATA +3 -4
- fast_agent_mcp-0.2.0.dist-info/RECORD +123 -0
- mcp_agent/__init__.py +75 -0
- mcp_agent/agents/agent.py +59 -371
- mcp_agent/agents/base_agent.py +522 -0
- mcp_agent/agents/workflow/__init__.py +1 -0
- mcp_agent/agents/workflow/chain_agent.py +173 -0
- mcp_agent/agents/workflow/evaluator_optimizer.py +362 -0
- mcp_agent/agents/workflow/orchestrator_agent.py +591 -0
- mcp_agent/{workflows/orchestrator → agents/workflow}/orchestrator_models.py +27 -11
- mcp_agent/agents/workflow/parallel_agent.py +182 -0
- mcp_agent/agents/workflow/router_agent.py +307 -0
- mcp_agent/app.py +3 -1
- mcp_agent/cli/commands/bootstrap.py +18 -7
- mcp_agent/cli/commands/setup.py +12 -4
- mcp_agent/cli/main.py +1 -1
- mcp_agent/cli/terminal.py +1 -1
- mcp_agent/config.py +24 -35
- mcp_agent/context.py +3 -1
- mcp_agent/context_dependent.py +3 -1
- mcp_agent/core/agent_types.py +10 -7
- mcp_agent/core/direct_agent_app.py +179 -0
- mcp_agent/core/direct_decorators.py +443 -0
- mcp_agent/core/direct_factory.py +476 -0
- mcp_agent/core/enhanced_prompt.py +15 -20
- mcp_agent/core/fastagent.py +151 -337
- mcp_agent/core/interactive_prompt.py +424 -0
- mcp_agent/core/mcp_content.py +19 -11
- mcp_agent/core/prompt.py +6 -2
- mcp_agent/core/validation.py +89 -16
- mcp_agent/executor/decorator_registry.py +6 -2
- mcp_agent/executor/temporal.py +35 -11
- mcp_agent/executor/workflow_signal.py +8 -2
- mcp_agent/human_input/handler.py +3 -1
- mcp_agent/llm/__init__.py +2 -0
- mcp_agent/{workflows/llm → llm}/augmented_llm.py +131 -256
- mcp_agent/{workflows/llm → llm}/augmented_llm_passthrough.py +35 -107
- mcp_agent/llm/augmented_llm_playback.py +83 -0
- mcp_agent/{workflows/llm → llm}/model_factory.py +26 -8
- mcp_agent/llm/providers/__init__.py +8 -0
- mcp_agent/{workflows/llm → llm/providers}/anthropic_utils.py +5 -1
- mcp_agent/{workflows/llm → llm/providers}/augmented_llm_anthropic.py +37 -141
- mcp_agent/llm/providers/augmented_llm_deepseek.py +53 -0
- mcp_agent/{workflows/llm → llm/providers}/augmented_llm_openai.py +112 -148
- mcp_agent/{workflows/llm → llm}/providers/multipart_converter_anthropic.py +78 -35
- mcp_agent/{workflows/llm → llm}/providers/multipart_converter_openai.py +73 -44
- mcp_agent/{workflows/llm → llm}/providers/openai_multipart.py +18 -4
- mcp_agent/{workflows/llm → llm/providers}/openai_utils.py +3 -3
- mcp_agent/{workflows/llm → llm}/providers/sampling_converter_anthropic.py +3 -3
- mcp_agent/{workflows/llm → llm}/providers/sampling_converter_openai.py +3 -3
- mcp_agent/{workflows/llm → llm}/sampling_converter.py +0 -21
- mcp_agent/{workflows/llm → llm}/sampling_format_converter.py +16 -1
- mcp_agent/logging/logger.py +2 -2
- mcp_agent/mcp/gen_client.py +9 -3
- mcp_agent/mcp/interfaces.py +67 -45
- mcp_agent/mcp/logger_textio.py +97 -0
- mcp_agent/mcp/mcp_agent_client_session.py +12 -4
- mcp_agent/mcp/mcp_agent_server.py +3 -1
- mcp_agent/mcp/mcp_aggregator.py +124 -93
- mcp_agent/mcp/mcp_connection_manager.py +21 -7
- mcp_agent/mcp/prompt_message_multipart.py +59 -1
- mcp_agent/mcp/prompt_render.py +77 -0
- mcp_agent/mcp/prompt_serialization.py +20 -13
- mcp_agent/mcp/prompts/prompt_constants.py +18 -0
- mcp_agent/mcp/prompts/prompt_helpers.py +327 -0
- mcp_agent/mcp/prompts/prompt_load.py +15 -5
- mcp_agent/mcp/prompts/prompt_server.py +154 -87
- mcp_agent/mcp/prompts/prompt_template.py +26 -35
- mcp_agent/mcp/resource_utils.py +3 -1
- mcp_agent/mcp/sampling.py +24 -15
- mcp_agent/mcp_server/agent_server.py +8 -5
- mcp_agent/mcp_server_registry.py +22 -9
- mcp_agent/resources/examples/{workflows → in_dev}/agent_build.py +1 -1
- mcp_agent/resources/examples/{data-analysis → in_dev}/slides.py +1 -1
- mcp_agent/resources/examples/internal/agent.py +4 -2
- mcp_agent/resources/examples/internal/fastagent.config.yaml +8 -2
- mcp_agent/resources/examples/prompting/image_server.py +3 -1
- mcp_agent/resources/examples/prompting/work_with_image.py +19 -0
- mcp_agent/ui/console_display.py +27 -7
- fast_agent_mcp-0.1.13.dist-info/RECORD +0 -164
- mcp_agent/core/agent_app.py +0 -570
- mcp_agent/core/agent_utils.py +0 -69
- mcp_agent/core/decorators.py +0 -448
- mcp_agent/core/factory.py +0 -422
- mcp_agent/core/proxies.py +0 -278
- mcp_agent/core/types.py +0 -22
- mcp_agent/eval/__init__.py +0 -0
- mcp_agent/mcp/stdio.py +0 -114
- mcp_agent/resources/examples/data-analysis/analysis-campaign.py +0 -188
- mcp_agent/resources/examples/data-analysis/analysis.py +0 -65
- mcp_agent/resources/examples/data-analysis/fastagent.config.yaml +0 -41
- mcp_agent/resources/examples/data-analysis/mount-point/WA_Fn-UseC_-HR-Employee-Attrition.csv +0 -1471
- mcp_agent/resources/examples/mcp_researcher/researcher-eval.py +0 -53
- mcp_agent/resources/examples/researcher/fastagent.config.yaml +0 -66
- mcp_agent/resources/examples/researcher/researcher-eval.py +0 -53
- mcp_agent/resources/examples/researcher/researcher-imp.py +0 -189
- mcp_agent/resources/examples/researcher/researcher.py +0 -39
- mcp_agent/resources/examples/workflows/chaining.py +0 -45
- mcp_agent/resources/examples/workflows/evaluator.py +0 -79
- mcp_agent/resources/examples/workflows/fastagent.config.yaml +0 -24
- mcp_agent/resources/examples/workflows/human_input.py +0 -26
- mcp_agent/resources/examples/workflows/orchestrator.py +0 -74
- mcp_agent/resources/examples/workflows/parallel.py +0 -79
- mcp_agent/resources/examples/workflows/router.py +0 -54
- mcp_agent/resources/examples/workflows/sse.py +0 -23
- mcp_agent/telemetry/__init__.py +0 -0
- mcp_agent/telemetry/usage_tracking.py +0 -19
- mcp_agent/workflows/__init__.py +0 -0
- mcp_agent/workflows/embedding/__init__.py +0 -0
- mcp_agent/workflows/embedding/embedding_base.py +0 -58
- mcp_agent/workflows/embedding/embedding_cohere.py +0 -49
- mcp_agent/workflows/embedding/embedding_openai.py +0 -37
- mcp_agent/workflows/evaluator_optimizer/__init__.py +0 -0
- mcp_agent/workflows/evaluator_optimizer/evaluator_optimizer.py +0 -447
- mcp_agent/workflows/intent_classifier/__init__.py +0 -0
- mcp_agent/workflows/intent_classifier/intent_classifier_base.py +0 -117
- mcp_agent/workflows/intent_classifier/intent_classifier_embedding.py +0 -130
- mcp_agent/workflows/intent_classifier/intent_classifier_embedding_cohere.py +0 -41
- mcp_agent/workflows/intent_classifier/intent_classifier_embedding_openai.py +0 -41
- mcp_agent/workflows/intent_classifier/intent_classifier_llm.py +0 -150
- mcp_agent/workflows/intent_classifier/intent_classifier_llm_anthropic.py +0 -60
- mcp_agent/workflows/intent_classifier/intent_classifier_llm_openai.py +0 -58
- mcp_agent/workflows/llm/__init__.py +0 -0
- mcp_agent/workflows/llm/augmented_llm_playback.py +0 -111
- mcp_agent/workflows/llm/providers/__init__.py +0 -8
- mcp_agent/workflows/orchestrator/__init__.py +0 -0
- mcp_agent/workflows/orchestrator/orchestrator.py +0 -535
- mcp_agent/workflows/parallel/__init__.py +0 -0
- mcp_agent/workflows/parallel/fan_in.py +0 -320
- mcp_agent/workflows/parallel/fan_out.py +0 -181
- mcp_agent/workflows/parallel/parallel_llm.py +0 -149
- mcp_agent/workflows/router/__init__.py +0 -0
- mcp_agent/workflows/router/router_base.py +0 -338
- mcp_agent/workflows/router/router_embedding.py +0 -226
- mcp_agent/workflows/router/router_embedding_cohere.py +0 -59
- mcp_agent/workflows/router/router_embedding_openai.py +0 -59
- mcp_agent/workflows/router/router_llm.py +0 -304
- mcp_agent/workflows/swarm/__init__.py +0 -0
- mcp_agent/workflows/swarm/swarm.py +0 -292
- mcp_agent/workflows/swarm/swarm_anthropic.py +0 -42
- mcp_agent/workflows/swarm/swarm_openai.py +0 -41
- {fast_agent_mcp-0.1.13.dist-info → fast_agent_mcp-0.2.0.dist-info}/WHEEL +0 -0
- {fast_agent_mcp-0.1.13.dist-info → fast_agent_mcp-0.2.0.dist-info}/entry_points.txt +0 -0
- {fast_agent_mcp-0.1.13.dist-info → fast_agent_mcp-0.2.0.dist-info}/licenses/LICENSE +0 -0
- /mcp_agent/{workflows/orchestrator → agents/workflow}/orchestrator_prompts.py +0 -0
- /mcp_agent/{workflows/llm → llm}/memory.py +0 -0
- /mcp_agent/{workflows/llm → llm}/prompt_utils.py +0 -0
@@ -2,32 +2,14 @@ from abc import abstractmethod
|
|
2
2
|
from typing import (
|
3
3
|
TYPE_CHECKING,
|
4
4
|
Any,
|
5
|
+
Generic,
|
5
6
|
List,
|
6
7
|
Optional,
|
7
8
|
Type,
|
9
|
+
TypeVar,
|
8
10
|
cast,
|
9
11
|
)
|
10
12
|
|
11
|
-
from mcp_agent.logging.logger import get_logger
|
12
|
-
from mcp_agent.mcp.interfaces import (
|
13
|
-
AugmentedLLMProtocol,
|
14
|
-
MessageParamT,
|
15
|
-
MessageT,
|
16
|
-
ModelT,
|
17
|
-
)
|
18
|
-
from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
|
19
|
-
from mcp_agent.workflows.llm.sampling_format_converter import (
|
20
|
-
BasicFormatConverter,
|
21
|
-
ProviderFormatConverter,
|
22
|
-
)
|
23
|
-
|
24
|
-
# Forward reference for type annotations
|
25
|
-
if TYPE_CHECKING:
|
26
|
-
from mcp_agent.agents.agent import Agent
|
27
|
-
from mcp_agent.context import Context
|
28
|
-
from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
|
29
|
-
|
30
|
-
|
31
13
|
from mcp.types import (
|
32
14
|
CallToolRequest,
|
33
15
|
CallToolResult,
|
@@ -35,21 +17,45 @@ from mcp.types import (
|
|
35
17
|
PromptMessage,
|
36
18
|
TextContent,
|
37
19
|
)
|
20
|
+
from pydantic_core import from_json
|
38
21
|
from rich.text import Text
|
39
22
|
|
40
23
|
from mcp_agent.context_dependent import ContextDependent
|
41
|
-
from mcp_agent.core.exceptions import
|
24
|
+
from mcp_agent.core.exceptions import PromptExitError
|
25
|
+
from mcp_agent.core.prompt import Prompt
|
42
26
|
from mcp_agent.core.request_params import RequestParams
|
43
27
|
from mcp_agent.event_progress import ProgressAction
|
28
|
+
from mcp_agent.llm.memory import Memory, SimpleMemory
|
29
|
+
from mcp_agent.llm.sampling_format_converter import (
|
30
|
+
BasicFormatConverter,
|
31
|
+
ProviderFormatConverter,
|
32
|
+
)
|
33
|
+
from mcp_agent.logging.logger import get_logger
|
34
|
+
from mcp_agent.mcp.interfaces import (
|
35
|
+
AugmentedLLMProtocol,
|
36
|
+
ModelT,
|
37
|
+
)
|
44
38
|
from mcp_agent.mcp.mcp_aggregator import MCPAggregator
|
39
|
+
from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
|
40
|
+
from mcp_agent.mcp.prompt_render import render_multipart_message
|
41
|
+
from mcp_agent.mcp.prompt_serialization import multipart_messages_to_delimited_format
|
45
42
|
from mcp_agent.ui.console_display import ConsoleDisplay
|
46
|
-
|
43
|
+
|
44
|
+
# Define type variables locally
|
45
|
+
MessageParamT = TypeVar("MessageParamT")
|
46
|
+
MessageT = TypeVar("MessageT")
|
47
|
+
|
48
|
+
# Forward reference for type annotations
|
49
|
+
if TYPE_CHECKING:
|
50
|
+
from mcp_agent.agents.agent import Agent
|
51
|
+
from mcp_agent.context import Context
|
52
|
+
|
47
53
|
|
48
54
|
# TODO -- move this to a constant
|
49
55
|
HUMAN_INPUT_TOOL_NAME = "__human_input__"
|
50
56
|
|
51
57
|
|
52
|
-
class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, MessageT]):
|
58
|
+
class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT, MessageT]):
|
53
59
|
"""
|
54
60
|
The basic building block of agentic systems is an LLM enhanced with augmentations
|
55
61
|
such as retrieval, tools, and memory provided from a collection of MCP servers.
|
@@ -66,7 +72,9 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
66
72
|
instruction: str | None = None,
|
67
73
|
name: str | None = None,
|
68
74
|
request_params: RequestParams | None = None,
|
69
|
-
type_converter: Type[
|
75
|
+
type_converter: Type[
|
76
|
+
ProviderFormatConverter[MessageParamT, MessageT]
|
77
|
+
] = BasicFormatConverter,
|
70
78
|
context: Optional["Context"] = None,
|
71
79
|
**kwargs: dict[str, Any],
|
72
80
|
) -> None:
|
@@ -81,10 +89,14 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
81
89
|
self.logger = get_logger(__name__)
|
82
90
|
self.executor = self.context.executor
|
83
91
|
self.aggregator = agent if agent is not None else MCPAggregator(server_names or [])
|
84
|
-
self.name =
|
85
|
-
self.instruction =
|
92
|
+
self.name = agent.name if agent else name
|
93
|
+
self.instruction = agent.instruction if agent else instruction
|
94
|
+
|
95
|
+
# memory contains provider specific API types.
|
86
96
|
self.history: Memory[MessageParamT] = SimpleMemory[MessageParamT]()
|
87
97
|
|
98
|
+
self.message_history: List[PromptMessageMultipart] = []
|
99
|
+
|
88
100
|
# Initialize the display component
|
89
101
|
self.display = ConsoleDisplay(config=self.context.config)
|
90
102
|
|
@@ -93,56 +105,82 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
93
105
|
|
94
106
|
# Merge with provided params if any
|
95
107
|
if self._init_request_params:
|
96
|
-
self.default_request_params = self._merge_request_params(
|
108
|
+
self.default_request_params = self._merge_request_params(
|
109
|
+
self.default_request_params, self._init_request_params
|
110
|
+
)
|
97
111
|
|
98
112
|
self.type_converter = type_converter
|
99
113
|
self.verb = kwargs.get("verb")
|
100
114
|
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
115
|
+
def _initialize_default_params(self, kwargs: dict) -> RequestParams:
|
116
|
+
"""Initialize default parameters for the LLM.
|
117
|
+
Should be overridden by provider implementations to set provider-specific defaults."""
|
118
|
+
return RequestParams(
|
119
|
+
systemPrompt=self.instruction,
|
120
|
+
parallel_tool_calls=True,
|
121
|
+
max_iterations=10,
|
122
|
+
use_history=True,
|
123
|
+
)
|
108
124
|
|
109
|
-
|
110
|
-
async def generate_str(
|
125
|
+
async def structured(
|
111
126
|
self,
|
112
|
-
|
127
|
+
prompt: List[PromptMessageMultipart],
|
128
|
+
model: Type[ModelT],
|
113
129
|
request_params: RequestParams | None = None,
|
114
|
-
) ->
|
115
|
-
"""
|
130
|
+
) -> ModelT | None:
|
131
|
+
"""Apply the prompt and return the result as a Pydantic model, or None if coercion fails"""
|
132
|
+
try:
|
133
|
+
result: PromptMessageMultipart = await self.generate(prompt, request_params)
|
134
|
+
json_data = from_json(result.first_text(), allow_partial=True)
|
135
|
+
validated_model = model.model_validate(json_data)
|
136
|
+
return cast("ModelT", validated_model)
|
137
|
+
except Exception as e:
|
138
|
+
logger = get_logger(__name__)
|
139
|
+
logger.error(f"Failed to parse structured response: {str(e)}")
|
140
|
+
return None
|
116
141
|
|
117
|
-
|
118
|
-
async def generate_structured(
|
142
|
+
async def generate(
|
119
143
|
self,
|
120
|
-
|
121
|
-
response_model: Type[ModelT],
|
144
|
+
multipart_messages: List[PromptMessageMultipart],
|
122
145
|
request_params: RequestParams | None = None,
|
123
|
-
) ->
|
124
|
-
"""Request a structured LLM generation and return the result as a Pydantic model."""
|
125
|
-
|
126
|
-
async def select_model(self, request_params: RequestParams | None = None) -> str | None:
|
146
|
+
) -> PromptMessageMultipart:
|
127
147
|
"""
|
128
|
-
|
148
|
+
Create a completion with the LLM using the provided messages.
|
129
149
|
"""
|
130
|
-
if
|
131
|
-
|
150
|
+
if multipart_messages[-1].first_text().startswith("***SAVE_HISTORY"):
|
151
|
+
parts: list[str] = multipart_messages[-1].first_text().split(" ", 1)
|
152
|
+
filename: str = (
|
153
|
+
parts[1].strip() if len(parts) > 1 else f"{self.name or 'assistant'}_prompts.txt"
|
154
|
+
)
|
155
|
+
await self._save_history(filename)
|
156
|
+
self.show_user_message(
|
157
|
+
f"History saved to {filename}", model=self.default_request_params.model, chat_turn=0
|
158
|
+
)
|
159
|
+
return Prompt.assistant(f"History saved to {filename}")
|
132
160
|
|
133
|
-
|
161
|
+
self.message_history.extend(multipart_messages)
|
134
162
|
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
163
|
+
if multipart_messages[-1].role == "user":
|
164
|
+
self.show_user_message(
|
165
|
+
render_multipart_message(multipart_messages[-1]),
|
166
|
+
model=self.default_request_params.model,
|
167
|
+
chat_turn=self.chat_turn(),
|
168
|
+
)
|
169
|
+
|
170
|
+
assistant_response: PromptMessageMultipart = await self._apply_prompt_provider_specific(
|
171
|
+
multipart_messages, request_params
|
143
172
|
)
|
144
173
|
|
145
|
-
|
174
|
+
self.message_history.append(assistant_response)
|
175
|
+
return assistant_response
|
176
|
+
|
177
|
+
def chat_turn(self) -> int:
|
178
|
+
"""Return the current chat turn number"""
|
179
|
+
return 1 + sum(1 for message in self.message_history if message.role == "assistant")
|
180
|
+
|
181
|
+
def _merge_request_params(
|
182
|
+
self, default_params: RequestParams, provided_params: RequestParams
|
183
|
+
) -> RequestParams:
|
146
184
|
"""Merge default and provided request parameters"""
|
147
185
|
|
148
186
|
merged = default_params.model_dump()
|
@@ -176,24 +214,13 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
176
214
|
return default_request_params
|
177
215
|
|
178
216
|
@classmethod
|
179
|
-
def convert_message_to_message_param(
|
217
|
+
def convert_message_to_message_param(
|
218
|
+
cls, message: MessageT, **kwargs: dict[str, Any]
|
219
|
+
) -> MessageParamT:
|
180
220
|
"""Convert a response object to an input parameter object to allow LLM calls to be chained."""
|
181
221
|
# Many LLM implementations will allow the same type for input and output messages
|
182
222
|
return cast("MessageParamT", message)
|
183
223
|
|
184
|
-
async def get_last_message(self) -> MessageParamT | None:
|
185
|
-
"""
|
186
|
-
Return the last message generated by the LLM or None if history is empty.
|
187
|
-
This is useful for prompt chaining workflows where the last message from one LLM is used as input to another.
|
188
|
-
"""
|
189
|
-
history = self.history.get()
|
190
|
-
return history[-1] if history else None
|
191
|
-
|
192
|
-
async def get_last_message_str(self) -> str | None:
|
193
|
-
"""Return the string representation of the last message generated by the LLM or None if history is empty."""
|
194
|
-
last_message = await self.get_last_message()
|
195
|
-
return self.message_param_str(last_message) if last_message else None
|
196
|
-
|
197
224
|
def show_tool_result(self, result: CallToolResult) -> None:
|
198
225
|
"""Display a tool result in a formatted panel."""
|
199
226
|
self.display.show_tool_result(result)
|
@@ -208,10 +235,12 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
208
235
|
|
209
236
|
async def show_assistant_message(
|
210
237
|
self,
|
211
|
-
message_text: str | Text,
|
238
|
+
message_text: str | Text | None,
|
212
239
|
highlight_namespaced_tool: str = "",
|
213
240
|
title: str = "ASSISTANT",
|
214
241
|
) -> None:
|
242
|
+
if message_text is None:
|
243
|
+
message_text = Text("No content to display", style="dim green italic")
|
215
244
|
"""Display an assistant message in a formatted panel."""
|
216
245
|
await self.display.show_assistant_message(
|
217
246
|
message_text,
|
@@ -225,11 +254,15 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
225
254
|
"""Display a user message in a formatted panel."""
|
226
255
|
self.display.show_user_message(message, model, chat_turn, name=self.name)
|
227
256
|
|
228
|
-
async def pre_tool_call(
|
257
|
+
async def pre_tool_call(
|
258
|
+
self, tool_call_id: str | None, request: CallToolRequest
|
259
|
+
) -> CallToolRequest | bool:
|
229
260
|
"""Called before a tool is executed. Return False to prevent execution."""
|
230
261
|
return request
|
231
262
|
|
232
|
-
async def post_tool_call(
|
263
|
+
async def post_tool_call(
|
264
|
+
self, tool_call_id: str | None, request: CallToolRequest, result: CallToolResult
|
265
|
+
) -> CallToolResult:
|
233
266
|
"""Called after a tool execution. Can modify the result before it's returned."""
|
234
267
|
return result
|
235
268
|
|
@@ -264,7 +297,9 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
264
297
|
tool_args = request.params.arguments
|
265
298
|
result = await self.aggregator.call_tool(tool_name, tool_args)
|
266
299
|
|
267
|
-
postprocess = await self.post_tool_call(
|
300
|
+
postprocess = await self.post_tool_call(
|
301
|
+
tool_call_id=tool_call_id, request=request, result=result
|
302
|
+
)
|
268
303
|
|
269
304
|
if isinstance(postprocess, CallToolResult):
|
270
305
|
result = postprocess
|
@@ -283,66 +318,9 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
283
318
|
],
|
284
319
|
)
|
285
320
|
|
286
|
-
def
|
287
|
-
|
288
|
-
|
289
|
-
Tries to extract just the content when possible.
|
290
|
-
"""
|
291
|
-
if isinstance(message, dict):
|
292
|
-
# For dictionary format messages
|
293
|
-
if "content" in message:
|
294
|
-
content = message["content"]
|
295
|
-
# Handle both string and structured content formats
|
296
|
-
if isinstance(content, str):
|
297
|
-
return content
|
298
|
-
elif isinstance(content, list) and content:
|
299
|
-
# Try to extract text from content parts
|
300
|
-
text_parts = []
|
301
|
-
for part in content:
|
302
|
-
if isinstance(part, dict) and "text" in part:
|
303
|
-
text_parts.append(part["text"])
|
304
|
-
elif hasattr(part, "text"):
|
305
|
-
text_parts.append(part.text) # type: ignore
|
306
|
-
if text_parts:
|
307
|
-
return "\n".join(text_parts)
|
308
|
-
|
309
|
-
# For objects with content attribute
|
310
|
-
if hasattr(message, "content"):
|
311
|
-
content = message.content # type: ignore
|
312
|
-
if isinstance(content, str):
|
313
|
-
return content
|
314
|
-
elif hasattr(content, "text"):
|
315
|
-
return content.text
|
316
|
-
|
317
|
-
# Default fallback
|
318
|
-
return str(message)
|
319
|
-
|
320
|
-
def message_str(self, message: MessageT) -> str:
|
321
|
-
"""
|
322
|
-
Convert an output message to a string representation.
|
323
|
-
Tries to extract just the content when possible.
|
324
|
-
"""
|
325
|
-
# First try to use the same method for consistency
|
326
|
-
result = self.message_param_str(message) # type: ignore
|
327
|
-
if result != str(message):
|
328
|
-
return result
|
329
|
-
|
330
|
-
# Additional handling for output-specific formats
|
331
|
-
if hasattr(message, "content"):
|
332
|
-
content = getattr(message, "content")
|
333
|
-
if isinstance(content, list):
|
334
|
-
# Extract text from content blocks
|
335
|
-
text_parts = []
|
336
|
-
for block in content:
|
337
|
-
if hasattr(block, "text") and block.text:
|
338
|
-
text_parts.append(block.text)
|
339
|
-
if text_parts:
|
340
|
-
return "\n".join(text_parts)
|
341
|
-
|
342
|
-
# Default fallback
|
343
|
-
return str(message)
|
344
|
-
|
345
|
-
def _log_chat_progress(self, chat_turn: Optional[int] = None, model: Optional[str] = None) -> None:
|
321
|
+
def _log_chat_progress(
|
322
|
+
self, chat_turn: Optional[int] = None, model: Optional[str] = None
|
323
|
+
) -> None:
|
346
324
|
"""Log a chat progress event"""
|
347
325
|
# Determine action type based on verb
|
348
326
|
if hasattr(self, "verb") and self.verb:
|
@@ -435,33 +413,28 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
435
413
|
multipart_messages = PromptMessageMultipart.parse_get_prompt_result(prompt_result)
|
436
414
|
|
437
415
|
# Delegate to the provider-specific implementation
|
438
|
-
|
416
|
+
result = await self._apply_prompt_provider_specific(multipart_messages, None)
|
417
|
+
return result.first_text()
|
439
418
|
|
440
|
-
async def
|
441
|
-
self,
|
442
|
-
multipart_messages: List["PromptMessageMultipart"],
|
443
|
-
request_params: RequestParams | None = None,
|
444
|
-
) -> str:
|
419
|
+
async def _save_history(self, filename: str) -> None:
|
445
420
|
"""
|
446
|
-
|
447
|
-
This is a cleaner interface to _apply_prompt_template_provider_specific.
|
448
|
-
|
449
|
-
Args:
|
450
|
-
multipart_messages: List of PromptMessageMultipart objects
|
451
|
-
request_params: Optional parameters to configure the LLM request
|
452
|
-
|
453
|
-
Returns:
|
454
|
-
String representation of the assistant's response
|
421
|
+
Save the Message History to a file in a simple delimeted format.
|
455
422
|
"""
|
456
|
-
#
|
457
|
-
|
423
|
+
# Convert to delimited format
|
424
|
+
delimited_content = multipart_messages_to_delimited_format(
|
425
|
+
self.message_history,
|
426
|
+
)
|
427
|
+
|
428
|
+
# Write to file
|
429
|
+
with open(filename, "w", encoding="utf-8") as f:
|
430
|
+
f.write("\n\n".join(delimited_content))
|
458
431
|
|
459
|
-
|
460
|
-
async def
|
432
|
+
@abstractmethod
|
433
|
+
async def _apply_prompt_provider_specific(
|
461
434
|
self,
|
462
435
|
multipart_messages: List["PromptMessageMultipart"],
|
463
436
|
request_params: RequestParams | None = None,
|
464
|
-
) ->
|
437
|
+
) -> PromptMessageMultipart:
|
465
438
|
"""
|
466
439
|
Provider-specific implementation of apply_prompt_template.
|
467
440
|
This default implementation handles basic text content for any LLM type.
|
@@ -475,101 +448,3 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
475
448
|
String representation of the assistant's response if generated,
|
476
449
|
or the last assistant message in the prompt
|
477
450
|
"""
|
478
|
-
# Check the last message role
|
479
|
-
last_message = multipart_messages[-1]
|
480
|
-
|
481
|
-
if last_message.role == "user":
|
482
|
-
# For user messages: Add all previous messages to history, then generate response to the last one
|
483
|
-
self.logger.debug("Last message in prompt is from user, generating assistant response")
|
484
|
-
|
485
|
-
# Add all but the last message to history
|
486
|
-
if len(multipart_messages) > 1:
|
487
|
-
previous_messages = multipart_messages[:-1]
|
488
|
-
converted = []
|
489
|
-
|
490
|
-
# Fallback generic method for all LLM types
|
491
|
-
for msg in previous_messages:
|
492
|
-
# Convert each PromptMessageMultipart to individual PromptMessages
|
493
|
-
prompt_messages = msg.from_multipart()
|
494
|
-
for prompt_msg in prompt_messages:
|
495
|
-
converted.append(self.type_converter.from_prompt_message(prompt_msg))
|
496
|
-
|
497
|
-
self.history.extend(converted, is_prompt=True)
|
498
|
-
|
499
|
-
# For generic LLMs, extract text and describe non-text content
|
500
|
-
user_text_parts = []
|
501
|
-
for content in last_message.content:
|
502
|
-
if content.type == "text":
|
503
|
-
user_text_parts.append(content.text)
|
504
|
-
elif content.type == "resource" and getattr(content, "resource", None) is not None:
|
505
|
-
if hasattr(content.resource, "text"):
|
506
|
-
user_text_parts.append(content.resource.text) # type: ignore
|
507
|
-
elif content.type == "image":
|
508
|
-
# Add a placeholder for images
|
509
|
-
mime_type = getattr(content, "mimeType", "image/unknown")
|
510
|
-
user_text_parts.append(f"[Image: {mime_type}]")
|
511
|
-
|
512
|
-
user_text = "\n".join(user_text_parts) if user_text_parts else ""
|
513
|
-
if not user_text:
|
514
|
-
# Fallback to original method if we couldn't extract text
|
515
|
-
user_text = str(last_message.content)
|
516
|
-
|
517
|
-
return await self.generate_str(user_text)
|
518
|
-
else:
|
519
|
-
# For assistant messages: Add all messages to history and return the last one
|
520
|
-
self.logger.debug("Last message in prompt is from assistant, returning it directly")
|
521
|
-
|
522
|
-
# Convert and add all messages to history
|
523
|
-
converted = []
|
524
|
-
|
525
|
-
# Fallback to the original method for all LLM types
|
526
|
-
for msg in multipart_messages:
|
527
|
-
# Convert each PromptMessageMultipart to individual PromptMessages
|
528
|
-
prompt_messages = msg.from_multipart()
|
529
|
-
for prompt_msg in prompt_messages:
|
530
|
-
converted.append(self.type_converter.from_prompt_message(prompt_msg))
|
531
|
-
|
532
|
-
self.history.extend(converted, is_prompt=True)
|
533
|
-
|
534
|
-
# Return the assistant's message with proper handling of different content types
|
535
|
-
assistant_text_parts = []
|
536
|
-
has_non_text_content = False
|
537
|
-
|
538
|
-
for content in last_message.content:
|
539
|
-
if content.type == "text":
|
540
|
-
assistant_text_parts.append(content.text)
|
541
|
-
elif content.type == "resource" and hasattr(content.resource, "text"):
|
542
|
-
# Add resource text with metadata
|
543
|
-
mime_type = getattr(content.resource, "mimeType", "text/plain")
|
544
|
-
uri = getattr(content.resource, "uri", "")
|
545
|
-
if uri:
|
546
|
-
assistant_text_parts.append(
|
547
|
-
f"[Resource: {uri}, Type: {mime_type}]\n{content.resource.text}" # ignore # type: ignore
|
548
|
-
)
|
549
|
-
else:
|
550
|
-
assistant_text_parts.append(
|
551
|
-
f"[Resource Type: {mime_type}]\n{content.resource.text}" # type ignore # type: ignore
|
552
|
-
)
|
553
|
-
elif content.type == "image":
|
554
|
-
# Note the presence of images
|
555
|
-
mime_type = getattr(content, "mimeType", "image/unknown")
|
556
|
-
assistant_text_parts.append(f"[Image: {mime_type}]")
|
557
|
-
has_non_text_content = True
|
558
|
-
else:
|
559
|
-
# Other content types
|
560
|
-
assistant_text_parts.append(f"[Content of type: {content.type}]")
|
561
|
-
has_non_text_content = True
|
562
|
-
|
563
|
-
# Join all parts with double newlines for better readability
|
564
|
-
result = "\n\n".join(assistant_text_parts) if assistant_text_parts else str(last_message.content)
|
565
|
-
|
566
|
-
# Add a note if non-text content was present
|
567
|
-
if has_non_text_content:
|
568
|
-
result += "\n\n[Note: This message contained non-text content that may not be fully represented in text format]"
|
569
|
-
|
570
|
-
return result
|
571
|
-
|
572
|
-
|
573
|
-
#####################################
|
574
|
-
### NEW INTERFACE METHODS BELOW ###
|
575
|
-
#####################################
|