aiagents4pharma 1.39.0__py3-none-any.whl → 1.39.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. aiagents4pharma/talk2scholars/agents/main_agent.py +7 -7
  2. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/main_agent/default.yaml +88 -12
  3. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/paper_download_agent/default.yaml +5 -0
  4. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/pdf_agent/default.yaml +5 -0
  5. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/s2_agent/default.yaml +1 -20
  6. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/zotero_agent/default.yaml +1 -26
  7. aiagents4pharma/talk2scholars/configs/tools/download_arxiv_paper/default.yaml +4 -0
  8. aiagents4pharma/talk2scholars/configs/tools/download_biorxiv_paper/default.yaml +2 -0
  9. aiagents4pharma/talk2scholars/configs/tools/download_medrxiv_paper/default.yaml +2 -0
  10. aiagents4pharma/talk2scholars/configs/tools/question_and_answer/default.yaml +22 -0
  11. aiagents4pharma/talk2scholars/tests/test_main_agent.py +20 -2
  12. aiagents4pharma/talk2scholars/tests/test_nvidia_nim_reranker_utils.py +28 -0
  13. aiagents4pharma/talk2scholars/tests/test_paper_download_tools.py +107 -29
  14. aiagents4pharma/talk2scholars/tests/test_pdf_agent.py +2 -3
  15. aiagents4pharma/talk2scholars/tests/test_question_and_answer_tool.py +194 -543
  16. aiagents4pharma/talk2scholars/tests/test_s2_agent.py +2 -2
  17. aiagents4pharma/talk2scholars/tests/{test_s2_display.py → test_s2_display_dataframe.py} +2 -3
  18. aiagents4pharma/talk2scholars/tests/test_s2_query_dataframe.py +201 -0
  19. aiagents4pharma/talk2scholars/tests/test_s2_retrieve.py +7 -6
  20. aiagents4pharma/talk2scholars/tests/test_s2_utils_ext_ids.py +413 -0
  21. aiagents4pharma/talk2scholars/tests/test_tool_helper_utils.py +140 -0
  22. aiagents4pharma/talk2scholars/tests/test_zotero_agent.py +0 -1
  23. aiagents4pharma/talk2scholars/tests/test_zotero_read.py +16 -18
  24. aiagents4pharma/talk2scholars/tools/paper_download/download_arxiv_input.py +92 -37
  25. aiagents4pharma/talk2scholars/tools/pdf/question_and_answer.py +73 -575
  26. aiagents4pharma/talk2scholars/tools/pdf/utils/__init__.py +10 -0
  27. aiagents4pharma/talk2scholars/tools/pdf/utils/generate_answer.py +97 -0
  28. aiagents4pharma/talk2scholars/tools/pdf/utils/nvidia_nim_reranker.py +77 -0
  29. aiagents4pharma/talk2scholars/tools/pdf/utils/retrieve_chunks.py +83 -0
  30. aiagents4pharma/talk2scholars/tools/pdf/utils/tool_helper.py +125 -0
  31. aiagents4pharma/talk2scholars/tools/pdf/utils/vector_store.py +162 -0
  32. aiagents4pharma/talk2scholars/tools/s2/display_dataframe.py +33 -10
  33. aiagents4pharma/talk2scholars/tools/s2/multi_paper_rec.py +39 -16
  34. aiagents4pharma/talk2scholars/tools/s2/query_dataframe.py +124 -10
  35. aiagents4pharma/talk2scholars/tools/s2/retrieve_semantic_scholar_paper_id.py +49 -17
  36. aiagents4pharma/talk2scholars/tools/s2/search.py +39 -16
  37. aiagents4pharma/talk2scholars/tools/s2/single_paper_rec.py +34 -16
  38. aiagents4pharma/talk2scholars/tools/s2/utils/multi_helper.py +49 -16
  39. aiagents4pharma/talk2scholars/tools/s2/utils/search_helper.py +51 -16
  40. aiagents4pharma/talk2scholars/tools/s2/utils/single_helper.py +50 -17
  41. {aiagents4pharma-1.39.0.dist-info → aiagents4pharma-1.39.2.dist-info}/METADATA +58 -105
  42. {aiagents4pharma-1.39.0.dist-info → aiagents4pharma-1.39.2.dist-info}/RECORD +45 -32
  43. aiagents4pharma/talk2scholars/tests/test_llm_main_integration.py +0 -89
  44. aiagents4pharma/talk2scholars/tests/test_routing_logic.py +0 -74
  45. aiagents4pharma/talk2scholars/tests/test_s2_query.py +0 -95
  46. {aiagents4pharma-1.39.0.dist-info → aiagents4pharma-1.39.2.dist-info}/WHEEL +0 -0
  47. {aiagents4pharma-1.39.0.dist-info → aiagents4pharma-1.39.2.dist-info}/licenses/LICENSE +0 -0
  48. {aiagents4pharma-1.39.0.dist-info → aiagents4pharma-1.39.2.dist-info}/top_level.txt +0 -0
@@ -1,635 +1,133 @@
1
1
  """
2
- PDF Question & Answer Tool
3
-
4
- This LangGraph tool answers user questions by leveraging a pre-built FAISS vector store
5
- of embedded PDF document chunks. Given a question, it retrieves the most relevant text
6
- segments from the loaded PDFs, invokes an LLM for answer generation, and returns the
7
- response with source attribution.
2
+ LangGraph PDF Retrieval-Augmented Generation (RAG) Tool
3
+
4
+ This tool answers user questions by retrieving and ranking relevant text chunks from PDFs
5
+ and invoking an LLM to generate a concise, source-attributed response. It supports
6
+ single or multiple PDF sources—such as Zotero libraries, arXiv papers, or direct uploads.
7
+
8
+ Workflow:
9
+ 1. (Optional) Load PDFs from diverse sources into a FAISS vector store of embeddings.
10
+ 2. Rerank candidate papers using NVIDIA NIM semantic re-ranker.
11
+ 3. Retrieve top-K diverse text chunks via Maximal Marginal Relevance (MMR).
12
+ 4. Build a context-rich prompt combining retrieved chunks and the user question.
13
+ 5. Invoke the LLM to craft a clear answer with source citations.
14
+ 6. Return the answer in a ToolMessage for LangGraph to dispatch.
8
15
  """
9
16
 
10
17
  import logging
11
18
  import os
12
19
  import time
13
- from typing import Annotated, Any, Dict, List, Optional, Tuple
20
+ from typing import Annotated, Any
14
21
 
15
- import hydra
16
- import numpy as np
17
- from langchain.text_splitter import RecursiveCharacterTextSplitter
18
- from langchain_community.document_loaders import PyPDFLoader
19
- from langchain_community.vectorstores import FAISS
20
- from langchain_core.documents import Document
21
- from langchain_core.embeddings import Embeddings
22
- from langchain_core.language_models.chat_models import BaseChatModel
23
22
  from langchain_core.messages import ToolMessage
24
23
  from langchain_core.tools import tool
25
24
  from langchain_core.tools.base import InjectedToolCallId
26
- from langchain_core.vectorstores import VectorStore
27
- from langchain_core.vectorstores.utils import maximal_marginal_relevance
28
- from langchain_nvidia_ai_endpoints import NVIDIARerank
29
25
  from langgraph.prebuilt import InjectedState
30
26
  from langgraph.types import Command
31
27
  from pydantic import BaseModel, Field
32
28
 
29
+ from .utils.generate_answer import load_hydra_config
30
+ from .utils.retrieve_chunks import retrieve_relevant_chunks
31
+ from .utils.tool_helper import QAToolHelper
32
+
33
+ # Helper for managing state, vectorstore, reranking, and formatting
34
+ helper = QAToolHelper()
35
+ # Load configuration and start logging
36
+ config = load_hydra_config()
37
+
33
38
  # Set up logging with configurable level
34
39
  log_level = os.environ.get("LOG_LEVEL", "INFO")
35
40
  logging.basicConfig(level=getattr(logging, log_level))
36
41
  logger = logging.getLogger(__name__)
37
42
  logger.setLevel(getattr(logging, log_level))
38
- # pylint: disable=too-many-arguments, too-many-positional-arguments, too-many-locals, too-many-branches, too-many-statements
39
-
40
-
41
- def load_hydra_config() -> Any:
42
- """
43
- Load the configuration using Hydra and return the configuration for the Q&A tool.
44
- """
45
- with hydra.initialize(version_base=None, config_path="../../configs"):
46
- cfg = hydra.compose(
47
- config_name="config",
48
- overrides=["tools/question_and_answer=default"],
49
- )
50
- config = cfg.tools.question_and_answer
51
- logger.info("Loaded Question and Answer tool configuration.")
52
- return config
53
43
 
54
44
 
55
45
  class QuestionAndAnswerInput(BaseModel):
56
46
  """
57
- Input schema for the PDF Q&A tool.
58
-
59
- Attributes:
60
- question (str): Free-text question to answer based on PDF content.
61
- paper_ids (Optional[List[str]]): If provided, restricts retrieval to these paper IDs.
62
- use_all_papers (bool): If True, include all loaded papers without semantic ranking.
63
- tool_call_id (str): Internal ID injected by LangGraph for this tool call.
64
- state (dict): Shared agent state containing:
65
- - 'article_data': dict of paper metadata with 'pdf_url' keys
66
- - 'text_embedding_model': embedding model instance
67
- - 'llm_model': chat/LLM instance
68
- - 'vector_store': pre-built Vectorstore for retrieval
47
+ Pydantic schema for the PDF Q&A tool inputs.
48
+
49
+ Fields:
50
+ question: User's free-text query to answer based on PDF content.
51
+ tool_call_id: LangGraph-injected call identifier for tracking.
52
+ state: Shared agent state dict containing:
53
+ - article_data: metadata mapping of paper IDs to info (e.g., 'pdf_url', title).
54
+ - text_embedding_model: embedding model instance for chunk indexing.
55
+ - llm_model: chat/LLM instance for answer generation.
56
+ - vector_store: optional pre-built Vectorstore for retrieval.
69
57
  """
70
58
 
71
- question: str = Field(description="The question to ask regarding the PDF content.")
72
- paper_ids: Optional[List[str]] = Field(
73
- default=None,
74
- description="Optional list of specific paper IDs to query. "
75
- "If not provided, relevant papers will be selected automatically.",
76
- )
77
- use_all_papers: bool = Field(
78
- default=False,
79
- description="Whether to use all available papers for answering the question. "
80
- "Set to True to bypass relevance filtering and include all loaded papers.",
59
+ question: str = Field(
60
+ description="User question for generating a PDF-based answer."
81
61
  )
82
62
  tool_call_id: Annotated[str, InjectedToolCallId]
83
63
  state: Annotated[dict, InjectedState]
84
64
 
85
65
 
86
- class Vectorstore:
87
- """
88
- A class for managing document embeddings and retrieval.
89
- Provides unified access to documents across multiple papers.
90
- """
91
-
92
- def __init__(
93
- self,
94
- embedding_model: Embeddings,
95
- metadata_fields: Optional[List[str]] = None,
96
- ):
97
- """
98
- Initialize the document store.
99
-
100
- Args:
101
- embedding_model: The embedding model to use
102
- metadata_fields: Fields to include in document metadata for filtering/retrieval
103
- """
104
- self.embedding_model = embedding_model
105
- self.metadata_fields = metadata_fields or [
106
- "title",
107
- "paper_id",
108
- "page",
109
- "chunk_id",
110
- ]
111
- self.initialization_time = time.time()
112
- logger.info("Vectorstore initialized at: %s", self.initialization_time)
113
-
114
- # Track loaded papers to prevent duplicate loading
115
- self.loaded_papers = set()
116
- self.vector_store_class = FAISS
117
- logger.info("Using FAISS vector store")
118
-
119
- # Store for initialized documents
120
- self.documents: Dict[str, Document] = {}
121
- self.vector_store: Optional[VectorStore] = None
122
- self.paper_metadata: Dict[str, Dict[str, Any]] = {}
123
- # Cache for document chunk embeddings to avoid recomputation
124
- self.embeddings: Dict[str, Any] = {}
125
-
126
- def add_paper(
127
- self,
128
- paper_id: str,
129
- pdf_url: str,
130
- paper_metadata: Dict[str, Any],
131
- ) -> None:
132
- """
133
- Add a paper to the document store.
134
-
135
- Args:
136
- paper_id: Unique identifier for the paper
137
- pdf_url: URL to the PDF
138
- paper_metadata: Metadata about the paper
139
- """
140
- # Skip if already loaded
141
- if paper_id in self.loaded_papers:
142
- logger.info("Paper %s already loaded, skipping", paper_id)
143
- return
144
-
145
- logger.info("Loading paper %s from %s", paper_id, pdf_url)
146
-
147
- # Store paper metadata
148
- self.paper_metadata[paper_id] = paper_metadata
149
-
150
- # Load the PDF and split into chunks according to Hydra config
151
- loader = PyPDFLoader(pdf_url)
152
- documents = loader.load()
153
- logger.info("Loaded %d pages from %s", len(documents), paper_id)
154
-
155
- # Create text splitter according to Hydra config
156
- cfg = load_hydra_config()
157
- splitter = RecursiveCharacterTextSplitter(
158
- chunk_size=cfg.chunk_size,
159
- chunk_overlap=cfg.chunk_overlap,
160
- separators=["\n\n", "\n", ". ", " ", ""],
161
- )
162
-
163
- # Split documents and add metadata for each chunk
164
- chunks = splitter.split_documents(documents)
165
- logger.info("Split %s into %d chunks", paper_id, len(chunks))
166
- # Embed and cache chunk embeddings
167
- chunk_texts = [chunk.page_content for chunk in chunks]
168
- chunk_embeddings = self.embedding_model.embed_documents(chunk_texts)
169
- logger.info("Embedded %d chunks for paper %s", len(chunk_embeddings), paper_id)
170
-
171
- # Enhance document metadata
172
- for i, chunk in enumerate(chunks):
173
- # Add paper metadata to each chunk
174
- chunk.metadata.update(
175
- {
176
- "paper_id": paper_id,
177
- "title": paper_metadata.get("Title", "Unknown"),
178
- "chunk_id": i,
179
- # Keep existing page number if available
180
- "page": chunk.metadata.get("page", 0),
181
- }
182
- )
183
-
184
- # Add any additional metadata fields
185
- for field in self.metadata_fields:
186
- if field in paper_metadata and field not in chunk.metadata:
187
- chunk.metadata[field] = paper_metadata[field]
188
-
189
- # Store chunk
190
- doc_id = f"{paper_id}_{i}"
191
- self.documents[doc_id] = chunk
192
- # Cache embedding if available
193
- if chunk_embeddings[i] is not None:
194
- self.embeddings[doc_id] = chunk_embeddings[i]
195
-
196
- # Mark as loaded to prevent duplicate loading
197
- self.loaded_papers.add(paper_id)
198
- logger.info("Added %d chunks from paper %s", len(chunks), paper_id)
199
-
200
- def build_vector_store(self) -> None:
201
- """
202
- Build the vector store from all loaded documents.
203
- Should be called after all papers are added.
204
- """
205
- if not self.documents:
206
- logger.warning("No documents added to build vector store")
207
- return
208
-
209
- if self.vector_store is not None:
210
- logger.info("Vector store already built, skipping")
211
- return
212
-
213
- # Create vector store from documents
214
- documents_list = list(self.documents.values())
215
- self.vector_store = self.vector_store_class.from_documents(
216
- documents=documents_list, embedding=self.embedding_model
217
- )
218
- logger.info("Built vector store with %d documents", len(documents_list))
219
-
220
- def rank_papers_by_query(
221
- self, query: str, top_k: int = 40
222
- ) -> List[Tuple[str, float]]:
223
- """
224
- Rank papers by relevance to the query using NVIDIA's off-the-shelf re-ranker.
225
-
226
- This function aggregates all chunks per paper, ranks them using the NVIDIA model,
227
- and returns the top-k papers.
228
-
229
- Args:
230
- query (str): The query string.
231
- top_k (int): Number of top papers to return.
232
-
233
- Returns:
234
- List of tuples (paper_id, dummy_score) sorted by relevance.
235
- """
236
-
237
- # Aggregate all document chunks for each paper
238
- paper_texts = {}
239
- for doc in self.documents.values():
240
- paper_id = doc.metadata["paper_id"]
241
- paper_texts.setdefault(paper_id, []).append(doc.page_content)
242
-
243
- aggregated_documents = []
244
- for paper_id, texts in paper_texts.items():
245
- aggregated_text = " ".join(texts)
246
- aggregated_documents.append(
247
- Document(page_content=aggregated_text, metadata={"paper_id": paper_id})
248
- )
249
-
250
- # Instantiate the NVIDIA re-ranker client
251
- config = load_hydra_config()
252
- reranker = NVIDIARerank(
253
- model=config.reranker.model,
254
- api_key=config.reranker.api_key,
255
- )
256
-
257
- # Get the ranked list of documents based on the query
258
- response = reranker.compress_documents(
259
- query=query, documents=aggregated_documents
260
- )
261
-
262
- ranked_papers = [doc.metadata["paper_id"] for doc in response[:top_k]]
263
- return ranked_papers
264
-
265
- def retrieve_relevant_chunks(
266
- self,
267
- query: str,
268
- paper_ids: Optional[List[str]] = None,
269
- top_k: int = 25,
270
- mmr_diversity: float = 1.00,
271
- ) -> List[Document]:
272
- """
273
- Retrieve the most relevant chunks for a query using maximal marginal relevance.
274
-
275
- Args:
276
- query: Query string
277
- paper_ids: Optional list of paper IDs to filter by
278
- top_k: Number of chunks to retrieve
279
- mmr_diversity: Diversity parameter for MMR (higher = more diverse)
280
-
281
- Returns:
282
- List of document chunks
283
- """
284
- if not self.vector_store:
285
- logger.error("Failed to build vector store")
286
- return []
287
-
288
- if paper_ids:
289
- logger.info("Filtering retrieval to papers: %s", paper_ids)
290
-
291
- # Step 1: Embed the query
292
- logger.info(
293
- "Embedding query using model: %s", type(self.embedding_model).__name__
294
- )
295
- query_embedding = np.array(self.embedding_model.embed_query(query))
296
-
297
- # Step 2: Filter relevant documents
298
- all_docs = [
299
- doc
300
- for doc in self.documents.values()
301
- if not paper_ids or doc.metadata["paper_id"] in paper_ids
302
- ]
303
-
304
- if not all_docs:
305
- logger.warning("No documents found after filtering by paper_ids.")
306
- return []
307
-
308
- # Step 3: Retrieve or compute embeddings for all documents using cache
309
- logger.info("Retrieving embeddings for %d chunks...", len(all_docs))
310
- all_embeddings = []
311
- for doc in all_docs:
312
- doc_id = f"{doc.metadata['paper_id']}_{doc.metadata['chunk_id']}"
313
- if doc_id not in self.embeddings:
314
- logger.info("Embedding missing chunk %s", doc_id)
315
- emb = self.embedding_model.embed_documents([doc.page_content])[0]
316
- self.embeddings[doc_id] = emb
317
- all_embeddings.append(self.embeddings[doc_id])
318
-
319
- # Step 4: Apply MMR
320
- mmr_indices = maximal_marginal_relevance(
321
- query_embedding,
322
- all_embeddings,
323
- k=top_k,
324
- lambda_mult=mmr_diversity,
325
- )
326
-
327
- results = [all_docs[i] for i in mmr_indices]
328
- logger.info("Retrieved %d chunks using MMR", len(results))
329
- return results
330
-
331
-
332
- def generate_answer(
333
- question: str,
334
- retrieved_chunks: List[Document],
335
- llm_model: BaseChatModel,
336
- config: Optional[Any] = None,
337
- ) -> Dict[str, Any]:
338
- """
339
- Generate an answer for a question using retrieved chunks.
340
-
341
- Args:
342
- question (str): The question to answer
343
- retrieved_chunks (List[Document]): List of relevant document chunks
344
- llm_model (BaseChatModel): Language model for generating answers
345
- config (Optional[Any]): Configuration for answer generation
346
-
347
- Returns:
348
- Dict[str, Any]: Dictionary with the answer and metadata
349
- """
350
- # Load configuration using the global function.
351
- config = load_hydra_config()
352
-
353
- # Ensure the configuration is not None and has the prompt_template.
354
- if config is None:
355
- raise ValueError("Hydra config loading failed: config is None.")
356
- if "prompt_template" not in config:
357
- raise ValueError("The prompt_template is missing from the configuration.")
358
-
359
- # Prepare context from retrieved documents with source attribution.
360
- # Group chunks by paper_id
361
- papers = {}
362
- for doc in retrieved_chunks:
363
- paper_id = doc.metadata.get("paper_id", "unknown")
364
- if paper_id not in papers:
365
- papers[paper_id] = []
366
- papers[paper_id].append(doc)
367
-
368
- # Format chunks by paper
369
- formatted_chunks = []
370
- doc_index = 1
371
- for paper_id, chunks in papers.items():
372
- # Get the title from the first chunk (should be the same for all chunks)
373
- title = chunks[0].metadata.get("title", "Unknown")
374
-
375
- # Add a document header
376
- formatted_chunks.append(
377
- f"[Document {doc_index}] From: '{title}' (ID: {paper_id})"
378
- )
379
-
380
- # Add each chunk with its page information
381
- for chunk in chunks:
382
- page = chunk.metadata.get("page", "unknown")
383
- formatted_chunks.append(f"Page {page}: {chunk.page_content}")
384
-
385
- # Increment document index for the next paper
386
- doc_index += 1
387
-
388
- # Join all chunks
389
- context = "\n\n".join(formatted_chunks)
390
-
391
- # Get unique paper sources.
392
- paper_sources = {doc.metadata["paper_id"] for doc in retrieved_chunks}
393
-
394
- # Create prompt using the Hydra-provided prompt_template.
395
- prompt = config["prompt_template"].format(context=context, question=question)
396
-
397
- # Get the answer from the language model
398
- response = llm_model.invoke(prompt)
399
-
400
- # Return the response with metadata
401
- return {
402
- "output_text": response.content,
403
- "sources": [doc.metadata for doc in retrieved_chunks],
404
- "num_sources": len(retrieved_chunks),
405
- "papers_used": list(paper_sources),
406
- }
407
-
408
-
409
- # Shared pre-built Vectorstore for RAG (set externally, e.g., by Streamlit startup)
410
- prebuilt_vector_store: Optional[Vectorstore] = None
411
-
412
-
413
66
  @tool(args_schema=QuestionAndAnswerInput, parse_docstring=True)
414
67
  def question_and_answer(
415
68
  question: str,
416
69
  state: Annotated[dict, InjectedState],
417
70
  tool_call_id: Annotated[str, InjectedToolCallId],
418
- paper_ids: Optional[List[str]] = None,
419
- use_all_papers: bool = False,
420
71
  ) -> Command[Any]:
421
72
  """
422
- Generate an answer to a user question using Retrieval-Augmented Generation (RAG) over PDFs.
423
-
424
- This tool expects that a FAISS vector store of PDF document chunks has already been built
425
- and stored in shared state. It retrieves the most relevant chunks for the input question,
426
- invokes an LLM to craft a response, and returns the answer with source attribution.
73
+ LangGraph tool for Retrieval-Augmented Generation over PDFs.
74
+
75
+ Given a user question, this tool applies the following pipeline:
76
+ 1. Validates that embedding and LLM models, plus article metadata, are in state.
77
+ 2. Initializes or reuses a FAISS-based Vectorstore for PDF embeddings.
78
+ 3. Loads one or more PDFs (from Zotero, arXiv, uploads) as text chunks into the store.
79
+ 4. Uses NVIDIA NIM semantic re-ranker to select top candidate papers.
80
+ 5. Retrieves the most relevant and diverse text chunks via Maximal Marginal Relevance.
81
+ 6. Constructs an LLM prompt combining contextual chunks and the query.
82
+ 7. Invokes the LLM to generate an answer, appending source attributions.
83
+ 8. Returns a LangGraph Command with a ToolMessage containing the answer.
427
84
 
428
85
  Args:
429
- question (str): The free-text question to answer.
430
- state (dict): Injected agent state mapping that must include:
431
- - 'article_data': mapping of paper IDs to metadata (including 'pdf_url')
432
- - 'text_embedding_model': the embedding model instance
433
- - 'llm_model': the chat/LLM instance
434
- tool_call_id (str): Internal identifier for this tool call.
435
- paper_ids (Optional[List[str]]): Specific paper IDs to restrict retrieval (default: None).
436
- use_all_papers (bool): If True, bypasses semantic ranking and includes all papers.
86
+ question (str): The free-text question to answer.
87
+ state (dict): Injected agent state; must include:
88
+ - article_data: mapping paper IDs metadata (pdf_url, title, etc.)
89
+ - text_embedding_model: embedding model instance.
90
+ - llm_model: chat/LLM instance.
91
+ tool_call_id (str): Internal identifier for this tool invocation.
437
92
 
438
93
  Returns:
439
- Command[Any]: A LangGraph Command that updates the conversation state:
440
- - 'messages': a single ToolMessage containing the generated answer text.
94
+ Command[Any]: updates conversation state with a ToolMessage(answer).
441
95
 
442
96
  Raises:
443
- ValueError: If required models or 'article_data' are missing from state.
444
- RuntimeError: If no relevant document chunks can be retrieved.
97
+ ValueError: when required models or metadata are missing in state.
98
+ RuntimeError: when no relevant chunks can be retrieved for the query.
445
99
  """
446
- # Load configuration
447
- config = load_hydra_config()
448
- # Create a unique identifier for this call to track potential infinite loops
449
100
  call_id = f"qa_call_{time.time()}"
450
101
  logger.info(
451
102
  "Starting PDF Question and Answer tool call %s for question: %s",
452
103
  call_id,
453
104
  question,
454
105
  )
106
+ helper.start_call(config, call_id)
455
107
 
456
- # Get required models from state
457
- text_embedding_model = state.get("text_embedding_model")
458
- if not text_embedding_model:
459
- error_msg = "No text embedding model found in state."
460
- logger.error("%s: %s", call_id, error_msg)
461
- raise ValueError(error_msg)
462
-
463
- llm_model = state.get("llm_model")
464
- if not llm_model:
465
- error_msg = "No LLM model found in state."
466
- logger.error("%s: %s", call_id, error_msg)
467
- raise ValueError(error_msg)
468
-
469
- # Get article data from state
470
- article_data = state.get("article_data", {})
471
- if not article_data:
472
- error_msg = "No article_data found in state."
473
- logger.error("%s: %s", call_id, error_msg)
474
- raise ValueError(error_msg)
108
+ # Extract models and article metadata
109
+ text_emb, llm_model, article_data = helper.get_state_models_and_data(state)
475
110
 
476
- # Use shared pre-built Vectorstore if provided, else create a new one
477
- if prebuilt_vector_store is not None:
478
- vector_store = prebuilt_vector_store
479
- logger.info("Using shared pre-built vector store from the memory")
480
- else:
481
- vector_store = Vectorstore(embedding_model=text_embedding_model)
482
- logger.info("Initialized new vector store (no pre-built store found)")
111
+ # Initialize or reuse vector store, then load candidate papers
112
+ vs = helper.init_vector_store(text_emb)
113
+ candidate_ids = list(article_data.keys())
114
+ logger.info("%s: Candidate paper IDs for reranking: %s", call_id, candidate_ids)
115
+ helper.load_candidate_papers(vs, article_data, candidate_ids)
483
116
 
484
- # Check if there are papers from different sources
485
- has_uploaded_papers = any(
486
- paper.get("source") == "upload"
487
- for paper in article_data.values()
488
- if isinstance(paper, dict)
117
+ # Rerank papers and retrieve top chunks
118
+ selected_ids = helper.run_reranker(vs, question, candidate_ids)
119
+ relevant_chunks = retrieve_relevant_chunks(
120
+ vs, query=question, paper_ids=selected_ids, top_k=config.top_k_chunks
489
121
  )
490
-
491
- has_zotero_papers = any(
492
- paper.get("source") == "zotero"
493
- for paper in article_data.values()
494
- if isinstance(paper, dict)
495
- )
496
-
497
- has_arxiv_papers = any(
498
- paper.get("source") == "arxiv"
499
- for paper in article_data.values()
500
- if isinstance(paper, dict)
501
- )
502
-
503
- has_biorxiv_papers = any(
504
- paper.get("source") == "biorxiv"
505
- for paper in article_data.values()
506
- if isinstance(paper, dict)
507
- )
508
-
509
- has_medrxiv_papers = any(
510
- paper.get("source") == "medrxiv"
511
- for paper in article_data.values()
512
- if isinstance(paper, dict)
513
- )
514
-
515
- # Choose papers to use
516
- selected_paper_ids = []
517
- has_combimed_papers = (
518
- has_uploaded_papers
519
- or has_zotero_papers
520
- or has_arxiv_papers
521
- or has_biorxiv_papers
522
- or has_medrxiv_papers
523
- )
524
-
525
- if paper_ids:
526
- # Use explicitly specified papers
527
- selected_paper_ids = [pid for pid in paper_ids if pid in article_data]
528
- logger.info(
529
- "%s: Using explicitly specified papers: %s", call_id, selected_paper_ids
530
- )
531
-
532
- if not selected_paper_ids:
533
- logger.warning(
534
- "%s: None of the provided paper_ids %s were found", call_id, paper_ids
535
- )
536
-
537
- elif use_all_papers or has_combimed_papers:
538
- # Use all available papers if explicitly requested or if we have papers from any source
539
- selected_paper_ids = list(article_data.keys())
540
- logger.info(
541
- "%s: Using all %d available papers", call_id, len(selected_paper_ids)
542
- )
543
-
544
- else:
545
- # Use semantic ranking to find relevant papers
546
- # First ensure papers are loaded
547
- for paper_id, paper in article_data.items():
548
- pdf_url = paper.get("pdf_url")
549
- if pdf_url and paper_id not in vector_store.loaded_papers:
550
- try:
551
- vector_store.add_paper(paper_id, pdf_url, paper)
552
- except (IOError, ValueError) as e:
553
- logger.error("Error loading paper %s: %s", paper_id, e)
554
- raise
555
-
556
- # Now rank papers
557
- ranked_papers = vector_store.rank_papers_by_query(
558
- question, top_k=config.top_k_papers
559
- )
560
- selected_paper_ids = [paper_id for paper_id, _ in ranked_papers]
561
- logger.info(
562
- "%s: Selected papers based on semantic relevance: %s",
563
- call_id,
564
- selected_paper_ids,
565
- )
566
-
567
- if not selected_paper_ids:
568
- # Fallback to all papers if selection failed
569
- selected_paper_ids = list(article_data.keys())
570
- logger.info(
571
- "%s: Falling back to all %d papers", call_id, len(selected_paper_ids)
572
- )
573
-
574
- # Load selected papers if needed
575
- for paper_id in selected_paper_ids:
576
- if paper_id not in vector_store.loaded_papers:
577
- pdf_url = article_data[paper_id].get("pdf_url")
578
- if pdf_url:
579
- try:
580
- vector_store.add_paper(paper_id, pdf_url, article_data[paper_id])
581
- except (IOError, ValueError) as e:
582
- logger.warning(
583
- "%s: Error loading paper %s: %s", call_id, paper_id, e
584
- )
585
-
586
- # Ensure vector store is built
587
- if not vector_store.vector_store:
588
- vector_store.build_vector_store()
589
-
590
- # Retrieve relevant chunks across selected papers
591
- relevant_chunks = vector_store.retrieve_relevant_chunks(
592
- query=question, paper_ids=selected_paper_ids, top_k=config.top_k_chunks
593
- )
594
-
595
122
  if not relevant_chunks:
596
- error_msg = "No relevant chunks found in the papers."
597
- logger.warning("%s: %s", call_id, error_msg)
598
- raise RuntimeError(
599
- f"I couldn't find relevant information to answer your question: '{question}'. "
600
- "Please try rephrasing or asking a different question."
601
- )
602
-
603
- # Generate answer using retrieved chunks
604
- result = generate_answer(question, relevant_chunks, llm_model)
123
+ msg = f"No relevant chunks found for question: '{question}'"
124
+ logger.warning("%s: %s", call_id, msg)
125
+ raise RuntimeError(msg)
605
126
 
606
- # Format answer with attribution
607
- answer_text = result.get("output_text", "No answer generated.")
608
-
609
- # Get paper titles for sources
610
- paper_titles = {}
611
- for paper_id in result.get("papers_used", []):
612
- if paper_id in article_data:
613
- paper_titles[paper_id] = article_data[paper_id].get(
614
- "Title", "Unknown paper"
615
- )
616
-
617
- # Format source information
618
- sources_text = ""
619
- if paper_titles:
620
- sources_text = "\n\nSources:\n" + "\n".join(
621
- [f"- {title}" for title in paper_titles.values()]
622
- )
623
-
624
- # Prepare the final response
625
- response_text = f"{answer_text}{sources_text}"
626
- logger.info(
627
- "%s: Successfully generated answer using %d chunks from %d papers",
628
- call_id,
629
- len(relevant_chunks),
630
- len(paper_titles),
127
+ # Generate answer and format with sources
128
+ response_text = helper.format_answer(
129
+ question, relevant_chunks, llm_model, article_data
631
130
  )
632
-
633
131
  return Command(
634
132
  update={
635
133
  "messages": [