ursa-ai 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ursa/__init__.py +3 -0
- ursa/agents/__init__.py +32 -0
- ursa/agents/acquisition_agents.py +812 -0
- ursa/agents/arxiv_agent.py +429 -0
- ursa/agents/base.py +728 -0
- ursa/agents/chat_agent.py +60 -0
- ursa/agents/code_review_agent.py +341 -0
- ursa/agents/execution_agent.py +915 -0
- ursa/agents/hypothesizer_agent.py +614 -0
- ursa/agents/lammps_agent.py +465 -0
- ursa/agents/mp_agent.py +204 -0
- ursa/agents/optimization_agent.py +410 -0
- ursa/agents/planning_agent.py +219 -0
- ursa/agents/rag_agent.py +304 -0
- ursa/agents/recall_agent.py +54 -0
- ursa/agents/websearch_agent.py +196 -0
- ursa/cli/__init__.py +363 -0
- ursa/cli/hitl.py +516 -0
- ursa/cli/hitl_api.py +75 -0
- ursa/observability/metrics_charts.py +1279 -0
- ursa/observability/metrics_io.py +11 -0
- ursa/observability/metrics_session.py +750 -0
- ursa/observability/pricing.json +97 -0
- ursa/observability/pricing.py +321 -0
- ursa/observability/timing.py +1466 -0
- ursa/prompt_library/__init__.py +0 -0
- ursa/prompt_library/code_review_prompts.py +51 -0
- ursa/prompt_library/execution_prompts.py +50 -0
- ursa/prompt_library/hypothesizer_prompts.py +17 -0
- ursa/prompt_library/literature_prompts.py +11 -0
- ursa/prompt_library/optimization_prompts.py +131 -0
- ursa/prompt_library/planning_prompts.py +79 -0
- ursa/prompt_library/websearch_prompts.py +131 -0
- ursa/tools/__init__.py +0 -0
- ursa/tools/feasibility_checker.py +114 -0
- ursa/tools/feasibility_tools.py +1075 -0
- ursa/tools/run_command.py +27 -0
- ursa/tools/write_code.py +42 -0
- ursa/util/__init__.py +0 -0
- ursa/util/diff_renderer.py +128 -0
- ursa/util/helperFunctions.py +142 -0
- ursa/util/logo_generator.py +625 -0
- ursa/util/memory_logger.py +183 -0
- ursa/util/optimization_schema.py +78 -0
- ursa/util/parse.py +405 -0
- ursa_ai-0.9.1.dist-info/METADATA +304 -0
- ursa_ai-0.9.1.dist-info/RECORD +51 -0
- ursa_ai-0.9.1.dist-info/WHEEL +5 -0
- ursa_ai-0.9.1.dist-info/entry_points.txt +2 -0
- ursa_ai-0.9.1.dist-info/licenses/LICENSE +8 -0
- ursa_ai-0.9.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,429 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import os
|
|
3
|
+
import re
|
|
4
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
5
|
+
from io import BytesIO
|
|
6
|
+
from typing import Any, Mapping, TypedDict
|
|
7
|
+
from urllib.parse import quote
|
|
8
|
+
|
|
9
|
+
import feedparser
|
|
10
|
+
import pymupdf
|
|
11
|
+
import requests
|
|
12
|
+
from langchain.chat_models import BaseChatModel
|
|
13
|
+
from langchain_community.document_loaders import PyPDFLoader
|
|
14
|
+
from langchain_core.output_parsers import StrOutputParser
|
|
15
|
+
from langchain_core.prompts import ChatPromptTemplate
|
|
16
|
+
from langgraph.graph import StateGraph
|
|
17
|
+
from PIL import Image
|
|
18
|
+
from tqdm import tqdm
|
|
19
|
+
|
|
20
|
+
from ursa.agents.base import BaseAgent
|
|
21
|
+
from ursa.agents.rag_agent import RAGAgent
|
|
22
|
+
|
|
23
|
+
try:
|
|
24
|
+
from openai import OpenAI
|
|
25
|
+
except Exception:
|
|
26
|
+
pass
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class PaperMetadata(TypedDict):
|
|
30
|
+
arxiv_id: str
|
|
31
|
+
full_text: str
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class PaperState(TypedDict, total=False):
|
|
35
|
+
query: str
|
|
36
|
+
context: str
|
|
37
|
+
papers: list[PaperMetadata]
|
|
38
|
+
summaries: list[str]
|
|
39
|
+
final_summary: str
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def describe_image(image: Image.Image) -> str:
|
|
43
|
+
if "OpenAI" not in globals():
|
|
44
|
+
print(
|
|
45
|
+
"Vision transformer for summarizing images currently only implemented for OpenAI API."
|
|
46
|
+
)
|
|
47
|
+
return ""
|
|
48
|
+
client = OpenAI()
|
|
49
|
+
|
|
50
|
+
buffered = BytesIO()
|
|
51
|
+
image.save(buffered, format="PNG")
|
|
52
|
+
img_base64 = base64.b64encode(buffered.getvalue()).decode()
|
|
53
|
+
|
|
54
|
+
response = client.chat.completions.create(
|
|
55
|
+
model="gpt-4-vision-preview",
|
|
56
|
+
messages=[
|
|
57
|
+
{
|
|
58
|
+
"role": "system",
|
|
59
|
+
"content": "You are a scientific assistant who explains plots and scientific diagrams.",
|
|
60
|
+
},
|
|
61
|
+
{
|
|
62
|
+
"role": "user",
|
|
63
|
+
"content": [
|
|
64
|
+
{
|
|
65
|
+
"type": "text",
|
|
66
|
+
"text": "Describe this scientific image or plot in detail.",
|
|
67
|
+
},
|
|
68
|
+
{
|
|
69
|
+
"type": "image_url",
|
|
70
|
+
"image_url": {
|
|
71
|
+
"url": f"data:image/png;base64,{img_base64}"
|
|
72
|
+
},
|
|
73
|
+
},
|
|
74
|
+
],
|
|
75
|
+
},
|
|
76
|
+
],
|
|
77
|
+
max_tokens=500,
|
|
78
|
+
)
|
|
79
|
+
return response.choices[0].message.content.strip()
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def extract_and_describe_images(
|
|
83
|
+
pdf_path: str, max_images: int = 5
|
|
84
|
+
) -> list[str]:
|
|
85
|
+
doc = pymupdf.open(pdf_path)
|
|
86
|
+
descriptions = []
|
|
87
|
+
image_count = 0
|
|
88
|
+
|
|
89
|
+
for page_index in range(len(doc)):
|
|
90
|
+
if image_count >= max_images:
|
|
91
|
+
break
|
|
92
|
+
page = doc[page_index]
|
|
93
|
+
images = page.get_images(full=True)
|
|
94
|
+
|
|
95
|
+
for img_index, img in enumerate(images):
|
|
96
|
+
if image_count >= max_images:
|
|
97
|
+
break
|
|
98
|
+
xref = img[0]
|
|
99
|
+
base_image = doc.extract_image(xref)
|
|
100
|
+
image_bytes = base_image["image"]
|
|
101
|
+
image = Image.open(BytesIO(image_bytes))
|
|
102
|
+
|
|
103
|
+
try:
|
|
104
|
+
desc = describe_image(image)
|
|
105
|
+
descriptions.append(
|
|
106
|
+
f"Page {page_index + 1}, Image {img_index + 1}: {desc}"
|
|
107
|
+
)
|
|
108
|
+
except Exception as e:
|
|
109
|
+
descriptions.append(
|
|
110
|
+
f"Page {page_index + 1}, Image {img_index + 1}: [Error: {e}]"
|
|
111
|
+
)
|
|
112
|
+
image_count += 1
|
|
113
|
+
|
|
114
|
+
return descriptions
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def remove_surrogates(text: str) -> str:
|
|
118
|
+
return re.sub(r"[\ud800-\udfff]", "", text)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class ArxivAgentLegacy(BaseAgent):
|
|
122
|
+
def __init__(
|
|
123
|
+
self,
|
|
124
|
+
llm: BaseChatModel,
|
|
125
|
+
summarize: bool = True,
|
|
126
|
+
process_images=True,
|
|
127
|
+
max_results: int = 3,
|
|
128
|
+
download_papers: bool = True,
|
|
129
|
+
rag_embedding=None,
|
|
130
|
+
database_path="arxiv_papers",
|
|
131
|
+
summaries_path="arxiv_generated_summaries",
|
|
132
|
+
vectorstore_path="arxiv_vectorstores",
|
|
133
|
+
**kwargs,
|
|
134
|
+
):
|
|
135
|
+
super().__init__(llm, **kwargs)
|
|
136
|
+
self.summarize = summarize
|
|
137
|
+
self.process_images = process_images
|
|
138
|
+
self.max_results = max_results
|
|
139
|
+
self.database_path = database_path
|
|
140
|
+
self.summaries_path = summaries_path
|
|
141
|
+
self.vectorstore_path = vectorstore_path
|
|
142
|
+
self.download_papers = download_papers
|
|
143
|
+
self.rag_embedding = rag_embedding
|
|
144
|
+
|
|
145
|
+
self._action = self._build_graph()
|
|
146
|
+
|
|
147
|
+
os.makedirs(self.database_path, exist_ok=True)
|
|
148
|
+
|
|
149
|
+
os.makedirs(self.summaries_path, exist_ok=True)
|
|
150
|
+
|
|
151
|
+
def _fetch_papers(self, query: str) -> list[PaperMetadata]:
|
|
152
|
+
if self.download_papers:
|
|
153
|
+
encoded_query = quote(query)
|
|
154
|
+
url = f"http://export.arxiv.org/api/query?search_query=all:{encoded_query}&start=0&max_results={self.max_results}"
|
|
155
|
+
# print(f"URL is {url}") # if verbose
|
|
156
|
+
entries = []
|
|
157
|
+
try:
|
|
158
|
+
response = requests.get(url, timeout=10)
|
|
159
|
+
response.raise_for_status()
|
|
160
|
+
|
|
161
|
+
feed = feedparser.parse(response.content)
|
|
162
|
+
# print(f"parsed response status is {feed.status}") # if verbose
|
|
163
|
+
entries = feed.entries
|
|
164
|
+
if feed.bozo:
|
|
165
|
+
raise Exception("Feed from arXiv looks like garbage =(")
|
|
166
|
+
except requests.exceptions.Timeout:
|
|
167
|
+
print("Request timed out while fetching papers.")
|
|
168
|
+
except requests.exceptions.RequestException as e:
|
|
169
|
+
print(f"Request error encountered while fetching papers: {e}")
|
|
170
|
+
except ValueError as ve:
|
|
171
|
+
print(f"Value error occurred while fetching papers: {ve}")
|
|
172
|
+
except Exception as e:
|
|
173
|
+
print(
|
|
174
|
+
f"An unexpected error occurred while fetching papers: {e}"
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
for i, entry in enumerate(entries):
|
|
178
|
+
full_id = entry.id.split("/abs/")[-1]
|
|
179
|
+
arxiv_id = full_id.split("/")[-1]
|
|
180
|
+
title = entry.title.strip()
|
|
181
|
+
# authors = ", ".join(author.name for author in entry.authors)
|
|
182
|
+
pdf_url = f"https://arxiv.org/pdf/{full_id}.pdf"
|
|
183
|
+
pdf_filename = os.path.join(
|
|
184
|
+
self.database_path, f"{arxiv_id}.pdf"
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
if os.path.exists(pdf_filename):
|
|
188
|
+
print(
|
|
189
|
+
f"Paper # {i + 1}, Title: {title}, already exists in database"
|
|
190
|
+
)
|
|
191
|
+
else:
|
|
192
|
+
print(f"Downloading paper # {i + 1}, Title: {title}")
|
|
193
|
+
response = requests.get(pdf_url)
|
|
194
|
+
with open(pdf_filename, "wb") as f:
|
|
195
|
+
f.write(response.content)
|
|
196
|
+
|
|
197
|
+
papers = []
|
|
198
|
+
|
|
199
|
+
pdf_files = [
|
|
200
|
+
f
|
|
201
|
+
for f in os.listdir(self.database_path)
|
|
202
|
+
if f.lower().endswith(".pdf")
|
|
203
|
+
]
|
|
204
|
+
|
|
205
|
+
for i, pdf_filename in enumerate(pdf_files):
|
|
206
|
+
full_text = ""
|
|
207
|
+
arxiv_id = pdf_filename.split(".pdf")[0]
|
|
208
|
+
vec_save_loc = self.vectorstore_path + "/" + arxiv_id
|
|
209
|
+
|
|
210
|
+
if self.summarize and not os.path.exists(vec_save_loc):
|
|
211
|
+
try:
|
|
212
|
+
loader = PyPDFLoader(
|
|
213
|
+
os.path.join(self.database_path, pdf_filename)
|
|
214
|
+
)
|
|
215
|
+
pages = loader.load()
|
|
216
|
+
full_text = "\n".join([p.page_content for p in pages])
|
|
217
|
+
|
|
218
|
+
if self.process_images:
|
|
219
|
+
image_descriptions = extract_and_describe_images(
|
|
220
|
+
os.path.join(self.database_path, pdf_filename)
|
|
221
|
+
)
|
|
222
|
+
full_text += (
|
|
223
|
+
"\n\n[Image Interpretations]\n"
|
|
224
|
+
+ "\n".join(image_descriptions)
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
except Exception as e:
|
|
228
|
+
full_text = f"Error loading paper: {e}"
|
|
229
|
+
|
|
230
|
+
papers.append({
|
|
231
|
+
"arxiv_id": arxiv_id,
|
|
232
|
+
"full_text": full_text,
|
|
233
|
+
})
|
|
234
|
+
|
|
235
|
+
return papers
|
|
236
|
+
|
|
237
|
+
def _fetch_node(self, state: PaperState) -> PaperState:
|
|
238
|
+
papers = self._fetch_papers(state["query"])
|
|
239
|
+
return {**state, "papers": papers}
|
|
240
|
+
|
|
241
|
+
def _summarize_node(self, state: PaperState) -> PaperState:
|
|
242
|
+
prompt = ChatPromptTemplate.from_template("""
|
|
243
|
+
You are a scientific assistant responsible for summarizing extracts from research papers, in the context of the following task: {context}
|
|
244
|
+
|
|
245
|
+
Summarize the retrieved scientific content below.
|
|
246
|
+
|
|
247
|
+
{retrieved_content}
|
|
248
|
+
""")
|
|
249
|
+
|
|
250
|
+
chain = prompt | self.llm | StrOutputParser()
|
|
251
|
+
|
|
252
|
+
summaries = [None] * len(state["papers"])
|
|
253
|
+
relevancy_scores = [0.0] * len(state["papers"])
|
|
254
|
+
|
|
255
|
+
def process_paper(i, paper):
|
|
256
|
+
arxiv_id = paper["arxiv_id"]
|
|
257
|
+
summary_filename = os.path.join(
|
|
258
|
+
self.summaries_path, f"{arxiv_id}_summary.txt"
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
try:
|
|
262
|
+
cleaned_text = remove_surrogates(paper["full_text"])
|
|
263
|
+
summary = chain.invoke(
|
|
264
|
+
{
|
|
265
|
+
"retrieved_content": cleaned_text,
|
|
266
|
+
"context": state["context"],
|
|
267
|
+
},
|
|
268
|
+
config=self.build_config(tags=["arxiv", "summarize_each"]),
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
except Exception as e:
|
|
272
|
+
summary = f"Error summarizing paper: {e}"
|
|
273
|
+
relevancy_scores[i] = 0.0
|
|
274
|
+
|
|
275
|
+
with open(summary_filename, "w") as f:
|
|
276
|
+
f.write(summary)
|
|
277
|
+
|
|
278
|
+
return i, summary
|
|
279
|
+
|
|
280
|
+
if "papers" not in state or len(state["papers"]) == 0:
|
|
281
|
+
print(
|
|
282
|
+
"No papers retrieved - bad query or network connection to ArXiv?"
|
|
283
|
+
)
|
|
284
|
+
return {**state, "summaries": None}
|
|
285
|
+
|
|
286
|
+
with ThreadPoolExecutor(
|
|
287
|
+
max_workers=min(32, len(state["papers"]))
|
|
288
|
+
) as executor:
|
|
289
|
+
futures = [
|
|
290
|
+
executor.submit(process_paper, i, paper)
|
|
291
|
+
for i, paper in enumerate(state["papers"])
|
|
292
|
+
]
|
|
293
|
+
|
|
294
|
+
for future in tqdm(
|
|
295
|
+
as_completed(futures),
|
|
296
|
+
total=len(futures),
|
|
297
|
+
desc="Summarizing Papers",
|
|
298
|
+
):
|
|
299
|
+
i, result = future.result()
|
|
300
|
+
summaries[i] = result
|
|
301
|
+
|
|
302
|
+
return {**state, "summaries": summaries}
|
|
303
|
+
|
|
304
|
+
def _rag_node(self, state: PaperState) -> PaperState:
|
|
305
|
+
new_state = state.copy()
|
|
306
|
+
rag_agent = RAGAgent(
|
|
307
|
+
llm=self.llm,
|
|
308
|
+
embedding=self.rag_embedding,
|
|
309
|
+
database_path=self.database_path,
|
|
310
|
+
)
|
|
311
|
+
new_state["final_summary"] = rag_agent.invoke(context=state["context"])[
|
|
312
|
+
"summary"
|
|
313
|
+
]
|
|
314
|
+
return new_state
|
|
315
|
+
|
|
316
|
+
def _aggregate_node(self, state: PaperState) -> PaperState:
|
|
317
|
+
summaries = state["summaries"]
|
|
318
|
+
papers = state["papers"]
|
|
319
|
+
formatted = []
|
|
320
|
+
|
|
321
|
+
if (
|
|
322
|
+
"summaries" not in state
|
|
323
|
+
or state["summaries"] is None
|
|
324
|
+
or "papers" not in state
|
|
325
|
+
or state["papers"] is None
|
|
326
|
+
):
|
|
327
|
+
return {**state, "final_summary": None}
|
|
328
|
+
|
|
329
|
+
for i, (paper, summary) in enumerate(zip(papers, summaries)):
|
|
330
|
+
citation = f"[{i + 1}] Arxiv ID: {paper['arxiv_id']}"
|
|
331
|
+
formatted.append(f"{citation}\n\nSummary:\n{summary}")
|
|
332
|
+
|
|
333
|
+
combined = "\n\n" + ("\n\n" + "-" * 40 + "\n\n").join(formatted)
|
|
334
|
+
|
|
335
|
+
with open(self.summaries_path + "/summaries_combined.txt", "w") as f:
|
|
336
|
+
f.write(combined)
|
|
337
|
+
|
|
338
|
+
prompt = ChatPromptTemplate.from_template("""
|
|
339
|
+
You are a scientific assistant helping extract insights from summaries of research papers.
|
|
340
|
+
|
|
341
|
+
Here are the summaries of a large number of extracts from scientific papers:
|
|
342
|
+
|
|
343
|
+
{Summaries}
|
|
344
|
+
|
|
345
|
+
Your task is to read all the summaries and provide a response to this task: {context}
|
|
346
|
+
""")
|
|
347
|
+
|
|
348
|
+
chain = prompt | self.llm | StrOutputParser()
|
|
349
|
+
|
|
350
|
+
final_summary = chain.invoke(
|
|
351
|
+
{
|
|
352
|
+
"Summaries": combined,
|
|
353
|
+
"context": state["context"],
|
|
354
|
+
},
|
|
355
|
+
config=self.build_config(tags=["arxiv", "aggregate"]),
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
with open(self.summaries_path + "/final_summary.txt", "w") as f:
|
|
359
|
+
f.write(final_summary)
|
|
360
|
+
|
|
361
|
+
return {**state, "final_summary": final_summary}
|
|
362
|
+
|
|
363
|
+
def _build_graph(self):
|
|
364
|
+
graph = StateGraph(PaperState)
|
|
365
|
+
|
|
366
|
+
self.add_node(graph, self._fetch_node)
|
|
367
|
+
if self.summarize:
|
|
368
|
+
if self.rag_embedding:
|
|
369
|
+
self.add_node(graph, self._rag_node)
|
|
370
|
+
graph.set_entry_point("_fetch_node")
|
|
371
|
+
graph.add_edge("_fetch_node", "_rag_node")
|
|
372
|
+
graph.set_finish_point("_rag_node")
|
|
373
|
+
else:
|
|
374
|
+
self.add_node(graph, self._summarize_node)
|
|
375
|
+
self.add_node(graph, self._aggregate_node)
|
|
376
|
+
|
|
377
|
+
graph.set_entry_point("_fetch_node")
|
|
378
|
+
graph.add_edge("_fetch_node", "_summarize_node")
|
|
379
|
+
graph.add_edge("_summarize_node", "_aggregate_node")
|
|
380
|
+
graph.set_finish_point("_aggregate_node")
|
|
381
|
+
else:
|
|
382
|
+
graph.set_entry_point("_fetch_node")
|
|
383
|
+
graph.set_finish_point("_fetch_node")
|
|
384
|
+
|
|
385
|
+
return graph.compile(checkpointer=self.checkpointer)
|
|
386
|
+
|
|
387
|
+
def _invoke(
|
|
388
|
+
self,
|
|
389
|
+
inputs: Mapping[str, Any],
|
|
390
|
+
*,
|
|
391
|
+
summarize: bool | None = None,
|
|
392
|
+
recursion_limit: int = 1000,
|
|
393
|
+
**_,
|
|
394
|
+
) -> str:
|
|
395
|
+
config = self.build_config(
|
|
396
|
+
recursion_limit=recursion_limit, tags=["graph"]
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
# this seems dumb, but it's b/c sometimes we had referred to the value as
|
|
400
|
+
# 'query' other times as 'arxiv_search_query' so trying to keep it compatible
|
|
401
|
+
# aliasing: accept arxiv_search_query -> query
|
|
402
|
+
if "query" not in inputs:
|
|
403
|
+
if "arxiv_search_query" in inputs:
|
|
404
|
+
# make a shallow copy and rename the key
|
|
405
|
+
inputs = dict(inputs)
|
|
406
|
+
inputs["query"] = inputs.pop("arxiv_search_query")
|
|
407
|
+
else:
|
|
408
|
+
raise KeyError(
|
|
409
|
+
"Missing 'query' in inputs (alias 'arxiv_search_query' also accepted)."
|
|
410
|
+
)
|
|
411
|
+
|
|
412
|
+
result = self._action.invoke(inputs, config)
|
|
413
|
+
|
|
414
|
+
use_summary = self.summarize if summarize is None else summarize
|
|
415
|
+
|
|
416
|
+
return (
|
|
417
|
+
result.get("final_summary", "No summary generated.")
|
|
418
|
+
if use_summary
|
|
419
|
+
else "\n\nFinished Fetching papers!"
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
# NOTE: Run test in `tests/agents/test_arxiv_agent/test_arxiv_agent.py` via:
|
|
424
|
+
#
|
|
425
|
+
# pytest -s tests/agents/test_arxiv_agent
|
|
426
|
+
#
|
|
427
|
+
# OR
|
|
428
|
+
#
|
|
429
|
+
# uv run pytest -s tests/agents/test_arxiv_agent
|