ursa-ai 0.5.0__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ursa-ai might be problematic. Click here for more details.
- ursa/__init__.py +0 -0
- ursa/agents/arxiv_agent.py +77 -47
- ursa/agents/base.py +369 -2
- ursa/agents/code_review_agent.py +3 -1
- ursa/agents/execution_agent.py +92 -48
- ursa/agents/hypothesizer_agent.py +39 -42
- ursa/agents/lammps_agent.py +51 -29
- ursa/agents/mp_agent.py +45 -20
- ursa/agents/optimization_agent.py +405 -0
- ursa/agents/planning_agent.py +63 -28
- ursa/agents/rag_agent.py +75 -44
- ursa/agents/recall_agent.py +35 -5
- ursa/agents/websearch_agent.py +44 -54
- ursa/cli/__init__.py +127 -0
- ursa/cli/hitl.py +426 -0
- ursa/observability/pricing.py +319 -0
- ursa/observability/timing.py +1441 -0
- ursa/prompt_library/__init__.py +0 -0
- ursa/prompt_library/execution_prompts.py +7 -0
- ursa/prompt_library/optimization_prompts.py +131 -0
- ursa/tools/__init__.py +0 -0
- ursa/tools/feasibility_checker.py +114 -0
- ursa/tools/feasibility_tools.py +1075 -0
- ursa/tools/write_code.py +1 -1
- ursa/util/__init__.py +0 -0
- ursa/util/helperFunctions.py +142 -0
- ursa/util/optimization_schema.py +78 -0
- ursa/util/parse.py +1 -1
- {ursa_ai-0.5.0.dist-info → ursa_ai-0.6.0.dist-info}/METADATA +123 -4
- ursa_ai-0.6.0.dist-info/RECORD +43 -0
- ursa_ai-0.6.0.dist-info/entry_points.txt +2 -0
- ursa_ai-0.5.0.dist-info/RECORD +0 -28
- {ursa_ai-0.5.0.dist-info → ursa_ai-0.6.0.dist-info}/WHEEL +0 -0
- {ursa_ai-0.5.0.dist-info → ursa_ai-0.6.0.dist-info}/licenses/LICENSE +0 -0
- {ursa_ai-0.5.0.dist-info → ursa_ai-0.6.0.dist-info}/top_level.txt +0 -0
ursa/agents/rag_agent.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import re
|
|
3
3
|
import statistics
|
|
4
|
+
from functools import cached_property
|
|
4
5
|
from threading import Lock
|
|
5
|
-
from typing import
|
|
6
|
+
from typing import Any, Mapping, TypedDict
|
|
6
7
|
|
|
7
8
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
8
9
|
from langchain_chroma import Chroma
|
|
@@ -11,15 +12,23 @@ from langchain_core.embeddings import Embeddings
|
|
|
11
12
|
from langchain_core.output_parsers import StrOutputParser
|
|
12
13
|
from langchain_core.prompts import ChatPromptTemplate
|
|
13
14
|
from langgraph.graph import StateGraph
|
|
15
|
+
from tqdm import tqdm
|
|
14
16
|
|
|
15
17
|
from ursa.agents.base import BaseAgent
|
|
16
18
|
|
|
17
19
|
|
|
20
|
+
class RAGMetadata(TypedDict):
|
|
21
|
+
k: int
|
|
22
|
+
num_results: int
|
|
23
|
+
relevance_scores: list[float]
|
|
24
|
+
|
|
25
|
+
|
|
18
26
|
class RAGState(TypedDict, total=False):
|
|
19
27
|
context: str
|
|
20
|
-
doc_texts:
|
|
21
|
-
doc_ids:
|
|
28
|
+
doc_texts: list[str]
|
|
29
|
+
doc_ids: list[str]
|
|
22
30
|
summary: str
|
|
31
|
+
rag_metadata: RAGMetadata
|
|
23
32
|
|
|
24
33
|
|
|
25
34
|
def remove_surrogates(text: str) -> str:
|
|
@@ -29,8 +38,8 @@ def remove_surrogates(text: str) -> str:
|
|
|
29
38
|
class RAGAgent(BaseAgent):
|
|
30
39
|
def __init__(
|
|
31
40
|
self,
|
|
41
|
+
embedding: Embeddings,
|
|
32
42
|
llm="openai/o3-mini",
|
|
33
|
-
embedding: Optional[Embeddings] = None,
|
|
34
43
|
return_k: int = 10,
|
|
35
44
|
chunk_size: int = 1000,
|
|
36
45
|
chunk_overlap: int = 200,
|
|
@@ -49,11 +58,18 @@ class RAGAgent(BaseAgent):
|
|
|
49
58
|
self.database_path = database_path
|
|
50
59
|
self.summaries_path = summaries_path
|
|
51
60
|
self.vectorstore_path = vectorstore_path
|
|
52
|
-
self.graph = self._build_graph()
|
|
53
61
|
|
|
54
62
|
os.makedirs(self.vectorstore_path, exist_ok=True)
|
|
55
63
|
self.vectorstore = self._open_global_vectorstore()
|
|
56
64
|
|
|
65
|
+
@cached_property
|
|
66
|
+
def graph(self):
|
|
67
|
+
return self._build_graph()
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def _action(self):
|
|
71
|
+
return self.graph
|
|
72
|
+
|
|
57
73
|
@property
|
|
58
74
|
def manifest_path(self) -> str:
|
|
59
75
|
return os.path.join(self.vectorstore_path, "_ingested_ids.txt")
|
|
@@ -66,6 +82,7 @@ class RAGAgent(BaseAgent):
|
|
|
66
82
|
return Chroma(
|
|
67
83
|
persist_directory=self.vectorstore_path,
|
|
68
84
|
embedding_function=self.embedding,
|
|
85
|
+
collection_metadata={"hnsw:space": "cosine"},
|
|
69
86
|
)
|
|
70
87
|
|
|
71
88
|
def _paper_exists_in_vectorstore(self, doc_id: str) -> bool:
|
|
@@ -101,7 +118,7 @@ class RAGAgent(BaseAgent):
|
|
|
101
118
|
search_kwargs={"k": k}
|
|
102
119
|
)
|
|
103
120
|
|
|
104
|
-
def
|
|
121
|
+
def _read_docs_node(self, state: RAGState) -> RAGState:
|
|
105
122
|
print("[RAG Agent] Reading Documents....")
|
|
106
123
|
papers = []
|
|
107
124
|
new_state = state.copy()
|
|
@@ -121,7 +138,7 @@ class RAGAgent(BaseAgent):
|
|
|
121
138
|
if not self._paper_exists_in_vectorstore(id)
|
|
122
139
|
]
|
|
123
140
|
|
|
124
|
-
for pdf_filename in pdf_files:
|
|
141
|
+
for pdf_filename in tqdm(pdf_files, desc="RAG parsing text"):
|
|
125
142
|
full_text = ""
|
|
126
143
|
|
|
127
144
|
try:
|
|
@@ -141,13 +158,23 @@ class RAGAgent(BaseAgent):
|
|
|
141
158
|
|
|
142
159
|
return new_state
|
|
143
160
|
|
|
144
|
-
def
|
|
161
|
+
def _ingest_docs_node(self, state: RAGState) -> RAGState:
|
|
145
162
|
splitter = RecursiveCharacterTextSplitter(
|
|
146
163
|
chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
|
|
147
164
|
)
|
|
148
165
|
|
|
166
|
+
if "doc_texts" not in state:
|
|
167
|
+
raise RuntimeError("Unexpected error: doc_ids not in state!")
|
|
168
|
+
|
|
169
|
+
if "doc_ids" not in state:
|
|
170
|
+
raise RuntimeError("Unexpected error: doc_texts not in state!")
|
|
171
|
+
|
|
149
172
|
batch_docs, batch_ids = [], []
|
|
150
|
-
for paper, id in
|
|
173
|
+
for paper, id in tqdm(
|
|
174
|
+
zip(state["doc_texts"], state["doc_ids"]),
|
|
175
|
+
total=len(state["doc_texts"]),
|
|
176
|
+
desc="RAG Ingesting",
|
|
177
|
+
):
|
|
151
178
|
cleaned_text = remove_surrogates(paper)
|
|
152
179
|
docs = splitter.create_documents(
|
|
153
180
|
[cleaned_text], metadatas=[{"id": id}]
|
|
@@ -160,12 +187,12 @@ class RAGAgent(BaseAgent):
|
|
|
160
187
|
print("[RAG Agent] Ingesting Documents Into RAG Database....")
|
|
161
188
|
with self._vs_lock:
|
|
162
189
|
self.vectorstore.add_documents(batch_docs, ids=batch_ids)
|
|
163
|
-
for id in
|
|
190
|
+
for id in batch_ids:
|
|
164
191
|
self._mark_paper_ingested(id)
|
|
165
192
|
|
|
166
193
|
return state
|
|
167
194
|
|
|
168
|
-
def
|
|
195
|
+
def _retrieve_and_summarize_node(self, state: RAGState) -> RAGState:
|
|
169
196
|
print(
|
|
170
197
|
"[RAG Agent] Retrieving Contextually Relevant Information From Database..."
|
|
171
198
|
)
|
|
@@ -181,9 +208,14 @@ class RAGAgent(BaseAgent):
|
|
|
181
208
|
|
|
182
209
|
# 2) One retrieval over the global DB with the task context
|
|
183
210
|
try:
|
|
184
|
-
|
|
211
|
+
if "context" not in state:
|
|
212
|
+
raise RuntimeError("Unexpected error: context not in state!")
|
|
213
|
+
|
|
214
|
+
results = self.vectorstore.similarity_search_with_relevance_scores(
|
|
185
215
|
state["context"], k=self.return_k
|
|
186
216
|
)
|
|
217
|
+
|
|
218
|
+
relevance_scores = [score for _, score in results]
|
|
187
219
|
except Exception as e:
|
|
188
220
|
print(f"RAG failed due to: {e}")
|
|
189
221
|
return {**state, "summary": ""}
|
|
@@ -195,13 +227,6 @@ class RAGAgent(BaseAgent):
|
|
|
195
227
|
source_ids_list.append(aid)
|
|
196
228
|
source_ids = ", ".join(source_ids_list)
|
|
197
229
|
|
|
198
|
-
# Compute a simple similarity-based quality score
|
|
199
|
-
relevancy_scores = []
|
|
200
|
-
if results:
|
|
201
|
-
distances = [score for _, score in results]
|
|
202
|
-
sims = [1.0 / (1.0 + d) for d in distances] # map distance -> [0,1)
|
|
203
|
-
relevancy_scores = sims
|
|
204
|
-
|
|
205
230
|
retrieved_content = (
|
|
206
231
|
"\n\n".join(doc.page_content for doc, _ in results)
|
|
207
232
|
if results
|
|
@@ -223,11 +248,11 @@ class RAGAgent(BaseAgent):
|
|
|
223
248
|
f.write(rag_summary)
|
|
224
249
|
|
|
225
250
|
# Diagnostics
|
|
226
|
-
if
|
|
227
|
-
print(f"\nMax
|
|
228
|
-
print(f"Min
|
|
251
|
+
if relevance_scores:
|
|
252
|
+
print(f"\nMax Relevance Score: {max(relevance_scores):.4f}")
|
|
253
|
+
print(f"Min Relevance Score: {min(relevance_scores):.4f}")
|
|
229
254
|
print(
|
|
230
|
-
f"Median
|
|
255
|
+
f"Median Relevance Score: {statistics.median(relevance_scores):.4f}\n"
|
|
231
256
|
)
|
|
232
257
|
else:
|
|
233
258
|
print("\nNo RAG results retrieved (score list empty).\n")
|
|
@@ -239,34 +264,40 @@ class RAGAgent(BaseAgent):
|
|
|
239
264
|
"rag_metadata": {
|
|
240
265
|
"k": self.return_k,
|
|
241
266
|
"num_results": len(results),
|
|
242
|
-
"
|
|
267
|
+
"relevance_scores": relevance_scores,
|
|
243
268
|
},
|
|
244
269
|
}
|
|
245
270
|
|
|
246
|
-
def
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
271
|
+
def _invoke(
|
|
272
|
+
self, inputs: Mapping[str, Any], recursion_limit: int = 100000, **_
|
|
273
|
+
):
|
|
274
|
+
config = self.build_config(
|
|
275
|
+
recursion_limit=recursion_limit, tags=["graph"]
|
|
276
|
+
)
|
|
277
|
+
return self._action.invoke(inputs, config)
|
|
253
278
|
|
|
254
|
-
|
|
255
|
-
|
|
279
|
+
def _build_graph(self):
|
|
280
|
+
graph = StateGraph(RAGState)
|
|
256
281
|
|
|
257
|
-
graph
|
|
258
|
-
|
|
282
|
+
self.add_node(graph, self._read_docs_node)
|
|
283
|
+
self.add_node(graph, self._ingest_docs_node)
|
|
284
|
+
self.add_node(graph, self._retrieve_and_summarize_node)
|
|
259
285
|
|
|
260
|
-
|
|
261
|
-
|
|
286
|
+
graph.add_edge("_read_docs_node", "_ingest_docs_node")
|
|
287
|
+
graph.add_edge("_ingest_docs_node", "_retrieve_and_summarize_node")
|
|
262
288
|
|
|
263
|
-
|
|
289
|
+
graph.set_entry_point("_read_docs_node")
|
|
290
|
+
graph.set_finish_point("_retrieve_and_summarize_node")
|
|
264
291
|
|
|
292
|
+
return graph.compile(checkpointer=self.checkpointer)
|
|
265
293
|
|
|
266
|
-
if __name__ == "__main__":
|
|
267
|
-
agent = RAGAgent(database_path="workspace/arxiv_papers_neutron_star")
|
|
268
|
-
result = agent.run(
|
|
269
|
-
context="What are the constraints on the neutron star radius and what uncertainties are there on the constraints?",
|
|
270
|
-
)
|
|
271
294
|
|
|
272
|
-
|
|
295
|
+
# NOTE: Run test in `tests/agents/test_rag_agent/test_rag_agent.py` via:
|
|
296
|
+
#
|
|
297
|
+
# pytest -s tests/agents/test_rag_agent
|
|
298
|
+
#
|
|
299
|
+
# OR
|
|
300
|
+
#
|
|
301
|
+
# uv run pytest -s tests/agents/test_rag_agent
|
|
302
|
+
#
|
|
303
|
+
# NOTE: You may need to `rm -rf workspace/rag-agent` to remove the vectorstore.
|
ursa/agents/recall_agent.py
CHANGED
|
@@ -1,23 +1,53 @@
|
|
|
1
|
+
from typing import Any, Mapping, TypedDict
|
|
2
|
+
|
|
3
|
+
from langgraph.graph import StateGraph
|
|
4
|
+
|
|
1
5
|
from .base import BaseAgent
|
|
2
6
|
|
|
3
7
|
|
|
8
|
+
class RecallState(TypedDict):
|
|
9
|
+
query: str
|
|
10
|
+
memory: str
|
|
11
|
+
|
|
12
|
+
|
|
4
13
|
class RecallAgent(BaseAgent):
|
|
5
14
|
def __init__(self, llm, memory, **kwargs):
|
|
6
15
|
super().__init__(llm, **kwargs)
|
|
7
16
|
self.memorydb = memory
|
|
17
|
+
self._action = self._build_graph()
|
|
8
18
|
|
|
9
|
-
def
|
|
10
|
-
memories = self.memorydb.retrieve(query)
|
|
19
|
+
def _remember(self, state: RecallState) -> str:
|
|
20
|
+
memories = self.memorydb.retrieve(state["query"])
|
|
11
21
|
summarize_query = f"""
|
|
12
22
|
You are being given the critical task of generating a detailed description of logged information
|
|
13
23
|
to an important official to make a decision. Summarize the following memories that are related to
|
|
14
24
|
the statement. Ensure that any specific details that are important are retained in the summary.
|
|
15
25
|
|
|
16
|
-
Query: {query}
|
|
26
|
+
Query: {state["query"]}
|
|
17
27
|
|
|
18
28
|
"""
|
|
19
29
|
|
|
20
30
|
for memory in memories:
|
|
21
31
|
summarize_query += f"Memory: {memory} \n\n"
|
|
22
|
-
memory = self.llm.invoke(summarize_query).content
|
|
23
|
-
return
|
|
32
|
+
state["memory"] = self.llm.invoke(summarize_query).content
|
|
33
|
+
return state
|
|
34
|
+
|
|
35
|
+
def _build_graph(self):
|
|
36
|
+
graph = StateGraph(RecallState)
|
|
37
|
+
|
|
38
|
+
self.add_node(graph, self._remember)
|
|
39
|
+
graph.set_entry_point("_remember")
|
|
40
|
+
graph.set_finish_point("_remember")
|
|
41
|
+
return graph.compile(checkpointer=self.checkpointer)
|
|
42
|
+
|
|
43
|
+
def _invoke(
|
|
44
|
+
self, inputs: Mapping[str, Any], recursion_limit: int = 100000, **_
|
|
45
|
+
):
|
|
46
|
+
config = self.build_config(
|
|
47
|
+
recursion_limit=recursion_limit, tags=["graph"]
|
|
48
|
+
)
|
|
49
|
+
if "query" not in inputs:
|
|
50
|
+
raise ("'query' is a required argument")
|
|
51
|
+
|
|
52
|
+
output = self._action.invoke(inputs, config)
|
|
53
|
+
return output["memory"]
|
ursa/agents/websearch_agent.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# from langchain_community.tools import TavilySearchResults
|
|
2
2
|
# from langchain_core.runnables.graph import MermaidDrawMethod
|
|
3
|
-
from typing import Annotated, Any, List, Optional
|
|
3
|
+
from typing import Annotated, Any, List, Mapping, Optional
|
|
4
4
|
|
|
5
5
|
import requests
|
|
6
6
|
from bs4 import BeautifulSoup
|
|
@@ -8,7 +8,7 @@ from langchain_community.tools import DuckDuckGoSearchResults
|
|
|
8
8
|
from langchain_core.language_models import BaseChatModel
|
|
9
9
|
from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage
|
|
10
10
|
from langchain_openai import ChatOpenAI
|
|
11
|
-
from langgraph.graph import
|
|
11
|
+
from langgraph.graph import StateGraph
|
|
12
12
|
from langgraph.graph.message import add_messages
|
|
13
13
|
from langgraph.prebuilt import InjectedState, create_react_agent
|
|
14
14
|
from pydantic import Field
|
|
@@ -57,9 +57,9 @@ class WebSearchAgent(BaseAgent):
|
|
|
57
57
|
self.has_internet = self._check_for_internet(
|
|
58
58
|
kwargs.get("url", "http://www.lanl.gov")
|
|
59
59
|
)
|
|
60
|
-
self.
|
|
60
|
+
self._build_graph()
|
|
61
61
|
|
|
62
|
-
def
|
|
62
|
+
def _review_node(self, state: WebSearchState) -> WebSearchState:
|
|
63
63
|
if not self.has_internet:
|
|
64
64
|
return {
|
|
65
65
|
"messages": [
|
|
@@ -78,7 +78,7 @@ class WebSearchAgent(BaseAgent):
|
|
|
78
78
|
)
|
|
79
79
|
return {"messages": [HumanMessage(content=res.content)]}
|
|
80
80
|
|
|
81
|
-
def
|
|
81
|
+
def _response_node(self, state: WebSearchState) -> WebSearchState:
|
|
82
82
|
if not self.has_internet:
|
|
83
83
|
return {
|
|
84
84
|
"messages": [
|
|
@@ -111,60 +111,50 @@ class WebSearchAgent(BaseAgent):
|
|
|
111
111
|
except (requests.ConnectionError, requests.Timeout):
|
|
112
112
|
return False
|
|
113
113
|
|
|
114
|
-
def
|
|
114
|
+
def _state_store_node(self, state: WebSearchState) -> WebSearchState:
|
|
115
115
|
state["thread_id"] = self.thread_id
|
|
116
116
|
return state
|
|
117
117
|
# return dict(**state, thread_id=self.thread_id)
|
|
118
118
|
|
|
119
|
-
def
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
self.llm,
|
|
126
|
-
self.tools,
|
|
127
|
-
state_schema=WebSearchState,
|
|
128
|
-
prompt=self.websearch_prompt,
|
|
129
|
-
),
|
|
119
|
+
def _create_react(self, state: WebSearchState) -> WebSearchState:
|
|
120
|
+
react_agent = create_react_agent(
|
|
121
|
+
self.llm,
|
|
122
|
+
self.tools,
|
|
123
|
+
state_schema=WebSearchState,
|
|
124
|
+
prompt=self.websearch_prompt,
|
|
130
125
|
)
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
self.graph
|
|
136
|
-
self.graph
|
|
137
|
-
self.graph
|
|
138
|
-
self.graph
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
126
|
+
return react_agent.invoke(state)
|
|
127
|
+
|
|
128
|
+
def _build_graph(self):
|
|
129
|
+
graph = StateGraph(WebSearchState)
|
|
130
|
+
self.add_node(graph, self._state_store_node)
|
|
131
|
+
self.add_node(graph, self._create_react)
|
|
132
|
+
self.add_node(graph, self._review_node)
|
|
133
|
+
self.add_node(graph, self._response_node)
|
|
134
|
+
|
|
135
|
+
graph.set_entry_point("_state_store_node")
|
|
136
|
+
graph.add_edge("_state_store_node", "_create_react")
|
|
137
|
+
graph.add_edge("_create_react", "_review_node")
|
|
138
|
+
graph.set_finish_point("_response_node")
|
|
139
|
+
|
|
140
|
+
graph.add_conditional_edges(
|
|
141
|
+
"_review_node",
|
|
142
142
|
should_continue,
|
|
143
|
-
{"websearch": "websearch", "response": "response"},
|
|
144
|
-
)
|
|
145
|
-
self.action = self.graph.compile(checkpointer=self.checkpointer)
|
|
146
|
-
# self.action.get_graph().draw_mermaid_png(output_file_path="./websearch_agent_graph.png", draw_method=MermaidDrawMethod.PYPPETEER)
|
|
147
|
-
|
|
148
|
-
def run(self, prompt, recursion_limit=100):
|
|
149
|
-
if not self.has_internet:
|
|
150
|
-
return {
|
|
151
|
-
"messages": [
|
|
152
|
-
HumanMessage(
|
|
153
|
-
content="No internet for WebSearch Agent. No research carried out."
|
|
154
|
-
)
|
|
155
|
-
]
|
|
156
|
-
}
|
|
157
|
-
inputs = {
|
|
158
|
-
"messages": [HumanMessage(content=prompt)],
|
|
159
|
-
"model": self.llm,
|
|
160
|
-
}
|
|
161
|
-
return self.action.invoke(
|
|
162
|
-
inputs,
|
|
163
143
|
{
|
|
164
|
-
"
|
|
165
|
-
"
|
|
144
|
+
"_create_react": "_create_react",
|
|
145
|
+
"_response_node": "_response_node",
|
|
166
146
|
},
|
|
167
147
|
)
|
|
148
|
+
self._action = graph.compile(checkpointer=self.checkpointer)
|
|
149
|
+
# self._action.get_graph().draw_mermaid_png(output_file_path="./websearch_agent_graph.png", draw_method=MermaidDrawMethod.PYPPETEER)
|
|
150
|
+
|
|
151
|
+
def _invoke(
|
|
152
|
+
self, inputs: Mapping[str, Any], recursion_limit: int = 1000, **_
|
|
153
|
+
):
|
|
154
|
+
config = self.build_config(
|
|
155
|
+
recursion_limit=recursion_limit, tags=["graph"]
|
|
156
|
+
)
|
|
157
|
+
return self._action.invoke(inputs, config)
|
|
168
158
|
|
|
169
159
|
|
|
170
160
|
def process_content(
|
|
@@ -204,10 +194,10 @@ search_tool = DuckDuckGoSearchResults(output_format="json", num_results=10)
|
|
|
204
194
|
|
|
205
195
|
def should_continue(state: WebSearchState):
|
|
206
196
|
if len(state["messages"]) > (state.get("max_websearch_steps", 100) + 3):
|
|
207
|
-
return "
|
|
197
|
+
return "_response_node"
|
|
208
198
|
if "[APPROVED]" in state["messages"][-1].content:
|
|
209
|
-
return "
|
|
210
|
-
return "
|
|
199
|
+
return "_response_node"
|
|
200
|
+
return "_create_react"
|
|
211
201
|
|
|
212
202
|
|
|
213
203
|
def main():
|
|
@@ -220,7 +210,7 @@ def main():
|
|
|
220
210
|
"messages": [HumanMessage(content=problem_string)],
|
|
221
211
|
"model": model,
|
|
222
212
|
}
|
|
223
|
-
result = websearcher.
|
|
213
|
+
result = websearcher.invoke(
|
|
224
214
|
inputs,
|
|
225
215
|
{
|
|
226
216
|
"recursion_limit": 10000,
|
ursa/cli/__init__.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from typing import Annotated, Optional
|
|
3
|
+
|
|
4
|
+
from rich.console import Console
|
|
5
|
+
from typer import Option, Typer
|
|
6
|
+
|
|
7
|
+
app = Typer()
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
# TODO: add help
|
|
11
|
+
@app.command()
|
|
12
|
+
def run(
|
|
13
|
+
workspace: Annotated[
|
|
14
|
+
Path, Option(help="Directory to store ursa ouput")
|
|
15
|
+
] = Path(".ursa"),
|
|
16
|
+
llm_model_name: Annotated[
|
|
17
|
+
str,
|
|
18
|
+
Option(
|
|
19
|
+
help="Name of LLM to use for agent tasks", envvar="URSA_LLM_NAME"
|
|
20
|
+
),
|
|
21
|
+
] = "gpt-5",
|
|
22
|
+
llm_base_url: Annotated[
|
|
23
|
+
str, Option(help="Base url for LLM.", envvar="URSA_LLM_BASE_URL")
|
|
24
|
+
] = "https://api.openai.com/v1",
|
|
25
|
+
llm_api_key: Annotated[
|
|
26
|
+
Optional[str], Option(help="API key for LLM", envvar="URSA_LLM_API_KEY")
|
|
27
|
+
] = None,
|
|
28
|
+
max_completion_tokens: Annotated[
|
|
29
|
+
int, Option(help="Maximum tokens for LLM to output")
|
|
30
|
+
] = 50000,
|
|
31
|
+
emb_model_name: Annotated[
|
|
32
|
+
str, Option(help="Embedding model name", envvar="URSA_EMB_NAME")
|
|
33
|
+
] = "text-embedding-3-small",
|
|
34
|
+
emb_base_url: Annotated[
|
|
35
|
+
str,
|
|
36
|
+
Option(help="Base url for embedding model", envvar="URSA_EMB_BASE_URL"),
|
|
37
|
+
] = "https://api.openai.com/v1",
|
|
38
|
+
emb_api_key: Annotated[
|
|
39
|
+
Optional[str],
|
|
40
|
+
Option(help="API key for embedding model", envvar="URSA_EMB_API_KEY"),
|
|
41
|
+
] = None,
|
|
42
|
+
share_key: Annotated[
|
|
43
|
+
bool,
|
|
44
|
+
Option(
|
|
45
|
+
help=(
|
|
46
|
+
"Whether or not the LLM and embedding model share the same "
|
|
47
|
+
"API key. If yes, then you can specify only one of them."
|
|
48
|
+
)
|
|
49
|
+
),
|
|
50
|
+
] = False,
|
|
51
|
+
arxiv_summarize: Annotated[
|
|
52
|
+
bool,
|
|
53
|
+
Option(
|
|
54
|
+
help="Whether or not to allow ArxivAgent to summarize response."
|
|
55
|
+
),
|
|
56
|
+
] = True,
|
|
57
|
+
arxiv_process_images: Annotated[
|
|
58
|
+
bool,
|
|
59
|
+
Option(help="Whether or not to allow ArxivAgent to process images."),
|
|
60
|
+
] = False,
|
|
61
|
+
arxiv_max_results: Annotated[
|
|
62
|
+
int,
|
|
63
|
+
Option(
|
|
64
|
+
help="Maximum number of results for ArxivAgent to retrieve from ArXiv."
|
|
65
|
+
),
|
|
66
|
+
] = 10,
|
|
67
|
+
arxiv_database_path: Annotated[
|
|
68
|
+
Optional[Path],
|
|
69
|
+
Option(
|
|
70
|
+
help="Path to download/downloaded ArXiv documents; used by ArxivAgent."
|
|
71
|
+
),
|
|
72
|
+
] = None,
|
|
73
|
+
arxiv_summaries_path: Annotated[
|
|
74
|
+
Optional[Path],
|
|
75
|
+
Option(help="Path to store ArXiv paper summaries; used by ArxivAgent."),
|
|
76
|
+
] = None,
|
|
77
|
+
arxiv_vectorstore_path: Annotated[
|
|
78
|
+
Optional[Path],
|
|
79
|
+
Option(
|
|
80
|
+
help="Path to store ArXiv paper vector store; used by ArxivAgent."
|
|
81
|
+
),
|
|
82
|
+
] = None,
|
|
83
|
+
arxiv_download_papers: Annotated[
|
|
84
|
+
bool,
|
|
85
|
+
Option(
|
|
86
|
+
help="Whether or not to allow ArxivAgent to download ArXiv papers."
|
|
87
|
+
),
|
|
88
|
+
] = True,
|
|
89
|
+
ssl_verify: Annotated[
|
|
90
|
+
bool, Option(help="Whether or not to verify SSL certificates.")
|
|
91
|
+
] = True,
|
|
92
|
+
) -> None:
|
|
93
|
+
console = Console()
|
|
94
|
+
with console.status("[grey50]Loading ursa ..."):
|
|
95
|
+
from ursa.cli.hitl import HITL, UrsaRepl
|
|
96
|
+
|
|
97
|
+
hitl = HITL(
|
|
98
|
+
workspace=workspace,
|
|
99
|
+
llm_model_name=llm_model_name,
|
|
100
|
+
llm_base_url=llm_base_url,
|
|
101
|
+
llm_api_key=llm_api_key,
|
|
102
|
+
max_completion_tokens=max_completion_tokens,
|
|
103
|
+
emb_model_name=emb_model_name,
|
|
104
|
+
emb_base_url=emb_base_url,
|
|
105
|
+
emb_api_key=emb_api_key,
|
|
106
|
+
share_key=share_key,
|
|
107
|
+
arxiv_summarize=arxiv_summarize,
|
|
108
|
+
arxiv_process_images=arxiv_process_images,
|
|
109
|
+
arxiv_max_results=arxiv_max_results,
|
|
110
|
+
arxiv_database_path=arxiv_database_path,
|
|
111
|
+
arxiv_summaries_path=arxiv_summaries_path,
|
|
112
|
+
arxiv_vectorstore_path=arxiv_vectorstore_path,
|
|
113
|
+
arxiv_download_papers=arxiv_download_papers,
|
|
114
|
+
ssl_verify=ssl_verify,
|
|
115
|
+
)
|
|
116
|
+
UrsaRepl(hitl).run()
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
@app.command()
|
|
120
|
+
def version() -> None:
|
|
121
|
+
from importlib.metadata import version as get_version
|
|
122
|
+
|
|
123
|
+
print(get_version("ursa-ai"))
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def main():
|
|
127
|
+
app()
|