aiagents4pharma 1.39.0__py3-none-any.whl → 1.39.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. aiagents4pharma/talk2scholars/agents/main_agent.py +7 -7
  2. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/main_agent/default.yaml +88 -12
  3. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/paper_download_agent/default.yaml +5 -0
  4. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/pdf_agent/default.yaml +5 -0
  5. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/s2_agent/default.yaml +1 -20
  6. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/zotero_agent/default.yaml +1 -26
  7. aiagents4pharma/talk2scholars/configs/tools/download_arxiv_paper/default.yaml +4 -0
  8. aiagents4pharma/talk2scholars/configs/tools/download_biorxiv_paper/default.yaml +2 -0
  9. aiagents4pharma/talk2scholars/configs/tools/download_medrxiv_paper/default.yaml +2 -0
  10. aiagents4pharma/talk2scholars/configs/tools/question_and_answer/default.yaml +22 -0
  11. aiagents4pharma/talk2scholars/tests/test_main_agent.py +20 -2
  12. aiagents4pharma/talk2scholars/tests/test_nvidia_nim_reranker_utils.py +28 -0
  13. aiagents4pharma/talk2scholars/tests/test_paper_download_tools.py +107 -29
  14. aiagents4pharma/talk2scholars/tests/test_pdf_agent.py +2 -3
  15. aiagents4pharma/talk2scholars/tests/test_question_and_answer_tool.py +194 -543
  16. aiagents4pharma/talk2scholars/tests/test_s2_agent.py +2 -2
  17. aiagents4pharma/talk2scholars/tests/{test_s2_display.py → test_s2_display_dataframe.py} +2 -3
  18. aiagents4pharma/talk2scholars/tests/test_s2_query_dataframe.py +201 -0
  19. aiagents4pharma/talk2scholars/tests/test_s2_retrieve.py +7 -6
  20. aiagents4pharma/talk2scholars/tests/test_s2_utils_ext_ids.py +413 -0
  21. aiagents4pharma/talk2scholars/tests/test_tool_helper_utils.py +140 -0
  22. aiagents4pharma/talk2scholars/tests/test_zotero_agent.py +0 -1
  23. aiagents4pharma/talk2scholars/tests/test_zotero_read.py +16 -18
  24. aiagents4pharma/talk2scholars/tools/paper_download/download_arxiv_input.py +92 -37
  25. aiagents4pharma/talk2scholars/tools/pdf/question_and_answer.py +73 -575
  26. aiagents4pharma/talk2scholars/tools/pdf/utils/__init__.py +10 -0
  27. aiagents4pharma/talk2scholars/tools/pdf/utils/generate_answer.py +97 -0
  28. aiagents4pharma/talk2scholars/tools/pdf/utils/nvidia_nim_reranker.py +77 -0
  29. aiagents4pharma/talk2scholars/tools/pdf/utils/retrieve_chunks.py +83 -0
  30. aiagents4pharma/talk2scholars/tools/pdf/utils/tool_helper.py +125 -0
  31. aiagents4pharma/talk2scholars/tools/pdf/utils/vector_store.py +162 -0
  32. aiagents4pharma/talk2scholars/tools/s2/display_dataframe.py +33 -10
  33. aiagents4pharma/talk2scholars/tools/s2/multi_paper_rec.py +39 -16
  34. aiagents4pharma/talk2scholars/tools/s2/query_dataframe.py +124 -10
  35. aiagents4pharma/talk2scholars/tools/s2/retrieve_semantic_scholar_paper_id.py +49 -17
  36. aiagents4pharma/talk2scholars/tools/s2/search.py +39 -16
  37. aiagents4pharma/talk2scholars/tools/s2/single_paper_rec.py +34 -16
  38. aiagents4pharma/talk2scholars/tools/s2/utils/multi_helper.py +49 -16
  39. aiagents4pharma/talk2scholars/tools/s2/utils/search_helper.py +51 -16
  40. aiagents4pharma/talk2scholars/tools/s2/utils/single_helper.py +50 -17
  41. {aiagents4pharma-1.39.0.dist-info → aiagents4pharma-1.39.2.dist-info}/METADATA +58 -105
  42. {aiagents4pharma-1.39.0.dist-info → aiagents4pharma-1.39.2.dist-info}/RECORD +45 -32
  43. aiagents4pharma/talk2scholars/tests/test_llm_main_integration.py +0 -89
  44. aiagents4pharma/talk2scholars/tests/test_routing_logic.py +0 -74
  45. aiagents4pharma/talk2scholars/tests/test_s2_query.py +0 -95
  46. {aiagents4pharma-1.39.0.dist-info → aiagents4pharma-1.39.2.dist-info}/WHEEL +0 -0
  47. {aiagents4pharma-1.39.0.dist-info → aiagents4pharma-1.39.2.dist-info}/licenses/LICENSE +0 -0
  48. {aiagents4pharma-1.39.0.dist-info → aiagents4pharma-1.39.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,10 @@
1
+ """
2
+ Utility modules for the PDF question_and_answer tool.
3
+ """
4
+
5
+ from . import generate_answer
6
+ from . import nvidia_nim_reranker
7
+ from . import retrieve_chunks
8
+ from . import vector_store
9
+
10
+ __all__ = ["generate_answer", "nvidia_nim_reranker", "retrieve_chunks", "vector_store"]
@@ -0,0 +1,97 @@
1
+ """
2
+ Generate an answer for a question using retrieved chunks of documents.
3
+ """
4
+
5
+ import logging
6
+ import os
7
+ from typing import Any, Dict, List
8
+
9
+ import hydra
10
+ from langchain_core.documents import Document
11
+ from langchain_core.language_models.chat_models import BaseChatModel
12
+
13
+ # Set up logging with configurable level
14
+ log_level = os.environ.get("LOG_LEVEL", "INFO")
15
+ logging.basicConfig(level=getattr(logging, log_level))
16
+ logger = logging.getLogger(__name__)
17
+ logger.setLevel(getattr(logging, log_level))
18
+
19
+
20
+ def _build_context_and_sources(
21
+ retrieved_chunks: List[Document],
22
+ ) -> tuple[str, set[str]]:
23
+ """
24
+ Build the combined context string and set of paper_ids from retrieved chunks.
25
+ """
26
+ papers = {}
27
+ for doc in retrieved_chunks:
28
+ pid = doc.metadata.get("paper_id", "unknown")
29
+ papers.setdefault(pid, []).append(doc)
30
+ formatted = []
31
+ idx = 1
32
+ for pid, chunks in papers.items():
33
+ title = chunks[0].metadata.get("title", "Unknown")
34
+ formatted.append(f"[Document {idx}] From: '{title}' (ID: {pid})")
35
+ for chunk in chunks:
36
+ page = chunk.metadata.get("page", "unknown")
37
+ formatted.append(f"Page {page}: {chunk.page_content}")
38
+ idx += 1
39
+ context = "\n\n".join(formatted)
40
+ sources: set[str] = set()
41
+ for doc in retrieved_chunks:
42
+ pid = doc.metadata.get("paper_id")
43
+ if isinstance(pid, str):
44
+ sources.add(pid)
45
+ return context, sources
46
+
47
+
48
+ def load_hydra_config() -> Any:
49
+ """
50
+ Load the configuration using Hydra and return the configuration for the Q&A tool.
51
+ """
52
+ with hydra.initialize(version_base=None, config_path="../../../configs"):
53
+ cfg = hydra.compose(
54
+ config_name="config",
55
+ overrides=["tools/question_and_answer=default"],
56
+ )
57
+ config = cfg.tools.question_and_answer
58
+ logger.debug("Loaded Question and Answer tool configuration.")
59
+ return config
60
+
61
+
62
+ def generate_answer(
63
+ question: str,
64
+ retrieved_chunks: List[Document],
65
+ llm_model: BaseChatModel,
66
+ config: Any,
67
+ ) -> Dict[str, Any]:
68
+ """
69
+ Generate an answer for a question using retrieved chunks.
70
+
71
+ Args:
72
+ question (str): The question to answer
73
+ retrieved_chunks (List[Document]): List of relevant document chunks
74
+ llm_model (BaseChatModel): Language model for generating answers
75
+ config (Any): Configuration for answer generation
76
+
77
+ Returns:
78
+ Dict[str, Any]: Dictionary with the answer and metadata
79
+ """
80
+ # Ensure the configuration is provided and has the prompt_template.
81
+ if config is None:
82
+ raise ValueError("Configuration for generate_answer is required.")
83
+ if "prompt_template" not in config:
84
+ raise ValueError("The prompt_template is missing from the configuration.")
85
+
86
+ # Build context and sources, then invoke LLM
87
+ context, paper_sources = _build_context_and_sources(retrieved_chunks)
88
+ prompt = config["prompt_template"].format(context=context, question=question)
89
+ response = llm_model.invoke(prompt)
90
+
91
+ # Return the response with metadata
92
+ return {
93
+ "output_text": response.content,
94
+ "sources": [doc.metadata for doc in retrieved_chunks],
95
+ "num_sources": len(retrieved_chunks),
96
+ "papers_used": list(paper_sources),
97
+ }
@@ -0,0 +1,77 @@
1
+ """
2
+ NVIDIA NIM Reranker Utility
3
+ """
4
+
5
+ import logging
6
+ import os
7
+
8
+
9
+ from typing import Any, List
10
+
11
+ from langchain_core.documents import Document
12
+ from langchain_nvidia_ai_endpoints import NVIDIARerank
13
+
14
+ # Set up logging with configurable level
15
+ log_level = os.environ.get("LOG_LEVEL", "INFO")
16
+ logging.basicConfig(level=getattr(logging, log_level))
17
+ logger = logging.getLogger(__name__)
18
+ logger.setLevel(getattr(logging, log_level))
19
+
20
+
21
+ def rank_papers_by_query(self, query: str, config: Any, top_k: int = 5) -> List[str]:
22
+ """
23
+ Rank papers by relevance to the query using NVIDIA's off-the-shelf re-ranker.
24
+
25
+ This function aggregates all chunks per paper, ranks them using the NVIDIA model,
26
+ and returns the top-k papers.
27
+
28
+ Args:
29
+ query (str): The query string.
30
+ config (Any): Configuration containing reranker settings (model, api_key).
31
+ top_k (int): Number of top papers to return.
32
+
33
+ Returns:
34
+ List of tuples (paper_id, dummy_score) sorted by relevance.
35
+ """
36
+
37
+ logger.info("Starting NVIDIA re-ranker for query: '%s' with top_k=%d", query, top_k)
38
+ # Aggregate all document chunks for each paper
39
+ paper_texts = {}
40
+ for doc in self.documents.values():
41
+ paper_id = doc.metadata["paper_id"]
42
+ paper_texts.setdefault(paper_id, []).append(doc.page_content)
43
+
44
+ aggregated_documents = []
45
+ for paper_id, texts in paper_texts.items():
46
+ aggregated_text = " ".join(texts)
47
+ aggregated_documents.append(
48
+ Document(page_content=aggregated_text, metadata={"paper_id": paper_id})
49
+ )
50
+
51
+ logger.info(
52
+ "Aggregated %d papers into %d documents for reranking",
53
+ len(paper_texts),
54
+ len(aggregated_documents),
55
+ )
56
+ # Instantiate the NVIDIA re-ranker client using provided config
57
+ # Use NVIDIA API key from Hydra configuration (expected to be resolved via oc.env)
58
+ api_key = config.reranker.api_key
59
+ if not api_key:
60
+ logger.error("No NVIDIA API key found in configuration for reranking")
61
+ raise ValueError("Configuration 'reranker.api_key' must be set for reranking")
62
+ logger.info("Using NVIDIA API key from configuration for reranking")
63
+ # Truncate long inputs at the model-end to avoid exceeding max token size
64
+ logger.info("Setting NVIDIA reranker truncate mode to END to limit input length")
65
+ reranker = NVIDIARerank(
66
+ model=config.reranker.model,
67
+ api_key=api_key,
68
+ truncate="END",
69
+ )
70
+
71
+ # Get the ranked list of documents based on the query
72
+ response = reranker.compress_documents(query=query, documents=aggregated_documents)
73
+ logger.info("Received %d documents from NVIDIA reranker", len(response))
74
+
75
+ ranked_papers = [doc.metadata["paper_id"] for doc in response[:top_k]]
76
+ logger.info("Top %d papers after reranking: %s", top_k, ranked_papers)
77
+ return ranked_papers
@@ -0,0 +1,83 @@
1
+ """
2
+ Retrieve relevant chunks from a vector store using MMR (Maximal Marginal Relevance).
3
+ """
4
+
5
+ import logging
6
+ import os
7
+ from typing import List, Optional
8
+
9
+ import numpy as np
10
+ from langchain_core.documents import Document
11
+ from langchain_core.vectorstores.utils import maximal_marginal_relevance
12
+
13
+
14
+ # Set up logging with configurable level
15
+ log_level = os.environ.get("LOG_LEVEL", "INFO")
16
+ logging.basicConfig(level=getattr(logging, log_level))
17
+ logger = logging.getLogger(__name__)
18
+ logger.setLevel(getattr(logging, log_level))
19
+
20
+
21
+ def retrieve_relevant_chunks(
22
+ self,
23
+ query: str,
24
+ paper_ids: Optional[List[str]] = None,
25
+ top_k: int = 25,
26
+ mmr_diversity: float = 1.00,
27
+ ) -> List[Document]:
28
+ """
29
+ Retrieve the most relevant chunks for a query using maximal marginal relevance.
30
+
31
+ Args:
32
+ query: Query string
33
+ paper_ids: Optional list of paper IDs to filter by
34
+ top_k: Number of chunks to retrieve
35
+ mmr_diversity: Diversity parameter for MMR (higher = more diverse)
36
+
37
+ Returns:
38
+ List of document chunks
39
+ """
40
+ if not self.vector_store:
41
+ logger.error("Failed to build vector store")
42
+ return []
43
+
44
+ if paper_ids:
45
+ logger.info("Filtering retrieval to papers: %s", paper_ids)
46
+
47
+ # Step 1: Embed the query
48
+ logger.info("Embedding query using model: %s", type(self.embedding_model).__name__)
49
+ query_embedding = np.array(self.embedding_model.embed_query(query))
50
+
51
+ # Step 2: Filter relevant documents
52
+ all_docs = [
53
+ doc
54
+ for doc in self.documents.values()
55
+ if not paper_ids or doc.metadata["paper_id"] in paper_ids
56
+ ]
57
+
58
+ if not all_docs:
59
+ logger.warning("No documents found after filtering by paper_ids.")
60
+ return []
61
+
62
+ # Step 3: Retrieve or compute embeddings for all documents using cache
63
+ logger.info("Retrieving embeddings for %d chunks...", len(all_docs))
64
+ all_embeddings = []
65
+ for doc in all_docs:
66
+ doc_id = f"{doc.metadata['paper_id']}_{doc.metadata['chunk_id']}"
67
+ if doc_id not in self.embeddings:
68
+ logger.info("Embedding missing chunk %s", doc_id)
69
+ emb = self.embedding_model.embed_documents([doc.page_content])[0]
70
+ self.embeddings[doc_id] = emb
71
+ all_embeddings.append(self.embeddings[doc_id])
72
+
73
+ # Step 4: Apply MMR
74
+ mmr_indices = maximal_marginal_relevance(
75
+ query_embedding,
76
+ all_embeddings,
77
+ k=top_k,
78
+ lambda_mult=mmr_diversity,
79
+ )
80
+
81
+ results = [all_docs[i] for i in mmr_indices]
82
+ logger.info("Retrieved %d chunks using MMR", len(results))
83
+ return results
@@ -0,0 +1,125 @@
1
+ """
2
+ Helper class for PDF Q&A tool orchestration: state validation, vectorstore init,
3
+ paper loading, reranking, and answer formatting.
4
+ """
5
+
6
+ import logging
7
+ from typing import Any, Dict, List, Optional
8
+
9
+ from .generate_answer import generate_answer
10
+ from .nvidia_nim_reranker import rank_papers_by_query
11
+ from .vector_store import Vectorstore
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class QAToolHelper:
17
+ """Encapsulates helper routines for the PDF Question & Answer tool."""
18
+
19
+ def __init__(self) -> None:
20
+ self.prebuilt_vector_store: Optional[Vectorstore] = None
21
+ self.config: Any = None
22
+ self.call_id: str = ""
23
+ logger.debug("Initialized QAToolHelper")
24
+
25
+ def start_call(self, config: Any, call_id: str) -> None:
26
+ """Initialize helper with current config and call identifier."""
27
+ self.config = config
28
+ self.call_id = call_id
29
+ logger.debug("QAToolHelper started call %s", call_id)
30
+
31
+ def get_state_models_and_data(self, state: dict) -> tuple[Any, Any, Dict[str, Any]]:
32
+ """Retrieve embedding model, LLM, and article data from agent state."""
33
+ text_emb = state.get("text_embedding_model")
34
+ if not text_emb:
35
+ msg = "No text embedding model found in state."
36
+ logger.error("%s: %s", self.call_id, msg)
37
+ raise ValueError(msg)
38
+ llm = state.get("llm_model")
39
+ if not llm:
40
+ msg = "No LLM model found in state."
41
+ logger.error("%s: %s", self.call_id, msg)
42
+ raise ValueError(msg)
43
+ articles = state.get("article_data", {})
44
+ if not articles:
45
+ msg = "No article_data found in state."
46
+ logger.error("%s: %s", self.call_id, msg)
47
+ raise ValueError(msg)
48
+ return text_emb, llm, articles
49
+
50
+ def init_vector_store(self, emb_model: Any) -> Vectorstore:
51
+ """Return shared or new Vectorstore instance."""
52
+ if self.prebuilt_vector_store is not None:
53
+ logger.info("Using shared pre-built vector store from memory")
54
+ return self.prebuilt_vector_store
55
+ vs = Vectorstore(embedding_model=emb_model, config=self.config)
56
+ logger.info("Initialized new vector store with provided configuration")
57
+ self.prebuilt_vector_store = vs
58
+ return vs
59
+
60
+ def load_candidate_papers(
61
+ self,
62
+ vs: Vectorstore,
63
+ articles: Dict[str, Any],
64
+ candidates: List[str],
65
+ ) -> None:
66
+ """Ensure each candidate paper is loaded into the vector store."""
67
+ for pid in candidates:
68
+ if pid not in vs.loaded_papers:
69
+ pdf_url = articles.get(pid, {}).get("pdf_url")
70
+ if not pdf_url:
71
+ continue
72
+ try:
73
+ vs.add_paper(pid, pdf_url, articles[pid])
74
+ except (IOError, ValueError) as exc:
75
+ logger.warning(
76
+ "%s: Error loading paper %s: %s", self.call_id, pid, exc
77
+ )
78
+
79
+ def run_reranker(
80
+ self,
81
+ vs: Vectorstore,
82
+ query: str,
83
+ candidates: List[str],
84
+ ) -> List[str]:
85
+ """Rank papers by relevance and return filtered paper IDs."""
86
+ try:
87
+ ranked = rank_papers_by_query(
88
+ vs, query, self.config, top_k=self.config.top_k_papers
89
+ )
90
+ logger.info("%s: Papers after NVIDIA reranking: %s", self.call_id, ranked)
91
+ return [pid for pid in ranked if pid in candidates]
92
+ except (ValueError, RuntimeError) as exc:
93
+ logger.error("%s: NVIDIA reranker failed: %s", self.call_id, exc)
94
+ logger.info(
95
+ "%s: Falling back to all %d candidate papers",
96
+ self.call_id,
97
+ len(candidates),
98
+ )
99
+ return candidates
100
+
101
+ def format_answer(
102
+ self,
103
+ question: str,
104
+ chunks: List[Any],
105
+ llm: Any,
106
+ articles: Dict[str, Any],
107
+ ) -> str:
108
+ """Generate the final answer text with source attributions."""
109
+ result = generate_answer(question, chunks, llm, self.config)
110
+ answer = result.get("output_text", "No answer generated.")
111
+ titles: Dict[str, str] = {}
112
+ for pid in result.get("papers_used", []):
113
+ if pid in articles:
114
+ titles[pid] = articles[pid].get("Title", "Unknown paper")
115
+ if titles:
116
+ srcs = "\n\nSources:\n" + "\n".join(f"- {t}" for t in titles.values())
117
+ else:
118
+ srcs = ""
119
+ logger.info(
120
+ "%s: Generated answer using %d chunks from %d papers",
121
+ self.call_id,
122
+ len(chunks),
123
+ len(titles),
124
+ )
125
+ return f"{answer}{srcs}"
@@ -0,0 +1,162 @@
1
+ """
2
+ Vectorstore class for managing document embeddings and retrieval.
3
+ """
4
+
5
+ import logging
6
+ import os
7
+ import time
8
+ from typing import Any, Dict, List, Optional
9
+
10
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
11
+ from langchain_community.document_loaders import PyPDFLoader
12
+ from langchain_community.vectorstores import FAISS
13
+ from langchain_core.documents import Document
14
+ from langchain_core.embeddings import Embeddings
15
+ from langchain_core.vectorstores import VectorStore
16
+
17
+
18
+ # Set up logging with configurable level
19
+ log_level = os.environ.get("LOG_LEVEL", "INFO")
20
+ logging.basicConfig(level=getattr(logging, log_level))
21
+ logger = logging.getLogger(__name__)
22
+ logger.setLevel(getattr(logging, log_level))
23
+
24
+
25
+ class Vectorstore:
26
+ """
27
+ A class for managing document embeddings and retrieval.
28
+ Provides unified access to documents across multiple papers.
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ embedding_model: Embeddings,
34
+ metadata_fields: Optional[List[str]] = None,
35
+ config: Any = None,
36
+ ):
37
+ """
38
+ Initialize the document store.
39
+
40
+ Args:
41
+ embedding_model: The embedding model to use
42
+ metadata_fields: Fields to include in document metadata for filtering/retrieval
43
+ """
44
+ self.embedding_model = embedding_model
45
+ self.config = config
46
+ self.metadata_fields = metadata_fields or [
47
+ "title",
48
+ "paper_id",
49
+ "page",
50
+ "chunk_id",
51
+ ]
52
+ self.initialization_time = time.time()
53
+ logger.info("Vectorstore initialized at: %s", self.initialization_time)
54
+
55
+ # Track loaded papers to prevent duplicate loading
56
+ self.loaded_papers = set()
57
+ self.vector_store_class = FAISS
58
+ logger.info("Using FAISS vector store")
59
+
60
+ # Store for initialized documents
61
+ self.documents: Dict[str, Document] = {}
62
+ self.vector_store: Optional[VectorStore] = None
63
+ self.paper_metadata: Dict[str, Dict[str, Any]] = {}
64
+ # Cache for document chunk embeddings to avoid recomputation
65
+ self.embeddings: Dict[str, Any] = {}
66
+
67
+ def add_paper(
68
+ self,
69
+ paper_id: str,
70
+ pdf_url: str,
71
+ paper_metadata: Dict[str, Any],
72
+ ) -> None:
73
+ """
74
+ Add a paper to the document store.
75
+
76
+ Args:
77
+ paper_id: Unique identifier for the paper
78
+ pdf_url: URL to the PDF
79
+ paper_metadata: Metadata about the paper
80
+ """
81
+ # Skip if already loaded
82
+ if paper_id in self.loaded_papers:
83
+ logger.info("Paper %s already loaded, skipping", paper_id)
84
+ return
85
+
86
+ logger.info("Loading paper %s from %s", paper_id, pdf_url)
87
+
88
+ # Store paper metadata
89
+ self.paper_metadata[paper_id] = paper_metadata
90
+
91
+ # Load the PDF and split into chunks according to Hydra config
92
+ loader = PyPDFLoader(pdf_url)
93
+ documents = loader.load()
94
+ logger.info("Loaded %d pages from %s", len(documents), paper_id)
95
+
96
+ # Create text splitter according to provided configuration
97
+ if self.config is None:
98
+ raise ValueError(
99
+ "Configuration is required for text splitting in Vectorstore."
100
+ )
101
+ splitter = RecursiveCharacterTextSplitter(
102
+ chunk_size=self.config.chunk_size,
103
+ chunk_overlap=self.config.chunk_overlap,
104
+ separators=["\n\n", "\n", ". ", " ", ""],
105
+ )
106
+
107
+ # Split documents and add metadata for each chunk
108
+ chunks = splitter.split_documents(documents)
109
+ logger.info("Split %s into %d chunks", paper_id, len(chunks))
110
+ # Embed and cache chunk embeddings
111
+ chunk_texts = [chunk.page_content for chunk in chunks]
112
+ chunk_embeddings = self.embedding_model.embed_documents(chunk_texts)
113
+ logger.info("Embedded %d chunks for paper %s", len(chunk_embeddings), paper_id)
114
+
115
+ # Enhance document metadata
116
+ for i, chunk in enumerate(chunks):
117
+ # Add paper metadata to each chunk
118
+ chunk.metadata.update(
119
+ {
120
+ "paper_id": paper_id,
121
+ "title": paper_metadata.get("Title", "Unknown"),
122
+ "chunk_id": i,
123
+ # Keep existing page number if available
124
+ "page": chunk.metadata.get("page", 0),
125
+ }
126
+ )
127
+
128
+ # Add any additional metadata fields
129
+ for field in self.metadata_fields:
130
+ if field in paper_metadata and field not in chunk.metadata:
131
+ chunk.metadata[field] = paper_metadata[field]
132
+
133
+ # Store chunk
134
+ doc_id = f"{paper_id}_{i}"
135
+ self.documents[doc_id] = chunk
136
+ # Cache embedding if available
137
+ if chunk_embeddings[i] is not None:
138
+ self.embeddings[doc_id] = chunk_embeddings[i]
139
+
140
+ # Mark as loaded to prevent duplicate loading
141
+ self.loaded_papers.add(paper_id)
142
+ logger.info("Added %d chunks from paper %s", len(chunks), paper_id)
143
+
144
+ def build_vector_store(self) -> None:
145
+ """
146
+ Build the vector store from all loaded documents.
147
+ Should be called after all papers are added.
148
+ """
149
+ if not self.documents:
150
+ logger.warning("No documents added to build vector store")
151
+ return
152
+
153
+ if self.vector_store is not None:
154
+ logger.info("Vector store already built, skipping")
155
+ return
156
+
157
+ # Create vector store from documents
158
+ documents_list = list(self.documents.values())
159
+ self.vector_store = self.vector_store_class.from_documents(
160
+ documents=documents_list, embedding=self.embedding_model
161
+ )
162
+ logger.info("Built vector store with %d documents", len(documents_list))
@@ -4,17 +4,19 @@
4
4
  """
5
5
  Tool for rendering the most recently displayed papers as a DataFrame artifact for the front-end.
6
6
 
7
- This module defines a tool that retrieves the paper metadata stored under the state key
8
- 'last_displayed_papers' and returns it as an artifact (dictionary of papers). The front-end
9
- can then render this artifact as a pandas DataFrame for display. If no papers are found,
10
- a NoPapersFoundError is raised to indicate that a search or recommendation should be
11
- performed first.
7
+ Call this tool when you need to present the current set of retrieved papers to the user
8
+ (e.g., "show me the papers", "display results"). It reads the 'last_displayed_papers'
9
+ dictionary from the agent state and returns it as an artifact that the UI will render
10
+ as a pandas DataFrame. This tool does not perform any new searches or filtering; it
11
+ only displays the existing list. If no papers are available, it raises NoPapersFoundError
12
+ to signal that a search or recommendation must be executed first.
12
13
  """
13
14
 
14
15
 
15
16
  import logging
16
17
 
17
18
  from typing import Annotated
19
+ from pydantic import BaseModel, Field
18
20
  from langchain_core.messages import ToolMessage
19
21
  from langchain_core.tools import tool
20
22
  from langchain_core.tools.base import InjectedToolCallId
@@ -40,10 +42,31 @@ class NoPapersFoundError(Exception):
40
42
  """
41
43
 
42
44
 
43
- @tool("display_dataframe", parse_docstring=True)
45
+ class DisplayDataFrameInput(BaseModel):
46
+ """
47
+ Pydantic schema for displaying the last set of papers as a DataFrame artifact.
48
+
49
+ Fields:
50
+ state: Agent state dict containing the 'last_displayed_papers' key.
51
+ tool_call_id: LangGraph-injected identifier for this tool invocation.
52
+ """
53
+
54
+ state: Annotated[dict, InjectedState] = Field(
55
+ ..., description="Agent state containing the 'last_displayed_papers' reference."
56
+ )
57
+ tool_call_id: Annotated[str, InjectedToolCallId] = Field(
58
+ ..., description="LangGraph-injected identifier for this tool call."
59
+ )
60
+
61
+
62
+ @tool(
63
+ "display_dataframe",
64
+ args_schema=DisplayDataFrameInput,
65
+ parse_docstring=True,
66
+ )
44
67
  def display_dataframe(
45
- tool_call_id: Annotated[str, InjectedToolCallId],
46
- state: Annotated[dict, InjectedState],
68
+ tool_call_id: str,
69
+ state: dict,
47
70
  ) -> Command:
48
71
  """
49
72
  Render the last set of retrieved papers as a DataFrame in the front-end.
@@ -55,7 +78,7 @@ def display_dataframe(
55
78
  that a search or recommendation must be performed first.
56
79
 
57
80
  Args:
58
- tool_call_id (InjectedToolCallId): Unique ID of this tool invocation.
81
+ tool_call_id (str): LangGraph-injected unique ID for this tool call.
59
82
  state (dict): The agent's state containing the 'last_displayed_papers' reference.
60
83
 
61
84
  Returns:
@@ -65,7 +88,7 @@ def display_dataframe(
65
88
  Raises:
66
89
  NoPapersFoundError: If no entries exist under 'last_displayed_papers' in state.
67
90
  """
68
- logger.info("Displaying papers")
91
+ logger.info("Displaying papers from 'last_displayed_papers'")
69
92
  context_val = state.get("last_displayed_papers")
70
93
  # Support both key reference (str) and direct mapping
71
94
  if isinstance(context_val, dict):