visual-rag-toolkit 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
demo/__init__.py CHANGED
@@ -7,4 +7,4 @@ A Streamlit-based UI for:
7
7
  - Interactive playground for visual search
8
8
  """
9
9
 
10
- __version__ = "0.1.0"
10
+ __version__ = "0.1.4"
demo/app.py CHANGED
@@ -1,13 +1,23 @@
1
1
  """Main entry point for the Visual RAG Toolkit demo application."""
2
2
 
3
+ import os
3
4
  import sys
4
5
  from pathlib import Path
5
6
 
6
- ROOT_DIR = Path(__file__).parent.parent
7
- sys.path.insert(0, str(ROOT_DIR))
7
+ # Ensure repo root is in sys.path for local development
8
+ # (In HF Space / Docker, PYTHONPATH is already set correctly)
9
+ _app_dir = Path(__file__).resolve().parent
10
+ _repo_root = _app_dir.parent
11
+ if str(_repo_root) not in sys.path:
12
+ sys.path.insert(0, str(_repo_root))
8
13
 
9
14
  from dotenv import load_dotenv
10
- load_dotenv(ROOT_DIR / ".env")
15
+
16
+ # Load .env from the repo root (works both locally and in Docker)
17
+ if (_repo_root / ".env").exists():
18
+ load_dotenv(_repo_root / ".env")
19
+ if (_app_dir / ".env").exists():
20
+ load_dotenv(_app_dir / ".env")
11
21
 
12
22
  import streamlit as st
13
23
 
@@ -28,15 +38,17 @@ from demo.ui.benchmark import render_benchmark_tab
28
38
  def main():
29
39
  render_header()
30
40
  render_sidebar()
31
-
32
- tab_upload, tab_playground, tab_benchmark = st.tabs(["📤 Upload", "🎮 Playground", "📊 Benchmarking"])
33
-
41
+
42
+ tab_upload, tab_playground, tab_benchmark = st.tabs(
43
+ ["📤 Upload", "🎮 Playground", "📊 Benchmarking"]
44
+ )
45
+
34
46
  with tab_upload:
35
47
  render_upload_tab()
36
-
48
+
37
49
  with tab_playground:
38
50
  render_playground_tab()
39
-
51
+
40
52
  with tab_benchmark:
41
53
  render_benchmark_tab()
42
54
 
demo/evaluation.py CHANGED
@@ -1,20 +1,23 @@
1
1
  """Evaluation runner with UI updates."""
2
2
 
3
3
  import hashlib
4
- import importlib.util
5
4
  import json
6
5
  import logging
7
6
  import time
8
7
  import traceback
9
8
  from datetime import datetime
10
- from pathlib import Path
11
9
  from typing import Any, Dict, List, Optional
12
10
 
13
11
  import numpy as np
14
12
  import streamlit as st
15
13
  import torch
14
+ from qdrant_client.models import FieldCondition, Filter, MatchValue
16
15
 
17
16
  from visual_rag import VisualEmbedder
17
+ from visual_rag.retrieval import MultiVectorRetriever
18
+ from benchmarks.vidore_tatdqa_test.dataset_loader import load_vidore_beir_dataset
19
+ from benchmarks.vidore_tatdqa_test.metrics import ndcg_at_k, mrr_at_k, recall_at_k
20
+ from demo.qdrant_utils import get_qdrant_credentials
18
21
 
19
22
 
20
23
  TORCH_DTYPE_MAP = {
@@ -22,49 +25,6 @@ TORCH_DTYPE_MAP = {
22
25
  "float32": torch.float32,
23
26
  "bfloat16": torch.bfloat16,
24
27
  }
25
- from qdrant_client.models import Filter, FieldCondition, MatchValue
26
-
27
- from visual_rag.retrieval import MultiVectorRetriever
28
-
29
-
30
- def _load_local_benchmark_module(module_filename: str):
31
- """
32
- Load `benchmarks/vidore_tatdqa_test/<module_filename>` via file path.
33
-
34
- Motivation:
35
- - Some environments (notably containers / Spaces) can have a third-party
36
- `benchmarks` package installed, causing `import benchmarks...` to resolve
37
- to the wrong module.
38
- - This fallback guarantees we load the repo's benchmark utilities.
39
- """
40
- root = Path(__file__).resolve().parents[1] # demo/.. = repo root
41
- target = root / "benchmarks" / "vidore_tatdqa_test" / module_filename
42
- if not target.exists():
43
- raise ModuleNotFoundError(f"Missing local benchmark module file: {target}")
44
-
45
- name = f"_visual_rag_toolkit_local_{target.stem}"
46
- spec = importlib.util.spec_from_file_location(name, str(target))
47
- if spec is None or spec.loader is None:
48
- raise ModuleNotFoundError(f"Could not load module spec for: {target}")
49
- mod = importlib.util.module_from_spec(spec)
50
- spec.loader.exec_module(mod) # type: ignore[attr-defined]
51
- return mod
52
-
53
-
54
- try:
55
- # Preferred: normal import
56
- from benchmarks.vidore_tatdqa_test.dataset_loader import load_vidore_beir_dataset
57
- from benchmarks.vidore_tatdqa_test.metrics import ndcg_at_k, mrr_at_k, recall_at_k
58
- except ModuleNotFoundError:
59
- # Robust fallback: load from local file paths
60
- _dl = _load_local_benchmark_module("dataset_loader.py")
61
- _mx = _load_local_benchmark_module("metrics.py")
62
- load_vidore_beir_dataset = _dl.load_vidore_beir_dataset
63
- ndcg_at_k = _mx.ndcg_at_k
64
- mrr_at_k = _mx.mrr_at_k
65
- recall_at_k = _mx.recall_at_k
66
-
67
- from demo.qdrant_utils import get_qdrant_credentials
68
28
 
69
29
  logger = logging.getLogger(__name__)
70
30
  logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")
demo/indexing.py CHANGED
@@ -1,12 +1,10 @@
1
1
  """Indexing runner with UI updates."""
2
2
 
3
3
  import hashlib
4
- import importlib.util
5
4
  import json
6
5
  import time
7
6
  import traceback
8
7
  from datetime import datetime
9
- from pathlib import Path
10
8
  from typing import Any, Dict, Optional
11
9
 
12
10
  import numpy as np
@@ -14,6 +12,10 @@ import streamlit as st
14
12
  import torch
15
13
 
16
14
  from visual_rag import VisualEmbedder
15
+ from visual_rag.embedding.pooling import tile_level_mean_pooling
16
+ from visual_rag.indexing.qdrant_indexer import QdrantIndexer
17
+ from benchmarks.vidore_tatdqa_test.dataset_loader import load_vidore_beir_dataset
18
+ from demo.qdrant_utils import get_qdrant_credentials
17
19
 
18
20
 
19
21
  TORCH_DTYPE_MAP = {
@@ -22,37 +24,6 @@ TORCH_DTYPE_MAP = {
22
24
  "bfloat16": torch.bfloat16,
23
25
  }
24
26
 
25
- # --- Robust imports (Spaces-friendly) ---
26
- # Some environments can have a third-party `benchmarks` package installed, or
27
- # resolve `visual_rag.indexing` oddly. These fallbacks keep the demo working.
28
- try:
29
- from visual_rag.indexing import QdrantIndexer
30
- except Exception: # pragma: no cover
31
- from visual_rag.indexing.qdrant_indexer import QdrantIndexer
32
-
33
-
34
- def _load_local_benchmark_module(module_filename: str):
35
- root = Path(__file__).resolve().parents[1] # demo/.. = repo root
36
- target = root / "benchmarks" / "vidore_tatdqa_test" / module_filename
37
- if not target.exists():
38
- raise ModuleNotFoundError(f"Missing local benchmark module file: {target}")
39
- name = f"_visual_rag_toolkit_local_{target.stem}"
40
- spec = importlib.util.spec_from_file_location(name, str(target))
41
- if spec is None or spec.loader is None:
42
- raise ModuleNotFoundError(f"Could not load module spec for: {target}")
43
- mod = importlib.util.module_from_spec(spec)
44
- spec.loader.exec_module(mod) # type: ignore[attr-defined]
45
- return mod
46
-
47
-
48
- try:
49
- from benchmarks.vidore_tatdqa_test.dataset_loader import load_vidore_beir_dataset
50
- except ModuleNotFoundError: # pragma: no cover
51
- _dl = _load_local_benchmark_module("dataset_loader.py")
52
- load_vidore_beir_dataset = _dl.load_vidore_beir_dataset
53
-
54
- from demo.qdrant_utils import get_qdrant_credentials
55
-
56
27
 
57
28
  def _stable_uuid(text: str) -> str:
58
29
  """Generate a stable UUID from text (same as benchmark script)."""
@@ -60,7 +31,9 @@ def _stable_uuid(text: str) -> str:
60
31
  return f"{hex_str[:8]}-{hex_str[8:12]}-{hex_str[12:16]}-{hex_str[16:20]}-{hex_str[20:32]}"
61
32
 
62
33
 
63
- def _union_point_id(*, dataset_name: str, source_doc_id: str, union_namespace: Optional[str]) -> str:
34
+ def _union_point_id(
35
+ *, dataset_name: str, source_doc_id: str, union_namespace: Optional[str]
36
+ ) -> str:
64
37
  """Generate union point ID (same as benchmark script)."""
65
38
  ns = f"{union_namespace}::{dataset_name}" if union_namespace else dataset_name
66
39
  return _stable_uuid(f"{ns}::{source_doc_id}")
@@ -68,16 +41,16 @@ def _union_point_id(*, dataset_name: str, source_doc_id: str, union_namespace: O
68
41
 
69
42
  def run_indexing_with_ui(config: Dict[str, Any]):
70
43
  st.divider()
71
-
44
+
72
45
  print("=" * 60)
73
46
  print("[INDEX] Starting indexing via UI")
74
47
  print("=" * 60)
75
-
48
+
76
49
  url, api_key = get_qdrant_credentials()
77
50
  if not url:
78
51
  st.error("QDRANT_URL not configured")
79
52
  return
80
-
53
+
81
54
  datasets = config.get("datasets", [])
82
55
  collection = config["collection"]
83
56
  model = config.get("model", "vidore/colpali-v1.3")
@@ -87,42 +60,50 @@ def run_indexing_with_ui(config: Dict[str, Any]):
87
60
  prefer_grpc = config.get("prefer_grpc", True)
88
61
  batch_size = config.get("batch_size", 4)
89
62
  max_docs = config.get("max_docs")
90
-
63
+
91
64
  print(f"[INDEX] Config: collection={collection}, model={model}")
92
65
  print(f"[INDEX] Datasets: {datasets}")
93
- print(f"[INDEX] max_docs={max_docs}, batch_size={batch_size}, recreate={recreate}")
94
- print(f"[INDEX] torch_dtype={torch_dtype}, qdrant_dtype={qdrant_vector_dtype}, grpc={prefer_grpc}")
95
-
66
+ print(
67
+ f"[INDEX] max_docs={max_docs}, batch_size={batch_size}, recreate={recreate}"
68
+ )
69
+ print(
70
+ f"[INDEX] torch_dtype={torch_dtype}, qdrant_dtype={qdrant_vector_dtype}, grpc={prefer_grpc}"
71
+ )
72
+
96
73
  phase1_container = st.container()
97
74
  phase2_container = st.container()
98
75
  phase3_container = st.container()
99
76
  results_container = st.container()
100
-
77
+
101
78
  try:
102
79
  with phase1_container:
103
80
  st.markdown("##### 🤖 Phase 1: Loading Model")
104
81
  model_status = st.empty()
105
82
  model_status.info(f"Loading `{model.split('/')[-1]}`...")
106
-
83
+
107
84
  print(f"[INDEX] Loading embedder: {model}")
108
85
  torch_dtype_obj = TORCH_DTYPE_MAP.get(torch_dtype, torch.float16)
109
- output_dtype_obj = np.float16 if qdrant_vector_dtype == "float16" else np.float32
86
+ output_dtype_obj = (
87
+ np.float16 if qdrant_vector_dtype == "float16" else np.float32
88
+ )
110
89
  embedder = VisualEmbedder(
111
90
  model_name=model,
112
91
  torch_dtype=torch_dtype_obj,
113
92
  output_dtype=output_dtype_obj,
114
93
  )
115
94
  embedder._load_model()
116
- print(f"[INDEX] Embedder loaded (torch_dtype={torch_dtype}, output_dtype={qdrant_vector_dtype})")
95
+ print(
96
+ f"[INDEX] Embedder loaded (torch_dtype={torch_dtype}, output_dtype={qdrant_vector_dtype})"
97
+ )
117
98
  model_status.success(f"✅ Model `{model.split('/')[-1]}` loaded")
118
-
99
+
119
100
  with phase2_container:
120
101
  st.markdown("##### 📦 Phase 2: Setting Up Collection")
121
-
102
+
122
103
  indexer_status = st.empty()
123
- indexer_status.info(f"Connecting to Qdrant...")
124
-
125
- print(f"[INDEX] Connecting to Qdrant...")
104
+ indexer_status.info("Connecting to Qdrant...")
105
+
106
+ print("[INDEX] Connecting to Qdrant...")
126
107
  indexer = QdrantIndexer(
127
108
  url=url,
128
109
  api_key=api_key,
@@ -130,186 +111,164 @@ def run_indexing_with_ui(config: Dict[str, Any]):
130
111
  prefer_grpc=prefer_grpc,
131
112
  vector_datatype=qdrant_vector_dtype,
132
113
  )
133
- print(f"[INDEX] Connected to Qdrant")
134
- indexer_status.success(f"✅ Connected to Qdrant")
135
-
114
+ print("[INDEX] Connected to Qdrant")
115
+ indexer_status.success("✅ Connected to Qdrant")
116
+
136
117
  coll_status = st.empty()
137
118
  action = "Recreating" if recreate else "Creating/verifying"
138
119
  coll_status.info(f"{action} collection `{collection}`...")
139
-
120
+
140
121
  print(f"[INDEX] {action} collection: {collection}")
141
122
  indexer.create_collection(force_recreate=recreate)
142
- indexer.create_payload_indexes(fields=[
143
- {"field": "dataset", "type": "keyword"},
144
- {"field": "doc_id", "type": "keyword"},
145
- {"field": "source_doc_id", "type": "keyword"},
146
- ])
147
- print(f"[INDEX] Collection ready")
123
+ indexer.create_payload_indexes(
124
+ fields=[
125
+ {"field": "dataset", "type": "keyword"},
126
+ {"field": "doc_id", "type": "keyword"},
127
+ {"field": "source_doc_id", "type": "keyword"},
128
+ ]
129
+ )
130
+ print("[INDEX] Collection ready")
148
131
  coll_status.success(f"✅ Collection `{collection}` ready")
149
-
132
+
150
133
  with phase3_container:
151
- st.markdown("##### 🚀 Phase 3: Indexing Documents")
152
-
153
- total_uploaded = 0
154
- total_docs = 0
155
- total_time = 0
156
-
157
- for ds_name in datasets:
158
- ds_short = ds_name.split("/")[-1]
159
- ds_header = st.empty()
160
- ds_header.info(f"📚 Loading `{ds_short}`...")
161
-
162
- print(f"[INDEX] Loading dataset: {ds_name}")
163
- corpus, queries, qrels = load_vidore_beir_dataset(ds_name)
164
-
165
- if max_docs and max_docs > 0 and len(corpus) > max_docs:
166
- corpus = corpus[:max_docs]
167
- print(f"[INDEX] Limited to {len(corpus)} docs (max_docs={max_docs})")
168
-
169
- total_docs += len(corpus)
170
- print(f"[INDEX] Dataset {ds_name}: {len(corpus)} documents to index")
171
- ds_header.success(f"📚 `{ds_short}`: {len(corpus)} documents")
172
-
173
- progress_bar = st.progress(0.0)
174
- batch_status = st.empty()
175
- log_area = st.empty()
176
- log_lines = []
177
-
178
- num_batches = (len(corpus) + batch_size - 1) // batch_size
179
- ds_start = time.time()
180
-
181
- for i in range(0, len(corpus), batch_size):
182
- batch = corpus[i:i + batch_size]
183
- images = [doc.image for doc in batch if hasattr(doc, 'image') and doc.image]
184
-
185
- if not images:
186
- continue
187
-
188
- batch_num = i // batch_size + 1
189
- batch_status.info(f"Batch {batch_num}/{num_batches}: embedding & uploading...")
190
-
191
- batch_start = time.time()
192
- embeddings, token_infos = embedder.embed_images(images, return_token_info=True)
193
- embed_time = time.time() - batch_start
194
-
195
- points = []
196
- for j, (doc, emb, token_info) in enumerate(zip(batch, embeddings, token_infos)):
197
- doc_id = doc.doc_id if hasattr(doc, 'doc_id') else str(i + j)
198
- source_doc_id = str(doc.payload.get("source_doc_id", doc_id) if hasattr(doc, 'payload') else doc_id)
199
-
200
- union_doc_id = _union_point_id(
201
- dataset_name=ds_name,
202
- source_doc_id=source_doc_id,
203
- union_namespace=collection,
204
- )
205
-
206
- emb_np = emb.cpu().numpy() if hasattr(emb, 'cpu') else np.array(emb)
207
- visual_indices = token_info.get("visual_token_indices") or list(range(emb_np.shape[0]))
208
- visual_emb = emb_np[visual_indices].astype(embedder.output_dtype)
209
-
210
- tile_pooled = embedder.mean_pool_visual_embedding(visual_emb, token_info, target_vectors=32)
211
- experimental = embedder.experimental_pool_visual_embedding(
212
- visual_emb, token_info, target_vectors=32, mean_pool=tile_pooled
213
- )
214
- global_pooled = embedder.global_pool_from_mean_pool(tile_pooled)
215
-
216
- points.append({
217
- "id": union_doc_id,
218
- "visual_embedding": visual_emb,
219
- "tile_pooled_embedding": tile_pooled,
220
- "experimental_pooled_embedding": experimental,
221
- "global_pooled_embedding": global_pooled,
222
- "metadata": {
223
- "dataset": ds_name,
134
+ st.markdown("##### 📊 Phase 3: Processing Datasets")
135
+
136
+ all_results = []
137
+
138
+ for ds_idx, dataset_name in enumerate(datasets):
139
+ ds_short = dataset_name.split("/")[-1]
140
+ ds_container = st.container()
141
+
142
+ with ds_container:
143
+ st.markdown(
144
+ f"**Dataset {ds_idx + 1}/{len(datasets)}: `{ds_short}`**"
145
+ )
146
+
147
+ load_status = st.empty()
148
+ load_status.info(f"Loading dataset `{ds_short}`...")
149
+
150
+ print(f"[INDEX] Loading dataset: {dataset_name}")
151
+ corpus, queries, qrels = load_vidore_beir_dataset(dataset_name)
152
+ total_docs = len(corpus)
153
+ print(f"[INDEX] Dataset loaded: {total_docs} docs")
154
+ load_status.success(f" Loaded {total_docs:,} documents")
155
+
156
+ if max_docs and max_docs < total_docs:
157
+ corpus = corpus[:max_docs]
158
+ print(f"[INDEX] Limiting to {max_docs} docs")
159
+
160
+ progress_bar = st.progress(0)
161
+ status_text = st.empty()
162
+
163
+ uploaded = 0
164
+ failed = 0
165
+ total = len(corpus)
166
+
167
+ for i, doc in enumerate(corpus):
168
+ try:
169
+ doc_id = str(doc.doc_id)
170
+ image = doc.image
171
+ if image is None:
172
+ failed += 1
173
+ continue
174
+
175
+ status_text.text(
176
+ f"Processing {i + 1}/{total}: {doc_id[:30]}..."
177
+ )
178
+
179
+ embeddings, token_infos = embedder.embed_images(
180
+ [image],
181
+ return_token_info=True,
182
+ show_progress=False,
183
+ )
184
+ emb = embeddings[0]
185
+ token_info = token_infos[0] if token_infos else {}
186
+
187
+ if hasattr(emb, "cpu"):
188
+ emb = emb.cpu()
189
+ emb_np = np.asarray(emb, dtype=output_dtype_obj)
190
+
191
+ initial = emb_np.tolist()
192
+ global_pool = emb_np.mean(axis=0).tolist()
193
+
194
+ num_tiles = token_info.get("num_tiles")
195
+ mean_pooling = None
196
+ experimental_pooling = None
197
+
198
+ if num_tiles and num_tiles > 0:
199
+ try:
200
+ mean_pooling = tile_level_mean_pooling(
201
+ emb_np, num_tiles=num_tiles, patches_per_tile=64
202
+ ).tolist()
203
+ except Exception:
204
+ pass
205
+
206
+ try:
207
+ exp_pool = embedder.experimental_pool_visual_embedding(
208
+ emb_np, num_tiles=num_tiles
209
+ )
210
+ if exp_pool is not None:
211
+ experimental_pooling = exp_pool.tolist()
212
+ except Exception:
213
+ pass
214
+
215
+ union_doc_id = _union_point_id(
216
+ dataset_name=dataset_name,
217
+ source_doc_id=doc_id,
218
+ union_namespace=collection,
219
+ )
220
+
221
+ payload = {
222
+ "dataset": dataset_name,
224
223
  "doc_id": doc_id,
225
- "source_doc_id": source_doc_id,
224
+ "source_doc_id": doc_id,
226
225
  "union_doc_id": union_doc_id,
227
- },
228
- })
229
-
230
- upload_start = time.time()
231
- indexer.upload_batch(points)
232
- upload_time = time.time() - upload_start
233
- total_uploaded += len(points)
234
-
235
- progress = (i + len(batch)) / len(corpus)
236
- progress_bar.progress(progress)
237
- batch_status.info(f"Batch {batch_num}/{num_batches} ({int(progress*100)}%) — embed: {embed_time:.1f}s, upload: {upload_time:.1f}s")
238
-
239
- log_interval = max(2, num_batches // 10)
240
- should_log = batch_num % log_interval == 0 or batch_num == num_batches
241
-
242
- if should_log and batch_num > 1:
243
- log_lines.append(f"[Batch {batch_num}/{num_batches}] +{len(points)} pts, total={total_uploaded}")
244
- log_area.code("\n".join(log_lines[-8:]), language="text")
245
- print(f"[INDEX] Batch {batch_num}/{num_batches}: +{len(points)} pts, total={total_uploaded}, embed={embed_time:.1f}s, upload={upload_time:.1f}s")
246
-
247
- ds_time = time.time() - ds_start
248
- total_time += ds_time
249
- progress_bar.progress(1.0)
250
- batch_status.success(f"✅ `{ds_short}` indexed: {len(corpus)} docs in {ds_time:.1f}s")
251
- print(f"[INDEX] Dataset {ds_name} complete: {len(corpus)} docs in {ds_time:.1f}s")
252
-
226
+ "num_tiles": num_tiles,
227
+ "num_visual_tokens": token_info.get("num_visual_tokens"),
228
+ }
229
+
230
+ vectors = {"initial": initial, "global_pooling": global_pool}
231
+ if mean_pooling:
232
+ vectors["mean_pooling"] = mean_pooling
233
+ if experimental_pooling:
234
+ vectors["experimental_pooling"] = experimental_pooling
235
+
236
+ indexer.upsert_point(
237
+ point_id=union_doc_id,
238
+ vectors=vectors,
239
+ payload=payload,
240
+ )
241
+
242
+ uploaded += 1
243
+
244
+ except Exception as e:
245
+ print(f"[INDEX] Error on doc {i}: {e}")
246
+ failed += 1
247
+
248
+ progress_bar.progress((i + 1) / total)
249
+
250
+ status_text.text(f" Done: {uploaded} uploaded, {failed} failed")
251
+ all_results.append(
252
+ {
253
+ "dataset": dataset_name,
254
+ "total": total,
255
+ "uploaded": uploaded,
256
+ "failed": failed,
257
+ }
258
+ )
259
+
253
260
  with results_container:
254
- st.markdown("##### 📊 Summary")
255
-
256
- docs_per_sec = total_uploaded / total_time if total_time > 0 else 0
257
-
258
- print("=" * 60)
259
- print("[INDEX] INDEXING COMPLETE")
260
- print(f"[INDEX] Total Uploaded: {total_uploaded:,}")
261
- print(f"[INDEX] Datasets: {len(datasets)}")
262
- print(f"[INDEX] Collection: {collection}")
263
- print(f"[INDEX] Total Time: {total_time:.1f}s")
264
- print(f"[INDEX] Throughput: {docs_per_sec:.2f} docs/s")
265
- print("=" * 60)
266
-
267
- c1, c2, c3, c4 = st.columns(4)
268
- c1.metric("Total Uploaded", f"{total_uploaded:,}")
269
- c2.metric("Datasets", len(datasets))
270
- c3.metric("Total Time", f"{total_time:.1f}s")
271
- c4.metric("Throughput", f"{docs_per_sec:.2f}/s")
272
-
273
- st.success(f"🎉 Indexing complete! {total_uploaded:,} documents indexed to `{collection}`")
274
-
275
- detailed_report = {
276
- "generated_at": datetime.now().isoformat(),
277
- "config": {
278
- "collection": collection,
279
- "model": model,
280
- "datasets": datasets,
281
- "batch_size": batch_size,
282
- "max_docs_per_dataset": max_docs,
283
- "recreate": recreate,
284
- "prefer_grpc": prefer_grpc,
285
- "torch_dtype": torch_dtype,
286
- "qdrant_vector_dtype": qdrant_vector_dtype,
287
- },
288
- "results": {
289
- "total_docs_uploaded": total_uploaded,
290
- "total_time_s": round(total_time, 2),
291
- "throughput_docs_per_s": round(docs_per_sec, 2),
292
- "num_datasets": len(datasets),
293
- },
294
- }
295
-
296
- with st.expander("📋 Full Summary"):
297
- st.json(detailed_report)
298
-
299
- report_json = json.dumps(detailed_report, indent=2)
300
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
301
- filename = f"index_report__{collection}__{timestamp}.json"
302
-
303
- st.download_button(
304
- label="📥 Download Indexing Report",
305
- data=report_json,
306
- file_name=filename,
307
- mime="application/json",
308
- use_container_width=True,
309
- )
310
-
261
+ st.markdown("##### 📋 Results Summary")
262
+
263
+ for r in all_results:
264
+ st.write(
265
+ f"**{r['dataset'].split('/')[-1]}**: {r['uploaded']:,} uploaded, {r['failed']:,} failed"
266
+ )
267
+
268
+ st.success(" Indexing complete!")
269
+
311
270
  except Exception as e:
271
+ st.error(f"Indexing error: {e}")
272
+ st.code(traceback.format_exc())
312
273
  print(f"[INDEX] ERROR: {e}")
313
- st.error(f"❌ Error: {e}")
314
- with st.expander("🔍 Full Error Details"):
315
- st.code(traceback.format_exc(), language="text")
274
+ traceback.print_exc()
demo/qdrant_utils.py CHANGED
@@ -8,12 +8,19 @@ import streamlit as st
8
8
 
9
9
 
10
10
  def get_qdrant_credentials() -> Tuple[Optional[str], Optional[str]]:
11
- url = st.session_state.get("qdrant_url_input") or os.getenv("SIGIR_QDRANT_URL") or os.getenv("DEST_QDRANT_URL") or os.getenv("QDRANT_URL")
12
- api_key = st.session_state.get("qdrant_key_input") or (
13
- os.getenv("SIGIR_QDRANT_KEY")
14
- or os.getenv("SIGIR_QDRANT_API_KEY")
15
- or os.getenv("DEST_QDRANT_API_KEY")
11
+ """Get Qdrant credentials from session state or environment variables.
12
+
13
+ Priority: session_state > QDRANT_URL/QDRANT_API_KEY > legacy env vars
14
+ """
15
+ url = (
16
+ st.session_state.get("qdrant_url_input")
17
+ or os.getenv("QDRANT_URL")
18
+ or os.getenv("SIGIR_QDRANT_URL") # legacy
19
+ )
20
+ api_key = (
21
+ st.session_state.get("qdrant_key_input")
16
22
  or os.getenv("QDRANT_API_KEY")
23
+ or os.getenv("SIGIR_QDRANT_KEY") # legacy
17
24
  )
18
25
  return url, api_key
19
26