visual-rag-toolkit 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/README.md +101 -0
- benchmarks/__init__.py +11 -0
- benchmarks/analyze_results.py +187 -0
- benchmarks/benchmark_datasets.txt +105 -0
- benchmarks/prepare_submission.py +205 -0
- benchmarks/quick_test.py +566 -0
- benchmarks/run_vidore.py +513 -0
- benchmarks/vidore_beir_qdrant/run_qdrant_beir.py +1365 -0
- benchmarks/vidore_tatdqa_test/COMMANDS.md +83 -0
- benchmarks/vidore_tatdqa_test/__init__.py +6 -0
- benchmarks/vidore_tatdqa_test/dataset_loader.py +363 -0
- benchmarks/vidore_tatdqa_test/metrics.py +44 -0
- benchmarks/vidore_tatdqa_test/run_qdrant.py +799 -0
- benchmarks/vidore_tatdqa_test/sweep_eval.py +372 -0
- demo/__init__.py +10 -0
- demo/app.py +45 -0
- demo/commands.py +334 -0
- demo/config.py +34 -0
- demo/download_models.py +75 -0
- demo/evaluation.py +602 -0
- demo/example_metadata_mapping_sigir.json +37 -0
- demo/indexing.py +286 -0
- demo/qdrant_utils.py +211 -0
- demo/results.py +35 -0
- demo/test_qdrant_connection.py +119 -0
- demo/ui/__init__.py +15 -0
- demo/ui/benchmark.py +355 -0
- demo/ui/header.py +30 -0
- demo/ui/playground.py +339 -0
- demo/ui/sidebar.py +162 -0
- demo/ui/upload.py +487 -0
- visual_rag/__init__.py +98 -0
- visual_rag/cli/__init__.py +1 -0
- visual_rag/cli/main.py +629 -0
- visual_rag/config.py +230 -0
- visual_rag/demo_runner.py +90 -0
- visual_rag/embedding/__init__.py +26 -0
- visual_rag/embedding/pooling.py +343 -0
- visual_rag/embedding/visual_embedder.py +622 -0
- visual_rag/indexing/__init__.py +21 -0
- visual_rag/indexing/cloudinary_uploader.py +274 -0
- visual_rag/indexing/pdf_processor.py +324 -0
- visual_rag/indexing/pipeline.py +628 -0
- visual_rag/indexing/qdrant_indexer.py +478 -0
- visual_rag/preprocessing/__init__.py +3 -0
- visual_rag/preprocessing/crop_empty.py +120 -0
- visual_rag/qdrant_admin.py +222 -0
- visual_rag/retrieval/__init__.py +19 -0
- visual_rag/retrieval/multi_vector.py +222 -0
- visual_rag/retrieval/single_stage.py +126 -0
- visual_rag/retrieval/three_stage.py +173 -0
- visual_rag/retrieval/two_stage.py +471 -0
- visual_rag/visualization/__init__.py +19 -0
- visual_rag/visualization/saliency.py +335 -0
- visual_rag_toolkit-0.1.1.dist-info/METADATA +305 -0
- visual_rag_toolkit-0.1.1.dist-info/RECORD +59 -0
- visual_rag_toolkit-0.1.1.dist-info/WHEEL +4 -0
- visual_rag_toolkit-0.1.1.dist-info/entry_points.txt +3 -0
- visual_rag_toolkit-0.1.1.dist-info/licenses/LICENSE +22 -0
demo/ui/upload.py
ADDED
|
@@ -0,0 +1,487 @@
|
|
|
1
|
+
"""Upload tab component."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import tempfile
|
|
5
|
+
import time
|
|
6
|
+
import traceback
|
|
7
|
+
import json
|
|
8
|
+
import inspect
|
|
9
|
+
from datetime import datetime
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
|
|
12
|
+
import streamlit as st
|
|
13
|
+
|
|
14
|
+
from demo.config import AVAILABLE_MODELS
|
|
15
|
+
from demo.qdrant_utils import (
|
|
16
|
+
get_qdrant_credentials,
|
|
17
|
+
get_collection_stats,
|
|
18
|
+
sample_points_cached,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
VECTOR_TYPES = ["initial", "mean_pooling", "experimental_pooling", "global_pooling"]
|
|
23
|
+
|
|
24
|
+
def _load_metadata_mapping_from_uploaded_json(uploaded_json_file) -> tuple[dict, str]:
|
|
25
|
+
"""
|
|
26
|
+
Load a filename->metadata mapping from an uploaded JSON file.
|
|
27
|
+
|
|
28
|
+
Supported formats:
|
|
29
|
+
- Flat dict:
|
|
30
|
+
{ "Some Report 2023": {"year": 2023, "source": "...", ...}, ... }
|
|
31
|
+
- Nested dict:
|
|
32
|
+
{ "filenames": { "Some Report 2023": {...}, ... }, ... }
|
|
33
|
+
|
|
34
|
+
Keys are normalized to: lowercase, trimmed, without ".pdf".
|
|
35
|
+
"""
|
|
36
|
+
if uploaded_json_file is None:
|
|
37
|
+
return {}, ""
|
|
38
|
+
|
|
39
|
+
try:
|
|
40
|
+
raw = uploaded_json_file.getvalue()
|
|
41
|
+
if not raw:
|
|
42
|
+
return {}, "Empty metadata file"
|
|
43
|
+
data = json.loads(raw.decode("utf-8"))
|
|
44
|
+
if not isinstance(data, dict):
|
|
45
|
+
return {}, "Metadata file must be a JSON object"
|
|
46
|
+
|
|
47
|
+
mapping = data.get("filenames") if isinstance(data.get("filenames"), dict) else data
|
|
48
|
+
|
|
49
|
+
# Drop non-mapping keys (common pattern: _description, _usage)
|
|
50
|
+
mapping = {k: v for k, v in mapping.items() if isinstance(k, str) and not k.startswith("_")}
|
|
51
|
+
|
|
52
|
+
normalized: dict[str, dict] = {}
|
|
53
|
+
bad = 0
|
|
54
|
+
for k, v in mapping.items():
|
|
55
|
+
if not isinstance(k, str) or not isinstance(v, dict):
|
|
56
|
+
bad += 1
|
|
57
|
+
continue
|
|
58
|
+
key = k.strip().lower()
|
|
59
|
+
if key.endswith(".pdf"):
|
|
60
|
+
key = key[:-4]
|
|
61
|
+
if not key:
|
|
62
|
+
bad += 1
|
|
63
|
+
continue
|
|
64
|
+
normalized[key] = v
|
|
65
|
+
|
|
66
|
+
msg = f"Loaded {len(normalized):,} filename metadata mappings"
|
|
67
|
+
if bad:
|
|
68
|
+
msg += f" (ignored {bad:,} non-mapping entries)"
|
|
69
|
+
return normalized, msg
|
|
70
|
+
except Exception as e:
|
|
71
|
+
return {}, f"Failed to parse metadata JSON: {str(e)[:120]}"
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def render_upload_tab():
|
|
75
|
+
if "upload_success" in st.session_state:
|
|
76
|
+
msg = st.session_state.pop("upload_success")
|
|
77
|
+
st.toast(f"✅ {msg}", icon="🎉")
|
|
78
|
+
st.balloons()
|
|
79
|
+
|
|
80
|
+
st.subheader("📤 PDF Upload & Processing")
|
|
81
|
+
|
|
82
|
+
col_upload, col_config = st.columns([3, 2])
|
|
83
|
+
|
|
84
|
+
with col_config:
|
|
85
|
+
st.markdown("##### Configuration")
|
|
86
|
+
|
|
87
|
+
c1, c2 = st.columns(2)
|
|
88
|
+
with c1:
|
|
89
|
+
model_name = st.selectbox("Model", AVAILABLE_MODELS, index=1, key="upload_model")
|
|
90
|
+
with c2:
|
|
91
|
+
collection_name = st.text_input("Collection", value="my_collection", key="upload_collection_input")
|
|
92
|
+
|
|
93
|
+
c3, c4 = st.columns(2)
|
|
94
|
+
with c3:
|
|
95
|
+
vector_dtype = st.selectbox("Vector Dtype", ["float16", "float32"], index=0, key="upload_dtype")
|
|
96
|
+
with c4:
|
|
97
|
+
use_cloudinary = st.toggle("Cloudinary", value=True, key="upload_cloudinary")
|
|
98
|
+
|
|
99
|
+
st.markdown("**Performance**")
|
|
100
|
+
p1, p2, p3 = st.columns(3)
|
|
101
|
+
with p1:
|
|
102
|
+
dpi = st.slider(
|
|
103
|
+
"PDF DPI",
|
|
104
|
+
min_value=90,
|
|
105
|
+
max_value=220,
|
|
106
|
+
value=int(st.session_state.get("upload_dpi", 140) or 140),
|
|
107
|
+
step=10,
|
|
108
|
+
key="upload_dpi",
|
|
109
|
+
help="Lower DPI is faster. 120–150 is a good default for PDFs.",
|
|
110
|
+
)
|
|
111
|
+
with p2:
|
|
112
|
+
embed_batch_size = st.slider(
|
|
113
|
+
"Embedding batch",
|
|
114
|
+
min_value=1,
|
|
115
|
+
max_value=32,
|
|
116
|
+
value=int(st.session_state.get("upload_embed_batch", 8) or 8),
|
|
117
|
+
step=1,
|
|
118
|
+
key="upload_embed_batch",
|
|
119
|
+
help="Higher = faster (until you hit GPU/VRAM limits).",
|
|
120
|
+
)
|
|
121
|
+
with p3:
|
|
122
|
+
upload_batch_size = st.slider(
|
|
123
|
+
"Upload batch",
|
|
124
|
+
min_value=1,
|
|
125
|
+
max_value=32,
|
|
126
|
+
value=int(st.session_state.get("upload_upload_batch", 8) or 8),
|
|
127
|
+
step=1,
|
|
128
|
+
key="upload_upload_batch",
|
|
129
|
+
help="How many pages to upsert to Qdrant per batch.",
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
vectors_to_index = st.multiselect(
|
|
133
|
+
"Vectors to Index",
|
|
134
|
+
VECTOR_TYPES,
|
|
135
|
+
default=VECTOR_TYPES,
|
|
136
|
+
key="upload_vectors",
|
|
137
|
+
help="Which vector types to store in Qdrant"
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
st.markdown("**Crop Settings**")
|
|
141
|
+
cc1, cc2 = st.columns(2)
|
|
142
|
+
with cc1:
|
|
143
|
+
crop_empty = st.toggle("Crop Margins", value=True, key="upload_crop")
|
|
144
|
+
with cc2:
|
|
145
|
+
# Use a slider (instead of free typing) to avoid locale confusion like "0,00".
|
|
146
|
+
# Threshold is std-dev of grayscale intensity (0..255). Smaller = stricter uniformity.
|
|
147
|
+
uniform_rowcol_std_threshold = 0.0
|
|
148
|
+
if crop_empty:
|
|
149
|
+
uniform_rowcol_std_threshold = st.select_slider(
|
|
150
|
+
"Uniform row/col threshold (any color)",
|
|
151
|
+
options=[0.0, 1.0, 2.0, 3.0, 5.0, 8.0, 12.0, 16.0],
|
|
152
|
+
value=float(st.session_state.get("upload_uniform_rowcol_std_threshold", 0.0) or 0.0),
|
|
153
|
+
key="upload_uniform_rowcol_std_threshold",
|
|
154
|
+
help=(
|
|
155
|
+
"0 = off (default). Higher values remove more uniform borders, even if they are gray/black. "
|
|
156
|
+
"Rule: we skip a scanned row/col if std(pixels) ≤ threshold.\n\n"
|
|
157
|
+
"Suggested:\n"
|
|
158
|
+
"- 1–2: clean solid borders\n"
|
|
159
|
+
"- 3–5: light scanner shading\n"
|
|
160
|
+
"- 8+: aggressive (may remove faint content)"
|
|
161
|
+
),
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
if crop_empty:
|
|
165
|
+
crop_pct = st.slider("Crop %", 0.90, 0.99, 0.99, 0.01, key="upload_crop_pct",
|
|
166
|
+
help="Remove margins with this % empty space")
|
|
167
|
+
else:
|
|
168
|
+
crop_pct = 0.99
|
|
169
|
+
|
|
170
|
+
st.markdown("**File Metadata (optional)**")
|
|
171
|
+
meta_file = st.file_uploader(
|
|
172
|
+
"Metadata mapping (JSON)",
|
|
173
|
+
type=["json"],
|
|
174
|
+
key="upload_metadata_json",
|
|
175
|
+
help=(
|
|
176
|
+
"Optional JSON file that maps PDF filenames to extra metadata fields "
|
|
177
|
+
"(e.g., year/source/district). Supported formats match `filename_metadata.json` "
|
|
178
|
+
"and `metadata_mapping.json`."
|
|
179
|
+
),
|
|
180
|
+
)
|
|
181
|
+
metadata_mapping, meta_msg = _load_metadata_mapping_from_uploaded_json(meta_file)
|
|
182
|
+
if meta_file:
|
|
183
|
+
if metadata_mapping:
|
|
184
|
+
st.success(meta_msg)
|
|
185
|
+
# Show a tiny preview without overwhelming the UI
|
|
186
|
+
with st.expander("Preview (first 3 entries)", expanded=False):
|
|
187
|
+
preview_items = list(metadata_mapping.items())[:3]
|
|
188
|
+
st.json({k: v for k, v in preview_items})
|
|
189
|
+
else:
|
|
190
|
+
st.warning(meta_msg or "No mappings loaded")
|
|
191
|
+
else:
|
|
192
|
+
metadata_mapping = {}
|
|
193
|
+
|
|
194
|
+
with col_upload:
|
|
195
|
+
uploaded_files = st.file_uploader(
|
|
196
|
+
"Select PDF files",
|
|
197
|
+
type=["pdf"],
|
|
198
|
+
accept_multiple_files=True,
|
|
199
|
+
key="pdf_uploader",
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
if uploaded_files:
|
|
203
|
+
st.success(f"**{len(uploaded_files)} file(s) selected**")
|
|
204
|
+
|
|
205
|
+
if st.button("🚀 Process PDFs", type="primary", key="process_btn"):
|
|
206
|
+
config = {
|
|
207
|
+
"model_name": model_name,
|
|
208
|
+
"collection_name": collection_name,
|
|
209
|
+
"vector_dtype": vector_dtype,
|
|
210
|
+
"vectors_to_index": vectors_to_index,
|
|
211
|
+
"crop_empty": crop_empty,
|
|
212
|
+
"crop_pct": crop_pct,
|
|
213
|
+
"uniform_rowcol_std_threshold": float(uniform_rowcol_std_threshold or 0.0),
|
|
214
|
+
"use_cloudinary": use_cloudinary,
|
|
215
|
+
"metadata_mapping": metadata_mapping,
|
|
216
|
+
"dpi": int(dpi),
|
|
217
|
+
"embed_batch_size": int(embed_batch_size),
|
|
218
|
+
"upload_batch_size": int(upload_batch_size),
|
|
219
|
+
}
|
|
220
|
+
process_pdfs(uploaded_files, config)
|
|
221
|
+
|
|
222
|
+
if st.session_state.get("last_upload_result"):
|
|
223
|
+
st.divider()
|
|
224
|
+
render_upload_results()
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def process_pdfs(uploaded_files, config):
|
|
228
|
+
model_name = config["model_name"]
|
|
229
|
+
collection_name = config["collection_name"]
|
|
230
|
+
vector_dtype = config["vector_dtype"]
|
|
231
|
+
crop_empty = config["crop_empty"]
|
|
232
|
+
crop_pct = config["crop_pct"]
|
|
233
|
+
uniform_rowcol_std_threshold = float(config.get("uniform_rowcol_std_threshold") or 0.0)
|
|
234
|
+
use_cloudinary = config["use_cloudinary"]
|
|
235
|
+
metadata_mapping = config.get("metadata_mapping") or {}
|
|
236
|
+
dpi = int(config.get("dpi") or 140)
|
|
237
|
+
embed_batch_size = int(config.get("embed_batch_size") or 8)
|
|
238
|
+
upload_batch_size = int(config.get("upload_batch_size") or 8)
|
|
239
|
+
|
|
240
|
+
st.divider()
|
|
241
|
+
|
|
242
|
+
phase1 = st.container()
|
|
243
|
+
phase2 = st.container()
|
|
244
|
+
phase3 = st.container()
|
|
245
|
+
results_container = st.container()
|
|
246
|
+
|
|
247
|
+
try:
|
|
248
|
+
with phase1:
|
|
249
|
+
st.markdown("##### 🤖 Phase 1: Loading Model")
|
|
250
|
+
model_status = st.empty()
|
|
251
|
+
model_short = model_name.split("/")[-1]
|
|
252
|
+
model_status.info(f"Loading `{model_short}`...")
|
|
253
|
+
|
|
254
|
+
import numpy as np
|
|
255
|
+
from visual_rag import VisualEmbedder
|
|
256
|
+
from visual_rag.indexing import QdrantIndexer, CloudinaryUploader, ProcessingPipeline
|
|
257
|
+
|
|
258
|
+
output_dtype = np.float16 if vector_dtype == "float16" else np.float32
|
|
259
|
+
embedder_key = f"{model_name}::{vector_dtype}"
|
|
260
|
+
embedder = None
|
|
261
|
+
if st.session_state.get("upload_embedder_key") == embedder_key:
|
|
262
|
+
embedder = st.session_state.get("upload_embedder")
|
|
263
|
+
if embedder is None:
|
|
264
|
+
embedder = VisualEmbedder(model_name=model_name, output_dtype=output_dtype)
|
|
265
|
+
embedder._load_model()
|
|
266
|
+
st.session_state["upload_embedder_key"] = embedder_key
|
|
267
|
+
st.session_state["upload_embedder"] = embedder
|
|
268
|
+
model_status.success(f"✅ Model `{model_short}` loaded ({vector_dtype})")
|
|
269
|
+
|
|
270
|
+
with phase2:
|
|
271
|
+
st.markdown("##### 📦 Phase 2: Setting Up Collection")
|
|
272
|
+
|
|
273
|
+
url, api_key = get_qdrant_credentials()
|
|
274
|
+
if not url or not api_key:
|
|
275
|
+
st.error("Qdrant credentials not configured")
|
|
276
|
+
return
|
|
277
|
+
|
|
278
|
+
qdrant_status = st.empty()
|
|
279
|
+
qdrant_status.info(f"Connecting to Qdrant...")
|
|
280
|
+
|
|
281
|
+
indexer = QdrantIndexer(
|
|
282
|
+
url=url,
|
|
283
|
+
api_key=api_key,
|
|
284
|
+
collection_name=collection_name,
|
|
285
|
+
prefer_grpc=False,
|
|
286
|
+
vector_datatype=vector_dtype,
|
|
287
|
+
timeout=180,
|
|
288
|
+
)
|
|
289
|
+
qdrant_status.success(f"✅ Connected to Qdrant")
|
|
290
|
+
|
|
291
|
+
coll_status = st.empty()
|
|
292
|
+
collection_exists = False
|
|
293
|
+
try:
|
|
294
|
+
collection_exists = indexer.collection_exists()
|
|
295
|
+
except Exception:
|
|
296
|
+
pass
|
|
297
|
+
|
|
298
|
+
if collection_exists:
|
|
299
|
+
coll_status.success(f"✅ Collection `{collection_name}` exists (will append)")
|
|
300
|
+
else:
|
|
301
|
+
coll_status.info(f"Creating collection `{collection_name}`...")
|
|
302
|
+
for attempt in range(3):
|
|
303
|
+
try:
|
|
304
|
+
indexer.create_collection(force_recreate=False)
|
|
305
|
+
break
|
|
306
|
+
except Exception as e:
|
|
307
|
+
if attempt < 2:
|
|
308
|
+
time.sleep(2)
|
|
309
|
+
else:
|
|
310
|
+
raise
|
|
311
|
+
coll_status.success(f"✅ Collection `{collection_name}` created")
|
|
312
|
+
|
|
313
|
+
idx_status = st.empty()
|
|
314
|
+
idx_status.info("Setting up indexes...")
|
|
315
|
+
try:
|
|
316
|
+
indexer.create_payload_indexes(fields=[
|
|
317
|
+
{"field": "filename", "type": "keyword"},
|
|
318
|
+
{"field": "page_number", "type": "integer"},
|
|
319
|
+
])
|
|
320
|
+
except Exception:
|
|
321
|
+
pass
|
|
322
|
+
idx_status.success("✅ Indexes ready")
|
|
323
|
+
|
|
324
|
+
cloud_status = st.empty()
|
|
325
|
+
cloudinary_uploader = None
|
|
326
|
+
if use_cloudinary:
|
|
327
|
+
cloud_status.info("Connecting to Cloudinary...")
|
|
328
|
+
try:
|
|
329
|
+
cloudinary_uploader = CloudinaryUploader()
|
|
330
|
+
cloud_status.success("✅ Cloudinary ready")
|
|
331
|
+
except Exception as e:
|
|
332
|
+
cloud_status.warning(f"⚠️ Cloudinary unavailable: {str(e)[:30]}")
|
|
333
|
+
else:
|
|
334
|
+
cloud_status.info("☁️ Cloudinary disabled")
|
|
335
|
+
|
|
336
|
+
pipeline = ProcessingPipeline(
|
|
337
|
+
embedder=embedder, indexer=indexer, cloudinary_uploader=cloudinary_uploader,
|
|
338
|
+
metadata_mapping=metadata_mapping,
|
|
339
|
+
config={
|
|
340
|
+
"processing": {"dpi": dpi},
|
|
341
|
+
"batching": {
|
|
342
|
+
"embedding_batch_size": embed_batch_size,
|
|
343
|
+
"upload_batch_size": upload_batch_size,
|
|
344
|
+
},
|
|
345
|
+
},
|
|
346
|
+
crop_empty=crop_empty, crop_empty_percentage_to_remove=crop_pct,
|
|
347
|
+
**({
|
|
348
|
+
"crop_empty_uniform_rowcol_std_threshold": uniform_rowcol_std_threshold
|
|
349
|
+
} if "crop_empty_uniform_rowcol_std_threshold" in inspect.signature(ProcessingPipeline.__init__).parameters else {}),
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
with phase3:
|
|
353
|
+
st.markdown("##### 📄 Phase 3: Processing PDFs")
|
|
354
|
+
|
|
355
|
+
overall_progress = st.progress(0.0)
|
|
356
|
+
file_status = st.empty()
|
|
357
|
+
log_area = st.empty()
|
|
358
|
+
log_lines = []
|
|
359
|
+
|
|
360
|
+
total_uploaded, total_skipped, total_failed = 0, 0, 0
|
|
361
|
+
file_results = []
|
|
362
|
+
|
|
363
|
+
page_status = st.empty()
|
|
364
|
+
|
|
365
|
+
for i, f in enumerate(uploaded_files):
|
|
366
|
+
original_filename = f.name
|
|
367
|
+
file_status.info(f"📄 Processing `{original_filename}` ({i+1}/{len(uploaded_files)})")
|
|
368
|
+
t0 = time.perf_counter()
|
|
369
|
+
|
|
370
|
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp:
|
|
371
|
+
tmp.write(f.getvalue())
|
|
372
|
+
tmp_path = Path(tmp.name)
|
|
373
|
+
|
|
374
|
+
def progress_cb(stage, current, total, message):
|
|
375
|
+
if stage == "process" and total > 0:
|
|
376
|
+
page_status.caption(f" └─ Page {current}/{total}")
|
|
377
|
+
elif stage == "embed" and total > 0:
|
|
378
|
+
# Never show internal function names; keep this human-friendly.
|
|
379
|
+
page_status.caption(f" └─ Embedding pages… ({current+1}-{min(current + 1 + (pipeline.embedding_batch_size - 1), total)}/{total})")
|
|
380
|
+
elif stage == "convert" and total > 0:
|
|
381
|
+
page_status.caption(f" └─ {total} pages found")
|
|
382
|
+
|
|
383
|
+
try:
|
|
384
|
+
result = pipeline.process_pdf(
|
|
385
|
+
tmp_path,
|
|
386
|
+
original_filename=original_filename,
|
|
387
|
+
progress_callback=progress_cb,
|
|
388
|
+
)
|
|
389
|
+
elapsed_s = time.perf_counter() - t0
|
|
390
|
+
uploaded = result.get("uploaded", 0)
|
|
391
|
+
skipped = result.get("skipped", 0)
|
|
392
|
+
total_uploaded += uploaded
|
|
393
|
+
total_skipped += skipped
|
|
394
|
+
total_pages = int(result.get("total_pages") or 0)
|
|
395
|
+
sec_per_page = (elapsed_s / total_pages) if total_pages > 0 else None
|
|
396
|
+
file_results.append({
|
|
397
|
+
"file": original_filename,
|
|
398
|
+
"uploaded": uploaded,
|
|
399
|
+
"skipped": skipped,
|
|
400
|
+
"total_pages": total_pages,
|
|
401
|
+
"elapsed_s": float(elapsed_s),
|
|
402
|
+
"sec_per_page": float(sec_per_page) if sec_per_page is not None else None,
|
|
403
|
+
})
|
|
404
|
+
timing_str = f"{elapsed_s:.1f}s" + (f" ({sec_per_page:.2f}s/page)" if sec_per_page is not None else "")
|
|
405
|
+
log_lines.append(f"✓ {original_filename}: {uploaded} uploaded, {skipped} skipped | {timing_str}")
|
|
406
|
+
except Exception as e:
|
|
407
|
+
total_failed += 1
|
|
408
|
+
log_lines.append(f"✗ {original_filename}: {str(e)[:50]}")
|
|
409
|
+
finally:
|
|
410
|
+
os.unlink(tmp_path)
|
|
411
|
+
|
|
412
|
+
page_status.empty()
|
|
413
|
+
overall_progress.progress((i + 1) / len(uploaded_files))
|
|
414
|
+
log_area.code("\n".join(log_lines[-10:]), language="text")
|
|
415
|
+
|
|
416
|
+
overall_progress.progress(1.0)
|
|
417
|
+
file_status.success(f"✅ Processed {len(uploaded_files)} file(s)")
|
|
418
|
+
|
|
419
|
+
with results_container:
|
|
420
|
+
st.markdown("##### 📊 Results")
|
|
421
|
+
|
|
422
|
+
st.success(f"✅ **{total_uploaded} pages** uploaded to `{collection_name}`" +
|
|
423
|
+
(f" ({total_skipped} skipped)" if total_skipped else "") +
|
|
424
|
+
(f" ({total_failed} failed)" if total_failed else ""))
|
|
425
|
+
|
|
426
|
+
if file_results:
|
|
427
|
+
with st.expander("📋 File Details", expanded=False):
|
|
428
|
+
for fr in file_results:
|
|
429
|
+
timing = ""
|
|
430
|
+
if fr.get("elapsed_s") is not None:
|
|
431
|
+
timing = f" | {fr['elapsed_s']:.1f}s"
|
|
432
|
+
if fr.get("sec_per_page") is not None:
|
|
433
|
+
timing += f" ({fr['sec_per_page']:.2f}s/page)"
|
|
434
|
+
st.text(
|
|
435
|
+
f"• {fr['file']}: {fr['uploaded']} uploaded"
|
|
436
|
+
+ (f", {fr['skipped']} skipped" if fr.get("skipped") else "")
|
|
437
|
+
+ (f", {fr['total_pages']} pages" if fr.get("total_pages") else "")
|
|
438
|
+
+ timing
|
|
439
|
+
)
|
|
440
|
+
|
|
441
|
+
st.session_state["last_upload_result"] = {
|
|
442
|
+
"total_uploaded": total_uploaded, "total_skipped": total_skipped, "total_failed": total_failed,
|
|
443
|
+
"file_results": file_results, "collection": collection_name,
|
|
444
|
+
}
|
|
445
|
+
|
|
446
|
+
get_collection_stats.clear()
|
|
447
|
+
sample_points_cached.clear()
|
|
448
|
+
|
|
449
|
+
if total_uploaded > 0:
|
|
450
|
+
st.session_state["upload_success"] = f"Uploaded {total_uploaded} pages to {collection_name}"
|
|
451
|
+
st.balloons()
|
|
452
|
+
|
|
453
|
+
except Exception as e:
|
|
454
|
+
st.error(f"❌ Processing error: {e}")
|
|
455
|
+
with st.expander("Traceback"):
|
|
456
|
+
st.code(traceback.format_exc())
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
def render_upload_results():
|
|
460
|
+
result = st.session_state.get("last_upload_result", {})
|
|
461
|
+
if not result:
|
|
462
|
+
return
|
|
463
|
+
|
|
464
|
+
uploaded = result.get("total_uploaded", 0)
|
|
465
|
+
skipped = result.get("total_skipped", 0)
|
|
466
|
+
failed = result.get("total_failed", 0)
|
|
467
|
+
collection = result.get("collection", "")
|
|
468
|
+
file_results = result.get("file_results", [])
|
|
469
|
+
|
|
470
|
+
st.success(f"✅ **{uploaded} pages** uploaded to `{collection}`" +
|
|
471
|
+
(f" ({skipped} skipped)" if skipped else "") +
|
|
472
|
+
(f" ({failed} failed)" if failed else ""))
|
|
473
|
+
|
|
474
|
+
if file_results:
|
|
475
|
+
with st.expander("📋 Details", expanded=False):
|
|
476
|
+
for fr in file_results:
|
|
477
|
+
timing = ""
|
|
478
|
+
if fr.get("elapsed_s") is not None:
|
|
479
|
+
timing = f" | {fr['elapsed_s']:.1f}s"
|
|
480
|
+
if fr.get("sec_per_page") is not None:
|
|
481
|
+
timing += f" ({fr['sec_per_page']:.2f}s/page)"
|
|
482
|
+
st.text(
|
|
483
|
+
f"• {fr['file']}: {fr['uploaded']} uploaded"
|
|
484
|
+
+ (f", {fr['skipped']} skipped" if fr.get("skipped") else "")
|
|
485
|
+
+ (f", {fr['total_pages']} pages" if fr.get("total_pages") else "")
|
|
486
|
+
+ timing
|
|
487
|
+
)
|
visual_rag/__init__.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Visual RAG Toolkit - End-to-end visual document retrieval with two-stage pooling.
|
|
3
|
+
|
|
4
|
+
A modular toolkit for building visual document retrieval systems:
|
|
5
|
+
|
|
6
|
+
Components:
|
|
7
|
+
-----------
|
|
8
|
+
- embedding: Visual and text embedding generation (ColPali, etc.)
|
|
9
|
+
- indexing: PDF processing, Qdrant indexing, Cloudinary uploads
|
|
10
|
+
- retrieval: Single-stage and two-stage retrieval with MaxSim
|
|
11
|
+
- visualization: Saliency maps and attention visualization
|
|
12
|
+
- cli: Command-line interface
|
|
13
|
+
|
|
14
|
+
Quick Start:
|
|
15
|
+
------------
|
|
16
|
+
>>> from visual_rag import VisualEmbedder, PDFProcessor, TwoStageRetriever
|
|
17
|
+
>>>
|
|
18
|
+
>>> # Process PDFs
|
|
19
|
+
>>> processor = PDFProcessor(dpi=140)
|
|
20
|
+
>>> images, texts = processor.process_pdf("report.pdf")
|
|
21
|
+
>>>
|
|
22
|
+
>>> # Generate embeddings
|
|
23
|
+
>>> embedder = VisualEmbedder()
|
|
24
|
+
>>> embeddings = embedder.embed_images(images)
|
|
25
|
+
>>> query_emb = embedder.embed_query("What is the budget?")
|
|
26
|
+
>>>
|
|
27
|
+
>>> # Search with two-stage retrieval
|
|
28
|
+
>>> retriever = TwoStageRetriever(qdrant_client, "my_collection")
|
|
29
|
+
>>> results = retriever.search(query_emb, top_k=10)
|
|
30
|
+
|
|
31
|
+
Each component works independently - use only what you need.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
__version__ = "0.1.0"
|
|
35
|
+
|
|
36
|
+
# Import main classes at package level for convenience
|
|
37
|
+
# These are optional - if dependencies aren't installed, we catch the error
|
|
38
|
+
|
|
39
|
+
try:
|
|
40
|
+
from visual_rag.embedding.visual_embedder import VisualEmbedder
|
|
41
|
+
except ImportError:
|
|
42
|
+
VisualEmbedder = None
|
|
43
|
+
|
|
44
|
+
try:
|
|
45
|
+
from visual_rag.indexing.pdf_processor import PDFProcessor
|
|
46
|
+
except ImportError:
|
|
47
|
+
PDFProcessor = None
|
|
48
|
+
|
|
49
|
+
try:
|
|
50
|
+
from visual_rag.indexing.qdrant_indexer import QdrantIndexer
|
|
51
|
+
except ImportError:
|
|
52
|
+
QdrantIndexer = None
|
|
53
|
+
|
|
54
|
+
try:
|
|
55
|
+
from visual_rag.indexing.cloudinary_uploader import CloudinaryUploader
|
|
56
|
+
except ImportError:
|
|
57
|
+
CloudinaryUploader = None
|
|
58
|
+
|
|
59
|
+
try:
|
|
60
|
+
from visual_rag.retrieval.two_stage import TwoStageRetriever
|
|
61
|
+
except ImportError:
|
|
62
|
+
TwoStageRetriever = None
|
|
63
|
+
|
|
64
|
+
try:
|
|
65
|
+
from visual_rag.retrieval.multi_vector import MultiVectorRetriever
|
|
66
|
+
except ImportError:
|
|
67
|
+
MultiVectorRetriever = None
|
|
68
|
+
|
|
69
|
+
try:
|
|
70
|
+
from visual_rag.qdrant_admin import QdrantAdmin
|
|
71
|
+
except ImportError:
|
|
72
|
+
QdrantAdmin = None
|
|
73
|
+
|
|
74
|
+
try:
|
|
75
|
+
from visual_rag.demo_runner import demo
|
|
76
|
+
except ImportError:
|
|
77
|
+
demo = None
|
|
78
|
+
|
|
79
|
+
# Config utilities (always available)
|
|
80
|
+
from visual_rag.config import get, get_section, load_config
|
|
81
|
+
|
|
82
|
+
__all__ = [
|
|
83
|
+
# Version
|
|
84
|
+
"__version__",
|
|
85
|
+
# Main classes
|
|
86
|
+
"VisualEmbedder",
|
|
87
|
+
"PDFProcessor",
|
|
88
|
+
"QdrantIndexer",
|
|
89
|
+
"CloudinaryUploader",
|
|
90
|
+
"TwoStageRetriever",
|
|
91
|
+
"MultiVectorRetriever",
|
|
92
|
+
"QdrantAdmin",
|
|
93
|
+
"demo",
|
|
94
|
+
# Config utilities
|
|
95
|
+
"load_config",
|
|
96
|
+
"get",
|
|
97
|
+
"get_section",
|
|
98
|
+
]
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""CLI entry point for visual-rag-toolkit."""
|