ursa-ai 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ursa/__init__.py +3 -0
- ursa/agents/__init__.py +32 -0
- ursa/agents/acquisition_agents.py +812 -0
- ursa/agents/arxiv_agent.py +429 -0
- ursa/agents/base.py +728 -0
- ursa/agents/chat_agent.py +60 -0
- ursa/agents/code_review_agent.py +341 -0
- ursa/agents/execution_agent.py +915 -0
- ursa/agents/hypothesizer_agent.py +614 -0
- ursa/agents/lammps_agent.py +465 -0
- ursa/agents/mp_agent.py +204 -0
- ursa/agents/optimization_agent.py +410 -0
- ursa/agents/planning_agent.py +219 -0
- ursa/agents/rag_agent.py +304 -0
- ursa/agents/recall_agent.py +54 -0
- ursa/agents/websearch_agent.py +196 -0
- ursa/cli/__init__.py +363 -0
- ursa/cli/hitl.py +516 -0
- ursa/cli/hitl_api.py +75 -0
- ursa/observability/metrics_charts.py +1279 -0
- ursa/observability/metrics_io.py +11 -0
- ursa/observability/metrics_session.py +750 -0
- ursa/observability/pricing.json +97 -0
- ursa/observability/pricing.py +321 -0
- ursa/observability/timing.py +1466 -0
- ursa/prompt_library/__init__.py +0 -0
- ursa/prompt_library/code_review_prompts.py +51 -0
- ursa/prompt_library/execution_prompts.py +50 -0
- ursa/prompt_library/hypothesizer_prompts.py +17 -0
- ursa/prompt_library/literature_prompts.py +11 -0
- ursa/prompt_library/optimization_prompts.py +131 -0
- ursa/prompt_library/planning_prompts.py +79 -0
- ursa/prompt_library/websearch_prompts.py +131 -0
- ursa/tools/__init__.py +0 -0
- ursa/tools/feasibility_checker.py +114 -0
- ursa/tools/feasibility_tools.py +1075 -0
- ursa/tools/run_command.py +27 -0
- ursa/tools/write_code.py +42 -0
- ursa/util/__init__.py +0 -0
- ursa/util/diff_renderer.py +128 -0
- ursa/util/helperFunctions.py +142 -0
- ursa/util/logo_generator.py +625 -0
- ursa/util/memory_logger.py +183 -0
- ursa/util/optimization_schema.py +78 -0
- ursa/util/parse.py +405 -0
- ursa_ai-0.9.1.dist-info/METADATA +304 -0
- ursa_ai-0.9.1.dist-info/RECORD +51 -0
- ursa_ai-0.9.1.dist-info/WHEEL +5 -0
- ursa_ai-0.9.1.dist-info/entry_points.txt +2 -0
- ursa_ai-0.9.1.dist-info/licenses/LICENSE +8 -0
- ursa_ai-0.9.1.dist-info/top_level.txt +1 -0
ursa/agents/rag_agent.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import re
|
|
3
|
+
import statistics
|
|
4
|
+
from functools import cached_property
|
|
5
|
+
from threading import Lock
|
|
6
|
+
from typing import Any, Mapping, TypedDict
|
|
7
|
+
|
|
8
|
+
from langchain.chat_models import BaseChatModel
|
|
9
|
+
from langchain.embeddings import Embeddings
|
|
10
|
+
from langchain_chroma import Chroma
|
|
11
|
+
from langchain_community.document_loaders import PyPDFLoader
|
|
12
|
+
from langchain_core.output_parsers import StrOutputParser
|
|
13
|
+
from langchain_core.prompts import ChatPromptTemplate
|
|
14
|
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|
15
|
+
from langgraph.graph import StateGraph
|
|
16
|
+
from tqdm import tqdm
|
|
17
|
+
|
|
18
|
+
from ursa.agents.base import BaseAgent
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class RAGMetadata(TypedDict):
|
|
22
|
+
k: int
|
|
23
|
+
num_results: int
|
|
24
|
+
relevance_scores: list[float]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class RAGState(TypedDict, total=False):
|
|
28
|
+
context: str
|
|
29
|
+
doc_texts: list[str]
|
|
30
|
+
doc_ids: list[str]
|
|
31
|
+
summary: str
|
|
32
|
+
rag_metadata: RAGMetadata
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def remove_surrogates(text: str) -> str:
|
|
36
|
+
return re.sub(r"[\ud800-\udfff]", "", text)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class RAGAgent(BaseAgent):
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
embedding: Embeddings,
|
|
43
|
+
llm: BaseChatModel,
|
|
44
|
+
return_k: int = 10,
|
|
45
|
+
chunk_size: int = 1000,
|
|
46
|
+
chunk_overlap: int = 200,
|
|
47
|
+
database_path: str = "database",
|
|
48
|
+
summaries_path: str = "database",
|
|
49
|
+
vectorstore_path: str = "vectorstore",
|
|
50
|
+
**kwargs,
|
|
51
|
+
):
|
|
52
|
+
super().__init__(llm, **kwargs)
|
|
53
|
+
self.retriever = None
|
|
54
|
+
self._vs_lock = Lock()
|
|
55
|
+
self.return_k = return_k
|
|
56
|
+
self.embedding = embedding
|
|
57
|
+
self.chunk_size = chunk_size
|
|
58
|
+
self.chunk_overlap = chunk_overlap
|
|
59
|
+
self.database_path = database_path
|
|
60
|
+
self.summaries_path = summaries_path
|
|
61
|
+
self.vectorstore_path = vectorstore_path
|
|
62
|
+
|
|
63
|
+
os.makedirs(self.vectorstore_path, exist_ok=True)
|
|
64
|
+
self.vectorstore = self._open_global_vectorstore()
|
|
65
|
+
|
|
66
|
+
@cached_property
|
|
67
|
+
def graph(self):
|
|
68
|
+
return self._build_graph()
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def _action(self):
|
|
72
|
+
return self.graph
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def manifest_path(self) -> str:
|
|
76
|
+
return os.path.join(self.vectorstore_path, "_ingested_ids.txt")
|
|
77
|
+
|
|
78
|
+
@property
|
|
79
|
+
def manifest_exists(self) -> bool:
|
|
80
|
+
return os.path.exists(self.manifest_path)
|
|
81
|
+
|
|
82
|
+
def _open_global_vectorstore(self) -> Chroma:
|
|
83
|
+
return Chroma(
|
|
84
|
+
persist_directory=self.vectorstore_path,
|
|
85
|
+
embedding_function=self.embedding,
|
|
86
|
+
collection_metadata={"hnsw:space": "cosine"},
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
def _paper_exists_in_vectorstore(self, doc_id: str) -> bool:
|
|
90
|
+
try:
|
|
91
|
+
col = self.vectorstore._collection
|
|
92
|
+
res = col.get(where={"id": doc_id}, limit=1)
|
|
93
|
+
return len(res.get("ids", [])) > 0
|
|
94
|
+
except Exception:
|
|
95
|
+
if not self.manifest_exists:
|
|
96
|
+
return False
|
|
97
|
+
with open(self.manifest_path, "r") as f:
|
|
98
|
+
return any(line.strip() == doc_id for line in f)
|
|
99
|
+
|
|
100
|
+
def _mark_paper_ingested(self, arxiv_id: str) -> None:
|
|
101
|
+
with open(self.manifest_path, "a") as f:
|
|
102
|
+
f.write(f"{arxiv_id}\n")
|
|
103
|
+
|
|
104
|
+
def _ensure_doc_in_vectorstore(self, paper_text: str, doc_id: str) -> None:
|
|
105
|
+
splitter = RecursiveCharacterTextSplitter(
|
|
106
|
+
chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
|
|
107
|
+
)
|
|
108
|
+
docs = splitter.create_documents(
|
|
109
|
+
[paper_text], metadatas=[{"id": doc_id}]
|
|
110
|
+
)
|
|
111
|
+
with self._vs_lock:
|
|
112
|
+
if not self._paper_exists_in_vectorstore(doc_id):
|
|
113
|
+
ids = [f"{doc_id}::{i}" for i, _ in enumerate(docs)]
|
|
114
|
+
self.vectorstore.add_documents(docs, ids=ids)
|
|
115
|
+
self._mark_paper_ingested(doc_id)
|
|
116
|
+
|
|
117
|
+
def _get_global_retriever(self, k: int = 5):
|
|
118
|
+
return self.vectorstore, self.vectorstore.as_retriever(
|
|
119
|
+
search_kwargs={"k": k}
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
def _read_docs_node(self, state: RAGState) -> RAGState:
|
|
123
|
+
print("[RAG Agent] Reading Documents....")
|
|
124
|
+
papers = []
|
|
125
|
+
new_state = state.copy()
|
|
126
|
+
|
|
127
|
+
pdf_files = [
|
|
128
|
+
f
|
|
129
|
+
for f in os.listdir(self.database_path)
|
|
130
|
+
if f.lower().endswith(".pdf")
|
|
131
|
+
]
|
|
132
|
+
|
|
133
|
+
doc_ids = [
|
|
134
|
+
pdf_filename.rsplit(".pdf", 1)[0] for pdf_filename in pdf_files
|
|
135
|
+
]
|
|
136
|
+
pdf_files = [
|
|
137
|
+
pdf_filename
|
|
138
|
+
for pdf_filename, id in zip(pdf_files, doc_ids)
|
|
139
|
+
if not self._paper_exists_in_vectorstore(id)
|
|
140
|
+
]
|
|
141
|
+
|
|
142
|
+
for pdf_filename in tqdm(pdf_files, desc="RAG parsing text"):
|
|
143
|
+
full_text = ""
|
|
144
|
+
|
|
145
|
+
try:
|
|
146
|
+
loader = PyPDFLoader(
|
|
147
|
+
os.path.join(self.database_path, pdf_filename)
|
|
148
|
+
)
|
|
149
|
+
pages = loader.load()
|
|
150
|
+
full_text = "\n".join([p.page_content for p in pages])
|
|
151
|
+
|
|
152
|
+
except Exception as e:
|
|
153
|
+
full_text = f"Error loading paper: {e}"
|
|
154
|
+
|
|
155
|
+
papers.append(full_text)
|
|
156
|
+
|
|
157
|
+
new_state["doc_texts"] = papers
|
|
158
|
+
new_state["doc_ids"] = doc_ids
|
|
159
|
+
|
|
160
|
+
return new_state
|
|
161
|
+
|
|
162
|
+
def _ingest_docs_node(self, state: RAGState) -> RAGState:
|
|
163
|
+
splitter = RecursiveCharacterTextSplitter(
|
|
164
|
+
chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
if "doc_texts" not in state:
|
|
168
|
+
raise RuntimeError("Unexpected error: doc_ids not in state!")
|
|
169
|
+
|
|
170
|
+
if "doc_ids" not in state:
|
|
171
|
+
raise RuntimeError("Unexpected error: doc_texts not in state!")
|
|
172
|
+
|
|
173
|
+
batch_docs, batch_ids = [], []
|
|
174
|
+
for paper, id in tqdm(
|
|
175
|
+
zip(state["doc_texts"], state["doc_ids"]),
|
|
176
|
+
total=len(state["doc_texts"]),
|
|
177
|
+
desc="RAG Ingesting",
|
|
178
|
+
):
|
|
179
|
+
cleaned_text = remove_surrogates(paper)
|
|
180
|
+
docs = splitter.create_documents(
|
|
181
|
+
[cleaned_text], metadatas=[{"id": id}]
|
|
182
|
+
)
|
|
183
|
+
ids = [f"{id}::{i}" for i, _ in enumerate(docs)]
|
|
184
|
+
batch_docs.extend(docs)
|
|
185
|
+
batch_ids.extend(ids)
|
|
186
|
+
|
|
187
|
+
if state["doc_texts"]:
|
|
188
|
+
print("[RAG Agent] Ingesting Documents Into RAG Database....")
|
|
189
|
+
with self._vs_lock:
|
|
190
|
+
self.vectorstore.add_documents(batch_docs, ids=batch_ids)
|
|
191
|
+
for id in batch_ids:
|
|
192
|
+
self._mark_paper_ingested(id)
|
|
193
|
+
|
|
194
|
+
return state
|
|
195
|
+
|
|
196
|
+
def _retrieve_and_summarize_node(self, state: RAGState) -> RAGState:
|
|
197
|
+
print(
|
|
198
|
+
"[RAG Agent] Retrieving Contextually Relevant Information From Database..."
|
|
199
|
+
)
|
|
200
|
+
prompt = ChatPromptTemplate.from_template("""
|
|
201
|
+
You are a scientific assistant responsible for summarizing extracts from research papers, in the context of the following task: {context}
|
|
202
|
+
|
|
203
|
+
Summarize the retrieved scientific content below.
|
|
204
|
+
Cite sources by ID when relevant: {source_ids}
|
|
205
|
+
|
|
206
|
+
{retrieved_content}
|
|
207
|
+
""")
|
|
208
|
+
chain = prompt | self.llm | StrOutputParser()
|
|
209
|
+
|
|
210
|
+
# 2) One retrieval over the global DB with the task context
|
|
211
|
+
try:
|
|
212
|
+
if "context" not in state:
|
|
213
|
+
raise RuntimeError("Unexpected error: context not in state!")
|
|
214
|
+
|
|
215
|
+
results = self.vectorstore.similarity_search_with_relevance_scores(
|
|
216
|
+
state["context"], k=self.return_k
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
relevance_scores = [score for _, score in results]
|
|
220
|
+
except Exception as e:
|
|
221
|
+
print(f"RAG failed due to: {e}")
|
|
222
|
+
return {**state, "summary": ""}
|
|
223
|
+
|
|
224
|
+
source_ids_list = []
|
|
225
|
+
for doc, _ in results:
|
|
226
|
+
aid = doc.metadata.get("id")
|
|
227
|
+
if aid and aid not in source_ids_list:
|
|
228
|
+
source_ids_list.append(aid)
|
|
229
|
+
source_ids = ", ".join(source_ids_list)
|
|
230
|
+
|
|
231
|
+
retrieved_content = (
|
|
232
|
+
"\n\n".join(doc.page_content for doc, _ in results)
|
|
233
|
+
if results
|
|
234
|
+
else ""
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
print("[RAG Agent] Summarizing Retrieved Information From Database...")
|
|
238
|
+
# 3) One summary based on retrieved chunks
|
|
239
|
+
rag_summary = chain.invoke({
|
|
240
|
+
"retrieved_content": retrieved_content,
|
|
241
|
+
"context": state["context"],
|
|
242
|
+
"source_ids": source_ids,
|
|
243
|
+
})
|
|
244
|
+
|
|
245
|
+
# Persist a single file for the batch (optional)
|
|
246
|
+
batch_name = "RAG_summary.txt"
|
|
247
|
+
os.makedirs(self.summaries_path, exist_ok=True)
|
|
248
|
+
with open(os.path.join(self.summaries_path, batch_name), "w") as f:
|
|
249
|
+
f.write(rag_summary)
|
|
250
|
+
|
|
251
|
+
# Diagnostics
|
|
252
|
+
if relevance_scores:
|
|
253
|
+
print(f"\nMax Relevance Score: {max(relevance_scores):.4f}")
|
|
254
|
+
print(f"Min Relevance Score: {min(relevance_scores):.4f}")
|
|
255
|
+
print(
|
|
256
|
+
f"Median Relevance Score: {statistics.median(relevance_scores):.4f}\n"
|
|
257
|
+
)
|
|
258
|
+
else:
|
|
259
|
+
print("\nNo RAG results retrieved (score list empty).\n")
|
|
260
|
+
|
|
261
|
+
# Return a single-element list by default (preferred)
|
|
262
|
+
return {
|
|
263
|
+
**state,
|
|
264
|
+
"summary": rag_summary,
|
|
265
|
+
"rag_metadata": {
|
|
266
|
+
"k": self.return_k,
|
|
267
|
+
"num_results": len(results),
|
|
268
|
+
"relevance_scores": relevance_scores,
|
|
269
|
+
},
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
def _invoke(
|
|
273
|
+
self, inputs: Mapping[str, Any], recursion_limit: int = 100000, **_
|
|
274
|
+
):
|
|
275
|
+
config = self.build_config(
|
|
276
|
+
recursion_limit=recursion_limit, tags=["graph"]
|
|
277
|
+
)
|
|
278
|
+
return self._action.invoke(inputs, config)
|
|
279
|
+
|
|
280
|
+
def _build_graph(self):
|
|
281
|
+
graph = StateGraph(RAGState)
|
|
282
|
+
|
|
283
|
+
self.add_node(graph, self._read_docs_node)
|
|
284
|
+
self.add_node(graph, self._ingest_docs_node)
|
|
285
|
+
self.add_node(graph, self._retrieve_and_summarize_node)
|
|
286
|
+
|
|
287
|
+
graph.add_edge("_read_docs_node", "_ingest_docs_node")
|
|
288
|
+
graph.add_edge("_ingest_docs_node", "_retrieve_and_summarize_node")
|
|
289
|
+
|
|
290
|
+
graph.set_entry_point("_read_docs_node")
|
|
291
|
+
graph.set_finish_point("_retrieve_and_summarize_node")
|
|
292
|
+
|
|
293
|
+
return graph.compile(checkpointer=self.checkpointer)
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
# NOTE: Run test in `tests/agents/test_rag_agent/test_rag_agent.py` via:
|
|
297
|
+
#
|
|
298
|
+
# pytest -s tests/agents/test_rag_agent
|
|
299
|
+
#
|
|
300
|
+
# OR
|
|
301
|
+
#
|
|
302
|
+
# uv run pytest -s tests/agents/test_rag_agent
|
|
303
|
+
#
|
|
304
|
+
# NOTE: You may need to `rm -rf workspace/rag-agent` to remove the vectorstore.
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
from typing import Any, Mapping, TypedDict
|
|
2
|
+
|
|
3
|
+
from langchain.chat_models import BaseChatModel
|
|
4
|
+
from langgraph.graph import StateGraph
|
|
5
|
+
|
|
6
|
+
from .base import BaseAgent
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class RecallState(TypedDict):
|
|
10
|
+
query: str
|
|
11
|
+
memory: str
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class RecallAgent(BaseAgent):
|
|
15
|
+
def __init__(self, llm: BaseChatModel, memory, **kwargs):
|
|
16
|
+
super().__init__(llm, **kwargs)
|
|
17
|
+
self.memorydb = memory
|
|
18
|
+
self._action = self._build_graph()
|
|
19
|
+
|
|
20
|
+
def _remember(self, state: RecallState) -> str:
|
|
21
|
+
memories = self.memorydb.retrieve(state["query"])
|
|
22
|
+
summarize_query = f"""
|
|
23
|
+
You are being given the critical task of generating a detailed description of logged information
|
|
24
|
+
to an important official to make a decision. Summarize the following memories that are related to
|
|
25
|
+
the statement. Ensure that any specific details that are important are retained in the summary.
|
|
26
|
+
|
|
27
|
+
Query: {state["query"]}
|
|
28
|
+
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
for memory in memories:
|
|
32
|
+
summarize_query += f"Memory: {memory} \n\n"
|
|
33
|
+
state["memory"] = self.llm.invoke(summarize_query).content
|
|
34
|
+
return state
|
|
35
|
+
|
|
36
|
+
def _build_graph(self):
|
|
37
|
+
graph = StateGraph(RecallState)
|
|
38
|
+
|
|
39
|
+
self.add_node(graph, self._remember)
|
|
40
|
+
graph.set_entry_point("_remember")
|
|
41
|
+
graph.set_finish_point("_remember")
|
|
42
|
+
return graph.compile(checkpointer=self.checkpointer)
|
|
43
|
+
|
|
44
|
+
def _invoke(
|
|
45
|
+
self, inputs: Mapping[str, Any], recursion_limit: int = 100000, **_
|
|
46
|
+
):
|
|
47
|
+
config = self.build_config(
|
|
48
|
+
recursion_limit=recursion_limit, tags=["graph"]
|
|
49
|
+
)
|
|
50
|
+
if "query" not in inputs:
|
|
51
|
+
raise ("'query' is a required argument")
|
|
52
|
+
|
|
53
|
+
output = self._action.invoke(inputs, config)
|
|
54
|
+
return output["memory"]
|
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
# from langchain_community.tools import TavilySearchResults
|
|
2
|
+
# from langchain_core.runnables.graph import MermaidDrawMethod
|
|
3
|
+
from typing import Annotated, Any, Mapping
|
|
4
|
+
|
|
5
|
+
import requests
|
|
6
|
+
from bs4 import BeautifulSoup
|
|
7
|
+
from langchain.agents import create_agent
|
|
8
|
+
from langchain.chat_models import BaseChatModel
|
|
9
|
+
from langchain.messages import HumanMessage, SystemMessage
|
|
10
|
+
from langchain_community.tools import DuckDuckGoSearchResults
|
|
11
|
+
from langgraph.graph import StateGraph
|
|
12
|
+
from langgraph.graph.message import add_messages
|
|
13
|
+
from langgraph.prebuilt import InjectedState
|
|
14
|
+
from pydantic import Field
|
|
15
|
+
from typing_extensions import TypedDict
|
|
16
|
+
|
|
17
|
+
from ..prompt_library.websearch_prompts import (
|
|
18
|
+
reflection_prompt,
|
|
19
|
+
summarize_prompt,
|
|
20
|
+
websearch_prompt,
|
|
21
|
+
)
|
|
22
|
+
from .base import BaseAgent
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class WebSearchState(TypedDict):
|
|
26
|
+
websearch_query: str
|
|
27
|
+
messages: Annotated[list, add_messages]
|
|
28
|
+
urls_visited: list[str]
|
|
29
|
+
max_websearch_steps: Annotated[
|
|
30
|
+
int, Field(default=100, description="Maximum number of websearch steps")
|
|
31
|
+
]
|
|
32
|
+
remaining_steps: int
|
|
33
|
+
is_last_step: bool
|
|
34
|
+
model: Any
|
|
35
|
+
thread_id: Any
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# Adding the model to the state clumsily so that all "read" sources arent in the
|
|
39
|
+
# context window. That eats a ton of tokens because each `llm.invoke` passes
|
|
40
|
+
# all the tokens of all the sources.
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class WebSearchAgentLegacy(BaseAgent):
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
llm: BaseChatModel,
|
|
47
|
+
**kwargs,
|
|
48
|
+
):
|
|
49
|
+
super().__init__(llm, **kwargs)
|
|
50
|
+
self.websearch_prompt = websearch_prompt
|
|
51
|
+
self.reflection_prompt = reflection_prompt
|
|
52
|
+
self.tools = [search_tool, process_content] # + cb_tools
|
|
53
|
+
self.has_internet = self._check_for_internet(
|
|
54
|
+
kwargs.get("url", "http://www.lanl.gov")
|
|
55
|
+
)
|
|
56
|
+
self._build_graph()
|
|
57
|
+
|
|
58
|
+
def _review_node(self, state: WebSearchState) -> WebSearchState:
|
|
59
|
+
if not self.has_internet:
|
|
60
|
+
return {
|
|
61
|
+
"messages": [
|
|
62
|
+
HumanMessage(
|
|
63
|
+
content="No internet for WebSearch Agent so no research to review."
|
|
64
|
+
)
|
|
65
|
+
],
|
|
66
|
+
"urls_visited": [],
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
translated = [SystemMessage(content=reflection_prompt)] + state[
|
|
70
|
+
"messages"
|
|
71
|
+
]
|
|
72
|
+
res = self.llm.invoke(
|
|
73
|
+
translated, {"configurable": {"thread_id": self.thread_id}}
|
|
74
|
+
)
|
|
75
|
+
return {"messages": [HumanMessage(content=res.content)]}
|
|
76
|
+
|
|
77
|
+
def _response_node(self, state: WebSearchState) -> WebSearchState:
|
|
78
|
+
if not self.has_internet:
|
|
79
|
+
return {
|
|
80
|
+
"messages": [
|
|
81
|
+
HumanMessage(
|
|
82
|
+
content="No internet for WebSearch Agent. No research carried out."
|
|
83
|
+
)
|
|
84
|
+
],
|
|
85
|
+
"urls_visited": [],
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
messages = state["messages"] + [SystemMessage(content=summarize_prompt)]
|
|
89
|
+
response = self.llm.invoke(
|
|
90
|
+
messages, {"configurable": {"thread_id": self.thread_id}}
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
urls_visited = []
|
|
94
|
+
for message in messages:
|
|
95
|
+
if message.model_dump().get("tool_calls", []):
|
|
96
|
+
if "url" in message.tool_calls[0]["args"]:
|
|
97
|
+
urls_visited.append(message.tool_calls[0]["args"]["url"])
|
|
98
|
+
return {"messages": [response.content], "urls_visited": urls_visited}
|
|
99
|
+
|
|
100
|
+
def _check_for_internet(self, url, timeout=2):
|
|
101
|
+
"""
|
|
102
|
+
Checks for internet connectivity by attempting an HTTP GET request.
|
|
103
|
+
"""
|
|
104
|
+
try:
|
|
105
|
+
requests.get(url, timeout=timeout)
|
|
106
|
+
return True
|
|
107
|
+
except (requests.ConnectionError, requests.Timeout):
|
|
108
|
+
return False
|
|
109
|
+
|
|
110
|
+
def _state_store_node(self, state: WebSearchState) -> WebSearchState:
|
|
111
|
+
state["thread_id"] = self.thread_id
|
|
112
|
+
return state
|
|
113
|
+
# return dict(**state, thread_id=self.thread_id)
|
|
114
|
+
|
|
115
|
+
def _create_react(self, state: WebSearchState) -> WebSearchState:
|
|
116
|
+
react_agent = create_agent(
|
|
117
|
+
self.llm,
|
|
118
|
+
self.tools,
|
|
119
|
+
state_schema=WebSearchState,
|
|
120
|
+
system_prompt=self.websearch_prompt,
|
|
121
|
+
)
|
|
122
|
+
return react_agent.invoke(state)
|
|
123
|
+
|
|
124
|
+
def _build_graph(self):
|
|
125
|
+
graph = StateGraph(WebSearchState)
|
|
126
|
+
self.add_node(graph, self._state_store_node)
|
|
127
|
+
self.add_node(graph, self._create_react)
|
|
128
|
+
self.add_node(graph, self._review_node)
|
|
129
|
+
self.add_node(graph, self._response_node)
|
|
130
|
+
|
|
131
|
+
graph.set_entry_point("_state_store_node")
|
|
132
|
+
graph.add_edge("_state_store_node", "_create_react")
|
|
133
|
+
graph.add_edge("_create_react", "_review_node")
|
|
134
|
+
graph.set_finish_point("_response_node")
|
|
135
|
+
|
|
136
|
+
graph.add_conditional_edges(
|
|
137
|
+
"_review_node",
|
|
138
|
+
should_continue,
|
|
139
|
+
{
|
|
140
|
+
"_create_react": "_create_react",
|
|
141
|
+
"_response_node": "_response_node",
|
|
142
|
+
},
|
|
143
|
+
)
|
|
144
|
+
self._action = graph.compile(checkpointer=self.checkpointer)
|
|
145
|
+
# self._action.get_graph().draw_mermaid_png(output_file_path="./websearch_agent_graph.png", draw_method=MermaidDrawMethod.PYPPETEER)
|
|
146
|
+
|
|
147
|
+
def _invoke(
|
|
148
|
+
self, inputs: Mapping[str, Any], recursion_limit: int = 1000, **_
|
|
149
|
+
):
|
|
150
|
+
config = self.build_config(
|
|
151
|
+
recursion_limit=recursion_limit, tags=["graph"]
|
|
152
|
+
)
|
|
153
|
+
return self._action.invoke(inputs, config)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def process_content(
|
|
157
|
+
url: str, context: str, state: Annotated[dict, InjectedState]
|
|
158
|
+
) -> str:
|
|
159
|
+
"""
|
|
160
|
+
Processes content from a given webpage.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
url: string with the url to obtain text content from.
|
|
164
|
+
context: string summary of the information the agent wants from the url for summarizing salient information.
|
|
165
|
+
"""
|
|
166
|
+
print("Parsing information from ", url)
|
|
167
|
+
response = requests.get(url)
|
|
168
|
+
soup = BeautifulSoup(response.content, "html.parser")
|
|
169
|
+
|
|
170
|
+
content_prompt = f"""
|
|
171
|
+
Here is the full content:
|
|
172
|
+
{soup.get_text()}
|
|
173
|
+
|
|
174
|
+
Carefully summarize the content in full detail, given the following context:
|
|
175
|
+
{context}
|
|
176
|
+
"""
|
|
177
|
+
summarized_information = (
|
|
178
|
+
state["model"]
|
|
179
|
+
.invoke(
|
|
180
|
+
content_prompt, {"configurable": {"thread_id": state["thread_id"]}}
|
|
181
|
+
)
|
|
182
|
+
.content
|
|
183
|
+
)
|
|
184
|
+
return summarized_information
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
search_tool = DuckDuckGoSearchResults(output_format="json", num_results=10)
|
|
188
|
+
# search_tool = TavilySearchResults(max_results=10, search_depth="advanced",include_answer=True)
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def should_continue(state: WebSearchState):
|
|
192
|
+
if len(state["messages"]) > (state.get("max_websearch_steps", 100) + 3):
|
|
193
|
+
return "_response_node"
|
|
194
|
+
if "[APPROVED]" in state["messages"][-1].content:
|
|
195
|
+
return "_response_node"
|
|
196
|
+
return "_create_react"
|