ursa-ai 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ursa/__init__.py +3 -0
- ursa/agents/__init__.py +32 -0
- ursa/agents/acquisition_agents.py +812 -0
- ursa/agents/arxiv_agent.py +429 -0
- ursa/agents/base.py +728 -0
- ursa/agents/chat_agent.py +60 -0
- ursa/agents/code_review_agent.py +341 -0
- ursa/agents/execution_agent.py +915 -0
- ursa/agents/hypothesizer_agent.py +614 -0
- ursa/agents/lammps_agent.py +465 -0
- ursa/agents/mp_agent.py +204 -0
- ursa/agents/optimization_agent.py +410 -0
- ursa/agents/planning_agent.py +219 -0
- ursa/agents/rag_agent.py +304 -0
- ursa/agents/recall_agent.py +54 -0
- ursa/agents/websearch_agent.py +196 -0
- ursa/cli/__init__.py +363 -0
- ursa/cli/hitl.py +516 -0
- ursa/cli/hitl_api.py +75 -0
- ursa/observability/metrics_charts.py +1279 -0
- ursa/observability/metrics_io.py +11 -0
- ursa/observability/metrics_session.py +750 -0
- ursa/observability/pricing.json +97 -0
- ursa/observability/pricing.py +321 -0
- ursa/observability/timing.py +1466 -0
- ursa/prompt_library/__init__.py +0 -0
- ursa/prompt_library/code_review_prompts.py +51 -0
- ursa/prompt_library/execution_prompts.py +50 -0
- ursa/prompt_library/hypothesizer_prompts.py +17 -0
- ursa/prompt_library/literature_prompts.py +11 -0
- ursa/prompt_library/optimization_prompts.py +131 -0
- ursa/prompt_library/planning_prompts.py +79 -0
- ursa/prompt_library/websearch_prompts.py +131 -0
- ursa/tools/__init__.py +0 -0
- ursa/tools/feasibility_checker.py +114 -0
- ursa/tools/feasibility_tools.py +1075 -0
- ursa/tools/run_command.py +27 -0
- ursa/tools/write_code.py +42 -0
- ursa/util/__init__.py +0 -0
- ursa/util/diff_renderer.py +128 -0
- ursa/util/helperFunctions.py +142 -0
- ursa/util/logo_generator.py +625 -0
- ursa/util/memory_logger.py +183 -0
- ursa/util/optimization_schema.py +78 -0
- ursa/util/parse.py +405 -0
- ursa_ai-0.9.1.dist-info/METADATA +304 -0
- ursa_ai-0.9.1.dist-info/RECORD +51 -0
- ursa_ai-0.9.1.dist-info/WHEEL +5 -0
- ursa_ai-0.9.1.dist-info/entry_points.txt +2 -0
- ursa_ai-0.9.1.dist-info/licenses/LICENSE +8 -0
- ursa_ai-0.9.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
from typing import Annotated, Any, Mapping
|
|
2
|
+
|
|
3
|
+
from langchain.chat_models import BaseChatModel
|
|
4
|
+
from langchain_openai import ChatOpenAI
|
|
5
|
+
from langgraph.graph import StateGraph
|
|
6
|
+
from langgraph.graph.message import add_messages
|
|
7
|
+
from typing_extensions import TypedDict
|
|
8
|
+
|
|
9
|
+
from .base import BaseAgent
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ChatState(TypedDict):
|
|
13
|
+
messages: Annotated[list, add_messages]
|
|
14
|
+
thread_id: str
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ChatAgent(BaseAgent):
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
llm: BaseChatModel,
|
|
21
|
+
**kwargs,
|
|
22
|
+
):
|
|
23
|
+
super().__init__(llm, **kwargs)
|
|
24
|
+
self._build_graph()
|
|
25
|
+
|
|
26
|
+
def _response_node(self, state: ChatState) -> ChatState:
|
|
27
|
+
res = self.llm.invoke(
|
|
28
|
+
state["messages"], {"configurable": {"thread_id": self.thread_id}}
|
|
29
|
+
)
|
|
30
|
+
return {"messages": [res]}
|
|
31
|
+
|
|
32
|
+
def _build_graph(self):
|
|
33
|
+
graph = StateGraph(ChatState)
|
|
34
|
+
self.add_node(graph, self._response_node)
|
|
35
|
+
graph.set_entry_point("_response_node")
|
|
36
|
+
graph.set_finish_point("_response_node")
|
|
37
|
+
self._action = graph.compile(checkpointer=self.checkpointer)
|
|
38
|
+
|
|
39
|
+
def _invoke(
|
|
40
|
+
self, inputs: Mapping[str, Any], recursion_limit: int = 1000, **_
|
|
41
|
+
):
|
|
42
|
+
config = self.build_config(
|
|
43
|
+
recursion_limit=recursion_limit, tags=["graph"]
|
|
44
|
+
)
|
|
45
|
+
return self._action.invoke(inputs, config)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def main():
|
|
49
|
+
model = ChatOpenAI(
|
|
50
|
+
model="gpt-5-mini", max_tokens=10000, timeout=None, max_retries=2
|
|
51
|
+
)
|
|
52
|
+
websearcher = ChatAgent(llm=model)
|
|
53
|
+
problem_string = "What is your name?"
|
|
54
|
+
print("Prompt: ", problem_string)
|
|
55
|
+
result = websearcher.invoke(problem_string)
|
|
56
|
+
return result["messages"][-1].content
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
if __name__ == "__main__":
|
|
60
|
+
print("Response: ", main())
|
|
@@ -0,0 +1,341 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import subprocess
|
|
3
|
+
from typing import Annotated, Literal, TypedDict
|
|
4
|
+
|
|
5
|
+
from langchain.chat_models import BaseChatModel
|
|
6
|
+
from langchain_core.messages import HumanMessage, SystemMessage
|
|
7
|
+
from langchain_core.tools import tool
|
|
8
|
+
from langgraph.graph import END, START, StateGraph
|
|
9
|
+
from langgraph.graph.message import add_messages
|
|
10
|
+
from langgraph.prebuilt import InjectedState, ToolNode
|
|
11
|
+
|
|
12
|
+
from ..prompt_library.code_review_prompts import (
|
|
13
|
+
get_code_review_prompt,
|
|
14
|
+
get_plan_review_prompt,
|
|
15
|
+
)
|
|
16
|
+
from ..prompt_library.execution_prompts import summarize_prompt
|
|
17
|
+
|
|
18
|
+
# from langchain_core.runnables.graph import MermaidDrawMethod
|
|
19
|
+
from .base import BaseAgent
|
|
20
|
+
|
|
21
|
+
# --- ANSI color codes ---
|
|
22
|
+
GREEN = "\033[92m"
|
|
23
|
+
BLUE = "\033[94m"
|
|
24
|
+
RED = "\033[91m"
|
|
25
|
+
RESET = "\033[0m"
|
|
26
|
+
BOLD = "\033[1m"
|
|
27
|
+
|
|
28
|
+
code_extensions = [
|
|
29
|
+
".py",
|
|
30
|
+
".R",
|
|
31
|
+
".jl",
|
|
32
|
+
".c",
|
|
33
|
+
".cpp",
|
|
34
|
+
".cc",
|
|
35
|
+
".cxx",
|
|
36
|
+
".c++",
|
|
37
|
+
".C",
|
|
38
|
+
".f90",
|
|
39
|
+
".f95",
|
|
40
|
+
".f03",
|
|
41
|
+
]
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class CodeReviewState(TypedDict):
|
|
45
|
+
messages: Annotated[list, add_messages]
|
|
46
|
+
project_prompt: str
|
|
47
|
+
code_files: list[str]
|
|
48
|
+
edited_files: list[str]
|
|
49
|
+
workspace: str
|
|
50
|
+
iteration: int
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class CodeReviewAgent(BaseAgent):
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
llm: BaseChatModel,
|
|
57
|
+
**kwargs,
|
|
58
|
+
):
|
|
59
|
+
super().__init__(llm, **kwargs)
|
|
60
|
+
print("### WORK IN PROGRESS ###")
|
|
61
|
+
print(
|
|
62
|
+
"CODE REVIEW AGENT NOT YET FULLY IMPLEMENTED AND TESTED. BE AWARE THAT IT WILL LIKELY NOT WORK AS INTENDED YET."
|
|
63
|
+
)
|
|
64
|
+
self.summarize_prompt = summarize_prompt
|
|
65
|
+
self.tools = [run_cmd, write_file, read_file]
|
|
66
|
+
self.tool_node = ToolNode(self.tools)
|
|
67
|
+
self.llm = self.llm.bind_tools(self.tools)
|
|
68
|
+
|
|
69
|
+
self._initialize_agent()
|
|
70
|
+
|
|
71
|
+
# Define the function that calls the model
|
|
72
|
+
def plan_review(self, state: CodeReviewState) -> CodeReviewState:
|
|
73
|
+
new_state = state.copy()
|
|
74
|
+
|
|
75
|
+
assert "workspace" in new_state.keys(), "No workspace set for review!"
|
|
76
|
+
|
|
77
|
+
plan_review_prompt = get_plan_review_prompt(
|
|
78
|
+
project_prompt=state["project_prompt"],
|
|
79
|
+
file_list=state["code_files"],
|
|
80
|
+
)
|
|
81
|
+
new_state["messages"] = [
|
|
82
|
+
SystemMessage(content=plan_review_prompt)
|
|
83
|
+
] + state["messages"]
|
|
84
|
+
response = self.llm.invoke(
|
|
85
|
+
new_state["messages"],
|
|
86
|
+
{"configurable": {"thread_id": self.thread_id}},
|
|
87
|
+
)
|
|
88
|
+
return {"messages": [response]}
|
|
89
|
+
|
|
90
|
+
# Define the function that calls the model
|
|
91
|
+
def file_review(self, state: CodeReviewState) -> CodeReviewState:
|
|
92
|
+
new_state = state.copy()
|
|
93
|
+
code_review_prompt = get_code_review_prompt(
|
|
94
|
+
project_prompt=state["project_prompt"],
|
|
95
|
+
file_list=state["code_files"],
|
|
96
|
+
)
|
|
97
|
+
filename = state["code_files"][state["iteration"]]
|
|
98
|
+
new_state["messages"][0] = SystemMessage(content=code_review_prompt)
|
|
99
|
+
new_state["messages"].append(
|
|
100
|
+
HumanMessage(content=f"Please review {filename}")
|
|
101
|
+
)
|
|
102
|
+
response = self.llm.invoke(
|
|
103
|
+
new_state["messages"],
|
|
104
|
+
{"configurable": {"thread_id": self.thread_id}},
|
|
105
|
+
)
|
|
106
|
+
return {"messages": [response]}
|
|
107
|
+
|
|
108
|
+
# Define the function that calls the model
|
|
109
|
+
def summarize(self, state: CodeReviewState) -> CodeReviewState:
|
|
110
|
+
messages = [SystemMessage(content=summarize_prompt)] + state["messages"]
|
|
111
|
+
response = self.llm.invoke(
|
|
112
|
+
messages, {"configurable": {"thread_id": self.thread_id}}
|
|
113
|
+
)
|
|
114
|
+
return {"messages": [response.content]}
|
|
115
|
+
|
|
116
|
+
def increment(self, state: CodeReviewState) -> CodeReviewState:
|
|
117
|
+
new_state = state.copy()
|
|
118
|
+
new_state["iteration"] += 1
|
|
119
|
+
if new_state["iteration"] >= len(new_state["code_files"]):
|
|
120
|
+
new_state["iteration"] = -1
|
|
121
|
+
print(
|
|
122
|
+
f"On to file {new_state['iteration'] + 1} out of {len(new_state['code_files'])}"
|
|
123
|
+
)
|
|
124
|
+
return new_state
|
|
125
|
+
|
|
126
|
+
# Define the function that calls the model
|
|
127
|
+
def safety_check(self, state: CodeReviewState) -> CodeReviewState:
|
|
128
|
+
new_state = state.copy()
|
|
129
|
+
if state["messages"][-1].tool_calls[0]["name"] == "run_cmd":
|
|
130
|
+
query = state["messages"][-1].tool_calls[0]["args"]["query"]
|
|
131
|
+
safety_check = self.llm.invoke(
|
|
132
|
+
(
|
|
133
|
+
"Assume commands to run python and Julia are safe because "
|
|
134
|
+
"the files are from a trusted source. "
|
|
135
|
+
"Answer only either [YES] or [NO]. Is this command safe to run: "
|
|
136
|
+
)
|
|
137
|
+
+ query,
|
|
138
|
+
{"configurable": {"thread_id": self.thread_id}},
|
|
139
|
+
)
|
|
140
|
+
if "[NO]" in safety_check.content:
|
|
141
|
+
print(f"{RED}{BOLD} [WARNING] {RESET}")
|
|
142
|
+
print(
|
|
143
|
+
f"{RED}{BOLD} [WARNING] That command deemed unsafe and cannot be run: {RESET}",
|
|
144
|
+
query,
|
|
145
|
+
" --- ",
|
|
146
|
+
safety_check,
|
|
147
|
+
)
|
|
148
|
+
print(f"{RED}{BOLD} [WARNING] {RESET}")
|
|
149
|
+
return {
|
|
150
|
+
"messages": [
|
|
151
|
+
"[UNSAFE] That command deemed unsafe and cannot be run: "
|
|
152
|
+
+ query
|
|
153
|
+
]
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
print(f"{GREEN}[PASSED] the safety check: {RESET}" + query)
|
|
157
|
+
elif state["messages"][-1].tool_calls[0]["name"] == "write_code":
|
|
158
|
+
fn = (
|
|
159
|
+
state["messages"][-1]
|
|
160
|
+
.tool_calls[0]["args"]
|
|
161
|
+
.get("filename", None)
|
|
162
|
+
)
|
|
163
|
+
if "code_files" in new_state:
|
|
164
|
+
if fn not in new_state["code_files"]:
|
|
165
|
+
new_state["code_files"].append(fn)
|
|
166
|
+
new_state["edited_files"].append(fn)
|
|
167
|
+
else:
|
|
168
|
+
new_state["edited_files"].append(fn)
|
|
169
|
+
else:
|
|
170
|
+
new_state["code_files"] = [fn]
|
|
171
|
+
|
|
172
|
+
return new_state
|
|
173
|
+
|
|
174
|
+
def _initialize_agent(self):
|
|
175
|
+
self.graph = StateGraph(CodeReviewState)
|
|
176
|
+
|
|
177
|
+
self.graph.add_node("plan_review", self.plan_review)
|
|
178
|
+
self.graph.add_node("file_review", self.file_review)
|
|
179
|
+
self.graph.add_node("increment", self.increment)
|
|
180
|
+
self.graph.add_node("action", self.tool_node)
|
|
181
|
+
self.graph.add_node("summarize", self.summarize)
|
|
182
|
+
self.graph.add_node("safety_check", self.safety_check)
|
|
183
|
+
|
|
184
|
+
# Set the entrypoint as `agent`
|
|
185
|
+
# This means that this node is the first one called
|
|
186
|
+
self.graph.add_edge(START, "plan_review")
|
|
187
|
+
|
|
188
|
+
self.graph.add_conditional_edges(
|
|
189
|
+
"file_review",
|
|
190
|
+
should_continue,
|
|
191
|
+
{
|
|
192
|
+
"action": "safety_check",
|
|
193
|
+
"increment": "increment",
|
|
194
|
+
"summarize": "summarize",
|
|
195
|
+
},
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
self.graph.add_conditional_edges(
|
|
199
|
+
"safety_check",
|
|
200
|
+
command_safe,
|
|
201
|
+
{
|
|
202
|
+
"safe": "action",
|
|
203
|
+
"unsafe": "file_review",
|
|
204
|
+
},
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
self.graph.add_edge("plan_review", "file_review")
|
|
208
|
+
self.graph.add_edge("action", "file_review")
|
|
209
|
+
self.graph.add_edge("increment", "file_review")
|
|
210
|
+
self.graph.add_edge("summarize", END)
|
|
211
|
+
|
|
212
|
+
self.action = self.graph.compile(checkpointer=self.checkpointer)
|
|
213
|
+
# self.action.get_graph().draw_mermaid_png(output_file_path="code_review_agent_graph.png", draw_method=MermaidDrawMethod.PYPPETEER)
|
|
214
|
+
|
|
215
|
+
def run(self, prompt, workspace):
|
|
216
|
+
code_files = [
|
|
217
|
+
x
|
|
218
|
+
for x in os.listdir(workspace)
|
|
219
|
+
if any([ext in x for ext in code_extensions])
|
|
220
|
+
]
|
|
221
|
+
initial_state = {
|
|
222
|
+
"messages": [],
|
|
223
|
+
"project_prompt": prompt,
|
|
224
|
+
"code_files": code_files,
|
|
225
|
+
"edited_files": [],
|
|
226
|
+
"iteration": 0,
|
|
227
|
+
"workspace": workspace,
|
|
228
|
+
}
|
|
229
|
+
return self.action.invoke(
|
|
230
|
+
initial_state, {"configurable": {"thread_id": self.thread_id}}
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
@tool
|
|
235
|
+
def run_cmd(query: str, state: Annotated[dict, InjectedState]) -> str:
|
|
236
|
+
"""Run command from commandline"""
|
|
237
|
+
workspace_dir = state["workspace"]
|
|
238
|
+
|
|
239
|
+
print("RUNNING: ", query)
|
|
240
|
+
process = subprocess.Popen(
|
|
241
|
+
query.split(" "),
|
|
242
|
+
stdout=subprocess.PIPE,
|
|
243
|
+
stderr=subprocess.PIPE,
|
|
244
|
+
text=True,
|
|
245
|
+
cwd=workspace_dir,
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
stdout, stderr = process.communicate(timeout=600)
|
|
249
|
+
|
|
250
|
+
print("STDOUT: ", stdout)
|
|
251
|
+
print("STDERR: ", stderr)
|
|
252
|
+
|
|
253
|
+
return f"STDOUT: {stdout} and STDERR: {stderr}"
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
@tool
|
|
257
|
+
def read_file(filename: str, state: Annotated[dict, InjectedState]):
|
|
258
|
+
"""
|
|
259
|
+
Reads in a file with a given filename into a string
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
filename: string filename to read in
|
|
263
|
+
"""
|
|
264
|
+
workspace_dir = state["workspace"]
|
|
265
|
+
full_filename = os.path.join(workspace_dir, filename)
|
|
266
|
+
|
|
267
|
+
print("[READING]: ", full_filename)
|
|
268
|
+
with open(full_filename, "r") as file:
|
|
269
|
+
file_contents = file.read()
|
|
270
|
+
return file_contents
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
@tool
|
|
274
|
+
def write_file(
|
|
275
|
+
code: str, filename: str, state: Annotated[dict, InjectedState]
|
|
276
|
+
) -> str:
|
|
277
|
+
"""
|
|
278
|
+
Writes text to a file in the given workspace as requested.
|
|
279
|
+
|
|
280
|
+
Args:
|
|
281
|
+
code: Text to write to a file
|
|
282
|
+
filename: the filename to write to
|
|
283
|
+
|
|
284
|
+
Returns:
|
|
285
|
+
Execution results
|
|
286
|
+
"""
|
|
287
|
+
workspace_dir = state["workspace"]
|
|
288
|
+
|
|
289
|
+
print("[WRITING]: ", filename)
|
|
290
|
+
try:
|
|
291
|
+
# Extract code if wrapped in markdown code blocks
|
|
292
|
+
if "```" in code:
|
|
293
|
+
code_parts = code.split("```")
|
|
294
|
+
if len(code_parts) >= 3:
|
|
295
|
+
# Extract the actual code
|
|
296
|
+
if "\n" in code_parts[1]:
|
|
297
|
+
code = "\n".join(code_parts[1].strip().split("\n")[1:])
|
|
298
|
+
else:
|
|
299
|
+
code = code_parts[2].strip()
|
|
300
|
+
|
|
301
|
+
# Write code to a file
|
|
302
|
+
code_file = os.path.join(workspace_dir, filename)
|
|
303
|
+
|
|
304
|
+
with open(code_file, "w") as f:
|
|
305
|
+
f.write(code)
|
|
306
|
+
print(f"Written code to file: {code_file}")
|
|
307
|
+
|
|
308
|
+
return f"File {filename} written successfully."
|
|
309
|
+
|
|
310
|
+
except Exception as e:
|
|
311
|
+
print(f"Error generating code: {str(e)}")
|
|
312
|
+
# Return minimal code that prints the error
|
|
313
|
+
return f"Failed to write {filename} successfully."
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
# Define the function that determines whether to continue or not
|
|
317
|
+
def should_continue(
|
|
318
|
+
state: CodeReviewState,
|
|
319
|
+
) -> Literal["summarize", "increment", "action"]:
|
|
320
|
+
messages = state["messages"]
|
|
321
|
+
last_message = messages[-1]
|
|
322
|
+
# If there is no tool call, then we finish
|
|
323
|
+
if not last_message.tool_calls:
|
|
324
|
+
if state["iteration"] == -1:
|
|
325
|
+
target_node = "summarize"
|
|
326
|
+
else:
|
|
327
|
+
target_node = "increment"
|
|
328
|
+
# Otherwise if there is, we use the tool
|
|
329
|
+
else:
|
|
330
|
+
target_node = "action"
|
|
331
|
+
return target_node
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
# Define the function that determines whether to continue or not
|
|
335
|
+
def command_safe(state: CodeReviewState) -> Literal["safe", "unsafe"]:
|
|
336
|
+
messages = state["messages"]
|
|
337
|
+
last_message = messages[-1]
|
|
338
|
+
if "[UNSAFE]" in last_message.content:
|
|
339
|
+
return "unsafe"
|
|
340
|
+
else:
|
|
341
|
+
return "safe"
|