fast-agent-mcp 0.1.6__py3-none-any.whl → 0.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {fast_agent_mcp-0.1.6.dist-info → fast_agent_mcp-0.1.8.dist-info}/METADATA +12 -6
- {fast_agent_mcp-0.1.6.dist-info → fast_agent_mcp-0.1.8.dist-info}/RECORD +23 -22
- mcp_agent/core/agent_app.py +38 -24
- mcp_agent/core/decorators.py +3 -2
- mcp_agent/core/enhanced_prompt.py +106 -20
- mcp_agent/core/factory.py +28 -66
- mcp_agent/human_input/handler.py +4 -1
- mcp_agent/mcp/mcp_aggregator.py +16 -12
- mcp_agent/resources/examples/researcher/researcher-eval.py +1 -1
- mcp_agent/resources/examples/researcher/researcher.py +1 -1
- mcp_agent/resources/examples/workflows/orchestrator.py +5 -4
- mcp_agent/resources/examples/workflows/router.py +0 -2
- mcp_agent/workflows/evaluator_optimizer/evaluator_optimizer.py +57 -87
- mcp_agent/workflows/llm/augmented_llm.py +25 -84
- mcp_agent/workflows/llm/augmented_llm_anthropic.py +8 -30
- mcp_agent/workflows/llm/augmented_llm_openai.py +34 -40
- mcp_agent/workflows/llm/augmented_llm_passthrough.py +61 -0
- mcp_agent/workflows/llm/model_factory.py +5 -3
- mcp_agent/workflows/orchestrator/orchestrator.py +62 -153
- mcp_agent/workflows/router/router_llm.py +18 -24
- {fast_agent_mcp-0.1.6.dist-info → fast_agent_mcp-0.1.8.dist-info}/WHEEL +0 -0
- {fast_agent_mcp-0.1.6.dist-info → fast_agent_mcp-0.1.8.dist-info}/entry_points.txt +0 -0
- {fast_agent_mcp-0.1.6.dist-info → fast_agent_mcp-0.1.8.dist-info}/licenses/LICENSE +0 -0
@@ -45,13 +45,14 @@ fast = FastAgent("Orchestrator-Workers")
|
|
45
45
|
@fast.orchestrator(
|
46
46
|
name="orchestrate",
|
47
47
|
agents=["finder", "writer", "proofreader"],
|
48
|
-
plan_type="
|
48
|
+
plan_type="iterative",
|
49
49
|
)
|
50
50
|
async def main():
|
51
51
|
async with fast.run() as agent:
|
52
|
-
await agent
|
53
|
-
|
54
|
-
|
52
|
+
await agent()
|
53
|
+
# await agent.author(
|
54
|
+
# "write a 250 word short story about kittens discovering a castle, and save it to short_story.md"
|
55
|
+
# )
|
55
56
|
|
56
57
|
# The orchestrator can be used just like any other agent
|
57
58
|
task = (
|
@@ -25,7 +25,6 @@ SAMPLE_REQUESTS = [
|
|
25
25
|
name="fetcher",
|
26
26
|
instruction="""You are an agent, with a tool enabling you to fetch URLs.""",
|
27
27
|
servers=["fetch"],
|
28
|
-
model="haiku",
|
29
28
|
)
|
30
29
|
@fast.agent(
|
31
30
|
name="code_expert",
|
@@ -33,7 +32,6 @@ SAMPLE_REQUESTS = [
|
|
33
32
|
When asked about code, architecture, or development practices,
|
34
33
|
you provide thorough and practical insights.""",
|
35
34
|
servers=["filesystem"],
|
36
|
-
model="gpt-4o",
|
37
35
|
)
|
38
36
|
@fast.agent(
|
39
37
|
name="general_assistant",
|
@@ -12,6 +12,7 @@ from mcp_agent.workflows.llm.augmented_llm import (
|
|
12
12
|
)
|
13
13
|
from mcp_agent.agents.agent import Agent, AgentConfig
|
14
14
|
from mcp_agent.logging.logger import get_logger
|
15
|
+
from mcp_agent.workflows.llm.augmented_llm_passthrough import PassthroughLLM
|
15
16
|
|
16
17
|
if TYPE_CHECKING:
|
17
18
|
from mcp_agent.context import Context
|
@@ -89,45 +90,33 @@ class EvaluatorOptimizerLLM(AugmentedLLM[MessageParamT, MessageT]):
|
|
89
90
|
evaluator: str | Agent | AugmentedLLM,
|
90
91
|
min_rating: QualityRating = QualityRating.GOOD,
|
91
92
|
max_refinements: int = 3,
|
92
|
-
llm_factory: Callable[[Agent], AugmentedLLM]
|
93
|
-
| None = None, # TODO: Remove legacy - factory should only be needed for str evaluator
|
93
|
+
llm_factory: Callable[[Agent], AugmentedLLM] | None = None,
|
94
94
|
context: Optional["Context"] = None,
|
95
|
-
name: Optional[str] = None,
|
96
|
-
instruction: Optional[str] = None,
|
95
|
+
name: Optional[str] = None,
|
96
|
+
instruction: Optional[str] = None,
|
97
97
|
):
|
98
98
|
"""
|
99
99
|
Initialize the evaluator-optimizer workflow.
|
100
100
|
|
101
101
|
Args:
|
102
|
-
generator: The agent/LLM/workflow that generates responses
|
103
|
-
|
104
|
-
- An AugmentedLLM instance
|
105
|
-
- An Orchestrator/Router/ParallelLLM workflow
|
106
|
-
evaluator_agent: The agent/LLM that evaluates responses
|
107
|
-
evaluation_criteria: Criteria for the evaluator to assess responses
|
102
|
+
generator: The agent/LLM/workflow that generates responses
|
103
|
+
evaluator: The evaluator (string instruction, Agent or AugmentedLLM)
|
108
104
|
min_rating: Minimum acceptable quality rating
|
109
105
|
max_refinements: Maximum refinement iterations
|
110
|
-
llm_factory:
|
106
|
+
llm_factory: Factory to create LLMs from agents when needed
|
111
107
|
name: Optional name for the workflow (defaults to generator's name)
|
112
108
|
instruction: Optional instruction (defaults to generator's instruction)
|
113
|
-
|
114
|
-
Note on History Management:
|
115
|
-
This workflow manages two distinct history contexts:
|
116
|
-
1. Generator History: Controlled by the generator's use_history setting. When False,
|
117
|
-
each refinement iteration starts fresh without previous context.
|
118
|
-
2. Evaluator History: Always disabled as each evaluation should be independent
|
119
|
-
and based solely on the current response.
|
120
109
|
"""
|
121
|
-
# Set
|
122
|
-
self.name = name or generator
|
110
|
+
# Set initial attributes
|
111
|
+
self.name = name or getattr(generator, "name", "EvaluatorOptimizer")
|
123
112
|
self.llm_factory = llm_factory
|
124
113
|
self.generator = generator
|
125
114
|
self.evaluator = evaluator
|
126
115
|
self.min_rating = min_rating
|
127
116
|
self.max_refinements = max_refinements
|
128
117
|
|
129
|
-
# Determine generator's history setting
|
130
|
-
|
118
|
+
# Determine generator's history setting directly based on type
|
119
|
+
self.generator_use_history = False
|
131
120
|
if isinstance(generator, Agent):
|
132
121
|
self.generator_use_history = generator.config.use_history
|
133
122
|
elif isinstance(generator, AugmentedLLM):
|
@@ -135,90 +124,55 @@ class EvaluatorOptimizerLLM(AugmentedLLM[MessageParamT, MessageT]):
|
|
135
124
|
generator.aggregator, Agent
|
136
125
|
):
|
137
126
|
self.generator_use_history = generator.aggregator.config.use_history
|
138
|
-
|
127
|
+
elif hasattr(generator, "default_request_params"):
|
139
128
|
self.generator_use_history = getattr(
|
140
|
-
generator,
|
141
|
-
"use_history",
|
142
|
-
getattr(generator.default_request_params, "use_history", False),
|
129
|
+
generator.default_request_params, "use_history", False
|
143
130
|
)
|
144
|
-
#
|
145
|
-
elif hasattr(generator, "_sequence") and hasattr(generator, "_agent_proxies"):
|
146
|
-
# This is how we detect a ChainProxy without directly importing it
|
147
|
-
# For ChainProxy, we'll default use_history to False
|
148
|
-
self.generator_use_history = False
|
149
|
-
else:
|
150
|
-
raise ValueError(f"Unsupported optimizer type: {type(generator)}")
|
151
|
-
|
152
|
-
# Now we can call super().__init__ which will use generator_use_history
|
153
|
-
super().__init__(context=context, name=name or generator.name)
|
131
|
+
# All other types default to False
|
154
132
|
|
155
|
-
#
|
156
|
-
|
133
|
+
# Initialize parent class
|
134
|
+
super().__init__(context=context, name=name or getattr(generator, "name", None))
|
157
135
|
|
136
|
+
# Create a PassthroughLLM as _llm property
|
137
|
+
# TODO -- remove this when we fix/remove the inheritance hierarchy
|
158
138
|
self._llm = PassthroughLLM(name=f"{self.name}_passthrough", context=context)
|
159
139
|
|
160
|
-
# Set up the generator
|
161
|
-
|
140
|
+
# Set up the generator based on type
|
162
141
|
if isinstance(generator, Agent):
|
163
142
|
if not llm_factory:
|
164
|
-
raise ValueError(
|
165
|
-
|
166
|
-
|
167
|
-
if hasattr(generator, "_llm") and generator._llm:
|
168
|
-
self.generator_llm = generator._llm
|
169
|
-
else:
|
170
|
-
self.generator_llm = llm_factory(agent=generator)
|
143
|
+
raise ValueError(
|
144
|
+
"llm_factory is required when using an Agent generator"
|
145
|
+
)
|
171
146
|
|
147
|
+
# Use existing LLM if available, otherwise create new one
|
148
|
+
self.generator_llm = getattr(generator, "_llm", None) or llm_factory(
|
149
|
+
agent=generator
|
150
|
+
)
|
172
151
|
self.aggregator = generator
|
173
|
-
self.instruction = (
|
174
|
-
instruction
|
175
|
-
|
176
|
-
|
177
|
-
if isinstance(generator.instruction, str)
|
178
|
-
else None
|
179
|
-
) # Fallback to generator's
|
152
|
+
self.instruction = instruction or (
|
153
|
+
generator.instruction
|
154
|
+
if isinstance(generator.instruction, str)
|
155
|
+
else None
|
180
156
|
)
|
181
|
-
elif
|
182
|
-
|
157
|
+
elif isinstance(generator, AugmentedLLM):
|
158
|
+
self.generator_llm = generator
|
159
|
+
self.aggregator = getattr(generator, "aggregator", None)
|
160
|
+
self.instruction = instruction or generator.instruction
|
161
|
+
else:
|
162
|
+
# ChainProxy-like object
|
183
163
|
self.generator_llm = generator
|
184
164
|
self.aggregator = None
|
185
165
|
self.instruction = (
|
186
166
|
instruction or f"Chain of agents: {', '.join(generator._sequence)}"
|
187
167
|
)
|
188
168
|
|
189
|
-
|
190
|
-
|
191
|
-
self.aggregator = generator.aggregator
|
192
|
-
self.instruction = generator.instruction
|
193
|
-
|
194
|
-
# Set up the evaluator - evaluations should be independent, so history is always disabled
|
195
|
-
if isinstance(evaluator, AugmentedLLM):
|
196
|
-
self.evaluator_llm = evaluator
|
197
|
-
# Override evaluator's history setting
|
198
|
-
if hasattr(evaluator, "default_request_params"):
|
199
|
-
evaluator.default_request_params.use_history = False
|
200
|
-
elif isinstance(evaluator, Agent):
|
201
|
-
if not llm_factory:
|
202
|
-
raise ValueError(
|
203
|
-
"llm_factory is required when using an Agent evaluator"
|
204
|
-
)
|
205
|
-
|
206
|
-
# Create evaluator with history disabled
|
207
|
-
if hasattr(evaluator, "_llm") and evaluator._llm:
|
208
|
-
self.evaluator_llm = evaluator._llm
|
209
|
-
if hasattr(self.evaluator_llm, "default_request_params"):
|
210
|
-
self.evaluator_llm.default_request_params.use_history = False
|
211
|
-
else:
|
212
|
-
# Force history off in config before creating LLM
|
213
|
-
evaluator.config.use_history = False
|
214
|
-
self.evaluator_llm = llm_factory(agent=evaluator)
|
215
|
-
elif isinstance(evaluator, str):
|
169
|
+
# Set up the evaluator - always disable history
|
170
|
+
if isinstance(evaluator, str):
|
216
171
|
if not llm_factory:
|
217
172
|
raise ValueError(
|
218
173
|
"llm_factory is required when using a string evaluator"
|
219
174
|
)
|
220
175
|
|
221
|
-
# Create evaluator agent with history disabled
|
222
176
|
evaluator_agent = Agent(
|
223
177
|
name="Evaluator",
|
224
178
|
instruction=evaluator,
|
@@ -226,17 +180,33 @@ class EvaluatorOptimizerLLM(AugmentedLLM[MessageParamT, MessageT]):
|
|
226
180
|
name="Evaluator",
|
227
181
|
instruction=evaluator,
|
228
182
|
servers=[],
|
229
|
-
use_history=False,
|
183
|
+
use_history=False,
|
230
184
|
),
|
231
185
|
)
|
232
186
|
self.evaluator_llm = llm_factory(agent=evaluator_agent)
|
187
|
+
elif isinstance(evaluator, Agent):
|
188
|
+
if not llm_factory:
|
189
|
+
raise ValueError(
|
190
|
+
"llm_factory is required when using an Agent evaluator"
|
191
|
+
)
|
192
|
+
|
193
|
+
# Disable history and use/create LLM
|
194
|
+
evaluator.config.use_history = False
|
195
|
+
self.evaluator_llm = getattr(evaluator, "_llm", None) or llm_factory(
|
196
|
+
agent=evaluator
|
197
|
+
)
|
198
|
+
elif isinstance(evaluator, AugmentedLLM):
|
199
|
+
self.evaluator_llm = evaluator
|
200
|
+
# Ensure history is disabled
|
201
|
+
if hasattr(self.evaluator_llm, "default_request_params"):
|
202
|
+
self.evaluator_llm.default_request_params.use_history = False
|
233
203
|
else:
|
234
204
|
raise ValueError(f"Unsupported evaluator type: {type(evaluator)}")
|
235
205
|
|
236
|
-
# Track iteration history
|
206
|
+
# Track iteration history
|
237
207
|
self.refinement_history = []
|
238
208
|
|
239
|
-
# Set up workflow's default params
|
209
|
+
# Set up workflow's default params
|
240
210
|
self.default_request_params = self._initialize_default_params({})
|
241
211
|
|
242
212
|
# Ensure evaluator's request params have history disabled
|
@@ -1,7 +1,6 @@
|
|
1
1
|
from abc import abstractmethod
|
2
2
|
|
3
3
|
from typing import (
|
4
|
-
Any,
|
5
4
|
Generic,
|
6
5
|
List,
|
7
6
|
Optional,
|
@@ -9,7 +8,6 @@ from typing import (
|
|
9
8
|
Type,
|
10
9
|
TypeVar,
|
11
10
|
TYPE_CHECKING,
|
12
|
-
Union,
|
13
11
|
)
|
14
12
|
|
15
13
|
from pydantic import Field
|
@@ -567,7 +565,7 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
567
565
|
text_parts.append(part.text)
|
568
566
|
if text_parts:
|
569
567
|
return "\n".join(text_parts)
|
570
|
-
|
568
|
+
|
571
569
|
# For objects with content attribute
|
572
570
|
if hasattr(message, "content"):
|
573
571
|
content = message.content
|
@@ -575,7 +573,7 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
575
573
|
return content
|
576
574
|
elif hasattr(content, "text"):
|
577
575
|
return content.text
|
578
|
-
|
576
|
+
|
579
577
|
# Default fallback
|
580
578
|
return str(message)
|
581
579
|
|
@@ -588,7 +586,7 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
588
586
|
result = self.message_param_str(message)
|
589
587
|
if result != str(message):
|
590
588
|
return result
|
591
|
-
|
589
|
+
|
592
590
|
# Additional handling for output-specific formats
|
593
591
|
if hasattr(message, "content"):
|
594
592
|
content = message.content
|
@@ -600,7 +598,7 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
600
598
|
text_parts.append(block.text)
|
601
599
|
if text_parts:
|
602
600
|
return "\n".join(text_parts)
|
603
|
-
|
601
|
+
|
604
602
|
# Default fallback
|
605
603
|
return str(message)
|
606
604
|
|
@@ -650,7 +648,7 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
650
648
|
):
|
651
649
|
"""
|
652
650
|
Display information about a loaded prompt template.
|
653
|
-
|
651
|
+
|
654
652
|
Args:
|
655
653
|
prompt_name: The name of the prompt
|
656
654
|
description: Optional description of the prompt
|
@@ -679,11 +677,11 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
679
677
|
prompt_name: The name of the prompt being applied
|
680
678
|
|
681
679
|
Returns:
|
682
|
-
String representation of the assistant's response if generated,
|
680
|
+
String representation of the assistant's response if generated,
|
683
681
|
or the last assistant message in the prompt
|
684
682
|
"""
|
685
683
|
prompt_messages: List[PromptMessage] = prompt_result.messages
|
686
|
-
|
684
|
+
|
687
685
|
# Check if we have any messages
|
688
686
|
if not prompt_messages:
|
689
687
|
return "Prompt contains no messages"
|
@@ -698,14 +696,16 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
698
696
|
message_count=len(prompt_messages),
|
699
697
|
arguments=arguments,
|
700
698
|
)
|
701
|
-
|
699
|
+
|
702
700
|
# Check the last message role
|
703
701
|
last_message = prompt_messages[-1]
|
704
|
-
|
702
|
+
|
705
703
|
if last_message.role == "user":
|
706
704
|
# For user messages: Add all previous messages to history, then generate response to the last one
|
707
|
-
self.logger.debug(
|
708
|
-
|
705
|
+
self.logger.debug(
|
706
|
+
"Last message in prompt is from user, generating assistant response"
|
707
|
+
)
|
708
|
+
|
709
709
|
# Add all but the last message to history
|
710
710
|
if len(prompt_messages) > 1:
|
711
711
|
previous_messages = prompt_messages[:-1]
|
@@ -713,87 +713,28 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol[MessageParamT, Message
|
|
713
713
|
for msg in previous_messages:
|
714
714
|
converted.append(self.type_converter.from_mcp_prompt_message(msg))
|
715
715
|
self.history.extend(converted, is_prompt=True)
|
716
|
-
|
716
|
+
|
717
717
|
# Extract the user's question and generate a response
|
718
718
|
user_content = last_message.content
|
719
|
-
user_text =
|
720
|
-
|
719
|
+
user_text = (
|
720
|
+
user_content.text
|
721
|
+
if hasattr(user_content, "text")
|
722
|
+
else str(user_content)
|
723
|
+
)
|
724
|
+
|
721
725
|
return await self.generate_str(user_text)
|
722
726
|
else:
|
723
727
|
# For assistant messages: Add all messages to history and return the last one
|
724
|
-
self.logger.debug(
|
725
|
-
|
728
|
+
self.logger.debug(
|
729
|
+
"Last message in prompt is from assistant, returning it directly"
|
730
|
+
)
|
731
|
+
|
726
732
|
# Convert and add all messages to history
|
727
733
|
converted = []
|
728
734
|
for msg in prompt_messages:
|
729
735
|
converted.append(self.type_converter.from_mcp_prompt_message(msg))
|
730
736
|
self.history.extend(converted, is_prompt=True)
|
731
|
-
|
737
|
+
|
732
738
|
# Return the assistant's message
|
733
739
|
content = last_message.content
|
734
740
|
return content.text if hasattr(content, "text") else str(content)
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
class PassthroughLLM(AugmentedLLM):
|
739
|
-
"""
|
740
|
-
A specialized LLM implementation that simply passes through input messages without modification.
|
741
|
-
|
742
|
-
This is useful for cases where you need an object with the AugmentedLLM interface
|
743
|
-
but want to preserve the original message without any processing, such as in a
|
744
|
-
parallel workflow where no fan-in aggregation is needed.
|
745
|
-
"""
|
746
|
-
|
747
|
-
def __init__(self, name: str = "Passthrough", context=None, **kwargs):
|
748
|
-
super().__init__(name=name, context=context, **kwargs)
|
749
|
-
|
750
|
-
async def generate(
|
751
|
-
self,
|
752
|
-
message: Union[str, MessageParamT, List[MessageParamT]],
|
753
|
-
request_params: Optional[RequestParams] = None,
|
754
|
-
) -> Union[List[MessageT], Any]:
|
755
|
-
"""Simply return the input message as is."""
|
756
|
-
# Return in the format expected by the caller
|
757
|
-
return [message] if isinstance(message, list) else message
|
758
|
-
|
759
|
-
async def generate_str(
|
760
|
-
self,
|
761
|
-
message: Union[str, MessageParamT, List[MessageParamT]],
|
762
|
-
request_params: Optional[RequestParams] = None,
|
763
|
-
) -> str:
|
764
|
-
"""Return the input message as a string."""
|
765
|
-
self.show_user_message(message, model="fastagent-passthrough", chat_turn=0)
|
766
|
-
await self.show_assistant_message(message, title="ASSISTANT/PASSTHROUGH")
|
767
|
-
|
768
|
-
return str(message)
|
769
|
-
|
770
|
-
async def generate_structured(
|
771
|
-
self,
|
772
|
-
message: Union[str, MessageParamT, List[MessageParamT]],
|
773
|
-
response_model: Type[ModelT],
|
774
|
-
request_params: Optional[RequestParams] = None,
|
775
|
-
) -> ModelT:
|
776
|
-
"""
|
777
|
-
Return the input message as the requested model type.
|
778
|
-
This is a best-effort implementation - it may fail if the
|
779
|
-
message cannot be converted to the requested model.
|
780
|
-
"""
|
781
|
-
if isinstance(message, response_model):
|
782
|
-
return message
|
783
|
-
elif isinstance(message, dict):
|
784
|
-
return response_model(**message)
|
785
|
-
elif isinstance(message, str):
|
786
|
-
try:
|
787
|
-
# Try to parse as JSON if it's a string
|
788
|
-
import json
|
789
|
-
|
790
|
-
data = json.loads(message)
|
791
|
-
return response_model(**data)
|
792
|
-
except: # noqa: E722
|
793
|
-
raise ValueError(
|
794
|
-
f"Cannot convert message of type {type(message)} to {response_model}"
|
795
|
-
)
|
796
|
-
else:
|
797
|
-
raise ValueError(
|
798
|
-
f"Cannot convert message of type {type(message)} to {response_model}"
|
799
|
-
)
|
@@ -4,7 +4,6 @@ from typing import Iterable, List, Type
|
|
4
4
|
|
5
5
|
from pydantic import BaseModel
|
6
6
|
|
7
|
-
import instructor
|
8
7
|
from anthropic import Anthropic, AuthenticationError
|
9
8
|
from anthropic.types import (
|
10
9
|
ContentBlock,
|
@@ -27,8 +26,8 @@ from mcp.types import (
|
|
27
26
|
TextContent,
|
28
27
|
TextResourceContents,
|
29
28
|
)
|
29
|
+
from pydantic_core import from_json
|
30
30
|
|
31
|
-
from mcp_agent.workflows.router.router_llm import StructuredResponse
|
32
31
|
from mcp_agent.workflows.llm.augmented_llm import (
|
33
32
|
AugmentedLLM,
|
34
33
|
ModelT,
|
@@ -96,7 +95,7 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
96
95
|
"Please check that your API key is valid and not expired.",
|
97
96
|
) from e
|
98
97
|
|
99
|
-
# Always include prompt messages, but only include conversation history
|
98
|
+
# Always include prompt messages, but only include conversation history
|
100
99
|
# if use_history is True
|
101
100
|
messages.extend(self.history.get(include_history=params.use_history))
|
102
101
|
|
@@ -295,10 +294,10 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
295
294
|
if params.use_history:
|
296
295
|
# Get current prompt messages
|
297
296
|
prompt_messages = self.history.get(include_history=False)
|
298
|
-
|
297
|
+
|
299
298
|
# Calculate new conversation messages (excluding prompts)
|
300
|
-
new_messages = messages[len(prompt_messages):]
|
301
|
-
|
299
|
+
new_messages = messages[len(prompt_messages) :]
|
300
|
+
|
302
301
|
# Update conversation history
|
303
302
|
self.history.set(new_messages)
|
304
303
|
|
@@ -367,10 +366,7 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
367
366
|
response_model: Type[ModelT],
|
368
367
|
request_params: RequestParams | None = None,
|
369
368
|
) -> ModelT:
|
370
|
-
#
|
371
|
-
# We need to do this in a two-step process because Instructor doesn't
|
372
|
-
# know how to invoke MCP tools via call_tool, so we'll handle all the
|
373
|
-
# processing first and then pass the final response through Instructor
|
369
|
+
# TODO -- simiar to the OAI version, we should create a tool call for the expected schema
|
374
370
|
response = await self.generate_str(
|
375
371
|
message=message,
|
376
372
|
request_params=request_params,
|
@@ -378,27 +374,9 @@ class AnthropicAugmentedLLM(AugmentedLLM[MessageParam, Message]):
|
|
378
374
|
# Don't try to parse if we got no response
|
379
375
|
if not response:
|
380
376
|
self.logger.error("No response from generate_str")
|
381
|
-
return
|
382
|
-
|
383
|
-
# Next we pass the text through instructor to extract structured data
|
384
|
-
client = instructor.from_anthropic(
|
385
|
-
Anthropic(api_key=self._api_key(self.context.config)),
|
386
|
-
)
|
377
|
+
return None
|
387
378
|
|
388
|
-
|
389
|
-
model = await self.select_model(params)
|
390
|
-
|
391
|
-
# Extract structured data from natural language
|
392
|
-
structured_response = client.chat.completions.create(
|
393
|
-
model=model,
|
394
|
-
response_model=response_model,
|
395
|
-
messages=[{"role": "user", "content": response}],
|
396
|
-
max_tokens=params.maxTokens,
|
397
|
-
)
|
398
|
-
await self.show_assistant_message(
|
399
|
-
str(structured_response), title="ASSISTANT/STRUCTURED"
|
400
|
-
)
|
401
|
-
return structured_response
|
379
|
+
return response_model.model_validate(from_json(response, allow_partial=True))
|
402
380
|
|
403
381
|
@classmethod
|
404
382
|
def convert_message_to_message_param(
|
@@ -2,15 +2,16 @@ import json
|
|
2
2
|
import os
|
3
3
|
from typing import Iterable, List, Type
|
4
4
|
from mcp.types import PromptMessage
|
5
|
-
import instructor
|
6
5
|
from openai import OpenAI, AuthenticationError
|
6
|
+
|
7
|
+
# from openai.types.beta.chat import
|
7
8
|
from openai.types.chat import (
|
8
9
|
ChatCompletionAssistantMessageParam,
|
10
|
+
ChatCompletionMessageParam,
|
9
11
|
ChatCompletionContentPartParam,
|
10
12
|
ChatCompletionContentPartTextParam,
|
11
13
|
ChatCompletionContentPartRefusalParam,
|
12
14
|
ChatCompletionMessage,
|
13
|
-
ChatCompletionMessageParam,
|
14
15
|
ChatCompletionSystemMessageParam,
|
15
16
|
ChatCompletionToolParam,
|
16
17
|
ChatCompletionToolMessageParam,
|
@@ -128,7 +129,12 @@ class OpenAIAugmentedLLM(
|
|
128
129
|
self.context.config.openai.base_url if self.context.config.openai else None
|
129
130
|
)
|
130
131
|
|
131
|
-
async def generate(
|
132
|
+
async def generate(
|
133
|
+
self,
|
134
|
+
message,
|
135
|
+
request_params: RequestParams | None = None,
|
136
|
+
response_model: Type[ModelT] | None = None,
|
137
|
+
) -> List[ChatCompletionMessage]:
|
132
138
|
"""
|
133
139
|
Process a query using an LLM and available tools.
|
134
140
|
The default implementation uses OpenAI's ChatCompletion as the LLM.
|
@@ -152,7 +158,7 @@ class OpenAIAugmentedLLM(
|
|
152
158
|
ChatCompletionSystemMessageParam(role="system", content=system_prompt)
|
153
159
|
)
|
154
160
|
|
155
|
-
# Always include prompt messages, but only include conversation history
|
161
|
+
# Always include prompt messages, but only include conversation history
|
156
162
|
# if use_history is True
|
157
163
|
messages.extend(self.history.get(include_history=params.use_history))
|
158
164
|
|
@@ -179,7 +185,7 @@ class OpenAIAugmentedLLM(
|
|
179
185
|
for tool in response.tools
|
180
186
|
]
|
181
187
|
if not available_tools:
|
182
|
-
available_tools =
|
188
|
+
available_tools = []
|
183
189
|
|
184
190
|
responses: List[ChatCompletionMessage] = []
|
185
191
|
model = await self.select_model(params)
|
@@ -215,9 +221,16 @@ class OpenAIAugmentedLLM(
|
|
215
221
|
self.logger.debug(f"{arguments}")
|
216
222
|
self._log_chat_progress(chat_turn, model=model)
|
217
223
|
|
218
|
-
|
219
|
-
|
220
|
-
|
224
|
+
if response_model is None:
|
225
|
+
executor_result = await self.executor.execute(
|
226
|
+
openai_client.chat.completions.create, **arguments
|
227
|
+
)
|
228
|
+
else:
|
229
|
+
executor_result = await self.executor.execute(
|
230
|
+
openai_client.beta.chat.completions.parse,
|
231
|
+
**arguments,
|
232
|
+
response_format=response_model,
|
233
|
+
)
|
221
234
|
|
222
235
|
response = executor_result[0]
|
223
236
|
|
@@ -334,10 +347,10 @@ class OpenAIAugmentedLLM(
|
|
334
347
|
if params.use_history:
|
335
348
|
# Get current prompt messages
|
336
349
|
prompt_messages = self.history.get(include_history=False)
|
337
|
-
|
350
|
+
|
338
351
|
# Calculate new conversation messages (excluding prompts)
|
339
|
-
new_messages = messages[len(prompt_messages):]
|
340
|
-
|
352
|
+
new_messages = messages[len(prompt_messages) :]
|
353
|
+
|
341
354
|
# Update conversation history
|
342
355
|
self.history.set(new_messages)
|
343
356
|
|
@@ -379,40 +392,21 @@ class OpenAIAugmentedLLM(
|
|
379
392
|
response_model: Type[ModelT],
|
380
393
|
request_params: RequestParams | None = None,
|
381
394
|
) -> ModelT:
|
382
|
-
|
383
|
-
# We need to do this in a two-step process because Instructor doesn't
|
384
|
-
# know how to invoke MCP tools via call_tool, so we'll handle all the
|
385
|
-
# processing first and then pass the final response through Instructor
|
386
|
-
response = await self.generate_str(
|
395
|
+
responses = await self.generate(
|
387
396
|
message=message,
|
388
397
|
request_params=request_params,
|
389
|
-
)
|
390
|
-
|
391
|
-
# Next we pass the text through instructor to extract structured data
|
392
|
-
client = instructor.from_openai(
|
393
|
-
OpenAI(
|
394
|
-
api_key=self._api_key(),
|
395
|
-
base_url=self._base_url(),
|
396
|
-
),
|
397
|
-
mode=instructor.Mode.TOOLS_STRICT,
|
398
|
-
)
|
399
|
-
|
400
|
-
params = self.get_request_params(request_params)
|
401
|
-
model = await self.select_model(params)
|
402
|
-
|
403
|
-
# Extract structured data from natural language
|
404
|
-
structured_response = client.chat.completions.create(
|
405
|
-
model=model,
|
406
398
|
response_model=response_model,
|
407
|
-
messages=[
|
408
|
-
{"role": "user", "content": response},
|
409
|
-
],
|
410
|
-
)
|
411
|
-
await self.show_assistant_message(
|
412
|
-
str(structured_response), title="ASSISTANT/STRUCTURED"
|
413
399
|
)
|
400
|
+
return responses[0].parsed
|
401
|
+
|
402
|
+
# return response_model.model_validate(
|
403
|
+
# from_json(responses[0].content, allow_partial=True)
|
404
|
+
# )
|
405
|
+
# part1 = from_json(response, allow_partial=True)
|
406
|
+
# return response_model.model_validate(part1)
|
414
407
|
|
415
|
-
|
408
|
+
# TODO -- would prefer to use the OpenAI message[0].parsed function here
|
409
|
+
# return response_model.model_validate(from_json(response, allow_partial=True))
|
416
410
|
|
417
411
|
async def pre_tool_call(self, tool_call_id: str | None, request: CallToolRequest):
|
418
412
|
return request
|