ursa-ai 0.0.3__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ursa-ai might be problematic. Click here for more details.
- ursa/agents/__init__.py +10 -0
- ursa/agents/arxiv_agent.py +349 -0
- ursa/agents/base.py +42 -0
- ursa/agents/code_review_agent.py +332 -0
- ursa/agents/execution_agent.py +497 -0
- ursa/agents/hypothesizer_agent.py +597 -0
- ursa/agents/mp_agent.py +257 -0
- ursa/agents/planning_agent.py +138 -0
- ursa/agents/recall_agent.py +25 -0
- ursa/agents/websearch_agent.py +193 -0
- ursa/prompt_library/code_review_prompts.py +51 -0
- ursa/prompt_library/execution_prompts.py +36 -0
- ursa/prompt_library/hypothesizer_prompts.py +17 -0
- ursa/prompt_library/literature_prompts.py +11 -0
- ursa/prompt_library/planning_prompts.py +79 -0
- ursa/prompt_library/websearch_prompts.py +131 -0
- ursa/tools/run_command.py +27 -0
- ursa/tools/write_code.py +42 -0
- ursa/util/diff_renderer.py +121 -0
- ursa/util/memory_logger.py +171 -0
- ursa/util/parse.py +89 -0
- ursa_ai-0.2.2.dist-info/METADATA +130 -0
- ursa_ai-0.2.2.dist-info/RECORD +26 -0
- ursa_ai-0.2.2.dist-info/licenses/LICENSE +8 -0
- ursa/__init__.py +0 -2
- ursa/py.typed +0 -0
- ursa_ai-0.0.3.dist-info/METADATA +0 -7
- ursa_ai-0.0.3.dist-info/RECORD +0 -6
- {ursa_ai-0.0.3.dist-info → ursa_ai-0.2.2.dist-info}/WHEEL +0 -0
- {ursa_ai-0.0.3.dist-info → ursa_ai-0.2.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,497 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
# from langchain_core.runnables.graph import MermaidDrawMethod
|
|
4
|
+
import subprocess
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Annotated, Any, Literal, Optional
|
|
7
|
+
|
|
8
|
+
import coolname
|
|
9
|
+
from langchain_community.tools import DuckDuckGoSearchResults # TavilySearchResults,
|
|
10
|
+
from langchain_core.language_models import BaseChatModel
|
|
11
|
+
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
|
12
|
+
from langchain_core.tools import InjectedToolCallId, tool
|
|
13
|
+
from langgraph.graph import END, START, StateGraph
|
|
14
|
+
from langgraph.graph.message import add_messages
|
|
15
|
+
from langgraph.prebuilt import InjectedState, ToolNode
|
|
16
|
+
from langgraph.types import Command
|
|
17
|
+
from litellm import ContentPolicyViolationError
|
|
18
|
+
|
|
19
|
+
# Rich
|
|
20
|
+
from rich import get_console
|
|
21
|
+
from rich.panel import Panel
|
|
22
|
+
from rich.syntax import Syntax
|
|
23
|
+
from typing_extensions import TypedDict
|
|
24
|
+
|
|
25
|
+
from ..prompt_library.execution_prompts import executor_prompt, summarize_prompt
|
|
26
|
+
from ..util.diff_renderer import DiffRenderer
|
|
27
|
+
from ..util.memory_logger import AgentMemory
|
|
28
|
+
from .base import BaseAgent
|
|
29
|
+
|
|
30
|
+
console = get_console() # always returns the same instance
|
|
31
|
+
|
|
32
|
+
# --- ANSI color codes ---
|
|
33
|
+
GREEN = "\033[92m"
|
|
34
|
+
BLUE = "\033[94m"
|
|
35
|
+
RED = "\033[91m"
|
|
36
|
+
RESET = "\033[0m"
|
|
37
|
+
BOLD = "\033[1m"
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class ExecutionState(TypedDict):
|
|
41
|
+
messages: Annotated[list, add_messages]
|
|
42
|
+
current_progress: str
|
|
43
|
+
code_files: list[str]
|
|
44
|
+
workspace: str
|
|
45
|
+
symlinkdir: dict
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class ExecutionAgent(BaseAgent):
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
llm: str | BaseChatModel = "openai/gpt-4o-mini",
|
|
52
|
+
agent_memory: Optional [Any | AgentMemory] = None,
|
|
53
|
+
log_state: bool = False,
|
|
54
|
+
**kwargs,
|
|
55
|
+
):
|
|
56
|
+
super().__init__(llm, **kwargs)
|
|
57
|
+
self.agent_memory = agent_memory
|
|
58
|
+
self.executor_prompt = executor_prompt
|
|
59
|
+
self.summarize_prompt = summarize_prompt
|
|
60
|
+
self.tools = [run_cmd, write_code, edit_code, search_tool]
|
|
61
|
+
self.tool_node = ToolNode(self.tools)
|
|
62
|
+
self.llm = self.llm.bind_tools(self.tools)
|
|
63
|
+
self.log_state = log_state
|
|
64
|
+
|
|
65
|
+
self._initialize_agent()
|
|
66
|
+
|
|
67
|
+
# Define the function that calls the model
|
|
68
|
+
def query_executor(self, state: ExecutionState) -> ExecutionState:
|
|
69
|
+
new_state = state.copy()
|
|
70
|
+
if "workspace" not in new_state.keys():
|
|
71
|
+
new_state["workspace"] = coolname.generate_slug(2)
|
|
72
|
+
print(
|
|
73
|
+
f"{RED}Creating the folder {BLUE}{BOLD}{new_state['workspace']}{RESET}{RED} for this project.{RESET}"
|
|
74
|
+
)
|
|
75
|
+
os.makedirs(new_state["workspace"], exist_ok=True)
|
|
76
|
+
|
|
77
|
+
# code related to symlink
|
|
78
|
+
if (
|
|
79
|
+
"symlinkdir" in new_state.keys()
|
|
80
|
+
and "is_linked" not in new_state["symlinkdir"].keys()
|
|
81
|
+
):
|
|
82
|
+
# symlinkdir = {"source": "foo", "dest": "bar"}
|
|
83
|
+
symlinkdir = new_state["symlinkdir"]
|
|
84
|
+
# user provided a symlinkdir key - let's do the linking!
|
|
85
|
+
|
|
86
|
+
src = Path(symlinkdir["source"]).expanduser().resolve()
|
|
87
|
+
workspace_root = Path(new_state["workspace"]).expanduser().resolve()
|
|
88
|
+
dst = workspace_root / symlinkdir["dest"] # prepend workspace
|
|
89
|
+
|
|
90
|
+
# if you want to replace an existing link/file, unlink it first
|
|
91
|
+
if dst.exists() or dst.is_symlink():
|
|
92
|
+
dst.unlink()
|
|
93
|
+
|
|
94
|
+
# create parent dirs for the link location if they don’t exist
|
|
95
|
+
dst.parent.mkdir(parents=True, exist_ok=True)
|
|
96
|
+
|
|
97
|
+
# actually make the link (tell pathlib it’s a directory target)
|
|
98
|
+
dst.symlink_to(src, target_is_directory=src.is_dir())
|
|
99
|
+
print(f"{RED}Symlinked {src} (source) --> {dst} (dest)")
|
|
100
|
+
# note that we've done the symlink now, so don't need to do it later
|
|
101
|
+
new_state["symlinkdir"]["is_linked"] = True
|
|
102
|
+
|
|
103
|
+
if type(new_state["messages"][0]) == SystemMessage:
|
|
104
|
+
new_state["messages"][0] = SystemMessage(content=self.executor_prompt)
|
|
105
|
+
else:
|
|
106
|
+
new_state["messages"] = [
|
|
107
|
+
SystemMessage(content=self.executor_prompt)
|
|
108
|
+
] + state["messages"]
|
|
109
|
+
try:
|
|
110
|
+
response = self.llm.invoke(
|
|
111
|
+
new_state["messages"], {"configurable": {"thread_id": self.thread_id}}
|
|
112
|
+
)
|
|
113
|
+
except ContentPolicyViolationError as e:
|
|
114
|
+
print("Error: ", e, " ", new_state["messages"][-1].content)
|
|
115
|
+
if self.log_state:
|
|
116
|
+
self.write_state("execution_agent.json", new_state)
|
|
117
|
+
return {"messages": [response], "workspace": new_state["workspace"]}
|
|
118
|
+
|
|
119
|
+
# Define the function that calls the model
|
|
120
|
+
def summarize(self, state: ExecutionState) -> ExecutionState:
|
|
121
|
+
messages = [SystemMessage(content=summarize_prompt)] + state["messages"]
|
|
122
|
+
try:
|
|
123
|
+
response = self.llm.invoke(
|
|
124
|
+
messages, {"configurable": {"thread_id": self.thread_id}}
|
|
125
|
+
)
|
|
126
|
+
except ContentPolicyViolationError as e:
|
|
127
|
+
print("Error: ", e, " ", messages[-1].content)
|
|
128
|
+
if self.agent_memory:
|
|
129
|
+
memories = []
|
|
130
|
+
# Handle looping through the messages
|
|
131
|
+
for x in state["messages"]:
|
|
132
|
+
if not type(x) == AIMessage:
|
|
133
|
+
memories.append(x.content)
|
|
134
|
+
elif not x.tool_calls:
|
|
135
|
+
memories.append(x.content)
|
|
136
|
+
else:
|
|
137
|
+
tool_strings = []
|
|
138
|
+
for tool in x.tool_calls:
|
|
139
|
+
tool_name = "Tool Name: " + tool["name"]
|
|
140
|
+
tool_strings.append(tool_name)
|
|
141
|
+
for y in tool["args"]:
|
|
142
|
+
tool_strings.append(
|
|
143
|
+
f'Arg: {str(y)}\nValue: {str(tool["args"][y])}'
|
|
144
|
+
)
|
|
145
|
+
memories.append("\n".join(tool_strings))
|
|
146
|
+
memories.append(response.content)
|
|
147
|
+
self.agent_memory.add_memories(memories)
|
|
148
|
+
save_state = state.copy()
|
|
149
|
+
save_state["messages"].append(response)
|
|
150
|
+
if self.log_state:
|
|
151
|
+
self.write_state("execution_agent.json", save_state)
|
|
152
|
+
return {"messages": [response.content]}
|
|
153
|
+
|
|
154
|
+
# Define the function that calls the model
|
|
155
|
+
def safety_check(self, state: ExecutionState) -> ExecutionState:
|
|
156
|
+
"""
|
|
157
|
+
Validate the safety of a pending shell command.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
state: Current execution state.
|
|
161
|
+
|
|
162
|
+
Returns:
|
|
163
|
+
Either the unchanged state (safe) or a state with tool message(s) (unsafe).
|
|
164
|
+
"""
|
|
165
|
+
new_state = state.copy()
|
|
166
|
+
last_msg = new_state["messages"][-1]
|
|
167
|
+
|
|
168
|
+
tool_responses = []
|
|
169
|
+
tool_failed = False
|
|
170
|
+
for tool_call in last_msg.tool_calls:
|
|
171
|
+
call_name = tool_call["name"]
|
|
172
|
+
|
|
173
|
+
if call_name == "run_cmd":
|
|
174
|
+
query = tool_call["args"]["query"]
|
|
175
|
+
safety_check = self.llm.invoke(
|
|
176
|
+
(
|
|
177
|
+
"Assume commands to run/install python and Julia files are safe because "
|
|
178
|
+
"the files are from a trusted source. "
|
|
179
|
+
f"Explain why, followed by an answer [YES] or [NO]. Is this command safe to run: {query}"
|
|
180
|
+
)
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
if "[NO]" in safety_check.content:
|
|
184
|
+
tool_failed = True
|
|
185
|
+
|
|
186
|
+
tool_response = f"""
|
|
187
|
+
[UNSAFE] That command `{query}` was deemed unsafe and cannot be run.
|
|
188
|
+
For reason: {safety_check.content}
|
|
189
|
+
"""
|
|
190
|
+
console.print(
|
|
191
|
+
"[bold red][WARNING][/bold red] Command deemed unsafe:", query
|
|
192
|
+
)
|
|
193
|
+
# and tell the user the reason
|
|
194
|
+
console.print(
|
|
195
|
+
"[bold red][WARNING][/bold red] REASON:", tool_response
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
else:
|
|
199
|
+
tool_response = f"Command `{query}` passed safety check."
|
|
200
|
+
console.print(
|
|
201
|
+
f"[green]Command passed safety check:[/green] {query}"
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
tool_responses.append(
|
|
205
|
+
ToolMessage(
|
|
206
|
+
content=tool_response,
|
|
207
|
+
tool_call_id=tool_call["id"],
|
|
208
|
+
)
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
if tool_failed:
|
|
212
|
+
new_state["messages"].extend(tool_responses)
|
|
213
|
+
|
|
214
|
+
return new_state
|
|
215
|
+
|
|
216
|
+
def _initialize_agent(self):
|
|
217
|
+
self.graph = StateGraph(ExecutionState)
|
|
218
|
+
|
|
219
|
+
self.graph.add_node("agent", self.query_executor)
|
|
220
|
+
self.graph.add_node("action", self.tool_node)
|
|
221
|
+
self.graph.add_node("summarize", self.summarize)
|
|
222
|
+
self.graph.add_node("safety_check", self.safety_check)
|
|
223
|
+
|
|
224
|
+
# Set the entrypoint as `agent`
|
|
225
|
+
# This means that this node is the first one called
|
|
226
|
+
self.graph.add_edge(START, "agent")
|
|
227
|
+
|
|
228
|
+
self.graph.add_conditional_edges(
|
|
229
|
+
"agent",
|
|
230
|
+
should_continue,
|
|
231
|
+
{
|
|
232
|
+
"continue": "safety_check",
|
|
233
|
+
"summarize": "summarize",
|
|
234
|
+
},
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
self.graph.add_conditional_edges(
|
|
238
|
+
"safety_check",
|
|
239
|
+
command_safe,
|
|
240
|
+
{
|
|
241
|
+
"safe": "action",
|
|
242
|
+
"unsafe": "agent",
|
|
243
|
+
},
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
self.graph.add_edge("action", "agent")
|
|
247
|
+
self.graph.add_edge("summarize", END)
|
|
248
|
+
|
|
249
|
+
self.action = self.graph.compile(checkpointer=self.checkpointer)
|
|
250
|
+
# self.action.get_graph().draw_mermaid_png(output_file_path="execution_agent_graph.png", draw_method=MermaidDrawMethod.PYPPETEER)
|
|
251
|
+
|
|
252
|
+
def run(self, prompt, recursion_limit=1000):
|
|
253
|
+
inputs = {"messages": [HumanMessage(content=prompt)]}
|
|
254
|
+
return self.action.invoke(
|
|
255
|
+
inputs,
|
|
256
|
+
{
|
|
257
|
+
"recursion_limit": recursion_limit,
|
|
258
|
+
"configurable": {"thread_id": self.thread_id},
|
|
259
|
+
},
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
@tool
|
|
264
|
+
def run_cmd(query: str, state: Annotated[dict, InjectedState]) -> str:
|
|
265
|
+
"""
|
|
266
|
+
Run a commandline command from using the subprocess package in python
|
|
267
|
+
|
|
268
|
+
Args:
|
|
269
|
+
query: commandline command to be run as a string given to the subprocess.run command.
|
|
270
|
+
"""
|
|
271
|
+
workspace_dir = state["workspace"]
|
|
272
|
+
print("RUNNING: ", query)
|
|
273
|
+
try:
|
|
274
|
+
result = subprocess.run(
|
|
275
|
+
query,
|
|
276
|
+
text=True,
|
|
277
|
+
shell=True,
|
|
278
|
+
timeout=60000,
|
|
279
|
+
capture_output=True,
|
|
280
|
+
cwd=workspace_dir,
|
|
281
|
+
)
|
|
282
|
+
stdout, stderr = result.stdout, result.stderr
|
|
283
|
+
except KeyboardInterrupt:
|
|
284
|
+
print("Keyboard Interrupt of command: ", query)
|
|
285
|
+
stdout, stderr = "", "KeyboardInterrupt:"
|
|
286
|
+
|
|
287
|
+
print("STDOUT: ", stdout)
|
|
288
|
+
print("STDERR: ", stderr)
|
|
289
|
+
|
|
290
|
+
return f"STDOUT: {stdout} and STDERR: {stderr}"
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
def _strip_fences(snippet: str) -> str:
|
|
294
|
+
"""
|
|
295
|
+
Remove leading markdown ``` fence
|
|
296
|
+
"""
|
|
297
|
+
if "```" not in snippet:
|
|
298
|
+
return snippet
|
|
299
|
+
|
|
300
|
+
parts = snippet.split("```")
|
|
301
|
+
if len(parts) < 3:
|
|
302
|
+
return snippet
|
|
303
|
+
|
|
304
|
+
body = parts[1]
|
|
305
|
+
return "\n".join(body.split("\n")[1:]) if "\n" in body else body.strip()
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
@tool
|
|
309
|
+
def write_code(
|
|
310
|
+
code: str,
|
|
311
|
+
filename: str,
|
|
312
|
+
tool_call_id: Annotated[str, InjectedToolCallId],
|
|
313
|
+
state: Annotated[dict, InjectedState],
|
|
314
|
+
) -> Command:
|
|
315
|
+
"""Write *code* to *filename*.
|
|
316
|
+
|
|
317
|
+
Args:
|
|
318
|
+
code: Source code as a string.
|
|
319
|
+
filename: Target filename (including extension).
|
|
320
|
+
|
|
321
|
+
Returns:
|
|
322
|
+
Success / failure message.
|
|
323
|
+
"""
|
|
324
|
+
workspace_dir = state["workspace"]
|
|
325
|
+
console.print("[cyan]Writing file:[/]", filename)
|
|
326
|
+
|
|
327
|
+
# Clean up markdown fences
|
|
328
|
+
code = _strip_fences(code)
|
|
329
|
+
|
|
330
|
+
# Syntax-highlighted preview
|
|
331
|
+
try:
|
|
332
|
+
lexer_name = Syntax.guess_lexer(filename, code)
|
|
333
|
+
except Exception:
|
|
334
|
+
lexer_name = "text"
|
|
335
|
+
|
|
336
|
+
console.print(
|
|
337
|
+
Panel(
|
|
338
|
+
Syntax(code, lexer_name, line_numbers=True),
|
|
339
|
+
title="File Preview",
|
|
340
|
+
border_style="cyan",
|
|
341
|
+
)
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
code_file = os.path.join(workspace_dir, filename)
|
|
345
|
+
try:
|
|
346
|
+
with open(code_file, "w", encoding="utf-8") as f:
|
|
347
|
+
f.write(code)
|
|
348
|
+
except Exception as exc:
|
|
349
|
+
console.print(
|
|
350
|
+
"[bold bright_white on red] :heavy_multiplication_x: [/] "
|
|
351
|
+
"[red]Failed to write file:[/]",
|
|
352
|
+
exc,
|
|
353
|
+
)
|
|
354
|
+
return f"Failed to write {filename}."
|
|
355
|
+
|
|
356
|
+
console.print(
|
|
357
|
+
f"[bold bright_white on green] :heavy_check_mark: [/] "
|
|
358
|
+
f"[green]File written:[/] {code_file}"
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
# Append the file to the list in state
|
|
362
|
+
file_list = state.get("code_files", [])
|
|
363
|
+
file_list.append(filename)
|
|
364
|
+
|
|
365
|
+
# Create a tool message to send back
|
|
366
|
+
msg = ToolMessage(
|
|
367
|
+
content=f"File {filename} written successfully.",
|
|
368
|
+
tool_call_id=tool_call_id,
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
# Return updated code files list & the message
|
|
372
|
+
return Command(
|
|
373
|
+
update={
|
|
374
|
+
"code_files": file_list,
|
|
375
|
+
"messages": [msg],
|
|
376
|
+
}
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
@tool
|
|
381
|
+
def edit_code(
|
|
382
|
+
old_code: str,
|
|
383
|
+
new_code: str,
|
|
384
|
+
filename: str,
|
|
385
|
+
state: Annotated[dict, InjectedState],
|
|
386
|
+
) -> str:
|
|
387
|
+
"""Replace the **first** occurrence of *old_code* with *new_code* in *filename*.
|
|
388
|
+
|
|
389
|
+
Args:
|
|
390
|
+
old_code: Code fragment to search for.
|
|
391
|
+
new_code: Replacement fragment.
|
|
392
|
+
filename: Target file inside the workspace.
|
|
393
|
+
|
|
394
|
+
Returns:
|
|
395
|
+
Success / failure message.
|
|
396
|
+
"""
|
|
397
|
+
workspace_dir = state["workspace"]
|
|
398
|
+
console.print("[cyan]Editing file:[/cyan]", filename)
|
|
399
|
+
|
|
400
|
+
code_file = os.path.join(workspace_dir, filename)
|
|
401
|
+
try:
|
|
402
|
+
with open(code_file, "r", encoding="utf-8") as f:
|
|
403
|
+
content = f.read()
|
|
404
|
+
except FileNotFoundError:
|
|
405
|
+
console.print(
|
|
406
|
+
"[bold bright_white on red] :heavy_multiplication_x: [/] "
|
|
407
|
+
"[red]File not found:[/]",
|
|
408
|
+
filename,
|
|
409
|
+
)
|
|
410
|
+
return f"Failed: {filename} not found."
|
|
411
|
+
|
|
412
|
+
# Clean up markdown fences
|
|
413
|
+
old_code_clean = _strip_fences(old_code)
|
|
414
|
+
new_code_clean = _strip_fences(new_code)
|
|
415
|
+
|
|
416
|
+
if old_code_clean not in content:
|
|
417
|
+
console.print(f"[yellow] ⚠️ 'old_code' not found in file'; no changes made.[/]")
|
|
418
|
+
return f"No changes made to {filename}: 'old_code' not found in file."
|
|
419
|
+
|
|
420
|
+
updated = content.replace(old_code_clean, new_code_clean, 1)
|
|
421
|
+
|
|
422
|
+
console.print(
|
|
423
|
+
Panel(
|
|
424
|
+
DiffRenderer(content, updated, filename),
|
|
425
|
+
title="Diff Preview",
|
|
426
|
+
border_style="cyan",
|
|
427
|
+
)
|
|
428
|
+
)
|
|
429
|
+
|
|
430
|
+
try:
|
|
431
|
+
with open(code_file, "w", encoding="utf-8") as f:
|
|
432
|
+
f.write(updated)
|
|
433
|
+
except Exception as exc:
|
|
434
|
+
console.print(
|
|
435
|
+
"[bold bright_white on red] :heavy_multiplication_x: [/] "
|
|
436
|
+
"[red]Failed to write file:[/]",
|
|
437
|
+
exc,
|
|
438
|
+
)
|
|
439
|
+
return f"Failed to edit {filename}."
|
|
440
|
+
|
|
441
|
+
console.print(
|
|
442
|
+
f"[bold bright_white on green] :heavy_check_mark: [/] "
|
|
443
|
+
f"[green]File updated:[/] {code_file}"
|
|
444
|
+
)
|
|
445
|
+
return f"File {filename} updated successfully."
|
|
446
|
+
|
|
447
|
+
|
|
448
|
+
search_tool = DuckDuckGoSearchResults(output_format="json", num_results=10)
|
|
449
|
+
# search_tool = TavilySearchResults(max_results=10, search_depth="advanced", include_answer=True)
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
# Define the function that determines whether to continue or not
|
|
453
|
+
def should_continue(state: ExecutionState) -> Literal["summarize", "continue"]:
|
|
454
|
+
messages = state["messages"]
|
|
455
|
+
last_message = messages[-1]
|
|
456
|
+
# If there is no tool call, then we finish
|
|
457
|
+
if not last_message.tool_calls:
|
|
458
|
+
return "summarize"
|
|
459
|
+
# Otherwise if there is, we continue
|
|
460
|
+
else:
|
|
461
|
+
return "continue"
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
# Define the function that determines whether to continue or not
|
|
465
|
+
def command_safe(state: ExecutionState) -> Literal["safe", "unsafe"]:
|
|
466
|
+
"""
|
|
467
|
+
Return graph edge "safe" if the last command was safe, otherwise return edge "unsafe"
|
|
468
|
+
"""
|
|
469
|
+
|
|
470
|
+
index = -1
|
|
471
|
+
message = state["messages"][index]
|
|
472
|
+
# Loop through all the consecutive tool messages in reverse order
|
|
473
|
+
while type(message) == ToolMessage:
|
|
474
|
+
if "[UNSAFE]" in message.content:
|
|
475
|
+
return "unsafe"
|
|
476
|
+
|
|
477
|
+
index -= 1
|
|
478
|
+
message = state["messages"][index]
|
|
479
|
+
|
|
480
|
+
return "safe"
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
def main():
|
|
484
|
+
execution_agent = ExecutionAgent()
|
|
485
|
+
problem_string = "Write and execute a python script to print the first 10 integers."
|
|
486
|
+
inputs = {
|
|
487
|
+
"messages": [HumanMessage(content=problem_string)]
|
|
488
|
+
} # , "workspace":"dummy_test"}
|
|
489
|
+
result = execution_agent.action.invoke(
|
|
490
|
+
inputs, {"configurable": {"thread_id": execution_agent.thread_id}}
|
|
491
|
+
)
|
|
492
|
+
print(result["messages"][-1].content)
|
|
493
|
+
return result
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
if __name__ == "__main__":
|
|
497
|
+
main()
|