ursa-ai 0.0.3__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ursa-ai might be problematic. Click here for more details.
- ursa/agents/__init__.py +10 -0
- ursa/agents/arxiv_agent.py +349 -0
- ursa/agents/base.py +42 -0
- ursa/agents/code_review_agent.py +332 -0
- ursa/agents/execution_agent.py +497 -0
- ursa/agents/hypothesizer_agent.py +597 -0
- ursa/agents/mp_agent.py +257 -0
- ursa/agents/planning_agent.py +138 -0
- ursa/agents/recall_agent.py +25 -0
- ursa/agents/websearch_agent.py +193 -0
- ursa/prompt_library/code_review_prompts.py +51 -0
- ursa/prompt_library/execution_prompts.py +36 -0
- ursa/prompt_library/hypothesizer_prompts.py +17 -0
- ursa/prompt_library/literature_prompts.py +11 -0
- ursa/prompt_library/planning_prompts.py +79 -0
- ursa/prompt_library/websearch_prompts.py +131 -0
- ursa/tools/run_command.py +27 -0
- ursa/tools/write_code.py +42 -0
- ursa/util/diff_renderer.py +121 -0
- ursa/util/memory_logger.py +171 -0
- ursa/util/parse.py +89 -0
- ursa_ai-0.2.2.dist-info/METADATA +130 -0
- ursa_ai-0.2.2.dist-info/RECORD +26 -0
- ursa_ai-0.2.2.dist-info/licenses/LICENSE +8 -0
- ursa/__init__.py +0 -2
- ursa/py.typed +0 -0
- ursa_ai-0.0.3.dist-info/METADATA +0 -7
- ursa_ai-0.0.3.dist-info/RECORD +0 -6
- {ursa_ai-0.0.3.dist-info → ursa_ai-0.2.2.dist-info}/WHEEL +0 -0
- {ursa_ai-0.0.3.dist-info → ursa_ai-0.2.2.dist-info}/top_level.txt +0 -0
ursa/agents/mp_agent.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import json
|
|
3
|
+
from typing import List, Dict
|
|
4
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
5
|
+
from tqdm import tqdm
|
|
6
|
+
|
|
7
|
+
from mp_api.client import MPRester
|
|
8
|
+
from langchain.schema import Document
|
|
9
|
+
|
|
10
|
+
import os
|
|
11
|
+
import pymupdf
|
|
12
|
+
import requests
|
|
13
|
+
import feedparser
|
|
14
|
+
from PIL import Image
|
|
15
|
+
from io import BytesIO
|
|
16
|
+
import base64
|
|
17
|
+
from urllib.parse import quote
|
|
18
|
+
from typing_extensions import TypedDict, List
|
|
19
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
20
|
+
from tqdm import tqdm
|
|
21
|
+
import re
|
|
22
|
+
|
|
23
|
+
from langchain_community.document_loaders import PyPDFLoader
|
|
24
|
+
from langchain_core.output_parsers import StrOutputParser
|
|
25
|
+
from langchain_core.prompts import ChatPromptTemplate
|
|
26
|
+
from langgraph.graph import StateGraph, END, START
|
|
27
|
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
28
|
+
from langchain_chroma import Chroma
|
|
29
|
+
from langchain_openai import OpenAIEmbeddings
|
|
30
|
+
|
|
31
|
+
from openai import OpenAI
|
|
32
|
+
|
|
33
|
+
from .base import BaseAgent
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
client = OpenAI()
|
|
39
|
+
|
|
40
|
+
embeddings = OpenAIEmbeddings()
|
|
41
|
+
|
|
42
|
+
class PaperMetadata(TypedDict):
|
|
43
|
+
arxiv_id: str
|
|
44
|
+
full_text: str
|
|
45
|
+
|
|
46
|
+
class PaperState(TypedDict, total=False):
|
|
47
|
+
query: str
|
|
48
|
+
context: str
|
|
49
|
+
papers: List[PaperMetadata]
|
|
50
|
+
summaries: List[str]
|
|
51
|
+
final_summary: str
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def describe_image(image: Image.Image) -> str:
|
|
55
|
+
buffered = BytesIO()
|
|
56
|
+
image.save(buffered, format="PNG")
|
|
57
|
+
img_base64 = base64.b64encode(buffered.getvalue()).decode()
|
|
58
|
+
|
|
59
|
+
response = client.chat.completions.create(
|
|
60
|
+
model="gpt-4-vision-preview",
|
|
61
|
+
messages=[
|
|
62
|
+
{"role": "system", "content": "You are a scientific assistant who explains plots and scientific diagrams."},
|
|
63
|
+
{
|
|
64
|
+
"role": "user",
|
|
65
|
+
"content": [
|
|
66
|
+
{"type": "text", "text": "Describe this scientific image or plot in detail."},
|
|
67
|
+
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_base64}"}}
|
|
68
|
+
],
|
|
69
|
+
},
|
|
70
|
+
],
|
|
71
|
+
max_tokens=500,
|
|
72
|
+
)
|
|
73
|
+
return response.choices[0].message.content.strip()
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def extract_and_describe_images(pdf_path: str, max_images: int = 5) -> List[str]:
|
|
77
|
+
doc = pymupdf.open(pdf_path)
|
|
78
|
+
descriptions = []
|
|
79
|
+
image_count = 0
|
|
80
|
+
|
|
81
|
+
for page_index in range(len(doc)):
|
|
82
|
+
if image_count >= max_images:
|
|
83
|
+
break
|
|
84
|
+
page = doc[page_index]
|
|
85
|
+
images = page.get_images(full=True)
|
|
86
|
+
|
|
87
|
+
for img_index, img in enumerate(images):
|
|
88
|
+
if image_count >= max_images:
|
|
89
|
+
break
|
|
90
|
+
xref = img[0]
|
|
91
|
+
base_image = doc.extract_image(xref)
|
|
92
|
+
image_bytes = base_image["image"]
|
|
93
|
+
image = Image.open(BytesIO(image_bytes))
|
|
94
|
+
|
|
95
|
+
try:
|
|
96
|
+
desc = describe_image(image)
|
|
97
|
+
descriptions.append(f"Page {page_index + 1}, Image {img_index + 1}: {desc}")
|
|
98
|
+
except Exception as e:
|
|
99
|
+
descriptions.append(f"Page {page_index + 1}, Image {img_index + 1}: [Error: {e}]")
|
|
100
|
+
image_count += 1
|
|
101
|
+
|
|
102
|
+
return descriptions
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def remove_surrogates(text: str) -> str:
|
|
106
|
+
return re.sub(r'[\ud800-\udfff]', '', text)
|
|
107
|
+
|
|
108
|
+
class MaterialsProjectAgent(BaseAgent):
|
|
109
|
+
def __init__(
|
|
110
|
+
self,
|
|
111
|
+
llm="openai/o3-mini",
|
|
112
|
+
summarize: bool = True,
|
|
113
|
+
max_results: int = 3,
|
|
114
|
+
database_path: str = 'mp_database',
|
|
115
|
+
summaries_path: str = 'mp_summaries',
|
|
116
|
+
vectorstore_path: str = 'mp_vectorstores',
|
|
117
|
+
**kwargs
|
|
118
|
+
):
|
|
119
|
+
super().__init__(llm, **kwargs)
|
|
120
|
+
self.summarize = summarize
|
|
121
|
+
self.max_results = max_results
|
|
122
|
+
self.database_path = database_path
|
|
123
|
+
self.summaries_path = summaries_path
|
|
124
|
+
self.vectorstore_path = vectorstore_path
|
|
125
|
+
|
|
126
|
+
os.makedirs(self.database_path, exist_ok=True)
|
|
127
|
+
os.makedirs(self.summaries_path, exist_ok=True)
|
|
128
|
+
os.makedirs(self.vectorstore_path, exist_ok=True)
|
|
129
|
+
|
|
130
|
+
self.embeddings = OpenAIEmbeddings() # or your preferred embedding
|
|
131
|
+
self.graph = self._build_graph()
|
|
132
|
+
|
|
133
|
+
def _fetch_node(self, state: Dict) -> Dict:
|
|
134
|
+
f = state["query"]
|
|
135
|
+
els = f["elements"] # e.g. ["Ga","In"]
|
|
136
|
+
bg = (f["band_gap_min"], f["band_gap_max"])
|
|
137
|
+
e_above_hull = (0, 0) # only on-hull (stable)
|
|
138
|
+
mats = []
|
|
139
|
+
with MPRester() as mpr:
|
|
140
|
+
# get ALL matching materials…
|
|
141
|
+
all_results = mpr.materials.summary.search(
|
|
142
|
+
elements=els,
|
|
143
|
+
band_gap=bg,
|
|
144
|
+
energy_above_hull=e_above_hull,
|
|
145
|
+
is_stable=True # equivalent filter
|
|
146
|
+
)
|
|
147
|
+
# …then take only the first `max_results`
|
|
148
|
+
for doc in all_results[: self.max_results]:
|
|
149
|
+
mid = doc.material_id
|
|
150
|
+
data = doc.dict()
|
|
151
|
+
# cache to disk
|
|
152
|
+
path = os.path.join(self.database_path, f"{mid}.json")
|
|
153
|
+
if not os.path.exists(path):
|
|
154
|
+
with open(path, "w") as f:
|
|
155
|
+
json.dump(data, f, indent=2)
|
|
156
|
+
mats.append({"material_id": mid, "metadata": data})
|
|
157
|
+
|
|
158
|
+
return {**state, "materials": mats}
|
|
159
|
+
|
|
160
|
+
def _get_or_build_vectorstore(self, text: str, mid: str):
|
|
161
|
+
"""Build or load a Chroma vectorstore for a single material's description."""
|
|
162
|
+
persist_dir = os.path.join(self.vectorstore_path, mid)
|
|
163
|
+
if os.path.exists(persist_dir):
|
|
164
|
+
store = Chroma(persist_directory=persist_dir, embedding_function=self.embeddings)
|
|
165
|
+
else:
|
|
166
|
+
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
|
|
167
|
+
docs = splitter.create_documents([text])
|
|
168
|
+
store = Chroma.from_documents(docs, self.embeddings, persist_directory=persist_dir)
|
|
169
|
+
return store.as_retriever(search_kwargs={"k": 5})
|
|
170
|
+
|
|
171
|
+
def _summarize_node(self, state: Dict) -> Dict:
|
|
172
|
+
"""Summarize each material via LLM over its metadata."""
|
|
173
|
+
# prompt template
|
|
174
|
+
prompt = ChatPromptTemplate.from_template("""
|
|
175
|
+
You are a materials-science assistant. Given the following metadata about a material, produce a concise summary focusing on its key properties:
|
|
176
|
+
|
|
177
|
+
{metadata}
|
|
178
|
+
""")
|
|
179
|
+
chain = prompt | self.llm | StrOutputParser()
|
|
180
|
+
|
|
181
|
+
summaries = [None] * len(state["materials"])
|
|
182
|
+
relevancy = [0.0] * len(state["materials"])
|
|
183
|
+
|
|
184
|
+
def process(i, mat):
|
|
185
|
+
mid = mat["material_id"]
|
|
186
|
+
meta = mat["metadata"]
|
|
187
|
+
# flatten metadata to text
|
|
188
|
+
text = "\n".join(f"{k}: {v}" for k, v in meta.items())
|
|
189
|
+
# build or load summary
|
|
190
|
+
summary_file = os.path.join(self.summaries_path, f"{mid}_summary.txt")
|
|
191
|
+
if os.path.exists(summary_file):
|
|
192
|
+
with open(summary_file) as f:
|
|
193
|
+
return i, f.read()
|
|
194
|
+
# optional: vectorize & retrieve, but here we just summarize full text
|
|
195
|
+
result = chain.invoke({"metadata": text})
|
|
196
|
+
with open(summary_file, 'w') as f:
|
|
197
|
+
f.write(result)
|
|
198
|
+
return i, result
|
|
199
|
+
|
|
200
|
+
with ThreadPoolExecutor(max_workers=min(8, len(state["materials"]))) as exe:
|
|
201
|
+
futures = [exe.submit(process, i, m) for i, m in enumerate(state["materials"])]
|
|
202
|
+
for future in tqdm(futures, desc="Summarizing materials"):
|
|
203
|
+
i, summ = future.result()
|
|
204
|
+
summaries[i] = summ
|
|
205
|
+
|
|
206
|
+
return {**state, "summaries": summaries}
|
|
207
|
+
|
|
208
|
+
def _aggregate_node(self, state: Dict) -> Dict:
|
|
209
|
+
"""Combine all summaries into a single, coherent answer."""
|
|
210
|
+
combined = "\n\n----\n\n".join(
|
|
211
|
+
f"[{i+1}] {m['material_id']}\n\n{summary}"
|
|
212
|
+
for i, (m, summary) in enumerate(zip(state["materials"], state["summaries"]))
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
prompt = ChatPromptTemplate.from_template("""
|
|
216
|
+
You are a materials informatics assistant. Below are brief summaries of several materials:
|
|
217
|
+
|
|
218
|
+
{summaries}
|
|
219
|
+
|
|
220
|
+
Answer the user’s question in context:
|
|
221
|
+
|
|
222
|
+
{context}
|
|
223
|
+
""")
|
|
224
|
+
chain = prompt | self.llm | StrOutputParser()
|
|
225
|
+
final = chain.invoke({"summaries": combined, "context": state["context"]})
|
|
226
|
+
return {**state, "final_summary": final}
|
|
227
|
+
|
|
228
|
+
def _build_graph(self):
|
|
229
|
+
g = StateGraph(dict) # using plain dict for state
|
|
230
|
+
g.add_node("fetch", self._fetch_node)
|
|
231
|
+
if self.summarize:
|
|
232
|
+
g.add_node("summarize", self._summarize_node)
|
|
233
|
+
g.add_node("aggregate", self._aggregate_node)
|
|
234
|
+
g.set_entry_point("fetch")
|
|
235
|
+
g.add_edge("fetch", "summarize")
|
|
236
|
+
g.add_edge("summarize", "aggregate")
|
|
237
|
+
g.set_finish_point("aggregate")
|
|
238
|
+
else:
|
|
239
|
+
g.set_entry_point("fetch")
|
|
240
|
+
g.set_finish_point("fetch")
|
|
241
|
+
return g.compile()
|
|
242
|
+
|
|
243
|
+
def run(self, mp_query: str, context: str) -> str:
|
|
244
|
+
state = {"query": mp_query, "context": context}
|
|
245
|
+
out = self.graph.invoke(state)
|
|
246
|
+
if self.summarize:
|
|
247
|
+
return out.get("final_summary", "")
|
|
248
|
+
return json.dumps(out.get("materials", []), indent=2)
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
if __name__ == "__main__":
|
|
252
|
+
agent = MaterialsProjectAgent()
|
|
253
|
+
resp = agent.run(
|
|
254
|
+
mp_query="LiFePO4",
|
|
255
|
+
context="What is its band gap and stability, and any synthesis challenges?"
|
|
256
|
+
)
|
|
257
|
+
print(resp)
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
# from langgraph.checkpoint.memory import MemorySaver
|
|
2
|
+
# from langchain_core.runnables.graph import MermaidDrawMethod
|
|
3
|
+
from typing import Annotated, Any, Dict, List, Optional
|
|
4
|
+
|
|
5
|
+
from langchain_core.language_models import BaseChatModel
|
|
6
|
+
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
|
7
|
+
from langgraph.graph import END, START, StateGraph
|
|
8
|
+
from langgraph.graph.message import add_messages
|
|
9
|
+
from pydantic import Field
|
|
10
|
+
from typing_extensions import TypedDict
|
|
11
|
+
|
|
12
|
+
from ..prompt_library.planning_prompts import (
|
|
13
|
+
formalize_prompt,
|
|
14
|
+
planner_prompt,
|
|
15
|
+
reflection_prompt,
|
|
16
|
+
)
|
|
17
|
+
from ..util.parse import extract_json
|
|
18
|
+
from .base import BaseAgent
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class PlanningState(TypedDict):
|
|
22
|
+
messages: Annotated[list, add_messages]
|
|
23
|
+
plan_steps: List[Dict[str, Any]] = Field(
|
|
24
|
+
default_factory=list, description="Ordered steps in the solution plan"
|
|
25
|
+
)
|
|
26
|
+
reflection_steps: Optional[int] = Field(
|
|
27
|
+
default=3, description="Number of reflection steps"
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class PlanningAgent(BaseAgent):
|
|
32
|
+
def __init__(
|
|
33
|
+
self, llm: str | BaseChatModel = "openai/gpt-4o-mini", **kwargs
|
|
34
|
+
):
|
|
35
|
+
super().__init__(llm, **kwargs)
|
|
36
|
+
self.planner_prompt = planner_prompt
|
|
37
|
+
self.formalize_prompt = formalize_prompt
|
|
38
|
+
self.reflection_prompt = reflection_prompt
|
|
39
|
+
self._initialize_agent()
|
|
40
|
+
|
|
41
|
+
def generation_node(self, state: PlanningState) -> PlanningState:
|
|
42
|
+
messages = state["messages"]
|
|
43
|
+
if type(messages[0]) == SystemMessage:
|
|
44
|
+
messages[0] = SystemMessage(content=self.planner_prompt)
|
|
45
|
+
else:
|
|
46
|
+
messages = [SystemMessage(content=self.planner_prompt)] + messages
|
|
47
|
+
return {"messages": [self.llm.invoke(messages, {"configurable": {"thread_id": self.thread_id}})]}
|
|
48
|
+
|
|
49
|
+
def formalize_node(self, state: PlanningState) -> PlanningState:
|
|
50
|
+
cls_map = {"ai": HumanMessage, "human": AIMessage}
|
|
51
|
+
translated = [state["messages"][0]] + [
|
|
52
|
+
cls_map[msg.type](content=msg.content)
|
|
53
|
+
for msg in state["messages"][1:]
|
|
54
|
+
]
|
|
55
|
+
translated = [SystemMessage(content=self.formalize_prompt)] + translated
|
|
56
|
+
for _ in range(10):
|
|
57
|
+
try:
|
|
58
|
+
res = self.llm.invoke(translated, {"configurable": {"thread_id": self.thread_id}})
|
|
59
|
+
json_out = extract_json(res.content)
|
|
60
|
+
break
|
|
61
|
+
except ValueError:
|
|
62
|
+
translated.append(
|
|
63
|
+
HumanMessage(
|
|
64
|
+
content="Your response was not valid JSON. Try again."
|
|
65
|
+
)
|
|
66
|
+
)
|
|
67
|
+
return {
|
|
68
|
+
"messages": [HumanMessage(content=res.content)],
|
|
69
|
+
"plan_steps": json_out,
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
def reflection_node(self, state: PlanningState) -> PlanningState:
|
|
73
|
+
cls_map = {"ai": HumanMessage, "human": AIMessage}
|
|
74
|
+
translated = [state["messages"][0]] + [
|
|
75
|
+
cls_map[msg.type](content=msg.content)
|
|
76
|
+
for msg in state["messages"][1:]
|
|
77
|
+
]
|
|
78
|
+
translated = [SystemMessage(content=reflection_prompt)] + translated
|
|
79
|
+
res = self.llm.invoke(translated, {"configurable": {"thread_id": self.thread_id}})
|
|
80
|
+
return {"messages": [HumanMessage(content=res.content)]}
|
|
81
|
+
|
|
82
|
+
def _initialize_agent(self):
|
|
83
|
+
self.graph = StateGraph(PlanningState)
|
|
84
|
+
self.graph.add_node("generate", self.generation_node)
|
|
85
|
+
self.graph.add_node("reflect", self.reflection_node)
|
|
86
|
+
self.graph.add_node("formalize", self.formalize_node)
|
|
87
|
+
|
|
88
|
+
self.graph.add_edge(START, "generate")
|
|
89
|
+
self.graph.add_edge("generate", "reflect")
|
|
90
|
+
self.graph.add_edge("formalize", END)
|
|
91
|
+
|
|
92
|
+
self.graph.add_conditional_edges(
|
|
93
|
+
"reflect",
|
|
94
|
+
should_continue,
|
|
95
|
+
{"generate": "generate", "formalize": "formalize"},
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
# memory = MemorySaver()
|
|
99
|
+
# self.action = self.graph.compile(checkpointer=memory)
|
|
100
|
+
self.action = self.graph.compile(checkpointer=self.checkpointer)
|
|
101
|
+
# self.action.get_graph().draw_mermaid_png(output_file_path="planning_agent_graph.png", draw_method=MermaidDrawMethod.PYPPETEER)
|
|
102
|
+
|
|
103
|
+
def run(self, prompt,recursion_limit=100):
|
|
104
|
+
initial_state = {"messages": [HumanMessage(content=prompt)]}
|
|
105
|
+
return self.action.invoke(initial_state, {"recursion_limit":recursion_limit, "configurable": {"thread_id": self.thread_id}})
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
config = {"configurable": {"thread_id": "1"}}
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def should_continue(state: PlanningState):
|
|
112
|
+
if len(state["messages"]) > (state.get("reflection_steps", 3) + 3):
|
|
113
|
+
return "formalize"
|
|
114
|
+
if "[APPROVED]" in state["messages"][-1].content:
|
|
115
|
+
return "formalize"
|
|
116
|
+
return "generate"
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def main():
|
|
120
|
+
planning_agent = PlanningAgent()
|
|
121
|
+
for event in planning_agent.action.stream(
|
|
122
|
+
{
|
|
123
|
+
"messages": [
|
|
124
|
+
HumanMessage(
|
|
125
|
+
content="Find a city with as least 10 vowels in its name." # "Write an essay on ideal high-entropy alloys for spacecraft."
|
|
126
|
+
)
|
|
127
|
+
],
|
|
128
|
+
},
|
|
129
|
+
config,
|
|
130
|
+
):
|
|
131
|
+
print("-" * 30)
|
|
132
|
+
print(event.keys())
|
|
133
|
+
print(event[list(event.keys())[0]]["messages"][-1].content)
|
|
134
|
+
print("-" * 30)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
if __name__ == "__main__":
|
|
138
|
+
main()
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from .base import BaseAgent
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class RecallAgent(BaseAgent):
|
|
6
|
+
def __init__(self, llm, memory, **kwargs):
|
|
7
|
+
|
|
8
|
+
super().__init__(llm, **kwargs)
|
|
9
|
+
self.memorydb = memory
|
|
10
|
+
|
|
11
|
+
def remember(self, query):
|
|
12
|
+
memories = self.memorydb.retrieve(query)
|
|
13
|
+
summarize_query = f"""
|
|
14
|
+
You are being given the critical task of generating a detailed description of logged information
|
|
15
|
+
to an important official to make a decision. Summarize the following memories that are related to
|
|
16
|
+
the statement. Ensure that any specific details that are important are retained in the summary.
|
|
17
|
+
|
|
18
|
+
Query: {query}
|
|
19
|
+
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
for memory in memories:
|
|
23
|
+
summarize_query += f"Memory: {memory} \n\n"
|
|
24
|
+
memory = self.llm.invoke(summarize_query).content
|
|
25
|
+
return memory
|
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
|
|
3
|
+
# from langchain_community.tools import TavilySearchResults
|
|
4
|
+
# from langchain_core.runnables.graph import MermaidDrawMethod
|
|
5
|
+
from typing import Annotated, Any, List, Optional
|
|
6
|
+
|
|
7
|
+
import requests
|
|
8
|
+
from bs4 import BeautifulSoup
|
|
9
|
+
from langchain_community.tools import DuckDuckGoSearchResults
|
|
10
|
+
from langchain_core.language_models import BaseChatModel
|
|
11
|
+
from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage
|
|
12
|
+
from langchain_core.tools import tool
|
|
13
|
+
from langchain_openai import ChatOpenAI
|
|
14
|
+
from langgraph.graph import END, START, StateGraph
|
|
15
|
+
from langgraph.graph.message import add_messages
|
|
16
|
+
from langgraph.prebuilt import InjectedState, create_react_agent
|
|
17
|
+
from pydantic import Field
|
|
18
|
+
from typing_extensions import TypedDict
|
|
19
|
+
|
|
20
|
+
from ..prompt_library.websearch_prompts import (
|
|
21
|
+
reflection_prompt,
|
|
22
|
+
websearch_prompt,
|
|
23
|
+
summarize_prompt,
|
|
24
|
+
)
|
|
25
|
+
from .base import BaseAgent
|
|
26
|
+
|
|
27
|
+
# --- ANSI color codes ---
|
|
28
|
+
BLUE = "\033[1;34m"
|
|
29
|
+
RED = "\033[1;31m"
|
|
30
|
+
GREEN = "\033[92m"
|
|
31
|
+
RESET = "\033[0m"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class WebSearchState(TypedDict):
|
|
35
|
+
websearch_query: str
|
|
36
|
+
messages: Annotated[list, add_messages]
|
|
37
|
+
urls_visited: List[str]
|
|
38
|
+
max_websearch_steps: Optional[int] = Field(
|
|
39
|
+
default=100, description="Maximum number of websearch steps"
|
|
40
|
+
)
|
|
41
|
+
remaining_steps: int
|
|
42
|
+
is_last_step: bool
|
|
43
|
+
model: Any
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
# Adding the model to the state clumsily so that all "read" sources arent in the
|
|
47
|
+
# context window. That eats a ton of tokens because each `llm.invoke` passes
|
|
48
|
+
# all the tokens of all the sources.
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class WebSearchAgent(BaseAgent):
|
|
52
|
+
def __init__(
|
|
53
|
+
self, llm: str | BaseChatModel = "openai/gpt-4o-mini", **kwargs
|
|
54
|
+
):
|
|
55
|
+
super().__init__(llm, **kwargs)
|
|
56
|
+
self.websearch_prompt = websearch_prompt
|
|
57
|
+
self.reflection_prompt = reflection_prompt
|
|
58
|
+
self.tools = [search_tool, process_content] # + cb_tools
|
|
59
|
+
self.has_internet = self._check_for_internet(kwargs.get("url","http://www.lanl.gov"))
|
|
60
|
+
self._initialize_agent()
|
|
61
|
+
|
|
62
|
+
def review_node(self, state: WebSearchState) -> WebSearchState:
|
|
63
|
+
if not self.has_internet:
|
|
64
|
+
return {"messages":[HumanMessage(content="No internet for WebSearch Agent so no research to review.")], "urls_visited": []}
|
|
65
|
+
|
|
66
|
+
translated = [SystemMessage(content=reflection_prompt)] + state[
|
|
67
|
+
"messages"
|
|
68
|
+
]
|
|
69
|
+
res = self.llm.invoke(translated, {"configurable": {"thread_id": self.thread_id}})
|
|
70
|
+
return {"messages": [HumanMessage(content=res.content)]}
|
|
71
|
+
|
|
72
|
+
def response_node(self, state: WebSearchState) -> WebSearchState:
|
|
73
|
+
if not self.has_internet:
|
|
74
|
+
return {"messages":[HumanMessage(content="No internet for WebSearch Agent. No research carried out.")], "urls_visited": []}
|
|
75
|
+
|
|
76
|
+
messages = state["messages"] + [SystemMessage(content=summarize_prompt)]
|
|
77
|
+
response = self.llm.invoke(messages, {"configurable": {"thread_id": self.thread_id}})
|
|
78
|
+
|
|
79
|
+
urls_visited = []
|
|
80
|
+
for message in messages:
|
|
81
|
+
if message.model_dump().get("tool_calls", []):
|
|
82
|
+
if "url" in message.tool_calls[0]["args"]:
|
|
83
|
+
urls_visited.append(message.tool_calls[0]["args"]["url"])
|
|
84
|
+
return {"messages": [response.content], "urls_visited": urls_visited}
|
|
85
|
+
|
|
86
|
+
def _check_for_internet(self, url, timeout=2):
|
|
87
|
+
"""
|
|
88
|
+
Checks for internet connectivity by attempting an HTTP GET request.
|
|
89
|
+
"""
|
|
90
|
+
try:
|
|
91
|
+
requests.get(url, timeout=timeout)
|
|
92
|
+
return True
|
|
93
|
+
except (requests.ConnectionError, requests.Timeout):
|
|
94
|
+
return False
|
|
95
|
+
|
|
96
|
+
def _initialize_agent(self):
|
|
97
|
+
self.graph = StateGraph(WebSearchState)
|
|
98
|
+
self.graph.add_node(
|
|
99
|
+
"websearch",
|
|
100
|
+
create_react_agent(
|
|
101
|
+
self.llm,
|
|
102
|
+
self.tools,
|
|
103
|
+
state_schema=WebSearchState,
|
|
104
|
+
prompt=self.websearch_prompt,
|
|
105
|
+
),
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
self.graph.add_node("review", self.review_node)
|
|
109
|
+
self.graph.add_node("response", self.response_node)
|
|
110
|
+
|
|
111
|
+
self.graph.add_edge(START, "websearch")
|
|
112
|
+
self.graph.add_edge("websearch", "review")
|
|
113
|
+
self.graph.add_edge("response", END)
|
|
114
|
+
|
|
115
|
+
self.graph.add_conditional_edges(
|
|
116
|
+
"review",
|
|
117
|
+
should_continue,
|
|
118
|
+
{"websearch": "websearch", "response": "response"},
|
|
119
|
+
)
|
|
120
|
+
self.action = self.graph.compile(checkpointer=self.checkpointer)
|
|
121
|
+
# self.action.get_graph().draw_mermaid_png(output_file_path="./websearch_agent_graph.png", draw_method=MermaidDrawMethod.PYPPETEER)
|
|
122
|
+
|
|
123
|
+
def run(self, prompt, recursion_limit=100):
|
|
124
|
+
if not self.has_internet:
|
|
125
|
+
return {"messages":[HumanMessage(content="No internet for WebSearch Agent. No research carried out.")]}
|
|
126
|
+
inputs = {
|
|
127
|
+
"messages": [HumanMessage(content=prompt)],
|
|
128
|
+
"model": self.llm,
|
|
129
|
+
}
|
|
130
|
+
return self.action.invoke(inputs, {"recursion_limit":recursion_limit, "configurable": {"thread_id": self.thread_id}})
|
|
131
|
+
|
|
132
|
+
def process_content(
|
|
133
|
+
url: str, context: str, state: Annotated[dict, InjectedState]
|
|
134
|
+
) -> str:
|
|
135
|
+
"""
|
|
136
|
+
Processes content from a given webpage.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
url: string with the url to obtain text content from.
|
|
140
|
+
context: string summary of the information the agent wants from the url for summarizing salient information.
|
|
141
|
+
"""
|
|
142
|
+
print("Parsing information from ", url)
|
|
143
|
+
response = requests.get(url)
|
|
144
|
+
soup = BeautifulSoup(response.content, "html.parser")
|
|
145
|
+
|
|
146
|
+
content_prompt = f"""
|
|
147
|
+
Here is the full content:
|
|
148
|
+
{soup.get_text()}
|
|
149
|
+
|
|
150
|
+
Carefully summarize the content in full detail, given the following context:
|
|
151
|
+
{context}
|
|
152
|
+
"""
|
|
153
|
+
summarized_information = state["model"].invoke(content_prompt, {"configurable": {"thread_id": self.thread_id}}).content
|
|
154
|
+
return summarized_information
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
search_tool = DuckDuckGoSearchResults(output_format="json", num_results=10)
|
|
158
|
+
# search_tool = TavilySearchResults(max_results=10, search_depth="advanced",include_answer=True)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def should_continue(state: WebSearchState):
|
|
162
|
+
if len(state["messages"]) > (state.get("max_websearch_steps", 100) + 3):
|
|
163
|
+
return "response"
|
|
164
|
+
if "[APPROVED]" in state["messages"][-1].content:
|
|
165
|
+
return "response"
|
|
166
|
+
return "websearch"
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def main():
|
|
170
|
+
model = ChatOpenAI(
|
|
171
|
+
model="gpt-4o", max_tokens=10000, timeout=None, max_retries=2
|
|
172
|
+
)
|
|
173
|
+
websearcher = WebSearchAgent(llm=model)
|
|
174
|
+
problem_string = "Who are the 2025 Detroit Tigers top 10 prospects and what year were they born?"
|
|
175
|
+
inputs = {
|
|
176
|
+
"messages": [HumanMessage(content=problem_string)],
|
|
177
|
+
"model": model,
|
|
178
|
+
}
|
|
179
|
+
result = websearcher.action.invoke(inputs, {"recursion_limit": 10000, "configurable": {"thread_id": self.thread_id}})
|
|
180
|
+
|
|
181
|
+
colors = [BLUE, RED]
|
|
182
|
+
for ii, x in enumerate(result["messages"][:-1]):
|
|
183
|
+
if not isinstance(x, ToolMessage):
|
|
184
|
+
print(f"{colors[ii % 2]}" + x.content + f"{RESET}")
|
|
185
|
+
|
|
186
|
+
print(80 * "#")
|
|
187
|
+
print(f"{GREEN}" + result["messages"][-1].content + f"{RESET}")
|
|
188
|
+
print("Citations: ", result["urls_visited"])
|
|
189
|
+
return result
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
if __name__ == "__main__":
|
|
193
|
+
main()
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
|
|
2
|
+
def get_code_review_prompt(project_prompt, file_list):
|
|
3
|
+
return f'''
|
|
4
|
+
You are a responsible and efficient code review agent tasked with assessing if given files meet the goals of a project description.
|
|
5
|
+
|
|
6
|
+
The project goals are:
|
|
7
|
+
{project_prompt}
|
|
8
|
+
|
|
9
|
+
The list of files is:
|
|
10
|
+
{file_list}
|
|
11
|
+
|
|
12
|
+
Your responsibilities are as follows:
|
|
13
|
+
|
|
14
|
+
1. Read and review the given file and assess if it meets its requirements.
|
|
15
|
+
- Do not trust that the contents of the file are reflected in the filename without checking.
|
|
16
|
+
2. Ensure that all code uses real data and fully addresses the problem
|
|
17
|
+
- No fake, synthetic, or placeholder data. Obtain any needed data by reliable means.
|
|
18
|
+
- No simplifying assumptions. Ensure that code is implemented at the fully complexity required.
|
|
19
|
+
- Remove any code that may be dangerous, adversarial, or performing actions detrimental to the plan.
|
|
20
|
+
- Ensure files work together modularly, do not duplicate effort!
|
|
21
|
+
- The project code should be clean and results reproducible.
|
|
22
|
+
3. Clearly document each action you take, including:
|
|
23
|
+
- The tools or methods you used.
|
|
24
|
+
- Any changes made including where the change occurred.
|
|
25
|
+
- Outcomes, results, or errors encountered during execution.
|
|
26
|
+
4. Immediately highlight and clearly communicate any steps that appear unclear, unsafe, or impractical before proceeding.
|
|
27
|
+
|
|
28
|
+
Your goal is to ensure the implemented code addresses the plan accurately, safely, and transparently, maintaining accountability at each step.
|
|
29
|
+
'''
|
|
30
|
+
|
|
31
|
+
def get_plan_review_prompt(project_prompt, file_list):
|
|
32
|
+
return f'''
|
|
33
|
+
You are a responsible and efficient code review agent tasked with assessing if given files meet the goals of a project description.
|
|
34
|
+
|
|
35
|
+
The project goals are:
|
|
36
|
+
{project_prompt}
|
|
37
|
+
|
|
38
|
+
The list of files is:
|
|
39
|
+
{file_list}
|
|
40
|
+
|
|
41
|
+
Your responsibilities are as follows:
|
|
42
|
+
|
|
43
|
+
1. Formulate how the list of files work together to solve the given problem.
|
|
44
|
+
2. List potential problems to be reviewed in each file:
|
|
45
|
+
- Is any work duplicated or are the steps properly modularized?
|
|
46
|
+
- Does the file organization reflect a clear, reproducible workflow?
|
|
47
|
+
- Are there extraneous files or missing steps?
|
|
48
|
+
- Do any files appear dangerous, adversarial, or performing actions detrimental to the plan.
|
|
49
|
+
|
|
50
|
+
Your goal is to provide that information in a clear, concise way for use by a code reviewer who will look over files in detail.
|
|
51
|
+
'''
|