visual-rag-toolkit 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
demo/app.py CHANGED
@@ -1,13 +1,23 @@
1
1
  """Main entry point for the Visual RAG Toolkit demo application."""
2
2
 
3
+ import os
3
4
  import sys
4
5
  from pathlib import Path
5
6
 
6
- ROOT_DIR = Path(__file__).parent.parent
7
- sys.path.insert(0, str(ROOT_DIR))
7
+ # Ensure repo root is in sys.path for local development
8
+ # (In HF Space / Docker, PYTHONPATH is already set correctly)
9
+ _app_dir = Path(__file__).resolve().parent
10
+ _repo_root = _app_dir.parent
11
+ if str(_repo_root) not in sys.path:
12
+ sys.path.insert(0, str(_repo_root))
8
13
 
9
14
  from dotenv import load_dotenv
10
- load_dotenv(ROOT_DIR / ".env")
15
+
16
+ # Load .env from the repo root (works both locally and in Docker)
17
+ if (_repo_root / ".env").exists():
18
+ load_dotenv(_repo_root / ".env")
19
+ if (_app_dir / ".env").exists():
20
+ load_dotenv(_app_dir / ".env")
11
21
 
12
22
  import streamlit as st
13
23
 
@@ -28,15 +38,17 @@ from demo.ui.benchmark import render_benchmark_tab
28
38
  def main():
29
39
  render_header()
30
40
  render_sidebar()
31
-
32
- tab_upload, tab_playground, tab_benchmark = st.tabs(["📤 Upload", "🎮 Playground", "📊 Benchmarking"])
33
-
41
+
42
+ tab_upload, tab_playground, tab_benchmark = st.tabs(
43
+ ["📤 Upload", "🎮 Playground", "📊 Benchmarking"]
44
+ )
45
+
34
46
  with tab_upload:
35
47
  render_upload_tab()
36
-
48
+
37
49
  with tab_playground:
38
50
  render_playground_tab()
39
-
51
+
40
52
  with tab_benchmark:
41
53
  render_benchmark_tab()
42
54
 
demo/evaluation.py CHANGED
@@ -1,20 +1,23 @@
1
1
  """Evaluation runner with UI updates."""
2
2
 
3
3
  import hashlib
4
- import importlib.util
5
4
  import json
6
5
  import logging
7
6
  import time
8
7
  import traceback
9
8
  from datetime import datetime
10
- from pathlib import Path
11
9
  from typing import Any, Dict, List, Optional
12
10
 
13
11
  import numpy as np
14
12
  import streamlit as st
15
13
  import torch
14
+ from qdrant_client.models import FieldCondition, Filter, MatchValue
16
15
 
17
16
  from visual_rag import VisualEmbedder
17
+ from visual_rag.retrieval import MultiVectorRetriever
18
+ from benchmarks.vidore_tatdqa_test.dataset_loader import load_vidore_beir_dataset
19
+ from benchmarks.vidore_tatdqa_test.metrics import ndcg_at_k, mrr_at_k, recall_at_k
20
+ from demo.qdrant_utils import get_qdrant_credentials
18
21
 
19
22
 
20
23
  TORCH_DTYPE_MAP = {
@@ -22,49 +25,6 @@ TORCH_DTYPE_MAP = {
22
25
  "float32": torch.float32,
23
26
  "bfloat16": torch.bfloat16,
24
27
  }
25
- from qdrant_client.models import Filter, FieldCondition, MatchValue
26
-
27
- from visual_rag.retrieval import MultiVectorRetriever
28
-
29
-
30
- def _load_local_benchmark_module(module_filename: str):
31
- """
32
- Load `benchmarks/vidore_tatdqa_test/<module_filename>` via file path.
33
-
34
- Motivation:
35
- - Some environments (notably containers / Spaces) can have a third-party
36
- `benchmarks` package installed, causing `import benchmarks...` to resolve
37
- to the wrong module.
38
- - This fallback guarantees we load the repo's benchmark utilities.
39
- """
40
- root = Path(__file__).resolve().parents[1] # demo/.. = repo root
41
- target = root / "benchmarks" / "vidore_tatdqa_test" / module_filename
42
- if not target.exists():
43
- raise ModuleNotFoundError(f"Missing local benchmark module file: {target}")
44
-
45
- name = f"_visual_rag_toolkit_local_{target.stem}"
46
- spec = importlib.util.spec_from_file_location(name, str(target))
47
- if spec is None or spec.loader is None:
48
- raise ModuleNotFoundError(f"Could not load module spec for: {target}")
49
- mod = importlib.util.module_from_spec(spec)
50
- spec.loader.exec_module(mod) # type: ignore[attr-defined]
51
- return mod
52
-
53
-
54
- try:
55
- # Preferred: normal import
56
- from benchmarks.vidore_tatdqa_test.dataset_loader import load_vidore_beir_dataset
57
- from benchmarks.vidore_tatdqa_test.metrics import ndcg_at_k, mrr_at_k, recall_at_k
58
- except ModuleNotFoundError:
59
- # Robust fallback: load from local file paths
60
- _dl = _load_local_benchmark_module("dataset_loader.py")
61
- _mx = _load_local_benchmark_module("metrics.py")
62
- load_vidore_beir_dataset = _dl.load_vidore_beir_dataset
63
- ndcg_at_k = _mx.ndcg_at_k
64
- mrr_at_k = _mx.mrr_at_k
65
- recall_at_k = _mx.recall_at_k
66
-
67
- from demo.qdrant_utils import get_qdrant_credentials
68
28
 
69
29
  logger = logging.getLogger(__name__)
70
30
  logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")
demo/indexing.py CHANGED
@@ -12,6 +12,10 @@ import streamlit as st
12
12
  import torch
13
13
 
14
14
  from visual_rag import VisualEmbedder
15
+ from visual_rag.embedding.pooling import tile_level_mean_pooling
16
+ from visual_rag.indexing.qdrant_indexer import QdrantIndexer
17
+ from benchmarks.vidore_tatdqa_test.dataset_loader import load_vidore_beir_dataset
18
+ from demo.qdrant_utils import get_qdrant_credentials
15
19
 
16
20
 
17
21
  TORCH_DTYPE_MAP = {
@@ -19,10 +23,6 @@ TORCH_DTYPE_MAP = {
19
23
  "float32": torch.float32,
20
24
  "bfloat16": torch.bfloat16,
21
25
  }
22
- from visual_rag.indexing import QdrantIndexer
23
- from benchmarks.vidore_tatdqa_test.dataset_loader import load_vidore_beir_dataset
24
-
25
- from demo.qdrant_utils import get_qdrant_credentials
26
26
 
27
27
 
28
28
  def _stable_uuid(text: str) -> str:
@@ -31,7 +31,9 @@ def _stable_uuid(text: str) -> str:
31
31
  return f"{hex_str[:8]}-{hex_str[8:12]}-{hex_str[12:16]}-{hex_str[16:20]}-{hex_str[20:32]}"
32
32
 
33
33
 
34
- def _union_point_id(*, dataset_name: str, source_doc_id: str, union_namespace: Optional[str]) -> str:
34
+ def _union_point_id(
35
+ *, dataset_name: str, source_doc_id: str, union_namespace: Optional[str]
36
+ ) -> str:
35
37
  """Generate union point ID (same as benchmark script)."""
36
38
  ns = f"{union_namespace}::{dataset_name}" if union_namespace else dataset_name
37
39
  return _stable_uuid(f"{ns}::{source_doc_id}")
@@ -39,16 +41,16 @@ def _union_point_id(*, dataset_name: str, source_doc_id: str, union_namespace: O
39
41
 
40
42
  def run_indexing_with_ui(config: Dict[str, Any]):
41
43
  st.divider()
42
-
44
+
43
45
  print("=" * 60)
44
46
  print("[INDEX] Starting indexing via UI")
45
47
  print("=" * 60)
46
-
48
+
47
49
  url, api_key = get_qdrant_credentials()
48
50
  if not url:
49
51
  st.error("QDRANT_URL not configured")
50
52
  return
51
-
53
+
52
54
  datasets = config.get("datasets", [])
53
55
  collection = config["collection"]
54
56
  model = config.get("model", "vidore/colpali-v1.3")
@@ -58,42 +60,50 @@ def run_indexing_with_ui(config: Dict[str, Any]):
58
60
  prefer_grpc = config.get("prefer_grpc", True)
59
61
  batch_size = config.get("batch_size", 4)
60
62
  max_docs = config.get("max_docs")
61
-
63
+
62
64
  print(f"[INDEX] Config: collection={collection}, model={model}")
63
65
  print(f"[INDEX] Datasets: {datasets}")
64
- print(f"[INDEX] max_docs={max_docs}, batch_size={batch_size}, recreate={recreate}")
65
- print(f"[INDEX] torch_dtype={torch_dtype}, qdrant_dtype={qdrant_vector_dtype}, grpc={prefer_grpc}")
66
-
66
+ print(
67
+ f"[INDEX] max_docs={max_docs}, batch_size={batch_size}, recreate={recreate}"
68
+ )
69
+ print(
70
+ f"[INDEX] torch_dtype={torch_dtype}, qdrant_dtype={qdrant_vector_dtype}, grpc={prefer_grpc}"
71
+ )
72
+
67
73
  phase1_container = st.container()
68
74
  phase2_container = st.container()
69
75
  phase3_container = st.container()
70
76
  results_container = st.container()
71
-
77
+
72
78
  try:
73
79
  with phase1_container:
74
80
  st.markdown("##### 🤖 Phase 1: Loading Model")
75
81
  model_status = st.empty()
76
82
  model_status.info(f"Loading `{model.split('/')[-1]}`...")
77
-
83
+
78
84
  print(f"[INDEX] Loading embedder: {model}")
79
85
  torch_dtype_obj = TORCH_DTYPE_MAP.get(torch_dtype, torch.float16)
80
- output_dtype_obj = np.float16 if qdrant_vector_dtype == "float16" else np.float32
86
+ output_dtype_obj = (
87
+ np.float16 if qdrant_vector_dtype == "float16" else np.float32
88
+ )
81
89
  embedder = VisualEmbedder(
82
90
  model_name=model,
83
91
  torch_dtype=torch_dtype_obj,
84
92
  output_dtype=output_dtype_obj,
85
93
  )
86
94
  embedder._load_model()
87
- print(f"[INDEX] Embedder loaded (torch_dtype={torch_dtype}, output_dtype={qdrant_vector_dtype})")
95
+ print(
96
+ f"[INDEX] Embedder loaded (torch_dtype={torch_dtype}, output_dtype={qdrant_vector_dtype})"
97
+ )
88
98
  model_status.success(f"✅ Model `{model.split('/')[-1]}` loaded")
89
-
99
+
90
100
  with phase2_container:
91
101
  st.markdown("##### 📦 Phase 2: Setting Up Collection")
92
-
102
+
93
103
  indexer_status = st.empty()
94
- indexer_status.info(f"Connecting to Qdrant...")
95
-
96
- print(f"[INDEX] Connecting to Qdrant...")
104
+ indexer_status.info("Connecting to Qdrant...")
105
+
106
+ print("[INDEX] Connecting to Qdrant...")
97
107
  indexer = QdrantIndexer(
98
108
  url=url,
99
109
  api_key=api_key,
@@ -101,186 +111,164 @@ def run_indexing_with_ui(config: Dict[str, Any]):
101
111
  prefer_grpc=prefer_grpc,
102
112
  vector_datatype=qdrant_vector_dtype,
103
113
  )
104
- print(f"[INDEX] Connected to Qdrant")
105
- indexer_status.success(f"✅ Connected to Qdrant")
106
-
114
+ print("[INDEX] Connected to Qdrant")
115
+ indexer_status.success("✅ Connected to Qdrant")
116
+
107
117
  coll_status = st.empty()
108
118
  action = "Recreating" if recreate else "Creating/verifying"
109
119
  coll_status.info(f"{action} collection `{collection}`...")
110
-
120
+
111
121
  print(f"[INDEX] {action} collection: {collection}")
112
122
  indexer.create_collection(force_recreate=recreate)
113
- indexer.create_payload_indexes(fields=[
114
- {"field": "dataset", "type": "keyword"},
115
- {"field": "doc_id", "type": "keyword"},
116
- {"field": "source_doc_id", "type": "keyword"},
117
- ])
118
- print(f"[INDEX] Collection ready")
123
+ indexer.create_payload_indexes(
124
+ fields=[
125
+ {"field": "dataset", "type": "keyword"},
126
+ {"field": "doc_id", "type": "keyword"},
127
+ {"field": "source_doc_id", "type": "keyword"},
128
+ ]
129
+ )
130
+ print("[INDEX] Collection ready")
119
131
  coll_status.success(f"✅ Collection `{collection}` ready")
120
-
132
+
121
133
  with phase3_container:
122
- st.markdown("##### 🚀 Phase 3: Indexing Documents")
123
-
124
- total_uploaded = 0
125
- total_docs = 0
126
- total_time = 0
127
-
128
- for ds_name in datasets:
129
- ds_short = ds_name.split("/")[-1]
130
- ds_header = st.empty()
131
- ds_header.info(f"📚 Loading `{ds_short}`...")
132
-
133
- print(f"[INDEX] Loading dataset: {ds_name}")
134
- corpus, queries, qrels = load_vidore_beir_dataset(ds_name)
135
-
136
- if max_docs and max_docs > 0 and len(corpus) > max_docs:
137
- corpus = corpus[:max_docs]
138
- print(f"[INDEX] Limited to {len(corpus)} docs (max_docs={max_docs})")
139
-
140
- total_docs += len(corpus)
141
- print(f"[INDEX] Dataset {ds_name}: {len(corpus)} documents to index")
142
- ds_header.success(f"📚 `{ds_short}`: {len(corpus)} documents")
143
-
144
- progress_bar = st.progress(0.0)
145
- batch_status = st.empty()
146
- log_area = st.empty()
147
- log_lines = []
148
-
149
- num_batches = (len(corpus) + batch_size - 1) // batch_size
150
- ds_start = time.time()
151
-
152
- for i in range(0, len(corpus), batch_size):
153
- batch = corpus[i:i + batch_size]
154
- images = [doc.image for doc in batch if hasattr(doc, 'image') and doc.image]
155
-
156
- if not images:
157
- continue
158
-
159
- batch_num = i // batch_size + 1
160
- batch_status.info(f"Batch {batch_num}/{num_batches}: embedding & uploading...")
161
-
162
- batch_start = time.time()
163
- embeddings, token_infos = embedder.embed_images(images, return_token_info=True)
164
- embed_time = time.time() - batch_start
165
-
166
- points = []
167
- for j, (doc, emb, token_info) in enumerate(zip(batch, embeddings, token_infos)):
168
- doc_id = doc.doc_id if hasattr(doc, 'doc_id') else str(i + j)
169
- source_doc_id = str(doc.payload.get("source_doc_id", doc_id) if hasattr(doc, 'payload') else doc_id)
170
-
171
- union_doc_id = _union_point_id(
172
- dataset_name=ds_name,
173
- source_doc_id=source_doc_id,
174
- union_namespace=collection,
175
- )
176
-
177
- emb_np = emb.cpu().numpy() if hasattr(emb, 'cpu') else np.array(emb)
178
- visual_indices = token_info.get("visual_token_indices") or list(range(emb_np.shape[0]))
179
- visual_emb = emb_np[visual_indices].astype(embedder.output_dtype)
180
-
181
- tile_pooled = embedder.mean_pool_visual_embedding(visual_emb, token_info, target_vectors=32)
182
- experimental = embedder.experimental_pool_visual_embedding(
183
- visual_emb, token_info, target_vectors=32, mean_pool=tile_pooled
184
- )
185
- global_pooled = embedder.global_pool_from_mean_pool(tile_pooled)
186
-
187
- points.append({
188
- "id": union_doc_id,
189
- "visual_embedding": visual_emb,
190
- "tile_pooled_embedding": tile_pooled,
191
- "experimental_pooled_embedding": experimental,
192
- "global_pooled_embedding": global_pooled,
193
- "metadata": {
194
- "dataset": ds_name,
134
+ st.markdown("##### 📊 Phase 3: Processing Datasets")
135
+
136
+ all_results = []
137
+
138
+ for ds_idx, dataset_name in enumerate(datasets):
139
+ ds_short = dataset_name.split("/")[-1]
140
+ ds_container = st.container()
141
+
142
+ with ds_container:
143
+ st.markdown(
144
+ f"**Dataset {ds_idx + 1}/{len(datasets)}: `{ds_short}`**"
145
+ )
146
+
147
+ load_status = st.empty()
148
+ load_status.info(f"Loading dataset `{ds_short}`...")
149
+
150
+ print(f"[INDEX] Loading dataset: {dataset_name}")
151
+ corpus, queries, qrels = load_vidore_beir_dataset(dataset_name)
152
+ total_docs = len(corpus)
153
+ print(f"[INDEX] Dataset loaded: {total_docs} docs")
154
+ load_status.success(f" Loaded {total_docs:,} documents")
155
+
156
+ if max_docs and max_docs < total_docs:
157
+ corpus = corpus[:max_docs]
158
+ print(f"[INDEX] Limiting to {max_docs} docs")
159
+
160
+ progress_bar = st.progress(0)
161
+ status_text = st.empty()
162
+
163
+ uploaded = 0
164
+ failed = 0
165
+ total = len(corpus)
166
+
167
+ for i, doc in enumerate(corpus):
168
+ try:
169
+ doc_id = str(doc.doc_id)
170
+ image = doc.image
171
+ if image is None:
172
+ failed += 1
173
+ continue
174
+
175
+ status_text.text(
176
+ f"Processing {i + 1}/{total}: {doc_id[:30]}..."
177
+ )
178
+
179
+ embeddings, token_infos = embedder.embed_images(
180
+ [image],
181
+ return_token_info=True,
182
+ show_progress=False,
183
+ )
184
+ emb = embeddings[0]
185
+ token_info = token_infos[0] if token_infos else {}
186
+
187
+ if hasattr(emb, "cpu"):
188
+ emb = emb.cpu()
189
+ emb_np = np.asarray(emb, dtype=output_dtype_obj)
190
+
191
+ initial = emb_np.tolist()
192
+ global_pool = emb_np.mean(axis=0).tolist()
193
+
194
+ num_tiles = token_info.get("num_tiles")
195
+ mean_pooling = None
196
+ experimental_pooling = None
197
+
198
+ if num_tiles and num_tiles > 0:
199
+ try:
200
+ mean_pooling = tile_level_mean_pooling(
201
+ emb_np, num_tiles=num_tiles, patches_per_tile=64
202
+ ).tolist()
203
+ except Exception:
204
+ pass
205
+
206
+ try:
207
+ exp_pool = embedder.experimental_pool_visual_embedding(
208
+ emb_np, num_tiles=num_tiles
209
+ )
210
+ if exp_pool is not None:
211
+ experimental_pooling = exp_pool.tolist()
212
+ except Exception:
213
+ pass
214
+
215
+ union_doc_id = _union_point_id(
216
+ dataset_name=dataset_name,
217
+ source_doc_id=doc_id,
218
+ union_namespace=collection,
219
+ )
220
+
221
+ payload = {
222
+ "dataset": dataset_name,
195
223
  "doc_id": doc_id,
196
- "source_doc_id": source_doc_id,
224
+ "source_doc_id": doc_id,
197
225
  "union_doc_id": union_doc_id,
198
- },
199
- })
200
-
201
- upload_start = time.time()
202
- indexer.upload_batch(points)
203
- upload_time = time.time() - upload_start
204
- total_uploaded += len(points)
205
-
206
- progress = (i + len(batch)) / len(corpus)
207
- progress_bar.progress(progress)
208
- batch_status.info(f"Batch {batch_num}/{num_batches} ({int(progress*100)}%) — embed: {embed_time:.1f}s, upload: {upload_time:.1f}s")
209
-
210
- log_interval = max(2, num_batches // 10)
211
- should_log = batch_num % log_interval == 0 or batch_num == num_batches
212
-
213
- if should_log and batch_num > 1:
214
- log_lines.append(f"[Batch {batch_num}/{num_batches}] +{len(points)} pts, total={total_uploaded}")
215
- log_area.code("\n".join(log_lines[-8:]), language="text")
216
- print(f"[INDEX] Batch {batch_num}/{num_batches}: +{len(points)} pts, total={total_uploaded}, embed={embed_time:.1f}s, upload={upload_time:.1f}s")
217
-
218
- ds_time = time.time() - ds_start
219
- total_time += ds_time
220
- progress_bar.progress(1.0)
221
- batch_status.success(f"✅ `{ds_short}` indexed: {len(corpus)} docs in {ds_time:.1f}s")
222
- print(f"[INDEX] Dataset {ds_name} complete: {len(corpus)} docs in {ds_time:.1f}s")
223
-
226
+ "num_tiles": num_tiles,
227
+ "num_visual_tokens": token_info.get("num_visual_tokens"),
228
+ }
229
+
230
+ vectors = {"initial": initial, "global_pooling": global_pool}
231
+ if mean_pooling:
232
+ vectors["mean_pooling"] = mean_pooling
233
+ if experimental_pooling:
234
+ vectors["experimental_pooling"] = experimental_pooling
235
+
236
+ indexer.upsert_point(
237
+ point_id=union_doc_id,
238
+ vectors=vectors,
239
+ payload=payload,
240
+ )
241
+
242
+ uploaded += 1
243
+
244
+ except Exception as e:
245
+ print(f"[INDEX] Error on doc {i}: {e}")
246
+ failed += 1
247
+
248
+ progress_bar.progress((i + 1) / total)
249
+
250
+ status_text.text(f" Done: {uploaded} uploaded, {failed} failed")
251
+ all_results.append(
252
+ {
253
+ "dataset": dataset_name,
254
+ "total": total,
255
+ "uploaded": uploaded,
256
+ "failed": failed,
257
+ }
258
+ )
259
+
224
260
  with results_container:
225
- st.markdown("##### 📊 Summary")
226
-
227
- docs_per_sec = total_uploaded / total_time if total_time > 0 else 0
228
-
229
- print("=" * 60)
230
- print("[INDEX] INDEXING COMPLETE")
231
- print(f"[INDEX] Total Uploaded: {total_uploaded:,}")
232
- print(f"[INDEX] Datasets: {len(datasets)}")
233
- print(f"[INDEX] Collection: {collection}")
234
- print(f"[INDEX] Total Time: {total_time:.1f}s")
235
- print(f"[INDEX] Throughput: {docs_per_sec:.2f} docs/s")
236
- print("=" * 60)
237
-
238
- c1, c2, c3, c4 = st.columns(4)
239
- c1.metric("Total Uploaded", f"{total_uploaded:,}")
240
- c2.metric("Datasets", len(datasets))
241
- c3.metric("Total Time", f"{total_time:.1f}s")
242
- c4.metric("Throughput", f"{docs_per_sec:.2f}/s")
243
-
244
- st.success(f"🎉 Indexing complete! {total_uploaded:,} documents indexed to `{collection}`")
245
-
246
- detailed_report = {
247
- "generated_at": datetime.now().isoformat(),
248
- "config": {
249
- "collection": collection,
250
- "model": model,
251
- "datasets": datasets,
252
- "batch_size": batch_size,
253
- "max_docs_per_dataset": max_docs,
254
- "recreate": recreate,
255
- "prefer_grpc": prefer_grpc,
256
- "torch_dtype": torch_dtype,
257
- "qdrant_vector_dtype": qdrant_vector_dtype,
258
- },
259
- "results": {
260
- "total_docs_uploaded": total_uploaded,
261
- "total_time_s": round(total_time, 2),
262
- "throughput_docs_per_s": round(docs_per_sec, 2),
263
- "num_datasets": len(datasets),
264
- },
265
- }
266
-
267
- with st.expander("📋 Full Summary"):
268
- st.json(detailed_report)
269
-
270
- report_json = json.dumps(detailed_report, indent=2)
271
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
272
- filename = f"index_report__{collection}__{timestamp}.json"
273
-
274
- st.download_button(
275
- label="📥 Download Indexing Report",
276
- data=report_json,
277
- file_name=filename,
278
- mime="application/json",
279
- use_container_width=True,
280
- )
281
-
261
+ st.markdown("##### 📋 Results Summary")
262
+
263
+ for r in all_results:
264
+ st.write(
265
+ f"**{r['dataset'].split('/')[-1]}**: {r['uploaded']:,} uploaded, {r['failed']:,} failed"
266
+ )
267
+
268
+ st.success(" Indexing complete!")
269
+
282
270
  except Exception as e:
271
+ st.error(f"Indexing error: {e}")
272
+ st.code(traceback.format_exc())
283
273
  print(f"[INDEX] ERROR: {e}")
284
- st.error(f"❌ Error: {e}")
285
- with st.expander("🔍 Full Error Details"):
286
- st.code(traceback.format_exc(), language="text")
274
+ traceback.print_exc()
demo/qdrant_utils.py CHANGED
@@ -8,12 +8,19 @@ import streamlit as st
8
8
 
9
9
 
10
10
  def get_qdrant_credentials() -> Tuple[Optional[str], Optional[str]]:
11
- url = st.session_state.get("qdrant_url_input") or os.getenv("SIGIR_QDRANT_URL") or os.getenv("DEST_QDRANT_URL") or os.getenv("QDRANT_URL")
12
- api_key = st.session_state.get("qdrant_key_input") or (
13
- os.getenv("SIGIR_QDRANT_KEY")
14
- or os.getenv("SIGIR_QDRANT_API_KEY")
15
- or os.getenv("DEST_QDRANT_API_KEY")
11
+ """Get Qdrant credentials from session state or environment variables.
12
+
13
+ Priority: session_state > QDRANT_URL/QDRANT_API_KEY > legacy env vars
14
+ """
15
+ url = (
16
+ st.session_state.get("qdrant_url_input")
17
+ or os.getenv("QDRANT_URL")
18
+ or os.getenv("SIGIR_QDRANT_URL") # legacy
19
+ )
20
+ api_key = (
21
+ st.session_state.get("qdrant_key_input")
16
22
  or os.getenv("QDRANT_API_KEY")
23
+ or os.getenv("SIGIR_QDRANT_KEY") # legacy
17
24
  )
18
25
  return url, api_key
19
26
 
demo/ui/playground.py CHANGED
@@ -9,6 +9,7 @@ from demo.qdrant_utils import (
9
9
  sample_points_cached,
10
10
  search_collection,
11
11
  )
12
+ from visual_rag.retrieval import MultiVectorRetriever
12
13
 
13
14
 
14
15
  def render_playground_tab():
@@ -46,7 +47,6 @@ def render_playground_tab():
46
47
  if not st.session_state.get("model_loaded"):
47
48
  with st.spinner(f"Loading {model_short}..."):
48
49
  try:
49
- from visual_rag.retrieval import MultiVectorRetriever
50
50
  _ = MultiVectorRetriever(collection_name=active_collection, model_name=model_name)
51
51
  st.session_state["model_loaded"] = True
52
52
  st.session_state["loaded_model_key"] = cache_key
demo/ui/sidebar.py CHANGED
@@ -3,6 +3,8 @@
3
3
  import os
4
4
  import streamlit as st
5
5
 
6
+ from qdrant_client.models import VectorParamsDiff
7
+
6
8
  from demo.qdrant_utils import (
7
9
  get_qdrant_credentials,
8
10
  init_qdrant_client_with_creds,
@@ -17,8 +19,8 @@ def render_sidebar():
17
19
  with st.sidebar:
18
20
  st.subheader("🔑 Qdrant Credentials")
19
21
 
20
- env_url = os.getenv("SIGIR_QDRANT_URL") or os.getenv("DEST_QDRANT_URL") or os.getenv("QDRANT_URL") or ""
21
- env_key = os.getenv("SIGIR_QDRANT_KEY") or os.getenv("SIGIR_QDRANT_API_KEY") or os.getenv("DEST_QDRANT_API_KEY") or os.getenv("QDRANT_API_KEY") or ""
22
+ env_url = os.getenv("QDRANT_URL") or os.getenv("SIGIR_QDRANT_URL") or ""
23
+ env_key = os.getenv("QDRANT_API_KEY") or os.getenv("SIGIR_QDRANT_KEY") or ""
22
24
 
23
25
  if "qdrant_url_input" not in st.session_state:
24
26
  st.session_state["qdrant_url_input"] = env_url
@@ -136,7 +138,6 @@ def render_sidebar():
136
138
  if target_in_ram != current_in_ram:
137
139
  if st.button("💾 Apply Change", key="admin_apply"):
138
140
  try:
139
- from qdrant_client.models import VectorParamsDiff
140
141
  client.update_collection(
141
142
  collection_name=active,
142
143
  vectors_config={sel_vec: VectorParamsDiff(on_disk=not target_in_ram)}