visual-rag-toolkit 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. benchmarks/README.md +101 -0
  2. benchmarks/__init__.py +11 -0
  3. benchmarks/analyze_results.py +187 -0
  4. benchmarks/benchmark_datasets.txt +105 -0
  5. benchmarks/prepare_submission.py +205 -0
  6. benchmarks/quick_test.py +566 -0
  7. benchmarks/run_vidore.py +513 -0
  8. benchmarks/vidore_beir_qdrant/run_qdrant_beir.py +1365 -0
  9. benchmarks/vidore_tatdqa_test/COMMANDS.md +83 -0
  10. benchmarks/vidore_tatdqa_test/__init__.py +6 -0
  11. benchmarks/vidore_tatdqa_test/dataset_loader.py +363 -0
  12. benchmarks/vidore_tatdqa_test/metrics.py +44 -0
  13. benchmarks/vidore_tatdqa_test/run_qdrant.py +799 -0
  14. benchmarks/vidore_tatdqa_test/sweep_eval.py +372 -0
  15. demo/__init__.py +10 -0
  16. demo/app.py +45 -0
  17. demo/commands.py +334 -0
  18. demo/config.py +34 -0
  19. demo/download_models.py +75 -0
  20. demo/evaluation.py +602 -0
  21. demo/example_metadata_mapping_sigir.json +37 -0
  22. demo/indexing.py +286 -0
  23. demo/qdrant_utils.py +211 -0
  24. demo/results.py +35 -0
  25. demo/test_qdrant_connection.py +119 -0
  26. demo/ui/__init__.py +15 -0
  27. demo/ui/benchmark.py +355 -0
  28. demo/ui/header.py +30 -0
  29. demo/ui/playground.py +339 -0
  30. demo/ui/sidebar.py +162 -0
  31. demo/ui/upload.py +487 -0
  32. visual_rag/__init__.py +98 -0
  33. visual_rag/cli/__init__.py +1 -0
  34. visual_rag/cli/main.py +629 -0
  35. visual_rag/config.py +230 -0
  36. visual_rag/demo_runner.py +90 -0
  37. visual_rag/embedding/__init__.py +26 -0
  38. visual_rag/embedding/pooling.py +343 -0
  39. visual_rag/embedding/visual_embedder.py +622 -0
  40. visual_rag/indexing/__init__.py +21 -0
  41. visual_rag/indexing/cloudinary_uploader.py +274 -0
  42. visual_rag/indexing/pdf_processor.py +324 -0
  43. visual_rag/indexing/pipeline.py +628 -0
  44. visual_rag/indexing/qdrant_indexer.py +478 -0
  45. visual_rag/preprocessing/__init__.py +3 -0
  46. visual_rag/preprocessing/crop_empty.py +120 -0
  47. visual_rag/qdrant_admin.py +222 -0
  48. visual_rag/retrieval/__init__.py +19 -0
  49. visual_rag/retrieval/multi_vector.py +222 -0
  50. visual_rag/retrieval/single_stage.py +126 -0
  51. visual_rag/retrieval/three_stage.py +173 -0
  52. visual_rag/retrieval/two_stage.py +471 -0
  53. visual_rag/visualization/__init__.py +19 -0
  54. visual_rag/visualization/saliency.py +335 -0
  55. visual_rag_toolkit-0.1.1.dist-info/METADATA +305 -0
  56. visual_rag_toolkit-0.1.1.dist-info/RECORD +59 -0
  57. visual_rag_toolkit-0.1.1.dist-info/WHEEL +4 -0
  58. visual_rag_toolkit-0.1.1.dist-info/entry_points.txt +3 -0
  59. visual_rag_toolkit-0.1.1.dist-info/licenses/LICENSE +22 -0
@@ -0,0 +1,566 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Quick Benchmark - Validate retrieval quality with ViDoRe data.
4
+
5
+ This script:
6
+ 1. Downloads samples from ViDoRe (with ground truth relevance)
7
+ 2. Embeds with ColSmol-500M
8
+ 3. Tests retrieval strategies (exhaustive vs two-stage)
9
+ 4. Computes METRICS: NDCG@K, MRR@K, Recall@K
10
+ 5. Compares speed and quality
11
+
12
+ Usage:
13
+ python quick_test.py --samples 100
14
+ python quick_test.py --samples 500 --skip-exhaustive # Faster
15
+ """
16
+
17
+ import sys
18
+ import time
19
+ import argparse
20
+ import logging
21
+ from pathlib import Path
22
+ from typing import List, Dict, Any
23
+
24
+ # Add parent directory to Python path (so we can import visual_rag)
25
+ # This allows running the script directly without pip install
26
+ _script_dir = Path(__file__).parent
27
+ _parent_dir = _script_dir.parent
28
+ if str(_parent_dir) not in sys.path:
29
+ sys.path.insert(0, str(_parent_dir))
30
+
31
+ import numpy as np
32
+ from tqdm import tqdm
33
+
34
+ # Visual RAG imports (now works without pip install)
35
+ from visual_rag.embedding import VisualEmbedder
36
+ from visual_rag.embedding.pooling import (
37
+ tile_level_mean_pooling,
38
+ compute_maxsim_score,
39
+ )
40
+
41
+ # Optional: datasets for ViDoRe
42
+ try:
43
+ from datasets import load_dataset as hf_load_dataset
44
+ HAS_DATASETS = True
45
+ except ImportError:
46
+ HAS_DATASETS = False
47
+
48
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
49
+ logger = logging.getLogger(__name__)
50
+
51
+
52
+ def load_vidore_sample(num_samples: int = 100) -> List[Dict]:
53
+ """
54
+ Load sample from ViDoRe DocVQA with ground truth.
55
+
56
+ Each sample has a query and its relevant document (1:1 mapping).
57
+ This allows computing retrieval metrics.
58
+ """
59
+ if not HAS_DATASETS:
60
+ logger.error("Install datasets: pip install datasets")
61
+ sys.exit(1)
62
+
63
+ logger.info(f"šŸ“„ Loading {num_samples} samples from ViDoRe DocVQA...")
64
+
65
+ ds = hf_load_dataset("vidore/docvqa_test_subsampled", split="test")
66
+
67
+ samples = []
68
+ for i, example in enumerate(ds):
69
+ if i >= num_samples:
70
+ break
71
+
72
+ samples.append({
73
+ "id": i,
74
+ "doc_id": f"doc_{i}",
75
+ "query_id": f"q_{i}",
76
+ "image": example.get("image", example.get("page_image")),
77
+ "query": example.get("query", example.get("question", "")),
78
+ # Ground truth: query i is relevant to doc i
79
+ "relevant_doc": f"doc_{i}",
80
+ })
81
+
82
+ logger.info(f"āœ… Loaded {len(samples)} samples with ground truth")
83
+ return samples
84
+
85
+
86
+ def embed_all(
87
+ samples: List[Dict],
88
+ model_name: str = "vidore/colSmol-500M",
89
+ ) -> Dict[str, Any]:
90
+ """Embed all documents and queries."""
91
+ logger.info(f"\nšŸ¤– Loading model: {model_name}")
92
+ embedder = VisualEmbedder(model_name=model_name)
93
+
94
+ images = [s["image"] for s in samples]
95
+ queries = [s["query"] for s in samples if s["query"]]
96
+
97
+ # Embed images
98
+ logger.info(f"šŸŽØ Embedding {len(images)} documents...")
99
+ start_time = time.time()
100
+
101
+ embeddings, token_infos = embedder.embed_images(
102
+ images, batch_size=4, return_token_info=True
103
+ )
104
+
105
+ doc_embed_time = time.time() - start_time
106
+ logger.info(f" Time: {doc_embed_time:.2f}s ({doc_embed_time/len(images)*1000:.1f}ms/doc)")
107
+
108
+ # Process embeddings: extract visual tokens + tile-level pooling
109
+ doc_data = {}
110
+ for i, (emb, token_info) in enumerate(zip(embeddings, token_infos)):
111
+ if hasattr(emb, 'cpu'):
112
+ emb = emb.cpu()
113
+ emb_np = emb.numpy() if hasattr(emb, 'numpy') else np.array(emb)
114
+
115
+ # Extract visual tokens only (filter special tokens)
116
+ visual_indices = token_info["visual_token_indices"]
117
+ visual_emb = emb_np[visual_indices].astype(np.float32)
118
+
119
+ # Tile-level pooling
120
+ n_rows = token_info.get("n_rows", 4)
121
+ n_cols = token_info.get("n_cols", 3)
122
+ num_tiles = n_rows * n_cols + 1 if n_rows and n_cols else 13
123
+
124
+ tile_pooled = tile_level_mean_pooling(visual_emb, num_tiles, patches_per_tile=64)
125
+
126
+ doc_data[f"doc_{i}"] = {
127
+ "embedding": visual_emb,
128
+ "pooled": tile_pooled,
129
+ "num_visual_tokens": len(visual_indices),
130
+ "num_tiles": tile_pooled.shape[0],
131
+ }
132
+
133
+ # Embed queries
134
+ logger.info(f"šŸ” Embedding {len(queries)} queries...")
135
+ start_time = time.time()
136
+
137
+ query_data = {}
138
+ for i, query in enumerate(tqdm(queries, desc="Queries")):
139
+ q_emb = embedder.embed_query(query)
140
+ if hasattr(q_emb, 'cpu'):
141
+ q_emb = q_emb.cpu()
142
+ q_np = q_emb.numpy() if hasattr(q_emb, 'numpy') else np.array(q_emb)
143
+ query_data[f"q_{i}"] = q_np.astype(np.float32)
144
+
145
+ query_embed_time = time.time() - start_time
146
+
147
+ return {
148
+ "docs": doc_data,
149
+ "queries": query_data,
150
+ "samples": samples,
151
+ "doc_embed_time": doc_embed_time,
152
+ "query_embed_time": query_embed_time,
153
+ "model": model_name,
154
+ }
155
+
156
+
157
+ def search_exhaustive(query_emb: np.ndarray, docs: Dict, top_k: int = 10) -> List[Dict]:
158
+ """Exhaustive MaxSim search over all documents."""
159
+ scores = []
160
+ for doc_id, doc in docs.items():
161
+ score = compute_maxsim_score(query_emb, doc["embedding"])
162
+ scores.append({"id": doc_id, "score": score})
163
+
164
+ scores.sort(key=lambda x: x["score"], reverse=True)
165
+ return scores[:top_k]
166
+
167
+
168
+ def search_two_stage(
169
+ query_emb: np.ndarray,
170
+ docs: Dict,
171
+ prefetch_k: int = 20,
172
+ top_k: int = 10,
173
+ ) -> List[Dict]:
174
+ """
175
+ Two-stage retrieval with tile-level pooling.
176
+
177
+ Stage 1: Fast prefetch using tile-pooled vectors
178
+ Stage 2: Exact MaxSim reranking on candidates
179
+ """
180
+ # Stage 1: Tile-level pooled search
181
+ query_pooled = query_emb.mean(axis=0)
182
+ query_pooled = query_pooled / (np.linalg.norm(query_pooled) + 1e-8)
183
+
184
+ stage1_scores = []
185
+ for doc_id, doc in docs.items():
186
+ doc_pooled = doc["pooled"]
187
+ doc_norm = doc_pooled / (np.linalg.norm(doc_pooled, axis=1, keepdims=True) + 1e-8)
188
+ tile_sims = np.dot(doc_norm, query_pooled)
189
+ score = float(tile_sims.max())
190
+ stage1_scores.append({"id": doc_id, "score": score})
191
+
192
+ stage1_scores.sort(key=lambda x: x["score"], reverse=True)
193
+ candidates = stage1_scores[:prefetch_k]
194
+
195
+ # Stage 2: Exact MaxSim on candidates
196
+ reranked = []
197
+ for cand in candidates:
198
+ doc_id = cand["id"]
199
+ score = compute_maxsim_score(query_emb, docs[doc_id]["embedding"])
200
+ reranked.append({"id": doc_id, "score": score, "stage1_rank": stage1_scores.index(cand) + 1})
201
+
202
+ reranked.sort(key=lambda x: x["score"], reverse=True)
203
+ return reranked[:top_k]
204
+
205
+
206
+ def compute_metrics(
207
+ results: Dict[str, List[Dict]],
208
+ samples: List[Dict],
209
+ k_values: List[int] = [1, 3, 5, 7, 10],
210
+ ) -> Dict[str, float]:
211
+ """
212
+ Compute retrieval metrics.
213
+
214
+ Since ViDoRe has 1:1 query-doc mapping (1 relevant doc per query):
215
+ - Recall@K (Hit Rate): Is the relevant doc in top-K? (0 or 1)
216
+ - Precision@K: (# relevant in top-K) / K
217
+ - MRR@K: 1/rank if found in top-K, else 0
218
+ - NDCG@K: DCG / IDCG with binary relevance
219
+ """
220
+ metrics = {}
221
+
222
+ # Also track per-query ranks for analysis
223
+ all_ranks = []
224
+
225
+ for k in k_values:
226
+ recalls = []
227
+ precisions = []
228
+ mrrs = []
229
+ ndcgs = []
230
+
231
+ for sample in samples:
232
+ query_id = sample["query_id"]
233
+ relevant_doc = sample["relevant_doc"]
234
+
235
+ if query_id not in results:
236
+ continue
237
+
238
+ ranking = results[query_id][:k]
239
+ ranked_ids = [r["id"] for r in ranking]
240
+
241
+ # Find rank of relevant doc (1-indexed, 0 if not found)
242
+ rank = 0
243
+ for i, doc_id in enumerate(ranked_ids):
244
+ if doc_id == relevant_doc:
245
+ rank = i + 1
246
+ break
247
+
248
+ # Recall@K (Hit Rate): 1 if found in top-K
249
+ found = 1.0 if rank > 0 else 0.0
250
+ recalls.append(found)
251
+
252
+ # Precision@K: (# relevant found) / K
253
+ # With 1 relevant doc: 1/K if found, 0 otherwise
254
+ precision = found / k
255
+ precisions.append(precision)
256
+
257
+ # MRR@K: 1/rank if found
258
+ mrr = 1.0 / rank if rank > 0 else 0.0
259
+ mrrs.append(mrr)
260
+
261
+ # NDCG@K (binary relevance)
262
+ # DCG = 1/log2(rank+1) if found, 0 otherwise
263
+ # IDCG = 1/log2(2) = 1 (best case: relevant at rank 1)
264
+ dcg = 1.0 / np.log2(rank + 1) if rank > 0 else 0.0
265
+ idcg = 1.0
266
+ ndcg = dcg / idcg
267
+ ndcgs.append(ndcg)
268
+
269
+ # Track actual rank for analysis (only for k=10)
270
+ if k == max(k_values):
271
+ full_ranking = results[query_id]
272
+ full_rank = 0
273
+ for i, r in enumerate(full_ranking):
274
+ if r["id"] == relevant_doc:
275
+ full_rank = i + 1
276
+ break
277
+ all_ranks.append(full_rank)
278
+
279
+ metrics[f"Recall@{k}"] = np.mean(recalls)
280
+ metrics[f"P@{k}"] = np.mean(precisions)
281
+ metrics[f"MRR@{k}"] = np.mean(mrrs)
282
+ metrics[f"NDCG@{k}"] = np.mean(ndcgs)
283
+
284
+ # Add summary stats
285
+ if all_ranks:
286
+ found_ranks = [r for r in all_ranks if r > 0]
287
+ metrics["avg_rank"] = np.mean(found_ranks) if found_ranks else float('inf')
288
+ metrics["median_rank"] = np.median(found_ranks) if found_ranks else float('inf')
289
+ metrics["not_found"] = sum(1 for r in all_ranks if r == 0)
290
+
291
+ return metrics
292
+
293
+
294
+ def run_benchmark(
295
+ data: Dict,
296
+ skip_exhaustive: bool = False,
297
+ prefetch_k: int = None,
298
+ top_k: int = 10,
299
+ ) -> Dict[str, Dict]:
300
+ """Run retrieval benchmark with metrics."""
301
+ docs = data["docs"]
302
+ queries = data["queries"]
303
+ samples = data["samples"]
304
+ num_docs = len(docs)
305
+
306
+ # Auto-set prefetch_k to be meaningful (default: 20, or 20% of docs if >100 docs)
307
+ if prefetch_k is None:
308
+ if num_docs <= 100:
309
+ prefetch_k = 20 # Default: prefetch 20, rerank to top-10
310
+ else:
311
+ prefetch_k = max(20, min(100, int(num_docs * 0.2))) # 20% for larger collections
312
+
313
+ # Ensure prefetch_k < num_docs for meaningful two-stage comparison
314
+ if prefetch_k >= num_docs:
315
+ logger.warning(f"āš ļø prefetch_k={prefetch_k} >= num_docs={num_docs}")
316
+ logger.warning(f" Two-stage will fetch ALL docs (same as exhaustive)")
317
+ logger.warning(f" Use --samples > {prefetch_k * 3} for meaningful comparison")
318
+
319
+ logger.info(f"šŸ“Š Benchmark config: {num_docs} docs, prefetch_k={prefetch_k}, top_k={top_k}")
320
+ logger.info(f" (Both methods return top-{top_k} results - realistic retrieval scenario)")
321
+
322
+ results = {}
323
+
324
+ # Two-stage retrieval (NOVEL)
325
+ logger.info(f"\nšŸ”¬ Running Two-Stage retrieval (prefetch top-{prefetch_k}, rerank to top-{top_k})...")
326
+ two_stage_results = {}
327
+ two_stage_times = []
328
+
329
+ for sample in tqdm(samples, desc="Two-Stage"):
330
+ query_id = sample["query_id"]
331
+ query_emb = queries[query_id]
332
+
333
+ start = time.time()
334
+ ranking = search_two_stage(query_emb, docs, prefetch_k=prefetch_k, top_k=top_k)
335
+ two_stage_times.append(time.time() - start)
336
+
337
+ two_stage_results[query_id] = ranking
338
+
339
+ two_stage_metrics = compute_metrics(two_stage_results, samples)
340
+ two_stage_metrics["avg_time_ms"] = np.mean(two_stage_times) * 1000
341
+ two_stage_metrics["prefetch_k"] = prefetch_k
342
+ two_stage_metrics["top_k"] = top_k
343
+ results["two_stage"] = two_stage_metrics
344
+
345
+ # Exhaustive search (baseline)
346
+ if not skip_exhaustive:
347
+ logger.info(f"šŸ”¬ Running Exhaustive MaxSim (searches ALL {num_docs} docs, returns top-{top_k})...")
348
+ exhaustive_results = {}
349
+ exhaustive_times = []
350
+
351
+ for sample in tqdm(samples, desc="Exhaustive"):
352
+ query_id = sample["query_id"]
353
+ query_emb = queries[query_id]
354
+
355
+ start = time.time()
356
+ ranking = search_exhaustive(query_emb, docs, top_k=top_k)
357
+ exhaustive_times.append(time.time() - start)
358
+
359
+ exhaustive_results[query_id] = ranking
360
+
361
+ exhaustive_metrics = compute_metrics(exhaustive_results, samples)
362
+ exhaustive_metrics["avg_time_ms"] = np.mean(exhaustive_times) * 1000
363
+ exhaustive_metrics["top_k"] = top_k
364
+ results["exhaustive"] = exhaustive_metrics
365
+
366
+ return results
367
+
368
+
369
+ def print_results(data: Dict, benchmark_results: Dict, show_precision: bool = False):
370
+ """Print benchmark results."""
371
+ print("\n" + "=" * 80)
372
+ print("šŸ“Š BENCHMARK RESULTS")
373
+ print("=" * 80)
374
+
375
+ num_docs = len(data['docs'])
376
+ print(f"\nšŸ¤– Model: {data['model']}")
377
+ print(f"šŸ“„ Documents: {num_docs}")
378
+ print(f"šŸ” Queries: {len(data['queries'])}")
379
+
380
+ # Embedding stats
381
+ sample_doc = list(data['docs'].values())[0]
382
+ print(f"\nšŸ“ Embedding (after visual token filtering):")
383
+ print(f" Visual tokens per doc: {sample_doc['num_visual_tokens']}")
384
+ print(f" Tile-pooled vectors: {sample_doc['num_tiles']}")
385
+
386
+ if "two_stage" in benchmark_results:
387
+ prefetch_k = benchmark_results["two_stage"].get("prefetch_k", "?")
388
+ print(f" Two-stage prefetch_k: {prefetch_k} (of {num_docs} docs)")
389
+
390
+ # Method labels - clearer naming
391
+ def get_label(method):
392
+ if method == "two_stage":
393
+ return "Pooled+Rerank" # Tile-pooled prefetch + MaxSim rerank
394
+ else:
395
+ return "Full MaxSim" # Exhaustive MaxSim on all docs
396
+
397
+ # Recall / Hit Rate table
398
+ print(f"\nšŸŽÆ RECALL (Hit Rate) @ K:")
399
+ print(f" {'Method':<20} {'@1':>8} {'@3':>8} {'@5':>8} {'@7':>8} {'@10':>8}")
400
+ print(f" {'-'*60}")
401
+
402
+ for method, metrics in benchmark_results.items():
403
+ print(f" {get_label(method):<20} "
404
+ f"{metrics.get('Recall@1', 0):>8.3f} "
405
+ f"{metrics.get('Recall@3', 0):>8.3f} "
406
+ f"{metrics.get('Recall@5', 0):>8.3f} "
407
+ f"{metrics.get('Recall@7', 0):>8.3f} "
408
+ f"{metrics.get('Recall@10', 0):>8.3f}")
409
+
410
+ # Precision table (optional)
411
+ if show_precision:
412
+ print(f"\nšŸ“ PRECISION @ K:")
413
+ print(f" {'Method':<20} {'@1':>8} {'@3':>8} {'@5':>8} {'@7':>8} {'@10':>8}")
414
+ print(f" {'-'*60}")
415
+
416
+ for method, metrics in benchmark_results.items():
417
+ print(f" {get_label(method):<20} "
418
+ f"{metrics.get('P@1', 0):>8.3f} "
419
+ f"{metrics.get('P@3', 0):>8.3f} "
420
+ f"{metrics.get('P@5', 0):>8.3f} "
421
+ f"{metrics.get('P@7', 0):>8.3f} "
422
+ f"{metrics.get('P@10', 0):>8.3f}")
423
+
424
+ # NDCG table
425
+ print(f"\nšŸ“ˆ NDCG @ K:")
426
+ print(f" {'Method':<20} {'@1':>8} {'@3':>8} {'@5':>8} {'@7':>8} {'@10':>8}")
427
+ print(f" {'-'*60}")
428
+
429
+ for method, metrics in benchmark_results.items():
430
+ print(f" {get_label(method):<20} "
431
+ f"{metrics.get('NDCG@1', 0):>8.3f} "
432
+ f"{metrics.get('NDCG@3', 0):>8.3f} "
433
+ f"{metrics.get('NDCG@5', 0):>8.3f} "
434
+ f"{metrics.get('NDCG@7', 0):>8.3f} "
435
+ f"{metrics.get('NDCG@10', 0):>8.3f}")
436
+
437
+ # MRR table
438
+ print(f"\nšŸ” MRR @ K:")
439
+ print(f" {'Method':<20} {'@1':>8} {'@3':>8} {'@5':>8} {'@7':>8} {'@10':>8}")
440
+ print(f" {'-'*60}")
441
+
442
+ for method, metrics in benchmark_results.items():
443
+ print(f" {get_label(method):<20} "
444
+ f"{metrics.get('MRR@1', 0):>8.3f} "
445
+ f"{metrics.get('MRR@3', 0):>8.3f} "
446
+ f"{metrics.get('MRR@5', 0):>8.3f} "
447
+ f"{metrics.get('MRR@7', 0):>8.3f} "
448
+ f"{metrics.get('MRR@10', 0):>8.3f}")
449
+
450
+ # Speed comparison
451
+ top_k = benchmark_results.get("two_stage", benchmark_results.get("exhaustive", {})).get("top_k", 10)
452
+ print(f"\nā±ļø SPEED (both return top-{top_k} results):")
453
+ print(f" {'Method':<20} {'Time (ms)':>12} {'Docs searched':>15}")
454
+ print(f" {'-'*50}")
455
+
456
+ for method, metrics in benchmark_results.items():
457
+ if method == "two_stage":
458
+ searched = metrics.get("prefetch_k", "?")
459
+ label = f"{searched} (stage-1)"
460
+ else:
461
+ searched = num_docs
462
+ label = f"{searched} (all)"
463
+ print(f" {get_label(method):<20} {metrics.get('avg_time_ms', 0):>12.2f} {label:>15}")
464
+
465
+ # Comparison summary
466
+ if "exhaustive" in benchmark_results and "two_stage" in benchmark_results:
467
+ ex = benchmark_results["exhaustive"]
468
+ ts = benchmark_results["two_stage"]
469
+
470
+ print(f"\nšŸ’” POOLED+RERANK vs FULL MAXSIM:")
471
+
472
+ for k in [1, 5, 10]:
473
+ ex_recall = ex.get(f"Recall@{k}", 0)
474
+ ts_recall = ts.get(f"Recall@{k}", 0)
475
+ if ex_recall > 0:
476
+ retention = ts_recall / ex_recall * 100
477
+ print(f" • Recall@{k} retention: {retention:.1f}% ({ts_recall:.3f} vs {ex_recall:.3f})")
478
+
479
+ speedup = ex["avg_time_ms"] / ts["avg_time_ms"] if ts["avg_time_ms"] > 0 else 0
480
+ print(f" • Speedup: {speedup:.1f}x")
481
+
482
+ # Rank stats with explanation
483
+ if "avg_rank" in ts:
484
+ prefetch_k = ts.get("prefetch_k", "?")
485
+ top_k = ts.get("top_k", 10)
486
+ not_found = ts.get("not_found", 0)
487
+ total = len(data["queries"])
488
+
489
+ print(f"\nšŸ“Š POOLED+RERANK STATISTICS:")
490
+ print(f" Stage-1 (pooled prefetch):")
491
+ print(f" • Searches top-{prefetch_k} candidates using tile-pooled vectors")
492
+ print(f" • {total - not_found}/{total} queries ({100 - not_found/total*100:.1f}%) had relevant doc in prefetch")
493
+ print(f" • {not_found}/{total} queries ({not_found/total*100:.1f}%) missed (relevant doc ranked >{prefetch_k})")
494
+ print(f" Stage-2 (MaxSim reranking):")
495
+ print(f" • Reranks prefetch candidates with exact MaxSim")
496
+ print(f" • Returns final top-{top_k} results")
497
+ if ts['avg_rank'] < float('inf'):
498
+ print(f" • Avg rank of relevant doc (when found): {ts['avg_rank']:.1f}")
499
+ print(f" • Median rank: {ts['median_rank']:.1f}")
500
+ print(f"\n šŸ’” The {not_found/total*100:.1f}% miss rate is for stage-1 prefetch.")
501
+ print(f" Final Recall@{top_k} shows how many relevant docs ARE in top-{top_k} results.")
502
+
503
+ print("\n" + "=" * 80)
504
+ print("āœ… Benchmark complete!")
505
+
506
+
507
+ def main():
508
+ parser = argparse.ArgumentParser(
509
+ description="Quick benchmark for visual-rag-toolkit",
510
+ formatter_class=argparse.RawDescriptionHelpFormatter,
511
+ )
512
+ parser.add_argument(
513
+ "--samples", type=int, default=100,
514
+ help="Number of samples (default: 100)"
515
+ )
516
+ parser.add_argument(
517
+ "--model", type=str, default="vidore/colSmol-500M",
518
+ help="Model: vidore/colSmol-500M (default), vidore/colpali-v1.3"
519
+ )
520
+ parser.add_argument(
521
+ "--prefetch-k", type=int, default=None,
522
+ help="Stage 1 candidates for two-stage (default: 20 for <=100 docs, auto for larger)"
523
+ )
524
+ parser.add_argument(
525
+ "--skip-exhaustive", action="store_true",
526
+ help="Skip exhaustive baseline (faster)"
527
+ )
528
+ parser.add_argument(
529
+ "--show-precision", action="store_true",
530
+ help="Show Precision@K metrics (hidden by default)"
531
+ )
532
+ parser.add_argument(
533
+ "--top-k", type=int, default=10,
534
+ help="Number of results to return (default: 10, realistic retrieval scenario)"
535
+ )
536
+
537
+ args = parser.parse_args()
538
+
539
+ print("\n" + "=" * 70)
540
+ print("🧪 VISUAL RAG TOOLKIT - RETRIEVAL BENCHMARK")
541
+ print("=" * 70)
542
+
543
+ # Load samples
544
+ samples = load_vidore_sample(args.samples)
545
+
546
+ if not samples:
547
+ logger.error("No samples loaded!")
548
+ sys.exit(1)
549
+
550
+ # Embed all
551
+ data = embed_all(samples, args.model)
552
+
553
+ # Run benchmark
554
+ benchmark_results = run_benchmark(
555
+ data,
556
+ skip_exhaustive=args.skip_exhaustive,
557
+ prefetch_k=args.prefetch_k,
558
+ top_k=args.top_k,
559
+ )
560
+
561
+ # Print results
562
+ print_results(data, benchmark_results, show_precision=args.show_precision)
563
+
564
+
565
+ if __name__ == "__main__":
566
+ main()