fast-agent-mcp 0.1.11__py3-none-any.whl → 0.1.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {fast_agent_mcp-0.1.11.dist-info → fast_agent_mcp-0.1.13.dist-info}/METADATA +1 -1
- fast_agent_mcp-0.1.13.dist-info/RECORD +164 -0
- mcp_agent/agents/agent.py +37 -102
- mcp_agent/app.py +16 -27
- mcp_agent/cli/commands/bootstrap.py +22 -52
- mcp_agent/cli/commands/config.py +4 -4
- mcp_agent/cli/commands/setup.py +11 -26
- mcp_agent/cli/main.py +6 -9
- mcp_agent/cli/terminal.py +2 -2
- mcp_agent/config.py +1 -5
- mcp_agent/context.py +13 -26
- mcp_agent/context_dependent.py +3 -7
- mcp_agent/core/agent_app.py +46 -122
- mcp_agent/core/agent_types.py +29 -2
- mcp_agent/core/agent_utils.py +3 -5
- mcp_agent/core/decorators.py +6 -14
- mcp_agent/core/enhanced_prompt.py +25 -52
- mcp_agent/core/error_handling.py +1 -1
- mcp_agent/core/exceptions.py +8 -8
- mcp_agent/core/factory.py +30 -72
- mcp_agent/core/fastagent.py +48 -88
- mcp_agent/core/mcp_content.py +10 -19
- mcp_agent/core/prompt.py +8 -15
- mcp_agent/core/proxies.py +34 -25
- mcp_agent/core/request_params.py +46 -0
- mcp_agent/core/types.py +6 -6
- mcp_agent/core/validation.py +16 -16
- mcp_agent/executor/decorator_registry.py +11 -23
- mcp_agent/executor/executor.py +8 -17
- mcp_agent/executor/task_registry.py +2 -4
- mcp_agent/executor/temporal.py +28 -74
- mcp_agent/executor/workflow.py +3 -5
- mcp_agent/executor/workflow_signal.py +17 -29
- mcp_agent/human_input/handler.py +4 -9
- mcp_agent/human_input/types.py +2 -3
- mcp_agent/logging/events.py +1 -5
- mcp_agent/logging/json_serializer.py +7 -6
- mcp_agent/logging/listeners.py +20 -23
- mcp_agent/logging/logger.py +15 -17
- mcp_agent/logging/rich_progress.py +10 -8
- mcp_agent/logging/tracing.py +4 -6
- mcp_agent/logging/transport.py +24 -24
- mcp_agent/mcp/gen_client.py +4 -12
- mcp_agent/mcp/interfaces.py +107 -88
- mcp_agent/mcp/mcp_agent_client_session.py +11 -19
- mcp_agent/mcp/mcp_agent_server.py +8 -10
- mcp_agent/mcp/mcp_aggregator.py +49 -122
- mcp_agent/mcp/mcp_connection_manager.py +16 -37
- mcp_agent/mcp/prompt_message_multipart.py +12 -18
- mcp_agent/mcp/prompt_serialization.py +13 -38
- mcp_agent/mcp/prompts/prompt_load.py +99 -0
- mcp_agent/mcp/prompts/prompt_server.py +21 -128
- mcp_agent/mcp/prompts/prompt_template.py +20 -42
- mcp_agent/mcp/resource_utils.py +8 -17
- mcp_agent/mcp/sampling.py +62 -64
- mcp_agent/mcp/stdio.py +11 -8
- mcp_agent/mcp_server/__init__.py +1 -1
- mcp_agent/mcp_server/agent_server.py +10 -17
- mcp_agent/mcp_server_registry.py +13 -35
- mcp_agent/resources/examples/data-analysis/analysis-campaign.py +1 -1
- mcp_agent/resources/examples/data-analysis/analysis.py +1 -1
- mcp_agent/resources/examples/data-analysis/slides.py +110 -0
- mcp_agent/resources/examples/internal/agent.py +2 -1
- mcp_agent/resources/examples/internal/job.py +2 -1
- mcp_agent/resources/examples/internal/prompt_category.py +1 -1
- mcp_agent/resources/examples/internal/prompt_sizing.py +3 -5
- mcp_agent/resources/examples/internal/sizer.py +2 -1
- mcp_agent/resources/examples/internal/social.py +2 -1
- mcp_agent/resources/examples/mcp_researcher/researcher-eval.py +1 -1
- mcp_agent/resources/examples/prompting/__init__.py +1 -1
- mcp_agent/resources/examples/prompting/agent.py +2 -1
- mcp_agent/resources/examples/prompting/image_server.py +5 -11
- mcp_agent/resources/examples/researcher/researcher-eval.py +1 -1
- mcp_agent/resources/examples/researcher/researcher-imp.py +3 -4
- mcp_agent/resources/examples/researcher/researcher.py +2 -1
- mcp_agent/resources/examples/workflows/agent_build.py +2 -1
- mcp_agent/resources/examples/workflows/chaining.py +2 -1
- mcp_agent/resources/examples/workflows/evaluator.py +2 -1
- mcp_agent/resources/examples/workflows/human_input.py +2 -1
- mcp_agent/resources/examples/workflows/orchestrator.py +2 -1
- mcp_agent/resources/examples/workflows/parallel.py +2 -1
- mcp_agent/resources/examples/workflows/router.py +2 -1
- mcp_agent/resources/examples/workflows/sse.py +1 -1
- mcp_agent/telemetry/usage_tracking.py +2 -1
- mcp_agent/ui/console_display.py +17 -41
- mcp_agent/workflows/embedding/embedding_base.py +1 -4
- mcp_agent/workflows/embedding/embedding_cohere.py +2 -2
- mcp_agent/workflows/embedding/embedding_openai.py +4 -13
- mcp_agent/workflows/evaluator_optimizer/evaluator_optimizer.py +23 -57
- mcp_agent/workflows/intent_classifier/intent_classifier_base.py +5 -8
- mcp_agent/workflows/intent_classifier/intent_classifier_embedding.py +7 -11
- mcp_agent/workflows/intent_classifier/intent_classifier_embedding_cohere.py +4 -8
- mcp_agent/workflows/intent_classifier/intent_classifier_embedding_openai.py +4 -8
- mcp_agent/workflows/intent_classifier/intent_classifier_llm.py +11 -22
- mcp_agent/workflows/intent_classifier/intent_classifier_llm_anthropic.py +3 -3
- mcp_agent/workflows/intent_classifier/intent_classifier_llm_openai.py +4 -6
- mcp_agent/workflows/llm/anthropic_utils.py +8 -29
- mcp_agent/workflows/llm/augmented_llm.py +94 -332
- mcp_agent/workflows/llm/augmented_llm_anthropic.py +43 -76
- mcp_agent/workflows/llm/augmented_llm_openai.py +46 -100
- mcp_agent/workflows/llm/augmented_llm_passthrough.py +42 -20
- mcp_agent/workflows/llm/augmented_llm_playback.py +8 -6
- mcp_agent/workflows/llm/memory.py +103 -0
- mcp_agent/workflows/llm/model_factory.py +9 -21
- mcp_agent/workflows/llm/openai_utils.py +1 -1
- mcp_agent/workflows/llm/prompt_utils.py +39 -27
- mcp_agent/workflows/llm/providers/multipart_converter_anthropic.py +246 -184
- mcp_agent/workflows/llm/providers/multipart_converter_openai.py +212 -202
- mcp_agent/workflows/llm/providers/openai_multipart.py +19 -61
- mcp_agent/workflows/llm/providers/sampling_converter_anthropic.py +11 -212
- mcp_agent/workflows/llm/providers/sampling_converter_openai.py +13 -215
- mcp_agent/workflows/llm/sampling_converter.py +117 -0
- mcp_agent/workflows/llm/sampling_format_converter.py +12 -29
- mcp_agent/workflows/orchestrator/orchestrator.py +24 -67
- mcp_agent/workflows/orchestrator/orchestrator_models.py +14 -40
- mcp_agent/workflows/parallel/fan_in.py +17 -47
- mcp_agent/workflows/parallel/fan_out.py +6 -12
- mcp_agent/workflows/parallel/parallel_llm.py +9 -26
- mcp_agent/workflows/router/router_base.py +29 -59
- mcp_agent/workflows/router/router_embedding.py +11 -25
- mcp_agent/workflows/router/router_embedding_cohere.py +2 -2
- mcp_agent/workflows/router/router_embedding_openai.py +2 -2
- mcp_agent/workflows/router/router_llm.py +12 -28
- mcp_agent/workflows/swarm/swarm.py +20 -48
- mcp_agent/workflows/swarm/swarm_anthropic.py +2 -2
- mcp_agent/workflows/swarm/swarm_openai.py +2 -2
- fast_agent_mcp-0.1.11.dist-info/RECORD +0 -160
- mcp_agent/workflows/llm/llm_selector.py +0 -345
- {fast_agent_mcp-0.1.11.dist-info → fast_agent_mcp-0.1.13.dist-info}/WHEEL +0 -0
- {fast_agent_mcp-0.1.11.dist-info → fast_agent_mcp-0.1.13.dist-info}/entry_points.txt +0 -0
- {fast_agent_mcp-0.1.11.dist-info → fast_agent_mcp-0.1.13.dist-info}/licenses/LICENSE +0 -0
@@ -1,5 +1,5 @@
|
|
1
1
|
import os
|
2
|
-
from typing import List, Type
|
2
|
+
from typing import TYPE_CHECKING, List, Type
|
3
3
|
|
4
4
|
from mcp_agent.workflows.llm.providers.multipart_converter_anthropic import (
|
5
5
|
AnthropicConverter,
|
@@ -9,6 +9,8 @@ from mcp_agent.workflows.llm.providers.sampling_converter_anthropic import (
|
|
9
9
|
)
|
10
10
|
|
11
11
|
if TYPE_CHECKING:
|
12
|
+
from mcp import ListToolsResult
|
13
|
+
|
12
14
|
from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
|
13
15
|
|
14
16
|
|
@@ -20,22 +22,22 @@ from anthropic.types import (
|
|
20
22
|
TextBlockParam,
|
21
23
|
ToolParam,
|
22
24
|
ToolUseBlockParam,
|
25
|
+
Usage,
|
23
26
|
)
|
24
27
|
from mcp.types import (
|
25
|
-
CallToolRequestParams,
|
26
28
|
CallToolRequest,
|
29
|
+
CallToolRequestParams,
|
27
30
|
)
|
28
31
|
from pydantic_core import from_json
|
32
|
+
from rich.text import Text
|
29
33
|
|
34
|
+
from mcp_agent.core.exceptions import ProviderKeyError
|
35
|
+
from mcp_agent.logging.logger import get_logger
|
30
36
|
from mcp_agent.workflows.llm.augmented_llm import (
|
31
37
|
AugmentedLLM,
|
32
38
|
ModelT,
|
33
39
|
RequestParams,
|
34
40
|
)
|
35
|
-
from mcp_agent.core.exceptions import ProviderKeyError
|
36
|
-
from rich.text import Text
|
37
|
-
|
38
|
-
from mcp_agent.logging.logger import get_logger
|
39
41
|
|
40
42
|
DEFAULT_ANTHROPIC_MODEL = "claude-3-7-sonnet-latest"
|
41
43
|
|
@@ -48,7 +50,7 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
48
50
|
selecting appropriate tools, and determining what information to retain.
|
49
51
|
"""
|
50
52
|
|
51
|
-
def __init__(self, *args, **kwargs):
|
53
|
+
def __init__(self, *args, **kwargs) -> None:
|
52
54
|
self.provider = "Anthropic"
|
53
55
|
# Initialize logger - keep it simple without name reference
|
54
56
|
self.logger = get_logger(__name__)
|
@@ -60,7 +62,6 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
60
62
|
"""Initialize Anthropic-specific default parameters"""
|
61
63
|
return RequestParams(
|
62
64
|
model=kwargs.get("model", DEFAULT_ANTHROPIC_MODEL),
|
63
|
-
modelPreferences=self.model_preferences,
|
64
65
|
maxTokens=4096, # default haiku3
|
65
66
|
systemPrompt=self.instruction,
|
66
67
|
parallel_tool_calls=True,
|
@@ -86,8 +87,7 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
86
87
|
except AuthenticationError as e:
|
87
88
|
raise ProviderKeyError(
|
88
89
|
"Invalid Anthropic API key",
|
89
|
-
"The configured Anthropic API key was rejected.\
|
90
|
-
"Please check that your API key is valid and not expired.",
|
90
|
+
"The configured Anthropic API key was rejected.\nPlease check that your API key is valid and not expired.",
|
91
91
|
) from e
|
92
92
|
|
93
93
|
# Always include prompt messages, but only include conversation history
|
@@ -101,14 +101,14 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
101
101
|
else:
|
102
102
|
messages.append(message)
|
103
103
|
|
104
|
-
|
104
|
+
tool_list: ListToolsResult = await self.aggregator.list_tools()
|
105
105
|
available_tools: List[ToolParam] = [
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
for tool in
|
106
|
+
ToolParam(
|
107
|
+
name=tool.name,
|
108
|
+
description=tool.description or "",
|
109
|
+
input_schema=tool.inputSchema,
|
110
|
+
)
|
111
|
+
for tool in tool_list.tools
|
112
112
|
]
|
113
113
|
|
114
114
|
responses: List[Message] = []
|
@@ -135,17 +135,14 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
135
135
|
|
136
136
|
self.logger.debug(f"{arguments}")
|
137
137
|
|
138
|
-
executor_result = await self.executor.execute(
|
139
|
-
anthropic.messages.create, **arguments
|
140
|
-
)
|
138
|
+
executor_result = await self.executor.execute(anthropic.messages.create, **arguments)
|
141
139
|
|
142
140
|
response = executor_result[0]
|
143
141
|
|
144
142
|
if isinstance(response, AuthenticationError):
|
145
143
|
raise ProviderKeyError(
|
146
144
|
"Invalid Anthropic API key",
|
147
|
-
"The configured Anthropic API key was rejected.\
|
148
|
-
"Please check that your API key is valid and not expired.",
|
145
|
+
"The configured Anthropic API key was rejected.\nPlease check that your API key is valid and not expired.",
|
149
146
|
) from response
|
150
147
|
elif isinstance(response, BaseException):
|
151
148
|
error_details = str(response)
|
@@ -155,13 +152,9 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
155
152
|
if hasattr(response, "status_code") and hasattr(response, "response"):
|
156
153
|
try:
|
157
154
|
error_json = response.response.json()
|
158
|
-
error_details =
|
159
|
-
f"Error code: {response.status_code} - {error_json}"
|
160
|
-
)
|
155
|
+
error_details = f"Error code: {response.status_code} - {error_json}"
|
161
156
|
except: # noqa: E722
|
162
|
-
error_details = (
|
163
|
-
f"Error code: {response.status_code} - {str(response)}"
|
164
|
-
)
|
157
|
+
error_details = f"Error code: {response.status_code} - {str(response)}"
|
165
158
|
|
166
159
|
# Convert other errors to text response
|
167
160
|
error_message = f"Error during generation: {error_details}"
|
@@ -172,7 +165,7 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
172
165
|
type="message",
|
173
166
|
content=[TextBlock(type="text", text=error_message)],
|
174
167
|
stop_reason="end_turn", # Must be one of the allowed values
|
175
|
-
usage=
|
168
|
+
usage=Usage(input_tokens=0, output_tokens=0), # Required field
|
176
169
|
)
|
177
170
|
|
178
171
|
self.logger.debug(
|
@@ -194,22 +187,16 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
194
187
|
|
195
188
|
await self.show_assistant_message(message_text)
|
196
189
|
|
197
|
-
self.logger.debug(
|
198
|
-
f"Iteration {i}: Stopping because finish_reason is 'end_turn'"
|
199
|
-
)
|
190
|
+
self.logger.debug(f"Iteration {i}: Stopping because finish_reason is 'end_turn'")
|
200
191
|
break
|
201
192
|
elif response.stop_reason == "stop_sequence":
|
202
193
|
# We have reached a stop sequence
|
203
|
-
self.logger.debug(
|
204
|
-
f"Iteration {i}: Stopping because finish_reason is 'stop_sequence'"
|
205
|
-
)
|
194
|
+
self.logger.debug(f"Iteration {i}: Stopping because finish_reason is 'stop_sequence'")
|
206
195
|
break
|
207
196
|
elif response.stop_reason == "max_tokens":
|
208
197
|
# We have reached the max tokens limit
|
209
198
|
|
210
|
-
self.logger.debug(
|
211
|
-
f"Iteration {i}: Stopping because finish_reason is 'max_tokens'"
|
212
|
-
)
|
199
|
+
self.logger.debug(f"Iteration {i}: Stopping because finish_reason is 'max_tokens'")
|
213
200
|
if params.maxTokens is not None:
|
214
201
|
message_text = Text(
|
215
202
|
f"the assistant has reached the maximum token limit ({params.maxTokens})",
|
@@ -256,22 +243,16 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
256
243
|
self.show_tool_call(available_tools, tool_name, tool_args)
|
257
244
|
tool_call_request = CallToolRequest(
|
258
245
|
method="tools/call",
|
259
|
-
params=CallToolRequestParams(
|
260
|
-
name=tool_name, arguments=tool_args
|
261
|
-
),
|
246
|
+
params=CallToolRequestParams(name=tool_name, arguments=tool_args),
|
262
247
|
)
|
263
248
|
# TODO -- support MCP isError etc.
|
264
|
-
result = await self.call_tool(
|
265
|
-
request=tool_call_request, tool_call_id=tool_use_id
|
266
|
-
)
|
249
|
+
result = await self.call_tool(request=tool_call_request, tool_call_id=tool_use_id)
|
267
250
|
self.show_tool_result(result)
|
268
251
|
|
269
252
|
# Add each result to our collection
|
270
253
|
tool_results.append((tool_use_id, result))
|
271
254
|
|
272
|
-
messages.append(
|
273
|
-
AnthropicConverter.create_tool_results_message(tool_results)
|
274
|
-
)
|
255
|
+
messages.append(AnthropicConverter.create_tool_results_message(tool_results))
|
275
256
|
|
276
257
|
# Only save the new conversation messages to history if use_history is true
|
277
258
|
# Keep the prompt messages separate
|
@@ -352,15 +333,13 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
352
333
|
# Join all collected text
|
353
334
|
return "\n".join(final_text)
|
354
335
|
|
355
|
-
async def generate_prompt(
|
356
|
-
self
|
357
|
-
) -> str:
|
358
|
-
return await self.generate_str(
|
359
|
-
AnthropicConverter.convert_to_anthropic(prompt), request_params
|
360
|
-
)
|
336
|
+
async def generate_prompt(self, prompt: "PromptMessageMultipart", request_params: RequestParams | None) -> str:
|
337
|
+
return await self.generate_str(AnthropicConverter.convert_to_anthropic(prompt), request_params)
|
361
338
|
|
362
339
|
async def _apply_prompt_template_provider_specific(
|
363
|
-
self,
|
340
|
+
self,
|
341
|
+
multipart_messages: List["PromptMessageMultipart"],
|
342
|
+
request_params: RequestParams | None = None,
|
364
343
|
) -> str:
|
365
344
|
"""
|
366
345
|
Anthropic-specific implementation of apply_prompt_template that handles
|
@@ -377,11 +356,7 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
377
356
|
last_message = multipart_messages[-1]
|
378
357
|
|
379
358
|
# Add all previous messages to history (or all messages if last is from assistant)
|
380
|
-
messages_to_add =
|
381
|
-
multipart_messages[:-1]
|
382
|
-
if last_message.role == "user"
|
383
|
-
else multipart_messages
|
384
|
-
)
|
359
|
+
messages_to_add = multipart_messages[:-1] if last_message.role == "user" else multipart_messages
|
385
360
|
converted = []
|
386
361
|
for msg in messages_to_add:
|
387
362
|
converted.append(AnthropicConverter.convert_to_anthropic(msg))
|
@@ -389,16 +364,12 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
389
364
|
|
390
365
|
if last_message.role == "user":
|
391
366
|
# For user messages: Generate response to the last one
|
392
|
-
self.logger.debug(
|
393
|
-
"Last message in prompt is from user, generating assistant response"
|
394
|
-
)
|
367
|
+
self.logger.debug("Last message in prompt is from user, generating assistant response")
|
395
368
|
message_param = AnthropicConverter.convert_to_anthropic(last_message)
|
396
|
-
return await self.generate_str(message_param)
|
369
|
+
return await self.generate_str(message_param, request_params)
|
397
370
|
else:
|
398
371
|
# For assistant messages: Return the last message content as text
|
399
|
-
self.logger.debug(
|
400
|
-
"Last message in prompt is from assistant, returning it directly"
|
401
|
-
)
|
372
|
+
self.logger.debug("Last message in prompt is from assistant, returning it directly")
|
402
373
|
return str(last_message)
|
403
374
|
|
404
375
|
async def _save_history_to_file(self, command: str) -> str:
|
@@ -423,19 +394,17 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
423
394
|
messages = self.history.get(include_history=True)
|
424
395
|
|
425
396
|
# Import required utilities
|
426
|
-
from mcp_agent.workflows.llm.anthropic_utils import (
|
427
|
-
anthropic_message_param_to_prompt_message_multipart,
|
428
|
-
)
|
429
397
|
from mcp_agent.mcp.prompt_serialization import (
|
430
398
|
multipart_messages_to_delimited_format,
|
431
399
|
)
|
400
|
+
from mcp_agent.workflows.llm.anthropic_utils import (
|
401
|
+
anthropic_message_param_to_prompt_message_multipart,
|
402
|
+
)
|
432
403
|
|
433
404
|
# Convert message params to PromptMessageMultipart objects
|
434
405
|
multipart_messages = []
|
435
406
|
for msg in messages:
|
436
|
-
multipart_messages.append(
|
437
|
-
anthropic_message_param_to_prompt_message_multipart(msg)
|
438
|
-
)
|
407
|
+
multipart_messages.append(anthropic_message_param_to_prompt_message_multipart(msg))
|
439
408
|
|
440
409
|
# Convert to delimited format
|
441
410
|
delimited_content = multipart_messages_to_delimited_format(
|
@@ -457,7 +426,7 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
457
426
|
|
458
427
|
async def generate_structured(
|
459
428
|
self,
|
460
|
-
message,
|
429
|
+
message: str,
|
461
430
|
response_model: Type[ModelT],
|
462
431
|
request_params: RequestParams | None = None,
|
463
432
|
) -> ModelT:
|
@@ -474,9 +443,7 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
474
443
|
return response_model.model_validate(from_json(response, allow_partial=True))
|
475
444
|
|
476
445
|
@classmethod
|
477
|
-
def convert_message_to_message_param(
|
478
|
-
cls, message: Message, **kwargs
|
479
|
-
) -> MessageParam:
|
446
|
+
def convert_message_to_message_param(cls, message: Message, **kwargs) -> MessageParam:
|
480
447
|
"""Convert a response object to an input parameter object to allow LLM calls to be chained."""
|
481
448
|
content = []
|
482
449
|
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import os
|
2
|
-
from typing import List, Type
|
2
|
+
from typing import TYPE_CHECKING, List, Type
|
3
3
|
|
4
4
|
from pydantic_core import from_json
|
5
5
|
|
@@ -10,30 +10,30 @@ from mcp_agent.workflows.llm.providers.sampling_converter_openai import (
|
|
10
10
|
|
11
11
|
if TYPE_CHECKING:
|
12
12
|
from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
|
13
|
-
from
|
13
|
+
from mcp.types import (
|
14
|
+
CallToolRequest,
|
15
|
+
CallToolRequestParams,
|
16
|
+
CallToolResult,
|
17
|
+
)
|
18
|
+
from openai import AuthenticationError, OpenAI
|
14
19
|
|
15
20
|
# from openai.types.beta.chat import
|
16
21
|
from openai.types.chat import (
|
17
|
-
ChatCompletionMessageParam,
|
18
22
|
ChatCompletionMessage,
|
23
|
+
ChatCompletionMessageParam,
|
19
24
|
ChatCompletionSystemMessageParam,
|
20
25
|
ChatCompletionToolParam,
|
21
26
|
ChatCompletionUserMessageParam,
|
22
27
|
)
|
23
|
-
from
|
24
|
-
CallToolRequestParams,
|
25
|
-
CallToolRequest,
|
26
|
-
CallToolResult,
|
27
|
-
)
|
28
|
+
from rich.text import Text
|
28
29
|
|
30
|
+
from mcp_agent.core.exceptions import ProviderKeyError
|
31
|
+
from mcp_agent.logging.logger import get_logger
|
29
32
|
from mcp_agent.workflows.llm.augmented_llm import (
|
30
33
|
AugmentedLLM,
|
31
34
|
ModelT,
|
32
35
|
RequestParams,
|
33
36
|
)
|
34
|
-
from mcp_agent.core.exceptions import ProviderKeyError
|
35
|
-
from mcp_agent.logging.logger import get_logger
|
36
|
-
from rich.text import Text
|
37
37
|
|
38
38
|
_logger = get_logger(__name__)
|
39
39
|
|
@@ -41,16 +41,14 @@ DEFAULT_OPENAI_MODEL = "gpt-4o"
|
|
41
41
|
DEFAULT_REASONING_EFFORT = "medium"
|
42
42
|
|
43
43
|
|
44
|
-
class OpenAIAugmentedLLM(
|
45
|
-
AugmentedLLM[ChatCompletionMessageParam, ChatCompletionMessage]
|
46
|
-
):
|
44
|
+
class OpenAIAugmentedLLM(AugmentedLLM[ChatCompletionMessageParam, ChatCompletionMessage]):
|
47
45
|
"""
|
48
46
|
The basic building block of agentic systems is an LLM enhanced with augmentations
|
49
47
|
such as retrieval, tools, and memory provided from a collection of MCP servers.
|
50
48
|
This implementation uses OpenAI's ChatCompletion as the LLM.
|
51
49
|
"""
|
52
50
|
|
53
|
-
def __init__(self, *args, **kwargs):
|
51
|
+
def __init__(self, *args, **kwargs) -> None:
|
54
52
|
# Set type_converter before calling super().__init__
|
55
53
|
if "type_converter" not in kwargs:
|
56
54
|
kwargs["type_converter"] = OpenAISamplingConverter
|
@@ -64,22 +62,14 @@ class OpenAIAugmentedLLM(
|
|
64
62
|
# Set up reasoning-related attributes
|
65
63
|
self._reasoning_effort = kwargs.get("reasoning_effort", None)
|
66
64
|
if self.context and self.context.config and self.context.config.openai:
|
67
|
-
if self._reasoning_effort is None and hasattr(
|
68
|
-
self.context.config.openai, "reasoning_effort"
|
69
|
-
):
|
65
|
+
if self._reasoning_effort is None and hasattr(self.context.config.openai, "reasoning_effort"):
|
70
66
|
self._reasoning_effort = self.context.config.openai.reasoning_effort
|
71
67
|
|
72
68
|
# Determine if we're using a reasoning model
|
73
|
-
chosen_model =
|
74
|
-
|
75
|
-
)
|
76
|
-
self._reasoning = chosen_model and (
|
77
|
-
chosen_model.startswith("o3") or chosen_model.startswith("o1")
|
78
|
-
)
|
69
|
+
chosen_model = self.default_request_params.model if self.default_request_params else None
|
70
|
+
self._reasoning = chosen_model and (chosen_model.startswith("o3") or chosen_model.startswith("o1"))
|
79
71
|
if self._reasoning:
|
80
|
-
self.logger.info(
|
81
|
-
f"Using reasoning model '{chosen_model}' with '{self._reasoning_effort}' reasoning effort"
|
82
|
-
)
|
72
|
+
self.logger.info(f"Using reasoning model '{chosen_model}' with '{self._reasoning_effort}' reasoning effort")
|
83
73
|
|
84
74
|
def _initialize_default_params(self, kwargs: dict) -> RequestParams:
|
85
75
|
"""Initialize OpenAI-specific default parameters"""
|
@@ -92,7 +82,6 @@ class OpenAIAugmentedLLM(
|
|
92
82
|
|
93
83
|
return RequestParams(
|
94
84
|
model=chosen_model,
|
95
|
-
modelPreferences=self.model_preferences,
|
96
85
|
systemPrompt=self.instruction,
|
97
86
|
parallel_tool_calls=True,
|
98
87
|
max_iterations=10,
|
@@ -121,9 +110,7 @@ class OpenAIAugmentedLLM(
|
|
121
110
|
return api_key
|
122
111
|
|
123
112
|
def _base_url(self) -> str:
|
124
|
-
return
|
125
|
-
self.context.config.openai.base_url if self.context.config.openai else None
|
126
|
-
)
|
113
|
+
return self.context.config.openai.base_url if self.context.config.openai else None
|
127
114
|
|
128
115
|
async def generate(
|
129
116
|
self,
|
@@ -144,24 +131,19 @@ class OpenAIAugmentedLLM(
|
|
144
131
|
except AuthenticationError as e:
|
145
132
|
raise ProviderKeyError(
|
146
133
|
"Invalid OpenAI API key",
|
147
|
-
"The configured OpenAI API key was rejected.\n"
|
148
|
-
"Please check that your API key is valid and not expired.",
|
134
|
+
"The configured OpenAI API key was rejected.\n" "Please check that your API key is valid and not expired.",
|
149
135
|
) from e
|
150
136
|
|
151
137
|
system_prompt = self.instruction or params.systemPrompt
|
152
138
|
if system_prompt:
|
153
|
-
messages.append(
|
154
|
-
ChatCompletionSystemMessageParam(role="system", content=system_prompt)
|
155
|
-
)
|
139
|
+
messages.append(ChatCompletionSystemMessageParam(role="system", content=system_prompt))
|
156
140
|
|
157
141
|
# Always include prompt messages, but only include conversation history
|
158
142
|
# if use_history is True
|
159
143
|
messages.extend(self.history.get(include_history=params.use_history))
|
160
144
|
|
161
145
|
if isinstance(message, str):
|
162
|
-
messages.append(
|
163
|
-
ChatCompletionUserMessageParam(role="user", content=message)
|
164
|
-
)
|
146
|
+
messages.append(ChatCompletionUserMessageParam(role="user", content=message))
|
165
147
|
elif isinstance(message, list):
|
166
148
|
messages.extend(message)
|
167
149
|
else:
|
@@ -187,9 +169,7 @@ class OpenAIAugmentedLLM(
|
|
187
169
|
model = await self.select_model(params)
|
188
170
|
chat_turn = len(messages) // 2
|
189
171
|
if self._reasoning:
|
190
|
-
self.show_user_message(
|
191
|
-
str(message), f"{model} ({self._reasoning_effort})", chat_turn
|
192
|
-
)
|
172
|
+
self.show_user_message(str(message), f"{model} ({self._reasoning_effort})", chat_turn)
|
193
173
|
else:
|
194
174
|
self.show_user_message(str(message), model, chat_turn)
|
195
175
|
|
@@ -218,9 +198,7 @@ class OpenAIAugmentedLLM(
|
|
218
198
|
self._log_chat_progress(chat_turn, model=model)
|
219
199
|
|
220
200
|
if response_model is None:
|
221
|
-
executor_result = await self.executor.execute(
|
222
|
-
openai_client.chat.completions.create, **arguments
|
223
|
-
)
|
201
|
+
executor_result = await self.executor.execute(openai_client.chat.completions.create, **arguments)
|
224
202
|
else:
|
225
203
|
executor_result = await self.executor.execute(
|
226
204
|
openai_client.beta.chat.completions.parse,
|
@@ -238,8 +216,7 @@ class OpenAIAugmentedLLM(
|
|
238
216
|
if isinstance(response, AuthenticationError):
|
239
217
|
raise ProviderKeyError(
|
240
218
|
"Invalid OpenAI API key",
|
241
|
-
"The configured OpenAI API key was rejected.\n"
|
242
|
-
"Please check that your API key is valid and not expired.",
|
219
|
+
"The configured OpenAI API key was rejected.\n" "Please check that your API key is valid and not expired.",
|
243
220
|
) from response
|
244
221
|
elif isinstance(response, BaseException):
|
245
222
|
self.logger.error(f"Error: {response}")
|
@@ -255,21 +232,14 @@ class OpenAIAugmentedLLM(
|
|
255
232
|
message = choice.message
|
256
233
|
responses.append(message)
|
257
234
|
|
258
|
-
converted_message = self.convert_message_to_message_param(
|
259
|
-
message, name=self.name
|
260
|
-
)
|
235
|
+
converted_message = self.convert_message_to_message_param(message, name=self.name)
|
261
236
|
messages.append(converted_message)
|
262
237
|
message_text = converted_message.content
|
263
|
-
if
|
264
|
-
choice.finish_reason in ["tool_calls", "function_call"]
|
265
|
-
and message.tool_calls
|
266
|
-
):
|
238
|
+
if choice.finish_reason in ["tool_calls", "function_call"] and message.tool_calls:
|
267
239
|
if message_text:
|
268
240
|
await self.show_assistant_message(
|
269
241
|
message_text,
|
270
|
-
message.tool_calls[
|
271
|
-
0
|
272
|
-
].function.name, # TODO support displaying multiple tool calls
|
242
|
+
message.tool_calls[0].function.name, # TODO support displaying multiple tool calls
|
273
243
|
)
|
274
244
|
else:
|
275
245
|
await self.show_assistant_message(
|
@@ -291,9 +261,7 @@ class OpenAIAugmentedLLM(
|
|
291
261
|
method="tools/call",
|
292
262
|
params=CallToolRequestParams(
|
293
263
|
name=tool_call.function.name,
|
294
|
-
arguments=from_json(
|
295
|
-
tool_call.function.arguments, allow_partial=True
|
296
|
-
),
|
264
|
+
arguments=from_json(tool_call.function.arguments, allow_partial=True),
|
297
265
|
),
|
298
266
|
)
|
299
267
|
result = await self.call_tool(tool_call_request, tool_call.id)
|
@@ -301,18 +269,12 @@ class OpenAIAugmentedLLM(
|
|
301
269
|
|
302
270
|
tool_results.append((tool_call.id, result))
|
303
271
|
|
304
|
-
messages.extend(
|
305
|
-
OpenAIConverter.convert_function_results_to_openai(tool_results)
|
306
|
-
)
|
272
|
+
messages.extend(OpenAIConverter.convert_function_results_to_openai(tool_results))
|
307
273
|
|
308
|
-
self.logger.debug(
|
309
|
-
f"Iteration {i}: Tool call results: {str(tool_results) if tool_results else 'None'}"
|
310
|
-
)
|
274
|
+
self.logger.debug(f"Iteration {i}: Tool call results: {str(tool_results) if tool_results else 'None'}")
|
311
275
|
elif choice.finish_reason == "length":
|
312
276
|
# We have reached the max tokens limit
|
313
|
-
self.logger.debug(
|
314
|
-
f"Iteration {i}: Stopping because finish_reason is 'length'"
|
315
|
-
)
|
277
|
+
self.logger.debug(f"Iteration {i}: Stopping because finish_reason is 'length'")
|
316
278
|
if request_params and request_params.maxTokens is not None:
|
317
279
|
message_text = Text(
|
318
280
|
f"the assistant has reached the maximum token limit ({request_params.maxTokens})",
|
@@ -329,15 +291,11 @@ class OpenAIAugmentedLLM(
|
|
329
291
|
break
|
330
292
|
elif choice.finish_reason == "content_filter":
|
331
293
|
# The response was filtered by the content filter
|
332
|
-
self.logger.debug(
|
333
|
-
f"Iteration {i}: Stopping because finish_reason is 'content_filter'"
|
334
|
-
)
|
294
|
+
self.logger.debug(f"Iteration {i}: Stopping because finish_reason is 'content_filter'")
|
335
295
|
# TODO: saqadri - would be useful to return the reason for stopping to the caller
|
336
296
|
break
|
337
297
|
elif choice.finish_reason == "stop":
|
338
|
-
self.logger.debug(
|
339
|
-
f"Iteration {i}: Stopping because finish_reason is 'stop'"
|
340
|
-
)
|
298
|
+
self.logger.debug(f"Iteration {i}: Stopping because finish_reason is 'stop'")
|
341
299
|
if message_text:
|
342
300
|
await self.show_assistant_message(message_text, "")
|
343
301
|
break
|
@@ -395,7 +353,9 @@ class OpenAIAugmentedLLM(
|
|
395
353
|
return "\n".join(final_text)
|
396
354
|
|
397
355
|
async def _apply_prompt_template_provider_specific(
|
398
|
-
self,
|
356
|
+
self,
|
357
|
+
multipart_messages: List["PromptMessageMultipart"],
|
358
|
+
request_params: RequestParams | None = None,
|
399
359
|
) -> str:
|
400
360
|
"""
|
401
361
|
OpenAI-specific implementation of apply_prompt_template that handles
|
@@ -415,11 +375,7 @@ class OpenAIAugmentedLLM(
|
|
415
375
|
last_message = multipart_messages[-1]
|
416
376
|
|
417
377
|
# Add all previous messages to history (or all messages if last is from assistant)
|
418
|
-
messages_to_add =
|
419
|
-
multipart_messages[:-1]
|
420
|
-
if last_message.role == "user"
|
421
|
-
else multipart_messages
|
422
|
-
)
|
378
|
+
messages_to_add = multipart_messages[:-1] if last_message.role == "user" else multipart_messages
|
423
379
|
converted = []
|
424
380
|
for msg in messages_to_add:
|
425
381
|
converted.append(OpenAIConverter.convert_to_openai(msg))
|
@@ -427,16 +383,12 @@ class OpenAIAugmentedLLM(
|
|
427
383
|
|
428
384
|
if last_message.role == "user":
|
429
385
|
# For user messages: Generate response to the last one
|
430
|
-
self.logger.debug(
|
431
|
-
"Last message in prompt is from user, generating assistant response"
|
432
|
-
)
|
386
|
+
self.logger.debug("Last message in prompt is from user, generating assistant response")
|
433
387
|
message_param = OpenAIConverter.convert_to_openai(last_message)
|
434
|
-
return await self.generate_str(message_param)
|
388
|
+
return await self.generate_str(message_param, request_params)
|
435
389
|
else:
|
436
390
|
# For assistant messages: Return the last message content as text
|
437
|
-
self.logger.debug(
|
438
|
-
"Last message in prompt is from assistant, returning it directly"
|
439
|
-
)
|
391
|
+
self.logger.debug("Last message in prompt is from assistant, returning it directly")
|
440
392
|
return str(last_message)
|
441
393
|
|
442
394
|
async def _save_history_to_file(self, command: str) -> str:
|
@@ -461,12 +413,12 @@ class OpenAIAugmentedLLM(
|
|
461
413
|
messages = self.history.get(include_history=True)
|
462
414
|
|
463
415
|
# Import required utilities
|
464
|
-
from mcp_agent.workflows.llm.openai_utils import (
|
465
|
-
openai_message_param_to_prompt_message_multipart,
|
466
|
-
)
|
467
416
|
from mcp_agent.mcp.prompt_serialization import (
|
468
417
|
multipart_messages_to_delimited_format,
|
469
418
|
)
|
419
|
+
from mcp_agent.workflows.llm.openai_utils import (
|
420
|
+
openai_message_param_to_prompt_message_multipart,
|
421
|
+
)
|
470
422
|
|
471
423
|
# Convert message params to PromptMessageMultipart objects
|
472
424
|
multipart_messages = []
|
@@ -476,9 +428,7 @@ class OpenAIAugmentedLLM(
|
|
476
428
|
continue
|
477
429
|
|
478
430
|
# Convert the message to a multipart message
|
479
|
-
multipart_messages.append(
|
480
|
-
openai_message_param_to_prompt_message_multipart(msg)
|
481
|
-
)
|
431
|
+
multipart_messages.append(openai_message_param_to_prompt_message_multipart(msg))
|
482
432
|
|
483
433
|
# Convert to delimited format
|
484
434
|
delimited_content = multipart_messages_to_delimited_format(
|
@@ -511,18 +461,14 @@ class OpenAIAugmentedLLM(
|
|
511
461
|
)
|
512
462
|
return responses[0].parsed
|
513
463
|
|
514
|
-
async def generate_prompt(
|
515
|
-
self, prompt: "PromptMessageMultipart", request_params: RequestParams | None
|
516
|
-
) -> str:
|
464
|
+
async def generate_prompt(self, prompt: "PromptMessageMultipart", request_params: RequestParams | None) -> str:
|
517
465
|
converted_prompt = OpenAIConverter.convert_to_openai(prompt)
|
518
466
|
return await self.generate_str(converted_prompt, request_params)
|
519
467
|
|
520
468
|
async def pre_tool_call(self, tool_call_id: str | None, request: CallToolRequest):
|
521
469
|
return request
|
522
470
|
|
523
|
-
async def post_tool_call(
|
524
|
-
self, tool_call_id: str | None, request: CallToolRequest, result: CallToolResult
|
525
|
-
):
|
471
|
+
async def post_tool_call(self, tool_call_id: str | None, request: CallToolRequest, result: CallToolResult):
|
526
472
|
return result
|
527
473
|
|
528
474
|
def message_param_str(self, message: ChatCompletionMessageParam) -> str:
|