aiagents4pharma 1.40.1__py3-none-any.whl → 1.42.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/default.yaml +1 -1
  2. aiagents4pharma/talk2knowledgegraphs/configs/tools/multimodal_subgraph_extraction/default.yaml +37 -0
  3. aiagents4pharma/talk2knowledgegraphs/configs/utils/enrichments/ols_terms/default.yaml +3 -0
  4. aiagents4pharma/talk2knowledgegraphs/configs/utils/enrichments/reactome_pathways/default.yaml +3 -0
  5. aiagents4pharma/talk2knowledgegraphs/configs/utils/enrichments/uniprot_proteins/default.yaml +6 -0
  6. aiagents4pharma/talk2knowledgegraphs/configs/utils/pubchem_utils/default.yaml +5 -0
  7. aiagents4pharma/talk2knowledgegraphs/milvus_data_dump.py +752 -350
  8. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/s2_agent/default.yaml +4 -0
  9. aiagents4pharma/talk2scholars/configs/tools/question_and_answer/default.yaml +44 -4
  10. aiagents4pharma/talk2scholars/tests/test_nvidia_nim_reranker.py +127 -0
  11. aiagents4pharma/talk2scholars/tests/test_pdf_answer_formatter.py +66 -0
  12. aiagents4pharma/talk2scholars/tests/test_pdf_batch_processor.py +101 -0
  13. aiagents4pharma/talk2scholars/tests/test_pdf_collection_manager.py +150 -0
  14. aiagents4pharma/talk2scholars/tests/test_pdf_document_processor.py +69 -0
  15. aiagents4pharma/talk2scholars/tests/test_pdf_generate_answer.py +75 -0
  16. aiagents4pharma/talk2scholars/tests/test_pdf_gpu_detection.py +140 -0
  17. aiagents4pharma/talk2scholars/tests/test_pdf_paper_loader.py +116 -0
  18. aiagents4pharma/talk2scholars/tests/test_pdf_rag_pipeline.py +98 -0
  19. aiagents4pharma/talk2scholars/tests/test_pdf_retrieve_chunks.py +197 -0
  20. aiagents4pharma/talk2scholars/tests/test_pdf_singleton_manager.py +156 -0
  21. aiagents4pharma/talk2scholars/tests/test_pdf_vector_normalization.py +121 -0
  22. aiagents4pharma/talk2scholars/tests/test_pdf_vector_store.py +434 -0
  23. aiagents4pharma/talk2scholars/tests/test_question_and_answer_tool.py +89 -509
  24. aiagents4pharma/talk2scholars/tests/test_tool_helper_utils.py +34 -89
  25. aiagents4pharma/talk2scholars/tools/paper_download/download_biorxiv_input.py +8 -6
  26. aiagents4pharma/talk2scholars/tools/paper_download/download_medrxiv_input.py +6 -4
  27. aiagents4pharma/talk2scholars/tools/pdf/question_and_answer.py +74 -40
  28. aiagents4pharma/talk2scholars/tools/pdf/utils/__init__.py +26 -1
  29. aiagents4pharma/talk2scholars/tools/pdf/utils/answer_formatter.py +62 -0
  30. aiagents4pharma/talk2scholars/tools/pdf/utils/batch_processor.py +200 -0
  31. aiagents4pharma/talk2scholars/tools/pdf/utils/collection_manager.py +172 -0
  32. aiagents4pharma/talk2scholars/tools/pdf/utils/document_processor.py +76 -0
  33. aiagents4pharma/talk2scholars/tools/pdf/utils/generate_answer.py +14 -14
  34. aiagents4pharma/talk2scholars/tools/pdf/utils/get_vectorstore.py +63 -0
  35. aiagents4pharma/talk2scholars/tools/pdf/utils/gpu_detection.py +154 -0
  36. aiagents4pharma/talk2scholars/tools/pdf/utils/nvidia_nim_reranker.py +60 -40
  37. aiagents4pharma/talk2scholars/tools/pdf/utils/paper_loader.py +123 -0
  38. aiagents4pharma/talk2scholars/tools/pdf/utils/rag_pipeline.py +122 -0
  39. aiagents4pharma/talk2scholars/tools/pdf/utils/retrieve_chunks.py +162 -40
  40. aiagents4pharma/talk2scholars/tools/pdf/utils/singleton_manager.py +140 -0
  41. aiagents4pharma/talk2scholars/tools/pdf/utils/tool_helper.py +40 -78
  42. aiagents4pharma/talk2scholars/tools/pdf/utils/vector_normalization.py +159 -0
  43. aiagents4pharma/talk2scholars/tools/pdf/utils/vector_store.py +277 -96
  44. aiagents4pharma/talk2scholars/tools/s2/multi_paper_rec.py +12 -9
  45. aiagents4pharma/talk2scholars/tools/s2/query_dataframe.py +0 -1
  46. aiagents4pharma/talk2scholars/tools/s2/retrieve_semantic_scholar_paper_id.py +9 -8
  47. aiagents4pharma/talk2scholars/tools/s2/single_paper_rec.py +5 -5
  48. {aiagents4pharma-1.40.1.dist-info → aiagents4pharma-1.42.0.dist-info}/METADATA +52 -126
  49. {aiagents4pharma-1.40.1.dist-info → aiagents4pharma-1.42.0.dist-info}/RECORD +52 -25
  50. aiagents4pharma/talk2scholars/tests/test_nvidia_nim_reranker_utils.py +0 -28
  51. {aiagents4pharma-1.40.1.dist-info → aiagents4pharma-1.42.0.dist-info}/WHEEL +0 -0
  52. {aiagents4pharma-1.40.1.dist-info → aiagents4pharma-1.42.0.dist-info}/licenses/LICENSE +0 -0
  53. {aiagents4pharma-1.40.1.dist-info → aiagents4pharma-1.42.0.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,14 @@
1
1
  """
2
- Retrieve relevant chunks from a vector store using MMR (Maximal Marginal Relevance).
2
+ Retrieve relevant chunks from a Milvus vector store using MMR (Maximal Marginal Relevance).
3
+ Follows traditional RAG pipeline - retrieve first, then rerank.
4
+ With automatic GPU/CPU search parameter optimization.
3
5
  """
4
6
 
5
7
  import logging
6
8
  import os
7
9
  from typing import List, Optional
8
10
 
9
- import numpy as np
10
11
  from langchain_core.documents import Document
11
- from langchain_core.vectorstores.utils import maximal_marginal_relevance
12
12
 
13
13
 
14
14
  # Set up logging with configurable level
@@ -19,65 +19,187 @@ logger.setLevel(getattr(logging, log_level))
19
19
 
20
20
 
21
21
  def retrieve_relevant_chunks(
22
- self,
22
+ vector_store,
23
23
  query: str,
24
24
  paper_ids: Optional[List[str]] = None,
25
- top_k: int = 25,
26
- mmr_diversity: float = 1.00,
25
+ top_k: int = 100, # Increased default to cast wider net before reranking
26
+ mmr_diversity: float = 0.8, # Slightly reduced for better diversity
27
27
  ) -> List[Document]:
28
28
  """
29
29
  Retrieve the most relevant chunks for a query using maximal marginal relevance.
30
+ Automatically uses GPU-optimized search parameters if GPU is available.
31
+
32
+ In the traditional RAG pipeline, this should retrieve chunks from ALL available papers,
33
+ not just pre-selected ones. The reranker will then select the best chunks.
30
34
 
31
35
  Args:
36
+ vector_store: The Milvus vector store instance
32
37
  query: Query string
33
- paper_ids: Optional list of paper IDs to filter by
34
- top_k: Number of chunks to retrieve
35
- mmr_diversity: Diversity parameter for MMR (higher = more diverse)
38
+ paper_ids: Optional list of paper IDs to filter by (default: None - search all papers)
39
+ top_k: Number of chunks to retrieve (default: 100 for reranking pipeline)
40
+ mmr_diversity: Diversity parameter for MMR (0=max diversity, 1=max relevance)
36
41
 
37
42
  Returns:
38
43
  List of document chunks
39
44
  """
40
- if not self.vector_store:
41
- logger.error("Failed to build vector store")
45
+ if not vector_store:
46
+ logger.error("Vector store is not initialized")
42
47
  return []
43
48
 
49
+ # Check if vector store has GPU capabilities
50
+ has_gpu = getattr(vector_store, "has_gpu", False)
51
+ search_mode = "GPU-accelerated" if has_gpu else "CPU"
52
+
53
+ # Prepare filter for paper_ids if provided
54
+ filter_dict = None
44
55
  if paper_ids:
56
+ logger.warning(
57
+ "Paper IDs filter provided. Traditional RAG pipeline typically"
58
+ "retrieves from ALL papers first. "
59
+ "Consider removing paper_ids filter for better results."
60
+ )
45
61
  logger.info("Filtering retrieval to papers: %s", paper_ids)
62
+ filter_dict = {"paper_id": paper_ids}
63
+ else:
64
+ logger.info(
65
+ "Retrieving chunks from ALL papers (traditional RAG approach) using %s search",
66
+ search_mode,
67
+ )
68
+
69
+ # Use Milvus's built-in MMR search with optimized parameters
70
+ logger.info(
71
+ "Performing %s MMR search with query: '%s', k=%d, diversity=%.2f",
72
+ search_mode,
73
+ query[:50] + "..." if len(query) > 50 else query,
74
+ top_k,
75
+ mmr_diversity,
76
+ )
46
77
 
47
- # Step 1: Embed the query
48
- logger.info("Embedding query using model: %s", type(self.embedding_model).__name__)
49
- query_embedding = np.array(self.embedding_model.embed_query(query))
78
+ # Fetch more candidates for better MMR results
79
+ # Adjust fetch_k based on available hardware
80
+ if has_gpu:
81
+ # GPU can handle larger candidate sets efficiently
82
+ fetch_k = min(top_k * 6, 800) # Increased for GPU
83
+ logger.debug("Using GPU-optimized fetch_k: %d", fetch_k)
84
+ else:
85
+ # CPU - more conservative to avoid performance issues
86
+ fetch_k = min(top_k * 4, 500) # Original conservative approach
87
+ logger.debug("Using CPU-optimized fetch_k: %d", fetch_k)
50
88
 
51
- # Step 2: Filter relevant documents
52
- all_docs = [
53
- doc
54
- for doc in self.documents.values()
55
- if not paper_ids or doc.metadata["paper_id"] in paper_ids
56
- ]
89
+ # Get search parameters from vector store if available
90
+ search_params = getattr(vector_store, "search_params", None)
57
91
 
58
- if not all_docs:
59
- logger.warning("No documents found after filtering by paper_ids.")
60
- return []
92
+ if search_params:
93
+ logger.debug("Using hardware-optimized search parameters: %s", search_params)
94
+ else:
95
+ logger.debug("Using default search parameters (no hardware optimization)")
61
96
 
62
- # Step 3: Retrieve or compute embeddings for all documents using cache
63
- logger.info("Retrieving embeddings for %d chunks...", len(all_docs))
64
- all_embeddings = []
65
- for doc in all_docs:
66
- doc_id = f"{doc.metadata['paper_id']}_{doc.metadata['chunk_id']}"
67
- if doc_id not in self.embeddings:
68
- logger.info("Embedding missing chunk %s", doc_id)
69
- emb = self.embedding_model.embed_documents([doc.page_content])[0]
70
- self.embeddings[doc_id] = emb
71
- all_embeddings.append(self.embeddings[doc_id])
72
-
73
- # Step 4: Apply MMR
74
- mmr_indices = maximal_marginal_relevance(
75
- query_embedding,
76
- all_embeddings,
97
+ # Perform MMR search - let the vector store handle search_params internally
98
+ # Don't pass search_params explicitly to avoid conflicts
99
+ results = vector_store.max_marginal_relevance_search(
100
+ query=query,
77
101
  k=top_k,
102
+ fetch_k=fetch_k,
78
103
  lambda_mult=mmr_diversity,
104
+ filter=filter_dict,
79
105
  )
80
106
 
81
- results = [all_docs[i] for i in mmr_indices]
82
- logger.info("Retrieved %d chunks using MMR", len(results))
107
+ logger.info(
108
+ "Retrieved %d chunks using %s MMR from Milvus", len(results), search_mode
109
+ )
110
+
111
+ # Log some details about retrieved chunks for debugging
112
+ if results and logger.isEnabledFor(logging.DEBUG):
113
+ paper_counts = {}
114
+ for doc in results:
115
+ paper_id = doc.metadata.get("paper_id", "unknown")
116
+ paper_counts[paper_id] = paper_counts.get(paper_id, 0) + 1
117
+
118
+ logger.debug(
119
+ "%s retrieval - chunks per paper: %s",
120
+ search_mode,
121
+ dict(sorted(paper_counts.items(), key=lambda x: x[1], reverse=True)[:10]),
122
+ )
123
+ logger.debug(
124
+ "%s retrieval - total papers represented: %d",
125
+ search_mode,
126
+ len(paper_counts),
127
+ )
128
+
83
129
  return results
130
+
131
+
132
+ def retrieve_relevant_chunks_with_scores(
133
+ vector_store,
134
+ query: str,
135
+ paper_ids: Optional[List[str]] = None,
136
+ top_k: int = 100,
137
+ score_threshold: float = 0.0,
138
+ ) -> List[tuple[Document, float]]:
139
+ """
140
+ Retrieve chunks with similarity scores, optimized for GPU/CPU.
141
+
142
+ Args:
143
+ vector_store: The Milvus vector store instance
144
+ query: Query string
145
+ paper_ids: Optional list of paper IDs to filter by
146
+ top_k: Number of chunks to retrieve
147
+ score_threshold: Minimum similarity score threshold
148
+
149
+ Returns:
150
+ List of (document, score) tuples
151
+ """
152
+ if not vector_store:
153
+ logger.error("Vector store is not initialized")
154
+ return []
155
+
156
+ has_gpu = getattr(vector_store, "has_gpu", False)
157
+ search_mode = "GPU-accelerated" if has_gpu else "CPU"
158
+
159
+ # Prepare filter
160
+ filter_dict = None
161
+ if paper_ids:
162
+ filter_dict = {"paper_id": paper_ids}
163
+
164
+ logger.info(
165
+ "Performing %s similarity search with scores: query='%s', k=%d, threshold=%.3f",
166
+ search_mode,
167
+ query[:50] + "..." if len(query) > 50 else query,
168
+ top_k,
169
+ score_threshold,
170
+ )
171
+
172
+ # Check hardware optimization status instead of unused search_params
173
+ has_optimization = hasattr(vector_store, "has_gpu") and vector_store.has_gpu
174
+
175
+ if has_optimization:
176
+ logger.debug("GPU-accelerated similarity search enabled")
177
+ else:
178
+ logger.debug("Standard CPU similarity search")
179
+
180
+ if hasattr(vector_store, "similarity_search_with_score"):
181
+ # Don't pass search_params to avoid conflicts
182
+ results = vector_store.similarity_search_with_score(
183
+ query=query,
184
+ k=top_k,
185
+ filter=filter_dict,
186
+ )
187
+
188
+ # Filter by score threshold
189
+ filtered_results = [
190
+ (doc, score) for doc, score in results if score >= score_threshold
191
+ ]
192
+
193
+ logger.info(
194
+ "%s search with scores retrieved %d/%d chunks above threshold %.3f",
195
+ search_mode,
196
+ len(filtered_results),
197
+ len(results),
198
+ score_threshold,
199
+ )
200
+
201
+ return filtered_results
202
+
203
+ raise NotImplementedError(
204
+ "Vector store does not support similarity_search_with_score"
205
+ )
@@ -0,0 +1,140 @@
1
+ """
2
+ Singleton manager for Milvus connections and vector stores.
3
+ Handles connection reuse, event loops, and GPU detection caching.
4
+ """
5
+
6
+ import asyncio
7
+ import logging
8
+ import threading
9
+ from typing import Any, Dict
10
+
11
+ from langchain_core.embeddings import Embeddings
12
+ from langchain_milvus import Milvus
13
+ from pymilvus import connections, db, utility
14
+ from pymilvus.exceptions import MilvusException
15
+
16
+ from .gpu_detection import detect_nvidia_gpu
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class VectorstoreSingleton:
22
+ """Singleton manager for Milvus connections and vector stores."""
23
+
24
+ _instance = None
25
+ _lock = threading.Lock()
26
+ _connections = {} # Store connections by connection string
27
+ _vector_stores = {} # Store vector stores by collection name
28
+ _event_loops = {} # Store event loops by thread ID
29
+ _gpu_detected = None # Cache GPU detection result
30
+
31
+ def __new__(cls):
32
+ if cls._instance is None:
33
+ with cls._lock:
34
+ if cls._instance is None:
35
+ cls._instance = super().__new__(cls)
36
+ return cls._instance
37
+
38
+ def get_event_loop(self) -> asyncio.AbstractEventLoop:
39
+ """Get or create event loop for current thread."""
40
+ thread_id = threading.get_ident()
41
+
42
+ if thread_id not in self._event_loops:
43
+ try:
44
+ loop = asyncio.get_event_loop()
45
+ if loop.is_closed():
46
+ raise RuntimeError("Event loop is closed")
47
+ except RuntimeError:
48
+ loop = asyncio.new_event_loop()
49
+ asyncio.set_event_loop(loop)
50
+ self._event_loops[thread_id] = loop
51
+ logger.info("Created new event loop for thread %s", thread_id)
52
+
53
+ return self._event_loops[thread_id]
54
+
55
+ def detect_gpu_once(self) -> bool:
56
+ """Detect GPU availability once and cache the result."""
57
+ if self._gpu_detected is None:
58
+ self._gpu_detected = detect_nvidia_gpu()
59
+ gpu_status = "available" if self._gpu_detected else "not available"
60
+ logger.info("GPU detection completed: NVIDIA GPU %s", gpu_status)
61
+ return self._gpu_detected
62
+
63
+ def get_connection(self, host: str, port: int, db_name: str) -> str:
64
+ """Get or create a Milvus connection."""
65
+ conn_key = f"{host}:{port}/{db_name}"
66
+
67
+ if conn_key not in self._connections:
68
+ try:
69
+ # Check if already connected
70
+ if connections.has_connection("default"):
71
+ connections.remove_connection("default")
72
+
73
+ # Connect to Milvus
74
+ connections.connect(
75
+ alias="default",
76
+ host=host,
77
+ port=port,
78
+ )
79
+ logger.info("Connected to Milvus at %s:%s", host, port)
80
+
81
+ # Check if database exists, create if not
82
+ existing_dbs = db.list_database()
83
+ if db_name not in existing_dbs:
84
+ db.create_database(db_name)
85
+ logger.info("Created database: %s", db_name)
86
+
87
+ # Use the database
88
+ db.using_database(db_name)
89
+ logger.info("Using database: %s", db_name)
90
+ logger.debug(
91
+ "Milvus DB switched to: %s, available collections: %s",
92
+ db_name,
93
+ utility.list_collections(),
94
+ )
95
+
96
+ self._connections[conn_key] = "default"
97
+
98
+ except MilvusException as e:
99
+ logger.error("Failed to connect to Milvus: %s", e)
100
+ raise
101
+
102
+ return self._connections[conn_key]
103
+
104
+ def get_vector_store(
105
+ self,
106
+ collection_name: str,
107
+ embedding_model: Embeddings,
108
+ connection_args: Dict[str, Any],
109
+ ) -> Milvus:
110
+ """Get or create a vector store for a collection."""
111
+ if collection_name not in self._vector_stores:
112
+ # Ensure event loop exists for this thread
113
+ self.get_event_loop()
114
+
115
+ # Create LangChain Milvus instance with explicit URI format
116
+ # This ensures LangChain uses the correct host
117
+ milvus_uri = f"http://{connection_args['host']}:{connection_args['port']}"
118
+
119
+ vector_store = Milvus(
120
+ embedding_function=embedding_model,
121
+ collection_name=collection_name,
122
+ connection_args={
123
+ "uri": milvus_uri, # Use URI format instead of host/port
124
+ "host": connection_args["host"],
125
+ "port": connection_args["port"],
126
+ },
127
+ text_field="text",
128
+ auto_id=False,
129
+ drop_old=False,
130
+ consistency_level="Strong",
131
+ )
132
+
133
+ self._vector_stores[collection_name] = vector_store
134
+ logger.info(
135
+ "Created new vector store for collection: %s with URI: %s",
136
+ collection_name,
137
+ milvus_uri,
138
+ )
139
+
140
+ return self._vector_stores[collection_name]
@@ -1,25 +1,26 @@
1
1
  """
2
- Helper class for PDF Q&A tool orchestration: state validation, vectorstore init,
3
- paper loading, reranking, and answer formatting.
2
+ Helper class for question and answer tool in PDF processing.
4
3
  """
5
4
 
6
5
  import logging
7
- from typing import Any, Dict, List, Optional
6
+ from typing import Any, Dict
8
7
 
9
- from .generate_answer import generate_answer
10
- from .nvidia_nim_reranker import rank_papers_by_query
11
- from .vector_store import Vectorstore
8
+
9
+ from .get_vectorstore import get_vectorstore
12
10
 
13
11
  logger = logging.getLogger(__name__)
14
12
 
15
13
 
16
14
  class QAToolHelper:
17
- """Encapsulates helper routines for the PDF Question & Answer tool."""
15
+ """
16
+ Encapsulates helper routines for the PDF Question & Answer tool.
17
+ Enhanced with automatic GPU/CPU detection and optimization.
18
+ """
18
19
 
19
20
  def __init__(self) -> None:
20
- self.prebuilt_vector_store: Optional[Vectorstore] = None
21
21
  self.config: Any = None
22
22
  self.call_id: str = ""
23
+ self.has_gpu: bool = False # Track GPU availability
23
24
  logger.debug("Initialized QAToolHelper")
24
25
 
25
26
  def start_call(self, config: Any, call_id: str) -> None:
@@ -47,79 +48,40 @@ class QAToolHelper:
47
48
  raise ValueError(msg)
48
49
  return text_emb, llm, articles
49
50
 
50
- def init_vector_store(self, emb_model: Any) -> Vectorstore:
51
- """Return shared or new Vectorstore instance."""
52
- if self.prebuilt_vector_store is not None:
53
- logger.info("Using shared pre-built vector store from memory")
54
- return self.prebuilt_vector_store
55
- vs = Vectorstore(embedding_model=emb_model, config=self.config)
56
- logger.info("Initialized new vector store with provided configuration")
57
- self.prebuilt_vector_store = vs
58
- return vs
51
+ def init_vector_store(self, emb_model: Any) -> Any:
52
+ """Get the singleton Milvus vector store instance with GPU/CPU optimization."""
53
+ logger.info(
54
+ "%s: Getting singleton vector store instance with hardware optimization",
55
+ self.call_id,
56
+ )
57
+ vs = get_vectorstore(embedding_model=emb_model, config=self.config)
59
58
 
60
- def load_candidate_papers(
61
- self,
62
- vs: Vectorstore,
63
- articles: Dict[str, Any],
64
- candidates: List[str],
65
- ) -> None:
66
- """Ensure each candidate paper is loaded into the vector store."""
67
- for pid in candidates:
68
- if pid not in vs.loaded_papers:
69
- pdf_url = articles.get(pid, {}).get("pdf_url")
70
- if not pdf_url:
71
- continue
72
- try:
73
- vs.add_paper(pid, pdf_url, articles[pid])
74
- except (IOError, ValueError) as exc:
75
- logger.warning(
76
- "%s: Error loading paper %s: %s", self.call_id, pid, exc
77
- )
59
+ # Track GPU availability from vector store
60
+ self.has_gpu = getattr(vs, "has_gpu", False)
61
+ hardware_type = "GPU-accelerated" if self.has_gpu else "CPU-only"
78
62
 
79
- def run_reranker(
80
- self,
81
- vs: Vectorstore,
82
- query: str,
83
- candidates: List[str],
84
- ) -> List[str]:
85
- """Rank papers by relevance and return filtered paper IDs."""
86
- try:
87
- ranked = rank_papers_by_query(
88
- vs, query, self.config, top_k=self.config.top_k_papers
89
- )
90
- logger.info("%s: Papers after NVIDIA reranking: %s", self.call_id, ranked)
91
- return [pid for pid in ranked if pid in candidates]
92
- except (ValueError, RuntimeError) as exc:
93
- logger.error("%s: NVIDIA reranker failed: %s", self.call_id, exc)
63
+ logger.info(
64
+ "%s: Vector store initialized (%s mode)",
65
+ self.call_id,
66
+ hardware_type,
67
+ )
68
+
69
+ # Log hardware-specific configuration
70
+ if hasattr(vs, "index_params"):
71
+ index_type = vs.index_params.get("index_type", "Unknown")
94
72
  logger.info(
95
- "%s: Falling back to all %d candidate papers",
73
+ "%s: Using %s index type for %s processing",
96
74
  self.call_id,
97
- len(candidates),
75
+ index_type,
76
+ hardware_type,
98
77
  )
99
- return candidates
100
78
 
101
- def format_answer(
102
- self,
103
- question: str,
104
- chunks: List[Any],
105
- llm: Any,
106
- articles: Dict[str, Any],
107
- ) -> str:
108
- """Generate the final answer text with source attributions."""
109
- result = generate_answer(question, chunks, llm, self.config)
110
- answer = result.get("output_text", "No answer generated.")
111
- titles: Dict[str, str] = {}
112
- for pid in result.get("papers_used", []):
113
- if pid in articles:
114
- titles[pid] = articles[pid].get("Title", "Unknown paper")
115
- if titles:
116
- srcs = "\n\nSources:\n" + "\n".join(f"- {t}" for t in titles.values())
117
- else:
118
- srcs = ""
119
- logger.info(
120
- "%s: Generated answer using %d chunks from %d papers",
121
- self.call_id,
122
- len(chunks),
123
- len(titles),
124
- )
125
- return f"{answer}{srcs}"
79
+ return vs
80
+
81
+ def get_hardware_stats(self) -> Dict[str, Any]:
82
+ """Get current hardware configuration stats for monitoring."""
83
+ return {
84
+ "gpu_available": self.has_gpu,
85
+ "hardware_mode": "GPU-accelerated" if self.has_gpu else "CPU-only",
86
+ "call_id": self.call_id,
87
+ }
@@ -0,0 +1,159 @@
1
+ """
2
+ Vector normalization utilities for GPU COSINE similarity support.
3
+ Since GPU indexes don't support COSINE distance, we normalize vectors
4
+ and use IP (Inner Product) distance instead.
5
+ """
6
+
7
+ import logging
8
+ from typing import List, Union
9
+
10
+ import numpy as np
11
+ from langchain_core.embeddings import Embeddings
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ def normalize_vector(vector: Union[List[float], np.ndarray]) -> List[float]:
17
+ """
18
+ Normalize a single vector to unit length.
19
+
20
+ Args:
21
+ vector: Input vector as list or numpy array
22
+
23
+ Returns:
24
+ Normalized vector as list
25
+ """
26
+ vector = np.asarray(vector, dtype=np.float32)
27
+ norm = np.linalg.norm(vector)
28
+
29
+ if norm == 0:
30
+ logger.warning("Zero vector encountered during normalization")
31
+ return vector.tolist()
32
+
33
+ normalized = vector / norm
34
+ return normalized.tolist()
35
+
36
+
37
+ def normalize_vectors_batch(vectors: List[List[float]]) -> List[List[float]]:
38
+ """
39
+ Normalize a batch of vectors to unit length.
40
+
41
+ Args:
42
+ vectors: List of vectors
43
+
44
+ Returns:
45
+ List of normalized vectors
46
+ """
47
+ if not vectors:
48
+ return vectors
49
+
50
+ # Convert to numpy array for efficient computation
51
+ vectors_array = np.asarray(vectors, dtype=np.float32)
52
+
53
+ # Calculate norms for each vector
54
+ norms = np.linalg.norm(vectors_array, axis=1, keepdims=True)
55
+
56
+ # Handle zero vectors
57
+ zero_mask = norms.flatten() == 0
58
+ if np.any(zero_mask):
59
+ logger.warning(
60
+ "Found %d zero vectors during batch normalization", np.sum(zero_mask)
61
+ )
62
+ norms[zero_mask] = 1.0 # Avoid division by zero
63
+
64
+ # Normalize
65
+ normalized = vectors_array / norms
66
+
67
+ return normalized.tolist()
68
+
69
+
70
+ class NormalizingEmbeddings(Embeddings):
71
+ """
72
+ Wrapper around an embedding model that automatically normalizes outputs.
73
+ This is needed for GPU indexes when using COSINE similarity.
74
+ """
75
+
76
+ def __init__(self, embedding_model: Embeddings, normalize_for_gpu: bool = True):
77
+ """
78
+ Initialize the normalizing wrapper.
79
+
80
+ Args:
81
+ embedding_model: The underlying embedding model
82
+ normalize_for_gpu: Whether to normalize embeddings (for GPU compatibility)
83
+ """
84
+ self.embedding_model = embedding_model
85
+ self.normalize_for_gpu = normalize_for_gpu
86
+
87
+ if normalize_for_gpu:
88
+ logger.info(
89
+ "Embedding model wrapped with normalization for GPU compatibility"
90
+ )
91
+
92
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
93
+ """Embed documents and optionally normalize."""
94
+ embeddings = self.embedding_model.embed_documents(texts)
95
+
96
+ if self.normalize_for_gpu:
97
+ embeddings = normalize_vectors_batch(embeddings)
98
+ logger.debug("Normalized %d document embeddings for GPU", len(embeddings))
99
+
100
+ return embeddings
101
+
102
+ def embed_query(self, text: str) -> List[float]:
103
+ """Embed query and optionally normalize."""
104
+ embedding = self.embedding_model.embed_query(text)
105
+
106
+ if self.normalize_for_gpu:
107
+ embedding = normalize_vector(embedding)
108
+ logger.debug("Normalized query embedding for GPU")
109
+
110
+ return embedding
111
+
112
+ def __getattr__(self, name):
113
+ """Delegate other attributes to the underlying model."""
114
+ return getattr(self.embedding_model, name)
115
+
116
+
117
+ def should_normalize_vectors(has_gpu: bool, use_cosine: bool) -> bool:
118
+ """
119
+ Determine if vectors should be normalized based on hardware and similarity metric.
120
+
121
+ Args:
122
+ has_gpu: Whether GPU is being used
123
+ use_cosine: Whether COSINE similarity is desired
124
+
125
+ Returns:
126
+ True if vectors should be normalized
127
+ """
128
+ needs_normalization = has_gpu and use_cosine
129
+
130
+ if needs_normalization:
131
+ logger.info(
132
+ "Vector normalization ENABLED: GPU detected with COSINE similarity request"
133
+ )
134
+ else:
135
+ logger.info(
136
+ "Vector normalization DISABLED: GPU=%s, COSINE=%s", has_gpu, use_cosine
137
+ )
138
+
139
+ return needs_normalization
140
+
141
+
142
+ def wrap_embedding_model_if_needed(
143
+ embedding_model: Embeddings, has_gpu: bool, use_cosine: bool = True
144
+ ) -> Embeddings:
145
+ """
146
+ Wrap embedding model with normalization if needed for GPU compatibility.
147
+
148
+ Args:
149
+ embedding_model: Original embedding model
150
+ has_gpu: Whether GPU is being used
151
+ use_cosine: Whether COSINE similarity is desired
152
+
153
+ Returns:
154
+ Original or wrapped embedding model
155
+ """
156
+ if should_normalize_vectors(has_gpu, use_cosine):
157
+ return NormalizingEmbeddings(embedding_model, normalize_for_gpu=True)
158
+
159
+ return embedding_model