fast-agent-mcp 0.1.12__py3-none-any.whl → 0.1.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {fast_agent_mcp-0.1.12.dist-info → fast_agent_mcp-0.1.13.dist-info}/METADATA +1 -1
- fast_agent_mcp-0.1.13.dist-info/RECORD +164 -0
- mcp_agent/agents/agent.py +37 -79
- mcp_agent/app.py +16 -22
- mcp_agent/cli/commands/bootstrap.py +22 -52
- mcp_agent/cli/commands/config.py +4 -4
- mcp_agent/cli/commands/setup.py +11 -26
- mcp_agent/cli/main.py +6 -9
- mcp_agent/cli/terminal.py +2 -2
- mcp_agent/config.py +1 -5
- mcp_agent/context.py +13 -24
- mcp_agent/context_dependent.py +3 -7
- mcp_agent/core/agent_app.py +45 -121
- mcp_agent/core/agent_utils.py +3 -5
- mcp_agent/core/decorators.py +5 -12
- mcp_agent/core/enhanced_prompt.py +25 -52
- mcp_agent/core/exceptions.py +8 -8
- mcp_agent/core/factory.py +29 -70
- mcp_agent/core/fastagent.py +48 -88
- mcp_agent/core/mcp_content.py +8 -16
- mcp_agent/core/prompt.py +8 -15
- mcp_agent/core/proxies.py +34 -25
- mcp_agent/core/request_params.py +6 -3
- mcp_agent/core/types.py +4 -6
- mcp_agent/core/validation.py +4 -3
- mcp_agent/executor/decorator_registry.py +11 -23
- mcp_agent/executor/executor.py +8 -17
- mcp_agent/executor/task_registry.py +2 -4
- mcp_agent/executor/temporal.py +28 -74
- mcp_agent/executor/workflow.py +3 -5
- mcp_agent/executor/workflow_signal.py +17 -29
- mcp_agent/human_input/handler.py +4 -9
- mcp_agent/human_input/types.py +2 -3
- mcp_agent/logging/events.py +1 -5
- mcp_agent/logging/json_serializer.py +7 -6
- mcp_agent/logging/listeners.py +20 -23
- mcp_agent/logging/logger.py +15 -17
- mcp_agent/logging/rich_progress.py +10 -8
- mcp_agent/logging/tracing.py +4 -6
- mcp_agent/logging/transport.py +22 -22
- mcp_agent/mcp/gen_client.py +4 -12
- mcp_agent/mcp/interfaces.py +71 -86
- mcp_agent/mcp/mcp_agent_client_session.py +11 -19
- mcp_agent/mcp/mcp_agent_server.py +8 -10
- mcp_agent/mcp/mcp_aggregator.py +45 -117
- mcp_agent/mcp/mcp_connection_manager.py +16 -37
- mcp_agent/mcp/prompt_message_multipart.py +12 -18
- mcp_agent/mcp/prompt_serialization.py +13 -38
- mcp_agent/mcp/prompts/prompt_load.py +99 -0
- mcp_agent/mcp/prompts/prompt_server.py +21 -128
- mcp_agent/mcp/prompts/prompt_template.py +20 -42
- mcp_agent/mcp/resource_utils.py +8 -17
- mcp_agent/mcp/sampling.py +5 -14
- mcp_agent/mcp/stdio.py +11 -8
- mcp_agent/mcp_server/agent_server.py +10 -17
- mcp_agent/mcp_server_registry.py +13 -35
- mcp_agent/resources/examples/data-analysis/analysis-campaign.py +1 -1
- mcp_agent/resources/examples/data-analysis/analysis.py +1 -1
- mcp_agent/resources/examples/data-analysis/slides.py +110 -0
- mcp_agent/resources/examples/internal/agent.py +2 -1
- mcp_agent/resources/examples/internal/job.py +2 -1
- mcp_agent/resources/examples/internal/prompt_category.py +1 -1
- mcp_agent/resources/examples/internal/prompt_sizing.py +3 -5
- mcp_agent/resources/examples/internal/sizer.py +2 -1
- mcp_agent/resources/examples/internal/social.py +2 -1
- mcp_agent/resources/examples/mcp_researcher/researcher-eval.py +1 -1
- mcp_agent/resources/examples/prompting/agent.py +2 -1
- mcp_agent/resources/examples/prompting/image_server.py +5 -11
- mcp_agent/resources/examples/researcher/researcher-eval.py +1 -1
- mcp_agent/resources/examples/researcher/researcher-imp.py +3 -4
- mcp_agent/resources/examples/researcher/researcher.py +2 -1
- mcp_agent/resources/examples/workflows/agent_build.py +2 -1
- mcp_agent/resources/examples/workflows/chaining.py +2 -1
- mcp_agent/resources/examples/workflows/evaluator.py +2 -1
- mcp_agent/resources/examples/workflows/human_input.py +2 -1
- mcp_agent/resources/examples/workflows/orchestrator.py +2 -1
- mcp_agent/resources/examples/workflows/parallel.py +2 -1
- mcp_agent/resources/examples/workflows/router.py +2 -1
- mcp_agent/resources/examples/workflows/sse.py +1 -1
- mcp_agent/telemetry/usage_tracking.py +2 -1
- mcp_agent/ui/console_display.py +15 -39
- mcp_agent/workflows/embedding/embedding_base.py +1 -4
- mcp_agent/workflows/embedding/embedding_cohere.py +2 -2
- mcp_agent/workflows/embedding/embedding_openai.py +4 -13
- mcp_agent/workflows/evaluator_optimizer/evaluator_optimizer.py +23 -57
- mcp_agent/workflows/intent_classifier/intent_classifier_base.py +5 -8
- mcp_agent/workflows/intent_classifier/intent_classifier_embedding.py +7 -11
- mcp_agent/workflows/intent_classifier/intent_classifier_embedding_cohere.py +4 -8
- mcp_agent/workflows/intent_classifier/intent_classifier_embedding_openai.py +4 -8
- mcp_agent/workflows/intent_classifier/intent_classifier_llm.py +11 -22
- mcp_agent/workflows/intent_classifier/intent_classifier_llm_anthropic.py +3 -3
- mcp_agent/workflows/intent_classifier/intent_classifier_llm_openai.py +4 -6
- mcp_agent/workflows/llm/anthropic_utils.py +8 -29
- mcp_agent/workflows/llm/augmented_llm.py +69 -247
- mcp_agent/workflows/llm/augmented_llm_anthropic.py +39 -73
- mcp_agent/workflows/llm/augmented_llm_openai.py +42 -97
- mcp_agent/workflows/llm/augmented_llm_passthrough.py +13 -20
- mcp_agent/workflows/llm/augmented_llm_playback.py +8 -6
- mcp_agent/workflows/llm/memory.py +103 -0
- mcp_agent/workflows/llm/model_factory.py +8 -20
- mcp_agent/workflows/llm/openai_utils.py +1 -1
- mcp_agent/workflows/llm/prompt_utils.py +1 -3
- mcp_agent/workflows/llm/providers/multipart_converter_anthropic.py +47 -89
- mcp_agent/workflows/llm/providers/multipart_converter_openai.py +20 -55
- mcp_agent/workflows/llm/providers/openai_multipart.py +19 -61
- mcp_agent/workflows/llm/providers/sampling_converter_anthropic.py +10 -12
- mcp_agent/workflows/llm/providers/sampling_converter_openai.py +7 -11
- mcp_agent/workflows/llm/sampling_converter.py +4 -11
- mcp_agent/workflows/llm/sampling_format_converter.py +12 -12
- mcp_agent/workflows/orchestrator/orchestrator.py +24 -67
- mcp_agent/workflows/orchestrator/orchestrator_models.py +14 -40
- mcp_agent/workflows/parallel/fan_in.py +17 -47
- mcp_agent/workflows/parallel/fan_out.py +6 -12
- mcp_agent/workflows/parallel/parallel_llm.py +9 -26
- mcp_agent/workflows/router/router_base.py +19 -49
- mcp_agent/workflows/router/router_embedding.py +11 -25
- mcp_agent/workflows/router/router_embedding_cohere.py +2 -2
- mcp_agent/workflows/router/router_embedding_openai.py +2 -2
- mcp_agent/workflows/router/router_llm.py +12 -28
- mcp_agent/workflows/swarm/swarm.py +20 -48
- mcp_agent/workflows/swarm/swarm_anthropic.py +2 -2
- mcp_agent/workflows/swarm/swarm_openai.py +2 -2
- fast_agent_mcp-0.1.12.dist-info/RECORD +0 -161
- {fast_agent_mcp-0.1.12.dist-info → fast_agent_mcp-0.1.13.dist-info}/WHEEL +0 -0
- {fast_agent_mcp-0.1.12.dist-info → fast_agent_mcp-0.1.13.dist-info}/entry_points.txt +0 -0
- {fast_agent_mcp-0.1.12.dist-info → fast_agent_mcp-0.1.13.dist-info}/licenses/LICENSE +0 -0
@@ -1,195 +1,54 @@
|
|
1
1
|
from abc import abstractmethod
|
2
|
-
|
3
2
|
from typing import (
|
4
|
-
|
3
|
+
TYPE_CHECKING,
|
4
|
+
Any,
|
5
5
|
List,
|
6
6
|
Optional,
|
7
|
-
Protocol,
|
8
7
|
Type,
|
9
|
-
|
10
|
-
TYPE_CHECKING,
|
8
|
+
cast,
|
11
9
|
)
|
12
10
|
|
13
|
-
from mcp_agent.
|
14
|
-
from mcp_agent.
|
15
|
-
|
11
|
+
from mcp_agent.logging.logger import get_logger
|
12
|
+
from mcp_agent.mcp.interfaces import (
|
13
|
+
AugmentedLLMProtocol,
|
16
14
|
MessageParamT,
|
17
15
|
MessageT,
|
16
|
+
ModelT,
|
17
|
+
)
|
18
|
+
from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
|
19
|
+
from mcp_agent.workflows.llm.sampling_format_converter import (
|
20
|
+
BasicFormatConverter,
|
21
|
+
ProviderFormatConverter,
|
18
22
|
)
|
19
23
|
|
20
24
|
# Forward reference for type annotations
|
21
25
|
if TYPE_CHECKING:
|
22
|
-
from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
|
23
26
|
from mcp_agent.agents.agent import Agent
|
24
27
|
from mcp_agent.context import Context
|
25
|
-
|
28
|
+
from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
|
26
29
|
|
27
30
|
|
28
31
|
from mcp.types import (
|
29
32
|
CallToolRequest,
|
30
33
|
CallToolResult,
|
34
|
+
GetPromptResult,
|
31
35
|
PromptMessage,
|
32
36
|
TextContent,
|
33
|
-
GetPromptResult,
|
34
37
|
)
|
38
|
+
from rich.text import Text
|
35
39
|
|
36
40
|
from mcp_agent.context_dependent import ContextDependent
|
37
41
|
from mcp_agent.core.exceptions import ModelConfigError, PromptExitError
|
38
42
|
from mcp_agent.core.request_params import RequestParams
|
39
43
|
from mcp_agent.event_progress import ProgressAction
|
40
|
-
|
41
|
-
try:
|
42
|
-
from mcp_agent.mcp.mcp_aggregator import MCPAggregator
|
43
|
-
except ImportError:
|
44
|
-
# For testing purposes
|
45
|
-
class MCPAggregator:
|
46
|
-
pass
|
47
|
-
|
48
|
-
|
44
|
+
from mcp_agent.mcp.mcp_aggregator import MCPAggregator
|
49
45
|
from mcp_agent.ui.console_display import ConsoleDisplay
|
50
|
-
from
|
51
|
-
|
52
|
-
|
53
|
-
ModelT = TypeVar("ModelT")
|
54
|
-
"""A type representing a structured output message from an LLM."""
|
55
|
-
|
46
|
+
from mcp_agent.workflows.llm.memory import Memory, SimpleMemory
|
56
47
|
|
57
48
|
# TODO -- move this to a constant
|
58
49
|
HUMAN_INPUT_TOOL_NAME = "__human_input__"
|
59
50
|
|
60
51
|
|
61
|
-
class Memory(Protocol, Generic[MessageParamT]):
|
62
|
-
"""
|
63
|
-
Simple memory management for storing past interactions in-memory.
|
64
|
-
"""
|
65
|
-
|
66
|
-
# TODO: saqadri - add checkpointing and other advanced memory capabilities
|
67
|
-
|
68
|
-
def __init__(self): ...
|
69
|
-
|
70
|
-
def extend(
|
71
|
-
self, messages: List[MessageParamT], is_prompt: bool = False
|
72
|
-
) -> None: ...
|
73
|
-
|
74
|
-
def set(self, messages: List[MessageParamT], is_prompt: bool = False) -> None: ...
|
75
|
-
|
76
|
-
def append(self, message: MessageParamT, is_prompt: bool = False) -> None: ...
|
77
|
-
|
78
|
-
def get(self, include_history: bool = True) -> List[MessageParamT]: ...
|
79
|
-
|
80
|
-
def clear(self, clear_prompts: bool = False) -> None: ...
|
81
|
-
|
82
|
-
|
83
|
-
class SimpleMemory(Memory, Generic[MessageParamT]):
|
84
|
-
"""
|
85
|
-
Simple memory management for storing past interactions in-memory.
|
86
|
-
|
87
|
-
Maintains both prompt messages (which are always included) and
|
88
|
-
generated conversation history (which is included based on use_history setting).
|
89
|
-
"""
|
90
|
-
|
91
|
-
def __init__(self):
|
92
|
-
self.history: List[MessageParamT] = []
|
93
|
-
self.prompt_messages: List[MessageParamT] = [] # Always included
|
94
|
-
|
95
|
-
def extend(self, messages: List[MessageParamT], is_prompt: bool = False):
|
96
|
-
"""
|
97
|
-
Add multiple messages to history.
|
98
|
-
|
99
|
-
Args:
|
100
|
-
messages: Messages to add
|
101
|
-
is_prompt: If True, add to prompt_messages instead of regular history
|
102
|
-
"""
|
103
|
-
if is_prompt:
|
104
|
-
self.prompt_messages.extend(messages)
|
105
|
-
else:
|
106
|
-
self.history.extend(messages)
|
107
|
-
|
108
|
-
def set(self, messages: List[MessageParamT], is_prompt: bool = False):
|
109
|
-
"""
|
110
|
-
Replace messages in history.
|
111
|
-
|
112
|
-
Args:
|
113
|
-
messages: Messages to set
|
114
|
-
is_prompt: If True, replace prompt_messages instead of regular history
|
115
|
-
"""
|
116
|
-
if is_prompt:
|
117
|
-
self.prompt_messages = messages.copy()
|
118
|
-
else:
|
119
|
-
self.history = messages.copy()
|
120
|
-
|
121
|
-
def append(self, message: MessageParamT, is_prompt: bool = False):
|
122
|
-
"""
|
123
|
-
Add a single message to history.
|
124
|
-
|
125
|
-
Args:
|
126
|
-
message: Message to add
|
127
|
-
is_prompt: If True, add to prompt_messages instead of regular history
|
128
|
-
"""
|
129
|
-
if is_prompt:
|
130
|
-
self.prompt_messages.append(message)
|
131
|
-
else:
|
132
|
-
self.history.append(message)
|
133
|
-
|
134
|
-
def get(self, include_history: bool = True) -> List[MessageParamT]:
|
135
|
-
"""
|
136
|
-
Get all messages in memory.
|
137
|
-
|
138
|
-
Args:
|
139
|
-
include_history: If True, include regular history messages
|
140
|
-
If False, only return prompt messages
|
141
|
-
|
142
|
-
Returns:
|
143
|
-
Combined list of prompt messages and optionally history messages
|
144
|
-
"""
|
145
|
-
if include_history:
|
146
|
-
return self.prompt_messages + self.history
|
147
|
-
else:
|
148
|
-
return self.prompt_messages.copy()
|
149
|
-
|
150
|
-
def clear(self, clear_prompts: bool = False):
|
151
|
-
"""
|
152
|
-
Clear history and optionally prompt messages.
|
153
|
-
|
154
|
-
Args:
|
155
|
-
clear_prompts: If True, also clear prompt messages
|
156
|
-
"""
|
157
|
-
self.history = []
|
158
|
-
if clear_prompts:
|
159
|
-
self.prompt_messages = []
|
160
|
-
|
161
|
-
|
162
|
-
class AugmentedLLMProtocol(Protocol, Generic[MessageParamT, MessageT]):
|
163
|
-
"""Protocol defining the interface for augmented LLMs"""
|
164
|
-
|
165
|
-
async def generate(
|
166
|
-
self,
|
167
|
-
message: str | MessageParamT | List[MessageParamT],
|
168
|
-
request_params: RequestParams | None = None,
|
169
|
-
) -> List[MessageT]:
|
170
|
-
"""Request an LLM generation, which may run multiple iterations, and return the result"""
|
171
|
-
|
172
|
-
async def generate_str(
|
173
|
-
self,
|
174
|
-
message: str | MessageParamT | List[MessageParamT],
|
175
|
-
request_params: RequestParams | None = None,
|
176
|
-
) -> str:
|
177
|
-
"""Request an LLM generation and return the string representation of the result"""
|
178
|
-
|
179
|
-
async def generate_structured(
|
180
|
-
self,
|
181
|
-
message: str | MessageParamT | List[MessageParamT],
|
182
|
-
response_model: Type[ModelT],
|
183
|
-
request_params: RequestParams | None = None,
|
184
|
-
) -> ModelT:
|
185
|
-
"""Request a structured LLM generation and return the result as a Pydantic model."""
|
186
|
-
|
187
|
-
async def generate_prompt(
|
188
|
-
self, prompt: PromptMessageMultipart, request_params: RequestParams | None
|
189
|
-
) -> str:
|
190
|
-
"""Request an LLM generation and return a string representation of the result"""
|
191
|
-
|
192
|
-
|
193
52
|
class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, MessageT]):
|
194
53
|
"""
|
195
54
|
The basic building block of agentic systems is an LLM enhanced with augmentations
|
@@ -198,9 +57,6 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
198
57
|
selecting appropriate tools, and determining what information to retain.
|
199
58
|
"""
|
200
59
|
|
201
|
-
# TODO: saqadri - add streaming support (e.g. generate_stream)
|
202
|
-
# TODO: saqadri - consider adding middleware patterns for pre/post processing of messages, for now we have pre/post_tool_call
|
203
|
-
|
204
60
|
provider: str | None = None
|
205
61
|
|
206
62
|
def __init__(
|
@@ -210,10 +66,10 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
210
66
|
instruction: str | None = None,
|
211
67
|
name: str | None = None,
|
212
68
|
request_params: RequestParams | None = None,
|
213
|
-
type_converter: Type[
|
69
|
+
type_converter: Type[ProviderFormatConverter[MessageParamT, MessageT]] = BasicFormatConverter,
|
214
70
|
context: Optional["Context"] = None,
|
215
|
-
**kwargs,
|
216
|
-
):
|
71
|
+
**kwargs: dict[str, Any],
|
72
|
+
) -> None:
|
217
73
|
"""
|
218
74
|
Initialize the LLM with a list of server names and an instruction.
|
219
75
|
If a name is provided, it will be used to identify the LLM.
|
@@ -222,15 +78,11 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
222
78
|
# Extract request_params before super() call
|
223
79
|
self._init_request_params = request_params
|
224
80
|
super().__init__(context=context, **kwargs)
|
225
|
-
|
81
|
+
self.logger = get_logger(__name__)
|
226
82
|
self.executor = self.context.executor
|
227
|
-
self.aggregator = (
|
228
|
-
agent if agent is not None else MCPAggregator(server_names or [])
|
229
|
-
)
|
83
|
+
self.aggregator = agent if agent is not None else MCPAggregator(server_names or [])
|
230
84
|
self.name = name or (agent.name if agent else None)
|
231
|
-
self.instruction = instruction or (
|
232
|
-
agent.instruction if agent and isinstance(agent.instruction, str) else None
|
233
|
-
)
|
85
|
+
self.instruction = instruction or (agent.instruction if agent and isinstance(agent.instruction, str) else None)
|
234
86
|
self.history: Memory[MessageParamT] = SimpleMemory[MessageParamT]()
|
235
87
|
|
236
88
|
# Initialize the display component
|
@@ -241,9 +93,7 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
241
93
|
|
242
94
|
# Merge with provided params if any
|
243
95
|
if self._init_request_params:
|
244
|
-
self.default_request_params = self._merge_request_params(
|
245
|
-
self.default_request_params, self._init_request_params
|
246
|
-
)
|
96
|
+
self.default_request_params = self._merge_request_params(self.default_request_params, self._init_request_params)
|
247
97
|
|
248
98
|
self.type_converter = type_converter
|
249
99
|
self.verb = kwargs.get("verb")
|
@@ -273,13 +123,11 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
273
123
|
) -> ModelT:
|
274
124
|
"""Request a structured LLM generation and return the result as a Pydantic model."""
|
275
125
|
|
276
|
-
async def select_model(
|
277
|
-
self, request_params: RequestParams | None = None
|
278
|
-
) -> str | None:
|
126
|
+
async def select_model(self, request_params: RequestParams | None = None) -> str | None:
|
279
127
|
"""
|
280
128
|
Return the configured model (legacy support)
|
281
129
|
"""
|
282
|
-
if request_params.model:
|
130
|
+
if request_params and request_params.model:
|
283
131
|
return request_params.model
|
284
132
|
|
285
133
|
raise ModelConfigError("Internal Error: Model is not configured correctly")
|
@@ -294,9 +142,7 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
294
142
|
use_history=True,
|
295
143
|
)
|
296
144
|
|
297
|
-
def _merge_request_params(
|
298
|
-
self, default_params: RequestParams, provided_params: RequestParams
|
299
|
-
) -> RequestParams:
|
145
|
+
def _merge_request_params(self, default_params: RequestParams, provided_params: RequestParams) -> RequestParams:
|
300
146
|
"""Merge default and provided request parameters"""
|
301
147
|
|
302
148
|
merged = default_params.model_dump()
|
@@ -330,12 +176,10 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
330
176
|
return default_request_params
|
331
177
|
|
332
178
|
@classmethod
|
333
|
-
def convert_message_to_message_param(
|
334
|
-
cls, message: MessageT, **kwargs
|
335
|
-
) -> MessageParamT:
|
179
|
+
def convert_message_to_message_param(cls, message: MessageT, **kwargs: dict[str, Any]) -> MessageParamT:
|
336
180
|
"""Convert a response object to an input parameter object to allow LLM calls to be chained."""
|
337
181
|
# Many LLM implementations will allow the same type for input and output messages
|
338
|
-
return message
|
182
|
+
return cast("MessageParamT", message)
|
339
183
|
|
340
184
|
async def get_last_message(self) -> MessageParamT | None:
|
341
185
|
"""
|
@@ -350,15 +194,15 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
350
194
|
last_message = await self.get_last_message()
|
351
195
|
return self.message_param_str(last_message) if last_message else None
|
352
196
|
|
353
|
-
def show_tool_result(self, result: CallToolResult):
|
197
|
+
def show_tool_result(self, result: CallToolResult) -> None:
|
354
198
|
"""Display a tool result in a formatted panel."""
|
355
199
|
self.display.show_tool_result(result)
|
356
200
|
|
357
|
-
def show_oai_tool_result(self, result):
|
201
|
+
def show_oai_tool_result(self, result: str) -> None:
|
358
202
|
"""Display a tool result in a formatted panel."""
|
359
203
|
self.display.show_oai_tool_result(result)
|
360
204
|
|
361
|
-
def show_tool_call(self, available_tools, tool_name, tool_args):
|
205
|
+
def show_tool_call(self, available_tools, tool_name, tool_args) -> None:
|
362
206
|
"""Display a tool call in a formatted panel."""
|
363
207
|
self.display.show_tool_call(available_tools, tool_name, tool_args)
|
364
208
|
|
@@ -367,7 +211,7 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
367
211
|
message_text: str | Text,
|
368
212
|
highlight_namespaced_tool: str = "",
|
369
213
|
title: str = "ASSISTANT",
|
370
|
-
):
|
214
|
+
) -> None:
|
371
215
|
"""Display an assistant message in a formatted panel."""
|
372
216
|
await self.display.show_assistant_message(
|
373
217
|
message_text,
|
@@ -377,19 +221,15 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
377
221
|
name=self.name,
|
378
222
|
)
|
379
223
|
|
380
|
-
def show_user_message(self, message, model: str | None, chat_turn: int):
|
224
|
+
def show_user_message(self, message, model: str | None, chat_turn: int) -> None:
|
381
225
|
"""Display a user message in a formatted panel."""
|
382
226
|
self.display.show_user_message(message, model, chat_turn, name=self.name)
|
383
227
|
|
384
|
-
async def pre_tool_call(
|
385
|
-
self, tool_call_id: str | None, request: CallToolRequest
|
386
|
-
) -> CallToolRequest | bool:
|
228
|
+
async def pre_tool_call(self, tool_call_id: str | None, request: CallToolRequest) -> CallToolRequest | bool:
|
387
229
|
"""Called before a tool is executed. Return False to prevent execution."""
|
388
230
|
return request
|
389
231
|
|
390
|
-
async def post_tool_call(
|
391
|
-
self, tool_call_id: str | None, request: CallToolRequest, result: CallToolResult
|
392
|
-
) -> CallToolResult:
|
232
|
+
async def post_tool_call(self, tool_call_id: str | None, request: CallToolRequest, result: CallToolResult) -> CallToolResult:
|
393
233
|
"""Called after a tool execution. Can modify the result before it's returned."""
|
394
234
|
return result
|
395
235
|
|
@@ -412,7 +252,8 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
412
252
|
isError=True,
|
413
253
|
content=[
|
414
254
|
TextContent(
|
415
|
-
|
255
|
+
type="text",
|
256
|
+
text=f"Error: Tool '{request.params.name}' was not allowed to run.",
|
416
257
|
)
|
417
258
|
],
|
418
259
|
)
|
@@ -423,9 +264,7 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
423
264
|
tool_args = request.params.arguments
|
424
265
|
result = await self.aggregator.call_tool(tool_name, tool_args)
|
425
266
|
|
426
|
-
postprocess = await self.post_tool_call(
|
427
|
-
tool_call_id=tool_call_id, request=request, result=result
|
428
|
-
)
|
267
|
+
postprocess = await self.post_tool_call(tool_call_id=tool_call_id, request=request, result=result)
|
429
268
|
|
430
269
|
if isinstance(postprocess, CallToolResult):
|
431
270
|
result = postprocess
|
@@ -463,13 +302,13 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
463
302
|
if isinstance(part, dict) and "text" in part:
|
464
303
|
text_parts.append(part["text"])
|
465
304
|
elif hasattr(part, "text"):
|
466
|
-
text_parts.append(part.text)
|
305
|
+
text_parts.append(part.text) # type: ignore
|
467
306
|
if text_parts:
|
468
307
|
return "\n".join(text_parts)
|
469
308
|
|
470
309
|
# For objects with content attribute
|
471
310
|
if hasattr(message, "content"):
|
472
|
-
content = message.content
|
311
|
+
content = message.content # type: ignore
|
473
312
|
if isinstance(content, str):
|
474
313
|
return content
|
475
314
|
elif hasattr(content, "text"):
|
@@ -484,13 +323,13 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
484
323
|
Tries to extract just the content when possible.
|
485
324
|
"""
|
486
325
|
# First try to use the same method for consistency
|
487
|
-
result = self.message_param_str(message)
|
326
|
+
result = self.message_param_str(message) # type: ignore
|
488
327
|
if result != str(message):
|
489
328
|
return result
|
490
329
|
|
491
330
|
# Additional handling for output-specific formats
|
492
331
|
if hasattr(message, "content"):
|
493
|
-
content = message
|
332
|
+
content = getattr(message, "content")
|
494
333
|
if isinstance(content, list):
|
495
334
|
# Extract text from content blocks
|
496
335
|
text_parts = []
|
@@ -503,9 +342,7 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
503
342
|
# Default fallback
|
504
343
|
return str(message)
|
505
344
|
|
506
|
-
def _log_chat_progress(
|
507
|
-
self, chat_turn: Optional[int] = None, model: Optional[str] = None
|
508
|
-
):
|
345
|
+
def _log_chat_progress(self, chat_turn: Optional[int] = None, model: Optional[str] = None) -> None:
|
509
346
|
"""Log a chat progress event"""
|
510
347
|
# Determine action type based on verb
|
511
348
|
if hasattr(self, "verb") and self.verb:
|
@@ -522,7 +359,7 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
522
359
|
}
|
523
360
|
self.logger.debug("Chat in progress", data=data)
|
524
361
|
|
525
|
-
def _log_chat_finished(self, model: Optional[str] = None):
|
362
|
+
def _log_chat_finished(self, model: Optional[str] = None) -> None:
|
526
363
|
"""Log a chat finished event"""
|
527
364
|
data = {
|
528
365
|
"progress_action": ProgressAction.READY,
|
@@ -531,9 +368,7 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
531
368
|
}
|
532
369
|
self.logger.debug("Chat finished", data=data)
|
533
370
|
|
534
|
-
def _convert_prompt_messages(
|
535
|
-
self, prompt_messages: List[PromptMessage]
|
536
|
-
) -> List[MessageParamT]:
|
371
|
+
def _convert_prompt_messages(self, prompt_messages: List[PromptMessage]) -> List[MessageParamT]:
|
537
372
|
"""
|
538
373
|
Convert prompt messages to this LLM's specific message format.
|
539
374
|
To be implemented by concrete LLM classes.
|
@@ -546,7 +381,7 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
546
381
|
description: Optional[str] = None,
|
547
382
|
message_count: int = 0,
|
548
383
|
arguments: Optional[dict[str, str]] = None,
|
549
|
-
):
|
384
|
+
) -> None:
|
550
385
|
"""
|
551
386
|
Display information about a loaded prompt template.
|
552
387
|
|
@@ -565,9 +400,7 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
565
400
|
arguments=arguments,
|
566
401
|
)
|
567
402
|
|
568
|
-
async def apply_prompt_template(
|
569
|
-
self, prompt_result: GetPromptResult, prompt_name: str
|
570
|
-
) -> str:
|
403
|
+
async def apply_prompt_template(self, prompt_result: GetPromptResult, prompt_name: str) -> str:
|
571
404
|
"""
|
572
405
|
Apply a prompt template by adding it to the conversation history.
|
573
406
|
If the last message in the prompt is from a user, automatically
|
@@ -599,14 +432,10 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
599
432
|
)
|
600
433
|
|
601
434
|
# Convert to PromptMessageMultipart objects
|
602
|
-
multipart_messages = PromptMessageMultipart.parse_get_prompt_result(
|
603
|
-
prompt_result
|
604
|
-
)
|
435
|
+
multipart_messages = PromptMessageMultipart.parse_get_prompt_result(prompt_result)
|
605
436
|
|
606
437
|
# Delegate to the provider-specific implementation
|
607
|
-
return await self._apply_prompt_template_provider_specific(
|
608
|
-
multipart_messages, None
|
609
|
-
)
|
438
|
+
return await self._apply_prompt_template_provider_specific(multipart_messages, None)
|
610
439
|
|
611
440
|
async def apply_prompt(
|
612
441
|
self,
|
@@ -625,10 +454,9 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
625
454
|
String representation of the assistant's response
|
626
455
|
"""
|
627
456
|
# Delegate to the provider-specific implementation
|
628
|
-
return await self._apply_prompt_template_provider_specific(
|
629
|
-
multipart_messages, request_params
|
630
|
-
)
|
457
|
+
return await self._apply_prompt_template_provider_specific(multipart_messages, request_params)
|
631
458
|
|
459
|
+
# this shouln't need to be very big...
|
632
460
|
async def _apply_prompt_template_provider_specific(
|
633
461
|
self,
|
634
462
|
multipart_messages: List["PromptMessageMultipart"],
|
@@ -652,9 +480,7 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
652
480
|
|
653
481
|
if last_message.role == "user":
|
654
482
|
# For user messages: Add all previous messages to history, then generate response to the last one
|
655
|
-
self.logger.debug(
|
656
|
-
"Last message in prompt is from user, generating assistant response"
|
657
|
-
)
|
483
|
+
self.logger.debug("Last message in prompt is from user, generating assistant response")
|
658
484
|
|
659
485
|
# Add all but the last message to history
|
660
486
|
if len(multipart_messages) > 1:
|
@@ -664,11 +490,9 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
664
490
|
# Fallback generic method for all LLM types
|
665
491
|
for msg in previous_messages:
|
666
492
|
# Convert each PromptMessageMultipart to individual PromptMessages
|
667
|
-
prompt_messages = msg.
|
493
|
+
prompt_messages = msg.from_multipart()
|
668
494
|
for prompt_msg in prompt_messages:
|
669
|
-
converted.append(
|
670
|
-
self.type_converter.from_prompt_message(prompt_msg)
|
671
|
-
)
|
495
|
+
converted.append(self.type_converter.from_prompt_message(prompt_msg))
|
672
496
|
|
673
497
|
self.history.extend(converted, is_prompt=True)
|
674
498
|
|
@@ -677,8 +501,9 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
677
501
|
for content in last_message.content:
|
678
502
|
if content.type == "text":
|
679
503
|
user_text_parts.append(content.text)
|
680
|
-
elif content.type == "resource" and
|
681
|
-
|
504
|
+
elif content.type == "resource" and getattr(content, "resource", None) is not None:
|
505
|
+
if hasattr(content.resource, "text"):
|
506
|
+
user_text_parts.append(content.resource.text) # type: ignore
|
682
507
|
elif content.type == "image":
|
683
508
|
# Add a placeholder for images
|
684
509
|
mime_type = getattr(content, "mimeType", "image/unknown")
|
@@ -692,9 +517,7 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
692
517
|
return await self.generate_str(user_text)
|
693
518
|
else:
|
694
519
|
# For assistant messages: Add all messages to history and return the last one
|
695
|
-
self.logger.debug(
|
696
|
-
"Last message in prompt is from assistant, returning it directly"
|
697
|
-
)
|
520
|
+
self.logger.debug("Last message in prompt is from assistant, returning it directly")
|
698
521
|
|
699
522
|
# Convert and add all messages to history
|
700
523
|
converted = []
|
@@ -702,11 +525,9 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
702
525
|
# Fallback to the original method for all LLM types
|
703
526
|
for msg in multipart_messages:
|
704
527
|
# Convert each PromptMessageMultipart to individual PromptMessages
|
705
|
-
prompt_messages = msg.
|
528
|
+
prompt_messages = msg.from_multipart()
|
706
529
|
for prompt_msg in prompt_messages:
|
707
|
-
converted.append(
|
708
|
-
self.type_converter.from_prompt_message(prompt_msg)
|
709
|
-
)
|
530
|
+
converted.append(self.type_converter.from_prompt_message(prompt_msg))
|
710
531
|
|
711
532
|
self.history.extend(converted, is_prompt=True)
|
712
533
|
|
@@ -723,11 +544,11 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
723
544
|
uri = getattr(content.resource, "uri", "")
|
724
545
|
if uri:
|
725
546
|
assistant_text_parts.append(
|
726
|
-
f"[Resource: {uri}, Type: {mime_type}]\n{content.resource.text}"
|
547
|
+
f"[Resource: {uri}, Type: {mime_type}]\n{content.resource.text}" # ignore # type: ignore
|
727
548
|
)
|
728
549
|
else:
|
729
550
|
assistant_text_parts.append(
|
730
|
-
f"[Resource Type: {mime_type}]\n{content.resource.text}"
|
551
|
+
f"[Resource Type: {mime_type}]\n{content.resource.text}" # type ignore # type: ignore
|
731
552
|
)
|
732
553
|
elif content.type == "image":
|
733
554
|
# Note the presence of images
|
@@ -740,14 +561,15 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
740
561
|
has_non_text_content = True
|
741
562
|
|
742
563
|
# Join all parts with double newlines for better readability
|
743
|
-
result = (
|
744
|
-
"\n\n".join(assistant_text_parts)
|
745
|
-
if assistant_text_parts
|
746
|
-
else str(last_message.content)
|
747
|
-
)
|
564
|
+
result = "\n\n".join(assistant_text_parts) if assistant_text_parts else str(last_message.content)
|
748
565
|
|
749
566
|
# Add a note if non-text content was present
|
750
567
|
if has_non_text_content:
|
751
568
|
result += "\n\n[Note: This message contained non-text content that may not be fully represented in text format]"
|
752
569
|
|
753
570
|
return result
|
571
|
+
|
572
|
+
|
573
|
+
#####################################
|
574
|
+
### NEW INTERFACE METHODS BELOW ###
|
575
|
+
#####################################
|