ursa-ai 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ursa/__init__.py +3 -0
- ursa/agents/__init__.py +32 -0
- ursa/agents/acquisition_agents.py +812 -0
- ursa/agents/arxiv_agent.py +429 -0
- ursa/agents/base.py +728 -0
- ursa/agents/chat_agent.py +60 -0
- ursa/agents/code_review_agent.py +341 -0
- ursa/agents/execution_agent.py +915 -0
- ursa/agents/hypothesizer_agent.py +614 -0
- ursa/agents/lammps_agent.py +465 -0
- ursa/agents/mp_agent.py +204 -0
- ursa/agents/optimization_agent.py +410 -0
- ursa/agents/planning_agent.py +219 -0
- ursa/agents/rag_agent.py +304 -0
- ursa/agents/recall_agent.py +54 -0
- ursa/agents/websearch_agent.py +196 -0
- ursa/cli/__init__.py +363 -0
- ursa/cli/hitl.py +516 -0
- ursa/cli/hitl_api.py +75 -0
- ursa/observability/metrics_charts.py +1279 -0
- ursa/observability/metrics_io.py +11 -0
- ursa/observability/metrics_session.py +750 -0
- ursa/observability/pricing.json +97 -0
- ursa/observability/pricing.py +321 -0
- ursa/observability/timing.py +1466 -0
- ursa/prompt_library/__init__.py +0 -0
- ursa/prompt_library/code_review_prompts.py +51 -0
- ursa/prompt_library/execution_prompts.py +50 -0
- ursa/prompt_library/hypothesizer_prompts.py +17 -0
- ursa/prompt_library/literature_prompts.py +11 -0
- ursa/prompt_library/optimization_prompts.py +131 -0
- ursa/prompt_library/planning_prompts.py +79 -0
- ursa/prompt_library/websearch_prompts.py +131 -0
- ursa/tools/__init__.py +0 -0
- ursa/tools/feasibility_checker.py +114 -0
- ursa/tools/feasibility_tools.py +1075 -0
- ursa/tools/run_command.py +27 -0
- ursa/tools/write_code.py +42 -0
- ursa/util/__init__.py +0 -0
- ursa/util/diff_renderer.py +128 -0
- ursa/util/helperFunctions.py +142 -0
- ursa/util/logo_generator.py +625 -0
- ursa/util/memory_logger.py +183 -0
- ursa/util/optimization_schema.py +78 -0
- ursa/util/parse.py +405 -0
- ursa_ai-0.9.1.dist-info/METADATA +304 -0
- ursa_ai-0.9.1.dist-info/RECORD +51 -0
- ursa_ai-0.9.1.dist-info/WHEEL +5 -0
- ursa_ai-0.9.1.dist-info/entry_points.txt +2 -0
- ursa_ai-0.9.1.dist-info/licenses/LICENSE +8 -0
- ursa_ai-0.9.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
import subprocess
|
|
2
|
+
|
|
3
|
+
from langchain_core.tools import tool
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@tool
|
|
7
|
+
def run_cmd(query: str, workspace_dir: str) -> str:
|
|
8
|
+
"""Run command from commandline in the directory workspace_dir"""
|
|
9
|
+
|
|
10
|
+
print("RUNNING: ", query)
|
|
11
|
+
print(
|
|
12
|
+
"DANGER DANGER DANGER - THERE IS NO GUARDRAIL FOR SAFETY IN THIS IMPLEMENTATION - DANGER DANGER DANGER"
|
|
13
|
+
)
|
|
14
|
+
process = subprocess.Popen(
|
|
15
|
+
query.split(" "),
|
|
16
|
+
stdout=subprocess.PIPE,
|
|
17
|
+
stderr=subprocess.PIPE,
|
|
18
|
+
text=True,
|
|
19
|
+
cwd=workspace_dir,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
stdout, stderr = process.communicate(timeout=600)
|
|
23
|
+
|
|
24
|
+
print("STDOUT: ", stdout)
|
|
25
|
+
print("STDERR: ", stderr)
|
|
26
|
+
|
|
27
|
+
return f"STDOUT: {stdout} and STDERR: {stderr}"
|
ursa/tools/write_code.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
from langchain_core.tools import tool
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@tool
|
|
7
|
+
def write_python(code: str, filename: str, workspace_dir: str) -> str:
|
|
8
|
+
"""
|
|
9
|
+
Writes code to a file in the given workspace.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
code: The code to write
|
|
13
|
+
filename: the filename to write
|
|
14
|
+
|
|
15
|
+
Returns:
|
|
16
|
+
File writing status: string
|
|
17
|
+
"""
|
|
18
|
+
print("Writing filename ", filename)
|
|
19
|
+
try:
|
|
20
|
+
# Extract code if wrapped in markdown code blocks
|
|
21
|
+
if "```" in code:
|
|
22
|
+
code_parts = code.split("```")
|
|
23
|
+
if len(code_parts) >= 3:
|
|
24
|
+
# Extract the actual code
|
|
25
|
+
if "\n" in code_parts[1]:
|
|
26
|
+
code = "\n".join(code_parts[1].strip().split("\n")[1:])
|
|
27
|
+
else:
|
|
28
|
+
code = code_parts[2].strip()
|
|
29
|
+
|
|
30
|
+
# Write code to a file
|
|
31
|
+
code_file = os.path.join(workspace_dir, filename)
|
|
32
|
+
|
|
33
|
+
with open(code_file, "w") as f:
|
|
34
|
+
f.write(code)
|
|
35
|
+
print(f"Written code to file: {code_file}")
|
|
36
|
+
|
|
37
|
+
return f"File {filename} written successfully."
|
|
38
|
+
|
|
39
|
+
except Exception as e:
|
|
40
|
+
print(f"Error generating code: {str(e)}")
|
|
41
|
+
# Return minimal code that prints the error
|
|
42
|
+
return f"Failed to write {filename} successfully."
|
ursa/util/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
import difflib
|
|
2
|
+
import re
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
from rich.console import Console, ConsoleOptions, RenderResult
|
|
6
|
+
from rich.syntax import Syntax
|
|
7
|
+
from rich.text import Text
|
|
8
|
+
|
|
9
|
+
# unified diff hunk header regex
|
|
10
|
+
_HUNK_RE = re.compile(r"^@@ -(\d+)(?:,\d+)? \+(\d+)(?:,\d+)? @@")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class _LineStyle:
|
|
15
|
+
prefix: str
|
|
16
|
+
bg: str
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
_STYLE = {
|
|
20
|
+
"add": _LineStyle("+ ", "on #003000"),
|
|
21
|
+
"del": _LineStyle("- ", "on #300000"),
|
|
22
|
+
"ctx": _LineStyle(" ", "on grey15"),
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class DiffRenderer:
|
|
27
|
+
"""Renderable diff—`console.print(DiffRenderer(...))`"""
|
|
28
|
+
|
|
29
|
+
def __init__(self, content: str, updated: str, filename: str):
|
|
30
|
+
# total lines in each version
|
|
31
|
+
self._old_total = len(content.splitlines())
|
|
32
|
+
self._new_total = len(updated.splitlines())
|
|
33
|
+
|
|
34
|
+
# number of digits in the largest count
|
|
35
|
+
self._num_width = len(str(max(self._old_total, self._new_total))) + 2
|
|
36
|
+
|
|
37
|
+
# get the diff
|
|
38
|
+
self._diff_lines = list(
|
|
39
|
+
difflib.unified_diff(
|
|
40
|
+
content.splitlines(),
|
|
41
|
+
updated.splitlines(),
|
|
42
|
+
fromfile=f"{filename} (original)",
|
|
43
|
+
tofile=f"{filename} (modified)",
|
|
44
|
+
lineterm="",
|
|
45
|
+
)
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
# get syntax style
|
|
49
|
+
try:
|
|
50
|
+
self._lexer_name = Syntax.guess_lexer(filename, updated)
|
|
51
|
+
except Exception:
|
|
52
|
+
self._lexer_name = "text"
|
|
53
|
+
|
|
54
|
+
def __rich_console__(
|
|
55
|
+
self, console: Console, opts: ConsoleOptions
|
|
56
|
+
) -> RenderResult:
|
|
57
|
+
old_line = new_line = None
|
|
58
|
+
width = console.width
|
|
59
|
+
|
|
60
|
+
for raw in self._diff_lines:
|
|
61
|
+
# grab line numbers from hunk header
|
|
62
|
+
if m := _HUNK_RE.match(raw):
|
|
63
|
+
old_line, new_line = map(int, m.groups())
|
|
64
|
+
# build a marker
|
|
65
|
+
n = self._num_width
|
|
66
|
+
tick_col = "." * (n - 1)
|
|
67
|
+
indent_ticks = f" {tick_col} {tick_col}"
|
|
68
|
+
# pad to the indent width
|
|
69
|
+
full_indent = indent_ticks.ljust(2 * n + 3)
|
|
70
|
+
yield Text(
|
|
71
|
+
f"{full_indent}{raw}".ljust(width), style="white on grey30"
|
|
72
|
+
)
|
|
73
|
+
continue
|
|
74
|
+
|
|
75
|
+
# skip header lines
|
|
76
|
+
if raw.startswith(("---", "+++")):
|
|
77
|
+
continue
|
|
78
|
+
|
|
79
|
+
# split the line
|
|
80
|
+
if raw.startswith("+"):
|
|
81
|
+
style = _STYLE["add"]
|
|
82
|
+
code = raw[1:]
|
|
83
|
+
elif raw.startswith("-"):
|
|
84
|
+
style = _STYLE["del"]
|
|
85
|
+
code = raw[1:]
|
|
86
|
+
else:
|
|
87
|
+
style = _STYLE["ctx"]
|
|
88
|
+
code = raw[1:] if raw.startswith(" ") else raw
|
|
89
|
+
|
|
90
|
+
# compute line numbers
|
|
91
|
+
if raw.startswith("+"):
|
|
92
|
+
old_num, new_num = None, new_line
|
|
93
|
+
new_line += 1
|
|
94
|
+
elif raw.startswith("-"):
|
|
95
|
+
old_num, new_num = old_line, None
|
|
96
|
+
old_line += 1
|
|
97
|
+
else:
|
|
98
|
+
old_num, new_num = old_line, new_line
|
|
99
|
+
old_line += 1
|
|
100
|
+
new_line += 1
|
|
101
|
+
|
|
102
|
+
old_str = str(old_num) if old_num is not None else " "
|
|
103
|
+
new_str = str(new_num) if new_num is not None else " "
|
|
104
|
+
|
|
105
|
+
# Syntax-highlight the code part
|
|
106
|
+
syntax = Syntax(
|
|
107
|
+
code, self._lexer_name, line_numbers=False, word_wrap=False
|
|
108
|
+
)
|
|
109
|
+
text_code: Text = syntax.highlight(code)
|
|
110
|
+
if text_code.plain.endswith("\n"):
|
|
111
|
+
text_code = text_code[:-1]
|
|
112
|
+
# apply background
|
|
113
|
+
text_code.stylize(style.bg)
|
|
114
|
+
|
|
115
|
+
# line numbers + code
|
|
116
|
+
nums = Text(
|
|
117
|
+
f"{old_str:>{self._num_width}}{new_str:>{self._num_width}} ",
|
|
118
|
+
style=f"white {style.bg}",
|
|
119
|
+
)
|
|
120
|
+
diff_mark = Text(style.prefix, style=f"bright_white {style.bg}")
|
|
121
|
+
line_text = nums + diff_mark + text_code
|
|
122
|
+
|
|
123
|
+
# pad to console width
|
|
124
|
+
pad_len = width - line_text.cell_len
|
|
125
|
+
if pad_len > 0:
|
|
126
|
+
line_text.append(" " * pad_len, style=style.bg)
|
|
127
|
+
|
|
128
|
+
yield line_text
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import uuid
|
|
5
|
+
from typing import Any, Callable, Dict, Iterable, List, Union
|
|
6
|
+
|
|
7
|
+
from langchain_core.messages import AIMessage, BaseMessage, ToolMessage
|
|
8
|
+
from langchain_core.runnables import Runnable
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
# --- if you already have your own versions, reuse them ---
|
|
12
|
+
def _parse_args(v: Any) -> Dict[str, Any]:
|
|
13
|
+
if v is None:
|
|
14
|
+
return {}
|
|
15
|
+
if isinstance(v, dict):
|
|
16
|
+
return v
|
|
17
|
+
if isinstance(v, str):
|
|
18
|
+
try:
|
|
19
|
+
return json.loads(v)
|
|
20
|
+
except Exception:
|
|
21
|
+
return {"_raw": v}
|
|
22
|
+
return {"_raw": v}
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def extract_tool_calls(msg: AIMessage) -> List[Dict[str, Any]]:
|
|
26
|
+
# Prefer normalized field
|
|
27
|
+
if msg.tool_calls:
|
|
28
|
+
out = []
|
|
29
|
+
for tc in msg.tool_calls:
|
|
30
|
+
name = getattr(tc, "name", None) or tc.get("name")
|
|
31
|
+
args = getattr(tc, "args", None) or tc.get("args")
|
|
32
|
+
call_id = getattr(tc, "id", None) or tc.get("id")
|
|
33
|
+
out.append({"name": name, "args": _parse_args(args), "id": call_id})
|
|
34
|
+
return out
|
|
35
|
+
|
|
36
|
+
# Fallbacks (OpenAI raw payloads)
|
|
37
|
+
ak = msg.additional_kwargs or {}
|
|
38
|
+
if ak.get("tool_calls"):
|
|
39
|
+
out = []
|
|
40
|
+
for tc in ak["tool_calls"]:
|
|
41
|
+
fn = tc.get("function", {}) or {}
|
|
42
|
+
out.append({
|
|
43
|
+
"name": fn.get("name"),
|
|
44
|
+
"args": _parse_args(fn.get("arguments")),
|
|
45
|
+
"id": tc.get("id"),
|
|
46
|
+
})
|
|
47
|
+
return out
|
|
48
|
+
|
|
49
|
+
if ak.get("function_call"):
|
|
50
|
+
fn = ak["function_call"]
|
|
51
|
+
return [
|
|
52
|
+
{
|
|
53
|
+
"name": fn.get("name"),
|
|
54
|
+
"args": _parse_args(fn.get("arguments")),
|
|
55
|
+
"id": None,
|
|
56
|
+
}
|
|
57
|
+
]
|
|
58
|
+
return []
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
# -----------------------------------------------------------------------------
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
ToolRegistry = Dict[str, Union[Runnable, Callable[..., Any]]]
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _stringify_output(x: Any) -> str:
|
|
68
|
+
if isinstance(x, str):
|
|
69
|
+
return x
|
|
70
|
+
try:
|
|
71
|
+
return json.dumps(x, ensure_ascii=False)
|
|
72
|
+
except Exception:
|
|
73
|
+
return str(x)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _invoke_tool(
|
|
77
|
+
tool: Union[Runnable, Callable[..., Any]], args: Dict[str, Any]
|
|
78
|
+
) -> Any:
|
|
79
|
+
# Runnable (LangChain tools & chains)
|
|
80
|
+
if isinstance(tool, Runnable):
|
|
81
|
+
return tool.invoke(args)
|
|
82
|
+
# Plain callable
|
|
83
|
+
try:
|
|
84
|
+
return tool(**args)
|
|
85
|
+
except TypeError:
|
|
86
|
+
# Some tools expect a single positional payload
|
|
87
|
+
return tool(args)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def run_tool_calls(
|
|
91
|
+
ai_msg: AIMessage,
|
|
92
|
+
tools: Union[ToolRegistry, Iterable[Union[Runnable, Callable[..., Any]]]],
|
|
93
|
+
) -> List[BaseMessage]:
|
|
94
|
+
"""
|
|
95
|
+
Args:
|
|
96
|
+
ai_msg: The LLM's AIMessage containing tool calls.
|
|
97
|
+
tools: Either a dict {name: tool} or an iterable of tools (must have `.name`
|
|
98
|
+
for mapping). Each tool can be a Runnable or a plain callable.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
out: list[BaseMessage] to feed back to the model
|
|
102
|
+
"""
|
|
103
|
+
# Build a name->tool map
|
|
104
|
+
if isinstance(tools, dict):
|
|
105
|
+
registry: ToolRegistry = tools # type: ignore
|
|
106
|
+
else:
|
|
107
|
+
registry = {}
|
|
108
|
+
for t in tools:
|
|
109
|
+
name = getattr(t, "name", None) or getattr(t, "__name__", None)
|
|
110
|
+
if not name:
|
|
111
|
+
raise ValueError(f"Tool {t!r} has no discoverable name.")
|
|
112
|
+
registry[name] = t # type: ignore
|
|
113
|
+
|
|
114
|
+
calls = extract_tool_calls(ai_msg)
|
|
115
|
+
|
|
116
|
+
if not calls:
|
|
117
|
+
return []
|
|
118
|
+
|
|
119
|
+
out: List[BaseMessage] = []
|
|
120
|
+
for call in calls:
|
|
121
|
+
name = call.get("name")
|
|
122
|
+
args = call.get("args", {}) or {}
|
|
123
|
+
call_id = call.get("id") or f"call_{uuid.uuid4().hex}"
|
|
124
|
+
|
|
125
|
+
# 1) the AIMessage that generated the call
|
|
126
|
+
out.append(ai_msg)
|
|
127
|
+
|
|
128
|
+
# 2) the ToolMessage with the execution result (or error)
|
|
129
|
+
if name not in registry:
|
|
130
|
+
content = f"ERROR: unknown tool '{name}'."
|
|
131
|
+
else:
|
|
132
|
+
try:
|
|
133
|
+
result = _invoke_tool(registry[name], args)
|
|
134
|
+
content = _stringify_output(result)
|
|
135
|
+
except Exception as e:
|
|
136
|
+
content = f"ERROR: {type(e).__name__}: {e}"
|
|
137
|
+
|
|
138
|
+
out.append(
|
|
139
|
+
ToolMessage(content=content, tool_call_id=call_id, name=name)
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
return out
|