fast-agent-mcp 0.1.6__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,7 +3,7 @@ import asyncio
3
3
  from mcp_agent.core.fastagent import FastAgent
4
4
  # from rich import print
5
5
 
6
- agents = FastAgent(name="Researcher")
6
+ agents = FastAgent(name="Researcher Agent")
7
7
 
8
8
 
9
9
  @agents.agent(
@@ -45,13 +45,14 @@ fast = FastAgent("Orchestrator-Workers")
45
45
  @fast.orchestrator(
46
46
  name="orchestrate",
47
47
  agents=["finder", "writer", "proofreader"],
48
- plan_type="full",
48
+ plan_type="iterative",
49
49
  )
50
50
  async def main():
51
51
  async with fast.run() as agent:
52
- await agent.author(
53
- "write a 250 word short story about kittens discovering a castle, and save it to short_story.md"
54
- )
52
+ await agent()
53
+ # await agent.author(
54
+ # "write a 250 word short story about kittens discovering a castle, and save it to short_story.md"
55
+ # )
55
56
 
56
57
  # The orchestrator can be used just like any other agent
57
58
  task = (
@@ -25,7 +25,6 @@ SAMPLE_REQUESTS = [
25
25
  name="fetcher",
26
26
  instruction="""You are an agent, with a tool enabling you to fetch URLs.""",
27
27
  servers=["fetch"],
28
- model="haiku",
29
28
  )
30
29
  @fast.agent(
31
30
  name="code_expert",
@@ -33,7 +32,6 @@ SAMPLE_REQUESTS = [
33
32
  When asked about code, architecture, or development practices,
34
33
  you provide thorough and practical insights.""",
35
34
  servers=["filesystem"],
36
- model="gpt-4o",
37
35
  )
38
36
  @fast.agent(
39
37
  name="general_assistant",
@@ -12,6 +12,7 @@ from mcp_agent.workflows.llm.augmented_llm import (
12
12
  )
13
13
  from mcp_agent.agents.agent import Agent, AgentConfig
14
14
  from mcp_agent.logging.logger import get_logger
15
+ from mcp_agent.workflows.llm.augmented_llm_passthrough import PassthroughLLM
15
16
 
16
17
  if TYPE_CHECKING:
17
18
  from mcp_agent.context import Context
@@ -89,45 +90,33 @@ class EvaluatorOptimizerLLM(AugmentedLLM[MessageParamT, MessageT]):
89
90
  evaluator: str | Agent | AugmentedLLM,
90
91
  min_rating: QualityRating = QualityRating.GOOD,
91
92
  max_refinements: int = 3,
92
- llm_factory: Callable[[Agent], AugmentedLLM]
93
- | None = None, # TODO: Remove legacy - factory should only be needed for str evaluator
93
+ llm_factory: Callable[[Agent], AugmentedLLM] | None = None,
94
94
  context: Optional["Context"] = None,
95
- name: Optional[str] = None, # Allow overriding the name
96
- instruction: Optional[str] = None, # Allow overriding the instruction
95
+ name: Optional[str] = None,
96
+ instruction: Optional[str] = None,
97
97
  ):
98
98
  """
99
99
  Initialize the evaluator-optimizer workflow.
100
100
 
101
101
  Args:
102
- generator: The agent/LLM/workflow that generates responses. Can be:
103
- - An Agent that will be converted to an AugmentedLLM
104
- - An AugmentedLLM instance
105
- - An Orchestrator/Router/ParallelLLM workflow
106
- evaluator_agent: The agent/LLM that evaluates responses
107
- evaluation_criteria: Criteria for the evaluator to assess responses
102
+ generator: The agent/LLM/workflow that generates responses
103
+ evaluator: The evaluator (string instruction, Agent or AugmentedLLM)
108
104
  min_rating: Minimum acceptable quality rating
109
105
  max_refinements: Maximum refinement iterations
110
- llm_factory: Optional factory to create LLMs from agents
106
+ llm_factory: Factory to create LLMs from agents when needed
111
107
  name: Optional name for the workflow (defaults to generator's name)
112
108
  instruction: Optional instruction (defaults to generator's instruction)
113
-
114
- Note on History Management:
115
- This workflow manages two distinct history contexts:
116
- 1. Generator History: Controlled by the generator's use_history setting. When False,
117
- each refinement iteration starts fresh without previous context.
118
- 2. Evaluator History: Always disabled as each evaluation should be independent
119
- and based solely on the current response.
120
109
  """
121
- # Set up initial instance attributes - allow name override
122
- self.name = name or generator.name
110
+ # Set initial attributes
111
+ self.name = name or getattr(generator, "name", "EvaluatorOptimizer")
123
112
  self.llm_factory = llm_factory
124
113
  self.generator = generator
125
114
  self.evaluator = evaluator
126
115
  self.min_rating = min_rating
127
116
  self.max_refinements = max_refinements
128
117
 
129
- # Determine generator's history setting before super().__init__
130
-
118
+ # Determine generator's history setting directly based on type
119
+ self.generator_use_history = False
131
120
  if isinstance(generator, Agent):
132
121
  self.generator_use_history = generator.config.use_history
133
122
  elif isinstance(generator, AugmentedLLM):
@@ -135,90 +124,55 @@ class EvaluatorOptimizerLLM(AugmentedLLM[MessageParamT, MessageT]):
135
124
  generator.aggregator, Agent
136
125
  ):
137
126
  self.generator_use_history = generator.aggregator.config.use_history
138
- else:
127
+ elif hasattr(generator, "default_request_params"):
139
128
  self.generator_use_history = getattr(
140
- generator,
141
- "use_history",
142
- getattr(generator.default_request_params, "use_history", False),
129
+ generator.default_request_params, "use_history", False
143
130
  )
144
- # Handle ChainProxy with type checking
145
- elif hasattr(generator, "_sequence") and hasattr(generator, "_agent_proxies"):
146
- # This is how we detect a ChainProxy without directly importing it
147
- # For ChainProxy, we'll default use_history to False
148
- self.generator_use_history = False
149
- else:
150
- raise ValueError(f"Unsupported optimizer type: {type(generator)}")
151
-
152
- # Now we can call super().__init__ which will use generator_use_history
153
- super().__init__(context=context, name=name or generator.name)
131
+ # All other types default to False
154
132
 
155
- # Add a PassthroughLLM as _llm property for compatibility with Orchestrator
156
- from mcp_agent.workflows.llm.augmented_llm import PassthroughLLM
133
+ # Initialize parent class
134
+ super().__init__(context=context, name=name or getattr(generator, "name", None))
157
135
 
136
+ # Create a PassthroughLLM as _llm property
137
+ # TODO -- remove this when we fix/remove the inheritance hierarchy
158
138
  self._llm = PassthroughLLM(name=f"{self.name}_passthrough", context=context)
159
139
 
160
- # Set up the generator
161
-
140
+ # Set up the generator based on type
162
141
  if isinstance(generator, Agent):
163
142
  if not llm_factory:
164
- raise ValueError("llm_factory is required when using an Agent")
165
-
166
- # Only create new LLM if agent doesn't have one
167
- if hasattr(generator, "_llm") and generator._llm:
168
- self.generator_llm = generator._llm
169
- else:
170
- self.generator_llm = llm_factory(agent=generator)
143
+ raise ValueError(
144
+ "llm_factory is required when using an Agent generator"
145
+ )
171
146
 
147
+ # Use existing LLM if available, otherwise create new one
148
+ self.generator_llm = getattr(generator, "_llm", None) or llm_factory(
149
+ agent=generator
150
+ )
172
151
  self.aggregator = generator
173
- self.instruction = (
174
- instruction # Use provided instruction if any
175
- or (
176
- generator.instruction
177
- if isinstance(generator.instruction, str)
178
- else None
179
- ) # Fallback to generator's
152
+ self.instruction = instruction or (
153
+ generator.instruction
154
+ if isinstance(generator.instruction, str)
155
+ else None
180
156
  )
181
- elif hasattr(generator, "_sequence") and hasattr(generator, "_agent_proxies"):
182
- # For ChainProxy, use it directly for generation
157
+ elif isinstance(generator, AugmentedLLM):
158
+ self.generator_llm = generator
159
+ self.aggregator = getattr(generator, "aggregator", None)
160
+ self.instruction = instruction or generator.instruction
161
+ else:
162
+ # ChainProxy-like object
183
163
  self.generator_llm = generator
184
164
  self.aggregator = None
185
165
  self.instruction = (
186
166
  instruction or f"Chain of agents: {', '.join(generator._sequence)}"
187
167
  )
188
168
 
189
- elif isinstance(generator, AugmentedLLM):
190
- self.generator_llm = generator
191
- self.aggregator = generator.aggregator
192
- self.instruction = generator.instruction
193
-
194
- # Set up the evaluator - evaluations should be independent, so history is always disabled
195
- if isinstance(evaluator, AugmentedLLM):
196
- self.evaluator_llm = evaluator
197
- # Override evaluator's history setting
198
- if hasattr(evaluator, "default_request_params"):
199
- evaluator.default_request_params.use_history = False
200
- elif isinstance(evaluator, Agent):
201
- if not llm_factory:
202
- raise ValueError(
203
- "llm_factory is required when using an Agent evaluator"
204
- )
205
-
206
- # Create evaluator with history disabled
207
- if hasattr(evaluator, "_llm") and evaluator._llm:
208
- self.evaluator_llm = evaluator._llm
209
- if hasattr(self.evaluator_llm, "default_request_params"):
210
- self.evaluator_llm.default_request_params.use_history = False
211
- else:
212
- # Force history off in config before creating LLM
213
- evaluator.config.use_history = False
214
- self.evaluator_llm = llm_factory(agent=evaluator)
215
- elif isinstance(evaluator, str):
169
+ # Set up the evaluator - always disable history
170
+ if isinstance(evaluator, str):
216
171
  if not llm_factory:
217
172
  raise ValueError(
218
173
  "llm_factory is required when using a string evaluator"
219
174
  )
220
175
 
221
- # Create evaluator agent with history disabled
222
176
  evaluator_agent = Agent(
223
177
  name="Evaluator",
224
178
  instruction=evaluator,
@@ -226,17 +180,33 @@ class EvaluatorOptimizerLLM(AugmentedLLM[MessageParamT, MessageT]):
226
180
  name="Evaluator",
227
181
  instruction=evaluator,
228
182
  servers=[],
229
- use_history=False, # Force history off for evaluator
183
+ use_history=False,
230
184
  ),
231
185
  )
232
186
  self.evaluator_llm = llm_factory(agent=evaluator_agent)
187
+ elif isinstance(evaluator, Agent):
188
+ if not llm_factory:
189
+ raise ValueError(
190
+ "llm_factory is required when using an Agent evaluator"
191
+ )
192
+
193
+ # Disable history and use/create LLM
194
+ evaluator.config.use_history = False
195
+ self.evaluator_llm = getattr(evaluator, "_llm", None) or llm_factory(
196
+ agent=evaluator
197
+ )
198
+ elif isinstance(evaluator, AugmentedLLM):
199
+ self.evaluator_llm = evaluator
200
+ # Ensure history is disabled
201
+ if hasattr(self.evaluator_llm, "default_request_params"):
202
+ self.evaluator_llm.default_request_params.use_history = False
233
203
  else:
234
204
  raise ValueError(f"Unsupported evaluator type: {type(evaluator)}")
235
205
 
236
- # Track iteration history (for the workflow itself)
206
+ # Track iteration history
237
207
  self.refinement_history = []
238
208
 
239
- # Set up workflow's default params based on generator's history setting
209
+ # Set up workflow's default params
240
210
  self.default_request_params = self._initialize_default_params({})
241
211
 
242
212
  # Ensure evaluator's request params have history disabled
@@ -1,7 +1,6 @@
1
1
  from abc import abstractmethod
2
2
 
3
3
  from typing import (
4
- Any,
5
4
  Generic,
6
5
  List,
7
6
  Optional,
@@ -9,7 +8,6 @@ from typing import (
9
8
  Type,
10
9
  TypeVar,
11
10
  TYPE_CHECKING,
12
- Union,
13
11
  )
14
12
 
15
13
  from pydantic import Field
@@ -567,7 +565,7 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
567
565
  text_parts.append(part.text)
568
566
  if text_parts:
569
567
  return "\n".join(text_parts)
570
-
568
+
571
569
  # For objects with content attribute
572
570
  if hasattr(message, "content"):
573
571
  content = message.content
@@ -575,7 +573,7 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
575
573
  return content
576
574
  elif hasattr(content, "text"):
577
575
  return content.text
578
-
576
+
579
577
  # Default fallback
580
578
  return str(message)
581
579
 
@@ -588,7 +586,7 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
588
586
  result = self.message_param_str(message)
589
587
  if result != str(message):
590
588
  return result
591
-
589
+
592
590
  # Additional handling for output-specific formats
593
591
  if hasattr(message, "content"):
594
592
  content = message.content
@@ -600,7 +598,7 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
600
598
  text_parts.append(block.text)
601
599
  if text_parts:
602
600
  return "\n".join(text_parts)
603
-
601
+
604
602
  # Default fallback
605
603
  return str(message)
606
604
 
@@ -650,7 +648,7 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
650
648
  ):
651
649
  """
652
650
  Display information about a loaded prompt template.
653
-
651
+
654
652
  Args:
655
653
  prompt_name: The name of the prompt
656
654
  description: Optional description of the prompt
@@ -679,11 +677,11 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
679
677
  prompt_name: The name of the prompt being applied
680
678
 
681
679
  Returns:
682
- String representation of the assistant's response if generated,
680
+ String representation of the assistant's response if generated,
683
681
  or the last assistant message in the prompt
684
682
  """
685
683
  prompt_messages: List[PromptMessage] = prompt_result.messages
686
-
684
+
687
685
  # Check if we have any messages
688
686
  if not prompt_messages:
689
687
  return "Prompt contains no messages"
@@ -698,14 +696,16 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
698
696
  message_count=len(prompt_messages),
699
697
  arguments=arguments,
700
698
  )
701
-
699
+
702
700
  # Check the last message role
703
701
  last_message = prompt_messages[-1]
704
-
702
+
705
703
  if last_message.role == "user":
706
704
  # For user messages: Add all previous messages to history, then generate response to the last one
707
- self.logger.debug("Last message in prompt is from user, generating assistant response")
708
-
705
+ self.logger.debug(
706
+ "Last message in prompt is from user, generating assistant response"
707
+ )
708
+
709
709
  # Add all but the last message to history
710
710
  if len(prompt_messages) > 1:
711
711
  previous_messages = prompt_messages[:-1]
@@ -713,87 +713,28 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
713
713
  for msg in previous_messages:
714
714
  converted.append(self.type_converter.from_mcp_prompt_message(msg))
715
715
  self.history.extend(converted, is_prompt=True)
716
-
716
+
717
717
  # Extract the user's question and generate a response
718
718
  user_content = last_message.content
719
- user_text = user_content.text if hasattr(user_content, "text") else str(user_content)
720
-
719
+ user_text = (
720
+ user_content.text
721
+ if hasattr(user_content, "text")
722
+ else str(user_content)
723
+ )
724
+
721
725
  return await self.generate_str(user_text)
722
726
  else:
723
727
  # For assistant messages: Add all messages to history and return the last one
724
- self.logger.debug("Last message in prompt is from assistant, returning it directly")
725
-
728
+ self.logger.debug(
729
+ "Last message in prompt is from assistant, returning it directly"
730
+ )
731
+
726
732
  # Convert and add all messages to history
727
733
  converted = []
728
734
  for msg in prompt_messages:
729
735
  converted.append(self.type_converter.from_mcp_prompt_message(msg))
730
736
  self.history.extend(converted, is_prompt=True)
731
-
737
+
732
738
  # Return the assistant's message
733
739
  content = last_message.content
734
740
  return content.text if hasattr(content, "text") else str(content)
735
-
736
-
737
-
738
- class PassthroughLLM(AugmentedLLM):
739
- """
740
- A specialized LLM implementation that simply passes through input messages without modification.
741
-
742
- This is useful for cases where you need an object with the AugmentedLLM interface
743
- but want to preserve the original message without any processing, such as in a
744
- parallel workflow where no fan-in aggregation is needed.
745
- """
746
-
747
- def __init__(self, name: str = "Passthrough", context=None, **kwargs):
748
- super().__init__(name=name, context=context, **kwargs)
749
-
750
- async def generate(
751
- self,
752
- message: Union[str, MessageParamT, List[MessageParamT]],
753
- request_params: Optional[RequestParams] = None,
754
- ) -> Union[List[MessageT], Any]:
755
- """Simply return the input message as is."""
756
- # Return in the format expected by the caller
757
- return [message] if isinstance(message, list) else message
758
-
759
- async def generate_str(
760
- self,
761
- message: Union[str, MessageParamT, List[MessageParamT]],
762
- request_params: Optional[RequestParams] = None,
763
- ) -> str:
764
- """Return the input message as a string."""
765
- self.show_user_message(message, model="fastagent-passthrough", chat_turn=0)
766
- await self.show_assistant_message(message, title="ASSISTANT/PASSTHROUGH")
767
-
768
- return str(message)
769
-
770
- async def generate_structured(
771
- self,
772
- message: Union[str, MessageParamT, List[MessageParamT]],
773
- response_model: Type[ModelT],
774
- request_params: Optional[RequestParams] = None,
775
- ) -> ModelT:
776
- """
777
- Return the input message as the requested model type.
778
- This is a best-effort implementation - it may fail if the
779
- message cannot be converted to the requested model.
780
- """
781
- if isinstance(message, response_model):
782
- return message
783
- elif isinstance(message, dict):
784
- return response_model(**message)
785
- elif isinstance(message, str):
786
- try:
787
- # Try to parse as JSON if it's a string
788
- import json
789
-
790
- data = json.loads(message)
791
- return response_model(**data)
792
- except: # noqa: E722
793
- raise ValueError(
794
- f"Cannot convert message of type {type(message)} to {response_model}"
795
- )
796
- else:
797
- raise ValueError(
798
- f"Cannot convert message of type {type(message)} to {response_model}"
799
- )
@@ -4,7 +4,6 @@ from typing import Iterable, List, Type
4
4
 
5
5
  from pydantic import BaseModel
6
6
 
7
- import instructor
8
7
  from anthropic import Anthropic, AuthenticationError
9
8
  from anthropic.types import (
10
9
  ContentBlock,
@@ -27,8 +26,8 @@ from mcp.types import (
27
26
  TextContent,
28
27
  TextResourceContents,
29
28
  )
29
+ from pydantic_core import from_json
30
30
 
31
- from mcp_agent.workflows.router.router_llm import StructuredResponse
32
31
  from mcp_agent.workflows.llm.augmented_llm import (
33
32
  AugmentedLLM,
34
33
  ModelT,
@@ -96,7 +95,7 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
96
95
  "Please check that your API key is valid and not expired.",
97
96
  ) from e
98
97
 
99
- # Always include prompt messages, but only include conversation history
98
+ # Always include prompt messages, but only include conversation history
100
99
  # if use_history is True
101
100
  messages.extend(self.history.get(include_history=params.use_history))
102
101
 
@@ -295,10 +294,10 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
295
294
  if params.use_history:
296
295
  # Get current prompt messages
297
296
  prompt_messages = self.history.get(include_history=False)
298
-
297
+
299
298
  # Calculate new conversation messages (excluding prompts)
300
- new_messages = messages[len(prompt_messages):]
301
-
299
+ new_messages = messages[len(prompt_messages) :]
300
+
302
301
  # Update conversation history
303
302
  self.history.set(new_messages)
304
303
 
@@ -367,10 +366,7 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
367
366
  response_model: Type[ModelT],
368
367
  request_params: RequestParams | None = None,
369
368
  ) -> ModelT:
370
- # First we invoke the LLM to generate a string response
371
- # We need to do this in a two-step process because Instructor doesn't
372
- # know how to invoke MCP tools via call_tool, so we'll handle all the
373
- # processing first and then pass the final response through Instructor
369
+ # TODO -- simiar to the OAI version, we should create a tool call for the expected schema
374
370
  response = await self.generate_str(
375
371
  message=message,
376
372
  request_params=request_params,
@@ -378,27 +374,9 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
378
374
  # Don't try to parse if we got no response
379
375
  if not response:
380
376
  self.logger.error("No response from generate_str")
381
- return StructuredResponse(categories=[])
382
-
383
- # Next we pass the text through instructor to extract structured data
384
- client = instructor.from_anthropic(
385
- Anthropic(api_key=self._api_key(self.context.config)),
386
- )
377
+ return None
387
378
 
388
- params = self.get_request_params(request_params)
389
- model = await self.select_model(params)
390
-
391
- # Extract structured data from natural language
392
- structured_response = client.chat.completions.create(
393
- model=model,
394
- response_model=response_model,
395
- messages=[{"role": "user", "content": response}],
396
- max_tokens=params.maxTokens,
397
- )
398
- await self.show_assistant_message(
399
- str(structured_response), title="ASSISTANT/STRUCTURED"
400
- )
401
- return structured_response
379
+ return response_model.model_validate(from_json(response, allow_partial=True))
402
380
 
403
381
  @classmethod
404
382
  def convert_message_to_message_param(
@@ -2,15 +2,16 @@ import json
2
2
  import os
3
3
  from typing import Iterable, List, Type
4
4
  from mcp.types import PromptMessage
5
- import instructor
6
5
  from openai import OpenAI, AuthenticationError
6
+
7
+ # from openai.types.beta.chat import
7
8
  from openai.types.chat import (
8
9
  ChatCompletionAssistantMessageParam,
10
+ ChatCompletionMessageParam,
9
11
  ChatCompletionContentPartParam,
10
12
  ChatCompletionContentPartTextParam,
11
13
  ChatCompletionContentPartRefusalParam,
12
14
  ChatCompletionMessage,
13
- ChatCompletionMessageParam,
14
15
  ChatCompletionSystemMessageParam,
15
16
  ChatCompletionToolParam,
16
17
  ChatCompletionToolMessageParam,
@@ -128,7 +129,12 @@ class OpenAIAugmentedLLM(
128
129
  self.context.config.openai.base_url if self.context.config.openai else None
129
130
  )
130
131
 
131
- async def generate(self, message, request_params: RequestParams | None = None):
132
+ async def generate(
133
+ self,
134
+ message,
135
+ request_params: RequestParams | None = None,
136
+ response_model: Type[ModelT] | None = None,
137
+ ) -> List[ChatCompletionMessage]:
132
138
  """
133
139
  Process a query using an LLM and available tools.
134
140
  The default implementation uses OpenAI's ChatCompletion as the LLM.
@@ -152,7 +158,7 @@ class OpenAIAugmentedLLM(
152
158
  ChatCompletionSystemMessageParam(role="system", content=system_prompt)
153
159
  )
154
160
 
155
- # Always include prompt messages, but only include conversation history
161
+ # Always include prompt messages, but only include conversation history
156
162
  # if use_history is True
157
163
  messages.extend(self.history.get(include_history=params.use_history))
158
164
 
@@ -179,7 +185,7 @@ class OpenAIAugmentedLLM(
179
185
  for tool in response.tools
180
186
  ]
181
187
  if not available_tools:
182
- available_tools = None
188
+ available_tools = []
183
189
 
184
190
  responses: List[ChatCompletionMessage] = []
185
191
  model = await self.select_model(params)
@@ -215,9 +221,16 @@ class OpenAIAugmentedLLM(
215
221
  self.logger.debug(f"{arguments}")
216
222
  self._log_chat_progress(chat_turn, model=model)
217
223
 
218
- executor_result = await self.executor.execute(
219
- openai_client.chat.completions.create, **arguments
220
- )
224
+ if response_model is None:
225
+ executor_result = await self.executor.execute(
226
+ openai_client.chat.completions.create, **arguments
227
+ )
228
+ else:
229
+ executor_result = await self.executor.execute(
230
+ openai_client.beta.chat.completions.parse,
231
+ **arguments,
232
+ response_format=response_model,
233
+ )
221
234
 
222
235
  response = executor_result[0]
223
236
 
@@ -334,10 +347,10 @@ class OpenAIAugmentedLLM(
334
347
  if params.use_history:
335
348
  # Get current prompt messages
336
349
  prompt_messages = self.history.get(include_history=False)
337
-
350
+
338
351
  # Calculate new conversation messages (excluding prompts)
339
- new_messages = messages[len(prompt_messages):]
340
-
352
+ new_messages = messages[len(prompt_messages) :]
353
+
341
354
  # Update conversation history
342
355
  self.history.set(new_messages)
343
356
 
@@ -379,40 +392,21 @@ class OpenAIAugmentedLLM(
379
392
  response_model: Type[ModelT],
380
393
  request_params: RequestParams | None = None,
381
394
  ) -> ModelT:
382
- # First we invoke the LLM to generate a string response
383
- # We need to do this in a two-step process because Instructor doesn't
384
- # know how to invoke MCP tools via call_tool, so we'll handle all the
385
- # processing first and then pass the final response through Instructor
386
- response = await self.generate_str(
395
+ responses = await self.generate(
387
396
  message=message,
388
397
  request_params=request_params,
389
- )
390
-
391
- # Next we pass the text through instructor to extract structured data
392
- client = instructor.from_openai(
393
- OpenAI(
394
- api_key=self._api_key(),
395
- base_url=self._base_url(),
396
- ),
397
- mode=instructor.Mode.TOOLS_STRICT,
398
- )
399
-
400
- params = self.get_request_params(request_params)
401
- model = await self.select_model(params)
402
-
403
- # Extract structured data from natural language
404
- structured_response = client.chat.completions.create(
405
- model=model,
406
398
  response_model=response_model,
407
- messages=[
408
- {"role": "user", "content": response},
409
- ],
410
- )
411
- await self.show_assistant_message(
412
- str(structured_response), title="ASSISTANT/STRUCTURED"
413
399
  )
400
+ return responses[0].parsed
401
+
402
+ # return response_model.model_validate(
403
+ # from_json(responses[0].content, allow_partial=True)
404
+ # )
405
+ # part1 = from_json(response, allow_partial=True)
406
+ # return response_model.model_validate(part1)
414
407
 
415
- return structured_response
408
+ # TODO -- would prefer to use the OpenAI message[0].parsed function here
409
+ # return response_model.model_validate(from_json(response, allow_partial=True))
416
410
 
417
411
  async def pre_tool_call(self, tool_call_id: str | None, request: CallToolRequest):
418
412
  return request