visual-rag-toolkit 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. benchmarks/README.md +101 -0
  2. benchmarks/__init__.py +11 -0
  3. benchmarks/analyze_results.py +187 -0
  4. benchmarks/benchmark_datasets.txt +105 -0
  5. benchmarks/prepare_submission.py +205 -0
  6. benchmarks/quick_test.py +566 -0
  7. benchmarks/run_vidore.py +513 -0
  8. benchmarks/vidore_beir_qdrant/run_qdrant_beir.py +1365 -0
  9. benchmarks/vidore_tatdqa_test/COMMANDS.md +83 -0
  10. benchmarks/vidore_tatdqa_test/__init__.py +6 -0
  11. benchmarks/vidore_tatdqa_test/dataset_loader.py +363 -0
  12. benchmarks/vidore_tatdqa_test/metrics.py +44 -0
  13. benchmarks/vidore_tatdqa_test/run_qdrant.py +799 -0
  14. benchmarks/vidore_tatdqa_test/sweep_eval.py +372 -0
  15. demo/__init__.py +10 -0
  16. demo/app.py +45 -0
  17. demo/commands.py +334 -0
  18. demo/config.py +34 -0
  19. demo/download_models.py +75 -0
  20. demo/evaluation.py +602 -0
  21. demo/example_metadata_mapping_sigir.json +37 -0
  22. demo/indexing.py +286 -0
  23. demo/qdrant_utils.py +211 -0
  24. demo/results.py +35 -0
  25. demo/test_qdrant_connection.py +119 -0
  26. demo/ui/__init__.py +15 -0
  27. demo/ui/benchmark.py +355 -0
  28. demo/ui/header.py +30 -0
  29. demo/ui/playground.py +339 -0
  30. demo/ui/sidebar.py +162 -0
  31. demo/ui/upload.py +487 -0
  32. visual_rag/__init__.py +98 -0
  33. visual_rag/cli/__init__.py +1 -0
  34. visual_rag/cli/main.py +629 -0
  35. visual_rag/config.py +230 -0
  36. visual_rag/demo_runner.py +90 -0
  37. visual_rag/embedding/__init__.py +26 -0
  38. visual_rag/embedding/pooling.py +343 -0
  39. visual_rag/embedding/visual_embedder.py +622 -0
  40. visual_rag/indexing/__init__.py +21 -0
  41. visual_rag/indexing/cloudinary_uploader.py +274 -0
  42. visual_rag/indexing/pdf_processor.py +324 -0
  43. visual_rag/indexing/pipeline.py +628 -0
  44. visual_rag/indexing/qdrant_indexer.py +478 -0
  45. visual_rag/preprocessing/__init__.py +3 -0
  46. visual_rag/preprocessing/crop_empty.py +120 -0
  47. visual_rag/qdrant_admin.py +222 -0
  48. visual_rag/retrieval/__init__.py +19 -0
  49. visual_rag/retrieval/multi_vector.py +222 -0
  50. visual_rag/retrieval/single_stage.py +126 -0
  51. visual_rag/retrieval/three_stage.py +173 -0
  52. visual_rag/retrieval/two_stage.py +471 -0
  53. visual_rag/visualization/__init__.py +19 -0
  54. visual_rag/visualization/saliency.py +335 -0
  55. visual_rag_toolkit-0.1.1.dist-info/METADATA +305 -0
  56. visual_rag_toolkit-0.1.1.dist-info/RECORD +59 -0
  57. visual_rag_toolkit-0.1.1.dist-info/WHEEL +4 -0
  58. visual_rag_toolkit-0.1.1.dist-info/entry_points.txt +3 -0
  59. visual_rag_toolkit-0.1.1.dist-info/licenses/LICENSE +22 -0
demo/ui/benchmark.py ADDED
@@ -0,0 +1,355 @@
1
+ """Benchmark tab component."""
2
+
3
+ from pathlib import Path
4
+ from typing import Any, Dict, List
5
+
6
+ import altair as alt
7
+ import pandas as pd
8
+ import streamlit as st
9
+
10
+ from demo.config import (
11
+ AVAILABLE_MODELS,
12
+ BENCHMARK_DATASETS,
13
+ DATASET_STATS,
14
+ RETRIEVAL_MODES,
15
+ STAGE1_MODES,
16
+ )
17
+ from demo.qdrant_utils import get_qdrant_credentials, get_collections
18
+ from demo.commands import build_index_command, build_eval_command, generate_python_eval_code, generate_python_index_code
19
+ from demo.results import get_available_results, load_results_file
20
+ from demo.evaluation import run_evaluation_with_ui
21
+ from demo.indexing import run_indexing_with_ui
22
+
23
+
24
+ def render_benchmark_tab():
25
+ st.subheader("📊 Benchmarking")
26
+
27
+ tab_index, tab_eval, tab_results = st.tabs(["Indexing", "Evaluation", "Results"])
28
+
29
+ url, api_key = get_qdrant_credentials()
30
+ collections = get_collections(url, api_key)
31
+
32
+ with tab_index:
33
+ render_benchmark_indexing(collections)
34
+
35
+ with tab_eval:
36
+ render_benchmark_evaluation(collections)
37
+
38
+ with tab_results:
39
+ render_benchmark_results()
40
+
41
+
42
+ def render_benchmark_indexing(collections: List[str]):
43
+ st.caption("Create a new collection with benchmark datasets")
44
+
45
+ c1, c2, c3 = st.columns(3)
46
+ with c1:
47
+ datasets = st.multiselect("Datasets", BENCHMARK_DATASETS, default=BENCHMARK_DATASETS, key="bi_ds")
48
+ with c2:
49
+ model = st.selectbox("Model", AVAILABLE_MODELS, key="bi_model")
50
+ with c3:
51
+ model_short = model.split("/")[-1].replace("-", "_").replace(".", "_")
52
+ collection = st.text_input("New Collection Name", value=f"vidore_{len(datasets)}ds__{model_short}", key="bi_coll")
53
+
54
+ sel_docs = sum(DATASET_STATS.get(d, {}).get("docs", 0) for d in datasets)
55
+ sel_queries = sum(DATASET_STATS.get(d, {}).get("queries", 0) for d in datasets)
56
+ st.markdown(f"🎯 **Selected:** {len(datasets)} dataset(s) — **{sel_docs:,}** docs, **{sel_queries:,}** queries")
57
+
58
+ c4, c5, c6, c7 = st.columns(4)
59
+ with c4:
60
+ crop = st.toggle("Crop", value=True, key="bi_crop")
61
+ with c5:
62
+ cloudinary = st.toggle("Cloudinary", value=True, key="bi_cloud")
63
+ with c6:
64
+ grpc = st.toggle("gRPC", value=True, key="bi_grpc")
65
+ with c7:
66
+ recreate = st.toggle("Recreate", value=False, key="bi_recreate")
67
+
68
+ crop_pct = st.slider("Crop %", 0.8, 0.99, 0.99, 0.01, key="bi_crop_pct") if crop else 0.99
69
+
70
+ st.markdown("---")
71
+
72
+ col_max, col_batch, col_torch, col_qdrant = st.columns([2, 2, 1, 1])
73
+ with col_max:
74
+ max_docs_val = max(sel_docs, 1)
75
+ max_docs = st.number_input(
76
+ "Max Docs (per dataset)",
77
+ min_value=1,
78
+ max_value=max_docs_val,
79
+ value=max_docs_val,
80
+ key="bi_max_docs",
81
+ help="Limit docs per dataset. Useful for quick tests."
82
+ )
83
+ with col_batch:
84
+ batch_size = st.number_input("Batch Size", min_value=1, max_value=16, value=4, key="bi_batch")
85
+ with col_torch:
86
+ torch_dtype = st.selectbox("Torch dtype", ["float16", "float32"], index=0, key="bi_torch_dtype")
87
+ with col_qdrant:
88
+ qdrant_dtype = st.selectbox("Qdrant dtype", ["float16", "float32"], index=0, key="bi_qdrant_dtype")
89
+
90
+ effective_docs = min(max_docs * len(datasets), sel_docs) if max_docs < max_docs_val else sel_docs
91
+
92
+ config = {
93
+ "datasets": datasets, "model": model, "collection": collection,
94
+ "crop_empty": crop, "crop_percentage": crop_pct,
95
+ "no_cloudinary": not cloudinary, "recreate": recreate, "resume": False,
96
+ "prefer_grpc": grpc, "batch_size": batch_size, "upload_batch_size": 8,
97
+ "qdrant_timeout": 180, "qdrant_retries": 5,
98
+ "torch_dtype": torch_dtype, "qdrant_vector_dtype": qdrant_dtype,
99
+ "max_docs": max_docs if max_docs < max_docs_val else None,
100
+ }
101
+
102
+ cmd = build_index_command(config)
103
+ python_code = generate_python_index_code(config)
104
+
105
+ col_cmd, col_info = st.columns([2, 1])
106
+ with col_cmd:
107
+ code_tab1, code_tab2 = st.tabs(["🐚 Bash", "🐍 Python"])
108
+ with code_tab1:
109
+ st.code(cmd, language="bash")
110
+ with code_tab2:
111
+ st.code(python_code, language="python")
112
+ with col_info:
113
+ st.markdown("<br><br><br>", unsafe_allow_html=True)
114
+
115
+ st.metric("Docs to Index", f"{effective_docs:,}")
116
+ st.metric("Model", model.split("/")[-1])
117
+ if effective_docs < sel_docs:
118
+ st.caption(f"Limited from {sel_docs:,} total")
119
+ st.divider()
120
+ run_index = st.button("🚀 Run Index", type="primary", key="bi_run", use_container_width=True)
121
+
122
+ if run_index:
123
+ if not collection:
124
+ st.error("Please provide a collection name")
125
+ elif not datasets:
126
+ st.error("Please select at least one dataset")
127
+ else:
128
+ run_indexing_with_ui(config)
129
+
130
+
131
+ def render_benchmark_evaluation(collections: List[str]):
132
+ collection = st.session_state.get("active_collection")
133
+
134
+ if not collection:
135
+ st.warning("⚠️ Select a collection from the sidebar first")
136
+ return
137
+
138
+ st.info(f"**Collection:** `{collection}` (from sidebar)")
139
+
140
+ all_docs = sum(DATASET_STATS.get(d, {}).get("docs", 0) for d in BENCHMARK_DATASETS)
141
+ all_queries = sum(DATASET_STATS.get(d, {}).get("queries", 0) for d in BENCHMARK_DATASETS)
142
+ st.markdown(f"📊 **Available:** {len(BENCHMARK_DATASETS)} datasets — **{all_docs:,}** docs, **{all_queries:,}** queries")
143
+
144
+ c1, c2 = st.columns([3, 1])
145
+ with c1:
146
+ st.multiselect("Datasets", BENCHMARK_DATASETS, default=BENCHMARK_DATASETS, key="be_ds")
147
+ with c2:
148
+ model = st.selectbox("Model", AVAILABLE_MODELS, key="be_model")
149
+
150
+ datasets = st.session_state.get("be_ds", BENCHMARK_DATASETS)
151
+ sel_docs = sum(DATASET_STATS.get(d, {}).get("docs", 0) for d in datasets)
152
+ sel_queries = sum(DATASET_STATS.get(d, {}).get("queries", 0) for d in datasets)
153
+ st.markdown(f"🎯 **Selected:** {len(datasets)} dataset(s) — **{sel_docs:,}** docs, **{sel_queries:,}** queries")
154
+
155
+ st.markdown("---")
156
+
157
+ col_mode, col_topk = st.columns([2, 1])
158
+ with col_mode:
159
+ mode = st.selectbox("Mode", RETRIEVAL_MODES, key="be_mode")
160
+ with col_topk:
161
+ top_k = st.slider("Top K", 10, 100, 100, key="be_topk")
162
+
163
+ stage1_mode, prefetch_k, stage1_k, stage2_k = "tokens_vs_tiles", 256, 1000, 300
164
+
165
+ if mode == "two_stage":
166
+ cc1, cc2 = st.columns(2)
167
+ with cc1:
168
+ stage1_mode = st.selectbox("Stage1 Mode", STAGE1_MODES, key="be_s1mode")
169
+ with cc2:
170
+ prefetch_k = st.slider("Prefetch K", 50, 1000, 256, key="be_pk")
171
+ elif mode == "three_stage":
172
+ cc1, cc2 = st.columns(2)
173
+ with cc1:
174
+ stage1_k = st.number_input("Stage1 K", 100, 5000, 1000, key="be_s1k")
175
+ with cc2:
176
+ stage2_k = st.number_input("Stage2 K", 50, 1000, 300, key="be_s2k")
177
+
178
+ st.markdown("---")
179
+
180
+ col_scope, _, col_grpc, col_nq = st.columns([2, 0.5, 1, 2])
181
+ with col_scope:
182
+ scope = st.selectbox("Scope", ["union", "per_dataset"], key="be_scope")
183
+ with col_grpc:
184
+ st.write("")
185
+ st.write("")
186
+ grpc = st.toggle("gRPC", value=True, key="be_grpc")
187
+ with col_nq:
188
+ max_q_val = max(sel_queries, 1)
189
+ max_queries = st.number_input(
190
+ "Max Queries",
191
+ min_value=1,
192
+ max_value=max_q_val,
193
+ value=max_q_val,
194
+ key="be_max_queries",
195
+ help="Limit number of queries to evaluate (useful for quick tests)"
196
+ )
197
+
198
+ result_prefix_val = st.session_state.get("be_prefix", "")
199
+
200
+ config = {
201
+ "datasets": datasets, "model": model, "collection": collection,
202
+ "mode": mode, "top_k": top_k, "evaluation_scope": scope,
203
+ "prefer_grpc": grpc,
204
+ "torch_dtype": "float16",
205
+ "qdrant_vector_dtype": "float16",
206
+ "qdrant_timeout": 180,
207
+ "stage1_mode": stage1_mode, "prefetch_k": prefetch_k,
208
+ "stage1_k": stage1_k, "stage2_k": stage2_k,
209
+ "result_prefix": result_prefix_val,
210
+ "max_queries": max_queries,
211
+ }
212
+
213
+ cmd = build_eval_command(config)
214
+
215
+ python_code = generate_python_eval_code(config)
216
+
217
+ col_cmd, col_info = st.columns([2, 1])
218
+ with col_cmd:
219
+ code_tab1, code_tab2 = st.tabs(["🐚 Bash", "🐍 Python"])
220
+ with code_tab1:
221
+ st.code(cmd, language="bash")
222
+ with code_tab2:
223
+ st.code(python_code, language="python")
224
+ with col_info:
225
+ st.markdown("<br><br><br>", unsafe_allow_html=True)
226
+
227
+ mode_desc = {
228
+ "single_full": "🔹 **Single Full**: Query all visual tokens against full document embeddings in one pass.",
229
+ "single_tiles": "🔸 **Single Tiles**: Query against tile-level embeddings only.",
230
+ "single_global": "🔶 **Single Global**: Query against global (pooled) document embeddings.",
231
+ "two_stage": "🔷 **Two Stage**: Fast prefetch with global/tiles, then rerank with full tokens.",
232
+ "three_stage": "🔶 **Three Stage**: Global → Tiles → Full tokens for maximum precision.",
233
+ }
234
+ scope_desc = {
235
+ "union": "📊 **Union**: Evaluate across all datasets combined as one corpus.",
236
+ "per_dataset": "📁 **Per Dataset**: Evaluate each dataset separately and report individual metrics.",
237
+ }
238
+ st.markdown(mode_desc.get(mode, ""))
239
+ st.markdown(scope_desc.get(scope, ""))
240
+ st.divider()
241
+ st.text_input("Result Prefix", placeholder="optional prefix for output", key="be_prefix")
242
+
243
+ run_eval = st.button("🚀 Run Eval", type="primary", key="be_run", use_container_width=True)
244
+
245
+ if run_eval:
246
+ if not collection:
247
+ st.error("Please select a collection first")
248
+ else:
249
+ run_evaluation_with_ui(config)
250
+
251
+
252
+ def render_benchmark_results():
253
+ st.markdown("##### Load Results")
254
+
255
+ available = get_available_results()
256
+
257
+ if not available:
258
+ st.info("No results found")
259
+ return
260
+
261
+ default_select = []
262
+ if st.session_state.get("auto_select_result"):
263
+ auto = st.session_state.pop("auto_select_result")
264
+ if auto in [str(p) for p in available]:
265
+ default_select = [auto]
266
+
267
+ selected = st.multiselect(
268
+ "Result files",
269
+ options=[str(p) for p in available],
270
+ format_func=lambda x: Path(x).name[:60],
271
+ default=default_select,
272
+ key="br_files",
273
+ )
274
+
275
+ for path in selected:
276
+ data = load_results_file(Path(path))
277
+ if data:
278
+ render_result_card(data, Path(path).name)
279
+
280
+
281
+ def render_result_card(data: Dict[str, Any], filename: str):
282
+ with st.expander(f"📊 {filename[:50]}", expanded=True):
283
+ c1, c2, c3, c4 = st.columns(4)
284
+ c1.metric("Model", (data.get("model") or "?").split("/")[-1])
285
+ c2.metric("Mode", data.get("mode", "?"))
286
+ c3.metric("Top K", data.get("top_k", "?"))
287
+ c4.metric("Time", f"{data.get('eval_wall_time_s', 0):.0f}s")
288
+
289
+ metrics = data.get("metrics_by_dataset", {})
290
+ if not metrics:
291
+ st.warning("No metrics data")
292
+ return
293
+
294
+ rows = []
295
+ for ds, m in metrics.items():
296
+ rows.append({
297
+ "Dataset": ds.split("/")[-1].replace("_v2", ""),
298
+ "NDCG@5": m.get("ndcg@5", 0),
299
+ "NDCG@10": m.get("ndcg@10", 0),
300
+ "Recall@5": m.get("recall@5", 0),
301
+ "Recall@10": m.get("recall@10", 0),
302
+ "MRR@10": m.get("mrr@10", 0),
303
+ "Latency": m.get("avg_latency_ms", 0),
304
+ "QPS": m.get("qps", 0),
305
+ })
306
+
307
+ df = pd.DataFrame(rows)
308
+
309
+ st.dataframe(
310
+ df.style.format({
311
+ "NDCG@5": "{:.4f}", "NDCG@10": "{:.4f}",
312
+ "Recall@5": "{:.4f}", "Recall@10": "{:.4f}",
313
+ "MRR@10": "{:.4f}", "Latency": "{:.1f}", "QPS": "{:.2f}"
314
+ }),
315
+ hide_index=True, use_container_width=True
316
+ )
317
+
318
+ chart_data = []
319
+ for ds, m in metrics.items():
320
+ ds_short = ds.split("/")[-1].replace("_v2", "").replace("_", " ").title()
321
+ chart_data.append({"Dataset": ds_short, "Metric": "NDCG@10", "Value": m.get("ndcg@10", 0)})
322
+ chart_data.append({"Dataset": ds_short, "Metric": "Recall@10", "Value": m.get("recall@10", 0)})
323
+ chart_data.append({"Dataset": ds_short, "Metric": "MRR@10", "Value": m.get("mrr@10", 0)})
324
+
325
+ chart_df = pd.DataFrame(chart_data)
326
+
327
+ chart = alt.Chart(chart_df).mark_bar().encode(
328
+ x=alt.X("Dataset:N", title=None),
329
+ y=alt.Y("Value:Q", scale=alt.Scale(domain=[0, 1]), title="Score"),
330
+ color=alt.Color("Metric:N", scale=alt.Scale(scheme="tableau10")),
331
+ xOffset="Metric:N",
332
+ tooltip=["Dataset", "Metric", alt.Tooltip("Value:Q", format=".4f")]
333
+ ).properties(height=300, title="Metrics by Dataset")
334
+
335
+ st.altair_chart(chart, use_container_width=True)
336
+
337
+ latency_data = [{"Dataset": ds.split("/")[-1].replace("_v2", ""), "Latency (ms)": m.get("avg_latency_ms", 0), "QPS": m.get("qps", 0)} for ds, m in metrics.items()]
338
+ latency_df = pd.DataFrame(latency_data)
339
+
340
+ c1, c2 = st.columns(2)
341
+ with c1:
342
+ lat_chart = alt.Chart(latency_df).mark_bar(color="#ff6b6b").encode(
343
+ x=alt.X("Dataset:N"),
344
+ y=alt.Y("Latency (ms):Q"),
345
+ tooltip=["Dataset", alt.Tooltip("Latency (ms):Q", format=".1f")]
346
+ ).properties(height=200, title="Avg Latency")
347
+ st.altair_chart(lat_chart, use_container_width=True)
348
+
349
+ with c2:
350
+ qps_chart = alt.Chart(latency_df).mark_bar(color="#4ecdc4").encode(
351
+ x=alt.X("Dataset:N"),
352
+ y=alt.Y("QPS:Q"),
353
+ tooltip=["Dataset", alt.Tooltip("QPS:Q", format=".2f")]
354
+ ).properties(height=200, title="QPS (Queries/sec)")
355
+ st.altair_chart(qps_chart, use_container_width=True)
demo/ui/header.py ADDED
@@ -0,0 +1,30 @@
1
+ """Header component."""
2
+
3
+ import streamlit as st
4
+
5
+
6
+ def render_header():
7
+ st.markdown("""
8
+ <div style="text-align: center; padding: 10px 0 15px 0;">
9
+ <h1 style="
10
+ font-family: 'Georgia', serif;
11
+ font-size: 2.5rem;
12
+ font-weight: 700;
13
+ color: #1a1a2e;
14
+ letter-spacing: 3px;
15
+ margin: 0;
16
+ text-shadow: 2px 2px 4px rgba(0,0,0,0.1);
17
+ ">
18
+ 🔬 Visual RAG Toolkit
19
+ </h1>
20
+ <p style="
21
+ font-family: 'Helvetica Neue', sans-serif;
22
+ font-size: 0.95rem;
23
+ color: #666;
24
+ margin-top: 5px;
25
+ letter-spacing: 1px;
26
+ ">
27
+ SIGIR 2026 Demo - Multi-Vector Visual Document Retrieval
28
+ </p>
29
+ </div>
30
+ """, unsafe_allow_html=True)