ursa-ai 0.0.3__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ursa-ai might be problematic. Click here for more details.

@@ -0,0 +1,332 @@
1
+ import os
2
+ import subprocess
3
+ from typing import Annotated, Literal
4
+
5
+ from langchain_core.language_models import BaseChatModel
6
+ from langchain_core.messages import HumanMessage, SystemMessage
7
+ from langchain_core.tools import tool
8
+ from langgraph.graph import END, START, StateGraph
9
+ from langgraph.graph.message import add_messages
10
+ from langgraph.prebuilt import ToolNode, InjectedState
11
+ from typing_extensions import TypedDict
12
+
13
+ from ..prompt_library.code_review_prompts import (
14
+ get_code_review_prompt,
15
+ get_plan_review_prompt,
16
+ )
17
+ from ..prompt_library.execution_prompts import summarize_prompt
18
+
19
+ # from langchain_core.runnables.graph import MermaidDrawMethod
20
+ from .base import BaseAgent
21
+
22
+ # --- ANSI color codes ---
23
+ GREEN = "\033[92m"
24
+ BLUE = "\033[94m"
25
+ RED = "\033[91m"
26
+ RESET = "\033[0m"
27
+ BOLD = "\033[1m"
28
+
29
+ code_extensions = [".py", ".R", ".jl",
30
+ ".c", ".cpp", ".cc", ".cxx", ".c++", ".C",
31
+ ".f90", ".f95", ".f03"]
32
+
33
+
34
+ class CodeReviewState(TypedDict):
35
+ messages: Annotated[list, add_messages]
36
+ project_prompt: str
37
+ code_files: list[str]
38
+ edited_files: list[str]
39
+ workspace: str
40
+ iteration: int
41
+
42
+
43
+ class CodeReviewAgent(BaseAgent):
44
+ def __init__(
45
+ self, llm: str | BaseChatModel = "openai/gpt-4o-mini", **kwargs
46
+ ):
47
+ super().__init__(llm, **kwargs)
48
+ print("### WORK IN PROGRESS ###")
49
+ print(
50
+ "CODE REVIEW AGENT NOT YET FULLY IMPLEMENTED AND TESTED. BE AWARE THAT IT WILL LIKELY NOT WORK AS INTENDED YET."
51
+ )
52
+ self.summarize_prompt = summarize_prompt
53
+ self.tools = [run_cmd, write_file, read_file]
54
+ self.tool_node = ToolNode(self.tools)
55
+ self.llm = self.llm.bind_tools(self.tools)
56
+
57
+ self._initialize_agent()
58
+
59
+ # Define the function that calls the model
60
+ def plan_review(self, state: CodeReviewState) -> CodeReviewState:
61
+ new_state = state.copy()
62
+
63
+ assert "workspace" in new_state.keys(), "No workspace set for review!"
64
+
65
+ plan_review_prompt = get_plan_review_prompt(
66
+ project_prompt=state["project_prompt"],
67
+ file_list=state["code_files"],
68
+ )
69
+ new_state["messages"] = [
70
+ SystemMessage(content=plan_review_prompt)
71
+ ] + state["messages"]
72
+ response = self.llm.invoke(new_state["messages"], {"configurable": {"thread_id": self.thread_id}})
73
+ return {"messages": [response]}
74
+
75
+ # Define the function that calls the model
76
+ def file_review(self, state: CodeReviewState) -> CodeReviewState:
77
+ new_state = state.copy()
78
+ code_review_prompt = get_code_review_prompt(
79
+ project_prompt=state["project_prompt"],
80
+ file_list=state["code_files"],
81
+ )
82
+ filename = state["code_files"][state["iteration"]]
83
+ new_state["messages"][0] = SystemMessage(content=code_review_prompt)
84
+ new_state["messages"].append(
85
+ HumanMessage(content=f"Please review {filename}")
86
+ )
87
+ response = self.llm.invoke(new_state["messages"], {"configurable": {"thread_id": self.thread_id}})
88
+ return {"messages": [response]}
89
+
90
+ # Define the function that calls the model
91
+ def summarize(self, state: CodeReviewState) -> CodeReviewState:
92
+ messages = [SystemMessage(content=summarize_prompt)] + state["messages"]
93
+ response = self.llm.invoke(messages, {"configurable": {"thread_id": self.thread_id}})
94
+ return {"messages": [response.content]}
95
+
96
+ def increment(self, state: CodeReviewState) -> CodeReviewState:
97
+ new_state = state.copy()
98
+ new_state["iteration"] += 1
99
+ if new_state["iteration"] >= len(new_state["code_files"]):
100
+ new_state["iteration"] = -1
101
+ print(
102
+ f"On to file {new_state['iteration']+1} out of {len(new_state['code_files'])}"
103
+ )
104
+ return new_state
105
+
106
+ # Define the function that calls the model
107
+ def safety_check(self, state: CodeReviewState) -> CodeReviewState:
108
+ new_state = state.copy()
109
+ if state["messages"][-1].tool_calls[0]["name"] == "run_cmd":
110
+ query = state["messages"][-1].tool_calls[0]["args"]["query"]
111
+ safety_check = self.llm.invoke(
112
+ (
113
+ "Assume commands to run python and Julia are safe because "
114
+ "the files are from a trusted source. "
115
+ "Answer only either [YES] or [NO]. Is this command safe to run: "
116
+ )
117
+ + query,
118
+ {"configurable": {"thread_id": self.thread_id}}
119
+ )
120
+ if "[NO]" in safety_check.content:
121
+ print(f"{RED}{BOLD} [WARNING] {RESET}")
122
+ print(
123
+ f"{RED}{BOLD} [WARNING] That command deemed unsafe and cannot be run: {RESET}",
124
+ query,
125
+ " --- ",
126
+ safety_check,
127
+ )
128
+ print(f"{RED}{BOLD} [WARNING] {RESET}")
129
+ return {
130
+ "messages": [
131
+ "[UNSAFE] That command deemed unsafe and cannot be run: "
132
+ + query
133
+ ]
134
+ }
135
+
136
+ print(f"{GREEN}[PASSED] the safety check: {RESET}" + query)
137
+ elif state["messages"][-1].tool_calls[0]["name"] == "write_code":
138
+ fn = (
139
+ state["messages"][-1]
140
+ .tool_calls[0]["args"]
141
+ .get("filename", None)
142
+ )
143
+ if "code_files" in new_state:
144
+ if fn not in new_state["code_files"]:
145
+ new_state["code_files"].append(fn)
146
+ new_state["edited_files"].append(fn)
147
+ else:
148
+ new_state["edited_files"].append(fn)
149
+ else:
150
+ new_state["code_files"] = [fn]
151
+
152
+ return new_state
153
+
154
+ def _initialize_agent(self):
155
+ self.graph = StateGraph(CodeReviewState)
156
+
157
+ self.graph.add_node("plan_review", self.plan_review)
158
+ self.graph.add_node("file_review", self.file_review)
159
+ self.graph.add_node("increment", self.increment)
160
+ self.graph.add_node("action", self.tool_node)
161
+ self.graph.add_node("summarize", self.summarize)
162
+ self.graph.add_node("safety_check", self.safety_check)
163
+
164
+ # Set the entrypoint as `agent`
165
+ # This means that this node is the first one called
166
+ self.graph.add_edge(START, "plan_review")
167
+
168
+ self.graph.add_conditional_edges(
169
+ "file_review",
170
+ should_continue,
171
+ {
172
+ "action": "safety_check",
173
+ "increment": "increment",
174
+ "summarize": "summarize",
175
+ },
176
+ )
177
+
178
+ self.graph.add_conditional_edges(
179
+ "safety_check",
180
+ command_safe,
181
+ {
182
+ "safe": "action",
183
+ "unsafe": "file_review",
184
+ },
185
+ )
186
+
187
+ self.graph.add_edge("plan_review", "file_review")
188
+ self.graph.add_edge("action", "file_review")
189
+ self.graph.add_edge("increment", "file_review")
190
+ self.graph.add_edge("summarize", END)
191
+
192
+ self.action = self.graph.compile(checkpointer=self.checkpointer)
193
+ # self.action.get_graph().draw_mermaid_png(output_file_path="code_review_agent_graph.png", draw_method=MermaidDrawMethod.PYPPETEER)
194
+
195
+ def run(self, prompt, workspace):
196
+ code_files = [x for x in os.listdir(workspace) if any([ext in x for ext in code_extensions])]
197
+ initial_state = {
198
+ "messages": [],
199
+ "project_prompt": prompt,
200
+ "code_files": code_files,
201
+ "edited_files": [],
202
+ "iteration": 0,
203
+ "workspace":workspace
204
+ }
205
+ return self.action.invoke(initial_state, {"configurable": {"thread_id": self.thread_id}})
206
+
207
+
208
+ @tool
209
+ def run_cmd(query: str, state: Annotated[dict, InjectedState]) -> str:
210
+ """Run command from commandline"""
211
+ workspace_dir = state["workspace"]
212
+
213
+ print("RUNNING: ", query)
214
+ process = subprocess.Popen(
215
+ query.split(" "),
216
+ stdout=subprocess.PIPE,
217
+ stderr=subprocess.PIPE,
218
+ text=True,
219
+ cwd=workspace_dir,
220
+ )
221
+
222
+ stdout, stderr = process.communicate(timeout=600)
223
+
224
+ print("STDOUT: ", stdout)
225
+ print("STDERR: ", stderr)
226
+
227
+ return f"STDOUT: {stdout} and STDERR: {stderr}"
228
+
229
+
230
+ @tool
231
+ def read_file(filename: str, state: Annotated[dict, InjectedState]):
232
+ """
233
+ Reads in a file with a given filename into a string
234
+
235
+ Args:
236
+ filename: string filename to read in
237
+ """
238
+ workspace_dir = state["workspace"]
239
+ full_filename = os.path.join(workspace_dir, filename)
240
+
241
+ print("[READING]: ", full_filename)
242
+ with open(full_filename, "r") as file:
243
+ file_contents = file.read()
244
+ return file_contents
245
+
246
+
247
+ @tool
248
+ def write_file(code: str, filename: str, state: Annotated[dict, InjectedState]):
249
+ """
250
+ Writes text to a file in the given workspace as requested.
251
+
252
+ Args:
253
+ code: Text to write to a file
254
+ filename: the filename to write to
255
+
256
+ Returns:
257
+ Execution results
258
+ """
259
+ workspace_dir = state["workspace"]
260
+
261
+ print("[WRITING]: ", filename)
262
+ try:
263
+ # Extract code if wrapped in markdown code blocks
264
+ if "```" in code:
265
+ code_parts = code.split("```")
266
+ if len(code_parts) >= 3:
267
+ # Extract the actual code
268
+ if "\n" in code_parts[1]:
269
+ code = "\n".join(code_parts[1].strip().split("\n")[1:])
270
+ else:
271
+ code = code_parts[2].strip()
272
+
273
+ # Write code to a file
274
+ code_file = os.path.join(workspace_dir, filename)
275
+
276
+ with open(code_file, "w") as f:
277
+ f.write(code)
278
+ print(f"Written code to file: {code_file}")
279
+
280
+ return f"File {filename} written successfully."
281
+
282
+ except Exception as e:
283
+ print(f"Error generating code: {str(e)}")
284
+ # Return minimal code that prints the error
285
+ return f"Failed to write {filename} successfully."
286
+
287
+
288
+ # Define the function that determines whether to continue or not
289
+ def should_continue(
290
+ state: CodeReviewState,
291
+ ) -> Literal["summarize", "increment", "action"]:
292
+ messages = state["messages"]
293
+ last_message = messages[-1]
294
+ # If there is no tool call, then we finish
295
+ if not last_message.tool_calls:
296
+ if state["iteration"] == -1:
297
+ target_node = "summarize"
298
+ else:
299
+ target_node = "increment"
300
+ # Otherwise if there is, we use the tool
301
+ else:
302
+ target_node = "action"
303
+ return target_node
304
+
305
+
306
+ # Define the function that determines whether to continue or not
307
+ def command_safe(state: CodeReviewState) -> Literal["safe", "unsafe"]:
308
+ messages = state["messages"]
309
+ last_message = messages[-1]
310
+ if "[UNSAFE]" in last_message.content:
311
+ return "unsafe"
312
+ else:
313
+ return "safe"
314
+
315
+
316
+ def main():
317
+ code_review_agent = CodeReviewAgent(llm="openai/o3-mini")
318
+ initial_state = {
319
+ "messages": [],
320
+ "project_prompt": "Find a city with as least 10 vowels in its name.",
321
+ "code_files": ["vowel_count.py"],
322
+ "edited_files": [],
323
+ "iteration": 0,
324
+ }
325
+ result = code_review_agent.action.invoke(initial_state), {"configurable": {"thread_id": self.thread_id}}
326
+ for x in result["messages"]:
327
+ print(x.content)
328
+ return result
329
+
330
+
331
+ if __name__ == "__main__":
332
+ main()