fast-agent-mcp 0.1.11__py3-none-any.whl → 0.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. {fast_agent_mcp-0.1.11.dist-info → fast_agent_mcp-0.1.13.dist-info}/METADATA +1 -1
  2. fast_agent_mcp-0.1.13.dist-info/RECORD +164 -0
  3. mcp_agent/agents/agent.py +37 -102
  4. mcp_agent/app.py +16 -27
  5. mcp_agent/cli/commands/bootstrap.py +22 -52
  6. mcp_agent/cli/commands/config.py +4 -4
  7. mcp_agent/cli/commands/setup.py +11 -26
  8. mcp_agent/cli/main.py +6 -9
  9. mcp_agent/cli/terminal.py +2 -2
  10. mcp_agent/config.py +1 -5
  11. mcp_agent/context.py +13 -26
  12. mcp_agent/context_dependent.py +3 -7
  13. mcp_agent/core/agent_app.py +46 -122
  14. mcp_agent/core/agent_types.py +29 -2
  15. mcp_agent/core/agent_utils.py +3 -5
  16. mcp_agent/core/decorators.py +6 -14
  17. mcp_agent/core/enhanced_prompt.py +25 -52
  18. mcp_agent/core/error_handling.py +1 -1
  19. mcp_agent/core/exceptions.py +8 -8
  20. mcp_agent/core/factory.py +30 -72
  21. mcp_agent/core/fastagent.py +48 -88
  22. mcp_agent/core/mcp_content.py +10 -19
  23. mcp_agent/core/prompt.py +8 -15
  24. mcp_agent/core/proxies.py +34 -25
  25. mcp_agent/core/request_params.py +46 -0
  26. mcp_agent/core/types.py +6 -6
  27. mcp_agent/core/validation.py +16 -16
  28. mcp_agent/executor/decorator_registry.py +11 -23
  29. mcp_agent/executor/executor.py +8 -17
  30. mcp_agent/executor/task_registry.py +2 -4
  31. mcp_agent/executor/temporal.py +28 -74
  32. mcp_agent/executor/workflow.py +3 -5
  33. mcp_agent/executor/workflow_signal.py +17 -29
  34. mcp_agent/human_input/handler.py +4 -9
  35. mcp_agent/human_input/types.py +2 -3
  36. mcp_agent/logging/events.py +1 -5
  37. mcp_agent/logging/json_serializer.py +7 -6
  38. mcp_agent/logging/listeners.py +20 -23
  39. mcp_agent/logging/logger.py +15 -17
  40. mcp_agent/logging/rich_progress.py +10 -8
  41. mcp_agent/logging/tracing.py +4 -6
  42. mcp_agent/logging/transport.py +24 -24
  43. mcp_agent/mcp/gen_client.py +4 -12
  44. mcp_agent/mcp/interfaces.py +107 -88
  45. mcp_agent/mcp/mcp_agent_client_session.py +11 -19
  46. mcp_agent/mcp/mcp_agent_server.py +8 -10
  47. mcp_agent/mcp/mcp_aggregator.py +49 -122
  48. mcp_agent/mcp/mcp_connection_manager.py +16 -37
  49. mcp_agent/mcp/prompt_message_multipart.py +12 -18
  50. mcp_agent/mcp/prompt_serialization.py +13 -38
  51. mcp_agent/mcp/prompts/prompt_load.py +99 -0
  52. mcp_agent/mcp/prompts/prompt_server.py +21 -128
  53. mcp_agent/mcp/prompts/prompt_template.py +20 -42
  54. mcp_agent/mcp/resource_utils.py +8 -17
  55. mcp_agent/mcp/sampling.py +62 -64
  56. mcp_agent/mcp/stdio.py +11 -8
  57. mcp_agent/mcp_server/__init__.py +1 -1
  58. mcp_agent/mcp_server/agent_server.py +10 -17
  59. mcp_agent/mcp_server_registry.py +13 -35
  60. mcp_agent/resources/examples/data-analysis/analysis-campaign.py +1 -1
  61. mcp_agent/resources/examples/data-analysis/analysis.py +1 -1
  62. mcp_agent/resources/examples/data-analysis/slides.py +110 -0
  63. mcp_agent/resources/examples/internal/agent.py +2 -1
  64. mcp_agent/resources/examples/internal/job.py +2 -1
  65. mcp_agent/resources/examples/internal/prompt_category.py +1 -1
  66. mcp_agent/resources/examples/internal/prompt_sizing.py +3 -5
  67. mcp_agent/resources/examples/internal/sizer.py +2 -1
  68. mcp_agent/resources/examples/internal/social.py +2 -1
  69. mcp_agent/resources/examples/mcp_researcher/researcher-eval.py +1 -1
  70. mcp_agent/resources/examples/prompting/__init__.py +1 -1
  71. mcp_agent/resources/examples/prompting/agent.py +2 -1
  72. mcp_agent/resources/examples/prompting/image_server.py +5 -11
  73. mcp_agent/resources/examples/researcher/researcher-eval.py +1 -1
  74. mcp_agent/resources/examples/researcher/researcher-imp.py +3 -4
  75. mcp_agent/resources/examples/researcher/researcher.py +2 -1
  76. mcp_agent/resources/examples/workflows/agent_build.py +2 -1
  77. mcp_agent/resources/examples/workflows/chaining.py +2 -1
  78. mcp_agent/resources/examples/workflows/evaluator.py +2 -1
  79. mcp_agent/resources/examples/workflows/human_input.py +2 -1
  80. mcp_agent/resources/examples/workflows/orchestrator.py +2 -1
  81. mcp_agent/resources/examples/workflows/parallel.py +2 -1
  82. mcp_agent/resources/examples/workflows/router.py +2 -1
  83. mcp_agent/resources/examples/workflows/sse.py +1 -1
  84. mcp_agent/telemetry/usage_tracking.py +2 -1
  85. mcp_agent/ui/console_display.py +17 -41
  86. mcp_agent/workflows/embedding/embedding_base.py +1 -4
  87. mcp_agent/workflows/embedding/embedding_cohere.py +2 -2
  88. mcp_agent/workflows/embedding/embedding_openai.py +4 -13
  89. mcp_agent/workflows/evaluator_optimizer/evaluator_optimizer.py +23 -57
  90. mcp_agent/workflows/intent_classifier/intent_classifier_base.py +5 -8
  91. mcp_agent/workflows/intent_classifier/intent_classifier_embedding.py +7 -11
  92. mcp_agent/workflows/intent_classifier/intent_classifier_embedding_cohere.py +4 -8
  93. mcp_agent/workflows/intent_classifier/intent_classifier_embedding_openai.py +4 -8
  94. mcp_agent/workflows/intent_classifier/intent_classifier_llm.py +11 -22
  95. mcp_agent/workflows/intent_classifier/intent_classifier_llm_anthropic.py +3 -3
  96. mcp_agent/workflows/intent_classifier/intent_classifier_llm_openai.py +4 -6
  97. mcp_agent/workflows/llm/anthropic_utils.py +8 -29
  98. mcp_agent/workflows/llm/augmented_llm.py +94 -332
  99. mcp_agent/workflows/llm/augmented_llm_anthropic.py +43 -76
  100. mcp_agent/workflows/llm/augmented_llm_openai.py +46 -100
  101. mcp_agent/workflows/llm/augmented_llm_passthrough.py +42 -20
  102. mcp_agent/workflows/llm/augmented_llm_playback.py +8 -6
  103. mcp_agent/workflows/llm/memory.py +103 -0
  104. mcp_agent/workflows/llm/model_factory.py +9 -21
  105. mcp_agent/workflows/llm/openai_utils.py +1 -1
  106. mcp_agent/workflows/llm/prompt_utils.py +39 -27
  107. mcp_agent/workflows/llm/providers/multipart_converter_anthropic.py +246 -184
  108. mcp_agent/workflows/llm/providers/multipart_converter_openai.py +212 -202
  109. mcp_agent/workflows/llm/providers/openai_multipart.py +19 -61
  110. mcp_agent/workflows/llm/providers/sampling_converter_anthropic.py +11 -212
  111. mcp_agent/workflows/llm/providers/sampling_converter_openai.py +13 -215
  112. mcp_agent/workflows/llm/sampling_converter.py +117 -0
  113. mcp_agent/workflows/llm/sampling_format_converter.py +12 -29
  114. mcp_agent/workflows/orchestrator/orchestrator.py +24 -67
  115. mcp_agent/workflows/orchestrator/orchestrator_models.py +14 -40
  116. mcp_agent/workflows/parallel/fan_in.py +17 -47
  117. mcp_agent/workflows/parallel/fan_out.py +6 -12
  118. mcp_agent/workflows/parallel/parallel_llm.py +9 -26
  119. mcp_agent/workflows/router/router_base.py +29 -59
  120. mcp_agent/workflows/router/router_embedding.py +11 -25
  121. mcp_agent/workflows/router/router_embedding_cohere.py +2 -2
  122. mcp_agent/workflows/router/router_embedding_openai.py +2 -2
  123. mcp_agent/workflows/router/router_llm.py +12 -28
  124. mcp_agent/workflows/swarm/swarm.py +20 -48
  125. mcp_agent/workflows/swarm/swarm_anthropic.py +2 -2
  126. mcp_agent/workflows/swarm/swarm_openai.py +2 -2
  127. fast_agent_mcp-0.1.11.dist-info/RECORD +0 -160
  128. mcp_agent/workflows/llm/llm_selector.py +0 -345
  129. {fast_agent_mcp-0.1.11.dist-info → fast_agent_mcp-0.1.13.dist-info}/WHEEL +0 -0
  130. {fast_agent_mcp-0.1.11.dist-info → fast_agent_mcp-0.1.13.dist-info}/entry_points.txt +0 -0
  131. {fast_agent_mcp-0.1.11.dist-info → fast_agent_mcp-0.1.13.dist-info}/licenses/LICENSE +0 -0
@@ -1,5 +1,5 @@
1
1
  import os
2
- from typing import List, Type, TYPE_CHECKING
2
+ from typing import TYPE_CHECKING, List, Type
3
3
 
4
4
  from mcp_agent.workflows.llm.providers.multipart_converter_anthropic import (
5
5
  AnthropicConverter,
@@ -9,6 +9,8 @@ from mcp_agent.workflows.llm.providers.sampling_converter_anthropic import (
9
9
  )
10
10
 
11
11
  if TYPE_CHECKING:
12
+ from mcp import ListToolsResult
13
+
12
14
  from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
13
15
 
14
16
 
@@ -20,22 +22,22 @@ from anthropic.types import (
20
22
  TextBlockParam,
21
23
  ToolParam,
22
24
  ToolUseBlockParam,
25
+ Usage,
23
26
  )
24
27
  from mcp.types import (
25
- CallToolRequestParams,
26
28
  CallToolRequest,
29
+ CallToolRequestParams,
27
30
  )
28
31
  from pydantic_core import from_json
32
+ from rich.text import Text
29
33
 
34
+ from mcp_agent.core.exceptions import ProviderKeyError
35
+ from mcp_agent.logging.logger import get_logger
30
36
  from mcp_agent.workflows.llm.augmented_llm import (
31
37
  AugmentedLLM,
32
38
  ModelT,
33
39
  RequestParams,
34
40
  )
35
- from mcp_agent.core.exceptions import ProviderKeyError
36
- from rich.text import Text
37
-
38
- from mcp_agent.logging.logger import get_logger
39
41
 
40
42
  DEFAULT_ANTHROPIC_MODEL = "claude-3-7-sonnet-latest"
41
43
 
@@ -48,7 +50,7 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
48
50
  selecting appropriate tools, and determining what information to retain.
49
51
  """
50
52
 
51
- def __init__(self, *args, **kwargs):
53
+ def __init__(self, *args, **kwargs) -> None:
52
54
  self.provider = "Anthropic"
53
55
  # Initialize logger - keep it simple without name reference
54
56
  self.logger = get_logger(__name__)
@@ -60,7 +62,6 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
60
62
  """Initialize Anthropic-specific default parameters"""
61
63
  return RequestParams(
62
64
  model=kwargs.get("model", DEFAULT_ANTHROPIC_MODEL),
63
- modelPreferences=self.model_preferences,
64
65
  maxTokens=4096, # default haiku3
65
66
  systemPrompt=self.instruction,
66
67
  parallel_tool_calls=True,
@@ -86,8 +87,7 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
86
87
  except AuthenticationError as e:
87
88
  raise ProviderKeyError(
88
89
  "Invalid Anthropic API key",
89
- "The configured Anthropic API key was rejected.\n"
90
- "Please check that your API key is valid and not expired.",
90
+ "The configured Anthropic API key was rejected.\nPlease check that your API key is valid and not expired.",
91
91
  ) from e
92
92
 
93
93
  # Always include prompt messages, but only include conversation history
@@ -101,14 +101,14 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
101
101
  else:
102
102
  messages.append(message)
103
103
 
104
- response = await self.aggregator.list_tools()
104
+ tool_list: ListToolsResult = await self.aggregator.list_tools()
105
105
  available_tools: List[ToolParam] = [
106
- {
107
- "name": tool.name,
108
- "description": tool.description,
109
- "input_schema": tool.inputSchema,
110
- }
111
- for tool in response.tools
106
+ ToolParam(
107
+ name=tool.name,
108
+ description=tool.description or "",
109
+ input_schema=tool.inputSchema,
110
+ )
111
+ for tool in tool_list.tools
112
112
  ]
113
113
 
114
114
  responses: List[Message] = []
@@ -135,17 +135,14 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
135
135
 
136
136
  self.logger.debug(f"{arguments}")
137
137
 
138
- executor_result = await self.executor.execute(
139
- anthropic.messages.create, **arguments
140
- )
138
+ executor_result = await self.executor.execute(anthropic.messages.create, **arguments)
141
139
 
142
140
  response = executor_result[0]
143
141
 
144
142
  if isinstance(response, AuthenticationError):
145
143
  raise ProviderKeyError(
146
144
  "Invalid Anthropic API key",
147
- "The configured Anthropic API key was rejected.\n"
148
- "Please check that your API key is valid and not expired.",
145
+ "The configured Anthropic API key was rejected.\nPlease check that your API key is valid and not expired.",
149
146
  ) from response
150
147
  elif isinstance(response, BaseException):
151
148
  error_details = str(response)
@@ -155,13 +152,9 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
155
152
  if hasattr(response, "status_code") and hasattr(response, "response"):
156
153
  try:
157
154
  error_json = response.response.json()
158
- error_details = (
159
- f"Error code: {response.status_code} - {error_json}"
160
- )
155
+ error_details = f"Error code: {response.status_code} - {error_json}"
161
156
  except: # noqa: E722
162
- error_details = (
163
- f"Error code: {response.status_code} - {str(response)}"
164
- )
157
+ error_details = f"Error code: {response.status_code} - {str(response)}"
165
158
 
166
159
  # Convert other errors to text response
167
160
  error_message = f"Error during generation: {error_details}"
@@ -172,7 +165,7 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
172
165
  type="message",
173
166
  content=[TextBlock(type="text", text=error_message)],
174
167
  stop_reason="end_turn", # Must be one of the allowed values
175
- usage={"input_tokens": 0, "output_tokens": 0}, # Required field
168
+ usage=Usage(input_tokens=0, output_tokens=0), # Required field
176
169
  )
177
170
 
178
171
  self.logger.debug(
@@ -194,22 +187,16 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
194
187
 
195
188
  await self.show_assistant_message(message_text)
196
189
 
197
- self.logger.debug(
198
- f"Iteration {i}: Stopping because finish_reason is 'end_turn'"
199
- )
190
+ self.logger.debug(f"Iteration {i}: Stopping because finish_reason is 'end_turn'")
200
191
  break
201
192
  elif response.stop_reason == "stop_sequence":
202
193
  # We have reached a stop sequence
203
- self.logger.debug(
204
- f"Iteration {i}: Stopping because finish_reason is 'stop_sequence'"
205
- )
194
+ self.logger.debug(f"Iteration {i}: Stopping because finish_reason is 'stop_sequence'")
206
195
  break
207
196
  elif response.stop_reason == "max_tokens":
208
197
  # We have reached the max tokens limit
209
198
 
210
- self.logger.debug(
211
- f"Iteration {i}: Stopping because finish_reason is 'max_tokens'"
212
- )
199
+ self.logger.debug(f"Iteration {i}: Stopping because finish_reason is 'max_tokens'")
213
200
  if params.maxTokens is not None:
214
201
  message_text = Text(
215
202
  f"the assistant has reached the maximum token limit ({params.maxTokens})",
@@ -256,22 +243,16 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
256
243
  self.show_tool_call(available_tools, tool_name, tool_args)
257
244
  tool_call_request = CallToolRequest(
258
245
  method="tools/call",
259
- params=CallToolRequestParams(
260
- name=tool_name, arguments=tool_args
261
- ),
246
+ params=CallToolRequestParams(name=tool_name, arguments=tool_args),
262
247
  )
263
248
  # TODO -- support MCP isError etc.
264
- result = await self.call_tool(
265
- request=tool_call_request, tool_call_id=tool_use_id
266
- )
249
+ result = await self.call_tool(request=tool_call_request, tool_call_id=tool_use_id)
267
250
  self.show_tool_result(result)
268
251
 
269
252
  # Add each result to our collection
270
253
  tool_results.append((tool_use_id, result))
271
254
 
272
- messages.append(
273
- AnthropicConverter.create_tool_results_message(tool_results)
274
- )
255
+ messages.append(AnthropicConverter.create_tool_results_message(tool_results))
275
256
 
276
257
  # Only save the new conversation messages to history if use_history is true
277
258
  # Keep the prompt messages separate
@@ -352,15 +333,13 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
352
333
  # Join all collected text
353
334
  return "\n".join(final_text)
354
335
 
355
- async def generate_prompt(
356
- self, prompt: "PromptMessageMultipart", request_params: RequestParams | None
357
- ) -> str:
358
- return await self.generate_str(
359
- AnthropicConverter.convert_to_anthropic(prompt), request_params
360
- )
336
+ async def generate_prompt(self, prompt: "PromptMessageMultipart", request_params: RequestParams | None) -> str:
337
+ return await self.generate_str(AnthropicConverter.convert_to_anthropic(prompt), request_params)
361
338
 
362
339
  async def _apply_prompt_template_provider_specific(
363
- self, multipart_messages: List["PromptMessageMultipart"]
340
+ self,
341
+ multipart_messages: List["PromptMessageMultipart"],
342
+ request_params: RequestParams | None = None,
364
343
  ) -> str:
365
344
  """
366
345
  Anthropic-specific implementation of apply_prompt_template that handles
@@ -377,11 +356,7 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
377
356
  last_message = multipart_messages[-1]
378
357
 
379
358
  # Add all previous messages to history (or all messages if last is from assistant)
380
- messages_to_add = (
381
- multipart_messages[:-1]
382
- if last_message.role == "user"
383
- else multipart_messages
384
- )
359
+ messages_to_add = multipart_messages[:-1] if last_message.role == "user" else multipart_messages
385
360
  converted = []
386
361
  for msg in messages_to_add:
387
362
  converted.append(AnthropicConverter.convert_to_anthropic(msg))
@@ -389,16 +364,12 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
389
364
 
390
365
  if last_message.role == "user":
391
366
  # For user messages: Generate response to the last one
392
- self.logger.debug(
393
- "Last message in prompt is from user, generating assistant response"
394
- )
367
+ self.logger.debug("Last message in prompt is from user, generating assistant response")
395
368
  message_param = AnthropicConverter.convert_to_anthropic(last_message)
396
- return await self.generate_str(message_param)
369
+ return await self.generate_str(message_param, request_params)
397
370
  else:
398
371
  # For assistant messages: Return the last message content as text
399
- self.logger.debug(
400
- "Last message in prompt is from assistant, returning it directly"
401
- )
372
+ self.logger.debug("Last message in prompt is from assistant, returning it directly")
402
373
  return str(last_message)
403
374
 
404
375
  async def _save_history_to_file(self, command: str) -> str:
@@ -423,19 +394,17 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
423
394
  messages = self.history.get(include_history=True)
424
395
 
425
396
  # Import required utilities
426
- from mcp_agent.workflows.llm.anthropic_utils import (
427
- anthropic_message_param_to_prompt_message_multipart,
428
- )
429
397
  from mcp_agent.mcp.prompt_serialization import (
430
398
  multipart_messages_to_delimited_format,
431
399
  )
400
+ from mcp_agent.workflows.llm.anthropic_utils import (
401
+ anthropic_message_param_to_prompt_message_multipart,
402
+ )
432
403
 
433
404
  # Convert message params to PromptMessageMultipart objects
434
405
  multipart_messages = []
435
406
  for msg in messages:
436
- multipart_messages.append(
437
- anthropic_message_param_to_prompt_message_multipart(msg)
438
- )
407
+ multipart_messages.append(anthropic_message_param_to_prompt_message_multipart(msg))
439
408
 
440
409
  # Convert to delimited format
441
410
  delimited_content = multipart_messages_to_delimited_format(
@@ -457,7 +426,7 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
457
426
 
458
427
  async def generate_structured(
459
428
  self,
460
- message,
429
+ message: str,
461
430
  response_model: Type[ModelT],
462
431
  request_params: RequestParams | None = None,
463
432
  ) -> ModelT:
@@ -474,9 +443,7 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
474
443
  return response_model.model_validate(from_json(response, allow_partial=True))
475
444
 
476
445
  @classmethod
477
- def convert_message_to_message_param(
478
- cls, message: Message, **kwargs
479
- ) -> MessageParam:
446
+ def convert_message_to_message_param(cls, message: Message, **kwargs) -> MessageParam:
480
447
  """Convert a response object to an input parameter object to allow LLM calls to be chained."""
481
448
  content = []
482
449
 
@@ -1,5 +1,5 @@
1
1
  import os
2
- from typing import List, Type, TYPE_CHECKING
2
+ from typing import TYPE_CHECKING, List, Type
3
3
 
4
4
  from pydantic_core import from_json
5
5
 
@@ -10,30 +10,30 @@ from mcp_agent.workflows.llm.providers.sampling_converter_openai import (
10
10
 
11
11
  if TYPE_CHECKING:
12
12
  from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
13
- from openai import OpenAI, AuthenticationError
13
+ from mcp.types import (
14
+ CallToolRequest,
15
+ CallToolRequestParams,
16
+ CallToolResult,
17
+ )
18
+ from openai import AuthenticationError, OpenAI
14
19
 
15
20
  # from openai.types.beta.chat import
16
21
  from openai.types.chat import (
17
- ChatCompletionMessageParam,
18
22
  ChatCompletionMessage,
23
+ ChatCompletionMessageParam,
19
24
  ChatCompletionSystemMessageParam,
20
25
  ChatCompletionToolParam,
21
26
  ChatCompletionUserMessageParam,
22
27
  )
23
- from mcp.types import (
24
- CallToolRequestParams,
25
- CallToolRequest,
26
- CallToolResult,
27
- )
28
+ from rich.text import Text
28
29
 
30
+ from mcp_agent.core.exceptions import ProviderKeyError
31
+ from mcp_agent.logging.logger import get_logger
29
32
  from mcp_agent.workflows.llm.augmented_llm import (
30
33
  AugmentedLLM,
31
34
  ModelT,
32
35
  RequestParams,
33
36
  )
34
- from mcp_agent.core.exceptions import ProviderKeyError
35
- from mcp_agent.logging.logger import get_logger
36
- from rich.text import Text
37
37
 
38
38
  _logger = get_logger(__name__)
39
39
 
@@ -41,16 +41,14 @@ DEFAULT_OPENAI_MODEL = "gpt-4o"
41
41
  DEFAULT_REASONING_EFFORT = "medium"
42
42
 
43
43
 
44
- class OpenAIAugmentedLLM(
45
- AugmentedLLM[ChatCompletionMessageParam, ChatCompletionMessage]
46
- ):
44
+ class OpenAIAugmentedLLM(AugmentedLLM[ChatCompletionMessageParam, ChatCompletionMessage]):
47
45
  """
48
46
  The basic building block of agentic systems is an LLM enhanced with augmentations
49
47
  such as retrieval, tools, and memory provided from a collection of MCP servers.
50
48
  This implementation uses OpenAI's ChatCompletion as the LLM.
51
49
  """
52
50
 
53
- def __init__(self, *args, **kwargs):
51
+ def __init__(self, *args, **kwargs) -> None:
54
52
  # Set type_converter before calling super().__init__
55
53
  if "type_converter" not in kwargs:
56
54
  kwargs["type_converter"] = OpenAISamplingConverter
@@ -64,22 +62,14 @@ class OpenAIAugmentedLLM(
64
62
  # Set up reasoning-related attributes
65
63
  self._reasoning_effort = kwargs.get("reasoning_effort", None)
66
64
  if self.context and self.context.config and self.context.config.openai:
67
- if self._reasoning_effort is None and hasattr(
68
- self.context.config.openai, "reasoning_effort"
69
- ):
65
+ if self._reasoning_effort is None and hasattr(self.context.config.openai, "reasoning_effort"):
70
66
  self._reasoning_effort = self.context.config.openai.reasoning_effort
71
67
 
72
68
  # Determine if we're using a reasoning model
73
- chosen_model = (
74
- self.default_request_params.model if self.default_request_params else None
75
- )
76
- self._reasoning = chosen_model and (
77
- chosen_model.startswith("o3") or chosen_model.startswith("o1")
78
- )
69
+ chosen_model = self.default_request_params.model if self.default_request_params else None
70
+ self._reasoning = chosen_model and (chosen_model.startswith("o3") or chosen_model.startswith("o1"))
79
71
  if self._reasoning:
80
- self.logger.info(
81
- f"Using reasoning model '{chosen_model}' with '{self._reasoning_effort}' reasoning effort"
82
- )
72
+ self.logger.info(f"Using reasoning model '{chosen_model}' with '{self._reasoning_effort}' reasoning effort")
83
73
 
84
74
  def _initialize_default_params(self, kwargs: dict) -> RequestParams:
85
75
  """Initialize OpenAI-specific default parameters"""
@@ -92,7 +82,6 @@ class OpenAIAugmentedLLM(
92
82
 
93
83
  return RequestParams(
94
84
  model=chosen_model,
95
- modelPreferences=self.model_preferences,
96
85
  systemPrompt=self.instruction,
97
86
  parallel_tool_calls=True,
98
87
  max_iterations=10,
@@ -121,9 +110,7 @@ class OpenAIAugmentedLLM(
121
110
  return api_key
122
111
 
123
112
  def _base_url(self) -> str:
124
- return (
125
- self.context.config.openai.base_url if self.context.config.openai else None
126
- )
113
+ return self.context.config.openai.base_url if self.context.config.openai else None
127
114
 
128
115
  async def generate(
129
116
  self,
@@ -144,24 +131,19 @@ class OpenAIAugmentedLLM(
144
131
  except AuthenticationError as e:
145
132
  raise ProviderKeyError(
146
133
  "Invalid OpenAI API key",
147
- "The configured OpenAI API key was rejected.\n"
148
- "Please check that your API key is valid and not expired.",
134
+ "The configured OpenAI API key was rejected.\n" "Please check that your API key is valid and not expired.",
149
135
  ) from e
150
136
 
151
137
  system_prompt = self.instruction or params.systemPrompt
152
138
  if system_prompt:
153
- messages.append(
154
- ChatCompletionSystemMessageParam(role="system", content=system_prompt)
155
- )
139
+ messages.append(ChatCompletionSystemMessageParam(role="system", content=system_prompt))
156
140
 
157
141
  # Always include prompt messages, but only include conversation history
158
142
  # if use_history is True
159
143
  messages.extend(self.history.get(include_history=params.use_history))
160
144
 
161
145
  if isinstance(message, str):
162
- messages.append(
163
- ChatCompletionUserMessageParam(role="user", content=message)
164
- )
146
+ messages.append(ChatCompletionUserMessageParam(role="user", content=message))
165
147
  elif isinstance(message, list):
166
148
  messages.extend(message)
167
149
  else:
@@ -187,9 +169,7 @@ class OpenAIAugmentedLLM(
187
169
  model = await self.select_model(params)
188
170
  chat_turn = len(messages) // 2
189
171
  if self._reasoning:
190
- self.show_user_message(
191
- str(message), f"{model} ({self._reasoning_effort})", chat_turn
192
- )
172
+ self.show_user_message(str(message), f"{model} ({self._reasoning_effort})", chat_turn)
193
173
  else:
194
174
  self.show_user_message(str(message), model, chat_turn)
195
175
 
@@ -218,9 +198,7 @@ class OpenAIAugmentedLLM(
218
198
  self._log_chat_progress(chat_turn, model=model)
219
199
 
220
200
  if response_model is None:
221
- executor_result = await self.executor.execute(
222
- openai_client.chat.completions.create, **arguments
223
- )
201
+ executor_result = await self.executor.execute(openai_client.chat.completions.create, **arguments)
224
202
  else:
225
203
  executor_result = await self.executor.execute(
226
204
  openai_client.beta.chat.completions.parse,
@@ -238,8 +216,7 @@ class OpenAIAugmentedLLM(
238
216
  if isinstance(response, AuthenticationError):
239
217
  raise ProviderKeyError(
240
218
  "Invalid OpenAI API key",
241
- "The configured OpenAI API key was rejected.\n"
242
- "Please check that your API key is valid and not expired.",
219
+ "The configured OpenAI API key was rejected.\n" "Please check that your API key is valid and not expired.",
243
220
  ) from response
244
221
  elif isinstance(response, BaseException):
245
222
  self.logger.error(f"Error: {response}")
@@ -255,21 +232,14 @@ class OpenAIAugmentedLLM(
255
232
  message = choice.message
256
233
  responses.append(message)
257
234
 
258
- converted_message = self.convert_message_to_message_param(
259
- message, name=self.name
260
- )
235
+ converted_message = self.convert_message_to_message_param(message, name=self.name)
261
236
  messages.append(converted_message)
262
237
  message_text = converted_message.content
263
- if (
264
- choice.finish_reason in ["tool_calls", "function_call"]
265
- and message.tool_calls
266
- ):
238
+ if choice.finish_reason in ["tool_calls", "function_call"] and message.tool_calls:
267
239
  if message_text:
268
240
  await self.show_assistant_message(
269
241
  message_text,
270
- message.tool_calls[
271
- 0
272
- ].function.name, # TODO support displaying multiple tool calls
242
+ message.tool_calls[0].function.name, # TODO support displaying multiple tool calls
273
243
  )
274
244
  else:
275
245
  await self.show_assistant_message(
@@ -291,9 +261,7 @@ class OpenAIAugmentedLLM(
291
261
  method="tools/call",
292
262
  params=CallToolRequestParams(
293
263
  name=tool_call.function.name,
294
- arguments=from_json(
295
- tool_call.function.arguments, allow_partial=True
296
- ),
264
+ arguments=from_json(tool_call.function.arguments, allow_partial=True),
297
265
  ),
298
266
  )
299
267
  result = await self.call_tool(tool_call_request, tool_call.id)
@@ -301,18 +269,12 @@ class OpenAIAugmentedLLM(
301
269
 
302
270
  tool_results.append((tool_call.id, result))
303
271
 
304
- messages.extend(
305
- OpenAIConverter.convert_function_results_to_openai(tool_results)
306
- )
272
+ messages.extend(OpenAIConverter.convert_function_results_to_openai(tool_results))
307
273
 
308
- self.logger.debug(
309
- f"Iteration {i}: Tool call results: {str(tool_results) if tool_results else 'None'}"
310
- )
274
+ self.logger.debug(f"Iteration {i}: Tool call results: {str(tool_results) if tool_results else 'None'}")
311
275
  elif choice.finish_reason == "length":
312
276
  # We have reached the max tokens limit
313
- self.logger.debug(
314
- f"Iteration {i}: Stopping because finish_reason is 'length'"
315
- )
277
+ self.logger.debug(f"Iteration {i}: Stopping because finish_reason is 'length'")
316
278
  if request_params and request_params.maxTokens is not None:
317
279
  message_text = Text(
318
280
  f"the assistant has reached the maximum token limit ({request_params.maxTokens})",
@@ -329,15 +291,11 @@ class OpenAIAugmentedLLM(
329
291
  break
330
292
  elif choice.finish_reason == "content_filter":
331
293
  # The response was filtered by the content filter
332
- self.logger.debug(
333
- f"Iteration {i}: Stopping because finish_reason is 'content_filter'"
334
- )
294
+ self.logger.debug(f"Iteration {i}: Stopping because finish_reason is 'content_filter'")
335
295
  # TODO: saqadri - would be useful to return the reason for stopping to the caller
336
296
  break
337
297
  elif choice.finish_reason == "stop":
338
- self.logger.debug(
339
- f"Iteration {i}: Stopping because finish_reason is 'stop'"
340
- )
298
+ self.logger.debug(f"Iteration {i}: Stopping because finish_reason is 'stop'")
341
299
  if message_text:
342
300
  await self.show_assistant_message(message_text, "")
343
301
  break
@@ -395,7 +353,9 @@ class OpenAIAugmentedLLM(
395
353
  return "\n".join(final_text)
396
354
 
397
355
  async def _apply_prompt_template_provider_specific(
398
- self, multipart_messages: List["PromptMessageMultipart"]
356
+ self,
357
+ multipart_messages: List["PromptMessageMultipart"],
358
+ request_params: RequestParams | None = None,
399
359
  ) -> str:
400
360
  """
401
361
  OpenAI-specific implementation of apply_prompt_template that handles
@@ -415,11 +375,7 @@ class OpenAIAugmentedLLM(
415
375
  last_message = multipart_messages[-1]
416
376
 
417
377
  # Add all previous messages to history (or all messages if last is from assistant)
418
- messages_to_add = (
419
- multipart_messages[:-1]
420
- if last_message.role == "user"
421
- else multipart_messages
422
- )
378
+ messages_to_add = multipart_messages[:-1] if last_message.role == "user" else multipart_messages
423
379
  converted = []
424
380
  for msg in messages_to_add:
425
381
  converted.append(OpenAIConverter.convert_to_openai(msg))
@@ -427,16 +383,12 @@ class OpenAIAugmentedLLM(
427
383
 
428
384
  if last_message.role == "user":
429
385
  # For user messages: Generate response to the last one
430
- self.logger.debug(
431
- "Last message in prompt is from user, generating assistant response"
432
- )
386
+ self.logger.debug("Last message in prompt is from user, generating assistant response")
433
387
  message_param = OpenAIConverter.convert_to_openai(last_message)
434
- return await self.generate_str(message_param)
388
+ return await self.generate_str(message_param, request_params)
435
389
  else:
436
390
  # For assistant messages: Return the last message content as text
437
- self.logger.debug(
438
- "Last message in prompt is from assistant, returning it directly"
439
- )
391
+ self.logger.debug("Last message in prompt is from assistant, returning it directly")
440
392
  return str(last_message)
441
393
 
442
394
  async def _save_history_to_file(self, command: str) -> str:
@@ -461,12 +413,12 @@ class OpenAIAugmentedLLM(
461
413
  messages = self.history.get(include_history=True)
462
414
 
463
415
  # Import required utilities
464
- from mcp_agent.workflows.llm.openai_utils import (
465
- openai_message_param_to_prompt_message_multipart,
466
- )
467
416
  from mcp_agent.mcp.prompt_serialization import (
468
417
  multipart_messages_to_delimited_format,
469
418
  )
419
+ from mcp_agent.workflows.llm.openai_utils import (
420
+ openai_message_param_to_prompt_message_multipart,
421
+ )
470
422
 
471
423
  # Convert message params to PromptMessageMultipart objects
472
424
  multipart_messages = []
@@ -476,9 +428,7 @@ class OpenAIAugmentedLLM(
476
428
  continue
477
429
 
478
430
  # Convert the message to a multipart message
479
- multipart_messages.append(
480
- openai_message_param_to_prompt_message_multipart(msg)
481
- )
431
+ multipart_messages.append(openai_message_param_to_prompt_message_multipart(msg))
482
432
 
483
433
  # Convert to delimited format
484
434
  delimited_content = multipart_messages_to_delimited_format(
@@ -511,18 +461,14 @@ class OpenAIAugmentedLLM(
511
461
  )
512
462
  return responses[0].parsed
513
463
 
514
- async def generate_prompt(
515
- self, prompt: "PromptMessageMultipart", request_params: RequestParams | None
516
- ) -> str:
464
+ async def generate_prompt(self, prompt: "PromptMessageMultipart", request_params: RequestParams | None) -> str:
517
465
  converted_prompt = OpenAIConverter.convert_to_openai(prompt)
518
466
  return await self.generate_str(converted_prompt, request_params)
519
467
 
520
468
  async def pre_tool_call(self, tool_call_id: str | None, request: CallToolRequest):
521
469
  return request
522
470
 
523
- async def post_tool_call(
524
- self, tool_call_id: str | None, request: CallToolRequest, result: CallToolResult
525
- ):
471
+ async def post_tool_call(self, tool_call_id: str | None, request: CallToolRequest, result: CallToolResult):
526
472
  return result
527
473
 
528
474
  def message_param_str(self, message: ChatCompletionMessageParam) -> str: