ursa-ai 0.7.0rc1__py3-none-any.whl → 0.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ursa-ai might be problematic. Click here for more details.

@@ -1,3 +1,30 @@
1
+ """Execution agent that builds a tool-enabled state graph to autonomously run tasks.
2
+
3
+ This module implements ExecutionAgent, a LangGraph-based agent that executes user
4
+ instructions by invoking LLM tool calls and coordinating a controlled workflow.
5
+
6
+ Key features:
7
+ - Workspace management with optional symlinking for external sources.
8
+ - Safety-checked shell execution via run_cmd with output size budgeting.
9
+ - Code authoring and edits through write_code and edit_code with rich previews.
10
+ - Web search capability through DuckDuckGoSearchResults.
11
+ - Summarization of the session and optional memory logging.
12
+ - Configurable graph with nodes for agent, safety_check, action, and summarize.
13
+
14
+ Implementation notes:
15
+ - LLM prompts are sourced from prompt_library.execution_prompts.
16
+ - Outputs from subprocess are trimmed under MAX_TOOL_MSG_CHARS to fit tool messages.
17
+ - The agent uses ToolNode and LangGraph StateGraph to loop until no tool calls remain.
18
+ - Safety gates block unsafe shell commands and surface the rationale to the user.
19
+
20
+ Environment:
21
+ - MAX_TOOL_MSG_CHARS caps combined stdout/stderr in tool responses.
22
+
23
+ Entry points:
24
+ - ExecutionAgent._invoke(...) runs the compiled graph.
25
+ - main() shows a minimal demo that writes and runs a script.
26
+ """
27
+
1
28
  import os
2
29
 
3
30
  # from langchain_core.runnables.graph import MermaidDrawMethod
@@ -48,7 +75,33 @@ RESET = "\033[0m"
48
75
  BOLD = "\033[1m"
49
76
 
50
77
 
78
+ # Global variables for the module.
79
+
80
+ # Set a limit for message characters - the user could overload
81
+ # that in their env, or maybe we could pull this out of the LLM parameters
82
+ MAX_TOOL_MSG_CHARS = int(os.getenv("MAX_TOOL_MSG_CHARS", "50000"))
83
+
84
+ # Set a search tool.
85
+ search_tool = DuckDuckGoSearchResults(output_format="json", num_results=10)
86
+ # search_tool = TavilySearchResults(
87
+ # max_results=10,
88
+ # search_depth="advanced",
89
+ # include_answer=True)
90
+
91
+
92
+ # Classes for typing
51
93
  class ExecutionState(TypedDict):
94
+ """TypedDict representing the execution agent's mutable run state used by nodes.
95
+
96
+ Fields:
97
+ - messages: list of messages (System/Human/AI/Tool) with add_messages metadata.
98
+ - current_progress: short status string describing agent progress.
99
+ - code_files: list of filenames created or edited in the workspace.
100
+ - workspace: path to the working directory where files and commands run.
101
+ - symlinkdir: optional dict describing a symlink operation (source, dest,
102
+ is_linked).
103
+ """
104
+
52
105
  messages: Annotated[list, add_messages]
53
106
  current_progress: str
54
107
  code_files: list[str]
@@ -56,220 +109,42 @@ class ExecutionState(TypedDict):
56
109
  symlinkdir: dict
57
110
 
58
111
 
59
- class ExecutionAgent(BaseAgent):
60
- def __init__(
61
- self,
62
- llm: str | BaseChatModel = "openai/gpt-4o-mini",
63
- agent_memory: Optional[Any | AgentMemory] = None,
64
- log_state: bool = False,
65
- **kwargs,
66
- ):
67
- super().__init__(llm, **kwargs)
68
- self.agent_memory = agent_memory
69
- self.safety_prompt = safety_prompt
70
- self.executor_prompt = executor_prompt
71
- self.summarize_prompt = summarize_prompt
72
- self.tools = [run_cmd, write_code, edit_code, search_tool]
73
- self.tool_node = ToolNode(self.tools)
74
- self.llm = self.llm.bind_tools(self.tools)
75
- self.log_state = log_state
76
-
77
- self._action = self._build_graph()
78
-
79
- # Define the function that calls the model
80
- def query_executor(self, state: ExecutionState) -> ExecutionState:
81
- new_state = state.copy()
82
- if "workspace" not in new_state.keys():
83
- new_state["workspace"] = randomname.get_name()
84
- print(
85
- f"{RED}Creating the folder {BLUE}{BOLD}{new_state['workspace']}{RESET}{RED} for this project.{RESET}"
86
- )
87
- os.makedirs(new_state["workspace"], exist_ok=True)
88
-
89
- # code related to symlink
90
- sd = new_state.get("symlinkdir")
91
- if isinstance(sd, dict) and "is_linked" not in sd:
92
- # symlinkdir = {"source": "foo", "dest": "bar"}
93
- symlinkdir = new_state["symlinkdir"]
94
- # user provided a symlinkdir key - let's do the linking!
95
-
96
- src = Path(symlinkdir["source"]).expanduser().resolve()
97
- workspace_root = Path(new_state["workspace"]).expanduser().resolve()
98
- dst = workspace_root / symlinkdir["dest"] # prepend workspace
99
-
100
- # if you want to replace an existing link/file, unlink it first
101
- if dst.exists() or dst.is_symlink():
102
- dst.unlink()
103
-
104
- # create parent dirs for the link location if they don’t exist
105
- dst.parent.mkdir(parents=True, exist_ok=True)
106
-
107
- # actually make the link (tell pathlib it’s a directory target)
108
- dst.symlink_to(src, target_is_directory=src.is_dir())
109
- print(f"{RED}Symlinked {src} (source) --> {dst} (dest)")
110
- # note that we've done the symlink now, so don't need to do it later
111
- new_state["symlinkdir"]["is_linked"] = True
112
-
113
- if isinstance(new_state["messages"][0], SystemMessage):
114
- new_state["messages"][0] = SystemMessage(
115
- content=self.executor_prompt
116
- )
117
- else:
118
- new_state["messages"] = [
119
- SystemMessage(content=self.executor_prompt)
120
- ] + state["messages"]
121
- try:
122
- response = self.llm.invoke(
123
- new_state["messages"], self.build_config(tags=["agent"])
124
- )
125
- except ContentPolicyViolationError as e:
126
- print("Error: ", e, " ", new_state["messages"][-1].content)
127
- if self.log_state:
128
- self.write_state("execution_agent.json", new_state)
129
- return {"messages": [response], "workspace": new_state["workspace"]}
130
-
131
- # Define the function that calls the model
132
- def summarize(self, state: ExecutionState) -> ExecutionState:
133
- messages = [SystemMessage(content=summarize_prompt)] + state["messages"]
134
- try:
135
- response = self.llm.invoke(
136
- messages, self.build_config(tags=["summarize"])
137
- )
138
- except ContentPolicyViolationError as e:
139
- print("Error: ", e, " ", messages[-1].content)
140
- if self.agent_memory:
141
- memories = []
142
- # Handle looping through the messages
143
- for x in state["messages"]:
144
- if not isinstance(x, AIMessage):
145
- memories.append(x.content)
146
- elif not x.tool_calls:
147
- memories.append(x.content)
148
- else:
149
- tool_strings = []
150
- for tool in x.tool_calls:
151
- tool_name = "Tool Name: " + tool["name"]
152
- tool_strings.append(tool_name)
153
- for y in tool["args"]:
154
- tool_strings.append(
155
- f"Arg: {str(y)}\nValue: {str(tool['args'][y])}"
156
- )
157
- memories.append("\n".join(tool_strings))
158
- memories.append(response.content)
159
- self.agent_memory.add_memories(memories)
160
- save_state = state.copy()
161
- save_state["messages"].append(response)
162
- if self.log_state:
163
- self.write_state("execution_agent.json", save_state)
164
- return {"messages": [response.content]}
165
-
166
- # Define the function that calls the model
167
- def safety_check(self, state: ExecutionState) -> ExecutionState:
168
- """
169
- Validate the safety of a pending shell command.
170
-
171
- Args:
172
- state: Current execution state.
173
-
174
- Returns:
175
- Either the unchanged state (safe) or a state with tool message(s) (unsafe).
176
- """
177
- new_state = state.copy()
178
- last_msg = new_state["messages"][-1]
179
-
180
- tool_responses = []
181
- tool_failed = False
182
- for tool_call in last_msg.tool_calls:
183
- call_name = tool_call["name"]
184
-
185
- if call_name == "run_cmd":
186
- query = tool_call["args"]["query"]
187
- safety_check = self.llm.invoke(
188
- self.safety_prompt + query,
189
- self.build_config(tags=["safety_check"]),
190
- )
191
-
192
- if "[NO]" in safety_check.content:
193
- tool_failed = True
194
-
195
- tool_response = f"""
196
- [UNSAFE] That command `{query}` was deemed unsafe and cannot be run.
197
- For reason: {safety_check.content}
198
- """
199
- console.print(
200
- "[bold red][WARNING][/bold red] Command deemed unsafe:",
201
- query,
202
- )
203
- # and tell the user the reason
204
- console.print(
205
- "[bold red][WARNING][/bold red] REASON:", tool_response
206
- )
207
-
208
- else:
209
- tool_response = f"Command `{query}` passed safety check."
210
- console.print(
211
- f"[green]Command passed safety check:[/green] {query}"
212
- )
213
-
214
- tool_responses.append(
215
- ToolMessage(
216
- content=tool_response,
217
- tool_call_id=tool_call["id"],
218
- )
219
- )
220
-
221
- if tool_failed:
222
- new_state["messages"].extend(tool_responses)
223
-
224
- return new_state
225
-
226
- def _build_graph(self):
227
- graph = StateGraph(ExecutionState)
228
-
229
- self.add_node(graph, self.query_executor, "agent")
230
- self.add_node(graph, self.tool_node, "action")
231
- self.add_node(graph, self.summarize, "summarize")
232
- self.add_node(graph, self.safety_check, "safety_check")
112
+ # Helper functions
113
+ def _strip_fences(snippet: str) -> str:
114
+ """Remove markdown fences from a code snippet.
233
115
 
234
- # Set the entrypoint as `agent`
235
- # This means that this node is the first one called
236
- graph.set_entry_point("agent")
116
+ This function strips leading triple backticks and any language
117
+ identifiers from a markdown-formatted code snippet and returns
118
+ only the contained code.
237
119
 
238
- graph.add_conditional_edges(
239
- "agent",
240
- self._wrap_cond(should_continue, "should_continue", "execution"),
241
- {"continue": "safety_check", "summarize": "summarize"},
242
- )
120
+ Args:
121
+ snippet: The markdown-formatted code snippet.
243
122
 
244
- graph.add_conditional_edges(
245
- "safety_check",
246
- self._wrap_cond(command_safe, "command_safe", "execution"),
247
- {"safe": "action", "unsafe": "agent"},
248
- )
123
+ Returns:
124
+ The snippet content without leading markdown fences.
125
+ """
126
+ if "```" not in snippet:
127
+ return snippet
249
128
 
250
- graph.add_edge("action", "agent")
251
- graph.set_finish_point("summarize")
129
+ parts = snippet.split("```")
130
+ if len(parts) < 3:
131
+ return snippet
252
132
 
253
- return graph.compile(checkpointer=self.checkpointer)
254
- # self.action.get_graph().draw_mermaid_png(output_file_path="execution_agent_graph.png", draw_method=MermaidDrawMethod.PYPPETEER)
133
+ body = parts[1]
134
+ return "\n".join(body.split("\n")[1:]) if "\n" in body else body.strip()
255
135
 
256
- def _invoke(
257
- self, inputs: Mapping[str, Any], recursion_limit: int = 999_999, **_
258
- ):
259
- config = self.build_config(
260
- recursion_limit=recursion_limit, tags=["graph"]
261
- )
262
- return self._action.invoke(inputs, config)
263
136
 
264
- # this is trying to stop people bypassing invoke
265
- @property
266
- def action(self):
267
- raise AttributeError(
268
- "Use .stream(...) or .invoke(...); direct .action access is unsupported."
269
- )
137
+ def _snip_text(text: str, max_chars: int) -> tuple[str, bool]:
138
+ """Truncate text to a maximum length and indicate if truncation occurred.
270
139
 
140
+ Args:
141
+ text: The original text to potentially truncate.
142
+ max_chars: The maximum characters allowed in the output.
271
143
 
272
- def _snip_text(text: str, max_chars: int) -> tuple[str, bool]:
144
+ Returns:
145
+ A tuple of (possibly truncated text, boolean flag indicating
146
+ if truncation occurred).
147
+ """
273
148
  if text is None:
274
149
  return "", False
275
150
  if max_chars <= 0:
@@ -287,6 +162,16 @@ def _snip_text(text: str, max_chars: int) -> tuple[str, bool]:
287
162
 
288
163
 
289
164
  def _fit_streams_to_budget(stdout: str, stderr: str, total_budget: int):
165
+ """Allocate and truncate stdout and stderr to fit a total character budget.
166
+
167
+ Args:
168
+ stdout: The original stdout string.
169
+ stderr: The original stderr string.
170
+ total_budget: The combined character budget for stdout and stderr.
171
+
172
+ Returns:
173
+ A tuple of (possibly truncated stdout, possibly truncated stderr).
174
+ """
290
175
  label_overhead = len("STDOUT:\n") + len("\nSTDERR:\n")
291
176
  budget = max(0, total_budget - label_overhead)
292
177
 
@@ -299,23 +184,72 @@ def _fit_streams_to_budget(stdout: str, stderr: str, total_budget: int):
299
184
 
300
185
  stdout_snip, _ = _snip_text(stdout, stdout_budget)
301
186
  stderr_snip, _ = _snip_text(stderr, stderr_budget)
187
+
302
188
  return stdout_snip, stderr_snip
303
189
 
304
190
 
305
- # the idea here is that we just set a limit - the user could overload
306
- # that in their env, or maybe we could pull this out of the LLM parameters
307
- MAX_TOOL_MSG_CHARS = int(os.getenv("MAX_TOOL_MSG_CHARS", "50000"))
191
+ def should_continue(state: ExecutionState) -> Literal["summarize", "continue"]:
192
+ """Return 'summarize' if no tool calls in the last message, else 'continue'.
193
+
194
+ Args:
195
+ state: The current execution state containing messages.
196
+
197
+ Returns:
198
+ A literal "summarize" if the last message has no tool calls,
199
+ otherwise "continue".
200
+ """
201
+ messages = state["messages"]
202
+ last_message = messages[-1]
203
+ # If there is no tool call, then we finish
204
+ if not last_message.tool_calls:
205
+ return "summarize"
206
+ # Otherwise if there is, we continue
207
+ else:
208
+ return "continue"
209
+
210
+
211
+ def command_safe(state: ExecutionState) -> Literal["safe", "unsafe"]:
212
+ """Return 'safe' if the last command was safe, otherwise 'unsafe'.
213
+
214
+ Args:
215
+ state: The current execution state containing messages and tool calls.
216
+ Returns:
217
+ A literal "safe" if no '[UNSAFE]' tags are in the last command,
218
+ otherwise "unsafe".
219
+ """
220
+ index = -1
221
+ message = state["messages"][index]
222
+ # Loop through all the consecutive tool messages in reverse order
223
+ while isinstance(message, ToolMessage):
224
+ if "[UNSAFE]" in message.content:
225
+ return "unsafe"
308
226
 
227
+ index -= 1
228
+ message = state["messages"][index]
229
+
230
+ return "safe"
309
231
 
232
+
233
+ # Tools for ExecutionAgent
310
234
  @tool
311
235
  def run_cmd(query: str, state: Annotated[dict, InjectedState]) -> str:
312
- """
313
- Run a commandline command from using the subprocess package in python
236
+ """Execute a shell command in the workspace and return its combined output.
237
+
238
+ Runs the specified command using subprocess.run in the given workspace
239
+ directory, captures stdout and stderr, enforces a maximum character budget,
240
+ and formats both streams into a single string. KeyboardInterrupt during
241
+ execution is caught and reported.
314
242
 
315
243
  Args:
316
- query: commandline command to be run as a string given to the subprocess.run command.
244
+ query: The shell command to execute.
245
+ state: A dict with injected state; must include the 'workspace' path.
246
+
247
+ Returns:
248
+ A formatted string with "STDOUT:" followed by the truncated stdout and
249
+ "STDERR:" followed by the truncated stderr.
317
250
  """
318
251
  workspace_dir = state["workspace"]
252
+
319
253
  print("RUNNING: ", query)
320
254
  try:
321
255
  result = subprocess.run(
@@ -342,21 +276,6 @@ def run_cmd(query: str, state: Annotated[dict, InjectedState]) -> str:
342
276
  return f"STDOUT:\n{stdout_fit}\nSTDERR:\n{stderr_fit}"
343
277
 
344
278
 
345
- def _strip_fences(snippet: str) -> str:
346
- """
347
- Remove leading markdown ``` fence
348
- """
349
- if "```" not in snippet:
350
- return snippet
351
-
352
- parts = snippet.split("```")
353
- if len(parts) < 3:
354
- return snippet
355
-
356
- body = parts[1]
357
- return "\n".join(body.split("\n")[1:]) if "\n" in body else body.strip()
358
-
359
-
360
279
  @tool
361
280
  def write_code(
362
281
  code: str,
@@ -364,22 +283,26 @@ def write_code(
364
283
  tool_call_id: Annotated[str, InjectedToolCallId],
365
284
  state: Annotated[dict, InjectedState],
366
285
  ) -> Command:
367
- """Write *code* to *filename*.
286
+ """Write source code to a file and update the agent’s workspace state.
368
287
 
369
288
  Args:
370
- code: Source code as a string.
371
- filename: Target filename (including extension).
289
+ code: The source code content to be written to disk.
290
+ filename: Name of the target file (including its extension).
291
+ tool_call_id: Identifier for this tool invocation.
292
+ state: Agent state dict holding workspace path and file list.
372
293
 
373
294
  Returns:
374
- Success / failure message.
295
+ Command: Contains an updated state (including code_files) and
296
+ a ToolMessage acknowledging success or failure.
375
297
  """
298
+ # Determine the full path to the target file
376
299
  workspace_dir = state["workspace"]
377
300
  console.print("[cyan]Writing file:[/]", filename)
378
301
 
379
- # Clean up markdown fences
302
+ # Clean up markdown fences on submitted code.
380
303
  code = _strip_fences(code)
381
304
 
382
- # Syntax-highlighted preview
305
+ # Show syntax-highlighted preview before writing to file
383
306
  try:
384
307
  lexer_name = Syntax.guess_lexer(filename, code)
385
308
  except Exception:
@@ -393,6 +316,7 @@ def write_code(
393
316
  )
394
317
  )
395
318
 
319
+ # Write cleaned code to disk
396
320
  code_file = os.path.join(workspace_dir, filename)
397
321
  try:
398
322
  with open(code_file, "w", encoding="utf-8") as f:
@@ -410,11 +334,11 @@ def write_code(
410
334
  f"[green]File written:[/] {code_file}"
411
335
  )
412
336
 
413
- # Append the file to the list in state
337
+ # Append the file to the list in agent's state for later reference
414
338
  file_list = state.get("code_files", [])
415
339
  file_list.append(filename)
416
340
 
417
- # Create a tool message to send back
341
+ # Create a tool message to send back to acknowledge success.
418
342
  msg = ToolMessage(
419
343
  content=f"File {filename} written successfully.",
420
344
  tool_call_id=tool_call_id,
@@ -499,41 +423,363 @@ def edit_code(
499
423
  return f"File {filename} updated successfully."
500
424
 
501
425
 
502
- search_tool = DuckDuckGoSearchResults(output_format="json", num_results=10)
503
- # search_tool = TavilySearchResults(max_results=10, search_depth="advanced", include_answer=True)
426
+ # Main module class
427
+ class ExecutionAgent(BaseAgent):
428
+ """Orchestrates model-driven code execution, tool calls, and state management.
504
429
 
430
+ Orchestrates model-driven code execution, tool calls, and state management for
431
+ iterative program synthesis and shell interaction.
505
432
 
506
- # Define the function that determines whether to continue or not
507
- def should_continue(state: ExecutionState) -> Literal["summarize", "continue"]:
508
- messages = state["messages"]
509
- last_message = messages[-1]
510
- # If there is no tool call, then we finish
511
- if not last_message.tool_calls:
512
- return "summarize"
513
- # Otherwise if there is, we continue
514
- else:
515
- return "continue"
516
-
433
+ This agent wraps an LLM with a small execution graph that alternates
434
+ between issuing model queries, invoking tools (run, write, edit, search),
435
+ performing safety checks, and summarizing progress. It manages a
436
+ workspace on disk, optional symlinks, and an optional memory backend to
437
+ persist summaries.
517
438
 
518
- # Define the function that determines whether to continue or not
519
- def command_safe(state: ExecutionState) -> Literal["safe", "unsafe"]:
520
- """
521
- Return graph edge "safe" if the last command was safe, otherwise return edge "unsafe"
439
+ Args:
440
+ llm (str | BaseChatModel): Model identifier or bound chat model
441
+ instance. If a string is provided, the BaseAgent initializer will
442
+ resolve it.
443
+ agent_memory (Any | AgentMemory, optional): Memory backend used to
444
+ store summarized agent interactions. If provided, summaries are
445
+ saved here.
446
+ log_state (bool): When True, the agent writes intermediate json state
447
+ to disk for debugging and auditability.
448
+ **kwargs: Passed through to the BaseAgent constructor (e.g., model
449
+ configuration, checkpointer).
450
+
451
+ Attributes:
452
+ safety_prompt (str): Prompt used to evaluate safety of shell
453
+ commands.
454
+ executor_prompt (str): Prompt used when invoking the executor LLM
455
+ loop.
456
+ summarize_prompt (str): Prompt used to request concise summaries for
457
+ memory or final output.
458
+ tools (list[Tool]): Tools available to the agent (run_cmd, write_code,
459
+ edit_code, search_tool).
460
+ tool_node (ToolNode): Graph node that dispatches tool calls.
461
+ llm (BaseChatModel): LLM instance bound to the available tools.
462
+ _action (StateGraph): Compiled execution graph that implements the
463
+ main loop and branching logic.
464
+
465
+ Methods:
466
+ query_executor(state): Send messages to the executor LLM, ensure
467
+ workspace exists, and handle symlink setup before returning the
468
+ model response.
469
+ summarize(state): Produce and optionally persist a summary of recent
470
+ interactions to the memory backend.
471
+ safety_check(state): Validate pending run_cmd calls via the safety
472
+ prompt and append ToolMessages for unsafe commands.
473
+ _build_graph(): Construct and compile the StateGraph for the agent
474
+ loop.
475
+ _invoke(inputs, recursion_limit=...): Internal entry that invokes the
476
+ compiled graph with a given recursion limit.
477
+ action (property): Disabled; direct access is not supported. Use
478
+ invoke or stream entry points instead.
479
+
480
+ Raises:
481
+ AttributeError: Accessing the .action attribute raises to encourage
482
+ using .stream(...) or .invoke(...).
522
483
  """
523
484
 
524
- index = -1
525
- message = state["messages"][index]
526
- # Loop through all the consecutive tool messages in reverse order
527
- while isinstance(message, ToolMessage):
528
- if "[UNSAFE]" in message.content:
529
- return "unsafe"
485
+ def __init__(
486
+ self,
487
+ llm: str | BaseChatModel = "openai/gpt-4o-mini",
488
+ agent_memory: Optional[Any | AgentMemory] = None,
489
+ log_state: bool = False,
490
+ **kwargs,
491
+ ):
492
+ """ExecutionAgent class initialization."""
493
+ super().__init__(llm, **kwargs)
494
+ self.agent_memory = agent_memory
495
+ self.safety_prompt = safety_prompt
496
+ self.executor_prompt = executor_prompt
497
+ self.summarize_prompt = summarize_prompt
498
+ self.tools = [run_cmd, write_code, edit_code, search_tool]
499
+ self.tool_node = ToolNode(self.tools)
500
+ self.llm = self.llm.bind_tools(self.tools)
501
+ self.log_state = log_state
530
502
 
531
- index -= 1
532
- message = state["messages"][index]
503
+ self._action = self._build_graph()
533
504
 
534
- return "safe"
505
+ # Define the function that calls the model
506
+ def query_executor(self, state: ExecutionState) -> ExecutionState:
507
+ """Prepare workspace, handle optional symlinks, and invoke the executor LLM.
508
+
509
+ This method copies the incoming state, ensures a workspace directory exists
510
+ (creating one with a random name when absent), optionally creates a symlink
511
+ described by state["symlinkdir"], sets or injects the executor system prompt
512
+ as the first message, and invokes the bound LLM. When logging is enabled,
513
+ it persists the pre-invocation state to disk.
514
+
515
+ Args:
516
+ state: The current execution state. Expected keys include:
517
+ - "messages": Ordered list of System/Human/AI/Tool messages.
518
+ - "workspace": Optional path to the working directory.
519
+ - "symlinkdir": Optional dict with "source" and "dest" keys.
520
+
521
+ Returns:
522
+ ExecutionState: Partial state update containing:
523
+ - "messages": A list with the model's response as the latest entry.
524
+ - "workspace": The resolved workspace path.
525
+ """
526
+ new_state = state.copy()
527
+
528
+ # 1) Ensure a workspace directory exists, creating a named one if absent.
529
+ if "workspace" not in new_state.keys():
530
+ new_state["workspace"] = randomname.get_name()
531
+ print(
532
+ f"{RED}Creating the folder "
533
+ f"{BLUE}{BOLD}{new_state['workspace']}{RESET}{RED} "
534
+ f"for this project.{RESET}"
535
+ )
536
+ os.makedirs(new_state["workspace"], exist_ok=True)
537
+
538
+ # 2) Optionally create a symlink if symlinkdir is provided and not yet linked.
539
+ sd = new_state.get("symlinkdir")
540
+ if isinstance(sd, dict) and "is_linked" not in sd:
541
+ # symlinkdir structure: {"source": "/path/to/src", "dest": "link/name"}
542
+ symlinkdir = sd
543
+
544
+ src = Path(symlinkdir["source"]).expanduser().resolve()
545
+ workspace_root = Path(new_state["workspace"]).expanduser().resolve()
546
+ dst = (
547
+ workspace_root / symlinkdir["dest"]
548
+ ) # Link lives inside workspace.
549
+
550
+ # If a file/link already exists at the destination, replace it.
551
+ if dst.exists() or dst.is_symlink():
552
+ dst.unlink()
553
+
554
+ # Ensure parent directories for the link exist.
555
+ dst.parent.mkdir(parents=True, exist_ok=True)
556
+
557
+ # Create the symlink (tell pathlib if the target is a directory).
558
+ dst.symlink_to(src, target_is_directory=src.is_dir())
559
+ print(f"{RED}Symlinked {src} (source) --> {dst} (dest)")
560
+ new_state["symlinkdir"]["is_linked"] = True
561
+
562
+ # 3) Ensure the executor prompt is the first SystemMessage.
563
+ if isinstance(new_state["messages"][0], SystemMessage):
564
+ new_state["messages"][0] = SystemMessage(
565
+ content=self.executor_prompt
566
+ )
567
+ else:
568
+ new_state["messages"] = [
569
+ SystemMessage(content=self.executor_prompt)
570
+ ] + state["messages"]
571
+
572
+ # 4) Invoke the LLM with the prepared message sequence.
573
+ try:
574
+ response = self.llm.invoke(
575
+ new_state["messages"], self.build_config(tags=["agent"])
576
+ )
577
+ except ContentPolicyViolationError as e:
578
+ print("Error: ", e, " ", new_state["messages"][-1].content)
579
+
580
+ # 5) Optionally persist the pre-invocation state for audit/debugging.
581
+ if self.log_state:
582
+ self.write_state("execution_agent.json", new_state)
583
+
584
+ # Return the model's response and the workspace path as a partial state update.
585
+ return {"messages": [response], "workspace": new_state["workspace"]}
586
+
587
+ def summarize(self, state: ExecutionState) -> ExecutionState:
588
+ """Produce a concise summary of the conversation and optionally persist memory.
589
+
590
+ This method builds a summarization prompt, invokes the LLM to obtain a compact
591
+ summary of recent interactions, optionally logs salient details to the agent
592
+ memory backend, and writes debug state when logging is enabled.
593
+
594
+ Args:
595
+ state (ExecutionState): The execution state containing message history.
596
+
597
+ Returns:
598
+ ExecutionState: A partial update with a single string message containing
599
+ the summary.
600
+ """
601
+ # 1) Construct the summarization message list (system prompt + prior messages).
602
+ messages = [SystemMessage(content=summarize_prompt)] + state["messages"]
603
+
604
+ # 2) Invoke the LLM to generate a summary; capture content even on failure.
605
+ response_content = ""
606
+ try:
607
+ response = self.llm.invoke(
608
+ messages, self.build_config(tags=["summarize"])
609
+ )
610
+ response_content = response.content
611
+ except ContentPolicyViolationError as e:
612
+ print("Error: ", e, " ", messages[-1].content)
613
+
614
+ # 3) Optionally persist salient details to the memory backend.
615
+ if self.agent_memory:
616
+ memories: list[str] = []
617
+ # Collect human/system/tool message content; for AI tool calls, store args.
618
+ for msg in state["messages"]:
619
+ if not isinstance(msg, AIMessage):
620
+ memories.append(msg.content)
621
+ elif not msg.tool_calls:
622
+ memories.append(msg.content)
623
+ else:
624
+ tool_strings = []
625
+ for tool in msg.tool_calls:
626
+ tool_strings.append("Tool Name: " + tool["name"])
627
+ for arg_name in tool["args"]:
628
+ tool_strings.append(
629
+ f"Arg: {str(arg_name)}\nValue: "
630
+ f"{str(tool['args'][arg_name])}"
631
+ )
632
+ memories.append("\n".join(tool_strings))
633
+ memories.append(response_content)
634
+ self.agent_memory.add_memories(memories)
635
+
636
+ # 4) Optionally write state to disk for debugging/auditing.
637
+ if self.log_state:
638
+ save_state = state.copy()
639
+ # Append the summary as an AI message for a complete trace.
640
+ save_state["messages"] = save_state["messages"] + [
641
+ AIMessage(content=response_content)
642
+ ]
643
+ self.write_state("execution_agent.json", save_state)
644
+
645
+ # 5) Return a partial state update with only the summary content.
646
+ return {"messages": [response_content]}
647
+
648
+ def safety_check(self, state: ExecutionState) -> ExecutionState:
649
+ """Assess pending shell commands for safety and inject ToolMessages with results.
650
+
651
+ This method inspects the most recent AI tool calls, evaluates any run_cmd
652
+ queries against the safety prompt, and constructs ToolMessages that either
653
+ flag unsafe commands with reasons or confirm safe execution. If any command
654
+ is unsafe, the generated ToolMessages are appended to the state so the agent
655
+ can react without executing the command.
656
+
657
+ Args:
658
+ state (ExecutionState): Current execution state.
659
+
660
+ Returns:
661
+ ExecutionState: Either the unchanged state (all safe) or a copy with one
662
+ or more ToolMessages appended when unsafe commands are detected.
663
+ """
664
+ # 1) Work on a shallow copy; inspect the most recent model message.
665
+ new_state = state.copy()
666
+ last_msg = new_state["messages"][-1]
667
+
668
+ # 2) Evaluate any pending run_cmd tool calls for safety.
669
+ tool_responses: list[ToolMessage] = []
670
+ any_unsafe = False
671
+ for tool_call in last_msg.tool_calls:
672
+ if tool_call["name"] != "run_cmd":
673
+ continue
674
+
675
+ query = tool_call["args"]["query"]
676
+ safety_result = self.llm.invoke(
677
+ self.safety_prompt + query,
678
+ self.build_config(tags=["safety_check"]),
679
+ )
680
+
681
+ if "[NO]" in safety_result.content:
682
+ any_unsafe = True
683
+ tool_response = (
684
+ "[UNSAFE] That command `{q}` was deemed unsafe and cannot be run.\n"
685
+ "For reason: {r}"
686
+ ).format(q=query, r=safety_result.content)
687
+ console.print(
688
+ "[bold red][WARNING][/bold red] Command deemed unsafe:",
689
+ query,
690
+ )
691
+ # Also surface the model's rationale for transparency.
692
+ console.print(
693
+ "[bold red][WARNING][/bold red] REASON:", tool_response
694
+ )
695
+ else:
696
+ tool_response = f"Command `{query}` passed safety check."
697
+ console.print(
698
+ f"[green]Command passed safety check:[/green] {query}"
699
+ )
700
+
701
+ tool_responses.append(
702
+ ToolMessage(
703
+ content=tool_response,
704
+ tool_call_id=tool_call["id"],
705
+ )
706
+ )
707
+
708
+ # 3) If any command is unsafe, append all tool responses; otherwise keep state.
709
+ if any_unsafe:
710
+ new_state["messages"].extend(tool_responses)
711
+
712
+ return new_state
713
+
714
+ def _build_graph(self):
715
+ """Construct and compile the agent's LangGraph state machine."""
716
+ # Create a graph over the agent's execution state.
717
+ graph = StateGraph(ExecutionState)
718
+
719
+ # Register nodes:
720
+ # - "agent": LLM planning/execution step
721
+ # - "action": tool dispatch (run_cmd, write_code, etc.)
722
+ # - "summarize": summary/finalization step
723
+ # - "safety_check": gate for shell command safety
724
+ self.add_node(graph, self.query_executor, "agent")
725
+ self.add_node(graph, self.tool_node, "action")
726
+ self.add_node(graph, self.summarize, "summarize")
727
+ self.add_node(graph, self.safety_check, "safety_check")
728
+
729
+ # Set entrypoint: execution starts with the "agent" node.
730
+ graph.set_entry_point("agent")
731
+
732
+ # From "agent", either continue (tools) or finish (summarize),
733
+ # based on presence of tool calls in the last message.
734
+ graph.add_conditional_edges(
735
+ "agent",
736
+ self._wrap_cond(should_continue, "should_continue", "execution"),
737
+ {"continue": "safety_check", "summarize": "summarize"},
738
+ )
739
+
740
+ # From "safety_check", route to tools if safe, otherwise back to agent
741
+ # to revise the plan without executing unsafe commands.
742
+ graph.add_conditional_edges(
743
+ "safety_check",
744
+ self._wrap_cond(command_safe, "command_safe", "execution"),
745
+ {"safe": "action", "unsafe": "agent"},
746
+ )
747
+
748
+ # After tools run, return control to the agent for the next step.
749
+ graph.add_edge("action", "agent")
750
+
751
+ # The graph completes at the "summarize" node.
752
+ graph.set_finish_point("summarize")
753
+
754
+ # Compile and return the executable graph (optionally with a checkpointer).
755
+ return graph.compile(checkpointer=self.checkpointer)
756
+
757
+ def _invoke(
758
+ self, inputs: Mapping[str, Any], recursion_limit: int = 999_999, **_
759
+ ):
760
+ """Invoke the compiled graph with inputs under a specified recursion limit.
761
+
762
+ This method builds a LangGraph config with the provided recursion limit
763
+ and a "graph" tag, then delegates to the compiled graph's invoke method.
764
+ """
765
+ # Build invocation config with a generous recursion limit for long runs.
766
+ config = self.build_config(
767
+ recursion_limit=recursion_limit, tags=["graph"]
768
+ )
769
+
770
+ # Delegate execution to the compiled graph.
771
+ return self._action.invoke(inputs, config)
772
+
773
+ # This property is trying to stop people bypassing invoke
774
+ @property
775
+ def action(self):
776
+ """Property used to affirm `action` attribute is unsupported."""
777
+ raise AttributeError(
778
+ "Use .stream(...) or .invoke(...); direct .action access is unsupported."
779
+ )
535
780
 
536
781
 
782
+ # Single module test execution
537
783
  def main():
538
784
  execution_agent = ExecutionAgent()
539
785
  problem_string = (