visual-rag-toolkit 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. benchmarks/README.md +101 -0
  2. benchmarks/__init__.py +11 -0
  3. benchmarks/analyze_results.py +187 -0
  4. benchmarks/benchmark_datasets.txt +105 -0
  5. benchmarks/prepare_submission.py +205 -0
  6. benchmarks/quick_test.py +566 -0
  7. benchmarks/run_vidore.py +513 -0
  8. benchmarks/vidore_beir_qdrant/run_qdrant_beir.py +1365 -0
  9. benchmarks/vidore_tatdqa_test/COMMANDS.md +83 -0
  10. benchmarks/vidore_tatdqa_test/__init__.py +6 -0
  11. benchmarks/vidore_tatdqa_test/dataset_loader.py +363 -0
  12. benchmarks/vidore_tatdqa_test/metrics.py +44 -0
  13. benchmarks/vidore_tatdqa_test/run_qdrant.py +799 -0
  14. benchmarks/vidore_tatdqa_test/sweep_eval.py +372 -0
  15. demo/__init__.py +10 -0
  16. demo/app.py +45 -0
  17. demo/commands.py +334 -0
  18. demo/config.py +34 -0
  19. demo/download_models.py +75 -0
  20. demo/evaluation.py +602 -0
  21. demo/example_metadata_mapping_sigir.json +37 -0
  22. demo/indexing.py +286 -0
  23. demo/qdrant_utils.py +211 -0
  24. demo/results.py +35 -0
  25. demo/test_qdrant_connection.py +119 -0
  26. demo/ui/__init__.py +15 -0
  27. demo/ui/benchmark.py +355 -0
  28. demo/ui/header.py +30 -0
  29. demo/ui/playground.py +339 -0
  30. demo/ui/sidebar.py +162 -0
  31. demo/ui/upload.py +487 -0
  32. visual_rag/__init__.py +98 -0
  33. visual_rag/cli/__init__.py +1 -0
  34. visual_rag/cli/main.py +629 -0
  35. visual_rag/config.py +230 -0
  36. visual_rag/demo_runner.py +90 -0
  37. visual_rag/embedding/__init__.py +26 -0
  38. visual_rag/embedding/pooling.py +343 -0
  39. visual_rag/embedding/visual_embedder.py +622 -0
  40. visual_rag/indexing/__init__.py +21 -0
  41. visual_rag/indexing/cloudinary_uploader.py +274 -0
  42. visual_rag/indexing/pdf_processor.py +324 -0
  43. visual_rag/indexing/pipeline.py +628 -0
  44. visual_rag/indexing/qdrant_indexer.py +478 -0
  45. visual_rag/preprocessing/__init__.py +3 -0
  46. visual_rag/preprocessing/crop_empty.py +120 -0
  47. visual_rag/qdrant_admin.py +222 -0
  48. visual_rag/retrieval/__init__.py +19 -0
  49. visual_rag/retrieval/multi_vector.py +222 -0
  50. visual_rag/retrieval/single_stage.py +126 -0
  51. visual_rag/retrieval/three_stage.py +173 -0
  52. visual_rag/retrieval/two_stage.py +471 -0
  53. visual_rag/visualization/__init__.py +19 -0
  54. visual_rag/visualization/saliency.py +335 -0
  55. visual_rag_toolkit-0.1.1.dist-info/METADATA +305 -0
  56. visual_rag_toolkit-0.1.1.dist-info/RECORD +59 -0
  57. visual_rag_toolkit-0.1.1.dist-info/WHEEL +4 -0
  58. visual_rag_toolkit-0.1.1.dist-info/entry_points.txt +3 -0
  59. visual_rag_toolkit-0.1.1.dist-info/licenses/LICENSE +22 -0
@@ -0,0 +1,83 @@
1
+ # ViDoRe TAT-DQA (Qdrant) — commands
2
+
3
+ ## Environment
4
+
5
+ Either export:
6
+
7
+ ```bash
8
+ export QDRANT_URL="..."
9
+ export QDRANT_API_KEY="..." # optional
10
+ ```
11
+
12
+ Or create a `.env` file in `visual-rag-toolkit/` with the same variables.
13
+
14
+ ## Index + evaluate (single run)
15
+
16
+ This is the “all-in-one” script (indexes, then evaluates once):
17
+
18
+ ```bash
19
+ python -m benchmarks.vidore_tatdqa_test.run_qdrant \
20
+ --dataset vidore/tatdqa_test \
21
+ --collection vidore_tatdqa_test \
22
+ --recreate --index \
23
+ --indexing-threshold 0 \
24
+ --batch-size 6 \
25
+ --upload-batch-size 12 \
26
+ --upload-workers 0 \
27
+ --loader-workers 0 \
28
+ --prefer-grpc \
29
+ --torch-dtype float16 \
30
+ --no-upsert-wait \
31
+ --qdrant-vector-dtype float16
32
+ ```
33
+
34
+ ## Evaluate only (no re-index) — baseline + sweeps
35
+
36
+ These commands assume the Qdrant collection already exists and is populated.
37
+
38
+ ### Baseline: single-stage full MaxSim
39
+
40
+ ```bash
41
+ python -m benchmarks.vidore_tatdqa_test.sweep_eval \
42
+ --dataset vidore/tatdqa_test \
43
+ --collection vidore_tatdqa_test \
44
+ --prefer-grpc \
45
+ --mode single_full \
46
+ --torch-dtype auto \
47
+ --query-batch-size 32 \
48
+ --top-k 10 \
49
+ --out-dir results/sweeps
50
+ ```
51
+
52
+ ### Two-stage sweep (preferred): stage-1 tokens vs tiles, stage-2 full rerank
53
+
54
+ ```bash
55
+ python -m benchmarks.vidore_tatdqa_test.sweep_eval \
56
+ --dataset vidore/tatdqa_test \
57
+ --collection vidore_tatdqa_test \
58
+ --prefer-grpc \
59
+ --mode two_stage \
60
+ --stage1-mode tokens_vs_tiles \
61
+ --prefetch-ks 20,50,100,200,400 \
62
+ --torch-dtype auto \
63
+ --query-batch-size 32 \
64
+ --top-k 10 \
65
+ --out-dir results/sweeps
66
+ ```
67
+
68
+ ### Smoke test (optional): run only N queries
69
+
70
+ ```bash
71
+ python -m benchmarks.vidore_tatdqa_test.sweep_eval \
72
+ --dataset vidore/tatdqa_test \
73
+ --collection vidore_tatdqa_test \
74
+ --prefer-grpc \
75
+ --mode single_full \
76
+ --torch-dtype auto \
77
+ --query-batch-size 32 \
78
+ --top-k 10 \
79
+ --max-queries 50 \
80
+ --out-dir results/sweeps
81
+ ```
82
+
83
+
@@ -0,0 +1,6 @@
1
+ __all__ = []
2
+
3
+
4
+
5
+
6
+
@@ -0,0 +1,363 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ import hashlib
5
+ import re
6
+ from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple
7
+
8
+
9
+ @dataclass(frozen=True)
10
+ class CorpusDoc:
11
+ doc_id: str
12
+ image: Any
13
+ payload: Dict[str, Any]
14
+
15
+
16
+ @dataclass(frozen=True)
17
+ class Query:
18
+ query_id: str
19
+ text: str
20
+
21
+
22
+ def _as_str(x: Any) -> str:
23
+ if x is None:
24
+ return ""
25
+ return str(x)
26
+
27
+
28
+ def _stable_uuid(text: str) -> str:
29
+ hex_str = hashlib.sha256(text.encode("utf-8")).hexdigest()[:32]
30
+ return f"{hex_str[:8]}-{hex_str[8:12]}-{hex_str[12:16]}-{hex_str[16:20]}-{hex_str[20:32]}"
31
+
32
+ def paired_source_doc_id(row: Mapping[str, Any], idx: int) -> str:
33
+ source_doc_id = _as_str(row.get("_id"))
34
+ if source_doc_id:
35
+ return source_doc_id
36
+ image_filename = _as_str(row.get("image_filename"))
37
+ page = _as_str(row.get("page"))
38
+ return f"{image_filename}::page={page}::idx={idx}"
39
+
40
+
41
+ def paired_doc_id(row: Mapping[str, Any], idx: int) -> str:
42
+ return _stable_uuid(paired_source_doc_id(row, idx))
43
+
44
+
45
+ def paired_payload(row: Mapping[str, Any], idx: int) -> Dict[str, Any]:
46
+ return {
47
+ "source": _as_str(row.get("source")),
48
+ "image_filename": _as_str(row.get("image_filename")),
49
+ "page": _as_str(row.get("page")),
50
+ "source_doc_id": paired_source_doc_id(row, idx),
51
+ }
52
+
53
+
54
+ def _normalize_qrels(qrels_rows: Iterable[Mapping[str, Any]]) -> Dict[str, Dict[str, int]]:
55
+ qrels: Dict[str, Dict[str, int]] = {}
56
+ for row in qrels_rows:
57
+ qid = _as_str(row.get("query-id") or row.get("query_id") or row.get("qid"))
58
+ did = _as_str(row.get("corpus-id") or row.get("corpus_id") or row.get("doc_id") or row.get("did"))
59
+ score = row.get("score") or row.get("relevance") or row.get("label") or 0
60
+ try:
61
+ score_int = int(score)
62
+ except Exception:
63
+ score_int = 0
64
+ if not qid or not did:
65
+ continue
66
+ qrels.setdefault(qid, {})[_stable_uuid(did)] = score_int
67
+ return qrels
68
+
69
+
70
+ def _expect_fields(obj: Any, required: List[str], context: str) -> None:
71
+ missing = [k for k in required if k not in obj]
72
+ if missing:
73
+ raise ValueError(f"{context}: missing required field(s): {missing}. Available: {list(obj.keys())}")
74
+
75
+
76
+ def _extract_beir_splits(ds: Any):
77
+ if isinstance(ds, Mapping) and all(k in ds for k in ("corpus", "queries", "qrels")):
78
+ return ds["corpus"], ds["queries"], ds["qrels"]
79
+ if isinstance(ds, Mapping) and "test" in ds:
80
+ test_split = ds["test"]
81
+ if hasattr(test_split, "column_names"):
82
+ cols = set(test_split.column_names)
83
+ if all(k in cols for k in ("corpus", "queries", "qrels")):
84
+ row = test_split[0]
85
+ return row["corpus"], row["queries"], row["qrels"]
86
+ return None
87
+
88
+
89
+ def _first_split(ds: Any):
90
+ if isinstance(ds, Mapping):
91
+ if "test" in ds:
92
+ return ds["test"]
93
+ return ds[next(iter(ds.keys()))]
94
+ return ds
95
+
96
+
97
+ def _get_config_names(dataset_name: str) -> List[str]:
98
+ from datasets import load_dataset_builder
99
+
100
+ try:
101
+ builder = load_dataset_builder(dataset_name)
102
+ return list(getattr(builder, "builder_configs", {}).keys())
103
+ except Exception:
104
+ return []
105
+
106
+
107
+ def _normalize_dataset_alias(name: str) -> str:
108
+ s = str(name or "").strip().lower()
109
+ s = re.sub(r"[\s\\-]+", "_", s)
110
+ s = re.sub(r"[^a-z0-9_/]+", "", s)
111
+ s = re.sub(r"_+", "_", s)
112
+ return s
113
+
114
+
115
+ def _resolve_vidore_dataset_name(dataset_name: str) -> str:
116
+ raw = str(dataset_name or "").strip()
117
+ norm = _normalize_dataset_alias(raw)
118
+ if not norm:
119
+ return raw
120
+
121
+ aliases = {
122
+ "economics_macro_multilingual": "vidore/economics_reports_v2",
123
+ "economics_macro_multilingual_v2": "vidore/economics_reports_v2",
124
+ "economics_macro_multilingual_eng": "vidore/economics_reports_eng_v2",
125
+ "economics_macro_multilingual_eng_v2": "vidore/economics_reports_eng_v2",
126
+ "economics_reports": "vidore/economics_reports_v2",
127
+ "economics_reports_v2": "vidore/economics_reports_v2",
128
+ "economics_reports_eng": "vidore/economics_reports_eng_v2",
129
+ "economics_reports_eng_v2": "vidore/economics_reports_eng_v2",
130
+ }
131
+ if norm in aliases:
132
+ return aliases[norm]
133
+
134
+ candidates: List[str] = []
135
+ if "/" in norm:
136
+ candidates.append(norm)
137
+ repo = norm.rsplit("/", 1)[-1]
138
+ if repo.endswith("_v2"):
139
+ candidates.append(norm[: -(len("_v2"))])
140
+ else:
141
+ candidates.append(f"{norm}_v2")
142
+ elif norm.startswith("vidore/"):
143
+ candidates.append(norm)
144
+ if not norm.endswith("_v2"):
145
+ candidates.append(f"{norm}_v2")
146
+ else:
147
+ if norm.endswith("_v2"):
148
+ candidates.append(f"vidore/{norm}")
149
+ else:
150
+ candidates.append(f"vidore/{norm}_v2")
151
+ candidates.append(f"vidore/{norm}")
152
+
153
+ return candidates[0] if candidates else raw
154
+
155
+
156
+ def _load_dataset_with_beir_config(dataset_name: str, config_names: List[str]):
157
+ from datasets import load_dataset
158
+
159
+ preferred = [n for n in config_names if "beir" in n.lower()]
160
+ for name in preferred:
161
+ try:
162
+ ds = load_dataset(dataset_name, name=name)
163
+ except Exception:
164
+ continue
165
+ if _extract_beir_splits(ds) is not None:
166
+ return ds
167
+ return None
168
+
169
+
170
+ def _load_beir_from_separate_configs(dataset_name: str, config_names: List[str]):
171
+ from datasets import load_dataset
172
+
173
+ def _pick(names: List[str]) -> Optional[str]:
174
+ if not config_names:
175
+ return names[0] if names else None
176
+ for name in names:
177
+ if name in config_names:
178
+ return name
179
+ return None
180
+
181
+ corpus_name = _pick(["corpus", "docs"])
182
+ queries_name = _pick(["queries"])
183
+ qrels_name = _pick(["qrels"])
184
+ if not corpus_name or not queries_name or not qrels_name:
185
+ return None
186
+
187
+ try:
188
+ corpus_ds = load_dataset(dataset_name, name=corpus_name)
189
+ queries_ds = load_dataset(dataset_name, name=queries_name)
190
+ qrels_ds = load_dataset(dataset_name, name=qrels_name)
191
+ except Exception:
192
+ return None
193
+
194
+ return _first_split(corpus_ds), _first_split(queries_ds), _first_split(qrels_ds)
195
+
196
+
197
+ def load_vidore_beir_dataset(dataset_name: str) -> Tuple[List[CorpusDoc], List[Query], Dict[str, Dict[str, int]]]:
198
+ try:
199
+ from datasets import load_dataset
200
+ except ImportError as e:
201
+ raise ImportError("datasets is required. Install with: pip install datasets") from e
202
+ resolved = _resolve_vidore_dataset_name(dataset_name)
203
+ candidates = []
204
+ for cand in [resolved, dataset_name]:
205
+ cand = str(cand or "").strip()
206
+ if not cand:
207
+ continue
208
+ if cand not in candidates:
209
+ candidates.append(cand)
210
+ if "/" in cand:
211
+ repo = cand.rsplit("/", 1)[-1]
212
+ if repo.endswith("_v2"):
213
+ alt = cand[: -(len("_v2"))]
214
+ if alt not in candidates:
215
+ candidates.append(alt)
216
+ else:
217
+ alt = f"{cand}_v2"
218
+ if alt not in candidates:
219
+ candidates.append(alt)
220
+
221
+ last_err: Optional[Exception] = None
222
+ extracted = None
223
+ used_name = None
224
+ used_configs: List[str] = []
225
+ for name_try in candidates:
226
+ config_names = _get_config_names(name_try)
227
+ used_configs = config_names
228
+ ds = None
229
+ if not config_names or "default" in config_names:
230
+ try:
231
+ ds = load_dataset(name_try)
232
+ except Exception as e:
233
+ last_err = e
234
+ ds = None
235
+
236
+ extracted = _extract_beir_splits(ds) if ds is not None else None
237
+ if extracted is None:
238
+ ds_beir = _load_dataset_with_beir_config(name_try, config_names)
239
+ if ds_beir is not None:
240
+ extracted = _extract_beir_splits(ds_beir)
241
+ if extracted is None:
242
+ extracted = _load_beir_from_separate_configs(name_try, config_names)
243
+ if extracted is not None:
244
+ used_name = name_try
245
+ break
246
+
247
+ if extracted is None:
248
+ if last_err is not None and not used_configs:
249
+ raise ValueError(
250
+ "Could not load dataset (check the dataset id and HF access). "
251
+ f"Tried: {candidates}."
252
+ ) from last_err
253
+ raise ValueError(
254
+ "Dataset does not look like BEIR/ViDoRe-v2 format. "
255
+ f"Tried: {candidates}. Available configs: {used_configs or 'unknown'}"
256
+ )
257
+
258
+ corpus_split, queries_split, qrels_split = extracted
259
+
260
+ corpus_docs: List[CorpusDoc] = []
261
+ for row in corpus_split:
262
+ if "corpus-id" in row:
263
+ source_doc_id = _as_str(row["corpus-id"])
264
+ elif "_id" in row:
265
+ source_doc_id = _as_str(row["_id"])
266
+ elif "doc-id" in row:
267
+ source_doc_id = _as_str(row["doc-id"])
268
+ else:
269
+ _expect_fields(row, ["_id"], context="corpus row")
270
+ source_doc_id = _as_str(row["_id"])
271
+ doc_id = _stable_uuid(source_doc_id)
272
+ image = row.get("image") or row.get("page_image") or row.get("document") or row.get("img")
273
+ if image is None:
274
+ raise ValueError("corpus row: missing image field (tried image/page_image/document/img)")
275
+ payload = {
276
+ **{
277
+ k: v
278
+ for k, v in row.items()
279
+ if k != "image" and k != "page_image" and k != "document" and k != "img"
280
+ },
281
+ "source_doc_id": source_doc_id,
282
+ }
283
+ corpus_docs.append(CorpusDoc(doc_id=doc_id, image=image, payload=payload))
284
+
285
+ queries: List[Query] = []
286
+ for row in queries_split:
287
+ if "_id" in row:
288
+ qid = _as_str(row["_id"])
289
+ elif "query-id" in row:
290
+ qid = _as_str(row["query-id"])
291
+ elif "query_id" in row:
292
+ qid = _as_str(row["query_id"])
293
+ else:
294
+ _expect_fields(row, ["_id"], context="queries row")
295
+ qid = _as_str(row["_id"])
296
+ text = _as_str(row.get("text") or row.get("query") or row.get("question"))
297
+ if not text:
298
+ raise ValueError("queries row: missing text field (tried text/query/question)")
299
+ queries.append(Query(query_id=qid, text=text))
300
+
301
+ qrels = _normalize_qrels(qrels_split)
302
+ if not qrels:
303
+ raise ValueError("qrels split parsed to empty mapping; expected non-empty qrels")
304
+
305
+ return corpus_docs, queries, qrels
306
+
307
+
308
+ def load_vidore_paired_dataset(dataset_name: str) -> Tuple[List[CorpusDoc], List[Query], Dict[str, Dict[str, int]]]:
309
+ """
310
+ Load ViDoRe v1-style paired QA datasets.
311
+
312
+ Expected shape:
313
+ - single split (usually "test")
314
+ - each row has at least: query + image (+ optional metadata like page/image_filename/source)
315
+
316
+ This protocol is "paired": each query is relevant to its paired page image.
317
+ """
318
+ try:
319
+ from datasets import load_dataset
320
+ except ImportError as e:
321
+ raise ImportError("datasets is required. Install with: pip install datasets") from e
322
+
323
+ ds = load_dataset(dataset_name, split="test")
324
+ cols = set(ds.column_names)
325
+ if "query" not in cols and "question" not in cols:
326
+ raise ValueError(f"paired dataset: missing query/question column. Got: {sorted(cols)}")
327
+ if "image" not in cols and "page_image" not in cols:
328
+ raise ValueError(f"paired dataset: missing image/page_image column. Got: {sorted(cols)}")
329
+
330
+ corpus_docs: List[CorpusDoc] = []
331
+ queries: List[Query] = []
332
+ qrels: Dict[str, Dict[str, int]] = {}
333
+
334
+ image_cols = [c for c in ("image", "page_image") if c in cols]
335
+ ds_meta = ds.remove_columns(image_cols) if image_cols else ds
336
+
337
+ for idx, row in enumerate(ds_meta):
338
+ query_text = _as_str(row.get("query") or row.get("question"))
339
+ doc_id = paired_doc_id(row, idx)
340
+ query_id = _as_str(row.get("query_id") or row.get("_query_id")) or f"q_{idx}"
341
+ payload = paired_payload(row, idx)
342
+
343
+ corpus_docs.append(CorpusDoc(doc_id=doc_id, image=None, payload=payload))
344
+ queries.append(Query(query_id=query_id, text=query_text))
345
+ qrels[query_id] = {doc_id: 1}
346
+
347
+ return corpus_docs, queries, qrels
348
+
349
+
350
+ def load_vidore_dataset_auto(dataset_name: str) -> Tuple[List[CorpusDoc], List[Query], Dict[str, Dict[str, int]], str]:
351
+ """
352
+ Auto-detect ViDoRe dataset format.
353
+ Returns: (corpus, queries, qrels, protocol)
354
+ protocol in {"beir", "paired"}.
355
+ """
356
+ try:
357
+ corpus, queries, qrels = load_vidore_beir_dataset(dataset_name)
358
+ return corpus, queries, qrels, "beir"
359
+ except ValueError:
360
+ corpus, queries, qrels = load_vidore_paired_dataset(dataset_name)
361
+ return corpus, queries, qrels, "paired"
362
+
363
+
@@ -0,0 +1,44 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Dict, List
4
+
5
+
6
+ def _dcg(relevances: List[float]) -> float:
7
+ import math
8
+
9
+ score = 0.0
10
+ for i, rel in enumerate(relevances):
11
+ if rel <= 0:
12
+ continue
13
+ score += (2.0**rel - 1.0) / math.log2(i + 2)
14
+ return score
15
+
16
+
17
+ def ndcg_at_k(ranking: List[str], qrels: Dict[str, int], k: int) -> float:
18
+ rels = [float(qrels.get(doc_id, 0)) for doc_id in ranking[:k]]
19
+ dcg = _dcg(rels)
20
+ ideal_rels = sorted((float(v) for v in qrels.values()), reverse=True)[:k]
21
+ idcg = _dcg(ideal_rels)
22
+ if idcg <= 0:
23
+ return 0.0
24
+ return dcg / idcg
25
+
26
+
27
+ def mrr_at_k(ranking: List[str], qrels: Dict[str, int], k: int) -> float:
28
+ for i, doc_id in enumerate(ranking[:k]):
29
+ if qrels.get(doc_id, 0) > 0:
30
+ return 1.0 / (i + 1)
31
+ return 0.0
32
+
33
+
34
+ def recall_at_k(ranking: List[str], qrels: Dict[str, int], k: int) -> float:
35
+ relevant = {doc_id for doc_id, rel in qrels.items() if rel > 0}
36
+ if not relevant:
37
+ return 0.0
38
+ retrieved = set(ranking[:k])
39
+ return len(retrieved & relevant) / len(relevant)
40
+
41
+
42
+
43
+
44
+