visual-rag-toolkit 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/README.md +101 -0
- benchmarks/__init__.py +11 -0
- benchmarks/analyze_results.py +187 -0
- benchmarks/benchmark_datasets.txt +105 -0
- benchmarks/prepare_submission.py +205 -0
- benchmarks/quick_test.py +566 -0
- benchmarks/run_vidore.py +513 -0
- benchmarks/vidore_beir_qdrant/run_qdrant_beir.py +1365 -0
- benchmarks/vidore_tatdqa_test/COMMANDS.md +83 -0
- benchmarks/vidore_tatdqa_test/__init__.py +6 -0
- benchmarks/vidore_tatdqa_test/dataset_loader.py +363 -0
- benchmarks/vidore_tatdqa_test/metrics.py +44 -0
- benchmarks/vidore_tatdqa_test/run_qdrant.py +799 -0
- benchmarks/vidore_tatdqa_test/sweep_eval.py +372 -0
- demo/__init__.py +10 -0
- demo/app.py +45 -0
- demo/commands.py +334 -0
- demo/config.py +34 -0
- demo/download_models.py +75 -0
- demo/evaluation.py +602 -0
- demo/example_metadata_mapping_sigir.json +37 -0
- demo/indexing.py +286 -0
- demo/qdrant_utils.py +211 -0
- demo/results.py +35 -0
- demo/test_qdrant_connection.py +119 -0
- demo/ui/__init__.py +15 -0
- demo/ui/benchmark.py +355 -0
- demo/ui/header.py +30 -0
- demo/ui/playground.py +339 -0
- demo/ui/sidebar.py +162 -0
- demo/ui/upload.py +487 -0
- visual_rag/__init__.py +98 -0
- visual_rag/cli/__init__.py +1 -0
- visual_rag/cli/main.py +629 -0
- visual_rag/config.py +230 -0
- visual_rag/demo_runner.py +90 -0
- visual_rag/embedding/__init__.py +26 -0
- visual_rag/embedding/pooling.py +343 -0
- visual_rag/embedding/visual_embedder.py +622 -0
- visual_rag/indexing/__init__.py +21 -0
- visual_rag/indexing/cloudinary_uploader.py +274 -0
- visual_rag/indexing/pdf_processor.py +324 -0
- visual_rag/indexing/pipeline.py +628 -0
- visual_rag/indexing/qdrant_indexer.py +478 -0
- visual_rag/preprocessing/__init__.py +3 -0
- visual_rag/preprocessing/crop_empty.py +120 -0
- visual_rag/qdrant_admin.py +222 -0
- visual_rag/retrieval/__init__.py +19 -0
- visual_rag/retrieval/multi_vector.py +222 -0
- visual_rag/retrieval/single_stage.py +126 -0
- visual_rag/retrieval/three_stage.py +173 -0
- visual_rag/retrieval/two_stage.py +471 -0
- visual_rag/visualization/__init__.py +19 -0
- visual_rag/visualization/saliency.py +335 -0
- visual_rag_toolkit-0.1.1.dist-info/METADATA +305 -0
- visual_rag_toolkit-0.1.1.dist-info/RECORD +59 -0
- visual_rag_toolkit-0.1.1.dist-info/WHEEL +4 -0
- visual_rag_toolkit-0.1.1.dist-info/entry_points.txt +3 -0
- visual_rag_toolkit-0.1.1.dist-info/licenses/LICENSE +22 -0
demo/indexing.py
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
1
|
+
"""Indexing runner with UI updates."""
|
|
2
|
+
|
|
3
|
+
import hashlib
|
|
4
|
+
import json
|
|
5
|
+
import time
|
|
6
|
+
import traceback
|
|
7
|
+
from datetime import datetime
|
|
8
|
+
from typing import Any, Dict, Optional
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import streamlit as st
|
|
12
|
+
import torch
|
|
13
|
+
|
|
14
|
+
from visual_rag import VisualEmbedder
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
TORCH_DTYPE_MAP = {
|
|
18
|
+
"float16": torch.float16,
|
|
19
|
+
"float32": torch.float32,
|
|
20
|
+
"bfloat16": torch.bfloat16,
|
|
21
|
+
}
|
|
22
|
+
from visual_rag.indexing import QdrantIndexer
|
|
23
|
+
from benchmarks.vidore_tatdqa_test.dataset_loader import load_vidore_beir_dataset
|
|
24
|
+
|
|
25
|
+
from demo.qdrant_utils import get_qdrant_credentials
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _stable_uuid(text: str) -> str:
|
|
29
|
+
"""Generate a stable UUID from text (same as benchmark script)."""
|
|
30
|
+
hex_str = hashlib.sha256(text.encode("utf-8")).hexdigest()[:32]
|
|
31
|
+
return f"{hex_str[:8]}-{hex_str[8:12]}-{hex_str[12:16]}-{hex_str[16:20]}-{hex_str[20:32]}"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _union_point_id(*, dataset_name: str, source_doc_id: str, union_namespace: Optional[str]) -> str:
|
|
35
|
+
"""Generate union point ID (same as benchmark script)."""
|
|
36
|
+
ns = f"{union_namespace}::{dataset_name}" if union_namespace else dataset_name
|
|
37
|
+
return _stable_uuid(f"{ns}::{source_doc_id}")
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def run_indexing_with_ui(config: Dict[str, Any]):
|
|
41
|
+
st.divider()
|
|
42
|
+
|
|
43
|
+
print("=" * 60)
|
|
44
|
+
print("[INDEX] Starting indexing via UI")
|
|
45
|
+
print("=" * 60)
|
|
46
|
+
|
|
47
|
+
url, api_key = get_qdrant_credentials()
|
|
48
|
+
if not url:
|
|
49
|
+
st.error("QDRANT_URL not configured")
|
|
50
|
+
return
|
|
51
|
+
|
|
52
|
+
datasets = config.get("datasets", [])
|
|
53
|
+
collection = config["collection"]
|
|
54
|
+
model = config.get("model", "vidore/colpali-v1.3")
|
|
55
|
+
recreate = config.get("recreate", False)
|
|
56
|
+
torch_dtype = config.get("torch_dtype", "float16")
|
|
57
|
+
qdrant_vector_dtype = config.get("qdrant_vector_dtype", "float16")
|
|
58
|
+
prefer_grpc = config.get("prefer_grpc", True)
|
|
59
|
+
batch_size = config.get("batch_size", 4)
|
|
60
|
+
max_docs = config.get("max_docs")
|
|
61
|
+
|
|
62
|
+
print(f"[INDEX] Config: collection={collection}, model={model}")
|
|
63
|
+
print(f"[INDEX] Datasets: {datasets}")
|
|
64
|
+
print(f"[INDEX] max_docs={max_docs}, batch_size={batch_size}, recreate={recreate}")
|
|
65
|
+
print(f"[INDEX] torch_dtype={torch_dtype}, qdrant_dtype={qdrant_vector_dtype}, grpc={prefer_grpc}")
|
|
66
|
+
|
|
67
|
+
phase1_container = st.container()
|
|
68
|
+
phase2_container = st.container()
|
|
69
|
+
phase3_container = st.container()
|
|
70
|
+
results_container = st.container()
|
|
71
|
+
|
|
72
|
+
try:
|
|
73
|
+
with phase1_container:
|
|
74
|
+
st.markdown("##### 🤖 Phase 1: Loading Model")
|
|
75
|
+
model_status = st.empty()
|
|
76
|
+
model_status.info(f"Loading `{model.split('/')[-1]}`...")
|
|
77
|
+
|
|
78
|
+
print(f"[INDEX] Loading embedder: {model}")
|
|
79
|
+
torch_dtype_obj = TORCH_DTYPE_MAP.get(torch_dtype, torch.float16)
|
|
80
|
+
output_dtype_obj = np.float16 if qdrant_vector_dtype == "float16" else np.float32
|
|
81
|
+
embedder = VisualEmbedder(
|
|
82
|
+
model_name=model,
|
|
83
|
+
torch_dtype=torch_dtype_obj,
|
|
84
|
+
output_dtype=output_dtype_obj,
|
|
85
|
+
)
|
|
86
|
+
embedder._load_model()
|
|
87
|
+
print(f"[INDEX] Embedder loaded (torch_dtype={torch_dtype}, output_dtype={qdrant_vector_dtype})")
|
|
88
|
+
model_status.success(f"✅ Model `{model.split('/')[-1]}` loaded")
|
|
89
|
+
|
|
90
|
+
with phase2_container:
|
|
91
|
+
st.markdown("##### 📦 Phase 2: Setting Up Collection")
|
|
92
|
+
|
|
93
|
+
indexer_status = st.empty()
|
|
94
|
+
indexer_status.info(f"Connecting to Qdrant...")
|
|
95
|
+
|
|
96
|
+
print(f"[INDEX] Connecting to Qdrant...")
|
|
97
|
+
indexer = QdrantIndexer(
|
|
98
|
+
url=url,
|
|
99
|
+
api_key=api_key,
|
|
100
|
+
collection_name=collection,
|
|
101
|
+
prefer_grpc=prefer_grpc,
|
|
102
|
+
vector_datatype=qdrant_vector_dtype,
|
|
103
|
+
)
|
|
104
|
+
print(f"[INDEX] Connected to Qdrant")
|
|
105
|
+
indexer_status.success(f"✅ Connected to Qdrant")
|
|
106
|
+
|
|
107
|
+
coll_status = st.empty()
|
|
108
|
+
action = "Recreating" if recreate else "Creating/verifying"
|
|
109
|
+
coll_status.info(f"{action} collection `{collection}`...")
|
|
110
|
+
|
|
111
|
+
print(f"[INDEX] {action} collection: {collection}")
|
|
112
|
+
indexer.create_collection(force_recreate=recreate)
|
|
113
|
+
indexer.create_payload_indexes(fields=[
|
|
114
|
+
{"field": "dataset", "type": "keyword"},
|
|
115
|
+
{"field": "doc_id", "type": "keyword"},
|
|
116
|
+
{"field": "source_doc_id", "type": "keyword"},
|
|
117
|
+
])
|
|
118
|
+
print(f"[INDEX] Collection ready")
|
|
119
|
+
coll_status.success(f"✅ Collection `{collection}` ready")
|
|
120
|
+
|
|
121
|
+
with phase3_container:
|
|
122
|
+
st.markdown("##### 🚀 Phase 3: Indexing Documents")
|
|
123
|
+
|
|
124
|
+
total_uploaded = 0
|
|
125
|
+
total_docs = 0
|
|
126
|
+
total_time = 0
|
|
127
|
+
|
|
128
|
+
for ds_name in datasets:
|
|
129
|
+
ds_short = ds_name.split("/")[-1]
|
|
130
|
+
ds_header = st.empty()
|
|
131
|
+
ds_header.info(f"📚 Loading `{ds_short}`...")
|
|
132
|
+
|
|
133
|
+
print(f"[INDEX] Loading dataset: {ds_name}")
|
|
134
|
+
corpus, queries, qrels = load_vidore_beir_dataset(ds_name)
|
|
135
|
+
|
|
136
|
+
if max_docs and max_docs > 0 and len(corpus) > max_docs:
|
|
137
|
+
corpus = corpus[:max_docs]
|
|
138
|
+
print(f"[INDEX] Limited to {len(corpus)} docs (max_docs={max_docs})")
|
|
139
|
+
|
|
140
|
+
total_docs += len(corpus)
|
|
141
|
+
print(f"[INDEX] Dataset {ds_name}: {len(corpus)} documents to index")
|
|
142
|
+
ds_header.success(f"📚 `{ds_short}`: {len(corpus)} documents")
|
|
143
|
+
|
|
144
|
+
progress_bar = st.progress(0.0)
|
|
145
|
+
batch_status = st.empty()
|
|
146
|
+
log_area = st.empty()
|
|
147
|
+
log_lines = []
|
|
148
|
+
|
|
149
|
+
num_batches = (len(corpus) + batch_size - 1) // batch_size
|
|
150
|
+
ds_start = time.time()
|
|
151
|
+
|
|
152
|
+
for i in range(0, len(corpus), batch_size):
|
|
153
|
+
batch = corpus[i:i + batch_size]
|
|
154
|
+
images = [doc.image for doc in batch if hasattr(doc, 'image') and doc.image]
|
|
155
|
+
|
|
156
|
+
if not images:
|
|
157
|
+
continue
|
|
158
|
+
|
|
159
|
+
batch_num = i // batch_size + 1
|
|
160
|
+
batch_status.info(f"Batch {batch_num}/{num_batches}: embedding & uploading...")
|
|
161
|
+
|
|
162
|
+
batch_start = time.time()
|
|
163
|
+
embeddings, token_infos = embedder.embed_images(images, return_token_info=True)
|
|
164
|
+
embed_time = time.time() - batch_start
|
|
165
|
+
|
|
166
|
+
points = []
|
|
167
|
+
for j, (doc, emb, token_info) in enumerate(zip(batch, embeddings, token_infos)):
|
|
168
|
+
doc_id = doc.doc_id if hasattr(doc, 'doc_id') else str(i + j)
|
|
169
|
+
source_doc_id = str(doc.payload.get("source_doc_id", doc_id) if hasattr(doc, 'payload') else doc_id)
|
|
170
|
+
|
|
171
|
+
union_doc_id = _union_point_id(
|
|
172
|
+
dataset_name=ds_name,
|
|
173
|
+
source_doc_id=source_doc_id,
|
|
174
|
+
union_namespace=collection,
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
emb_np = emb.cpu().numpy() if hasattr(emb, 'cpu') else np.array(emb)
|
|
178
|
+
visual_indices = token_info.get("visual_token_indices") or list(range(emb_np.shape[0]))
|
|
179
|
+
visual_emb = emb_np[visual_indices].astype(embedder.output_dtype)
|
|
180
|
+
|
|
181
|
+
tile_pooled = embedder.mean_pool_visual_embedding(visual_emb, token_info, target_vectors=32)
|
|
182
|
+
experimental = embedder.experimental_pool_visual_embedding(
|
|
183
|
+
visual_emb, token_info, target_vectors=32, mean_pool=tile_pooled
|
|
184
|
+
)
|
|
185
|
+
global_pooled = embedder.global_pool_from_mean_pool(tile_pooled)
|
|
186
|
+
|
|
187
|
+
points.append({
|
|
188
|
+
"id": union_doc_id,
|
|
189
|
+
"visual_embedding": visual_emb,
|
|
190
|
+
"tile_pooled_embedding": tile_pooled,
|
|
191
|
+
"experimental_pooled_embedding": experimental,
|
|
192
|
+
"global_pooled_embedding": global_pooled,
|
|
193
|
+
"metadata": {
|
|
194
|
+
"dataset": ds_name,
|
|
195
|
+
"doc_id": doc_id,
|
|
196
|
+
"source_doc_id": source_doc_id,
|
|
197
|
+
"union_doc_id": union_doc_id,
|
|
198
|
+
},
|
|
199
|
+
})
|
|
200
|
+
|
|
201
|
+
upload_start = time.time()
|
|
202
|
+
indexer.upload_batch(points)
|
|
203
|
+
upload_time = time.time() - upload_start
|
|
204
|
+
total_uploaded += len(points)
|
|
205
|
+
|
|
206
|
+
progress = (i + len(batch)) / len(corpus)
|
|
207
|
+
progress_bar.progress(progress)
|
|
208
|
+
batch_status.info(f"Batch {batch_num}/{num_batches} ({int(progress*100)}%) — embed: {embed_time:.1f}s, upload: {upload_time:.1f}s")
|
|
209
|
+
|
|
210
|
+
log_interval = max(2, num_batches // 10)
|
|
211
|
+
should_log = batch_num % log_interval == 0 or batch_num == num_batches
|
|
212
|
+
|
|
213
|
+
if should_log and batch_num > 1:
|
|
214
|
+
log_lines.append(f"[Batch {batch_num}/{num_batches}] +{len(points)} pts, total={total_uploaded}")
|
|
215
|
+
log_area.code("\n".join(log_lines[-8:]), language="text")
|
|
216
|
+
print(f"[INDEX] Batch {batch_num}/{num_batches}: +{len(points)} pts, total={total_uploaded}, embed={embed_time:.1f}s, upload={upload_time:.1f}s")
|
|
217
|
+
|
|
218
|
+
ds_time = time.time() - ds_start
|
|
219
|
+
total_time += ds_time
|
|
220
|
+
progress_bar.progress(1.0)
|
|
221
|
+
batch_status.success(f"✅ `{ds_short}` indexed: {len(corpus)} docs in {ds_time:.1f}s")
|
|
222
|
+
print(f"[INDEX] Dataset {ds_name} complete: {len(corpus)} docs in {ds_time:.1f}s")
|
|
223
|
+
|
|
224
|
+
with results_container:
|
|
225
|
+
st.markdown("##### 📊 Summary")
|
|
226
|
+
|
|
227
|
+
docs_per_sec = total_uploaded / total_time if total_time > 0 else 0
|
|
228
|
+
|
|
229
|
+
print("=" * 60)
|
|
230
|
+
print("[INDEX] INDEXING COMPLETE")
|
|
231
|
+
print(f"[INDEX] Total Uploaded: {total_uploaded:,}")
|
|
232
|
+
print(f"[INDEX] Datasets: {len(datasets)}")
|
|
233
|
+
print(f"[INDEX] Collection: {collection}")
|
|
234
|
+
print(f"[INDEX] Total Time: {total_time:.1f}s")
|
|
235
|
+
print(f"[INDEX] Throughput: {docs_per_sec:.2f} docs/s")
|
|
236
|
+
print("=" * 60)
|
|
237
|
+
|
|
238
|
+
c1, c2, c3, c4 = st.columns(4)
|
|
239
|
+
c1.metric("Total Uploaded", f"{total_uploaded:,}")
|
|
240
|
+
c2.metric("Datasets", len(datasets))
|
|
241
|
+
c3.metric("Total Time", f"{total_time:.1f}s")
|
|
242
|
+
c4.metric("Throughput", f"{docs_per_sec:.2f}/s")
|
|
243
|
+
|
|
244
|
+
st.success(f"🎉 Indexing complete! {total_uploaded:,} documents indexed to `{collection}`")
|
|
245
|
+
|
|
246
|
+
detailed_report = {
|
|
247
|
+
"generated_at": datetime.now().isoformat(),
|
|
248
|
+
"config": {
|
|
249
|
+
"collection": collection,
|
|
250
|
+
"model": model,
|
|
251
|
+
"datasets": datasets,
|
|
252
|
+
"batch_size": batch_size,
|
|
253
|
+
"max_docs_per_dataset": max_docs,
|
|
254
|
+
"recreate": recreate,
|
|
255
|
+
"prefer_grpc": prefer_grpc,
|
|
256
|
+
"torch_dtype": torch_dtype,
|
|
257
|
+
"qdrant_vector_dtype": qdrant_vector_dtype,
|
|
258
|
+
},
|
|
259
|
+
"results": {
|
|
260
|
+
"total_docs_uploaded": total_uploaded,
|
|
261
|
+
"total_time_s": round(total_time, 2),
|
|
262
|
+
"throughput_docs_per_s": round(docs_per_sec, 2),
|
|
263
|
+
"num_datasets": len(datasets),
|
|
264
|
+
},
|
|
265
|
+
}
|
|
266
|
+
|
|
267
|
+
with st.expander("📋 Full Summary"):
|
|
268
|
+
st.json(detailed_report)
|
|
269
|
+
|
|
270
|
+
report_json = json.dumps(detailed_report, indent=2)
|
|
271
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
272
|
+
filename = f"index_report__{collection}__{timestamp}.json"
|
|
273
|
+
|
|
274
|
+
st.download_button(
|
|
275
|
+
label="📥 Download Indexing Report",
|
|
276
|
+
data=report_json,
|
|
277
|
+
file_name=filename,
|
|
278
|
+
mime="application/json",
|
|
279
|
+
use_container_width=True,
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
except Exception as e:
|
|
283
|
+
print(f"[INDEX] ERROR: {e}")
|
|
284
|
+
st.error(f"❌ Error: {e}")
|
|
285
|
+
with st.expander("🔍 Full Error Details"):
|
|
286
|
+
st.code(traceback.format_exc(), language="text")
|
demo/qdrant_utils.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
1
|
+
"""Qdrant connection and utility functions."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import traceback
|
|
5
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
6
|
+
|
|
7
|
+
import streamlit as st
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def get_qdrant_credentials() -> Tuple[Optional[str], Optional[str]]:
|
|
11
|
+
url = st.session_state.get("qdrant_url_input") or os.getenv("SIGIR_QDRANT_URL") or os.getenv("DEST_QDRANT_URL") or os.getenv("QDRANT_URL")
|
|
12
|
+
api_key = st.session_state.get("qdrant_key_input") or (
|
|
13
|
+
os.getenv("SIGIR_QDRANT_KEY")
|
|
14
|
+
or os.getenv("SIGIR_QDRANT_API_KEY")
|
|
15
|
+
or os.getenv("DEST_QDRANT_API_KEY")
|
|
16
|
+
or os.getenv("QDRANT_API_KEY")
|
|
17
|
+
)
|
|
18
|
+
return url, api_key
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def init_qdrant_client_with_creds(url: str, api_key: str):
|
|
22
|
+
try:
|
|
23
|
+
from qdrant_client import QdrantClient
|
|
24
|
+
if not url:
|
|
25
|
+
return None, "QDRANT_URL not configured"
|
|
26
|
+
client = QdrantClient(url=url, api_key=api_key, timeout=60)
|
|
27
|
+
client.get_collections()
|
|
28
|
+
return client, None
|
|
29
|
+
except Exception as e:
|
|
30
|
+
return None, str(e)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@st.cache_resource(show_spinner="Connecting to Qdrant...")
|
|
34
|
+
def init_qdrant_client():
|
|
35
|
+
url, api_key = get_qdrant_credentials()
|
|
36
|
+
return init_qdrant_client_with_creds(url, api_key)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@st.cache_resource(show_spinner="Loading embedding model...")
|
|
40
|
+
def init_embedder(model_name: str):
|
|
41
|
+
try:
|
|
42
|
+
from visual_rag import VisualEmbedder
|
|
43
|
+
return VisualEmbedder(model_name=model_name), None
|
|
44
|
+
except Exception as e:
|
|
45
|
+
return None, f"{e}\n\n{traceback.format_exc()}"
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@st.cache_data(ttl=300, show_spinner="Fetching collections...")
|
|
49
|
+
def get_collections(_url: str, _api_key: str) -> List[str]:
|
|
50
|
+
client, err = init_qdrant_client_with_creds(_url, _api_key)
|
|
51
|
+
if client is None:
|
|
52
|
+
return []
|
|
53
|
+
try:
|
|
54
|
+
collections = client.get_collections().collections
|
|
55
|
+
return sorted([c.name for c in collections])
|
|
56
|
+
except Exception:
|
|
57
|
+
return []
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@st.cache_data(ttl=120, show_spinner="Loading collection stats...")
|
|
61
|
+
def get_collection_stats(collection_name: str) -> Dict[str, Any]:
|
|
62
|
+
url, api_key = get_qdrant_credentials()
|
|
63
|
+
client, err = init_qdrant_client_with_creds(url, api_key)
|
|
64
|
+
if client is None:
|
|
65
|
+
return {"error": err}
|
|
66
|
+
try:
|
|
67
|
+
info = client.get_collection(collection_name)
|
|
68
|
+
vectors_config = getattr(getattr(getattr(info, "config", None), "params", None), "vectors", None)
|
|
69
|
+
vector_info = {}
|
|
70
|
+
if vectors_config is not None:
|
|
71
|
+
if hasattr(vectors_config, "items"):
|
|
72
|
+
for name, cfg in vectors_config.items():
|
|
73
|
+
size = getattr(cfg, "size", None)
|
|
74
|
+
multivec = getattr(cfg, "multivector_config", None)
|
|
75
|
+
on_disk = getattr(cfg, "on_disk", None)
|
|
76
|
+
datatype = str(getattr(cfg, "datatype", "Float32")).replace("Datatype.", "")
|
|
77
|
+
quantization = getattr(cfg, "quantization_config", None)
|
|
78
|
+
num_vectors = 1
|
|
79
|
+
if multivec is not None:
|
|
80
|
+
comparator = getattr(multivec, "comparator", None)
|
|
81
|
+
num_vectors = "N" if comparator else 1
|
|
82
|
+
vector_info[name] = {
|
|
83
|
+
"size": size,
|
|
84
|
+
"num_vectors": num_vectors,
|
|
85
|
+
"is_multivector": multivec is not None,
|
|
86
|
+
"on_disk": on_disk,
|
|
87
|
+
"datatype": datatype,
|
|
88
|
+
"quantization": quantization is not None,
|
|
89
|
+
}
|
|
90
|
+
elif hasattr(vectors_config, "size"):
|
|
91
|
+
on_disk = getattr(vectors_config, "on_disk", None)
|
|
92
|
+
datatype = str(getattr(vectors_config, "datatype", "Float32")).replace("Datatype.", "")
|
|
93
|
+
multivec = getattr(vectors_config, "multivector_config", None)
|
|
94
|
+
vector_info["default"] = {
|
|
95
|
+
"size": getattr(vectors_config, "size", None),
|
|
96
|
+
"num_vectors": "N" if multivec else 1,
|
|
97
|
+
"is_multivector": multivec is not None,
|
|
98
|
+
"on_disk": on_disk,
|
|
99
|
+
"datatype": datatype,
|
|
100
|
+
}
|
|
101
|
+
return {
|
|
102
|
+
"points_count": getattr(info, "points_count", 0),
|
|
103
|
+
"vectors_count": getattr(info, "vectors_count", getattr(info, "points_count", 0)),
|
|
104
|
+
"status": str(getattr(info, "status", "unknown")),
|
|
105
|
+
"vector_info": vector_info,
|
|
106
|
+
"indexed_vectors_count": getattr(info, "indexed_vectors_count", None),
|
|
107
|
+
}
|
|
108
|
+
except Exception as e:
|
|
109
|
+
return {"error": f"{e}\n\n{traceback.format_exc()}"}
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@st.cache_data(ttl=60)
|
|
113
|
+
def sample_points_cached(collection_name: str, n: int, seed: int, _url: str, _api_key: str) -> List[Dict[str, Any]]:
|
|
114
|
+
client, err = init_qdrant_client_with_creds(_url, _api_key)
|
|
115
|
+
if client is None:
|
|
116
|
+
return []
|
|
117
|
+
try:
|
|
118
|
+
import random
|
|
119
|
+
rng = random.Random(seed)
|
|
120
|
+
points, _ = client.scroll(
|
|
121
|
+
collection_name=collection_name,
|
|
122
|
+
limit=min(n * 10, 100),
|
|
123
|
+
with_payload=True,
|
|
124
|
+
with_vectors=False,
|
|
125
|
+
)
|
|
126
|
+
if not points:
|
|
127
|
+
return []
|
|
128
|
+
sampled = rng.sample(points, min(n, len(points)))
|
|
129
|
+
results = []
|
|
130
|
+
for p in sampled:
|
|
131
|
+
payload = dict(p.payload) if p.payload else {}
|
|
132
|
+
results.append({
|
|
133
|
+
"id": str(p.id),
|
|
134
|
+
"payload": payload,
|
|
135
|
+
})
|
|
136
|
+
return results
|
|
137
|
+
except Exception:
|
|
138
|
+
return []
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
@st.cache_data(ttl=300)
|
|
142
|
+
def get_vector_sizes(collection_name: str, _url: str, _api_key: str) -> Dict[str, int]:
|
|
143
|
+
client, err = init_qdrant_client_with_creds(_url, _api_key)
|
|
144
|
+
if client is None:
|
|
145
|
+
return {}
|
|
146
|
+
try:
|
|
147
|
+
points, _ = client.scroll(
|
|
148
|
+
collection_name=collection_name,
|
|
149
|
+
limit=1,
|
|
150
|
+
with_payload=False,
|
|
151
|
+
with_vectors=True,
|
|
152
|
+
)
|
|
153
|
+
if not points:
|
|
154
|
+
return {}
|
|
155
|
+
vectors = points[0].vector
|
|
156
|
+
sizes = {}
|
|
157
|
+
if isinstance(vectors, dict):
|
|
158
|
+
for name, vec in vectors.items():
|
|
159
|
+
if isinstance(vec, list):
|
|
160
|
+
if vec and isinstance(vec[0], list):
|
|
161
|
+
sizes[name] = len(vec)
|
|
162
|
+
else:
|
|
163
|
+
sizes[name] = 1
|
|
164
|
+
else:
|
|
165
|
+
sizes[name] = 1
|
|
166
|
+
return sizes
|
|
167
|
+
except Exception:
|
|
168
|
+
return {}
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def search_collection(
|
|
172
|
+
collection_name: str,
|
|
173
|
+
query: str,
|
|
174
|
+
top_k: int = 10,
|
|
175
|
+
mode: str = "single_full",
|
|
176
|
+
prefetch_k: int = 256,
|
|
177
|
+
stage1_mode: str = "tokens_vs_tiles",
|
|
178
|
+
stage1_k: int = 1000,
|
|
179
|
+
stage2_k: int = 300,
|
|
180
|
+
model_name: str = "vidore/colSmol-500M",
|
|
181
|
+
) -> Tuple[List[Dict[str, Any]], Optional[str]]:
|
|
182
|
+
try:
|
|
183
|
+
import traceback
|
|
184
|
+
from visual_rag.retrieval import MultiVectorRetriever
|
|
185
|
+
retriever = MultiVectorRetriever(
|
|
186
|
+
collection_name=collection_name,
|
|
187
|
+
model_name=model_name,
|
|
188
|
+
)
|
|
189
|
+
if mode == "three_stage":
|
|
190
|
+
q_emb = retriever.embedder.embed_query(query)
|
|
191
|
+
if hasattr(q_emb, "cpu"):
|
|
192
|
+
q_emb = q_emb.cpu().numpy()
|
|
193
|
+
results = retriever.search_embedded(
|
|
194
|
+
query_embedding=q_emb,
|
|
195
|
+
top_k=top_k,
|
|
196
|
+
mode=mode,
|
|
197
|
+
stage1_k=stage1_k,
|
|
198
|
+
stage2_k=stage2_k,
|
|
199
|
+
)
|
|
200
|
+
else:
|
|
201
|
+
results = retriever.search(
|
|
202
|
+
query=query,
|
|
203
|
+
top_k=top_k,
|
|
204
|
+
mode=mode,
|
|
205
|
+
prefetch_k=prefetch_k,
|
|
206
|
+
stage1_mode=stage1_mode,
|
|
207
|
+
)
|
|
208
|
+
return results, None
|
|
209
|
+
except Exception as e:
|
|
210
|
+
import traceback
|
|
211
|
+
return [], f"{e}\n\n{traceback.format_exc()}"
|
demo/results.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
"""Results file handling utilities."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any, Dict, List, Optional
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def load_results_file(path: Path) -> Optional[Dict[str, Any]]:
|
|
9
|
+
try:
|
|
10
|
+
with open(path, "r") as f:
|
|
11
|
+
return json.load(f)
|
|
12
|
+
except Exception:
|
|
13
|
+
return None
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def get_available_results() -> List[Path]:
|
|
17
|
+
results_dir = Path(__file__).parent.parent / "results"
|
|
18
|
+
if not results_dir.exists():
|
|
19
|
+
return []
|
|
20
|
+
results = []
|
|
21
|
+
for subdir in results_dir.iterdir():
|
|
22
|
+
if subdir.is_dir():
|
|
23
|
+
for f in subdir.glob("*.json"):
|
|
24
|
+
if "index_failures" not in f.name:
|
|
25
|
+
results.append(f)
|
|
26
|
+
return sorted(results, key=lambda x: x.stat().st_mtime, reverse=True)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def find_main_result_file(collection: str, mode: str) -> Optional[Path]:
|
|
30
|
+
results = get_available_results()
|
|
31
|
+
for r in results:
|
|
32
|
+
if collection in str(r) and mode in r.name:
|
|
33
|
+
if "__vidore_" not in r.name:
|
|
34
|
+
return r
|
|
35
|
+
return results[0] if results else None
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""Test Qdrant connection and collection creation."""
|
|
3
|
+
|
|
4
|
+
import os
|
|
5
|
+
import sys
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
9
|
+
|
|
10
|
+
from dotenv import load_dotenv
|
|
11
|
+
load_dotenv(Path(__file__).parent.parent / ".env")
|
|
12
|
+
load_dotenv(Path(__file__).parent.parent.parent / ".env")
|
|
13
|
+
|
|
14
|
+
def test_connection():
|
|
15
|
+
from qdrant_client import QdrantClient
|
|
16
|
+
from qdrant_client.http import models
|
|
17
|
+
|
|
18
|
+
url = os.getenv("QDRANT_URL")
|
|
19
|
+
api_key = os.getenv("QDRANT_API_KEY")
|
|
20
|
+
|
|
21
|
+
print(f"URL: {url}")
|
|
22
|
+
print(f"API Key: {'***' + api_key[-4:] if api_key else 'NOT SET'}")
|
|
23
|
+
|
|
24
|
+
if not url or not api_key:
|
|
25
|
+
print("ERROR: QDRANT_URL or QDRANT_API_KEY not set")
|
|
26
|
+
return
|
|
27
|
+
|
|
28
|
+
print("\n1. Creating client...")
|
|
29
|
+
client = QdrantClient(url=url, api_key=api_key, timeout=60)
|
|
30
|
+
|
|
31
|
+
print("\n2. Getting collections...")
|
|
32
|
+
try:
|
|
33
|
+
collections = client.get_collections()
|
|
34
|
+
print(f" Found {len(collections.collections)} collections:")
|
|
35
|
+
for c in collections.collections:
|
|
36
|
+
print(f" - {c.name}")
|
|
37
|
+
except Exception as e:
|
|
38
|
+
print(f" ERROR: {e}")
|
|
39
|
+
return
|
|
40
|
+
|
|
41
|
+
test_collection = "_test_visual_rag_toolkit"
|
|
42
|
+
|
|
43
|
+
print(f"\n3. Checking if '{test_collection}' exists...")
|
|
44
|
+
exists = any(c.name == test_collection for c in collections.collections)
|
|
45
|
+
print(f" Exists: {exists}")
|
|
46
|
+
|
|
47
|
+
if exists:
|
|
48
|
+
print(f"\n4. Deleting test collection...")
|
|
49
|
+
try:
|
|
50
|
+
client.delete_collection(test_collection)
|
|
51
|
+
print(" Deleted")
|
|
52
|
+
except Exception as e:
|
|
53
|
+
print(f" ERROR: {e}")
|
|
54
|
+
|
|
55
|
+
print(f"\n5. Creating SIMPLE collection (single vector)...")
|
|
56
|
+
try:
|
|
57
|
+
client.create_collection(
|
|
58
|
+
collection_name=test_collection,
|
|
59
|
+
vectors_config=models.VectorParams(
|
|
60
|
+
size=128,
|
|
61
|
+
distance=models.Distance.COSINE,
|
|
62
|
+
),
|
|
63
|
+
)
|
|
64
|
+
print(" SUCCESS: Simple collection created")
|
|
65
|
+
except Exception as e:
|
|
66
|
+
print(f" ERROR: {e}")
|
|
67
|
+
print("\n This means basic collection creation is failing.")
|
|
68
|
+
print(" Check your Qdrant Cloud cluster status/limits.")
|
|
69
|
+
return
|
|
70
|
+
|
|
71
|
+
print(f"\n6. Deleting test collection...")
|
|
72
|
+
try:
|
|
73
|
+
client.delete_collection(test_collection)
|
|
74
|
+
print(" Deleted")
|
|
75
|
+
except Exception as e:
|
|
76
|
+
print(f" ERROR: {e}")
|
|
77
|
+
|
|
78
|
+
print(f"\n7. Creating MULTI-VECTOR collection (like visual-rag)...")
|
|
79
|
+
try:
|
|
80
|
+
client.create_collection(
|
|
81
|
+
collection_name=test_collection,
|
|
82
|
+
vectors_config={
|
|
83
|
+
"initial": models.VectorParams(
|
|
84
|
+
size=128,
|
|
85
|
+
distance=models.Distance.COSINE,
|
|
86
|
+
multivector_config=models.MultiVectorConfig(
|
|
87
|
+
comparator=models.MultiVectorComparator.MAX_SIM
|
|
88
|
+
),
|
|
89
|
+
),
|
|
90
|
+
"mean_pooling": models.VectorParams(
|
|
91
|
+
size=128,
|
|
92
|
+
distance=models.Distance.COSINE,
|
|
93
|
+
multivector_config=models.MultiVectorConfig(
|
|
94
|
+
comparator=models.MultiVectorComparator.MAX_SIM
|
|
95
|
+
),
|
|
96
|
+
),
|
|
97
|
+
},
|
|
98
|
+
)
|
|
99
|
+
print(" SUCCESS: Multi-vector collection created")
|
|
100
|
+
except Exception as e:
|
|
101
|
+
print(f" ERROR: {e}")
|
|
102
|
+
print("\n Multi-vector collection failed but simple worked.")
|
|
103
|
+
print(" Your Qdrant version may not support multi-vector.")
|
|
104
|
+
return
|
|
105
|
+
|
|
106
|
+
print(f"\n8. Final cleanup...")
|
|
107
|
+
try:
|
|
108
|
+
client.delete_collection(test_collection)
|
|
109
|
+
print(" Deleted")
|
|
110
|
+
except Exception as e:
|
|
111
|
+
print(f" ERROR: {e}")
|
|
112
|
+
|
|
113
|
+
print("\n" + "="*50)
|
|
114
|
+
print("ALL TESTS PASSED - Qdrant connection is working!")
|
|
115
|
+
print("="*50)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
if __name__ == "__main__":
|
|
119
|
+
test_connection()
|
demo/ui/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""UI components for the demo app."""
|
|
2
|
+
|
|
3
|
+
from demo.ui.header import render_header
|
|
4
|
+
from demo.ui.sidebar import render_sidebar
|
|
5
|
+
from demo.ui.upload import render_upload_tab
|
|
6
|
+
from demo.ui.playground import render_playground_tab
|
|
7
|
+
from demo.ui.benchmark import render_benchmark_tab
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"render_header",
|
|
11
|
+
"render_sidebar",
|
|
12
|
+
"render_upload_tab",
|
|
13
|
+
"render_playground_tab",
|
|
14
|
+
"render_benchmark_tab",
|
|
15
|
+
]
|