visual-rag-toolkit 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- demo/app.py +20 -8
- demo/evaluation.py +5 -45
- demo/indexing.py +180 -221
- demo/qdrant_utils.py +12 -5
- demo/ui/playground.py +1 -1
- demo/ui/sidebar.py +4 -3
- demo/ui/upload.py +5 -4
- visual_rag/__init__.py +43 -1
- visual_rag/config.py +4 -7
- visual_rag/indexing/__init__.py +21 -4
- visual_rag/indexing/qdrant_indexer.py +92 -42
- visual_rag/retrieval/multi_vector.py +63 -65
- visual_rag/retrieval/single_stage.py +7 -0
- visual_rag/retrieval/two_stage.py +8 -10
- {visual_rag_toolkit-0.1.2.dist-info → visual_rag_toolkit-0.1.3.dist-info}/METADATA +24 -15
- {visual_rag_toolkit-0.1.2.dist-info → visual_rag_toolkit-0.1.3.dist-info}/RECORD +19 -19
- {visual_rag_toolkit-0.1.2.dist-info → visual_rag_toolkit-0.1.3.dist-info}/WHEEL +0 -0
- {visual_rag_toolkit-0.1.2.dist-info → visual_rag_toolkit-0.1.3.dist-info}/entry_points.txt +0 -0
- {visual_rag_toolkit-0.1.2.dist-info → visual_rag_toolkit-0.1.3.dist-info}/licenses/LICENSE +0 -0
demo/app.py
CHANGED
|
@@ -1,13 +1,23 @@
|
|
|
1
1
|
"""Main entry point for the Visual RAG Toolkit demo application."""
|
|
2
2
|
|
|
3
|
+
import os
|
|
3
4
|
import sys
|
|
4
5
|
from pathlib import Path
|
|
5
6
|
|
|
6
|
-
|
|
7
|
-
|
|
7
|
+
# Ensure repo root is in sys.path for local development
|
|
8
|
+
# (In HF Space / Docker, PYTHONPATH is already set correctly)
|
|
9
|
+
_app_dir = Path(__file__).resolve().parent
|
|
10
|
+
_repo_root = _app_dir.parent
|
|
11
|
+
if str(_repo_root) not in sys.path:
|
|
12
|
+
sys.path.insert(0, str(_repo_root))
|
|
8
13
|
|
|
9
14
|
from dotenv import load_dotenv
|
|
10
|
-
|
|
15
|
+
|
|
16
|
+
# Load .env from the repo root (works both locally and in Docker)
|
|
17
|
+
if (_repo_root / ".env").exists():
|
|
18
|
+
load_dotenv(_repo_root / ".env")
|
|
19
|
+
if (_app_dir / ".env").exists():
|
|
20
|
+
load_dotenv(_app_dir / ".env")
|
|
11
21
|
|
|
12
22
|
import streamlit as st
|
|
13
23
|
|
|
@@ -28,15 +38,17 @@ from demo.ui.benchmark import render_benchmark_tab
|
|
|
28
38
|
def main():
|
|
29
39
|
render_header()
|
|
30
40
|
render_sidebar()
|
|
31
|
-
|
|
32
|
-
tab_upload, tab_playground, tab_benchmark = st.tabs(
|
|
33
|
-
|
|
41
|
+
|
|
42
|
+
tab_upload, tab_playground, tab_benchmark = st.tabs(
|
|
43
|
+
["📤 Upload", "🎮 Playground", "📊 Benchmarking"]
|
|
44
|
+
)
|
|
45
|
+
|
|
34
46
|
with tab_upload:
|
|
35
47
|
render_upload_tab()
|
|
36
|
-
|
|
48
|
+
|
|
37
49
|
with tab_playground:
|
|
38
50
|
render_playground_tab()
|
|
39
|
-
|
|
51
|
+
|
|
40
52
|
with tab_benchmark:
|
|
41
53
|
render_benchmark_tab()
|
|
42
54
|
|
demo/evaluation.py
CHANGED
|
@@ -1,20 +1,23 @@
|
|
|
1
1
|
"""Evaluation runner with UI updates."""
|
|
2
2
|
|
|
3
3
|
import hashlib
|
|
4
|
-
import importlib.util
|
|
5
4
|
import json
|
|
6
5
|
import logging
|
|
7
6
|
import time
|
|
8
7
|
import traceback
|
|
9
8
|
from datetime import datetime
|
|
10
|
-
from pathlib import Path
|
|
11
9
|
from typing import Any, Dict, List, Optional
|
|
12
10
|
|
|
13
11
|
import numpy as np
|
|
14
12
|
import streamlit as st
|
|
15
13
|
import torch
|
|
14
|
+
from qdrant_client.models import FieldCondition, Filter, MatchValue
|
|
16
15
|
|
|
17
16
|
from visual_rag import VisualEmbedder
|
|
17
|
+
from visual_rag.retrieval import MultiVectorRetriever
|
|
18
|
+
from benchmarks.vidore_tatdqa_test.dataset_loader import load_vidore_beir_dataset
|
|
19
|
+
from benchmarks.vidore_tatdqa_test.metrics import ndcg_at_k, mrr_at_k, recall_at_k
|
|
20
|
+
from demo.qdrant_utils import get_qdrant_credentials
|
|
18
21
|
|
|
19
22
|
|
|
20
23
|
TORCH_DTYPE_MAP = {
|
|
@@ -22,49 +25,6 @@ TORCH_DTYPE_MAP = {
|
|
|
22
25
|
"float32": torch.float32,
|
|
23
26
|
"bfloat16": torch.bfloat16,
|
|
24
27
|
}
|
|
25
|
-
from qdrant_client.models import Filter, FieldCondition, MatchValue
|
|
26
|
-
|
|
27
|
-
from visual_rag.retrieval import MultiVectorRetriever
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
def _load_local_benchmark_module(module_filename: str):
|
|
31
|
-
"""
|
|
32
|
-
Load `benchmarks/vidore_tatdqa_test/<module_filename>` via file path.
|
|
33
|
-
|
|
34
|
-
Motivation:
|
|
35
|
-
- Some environments (notably containers / Spaces) can have a third-party
|
|
36
|
-
`benchmarks` package installed, causing `import benchmarks...` to resolve
|
|
37
|
-
to the wrong module.
|
|
38
|
-
- This fallback guarantees we load the repo's benchmark utilities.
|
|
39
|
-
"""
|
|
40
|
-
root = Path(__file__).resolve().parents[1] # demo/.. = repo root
|
|
41
|
-
target = root / "benchmarks" / "vidore_tatdqa_test" / module_filename
|
|
42
|
-
if not target.exists():
|
|
43
|
-
raise ModuleNotFoundError(f"Missing local benchmark module file: {target}")
|
|
44
|
-
|
|
45
|
-
name = f"_visual_rag_toolkit_local_{target.stem}"
|
|
46
|
-
spec = importlib.util.spec_from_file_location(name, str(target))
|
|
47
|
-
if spec is None or spec.loader is None:
|
|
48
|
-
raise ModuleNotFoundError(f"Could not load module spec for: {target}")
|
|
49
|
-
mod = importlib.util.module_from_spec(spec)
|
|
50
|
-
spec.loader.exec_module(mod) # type: ignore[attr-defined]
|
|
51
|
-
return mod
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
try:
|
|
55
|
-
# Preferred: normal import
|
|
56
|
-
from benchmarks.vidore_tatdqa_test.dataset_loader import load_vidore_beir_dataset
|
|
57
|
-
from benchmarks.vidore_tatdqa_test.metrics import ndcg_at_k, mrr_at_k, recall_at_k
|
|
58
|
-
except ModuleNotFoundError:
|
|
59
|
-
# Robust fallback: load from local file paths
|
|
60
|
-
_dl = _load_local_benchmark_module("dataset_loader.py")
|
|
61
|
-
_mx = _load_local_benchmark_module("metrics.py")
|
|
62
|
-
load_vidore_beir_dataset = _dl.load_vidore_beir_dataset
|
|
63
|
-
ndcg_at_k = _mx.ndcg_at_k
|
|
64
|
-
mrr_at_k = _mx.mrr_at_k
|
|
65
|
-
recall_at_k = _mx.recall_at_k
|
|
66
|
-
|
|
67
|
-
from demo.qdrant_utils import get_qdrant_credentials
|
|
68
28
|
|
|
69
29
|
logger = logging.getLogger(__name__)
|
|
70
30
|
logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")
|
demo/indexing.py
CHANGED
|
@@ -1,12 +1,10 @@
|
|
|
1
1
|
"""Indexing runner with UI updates."""
|
|
2
2
|
|
|
3
3
|
import hashlib
|
|
4
|
-
import importlib.util
|
|
5
4
|
import json
|
|
6
5
|
import time
|
|
7
6
|
import traceback
|
|
8
7
|
from datetime import datetime
|
|
9
|
-
from pathlib import Path
|
|
10
8
|
from typing import Any, Dict, Optional
|
|
11
9
|
|
|
12
10
|
import numpy as np
|
|
@@ -14,6 +12,10 @@ import streamlit as st
|
|
|
14
12
|
import torch
|
|
15
13
|
|
|
16
14
|
from visual_rag import VisualEmbedder
|
|
15
|
+
from visual_rag.embedding.pooling import tile_level_mean_pooling
|
|
16
|
+
from visual_rag.indexing.qdrant_indexer import QdrantIndexer
|
|
17
|
+
from benchmarks.vidore_tatdqa_test.dataset_loader import load_vidore_beir_dataset
|
|
18
|
+
from demo.qdrant_utils import get_qdrant_credentials
|
|
17
19
|
|
|
18
20
|
|
|
19
21
|
TORCH_DTYPE_MAP = {
|
|
@@ -22,37 +24,6 @@ TORCH_DTYPE_MAP = {
|
|
|
22
24
|
"bfloat16": torch.bfloat16,
|
|
23
25
|
}
|
|
24
26
|
|
|
25
|
-
# --- Robust imports (Spaces-friendly) ---
|
|
26
|
-
# Some environments can have a third-party `benchmarks` package installed, or
|
|
27
|
-
# resolve `visual_rag.indexing` oddly. These fallbacks keep the demo working.
|
|
28
|
-
try:
|
|
29
|
-
from visual_rag.indexing import QdrantIndexer
|
|
30
|
-
except Exception: # pragma: no cover
|
|
31
|
-
from visual_rag.indexing.qdrant_indexer import QdrantIndexer
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
def _load_local_benchmark_module(module_filename: str):
|
|
35
|
-
root = Path(__file__).resolve().parents[1] # demo/.. = repo root
|
|
36
|
-
target = root / "benchmarks" / "vidore_tatdqa_test" / module_filename
|
|
37
|
-
if not target.exists():
|
|
38
|
-
raise ModuleNotFoundError(f"Missing local benchmark module file: {target}")
|
|
39
|
-
name = f"_visual_rag_toolkit_local_{target.stem}"
|
|
40
|
-
spec = importlib.util.spec_from_file_location(name, str(target))
|
|
41
|
-
if spec is None or spec.loader is None:
|
|
42
|
-
raise ModuleNotFoundError(f"Could not load module spec for: {target}")
|
|
43
|
-
mod = importlib.util.module_from_spec(spec)
|
|
44
|
-
spec.loader.exec_module(mod) # type: ignore[attr-defined]
|
|
45
|
-
return mod
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
try:
|
|
49
|
-
from benchmarks.vidore_tatdqa_test.dataset_loader import load_vidore_beir_dataset
|
|
50
|
-
except ModuleNotFoundError: # pragma: no cover
|
|
51
|
-
_dl = _load_local_benchmark_module("dataset_loader.py")
|
|
52
|
-
load_vidore_beir_dataset = _dl.load_vidore_beir_dataset
|
|
53
|
-
|
|
54
|
-
from demo.qdrant_utils import get_qdrant_credentials
|
|
55
|
-
|
|
56
27
|
|
|
57
28
|
def _stable_uuid(text: str) -> str:
|
|
58
29
|
"""Generate a stable UUID from text (same as benchmark script)."""
|
|
@@ -60,7 +31,9 @@ def _stable_uuid(text: str) -> str:
|
|
|
60
31
|
return f"{hex_str[:8]}-{hex_str[8:12]}-{hex_str[12:16]}-{hex_str[16:20]}-{hex_str[20:32]}"
|
|
61
32
|
|
|
62
33
|
|
|
63
|
-
def _union_point_id(
|
|
34
|
+
def _union_point_id(
|
|
35
|
+
*, dataset_name: str, source_doc_id: str, union_namespace: Optional[str]
|
|
36
|
+
) -> str:
|
|
64
37
|
"""Generate union point ID (same as benchmark script)."""
|
|
65
38
|
ns = f"{union_namespace}::{dataset_name}" if union_namespace else dataset_name
|
|
66
39
|
return _stable_uuid(f"{ns}::{source_doc_id}")
|
|
@@ -68,16 +41,16 @@ def _union_point_id(*, dataset_name: str, source_doc_id: str, union_namespace: O
|
|
|
68
41
|
|
|
69
42
|
def run_indexing_with_ui(config: Dict[str, Any]):
|
|
70
43
|
st.divider()
|
|
71
|
-
|
|
44
|
+
|
|
72
45
|
print("=" * 60)
|
|
73
46
|
print("[INDEX] Starting indexing via UI")
|
|
74
47
|
print("=" * 60)
|
|
75
|
-
|
|
48
|
+
|
|
76
49
|
url, api_key = get_qdrant_credentials()
|
|
77
50
|
if not url:
|
|
78
51
|
st.error("QDRANT_URL not configured")
|
|
79
52
|
return
|
|
80
|
-
|
|
53
|
+
|
|
81
54
|
datasets = config.get("datasets", [])
|
|
82
55
|
collection = config["collection"]
|
|
83
56
|
model = config.get("model", "vidore/colpali-v1.3")
|
|
@@ -87,42 +60,50 @@ def run_indexing_with_ui(config: Dict[str, Any]):
|
|
|
87
60
|
prefer_grpc = config.get("prefer_grpc", True)
|
|
88
61
|
batch_size = config.get("batch_size", 4)
|
|
89
62
|
max_docs = config.get("max_docs")
|
|
90
|
-
|
|
63
|
+
|
|
91
64
|
print(f"[INDEX] Config: collection={collection}, model={model}")
|
|
92
65
|
print(f"[INDEX] Datasets: {datasets}")
|
|
93
|
-
print(
|
|
94
|
-
|
|
95
|
-
|
|
66
|
+
print(
|
|
67
|
+
f"[INDEX] max_docs={max_docs}, batch_size={batch_size}, recreate={recreate}"
|
|
68
|
+
)
|
|
69
|
+
print(
|
|
70
|
+
f"[INDEX] torch_dtype={torch_dtype}, qdrant_dtype={qdrant_vector_dtype}, grpc={prefer_grpc}"
|
|
71
|
+
)
|
|
72
|
+
|
|
96
73
|
phase1_container = st.container()
|
|
97
74
|
phase2_container = st.container()
|
|
98
75
|
phase3_container = st.container()
|
|
99
76
|
results_container = st.container()
|
|
100
|
-
|
|
77
|
+
|
|
101
78
|
try:
|
|
102
79
|
with phase1_container:
|
|
103
80
|
st.markdown("##### 🤖 Phase 1: Loading Model")
|
|
104
81
|
model_status = st.empty()
|
|
105
82
|
model_status.info(f"Loading `{model.split('/')[-1]}`...")
|
|
106
|
-
|
|
83
|
+
|
|
107
84
|
print(f"[INDEX] Loading embedder: {model}")
|
|
108
85
|
torch_dtype_obj = TORCH_DTYPE_MAP.get(torch_dtype, torch.float16)
|
|
109
|
-
output_dtype_obj =
|
|
86
|
+
output_dtype_obj = (
|
|
87
|
+
np.float16 if qdrant_vector_dtype == "float16" else np.float32
|
|
88
|
+
)
|
|
110
89
|
embedder = VisualEmbedder(
|
|
111
90
|
model_name=model,
|
|
112
91
|
torch_dtype=torch_dtype_obj,
|
|
113
92
|
output_dtype=output_dtype_obj,
|
|
114
93
|
)
|
|
115
94
|
embedder._load_model()
|
|
116
|
-
print(
|
|
95
|
+
print(
|
|
96
|
+
f"[INDEX] Embedder loaded (torch_dtype={torch_dtype}, output_dtype={qdrant_vector_dtype})"
|
|
97
|
+
)
|
|
117
98
|
model_status.success(f"✅ Model `{model.split('/')[-1]}` loaded")
|
|
118
|
-
|
|
99
|
+
|
|
119
100
|
with phase2_container:
|
|
120
101
|
st.markdown("##### 📦 Phase 2: Setting Up Collection")
|
|
121
|
-
|
|
102
|
+
|
|
122
103
|
indexer_status = st.empty()
|
|
123
|
-
indexer_status.info(
|
|
124
|
-
|
|
125
|
-
print(
|
|
104
|
+
indexer_status.info("Connecting to Qdrant...")
|
|
105
|
+
|
|
106
|
+
print("[INDEX] Connecting to Qdrant...")
|
|
126
107
|
indexer = QdrantIndexer(
|
|
127
108
|
url=url,
|
|
128
109
|
api_key=api_key,
|
|
@@ -130,186 +111,164 @@ def run_indexing_with_ui(config: Dict[str, Any]):
|
|
|
130
111
|
prefer_grpc=prefer_grpc,
|
|
131
112
|
vector_datatype=qdrant_vector_dtype,
|
|
132
113
|
)
|
|
133
|
-
print(
|
|
134
|
-
indexer_status.success(
|
|
135
|
-
|
|
114
|
+
print("[INDEX] Connected to Qdrant")
|
|
115
|
+
indexer_status.success("✅ Connected to Qdrant")
|
|
116
|
+
|
|
136
117
|
coll_status = st.empty()
|
|
137
118
|
action = "Recreating" if recreate else "Creating/verifying"
|
|
138
119
|
coll_status.info(f"{action} collection `{collection}`...")
|
|
139
|
-
|
|
120
|
+
|
|
140
121
|
print(f"[INDEX] {action} collection: {collection}")
|
|
141
122
|
indexer.create_collection(force_recreate=recreate)
|
|
142
|
-
indexer.create_payload_indexes(
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
123
|
+
indexer.create_payload_indexes(
|
|
124
|
+
fields=[
|
|
125
|
+
{"field": "dataset", "type": "keyword"},
|
|
126
|
+
{"field": "doc_id", "type": "keyword"},
|
|
127
|
+
{"field": "source_doc_id", "type": "keyword"},
|
|
128
|
+
]
|
|
129
|
+
)
|
|
130
|
+
print("[INDEX] Collection ready")
|
|
148
131
|
coll_status.success(f"✅ Collection `{collection}` ready")
|
|
149
|
-
|
|
132
|
+
|
|
150
133
|
with phase3_container:
|
|
151
|
-
st.markdown("#####
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
print(f"[INDEX]
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
134
|
+
st.markdown("##### 📊 Phase 3: Processing Datasets")
|
|
135
|
+
|
|
136
|
+
all_results = []
|
|
137
|
+
|
|
138
|
+
for ds_idx, dataset_name in enumerate(datasets):
|
|
139
|
+
ds_short = dataset_name.split("/")[-1]
|
|
140
|
+
ds_container = st.container()
|
|
141
|
+
|
|
142
|
+
with ds_container:
|
|
143
|
+
st.markdown(
|
|
144
|
+
f"**Dataset {ds_idx + 1}/{len(datasets)}: `{ds_short}`**"
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
load_status = st.empty()
|
|
148
|
+
load_status.info(f"Loading dataset `{ds_short}`...")
|
|
149
|
+
|
|
150
|
+
print(f"[INDEX] Loading dataset: {dataset_name}")
|
|
151
|
+
corpus, queries, qrels = load_vidore_beir_dataset(dataset_name)
|
|
152
|
+
total_docs = len(corpus)
|
|
153
|
+
print(f"[INDEX] Dataset loaded: {total_docs} docs")
|
|
154
|
+
load_status.success(f"✅ Loaded {total_docs:,} documents")
|
|
155
|
+
|
|
156
|
+
if max_docs and max_docs < total_docs:
|
|
157
|
+
corpus = corpus[:max_docs]
|
|
158
|
+
print(f"[INDEX] Limiting to {max_docs} docs")
|
|
159
|
+
|
|
160
|
+
progress_bar = st.progress(0)
|
|
161
|
+
status_text = st.empty()
|
|
162
|
+
|
|
163
|
+
uploaded = 0
|
|
164
|
+
failed = 0
|
|
165
|
+
total = len(corpus)
|
|
166
|
+
|
|
167
|
+
for i, doc in enumerate(corpus):
|
|
168
|
+
try:
|
|
169
|
+
doc_id = str(doc.doc_id)
|
|
170
|
+
image = doc.image
|
|
171
|
+
if image is None:
|
|
172
|
+
failed += 1
|
|
173
|
+
continue
|
|
174
|
+
|
|
175
|
+
status_text.text(
|
|
176
|
+
f"Processing {i + 1}/{total}: {doc_id[:30]}..."
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
embeddings, token_infos = embedder.embed_images(
|
|
180
|
+
[image],
|
|
181
|
+
return_token_info=True,
|
|
182
|
+
show_progress=False,
|
|
183
|
+
)
|
|
184
|
+
emb = embeddings[0]
|
|
185
|
+
token_info = token_infos[0] if token_infos else {}
|
|
186
|
+
|
|
187
|
+
if hasattr(emb, "cpu"):
|
|
188
|
+
emb = emb.cpu()
|
|
189
|
+
emb_np = np.asarray(emb, dtype=output_dtype_obj)
|
|
190
|
+
|
|
191
|
+
initial = emb_np.tolist()
|
|
192
|
+
global_pool = emb_np.mean(axis=0).tolist()
|
|
193
|
+
|
|
194
|
+
num_tiles = token_info.get("num_tiles")
|
|
195
|
+
mean_pooling = None
|
|
196
|
+
experimental_pooling = None
|
|
197
|
+
|
|
198
|
+
if num_tiles and num_tiles > 0:
|
|
199
|
+
try:
|
|
200
|
+
mean_pooling = tile_level_mean_pooling(
|
|
201
|
+
emb_np, num_tiles=num_tiles, patches_per_tile=64
|
|
202
|
+
).tolist()
|
|
203
|
+
except Exception:
|
|
204
|
+
pass
|
|
205
|
+
|
|
206
|
+
try:
|
|
207
|
+
exp_pool = embedder.experimental_pool_visual_embedding(
|
|
208
|
+
emb_np, num_tiles=num_tiles
|
|
209
|
+
)
|
|
210
|
+
if exp_pool is not None:
|
|
211
|
+
experimental_pooling = exp_pool.tolist()
|
|
212
|
+
except Exception:
|
|
213
|
+
pass
|
|
214
|
+
|
|
215
|
+
union_doc_id = _union_point_id(
|
|
216
|
+
dataset_name=dataset_name,
|
|
217
|
+
source_doc_id=doc_id,
|
|
218
|
+
union_namespace=collection,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
payload = {
|
|
222
|
+
"dataset": dataset_name,
|
|
224
223
|
"doc_id": doc_id,
|
|
225
|
-
"source_doc_id":
|
|
224
|
+
"source_doc_id": doc_id,
|
|
226
225
|
"union_doc_id": union_doc_id,
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
226
|
+
"num_tiles": num_tiles,
|
|
227
|
+
"num_visual_tokens": token_info.get("num_visual_tokens"),
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
vectors = {"initial": initial, "global_pooling": global_pool}
|
|
231
|
+
if mean_pooling:
|
|
232
|
+
vectors["mean_pooling"] = mean_pooling
|
|
233
|
+
if experimental_pooling:
|
|
234
|
+
vectors["experimental_pooling"] = experimental_pooling
|
|
235
|
+
|
|
236
|
+
indexer.upsert_point(
|
|
237
|
+
point_id=union_doc_id,
|
|
238
|
+
vectors=vectors,
|
|
239
|
+
payload=payload,
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
uploaded += 1
|
|
243
|
+
|
|
244
|
+
except Exception as e:
|
|
245
|
+
print(f"[INDEX] Error on doc {i}: {e}")
|
|
246
|
+
failed += 1
|
|
247
|
+
|
|
248
|
+
progress_bar.progress((i + 1) / total)
|
|
249
|
+
|
|
250
|
+
status_text.text(f"✅ Done: {uploaded} uploaded, {failed} failed")
|
|
251
|
+
all_results.append(
|
|
252
|
+
{
|
|
253
|
+
"dataset": dataset_name,
|
|
254
|
+
"total": total,
|
|
255
|
+
"uploaded": uploaded,
|
|
256
|
+
"failed": failed,
|
|
257
|
+
}
|
|
258
|
+
)
|
|
259
|
+
|
|
253
260
|
with results_container:
|
|
254
|
-
st.markdown("#####
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
print(f"[INDEX] Total Time: {total_time:.1f}s")
|
|
264
|
-
print(f"[INDEX] Throughput: {docs_per_sec:.2f} docs/s")
|
|
265
|
-
print("=" * 60)
|
|
266
|
-
|
|
267
|
-
c1, c2, c3, c4 = st.columns(4)
|
|
268
|
-
c1.metric("Total Uploaded", f"{total_uploaded:,}")
|
|
269
|
-
c2.metric("Datasets", len(datasets))
|
|
270
|
-
c3.metric("Total Time", f"{total_time:.1f}s")
|
|
271
|
-
c4.metric("Throughput", f"{docs_per_sec:.2f}/s")
|
|
272
|
-
|
|
273
|
-
st.success(f"🎉 Indexing complete! {total_uploaded:,} documents indexed to `{collection}`")
|
|
274
|
-
|
|
275
|
-
detailed_report = {
|
|
276
|
-
"generated_at": datetime.now().isoformat(),
|
|
277
|
-
"config": {
|
|
278
|
-
"collection": collection,
|
|
279
|
-
"model": model,
|
|
280
|
-
"datasets": datasets,
|
|
281
|
-
"batch_size": batch_size,
|
|
282
|
-
"max_docs_per_dataset": max_docs,
|
|
283
|
-
"recreate": recreate,
|
|
284
|
-
"prefer_grpc": prefer_grpc,
|
|
285
|
-
"torch_dtype": torch_dtype,
|
|
286
|
-
"qdrant_vector_dtype": qdrant_vector_dtype,
|
|
287
|
-
},
|
|
288
|
-
"results": {
|
|
289
|
-
"total_docs_uploaded": total_uploaded,
|
|
290
|
-
"total_time_s": round(total_time, 2),
|
|
291
|
-
"throughput_docs_per_s": round(docs_per_sec, 2),
|
|
292
|
-
"num_datasets": len(datasets),
|
|
293
|
-
},
|
|
294
|
-
}
|
|
295
|
-
|
|
296
|
-
with st.expander("📋 Full Summary"):
|
|
297
|
-
st.json(detailed_report)
|
|
298
|
-
|
|
299
|
-
report_json = json.dumps(detailed_report, indent=2)
|
|
300
|
-
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
301
|
-
filename = f"index_report__{collection}__{timestamp}.json"
|
|
302
|
-
|
|
303
|
-
st.download_button(
|
|
304
|
-
label="📥 Download Indexing Report",
|
|
305
|
-
data=report_json,
|
|
306
|
-
file_name=filename,
|
|
307
|
-
mime="application/json",
|
|
308
|
-
use_container_width=True,
|
|
309
|
-
)
|
|
310
|
-
|
|
261
|
+
st.markdown("##### 📋 Results Summary")
|
|
262
|
+
|
|
263
|
+
for r in all_results:
|
|
264
|
+
st.write(
|
|
265
|
+
f"**{r['dataset'].split('/')[-1]}**: {r['uploaded']:,} uploaded, {r['failed']:,} failed"
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
st.success("✅ Indexing complete!")
|
|
269
|
+
|
|
311
270
|
except Exception as e:
|
|
271
|
+
st.error(f"Indexing error: {e}")
|
|
272
|
+
st.code(traceback.format_exc())
|
|
312
273
|
print(f"[INDEX] ERROR: {e}")
|
|
313
|
-
|
|
314
|
-
with st.expander("🔍 Full Error Details"):
|
|
315
|
-
st.code(traceback.format_exc(), language="text")
|
|
274
|
+
traceback.print_exc()
|
demo/qdrant_utils.py
CHANGED
|
@@ -8,12 +8,19 @@ import streamlit as st
|
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
def get_qdrant_credentials() -> Tuple[Optional[str], Optional[str]]:
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
11
|
+
"""Get Qdrant credentials from session state or environment variables.
|
|
12
|
+
|
|
13
|
+
Priority: session_state > QDRANT_URL/QDRANT_API_KEY > legacy env vars
|
|
14
|
+
"""
|
|
15
|
+
url = (
|
|
16
|
+
st.session_state.get("qdrant_url_input")
|
|
17
|
+
or os.getenv("QDRANT_URL")
|
|
18
|
+
or os.getenv("SIGIR_QDRANT_URL") # legacy
|
|
19
|
+
)
|
|
20
|
+
api_key = (
|
|
21
|
+
st.session_state.get("qdrant_key_input")
|
|
16
22
|
or os.getenv("QDRANT_API_KEY")
|
|
23
|
+
or os.getenv("SIGIR_QDRANT_KEY") # legacy
|
|
17
24
|
)
|
|
18
25
|
return url, api_key
|
|
19
26
|
|
demo/ui/playground.py
CHANGED
|
@@ -9,6 +9,7 @@ from demo.qdrant_utils import (
|
|
|
9
9
|
sample_points_cached,
|
|
10
10
|
search_collection,
|
|
11
11
|
)
|
|
12
|
+
from visual_rag.retrieval import MultiVectorRetriever
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
def render_playground_tab():
|
|
@@ -46,7 +47,6 @@ def render_playground_tab():
|
|
|
46
47
|
if not st.session_state.get("model_loaded"):
|
|
47
48
|
with st.spinner(f"Loading {model_short}..."):
|
|
48
49
|
try:
|
|
49
|
-
from visual_rag.retrieval import MultiVectorRetriever
|
|
50
50
|
_ = MultiVectorRetriever(collection_name=active_collection, model_name=model_name)
|
|
51
51
|
st.session_state["model_loaded"] = True
|
|
52
52
|
st.session_state["loaded_model_key"] = cache_key
|