visual-rag-toolkit 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- demo/app.py +20 -8
- demo/evaluation.py +5 -45
- demo/indexing.py +180 -192
- demo/qdrant_utils.py +12 -5
- demo/ui/playground.py +1 -1
- demo/ui/sidebar.py +4 -3
- demo/ui/upload.py +5 -4
- visual_rag/__init__.py +43 -1
- visual_rag/config.py +4 -7
- visual_rag/indexing/__init__.py +21 -4
- visual_rag/indexing/qdrant_indexer.py +92 -42
- visual_rag/retrieval/multi_vector.py +63 -65
- visual_rag/retrieval/single_stage.py +7 -0
- visual_rag/retrieval/two_stage.py +8 -10
- {visual_rag_toolkit-0.1.1.dist-info → visual_rag_toolkit-0.1.3.dist-info}/METADATA +98 -17
- {visual_rag_toolkit-0.1.1.dist-info → visual_rag_toolkit-0.1.3.dist-info}/RECORD +19 -20
- benchmarks/vidore_tatdqa_test/COMMANDS.md +0 -83
- {visual_rag_toolkit-0.1.1.dist-info → visual_rag_toolkit-0.1.3.dist-info}/WHEEL +0 -0
- {visual_rag_toolkit-0.1.1.dist-info → visual_rag_toolkit-0.1.3.dist-info}/entry_points.txt +0 -0
- {visual_rag_toolkit-0.1.1.dist-info → visual_rag_toolkit-0.1.3.dist-info}/licenses/LICENSE +0 -0
demo/app.py
CHANGED
|
@@ -1,13 +1,23 @@
|
|
|
1
1
|
"""Main entry point for the Visual RAG Toolkit demo application."""
|
|
2
2
|
|
|
3
|
+
import os
|
|
3
4
|
import sys
|
|
4
5
|
from pathlib import Path
|
|
5
6
|
|
|
6
|
-
|
|
7
|
-
|
|
7
|
+
# Ensure repo root is in sys.path for local development
|
|
8
|
+
# (In HF Space / Docker, PYTHONPATH is already set correctly)
|
|
9
|
+
_app_dir = Path(__file__).resolve().parent
|
|
10
|
+
_repo_root = _app_dir.parent
|
|
11
|
+
if str(_repo_root) not in sys.path:
|
|
12
|
+
sys.path.insert(0, str(_repo_root))
|
|
8
13
|
|
|
9
14
|
from dotenv import load_dotenv
|
|
10
|
-
|
|
15
|
+
|
|
16
|
+
# Load .env from the repo root (works both locally and in Docker)
|
|
17
|
+
if (_repo_root / ".env").exists():
|
|
18
|
+
load_dotenv(_repo_root / ".env")
|
|
19
|
+
if (_app_dir / ".env").exists():
|
|
20
|
+
load_dotenv(_app_dir / ".env")
|
|
11
21
|
|
|
12
22
|
import streamlit as st
|
|
13
23
|
|
|
@@ -28,15 +38,17 @@ from demo.ui.benchmark import render_benchmark_tab
|
|
|
28
38
|
def main():
|
|
29
39
|
render_header()
|
|
30
40
|
render_sidebar()
|
|
31
|
-
|
|
32
|
-
tab_upload, tab_playground, tab_benchmark = st.tabs(
|
|
33
|
-
|
|
41
|
+
|
|
42
|
+
tab_upload, tab_playground, tab_benchmark = st.tabs(
|
|
43
|
+
["📤 Upload", "🎮 Playground", "📊 Benchmarking"]
|
|
44
|
+
)
|
|
45
|
+
|
|
34
46
|
with tab_upload:
|
|
35
47
|
render_upload_tab()
|
|
36
|
-
|
|
48
|
+
|
|
37
49
|
with tab_playground:
|
|
38
50
|
render_playground_tab()
|
|
39
|
-
|
|
51
|
+
|
|
40
52
|
with tab_benchmark:
|
|
41
53
|
render_benchmark_tab()
|
|
42
54
|
|
demo/evaluation.py
CHANGED
|
@@ -1,20 +1,23 @@
|
|
|
1
1
|
"""Evaluation runner with UI updates."""
|
|
2
2
|
|
|
3
3
|
import hashlib
|
|
4
|
-
import importlib.util
|
|
5
4
|
import json
|
|
6
5
|
import logging
|
|
7
6
|
import time
|
|
8
7
|
import traceback
|
|
9
8
|
from datetime import datetime
|
|
10
|
-
from pathlib import Path
|
|
11
9
|
from typing import Any, Dict, List, Optional
|
|
12
10
|
|
|
13
11
|
import numpy as np
|
|
14
12
|
import streamlit as st
|
|
15
13
|
import torch
|
|
14
|
+
from qdrant_client.models import FieldCondition, Filter, MatchValue
|
|
16
15
|
|
|
17
16
|
from visual_rag import VisualEmbedder
|
|
17
|
+
from visual_rag.retrieval import MultiVectorRetriever
|
|
18
|
+
from benchmarks.vidore_tatdqa_test.dataset_loader import load_vidore_beir_dataset
|
|
19
|
+
from benchmarks.vidore_tatdqa_test.metrics import ndcg_at_k, mrr_at_k, recall_at_k
|
|
20
|
+
from demo.qdrant_utils import get_qdrant_credentials
|
|
18
21
|
|
|
19
22
|
|
|
20
23
|
TORCH_DTYPE_MAP = {
|
|
@@ -22,49 +25,6 @@ TORCH_DTYPE_MAP = {
|
|
|
22
25
|
"float32": torch.float32,
|
|
23
26
|
"bfloat16": torch.bfloat16,
|
|
24
27
|
}
|
|
25
|
-
from qdrant_client.models import Filter, FieldCondition, MatchValue
|
|
26
|
-
|
|
27
|
-
from visual_rag.retrieval import MultiVectorRetriever
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
def _load_local_benchmark_module(module_filename: str):
|
|
31
|
-
"""
|
|
32
|
-
Load `benchmarks/vidore_tatdqa_test/<module_filename>` via file path.
|
|
33
|
-
|
|
34
|
-
Motivation:
|
|
35
|
-
- Some environments (notably containers / Spaces) can have a third-party
|
|
36
|
-
`benchmarks` package installed, causing `import benchmarks...` to resolve
|
|
37
|
-
to the wrong module.
|
|
38
|
-
- This fallback guarantees we load the repo's benchmark utilities.
|
|
39
|
-
"""
|
|
40
|
-
root = Path(__file__).resolve().parents[1] # demo/.. = repo root
|
|
41
|
-
target = root / "benchmarks" / "vidore_tatdqa_test" / module_filename
|
|
42
|
-
if not target.exists():
|
|
43
|
-
raise ModuleNotFoundError(f"Missing local benchmark module file: {target}")
|
|
44
|
-
|
|
45
|
-
name = f"_visual_rag_toolkit_local_{target.stem}"
|
|
46
|
-
spec = importlib.util.spec_from_file_location(name, str(target))
|
|
47
|
-
if spec is None or spec.loader is None:
|
|
48
|
-
raise ModuleNotFoundError(f"Could not load module spec for: {target}")
|
|
49
|
-
mod = importlib.util.module_from_spec(spec)
|
|
50
|
-
spec.loader.exec_module(mod) # type: ignore[attr-defined]
|
|
51
|
-
return mod
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
try:
|
|
55
|
-
# Preferred: normal import
|
|
56
|
-
from benchmarks.vidore_tatdqa_test.dataset_loader import load_vidore_beir_dataset
|
|
57
|
-
from benchmarks.vidore_tatdqa_test.metrics import ndcg_at_k, mrr_at_k, recall_at_k
|
|
58
|
-
except ModuleNotFoundError:
|
|
59
|
-
# Robust fallback: load from local file paths
|
|
60
|
-
_dl = _load_local_benchmark_module("dataset_loader.py")
|
|
61
|
-
_mx = _load_local_benchmark_module("metrics.py")
|
|
62
|
-
load_vidore_beir_dataset = _dl.load_vidore_beir_dataset
|
|
63
|
-
ndcg_at_k = _mx.ndcg_at_k
|
|
64
|
-
mrr_at_k = _mx.mrr_at_k
|
|
65
|
-
recall_at_k = _mx.recall_at_k
|
|
66
|
-
|
|
67
|
-
from demo.qdrant_utils import get_qdrant_credentials
|
|
68
28
|
|
|
69
29
|
logger = logging.getLogger(__name__)
|
|
70
30
|
logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")
|
demo/indexing.py
CHANGED
|
@@ -12,6 +12,10 @@ import streamlit as st
|
|
|
12
12
|
import torch
|
|
13
13
|
|
|
14
14
|
from visual_rag import VisualEmbedder
|
|
15
|
+
from visual_rag.embedding.pooling import tile_level_mean_pooling
|
|
16
|
+
from visual_rag.indexing.qdrant_indexer import QdrantIndexer
|
|
17
|
+
from benchmarks.vidore_tatdqa_test.dataset_loader import load_vidore_beir_dataset
|
|
18
|
+
from demo.qdrant_utils import get_qdrant_credentials
|
|
15
19
|
|
|
16
20
|
|
|
17
21
|
TORCH_DTYPE_MAP = {
|
|
@@ -19,10 +23,6 @@ TORCH_DTYPE_MAP = {
|
|
|
19
23
|
"float32": torch.float32,
|
|
20
24
|
"bfloat16": torch.bfloat16,
|
|
21
25
|
}
|
|
22
|
-
from visual_rag.indexing import QdrantIndexer
|
|
23
|
-
from benchmarks.vidore_tatdqa_test.dataset_loader import load_vidore_beir_dataset
|
|
24
|
-
|
|
25
|
-
from demo.qdrant_utils import get_qdrant_credentials
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
def _stable_uuid(text: str) -> str:
|
|
@@ -31,7 +31,9 @@ def _stable_uuid(text: str) -> str:
|
|
|
31
31
|
return f"{hex_str[:8]}-{hex_str[8:12]}-{hex_str[12:16]}-{hex_str[16:20]}-{hex_str[20:32]}"
|
|
32
32
|
|
|
33
33
|
|
|
34
|
-
def _union_point_id(
|
|
34
|
+
def _union_point_id(
|
|
35
|
+
*, dataset_name: str, source_doc_id: str, union_namespace: Optional[str]
|
|
36
|
+
) -> str:
|
|
35
37
|
"""Generate union point ID (same as benchmark script)."""
|
|
36
38
|
ns = f"{union_namespace}::{dataset_name}" if union_namespace else dataset_name
|
|
37
39
|
return _stable_uuid(f"{ns}::{source_doc_id}")
|
|
@@ -39,16 +41,16 @@ def _union_point_id(*, dataset_name: str, source_doc_id: str, union_namespace: O
|
|
|
39
41
|
|
|
40
42
|
def run_indexing_with_ui(config: Dict[str, Any]):
|
|
41
43
|
st.divider()
|
|
42
|
-
|
|
44
|
+
|
|
43
45
|
print("=" * 60)
|
|
44
46
|
print("[INDEX] Starting indexing via UI")
|
|
45
47
|
print("=" * 60)
|
|
46
|
-
|
|
48
|
+
|
|
47
49
|
url, api_key = get_qdrant_credentials()
|
|
48
50
|
if not url:
|
|
49
51
|
st.error("QDRANT_URL not configured")
|
|
50
52
|
return
|
|
51
|
-
|
|
53
|
+
|
|
52
54
|
datasets = config.get("datasets", [])
|
|
53
55
|
collection = config["collection"]
|
|
54
56
|
model = config.get("model", "vidore/colpali-v1.3")
|
|
@@ -58,42 +60,50 @@ def run_indexing_with_ui(config: Dict[str, Any]):
|
|
|
58
60
|
prefer_grpc = config.get("prefer_grpc", True)
|
|
59
61
|
batch_size = config.get("batch_size", 4)
|
|
60
62
|
max_docs = config.get("max_docs")
|
|
61
|
-
|
|
63
|
+
|
|
62
64
|
print(f"[INDEX] Config: collection={collection}, model={model}")
|
|
63
65
|
print(f"[INDEX] Datasets: {datasets}")
|
|
64
|
-
print(
|
|
65
|
-
|
|
66
|
-
|
|
66
|
+
print(
|
|
67
|
+
f"[INDEX] max_docs={max_docs}, batch_size={batch_size}, recreate={recreate}"
|
|
68
|
+
)
|
|
69
|
+
print(
|
|
70
|
+
f"[INDEX] torch_dtype={torch_dtype}, qdrant_dtype={qdrant_vector_dtype}, grpc={prefer_grpc}"
|
|
71
|
+
)
|
|
72
|
+
|
|
67
73
|
phase1_container = st.container()
|
|
68
74
|
phase2_container = st.container()
|
|
69
75
|
phase3_container = st.container()
|
|
70
76
|
results_container = st.container()
|
|
71
|
-
|
|
77
|
+
|
|
72
78
|
try:
|
|
73
79
|
with phase1_container:
|
|
74
80
|
st.markdown("##### 🤖 Phase 1: Loading Model")
|
|
75
81
|
model_status = st.empty()
|
|
76
82
|
model_status.info(f"Loading `{model.split('/')[-1]}`...")
|
|
77
|
-
|
|
83
|
+
|
|
78
84
|
print(f"[INDEX] Loading embedder: {model}")
|
|
79
85
|
torch_dtype_obj = TORCH_DTYPE_MAP.get(torch_dtype, torch.float16)
|
|
80
|
-
output_dtype_obj =
|
|
86
|
+
output_dtype_obj = (
|
|
87
|
+
np.float16 if qdrant_vector_dtype == "float16" else np.float32
|
|
88
|
+
)
|
|
81
89
|
embedder = VisualEmbedder(
|
|
82
90
|
model_name=model,
|
|
83
91
|
torch_dtype=torch_dtype_obj,
|
|
84
92
|
output_dtype=output_dtype_obj,
|
|
85
93
|
)
|
|
86
94
|
embedder._load_model()
|
|
87
|
-
print(
|
|
95
|
+
print(
|
|
96
|
+
f"[INDEX] Embedder loaded (torch_dtype={torch_dtype}, output_dtype={qdrant_vector_dtype})"
|
|
97
|
+
)
|
|
88
98
|
model_status.success(f"✅ Model `{model.split('/')[-1]}` loaded")
|
|
89
|
-
|
|
99
|
+
|
|
90
100
|
with phase2_container:
|
|
91
101
|
st.markdown("##### 📦 Phase 2: Setting Up Collection")
|
|
92
|
-
|
|
102
|
+
|
|
93
103
|
indexer_status = st.empty()
|
|
94
|
-
indexer_status.info(
|
|
95
|
-
|
|
96
|
-
print(
|
|
104
|
+
indexer_status.info("Connecting to Qdrant...")
|
|
105
|
+
|
|
106
|
+
print("[INDEX] Connecting to Qdrant...")
|
|
97
107
|
indexer = QdrantIndexer(
|
|
98
108
|
url=url,
|
|
99
109
|
api_key=api_key,
|
|
@@ -101,186 +111,164 @@ def run_indexing_with_ui(config: Dict[str, Any]):
|
|
|
101
111
|
prefer_grpc=prefer_grpc,
|
|
102
112
|
vector_datatype=qdrant_vector_dtype,
|
|
103
113
|
)
|
|
104
|
-
print(
|
|
105
|
-
indexer_status.success(
|
|
106
|
-
|
|
114
|
+
print("[INDEX] Connected to Qdrant")
|
|
115
|
+
indexer_status.success("✅ Connected to Qdrant")
|
|
116
|
+
|
|
107
117
|
coll_status = st.empty()
|
|
108
118
|
action = "Recreating" if recreate else "Creating/verifying"
|
|
109
119
|
coll_status.info(f"{action} collection `{collection}`...")
|
|
110
|
-
|
|
120
|
+
|
|
111
121
|
print(f"[INDEX] {action} collection: {collection}")
|
|
112
122
|
indexer.create_collection(force_recreate=recreate)
|
|
113
|
-
indexer.create_payload_indexes(
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
123
|
+
indexer.create_payload_indexes(
|
|
124
|
+
fields=[
|
|
125
|
+
{"field": "dataset", "type": "keyword"},
|
|
126
|
+
{"field": "doc_id", "type": "keyword"},
|
|
127
|
+
{"field": "source_doc_id", "type": "keyword"},
|
|
128
|
+
]
|
|
129
|
+
)
|
|
130
|
+
print("[INDEX] Collection ready")
|
|
119
131
|
coll_status.success(f"✅ Collection `{collection}` ready")
|
|
120
|
-
|
|
132
|
+
|
|
121
133
|
with phase3_container:
|
|
122
|
-
st.markdown("#####
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
print(f"[INDEX]
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
134
|
+
st.markdown("##### 📊 Phase 3: Processing Datasets")
|
|
135
|
+
|
|
136
|
+
all_results = []
|
|
137
|
+
|
|
138
|
+
for ds_idx, dataset_name in enumerate(datasets):
|
|
139
|
+
ds_short = dataset_name.split("/")[-1]
|
|
140
|
+
ds_container = st.container()
|
|
141
|
+
|
|
142
|
+
with ds_container:
|
|
143
|
+
st.markdown(
|
|
144
|
+
f"**Dataset {ds_idx + 1}/{len(datasets)}: `{ds_short}`**"
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
load_status = st.empty()
|
|
148
|
+
load_status.info(f"Loading dataset `{ds_short}`...")
|
|
149
|
+
|
|
150
|
+
print(f"[INDEX] Loading dataset: {dataset_name}")
|
|
151
|
+
corpus, queries, qrels = load_vidore_beir_dataset(dataset_name)
|
|
152
|
+
total_docs = len(corpus)
|
|
153
|
+
print(f"[INDEX] Dataset loaded: {total_docs} docs")
|
|
154
|
+
load_status.success(f"✅ Loaded {total_docs:,} documents")
|
|
155
|
+
|
|
156
|
+
if max_docs and max_docs < total_docs:
|
|
157
|
+
corpus = corpus[:max_docs]
|
|
158
|
+
print(f"[INDEX] Limiting to {max_docs} docs")
|
|
159
|
+
|
|
160
|
+
progress_bar = st.progress(0)
|
|
161
|
+
status_text = st.empty()
|
|
162
|
+
|
|
163
|
+
uploaded = 0
|
|
164
|
+
failed = 0
|
|
165
|
+
total = len(corpus)
|
|
166
|
+
|
|
167
|
+
for i, doc in enumerate(corpus):
|
|
168
|
+
try:
|
|
169
|
+
doc_id = str(doc.doc_id)
|
|
170
|
+
image = doc.image
|
|
171
|
+
if image is None:
|
|
172
|
+
failed += 1
|
|
173
|
+
continue
|
|
174
|
+
|
|
175
|
+
status_text.text(
|
|
176
|
+
f"Processing {i + 1}/{total}: {doc_id[:30]}..."
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
embeddings, token_infos = embedder.embed_images(
|
|
180
|
+
[image],
|
|
181
|
+
return_token_info=True,
|
|
182
|
+
show_progress=False,
|
|
183
|
+
)
|
|
184
|
+
emb = embeddings[0]
|
|
185
|
+
token_info = token_infos[0] if token_infos else {}
|
|
186
|
+
|
|
187
|
+
if hasattr(emb, "cpu"):
|
|
188
|
+
emb = emb.cpu()
|
|
189
|
+
emb_np = np.asarray(emb, dtype=output_dtype_obj)
|
|
190
|
+
|
|
191
|
+
initial = emb_np.tolist()
|
|
192
|
+
global_pool = emb_np.mean(axis=0).tolist()
|
|
193
|
+
|
|
194
|
+
num_tiles = token_info.get("num_tiles")
|
|
195
|
+
mean_pooling = None
|
|
196
|
+
experimental_pooling = None
|
|
197
|
+
|
|
198
|
+
if num_tiles and num_tiles > 0:
|
|
199
|
+
try:
|
|
200
|
+
mean_pooling = tile_level_mean_pooling(
|
|
201
|
+
emb_np, num_tiles=num_tiles, patches_per_tile=64
|
|
202
|
+
).tolist()
|
|
203
|
+
except Exception:
|
|
204
|
+
pass
|
|
205
|
+
|
|
206
|
+
try:
|
|
207
|
+
exp_pool = embedder.experimental_pool_visual_embedding(
|
|
208
|
+
emb_np, num_tiles=num_tiles
|
|
209
|
+
)
|
|
210
|
+
if exp_pool is not None:
|
|
211
|
+
experimental_pooling = exp_pool.tolist()
|
|
212
|
+
except Exception:
|
|
213
|
+
pass
|
|
214
|
+
|
|
215
|
+
union_doc_id = _union_point_id(
|
|
216
|
+
dataset_name=dataset_name,
|
|
217
|
+
source_doc_id=doc_id,
|
|
218
|
+
union_namespace=collection,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
payload = {
|
|
222
|
+
"dataset": dataset_name,
|
|
195
223
|
"doc_id": doc_id,
|
|
196
|
-
"source_doc_id":
|
|
224
|
+
"source_doc_id": doc_id,
|
|
197
225
|
"union_doc_id": union_doc_id,
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
226
|
+
"num_tiles": num_tiles,
|
|
227
|
+
"num_visual_tokens": token_info.get("num_visual_tokens"),
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
vectors = {"initial": initial, "global_pooling": global_pool}
|
|
231
|
+
if mean_pooling:
|
|
232
|
+
vectors["mean_pooling"] = mean_pooling
|
|
233
|
+
if experimental_pooling:
|
|
234
|
+
vectors["experimental_pooling"] = experimental_pooling
|
|
235
|
+
|
|
236
|
+
indexer.upsert_point(
|
|
237
|
+
point_id=union_doc_id,
|
|
238
|
+
vectors=vectors,
|
|
239
|
+
payload=payload,
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
uploaded += 1
|
|
243
|
+
|
|
244
|
+
except Exception as e:
|
|
245
|
+
print(f"[INDEX] Error on doc {i}: {e}")
|
|
246
|
+
failed += 1
|
|
247
|
+
|
|
248
|
+
progress_bar.progress((i + 1) / total)
|
|
249
|
+
|
|
250
|
+
status_text.text(f"✅ Done: {uploaded} uploaded, {failed} failed")
|
|
251
|
+
all_results.append(
|
|
252
|
+
{
|
|
253
|
+
"dataset": dataset_name,
|
|
254
|
+
"total": total,
|
|
255
|
+
"uploaded": uploaded,
|
|
256
|
+
"failed": failed,
|
|
257
|
+
}
|
|
258
|
+
)
|
|
259
|
+
|
|
224
260
|
with results_container:
|
|
225
|
-
st.markdown("#####
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
print(f"[INDEX] Total Time: {total_time:.1f}s")
|
|
235
|
-
print(f"[INDEX] Throughput: {docs_per_sec:.2f} docs/s")
|
|
236
|
-
print("=" * 60)
|
|
237
|
-
|
|
238
|
-
c1, c2, c3, c4 = st.columns(4)
|
|
239
|
-
c1.metric("Total Uploaded", f"{total_uploaded:,}")
|
|
240
|
-
c2.metric("Datasets", len(datasets))
|
|
241
|
-
c3.metric("Total Time", f"{total_time:.1f}s")
|
|
242
|
-
c4.metric("Throughput", f"{docs_per_sec:.2f}/s")
|
|
243
|
-
|
|
244
|
-
st.success(f"🎉 Indexing complete! {total_uploaded:,} documents indexed to `{collection}`")
|
|
245
|
-
|
|
246
|
-
detailed_report = {
|
|
247
|
-
"generated_at": datetime.now().isoformat(),
|
|
248
|
-
"config": {
|
|
249
|
-
"collection": collection,
|
|
250
|
-
"model": model,
|
|
251
|
-
"datasets": datasets,
|
|
252
|
-
"batch_size": batch_size,
|
|
253
|
-
"max_docs_per_dataset": max_docs,
|
|
254
|
-
"recreate": recreate,
|
|
255
|
-
"prefer_grpc": prefer_grpc,
|
|
256
|
-
"torch_dtype": torch_dtype,
|
|
257
|
-
"qdrant_vector_dtype": qdrant_vector_dtype,
|
|
258
|
-
},
|
|
259
|
-
"results": {
|
|
260
|
-
"total_docs_uploaded": total_uploaded,
|
|
261
|
-
"total_time_s": round(total_time, 2),
|
|
262
|
-
"throughput_docs_per_s": round(docs_per_sec, 2),
|
|
263
|
-
"num_datasets": len(datasets),
|
|
264
|
-
},
|
|
265
|
-
}
|
|
266
|
-
|
|
267
|
-
with st.expander("📋 Full Summary"):
|
|
268
|
-
st.json(detailed_report)
|
|
269
|
-
|
|
270
|
-
report_json = json.dumps(detailed_report, indent=2)
|
|
271
|
-
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
272
|
-
filename = f"index_report__{collection}__{timestamp}.json"
|
|
273
|
-
|
|
274
|
-
st.download_button(
|
|
275
|
-
label="📥 Download Indexing Report",
|
|
276
|
-
data=report_json,
|
|
277
|
-
file_name=filename,
|
|
278
|
-
mime="application/json",
|
|
279
|
-
use_container_width=True,
|
|
280
|
-
)
|
|
281
|
-
|
|
261
|
+
st.markdown("##### 📋 Results Summary")
|
|
262
|
+
|
|
263
|
+
for r in all_results:
|
|
264
|
+
st.write(
|
|
265
|
+
f"**{r['dataset'].split('/')[-1]}**: {r['uploaded']:,} uploaded, {r['failed']:,} failed"
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
st.success("✅ Indexing complete!")
|
|
269
|
+
|
|
282
270
|
except Exception as e:
|
|
271
|
+
st.error(f"Indexing error: {e}")
|
|
272
|
+
st.code(traceback.format_exc())
|
|
283
273
|
print(f"[INDEX] ERROR: {e}")
|
|
284
|
-
|
|
285
|
-
with st.expander("🔍 Full Error Details"):
|
|
286
|
-
st.code(traceback.format_exc(), language="text")
|
|
274
|
+
traceback.print_exc()
|
demo/qdrant_utils.py
CHANGED
|
@@ -8,12 +8,19 @@ import streamlit as st
|
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
def get_qdrant_credentials() -> Tuple[Optional[str], Optional[str]]:
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
11
|
+
"""Get Qdrant credentials from session state or environment variables.
|
|
12
|
+
|
|
13
|
+
Priority: session_state > QDRANT_URL/QDRANT_API_KEY > legacy env vars
|
|
14
|
+
"""
|
|
15
|
+
url = (
|
|
16
|
+
st.session_state.get("qdrant_url_input")
|
|
17
|
+
or os.getenv("QDRANT_URL")
|
|
18
|
+
or os.getenv("SIGIR_QDRANT_URL") # legacy
|
|
19
|
+
)
|
|
20
|
+
api_key = (
|
|
21
|
+
st.session_state.get("qdrant_key_input")
|
|
16
22
|
or os.getenv("QDRANT_API_KEY")
|
|
23
|
+
or os.getenv("SIGIR_QDRANT_KEY") # legacy
|
|
17
24
|
)
|
|
18
25
|
return url, api_key
|
|
19
26
|
|
demo/ui/playground.py
CHANGED
|
@@ -9,6 +9,7 @@ from demo.qdrant_utils import (
|
|
|
9
9
|
sample_points_cached,
|
|
10
10
|
search_collection,
|
|
11
11
|
)
|
|
12
|
+
from visual_rag.retrieval import MultiVectorRetriever
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
def render_playground_tab():
|
|
@@ -46,7 +47,6 @@ def render_playground_tab():
|
|
|
46
47
|
if not st.session_state.get("model_loaded"):
|
|
47
48
|
with st.spinner(f"Loading {model_short}..."):
|
|
48
49
|
try:
|
|
49
|
-
from visual_rag.retrieval import MultiVectorRetriever
|
|
50
50
|
_ = MultiVectorRetriever(collection_name=active_collection, model_name=model_name)
|
|
51
51
|
st.session_state["model_loaded"] = True
|
|
52
52
|
st.session_state["loaded_model_key"] = cache_key
|
demo/ui/sidebar.py
CHANGED
|
@@ -3,6 +3,8 @@
|
|
|
3
3
|
import os
|
|
4
4
|
import streamlit as st
|
|
5
5
|
|
|
6
|
+
from qdrant_client.models import VectorParamsDiff
|
|
7
|
+
|
|
6
8
|
from demo.qdrant_utils import (
|
|
7
9
|
get_qdrant_credentials,
|
|
8
10
|
init_qdrant_client_with_creds,
|
|
@@ -17,8 +19,8 @@ def render_sidebar():
|
|
|
17
19
|
with st.sidebar:
|
|
18
20
|
st.subheader("🔑 Qdrant Credentials")
|
|
19
21
|
|
|
20
|
-
env_url = os.getenv("
|
|
21
|
-
env_key = os.getenv("
|
|
22
|
+
env_url = os.getenv("QDRANT_URL") or os.getenv("SIGIR_QDRANT_URL") or ""
|
|
23
|
+
env_key = os.getenv("QDRANT_API_KEY") or os.getenv("SIGIR_QDRANT_KEY") or ""
|
|
22
24
|
|
|
23
25
|
if "qdrant_url_input" not in st.session_state:
|
|
24
26
|
st.session_state["qdrant_url_input"] = env_url
|
|
@@ -136,7 +138,6 @@ def render_sidebar():
|
|
|
136
138
|
if target_in_ram != current_in_ram:
|
|
137
139
|
if st.button("💾 Apply Change", key="admin_apply"):
|
|
138
140
|
try:
|
|
139
|
-
from qdrant_client.models import VectorParamsDiff
|
|
140
141
|
client.update_collection(
|
|
141
142
|
collection_name=active,
|
|
142
143
|
vectors_config={sel_vec: VectorParamsDiff(on_disk=not target_in_ram)}
|