visual-rag-toolkit 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. benchmarks/README.md +101 -0
  2. benchmarks/__init__.py +11 -0
  3. benchmarks/analyze_results.py +187 -0
  4. benchmarks/benchmark_datasets.txt +105 -0
  5. benchmarks/prepare_submission.py +205 -0
  6. benchmarks/quick_test.py +566 -0
  7. benchmarks/run_vidore.py +513 -0
  8. benchmarks/vidore_beir_qdrant/run_qdrant_beir.py +1365 -0
  9. benchmarks/vidore_tatdqa_test/COMMANDS.md +83 -0
  10. benchmarks/vidore_tatdqa_test/__init__.py +6 -0
  11. benchmarks/vidore_tatdqa_test/dataset_loader.py +363 -0
  12. benchmarks/vidore_tatdqa_test/metrics.py +44 -0
  13. benchmarks/vidore_tatdqa_test/run_qdrant.py +799 -0
  14. benchmarks/vidore_tatdqa_test/sweep_eval.py +372 -0
  15. demo/__init__.py +10 -0
  16. demo/app.py +45 -0
  17. demo/commands.py +334 -0
  18. demo/config.py +34 -0
  19. demo/download_models.py +75 -0
  20. demo/evaluation.py +602 -0
  21. demo/example_metadata_mapping_sigir.json +37 -0
  22. demo/indexing.py +286 -0
  23. demo/qdrant_utils.py +211 -0
  24. demo/results.py +35 -0
  25. demo/test_qdrant_connection.py +119 -0
  26. demo/ui/__init__.py +15 -0
  27. demo/ui/benchmark.py +355 -0
  28. demo/ui/header.py +30 -0
  29. demo/ui/playground.py +339 -0
  30. demo/ui/sidebar.py +162 -0
  31. demo/ui/upload.py +487 -0
  32. visual_rag/__init__.py +98 -0
  33. visual_rag/cli/__init__.py +1 -0
  34. visual_rag/cli/main.py +629 -0
  35. visual_rag/config.py +230 -0
  36. visual_rag/demo_runner.py +90 -0
  37. visual_rag/embedding/__init__.py +26 -0
  38. visual_rag/embedding/pooling.py +343 -0
  39. visual_rag/embedding/visual_embedder.py +622 -0
  40. visual_rag/indexing/__init__.py +21 -0
  41. visual_rag/indexing/cloudinary_uploader.py +274 -0
  42. visual_rag/indexing/pdf_processor.py +324 -0
  43. visual_rag/indexing/pipeline.py +628 -0
  44. visual_rag/indexing/qdrant_indexer.py +478 -0
  45. visual_rag/preprocessing/__init__.py +3 -0
  46. visual_rag/preprocessing/crop_empty.py +120 -0
  47. visual_rag/qdrant_admin.py +222 -0
  48. visual_rag/retrieval/__init__.py +19 -0
  49. visual_rag/retrieval/multi_vector.py +222 -0
  50. visual_rag/retrieval/single_stage.py +126 -0
  51. visual_rag/retrieval/three_stage.py +173 -0
  52. visual_rag/retrieval/two_stage.py +471 -0
  53. visual_rag/visualization/__init__.py +19 -0
  54. visual_rag/visualization/saliency.py +335 -0
  55. visual_rag_toolkit-0.1.1.dist-info/METADATA +305 -0
  56. visual_rag_toolkit-0.1.1.dist-info/RECORD +59 -0
  57. visual_rag_toolkit-0.1.1.dist-info/WHEEL +4 -0
  58. visual_rag_toolkit-0.1.1.dist-info/entry_points.txt +3 -0
  59. visual_rag_toolkit-0.1.1.dist-info/licenses/LICENSE +22 -0
@@ -0,0 +1,513 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ ViDoRe Benchmark Evaluation Script
4
+
5
+ Evaluates visual document retrieval on the ViDoRe benchmark datasets.
6
+
7
+ Usage:
8
+ # Single dataset
9
+ python run_vidore.py --dataset vidore/docvqa_test_subsampled
10
+
11
+ # All datasets
12
+ python run_vidore.py --all
13
+
14
+ # With two-stage retrieval
15
+ python run_vidore.py --dataset vidore/docvqa_test_subsampled --two-stage
16
+ """
17
+
18
+ import argparse
19
+ import json
20
+ import time
21
+ import logging
22
+ from pathlib import Path
23
+ from typing import List, Dict, Any, Optional
24
+
25
+ import numpy as np
26
+ from tqdm import tqdm
27
+
28
+ logging.basicConfig(level=logging.INFO)
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ # ViDoRe benchmark datasets
33
+ # Official leaderboard: https://huggingface.co/spaces/vidore/vidore-leaderboard
34
+ VIDORE_DATASETS = {
35
+ # === RECOMMENDED FOR QUICK TESTING (smaller, faster) ===
36
+ "docvqa": "vidore/docvqa_test_subsampled", # ~500 queries, Document VQA
37
+ "infovqa": "vidore/infovqa_test_subsampled", # ~500 queries, Infographics
38
+ "tabfquad": "vidore/tabfquad_test_subsampled", # ~500 queries, Tables
39
+
40
+ # === FULL EVALUATION ===
41
+ "tatdqa": "vidore/tatdqa_test", # ~1500 queries, Financial tables
42
+ "arxivqa": "vidore/arxivqa_test_subsampled", # ~500 queries, Scientific papers
43
+ "shift": "vidore/shiftproject_test", # ~500 queries, Sustainability reports
44
+ }
45
+
46
+ # Aliases for convenience
47
+ QUICK_DATASETS = ["docvqa", "infovqa"] # Fast testing
48
+ ALL_DATASETS = list(VIDORE_DATASETS.keys())
49
+
50
+
51
+ def load_dataset(dataset_name: str) -> Dict[str, Any]:
52
+ """Load a ViDoRe dataset from HuggingFace."""
53
+ try:
54
+ from datasets import load_dataset
55
+ except ImportError:
56
+ raise ImportError("datasets library required. Install with: pip install datasets")
57
+
58
+ logger.info(f"Loading dataset: {dataset_name}")
59
+
60
+ # Load dataset
61
+ ds = load_dataset(dataset_name, split="test")
62
+
63
+ # Extract queries and documents
64
+ # ViDoRe format: each example has query, image, and relevant doc info
65
+ queries = []
66
+ documents = []
67
+ qrels = {} # query_id -> {doc_id: relevance}
68
+
69
+ for idx, example in enumerate(tqdm(ds, desc="Loading data")):
70
+ query_id = f"q_{idx}"
71
+ doc_id = f"d_{idx}"
72
+
73
+ # Get query text
74
+ query_text = example.get("query", example.get("question", ""))
75
+ queries.append({
76
+ "id": query_id,
77
+ "text": query_text,
78
+ })
79
+
80
+ # Get document image
81
+ image = example.get("image", example.get("page_image"))
82
+ documents.append({
83
+ "id": doc_id,
84
+ "image": image,
85
+ })
86
+
87
+ # Relevance (self-document is relevant)
88
+ qrels[query_id] = {doc_id: 1}
89
+
90
+ logger.info(f"Loaded {len(queries)} queries and {len(documents)} documents")
91
+
92
+ return {
93
+ "queries": queries,
94
+ "documents": documents,
95
+ "qrels": qrels,
96
+ }
97
+
98
+
99
+ def embed_documents(
100
+ documents: List[Dict],
101
+ embedder,
102
+ batch_size: int = 4,
103
+ return_pooled: bool = False,
104
+ ) -> Dict[str, np.ndarray]:
105
+ """
106
+ Embed all documents.
107
+
108
+ Args:
109
+ documents: List of {id, image} dicts
110
+ embedder: VisualEmbedder instance
111
+ batch_size: Batch size for embedding
112
+ return_pooled: Also return tile-level pooled embeddings (for two-stage)
113
+
114
+ Returns:
115
+ doc_embeddings dict, and optionally pooled_embeddings dict
116
+ """
117
+ from visual_rag.embedding.pooling import tile_level_mean_pooling
118
+
119
+ logger.info(f"Embedding {len(documents)} documents...")
120
+
121
+ images = [doc["image"] for doc in documents]
122
+
123
+ # Get embeddings with token info for proper pooling
124
+ embeddings, token_infos = embedder.embed_images(
125
+ images, batch_size=batch_size, return_token_info=True
126
+ )
127
+
128
+ doc_embeddings = {}
129
+ pooled_embeddings = {} if return_pooled else None
130
+
131
+ for doc, emb, token_info in zip(documents, embeddings, token_infos):
132
+ if hasattr(emb, "numpy"):
133
+ emb_np = emb.numpy()
134
+ elif hasattr(emb, "cpu"):
135
+ emb_np = emb.cpu().numpy()
136
+ else:
137
+ emb_np = np.array(emb)
138
+
139
+ doc_embeddings[doc["id"]] = emb_np.astype(np.float32)
140
+
141
+ # Compute tile-level pooling (NOVEL approach)
142
+ if return_pooled:
143
+ n_rows = token_info.get("n_rows", 4)
144
+ n_cols = token_info.get("n_cols", 3)
145
+ num_tiles = n_rows * n_cols + 1 if n_rows and n_cols else 13
146
+
147
+ pooled = tile_level_mean_pooling(emb_np, num_tiles, patches_per_tile=64)
148
+ pooled_embeddings[doc["id"]] = pooled.astype(np.float32)
149
+
150
+ if return_pooled:
151
+ return doc_embeddings, pooled_embeddings
152
+ return doc_embeddings
153
+
154
+
155
+ def embed_queries(
156
+ queries: List[Dict],
157
+ embedder,
158
+ ) -> Dict[str, np.ndarray]:
159
+ """Embed all queries."""
160
+ logger.info(f"Embedding {len(queries)} queries...")
161
+
162
+ query_embeddings = {}
163
+ for query in tqdm(queries, desc="Embedding queries"):
164
+ emb = embedder.embed_query(query["text"])
165
+ if hasattr(emb, "numpy"):
166
+ emb = emb.numpy()
167
+ elif hasattr(emb, "cpu"):
168
+ emb = emb.cpu().numpy()
169
+ query_embeddings[query["id"]] = np.array(emb, dtype=np.float32)
170
+
171
+ return query_embeddings
172
+
173
+
174
+ def compute_maxsim(query_emb: np.ndarray, doc_emb: np.ndarray) -> float:
175
+ """Compute ColBERT-style MaxSim score."""
176
+ # Normalize
177
+ query_norm = query_emb / (np.linalg.norm(query_emb, axis=1, keepdims=True) + 1e-8)
178
+ doc_norm = doc_emb / (np.linalg.norm(doc_emb, axis=1, keepdims=True) + 1e-8)
179
+
180
+ # Compute similarity matrix
181
+ sim_matrix = np.dot(query_norm, doc_norm.T)
182
+
183
+ # MaxSim: max per query token, then sum
184
+ max_sims = sim_matrix.max(axis=1)
185
+ return float(max_sims.sum())
186
+
187
+
188
+ def search_exhaustive(
189
+ query_emb: np.ndarray,
190
+ doc_embeddings: Dict[str, np.ndarray],
191
+ top_k: int = 10,
192
+ ) -> List[Dict]:
193
+ """Exhaustive MaxSim search over all documents."""
194
+ scores = []
195
+ for doc_id, doc_emb in doc_embeddings.items():
196
+ score = compute_maxsim(query_emb, doc_emb)
197
+ scores.append({"id": doc_id, "score": score})
198
+
199
+ # Sort by score
200
+ scores.sort(key=lambda x: x["score"], reverse=True)
201
+ return scores[:top_k]
202
+
203
+
204
+ def search_two_stage(
205
+ query_emb: np.ndarray,
206
+ doc_embeddings: Dict[str, np.ndarray],
207
+ pooled_embeddings: Dict[str, np.ndarray],
208
+ prefetch_k: int = 100,
209
+ top_k: int = 10,
210
+ ) -> List[Dict]:
211
+ """
212
+ Two-stage retrieval: tile-level pooled prefetch + MaxSim rerank.
213
+
214
+ Stage 1: Use tile-level pooled vectors for fast retrieval
215
+ Each doc has [num_tiles, 128] pooled representation
216
+ Compute MaxSim on pooled vectors (much faster)
217
+
218
+ Stage 2: Exact MaxSim reranking on top candidates
219
+ Use full multi-vector embeddings for precision
220
+ """
221
+ # Stage 1: Pooled MaxSim (fast approximation)
222
+ # Query pooled: mean across query tokens → [128]
223
+ query_pooled = query_emb.mean(axis=0)
224
+ query_pooled = query_pooled / (np.linalg.norm(query_pooled) + 1e-8)
225
+
226
+ stage1_scores = []
227
+ for doc_id, doc_pooled in pooled_embeddings.items():
228
+ # doc_pooled shape: [num_tiles, 128] from tile-level pooling
229
+ # Compute similarity with each tile, take max (simplified MaxSim)
230
+ doc_norm = doc_pooled / (np.linalg.norm(doc_pooled, axis=1, keepdims=True) + 1e-8)
231
+ tile_sims = np.dot(doc_norm, query_pooled)
232
+ score = float(tile_sims.max()) # Max tile similarity
233
+ stage1_scores.append({"id": doc_id, "score": score})
234
+
235
+ stage1_scores.sort(key=lambda x: x["score"], reverse=True)
236
+ candidates = stage1_scores[:prefetch_k]
237
+
238
+ # Stage 2: Exact MaxSim rerank on candidates
239
+ reranked = []
240
+ for cand in candidates:
241
+ doc_id = cand["id"]
242
+ doc_emb = doc_embeddings[doc_id]
243
+ score = compute_maxsim(query_emb, doc_emb)
244
+ reranked.append({
245
+ "id": doc_id,
246
+ "score": score,
247
+ "stage1_score": cand["score"],
248
+ "stage1_rank": stage1_scores.index(cand) + 1,
249
+ })
250
+
251
+ reranked.sort(key=lambda x: x["score"], reverse=True)
252
+ return reranked[:top_k]
253
+
254
+
255
+ def compute_metrics(
256
+ results: Dict[str, List[Dict]],
257
+ qrels: Dict[str, Dict[str, int]],
258
+ ) -> Dict[str, float]:
259
+ """Compute evaluation metrics."""
260
+ ndcg_5 = []
261
+ ndcg_10 = []
262
+ mrr_10 = []
263
+ recall_5 = []
264
+ recall_10 = []
265
+
266
+ for query_id, ranking in results.items():
267
+ relevant = set(qrels.get(query_id, {}).keys())
268
+
269
+ # MRR@10
270
+ rr = 0.0
271
+ for i, doc in enumerate(ranking[:10]):
272
+ if doc["id"] in relevant:
273
+ rr = 1.0 / (i + 1)
274
+ break
275
+ mrr_10.append(rr)
276
+
277
+ # Recall@5, Recall@10
278
+ retrieved_5 = set(d["id"] for d in ranking[:5])
279
+ retrieved_10 = set(d["id"] for d in ranking[:10])
280
+
281
+ if relevant:
282
+ recall_5.append(len(retrieved_5 & relevant) / len(relevant))
283
+ recall_10.append(len(retrieved_10 & relevant) / len(relevant))
284
+
285
+ # NDCG@5, NDCG@10
286
+ dcg_5 = sum(
287
+ 1.0 / np.log2(i + 2) for i, d in enumerate(ranking[:5]) if d["id"] in relevant
288
+ )
289
+ dcg_10 = sum(
290
+ 1.0 / np.log2(i + 2) for i, d in enumerate(ranking[:10]) if d["id"] in relevant
291
+ )
292
+
293
+ # Ideal DCG
294
+ k_rel = min(len(relevant), 5)
295
+ idcg_5 = sum(1.0 / np.log2(i + 2) for i in range(k_rel))
296
+ k_rel = min(len(relevant), 10)
297
+ idcg_10 = sum(1.0 / np.log2(i + 2) for i in range(k_rel))
298
+
299
+ ndcg_5.append(dcg_5 / idcg_5 if idcg_5 > 0 else 0.0)
300
+ ndcg_10.append(dcg_10 / idcg_10 if idcg_10 > 0 else 0.0)
301
+
302
+ return {
303
+ "ndcg@5": float(np.mean(ndcg_5)),
304
+ "ndcg@10": float(np.mean(ndcg_10)),
305
+ "mrr@10": float(np.mean(mrr_10)),
306
+ "recall@5": float(np.mean(recall_5)),
307
+ "recall@10": float(np.mean(recall_10)),
308
+ }
309
+
310
+
311
+ def run_evaluation(
312
+ dataset_name: str,
313
+ model_name: str = "vidore/colSmol-500M",
314
+ two_stage: bool = False,
315
+ prefetch_k: int = 100,
316
+ top_k: int = 10,
317
+ output_dir: Optional[str] = None,
318
+ ) -> Dict[str, Any]:
319
+ """Run full evaluation on a dataset."""
320
+ from visual_rag.embedding import VisualEmbedder
321
+
322
+ logger.info(f"=" * 60)
323
+ logger.info(f"Evaluating: {dataset_name}")
324
+ logger.info(f"Model: {model_name}")
325
+ logger.info(f"Two-stage: {two_stage}")
326
+ logger.info(f"=" * 60)
327
+
328
+ # Load dataset
329
+ data = load_dataset(dataset_name)
330
+
331
+ # Initialize embedder
332
+ embedder = VisualEmbedder(model_name=model_name)
333
+
334
+ # Embed documents (with tile-level pooling if two-stage)
335
+ start_time = time.time()
336
+ if two_stage:
337
+ doc_embeddings, pooled_embeddings = embed_documents(
338
+ data["documents"], embedder, return_pooled=True
339
+ )
340
+ logger.info(f"Using tile-level pooling for two-stage retrieval")
341
+ else:
342
+ doc_embeddings = embed_documents(data["documents"], embedder)
343
+ pooled_embeddings = None
344
+ embed_time = time.time() - start_time
345
+ logger.info(f"Document embedding time: {embed_time:.2f}s")
346
+
347
+ # Embed queries
348
+ query_embeddings = embed_queries(data["queries"], embedder)
349
+
350
+ # Run search
351
+ logger.info("Running search...")
352
+ results = {}
353
+ search_times = []
354
+
355
+ for query in tqdm(data["queries"], desc="Searching"):
356
+ query_id = query["id"]
357
+ query_emb = query_embeddings[query_id]
358
+
359
+ start = time.time()
360
+ if two_stage:
361
+ ranking = search_two_stage(
362
+ query_emb, doc_embeddings, pooled_embeddings,
363
+ prefetch_k=prefetch_k, top_k=top_k
364
+ )
365
+ else:
366
+ ranking = search_exhaustive(query_emb, doc_embeddings, top_k=top_k)
367
+ search_times.append(time.time() - start)
368
+
369
+ results[query_id] = ranking
370
+
371
+ avg_search_time = np.mean(search_times)
372
+ logger.info(f"Average search time: {avg_search_time * 1000:.2f}ms")
373
+
374
+ # Compute metrics
375
+ metrics = compute_metrics(results, data["qrels"])
376
+ metrics["avg_search_time_ms"] = avg_search_time * 1000
377
+ metrics["embed_time_s"] = embed_time
378
+
379
+ logger.info(f"\nResults:")
380
+ for k, v in metrics.items():
381
+ logger.info(f" {k}: {v:.4f}")
382
+
383
+ # Save results
384
+ if output_dir:
385
+ output_path = Path(output_dir)
386
+ output_path.mkdir(parents=True, exist_ok=True)
387
+
388
+ dataset_short = dataset_name.split("/")[-1]
389
+ suffix = "_twostage" if two_stage else ""
390
+ result_file = output_path / f"{dataset_short}{suffix}.json"
391
+
392
+ with open(result_file, "w") as f:
393
+ json.dump({
394
+ "dataset": dataset_name,
395
+ "model": model_name,
396
+ "two_stage": two_stage,
397
+ "metrics": metrics,
398
+ }, f, indent=2)
399
+
400
+ logger.info(f"Saved results to: {result_file}")
401
+
402
+ return metrics
403
+
404
+
405
+ def main():
406
+ parser = argparse.ArgumentParser(
407
+ description="ViDoRe Benchmark Evaluation",
408
+ formatter_class=argparse.RawDescriptionHelpFormatter,
409
+ epilog=f"""
410
+ Available datasets:
411
+ {', '.join(VIDORE_DATASETS.keys())}
412
+
413
+ Examples:
414
+ # Quick test on DocVQA
415
+ python run_vidore.py --dataset docvqa
416
+
417
+ # Quick test with two-stage (your novel approach)
418
+ python run_vidore.py --dataset docvqa --two-stage
419
+
420
+ # Run on recommended quick datasets
421
+ python run_vidore.py --quick
422
+
423
+ # Full evaluation on all datasets
424
+ python run_vidore.py --all
425
+
426
+ # Compare exhaustive vs two-stage
427
+ python run_vidore.py --dataset docvqa
428
+ python run_vidore.py --dataset docvqa --two-stage
429
+ python analyze_results.py --results results/ --compare
430
+ """
431
+ )
432
+ parser.add_argument(
433
+ "--dataset", type=str, choices=list(VIDORE_DATASETS.keys()),
434
+ help=f"Dataset to evaluate: {', '.join(VIDORE_DATASETS.keys())}"
435
+ )
436
+ parser.add_argument(
437
+ "--quick", action="store_true",
438
+ help=f"Run on quick datasets: {QUICK_DATASETS}"
439
+ )
440
+ parser.add_argument(
441
+ "--all", action="store_true",
442
+ help="Evaluate on all ViDoRe datasets"
443
+ )
444
+ parser.add_argument(
445
+ "--model", type=str, default="vidore/colSmol-500M",
446
+ help="Model: vidore/colSmol-500M (default), vidore/colpali-v1.3, vidore/colqwen2-v1.0"
447
+ )
448
+ parser.add_argument(
449
+ "--two-stage", action="store_true",
450
+ help="Use two-stage retrieval (tile-level pooled prefetch + MaxSim rerank)"
451
+ )
452
+ parser.add_argument(
453
+ "--prefetch-k", type=int, default=100,
454
+ help="Stage 1 candidates (default: 100)"
455
+ )
456
+ parser.add_argument(
457
+ "--top-k", type=int, default=10,
458
+ help="Final results (default: 10)"
459
+ )
460
+ parser.add_argument(
461
+ "--output-dir", type=str, default="results",
462
+ help="Output directory (default: results)"
463
+ )
464
+
465
+ args = parser.parse_args()
466
+
467
+ # Determine which datasets to run
468
+ if args.all:
469
+ dataset_keys = ALL_DATASETS
470
+ elif args.quick:
471
+ dataset_keys = QUICK_DATASETS
472
+ elif args.dataset:
473
+ dataset_keys = [args.dataset]
474
+ else:
475
+ parser.error("Specify --dataset, --quick, or --all")
476
+
477
+ # Convert keys to full HuggingFace paths
478
+ datasets = [VIDORE_DATASETS[k] for k in dataset_keys]
479
+ logger.info(f"Running on {len(datasets)} dataset(s): {dataset_keys}")
480
+
481
+ all_results = {}
482
+ for dataset in datasets:
483
+ try:
484
+ metrics = run_evaluation(
485
+ dataset_name=dataset,
486
+ model_name=args.model,
487
+ two_stage=args.two_stage,
488
+ prefetch_k=args.prefetch_k,
489
+ top_k=args.top_k,
490
+ output_dir=args.output_dir,
491
+ )
492
+ all_results[dataset] = metrics
493
+ except Exception as e:
494
+ logger.error(f"Failed on {dataset}: {e}")
495
+ continue
496
+
497
+ # Summary
498
+ if len(all_results) > 1:
499
+ logger.info("\n" + "=" * 60)
500
+ logger.info("SUMMARY")
501
+ logger.info("=" * 60)
502
+
503
+ avg_ndcg10 = np.mean([m["ndcg@10"] for m in all_results.values()])
504
+ avg_mrr10 = np.mean([m["mrr@10"] for m in all_results.values()])
505
+
506
+ logger.info(f"Average NDCG@10: {avg_ndcg10:.4f}")
507
+ logger.info(f"Average MRR@10: {avg_mrr10:.4f}")
508
+
509
+
510
+ if __name__ == "__main__":
511
+ main()
512
+
513
+