visual-rag-toolkit 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/README.md +101 -0
- benchmarks/__init__.py +11 -0
- benchmarks/analyze_results.py +187 -0
- benchmarks/benchmark_datasets.txt +105 -0
- benchmarks/prepare_submission.py +205 -0
- benchmarks/quick_test.py +566 -0
- benchmarks/run_vidore.py +513 -0
- benchmarks/vidore_beir_qdrant/run_qdrant_beir.py +1365 -0
- benchmarks/vidore_tatdqa_test/COMMANDS.md +83 -0
- benchmarks/vidore_tatdqa_test/__init__.py +6 -0
- benchmarks/vidore_tatdqa_test/dataset_loader.py +363 -0
- benchmarks/vidore_tatdqa_test/metrics.py +44 -0
- benchmarks/vidore_tatdqa_test/run_qdrant.py +799 -0
- benchmarks/vidore_tatdqa_test/sweep_eval.py +372 -0
- demo/__init__.py +10 -0
- demo/app.py +45 -0
- demo/commands.py +334 -0
- demo/config.py +34 -0
- demo/download_models.py +75 -0
- demo/evaluation.py +602 -0
- demo/example_metadata_mapping_sigir.json +37 -0
- demo/indexing.py +286 -0
- demo/qdrant_utils.py +211 -0
- demo/results.py +35 -0
- demo/test_qdrant_connection.py +119 -0
- demo/ui/__init__.py +15 -0
- demo/ui/benchmark.py +355 -0
- demo/ui/header.py +30 -0
- demo/ui/playground.py +339 -0
- demo/ui/sidebar.py +162 -0
- demo/ui/upload.py +487 -0
- visual_rag/__init__.py +98 -0
- visual_rag/cli/__init__.py +1 -0
- visual_rag/cli/main.py +629 -0
- visual_rag/config.py +230 -0
- visual_rag/demo_runner.py +90 -0
- visual_rag/embedding/__init__.py +26 -0
- visual_rag/embedding/pooling.py +343 -0
- visual_rag/embedding/visual_embedder.py +622 -0
- visual_rag/indexing/__init__.py +21 -0
- visual_rag/indexing/cloudinary_uploader.py +274 -0
- visual_rag/indexing/pdf_processor.py +324 -0
- visual_rag/indexing/pipeline.py +628 -0
- visual_rag/indexing/qdrant_indexer.py +478 -0
- visual_rag/preprocessing/__init__.py +3 -0
- visual_rag/preprocessing/crop_empty.py +120 -0
- visual_rag/qdrant_admin.py +222 -0
- visual_rag/retrieval/__init__.py +19 -0
- visual_rag/retrieval/multi_vector.py +222 -0
- visual_rag/retrieval/single_stage.py +126 -0
- visual_rag/retrieval/three_stage.py +173 -0
- visual_rag/retrieval/two_stage.py +471 -0
- visual_rag/visualization/__init__.py +19 -0
- visual_rag/visualization/saliency.py +335 -0
- visual_rag_toolkit-0.1.1.dist-info/METADATA +305 -0
- visual_rag_toolkit-0.1.1.dist-info/RECORD +59 -0
- visual_rag_toolkit-0.1.1.dist-info/WHEEL +4 -0
- visual_rag_toolkit-0.1.1.dist-info/entry_points.txt +3 -0
- visual_rag_toolkit-0.1.1.dist-info/licenses/LICENSE +22 -0
benchmarks/quick_test.py
ADDED
|
@@ -0,0 +1,566 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Quick Benchmark - Validate retrieval quality with ViDoRe data.
|
|
4
|
+
|
|
5
|
+
This script:
|
|
6
|
+
1. Downloads samples from ViDoRe (with ground truth relevance)
|
|
7
|
+
2. Embeds with ColSmol-500M
|
|
8
|
+
3. Tests retrieval strategies (exhaustive vs two-stage)
|
|
9
|
+
4. Computes METRICS: NDCG@K, MRR@K, Recall@K
|
|
10
|
+
5. Compares speed and quality
|
|
11
|
+
|
|
12
|
+
Usage:
|
|
13
|
+
python quick_test.py --samples 100
|
|
14
|
+
python quick_test.py --samples 500 --skip-exhaustive # Faster
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import sys
|
|
18
|
+
import time
|
|
19
|
+
import argparse
|
|
20
|
+
import logging
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
from typing import List, Dict, Any
|
|
23
|
+
|
|
24
|
+
# Add parent directory to Python path (so we can import visual_rag)
|
|
25
|
+
# This allows running the script directly without pip install
|
|
26
|
+
_script_dir = Path(__file__).parent
|
|
27
|
+
_parent_dir = _script_dir.parent
|
|
28
|
+
if str(_parent_dir) not in sys.path:
|
|
29
|
+
sys.path.insert(0, str(_parent_dir))
|
|
30
|
+
|
|
31
|
+
import numpy as np
|
|
32
|
+
from tqdm import tqdm
|
|
33
|
+
|
|
34
|
+
# Visual RAG imports (now works without pip install)
|
|
35
|
+
from visual_rag.embedding import VisualEmbedder
|
|
36
|
+
from visual_rag.embedding.pooling import (
|
|
37
|
+
tile_level_mean_pooling,
|
|
38
|
+
compute_maxsim_score,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
# Optional: datasets for ViDoRe
|
|
42
|
+
try:
|
|
43
|
+
from datasets import load_dataset as hf_load_dataset
|
|
44
|
+
HAS_DATASETS = True
|
|
45
|
+
except ImportError:
|
|
46
|
+
HAS_DATASETS = False
|
|
47
|
+
|
|
48
|
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
49
|
+
logger = logging.getLogger(__name__)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def load_vidore_sample(num_samples: int = 100) -> List[Dict]:
|
|
53
|
+
"""
|
|
54
|
+
Load sample from ViDoRe DocVQA with ground truth.
|
|
55
|
+
|
|
56
|
+
Each sample has a query and its relevant document (1:1 mapping).
|
|
57
|
+
This allows computing retrieval metrics.
|
|
58
|
+
"""
|
|
59
|
+
if not HAS_DATASETS:
|
|
60
|
+
logger.error("Install datasets: pip install datasets")
|
|
61
|
+
sys.exit(1)
|
|
62
|
+
|
|
63
|
+
logger.info(f"š„ Loading {num_samples} samples from ViDoRe DocVQA...")
|
|
64
|
+
|
|
65
|
+
ds = hf_load_dataset("vidore/docvqa_test_subsampled", split="test")
|
|
66
|
+
|
|
67
|
+
samples = []
|
|
68
|
+
for i, example in enumerate(ds):
|
|
69
|
+
if i >= num_samples:
|
|
70
|
+
break
|
|
71
|
+
|
|
72
|
+
samples.append({
|
|
73
|
+
"id": i,
|
|
74
|
+
"doc_id": f"doc_{i}",
|
|
75
|
+
"query_id": f"q_{i}",
|
|
76
|
+
"image": example.get("image", example.get("page_image")),
|
|
77
|
+
"query": example.get("query", example.get("question", "")),
|
|
78
|
+
# Ground truth: query i is relevant to doc i
|
|
79
|
+
"relevant_doc": f"doc_{i}",
|
|
80
|
+
})
|
|
81
|
+
|
|
82
|
+
logger.info(f"ā
Loaded {len(samples)} samples with ground truth")
|
|
83
|
+
return samples
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def embed_all(
|
|
87
|
+
samples: List[Dict],
|
|
88
|
+
model_name: str = "vidore/colSmol-500M",
|
|
89
|
+
) -> Dict[str, Any]:
|
|
90
|
+
"""Embed all documents and queries."""
|
|
91
|
+
logger.info(f"\nš¤ Loading model: {model_name}")
|
|
92
|
+
embedder = VisualEmbedder(model_name=model_name)
|
|
93
|
+
|
|
94
|
+
images = [s["image"] for s in samples]
|
|
95
|
+
queries = [s["query"] for s in samples if s["query"]]
|
|
96
|
+
|
|
97
|
+
# Embed images
|
|
98
|
+
logger.info(f"šØ Embedding {len(images)} documents...")
|
|
99
|
+
start_time = time.time()
|
|
100
|
+
|
|
101
|
+
embeddings, token_infos = embedder.embed_images(
|
|
102
|
+
images, batch_size=4, return_token_info=True
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
doc_embed_time = time.time() - start_time
|
|
106
|
+
logger.info(f" Time: {doc_embed_time:.2f}s ({doc_embed_time/len(images)*1000:.1f}ms/doc)")
|
|
107
|
+
|
|
108
|
+
# Process embeddings: extract visual tokens + tile-level pooling
|
|
109
|
+
doc_data = {}
|
|
110
|
+
for i, (emb, token_info) in enumerate(zip(embeddings, token_infos)):
|
|
111
|
+
if hasattr(emb, 'cpu'):
|
|
112
|
+
emb = emb.cpu()
|
|
113
|
+
emb_np = emb.numpy() if hasattr(emb, 'numpy') else np.array(emb)
|
|
114
|
+
|
|
115
|
+
# Extract visual tokens only (filter special tokens)
|
|
116
|
+
visual_indices = token_info["visual_token_indices"]
|
|
117
|
+
visual_emb = emb_np[visual_indices].astype(np.float32)
|
|
118
|
+
|
|
119
|
+
# Tile-level pooling
|
|
120
|
+
n_rows = token_info.get("n_rows", 4)
|
|
121
|
+
n_cols = token_info.get("n_cols", 3)
|
|
122
|
+
num_tiles = n_rows * n_cols + 1 if n_rows and n_cols else 13
|
|
123
|
+
|
|
124
|
+
tile_pooled = tile_level_mean_pooling(visual_emb, num_tiles, patches_per_tile=64)
|
|
125
|
+
|
|
126
|
+
doc_data[f"doc_{i}"] = {
|
|
127
|
+
"embedding": visual_emb,
|
|
128
|
+
"pooled": tile_pooled,
|
|
129
|
+
"num_visual_tokens": len(visual_indices),
|
|
130
|
+
"num_tiles": tile_pooled.shape[0],
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
# Embed queries
|
|
134
|
+
logger.info(f"š Embedding {len(queries)} queries...")
|
|
135
|
+
start_time = time.time()
|
|
136
|
+
|
|
137
|
+
query_data = {}
|
|
138
|
+
for i, query in enumerate(tqdm(queries, desc="Queries")):
|
|
139
|
+
q_emb = embedder.embed_query(query)
|
|
140
|
+
if hasattr(q_emb, 'cpu'):
|
|
141
|
+
q_emb = q_emb.cpu()
|
|
142
|
+
q_np = q_emb.numpy() if hasattr(q_emb, 'numpy') else np.array(q_emb)
|
|
143
|
+
query_data[f"q_{i}"] = q_np.astype(np.float32)
|
|
144
|
+
|
|
145
|
+
query_embed_time = time.time() - start_time
|
|
146
|
+
|
|
147
|
+
return {
|
|
148
|
+
"docs": doc_data,
|
|
149
|
+
"queries": query_data,
|
|
150
|
+
"samples": samples,
|
|
151
|
+
"doc_embed_time": doc_embed_time,
|
|
152
|
+
"query_embed_time": query_embed_time,
|
|
153
|
+
"model": model_name,
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def search_exhaustive(query_emb: np.ndarray, docs: Dict, top_k: int = 10) -> List[Dict]:
|
|
158
|
+
"""Exhaustive MaxSim search over all documents."""
|
|
159
|
+
scores = []
|
|
160
|
+
for doc_id, doc in docs.items():
|
|
161
|
+
score = compute_maxsim_score(query_emb, doc["embedding"])
|
|
162
|
+
scores.append({"id": doc_id, "score": score})
|
|
163
|
+
|
|
164
|
+
scores.sort(key=lambda x: x["score"], reverse=True)
|
|
165
|
+
return scores[:top_k]
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def search_two_stage(
|
|
169
|
+
query_emb: np.ndarray,
|
|
170
|
+
docs: Dict,
|
|
171
|
+
prefetch_k: int = 20,
|
|
172
|
+
top_k: int = 10,
|
|
173
|
+
) -> List[Dict]:
|
|
174
|
+
"""
|
|
175
|
+
Two-stage retrieval with tile-level pooling.
|
|
176
|
+
|
|
177
|
+
Stage 1: Fast prefetch using tile-pooled vectors
|
|
178
|
+
Stage 2: Exact MaxSim reranking on candidates
|
|
179
|
+
"""
|
|
180
|
+
# Stage 1: Tile-level pooled search
|
|
181
|
+
query_pooled = query_emb.mean(axis=0)
|
|
182
|
+
query_pooled = query_pooled / (np.linalg.norm(query_pooled) + 1e-8)
|
|
183
|
+
|
|
184
|
+
stage1_scores = []
|
|
185
|
+
for doc_id, doc in docs.items():
|
|
186
|
+
doc_pooled = doc["pooled"]
|
|
187
|
+
doc_norm = doc_pooled / (np.linalg.norm(doc_pooled, axis=1, keepdims=True) + 1e-8)
|
|
188
|
+
tile_sims = np.dot(doc_norm, query_pooled)
|
|
189
|
+
score = float(tile_sims.max())
|
|
190
|
+
stage1_scores.append({"id": doc_id, "score": score})
|
|
191
|
+
|
|
192
|
+
stage1_scores.sort(key=lambda x: x["score"], reverse=True)
|
|
193
|
+
candidates = stage1_scores[:prefetch_k]
|
|
194
|
+
|
|
195
|
+
# Stage 2: Exact MaxSim on candidates
|
|
196
|
+
reranked = []
|
|
197
|
+
for cand in candidates:
|
|
198
|
+
doc_id = cand["id"]
|
|
199
|
+
score = compute_maxsim_score(query_emb, docs[doc_id]["embedding"])
|
|
200
|
+
reranked.append({"id": doc_id, "score": score, "stage1_rank": stage1_scores.index(cand) + 1})
|
|
201
|
+
|
|
202
|
+
reranked.sort(key=lambda x: x["score"], reverse=True)
|
|
203
|
+
return reranked[:top_k]
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def compute_metrics(
|
|
207
|
+
results: Dict[str, List[Dict]],
|
|
208
|
+
samples: List[Dict],
|
|
209
|
+
k_values: List[int] = [1, 3, 5, 7, 10],
|
|
210
|
+
) -> Dict[str, float]:
|
|
211
|
+
"""
|
|
212
|
+
Compute retrieval metrics.
|
|
213
|
+
|
|
214
|
+
Since ViDoRe has 1:1 query-doc mapping (1 relevant doc per query):
|
|
215
|
+
- Recall@K (Hit Rate): Is the relevant doc in top-K? (0 or 1)
|
|
216
|
+
- Precision@K: (# relevant in top-K) / K
|
|
217
|
+
- MRR@K: 1/rank if found in top-K, else 0
|
|
218
|
+
- NDCG@K: DCG / IDCG with binary relevance
|
|
219
|
+
"""
|
|
220
|
+
metrics = {}
|
|
221
|
+
|
|
222
|
+
# Also track per-query ranks for analysis
|
|
223
|
+
all_ranks = []
|
|
224
|
+
|
|
225
|
+
for k in k_values:
|
|
226
|
+
recalls = []
|
|
227
|
+
precisions = []
|
|
228
|
+
mrrs = []
|
|
229
|
+
ndcgs = []
|
|
230
|
+
|
|
231
|
+
for sample in samples:
|
|
232
|
+
query_id = sample["query_id"]
|
|
233
|
+
relevant_doc = sample["relevant_doc"]
|
|
234
|
+
|
|
235
|
+
if query_id not in results:
|
|
236
|
+
continue
|
|
237
|
+
|
|
238
|
+
ranking = results[query_id][:k]
|
|
239
|
+
ranked_ids = [r["id"] for r in ranking]
|
|
240
|
+
|
|
241
|
+
# Find rank of relevant doc (1-indexed, 0 if not found)
|
|
242
|
+
rank = 0
|
|
243
|
+
for i, doc_id in enumerate(ranked_ids):
|
|
244
|
+
if doc_id == relevant_doc:
|
|
245
|
+
rank = i + 1
|
|
246
|
+
break
|
|
247
|
+
|
|
248
|
+
# Recall@K (Hit Rate): 1 if found in top-K
|
|
249
|
+
found = 1.0 if rank > 0 else 0.0
|
|
250
|
+
recalls.append(found)
|
|
251
|
+
|
|
252
|
+
# Precision@K: (# relevant found) / K
|
|
253
|
+
# With 1 relevant doc: 1/K if found, 0 otherwise
|
|
254
|
+
precision = found / k
|
|
255
|
+
precisions.append(precision)
|
|
256
|
+
|
|
257
|
+
# MRR@K: 1/rank if found
|
|
258
|
+
mrr = 1.0 / rank if rank > 0 else 0.0
|
|
259
|
+
mrrs.append(mrr)
|
|
260
|
+
|
|
261
|
+
# NDCG@K (binary relevance)
|
|
262
|
+
# DCG = 1/log2(rank+1) if found, 0 otherwise
|
|
263
|
+
# IDCG = 1/log2(2) = 1 (best case: relevant at rank 1)
|
|
264
|
+
dcg = 1.0 / np.log2(rank + 1) if rank > 0 else 0.0
|
|
265
|
+
idcg = 1.0
|
|
266
|
+
ndcg = dcg / idcg
|
|
267
|
+
ndcgs.append(ndcg)
|
|
268
|
+
|
|
269
|
+
# Track actual rank for analysis (only for k=10)
|
|
270
|
+
if k == max(k_values):
|
|
271
|
+
full_ranking = results[query_id]
|
|
272
|
+
full_rank = 0
|
|
273
|
+
for i, r in enumerate(full_ranking):
|
|
274
|
+
if r["id"] == relevant_doc:
|
|
275
|
+
full_rank = i + 1
|
|
276
|
+
break
|
|
277
|
+
all_ranks.append(full_rank)
|
|
278
|
+
|
|
279
|
+
metrics[f"Recall@{k}"] = np.mean(recalls)
|
|
280
|
+
metrics[f"P@{k}"] = np.mean(precisions)
|
|
281
|
+
metrics[f"MRR@{k}"] = np.mean(mrrs)
|
|
282
|
+
metrics[f"NDCG@{k}"] = np.mean(ndcgs)
|
|
283
|
+
|
|
284
|
+
# Add summary stats
|
|
285
|
+
if all_ranks:
|
|
286
|
+
found_ranks = [r for r in all_ranks if r > 0]
|
|
287
|
+
metrics["avg_rank"] = np.mean(found_ranks) if found_ranks else float('inf')
|
|
288
|
+
metrics["median_rank"] = np.median(found_ranks) if found_ranks else float('inf')
|
|
289
|
+
metrics["not_found"] = sum(1 for r in all_ranks if r == 0)
|
|
290
|
+
|
|
291
|
+
return metrics
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def run_benchmark(
|
|
295
|
+
data: Dict,
|
|
296
|
+
skip_exhaustive: bool = False,
|
|
297
|
+
prefetch_k: int = None,
|
|
298
|
+
top_k: int = 10,
|
|
299
|
+
) -> Dict[str, Dict]:
|
|
300
|
+
"""Run retrieval benchmark with metrics."""
|
|
301
|
+
docs = data["docs"]
|
|
302
|
+
queries = data["queries"]
|
|
303
|
+
samples = data["samples"]
|
|
304
|
+
num_docs = len(docs)
|
|
305
|
+
|
|
306
|
+
# Auto-set prefetch_k to be meaningful (default: 20, or 20% of docs if >100 docs)
|
|
307
|
+
if prefetch_k is None:
|
|
308
|
+
if num_docs <= 100:
|
|
309
|
+
prefetch_k = 20 # Default: prefetch 20, rerank to top-10
|
|
310
|
+
else:
|
|
311
|
+
prefetch_k = max(20, min(100, int(num_docs * 0.2))) # 20% for larger collections
|
|
312
|
+
|
|
313
|
+
# Ensure prefetch_k < num_docs for meaningful two-stage comparison
|
|
314
|
+
if prefetch_k >= num_docs:
|
|
315
|
+
logger.warning(f"ā ļø prefetch_k={prefetch_k} >= num_docs={num_docs}")
|
|
316
|
+
logger.warning(f" Two-stage will fetch ALL docs (same as exhaustive)")
|
|
317
|
+
logger.warning(f" Use --samples > {prefetch_k * 3} for meaningful comparison")
|
|
318
|
+
|
|
319
|
+
logger.info(f"š Benchmark config: {num_docs} docs, prefetch_k={prefetch_k}, top_k={top_k}")
|
|
320
|
+
logger.info(f" (Both methods return top-{top_k} results - realistic retrieval scenario)")
|
|
321
|
+
|
|
322
|
+
results = {}
|
|
323
|
+
|
|
324
|
+
# Two-stage retrieval (NOVEL)
|
|
325
|
+
logger.info(f"\nš¬ Running Two-Stage retrieval (prefetch top-{prefetch_k}, rerank to top-{top_k})...")
|
|
326
|
+
two_stage_results = {}
|
|
327
|
+
two_stage_times = []
|
|
328
|
+
|
|
329
|
+
for sample in tqdm(samples, desc="Two-Stage"):
|
|
330
|
+
query_id = sample["query_id"]
|
|
331
|
+
query_emb = queries[query_id]
|
|
332
|
+
|
|
333
|
+
start = time.time()
|
|
334
|
+
ranking = search_two_stage(query_emb, docs, prefetch_k=prefetch_k, top_k=top_k)
|
|
335
|
+
two_stage_times.append(time.time() - start)
|
|
336
|
+
|
|
337
|
+
two_stage_results[query_id] = ranking
|
|
338
|
+
|
|
339
|
+
two_stage_metrics = compute_metrics(two_stage_results, samples)
|
|
340
|
+
two_stage_metrics["avg_time_ms"] = np.mean(two_stage_times) * 1000
|
|
341
|
+
two_stage_metrics["prefetch_k"] = prefetch_k
|
|
342
|
+
two_stage_metrics["top_k"] = top_k
|
|
343
|
+
results["two_stage"] = two_stage_metrics
|
|
344
|
+
|
|
345
|
+
# Exhaustive search (baseline)
|
|
346
|
+
if not skip_exhaustive:
|
|
347
|
+
logger.info(f"š¬ Running Exhaustive MaxSim (searches ALL {num_docs} docs, returns top-{top_k})...")
|
|
348
|
+
exhaustive_results = {}
|
|
349
|
+
exhaustive_times = []
|
|
350
|
+
|
|
351
|
+
for sample in tqdm(samples, desc="Exhaustive"):
|
|
352
|
+
query_id = sample["query_id"]
|
|
353
|
+
query_emb = queries[query_id]
|
|
354
|
+
|
|
355
|
+
start = time.time()
|
|
356
|
+
ranking = search_exhaustive(query_emb, docs, top_k=top_k)
|
|
357
|
+
exhaustive_times.append(time.time() - start)
|
|
358
|
+
|
|
359
|
+
exhaustive_results[query_id] = ranking
|
|
360
|
+
|
|
361
|
+
exhaustive_metrics = compute_metrics(exhaustive_results, samples)
|
|
362
|
+
exhaustive_metrics["avg_time_ms"] = np.mean(exhaustive_times) * 1000
|
|
363
|
+
exhaustive_metrics["top_k"] = top_k
|
|
364
|
+
results["exhaustive"] = exhaustive_metrics
|
|
365
|
+
|
|
366
|
+
return results
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
def print_results(data: Dict, benchmark_results: Dict, show_precision: bool = False):
|
|
370
|
+
"""Print benchmark results."""
|
|
371
|
+
print("\n" + "=" * 80)
|
|
372
|
+
print("š BENCHMARK RESULTS")
|
|
373
|
+
print("=" * 80)
|
|
374
|
+
|
|
375
|
+
num_docs = len(data['docs'])
|
|
376
|
+
print(f"\nš¤ Model: {data['model']}")
|
|
377
|
+
print(f"š Documents: {num_docs}")
|
|
378
|
+
print(f"š Queries: {len(data['queries'])}")
|
|
379
|
+
|
|
380
|
+
# Embedding stats
|
|
381
|
+
sample_doc = list(data['docs'].values())[0]
|
|
382
|
+
print(f"\nš Embedding (after visual token filtering):")
|
|
383
|
+
print(f" Visual tokens per doc: {sample_doc['num_visual_tokens']}")
|
|
384
|
+
print(f" Tile-pooled vectors: {sample_doc['num_tiles']}")
|
|
385
|
+
|
|
386
|
+
if "two_stage" in benchmark_results:
|
|
387
|
+
prefetch_k = benchmark_results["two_stage"].get("prefetch_k", "?")
|
|
388
|
+
print(f" Two-stage prefetch_k: {prefetch_k} (of {num_docs} docs)")
|
|
389
|
+
|
|
390
|
+
# Method labels - clearer naming
|
|
391
|
+
def get_label(method):
|
|
392
|
+
if method == "two_stage":
|
|
393
|
+
return "Pooled+Rerank" # Tile-pooled prefetch + MaxSim rerank
|
|
394
|
+
else:
|
|
395
|
+
return "Full MaxSim" # Exhaustive MaxSim on all docs
|
|
396
|
+
|
|
397
|
+
# Recall / Hit Rate table
|
|
398
|
+
print(f"\nšÆ RECALL (Hit Rate) @ K:")
|
|
399
|
+
print(f" {'Method':<20} {'@1':>8} {'@3':>8} {'@5':>8} {'@7':>8} {'@10':>8}")
|
|
400
|
+
print(f" {'-'*60}")
|
|
401
|
+
|
|
402
|
+
for method, metrics in benchmark_results.items():
|
|
403
|
+
print(f" {get_label(method):<20} "
|
|
404
|
+
f"{metrics.get('Recall@1', 0):>8.3f} "
|
|
405
|
+
f"{metrics.get('Recall@3', 0):>8.3f} "
|
|
406
|
+
f"{metrics.get('Recall@5', 0):>8.3f} "
|
|
407
|
+
f"{metrics.get('Recall@7', 0):>8.3f} "
|
|
408
|
+
f"{metrics.get('Recall@10', 0):>8.3f}")
|
|
409
|
+
|
|
410
|
+
# Precision table (optional)
|
|
411
|
+
if show_precision:
|
|
412
|
+
print(f"\nš PRECISION @ K:")
|
|
413
|
+
print(f" {'Method':<20} {'@1':>8} {'@3':>8} {'@5':>8} {'@7':>8} {'@10':>8}")
|
|
414
|
+
print(f" {'-'*60}")
|
|
415
|
+
|
|
416
|
+
for method, metrics in benchmark_results.items():
|
|
417
|
+
print(f" {get_label(method):<20} "
|
|
418
|
+
f"{metrics.get('P@1', 0):>8.3f} "
|
|
419
|
+
f"{metrics.get('P@3', 0):>8.3f} "
|
|
420
|
+
f"{metrics.get('P@5', 0):>8.3f} "
|
|
421
|
+
f"{metrics.get('P@7', 0):>8.3f} "
|
|
422
|
+
f"{metrics.get('P@10', 0):>8.3f}")
|
|
423
|
+
|
|
424
|
+
# NDCG table
|
|
425
|
+
print(f"\nš NDCG @ K:")
|
|
426
|
+
print(f" {'Method':<20} {'@1':>8} {'@3':>8} {'@5':>8} {'@7':>8} {'@10':>8}")
|
|
427
|
+
print(f" {'-'*60}")
|
|
428
|
+
|
|
429
|
+
for method, metrics in benchmark_results.items():
|
|
430
|
+
print(f" {get_label(method):<20} "
|
|
431
|
+
f"{metrics.get('NDCG@1', 0):>8.3f} "
|
|
432
|
+
f"{metrics.get('NDCG@3', 0):>8.3f} "
|
|
433
|
+
f"{metrics.get('NDCG@5', 0):>8.3f} "
|
|
434
|
+
f"{metrics.get('NDCG@7', 0):>8.3f} "
|
|
435
|
+
f"{metrics.get('NDCG@10', 0):>8.3f}")
|
|
436
|
+
|
|
437
|
+
# MRR table
|
|
438
|
+
print(f"\nš MRR @ K:")
|
|
439
|
+
print(f" {'Method':<20} {'@1':>8} {'@3':>8} {'@5':>8} {'@7':>8} {'@10':>8}")
|
|
440
|
+
print(f" {'-'*60}")
|
|
441
|
+
|
|
442
|
+
for method, metrics in benchmark_results.items():
|
|
443
|
+
print(f" {get_label(method):<20} "
|
|
444
|
+
f"{metrics.get('MRR@1', 0):>8.3f} "
|
|
445
|
+
f"{metrics.get('MRR@3', 0):>8.3f} "
|
|
446
|
+
f"{metrics.get('MRR@5', 0):>8.3f} "
|
|
447
|
+
f"{metrics.get('MRR@7', 0):>8.3f} "
|
|
448
|
+
f"{metrics.get('MRR@10', 0):>8.3f}")
|
|
449
|
+
|
|
450
|
+
# Speed comparison
|
|
451
|
+
top_k = benchmark_results.get("two_stage", benchmark_results.get("exhaustive", {})).get("top_k", 10)
|
|
452
|
+
print(f"\nā±ļø SPEED (both return top-{top_k} results):")
|
|
453
|
+
print(f" {'Method':<20} {'Time (ms)':>12} {'Docs searched':>15}")
|
|
454
|
+
print(f" {'-'*50}")
|
|
455
|
+
|
|
456
|
+
for method, metrics in benchmark_results.items():
|
|
457
|
+
if method == "two_stage":
|
|
458
|
+
searched = metrics.get("prefetch_k", "?")
|
|
459
|
+
label = f"{searched} (stage-1)"
|
|
460
|
+
else:
|
|
461
|
+
searched = num_docs
|
|
462
|
+
label = f"{searched} (all)"
|
|
463
|
+
print(f" {get_label(method):<20} {metrics.get('avg_time_ms', 0):>12.2f} {label:>15}")
|
|
464
|
+
|
|
465
|
+
# Comparison summary
|
|
466
|
+
if "exhaustive" in benchmark_results and "two_stage" in benchmark_results:
|
|
467
|
+
ex = benchmark_results["exhaustive"]
|
|
468
|
+
ts = benchmark_results["two_stage"]
|
|
469
|
+
|
|
470
|
+
print(f"\nš” POOLED+RERANK vs FULL MAXSIM:")
|
|
471
|
+
|
|
472
|
+
for k in [1, 5, 10]:
|
|
473
|
+
ex_recall = ex.get(f"Recall@{k}", 0)
|
|
474
|
+
ts_recall = ts.get(f"Recall@{k}", 0)
|
|
475
|
+
if ex_recall > 0:
|
|
476
|
+
retention = ts_recall / ex_recall * 100
|
|
477
|
+
print(f" ⢠Recall@{k} retention: {retention:.1f}% ({ts_recall:.3f} vs {ex_recall:.3f})")
|
|
478
|
+
|
|
479
|
+
speedup = ex["avg_time_ms"] / ts["avg_time_ms"] if ts["avg_time_ms"] > 0 else 0
|
|
480
|
+
print(f" ⢠Speedup: {speedup:.1f}x")
|
|
481
|
+
|
|
482
|
+
# Rank stats with explanation
|
|
483
|
+
if "avg_rank" in ts:
|
|
484
|
+
prefetch_k = ts.get("prefetch_k", "?")
|
|
485
|
+
top_k = ts.get("top_k", 10)
|
|
486
|
+
not_found = ts.get("not_found", 0)
|
|
487
|
+
total = len(data["queries"])
|
|
488
|
+
|
|
489
|
+
print(f"\nš POOLED+RERANK STATISTICS:")
|
|
490
|
+
print(f" Stage-1 (pooled prefetch):")
|
|
491
|
+
print(f" ⢠Searches top-{prefetch_k} candidates using tile-pooled vectors")
|
|
492
|
+
print(f" ⢠{total - not_found}/{total} queries ({100 - not_found/total*100:.1f}%) had relevant doc in prefetch")
|
|
493
|
+
print(f" ⢠{not_found}/{total} queries ({not_found/total*100:.1f}%) missed (relevant doc ranked >{prefetch_k})")
|
|
494
|
+
print(f" Stage-2 (MaxSim reranking):")
|
|
495
|
+
print(f" ⢠Reranks prefetch candidates with exact MaxSim")
|
|
496
|
+
print(f" ⢠Returns final top-{top_k} results")
|
|
497
|
+
if ts['avg_rank'] < float('inf'):
|
|
498
|
+
print(f" ⢠Avg rank of relevant doc (when found): {ts['avg_rank']:.1f}")
|
|
499
|
+
print(f" ⢠Median rank: {ts['median_rank']:.1f}")
|
|
500
|
+
print(f"\n š” The {not_found/total*100:.1f}% miss rate is for stage-1 prefetch.")
|
|
501
|
+
print(f" Final Recall@{top_k} shows how many relevant docs ARE in top-{top_k} results.")
|
|
502
|
+
|
|
503
|
+
print("\n" + "=" * 80)
|
|
504
|
+
print("ā
Benchmark complete!")
|
|
505
|
+
|
|
506
|
+
|
|
507
|
+
def main():
|
|
508
|
+
parser = argparse.ArgumentParser(
|
|
509
|
+
description="Quick benchmark for visual-rag-toolkit",
|
|
510
|
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
511
|
+
)
|
|
512
|
+
parser.add_argument(
|
|
513
|
+
"--samples", type=int, default=100,
|
|
514
|
+
help="Number of samples (default: 100)"
|
|
515
|
+
)
|
|
516
|
+
parser.add_argument(
|
|
517
|
+
"--model", type=str, default="vidore/colSmol-500M",
|
|
518
|
+
help="Model: vidore/colSmol-500M (default), vidore/colpali-v1.3"
|
|
519
|
+
)
|
|
520
|
+
parser.add_argument(
|
|
521
|
+
"--prefetch-k", type=int, default=None,
|
|
522
|
+
help="Stage 1 candidates for two-stage (default: 20 for <=100 docs, auto for larger)"
|
|
523
|
+
)
|
|
524
|
+
parser.add_argument(
|
|
525
|
+
"--skip-exhaustive", action="store_true",
|
|
526
|
+
help="Skip exhaustive baseline (faster)"
|
|
527
|
+
)
|
|
528
|
+
parser.add_argument(
|
|
529
|
+
"--show-precision", action="store_true",
|
|
530
|
+
help="Show Precision@K metrics (hidden by default)"
|
|
531
|
+
)
|
|
532
|
+
parser.add_argument(
|
|
533
|
+
"--top-k", type=int, default=10,
|
|
534
|
+
help="Number of results to return (default: 10, realistic retrieval scenario)"
|
|
535
|
+
)
|
|
536
|
+
|
|
537
|
+
args = parser.parse_args()
|
|
538
|
+
|
|
539
|
+
print("\n" + "=" * 70)
|
|
540
|
+
print("š§Ŗ VISUAL RAG TOOLKIT - RETRIEVAL BENCHMARK")
|
|
541
|
+
print("=" * 70)
|
|
542
|
+
|
|
543
|
+
# Load samples
|
|
544
|
+
samples = load_vidore_sample(args.samples)
|
|
545
|
+
|
|
546
|
+
if not samples:
|
|
547
|
+
logger.error("No samples loaded!")
|
|
548
|
+
sys.exit(1)
|
|
549
|
+
|
|
550
|
+
# Embed all
|
|
551
|
+
data = embed_all(samples, args.model)
|
|
552
|
+
|
|
553
|
+
# Run benchmark
|
|
554
|
+
benchmark_results = run_benchmark(
|
|
555
|
+
data,
|
|
556
|
+
skip_exhaustive=args.skip_exhaustive,
|
|
557
|
+
prefetch_k=args.prefetch_k,
|
|
558
|
+
top_k=args.top_k,
|
|
559
|
+
)
|
|
560
|
+
|
|
561
|
+
# Print results
|
|
562
|
+
print_results(data, benchmark_results, show_precision=args.show_precision)
|
|
563
|
+
|
|
564
|
+
|
|
565
|
+
if __name__ == "__main__":
|
|
566
|
+
main()
|