visual-rag-toolkit 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. benchmarks/README.md +101 -0
  2. benchmarks/__init__.py +11 -0
  3. benchmarks/analyze_results.py +187 -0
  4. benchmarks/benchmark_datasets.txt +105 -0
  5. benchmarks/prepare_submission.py +205 -0
  6. benchmarks/quick_test.py +566 -0
  7. benchmarks/run_vidore.py +513 -0
  8. benchmarks/vidore_beir_qdrant/run_qdrant_beir.py +1365 -0
  9. benchmarks/vidore_tatdqa_test/COMMANDS.md +83 -0
  10. benchmarks/vidore_tatdqa_test/__init__.py +6 -0
  11. benchmarks/vidore_tatdqa_test/dataset_loader.py +363 -0
  12. benchmarks/vidore_tatdqa_test/metrics.py +44 -0
  13. benchmarks/vidore_tatdqa_test/run_qdrant.py +799 -0
  14. benchmarks/vidore_tatdqa_test/sweep_eval.py +372 -0
  15. demo/__init__.py +10 -0
  16. demo/app.py +45 -0
  17. demo/commands.py +334 -0
  18. demo/config.py +34 -0
  19. demo/download_models.py +75 -0
  20. demo/evaluation.py +602 -0
  21. demo/example_metadata_mapping_sigir.json +37 -0
  22. demo/indexing.py +286 -0
  23. demo/qdrant_utils.py +211 -0
  24. demo/results.py +35 -0
  25. demo/test_qdrant_connection.py +119 -0
  26. demo/ui/__init__.py +15 -0
  27. demo/ui/benchmark.py +355 -0
  28. demo/ui/header.py +30 -0
  29. demo/ui/playground.py +339 -0
  30. demo/ui/sidebar.py +162 -0
  31. demo/ui/upload.py +487 -0
  32. visual_rag/__init__.py +98 -0
  33. visual_rag/cli/__init__.py +1 -0
  34. visual_rag/cli/main.py +629 -0
  35. visual_rag/config.py +230 -0
  36. visual_rag/demo_runner.py +90 -0
  37. visual_rag/embedding/__init__.py +26 -0
  38. visual_rag/embedding/pooling.py +343 -0
  39. visual_rag/embedding/visual_embedder.py +622 -0
  40. visual_rag/indexing/__init__.py +21 -0
  41. visual_rag/indexing/cloudinary_uploader.py +274 -0
  42. visual_rag/indexing/pdf_processor.py +324 -0
  43. visual_rag/indexing/pipeline.py +628 -0
  44. visual_rag/indexing/qdrant_indexer.py +478 -0
  45. visual_rag/preprocessing/__init__.py +3 -0
  46. visual_rag/preprocessing/crop_empty.py +120 -0
  47. visual_rag/qdrant_admin.py +222 -0
  48. visual_rag/retrieval/__init__.py +19 -0
  49. visual_rag/retrieval/multi_vector.py +222 -0
  50. visual_rag/retrieval/single_stage.py +126 -0
  51. visual_rag/retrieval/three_stage.py +173 -0
  52. visual_rag/retrieval/two_stage.py +471 -0
  53. visual_rag/visualization/__init__.py +19 -0
  54. visual_rag/visualization/saliency.py +335 -0
  55. visual_rag_toolkit-0.1.1.dist-info/METADATA +305 -0
  56. visual_rag_toolkit-0.1.1.dist-info/RECORD +59 -0
  57. visual_rag_toolkit-0.1.1.dist-info/WHEEL +4 -0
  58. visual_rag_toolkit-0.1.1.dist-info/entry_points.txt +3 -0
  59. visual_rag_toolkit-0.1.1.dist-info/licenses/LICENSE +22 -0
@@ -0,0 +1,372 @@
1
+ import argparse
2
+ import json
3
+ import os
4
+ import time
5
+ import logging
6
+ from pathlib import Path
7
+ from typing import Dict, List, Optional, Tuple
8
+
9
+ import numpy as np
10
+
11
+ from benchmarks.vidore_tatdqa_test.dataset_loader import load_vidore_dataset_auto
12
+ from benchmarks.vidore_tatdqa_test.metrics import ndcg_at_k, mrr_at_k, recall_at_k
13
+ from visual_rag import VisualEmbedder
14
+ from visual_rag.retrieval import MultiVectorRetriever
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def _maybe_load_dotenv() -> None:
20
+ try:
21
+ from dotenv import load_dotenv
22
+ except ImportError:
23
+ return
24
+ if Path(".env").exists():
25
+ load_dotenv(".env")
26
+
27
+ def _torch_dtype_to_str(dtype) -> str:
28
+ if dtype is None:
29
+ return "auto"
30
+ s = str(dtype)
31
+ return s.replace("torch.", "")
32
+
33
+
34
+ def _parse_torch_dtype(dtype_str: str):
35
+ if dtype_str == "auto":
36
+ return None
37
+ import torch
38
+
39
+ mapping = {
40
+ "float32": torch.float32,
41
+ "float16": torch.float16,
42
+ "bfloat16": torch.bfloat16,
43
+ }
44
+ return mapping[dtype_str]
45
+
46
+
47
+ def _infer_collection_vector_dtype(*, client, collection_name: str) -> Optional[str]:
48
+ try:
49
+ info = client.get_collection(collection_name)
50
+ except Exception:
51
+ return None
52
+ vectors = getattr(getattr(getattr(info, "config", None), "params", None), "vectors", None)
53
+ if not vectors:
54
+ return None
55
+
56
+ initial = None
57
+ if isinstance(vectors, dict):
58
+ initial = vectors.get("initial")
59
+ else:
60
+ try:
61
+ initial = vectors.get("initial")
62
+ except Exception:
63
+ initial = None
64
+
65
+ dt = getattr(initial, "datatype", None) if initial is not None else None
66
+ if dt is None:
67
+ return None
68
+
69
+ s = str(dt).lower()
70
+ if "float16" in s:
71
+ return "float16"
72
+ if "float32" in s:
73
+ return "float32"
74
+ return None
75
+
76
+
77
+ def _evaluate(
78
+ *,
79
+ retriever: MultiVectorRetriever,
80
+ queries: List,
81
+ qrels: Dict[str, Dict[str, int]],
82
+ top_k: int,
83
+ mode: str,
84
+ stage1_mode: str,
85
+ prefetch_k: int,
86
+ max_queries: int = 0,
87
+ precomputed_query_embeddings: Optional[List[np.ndarray]] = None,
88
+ ) -> Dict[str, float]:
89
+ ndcg10: List[float] = []
90
+ mrr10: List[float] = []
91
+ recall10: List[float] = []
92
+ recall5: List[float] = []
93
+ latencies_ms: List[float] = []
94
+
95
+ if max_queries and max_queries > 0:
96
+ queries = queries[:max_queries]
97
+
98
+ iterator = queries
99
+ try:
100
+ from tqdm import tqdm
101
+
102
+ iterator = tqdm(queries, desc=f"🔎 Evaluating ({mode})", unit="q")
103
+ except ImportError:
104
+ pass
105
+
106
+ for idx, q in enumerate(iterator):
107
+ start = time.time()
108
+ if precomputed_query_embeddings is None:
109
+ results = retriever.search(
110
+ query=q.text,
111
+ top_k=top_k,
112
+ mode=mode,
113
+ prefetch_k=prefetch_k,
114
+ stage1_mode=stage1_mode,
115
+ )
116
+ else:
117
+ query_embedding = precomputed_query_embeddings[idx]
118
+ if mode == "single_full":
119
+ results = retriever._single_stage.search(
120
+ query_embedding=query_embedding,
121
+ top_k=top_k,
122
+ strategy="multi_vector",
123
+ )
124
+ elif mode == "two_stage":
125
+ results = retriever._two_stage.search(
126
+ query_embedding=query_embedding,
127
+ top_k=top_k,
128
+ prefetch_k=prefetch_k,
129
+ stage1_mode=stage1_mode,
130
+ )
131
+ else:
132
+ raise ValueError(f"Unsupported mode for precomputed embeddings: {mode}")
133
+ latencies_ms.append((time.time() - start) * 1000.0)
134
+
135
+ ranking = [str(r["id"]) for r in results]
136
+ rels = qrels.get(q.query_id, {})
137
+
138
+ ndcg10.append(ndcg_at_k(ranking, rels, k=10))
139
+ mrr10.append(mrr_at_k(ranking, rels, k=10))
140
+ recall5.append(recall_at_k(ranking, rels, k=5))
141
+ recall10.append(recall_at_k(ranking, rels, k=10))
142
+
143
+ return {
144
+ "ndcg@10": float(np.mean(ndcg10)),
145
+ "mrr@10": float(np.mean(mrr10)),
146
+ "recall@5": float(np.mean(recall5)),
147
+ "recall@10": float(np.mean(recall10)),
148
+ "avg_latency_ms": float(np.mean(latencies_ms)),
149
+ "p95_latency_ms": float(np.percentile(latencies_ms, 95)),
150
+ }
151
+
152
+
153
+ def main() -> None:
154
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", force=True)
155
+
156
+ parser = argparse.ArgumentParser()
157
+ parser.add_argument("--dataset", type=str, default="vidore/tatdqa_test")
158
+ parser.add_argument("--collection", type=str, default="vidore_tatdqa_test")
159
+ parser.add_argument("--model", type=str, default="vidore/colSmol-500M")
160
+ parser.add_argument(
161
+ "--torch-dtype",
162
+ type=str,
163
+ default="auto",
164
+ choices=["auto", "float32", "float16", "bfloat16"],
165
+ help="Torch dtype for model weights (default: auto; inferred from collection vector dtype when possible).",
166
+ )
167
+ parser.add_argument("--top-k", type=int, default=10)
168
+ parser.add_argument("--mode", type=str, default="two_stage", choices=["single_full", "two_stage"])
169
+ parser.add_argument(
170
+ "--stage1-mode",
171
+ type=str,
172
+ default="tokens_vs_tiles",
173
+ choices=["pooled_query_vs_tiles", "tokens_vs_tiles", "pooled_query_vs_global"],
174
+ )
175
+ parser.add_argument(
176
+ "--prefetch-ks",
177
+ type=str,
178
+ default="20,50,100,200,400",
179
+ help="Comma-separated list of prefetch_k values (only used for mode=two_stage).",
180
+ )
181
+ parser.add_argument("--prefer-grpc", action="store_true")
182
+ parser.add_argument("--out-dir", type=str, default="results/sweeps")
183
+ parser.add_argument(
184
+ "--max-queries",
185
+ type=int,
186
+ default=0,
187
+ help="Limit number of queries for a quick smoke test (0 = all).",
188
+ )
189
+ parser.add_argument(
190
+ "--sample-queries",
191
+ type=int,
192
+ default=0,
193
+ help="Sample N queries for sweeps (0 = disable).",
194
+ )
195
+ parser.add_argument(
196
+ "--sample-strategy",
197
+ type=str,
198
+ default="head",
199
+ choices=["head", "random"],
200
+ help="How to sample queries when --sample-queries is set.",
201
+ )
202
+ parser.add_argument(
203
+ "--sample-seed",
204
+ type=int,
205
+ default=42,
206
+ help="Random seed for --sample-strategy random.",
207
+ )
208
+ parser.add_argument(
209
+ "--query-batch-size",
210
+ type=int,
211
+ default=32,
212
+ help="Batch size for embedding queries. Set 0 to disable pre-embedding and embed per query.",
213
+ )
214
+ args = parser.parse_args()
215
+
216
+ _maybe_load_dotenv()
217
+
218
+ if not os.getenv("QDRANT_URL"):
219
+ raise ValueError("QDRANT_URL not set. Add it to .env or export it in your shell.")
220
+
221
+ logger.info(f"Dataset: {args.dataset}")
222
+ logger.info(f"Collection: {args.collection}")
223
+ logger.info(f"Model: {args.model}")
224
+ logger.info(f"Mode: {args.mode}")
225
+ if args.mode == "two_stage":
226
+ logger.info(f"Stage1 mode: {args.stage1_mode}")
227
+ logger.info(f"Prefetch ks: {args.prefetch_ks}")
228
+ if args.max_queries:
229
+ logger.info(f"Max queries (smoke test): {args.max_queries}")
230
+
231
+ from qdrant_client import QdrantClient
232
+
233
+ client = QdrantClient(
234
+ url=os.getenv("QDRANT_URL"),
235
+ api_key=os.getenv("QDRANT_API_KEY"),
236
+ prefer_grpc=args.prefer_grpc,
237
+ check_compatibility=False,
238
+ timeout=120,
239
+ )
240
+
241
+ requested_torch_dtype = args.torch_dtype
242
+ if requested_torch_dtype == "auto":
243
+ vdt = _infer_collection_vector_dtype(client=client, collection_name=args.collection)
244
+ if vdt == "float16":
245
+ requested_torch_dtype = "float16"
246
+ elif vdt == "float32":
247
+ requested_torch_dtype = "float32"
248
+
249
+ torch_dtype = _parse_torch_dtype(requested_torch_dtype)
250
+
251
+ corpus, queries, qrels, protocol = load_vidore_dataset_auto(args.dataset)
252
+ del corpus
253
+ logger.info(f"Loaded protocol={protocol}, queries={len(queries)}")
254
+
255
+ if args.max_queries and args.max_queries > 0:
256
+ queries = queries[: args.max_queries]
257
+ if args.sample_queries and args.sample_queries > 0:
258
+ if args.sample_strategy == "head":
259
+ queries = queries[: args.sample_queries]
260
+ else:
261
+ rng = np.random.default_rng(int(args.sample_seed))
262
+ n = min(int(args.sample_queries), len(queries))
263
+ idxs = rng.choice(len(queries), size=n, replace=False).tolist()
264
+ queries = [queries[i] for i in idxs]
265
+ logger.info(f"Eval queries: {len(queries)}")
266
+
267
+ embedder = VisualEmbedder(model_name=args.model, torch_dtype=torch_dtype)
268
+ logger.info(f"Effective torch dtype: {_torch_dtype_to_str(embedder.torch_dtype)}")
269
+
270
+ retriever = MultiVectorRetriever(
271
+ collection_name=args.collection,
272
+ embedder=embedder,
273
+ prefer_grpc=args.prefer_grpc,
274
+ qdrant_client=client,
275
+ )
276
+
277
+ precomputed_query_embeddings: Optional[List[np.ndarray]] = None
278
+ if args.query_batch_size and args.query_batch_size > 0:
279
+ texts = [q.text for q in queries]
280
+ logger.info(f"Pre-embedding {len(texts)} queries (batch={args.query_batch_size})...")
281
+ q_tensors = embedder.embed_queries(texts, batch_size=args.query_batch_size, show_progress=True)
282
+ precomputed_query_embeddings = [t.detach().cpu().float().numpy() for t in q_tensors]
283
+ try:
284
+ import torch
285
+
286
+ if torch.backends.mps.is_available():
287
+ torch.mps.empty_cache()
288
+ except Exception:
289
+ pass
290
+
291
+ out_dir = Path(args.out_dir) / args.collection
292
+ out_dir.mkdir(parents=True, exist_ok=True)
293
+
294
+ if args.mode == "single_full":
295
+ metrics = _evaluate(
296
+ retriever=retriever,
297
+ queries=queries,
298
+ qrels=qrels,
299
+ top_k=args.top_k,
300
+ mode="single_full",
301
+ stage1_mode=args.stage1_mode,
302
+ prefetch_k=0,
303
+ max_queries=args.max_queries,
304
+ precomputed_query_embeddings=precomputed_query_embeddings,
305
+ )
306
+ out_path = out_dir / f"{protocol}__single_full__top{args.top_k}.json"
307
+ with open(out_path, "w") as f:
308
+ json.dump(
309
+ {
310
+ "dataset": args.dataset,
311
+ "protocol": protocol,
312
+ "collection": args.collection,
313
+ "model": args.model,
314
+ "torch_dtype": _torch_dtype_to_str(embedder.torch_dtype),
315
+ "mode": "single_full",
316
+ "top_k": args.top_k,
317
+ "max_queries": args.max_queries,
318
+ "sample_queries": args.sample_queries,
319
+ "sample_strategy": args.sample_strategy if args.sample_queries else None,
320
+ "sample_seed": args.sample_seed if args.sample_queries and args.sample_strategy == "random" else None,
321
+ "metrics": metrics,
322
+ },
323
+ f,
324
+ indent=2,
325
+ )
326
+ print(out_path)
327
+ print(json.dumps(metrics, indent=2))
328
+ return
329
+
330
+ prefetch_ks = [int(x.strip()) for x in args.prefetch_ks.split(",") if x.strip()]
331
+ for k in prefetch_ks:
332
+ metrics = _evaluate(
333
+ retriever=retriever,
334
+ queries=queries,
335
+ qrels=qrels,
336
+ top_k=args.top_k,
337
+ mode="two_stage",
338
+ stage1_mode=args.stage1_mode,
339
+ prefetch_k=k,
340
+ max_queries=args.max_queries,
341
+ precomputed_query_embeddings=precomputed_query_embeddings,
342
+ )
343
+ out_path = out_dir / f"{protocol}__two_stage__{args.stage1_mode}__prefetch{k}__top{args.top_k}.json"
344
+ with open(out_path, "w") as f:
345
+ json.dump(
346
+ {
347
+ "dataset": args.dataset,
348
+ "protocol": protocol,
349
+ "collection": args.collection,
350
+ "model": args.model,
351
+ "mode": "two_stage",
352
+ "stage1_mode": args.stage1_mode,
353
+ "prefetch_k": k,
354
+ "top_k": args.top_k,
355
+ "torch_dtype": _torch_dtype_to_str(embedder.torch_dtype),
356
+ "max_queries": args.max_queries,
357
+ "sample_queries": args.sample_queries,
358
+ "sample_strategy": args.sample_strategy if args.sample_queries else None,
359
+ "sample_seed": args.sample_seed if args.sample_queries and args.sample_strategy == "random" else None,
360
+ "metrics": metrics,
361
+ },
362
+ f,
363
+ indent=2,
364
+ )
365
+ print(out_path)
366
+ print(json.dumps(metrics, indent=2))
367
+
368
+
369
+ if __name__ == "__main__":
370
+ main()
371
+
372
+
demo/__init__.py ADDED
@@ -0,0 +1,10 @@
1
+ """
2
+ Demo application for Visual RAG Toolkit.
3
+
4
+ A Streamlit-based UI for:
5
+ - Uploading and indexing PDFs
6
+ - Running benchmark evaluations
7
+ - Interactive playground for visual search
8
+ """
9
+
10
+ __version__ = "0.1.0"
demo/app.py ADDED
@@ -0,0 +1,45 @@
1
+ """Main entry point for the Visual RAG Toolkit demo application."""
2
+
3
+ import sys
4
+ from pathlib import Path
5
+
6
+ ROOT_DIR = Path(__file__).parent.parent
7
+ sys.path.insert(0, str(ROOT_DIR))
8
+
9
+ from dotenv import load_dotenv
10
+ load_dotenv(ROOT_DIR / ".env")
11
+
12
+ import streamlit as st
13
+
14
+ st.set_page_config(
15
+ page_title="Visual RAG Toolkit",
16
+ page_icon="🔬",
17
+ layout="wide",
18
+ initial_sidebar_state="expanded",
19
+ )
20
+
21
+ from demo.ui.header import render_header
22
+ from demo.ui.sidebar import render_sidebar
23
+ from demo.ui.upload import render_upload_tab
24
+ from demo.ui.playground import render_playground_tab
25
+ from demo.ui.benchmark import render_benchmark_tab
26
+
27
+
28
+ def main():
29
+ render_header()
30
+ render_sidebar()
31
+
32
+ tab_upload, tab_playground, tab_benchmark = st.tabs(["📤 Upload", "🎮 Playground", "📊 Benchmarking"])
33
+
34
+ with tab_upload:
35
+ render_upload_tab()
36
+
37
+ with tab_playground:
38
+ render_playground_tab()
39
+
40
+ with tab_benchmark:
41
+ render_benchmark_tab()
42
+
43
+
44
+ if __name__ == "__main__":
45
+ main()