visual-rag-toolkit 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. benchmarks/README.md +101 -0
  2. benchmarks/__init__.py +11 -0
  3. benchmarks/analyze_results.py +187 -0
  4. benchmarks/benchmark_datasets.txt +105 -0
  5. benchmarks/prepare_submission.py +205 -0
  6. benchmarks/quick_test.py +566 -0
  7. benchmarks/run_vidore.py +513 -0
  8. benchmarks/vidore_beir_qdrant/run_qdrant_beir.py +1365 -0
  9. benchmarks/vidore_tatdqa_test/COMMANDS.md +83 -0
  10. benchmarks/vidore_tatdqa_test/__init__.py +6 -0
  11. benchmarks/vidore_tatdqa_test/dataset_loader.py +363 -0
  12. benchmarks/vidore_tatdqa_test/metrics.py +44 -0
  13. benchmarks/vidore_tatdqa_test/run_qdrant.py +799 -0
  14. benchmarks/vidore_tatdqa_test/sweep_eval.py +372 -0
  15. demo/__init__.py +10 -0
  16. demo/app.py +45 -0
  17. demo/commands.py +334 -0
  18. demo/config.py +34 -0
  19. demo/download_models.py +75 -0
  20. demo/evaluation.py +602 -0
  21. demo/example_metadata_mapping_sigir.json +37 -0
  22. demo/indexing.py +286 -0
  23. demo/qdrant_utils.py +211 -0
  24. demo/results.py +35 -0
  25. demo/test_qdrant_connection.py +119 -0
  26. demo/ui/__init__.py +15 -0
  27. demo/ui/benchmark.py +355 -0
  28. demo/ui/header.py +30 -0
  29. demo/ui/playground.py +339 -0
  30. demo/ui/sidebar.py +162 -0
  31. demo/ui/upload.py +487 -0
  32. visual_rag/__init__.py +98 -0
  33. visual_rag/cli/__init__.py +1 -0
  34. visual_rag/cli/main.py +629 -0
  35. visual_rag/config.py +230 -0
  36. visual_rag/demo_runner.py +90 -0
  37. visual_rag/embedding/__init__.py +26 -0
  38. visual_rag/embedding/pooling.py +343 -0
  39. visual_rag/embedding/visual_embedder.py +622 -0
  40. visual_rag/indexing/__init__.py +21 -0
  41. visual_rag/indexing/cloudinary_uploader.py +274 -0
  42. visual_rag/indexing/pdf_processor.py +324 -0
  43. visual_rag/indexing/pipeline.py +628 -0
  44. visual_rag/indexing/qdrant_indexer.py +478 -0
  45. visual_rag/preprocessing/__init__.py +3 -0
  46. visual_rag/preprocessing/crop_empty.py +120 -0
  47. visual_rag/qdrant_admin.py +222 -0
  48. visual_rag/retrieval/__init__.py +19 -0
  49. visual_rag/retrieval/multi_vector.py +222 -0
  50. visual_rag/retrieval/single_stage.py +126 -0
  51. visual_rag/retrieval/three_stage.py +173 -0
  52. visual_rag/retrieval/two_stage.py +471 -0
  53. visual_rag/visualization/__init__.py +19 -0
  54. visual_rag/visualization/saliency.py +335 -0
  55. visual_rag_toolkit-0.1.1.dist-info/METADATA +305 -0
  56. visual_rag_toolkit-0.1.1.dist-info/RECORD +59 -0
  57. visual_rag_toolkit-0.1.1.dist-info/WHEEL +4 -0
  58. visual_rag_toolkit-0.1.1.dist-info/entry_points.txt +3 -0
  59. visual_rag_toolkit-0.1.1.dist-info/licenses/LICENSE +22 -0
demo/ui/playground.py ADDED
@@ -0,0 +1,339 @@
1
+ """Playground tab component."""
2
+
3
+ import streamlit as st
4
+
5
+ from demo.config import AVAILABLE_MODELS, RETRIEVAL_MODES, STAGE1_MODES
6
+ from demo.qdrant_utils import (
7
+ get_qdrant_credentials,
8
+ get_collections,
9
+ sample_points_cached,
10
+ search_collection,
11
+ )
12
+
13
+
14
+ def render_playground_tab():
15
+ st.subheader("🎮 Playground")
16
+
17
+ active_collection = st.session_state.get("active_collection")
18
+ url, api_key = get_qdrant_credentials()
19
+
20
+ if not active_collection:
21
+ collections = get_collections(url, api_key)
22
+ if collections:
23
+ active_collection = collections[0]
24
+
25
+ if not active_collection:
26
+ st.warning("No collection available. Upload documents or select a collection.")
27
+ return
28
+
29
+ points_for_model = sample_points_cached(active_collection, 1, 0, url, api_key)
30
+ model_name = None
31
+ if points_for_model:
32
+ model_name = points_for_model[0].get("payload", {}).get("model_name")
33
+ if not model_name:
34
+ model_name = AVAILABLE_MODELS[1]
35
+
36
+ model_short = model_name.split("/")[-1] if model_name else "unknown"
37
+ cache_key = f"{active_collection}_{model_name}"
38
+
39
+ if st.session_state.get("loaded_model_key") != cache_key:
40
+ st.session_state["model_loaded"] = False
41
+
42
+ col_info, col_model = st.columns([2, 1])
43
+ with col_info:
44
+ st.info(f"**Collection:** `{active_collection}`")
45
+ with col_model:
46
+ if not st.session_state.get("model_loaded"):
47
+ with st.spinner(f"Loading {model_short}..."):
48
+ try:
49
+ from visual_rag.retrieval import MultiVectorRetriever
50
+ _ = MultiVectorRetriever(collection_name=active_collection, model_name=model_name)
51
+ st.session_state["model_loaded"] = True
52
+ st.session_state["loaded_model_key"] = cache_key
53
+ st.session_state["loaded_model_name"] = model_name
54
+ except Exception as e:
55
+ st.warning(f"Failed: {model_short}")
56
+
57
+ if st.session_state.get("model_loaded"):
58
+ st.markdown(f"✅ Found <span style='color:#e74c3c;font-weight:bold;'>{model_short}</span> model", unsafe_allow_html=True)
59
+
60
+ with st.expander("📦 Sample Points Explorer", expanded=True):
61
+ render_sample_explorer(active_collection, url, api_key)
62
+
63
+ st.divider()
64
+
65
+ st.subheader("🔍 RAG Query")
66
+ render_rag_query_interface(active_collection, model_name)
67
+
68
+
69
+ def render_document_details(pt: dict, p: dict, score: float = None, rel_pct: float = None):
70
+ def _is_missing(v) -> bool:
71
+ if v is None:
72
+ return True
73
+ if isinstance(v, (list, tuple, dict)) and len(v) == 0:
74
+ return True
75
+ if isinstance(v, str):
76
+ s = v.strip()
77
+ return s == "" or s.lower() in {"na", "n/a", "none", "null", "unknown", "?", "-"}
78
+ return False
79
+
80
+ doc_id = p.get("doc_id") or p.get("union_doc_id") or p.get("source_doc_id") or "?"
81
+ corpus_id = p.get("corpus-id") or p.get("source_doc_id") or "?"
82
+ dataset = p.get("dataset") or p.get("source") or None
83
+ model = (p.get("model_name") or p.get("model") or None)
84
+ model = model.split("/")[-1] if isinstance(model, str) else None
85
+ doc_name = p.get("doc-id") or p.get("filename") or "Unknown"
86
+
87
+ num_tiles = p.get("num_tiles")
88
+ visual_tokens = p.get("index_recovery_num_visual_tokens") or p.get("num_visual_tokens")
89
+ patches_per_tile = p.get("patches_per_tile")
90
+ torch_dtype = p.get("torch_dtype")
91
+
92
+ orig_w = p.get("original_width")
93
+ orig_h = p.get("original_height")
94
+ crop_w = p.get("cropped_width")
95
+ crop_h = p.get("cropped_height")
96
+ resize_w = p.get("resized_width")
97
+ resize_h = p.get("resized_height")
98
+ crop_pct = p.get("crop_empty_percentage_to_remove")
99
+ crop_enabled = bool(p.get("crop_empty_enabled", False))
100
+
101
+ col_meta, col_img = st.columns([1, 2])
102
+
103
+ with col_meta:
104
+ st.markdown("##### 📄 Document Info")
105
+ st.markdown(f"**📁 Doc:** {doc_name}")
106
+ if not _is_missing(dataset):
107
+ st.markdown(f"**🏛️ Dataset:** {dataset}")
108
+ if not _is_missing(doc_id) and str(doc_id) != "?":
109
+ st.markdown(f"**🔑 Doc ID:** `{str(doc_id)[:20]}...`")
110
+ if not _is_missing(corpus_id) and str(corpus_id) != "?":
111
+ st.markdown(f"**📋 Corpus ID:** {corpus_id}")
112
+
113
+ if score is not None:
114
+ st.divider()
115
+ st.markdown("##### 🎯 Relevance")
116
+ if rel_pct is not None:
117
+ st.markdown(f"**Relative:** 🟢 {rel_pct:.1f}%")
118
+ st.progress(rel_pct / 100)
119
+ st.caption(f"Raw score: {score:.4f}")
120
+
121
+ st.divider()
122
+ visual_rows = []
123
+ if not _is_missing(model):
124
+ visual_rows.append(("🤖 Model", f"`{model}`"))
125
+ if not _is_missing(num_tiles):
126
+ visual_rows.append(("🔲 Tiles", str(num_tiles)))
127
+ if not _is_missing(visual_tokens):
128
+ visual_rows.append(("🔢 Visual Tokens", str(visual_tokens)))
129
+ if not _is_missing(patches_per_tile):
130
+ visual_rows.append(("📦 Patches/Tile", str(patches_per_tile)))
131
+ if not _is_missing(torch_dtype):
132
+ visual_rows.append(("⚙️ Dtype", str(torch_dtype)))
133
+ if visual_rows:
134
+ st.markdown("##### 🎨 Visual Metadata")
135
+ for k, v in visual_rows:
136
+ st.markdown(f"**{k}:** {v}")
137
+
138
+ st.divider()
139
+ dim_rows = []
140
+ if not _is_missing(orig_w) and not _is_missing(orig_h):
141
+ dim_rows.append(("Original", f"{orig_w}×{orig_h}"))
142
+ if not _is_missing(resize_w) and not _is_missing(resize_h):
143
+ dim_rows.append(("Resized", f"{resize_w}×{resize_h}"))
144
+ if crop_enabled and not _is_missing(crop_w) and not _is_missing(crop_h):
145
+ dim_rows.append(("Cropped", f"{crop_w}×{crop_h}"))
146
+ if dim_rows:
147
+ st.markdown("##### 📐 Dimensions")
148
+ for k, v in dim_rows:
149
+ st.markdown(f"**{k}:** {v}")
150
+ if crop_enabled and not _is_missing(crop_pct):
151
+ try:
152
+ st.markdown(f"**Crop %:** {int(float(crop_pct) * 100)}%")
153
+ except Exception:
154
+ pass
155
+
156
+ with col_img:
157
+ st.markdown("##### 📷 Document Page")
158
+ tabs = st.tabs(["🖼️ Original", "📷 Resized", "✂️ Cropped"])
159
+
160
+ url_o = p.get("original_url")
161
+ url_r = p.get("resized_url") or p.get("page")
162
+ url_c = p.get("cropped_url")
163
+
164
+ with tabs[0]:
165
+ if url_o:
166
+ st.image(url_o, width=600)
167
+ st.caption(f"📐 **{orig_w}×{orig_h}**")
168
+ else:
169
+ st.info("No original image available")
170
+
171
+ with tabs[1]:
172
+ if url_r:
173
+ st.image(url_r, width=600)
174
+ st.caption(f"📐 **{resize_w}×{resize_h}**")
175
+ else:
176
+ st.info("No resized image available")
177
+
178
+ with tabs[2]:
179
+ if url_c:
180
+ # Display on a checkerboard background to make the crop boundary obvious.
181
+ w_caption = f"{crop_w}×{crop_h}" if (not _is_missing(crop_w) and not _is_missing(crop_h)) else None
182
+ pct_caption = None
183
+ if not _is_missing(crop_pct):
184
+ try:
185
+ pct_caption = f"{int(float(crop_pct) * 100)}%"
186
+ except Exception:
187
+ pct_caption = None
188
+ st.markdown(
189
+ f"""
190
+ <div style="
191
+ width: 600px;
192
+ padding: 14px;
193
+ border-radius: 10px;
194
+ background-image:
195
+ linear-gradient(45deg, #e6e6e6 25%, transparent 25%),
196
+ linear-gradient(-45deg, #e6e6e6 25%, transparent 25%),
197
+ linear-gradient(45deg, transparent 75%, #e6e6e6 75%),
198
+ linear-gradient(-45deg, transparent 75%, #e6e6e6 75%);
199
+ background-size: 24px 24px;
200
+ background-position: 0 0, 0 12px, 12px -12px, -12px 0px;
201
+ box-shadow: 0 10px 30px rgba(0,0,0,0.18);
202
+ display: inline-block;
203
+ ">
204
+ <img src="{url_c}" style="width: 100%; border-radius: 6px; display:block;" />
205
+ </div>
206
+ """,
207
+ unsafe_allow_html=True,
208
+ )
209
+ cap = []
210
+ if w_caption:
211
+ cap.append(f"📐 **{w_caption}**")
212
+ if pct_caption:
213
+ cap.append(f"Crop: {pct_caption}")
214
+ if cap:
215
+ st.caption(" | ".join(cap))
216
+ else:
217
+ st.info("No cropped image available")
218
+
219
+ with st.expander("🔗 Image URLs"):
220
+ if url_o:
221
+ st.code(url_o, language=None)
222
+ if url_r and url_r != url_o:
223
+ st.code(url_r, language=None)
224
+ if url_c:
225
+ st.code(url_c, language=None)
226
+
227
+
228
+ def render_sample_explorer(collection_name: str, url: str, api_key: str):
229
+ sample_for_filters = sample_points_cached(collection_name, 50, 0, url, api_key)
230
+ datasets = set()
231
+ doc_ids = set()
232
+ for pt in sample_for_filters:
233
+ p = pt.get("payload", {})
234
+ if ds := p.get("dataset"):
235
+ datasets.add(ds)
236
+ if did := (p.get("doc-id") or p.get("filename")):
237
+ doc_ids.add(did)
238
+
239
+ c1, c2, c3, c4 = st.columns([1, 1, 2, 1])
240
+ with c1:
241
+ n_samples = st.slider("Samples", 1, 20, 3, key="pg_n")
242
+ with c2:
243
+ seed = st.number_input("Seed", 0, 9999, 42, key="pg_seed")
244
+ with c3:
245
+ filter_ds = st.selectbox("Dataset", ["All"] + sorted(datasets), key="pg_filter_ds")
246
+ with c4:
247
+ st.write("")
248
+ do_sample = st.button("🎲 Sample", type="primary", key="pg_sample_btn")
249
+
250
+ if do_sample:
251
+ points = sample_points_cached(collection_name, n_samples * 5, seed, url, api_key)
252
+ if filter_ds != "All":
253
+ points = [p for p in points if p.get("payload", {}).get("dataset") == filter_ds]
254
+ points = points[:n_samples]
255
+ st.session_state["pg_points"] = points
256
+
257
+ points = st.session_state.get("pg_points", [])
258
+
259
+ if not points:
260
+ st.caption("Click 'Sample' to load documents")
261
+ return
262
+
263
+ st.success(f"**{len(points)} points loaded**")
264
+
265
+ for i, pt in enumerate(points):
266
+ p = pt.get("payload", {})
267
+
268
+ filename = p.get("filename") or p.get("doc_id") or p.get("source_doc_id") or "Unknown"
269
+ page_num = p.get("page_number") or p.get("page") or "?"
270
+
271
+ with st.expander(f"**{i+1}.** {str(filename)[:40]} - Page {page_num}", expanded=(i == 0)):
272
+ render_document_details(pt, p)
273
+
274
+
275
+ def render_rag_query_interface(collection_name: str, model_name: str = None):
276
+ if not collection_name:
277
+ return
278
+
279
+ url, api_key = get_qdrant_credentials()
280
+
281
+ if not model_name:
282
+ points = sample_points_cached(collection_name, 1, 0, url, api_key)
283
+ if points:
284
+ model_name = points[0].get("payload", {}).get("model_name")
285
+ if not model_name:
286
+ model_name = AVAILABLE_MODELS[1]
287
+
288
+ st.caption(f"Model: **{model_name.split('/')[-1] if model_name else 'auto'}**")
289
+
290
+ c1, c2, c3 = st.columns([2, 1, 1])
291
+ with c2:
292
+ mode = st.selectbox("Mode", RETRIEVAL_MODES, index=0, key="q_mode")
293
+ with c3:
294
+ top_k = st.slider("Top K", 1, 30, 10, key="q_topk")
295
+
296
+ prefetch_k, stage1_mode, stage1_k, stage2_k = 256, "tokens_vs_tiles", 1000, 300
297
+
298
+ if mode == "two_stage":
299
+ cc1, cc2 = st.columns(2)
300
+ with cc1:
301
+ stage1_mode = st.selectbox("Stage1", STAGE1_MODES, key="q_s1mode")
302
+ with cc2:
303
+ prefetch_k = st.slider("Prefetch K", 50, 500, 256, key="q_pk")
304
+ elif mode == "three_stage":
305
+ cc1, cc2 = st.columns(2)
306
+ with cc1:
307
+ stage1_k = st.number_input("Stage1 K", 100, 5000, 1000, key="q_s1k")
308
+ with cc2:
309
+ stage2_k = st.number_input("Stage2 K", 50, 1000, 300, key="q_s2k")
310
+
311
+ with c1:
312
+ query = st.text_input("Query", placeholder="Enter your search query...", key="q_text")
313
+
314
+ if st.button("🔍 Search", type="primary", disabled=not query, key="q_search"):
315
+ with st.spinner("Searching..."):
316
+ results, err = search_collection(
317
+ collection_name, query, top_k, mode, prefetch_k, stage1_mode, stage1_k, stage2_k, model_name
318
+ )
319
+ if err:
320
+ st.error("Search failed")
321
+ st.code(err)
322
+ else:
323
+ st.session_state["q_results"] = results
324
+
325
+ results = st.session_state.get("q_results", [])
326
+ if results:
327
+ st.success(f"**{len(results)} results**")
328
+ max_score = max(r.get("score_final", r.get("score_stage1", 0)) for r in results) or 1
329
+
330
+ for i, r in enumerate(results):
331
+ p = r.get("payload", {})
332
+ score = r.get("score_final", r.get("score_stage1", 0))
333
+ rel = score / max_score * 100
334
+
335
+ filename = p.get("filename") or p.get("doc_id") or p.get("source_doc_id") or "Unknown"
336
+ page_num = p.get("page_number") or p.get("page") or "?"
337
+
338
+ with st.expander(f"**#{i+1}** {str(filename)[:35]} - Page {page_num} | 🎯 {rel:.0f}%", expanded=(i < 3)):
339
+ render_document_details(r, p, score=score, rel_pct=rel)
demo/ui/sidebar.py ADDED
@@ -0,0 +1,162 @@
1
+ """Sidebar component."""
2
+
3
+ import os
4
+ import streamlit as st
5
+
6
+ from demo.qdrant_utils import (
7
+ get_qdrant_credentials,
8
+ init_qdrant_client_with_creds,
9
+ get_collections,
10
+ get_collection_stats,
11
+ sample_points_cached,
12
+ get_vector_sizes,
13
+ )
14
+
15
+
16
+ def render_sidebar():
17
+ with st.sidebar:
18
+ st.subheader("🔑 Qdrant Credentials")
19
+
20
+ env_url = os.getenv("SIGIR_QDRANT_URL") or os.getenv("DEST_QDRANT_URL") or os.getenv("QDRANT_URL") or ""
21
+ env_key = os.getenv("SIGIR_QDRANT_KEY") or os.getenv("SIGIR_QDRANT_API_KEY") or os.getenv("DEST_QDRANT_API_KEY") or os.getenv("QDRANT_API_KEY") or ""
22
+
23
+ if "qdrant_url_input" not in st.session_state:
24
+ st.session_state["qdrant_url_input"] = env_url
25
+ if "qdrant_key_input" not in st.session_state:
26
+ st.session_state["qdrant_key_input"] = env_key
27
+
28
+ qdrant_url = st.text_input(
29
+ "Qdrant URL",
30
+ value=st.session_state["qdrant_url_input"],
31
+ key="qdrant_url_widget",
32
+ placeholder="https://xxx.cloud.qdrant.io:6333",
33
+ )
34
+ qdrant_key = st.text_input(
35
+ "API Key",
36
+ value=st.session_state["qdrant_key_input"],
37
+ key="qdrant_key_widget",
38
+ type="password",
39
+ )
40
+
41
+ if qdrant_url != st.session_state["qdrant_url_input"] or qdrant_key != st.session_state["qdrant_key_input"]:
42
+ st.session_state["qdrant_url_input"] = qdrant_url
43
+ st.session_state["qdrant_key_input"] = qdrant_key
44
+ get_collections.clear()
45
+ get_collection_stats.clear()
46
+ sample_points_cached.clear()
47
+
48
+ st.divider()
49
+
50
+ st.subheader("📡 Status")
51
+ url, api_key = get_qdrant_credentials()
52
+ client, err = init_qdrant_client_with_creds(url, api_key)
53
+
54
+ col_s1, col_s2 = st.columns(2)
55
+ with col_s1:
56
+ if client:
57
+ st.success("Qdrant ✓", icon="✅")
58
+ else:
59
+ st.error("Qdrant ✗", icon="❌")
60
+ with col_s2:
61
+ cloudinary_ok = all([os.getenv("CLOUDINARY_CLOUD_NAME"), os.getenv("CLOUDINARY_API_KEY")])
62
+ if cloudinary_ok:
63
+ st.success("Cloudinary ✓", icon="✅")
64
+ else:
65
+ st.warning("Cloudinary ✗", icon="⚠️")
66
+
67
+ st.divider()
68
+
69
+ with st.expander("📦 Collection", expanded=True):
70
+ collections = get_collections(url, api_key)
71
+ if collections:
72
+ prev_collection = st.session_state.get("active_collection")
73
+ selected = st.selectbox(
74
+ "Select Collection",
75
+ options=collections,
76
+ key="sidebar_collection",
77
+ label_visibility="collapsed",
78
+ )
79
+ if selected:
80
+ if selected != prev_collection:
81
+ st.session_state["model_loaded"] = False
82
+ st.session_state["loaded_model_key"] = None
83
+ st.session_state["active_collection"] = selected
84
+ stats = get_collection_stats(selected)
85
+ if "error" not in stats:
86
+ col1, col2 = st.columns(2)
87
+ col1.metric("Points", f"{stats.get('points_count', 0):,}")
88
+ status_raw = stats.get("status", "unknown").replace("CollectionStatus.", "").lower()
89
+ status_icon = "🟢" if status_raw == "green" else "🟡" if status_raw == "yellow" else "🔴"
90
+ col2.metric("Status", status_icon)
91
+
92
+ points = stats.get("points_count", 0)
93
+ indexed = stats.get("indexed_vectors_count", 0) or 0
94
+ is_indexed = indexed >= points and points > 0
95
+ col3, col4 = st.columns(2)
96
+ col3.metric("Indexed", f"{indexed:,}")
97
+ col4.metric("HNSW", "✅" if is_indexed else "⏳")
98
+
99
+ vector_info = stats.get("vector_info", {})
100
+ if vector_info:
101
+ st.markdown("---")
102
+ st.markdown("**🔢 Vectors**")
103
+ vec_sizes = get_vector_sizes(selected, url, api_key)
104
+ rows = []
105
+ sorted_names = sorted(vector_info.keys(), key=lambda x: len(x))
106
+ for vname in sorted_names:
107
+ vinfo = vector_info[vname]
108
+ dim = vinfo.get("size", "?")
109
+ num_vec = vec_sizes.get(vname, vinfo.get("num_vectors", 1))
110
+ dtype = vinfo.get("datatype", "?").upper()
111
+ on_disk = vinfo.get("on_disk", False)
112
+ disk_icon = "💾" if on_disk else "🧠"
113
+ dim_str = f"{num_vec}×{dim}"
114
+ rows.append(f"<tr><td style='text-align:left;padding-right:12px;'><code>{vname}</code></td><td style='text-align:right;'>{dim_str}, {dtype}, {disk_icon}</td></tr>")
115
+ table_html = f"<table style='width:100%;font-size:0.85em;'>{''.join(rows)}</table>"
116
+ st.markdown(table_html, unsafe_allow_html=True)
117
+ else:
118
+ st.error("Error loading stats")
119
+ else:
120
+ st.info("No collections")
121
+
122
+ with st.expander("⚙️ Admin", expanded=False):
123
+ active = st.session_state.get("active_collection")
124
+ if active and client:
125
+ stats = get_collection_stats(active)
126
+ vector_info = stats.get("vector_info", {})
127
+ if vector_info:
128
+ st.markdown("**Change Storage**")
129
+ vector_names = sorted(vector_info.keys())
130
+ sel_vec = st.selectbox("Vector", vector_names, key="admin_vec")
131
+ if sel_vec:
132
+ current_on_disk = vector_info.get(sel_vec, {}).get("on_disk", False)
133
+ current_in_ram = not current_on_disk
134
+ st.caption(f"Current: {'🧠 RAM' if current_in_ram else '💾 Disk'}")
135
+ target_in_ram = st.toggle("Move to RAM", value=current_in_ram, key=f"admin_ram_{sel_vec}")
136
+ if target_in_ram != current_in_ram:
137
+ if st.button("💾 Apply Change", key="admin_apply"):
138
+ try:
139
+ from qdrant_client.models import VectorParamsDiff
140
+ client.update_collection(
141
+ collection_name=active,
142
+ vectors_config={sel_vec: VectorParamsDiff(on_disk=not target_in_ram)}
143
+ )
144
+ get_collection_stats.clear()
145
+ st.success(f"Updated {sel_vec}")
146
+ st.rerun()
147
+ except Exception as e:
148
+ st.error(f"Failed: {e}")
149
+ else:
150
+ st.caption("Toggle to change storage location")
151
+ else:
152
+ st.info("No vectors")
153
+ else:
154
+ st.info("Select a collection")
155
+
156
+ st.divider()
157
+
158
+ if st.button("🔄 Refresh", type="secondary", use_container_width=True):
159
+ get_collections.clear()
160
+ get_collection_stats.clear()
161
+ sample_points_cached.clear()
162
+ st.rerun()