aiagents4pharma 1.31.0__py3-none-any.whl → 1.33.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. aiagents4pharma/talk2knowledgegraphs/configs/config.yaml +1 -0
  2. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_uniprot.py +44 -0
  3. aiagents4pharma/talk2knowledgegraphs/utils/enrichments/__init__.py +1 -0
  4. aiagents4pharma/talk2knowledgegraphs/utils/enrichments/uniprot_proteins.py +90 -0
  5. aiagents4pharma/talk2scholars/agents/main_agent.py +4 -3
  6. aiagents4pharma/talk2scholars/agents/paper_download_agent.py +3 -4
  7. aiagents4pharma/talk2scholars/agents/pdf_agent.py +6 -7
  8. aiagents4pharma/talk2scholars/agents/s2_agent.py +23 -20
  9. aiagents4pharma/talk2scholars/agents/zotero_agent.py +11 -11
  10. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/main_agent/default.yaml +19 -19
  11. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/s2_agent/default.yaml +20 -15
  12. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/zotero_agent/default.yaml +27 -6
  13. aiagents4pharma/talk2scholars/state/state_talk2scholars.py +7 -7
  14. aiagents4pharma/talk2scholars/tests/test_main_agent.py +16 -16
  15. aiagents4pharma/talk2scholars/tests/test_paper_download_agent.py +17 -24
  16. aiagents4pharma/talk2scholars/tests/test_paper_download_tools.py +152 -135
  17. aiagents4pharma/talk2scholars/tests/test_pdf_agent.py +9 -16
  18. aiagents4pharma/talk2scholars/tests/test_question_and_answer_tool.py +790 -218
  19. aiagents4pharma/talk2scholars/tests/test_s2_agent.py +9 -9
  20. aiagents4pharma/talk2scholars/tests/test_s2_display.py +8 -8
  21. aiagents4pharma/talk2scholars/tests/test_s2_query.py +8 -8
  22. aiagents4pharma/talk2scholars/tests/test_zotero_agent.py +12 -12
  23. aiagents4pharma/talk2scholars/tests/test_zotero_path.py +11 -12
  24. aiagents4pharma/talk2scholars/tests/test_zotero_read.py +400 -22
  25. aiagents4pharma/talk2scholars/tools/paper_download/__init__.py +0 -6
  26. aiagents4pharma/talk2scholars/tools/paper_download/download_arxiv_input.py +89 -31
  27. aiagents4pharma/talk2scholars/tools/pdf/question_and_answer.py +540 -156
  28. aiagents4pharma/talk2scholars/tools/s2/__init__.py +4 -4
  29. aiagents4pharma/talk2scholars/tools/s2/{display_results.py → display_dataframe.py} +19 -21
  30. aiagents4pharma/talk2scholars/tools/s2/query_dataframe.py +71 -0
  31. aiagents4pharma/talk2scholars/tools/zotero/utils/read_helper.py +213 -35
  32. aiagents4pharma/talk2scholars/tools/zotero/zotero_read.py +3 -3
  33. {aiagents4pharma-1.31.0.dist-info → aiagents4pharma-1.33.0.dist-info}/METADATA +3 -1
  34. {aiagents4pharma-1.31.0.dist-info → aiagents4pharma-1.33.0.dist-info}/RECORD +37 -37
  35. {aiagents4pharma-1.31.0.dist-info → aiagents4pharma-1.33.0.dist-info}/WHEEL +1 -1
  36. aiagents4pharma/talk2scholars/tools/paper_download/abstract_downloader.py +0 -45
  37. aiagents4pharma/talk2scholars/tools/paper_download/arxiv_downloader.py +0 -115
  38. aiagents4pharma/talk2scholars/tools/s2/query_results.py +0 -61
  39. {aiagents4pharma-1.31.0.dist-info → aiagents4pharma-1.33.0.dist-info}/licenses/LICENSE +0 -0
  40. {aiagents4pharma-1.31.0.dist-info → aiagents4pharma-1.33.0.dist-info}/top_level.txt +0 -0
@@ -1,217 +1,601 @@
1
- #!/usr/bin/env python3
2
1
  """
3
- question_and_answer: Tool for performing Q&A on PDF documents using retrieval augmented generation.
4
-
5
- This module provides functionality to extract text from PDF binary data, split it into
6
- chunks, retrieve relevant segments via a vector store, and generate an answer to a
7
- user-provided question using a language model chain.
2
+ Tool for performing Q&A on PDF documents using retrieval augmented generation.
3
+ This module provides functionality to load PDFs from URLs, split them into
4
+ chunks, retrieve relevant segments via semantic search, and generate answers
5
+ to user-provided questions using a language model chain.
8
6
  """
9
7
 
10
- import io
11
8
  import logging
12
- from typing import Annotated, Dict, Any, List
9
+ import os
10
+ import time
11
+ from typing import Annotated, Any, Dict, List, Optional, Tuple
13
12
 
14
- from PyPDF2 import PdfReader
15
- from pydantic import BaseModel, Field
16
13
  import hydra
17
-
18
- from langchain.chains.question_answering import load_qa_chain
19
- from langchain.docstore.document import Document
20
- from langchain.text_splitter import CharacterTextSplitter
14
+ import numpy as np
15
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
16
+ from langchain_community.document_loaders import PyPDFLoader
17
+ from langchain_community.vectorstores import FAISS
18
+ from langchain_core.documents import Document
19
+ from langchain_core.embeddings import Embeddings
21
20
  from langchain_core.language_models.chat_models import BaseChatModel
22
- from langchain_core.vectorstores import InMemoryVectorStore
23
21
  from langchain_core.messages import ToolMessage
24
22
  from langchain_core.tools import tool
25
23
  from langchain_core.tools.base import InjectedToolCallId
26
- from langchain_core.embeddings import Embeddings
27
- from langchain_community.vectorstores import Annoy
28
- from langchain_community.document_loaders import PyPDFLoader
29
- from langchain_openai import OpenAIEmbeddings
30
- from langgraph.types import Command
24
+ from langchain_core.vectorstores import VectorStore
25
+ from langchain_core.vectorstores.utils import maximal_marginal_relevance
26
+ from langchain_nvidia_ai_endpoints import NVIDIARerank
31
27
  from langgraph.prebuilt import InjectedState
28
+ from langgraph.types import Command
29
+ from pydantic import BaseModel, Field
32
30
 
33
- # Set up logging.
34
- logging.basicConfig(level=logging.INFO)
31
+ # Set up logging with configurable level
32
+ log_level = os.environ.get("LOG_LEVEL", "INFO")
33
+ logging.basicConfig(level=getattr(logging, log_level))
35
34
  logger = logging.getLogger(__name__)
36
- logger.setLevel(logging.INFO)
35
+ logger.setLevel(getattr(logging, log_level))
36
+ # pylint: disable=too-many-arguments, too-many-positional-arguments, too-many-locals, too-many-branches, too-many-statements
37
37
 
38
- # Load configuration using Hydra.
39
- with hydra.initialize(version_base=None, config_path="../../configs"):
40
- cfg = hydra.compose(
41
- config_name="config", overrides=["tools/question_and_answer=default"]
42
- )
43
- cfg = cfg.tools.question_and_answer
44
- logger.info("Loaded Question and Answer tool configuration.")
38
+
39
+ def load_hydra_config() -> Any:
40
+ """
41
+ Load the configuration using Hydra and return the configuration for the Q&A tool.
42
+ """
43
+ with hydra.initialize(version_base=None, config_path="../../configs"):
44
+ cfg = hydra.compose(
45
+ config_name="config",
46
+ overrides=["tools/question_and_answer=default"],
47
+ )
48
+ config = cfg.tools.question_and_answer
49
+ logger.info("Loaded Question and Answer tool configuration.")
50
+ return config
45
51
 
46
52
 
47
53
  class QuestionAndAnswerInput(BaseModel):
48
54
  """
49
55
  Input schema for the PDF Question and Answer tool.
50
56
 
57
+ This schema defines the inputs required for querying academic or research-related
58
+ PDFs to answer a specific question using a language model and document retrieval.
59
+
51
60
  Attributes:
52
61
  question (str): The question to ask regarding the PDF content.
62
+ paper_ids (Optional[List[str]]): Optional list of specific paper IDs to query.
63
+ If not provided, the system will determine relevant papers automatically.
64
+ use_all_papers (bool): Whether to use all available papers for answering the question.
65
+ If True, the system will include all loaded papers regardless of relevance filtering.
53
66
  tool_call_id (str): Unique identifier for the tool call, injected automatically.
67
+ state (dict): Shared application state, injected automatically.
54
68
  """
55
69
 
56
70
  question: str = Field(description="The question to ask regarding the PDF content.")
71
+ paper_ids: Optional[List[str]] = Field(
72
+ default=None,
73
+ description="Optional list of specific paper IDs to query. "
74
+ "If not provided, relevant papers will be selected automatically.",
75
+ )
76
+ use_all_papers: bool = Field(
77
+ default=False,
78
+ description="Whether to use all available papers for answering the question. "
79
+ "Set to True to bypass relevance filtering and include all loaded papers.",
80
+ )
57
81
  tool_call_id: Annotated[str, InjectedToolCallId]
58
82
  state: Annotated[dict, InjectedState]
59
83
 
60
84
 
61
- def extract_text_from_pdf_data(pdf_bytes: bytes) -> str:
85
+ class Vectorstore:
86
+ """
87
+ A class for managing document embeddings and retrieval.
88
+ Provides unified access to documents across multiple papers.
62
89
  """
63
- Extract text content from PDF binary data.
64
90
 
65
- This function uses PyPDF2 to read the provided PDF bytes and concatenates the text
66
- extracted from each page.
91
+ def __init__(
92
+ self,
93
+ embedding_model: Embeddings,
94
+ metadata_fields: Optional[List[str]] = None,
95
+ ):
96
+ """
97
+ Initialize the document store.
67
98
 
68
- Args:
69
- pdf_bytes (bytes): The binary data of the PDF document.
99
+ Args:
100
+ embedding_model: The embedding model to use
101
+ metadata_fields: Fields to include in document metadata for filtering/retrieval
102
+ """
103
+ self.embedding_model = embedding_model
104
+ self.metadata_fields = metadata_fields or [
105
+ "title",
106
+ "paper_id",
107
+ "page",
108
+ "chunk_id",
109
+ ]
110
+ self.initialization_time = time.time()
111
+ logger.info("Vectorstore initialized at: %s", self.initialization_time)
70
112
 
71
- Returns:
72
- str: The complete text extracted from the PDF.
73
- """
74
- reader = PdfReader(io.BytesIO(pdf_bytes))
75
- text = ""
76
- for page in reader.pages:
77
- page_text = page.extract_text() or ""
78
- text += page_text
79
- return text
113
+ # Track loaded papers to prevent duplicate loading
114
+ self.loaded_papers = set()
115
+ self.vector_store_class = FAISS
116
+ logger.info("Using FAISS vector store")
117
+
118
+ # Store for initialized documents
119
+ self.documents: Dict[str, Document] = {}
120
+ self.vector_store: Optional[VectorStore] = None
121
+ self.paper_metadata: Dict[str, Dict[str, Any]] = {}
122
+
123
+ def add_paper(
124
+ self,
125
+ paper_id: str,
126
+ pdf_url: str,
127
+ paper_metadata: Dict[str, Any],
128
+ ) -> None:
129
+ """
130
+ Add a paper to the document store.
131
+
132
+ Args:
133
+ paper_id: Unique identifier for the paper
134
+ pdf_url: URL to the PDF
135
+ paper_metadata: Metadata about the paper
136
+ """
137
+ # Skip if already loaded
138
+ if paper_id in self.loaded_papers:
139
+ logger.info("Paper %s already loaded, skipping", paper_id)
140
+ return
141
+
142
+ logger.info("Loading paper %s from %s", paper_id, pdf_url)
143
+
144
+ # Store paper metadata
145
+ self.paper_metadata[paper_id] = paper_metadata
146
+
147
+ # Load the PDF and split into chunks according to Hydra config
148
+ loader = PyPDFLoader(pdf_url)
149
+ documents = loader.load()
150
+ logger.info("Loaded %d pages from %s", len(documents), paper_id)
151
+
152
+ # Create text splitter according to Hydra config
153
+ cfg = load_hydra_config()
154
+ splitter = RecursiveCharacterTextSplitter(
155
+ chunk_size=cfg.chunk_size,
156
+ chunk_overlap=cfg.chunk_overlap,
157
+ separators=["\n\n", "\n", ". ", " ", ""],
158
+ )
159
+
160
+ # Split documents and add metadata for each chunk
161
+ chunks = splitter.split_documents(documents)
162
+ logger.info("Split %s into %d chunks", paper_id, len(chunks))
163
+
164
+ # Enhance document metadata
165
+ for i, chunk in enumerate(chunks):
166
+ # Add paper metadata to each chunk
167
+ chunk.metadata.update(
168
+ {
169
+ "paper_id": paper_id,
170
+ "title": paper_metadata.get("Title", "Unknown"),
171
+ "chunk_id": i,
172
+ # Keep existing page number if available
173
+ "page": chunk.metadata.get("page", 0),
174
+ }
175
+ )
176
+
177
+ # Add any additional metadata fields
178
+ for field in self.metadata_fields:
179
+ if field in paper_metadata and field not in chunk.metadata:
180
+ chunk.metadata[field] = paper_metadata[field]
181
+
182
+ # Store chunk
183
+ doc_id = f"{paper_id}_{i}"
184
+ self.documents[doc_id] = chunk
185
+
186
+ # Mark as loaded to prevent duplicate loading
187
+ self.loaded_papers.add(paper_id)
188
+ logger.info("Added %d chunks from paper %s", len(chunks), paper_id)
189
+
190
+ def build_vector_store(self) -> None:
191
+ """
192
+ Build the vector store from all loaded documents.
193
+ Should be called after all papers are added.
194
+ """
195
+ if not self.documents:
196
+ logger.warning("No documents added to build vector store")
197
+ return
198
+
199
+ if self.vector_store is not None:
200
+ logger.info("Vector store already built, skipping")
201
+ return
202
+
203
+ # Create vector store from documents
204
+ documents_list = list(self.documents.values())
205
+ self.vector_store = self.vector_store_class.from_documents(
206
+ documents=documents_list, embedding=self.embedding_model
207
+ )
208
+ logger.info("Built vector store with %d documents", len(documents_list))
209
+
210
+ def rank_papers_by_query(
211
+ self, query: str, top_k: int = 40
212
+ ) -> List[Tuple[str, float]]:
213
+ """
214
+ Rank papers by relevance to the query using NVIDIA's off-the-shelf re-ranker.
215
+
216
+ This function aggregates all chunks per paper, ranks them using the NVIDIA model,
217
+ and returns the top-k papers.
218
+
219
+ Args:
220
+ query (str): The query string.
221
+ top_k (int): Number of top papers to return.
222
+
223
+ Returns:
224
+ List of tuples (paper_id, dummy_score) sorted by relevance.
225
+ """
226
+
227
+ # Aggregate all document chunks for each paper
228
+ paper_texts = {}
229
+ for doc in self.documents.values():
230
+ paper_id = doc.metadata["paper_id"]
231
+ paper_texts.setdefault(paper_id, []).append(doc.page_content)
232
+
233
+ aggregated_documents = []
234
+ for paper_id, texts in paper_texts.items():
235
+ aggregated_text = " ".join(texts)
236
+ aggregated_documents.append(
237
+ Document(page_content=aggregated_text, metadata={"paper_id": paper_id})
238
+ )
239
+
240
+ # Instantiate the NVIDIA re-ranker client
241
+ config = load_hydra_config()
242
+ reranker = NVIDIARerank(
243
+ model=config.reranker.model,
244
+ api_key=config.reranker.api_key,
245
+ )
246
+
247
+ # Get the ranked list of documents based on the query
248
+ response = reranker.compress_documents(
249
+ query=query, documents=aggregated_documents
250
+ )
251
+
252
+ ranked_papers = [doc.metadata["paper_id"] for doc in response[:top_k]]
253
+ return ranked_papers
254
+
255
+ def retrieve_relevant_chunks(
256
+ self,
257
+ query: str,
258
+ paper_ids: Optional[List[str]] = None,
259
+ top_k: int = 25,
260
+ mmr_diversity: float = 1.00,
261
+ ) -> List[Document]:
262
+ """
263
+ Retrieve the most relevant chunks for a query using maximal marginal relevance.
264
+
265
+ Args:
266
+ query: Query string
267
+ paper_ids: Optional list of paper IDs to filter by
268
+ top_k: Number of chunks to retrieve
269
+ mmr_diversity: Diversity parameter for MMR (higher = more diverse)
270
+
271
+ Returns:
272
+ List of document chunks
273
+ """
274
+ if not self.vector_store:
275
+ logger.error("Failed to build vector store")
276
+ return []
277
+
278
+ if paper_ids:
279
+ logger.info("Filtering retrieval to papers: %s", paper_ids)
280
+
281
+ # Step 1: Embed the query
282
+ logger.info(
283
+ "Embedding query using model: %s", type(self.embedding_model).__name__
284
+ )
285
+ query_embedding = np.array(self.embedding_model.embed_query(query))
286
+
287
+ # Step 2: Filter relevant documents
288
+ all_docs = [
289
+ doc
290
+ for doc in self.documents.values()
291
+ if not paper_ids or doc.metadata["paper_id"] in paper_ids
292
+ ]
293
+
294
+ if not all_docs:
295
+ logger.warning("No documents found after filtering by paper_ids.")
296
+ return []
297
+
298
+ texts = [doc.page_content for doc in all_docs]
299
+
300
+ # Step 3: Batch embed all documents
301
+ logger.info("Starting batch embedding for %d chunks...", len(texts))
302
+ all_embeddings = self.embedding_model.embed_documents(texts)
303
+ logger.info("Completed embedding for %d chunks...", len(texts))
304
+
305
+ # Step 4: Apply MMR
306
+ mmr_indices = maximal_marginal_relevance(
307
+ query_embedding,
308
+ all_embeddings,
309
+ k=top_k,
310
+ lambda_mult=mmr_diversity,
311
+ )
312
+
313
+ results = [all_docs[i] for i in mmr_indices]
314
+ logger.info("Retrieved %d chunks using MMR", len(results))
315
+ return results
80
316
 
81
317
 
82
318
  def generate_answer(
83
- question: str, pdf_bytes: bytes, llm_model: BaseChatModel
319
+ question: str,
320
+ retrieved_chunks: List[Document],
321
+ llm_model: BaseChatModel,
322
+ config: Optional[Any] = None,
84
323
  ) -> Dict[str, Any]:
85
324
  """
86
- Generate an answer for a question using retrieval augmented generation on PDF content.
87
-
88
- This function extracts text from the PDF data, splits the text into manageable chunks,
89
- performs a similarity search to retrieve the most relevant segments, and then uses a
90
- question-answering chain (built using the provided llm_model) to generate an answer.
325
+ Generate an answer for a question using retrieved chunks.
91
326
 
92
327
  Args:
93
- question (str): The question to be answered.
94
- pdf_bytes (bytes): The binary content of the PDF document.
95
- llm_model (BaseChatModel): The language model instance to use for answering.
328
+ question (str): The question to answer
329
+ retrieved_chunks (List[Document]): List of relevant document chunks
330
+ llm_model (BaseChatModel): Language model for generating answers
331
+ config (Optional[Any]): Configuration for answer generation
96
332
 
97
333
  Returns:
98
- Dict[str, Any]: A dictionary containing the answer generated by the language model.
334
+ Dict[str, Any]: Dictionary with the answer and metadata
99
335
  """
100
- text = extract_text_from_pdf_data(pdf_bytes)
101
- logger.info("Extracted text from PDF.")
102
- text_splitter = CharacterTextSplitter(
103
- separator="\n", chunk_size=cfg.chunk_size, chunk_overlap=cfg.chunk_overlap
104
- )
105
- chunks = text_splitter.split_text(text)
106
- documents: List[Document] = [Document(page_content=chunk) for chunk in chunks]
107
- logger.info("Split PDF text into %d chunks.", len(documents))
108
-
109
- embeddings = OpenAIEmbeddings(openai_api_key=cfg.openai_api_key)
110
- vector_store = Annoy.from_documents(documents, embeddings)
111
- search_results = vector_store.similarity_search(question, k=cfg.num_retrievals)
112
- logger.info("Retrieved %d relevant document chunks.", len(search_results))
113
- # Use the provided llm_model to build the QA chain.
114
- qa_chain = load_qa_chain(llm_model, chain_type=cfg.qa_chain_type)
115
- answer = qa_chain.invoke(
116
- input={"input_documents": search_results, "question": question}
117
- )
118
- return answer
336
+ # Load configuration using the global function.
337
+ config = load_hydra_config()
119
338
 
339
+ # Ensure the configuration is not None and has the prompt_template.
340
+ if config is None:
341
+ raise ValueError("Hydra config loading failed: config is None.")
342
+ if "prompt_template" not in config:
343
+ raise ValueError("The prompt_template is missing from the configuration.")
120
344
 
121
- def generate_answer2(
122
- question: str, pdf_url: str, text_embedding_model: Embeddings
123
- ) -> Dict[str, Any]:
124
- """
125
- Generate an answer for a question using retrieval augmented generation on PDF content.
345
+ # Prepare context from retrieved documents with source attribution.
346
+ # Group chunks by paper_id
347
+ papers = {}
348
+ for doc in retrieved_chunks:
349
+ paper_id = doc.metadata.get("paper_id", "unknown")
350
+ if paper_id not in papers:
351
+ papers[paper_id] = []
352
+ papers[paper_id].append(doc)
126
353
 
127
- This function extracts text from the PDF data, splits the text into manageable chunks,
128
- performs a similarity search to retrieve the most relevant segments, and then uses a
129
- question-answering chain (built using the provided llm_model) to generate an answer.
354
+ # Format chunks by paper
355
+ formatted_chunks = []
356
+ doc_index = 1
357
+ for paper_id, chunks in papers.items():
358
+ # Get the title from the first chunk (should be the same for all chunks)
359
+ title = chunks[0].metadata.get("title", "Unknown")
130
360
 
131
- Args:
132
- question (str): The question to be answered.
133
- pdf_bytes (bytes): The binary content of the PDF document.
134
- llm_model (BaseChatModel): The language model instance to use for answering.
361
+ # Add a document header
362
+ formatted_chunks.append(
363
+ f"[Document {doc_index}] From: '{title}' (ID: {paper_id})"
364
+ )
135
365
 
136
- Returns:
137
- Dict[str, Any]: A dictionary containing the answer generated by the language model.
138
- """
139
- # text = extract_text_from_pdf_data(pdf_bytes)
140
- # logger.info("Extracted text from PDF.")
141
- logger.log(logging.INFO, "searching the article with the question: %s", question)
142
- # Load the article
143
- # loader = PyPDFLoader(state['pdf_file_name'])
144
- # loader = PyPDFLoader("https://arxiv.org/pdf/2310.08365")
145
- loader = PyPDFLoader(pdf_url)
146
- # Load the pages of the article
147
- pages = []
148
- for page in loader.lazy_load():
149
- pages.append(page)
150
- # Set up text embedding model
151
- # text_embedding_model = state['text_embedding_model']
152
- # text_embedding_model = OpenAIEmbeddings(openai_api_key=cfg.openai_api_key)
153
- logging.info("Loaded text embedding model %s", text_embedding_model)
154
- # Create a vector store from the pages
155
- vector_store = InMemoryVectorStore.from_documents(pages, text_embedding_model)
156
- # Search the article with the question
157
- docs = vector_store.similarity_search(question)
158
- # Return the content of the pages
159
- return "\n".join([doc.page_content for doc in docs])
160
- # return answer
161
-
162
-
163
- @tool(args_schema=QuestionAndAnswerInput)
164
- def question_and_answer_tool(
366
+ # Add each chunk with its page information
367
+ for chunk in chunks:
368
+ page = chunk.metadata.get("page", "unknown")
369
+ formatted_chunks.append(f"Page {page}: {chunk.page_content}")
370
+
371
+ # Increment document index for the next paper
372
+ doc_index += 1
373
+
374
+ # Join all chunks
375
+ context = "\n\n".join(formatted_chunks)
376
+
377
+ # Get unique paper sources.
378
+ paper_sources = {doc.metadata["paper_id"] for doc in retrieved_chunks}
379
+
380
+ # Create prompt using the Hydra-provided prompt_template.
381
+ prompt = config["prompt_template"].format(context=context, question=question)
382
+
383
+ # Get the answer from the language model
384
+ response = llm_model.invoke(prompt)
385
+
386
+ # Return the response with metadata
387
+ return {
388
+ "output_text": response.content,
389
+ "sources": [doc.metadata for doc in retrieved_chunks],
390
+ "num_sources": len(retrieved_chunks),
391
+ "papers_used": list(paper_sources),
392
+ }
393
+
394
+
395
+ @tool(args_schema=QuestionAndAnswerInput, parse_docstring=True)
396
+ def question_and_answer(
165
397
  question: str,
166
- tool_call_id: Annotated[str, InjectedToolCallId],
167
398
  state: Annotated[dict, InjectedState],
168
- ) -> Dict[str, Any]:
399
+ tool_call_id: Annotated[str, InjectedToolCallId],
400
+ paper_ids: Optional[List[str]] = None,
401
+ use_all_papers: bool = False,
402
+ ) -> Command[Any]:
169
403
  """
170
- Answer a question using PDF content stored in the state via retrieval augmented generation.
404
+ Answer a question using PDF content with advanced retrieval augmented generation.
171
405
 
172
- This tool retrieves the PDF binary data from the state (under the key "pdf_data"), extracts its
173
- textual content, and generates an answer to the specified question. It also extracts the
174
- llm_model (of type BaseChatModel) from the state to use for answering.
406
+ This tool retrieves PDF documents from URLs, processes them using semantic search,
407
+ and generates an answer to the user's question based on the most relevant content.
408
+ It can work with multiple papers simultaneously and provides source attribution.
175
409
 
176
410
  Args:
177
- question (str): The question regarding the PDF content.
411
+ question (str): The question to answer based on PDF content.
412
+ paper_ids (Optional[List[str]]): Optional list of specific paper IDs to query.
413
+ use_all_papers (bool): Whether to use all available papers.
178
414
  tool_call_id (str): Unique identifier for the current tool call.
179
- state (dict): A dictionary representing the current state, expected to contain PDF data
180
- under the key "pdf_data" with a sub-key "pdf_object" for the binary content,
181
- and a key "llm_model" holding the language model instance.
415
+ state (dict): Current state dictionary containing article data and required models.
416
+ Expected keys:
417
+ - "article_data": Dictionary containing article metadata including PDF URLs
418
+ - "text_embedding_model": Model for generating embeddings
419
+ - "llm_model": Language model for generating answers
420
+ - "vector_store": Optional Vectorstore instance
182
421
 
183
422
  Returns:
184
- Dict[str, Any]: A dictionary containing the generated answer or an error message.
423
+ Dict[str, Any]: A dictionary wrapped in a Command that updates the conversation
424
+ with either the answer or an error message.
425
+
426
+ Raises:
427
+ ValueError: If required components are missing or if PDF processing fails.
185
428
  """
186
- logger.info("Starting PDF Question and Answer tool using PDF data from state.")
187
- # print (state['text_embedding_model'])
188
- text_embedding_model = state["text_embedding_model"]
189
- pdf_state = state.get("pdf_data")
190
- if not pdf_state:
191
- error_msg = "No pdf_data found in state."
192
- logger.error(error_msg)
193
- return Command(
194
- update={
195
- "messages": [ToolMessage(content=error_msg, tool_call_id=tool_call_id)]
196
- }
197
- )
198
- pdf_bytes = pdf_state.get("pdf_object")
199
- if not pdf_bytes:
200
- error_msg = "PDF binary data is missing in the pdf_data from state."
201
- logger.error(error_msg)
202
- return Command(
203
- update={
204
- "messages": [ToolMessage(content=error_msg, tool_call_id=tool_call_id)]
205
- }
206
- )
207
- pdf_url = pdf_state.get("pdf_url")
208
- # Retrieve llm_model from state; use a default if not provided.
429
+ # Load configuration
430
+ config = load_hydra_config()
431
+ # Create a unique identifier for this call to track potential infinite loops
432
+ call_id = f"qa_call_{time.time()}"
433
+ logger.info(
434
+ "Starting PDF Question and Answer tool call %s for question: %s",
435
+ call_id,
436
+ question,
437
+ )
438
+
439
+ # Get required models from state
440
+ text_embedding_model = state.get("text_embedding_model")
441
+ if not text_embedding_model:
442
+ error_msg = "No text embedding model found in state."
443
+ logger.error("%s: %s", call_id, error_msg)
444
+ raise ValueError(error_msg)
445
+
209
446
  llm_model = state.get("llm_model")
210
447
  if not llm_model:
211
- logger.error("Missing LLM model instance in state.")
212
- return {"error": "No LLM model found in state."}
213
- # answer = generate_answer(question, pdf_bytes, llm_model)
214
- print(pdf_url)
215
- answer = generate_answer2(question, pdf_url, text_embedding_model)
216
- # logger.info("Generated answer: %s", answer)
217
- return answer
448
+ error_msg = "No LLM model found in state."
449
+ logger.error("%s: %s", call_id, error_msg)
450
+ raise ValueError(error_msg)
451
+
452
+ # Get article data from state
453
+ article_data = state.get("article_data", {})
454
+ if not article_data:
455
+ error_msg = "No article_data found in state."
456
+ logger.error("%s: %s", call_id, error_msg)
457
+ raise ValueError(error_msg)
458
+
459
+ # Always use a fresh in-memory document store for this Q&A call
460
+ vector_store = Vectorstore(embedding_model=text_embedding_model)
461
+
462
+ # Check if there are papers from different sources
463
+ has_uploaded_papers = any(
464
+ paper.get("source") == "upload"
465
+ for paper in article_data.values()
466
+ if isinstance(paper, dict)
467
+ )
468
+
469
+ has_zotero_papers = any(
470
+ paper.get("source") == "zotero"
471
+ for paper in article_data.values()
472
+ if isinstance(paper, dict)
473
+ )
474
+
475
+ has_arxiv_papers = any(
476
+ paper.get("source") == "arxiv"
477
+ for paper in article_data.values()
478
+ if isinstance(paper, dict)
479
+ )
480
+
481
+ # Choose papers to use
482
+ selected_paper_ids = []
483
+
484
+ if paper_ids:
485
+ # Use explicitly specified papers
486
+ selected_paper_ids = [pid for pid in paper_ids if pid in article_data]
487
+ logger.info(
488
+ "%s: Using explicitly specified papers: %s", call_id, selected_paper_ids
489
+ )
490
+
491
+ if not selected_paper_ids:
492
+ logger.warning(
493
+ "%s: None of the provided paper_ids %s were found", call_id, paper_ids
494
+ )
495
+
496
+ elif use_all_papers or has_uploaded_papers or has_zotero_papers or has_arxiv_papers:
497
+ # Use all available papers if explicitly requested or if we have papers from any source
498
+ selected_paper_ids = list(article_data.keys())
499
+ logger.info(
500
+ "%s: Using all %d available papers", call_id, len(selected_paper_ids)
501
+ )
502
+
503
+ else:
504
+ # Use semantic ranking to find relevant papers
505
+ # First ensure papers are loaded
506
+ for paper_id, paper in article_data.items():
507
+ pdf_url = paper.get("pdf_url")
508
+ if pdf_url and paper_id not in vector_store.loaded_papers:
509
+ try:
510
+ vector_store.add_paper(paper_id, pdf_url, paper)
511
+ except (IOError, ValueError) as e:
512
+ logger.error("Error loading paper %s: %s", paper_id, e)
513
+ raise
514
+
515
+ # Now rank papers
516
+ ranked_papers = vector_store.rank_papers_by_query(
517
+ question, top_k=config.top_k_papers
518
+ )
519
+ selected_paper_ids = [paper_id for paper_id, _ in ranked_papers]
520
+ logger.info(
521
+ "%s: Selected papers based on semantic relevance: %s",
522
+ call_id,
523
+ selected_paper_ids,
524
+ )
525
+
526
+ if not selected_paper_ids:
527
+ # Fallback to all papers if selection failed
528
+ selected_paper_ids = list(article_data.keys())
529
+ logger.info(
530
+ "%s: Falling back to all %d papers", call_id, len(selected_paper_ids)
531
+ )
532
+
533
+ # Load selected papers if needed
534
+ for paper_id in selected_paper_ids:
535
+ if paper_id not in vector_store.loaded_papers:
536
+ pdf_url = article_data[paper_id].get("pdf_url")
537
+ if pdf_url:
538
+ try:
539
+ vector_store.add_paper(paper_id, pdf_url, article_data[paper_id])
540
+ except (IOError, ValueError) as e:
541
+ logger.warning(
542
+ "%s: Error loading paper %s: %s", call_id, paper_id, e
543
+ )
544
+
545
+ # Ensure vector store is built
546
+ if not vector_store.vector_store:
547
+ vector_store.build_vector_store()
548
+
549
+ # Retrieve relevant chunks across selected papers
550
+ relevant_chunks = vector_store.retrieve_relevant_chunks(
551
+ query=question, paper_ids=selected_paper_ids, top_k=config.top_k_chunks
552
+ )
553
+
554
+ if not relevant_chunks:
555
+ error_msg = "No relevant chunks found in the papers."
556
+ logger.warning("%s: %s", call_id, error_msg)
557
+ raise RuntimeError(
558
+ f"I couldn't find relevant information to answer your question: '{question}'. "
559
+ "Please try rephrasing or asking a different question."
560
+ )
561
+
562
+ # Generate answer using retrieved chunks
563
+ result = generate_answer(question, relevant_chunks, llm_model)
564
+
565
+ # Format answer with attribution
566
+ answer_text = result.get("output_text", "No answer generated.")
567
+
568
+ # Get paper titles for sources
569
+ paper_titles = {}
570
+ for paper_id in result.get("papers_used", []):
571
+ if paper_id in article_data:
572
+ paper_titles[paper_id] = article_data[paper_id].get(
573
+ "Title", "Unknown paper"
574
+ )
575
+
576
+ # Format source information
577
+ sources_text = ""
578
+ if paper_titles:
579
+ sources_text = "\n\nSources:\n" + "\n".join(
580
+ [f"- {title}" for title in paper_titles.values()]
581
+ )
582
+
583
+ # Prepare the final response
584
+ response_text = f"{answer_text}{sources_text}"
585
+ logger.info(
586
+ "%s: Successfully generated answer using %d chunks from %d papers",
587
+ call_id,
588
+ len(relevant_chunks),
589
+ len(paper_titles),
590
+ )
591
+
592
+ return Command(
593
+ update={
594
+ "messages": [
595
+ ToolMessage(
596
+ content=response_text,
597
+ tool_call_id=tool_call_id,
598
+ )
599
+ ],
600
+ }
601
+ )