visual-rag-toolkit 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/README.md +101 -0
- benchmarks/__init__.py +11 -0
- benchmarks/analyze_results.py +187 -0
- benchmarks/benchmark_datasets.txt +105 -0
- benchmarks/prepare_submission.py +205 -0
- benchmarks/quick_test.py +566 -0
- benchmarks/run_vidore.py +513 -0
- benchmarks/vidore_beir_qdrant/run_qdrant_beir.py +1365 -0
- benchmarks/vidore_tatdqa_test/COMMANDS.md +83 -0
- benchmarks/vidore_tatdqa_test/__init__.py +6 -0
- benchmarks/vidore_tatdqa_test/dataset_loader.py +363 -0
- benchmarks/vidore_tatdqa_test/metrics.py +44 -0
- benchmarks/vidore_tatdqa_test/run_qdrant.py +799 -0
- benchmarks/vidore_tatdqa_test/sweep_eval.py +372 -0
- demo/__init__.py +10 -0
- demo/app.py +45 -0
- demo/commands.py +334 -0
- demo/config.py +34 -0
- demo/download_models.py +75 -0
- demo/evaluation.py +602 -0
- demo/example_metadata_mapping_sigir.json +37 -0
- demo/indexing.py +286 -0
- demo/qdrant_utils.py +211 -0
- demo/results.py +35 -0
- demo/test_qdrant_connection.py +119 -0
- demo/ui/__init__.py +15 -0
- demo/ui/benchmark.py +355 -0
- demo/ui/header.py +30 -0
- demo/ui/playground.py +339 -0
- demo/ui/sidebar.py +162 -0
- demo/ui/upload.py +487 -0
- visual_rag/__init__.py +98 -0
- visual_rag/cli/__init__.py +1 -0
- visual_rag/cli/main.py +629 -0
- visual_rag/config.py +230 -0
- visual_rag/demo_runner.py +90 -0
- visual_rag/embedding/__init__.py +26 -0
- visual_rag/embedding/pooling.py +343 -0
- visual_rag/embedding/visual_embedder.py +622 -0
- visual_rag/indexing/__init__.py +21 -0
- visual_rag/indexing/cloudinary_uploader.py +274 -0
- visual_rag/indexing/pdf_processor.py +324 -0
- visual_rag/indexing/pipeline.py +628 -0
- visual_rag/indexing/qdrant_indexer.py +478 -0
- visual_rag/preprocessing/__init__.py +3 -0
- visual_rag/preprocessing/crop_empty.py +120 -0
- visual_rag/qdrant_admin.py +222 -0
- visual_rag/retrieval/__init__.py +19 -0
- visual_rag/retrieval/multi_vector.py +222 -0
- visual_rag/retrieval/single_stage.py +126 -0
- visual_rag/retrieval/three_stage.py +173 -0
- visual_rag/retrieval/two_stage.py +471 -0
- visual_rag/visualization/__init__.py +19 -0
- visual_rag/visualization/saliency.py +335 -0
- visual_rag_toolkit-0.1.1.dist-info/METADATA +305 -0
- visual_rag_toolkit-0.1.1.dist-info/RECORD +59 -0
- visual_rag_toolkit-0.1.1.dist-info/WHEEL +4 -0
- visual_rag_toolkit-0.1.1.dist-info/entry_points.txt +3 -0
- visual_rag_toolkit-0.1.1.dist-info/licenses/LICENSE +22 -0
demo/ui/playground.py
ADDED
|
@@ -0,0 +1,339 @@
|
|
|
1
|
+
"""Playground tab component."""
|
|
2
|
+
|
|
3
|
+
import streamlit as st
|
|
4
|
+
|
|
5
|
+
from demo.config import AVAILABLE_MODELS, RETRIEVAL_MODES, STAGE1_MODES
|
|
6
|
+
from demo.qdrant_utils import (
|
|
7
|
+
get_qdrant_credentials,
|
|
8
|
+
get_collections,
|
|
9
|
+
sample_points_cached,
|
|
10
|
+
search_collection,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def render_playground_tab():
|
|
15
|
+
st.subheader("🎮 Playground")
|
|
16
|
+
|
|
17
|
+
active_collection = st.session_state.get("active_collection")
|
|
18
|
+
url, api_key = get_qdrant_credentials()
|
|
19
|
+
|
|
20
|
+
if not active_collection:
|
|
21
|
+
collections = get_collections(url, api_key)
|
|
22
|
+
if collections:
|
|
23
|
+
active_collection = collections[0]
|
|
24
|
+
|
|
25
|
+
if not active_collection:
|
|
26
|
+
st.warning("No collection available. Upload documents or select a collection.")
|
|
27
|
+
return
|
|
28
|
+
|
|
29
|
+
points_for_model = sample_points_cached(active_collection, 1, 0, url, api_key)
|
|
30
|
+
model_name = None
|
|
31
|
+
if points_for_model:
|
|
32
|
+
model_name = points_for_model[0].get("payload", {}).get("model_name")
|
|
33
|
+
if not model_name:
|
|
34
|
+
model_name = AVAILABLE_MODELS[1]
|
|
35
|
+
|
|
36
|
+
model_short = model_name.split("/")[-1] if model_name else "unknown"
|
|
37
|
+
cache_key = f"{active_collection}_{model_name}"
|
|
38
|
+
|
|
39
|
+
if st.session_state.get("loaded_model_key") != cache_key:
|
|
40
|
+
st.session_state["model_loaded"] = False
|
|
41
|
+
|
|
42
|
+
col_info, col_model = st.columns([2, 1])
|
|
43
|
+
with col_info:
|
|
44
|
+
st.info(f"**Collection:** `{active_collection}`")
|
|
45
|
+
with col_model:
|
|
46
|
+
if not st.session_state.get("model_loaded"):
|
|
47
|
+
with st.spinner(f"Loading {model_short}..."):
|
|
48
|
+
try:
|
|
49
|
+
from visual_rag.retrieval import MultiVectorRetriever
|
|
50
|
+
_ = MultiVectorRetriever(collection_name=active_collection, model_name=model_name)
|
|
51
|
+
st.session_state["model_loaded"] = True
|
|
52
|
+
st.session_state["loaded_model_key"] = cache_key
|
|
53
|
+
st.session_state["loaded_model_name"] = model_name
|
|
54
|
+
except Exception as e:
|
|
55
|
+
st.warning(f"Failed: {model_short}")
|
|
56
|
+
|
|
57
|
+
if st.session_state.get("model_loaded"):
|
|
58
|
+
st.markdown(f"✅ Found <span style='color:#e74c3c;font-weight:bold;'>{model_short}</span> model", unsafe_allow_html=True)
|
|
59
|
+
|
|
60
|
+
with st.expander("📦 Sample Points Explorer", expanded=True):
|
|
61
|
+
render_sample_explorer(active_collection, url, api_key)
|
|
62
|
+
|
|
63
|
+
st.divider()
|
|
64
|
+
|
|
65
|
+
st.subheader("🔍 RAG Query")
|
|
66
|
+
render_rag_query_interface(active_collection, model_name)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def render_document_details(pt: dict, p: dict, score: float = None, rel_pct: float = None):
|
|
70
|
+
def _is_missing(v) -> bool:
|
|
71
|
+
if v is None:
|
|
72
|
+
return True
|
|
73
|
+
if isinstance(v, (list, tuple, dict)) and len(v) == 0:
|
|
74
|
+
return True
|
|
75
|
+
if isinstance(v, str):
|
|
76
|
+
s = v.strip()
|
|
77
|
+
return s == "" or s.lower() in {"na", "n/a", "none", "null", "unknown", "?", "-"}
|
|
78
|
+
return False
|
|
79
|
+
|
|
80
|
+
doc_id = p.get("doc_id") or p.get("union_doc_id") or p.get("source_doc_id") or "?"
|
|
81
|
+
corpus_id = p.get("corpus-id") or p.get("source_doc_id") or "?"
|
|
82
|
+
dataset = p.get("dataset") or p.get("source") or None
|
|
83
|
+
model = (p.get("model_name") or p.get("model") or None)
|
|
84
|
+
model = model.split("/")[-1] if isinstance(model, str) else None
|
|
85
|
+
doc_name = p.get("doc-id") or p.get("filename") or "Unknown"
|
|
86
|
+
|
|
87
|
+
num_tiles = p.get("num_tiles")
|
|
88
|
+
visual_tokens = p.get("index_recovery_num_visual_tokens") or p.get("num_visual_tokens")
|
|
89
|
+
patches_per_tile = p.get("patches_per_tile")
|
|
90
|
+
torch_dtype = p.get("torch_dtype")
|
|
91
|
+
|
|
92
|
+
orig_w = p.get("original_width")
|
|
93
|
+
orig_h = p.get("original_height")
|
|
94
|
+
crop_w = p.get("cropped_width")
|
|
95
|
+
crop_h = p.get("cropped_height")
|
|
96
|
+
resize_w = p.get("resized_width")
|
|
97
|
+
resize_h = p.get("resized_height")
|
|
98
|
+
crop_pct = p.get("crop_empty_percentage_to_remove")
|
|
99
|
+
crop_enabled = bool(p.get("crop_empty_enabled", False))
|
|
100
|
+
|
|
101
|
+
col_meta, col_img = st.columns([1, 2])
|
|
102
|
+
|
|
103
|
+
with col_meta:
|
|
104
|
+
st.markdown("##### 📄 Document Info")
|
|
105
|
+
st.markdown(f"**📁 Doc:** {doc_name}")
|
|
106
|
+
if not _is_missing(dataset):
|
|
107
|
+
st.markdown(f"**🏛️ Dataset:** {dataset}")
|
|
108
|
+
if not _is_missing(doc_id) and str(doc_id) != "?":
|
|
109
|
+
st.markdown(f"**🔑 Doc ID:** `{str(doc_id)[:20]}...`")
|
|
110
|
+
if not _is_missing(corpus_id) and str(corpus_id) != "?":
|
|
111
|
+
st.markdown(f"**📋 Corpus ID:** {corpus_id}")
|
|
112
|
+
|
|
113
|
+
if score is not None:
|
|
114
|
+
st.divider()
|
|
115
|
+
st.markdown("##### 🎯 Relevance")
|
|
116
|
+
if rel_pct is not None:
|
|
117
|
+
st.markdown(f"**Relative:** 🟢 {rel_pct:.1f}%")
|
|
118
|
+
st.progress(rel_pct / 100)
|
|
119
|
+
st.caption(f"Raw score: {score:.4f}")
|
|
120
|
+
|
|
121
|
+
st.divider()
|
|
122
|
+
visual_rows = []
|
|
123
|
+
if not _is_missing(model):
|
|
124
|
+
visual_rows.append(("🤖 Model", f"`{model}`"))
|
|
125
|
+
if not _is_missing(num_tiles):
|
|
126
|
+
visual_rows.append(("🔲 Tiles", str(num_tiles)))
|
|
127
|
+
if not _is_missing(visual_tokens):
|
|
128
|
+
visual_rows.append(("🔢 Visual Tokens", str(visual_tokens)))
|
|
129
|
+
if not _is_missing(patches_per_tile):
|
|
130
|
+
visual_rows.append(("📦 Patches/Tile", str(patches_per_tile)))
|
|
131
|
+
if not _is_missing(torch_dtype):
|
|
132
|
+
visual_rows.append(("⚙️ Dtype", str(torch_dtype)))
|
|
133
|
+
if visual_rows:
|
|
134
|
+
st.markdown("##### 🎨 Visual Metadata")
|
|
135
|
+
for k, v in visual_rows:
|
|
136
|
+
st.markdown(f"**{k}:** {v}")
|
|
137
|
+
|
|
138
|
+
st.divider()
|
|
139
|
+
dim_rows = []
|
|
140
|
+
if not _is_missing(orig_w) and not _is_missing(orig_h):
|
|
141
|
+
dim_rows.append(("Original", f"{orig_w}×{orig_h}"))
|
|
142
|
+
if not _is_missing(resize_w) and not _is_missing(resize_h):
|
|
143
|
+
dim_rows.append(("Resized", f"{resize_w}×{resize_h}"))
|
|
144
|
+
if crop_enabled and not _is_missing(crop_w) and not _is_missing(crop_h):
|
|
145
|
+
dim_rows.append(("Cropped", f"{crop_w}×{crop_h}"))
|
|
146
|
+
if dim_rows:
|
|
147
|
+
st.markdown("##### 📐 Dimensions")
|
|
148
|
+
for k, v in dim_rows:
|
|
149
|
+
st.markdown(f"**{k}:** {v}")
|
|
150
|
+
if crop_enabled and not _is_missing(crop_pct):
|
|
151
|
+
try:
|
|
152
|
+
st.markdown(f"**Crop %:** {int(float(crop_pct) * 100)}%")
|
|
153
|
+
except Exception:
|
|
154
|
+
pass
|
|
155
|
+
|
|
156
|
+
with col_img:
|
|
157
|
+
st.markdown("##### 📷 Document Page")
|
|
158
|
+
tabs = st.tabs(["🖼️ Original", "📷 Resized", "✂️ Cropped"])
|
|
159
|
+
|
|
160
|
+
url_o = p.get("original_url")
|
|
161
|
+
url_r = p.get("resized_url") or p.get("page")
|
|
162
|
+
url_c = p.get("cropped_url")
|
|
163
|
+
|
|
164
|
+
with tabs[0]:
|
|
165
|
+
if url_o:
|
|
166
|
+
st.image(url_o, width=600)
|
|
167
|
+
st.caption(f"📐 **{orig_w}×{orig_h}**")
|
|
168
|
+
else:
|
|
169
|
+
st.info("No original image available")
|
|
170
|
+
|
|
171
|
+
with tabs[1]:
|
|
172
|
+
if url_r:
|
|
173
|
+
st.image(url_r, width=600)
|
|
174
|
+
st.caption(f"📐 **{resize_w}×{resize_h}**")
|
|
175
|
+
else:
|
|
176
|
+
st.info("No resized image available")
|
|
177
|
+
|
|
178
|
+
with tabs[2]:
|
|
179
|
+
if url_c:
|
|
180
|
+
# Display on a checkerboard background to make the crop boundary obvious.
|
|
181
|
+
w_caption = f"{crop_w}×{crop_h}" if (not _is_missing(crop_w) and not _is_missing(crop_h)) else None
|
|
182
|
+
pct_caption = None
|
|
183
|
+
if not _is_missing(crop_pct):
|
|
184
|
+
try:
|
|
185
|
+
pct_caption = f"{int(float(crop_pct) * 100)}%"
|
|
186
|
+
except Exception:
|
|
187
|
+
pct_caption = None
|
|
188
|
+
st.markdown(
|
|
189
|
+
f"""
|
|
190
|
+
<div style="
|
|
191
|
+
width: 600px;
|
|
192
|
+
padding: 14px;
|
|
193
|
+
border-radius: 10px;
|
|
194
|
+
background-image:
|
|
195
|
+
linear-gradient(45deg, #e6e6e6 25%, transparent 25%),
|
|
196
|
+
linear-gradient(-45deg, #e6e6e6 25%, transparent 25%),
|
|
197
|
+
linear-gradient(45deg, transparent 75%, #e6e6e6 75%),
|
|
198
|
+
linear-gradient(-45deg, transparent 75%, #e6e6e6 75%);
|
|
199
|
+
background-size: 24px 24px;
|
|
200
|
+
background-position: 0 0, 0 12px, 12px -12px, -12px 0px;
|
|
201
|
+
box-shadow: 0 10px 30px rgba(0,0,0,0.18);
|
|
202
|
+
display: inline-block;
|
|
203
|
+
">
|
|
204
|
+
<img src="{url_c}" style="width: 100%; border-radius: 6px; display:block;" />
|
|
205
|
+
</div>
|
|
206
|
+
""",
|
|
207
|
+
unsafe_allow_html=True,
|
|
208
|
+
)
|
|
209
|
+
cap = []
|
|
210
|
+
if w_caption:
|
|
211
|
+
cap.append(f"📐 **{w_caption}**")
|
|
212
|
+
if pct_caption:
|
|
213
|
+
cap.append(f"Crop: {pct_caption}")
|
|
214
|
+
if cap:
|
|
215
|
+
st.caption(" | ".join(cap))
|
|
216
|
+
else:
|
|
217
|
+
st.info("No cropped image available")
|
|
218
|
+
|
|
219
|
+
with st.expander("🔗 Image URLs"):
|
|
220
|
+
if url_o:
|
|
221
|
+
st.code(url_o, language=None)
|
|
222
|
+
if url_r and url_r != url_o:
|
|
223
|
+
st.code(url_r, language=None)
|
|
224
|
+
if url_c:
|
|
225
|
+
st.code(url_c, language=None)
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def render_sample_explorer(collection_name: str, url: str, api_key: str):
|
|
229
|
+
sample_for_filters = sample_points_cached(collection_name, 50, 0, url, api_key)
|
|
230
|
+
datasets = set()
|
|
231
|
+
doc_ids = set()
|
|
232
|
+
for pt in sample_for_filters:
|
|
233
|
+
p = pt.get("payload", {})
|
|
234
|
+
if ds := p.get("dataset"):
|
|
235
|
+
datasets.add(ds)
|
|
236
|
+
if did := (p.get("doc-id") or p.get("filename")):
|
|
237
|
+
doc_ids.add(did)
|
|
238
|
+
|
|
239
|
+
c1, c2, c3, c4 = st.columns([1, 1, 2, 1])
|
|
240
|
+
with c1:
|
|
241
|
+
n_samples = st.slider("Samples", 1, 20, 3, key="pg_n")
|
|
242
|
+
with c2:
|
|
243
|
+
seed = st.number_input("Seed", 0, 9999, 42, key="pg_seed")
|
|
244
|
+
with c3:
|
|
245
|
+
filter_ds = st.selectbox("Dataset", ["All"] + sorted(datasets), key="pg_filter_ds")
|
|
246
|
+
with c4:
|
|
247
|
+
st.write("")
|
|
248
|
+
do_sample = st.button("🎲 Sample", type="primary", key="pg_sample_btn")
|
|
249
|
+
|
|
250
|
+
if do_sample:
|
|
251
|
+
points = sample_points_cached(collection_name, n_samples * 5, seed, url, api_key)
|
|
252
|
+
if filter_ds != "All":
|
|
253
|
+
points = [p for p in points if p.get("payload", {}).get("dataset") == filter_ds]
|
|
254
|
+
points = points[:n_samples]
|
|
255
|
+
st.session_state["pg_points"] = points
|
|
256
|
+
|
|
257
|
+
points = st.session_state.get("pg_points", [])
|
|
258
|
+
|
|
259
|
+
if not points:
|
|
260
|
+
st.caption("Click 'Sample' to load documents")
|
|
261
|
+
return
|
|
262
|
+
|
|
263
|
+
st.success(f"**{len(points)} points loaded**")
|
|
264
|
+
|
|
265
|
+
for i, pt in enumerate(points):
|
|
266
|
+
p = pt.get("payload", {})
|
|
267
|
+
|
|
268
|
+
filename = p.get("filename") or p.get("doc_id") or p.get("source_doc_id") or "Unknown"
|
|
269
|
+
page_num = p.get("page_number") or p.get("page") or "?"
|
|
270
|
+
|
|
271
|
+
with st.expander(f"**{i+1}.** {str(filename)[:40]} - Page {page_num}", expanded=(i == 0)):
|
|
272
|
+
render_document_details(pt, p)
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
def render_rag_query_interface(collection_name: str, model_name: str = None):
|
|
276
|
+
if not collection_name:
|
|
277
|
+
return
|
|
278
|
+
|
|
279
|
+
url, api_key = get_qdrant_credentials()
|
|
280
|
+
|
|
281
|
+
if not model_name:
|
|
282
|
+
points = sample_points_cached(collection_name, 1, 0, url, api_key)
|
|
283
|
+
if points:
|
|
284
|
+
model_name = points[0].get("payload", {}).get("model_name")
|
|
285
|
+
if not model_name:
|
|
286
|
+
model_name = AVAILABLE_MODELS[1]
|
|
287
|
+
|
|
288
|
+
st.caption(f"Model: **{model_name.split('/')[-1] if model_name else 'auto'}**")
|
|
289
|
+
|
|
290
|
+
c1, c2, c3 = st.columns([2, 1, 1])
|
|
291
|
+
with c2:
|
|
292
|
+
mode = st.selectbox("Mode", RETRIEVAL_MODES, index=0, key="q_mode")
|
|
293
|
+
with c3:
|
|
294
|
+
top_k = st.slider("Top K", 1, 30, 10, key="q_topk")
|
|
295
|
+
|
|
296
|
+
prefetch_k, stage1_mode, stage1_k, stage2_k = 256, "tokens_vs_tiles", 1000, 300
|
|
297
|
+
|
|
298
|
+
if mode == "two_stage":
|
|
299
|
+
cc1, cc2 = st.columns(2)
|
|
300
|
+
with cc1:
|
|
301
|
+
stage1_mode = st.selectbox("Stage1", STAGE1_MODES, key="q_s1mode")
|
|
302
|
+
with cc2:
|
|
303
|
+
prefetch_k = st.slider("Prefetch K", 50, 500, 256, key="q_pk")
|
|
304
|
+
elif mode == "three_stage":
|
|
305
|
+
cc1, cc2 = st.columns(2)
|
|
306
|
+
with cc1:
|
|
307
|
+
stage1_k = st.number_input("Stage1 K", 100, 5000, 1000, key="q_s1k")
|
|
308
|
+
with cc2:
|
|
309
|
+
stage2_k = st.number_input("Stage2 K", 50, 1000, 300, key="q_s2k")
|
|
310
|
+
|
|
311
|
+
with c1:
|
|
312
|
+
query = st.text_input("Query", placeholder="Enter your search query...", key="q_text")
|
|
313
|
+
|
|
314
|
+
if st.button("🔍 Search", type="primary", disabled=not query, key="q_search"):
|
|
315
|
+
with st.spinner("Searching..."):
|
|
316
|
+
results, err = search_collection(
|
|
317
|
+
collection_name, query, top_k, mode, prefetch_k, stage1_mode, stage1_k, stage2_k, model_name
|
|
318
|
+
)
|
|
319
|
+
if err:
|
|
320
|
+
st.error("Search failed")
|
|
321
|
+
st.code(err)
|
|
322
|
+
else:
|
|
323
|
+
st.session_state["q_results"] = results
|
|
324
|
+
|
|
325
|
+
results = st.session_state.get("q_results", [])
|
|
326
|
+
if results:
|
|
327
|
+
st.success(f"**{len(results)} results**")
|
|
328
|
+
max_score = max(r.get("score_final", r.get("score_stage1", 0)) for r in results) or 1
|
|
329
|
+
|
|
330
|
+
for i, r in enumerate(results):
|
|
331
|
+
p = r.get("payload", {})
|
|
332
|
+
score = r.get("score_final", r.get("score_stage1", 0))
|
|
333
|
+
rel = score / max_score * 100
|
|
334
|
+
|
|
335
|
+
filename = p.get("filename") or p.get("doc_id") or p.get("source_doc_id") or "Unknown"
|
|
336
|
+
page_num = p.get("page_number") or p.get("page") or "?"
|
|
337
|
+
|
|
338
|
+
with st.expander(f"**#{i+1}** {str(filename)[:35]} - Page {page_num} | 🎯 {rel:.0f}%", expanded=(i < 3)):
|
|
339
|
+
render_document_details(r, p, score=score, rel_pct=rel)
|
demo/ui/sidebar.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
"""Sidebar component."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import streamlit as st
|
|
5
|
+
|
|
6
|
+
from demo.qdrant_utils import (
|
|
7
|
+
get_qdrant_credentials,
|
|
8
|
+
init_qdrant_client_with_creds,
|
|
9
|
+
get_collections,
|
|
10
|
+
get_collection_stats,
|
|
11
|
+
sample_points_cached,
|
|
12
|
+
get_vector_sizes,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def render_sidebar():
|
|
17
|
+
with st.sidebar:
|
|
18
|
+
st.subheader("🔑 Qdrant Credentials")
|
|
19
|
+
|
|
20
|
+
env_url = os.getenv("SIGIR_QDRANT_URL") or os.getenv("DEST_QDRANT_URL") or os.getenv("QDRANT_URL") or ""
|
|
21
|
+
env_key = os.getenv("SIGIR_QDRANT_KEY") or os.getenv("SIGIR_QDRANT_API_KEY") or os.getenv("DEST_QDRANT_API_KEY") or os.getenv("QDRANT_API_KEY") or ""
|
|
22
|
+
|
|
23
|
+
if "qdrant_url_input" not in st.session_state:
|
|
24
|
+
st.session_state["qdrant_url_input"] = env_url
|
|
25
|
+
if "qdrant_key_input" not in st.session_state:
|
|
26
|
+
st.session_state["qdrant_key_input"] = env_key
|
|
27
|
+
|
|
28
|
+
qdrant_url = st.text_input(
|
|
29
|
+
"Qdrant URL",
|
|
30
|
+
value=st.session_state["qdrant_url_input"],
|
|
31
|
+
key="qdrant_url_widget",
|
|
32
|
+
placeholder="https://xxx.cloud.qdrant.io:6333",
|
|
33
|
+
)
|
|
34
|
+
qdrant_key = st.text_input(
|
|
35
|
+
"API Key",
|
|
36
|
+
value=st.session_state["qdrant_key_input"],
|
|
37
|
+
key="qdrant_key_widget",
|
|
38
|
+
type="password",
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
if qdrant_url != st.session_state["qdrant_url_input"] or qdrant_key != st.session_state["qdrant_key_input"]:
|
|
42
|
+
st.session_state["qdrant_url_input"] = qdrant_url
|
|
43
|
+
st.session_state["qdrant_key_input"] = qdrant_key
|
|
44
|
+
get_collections.clear()
|
|
45
|
+
get_collection_stats.clear()
|
|
46
|
+
sample_points_cached.clear()
|
|
47
|
+
|
|
48
|
+
st.divider()
|
|
49
|
+
|
|
50
|
+
st.subheader("📡 Status")
|
|
51
|
+
url, api_key = get_qdrant_credentials()
|
|
52
|
+
client, err = init_qdrant_client_with_creds(url, api_key)
|
|
53
|
+
|
|
54
|
+
col_s1, col_s2 = st.columns(2)
|
|
55
|
+
with col_s1:
|
|
56
|
+
if client:
|
|
57
|
+
st.success("Qdrant ✓", icon="✅")
|
|
58
|
+
else:
|
|
59
|
+
st.error("Qdrant ✗", icon="❌")
|
|
60
|
+
with col_s2:
|
|
61
|
+
cloudinary_ok = all([os.getenv("CLOUDINARY_CLOUD_NAME"), os.getenv("CLOUDINARY_API_KEY")])
|
|
62
|
+
if cloudinary_ok:
|
|
63
|
+
st.success("Cloudinary ✓", icon="✅")
|
|
64
|
+
else:
|
|
65
|
+
st.warning("Cloudinary ✗", icon="⚠️")
|
|
66
|
+
|
|
67
|
+
st.divider()
|
|
68
|
+
|
|
69
|
+
with st.expander("📦 Collection", expanded=True):
|
|
70
|
+
collections = get_collections(url, api_key)
|
|
71
|
+
if collections:
|
|
72
|
+
prev_collection = st.session_state.get("active_collection")
|
|
73
|
+
selected = st.selectbox(
|
|
74
|
+
"Select Collection",
|
|
75
|
+
options=collections,
|
|
76
|
+
key="sidebar_collection",
|
|
77
|
+
label_visibility="collapsed",
|
|
78
|
+
)
|
|
79
|
+
if selected:
|
|
80
|
+
if selected != prev_collection:
|
|
81
|
+
st.session_state["model_loaded"] = False
|
|
82
|
+
st.session_state["loaded_model_key"] = None
|
|
83
|
+
st.session_state["active_collection"] = selected
|
|
84
|
+
stats = get_collection_stats(selected)
|
|
85
|
+
if "error" not in stats:
|
|
86
|
+
col1, col2 = st.columns(2)
|
|
87
|
+
col1.metric("Points", f"{stats.get('points_count', 0):,}")
|
|
88
|
+
status_raw = stats.get("status", "unknown").replace("CollectionStatus.", "").lower()
|
|
89
|
+
status_icon = "🟢" if status_raw == "green" else "🟡" if status_raw == "yellow" else "🔴"
|
|
90
|
+
col2.metric("Status", status_icon)
|
|
91
|
+
|
|
92
|
+
points = stats.get("points_count", 0)
|
|
93
|
+
indexed = stats.get("indexed_vectors_count", 0) or 0
|
|
94
|
+
is_indexed = indexed >= points and points > 0
|
|
95
|
+
col3, col4 = st.columns(2)
|
|
96
|
+
col3.metric("Indexed", f"{indexed:,}")
|
|
97
|
+
col4.metric("HNSW", "✅" if is_indexed else "⏳")
|
|
98
|
+
|
|
99
|
+
vector_info = stats.get("vector_info", {})
|
|
100
|
+
if vector_info:
|
|
101
|
+
st.markdown("---")
|
|
102
|
+
st.markdown("**🔢 Vectors**")
|
|
103
|
+
vec_sizes = get_vector_sizes(selected, url, api_key)
|
|
104
|
+
rows = []
|
|
105
|
+
sorted_names = sorted(vector_info.keys(), key=lambda x: len(x))
|
|
106
|
+
for vname in sorted_names:
|
|
107
|
+
vinfo = vector_info[vname]
|
|
108
|
+
dim = vinfo.get("size", "?")
|
|
109
|
+
num_vec = vec_sizes.get(vname, vinfo.get("num_vectors", 1))
|
|
110
|
+
dtype = vinfo.get("datatype", "?").upper()
|
|
111
|
+
on_disk = vinfo.get("on_disk", False)
|
|
112
|
+
disk_icon = "💾" if on_disk else "🧠"
|
|
113
|
+
dim_str = f"{num_vec}×{dim}"
|
|
114
|
+
rows.append(f"<tr><td style='text-align:left;padding-right:12px;'><code>{vname}</code></td><td style='text-align:right;'>{dim_str}, {dtype}, {disk_icon}</td></tr>")
|
|
115
|
+
table_html = f"<table style='width:100%;font-size:0.85em;'>{''.join(rows)}</table>"
|
|
116
|
+
st.markdown(table_html, unsafe_allow_html=True)
|
|
117
|
+
else:
|
|
118
|
+
st.error("Error loading stats")
|
|
119
|
+
else:
|
|
120
|
+
st.info("No collections")
|
|
121
|
+
|
|
122
|
+
with st.expander("⚙️ Admin", expanded=False):
|
|
123
|
+
active = st.session_state.get("active_collection")
|
|
124
|
+
if active and client:
|
|
125
|
+
stats = get_collection_stats(active)
|
|
126
|
+
vector_info = stats.get("vector_info", {})
|
|
127
|
+
if vector_info:
|
|
128
|
+
st.markdown("**Change Storage**")
|
|
129
|
+
vector_names = sorted(vector_info.keys())
|
|
130
|
+
sel_vec = st.selectbox("Vector", vector_names, key="admin_vec")
|
|
131
|
+
if sel_vec:
|
|
132
|
+
current_on_disk = vector_info.get(sel_vec, {}).get("on_disk", False)
|
|
133
|
+
current_in_ram = not current_on_disk
|
|
134
|
+
st.caption(f"Current: {'🧠 RAM' if current_in_ram else '💾 Disk'}")
|
|
135
|
+
target_in_ram = st.toggle("Move to RAM", value=current_in_ram, key=f"admin_ram_{sel_vec}")
|
|
136
|
+
if target_in_ram != current_in_ram:
|
|
137
|
+
if st.button("💾 Apply Change", key="admin_apply"):
|
|
138
|
+
try:
|
|
139
|
+
from qdrant_client.models import VectorParamsDiff
|
|
140
|
+
client.update_collection(
|
|
141
|
+
collection_name=active,
|
|
142
|
+
vectors_config={sel_vec: VectorParamsDiff(on_disk=not target_in_ram)}
|
|
143
|
+
)
|
|
144
|
+
get_collection_stats.clear()
|
|
145
|
+
st.success(f"Updated {sel_vec}")
|
|
146
|
+
st.rerun()
|
|
147
|
+
except Exception as e:
|
|
148
|
+
st.error(f"Failed: {e}")
|
|
149
|
+
else:
|
|
150
|
+
st.caption("Toggle to change storage location")
|
|
151
|
+
else:
|
|
152
|
+
st.info("No vectors")
|
|
153
|
+
else:
|
|
154
|
+
st.info("Select a collection")
|
|
155
|
+
|
|
156
|
+
st.divider()
|
|
157
|
+
|
|
158
|
+
if st.button("🔄 Refresh", type="secondary", use_container_width=True):
|
|
159
|
+
get_collections.clear()
|
|
160
|
+
get_collection_stats.clear()
|
|
161
|
+
sample_points_cached.clear()
|
|
162
|
+
st.rerun()
|