ursa-ai 0.5.0__py3-none-any.whl → 0.6.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ursa-ai might be problematic. Click here for more details.

ursa/agents/rag_agent.py CHANGED
@@ -1,8 +1,9 @@
1
1
  import os
2
2
  import re
3
3
  import statistics
4
+ from functools import cached_property
4
5
  from threading import Lock
5
- from typing import List, Optional, TypedDict
6
+ from typing import Any, Mapping, TypedDict
6
7
 
7
8
  from langchain.text_splitter import RecursiveCharacterTextSplitter
8
9
  from langchain_chroma import Chroma
@@ -11,15 +12,23 @@ from langchain_core.embeddings import Embeddings
11
12
  from langchain_core.output_parsers import StrOutputParser
12
13
  from langchain_core.prompts import ChatPromptTemplate
13
14
  from langgraph.graph import StateGraph
15
+ from tqdm import tqdm
14
16
 
15
17
  from ursa.agents.base import BaseAgent
16
18
 
17
19
 
20
+ class RAGMetadata(TypedDict):
21
+ k: int
22
+ num_results: int
23
+ relevance_scores: list[float]
24
+
25
+
18
26
  class RAGState(TypedDict, total=False):
19
27
  context: str
20
- doc_texts: List[str]
21
- doc_ids: List[str]
28
+ doc_texts: list[str]
29
+ doc_ids: list[str]
22
30
  summary: str
31
+ rag_metadata: RAGMetadata
23
32
 
24
33
 
25
34
  def remove_surrogates(text: str) -> str:
@@ -29,8 +38,8 @@ def remove_surrogates(text: str) -> str:
29
38
  class RAGAgent(BaseAgent):
30
39
  def __init__(
31
40
  self,
41
+ embedding: Embeddings,
32
42
  llm="openai/o3-mini",
33
- embedding: Optional[Embeddings] = None,
34
43
  return_k: int = 10,
35
44
  chunk_size: int = 1000,
36
45
  chunk_overlap: int = 200,
@@ -49,11 +58,18 @@ class RAGAgent(BaseAgent):
49
58
  self.database_path = database_path
50
59
  self.summaries_path = summaries_path
51
60
  self.vectorstore_path = vectorstore_path
52
- self.graph = self._build_graph()
53
61
 
54
62
  os.makedirs(self.vectorstore_path, exist_ok=True)
55
63
  self.vectorstore = self._open_global_vectorstore()
56
64
 
65
+ @cached_property
66
+ def graph(self):
67
+ return self._build_graph()
68
+
69
+ @property
70
+ def _action(self):
71
+ return self.graph
72
+
57
73
  @property
58
74
  def manifest_path(self) -> str:
59
75
  return os.path.join(self.vectorstore_path, "_ingested_ids.txt")
@@ -66,6 +82,7 @@ class RAGAgent(BaseAgent):
66
82
  return Chroma(
67
83
  persist_directory=self.vectorstore_path,
68
84
  embedding_function=self.embedding,
85
+ collection_metadata={"hnsw:space": "cosine"},
69
86
  )
70
87
 
71
88
  def _paper_exists_in_vectorstore(self, doc_id: str) -> bool:
@@ -101,7 +118,7 @@ class RAGAgent(BaseAgent):
101
118
  search_kwargs={"k": k}
102
119
  )
103
120
 
104
- def _read_docs(self, state: RAGState) -> RAGState:
121
+ def _read_docs_node(self, state: RAGState) -> RAGState:
105
122
  print("[RAG Agent] Reading Documents....")
106
123
  papers = []
107
124
  new_state = state.copy()
@@ -121,7 +138,7 @@ class RAGAgent(BaseAgent):
121
138
  if not self._paper_exists_in_vectorstore(id)
122
139
  ]
123
140
 
124
- for pdf_filename in pdf_files:
141
+ for pdf_filename in tqdm(pdf_files, desc="RAG parsing text"):
125
142
  full_text = ""
126
143
 
127
144
  try:
@@ -141,13 +158,23 @@ class RAGAgent(BaseAgent):
141
158
 
142
159
  return new_state
143
160
 
144
- def _ingest_docs(self, state: RAGState) -> RAGState:
161
+ def _ingest_docs_node(self, state: RAGState) -> RAGState:
145
162
  splitter = RecursiveCharacterTextSplitter(
146
163
  chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
147
164
  )
148
165
 
166
+ if "doc_texts" not in state:
167
+ raise RuntimeError("Unexpected error: doc_ids not in state!")
168
+
169
+ if "doc_ids" not in state:
170
+ raise RuntimeError("Unexpected error: doc_texts not in state!")
171
+
149
172
  batch_docs, batch_ids = [], []
150
- for paper, id in zip(state["doc_texts"], state["doc_ids"]):
173
+ for paper, id in tqdm(
174
+ zip(state["doc_texts"], state["doc_ids"]),
175
+ total=len(state["doc_texts"]),
176
+ desc="RAG Ingesting",
177
+ ):
151
178
  cleaned_text = remove_surrogates(paper)
152
179
  docs = splitter.create_documents(
153
180
  [cleaned_text], metadatas=[{"id": id}]
@@ -160,12 +187,12 @@ class RAGAgent(BaseAgent):
160
187
  print("[RAG Agent] Ingesting Documents Into RAG Database....")
161
188
  with self._vs_lock:
162
189
  self.vectorstore.add_documents(batch_docs, ids=batch_ids)
163
- for id in ids:
190
+ for id in batch_ids:
164
191
  self._mark_paper_ingested(id)
165
192
 
166
193
  return state
167
194
 
168
- def _summarize_node(self, state: RAGState) -> RAGState:
195
+ def _retrieve_and_summarize_node(self, state: RAGState) -> RAGState:
169
196
  print(
170
197
  "[RAG Agent] Retrieving Contextually Relevant Information From Database..."
171
198
  )
@@ -181,9 +208,14 @@ class RAGAgent(BaseAgent):
181
208
 
182
209
  # 2) One retrieval over the global DB with the task context
183
210
  try:
184
- results = self.vectorstore.similarity_search_with_score(
211
+ if "context" not in state:
212
+ raise RuntimeError("Unexpected error: context not in state!")
213
+
214
+ results = self.vectorstore.similarity_search_with_relevance_scores(
185
215
  state["context"], k=self.return_k
186
216
  )
217
+
218
+ relevance_scores = [score for _, score in results]
187
219
  except Exception as e:
188
220
  print(f"RAG failed due to: {e}")
189
221
  return {**state, "summary": ""}
@@ -195,13 +227,6 @@ class RAGAgent(BaseAgent):
195
227
  source_ids_list.append(aid)
196
228
  source_ids = ", ".join(source_ids_list)
197
229
 
198
- # Compute a simple similarity-based quality score
199
- relevancy_scores = []
200
- if results:
201
- distances = [score for _, score in results]
202
- sims = [1.0 / (1.0 + d) for d in distances] # map distance -> [0,1)
203
- relevancy_scores = sims
204
-
205
230
  retrieved_content = (
206
231
  "\n\n".join(doc.page_content for doc, _ in results)
207
232
  if results
@@ -223,11 +248,11 @@ class RAGAgent(BaseAgent):
223
248
  f.write(rag_summary)
224
249
 
225
250
  # Diagnostics
226
- if relevancy_scores:
227
- print(f"\nMax Relevancy Score: {max(relevancy_scores):.4f}")
228
- print(f"Min Relevancy Score: {min(relevancy_scores):.4f}")
251
+ if relevance_scores:
252
+ print(f"\nMax Relevance Score: {max(relevance_scores):.4f}")
253
+ print(f"Min Relevance Score: {min(relevance_scores):.4f}")
229
254
  print(
230
- f"Median Relevancy Score: {statistics.median(relevancy_scores):.4f}\n"
255
+ f"Median Relevance Score: {statistics.median(relevance_scores):.4f}\n"
231
256
  )
232
257
  else:
233
258
  print("\nNo RAG results retrieved (score list empty).\n")
@@ -239,34 +264,40 @@ class RAGAgent(BaseAgent):
239
264
  "rag_metadata": {
240
265
  "k": self.return_k,
241
266
  "num_results": len(results),
242
- "relevancy_scores": relevancy_scores,
267
+ "relevance_scores": relevance_scores,
243
268
  },
244
269
  }
245
270
 
246
- def _build_graph(self):
247
- builder = StateGraph(RAGState)
248
- builder.add_node("Read Documents", self._read_docs)
249
- builder.add_node("Ingest Documents", self._ingest_docs)
250
- builder.add_node("Retrieve and Summarize", self._summarize_node)
251
- builder.add_edge("Read Documents", "Ingest Documents")
252
- builder.add_edge("Ingest Documents", "Retrieve and Summarize")
271
+ def _invoke(
272
+ self, inputs: Mapping[str, Any], recursion_limit: int = 100000, **_
273
+ ):
274
+ config = self.build_config(
275
+ recursion_limit=recursion_limit, tags=["graph"]
276
+ )
277
+ return self._action.invoke(inputs, config)
253
278
 
254
- builder.set_entry_point("Read Documents")
255
- builder.set_finish_point("Retrieve and Summarize")
279
+ def _build_graph(self):
280
+ graph = StateGraph(RAGState)
256
281
 
257
- graph = builder.compile()
258
- return graph
282
+ self.add_node(graph, self._read_docs_node)
283
+ self.add_node(graph, self._ingest_docs_node)
284
+ self.add_node(graph, self._retrieve_and_summarize_node)
259
285
 
260
- def run(self, context: str) -> str:
261
- result = self.graph.invoke({"context": context})
286
+ graph.add_edge("_read_docs_node", "_ingest_docs_node")
287
+ graph.add_edge("_ingest_docs_node", "_retrieve_and_summarize_node")
262
288
 
263
- return result.get("summary", "No summary generated.")
289
+ graph.set_entry_point("_read_docs_node")
290
+ graph.set_finish_point("_retrieve_and_summarize_node")
264
291
 
292
+ return graph.compile(checkpointer=self.checkpointer)
265
293
 
266
- if __name__ == "__main__":
267
- agent = RAGAgent(database_path="workspace/arxiv_papers_neutron_star")
268
- result = agent.run(
269
- context="What are the constraints on the neutron star radius and what uncertainties are there on the constraints?",
270
- )
271
294
 
272
- print(result)
295
+ # NOTE: Run test in `tests/agents/test_rag_agent/test_rag_agent.py` via:
296
+ #
297
+ # pytest -s tests/agents/test_rag_agent
298
+ #
299
+ # OR
300
+ #
301
+ # uv run pytest -s tests/agents/test_rag_agent
302
+ #
303
+ # NOTE: You may need to `rm -rf workspace/rag-agent` to remove the vectorstore.
@@ -1,23 +1,53 @@
1
+ from typing import Any, Mapping, TypedDict
2
+
3
+ from langgraph.graph import StateGraph
4
+
1
5
  from .base import BaseAgent
2
6
 
3
7
 
8
+ class RecallState(TypedDict):
9
+ query: str
10
+ memory: str
11
+
12
+
4
13
  class RecallAgent(BaseAgent):
5
14
  def __init__(self, llm, memory, **kwargs):
6
15
  super().__init__(llm, **kwargs)
7
16
  self.memorydb = memory
17
+ self._action = self._build_graph()
8
18
 
9
- def remember(self, query):
10
- memories = self.memorydb.retrieve(query)
19
+ def _remember(self, state: RecallState) -> str:
20
+ memories = self.memorydb.retrieve(state["query"])
11
21
  summarize_query = f"""
12
22
  You are being given the critical task of generating a detailed description of logged information
13
23
  to an important official to make a decision. Summarize the following memories that are related to
14
24
  the statement. Ensure that any specific details that are important are retained in the summary.
15
25
 
16
- Query: {query}
26
+ Query: {state["query"]}
17
27
 
18
28
  """
19
29
 
20
30
  for memory in memories:
21
31
  summarize_query += f"Memory: {memory} \n\n"
22
- memory = self.llm.invoke(summarize_query).content
23
- return memory
32
+ state["memory"] = self.llm.invoke(summarize_query).content
33
+ return state
34
+
35
+ def _build_graph(self):
36
+ graph = StateGraph(RecallState)
37
+
38
+ self.add_node(graph, self._remember)
39
+ graph.set_entry_point("_remember")
40
+ graph.set_finish_point("_remember")
41
+ return graph.compile(checkpointer=self.checkpointer)
42
+
43
+ def _invoke(
44
+ self, inputs: Mapping[str, Any], recursion_limit: int = 100000, **_
45
+ ):
46
+ config = self.build_config(
47
+ recursion_limit=recursion_limit, tags=["graph"]
48
+ )
49
+ if "query" not in inputs:
50
+ raise ("'query' is a required argument")
51
+
52
+ output = self._action.invoke(inputs, config)
53
+ return output["memory"]
@@ -1,6 +1,6 @@
1
1
  # from langchain_community.tools import TavilySearchResults
2
2
  # from langchain_core.runnables.graph import MermaidDrawMethod
3
- from typing import Annotated, Any, List, Optional
3
+ from typing import Annotated, Any, List, Mapping, Optional
4
4
 
5
5
  import requests
6
6
  from bs4 import BeautifulSoup
@@ -8,7 +8,7 @@ from langchain_community.tools import DuckDuckGoSearchResults
8
8
  from langchain_core.language_models import BaseChatModel
9
9
  from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage
10
10
  from langchain_openai import ChatOpenAI
11
- from langgraph.graph import END, START, StateGraph
11
+ from langgraph.graph import StateGraph
12
12
  from langgraph.graph.message import add_messages
13
13
  from langgraph.prebuilt import InjectedState, create_react_agent
14
14
  from pydantic import Field
@@ -57,9 +57,9 @@ class WebSearchAgent(BaseAgent):
57
57
  self.has_internet = self._check_for_internet(
58
58
  kwargs.get("url", "http://www.lanl.gov")
59
59
  )
60
- self._initialize_agent()
60
+ self._build_graph()
61
61
 
62
- def review_node(self, state: WebSearchState) -> WebSearchState:
62
+ def _review_node(self, state: WebSearchState) -> WebSearchState:
63
63
  if not self.has_internet:
64
64
  return {
65
65
  "messages": [
@@ -78,7 +78,7 @@ class WebSearchAgent(BaseAgent):
78
78
  )
79
79
  return {"messages": [HumanMessage(content=res.content)]}
80
80
 
81
- def response_node(self, state: WebSearchState) -> WebSearchState:
81
+ def _response_node(self, state: WebSearchState) -> WebSearchState:
82
82
  if not self.has_internet:
83
83
  return {
84
84
  "messages": [
@@ -111,60 +111,50 @@ class WebSearchAgent(BaseAgent):
111
111
  except (requests.ConnectionError, requests.Timeout):
112
112
  return False
113
113
 
114
- def state_store_node(self, state: WebSearchState) -> WebSearchState:
114
+ def _state_store_node(self, state: WebSearchState) -> WebSearchState:
115
115
  state["thread_id"] = self.thread_id
116
116
  return state
117
117
  # return dict(**state, thread_id=self.thread_id)
118
118
 
119
- def _initialize_agent(self):
120
- self.graph = StateGraph(WebSearchState)
121
- self.graph.add_node("state_store", self.state_store_node)
122
- self.graph.add_node(
123
- "websearch",
124
- create_react_agent(
125
- self.llm,
126
- self.tools,
127
- state_schema=WebSearchState,
128
- prompt=self.websearch_prompt,
129
- ),
119
+ def _create_react(self, state: WebSearchState) -> WebSearchState:
120
+ react_agent = create_react_agent(
121
+ self.llm,
122
+ self.tools,
123
+ state_schema=WebSearchState,
124
+ prompt=self.websearch_prompt,
130
125
  )
131
-
132
- self.graph.add_node("review", self.review_node)
133
- self.graph.add_node("response", self.response_node)
134
-
135
- self.graph.add_edge(START, "state_store")
136
- self.graph.add_edge("state_store", "websearch")
137
- self.graph.add_edge("websearch", "review")
138
- self.graph.add_edge("response", END)
139
-
140
- self.graph.add_conditional_edges(
141
- "review",
126
+ return react_agent.invoke(state)
127
+
128
+ def _build_graph(self):
129
+ graph = StateGraph(WebSearchState)
130
+ self.add_node(graph, self._state_store_node)
131
+ self.add_node(graph, self._create_react)
132
+ self.add_node(graph, self._review_node)
133
+ self.add_node(graph, self._response_node)
134
+
135
+ graph.set_entry_point("_state_store_node")
136
+ graph.add_edge("_state_store_node", "_create_react")
137
+ graph.add_edge("_create_react", "_review_node")
138
+ graph.set_finish_point("_response_node")
139
+
140
+ graph.add_conditional_edges(
141
+ "_review_node",
142
142
  should_continue,
143
- {"websearch": "websearch", "response": "response"},
144
- )
145
- self.action = self.graph.compile(checkpointer=self.checkpointer)
146
- # self.action.get_graph().draw_mermaid_png(output_file_path="./websearch_agent_graph.png", draw_method=MermaidDrawMethod.PYPPETEER)
147
-
148
- def run(self, prompt, recursion_limit=100):
149
- if not self.has_internet:
150
- return {
151
- "messages": [
152
- HumanMessage(
153
- content="No internet for WebSearch Agent. No research carried out."
154
- )
155
- ]
156
- }
157
- inputs = {
158
- "messages": [HumanMessage(content=prompt)],
159
- "model": self.llm,
160
- }
161
- return self.action.invoke(
162
- inputs,
163
143
  {
164
- "recursion_limit": recursion_limit,
165
- "configurable": {"thread_id": self.thread_id},
144
+ "_create_react": "_create_react",
145
+ "_response_node": "_response_node",
166
146
  },
167
147
  )
148
+ self._action = graph.compile(checkpointer=self.checkpointer)
149
+ # self._action.get_graph().draw_mermaid_png(output_file_path="./websearch_agent_graph.png", draw_method=MermaidDrawMethod.PYPPETEER)
150
+
151
+ def _invoke(
152
+ self, inputs: Mapping[str, Any], recursion_limit: int = 1000, **_
153
+ ):
154
+ config = self.build_config(
155
+ recursion_limit=recursion_limit, tags=["graph"]
156
+ )
157
+ return self._action.invoke(inputs, config)
168
158
 
169
159
 
170
160
  def process_content(
@@ -204,10 +194,10 @@ search_tool = DuckDuckGoSearchResults(output_format="json", num_results=10)
204
194
 
205
195
  def should_continue(state: WebSearchState):
206
196
  if len(state["messages"]) > (state.get("max_websearch_steps", 100) + 3):
207
- return "response"
197
+ return "_response_node"
208
198
  if "[APPROVED]" in state["messages"][-1].content:
209
- return "response"
210
- return "websearch"
199
+ return "_response_node"
200
+ return "_create_react"
211
201
 
212
202
 
213
203
  def main():
@@ -220,7 +210,7 @@ def main():
220
210
  "messages": [HumanMessage(content=problem_string)],
221
211
  "model": model,
222
212
  }
223
- result = websearcher.action.invoke(
213
+ result = websearcher.invoke(
224
214
  inputs,
225
215
  {
226
216
  "recursion_limit": 10000,
ursa/cli/__init__.py ADDED
@@ -0,0 +1,127 @@
1
+ from pathlib import Path
2
+ from typing import Annotated, Optional
3
+
4
+ from rich.console import Console
5
+ from typer import Option, Typer
6
+
7
+ app = Typer()
8
+
9
+
10
+ # TODO: add help
11
+ @app.command()
12
+ def run(
13
+ workspace: Annotated[
14
+ Path, Option(help="Directory to store ursa ouput")
15
+ ] = Path(".ursa"),
16
+ llm_model_name: Annotated[
17
+ str,
18
+ Option(
19
+ help="Name of LLM to use for agent tasks", envvar="URSA_LLM_NAME"
20
+ ),
21
+ ] = "gpt-5",
22
+ llm_base_url: Annotated[
23
+ str, Option(help="Base url for LLM.", envvar="URSA_LLM_BASE_URL")
24
+ ] = "https://api.openai.com/v1",
25
+ llm_api_key: Annotated[
26
+ Optional[str], Option(help="API key for LLM", envvar="URSA_LLM_API_KEY")
27
+ ] = None,
28
+ max_completion_tokens: Annotated[
29
+ int, Option(help="Maximum tokens for LLM to output")
30
+ ] = 50000,
31
+ emb_model_name: Annotated[
32
+ str, Option(help="Embedding model name", envvar="URSA_EMB_NAME")
33
+ ] = "text-embedding-3-small",
34
+ emb_base_url: Annotated[
35
+ str,
36
+ Option(help="Base url for embedding model", envvar="URSA_EMB_BASE_URL"),
37
+ ] = "https://api.openai.com/v1",
38
+ emb_api_key: Annotated[
39
+ Optional[str],
40
+ Option(help="API key for embedding model", envvar="URSA_EMB_API_KEY"),
41
+ ] = None,
42
+ share_key: Annotated[
43
+ bool,
44
+ Option(
45
+ help=(
46
+ "Whether or not the LLM and embedding model share the same "
47
+ "API key. If yes, then you can specify only one of them."
48
+ )
49
+ ),
50
+ ] = False,
51
+ arxiv_summarize: Annotated[
52
+ bool,
53
+ Option(
54
+ help="Whether or not to allow ArxivAgent to summarize response."
55
+ ),
56
+ ] = True,
57
+ arxiv_process_images: Annotated[
58
+ bool,
59
+ Option(help="Whether or not to allow ArxivAgent to process images."),
60
+ ] = False,
61
+ arxiv_max_results: Annotated[
62
+ int,
63
+ Option(
64
+ help="Maximum number of results for ArxivAgent to retrieve from ArXiv."
65
+ ),
66
+ ] = 10,
67
+ arxiv_database_path: Annotated[
68
+ Optional[Path],
69
+ Option(
70
+ help="Path to download/downloaded ArXiv documents; used by ArxivAgent."
71
+ ),
72
+ ] = None,
73
+ arxiv_summaries_path: Annotated[
74
+ Optional[Path],
75
+ Option(help="Path to store ArXiv paper summaries; used by ArxivAgent."),
76
+ ] = None,
77
+ arxiv_vectorstore_path: Annotated[
78
+ Optional[Path],
79
+ Option(
80
+ help="Path to store ArXiv paper vector store; used by ArxivAgent."
81
+ ),
82
+ ] = None,
83
+ arxiv_download_papers: Annotated[
84
+ bool,
85
+ Option(
86
+ help="Whether or not to allow ArxivAgent to download ArXiv papers."
87
+ ),
88
+ ] = True,
89
+ ssl_verify: Annotated[
90
+ bool, Option(help="Whether or not to verify SSL certificates.")
91
+ ] = True,
92
+ ) -> None:
93
+ console = Console()
94
+ with console.status("[grey50]Loading ursa ..."):
95
+ from ursa.cli.hitl import HITL, UrsaRepl
96
+
97
+ hitl = HITL(
98
+ workspace=workspace,
99
+ llm_model_name=llm_model_name,
100
+ llm_base_url=llm_base_url,
101
+ llm_api_key=llm_api_key,
102
+ max_completion_tokens=max_completion_tokens,
103
+ emb_model_name=emb_model_name,
104
+ emb_base_url=emb_base_url,
105
+ emb_api_key=emb_api_key,
106
+ share_key=share_key,
107
+ arxiv_summarize=arxiv_summarize,
108
+ arxiv_process_images=arxiv_process_images,
109
+ arxiv_max_results=arxiv_max_results,
110
+ arxiv_database_path=arxiv_database_path,
111
+ arxiv_summaries_path=arxiv_summaries_path,
112
+ arxiv_vectorstore_path=arxiv_vectorstore_path,
113
+ arxiv_download_papers=arxiv_download_papers,
114
+ ssl_verify=ssl_verify,
115
+ )
116
+ UrsaRepl(hitl).run()
117
+
118
+
119
+ @app.command()
120
+ def version() -> None:
121
+ from importlib.metadata import version as get_version
122
+
123
+ print(get_version("ursa-ai"))
124
+
125
+
126
+ def main():
127
+ app()