ursa-ai 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. ursa/__init__.py +3 -0
  2. ursa/agents/__init__.py +32 -0
  3. ursa/agents/acquisition_agents.py +812 -0
  4. ursa/agents/arxiv_agent.py +429 -0
  5. ursa/agents/base.py +728 -0
  6. ursa/agents/chat_agent.py +60 -0
  7. ursa/agents/code_review_agent.py +341 -0
  8. ursa/agents/execution_agent.py +915 -0
  9. ursa/agents/hypothesizer_agent.py +614 -0
  10. ursa/agents/lammps_agent.py +465 -0
  11. ursa/agents/mp_agent.py +204 -0
  12. ursa/agents/optimization_agent.py +410 -0
  13. ursa/agents/planning_agent.py +219 -0
  14. ursa/agents/rag_agent.py +304 -0
  15. ursa/agents/recall_agent.py +54 -0
  16. ursa/agents/websearch_agent.py +196 -0
  17. ursa/cli/__init__.py +363 -0
  18. ursa/cli/hitl.py +516 -0
  19. ursa/cli/hitl_api.py +75 -0
  20. ursa/observability/metrics_charts.py +1279 -0
  21. ursa/observability/metrics_io.py +11 -0
  22. ursa/observability/metrics_session.py +750 -0
  23. ursa/observability/pricing.json +97 -0
  24. ursa/observability/pricing.py +321 -0
  25. ursa/observability/timing.py +1466 -0
  26. ursa/prompt_library/__init__.py +0 -0
  27. ursa/prompt_library/code_review_prompts.py +51 -0
  28. ursa/prompt_library/execution_prompts.py +50 -0
  29. ursa/prompt_library/hypothesizer_prompts.py +17 -0
  30. ursa/prompt_library/literature_prompts.py +11 -0
  31. ursa/prompt_library/optimization_prompts.py +131 -0
  32. ursa/prompt_library/planning_prompts.py +79 -0
  33. ursa/prompt_library/websearch_prompts.py +131 -0
  34. ursa/tools/__init__.py +0 -0
  35. ursa/tools/feasibility_checker.py +114 -0
  36. ursa/tools/feasibility_tools.py +1075 -0
  37. ursa/tools/run_command.py +27 -0
  38. ursa/tools/write_code.py +42 -0
  39. ursa/util/__init__.py +0 -0
  40. ursa/util/diff_renderer.py +128 -0
  41. ursa/util/helperFunctions.py +142 -0
  42. ursa/util/logo_generator.py +625 -0
  43. ursa/util/memory_logger.py +183 -0
  44. ursa/util/optimization_schema.py +78 -0
  45. ursa/util/parse.py +405 -0
  46. ursa_ai-0.9.1.dist-info/METADATA +304 -0
  47. ursa_ai-0.9.1.dist-info/RECORD +51 -0
  48. ursa_ai-0.9.1.dist-info/WHEEL +5 -0
  49. ursa_ai-0.9.1.dist-info/entry_points.txt +2 -0
  50. ursa_ai-0.9.1.dist-info/licenses/LICENSE +8 -0
  51. ursa_ai-0.9.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,915 @@
1
+ """Execution agent that builds a tool-enabled state graph to autonomously run tasks.
2
+
3
+ This module implements ExecutionAgent, a LangGraph-based agent that executes user
4
+ instructions by invoking LLM tool calls and coordinating a controlled workflow.
5
+
6
+ Key features:
7
+ - Workspace management with optional symlinking for external sources.
8
+ - Safety-checked shell execution via run_cmd with output size budgeting.
9
+ - Code authoring and edits through write_code and edit_code with rich previews.
10
+ - Web search capability through DuckDuckGoSearchResults.
11
+ - Summarization of the session and optional memory logging.
12
+ - Configurable graph with nodes for agent, safety_check, action, and summarize.
13
+
14
+ Implementation notes:
15
+ - LLM prompts are sourced from prompt_library.execution_prompts.
16
+ - Outputs from subprocess are trimmed under MAX_TOOL_MSG_CHARS to fit tool messages.
17
+ - The agent uses ToolNode and LangGraph StateGraph to loop until no tool calls remain.
18
+ - Safety gates block unsafe shell commands and surface the rationale to the user.
19
+
20
+ Environment:
21
+ - MAX_TOOL_MSG_CHARS caps combined stdout/stderr in tool responses.
22
+
23
+ Entry points:
24
+ - ExecutionAgent._invoke(...) runs the compiled graph.
25
+ - main() shows a minimal demo that writes and runs a script.
26
+ """
27
+
28
+ # from langchain_core.runnables.graph import MermaidDrawMethod
29
+ import os
30
+ import subprocess
31
+ from pathlib import Path
32
+ from typing import Annotated, Any, Callable, Literal, Mapping, Optional
33
+
34
+ import randomname
35
+ from langchain.agents.middleware import SummarizationMiddleware
36
+ from langchain.chat_models import BaseChatModel
37
+ from langchain_community.tools import (
38
+ DuckDuckGoSearchResults,
39
+ ) # TavilySearchResults,
40
+ from langchain_core.messages import (
41
+ AIMessage,
42
+ AnyMessage,
43
+ SystemMessage,
44
+ ToolMessage,
45
+ )
46
+ from langchain_core.tools import InjectedToolCallId, StructuredTool, tool
47
+ from langchain_mcp_adapters.client import MultiServerMCPClient
48
+ from langgraph.graph import StateGraph
49
+ from langgraph.graph.message import add_messages
50
+ from langgraph.prebuilt import InjectedState, ToolNode
51
+ from langgraph.types import Command
52
+
53
+ # Rich
54
+ from rich import get_console
55
+ from rich.panel import Panel
56
+ from rich.syntax import Syntax
57
+ from typing_extensions import TypedDict
58
+
59
+ from ..prompt_library.execution_prompts import (
60
+ executor_prompt,
61
+ get_safety_prompt,
62
+ summarize_prompt,
63
+ )
64
+ from ..util.diff_renderer import DiffRenderer
65
+ from ..util.memory_logger import AgentMemory
66
+ from .base import BaseAgent
67
+
68
+ console = get_console() # always returns the same instance
69
+
70
+ # --- ANSI color codes ---
71
+ GREEN = "\033[92m"
72
+ BLUE = "\033[94m"
73
+ RED = "\033[91m"
74
+ RESET = "\033[0m"
75
+ BOLD = "\033[1m"
76
+
77
+
78
+ # Global variables for the module.
79
+
80
+ # Set a limit for message characters - the user could overload
81
+ # that in their env, or maybe we could pull this out of the LLM parameters
82
+ MAX_TOOL_MSG_CHARS = int(os.getenv("MAX_TOOL_MSG_CHARS", "50000"))
83
+
84
+ # Set a search tool.
85
+ search_tool = DuckDuckGoSearchResults(output_format="json", num_results=10)
86
+ # search_tool = TavilySearchResults(
87
+ # max_results=10,
88
+ # search_depth="advanced",
89
+ # include_answer=True)
90
+
91
+
92
+ # Classes for typing
93
+ class ExecutionState(TypedDict):
94
+ """TypedDict representing the execution agent's mutable run state used by nodes.
95
+
96
+ Fields:
97
+ - messages: list of messages (System/Human/AI/Tool) with add_messages metadata.
98
+ - current_progress: short status string describing agent progress.
99
+ - code_files: list of filenames created or edited in the workspace.
100
+ - workspace: path to the working directory where files and commands run.
101
+ - symlinkdir: optional dict describing a symlink operation (source, dest,
102
+ is_linked).
103
+ """
104
+
105
+ messages: Annotated[list[AnyMessage], add_messages]
106
+ current_progress: str
107
+ code_files: list[str]
108
+ workspace: str
109
+ symlinkdir: dict
110
+
111
+
112
+ # Helper functions
113
+ def convert_to_tool(fn):
114
+ if isinstance(fn, StructuredTool):
115
+ return fn
116
+ else:
117
+ return StructuredTool.from_function(
118
+ func=fn, name=fn.__name__, description=fn.__doc__
119
+ )
120
+
121
+
122
+ def _strip_fences(snippet: str) -> str:
123
+ """Remove markdown fences from a code snippet.
124
+
125
+ This function strips leading triple backticks and any language
126
+ identifiers from a markdown-formatted code snippet and returns
127
+ only the contained code.
128
+
129
+ Args:
130
+ snippet: The markdown-formatted code snippet.
131
+
132
+ Returns:
133
+ The snippet content without leading markdown fences.
134
+ """
135
+ if "```" not in snippet:
136
+ return snippet
137
+
138
+ parts = snippet.split("```")
139
+ if len(parts) < 3:
140
+ return snippet
141
+
142
+ body = parts[1]
143
+ return "\n".join(body.split("\n")[1:]) if "\n" in body else body.strip()
144
+
145
+
146
+ def _snip_text(text: str, max_chars: int) -> tuple[str, bool]:
147
+ """Truncate text to a maximum length and indicate if truncation occurred.
148
+
149
+ Args:
150
+ text: The original text to potentially truncate.
151
+ max_chars: The maximum characters allowed in the output.
152
+
153
+ Returns:
154
+ A tuple of (possibly truncated text, boolean flag indicating
155
+ if truncation occurred).
156
+ """
157
+ if text is None:
158
+ return "", False
159
+ if max_chars <= 0:
160
+ return "", len(text) > 0
161
+ if len(text) <= max_chars:
162
+ return text, False
163
+ head = max_chars // 2
164
+ tail = max_chars - head
165
+ return (
166
+ text[:head]
167
+ + f"\n... [snipped {len(text) - max_chars} chars] ...\n"
168
+ + text[-tail:],
169
+ True,
170
+ )
171
+
172
+
173
+ def _fit_streams_to_budget(stdout: str, stderr: str, total_budget: int):
174
+ """Allocate and truncate stdout and stderr to fit a total character budget.
175
+
176
+ Args:
177
+ stdout: The original stdout string.
178
+ stderr: The original stderr string.
179
+ total_budget: The combined character budget for stdout and stderr.
180
+
181
+ Returns:
182
+ A tuple of (possibly truncated stdout, possibly truncated stderr).
183
+ """
184
+ label_overhead = len("STDOUT:\n") + len("\nSTDERR:\n")
185
+ budget = max(0, total_budget - label_overhead)
186
+
187
+ if len(stdout) + len(stderr) <= budget:
188
+ return stdout, stderr
189
+
190
+ total_len = max(1, len(stdout) + len(stderr))
191
+ stdout_budget = int(budget * (len(stdout) / total_len))
192
+ stderr_budget = budget - stdout_budget
193
+
194
+ stdout_snip, _ = _snip_text(stdout, stdout_budget)
195
+ stderr_snip, _ = _snip_text(stderr, stderr_budget)
196
+
197
+ return stdout_snip, stderr_snip
198
+
199
+
200
+ def should_continue(state: ExecutionState) -> Literal["summarize", "continue"]:
201
+ """Return 'summarize' if no tool calls in the last message, else 'continue'.
202
+
203
+ Args:
204
+ state: The current execution state containing messages.
205
+
206
+ Returns:
207
+ A literal "summarize" if the last message has no tool calls,
208
+ otherwise "continue".
209
+ """
210
+ messages = state["messages"]
211
+ last_message = messages[-1]
212
+ # If there is no tool call, then we finish
213
+ if not last_message.tool_calls:
214
+ return "summarize"
215
+ # Otherwise if there is, we continue
216
+ else:
217
+ return "continue"
218
+
219
+
220
+ def command_safe(state: ExecutionState) -> Literal["safe", "unsafe"]:
221
+ """Return 'safe' if the last command was safe, otherwise 'unsafe'.
222
+
223
+ Args:
224
+ state: The current execution state containing messages and tool calls.
225
+ Returns:
226
+ A literal "safe" if no '[UNSAFE]' tags are in the last command,
227
+ otherwise "unsafe".
228
+ """
229
+ index = -1
230
+ message = state["messages"][index]
231
+ # Loop through all the consecutive tool messages in reverse order
232
+ while isinstance(message, ToolMessage):
233
+ if "[UNSAFE]" in message.content:
234
+ return "unsafe"
235
+
236
+ index -= 1
237
+ message = state["messages"][index]
238
+
239
+ return "safe"
240
+
241
+
242
+ # Tools for ExecutionAgent
243
+ @tool
244
+ def run_cmd(query: str, state: Annotated[dict, InjectedState]) -> str:
245
+ """Execute a shell command in the workspace and return its combined output.
246
+
247
+ Runs the specified command using subprocess.run in the given workspace
248
+ directory, captures stdout and stderr, enforces a maximum character budget,
249
+ and formats both streams into a single string. KeyboardInterrupt during
250
+ execution is caught and reported.
251
+
252
+ Args:
253
+ query: The shell command to execute.
254
+ state: A dict with injected state; must include the 'workspace' path.
255
+
256
+ Returns:
257
+ A formatted string with "STDOUT:" followed by the truncated stdout and
258
+ "STDERR:" followed by the truncated stderr.
259
+ """
260
+ workspace_dir = state["workspace"]
261
+
262
+ print("RUNNING: ", query)
263
+ try:
264
+ result = subprocess.run(
265
+ query,
266
+ text=True,
267
+ shell=True,
268
+ timeout=60000,
269
+ capture_output=True,
270
+ cwd=workspace_dir,
271
+ )
272
+ stdout, stderr = result.stdout, result.stderr
273
+ except KeyboardInterrupt:
274
+ print("Keyboard Interrupt of command: ", query)
275
+ stdout, stderr = "", "KeyboardInterrupt:"
276
+
277
+ # Fit BOTH streams under a single overall cap
278
+ stdout_fit, stderr_fit = _fit_streams_to_budget(
279
+ stdout or "", stderr or "", MAX_TOOL_MSG_CHARS
280
+ )
281
+
282
+ print("STDOUT: ", stdout_fit)
283
+ print("STDERR: ", stderr_fit)
284
+
285
+ return f"STDOUT:\n{stdout_fit}\nSTDERR:\n{stderr_fit}"
286
+
287
+
288
+ @tool
289
+ def write_code(
290
+ code: str,
291
+ filename: str,
292
+ tool_call_id: Annotated[str, InjectedToolCallId],
293
+ state: Annotated[dict, InjectedState],
294
+ ) -> Command:
295
+ """Write source code to a file and update the agent’s workspace state.
296
+
297
+ Args:
298
+ code: The source code content to be written to disk.
299
+ filename: Name of the target file (including its extension).
300
+ tool_call_id: Identifier for this tool invocation.
301
+ state: Agent state dict holding workspace path and file list.
302
+
303
+ Returns:
304
+ Command: Contains an updated state (including code_files) and
305
+ a ToolMessage acknowledging success or failure.
306
+ """
307
+ # Determine the full path to the target file
308
+ workspace_dir = state["workspace"]
309
+ console.print("[cyan]Writing file:[/]", filename)
310
+
311
+ # Clean up markdown fences on submitted code.
312
+ code = _strip_fences(code)
313
+
314
+ # Show syntax-highlighted preview before writing to file
315
+ try:
316
+ lexer_name = Syntax.guess_lexer(filename, code)
317
+ except Exception:
318
+ lexer_name = "text"
319
+
320
+ console.print(
321
+ Panel(
322
+ Syntax(code, lexer_name, line_numbers=True),
323
+ title="File Preview",
324
+ border_style="cyan",
325
+ )
326
+ )
327
+
328
+ # Write cleaned code to disk
329
+ code_file = os.path.join(workspace_dir, filename)
330
+ try:
331
+ with open(code_file, "w", encoding="utf-8") as f:
332
+ f.write(code)
333
+ except Exception as exc:
334
+ console.print(
335
+ "[bold bright_white on red] :heavy_multiplication_x: [/] "
336
+ "[red]Failed to write file:[/]",
337
+ exc,
338
+ )
339
+ return f"Failed to write {filename}."
340
+
341
+ console.print(
342
+ f"[bold bright_white on green] :heavy_check_mark: [/] "
343
+ f"[green]File written:[/] {code_file}"
344
+ )
345
+
346
+ # Append the file to the list in agent's state for later reference
347
+ file_list = state.get("code_files", [])
348
+ if filename not in file_list:
349
+ file_list.append(filename)
350
+
351
+ # Create a tool message to send back to acknowledge success.
352
+ msg = ToolMessage(
353
+ content=f"File {filename} written successfully.",
354
+ tool_call_id=tool_call_id,
355
+ )
356
+
357
+ # Return updated code files list & the message
358
+ return Command(
359
+ update={
360
+ "code_files": file_list,
361
+ "messages": [msg],
362
+ }
363
+ )
364
+
365
+
366
+ @tool
367
+ def edit_code(
368
+ old_code: str,
369
+ new_code: str,
370
+ filename: str,
371
+ state: Annotated[dict, InjectedState],
372
+ ) -> str:
373
+ """Replace the **first** occurrence of *old_code* with *new_code* in *filename*.
374
+
375
+ Args:
376
+ old_code: Code fragment to search for.
377
+ new_code: Replacement fragment.
378
+ filename: Target file inside the workspace.
379
+
380
+ Returns:
381
+ Success / failure message.
382
+ """
383
+ workspace_dir = state["workspace"]
384
+ console.print("[cyan]Editing file:[/cyan]", filename)
385
+
386
+ code_file = os.path.join(workspace_dir, filename)
387
+ try:
388
+ with open(code_file, "r", encoding="utf-8") as f:
389
+ content = f.read()
390
+ except FileNotFoundError:
391
+ console.print(
392
+ "[bold bright_white on red] :heavy_multiplication_x: [/] "
393
+ "[red]File not found:[/]",
394
+ filename,
395
+ )
396
+ return f"Failed: {filename} not found."
397
+
398
+ # Clean up markdown fences
399
+ old_code_clean = _strip_fences(old_code)
400
+ new_code_clean = _strip_fences(new_code)
401
+
402
+ if old_code_clean not in content:
403
+ console.print(
404
+ "[yellow] ⚠️ 'old_code' not found in file'; no changes made.[/]"
405
+ )
406
+ return f"No changes made to {filename}: 'old_code' not found in file."
407
+
408
+ updated = content.replace(old_code_clean, new_code_clean, 1)
409
+
410
+ console.print(
411
+ Panel(
412
+ DiffRenderer(content, updated, filename),
413
+ title="Diff Preview",
414
+ border_style="cyan",
415
+ )
416
+ )
417
+
418
+ try:
419
+ with open(code_file, "w", encoding="utf-8") as f:
420
+ f.write(updated)
421
+ except Exception as exc:
422
+ console.print(
423
+ "[bold bright_white on red] :heavy_multiplication_x: [/] "
424
+ "[red]Failed to write file:[/]",
425
+ exc,
426
+ )
427
+ return f"Failed to edit {filename}."
428
+
429
+ console.print(
430
+ f"[bold bright_white on green] :heavy_check_mark: [/] "
431
+ f"[green]File updated:[/] {code_file}"
432
+ )
433
+ file_list = state.get("code_files", [])
434
+ if code_file not in file_list:
435
+ file_list.append(filename)
436
+ state["code_files"] = file_list
437
+
438
+ return f"File {filename} updated successfully."
439
+
440
+
441
+ # Main module class
442
+ class ExecutionAgent(BaseAgent):
443
+ """Orchestrates model-driven code execution, tool calls, and state management.
444
+
445
+ Orchestrates model-driven code execution, tool calls, and state management for
446
+ iterative program synthesis and shell interaction.
447
+
448
+ This agent wraps an LLM with a small execution graph that alternates
449
+ between issuing model queries, invoking tools (run, write, edit, search),
450
+ performing safety checks, and summarizing progress. It manages a
451
+ workspace on disk, optional symlinks, and an optional memory backend to
452
+ persist summaries.
453
+
454
+ Args:
455
+ llm (BaseChatModel): Model identifier or bound chat model
456
+ instance. If a string is provided, the BaseAgent initializer will
457
+ resolve it.
458
+ agent_memory (Any | AgentMemory, optional): Memory backend used to
459
+ store summarized agent interactions. If provided, summaries are
460
+ saved here.
461
+ log_state (bool): When True, the agent writes intermediate json state
462
+ to disk for debugging and auditability.
463
+ **kwargs: Passed through to the BaseAgent constructor (e.g., model
464
+ configuration, checkpointer).
465
+
466
+ Attributes:
467
+ safe_codes (list[str]): List of trusted programming languages for the
468
+ agent. Defaults to python and julia
469
+ executor_prompt (str): Prompt used when invoking the executor LLM
470
+ loop.
471
+ summarize_prompt (str): Prompt used to request concise summaries for
472
+ memory or final output.
473
+ tools (list[Tool]): Tools available to the agent (run_cmd, write_code,
474
+ edit_code, search_tool).
475
+ tool_node (ToolNode): Graph node that dispatches tool calls.
476
+ llm (BaseChatModel): LLM instance bound to the available tools.
477
+ _action (StateGraph): Compiled execution graph that implements the
478
+ main loop and branching logic.
479
+
480
+ Methods:
481
+ query_executor(state): Send messages to the executor LLM, ensure
482
+ workspace exists, and handle symlink setup before returning the
483
+ model response.
484
+ summarize(state): Produce and optionally persist a summary of recent
485
+ interactions to the memory backend.
486
+ safety_check(state): Validate pending run_cmd calls via the safety
487
+ prompt and append ToolMessages for unsafe commands.
488
+ get_safety_prompt(query, safe_codes, created_files): Get the LLM prompt for safety_check
489
+ that includes an editable list of available programming languages and gets the context
490
+ of files that the agent has generated and can trust.
491
+ _build_graph(): Construct and compile the StateGraph for the agent
492
+ loop.
493
+ _invoke(inputs, recursion_limit=...): Internal entry that invokes the
494
+ compiled graph with a given recursion limit.
495
+ action (property): Disabled; direct access is not supported. Use
496
+ invoke or stream entry points instead.
497
+
498
+ Raises:
499
+ AttributeError: Accessing the .action attribute raises to encourage
500
+ using .stream(...) or .invoke(...).
501
+ """
502
+
503
+ def __init__(
504
+ self,
505
+ llm: BaseChatModel,
506
+ agent_memory: Optional[Any | AgentMemory] = None,
507
+ log_state: bool = False,
508
+ extra_tools: Optional[list[Callable[..., Any]]] = None,
509
+ tokens_before_summarize: int = 50000,
510
+ messages_to_keep: int = 20,
511
+ safe_codes: Optional[list[str]] = None,
512
+ **kwargs,
513
+ ):
514
+ """ExecutionAgent class initialization."""
515
+ super().__init__(llm, **kwargs)
516
+ self.agent_memory = agent_memory
517
+ self.safe_codes = safe_codes or ["python", "julia"]
518
+ self.get_safety_prompt = get_safety_prompt
519
+ self.executor_prompt = executor_prompt
520
+ self.summarize_prompt = summarize_prompt
521
+ self.tools = [run_cmd, write_code, edit_code, search_tool]
522
+ self.extra_tools = extra_tools
523
+ if self.extra_tools is not None:
524
+ self.tools.extend(self.extra_tools)
525
+ self.tool_node = ToolNode(self.tools)
526
+ self.llm = self.llm.bind_tools(self.tools)
527
+ self.log_state = log_state
528
+ self._action = self._build_graph()
529
+ self.context_summarizer = SummarizationMiddleware(
530
+ model=self.llm,
531
+ max_tokens_before_summary=tokens_before_summarize,
532
+ messages_to_keep=messages_to_keep,
533
+ )
534
+
535
+ # Check message history length and summarize to shorten the token usage:
536
+ def _summarize_context(self, state: ExecutionState) -> ExecutionState:
537
+ summarized_messages = self.context_summarizer.before_model(state, None)
538
+ if summarized_messages:
539
+ tokens_before_summarize = self.context_summarizer.token_counter(
540
+ state["messages"]
541
+ )
542
+ state["messages"] = summarized_messages["messages"]
543
+ tokens_after_summarize = self.context_summarizer.token_counter(
544
+ state["messages"][1:]
545
+ )
546
+ console.print(
547
+ Panel(
548
+ (
549
+ f"Summarized Conversation History:\n"
550
+ f"Approximate tokens before: {tokens_before_summarize}\n"
551
+ f"Approximate tokens after: {tokens_after_summarize}\n"
552
+ ),
553
+ title="[bold yellow1 on black]:clipboard: Plan",
554
+ border_style="yellow1",
555
+ style="bold yellow1 on black",
556
+ )
557
+ )
558
+ else:
559
+ tokens_after_summarize = self.context_summarizer.token_counter(
560
+ state["messages"]
561
+ )
562
+ return state
563
+
564
+ # Define the function that calls the model
565
+ def query_executor(self, state: ExecutionState) -> ExecutionState:
566
+ """Prepare workspace, handle optional symlinks, and invoke the executor LLM.
567
+
568
+ This method copies the incoming state, ensures a workspace directory exists
569
+ (creating one with a random name when absent), optionally creates a symlink
570
+ described by state["symlinkdir"], sets or injects the executor system prompt
571
+ as the first message, and invokes the bound LLM. When logging is enabled,
572
+ it persists the pre-invocation state to disk.
573
+
574
+ Args:
575
+ state: The current execution state. Expected keys include:
576
+ - "messages": Ordered list of System/Human/AI/Tool messages.
577
+ - "workspace": Optional path to the working directory.
578
+ - "symlinkdir": Optional dict with "source" and "dest" keys.
579
+
580
+ Returns:
581
+ ExecutionState: Partial state update containing:
582
+ - "messages": A list with the model's response as the latest entry.
583
+ - "workspace": The resolved workspace path.
584
+ """
585
+ new_state = state.copy()
586
+
587
+ # 1) Ensure a workspace directory exists, creating a named one if absent.
588
+ if "workspace" not in new_state.keys():
589
+ new_state["workspace"] = randomname.get_name()
590
+ print(
591
+ f"{RED}Creating the folder "
592
+ f"{BLUE}{BOLD}{new_state['workspace']}{RESET}{RED} "
593
+ f"for this project.{RESET}"
594
+ )
595
+ os.makedirs(new_state["workspace"], exist_ok=True)
596
+
597
+ # 1.5) Check message history length and summarize to shorten the token usage:
598
+ new_state = self._summarize_context(new_state)
599
+
600
+ # 2) Optionally create a symlink if symlinkdir is provided and not yet linked.
601
+ sd = new_state.get("symlinkdir")
602
+ if isinstance(sd, dict) and "is_linked" not in sd:
603
+ # symlinkdir structure: {"source": "/path/to/src", "dest": "link/name"}
604
+ symlinkdir = sd
605
+
606
+ src = Path(symlinkdir["source"]).expanduser().resolve()
607
+ workspace_root = Path(new_state["workspace"]).expanduser().resolve()
608
+ dst = (
609
+ workspace_root / symlinkdir["dest"]
610
+ ) # Link lives inside workspace.
611
+
612
+ # If a file/link already exists at the destination, replace it.
613
+ if dst.exists() or dst.is_symlink():
614
+ dst.unlink()
615
+
616
+ # Ensure parent directories for the link exist.
617
+ dst.parent.mkdir(parents=True, exist_ok=True)
618
+
619
+ # Create the symlink (tell pathlib if the target is a directory).
620
+ dst.symlink_to(src, target_is_directory=src.is_dir())
621
+ print(f"{RED}Symlinked {src} (source) --> {dst} (dest)")
622
+ new_state["symlinkdir"]["is_linked"] = True
623
+
624
+ # 3) Ensure the executor prompt is the first SystemMessage.
625
+ if isinstance(new_state["messages"][0], SystemMessage):
626
+ new_state["messages"][0] = SystemMessage(
627
+ content=self.executor_prompt
628
+ )
629
+ else:
630
+ new_state["messages"] = [
631
+ SystemMessage(content=self.executor_prompt)
632
+ ] + state["messages"]
633
+
634
+ # 4) Invoke the LLM with the prepared message sequence.
635
+ try:
636
+ response = self.llm.invoke(
637
+ new_state["messages"], self.build_config(tags=["agent"])
638
+ )
639
+ new_state["messages"].append(response)
640
+ except Exception as e:
641
+ print("Error: ", e, " ", new_state["messages"][-1].content)
642
+ new_state["messages"].append(
643
+ AIMessage(content=f"Response error {e}")
644
+ )
645
+
646
+ # 5) Optionally persist the pre-invocation state for audit/debugging.
647
+ if self.log_state:
648
+ self.write_state("execution_agent.json", new_state)
649
+
650
+ # Return the model's response and the workspace path as a partial state update.
651
+ return new_state
652
+
653
+ def summarize(self, state: ExecutionState) -> ExecutionState:
654
+ """Produce a concise summary of the conversation and optionally persist memory.
655
+
656
+ This method builds a summarization prompt, invokes the LLM to obtain a compact
657
+ summary of recent interactions, optionally logs salient details to the agent
658
+ memory backend, and writes debug state when logging is enabled.
659
+
660
+ Args:
661
+ state (ExecutionState): The execution state containing message history.
662
+
663
+ Returns:
664
+ ExecutionState: A partial update with a single string message containing
665
+ the summary.
666
+ """
667
+ new_state = state.copy()
668
+
669
+ # 0) Check message history length and summarize to shorten the token usage:
670
+ new_state = self._summarize_context(new_state)
671
+
672
+ # 1) Construct the summarization message list (system prompt + prior messages).
673
+ messages = (
674
+ new_state["messages"]
675
+ if isinstance(new_state["messages"][0], SystemMessage)
676
+ else [SystemMessage(content=summarize_prompt)]
677
+ + new_state["messages"]
678
+ )
679
+
680
+ # 2) Invoke the LLM to generate a summary; capture content even on failure.
681
+ response_content = ""
682
+ try:
683
+ response = self.llm.invoke(
684
+ messages, self.build_config(tags=["summarize"])
685
+ )
686
+ response_content = response.content
687
+ new_state["messages"].append(response)
688
+ except Exception as e:
689
+ print("Error: ", e, " ", messages[-1].content)
690
+ new_state["messages"].append(
691
+ AIMessage(content=f"Response error {e}")
692
+ )
693
+
694
+ # 3) Optionally persist salient details to the memory backend.
695
+ if self.agent_memory:
696
+ memories: list[str] = []
697
+ # Collect human/system/tool message content; for AI tool calls, store args.
698
+ for msg in new_state["messages"]:
699
+ if not isinstance(msg, AIMessage):
700
+ memories.append(msg.content)
701
+ elif not msg.tool_calls:
702
+ memories.append(msg.content)
703
+ else:
704
+ tool_strings = []
705
+ for tool in msg.tool_calls:
706
+ tool_strings.append("Tool Name: " + tool["name"])
707
+ for arg_name in tool["args"]:
708
+ tool_strings.append(
709
+ f"Arg: {str(arg_name)}\nValue: "
710
+ f"{str(tool['args'][arg_name])}"
711
+ )
712
+ memories.append("\n".join(tool_strings))
713
+ memories.append(response_content)
714
+ self.agent_memory.add_memories(memories)
715
+
716
+ # 4) Optionally write state to disk for debugging/auditing.
717
+ if self.log_state:
718
+ self.write_state("execution_agent.json", new_state)
719
+
720
+ # 5) Return a partial state update with only the summary content.
721
+ return new_state
722
+
723
+ def safety_check(self, state: ExecutionState) -> ExecutionState:
724
+ """Assess pending shell commands for safety and inject ToolMessages with results.
725
+
726
+ This method inspects the most recent AI tool calls, evaluates any run_cmd
727
+ queries against the safety prompt, and constructs ToolMessages that either
728
+ flag unsafe commands with reasons or confirm safe execution. If any command
729
+ is unsafe, the generated ToolMessages are appended to the state so the agent
730
+ can react without executing the command.
731
+
732
+ Args:
733
+ state (ExecutionState): Current execution state.
734
+
735
+ Returns:
736
+ ExecutionState: Either the unchanged state (all safe) or a copy with one
737
+ or more ToolMessages appended when unsafe commands are detected.
738
+ """
739
+ # 1) Work on a shallow copy; inspect the most recent model message.
740
+ new_state = state.copy()
741
+ last_msg = new_state["messages"][-1]
742
+
743
+ # 1.5) Check message history length and summarize to shorten the token usage:
744
+ new_state = self._summarize_context(new_state)
745
+
746
+ # 2) Evaluate any pending run_cmd tool calls for safety.
747
+ tool_responses: list[ToolMessage] = []
748
+ any_unsafe = False
749
+ for tool_call in last_msg.tool_calls:
750
+ if tool_call["name"] != "run_cmd":
751
+ continue
752
+
753
+ query = tool_call["args"]["query"]
754
+ safety_result = self.llm.invoke(
755
+ self.get_safety_prompt(
756
+ query, self.safe_codes, new_state.get("code_files", [])
757
+ ),
758
+ self.build_config(tags=["safety_check"]),
759
+ )
760
+
761
+ if "[NO]" in safety_result.content:
762
+ any_unsafe = True
763
+ tool_response = (
764
+ "[UNSAFE] That command `{q}` was deemed unsafe and cannot be run.\n"
765
+ "For reason: {r}"
766
+ ).format(q=query, r=safety_result.content)
767
+ console.print(
768
+ "[bold red][WARNING][/bold red] Command deemed unsafe:",
769
+ query,
770
+ )
771
+ # Also surface the model's rationale for transparency.
772
+ console.print(
773
+ "[bold red][WARNING][/bold red] REASON:", tool_response
774
+ )
775
+ else:
776
+ tool_response = f"Command `{query}` passed safety check."
777
+ console.print(
778
+ f"[green]Command passed safety check:[/green] {query}"
779
+ )
780
+
781
+ tool_responses.append(
782
+ ToolMessage(
783
+ content=tool_response,
784
+ tool_call_id=tool_call["id"],
785
+ )
786
+ )
787
+
788
+ # 3) If any command is unsafe, append all tool responses; otherwise keep state.
789
+ if any_unsafe:
790
+ new_state["messages"].extend(tool_responses)
791
+
792
+ return new_state
793
+
794
+ def _build_graph(self):
795
+ """Construct and compile the agent's LangGraph state machine."""
796
+ # Create a graph over the agent's execution state.
797
+ graph = StateGraph(ExecutionState)
798
+
799
+ # Register nodes:
800
+ # - "agent": LLM planning/execution step
801
+ # - "action": tool dispatch (run_cmd, write_code, etc.)
802
+ # - "summarize": summary/finalization step
803
+ # - "safety_check": gate for shell command safety
804
+ self.add_node(graph, self.query_executor, "agent")
805
+ self.add_node(graph, self.tool_node, "action")
806
+ self.add_node(graph, self.summarize, "summarize")
807
+ self.add_node(graph, self.safety_check, "safety_check")
808
+
809
+ # Set entrypoint: execution starts with the "agent" node.
810
+ graph.set_entry_point("agent")
811
+
812
+ # From "agent", either continue (tools) or finish (summarize),
813
+ # based on presence of tool calls in the last message.
814
+ graph.add_conditional_edges(
815
+ "agent",
816
+ self._wrap_cond(should_continue, "should_continue", "execution"),
817
+ {"continue": "safety_check", "summarize": "summarize"},
818
+ )
819
+
820
+ # From "safety_check", route to tools if safe, otherwise back to agent
821
+ # to revise the plan without executing unsafe commands.
822
+ graph.add_conditional_edges(
823
+ "safety_check",
824
+ self._wrap_cond(command_safe, "command_safe", "execution"),
825
+ {"safe": "action", "unsafe": "agent"},
826
+ )
827
+
828
+ # After tools run, return control to the agent for the next step.
829
+ graph.add_edge("action", "agent")
830
+
831
+ # The graph completes at the "summarize" node.
832
+ graph.set_finish_point("summarize")
833
+
834
+ # Compile and return the executable graph (optionally with a checkpointer).
835
+ return graph.compile(checkpointer=self.checkpointer)
836
+
837
+ async def add_mcp_tool(
838
+ self, mcp_tools: Callable[..., Any] | list[Callable[..., Any]]
839
+ ) -> None:
840
+ client = MultiServerMCPClient(mcp_tools)
841
+ tools = await client.get_tools()
842
+ self.add_tool(tools)
843
+
844
+ def add_tool(
845
+ self, new_tools: Callable[..., Any] | list[Callable[..., Any]]
846
+ ) -> None:
847
+ if isinstance(new_tools, list):
848
+ self.tools.extend([convert_to_tool(x) for x in new_tools])
849
+ elif isinstance(new_tools, StructuredTool) or isinstance(
850
+ new_tools, Callable
851
+ ):
852
+ self.tools.append(convert_to_tool(new_tools))
853
+ else:
854
+ raise TypeError("Expected a callable or a list of callables.")
855
+ self.tool_node = ToolNode(self.tools)
856
+ self.llm = self.llm.bind_tools(self.tools)
857
+ self._action = self._build_graph()
858
+
859
+ def list_tools(self) -> None:
860
+ print(
861
+ f"Available tool names are: {', '.join([x.name for x in self.tools])}."
862
+ )
863
+
864
+ def remove_tool(self, cut_tools: str | list[str]) -> None:
865
+ if isinstance(cut_tools, str):
866
+ self.remove_tool([cut_tools])
867
+ elif isinstance(cut_tools, list):
868
+ self.tools = [x for x in self.tools if x.name not in cut_tools]
869
+ self.tool_node = ToolNode(self.tools)
870
+ self.llm = self.llm.bind_tools(self.tools)
871
+ self._action = self._build_graph()
872
+ else:
873
+ raise TypeError(
874
+ "Expected a string or a list of strings describing the tools to remove."
875
+ )
876
+
877
+ def _invoke(
878
+ self, inputs: Mapping[str, Any], recursion_limit: int = 999_999, **_
879
+ ):
880
+ """Invoke the compiled graph with inputs under a specified recursion limit.
881
+
882
+ This method builds a LangGraph config with the provided recursion limit
883
+ and a "graph" tag, then delegates to the compiled graph's invoke method.
884
+ """
885
+ # Build invocation config with a generous recursion limit for long runs.
886
+ config = self.build_config(
887
+ recursion_limit=recursion_limit, tags=["graph"]
888
+ )
889
+
890
+ # Delegate execution to the compiled graph.
891
+ return self._action.invoke(inputs, config)
892
+
893
+ def _ainvoke(
894
+ self, inputs: Mapping[str, Any], recursion_limit: int = 999_999, **_
895
+ ):
896
+ """Invoke the compiled graph with inputs under a specified recursion limit.
897
+
898
+ This method builds a LangGraph config with the provided recursion limit
899
+ and a "graph" tag, then delegates to the compiled graph's invoke method.
900
+ """
901
+ # Build invocation config with a generous recursion limit for long runs.
902
+ config = self.build_config(
903
+ recursion_limit=recursion_limit, tags=["graph"]
904
+ )
905
+
906
+ # Delegate execution to the compiled graph.
907
+ return self._action.ainvoke(inputs, config)
908
+
909
+ # This property is trying to stop people bypassing invoke
910
+ @property
911
+ def action(self):
912
+ """Property used to affirm `action` attribute is unsupported."""
913
+ raise AttributeError(
914
+ "Use .stream(...) or .invoke(...); direct .action access is unsupported."
915
+ )