quantalogic 0.59.3__py3-none-any.whl → 0.61.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quantalogic/agent.py +268 -24
- quantalogic/agent_config.py +5 -5
- quantalogic/agent_factory.py +2 -2
- quantalogic/codeact/__init__.py +0 -0
- quantalogic/codeact/agent.py +499 -0
- quantalogic/codeact/cli.py +232 -0
- quantalogic/codeact/constants.py +9 -0
- quantalogic/codeact/events.py +78 -0
- quantalogic/codeact/llm_util.py +76 -0
- quantalogic/codeact/prompts/error_format.j2 +11 -0
- quantalogic/codeact/prompts/generate_action.j2 +26 -0
- quantalogic/codeact/prompts/generate_program.j2 +39 -0
- quantalogic/codeact/prompts/response_format.j2 +11 -0
- quantalogic/codeact/tools_manager.py +135 -0
- quantalogic/codeact/utils.py +135 -0
- quantalogic/coding_agent.py +2 -2
- quantalogic/create_custom_agent.py +26 -78
- quantalogic/prompts/chat_system_prompt.j2 +10 -7
- quantalogic/prompts/code_2_system_prompt.j2 +190 -0
- quantalogic/prompts/code_system_prompt.j2 +142 -0
- quantalogic/prompts/doc_system_prompt.j2 +178 -0
- quantalogic/prompts/legal_2_system_prompt.j2 +218 -0
- quantalogic/prompts/legal_system_prompt.j2 +140 -0
- quantalogic/prompts/system_prompt.j2 +6 -2
- quantalogic/prompts/tools_prompt.j2 +2 -4
- quantalogic/prompts.py +23 -4
- quantalogic/python_interpreter/__init__.py +23 -0
- quantalogic/python_interpreter/assignment_visitors.py +63 -0
- quantalogic/python_interpreter/base_visitors.py +20 -0
- quantalogic/python_interpreter/class_visitors.py +22 -0
- quantalogic/python_interpreter/comprehension_visitors.py +172 -0
- quantalogic/python_interpreter/context_visitors.py +59 -0
- quantalogic/python_interpreter/control_flow_visitors.py +88 -0
- quantalogic/python_interpreter/exception_visitors.py +109 -0
- quantalogic/python_interpreter/exceptions.py +39 -0
- quantalogic/python_interpreter/execution.py +202 -0
- quantalogic/python_interpreter/function_utils.py +386 -0
- quantalogic/python_interpreter/function_visitors.py +209 -0
- quantalogic/python_interpreter/import_visitors.py +28 -0
- quantalogic/python_interpreter/interpreter_core.py +358 -0
- quantalogic/python_interpreter/literal_visitors.py +74 -0
- quantalogic/python_interpreter/misc_visitors.py +148 -0
- quantalogic/python_interpreter/operator_visitors.py +108 -0
- quantalogic/python_interpreter/scope.py +10 -0
- quantalogic/python_interpreter/visit_handlers.py +110 -0
- quantalogic/server/agent_server.py +1 -1
- quantalogic/tools/__init__.py +6 -3
- quantalogic/tools/action_gen.py +366 -0
- quantalogic/tools/duckduckgo_search_tool.py +1 -0
- quantalogic/tools/execute_bash_command_tool.py +114 -57
- quantalogic/tools/file_tracker_tool.py +49 -0
- quantalogic/tools/google_packages/google_news_tool.py +3 -0
- quantalogic/tools/image_generation/dalle_e.py +89 -137
- quantalogic/tools/python_tool.py +13 -0
- quantalogic/tools/rag_tool/__init__.py +2 -9
- quantalogic/tools/rag_tool/document_rag_sources_.py +728 -0
- quantalogic/tools/rag_tool/ocr_pdf_markdown.py +144 -0
- quantalogic/tools/replace_in_file_tool.py +1 -1
- quantalogic/tools/{search_definition_names.py → search_definition_names_tool.py} +2 -2
- quantalogic/tools/terminal_capture_tool.py +293 -0
- quantalogic/tools/tool.py +120 -22
- quantalogic/tools/utilities/__init__.py +2 -0
- quantalogic/tools/utilities/download_file_tool.py +3 -5
- quantalogic/tools/utilities/llm_tool.py +283 -0
- quantalogic/tools/utilities/selenium_tool.py +296 -0
- quantalogic/tools/utilities/vscode_tool.py +1 -1
- quantalogic/tools/web_navigation/__init__.py +5 -0
- quantalogic/tools/web_navigation/web_tool.py +145 -0
- quantalogic/tools/write_file_tool.py +72 -36
- quantalogic/utils/__init__.py +0 -1
- quantalogic/utils/test_python_interpreter.py +119 -0
- {quantalogic-0.59.3.dist-info → quantalogic-0.61.0.dist-info}/METADATA +7 -2
- {quantalogic-0.59.3.dist-info → quantalogic-0.61.0.dist-info}/RECORD +76 -35
- quantalogic/tools/rag_tool/document_metadata.py +0 -15
- quantalogic/tools/rag_tool/query_response.py +0 -20
- quantalogic/tools/rag_tool/rag_tool.py +0 -566
- quantalogic/tools/rag_tool/rag_tool_beta.py +0 -264
- quantalogic/utils/python_interpreter.py +0 -905
- {quantalogic-0.59.3.dist-info → quantalogic-0.61.0.dist-info}/LICENSE +0 -0
- {quantalogic-0.59.3.dist-info → quantalogic-0.61.0.dist-info}/WHEEL +0 -0
- {quantalogic-0.59.3.dist-info → quantalogic-0.61.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,499 @@
|
|
1
|
+
import asyncio
|
2
|
+
import time
|
3
|
+
from typing import Callable, Dict, List, Optional, Tuple
|
4
|
+
|
5
|
+
from jinja2 import Environment, FileSystemLoader
|
6
|
+
from loguru import logger
|
7
|
+
from lxml import etree
|
8
|
+
|
9
|
+
from quantalogic.python_interpreter import execute_async
|
10
|
+
from quantalogic.tools import Tool
|
11
|
+
|
12
|
+
from .constants import MAX_GENERATE_PROGRAM_TOKENS, MAX_HISTORY_TOKENS, MAX_TOKENS, TEMPLATE_DIR
|
13
|
+
from .events import (
|
14
|
+
ActionExecutedEvent,
|
15
|
+
ActionGeneratedEvent,
|
16
|
+
ErrorOccurredEvent,
|
17
|
+
StepCompletedEvent,
|
18
|
+
StepStartedEvent,
|
19
|
+
TaskCompletedEvent,
|
20
|
+
TaskStartedEvent,
|
21
|
+
ThoughtGeneratedEvent,
|
22
|
+
ToolExecutionCompletedEvent,
|
23
|
+
ToolExecutionErrorEvent,
|
24
|
+
ToolExecutionStartedEvent,
|
25
|
+
)
|
26
|
+
from .llm_util import litellm_completion
|
27
|
+
from .tools_manager import RetrieveStepTool, get_default_tools
|
28
|
+
from .utils import XMLResultHandler, validate_code, validate_xml
|
29
|
+
|
30
|
+
jinja_env = Environment(loader=FileSystemLoader(TEMPLATE_DIR), trim_blocks=True, lstrip_blocks=True)
|
31
|
+
|
32
|
+
async def generate_program(
|
33
|
+
task_description: str,
|
34
|
+
tools: List[Tool],
|
35
|
+
model: str,
|
36
|
+
max_tokens: int,
|
37
|
+
step: int,
|
38
|
+
notify_event: Callable,
|
39
|
+
streaming: bool = False # New parameter for streaming
|
40
|
+
) -> str:
|
41
|
+
"""Generate a Python program using the specified model with streaming support."""
|
42
|
+
tool_docstrings = "\n\n".join(tool.to_docstring() for tool in tools)
|
43
|
+
prompt = jinja_env.get_template("generate_program.j2").render(
|
44
|
+
task_description=task_description,
|
45
|
+
tool_docstrings=tool_docstrings
|
46
|
+
)
|
47
|
+
|
48
|
+
for attempt in range(3):
|
49
|
+
try:
|
50
|
+
response = await litellm_completion(
|
51
|
+
model=model,
|
52
|
+
messages=[
|
53
|
+
{"role": "system", "content": "You are a Python code generator."},
|
54
|
+
{"role": "user", "content": prompt}
|
55
|
+
],
|
56
|
+
max_tokens=max_tokens,
|
57
|
+
temperature=0.3,
|
58
|
+
stream=streaming, # Use streaming flag
|
59
|
+
step=step,
|
60
|
+
notify_event=notify_event
|
61
|
+
)
|
62
|
+
code = response.strip()
|
63
|
+
return code[9:-3].strip() if code.startswith("```python") and code.endswith("```") else code
|
64
|
+
except Exception as e:
|
65
|
+
if attempt < 2:
|
66
|
+
await asyncio.sleep(2 ** attempt)
|
67
|
+
else:
|
68
|
+
raise Exception(f"Code generation failed with {model}: {e}")
|
69
|
+
|
70
|
+
class Reasoner:
|
71
|
+
"""Handles action generation using the language model."""
|
72
|
+
def __init__(self, model: str, tools: List[Tool]):
|
73
|
+
self.model = model
|
74
|
+
self.tools = tools
|
75
|
+
|
76
|
+
async def generate_action(
|
77
|
+
self,
|
78
|
+
task: str,
|
79
|
+
history_str: str,
|
80
|
+
step: int,
|
81
|
+
max_iterations: int,
|
82
|
+
system_prompt: Optional[str] = None,
|
83
|
+
notify_event: Callable = None,
|
84
|
+
streaming: bool = False # New parameter for streaming
|
85
|
+
) -> str:
|
86
|
+
"""Generate an action based on task and history with streaming support."""
|
87
|
+
try:
|
88
|
+
task_prompt = jinja_env.get_template("generate_action.j2").render(
|
89
|
+
task=task if not system_prompt else f"{system_prompt}\nTask: {task}",
|
90
|
+
history_str=history_str,
|
91
|
+
current_step=step,
|
92
|
+
max_iterations=max_iterations
|
93
|
+
)
|
94
|
+
program = await generate_program(task_prompt, self.tools, self.model, MAX_GENERATE_PROGRAM_TOKENS, step, notify_event, streaming=streaming)
|
95
|
+
response = jinja_env.get_template("response_format.j2").render(
|
96
|
+
task=task,
|
97
|
+
history_str=history_str,
|
98
|
+
program=program,
|
99
|
+
current_step=step,
|
100
|
+
max_iterations=max_iterations
|
101
|
+
)
|
102
|
+
if not validate_xml(response):
|
103
|
+
raise ValueError("Invalid XML generated")
|
104
|
+
return response
|
105
|
+
except Exception as e:
|
106
|
+
return jinja_env.get_template("error_format.j2").render(error=str(e))
|
107
|
+
|
108
|
+
class Executor:
|
109
|
+
"""Manages action execution and context updates."""
|
110
|
+
def __init__(self, tools: List[Tool], notify_event: Callable):
|
111
|
+
self.tools = tools
|
112
|
+
self.notify_event = notify_event # Callback to notify observers
|
113
|
+
self.tool_namespace = self._build_tool_namespace()
|
114
|
+
|
115
|
+
def _build_tool_namespace(self) -> Dict:
|
116
|
+
"""Build the namespace with wrapped tool functions that trigger events."""
|
117
|
+
def wrap_tool(tool):
|
118
|
+
async def wrapped_tool(**kwargs):
|
119
|
+
# Get the current step from the namespace
|
120
|
+
current_step = self.tool_namespace.get('current_step', None)
|
121
|
+
# Summarize parameters to keep events lightweight
|
122
|
+
parameters_summary = {
|
123
|
+
k: str(v)[:100] + "..." if len(str(v)) > 100 else str(v)
|
124
|
+
for k, v in kwargs.items()
|
125
|
+
}
|
126
|
+
# Trigger start event
|
127
|
+
await self.notify_event(ToolExecutionStartedEvent(
|
128
|
+
event_type="ToolExecutionStarted",
|
129
|
+
step_number=current_step,
|
130
|
+
tool_name=tool.name,
|
131
|
+
parameters_summary=parameters_summary
|
132
|
+
))
|
133
|
+
try:
|
134
|
+
result = await tool.async_execute(**kwargs)
|
135
|
+
# Summarize result
|
136
|
+
result_summary = str(result)[:100] + "..." if len(str(result)) > 100 else str(result)
|
137
|
+
# Trigger completion event
|
138
|
+
await self.notify_event(ToolExecutionCompletedEvent(
|
139
|
+
event_type="ToolExecutionCompleted",
|
140
|
+
step_number=current_step,
|
141
|
+
tool_name=tool.name,
|
142
|
+
result_summary=result_summary
|
143
|
+
))
|
144
|
+
return result
|
145
|
+
except Exception as e:
|
146
|
+
# Trigger error event
|
147
|
+
await self.notify_event(ToolExecutionErrorEvent(
|
148
|
+
event_type="ToolExecutionError",
|
149
|
+
step_number=current_step,
|
150
|
+
tool_name=tool.name,
|
151
|
+
error=str(e)
|
152
|
+
))
|
153
|
+
raise
|
154
|
+
return wrapped_tool
|
155
|
+
|
156
|
+
return {
|
157
|
+
"asyncio": asyncio,
|
158
|
+
"context_vars": {}, # Updated dynamically
|
159
|
+
**{tool.name: wrap_tool(tool) for tool in self.tools}
|
160
|
+
}
|
161
|
+
|
162
|
+
async def execute_action(self, code: str, context_vars: Dict, step: int, timeout: int = 300) -> str:
|
163
|
+
"""Execute the generated code and return the result, setting the step number."""
|
164
|
+
self.tool_namespace["context_vars"] = context_vars
|
165
|
+
self.tool_namespace['current_step'] = step # Set step for tools to access
|
166
|
+
if not validate_code(code):
|
167
|
+
return etree.tostring(
|
168
|
+
etree.Element("ExecutionResult", status="Error", message="Code lacks async main()"),
|
169
|
+
encoding="unicode"
|
170
|
+
)
|
171
|
+
|
172
|
+
try:
|
173
|
+
result = await execute_async(
|
174
|
+
code=code, timeout=timeout, entry_point="main",
|
175
|
+
allowed_modules=["asyncio"], namespace=self.tool_namespace
|
176
|
+
)
|
177
|
+
if result.local_variables:
|
178
|
+
context_vars.update({
|
179
|
+
k: v for k, v in result.local_variables.items()
|
180
|
+
if not k.startswith('__') and not callable(v)
|
181
|
+
})
|
182
|
+
return XMLResultHandler.format_execution_result(result)
|
183
|
+
except Exception as e:
|
184
|
+
return etree.tostring(
|
185
|
+
etree.Element("ExecutionResult", status="Error", message=f"Execution error: {e}"),
|
186
|
+
encoding="unicode"
|
187
|
+
)
|
188
|
+
|
189
|
+
class ReActAgent:
|
190
|
+
"""Core agent implementing the ReAct framework with modular components."""
|
191
|
+
def __init__(self, model: str, tools: List[Tool], max_iterations: int = 5, max_history_tokens: int = 2000):
|
192
|
+
self.reasoner = Reasoner(model, tools)
|
193
|
+
self.executor = Executor(tools, notify_event=self._notify_observers)
|
194
|
+
self.max_iterations = max_iterations
|
195
|
+
self.max_history_tokens = max_history_tokens # Limit history token size
|
196
|
+
self.context_vars: Dict = {}
|
197
|
+
self._observers: List[Tuple[Callable, List[str]]] = []
|
198
|
+
self.history_store: List[Dict] = [] # Persistent storage for all steps
|
199
|
+
|
200
|
+
def add_observer(self, observer: Callable, event_types: List[str]) -> 'ReActAgent':
|
201
|
+
"""Add an observer for specific event types."""
|
202
|
+
self._observers.append((observer, event_types))
|
203
|
+
return self
|
204
|
+
|
205
|
+
async def _notify_observers(self, event):
|
206
|
+
"""Notify all subscribed observers of an event."""
|
207
|
+
await asyncio.gather(
|
208
|
+
*(observer(event) for observer, types in self._observers if event.event_type in types),
|
209
|
+
return_exceptions=True
|
210
|
+
)
|
211
|
+
|
212
|
+
async def generate_action(
|
213
|
+
self,
|
214
|
+
task: str,
|
215
|
+
history: List[Dict],
|
216
|
+
step: int,
|
217
|
+
max_iterations: int,
|
218
|
+
system_prompt: Optional[str] = None,
|
219
|
+
streaming: bool = False # New parameter for streaming
|
220
|
+
) -> str:
|
221
|
+
"""Generate an action using the Reasoner."""
|
222
|
+
history_str = self._format_history(history, max_iterations)
|
223
|
+
start = time.perf_counter()
|
224
|
+
response = await self.reasoner.generate_action(task, history_str, step, max_iterations, system_prompt, self._notify_observers, streaming=streaming)
|
225
|
+
thought, code = XMLResultHandler.parse_response(response)
|
226
|
+
gen_time = time.perf_counter() - start
|
227
|
+
await self._notify_observers(ThoughtGeneratedEvent(
|
228
|
+
event_type="ThoughtGenerated", step_number=step, thought=thought, generation_time=gen_time
|
229
|
+
))
|
230
|
+
await self._notify_observers(ActionGeneratedEvent(
|
231
|
+
event_type="ActionGenerated", step_number=step, action_code=code, generation_time=gen_time
|
232
|
+
))
|
233
|
+
if not response.endswith("</Code>"):
|
234
|
+
logger.warning(f"Response might be truncated at step {step}")
|
235
|
+
return response
|
236
|
+
|
237
|
+
async def execute_action(self, code: str, step: int, timeout: int = 300) -> str:
|
238
|
+
"""Execute an action using the Executor, passing the step number."""
|
239
|
+
start = time.perf_counter()
|
240
|
+
result_xml = await self.executor.execute_action(code, self.context_vars, step, timeout)
|
241
|
+
execution_time = time.perf_counter() - start
|
242
|
+
await self._notify_observers(ActionExecutedEvent(
|
243
|
+
event_type="ActionExecuted", step_number=step, result_xml=result_xml, execution_time=execution_time
|
244
|
+
))
|
245
|
+
return result_xml
|
246
|
+
|
247
|
+
def _format_history(self, history: List[Dict], max_iterations: int) -> str:
|
248
|
+
"""Format the history with available variables, truncating to fit within max_history_tokens."""
|
249
|
+
included_steps = []
|
250
|
+
total_tokens = 0
|
251
|
+
for step in reversed(history): # Start from most recent
|
252
|
+
# Extract variables from context_vars updated after this step
|
253
|
+
try:
|
254
|
+
root = etree.fromstring(step['result'])
|
255
|
+
vars_elem = root.find("Variables")
|
256
|
+
available_vars = (
|
257
|
+
[var.get('name') for var in vars_elem.findall("Variable")]
|
258
|
+
if vars_elem is not None else []
|
259
|
+
)
|
260
|
+
except etree.XMLSyntaxError:
|
261
|
+
available_vars = []
|
262
|
+
|
263
|
+
step_str = (
|
264
|
+
f"===== Step {step['step_number']} of {max_iterations} max =====\n"
|
265
|
+
f"Thought:\n{step['thought']}\n\n"
|
266
|
+
f"Action:\n{step['action']}\n\n"
|
267
|
+
f"Result:\n{XMLResultHandler.format_result_summary(step['result'])}\n"
|
268
|
+
f"Available variables: {', '.join(available_vars) or 'None'}"
|
269
|
+
)
|
270
|
+
step_tokens = len(step_str.split()) # Approximate token count
|
271
|
+
if total_tokens + step_tokens > self.max_history_tokens:
|
272
|
+
break
|
273
|
+
included_steps.append(step_str)
|
274
|
+
total_tokens += step_tokens
|
275
|
+
return "\n".join(reversed(included_steps)) or "No previous steps"
|
276
|
+
|
277
|
+
async def is_task_complete(self, task: str, history: List[Dict], result: str, success_criteria: Optional[str]) -> Tuple[bool, str]:
|
278
|
+
"""Check if the task is complete based on the result."""
|
279
|
+
try:
|
280
|
+
root = etree.fromstring(result)
|
281
|
+
if root.findtext("Completed") == "true":
|
282
|
+
final_answer = root.findtext("FinalAnswer") or ""
|
283
|
+
verification = await litellm_completion(
|
284
|
+
model=self.reasoner.model,
|
285
|
+
messages=[{
|
286
|
+
"role": "user",
|
287
|
+
"content": f"Does '{final_answer}' solve '{task}' given history:\n{self._format_history(history, self.max_iterations)}?"
|
288
|
+
}],
|
289
|
+
max_tokens=100,
|
290
|
+
temperature=0.1,
|
291
|
+
stream=False # Non-streaming for quick verification
|
292
|
+
)
|
293
|
+
if "yes" in verification.lower():
|
294
|
+
return True, final_answer
|
295
|
+
return True, final_answer
|
296
|
+
except etree.XMLSyntaxError:
|
297
|
+
pass
|
298
|
+
|
299
|
+
if success_criteria and (result_value := XMLResultHandler.extract_result_value(result)) and success_criteria in result_value:
|
300
|
+
return True, result_value
|
301
|
+
return False, ""
|
302
|
+
|
303
|
+
async def solve(
|
304
|
+
self,
|
305
|
+
task: str,
|
306
|
+
success_criteria: Optional[str] = None,
|
307
|
+
system_prompt: Optional[str] = None,
|
308
|
+
max_iterations: Optional[int] = None,
|
309
|
+
streaming: bool = False # New parameter for streaming
|
310
|
+
) -> List[Dict]:
|
311
|
+
"""Solve a task using the ReAct framework."""
|
312
|
+
max_iters = max_iterations if max_iterations is not None else self.max_iterations
|
313
|
+
history = []
|
314
|
+
self.history_store = [] # Reset for each new task
|
315
|
+
await self._notify_observers(TaskStartedEvent(event_type="TaskStarted", task_description=task))
|
316
|
+
|
317
|
+
for step in range(1, max_iters + 1):
|
318
|
+
await self._notify_observers(StepStartedEvent(event_type="StepStarted", step_number=step))
|
319
|
+
try:
|
320
|
+
response = await self.generate_action(task, history, step, max_iters, system_prompt, streaming=streaming)
|
321
|
+
thought, code = XMLResultHandler.parse_response(response)
|
322
|
+
result = await self.execute_action(code, step)
|
323
|
+
step_data = {"step_number": step, "thought": thought, "action": code, "result": result}
|
324
|
+
history.append(step_data)
|
325
|
+
self.history_store.append(step_data) # Store every step persistently
|
326
|
+
|
327
|
+
is_complete, final_answer = await self.is_task_complete(task, history, result, success_criteria)
|
328
|
+
if is_complete:
|
329
|
+
history[-1]["result"] += f"\n<FinalAnswer><![CDATA[\n{final_answer}\n]]></FinalAnswer>"
|
330
|
+
|
331
|
+
await self._notify_observers(StepCompletedEvent(
|
332
|
+
event_type="StepCompleted", step_number=step, thought=thought,
|
333
|
+
action=code, result=history[-1]["result"], is_complete=is_complete,
|
334
|
+
final_answer=final_answer if is_complete else None
|
335
|
+
))
|
336
|
+
|
337
|
+
if is_complete:
|
338
|
+
await self._notify_observers(TaskCompletedEvent(
|
339
|
+
event_type="TaskCompleted", final_answer=final_answer, reason="success"
|
340
|
+
))
|
341
|
+
break
|
342
|
+
except Exception as e:
|
343
|
+
await self._notify_observers(ErrorOccurredEvent(
|
344
|
+
event_type="ErrorOccurred", error_message=str(e), step_number=step
|
345
|
+
))
|
346
|
+
break
|
347
|
+
|
348
|
+
if not any("<FinalAnswer>" in step["result"] for step in history):
|
349
|
+
await self._notify_observers(TaskCompletedEvent(
|
350
|
+
event_type="TaskCompleted", final_answer=None,
|
351
|
+
reason="max_iterations_reached" if len(history) == max_iters else "error"
|
352
|
+
))
|
353
|
+
return history
|
354
|
+
|
355
|
+
class Agent:
|
356
|
+
"""High-level interface for the Quantalogic Agent, providing chat and solve functionalities."""
|
357
|
+
def __init__(
|
358
|
+
self,
|
359
|
+
model: str = "gemini/gemini-2.0-flash",
|
360
|
+
tools: Optional[List[Tool]] = None,
|
361
|
+
max_iterations: int = 5,
|
362
|
+
personality: Optional[str] = None,
|
363
|
+
backstory: Optional[str] = None,
|
364
|
+
sop: Optional[str] = None,
|
365
|
+
max_history_tokens: int = MAX_HISTORY_TOKENS
|
366
|
+
):
|
367
|
+
self.model = model
|
368
|
+
self.default_tools = tools if tools is not None else get_default_tools(model)
|
369
|
+
self.max_iterations = max_iterations
|
370
|
+
self.personality = personality
|
371
|
+
self.backstory = backstory
|
372
|
+
self.sop = sop
|
373
|
+
self.max_history_tokens = max_history_tokens
|
374
|
+
self._observers: List[Tuple[Callable, List[str]]] = []
|
375
|
+
# New attribute to store context_vars from the last solve call
|
376
|
+
self.last_solve_context_vars: Dict = {}
|
377
|
+
|
378
|
+
def _build_system_prompt(self) -> str:
|
379
|
+
"""Builds a system prompt based on personality, backstory, and SOP."""
|
380
|
+
prompt = "You are an AI assistant."
|
381
|
+
if self.personality:
|
382
|
+
prompt += f" You have a {self.personality} personality."
|
383
|
+
if self.backstory:
|
384
|
+
prompt += f" Your backstory is: {self.backstory}"
|
385
|
+
if self.sop:
|
386
|
+
prompt += f" Follow this standard operating procedure: {self.sop}"
|
387
|
+
return prompt
|
388
|
+
|
389
|
+
async def chat(
|
390
|
+
self,
|
391
|
+
message: str,
|
392
|
+
use_tools: bool = False,
|
393
|
+
tools: Optional[List[Tool]] = None,
|
394
|
+
timeout: int = 30,
|
395
|
+
max_tokens: int = MAX_TOKENS,
|
396
|
+
temperature: float = 0.7,
|
397
|
+
streaming: bool = False # New parameter for streaming
|
398
|
+
) -> str:
|
399
|
+
"""Single-step interaction with optional custom tools and streaming."""
|
400
|
+
system_prompt = self._build_system_prompt()
|
401
|
+
if use_tools:
|
402
|
+
# Use provided tools or fall back to default tools, adding RetrieveStepTool
|
403
|
+
chat_tools = tools if tools is not None else self.default_tools
|
404
|
+
chat_agent = ReActAgent(model=self.model, tools=chat_tools, max_iterations=1, max_history_tokens=self.max_history_tokens)
|
405
|
+
# Add RetrieveStepTool after instantiation
|
406
|
+
chat_agent.executor.tools.append(RetrieveStepTool(chat_agent.history_store))
|
407
|
+
for observer, event_types in self._observers:
|
408
|
+
chat_agent.add_observer(observer, event_types)
|
409
|
+
history = await chat_agent.solve(message, system_prompt=system_prompt, streaming=streaming)
|
410
|
+
return self._extract_response(history)
|
411
|
+
else:
|
412
|
+
response = await litellm_completion(
|
413
|
+
model=self.model,
|
414
|
+
messages=[
|
415
|
+
{"role": "system", "content": system_prompt},
|
416
|
+
{"role": "user", "content": message}
|
417
|
+
],
|
418
|
+
max_tokens=max_tokens,
|
419
|
+
temperature=temperature,
|
420
|
+
stream=streaming, # Use streaming flag
|
421
|
+
notify_event=self._notify_observers if streaming else None
|
422
|
+
)
|
423
|
+
return response.strip()
|
424
|
+
|
425
|
+
def sync_chat(self, message: str, timeout: int = 30) -> str:
|
426
|
+
"""Synchronous wrapper for chat."""
|
427
|
+
return asyncio.run(self.chat(message, timeout=timeout))
|
428
|
+
|
429
|
+
async def solve(
|
430
|
+
self,
|
431
|
+
task: str,
|
432
|
+
success_criteria: Optional[str] = None,
|
433
|
+
max_iterations: Optional[int] = None,
|
434
|
+
tools: Optional[List[Tool]] = None,
|
435
|
+
timeout: int = 300,
|
436
|
+
streaming: bool = False # New parameter for streaming
|
437
|
+
) -> List[Dict]:
|
438
|
+
"""Multi-step task solving with optional custom tools, max_iterations, and streaming."""
|
439
|
+
system_prompt = self._build_system_prompt()
|
440
|
+
# Use provided tools or fall back to default tools
|
441
|
+
solve_tools = tools if tools is not None else self.default_tools
|
442
|
+
solve_agent = ReActAgent(
|
443
|
+
model=self.model,
|
444
|
+
tools=solve_tools,
|
445
|
+
max_iterations=max_iterations if max_iterations is not None else self.max_iterations,
|
446
|
+
max_history_tokens=self.max_history_tokens
|
447
|
+
)
|
448
|
+
# Add RetrieveStepTool after instantiation
|
449
|
+
solve_agent.executor.tools.append(RetrieveStepTool(solve_agent.history_store))
|
450
|
+
for observer, event_types in self._observers:
|
451
|
+
solve_agent.add_observer(observer, event_types)
|
452
|
+
|
453
|
+
# Execute the task and get the history with streaming support
|
454
|
+
history = await solve_agent.solve(task, success_criteria, system_prompt=system_prompt, max_iterations=max_iterations, streaming=streaming)
|
455
|
+
|
456
|
+
# Store a copy of the final context_vars
|
457
|
+
self.last_solve_context_vars = solve_agent.context_vars.copy()
|
458
|
+
|
459
|
+
return history
|
460
|
+
|
461
|
+
def sync_solve(self, task: str, success_criteria: Optional[str] = None, timeout: int = 300) -> List[Dict]:
|
462
|
+
"""Synchronous wrapper for solve."""
|
463
|
+
return asyncio.run(self.solve(task, success_criteria, timeout=timeout))
|
464
|
+
|
465
|
+
def add_observer(self, observer: Callable, event_types: List[str]) -> 'Agent':
|
466
|
+
"""Add an observer to be applied to agents created in chat and solve."""
|
467
|
+
self._observers.append((observer, event_types))
|
468
|
+
return self
|
469
|
+
|
470
|
+
def list_tools(self) -> List[str]:
|
471
|
+
"""Return a list of available tool names."""
|
472
|
+
return [tool.name for tool in self.default_tools]
|
473
|
+
|
474
|
+
def get_context_vars(self) -> Dict:
|
475
|
+
"""Return the context variables from the last solve call."""
|
476
|
+
return self.last_solve_context_vars
|
477
|
+
|
478
|
+
def _extract_response(self, history: List[Dict]) -> str:
|
479
|
+
"""Extract a clean response from the history."""
|
480
|
+
if not history:
|
481
|
+
return "No response generated."
|
482
|
+
last_result = history[-1]["result"]
|
483
|
+
try:
|
484
|
+
root = etree.fromstring(last_result)
|
485
|
+
if root.findtext("Status") == "Success":
|
486
|
+
value = root.findtext("Value") or ""
|
487
|
+
final_answer = root.findtext("FinalAnswer")
|
488
|
+
return final_answer.strip() if final_answer else value.strip()
|
489
|
+
else:
|
490
|
+
return f"Error: {root.findtext('Value') or 'Unknown error'}"
|
491
|
+
except etree.XMLSyntaxError:
|
492
|
+
return last_result
|
493
|
+
|
494
|
+
async def _notify_observers(self, event):
|
495
|
+
"""Notify all subscribed observers of an event."""
|
496
|
+
await asyncio.gather(
|
497
|
+
*(observer(event) for observer, types in self._observers if event.event_type in types),
|
498
|
+
return_exceptions=True
|
499
|
+
)
|