ursa-ai 0.0.3__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ursa-ai might be problematic. Click here for more details.
- ursa/agents/__init__.py +10 -0
- ursa/agents/arxiv_agent.py +349 -0
- ursa/agents/base.py +42 -0
- ursa/agents/code_review_agent.py +332 -0
- ursa/agents/execution_agent.py +497 -0
- ursa/agents/hypothesizer_agent.py +597 -0
- ursa/agents/mp_agent.py +257 -0
- ursa/agents/planning_agent.py +138 -0
- ursa/agents/recall_agent.py +25 -0
- ursa/agents/websearch_agent.py +193 -0
- ursa/prompt_library/code_review_prompts.py +51 -0
- ursa/prompt_library/execution_prompts.py +36 -0
- ursa/prompt_library/hypothesizer_prompts.py +17 -0
- ursa/prompt_library/literature_prompts.py +11 -0
- ursa/prompt_library/planning_prompts.py +79 -0
- ursa/prompt_library/websearch_prompts.py +131 -0
- ursa/tools/run_command.py +27 -0
- ursa/tools/write_code.py +42 -0
- ursa/util/diff_renderer.py +121 -0
- ursa/util/memory_logger.py +171 -0
- ursa/util/parse.py +89 -0
- ursa_ai-0.2.2.dist-info/METADATA +130 -0
- ursa_ai-0.2.2.dist-info/RECORD +26 -0
- ursa_ai-0.2.2.dist-info/licenses/LICENSE +8 -0
- ursa/__init__.py +0 -2
- ursa/py.typed +0 -0
- ursa_ai-0.0.3.dist-info/METADATA +0 -7
- ursa_ai-0.0.3.dist-info/RECORD +0 -6
- {ursa_ai-0.0.3.dist-info → ursa_ai-0.2.2.dist-info}/WHEEL +0 -0
- {ursa_ai-0.0.3.dist-info → ursa_ai-0.2.2.dist-info}/top_level.txt +0 -0
ursa/agents/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
from .planning_agent import PlanningAgent, PlanningState
|
|
2
|
+
from .websearch_agent import WebSearchAgent, WebSearchState
|
|
3
|
+
from .execution_agent import ExecutionAgent, ExecutionState
|
|
4
|
+
from .code_review_agent import CodeReviewAgent, CodeReviewState
|
|
5
|
+
from .hypothesizer_agent import HypothesizerAgent, HypothesizerState
|
|
6
|
+
from .arxiv_agent import ArxivAgent, PaperState, PaperMetadata
|
|
7
|
+
from .recall_agent import RecallAgent
|
|
8
|
+
from .base import BaseAgent, BaseChatModel
|
|
9
|
+
from .mp_agent import MaterialsProjectAgent
|
|
10
|
+
|
|
@@ -0,0 +1,349 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import pymupdf
|
|
3
|
+
import requests
|
|
4
|
+
import feedparser
|
|
5
|
+
from PIL import Image
|
|
6
|
+
from io import BytesIO
|
|
7
|
+
import base64
|
|
8
|
+
from urllib.parse import quote
|
|
9
|
+
from typing_extensions import TypedDict, List
|
|
10
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
11
|
+
from tqdm import tqdm
|
|
12
|
+
import statistics
|
|
13
|
+
import re
|
|
14
|
+
|
|
15
|
+
from langchain_community.document_loaders import PyPDFLoader
|
|
16
|
+
from langchain_core.output_parsers import StrOutputParser
|
|
17
|
+
from langchain_core.prompts import ChatPromptTemplate
|
|
18
|
+
from langgraph.graph import StateGraph, END, START
|
|
19
|
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
20
|
+
from langchain_chroma import Chroma
|
|
21
|
+
|
|
22
|
+
from .base import BaseAgent
|
|
23
|
+
|
|
24
|
+
try:
|
|
25
|
+
from openai import OpenAI
|
|
26
|
+
except:
|
|
27
|
+
pass
|
|
28
|
+
|
|
29
|
+
# embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
|
|
30
|
+
# embeddings = OpenAIEmbeddings()
|
|
31
|
+
|
|
32
|
+
class PaperMetadata(TypedDict):
|
|
33
|
+
arxiv_id: str
|
|
34
|
+
full_text: str
|
|
35
|
+
|
|
36
|
+
class PaperState(TypedDict, total=False):
|
|
37
|
+
query: str
|
|
38
|
+
context: str
|
|
39
|
+
papers: List[PaperMetadata]
|
|
40
|
+
summaries: List[str]
|
|
41
|
+
final_summary: str
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def describe_image(image: Image.Image) -> str:
|
|
45
|
+
if 'OpenAI' not in globals():
|
|
46
|
+
print("Vision transformer for summarizing images currently only implemented for OpenAI API.")
|
|
47
|
+
return ""
|
|
48
|
+
client = OpenAI()
|
|
49
|
+
|
|
50
|
+
buffered = BytesIO()
|
|
51
|
+
image.save(buffered, format="PNG")
|
|
52
|
+
img_base64 = base64.b64encode(buffered.getvalue()).decode()
|
|
53
|
+
|
|
54
|
+
response = client.chat.completions.create(
|
|
55
|
+
model="gpt-4-vision-preview",
|
|
56
|
+
messages=[
|
|
57
|
+
{"role": "system", "content": "You are a scientific assistant who explains plots and scientific diagrams."},
|
|
58
|
+
{
|
|
59
|
+
"role": "user",
|
|
60
|
+
"content": [
|
|
61
|
+
{"type": "text", "text": "Describe this scientific image or plot in detail."},
|
|
62
|
+
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_base64}"}}
|
|
63
|
+
],
|
|
64
|
+
},
|
|
65
|
+
],
|
|
66
|
+
max_tokens=500,
|
|
67
|
+
)
|
|
68
|
+
return response.choices[0].message.content.strip()
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def extract_and_describe_images(pdf_path: str, max_images: int = 5) -> List[str]:
|
|
72
|
+
doc = pymupdf.open(pdf_path)
|
|
73
|
+
descriptions = []
|
|
74
|
+
image_count = 0
|
|
75
|
+
|
|
76
|
+
for page_index in range(len(doc)):
|
|
77
|
+
if image_count >= max_images:
|
|
78
|
+
break
|
|
79
|
+
page = doc[page_index]
|
|
80
|
+
images = page.get_images(full=True)
|
|
81
|
+
|
|
82
|
+
for img_index, img in enumerate(images):
|
|
83
|
+
if image_count >= max_images:
|
|
84
|
+
break
|
|
85
|
+
xref = img[0]
|
|
86
|
+
base_image = doc.extract_image(xref)
|
|
87
|
+
image_bytes = base_image["image"]
|
|
88
|
+
image = Image.open(BytesIO(image_bytes))
|
|
89
|
+
|
|
90
|
+
try:
|
|
91
|
+
desc = describe_image(image)
|
|
92
|
+
descriptions.append(f"Page {page_index + 1}, Image {img_index + 1}: {desc}")
|
|
93
|
+
except Exception as e:
|
|
94
|
+
descriptions.append(f"Page {page_index + 1}, Image {img_index + 1}: [Error: {e}]")
|
|
95
|
+
image_count += 1
|
|
96
|
+
|
|
97
|
+
return descriptions
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def remove_surrogates(text: str) -> str:
|
|
101
|
+
return re.sub(r'[\ud800-\udfff]', '', text)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class ArxivAgent(BaseAgent):
|
|
105
|
+
def __init__(self,
|
|
106
|
+
llm="openai/o3-mini",
|
|
107
|
+
summarize: bool = True,
|
|
108
|
+
process_images = True,
|
|
109
|
+
max_results: int = 3,
|
|
110
|
+
download_papers: bool = True,
|
|
111
|
+
rag_embedding = None,
|
|
112
|
+
database_path ='arxiv_papers',
|
|
113
|
+
summaries_path ='arxiv_generated_summaries',
|
|
114
|
+
vectorstore_path ='arxiv_vectorstores',**kwargs):
|
|
115
|
+
|
|
116
|
+
super().__init__(llm, **kwargs)
|
|
117
|
+
self.summarize = summarize
|
|
118
|
+
self.process_images = process_images
|
|
119
|
+
self.max_results = max_results
|
|
120
|
+
self.database_path = database_path
|
|
121
|
+
self.summaries_path = summaries_path
|
|
122
|
+
self.vectorstore_path = vectorstore_path
|
|
123
|
+
self.download_papers = download_papers
|
|
124
|
+
self.rag_embedding = rag_embedding
|
|
125
|
+
|
|
126
|
+
self.graph = self._build_graph()
|
|
127
|
+
|
|
128
|
+
os.makedirs(self.database_path, exist_ok=True)
|
|
129
|
+
|
|
130
|
+
os.makedirs(self.summaries_path, exist_ok=True)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def _fetch_papers(self, query: str) -> List[PaperMetadata]:
|
|
134
|
+
|
|
135
|
+
if self.download_papers:
|
|
136
|
+
|
|
137
|
+
encoded_query = quote(query)
|
|
138
|
+
url = f"http://export.arxiv.org/api/query?search_query=all:{encoded_query}&start=0&max_results={self.max_results}"
|
|
139
|
+
feed = feedparser.parse(url)
|
|
140
|
+
|
|
141
|
+
for i,entry in enumerate(feed.entries):
|
|
142
|
+
full_id = entry.id.split('/abs/')[-1]
|
|
143
|
+
arxiv_id = full_id.split('/')[-1]
|
|
144
|
+
title = entry.title.strip()
|
|
145
|
+
authors = ", ".join(author.name for author in entry.authors)
|
|
146
|
+
pdf_url = f"https://arxiv.org/pdf/{full_id}.pdf"
|
|
147
|
+
pdf_filename = os.path.join(self.database_path, f"{arxiv_id}.pdf")
|
|
148
|
+
|
|
149
|
+
if os.path.exists(pdf_filename):
|
|
150
|
+
print(f"Paper # {i+1}, Title: {title}, already exists in database")
|
|
151
|
+
else:
|
|
152
|
+
print(f"Downloading paper # {i+1}, Title: {title}")
|
|
153
|
+
response = requests.get(pdf_url)
|
|
154
|
+
with open(pdf_filename, 'wb') as f:
|
|
155
|
+
f.write(response.content)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
papers = []
|
|
159
|
+
|
|
160
|
+
pdf_files = [f for f in os.listdir(self.database_path) if f.lower().endswith(".pdf")]
|
|
161
|
+
|
|
162
|
+
for i,pdf_filename in enumerate(pdf_files):
|
|
163
|
+
full_text = ""
|
|
164
|
+
arxiv_id = pdf_filename.split('.pdf')[0]
|
|
165
|
+
vec_save_loc = self.vectorstore_path + '/' + arxiv_id
|
|
166
|
+
|
|
167
|
+
if self.summarize and not os.path.exists(vec_save_loc):
|
|
168
|
+
try:
|
|
169
|
+
loader = PyPDFLoader( os.path.join(self.database_path, pdf_filename) )
|
|
170
|
+
pages = loader.load()
|
|
171
|
+
full_text = "\n".join([p.page_content for p in pages])
|
|
172
|
+
|
|
173
|
+
if self.process_images:
|
|
174
|
+
image_descriptions = extract_and_describe_images( os.path.join(self.database_path, pdf_filename) )
|
|
175
|
+
full_text += "\n\n[Image Interpretations]\n" + "\n".join(image_descriptions)
|
|
176
|
+
|
|
177
|
+
except Exception as e:
|
|
178
|
+
full_text = f"Error loading paper: {e}"
|
|
179
|
+
|
|
180
|
+
papers.append({
|
|
181
|
+
"arxiv_id": arxiv_id,
|
|
182
|
+
"full_text": full_text,
|
|
183
|
+
})
|
|
184
|
+
|
|
185
|
+
return papers
|
|
186
|
+
|
|
187
|
+
def _fetch_node(self, state: PaperState) -> PaperState:
|
|
188
|
+
papers = self._fetch_papers(state["query"])
|
|
189
|
+
return {**state, "papers": papers}
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def _get_or_build_vectorstore(self, paper_text: str, arxiv_id: str):
|
|
193
|
+
os.makedirs(self.vectorstore_path, exist_ok=True)
|
|
194
|
+
|
|
195
|
+
persist_directory = os.path.join(self.vectorstore_path, arxiv_id)
|
|
196
|
+
|
|
197
|
+
if os.path.exists(persist_directory):
|
|
198
|
+
vectorstore = Chroma(persist_directory=persist_directory, embedding_function=self.rag_embedding)
|
|
199
|
+
else:
|
|
200
|
+
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
|
201
|
+
docs = splitter.create_documents([paper_text])
|
|
202
|
+
vectorstore = Chroma.from_documents(docs, self.rag_embedding, persist_directory=persist_directory)
|
|
203
|
+
|
|
204
|
+
return vectorstore.as_retriever(search_kwargs={"k": 5})
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def _summarize_node(self, state: PaperState) -> PaperState:
|
|
208
|
+
|
|
209
|
+
prompt = ChatPromptTemplate.from_template("""
|
|
210
|
+
You are a scientific assistant responsible for summarizing extracts from research papers, in the context of the following task: {context}
|
|
211
|
+
|
|
212
|
+
Summarize the retrieved scientific content below.
|
|
213
|
+
|
|
214
|
+
{retrieved_content}
|
|
215
|
+
""")
|
|
216
|
+
|
|
217
|
+
chain = prompt | self.llm | StrOutputParser()
|
|
218
|
+
|
|
219
|
+
summaries = [None] * len(state["papers"])
|
|
220
|
+
relevancy_scores = [0.0] * len(state["papers"])
|
|
221
|
+
|
|
222
|
+
def process_paper(i, paper):
|
|
223
|
+
arxiv_id = paper["arxiv_id"]
|
|
224
|
+
summary_filename = os.path.join(self.summaries_path, f"{arxiv_id}_summary.txt")
|
|
225
|
+
|
|
226
|
+
try:
|
|
227
|
+
cleaned_text = remove_surrogates(paper["full_text"])
|
|
228
|
+
if self.rag_embedding:
|
|
229
|
+
retriever = self._get_or_build_vectorstore(cleaned_text, arxiv_id)
|
|
230
|
+
|
|
231
|
+
relevant_docs_with_scores = retriever.vectorstore.similarity_search_with_score(state["context"], k=5)
|
|
232
|
+
|
|
233
|
+
if relevant_docs_with_scores:
|
|
234
|
+
score = sum([s for _, s in relevant_docs_with_scores]) / len(relevant_docs_with_scores)
|
|
235
|
+
relevancy_scores[i] = abs(1.0 - score)
|
|
236
|
+
else:
|
|
237
|
+
relevancy_scores[i] = 0.0
|
|
238
|
+
|
|
239
|
+
retrieved_content = "\n\n".join([doc.page_content for doc, _ in relevant_docs_with_scores])
|
|
240
|
+
else:
|
|
241
|
+
retrieved_content = cleaned_text
|
|
242
|
+
|
|
243
|
+
summary = chain.invoke({"retrieved_content": retrieved_content, "context": state["context"]})
|
|
244
|
+
|
|
245
|
+
except Exception as e:
|
|
246
|
+
summary = f"Error summarizing paper: {e}"
|
|
247
|
+
relevancy_scores[i] = 0.0
|
|
248
|
+
|
|
249
|
+
with open(summary_filename, "w") as f:
|
|
250
|
+
f.write(summary)
|
|
251
|
+
|
|
252
|
+
return i, summary
|
|
253
|
+
|
|
254
|
+
if ('papers' not in state or len(state['papers']) == 0):
|
|
255
|
+
print(f"No papers retrieved - bad query or network connection to ArXiv?")
|
|
256
|
+
return {**state, "summaries": None}
|
|
257
|
+
|
|
258
|
+
with ThreadPoolExecutor(max_workers=min(32, len(state["papers"]))) as executor:
|
|
259
|
+
futures = [executor.submit(process_paper, i, paper) for i, paper in enumerate(state["papers"])]
|
|
260
|
+
|
|
261
|
+
for future in tqdm(as_completed(futures), total=len(futures), desc="Summarizing Papers"):
|
|
262
|
+
i, result = future.result()
|
|
263
|
+
summaries[i] = result
|
|
264
|
+
|
|
265
|
+
if self.rag_embedding:
|
|
266
|
+
print(f"\nMax Relevancy Score: {max(relevancy_scores)}")
|
|
267
|
+
print(f"Min Relevancy Score: {min(relevancy_scores)}")
|
|
268
|
+
print(f"Median Relevancy Score: {statistics.median(relevancy_scores)}\n")
|
|
269
|
+
|
|
270
|
+
return {**state, "summaries": summaries}
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def _aggregate_node(self, state: PaperState) -> PaperState:
|
|
275
|
+
summaries = state["summaries"]
|
|
276
|
+
papers = state["papers"]
|
|
277
|
+
formatted = []
|
|
278
|
+
|
|
279
|
+
if 'summaries' not in state or state['summaries'] is None or 'papers' not in state or state['papers'] is None:
|
|
280
|
+
return {**state, "final_summary": None}
|
|
281
|
+
|
|
282
|
+
for i, (paper, summary) in enumerate(zip(papers, summaries)):
|
|
283
|
+
citation = f"[{i+1}] Arxiv ID: {paper['arxiv_id']}"
|
|
284
|
+
formatted.append(f"{citation}\n\nSummary:\n{summary}")
|
|
285
|
+
|
|
286
|
+
combined = "\n\n" + ("\n\n" + "-" * 40 + "\n\n").join(formatted)
|
|
287
|
+
|
|
288
|
+
with open(self.summaries_path+'/summaries_combined.txt', "w") as f:
|
|
289
|
+
f.write(combined)
|
|
290
|
+
|
|
291
|
+
prompt = ChatPromptTemplate.from_template("""
|
|
292
|
+
You are a scientific assistant helping extract insights from summaries of research papers.
|
|
293
|
+
|
|
294
|
+
Here are the summaries of a large number of extracts from scientific papers:
|
|
295
|
+
|
|
296
|
+
{Summaries}
|
|
297
|
+
|
|
298
|
+
Your task is to read all the summaries and provide a response to this task: {context}
|
|
299
|
+
""")
|
|
300
|
+
|
|
301
|
+
chain = prompt | self.llm | StrOutputParser()
|
|
302
|
+
|
|
303
|
+
final_summary = chain.invoke({"Summaries": combined, "context":state["context"]})
|
|
304
|
+
|
|
305
|
+
with open(self.summaries_path+'/final_summary.txt', "w") as f:
|
|
306
|
+
f.write(final_summary)
|
|
307
|
+
|
|
308
|
+
return {**state, "final_summary": final_summary}
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
def _build_graph(self):
|
|
313
|
+
builder = StateGraph(PaperState)
|
|
314
|
+
builder.add_node("fetch_papers", self._fetch_node)
|
|
315
|
+
|
|
316
|
+
if self.summarize:
|
|
317
|
+
builder.add_node("summarize_each", self._summarize_node)
|
|
318
|
+
builder.add_node("aggregate", self._aggregate_node)
|
|
319
|
+
|
|
320
|
+
builder.set_entry_point("fetch_papers")
|
|
321
|
+
builder.add_edge("fetch_papers", "summarize_each")
|
|
322
|
+
builder.add_edge("summarize_each", "aggregate")
|
|
323
|
+
builder.set_finish_point("aggregate")
|
|
324
|
+
|
|
325
|
+
else:
|
|
326
|
+
builder.set_entry_point("fetch_papers")
|
|
327
|
+
builder.set_finish_point("fetch_papers")
|
|
328
|
+
|
|
329
|
+
graph = builder.compile()
|
|
330
|
+
return graph
|
|
331
|
+
|
|
332
|
+
def run(self, arxiv_search_query: str, context: str) -> str:
|
|
333
|
+
result = self.graph.invoke({"query": arxiv_search_query, "context":context})
|
|
334
|
+
|
|
335
|
+
if self.summarize:
|
|
336
|
+
return result.get("final_summary", "No summary generated.")
|
|
337
|
+
else:
|
|
338
|
+
return "\n\nFinished Fetching papers!"
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
if __name__ == "__main__":
|
|
343
|
+
agent = ArxivAgent()
|
|
344
|
+
result = agent.run(arxiv_search_query="Experimental Constraints on neutron star radius",
|
|
345
|
+
context="What are the constraints on the neutron star radius and what uncertainties are there on the constraints?")
|
|
346
|
+
|
|
347
|
+
print(result)
|
|
348
|
+
|
|
349
|
+
|
ursa/agents/base.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
from langchain_core.language_models.chat_models import BaseChatModel
|
|
2
|
+
from langchain_litellm import ChatLiteLLM
|
|
3
|
+
from langgraph.checkpoint.base import BaseCheckpointSaver
|
|
4
|
+
from langchain_core.load import dumps
|
|
5
|
+
|
|
6
|
+
import json
|
|
7
|
+
|
|
8
|
+
class BaseAgent:
|
|
9
|
+
# llm: BaseChatModel
|
|
10
|
+
# llm_with_tools: Runnable[LanguageModelInput, BaseMessage]
|
|
11
|
+
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
llm: str | BaseChatModel,
|
|
15
|
+
checkpointer: BaseCheckpointSaver = None,
|
|
16
|
+
**kwargs,
|
|
17
|
+
):
|
|
18
|
+
match llm:
|
|
19
|
+
case BaseChatModel():
|
|
20
|
+
self.llm = llm
|
|
21
|
+
|
|
22
|
+
case str():
|
|
23
|
+
self.llm_provider, self.llm_model = llm.split("/")
|
|
24
|
+
self.llm = ChatLiteLLM(
|
|
25
|
+
model=llm,
|
|
26
|
+
max_tokens=kwargs.pop("max_tokens", 10000),
|
|
27
|
+
max_retries=kwargs.pop("max_retries", 2),
|
|
28
|
+
**kwargs,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
case _:
|
|
32
|
+
raise TypeError(
|
|
33
|
+
"llm argument must be a string with the provider and model, or a BaseChatModel instance."
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
self.checkpointer = checkpointer
|
|
37
|
+
self.thread_id = self.__class__.__name__
|
|
38
|
+
|
|
39
|
+
def write_state(self, filename, state):
|
|
40
|
+
json_state = dumps(state, ensure_ascii=False)
|
|
41
|
+
with open(filename, "w") as f:
|
|
42
|
+
f.write(json_state)
|