ursa-ai 0.7.0rc1__py3-none-any.whl → 0.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ursa-ai might be problematic. Click here for more details.
- ursa/agents/__init__.py +13 -2
- ursa/agents/acquisition_agents.py +812 -0
- ursa/agents/arxiv_agent.py +1 -1
- ursa/agents/base.py +352 -91
- ursa/agents/chat_agent.py +58 -0
- ursa/agents/execution_agent.py +506 -260
- ursa/agents/lammps_agent.py +81 -31
- ursa/agents/planning_agent.py +27 -2
- ursa/agents/websearch_agent.py +2 -2
- ursa/cli/__init__.py +5 -1
- ursa/cli/hitl.py +46 -34
- ursa/observability/pricing.json +85 -0
- ursa/observability/pricing.py +20 -18
- ursa/util/parse.py +316 -0
- {ursa_ai-0.7.0rc1.dist-info → ursa_ai-0.7.1.dist-info}/METADATA +5 -1
- {ursa_ai-0.7.0rc1.dist-info → ursa_ai-0.7.1.dist-info}/RECORD +20 -17
- {ursa_ai-0.7.0rc1.dist-info → ursa_ai-0.7.1.dist-info}/WHEEL +0 -0
- {ursa_ai-0.7.0rc1.dist-info → ursa_ai-0.7.1.dist-info}/entry_points.txt +0 -0
- {ursa_ai-0.7.0rc1.dist-info → ursa_ai-0.7.1.dist-info}/licenses/LICENSE +0 -0
- {ursa_ai-0.7.0rc1.dist-info → ursa_ai-0.7.1.dist-info}/top_level.txt +0 -0
ursa/agents/execution_agent.py
CHANGED
|
@@ -1,3 +1,30 @@
|
|
|
1
|
+
"""Execution agent that builds a tool-enabled state graph to autonomously run tasks.
|
|
2
|
+
|
|
3
|
+
This module implements ExecutionAgent, a LangGraph-based agent that executes user
|
|
4
|
+
instructions by invoking LLM tool calls and coordinating a controlled workflow.
|
|
5
|
+
|
|
6
|
+
Key features:
|
|
7
|
+
- Workspace management with optional symlinking for external sources.
|
|
8
|
+
- Safety-checked shell execution via run_cmd with output size budgeting.
|
|
9
|
+
- Code authoring and edits through write_code and edit_code with rich previews.
|
|
10
|
+
- Web search capability through DuckDuckGoSearchResults.
|
|
11
|
+
- Summarization of the session and optional memory logging.
|
|
12
|
+
- Configurable graph with nodes for agent, safety_check, action, and summarize.
|
|
13
|
+
|
|
14
|
+
Implementation notes:
|
|
15
|
+
- LLM prompts are sourced from prompt_library.execution_prompts.
|
|
16
|
+
- Outputs from subprocess are trimmed under MAX_TOOL_MSG_CHARS to fit tool messages.
|
|
17
|
+
- The agent uses ToolNode and LangGraph StateGraph to loop until no tool calls remain.
|
|
18
|
+
- Safety gates block unsafe shell commands and surface the rationale to the user.
|
|
19
|
+
|
|
20
|
+
Environment:
|
|
21
|
+
- MAX_TOOL_MSG_CHARS caps combined stdout/stderr in tool responses.
|
|
22
|
+
|
|
23
|
+
Entry points:
|
|
24
|
+
- ExecutionAgent._invoke(...) runs the compiled graph.
|
|
25
|
+
- main() shows a minimal demo that writes and runs a script.
|
|
26
|
+
"""
|
|
27
|
+
|
|
1
28
|
import os
|
|
2
29
|
|
|
3
30
|
# from langchain_core.runnables.graph import MermaidDrawMethod
|
|
@@ -48,7 +75,33 @@ RESET = "\033[0m"
|
|
|
48
75
|
BOLD = "\033[1m"
|
|
49
76
|
|
|
50
77
|
|
|
78
|
+
# Global variables for the module.
|
|
79
|
+
|
|
80
|
+
# Set a limit for message characters - the user could overload
|
|
81
|
+
# that in their env, or maybe we could pull this out of the LLM parameters
|
|
82
|
+
MAX_TOOL_MSG_CHARS = int(os.getenv("MAX_TOOL_MSG_CHARS", "50000"))
|
|
83
|
+
|
|
84
|
+
# Set a search tool.
|
|
85
|
+
search_tool = DuckDuckGoSearchResults(output_format="json", num_results=10)
|
|
86
|
+
# search_tool = TavilySearchResults(
|
|
87
|
+
# max_results=10,
|
|
88
|
+
# search_depth="advanced",
|
|
89
|
+
# include_answer=True)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
# Classes for typing
|
|
51
93
|
class ExecutionState(TypedDict):
|
|
94
|
+
"""TypedDict representing the execution agent's mutable run state used by nodes.
|
|
95
|
+
|
|
96
|
+
Fields:
|
|
97
|
+
- messages: list of messages (System/Human/AI/Tool) with add_messages metadata.
|
|
98
|
+
- current_progress: short status string describing agent progress.
|
|
99
|
+
- code_files: list of filenames created or edited in the workspace.
|
|
100
|
+
- workspace: path to the working directory where files and commands run.
|
|
101
|
+
- symlinkdir: optional dict describing a symlink operation (source, dest,
|
|
102
|
+
is_linked).
|
|
103
|
+
"""
|
|
104
|
+
|
|
52
105
|
messages: Annotated[list, add_messages]
|
|
53
106
|
current_progress: str
|
|
54
107
|
code_files: list[str]
|
|
@@ -56,220 +109,42 @@ class ExecutionState(TypedDict):
|
|
|
56
109
|
symlinkdir: dict
|
|
57
110
|
|
|
58
111
|
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
llm: str | BaseChatModel = "openai/gpt-4o-mini",
|
|
63
|
-
agent_memory: Optional[Any | AgentMemory] = None,
|
|
64
|
-
log_state: bool = False,
|
|
65
|
-
**kwargs,
|
|
66
|
-
):
|
|
67
|
-
super().__init__(llm, **kwargs)
|
|
68
|
-
self.agent_memory = agent_memory
|
|
69
|
-
self.safety_prompt = safety_prompt
|
|
70
|
-
self.executor_prompt = executor_prompt
|
|
71
|
-
self.summarize_prompt = summarize_prompt
|
|
72
|
-
self.tools = [run_cmd, write_code, edit_code, search_tool]
|
|
73
|
-
self.tool_node = ToolNode(self.tools)
|
|
74
|
-
self.llm = self.llm.bind_tools(self.tools)
|
|
75
|
-
self.log_state = log_state
|
|
76
|
-
|
|
77
|
-
self._action = self._build_graph()
|
|
78
|
-
|
|
79
|
-
# Define the function that calls the model
|
|
80
|
-
def query_executor(self, state: ExecutionState) -> ExecutionState:
|
|
81
|
-
new_state = state.copy()
|
|
82
|
-
if "workspace" not in new_state.keys():
|
|
83
|
-
new_state["workspace"] = randomname.get_name()
|
|
84
|
-
print(
|
|
85
|
-
f"{RED}Creating the folder {BLUE}{BOLD}{new_state['workspace']}{RESET}{RED} for this project.{RESET}"
|
|
86
|
-
)
|
|
87
|
-
os.makedirs(new_state["workspace"], exist_ok=True)
|
|
88
|
-
|
|
89
|
-
# code related to symlink
|
|
90
|
-
sd = new_state.get("symlinkdir")
|
|
91
|
-
if isinstance(sd, dict) and "is_linked" not in sd:
|
|
92
|
-
# symlinkdir = {"source": "foo", "dest": "bar"}
|
|
93
|
-
symlinkdir = new_state["symlinkdir"]
|
|
94
|
-
# user provided a symlinkdir key - let's do the linking!
|
|
95
|
-
|
|
96
|
-
src = Path(symlinkdir["source"]).expanduser().resolve()
|
|
97
|
-
workspace_root = Path(new_state["workspace"]).expanduser().resolve()
|
|
98
|
-
dst = workspace_root / symlinkdir["dest"] # prepend workspace
|
|
99
|
-
|
|
100
|
-
# if you want to replace an existing link/file, unlink it first
|
|
101
|
-
if dst.exists() or dst.is_symlink():
|
|
102
|
-
dst.unlink()
|
|
103
|
-
|
|
104
|
-
# create parent dirs for the link location if they don’t exist
|
|
105
|
-
dst.parent.mkdir(parents=True, exist_ok=True)
|
|
106
|
-
|
|
107
|
-
# actually make the link (tell pathlib it’s a directory target)
|
|
108
|
-
dst.symlink_to(src, target_is_directory=src.is_dir())
|
|
109
|
-
print(f"{RED}Symlinked {src} (source) --> {dst} (dest)")
|
|
110
|
-
# note that we've done the symlink now, so don't need to do it later
|
|
111
|
-
new_state["symlinkdir"]["is_linked"] = True
|
|
112
|
-
|
|
113
|
-
if isinstance(new_state["messages"][0], SystemMessage):
|
|
114
|
-
new_state["messages"][0] = SystemMessage(
|
|
115
|
-
content=self.executor_prompt
|
|
116
|
-
)
|
|
117
|
-
else:
|
|
118
|
-
new_state["messages"] = [
|
|
119
|
-
SystemMessage(content=self.executor_prompt)
|
|
120
|
-
] + state["messages"]
|
|
121
|
-
try:
|
|
122
|
-
response = self.llm.invoke(
|
|
123
|
-
new_state["messages"], self.build_config(tags=["agent"])
|
|
124
|
-
)
|
|
125
|
-
except ContentPolicyViolationError as e:
|
|
126
|
-
print("Error: ", e, " ", new_state["messages"][-1].content)
|
|
127
|
-
if self.log_state:
|
|
128
|
-
self.write_state("execution_agent.json", new_state)
|
|
129
|
-
return {"messages": [response], "workspace": new_state["workspace"]}
|
|
130
|
-
|
|
131
|
-
# Define the function that calls the model
|
|
132
|
-
def summarize(self, state: ExecutionState) -> ExecutionState:
|
|
133
|
-
messages = [SystemMessage(content=summarize_prompt)] + state["messages"]
|
|
134
|
-
try:
|
|
135
|
-
response = self.llm.invoke(
|
|
136
|
-
messages, self.build_config(tags=["summarize"])
|
|
137
|
-
)
|
|
138
|
-
except ContentPolicyViolationError as e:
|
|
139
|
-
print("Error: ", e, " ", messages[-1].content)
|
|
140
|
-
if self.agent_memory:
|
|
141
|
-
memories = []
|
|
142
|
-
# Handle looping through the messages
|
|
143
|
-
for x in state["messages"]:
|
|
144
|
-
if not isinstance(x, AIMessage):
|
|
145
|
-
memories.append(x.content)
|
|
146
|
-
elif not x.tool_calls:
|
|
147
|
-
memories.append(x.content)
|
|
148
|
-
else:
|
|
149
|
-
tool_strings = []
|
|
150
|
-
for tool in x.tool_calls:
|
|
151
|
-
tool_name = "Tool Name: " + tool["name"]
|
|
152
|
-
tool_strings.append(tool_name)
|
|
153
|
-
for y in tool["args"]:
|
|
154
|
-
tool_strings.append(
|
|
155
|
-
f"Arg: {str(y)}\nValue: {str(tool['args'][y])}"
|
|
156
|
-
)
|
|
157
|
-
memories.append("\n".join(tool_strings))
|
|
158
|
-
memories.append(response.content)
|
|
159
|
-
self.agent_memory.add_memories(memories)
|
|
160
|
-
save_state = state.copy()
|
|
161
|
-
save_state["messages"].append(response)
|
|
162
|
-
if self.log_state:
|
|
163
|
-
self.write_state("execution_agent.json", save_state)
|
|
164
|
-
return {"messages": [response.content]}
|
|
165
|
-
|
|
166
|
-
# Define the function that calls the model
|
|
167
|
-
def safety_check(self, state: ExecutionState) -> ExecutionState:
|
|
168
|
-
"""
|
|
169
|
-
Validate the safety of a pending shell command.
|
|
170
|
-
|
|
171
|
-
Args:
|
|
172
|
-
state: Current execution state.
|
|
173
|
-
|
|
174
|
-
Returns:
|
|
175
|
-
Either the unchanged state (safe) or a state with tool message(s) (unsafe).
|
|
176
|
-
"""
|
|
177
|
-
new_state = state.copy()
|
|
178
|
-
last_msg = new_state["messages"][-1]
|
|
179
|
-
|
|
180
|
-
tool_responses = []
|
|
181
|
-
tool_failed = False
|
|
182
|
-
for tool_call in last_msg.tool_calls:
|
|
183
|
-
call_name = tool_call["name"]
|
|
184
|
-
|
|
185
|
-
if call_name == "run_cmd":
|
|
186
|
-
query = tool_call["args"]["query"]
|
|
187
|
-
safety_check = self.llm.invoke(
|
|
188
|
-
self.safety_prompt + query,
|
|
189
|
-
self.build_config(tags=["safety_check"]),
|
|
190
|
-
)
|
|
191
|
-
|
|
192
|
-
if "[NO]" in safety_check.content:
|
|
193
|
-
tool_failed = True
|
|
194
|
-
|
|
195
|
-
tool_response = f"""
|
|
196
|
-
[UNSAFE] That command `{query}` was deemed unsafe and cannot be run.
|
|
197
|
-
For reason: {safety_check.content}
|
|
198
|
-
"""
|
|
199
|
-
console.print(
|
|
200
|
-
"[bold red][WARNING][/bold red] Command deemed unsafe:",
|
|
201
|
-
query,
|
|
202
|
-
)
|
|
203
|
-
# and tell the user the reason
|
|
204
|
-
console.print(
|
|
205
|
-
"[bold red][WARNING][/bold red] REASON:", tool_response
|
|
206
|
-
)
|
|
207
|
-
|
|
208
|
-
else:
|
|
209
|
-
tool_response = f"Command `{query}` passed safety check."
|
|
210
|
-
console.print(
|
|
211
|
-
f"[green]Command passed safety check:[/green] {query}"
|
|
212
|
-
)
|
|
213
|
-
|
|
214
|
-
tool_responses.append(
|
|
215
|
-
ToolMessage(
|
|
216
|
-
content=tool_response,
|
|
217
|
-
tool_call_id=tool_call["id"],
|
|
218
|
-
)
|
|
219
|
-
)
|
|
220
|
-
|
|
221
|
-
if tool_failed:
|
|
222
|
-
new_state["messages"].extend(tool_responses)
|
|
223
|
-
|
|
224
|
-
return new_state
|
|
225
|
-
|
|
226
|
-
def _build_graph(self):
|
|
227
|
-
graph = StateGraph(ExecutionState)
|
|
228
|
-
|
|
229
|
-
self.add_node(graph, self.query_executor, "agent")
|
|
230
|
-
self.add_node(graph, self.tool_node, "action")
|
|
231
|
-
self.add_node(graph, self.summarize, "summarize")
|
|
232
|
-
self.add_node(graph, self.safety_check, "safety_check")
|
|
112
|
+
# Helper functions
|
|
113
|
+
def _strip_fences(snippet: str) -> str:
|
|
114
|
+
"""Remove markdown fences from a code snippet.
|
|
233
115
|
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
116
|
+
This function strips leading triple backticks and any language
|
|
117
|
+
identifiers from a markdown-formatted code snippet and returns
|
|
118
|
+
only the contained code.
|
|
237
119
|
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
self._wrap_cond(should_continue, "should_continue", "execution"),
|
|
241
|
-
{"continue": "safety_check", "summarize": "summarize"},
|
|
242
|
-
)
|
|
120
|
+
Args:
|
|
121
|
+
snippet: The markdown-formatted code snippet.
|
|
243
122
|
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
123
|
+
Returns:
|
|
124
|
+
The snippet content without leading markdown fences.
|
|
125
|
+
"""
|
|
126
|
+
if "```" not in snippet:
|
|
127
|
+
return snippet
|
|
249
128
|
|
|
250
|
-
|
|
251
|
-
|
|
129
|
+
parts = snippet.split("```")
|
|
130
|
+
if len(parts) < 3:
|
|
131
|
+
return snippet
|
|
252
132
|
|
|
253
|
-
|
|
254
|
-
|
|
133
|
+
body = parts[1]
|
|
134
|
+
return "\n".join(body.split("\n")[1:]) if "\n" in body else body.strip()
|
|
255
135
|
|
|
256
|
-
def _invoke(
|
|
257
|
-
self, inputs: Mapping[str, Any], recursion_limit: int = 999_999, **_
|
|
258
|
-
):
|
|
259
|
-
config = self.build_config(
|
|
260
|
-
recursion_limit=recursion_limit, tags=["graph"]
|
|
261
|
-
)
|
|
262
|
-
return self._action.invoke(inputs, config)
|
|
263
136
|
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
def action(self):
|
|
267
|
-
raise AttributeError(
|
|
268
|
-
"Use .stream(...) or .invoke(...); direct .action access is unsupported."
|
|
269
|
-
)
|
|
137
|
+
def _snip_text(text: str, max_chars: int) -> tuple[str, bool]:
|
|
138
|
+
"""Truncate text to a maximum length and indicate if truncation occurred.
|
|
270
139
|
|
|
140
|
+
Args:
|
|
141
|
+
text: The original text to potentially truncate.
|
|
142
|
+
max_chars: The maximum characters allowed in the output.
|
|
271
143
|
|
|
272
|
-
|
|
144
|
+
Returns:
|
|
145
|
+
A tuple of (possibly truncated text, boolean flag indicating
|
|
146
|
+
if truncation occurred).
|
|
147
|
+
"""
|
|
273
148
|
if text is None:
|
|
274
149
|
return "", False
|
|
275
150
|
if max_chars <= 0:
|
|
@@ -287,6 +162,16 @@ def _snip_text(text: str, max_chars: int) -> tuple[str, bool]:
|
|
|
287
162
|
|
|
288
163
|
|
|
289
164
|
def _fit_streams_to_budget(stdout: str, stderr: str, total_budget: int):
|
|
165
|
+
"""Allocate and truncate stdout and stderr to fit a total character budget.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
stdout: The original stdout string.
|
|
169
|
+
stderr: The original stderr string.
|
|
170
|
+
total_budget: The combined character budget for stdout and stderr.
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
A tuple of (possibly truncated stdout, possibly truncated stderr).
|
|
174
|
+
"""
|
|
290
175
|
label_overhead = len("STDOUT:\n") + len("\nSTDERR:\n")
|
|
291
176
|
budget = max(0, total_budget - label_overhead)
|
|
292
177
|
|
|
@@ -299,23 +184,72 @@ def _fit_streams_to_budget(stdout: str, stderr: str, total_budget: int):
|
|
|
299
184
|
|
|
300
185
|
stdout_snip, _ = _snip_text(stdout, stdout_budget)
|
|
301
186
|
stderr_snip, _ = _snip_text(stderr, stderr_budget)
|
|
187
|
+
|
|
302
188
|
return stdout_snip, stderr_snip
|
|
303
189
|
|
|
304
190
|
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
191
|
+
def should_continue(state: ExecutionState) -> Literal["summarize", "continue"]:
|
|
192
|
+
"""Return 'summarize' if no tool calls in the last message, else 'continue'.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
state: The current execution state containing messages.
|
|
196
|
+
|
|
197
|
+
Returns:
|
|
198
|
+
A literal "summarize" if the last message has no tool calls,
|
|
199
|
+
otherwise "continue".
|
|
200
|
+
"""
|
|
201
|
+
messages = state["messages"]
|
|
202
|
+
last_message = messages[-1]
|
|
203
|
+
# If there is no tool call, then we finish
|
|
204
|
+
if not last_message.tool_calls:
|
|
205
|
+
return "summarize"
|
|
206
|
+
# Otherwise if there is, we continue
|
|
207
|
+
else:
|
|
208
|
+
return "continue"
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def command_safe(state: ExecutionState) -> Literal["safe", "unsafe"]:
|
|
212
|
+
"""Return 'safe' if the last command was safe, otherwise 'unsafe'.
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
state: The current execution state containing messages and tool calls.
|
|
216
|
+
Returns:
|
|
217
|
+
A literal "safe" if no '[UNSAFE]' tags are in the last command,
|
|
218
|
+
otherwise "unsafe".
|
|
219
|
+
"""
|
|
220
|
+
index = -1
|
|
221
|
+
message = state["messages"][index]
|
|
222
|
+
# Loop through all the consecutive tool messages in reverse order
|
|
223
|
+
while isinstance(message, ToolMessage):
|
|
224
|
+
if "[UNSAFE]" in message.content:
|
|
225
|
+
return "unsafe"
|
|
308
226
|
|
|
227
|
+
index -= 1
|
|
228
|
+
message = state["messages"][index]
|
|
229
|
+
|
|
230
|
+
return "safe"
|
|
309
231
|
|
|
232
|
+
|
|
233
|
+
# Tools for ExecutionAgent
|
|
310
234
|
@tool
|
|
311
235
|
def run_cmd(query: str, state: Annotated[dict, InjectedState]) -> str:
|
|
312
|
-
"""
|
|
313
|
-
|
|
236
|
+
"""Execute a shell command in the workspace and return its combined output.
|
|
237
|
+
|
|
238
|
+
Runs the specified command using subprocess.run in the given workspace
|
|
239
|
+
directory, captures stdout and stderr, enforces a maximum character budget,
|
|
240
|
+
and formats both streams into a single string. KeyboardInterrupt during
|
|
241
|
+
execution is caught and reported.
|
|
314
242
|
|
|
315
243
|
Args:
|
|
316
|
-
query:
|
|
244
|
+
query: The shell command to execute.
|
|
245
|
+
state: A dict with injected state; must include the 'workspace' path.
|
|
246
|
+
|
|
247
|
+
Returns:
|
|
248
|
+
A formatted string with "STDOUT:" followed by the truncated stdout and
|
|
249
|
+
"STDERR:" followed by the truncated stderr.
|
|
317
250
|
"""
|
|
318
251
|
workspace_dir = state["workspace"]
|
|
252
|
+
|
|
319
253
|
print("RUNNING: ", query)
|
|
320
254
|
try:
|
|
321
255
|
result = subprocess.run(
|
|
@@ -342,21 +276,6 @@ def run_cmd(query: str, state: Annotated[dict, InjectedState]) -> str:
|
|
|
342
276
|
return f"STDOUT:\n{stdout_fit}\nSTDERR:\n{stderr_fit}"
|
|
343
277
|
|
|
344
278
|
|
|
345
|
-
def _strip_fences(snippet: str) -> str:
|
|
346
|
-
"""
|
|
347
|
-
Remove leading markdown ``` fence
|
|
348
|
-
"""
|
|
349
|
-
if "```" not in snippet:
|
|
350
|
-
return snippet
|
|
351
|
-
|
|
352
|
-
parts = snippet.split("```")
|
|
353
|
-
if len(parts) < 3:
|
|
354
|
-
return snippet
|
|
355
|
-
|
|
356
|
-
body = parts[1]
|
|
357
|
-
return "\n".join(body.split("\n")[1:]) if "\n" in body else body.strip()
|
|
358
|
-
|
|
359
|
-
|
|
360
279
|
@tool
|
|
361
280
|
def write_code(
|
|
362
281
|
code: str,
|
|
@@ -364,22 +283,26 @@ def write_code(
|
|
|
364
283
|
tool_call_id: Annotated[str, InjectedToolCallId],
|
|
365
284
|
state: Annotated[dict, InjectedState],
|
|
366
285
|
) -> Command:
|
|
367
|
-
"""Write
|
|
286
|
+
"""Write source code to a file and update the agent’s workspace state.
|
|
368
287
|
|
|
369
288
|
Args:
|
|
370
|
-
code:
|
|
371
|
-
filename:
|
|
289
|
+
code: The source code content to be written to disk.
|
|
290
|
+
filename: Name of the target file (including its extension).
|
|
291
|
+
tool_call_id: Identifier for this tool invocation.
|
|
292
|
+
state: Agent state dict holding workspace path and file list.
|
|
372
293
|
|
|
373
294
|
Returns:
|
|
374
|
-
|
|
295
|
+
Command: Contains an updated state (including code_files) and
|
|
296
|
+
a ToolMessage acknowledging success or failure.
|
|
375
297
|
"""
|
|
298
|
+
# Determine the full path to the target file
|
|
376
299
|
workspace_dir = state["workspace"]
|
|
377
300
|
console.print("[cyan]Writing file:[/]", filename)
|
|
378
301
|
|
|
379
|
-
# Clean up markdown fences
|
|
302
|
+
# Clean up markdown fences on submitted code.
|
|
380
303
|
code = _strip_fences(code)
|
|
381
304
|
|
|
382
|
-
#
|
|
305
|
+
# Show syntax-highlighted preview before writing to file
|
|
383
306
|
try:
|
|
384
307
|
lexer_name = Syntax.guess_lexer(filename, code)
|
|
385
308
|
except Exception:
|
|
@@ -393,6 +316,7 @@ def write_code(
|
|
|
393
316
|
)
|
|
394
317
|
)
|
|
395
318
|
|
|
319
|
+
# Write cleaned code to disk
|
|
396
320
|
code_file = os.path.join(workspace_dir, filename)
|
|
397
321
|
try:
|
|
398
322
|
with open(code_file, "w", encoding="utf-8") as f:
|
|
@@ -410,11 +334,11 @@ def write_code(
|
|
|
410
334
|
f"[green]File written:[/] {code_file}"
|
|
411
335
|
)
|
|
412
336
|
|
|
413
|
-
# Append the file to the list in state
|
|
337
|
+
# Append the file to the list in agent's state for later reference
|
|
414
338
|
file_list = state.get("code_files", [])
|
|
415
339
|
file_list.append(filename)
|
|
416
340
|
|
|
417
|
-
# Create a tool message to send back
|
|
341
|
+
# Create a tool message to send back to acknowledge success.
|
|
418
342
|
msg = ToolMessage(
|
|
419
343
|
content=f"File {filename} written successfully.",
|
|
420
344
|
tool_call_id=tool_call_id,
|
|
@@ -499,41 +423,363 @@ def edit_code(
|
|
|
499
423
|
return f"File {filename} updated successfully."
|
|
500
424
|
|
|
501
425
|
|
|
502
|
-
|
|
503
|
-
|
|
426
|
+
# Main module class
|
|
427
|
+
class ExecutionAgent(BaseAgent):
|
|
428
|
+
"""Orchestrates model-driven code execution, tool calls, and state management.
|
|
504
429
|
|
|
430
|
+
Orchestrates model-driven code execution, tool calls, and state management for
|
|
431
|
+
iterative program synthesis and shell interaction.
|
|
505
432
|
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
if not last_message.tool_calls:
|
|
512
|
-
return "summarize"
|
|
513
|
-
# Otherwise if there is, we continue
|
|
514
|
-
else:
|
|
515
|
-
return "continue"
|
|
516
|
-
|
|
433
|
+
This agent wraps an LLM with a small execution graph that alternates
|
|
434
|
+
between issuing model queries, invoking tools (run, write, edit, search),
|
|
435
|
+
performing safety checks, and summarizing progress. It manages a
|
|
436
|
+
workspace on disk, optional symlinks, and an optional memory backend to
|
|
437
|
+
persist summaries.
|
|
517
438
|
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
439
|
+
Args:
|
|
440
|
+
llm (str | BaseChatModel): Model identifier or bound chat model
|
|
441
|
+
instance. If a string is provided, the BaseAgent initializer will
|
|
442
|
+
resolve it.
|
|
443
|
+
agent_memory (Any | AgentMemory, optional): Memory backend used to
|
|
444
|
+
store summarized agent interactions. If provided, summaries are
|
|
445
|
+
saved here.
|
|
446
|
+
log_state (bool): When True, the agent writes intermediate json state
|
|
447
|
+
to disk for debugging and auditability.
|
|
448
|
+
**kwargs: Passed through to the BaseAgent constructor (e.g., model
|
|
449
|
+
configuration, checkpointer).
|
|
450
|
+
|
|
451
|
+
Attributes:
|
|
452
|
+
safety_prompt (str): Prompt used to evaluate safety of shell
|
|
453
|
+
commands.
|
|
454
|
+
executor_prompt (str): Prompt used when invoking the executor LLM
|
|
455
|
+
loop.
|
|
456
|
+
summarize_prompt (str): Prompt used to request concise summaries for
|
|
457
|
+
memory or final output.
|
|
458
|
+
tools (list[Tool]): Tools available to the agent (run_cmd, write_code,
|
|
459
|
+
edit_code, search_tool).
|
|
460
|
+
tool_node (ToolNode): Graph node that dispatches tool calls.
|
|
461
|
+
llm (BaseChatModel): LLM instance bound to the available tools.
|
|
462
|
+
_action (StateGraph): Compiled execution graph that implements the
|
|
463
|
+
main loop and branching logic.
|
|
464
|
+
|
|
465
|
+
Methods:
|
|
466
|
+
query_executor(state): Send messages to the executor LLM, ensure
|
|
467
|
+
workspace exists, and handle symlink setup before returning the
|
|
468
|
+
model response.
|
|
469
|
+
summarize(state): Produce and optionally persist a summary of recent
|
|
470
|
+
interactions to the memory backend.
|
|
471
|
+
safety_check(state): Validate pending run_cmd calls via the safety
|
|
472
|
+
prompt and append ToolMessages for unsafe commands.
|
|
473
|
+
_build_graph(): Construct and compile the StateGraph for the agent
|
|
474
|
+
loop.
|
|
475
|
+
_invoke(inputs, recursion_limit=...): Internal entry that invokes the
|
|
476
|
+
compiled graph with a given recursion limit.
|
|
477
|
+
action (property): Disabled; direct access is not supported. Use
|
|
478
|
+
invoke or stream entry points instead.
|
|
479
|
+
|
|
480
|
+
Raises:
|
|
481
|
+
AttributeError: Accessing the .action attribute raises to encourage
|
|
482
|
+
using .stream(...) or .invoke(...).
|
|
522
483
|
"""
|
|
523
484
|
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
485
|
+
def __init__(
|
|
486
|
+
self,
|
|
487
|
+
llm: str | BaseChatModel = "openai/gpt-4o-mini",
|
|
488
|
+
agent_memory: Optional[Any | AgentMemory] = None,
|
|
489
|
+
log_state: bool = False,
|
|
490
|
+
**kwargs,
|
|
491
|
+
):
|
|
492
|
+
"""ExecutionAgent class initialization."""
|
|
493
|
+
super().__init__(llm, **kwargs)
|
|
494
|
+
self.agent_memory = agent_memory
|
|
495
|
+
self.safety_prompt = safety_prompt
|
|
496
|
+
self.executor_prompt = executor_prompt
|
|
497
|
+
self.summarize_prompt = summarize_prompt
|
|
498
|
+
self.tools = [run_cmd, write_code, edit_code, search_tool]
|
|
499
|
+
self.tool_node = ToolNode(self.tools)
|
|
500
|
+
self.llm = self.llm.bind_tools(self.tools)
|
|
501
|
+
self.log_state = log_state
|
|
530
502
|
|
|
531
|
-
|
|
532
|
-
message = state["messages"][index]
|
|
503
|
+
self._action = self._build_graph()
|
|
533
504
|
|
|
534
|
-
|
|
505
|
+
# Define the function that calls the model
|
|
506
|
+
def query_executor(self, state: ExecutionState) -> ExecutionState:
|
|
507
|
+
"""Prepare workspace, handle optional symlinks, and invoke the executor LLM.
|
|
508
|
+
|
|
509
|
+
This method copies the incoming state, ensures a workspace directory exists
|
|
510
|
+
(creating one with a random name when absent), optionally creates a symlink
|
|
511
|
+
described by state["symlinkdir"], sets or injects the executor system prompt
|
|
512
|
+
as the first message, and invokes the bound LLM. When logging is enabled,
|
|
513
|
+
it persists the pre-invocation state to disk.
|
|
514
|
+
|
|
515
|
+
Args:
|
|
516
|
+
state: The current execution state. Expected keys include:
|
|
517
|
+
- "messages": Ordered list of System/Human/AI/Tool messages.
|
|
518
|
+
- "workspace": Optional path to the working directory.
|
|
519
|
+
- "symlinkdir": Optional dict with "source" and "dest" keys.
|
|
520
|
+
|
|
521
|
+
Returns:
|
|
522
|
+
ExecutionState: Partial state update containing:
|
|
523
|
+
- "messages": A list with the model's response as the latest entry.
|
|
524
|
+
- "workspace": The resolved workspace path.
|
|
525
|
+
"""
|
|
526
|
+
new_state = state.copy()
|
|
527
|
+
|
|
528
|
+
# 1) Ensure a workspace directory exists, creating a named one if absent.
|
|
529
|
+
if "workspace" not in new_state.keys():
|
|
530
|
+
new_state["workspace"] = randomname.get_name()
|
|
531
|
+
print(
|
|
532
|
+
f"{RED}Creating the folder "
|
|
533
|
+
f"{BLUE}{BOLD}{new_state['workspace']}{RESET}{RED} "
|
|
534
|
+
f"for this project.{RESET}"
|
|
535
|
+
)
|
|
536
|
+
os.makedirs(new_state["workspace"], exist_ok=True)
|
|
537
|
+
|
|
538
|
+
# 2) Optionally create a symlink if symlinkdir is provided and not yet linked.
|
|
539
|
+
sd = new_state.get("symlinkdir")
|
|
540
|
+
if isinstance(sd, dict) and "is_linked" not in sd:
|
|
541
|
+
# symlinkdir structure: {"source": "/path/to/src", "dest": "link/name"}
|
|
542
|
+
symlinkdir = sd
|
|
543
|
+
|
|
544
|
+
src = Path(symlinkdir["source"]).expanduser().resolve()
|
|
545
|
+
workspace_root = Path(new_state["workspace"]).expanduser().resolve()
|
|
546
|
+
dst = (
|
|
547
|
+
workspace_root / symlinkdir["dest"]
|
|
548
|
+
) # Link lives inside workspace.
|
|
549
|
+
|
|
550
|
+
# If a file/link already exists at the destination, replace it.
|
|
551
|
+
if dst.exists() or dst.is_symlink():
|
|
552
|
+
dst.unlink()
|
|
553
|
+
|
|
554
|
+
# Ensure parent directories for the link exist.
|
|
555
|
+
dst.parent.mkdir(parents=True, exist_ok=True)
|
|
556
|
+
|
|
557
|
+
# Create the symlink (tell pathlib if the target is a directory).
|
|
558
|
+
dst.symlink_to(src, target_is_directory=src.is_dir())
|
|
559
|
+
print(f"{RED}Symlinked {src} (source) --> {dst} (dest)")
|
|
560
|
+
new_state["symlinkdir"]["is_linked"] = True
|
|
561
|
+
|
|
562
|
+
# 3) Ensure the executor prompt is the first SystemMessage.
|
|
563
|
+
if isinstance(new_state["messages"][0], SystemMessage):
|
|
564
|
+
new_state["messages"][0] = SystemMessage(
|
|
565
|
+
content=self.executor_prompt
|
|
566
|
+
)
|
|
567
|
+
else:
|
|
568
|
+
new_state["messages"] = [
|
|
569
|
+
SystemMessage(content=self.executor_prompt)
|
|
570
|
+
] + state["messages"]
|
|
571
|
+
|
|
572
|
+
# 4) Invoke the LLM with the prepared message sequence.
|
|
573
|
+
try:
|
|
574
|
+
response = self.llm.invoke(
|
|
575
|
+
new_state["messages"], self.build_config(tags=["agent"])
|
|
576
|
+
)
|
|
577
|
+
except ContentPolicyViolationError as e:
|
|
578
|
+
print("Error: ", e, " ", new_state["messages"][-1].content)
|
|
579
|
+
|
|
580
|
+
# 5) Optionally persist the pre-invocation state for audit/debugging.
|
|
581
|
+
if self.log_state:
|
|
582
|
+
self.write_state("execution_agent.json", new_state)
|
|
583
|
+
|
|
584
|
+
# Return the model's response and the workspace path as a partial state update.
|
|
585
|
+
return {"messages": [response], "workspace": new_state["workspace"]}
|
|
586
|
+
|
|
587
|
+
def summarize(self, state: ExecutionState) -> ExecutionState:
|
|
588
|
+
"""Produce a concise summary of the conversation and optionally persist memory.
|
|
589
|
+
|
|
590
|
+
This method builds a summarization prompt, invokes the LLM to obtain a compact
|
|
591
|
+
summary of recent interactions, optionally logs salient details to the agent
|
|
592
|
+
memory backend, and writes debug state when logging is enabled.
|
|
593
|
+
|
|
594
|
+
Args:
|
|
595
|
+
state (ExecutionState): The execution state containing message history.
|
|
596
|
+
|
|
597
|
+
Returns:
|
|
598
|
+
ExecutionState: A partial update with a single string message containing
|
|
599
|
+
the summary.
|
|
600
|
+
"""
|
|
601
|
+
# 1) Construct the summarization message list (system prompt + prior messages).
|
|
602
|
+
messages = [SystemMessage(content=summarize_prompt)] + state["messages"]
|
|
603
|
+
|
|
604
|
+
# 2) Invoke the LLM to generate a summary; capture content even on failure.
|
|
605
|
+
response_content = ""
|
|
606
|
+
try:
|
|
607
|
+
response = self.llm.invoke(
|
|
608
|
+
messages, self.build_config(tags=["summarize"])
|
|
609
|
+
)
|
|
610
|
+
response_content = response.content
|
|
611
|
+
except ContentPolicyViolationError as e:
|
|
612
|
+
print("Error: ", e, " ", messages[-1].content)
|
|
613
|
+
|
|
614
|
+
# 3) Optionally persist salient details to the memory backend.
|
|
615
|
+
if self.agent_memory:
|
|
616
|
+
memories: list[str] = []
|
|
617
|
+
# Collect human/system/tool message content; for AI tool calls, store args.
|
|
618
|
+
for msg in state["messages"]:
|
|
619
|
+
if not isinstance(msg, AIMessage):
|
|
620
|
+
memories.append(msg.content)
|
|
621
|
+
elif not msg.tool_calls:
|
|
622
|
+
memories.append(msg.content)
|
|
623
|
+
else:
|
|
624
|
+
tool_strings = []
|
|
625
|
+
for tool in msg.tool_calls:
|
|
626
|
+
tool_strings.append("Tool Name: " + tool["name"])
|
|
627
|
+
for arg_name in tool["args"]:
|
|
628
|
+
tool_strings.append(
|
|
629
|
+
f"Arg: {str(arg_name)}\nValue: "
|
|
630
|
+
f"{str(tool['args'][arg_name])}"
|
|
631
|
+
)
|
|
632
|
+
memories.append("\n".join(tool_strings))
|
|
633
|
+
memories.append(response_content)
|
|
634
|
+
self.agent_memory.add_memories(memories)
|
|
635
|
+
|
|
636
|
+
# 4) Optionally write state to disk for debugging/auditing.
|
|
637
|
+
if self.log_state:
|
|
638
|
+
save_state = state.copy()
|
|
639
|
+
# Append the summary as an AI message for a complete trace.
|
|
640
|
+
save_state["messages"] = save_state["messages"] + [
|
|
641
|
+
AIMessage(content=response_content)
|
|
642
|
+
]
|
|
643
|
+
self.write_state("execution_agent.json", save_state)
|
|
644
|
+
|
|
645
|
+
# 5) Return a partial state update with only the summary content.
|
|
646
|
+
return {"messages": [response_content]}
|
|
647
|
+
|
|
648
|
+
def safety_check(self, state: ExecutionState) -> ExecutionState:
|
|
649
|
+
"""Assess pending shell commands for safety and inject ToolMessages with results.
|
|
650
|
+
|
|
651
|
+
This method inspects the most recent AI tool calls, evaluates any run_cmd
|
|
652
|
+
queries against the safety prompt, and constructs ToolMessages that either
|
|
653
|
+
flag unsafe commands with reasons or confirm safe execution. If any command
|
|
654
|
+
is unsafe, the generated ToolMessages are appended to the state so the agent
|
|
655
|
+
can react without executing the command.
|
|
656
|
+
|
|
657
|
+
Args:
|
|
658
|
+
state (ExecutionState): Current execution state.
|
|
659
|
+
|
|
660
|
+
Returns:
|
|
661
|
+
ExecutionState: Either the unchanged state (all safe) or a copy with one
|
|
662
|
+
or more ToolMessages appended when unsafe commands are detected.
|
|
663
|
+
"""
|
|
664
|
+
# 1) Work on a shallow copy; inspect the most recent model message.
|
|
665
|
+
new_state = state.copy()
|
|
666
|
+
last_msg = new_state["messages"][-1]
|
|
667
|
+
|
|
668
|
+
# 2) Evaluate any pending run_cmd tool calls for safety.
|
|
669
|
+
tool_responses: list[ToolMessage] = []
|
|
670
|
+
any_unsafe = False
|
|
671
|
+
for tool_call in last_msg.tool_calls:
|
|
672
|
+
if tool_call["name"] != "run_cmd":
|
|
673
|
+
continue
|
|
674
|
+
|
|
675
|
+
query = tool_call["args"]["query"]
|
|
676
|
+
safety_result = self.llm.invoke(
|
|
677
|
+
self.safety_prompt + query,
|
|
678
|
+
self.build_config(tags=["safety_check"]),
|
|
679
|
+
)
|
|
680
|
+
|
|
681
|
+
if "[NO]" in safety_result.content:
|
|
682
|
+
any_unsafe = True
|
|
683
|
+
tool_response = (
|
|
684
|
+
"[UNSAFE] That command `{q}` was deemed unsafe and cannot be run.\n"
|
|
685
|
+
"For reason: {r}"
|
|
686
|
+
).format(q=query, r=safety_result.content)
|
|
687
|
+
console.print(
|
|
688
|
+
"[bold red][WARNING][/bold red] Command deemed unsafe:",
|
|
689
|
+
query,
|
|
690
|
+
)
|
|
691
|
+
# Also surface the model's rationale for transparency.
|
|
692
|
+
console.print(
|
|
693
|
+
"[bold red][WARNING][/bold red] REASON:", tool_response
|
|
694
|
+
)
|
|
695
|
+
else:
|
|
696
|
+
tool_response = f"Command `{query}` passed safety check."
|
|
697
|
+
console.print(
|
|
698
|
+
f"[green]Command passed safety check:[/green] {query}"
|
|
699
|
+
)
|
|
700
|
+
|
|
701
|
+
tool_responses.append(
|
|
702
|
+
ToolMessage(
|
|
703
|
+
content=tool_response,
|
|
704
|
+
tool_call_id=tool_call["id"],
|
|
705
|
+
)
|
|
706
|
+
)
|
|
707
|
+
|
|
708
|
+
# 3) If any command is unsafe, append all tool responses; otherwise keep state.
|
|
709
|
+
if any_unsafe:
|
|
710
|
+
new_state["messages"].extend(tool_responses)
|
|
711
|
+
|
|
712
|
+
return new_state
|
|
713
|
+
|
|
714
|
+
def _build_graph(self):
|
|
715
|
+
"""Construct and compile the agent's LangGraph state machine."""
|
|
716
|
+
# Create a graph over the agent's execution state.
|
|
717
|
+
graph = StateGraph(ExecutionState)
|
|
718
|
+
|
|
719
|
+
# Register nodes:
|
|
720
|
+
# - "agent": LLM planning/execution step
|
|
721
|
+
# - "action": tool dispatch (run_cmd, write_code, etc.)
|
|
722
|
+
# - "summarize": summary/finalization step
|
|
723
|
+
# - "safety_check": gate for shell command safety
|
|
724
|
+
self.add_node(graph, self.query_executor, "agent")
|
|
725
|
+
self.add_node(graph, self.tool_node, "action")
|
|
726
|
+
self.add_node(graph, self.summarize, "summarize")
|
|
727
|
+
self.add_node(graph, self.safety_check, "safety_check")
|
|
728
|
+
|
|
729
|
+
# Set entrypoint: execution starts with the "agent" node.
|
|
730
|
+
graph.set_entry_point("agent")
|
|
731
|
+
|
|
732
|
+
# From "agent", either continue (tools) or finish (summarize),
|
|
733
|
+
# based on presence of tool calls in the last message.
|
|
734
|
+
graph.add_conditional_edges(
|
|
735
|
+
"agent",
|
|
736
|
+
self._wrap_cond(should_continue, "should_continue", "execution"),
|
|
737
|
+
{"continue": "safety_check", "summarize": "summarize"},
|
|
738
|
+
)
|
|
739
|
+
|
|
740
|
+
# From "safety_check", route to tools if safe, otherwise back to agent
|
|
741
|
+
# to revise the plan without executing unsafe commands.
|
|
742
|
+
graph.add_conditional_edges(
|
|
743
|
+
"safety_check",
|
|
744
|
+
self._wrap_cond(command_safe, "command_safe", "execution"),
|
|
745
|
+
{"safe": "action", "unsafe": "agent"},
|
|
746
|
+
)
|
|
747
|
+
|
|
748
|
+
# After tools run, return control to the agent for the next step.
|
|
749
|
+
graph.add_edge("action", "agent")
|
|
750
|
+
|
|
751
|
+
# The graph completes at the "summarize" node.
|
|
752
|
+
graph.set_finish_point("summarize")
|
|
753
|
+
|
|
754
|
+
# Compile and return the executable graph (optionally with a checkpointer).
|
|
755
|
+
return graph.compile(checkpointer=self.checkpointer)
|
|
756
|
+
|
|
757
|
+
def _invoke(
|
|
758
|
+
self, inputs: Mapping[str, Any], recursion_limit: int = 999_999, **_
|
|
759
|
+
):
|
|
760
|
+
"""Invoke the compiled graph with inputs under a specified recursion limit.
|
|
761
|
+
|
|
762
|
+
This method builds a LangGraph config with the provided recursion limit
|
|
763
|
+
and a "graph" tag, then delegates to the compiled graph's invoke method.
|
|
764
|
+
"""
|
|
765
|
+
# Build invocation config with a generous recursion limit for long runs.
|
|
766
|
+
config = self.build_config(
|
|
767
|
+
recursion_limit=recursion_limit, tags=["graph"]
|
|
768
|
+
)
|
|
769
|
+
|
|
770
|
+
# Delegate execution to the compiled graph.
|
|
771
|
+
return self._action.invoke(inputs, config)
|
|
772
|
+
|
|
773
|
+
# This property is trying to stop people bypassing invoke
|
|
774
|
+
@property
|
|
775
|
+
def action(self):
|
|
776
|
+
"""Property used to affirm `action` attribute is unsupported."""
|
|
777
|
+
raise AttributeError(
|
|
778
|
+
"Use .stream(...) or .invoke(...); direct .action access is unsupported."
|
|
779
|
+
)
|
|
535
780
|
|
|
536
781
|
|
|
782
|
+
# Single module test execution
|
|
537
783
|
def main():
|
|
538
784
|
execution_agent = ExecutionAgent()
|
|
539
785
|
problem_string = (
|