visual-rag-toolkit 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/README.md +101 -0
- benchmarks/__init__.py +11 -0
- benchmarks/analyze_results.py +187 -0
- benchmarks/benchmark_datasets.txt +105 -0
- benchmarks/prepare_submission.py +205 -0
- benchmarks/quick_test.py +566 -0
- benchmarks/run_vidore.py +513 -0
- benchmarks/vidore_beir_qdrant/run_qdrant_beir.py +1365 -0
- benchmarks/vidore_tatdqa_test/COMMANDS.md +83 -0
- benchmarks/vidore_tatdqa_test/__init__.py +6 -0
- benchmarks/vidore_tatdqa_test/dataset_loader.py +363 -0
- benchmarks/vidore_tatdqa_test/metrics.py +44 -0
- benchmarks/vidore_tatdqa_test/run_qdrant.py +799 -0
- benchmarks/vidore_tatdqa_test/sweep_eval.py +372 -0
- demo/__init__.py +10 -0
- demo/app.py +45 -0
- demo/commands.py +334 -0
- demo/config.py +34 -0
- demo/download_models.py +75 -0
- demo/evaluation.py +602 -0
- demo/example_metadata_mapping_sigir.json +37 -0
- demo/indexing.py +286 -0
- demo/qdrant_utils.py +211 -0
- demo/results.py +35 -0
- demo/test_qdrant_connection.py +119 -0
- demo/ui/__init__.py +15 -0
- demo/ui/benchmark.py +355 -0
- demo/ui/header.py +30 -0
- demo/ui/playground.py +339 -0
- demo/ui/sidebar.py +162 -0
- demo/ui/upload.py +487 -0
- visual_rag/__init__.py +98 -0
- visual_rag/cli/__init__.py +1 -0
- visual_rag/cli/main.py +629 -0
- visual_rag/config.py +230 -0
- visual_rag/demo_runner.py +90 -0
- visual_rag/embedding/__init__.py +26 -0
- visual_rag/embedding/pooling.py +343 -0
- visual_rag/embedding/visual_embedder.py +622 -0
- visual_rag/indexing/__init__.py +21 -0
- visual_rag/indexing/cloudinary_uploader.py +274 -0
- visual_rag/indexing/pdf_processor.py +324 -0
- visual_rag/indexing/pipeline.py +628 -0
- visual_rag/indexing/qdrant_indexer.py +478 -0
- visual_rag/preprocessing/__init__.py +3 -0
- visual_rag/preprocessing/crop_empty.py +120 -0
- visual_rag/qdrant_admin.py +222 -0
- visual_rag/retrieval/__init__.py +19 -0
- visual_rag/retrieval/multi_vector.py +222 -0
- visual_rag/retrieval/single_stage.py +126 -0
- visual_rag/retrieval/three_stage.py +173 -0
- visual_rag/retrieval/two_stage.py +471 -0
- visual_rag/visualization/__init__.py +19 -0
- visual_rag/visualization/saliency.py +335 -0
- visual_rag_toolkit-0.1.1.dist-info/METADATA +305 -0
- visual_rag_toolkit-0.1.1.dist-info/RECORD +59 -0
- visual_rag_toolkit-0.1.1.dist-info/WHEEL +4 -0
- visual_rag_toolkit-0.1.1.dist-info/entry_points.txt +3 -0
- visual_rag_toolkit-0.1.1.dist-info/licenses/LICENSE +22 -0
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
# ViDoRe TAT-DQA (Qdrant) — commands
|
|
2
|
+
|
|
3
|
+
## Environment
|
|
4
|
+
|
|
5
|
+
Either export:
|
|
6
|
+
|
|
7
|
+
```bash
|
|
8
|
+
export QDRANT_URL="..."
|
|
9
|
+
export QDRANT_API_KEY="..." # optional
|
|
10
|
+
```
|
|
11
|
+
|
|
12
|
+
Or create a `.env` file in `visual-rag-toolkit/` with the same variables.
|
|
13
|
+
|
|
14
|
+
## Index + evaluate (single run)
|
|
15
|
+
|
|
16
|
+
This is the “all-in-one” script (indexes, then evaluates once):
|
|
17
|
+
|
|
18
|
+
```bash
|
|
19
|
+
python -m benchmarks.vidore_tatdqa_test.run_qdrant \
|
|
20
|
+
--dataset vidore/tatdqa_test \
|
|
21
|
+
--collection vidore_tatdqa_test \
|
|
22
|
+
--recreate --index \
|
|
23
|
+
--indexing-threshold 0 \
|
|
24
|
+
--batch-size 6 \
|
|
25
|
+
--upload-batch-size 12 \
|
|
26
|
+
--upload-workers 0 \
|
|
27
|
+
--loader-workers 0 \
|
|
28
|
+
--prefer-grpc \
|
|
29
|
+
--torch-dtype float16 \
|
|
30
|
+
--no-upsert-wait \
|
|
31
|
+
--qdrant-vector-dtype float16
|
|
32
|
+
```
|
|
33
|
+
|
|
34
|
+
## Evaluate only (no re-index) — baseline + sweeps
|
|
35
|
+
|
|
36
|
+
These commands assume the Qdrant collection already exists and is populated.
|
|
37
|
+
|
|
38
|
+
### Baseline: single-stage full MaxSim
|
|
39
|
+
|
|
40
|
+
```bash
|
|
41
|
+
python -m benchmarks.vidore_tatdqa_test.sweep_eval \
|
|
42
|
+
--dataset vidore/tatdqa_test \
|
|
43
|
+
--collection vidore_tatdqa_test \
|
|
44
|
+
--prefer-grpc \
|
|
45
|
+
--mode single_full \
|
|
46
|
+
--torch-dtype auto \
|
|
47
|
+
--query-batch-size 32 \
|
|
48
|
+
--top-k 10 \
|
|
49
|
+
--out-dir results/sweeps
|
|
50
|
+
```
|
|
51
|
+
|
|
52
|
+
### Two-stage sweep (preferred): stage-1 tokens vs tiles, stage-2 full rerank
|
|
53
|
+
|
|
54
|
+
```bash
|
|
55
|
+
python -m benchmarks.vidore_tatdqa_test.sweep_eval \
|
|
56
|
+
--dataset vidore/tatdqa_test \
|
|
57
|
+
--collection vidore_tatdqa_test \
|
|
58
|
+
--prefer-grpc \
|
|
59
|
+
--mode two_stage \
|
|
60
|
+
--stage1-mode tokens_vs_tiles \
|
|
61
|
+
--prefetch-ks 20,50,100,200,400 \
|
|
62
|
+
--torch-dtype auto \
|
|
63
|
+
--query-batch-size 32 \
|
|
64
|
+
--top-k 10 \
|
|
65
|
+
--out-dir results/sweeps
|
|
66
|
+
```
|
|
67
|
+
|
|
68
|
+
### Smoke test (optional): run only N queries
|
|
69
|
+
|
|
70
|
+
```bash
|
|
71
|
+
python -m benchmarks.vidore_tatdqa_test.sweep_eval \
|
|
72
|
+
--dataset vidore/tatdqa_test \
|
|
73
|
+
--collection vidore_tatdqa_test \
|
|
74
|
+
--prefer-grpc \
|
|
75
|
+
--mode single_full \
|
|
76
|
+
--torch-dtype auto \
|
|
77
|
+
--query-batch-size 32 \
|
|
78
|
+
--top-k 10 \
|
|
79
|
+
--max-queries 50 \
|
|
80
|
+
--out-dir results/sweeps
|
|
81
|
+
```
|
|
82
|
+
|
|
83
|
+
|
|
@@ -0,0 +1,363 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
import hashlib
|
|
5
|
+
import re
|
|
6
|
+
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass(frozen=True)
|
|
10
|
+
class CorpusDoc:
|
|
11
|
+
doc_id: str
|
|
12
|
+
image: Any
|
|
13
|
+
payload: Dict[str, Any]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass(frozen=True)
|
|
17
|
+
class Query:
|
|
18
|
+
query_id: str
|
|
19
|
+
text: str
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _as_str(x: Any) -> str:
|
|
23
|
+
if x is None:
|
|
24
|
+
return ""
|
|
25
|
+
return str(x)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _stable_uuid(text: str) -> str:
|
|
29
|
+
hex_str = hashlib.sha256(text.encode("utf-8")).hexdigest()[:32]
|
|
30
|
+
return f"{hex_str[:8]}-{hex_str[8:12]}-{hex_str[12:16]}-{hex_str[16:20]}-{hex_str[20:32]}"
|
|
31
|
+
|
|
32
|
+
def paired_source_doc_id(row: Mapping[str, Any], idx: int) -> str:
|
|
33
|
+
source_doc_id = _as_str(row.get("_id"))
|
|
34
|
+
if source_doc_id:
|
|
35
|
+
return source_doc_id
|
|
36
|
+
image_filename = _as_str(row.get("image_filename"))
|
|
37
|
+
page = _as_str(row.get("page"))
|
|
38
|
+
return f"{image_filename}::page={page}::idx={idx}"
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def paired_doc_id(row: Mapping[str, Any], idx: int) -> str:
|
|
42
|
+
return _stable_uuid(paired_source_doc_id(row, idx))
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def paired_payload(row: Mapping[str, Any], idx: int) -> Dict[str, Any]:
|
|
46
|
+
return {
|
|
47
|
+
"source": _as_str(row.get("source")),
|
|
48
|
+
"image_filename": _as_str(row.get("image_filename")),
|
|
49
|
+
"page": _as_str(row.get("page")),
|
|
50
|
+
"source_doc_id": paired_source_doc_id(row, idx),
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _normalize_qrels(qrels_rows: Iterable[Mapping[str, Any]]) -> Dict[str, Dict[str, int]]:
|
|
55
|
+
qrels: Dict[str, Dict[str, int]] = {}
|
|
56
|
+
for row in qrels_rows:
|
|
57
|
+
qid = _as_str(row.get("query-id") or row.get("query_id") or row.get("qid"))
|
|
58
|
+
did = _as_str(row.get("corpus-id") or row.get("corpus_id") or row.get("doc_id") or row.get("did"))
|
|
59
|
+
score = row.get("score") or row.get("relevance") or row.get("label") or 0
|
|
60
|
+
try:
|
|
61
|
+
score_int = int(score)
|
|
62
|
+
except Exception:
|
|
63
|
+
score_int = 0
|
|
64
|
+
if not qid or not did:
|
|
65
|
+
continue
|
|
66
|
+
qrels.setdefault(qid, {})[_stable_uuid(did)] = score_int
|
|
67
|
+
return qrels
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _expect_fields(obj: Any, required: List[str], context: str) -> None:
|
|
71
|
+
missing = [k for k in required if k not in obj]
|
|
72
|
+
if missing:
|
|
73
|
+
raise ValueError(f"{context}: missing required field(s): {missing}. Available: {list(obj.keys())}")
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _extract_beir_splits(ds: Any):
|
|
77
|
+
if isinstance(ds, Mapping) and all(k in ds for k in ("corpus", "queries", "qrels")):
|
|
78
|
+
return ds["corpus"], ds["queries"], ds["qrels"]
|
|
79
|
+
if isinstance(ds, Mapping) and "test" in ds:
|
|
80
|
+
test_split = ds["test"]
|
|
81
|
+
if hasattr(test_split, "column_names"):
|
|
82
|
+
cols = set(test_split.column_names)
|
|
83
|
+
if all(k in cols for k in ("corpus", "queries", "qrels")):
|
|
84
|
+
row = test_split[0]
|
|
85
|
+
return row["corpus"], row["queries"], row["qrels"]
|
|
86
|
+
return None
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _first_split(ds: Any):
|
|
90
|
+
if isinstance(ds, Mapping):
|
|
91
|
+
if "test" in ds:
|
|
92
|
+
return ds["test"]
|
|
93
|
+
return ds[next(iter(ds.keys()))]
|
|
94
|
+
return ds
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def _get_config_names(dataset_name: str) -> List[str]:
|
|
98
|
+
from datasets import load_dataset_builder
|
|
99
|
+
|
|
100
|
+
try:
|
|
101
|
+
builder = load_dataset_builder(dataset_name)
|
|
102
|
+
return list(getattr(builder, "builder_configs", {}).keys())
|
|
103
|
+
except Exception:
|
|
104
|
+
return []
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def _normalize_dataset_alias(name: str) -> str:
|
|
108
|
+
s = str(name or "").strip().lower()
|
|
109
|
+
s = re.sub(r"[\s\\-]+", "_", s)
|
|
110
|
+
s = re.sub(r"[^a-z0-9_/]+", "", s)
|
|
111
|
+
s = re.sub(r"_+", "_", s)
|
|
112
|
+
return s
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def _resolve_vidore_dataset_name(dataset_name: str) -> str:
|
|
116
|
+
raw = str(dataset_name or "").strip()
|
|
117
|
+
norm = _normalize_dataset_alias(raw)
|
|
118
|
+
if not norm:
|
|
119
|
+
return raw
|
|
120
|
+
|
|
121
|
+
aliases = {
|
|
122
|
+
"economics_macro_multilingual": "vidore/economics_reports_v2",
|
|
123
|
+
"economics_macro_multilingual_v2": "vidore/economics_reports_v2",
|
|
124
|
+
"economics_macro_multilingual_eng": "vidore/economics_reports_eng_v2",
|
|
125
|
+
"economics_macro_multilingual_eng_v2": "vidore/economics_reports_eng_v2",
|
|
126
|
+
"economics_reports": "vidore/economics_reports_v2",
|
|
127
|
+
"economics_reports_v2": "vidore/economics_reports_v2",
|
|
128
|
+
"economics_reports_eng": "vidore/economics_reports_eng_v2",
|
|
129
|
+
"economics_reports_eng_v2": "vidore/economics_reports_eng_v2",
|
|
130
|
+
}
|
|
131
|
+
if norm in aliases:
|
|
132
|
+
return aliases[norm]
|
|
133
|
+
|
|
134
|
+
candidates: List[str] = []
|
|
135
|
+
if "/" in norm:
|
|
136
|
+
candidates.append(norm)
|
|
137
|
+
repo = norm.rsplit("/", 1)[-1]
|
|
138
|
+
if repo.endswith("_v2"):
|
|
139
|
+
candidates.append(norm[: -(len("_v2"))])
|
|
140
|
+
else:
|
|
141
|
+
candidates.append(f"{norm}_v2")
|
|
142
|
+
elif norm.startswith("vidore/"):
|
|
143
|
+
candidates.append(norm)
|
|
144
|
+
if not norm.endswith("_v2"):
|
|
145
|
+
candidates.append(f"{norm}_v2")
|
|
146
|
+
else:
|
|
147
|
+
if norm.endswith("_v2"):
|
|
148
|
+
candidates.append(f"vidore/{norm}")
|
|
149
|
+
else:
|
|
150
|
+
candidates.append(f"vidore/{norm}_v2")
|
|
151
|
+
candidates.append(f"vidore/{norm}")
|
|
152
|
+
|
|
153
|
+
return candidates[0] if candidates else raw
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def _load_dataset_with_beir_config(dataset_name: str, config_names: List[str]):
|
|
157
|
+
from datasets import load_dataset
|
|
158
|
+
|
|
159
|
+
preferred = [n for n in config_names if "beir" in n.lower()]
|
|
160
|
+
for name in preferred:
|
|
161
|
+
try:
|
|
162
|
+
ds = load_dataset(dataset_name, name=name)
|
|
163
|
+
except Exception:
|
|
164
|
+
continue
|
|
165
|
+
if _extract_beir_splits(ds) is not None:
|
|
166
|
+
return ds
|
|
167
|
+
return None
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def _load_beir_from_separate_configs(dataset_name: str, config_names: List[str]):
|
|
171
|
+
from datasets import load_dataset
|
|
172
|
+
|
|
173
|
+
def _pick(names: List[str]) -> Optional[str]:
|
|
174
|
+
if not config_names:
|
|
175
|
+
return names[0] if names else None
|
|
176
|
+
for name in names:
|
|
177
|
+
if name in config_names:
|
|
178
|
+
return name
|
|
179
|
+
return None
|
|
180
|
+
|
|
181
|
+
corpus_name = _pick(["corpus", "docs"])
|
|
182
|
+
queries_name = _pick(["queries"])
|
|
183
|
+
qrels_name = _pick(["qrels"])
|
|
184
|
+
if not corpus_name or not queries_name or not qrels_name:
|
|
185
|
+
return None
|
|
186
|
+
|
|
187
|
+
try:
|
|
188
|
+
corpus_ds = load_dataset(dataset_name, name=corpus_name)
|
|
189
|
+
queries_ds = load_dataset(dataset_name, name=queries_name)
|
|
190
|
+
qrels_ds = load_dataset(dataset_name, name=qrels_name)
|
|
191
|
+
except Exception:
|
|
192
|
+
return None
|
|
193
|
+
|
|
194
|
+
return _first_split(corpus_ds), _first_split(queries_ds), _first_split(qrels_ds)
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def load_vidore_beir_dataset(dataset_name: str) -> Tuple[List[CorpusDoc], List[Query], Dict[str, Dict[str, int]]]:
|
|
198
|
+
try:
|
|
199
|
+
from datasets import load_dataset
|
|
200
|
+
except ImportError as e:
|
|
201
|
+
raise ImportError("datasets is required. Install with: pip install datasets") from e
|
|
202
|
+
resolved = _resolve_vidore_dataset_name(dataset_name)
|
|
203
|
+
candidates = []
|
|
204
|
+
for cand in [resolved, dataset_name]:
|
|
205
|
+
cand = str(cand or "").strip()
|
|
206
|
+
if not cand:
|
|
207
|
+
continue
|
|
208
|
+
if cand not in candidates:
|
|
209
|
+
candidates.append(cand)
|
|
210
|
+
if "/" in cand:
|
|
211
|
+
repo = cand.rsplit("/", 1)[-1]
|
|
212
|
+
if repo.endswith("_v2"):
|
|
213
|
+
alt = cand[: -(len("_v2"))]
|
|
214
|
+
if alt not in candidates:
|
|
215
|
+
candidates.append(alt)
|
|
216
|
+
else:
|
|
217
|
+
alt = f"{cand}_v2"
|
|
218
|
+
if alt not in candidates:
|
|
219
|
+
candidates.append(alt)
|
|
220
|
+
|
|
221
|
+
last_err: Optional[Exception] = None
|
|
222
|
+
extracted = None
|
|
223
|
+
used_name = None
|
|
224
|
+
used_configs: List[str] = []
|
|
225
|
+
for name_try in candidates:
|
|
226
|
+
config_names = _get_config_names(name_try)
|
|
227
|
+
used_configs = config_names
|
|
228
|
+
ds = None
|
|
229
|
+
if not config_names or "default" in config_names:
|
|
230
|
+
try:
|
|
231
|
+
ds = load_dataset(name_try)
|
|
232
|
+
except Exception as e:
|
|
233
|
+
last_err = e
|
|
234
|
+
ds = None
|
|
235
|
+
|
|
236
|
+
extracted = _extract_beir_splits(ds) if ds is not None else None
|
|
237
|
+
if extracted is None:
|
|
238
|
+
ds_beir = _load_dataset_with_beir_config(name_try, config_names)
|
|
239
|
+
if ds_beir is not None:
|
|
240
|
+
extracted = _extract_beir_splits(ds_beir)
|
|
241
|
+
if extracted is None:
|
|
242
|
+
extracted = _load_beir_from_separate_configs(name_try, config_names)
|
|
243
|
+
if extracted is not None:
|
|
244
|
+
used_name = name_try
|
|
245
|
+
break
|
|
246
|
+
|
|
247
|
+
if extracted is None:
|
|
248
|
+
if last_err is not None and not used_configs:
|
|
249
|
+
raise ValueError(
|
|
250
|
+
"Could not load dataset (check the dataset id and HF access). "
|
|
251
|
+
f"Tried: {candidates}."
|
|
252
|
+
) from last_err
|
|
253
|
+
raise ValueError(
|
|
254
|
+
"Dataset does not look like BEIR/ViDoRe-v2 format. "
|
|
255
|
+
f"Tried: {candidates}. Available configs: {used_configs or 'unknown'}"
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
corpus_split, queries_split, qrels_split = extracted
|
|
259
|
+
|
|
260
|
+
corpus_docs: List[CorpusDoc] = []
|
|
261
|
+
for row in corpus_split:
|
|
262
|
+
if "corpus-id" in row:
|
|
263
|
+
source_doc_id = _as_str(row["corpus-id"])
|
|
264
|
+
elif "_id" in row:
|
|
265
|
+
source_doc_id = _as_str(row["_id"])
|
|
266
|
+
elif "doc-id" in row:
|
|
267
|
+
source_doc_id = _as_str(row["doc-id"])
|
|
268
|
+
else:
|
|
269
|
+
_expect_fields(row, ["_id"], context="corpus row")
|
|
270
|
+
source_doc_id = _as_str(row["_id"])
|
|
271
|
+
doc_id = _stable_uuid(source_doc_id)
|
|
272
|
+
image = row.get("image") or row.get("page_image") or row.get("document") or row.get("img")
|
|
273
|
+
if image is None:
|
|
274
|
+
raise ValueError("corpus row: missing image field (tried image/page_image/document/img)")
|
|
275
|
+
payload = {
|
|
276
|
+
**{
|
|
277
|
+
k: v
|
|
278
|
+
for k, v in row.items()
|
|
279
|
+
if k != "image" and k != "page_image" and k != "document" and k != "img"
|
|
280
|
+
},
|
|
281
|
+
"source_doc_id": source_doc_id,
|
|
282
|
+
}
|
|
283
|
+
corpus_docs.append(CorpusDoc(doc_id=doc_id, image=image, payload=payload))
|
|
284
|
+
|
|
285
|
+
queries: List[Query] = []
|
|
286
|
+
for row in queries_split:
|
|
287
|
+
if "_id" in row:
|
|
288
|
+
qid = _as_str(row["_id"])
|
|
289
|
+
elif "query-id" in row:
|
|
290
|
+
qid = _as_str(row["query-id"])
|
|
291
|
+
elif "query_id" in row:
|
|
292
|
+
qid = _as_str(row["query_id"])
|
|
293
|
+
else:
|
|
294
|
+
_expect_fields(row, ["_id"], context="queries row")
|
|
295
|
+
qid = _as_str(row["_id"])
|
|
296
|
+
text = _as_str(row.get("text") or row.get("query") or row.get("question"))
|
|
297
|
+
if not text:
|
|
298
|
+
raise ValueError("queries row: missing text field (tried text/query/question)")
|
|
299
|
+
queries.append(Query(query_id=qid, text=text))
|
|
300
|
+
|
|
301
|
+
qrels = _normalize_qrels(qrels_split)
|
|
302
|
+
if not qrels:
|
|
303
|
+
raise ValueError("qrels split parsed to empty mapping; expected non-empty qrels")
|
|
304
|
+
|
|
305
|
+
return corpus_docs, queries, qrels
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
def load_vidore_paired_dataset(dataset_name: str) -> Tuple[List[CorpusDoc], List[Query], Dict[str, Dict[str, int]]]:
|
|
309
|
+
"""
|
|
310
|
+
Load ViDoRe v1-style paired QA datasets.
|
|
311
|
+
|
|
312
|
+
Expected shape:
|
|
313
|
+
- single split (usually "test")
|
|
314
|
+
- each row has at least: query + image (+ optional metadata like page/image_filename/source)
|
|
315
|
+
|
|
316
|
+
This protocol is "paired": each query is relevant to its paired page image.
|
|
317
|
+
"""
|
|
318
|
+
try:
|
|
319
|
+
from datasets import load_dataset
|
|
320
|
+
except ImportError as e:
|
|
321
|
+
raise ImportError("datasets is required. Install with: pip install datasets") from e
|
|
322
|
+
|
|
323
|
+
ds = load_dataset(dataset_name, split="test")
|
|
324
|
+
cols = set(ds.column_names)
|
|
325
|
+
if "query" not in cols and "question" not in cols:
|
|
326
|
+
raise ValueError(f"paired dataset: missing query/question column. Got: {sorted(cols)}")
|
|
327
|
+
if "image" not in cols and "page_image" not in cols:
|
|
328
|
+
raise ValueError(f"paired dataset: missing image/page_image column. Got: {sorted(cols)}")
|
|
329
|
+
|
|
330
|
+
corpus_docs: List[CorpusDoc] = []
|
|
331
|
+
queries: List[Query] = []
|
|
332
|
+
qrels: Dict[str, Dict[str, int]] = {}
|
|
333
|
+
|
|
334
|
+
image_cols = [c for c in ("image", "page_image") if c in cols]
|
|
335
|
+
ds_meta = ds.remove_columns(image_cols) if image_cols else ds
|
|
336
|
+
|
|
337
|
+
for idx, row in enumerate(ds_meta):
|
|
338
|
+
query_text = _as_str(row.get("query") or row.get("question"))
|
|
339
|
+
doc_id = paired_doc_id(row, idx)
|
|
340
|
+
query_id = _as_str(row.get("query_id") or row.get("_query_id")) or f"q_{idx}"
|
|
341
|
+
payload = paired_payload(row, idx)
|
|
342
|
+
|
|
343
|
+
corpus_docs.append(CorpusDoc(doc_id=doc_id, image=None, payload=payload))
|
|
344
|
+
queries.append(Query(query_id=query_id, text=query_text))
|
|
345
|
+
qrels[query_id] = {doc_id: 1}
|
|
346
|
+
|
|
347
|
+
return corpus_docs, queries, qrels
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
def load_vidore_dataset_auto(dataset_name: str) -> Tuple[List[CorpusDoc], List[Query], Dict[str, Dict[str, int]], str]:
|
|
351
|
+
"""
|
|
352
|
+
Auto-detect ViDoRe dataset format.
|
|
353
|
+
Returns: (corpus, queries, qrels, protocol)
|
|
354
|
+
protocol in {"beir", "paired"}.
|
|
355
|
+
"""
|
|
356
|
+
try:
|
|
357
|
+
corpus, queries, qrels = load_vidore_beir_dataset(dataset_name)
|
|
358
|
+
return corpus, queries, qrels, "beir"
|
|
359
|
+
except ValueError:
|
|
360
|
+
corpus, queries, qrels = load_vidore_paired_dataset(dataset_name)
|
|
361
|
+
return corpus, queries, qrels, "paired"
|
|
362
|
+
|
|
363
|
+
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Dict, List
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def _dcg(relevances: List[float]) -> float:
|
|
7
|
+
import math
|
|
8
|
+
|
|
9
|
+
score = 0.0
|
|
10
|
+
for i, rel in enumerate(relevances):
|
|
11
|
+
if rel <= 0:
|
|
12
|
+
continue
|
|
13
|
+
score += (2.0**rel - 1.0) / math.log2(i + 2)
|
|
14
|
+
return score
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def ndcg_at_k(ranking: List[str], qrels: Dict[str, int], k: int) -> float:
|
|
18
|
+
rels = [float(qrels.get(doc_id, 0)) for doc_id in ranking[:k]]
|
|
19
|
+
dcg = _dcg(rels)
|
|
20
|
+
ideal_rels = sorted((float(v) for v in qrels.values()), reverse=True)[:k]
|
|
21
|
+
idcg = _dcg(ideal_rels)
|
|
22
|
+
if idcg <= 0:
|
|
23
|
+
return 0.0
|
|
24
|
+
return dcg / idcg
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def mrr_at_k(ranking: List[str], qrels: Dict[str, int], k: int) -> float:
|
|
28
|
+
for i, doc_id in enumerate(ranking[:k]):
|
|
29
|
+
if qrels.get(doc_id, 0) > 0:
|
|
30
|
+
return 1.0 / (i + 1)
|
|
31
|
+
return 0.0
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def recall_at_k(ranking: List[str], qrels: Dict[str, int], k: int) -> float:
|
|
35
|
+
relevant = {doc_id for doc_id, rel in qrels.items() if rel > 0}
|
|
36
|
+
if not relevant:
|
|
37
|
+
return 0.0
|
|
38
|
+
retrieved = set(ranking[:k])
|
|
39
|
+
return len(retrieved & relevant) / len(relevant)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
|