ursa-ai 0.7.0rc1__py3-none-any.whl → 0.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ursa-ai might be problematic. Click here for more details.
- ursa/agents/__init__.py +13 -2
- ursa/agents/acquisition_agents.py +812 -0
- ursa/agents/arxiv_agent.py +1 -1
- ursa/agents/base.py +352 -91
- ursa/agents/chat_agent.py +58 -0
- ursa/agents/execution_agent.py +506 -260
- ursa/agents/lammps_agent.py +81 -31
- ursa/agents/planning_agent.py +27 -2
- ursa/agents/websearch_agent.py +2 -2
- ursa/cli/__init__.py +5 -1
- ursa/cli/hitl.py +46 -34
- ursa/observability/pricing.json +85 -0
- ursa/observability/pricing.py +20 -18
- ursa/util/parse.py +316 -0
- {ursa_ai-0.7.0rc1.dist-info → ursa_ai-0.7.1.dist-info}/METADATA +5 -1
- {ursa_ai-0.7.0rc1.dist-info → ursa_ai-0.7.1.dist-info}/RECORD +20 -17
- {ursa_ai-0.7.0rc1.dist-info → ursa_ai-0.7.1.dist-info}/WHEEL +0 -0
- {ursa_ai-0.7.0rc1.dist-info → ursa_ai-0.7.1.dist-info}/entry_points.txt +0 -0
- {ursa_ai-0.7.0rc1.dist-info → ursa_ai-0.7.1.dist-info}/licenses/LICENSE +0 -0
- {ursa_ai-0.7.0rc1.dist-info → ursa_ai-0.7.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,812 @@
|
|
|
1
|
+
# generic_acquisition_agents.py
|
|
2
|
+
|
|
3
|
+
import hashlib
|
|
4
|
+
import json
|
|
5
|
+
import os
|
|
6
|
+
import re
|
|
7
|
+
import shutil
|
|
8
|
+
import time
|
|
9
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
10
|
+
from io import BytesIO
|
|
11
|
+
from typing import Any, Dict, Mapping, Optional
|
|
12
|
+
from urllib.parse import quote, urlparse
|
|
13
|
+
|
|
14
|
+
import feedparser
|
|
15
|
+
|
|
16
|
+
# PDF & Vision extras (match your existing stack)
|
|
17
|
+
import pymupdf
|
|
18
|
+
import requests
|
|
19
|
+
from langchain_community.document_loaders import PyPDFLoader
|
|
20
|
+
from langchain_core.language_models import BaseChatModel
|
|
21
|
+
from langchain_core.output_parsers import StrOutputParser
|
|
22
|
+
from langchain_core.prompts import ChatPromptTemplate
|
|
23
|
+
from langgraph.graph import StateGraph
|
|
24
|
+
from PIL import Image
|
|
25
|
+
from typing_extensions import List, TypedDict
|
|
26
|
+
|
|
27
|
+
from ursa.agents.base import BaseAgent
|
|
28
|
+
from ursa.agents.rag_agent import RAGAgent
|
|
29
|
+
from ursa.util.parse import (
|
|
30
|
+
_derive_filename_from_cd_or_url,
|
|
31
|
+
_download_stream_to,
|
|
32
|
+
_get_soup,
|
|
33
|
+
_is_pdf_response,
|
|
34
|
+
extract_main_text_only,
|
|
35
|
+
resolve_pdf_from_osti_record,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
try:
|
|
39
|
+
from ddgs import DDGS # pip install duckduckgo-search
|
|
40
|
+
except Exception:
|
|
41
|
+
DDGS = None
|
|
42
|
+
|
|
43
|
+
try:
|
|
44
|
+
from openai import OpenAI
|
|
45
|
+
except Exception:
|
|
46
|
+
OpenAI = None
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
# ---------- Shared State / Types ----------
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class ItemMetadata(TypedDict, total=False):
|
|
53
|
+
id: str # canonical ID (e.g., arxiv_id, sha, OSTI id)
|
|
54
|
+
title: str
|
|
55
|
+
url: str
|
|
56
|
+
local_path: str
|
|
57
|
+
full_text: str
|
|
58
|
+
extra: Dict[str, Any]
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class AcquisitionState(TypedDict, total=False):
|
|
62
|
+
query: str
|
|
63
|
+
context: str
|
|
64
|
+
items: List[ItemMetadata]
|
|
65
|
+
summaries: List[str]
|
|
66
|
+
final_summary: str
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
# ---------- Small Utilities reused across agents ----------
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _safe_filename(s: str) -> str:
|
|
73
|
+
s = re.sub(r"[^\w\-_.]+", "_", s)
|
|
74
|
+
return s[:240]
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _hash(s: str) -> str:
|
|
78
|
+
return hashlib.sha256(s.encode("utf-8")).hexdigest()[:16]
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def remove_surrogates(text: str) -> str:
|
|
82
|
+
return re.sub(r"[\ud800-\udfff]", "", text)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _looks_like_pdf_url(url: str) -> bool:
|
|
86
|
+
parsed = urlparse(url)
|
|
87
|
+
return parsed.path.lower().endswith(".pdf")
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def _download(url: str, dest_path: str, timeout: int = 20) -> str:
|
|
91
|
+
r = requests.get(url, stream=True, timeout=timeout)
|
|
92
|
+
r.raise_for_status()
|
|
93
|
+
os.makedirs(os.path.dirname(dest_path), exist_ok=True)
|
|
94
|
+
with open(dest_path, "wb") as f:
|
|
95
|
+
shutil.copyfileobj(r.raw, f)
|
|
96
|
+
return dest_path
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def _load_pdf_text(path: str) -> str:
|
|
100
|
+
loader = PyPDFLoader(path)
|
|
101
|
+
pages = loader.load()
|
|
102
|
+
return "\n".join(p.page_content for p in pages)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
# def _basic_readable_text_from_html(html: str) -> str:
|
|
106
|
+
# soup = BeautifulSoup(html, "html.parser")
|
|
107
|
+
# # Drop scripts/styles/navs for a crude readability
|
|
108
|
+
# for tag in soup(["script", "style", "noscript", "header", "footer", "nav"]):
|
|
109
|
+
# tag.decompose()
|
|
110
|
+
# # Keep title for context
|
|
111
|
+
# title = soup.title.get_text(strip=True) if soup.title else ""
|
|
112
|
+
# # Join paragraphs
|
|
113
|
+
# texts = [
|
|
114
|
+
# p.get_text(" ", strip=True)
|
|
115
|
+
# for p in soup.find_all(["p", "h1", "h2", "h3", "li", "figcaption"])
|
|
116
|
+
# ]
|
|
117
|
+
# body = "\n".join(t for t in texts if t)
|
|
118
|
+
# return (title + "\n\n" + body).strip()
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def describe_image(image: Image.Image) -> str:
|
|
122
|
+
if OpenAI is None:
|
|
123
|
+
return ""
|
|
124
|
+
client = OpenAI()
|
|
125
|
+
buf = BytesIO()
|
|
126
|
+
image.save(buf, format="PNG")
|
|
127
|
+
import base64
|
|
128
|
+
|
|
129
|
+
img_b64 = base64.b64encode(buf.getvalue()).decode()
|
|
130
|
+
resp = client.chat.completions.create(
|
|
131
|
+
model="gpt-4-vision-preview",
|
|
132
|
+
messages=[
|
|
133
|
+
{
|
|
134
|
+
"role": "system",
|
|
135
|
+
"content": "You are a scientific assistant who explains plots and scientific diagrams.",
|
|
136
|
+
},
|
|
137
|
+
{
|
|
138
|
+
"role": "user",
|
|
139
|
+
"content": [
|
|
140
|
+
{
|
|
141
|
+
"type": "text",
|
|
142
|
+
"text": "Describe this scientific image or plot in detail.",
|
|
143
|
+
},
|
|
144
|
+
{
|
|
145
|
+
"type": "image_url",
|
|
146
|
+
"image_url": {
|
|
147
|
+
"url": f"data:image/png;base64,{img_b64}"
|
|
148
|
+
},
|
|
149
|
+
},
|
|
150
|
+
],
|
|
151
|
+
},
|
|
152
|
+
],
|
|
153
|
+
max_tokens=400,
|
|
154
|
+
)
|
|
155
|
+
return resp.choices[0].message.content.strip()
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def extract_and_describe_images(
|
|
159
|
+
pdf_path: str, max_images: int = 5
|
|
160
|
+
) -> List[str]:
|
|
161
|
+
descriptions: List[str] = []
|
|
162
|
+
try:
|
|
163
|
+
doc = pymupdf.open(pdf_path)
|
|
164
|
+
except Exception as e:
|
|
165
|
+
return [f"[Image extraction failed: {e}]"]
|
|
166
|
+
|
|
167
|
+
count = 0
|
|
168
|
+
for pi in range(len(doc)):
|
|
169
|
+
if count >= max_images:
|
|
170
|
+
break
|
|
171
|
+
page = doc[pi]
|
|
172
|
+
for ji, img in enumerate(page.get_images(full=True)):
|
|
173
|
+
if count >= max_images:
|
|
174
|
+
break
|
|
175
|
+
xref = img[0]
|
|
176
|
+
base = doc.extract_image(xref)
|
|
177
|
+
image = Image.open(BytesIO(base["image"]))
|
|
178
|
+
try:
|
|
179
|
+
desc = describe_image(image) if OpenAI else ""
|
|
180
|
+
except Exception as e:
|
|
181
|
+
desc = f"[Error: {e}]"
|
|
182
|
+
descriptions.append(f"Page {pi + 1}, Image {ji + 1}: {desc}")
|
|
183
|
+
count += 1
|
|
184
|
+
return descriptions
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
# ---------- The Parent / Generic Agent ----------
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
class BaseAcquisitionAgent(BaseAgent):
|
|
191
|
+
"""
|
|
192
|
+
A generic "acquire-then-summarize-or-RAG" agent.
|
|
193
|
+
|
|
194
|
+
Subclasses must implement:
|
|
195
|
+
- _search(self, query) -> List[dict-like]: lightweight hits
|
|
196
|
+
- _materialize(self, hit) -> ItemMetadata: download or scrape and return populated item
|
|
197
|
+
- _id(self, hit_or_item) -> str: stable id for caching/file naming
|
|
198
|
+
- _citation(self, item) -> str: human-readable citation string
|
|
199
|
+
|
|
200
|
+
Optional hooks:
|
|
201
|
+
- _postprocess_text(self, text, local_path) -> str (e.g., image interpretation)
|
|
202
|
+
- _filter_hit(self, hit) -> bool
|
|
203
|
+
"""
|
|
204
|
+
|
|
205
|
+
def __init__(
|
|
206
|
+
self,
|
|
207
|
+
llm: str | BaseChatModel = "openai/o3-mini",
|
|
208
|
+
*,
|
|
209
|
+
summarize: bool = True,
|
|
210
|
+
rag_embedding=None,
|
|
211
|
+
process_images: bool = True,
|
|
212
|
+
max_results: int = 5,
|
|
213
|
+
database_path: str = "acq_db",
|
|
214
|
+
summaries_path: str = "acq_summaries",
|
|
215
|
+
vectorstore_path: str = "acq_vectorstores",
|
|
216
|
+
download: bool = True,
|
|
217
|
+
**kwargs,
|
|
218
|
+
):
|
|
219
|
+
super().__init__(llm, **kwargs)
|
|
220
|
+
self.summarize = summarize
|
|
221
|
+
self.rag_embedding = rag_embedding
|
|
222
|
+
self.process_images = process_images
|
|
223
|
+
self.max_results = max_results
|
|
224
|
+
self.database_path = database_path
|
|
225
|
+
self.summaries_path = summaries_path
|
|
226
|
+
self.vectorstore_path = vectorstore_path
|
|
227
|
+
self.download = download
|
|
228
|
+
|
|
229
|
+
os.makedirs(self.database_path, exist_ok=True)
|
|
230
|
+
os.makedirs(self.summaries_path, exist_ok=True)
|
|
231
|
+
|
|
232
|
+
self._action = self._build_graph()
|
|
233
|
+
|
|
234
|
+
# ---- abstract-ish methods ----
|
|
235
|
+
def _search(self, query: str) -> List[Dict[str, Any]]:
|
|
236
|
+
raise NotImplementedError
|
|
237
|
+
|
|
238
|
+
def _materialize(self, hit: Dict[str, Any]) -> ItemMetadata:
|
|
239
|
+
raise NotImplementedError
|
|
240
|
+
|
|
241
|
+
def _id(self, hit_or_item: Dict[str, Any]) -> str:
|
|
242
|
+
raise NotImplementedError
|
|
243
|
+
|
|
244
|
+
def _citation(self, item: ItemMetadata) -> str:
|
|
245
|
+
# Subclass should format its ideal citation; fallback is ID or URL.
|
|
246
|
+
return item.get("id") or item.get("url", "Unknown Source")
|
|
247
|
+
|
|
248
|
+
# ---- optional hooks ----
|
|
249
|
+
def _filter_hit(self, hit: Dict[str, Any]) -> bool:
|
|
250
|
+
return True
|
|
251
|
+
|
|
252
|
+
def _postprocess_text(self, text: str, local_path: Optional[str]) -> str:
|
|
253
|
+
# Default: optionally add image descriptions for PDFs
|
|
254
|
+
if (
|
|
255
|
+
self.process_images
|
|
256
|
+
and local_path
|
|
257
|
+
and local_path.lower().endswith(".pdf")
|
|
258
|
+
):
|
|
259
|
+
try:
|
|
260
|
+
descs = extract_and_describe_images(local_path)
|
|
261
|
+
if any(descs):
|
|
262
|
+
text += "\n\n[Image Interpretations]\n" + "\n".join(descs)
|
|
263
|
+
except Exception:
|
|
264
|
+
pass
|
|
265
|
+
return text
|
|
266
|
+
|
|
267
|
+
# ---- shared nodes ----
|
|
268
|
+
def _fetch_items(self, query: str) -> List[ItemMetadata]:
|
|
269
|
+
hits = self._search(query)[: self.max_results] if self.download else []
|
|
270
|
+
items: List[ItemMetadata] = []
|
|
271
|
+
|
|
272
|
+
# If not downloading/scraping, try to load whatever is cached in database_path.
|
|
273
|
+
if not self.download:
|
|
274
|
+
for fname in os.listdir(self.database_path):
|
|
275
|
+
if fname.lower().endswith((".pdf", ".txt", ".html")):
|
|
276
|
+
item_id = os.path.splitext(fname)[0]
|
|
277
|
+
local_path = os.path.join(self.database_path, fname)
|
|
278
|
+
full_text = ""
|
|
279
|
+
try:
|
|
280
|
+
if fname.lower().endswith(".pdf"):
|
|
281
|
+
full_text = _load_pdf_text(local_path)
|
|
282
|
+
else:
|
|
283
|
+
with open(
|
|
284
|
+
local_path,
|
|
285
|
+
"r",
|
|
286
|
+
encoding="utf-8",
|
|
287
|
+
errors="ignore",
|
|
288
|
+
) as f:
|
|
289
|
+
full_text = f.read()
|
|
290
|
+
except Exception as e:
|
|
291
|
+
full_text = f"[Error reading cached file: {e}]"
|
|
292
|
+
full_text = self._postprocess_text(full_text, local_path)
|
|
293
|
+
items.append({
|
|
294
|
+
"id": item_id,
|
|
295
|
+
"local_path": local_path,
|
|
296
|
+
"full_text": full_text,
|
|
297
|
+
})
|
|
298
|
+
return items
|
|
299
|
+
|
|
300
|
+
# Normal path: search → materialize each
|
|
301
|
+
with ThreadPoolExecutor(max_workers=min(32, max(1, len(hits)))) as ex:
|
|
302
|
+
futures = [
|
|
303
|
+
ex.submit(self._materialize, h)
|
|
304
|
+
for h in hits
|
|
305
|
+
if self._filter_hit(h)
|
|
306
|
+
]
|
|
307
|
+
for fut in as_completed(futures):
|
|
308
|
+
try:
|
|
309
|
+
item = fut.result()
|
|
310
|
+
items.append(item)
|
|
311
|
+
except Exception as e:
|
|
312
|
+
items.append({
|
|
313
|
+
"id": _hash(str(time.time())),
|
|
314
|
+
"full_text": f"[Error: {e}]",
|
|
315
|
+
})
|
|
316
|
+
return items
|
|
317
|
+
|
|
318
|
+
def _fetch_node(self, state: AcquisitionState) -> AcquisitionState:
|
|
319
|
+
items = self._fetch_items(state["query"])
|
|
320
|
+
return {**state, "items": items}
|
|
321
|
+
|
|
322
|
+
def _summarize_node(self, state: AcquisitionState) -> AcquisitionState:
|
|
323
|
+
prompt = ChatPromptTemplate.from_template("""
|
|
324
|
+
You are an assistant responsible for summarizing retrieved content in the context of this task: {context}
|
|
325
|
+
|
|
326
|
+
Summarize the content below:
|
|
327
|
+
|
|
328
|
+
{retrieved_content}
|
|
329
|
+
""")
|
|
330
|
+
chain = prompt | self.llm | StrOutputParser()
|
|
331
|
+
|
|
332
|
+
if "items" not in state or not state["items"]:
|
|
333
|
+
return {**state, "summaries": None}
|
|
334
|
+
|
|
335
|
+
summaries: List[Optional[str]] = [None] * len(state["items"])
|
|
336
|
+
|
|
337
|
+
def process(i: int, item: ItemMetadata):
|
|
338
|
+
item_id = item.get("id", f"item_{i}")
|
|
339
|
+
out_path = os.path.join(
|
|
340
|
+
self.summaries_path, f"{_safe_filename(item_id)}_summary.txt"
|
|
341
|
+
)
|
|
342
|
+
try:
|
|
343
|
+
cleaned = remove_surrogates(item.get("full_text", ""))
|
|
344
|
+
summary = chain.invoke(
|
|
345
|
+
{"retrieved_content": cleaned, "context": state["context"]},
|
|
346
|
+
config=self.build_config(tags=["acq", "summarize_each"]),
|
|
347
|
+
)
|
|
348
|
+
except Exception as e:
|
|
349
|
+
summary = f"[Error summarizing item {item_id}: {e}]"
|
|
350
|
+
with open(out_path, "w", encoding="utf-8") as f:
|
|
351
|
+
f.write(summary)
|
|
352
|
+
return i, summary
|
|
353
|
+
|
|
354
|
+
with ThreadPoolExecutor(max_workers=min(32, len(state["items"]))) as ex:
|
|
355
|
+
futures = [
|
|
356
|
+
ex.submit(process, i, it) for i, it in enumerate(state["items"])
|
|
357
|
+
]
|
|
358
|
+
for fut in as_completed(futures):
|
|
359
|
+
i, s = fut.result()
|
|
360
|
+
summaries[i] = s
|
|
361
|
+
|
|
362
|
+
return {**state, "summaries": summaries} # type: ignore
|
|
363
|
+
|
|
364
|
+
def _rag_node(self, state: AcquisitionState) -> AcquisitionState:
|
|
365
|
+
new_state = state.copy()
|
|
366
|
+
rag_agent = RAGAgent(
|
|
367
|
+
llm=self.llm,
|
|
368
|
+
embedding=self.rag_embedding,
|
|
369
|
+
database_path=self.database_path,
|
|
370
|
+
)
|
|
371
|
+
new_state["final_summary"] = rag_agent.invoke(context=state["context"])[
|
|
372
|
+
"summary"
|
|
373
|
+
]
|
|
374
|
+
return new_state
|
|
375
|
+
|
|
376
|
+
def _aggregate_node(self, state: AcquisitionState) -> AcquisitionState:
|
|
377
|
+
if not state.get("summaries") or not state.get("items"):
|
|
378
|
+
return {**state, "final_summary": None}
|
|
379
|
+
|
|
380
|
+
blocks: List[str] = []
|
|
381
|
+
for idx, (item, summ) in enumerate(
|
|
382
|
+
zip(state["items"], state["summaries"])
|
|
383
|
+
): # type: ignore
|
|
384
|
+
cite = self._citation(item)
|
|
385
|
+
blocks.append(f"[{idx + 1}] {cite}\n\nSummary:\n{summ}")
|
|
386
|
+
|
|
387
|
+
combined = "\n\n" + ("\n\n" + "-" * 40 + "\n\n").join(blocks)
|
|
388
|
+
with open(
|
|
389
|
+
os.path.join(self.summaries_path, "summaries_combined.txt"),
|
|
390
|
+
"w",
|
|
391
|
+
encoding="utf-8",
|
|
392
|
+
) as f:
|
|
393
|
+
f.write(combined)
|
|
394
|
+
|
|
395
|
+
prompt = ChatPromptTemplate.from_template("""
|
|
396
|
+
You are a scientific assistant extracting insights from multiple summaries.
|
|
397
|
+
|
|
398
|
+
Here are the summaries:
|
|
399
|
+
|
|
400
|
+
{Summaries}
|
|
401
|
+
|
|
402
|
+
Your task is to read all the summaries and provide a response to this task: {context}
|
|
403
|
+
""")
|
|
404
|
+
chain = prompt | self.llm | StrOutputParser()
|
|
405
|
+
|
|
406
|
+
final_summary = chain.invoke(
|
|
407
|
+
{"Summaries": combined, "context": state["context"]},
|
|
408
|
+
config=self.build_config(tags=["acq", "aggregate"]),
|
|
409
|
+
)
|
|
410
|
+
with open(
|
|
411
|
+
os.path.join(self.summaries_path, "final_summary.txt"),
|
|
412
|
+
"w",
|
|
413
|
+
encoding="utf-8",
|
|
414
|
+
) as f:
|
|
415
|
+
f.write(final_summary)
|
|
416
|
+
|
|
417
|
+
return {**state, "final_summary": final_summary}
|
|
418
|
+
|
|
419
|
+
def _build_graph(self):
|
|
420
|
+
graph = StateGraph(AcquisitionState)
|
|
421
|
+
self.add_node(graph, self._fetch_node)
|
|
422
|
+
|
|
423
|
+
if self.summarize:
|
|
424
|
+
if self.rag_embedding:
|
|
425
|
+
self.add_node(graph, self._rag_node)
|
|
426
|
+
graph.set_entry_point("_fetch_node")
|
|
427
|
+
graph.add_edge("_fetch_node", "_rag_node")
|
|
428
|
+
graph.set_finish_point("_rag_node")
|
|
429
|
+
else:
|
|
430
|
+
self.add_node(graph, self._summarize_node)
|
|
431
|
+
self.add_node(graph, self._aggregate_node)
|
|
432
|
+
|
|
433
|
+
graph.set_entry_point("_fetch_node")
|
|
434
|
+
graph.add_edge("_fetch_node", "_summarize_node")
|
|
435
|
+
graph.add_edge("_summarize_node", "_aggregate_node")
|
|
436
|
+
graph.set_finish_point("_aggregate_node")
|
|
437
|
+
else:
|
|
438
|
+
graph.set_entry_point("_fetch_node")
|
|
439
|
+
graph.set_finish_point("_fetch_node")
|
|
440
|
+
|
|
441
|
+
return graph.compile(checkpointer=self.checkpointer)
|
|
442
|
+
|
|
443
|
+
def _invoke(
|
|
444
|
+
self,
|
|
445
|
+
inputs: Mapping[str, Any],
|
|
446
|
+
*,
|
|
447
|
+
summarize: bool | None = None,
|
|
448
|
+
recursion_limit: int = 1000,
|
|
449
|
+
**_,
|
|
450
|
+
) -> str:
|
|
451
|
+
config = self.build_config(
|
|
452
|
+
recursion_limit=recursion_limit, tags=["graph"]
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
# alias support like your ArxivAgent
|
|
456
|
+
if "query" not in inputs:
|
|
457
|
+
if "arxiv_search_query" in inputs:
|
|
458
|
+
inputs = dict(inputs)
|
|
459
|
+
inputs["query"] = inputs.pop("arxiv_search_query")
|
|
460
|
+
else:
|
|
461
|
+
raise KeyError(
|
|
462
|
+
"Missing 'query' in inputs (alias 'arxiv_search_query' also accepted)."
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
result = self._action.invoke(inputs, config)
|
|
466
|
+
use_summary = self.summarize if summarize is None else summarize
|
|
467
|
+
return (
|
|
468
|
+
result.get("final_summary", "No summary generated.")
|
|
469
|
+
if use_summary
|
|
470
|
+
else "\n\nFinished fetching items!"
|
|
471
|
+
)
|
|
472
|
+
|
|
473
|
+
|
|
474
|
+
# ---------- Concrete: Web Search via ddgs ----------
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
class WebSearchAgent(BaseAcquisitionAgent):
|
|
478
|
+
"""
|
|
479
|
+
Uses DuckDuckGo Search (ddgs) to find pages, downloads HTML or PDFs,
|
|
480
|
+
extracts text, and then follows the same summarize/RAG path.
|
|
481
|
+
"""
|
|
482
|
+
|
|
483
|
+
def __init__(self, *args, user_agent: str = "Mozilla/5.0", **kwargs):
|
|
484
|
+
super().__init__(*args, **kwargs)
|
|
485
|
+
self.user_agent = user_agent
|
|
486
|
+
if DDGS is None:
|
|
487
|
+
raise ImportError(
|
|
488
|
+
"duckduckgo-search (DDGS) is required for WebSearchAgentGeneric."
|
|
489
|
+
)
|
|
490
|
+
|
|
491
|
+
def _id(self, hit_or_item: Dict[str, Any]) -> str:
|
|
492
|
+
url = hit_or_item.get("href") or hit_or_item.get("url") or ""
|
|
493
|
+
return (
|
|
494
|
+
_hash(url)
|
|
495
|
+
if url
|
|
496
|
+
else hit_or_item.get("id", _hash(json.dumps(hit_or_item)))
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
def _citation(self, item: ItemMetadata) -> str:
|
|
500
|
+
t = item.get("title", "") or ""
|
|
501
|
+
u = item.get("url", "") or ""
|
|
502
|
+
return f"{t} ({u})" if t else (u or item.get("id", "Web result"))
|
|
503
|
+
|
|
504
|
+
def _search(self, query: str) -> List[Dict[str, Any]]:
|
|
505
|
+
results: List[Dict[str, Any]] = []
|
|
506
|
+
with DDGS() as ddgs:
|
|
507
|
+
for r in ddgs.text(
|
|
508
|
+
query, max_results=self.max_results, backend="duckduckgo"
|
|
509
|
+
):
|
|
510
|
+
# r keys typically: title, href, body
|
|
511
|
+
results.append(r)
|
|
512
|
+
return results
|
|
513
|
+
|
|
514
|
+
def _materialize(self, hit: Dict[str, Any]) -> ItemMetadata:
|
|
515
|
+
url = hit.get("href") or hit.get("url")
|
|
516
|
+
title = hit.get("title", "")
|
|
517
|
+
if not url:
|
|
518
|
+
return {"id": self._id(hit), "title": title, "full_text": ""}
|
|
519
|
+
|
|
520
|
+
headers = {"User-Agent": self.user_agent}
|
|
521
|
+
local_path = ""
|
|
522
|
+
full_text = ""
|
|
523
|
+
item_id = self._id(hit)
|
|
524
|
+
|
|
525
|
+
try:
|
|
526
|
+
if _looks_like_pdf_url(url):
|
|
527
|
+
local_path = os.path.join(
|
|
528
|
+
self.database_path, _safe_filename(item_id) + ".pdf"
|
|
529
|
+
)
|
|
530
|
+
_download(url, local_path)
|
|
531
|
+
full_text = _load_pdf_text(local_path)
|
|
532
|
+
else:
|
|
533
|
+
r = requests.get(url, headers=headers, timeout=20)
|
|
534
|
+
r.raise_for_status()
|
|
535
|
+
html = r.text
|
|
536
|
+
local_path = os.path.join(
|
|
537
|
+
self.database_path, _safe_filename(item_id) + ".html"
|
|
538
|
+
)
|
|
539
|
+
with open(local_path, "w", encoding="utf-8") as f:
|
|
540
|
+
f.write(html)
|
|
541
|
+
full_text = extract_main_text_only(html)
|
|
542
|
+
# full_text = _basic_readable_text_from_html(html)
|
|
543
|
+
except Exception as e:
|
|
544
|
+
full_text = f"[Error retrieving {url}: {e}]"
|
|
545
|
+
|
|
546
|
+
full_text = self._postprocess_text(full_text, local_path)
|
|
547
|
+
return {
|
|
548
|
+
"id": item_id,
|
|
549
|
+
"title": title,
|
|
550
|
+
"url": url,
|
|
551
|
+
"local_path": local_path,
|
|
552
|
+
"full_text": full_text,
|
|
553
|
+
"extra": {"snippet": hit.get("body", "")},
|
|
554
|
+
}
|
|
555
|
+
|
|
556
|
+
|
|
557
|
+
# ---------- Concrete: OSTI.gov Agent (minimal, adaptable) ----------
|
|
558
|
+
|
|
559
|
+
|
|
560
|
+
class OSTIAgent(BaseAcquisitionAgent):
|
|
561
|
+
"""
|
|
562
|
+
Minimal OSTI.gov acquisition agent.
|
|
563
|
+
|
|
564
|
+
NOTE:
|
|
565
|
+
- OSTI provides search endpoints that can return metadata including full-text links.
|
|
566
|
+
- Depending on your environment, you may prefer the public API or site scraping.
|
|
567
|
+
- Here we assume a JSON API that yields results with keys like:
|
|
568
|
+
{'osti_id': '12345', 'title': '...', 'pdf_url': 'https://...pdf', 'landing_page': 'https://...'}
|
|
569
|
+
Adapt field names if your OSTI integration differs.
|
|
570
|
+
|
|
571
|
+
Customize `_search` and `_materialize` to match your OSTI access path.
|
|
572
|
+
"""
|
|
573
|
+
|
|
574
|
+
def __init__(
|
|
575
|
+
self,
|
|
576
|
+
*args,
|
|
577
|
+
api_base: str = "https://www.osti.gov/api/v1/records",
|
|
578
|
+
**kwargs,
|
|
579
|
+
):
|
|
580
|
+
super().__init__(*args, **kwargs)
|
|
581
|
+
self.api_base = api_base
|
|
582
|
+
|
|
583
|
+
def _id(self, hit_or_item: Dict[str, Any]) -> str:
|
|
584
|
+
if "osti_id" in hit_or_item:
|
|
585
|
+
return str(hit_or_item["osti_id"])
|
|
586
|
+
if "id" in hit_or_item:
|
|
587
|
+
return str(hit_or_item["id"])
|
|
588
|
+
if "landing_page" in hit_or_item:
|
|
589
|
+
return _hash(hit_or_item["landing_page"])
|
|
590
|
+
return _hash(json.dumps(hit_or_item))
|
|
591
|
+
|
|
592
|
+
def _citation(self, item: ItemMetadata) -> str:
|
|
593
|
+
t = item.get("title", "") or ""
|
|
594
|
+
oid = item.get("id", "")
|
|
595
|
+
return f"OSTI {oid}: {t}" if t else f"OSTI {oid}"
|
|
596
|
+
|
|
597
|
+
def _search(self, query: str) -> List[Dict[str, Any]]:
|
|
598
|
+
"""
|
|
599
|
+
Adjust params to your OSTI setup. This call is intentionally simple;
|
|
600
|
+
add paging/auth as needed.
|
|
601
|
+
"""
|
|
602
|
+
params = {
|
|
603
|
+
"q": query,
|
|
604
|
+
"size": self.max_results,
|
|
605
|
+
}
|
|
606
|
+
try:
|
|
607
|
+
r = requests.get(self.api_base, params=params, timeout=25)
|
|
608
|
+
r.raise_for_status()
|
|
609
|
+
data = r.json()
|
|
610
|
+
# Normalize to a list of hits; adapt key if your API differs.
|
|
611
|
+
if isinstance(data, dict) and "records" in data:
|
|
612
|
+
hits = data["records"]
|
|
613
|
+
elif isinstance(data, list):
|
|
614
|
+
hits = data
|
|
615
|
+
else:
|
|
616
|
+
hits = []
|
|
617
|
+
return hits[: self.max_results]
|
|
618
|
+
except Exception as e:
|
|
619
|
+
return [
|
|
620
|
+
{
|
|
621
|
+
"id": _hash(query + str(time.time())),
|
|
622
|
+
"title": "Search error",
|
|
623
|
+
"error": str(e),
|
|
624
|
+
}
|
|
625
|
+
]
|
|
626
|
+
|
|
627
|
+
def _materialize(self, hit: Dict[str, Any]) -> ItemMetadata:
|
|
628
|
+
item_id = self._id(hit)
|
|
629
|
+
title = hit.get("title") or hit.get("title_public", "") or ""
|
|
630
|
+
landing = None
|
|
631
|
+
local_path = ""
|
|
632
|
+
full_text = ""
|
|
633
|
+
|
|
634
|
+
try:
|
|
635
|
+
pdf_url, landing_used, _ = resolve_pdf_from_osti_record(
|
|
636
|
+
hit,
|
|
637
|
+
headers={"User-Agent": "Mozilla/5.0"},
|
|
638
|
+
unpaywall_email=os.environ.get("UNPAYWALL_EMAIL"), # optional
|
|
639
|
+
)
|
|
640
|
+
|
|
641
|
+
if pdf_url:
|
|
642
|
+
# Try to download as PDF (validate headers)
|
|
643
|
+
with requests.get(
|
|
644
|
+
pdf_url,
|
|
645
|
+
headers={"User-Agent": "Mozilla/5.0"},
|
|
646
|
+
timeout=25,
|
|
647
|
+
allow_redirects=True,
|
|
648
|
+
stream=True,
|
|
649
|
+
) as r:
|
|
650
|
+
r.raise_for_status()
|
|
651
|
+
if _is_pdf_response(r):
|
|
652
|
+
fname = _derive_filename_from_cd_or_url(
|
|
653
|
+
r, f"osti_{item_id}.pdf"
|
|
654
|
+
)
|
|
655
|
+
local_path = os.path.join(self.database_path, fname)
|
|
656
|
+
_download_stream_to(local_path, r)
|
|
657
|
+
# Extract PDF text
|
|
658
|
+
try:
|
|
659
|
+
from langchain_community.document_loaders import (
|
|
660
|
+
PyPDFLoader,
|
|
661
|
+
)
|
|
662
|
+
|
|
663
|
+
loader = PyPDFLoader(local_path)
|
|
664
|
+
pages = loader.load()
|
|
665
|
+
full_text = "\n".join(p.page_content for p in pages)
|
|
666
|
+
except Exception as e:
|
|
667
|
+
full_text = (
|
|
668
|
+
f"[Downloaded but text extraction failed: {e}]"
|
|
669
|
+
)
|
|
670
|
+
else:
|
|
671
|
+
# Not a PDF; treat as HTML landing and parse text
|
|
672
|
+
landing = r.url
|
|
673
|
+
r.close()
|
|
674
|
+
# If we still have no text, try scraping the DOE PAGES landing or citation page
|
|
675
|
+
if not full_text:
|
|
676
|
+
# Prefer DOE PAGES landing if present, else OSTI biblio
|
|
677
|
+
landing = (
|
|
678
|
+
landing
|
|
679
|
+
or landing_used
|
|
680
|
+
or next(
|
|
681
|
+
(
|
|
682
|
+
link.get("href")
|
|
683
|
+
for link in hit.get("links", [])
|
|
684
|
+
if link.get("rel")
|
|
685
|
+
in ("citation_doe_pages", "citation")
|
|
686
|
+
),
|
|
687
|
+
None,
|
|
688
|
+
)
|
|
689
|
+
)
|
|
690
|
+
if landing:
|
|
691
|
+
soup = _get_soup(
|
|
692
|
+
landing,
|
|
693
|
+
timeout=25,
|
|
694
|
+
headers={"User-Agent": "Mozilla/5.0"},
|
|
695
|
+
)
|
|
696
|
+
html_text = soup.get_text(" ", strip=True)
|
|
697
|
+
full_text = html_text[:1_000_000] # keep it bounded
|
|
698
|
+
# Save raw HTML for cache/inspection
|
|
699
|
+
local_path = os.path.join(
|
|
700
|
+
self.database_path, f"{item_id}.html"
|
|
701
|
+
)
|
|
702
|
+
with open(local_path, "w", encoding="utf-8") as f:
|
|
703
|
+
f.write(str(soup))
|
|
704
|
+
else:
|
|
705
|
+
full_text = "[No PDF or landing page text available.]"
|
|
706
|
+
|
|
707
|
+
except Exception as e:
|
|
708
|
+
full_text = f"[Error materializing OSTI {item_id}: {e}]"
|
|
709
|
+
|
|
710
|
+
full_text = self._postprocess_text(full_text, local_path)
|
|
711
|
+
return {
|
|
712
|
+
"id": item_id,
|
|
713
|
+
"title": title,
|
|
714
|
+
"url": landing,
|
|
715
|
+
"local_path": local_path,
|
|
716
|
+
"full_text": full_text,
|
|
717
|
+
"extra": {"raw_hit": hit},
|
|
718
|
+
}
|
|
719
|
+
|
|
720
|
+
|
|
721
|
+
# ---------- (Optional) Refactor your ArxivAgent to reuse the parent ----------
|
|
722
|
+
|
|
723
|
+
|
|
724
|
+
class ArxivAgent(BaseAcquisitionAgent):
|
|
725
|
+
"""
|
|
726
|
+
Drop-in replacement for your existing ArxivAgent that reuses the generic flow.
|
|
727
|
+
Keeps the same behaviors (download PDFs, image processing, summarization/RAG).
|
|
728
|
+
"""
|
|
729
|
+
|
|
730
|
+
def __init__(
|
|
731
|
+
self,
|
|
732
|
+
llm: str | BaseChatModel = "openai/o3-mini",
|
|
733
|
+
*,
|
|
734
|
+
process_images: bool = True,
|
|
735
|
+
max_results: int = 3,
|
|
736
|
+
download: bool = True,
|
|
737
|
+
rag_embedding=None,
|
|
738
|
+
database_path="arxiv_papers",
|
|
739
|
+
summaries_path="arxiv_generated_summaries",
|
|
740
|
+
vectorstore_path="arxiv_vectorstores",
|
|
741
|
+
**kwargs,
|
|
742
|
+
):
|
|
743
|
+
super().__init__(
|
|
744
|
+
llm,
|
|
745
|
+
rag_embedding=rag_embedding,
|
|
746
|
+
process_images=process_images,
|
|
747
|
+
max_results=max_results,
|
|
748
|
+
database_path=database_path,
|
|
749
|
+
summaries_path=summaries_path,
|
|
750
|
+
vectorstore_path=vectorstore_path,
|
|
751
|
+
download=download,
|
|
752
|
+
**kwargs,
|
|
753
|
+
)
|
|
754
|
+
|
|
755
|
+
def _id(self, hit_or_item: Dict[str, Any]) -> str:
|
|
756
|
+
# hits from arXiv feed have 'id' like ".../abs/XXXX.YYYY"
|
|
757
|
+
arxiv_id = hit_or_item.get("arxiv_id")
|
|
758
|
+
if arxiv_id:
|
|
759
|
+
return arxiv_id
|
|
760
|
+
feed_id = hit_or_item.get("id", "")
|
|
761
|
+
if "/abs/" in feed_id:
|
|
762
|
+
return feed_id.split("/abs/")[-1]
|
|
763
|
+
return _hash(json.dumps(hit_or_item))
|
|
764
|
+
|
|
765
|
+
def _citation(self, item: ItemMetadata) -> str:
|
|
766
|
+
return f"ArXiv ID: {item.get('id', '?')}"
|
|
767
|
+
|
|
768
|
+
def _search(self, query: str) -> List[Dict[str, Any]]:
|
|
769
|
+
enc = quote(query)
|
|
770
|
+
url = f"http://export.arxiv.org/api/query?search_query=all:{enc}&start=0&max_results={self.max_results}"
|
|
771
|
+
try:
|
|
772
|
+
resp = requests.get(url, timeout=15)
|
|
773
|
+
resp.raise_for_status()
|
|
774
|
+
feed = feedparser.parse(resp.content)
|
|
775
|
+
entries = feed.entries if hasattr(feed, "entries") else []
|
|
776
|
+
hits = []
|
|
777
|
+
for e in entries:
|
|
778
|
+
full_id = e.id.split("/abs/")[-1]
|
|
779
|
+
hits.append({
|
|
780
|
+
"id": e.id,
|
|
781
|
+
"title": e.title.strip(),
|
|
782
|
+
"arxiv_id": full_id.split("/")[-1],
|
|
783
|
+
})
|
|
784
|
+
return hits
|
|
785
|
+
except Exception as e:
|
|
786
|
+
return [
|
|
787
|
+
{
|
|
788
|
+
"id": _hash(query + str(time.time())),
|
|
789
|
+
"title": "Search error",
|
|
790
|
+
"error": str(e),
|
|
791
|
+
}
|
|
792
|
+
]
|
|
793
|
+
|
|
794
|
+
def _materialize(self, hit: Dict[str, Any]) -> ItemMetadata:
|
|
795
|
+
arxiv_id = self._id(hit)
|
|
796
|
+
title = hit.get("title", "")
|
|
797
|
+
pdf_url = f"https://arxiv.org/pdf/{arxiv_id}.pdf"
|
|
798
|
+
local_path = os.path.join(self.database_path, f"{arxiv_id}.pdf")
|
|
799
|
+
full_text = ""
|
|
800
|
+
try:
|
|
801
|
+
_download(pdf_url, local_path)
|
|
802
|
+
full_text = _load_pdf_text(local_path)
|
|
803
|
+
except Exception as e:
|
|
804
|
+
full_text = f"[Error loading ArXiv {arxiv_id}: {e}]"
|
|
805
|
+
full_text = self._postprocess_text(full_text, local_path)
|
|
806
|
+
return {
|
|
807
|
+
"id": arxiv_id,
|
|
808
|
+
"title": title,
|
|
809
|
+
"url": pdf_url,
|
|
810
|
+
"local_path": local_path,
|
|
811
|
+
"full_text": full_text,
|
|
812
|
+
}
|