aiagents4pharma 1.39.0__py3-none-any.whl → 1.39.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiagents4pharma/talk2scholars/agents/main_agent.py +7 -7
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/main_agent/default.yaml +88 -12
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/paper_download_agent/default.yaml +5 -0
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/pdf_agent/default.yaml +5 -0
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/s2_agent/default.yaml +1 -20
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/zotero_agent/default.yaml +1 -26
- aiagents4pharma/talk2scholars/configs/tools/download_arxiv_paper/default.yaml +4 -0
- aiagents4pharma/talk2scholars/configs/tools/download_biorxiv_paper/default.yaml +2 -0
- aiagents4pharma/talk2scholars/configs/tools/download_medrxiv_paper/default.yaml +2 -0
- aiagents4pharma/talk2scholars/configs/tools/question_and_answer/default.yaml +22 -0
- aiagents4pharma/talk2scholars/tests/test_main_agent.py +20 -2
- aiagents4pharma/talk2scholars/tests/test_nvidia_nim_reranker_utils.py +28 -0
- aiagents4pharma/talk2scholars/tests/test_paper_download_tools.py +107 -29
- aiagents4pharma/talk2scholars/tests/test_pdf_agent.py +2 -3
- aiagents4pharma/talk2scholars/tests/test_question_and_answer_tool.py +194 -543
- aiagents4pharma/talk2scholars/tests/test_s2_agent.py +2 -2
- aiagents4pharma/talk2scholars/tests/{test_s2_display.py → test_s2_display_dataframe.py} +2 -3
- aiagents4pharma/talk2scholars/tests/test_s2_query_dataframe.py +201 -0
- aiagents4pharma/talk2scholars/tests/test_s2_retrieve.py +7 -6
- aiagents4pharma/talk2scholars/tests/test_s2_utils_ext_ids.py +413 -0
- aiagents4pharma/talk2scholars/tests/test_tool_helper_utils.py +140 -0
- aiagents4pharma/talk2scholars/tests/test_zotero_agent.py +0 -1
- aiagents4pharma/talk2scholars/tests/test_zotero_read.py +16 -18
- aiagents4pharma/talk2scholars/tools/paper_download/download_arxiv_input.py +92 -37
- aiagents4pharma/talk2scholars/tools/pdf/question_and_answer.py +73 -575
- aiagents4pharma/talk2scholars/tools/pdf/utils/__init__.py +10 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/generate_answer.py +97 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/nvidia_nim_reranker.py +77 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/retrieve_chunks.py +83 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/tool_helper.py +125 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/vector_store.py +162 -0
- aiagents4pharma/talk2scholars/tools/s2/display_dataframe.py +33 -10
- aiagents4pharma/talk2scholars/tools/s2/multi_paper_rec.py +39 -16
- aiagents4pharma/talk2scholars/tools/s2/query_dataframe.py +124 -10
- aiagents4pharma/talk2scholars/tools/s2/retrieve_semantic_scholar_paper_id.py +49 -17
- aiagents4pharma/talk2scholars/tools/s2/search.py +39 -16
- aiagents4pharma/talk2scholars/tools/s2/single_paper_rec.py +34 -16
- aiagents4pharma/talk2scholars/tools/s2/utils/multi_helper.py +49 -16
- aiagents4pharma/talk2scholars/tools/s2/utils/search_helper.py +51 -16
- aiagents4pharma/talk2scholars/tools/s2/utils/single_helper.py +50 -17
- {aiagents4pharma-1.39.0.dist-info → aiagents4pharma-1.39.2.dist-info}/METADATA +58 -105
- {aiagents4pharma-1.39.0.dist-info → aiagents4pharma-1.39.2.dist-info}/RECORD +45 -32
- aiagents4pharma/talk2scholars/tests/test_llm_main_integration.py +0 -89
- aiagents4pharma/talk2scholars/tests/test_routing_logic.py +0 -74
- aiagents4pharma/talk2scholars/tests/test_s2_query.py +0 -95
- {aiagents4pharma-1.39.0.dist-info → aiagents4pharma-1.39.2.dist-info}/WHEEL +0 -0
- {aiagents4pharma-1.39.0.dist-info → aiagents4pharma-1.39.2.dist-info}/licenses/LICENSE +0 -0
- {aiagents4pharma-1.39.0.dist-info → aiagents4pharma-1.39.2.dist-info}/top_level.txt +0 -0
@@ -1,635 +1,133 @@
|
|
1
1
|
"""
|
2
|
-
PDF
|
3
|
-
|
4
|
-
This
|
5
|
-
|
6
|
-
|
7
|
-
|
2
|
+
LangGraph PDF Retrieval-Augmented Generation (RAG) Tool
|
3
|
+
|
4
|
+
This tool answers user questions by retrieving and ranking relevant text chunks from PDFs
|
5
|
+
and invoking an LLM to generate a concise, source-attributed response. It supports
|
6
|
+
single or multiple PDF sources—such as Zotero libraries, arXiv papers, or direct uploads.
|
7
|
+
|
8
|
+
Workflow:
|
9
|
+
1. (Optional) Load PDFs from diverse sources into a FAISS vector store of embeddings.
|
10
|
+
2. Rerank candidate papers using NVIDIA NIM semantic re-ranker.
|
11
|
+
3. Retrieve top-K diverse text chunks via Maximal Marginal Relevance (MMR).
|
12
|
+
4. Build a context-rich prompt combining retrieved chunks and the user question.
|
13
|
+
5. Invoke the LLM to craft a clear answer with source citations.
|
14
|
+
6. Return the answer in a ToolMessage for LangGraph to dispatch.
|
8
15
|
"""
|
9
16
|
|
10
17
|
import logging
|
11
18
|
import os
|
12
19
|
import time
|
13
|
-
from typing import Annotated, Any
|
20
|
+
from typing import Annotated, Any
|
14
21
|
|
15
|
-
import hydra
|
16
|
-
import numpy as np
|
17
|
-
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
18
|
-
from langchain_community.document_loaders import PyPDFLoader
|
19
|
-
from langchain_community.vectorstores import FAISS
|
20
|
-
from langchain_core.documents import Document
|
21
|
-
from langchain_core.embeddings import Embeddings
|
22
|
-
from langchain_core.language_models.chat_models import BaseChatModel
|
23
22
|
from langchain_core.messages import ToolMessage
|
24
23
|
from langchain_core.tools import tool
|
25
24
|
from langchain_core.tools.base import InjectedToolCallId
|
26
|
-
from langchain_core.vectorstores import VectorStore
|
27
|
-
from langchain_core.vectorstores.utils import maximal_marginal_relevance
|
28
|
-
from langchain_nvidia_ai_endpoints import NVIDIARerank
|
29
25
|
from langgraph.prebuilt import InjectedState
|
30
26
|
from langgraph.types import Command
|
31
27
|
from pydantic import BaseModel, Field
|
32
28
|
|
29
|
+
from .utils.generate_answer import load_hydra_config
|
30
|
+
from .utils.retrieve_chunks import retrieve_relevant_chunks
|
31
|
+
from .utils.tool_helper import QAToolHelper
|
32
|
+
|
33
|
+
# Helper for managing state, vectorstore, reranking, and formatting
|
34
|
+
helper = QAToolHelper()
|
35
|
+
# Load configuration and start logging
|
36
|
+
config = load_hydra_config()
|
37
|
+
|
33
38
|
# Set up logging with configurable level
|
34
39
|
log_level = os.environ.get("LOG_LEVEL", "INFO")
|
35
40
|
logging.basicConfig(level=getattr(logging, log_level))
|
36
41
|
logger = logging.getLogger(__name__)
|
37
42
|
logger.setLevel(getattr(logging, log_level))
|
38
|
-
# pylint: disable=too-many-arguments, too-many-positional-arguments, too-many-locals, too-many-branches, too-many-statements
|
39
|
-
|
40
|
-
|
41
|
-
def load_hydra_config() -> Any:
|
42
|
-
"""
|
43
|
-
Load the configuration using Hydra and return the configuration for the Q&A tool.
|
44
|
-
"""
|
45
|
-
with hydra.initialize(version_base=None, config_path="../../configs"):
|
46
|
-
cfg = hydra.compose(
|
47
|
-
config_name="config",
|
48
|
-
overrides=["tools/question_and_answer=default"],
|
49
|
-
)
|
50
|
-
config = cfg.tools.question_and_answer
|
51
|
-
logger.info("Loaded Question and Answer tool configuration.")
|
52
|
-
return config
|
53
43
|
|
54
44
|
|
55
45
|
class QuestionAndAnswerInput(BaseModel):
|
56
46
|
"""
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
- 'llm_model': chat/LLM instance
|
68
|
-
- 'vector_store': pre-built Vectorstore for retrieval
|
47
|
+
Pydantic schema for the PDF Q&A tool inputs.
|
48
|
+
|
49
|
+
Fields:
|
50
|
+
question: User's free-text query to answer based on PDF content.
|
51
|
+
tool_call_id: LangGraph-injected call identifier for tracking.
|
52
|
+
state: Shared agent state dict containing:
|
53
|
+
- article_data: metadata mapping of paper IDs to info (e.g., 'pdf_url', title).
|
54
|
+
- text_embedding_model: embedding model instance for chunk indexing.
|
55
|
+
- llm_model: chat/LLM instance for answer generation.
|
56
|
+
- vector_store: optional pre-built Vectorstore for retrieval.
|
69
57
|
"""
|
70
58
|
|
71
|
-
question: str = Field(
|
72
|
-
|
73
|
-
default=None,
|
74
|
-
description="Optional list of specific paper IDs to query. "
|
75
|
-
"If not provided, relevant papers will be selected automatically.",
|
76
|
-
)
|
77
|
-
use_all_papers: bool = Field(
|
78
|
-
default=False,
|
79
|
-
description="Whether to use all available papers for answering the question. "
|
80
|
-
"Set to True to bypass relevance filtering and include all loaded papers.",
|
59
|
+
question: str = Field(
|
60
|
+
description="User question for generating a PDF-based answer."
|
81
61
|
)
|
82
62
|
tool_call_id: Annotated[str, InjectedToolCallId]
|
83
63
|
state: Annotated[dict, InjectedState]
|
84
64
|
|
85
65
|
|
86
|
-
class Vectorstore:
|
87
|
-
"""
|
88
|
-
A class for managing document embeddings and retrieval.
|
89
|
-
Provides unified access to documents across multiple papers.
|
90
|
-
"""
|
91
|
-
|
92
|
-
def __init__(
|
93
|
-
self,
|
94
|
-
embedding_model: Embeddings,
|
95
|
-
metadata_fields: Optional[List[str]] = None,
|
96
|
-
):
|
97
|
-
"""
|
98
|
-
Initialize the document store.
|
99
|
-
|
100
|
-
Args:
|
101
|
-
embedding_model: The embedding model to use
|
102
|
-
metadata_fields: Fields to include in document metadata for filtering/retrieval
|
103
|
-
"""
|
104
|
-
self.embedding_model = embedding_model
|
105
|
-
self.metadata_fields = metadata_fields or [
|
106
|
-
"title",
|
107
|
-
"paper_id",
|
108
|
-
"page",
|
109
|
-
"chunk_id",
|
110
|
-
]
|
111
|
-
self.initialization_time = time.time()
|
112
|
-
logger.info("Vectorstore initialized at: %s", self.initialization_time)
|
113
|
-
|
114
|
-
# Track loaded papers to prevent duplicate loading
|
115
|
-
self.loaded_papers = set()
|
116
|
-
self.vector_store_class = FAISS
|
117
|
-
logger.info("Using FAISS vector store")
|
118
|
-
|
119
|
-
# Store for initialized documents
|
120
|
-
self.documents: Dict[str, Document] = {}
|
121
|
-
self.vector_store: Optional[VectorStore] = None
|
122
|
-
self.paper_metadata: Dict[str, Dict[str, Any]] = {}
|
123
|
-
# Cache for document chunk embeddings to avoid recomputation
|
124
|
-
self.embeddings: Dict[str, Any] = {}
|
125
|
-
|
126
|
-
def add_paper(
|
127
|
-
self,
|
128
|
-
paper_id: str,
|
129
|
-
pdf_url: str,
|
130
|
-
paper_metadata: Dict[str, Any],
|
131
|
-
) -> None:
|
132
|
-
"""
|
133
|
-
Add a paper to the document store.
|
134
|
-
|
135
|
-
Args:
|
136
|
-
paper_id: Unique identifier for the paper
|
137
|
-
pdf_url: URL to the PDF
|
138
|
-
paper_metadata: Metadata about the paper
|
139
|
-
"""
|
140
|
-
# Skip if already loaded
|
141
|
-
if paper_id in self.loaded_papers:
|
142
|
-
logger.info("Paper %s already loaded, skipping", paper_id)
|
143
|
-
return
|
144
|
-
|
145
|
-
logger.info("Loading paper %s from %s", paper_id, pdf_url)
|
146
|
-
|
147
|
-
# Store paper metadata
|
148
|
-
self.paper_metadata[paper_id] = paper_metadata
|
149
|
-
|
150
|
-
# Load the PDF and split into chunks according to Hydra config
|
151
|
-
loader = PyPDFLoader(pdf_url)
|
152
|
-
documents = loader.load()
|
153
|
-
logger.info("Loaded %d pages from %s", len(documents), paper_id)
|
154
|
-
|
155
|
-
# Create text splitter according to Hydra config
|
156
|
-
cfg = load_hydra_config()
|
157
|
-
splitter = RecursiveCharacterTextSplitter(
|
158
|
-
chunk_size=cfg.chunk_size,
|
159
|
-
chunk_overlap=cfg.chunk_overlap,
|
160
|
-
separators=["\n\n", "\n", ". ", " ", ""],
|
161
|
-
)
|
162
|
-
|
163
|
-
# Split documents and add metadata for each chunk
|
164
|
-
chunks = splitter.split_documents(documents)
|
165
|
-
logger.info("Split %s into %d chunks", paper_id, len(chunks))
|
166
|
-
# Embed and cache chunk embeddings
|
167
|
-
chunk_texts = [chunk.page_content for chunk in chunks]
|
168
|
-
chunk_embeddings = self.embedding_model.embed_documents(chunk_texts)
|
169
|
-
logger.info("Embedded %d chunks for paper %s", len(chunk_embeddings), paper_id)
|
170
|
-
|
171
|
-
# Enhance document metadata
|
172
|
-
for i, chunk in enumerate(chunks):
|
173
|
-
# Add paper metadata to each chunk
|
174
|
-
chunk.metadata.update(
|
175
|
-
{
|
176
|
-
"paper_id": paper_id,
|
177
|
-
"title": paper_metadata.get("Title", "Unknown"),
|
178
|
-
"chunk_id": i,
|
179
|
-
# Keep existing page number if available
|
180
|
-
"page": chunk.metadata.get("page", 0),
|
181
|
-
}
|
182
|
-
)
|
183
|
-
|
184
|
-
# Add any additional metadata fields
|
185
|
-
for field in self.metadata_fields:
|
186
|
-
if field in paper_metadata and field not in chunk.metadata:
|
187
|
-
chunk.metadata[field] = paper_metadata[field]
|
188
|
-
|
189
|
-
# Store chunk
|
190
|
-
doc_id = f"{paper_id}_{i}"
|
191
|
-
self.documents[doc_id] = chunk
|
192
|
-
# Cache embedding if available
|
193
|
-
if chunk_embeddings[i] is not None:
|
194
|
-
self.embeddings[doc_id] = chunk_embeddings[i]
|
195
|
-
|
196
|
-
# Mark as loaded to prevent duplicate loading
|
197
|
-
self.loaded_papers.add(paper_id)
|
198
|
-
logger.info("Added %d chunks from paper %s", len(chunks), paper_id)
|
199
|
-
|
200
|
-
def build_vector_store(self) -> None:
|
201
|
-
"""
|
202
|
-
Build the vector store from all loaded documents.
|
203
|
-
Should be called after all papers are added.
|
204
|
-
"""
|
205
|
-
if not self.documents:
|
206
|
-
logger.warning("No documents added to build vector store")
|
207
|
-
return
|
208
|
-
|
209
|
-
if self.vector_store is not None:
|
210
|
-
logger.info("Vector store already built, skipping")
|
211
|
-
return
|
212
|
-
|
213
|
-
# Create vector store from documents
|
214
|
-
documents_list = list(self.documents.values())
|
215
|
-
self.vector_store = self.vector_store_class.from_documents(
|
216
|
-
documents=documents_list, embedding=self.embedding_model
|
217
|
-
)
|
218
|
-
logger.info("Built vector store with %d documents", len(documents_list))
|
219
|
-
|
220
|
-
def rank_papers_by_query(
|
221
|
-
self, query: str, top_k: int = 40
|
222
|
-
) -> List[Tuple[str, float]]:
|
223
|
-
"""
|
224
|
-
Rank papers by relevance to the query using NVIDIA's off-the-shelf re-ranker.
|
225
|
-
|
226
|
-
This function aggregates all chunks per paper, ranks them using the NVIDIA model,
|
227
|
-
and returns the top-k papers.
|
228
|
-
|
229
|
-
Args:
|
230
|
-
query (str): The query string.
|
231
|
-
top_k (int): Number of top papers to return.
|
232
|
-
|
233
|
-
Returns:
|
234
|
-
List of tuples (paper_id, dummy_score) sorted by relevance.
|
235
|
-
"""
|
236
|
-
|
237
|
-
# Aggregate all document chunks for each paper
|
238
|
-
paper_texts = {}
|
239
|
-
for doc in self.documents.values():
|
240
|
-
paper_id = doc.metadata["paper_id"]
|
241
|
-
paper_texts.setdefault(paper_id, []).append(doc.page_content)
|
242
|
-
|
243
|
-
aggregated_documents = []
|
244
|
-
for paper_id, texts in paper_texts.items():
|
245
|
-
aggregated_text = " ".join(texts)
|
246
|
-
aggregated_documents.append(
|
247
|
-
Document(page_content=aggregated_text, metadata={"paper_id": paper_id})
|
248
|
-
)
|
249
|
-
|
250
|
-
# Instantiate the NVIDIA re-ranker client
|
251
|
-
config = load_hydra_config()
|
252
|
-
reranker = NVIDIARerank(
|
253
|
-
model=config.reranker.model,
|
254
|
-
api_key=config.reranker.api_key,
|
255
|
-
)
|
256
|
-
|
257
|
-
# Get the ranked list of documents based on the query
|
258
|
-
response = reranker.compress_documents(
|
259
|
-
query=query, documents=aggregated_documents
|
260
|
-
)
|
261
|
-
|
262
|
-
ranked_papers = [doc.metadata["paper_id"] for doc in response[:top_k]]
|
263
|
-
return ranked_papers
|
264
|
-
|
265
|
-
def retrieve_relevant_chunks(
|
266
|
-
self,
|
267
|
-
query: str,
|
268
|
-
paper_ids: Optional[List[str]] = None,
|
269
|
-
top_k: int = 25,
|
270
|
-
mmr_diversity: float = 1.00,
|
271
|
-
) -> List[Document]:
|
272
|
-
"""
|
273
|
-
Retrieve the most relevant chunks for a query using maximal marginal relevance.
|
274
|
-
|
275
|
-
Args:
|
276
|
-
query: Query string
|
277
|
-
paper_ids: Optional list of paper IDs to filter by
|
278
|
-
top_k: Number of chunks to retrieve
|
279
|
-
mmr_diversity: Diversity parameter for MMR (higher = more diverse)
|
280
|
-
|
281
|
-
Returns:
|
282
|
-
List of document chunks
|
283
|
-
"""
|
284
|
-
if not self.vector_store:
|
285
|
-
logger.error("Failed to build vector store")
|
286
|
-
return []
|
287
|
-
|
288
|
-
if paper_ids:
|
289
|
-
logger.info("Filtering retrieval to papers: %s", paper_ids)
|
290
|
-
|
291
|
-
# Step 1: Embed the query
|
292
|
-
logger.info(
|
293
|
-
"Embedding query using model: %s", type(self.embedding_model).__name__
|
294
|
-
)
|
295
|
-
query_embedding = np.array(self.embedding_model.embed_query(query))
|
296
|
-
|
297
|
-
# Step 2: Filter relevant documents
|
298
|
-
all_docs = [
|
299
|
-
doc
|
300
|
-
for doc in self.documents.values()
|
301
|
-
if not paper_ids or doc.metadata["paper_id"] in paper_ids
|
302
|
-
]
|
303
|
-
|
304
|
-
if not all_docs:
|
305
|
-
logger.warning("No documents found after filtering by paper_ids.")
|
306
|
-
return []
|
307
|
-
|
308
|
-
# Step 3: Retrieve or compute embeddings for all documents using cache
|
309
|
-
logger.info("Retrieving embeddings for %d chunks...", len(all_docs))
|
310
|
-
all_embeddings = []
|
311
|
-
for doc in all_docs:
|
312
|
-
doc_id = f"{doc.metadata['paper_id']}_{doc.metadata['chunk_id']}"
|
313
|
-
if doc_id not in self.embeddings:
|
314
|
-
logger.info("Embedding missing chunk %s", doc_id)
|
315
|
-
emb = self.embedding_model.embed_documents([doc.page_content])[0]
|
316
|
-
self.embeddings[doc_id] = emb
|
317
|
-
all_embeddings.append(self.embeddings[doc_id])
|
318
|
-
|
319
|
-
# Step 4: Apply MMR
|
320
|
-
mmr_indices = maximal_marginal_relevance(
|
321
|
-
query_embedding,
|
322
|
-
all_embeddings,
|
323
|
-
k=top_k,
|
324
|
-
lambda_mult=mmr_diversity,
|
325
|
-
)
|
326
|
-
|
327
|
-
results = [all_docs[i] for i in mmr_indices]
|
328
|
-
logger.info("Retrieved %d chunks using MMR", len(results))
|
329
|
-
return results
|
330
|
-
|
331
|
-
|
332
|
-
def generate_answer(
|
333
|
-
question: str,
|
334
|
-
retrieved_chunks: List[Document],
|
335
|
-
llm_model: BaseChatModel,
|
336
|
-
config: Optional[Any] = None,
|
337
|
-
) -> Dict[str, Any]:
|
338
|
-
"""
|
339
|
-
Generate an answer for a question using retrieved chunks.
|
340
|
-
|
341
|
-
Args:
|
342
|
-
question (str): The question to answer
|
343
|
-
retrieved_chunks (List[Document]): List of relevant document chunks
|
344
|
-
llm_model (BaseChatModel): Language model for generating answers
|
345
|
-
config (Optional[Any]): Configuration for answer generation
|
346
|
-
|
347
|
-
Returns:
|
348
|
-
Dict[str, Any]: Dictionary with the answer and metadata
|
349
|
-
"""
|
350
|
-
# Load configuration using the global function.
|
351
|
-
config = load_hydra_config()
|
352
|
-
|
353
|
-
# Ensure the configuration is not None and has the prompt_template.
|
354
|
-
if config is None:
|
355
|
-
raise ValueError("Hydra config loading failed: config is None.")
|
356
|
-
if "prompt_template" not in config:
|
357
|
-
raise ValueError("The prompt_template is missing from the configuration.")
|
358
|
-
|
359
|
-
# Prepare context from retrieved documents with source attribution.
|
360
|
-
# Group chunks by paper_id
|
361
|
-
papers = {}
|
362
|
-
for doc in retrieved_chunks:
|
363
|
-
paper_id = doc.metadata.get("paper_id", "unknown")
|
364
|
-
if paper_id not in papers:
|
365
|
-
papers[paper_id] = []
|
366
|
-
papers[paper_id].append(doc)
|
367
|
-
|
368
|
-
# Format chunks by paper
|
369
|
-
formatted_chunks = []
|
370
|
-
doc_index = 1
|
371
|
-
for paper_id, chunks in papers.items():
|
372
|
-
# Get the title from the first chunk (should be the same for all chunks)
|
373
|
-
title = chunks[0].metadata.get("title", "Unknown")
|
374
|
-
|
375
|
-
# Add a document header
|
376
|
-
formatted_chunks.append(
|
377
|
-
f"[Document {doc_index}] From: '{title}' (ID: {paper_id})"
|
378
|
-
)
|
379
|
-
|
380
|
-
# Add each chunk with its page information
|
381
|
-
for chunk in chunks:
|
382
|
-
page = chunk.metadata.get("page", "unknown")
|
383
|
-
formatted_chunks.append(f"Page {page}: {chunk.page_content}")
|
384
|
-
|
385
|
-
# Increment document index for the next paper
|
386
|
-
doc_index += 1
|
387
|
-
|
388
|
-
# Join all chunks
|
389
|
-
context = "\n\n".join(formatted_chunks)
|
390
|
-
|
391
|
-
# Get unique paper sources.
|
392
|
-
paper_sources = {doc.metadata["paper_id"] for doc in retrieved_chunks}
|
393
|
-
|
394
|
-
# Create prompt using the Hydra-provided prompt_template.
|
395
|
-
prompt = config["prompt_template"].format(context=context, question=question)
|
396
|
-
|
397
|
-
# Get the answer from the language model
|
398
|
-
response = llm_model.invoke(prompt)
|
399
|
-
|
400
|
-
# Return the response with metadata
|
401
|
-
return {
|
402
|
-
"output_text": response.content,
|
403
|
-
"sources": [doc.metadata for doc in retrieved_chunks],
|
404
|
-
"num_sources": len(retrieved_chunks),
|
405
|
-
"papers_used": list(paper_sources),
|
406
|
-
}
|
407
|
-
|
408
|
-
|
409
|
-
# Shared pre-built Vectorstore for RAG (set externally, e.g., by Streamlit startup)
|
410
|
-
prebuilt_vector_store: Optional[Vectorstore] = None
|
411
|
-
|
412
|
-
|
413
66
|
@tool(args_schema=QuestionAndAnswerInput, parse_docstring=True)
|
414
67
|
def question_and_answer(
|
415
68
|
question: str,
|
416
69
|
state: Annotated[dict, InjectedState],
|
417
70
|
tool_call_id: Annotated[str, InjectedToolCallId],
|
418
|
-
paper_ids: Optional[List[str]] = None,
|
419
|
-
use_all_papers: bool = False,
|
420
71
|
) -> Command[Any]:
|
421
72
|
"""
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
73
|
+
LangGraph tool for Retrieval-Augmented Generation over PDFs.
|
74
|
+
|
75
|
+
Given a user question, this tool applies the following pipeline:
|
76
|
+
1. Validates that embedding and LLM models, plus article metadata, are in state.
|
77
|
+
2. Initializes or reuses a FAISS-based Vectorstore for PDF embeddings.
|
78
|
+
3. Loads one or more PDFs (from Zotero, arXiv, uploads) as text chunks into the store.
|
79
|
+
4. Uses NVIDIA NIM semantic re-ranker to select top candidate papers.
|
80
|
+
5. Retrieves the most relevant and diverse text chunks via Maximal Marginal Relevance.
|
81
|
+
6. Constructs an LLM prompt combining contextual chunks and the query.
|
82
|
+
7. Invokes the LLM to generate an answer, appending source attributions.
|
83
|
+
8. Returns a LangGraph Command with a ToolMessage containing the answer.
|
427
84
|
|
428
85
|
Args:
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
paper_ids (Optional[List[str]]): Specific paper IDs to restrict retrieval (default: None).
|
436
|
-
use_all_papers (bool): If True, bypasses semantic ranking and includes all papers.
|
86
|
+
question (str): The free-text question to answer.
|
87
|
+
state (dict): Injected agent state; must include:
|
88
|
+
- article_data: mapping paper IDs → metadata (pdf_url, title, etc.)
|
89
|
+
- text_embedding_model: embedding model instance.
|
90
|
+
- llm_model: chat/LLM instance.
|
91
|
+
tool_call_id (str): Internal identifier for this tool invocation.
|
437
92
|
|
438
93
|
Returns:
|
439
|
-
|
440
|
-
- 'messages': a single ToolMessage containing the generated answer text.
|
94
|
+
Command[Any]: updates conversation state with a ToolMessage(answer).
|
441
95
|
|
442
96
|
Raises:
|
443
|
-
|
444
|
-
|
97
|
+
ValueError: when required models or metadata are missing in state.
|
98
|
+
RuntimeError: when no relevant chunks can be retrieved for the query.
|
445
99
|
"""
|
446
|
-
# Load configuration
|
447
|
-
config = load_hydra_config()
|
448
|
-
# Create a unique identifier for this call to track potential infinite loops
|
449
100
|
call_id = f"qa_call_{time.time()}"
|
450
101
|
logger.info(
|
451
102
|
"Starting PDF Question and Answer tool call %s for question: %s",
|
452
103
|
call_id,
|
453
104
|
question,
|
454
105
|
)
|
106
|
+
helper.start_call(config, call_id)
|
455
107
|
|
456
|
-
#
|
457
|
-
|
458
|
-
if not text_embedding_model:
|
459
|
-
error_msg = "No text embedding model found in state."
|
460
|
-
logger.error("%s: %s", call_id, error_msg)
|
461
|
-
raise ValueError(error_msg)
|
462
|
-
|
463
|
-
llm_model = state.get("llm_model")
|
464
|
-
if not llm_model:
|
465
|
-
error_msg = "No LLM model found in state."
|
466
|
-
logger.error("%s: %s", call_id, error_msg)
|
467
|
-
raise ValueError(error_msg)
|
468
|
-
|
469
|
-
# Get article data from state
|
470
|
-
article_data = state.get("article_data", {})
|
471
|
-
if not article_data:
|
472
|
-
error_msg = "No article_data found in state."
|
473
|
-
logger.error("%s: %s", call_id, error_msg)
|
474
|
-
raise ValueError(error_msg)
|
108
|
+
# Extract models and article metadata
|
109
|
+
text_emb, llm_model, article_data = helper.get_state_models_and_data(state)
|
475
110
|
|
476
|
-
#
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
vector_store = Vectorstore(embedding_model=text_embedding_model)
|
482
|
-
logger.info("Initialized new vector store (no pre-built store found)")
|
111
|
+
# Initialize or reuse vector store, then load candidate papers
|
112
|
+
vs = helper.init_vector_store(text_emb)
|
113
|
+
candidate_ids = list(article_data.keys())
|
114
|
+
logger.info("%s: Candidate paper IDs for reranking: %s", call_id, candidate_ids)
|
115
|
+
helper.load_candidate_papers(vs, article_data, candidate_ids)
|
483
116
|
|
484
|
-
#
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
if isinstance(paper, dict)
|
117
|
+
# Rerank papers and retrieve top chunks
|
118
|
+
selected_ids = helper.run_reranker(vs, question, candidate_ids)
|
119
|
+
relevant_chunks = retrieve_relevant_chunks(
|
120
|
+
vs, query=question, paper_ids=selected_ids, top_k=config.top_k_chunks
|
489
121
|
)
|
490
|
-
|
491
|
-
has_zotero_papers = any(
|
492
|
-
paper.get("source") == "zotero"
|
493
|
-
for paper in article_data.values()
|
494
|
-
if isinstance(paper, dict)
|
495
|
-
)
|
496
|
-
|
497
|
-
has_arxiv_papers = any(
|
498
|
-
paper.get("source") == "arxiv"
|
499
|
-
for paper in article_data.values()
|
500
|
-
if isinstance(paper, dict)
|
501
|
-
)
|
502
|
-
|
503
|
-
has_biorxiv_papers = any(
|
504
|
-
paper.get("source") == "biorxiv"
|
505
|
-
for paper in article_data.values()
|
506
|
-
if isinstance(paper, dict)
|
507
|
-
)
|
508
|
-
|
509
|
-
has_medrxiv_papers = any(
|
510
|
-
paper.get("source") == "medrxiv"
|
511
|
-
for paper in article_data.values()
|
512
|
-
if isinstance(paper, dict)
|
513
|
-
)
|
514
|
-
|
515
|
-
# Choose papers to use
|
516
|
-
selected_paper_ids = []
|
517
|
-
has_combimed_papers = (
|
518
|
-
has_uploaded_papers
|
519
|
-
or has_zotero_papers
|
520
|
-
or has_arxiv_papers
|
521
|
-
or has_biorxiv_papers
|
522
|
-
or has_medrxiv_papers
|
523
|
-
)
|
524
|
-
|
525
|
-
if paper_ids:
|
526
|
-
# Use explicitly specified papers
|
527
|
-
selected_paper_ids = [pid for pid in paper_ids if pid in article_data]
|
528
|
-
logger.info(
|
529
|
-
"%s: Using explicitly specified papers: %s", call_id, selected_paper_ids
|
530
|
-
)
|
531
|
-
|
532
|
-
if not selected_paper_ids:
|
533
|
-
logger.warning(
|
534
|
-
"%s: None of the provided paper_ids %s were found", call_id, paper_ids
|
535
|
-
)
|
536
|
-
|
537
|
-
elif use_all_papers or has_combimed_papers:
|
538
|
-
# Use all available papers if explicitly requested or if we have papers from any source
|
539
|
-
selected_paper_ids = list(article_data.keys())
|
540
|
-
logger.info(
|
541
|
-
"%s: Using all %d available papers", call_id, len(selected_paper_ids)
|
542
|
-
)
|
543
|
-
|
544
|
-
else:
|
545
|
-
# Use semantic ranking to find relevant papers
|
546
|
-
# First ensure papers are loaded
|
547
|
-
for paper_id, paper in article_data.items():
|
548
|
-
pdf_url = paper.get("pdf_url")
|
549
|
-
if pdf_url and paper_id not in vector_store.loaded_papers:
|
550
|
-
try:
|
551
|
-
vector_store.add_paper(paper_id, pdf_url, paper)
|
552
|
-
except (IOError, ValueError) as e:
|
553
|
-
logger.error("Error loading paper %s: %s", paper_id, e)
|
554
|
-
raise
|
555
|
-
|
556
|
-
# Now rank papers
|
557
|
-
ranked_papers = vector_store.rank_papers_by_query(
|
558
|
-
question, top_k=config.top_k_papers
|
559
|
-
)
|
560
|
-
selected_paper_ids = [paper_id for paper_id, _ in ranked_papers]
|
561
|
-
logger.info(
|
562
|
-
"%s: Selected papers based on semantic relevance: %s",
|
563
|
-
call_id,
|
564
|
-
selected_paper_ids,
|
565
|
-
)
|
566
|
-
|
567
|
-
if not selected_paper_ids:
|
568
|
-
# Fallback to all papers if selection failed
|
569
|
-
selected_paper_ids = list(article_data.keys())
|
570
|
-
logger.info(
|
571
|
-
"%s: Falling back to all %d papers", call_id, len(selected_paper_ids)
|
572
|
-
)
|
573
|
-
|
574
|
-
# Load selected papers if needed
|
575
|
-
for paper_id in selected_paper_ids:
|
576
|
-
if paper_id not in vector_store.loaded_papers:
|
577
|
-
pdf_url = article_data[paper_id].get("pdf_url")
|
578
|
-
if pdf_url:
|
579
|
-
try:
|
580
|
-
vector_store.add_paper(paper_id, pdf_url, article_data[paper_id])
|
581
|
-
except (IOError, ValueError) as e:
|
582
|
-
logger.warning(
|
583
|
-
"%s: Error loading paper %s: %s", call_id, paper_id, e
|
584
|
-
)
|
585
|
-
|
586
|
-
# Ensure vector store is built
|
587
|
-
if not vector_store.vector_store:
|
588
|
-
vector_store.build_vector_store()
|
589
|
-
|
590
|
-
# Retrieve relevant chunks across selected papers
|
591
|
-
relevant_chunks = vector_store.retrieve_relevant_chunks(
|
592
|
-
query=question, paper_ids=selected_paper_ids, top_k=config.top_k_chunks
|
593
|
-
)
|
594
|
-
|
595
122
|
if not relevant_chunks:
|
596
|
-
|
597
|
-
logger.warning("%s: %s", call_id,
|
598
|
-
raise RuntimeError(
|
599
|
-
f"I couldn't find relevant information to answer your question: '{question}'. "
|
600
|
-
"Please try rephrasing or asking a different question."
|
601
|
-
)
|
602
|
-
|
603
|
-
# Generate answer using retrieved chunks
|
604
|
-
result = generate_answer(question, relevant_chunks, llm_model)
|
123
|
+
msg = f"No relevant chunks found for question: '{question}'"
|
124
|
+
logger.warning("%s: %s", call_id, msg)
|
125
|
+
raise RuntimeError(msg)
|
605
126
|
|
606
|
-
#
|
607
|
-
|
608
|
-
|
609
|
-
# Get paper titles for sources
|
610
|
-
paper_titles = {}
|
611
|
-
for paper_id in result.get("papers_used", []):
|
612
|
-
if paper_id in article_data:
|
613
|
-
paper_titles[paper_id] = article_data[paper_id].get(
|
614
|
-
"Title", "Unknown paper"
|
615
|
-
)
|
616
|
-
|
617
|
-
# Format source information
|
618
|
-
sources_text = ""
|
619
|
-
if paper_titles:
|
620
|
-
sources_text = "\n\nSources:\n" + "\n".join(
|
621
|
-
[f"- {title}" for title in paper_titles.values()]
|
622
|
-
)
|
623
|
-
|
624
|
-
# Prepare the final response
|
625
|
-
response_text = f"{answer_text}{sources_text}"
|
626
|
-
logger.info(
|
627
|
-
"%s: Successfully generated answer using %d chunks from %d papers",
|
628
|
-
call_id,
|
629
|
-
len(relevant_chunks),
|
630
|
-
len(paper_titles),
|
127
|
+
# Generate answer and format with sources
|
128
|
+
response_text = helper.format_answer(
|
129
|
+
question, relevant_chunks, llm_model, article_data
|
631
130
|
)
|
632
|
-
|
633
131
|
return Command(
|
634
132
|
update={
|
635
133
|
"messages": [
|