aiagents4pharma 1.31.0__py3-none-any.whl → 1.33.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiagents4pharma/talk2knowledgegraphs/configs/config.yaml +1 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_uniprot.py +44 -0
- aiagents4pharma/talk2knowledgegraphs/utils/enrichments/__init__.py +1 -0
- aiagents4pharma/talk2knowledgegraphs/utils/enrichments/uniprot_proteins.py +90 -0
- aiagents4pharma/talk2scholars/agents/main_agent.py +4 -3
- aiagents4pharma/talk2scholars/agents/paper_download_agent.py +3 -4
- aiagents4pharma/talk2scholars/agents/pdf_agent.py +6 -7
- aiagents4pharma/talk2scholars/agents/s2_agent.py +23 -20
- aiagents4pharma/talk2scholars/agents/zotero_agent.py +11 -11
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/main_agent/default.yaml +19 -19
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/s2_agent/default.yaml +20 -15
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/zotero_agent/default.yaml +27 -6
- aiagents4pharma/talk2scholars/state/state_talk2scholars.py +7 -7
- aiagents4pharma/talk2scholars/tests/test_main_agent.py +16 -16
- aiagents4pharma/talk2scholars/tests/test_paper_download_agent.py +17 -24
- aiagents4pharma/talk2scholars/tests/test_paper_download_tools.py +152 -135
- aiagents4pharma/talk2scholars/tests/test_pdf_agent.py +9 -16
- aiagents4pharma/talk2scholars/tests/test_question_and_answer_tool.py +790 -218
- aiagents4pharma/talk2scholars/tests/test_s2_agent.py +9 -9
- aiagents4pharma/talk2scholars/tests/test_s2_display.py +8 -8
- aiagents4pharma/talk2scholars/tests/test_s2_query.py +8 -8
- aiagents4pharma/talk2scholars/tests/test_zotero_agent.py +12 -12
- aiagents4pharma/talk2scholars/tests/test_zotero_path.py +11 -12
- aiagents4pharma/talk2scholars/tests/test_zotero_read.py +400 -22
- aiagents4pharma/talk2scholars/tools/paper_download/__init__.py +0 -6
- aiagents4pharma/talk2scholars/tools/paper_download/download_arxiv_input.py +89 -31
- aiagents4pharma/talk2scholars/tools/pdf/question_and_answer.py +540 -156
- aiagents4pharma/talk2scholars/tools/s2/__init__.py +4 -4
- aiagents4pharma/talk2scholars/tools/s2/{display_results.py → display_dataframe.py} +19 -21
- aiagents4pharma/talk2scholars/tools/s2/query_dataframe.py +71 -0
- aiagents4pharma/talk2scholars/tools/zotero/utils/read_helper.py +213 -35
- aiagents4pharma/talk2scholars/tools/zotero/zotero_read.py +3 -3
- {aiagents4pharma-1.31.0.dist-info → aiagents4pharma-1.33.0.dist-info}/METADATA +3 -1
- {aiagents4pharma-1.31.0.dist-info → aiagents4pharma-1.33.0.dist-info}/RECORD +37 -37
- {aiagents4pharma-1.31.0.dist-info → aiagents4pharma-1.33.0.dist-info}/WHEEL +1 -1
- aiagents4pharma/talk2scholars/tools/paper_download/abstract_downloader.py +0 -45
- aiagents4pharma/talk2scholars/tools/paper_download/arxiv_downloader.py +0 -115
- aiagents4pharma/talk2scholars/tools/s2/query_results.py +0 -61
- {aiagents4pharma-1.31.0.dist-info → aiagents4pharma-1.33.0.dist-info}/licenses/LICENSE +0 -0
- {aiagents4pharma-1.31.0.dist-info → aiagents4pharma-1.33.0.dist-info}/top_level.txt +0 -0
@@ -1,217 +1,601 @@
|
|
1
|
-
#!/usr/bin/env python3
|
2
1
|
"""
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
user-provided question using a language model chain.
|
2
|
+
Tool for performing Q&A on PDF documents using retrieval augmented generation.
|
3
|
+
This module provides functionality to load PDFs from URLs, split them into
|
4
|
+
chunks, retrieve relevant segments via semantic search, and generate answers
|
5
|
+
to user-provided questions using a language model chain.
|
8
6
|
"""
|
9
7
|
|
10
|
-
import io
|
11
8
|
import logging
|
12
|
-
|
9
|
+
import os
|
10
|
+
import time
|
11
|
+
from typing import Annotated, Any, Dict, List, Optional, Tuple
|
13
12
|
|
14
|
-
from PyPDF2 import PdfReader
|
15
|
-
from pydantic import BaseModel, Field
|
16
13
|
import hydra
|
17
|
-
|
18
|
-
from langchain.
|
19
|
-
from
|
20
|
-
from
|
14
|
+
import numpy as np
|
15
|
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
16
|
+
from langchain_community.document_loaders import PyPDFLoader
|
17
|
+
from langchain_community.vectorstores import FAISS
|
18
|
+
from langchain_core.documents import Document
|
19
|
+
from langchain_core.embeddings import Embeddings
|
21
20
|
from langchain_core.language_models.chat_models import BaseChatModel
|
22
|
-
from langchain_core.vectorstores import InMemoryVectorStore
|
23
21
|
from langchain_core.messages import ToolMessage
|
24
22
|
from langchain_core.tools import tool
|
25
23
|
from langchain_core.tools.base import InjectedToolCallId
|
26
|
-
from langchain_core.
|
27
|
-
from
|
28
|
-
from
|
29
|
-
from langchain_openai import OpenAIEmbeddings
|
30
|
-
from langgraph.types import Command
|
24
|
+
from langchain_core.vectorstores import VectorStore
|
25
|
+
from langchain_core.vectorstores.utils import maximal_marginal_relevance
|
26
|
+
from langchain_nvidia_ai_endpoints import NVIDIARerank
|
31
27
|
from langgraph.prebuilt import InjectedState
|
28
|
+
from langgraph.types import Command
|
29
|
+
from pydantic import BaseModel, Field
|
32
30
|
|
33
|
-
# Set up logging
|
34
|
-
|
31
|
+
# Set up logging with configurable level
|
32
|
+
log_level = os.environ.get("LOG_LEVEL", "INFO")
|
33
|
+
logging.basicConfig(level=getattr(logging, log_level))
|
35
34
|
logger = logging.getLogger(__name__)
|
36
|
-
logger.setLevel(logging
|
35
|
+
logger.setLevel(getattr(logging, log_level))
|
36
|
+
# pylint: disable=too-many-arguments, too-many-positional-arguments, too-many-locals, too-many-branches, too-many-statements
|
37
37
|
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
38
|
+
|
39
|
+
def load_hydra_config() -> Any:
|
40
|
+
"""
|
41
|
+
Load the configuration using Hydra and return the configuration for the Q&A tool.
|
42
|
+
"""
|
43
|
+
with hydra.initialize(version_base=None, config_path="../../configs"):
|
44
|
+
cfg = hydra.compose(
|
45
|
+
config_name="config",
|
46
|
+
overrides=["tools/question_and_answer=default"],
|
47
|
+
)
|
48
|
+
config = cfg.tools.question_and_answer
|
49
|
+
logger.info("Loaded Question and Answer tool configuration.")
|
50
|
+
return config
|
45
51
|
|
46
52
|
|
47
53
|
class QuestionAndAnswerInput(BaseModel):
|
48
54
|
"""
|
49
55
|
Input schema for the PDF Question and Answer tool.
|
50
56
|
|
57
|
+
This schema defines the inputs required for querying academic or research-related
|
58
|
+
PDFs to answer a specific question using a language model and document retrieval.
|
59
|
+
|
51
60
|
Attributes:
|
52
61
|
question (str): The question to ask regarding the PDF content.
|
62
|
+
paper_ids (Optional[List[str]]): Optional list of specific paper IDs to query.
|
63
|
+
If not provided, the system will determine relevant papers automatically.
|
64
|
+
use_all_papers (bool): Whether to use all available papers for answering the question.
|
65
|
+
If True, the system will include all loaded papers regardless of relevance filtering.
|
53
66
|
tool_call_id (str): Unique identifier for the tool call, injected automatically.
|
67
|
+
state (dict): Shared application state, injected automatically.
|
54
68
|
"""
|
55
69
|
|
56
70
|
question: str = Field(description="The question to ask regarding the PDF content.")
|
71
|
+
paper_ids: Optional[List[str]] = Field(
|
72
|
+
default=None,
|
73
|
+
description="Optional list of specific paper IDs to query. "
|
74
|
+
"If not provided, relevant papers will be selected automatically.",
|
75
|
+
)
|
76
|
+
use_all_papers: bool = Field(
|
77
|
+
default=False,
|
78
|
+
description="Whether to use all available papers for answering the question. "
|
79
|
+
"Set to True to bypass relevance filtering and include all loaded papers.",
|
80
|
+
)
|
57
81
|
tool_call_id: Annotated[str, InjectedToolCallId]
|
58
82
|
state: Annotated[dict, InjectedState]
|
59
83
|
|
60
84
|
|
61
|
-
|
85
|
+
class Vectorstore:
|
86
|
+
"""
|
87
|
+
A class for managing document embeddings and retrieval.
|
88
|
+
Provides unified access to documents across multiple papers.
|
62
89
|
"""
|
63
|
-
Extract text content from PDF binary data.
|
64
90
|
|
65
|
-
|
66
|
-
|
91
|
+
def __init__(
|
92
|
+
self,
|
93
|
+
embedding_model: Embeddings,
|
94
|
+
metadata_fields: Optional[List[str]] = None,
|
95
|
+
):
|
96
|
+
"""
|
97
|
+
Initialize the document store.
|
67
98
|
|
68
|
-
|
69
|
-
|
99
|
+
Args:
|
100
|
+
embedding_model: The embedding model to use
|
101
|
+
metadata_fields: Fields to include in document metadata for filtering/retrieval
|
102
|
+
"""
|
103
|
+
self.embedding_model = embedding_model
|
104
|
+
self.metadata_fields = metadata_fields or [
|
105
|
+
"title",
|
106
|
+
"paper_id",
|
107
|
+
"page",
|
108
|
+
"chunk_id",
|
109
|
+
]
|
110
|
+
self.initialization_time = time.time()
|
111
|
+
logger.info("Vectorstore initialized at: %s", self.initialization_time)
|
70
112
|
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
113
|
+
# Track loaded papers to prevent duplicate loading
|
114
|
+
self.loaded_papers = set()
|
115
|
+
self.vector_store_class = FAISS
|
116
|
+
logger.info("Using FAISS vector store")
|
117
|
+
|
118
|
+
# Store for initialized documents
|
119
|
+
self.documents: Dict[str, Document] = {}
|
120
|
+
self.vector_store: Optional[VectorStore] = None
|
121
|
+
self.paper_metadata: Dict[str, Dict[str, Any]] = {}
|
122
|
+
|
123
|
+
def add_paper(
|
124
|
+
self,
|
125
|
+
paper_id: str,
|
126
|
+
pdf_url: str,
|
127
|
+
paper_metadata: Dict[str, Any],
|
128
|
+
) -> None:
|
129
|
+
"""
|
130
|
+
Add a paper to the document store.
|
131
|
+
|
132
|
+
Args:
|
133
|
+
paper_id: Unique identifier for the paper
|
134
|
+
pdf_url: URL to the PDF
|
135
|
+
paper_metadata: Metadata about the paper
|
136
|
+
"""
|
137
|
+
# Skip if already loaded
|
138
|
+
if paper_id in self.loaded_papers:
|
139
|
+
logger.info("Paper %s already loaded, skipping", paper_id)
|
140
|
+
return
|
141
|
+
|
142
|
+
logger.info("Loading paper %s from %s", paper_id, pdf_url)
|
143
|
+
|
144
|
+
# Store paper metadata
|
145
|
+
self.paper_metadata[paper_id] = paper_metadata
|
146
|
+
|
147
|
+
# Load the PDF and split into chunks according to Hydra config
|
148
|
+
loader = PyPDFLoader(pdf_url)
|
149
|
+
documents = loader.load()
|
150
|
+
logger.info("Loaded %d pages from %s", len(documents), paper_id)
|
151
|
+
|
152
|
+
# Create text splitter according to Hydra config
|
153
|
+
cfg = load_hydra_config()
|
154
|
+
splitter = RecursiveCharacterTextSplitter(
|
155
|
+
chunk_size=cfg.chunk_size,
|
156
|
+
chunk_overlap=cfg.chunk_overlap,
|
157
|
+
separators=["\n\n", "\n", ". ", " ", ""],
|
158
|
+
)
|
159
|
+
|
160
|
+
# Split documents and add metadata for each chunk
|
161
|
+
chunks = splitter.split_documents(documents)
|
162
|
+
logger.info("Split %s into %d chunks", paper_id, len(chunks))
|
163
|
+
|
164
|
+
# Enhance document metadata
|
165
|
+
for i, chunk in enumerate(chunks):
|
166
|
+
# Add paper metadata to each chunk
|
167
|
+
chunk.metadata.update(
|
168
|
+
{
|
169
|
+
"paper_id": paper_id,
|
170
|
+
"title": paper_metadata.get("Title", "Unknown"),
|
171
|
+
"chunk_id": i,
|
172
|
+
# Keep existing page number if available
|
173
|
+
"page": chunk.metadata.get("page", 0),
|
174
|
+
}
|
175
|
+
)
|
176
|
+
|
177
|
+
# Add any additional metadata fields
|
178
|
+
for field in self.metadata_fields:
|
179
|
+
if field in paper_metadata and field not in chunk.metadata:
|
180
|
+
chunk.metadata[field] = paper_metadata[field]
|
181
|
+
|
182
|
+
# Store chunk
|
183
|
+
doc_id = f"{paper_id}_{i}"
|
184
|
+
self.documents[doc_id] = chunk
|
185
|
+
|
186
|
+
# Mark as loaded to prevent duplicate loading
|
187
|
+
self.loaded_papers.add(paper_id)
|
188
|
+
logger.info("Added %d chunks from paper %s", len(chunks), paper_id)
|
189
|
+
|
190
|
+
def build_vector_store(self) -> None:
|
191
|
+
"""
|
192
|
+
Build the vector store from all loaded documents.
|
193
|
+
Should be called after all papers are added.
|
194
|
+
"""
|
195
|
+
if not self.documents:
|
196
|
+
logger.warning("No documents added to build vector store")
|
197
|
+
return
|
198
|
+
|
199
|
+
if self.vector_store is not None:
|
200
|
+
logger.info("Vector store already built, skipping")
|
201
|
+
return
|
202
|
+
|
203
|
+
# Create vector store from documents
|
204
|
+
documents_list = list(self.documents.values())
|
205
|
+
self.vector_store = self.vector_store_class.from_documents(
|
206
|
+
documents=documents_list, embedding=self.embedding_model
|
207
|
+
)
|
208
|
+
logger.info("Built vector store with %d documents", len(documents_list))
|
209
|
+
|
210
|
+
def rank_papers_by_query(
|
211
|
+
self, query: str, top_k: int = 40
|
212
|
+
) -> List[Tuple[str, float]]:
|
213
|
+
"""
|
214
|
+
Rank papers by relevance to the query using NVIDIA's off-the-shelf re-ranker.
|
215
|
+
|
216
|
+
This function aggregates all chunks per paper, ranks them using the NVIDIA model,
|
217
|
+
and returns the top-k papers.
|
218
|
+
|
219
|
+
Args:
|
220
|
+
query (str): The query string.
|
221
|
+
top_k (int): Number of top papers to return.
|
222
|
+
|
223
|
+
Returns:
|
224
|
+
List of tuples (paper_id, dummy_score) sorted by relevance.
|
225
|
+
"""
|
226
|
+
|
227
|
+
# Aggregate all document chunks for each paper
|
228
|
+
paper_texts = {}
|
229
|
+
for doc in self.documents.values():
|
230
|
+
paper_id = doc.metadata["paper_id"]
|
231
|
+
paper_texts.setdefault(paper_id, []).append(doc.page_content)
|
232
|
+
|
233
|
+
aggregated_documents = []
|
234
|
+
for paper_id, texts in paper_texts.items():
|
235
|
+
aggregated_text = " ".join(texts)
|
236
|
+
aggregated_documents.append(
|
237
|
+
Document(page_content=aggregated_text, metadata={"paper_id": paper_id})
|
238
|
+
)
|
239
|
+
|
240
|
+
# Instantiate the NVIDIA re-ranker client
|
241
|
+
config = load_hydra_config()
|
242
|
+
reranker = NVIDIARerank(
|
243
|
+
model=config.reranker.model,
|
244
|
+
api_key=config.reranker.api_key,
|
245
|
+
)
|
246
|
+
|
247
|
+
# Get the ranked list of documents based on the query
|
248
|
+
response = reranker.compress_documents(
|
249
|
+
query=query, documents=aggregated_documents
|
250
|
+
)
|
251
|
+
|
252
|
+
ranked_papers = [doc.metadata["paper_id"] for doc in response[:top_k]]
|
253
|
+
return ranked_papers
|
254
|
+
|
255
|
+
def retrieve_relevant_chunks(
|
256
|
+
self,
|
257
|
+
query: str,
|
258
|
+
paper_ids: Optional[List[str]] = None,
|
259
|
+
top_k: int = 25,
|
260
|
+
mmr_diversity: float = 1.00,
|
261
|
+
) -> List[Document]:
|
262
|
+
"""
|
263
|
+
Retrieve the most relevant chunks for a query using maximal marginal relevance.
|
264
|
+
|
265
|
+
Args:
|
266
|
+
query: Query string
|
267
|
+
paper_ids: Optional list of paper IDs to filter by
|
268
|
+
top_k: Number of chunks to retrieve
|
269
|
+
mmr_diversity: Diversity parameter for MMR (higher = more diverse)
|
270
|
+
|
271
|
+
Returns:
|
272
|
+
List of document chunks
|
273
|
+
"""
|
274
|
+
if not self.vector_store:
|
275
|
+
logger.error("Failed to build vector store")
|
276
|
+
return []
|
277
|
+
|
278
|
+
if paper_ids:
|
279
|
+
logger.info("Filtering retrieval to papers: %s", paper_ids)
|
280
|
+
|
281
|
+
# Step 1: Embed the query
|
282
|
+
logger.info(
|
283
|
+
"Embedding query using model: %s", type(self.embedding_model).__name__
|
284
|
+
)
|
285
|
+
query_embedding = np.array(self.embedding_model.embed_query(query))
|
286
|
+
|
287
|
+
# Step 2: Filter relevant documents
|
288
|
+
all_docs = [
|
289
|
+
doc
|
290
|
+
for doc in self.documents.values()
|
291
|
+
if not paper_ids or doc.metadata["paper_id"] in paper_ids
|
292
|
+
]
|
293
|
+
|
294
|
+
if not all_docs:
|
295
|
+
logger.warning("No documents found after filtering by paper_ids.")
|
296
|
+
return []
|
297
|
+
|
298
|
+
texts = [doc.page_content for doc in all_docs]
|
299
|
+
|
300
|
+
# Step 3: Batch embed all documents
|
301
|
+
logger.info("Starting batch embedding for %d chunks...", len(texts))
|
302
|
+
all_embeddings = self.embedding_model.embed_documents(texts)
|
303
|
+
logger.info("Completed embedding for %d chunks...", len(texts))
|
304
|
+
|
305
|
+
# Step 4: Apply MMR
|
306
|
+
mmr_indices = maximal_marginal_relevance(
|
307
|
+
query_embedding,
|
308
|
+
all_embeddings,
|
309
|
+
k=top_k,
|
310
|
+
lambda_mult=mmr_diversity,
|
311
|
+
)
|
312
|
+
|
313
|
+
results = [all_docs[i] for i in mmr_indices]
|
314
|
+
logger.info("Retrieved %d chunks using MMR", len(results))
|
315
|
+
return results
|
80
316
|
|
81
317
|
|
82
318
|
def generate_answer(
|
83
|
-
question: str,
|
319
|
+
question: str,
|
320
|
+
retrieved_chunks: List[Document],
|
321
|
+
llm_model: BaseChatModel,
|
322
|
+
config: Optional[Any] = None,
|
84
323
|
) -> Dict[str, Any]:
|
85
324
|
"""
|
86
|
-
Generate an answer for a question using
|
87
|
-
|
88
|
-
This function extracts text from the PDF data, splits the text into manageable chunks,
|
89
|
-
performs a similarity search to retrieve the most relevant segments, and then uses a
|
90
|
-
question-answering chain (built using the provided llm_model) to generate an answer.
|
325
|
+
Generate an answer for a question using retrieved chunks.
|
91
326
|
|
92
327
|
Args:
|
93
|
-
question (str): The question to
|
94
|
-
|
95
|
-
llm_model (BaseChatModel):
|
328
|
+
question (str): The question to answer
|
329
|
+
retrieved_chunks (List[Document]): List of relevant document chunks
|
330
|
+
llm_model (BaseChatModel): Language model for generating answers
|
331
|
+
config (Optional[Any]): Configuration for answer generation
|
96
332
|
|
97
333
|
Returns:
|
98
|
-
Dict[str, Any]:
|
334
|
+
Dict[str, Any]: Dictionary with the answer and metadata
|
99
335
|
"""
|
100
|
-
|
101
|
-
|
102
|
-
text_splitter = CharacterTextSplitter(
|
103
|
-
separator="\n", chunk_size=cfg.chunk_size, chunk_overlap=cfg.chunk_overlap
|
104
|
-
)
|
105
|
-
chunks = text_splitter.split_text(text)
|
106
|
-
documents: List[Document] = [Document(page_content=chunk) for chunk in chunks]
|
107
|
-
logger.info("Split PDF text into %d chunks.", len(documents))
|
108
|
-
|
109
|
-
embeddings = OpenAIEmbeddings(openai_api_key=cfg.openai_api_key)
|
110
|
-
vector_store = Annoy.from_documents(documents, embeddings)
|
111
|
-
search_results = vector_store.similarity_search(question, k=cfg.num_retrievals)
|
112
|
-
logger.info("Retrieved %d relevant document chunks.", len(search_results))
|
113
|
-
# Use the provided llm_model to build the QA chain.
|
114
|
-
qa_chain = load_qa_chain(llm_model, chain_type=cfg.qa_chain_type)
|
115
|
-
answer = qa_chain.invoke(
|
116
|
-
input={"input_documents": search_results, "question": question}
|
117
|
-
)
|
118
|
-
return answer
|
336
|
+
# Load configuration using the global function.
|
337
|
+
config = load_hydra_config()
|
119
338
|
|
339
|
+
# Ensure the configuration is not None and has the prompt_template.
|
340
|
+
if config is None:
|
341
|
+
raise ValueError("Hydra config loading failed: config is None.")
|
342
|
+
if "prompt_template" not in config:
|
343
|
+
raise ValueError("The prompt_template is missing from the configuration.")
|
120
344
|
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
345
|
+
# Prepare context from retrieved documents with source attribution.
|
346
|
+
# Group chunks by paper_id
|
347
|
+
papers = {}
|
348
|
+
for doc in retrieved_chunks:
|
349
|
+
paper_id = doc.metadata.get("paper_id", "unknown")
|
350
|
+
if paper_id not in papers:
|
351
|
+
papers[paper_id] = []
|
352
|
+
papers[paper_id].append(doc)
|
126
353
|
|
127
|
-
|
128
|
-
|
129
|
-
|
354
|
+
# Format chunks by paper
|
355
|
+
formatted_chunks = []
|
356
|
+
doc_index = 1
|
357
|
+
for paper_id, chunks in papers.items():
|
358
|
+
# Get the title from the first chunk (should be the same for all chunks)
|
359
|
+
title = chunks[0].metadata.get("title", "Unknown")
|
130
360
|
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
361
|
+
# Add a document header
|
362
|
+
formatted_chunks.append(
|
363
|
+
f"[Document {doc_index}] From: '{title}' (ID: {paper_id})"
|
364
|
+
)
|
135
365
|
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
#
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
for
|
149
|
-
|
150
|
-
#
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
#
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
366
|
+
# Add each chunk with its page information
|
367
|
+
for chunk in chunks:
|
368
|
+
page = chunk.metadata.get("page", "unknown")
|
369
|
+
formatted_chunks.append(f"Page {page}: {chunk.page_content}")
|
370
|
+
|
371
|
+
# Increment document index for the next paper
|
372
|
+
doc_index += 1
|
373
|
+
|
374
|
+
# Join all chunks
|
375
|
+
context = "\n\n".join(formatted_chunks)
|
376
|
+
|
377
|
+
# Get unique paper sources.
|
378
|
+
paper_sources = {doc.metadata["paper_id"] for doc in retrieved_chunks}
|
379
|
+
|
380
|
+
# Create prompt using the Hydra-provided prompt_template.
|
381
|
+
prompt = config["prompt_template"].format(context=context, question=question)
|
382
|
+
|
383
|
+
# Get the answer from the language model
|
384
|
+
response = llm_model.invoke(prompt)
|
385
|
+
|
386
|
+
# Return the response with metadata
|
387
|
+
return {
|
388
|
+
"output_text": response.content,
|
389
|
+
"sources": [doc.metadata for doc in retrieved_chunks],
|
390
|
+
"num_sources": len(retrieved_chunks),
|
391
|
+
"papers_used": list(paper_sources),
|
392
|
+
}
|
393
|
+
|
394
|
+
|
395
|
+
@tool(args_schema=QuestionAndAnswerInput, parse_docstring=True)
|
396
|
+
def question_and_answer(
|
165
397
|
question: str,
|
166
|
-
tool_call_id: Annotated[str, InjectedToolCallId],
|
167
398
|
state: Annotated[dict, InjectedState],
|
168
|
-
|
399
|
+
tool_call_id: Annotated[str, InjectedToolCallId],
|
400
|
+
paper_ids: Optional[List[str]] = None,
|
401
|
+
use_all_papers: bool = False,
|
402
|
+
) -> Command[Any]:
|
169
403
|
"""
|
170
|
-
Answer a question using PDF content
|
404
|
+
Answer a question using PDF content with advanced retrieval augmented generation.
|
171
405
|
|
172
|
-
This tool retrieves
|
173
|
-
|
174
|
-
|
406
|
+
This tool retrieves PDF documents from URLs, processes them using semantic search,
|
407
|
+
and generates an answer to the user's question based on the most relevant content.
|
408
|
+
It can work with multiple papers simultaneously and provides source attribution.
|
175
409
|
|
176
410
|
Args:
|
177
|
-
question (str): The question
|
411
|
+
question (str): The question to answer based on PDF content.
|
412
|
+
paper_ids (Optional[List[str]]): Optional list of specific paper IDs to query.
|
413
|
+
use_all_papers (bool): Whether to use all available papers.
|
178
414
|
tool_call_id (str): Unique identifier for the current tool call.
|
179
|
-
state (dict):
|
180
|
-
|
181
|
-
|
415
|
+
state (dict): Current state dictionary containing article data and required models.
|
416
|
+
Expected keys:
|
417
|
+
- "article_data": Dictionary containing article metadata including PDF URLs
|
418
|
+
- "text_embedding_model": Model for generating embeddings
|
419
|
+
- "llm_model": Language model for generating answers
|
420
|
+
- "vector_store": Optional Vectorstore instance
|
182
421
|
|
183
422
|
Returns:
|
184
|
-
Dict[str, Any]: A dictionary
|
423
|
+
Dict[str, Any]: A dictionary wrapped in a Command that updates the conversation
|
424
|
+
with either the answer or an error message.
|
425
|
+
|
426
|
+
Raises:
|
427
|
+
ValueError: If required components are missing or if PDF processing fails.
|
185
428
|
"""
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
update={
|
204
|
-
"messages": [ToolMessage(content=error_msg, tool_call_id=tool_call_id)]
|
205
|
-
}
|
206
|
-
)
|
207
|
-
pdf_url = pdf_state.get("pdf_url")
|
208
|
-
# Retrieve llm_model from state; use a default if not provided.
|
429
|
+
# Load configuration
|
430
|
+
config = load_hydra_config()
|
431
|
+
# Create a unique identifier for this call to track potential infinite loops
|
432
|
+
call_id = f"qa_call_{time.time()}"
|
433
|
+
logger.info(
|
434
|
+
"Starting PDF Question and Answer tool call %s for question: %s",
|
435
|
+
call_id,
|
436
|
+
question,
|
437
|
+
)
|
438
|
+
|
439
|
+
# Get required models from state
|
440
|
+
text_embedding_model = state.get("text_embedding_model")
|
441
|
+
if not text_embedding_model:
|
442
|
+
error_msg = "No text embedding model found in state."
|
443
|
+
logger.error("%s: %s", call_id, error_msg)
|
444
|
+
raise ValueError(error_msg)
|
445
|
+
|
209
446
|
llm_model = state.get("llm_model")
|
210
447
|
if not llm_model:
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
448
|
+
error_msg = "No LLM model found in state."
|
449
|
+
logger.error("%s: %s", call_id, error_msg)
|
450
|
+
raise ValueError(error_msg)
|
451
|
+
|
452
|
+
# Get article data from state
|
453
|
+
article_data = state.get("article_data", {})
|
454
|
+
if not article_data:
|
455
|
+
error_msg = "No article_data found in state."
|
456
|
+
logger.error("%s: %s", call_id, error_msg)
|
457
|
+
raise ValueError(error_msg)
|
458
|
+
|
459
|
+
# Always use a fresh in-memory document store for this Q&A call
|
460
|
+
vector_store = Vectorstore(embedding_model=text_embedding_model)
|
461
|
+
|
462
|
+
# Check if there are papers from different sources
|
463
|
+
has_uploaded_papers = any(
|
464
|
+
paper.get("source") == "upload"
|
465
|
+
for paper in article_data.values()
|
466
|
+
if isinstance(paper, dict)
|
467
|
+
)
|
468
|
+
|
469
|
+
has_zotero_papers = any(
|
470
|
+
paper.get("source") == "zotero"
|
471
|
+
for paper in article_data.values()
|
472
|
+
if isinstance(paper, dict)
|
473
|
+
)
|
474
|
+
|
475
|
+
has_arxiv_papers = any(
|
476
|
+
paper.get("source") == "arxiv"
|
477
|
+
for paper in article_data.values()
|
478
|
+
if isinstance(paper, dict)
|
479
|
+
)
|
480
|
+
|
481
|
+
# Choose papers to use
|
482
|
+
selected_paper_ids = []
|
483
|
+
|
484
|
+
if paper_ids:
|
485
|
+
# Use explicitly specified papers
|
486
|
+
selected_paper_ids = [pid for pid in paper_ids if pid in article_data]
|
487
|
+
logger.info(
|
488
|
+
"%s: Using explicitly specified papers: %s", call_id, selected_paper_ids
|
489
|
+
)
|
490
|
+
|
491
|
+
if not selected_paper_ids:
|
492
|
+
logger.warning(
|
493
|
+
"%s: None of the provided paper_ids %s were found", call_id, paper_ids
|
494
|
+
)
|
495
|
+
|
496
|
+
elif use_all_papers or has_uploaded_papers or has_zotero_papers or has_arxiv_papers:
|
497
|
+
# Use all available papers if explicitly requested or if we have papers from any source
|
498
|
+
selected_paper_ids = list(article_data.keys())
|
499
|
+
logger.info(
|
500
|
+
"%s: Using all %d available papers", call_id, len(selected_paper_ids)
|
501
|
+
)
|
502
|
+
|
503
|
+
else:
|
504
|
+
# Use semantic ranking to find relevant papers
|
505
|
+
# First ensure papers are loaded
|
506
|
+
for paper_id, paper in article_data.items():
|
507
|
+
pdf_url = paper.get("pdf_url")
|
508
|
+
if pdf_url and paper_id not in vector_store.loaded_papers:
|
509
|
+
try:
|
510
|
+
vector_store.add_paper(paper_id, pdf_url, paper)
|
511
|
+
except (IOError, ValueError) as e:
|
512
|
+
logger.error("Error loading paper %s: %s", paper_id, e)
|
513
|
+
raise
|
514
|
+
|
515
|
+
# Now rank papers
|
516
|
+
ranked_papers = vector_store.rank_papers_by_query(
|
517
|
+
question, top_k=config.top_k_papers
|
518
|
+
)
|
519
|
+
selected_paper_ids = [paper_id for paper_id, _ in ranked_papers]
|
520
|
+
logger.info(
|
521
|
+
"%s: Selected papers based on semantic relevance: %s",
|
522
|
+
call_id,
|
523
|
+
selected_paper_ids,
|
524
|
+
)
|
525
|
+
|
526
|
+
if not selected_paper_ids:
|
527
|
+
# Fallback to all papers if selection failed
|
528
|
+
selected_paper_ids = list(article_data.keys())
|
529
|
+
logger.info(
|
530
|
+
"%s: Falling back to all %d papers", call_id, len(selected_paper_ids)
|
531
|
+
)
|
532
|
+
|
533
|
+
# Load selected papers if needed
|
534
|
+
for paper_id in selected_paper_ids:
|
535
|
+
if paper_id not in vector_store.loaded_papers:
|
536
|
+
pdf_url = article_data[paper_id].get("pdf_url")
|
537
|
+
if pdf_url:
|
538
|
+
try:
|
539
|
+
vector_store.add_paper(paper_id, pdf_url, article_data[paper_id])
|
540
|
+
except (IOError, ValueError) as e:
|
541
|
+
logger.warning(
|
542
|
+
"%s: Error loading paper %s: %s", call_id, paper_id, e
|
543
|
+
)
|
544
|
+
|
545
|
+
# Ensure vector store is built
|
546
|
+
if not vector_store.vector_store:
|
547
|
+
vector_store.build_vector_store()
|
548
|
+
|
549
|
+
# Retrieve relevant chunks across selected papers
|
550
|
+
relevant_chunks = vector_store.retrieve_relevant_chunks(
|
551
|
+
query=question, paper_ids=selected_paper_ids, top_k=config.top_k_chunks
|
552
|
+
)
|
553
|
+
|
554
|
+
if not relevant_chunks:
|
555
|
+
error_msg = "No relevant chunks found in the papers."
|
556
|
+
logger.warning("%s: %s", call_id, error_msg)
|
557
|
+
raise RuntimeError(
|
558
|
+
f"I couldn't find relevant information to answer your question: '{question}'. "
|
559
|
+
"Please try rephrasing or asking a different question."
|
560
|
+
)
|
561
|
+
|
562
|
+
# Generate answer using retrieved chunks
|
563
|
+
result = generate_answer(question, relevant_chunks, llm_model)
|
564
|
+
|
565
|
+
# Format answer with attribution
|
566
|
+
answer_text = result.get("output_text", "No answer generated.")
|
567
|
+
|
568
|
+
# Get paper titles for sources
|
569
|
+
paper_titles = {}
|
570
|
+
for paper_id in result.get("papers_used", []):
|
571
|
+
if paper_id in article_data:
|
572
|
+
paper_titles[paper_id] = article_data[paper_id].get(
|
573
|
+
"Title", "Unknown paper"
|
574
|
+
)
|
575
|
+
|
576
|
+
# Format source information
|
577
|
+
sources_text = ""
|
578
|
+
if paper_titles:
|
579
|
+
sources_text = "\n\nSources:\n" + "\n".join(
|
580
|
+
[f"- {title}" for title in paper_titles.values()]
|
581
|
+
)
|
582
|
+
|
583
|
+
# Prepare the final response
|
584
|
+
response_text = f"{answer_text}{sources_text}"
|
585
|
+
logger.info(
|
586
|
+
"%s: Successfully generated answer using %d chunks from %d papers",
|
587
|
+
call_id,
|
588
|
+
len(relevant_chunks),
|
589
|
+
len(paper_titles),
|
590
|
+
)
|
591
|
+
|
592
|
+
return Command(
|
593
|
+
update={
|
594
|
+
"messages": [
|
595
|
+
ToolMessage(
|
596
|
+
content=response_text,
|
597
|
+
tool_call_id=tool_call_id,
|
598
|
+
)
|
599
|
+
],
|
600
|
+
}
|
601
|
+
)
|