ursa-ai 0.0.3__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ursa-ai might be problematic. Click here for more details.

@@ -0,0 +1,10 @@
1
+ from .planning_agent import PlanningAgent, PlanningState
2
+ from .websearch_agent import WebSearchAgent, WebSearchState
3
+ from .execution_agent import ExecutionAgent, ExecutionState
4
+ from .code_review_agent import CodeReviewAgent, CodeReviewState
5
+ from .hypothesizer_agent import HypothesizerAgent, HypothesizerState
6
+ from .arxiv_agent import ArxivAgent, PaperState, PaperMetadata
7
+ from .recall_agent import RecallAgent
8
+ from .base import BaseAgent, BaseChatModel
9
+ from .mp_agent import MaterialsProjectAgent
10
+
@@ -0,0 +1,349 @@
1
+ import os
2
+ import pymupdf
3
+ import requests
4
+ import feedparser
5
+ from PIL import Image
6
+ from io import BytesIO
7
+ import base64
8
+ from urllib.parse import quote
9
+ from typing_extensions import TypedDict, List
10
+ from concurrent.futures import ThreadPoolExecutor, as_completed
11
+ from tqdm import tqdm
12
+ import statistics
13
+ import re
14
+
15
+ from langchain_community.document_loaders import PyPDFLoader
16
+ from langchain_core.output_parsers import StrOutputParser
17
+ from langchain_core.prompts import ChatPromptTemplate
18
+ from langgraph.graph import StateGraph, END, START
19
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
20
+ from langchain_chroma import Chroma
21
+
22
+ from .base import BaseAgent
23
+
24
+ try:
25
+ from openai import OpenAI
26
+ except:
27
+ pass
28
+
29
+ # embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
30
+ # embeddings = OpenAIEmbeddings()
31
+
32
+ class PaperMetadata(TypedDict):
33
+ arxiv_id: str
34
+ full_text: str
35
+
36
+ class PaperState(TypedDict, total=False):
37
+ query: str
38
+ context: str
39
+ papers: List[PaperMetadata]
40
+ summaries: List[str]
41
+ final_summary: str
42
+
43
+
44
+ def describe_image(image: Image.Image) -> str:
45
+ if 'OpenAI' not in globals():
46
+ print("Vision transformer for summarizing images currently only implemented for OpenAI API.")
47
+ return ""
48
+ client = OpenAI()
49
+
50
+ buffered = BytesIO()
51
+ image.save(buffered, format="PNG")
52
+ img_base64 = base64.b64encode(buffered.getvalue()).decode()
53
+
54
+ response = client.chat.completions.create(
55
+ model="gpt-4-vision-preview",
56
+ messages=[
57
+ {"role": "system", "content": "You are a scientific assistant who explains plots and scientific diagrams."},
58
+ {
59
+ "role": "user",
60
+ "content": [
61
+ {"type": "text", "text": "Describe this scientific image or plot in detail."},
62
+ {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_base64}"}}
63
+ ],
64
+ },
65
+ ],
66
+ max_tokens=500,
67
+ )
68
+ return response.choices[0].message.content.strip()
69
+
70
+
71
+ def extract_and_describe_images(pdf_path: str, max_images: int = 5) -> List[str]:
72
+ doc = pymupdf.open(pdf_path)
73
+ descriptions = []
74
+ image_count = 0
75
+
76
+ for page_index in range(len(doc)):
77
+ if image_count >= max_images:
78
+ break
79
+ page = doc[page_index]
80
+ images = page.get_images(full=True)
81
+
82
+ for img_index, img in enumerate(images):
83
+ if image_count >= max_images:
84
+ break
85
+ xref = img[0]
86
+ base_image = doc.extract_image(xref)
87
+ image_bytes = base_image["image"]
88
+ image = Image.open(BytesIO(image_bytes))
89
+
90
+ try:
91
+ desc = describe_image(image)
92
+ descriptions.append(f"Page {page_index + 1}, Image {img_index + 1}: {desc}")
93
+ except Exception as e:
94
+ descriptions.append(f"Page {page_index + 1}, Image {img_index + 1}: [Error: {e}]")
95
+ image_count += 1
96
+
97
+ return descriptions
98
+
99
+
100
+ def remove_surrogates(text: str) -> str:
101
+ return re.sub(r'[\ud800-\udfff]', '', text)
102
+
103
+
104
+ class ArxivAgent(BaseAgent):
105
+ def __init__(self,
106
+ llm="openai/o3-mini",
107
+ summarize: bool = True,
108
+ process_images = True,
109
+ max_results: int = 3,
110
+ download_papers: bool = True,
111
+ rag_embedding = None,
112
+ database_path ='arxiv_papers',
113
+ summaries_path ='arxiv_generated_summaries',
114
+ vectorstore_path ='arxiv_vectorstores',**kwargs):
115
+
116
+ super().__init__(llm, **kwargs)
117
+ self.summarize = summarize
118
+ self.process_images = process_images
119
+ self.max_results = max_results
120
+ self.database_path = database_path
121
+ self.summaries_path = summaries_path
122
+ self.vectorstore_path = vectorstore_path
123
+ self.download_papers = download_papers
124
+ self.rag_embedding = rag_embedding
125
+
126
+ self.graph = self._build_graph()
127
+
128
+ os.makedirs(self.database_path, exist_ok=True)
129
+
130
+ os.makedirs(self.summaries_path, exist_ok=True)
131
+
132
+
133
+ def _fetch_papers(self, query: str) -> List[PaperMetadata]:
134
+
135
+ if self.download_papers:
136
+
137
+ encoded_query = quote(query)
138
+ url = f"http://export.arxiv.org/api/query?search_query=all:{encoded_query}&start=0&max_results={self.max_results}"
139
+ feed = feedparser.parse(url)
140
+
141
+ for i,entry in enumerate(feed.entries):
142
+ full_id = entry.id.split('/abs/')[-1]
143
+ arxiv_id = full_id.split('/')[-1]
144
+ title = entry.title.strip()
145
+ authors = ", ".join(author.name for author in entry.authors)
146
+ pdf_url = f"https://arxiv.org/pdf/{full_id}.pdf"
147
+ pdf_filename = os.path.join(self.database_path, f"{arxiv_id}.pdf")
148
+
149
+ if os.path.exists(pdf_filename):
150
+ print(f"Paper # {i+1}, Title: {title}, already exists in database")
151
+ else:
152
+ print(f"Downloading paper # {i+1}, Title: {title}")
153
+ response = requests.get(pdf_url)
154
+ with open(pdf_filename, 'wb') as f:
155
+ f.write(response.content)
156
+
157
+
158
+ papers = []
159
+
160
+ pdf_files = [f for f in os.listdir(self.database_path) if f.lower().endswith(".pdf")]
161
+
162
+ for i,pdf_filename in enumerate(pdf_files):
163
+ full_text = ""
164
+ arxiv_id = pdf_filename.split('.pdf')[0]
165
+ vec_save_loc = self.vectorstore_path + '/' + arxiv_id
166
+
167
+ if self.summarize and not os.path.exists(vec_save_loc):
168
+ try:
169
+ loader = PyPDFLoader( os.path.join(self.database_path, pdf_filename) )
170
+ pages = loader.load()
171
+ full_text = "\n".join([p.page_content for p in pages])
172
+
173
+ if self.process_images:
174
+ image_descriptions = extract_and_describe_images( os.path.join(self.database_path, pdf_filename) )
175
+ full_text += "\n\n[Image Interpretations]\n" + "\n".join(image_descriptions)
176
+
177
+ except Exception as e:
178
+ full_text = f"Error loading paper: {e}"
179
+
180
+ papers.append({
181
+ "arxiv_id": arxiv_id,
182
+ "full_text": full_text,
183
+ })
184
+
185
+ return papers
186
+
187
+ def _fetch_node(self, state: PaperState) -> PaperState:
188
+ papers = self._fetch_papers(state["query"])
189
+ return {**state, "papers": papers}
190
+
191
+
192
+ def _get_or_build_vectorstore(self, paper_text: str, arxiv_id: str):
193
+ os.makedirs(self.vectorstore_path, exist_ok=True)
194
+
195
+ persist_directory = os.path.join(self.vectorstore_path, arxiv_id)
196
+
197
+ if os.path.exists(persist_directory):
198
+ vectorstore = Chroma(persist_directory=persist_directory, embedding_function=self.rag_embedding)
199
+ else:
200
+ splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
201
+ docs = splitter.create_documents([paper_text])
202
+ vectorstore = Chroma.from_documents(docs, self.rag_embedding, persist_directory=persist_directory)
203
+
204
+ return vectorstore.as_retriever(search_kwargs={"k": 5})
205
+
206
+
207
+ def _summarize_node(self, state: PaperState) -> PaperState:
208
+
209
+ prompt = ChatPromptTemplate.from_template("""
210
+ You are a scientific assistant responsible for summarizing extracts from research papers, in the context of the following task: {context}
211
+
212
+ Summarize the retrieved scientific content below.
213
+
214
+ {retrieved_content}
215
+ """)
216
+
217
+ chain = prompt | self.llm | StrOutputParser()
218
+
219
+ summaries = [None] * len(state["papers"])
220
+ relevancy_scores = [0.0] * len(state["papers"])
221
+
222
+ def process_paper(i, paper):
223
+ arxiv_id = paper["arxiv_id"]
224
+ summary_filename = os.path.join(self.summaries_path, f"{arxiv_id}_summary.txt")
225
+
226
+ try:
227
+ cleaned_text = remove_surrogates(paper["full_text"])
228
+ if self.rag_embedding:
229
+ retriever = self._get_or_build_vectorstore(cleaned_text, arxiv_id)
230
+
231
+ relevant_docs_with_scores = retriever.vectorstore.similarity_search_with_score(state["context"], k=5)
232
+
233
+ if relevant_docs_with_scores:
234
+ score = sum([s for _, s in relevant_docs_with_scores]) / len(relevant_docs_with_scores)
235
+ relevancy_scores[i] = abs(1.0 - score)
236
+ else:
237
+ relevancy_scores[i] = 0.0
238
+
239
+ retrieved_content = "\n\n".join([doc.page_content for doc, _ in relevant_docs_with_scores])
240
+ else:
241
+ retrieved_content = cleaned_text
242
+
243
+ summary = chain.invoke({"retrieved_content": retrieved_content, "context": state["context"]})
244
+
245
+ except Exception as e:
246
+ summary = f"Error summarizing paper: {e}"
247
+ relevancy_scores[i] = 0.0
248
+
249
+ with open(summary_filename, "w") as f:
250
+ f.write(summary)
251
+
252
+ return i, summary
253
+
254
+ if ('papers' not in state or len(state['papers']) == 0):
255
+ print(f"No papers retrieved - bad query or network connection to ArXiv?")
256
+ return {**state, "summaries": None}
257
+
258
+ with ThreadPoolExecutor(max_workers=min(32, len(state["papers"]))) as executor:
259
+ futures = [executor.submit(process_paper, i, paper) for i, paper in enumerate(state["papers"])]
260
+
261
+ for future in tqdm(as_completed(futures), total=len(futures), desc="Summarizing Papers"):
262
+ i, result = future.result()
263
+ summaries[i] = result
264
+
265
+ if self.rag_embedding:
266
+ print(f"\nMax Relevancy Score: {max(relevancy_scores)}")
267
+ print(f"Min Relevancy Score: {min(relevancy_scores)}")
268
+ print(f"Median Relevancy Score: {statistics.median(relevancy_scores)}\n")
269
+
270
+ return {**state, "summaries": summaries}
271
+
272
+
273
+
274
+ def _aggregate_node(self, state: PaperState) -> PaperState:
275
+ summaries = state["summaries"]
276
+ papers = state["papers"]
277
+ formatted = []
278
+
279
+ if 'summaries' not in state or state['summaries'] is None or 'papers' not in state or state['papers'] is None:
280
+ return {**state, "final_summary": None}
281
+
282
+ for i, (paper, summary) in enumerate(zip(papers, summaries)):
283
+ citation = f"[{i+1}] Arxiv ID: {paper['arxiv_id']}"
284
+ formatted.append(f"{citation}\n\nSummary:\n{summary}")
285
+
286
+ combined = "\n\n" + ("\n\n" + "-" * 40 + "\n\n").join(formatted)
287
+
288
+ with open(self.summaries_path+'/summaries_combined.txt', "w") as f:
289
+ f.write(combined)
290
+
291
+ prompt = ChatPromptTemplate.from_template("""
292
+ You are a scientific assistant helping extract insights from summaries of research papers.
293
+
294
+ Here are the summaries of a large number of extracts from scientific papers:
295
+
296
+ {Summaries}
297
+
298
+ Your task is to read all the summaries and provide a response to this task: {context}
299
+ """)
300
+
301
+ chain = prompt | self.llm | StrOutputParser()
302
+
303
+ final_summary = chain.invoke({"Summaries": combined, "context":state["context"]})
304
+
305
+ with open(self.summaries_path+'/final_summary.txt', "w") as f:
306
+ f.write(final_summary)
307
+
308
+ return {**state, "final_summary": final_summary}
309
+
310
+
311
+
312
+ def _build_graph(self):
313
+ builder = StateGraph(PaperState)
314
+ builder.add_node("fetch_papers", self._fetch_node)
315
+
316
+ if self.summarize:
317
+ builder.add_node("summarize_each", self._summarize_node)
318
+ builder.add_node("aggregate", self._aggregate_node)
319
+
320
+ builder.set_entry_point("fetch_papers")
321
+ builder.add_edge("fetch_papers", "summarize_each")
322
+ builder.add_edge("summarize_each", "aggregate")
323
+ builder.set_finish_point("aggregate")
324
+
325
+ else:
326
+ builder.set_entry_point("fetch_papers")
327
+ builder.set_finish_point("fetch_papers")
328
+
329
+ graph = builder.compile()
330
+ return graph
331
+
332
+ def run(self, arxiv_search_query: str, context: str) -> str:
333
+ result = self.graph.invoke({"query": arxiv_search_query, "context":context})
334
+
335
+ if self.summarize:
336
+ return result.get("final_summary", "No summary generated.")
337
+ else:
338
+ return "\n\nFinished Fetching papers!"
339
+
340
+
341
+
342
+ if __name__ == "__main__":
343
+ agent = ArxivAgent()
344
+ result = agent.run(arxiv_search_query="Experimental Constraints on neutron star radius",
345
+ context="What are the constraints on the neutron star radius and what uncertainties are there on the constraints?")
346
+
347
+ print(result)
348
+
349
+
ursa/agents/base.py ADDED
@@ -0,0 +1,42 @@
1
+ from langchain_core.language_models.chat_models import BaseChatModel
2
+ from langchain_litellm import ChatLiteLLM
3
+ from langgraph.checkpoint.base import BaseCheckpointSaver
4
+ from langchain_core.load import dumps
5
+
6
+ import json
7
+
8
+ class BaseAgent:
9
+ # llm: BaseChatModel
10
+ # llm_with_tools: Runnable[LanguageModelInput, BaseMessage]
11
+
12
+ def __init__(
13
+ self,
14
+ llm: str | BaseChatModel,
15
+ checkpointer: BaseCheckpointSaver = None,
16
+ **kwargs,
17
+ ):
18
+ match llm:
19
+ case BaseChatModel():
20
+ self.llm = llm
21
+
22
+ case str():
23
+ self.llm_provider, self.llm_model = llm.split("/")
24
+ self.llm = ChatLiteLLM(
25
+ model=llm,
26
+ max_tokens=kwargs.pop("max_tokens", 10000),
27
+ max_retries=kwargs.pop("max_retries", 2),
28
+ **kwargs,
29
+ )
30
+
31
+ case _:
32
+ raise TypeError(
33
+ "llm argument must be a string with the provider and model, or a BaseChatModel instance."
34
+ )
35
+
36
+ self.checkpointer = checkpointer
37
+ self.thread_id = self.__class__.__name__
38
+
39
+ def write_state(self, filename, state):
40
+ json_state = dumps(state, ensure_ascii=False)
41
+ with open(filename, "w") as f:
42
+ f.write(json_state)