aiagents4pharma 1.40.1__py3-none-any.whl → 1.41.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiagents4pharma/talk2scholars/configs/agents/talk2scholars/s2_agent/default.yaml +4 -0
- aiagents4pharma/talk2scholars/configs/tools/question_and_answer/default.yaml +44 -4
- aiagents4pharma/talk2scholars/tests/test_nvidia_nim_reranker.py +127 -0
- aiagents4pharma/talk2scholars/tests/test_pdf_answer_formatter.py +66 -0
- aiagents4pharma/talk2scholars/tests/test_pdf_batch_processor.py +101 -0
- aiagents4pharma/talk2scholars/tests/test_pdf_collection_manager.py +150 -0
- aiagents4pharma/talk2scholars/tests/test_pdf_document_processor.py +69 -0
- aiagents4pharma/talk2scholars/tests/test_pdf_generate_answer.py +75 -0
- aiagents4pharma/talk2scholars/tests/test_pdf_gpu_detection.py +140 -0
- aiagents4pharma/talk2scholars/tests/test_pdf_paper_loader.py +116 -0
- aiagents4pharma/talk2scholars/tests/test_pdf_rag_pipeline.py +98 -0
- aiagents4pharma/talk2scholars/tests/test_pdf_retrieve_chunks.py +197 -0
- aiagents4pharma/talk2scholars/tests/test_pdf_singleton_manager.py +156 -0
- aiagents4pharma/talk2scholars/tests/test_pdf_vector_normalization.py +121 -0
- aiagents4pharma/talk2scholars/tests/test_pdf_vector_store.py +434 -0
- aiagents4pharma/talk2scholars/tests/test_question_and_answer_tool.py +89 -509
- aiagents4pharma/talk2scholars/tests/test_tool_helper_utils.py +34 -89
- aiagents4pharma/talk2scholars/tools/paper_download/download_biorxiv_input.py +8 -6
- aiagents4pharma/talk2scholars/tools/paper_download/download_medrxiv_input.py +6 -4
- aiagents4pharma/talk2scholars/tools/pdf/question_and_answer.py +74 -40
- aiagents4pharma/talk2scholars/tools/pdf/utils/__init__.py +26 -1
- aiagents4pharma/talk2scholars/tools/pdf/utils/answer_formatter.py +62 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/batch_processor.py +200 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/collection_manager.py +172 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/document_processor.py +76 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/generate_answer.py +14 -14
- aiagents4pharma/talk2scholars/tools/pdf/utils/get_vectorstore.py +63 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/gpu_detection.py +154 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/nvidia_nim_reranker.py +60 -40
- aiagents4pharma/talk2scholars/tools/pdf/utils/paper_loader.py +123 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/rag_pipeline.py +122 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/retrieve_chunks.py +162 -40
- aiagents4pharma/talk2scholars/tools/pdf/utils/singleton_manager.py +140 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/tool_helper.py +40 -78
- aiagents4pharma/talk2scholars/tools/pdf/utils/vector_normalization.py +159 -0
- aiagents4pharma/talk2scholars/tools/pdf/utils/vector_store.py +277 -96
- aiagents4pharma/talk2scholars/tools/s2/multi_paper_rec.py +12 -9
- aiagents4pharma/talk2scholars/tools/s2/query_dataframe.py +0 -1
- aiagents4pharma/talk2scholars/tools/s2/retrieve_semantic_scholar_paper_id.py +9 -8
- aiagents4pharma/talk2scholars/tools/s2/single_paper_rec.py +5 -5
- {aiagents4pharma-1.40.1.dist-info → aiagents4pharma-1.41.0.dist-info}/METADATA +27 -115
- {aiagents4pharma-1.40.1.dist-info → aiagents4pharma-1.41.0.dist-info}/RECORD +45 -23
- aiagents4pharma/talk2scholars/tests/test_nvidia_nim_reranker_utils.py +0 -28
- {aiagents4pharma-1.40.1.dist-info → aiagents4pharma-1.41.0.dist-info}/WHEEL +0 -0
- {aiagents4pharma-1.40.1.dist-info → aiagents4pharma-1.41.0.dist-info}/licenses/LICENSE +0 -0
- {aiagents4pharma-1.40.1.dist-info → aiagents4pharma-1.41.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,154 @@
|
|
1
|
+
"""
|
2
|
+
GPU Detection Utility for Milvus Index Selection
|
3
|
+
Handle COSINE -> IP conversion for GPU indexes
|
4
|
+
"""
|
5
|
+
|
6
|
+
import logging
|
7
|
+
import subprocess
|
8
|
+
from typing import Dict, Any, Tuple
|
9
|
+
|
10
|
+
logger = logging.getLogger(__name__)
|
11
|
+
|
12
|
+
|
13
|
+
def detect_nvidia_gpu(config=None) -> bool:
|
14
|
+
"""
|
15
|
+
Detect if NVIDIA GPU is available and should be used.
|
16
|
+
|
17
|
+
Args:
|
18
|
+
config: Hydra config object that may contain force_cpu_mode flag
|
19
|
+
|
20
|
+
Returns:
|
21
|
+
bool: True if GPU should be used, False if CPU should be used
|
22
|
+
"""
|
23
|
+
|
24
|
+
# Check for force CPU mode in config
|
25
|
+
if config and hasattr(config, "gpu_detection"):
|
26
|
+
force_cpu = getattr(config.gpu_detection, "force_cpu_mode", False)
|
27
|
+
if force_cpu:
|
28
|
+
logger.info(
|
29
|
+
"Force CPU mode enabled in config - using CPU even though GPU may be available"
|
30
|
+
)
|
31
|
+
return False
|
32
|
+
|
33
|
+
# Normal GPU detection logic
|
34
|
+
try:
|
35
|
+
result = subprocess.run(
|
36
|
+
["nvidia-smi", "--query-gpu=name", "--format=csv,noheader"],
|
37
|
+
capture_output=True,
|
38
|
+
text=True,
|
39
|
+
timeout=10,
|
40
|
+
check=False,
|
41
|
+
)
|
42
|
+
|
43
|
+
if result.returncode == 0 and result.stdout.strip():
|
44
|
+
gpu_names = result.stdout.strip().split("\n")
|
45
|
+
logger.info("Detected NVIDIA GPU(s): %s", gpu_names)
|
46
|
+
logger.info("To force CPU mode, set 'force_cpu_mode: true' in config")
|
47
|
+
return True
|
48
|
+
|
49
|
+
logger.info("nvidia-smi command failed or no GPUs detected")
|
50
|
+
return False
|
51
|
+
|
52
|
+
except (subprocess.TimeoutExpired, FileNotFoundError) as e:
|
53
|
+
logger.info("NVIDIA GPU detection failed: %s", e)
|
54
|
+
return False
|
55
|
+
|
56
|
+
|
57
|
+
def get_optimal_index_config(
|
58
|
+
has_gpu: bool, embedding_dim: int = 768, use_cosine: bool = True
|
59
|
+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
60
|
+
"""
|
61
|
+
Get optimal index and search parameters based on GPU availability.
|
62
|
+
|
63
|
+
IMPORTANT: GPU indexes don't support COSINE distance. When using GPU with COSINE,
|
64
|
+
vectors must be normalized and IP distance used instead.
|
65
|
+
|
66
|
+
Args:
|
67
|
+
has_gpu (bool): Whether NVIDIA GPU is available
|
68
|
+
embedding_dim (int): Dimension of embeddings
|
69
|
+
use_cosine (bool): Whether to use cosine similarity (will be converted to IP for GPU)
|
70
|
+
|
71
|
+
Returns:
|
72
|
+
Tuple[Dict[str, Any], Dict[str, Any]]: (index_params, search_params)
|
73
|
+
"""
|
74
|
+
if has_gpu:
|
75
|
+
logger.info("Configuring GPU_CAGRA index for NVIDIA GPU")
|
76
|
+
|
77
|
+
# For GPU: COSINE is not supported, must use IP with normalized vectors
|
78
|
+
if use_cosine:
|
79
|
+
logger.warning(
|
80
|
+
"GPU indexes don't support COSINE distance. "
|
81
|
+
"Vectors will be normalized and IP distance will be used instead."
|
82
|
+
)
|
83
|
+
metric_type = (
|
84
|
+
"IP" # Inner Product for normalized vectors = cosine similarity
|
85
|
+
)
|
86
|
+
else:
|
87
|
+
metric_type = "IP" # Default to IP for GPU
|
88
|
+
|
89
|
+
# GPU_CAGRA index parameters - optimized for performance
|
90
|
+
index_params = {
|
91
|
+
"index_type": "GPU_CAGRA",
|
92
|
+
"metric_type": metric_type,
|
93
|
+
"params": {
|
94
|
+
"intermediate_graph_degree": 64, # Higher for better recall
|
95
|
+
"graph_degree": 32, # Balanced performance/recall
|
96
|
+
"build_algo": "IVF_PQ", # Higher quality build
|
97
|
+
"cache_dataset_on_device": "true", # Cache for better recall
|
98
|
+
"adapt_for_cpu": "false", # Pure GPU mode
|
99
|
+
},
|
100
|
+
}
|
101
|
+
|
102
|
+
# GPU_CAGRA search parameters
|
103
|
+
search_params = {
|
104
|
+
"metric_type": metric_type,
|
105
|
+
"params": {
|
106
|
+
"itopk_size": 128, # Power of 2, good for intermediate results
|
107
|
+
"search_width": 16, # Balanced entry points
|
108
|
+
"team_size": 16, # Optimize for typical vector dimensions
|
109
|
+
},
|
110
|
+
}
|
111
|
+
|
112
|
+
else:
|
113
|
+
logger.info("Configuring CPU index (IVF_FLAT) - no NVIDIA GPU detected")
|
114
|
+
|
115
|
+
# CPU supports COSINE directly
|
116
|
+
metric_type = "COSINE" if use_cosine else "IP"
|
117
|
+
|
118
|
+
# CPU IVF_FLAT index parameters
|
119
|
+
index_params = {
|
120
|
+
"index_type": "IVF_FLAT",
|
121
|
+
"metric_type": metric_type,
|
122
|
+
"params": {
|
123
|
+
"nlist": min(
|
124
|
+
1024, max(64, embedding_dim // 8)
|
125
|
+
) # Dynamic nlist based on dimension
|
126
|
+
},
|
127
|
+
}
|
128
|
+
|
129
|
+
# CPU search parameters
|
130
|
+
search_params = {
|
131
|
+
"metric_type": metric_type,
|
132
|
+
"params": {"nprobe": 16}, # Slightly higher than original for better recall
|
133
|
+
}
|
134
|
+
|
135
|
+
return index_params, search_params
|
136
|
+
|
137
|
+
|
138
|
+
def log_index_configuration(
|
139
|
+
index_params: Dict[str, Any], search_params: Dict[str, Any], use_cosine: bool = True
|
140
|
+
) -> None:
|
141
|
+
"""Log the selected index configuration for debugging."""
|
142
|
+
index_type = index_params.get("index_type", "Unknown")
|
143
|
+
metric_type = index_params.get("metric_type", "Unknown")
|
144
|
+
|
145
|
+
logger.info("=== Milvus Index Configuration ===")
|
146
|
+
logger.info("Index Type: %s", index_type)
|
147
|
+
logger.info("Metric Type: %s", metric_type)
|
148
|
+
|
149
|
+
if index_type == "GPU_CAGRA" and use_cosine and metric_type == "IP":
|
150
|
+
logger.info("NOTE: Using IP with normalized vectors to simulate COSINE for GPU")
|
151
|
+
|
152
|
+
logger.info("Index Params: %s", index_params.get("params", {}))
|
153
|
+
logger.info("Search Params: %s", search_params.get("params", {}))
|
154
|
+
logger.info("===================================")
|
@@ -1,11 +1,10 @@
|
|
1
1
|
"""
|
2
|
-
NVIDIA NIM Reranker Utility
|
2
|
+
NVIDIA NIM Reranker Utility for Milvus Integration
|
3
|
+
Rerank chunks instead of papers following traditional RAG pipeline
|
3
4
|
"""
|
4
5
|
|
5
6
|
import logging
|
6
7
|
import os
|
7
|
-
|
8
|
-
|
9
8
|
from typing import Any, List
|
10
9
|
|
11
10
|
from langchain_core.documents import Document
|
@@ -18,60 +17,81 @@ logger = logging.getLogger(__name__)
|
|
18
17
|
logger.setLevel(getattr(logging, log_level))
|
19
18
|
|
20
19
|
|
21
|
-
def
|
20
|
+
def rerank_chunks(
|
21
|
+
chunks: List[Document], query: str, config: Any, top_k: int = 25
|
22
|
+
) -> List[Document]:
|
22
23
|
"""
|
23
|
-
|
24
|
+
Rerank chunks by relevance to the query using NVIDIA's reranker.
|
24
25
|
|
25
|
-
This
|
26
|
-
and returns the top-k papers.
|
26
|
+
This follows the traditional RAG pipeline: first retrieve chunks, then rerank them.
|
27
27
|
|
28
28
|
Args:
|
29
|
-
|
30
|
-
|
31
|
-
|
29
|
+
chunks (List[Document]): List of chunks to rerank
|
30
|
+
query (str): The query string
|
31
|
+
config (Any): Configuration containing reranker settings
|
32
|
+
top_k (int): Number of top chunks to return after reranking
|
32
33
|
|
33
34
|
Returns:
|
34
|
-
List
|
35
|
+
List[Document]: Reranked chunks (top_k most relevant)
|
35
36
|
"""
|
37
|
+
logger.info(
|
38
|
+
"Starting NVIDIA chunk reranker for query: '%s' with %d chunks, top_k=%d",
|
39
|
+
query[:50] + "..." if len(query) > 50 else query,
|
40
|
+
len(chunks),
|
41
|
+
top_k,
|
42
|
+
)
|
36
43
|
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
aggregated_documents = []
|
45
|
-
for paper_id, texts in paper_texts.items():
|
46
|
-
aggregated_text = " ".join(texts)
|
47
|
-
aggregated_documents.append(
|
48
|
-
Document(page_content=aggregated_text, metadata={"paper_id": paper_id})
|
44
|
+
# If we have fewer chunks than top_k, just return all
|
45
|
+
if len(chunks) <= top_k:
|
46
|
+
logger.info(
|
47
|
+
"Number of chunks (%d) <= top_k (%d), returning all chunks without reranking",
|
48
|
+
len(chunks),
|
49
|
+
top_k,
|
49
50
|
)
|
51
|
+
return chunks
|
50
52
|
|
51
|
-
|
52
|
-
"Aggregated %d papers into %d documents for reranking",
|
53
|
-
len(paper_texts),
|
54
|
-
len(aggregated_documents),
|
55
|
-
)
|
56
|
-
# Instantiate the NVIDIA re-ranker client using provided config
|
57
|
-
# Use NVIDIA API key from Hydra configuration (expected to be resolved via oc.env)
|
53
|
+
# Get API key from config
|
58
54
|
api_key = config.reranker.api_key
|
59
55
|
if not api_key:
|
60
56
|
logger.error("No NVIDIA API key found in configuration for reranking")
|
61
57
|
raise ValueError("Configuration 'reranker.api_key' must be set for reranking")
|
62
|
-
|
63
|
-
|
64
|
-
|
58
|
+
|
59
|
+
logger.info("Using NVIDIA reranker model: %s", config.reranker.model)
|
60
|
+
|
61
|
+
# Initialize reranker with truncation to handle long chunks
|
65
62
|
reranker = NVIDIARerank(
|
66
63
|
model=config.reranker.model,
|
67
64
|
api_key=api_key,
|
68
|
-
truncate="END",
|
65
|
+
truncate="END", # Truncate at the end if too long
|
66
|
+
)
|
67
|
+
|
68
|
+
# Log chunk metadata for debugging
|
69
|
+
logger.debug(
|
70
|
+
"Reranking chunks from papers: %s",
|
71
|
+
list(set(chunk.metadata.get("paper_id", "unknown") for chunk in chunks))[:5],
|
72
|
+
)
|
73
|
+
|
74
|
+
# Rerank the chunks
|
75
|
+
logger.info("Calling NVIDIA reranker API with %d chunks...", len(chunks))
|
76
|
+
reranked_chunks = reranker.compress_documents(query=query, documents=chunks)
|
77
|
+
|
78
|
+
for i, doc in enumerate(reranked_chunks[:top_k]):
|
79
|
+
score = doc.metadata.get("relevance_score", "N/A")
|
80
|
+
source = doc.metadata.get("paper_id", "unknown")
|
81
|
+
logger.info("Rank %d | Score: %.4f | Source: %s", i + 1, score, source)
|
82
|
+
|
83
|
+
logger.info(
|
84
|
+
"Successfully reranked chunks. Returning top %d chunks",
|
85
|
+
min(top_k, len(reranked_chunks)),
|
69
86
|
)
|
70
87
|
|
71
|
-
#
|
72
|
-
|
73
|
-
|
88
|
+
# Log which papers the top chunks come from
|
89
|
+
if reranked_chunks and logger.isEnabledFor(logging.DEBUG):
|
90
|
+
top_papers = {}
|
91
|
+
for chunk in reranked_chunks[:top_k]:
|
92
|
+
paper_id = chunk.metadata.get("paper_id", "unknown")
|
93
|
+
top_papers[paper_id] = top_papers.get(paper_id, 0) + 1
|
94
|
+
logger.debug("Top %d chunks distribution by paper: %s", top_k, top_papers)
|
74
95
|
|
75
|
-
|
76
|
-
|
77
|
-
return ranked_papers
|
96
|
+
# Return only top_k chunks (convert to list to match return type)
|
97
|
+
return list(reranked_chunks[:top_k])
|
@@ -0,0 +1,123 @@
|
|
1
|
+
"""
|
2
|
+
Paper loading utilities for managing PDF documents in vector store.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import logging
|
6
|
+
from typing import Any, Dict
|
7
|
+
|
8
|
+
from .batch_processor import add_papers_batch
|
9
|
+
|
10
|
+
logger = logging.getLogger(__name__)
|
11
|
+
|
12
|
+
|
13
|
+
def load_all_papers(
|
14
|
+
vector_store: Any, # The Vectorstore instance
|
15
|
+
articles: Dict[str, Any],
|
16
|
+
call_id: str,
|
17
|
+
config: Any,
|
18
|
+
has_gpu: bool,
|
19
|
+
) -> None:
|
20
|
+
"""
|
21
|
+
Ensure all papers from article_data are loaded into the Milvus vector store.
|
22
|
+
Optimized for GPU/CPU processing.
|
23
|
+
|
24
|
+
Args:
|
25
|
+
vector_store: The Vectorstore instance
|
26
|
+
articles: Dictionary of article data
|
27
|
+
call_id: Call identifier for logging
|
28
|
+
config: Configuration object
|
29
|
+
has_gpu: Whether GPU is available
|
30
|
+
"""
|
31
|
+
papers_to_load = []
|
32
|
+
skipped_papers = []
|
33
|
+
already_loaded = []
|
34
|
+
|
35
|
+
# Check which papers need to be loaded
|
36
|
+
for pid, article_info in articles.items():
|
37
|
+
if pid not in vector_store.loaded_papers:
|
38
|
+
pdf_url = article_info.get("pdf_url")
|
39
|
+
if pdf_url:
|
40
|
+
# Prepare tuple for batch loading
|
41
|
+
papers_to_load.append((pid, pdf_url, article_info))
|
42
|
+
else:
|
43
|
+
skipped_papers.append(pid)
|
44
|
+
else:
|
45
|
+
already_loaded.append(pid)
|
46
|
+
|
47
|
+
# Log summary of papers status with hardware info
|
48
|
+
hardware_info = f" (GPU acceleration: {'enabled' if has_gpu else 'disabled'})"
|
49
|
+
logger.info(
|
50
|
+
"%s: Paper loading summary%s - Total: %d, Already loaded: %d, To load: %d, No PDF: %d",
|
51
|
+
call_id,
|
52
|
+
hardware_info,
|
53
|
+
len(articles),
|
54
|
+
len(already_loaded),
|
55
|
+
len(papers_to_load),
|
56
|
+
len(skipped_papers),
|
57
|
+
)
|
58
|
+
|
59
|
+
if skipped_papers:
|
60
|
+
logger.warning(
|
61
|
+
"%s: Skipping %d papers without PDF URLs: %s%s",
|
62
|
+
call_id,
|
63
|
+
len(skipped_papers),
|
64
|
+
skipped_papers[:5], # Show first 5
|
65
|
+
"..." if len(skipped_papers) > 5 else "",
|
66
|
+
)
|
67
|
+
|
68
|
+
if not papers_to_load:
|
69
|
+
logger.info("%s: All papers with PDFs are already loaded in Milvus", call_id)
|
70
|
+
return
|
71
|
+
|
72
|
+
# Use batch loading with parallel processing for ALL papers at once
|
73
|
+
# Adjust parameters based on hardware capabilities
|
74
|
+
if has_gpu:
|
75
|
+
# GPU can handle more parallel processing
|
76
|
+
max_workers = min(12, max(4, len(papers_to_load))) # More workers for GPU
|
77
|
+
batch_size = config.get("embedding_batch_size", 2000) # Larger batches for GPU
|
78
|
+
logger.info(
|
79
|
+
"%s: Using GPU-optimized loading parameters: %d workers, batch size %d",
|
80
|
+
call_id,
|
81
|
+
max_workers,
|
82
|
+
batch_size,
|
83
|
+
)
|
84
|
+
else:
|
85
|
+
# CPU - more conservative parameters
|
86
|
+
max_workers = min(8, max(3, len(papers_to_load))) # Conservative for CPU
|
87
|
+
batch_size = config.get("embedding_batch_size", 1000) # Smaller batches for CPU
|
88
|
+
logger.info(
|
89
|
+
"%s: Using CPU-optimized loading parameters: %d workers, batch size %d",
|
90
|
+
call_id,
|
91
|
+
max_workers,
|
92
|
+
batch_size,
|
93
|
+
)
|
94
|
+
|
95
|
+
logger.info(
|
96
|
+
"%s: Loading %d papers in ONE BATCH using %d parallel workers (batch size: %d, %s)",
|
97
|
+
call_id,
|
98
|
+
len(papers_to_load),
|
99
|
+
max_workers,
|
100
|
+
batch_size,
|
101
|
+
"GPU accelerated" if has_gpu else "CPU processing",
|
102
|
+
)
|
103
|
+
|
104
|
+
# This should process ALL papers at once with hardware optimization
|
105
|
+
add_papers_batch(
|
106
|
+
papers_to_add=papers_to_load,
|
107
|
+
vector_store=vector_store.vector_store, # Pass the LangChain vector store
|
108
|
+
loaded_papers=vector_store.loaded_papers,
|
109
|
+
paper_metadata=vector_store.paper_metadata,
|
110
|
+
documents=vector_store.documents,
|
111
|
+
config=vector_store.config,
|
112
|
+
metadata_fields=vector_store.metadata_fields,
|
113
|
+
has_gpu=vector_store.has_gpu,
|
114
|
+
max_workers=max_workers,
|
115
|
+
batch_size=batch_size,
|
116
|
+
)
|
117
|
+
|
118
|
+
logger.info(
|
119
|
+
"%s: Successfully completed batch loading of all %d papers with %s",
|
120
|
+
call_id,
|
121
|
+
len(papers_to_load),
|
122
|
+
"GPU acceleration" if has_gpu else "CPU processing",
|
123
|
+
)
|
@@ -0,0 +1,122 @@
|
|
1
|
+
"""
|
2
|
+
RAG pipeline for retrieving and reranking chunks from a vector store.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import logging
|
6
|
+
from typing import Any, List
|
7
|
+
|
8
|
+
|
9
|
+
# Import our GPU detection utility
|
10
|
+
from .nvidia_nim_reranker import rerank_chunks
|
11
|
+
from .retrieve_chunks import retrieve_relevant_chunks
|
12
|
+
|
13
|
+
logger = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
|
16
|
+
def retrieve_and_rerank_chunks(
|
17
|
+
vector_store: Any, query: str, config: Any, call_id: str, has_gpu: bool
|
18
|
+
) -> List[Any]:
|
19
|
+
"""
|
20
|
+
Traditional RAG pipeline: retrieve chunks from all papers, then rerank.
|
21
|
+
Optimized for GPU/CPU hardware.
|
22
|
+
|
23
|
+
Args:
|
24
|
+
vs: Vector store instance
|
25
|
+
query: User query
|
26
|
+
|
27
|
+
Returns:
|
28
|
+
List of reranked chunks
|
29
|
+
"""
|
30
|
+
hardware_mode = "GPU-accelerated" if has_gpu else "CPU-optimized"
|
31
|
+
logger.info(
|
32
|
+
"%s: Starting traditional RAG pipeline - retrieve then rerank (%s)",
|
33
|
+
call_id,
|
34
|
+
hardware_mode,
|
35
|
+
)
|
36
|
+
|
37
|
+
# Step 1: Retrieve chunks from ALL papers (cast wide net)
|
38
|
+
# Adjust initial retrieval count based on hardware
|
39
|
+
if has_gpu:
|
40
|
+
# GPU can handle larger initial retrieval efficiently
|
41
|
+
initial_chunks_count = config.get(
|
42
|
+
"initial_retrieval_k", 150
|
43
|
+
) # Increased for GPU
|
44
|
+
mmr_diversity = config.get(
|
45
|
+
"mmr_diversity", 0.75
|
46
|
+
) # Slightly more diverse for larger sets
|
47
|
+
else:
|
48
|
+
# CPU - use conservative settings
|
49
|
+
initial_chunks_count = config.get("initial_retrieval_k", 100) # Original
|
50
|
+
mmr_diversity = config.get("mmr_diversity", 0.8) # Original
|
51
|
+
|
52
|
+
logger.info(
|
53
|
+
"%s: Step 1 - Retrieving top %d chunks from ALL papers (%s mode)",
|
54
|
+
call_id,
|
55
|
+
initial_chunks_count,
|
56
|
+
hardware_mode,
|
57
|
+
)
|
58
|
+
|
59
|
+
retrieved_chunks = retrieve_relevant_chunks(
|
60
|
+
vector_store,
|
61
|
+
query=query,
|
62
|
+
paper_ids=None, # No filter - retrieve from all papers
|
63
|
+
top_k=initial_chunks_count,
|
64
|
+
mmr_diversity=mmr_diversity,
|
65
|
+
)
|
66
|
+
|
67
|
+
if not retrieved_chunks:
|
68
|
+
logger.warning("%s: No chunks retrieved from vector store", call_id)
|
69
|
+
return []
|
70
|
+
|
71
|
+
logger.info(
|
72
|
+
"%s: Retrieved %d chunks from %d unique papers using %s",
|
73
|
+
call_id,
|
74
|
+
len(retrieved_chunks),
|
75
|
+
len(
|
76
|
+
set(chunk.metadata.get("paper_id", "unknown") for chunk in retrieved_chunks)
|
77
|
+
),
|
78
|
+
hardware_mode,
|
79
|
+
)
|
80
|
+
|
81
|
+
# Step 2: Rerank the retrieved chunks
|
82
|
+
final_chunk_count = config.top_k_chunks
|
83
|
+
logger.info(
|
84
|
+
"%s: Step 2 - Reranking %d chunks to get top %d",
|
85
|
+
call_id,
|
86
|
+
len(retrieved_chunks),
|
87
|
+
final_chunk_count,
|
88
|
+
)
|
89
|
+
|
90
|
+
reranked_chunks = rerank_chunks(
|
91
|
+
chunks=retrieved_chunks,
|
92
|
+
query=query,
|
93
|
+
config=config,
|
94
|
+
top_k=final_chunk_count,
|
95
|
+
)
|
96
|
+
|
97
|
+
# Log final results with hardware info
|
98
|
+
final_papers = len(
|
99
|
+
set(chunk.metadata.get("paper_id", "unknown") for chunk in reranked_chunks)
|
100
|
+
)
|
101
|
+
|
102
|
+
logger.info(
|
103
|
+
"%s: Reranking complete using %s. Final %d chunks from %d unique papers",
|
104
|
+
call_id,
|
105
|
+
hardware_mode,
|
106
|
+
len(reranked_chunks),
|
107
|
+
final_papers,
|
108
|
+
)
|
109
|
+
|
110
|
+
# Log performance insights
|
111
|
+
if len(retrieved_chunks) > 0:
|
112
|
+
efficiency = len(reranked_chunks) / len(retrieved_chunks) * 100
|
113
|
+
logger.debug(
|
114
|
+
"%s: Pipeline efficiency: %.1f%% (%d final / %d initial chunks) - %s",
|
115
|
+
call_id,
|
116
|
+
efficiency,
|
117
|
+
len(reranked_chunks),
|
118
|
+
len(retrieved_chunks),
|
119
|
+
hardware_mode,
|
120
|
+
)
|
121
|
+
|
122
|
+
return reranked_chunks
|