ursa-ai 0.2.5__tar.gz → 0.2.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ursa-ai might be problematic. Click here for more details.

Files changed (32) hide show
  1. {ursa_ai-0.2.5/src/ursa_ai.egg-info → ursa_ai-0.2.7}/PKG-INFO +5 -5
  2. {ursa_ai-0.2.5 → ursa_ai-0.2.7}/README.md +4 -4
  3. {ursa_ai-0.2.5 → ursa_ai-0.2.7}/pyproject.toml +1 -0
  4. ursa_ai-0.2.7/src/ursa/agents/__init__.py +9 -0
  5. {ursa_ai-0.2.5 → ursa_ai-0.2.7}/src/ursa/agents/arxiv_agent.py +184 -107
  6. {ursa_ai-0.2.5 → ursa_ai-0.2.7}/src/ursa/agents/base.py +2 -1
  7. {ursa_ai-0.2.5 → ursa_ai-0.2.7}/src/ursa/agents/code_review_agent.py +42 -14
  8. {ursa_ai-0.2.5 → ursa_ai-0.2.7}/src/ursa/agents/execution_agent.py +24 -9
  9. {ursa_ai-0.2.5 → ursa_ai-0.2.7}/src/ursa/agents/hypothesizer_agent.py +13 -6
  10. {ursa_ai-0.2.5 → ursa_ai-0.2.7}/src/ursa/agents/mp_agent.py +73 -37
  11. {ursa_ai-0.2.5 → ursa_ai-0.2.7}/src/ursa/agents/planning_agent.py +22 -6
  12. {ursa_ai-0.2.5 → ursa_ai-0.2.7}/src/ursa/agents/recall_agent.py +1 -2
  13. {ursa_ai-0.2.5 → ursa_ai-0.2.7}/src/ursa/agents/websearch_agent.py +55 -12
  14. {ursa_ai-0.2.5 → ursa_ai-0.2.7}/src/ursa/prompt_library/code_review_prompts.py +5 -5
  15. {ursa_ai-0.2.5 → ursa_ai-0.2.7}/src/ursa/prompt_library/execution_prompts.py +4 -4
  16. {ursa_ai-0.2.5 → ursa_ai-0.2.7}/src/ursa/prompt_library/literature_prompts.py +4 -4
  17. {ursa_ai-0.2.5 → ursa_ai-0.2.7}/src/ursa/prompt_library/planning_prompts.py +4 -4
  18. {ursa_ai-0.2.5 → ursa_ai-0.2.7}/src/ursa/prompt_library/websearch_prompts.py +14 -14
  19. {ursa_ai-0.2.5 → ursa_ai-0.2.7}/src/ursa/util/diff_renderer.py +10 -3
  20. {ursa_ai-0.2.5 → ursa_ai-0.2.7}/src/ursa/util/memory_logger.py +9 -6
  21. {ursa_ai-0.2.5 → ursa_ai-0.2.7/src/ursa_ai.egg-info}/PKG-INFO +5 -5
  22. ursa_ai-0.2.5/src/ursa/agents/__init__.py +0 -10
  23. {ursa_ai-0.2.5 → ursa_ai-0.2.7}/LICENSE +0 -0
  24. {ursa_ai-0.2.5 → ursa_ai-0.2.7}/setup.cfg +0 -0
  25. {ursa_ai-0.2.5 → ursa_ai-0.2.7}/src/ursa/prompt_library/hypothesizer_prompts.py +0 -0
  26. {ursa_ai-0.2.5 → ursa_ai-0.2.7}/src/ursa/tools/run_command.py +0 -0
  27. {ursa_ai-0.2.5 → ursa_ai-0.2.7}/src/ursa/tools/write_code.py +0 -0
  28. {ursa_ai-0.2.5 → ursa_ai-0.2.7}/src/ursa/util/parse.py +0 -0
  29. {ursa_ai-0.2.5 → ursa_ai-0.2.7}/src/ursa_ai.egg-info/SOURCES.txt +0 -0
  30. {ursa_ai-0.2.5 → ursa_ai-0.2.7}/src/ursa_ai.egg-info/dependency_links.txt +0 -0
  31. {ursa_ai-0.2.5 → ursa_ai-0.2.7}/src/ursa_ai.egg-info/requires.txt +0 -0
  32. {ursa_ai-0.2.5 → ursa_ai-0.2.7}/src/ursa_ai.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ursa-ai
3
- Version: 0.2.5
3
+ Version: 0.2.7
4
4
  Summary: Agents for science at LANL
5
5
  Author-email: Mike Grosskopf <mikegros@lanl.gov>, Nathan Debardeleben <ndebard@lanl.gov>, Rahul Somasundaram <rsomasundaram@lanl.gov>, Isaac Michaud <imichaud@lanl.gov>, Avanish Mishra <avanish@lanl.gov>, Arthur Lui <alui@lanl.gov>, Russell Bent <rbent@lanl.gov>, Earl Lawrence <earl@lanl.gov>
6
6
  License-Expression: BSD-3-Clause
@@ -42,10 +42,10 @@ Dynamic: license-file
42
42
 
43
43
  # URSA - The Universal Research and Scientific Agent
44
44
 
45
- <img src="./logos/logo.png" alt="URSA Logo" width="200" height="200">
45
+ <img src="https://github.com/lanl/ursa/raw/main/logos/logo.png" alt="URSA Logo" width="200" height="200">
46
46
 
47
47
  [![PyPI Version][pypi-version]](https://pypi.org/project/ursa-ai/)
48
- [![PyPI Downloads][total-downloads]](https://pepy.tech/projects/ursa-ai)
48
+ [![PyPI Downloads][monthly-downloads]](https://pypistats.org/packages/ursa-ai)
49
49
 
50
50
  The flexible agentic workflow for accelerating scientific tasks.
51
51
  Composes information flow between agents for planning, code writing and execution, and online research to solve complex problems.
@@ -115,7 +115,7 @@ You have a duty for ensuring that you use URSA responsibly.
115
115
 
116
116
  URSA has been developed at Los Alamos National Laboratory as part of the ArtIMis project.
117
117
 
118
- <img src="./logos/artimis.png" alt="ArtIMis Logo" width="200" height="200">
118
+ <img src="https://github.com/lanl/ursa/raw/main/logos/artimis.png" alt="ArtIMis Logo" width="200" height="200">
119
119
 
120
120
  ### Notice of Copyright Assertion (O4958):
121
121
  *This program is Open-Source under the BSD-3 License.
@@ -127,4 +127,4 @@ Redistribution and use in source and binary forms, with or without modification,
127
127
  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
128
128
 
129
129
  [pypi-version]: https://img.shields.io/pypi/v/ursa-ai?style=flat-square&label=PyPI
130
- [total-downloads]: https://img.shields.io/pepy/dt/ursa-ai?style=flat-square&label=downloads&color=blue
130
+ [monthly-downloads]: https://img.shields.io/pypi/dm/ursa-ai?style=flat-square&label=Downloads&color=blue
@@ -1,9 +1,9 @@
1
1
  # URSA - The Universal Research and Scientific Agent
2
2
 
3
- <img src="./logos/logo.png" alt="URSA Logo" width="200" height="200">
3
+ <img src="https://github.com/lanl/ursa/raw/main/logos/logo.png" alt="URSA Logo" width="200" height="200">
4
4
 
5
5
  [![PyPI Version][pypi-version]](https://pypi.org/project/ursa-ai/)
6
- [![PyPI Downloads][total-downloads]](https://pepy.tech/projects/ursa-ai)
6
+ [![PyPI Downloads][monthly-downloads]](https://pypistats.org/packages/ursa-ai)
7
7
 
8
8
  The flexible agentic workflow for accelerating scientific tasks.
9
9
  Composes information flow between agents for planning, code writing and execution, and online research to solve complex problems.
@@ -73,7 +73,7 @@ You have a duty for ensuring that you use URSA responsibly.
73
73
 
74
74
  URSA has been developed at Los Alamos National Laboratory as part of the ArtIMis project.
75
75
 
76
- <img src="./logos/artimis.png" alt="ArtIMis Logo" width="200" height="200">
76
+ <img src="https://github.com/lanl/ursa/raw/main/logos/artimis.png" alt="ArtIMis Logo" width="200" height="200">
77
77
 
78
78
  ### Notice of Copyright Assertion (O4958):
79
79
  *This program is Open-Source under the BSD-3 License.
@@ -85,4 +85,4 @@ Redistribution and use in source and binary forms, with or without modification,
85
85
  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
86
86
 
87
87
  [pypi-version]: https://img.shields.io/pypi/v/ursa-ai?style=flat-square&label=PyPI
88
- [total-downloads]: https://img.shields.io/pepy/dt/ursa-ai?style=flat-square&label=downloads&color=blue
88
+ [monthly-downloads]: https://img.shields.io/pypi/dm/ursa-ai?style=flat-square&label=Downloads&color=blue
@@ -79,5 +79,6 @@ pycodestyle.max-doc-length = 80
79
79
  dev = [
80
80
  "langgraph-checkpoint-sqlite>=2.0.10",
81
81
  "notebook>=7.3.3",
82
+ "pre-commit>=4.3.0",
82
83
  "scikit-optimize>=0.10.2",
83
84
  ]
@@ -0,0 +1,9 @@
1
+ from .planning_agent import PlanningAgent, PlanningState
2
+ from .websearch_agent import WebSearchAgent, WebSearchState
3
+ from .execution_agent import ExecutionAgent, ExecutionState
4
+ from .code_review_agent import CodeReviewAgent, CodeReviewState
5
+ from .hypothesizer_agent import HypothesizerAgent, HypothesizerState
6
+ from .arxiv_agent import ArxivAgent, PaperState, PaperMetadata
7
+ from .recall_agent import RecallAgent
8
+ from .base import BaseAgent, BaseChatModel
9
+ from .mp_agent import MaterialsProjectAgent
@@ -1,5 +1,5 @@
1
1
  import os
2
- import pymupdf
2
+ import pymupdf
3
3
  import requests
4
4
  import feedparser
5
5
  from PIL import Image
@@ -29,10 +29,12 @@ except:
29
29
  # embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
30
30
  # embeddings = OpenAIEmbeddings()
31
31
 
32
+
32
33
  class PaperMetadata(TypedDict):
33
34
  arxiv_id: str
34
35
  full_text: str
35
36
 
37
+
36
38
  class PaperState(TypedDict, total=False):
37
39
  query: str
38
40
  context: str
@@ -42,11 +44,13 @@ class PaperState(TypedDict, total=False):
42
44
 
43
45
 
44
46
  def describe_image(image: Image.Image) -> str:
45
- if 'OpenAI' not in globals():
46
- print("Vision transformer for summarizing images currently only implemented for OpenAI API.")
47
+ if "OpenAI" not in globals():
48
+ print(
49
+ "Vision transformer for summarizing images currently only implemented for OpenAI API."
50
+ )
47
51
  return ""
48
52
  client = OpenAI()
49
-
53
+
50
54
  buffered = BytesIO()
51
55
  image.save(buffered, format="PNG")
52
56
  img_base64 = base64.b64encode(buffered.getvalue()).decode()
@@ -54,12 +58,23 @@ def describe_image(image: Image.Image) -> str:
54
58
  response = client.chat.completions.create(
55
59
  model="gpt-4-vision-preview",
56
60
  messages=[
57
- {"role": "system", "content": "You are a scientific assistant who explains plots and scientific diagrams."},
61
+ {
62
+ "role": "system",
63
+ "content": "You are a scientific assistant who explains plots and scientific diagrams.",
64
+ },
58
65
  {
59
66
  "role": "user",
60
67
  "content": [
61
- {"type": "text", "text": "Describe this scientific image or plot in detail."},
62
- {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_base64}"}}
68
+ {
69
+ "type": "text",
70
+ "text": "Describe this scientific image or plot in detail.",
71
+ },
72
+ {
73
+ "type": "image_url",
74
+ "image_url": {
75
+ "url": f"data:image/png;base64,{img_base64}"
76
+ },
77
+ },
63
78
  ],
64
79
  },
65
80
  ],
@@ -68,7 +83,9 @@ def describe_image(image: Image.Image) -> str:
68
83
  return response.choices[0].message.content.strip()
69
84
 
70
85
 
71
- def extract_and_describe_images(pdf_path: str, max_images: int = 5) -> List[str]:
86
+ def extract_and_describe_images(
87
+ pdf_path: str, max_images: int = 5
88
+ ) -> List[str]:
72
89
  doc = pymupdf.open(pdf_path)
73
90
  descriptions = []
74
91
  image_count = 0
@@ -89,98 +106,117 @@ def extract_and_describe_images(pdf_path: str, max_images: int = 5) -> List[str]
89
106
 
90
107
  try:
91
108
  desc = describe_image(image)
92
- descriptions.append(f"Page {page_index + 1}, Image {img_index + 1}: {desc}")
109
+ descriptions.append(
110
+ f"Page {page_index + 1}, Image {img_index + 1}: {desc}"
111
+ )
93
112
  except Exception as e:
94
- descriptions.append(f"Page {page_index + 1}, Image {img_index + 1}: [Error: {e}]")
113
+ descriptions.append(
114
+ f"Page {page_index + 1}, Image {img_index + 1}: [Error: {e}]"
115
+ )
95
116
  image_count += 1
96
117
 
97
118
  return descriptions
98
119
 
99
120
 
100
121
  def remove_surrogates(text: str) -> str:
101
- return re.sub(r'[\ud800-\udfff]', '', text)
122
+ return re.sub(r"[\ud800-\udfff]", "", text)
102
123
 
103
124
 
104
125
  class ArxivAgent(BaseAgent):
105
- def __init__(self,
106
- llm="openai/o3-mini",
107
- summarize: bool = True,
108
- process_images = True,
109
- max_results: int = 3,
110
- download_papers: bool = True,
111
- rag_embedding = None,
112
- database_path ='arxiv_papers',
113
- summaries_path ='arxiv_generated_summaries',
114
- vectorstore_path ='arxiv_vectorstores',**kwargs):
115
-
126
+ def __init__(
127
+ self,
128
+ llm="openai/o3-mini",
129
+ summarize: bool = True,
130
+ process_images=True,
131
+ max_results: int = 3,
132
+ download_papers: bool = True,
133
+ rag_embedding=None,
134
+ database_path="arxiv_papers",
135
+ summaries_path="arxiv_generated_summaries",
136
+ vectorstore_path="arxiv_vectorstores",
137
+ **kwargs,
138
+ ):
116
139
  super().__init__(llm, **kwargs)
117
- self.summarize = summarize
118
- self.process_images = process_images
119
- self.max_results = max_results
120
- self.database_path = database_path
121
- self.summaries_path = summaries_path
140
+ self.summarize = summarize
141
+ self.process_images = process_images
142
+ self.max_results = max_results
143
+ self.database_path = database_path
144
+ self.summaries_path = summaries_path
122
145
  self.vectorstore_path = vectorstore_path
123
- self.download_papers = download_papers
124
- self.rag_embedding = rag_embedding
125
-
146
+ self.download_papers = download_papers
147
+ self.rag_embedding = rag_embedding
148
+
126
149
  self.graph = self._build_graph()
127
150
 
128
151
  os.makedirs(self.database_path, exist_ok=True)
129
152
 
130
153
  os.makedirs(self.summaries_path, exist_ok=True)
131
154
 
132
-
133
155
  def _fetch_papers(self, query: str) -> List[PaperMetadata]:
134
-
135
156
  if self.download_papers:
136
-
137
157
  encoded_query = quote(query)
138
158
  url = f"http://export.arxiv.org/api/query?search_query=all:{encoded_query}&start=0&max_results={self.max_results}"
139
159
  feed = feedparser.parse(url)
140
-
141
- for i,entry in enumerate(feed.entries):
142
- full_id = entry.id.split('/abs/')[-1]
143
- arxiv_id = full_id.split('/')[-1]
160
+
161
+ for i, entry in enumerate(feed.entries):
162
+ full_id = entry.id.split("/abs/")[-1]
163
+ arxiv_id = full_id.split("/")[-1]
144
164
  title = entry.title.strip()
145
165
  authors = ", ".join(author.name for author in entry.authors)
146
166
  pdf_url = f"https://arxiv.org/pdf/{full_id}.pdf"
147
- pdf_filename = os.path.join(self.database_path, f"{arxiv_id}.pdf")
148
-
167
+ pdf_filename = os.path.join(
168
+ self.database_path, f"{arxiv_id}.pdf"
169
+ )
170
+
149
171
  if os.path.exists(pdf_filename):
150
- print(f"Paper # {i+1}, Title: {title}, already exists in database")
172
+ print(
173
+ f"Paper # {i + 1}, Title: {title}, already exists in database"
174
+ )
151
175
  else:
152
- print(f"Downloading paper # {i+1}, Title: {title}")
176
+ print(f"Downloading paper # {i + 1}, Title: {title}")
153
177
  response = requests.get(pdf_url)
154
- with open(pdf_filename, 'wb') as f:
178
+ with open(pdf_filename, "wb") as f:
155
179
  f.write(response.content)
156
-
157
180
 
158
181
  papers = []
159
182
 
160
- pdf_files = [f for f in os.listdir(self.database_path) if f.lower().endswith(".pdf")]
161
-
162
- for i,pdf_filename in enumerate(pdf_files):
183
+ pdf_files = [
184
+ f
185
+ for f in os.listdir(self.database_path)
186
+ if f.lower().endswith(".pdf")
187
+ ]
188
+
189
+ for i, pdf_filename in enumerate(pdf_files):
163
190
  full_text = ""
164
- arxiv_id = pdf_filename.split('.pdf')[0]
165
- vec_save_loc = self.vectorstore_path + '/' + arxiv_id
191
+ arxiv_id = pdf_filename.split(".pdf")[0]
192
+ vec_save_loc = self.vectorstore_path + "/" + arxiv_id
166
193
 
167
194
  if self.summarize and not os.path.exists(vec_save_loc):
168
195
  try:
169
- loader = PyPDFLoader( os.path.join(self.database_path, pdf_filename) )
196
+ loader = PyPDFLoader(
197
+ os.path.join(self.database_path, pdf_filename)
198
+ )
170
199
  pages = loader.load()
171
200
  full_text = "\n".join([p.page_content for p in pages])
172
-
201
+
173
202
  if self.process_images:
174
- image_descriptions = extract_and_describe_images( os.path.join(self.database_path, pdf_filename) )
175
- full_text += "\n\n[Image Interpretations]\n" + "\n".join(image_descriptions)
176
-
203
+ image_descriptions = extract_and_describe_images(
204
+ os.path.join(self.database_path, pdf_filename)
205
+ )
206
+ full_text += (
207
+ "\n\n[Image Interpretations]\n"
208
+ + "\n".join(image_descriptions)
209
+ )
210
+
177
211
  except Exception as e:
178
212
  full_text = f"Error loading paper: {e}"
179
-
180
- papers.append({
181
- "arxiv_id": arxiv_id,
182
- "full_text": full_text,
183
- })
213
+
214
+ papers.append(
215
+ {
216
+ "arxiv_id": arxiv_id,
217
+ "full_text": full_text,
218
+ }
219
+ )
184
220
 
185
221
  return papers
186
222
 
@@ -188,24 +224,28 @@ class ArxivAgent(BaseAgent):
188
224
  papers = self._fetch_papers(state["query"])
189
225
  return {**state, "papers": papers}
190
226
 
191
-
192
227
  def _get_or_build_vectorstore(self, paper_text: str, arxiv_id: str):
193
228
  os.makedirs(self.vectorstore_path, exist_ok=True)
194
-
229
+
195
230
  persist_directory = os.path.join(self.vectorstore_path, arxiv_id)
196
-
231
+
197
232
  if os.path.exists(persist_directory):
198
- vectorstore = Chroma(persist_directory=persist_directory, embedding_function=self.rag_embedding)
233
+ vectorstore = Chroma(
234
+ persist_directory=persist_directory,
235
+ embedding_function=self.rag_embedding,
236
+ )
199
237
  else:
200
- splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
238
+ splitter = RecursiveCharacterTextSplitter(
239
+ chunk_size=1000, chunk_overlap=200
240
+ )
201
241
  docs = splitter.create_documents([paper_text])
202
- vectorstore = Chroma.from_documents(docs, self.rag_embedding, persist_directory=persist_directory)
203
-
242
+ vectorstore = Chroma.from_documents(
243
+ docs, self.rag_embedding, persist_directory=persist_directory
244
+ )
245
+
204
246
  return vectorstore.as_retriever(search_kwargs={"k": 5})
205
-
206
247
 
207
248
  def _summarize_node(self, state: PaperState) -> PaperState:
208
-
209
249
  prompt = ChatPromptTemplate.from_template("""
210
250
  You are a scientific assistant responsible for summarizing extracts from research papers, in the context of the following task: {context}
211
251
 
@@ -213,79 +253,115 @@ class ArxivAgent(BaseAgent):
213
253
 
214
254
  {retrieved_content}
215
255
  """)
216
-
256
+
217
257
  chain = prompt | self.llm | StrOutputParser()
218
258
 
219
259
  summaries = [None] * len(state["papers"])
220
260
  relevancy_scores = [0.0] * len(state["papers"])
221
-
261
+
222
262
  def process_paper(i, paper):
223
263
  arxiv_id = paper["arxiv_id"]
224
- summary_filename = os.path.join(self.summaries_path, f"{arxiv_id}_summary.txt")
225
-
264
+ summary_filename = os.path.join(
265
+ self.summaries_path, f"{arxiv_id}_summary.txt"
266
+ )
267
+
226
268
  try:
227
269
  cleaned_text = remove_surrogates(paper["full_text"])
228
270
  if self.rag_embedding:
229
- retriever = self._get_or_build_vectorstore(cleaned_text, arxiv_id)
271
+ retriever = self._get_or_build_vectorstore(
272
+ cleaned_text, arxiv_id
273
+ )
230
274
 
231
- relevant_docs_with_scores = retriever.vectorstore.similarity_search_with_score(state["context"], k=5)
275
+ relevant_docs_with_scores = (
276
+ retriever.vectorstore.similarity_search_with_score(
277
+ state["context"], k=5
278
+ )
279
+ )
232
280
 
233
281
  if relevant_docs_with_scores:
234
- score = sum([s for _, s in relevant_docs_with_scores]) / len(relevant_docs_with_scores)
282
+ score = sum(
283
+ [s for _, s in relevant_docs_with_scores]
284
+ ) / len(relevant_docs_with_scores)
235
285
  relevancy_scores[i] = abs(1.0 - score)
236
286
  else:
237
287
  relevancy_scores[i] = 0.0
238
-
239
- retrieved_content = "\n\n".join([doc.page_content for doc, _ in relevant_docs_with_scores])
288
+
289
+ retrieved_content = "\n\n".join(
290
+ [
291
+ doc.page_content
292
+ for doc, _ in relevant_docs_with_scores
293
+ ]
294
+ )
240
295
  else:
241
296
  retrieved_content = cleaned_text
242
-
243
- summary = chain.invoke({"retrieved_content": retrieved_content, "context": state["context"]})
244
-
297
+
298
+ summary = chain.invoke(
299
+ {
300
+ "retrieved_content": retrieved_content,
301
+ "context": state["context"],
302
+ }
303
+ )
304
+
245
305
  except Exception as e:
246
306
  summary = f"Error summarizing paper: {e}"
247
307
  relevancy_scores[i] = 0.0
248
-
308
+
249
309
  with open(summary_filename, "w") as f:
250
310
  f.write(summary)
251
311
 
252
312
  return i, summary
253
-
254
- if ('papers' not in state or len(state['papers']) == 0):
255
- print(f"No papers retrieved - bad query or network connection to ArXiv?")
256
- return {**state, "summaries": None}
257
313
 
258
- with ThreadPoolExecutor(max_workers=min(32, len(state["papers"]))) as executor:
259
- futures = [executor.submit(process_paper, i, paper) for i, paper in enumerate(state["papers"])]
314
+ if "papers" not in state or len(state["papers"]) == 0:
315
+ print(
316
+ f"No papers retrieved - bad query or network connection to ArXiv?"
317
+ )
318
+ return {**state, "summaries": None}
260
319
 
261
- for future in tqdm(as_completed(futures), total=len(futures), desc="Summarizing Papers"):
320
+ with ThreadPoolExecutor(
321
+ max_workers=min(32, len(state["papers"]))
322
+ ) as executor:
323
+ futures = [
324
+ executor.submit(process_paper, i, paper)
325
+ for i, paper in enumerate(state["papers"])
326
+ ]
327
+
328
+ for future in tqdm(
329
+ as_completed(futures),
330
+ total=len(futures),
331
+ desc="Summarizing Papers",
332
+ ):
262
333
  i, result = future.result()
263
334
  summaries[i] = result
264
335
 
265
336
  if self.rag_embedding:
266
337
  print(f"\nMax Relevancy Score: {max(relevancy_scores)}")
267
338
  print(f"Min Relevancy Score: {min(relevancy_scores)}")
268
- print(f"Median Relevancy Score: {statistics.median(relevancy_scores)}\n")
269
-
270
- return {**state, "summaries": summaries}
339
+ print(
340
+ f"Median Relevancy Score: {statistics.median(relevancy_scores)}\n"
341
+ )
271
342
 
343
+ return {**state, "summaries": summaries}
272
344
 
273
-
274
345
  def _aggregate_node(self, state: PaperState) -> PaperState:
275
346
  summaries = state["summaries"]
276
347
  papers = state["papers"]
277
348
  formatted = []
278
349
 
279
- if 'summaries' not in state or state['summaries'] is None or 'papers' not in state or state['papers'] is None:
350
+ if (
351
+ "summaries" not in state
352
+ or state["summaries"] is None
353
+ or "papers" not in state
354
+ or state["papers"] is None
355
+ ):
280
356
  return {**state, "final_summary": None}
281
357
 
282
358
  for i, (paper, summary) in enumerate(zip(papers, summaries)):
283
- citation = f"[{i+1}] Arxiv ID: {paper['arxiv_id']}"
359
+ citation = f"[{i + 1}] Arxiv ID: {paper['arxiv_id']}"
284
360
  formatted.append(f"{citation}\n\nSummary:\n{summary}")
285
361
 
286
362
  combined = "\n\n" + ("\n\n" + "-" * 40 + "\n\n").join(formatted)
287
363
 
288
- with open(self.summaries_path+'/summaries_combined.txt', "w") as f:
364
+ with open(self.summaries_path + "/summaries_combined.txt", "w") as f:
289
365
  f.write(combined)
290
366
 
291
367
  prompt = ChatPromptTemplate.from_template("""
@@ -300,15 +376,15 @@ class ArxivAgent(BaseAgent):
300
376
 
301
377
  chain = prompt | self.llm | StrOutputParser()
302
378
 
303
- final_summary = chain.invoke({"Summaries": combined, "context":state["context"]})
379
+ final_summary = chain.invoke(
380
+ {"Summaries": combined, "context": state["context"]}
381
+ )
304
382
 
305
- with open(self.summaries_path+'/final_summary.txt', "w") as f:
383
+ with open(self.summaries_path + "/final_summary.txt", "w") as f:
306
384
  f.write(final_summary)
307
385
 
308
386
  return {**state, "final_summary": final_summary}
309
387
 
310
-
311
-
312
388
  def _build_graph(self):
313
389
  builder = StateGraph(PaperState)
314
390
  builder.add_node("fetch_papers", self._fetch_node)
@@ -325,25 +401,26 @@ class ArxivAgent(BaseAgent):
325
401
  else:
326
402
  builder.set_entry_point("fetch_papers")
327
403
  builder.set_finish_point("fetch_papers")
328
-
404
+
329
405
  graph = builder.compile()
330
406
  return graph
331
407
 
332
408
  def run(self, arxiv_search_query: str, context: str) -> str:
333
- result = self.graph.invoke({"query": arxiv_search_query, "context":context})
409
+ result = self.graph.invoke(
410
+ {"query": arxiv_search_query, "context": context}
411
+ )
334
412
 
335
413
  if self.summarize:
336
414
  return result.get("final_summary", "No summary generated.")
337
415
  else:
338
416
  return "\n\nFinished Fetching papers!"
339
-
340
-
417
+
341
418
 
342
419
  if __name__ == "__main__":
343
420
  agent = ArxivAgent()
344
- result = agent.run(arxiv_search_query="Experimental Constraints on neutron star radius",
345
- context="What are the constraints on the neutron star radius and what uncertainties are there on the constraints?")
346
-
347
- print(result)
348
-
421
+ result = agent.run(
422
+ arxiv_search_query="Experimental Constraints on neutron star radius",
423
+ context="What are the constraints on the neutron star radius and what uncertainties are there on the constraints?",
424
+ )
349
425
 
426
+ print(result)
@@ -5,6 +5,7 @@ from langchain_core.load import dumps
5
5
 
6
6
  import json
7
7
 
8
+
8
9
  class BaseAgent:
9
10
  # llm: BaseChatModel
10
11
  # llm_with_tools: Runnable[LanguageModelInput, BaseMessage]
@@ -35,7 +36,7 @@ class BaseAgent:
35
36
 
36
37
  self.checkpointer = checkpointer
37
38
  self.thread_id = self.__class__.__name__
38
-
39
+
39
40
  def write_state(self, filename, state):
40
41
  json_state = dumps(state, ensure_ascii=False)
41
42
  with open(filename, "w") as f: