fast-agent-mcp 0.1.11__py3-none-any.whl → 0.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. {fast_agent_mcp-0.1.11.dist-info → fast_agent_mcp-0.1.13.dist-info}/METADATA +1 -1
  2. fast_agent_mcp-0.1.13.dist-info/RECORD +164 -0
  3. mcp_agent/agents/agent.py +37 -102
  4. mcp_agent/app.py +16 -27
  5. mcp_agent/cli/commands/bootstrap.py +22 -52
  6. mcp_agent/cli/commands/config.py +4 -4
  7. mcp_agent/cli/commands/setup.py +11 -26
  8. mcp_agent/cli/main.py +6 -9
  9. mcp_agent/cli/terminal.py +2 -2
  10. mcp_agent/config.py +1 -5
  11. mcp_agent/context.py +13 -26
  12. mcp_agent/context_dependent.py +3 -7
  13. mcp_agent/core/agent_app.py +46 -122
  14. mcp_agent/core/agent_types.py +29 -2
  15. mcp_agent/core/agent_utils.py +3 -5
  16. mcp_agent/core/decorators.py +6 -14
  17. mcp_agent/core/enhanced_prompt.py +25 -52
  18. mcp_agent/core/error_handling.py +1 -1
  19. mcp_agent/core/exceptions.py +8 -8
  20. mcp_agent/core/factory.py +30 -72
  21. mcp_agent/core/fastagent.py +48 -88
  22. mcp_agent/core/mcp_content.py +10 -19
  23. mcp_agent/core/prompt.py +8 -15
  24. mcp_agent/core/proxies.py +34 -25
  25. mcp_agent/core/request_params.py +46 -0
  26. mcp_agent/core/types.py +6 -6
  27. mcp_agent/core/validation.py +16 -16
  28. mcp_agent/executor/decorator_registry.py +11 -23
  29. mcp_agent/executor/executor.py +8 -17
  30. mcp_agent/executor/task_registry.py +2 -4
  31. mcp_agent/executor/temporal.py +28 -74
  32. mcp_agent/executor/workflow.py +3 -5
  33. mcp_agent/executor/workflow_signal.py +17 -29
  34. mcp_agent/human_input/handler.py +4 -9
  35. mcp_agent/human_input/types.py +2 -3
  36. mcp_agent/logging/events.py +1 -5
  37. mcp_agent/logging/json_serializer.py +7 -6
  38. mcp_agent/logging/listeners.py +20 -23
  39. mcp_agent/logging/logger.py +15 -17
  40. mcp_agent/logging/rich_progress.py +10 -8
  41. mcp_agent/logging/tracing.py +4 -6
  42. mcp_agent/logging/transport.py +24 -24
  43. mcp_agent/mcp/gen_client.py +4 -12
  44. mcp_agent/mcp/interfaces.py +107 -88
  45. mcp_agent/mcp/mcp_agent_client_session.py +11 -19
  46. mcp_agent/mcp/mcp_agent_server.py +8 -10
  47. mcp_agent/mcp/mcp_aggregator.py +49 -122
  48. mcp_agent/mcp/mcp_connection_manager.py +16 -37
  49. mcp_agent/mcp/prompt_message_multipart.py +12 -18
  50. mcp_agent/mcp/prompt_serialization.py +13 -38
  51. mcp_agent/mcp/prompts/prompt_load.py +99 -0
  52. mcp_agent/mcp/prompts/prompt_server.py +21 -128
  53. mcp_agent/mcp/prompts/prompt_template.py +20 -42
  54. mcp_agent/mcp/resource_utils.py +8 -17
  55. mcp_agent/mcp/sampling.py +62 -64
  56. mcp_agent/mcp/stdio.py +11 -8
  57. mcp_agent/mcp_server/__init__.py +1 -1
  58. mcp_agent/mcp_server/agent_server.py +10 -17
  59. mcp_agent/mcp_server_registry.py +13 -35
  60. mcp_agent/resources/examples/data-analysis/analysis-campaign.py +1 -1
  61. mcp_agent/resources/examples/data-analysis/analysis.py +1 -1
  62. mcp_agent/resources/examples/data-analysis/slides.py +110 -0
  63. mcp_agent/resources/examples/internal/agent.py +2 -1
  64. mcp_agent/resources/examples/internal/job.py +2 -1
  65. mcp_agent/resources/examples/internal/prompt_category.py +1 -1
  66. mcp_agent/resources/examples/internal/prompt_sizing.py +3 -5
  67. mcp_agent/resources/examples/internal/sizer.py +2 -1
  68. mcp_agent/resources/examples/internal/social.py +2 -1
  69. mcp_agent/resources/examples/mcp_researcher/researcher-eval.py +1 -1
  70. mcp_agent/resources/examples/prompting/__init__.py +1 -1
  71. mcp_agent/resources/examples/prompting/agent.py +2 -1
  72. mcp_agent/resources/examples/prompting/image_server.py +5 -11
  73. mcp_agent/resources/examples/researcher/researcher-eval.py +1 -1
  74. mcp_agent/resources/examples/researcher/researcher-imp.py +3 -4
  75. mcp_agent/resources/examples/researcher/researcher.py +2 -1
  76. mcp_agent/resources/examples/workflows/agent_build.py +2 -1
  77. mcp_agent/resources/examples/workflows/chaining.py +2 -1
  78. mcp_agent/resources/examples/workflows/evaluator.py +2 -1
  79. mcp_agent/resources/examples/workflows/human_input.py +2 -1
  80. mcp_agent/resources/examples/workflows/orchestrator.py +2 -1
  81. mcp_agent/resources/examples/workflows/parallel.py +2 -1
  82. mcp_agent/resources/examples/workflows/router.py +2 -1
  83. mcp_agent/resources/examples/workflows/sse.py +1 -1
  84. mcp_agent/telemetry/usage_tracking.py +2 -1
  85. mcp_agent/ui/console_display.py +17 -41
  86. mcp_agent/workflows/embedding/embedding_base.py +1 -4
  87. mcp_agent/workflows/embedding/embedding_cohere.py +2 -2
  88. mcp_agent/workflows/embedding/embedding_openai.py +4 -13
  89. mcp_agent/workflows/evaluator_optimizer/evaluator_optimizer.py +23 -57
  90. mcp_agent/workflows/intent_classifier/intent_classifier_base.py +5 -8
  91. mcp_agent/workflows/intent_classifier/intent_classifier_embedding.py +7 -11
  92. mcp_agent/workflows/intent_classifier/intent_classifier_embedding_cohere.py +4 -8
  93. mcp_agent/workflows/intent_classifier/intent_classifier_embedding_openai.py +4 -8
  94. mcp_agent/workflows/intent_classifier/intent_classifier_llm.py +11 -22
  95. mcp_agent/workflows/intent_classifier/intent_classifier_llm_anthropic.py +3 -3
  96. mcp_agent/workflows/intent_classifier/intent_classifier_llm_openai.py +4 -6
  97. mcp_agent/workflows/llm/anthropic_utils.py +8 -29
  98. mcp_agent/workflows/llm/augmented_llm.py +94 -332
  99. mcp_agent/workflows/llm/augmented_llm_anthropic.py +43 -76
  100. mcp_agent/workflows/llm/augmented_llm_openai.py +46 -100
  101. mcp_agent/workflows/llm/augmented_llm_passthrough.py +42 -20
  102. mcp_agent/workflows/llm/augmented_llm_playback.py +8 -6
  103. mcp_agent/workflows/llm/memory.py +103 -0
  104. mcp_agent/workflows/llm/model_factory.py +9 -21
  105. mcp_agent/workflows/llm/openai_utils.py +1 -1
  106. mcp_agent/workflows/llm/prompt_utils.py +39 -27
  107. mcp_agent/workflows/llm/providers/multipart_converter_anthropic.py +246 -184
  108. mcp_agent/workflows/llm/providers/multipart_converter_openai.py +212 -202
  109. mcp_agent/workflows/llm/providers/openai_multipart.py +19 -61
  110. mcp_agent/workflows/llm/providers/sampling_converter_anthropic.py +11 -212
  111. mcp_agent/workflows/llm/providers/sampling_converter_openai.py +13 -215
  112. mcp_agent/workflows/llm/sampling_converter.py +117 -0
  113. mcp_agent/workflows/llm/sampling_format_converter.py +12 -29
  114. mcp_agent/workflows/orchestrator/orchestrator.py +24 -67
  115. mcp_agent/workflows/orchestrator/orchestrator_models.py +14 -40
  116. mcp_agent/workflows/parallel/fan_in.py +17 -47
  117. mcp_agent/workflows/parallel/fan_out.py +6 -12
  118. mcp_agent/workflows/parallel/parallel_llm.py +9 -26
  119. mcp_agent/workflows/router/router_base.py +29 -59
  120. mcp_agent/workflows/router/router_embedding.py +11 -25
  121. mcp_agent/workflows/router/router_embedding_cohere.py +2 -2
  122. mcp_agent/workflows/router/router_embedding_openai.py +2 -2
  123. mcp_agent/workflows/router/router_llm.py +12 -28
  124. mcp_agent/workflows/swarm/swarm.py +20 -48
  125. mcp_agent/workflows/swarm/swarm_anthropic.py +2 -2
  126. mcp_agent/workflows/swarm/swarm_openai.py +2 -2
  127. fast_agent_mcp-0.1.11.dist-info/RECORD +0 -160
  128. mcp_agent/workflows/llm/llm_selector.py +0 -345
  129. {fast_agent_mcp-0.1.11.dist-info → fast_agent_mcp-0.1.13.dist-info}/WHEEL +0 -0
  130. {fast_agent_mcp-0.1.11.dist-info → fast_agent_mcp-0.1.13.dist-info}/entry_points.txt +0 -0
  131. {fast_agent_mcp-0.1.11.dist-info → fast_agent_mcp-0.1.13.dist-info}/licenses/LICENSE +0 -0
@@ -1,228 +1,54 @@
1
1
  from abc import abstractmethod
2
-
3
2
  from typing import (
4
- Generic,
3
+ TYPE_CHECKING,
4
+ Any,
5
5
  List,
6
6
  Optional,
7
- Protocol,
8
7
  Type,
9
- TypeVar,
10
- TYPE_CHECKING,
8
+ cast,
11
9
  )
12
10
 
13
- from mcp import CreateMessageResult, SamplingMessage
14
- from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
15
- from mcp_agent.workflows.llm.sampling_format_converter import (
16
- SamplingFormatConverter,
11
+ from mcp_agent.logging.logger import get_logger
12
+ from mcp_agent.mcp.interfaces import (
13
+ AugmentedLLMProtocol,
17
14
  MessageParamT,
18
15
  MessageT,
16
+ ModelT,
17
+ )
18
+ from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
19
+ from mcp_agent.workflows.llm.sampling_format_converter import (
20
+ BasicFormatConverter,
21
+ ProviderFormatConverter,
19
22
  )
20
23
 
21
24
  # Forward reference for type annotations
22
25
  if TYPE_CHECKING:
23
- from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
24
26
  from mcp_agent.agents.agent import Agent
25
27
  from mcp_agent.context import Context
28
+ from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
26
29
 
27
30
 
28
- from pydantic import Field
29
-
30
31
  from mcp.types import (
31
32
  CallToolRequest,
32
33
  CallToolResult,
33
- CreateMessageRequestParams,
34
- ModelPreferences,
34
+ GetPromptResult,
35
35
  PromptMessage,
36
36
  TextContent,
37
- GetPromptResult,
38
37
  )
38
+ from rich.text import Text
39
39
 
40
40
  from mcp_agent.context_dependent import ContextDependent
41
- from mcp_agent.core.exceptions import PromptExitError
41
+ from mcp_agent.core.exceptions import ModelConfigError, PromptExitError
42
+ from mcp_agent.core.request_params import RequestParams
42
43
  from mcp_agent.event_progress import ProgressAction
43
44
  from mcp_agent.mcp.mcp_aggregator import MCPAggregator
44
- from mcp_agent.workflows.llm.llm_selector import ModelSelector
45
45
  from mcp_agent.ui.console_display import ConsoleDisplay
46
- from rich.text import Text
47
-
48
-
49
- ModelT = TypeVar("ModelT")
50
- """A type representing a structured output message from an LLM."""
51
-
46
+ from mcp_agent.workflows.llm.memory import Memory, SimpleMemory
52
47
 
53
48
  # TODO -- move this to a constant
54
49
  HUMAN_INPUT_TOOL_NAME = "__human_input__"
55
50
 
56
51
 
57
- class Memory(Protocol, Generic[MessageParamT]):
58
- """
59
- Simple memory management for storing past interactions in-memory.
60
- """
61
-
62
- # TODO: saqadri - add checkpointing and other advanced memory capabilities
63
-
64
- def __init__(self): ...
65
-
66
- def extend(
67
- self, messages: List[MessageParamT], is_prompt: bool = False
68
- ) -> None: ...
69
-
70
- def set(self, messages: List[MessageParamT], is_prompt: bool = False) -> None: ...
71
-
72
- def append(self, message: MessageParamT, is_prompt: bool = False) -> None: ...
73
-
74
- def get(self, include_history: bool = True) -> List[MessageParamT]: ...
75
-
76
- def clear(self, clear_prompts: bool = False) -> None: ...
77
-
78
-
79
- class SimpleMemory(Memory, Generic[MessageParamT]):
80
- """
81
- Simple memory management for storing past interactions in-memory.
82
-
83
- Maintains both prompt messages (which are always included) and
84
- generated conversation history (which is included based on use_history setting).
85
- """
86
-
87
- def __init__(self):
88
- self.history: List[MessageParamT] = []
89
- self.prompt_messages: List[MessageParamT] = [] # Always included
90
-
91
- def extend(self, messages: List[MessageParamT], is_prompt: bool = False):
92
- """
93
- Add multiple messages to history.
94
-
95
- Args:
96
- messages: Messages to add
97
- is_prompt: If True, add to prompt_messages instead of regular history
98
- """
99
- if is_prompt:
100
- self.prompt_messages.extend(messages)
101
- else:
102
- self.history.extend(messages)
103
-
104
- def set(self, messages: List[MessageParamT], is_prompt: bool = False):
105
- """
106
- Replace messages in history.
107
-
108
- Args:
109
- messages: Messages to set
110
- is_prompt: If True, replace prompt_messages instead of regular history
111
- """
112
- if is_prompt:
113
- self.prompt_messages = messages.copy()
114
- else:
115
- self.history = messages.copy()
116
-
117
- def append(self, message: MessageParamT, is_prompt: bool = False):
118
- """
119
- Add a single message to history.
120
-
121
- Args:
122
- message: Message to add
123
- is_prompt: If True, add to prompt_messages instead of regular history
124
- """
125
- if is_prompt:
126
- self.prompt_messages.append(message)
127
- else:
128
- self.history.append(message)
129
-
130
- def get(self, include_history: bool = True) -> List[MessageParamT]:
131
- """
132
- Get all messages in memory.
133
-
134
- Args:
135
- include_history: If True, include regular history messages
136
- If False, only return prompt messages
137
-
138
- Returns:
139
- Combined list of prompt messages and optionally history messages
140
- """
141
- if include_history:
142
- return self.prompt_messages + self.history
143
- else:
144
- return self.prompt_messages.copy()
145
-
146
- def clear(self, clear_prompts: bool = False):
147
- """
148
- Clear history and optionally prompt messages.
149
-
150
- Args:
151
- clear_prompts: If True, also clear prompt messages
152
- """
153
- self.history = []
154
- if clear_prompts:
155
- self.prompt_messages = []
156
-
157
-
158
- class RequestParams(CreateMessageRequestParams):
159
- """
160
- Parameters to configure the AugmentedLLM 'generate' requests.
161
- """
162
-
163
- messages: None = Field(exclude=True, default=None)
164
- """
165
- Ignored. 'messages' are removed from CreateMessageRequestParams
166
- to avoid confusion with the 'message' parameter on 'generate' method.
167
- """
168
-
169
- maxTokens: int = 2048
170
- """The maximum number of tokens to sample, as requested by the server."""
171
-
172
- model: str | None = None
173
- """
174
- The model to use for the LLM generation.
175
- If specified, this overrides the 'modelPreferences' selection criteria.
176
- """
177
-
178
- use_history: bool = True
179
- """
180
- Include the message history in the generate request.
181
- """
182
-
183
- max_iterations: int = 10
184
- """
185
- The maximum number of iterations to run the LLM for.
186
- """
187
-
188
- parallel_tool_calls: bool = True
189
- """
190
- Whether to allow multiple tool calls per iteration.
191
- Also known as multi-step tool use.
192
- """
193
-
194
-
195
- class AugmentedLLMProtocol(Protocol, Generic[MessageParamT, MessageT]):
196
- """Protocol defining the interface for augmented LLMs"""
197
-
198
- async def generate(
199
- self,
200
- message: str | MessageParamT | List[MessageParamT],
201
- request_params: RequestParams | None = None,
202
- ) -> List[MessageT]:
203
- """Request an LLM generation, which may run multiple iterations, and return the result"""
204
-
205
- async def generate_str(
206
- self,
207
- message: str | MessageParamT | List[MessageParamT],
208
- request_params: RequestParams | None = None,
209
- ) -> str:
210
- """Request an LLM generation and return the string representation of the result"""
211
-
212
- async def generate_structured(
213
- self,
214
- message: str | MessageParamT | List[MessageParamT],
215
- response_model: Type[ModelT],
216
- request_params: RequestParams | None = None,
217
- ) -> ModelT:
218
- """Request a structured LLM generation and return the result as a Pydantic model."""
219
-
220
- async def generate_prompt(
221
- self, prompt: PromptMessageMultipart, request_params: RequestParams | None
222
- ) -> str:
223
- """Request an LLM generation and return a string representation of the result"""
224
-
225
-
226
52
  class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, MessageT]):
227
53
  """
228
54
  The basic building block of agentic systems is an LLM enhanced with augmentations
@@ -231,9 +57,6 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
231
57
  selecting appropriate tools, and determining what information to retain.
232
58
  """
233
59
 
234
- # TODO: saqadri - add streaming support (e.g. generate_stream)
235
- # TODO: saqadri - consider adding middleware patterns for pre/post processing of messages, for now we have pre/post_tool_call
236
-
237
60
  provider: str | None = None
238
61
 
239
62
  def __init__(
@@ -243,10 +66,10 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
243
66
  instruction: str | None = None,
244
67
  name: str | None = None,
245
68
  request_params: RequestParams | None = None,
246
- type_converter: Type[SamplingFormatConverter[MessageParamT, MessageT]] = None,
69
+ type_converter: Type[ProviderFormatConverter[MessageParamT, MessageT]] = BasicFormatConverter,
247
70
  context: Optional["Context"] = None,
248
- **kwargs,
249
- ):
71
+ **kwargs: dict[str, Any],
72
+ ) -> None:
250
73
  """
251
74
  Initialize the LLM with a list of server names and an instruction.
252
75
  If a name is provided, it will be used to identify the LLM.
@@ -255,44 +78,23 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
255
78
  # Extract request_params before super() call
256
79
  self._init_request_params = request_params
257
80
  super().__init__(context=context, **kwargs)
258
-
81
+ self.logger = get_logger(__name__)
259
82
  self.executor = self.context.executor
260
- self.aggregator = (
261
- agent if agent is not None else MCPAggregator(server_names or [])
262
- )
83
+ self.aggregator = agent if agent is not None else MCPAggregator(server_names or [])
263
84
  self.name = name or (agent.name if agent else None)
264
- self.instruction = instruction or (
265
- agent.instruction if agent and isinstance(agent.instruction, str) else None
266
- )
85
+ self.instruction = instruction or (agent.instruction if agent and isinstance(agent.instruction, str) else None)
267
86
  self.history: Memory[MessageParamT] = SimpleMemory[MessageParamT]()
268
87
 
269
88
  # Initialize the display component
270
89
  self.display = ConsoleDisplay(config=self.context.config)
271
90
 
272
- # Set initial model preferences
273
- self.model_preferences = ModelPreferences(
274
- costPriority=0.3,
275
- speedPriority=0.4,
276
- intelligencePriority=0.3,
277
- )
278
-
279
91
  # Initialize default parameters
280
92
  self.default_request_params = self._initialize_default_params(kwargs)
281
93
 
282
- # Update model preferences from default params
283
- if self.default_request_params and self.default_request_params.modelPreferences:
284
- self.model_preferences = self.default_request_params.modelPreferences
285
-
286
94
  # Merge with provided params if any
287
95
  if self._init_request_params:
288
- self.default_request_params = self._merge_request_params(
289
- self.default_request_params, self._init_request_params
290
- )
291
- # Update model preferences again if they changed in the merge
292
- if self.default_request_params.modelPreferences:
293
- self.model_preferences = self.default_request_params.modelPreferences
96
+ self.default_request_params = self._merge_request_params(self.default_request_params, self._init_request_params)
294
97
 
295
- self.model_selector = self.context.model_selector
296
98
  self.type_converter = type_converter
297
99
  self.verb = kwargs.get("verb")
298
100
 
@@ -321,48 +123,26 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
321
123
  ) -> ModelT:
322
124
  """Request a structured LLM generation and return the result as a Pydantic model."""
323
125
 
324
- # aysnc def generate2_str(self, prompt: PromptMessageMultipart, request_params: RequestParams | None = None) -> List[MessageT]:
325
- # """Request an LLM generation, which may run multiple iterations, and return the result"""
326
- # return None
327
-
328
- async def select_model(
329
- self, request_params: RequestParams | None = None
330
- ) -> str | None:
126
+ async def select_model(self, request_params: RequestParams | None = None) -> str | None:
331
127
  """
332
- Select an LLM based on the request parameters.
333
- If a model is specified in the request, it will override the model selection criteria.
128
+ Return the configured model (legacy support)
334
129
  """
335
- model_preferences = self.model_preferences
336
- if request_params is not None:
337
- model_preferences = request_params.modelPreferences or model_preferences
338
- model = request_params.model
339
- if model:
340
- return model
341
-
342
- ## TODO -- can't have been tested, returns invalid model strings (e.g. claude-35-sonnet)
343
- if not self.model_selector:
344
- self.model_selector = ModelSelector()
345
-
346
- model_info = self.model_selector.select_best_model(
347
- model_preferences=model_preferences, provider=self.provider
348
- )
130
+ if request_params and request_params.model:
131
+ return request_params.model
349
132
 
350
- return model_info.name
133
+ raise ModelConfigError("Internal Error: Model is not configured correctly")
351
134
 
352
135
  def _initialize_default_params(self, kwargs: dict) -> RequestParams:
353
136
  """Initialize default parameters for the LLM.
354
137
  Should be overridden by provider implementations to set provider-specific defaults."""
355
138
  return RequestParams(
356
- modelPreferences=self.model_preferences,
357
139
  systemPrompt=self.instruction,
358
140
  parallel_tool_calls=True,
359
141
  max_iterations=10,
360
142
  use_history=True,
361
143
  )
362
144
 
363
- def _merge_request_params(
364
- self, default_params: RequestParams, provided_params: RequestParams
365
- ) -> RequestParams:
145
+ def _merge_request_params(self, default_params: RequestParams, provided_params: RequestParams) -> RequestParams:
366
146
  """Merge default and provided request parameters"""
367
147
 
368
148
  merged = default_params.model_dump()
@@ -395,32 +175,11 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
395
175
 
396
176
  return default_request_params
397
177
 
398
- def to_mcp_message_result(self, result: MessageT) -> CreateMessageResult:
399
- """Convert an LLM response to an MCP message result type."""
400
- return self.type_converter.to_sampling_result(result)
401
-
402
- def from_mcp_message_result(self, result: CreateMessageResult) -> MessageT:
403
- """Convert an MCP message result to an LLM response type."""
404
- return self.type_converter.from_sampling_result(result)
405
-
406
- def to_mcp_message_param(self, param: MessageParamT) -> SamplingMessage:
407
- """Convert an LLM input to an MCP message (SamplingMessage) type."""
408
- return self.type_converter.to_sampling_message(param)
409
-
410
- def from_mcp_message_param(self, param: SamplingMessage) -> MessageParamT:
411
- """Convert an MCP message (SamplingMessage) to an LLM input type."""
412
- return self.type_converter.from_sampling_message(param)
413
-
414
- def from_mcp_prompt_message(self, message: PromptMessage) -> MessageParamT:
415
- return self.type_converter.from_prompt_message(message)
416
-
417
178
  @classmethod
418
- def convert_message_to_message_param(
419
- cls, message: MessageT, **kwargs
420
- ) -> MessageParamT:
179
+ def convert_message_to_message_param(cls, message: MessageT, **kwargs: dict[str, Any]) -> MessageParamT:
421
180
  """Convert a response object to an input parameter object to allow LLM calls to be chained."""
422
181
  # Many LLM implementations will allow the same type for input and output messages
423
- return message
182
+ return cast("MessageParamT", message)
424
183
 
425
184
  async def get_last_message(self) -> MessageParamT | None:
426
185
  """
@@ -435,15 +194,15 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
435
194
  last_message = await self.get_last_message()
436
195
  return self.message_param_str(last_message) if last_message else None
437
196
 
438
- def show_tool_result(self, result: CallToolResult):
197
+ def show_tool_result(self, result: CallToolResult) -> None:
439
198
  """Display a tool result in a formatted panel."""
440
199
  self.display.show_tool_result(result)
441
200
 
442
- def show_oai_tool_result(self, result):
201
+ def show_oai_tool_result(self, result: str) -> None:
443
202
  """Display a tool result in a formatted panel."""
444
203
  self.display.show_oai_tool_result(result)
445
204
 
446
- def show_tool_call(self, available_tools, tool_name, tool_args):
205
+ def show_tool_call(self, available_tools, tool_name, tool_args) -> None:
447
206
  """Display a tool call in a formatted panel."""
448
207
  self.display.show_tool_call(available_tools, tool_name, tool_args)
449
208
 
@@ -452,7 +211,7 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
452
211
  message_text: str | Text,
453
212
  highlight_namespaced_tool: str = "",
454
213
  title: str = "ASSISTANT",
455
- ):
214
+ ) -> None:
456
215
  """Display an assistant message in a formatted panel."""
457
216
  await self.display.show_assistant_message(
458
217
  message_text,
@@ -462,19 +221,15 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
462
221
  name=self.name,
463
222
  )
464
223
 
465
- def show_user_message(self, message, model: str | None, chat_turn: int):
224
+ def show_user_message(self, message, model: str | None, chat_turn: int) -> None:
466
225
  """Display a user message in a formatted panel."""
467
226
  self.display.show_user_message(message, model, chat_turn, name=self.name)
468
227
 
469
- async def pre_tool_call(
470
- self, tool_call_id: str | None, request: CallToolRequest
471
- ) -> CallToolRequest | bool:
228
+ async def pre_tool_call(self, tool_call_id: str | None, request: CallToolRequest) -> CallToolRequest | bool:
472
229
  """Called before a tool is executed. Return False to prevent execution."""
473
230
  return request
474
231
 
475
- async def post_tool_call(
476
- self, tool_call_id: str | None, request: CallToolRequest, result: CallToolResult
477
- ) -> CallToolResult:
232
+ async def post_tool_call(self, tool_call_id: str | None, request: CallToolRequest, result: CallToolResult) -> CallToolResult:
478
233
  """Called after a tool execution. Can modify the result before it's returned."""
479
234
  return result
480
235
 
@@ -497,7 +252,8 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
497
252
  isError=True,
498
253
  content=[
499
254
  TextContent(
500
- text=f"Error: Tool '{request.params.name}' was not allowed to run."
255
+ type="text",
256
+ text=f"Error: Tool '{request.params.name}' was not allowed to run.",
501
257
  )
502
258
  ],
503
259
  )
@@ -508,9 +264,7 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
508
264
  tool_args = request.params.arguments
509
265
  result = await self.aggregator.call_tool(tool_name, tool_args)
510
266
 
511
- postprocess = await self.post_tool_call(
512
- tool_call_id=tool_call_id, request=request, result=result
513
- )
267
+ postprocess = await self.post_tool_call(tool_call_id=tool_call_id, request=request, result=result)
514
268
 
515
269
  if isinstance(postprocess, CallToolResult):
516
270
  result = postprocess
@@ -548,13 +302,13 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
548
302
  if isinstance(part, dict) and "text" in part:
549
303
  text_parts.append(part["text"])
550
304
  elif hasattr(part, "text"):
551
- text_parts.append(part.text)
305
+ text_parts.append(part.text) # type: ignore
552
306
  if text_parts:
553
307
  return "\n".join(text_parts)
554
308
 
555
309
  # For objects with content attribute
556
310
  if hasattr(message, "content"):
557
- content = message.content
311
+ content = message.content # type: ignore
558
312
  if isinstance(content, str):
559
313
  return content
560
314
  elif hasattr(content, "text"):
@@ -569,13 +323,13 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
569
323
  Tries to extract just the content when possible.
570
324
  """
571
325
  # First try to use the same method for consistency
572
- result = self.message_param_str(message)
326
+ result = self.message_param_str(message) # type: ignore
573
327
  if result != str(message):
574
328
  return result
575
329
 
576
330
  # Additional handling for output-specific formats
577
331
  if hasattr(message, "content"):
578
- content = message.content
332
+ content = getattr(message, "content")
579
333
  if isinstance(content, list):
580
334
  # Extract text from content blocks
581
335
  text_parts = []
@@ -588,9 +342,7 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
588
342
  # Default fallback
589
343
  return str(message)
590
344
 
591
- def _log_chat_progress(
592
- self, chat_turn: Optional[int] = None, model: Optional[str] = None
593
- ):
345
+ def _log_chat_progress(self, chat_turn: Optional[int] = None, model: Optional[str] = None) -> None:
594
346
  """Log a chat progress event"""
595
347
  # Determine action type based on verb
596
348
  if hasattr(self, "verb") and self.verb:
@@ -607,7 +359,7 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
607
359
  }
608
360
  self.logger.debug("Chat in progress", data=data)
609
361
 
610
- def _log_chat_finished(self, model: Optional[str] = None):
362
+ def _log_chat_finished(self, model: Optional[str] = None) -> None:
611
363
  """Log a chat finished event"""
612
364
  data = {
613
365
  "progress_action": ProgressAction.READY,
@@ -616,9 +368,7 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
616
368
  }
617
369
  self.logger.debug("Chat finished", data=data)
618
370
 
619
- def _convert_prompt_messages(
620
- self, prompt_messages: List[PromptMessage]
621
- ) -> List[MessageParamT]:
371
+ def _convert_prompt_messages(self, prompt_messages: List[PromptMessage]) -> List[MessageParamT]:
622
372
  """
623
373
  Convert prompt messages to this LLM's specific message format.
624
374
  To be implemented by concrete LLM classes.
@@ -631,7 +381,7 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
631
381
  description: Optional[str] = None,
632
382
  message_count: int = 0,
633
383
  arguments: Optional[dict[str, str]] = None,
634
- ):
384
+ ) -> None:
635
385
  """
636
386
  Display information about a loaded prompt template.
637
387
 
@@ -650,9 +400,7 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
650
400
  arguments=arguments,
651
401
  )
652
402
 
653
- async def apply_prompt_template(
654
- self, prompt_result: GetPromptResult, prompt_name: str
655
- ) -> str:
403
+ async def apply_prompt_template(self, prompt_result: GetPromptResult, prompt_name: str) -> str:
656
404
  """
657
405
  Apply a prompt template by adding it to the conversation history.
658
406
  If the last message in the prompt is from a user, automatically
@@ -684,15 +432,35 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
684
432
  )
685
433
 
686
434
  # Convert to PromptMessageMultipart objects
687
- multipart_messages = PromptMessageMultipart.parse_get_prompt_result(
688
- prompt_result
689
- )
435
+ multipart_messages = PromptMessageMultipart.parse_get_prompt_result(prompt_result)
436
+
437
+ # Delegate to the provider-specific implementation
438
+ return await self._apply_prompt_template_provider_specific(multipart_messages, None)
690
439
 
440
+ async def apply_prompt(
441
+ self,
442
+ multipart_messages: List["PromptMessageMultipart"],
443
+ request_params: RequestParams | None = None,
444
+ ) -> str:
445
+ """
446
+ Apply a list of PromptMessageMultipart messages directly to the LLM.
447
+ This is a cleaner interface to _apply_prompt_template_provider_specific.
448
+
449
+ Args:
450
+ multipart_messages: List of PromptMessageMultipart objects
451
+ request_params: Optional parameters to configure the LLM request
452
+
453
+ Returns:
454
+ String representation of the assistant's response
455
+ """
691
456
  # Delegate to the provider-specific implementation
692
- return await self._apply_prompt_template_provider_specific(multipart_messages)
457
+ return await self._apply_prompt_template_provider_specific(multipart_messages, request_params)
693
458
 
459
+ # this shouln't need to be very big...
694
460
  async def _apply_prompt_template_provider_specific(
695
- self, multipart_messages: List["PromptMessageMultipart"]
461
+ self,
462
+ multipart_messages: List["PromptMessageMultipart"],
463
+ request_params: RequestParams | None = None,
696
464
  ) -> str:
697
465
  """
698
466
  Provider-specific implementation of apply_prompt_template.
@@ -712,9 +480,7 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
712
480
 
713
481
  if last_message.role == "user":
714
482
  # For user messages: Add all previous messages to history, then generate response to the last one
715
- self.logger.debug(
716
- "Last message in prompt is from user, generating assistant response"
717
- )
483
+ self.logger.debug("Last message in prompt is from user, generating assistant response")
718
484
 
719
485
  # Add all but the last message to history
720
486
  if len(multipart_messages) > 1:
@@ -724,11 +490,9 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
724
490
  # Fallback generic method for all LLM types
725
491
  for msg in previous_messages:
726
492
  # Convert each PromptMessageMultipart to individual PromptMessages
727
- prompt_messages = msg.to_prompt_messages()
493
+ prompt_messages = msg.from_multipart()
728
494
  for prompt_msg in prompt_messages:
729
- converted.append(
730
- self.type_converter.from_prompt_message(prompt_msg)
731
- )
495
+ converted.append(self.type_converter.from_prompt_message(prompt_msg))
732
496
 
733
497
  self.history.extend(converted, is_prompt=True)
734
498
 
@@ -737,8 +501,9 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
737
501
  for content in last_message.content:
738
502
  if content.type == "text":
739
503
  user_text_parts.append(content.text)
740
- elif content.type == "resource" and hasattr(content.resource, "text"):
741
- user_text_parts.append(content.resource.text)
504
+ elif content.type == "resource" and getattr(content, "resource", None) is not None:
505
+ if hasattr(content.resource, "text"):
506
+ user_text_parts.append(content.resource.text) # type: ignore
742
507
  elif content.type == "image":
743
508
  # Add a placeholder for images
744
509
  mime_type = getattr(content, "mimeType", "image/unknown")
@@ -752,9 +517,7 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
752
517
  return await self.generate_str(user_text)
753
518
  else:
754
519
  # For assistant messages: Add all messages to history and return the last one
755
- self.logger.debug(
756
- "Last message in prompt is from assistant, returning it directly"
757
- )
520
+ self.logger.debug("Last message in prompt is from assistant, returning it directly")
758
521
 
759
522
  # Convert and add all messages to history
760
523
  converted = []
@@ -762,11 +525,9 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
762
525
  # Fallback to the original method for all LLM types
763
526
  for msg in multipart_messages:
764
527
  # Convert each PromptMessageMultipart to individual PromptMessages
765
- prompt_messages = msg.to_prompt_messages()
528
+ prompt_messages = msg.from_multipart()
766
529
  for prompt_msg in prompt_messages:
767
- converted.append(
768
- self.type_converter.from_prompt_message(prompt_msg)
769
- )
530
+ converted.append(self.type_converter.from_prompt_message(prompt_msg))
770
531
 
771
532
  self.history.extend(converted, is_prompt=True)
772
533
 
@@ -783,11 +544,11 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
783
544
  uri = getattr(content.resource, "uri", "")
784
545
  if uri:
785
546
  assistant_text_parts.append(
786
- f"[Resource: {uri}, Type: {mime_type}]\n{content.resource.text}"
547
+ f"[Resource: {uri}, Type: {mime_type}]\n{content.resource.text}" # ignore # type: ignore
787
548
  )
788
549
  else:
789
550
  assistant_text_parts.append(
790
- f"[Resource Type: {mime_type}]\n{content.resource.text}"
551
+ f"[Resource Type: {mime_type}]\n{content.resource.text}" # type ignore # type: ignore
791
552
  )
792
553
  elif content.type == "image":
793
554
  # Note the presence of images
@@ -800,14 +561,15 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
800
561
  has_non_text_content = True
801
562
 
802
563
  # Join all parts with double newlines for better readability
803
- result = (
804
- "\n\n".join(assistant_text_parts)
805
- if assistant_text_parts
806
- else str(last_message.content)
807
- )
564
+ result = "\n\n".join(assistant_text_parts) if assistant_text_parts else str(last_message.content)
808
565
 
809
566
  # Add a note if non-text content was present
810
567
  if has_non_text_content:
811
568
  result += "\n\n[Note: This message contained non-text content that may not be fully represented in text format]"
812
569
 
813
570
  return result
571
+
572
+
573
+ #####################################
574
+ ### NEW INTERFACE METHODS BELOW ###
575
+ #####################################