ursa-ai 0.4.2__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ursa-ai might be problematic. Click here for more details.

ursa/agents/__init__.py CHANGED
@@ -14,6 +14,8 @@ from .lammps_agent import LammpsState as LammpsState
14
14
  from .mp_agent import MaterialsProjectAgent as MaterialsProjectAgent
15
15
  from .planning_agent import PlanningAgent as PlanningAgent
16
16
  from .planning_agent import PlanningState as PlanningState
17
+ from .rag_agent import RAGAgent as RAGAgent
18
+ from .rag_agent import RAGState as RAGState
17
19
  from .recall_agent import RecallAgent as RecallAgent
18
20
  from .websearch_agent import WebSearchAgent as WebSearchAgent
19
21
  from .websearch_agent import WebSearchState as WebSearchState
@@ -1,7 +1,6 @@
1
1
  import base64
2
2
  import os
3
3
  import re
4
- import statistics
5
4
  from concurrent.futures import ThreadPoolExecutor, as_completed
6
5
  from io import BytesIO
7
6
  from urllib.parse import quote
@@ -9,8 +8,6 @@ from urllib.parse import quote
9
8
  import feedparser
10
9
  import pymupdf
11
10
  import requests
12
- from langchain.text_splitter import RecursiveCharacterTextSplitter
13
- from langchain_chroma import Chroma
14
11
  from langchain_community.document_loaders import PyPDFLoader
15
12
  from langchain_core.output_parsers import StrOutputParser
16
13
  from langchain_core.prompts import ChatPromptTemplate
@@ -20,15 +17,13 @@ from tqdm import tqdm
20
17
  from typing_extensions import List, TypedDict
21
18
 
22
19
  from .base import BaseAgent
20
+ from .rag_agent import RAGAgent
23
21
 
24
22
  try:
25
23
  from openai import OpenAI
26
24
  except Exception:
27
25
  pass
28
26
 
29
- # embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
30
- # embeddings = OpenAIEmbeddings()
31
-
32
27
 
33
28
  class PaperMetadata(TypedDict):
34
29
  arxiv_id: str
@@ -242,27 +237,6 @@ class ArxivAgent(BaseAgent):
242
237
  papers = self._fetch_papers(state["query"])
243
238
  return {**state, "papers": papers}
244
239
 
245
- def _get_or_build_vectorstore(self, paper_text: str, arxiv_id: str):
246
- os.makedirs(self.vectorstore_path, exist_ok=True)
247
-
248
- persist_directory = os.path.join(self.vectorstore_path, arxiv_id)
249
-
250
- if os.path.exists(persist_directory):
251
- vectorstore = Chroma(
252
- persist_directory=persist_directory,
253
- embedding_function=self.rag_embedding,
254
- )
255
- else:
256
- splitter = RecursiveCharacterTextSplitter(
257
- chunk_size=1000, chunk_overlap=200
258
- )
259
- docs = splitter.create_documents([paper_text])
260
- vectorstore = Chroma.from_documents(
261
- docs, self.rag_embedding, persist_directory=persist_directory
262
- )
263
-
264
- return vectorstore.as_retriever(search_kwargs={"k": 5})
265
-
266
240
  def _summarize_node(self, state: PaperState) -> PaperState:
267
241
  prompt = ChatPromptTemplate.from_template("""
268
242
  You are a scientific assistant responsible for summarizing extracts from research papers, in the context of the following task: {context}
@@ -285,33 +259,8 @@ class ArxivAgent(BaseAgent):
285
259
 
286
260
  try:
287
261
  cleaned_text = remove_surrogates(paper["full_text"])
288
- if self.rag_embedding:
289
- retriever = self._get_or_build_vectorstore(
290
- cleaned_text, arxiv_id
291
- )
292
-
293
- relevant_docs_with_scores = (
294
- retriever.vectorstore.similarity_search_with_score(
295
- state["context"], k=5
296
- )
297
- )
298
-
299
- if relevant_docs_with_scores:
300
- score = sum([
301
- s for _, s in relevant_docs_with_scores
302
- ]) / len(relevant_docs_with_scores)
303
- relevancy_scores[i] = abs(1.0 - score)
304
- else:
305
- relevancy_scores[i] = 0.0
306
-
307
- retrieved_content = "\n\n".join([
308
- doc.page_content for doc, _ in relevant_docs_with_scores
309
- ])
310
- else:
311
- retrieved_content = cleaned_text
312
-
313
262
  summary = chain.invoke({
314
- "retrieved_content": retrieved_content,
263
+ "retrieved_content": cleaned_text,
315
264
  "context": state["context"],
316
265
  })
317
266
 
@@ -346,15 +295,18 @@ class ArxivAgent(BaseAgent):
346
295
  i, result = future.result()
347
296
  summaries[i] = result
348
297
 
349
- if self.rag_embedding:
350
- print(f"\nMax Relevancy Score: {max(relevancy_scores)}")
351
- print(f"Min Relevancy Score: {min(relevancy_scores)}")
352
- print(
353
- f"Median Relevancy Score: {statistics.median(relevancy_scores)}\n"
354
- )
355
-
356
298
  return {**state, "summaries": summaries}
357
299
 
300
+ def _rag_node(self, state: PaperState) -> PaperState:
301
+ new_state = state.copy()
302
+ rag_agent = RAGAgent(
303
+ llm=self.llm,
304
+ embedding=self.rag_embedding,
305
+ database_path=self.database_path,
306
+ )
307
+ new_state["final_summary"] = rag_agent.run(context=state["context"])
308
+ return new_state
309
+
358
310
  def _aggregate_node(self, state: PaperState) -> PaperState:
359
311
  summaries = state["summaries"]
360
312
  papers = state["papers"]
@@ -404,13 +356,20 @@ class ArxivAgent(BaseAgent):
404
356
  builder.add_node("fetch_papers", self._fetch_node)
405
357
 
406
358
  if self.summarize:
407
- builder.add_node("summarize_each", self._summarize_node)
408
- builder.add_node("aggregate", self._aggregate_node)
409
-
410
- builder.set_entry_point("fetch_papers")
411
- builder.add_edge("fetch_papers", "summarize_each")
412
- builder.add_edge("summarize_each", "aggregate")
413
- builder.set_finish_point("aggregate")
359
+ if self.rag_embedding:
360
+ builder.add_node("rag_summarize", self._rag_node)
361
+
362
+ builder.set_entry_point("fetch_papers")
363
+ builder.add_edge("fetch_papers", "rag_summarize")
364
+ builder.set_finish_point("rag_summarize")
365
+ else:
366
+ builder.add_node("summarize_each", self._summarize_node)
367
+ builder.add_node("aggregate", self._aggregate_node)
368
+
369
+ builder.set_entry_point("fetch_papers")
370
+ builder.add_edge("fetch_papers", "summarize_each")
371
+ builder.add_edge("summarize_each", "aggregate")
372
+ builder.set_finish_point("aggregate")
414
373
 
415
374
  else:
416
375
  builder.set_entry_point("fetch_papers")
@@ -0,0 +1,272 @@
1
+ import os
2
+ import re
3
+ import statistics
4
+ from threading import Lock
5
+ from typing import List, Optional, TypedDict
6
+
7
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
8
+ from langchain_chroma import Chroma
9
+ from langchain_community.document_loaders import PyPDFLoader
10
+ from langchain_core.embeddings import Embeddings
11
+ from langchain_core.output_parsers import StrOutputParser
12
+ from langchain_core.prompts import ChatPromptTemplate
13
+ from langgraph.graph import StateGraph
14
+
15
+ from ursa.agents.base import BaseAgent
16
+
17
+
18
+ class RAGState(TypedDict, total=False):
19
+ context: str
20
+ doc_texts: List[str]
21
+ doc_ids: List[str]
22
+ summary: str
23
+
24
+
25
+ def remove_surrogates(text: str) -> str:
26
+ return re.sub(r"[\ud800-\udfff]", "", text)
27
+
28
+
29
+ class RAGAgent(BaseAgent):
30
+ def __init__(
31
+ self,
32
+ llm="openai/o3-mini",
33
+ embedding: Optional[Embeddings] = None,
34
+ return_k: int = 10,
35
+ chunk_size: int = 1000,
36
+ chunk_overlap: int = 200,
37
+ database_path: str = "database",
38
+ summaries_path: str = "database",
39
+ vectorstore_path: str = "vectorstore",
40
+ **kwargs,
41
+ ):
42
+ super().__init__(llm, **kwargs)
43
+ self.retriever = None
44
+ self._vs_lock = Lock()
45
+ self.return_k = return_k
46
+ self.embedding = embedding
47
+ self.chunk_size = chunk_size
48
+ self.chunk_overlap = chunk_overlap
49
+ self.database_path = database_path
50
+ self.summaries_path = summaries_path
51
+ self.vectorstore_path = vectorstore_path
52
+ self.graph = self._build_graph()
53
+
54
+ os.makedirs(self.vectorstore_path, exist_ok=True)
55
+ self.vectorstore = self._open_global_vectorstore()
56
+
57
+ @property
58
+ def manifest_path(self) -> str:
59
+ return os.path.join(self.vectorstore_path, "_ingested_ids.txt")
60
+
61
+ @property
62
+ def manifest_exists(self) -> bool:
63
+ return os.path.exists(self.manifest_path)
64
+
65
+ def _open_global_vectorstore(self) -> Chroma:
66
+ return Chroma(
67
+ persist_directory=self.vectorstore_path,
68
+ embedding_function=self.embedding,
69
+ )
70
+
71
+ def _paper_exists_in_vectorstore(self, doc_id: str) -> bool:
72
+ try:
73
+ col = self.vectorstore._collection
74
+ res = col.get(where={"id": doc_id}, limit=1)
75
+ return len(res.get("ids", [])) > 0
76
+ except Exception:
77
+ if not self.manifest_exists:
78
+ return False
79
+ with open(self.manifest_path, "r") as f:
80
+ return any(line.strip() == doc_id for line in f)
81
+
82
+ def _mark_paper_ingested(self, arxiv_id: str) -> None:
83
+ with open(self.manifest_path, "a") as f:
84
+ f.write(f"{arxiv_id}\n")
85
+
86
+ def _ensure_doc_in_vectorstore(self, paper_text: str, doc_id: str) -> None:
87
+ splitter = RecursiveCharacterTextSplitter(
88
+ chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
89
+ )
90
+ docs = splitter.create_documents(
91
+ [paper_text], metadatas=[{"id": doc_id}]
92
+ )
93
+ with self._vs_lock:
94
+ if not self._paper_exists_in_vectorstore(doc_id):
95
+ ids = [f"{doc_id}::{i}" for i, _ in enumerate(docs)]
96
+ self.vectorstore.add_documents(docs, ids=ids)
97
+ self._mark_paper_ingested(doc_id)
98
+
99
+ def _get_global_retriever(self, k: int = 5):
100
+ return self.vectorstore, self.vectorstore.as_retriever(
101
+ search_kwargs={"k": k}
102
+ )
103
+
104
+ def _read_docs(self, state: RAGState) -> RAGState:
105
+ print("[RAG Agent] Reading Documents....")
106
+ papers = []
107
+ new_state = state.copy()
108
+
109
+ pdf_files = [
110
+ f
111
+ for f in os.listdir(self.database_path)
112
+ if f.lower().endswith(".pdf")
113
+ ]
114
+
115
+ doc_ids = [
116
+ pdf_filename.rsplit(".pdf", 1)[0] for pdf_filename in pdf_files
117
+ ]
118
+ pdf_files = [
119
+ pdf_filename
120
+ for pdf_filename, id in zip(pdf_files, doc_ids)
121
+ if not self._paper_exists_in_vectorstore(id)
122
+ ]
123
+
124
+ for pdf_filename in pdf_files:
125
+ full_text = ""
126
+
127
+ try:
128
+ loader = PyPDFLoader(
129
+ os.path.join(self.database_path, pdf_filename)
130
+ )
131
+ pages = loader.load()
132
+ full_text = "\n".join([p.page_content for p in pages])
133
+
134
+ except Exception as e:
135
+ full_text = f"Error loading paper: {e}"
136
+
137
+ papers.append(full_text)
138
+
139
+ new_state["doc_texts"] = papers
140
+ new_state["doc_ids"] = doc_ids
141
+
142
+ return new_state
143
+
144
+ def _ingest_docs(self, state: RAGState) -> RAGState:
145
+ splitter = RecursiveCharacterTextSplitter(
146
+ chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
147
+ )
148
+
149
+ batch_docs, batch_ids = [], []
150
+ for paper, id in zip(state["doc_texts"], state["doc_ids"]):
151
+ cleaned_text = remove_surrogates(paper)
152
+ docs = splitter.create_documents(
153
+ [cleaned_text], metadatas=[{"id": id}]
154
+ )
155
+ ids = [f"{id}::{i}" for i, _ in enumerate(docs)]
156
+ batch_docs.extend(docs)
157
+ batch_ids.extend(ids)
158
+
159
+ if state["doc_texts"]:
160
+ print("[RAG Agent] Ingesting Documents Into RAG Database....")
161
+ with self._vs_lock:
162
+ self.vectorstore.add_documents(batch_docs, ids=batch_ids)
163
+ for id in ids:
164
+ self._mark_paper_ingested(id)
165
+
166
+ return state
167
+
168
+ def _summarize_node(self, state: RAGState) -> RAGState:
169
+ print(
170
+ "[RAG Agent] Retrieving Contextually Relevant Information From Database..."
171
+ )
172
+ prompt = ChatPromptTemplate.from_template("""
173
+ You are a scientific assistant responsible for summarizing extracts from research papers, in the context of the following task: {context}
174
+
175
+ Summarize the retrieved scientific content below.
176
+ Cite sources by ID when relevant: {source_ids}
177
+
178
+ {retrieved_content}
179
+ """)
180
+ chain = prompt | self.llm | StrOutputParser()
181
+
182
+ # 2) One retrieval over the global DB with the task context
183
+ try:
184
+ results = self.vectorstore.similarity_search_with_score(
185
+ state["context"], k=self.return_k
186
+ )
187
+ except Exception as e:
188
+ print(f"RAG failed due to: {e}")
189
+ return {**state, "summary": ""}
190
+
191
+ source_ids_list = []
192
+ for doc, _ in results:
193
+ aid = doc.metadata.get("id")
194
+ if aid and aid not in source_ids_list:
195
+ source_ids_list.append(aid)
196
+ source_ids = ", ".join(source_ids_list)
197
+
198
+ # Compute a simple similarity-based quality score
199
+ relevancy_scores = []
200
+ if results:
201
+ distances = [score for _, score in results]
202
+ sims = [1.0 / (1.0 + d) for d in distances] # map distance -> [0,1)
203
+ relevancy_scores = sims
204
+
205
+ retrieved_content = (
206
+ "\n\n".join(doc.page_content for doc, _ in results)
207
+ if results
208
+ else ""
209
+ )
210
+
211
+ print("[RAG Agent] Summarizing Retrieved Information From Database...")
212
+ # 3) One summary based on retrieved chunks
213
+ rag_summary = chain.invoke({
214
+ "retrieved_content": retrieved_content,
215
+ "context": state["context"],
216
+ "source_ids": source_ids,
217
+ })
218
+
219
+ # Persist a single file for the batch (optional)
220
+ batch_name = "RAG_summary.txt"
221
+ os.makedirs(self.summaries_path, exist_ok=True)
222
+ with open(os.path.join(self.summaries_path, batch_name), "w") as f:
223
+ f.write(rag_summary)
224
+
225
+ # Diagnostics
226
+ if relevancy_scores:
227
+ print(f"\nMax Relevancy Score: {max(relevancy_scores):.4f}")
228
+ print(f"Min Relevancy Score: {min(relevancy_scores):.4f}")
229
+ print(
230
+ f"Median Relevancy Score: {statistics.median(relevancy_scores):.4f}\n"
231
+ )
232
+ else:
233
+ print("\nNo RAG results retrieved (score list empty).\n")
234
+
235
+ # Return a single-element list by default (preferred)
236
+ return {
237
+ **state,
238
+ "summary": rag_summary,
239
+ "rag_metadata": {
240
+ "k": self.return_k,
241
+ "num_results": len(results),
242
+ "relevancy_scores": relevancy_scores,
243
+ },
244
+ }
245
+
246
+ def _build_graph(self):
247
+ builder = StateGraph(RAGState)
248
+ builder.add_node("Read Documents", self._read_docs)
249
+ builder.add_node("Ingest Documents", self._ingest_docs)
250
+ builder.add_node("Retrieve and Summarize", self._summarize_node)
251
+ builder.add_edge("Read Documents", "Ingest Documents")
252
+ builder.add_edge("Ingest Documents", "Retrieve and Summarize")
253
+
254
+ builder.set_entry_point("Read Documents")
255
+ builder.set_finish_point("Retrieve and Summarize")
256
+
257
+ graph = builder.compile()
258
+ return graph
259
+
260
+ def run(self, context: str) -> str:
261
+ result = self.graph.invoke({"context": context})
262
+
263
+ return result.get("summary", "No summary generated.")
264
+
265
+
266
+ if __name__ == "__main__":
267
+ agent = RAGAgent(database_path="workspace/arxiv_papers_neutron_star")
268
+ result = agent.run(
269
+ context="What are the constraints on the neutron star radius and what uncertainties are there on the constraints?",
270
+ )
271
+
272
+ print(result)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ursa-ai
3
- Version: 0.4.2
3
+ Version: 0.5.0
4
4
  Summary: Agents for science at LANL
5
5
  Author-email: Mike Grosskopf <mikegros@lanl.gov>, Nathan Debardeleben <ndebard@lanl.gov>, Rahul Somasundaram <rsomasundaram@lanl.gov>, Isaac Michaud <imichaud@lanl.gov>, Avanish Mishra <avanish@lanl.gov>, Arthur Lui <alui@lanl.gov>, Russell Bent <rbent@lanl.gov>, Earl Lawrence <earl@lanl.gov>
6
6
  License-Expression: BSD-3-Clause
@@ -1,5 +1,5 @@
1
- ursa/agents/__init__.py,sha256=HG95_hMlR0cEAF8Vbr7dF_7I2cWM-rMW1p8ULRXTzTg,1114
2
- ursa/agents/arxiv_agent.py,sha256=hS6RKqP0fz9KdAyFbXIxl7DuZyuhxpR-jJLfq2tCnlY,15106
1
+ ursa/agents/__init__.py,sha256=u5ClncJ-w4nIJLQJTOC-NBv6Hu4pXxALgZJJxDt3tZw,1202
2
+ ursa/agents/arxiv_agent.py,sha256=kiRPshNI-fR-uEFZGp6dyfQ38sIFFb6rUiiNVw3lAXc,13494
3
3
  ursa/agents/base.py,sha256=uFhRLVzqhFbTZVA7IePKbUi03ATCXuvga7rzwaHy1B0,1321
4
4
  ursa/agents/code_review_agent.py,sha256=aUDq5gT-jdl9Qs-Wewj2oz1d60xov9sN-DOYRfGNTU0,11550
5
5
  ursa/agents/execution_agent.py,sha256=okVTsZhG0S92evHtmxge3Ymq3pH0QLYY2VqzO39WG5Y,16581
@@ -7,6 +7,7 @@ ursa/agents/hypothesizer_agent.py,sha256=pUwFDWGBJAqL7CDXxWYJrQrknC3DgRe82Poc_Q_
7
7
  ursa/agents/lammps_agent.py,sha256=16eKOtAXEm-clnIZcfEaoxQONqbUJm_1dZhZhm2C2uM,14099
8
8
  ursa/agents/mp_agent.py,sha256=UyJSheMGHZpWQJL3EgYgPPqArfv6F8sndN05q4CPtyo,6015
9
9
  ursa/agents/planning_agent.py,sha256=ayyNDQifPvYtQ-JYnFk3TaXWZcd_6k8qUheJGariqG8,5574
10
+ ursa/agents/rag_agent.py,sha256=kEL5F3lZeSKtdXXyUWlXJYYwP2ZNQ1Hz9IWl6pzQnlY,9409
10
11
  ursa/agents/recall_agent.py,sha256=bQk7ZJtiO5pj89A50OBDzAJ4G2F7ZdsMwmKnp1WWR7g,813
11
12
  ursa/agents/websearch_agent.py,sha256=rCv4AWbqe5Us4FmuypM6jptri21nKoNg044ncsu9u3E,8014
12
13
  ursa/prompt_library/code_review_prompts.py,sha256=-HuhwW9W_p2LDn44bXLntxLADHCOyl-2KIXxRHto66w,2444
@@ -20,8 +21,8 @@ ursa/tools/write_code.py,sha256=DtCsUMZegYm0mk-HMPG5Zo3Ba1gbGfnXHsv1NZTdDs8,1220
20
21
  ursa/util/diff_renderer.py,sha256=1L1q2qWWb8gLhR532-LgJn2TrqXDx0gUpPVOWD_sqeU,4086
21
22
  ursa/util/memory_logger.py,sha256=GiKYbQBpxlNRLKyqKFJyrbSbVCkXpRB7Yr5so43tUAw,6097
22
23
  ursa/util/parse.py,sha256=M0cjyQWmjatxX4WbVmDRUiirTLyW-t_Aemlrlrsc5nA,2811
23
- ursa_ai-0.4.2.dist-info/licenses/LICENSE,sha256=4Vr6_u2zTHIUvYjoOBg9ztDbfpV3hyCFv3mTCS87gYU,1482
24
- ursa_ai-0.4.2.dist-info/METADATA,sha256=QDR_81uiAgRFxGGjqPMsFmIH9gQ93-_gFHbAS4WY54s,6898
25
- ursa_ai-0.4.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
26
- ursa_ai-0.4.2.dist-info/top_level.txt,sha256=OjA1gRYSUAeiXGnpqPC8iOOGfcjFO1IlP848qMnYSdY,5
27
- ursa_ai-0.4.2.dist-info/RECORD,,
24
+ ursa_ai-0.5.0.dist-info/licenses/LICENSE,sha256=4Vr6_u2zTHIUvYjoOBg9ztDbfpV3hyCFv3mTCS87gYU,1482
25
+ ursa_ai-0.5.0.dist-info/METADATA,sha256=Um85Xzvs6LSGUg9X0tWitsLVwYqYgMJ0I3EhYzK8oIE,6898
26
+ ursa_ai-0.5.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
27
+ ursa_ai-0.5.0.dist-info/top_level.txt,sha256=OjA1gRYSUAeiXGnpqPC8iOOGfcjFO1IlP848qMnYSdY,5
28
+ ursa_ai-0.5.0.dist-info/RECORD,,