ursa-ai 0.0.3__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ursa-ai might be problematic. Click here for more details.

@@ -0,0 +1,497 @@
1
+ import os
2
+
3
+ # from langchain_core.runnables.graph import MermaidDrawMethod
4
+ import subprocess
5
+ from pathlib import Path
6
+ from typing import Annotated, Any, Literal, Optional
7
+
8
+ import coolname
9
+ from langchain_community.tools import DuckDuckGoSearchResults # TavilySearchResults,
10
+ from langchain_core.language_models import BaseChatModel
11
+ from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
12
+ from langchain_core.tools import InjectedToolCallId, tool
13
+ from langgraph.graph import END, START, StateGraph
14
+ from langgraph.graph.message import add_messages
15
+ from langgraph.prebuilt import InjectedState, ToolNode
16
+ from langgraph.types import Command
17
+ from litellm import ContentPolicyViolationError
18
+
19
+ # Rich
20
+ from rich import get_console
21
+ from rich.panel import Panel
22
+ from rich.syntax import Syntax
23
+ from typing_extensions import TypedDict
24
+
25
+ from ..prompt_library.execution_prompts import executor_prompt, summarize_prompt
26
+ from ..util.diff_renderer import DiffRenderer
27
+ from ..util.memory_logger import AgentMemory
28
+ from .base import BaseAgent
29
+
30
+ console = get_console() # always returns the same instance
31
+
32
+ # --- ANSI color codes ---
33
+ GREEN = "\033[92m"
34
+ BLUE = "\033[94m"
35
+ RED = "\033[91m"
36
+ RESET = "\033[0m"
37
+ BOLD = "\033[1m"
38
+
39
+
40
+ class ExecutionState(TypedDict):
41
+ messages: Annotated[list, add_messages]
42
+ current_progress: str
43
+ code_files: list[str]
44
+ workspace: str
45
+ symlinkdir: dict
46
+
47
+
48
+ class ExecutionAgent(BaseAgent):
49
+ def __init__(
50
+ self,
51
+ llm: str | BaseChatModel = "openai/gpt-4o-mini",
52
+ agent_memory: Optional [Any | AgentMemory] = None,
53
+ log_state: bool = False,
54
+ **kwargs,
55
+ ):
56
+ super().__init__(llm, **kwargs)
57
+ self.agent_memory = agent_memory
58
+ self.executor_prompt = executor_prompt
59
+ self.summarize_prompt = summarize_prompt
60
+ self.tools = [run_cmd, write_code, edit_code, search_tool]
61
+ self.tool_node = ToolNode(self.tools)
62
+ self.llm = self.llm.bind_tools(self.tools)
63
+ self.log_state = log_state
64
+
65
+ self._initialize_agent()
66
+
67
+ # Define the function that calls the model
68
+ def query_executor(self, state: ExecutionState) -> ExecutionState:
69
+ new_state = state.copy()
70
+ if "workspace" not in new_state.keys():
71
+ new_state["workspace"] = coolname.generate_slug(2)
72
+ print(
73
+ f"{RED}Creating the folder {BLUE}{BOLD}{new_state['workspace']}{RESET}{RED} for this project.{RESET}"
74
+ )
75
+ os.makedirs(new_state["workspace"], exist_ok=True)
76
+
77
+ # code related to symlink
78
+ if (
79
+ "symlinkdir" in new_state.keys()
80
+ and "is_linked" not in new_state["symlinkdir"].keys()
81
+ ):
82
+ # symlinkdir = {"source": "foo", "dest": "bar"}
83
+ symlinkdir = new_state["symlinkdir"]
84
+ # user provided a symlinkdir key - let's do the linking!
85
+
86
+ src = Path(symlinkdir["source"]).expanduser().resolve()
87
+ workspace_root = Path(new_state["workspace"]).expanduser().resolve()
88
+ dst = workspace_root / symlinkdir["dest"] # prepend workspace
89
+
90
+ # if you want to replace an existing link/file, unlink it first
91
+ if dst.exists() or dst.is_symlink():
92
+ dst.unlink()
93
+
94
+ # create parent dirs for the link location if they don’t exist
95
+ dst.parent.mkdir(parents=True, exist_ok=True)
96
+
97
+ # actually make the link (tell pathlib it’s a directory target)
98
+ dst.symlink_to(src, target_is_directory=src.is_dir())
99
+ print(f"{RED}Symlinked {src} (source) --> {dst} (dest)")
100
+ # note that we've done the symlink now, so don't need to do it later
101
+ new_state["symlinkdir"]["is_linked"] = True
102
+
103
+ if type(new_state["messages"][0]) == SystemMessage:
104
+ new_state["messages"][0] = SystemMessage(content=self.executor_prompt)
105
+ else:
106
+ new_state["messages"] = [
107
+ SystemMessage(content=self.executor_prompt)
108
+ ] + state["messages"]
109
+ try:
110
+ response = self.llm.invoke(
111
+ new_state["messages"], {"configurable": {"thread_id": self.thread_id}}
112
+ )
113
+ except ContentPolicyViolationError as e:
114
+ print("Error: ", e, " ", new_state["messages"][-1].content)
115
+ if self.log_state:
116
+ self.write_state("execution_agent.json", new_state)
117
+ return {"messages": [response], "workspace": new_state["workspace"]}
118
+
119
+ # Define the function that calls the model
120
+ def summarize(self, state: ExecutionState) -> ExecutionState:
121
+ messages = [SystemMessage(content=summarize_prompt)] + state["messages"]
122
+ try:
123
+ response = self.llm.invoke(
124
+ messages, {"configurable": {"thread_id": self.thread_id}}
125
+ )
126
+ except ContentPolicyViolationError as e:
127
+ print("Error: ", e, " ", messages[-1].content)
128
+ if self.agent_memory:
129
+ memories = []
130
+ # Handle looping through the messages
131
+ for x in state["messages"]:
132
+ if not type(x) == AIMessage:
133
+ memories.append(x.content)
134
+ elif not x.tool_calls:
135
+ memories.append(x.content)
136
+ else:
137
+ tool_strings = []
138
+ for tool in x.tool_calls:
139
+ tool_name = "Tool Name: " + tool["name"]
140
+ tool_strings.append(tool_name)
141
+ for y in tool["args"]:
142
+ tool_strings.append(
143
+ f'Arg: {str(y)}\nValue: {str(tool["args"][y])}'
144
+ )
145
+ memories.append("\n".join(tool_strings))
146
+ memories.append(response.content)
147
+ self.agent_memory.add_memories(memories)
148
+ save_state = state.copy()
149
+ save_state["messages"].append(response)
150
+ if self.log_state:
151
+ self.write_state("execution_agent.json", save_state)
152
+ return {"messages": [response.content]}
153
+
154
+ # Define the function that calls the model
155
+ def safety_check(self, state: ExecutionState) -> ExecutionState:
156
+ """
157
+ Validate the safety of a pending shell command.
158
+
159
+ Args:
160
+ state: Current execution state.
161
+
162
+ Returns:
163
+ Either the unchanged state (safe) or a state with tool message(s) (unsafe).
164
+ """
165
+ new_state = state.copy()
166
+ last_msg = new_state["messages"][-1]
167
+
168
+ tool_responses = []
169
+ tool_failed = False
170
+ for tool_call in last_msg.tool_calls:
171
+ call_name = tool_call["name"]
172
+
173
+ if call_name == "run_cmd":
174
+ query = tool_call["args"]["query"]
175
+ safety_check = self.llm.invoke(
176
+ (
177
+ "Assume commands to run/install python and Julia files are safe because "
178
+ "the files are from a trusted source. "
179
+ f"Explain why, followed by an answer [YES] or [NO]. Is this command safe to run: {query}"
180
+ )
181
+ )
182
+
183
+ if "[NO]" in safety_check.content:
184
+ tool_failed = True
185
+
186
+ tool_response = f"""
187
+ [UNSAFE] That command `{query}` was deemed unsafe and cannot be run.
188
+ For reason: {safety_check.content}
189
+ """
190
+ console.print(
191
+ "[bold red][WARNING][/bold red] Command deemed unsafe:", query
192
+ )
193
+ # and tell the user the reason
194
+ console.print(
195
+ "[bold red][WARNING][/bold red] REASON:", tool_response
196
+ )
197
+
198
+ else:
199
+ tool_response = f"Command `{query}` passed safety check."
200
+ console.print(
201
+ f"[green]Command passed safety check:[/green] {query}"
202
+ )
203
+
204
+ tool_responses.append(
205
+ ToolMessage(
206
+ content=tool_response,
207
+ tool_call_id=tool_call["id"],
208
+ )
209
+ )
210
+
211
+ if tool_failed:
212
+ new_state["messages"].extend(tool_responses)
213
+
214
+ return new_state
215
+
216
+ def _initialize_agent(self):
217
+ self.graph = StateGraph(ExecutionState)
218
+
219
+ self.graph.add_node("agent", self.query_executor)
220
+ self.graph.add_node("action", self.tool_node)
221
+ self.graph.add_node("summarize", self.summarize)
222
+ self.graph.add_node("safety_check", self.safety_check)
223
+
224
+ # Set the entrypoint as `agent`
225
+ # This means that this node is the first one called
226
+ self.graph.add_edge(START, "agent")
227
+
228
+ self.graph.add_conditional_edges(
229
+ "agent",
230
+ should_continue,
231
+ {
232
+ "continue": "safety_check",
233
+ "summarize": "summarize",
234
+ },
235
+ )
236
+
237
+ self.graph.add_conditional_edges(
238
+ "safety_check",
239
+ command_safe,
240
+ {
241
+ "safe": "action",
242
+ "unsafe": "agent",
243
+ },
244
+ )
245
+
246
+ self.graph.add_edge("action", "agent")
247
+ self.graph.add_edge("summarize", END)
248
+
249
+ self.action = self.graph.compile(checkpointer=self.checkpointer)
250
+ # self.action.get_graph().draw_mermaid_png(output_file_path="execution_agent_graph.png", draw_method=MermaidDrawMethod.PYPPETEER)
251
+
252
+ def run(self, prompt, recursion_limit=1000):
253
+ inputs = {"messages": [HumanMessage(content=prompt)]}
254
+ return self.action.invoke(
255
+ inputs,
256
+ {
257
+ "recursion_limit": recursion_limit,
258
+ "configurable": {"thread_id": self.thread_id},
259
+ },
260
+ )
261
+
262
+
263
+ @tool
264
+ def run_cmd(query: str, state: Annotated[dict, InjectedState]) -> str:
265
+ """
266
+ Run a commandline command from using the subprocess package in python
267
+
268
+ Args:
269
+ query: commandline command to be run as a string given to the subprocess.run command.
270
+ """
271
+ workspace_dir = state["workspace"]
272
+ print("RUNNING: ", query)
273
+ try:
274
+ result = subprocess.run(
275
+ query,
276
+ text=True,
277
+ shell=True,
278
+ timeout=60000,
279
+ capture_output=True,
280
+ cwd=workspace_dir,
281
+ )
282
+ stdout, stderr = result.stdout, result.stderr
283
+ except KeyboardInterrupt:
284
+ print("Keyboard Interrupt of command: ", query)
285
+ stdout, stderr = "", "KeyboardInterrupt:"
286
+
287
+ print("STDOUT: ", stdout)
288
+ print("STDERR: ", stderr)
289
+
290
+ return f"STDOUT: {stdout} and STDERR: {stderr}"
291
+
292
+
293
+ def _strip_fences(snippet: str) -> str:
294
+ """
295
+ Remove leading markdown ``` fence
296
+ """
297
+ if "```" not in snippet:
298
+ return snippet
299
+
300
+ parts = snippet.split("```")
301
+ if len(parts) < 3:
302
+ return snippet
303
+
304
+ body = parts[1]
305
+ return "\n".join(body.split("\n")[1:]) if "\n" in body else body.strip()
306
+
307
+
308
+ @tool
309
+ def write_code(
310
+ code: str,
311
+ filename: str,
312
+ tool_call_id: Annotated[str, InjectedToolCallId],
313
+ state: Annotated[dict, InjectedState],
314
+ ) -> Command:
315
+ """Write *code* to *filename*.
316
+
317
+ Args:
318
+ code: Source code as a string.
319
+ filename: Target filename (including extension).
320
+
321
+ Returns:
322
+ Success / failure message.
323
+ """
324
+ workspace_dir = state["workspace"]
325
+ console.print("[cyan]Writing file:[/]", filename)
326
+
327
+ # Clean up markdown fences
328
+ code = _strip_fences(code)
329
+
330
+ # Syntax-highlighted preview
331
+ try:
332
+ lexer_name = Syntax.guess_lexer(filename, code)
333
+ except Exception:
334
+ lexer_name = "text"
335
+
336
+ console.print(
337
+ Panel(
338
+ Syntax(code, lexer_name, line_numbers=True),
339
+ title="File Preview",
340
+ border_style="cyan",
341
+ )
342
+ )
343
+
344
+ code_file = os.path.join(workspace_dir, filename)
345
+ try:
346
+ with open(code_file, "w", encoding="utf-8") as f:
347
+ f.write(code)
348
+ except Exception as exc:
349
+ console.print(
350
+ "[bold bright_white on red] :heavy_multiplication_x: [/] "
351
+ "[red]Failed to write file:[/]",
352
+ exc,
353
+ )
354
+ return f"Failed to write {filename}."
355
+
356
+ console.print(
357
+ f"[bold bright_white on green] :heavy_check_mark: [/] "
358
+ f"[green]File written:[/] {code_file}"
359
+ )
360
+
361
+ # Append the file to the list in state
362
+ file_list = state.get("code_files", [])
363
+ file_list.append(filename)
364
+
365
+ # Create a tool message to send back
366
+ msg = ToolMessage(
367
+ content=f"File {filename} written successfully.",
368
+ tool_call_id=tool_call_id,
369
+ )
370
+
371
+ # Return updated code files list & the message
372
+ return Command(
373
+ update={
374
+ "code_files": file_list,
375
+ "messages": [msg],
376
+ }
377
+ )
378
+
379
+
380
+ @tool
381
+ def edit_code(
382
+ old_code: str,
383
+ new_code: str,
384
+ filename: str,
385
+ state: Annotated[dict, InjectedState],
386
+ ) -> str:
387
+ """Replace the **first** occurrence of *old_code* with *new_code* in *filename*.
388
+
389
+ Args:
390
+ old_code: Code fragment to search for.
391
+ new_code: Replacement fragment.
392
+ filename: Target file inside the workspace.
393
+
394
+ Returns:
395
+ Success / failure message.
396
+ """
397
+ workspace_dir = state["workspace"]
398
+ console.print("[cyan]Editing file:[/cyan]", filename)
399
+
400
+ code_file = os.path.join(workspace_dir, filename)
401
+ try:
402
+ with open(code_file, "r", encoding="utf-8") as f:
403
+ content = f.read()
404
+ except FileNotFoundError:
405
+ console.print(
406
+ "[bold bright_white on red] :heavy_multiplication_x: [/] "
407
+ "[red]File not found:[/]",
408
+ filename,
409
+ )
410
+ return f"Failed: {filename} not found."
411
+
412
+ # Clean up markdown fences
413
+ old_code_clean = _strip_fences(old_code)
414
+ new_code_clean = _strip_fences(new_code)
415
+
416
+ if old_code_clean not in content:
417
+ console.print(f"[yellow] ⚠️ 'old_code' not found in file'; no changes made.[/]")
418
+ return f"No changes made to {filename}: 'old_code' not found in file."
419
+
420
+ updated = content.replace(old_code_clean, new_code_clean, 1)
421
+
422
+ console.print(
423
+ Panel(
424
+ DiffRenderer(content, updated, filename),
425
+ title="Diff Preview",
426
+ border_style="cyan",
427
+ )
428
+ )
429
+
430
+ try:
431
+ with open(code_file, "w", encoding="utf-8") as f:
432
+ f.write(updated)
433
+ except Exception as exc:
434
+ console.print(
435
+ "[bold bright_white on red] :heavy_multiplication_x: [/] "
436
+ "[red]Failed to write file:[/]",
437
+ exc,
438
+ )
439
+ return f"Failed to edit {filename}."
440
+
441
+ console.print(
442
+ f"[bold bright_white on green] :heavy_check_mark: [/] "
443
+ f"[green]File updated:[/] {code_file}"
444
+ )
445
+ return f"File {filename} updated successfully."
446
+
447
+
448
+ search_tool = DuckDuckGoSearchResults(output_format="json", num_results=10)
449
+ # search_tool = TavilySearchResults(max_results=10, search_depth="advanced", include_answer=True)
450
+
451
+
452
+ # Define the function that determines whether to continue or not
453
+ def should_continue(state: ExecutionState) -> Literal["summarize", "continue"]:
454
+ messages = state["messages"]
455
+ last_message = messages[-1]
456
+ # If there is no tool call, then we finish
457
+ if not last_message.tool_calls:
458
+ return "summarize"
459
+ # Otherwise if there is, we continue
460
+ else:
461
+ return "continue"
462
+
463
+
464
+ # Define the function that determines whether to continue or not
465
+ def command_safe(state: ExecutionState) -> Literal["safe", "unsafe"]:
466
+ """
467
+ Return graph edge "safe" if the last command was safe, otherwise return edge "unsafe"
468
+ """
469
+
470
+ index = -1
471
+ message = state["messages"][index]
472
+ # Loop through all the consecutive tool messages in reverse order
473
+ while type(message) == ToolMessage:
474
+ if "[UNSAFE]" in message.content:
475
+ return "unsafe"
476
+
477
+ index -= 1
478
+ message = state["messages"][index]
479
+
480
+ return "safe"
481
+
482
+
483
+ def main():
484
+ execution_agent = ExecutionAgent()
485
+ problem_string = "Write and execute a python script to print the first 10 integers."
486
+ inputs = {
487
+ "messages": [HumanMessage(content=problem_string)]
488
+ } # , "workspace":"dummy_test"}
489
+ result = execution_agent.action.invoke(
490
+ inputs, {"configurable": {"thread_id": execution_agent.thread_id}}
491
+ )
492
+ print(result["messages"][-1].content)
493
+ return result
494
+
495
+
496
+ if __name__ == "__main__":
497
+ main()