visual-rag-toolkit 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. benchmarks/README.md +101 -0
  2. benchmarks/__init__.py +11 -0
  3. benchmarks/analyze_results.py +187 -0
  4. benchmarks/benchmark_datasets.txt +105 -0
  5. benchmarks/prepare_submission.py +205 -0
  6. benchmarks/quick_test.py +566 -0
  7. benchmarks/run_vidore.py +513 -0
  8. benchmarks/vidore_beir_qdrant/run_qdrant_beir.py +1365 -0
  9. benchmarks/vidore_tatdqa_test/COMMANDS.md +83 -0
  10. benchmarks/vidore_tatdqa_test/__init__.py +6 -0
  11. benchmarks/vidore_tatdqa_test/dataset_loader.py +363 -0
  12. benchmarks/vidore_tatdqa_test/metrics.py +44 -0
  13. benchmarks/vidore_tatdqa_test/run_qdrant.py +799 -0
  14. benchmarks/vidore_tatdqa_test/sweep_eval.py +372 -0
  15. demo/__init__.py +10 -0
  16. demo/app.py +45 -0
  17. demo/commands.py +334 -0
  18. demo/config.py +34 -0
  19. demo/download_models.py +75 -0
  20. demo/evaluation.py +602 -0
  21. demo/example_metadata_mapping_sigir.json +37 -0
  22. demo/indexing.py +286 -0
  23. demo/qdrant_utils.py +211 -0
  24. demo/results.py +35 -0
  25. demo/test_qdrant_connection.py +119 -0
  26. demo/ui/__init__.py +15 -0
  27. demo/ui/benchmark.py +355 -0
  28. demo/ui/header.py +30 -0
  29. demo/ui/playground.py +339 -0
  30. demo/ui/sidebar.py +162 -0
  31. demo/ui/upload.py +487 -0
  32. visual_rag/__init__.py +98 -0
  33. visual_rag/cli/__init__.py +1 -0
  34. visual_rag/cli/main.py +629 -0
  35. visual_rag/config.py +230 -0
  36. visual_rag/demo_runner.py +90 -0
  37. visual_rag/embedding/__init__.py +26 -0
  38. visual_rag/embedding/pooling.py +343 -0
  39. visual_rag/embedding/visual_embedder.py +622 -0
  40. visual_rag/indexing/__init__.py +21 -0
  41. visual_rag/indexing/cloudinary_uploader.py +274 -0
  42. visual_rag/indexing/pdf_processor.py +324 -0
  43. visual_rag/indexing/pipeline.py +628 -0
  44. visual_rag/indexing/qdrant_indexer.py +478 -0
  45. visual_rag/preprocessing/__init__.py +3 -0
  46. visual_rag/preprocessing/crop_empty.py +120 -0
  47. visual_rag/qdrant_admin.py +222 -0
  48. visual_rag/retrieval/__init__.py +19 -0
  49. visual_rag/retrieval/multi_vector.py +222 -0
  50. visual_rag/retrieval/single_stage.py +126 -0
  51. visual_rag/retrieval/three_stage.py +173 -0
  52. visual_rag/retrieval/two_stage.py +471 -0
  53. visual_rag/visualization/__init__.py +19 -0
  54. visual_rag/visualization/saliency.py +335 -0
  55. visual_rag_toolkit-0.1.1.dist-info/METADATA +305 -0
  56. visual_rag_toolkit-0.1.1.dist-info/RECORD +59 -0
  57. visual_rag_toolkit-0.1.1.dist-info/WHEEL +4 -0
  58. visual_rag_toolkit-0.1.1.dist-info/entry_points.txt +3 -0
  59. visual_rag_toolkit-0.1.1.dist-info/licenses/LICENSE +22 -0
@@ -0,0 +1,173 @@
1
+ import logging
2
+ from typing import Any, Dict, List, Optional, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ class ThreeStageRetriever:
11
+ def __init__(
12
+ self,
13
+ qdrant_client,
14
+ collection_name: str,
15
+ *,
16
+ full_vector_name: str = "initial",
17
+ experimental_vector_name: str = "experimental_pooling",
18
+ global_vector_name: str = "global_pooling",
19
+ request_timeout: int = 120,
20
+ max_retries: int = 3,
21
+ retry_sleep: float = 0.5,
22
+ ):
23
+ self.client = qdrant_client
24
+ self.collection_name = collection_name
25
+ self.full_vector_name = full_vector_name
26
+ self.experimental_vector_name = experimental_vector_name
27
+ self.global_vector_name = global_vector_name
28
+ self.request_timeout = int(request_timeout)
29
+ self.max_retries = int(max_retries)
30
+ self.retry_sleep = float(retry_sleep)
31
+
32
+ self._global_is_multivector: Optional[bool] = None
33
+ self._experimental_is_multivector: Optional[bool] = None
34
+
35
+ def _retry_call(self, fn):
36
+ import time
37
+
38
+ last_err = None
39
+ for attempt in range(self.max_retries):
40
+ try:
41
+ return fn()
42
+ except Exception as e:
43
+ last_err = e
44
+ if attempt >= self.max_retries - 1:
45
+ break
46
+ time.sleep(self.retry_sleep * (2**attempt))
47
+ if last_err is not None:
48
+ raise last_err
49
+
50
+ def _to_numpy(self, embedding: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
51
+ if isinstance(embedding, torch.Tensor):
52
+ if embedding.dtype == torch.bfloat16:
53
+ return embedding.cpu().float().numpy()
54
+ return embedding.cpu().numpy()
55
+ return np.array(embedding, dtype=np.float32)
56
+
57
+ def _infer_vector_is_multivector(self, vector_name: str) -> bool:
58
+ info = self.client.get_collection(self.collection_name)
59
+ cfg = getattr(info, "config", None)
60
+ params = getattr(cfg, "params", None) if cfg is not None else None
61
+ vectors = getattr(params, "vectors", None) if params is not None else None
62
+ v = None
63
+ try:
64
+ if isinstance(vectors, dict):
65
+ v = vectors.get(vector_name)
66
+ else:
67
+ v = vectors[vector_name]
68
+ except Exception:
69
+ v = None
70
+ mv = getattr(v, "multivector_config", None) if v is not None else None
71
+ if mv is None and isinstance(v, dict):
72
+ mv = v.get("multivector_config")
73
+ return mv is not None
74
+
75
+ def _and_filter(self, base_filter, ids: List[Any]):
76
+ from qdrant_client.http import models as m
77
+
78
+ has_id = m.HasIdCondition(has_id=list(ids))
79
+ if base_filter is None:
80
+ return m.Filter(must=[has_id])
81
+ return m.Filter(must=[base_filter, has_id])
82
+
83
+ def search_server_side(
84
+ self,
85
+ *,
86
+ query_embedding: Union[torch.Tensor, np.ndarray],
87
+ top_k: int = 100,
88
+ stage1_k: int = 1000,
89
+ stage2_k: int = 300,
90
+ filter_obj=None,
91
+ ) -> List[Dict[str, Any]]:
92
+ from qdrant_client.http import models as m
93
+
94
+ query_np = self._to_numpy(query_embedding)
95
+
96
+ stage1_query = query_np.mean(axis=0).tolist()
97
+ stage2_query = query_np.tolist()
98
+ stage3_query = query_np.tolist()
99
+
100
+ logger.info(f"Stage 1: global prefetch {int(stage1_k)}")
101
+
102
+ def _do_stage1():
103
+ return self.client.query_points(
104
+ collection_name=self.collection_name,
105
+ query=stage1_query,
106
+ using=self.global_vector_name,
107
+ limit=int(stage1_k),
108
+ query_filter=filter_obj,
109
+ with_payload=False,
110
+ with_vectors=False,
111
+ timeout=self.request_timeout,
112
+ ).points
113
+
114
+ s1 = self._retry_call(_do_stage1)
115
+ if not s1:
116
+ return []
117
+ s1_ids = [p.id for p in s1]
118
+ s1_score = {str(p.id): float(p.score) for p in s1}
119
+
120
+ logger.info(f"Stage 2: experimental prefetch {int(stage2_k)} (restricted to stage1)")
121
+
122
+ stage2_filter = self._and_filter(filter_obj, s1_ids)
123
+
124
+ def _do_stage2():
125
+ return self.client.query_points(
126
+ collection_name=self.collection_name,
127
+ query=stage2_query,
128
+ using=self.experimental_vector_name,
129
+ limit=int(min(int(stage2_k), len(s1_ids))),
130
+ query_filter=stage2_filter,
131
+ with_payload=False,
132
+ with_vectors=False,
133
+ timeout=self.request_timeout,
134
+ ).points
135
+
136
+ s2 = self._retry_call(_do_stage2)
137
+ if not s2:
138
+ return []
139
+ s2_ids = [p.id for p in s2]
140
+ s2_score = {str(p.id): float(p.score) for p in s2}
141
+
142
+ logger.info(f"Stage 3: exact rerank on initial to top {int(top_k)} (restricted to stage2)")
143
+
144
+ stage3_filter = self._and_filter(filter_obj, s2_ids)
145
+
146
+ def _do_stage3():
147
+ return self.client.query_points(
148
+ collection_name=self.collection_name,
149
+ query=stage3_query,
150
+ using=self.full_vector_name,
151
+ limit=int(top_k),
152
+ query_filter=stage3_filter,
153
+ with_payload=True,
154
+ with_vectors=False,
155
+ search_params=m.SearchParams(exact=True),
156
+ timeout=self.request_timeout,
157
+ ).points
158
+
159
+ s3 = self._retry_call(_do_stage3)
160
+ out = []
161
+ for p in s3:
162
+ pid = str(p.id)
163
+ out.append(
164
+ {
165
+ "id": p.id,
166
+ "score_stage1": s1_score.get(pid),
167
+ "score_stage2": s2_score.get(pid),
168
+ "score_stage3": float(p.score),
169
+ "score_final": float(p.score),
170
+ "payload": p.payload,
171
+ }
172
+ )
173
+ return out
@@ -0,0 +1,471 @@
1
+ """
2
+ Two-Stage Retrieval for Scalable Visual Document Search.
3
+
4
+ This is our NOVEL contribution:
5
+ - Stage 1: Fast prefetch using tile-level pooled vectors (mean_pooling)
6
+ - Stage 2: Exact reranking using full multi-vector embeddings (MaxSim)
7
+
8
+ Benefits:
9
+ - 5-10x faster than full MaxSim at scale
10
+ - Maintains 95%+ accuracy compared to full search
11
+ - Memory efficient (don't load all embeddings upfront)
12
+
13
+ Research Context:
14
+ - Different from HPC-ColPali (compression vs pooling)
15
+ - Inspired by text ColBERT two-stage retrieval
16
+ - Novel: tile-level pooling preserves spatial structure
17
+ """
18
+
19
+ import logging
20
+ from typing import Any, Dict, List, Optional, Union
21
+
22
+ import numpy as np
23
+ import torch
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class TwoStageRetriever:
29
+ """
30
+ Two-stage visual document retrieval with pooling and reranking.
31
+
32
+ Stage 1 (Prefetch):
33
+ Uses tile-level mean-pooled vectors for fast HNSW search.
34
+ Retrieves prefetch_k candidates (e.g., 100-500).
35
+
36
+ Stage 2 (Rerank):
37
+ Fetches full multi-vector embeddings for candidates.
38
+ Computes exact MaxSim scores for precise ranking.
39
+ Returns top_k results (e.g., 10).
40
+
41
+ Args:
42
+ qdrant_client: Connected Qdrant client
43
+ collection_name: Name of the Qdrant collection
44
+ full_vector_name: Name of full multi-vector field (default: "initial")
45
+ pooled_vector_name: Name of pooled vector field (default: "mean_pooling")
46
+
47
+ Example:
48
+ >>> retriever = TwoStageRetriever(client, "my_collection")
49
+ >>>
50
+ >>> # Two-stage search: prefetch 200, return top 10
51
+ >>> results = retriever.search(
52
+ ... query_embedding=query,
53
+ ... top_k=10,
54
+ ... prefetch_k=200,
55
+ ... )
56
+ >>>
57
+ >>> # Compare latency:
58
+ >>> # Full MaxSim (1000 docs): ~500ms
59
+ >>> # Two-stage (200→10): ~50ms
60
+ """
61
+
62
+ def __init__(
63
+ self,
64
+ qdrant_client,
65
+ collection_name: str,
66
+ full_vector_name: str = "initial",
67
+ pooled_vector_name: str = "mean_pooling",
68
+ experimental_vector_name: str = "experimental_pooling",
69
+ global_vector_name: str = "global_pooling",
70
+ request_timeout: int = 120,
71
+ max_retries: int = 3,
72
+ retry_sleep: float = 0.5,
73
+ ):
74
+ self.client = qdrant_client
75
+ self.collection_name = collection_name
76
+ self.full_vector_name = full_vector_name
77
+ self.pooled_vector_name = pooled_vector_name
78
+ self.experimental_vector_name = experimental_vector_name
79
+ self.global_vector_name = global_vector_name
80
+ self.request_timeout = int(request_timeout)
81
+ self.max_retries = int(max_retries)
82
+ self.retry_sleep = float(retry_sleep)
83
+
84
+ def _retry_call(self, fn):
85
+ import time
86
+
87
+ last_err = None
88
+ for attempt in range(self.max_retries):
89
+ try:
90
+ return fn()
91
+ except Exception as e:
92
+ last_err = e
93
+ if attempt >= self.max_retries - 1:
94
+ break
95
+ time.sleep(self.retry_sleep * (2**attempt))
96
+ if last_err is not None:
97
+ raise last_err
98
+
99
+ def search_server_side(
100
+ self,
101
+ query_embedding: Union[torch.Tensor, np.ndarray],
102
+ top_k: int = 10,
103
+ prefetch_k: Optional[int] = None,
104
+ filter_obj=None,
105
+ stage1_mode: str = "pooled_query_vs_tiles",
106
+ ) -> List[Dict[str, Any]]:
107
+ """
108
+ Two-stage retrieval using Qdrant's native prefetch (all server-side).
109
+
110
+ This is MUCH faster than search() because it avoids network transfer
111
+ of large multi-vector embeddings. All computation happens in Qdrant.
112
+
113
+ Args:
114
+ query_embedding: Query embeddings [num_tokens, dim]
115
+ top_k: Final number of results
116
+ prefetch_k: Candidates for stage 1 (default: 10x top_k)
117
+ filter_obj: Qdrant filter
118
+ stage1_mode: How to do stage 1 prefetch
119
+
120
+ Returns:
121
+ List of results with scores
122
+ """
123
+ from qdrant_client.http import models
124
+
125
+ query_np = self._to_numpy(query_embedding)
126
+
127
+ if prefetch_k is None:
128
+ prefetch_k = max(100, top_k * 10)
129
+
130
+ if stage1_mode == "pooled_query_vs_tiles":
131
+ prefetch_query = query_np.mean(axis=0).tolist()
132
+ prefetch_using = self.pooled_vector_name
133
+ elif stage1_mode == "tokens_vs_tiles":
134
+ prefetch_query = query_np.tolist()
135
+ prefetch_using = self.pooled_vector_name
136
+ elif stage1_mode == "pooled_query_vs_experimental":
137
+ prefetch_query = query_np.mean(axis=0).tolist()
138
+ prefetch_using = self.experimental_vector_name
139
+ elif stage1_mode == "tokens_vs_experimental":
140
+ prefetch_query = query_np.tolist()
141
+ prefetch_using = self.experimental_vector_name
142
+ elif stage1_mode == "pooled_query_vs_global":
143
+ prefetch_query = query_np.mean(axis=0).tolist()
144
+ prefetch_using = self.global_vector_name
145
+ else:
146
+ raise ValueError(f"Unknown stage1_mode: {stage1_mode}")
147
+
148
+ rerank_query = query_np.tolist()
149
+
150
+ def _do_query():
151
+ return self.client.query_points(
152
+ collection_name=self.collection_name,
153
+ query=rerank_query,
154
+ using=self.full_vector_name,
155
+ limit=top_k,
156
+ query_filter=filter_obj,
157
+ with_payload=True,
158
+ search_params=models.SearchParams(exact=True),
159
+ prefetch=[
160
+ models.Prefetch(
161
+ query=prefetch_query,
162
+ using=prefetch_using,
163
+ limit=prefetch_k,
164
+ )
165
+ ],
166
+ timeout=self.request_timeout,
167
+ ).points
168
+
169
+ results = self._retry_call(_do_query)
170
+
171
+ return [
172
+ {
173
+ "id": r.id,
174
+ "score_stage1": None,
175
+ "score_stage2": r.score,
176
+ "score_final": r.score,
177
+ "payload": r.payload,
178
+ }
179
+ for r in results
180
+ ]
181
+
182
+ def search(
183
+ self,
184
+ query_embedding: Union[torch.Tensor, np.ndarray],
185
+ top_k: int = 10,
186
+ prefetch_k: Optional[int] = None,
187
+ filter_obj=None,
188
+ use_reranking: bool = True,
189
+ return_embeddings: bool = False,
190
+ stage1_mode: str = "pooled_query_vs_tiles",
191
+ ) -> List[Dict[str, Any]]:
192
+ """
193
+ Two-stage retrieval: prefetch with pooling, rerank with MaxSim.
194
+
195
+ Args:
196
+ query_embedding: Query embeddings [num_tokens, dim]
197
+ top_k: Final number of results to return
198
+ prefetch_k: Candidates for stage 1 (default: 10x top_k)
199
+ filter_obj: Qdrant filter for metadata filtering
200
+ use_reranking: Enable stage 2 reranking (default: True)
201
+ return_embeddings: Include embeddings in results
202
+ stage1_mode:
203
+ - "pooled_query_vs_tiles": pool query to 1×dim and search tile vectors (using="mean_pooling")
204
+ - "tokens_vs_tiles": search tile vectors with full query tokens (using="mean_pooling")
205
+ - "pooled_query_vs_global": pool query to 1×dim and search global pooled doc vectors (using="global_pooling")
206
+
207
+ Returns:
208
+ List of results with scores and metadata:
209
+ [
210
+ {
211
+ "id": point_id,
212
+ "score_stage1": float, # Pooled similarity
213
+ "score_stage2": float, # MaxSim (if reranking)
214
+ "score_final": float, # Final score used for ranking
215
+ "payload": {...}, # Document metadata
216
+ },
217
+ ...
218
+ ]
219
+ """
220
+ # Convert to numpy
221
+ query_np = self._to_numpy(query_embedding)
222
+
223
+ # Auto-set prefetch_k
224
+ if prefetch_k is None:
225
+ prefetch_k = max(100, top_k * 10)
226
+
227
+ # Stage 1: Prefetch with pooled vectors
228
+ logger.info(f"🔍 Stage 1: Prefetching {prefetch_k} candidates ({stage1_mode})")
229
+ candidates = self._stage1_prefetch(
230
+ query_np=query_np,
231
+ top_k=prefetch_k,
232
+ filter_obj=filter_obj,
233
+ stage1_mode=stage1_mode,
234
+ )
235
+
236
+ if not candidates:
237
+ logger.warning("No candidates found in stage 1")
238
+ return []
239
+
240
+ logger.info(f"✅ Stage 1: Retrieved {len(candidates)} candidates")
241
+
242
+ # Stage 2: Rerank with full embeddings
243
+ if use_reranking and len(candidates) > top_k:
244
+ logger.info("🎯 Stage 2: Reranking with MaxSim...")
245
+ results = self._stage2_rerank(
246
+ query_np=query_np,
247
+ candidates=candidates,
248
+ top_k=top_k,
249
+ return_embeddings=return_embeddings,
250
+ )
251
+ logger.info(f"✅ Stage 2: Reranked to top {len(results)} results")
252
+ else:
253
+ # Skip reranking
254
+ results = candidates[:top_k]
255
+ for r in results:
256
+ r["score_final"] = r["score_stage1"]
257
+ logger.info(f"⏭️ Skipping reranking, returning top {len(results)}")
258
+
259
+ return results
260
+
261
+ def search_single_stage(
262
+ self,
263
+ query_embedding: Union[torch.Tensor, np.ndarray],
264
+ top_k: int = 10,
265
+ filter_obj=None,
266
+ use_pooling: bool = False,
267
+ ) -> List[Dict[str, Any]]:
268
+ """
269
+ Single-stage search (either pooled or full multi-vector).
270
+
271
+ Args:
272
+ query_embedding: Query embeddings
273
+ top_k: Number of results
274
+ filter_obj: Qdrant filter
275
+ use_pooling: Use pooled vectors (faster) or full (more accurate)
276
+
277
+ Returns:
278
+ List of results
279
+ """
280
+ query_np = self._to_numpy(query_embedding)
281
+
282
+ if use_pooling:
283
+ # Pool query and search pooled vectors
284
+ query_pooled = query_np.mean(axis=0)
285
+ vector_name = self.pooled_vector_name
286
+ query_vector = query_pooled.tolist()
287
+ logger.info(f"🔍 Pooled search: {vector_name}")
288
+ else:
289
+ # Native multi-vector search
290
+ vector_name = self.full_vector_name
291
+ query_vector = query_np.tolist()
292
+ logger.info(f"🎯 Multi-vector search: {vector_name}")
293
+
294
+ results = self.client.query_points(
295
+ collection_name=self.collection_name,
296
+ query=query_vector,
297
+ using=vector_name,
298
+ query_filter=filter_obj,
299
+ limit=top_k,
300
+ with_payload=True,
301
+ with_vectors=False,
302
+ timeout=120,
303
+ ).points
304
+
305
+ return [
306
+ {
307
+ "id": r.id,
308
+ "score_stage1": r.score,
309
+ "score_final": r.score,
310
+ "payload": r.payload,
311
+ }
312
+ for r in results
313
+ ]
314
+
315
+ def _stage1_prefetch(
316
+ self,
317
+ query_np: np.ndarray,
318
+ top_k: int,
319
+ filter_obj=None,
320
+ stage1_mode: str = "pooled_query_vs_tiles",
321
+ ) -> List[Dict[str, Any]]:
322
+ """Stage 1: Prefetch candidates."""
323
+ if stage1_mode == "pooled_query_vs_tiles":
324
+ query_vector = query_np.mean(axis=0).tolist()
325
+ vector_name = self.pooled_vector_name
326
+ elif stage1_mode == "tokens_vs_tiles":
327
+ query_vector = query_np.tolist()
328
+ vector_name = self.pooled_vector_name
329
+ elif stage1_mode == "pooled_query_vs_global":
330
+ query_vector = query_np.mean(axis=0).tolist()
331
+ vector_name = self.global_vector_name
332
+ else:
333
+ raise ValueError(f"Unknown stage1_mode: {stage1_mode}")
334
+
335
+ def _do_query():
336
+ return self.client.query_points(
337
+ collection_name=self.collection_name,
338
+ query=query_vector,
339
+ using=vector_name,
340
+ query_filter=filter_obj,
341
+ limit=top_k,
342
+ with_payload=True,
343
+ with_vectors=False,
344
+ timeout=self.request_timeout,
345
+ ).points
346
+
347
+ results = self._retry_call(_do_query)
348
+
349
+ return [
350
+ {
351
+ "id": r.id,
352
+ "score_stage1": r.score,
353
+ "payload": r.payload,
354
+ }
355
+ for r in results
356
+ ]
357
+
358
+ def _stage2_rerank(
359
+ self,
360
+ query_np: np.ndarray,
361
+ candidates: List[Dict[str, Any]],
362
+ top_k: int,
363
+ return_embeddings: bool = False,
364
+ ) -> List[Dict[str, Any]]:
365
+ """Stage 2: Rerank with full multi-vector MaxSim scoring."""
366
+ from visual_rag.embedding.pooling import compute_maxsim_score
367
+
368
+ # Fetch full embeddings for candidates
369
+ candidate_ids = [c["id"] for c in candidates]
370
+
371
+ # Retrieve points with vectors
372
+ def _do_retrieve():
373
+ return self.client.retrieve(
374
+ collection_name=self.collection_name,
375
+ ids=candidate_ids,
376
+ with_payload=False,
377
+ with_vectors=[self.full_vector_name],
378
+ timeout=self.request_timeout,
379
+ )
380
+
381
+ points = self._retry_call(_do_retrieve)
382
+
383
+ # Build ID to embedding map
384
+ id_to_embedding = {}
385
+ for point in points:
386
+ if point.vector and self.full_vector_name in point.vector:
387
+ id_to_embedding[point.id] = np.array(
388
+ point.vector[self.full_vector_name], dtype=np.float32
389
+ )
390
+
391
+ # Compute MaxSim scores
392
+ reranked = []
393
+ for candidate in candidates:
394
+ point_id = candidate["id"]
395
+ doc_embedding = id_to_embedding.get(point_id)
396
+
397
+ if doc_embedding is None:
398
+ # Fallback to stage 1 score
399
+ candidate["score_stage2"] = candidate["score_stage1"]
400
+ candidate["score_final"] = candidate["score_stage1"]
401
+ else:
402
+ # Compute exact MaxSim
403
+ maxsim_score = compute_maxsim_score(query_np, doc_embedding)
404
+ candidate["score_stage2"] = maxsim_score
405
+ candidate["score_final"] = maxsim_score
406
+
407
+ if return_embeddings:
408
+ candidate["embedding"] = doc_embedding
409
+
410
+ reranked.append(candidate)
411
+
412
+ # Sort by final score (descending)
413
+ reranked.sort(key=lambda x: x["score_final"], reverse=True)
414
+
415
+ return reranked[:top_k]
416
+
417
+ def _to_numpy(self, embedding: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
418
+ """Convert embedding to numpy array."""
419
+ if isinstance(embedding, torch.Tensor):
420
+ if embedding.dtype == torch.bfloat16:
421
+ return embedding.cpu().float().numpy()
422
+ return embedding.cpu().numpy()
423
+ return np.array(embedding, dtype=np.float32)
424
+
425
+ def build_filter(
426
+ self,
427
+ year: Optional[Any] = None,
428
+ source: Optional[str] = None,
429
+ district: Optional[str] = None,
430
+ filename: Optional[str] = None,
431
+ has_text: Optional[bool] = None,
432
+ ):
433
+ """
434
+ Build Qdrant filter from parameters.
435
+
436
+ Supports single values or lists (using MatchAny).
437
+ """
438
+ from qdrant_client.models import FieldCondition, Filter, MatchAny, MatchValue
439
+
440
+ conditions = []
441
+
442
+ if year is not None:
443
+ if isinstance(year, list):
444
+ year_values = [int(y) if isinstance(y, str) else y for y in year]
445
+ conditions.append(FieldCondition(key="year", match=MatchAny(any=year_values)))
446
+ else:
447
+ year_value = int(year) if isinstance(year, str) else year
448
+ conditions.append(FieldCondition(key="year", match=MatchValue(value=year_value)))
449
+
450
+ if source is not None:
451
+ if isinstance(source, list):
452
+ conditions.append(FieldCondition(key="source", match=MatchAny(any=source)))
453
+ else:
454
+ conditions.append(FieldCondition(key="source", match=MatchValue(value=source)))
455
+
456
+ if district is not None:
457
+ if isinstance(district, list):
458
+ conditions.append(FieldCondition(key="district", match=MatchAny(any=district)))
459
+ else:
460
+ conditions.append(FieldCondition(key="district", match=MatchValue(value=district)))
461
+
462
+ if filename is not None:
463
+ if isinstance(filename, list):
464
+ conditions.append(FieldCondition(key="filename", match=MatchAny(any=filename)))
465
+ else:
466
+ conditions.append(FieldCondition(key="filename", match=MatchValue(value=filename)))
467
+
468
+ if has_text is not None:
469
+ conditions.append(FieldCondition(key="has_text", match=MatchValue(value=has_text)))
470
+
471
+ return Filter(must=conditions) if conditions else None
@@ -0,0 +1,19 @@
1
+ """
2
+ Visualization module - Saliency maps and attention visualization.
3
+
4
+ This module provides:
5
+ - Saliency map generation showing query-document relevance
6
+ - Attention heatmaps for visual token analysis
7
+ """
8
+
9
+ from visual_rag.visualization.saliency import (
10
+ create_saliency_overlay,
11
+ generate_saliency_map,
12
+ visualize_search_results,
13
+ )
14
+
15
+ __all__ = [
16
+ "generate_saliency_map",
17
+ "create_saliency_overlay",
18
+ "visualize_search_results",
19
+ ]