ursa-ai 0.2.5__tar.gz → 0.2.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ursa-ai might be problematic. Click here for more details.
- {ursa_ai-0.2.5/src/ursa_ai.egg-info → ursa_ai-0.2.6}/PKG-INFO +3 -3
- {ursa_ai-0.2.5 → ursa_ai-0.2.6}/README.md +2 -2
- {ursa_ai-0.2.5 → ursa_ai-0.2.6}/pyproject.toml +1 -0
- ursa_ai-0.2.6/src/ursa/agents/__init__.py +9 -0
- {ursa_ai-0.2.5 → ursa_ai-0.2.6}/src/ursa/agents/arxiv_agent.py +184 -107
- {ursa_ai-0.2.5 → ursa_ai-0.2.6}/src/ursa/agents/base.py +2 -1
- {ursa_ai-0.2.5 → ursa_ai-0.2.6}/src/ursa/agents/code_review_agent.py +42 -14
- {ursa_ai-0.2.5 → ursa_ai-0.2.6}/src/ursa/agents/execution_agent.py +24 -9
- {ursa_ai-0.2.5 → ursa_ai-0.2.6}/src/ursa/agents/hypothesizer_agent.py +13 -6
- {ursa_ai-0.2.5 → ursa_ai-0.2.6}/src/ursa/agents/mp_agent.py +73 -37
- {ursa_ai-0.2.5 → ursa_ai-0.2.6}/src/ursa/agents/planning_agent.py +22 -6
- {ursa_ai-0.2.5 → ursa_ai-0.2.6}/src/ursa/agents/recall_agent.py +1 -2
- {ursa_ai-0.2.5 → ursa_ai-0.2.6}/src/ursa/agents/websearch_agent.py +55 -12
- {ursa_ai-0.2.5 → ursa_ai-0.2.6}/src/ursa/prompt_library/code_review_prompts.py +5 -5
- {ursa_ai-0.2.5 → ursa_ai-0.2.6}/src/ursa/prompt_library/execution_prompts.py +4 -4
- {ursa_ai-0.2.5 → ursa_ai-0.2.6}/src/ursa/prompt_library/literature_prompts.py +4 -4
- {ursa_ai-0.2.5 → ursa_ai-0.2.6}/src/ursa/prompt_library/planning_prompts.py +4 -4
- {ursa_ai-0.2.5 → ursa_ai-0.2.6}/src/ursa/prompt_library/websearch_prompts.py +14 -14
- {ursa_ai-0.2.5 → ursa_ai-0.2.6}/src/ursa/util/diff_renderer.py +10 -3
- {ursa_ai-0.2.5 → ursa_ai-0.2.6}/src/ursa/util/memory_logger.py +9 -6
- {ursa_ai-0.2.5 → ursa_ai-0.2.6/src/ursa_ai.egg-info}/PKG-INFO +3 -3
- ursa_ai-0.2.5/src/ursa/agents/__init__.py +0 -10
- {ursa_ai-0.2.5 → ursa_ai-0.2.6}/LICENSE +0 -0
- {ursa_ai-0.2.5 → ursa_ai-0.2.6}/setup.cfg +0 -0
- {ursa_ai-0.2.5 → ursa_ai-0.2.6}/src/ursa/prompt_library/hypothesizer_prompts.py +0 -0
- {ursa_ai-0.2.5 → ursa_ai-0.2.6}/src/ursa/tools/run_command.py +0 -0
- {ursa_ai-0.2.5 → ursa_ai-0.2.6}/src/ursa/tools/write_code.py +0 -0
- {ursa_ai-0.2.5 → ursa_ai-0.2.6}/src/ursa/util/parse.py +0 -0
- {ursa_ai-0.2.5 → ursa_ai-0.2.6}/src/ursa_ai.egg-info/SOURCES.txt +0 -0
- {ursa_ai-0.2.5 → ursa_ai-0.2.6}/src/ursa_ai.egg-info/dependency_links.txt +0 -0
- {ursa_ai-0.2.5 → ursa_ai-0.2.6}/src/ursa_ai.egg-info/requires.txt +0 -0
- {ursa_ai-0.2.5 → ursa_ai-0.2.6}/src/ursa_ai.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: ursa-ai
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.6
|
|
4
4
|
Summary: Agents for science at LANL
|
|
5
5
|
Author-email: Mike Grosskopf <mikegros@lanl.gov>, Nathan Debardeleben <ndebard@lanl.gov>, Rahul Somasundaram <rsomasundaram@lanl.gov>, Isaac Michaud <imichaud@lanl.gov>, Avanish Mishra <avanish@lanl.gov>, Arthur Lui <alui@lanl.gov>, Russell Bent <rbent@lanl.gov>, Earl Lawrence <earl@lanl.gov>
|
|
6
6
|
License-Expression: BSD-3-Clause
|
|
@@ -42,7 +42,7 @@ Dynamic: license-file
|
|
|
42
42
|
|
|
43
43
|
# URSA - The Universal Research and Scientific Agent
|
|
44
44
|
|
|
45
|
-
<img src="
|
|
45
|
+
<img src="https://github.com/lanl/ursa/raw/main/logos/logo.png" alt="URSA Logo" width="200" height="200">
|
|
46
46
|
|
|
47
47
|
[![PyPI Version][pypi-version]](https://pypi.org/project/ursa-ai/)
|
|
48
48
|
[![PyPI Downloads][total-downloads]](https://pepy.tech/projects/ursa-ai)
|
|
@@ -115,7 +115,7 @@ You have a duty for ensuring that you use URSA responsibly.
|
|
|
115
115
|
|
|
116
116
|
URSA has been developed at Los Alamos National Laboratory as part of the ArtIMis project.
|
|
117
117
|
|
|
118
|
-
<img src="
|
|
118
|
+
<img src="https://github.com/lanl/ursa/raw/main/logos/artimis.png" alt="ArtIMis Logo" width="200" height="200">
|
|
119
119
|
|
|
120
120
|
### Notice of Copyright Assertion (O4958):
|
|
121
121
|
*This program is Open-Source under the BSD-3 License.
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# URSA - The Universal Research and Scientific Agent
|
|
2
2
|
|
|
3
|
-
<img src="
|
|
3
|
+
<img src="https://github.com/lanl/ursa/raw/main/logos/logo.png" alt="URSA Logo" width="200" height="200">
|
|
4
4
|
|
|
5
5
|
[![PyPI Version][pypi-version]](https://pypi.org/project/ursa-ai/)
|
|
6
6
|
[![PyPI Downloads][total-downloads]](https://pepy.tech/projects/ursa-ai)
|
|
@@ -73,7 +73,7 @@ You have a duty for ensuring that you use URSA responsibly.
|
|
|
73
73
|
|
|
74
74
|
URSA has been developed at Los Alamos National Laboratory as part of the ArtIMis project.
|
|
75
75
|
|
|
76
|
-
<img src="
|
|
76
|
+
<img src="https://github.com/lanl/ursa/raw/main/logos/artimis.png" alt="ArtIMis Logo" width="200" height="200">
|
|
77
77
|
|
|
78
78
|
### Notice of Copyright Assertion (O4958):
|
|
79
79
|
*This program is Open-Source under the BSD-3 License.
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
from .planning_agent import PlanningAgent, PlanningState
|
|
2
|
+
from .websearch_agent import WebSearchAgent, WebSearchState
|
|
3
|
+
from .execution_agent import ExecutionAgent, ExecutionState
|
|
4
|
+
from .code_review_agent import CodeReviewAgent, CodeReviewState
|
|
5
|
+
from .hypothesizer_agent import HypothesizerAgent, HypothesizerState
|
|
6
|
+
from .arxiv_agent import ArxivAgent, PaperState, PaperMetadata
|
|
7
|
+
from .recall_agent import RecallAgent
|
|
8
|
+
from .base import BaseAgent, BaseChatModel
|
|
9
|
+
from .mp_agent import MaterialsProjectAgent
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import os
|
|
2
|
-
import pymupdf
|
|
2
|
+
import pymupdf
|
|
3
3
|
import requests
|
|
4
4
|
import feedparser
|
|
5
5
|
from PIL import Image
|
|
@@ -29,10 +29,12 @@ except:
|
|
|
29
29
|
# embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
|
|
30
30
|
# embeddings = OpenAIEmbeddings()
|
|
31
31
|
|
|
32
|
+
|
|
32
33
|
class PaperMetadata(TypedDict):
|
|
33
34
|
arxiv_id: str
|
|
34
35
|
full_text: str
|
|
35
36
|
|
|
37
|
+
|
|
36
38
|
class PaperState(TypedDict, total=False):
|
|
37
39
|
query: str
|
|
38
40
|
context: str
|
|
@@ -42,11 +44,13 @@ class PaperState(TypedDict, total=False):
|
|
|
42
44
|
|
|
43
45
|
|
|
44
46
|
def describe_image(image: Image.Image) -> str:
|
|
45
|
-
if
|
|
46
|
-
print(
|
|
47
|
+
if "OpenAI" not in globals():
|
|
48
|
+
print(
|
|
49
|
+
"Vision transformer for summarizing images currently only implemented for OpenAI API."
|
|
50
|
+
)
|
|
47
51
|
return ""
|
|
48
52
|
client = OpenAI()
|
|
49
|
-
|
|
53
|
+
|
|
50
54
|
buffered = BytesIO()
|
|
51
55
|
image.save(buffered, format="PNG")
|
|
52
56
|
img_base64 = base64.b64encode(buffered.getvalue()).decode()
|
|
@@ -54,12 +58,23 @@ def describe_image(image: Image.Image) -> str:
|
|
|
54
58
|
response = client.chat.completions.create(
|
|
55
59
|
model="gpt-4-vision-preview",
|
|
56
60
|
messages=[
|
|
57
|
-
{
|
|
61
|
+
{
|
|
62
|
+
"role": "system",
|
|
63
|
+
"content": "You are a scientific assistant who explains plots and scientific diagrams.",
|
|
64
|
+
},
|
|
58
65
|
{
|
|
59
66
|
"role": "user",
|
|
60
67
|
"content": [
|
|
61
|
-
{
|
|
62
|
-
|
|
68
|
+
{
|
|
69
|
+
"type": "text",
|
|
70
|
+
"text": "Describe this scientific image or plot in detail.",
|
|
71
|
+
},
|
|
72
|
+
{
|
|
73
|
+
"type": "image_url",
|
|
74
|
+
"image_url": {
|
|
75
|
+
"url": f"data:image/png;base64,{img_base64}"
|
|
76
|
+
},
|
|
77
|
+
},
|
|
63
78
|
],
|
|
64
79
|
},
|
|
65
80
|
],
|
|
@@ -68,7 +83,9 @@ def describe_image(image: Image.Image) -> str:
|
|
|
68
83
|
return response.choices[0].message.content.strip()
|
|
69
84
|
|
|
70
85
|
|
|
71
|
-
def extract_and_describe_images(
|
|
86
|
+
def extract_and_describe_images(
|
|
87
|
+
pdf_path: str, max_images: int = 5
|
|
88
|
+
) -> List[str]:
|
|
72
89
|
doc = pymupdf.open(pdf_path)
|
|
73
90
|
descriptions = []
|
|
74
91
|
image_count = 0
|
|
@@ -89,98 +106,117 @@ def extract_and_describe_images(pdf_path: str, max_images: int = 5) -> List[str]
|
|
|
89
106
|
|
|
90
107
|
try:
|
|
91
108
|
desc = describe_image(image)
|
|
92
|
-
descriptions.append(
|
|
109
|
+
descriptions.append(
|
|
110
|
+
f"Page {page_index + 1}, Image {img_index + 1}: {desc}"
|
|
111
|
+
)
|
|
93
112
|
except Exception as e:
|
|
94
|
-
descriptions.append(
|
|
113
|
+
descriptions.append(
|
|
114
|
+
f"Page {page_index + 1}, Image {img_index + 1}: [Error: {e}]"
|
|
115
|
+
)
|
|
95
116
|
image_count += 1
|
|
96
117
|
|
|
97
118
|
return descriptions
|
|
98
119
|
|
|
99
120
|
|
|
100
121
|
def remove_surrogates(text: str) -> str:
|
|
101
|
-
return re.sub(r
|
|
122
|
+
return re.sub(r"[\ud800-\udfff]", "", text)
|
|
102
123
|
|
|
103
124
|
|
|
104
125
|
class ArxivAgent(BaseAgent):
|
|
105
|
-
def __init__(
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
126
|
+
def __init__(
|
|
127
|
+
self,
|
|
128
|
+
llm="openai/o3-mini",
|
|
129
|
+
summarize: bool = True,
|
|
130
|
+
process_images=True,
|
|
131
|
+
max_results: int = 3,
|
|
132
|
+
download_papers: bool = True,
|
|
133
|
+
rag_embedding=None,
|
|
134
|
+
database_path="arxiv_papers",
|
|
135
|
+
summaries_path="arxiv_generated_summaries",
|
|
136
|
+
vectorstore_path="arxiv_vectorstores",
|
|
137
|
+
**kwargs,
|
|
138
|
+
):
|
|
116
139
|
super().__init__(llm, **kwargs)
|
|
117
|
-
self.summarize
|
|
118
|
-
self.process_images
|
|
119
|
-
self.max_results
|
|
120
|
-
self.database_path
|
|
121
|
-
self.summaries_path
|
|
140
|
+
self.summarize = summarize
|
|
141
|
+
self.process_images = process_images
|
|
142
|
+
self.max_results = max_results
|
|
143
|
+
self.database_path = database_path
|
|
144
|
+
self.summaries_path = summaries_path
|
|
122
145
|
self.vectorstore_path = vectorstore_path
|
|
123
|
-
self.download_papers
|
|
124
|
-
self.rag_embedding
|
|
125
|
-
|
|
146
|
+
self.download_papers = download_papers
|
|
147
|
+
self.rag_embedding = rag_embedding
|
|
148
|
+
|
|
126
149
|
self.graph = self._build_graph()
|
|
127
150
|
|
|
128
151
|
os.makedirs(self.database_path, exist_ok=True)
|
|
129
152
|
|
|
130
153
|
os.makedirs(self.summaries_path, exist_ok=True)
|
|
131
154
|
|
|
132
|
-
|
|
133
155
|
def _fetch_papers(self, query: str) -> List[PaperMetadata]:
|
|
134
|
-
|
|
135
156
|
if self.download_papers:
|
|
136
|
-
|
|
137
157
|
encoded_query = quote(query)
|
|
138
158
|
url = f"http://export.arxiv.org/api/query?search_query=all:{encoded_query}&start=0&max_results={self.max_results}"
|
|
139
159
|
feed = feedparser.parse(url)
|
|
140
|
-
|
|
141
|
-
for i,entry in enumerate(feed.entries):
|
|
142
|
-
full_id = entry.id.split(
|
|
143
|
-
arxiv_id = full_id.split(
|
|
160
|
+
|
|
161
|
+
for i, entry in enumerate(feed.entries):
|
|
162
|
+
full_id = entry.id.split("/abs/")[-1]
|
|
163
|
+
arxiv_id = full_id.split("/")[-1]
|
|
144
164
|
title = entry.title.strip()
|
|
145
165
|
authors = ", ".join(author.name for author in entry.authors)
|
|
146
166
|
pdf_url = f"https://arxiv.org/pdf/{full_id}.pdf"
|
|
147
|
-
pdf_filename = os.path.join(
|
|
148
|
-
|
|
167
|
+
pdf_filename = os.path.join(
|
|
168
|
+
self.database_path, f"{arxiv_id}.pdf"
|
|
169
|
+
)
|
|
170
|
+
|
|
149
171
|
if os.path.exists(pdf_filename):
|
|
150
|
-
print(
|
|
172
|
+
print(
|
|
173
|
+
f"Paper # {i + 1}, Title: {title}, already exists in database"
|
|
174
|
+
)
|
|
151
175
|
else:
|
|
152
|
-
print(f"Downloading paper # {i+1}, Title: {title}")
|
|
176
|
+
print(f"Downloading paper # {i + 1}, Title: {title}")
|
|
153
177
|
response = requests.get(pdf_url)
|
|
154
|
-
with open(pdf_filename,
|
|
178
|
+
with open(pdf_filename, "wb") as f:
|
|
155
179
|
f.write(response.content)
|
|
156
|
-
|
|
157
180
|
|
|
158
181
|
papers = []
|
|
159
182
|
|
|
160
|
-
pdf_files = [
|
|
161
|
-
|
|
162
|
-
|
|
183
|
+
pdf_files = [
|
|
184
|
+
f
|
|
185
|
+
for f in os.listdir(self.database_path)
|
|
186
|
+
if f.lower().endswith(".pdf")
|
|
187
|
+
]
|
|
188
|
+
|
|
189
|
+
for i, pdf_filename in enumerate(pdf_files):
|
|
163
190
|
full_text = ""
|
|
164
|
-
arxiv_id = pdf_filename.split(
|
|
165
|
-
vec_save_loc =
|
|
191
|
+
arxiv_id = pdf_filename.split(".pdf")[0]
|
|
192
|
+
vec_save_loc = self.vectorstore_path + "/" + arxiv_id
|
|
166
193
|
|
|
167
194
|
if self.summarize and not os.path.exists(vec_save_loc):
|
|
168
195
|
try:
|
|
169
|
-
loader = PyPDFLoader(
|
|
196
|
+
loader = PyPDFLoader(
|
|
197
|
+
os.path.join(self.database_path, pdf_filename)
|
|
198
|
+
)
|
|
170
199
|
pages = loader.load()
|
|
171
200
|
full_text = "\n".join([p.page_content for p in pages])
|
|
172
|
-
|
|
201
|
+
|
|
173
202
|
if self.process_images:
|
|
174
|
-
image_descriptions = extract_and_describe_images(
|
|
175
|
-
|
|
176
|
-
|
|
203
|
+
image_descriptions = extract_and_describe_images(
|
|
204
|
+
os.path.join(self.database_path, pdf_filename)
|
|
205
|
+
)
|
|
206
|
+
full_text += (
|
|
207
|
+
"\n\n[Image Interpretations]\n"
|
|
208
|
+
+ "\n".join(image_descriptions)
|
|
209
|
+
)
|
|
210
|
+
|
|
177
211
|
except Exception as e:
|
|
178
212
|
full_text = f"Error loading paper: {e}"
|
|
179
|
-
|
|
180
|
-
papers.append(
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
213
|
+
|
|
214
|
+
papers.append(
|
|
215
|
+
{
|
|
216
|
+
"arxiv_id": arxiv_id,
|
|
217
|
+
"full_text": full_text,
|
|
218
|
+
}
|
|
219
|
+
)
|
|
184
220
|
|
|
185
221
|
return papers
|
|
186
222
|
|
|
@@ -188,24 +224,28 @@ class ArxivAgent(BaseAgent):
|
|
|
188
224
|
papers = self._fetch_papers(state["query"])
|
|
189
225
|
return {**state, "papers": papers}
|
|
190
226
|
|
|
191
|
-
|
|
192
227
|
def _get_or_build_vectorstore(self, paper_text: str, arxiv_id: str):
|
|
193
228
|
os.makedirs(self.vectorstore_path, exist_ok=True)
|
|
194
|
-
|
|
229
|
+
|
|
195
230
|
persist_directory = os.path.join(self.vectorstore_path, arxiv_id)
|
|
196
|
-
|
|
231
|
+
|
|
197
232
|
if os.path.exists(persist_directory):
|
|
198
|
-
vectorstore = Chroma(
|
|
233
|
+
vectorstore = Chroma(
|
|
234
|
+
persist_directory=persist_directory,
|
|
235
|
+
embedding_function=self.rag_embedding,
|
|
236
|
+
)
|
|
199
237
|
else:
|
|
200
|
-
splitter = RecursiveCharacterTextSplitter(
|
|
238
|
+
splitter = RecursiveCharacterTextSplitter(
|
|
239
|
+
chunk_size=1000, chunk_overlap=200
|
|
240
|
+
)
|
|
201
241
|
docs = splitter.create_documents([paper_text])
|
|
202
|
-
vectorstore = Chroma.from_documents(
|
|
203
|
-
|
|
242
|
+
vectorstore = Chroma.from_documents(
|
|
243
|
+
docs, self.rag_embedding, persist_directory=persist_directory
|
|
244
|
+
)
|
|
245
|
+
|
|
204
246
|
return vectorstore.as_retriever(search_kwargs={"k": 5})
|
|
205
|
-
|
|
206
247
|
|
|
207
248
|
def _summarize_node(self, state: PaperState) -> PaperState:
|
|
208
|
-
|
|
209
249
|
prompt = ChatPromptTemplate.from_template("""
|
|
210
250
|
You are a scientific assistant responsible for summarizing extracts from research papers, in the context of the following task: {context}
|
|
211
251
|
|
|
@@ -213,79 +253,115 @@ class ArxivAgent(BaseAgent):
|
|
|
213
253
|
|
|
214
254
|
{retrieved_content}
|
|
215
255
|
""")
|
|
216
|
-
|
|
256
|
+
|
|
217
257
|
chain = prompt | self.llm | StrOutputParser()
|
|
218
258
|
|
|
219
259
|
summaries = [None] * len(state["papers"])
|
|
220
260
|
relevancy_scores = [0.0] * len(state["papers"])
|
|
221
|
-
|
|
261
|
+
|
|
222
262
|
def process_paper(i, paper):
|
|
223
263
|
arxiv_id = paper["arxiv_id"]
|
|
224
|
-
summary_filename = os.path.join(
|
|
225
|
-
|
|
264
|
+
summary_filename = os.path.join(
|
|
265
|
+
self.summaries_path, f"{arxiv_id}_summary.txt"
|
|
266
|
+
)
|
|
267
|
+
|
|
226
268
|
try:
|
|
227
269
|
cleaned_text = remove_surrogates(paper["full_text"])
|
|
228
270
|
if self.rag_embedding:
|
|
229
|
-
retriever = self._get_or_build_vectorstore(
|
|
271
|
+
retriever = self._get_or_build_vectorstore(
|
|
272
|
+
cleaned_text, arxiv_id
|
|
273
|
+
)
|
|
230
274
|
|
|
231
|
-
relevant_docs_with_scores =
|
|
275
|
+
relevant_docs_with_scores = (
|
|
276
|
+
retriever.vectorstore.similarity_search_with_score(
|
|
277
|
+
state["context"], k=5
|
|
278
|
+
)
|
|
279
|
+
)
|
|
232
280
|
|
|
233
281
|
if relevant_docs_with_scores:
|
|
234
|
-
score = sum(
|
|
282
|
+
score = sum(
|
|
283
|
+
[s for _, s in relevant_docs_with_scores]
|
|
284
|
+
) / len(relevant_docs_with_scores)
|
|
235
285
|
relevancy_scores[i] = abs(1.0 - score)
|
|
236
286
|
else:
|
|
237
287
|
relevancy_scores[i] = 0.0
|
|
238
|
-
|
|
239
|
-
retrieved_content = "\n\n".join(
|
|
288
|
+
|
|
289
|
+
retrieved_content = "\n\n".join(
|
|
290
|
+
[
|
|
291
|
+
doc.page_content
|
|
292
|
+
for doc, _ in relevant_docs_with_scores
|
|
293
|
+
]
|
|
294
|
+
)
|
|
240
295
|
else:
|
|
241
296
|
retrieved_content = cleaned_text
|
|
242
|
-
|
|
243
|
-
summary = chain.invoke(
|
|
244
|
-
|
|
297
|
+
|
|
298
|
+
summary = chain.invoke(
|
|
299
|
+
{
|
|
300
|
+
"retrieved_content": retrieved_content,
|
|
301
|
+
"context": state["context"],
|
|
302
|
+
}
|
|
303
|
+
)
|
|
304
|
+
|
|
245
305
|
except Exception as e:
|
|
246
306
|
summary = f"Error summarizing paper: {e}"
|
|
247
307
|
relevancy_scores[i] = 0.0
|
|
248
|
-
|
|
308
|
+
|
|
249
309
|
with open(summary_filename, "w") as f:
|
|
250
310
|
f.write(summary)
|
|
251
311
|
|
|
252
312
|
return i, summary
|
|
253
|
-
|
|
254
|
-
if ('papers' not in state or len(state['papers']) == 0):
|
|
255
|
-
print(f"No papers retrieved - bad query or network connection to ArXiv?")
|
|
256
|
-
return {**state, "summaries": None}
|
|
257
313
|
|
|
258
|
-
|
|
259
|
-
|
|
314
|
+
if "papers" not in state or len(state["papers"]) == 0:
|
|
315
|
+
print(
|
|
316
|
+
f"No papers retrieved - bad query or network connection to ArXiv?"
|
|
317
|
+
)
|
|
318
|
+
return {**state, "summaries": None}
|
|
260
319
|
|
|
261
|
-
|
|
320
|
+
with ThreadPoolExecutor(
|
|
321
|
+
max_workers=min(32, len(state["papers"]))
|
|
322
|
+
) as executor:
|
|
323
|
+
futures = [
|
|
324
|
+
executor.submit(process_paper, i, paper)
|
|
325
|
+
for i, paper in enumerate(state["papers"])
|
|
326
|
+
]
|
|
327
|
+
|
|
328
|
+
for future in tqdm(
|
|
329
|
+
as_completed(futures),
|
|
330
|
+
total=len(futures),
|
|
331
|
+
desc="Summarizing Papers",
|
|
332
|
+
):
|
|
262
333
|
i, result = future.result()
|
|
263
334
|
summaries[i] = result
|
|
264
335
|
|
|
265
336
|
if self.rag_embedding:
|
|
266
337
|
print(f"\nMax Relevancy Score: {max(relevancy_scores)}")
|
|
267
338
|
print(f"Min Relevancy Score: {min(relevancy_scores)}")
|
|
268
|
-
print(
|
|
269
|
-
|
|
270
|
-
|
|
339
|
+
print(
|
|
340
|
+
f"Median Relevancy Score: {statistics.median(relevancy_scores)}\n"
|
|
341
|
+
)
|
|
271
342
|
|
|
343
|
+
return {**state, "summaries": summaries}
|
|
272
344
|
|
|
273
|
-
|
|
274
345
|
def _aggregate_node(self, state: PaperState) -> PaperState:
|
|
275
346
|
summaries = state["summaries"]
|
|
276
347
|
papers = state["papers"]
|
|
277
348
|
formatted = []
|
|
278
349
|
|
|
279
|
-
if
|
|
350
|
+
if (
|
|
351
|
+
"summaries" not in state
|
|
352
|
+
or state["summaries"] is None
|
|
353
|
+
or "papers" not in state
|
|
354
|
+
or state["papers"] is None
|
|
355
|
+
):
|
|
280
356
|
return {**state, "final_summary": None}
|
|
281
357
|
|
|
282
358
|
for i, (paper, summary) in enumerate(zip(papers, summaries)):
|
|
283
|
-
citation = f"[{i+1}] Arxiv ID: {paper['arxiv_id']}"
|
|
359
|
+
citation = f"[{i + 1}] Arxiv ID: {paper['arxiv_id']}"
|
|
284
360
|
formatted.append(f"{citation}\n\nSummary:\n{summary}")
|
|
285
361
|
|
|
286
362
|
combined = "\n\n" + ("\n\n" + "-" * 40 + "\n\n").join(formatted)
|
|
287
363
|
|
|
288
|
-
with open(self.summaries_path+
|
|
364
|
+
with open(self.summaries_path + "/summaries_combined.txt", "w") as f:
|
|
289
365
|
f.write(combined)
|
|
290
366
|
|
|
291
367
|
prompt = ChatPromptTemplate.from_template("""
|
|
@@ -300,15 +376,15 @@ class ArxivAgent(BaseAgent):
|
|
|
300
376
|
|
|
301
377
|
chain = prompt | self.llm | StrOutputParser()
|
|
302
378
|
|
|
303
|
-
final_summary = chain.invoke(
|
|
379
|
+
final_summary = chain.invoke(
|
|
380
|
+
{"Summaries": combined, "context": state["context"]}
|
|
381
|
+
)
|
|
304
382
|
|
|
305
|
-
with open(self.summaries_path+
|
|
383
|
+
with open(self.summaries_path + "/final_summary.txt", "w") as f:
|
|
306
384
|
f.write(final_summary)
|
|
307
385
|
|
|
308
386
|
return {**state, "final_summary": final_summary}
|
|
309
387
|
|
|
310
|
-
|
|
311
|
-
|
|
312
388
|
def _build_graph(self):
|
|
313
389
|
builder = StateGraph(PaperState)
|
|
314
390
|
builder.add_node("fetch_papers", self._fetch_node)
|
|
@@ -325,25 +401,26 @@ class ArxivAgent(BaseAgent):
|
|
|
325
401
|
else:
|
|
326
402
|
builder.set_entry_point("fetch_papers")
|
|
327
403
|
builder.set_finish_point("fetch_papers")
|
|
328
|
-
|
|
404
|
+
|
|
329
405
|
graph = builder.compile()
|
|
330
406
|
return graph
|
|
331
407
|
|
|
332
408
|
def run(self, arxiv_search_query: str, context: str) -> str:
|
|
333
|
-
result = self.graph.invoke(
|
|
409
|
+
result = self.graph.invoke(
|
|
410
|
+
{"query": arxiv_search_query, "context": context}
|
|
411
|
+
)
|
|
334
412
|
|
|
335
413
|
if self.summarize:
|
|
336
414
|
return result.get("final_summary", "No summary generated.")
|
|
337
415
|
else:
|
|
338
416
|
return "\n\nFinished Fetching papers!"
|
|
339
|
-
|
|
340
|
-
|
|
417
|
+
|
|
341
418
|
|
|
342
419
|
if __name__ == "__main__":
|
|
343
420
|
agent = ArxivAgent()
|
|
344
|
-
result = agent.run(
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
421
|
+
result = agent.run(
|
|
422
|
+
arxiv_search_query="Experimental Constraints on neutron star radius",
|
|
423
|
+
context="What are the constraints on the neutron star radius and what uncertainties are there on the constraints?",
|
|
424
|
+
)
|
|
349
425
|
|
|
426
|
+
print(result)
|
|
@@ -5,6 +5,7 @@ from langchain_core.load import dumps
|
|
|
5
5
|
|
|
6
6
|
import json
|
|
7
7
|
|
|
8
|
+
|
|
8
9
|
class BaseAgent:
|
|
9
10
|
# llm: BaseChatModel
|
|
10
11
|
# llm_with_tools: Runnable[LanguageModelInput, BaseMessage]
|
|
@@ -35,7 +36,7 @@ class BaseAgent:
|
|
|
35
36
|
|
|
36
37
|
self.checkpointer = checkpointer
|
|
37
38
|
self.thread_id = self.__class__.__name__
|
|
38
|
-
|
|
39
|
+
|
|
39
40
|
def write_state(self, filename, state):
|
|
40
41
|
json_state = dumps(state, ensure_ascii=False)
|
|
41
42
|
with open(filename, "w") as f:
|