fast-agent-mcp 0.1.11__py3-none-any.whl → 0.1.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {fast_agent_mcp-0.1.11.dist-info → fast_agent_mcp-0.1.13.dist-info}/METADATA +1 -1
- fast_agent_mcp-0.1.13.dist-info/RECORD +164 -0
- mcp_agent/agents/agent.py +37 -102
- mcp_agent/app.py +16 -27
- mcp_agent/cli/commands/bootstrap.py +22 -52
- mcp_agent/cli/commands/config.py +4 -4
- mcp_agent/cli/commands/setup.py +11 -26
- mcp_agent/cli/main.py +6 -9
- mcp_agent/cli/terminal.py +2 -2
- mcp_agent/config.py +1 -5
- mcp_agent/context.py +13 -26
- mcp_agent/context_dependent.py +3 -7
- mcp_agent/core/agent_app.py +46 -122
- mcp_agent/core/agent_types.py +29 -2
- mcp_agent/core/agent_utils.py +3 -5
- mcp_agent/core/decorators.py +6 -14
- mcp_agent/core/enhanced_prompt.py +25 -52
- mcp_agent/core/error_handling.py +1 -1
- mcp_agent/core/exceptions.py +8 -8
- mcp_agent/core/factory.py +30 -72
- mcp_agent/core/fastagent.py +48 -88
- mcp_agent/core/mcp_content.py +10 -19
- mcp_agent/core/prompt.py +8 -15
- mcp_agent/core/proxies.py +34 -25
- mcp_agent/core/request_params.py +46 -0
- mcp_agent/core/types.py +6 -6
- mcp_agent/core/validation.py +16 -16
- mcp_agent/executor/decorator_registry.py +11 -23
- mcp_agent/executor/executor.py +8 -17
- mcp_agent/executor/task_registry.py +2 -4
- mcp_agent/executor/temporal.py +28 -74
- mcp_agent/executor/workflow.py +3 -5
- mcp_agent/executor/workflow_signal.py +17 -29
- mcp_agent/human_input/handler.py +4 -9
- mcp_agent/human_input/types.py +2 -3
- mcp_agent/logging/events.py +1 -5
- mcp_agent/logging/json_serializer.py +7 -6
- mcp_agent/logging/listeners.py +20 -23
- mcp_agent/logging/logger.py +15 -17
- mcp_agent/logging/rich_progress.py +10 -8
- mcp_agent/logging/tracing.py +4 -6
- mcp_agent/logging/transport.py +24 -24
- mcp_agent/mcp/gen_client.py +4 -12
- mcp_agent/mcp/interfaces.py +107 -88
- mcp_agent/mcp/mcp_agent_client_session.py +11 -19
- mcp_agent/mcp/mcp_agent_server.py +8 -10
- mcp_agent/mcp/mcp_aggregator.py +49 -122
- mcp_agent/mcp/mcp_connection_manager.py +16 -37
- mcp_agent/mcp/prompt_message_multipart.py +12 -18
- mcp_agent/mcp/prompt_serialization.py +13 -38
- mcp_agent/mcp/prompts/prompt_load.py +99 -0
- mcp_agent/mcp/prompts/prompt_server.py +21 -128
- mcp_agent/mcp/prompts/prompt_template.py +20 -42
- mcp_agent/mcp/resource_utils.py +8 -17
- mcp_agent/mcp/sampling.py +62 -64
- mcp_agent/mcp/stdio.py +11 -8
- mcp_agent/mcp_server/__init__.py +1 -1
- mcp_agent/mcp_server/agent_server.py +10 -17
- mcp_agent/mcp_server_registry.py +13 -35
- mcp_agent/resources/examples/data-analysis/analysis-campaign.py +1 -1
- mcp_agent/resources/examples/data-analysis/analysis.py +1 -1
- mcp_agent/resources/examples/data-analysis/slides.py +110 -0
- mcp_agent/resources/examples/internal/agent.py +2 -1
- mcp_agent/resources/examples/internal/job.py +2 -1
- mcp_agent/resources/examples/internal/prompt_category.py +1 -1
- mcp_agent/resources/examples/internal/prompt_sizing.py +3 -5
- mcp_agent/resources/examples/internal/sizer.py +2 -1
- mcp_agent/resources/examples/internal/social.py +2 -1
- mcp_agent/resources/examples/mcp_researcher/researcher-eval.py +1 -1
- mcp_agent/resources/examples/prompting/__init__.py +1 -1
- mcp_agent/resources/examples/prompting/agent.py +2 -1
- mcp_agent/resources/examples/prompting/image_server.py +5 -11
- mcp_agent/resources/examples/researcher/researcher-eval.py +1 -1
- mcp_agent/resources/examples/researcher/researcher-imp.py +3 -4
- mcp_agent/resources/examples/researcher/researcher.py +2 -1
- mcp_agent/resources/examples/workflows/agent_build.py +2 -1
- mcp_agent/resources/examples/workflows/chaining.py +2 -1
- mcp_agent/resources/examples/workflows/evaluator.py +2 -1
- mcp_agent/resources/examples/workflows/human_input.py +2 -1
- mcp_agent/resources/examples/workflows/orchestrator.py +2 -1
- mcp_agent/resources/examples/workflows/parallel.py +2 -1
- mcp_agent/resources/examples/workflows/router.py +2 -1
- mcp_agent/resources/examples/workflows/sse.py +1 -1
- mcp_agent/telemetry/usage_tracking.py +2 -1
- mcp_agent/ui/console_display.py +17 -41
- mcp_agent/workflows/embedding/embedding_base.py +1 -4
- mcp_agent/workflows/embedding/embedding_cohere.py +2 -2
- mcp_agent/workflows/embedding/embedding_openai.py +4 -13
- mcp_agent/workflows/evaluator_optimizer/evaluator_optimizer.py +23 -57
- mcp_agent/workflows/intent_classifier/intent_classifier_base.py +5 -8
- mcp_agent/workflows/intent_classifier/intent_classifier_embedding.py +7 -11
- mcp_agent/workflows/intent_classifier/intent_classifier_embedding_cohere.py +4 -8
- mcp_agent/workflows/intent_classifier/intent_classifier_embedding_openai.py +4 -8
- mcp_agent/workflows/intent_classifier/intent_classifier_llm.py +11 -22
- mcp_agent/workflows/intent_classifier/intent_classifier_llm_anthropic.py +3 -3
- mcp_agent/workflows/intent_classifier/intent_classifier_llm_openai.py +4 -6
- mcp_agent/workflows/llm/anthropic_utils.py +8 -29
- mcp_agent/workflows/llm/augmented_llm.py +94 -332
- mcp_agent/workflows/llm/augmented_llm_anthropic.py +43 -76
- mcp_agent/workflows/llm/augmented_llm_openai.py +46 -100
- mcp_agent/workflows/llm/augmented_llm_passthrough.py +42 -20
- mcp_agent/workflows/llm/augmented_llm_playback.py +8 -6
- mcp_agent/workflows/llm/memory.py +103 -0
- mcp_agent/workflows/llm/model_factory.py +9 -21
- mcp_agent/workflows/llm/openai_utils.py +1 -1
- mcp_agent/workflows/llm/prompt_utils.py +39 -27
- mcp_agent/workflows/llm/providers/multipart_converter_anthropic.py +246 -184
- mcp_agent/workflows/llm/providers/multipart_converter_openai.py +212 -202
- mcp_agent/workflows/llm/providers/openai_multipart.py +19 -61
- mcp_agent/workflows/llm/providers/sampling_converter_anthropic.py +11 -212
- mcp_agent/workflows/llm/providers/sampling_converter_openai.py +13 -215
- mcp_agent/workflows/llm/sampling_converter.py +117 -0
- mcp_agent/workflows/llm/sampling_format_converter.py +12 -29
- mcp_agent/workflows/orchestrator/orchestrator.py +24 -67
- mcp_agent/workflows/orchestrator/orchestrator_models.py +14 -40
- mcp_agent/workflows/parallel/fan_in.py +17 -47
- mcp_agent/workflows/parallel/fan_out.py +6 -12
- mcp_agent/workflows/parallel/parallel_llm.py +9 -26
- mcp_agent/workflows/router/router_base.py +29 -59
- mcp_agent/workflows/router/router_embedding.py +11 -25
- mcp_agent/workflows/router/router_embedding_cohere.py +2 -2
- mcp_agent/workflows/router/router_embedding_openai.py +2 -2
- mcp_agent/workflows/router/router_llm.py +12 -28
- mcp_agent/workflows/swarm/swarm.py +20 -48
- mcp_agent/workflows/swarm/swarm_anthropic.py +2 -2
- mcp_agent/workflows/swarm/swarm_openai.py +2 -2
- fast_agent_mcp-0.1.11.dist-info/RECORD +0 -160
- mcp_agent/workflows/llm/llm_selector.py +0 -345
- {fast_agent_mcp-0.1.11.dist-info → fast_agent_mcp-0.1.13.dist-info}/WHEEL +0 -0
- {fast_agent_mcp-0.1.11.dist-info → fast_agent_mcp-0.1.13.dist-info}/entry_points.txt +0 -0
- {fast_agent_mcp-0.1.11.dist-info → fast_agent_mcp-0.1.13.dist-info}/licenses/LICENSE +0 -0
@@ -1,228 +1,54 @@
|
|
1
1
|
from abc import abstractmethod
|
2
|
-
|
3
2
|
from typing import (
|
4
|
-
|
3
|
+
TYPE_CHECKING,
|
4
|
+
Any,
|
5
5
|
List,
|
6
6
|
Optional,
|
7
|
-
Protocol,
|
8
7
|
Type,
|
9
|
-
|
10
|
-
TYPE_CHECKING,
|
8
|
+
cast,
|
11
9
|
)
|
12
10
|
|
13
|
-
from
|
14
|
-
from mcp_agent.mcp.
|
15
|
-
|
16
|
-
SamplingFormatConverter,
|
11
|
+
from mcp_agent.logging.logger import get_logger
|
12
|
+
from mcp_agent.mcp.interfaces import (
|
13
|
+
AugmentedLLMProtocol,
|
17
14
|
MessageParamT,
|
18
15
|
MessageT,
|
16
|
+
ModelT,
|
17
|
+
)
|
18
|
+
from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
|
19
|
+
from mcp_agent.workflows.llm.sampling_format_converter import (
|
20
|
+
BasicFormatConverter,
|
21
|
+
ProviderFormatConverter,
|
19
22
|
)
|
20
23
|
|
21
24
|
# Forward reference for type annotations
|
22
25
|
if TYPE_CHECKING:
|
23
|
-
from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
|
24
26
|
from mcp_agent.agents.agent import Agent
|
25
27
|
from mcp_agent.context import Context
|
28
|
+
from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
|
26
29
|
|
27
30
|
|
28
|
-
from pydantic import Field
|
29
|
-
|
30
31
|
from mcp.types import (
|
31
32
|
CallToolRequest,
|
32
33
|
CallToolResult,
|
33
|
-
|
34
|
-
ModelPreferences,
|
34
|
+
GetPromptResult,
|
35
35
|
PromptMessage,
|
36
36
|
TextContent,
|
37
|
-
GetPromptResult,
|
38
37
|
)
|
38
|
+
from rich.text import Text
|
39
39
|
|
40
40
|
from mcp_agent.context_dependent import ContextDependent
|
41
|
-
from mcp_agent.core.exceptions import PromptExitError
|
41
|
+
from mcp_agent.core.exceptions import ModelConfigError, PromptExitError
|
42
|
+
from mcp_agent.core.request_params import RequestParams
|
42
43
|
from mcp_agent.event_progress import ProgressAction
|
43
44
|
from mcp_agent.mcp.mcp_aggregator import MCPAggregator
|
44
|
-
from mcp_agent.workflows.llm.llm_selector import ModelSelector
|
45
45
|
from mcp_agent.ui.console_display import ConsoleDisplay
|
46
|
-
from
|
47
|
-
|
48
|
-
|
49
|
-
ModelT = TypeVar("ModelT")
|
50
|
-
"""A type representing a structured output message from an LLM."""
|
51
|
-
|
46
|
+
from mcp_agent.workflows.llm.memory import Memory, SimpleMemory
|
52
47
|
|
53
48
|
# TODO -- move this to a constant
|
54
49
|
HUMAN_INPUT_TOOL_NAME = "__human_input__"
|
55
50
|
|
56
51
|
|
57
|
-
class Memory(Protocol, Generic[MessageParamT]):
|
58
|
-
"""
|
59
|
-
Simple memory management for storing past interactions in-memory.
|
60
|
-
"""
|
61
|
-
|
62
|
-
# TODO: saqadri - add checkpointing and other advanced memory capabilities
|
63
|
-
|
64
|
-
def __init__(self): ...
|
65
|
-
|
66
|
-
def extend(
|
67
|
-
self, messages: List[MessageParamT], is_prompt: bool = False
|
68
|
-
) -> None: ...
|
69
|
-
|
70
|
-
def set(self, messages: List[MessageParamT], is_prompt: bool = False) -> None: ...
|
71
|
-
|
72
|
-
def append(self, message: MessageParamT, is_prompt: bool = False) -> None: ...
|
73
|
-
|
74
|
-
def get(self, include_history: bool = True) -> List[MessageParamT]: ...
|
75
|
-
|
76
|
-
def clear(self, clear_prompts: bool = False) -> None: ...
|
77
|
-
|
78
|
-
|
79
|
-
class SimpleMemory(Memory, Generic[MessageParamT]):
|
80
|
-
"""
|
81
|
-
Simple memory management for storing past interactions in-memory.
|
82
|
-
|
83
|
-
Maintains both prompt messages (which are always included) and
|
84
|
-
generated conversation history (which is included based on use_history setting).
|
85
|
-
"""
|
86
|
-
|
87
|
-
def __init__(self):
|
88
|
-
self.history: List[MessageParamT] = []
|
89
|
-
self.prompt_messages: List[MessageParamT] = [] # Always included
|
90
|
-
|
91
|
-
def extend(self, messages: List[MessageParamT], is_prompt: bool = False):
|
92
|
-
"""
|
93
|
-
Add multiple messages to history.
|
94
|
-
|
95
|
-
Args:
|
96
|
-
messages: Messages to add
|
97
|
-
is_prompt: If True, add to prompt_messages instead of regular history
|
98
|
-
"""
|
99
|
-
if is_prompt:
|
100
|
-
self.prompt_messages.extend(messages)
|
101
|
-
else:
|
102
|
-
self.history.extend(messages)
|
103
|
-
|
104
|
-
def set(self, messages: List[MessageParamT], is_prompt: bool = False):
|
105
|
-
"""
|
106
|
-
Replace messages in history.
|
107
|
-
|
108
|
-
Args:
|
109
|
-
messages: Messages to set
|
110
|
-
is_prompt: If True, replace prompt_messages instead of regular history
|
111
|
-
"""
|
112
|
-
if is_prompt:
|
113
|
-
self.prompt_messages = messages.copy()
|
114
|
-
else:
|
115
|
-
self.history = messages.copy()
|
116
|
-
|
117
|
-
def append(self, message: MessageParamT, is_prompt: bool = False):
|
118
|
-
"""
|
119
|
-
Add a single message to history.
|
120
|
-
|
121
|
-
Args:
|
122
|
-
message: Message to add
|
123
|
-
is_prompt: If True, add to prompt_messages instead of regular history
|
124
|
-
"""
|
125
|
-
if is_prompt:
|
126
|
-
self.prompt_messages.append(message)
|
127
|
-
else:
|
128
|
-
self.history.append(message)
|
129
|
-
|
130
|
-
def get(self, include_history: bool = True) -> List[MessageParamT]:
|
131
|
-
"""
|
132
|
-
Get all messages in memory.
|
133
|
-
|
134
|
-
Args:
|
135
|
-
include_history: If True, include regular history messages
|
136
|
-
If False, only return prompt messages
|
137
|
-
|
138
|
-
Returns:
|
139
|
-
Combined list of prompt messages and optionally history messages
|
140
|
-
"""
|
141
|
-
if include_history:
|
142
|
-
return self.prompt_messages + self.history
|
143
|
-
else:
|
144
|
-
return self.prompt_messages.copy()
|
145
|
-
|
146
|
-
def clear(self, clear_prompts: bool = False):
|
147
|
-
"""
|
148
|
-
Clear history and optionally prompt messages.
|
149
|
-
|
150
|
-
Args:
|
151
|
-
clear_prompts: If True, also clear prompt messages
|
152
|
-
"""
|
153
|
-
self.history = []
|
154
|
-
if clear_prompts:
|
155
|
-
self.prompt_messages = []
|
156
|
-
|
157
|
-
|
158
|
-
class RequestParams(CreateMessageRequestParams):
|
159
|
-
"""
|
160
|
-
Parameters to configure the AugmentedLLM 'generate' requests.
|
161
|
-
"""
|
162
|
-
|
163
|
-
messages: None = Field(exclude=True, default=None)
|
164
|
-
"""
|
165
|
-
Ignored. 'messages' are removed from CreateMessageRequestParams
|
166
|
-
to avoid confusion with the 'message' parameter on 'generate' method.
|
167
|
-
"""
|
168
|
-
|
169
|
-
maxTokens: int = 2048
|
170
|
-
"""The maximum number of tokens to sample, as requested by the server."""
|
171
|
-
|
172
|
-
model: str | None = None
|
173
|
-
"""
|
174
|
-
The model to use for the LLM generation.
|
175
|
-
If specified, this overrides the 'modelPreferences' selection criteria.
|
176
|
-
"""
|
177
|
-
|
178
|
-
use_history: bool = True
|
179
|
-
"""
|
180
|
-
Include the message history in the generate request.
|
181
|
-
"""
|
182
|
-
|
183
|
-
max_iterations: int = 10
|
184
|
-
"""
|
185
|
-
The maximum number of iterations to run the LLM for.
|
186
|
-
"""
|
187
|
-
|
188
|
-
parallel_tool_calls: bool = True
|
189
|
-
"""
|
190
|
-
Whether to allow multiple tool calls per iteration.
|
191
|
-
Also known as multi-step tool use.
|
192
|
-
"""
|
193
|
-
|
194
|
-
|
195
|
-
class AugmentedLLMProtocol(Protocol, Generic[MessageParamT, MessageT]):
|
196
|
-
"""Protocol defining the interface for augmented LLMs"""
|
197
|
-
|
198
|
-
async def generate(
|
199
|
-
self,
|
200
|
-
message: str | MessageParamT | List[MessageParamT],
|
201
|
-
request_params: RequestParams | None = None,
|
202
|
-
) -> List[MessageT]:
|
203
|
-
"""Request an LLM generation, which may run multiple iterations, and return the result"""
|
204
|
-
|
205
|
-
async def generate_str(
|
206
|
-
self,
|
207
|
-
message: str | MessageParamT | List[MessageParamT],
|
208
|
-
request_params: RequestParams | None = None,
|
209
|
-
) -> str:
|
210
|
-
"""Request an LLM generation and return the string representation of the result"""
|
211
|
-
|
212
|
-
async def generate_structured(
|
213
|
-
self,
|
214
|
-
message: str | MessageParamT | List[MessageParamT],
|
215
|
-
response_model: Type[ModelT],
|
216
|
-
request_params: RequestParams | None = None,
|
217
|
-
) -> ModelT:
|
218
|
-
"""Request a structured LLM generation and return the result as a Pydantic model."""
|
219
|
-
|
220
|
-
async def generate_prompt(
|
221
|
-
self, prompt: PromptMessageMultipart, request_params: RequestParams | None
|
222
|
-
) -> str:
|
223
|
-
"""Request an LLM generation and return a string representation of the result"""
|
224
|
-
|
225
|
-
|
226
52
|
class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, MessageT]):
|
227
53
|
"""
|
228
54
|
The basic building block of agentic systems is an LLM enhanced with augmentations
|
@@ -231,9 +57,6 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
231
57
|
selecting appropriate tools, and determining what information to retain.
|
232
58
|
"""
|
233
59
|
|
234
|
-
# TODO: saqadri - add streaming support (e.g. generate_stream)
|
235
|
-
# TODO: saqadri - consider adding middleware patterns for pre/post processing of messages, for now we have pre/post_tool_call
|
236
|
-
|
237
60
|
provider: str | None = None
|
238
61
|
|
239
62
|
def __init__(
|
@@ -243,10 +66,10 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
243
66
|
instruction: str | None = None,
|
244
67
|
name: str | None = None,
|
245
68
|
request_params: RequestParams | None = None,
|
246
|
-
type_converter: Type[
|
69
|
+
type_converter: Type[ProviderFormatConverter[MessageParamT, MessageT]] = BasicFormatConverter,
|
247
70
|
context: Optional["Context"] = None,
|
248
|
-
**kwargs,
|
249
|
-
):
|
71
|
+
**kwargs: dict[str, Any],
|
72
|
+
) -> None:
|
250
73
|
"""
|
251
74
|
Initialize the LLM with a list of server names and an instruction.
|
252
75
|
If a name is provided, it will be used to identify the LLM.
|
@@ -255,44 +78,23 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
255
78
|
# Extract request_params before super() call
|
256
79
|
self._init_request_params = request_params
|
257
80
|
super().__init__(context=context, **kwargs)
|
258
|
-
|
81
|
+
self.logger = get_logger(__name__)
|
259
82
|
self.executor = self.context.executor
|
260
|
-
self.aggregator = (
|
261
|
-
agent if agent is not None else MCPAggregator(server_names or [])
|
262
|
-
)
|
83
|
+
self.aggregator = agent if agent is not None else MCPAggregator(server_names or [])
|
263
84
|
self.name = name or (agent.name if agent else None)
|
264
|
-
self.instruction = instruction or (
|
265
|
-
agent.instruction if agent and isinstance(agent.instruction, str) else None
|
266
|
-
)
|
85
|
+
self.instruction = instruction or (agent.instruction if agent and isinstance(agent.instruction, str) else None)
|
267
86
|
self.history: Memory[MessageParamT] = SimpleMemory[MessageParamT]()
|
268
87
|
|
269
88
|
# Initialize the display component
|
270
89
|
self.display = ConsoleDisplay(config=self.context.config)
|
271
90
|
|
272
|
-
# Set initial model preferences
|
273
|
-
self.model_preferences = ModelPreferences(
|
274
|
-
costPriority=0.3,
|
275
|
-
speedPriority=0.4,
|
276
|
-
intelligencePriority=0.3,
|
277
|
-
)
|
278
|
-
|
279
91
|
# Initialize default parameters
|
280
92
|
self.default_request_params = self._initialize_default_params(kwargs)
|
281
93
|
|
282
|
-
# Update model preferences from default params
|
283
|
-
if self.default_request_params and self.default_request_params.modelPreferences:
|
284
|
-
self.model_preferences = self.default_request_params.modelPreferences
|
285
|
-
|
286
94
|
# Merge with provided params if any
|
287
95
|
if self._init_request_params:
|
288
|
-
self.default_request_params = self._merge_request_params(
|
289
|
-
self.default_request_params, self._init_request_params
|
290
|
-
)
|
291
|
-
# Update model preferences again if they changed in the merge
|
292
|
-
if self.default_request_params.modelPreferences:
|
293
|
-
self.model_preferences = self.default_request_params.modelPreferences
|
96
|
+
self.default_request_params = self._merge_request_params(self.default_request_params, self._init_request_params)
|
294
97
|
|
295
|
-
self.model_selector = self.context.model_selector
|
296
98
|
self.type_converter = type_converter
|
297
99
|
self.verb = kwargs.get("verb")
|
298
100
|
|
@@ -321,48 +123,26 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
321
123
|
) -> ModelT:
|
322
124
|
"""Request a structured LLM generation and return the result as a Pydantic model."""
|
323
125
|
|
324
|
-
|
325
|
-
# """Request an LLM generation, which may run multiple iterations, and return the result"""
|
326
|
-
# return None
|
327
|
-
|
328
|
-
async def select_model(
|
329
|
-
self, request_params: RequestParams | None = None
|
330
|
-
) -> str | None:
|
126
|
+
async def select_model(self, request_params: RequestParams | None = None) -> str | None:
|
331
127
|
"""
|
332
|
-
|
333
|
-
If a model is specified in the request, it will override the model selection criteria.
|
128
|
+
Return the configured model (legacy support)
|
334
129
|
"""
|
335
|
-
|
336
|
-
|
337
|
-
model_preferences = request_params.modelPreferences or model_preferences
|
338
|
-
model = request_params.model
|
339
|
-
if model:
|
340
|
-
return model
|
341
|
-
|
342
|
-
## TODO -- can't have been tested, returns invalid model strings (e.g. claude-35-sonnet)
|
343
|
-
if not self.model_selector:
|
344
|
-
self.model_selector = ModelSelector()
|
345
|
-
|
346
|
-
model_info = self.model_selector.select_best_model(
|
347
|
-
model_preferences=model_preferences, provider=self.provider
|
348
|
-
)
|
130
|
+
if request_params and request_params.model:
|
131
|
+
return request_params.model
|
349
132
|
|
350
|
-
|
133
|
+
raise ModelConfigError("Internal Error: Model is not configured correctly")
|
351
134
|
|
352
135
|
def _initialize_default_params(self, kwargs: dict) -> RequestParams:
|
353
136
|
"""Initialize default parameters for the LLM.
|
354
137
|
Should be overridden by provider implementations to set provider-specific defaults."""
|
355
138
|
return RequestParams(
|
356
|
-
modelPreferences=self.model_preferences,
|
357
139
|
systemPrompt=self.instruction,
|
358
140
|
parallel_tool_calls=True,
|
359
141
|
max_iterations=10,
|
360
142
|
use_history=True,
|
361
143
|
)
|
362
144
|
|
363
|
-
def _merge_request_params(
|
364
|
-
self, default_params: RequestParams, provided_params: RequestParams
|
365
|
-
) -> RequestParams:
|
145
|
+
def _merge_request_params(self, default_params: RequestParams, provided_params: RequestParams) -> RequestParams:
|
366
146
|
"""Merge default and provided request parameters"""
|
367
147
|
|
368
148
|
merged = default_params.model_dump()
|
@@ -395,32 +175,11 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
395
175
|
|
396
176
|
return default_request_params
|
397
177
|
|
398
|
-
def to_mcp_message_result(self, result: MessageT) -> CreateMessageResult:
|
399
|
-
"""Convert an LLM response to an MCP message result type."""
|
400
|
-
return self.type_converter.to_sampling_result(result)
|
401
|
-
|
402
|
-
def from_mcp_message_result(self, result: CreateMessageResult) -> MessageT:
|
403
|
-
"""Convert an MCP message result to an LLM response type."""
|
404
|
-
return self.type_converter.from_sampling_result(result)
|
405
|
-
|
406
|
-
def to_mcp_message_param(self, param: MessageParamT) -> SamplingMessage:
|
407
|
-
"""Convert an LLM input to an MCP message (SamplingMessage) type."""
|
408
|
-
return self.type_converter.to_sampling_message(param)
|
409
|
-
|
410
|
-
def from_mcp_message_param(self, param: SamplingMessage) -> MessageParamT:
|
411
|
-
"""Convert an MCP message (SamplingMessage) to an LLM input type."""
|
412
|
-
return self.type_converter.from_sampling_message(param)
|
413
|
-
|
414
|
-
def from_mcp_prompt_message(self, message: PromptMessage) -> MessageParamT:
|
415
|
-
return self.type_converter.from_prompt_message(message)
|
416
|
-
|
417
178
|
@classmethod
|
418
|
-
def convert_message_to_message_param(
|
419
|
-
cls, message: MessageT, **kwargs
|
420
|
-
) -> MessageParamT:
|
179
|
+
def convert_message_to_message_param(cls, message: MessageT, **kwargs: dict[str, Any]) -> MessageParamT:
|
421
180
|
"""Convert a response object to an input parameter object to allow LLM calls to be chained."""
|
422
181
|
# Many LLM implementations will allow the same type for input and output messages
|
423
|
-
return message
|
182
|
+
return cast("MessageParamT", message)
|
424
183
|
|
425
184
|
async def get_last_message(self) -> MessageParamT | None:
|
426
185
|
"""
|
@@ -435,15 +194,15 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
435
194
|
last_message = await self.get_last_message()
|
436
195
|
return self.message_param_str(last_message) if last_message else None
|
437
196
|
|
438
|
-
def show_tool_result(self, result: CallToolResult):
|
197
|
+
def show_tool_result(self, result: CallToolResult) -> None:
|
439
198
|
"""Display a tool result in a formatted panel."""
|
440
199
|
self.display.show_tool_result(result)
|
441
200
|
|
442
|
-
def show_oai_tool_result(self, result):
|
201
|
+
def show_oai_tool_result(self, result: str) -> None:
|
443
202
|
"""Display a tool result in a formatted panel."""
|
444
203
|
self.display.show_oai_tool_result(result)
|
445
204
|
|
446
|
-
def show_tool_call(self, available_tools, tool_name, tool_args):
|
205
|
+
def show_tool_call(self, available_tools, tool_name, tool_args) -> None:
|
447
206
|
"""Display a tool call in a formatted panel."""
|
448
207
|
self.display.show_tool_call(available_tools, tool_name, tool_args)
|
449
208
|
|
@@ -452,7 +211,7 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
452
211
|
message_text: str | Text,
|
453
212
|
highlight_namespaced_tool: str = "",
|
454
213
|
title: str = "ASSISTANT",
|
455
|
-
):
|
214
|
+
) -> None:
|
456
215
|
"""Display an assistant message in a formatted panel."""
|
457
216
|
await self.display.show_assistant_message(
|
458
217
|
message_text,
|
@@ -462,19 +221,15 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
462
221
|
name=self.name,
|
463
222
|
)
|
464
223
|
|
465
|
-
def show_user_message(self, message, model: str | None, chat_turn: int):
|
224
|
+
def show_user_message(self, message, model: str | None, chat_turn: int) -> None:
|
466
225
|
"""Display a user message in a formatted panel."""
|
467
226
|
self.display.show_user_message(message, model, chat_turn, name=self.name)
|
468
227
|
|
469
|
-
async def pre_tool_call(
|
470
|
-
self, tool_call_id: str | None, request: CallToolRequest
|
471
|
-
) -> CallToolRequest | bool:
|
228
|
+
async def pre_tool_call(self, tool_call_id: str | None, request: CallToolRequest) -> CallToolRequest | bool:
|
472
229
|
"""Called before a tool is executed. Return False to prevent execution."""
|
473
230
|
return request
|
474
231
|
|
475
|
-
async def post_tool_call(
|
476
|
-
self, tool_call_id: str | None, request: CallToolRequest, result: CallToolResult
|
477
|
-
) -> CallToolResult:
|
232
|
+
async def post_tool_call(self, tool_call_id: str | None, request: CallToolRequest, result: CallToolResult) -> CallToolResult:
|
478
233
|
"""Called after a tool execution. Can modify the result before it's returned."""
|
479
234
|
return result
|
480
235
|
|
@@ -497,7 +252,8 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
497
252
|
isError=True,
|
498
253
|
content=[
|
499
254
|
TextContent(
|
500
|
-
|
255
|
+
type="text",
|
256
|
+
text=f"Error: Tool '{request.params.name}' was not allowed to run.",
|
501
257
|
)
|
502
258
|
],
|
503
259
|
)
|
@@ -508,9 +264,7 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
508
264
|
tool_args = request.params.arguments
|
509
265
|
result = await self.aggregator.call_tool(tool_name, tool_args)
|
510
266
|
|
511
|
-
postprocess = await self.post_tool_call(
|
512
|
-
tool_call_id=tool_call_id, request=request, result=result
|
513
|
-
)
|
267
|
+
postprocess = await self.post_tool_call(tool_call_id=tool_call_id, request=request, result=result)
|
514
268
|
|
515
269
|
if isinstance(postprocess, CallToolResult):
|
516
270
|
result = postprocess
|
@@ -548,13 +302,13 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
548
302
|
if isinstance(part, dict) and "text" in part:
|
549
303
|
text_parts.append(part["text"])
|
550
304
|
elif hasattr(part, "text"):
|
551
|
-
text_parts.append(part.text)
|
305
|
+
text_parts.append(part.text) # type: ignore
|
552
306
|
if text_parts:
|
553
307
|
return "\n".join(text_parts)
|
554
308
|
|
555
309
|
# For objects with content attribute
|
556
310
|
if hasattr(message, "content"):
|
557
|
-
content = message.content
|
311
|
+
content = message.content # type: ignore
|
558
312
|
if isinstance(content, str):
|
559
313
|
return content
|
560
314
|
elif hasattr(content, "text"):
|
@@ -569,13 +323,13 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
569
323
|
Tries to extract just the content when possible.
|
570
324
|
"""
|
571
325
|
# First try to use the same method for consistency
|
572
|
-
result = self.message_param_str(message)
|
326
|
+
result = self.message_param_str(message) # type: ignore
|
573
327
|
if result != str(message):
|
574
328
|
return result
|
575
329
|
|
576
330
|
# Additional handling for output-specific formats
|
577
331
|
if hasattr(message, "content"):
|
578
|
-
content = message
|
332
|
+
content = getattr(message, "content")
|
579
333
|
if isinstance(content, list):
|
580
334
|
# Extract text from content blocks
|
581
335
|
text_parts = []
|
@@ -588,9 +342,7 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
588
342
|
# Default fallback
|
589
343
|
return str(message)
|
590
344
|
|
591
|
-
def _log_chat_progress(
|
592
|
-
self, chat_turn: Optional[int] = None, model: Optional[str] = None
|
593
|
-
):
|
345
|
+
def _log_chat_progress(self, chat_turn: Optional[int] = None, model: Optional[str] = None) -> None:
|
594
346
|
"""Log a chat progress event"""
|
595
347
|
# Determine action type based on verb
|
596
348
|
if hasattr(self, "verb") and self.verb:
|
@@ -607,7 +359,7 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
607
359
|
}
|
608
360
|
self.logger.debug("Chat in progress", data=data)
|
609
361
|
|
610
|
-
def _log_chat_finished(self, model: Optional[str] = None):
|
362
|
+
def _log_chat_finished(self, model: Optional[str] = None) -> None:
|
611
363
|
"""Log a chat finished event"""
|
612
364
|
data = {
|
613
365
|
"progress_action": ProgressAction.READY,
|
@@ -616,9 +368,7 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
616
368
|
}
|
617
369
|
self.logger.debug("Chat finished", data=data)
|
618
370
|
|
619
|
-
def _convert_prompt_messages(
|
620
|
-
self, prompt_messages: List[PromptMessage]
|
621
|
-
) -> List[MessageParamT]:
|
371
|
+
def _convert_prompt_messages(self, prompt_messages: List[PromptMessage]) -> List[MessageParamT]:
|
622
372
|
"""
|
623
373
|
Convert prompt messages to this LLM's specific message format.
|
624
374
|
To be implemented by concrete LLM classes.
|
@@ -631,7 +381,7 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
631
381
|
description: Optional[str] = None,
|
632
382
|
message_count: int = 0,
|
633
383
|
arguments: Optional[dict[str, str]] = None,
|
634
|
-
):
|
384
|
+
) -> None:
|
635
385
|
"""
|
636
386
|
Display information about a loaded prompt template.
|
637
387
|
|
@@ -650,9 +400,7 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
650
400
|
arguments=arguments,
|
651
401
|
)
|
652
402
|
|
653
|
-
async def apply_prompt_template(
|
654
|
-
self, prompt_result: GetPromptResult, prompt_name: str
|
655
|
-
) -> str:
|
403
|
+
async def apply_prompt_template(self, prompt_result: GetPromptResult, prompt_name: str) -> str:
|
656
404
|
"""
|
657
405
|
Apply a prompt template by adding it to the conversation history.
|
658
406
|
If the last message in the prompt is from a user, automatically
|
@@ -684,15 +432,35 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
684
432
|
)
|
685
433
|
|
686
434
|
# Convert to PromptMessageMultipart objects
|
687
|
-
multipart_messages = PromptMessageMultipart.parse_get_prompt_result(
|
688
|
-
|
689
|
-
|
435
|
+
multipart_messages = PromptMessageMultipart.parse_get_prompt_result(prompt_result)
|
436
|
+
|
437
|
+
# Delegate to the provider-specific implementation
|
438
|
+
return await self._apply_prompt_template_provider_specific(multipart_messages, None)
|
690
439
|
|
440
|
+
async def apply_prompt(
|
441
|
+
self,
|
442
|
+
multipart_messages: List["PromptMessageMultipart"],
|
443
|
+
request_params: RequestParams | None = None,
|
444
|
+
) -> str:
|
445
|
+
"""
|
446
|
+
Apply a list of PromptMessageMultipart messages directly to the LLM.
|
447
|
+
This is a cleaner interface to _apply_prompt_template_provider_specific.
|
448
|
+
|
449
|
+
Args:
|
450
|
+
multipart_messages: List of PromptMessageMultipart objects
|
451
|
+
request_params: Optional parameters to configure the LLM request
|
452
|
+
|
453
|
+
Returns:
|
454
|
+
String representation of the assistant's response
|
455
|
+
"""
|
691
456
|
# Delegate to the provider-specific implementation
|
692
|
-
return await self._apply_prompt_template_provider_specific(multipart_messages)
|
457
|
+
return await self._apply_prompt_template_provider_specific(multipart_messages, request_params)
|
693
458
|
|
459
|
+
# this shouln't need to be very big...
|
694
460
|
async def _apply_prompt_template_provider_specific(
|
695
|
-
self,
|
461
|
+
self,
|
462
|
+
multipart_messages: List["PromptMessageMultipart"],
|
463
|
+
request_params: RequestParams | None = None,
|
696
464
|
) -> str:
|
697
465
|
"""
|
698
466
|
Provider-specific implementation of apply_prompt_template.
|
@@ -712,9 +480,7 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
712
480
|
|
713
481
|
if last_message.role == "user":
|
714
482
|
# For user messages: Add all previous messages to history, then generate response to the last one
|
715
|
-
self.logger.debug(
|
716
|
-
"Last message in prompt is from user, generating assistant response"
|
717
|
-
)
|
483
|
+
self.logger.debug("Last message in prompt is from user, generating assistant response")
|
718
484
|
|
719
485
|
# Add all but the last message to history
|
720
486
|
if len(multipart_messages) > 1:
|
@@ -724,11 +490,9 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
724
490
|
# Fallback generic method for all LLM types
|
725
491
|
for msg in previous_messages:
|
726
492
|
# Convert each PromptMessageMultipart to individual PromptMessages
|
727
|
-
prompt_messages = msg.
|
493
|
+
prompt_messages = msg.from_multipart()
|
728
494
|
for prompt_msg in prompt_messages:
|
729
|
-
converted.append(
|
730
|
-
self.type_converter.from_prompt_message(prompt_msg)
|
731
|
-
)
|
495
|
+
converted.append(self.type_converter.from_prompt_message(prompt_msg))
|
732
496
|
|
733
497
|
self.history.extend(converted, is_prompt=True)
|
734
498
|
|
@@ -737,8 +501,9 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
737
501
|
for content in last_message.content:
|
738
502
|
if content.type == "text":
|
739
503
|
user_text_parts.append(content.text)
|
740
|
-
elif content.type == "resource" and
|
741
|
-
|
504
|
+
elif content.type == "resource" and getattr(content, "resource", None) is not None:
|
505
|
+
if hasattr(content.resource, "text"):
|
506
|
+
user_text_parts.append(content.resource.text) # type: ignore
|
742
507
|
elif content.type == "image":
|
743
508
|
# Add a placeholder for images
|
744
509
|
mime_type = getattr(content, "mimeType", "image/unknown")
|
@@ -752,9 +517,7 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
752
517
|
return await self.generate_str(user_text)
|
753
518
|
else:
|
754
519
|
# For assistant messages: Add all messages to history and return the last one
|
755
|
-
self.logger.debug(
|
756
|
-
"Last message in prompt is from assistant, returning it directly"
|
757
|
-
)
|
520
|
+
self.logger.debug("Last message in prompt is from assistant, returning it directly")
|
758
521
|
|
759
522
|
# Convert and add all messages to history
|
760
523
|
converted = []
|
@@ -762,11 +525,9 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
762
525
|
# Fallback to the original method for all LLM types
|
763
526
|
for msg in multipart_messages:
|
764
527
|
# Convert each PromptMessageMultipart to individual PromptMessages
|
765
|
-
prompt_messages = msg.
|
528
|
+
prompt_messages = msg.from_multipart()
|
766
529
|
for prompt_msg in prompt_messages:
|
767
|
-
converted.append(
|
768
|
-
self.type_converter.from_prompt_message(prompt_msg)
|
769
|
-
)
|
530
|
+
converted.append(self.type_converter.from_prompt_message(prompt_msg))
|
770
531
|
|
771
532
|
self.history.extend(converted, is_prompt=True)
|
772
533
|
|
@@ -783,11 +544,11 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
783
544
|
uri = getattr(content.resource, "uri", "")
|
784
545
|
if uri:
|
785
546
|
assistant_text_parts.append(
|
786
|
-
f"[Resource: {uri}, Type: {mime_type}]\n{content.resource.text}"
|
547
|
+
f"[Resource: {uri}, Type: {mime_type}]\n{content.resource.text}" # ignore # type: ignore
|
787
548
|
)
|
788
549
|
else:
|
789
550
|
assistant_text_parts.append(
|
790
|
-
f"[Resource Type: {mime_type}]\n{content.resource.text}"
|
551
|
+
f"[Resource Type: {mime_type}]\n{content.resource.text}" # type ignore # type: ignore
|
791
552
|
)
|
792
553
|
elif content.type == "image":
|
793
554
|
# Note the presence of images
|
@@ -800,14 +561,15 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
800
561
|
has_non_text_content = True
|
801
562
|
|
802
563
|
# Join all parts with double newlines for better readability
|
803
|
-
result = (
|
804
|
-
"\n\n".join(assistant_text_parts)
|
805
|
-
if assistant_text_parts
|
806
|
-
else str(last_message.content)
|
807
|
-
)
|
564
|
+
result = "\n\n".join(assistant_text_parts) if assistant_text_parts else str(last_message.content)
|
808
565
|
|
809
566
|
# Add a note if non-text content was present
|
810
567
|
if has_non_text_content:
|
811
568
|
result += "\n\n[Note: This message contained non-text content that may not be fully represented in text format]"
|
812
569
|
|
813
570
|
return result
|
571
|
+
|
572
|
+
|
573
|
+
#####################################
|
574
|
+
### NEW INTERFACE METHODS BELOW ###
|
575
|
+
#####################################
|