visual-rag-toolkit 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. benchmarks/README.md +101 -0
  2. benchmarks/__init__.py +11 -0
  3. benchmarks/analyze_results.py +187 -0
  4. benchmarks/benchmark_datasets.txt +105 -0
  5. benchmarks/prepare_submission.py +205 -0
  6. benchmarks/quick_test.py +566 -0
  7. benchmarks/run_vidore.py +513 -0
  8. benchmarks/vidore_beir_qdrant/run_qdrant_beir.py +1365 -0
  9. benchmarks/vidore_tatdqa_test/COMMANDS.md +83 -0
  10. benchmarks/vidore_tatdqa_test/__init__.py +6 -0
  11. benchmarks/vidore_tatdqa_test/dataset_loader.py +363 -0
  12. benchmarks/vidore_tatdqa_test/metrics.py +44 -0
  13. benchmarks/vidore_tatdqa_test/run_qdrant.py +799 -0
  14. benchmarks/vidore_tatdqa_test/sweep_eval.py +372 -0
  15. demo/__init__.py +10 -0
  16. demo/app.py +45 -0
  17. demo/commands.py +334 -0
  18. demo/config.py +34 -0
  19. demo/download_models.py +75 -0
  20. demo/evaluation.py +602 -0
  21. demo/example_metadata_mapping_sigir.json +37 -0
  22. demo/indexing.py +286 -0
  23. demo/qdrant_utils.py +211 -0
  24. demo/results.py +35 -0
  25. demo/test_qdrant_connection.py +119 -0
  26. demo/ui/__init__.py +15 -0
  27. demo/ui/benchmark.py +355 -0
  28. demo/ui/header.py +30 -0
  29. demo/ui/playground.py +339 -0
  30. demo/ui/sidebar.py +162 -0
  31. demo/ui/upload.py +487 -0
  32. visual_rag/__init__.py +98 -0
  33. visual_rag/cli/__init__.py +1 -0
  34. visual_rag/cli/main.py +629 -0
  35. visual_rag/config.py +230 -0
  36. visual_rag/demo_runner.py +90 -0
  37. visual_rag/embedding/__init__.py +26 -0
  38. visual_rag/embedding/pooling.py +343 -0
  39. visual_rag/embedding/visual_embedder.py +622 -0
  40. visual_rag/indexing/__init__.py +21 -0
  41. visual_rag/indexing/cloudinary_uploader.py +274 -0
  42. visual_rag/indexing/pdf_processor.py +324 -0
  43. visual_rag/indexing/pipeline.py +628 -0
  44. visual_rag/indexing/qdrant_indexer.py +478 -0
  45. visual_rag/preprocessing/__init__.py +3 -0
  46. visual_rag/preprocessing/crop_empty.py +120 -0
  47. visual_rag/qdrant_admin.py +222 -0
  48. visual_rag/retrieval/__init__.py +19 -0
  49. visual_rag/retrieval/multi_vector.py +222 -0
  50. visual_rag/retrieval/single_stage.py +126 -0
  51. visual_rag/retrieval/three_stage.py +173 -0
  52. visual_rag/retrieval/two_stage.py +471 -0
  53. visual_rag/visualization/__init__.py +19 -0
  54. visual_rag/visualization/saliency.py +335 -0
  55. visual_rag_toolkit-0.1.1.dist-info/METADATA +305 -0
  56. visual_rag_toolkit-0.1.1.dist-info/RECORD +59 -0
  57. visual_rag_toolkit-0.1.1.dist-info/WHEEL +4 -0
  58. visual_rag_toolkit-0.1.1.dist-info/entry_points.txt +3 -0
  59. visual_rag_toolkit-0.1.1.dist-info/licenses/LICENSE +22 -0
demo/indexing.py ADDED
@@ -0,0 +1,286 @@
1
+ """Indexing runner with UI updates."""
2
+
3
+ import hashlib
4
+ import json
5
+ import time
6
+ import traceback
7
+ from datetime import datetime
8
+ from typing import Any, Dict, Optional
9
+
10
+ import numpy as np
11
+ import streamlit as st
12
+ import torch
13
+
14
+ from visual_rag import VisualEmbedder
15
+
16
+
17
+ TORCH_DTYPE_MAP = {
18
+ "float16": torch.float16,
19
+ "float32": torch.float32,
20
+ "bfloat16": torch.bfloat16,
21
+ }
22
+ from visual_rag.indexing import QdrantIndexer
23
+ from benchmarks.vidore_tatdqa_test.dataset_loader import load_vidore_beir_dataset
24
+
25
+ from demo.qdrant_utils import get_qdrant_credentials
26
+
27
+
28
+ def _stable_uuid(text: str) -> str:
29
+ """Generate a stable UUID from text (same as benchmark script)."""
30
+ hex_str = hashlib.sha256(text.encode("utf-8")).hexdigest()[:32]
31
+ return f"{hex_str[:8]}-{hex_str[8:12]}-{hex_str[12:16]}-{hex_str[16:20]}-{hex_str[20:32]}"
32
+
33
+
34
+ def _union_point_id(*, dataset_name: str, source_doc_id: str, union_namespace: Optional[str]) -> str:
35
+ """Generate union point ID (same as benchmark script)."""
36
+ ns = f"{union_namespace}::{dataset_name}" if union_namespace else dataset_name
37
+ return _stable_uuid(f"{ns}::{source_doc_id}")
38
+
39
+
40
+ def run_indexing_with_ui(config: Dict[str, Any]):
41
+ st.divider()
42
+
43
+ print("=" * 60)
44
+ print("[INDEX] Starting indexing via UI")
45
+ print("=" * 60)
46
+
47
+ url, api_key = get_qdrant_credentials()
48
+ if not url:
49
+ st.error("QDRANT_URL not configured")
50
+ return
51
+
52
+ datasets = config.get("datasets", [])
53
+ collection = config["collection"]
54
+ model = config.get("model", "vidore/colpali-v1.3")
55
+ recreate = config.get("recreate", False)
56
+ torch_dtype = config.get("torch_dtype", "float16")
57
+ qdrant_vector_dtype = config.get("qdrant_vector_dtype", "float16")
58
+ prefer_grpc = config.get("prefer_grpc", True)
59
+ batch_size = config.get("batch_size", 4)
60
+ max_docs = config.get("max_docs")
61
+
62
+ print(f"[INDEX] Config: collection={collection}, model={model}")
63
+ print(f"[INDEX] Datasets: {datasets}")
64
+ print(f"[INDEX] max_docs={max_docs}, batch_size={batch_size}, recreate={recreate}")
65
+ print(f"[INDEX] torch_dtype={torch_dtype}, qdrant_dtype={qdrant_vector_dtype}, grpc={prefer_grpc}")
66
+
67
+ phase1_container = st.container()
68
+ phase2_container = st.container()
69
+ phase3_container = st.container()
70
+ results_container = st.container()
71
+
72
+ try:
73
+ with phase1_container:
74
+ st.markdown("##### 🤖 Phase 1: Loading Model")
75
+ model_status = st.empty()
76
+ model_status.info(f"Loading `{model.split('/')[-1]}`...")
77
+
78
+ print(f"[INDEX] Loading embedder: {model}")
79
+ torch_dtype_obj = TORCH_DTYPE_MAP.get(torch_dtype, torch.float16)
80
+ output_dtype_obj = np.float16 if qdrant_vector_dtype == "float16" else np.float32
81
+ embedder = VisualEmbedder(
82
+ model_name=model,
83
+ torch_dtype=torch_dtype_obj,
84
+ output_dtype=output_dtype_obj,
85
+ )
86
+ embedder._load_model()
87
+ print(f"[INDEX] Embedder loaded (torch_dtype={torch_dtype}, output_dtype={qdrant_vector_dtype})")
88
+ model_status.success(f"✅ Model `{model.split('/')[-1]}` loaded")
89
+
90
+ with phase2_container:
91
+ st.markdown("##### 📦 Phase 2: Setting Up Collection")
92
+
93
+ indexer_status = st.empty()
94
+ indexer_status.info(f"Connecting to Qdrant...")
95
+
96
+ print(f"[INDEX] Connecting to Qdrant...")
97
+ indexer = QdrantIndexer(
98
+ url=url,
99
+ api_key=api_key,
100
+ collection_name=collection,
101
+ prefer_grpc=prefer_grpc,
102
+ vector_datatype=qdrant_vector_dtype,
103
+ )
104
+ print(f"[INDEX] Connected to Qdrant")
105
+ indexer_status.success(f"✅ Connected to Qdrant")
106
+
107
+ coll_status = st.empty()
108
+ action = "Recreating" if recreate else "Creating/verifying"
109
+ coll_status.info(f"{action} collection `{collection}`...")
110
+
111
+ print(f"[INDEX] {action} collection: {collection}")
112
+ indexer.create_collection(force_recreate=recreate)
113
+ indexer.create_payload_indexes(fields=[
114
+ {"field": "dataset", "type": "keyword"},
115
+ {"field": "doc_id", "type": "keyword"},
116
+ {"field": "source_doc_id", "type": "keyword"},
117
+ ])
118
+ print(f"[INDEX] Collection ready")
119
+ coll_status.success(f"✅ Collection `{collection}` ready")
120
+
121
+ with phase3_container:
122
+ st.markdown("##### 🚀 Phase 3: Indexing Documents")
123
+
124
+ total_uploaded = 0
125
+ total_docs = 0
126
+ total_time = 0
127
+
128
+ for ds_name in datasets:
129
+ ds_short = ds_name.split("/")[-1]
130
+ ds_header = st.empty()
131
+ ds_header.info(f"📚 Loading `{ds_short}`...")
132
+
133
+ print(f"[INDEX] Loading dataset: {ds_name}")
134
+ corpus, queries, qrels = load_vidore_beir_dataset(ds_name)
135
+
136
+ if max_docs and max_docs > 0 and len(corpus) > max_docs:
137
+ corpus = corpus[:max_docs]
138
+ print(f"[INDEX] Limited to {len(corpus)} docs (max_docs={max_docs})")
139
+
140
+ total_docs += len(corpus)
141
+ print(f"[INDEX] Dataset {ds_name}: {len(corpus)} documents to index")
142
+ ds_header.success(f"📚 `{ds_short}`: {len(corpus)} documents")
143
+
144
+ progress_bar = st.progress(0.0)
145
+ batch_status = st.empty()
146
+ log_area = st.empty()
147
+ log_lines = []
148
+
149
+ num_batches = (len(corpus) + batch_size - 1) // batch_size
150
+ ds_start = time.time()
151
+
152
+ for i in range(0, len(corpus), batch_size):
153
+ batch = corpus[i:i + batch_size]
154
+ images = [doc.image for doc in batch if hasattr(doc, 'image') and doc.image]
155
+
156
+ if not images:
157
+ continue
158
+
159
+ batch_num = i // batch_size + 1
160
+ batch_status.info(f"Batch {batch_num}/{num_batches}: embedding & uploading...")
161
+
162
+ batch_start = time.time()
163
+ embeddings, token_infos = embedder.embed_images(images, return_token_info=True)
164
+ embed_time = time.time() - batch_start
165
+
166
+ points = []
167
+ for j, (doc, emb, token_info) in enumerate(zip(batch, embeddings, token_infos)):
168
+ doc_id = doc.doc_id if hasattr(doc, 'doc_id') else str(i + j)
169
+ source_doc_id = str(doc.payload.get("source_doc_id", doc_id) if hasattr(doc, 'payload') else doc_id)
170
+
171
+ union_doc_id = _union_point_id(
172
+ dataset_name=ds_name,
173
+ source_doc_id=source_doc_id,
174
+ union_namespace=collection,
175
+ )
176
+
177
+ emb_np = emb.cpu().numpy() if hasattr(emb, 'cpu') else np.array(emb)
178
+ visual_indices = token_info.get("visual_token_indices") or list(range(emb_np.shape[0]))
179
+ visual_emb = emb_np[visual_indices].astype(embedder.output_dtype)
180
+
181
+ tile_pooled = embedder.mean_pool_visual_embedding(visual_emb, token_info, target_vectors=32)
182
+ experimental = embedder.experimental_pool_visual_embedding(
183
+ visual_emb, token_info, target_vectors=32, mean_pool=tile_pooled
184
+ )
185
+ global_pooled = embedder.global_pool_from_mean_pool(tile_pooled)
186
+
187
+ points.append({
188
+ "id": union_doc_id,
189
+ "visual_embedding": visual_emb,
190
+ "tile_pooled_embedding": tile_pooled,
191
+ "experimental_pooled_embedding": experimental,
192
+ "global_pooled_embedding": global_pooled,
193
+ "metadata": {
194
+ "dataset": ds_name,
195
+ "doc_id": doc_id,
196
+ "source_doc_id": source_doc_id,
197
+ "union_doc_id": union_doc_id,
198
+ },
199
+ })
200
+
201
+ upload_start = time.time()
202
+ indexer.upload_batch(points)
203
+ upload_time = time.time() - upload_start
204
+ total_uploaded += len(points)
205
+
206
+ progress = (i + len(batch)) / len(corpus)
207
+ progress_bar.progress(progress)
208
+ batch_status.info(f"Batch {batch_num}/{num_batches} ({int(progress*100)}%) — embed: {embed_time:.1f}s, upload: {upload_time:.1f}s")
209
+
210
+ log_interval = max(2, num_batches // 10)
211
+ should_log = batch_num % log_interval == 0 or batch_num == num_batches
212
+
213
+ if should_log and batch_num > 1:
214
+ log_lines.append(f"[Batch {batch_num}/{num_batches}] +{len(points)} pts, total={total_uploaded}")
215
+ log_area.code("\n".join(log_lines[-8:]), language="text")
216
+ print(f"[INDEX] Batch {batch_num}/{num_batches}: +{len(points)} pts, total={total_uploaded}, embed={embed_time:.1f}s, upload={upload_time:.1f}s")
217
+
218
+ ds_time = time.time() - ds_start
219
+ total_time += ds_time
220
+ progress_bar.progress(1.0)
221
+ batch_status.success(f"✅ `{ds_short}` indexed: {len(corpus)} docs in {ds_time:.1f}s")
222
+ print(f"[INDEX] Dataset {ds_name} complete: {len(corpus)} docs in {ds_time:.1f}s")
223
+
224
+ with results_container:
225
+ st.markdown("##### 📊 Summary")
226
+
227
+ docs_per_sec = total_uploaded / total_time if total_time > 0 else 0
228
+
229
+ print("=" * 60)
230
+ print("[INDEX] INDEXING COMPLETE")
231
+ print(f"[INDEX] Total Uploaded: {total_uploaded:,}")
232
+ print(f"[INDEX] Datasets: {len(datasets)}")
233
+ print(f"[INDEX] Collection: {collection}")
234
+ print(f"[INDEX] Total Time: {total_time:.1f}s")
235
+ print(f"[INDEX] Throughput: {docs_per_sec:.2f} docs/s")
236
+ print("=" * 60)
237
+
238
+ c1, c2, c3, c4 = st.columns(4)
239
+ c1.metric("Total Uploaded", f"{total_uploaded:,}")
240
+ c2.metric("Datasets", len(datasets))
241
+ c3.metric("Total Time", f"{total_time:.1f}s")
242
+ c4.metric("Throughput", f"{docs_per_sec:.2f}/s")
243
+
244
+ st.success(f"🎉 Indexing complete! {total_uploaded:,} documents indexed to `{collection}`")
245
+
246
+ detailed_report = {
247
+ "generated_at": datetime.now().isoformat(),
248
+ "config": {
249
+ "collection": collection,
250
+ "model": model,
251
+ "datasets": datasets,
252
+ "batch_size": batch_size,
253
+ "max_docs_per_dataset": max_docs,
254
+ "recreate": recreate,
255
+ "prefer_grpc": prefer_grpc,
256
+ "torch_dtype": torch_dtype,
257
+ "qdrant_vector_dtype": qdrant_vector_dtype,
258
+ },
259
+ "results": {
260
+ "total_docs_uploaded": total_uploaded,
261
+ "total_time_s": round(total_time, 2),
262
+ "throughput_docs_per_s": round(docs_per_sec, 2),
263
+ "num_datasets": len(datasets),
264
+ },
265
+ }
266
+
267
+ with st.expander("📋 Full Summary"):
268
+ st.json(detailed_report)
269
+
270
+ report_json = json.dumps(detailed_report, indent=2)
271
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
272
+ filename = f"index_report__{collection}__{timestamp}.json"
273
+
274
+ st.download_button(
275
+ label="📥 Download Indexing Report",
276
+ data=report_json,
277
+ file_name=filename,
278
+ mime="application/json",
279
+ use_container_width=True,
280
+ )
281
+
282
+ except Exception as e:
283
+ print(f"[INDEX] ERROR: {e}")
284
+ st.error(f"❌ Error: {e}")
285
+ with st.expander("🔍 Full Error Details"):
286
+ st.code(traceback.format_exc(), language="text")
demo/qdrant_utils.py ADDED
@@ -0,0 +1,211 @@
1
+ """Qdrant connection and utility functions."""
2
+
3
+ import os
4
+ import traceback
5
+ from typing import Any, Dict, List, Optional, Tuple
6
+
7
+ import streamlit as st
8
+
9
+
10
+ def get_qdrant_credentials() -> Tuple[Optional[str], Optional[str]]:
11
+ url = st.session_state.get("qdrant_url_input") or os.getenv("SIGIR_QDRANT_URL") or os.getenv("DEST_QDRANT_URL") or os.getenv("QDRANT_URL")
12
+ api_key = st.session_state.get("qdrant_key_input") or (
13
+ os.getenv("SIGIR_QDRANT_KEY")
14
+ or os.getenv("SIGIR_QDRANT_API_KEY")
15
+ or os.getenv("DEST_QDRANT_API_KEY")
16
+ or os.getenv("QDRANT_API_KEY")
17
+ )
18
+ return url, api_key
19
+
20
+
21
+ def init_qdrant_client_with_creds(url: str, api_key: str):
22
+ try:
23
+ from qdrant_client import QdrantClient
24
+ if not url:
25
+ return None, "QDRANT_URL not configured"
26
+ client = QdrantClient(url=url, api_key=api_key, timeout=60)
27
+ client.get_collections()
28
+ return client, None
29
+ except Exception as e:
30
+ return None, str(e)
31
+
32
+
33
+ @st.cache_resource(show_spinner="Connecting to Qdrant...")
34
+ def init_qdrant_client():
35
+ url, api_key = get_qdrant_credentials()
36
+ return init_qdrant_client_with_creds(url, api_key)
37
+
38
+
39
+ @st.cache_resource(show_spinner="Loading embedding model...")
40
+ def init_embedder(model_name: str):
41
+ try:
42
+ from visual_rag import VisualEmbedder
43
+ return VisualEmbedder(model_name=model_name), None
44
+ except Exception as e:
45
+ return None, f"{e}\n\n{traceback.format_exc()}"
46
+
47
+
48
+ @st.cache_data(ttl=300, show_spinner="Fetching collections...")
49
+ def get_collections(_url: str, _api_key: str) -> List[str]:
50
+ client, err = init_qdrant_client_with_creds(_url, _api_key)
51
+ if client is None:
52
+ return []
53
+ try:
54
+ collections = client.get_collections().collections
55
+ return sorted([c.name for c in collections])
56
+ except Exception:
57
+ return []
58
+
59
+
60
+ @st.cache_data(ttl=120, show_spinner="Loading collection stats...")
61
+ def get_collection_stats(collection_name: str) -> Dict[str, Any]:
62
+ url, api_key = get_qdrant_credentials()
63
+ client, err = init_qdrant_client_with_creds(url, api_key)
64
+ if client is None:
65
+ return {"error": err}
66
+ try:
67
+ info = client.get_collection(collection_name)
68
+ vectors_config = getattr(getattr(getattr(info, "config", None), "params", None), "vectors", None)
69
+ vector_info = {}
70
+ if vectors_config is not None:
71
+ if hasattr(vectors_config, "items"):
72
+ for name, cfg in vectors_config.items():
73
+ size = getattr(cfg, "size", None)
74
+ multivec = getattr(cfg, "multivector_config", None)
75
+ on_disk = getattr(cfg, "on_disk", None)
76
+ datatype = str(getattr(cfg, "datatype", "Float32")).replace("Datatype.", "")
77
+ quantization = getattr(cfg, "quantization_config", None)
78
+ num_vectors = 1
79
+ if multivec is not None:
80
+ comparator = getattr(multivec, "comparator", None)
81
+ num_vectors = "N" if comparator else 1
82
+ vector_info[name] = {
83
+ "size": size,
84
+ "num_vectors": num_vectors,
85
+ "is_multivector": multivec is not None,
86
+ "on_disk": on_disk,
87
+ "datatype": datatype,
88
+ "quantization": quantization is not None,
89
+ }
90
+ elif hasattr(vectors_config, "size"):
91
+ on_disk = getattr(vectors_config, "on_disk", None)
92
+ datatype = str(getattr(vectors_config, "datatype", "Float32")).replace("Datatype.", "")
93
+ multivec = getattr(vectors_config, "multivector_config", None)
94
+ vector_info["default"] = {
95
+ "size": getattr(vectors_config, "size", None),
96
+ "num_vectors": "N" if multivec else 1,
97
+ "is_multivector": multivec is not None,
98
+ "on_disk": on_disk,
99
+ "datatype": datatype,
100
+ }
101
+ return {
102
+ "points_count": getattr(info, "points_count", 0),
103
+ "vectors_count": getattr(info, "vectors_count", getattr(info, "points_count", 0)),
104
+ "status": str(getattr(info, "status", "unknown")),
105
+ "vector_info": vector_info,
106
+ "indexed_vectors_count": getattr(info, "indexed_vectors_count", None),
107
+ }
108
+ except Exception as e:
109
+ return {"error": f"{e}\n\n{traceback.format_exc()}"}
110
+
111
+
112
+ @st.cache_data(ttl=60)
113
+ def sample_points_cached(collection_name: str, n: int, seed: int, _url: str, _api_key: str) -> List[Dict[str, Any]]:
114
+ client, err = init_qdrant_client_with_creds(_url, _api_key)
115
+ if client is None:
116
+ return []
117
+ try:
118
+ import random
119
+ rng = random.Random(seed)
120
+ points, _ = client.scroll(
121
+ collection_name=collection_name,
122
+ limit=min(n * 10, 100),
123
+ with_payload=True,
124
+ with_vectors=False,
125
+ )
126
+ if not points:
127
+ return []
128
+ sampled = rng.sample(points, min(n, len(points)))
129
+ results = []
130
+ for p in sampled:
131
+ payload = dict(p.payload) if p.payload else {}
132
+ results.append({
133
+ "id": str(p.id),
134
+ "payload": payload,
135
+ })
136
+ return results
137
+ except Exception:
138
+ return []
139
+
140
+
141
+ @st.cache_data(ttl=300)
142
+ def get_vector_sizes(collection_name: str, _url: str, _api_key: str) -> Dict[str, int]:
143
+ client, err = init_qdrant_client_with_creds(_url, _api_key)
144
+ if client is None:
145
+ return {}
146
+ try:
147
+ points, _ = client.scroll(
148
+ collection_name=collection_name,
149
+ limit=1,
150
+ with_payload=False,
151
+ with_vectors=True,
152
+ )
153
+ if not points:
154
+ return {}
155
+ vectors = points[0].vector
156
+ sizes = {}
157
+ if isinstance(vectors, dict):
158
+ for name, vec in vectors.items():
159
+ if isinstance(vec, list):
160
+ if vec and isinstance(vec[0], list):
161
+ sizes[name] = len(vec)
162
+ else:
163
+ sizes[name] = 1
164
+ else:
165
+ sizes[name] = 1
166
+ return sizes
167
+ except Exception:
168
+ return {}
169
+
170
+
171
+ def search_collection(
172
+ collection_name: str,
173
+ query: str,
174
+ top_k: int = 10,
175
+ mode: str = "single_full",
176
+ prefetch_k: int = 256,
177
+ stage1_mode: str = "tokens_vs_tiles",
178
+ stage1_k: int = 1000,
179
+ stage2_k: int = 300,
180
+ model_name: str = "vidore/colSmol-500M",
181
+ ) -> Tuple[List[Dict[str, Any]], Optional[str]]:
182
+ try:
183
+ import traceback
184
+ from visual_rag.retrieval import MultiVectorRetriever
185
+ retriever = MultiVectorRetriever(
186
+ collection_name=collection_name,
187
+ model_name=model_name,
188
+ )
189
+ if mode == "three_stage":
190
+ q_emb = retriever.embedder.embed_query(query)
191
+ if hasattr(q_emb, "cpu"):
192
+ q_emb = q_emb.cpu().numpy()
193
+ results = retriever.search_embedded(
194
+ query_embedding=q_emb,
195
+ top_k=top_k,
196
+ mode=mode,
197
+ stage1_k=stage1_k,
198
+ stage2_k=stage2_k,
199
+ )
200
+ else:
201
+ results = retriever.search(
202
+ query=query,
203
+ top_k=top_k,
204
+ mode=mode,
205
+ prefetch_k=prefetch_k,
206
+ stage1_mode=stage1_mode,
207
+ )
208
+ return results, None
209
+ except Exception as e:
210
+ import traceback
211
+ return [], f"{e}\n\n{traceback.format_exc()}"
demo/results.py ADDED
@@ -0,0 +1,35 @@
1
+ """Results file handling utilities."""
2
+
3
+ import json
4
+ from pathlib import Path
5
+ from typing import Any, Dict, List, Optional
6
+
7
+
8
+ def load_results_file(path: Path) -> Optional[Dict[str, Any]]:
9
+ try:
10
+ with open(path, "r") as f:
11
+ return json.load(f)
12
+ except Exception:
13
+ return None
14
+
15
+
16
+ def get_available_results() -> List[Path]:
17
+ results_dir = Path(__file__).parent.parent / "results"
18
+ if not results_dir.exists():
19
+ return []
20
+ results = []
21
+ for subdir in results_dir.iterdir():
22
+ if subdir.is_dir():
23
+ for f in subdir.glob("*.json"):
24
+ if "index_failures" not in f.name:
25
+ results.append(f)
26
+ return sorted(results, key=lambda x: x.stat().st_mtime, reverse=True)
27
+
28
+
29
+ def find_main_result_file(collection: str, mode: str) -> Optional[Path]:
30
+ results = get_available_results()
31
+ for r in results:
32
+ if collection in str(r) and mode in r.name:
33
+ if "__vidore_" not in r.name:
34
+ return r
35
+ return results[0] if results else None
@@ -0,0 +1,119 @@
1
+ #!/usr/bin/env python3
2
+ """Test Qdrant connection and collection creation."""
3
+
4
+ import os
5
+ import sys
6
+ from pathlib import Path
7
+
8
+ sys.path.insert(0, str(Path(__file__).parent.parent))
9
+
10
+ from dotenv import load_dotenv
11
+ load_dotenv(Path(__file__).parent.parent / ".env")
12
+ load_dotenv(Path(__file__).parent.parent.parent / ".env")
13
+
14
+ def test_connection():
15
+ from qdrant_client import QdrantClient
16
+ from qdrant_client.http import models
17
+
18
+ url = os.getenv("QDRANT_URL")
19
+ api_key = os.getenv("QDRANT_API_KEY")
20
+
21
+ print(f"URL: {url}")
22
+ print(f"API Key: {'***' + api_key[-4:] if api_key else 'NOT SET'}")
23
+
24
+ if not url or not api_key:
25
+ print("ERROR: QDRANT_URL or QDRANT_API_KEY not set")
26
+ return
27
+
28
+ print("\n1. Creating client...")
29
+ client = QdrantClient(url=url, api_key=api_key, timeout=60)
30
+
31
+ print("\n2. Getting collections...")
32
+ try:
33
+ collections = client.get_collections()
34
+ print(f" Found {len(collections.collections)} collections:")
35
+ for c in collections.collections:
36
+ print(f" - {c.name}")
37
+ except Exception as e:
38
+ print(f" ERROR: {e}")
39
+ return
40
+
41
+ test_collection = "_test_visual_rag_toolkit"
42
+
43
+ print(f"\n3. Checking if '{test_collection}' exists...")
44
+ exists = any(c.name == test_collection for c in collections.collections)
45
+ print(f" Exists: {exists}")
46
+
47
+ if exists:
48
+ print(f"\n4. Deleting test collection...")
49
+ try:
50
+ client.delete_collection(test_collection)
51
+ print(" Deleted")
52
+ except Exception as e:
53
+ print(f" ERROR: {e}")
54
+
55
+ print(f"\n5. Creating SIMPLE collection (single vector)...")
56
+ try:
57
+ client.create_collection(
58
+ collection_name=test_collection,
59
+ vectors_config=models.VectorParams(
60
+ size=128,
61
+ distance=models.Distance.COSINE,
62
+ ),
63
+ )
64
+ print(" SUCCESS: Simple collection created")
65
+ except Exception as e:
66
+ print(f" ERROR: {e}")
67
+ print("\n This means basic collection creation is failing.")
68
+ print(" Check your Qdrant Cloud cluster status/limits.")
69
+ return
70
+
71
+ print(f"\n6. Deleting test collection...")
72
+ try:
73
+ client.delete_collection(test_collection)
74
+ print(" Deleted")
75
+ except Exception as e:
76
+ print(f" ERROR: {e}")
77
+
78
+ print(f"\n7. Creating MULTI-VECTOR collection (like visual-rag)...")
79
+ try:
80
+ client.create_collection(
81
+ collection_name=test_collection,
82
+ vectors_config={
83
+ "initial": models.VectorParams(
84
+ size=128,
85
+ distance=models.Distance.COSINE,
86
+ multivector_config=models.MultiVectorConfig(
87
+ comparator=models.MultiVectorComparator.MAX_SIM
88
+ ),
89
+ ),
90
+ "mean_pooling": models.VectorParams(
91
+ size=128,
92
+ distance=models.Distance.COSINE,
93
+ multivector_config=models.MultiVectorConfig(
94
+ comparator=models.MultiVectorComparator.MAX_SIM
95
+ ),
96
+ ),
97
+ },
98
+ )
99
+ print(" SUCCESS: Multi-vector collection created")
100
+ except Exception as e:
101
+ print(f" ERROR: {e}")
102
+ print("\n Multi-vector collection failed but simple worked.")
103
+ print(" Your Qdrant version may not support multi-vector.")
104
+ return
105
+
106
+ print(f"\n8. Final cleanup...")
107
+ try:
108
+ client.delete_collection(test_collection)
109
+ print(" Deleted")
110
+ except Exception as e:
111
+ print(f" ERROR: {e}")
112
+
113
+ print("\n" + "="*50)
114
+ print("ALL TESTS PASSED - Qdrant connection is working!")
115
+ print("="*50)
116
+
117
+
118
+ if __name__ == "__main__":
119
+ test_connection()
demo/ui/__init__.py ADDED
@@ -0,0 +1,15 @@
1
+ """UI components for the demo app."""
2
+
3
+ from demo.ui.header import render_header
4
+ from demo.ui.sidebar import render_sidebar
5
+ from demo.ui.upload import render_upload_tab
6
+ from demo.ui.playground import render_playground_tab
7
+ from demo.ui.benchmark import render_benchmark_tab
8
+
9
+ __all__ = [
10
+ "render_header",
11
+ "render_sidebar",
12
+ "render_upload_tab",
13
+ "render_playground_tab",
14
+ "render_benchmark_tab",
15
+ ]