visual-rag-toolkit 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/README.md +101 -0
- benchmarks/__init__.py +11 -0
- benchmarks/analyze_results.py +187 -0
- benchmarks/benchmark_datasets.txt +105 -0
- benchmarks/prepare_submission.py +205 -0
- benchmarks/quick_test.py +566 -0
- benchmarks/run_vidore.py +513 -0
- benchmarks/vidore_beir_qdrant/run_qdrant_beir.py +1365 -0
- benchmarks/vidore_tatdqa_test/COMMANDS.md +83 -0
- benchmarks/vidore_tatdqa_test/__init__.py +6 -0
- benchmarks/vidore_tatdqa_test/dataset_loader.py +363 -0
- benchmarks/vidore_tatdqa_test/metrics.py +44 -0
- benchmarks/vidore_tatdqa_test/run_qdrant.py +799 -0
- benchmarks/vidore_tatdqa_test/sweep_eval.py +372 -0
- demo/__init__.py +10 -0
- demo/app.py +45 -0
- demo/commands.py +334 -0
- demo/config.py +34 -0
- demo/download_models.py +75 -0
- demo/evaluation.py +602 -0
- demo/example_metadata_mapping_sigir.json +37 -0
- demo/indexing.py +286 -0
- demo/qdrant_utils.py +211 -0
- demo/results.py +35 -0
- demo/test_qdrant_connection.py +119 -0
- demo/ui/__init__.py +15 -0
- demo/ui/benchmark.py +355 -0
- demo/ui/header.py +30 -0
- demo/ui/playground.py +339 -0
- demo/ui/sidebar.py +162 -0
- demo/ui/upload.py +487 -0
- visual_rag/__init__.py +98 -0
- visual_rag/cli/__init__.py +1 -0
- visual_rag/cli/main.py +629 -0
- visual_rag/config.py +230 -0
- visual_rag/demo_runner.py +90 -0
- visual_rag/embedding/__init__.py +26 -0
- visual_rag/embedding/pooling.py +343 -0
- visual_rag/embedding/visual_embedder.py +622 -0
- visual_rag/indexing/__init__.py +21 -0
- visual_rag/indexing/cloudinary_uploader.py +274 -0
- visual_rag/indexing/pdf_processor.py +324 -0
- visual_rag/indexing/pipeline.py +628 -0
- visual_rag/indexing/qdrant_indexer.py +478 -0
- visual_rag/preprocessing/__init__.py +3 -0
- visual_rag/preprocessing/crop_empty.py +120 -0
- visual_rag/qdrant_admin.py +222 -0
- visual_rag/retrieval/__init__.py +19 -0
- visual_rag/retrieval/multi_vector.py +222 -0
- visual_rag/retrieval/single_stage.py +126 -0
- visual_rag/retrieval/three_stage.py +173 -0
- visual_rag/retrieval/two_stage.py +471 -0
- visual_rag/visualization/__init__.py +19 -0
- visual_rag/visualization/saliency.py +335 -0
- visual_rag_toolkit-0.1.1.dist-info/METADATA +305 -0
- visual_rag_toolkit-0.1.1.dist-info/RECORD +59 -0
- visual_rag_toolkit-0.1.1.dist-info/WHEEL +4 -0
- visual_rag_toolkit-0.1.1.dist-info/entry_points.txt +3 -0
- visual_rag_toolkit-0.1.1.dist-info/licenses/LICENSE +22 -0
benchmarks/run_vidore.py
ADDED
|
@@ -0,0 +1,513 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
ViDoRe Benchmark Evaluation Script
|
|
4
|
+
|
|
5
|
+
Evaluates visual document retrieval on the ViDoRe benchmark datasets.
|
|
6
|
+
|
|
7
|
+
Usage:
|
|
8
|
+
# Single dataset
|
|
9
|
+
python run_vidore.py --dataset vidore/docvqa_test_subsampled
|
|
10
|
+
|
|
11
|
+
# All datasets
|
|
12
|
+
python run_vidore.py --all
|
|
13
|
+
|
|
14
|
+
# With two-stage retrieval
|
|
15
|
+
python run_vidore.py --dataset vidore/docvqa_test_subsampled --two-stage
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
import argparse
|
|
19
|
+
import json
|
|
20
|
+
import time
|
|
21
|
+
import logging
|
|
22
|
+
from pathlib import Path
|
|
23
|
+
from typing import List, Dict, Any, Optional
|
|
24
|
+
|
|
25
|
+
import numpy as np
|
|
26
|
+
from tqdm import tqdm
|
|
27
|
+
|
|
28
|
+
logging.basicConfig(level=logging.INFO)
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
# ViDoRe benchmark datasets
|
|
33
|
+
# Official leaderboard: https://huggingface.co/spaces/vidore/vidore-leaderboard
|
|
34
|
+
VIDORE_DATASETS = {
|
|
35
|
+
# === RECOMMENDED FOR QUICK TESTING (smaller, faster) ===
|
|
36
|
+
"docvqa": "vidore/docvqa_test_subsampled", # ~500 queries, Document VQA
|
|
37
|
+
"infovqa": "vidore/infovqa_test_subsampled", # ~500 queries, Infographics
|
|
38
|
+
"tabfquad": "vidore/tabfquad_test_subsampled", # ~500 queries, Tables
|
|
39
|
+
|
|
40
|
+
# === FULL EVALUATION ===
|
|
41
|
+
"tatdqa": "vidore/tatdqa_test", # ~1500 queries, Financial tables
|
|
42
|
+
"arxivqa": "vidore/arxivqa_test_subsampled", # ~500 queries, Scientific papers
|
|
43
|
+
"shift": "vidore/shiftproject_test", # ~500 queries, Sustainability reports
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
# Aliases for convenience
|
|
47
|
+
QUICK_DATASETS = ["docvqa", "infovqa"] # Fast testing
|
|
48
|
+
ALL_DATASETS = list(VIDORE_DATASETS.keys())
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def load_dataset(dataset_name: str) -> Dict[str, Any]:
|
|
52
|
+
"""Load a ViDoRe dataset from HuggingFace."""
|
|
53
|
+
try:
|
|
54
|
+
from datasets import load_dataset
|
|
55
|
+
except ImportError:
|
|
56
|
+
raise ImportError("datasets library required. Install with: pip install datasets")
|
|
57
|
+
|
|
58
|
+
logger.info(f"Loading dataset: {dataset_name}")
|
|
59
|
+
|
|
60
|
+
# Load dataset
|
|
61
|
+
ds = load_dataset(dataset_name, split="test")
|
|
62
|
+
|
|
63
|
+
# Extract queries and documents
|
|
64
|
+
# ViDoRe format: each example has query, image, and relevant doc info
|
|
65
|
+
queries = []
|
|
66
|
+
documents = []
|
|
67
|
+
qrels = {} # query_id -> {doc_id: relevance}
|
|
68
|
+
|
|
69
|
+
for idx, example in enumerate(tqdm(ds, desc="Loading data")):
|
|
70
|
+
query_id = f"q_{idx}"
|
|
71
|
+
doc_id = f"d_{idx}"
|
|
72
|
+
|
|
73
|
+
# Get query text
|
|
74
|
+
query_text = example.get("query", example.get("question", ""))
|
|
75
|
+
queries.append({
|
|
76
|
+
"id": query_id,
|
|
77
|
+
"text": query_text,
|
|
78
|
+
})
|
|
79
|
+
|
|
80
|
+
# Get document image
|
|
81
|
+
image = example.get("image", example.get("page_image"))
|
|
82
|
+
documents.append({
|
|
83
|
+
"id": doc_id,
|
|
84
|
+
"image": image,
|
|
85
|
+
})
|
|
86
|
+
|
|
87
|
+
# Relevance (self-document is relevant)
|
|
88
|
+
qrels[query_id] = {doc_id: 1}
|
|
89
|
+
|
|
90
|
+
logger.info(f"Loaded {len(queries)} queries and {len(documents)} documents")
|
|
91
|
+
|
|
92
|
+
return {
|
|
93
|
+
"queries": queries,
|
|
94
|
+
"documents": documents,
|
|
95
|
+
"qrels": qrels,
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def embed_documents(
|
|
100
|
+
documents: List[Dict],
|
|
101
|
+
embedder,
|
|
102
|
+
batch_size: int = 4,
|
|
103
|
+
return_pooled: bool = False,
|
|
104
|
+
) -> Dict[str, np.ndarray]:
|
|
105
|
+
"""
|
|
106
|
+
Embed all documents.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
documents: List of {id, image} dicts
|
|
110
|
+
embedder: VisualEmbedder instance
|
|
111
|
+
batch_size: Batch size for embedding
|
|
112
|
+
return_pooled: Also return tile-level pooled embeddings (for two-stage)
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
doc_embeddings dict, and optionally pooled_embeddings dict
|
|
116
|
+
"""
|
|
117
|
+
from visual_rag.embedding.pooling import tile_level_mean_pooling
|
|
118
|
+
|
|
119
|
+
logger.info(f"Embedding {len(documents)} documents...")
|
|
120
|
+
|
|
121
|
+
images = [doc["image"] for doc in documents]
|
|
122
|
+
|
|
123
|
+
# Get embeddings with token info for proper pooling
|
|
124
|
+
embeddings, token_infos = embedder.embed_images(
|
|
125
|
+
images, batch_size=batch_size, return_token_info=True
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
doc_embeddings = {}
|
|
129
|
+
pooled_embeddings = {} if return_pooled else None
|
|
130
|
+
|
|
131
|
+
for doc, emb, token_info in zip(documents, embeddings, token_infos):
|
|
132
|
+
if hasattr(emb, "numpy"):
|
|
133
|
+
emb_np = emb.numpy()
|
|
134
|
+
elif hasattr(emb, "cpu"):
|
|
135
|
+
emb_np = emb.cpu().numpy()
|
|
136
|
+
else:
|
|
137
|
+
emb_np = np.array(emb)
|
|
138
|
+
|
|
139
|
+
doc_embeddings[doc["id"]] = emb_np.astype(np.float32)
|
|
140
|
+
|
|
141
|
+
# Compute tile-level pooling (NOVEL approach)
|
|
142
|
+
if return_pooled:
|
|
143
|
+
n_rows = token_info.get("n_rows", 4)
|
|
144
|
+
n_cols = token_info.get("n_cols", 3)
|
|
145
|
+
num_tiles = n_rows * n_cols + 1 if n_rows and n_cols else 13
|
|
146
|
+
|
|
147
|
+
pooled = tile_level_mean_pooling(emb_np, num_tiles, patches_per_tile=64)
|
|
148
|
+
pooled_embeddings[doc["id"]] = pooled.astype(np.float32)
|
|
149
|
+
|
|
150
|
+
if return_pooled:
|
|
151
|
+
return doc_embeddings, pooled_embeddings
|
|
152
|
+
return doc_embeddings
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def embed_queries(
|
|
156
|
+
queries: List[Dict],
|
|
157
|
+
embedder,
|
|
158
|
+
) -> Dict[str, np.ndarray]:
|
|
159
|
+
"""Embed all queries."""
|
|
160
|
+
logger.info(f"Embedding {len(queries)} queries...")
|
|
161
|
+
|
|
162
|
+
query_embeddings = {}
|
|
163
|
+
for query in tqdm(queries, desc="Embedding queries"):
|
|
164
|
+
emb = embedder.embed_query(query["text"])
|
|
165
|
+
if hasattr(emb, "numpy"):
|
|
166
|
+
emb = emb.numpy()
|
|
167
|
+
elif hasattr(emb, "cpu"):
|
|
168
|
+
emb = emb.cpu().numpy()
|
|
169
|
+
query_embeddings[query["id"]] = np.array(emb, dtype=np.float32)
|
|
170
|
+
|
|
171
|
+
return query_embeddings
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def compute_maxsim(query_emb: np.ndarray, doc_emb: np.ndarray) -> float:
|
|
175
|
+
"""Compute ColBERT-style MaxSim score."""
|
|
176
|
+
# Normalize
|
|
177
|
+
query_norm = query_emb / (np.linalg.norm(query_emb, axis=1, keepdims=True) + 1e-8)
|
|
178
|
+
doc_norm = doc_emb / (np.linalg.norm(doc_emb, axis=1, keepdims=True) + 1e-8)
|
|
179
|
+
|
|
180
|
+
# Compute similarity matrix
|
|
181
|
+
sim_matrix = np.dot(query_norm, doc_norm.T)
|
|
182
|
+
|
|
183
|
+
# MaxSim: max per query token, then sum
|
|
184
|
+
max_sims = sim_matrix.max(axis=1)
|
|
185
|
+
return float(max_sims.sum())
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def search_exhaustive(
|
|
189
|
+
query_emb: np.ndarray,
|
|
190
|
+
doc_embeddings: Dict[str, np.ndarray],
|
|
191
|
+
top_k: int = 10,
|
|
192
|
+
) -> List[Dict]:
|
|
193
|
+
"""Exhaustive MaxSim search over all documents."""
|
|
194
|
+
scores = []
|
|
195
|
+
for doc_id, doc_emb in doc_embeddings.items():
|
|
196
|
+
score = compute_maxsim(query_emb, doc_emb)
|
|
197
|
+
scores.append({"id": doc_id, "score": score})
|
|
198
|
+
|
|
199
|
+
# Sort by score
|
|
200
|
+
scores.sort(key=lambda x: x["score"], reverse=True)
|
|
201
|
+
return scores[:top_k]
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def search_two_stage(
|
|
205
|
+
query_emb: np.ndarray,
|
|
206
|
+
doc_embeddings: Dict[str, np.ndarray],
|
|
207
|
+
pooled_embeddings: Dict[str, np.ndarray],
|
|
208
|
+
prefetch_k: int = 100,
|
|
209
|
+
top_k: int = 10,
|
|
210
|
+
) -> List[Dict]:
|
|
211
|
+
"""
|
|
212
|
+
Two-stage retrieval: tile-level pooled prefetch + MaxSim rerank.
|
|
213
|
+
|
|
214
|
+
Stage 1: Use tile-level pooled vectors for fast retrieval
|
|
215
|
+
Each doc has [num_tiles, 128] pooled representation
|
|
216
|
+
Compute MaxSim on pooled vectors (much faster)
|
|
217
|
+
|
|
218
|
+
Stage 2: Exact MaxSim reranking on top candidates
|
|
219
|
+
Use full multi-vector embeddings for precision
|
|
220
|
+
"""
|
|
221
|
+
# Stage 1: Pooled MaxSim (fast approximation)
|
|
222
|
+
# Query pooled: mean across query tokens → [128]
|
|
223
|
+
query_pooled = query_emb.mean(axis=0)
|
|
224
|
+
query_pooled = query_pooled / (np.linalg.norm(query_pooled) + 1e-8)
|
|
225
|
+
|
|
226
|
+
stage1_scores = []
|
|
227
|
+
for doc_id, doc_pooled in pooled_embeddings.items():
|
|
228
|
+
# doc_pooled shape: [num_tiles, 128] from tile-level pooling
|
|
229
|
+
# Compute similarity with each tile, take max (simplified MaxSim)
|
|
230
|
+
doc_norm = doc_pooled / (np.linalg.norm(doc_pooled, axis=1, keepdims=True) + 1e-8)
|
|
231
|
+
tile_sims = np.dot(doc_norm, query_pooled)
|
|
232
|
+
score = float(tile_sims.max()) # Max tile similarity
|
|
233
|
+
stage1_scores.append({"id": doc_id, "score": score})
|
|
234
|
+
|
|
235
|
+
stage1_scores.sort(key=lambda x: x["score"], reverse=True)
|
|
236
|
+
candidates = stage1_scores[:prefetch_k]
|
|
237
|
+
|
|
238
|
+
# Stage 2: Exact MaxSim rerank on candidates
|
|
239
|
+
reranked = []
|
|
240
|
+
for cand in candidates:
|
|
241
|
+
doc_id = cand["id"]
|
|
242
|
+
doc_emb = doc_embeddings[doc_id]
|
|
243
|
+
score = compute_maxsim(query_emb, doc_emb)
|
|
244
|
+
reranked.append({
|
|
245
|
+
"id": doc_id,
|
|
246
|
+
"score": score,
|
|
247
|
+
"stage1_score": cand["score"],
|
|
248
|
+
"stage1_rank": stage1_scores.index(cand) + 1,
|
|
249
|
+
})
|
|
250
|
+
|
|
251
|
+
reranked.sort(key=lambda x: x["score"], reverse=True)
|
|
252
|
+
return reranked[:top_k]
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def compute_metrics(
|
|
256
|
+
results: Dict[str, List[Dict]],
|
|
257
|
+
qrels: Dict[str, Dict[str, int]],
|
|
258
|
+
) -> Dict[str, float]:
|
|
259
|
+
"""Compute evaluation metrics."""
|
|
260
|
+
ndcg_5 = []
|
|
261
|
+
ndcg_10 = []
|
|
262
|
+
mrr_10 = []
|
|
263
|
+
recall_5 = []
|
|
264
|
+
recall_10 = []
|
|
265
|
+
|
|
266
|
+
for query_id, ranking in results.items():
|
|
267
|
+
relevant = set(qrels.get(query_id, {}).keys())
|
|
268
|
+
|
|
269
|
+
# MRR@10
|
|
270
|
+
rr = 0.0
|
|
271
|
+
for i, doc in enumerate(ranking[:10]):
|
|
272
|
+
if doc["id"] in relevant:
|
|
273
|
+
rr = 1.0 / (i + 1)
|
|
274
|
+
break
|
|
275
|
+
mrr_10.append(rr)
|
|
276
|
+
|
|
277
|
+
# Recall@5, Recall@10
|
|
278
|
+
retrieved_5 = set(d["id"] for d in ranking[:5])
|
|
279
|
+
retrieved_10 = set(d["id"] for d in ranking[:10])
|
|
280
|
+
|
|
281
|
+
if relevant:
|
|
282
|
+
recall_5.append(len(retrieved_5 & relevant) / len(relevant))
|
|
283
|
+
recall_10.append(len(retrieved_10 & relevant) / len(relevant))
|
|
284
|
+
|
|
285
|
+
# NDCG@5, NDCG@10
|
|
286
|
+
dcg_5 = sum(
|
|
287
|
+
1.0 / np.log2(i + 2) for i, d in enumerate(ranking[:5]) if d["id"] in relevant
|
|
288
|
+
)
|
|
289
|
+
dcg_10 = sum(
|
|
290
|
+
1.0 / np.log2(i + 2) for i, d in enumerate(ranking[:10]) if d["id"] in relevant
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
# Ideal DCG
|
|
294
|
+
k_rel = min(len(relevant), 5)
|
|
295
|
+
idcg_5 = sum(1.0 / np.log2(i + 2) for i in range(k_rel))
|
|
296
|
+
k_rel = min(len(relevant), 10)
|
|
297
|
+
idcg_10 = sum(1.0 / np.log2(i + 2) for i in range(k_rel))
|
|
298
|
+
|
|
299
|
+
ndcg_5.append(dcg_5 / idcg_5 if idcg_5 > 0 else 0.0)
|
|
300
|
+
ndcg_10.append(dcg_10 / idcg_10 if idcg_10 > 0 else 0.0)
|
|
301
|
+
|
|
302
|
+
return {
|
|
303
|
+
"ndcg@5": float(np.mean(ndcg_5)),
|
|
304
|
+
"ndcg@10": float(np.mean(ndcg_10)),
|
|
305
|
+
"mrr@10": float(np.mean(mrr_10)),
|
|
306
|
+
"recall@5": float(np.mean(recall_5)),
|
|
307
|
+
"recall@10": float(np.mean(recall_10)),
|
|
308
|
+
}
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
def run_evaluation(
|
|
312
|
+
dataset_name: str,
|
|
313
|
+
model_name: str = "vidore/colSmol-500M",
|
|
314
|
+
two_stage: bool = False,
|
|
315
|
+
prefetch_k: int = 100,
|
|
316
|
+
top_k: int = 10,
|
|
317
|
+
output_dir: Optional[str] = None,
|
|
318
|
+
) -> Dict[str, Any]:
|
|
319
|
+
"""Run full evaluation on a dataset."""
|
|
320
|
+
from visual_rag.embedding import VisualEmbedder
|
|
321
|
+
|
|
322
|
+
logger.info(f"=" * 60)
|
|
323
|
+
logger.info(f"Evaluating: {dataset_name}")
|
|
324
|
+
logger.info(f"Model: {model_name}")
|
|
325
|
+
logger.info(f"Two-stage: {two_stage}")
|
|
326
|
+
logger.info(f"=" * 60)
|
|
327
|
+
|
|
328
|
+
# Load dataset
|
|
329
|
+
data = load_dataset(dataset_name)
|
|
330
|
+
|
|
331
|
+
# Initialize embedder
|
|
332
|
+
embedder = VisualEmbedder(model_name=model_name)
|
|
333
|
+
|
|
334
|
+
# Embed documents (with tile-level pooling if two-stage)
|
|
335
|
+
start_time = time.time()
|
|
336
|
+
if two_stage:
|
|
337
|
+
doc_embeddings, pooled_embeddings = embed_documents(
|
|
338
|
+
data["documents"], embedder, return_pooled=True
|
|
339
|
+
)
|
|
340
|
+
logger.info(f"Using tile-level pooling for two-stage retrieval")
|
|
341
|
+
else:
|
|
342
|
+
doc_embeddings = embed_documents(data["documents"], embedder)
|
|
343
|
+
pooled_embeddings = None
|
|
344
|
+
embed_time = time.time() - start_time
|
|
345
|
+
logger.info(f"Document embedding time: {embed_time:.2f}s")
|
|
346
|
+
|
|
347
|
+
# Embed queries
|
|
348
|
+
query_embeddings = embed_queries(data["queries"], embedder)
|
|
349
|
+
|
|
350
|
+
# Run search
|
|
351
|
+
logger.info("Running search...")
|
|
352
|
+
results = {}
|
|
353
|
+
search_times = []
|
|
354
|
+
|
|
355
|
+
for query in tqdm(data["queries"], desc="Searching"):
|
|
356
|
+
query_id = query["id"]
|
|
357
|
+
query_emb = query_embeddings[query_id]
|
|
358
|
+
|
|
359
|
+
start = time.time()
|
|
360
|
+
if two_stage:
|
|
361
|
+
ranking = search_two_stage(
|
|
362
|
+
query_emb, doc_embeddings, pooled_embeddings,
|
|
363
|
+
prefetch_k=prefetch_k, top_k=top_k
|
|
364
|
+
)
|
|
365
|
+
else:
|
|
366
|
+
ranking = search_exhaustive(query_emb, doc_embeddings, top_k=top_k)
|
|
367
|
+
search_times.append(time.time() - start)
|
|
368
|
+
|
|
369
|
+
results[query_id] = ranking
|
|
370
|
+
|
|
371
|
+
avg_search_time = np.mean(search_times)
|
|
372
|
+
logger.info(f"Average search time: {avg_search_time * 1000:.2f}ms")
|
|
373
|
+
|
|
374
|
+
# Compute metrics
|
|
375
|
+
metrics = compute_metrics(results, data["qrels"])
|
|
376
|
+
metrics["avg_search_time_ms"] = avg_search_time * 1000
|
|
377
|
+
metrics["embed_time_s"] = embed_time
|
|
378
|
+
|
|
379
|
+
logger.info(f"\nResults:")
|
|
380
|
+
for k, v in metrics.items():
|
|
381
|
+
logger.info(f" {k}: {v:.4f}")
|
|
382
|
+
|
|
383
|
+
# Save results
|
|
384
|
+
if output_dir:
|
|
385
|
+
output_path = Path(output_dir)
|
|
386
|
+
output_path.mkdir(parents=True, exist_ok=True)
|
|
387
|
+
|
|
388
|
+
dataset_short = dataset_name.split("/")[-1]
|
|
389
|
+
suffix = "_twostage" if two_stage else ""
|
|
390
|
+
result_file = output_path / f"{dataset_short}{suffix}.json"
|
|
391
|
+
|
|
392
|
+
with open(result_file, "w") as f:
|
|
393
|
+
json.dump({
|
|
394
|
+
"dataset": dataset_name,
|
|
395
|
+
"model": model_name,
|
|
396
|
+
"two_stage": two_stage,
|
|
397
|
+
"metrics": metrics,
|
|
398
|
+
}, f, indent=2)
|
|
399
|
+
|
|
400
|
+
logger.info(f"Saved results to: {result_file}")
|
|
401
|
+
|
|
402
|
+
return metrics
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
def main():
|
|
406
|
+
parser = argparse.ArgumentParser(
|
|
407
|
+
description="ViDoRe Benchmark Evaluation",
|
|
408
|
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
409
|
+
epilog=f"""
|
|
410
|
+
Available datasets:
|
|
411
|
+
{', '.join(VIDORE_DATASETS.keys())}
|
|
412
|
+
|
|
413
|
+
Examples:
|
|
414
|
+
# Quick test on DocVQA
|
|
415
|
+
python run_vidore.py --dataset docvqa
|
|
416
|
+
|
|
417
|
+
# Quick test with two-stage (your novel approach)
|
|
418
|
+
python run_vidore.py --dataset docvqa --two-stage
|
|
419
|
+
|
|
420
|
+
# Run on recommended quick datasets
|
|
421
|
+
python run_vidore.py --quick
|
|
422
|
+
|
|
423
|
+
# Full evaluation on all datasets
|
|
424
|
+
python run_vidore.py --all
|
|
425
|
+
|
|
426
|
+
# Compare exhaustive vs two-stage
|
|
427
|
+
python run_vidore.py --dataset docvqa
|
|
428
|
+
python run_vidore.py --dataset docvqa --two-stage
|
|
429
|
+
python analyze_results.py --results results/ --compare
|
|
430
|
+
"""
|
|
431
|
+
)
|
|
432
|
+
parser.add_argument(
|
|
433
|
+
"--dataset", type=str, choices=list(VIDORE_DATASETS.keys()),
|
|
434
|
+
help=f"Dataset to evaluate: {', '.join(VIDORE_DATASETS.keys())}"
|
|
435
|
+
)
|
|
436
|
+
parser.add_argument(
|
|
437
|
+
"--quick", action="store_true",
|
|
438
|
+
help=f"Run on quick datasets: {QUICK_DATASETS}"
|
|
439
|
+
)
|
|
440
|
+
parser.add_argument(
|
|
441
|
+
"--all", action="store_true",
|
|
442
|
+
help="Evaluate on all ViDoRe datasets"
|
|
443
|
+
)
|
|
444
|
+
parser.add_argument(
|
|
445
|
+
"--model", type=str, default="vidore/colSmol-500M",
|
|
446
|
+
help="Model: vidore/colSmol-500M (default), vidore/colpali-v1.3, vidore/colqwen2-v1.0"
|
|
447
|
+
)
|
|
448
|
+
parser.add_argument(
|
|
449
|
+
"--two-stage", action="store_true",
|
|
450
|
+
help="Use two-stage retrieval (tile-level pooled prefetch + MaxSim rerank)"
|
|
451
|
+
)
|
|
452
|
+
parser.add_argument(
|
|
453
|
+
"--prefetch-k", type=int, default=100,
|
|
454
|
+
help="Stage 1 candidates (default: 100)"
|
|
455
|
+
)
|
|
456
|
+
parser.add_argument(
|
|
457
|
+
"--top-k", type=int, default=10,
|
|
458
|
+
help="Final results (default: 10)"
|
|
459
|
+
)
|
|
460
|
+
parser.add_argument(
|
|
461
|
+
"--output-dir", type=str, default="results",
|
|
462
|
+
help="Output directory (default: results)"
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
args = parser.parse_args()
|
|
466
|
+
|
|
467
|
+
# Determine which datasets to run
|
|
468
|
+
if args.all:
|
|
469
|
+
dataset_keys = ALL_DATASETS
|
|
470
|
+
elif args.quick:
|
|
471
|
+
dataset_keys = QUICK_DATASETS
|
|
472
|
+
elif args.dataset:
|
|
473
|
+
dataset_keys = [args.dataset]
|
|
474
|
+
else:
|
|
475
|
+
parser.error("Specify --dataset, --quick, or --all")
|
|
476
|
+
|
|
477
|
+
# Convert keys to full HuggingFace paths
|
|
478
|
+
datasets = [VIDORE_DATASETS[k] for k in dataset_keys]
|
|
479
|
+
logger.info(f"Running on {len(datasets)} dataset(s): {dataset_keys}")
|
|
480
|
+
|
|
481
|
+
all_results = {}
|
|
482
|
+
for dataset in datasets:
|
|
483
|
+
try:
|
|
484
|
+
metrics = run_evaluation(
|
|
485
|
+
dataset_name=dataset,
|
|
486
|
+
model_name=args.model,
|
|
487
|
+
two_stage=args.two_stage,
|
|
488
|
+
prefetch_k=args.prefetch_k,
|
|
489
|
+
top_k=args.top_k,
|
|
490
|
+
output_dir=args.output_dir,
|
|
491
|
+
)
|
|
492
|
+
all_results[dataset] = metrics
|
|
493
|
+
except Exception as e:
|
|
494
|
+
logger.error(f"Failed on {dataset}: {e}")
|
|
495
|
+
continue
|
|
496
|
+
|
|
497
|
+
# Summary
|
|
498
|
+
if len(all_results) > 1:
|
|
499
|
+
logger.info("\n" + "=" * 60)
|
|
500
|
+
logger.info("SUMMARY")
|
|
501
|
+
logger.info("=" * 60)
|
|
502
|
+
|
|
503
|
+
avg_ndcg10 = np.mean([m["ndcg@10"] for m in all_results.values()])
|
|
504
|
+
avg_mrr10 = np.mean([m["mrr@10"] for m in all_results.values()])
|
|
505
|
+
|
|
506
|
+
logger.info(f"Average NDCG@10: {avg_ndcg10:.4f}")
|
|
507
|
+
logger.info(f"Average MRR@10: {avg_mrr10:.4f}")
|
|
508
|
+
|
|
509
|
+
|
|
510
|
+
if __name__ == "__main__":
|
|
511
|
+
main()
|
|
512
|
+
|
|
513
|
+
|