ursa-ai 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ursa/__init__.py +3 -0
- ursa/agents/__init__.py +32 -0
- ursa/agents/acquisition_agents.py +812 -0
- ursa/agents/arxiv_agent.py +429 -0
- ursa/agents/base.py +728 -0
- ursa/agents/chat_agent.py +60 -0
- ursa/agents/code_review_agent.py +341 -0
- ursa/agents/execution_agent.py +915 -0
- ursa/agents/hypothesizer_agent.py +614 -0
- ursa/agents/lammps_agent.py +465 -0
- ursa/agents/mp_agent.py +204 -0
- ursa/agents/optimization_agent.py +410 -0
- ursa/agents/planning_agent.py +219 -0
- ursa/agents/rag_agent.py +304 -0
- ursa/agents/recall_agent.py +54 -0
- ursa/agents/websearch_agent.py +196 -0
- ursa/cli/__init__.py +363 -0
- ursa/cli/hitl.py +516 -0
- ursa/cli/hitl_api.py +75 -0
- ursa/observability/metrics_charts.py +1279 -0
- ursa/observability/metrics_io.py +11 -0
- ursa/observability/metrics_session.py +750 -0
- ursa/observability/pricing.json +97 -0
- ursa/observability/pricing.py +321 -0
- ursa/observability/timing.py +1466 -0
- ursa/prompt_library/__init__.py +0 -0
- ursa/prompt_library/code_review_prompts.py +51 -0
- ursa/prompt_library/execution_prompts.py +50 -0
- ursa/prompt_library/hypothesizer_prompts.py +17 -0
- ursa/prompt_library/literature_prompts.py +11 -0
- ursa/prompt_library/optimization_prompts.py +131 -0
- ursa/prompt_library/planning_prompts.py +79 -0
- ursa/prompt_library/websearch_prompts.py +131 -0
- ursa/tools/__init__.py +0 -0
- ursa/tools/feasibility_checker.py +114 -0
- ursa/tools/feasibility_tools.py +1075 -0
- ursa/tools/run_command.py +27 -0
- ursa/tools/write_code.py +42 -0
- ursa/util/__init__.py +0 -0
- ursa/util/diff_renderer.py +128 -0
- ursa/util/helperFunctions.py +142 -0
- ursa/util/logo_generator.py +625 -0
- ursa/util/memory_logger.py +183 -0
- ursa/util/optimization_schema.py +78 -0
- ursa/util/parse.py +405 -0
- ursa_ai-0.9.1.dist-info/METADATA +304 -0
- ursa_ai-0.9.1.dist-info/RECORD +51 -0
- ursa_ai-0.9.1.dist-info/WHEEL +5 -0
- ursa_ai-0.9.1.dist-info/entry_points.txt +2 -0
- ursa_ai-0.9.1.dist-info/licenses/LICENSE +8 -0
- ursa_ai-0.9.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,915 @@
|
|
|
1
|
+
"""Execution agent that builds a tool-enabled state graph to autonomously run tasks.
|
|
2
|
+
|
|
3
|
+
This module implements ExecutionAgent, a LangGraph-based agent that executes user
|
|
4
|
+
instructions by invoking LLM tool calls and coordinating a controlled workflow.
|
|
5
|
+
|
|
6
|
+
Key features:
|
|
7
|
+
- Workspace management with optional symlinking for external sources.
|
|
8
|
+
- Safety-checked shell execution via run_cmd with output size budgeting.
|
|
9
|
+
- Code authoring and edits through write_code and edit_code with rich previews.
|
|
10
|
+
- Web search capability through DuckDuckGoSearchResults.
|
|
11
|
+
- Summarization of the session and optional memory logging.
|
|
12
|
+
- Configurable graph with nodes for agent, safety_check, action, and summarize.
|
|
13
|
+
|
|
14
|
+
Implementation notes:
|
|
15
|
+
- LLM prompts are sourced from prompt_library.execution_prompts.
|
|
16
|
+
- Outputs from subprocess are trimmed under MAX_TOOL_MSG_CHARS to fit tool messages.
|
|
17
|
+
- The agent uses ToolNode and LangGraph StateGraph to loop until no tool calls remain.
|
|
18
|
+
- Safety gates block unsafe shell commands and surface the rationale to the user.
|
|
19
|
+
|
|
20
|
+
Environment:
|
|
21
|
+
- MAX_TOOL_MSG_CHARS caps combined stdout/stderr in tool responses.
|
|
22
|
+
|
|
23
|
+
Entry points:
|
|
24
|
+
- ExecutionAgent._invoke(...) runs the compiled graph.
|
|
25
|
+
- main() shows a minimal demo that writes and runs a script.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
# from langchain_core.runnables.graph import MermaidDrawMethod
|
|
29
|
+
import os
|
|
30
|
+
import subprocess
|
|
31
|
+
from pathlib import Path
|
|
32
|
+
from typing import Annotated, Any, Callable, Literal, Mapping, Optional
|
|
33
|
+
|
|
34
|
+
import randomname
|
|
35
|
+
from langchain.agents.middleware import SummarizationMiddleware
|
|
36
|
+
from langchain.chat_models import BaseChatModel
|
|
37
|
+
from langchain_community.tools import (
|
|
38
|
+
DuckDuckGoSearchResults,
|
|
39
|
+
) # TavilySearchResults,
|
|
40
|
+
from langchain_core.messages import (
|
|
41
|
+
AIMessage,
|
|
42
|
+
AnyMessage,
|
|
43
|
+
SystemMessage,
|
|
44
|
+
ToolMessage,
|
|
45
|
+
)
|
|
46
|
+
from langchain_core.tools import InjectedToolCallId, StructuredTool, tool
|
|
47
|
+
from langchain_mcp_adapters.client import MultiServerMCPClient
|
|
48
|
+
from langgraph.graph import StateGraph
|
|
49
|
+
from langgraph.graph.message import add_messages
|
|
50
|
+
from langgraph.prebuilt import InjectedState, ToolNode
|
|
51
|
+
from langgraph.types import Command
|
|
52
|
+
|
|
53
|
+
# Rich
|
|
54
|
+
from rich import get_console
|
|
55
|
+
from rich.panel import Panel
|
|
56
|
+
from rich.syntax import Syntax
|
|
57
|
+
from typing_extensions import TypedDict
|
|
58
|
+
|
|
59
|
+
from ..prompt_library.execution_prompts import (
|
|
60
|
+
executor_prompt,
|
|
61
|
+
get_safety_prompt,
|
|
62
|
+
summarize_prompt,
|
|
63
|
+
)
|
|
64
|
+
from ..util.diff_renderer import DiffRenderer
|
|
65
|
+
from ..util.memory_logger import AgentMemory
|
|
66
|
+
from .base import BaseAgent
|
|
67
|
+
|
|
68
|
+
console = get_console() # always returns the same instance
|
|
69
|
+
|
|
70
|
+
# --- ANSI color codes ---
|
|
71
|
+
GREEN = "\033[92m"
|
|
72
|
+
BLUE = "\033[94m"
|
|
73
|
+
RED = "\033[91m"
|
|
74
|
+
RESET = "\033[0m"
|
|
75
|
+
BOLD = "\033[1m"
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
# Global variables for the module.
|
|
79
|
+
|
|
80
|
+
# Set a limit for message characters - the user could overload
|
|
81
|
+
# that in their env, or maybe we could pull this out of the LLM parameters
|
|
82
|
+
MAX_TOOL_MSG_CHARS = int(os.getenv("MAX_TOOL_MSG_CHARS", "50000"))
|
|
83
|
+
|
|
84
|
+
# Set a search tool.
|
|
85
|
+
search_tool = DuckDuckGoSearchResults(output_format="json", num_results=10)
|
|
86
|
+
# search_tool = TavilySearchResults(
|
|
87
|
+
# max_results=10,
|
|
88
|
+
# search_depth="advanced",
|
|
89
|
+
# include_answer=True)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
# Classes for typing
|
|
93
|
+
class ExecutionState(TypedDict):
|
|
94
|
+
"""TypedDict representing the execution agent's mutable run state used by nodes.
|
|
95
|
+
|
|
96
|
+
Fields:
|
|
97
|
+
- messages: list of messages (System/Human/AI/Tool) with add_messages metadata.
|
|
98
|
+
- current_progress: short status string describing agent progress.
|
|
99
|
+
- code_files: list of filenames created or edited in the workspace.
|
|
100
|
+
- workspace: path to the working directory where files and commands run.
|
|
101
|
+
- symlinkdir: optional dict describing a symlink operation (source, dest,
|
|
102
|
+
is_linked).
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
messages: Annotated[list[AnyMessage], add_messages]
|
|
106
|
+
current_progress: str
|
|
107
|
+
code_files: list[str]
|
|
108
|
+
workspace: str
|
|
109
|
+
symlinkdir: dict
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
# Helper functions
|
|
113
|
+
def convert_to_tool(fn):
|
|
114
|
+
if isinstance(fn, StructuredTool):
|
|
115
|
+
return fn
|
|
116
|
+
else:
|
|
117
|
+
return StructuredTool.from_function(
|
|
118
|
+
func=fn, name=fn.__name__, description=fn.__doc__
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def _strip_fences(snippet: str) -> str:
|
|
123
|
+
"""Remove markdown fences from a code snippet.
|
|
124
|
+
|
|
125
|
+
This function strips leading triple backticks and any language
|
|
126
|
+
identifiers from a markdown-formatted code snippet and returns
|
|
127
|
+
only the contained code.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
snippet: The markdown-formatted code snippet.
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
The snippet content without leading markdown fences.
|
|
134
|
+
"""
|
|
135
|
+
if "```" not in snippet:
|
|
136
|
+
return snippet
|
|
137
|
+
|
|
138
|
+
parts = snippet.split("```")
|
|
139
|
+
if len(parts) < 3:
|
|
140
|
+
return snippet
|
|
141
|
+
|
|
142
|
+
body = parts[1]
|
|
143
|
+
return "\n".join(body.split("\n")[1:]) if "\n" in body else body.strip()
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def _snip_text(text: str, max_chars: int) -> tuple[str, bool]:
|
|
147
|
+
"""Truncate text to a maximum length and indicate if truncation occurred.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
text: The original text to potentially truncate.
|
|
151
|
+
max_chars: The maximum characters allowed in the output.
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
A tuple of (possibly truncated text, boolean flag indicating
|
|
155
|
+
if truncation occurred).
|
|
156
|
+
"""
|
|
157
|
+
if text is None:
|
|
158
|
+
return "", False
|
|
159
|
+
if max_chars <= 0:
|
|
160
|
+
return "", len(text) > 0
|
|
161
|
+
if len(text) <= max_chars:
|
|
162
|
+
return text, False
|
|
163
|
+
head = max_chars // 2
|
|
164
|
+
tail = max_chars - head
|
|
165
|
+
return (
|
|
166
|
+
text[:head]
|
|
167
|
+
+ f"\n... [snipped {len(text) - max_chars} chars] ...\n"
|
|
168
|
+
+ text[-tail:],
|
|
169
|
+
True,
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def _fit_streams_to_budget(stdout: str, stderr: str, total_budget: int):
|
|
174
|
+
"""Allocate and truncate stdout and stderr to fit a total character budget.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
stdout: The original stdout string.
|
|
178
|
+
stderr: The original stderr string.
|
|
179
|
+
total_budget: The combined character budget for stdout and stderr.
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
A tuple of (possibly truncated stdout, possibly truncated stderr).
|
|
183
|
+
"""
|
|
184
|
+
label_overhead = len("STDOUT:\n") + len("\nSTDERR:\n")
|
|
185
|
+
budget = max(0, total_budget - label_overhead)
|
|
186
|
+
|
|
187
|
+
if len(stdout) + len(stderr) <= budget:
|
|
188
|
+
return stdout, stderr
|
|
189
|
+
|
|
190
|
+
total_len = max(1, len(stdout) + len(stderr))
|
|
191
|
+
stdout_budget = int(budget * (len(stdout) / total_len))
|
|
192
|
+
stderr_budget = budget - stdout_budget
|
|
193
|
+
|
|
194
|
+
stdout_snip, _ = _snip_text(stdout, stdout_budget)
|
|
195
|
+
stderr_snip, _ = _snip_text(stderr, stderr_budget)
|
|
196
|
+
|
|
197
|
+
return stdout_snip, stderr_snip
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def should_continue(state: ExecutionState) -> Literal["summarize", "continue"]:
|
|
201
|
+
"""Return 'summarize' if no tool calls in the last message, else 'continue'.
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
state: The current execution state containing messages.
|
|
205
|
+
|
|
206
|
+
Returns:
|
|
207
|
+
A literal "summarize" if the last message has no tool calls,
|
|
208
|
+
otherwise "continue".
|
|
209
|
+
"""
|
|
210
|
+
messages = state["messages"]
|
|
211
|
+
last_message = messages[-1]
|
|
212
|
+
# If there is no tool call, then we finish
|
|
213
|
+
if not last_message.tool_calls:
|
|
214
|
+
return "summarize"
|
|
215
|
+
# Otherwise if there is, we continue
|
|
216
|
+
else:
|
|
217
|
+
return "continue"
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def command_safe(state: ExecutionState) -> Literal["safe", "unsafe"]:
|
|
221
|
+
"""Return 'safe' if the last command was safe, otherwise 'unsafe'.
|
|
222
|
+
|
|
223
|
+
Args:
|
|
224
|
+
state: The current execution state containing messages and tool calls.
|
|
225
|
+
Returns:
|
|
226
|
+
A literal "safe" if no '[UNSAFE]' tags are in the last command,
|
|
227
|
+
otherwise "unsafe".
|
|
228
|
+
"""
|
|
229
|
+
index = -1
|
|
230
|
+
message = state["messages"][index]
|
|
231
|
+
# Loop through all the consecutive tool messages in reverse order
|
|
232
|
+
while isinstance(message, ToolMessage):
|
|
233
|
+
if "[UNSAFE]" in message.content:
|
|
234
|
+
return "unsafe"
|
|
235
|
+
|
|
236
|
+
index -= 1
|
|
237
|
+
message = state["messages"][index]
|
|
238
|
+
|
|
239
|
+
return "safe"
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
# Tools for ExecutionAgent
|
|
243
|
+
@tool
|
|
244
|
+
def run_cmd(query: str, state: Annotated[dict, InjectedState]) -> str:
|
|
245
|
+
"""Execute a shell command in the workspace and return its combined output.
|
|
246
|
+
|
|
247
|
+
Runs the specified command using subprocess.run in the given workspace
|
|
248
|
+
directory, captures stdout and stderr, enforces a maximum character budget,
|
|
249
|
+
and formats both streams into a single string. KeyboardInterrupt during
|
|
250
|
+
execution is caught and reported.
|
|
251
|
+
|
|
252
|
+
Args:
|
|
253
|
+
query: The shell command to execute.
|
|
254
|
+
state: A dict with injected state; must include the 'workspace' path.
|
|
255
|
+
|
|
256
|
+
Returns:
|
|
257
|
+
A formatted string with "STDOUT:" followed by the truncated stdout and
|
|
258
|
+
"STDERR:" followed by the truncated stderr.
|
|
259
|
+
"""
|
|
260
|
+
workspace_dir = state["workspace"]
|
|
261
|
+
|
|
262
|
+
print("RUNNING: ", query)
|
|
263
|
+
try:
|
|
264
|
+
result = subprocess.run(
|
|
265
|
+
query,
|
|
266
|
+
text=True,
|
|
267
|
+
shell=True,
|
|
268
|
+
timeout=60000,
|
|
269
|
+
capture_output=True,
|
|
270
|
+
cwd=workspace_dir,
|
|
271
|
+
)
|
|
272
|
+
stdout, stderr = result.stdout, result.stderr
|
|
273
|
+
except KeyboardInterrupt:
|
|
274
|
+
print("Keyboard Interrupt of command: ", query)
|
|
275
|
+
stdout, stderr = "", "KeyboardInterrupt:"
|
|
276
|
+
|
|
277
|
+
# Fit BOTH streams under a single overall cap
|
|
278
|
+
stdout_fit, stderr_fit = _fit_streams_to_budget(
|
|
279
|
+
stdout or "", stderr or "", MAX_TOOL_MSG_CHARS
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
print("STDOUT: ", stdout_fit)
|
|
283
|
+
print("STDERR: ", stderr_fit)
|
|
284
|
+
|
|
285
|
+
return f"STDOUT:\n{stdout_fit}\nSTDERR:\n{stderr_fit}"
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
@tool
|
|
289
|
+
def write_code(
|
|
290
|
+
code: str,
|
|
291
|
+
filename: str,
|
|
292
|
+
tool_call_id: Annotated[str, InjectedToolCallId],
|
|
293
|
+
state: Annotated[dict, InjectedState],
|
|
294
|
+
) -> Command:
|
|
295
|
+
"""Write source code to a file and update the agent’s workspace state.
|
|
296
|
+
|
|
297
|
+
Args:
|
|
298
|
+
code: The source code content to be written to disk.
|
|
299
|
+
filename: Name of the target file (including its extension).
|
|
300
|
+
tool_call_id: Identifier for this tool invocation.
|
|
301
|
+
state: Agent state dict holding workspace path and file list.
|
|
302
|
+
|
|
303
|
+
Returns:
|
|
304
|
+
Command: Contains an updated state (including code_files) and
|
|
305
|
+
a ToolMessage acknowledging success or failure.
|
|
306
|
+
"""
|
|
307
|
+
# Determine the full path to the target file
|
|
308
|
+
workspace_dir = state["workspace"]
|
|
309
|
+
console.print("[cyan]Writing file:[/]", filename)
|
|
310
|
+
|
|
311
|
+
# Clean up markdown fences on submitted code.
|
|
312
|
+
code = _strip_fences(code)
|
|
313
|
+
|
|
314
|
+
# Show syntax-highlighted preview before writing to file
|
|
315
|
+
try:
|
|
316
|
+
lexer_name = Syntax.guess_lexer(filename, code)
|
|
317
|
+
except Exception:
|
|
318
|
+
lexer_name = "text"
|
|
319
|
+
|
|
320
|
+
console.print(
|
|
321
|
+
Panel(
|
|
322
|
+
Syntax(code, lexer_name, line_numbers=True),
|
|
323
|
+
title="File Preview",
|
|
324
|
+
border_style="cyan",
|
|
325
|
+
)
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
# Write cleaned code to disk
|
|
329
|
+
code_file = os.path.join(workspace_dir, filename)
|
|
330
|
+
try:
|
|
331
|
+
with open(code_file, "w", encoding="utf-8") as f:
|
|
332
|
+
f.write(code)
|
|
333
|
+
except Exception as exc:
|
|
334
|
+
console.print(
|
|
335
|
+
"[bold bright_white on red] :heavy_multiplication_x: [/] "
|
|
336
|
+
"[red]Failed to write file:[/]",
|
|
337
|
+
exc,
|
|
338
|
+
)
|
|
339
|
+
return f"Failed to write {filename}."
|
|
340
|
+
|
|
341
|
+
console.print(
|
|
342
|
+
f"[bold bright_white on green] :heavy_check_mark: [/] "
|
|
343
|
+
f"[green]File written:[/] {code_file}"
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
# Append the file to the list in agent's state for later reference
|
|
347
|
+
file_list = state.get("code_files", [])
|
|
348
|
+
if filename not in file_list:
|
|
349
|
+
file_list.append(filename)
|
|
350
|
+
|
|
351
|
+
# Create a tool message to send back to acknowledge success.
|
|
352
|
+
msg = ToolMessage(
|
|
353
|
+
content=f"File {filename} written successfully.",
|
|
354
|
+
tool_call_id=tool_call_id,
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
# Return updated code files list & the message
|
|
358
|
+
return Command(
|
|
359
|
+
update={
|
|
360
|
+
"code_files": file_list,
|
|
361
|
+
"messages": [msg],
|
|
362
|
+
}
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
@tool
|
|
367
|
+
def edit_code(
|
|
368
|
+
old_code: str,
|
|
369
|
+
new_code: str,
|
|
370
|
+
filename: str,
|
|
371
|
+
state: Annotated[dict, InjectedState],
|
|
372
|
+
) -> str:
|
|
373
|
+
"""Replace the **first** occurrence of *old_code* with *new_code* in *filename*.
|
|
374
|
+
|
|
375
|
+
Args:
|
|
376
|
+
old_code: Code fragment to search for.
|
|
377
|
+
new_code: Replacement fragment.
|
|
378
|
+
filename: Target file inside the workspace.
|
|
379
|
+
|
|
380
|
+
Returns:
|
|
381
|
+
Success / failure message.
|
|
382
|
+
"""
|
|
383
|
+
workspace_dir = state["workspace"]
|
|
384
|
+
console.print("[cyan]Editing file:[/cyan]", filename)
|
|
385
|
+
|
|
386
|
+
code_file = os.path.join(workspace_dir, filename)
|
|
387
|
+
try:
|
|
388
|
+
with open(code_file, "r", encoding="utf-8") as f:
|
|
389
|
+
content = f.read()
|
|
390
|
+
except FileNotFoundError:
|
|
391
|
+
console.print(
|
|
392
|
+
"[bold bright_white on red] :heavy_multiplication_x: [/] "
|
|
393
|
+
"[red]File not found:[/]",
|
|
394
|
+
filename,
|
|
395
|
+
)
|
|
396
|
+
return f"Failed: {filename} not found."
|
|
397
|
+
|
|
398
|
+
# Clean up markdown fences
|
|
399
|
+
old_code_clean = _strip_fences(old_code)
|
|
400
|
+
new_code_clean = _strip_fences(new_code)
|
|
401
|
+
|
|
402
|
+
if old_code_clean not in content:
|
|
403
|
+
console.print(
|
|
404
|
+
"[yellow] ⚠️ 'old_code' not found in file'; no changes made.[/]"
|
|
405
|
+
)
|
|
406
|
+
return f"No changes made to {filename}: 'old_code' not found in file."
|
|
407
|
+
|
|
408
|
+
updated = content.replace(old_code_clean, new_code_clean, 1)
|
|
409
|
+
|
|
410
|
+
console.print(
|
|
411
|
+
Panel(
|
|
412
|
+
DiffRenderer(content, updated, filename),
|
|
413
|
+
title="Diff Preview",
|
|
414
|
+
border_style="cyan",
|
|
415
|
+
)
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
try:
|
|
419
|
+
with open(code_file, "w", encoding="utf-8") as f:
|
|
420
|
+
f.write(updated)
|
|
421
|
+
except Exception as exc:
|
|
422
|
+
console.print(
|
|
423
|
+
"[bold bright_white on red] :heavy_multiplication_x: [/] "
|
|
424
|
+
"[red]Failed to write file:[/]",
|
|
425
|
+
exc,
|
|
426
|
+
)
|
|
427
|
+
return f"Failed to edit {filename}."
|
|
428
|
+
|
|
429
|
+
console.print(
|
|
430
|
+
f"[bold bright_white on green] :heavy_check_mark: [/] "
|
|
431
|
+
f"[green]File updated:[/] {code_file}"
|
|
432
|
+
)
|
|
433
|
+
file_list = state.get("code_files", [])
|
|
434
|
+
if code_file not in file_list:
|
|
435
|
+
file_list.append(filename)
|
|
436
|
+
state["code_files"] = file_list
|
|
437
|
+
|
|
438
|
+
return f"File {filename} updated successfully."
|
|
439
|
+
|
|
440
|
+
|
|
441
|
+
# Main module class
|
|
442
|
+
class ExecutionAgent(BaseAgent):
|
|
443
|
+
"""Orchestrates model-driven code execution, tool calls, and state management.
|
|
444
|
+
|
|
445
|
+
Orchestrates model-driven code execution, tool calls, and state management for
|
|
446
|
+
iterative program synthesis and shell interaction.
|
|
447
|
+
|
|
448
|
+
This agent wraps an LLM with a small execution graph that alternates
|
|
449
|
+
between issuing model queries, invoking tools (run, write, edit, search),
|
|
450
|
+
performing safety checks, and summarizing progress. It manages a
|
|
451
|
+
workspace on disk, optional symlinks, and an optional memory backend to
|
|
452
|
+
persist summaries.
|
|
453
|
+
|
|
454
|
+
Args:
|
|
455
|
+
llm (BaseChatModel): Model identifier or bound chat model
|
|
456
|
+
instance. If a string is provided, the BaseAgent initializer will
|
|
457
|
+
resolve it.
|
|
458
|
+
agent_memory (Any | AgentMemory, optional): Memory backend used to
|
|
459
|
+
store summarized agent interactions. If provided, summaries are
|
|
460
|
+
saved here.
|
|
461
|
+
log_state (bool): When True, the agent writes intermediate json state
|
|
462
|
+
to disk for debugging and auditability.
|
|
463
|
+
**kwargs: Passed through to the BaseAgent constructor (e.g., model
|
|
464
|
+
configuration, checkpointer).
|
|
465
|
+
|
|
466
|
+
Attributes:
|
|
467
|
+
safe_codes (list[str]): List of trusted programming languages for the
|
|
468
|
+
agent. Defaults to python and julia
|
|
469
|
+
executor_prompt (str): Prompt used when invoking the executor LLM
|
|
470
|
+
loop.
|
|
471
|
+
summarize_prompt (str): Prompt used to request concise summaries for
|
|
472
|
+
memory or final output.
|
|
473
|
+
tools (list[Tool]): Tools available to the agent (run_cmd, write_code,
|
|
474
|
+
edit_code, search_tool).
|
|
475
|
+
tool_node (ToolNode): Graph node that dispatches tool calls.
|
|
476
|
+
llm (BaseChatModel): LLM instance bound to the available tools.
|
|
477
|
+
_action (StateGraph): Compiled execution graph that implements the
|
|
478
|
+
main loop and branching logic.
|
|
479
|
+
|
|
480
|
+
Methods:
|
|
481
|
+
query_executor(state): Send messages to the executor LLM, ensure
|
|
482
|
+
workspace exists, and handle symlink setup before returning the
|
|
483
|
+
model response.
|
|
484
|
+
summarize(state): Produce and optionally persist a summary of recent
|
|
485
|
+
interactions to the memory backend.
|
|
486
|
+
safety_check(state): Validate pending run_cmd calls via the safety
|
|
487
|
+
prompt and append ToolMessages for unsafe commands.
|
|
488
|
+
get_safety_prompt(query, safe_codes, created_files): Get the LLM prompt for safety_check
|
|
489
|
+
that includes an editable list of available programming languages and gets the context
|
|
490
|
+
of files that the agent has generated and can trust.
|
|
491
|
+
_build_graph(): Construct and compile the StateGraph for the agent
|
|
492
|
+
loop.
|
|
493
|
+
_invoke(inputs, recursion_limit=...): Internal entry that invokes the
|
|
494
|
+
compiled graph with a given recursion limit.
|
|
495
|
+
action (property): Disabled; direct access is not supported. Use
|
|
496
|
+
invoke or stream entry points instead.
|
|
497
|
+
|
|
498
|
+
Raises:
|
|
499
|
+
AttributeError: Accessing the .action attribute raises to encourage
|
|
500
|
+
using .stream(...) or .invoke(...).
|
|
501
|
+
"""
|
|
502
|
+
|
|
503
|
+
def __init__(
|
|
504
|
+
self,
|
|
505
|
+
llm: BaseChatModel,
|
|
506
|
+
agent_memory: Optional[Any | AgentMemory] = None,
|
|
507
|
+
log_state: bool = False,
|
|
508
|
+
extra_tools: Optional[list[Callable[..., Any]]] = None,
|
|
509
|
+
tokens_before_summarize: int = 50000,
|
|
510
|
+
messages_to_keep: int = 20,
|
|
511
|
+
safe_codes: Optional[list[str]] = None,
|
|
512
|
+
**kwargs,
|
|
513
|
+
):
|
|
514
|
+
"""ExecutionAgent class initialization."""
|
|
515
|
+
super().__init__(llm, **kwargs)
|
|
516
|
+
self.agent_memory = agent_memory
|
|
517
|
+
self.safe_codes = safe_codes or ["python", "julia"]
|
|
518
|
+
self.get_safety_prompt = get_safety_prompt
|
|
519
|
+
self.executor_prompt = executor_prompt
|
|
520
|
+
self.summarize_prompt = summarize_prompt
|
|
521
|
+
self.tools = [run_cmd, write_code, edit_code, search_tool]
|
|
522
|
+
self.extra_tools = extra_tools
|
|
523
|
+
if self.extra_tools is not None:
|
|
524
|
+
self.tools.extend(self.extra_tools)
|
|
525
|
+
self.tool_node = ToolNode(self.tools)
|
|
526
|
+
self.llm = self.llm.bind_tools(self.tools)
|
|
527
|
+
self.log_state = log_state
|
|
528
|
+
self._action = self._build_graph()
|
|
529
|
+
self.context_summarizer = SummarizationMiddleware(
|
|
530
|
+
model=self.llm,
|
|
531
|
+
max_tokens_before_summary=tokens_before_summarize,
|
|
532
|
+
messages_to_keep=messages_to_keep,
|
|
533
|
+
)
|
|
534
|
+
|
|
535
|
+
# Check message history length and summarize to shorten the token usage:
|
|
536
|
+
def _summarize_context(self, state: ExecutionState) -> ExecutionState:
|
|
537
|
+
summarized_messages = self.context_summarizer.before_model(state, None)
|
|
538
|
+
if summarized_messages:
|
|
539
|
+
tokens_before_summarize = self.context_summarizer.token_counter(
|
|
540
|
+
state["messages"]
|
|
541
|
+
)
|
|
542
|
+
state["messages"] = summarized_messages["messages"]
|
|
543
|
+
tokens_after_summarize = self.context_summarizer.token_counter(
|
|
544
|
+
state["messages"][1:]
|
|
545
|
+
)
|
|
546
|
+
console.print(
|
|
547
|
+
Panel(
|
|
548
|
+
(
|
|
549
|
+
f"Summarized Conversation History:\n"
|
|
550
|
+
f"Approximate tokens before: {tokens_before_summarize}\n"
|
|
551
|
+
f"Approximate tokens after: {tokens_after_summarize}\n"
|
|
552
|
+
),
|
|
553
|
+
title="[bold yellow1 on black]:clipboard: Plan",
|
|
554
|
+
border_style="yellow1",
|
|
555
|
+
style="bold yellow1 on black",
|
|
556
|
+
)
|
|
557
|
+
)
|
|
558
|
+
else:
|
|
559
|
+
tokens_after_summarize = self.context_summarizer.token_counter(
|
|
560
|
+
state["messages"]
|
|
561
|
+
)
|
|
562
|
+
return state
|
|
563
|
+
|
|
564
|
+
# Define the function that calls the model
|
|
565
|
+
def query_executor(self, state: ExecutionState) -> ExecutionState:
|
|
566
|
+
"""Prepare workspace, handle optional symlinks, and invoke the executor LLM.
|
|
567
|
+
|
|
568
|
+
This method copies the incoming state, ensures a workspace directory exists
|
|
569
|
+
(creating one with a random name when absent), optionally creates a symlink
|
|
570
|
+
described by state["symlinkdir"], sets or injects the executor system prompt
|
|
571
|
+
as the first message, and invokes the bound LLM. When logging is enabled,
|
|
572
|
+
it persists the pre-invocation state to disk.
|
|
573
|
+
|
|
574
|
+
Args:
|
|
575
|
+
state: The current execution state. Expected keys include:
|
|
576
|
+
- "messages": Ordered list of System/Human/AI/Tool messages.
|
|
577
|
+
- "workspace": Optional path to the working directory.
|
|
578
|
+
- "symlinkdir": Optional dict with "source" and "dest" keys.
|
|
579
|
+
|
|
580
|
+
Returns:
|
|
581
|
+
ExecutionState: Partial state update containing:
|
|
582
|
+
- "messages": A list with the model's response as the latest entry.
|
|
583
|
+
- "workspace": The resolved workspace path.
|
|
584
|
+
"""
|
|
585
|
+
new_state = state.copy()
|
|
586
|
+
|
|
587
|
+
# 1) Ensure a workspace directory exists, creating a named one if absent.
|
|
588
|
+
if "workspace" not in new_state.keys():
|
|
589
|
+
new_state["workspace"] = randomname.get_name()
|
|
590
|
+
print(
|
|
591
|
+
f"{RED}Creating the folder "
|
|
592
|
+
f"{BLUE}{BOLD}{new_state['workspace']}{RESET}{RED} "
|
|
593
|
+
f"for this project.{RESET}"
|
|
594
|
+
)
|
|
595
|
+
os.makedirs(new_state["workspace"], exist_ok=True)
|
|
596
|
+
|
|
597
|
+
# 1.5) Check message history length and summarize to shorten the token usage:
|
|
598
|
+
new_state = self._summarize_context(new_state)
|
|
599
|
+
|
|
600
|
+
# 2) Optionally create a symlink if symlinkdir is provided and not yet linked.
|
|
601
|
+
sd = new_state.get("symlinkdir")
|
|
602
|
+
if isinstance(sd, dict) and "is_linked" not in sd:
|
|
603
|
+
# symlinkdir structure: {"source": "/path/to/src", "dest": "link/name"}
|
|
604
|
+
symlinkdir = sd
|
|
605
|
+
|
|
606
|
+
src = Path(symlinkdir["source"]).expanduser().resolve()
|
|
607
|
+
workspace_root = Path(new_state["workspace"]).expanduser().resolve()
|
|
608
|
+
dst = (
|
|
609
|
+
workspace_root / symlinkdir["dest"]
|
|
610
|
+
) # Link lives inside workspace.
|
|
611
|
+
|
|
612
|
+
# If a file/link already exists at the destination, replace it.
|
|
613
|
+
if dst.exists() or dst.is_symlink():
|
|
614
|
+
dst.unlink()
|
|
615
|
+
|
|
616
|
+
# Ensure parent directories for the link exist.
|
|
617
|
+
dst.parent.mkdir(parents=True, exist_ok=True)
|
|
618
|
+
|
|
619
|
+
# Create the symlink (tell pathlib if the target is a directory).
|
|
620
|
+
dst.symlink_to(src, target_is_directory=src.is_dir())
|
|
621
|
+
print(f"{RED}Symlinked {src} (source) --> {dst} (dest)")
|
|
622
|
+
new_state["symlinkdir"]["is_linked"] = True
|
|
623
|
+
|
|
624
|
+
# 3) Ensure the executor prompt is the first SystemMessage.
|
|
625
|
+
if isinstance(new_state["messages"][0], SystemMessage):
|
|
626
|
+
new_state["messages"][0] = SystemMessage(
|
|
627
|
+
content=self.executor_prompt
|
|
628
|
+
)
|
|
629
|
+
else:
|
|
630
|
+
new_state["messages"] = [
|
|
631
|
+
SystemMessage(content=self.executor_prompt)
|
|
632
|
+
] + state["messages"]
|
|
633
|
+
|
|
634
|
+
# 4) Invoke the LLM with the prepared message sequence.
|
|
635
|
+
try:
|
|
636
|
+
response = self.llm.invoke(
|
|
637
|
+
new_state["messages"], self.build_config(tags=["agent"])
|
|
638
|
+
)
|
|
639
|
+
new_state["messages"].append(response)
|
|
640
|
+
except Exception as e:
|
|
641
|
+
print("Error: ", e, " ", new_state["messages"][-1].content)
|
|
642
|
+
new_state["messages"].append(
|
|
643
|
+
AIMessage(content=f"Response error {e}")
|
|
644
|
+
)
|
|
645
|
+
|
|
646
|
+
# 5) Optionally persist the pre-invocation state for audit/debugging.
|
|
647
|
+
if self.log_state:
|
|
648
|
+
self.write_state("execution_agent.json", new_state)
|
|
649
|
+
|
|
650
|
+
# Return the model's response and the workspace path as a partial state update.
|
|
651
|
+
return new_state
|
|
652
|
+
|
|
653
|
+
def summarize(self, state: ExecutionState) -> ExecutionState:
|
|
654
|
+
"""Produce a concise summary of the conversation and optionally persist memory.
|
|
655
|
+
|
|
656
|
+
This method builds a summarization prompt, invokes the LLM to obtain a compact
|
|
657
|
+
summary of recent interactions, optionally logs salient details to the agent
|
|
658
|
+
memory backend, and writes debug state when logging is enabled.
|
|
659
|
+
|
|
660
|
+
Args:
|
|
661
|
+
state (ExecutionState): The execution state containing message history.
|
|
662
|
+
|
|
663
|
+
Returns:
|
|
664
|
+
ExecutionState: A partial update with a single string message containing
|
|
665
|
+
the summary.
|
|
666
|
+
"""
|
|
667
|
+
new_state = state.copy()
|
|
668
|
+
|
|
669
|
+
# 0) Check message history length and summarize to shorten the token usage:
|
|
670
|
+
new_state = self._summarize_context(new_state)
|
|
671
|
+
|
|
672
|
+
# 1) Construct the summarization message list (system prompt + prior messages).
|
|
673
|
+
messages = (
|
|
674
|
+
new_state["messages"]
|
|
675
|
+
if isinstance(new_state["messages"][0], SystemMessage)
|
|
676
|
+
else [SystemMessage(content=summarize_prompt)]
|
|
677
|
+
+ new_state["messages"]
|
|
678
|
+
)
|
|
679
|
+
|
|
680
|
+
# 2) Invoke the LLM to generate a summary; capture content even on failure.
|
|
681
|
+
response_content = ""
|
|
682
|
+
try:
|
|
683
|
+
response = self.llm.invoke(
|
|
684
|
+
messages, self.build_config(tags=["summarize"])
|
|
685
|
+
)
|
|
686
|
+
response_content = response.content
|
|
687
|
+
new_state["messages"].append(response)
|
|
688
|
+
except Exception as e:
|
|
689
|
+
print("Error: ", e, " ", messages[-1].content)
|
|
690
|
+
new_state["messages"].append(
|
|
691
|
+
AIMessage(content=f"Response error {e}")
|
|
692
|
+
)
|
|
693
|
+
|
|
694
|
+
# 3) Optionally persist salient details to the memory backend.
|
|
695
|
+
if self.agent_memory:
|
|
696
|
+
memories: list[str] = []
|
|
697
|
+
# Collect human/system/tool message content; for AI tool calls, store args.
|
|
698
|
+
for msg in new_state["messages"]:
|
|
699
|
+
if not isinstance(msg, AIMessage):
|
|
700
|
+
memories.append(msg.content)
|
|
701
|
+
elif not msg.tool_calls:
|
|
702
|
+
memories.append(msg.content)
|
|
703
|
+
else:
|
|
704
|
+
tool_strings = []
|
|
705
|
+
for tool in msg.tool_calls:
|
|
706
|
+
tool_strings.append("Tool Name: " + tool["name"])
|
|
707
|
+
for arg_name in tool["args"]:
|
|
708
|
+
tool_strings.append(
|
|
709
|
+
f"Arg: {str(arg_name)}\nValue: "
|
|
710
|
+
f"{str(tool['args'][arg_name])}"
|
|
711
|
+
)
|
|
712
|
+
memories.append("\n".join(tool_strings))
|
|
713
|
+
memories.append(response_content)
|
|
714
|
+
self.agent_memory.add_memories(memories)
|
|
715
|
+
|
|
716
|
+
# 4) Optionally write state to disk for debugging/auditing.
|
|
717
|
+
if self.log_state:
|
|
718
|
+
self.write_state("execution_agent.json", new_state)
|
|
719
|
+
|
|
720
|
+
# 5) Return a partial state update with only the summary content.
|
|
721
|
+
return new_state
|
|
722
|
+
|
|
723
|
+
def safety_check(self, state: ExecutionState) -> ExecutionState:
|
|
724
|
+
"""Assess pending shell commands for safety and inject ToolMessages with results.
|
|
725
|
+
|
|
726
|
+
This method inspects the most recent AI tool calls, evaluates any run_cmd
|
|
727
|
+
queries against the safety prompt, and constructs ToolMessages that either
|
|
728
|
+
flag unsafe commands with reasons or confirm safe execution. If any command
|
|
729
|
+
is unsafe, the generated ToolMessages are appended to the state so the agent
|
|
730
|
+
can react without executing the command.
|
|
731
|
+
|
|
732
|
+
Args:
|
|
733
|
+
state (ExecutionState): Current execution state.
|
|
734
|
+
|
|
735
|
+
Returns:
|
|
736
|
+
ExecutionState: Either the unchanged state (all safe) or a copy with one
|
|
737
|
+
or more ToolMessages appended when unsafe commands are detected.
|
|
738
|
+
"""
|
|
739
|
+
# 1) Work on a shallow copy; inspect the most recent model message.
|
|
740
|
+
new_state = state.copy()
|
|
741
|
+
last_msg = new_state["messages"][-1]
|
|
742
|
+
|
|
743
|
+
# 1.5) Check message history length and summarize to shorten the token usage:
|
|
744
|
+
new_state = self._summarize_context(new_state)
|
|
745
|
+
|
|
746
|
+
# 2) Evaluate any pending run_cmd tool calls for safety.
|
|
747
|
+
tool_responses: list[ToolMessage] = []
|
|
748
|
+
any_unsafe = False
|
|
749
|
+
for tool_call in last_msg.tool_calls:
|
|
750
|
+
if tool_call["name"] != "run_cmd":
|
|
751
|
+
continue
|
|
752
|
+
|
|
753
|
+
query = tool_call["args"]["query"]
|
|
754
|
+
safety_result = self.llm.invoke(
|
|
755
|
+
self.get_safety_prompt(
|
|
756
|
+
query, self.safe_codes, new_state.get("code_files", [])
|
|
757
|
+
),
|
|
758
|
+
self.build_config(tags=["safety_check"]),
|
|
759
|
+
)
|
|
760
|
+
|
|
761
|
+
if "[NO]" in safety_result.content:
|
|
762
|
+
any_unsafe = True
|
|
763
|
+
tool_response = (
|
|
764
|
+
"[UNSAFE] That command `{q}` was deemed unsafe and cannot be run.\n"
|
|
765
|
+
"For reason: {r}"
|
|
766
|
+
).format(q=query, r=safety_result.content)
|
|
767
|
+
console.print(
|
|
768
|
+
"[bold red][WARNING][/bold red] Command deemed unsafe:",
|
|
769
|
+
query,
|
|
770
|
+
)
|
|
771
|
+
# Also surface the model's rationale for transparency.
|
|
772
|
+
console.print(
|
|
773
|
+
"[bold red][WARNING][/bold red] REASON:", tool_response
|
|
774
|
+
)
|
|
775
|
+
else:
|
|
776
|
+
tool_response = f"Command `{query}` passed safety check."
|
|
777
|
+
console.print(
|
|
778
|
+
f"[green]Command passed safety check:[/green] {query}"
|
|
779
|
+
)
|
|
780
|
+
|
|
781
|
+
tool_responses.append(
|
|
782
|
+
ToolMessage(
|
|
783
|
+
content=tool_response,
|
|
784
|
+
tool_call_id=tool_call["id"],
|
|
785
|
+
)
|
|
786
|
+
)
|
|
787
|
+
|
|
788
|
+
# 3) If any command is unsafe, append all tool responses; otherwise keep state.
|
|
789
|
+
if any_unsafe:
|
|
790
|
+
new_state["messages"].extend(tool_responses)
|
|
791
|
+
|
|
792
|
+
return new_state
|
|
793
|
+
|
|
794
|
+
def _build_graph(self):
|
|
795
|
+
"""Construct and compile the agent's LangGraph state machine."""
|
|
796
|
+
# Create a graph over the agent's execution state.
|
|
797
|
+
graph = StateGraph(ExecutionState)
|
|
798
|
+
|
|
799
|
+
# Register nodes:
|
|
800
|
+
# - "agent": LLM planning/execution step
|
|
801
|
+
# - "action": tool dispatch (run_cmd, write_code, etc.)
|
|
802
|
+
# - "summarize": summary/finalization step
|
|
803
|
+
# - "safety_check": gate for shell command safety
|
|
804
|
+
self.add_node(graph, self.query_executor, "agent")
|
|
805
|
+
self.add_node(graph, self.tool_node, "action")
|
|
806
|
+
self.add_node(graph, self.summarize, "summarize")
|
|
807
|
+
self.add_node(graph, self.safety_check, "safety_check")
|
|
808
|
+
|
|
809
|
+
# Set entrypoint: execution starts with the "agent" node.
|
|
810
|
+
graph.set_entry_point("agent")
|
|
811
|
+
|
|
812
|
+
# From "agent", either continue (tools) or finish (summarize),
|
|
813
|
+
# based on presence of tool calls in the last message.
|
|
814
|
+
graph.add_conditional_edges(
|
|
815
|
+
"agent",
|
|
816
|
+
self._wrap_cond(should_continue, "should_continue", "execution"),
|
|
817
|
+
{"continue": "safety_check", "summarize": "summarize"},
|
|
818
|
+
)
|
|
819
|
+
|
|
820
|
+
# From "safety_check", route to tools if safe, otherwise back to agent
|
|
821
|
+
# to revise the plan without executing unsafe commands.
|
|
822
|
+
graph.add_conditional_edges(
|
|
823
|
+
"safety_check",
|
|
824
|
+
self._wrap_cond(command_safe, "command_safe", "execution"),
|
|
825
|
+
{"safe": "action", "unsafe": "agent"},
|
|
826
|
+
)
|
|
827
|
+
|
|
828
|
+
# After tools run, return control to the agent for the next step.
|
|
829
|
+
graph.add_edge("action", "agent")
|
|
830
|
+
|
|
831
|
+
# The graph completes at the "summarize" node.
|
|
832
|
+
graph.set_finish_point("summarize")
|
|
833
|
+
|
|
834
|
+
# Compile and return the executable graph (optionally with a checkpointer).
|
|
835
|
+
return graph.compile(checkpointer=self.checkpointer)
|
|
836
|
+
|
|
837
|
+
async def add_mcp_tool(
|
|
838
|
+
self, mcp_tools: Callable[..., Any] | list[Callable[..., Any]]
|
|
839
|
+
) -> None:
|
|
840
|
+
client = MultiServerMCPClient(mcp_tools)
|
|
841
|
+
tools = await client.get_tools()
|
|
842
|
+
self.add_tool(tools)
|
|
843
|
+
|
|
844
|
+
def add_tool(
|
|
845
|
+
self, new_tools: Callable[..., Any] | list[Callable[..., Any]]
|
|
846
|
+
) -> None:
|
|
847
|
+
if isinstance(new_tools, list):
|
|
848
|
+
self.tools.extend([convert_to_tool(x) for x in new_tools])
|
|
849
|
+
elif isinstance(new_tools, StructuredTool) or isinstance(
|
|
850
|
+
new_tools, Callable
|
|
851
|
+
):
|
|
852
|
+
self.tools.append(convert_to_tool(new_tools))
|
|
853
|
+
else:
|
|
854
|
+
raise TypeError("Expected a callable or a list of callables.")
|
|
855
|
+
self.tool_node = ToolNode(self.tools)
|
|
856
|
+
self.llm = self.llm.bind_tools(self.tools)
|
|
857
|
+
self._action = self._build_graph()
|
|
858
|
+
|
|
859
|
+
def list_tools(self) -> None:
|
|
860
|
+
print(
|
|
861
|
+
f"Available tool names are: {', '.join([x.name for x in self.tools])}."
|
|
862
|
+
)
|
|
863
|
+
|
|
864
|
+
def remove_tool(self, cut_tools: str | list[str]) -> None:
|
|
865
|
+
if isinstance(cut_tools, str):
|
|
866
|
+
self.remove_tool([cut_tools])
|
|
867
|
+
elif isinstance(cut_tools, list):
|
|
868
|
+
self.tools = [x for x in self.tools if x.name not in cut_tools]
|
|
869
|
+
self.tool_node = ToolNode(self.tools)
|
|
870
|
+
self.llm = self.llm.bind_tools(self.tools)
|
|
871
|
+
self._action = self._build_graph()
|
|
872
|
+
else:
|
|
873
|
+
raise TypeError(
|
|
874
|
+
"Expected a string or a list of strings describing the tools to remove."
|
|
875
|
+
)
|
|
876
|
+
|
|
877
|
+
def _invoke(
|
|
878
|
+
self, inputs: Mapping[str, Any], recursion_limit: int = 999_999, **_
|
|
879
|
+
):
|
|
880
|
+
"""Invoke the compiled graph with inputs under a specified recursion limit.
|
|
881
|
+
|
|
882
|
+
This method builds a LangGraph config with the provided recursion limit
|
|
883
|
+
and a "graph" tag, then delegates to the compiled graph's invoke method.
|
|
884
|
+
"""
|
|
885
|
+
# Build invocation config with a generous recursion limit for long runs.
|
|
886
|
+
config = self.build_config(
|
|
887
|
+
recursion_limit=recursion_limit, tags=["graph"]
|
|
888
|
+
)
|
|
889
|
+
|
|
890
|
+
# Delegate execution to the compiled graph.
|
|
891
|
+
return self._action.invoke(inputs, config)
|
|
892
|
+
|
|
893
|
+
def _ainvoke(
|
|
894
|
+
self, inputs: Mapping[str, Any], recursion_limit: int = 999_999, **_
|
|
895
|
+
):
|
|
896
|
+
"""Invoke the compiled graph with inputs under a specified recursion limit.
|
|
897
|
+
|
|
898
|
+
This method builds a LangGraph config with the provided recursion limit
|
|
899
|
+
and a "graph" tag, then delegates to the compiled graph's invoke method.
|
|
900
|
+
"""
|
|
901
|
+
# Build invocation config with a generous recursion limit for long runs.
|
|
902
|
+
config = self.build_config(
|
|
903
|
+
recursion_limit=recursion_limit, tags=["graph"]
|
|
904
|
+
)
|
|
905
|
+
|
|
906
|
+
# Delegate execution to the compiled graph.
|
|
907
|
+
return self._action.ainvoke(inputs, config)
|
|
908
|
+
|
|
909
|
+
# This property is trying to stop people bypassing invoke
|
|
910
|
+
@property
|
|
911
|
+
def action(self):
|
|
912
|
+
"""Property used to affirm `action` attribute is unsupported."""
|
|
913
|
+
raise AttributeError(
|
|
914
|
+
"Use .stream(...) or .invoke(...); direct .action access is unsupported."
|
|
915
|
+
)
|