ursa-ai 0.4.2__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ursa-ai might be problematic. Click here for more details.
- ursa/__init__.py +0 -0
- ursa/agents/__init__.py +2 -0
- ursa/agents/arxiv_agent.py +88 -99
- ursa/agents/base.py +369 -2
- ursa/agents/code_review_agent.py +3 -1
- ursa/agents/execution_agent.py +92 -48
- ursa/agents/hypothesizer_agent.py +39 -42
- ursa/agents/lammps_agent.py +51 -29
- ursa/agents/mp_agent.py +45 -20
- ursa/agents/optimization_agent.py +405 -0
- ursa/agents/planning_agent.py +63 -28
- ursa/agents/rag_agent.py +303 -0
- ursa/agents/recall_agent.py +35 -5
- ursa/agents/websearch_agent.py +44 -54
- ursa/cli/__init__.py +127 -0
- ursa/cli/hitl.py +426 -0
- ursa/observability/pricing.py +319 -0
- ursa/observability/timing.py +1441 -0
- ursa/prompt_library/__init__.py +0 -0
- ursa/prompt_library/execution_prompts.py +7 -0
- ursa/prompt_library/optimization_prompts.py +131 -0
- ursa/tools/__init__.py +0 -0
- ursa/tools/feasibility_checker.py +114 -0
- ursa/tools/feasibility_tools.py +1075 -0
- ursa/tools/write_code.py +1 -1
- ursa/util/__init__.py +0 -0
- ursa/util/helperFunctions.py +142 -0
- ursa/util/optimization_schema.py +78 -0
- ursa/util/parse.py +1 -1
- {ursa_ai-0.4.2.dist-info → ursa_ai-0.6.0.dist-info}/METADATA +123 -4
- ursa_ai-0.6.0.dist-info/RECORD +43 -0
- ursa_ai-0.6.0.dist-info/entry_points.txt +2 -0
- ursa_ai-0.4.2.dist-info/RECORD +0 -27
- {ursa_ai-0.4.2.dist-info → ursa_ai-0.6.0.dist-info}/WHEEL +0 -0
- {ursa_ai-0.4.2.dist-info → ursa_ai-0.6.0.dist-info}/licenses/LICENSE +0 -0
- {ursa_ai-0.4.2.dist-info → ursa_ai-0.6.0.dist-info}/top_level.txt +0 -0
ursa/agents/rag_agent.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import re
|
|
3
|
+
import statistics
|
|
4
|
+
from functools import cached_property
|
|
5
|
+
from threading import Lock
|
|
6
|
+
from typing import Any, Mapping, TypedDict
|
|
7
|
+
|
|
8
|
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
9
|
+
from langchain_chroma import Chroma
|
|
10
|
+
from langchain_community.document_loaders import PyPDFLoader
|
|
11
|
+
from langchain_core.embeddings import Embeddings
|
|
12
|
+
from langchain_core.output_parsers import StrOutputParser
|
|
13
|
+
from langchain_core.prompts import ChatPromptTemplate
|
|
14
|
+
from langgraph.graph import StateGraph
|
|
15
|
+
from tqdm import tqdm
|
|
16
|
+
|
|
17
|
+
from ursa.agents.base import BaseAgent
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class RAGMetadata(TypedDict):
|
|
21
|
+
k: int
|
|
22
|
+
num_results: int
|
|
23
|
+
relevance_scores: list[float]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class RAGState(TypedDict, total=False):
|
|
27
|
+
context: str
|
|
28
|
+
doc_texts: list[str]
|
|
29
|
+
doc_ids: list[str]
|
|
30
|
+
summary: str
|
|
31
|
+
rag_metadata: RAGMetadata
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def remove_surrogates(text: str) -> str:
|
|
35
|
+
return re.sub(r"[\ud800-\udfff]", "", text)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class RAGAgent(BaseAgent):
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
embedding: Embeddings,
|
|
42
|
+
llm="openai/o3-mini",
|
|
43
|
+
return_k: int = 10,
|
|
44
|
+
chunk_size: int = 1000,
|
|
45
|
+
chunk_overlap: int = 200,
|
|
46
|
+
database_path: str = "database",
|
|
47
|
+
summaries_path: str = "database",
|
|
48
|
+
vectorstore_path: str = "vectorstore",
|
|
49
|
+
**kwargs,
|
|
50
|
+
):
|
|
51
|
+
super().__init__(llm, **kwargs)
|
|
52
|
+
self.retriever = None
|
|
53
|
+
self._vs_lock = Lock()
|
|
54
|
+
self.return_k = return_k
|
|
55
|
+
self.embedding = embedding
|
|
56
|
+
self.chunk_size = chunk_size
|
|
57
|
+
self.chunk_overlap = chunk_overlap
|
|
58
|
+
self.database_path = database_path
|
|
59
|
+
self.summaries_path = summaries_path
|
|
60
|
+
self.vectorstore_path = vectorstore_path
|
|
61
|
+
|
|
62
|
+
os.makedirs(self.vectorstore_path, exist_ok=True)
|
|
63
|
+
self.vectorstore = self._open_global_vectorstore()
|
|
64
|
+
|
|
65
|
+
@cached_property
|
|
66
|
+
def graph(self):
|
|
67
|
+
return self._build_graph()
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def _action(self):
|
|
71
|
+
return self.graph
|
|
72
|
+
|
|
73
|
+
@property
|
|
74
|
+
def manifest_path(self) -> str:
|
|
75
|
+
return os.path.join(self.vectorstore_path, "_ingested_ids.txt")
|
|
76
|
+
|
|
77
|
+
@property
|
|
78
|
+
def manifest_exists(self) -> bool:
|
|
79
|
+
return os.path.exists(self.manifest_path)
|
|
80
|
+
|
|
81
|
+
def _open_global_vectorstore(self) -> Chroma:
|
|
82
|
+
return Chroma(
|
|
83
|
+
persist_directory=self.vectorstore_path,
|
|
84
|
+
embedding_function=self.embedding,
|
|
85
|
+
collection_metadata={"hnsw:space": "cosine"},
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
def _paper_exists_in_vectorstore(self, doc_id: str) -> bool:
|
|
89
|
+
try:
|
|
90
|
+
col = self.vectorstore._collection
|
|
91
|
+
res = col.get(where={"id": doc_id}, limit=1)
|
|
92
|
+
return len(res.get("ids", [])) > 0
|
|
93
|
+
except Exception:
|
|
94
|
+
if not self.manifest_exists:
|
|
95
|
+
return False
|
|
96
|
+
with open(self.manifest_path, "r") as f:
|
|
97
|
+
return any(line.strip() == doc_id for line in f)
|
|
98
|
+
|
|
99
|
+
def _mark_paper_ingested(self, arxiv_id: str) -> None:
|
|
100
|
+
with open(self.manifest_path, "a") as f:
|
|
101
|
+
f.write(f"{arxiv_id}\n")
|
|
102
|
+
|
|
103
|
+
def _ensure_doc_in_vectorstore(self, paper_text: str, doc_id: str) -> None:
|
|
104
|
+
splitter = RecursiveCharacterTextSplitter(
|
|
105
|
+
chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
|
|
106
|
+
)
|
|
107
|
+
docs = splitter.create_documents(
|
|
108
|
+
[paper_text], metadatas=[{"id": doc_id}]
|
|
109
|
+
)
|
|
110
|
+
with self._vs_lock:
|
|
111
|
+
if not self._paper_exists_in_vectorstore(doc_id):
|
|
112
|
+
ids = [f"{doc_id}::{i}" for i, _ in enumerate(docs)]
|
|
113
|
+
self.vectorstore.add_documents(docs, ids=ids)
|
|
114
|
+
self._mark_paper_ingested(doc_id)
|
|
115
|
+
|
|
116
|
+
def _get_global_retriever(self, k: int = 5):
|
|
117
|
+
return self.vectorstore, self.vectorstore.as_retriever(
|
|
118
|
+
search_kwargs={"k": k}
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
def _read_docs_node(self, state: RAGState) -> RAGState:
|
|
122
|
+
print("[RAG Agent] Reading Documents....")
|
|
123
|
+
papers = []
|
|
124
|
+
new_state = state.copy()
|
|
125
|
+
|
|
126
|
+
pdf_files = [
|
|
127
|
+
f
|
|
128
|
+
for f in os.listdir(self.database_path)
|
|
129
|
+
if f.lower().endswith(".pdf")
|
|
130
|
+
]
|
|
131
|
+
|
|
132
|
+
doc_ids = [
|
|
133
|
+
pdf_filename.rsplit(".pdf", 1)[0] for pdf_filename in pdf_files
|
|
134
|
+
]
|
|
135
|
+
pdf_files = [
|
|
136
|
+
pdf_filename
|
|
137
|
+
for pdf_filename, id in zip(pdf_files, doc_ids)
|
|
138
|
+
if not self._paper_exists_in_vectorstore(id)
|
|
139
|
+
]
|
|
140
|
+
|
|
141
|
+
for pdf_filename in tqdm(pdf_files, desc="RAG parsing text"):
|
|
142
|
+
full_text = ""
|
|
143
|
+
|
|
144
|
+
try:
|
|
145
|
+
loader = PyPDFLoader(
|
|
146
|
+
os.path.join(self.database_path, pdf_filename)
|
|
147
|
+
)
|
|
148
|
+
pages = loader.load()
|
|
149
|
+
full_text = "\n".join([p.page_content for p in pages])
|
|
150
|
+
|
|
151
|
+
except Exception as e:
|
|
152
|
+
full_text = f"Error loading paper: {e}"
|
|
153
|
+
|
|
154
|
+
papers.append(full_text)
|
|
155
|
+
|
|
156
|
+
new_state["doc_texts"] = papers
|
|
157
|
+
new_state["doc_ids"] = doc_ids
|
|
158
|
+
|
|
159
|
+
return new_state
|
|
160
|
+
|
|
161
|
+
def _ingest_docs_node(self, state: RAGState) -> RAGState:
|
|
162
|
+
splitter = RecursiveCharacterTextSplitter(
|
|
163
|
+
chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
if "doc_texts" not in state:
|
|
167
|
+
raise RuntimeError("Unexpected error: doc_ids not in state!")
|
|
168
|
+
|
|
169
|
+
if "doc_ids" not in state:
|
|
170
|
+
raise RuntimeError("Unexpected error: doc_texts not in state!")
|
|
171
|
+
|
|
172
|
+
batch_docs, batch_ids = [], []
|
|
173
|
+
for paper, id in tqdm(
|
|
174
|
+
zip(state["doc_texts"], state["doc_ids"]),
|
|
175
|
+
total=len(state["doc_texts"]),
|
|
176
|
+
desc="RAG Ingesting",
|
|
177
|
+
):
|
|
178
|
+
cleaned_text = remove_surrogates(paper)
|
|
179
|
+
docs = splitter.create_documents(
|
|
180
|
+
[cleaned_text], metadatas=[{"id": id}]
|
|
181
|
+
)
|
|
182
|
+
ids = [f"{id}::{i}" for i, _ in enumerate(docs)]
|
|
183
|
+
batch_docs.extend(docs)
|
|
184
|
+
batch_ids.extend(ids)
|
|
185
|
+
|
|
186
|
+
if state["doc_texts"]:
|
|
187
|
+
print("[RAG Agent] Ingesting Documents Into RAG Database....")
|
|
188
|
+
with self._vs_lock:
|
|
189
|
+
self.vectorstore.add_documents(batch_docs, ids=batch_ids)
|
|
190
|
+
for id in batch_ids:
|
|
191
|
+
self._mark_paper_ingested(id)
|
|
192
|
+
|
|
193
|
+
return state
|
|
194
|
+
|
|
195
|
+
def _retrieve_and_summarize_node(self, state: RAGState) -> RAGState:
|
|
196
|
+
print(
|
|
197
|
+
"[RAG Agent] Retrieving Contextually Relevant Information From Database..."
|
|
198
|
+
)
|
|
199
|
+
prompt = ChatPromptTemplate.from_template("""
|
|
200
|
+
You are a scientific assistant responsible for summarizing extracts from research papers, in the context of the following task: {context}
|
|
201
|
+
|
|
202
|
+
Summarize the retrieved scientific content below.
|
|
203
|
+
Cite sources by ID when relevant: {source_ids}
|
|
204
|
+
|
|
205
|
+
{retrieved_content}
|
|
206
|
+
""")
|
|
207
|
+
chain = prompt | self.llm | StrOutputParser()
|
|
208
|
+
|
|
209
|
+
# 2) One retrieval over the global DB with the task context
|
|
210
|
+
try:
|
|
211
|
+
if "context" not in state:
|
|
212
|
+
raise RuntimeError("Unexpected error: context not in state!")
|
|
213
|
+
|
|
214
|
+
results = self.vectorstore.similarity_search_with_relevance_scores(
|
|
215
|
+
state["context"], k=self.return_k
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
relevance_scores = [score for _, score in results]
|
|
219
|
+
except Exception as e:
|
|
220
|
+
print(f"RAG failed due to: {e}")
|
|
221
|
+
return {**state, "summary": ""}
|
|
222
|
+
|
|
223
|
+
source_ids_list = []
|
|
224
|
+
for doc, _ in results:
|
|
225
|
+
aid = doc.metadata.get("id")
|
|
226
|
+
if aid and aid not in source_ids_list:
|
|
227
|
+
source_ids_list.append(aid)
|
|
228
|
+
source_ids = ", ".join(source_ids_list)
|
|
229
|
+
|
|
230
|
+
retrieved_content = (
|
|
231
|
+
"\n\n".join(doc.page_content for doc, _ in results)
|
|
232
|
+
if results
|
|
233
|
+
else ""
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
print("[RAG Agent] Summarizing Retrieved Information From Database...")
|
|
237
|
+
# 3) One summary based on retrieved chunks
|
|
238
|
+
rag_summary = chain.invoke({
|
|
239
|
+
"retrieved_content": retrieved_content,
|
|
240
|
+
"context": state["context"],
|
|
241
|
+
"source_ids": source_ids,
|
|
242
|
+
})
|
|
243
|
+
|
|
244
|
+
# Persist a single file for the batch (optional)
|
|
245
|
+
batch_name = "RAG_summary.txt"
|
|
246
|
+
os.makedirs(self.summaries_path, exist_ok=True)
|
|
247
|
+
with open(os.path.join(self.summaries_path, batch_name), "w") as f:
|
|
248
|
+
f.write(rag_summary)
|
|
249
|
+
|
|
250
|
+
# Diagnostics
|
|
251
|
+
if relevance_scores:
|
|
252
|
+
print(f"\nMax Relevance Score: {max(relevance_scores):.4f}")
|
|
253
|
+
print(f"Min Relevance Score: {min(relevance_scores):.4f}")
|
|
254
|
+
print(
|
|
255
|
+
f"Median Relevance Score: {statistics.median(relevance_scores):.4f}\n"
|
|
256
|
+
)
|
|
257
|
+
else:
|
|
258
|
+
print("\nNo RAG results retrieved (score list empty).\n")
|
|
259
|
+
|
|
260
|
+
# Return a single-element list by default (preferred)
|
|
261
|
+
return {
|
|
262
|
+
**state,
|
|
263
|
+
"summary": rag_summary,
|
|
264
|
+
"rag_metadata": {
|
|
265
|
+
"k": self.return_k,
|
|
266
|
+
"num_results": len(results),
|
|
267
|
+
"relevance_scores": relevance_scores,
|
|
268
|
+
},
|
|
269
|
+
}
|
|
270
|
+
|
|
271
|
+
def _invoke(
|
|
272
|
+
self, inputs: Mapping[str, Any], recursion_limit: int = 100000, **_
|
|
273
|
+
):
|
|
274
|
+
config = self.build_config(
|
|
275
|
+
recursion_limit=recursion_limit, tags=["graph"]
|
|
276
|
+
)
|
|
277
|
+
return self._action.invoke(inputs, config)
|
|
278
|
+
|
|
279
|
+
def _build_graph(self):
|
|
280
|
+
graph = StateGraph(RAGState)
|
|
281
|
+
|
|
282
|
+
self.add_node(graph, self._read_docs_node)
|
|
283
|
+
self.add_node(graph, self._ingest_docs_node)
|
|
284
|
+
self.add_node(graph, self._retrieve_and_summarize_node)
|
|
285
|
+
|
|
286
|
+
graph.add_edge("_read_docs_node", "_ingest_docs_node")
|
|
287
|
+
graph.add_edge("_ingest_docs_node", "_retrieve_and_summarize_node")
|
|
288
|
+
|
|
289
|
+
graph.set_entry_point("_read_docs_node")
|
|
290
|
+
graph.set_finish_point("_retrieve_and_summarize_node")
|
|
291
|
+
|
|
292
|
+
return graph.compile(checkpointer=self.checkpointer)
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
# NOTE: Run test in `tests/agents/test_rag_agent/test_rag_agent.py` via:
|
|
296
|
+
#
|
|
297
|
+
# pytest -s tests/agents/test_rag_agent
|
|
298
|
+
#
|
|
299
|
+
# OR
|
|
300
|
+
#
|
|
301
|
+
# uv run pytest -s tests/agents/test_rag_agent
|
|
302
|
+
#
|
|
303
|
+
# NOTE: You may need to `rm -rf workspace/rag-agent` to remove the vectorstore.
|
ursa/agents/recall_agent.py
CHANGED
|
@@ -1,23 +1,53 @@
|
|
|
1
|
+
from typing import Any, Mapping, TypedDict
|
|
2
|
+
|
|
3
|
+
from langgraph.graph import StateGraph
|
|
4
|
+
|
|
1
5
|
from .base import BaseAgent
|
|
2
6
|
|
|
3
7
|
|
|
8
|
+
class RecallState(TypedDict):
|
|
9
|
+
query: str
|
|
10
|
+
memory: str
|
|
11
|
+
|
|
12
|
+
|
|
4
13
|
class RecallAgent(BaseAgent):
|
|
5
14
|
def __init__(self, llm, memory, **kwargs):
|
|
6
15
|
super().__init__(llm, **kwargs)
|
|
7
16
|
self.memorydb = memory
|
|
17
|
+
self._action = self._build_graph()
|
|
8
18
|
|
|
9
|
-
def
|
|
10
|
-
memories = self.memorydb.retrieve(query)
|
|
19
|
+
def _remember(self, state: RecallState) -> str:
|
|
20
|
+
memories = self.memorydb.retrieve(state["query"])
|
|
11
21
|
summarize_query = f"""
|
|
12
22
|
You are being given the critical task of generating a detailed description of logged information
|
|
13
23
|
to an important official to make a decision. Summarize the following memories that are related to
|
|
14
24
|
the statement. Ensure that any specific details that are important are retained in the summary.
|
|
15
25
|
|
|
16
|
-
Query: {query}
|
|
26
|
+
Query: {state["query"]}
|
|
17
27
|
|
|
18
28
|
"""
|
|
19
29
|
|
|
20
30
|
for memory in memories:
|
|
21
31
|
summarize_query += f"Memory: {memory} \n\n"
|
|
22
|
-
memory = self.llm.invoke(summarize_query).content
|
|
23
|
-
return
|
|
32
|
+
state["memory"] = self.llm.invoke(summarize_query).content
|
|
33
|
+
return state
|
|
34
|
+
|
|
35
|
+
def _build_graph(self):
|
|
36
|
+
graph = StateGraph(RecallState)
|
|
37
|
+
|
|
38
|
+
self.add_node(graph, self._remember)
|
|
39
|
+
graph.set_entry_point("_remember")
|
|
40
|
+
graph.set_finish_point("_remember")
|
|
41
|
+
return graph.compile(checkpointer=self.checkpointer)
|
|
42
|
+
|
|
43
|
+
def _invoke(
|
|
44
|
+
self, inputs: Mapping[str, Any], recursion_limit: int = 100000, **_
|
|
45
|
+
):
|
|
46
|
+
config = self.build_config(
|
|
47
|
+
recursion_limit=recursion_limit, tags=["graph"]
|
|
48
|
+
)
|
|
49
|
+
if "query" not in inputs:
|
|
50
|
+
raise ("'query' is a required argument")
|
|
51
|
+
|
|
52
|
+
output = self._action.invoke(inputs, config)
|
|
53
|
+
return output["memory"]
|
ursa/agents/websearch_agent.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# from langchain_community.tools import TavilySearchResults
|
|
2
2
|
# from langchain_core.runnables.graph import MermaidDrawMethod
|
|
3
|
-
from typing import Annotated, Any, List, Optional
|
|
3
|
+
from typing import Annotated, Any, List, Mapping, Optional
|
|
4
4
|
|
|
5
5
|
import requests
|
|
6
6
|
from bs4 import BeautifulSoup
|
|
@@ -8,7 +8,7 @@ from langchain_community.tools import DuckDuckGoSearchResults
|
|
|
8
8
|
from langchain_core.language_models import BaseChatModel
|
|
9
9
|
from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage
|
|
10
10
|
from langchain_openai import ChatOpenAI
|
|
11
|
-
from langgraph.graph import
|
|
11
|
+
from langgraph.graph import StateGraph
|
|
12
12
|
from langgraph.graph.message import add_messages
|
|
13
13
|
from langgraph.prebuilt import InjectedState, create_react_agent
|
|
14
14
|
from pydantic import Field
|
|
@@ -57,9 +57,9 @@ class WebSearchAgent(BaseAgent):
|
|
|
57
57
|
self.has_internet = self._check_for_internet(
|
|
58
58
|
kwargs.get("url", "http://www.lanl.gov")
|
|
59
59
|
)
|
|
60
|
-
self.
|
|
60
|
+
self._build_graph()
|
|
61
61
|
|
|
62
|
-
def
|
|
62
|
+
def _review_node(self, state: WebSearchState) -> WebSearchState:
|
|
63
63
|
if not self.has_internet:
|
|
64
64
|
return {
|
|
65
65
|
"messages": [
|
|
@@ -78,7 +78,7 @@ class WebSearchAgent(BaseAgent):
|
|
|
78
78
|
)
|
|
79
79
|
return {"messages": [HumanMessage(content=res.content)]}
|
|
80
80
|
|
|
81
|
-
def
|
|
81
|
+
def _response_node(self, state: WebSearchState) -> WebSearchState:
|
|
82
82
|
if not self.has_internet:
|
|
83
83
|
return {
|
|
84
84
|
"messages": [
|
|
@@ -111,60 +111,50 @@ class WebSearchAgent(BaseAgent):
|
|
|
111
111
|
except (requests.ConnectionError, requests.Timeout):
|
|
112
112
|
return False
|
|
113
113
|
|
|
114
|
-
def
|
|
114
|
+
def _state_store_node(self, state: WebSearchState) -> WebSearchState:
|
|
115
115
|
state["thread_id"] = self.thread_id
|
|
116
116
|
return state
|
|
117
117
|
# return dict(**state, thread_id=self.thread_id)
|
|
118
118
|
|
|
119
|
-
def
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
self.llm,
|
|
126
|
-
self.tools,
|
|
127
|
-
state_schema=WebSearchState,
|
|
128
|
-
prompt=self.websearch_prompt,
|
|
129
|
-
),
|
|
119
|
+
def _create_react(self, state: WebSearchState) -> WebSearchState:
|
|
120
|
+
react_agent = create_react_agent(
|
|
121
|
+
self.llm,
|
|
122
|
+
self.tools,
|
|
123
|
+
state_schema=WebSearchState,
|
|
124
|
+
prompt=self.websearch_prompt,
|
|
130
125
|
)
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
self.graph
|
|
136
|
-
self.graph
|
|
137
|
-
self.graph
|
|
138
|
-
self.graph
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
126
|
+
return react_agent.invoke(state)
|
|
127
|
+
|
|
128
|
+
def _build_graph(self):
|
|
129
|
+
graph = StateGraph(WebSearchState)
|
|
130
|
+
self.add_node(graph, self._state_store_node)
|
|
131
|
+
self.add_node(graph, self._create_react)
|
|
132
|
+
self.add_node(graph, self._review_node)
|
|
133
|
+
self.add_node(graph, self._response_node)
|
|
134
|
+
|
|
135
|
+
graph.set_entry_point("_state_store_node")
|
|
136
|
+
graph.add_edge("_state_store_node", "_create_react")
|
|
137
|
+
graph.add_edge("_create_react", "_review_node")
|
|
138
|
+
graph.set_finish_point("_response_node")
|
|
139
|
+
|
|
140
|
+
graph.add_conditional_edges(
|
|
141
|
+
"_review_node",
|
|
142
142
|
should_continue,
|
|
143
|
-
{"websearch": "websearch", "response": "response"},
|
|
144
|
-
)
|
|
145
|
-
self.action = self.graph.compile(checkpointer=self.checkpointer)
|
|
146
|
-
# self.action.get_graph().draw_mermaid_png(output_file_path="./websearch_agent_graph.png", draw_method=MermaidDrawMethod.PYPPETEER)
|
|
147
|
-
|
|
148
|
-
def run(self, prompt, recursion_limit=100):
|
|
149
|
-
if not self.has_internet:
|
|
150
|
-
return {
|
|
151
|
-
"messages": [
|
|
152
|
-
HumanMessage(
|
|
153
|
-
content="No internet for WebSearch Agent. No research carried out."
|
|
154
|
-
)
|
|
155
|
-
]
|
|
156
|
-
}
|
|
157
|
-
inputs = {
|
|
158
|
-
"messages": [HumanMessage(content=prompt)],
|
|
159
|
-
"model": self.llm,
|
|
160
|
-
}
|
|
161
|
-
return self.action.invoke(
|
|
162
|
-
inputs,
|
|
163
143
|
{
|
|
164
|
-
"
|
|
165
|
-
"
|
|
144
|
+
"_create_react": "_create_react",
|
|
145
|
+
"_response_node": "_response_node",
|
|
166
146
|
},
|
|
167
147
|
)
|
|
148
|
+
self._action = graph.compile(checkpointer=self.checkpointer)
|
|
149
|
+
# self._action.get_graph().draw_mermaid_png(output_file_path="./websearch_agent_graph.png", draw_method=MermaidDrawMethod.PYPPETEER)
|
|
150
|
+
|
|
151
|
+
def _invoke(
|
|
152
|
+
self, inputs: Mapping[str, Any], recursion_limit: int = 1000, **_
|
|
153
|
+
):
|
|
154
|
+
config = self.build_config(
|
|
155
|
+
recursion_limit=recursion_limit, tags=["graph"]
|
|
156
|
+
)
|
|
157
|
+
return self._action.invoke(inputs, config)
|
|
168
158
|
|
|
169
159
|
|
|
170
160
|
def process_content(
|
|
@@ -204,10 +194,10 @@ search_tool = DuckDuckGoSearchResults(output_format="json", num_results=10)
|
|
|
204
194
|
|
|
205
195
|
def should_continue(state: WebSearchState):
|
|
206
196
|
if len(state["messages"]) > (state.get("max_websearch_steps", 100) + 3):
|
|
207
|
-
return "
|
|
197
|
+
return "_response_node"
|
|
208
198
|
if "[APPROVED]" in state["messages"][-1].content:
|
|
209
|
-
return "
|
|
210
|
-
return "
|
|
199
|
+
return "_response_node"
|
|
200
|
+
return "_create_react"
|
|
211
201
|
|
|
212
202
|
|
|
213
203
|
def main():
|
|
@@ -220,7 +210,7 @@ def main():
|
|
|
220
210
|
"messages": [HumanMessage(content=problem_string)],
|
|
221
211
|
"model": model,
|
|
222
212
|
}
|
|
223
|
-
result = websearcher.
|
|
213
|
+
result = websearcher.invoke(
|
|
224
214
|
inputs,
|
|
225
215
|
{
|
|
226
216
|
"recursion_limit": 10000,
|
ursa/cli/__init__.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from typing import Annotated, Optional
|
|
3
|
+
|
|
4
|
+
from rich.console import Console
|
|
5
|
+
from typer import Option, Typer
|
|
6
|
+
|
|
7
|
+
app = Typer()
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
# TODO: add help
|
|
11
|
+
@app.command()
|
|
12
|
+
def run(
|
|
13
|
+
workspace: Annotated[
|
|
14
|
+
Path, Option(help="Directory to store ursa ouput")
|
|
15
|
+
] = Path(".ursa"),
|
|
16
|
+
llm_model_name: Annotated[
|
|
17
|
+
str,
|
|
18
|
+
Option(
|
|
19
|
+
help="Name of LLM to use for agent tasks", envvar="URSA_LLM_NAME"
|
|
20
|
+
),
|
|
21
|
+
] = "gpt-5",
|
|
22
|
+
llm_base_url: Annotated[
|
|
23
|
+
str, Option(help="Base url for LLM.", envvar="URSA_LLM_BASE_URL")
|
|
24
|
+
] = "https://api.openai.com/v1",
|
|
25
|
+
llm_api_key: Annotated[
|
|
26
|
+
Optional[str], Option(help="API key for LLM", envvar="URSA_LLM_API_KEY")
|
|
27
|
+
] = None,
|
|
28
|
+
max_completion_tokens: Annotated[
|
|
29
|
+
int, Option(help="Maximum tokens for LLM to output")
|
|
30
|
+
] = 50000,
|
|
31
|
+
emb_model_name: Annotated[
|
|
32
|
+
str, Option(help="Embedding model name", envvar="URSA_EMB_NAME")
|
|
33
|
+
] = "text-embedding-3-small",
|
|
34
|
+
emb_base_url: Annotated[
|
|
35
|
+
str,
|
|
36
|
+
Option(help="Base url for embedding model", envvar="URSA_EMB_BASE_URL"),
|
|
37
|
+
] = "https://api.openai.com/v1",
|
|
38
|
+
emb_api_key: Annotated[
|
|
39
|
+
Optional[str],
|
|
40
|
+
Option(help="API key for embedding model", envvar="URSA_EMB_API_KEY"),
|
|
41
|
+
] = None,
|
|
42
|
+
share_key: Annotated[
|
|
43
|
+
bool,
|
|
44
|
+
Option(
|
|
45
|
+
help=(
|
|
46
|
+
"Whether or not the LLM and embedding model share the same "
|
|
47
|
+
"API key. If yes, then you can specify only one of them."
|
|
48
|
+
)
|
|
49
|
+
),
|
|
50
|
+
] = False,
|
|
51
|
+
arxiv_summarize: Annotated[
|
|
52
|
+
bool,
|
|
53
|
+
Option(
|
|
54
|
+
help="Whether or not to allow ArxivAgent to summarize response."
|
|
55
|
+
),
|
|
56
|
+
] = True,
|
|
57
|
+
arxiv_process_images: Annotated[
|
|
58
|
+
bool,
|
|
59
|
+
Option(help="Whether or not to allow ArxivAgent to process images."),
|
|
60
|
+
] = False,
|
|
61
|
+
arxiv_max_results: Annotated[
|
|
62
|
+
int,
|
|
63
|
+
Option(
|
|
64
|
+
help="Maximum number of results for ArxivAgent to retrieve from ArXiv."
|
|
65
|
+
),
|
|
66
|
+
] = 10,
|
|
67
|
+
arxiv_database_path: Annotated[
|
|
68
|
+
Optional[Path],
|
|
69
|
+
Option(
|
|
70
|
+
help="Path to download/downloaded ArXiv documents; used by ArxivAgent."
|
|
71
|
+
),
|
|
72
|
+
] = None,
|
|
73
|
+
arxiv_summaries_path: Annotated[
|
|
74
|
+
Optional[Path],
|
|
75
|
+
Option(help="Path to store ArXiv paper summaries; used by ArxivAgent."),
|
|
76
|
+
] = None,
|
|
77
|
+
arxiv_vectorstore_path: Annotated[
|
|
78
|
+
Optional[Path],
|
|
79
|
+
Option(
|
|
80
|
+
help="Path to store ArXiv paper vector store; used by ArxivAgent."
|
|
81
|
+
),
|
|
82
|
+
] = None,
|
|
83
|
+
arxiv_download_papers: Annotated[
|
|
84
|
+
bool,
|
|
85
|
+
Option(
|
|
86
|
+
help="Whether or not to allow ArxivAgent to download ArXiv papers."
|
|
87
|
+
),
|
|
88
|
+
] = True,
|
|
89
|
+
ssl_verify: Annotated[
|
|
90
|
+
bool, Option(help="Whether or not to verify SSL certificates.")
|
|
91
|
+
] = True,
|
|
92
|
+
) -> None:
|
|
93
|
+
console = Console()
|
|
94
|
+
with console.status("[grey50]Loading ursa ..."):
|
|
95
|
+
from ursa.cli.hitl import HITL, UrsaRepl
|
|
96
|
+
|
|
97
|
+
hitl = HITL(
|
|
98
|
+
workspace=workspace,
|
|
99
|
+
llm_model_name=llm_model_name,
|
|
100
|
+
llm_base_url=llm_base_url,
|
|
101
|
+
llm_api_key=llm_api_key,
|
|
102
|
+
max_completion_tokens=max_completion_tokens,
|
|
103
|
+
emb_model_name=emb_model_name,
|
|
104
|
+
emb_base_url=emb_base_url,
|
|
105
|
+
emb_api_key=emb_api_key,
|
|
106
|
+
share_key=share_key,
|
|
107
|
+
arxiv_summarize=arxiv_summarize,
|
|
108
|
+
arxiv_process_images=arxiv_process_images,
|
|
109
|
+
arxiv_max_results=arxiv_max_results,
|
|
110
|
+
arxiv_database_path=arxiv_database_path,
|
|
111
|
+
arxiv_summaries_path=arxiv_summaries_path,
|
|
112
|
+
arxiv_vectorstore_path=arxiv_vectorstore_path,
|
|
113
|
+
arxiv_download_papers=arxiv_download_papers,
|
|
114
|
+
ssl_verify=ssl_verify,
|
|
115
|
+
)
|
|
116
|
+
UrsaRepl(hitl).run()
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
@app.command()
|
|
120
|
+
def version() -> None:
|
|
121
|
+
from importlib.metadata import version as get_version
|
|
122
|
+
|
|
123
|
+
print(get_version("ursa-ai"))
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def main():
|
|
127
|
+
app()
|