ursa-ai 0.7.0rc2__py3-none-any.whl → 0.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ursa-ai might be problematic. Click here for more details.

@@ -0,0 +1,812 @@
1
+ # generic_acquisition_agents.py
2
+
3
+ import hashlib
4
+ import json
5
+ import os
6
+ import re
7
+ import shutil
8
+ import time
9
+ from concurrent.futures import ThreadPoolExecutor, as_completed
10
+ from io import BytesIO
11
+ from typing import Any, Dict, Mapping, Optional
12
+ from urllib.parse import quote, urlparse
13
+
14
+ import feedparser
15
+
16
+ # PDF & Vision extras (match your existing stack)
17
+ import pymupdf
18
+ import requests
19
+ from langchain_community.document_loaders import PyPDFLoader
20
+ from langchain_core.language_models import BaseChatModel
21
+ from langchain_core.output_parsers import StrOutputParser
22
+ from langchain_core.prompts import ChatPromptTemplate
23
+ from langgraph.graph import StateGraph
24
+ from PIL import Image
25
+ from typing_extensions import List, TypedDict
26
+
27
+ from ursa.agents.base import BaseAgent
28
+ from ursa.agents.rag_agent import RAGAgent
29
+ from ursa.util.parse import (
30
+ _derive_filename_from_cd_or_url,
31
+ _download_stream_to,
32
+ _get_soup,
33
+ _is_pdf_response,
34
+ extract_main_text_only,
35
+ resolve_pdf_from_osti_record,
36
+ )
37
+
38
+ try:
39
+ from ddgs import DDGS # pip install duckduckgo-search
40
+ except Exception:
41
+ DDGS = None
42
+
43
+ try:
44
+ from openai import OpenAI
45
+ except Exception:
46
+ OpenAI = None
47
+
48
+
49
+ # ---------- Shared State / Types ----------
50
+
51
+
52
+ class ItemMetadata(TypedDict, total=False):
53
+ id: str # canonical ID (e.g., arxiv_id, sha, OSTI id)
54
+ title: str
55
+ url: str
56
+ local_path: str
57
+ full_text: str
58
+ extra: Dict[str, Any]
59
+
60
+
61
+ class AcquisitionState(TypedDict, total=False):
62
+ query: str
63
+ context: str
64
+ items: List[ItemMetadata]
65
+ summaries: List[str]
66
+ final_summary: str
67
+
68
+
69
+ # ---------- Small Utilities reused across agents ----------
70
+
71
+
72
+ def _safe_filename(s: str) -> str:
73
+ s = re.sub(r"[^\w\-_.]+", "_", s)
74
+ return s[:240]
75
+
76
+
77
+ def _hash(s: str) -> str:
78
+ return hashlib.sha256(s.encode("utf-8")).hexdigest()[:16]
79
+
80
+
81
+ def remove_surrogates(text: str) -> str:
82
+ return re.sub(r"[\ud800-\udfff]", "", text)
83
+
84
+
85
+ def _looks_like_pdf_url(url: str) -> bool:
86
+ parsed = urlparse(url)
87
+ return parsed.path.lower().endswith(".pdf")
88
+
89
+
90
+ def _download(url: str, dest_path: str, timeout: int = 20) -> str:
91
+ r = requests.get(url, stream=True, timeout=timeout)
92
+ r.raise_for_status()
93
+ os.makedirs(os.path.dirname(dest_path), exist_ok=True)
94
+ with open(dest_path, "wb") as f:
95
+ shutil.copyfileobj(r.raw, f)
96
+ return dest_path
97
+
98
+
99
+ def _load_pdf_text(path: str) -> str:
100
+ loader = PyPDFLoader(path)
101
+ pages = loader.load()
102
+ return "\n".join(p.page_content for p in pages)
103
+
104
+
105
+ # def _basic_readable_text_from_html(html: str) -> str:
106
+ # soup = BeautifulSoup(html, "html.parser")
107
+ # # Drop scripts/styles/navs for a crude readability
108
+ # for tag in soup(["script", "style", "noscript", "header", "footer", "nav"]):
109
+ # tag.decompose()
110
+ # # Keep title for context
111
+ # title = soup.title.get_text(strip=True) if soup.title else ""
112
+ # # Join paragraphs
113
+ # texts = [
114
+ # p.get_text(" ", strip=True)
115
+ # for p in soup.find_all(["p", "h1", "h2", "h3", "li", "figcaption"])
116
+ # ]
117
+ # body = "\n".join(t for t in texts if t)
118
+ # return (title + "\n\n" + body).strip()
119
+
120
+
121
+ def describe_image(image: Image.Image) -> str:
122
+ if OpenAI is None:
123
+ return ""
124
+ client = OpenAI()
125
+ buf = BytesIO()
126
+ image.save(buf, format="PNG")
127
+ import base64
128
+
129
+ img_b64 = base64.b64encode(buf.getvalue()).decode()
130
+ resp = client.chat.completions.create(
131
+ model="gpt-4-vision-preview",
132
+ messages=[
133
+ {
134
+ "role": "system",
135
+ "content": "You are a scientific assistant who explains plots and scientific diagrams.",
136
+ },
137
+ {
138
+ "role": "user",
139
+ "content": [
140
+ {
141
+ "type": "text",
142
+ "text": "Describe this scientific image or plot in detail.",
143
+ },
144
+ {
145
+ "type": "image_url",
146
+ "image_url": {
147
+ "url": f"data:image/png;base64,{img_b64}"
148
+ },
149
+ },
150
+ ],
151
+ },
152
+ ],
153
+ max_tokens=400,
154
+ )
155
+ return resp.choices[0].message.content.strip()
156
+
157
+
158
+ def extract_and_describe_images(
159
+ pdf_path: str, max_images: int = 5
160
+ ) -> List[str]:
161
+ descriptions: List[str] = []
162
+ try:
163
+ doc = pymupdf.open(pdf_path)
164
+ except Exception as e:
165
+ return [f"[Image extraction failed: {e}]"]
166
+
167
+ count = 0
168
+ for pi in range(len(doc)):
169
+ if count >= max_images:
170
+ break
171
+ page = doc[pi]
172
+ for ji, img in enumerate(page.get_images(full=True)):
173
+ if count >= max_images:
174
+ break
175
+ xref = img[0]
176
+ base = doc.extract_image(xref)
177
+ image = Image.open(BytesIO(base["image"]))
178
+ try:
179
+ desc = describe_image(image) if OpenAI else ""
180
+ except Exception as e:
181
+ desc = f"[Error: {e}]"
182
+ descriptions.append(f"Page {pi + 1}, Image {ji + 1}: {desc}")
183
+ count += 1
184
+ return descriptions
185
+
186
+
187
+ # ---------- The Parent / Generic Agent ----------
188
+
189
+
190
+ class BaseAcquisitionAgent(BaseAgent):
191
+ """
192
+ A generic "acquire-then-summarize-or-RAG" agent.
193
+
194
+ Subclasses must implement:
195
+ - _search(self, query) -> List[dict-like]: lightweight hits
196
+ - _materialize(self, hit) -> ItemMetadata: download or scrape and return populated item
197
+ - _id(self, hit_or_item) -> str: stable id for caching/file naming
198
+ - _citation(self, item) -> str: human-readable citation string
199
+
200
+ Optional hooks:
201
+ - _postprocess_text(self, text, local_path) -> str (e.g., image interpretation)
202
+ - _filter_hit(self, hit) -> bool
203
+ """
204
+
205
+ def __init__(
206
+ self,
207
+ llm: str | BaseChatModel = "openai/o3-mini",
208
+ *,
209
+ summarize: bool = True,
210
+ rag_embedding=None,
211
+ process_images: bool = True,
212
+ max_results: int = 5,
213
+ database_path: str = "acq_db",
214
+ summaries_path: str = "acq_summaries",
215
+ vectorstore_path: str = "acq_vectorstores",
216
+ download: bool = True,
217
+ **kwargs,
218
+ ):
219
+ super().__init__(llm, **kwargs)
220
+ self.summarize = summarize
221
+ self.rag_embedding = rag_embedding
222
+ self.process_images = process_images
223
+ self.max_results = max_results
224
+ self.database_path = database_path
225
+ self.summaries_path = summaries_path
226
+ self.vectorstore_path = vectorstore_path
227
+ self.download = download
228
+
229
+ os.makedirs(self.database_path, exist_ok=True)
230
+ os.makedirs(self.summaries_path, exist_ok=True)
231
+
232
+ self._action = self._build_graph()
233
+
234
+ # ---- abstract-ish methods ----
235
+ def _search(self, query: str) -> List[Dict[str, Any]]:
236
+ raise NotImplementedError
237
+
238
+ def _materialize(self, hit: Dict[str, Any]) -> ItemMetadata:
239
+ raise NotImplementedError
240
+
241
+ def _id(self, hit_or_item: Dict[str, Any]) -> str:
242
+ raise NotImplementedError
243
+
244
+ def _citation(self, item: ItemMetadata) -> str:
245
+ # Subclass should format its ideal citation; fallback is ID or URL.
246
+ return item.get("id") or item.get("url", "Unknown Source")
247
+
248
+ # ---- optional hooks ----
249
+ def _filter_hit(self, hit: Dict[str, Any]) -> bool:
250
+ return True
251
+
252
+ def _postprocess_text(self, text: str, local_path: Optional[str]) -> str:
253
+ # Default: optionally add image descriptions for PDFs
254
+ if (
255
+ self.process_images
256
+ and local_path
257
+ and local_path.lower().endswith(".pdf")
258
+ ):
259
+ try:
260
+ descs = extract_and_describe_images(local_path)
261
+ if any(descs):
262
+ text += "\n\n[Image Interpretations]\n" + "\n".join(descs)
263
+ except Exception:
264
+ pass
265
+ return text
266
+
267
+ # ---- shared nodes ----
268
+ def _fetch_items(self, query: str) -> List[ItemMetadata]:
269
+ hits = self._search(query)[: self.max_results] if self.download else []
270
+ items: List[ItemMetadata] = []
271
+
272
+ # If not downloading/scraping, try to load whatever is cached in database_path.
273
+ if not self.download:
274
+ for fname in os.listdir(self.database_path):
275
+ if fname.lower().endswith((".pdf", ".txt", ".html")):
276
+ item_id = os.path.splitext(fname)[0]
277
+ local_path = os.path.join(self.database_path, fname)
278
+ full_text = ""
279
+ try:
280
+ if fname.lower().endswith(".pdf"):
281
+ full_text = _load_pdf_text(local_path)
282
+ else:
283
+ with open(
284
+ local_path,
285
+ "r",
286
+ encoding="utf-8",
287
+ errors="ignore",
288
+ ) as f:
289
+ full_text = f.read()
290
+ except Exception as e:
291
+ full_text = f"[Error reading cached file: {e}]"
292
+ full_text = self._postprocess_text(full_text, local_path)
293
+ items.append({
294
+ "id": item_id,
295
+ "local_path": local_path,
296
+ "full_text": full_text,
297
+ })
298
+ return items
299
+
300
+ # Normal path: search → materialize each
301
+ with ThreadPoolExecutor(max_workers=min(32, max(1, len(hits)))) as ex:
302
+ futures = [
303
+ ex.submit(self._materialize, h)
304
+ for h in hits
305
+ if self._filter_hit(h)
306
+ ]
307
+ for fut in as_completed(futures):
308
+ try:
309
+ item = fut.result()
310
+ items.append(item)
311
+ except Exception as e:
312
+ items.append({
313
+ "id": _hash(str(time.time())),
314
+ "full_text": f"[Error: {e}]",
315
+ })
316
+ return items
317
+
318
+ def _fetch_node(self, state: AcquisitionState) -> AcquisitionState:
319
+ items = self._fetch_items(state["query"])
320
+ return {**state, "items": items}
321
+
322
+ def _summarize_node(self, state: AcquisitionState) -> AcquisitionState:
323
+ prompt = ChatPromptTemplate.from_template("""
324
+ You are an assistant responsible for summarizing retrieved content in the context of this task: {context}
325
+
326
+ Summarize the content below:
327
+
328
+ {retrieved_content}
329
+ """)
330
+ chain = prompt | self.llm | StrOutputParser()
331
+
332
+ if "items" not in state or not state["items"]:
333
+ return {**state, "summaries": None}
334
+
335
+ summaries: List[Optional[str]] = [None] * len(state["items"])
336
+
337
+ def process(i: int, item: ItemMetadata):
338
+ item_id = item.get("id", f"item_{i}")
339
+ out_path = os.path.join(
340
+ self.summaries_path, f"{_safe_filename(item_id)}_summary.txt"
341
+ )
342
+ try:
343
+ cleaned = remove_surrogates(item.get("full_text", ""))
344
+ summary = chain.invoke(
345
+ {"retrieved_content": cleaned, "context": state["context"]},
346
+ config=self.build_config(tags=["acq", "summarize_each"]),
347
+ )
348
+ except Exception as e:
349
+ summary = f"[Error summarizing item {item_id}: {e}]"
350
+ with open(out_path, "w", encoding="utf-8") as f:
351
+ f.write(summary)
352
+ return i, summary
353
+
354
+ with ThreadPoolExecutor(max_workers=min(32, len(state["items"]))) as ex:
355
+ futures = [
356
+ ex.submit(process, i, it) for i, it in enumerate(state["items"])
357
+ ]
358
+ for fut in as_completed(futures):
359
+ i, s = fut.result()
360
+ summaries[i] = s
361
+
362
+ return {**state, "summaries": summaries} # type: ignore
363
+
364
+ def _rag_node(self, state: AcquisitionState) -> AcquisitionState:
365
+ new_state = state.copy()
366
+ rag_agent = RAGAgent(
367
+ llm=self.llm,
368
+ embedding=self.rag_embedding,
369
+ database_path=self.database_path,
370
+ )
371
+ new_state["final_summary"] = rag_agent.invoke(context=state["context"])[
372
+ "summary"
373
+ ]
374
+ return new_state
375
+
376
+ def _aggregate_node(self, state: AcquisitionState) -> AcquisitionState:
377
+ if not state.get("summaries") or not state.get("items"):
378
+ return {**state, "final_summary": None}
379
+
380
+ blocks: List[str] = []
381
+ for idx, (item, summ) in enumerate(
382
+ zip(state["items"], state["summaries"])
383
+ ): # type: ignore
384
+ cite = self._citation(item)
385
+ blocks.append(f"[{idx + 1}] {cite}\n\nSummary:\n{summ}")
386
+
387
+ combined = "\n\n" + ("\n\n" + "-" * 40 + "\n\n").join(blocks)
388
+ with open(
389
+ os.path.join(self.summaries_path, "summaries_combined.txt"),
390
+ "w",
391
+ encoding="utf-8",
392
+ ) as f:
393
+ f.write(combined)
394
+
395
+ prompt = ChatPromptTemplate.from_template("""
396
+ You are a scientific assistant extracting insights from multiple summaries.
397
+
398
+ Here are the summaries:
399
+
400
+ {Summaries}
401
+
402
+ Your task is to read all the summaries and provide a response to this task: {context}
403
+ """)
404
+ chain = prompt | self.llm | StrOutputParser()
405
+
406
+ final_summary = chain.invoke(
407
+ {"Summaries": combined, "context": state["context"]},
408
+ config=self.build_config(tags=["acq", "aggregate"]),
409
+ )
410
+ with open(
411
+ os.path.join(self.summaries_path, "final_summary.txt"),
412
+ "w",
413
+ encoding="utf-8",
414
+ ) as f:
415
+ f.write(final_summary)
416
+
417
+ return {**state, "final_summary": final_summary}
418
+
419
+ def _build_graph(self):
420
+ graph = StateGraph(AcquisitionState)
421
+ self.add_node(graph, self._fetch_node)
422
+
423
+ if self.summarize:
424
+ if self.rag_embedding:
425
+ self.add_node(graph, self._rag_node)
426
+ graph.set_entry_point("_fetch_node")
427
+ graph.add_edge("_fetch_node", "_rag_node")
428
+ graph.set_finish_point("_rag_node")
429
+ else:
430
+ self.add_node(graph, self._summarize_node)
431
+ self.add_node(graph, self._aggregate_node)
432
+
433
+ graph.set_entry_point("_fetch_node")
434
+ graph.add_edge("_fetch_node", "_summarize_node")
435
+ graph.add_edge("_summarize_node", "_aggregate_node")
436
+ graph.set_finish_point("_aggregate_node")
437
+ else:
438
+ graph.set_entry_point("_fetch_node")
439
+ graph.set_finish_point("_fetch_node")
440
+
441
+ return graph.compile(checkpointer=self.checkpointer)
442
+
443
+ def _invoke(
444
+ self,
445
+ inputs: Mapping[str, Any],
446
+ *,
447
+ summarize: bool | None = None,
448
+ recursion_limit: int = 1000,
449
+ **_,
450
+ ) -> str:
451
+ config = self.build_config(
452
+ recursion_limit=recursion_limit, tags=["graph"]
453
+ )
454
+
455
+ # alias support like your ArxivAgent
456
+ if "query" not in inputs:
457
+ if "arxiv_search_query" in inputs:
458
+ inputs = dict(inputs)
459
+ inputs["query"] = inputs.pop("arxiv_search_query")
460
+ else:
461
+ raise KeyError(
462
+ "Missing 'query' in inputs (alias 'arxiv_search_query' also accepted)."
463
+ )
464
+
465
+ result = self._action.invoke(inputs, config)
466
+ use_summary = self.summarize if summarize is None else summarize
467
+ return (
468
+ result.get("final_summary", "No summary generated.")
469
+ if use_summary
470
+ else "\n\nFinished fetching items!"
471
+ )
472
+
473
+
474
+ # ---------- Concrete: Web Search via ddgs ----------
475
+
476
+
477
+ class WebSearchAgent(BaseAcquisitionAgent):
478
+ """
479
+ Uses DuckDuckGo Search (ddgs) to find pages, downloads HTML or PDFs,
480
+ extracts text, and then follows the same summarize/RAG path.
481
+ """
482
+
483
+ def __init__(self, *args, user_agent: str = "Mozilla/5.0", **kwargs):
484
+ super().__init__(*args, **kwargs)
485
+ self.user_agent = user_agent
486
+ if DDGS is None:
487
+ raise ImportError(
488
+ "duckduckgo-search (DDGS) is required for WebSearchAgentGeneric."
489
+ )
490
+
491
+ def _id(self, hit_or_item: Dict[str, Any]) -> str:
492
+ url = hit_or_item.get("href") or hit_or_item.get("url") or ""
493
+ return (
494
+ _hash(url)
495
+ if url
496
+ else hit_or_item.get("id", _hash(json.dumps(hit_or_item)))
497
+ )
498
+
499
+ def _citation(self, item: ItemMetadata) -> str:
500
+ t = item.get("title", "") or ""
501
+ u = item.get("url", "") or ""
502
+ return f"{t} ({u})" if t else (u or item.get("id", "Web result"))
503
+
504
+ def _search(self, query: str) -> List[Dict[str, Any]]:
505
+ results: List[Dict[str, Any]] = []
506
+ with DDGS() as ddgs:
507
+ for r in ddgs.text(
508
+ query, max_results=self.max_results, backend="duckduckgo"
509
+ ):
510
+ # r keys typically: title, href, body
511
+ results.append(r)
512
+ return results
513
+
514
+ def _materialize(self, hit: Dict[str, Any]) -> ItemMetadata:
515
+ url = hit.get("href") or hit.get("url")
516
+ title = hit.get("title", "")
517
+ if not url:
518
+ return {"id": self._id(hit), "title": title, "full_text": ""}
519
+
520
+ headers = {"User-Agent": self.user_agent}
521
+ local_path = ""
522
+ full_text = ""
523
+ item_id = self._id(hit)
524
+
525
+ try:
526
+ if _looks_like_pdf_url(url):
527
+ local_path = os.path.join(
528
+ self.database_path, _safe_filename(item_id) + ".pdf"
529
+ )
530
+ _download(url, local_path)
531
+ full_text = _load_pdf_text(local_path)
532
+ else:
533
+ r = requests.get(url, headers=headers, timeout=20)
534
+ r.raise_for_status()
535
+ html = r.text
536
+ local_path = os.path.join(
537
+ self.database_path, _safe_filename(item_id) + ".html"
538
+ )
539
+ with open(local_path, "w", encoding="utf-8") as f:
540
+ f.write(html)
541
+ full_text = extract_main_text_only(html)
542
+ # full_text = _basic_readable_text_from_html(html)
543
+ except Exception as e:
544
+ full_text = f"[Error retrieving {url}: {e}]"
545
+
546
+ full_text = self._postprocess_text(full_text, local_path)
547
+ return {
548
+ "id": item_id,
549
+ "title": title,
550
+ "url": url,
551
+ "local_path": local_path,
552
+ "full_text": full_text,
553
+ "extra": {"snippet": hit.get("body", "")},
554
+ }
555
+
556
+
557
+ # ---------- Concrete: OSTI.gov Agent (minimal, adaptable) ----------
558
+
559
+
560
+ class OSTIAgent(BaseAcquisitionAgent):
561
+ """
562
+ Minimal OSTI.gov acquisition agent.
563
+
564
+ NOTE:
565
+ - OSTI provides search endpoints that can return metadata including full-text links.
566
+ - Depending on your environment, you may prefer the public API or site scraping.
567
+ - Here we assume a JSON API that yields results with keys like:
568
+ {'osti_id': '12345', 'title': '...', 'pdf_url': 'https://...pdf', 'landing_page': 'https://...'}
569
+ Adapt field names if your OSTI integration differs.
570
+
571
+ Customize `_search` and `_materialize` to match your OSTI access path.
572
+ """
573
+
574
+ def __init__(
575
+ self,
576
+ *args,
577
+ api_base: str = "https://www.osti.gov/api/v1/records",
578
+ **kwargs,
579
+ ):
580
+ super().__init__(*args, **kwargs)
581
+ self.api_base = api_base
582
+
583
+ def _id(self, hit_or_item: Dict[str, Any]) -> str:
584
+ if "osti_id" in hit_or_item:
585
+ return str(hit_or_item["osti_id"])
586
+ if "id" in hit_or_item:
587
+ return str(hit_or_item["id"])
588
+ if "landing_page" in hit_or_item:
589
+ return _hash(hit_or_item["landing_page"])
590
+ return _hash(json.dumps(hit_or_item))
591
+
592
+ def _citation(self, item: ItemMetadata) -> str:
593
+ t = item.get("title", "") or ""
594
+ oid = item.get("id", "")
595
+ return f"OSTI {oid}: {t}" if t else f"OSTI {oid}"
596
+
597
+ def _search(self, query: str) -> List[Dict[str, Any]]:
598
+ """
599
+ Adjust params to your OSTI setup. This call is intentionally simple;
600
+ add paging/auth as needed.
601
+ """
602
+ params = {
603
+ "q": query,
604
+ "size": self.max_results,
605
+ }
606
+ try:
607
+ r = requests.get(self.api_base, params=params, timeout=25)
608
+ r.raise_for_status()
609
+ data = r.json()
610
+ # Normalize to a list of hits; adapt key if your API differs.
611
+ if isinstance(data, dict) and "records" in data:
612
+ hits = data["records"]
613
+ elif isinstance(data, list):
614
+ hits = data
615
+ else:
616
+ hits = []
617
+ return hits[: self.max_results]
618
+ except Exception as e:
619
+ return [
620
+ {
621
+ "id": _hash(query + str(time.time())),
622
+ "title": "Search error",
623
+ "error": str(e),
624
+ }
625
+ ]
626
+
627
+ def _materialize(self, hit: Dict[str, Any]) -> ItemMetadata:
628
+ item_id = self._id(hit)
629
+ title = hit.get("title") or hit.get("title_public", "") or ""
630
+ landing = None
631
+ local_path = ""
632
+ full_text = ""
633
+
634
+ try:
635
+ pdf_url, landing_used, _ = resolve_pdf_from_osti_record(
636
+ hit,
637
+ headers={"User-Agent": "Mozilla/5.0"},
638
+ unpaywall_email=os.environ.get("UNPAYWALL_EMAIL"), # optional
639
+ )
640
+
641
+ if pdf_url:
642
+ # Try to download as PDF (validate headers)
643
+ with requests.get(
644
+ pdf_url,
645
+ headers={"User-Agent": "Mozilla/5.0"},
646
+ timeout=25,
647
+ allow_redirects=True,
648
+ stream=True,
649
+ ) as r:
650
+ r.raise_for_status()
651
+ if _is_pdf_response(r):
652
+ fname = _derive_filename_from_cd_or_url(
653
+ r, f"osti_{item_id}.pdf"
654
+ )
655
+ local_path = os.path.join(self.database_path, fname)
656
+ _download_stream_to(local_path, r)
657
+ # Extract PDF text
658
+ try:
659
+ from langchain_community.document_loaders import (
660
+ PyPDFLoader,
661
+ )
662
+
663
+ loader = PyPDFLoader(local_path)
664
+ pages = loader.load()
665
+ full_text = "\n".join(p.page_content for p in pages)
666
+ except Exception as e:
667
+ full_text = (
668
+ f"[Downloaded but text extraction failed: {e}]"
669
+ )
670
+ else:
671
+ # Not a PDF; treat as HTML landing and parse text
672
+ landing = r.url
673
+ r.close()
674
+ # If we still have no text, try scraping the DOE PAGES landing or citation page
675
+ if not full_text:
676
+ # Prefer DOE PAGES landing if present, else OSTI biblio
677
+ landing = (
678
+ landing
679
+ or landing_used
680
+ or next(
681
+ (
682
+ link.get("href")
683
+ for link in hit.get("links", [])
684
+ if link.get("rel")
685
+ in ("citation_doe_pages", "citation")
686
+ ),
687
+ None,
688
+ )
689
+ )
690
+ if landing:
691
+ soup = _get_soup(
692
+ landing,
693
+ timeout=25,
694
+ headers={"User-Agent": "Mozilla/5.0"},
695
+ )
696
+ html_text = soup.get_text(" ", strip=True)
697
+ full_text = html_text[:1_000_000] # keep it bounded
698
+ # Save raw HTML for cache/inspection
699
+ local_path = os.path.join(
700
+ self.database_path, f"{item_id}.html"
701
+ )
702
+ with open(local_path, "w", encoding="utf-8") as f:
703
+ f.write(str(soup))
704
+ else:
705
+ full_text = "[No PDF or landing page text available.]"
706
+
707
+ except Exception as e:
708
+ full_text = f"[Error materializing OSTI {item_id}: {e}]"
709
+
710
+ full_text = self._postprocess_text(full_text, local_path)
711
+ return {
712
+ "id": item_id,
713
+ "title": title,
714
+ "url": landing,
715
+ "local_path": local_path,
716
+ "full_text": full_text,
717
+ "extra": {"raw_hit": hit},
718
+ }
719
+
720
+
721
+ # ---------- (Optional) Refactor your ArxivAgent to reuse the parent ----------
722
+
723
+
724
+ class ArxivAgent(BaseAcquisitionAgent):
725
+ """
726
+ Drop-in replacement for your existing ArxivAgent that reuses the generic flow.
727
+ Keeps the same behaviors (download PDFs, image processing, summarization/RAG).
728
+ """
729
+
730
+ def __init__(
731
+ self,
732
+ llm: str | BaseChatModel = "openai/o3-mini",
733
+ *,
734
+ process_images: bool = True,
735
+ max_results: int = 3,
736
+ download: bool = True,
737
+ rag_embedding=None,
738
+ database_path="arxiv_papers",
739
+ summaries_path="arxiv_generated_summaries",
740
+ vectorstore_path="arxiv_vectorstores",
741
+ **kwargs,
742
+ ):
743
+ super().__init__(
744
+ llm,
745
+ rag_embedding=rag_embedding,
746
+ process_images=process_images,
747
+ max_results=max_results,
748
+ database_path=database_path,
749
+ summaries_path=summaries_path,
750
+ vectorstore_path=vectorstore_path,
751
+ download=download,
752
+ **kwargs,
753
+ )
754
+
755
+ def _id(self, hit_or_item: Dict[str, Any]) -> str:
756
+ # hits from arXiv feed have 'id' like ".../abs/XXXX.YYYY"
757
+ arxiv_id = hit_or_item.get("arxiv_id")
758
+ if arxiv_id:
759
+ return arxiv_id
760
+ feed_id = hit_or_item.get("id", "")
761
+ if "/abs/" in feed_id:
762
+ return feed_id.split("/abs/")[-1]
763
+ return _hash(json.dumps(hit_or_item))
764
+
765
+ def _citation(self, item: ItemMetadata) -> str:
766
+ return f"ArXiv ID: {item.get('id', '?')}"
767
+
768
+ def _search(self, query: str) -> List[Dict[str, Any]]:
769
+ enc = quote(query)
770
+ url = f"http://export.arxiv.org/api/query?search_query=all:{enc}&start=0&max_results={self.max_results}"
771
+ try:
772
+ resp = requests.get(url, timeout=15)
773
+ resp.raise_for_status()
774
+ feed = feedparser.parse(resp.content)
775
+ entries = feed.entries if hasattr(feed, "entries") else []
776
+ hits = []
777
+ for e in entries:
778
+ full_id = e.id.split("/abs/")[-1]
779
+ hits.append({
780
+ "id": e.id,
781
+ "title": e.title.strip(),
782
+ "arxiv_id": full_id.split("/")[-1],
783
+ })
784
+ return hits
785
+ except Exception as e:
786
+ return [
787
+ {
788
+ "id": _hash(query + str(time.time())),
789
+ "title": "Search error",
790
+ "error": str(e),
791
+ }
792
+ ]
793
+
794
+ def _materialize(self, hit: Dict[str, Any]) -> ItemMetadata:
795
+ arxiv_id = self._id(hit)
796
+ title = hit.get("title", "")
797
+ pdf_url = f"https://arxiv.org/pdf/{arxiv_id}.pdf"
798
+ local_path = os.path.join(self.database_path, f"{arxiv_id}.pdf")
799
+ full_text = ""
800
+ try:
801
+ _download(pdf_url, local_path)
802
+ full_text = _load_pdf_text(local_path)
803
+ except Exception as e:
804
+ full_text = f"[Error loading ArXiv {arxiv_id}: {e}]"
805
+ full_text = self._postprocess_text(full_text, local_path)
806
+ return {
807
+ "id": arxiv_id,
808
+ "title": title,
809
+ "url": pdf_url,
810
+ "local_path": local_path,
811
+ "full_text": full_text,
812
+ }