fast-agent-mcp 0.1.12__py3-none-any.whl → 0.1.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {fast_agent_mcp-0.1.12.dist-info → fast_agent_mcp-0.1.13.dist-info}/METADATA +1 -1
- fast_agent_mcp-0.1.13.dist-info/RECORD +164 -0
- mcp_agent/agents/agent.py +37 -79
- mcp_agent/app.py +16 -22
- mcp_agent/cli/commands/bootstrap.py +22 -52
- mcp_agent/cli/commands/config.py +4 -4
- mcp_agent/cli/commands/setup.py +11 -26
- mcp_agent/cli/main.py +6 -9
- mcp_agent/cli/terminal.py +2 -2
- mcp_agent/config.py +1 -5
- mcp_agent/context.py +13 -24
- mcp_agent/context_dependent.py +3 -7
- mcp_agent/core/agent_app.py +45 -121
- mcp_agent/core/agent_utils.py +3 -5
- mcp_agent/core/decorators.py +5 -12
- mcp_agent/core/enhanced_prompt.py +25 -52
- mcp_agent/core/exceptions.py +8 -8
- mcp_agent/core/factory.py +29 -70
- mcp_agent/core/fastagent.py +48 -88
- mcp_agent/core/mcp_content.py +8 -16
- mcp_agent/core/prompt.py +8 -15
- mcp_agent/core/proxies.py +34 -25
- mcp_agent/core/request_params.py +6 -3
- mcp_agent/core/types.py +4 -6
- mcp_agent/core/validation.py +4 -3
- mcp_agent/executor/decorator_registry.py +11 -23
- mcp_agent/executor/executor.py +8 -17
- mcp_agent/executor/task_registry.py +2 -4
- mcp_agent/executor/temporal.py +28 -74
- mcp_agent/executor/workflow.py +3 -5
- mcp_agent/executor/workflow_signal.py +17 -29
- mcp_agent/human_input/handler.py +4 -9
- mcp_agent/human_input/types.py +2 -3
- mcp_agent/logging/events.py +1 -5
- mcp_agent/logging/json_serializer.py +7 -6
- mcp_agent/logging/listeners.py +20 -23
- mcp_agent/logging/logger.py +15 -17
- mcp_agent/logging/rich_progress.py +10 -8
- mcp_agent/logging/tracing.py +4 -6
- mcp_agent/logging/transport.py +22 -22
- mcp_agent/mcp/gen_client.py +4 -12
- mcp_agent/mcp/interfaces.py +71 -86
- mcp_agent/mcp/mcp_agent_client_session.py +11 -19
- mcp_agent/mcp/mcp_agent_server.py +8 -10
- mcp_agent/mcp/mcp_aggregator.py +45 -117
- mcp_agent/mcp/mcp_connection_manager.py +16 -37
- mcp_agent/mcp/prompt_message_multipart.py +12 -18
- mcp_agent/mcp/prompt_serialization.py +13 -38
- mcp_agent/mcp/prompts/prompt_load.py +99 -0
- mcp_agent/mcp/prompts/prompt_server.py +21 -128
- mcp_agent/mcp/prompts/prompt_template.py +20 -42
- mcp_agent/mcp/resource_utils.py +8 -17
- mcp_agent/mcp/sampling.py +5 -14
- mcp_agent/mcp/stdio.py +11 -8
- mcp_agent/mcp_server/agent_server.py +10 -17
- mcp_agent/mcp_server_registry.py +13 -35
- mcp_agent/resources/examples/data-analysis/analysis-campaign.py +1 -1
- mcp_agent/resources/examples/data-analysis/analysis.py +1 -1
- mcp_agent/resources/examples/data-analysis/slides.py +110 -0
- mcp_agent/resources/examples/internal/agent.py +2 -1
- mcp_agent/resources/examples/internal/job.py +2 -1
- mcp_agent/resources/examples/internal/prompt_category.py +1 -1
- mcp_agent/resources/examples/internal/prompt_sizing.py +3 -5
- mcp_agent/resources/examples/internal/sizer.py +2 -1
- mcp_agent/resources/examples/internal/social.py +2 -1
- mcp_agent/resources/examples/mcp_researcher/researcher-eval.py +1 -1
- mcp_agent/resources/examples/prompting/agent.py +2 -1
- mcp_agent/resources/examples/prompting/image_server.py +5 -11
- mcp_agent/resources/examples/researcher/researcher-eval.py +1 -1
- mcp_agent/resources/examples/researcher/researcher-imp.py +3 -4
- mcp_agent/resources/examples/researcher/researcher.py +2 -1
- mcp_agent/resources/examples/workflows/agent_build.py +2 -1
- mcp_agent/resources/examples/workflows/chaining.py +2 -1
- mcp_agent/resources/examples/workflows/evaluator.py +2 -1
- mcp_agent/resources/examples/workflows/human_input.py +2 -1
- mcp_agent/resources/examples/workflows/orchestrator.py +2 -1
- mcp_agent/resources/examples/workflows/parallel.py +2 -1
- mcp_agent/resources/examples/workflows/router.py +2 -1
- mcp_agent/resources/examples/workflows/sse.py +1 -1
- mcp_agent/telemetry/usage_tracking.py +2 -1
- mcp_agent/ui/console_display.py +15 -39
- mcp_agent/workflows/embedding/embedding_base.py +1 -4
- mcp_agent/workflows/embedding/embedding_cohere.py +2 -2
- mcp_agent/workflows/embedding/embedding_openai.py +4 -13
- mcp_agent/workflows/evaluator_optimizer/evaluator_optimizer.py +23 -57
- mcp_agent/workflows/intent_classifier/intent_classifier_base.py +5 -8
- mcp_agent/workflows/intent_classifier/intent_classifier_embedding.py +7 -11
- mcp_agent/workflows/intent_classifier/intent_classifier_embedding_cohere.py +4 -8
- mcp_agent/workflows/intent_classifier/intent_classifier_embedding_openai.py +4 -8
- mcp_agent/workflows/intent_classifier/intent_classifier_llm.py +11 -22
- mcp_agent/workflows/intent_classifier/intent_classifier_llm_anthropic.py +3 -3
- mcp_agent/workflows/intent_classifier/intent_classifier_llm_openai.py +4 -6
- mcp_agent/workflows/llm/anthropic_utils.py +8 -29
- mcp_agent/workflows/llm/augmented_llm.py +69 -247
- mcp_agent/workflows/llm/augmented_llm_anthropic.py +39 -73
- mcp_agent/workflows/llm/augmented_llm_openai.py +42 -97
- mcp_agent/workflows/llm/augmented_llm_passthrough.py +13 -20
- mcp_agent/workflows/llm/augmented_llm_playback.py +8 -6
- mcp_agent/workflows/llm/memory.py +103 -0
- mcp_agent/workflows/llm/model_factory.py +8 -20
- mcp_agent/workflows/llm/openai_utils.py +1 -1
- mcp_agent/workflows/llm/prompt_utils.py +1 -3
- mcp_agent/workflows/llm/providers/multipart_converter_anthropic.py +47 -89
- mcp_agent/workflows/llm/providers/multipart_converter_openai.py +20 -55
- mcp_agent/workflows/llm/providers/openai_multipart.py +19 -61
- mcp_agent/workflows/llm/providers/sampling_converter_anthropic.py +10 -12
- mcp_agent/workflows/llm/providers/sampling_converter_openai.py +7 -11
- mcp_agent/workflows/llm/sampling_converter.py +4 -11
- mcp_agent/workflows/llm/sampling_format_converter.py +12 -12
- mcp_agent/workflows/orchestrator/orchestrator.py +24 -67
- mcp_agent/workflows/orchestrator/orchestrator_models.py +14 -40
- mcp_agent/workflows/parallel/fan_in.py +17 -47
- mcp_agent/workflows/parallel/fan_out.py +6 -12
- mcp_agent/workflows/parallel/parallel_llm.py +9 -26
- mcp_agent/workflows/router/router_base.py +19 -49
- mcp_agent/workflows/router/router_embedding.py +11 -25
- mcp_agent/workflows/router/router_embedding_cohere.py +2 -2
- mcp_agent/workflows/router/router_embedding_openai.py +2 -2
- mcp_agent/workflows/router/router_llm.py +12 -28
- mcp_agent/workflows/swarm/swarm.py +20 -48
- mcp_agent/workflows/swarm/swarm_anthropic.py +2 -2
- mcp_agent/workflows/swarm/swarm_openai.py +2 -2
- fast_agent_mcp-0.1.12.dist-info/RECORD +0 -161
- {fast_agent_mcp-0.1.12.dist-info → fast_agent_mcp-0.1.13.dist-info}/WHEEL +0 -0
- {fast_agent_mcp-0.1.12.dist-info → fast_agent_mcp-0.1.13.dist-info}/entry_points.txt +0 -0
- {fast_agent_mcp-0.1.12.dist-info → fast_agent_mcp-0.1.13.dist-info}/licenses/LICENSE +0 -0
@@ -1,5 +1,5 @@
|
|
1
1
|
import os
|
2
|
-
from typing import List, Type
|
2
|
+
from typing import TYPE_CHECKING, List, Type
|
3
3
|
|
4
4
|
from mcp_agent.workflows.llm.providers.multipart_converter_anthropic import (
|
5
5
|
AnthropicConverter,
|
@@ -9,6 +9,8 @@ from mcp_agent.workflows.llm.providers.sampling_converter_anthropic import (
|
|
9
9
|
)
|
10
10
|
|
11
11
|
if TYPE_CHECKING:
|
12
|
+
from mcp import ListToolsResult
|
13
|
+
|
12
14
|
from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
|
13
15
|
|
14
16
|
|
@@ -20,22 +22,22 @@ from anthropic.types import (
|
|
20
22
|
TextBlockParam,
|
21
23
|
ToolParam,
|
22
24
|
ToolUseBlockParam,
|
25
|
+
Usage,
|
23
26
|
)
|
24
27
|
from mcp.types import (
|
25
|
-
CallToolRequestParams,
|
26
28
|
CallToolRequest,
|
29
|
+
CallToolRequestParams,
|
27
30
|
)
|
28
31
|
from pydantic_core import from_json
|
32
|
+
from rich.text import Text
|
29
33
|
|
34
|
+
from mcp_agent.core.exceptions import ProviderKeyError
|
35
|
+
from mcp_agent.logging.logger import get_logger
|
30
36
|
from mcp_agent.workflows.llm.augmented_llm import (
|
31
37
|
AugmentedLLM,
|
32
38
|
ModelT,
|
33
39
|
RequestParams,
|
34
40
|
)
|
35
|
-
from mcp_agent.core.exceptions import ProviderKeyError
|
36
|
-
from rich.text import Text
|
37
|
-
|
38
|
-
from mcp_agent.logging.logger import get_logger
|
39
41
|
|
40
42
|
DEFAULT_ANTHROPIC_MODEL = "claude-3-7-sonnet-latest"
|
41
43
|
|
@@ -48,7 +50,7 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
48
50
|
selecting appropriate tools, and determining what information to retain.
|
49
51
|
"""
|
50
52
|
|
51
|
-
def __init__(self, *args, **kwargs):
|
53
|
+
def __init__(self, *args, **kwargs) -> None:
|
52
54
|
self.provider = "Anthropic"
|
53
55
|
# Initialize logger - keep it simple without name reference
|
54
56
|
self.logger = get_logger(__name__)
|
@@ -85,8 +87,7 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
85
87
|
except AuthenticationError as e:
|
86
88
|
raise ProviderKeyError(
|
87
89
|
"Invalid Anthropic API key",
|
88
|
-
"The configured Anthropic API key was rejected.\
|
89
|
-
"Please check that your API key is valid and not expired.",
|
90
|
+
"The configured Anthropic API key was rejected.\nPlease check that your API key is valid and not expired.",
|
90
91
|
) from e
|
91
92
|
|
92
93
|
# Always include prompt messages, but only include conversation history
|
@@ -100,14 +101,14 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
100
101
|
else:
|
101
102
|
messages.append(message)
|
102
103
|
|
103
|
-
|
104
|
+
tool_list: ListToolsResult = await self.aggregator.list_tools()
|
104
105
|
available_tools: List[ToolParam] = [
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
for tool in
|
106
|
+
ToolParam(
|
107
|
+
name=tool.name,
|
108
|
+
description=tool.description or "",
|
109
|
+
input_schema=tool.inputSchema,
|
110
|
+
)
|
111
|
+
for tool in tool_list.tools
|
111
112
|
]
|
112
113
|
|
113
114
|
responses: List[Message] = []
|
@@ -134,17 +135,14 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
134
135
|
|
135
136
|
self.logger.debug(f"{arguments}")
|
136
137
|
|
137
|
-
executor_result = await self.executor.execute(
|
138
|
-
anthropic.messages.create, **arguments
|
139
|
-
)
|
138
|
+
executor_result = await self.executor.execute(anthropic.messages.create, **arguments)
|
140
139
|
|
141
140
|
response = executor_result[0]
|
142
141
|
|
143
142
|
if isinstance(response, AuthenticationError):
|
144
143
|
raise ProviderKeyError(
|
145
144
|
"Invalid Anthropic API key",
|
146
|
-
"The configured Anthropic API key was rejected.\
|
147
|
-
"Please check that your API key is valid and not expired.",
|
145
|
+
"The configured Anthropic API key was rejected.\nPlease check that your API key is valid and not expired.",
|
148
146
|
) from response
|
149
147
|
elif isinstance(response, BaseException):
|
150
148
|
error_details = str(response)
|
@@ -154,13 +152,9 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
154
152
|
if hasattr(response, "status_code") and hasattr(response, "response"):
|
155
153
|
try:
|
156
154
|
error_json = response.response.json()
|
157
|
-
error_details =
|
158
|
-
f"Error code: {response.status_code} - {error_json}"
|
159
|
-
)
|
155
|
+
error_details = f"Error code: {response.status_code} - {error_json}"
|
160
156
|
except: # noqa: E722
|
161
|
-
error_details = (
|
162
|
-
f"Error code: {response.status_code} - {str(response)}"
|
163
|
-
)
|
157
|
+
error_details = f"Error code: {response.status_code} - {str(response)}"
|
164
158
|
|
165
159
|
# Convert other errors to text response
|
166
160
|
error_message = f"Error during generation: {error_details}"
|
@@ -171,7 +165,7 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
171
165
|
type="message",
|
172
166
|
content=[TextBlock(type="text", text=error_message)],
|
173
167
|
stop_reason="end_turn", # Must be one of the allowed values
|
174
|
-
usage=
|
168
|
+
usage=Usage(input_tokens=0, output_tokens=0), # Required field
|
175
169
|
)
|
176
170
|
|
177
171
|
self.logger.debug(
|
@@ -193,22 +187,16 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
193
187
|
|
194
188
|
await self.show_assistant_message(message_text)
|
195
189
|
|
196
|
-
self.logger.debug(
|
197
|
-
f"Iteration {i}: Stopping because finish_reason is 'end_turn'"
|
198
|
-
)
|
190
|
+
self.logger.debug(f"Iteration {i}: Stopping because finish_reason is 'end_turn'")
|
199
191
|
break
|
200
192
|
elif response.stop_reason == "stop_sequence":
|
201
193
|
# We have reached a stop sequence
|
202
|
-
self.logger.debug(
|
203
|
-
f"Iteration {i}: Stopping because finish_reason is 'stop_sequence'"
|
204
|
-
)
|
194
|
+
self.logger.debug(f"Iteration {i}: Stopping because finish_reason is 'stop_sequence'")
|
205
195
|
break
|
206
196
|
elif response.stop_reason == "max_tokens":
|
207
197
|
# We have reached the max tokens limit
|
208
198
|
|
209
|
-
self.logger.debug(
|
210
|
-
f"Iteration {i}: Stopping because finish_reason is 'max_tokens'"
|
211
|
-
)
|
199
|
+
self.logger.debug(f"Iteration {i}: Stopping because finish_reason is 'max_tokens'")
|
212
200
|
if params.maxTokens is not None:
|
213
201
|
message_text = Text(
|
214
202
|
f"the assistant has reached the maximum token limit ({params.maxTokens})",
|
@@ -255,22 +243,16 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
255
243
|
self.show_tool_call(available_tools, tool_name, tool_args)
|
256
244
|
tool_call_request = CallToolRequest(
|
257
245
|
method="tools/call",
|
258
|
-
params=CallToolRequestParams(
|
259
|
-
name=tool_name, arguments=tool_args
|
260
|
-
),
|
246
|
+
params=CallToolRequestParams(name=tool_name, arguments=tool_args),
|
261
247
|
)
|
262
248
|
# TODO -- support MCP isError etc.
|
263
|
-
result = await self.call_tool(
|
264
|
-
request=tool_call_request, tool_call_id=tool_use_id
|
265
|
-
)
|
249
|
+
result = await self.call_tool(request=tool_call_request, tool_call_id=tool_use_id)
|
266
250
|
self.show_tool_result(result)
|
267
251
|
|
268
252
|
# Add each result to our collection
|
269
253
|
tool_results.append((tool_use_id, result))
|
270
254
|
|
271
|
-
messages.append(
|
272
|
-
AnthropicConverter.create_tool_results_message(tool_results)
|
273
|
-
)
|
255
|
+
messages.append(AnthropicConverter.create_tool_results_message(tool_results))
|
274
256
|
|
275
257
|
# Only save the new conversation messages to history if use_history is true
|
276
258
|
# Keep the prompt messages separate
|
@@ -351,12 +333,8 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
351
333
|
# Join all collected text
|
352
334
|
return "\n".join(final_text)
|
353
335
|
|
354
|
-
async def generate_prompt(
|
355
|
-
self
|
356
|
-
) -> str:
|
357
|
-
return await self.generate_str(
|
358
|
-
AnthropicConverter.convert_to_anthropic(prompt), request_params
|
359
|
-
)
|
336
|
+
async def generate_prompt(self, prompt: "PromptMessageMultipart", request_params: RequestParams | None) -> str:
|
337
|
+
return await self.generate_str(AnthropicConverter.convert_to_anthropic(prompt), request_params)
|
360
338
|
|
361
339
|
async def _apply_prompt_template_provider_specific(
|
362
340
|
self,
|
@@ -378,11 +356,7 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
378
356
|
last_message = multipart_messages[-1]
|
379
357
|
|
380
358
|
# Add all previous messages to history (or all messages if last is from assistant)
|
381
|
-
messages_to_add =
|
382
|
-
multipart_messages[:-1]
|
383
|
-
if last_message.role == "user"
|
384
|
-
else multipart_messages
|
385
|
-
)
|
359
|
+
messages_to_add = multipart_messages[:-1] if last_message.role == "user" else multipart_messages
|
386
360
|
converted = []
|
387
361
|
for msg in messages_to_add:
|
388
362
|
converted.append(AnthropicConverter.convert_to_anthropic(msg))
|
@@ -390,16 +364,12 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
390
364
|
|
391
365
|
if last_message.role == "user":
|
392
366
|
# For user messages: Generate response to the last one
|
393
|
-
self.logger.debug(
|
394
|
-
"Last message in prompt is from user, generating assistant response"
|
395
|
-
)
|
367
|
+
self.logger.debug("Last message in prompt is from user, generating assistant response")
|
396
368
|
message_param = AnthropicConverter.convert_to_anthropic(last_message)
|
397
369
|
return await self.generate_str(message_param, request_params)
|
398
370
|
else:
|
399
371
|
# For assistant messages: Return the last message content as text
|
400
|
-
self.logger.debug(
|
401
|
-
"Last message in prompt is from assistant, returning it directly"
|
402
|
-
)
|
372
|
+
self.logger.debug("Last message in prompt is from assistant, returning it directly")
|
403
373
|
return str(last_message)
|
404
374
|
|
405
375
|
async def _save_history_to_file(self, command: str) -> str:
|
@@ -424,19 +394,17 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
424
394
|
messages = self.history.get(include_history=True)
|
425
395
|
|
426
396
|
# Import required utilities
|
427
|
-
from mcp_agent.workflows.llm.anthropic_utils import (
|
428
|
-
anthropic_message_param_to_prompt_message_multipart,
|
429
|
-
)
|
430
397
|
from mcp_agent.mcp.prompt_serialization import (
|
431
398
|
multipart_messages_to_delimited_format,
|
432
399
|
)
|
400
|
+
from mcp_agent.workflows.llm.anthropic_utils import (
|
401
|
+
anthropic_message_param_to_prompt_message_multipart,
|
402
|
+
)
|
433
403
|
|
434
404
|
# Convert message params to PromptMessageMultipart objects
|
435
405
|
multipart_messages = []
|
436
406
|
for msg in messages:
|
437
|
-
multipart_messages.append(
|
438
|
-
anthropic_message_param_to_prompt_message_multipart(msg)
|
439
|
-
)
|
407
|
+
multipart_messages.append(anthropic_message_param_to_prompt_message_multipart(msg))
|
440
408
|
|
441
409
|
# Convert to delimited format
|
442
410
|
delimited_content = multipart_messages_to_delimited_format(
|
@@ -458,7 +426,7 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
458
426
|
|
459
427
|
async def generate_structured(
|
460
428
|
self,
|
461
|
-
message,
|
429
|
+
message: str,
|
462
430
|
response_model: Type[ModelT],
|
463
431
|
request_params: RequestParams | None = None,
|
464
432
|
) -> ModelT:
|
@@ -475,9 +443,7 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
475
443
|
return response_model.model_validate(from_json(response, allow_partial=True))
|
476
444
|
|
477
445
|
@classmethod
|
478
|
-
def convert_message_to_message_param(
|
479
|
-
cls, message: Message, **kwargs
|
480
|
-
) -> MessageParam:
|
446
|
+
def convert_message_to_message_param(cls, message: Message, **kwargs) -> MessageParam:
|
481
447
|
"""Convert a response object to an input parameter object to allow LLM calls to be chained."""
|
482
448
|
content = []
|
483
449
|
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import os
|
2
|
-
from typing import List, Type
|
2
|
+
from typing import TYPE_CHECKING, List, Type
|
3
3
|
|
4
4
|
from pydantic_core import from_json
|
5
5
|
|
@@ -10,30 +10,30 @@ from mcp_agent.workflows.llm.providers.sampling_converter_openai import (
|
|
10
10
|
|
11
11
|
if TYPE_CHECKING:
|
12
12
|
from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
|
13
|
-
from
|
13
|
+
from mcp.types import (
|
14
|
+
CallToolRequest,
|
15
|
+
CallToolRequestParams,
|
16
|
+
CallToolResult,
|
17
|
+
)
|
18
|
+
from openai import AuthenticationError, OpenAI
|
14
19
|
|
15
20
|
# from openai.types.beta.chat import
|
16
21
|
from openai.types.chat import (
|
17
|
-
ChatCompletionMessageParam,
|
18
22
|
ChatCompletionMessage,
|
23
|
+
ChatCompletionMessageParam,
|
19
24
|
ChatCompletionSystemMessageParam,
|
20
25
|
ChatCompletionToolParam,
|
21
26
|
ChatCompletionUserMessageParam,
|
22
27
|
)
|
23
|
-
from
|
24
|
-
CallToolRequestParams,
|
25
|
-
CallToolRequest,
|
26
|
-
CallToolResult,
|
27
|
-
)
|
28
|
+
from rich.text import Text
|
28
29
|
|
30
|
+
from mcp_agent.core.exceptions import ProviderKeyError
|
31
|
+
from mcp_agent.logging.logger import get_logger
|
29
32
|
from mcp_agent.workflows.llm.augmented_llm import (
|
30
33
|
AugmentedLLM,
|
31
34
|
ModelT,
|
32
35
|
RequestParams,
|
33
36
|
)
|
34
|
-
from mcp_agent.core.exceptions import ProviderKeyError
|
35
|
-
from mcp_agent.logging.logger import get_logger
|
36
|
-
from rich.text import Text
|
37
37
|
|
38
38
|
_logger = get_logger(__name__)
|
39
39
|
|
@@ -41,16 +41,14 @@ DEFAULT_OPENAI_MODEL = "gpt-4o"
|
|
41
41
|
DEFAULT_REASONING_EFFORT = "medium"
|
42
42
|
|
43
43
|
|
44
|
-
class OpenAIAugmentedLLM(
|
45
|
-
AugmentedLLM[ChatCompletionMessageParam, ChatCompletionMessage]
|
46
|
-
):
|
44
|
+
class OpenAIAugmentedLLM(AugmentedLLM[ChatCompletionMessageParam, ChatCompletionMessage]):
|
47
45
|
"""
|
48
46
|
The basic building block of agentic systems is an LLM enhanced with augmentations
|
49
47
|
such as retrieval, tools, and memory provided from a collection of MCP servers.
|
50
48
|
This implementation uses OpenAI's ChatCompletion as the LLM.
|
51
49
|
"""
|
52
50
|
|
53
|
-
def __init__(self, *args, **kwargs):
|
51
|
+
def __init__(self, *args, **kwargs) -> None:
|
54
52
|
# Set type_converter before calling super().__init__
|
55
53
|
if "type_converter" not in kwargs:
|
56
54
|
kwargs["type_converter"] = OpenAISamplingConverter
|
@@ -64,22 +62,14 @@ class OpenAIAugmentedLLM(
|
|
64
62
|
# Set up reasoning-related attributes
|
65
63
|
self._reasoning_effort = kwargs.get("reasoning_effort", None)
|
66
64
|
if self.context and self.context.config and self.context.config.openai:
|
67
|
-
if self._reasoning_effort is None and hasattr(
|
68
|
-
self.context.config.openai, "reasoning_effort"
|
69
|
-
):
|
65
|
+
if self._reasoning_effort is None and hasattr(self.context.config.openai, "reasoning_effort"):
|
70
66
|
self._reasoning_effort = self.context.config.openai.reasoning_effort
|
71
67
|
|
72
68
|
# Determine if we're using a reasoning model
|
73
|
-
chosen_model =
|
74
|
-
|
75
|
-
)
|
76
|
-
self._reasoning = chosen_model and (
|
77
|
-
chosen_model.startswith("o3") or chosen_model.startswith("o1")
|
78
|
-
)
|
69
|
+
chosen_model = self.default_request_params.model if self.default_request_params else None
|
70
|
+
self._reasoning = chosen_model and (chosen_model.startswith("o3") or chosen_model.startswith("o1"))
|
79
71
|
if self._reasoning:
|
80
|
-
self.logger.info(
|
81
|
-
f"Using reasoning model '{chosen_model}' with '{self._reasoning_effort}' reasoning effort"
|
82
|
-
)
|
72
|
+
self.logger.info(f"Using reasoning model '{chosen_model}' with '{self._reasoning_effort}' reasoning effort")
|
83
73
|
|
84
74
|
def _initialize_default_params(self, kwargs: dict) -> RequestParams:
|
85
75
|
"""Initialize OpenAI-specific default parameters"""
|
@@ -120,9 +110,7 @@ class OpenAIAugmentedLLM(
|
|
120
110
|
return api_key
|
121
111
|
|
122
112
|
def _base_url(self) -> str:
|
123
|
-
return
|
124
|
-
self.context.config.openai.base_url if self.context.config.openai else None
|
125
|
-
)
|
113
|
+
return self.context.config.openai.base_url if self.context.config.openai else None
|
126
114
|
|
127
115
|
async def generate(
|
128
116
|
self,
|
@@ -143,24 +131,19 @@ class OpenAIAugmentedLLM(
|
|
143
131
|
except AuthenticationError as e:
|
144
132
|
raise ProviderKeyError(
|
145
133
|
"Invalid OpenAI API key",
|
146
|
-
"The configured OpenAI API key was rejected.\n"
|
147
|
-
"Please check that your API key is valid and not expired.",
|
134
|
+
"The configured OpenAI API key was rejected.\n" "Please check that your API key is valid and not expired.",
|
148
135
|
) from e
|
149
136
|
|
150
137
|
system_prompt = self.instruction or params.systemPrompt
|
151
138
|
if system_prompt:
|
152
|
-
messages.append(
|
153
|
-
ChatCompletionSystemMessageParam(role="system", content=system_prompt)
|
154
|
-
)
|
139
|
+
messages.append(ChatCompletionSystemMessageParam(role="system", content=system_prompt))
|
155
140
|
|
156
141
|
# Always include prompt messages, but only include conversation history
|
157
142
|
# if use_history is True
|
158
143
|
messages.extend(self.history.get(include_history=params.use_history))
|
159
144
|
|
160
145
|
if isinstance(message, str):
|
161
|
-
messages.append(
|
162
|
-
ChatCompletionUserMessageParam(role="user", content=message)
|
163
|
-
)
|
146
|
+
messages.append(ChatCompletionUserMessageParam(role="user", content=message))
|
164
147
|
elif isinstance(message, list):
|
165
148
|
messages.extend(message)
|
166
149
|
else:
|
@@ -186,9 +169,7 @@ class OpenAIAugmentedLLM(
|
|
186
169
|
model = await self.select_model(params)
|
187
170
|
chat_turn = len(messages) // 2
|
188
171
|
if self._reasoning:
|
189
|
-
self.show_user_message(
|
190
|
-
str(message), f"{model} ({self._reasoning_effort})", chat_turn
|
191
|
-
)
|
172
|
+
self.show_user_message(str(message), f"{model} ({self._reasoning_effort})", chat_turn)
|
192
173
|
else:
|
193
174
|
self.show_user_message(str(message), model, chat_turn)
|
194
175
|
|
@@ -217,9 +198,7 @@ class OpenAIAugmentedLLM(
|
|
217
198
|
self._log_chat_progress(chat_turn, model=model)
|
218
199
|
|
219
200
|
if response_model is None:
|
220
|
-
executor_result = await self.executor.execute(
|
221
|
-
openai_client.chat.completions.create, **arguments
|
222
|
-
)
|
201
|
+
executor_result = await self.executor.execute(openai_client.chat.completions.create, **arguments)
|
223
202
|
else:
|
224
203
|
executor_result = await self.executor.execute(
|
225
204
|
openai_client.beta.chat.completions.parse,
|
@@ -237,8 +216,7 @@ class OpenAIAugmentedLLM(
|
|
237
216
|
if isinstance(response, AuthenticationError):
|
238
217
|
raise ProviderKeyError(
|
239
218
|
"Invalid OpenAI API key",
|
240
|
-
"The configured OpenAI API key was rejected.\n"
|
241
|
-
"Please check that your API key is valid and not expired.",
|
219
|
+
"The configured OpenAI API key was rejected.\n" "Please check that your API key is valid and not expired.",
|
242
220
|
) from response
|
243
221
|
elif isinstance(response, BaseException):
|
244
222
|
self.logger.error(f"Error: {response}")
|
@@ -254,21 +232,14 @@ class OpenAIAugmentedLLM(
|
|
254
232
|
message = choice.message
|
255
233
|
responses.append(message)
|
256
234
|
|
257
|
-
converted_message = self.convert_message_to_message_param(
|
258
|
-
message, name=self.name
|
259
|
-
)
|
235
|
+
converted_message = self.convert_message_to_message_param(message, name=self.name)
|
260
236
|
messages.append(converted_message)
|
261
237
|
message_text = converted_message.content
|
262
|
-
if
|
263
|
-
choice.finish_reason in ["tool_calls", "function_call"]
|
264
|
-
and message.tool_calls
|
265
|
-
):
|
238
|
+
if choice.finish_reason in ["tool_calls", "function_call"] and message.tool_calls:
|
266
239
|
if message_text:
|
267
240
|
await self.show_assistant_message(
|
268
241
|
message_text,
|
269
|
-
message.tool_calls[
|
270
|
-
0
|
271
|
-
].function.name, # TODO support displaying multiple tool calls
|
242
|
+
message.tool_calls[0].function.name, # TODO support displaying multiple tool calls
|
272
243
|
)
|
273
244
|
else:
|
274
245
|
await self.show_assistant_message(
|
@@ -290,9 +261,7 @@ class OpenAIAugmentedLLM(
|
|
290
261
|
method="tools/call",
|
291
262
|
params=CallToolRequestParams(
|
292
263
|
name=tool_call.function.name,
|
293
|
-
arguments=from_json(
|
294
|
-
tool_call.function.arguments, allow_partial=True
|
295
|
-
),
|
264
|
+
arguments=from_json(tool_call.function.arguments, allow_partial=True),
|
296
265
|
),
|
297
266
|
)
|
298
267
|
result = await self.call_tool(tool_call_request, tool_call.id)
|
@@ -300,18 +269,12 @@ class OpenAIAugmentedLLM(
|
|
300
269
|
|
301
270
|
tool_results.append((tool_call.id, result))
|
302
271
|
|
303
|
-
messages.extend(
|
304
|
-
OpenAIConverter.convert_function_results_to_openai(tool_results)
|
305
|
-
)
|
272
|
+
messages.extend(OpenAIConverter.convert_function_results_to_openai(tool_results))
|
306
273
|
|
307
|
-
self.logger.debug(
|
308
|
-
f"Iteration {i}: Tool call results: {str(tool_results) if tool_results else 'None'}"
|
309
|
-
)
|
274
|
+
self.logger.debug(f"Iteration {i}: Tool call results: {str(tool_results) if tool_results else 'None'}")
|
310
275
|
elif choice.finish_reason == "length":
|
311
276
|
# We have reached the max tokens limit
|
312
|
-
self.logger.debug(
|
313
|
-
f"Iteration {i}: Stopping because finish_reason is 'length'"
|
314
|
-
)
|
277
|
+
self.logger.debug(f"Iteration {i}: Stopping because finish_reason is 'length'")
|
315
278
|
if request_params and request_params.maxTokens is not None:
|
316
279
|
message_text = Text(
|
317
280
|
f"the assistant has reached the maximum token limit ({request_params.maxTokens})",
|
@@ -328,15 +291,11 @@ class OpenAIAugmentedLLM(
|
|
328
291
|
break
|
329
292
|
elif choice.finish_reason == "content_filter":
|
330
293
|
# The response was filtered by the content filter
|
331
|
-
self.logger.debug(
|
332
|
-
f"Iteration {i}: Stopping because finish_reason is 'content_filter'"
|
333
|
-
)
|
294
|
+
self.logger.debug(f"Iteration {i}: Stopping because finish_reason is 'content_filter'")
|
334
295
|
# TODO: saqadri - would be useful to return the reason for stopping to the caller
|
335
296
|
break
|
336
297
|
elif choice.finish_reason == "stop":
|
337
|
-
self.logger.debug(
|
338
|
-
f"Iteration {i}: Stopping because finish_reason is 'stop'"
|
339
|
-
)
|
298
|
+
self.logger.debug(f"Iteration {i}: Stopping because finish_reason is 'stop'")
|
340
299
|
if message_text:
|
341
300
|
await self.show_assistant_message(message_text, "")
|
342
301
|
break
|
@@ -416,11 +375,7 @@ class OpenAIAugmentedLLM(
|
|
416
375
|
last_message = multipart_messages[-1]
|
417
376
|
|
418
377
|
# Add all previous messages to history (or all messages if last is from assistant)
|
419
|
-
messages_to_add =
|
420
|
-
multipart_messages[:-1]
|
421
|
-
if last_message.role == "user"
|
422
|
-
else multipart_messages
|
423
|
-
)
|
378
|
+
messages_to_add = multipart_messages[:-1] if last_message.role == "user" else multipart_messages
|
424
379
|
converted = []
|
425
380
|
for msg in messages_to_add:
|
426
381
|
converted.append(OpenAIConverter.convert_to_openai(msg))
|
@@ -428,16 +383,12 @@ class OpenAIAugmentedLLM(
|
|
428
383
|
|
429
384
|
if last_message.role == "user":
|
430
385
|
# For user messages: Generate response to the last one
|
431
|
-
self.logger.debug(
|
432
|
-
"Last message in prompt is from user, generating assistant response"
|
433
|
-
)
|
386
|
+
self.logger.debug("Last message in prompt is from user, generating assistant response")
|
434
387
|
message_param = OpenAIConverter.convert_to_openai(last_message)
|
435
388
|
return await self.generate_str(message_param, request_params)
|
436
389
|
else:
|
437
390
|
# For assistant messages: Return the last message content as text
|
438
|
-
self.logger.debug(
|
439
|
-
"Last message in prompt is from assistant, returning it directly"
|
440
|
-
)
|
391
|
+
self.logger.debug("Last message in prompt is from assistant, returning it directly")
|
441
392
|
return str(last_message)
|
442
393
|
|
443
394
|
async def _save_history_to_file(self, command: str) -> str:
|
@@ -462,12 +413,12 @@ class OpenAIAugmentedLLM(
|
|
462
413
|
messages = self.history.get(include_history=True)
|
463
414
|
|
464
415
|
# Import required utilities
|
465
|
-
from mcp_agent.workflows.llm.openai_utils import (
|
466
|
-
openai_message_param_to_prompt_message_multipart,
|
467
|
-
)
|
468
416
|
from mcp_agent.mcp.prompt_serialization import (
|
469
417
|
multipart_messages_to_delimited_format,
|
470
418
|
)
|
419
|
+
from mcp_agent.workflows.llm.openai_utils import (
|
420
|
+
openai_message_param_to_prompt_message_multipart,
|
421
|
+
)
|
471
422
|
|
472
423
|
# Convert message params to PromptMessageMultipart objects
|
473
424
|
multipart_messages = []
|
@@ -477,9 +428,7 @@ class OpenAIAugmentedLLM(
|
|
477
428
|
continue
|
478
429
|
|
479
430
|
# Convert the message to a multipart message
|
480
|
-
multipart_messages.append(
|
481
|
-
openai_message_param_to_prompt_message_multipart(msg)
|
482
|
-
)
|
431
|
+
multipart_messages.append(openai_message_param_to_prompt_message_multipart(msg))
|
483
432
|
|
484
433
|
# Convert to delimited format
|
485
434
|
delimited_content = multipart_messages_to_delimited_format(
|
@@ -512,18 +461,14 @@ class OpenAIAugmentedLLM(
|
|
512
461
|
)
|
513
462
|
return responses[0].parsed
|
514
463
|
|
515
|
-
async def generate_prompt(
|
516
|
-
self, prompt: "PromptMessageMultipart", request_params: RequestParams | None
|
517
|
-
) -> str:
|
464
|
+
async def generate_prompt(self, prompt: "PromptMessageMultipart", request_params: RequestParams | None) -> str:
|
518
465
|
converted_prompt = OpenAIConverter.convert_to_openai(prompt)
|
519
466
|
return await self.generate_str(converted_prompt, request_params)
|
520
467
|
|
521
468
|
async def pre_tool_call(self, tool_call_id: str | None, request: CallToolRequest):
|
522
469
|
return request
|
523
470
|
|
524
|
-
async def post_tool_call(
|
525
|
-
self, tool_call_id: str | None, request: CallToolRequest, result: CallToolResult
|
526
|
-
):
|
471
|
+
async def post_tool_call(self, tool_call_id: str | None, request: CallToolRequest, result: CallToolResult):
|
527
472
|
return result
|
528
473
|
|
529
474
|
def message_param_str(self, message: ChatCompletionMessageParam) -> str:
|
@@ -1,8 +1,11 @@
|
|
1
|
-
from typing import Any, List, Optional, Type, Union
|
2
1
|
import json # Import at the module level
|
2
|
+
from typing import Any, List, Optional, Type, Union
|
3
|
+
|
3
4
|
from mcp import GetPromptResult
|
4
5
|
from mcp.types import PromptMessage
|
5
6
|
from pydantic_core import from_json
|
7
|
+
|
8
|
+
from mcp_agent.logging.logger import get_logger
|
6
9
|
from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
|
7
10
|
from mcp_agent.workflows.llm.augmented_llm import (
|
8
11
|
AugmentedLLM,
|
@@ -11,7 +14,6 @@ from mcp_agent.workflows.llm.augmented_llm import (
|
|
11
14
|
ModelT,
|
12
15
|
RequestParams,
|
13
16
|
)
|
14
|
-
from mcp_agent.logging.logger import get_logger
|
15
17
|
|
16
18
|
|
17
19
|
class PassthroughLLM(AugmentedLLM):
|
@@ -23,7 +25,7 @@ class PassthroughLLM(AugmentedLLM):
|
|
23
25
|
parallel workflow where no fan-in aggregation is needed.
|
24
26
|
"""
|
25
27
|
|
26
|
-
def __init__(self, name: str = "Passthrough", context=None, **kwargs):
|
28
|
+
def __init__(self, name: str = "Passthrough", context=None, **kwargs) -> None:
|
27
29
|
super().__init__(name=name, context=context, **kwargs)
|
28
30
|
self.provider = "fast-agent"
|
29
31
|
# Initialize logger - keep it simple without name reference
|
@@ -61,6 +63,9 @@ class PassthroughLLM(AugmentedLLM):
|
|
61
63
|
|
62
64
|
return str(message)
|
63
65
|
|
66
|
+
async def initialize(self) -> None:
|
67
|
+
pass
|
68
|
+
|
64
69
|
async def _call_tool_and_return_result(self, command: str) -> str:
|
65
70
|
"""
|
66
71
|
Call a tool based on the command and return its result as a string.
|
@@ -94,9 +99,7 @@ class PassthroughLLM(AugmentedLLM):
|
|
94
99
|
"""
|
95
100
|
parts = command.split(" ", 2)
|
96
101
|
if len(parts) < 2:
|
97
|
-
raise ValueError(
|
98
|
-
"Invalid format. Expected '***CALL_TOOL <tool_name> [arguments_json]'"
|
99
|
-
)
|
102
|
+
raise ValueError("Invalid format. Expected '***CALL_TOOL <tool_name> [arguments_json]'")
|
100
103
|
|
101
104
|
tool_name = parts[1].strip()
|
102
105
|
arguments = None
|
@@ -158,15 +161,9 @@ class PassthroughLLM(AugmentedLLM):
|
|
158
161
|
elif isinstance(message, str):
|
159
162
|
return response_model.model_validate(from_json(message, allow_partial=True))
|
160
163
|
|
161
|
-
async def generate_prompt(
|
162
|
-
self, prompt: "PromptMessageMultipart", request_params: RequestParams | None
|
163
|
-
) -> str:
|
164
|
+
async def generate_prompt(self, prompt: "PromptMessageMultipart", request_params: RequestParams | None) -> str:
|
164
165
|
# Check if this prompt contains a tool call command
|
165
|
-
if (
|
166
|
-
prompt.content
|
167
|
-
and prompt.content[0].text
|
168
|
-
and prompt.content[0].text.startswith("***CALL_TOOL ")
|
169
|
-
):
|
166
|
+
if prompt.content and prompt.content[0].text and prompt.content[0].text.startswith("***CALL_TOOL "):
|
170
167
|
return await self._call_tool_and_return_result(prompt.content[0].text)
|
171
168
|
|
172
169
|
# Process all parts of the PromptMessageMultipart
|
@@ -204,9 +201,7 @@ class PassthroughLLM(AugmentedLLM):
|
|
204
201
|
|
205
202
|
return result
|
206
203
|
|
207
|
-
async def apply_prompt_template(
|
208
|
-
self, prompt_result: GetPromptResult, prompt_name: str
|
209
|
-
) -> str:
|
204
|
+
async def apply_prompt_template(self, prompt_result: GetPromptResult, prompt_name: str) -> str:
|
210
205
|
"""
|
211
206
|
Apply a prompt template by adding it to the conversation history.
|
212
207
|
For PassthroughLLM, this returns all content concatenated together.
|
@@ -233,9 +228,7 @@ class PassthroughLLM(AugmentedLLM):
|
|
233
228
|
self._messages = prompt_messages
|
234
229
|
|
235
230
|
# Convert prompt messages to multipart format
|
236
|
-
multipart_messages = PromptMessageMultipart.
|
237
|
-
prompt_messages
|
238
|
-
)
|
231
|
+
multipart_messages = PromptMessageMultipart.to_multipart(prompt_messages)
|
239
232
|
|
240
233
|
# Use apply_prompt to handle the multipart messages
|
241
234
|
return await self.apply_prompt(multipart_messages)
|