aiagents4pharma 1.39.0__py3-none-any.whl → 1.39.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiagents4pharma/talk2scholars/agents/main_agent.py +7 -7
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/main_agent/default.yaml +88 -12
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/paper_download_agent/default.yaml +5 -0
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/pdf_agent/default.yaml +5 -0
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/s2_agent/default.yaml +1 -20
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/zotero_agent/default.yaml +1 -26
- aiagents4pharma/talk2scholars/configs/tools/download_arxiv_paper/default.yaml +4 -0
- aiagents4pharma/talk2scholars/configs/tools/download_biorxiv_paper/default.yaml +2 -0
- aiagents4pharma/talk2scholars/configs/tools/download_medrxiv_paper/default.yaml +2 -0
- aiagents4pharma/talk2scholars/configs/tools/question_and_answer/default.yaml +22 -0
- aiagents4pharma/talk2scholars/tests/test_main_agent.py +20 -2
- aiagents4pharma/talk2scholars/tests/test_nvidia_nim_reranker_utils.py +28 -0
- aiagents4pharma/talk2scholars/tests/test_paper_download_tools.py +107 -29
- aiagents4pharma/talk2scholars/tests/test_pdf_agent.py +2 -3
- aiagents4pharma/talk2scholars/tests/test_question_and_answer_tool.py +194 -543
- aiagents4pharma/talk2scholars/tests/test_s2_agent.py +2 -2
- aiagents4pharma/talk2scholars/tests/{test_s2_display.py → test_s2_display_dataframe.py} +2 -3
- aiagents4pharma/talk2scholars/tests/test_s2_query_dataframe.py +201 -0
- aiagents4pharma/talk2scholars/tests/test_s2_retrieve.py +7 -6
- aiagents4pharma/talk2scholars/tests/test_s2_utils_ext_ids.py +413 -0
- aiagents4pharma/talk2scholars/tests/test_tool_helper_utils.py +140 -0
- aiagents4pharma/talk2scholars/tests/test_zotero_agent.py +0 -1
- aiagents4pharma/talk2scholars/tests/test_zotero_read.py +16 -18
- aiagents4pharma/talk2scholars/tools/paper_download/download_arxiv_input.py +92 -37
- aiagents4pharma/talk2scholars/tools/pdf/question_and_answer.py +73 -575
- aiagents4pharma/talk2scholars/tools/pdf/utils/__init__.py +10 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/generate_answer.py +97 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/nvidia_nim_reranker.py +77 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/retrieve_chunks.py +83 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/tool_helper.py +125 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/vector_store.py +162 -0
- aiagents4pharma/talk2scholars/tools/s2/display_dataframe.py +33 -10
- aiagents4pharma/talk2scholars/tools/s2/multi_paper_rec.py +39 -16
- aiagents4pharma/talk2scholars/tools/s2/query_dataframe.py +124 -10
- aiagents4pharma/talk2scholars/tools/s2/retrieve_semantic_scholar_paper_id.py +49 -17
- aiagents4pharma/talk2scholars/tools/s2/search.py +39 -16
- aiagents4pharma/talk2scholars/tools/s2/single_paper_rec.py +34 -16
- aiagents4pharma/talk2scholars/tools/s2/utils/multi_helper.py +49 -16
- aiagents4pharma/talk2scholars/tools/s2/utils/search_helper.py +51 -16
- aiagents4pharma/talk2scholars/tools/s2/utils/single_helper.py +50 -17
- {aiagents4pharma-1.39.0.dist-info → aiagents4pharma-1.39.2.dist-info}/METADATA +58 -105
- {aiagents4pharma-1.39.0.dist-info → aiagents4pharma-1.39.2.dist-info}/RECORD +45 -32
- aiagents4pharma/talk2scholars/tests/test_llm_main_integration.py +0 -89
- aiagents4pharma/talk2scholars/tests/test_routing_logic.py +0 -74
- aiagents4pharma/talk2scholars/tests/test_s2_query.py +0 -95
- {aiagents4pharma-1.39.0.dist-info → aiagents4pharma-1.39.2.dist-info}/WHEEL +0 -0
- {aiagents4pharma-1.39.0.dist-info → aiagents4pharma-1.39.2.dist-info}/licenses/LICENSE +0 -0
- {aiagents4pharma-1.39.0.dist-info → aiagents4pharma-1.39.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,10 @@
|
|
1
|
+
"""
|
2
|
+
Utility modules for the PDF question_and_answer tool.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from . import generate_answer
|
6
|
+
from . import nvidia_nim_reranker
|
7
|
+
from . import retrieve_chunks
|
8
|
+
from . import vector_store
|
9
|
+
|
10
|
+
__all__ = ["generate_answer", "nvidia_nim_reranker", "retrieve_chunks", "vector_store"]
|
@@ -0,0 +1,97 @@
|
|
1
|
+
"""
|
2
|
+
Generate an answer for a question using retrieved chunks of documents.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import logging
|
6
|
+
import os
|
7
|
+
from typing import Any, Dict, List
|
8
|
+
|
9
|
+
import hydra
|
10
|
+
from langchain_core.documents import Document
|
11
|
+
from langchain_core.language_models.chat_models import BaseChatModel
|
12
|
+
|
13
|
+
# Set up logging with configurable level
|
14
|
+
log_level = os.environ.get("LOG_LEVEL", "INFO")
|
15
|
+
logging.basicConfig(level=getattr(logging, log_level))
|
16
|
+
logger = logging.getLogger(__name__)
|
17
|
+
logger.setLevel(getattr(logging, log_level))
|
18
|
+
|
19
|
+
|
20
|
+
def _build_context_and_sources(
|
21
|
+
retrieved_chunks: List[Document],
|
22
|
+
) -> tuple[str, set[str]]:
|
23
|
+
"""
|
24
|
+
Build the combined context string and set of paper_ids from retrieved chunks.
|
25
|
+
"""
|
26
|
+
papers = {}
|
27
|
+
for doc in retrieved_chunks:
|
28
|
+
pid = doc.metadata.get("paper_id", "unknown")
|
29
|
+
papers.setdefault(pid, []).append(doc)
|
30
|
+
formatted = []
|
31
|
+
idx = 1
|
32
|
+
for pid, chunks in papers.items():
|
33
|
+
title = chunks[0].metadata.get("title", "Unknown")
|
34
|
+
formatted.append(f"[Document {idx}] From: '{title}' (ID: {pid})")
|
35
|
+
for chunk in chunks:
|
36
|
+
page = chunk.metadata.get("page", "unknown")
|
37
|
+
formatted.append(f"Page {page}: {chunk.page_content}")
|
38
|
+
idx += 1
|
39
|
+
context = "\n\n".join(formatted)
|
40
|
+
sources: set[str] = set()
|
41
|
+
for doc in retrieved_chunks:
|
42
|
+
pid = doc.metadata.get("paper_id")
|
43
|
+
if isinstance(pid, str):
|
44
|
+
sources.add(pid)
|
45
|
+
return context, sources
|
46
|
+
|
47
|
+
|
48
|
+
def load_hydra_config() -> Any:
|
49
|
+
"""
|
50
|
+
Load the configuration using Hydra and return the configuration for the Q&A tool.
|
51
|
+
"""
|
52
|
+
with hydra.initialize(version_base=None, config_path="../../../configs"):
|
53
|
+
cfg = hydra.compose(
|
54
|
+
config_name="config",
|
55
|
+
overrides=["tools/question_and_answer=default"],
|
56
|
+
)
|
57
|
+
config = cfg.tools.question_and_answer
|
58
|
+
logger.debug("Loaded Question and Answer tool configuration.")
|
59
|
+
return config
|
60
|
+
|
61
|
+
|
62
|
+
def generate_answer(
|
63
|
+
question: str,
|
64
|
+
retrieved_chunks: List[Document],
|
65
|
+
llm_model: BaseChatModel,
|
66
|
+
config: Any,
|
67
|
+
) -> Dict[str, Any]:
|
68
|
+
"""
|
69
|
+
Generate an answer for a question using retrieved chunks.
|
70
|
+
|
71
|
+
Args:
|
72
|
+
question (str): The question to answer
|
73
|
+
retrieved_chunks (List[Document]): List of relevant document chunks
|
74
|
+
llm_model (BaseChatModel): Language model for generating answers
|
75
|
+
config (Any): Configuration for answer generation
|
76
|
+
|
77
|
+
Returns:
|
78
|
+
Dict[str, Any]: Dictionary with the answer and metadata
|
79
|
+
"""
|
80
|
+
# Ensure the configuration is provided and has the prompt_template.
|
81
|
+
if config is None:
|
82
|
+
raise ValueError("Configuration for generate_answer is required.")
|
83
|
+
if "prompt_template" not in config:
|
84
|
+
raise ValueError("The prompt_template is missing from the configuration.")
|
85
|
+
|
86
|
+
# Build context and sources, then invoke LLM
|
87
|
+
context, paper_sources = _build_context_and_sources(retrieved_chunks)
|
88
|
+
prompt = config["prompt_template"].format(context=context, question=question)
|
89
|
+
response = llm_model.invoke(prompt)
|
90
|
+
|
91
|
+
# Return the response with metadata
|
92
|
+
return {
|
93
|
+
"output_text": response.content,
|
94
|
+
"sources": [doc.metadata for doc in retrieved_chunks],
|
95
|
+
"num_sources": len(retrieved_chunks),
|
96
|
+
"papers_used": list(paper_sources),
|
97
|
+
}
|
@@ -0,0 +1,77 @@
|
|
1
|
+
"""
|
2
|
+
NVIDIA NIM Reranker Utility
|
3
|
+
"""
|
4
|
+
|
5
|
+
import logging
|
6
|
+
import os
|
7
|
+
|
8
|
+
|
9
|
+
from typing import Any, List
|
10
|
+
|
11
|
+
from langchain_core.documents import Document
|
12
|
+
from langchain_nvidia_ai_endpoints import NVIDIARerank
|
13
|
+
|
14
|
+
# Set up logging with configurable level
|
15
|
+
log_level = os.environ.get("LOG_LEVEL", "INFO")
|
16
|
+
logging.basicConfig(level=getattr(logging, log_level))
|
17
|
+
logger = logging.getLogger(__name__)
|
18
|
+
logger.setLevel(getattr(logging, log_level))
|
19
|
+
|
20
|
+
|
21
|
+
def rank_papers_by_query(self, query: str, config: Any, top_k: int = 5) -> List[str]:
|
22
|
+
"""
|
23
|
+
Rank papers by relevance to the query using NVIDIA's off-the-shelf re-ranker.
|
24
|
+
|
25
|
+
This function aggregates all chunks per paper, ranks them using the NVIDIA model,
|
26
|
+
and returns the top-k papers.
|
27
|
+
|
28
|
+
Args:
|
29
|
+
query (str): The query string.
|
30
|
+
config (Any): Configuration containing reranker settings (model, api_key).
|
31
|
+
top_k (int): Number of top papers to return.
|
32
|
+
|
33
|
+
Returns:
|
34
|
+
List of tuples (paper_id, dummy_score) sorted by relevance.
|
35
|
+
"""
|
36
|
+
|
37
|
+
logger.info("Starting NVIDIA re-ranker for query: '%s' with top_k=%d", query, top_k)
|
38
|
+
# Aggregate all document chunks for each paper
|
39
|
+
paper_texts = {}
|
40
|
+
for doc in self.documents.values():
|
41
|
+
paper_id = doc.metadata["paper_id"]
|
42
|
+
paper_texts.setdefault(paper_id, []).append(doc.page_content)
|
43
|
+
|
44
|
+
aggregated_documents = []
|
45
|
+
for paper_id, texts in paper_texts.items():
|
46
|
+
aggregated_text = " ".join(texts)
|
47
|
+
aggregated_documents.append(
|
48
|
+
Document(page_content=aggregated_text, metadata={"paper_id": paper_id})
|
49
|
+
)
|
50
|
+
|
51
|
+
logger.info(
|
52
|
+
"Aggregated %d papers into %d documents for reranking",
|
53
|
+
len(paper_texts),
|
54
|
+
len(aggregated_documents),
|
55
|
+
)
|
56
|
+
# Instantiate the NVIDIA re-ranker client using provided config
|
57
|
+
# Use NVIDIA API key from Hydra configuration (expected to be resolved via oc.env)
|
58
|
+
api_key = config.reranker.api_key
|
59
|
+
if not api_key:
|
60
|
+
logger.error("No NVIDIA API key found in configuration for reranking")
|
61
|
+
raise ValueError("Configuration 'reranker.api_key' must be set for reranking")
|
62
|
+
logger.info("Using NVIDIA API key from configuration for reranking")
|
63
|
+
# Truncate long inputs at the model-end to avoid exceeding max token size
|
64
|
+
logger.info("Setting NVIDIA reranker truncate mode to END to limit input length")
|
65
|
+
reranker = NVIDIARerank(
|
66
|
+
model=config.reranker.model,
|
67
|
+
api_key=api_key,
|
68
|
+
truncate="END",
|
69
|
+
)
|
70
|
+
|
71
|
+
# Get the ranked list of documents based on the query
|
72
|
+
response = reranker.compress_documents(query=query, documents=aggregated_documents)
|
73
|
+
logger.info("Received %d documents from NVIDIA reranker", len(response))
|
74
|
+
|
75
|
+
ranked_papers = [doc.metadata["paper_id"] for doc in response[:top_k]]
|
76
|
+
logger.info("Top %d papers after reranking: %s", top_k, ranked_papers)
|
77
|
+
return ranked_papers
|
@@ -0,0 +1,83 @@
|
|
1
|
+
"""
|
2
|
+
Retrieve relevant chunks from a vector store using MMR (Maximal Marginal Relevance).
|
3
|
+
"""
|
4
|
+
|
5
|
+
import logging
|
6
|
+
import os
|
7
|
+
from typing import List, Optional
|
8
|
+
|
9
|
+
import numpy as np
|
10
|
+
from langchain_core.documents import Document
|
11
|
+
from langchain_core.vectorstores.utils import maximal_marginal_relevance
|
12
|
+
|
13
|
+
|
14
|
+
# Set up logging with configurable level
|
15
|
+
log_level = os.environ.get("LOG_LEVEL", "INFO")
|
16
|
+
logging.basicConfig(level=getattr(logging, log_level))
|
17
|
+
logger = logging.getLogger(__name__)
|
18
|
+
logger.setLevel(getattr(logging, log_level))
|
19
|
+
|
20
|
+
|
21
|
+
def retrieve_relevant_chunks(
|
22
|
+
self,
|
23
|
+
query: str,
|
24
|
+
paper_ids: Optional[List[str]] = None,
|
25
|
+
top_k: int = 25,
|
26
|
+
mmr_diversity: float = 1.00,
|
27
|
+
) -> List[Document]:
|
28
|
+
"""
|
29
|
+
Retrieve the most relevant chunks for a query using maximal marginal relevance.
|
30
|
+
|
31
|
+
Args:
|
32
|
+
query: Query string
|
33
|
+
paper_ids: Optional list of paper IDs to filter by
|
34
|
+
top_k: Number of chunks to retrieve
|
35
|
+
mmr_diversity: Diversity parameter for MMR (higher = more diverse)
|
36
|
+
|
37
|
+
Returns:
|
38
|
+
List of document chunks
|
39
|
+
"""
|
40
|
+
if not self.vector_store:
|
41
|
+
logger.error("Failed to build vector store")
|
42
|
+
return []
|
43
|
+
|
44
|
+
if paper_ids:
|
45
|
+
logger.info("Filtering retrieval to papers: %s", paper_ids)
|
46
|
+
|
47
|
+
# Step 1: Embed the query
|
48
|
+
logger.info("Embedding query using model: %s", type(self.embedding_model).__name__)
|
49
|
+
query_embedding = np.array(self.embedding_model.embed_query(query))
|
50
|
+
|
51
|
+
# Step 2: Filter relevant documents
|
52
|
+
all_docs = [
|
53
|
+
doc
|
54
|
+
for doc in self.documents.values()
|
55
|
+
if not paper_ids or doc.metadata["paper_id"] in paper_ids
|
56
|
+
]
|
57
|
+
|
58
|
+
if not all_docs:
|
59
|
+
logger.warning("No documents found after filtering by paper_ids.")
|
60
|
+
return []
|
61
|
+
|
62
|
+
# Step 3: Retrieve or compute embeddings for all documents using cache
|
63
|
+
logger.info("Retrieving embeddings for %d chunks...", len(all_docs))
|
64
|
+
all_embeddings = []
|
65
|
+
for doc in all_docs:
|
66
|
+
doc_id = f"{doc.metadata['paper_id']}_{doc.metadata['chunk_id']}"
|
67
|
+
if doc_id not in self.embeddings:
|
68
|
+
logger.info("Embedding missing chunk %s", doc_id)
|
69
|
+
emb = self.embedding_model.embed_documents([doc.page_content])[0]
|
70
|
+
self.embeddings[doc_id] = emb
|
71
|
+
all_embeddings.append(self.embeddings[doc_id])
|
72
|
+
|
73
|
+
# Step 4: Apply MMR
|
74
|
+
mmr_indices = maximal_marginal_relevance(
|
75
|
+
query_embedding,
|
76
|
+
all_embeddings,
|
77
|
+
k=top_k,
|
78
|
+
lambda_mult=mmr_diversity,
|
79
|
+
)
|
80
|
+
|
81
|
+
results = [all_docs[i] for i in mmr_indices]
|
82
|
+
logger.info("Retrieved %d chunks using MMR", len(results))
|
83
|
+
return results
|
@@ -0,0 +1,125 @@
|
|
1
|
+
"""
|
2
|
+
Helper class for PDF Q&A tool orchestration: state validation, vectorstore init,
|
3
|
+
paper loading, reranking, and answer formatting.
|
4
|
+
"""
|
5
|
+
|
6
|
+
import logging
|
7
|
+
from typing import Any, Dict, List, Optional
|
8
|
+
|
9
|
+
from .generate_answer import generate_answer
|
10
|
+
from .nvidia_nim_reranker import rank_papers_by_query
|
11
|
+
from .vector_store import Vectorstore
|
12
|
+
|
13
|
+
logger = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
|
16
|
+
class QAToolHelper:
|
17
|
+
"""Encapsulates helper routines for the PDF Question & Answer tool."""
|
18
|
+
|
19
|
+
def __init__(self) -> None:
|
20
|
+
self.prebuilt_vector_store: Optional[Vectorstore] = None
|
21
|
+
self.config: Any = None
|
22
|
+
self.call_id: str = ""
|
23
|
+
logger.debug("Initialized QAToolHelper")
|
24
|
+
|
25
|
+
def start_call(self, config: Any, call_id: str) -> None:
|
26
|
+
"""Initialize helper with current config and call identifier."""
|
27
|
+
self.config = config
|
28
|
+
self.call_id = call_id
|
29
|
+
logger.debug("QAToolHelper started call %s", call_id)
|
30
|
+
|
31
|
+
def get_state_models_and_data(self, state: dict) -> tuple[Any, Any, Dict[str, Any]]:
|
32
|
+
"""Retrieve embedding model, LLM, and article data from agent state."""
|
33
|
+
text_emb = state.get("text_embedding_model")
|
34
|
+
if not text_emb:
|
35
|
+
msg = "No text embedding model found in state."
|
36
|
+
logger.error("%s: %s", self.call_id, msg)
|
37
|
+
raise ValueError(msg)
|
38
|
+
llm = state.get("llm_model")
|
39
|
+
if not llm:
|
40
|
+
msg = "No LLM model found in state."
|
41
|
+
logger.error("%s: %s", self.call_id, msg)
|
42
|
+
raise ValueError(msg)
|
43
|
+
articles = state.get("article_data", {})
|
44
|
+
if not articles:
|
45
|
+
msg = "No article_data found in state."
|
46
|
+
logger.error("%s: %s", self.call_id, msg)
|
47
|
+
raise ValueError(msg)
|
48
|
+
return text_emb, llm, articles
|
49
|
+
|
50
|
+
def init_vector_store(self, emb_model: Any) -> Vectorstore:
|
51
|
+
"""Return shared or new Vectorstore instance."""
|
52
|
+
if self.prebuilt_vector_store is not None:
|
53
|
+
logger.info("Using shared pre-built vector store from memory")
|
54
|
+
return self.prebuilt_vector_store
|
55
|
+
vs = Vectorstore(embedding_model=emb_model, config=self.config)
|
56
|
+
logger.info("Initialized new vector store with provided configuration")
|
57
|
+
self.prebuilt_vector_store = vs
|
58
|
+
return vs
|
59
|
+
|
60
|
+
def load_candidate_papers(
|
61
|
+
self,
|
62
|
+
vs: Vectorstore,
|
63
|
+
articles: Dict[str, Any],
|
64
|
+
candidates: List[str],
|
65
|
+
) -> None:
|
66
|
+
"""Ensure each candidate paper is loaded into the vector store."""
|
67
|
+
for pid in candidates:
|
68
|
+
if pid not in vs.loaded_papers:
|
69
|
+
pdf_url = articles.get(pid, {}).get("pdf_url")
|
70
|
+
if not pdf_url:
|
71
|
+
continue
|
72
|
+
try:
|
73
|
+
vs.add_paper(pid, pdf_url, articles[pid])
|
74
|
+
except (IOError, ValueError) as exc:
|
75
|
+
logger.warning(
|
76
|
+
"%s: Error loading paper %s: %s", self.call_id, pid, exc
|
77
|
+
)
|
78
|
+
|
79
|
+
def run_reranker(
|
80
|
+
self,
|
81
|
+
vs: Vectorstore,
|
82
|
+
query: str,
|
83
|
+
candidates: List[str],
|
84
|
+
) -> List[str]:
|
85
|
+
"""Rank papers by relevance and return filtered paper IDs."""
|
86
|
+
try:
|
87
|
+
ranked = rank_papers_by_query(
|
88
|
+
vs, query, self.config, top_k=self.config.top_k_papers
|
89
|
+
)
|
90
|
+
logger.info("%s: Papers after NVIDIA reranking: %s", self.call_id, ranked)
|
91
|
+
return [pid for pid in ranked if pid in candidates]
|
92
|
+
except (ValueError, RuntimeError) as exc:
|
93
|
+
logger.error("%s: NVIDIA reranker failed: %s", self.call_id, exc)
|
94
|
+
logger.info(
|
95
|
+
"%s: Falling back to all %d candidate papers",
|
96
|
+
self.call_id,
|
97
|
+
len(candidates),
|
98
|
+
)
|
99
|
+
return candidates
|
100
|
+
|
101
|
+
def format_answer(
|
102
|
+
self,
|
103
|
+
question: str,
|
104
|
+
chunks: List[Any],
|
105
|
+
llm: Any,
|
106
|
+
articles: Dict[str, Any],
|
107
|
+
) -> str:
|
108
|
+
"""Generate the final answer text with source attributions."""
|
109
|
+
result = generate_answer(question, chunks, llm, self.config)
|
110
|
+
answer = result.get("output_text", "No answer generated.")
|
111
|
+
titles: Dict[str, str] = {}
|
112
|
+
for pid in result.get("papers_used", []):
|
113
|
+
if pid in articles:
|
114
|
+
titles[pid] = articles[pid].get("Title", "Unknown paper")
|
115
|
+
if titles:
|
116
|
+
srcs = "\n\nSources:\n" + "\n".join(f"- {t}" for t in titles.values())
|
117
|
+
else:
|
118
|
+
srcs = ""
|
119
|
+
logger.info(
|
120
|
+
"%s: Generated answer using %d chunks from %d papers",
|
121
|
+
self.call_id,
|
122
|
+
len(chunks),
|
123
|
+
len(titles),
|
124
|
+
)
|
125
|
+
return f"{answer}{srcs}"
|
@@ -0,0 +1,162 @@
|
|
1
|
+
"""
|
2
|
+
Vectorstore class for managing document embeddings and retrieval.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import logging
|
6
|
+
import os
|
7
|
+
import time
|
8
|
+
from typing import Any, Dict, List, Optional
|
9
|
+
|
10
|
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
11
|
+
from langchain_community.document_loaders import PyPDFLoader
|
12
|
+
from langchain_community.vectorstores import FAISS
|
13
|
+
from langchain_core.documents import Document
|
14
|
+
from langchain_core.embeddings import Embeddings
|
15
|
+
from langchain_core.vectorstores import VectorStore
|
16
|
+
|
17
|
+
|
18
|
+
# Set up logging with configurable level
|
19
|
+
log_level = os.environ.get("LOG_LEVEL", "INFO")
|
20
|
+
logging.basicConfig(level=getattr(logging, log_level))
|
21
|
+
logger = logging.getLogger(__name__)
|
22
|
+
logger.setLevel(getattr(logging, log_level))
|
23
|
+
|
24
|
+
|
25
|
+
class Vectorstore:
|
26
|
+
"""
|
27
|
+
A class for managing document embeddings and retrieval.
|
28
|
+
Provides unified access to documents across multiple papers.
|
29
|
+
"""
|
30
|
+
|
31
|
+
def __init__(
|
32
|
+
self,
|
33
|
+
embedding_model: Embeddings,
|
34
|
+
metadata_fields: Optional[List[str]] = None,
|
35
|
+
config: Any = None,
|
36
|
+
):
|
37
|
+
"""
|
38
|
+
Initialize the document store.
|
39
|
+
|
40
|
+
Args:
|
41
|
+
embedding_model: The embedding model to use
|
42
|
+
metadata_fields: Fields to include in document metadata for filtering/retrieval
|
43
|
+
"""
|
44
|
+
self.embedding_model = embedding_model
|
45
|
+
self.config = config
|
46
|
+
self.metadata_fields = metadata_fields or [
|
47
|
+
"title",
|
48
|
+
"paper_id",
|
49
|
+
"page",
|
50
|
+
"chunk_id",
|
51
|
+
]
|
52
|
+
self.initialization_time = time.time()
|
53
|
+
logger.info("Vectorstore initialized at: %s", self.initialization_time)
|
54
|
+
|
55
|
+
# Track loaded papers to prevent duplicate loading
|
56
|
+
self.loaded_papers = set()
|
57
|
+
self.vector_store_class = FAISS
|
58
|
+
logger.info("Using FAISS vector store")
|
59
|
+
|
60
|
+
# Store for initialized documents
|
61
|
+
self.documents: Dict[str, Document] = {}
|
62
|
+
self.vector_store: Optional[VectorStore] = None
|
63
|
+
self.paper_metadata: Dict[str, Dict[str, Any]] = {}
|
64
|
+
# Cache for document chunk embeddings to avoid recomputation
|
65
|
+
self.embeddings: Dict[str, Any] = {}
|
66
|
+
|
67
|
+
def add_paper(
|
68
|
+
self,
|
69
|
+
paper_id: str,
|
70
|
+
pdf_url: str,
|
71
|
+
paper_metadata: Dict[str, Any],
|
72
|
+
) -> None:
|
73
|
+
"""
|
74
|
+
Add a paper to the document store.
|
75
|
+
|
76
|
+
Args:
|
77
|
+
paper_id: Unique identifier for the paper
|
78
|
+
pdf_url: URL to the PDF
|
79
|
+
paper_metadata: Metadata about the paper
|
80
|
+
"""
|
81
|
+
# Skip if already loaded
|
82
|
+
if paper_id in self.loaded_papers:
|
83
|
+
logger.info("Paper %s already loaded, skipping", paper_id)
|
84
|
+
return
|
85
|
+
|
86
|
+
logger.info("Loading paper %s from %s", paper_id, pdf_url)
|
87
|
+
|
88
|
+
# Store paper metadata
|
89
|
+
self.paper_metadata[paper_id] = paper_metadata
|
90
|
+
|
91
|
+
# Load the PDF and split into chunks according to Hydra config
|
92
|
+
loader = PyPDFLoader(pdf_url)
|
93
|
+
documents = loader.load()
|
94
|
+
logger.info("Loaded %d pages from %s", len(documents), paper_id)
|
95
|
+
|
96
|
+
# Create text splitter according to provided configuration
|
97
|
+
if self.config is None:
|
98
|
+
raise ValueError(
|
99
|
+
"Configuration is required for text splitting in Vectorstore."
|
100
|
+
)
|
101
|
+
splitter = RecursiveCharacterTextSplitter(
|
102
|
+
chunk_size=self.config.chunk_size,
|
103
|
+
chunk_overlap=self.config.chunk_overlap,
|
104
|
+
separators=["\n\n", "\n", ". ", " ", ""],
|
105
|
+
)
|
106
|
+
|
107
|
+
# Split documents and add metadata for each chunk
|
108
|
+
chunks = splitter.split_documents(documents)
|
109
|
+
logger.info("Split %s into %d chunks", paper_id, len(chunks))
|
110
|
+
# Embed and cache chunk embeddings
|
111
|
+
chunk_texts = [chunk.page_content for chunk in chunks]
|
112
|
+
chunk_embeddings = self.embedding_model.embed_documents(chunk_texts)
|
113
|
+
logger.info("Embedded %d chunks for paper %s", len(chunk_embeddings), paper_id)
|
114
|
+
|
115
|
+
# Enhance document metadata
|
116
|
+
for i, chunk in enumerate(chunks):
|
117
|
+
# Add paper metadata to each chunk
|
118
|
+
chunk.metadata.update(
|
119
|
+
{
|
120
|
+
"paper_id": paper_id,
|
121
|
+
"title": paper_metadata.get("Title", "Unknown"),
|
122
|
+
"chunk_id": i,
|
123
|
+
# Keep existing page number if available
|
124
|
+
"page": chunk.metadata.get("page", 0),
|
125
|
+
}
|
126
|
+
)
|
127
|
+
|
128
|
+
# Add any additional metadata fields
|
129
|
+
for field in self.metadata_fields:
|
130
|
+
if field in paper_metadata and field not in chunk.metadata:
|
131
|
+
chunk.metadata[field] = paper_metadata[field]
|
132
|
+
|
133
|
+
# Store chunk
|
134
|
+
doc_id = f"{paper_id}_{i}"
|
135
|
+
self.documents[doc_id] = chunk
|
136
|
+
# Cache embedding if available
|
137
|
+
if chunk_embeddings[i] is not None:
|
138
|
+
self.embeddings[doc_id] = chunk_embeddings[i]
|
139
|
+
|
140
|
+
# Mark as loaded to prevent duplicate loading
|
141
|
+
self.loaded_papers.add(paper_id)
|
142
|
+
logger.info("Added %d chunks from paper %s", len(chunks), paper_id)
|
143
|
+
|
144
|
+
def build_vector_store(self) -> None:
|
145
|
+
"""
|
146
|
+
Build the vector store from all loaded documents.
|
147
|
+
Should be called after all papers are added.
|
148
|
+
"""
|
149
|
+
if not self.documents:
|
150
|
+
logger.warning("No documents added to build vector store")
|
151
|
+
return
|
152
|
+
|
153
|
+
if self.vector_store is not None:
|
154
|
+
logger.info("Vector store already built, skipping")
|
155
|
+
return
|
156
|
+
|
157
|
+
# Create vector store from documents
|
158
|
+
documents_list = list(self.documents.values())
|
159
|
+
self.vector_store = self.vector_store_class.from_documents(
|
160
|
+
documents=documents_list, embedding=self.embedding_model
|
161
|
+
)
|
162
|
+
logger.info("Built vector store with %d documents", len(documents_list))
|
@@ -4,17 +4,19 @@
|
|
4
4
|
"""
|
5
5
|
Tool for rendering the most recently displayed papers as a DataFrame artifact for the front-end.
|
6
6
|
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
a
|
11
|
-
|
7
|
+
Call this tool when you need to present the current set of retrieved papers to the user
|
8
|
+
(e.g., "show me the papers", "display results"). It reads the 'last_displayed_papers'
|
9
|
+
dictionary from the agent state and returns it as an artifact that the UI will render
|
10
|
+
as a pandas DataFrame. This tool does not perform any new searches or filtering; it
|
11
|
+
only displays the existing list. If no papers are available, it raises NoPapersFoundError
|
12
|
+
to signal that a search or recommendation must be executed first.
|
12
13
|
"""
|
13
14
|
|
14
15
|
|
15
16
|
import logging
|
16
17
|
|
17
18
|
from typing import Annotated
|
19
|
+
from pydantic import BaseModel, Field
|
18
20
|
from langchain_core.messages import ToolMessage
|
19
21
|
from langchain_core.tools import tool
|
20
22
|
from langchain_core.tools.base import InjectedToolCallId
|
@@ -40,10 +42,31 @@ class NoPapersFoundError(Exception):
|
|
40
42
|
"""
|
41
43
|
|
42
44
|
|
43
|
-
|
45
|
+
class DisplayDataFrameInput(BaseModel):
|
46
|
+
"""
|
47
|
+
Pydantic schema for displaying the last set of papers as a DataFrame artifact.
|
48
|
+
|
49
|
+
Fields:
|
50
|
+
state: Agent state dict containing the 'last_displayed_papers' key.
|
51
|
+
tool_call_id: LangGraph-injected identifier for this tool invocation.
|
52
|
+
"""
|
53
|
+
|
54
|
+
state: Annotated[dict, InjectedState] = Field(
|
55
|
+
..., description="Agent state containing the 'last_displayed_papers' reference."
|
56
|
+
)
|
57
|
+
tool_call_id: Annotated[str, InjectedToolCallId] = Field(
|
58
|
+
..., description="LangGraph-injected identifier for this tool call."
|
59
|
+
)
|
60
|
+
|
61
|
+
|
62
|
+
@tool(
|
63
|
+
"display_dataframe",
|
64
|
+
args_schema=DisplayDataFrameInput,
|
65
|
+
parse_docstring=True,
|
66
|
+
)
|
44
67
|
def display_dataframe(
|
45
|
-
tool_call_id:
|
46
|
-
state:
|
68
|
+
tool_call_id: str,
|
69
|
+
state: dict,
|
47
70
|
) -> Command:
|
48
71
|
"""
|
49
72
|
Render the last set of retrieved papers as a DataFrame in the front-end.
|
@@ -55,7 +78,7 @@ def display_dataframe(
|
|
55
78
|
that a search or recommendation must be performed first.
|
56
79
|
|
57
80
|
Args:
|
58
|
-
tool_call_id (
|
81
|
+
tool_call_id (str): LangGraph-injected unique ID for this tool call.
|
59
82
|
state (dict): The agent's state containing the 'last_displayed_papers' reference.
|
60
83
|
|
61
84
|
Returns:
|
@@ -65,7 +88,7 @@ def display_dataframe(
|
|
65
88
|
Raises:
|
66
89
|
NoPapersFoundError: If no entries exist under 'last_displayed_papers' in state.
|
67
90
|
"""
|
68
|
-
logger.info("Displaying papers")
|
91
|
+
logger.info("Displaying papers from 'last_displayed_papers'")
|
69
92
|
context_val = state.get("last_displayed_papers")
|
70
93
|
# Support both key reference (str) and direct mapping
|
71
94
|
if isinstance(context_val, dict):
|