visual-rag-toolkit 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. benchmarks/README.md +101 -0
  2. benchmarks/__init__.py +11 -0
  3. benchmarks/analyze_results.py +187 -0
  4. benchmarks/benchmark_datasets.txt +105 -0
  5. benchmarks/prepare_submission.py +205 -0
  6. benchmarks/quick_test.py +566 -0
  7. benchmarks/run_vidore.py +513 -0
  8. benchmarks/vidore_beir_qdrant/run_qdrant_beir.py +1365 -0
  9. benchmarks/vidore_tatdqa_test/COMMANDS.md +83 -0
  10. benchmarks/vidore_tatdqa_test/__init__.py +6 -0
  11. benchmarks/vidore_tatdqa_test/dataset_loader.py +363 -0
  12. benchmarks/vidore_tatdqa_test/metrics.py +44 -0
  13. benchmarks/vidore_tatdqa_test/run_qdrant.py +799 -0
  14. benchmarks/vidore_tatdqa_test/sweep_eval.py +372 -0
  15. demo/__init__.py +10 -0
  16. demo/app.py +45 -0
  17. demo/commands.py +334 -0
  18. demo/config.py +34 -0
  19. demo/download_models.py +75 -0
  20. demo/evaluation.py +602 -0
  21. demo/example_metadata_mapping_sigir.json +37 -0
  22. demo/indexing.py +286 -0
  23. demo/qdrant_utils.py +211 -0
  24. demo/results.py +35 -0
  25. demo/test_qdrant_connection.py +119 -0
  26. demo/ui/__init__.py +15 -0
  27. demo/ui/benchmark.py +355 -0
  28. demo/ui/header.py +30 -0
  29. demo/ui/playground.py +339 -0
  30. demo/ui/sidebar.py +162 -0
  31. demo/ui/upload.py +487 -0
  32. visual_rag/__init__.py +98 -0
  33. visual_rag/cli/__init__.py +1 -0
  34. visual_rag/cli/main.py +629 -0
  35. visual_rag/config.py +230 -0
  36. visual_rag/demo_runner.py +90 -0
  37. visual_rag/embedding/__init__.py +26 -0
  38. visual_rag/embedding/pooling.py +343 -0
  39. visual_rag/embedding/visual_embedder.py +622 -0
  40. visual_rag/indexing/__init__.py +21 -0
  41. visual_rag/indexing/cloudinary_uploader.py +274 -0
  42. visual_rag/indexing/pdf_processor.py +324 -0
  43. visual_rag/indexing/pipeline.py +628 -0
  44. visual_rag/indexing/qdrant_indexer.py +478 -0
  45. visual_rag/preprocessing/__init__.py +3 -0
  46. visual_rag/preprocessing/crop_empty.py +120 -0
  47. visual_rag/qdrant_admin.py +222 -0
  48. visual_rag/retrieval/__init__.py +19 -0
  49. visual_rag/retrieval/multi_vector.py +222 -0
  50. visual_rag/retrieval/single_stage.py +126 -0
  51. visual_rag/retrieval/three_stage.py +173 -0
  52. visual_rag/retrieval/two_stage.py +471 -0
  53. visual_rag/visualization/__init__.py +19 -0
  54. visual_rag/visualization/saliency.py +335 -0
  55. visual_rag_toolkit-0.1.1.dist-info/METADATA +305 -0
  56. visual_rag_toolkit-0.1.1.dist-info/RECORD +59 -0
  57. visual_rag_toolkit-0.1.1.dist-info/WHEEL +4 -0
  58. visual_rag_toolkit-0.1.1.dist-info/entry_points.txt +3 -0
  59. visual_rag_toolkit-0.1.1.dist-info/licenses/LICENSE +22 -0
demo/ui/upload.py ADDED
@@ -0,0 +1,487 @@
1
+ """Upload tab component."""
2
+
3
+ import os
4
+ import tempfile
5
+ import time
6
+ import traceback
7
+ import json
8
+ import inspect
9
+ from datetime import datetime
10
+ from pathlib import Path
11
+
12
+ import streamlit as st
13
+
14
+ from demo.config import AVAILABLE_MODELS
15
+ from demo.qdrant_utils import (
16
+ get_qdrant_credentials,
17
+ get_collection_stats,
18
+ sample_points_cached,
19
+ )
20
+
21
+
22
+ VECTOR_TYPES = ["initial", "mean_pooling", "experimental_pooling", "global_pooling"]
23
+
24
+ def _load_metadata_mapping_from_uploaded_json(uploaded_json_file) -> tuple[dict, str]:
25
+ """
26
+ Load a filename->metadata mapping from an uploaded JSON file.
27
+
28
+ Supported formats:
29
+ - Flat dict:
30
+ { "Some Report 2023": {"year": 2023, "source": "...", ...}, ... }
31
+ - Nested dict:
32
+ { "filenames": { "Some Report 2023": {...}, ... }, ... }
33
+
34
+ Keys are normalized to: lowercase, trimmed, without ".pdf".
35
+ """
36
+ if uploaded_json_file is None:
37
+ return {}, ""
38
+
39
+ try:
40
+ raw = uploaded_json_file.getvalue()
41
+ if not raw:
42
+ return {}, "Empty metadata file"
43
+ data = json.loads(raw.decode("utf-8"))
44
+ if not isinstance(data, dict):
45
+ return {}, "Metadata file must be a JSON object"
46
+
47
+ mapping = data.get("filenames") if isinstance(data.get("filenames"), dict) else data
48
+
49
+ # Drop non-mapping keys (common pattern: _description, _usage)
50
+ mapping = {k: v for k, v in mapping.items() if isinstance(k, str) and not k.startswith("_")}
51
+
52
+ normalized: dict[str, dict] = {}
53
+ bad = 0
54
+ for k, v in mapping.items():
55
+ if not isinstance(k, str) or not isinstance(v, dict):
56
+ bad += 1
57
+ continue
58
+ key = k.strip().lower()
59
+ if key.endswith(".pdf"):
60
+ key = key[:-4]
61
+ if not key:
62
+ bad += 1
63
+ continue
64
+ normalized[key] = v
65
+
66
+ msg = f"Loaded {len(normalized):,} filename metadata mappings"
67
+ if bad:
68
+ msg += f" (ignored {bad:,} non-mapping entries)"
69
+ return normalized, msg
70
+ except Exception as e:
71
+ return {}, f"Failed to parse metadata JSON: {str(e)[:120]}"
72
+
73
+
74
+ def render_upload_tab():
75
+ if "upload_success" in st.session_state:
76
+ msg = st.session_state.pop("upload_success")
77
+ st.toast(f"✅ {msg}", icon="🎉")
78
+ st.balloons()
79
+
80
+ st.subheader("📤 PDF Upload & Processing")
81
+
82
+ col_upload, col_config = st.columns([3, 2])
83
+
84
+ with col_config:
85
+ st.markdown("##### Configuration")
86
+
87
+ c1, c2 = st.columns(2)
88
+ with c1:
89
+ model_name = st.selectbox("Model", AVAILABLE_MODELS, index=1, key="upload_model")
90
+ with c2:
91
+ collection_name = st.text_input("Collection", value="my_collection", key="upload_collection_input")
92
+
93
+ c3, c4 = st.columns(2)
94
+ with c3:
95
+ vector_dtype = st.selectbox("Vector Dtype", ["float16", "float32"], index=0, key="upload_dtype")
96
+ with c4:
97
+ use_cloudinary = st.toggle("Cloudinary", value=True, key="upload_cloudinary")
98
+
99
+ st.markdown("**Performance**")
100
+ p1, p2, p3 = st.columns(3)
101
+ with p1:
102
+ dpi = st.slider(
103
+ "PDF DPI",
104
+ min_value=90,
105
+ max_value=220,
106
+ value=int(st.session_state.get("upload_dpi", 140) or 140),
107
+ step=10,
108
+ key="upload_dpi",
109
+ help="Lower DPI is faster. 120–150 is a good default for PDFs.",
110
+ )
111
+ with p2:
112
+ embed_batch_size = st.slider(
113
+ "Embedding batch",
114
+ min_value=1,
115
+ max_value=32,
116
+ value=int(st.session_state.get("upload_embed_batch", 8) or 8),
117
+ step=1,
118
+ key="upload_embed_batch",
119
+ help="Higher = faster (until you hit GPU/VRAM limits).",
120
+ )
121
+ with p3:
122
+ upload_batch_size = st.slider(
123
+ "Upload batch",
124
+ min_value=1,
125
+ max_value=32,
126
+ value=int(st.session_state.get("upload_upload_batch", 8) or 8),
127
+ step=1,
128
+ key="upload_upload_batch",
129
+ help="How many pages to upsert to Qdrant per batch.",
130
+ )
131
+
132
+ vectors_to_index = st.multiselect(
133
+ "Vectors to Index",
134
+ VECTOR_TYPES,
135
+ default=VECTOR_TYPES,
136
+ key="upload_vectors",
137
+ help="Which vector types to store in Qdrant"
138
+ )
139
+
140
+ st.markdown("**Crop Settings**")
141
+ cc1, cc2 = st.columns(2)
142
+ with cc1:
143
+ crop_empty = st.toggle("Crop Margins", value=True, key="upload_crop")
144
+ with cc2:
145
+ # Use a slider (instead of free typing) to avoid locale confusion like "0,00".
146
+ # Threshold is std-dev of grayscale intensity (0..255). Smaller = stricter uniformity.
147
+ uniform_rowcol_std_threshold = 0.0
148
+ if crop_empty:
149
+ uniform_rowcol_std_threshold = st.select_slider(
150
+ "Uniform row/col threshold (any color)",
151
+ options=[0.0, 1.0, 2.0, 3.0, 5.0, 8.0, 12.0, 16.0],
152
+ value=float(st.session_state.get("upload_uniform_rowcol_std_threshold", 0.0) or 0.0),
153
+ key="upload_uniform_rowcol_std_threshold",
154
+ help=(
155
+ "0 = off (default). Higher values remove more uniform borders, even if they are gray/black. "
156
+ "Rule: we skip a scanned row/col if std(pixels) ≤ threshold.\n\n"
157
+ "Suggested:\n"
158
+ "- 1–2: clean solid borders\n"
159
+ "- 3–5: light scanner shading\n"
160
+ "- 8+: aggressive (may remove faint content)"
161
+ ),
162
+ )
163
+
164
+ if crop_empty:
165
+ crop_pct = st.slider("Crop %", 0.90, 0.99, 0.99, 0.01, key="upload_crop_pct",
166
+ help="Remove margins with this % empty space")
167
+ else:
168
+ crop_pct = 0.99
169
+
170
+ st.markdown("**File Metadata (optional)**")
171
+ meta_file = st.file_uploader(
172
+ "Metadata mapping (JSON)",
173
+ type=["json"],
174
+ key="upload_metadata_json",
175
+ help=(
176
+ "Optional JSON file that maps PDF filenames to extra metadata fields "
177
+ "(e.g., year/source/district). Supported formats match `filename_metadata.json` "
178
+ "and `metadata_mapping.json`."
179
+ ),
180
+ )
181
+ metadata_mapping, meta_msg = _load_metadata_mapping_from_uploaded_json(meta_file)
182
+ if meta_file:
183
+ if metadata_mapping:
184
+ st.success(meta_msg)
185
+ # Show a tiny preview without overwhelming the UI
186
+ with st.expander("Preview (first 3 entries)", expanded=False):
187
+ preview_items = list(metadata_mapping.items())[:3]
188
+ st.json({k: v for k, v in preview_items})
189
+ else:
190
+ st.warning(meta_msg or "No mappings loaded")
191
+ else:
192
+ metadata_mapping = {}
193
+
194
+ with col_upload:
195
+ uploaded_files = st.file_uploader(
196
+ "Select PDF files",
197
+ type=["pdf"],
198
+ accept_multiple_files=True,
199
+ key="pdf_uploader",
200
+ )
201
+
202
+ if uploaded_files:
203
+ st.success(f"**{len(uploaded_files)} file(s) selected**")
204
+
205
+ if st.button("🚀 Process PDFs", type="primary", key="process_btn"):
206
+ config = {
207
+ "model_name": model_name,
208
+ "collection_name": collection_name,
209
+ "vector_dtype": vector_dtype,
210
+ "vectors_to_index": vectors_to_index,
211
+ "crop_empty": crop_empty,
212
+ "crop_pct": crop_pct,
213
+ "uniform_rowcol_std_threshold": float(uniform_rowcol_std_threshold or 0.0),
214
+ "use_cloudinary": use_cloudinary,
215
+ "metadata_mapping": metadata_mapping,
216
+ "dpi": int(dpi),
217
+ "embed_batch_size": int(embed_batch_size),
218
+ "upload_batch_size": int(upload_batch_size),
219
+ }
220
+ process_pdfs(uploaded_files, config)
221
+
222
+ if st.session_state.get("last_upload_result"):
223
+ st.divider()
224
+ render_upload_results()
225
+
226
+
227
+ def process_pdfs(uploaded_files, config):
228
+ model_name = config["model_name"]
229
+ collection_name = config["collection_name"]
230
+ vector_dtype = config["vector_dtype"]
231
+ crop_empty = config["crop_empty"]
232
+ crop_pct = config["crop_pct"]
233
+ uniform_rowcol_std_threshold = float(config.get("uniform_rowcol_std_threshold") or 0.0)
234
+ use_cloudinary = config["use_cloudinary"]
235
+ metadata_mapping = config.get("metadata_mapping") or {}
236
+ dpi = int(config.get("dpi") or 140)
237
+ embed_batch_size = int(config.get("embed_batch_size") or 8)
238
+ upload_batch_size = int(config.get("upload_batch_size") or 8)
239
+
240
+ st.divider()
241
+
242
+ phase1 = st.container()
243
+ phase2 = st.container()
244
+ phase3 = st.container()
245
+ results_container = st.container()
246
+
247
+ try:
248
+ with phase1:
249
+ st.markdown("##### 🤖 Phase 1: Loading Model")
250
+ model_status = st.empty()
251
+ model_short = model_name.split("/")[-1]
252
+ model_status.info(f"Loading `{model_short}`...")
253
+
254
+ import numpy as np
255
+ from visual_rag import VisualEmbedder
256
+ from visual_rag.indexing import QdrantIndexer, CloudinaryUploader, ProcessingPipeline
257
+
258
+ output_dtype = np.float16 if vector_dtype == "float16" else np.float32
259
+ embedder_key = f"{model_name}::{vector_dtype}"
260
+ embedder = None
261
+ if st.session_state.get("upload_embedder_key") == embedder_key:
262
+ embedder = st.session_state.get("upload_embedder")
263
+ if embedder is None:
264
+ embedder = VisualEmbedder(model_name=model_name, output_dtype=output_dtype)
265
+ embedder._load_model()
266
+ st.session_state["upload_embedder_key"] = embedder_key
267
+ st.session_state["upload_embedder"] = embedder
268
+ model_status.success(f"✅ Model `{model_short}` loaded ({vector_dtype})")
269
+
270
+ with phase2:
271
+ st.markdown("##### 📦 Phase 2: Setting Up Collection")
272
+
273
+ url, api_key = get_qdrant_credentials()
274
+ if not url or not api_key:
275
+ st.error("Qdrant credentials not configured")
276
+ return
277
+
278
+ qdrant_status = st.empty()
279
+ qdrant_status.info(f"Connecting to Qdrant...")
280
+
281
+ indexer = QdrantIndexer(
282
+ url=url,
283
+ api_key=api_key,
284
+ collection_name=collection_name,
285
+ prefer_grpc=False,
286
+ vector_datatype=vector_dtype,
287
+ timeout=180,
288
+ )
289
+ qdrant_status.success(f"✅ Connected to Qdrant")
290
+
291
+ coll_status = st.empty()
292
+ collection_exists = False
293
+ try:
294
+ collection_exists = indexer.collection_exists()
295
+ except Exception:
296
+ pass
297
+
298
+ if collection_exists:
299
+ coll_status.success(f"✅ Collection `{collection_name}` exists (will append)")
300
+ else:
301
+ coll_status.info(f"Creating collection `{collection_name}`...")
302
+ for attempt in range(3):
303
+ try:
304
+ indexer.create_collection(force_recreate=False)
305
+ break
306
+ except Exception as e:
307
+ if attempt < 2:
308
+ time.sleep(2)
309
+ else:
310
+ raise
311
+ coll_status.success(f"✅ Collection `{collection_name}` created")
312
+
313
+ idx_status = st.empty()
314
+ idx_status.info("Setting up indexes...")
315
+ try:
316
+ indexer.create_payload_indexes(fields=[
317
+ {"field": "filename", "type": "keyword"},
318
+ {"field": "page_number", "type": "integer"},
319
+ ])
320
+ except Exception:
321
+ pass
322
+ idx_status.success("✅ Indexes ready")
323
+
324
+ cloud_status = st.empty()
325
+ cloudinary_uploader = None
326
+ if use_cloudinary:
327
+ cloud_status.info("Connecting to Cloudinary...")
328
+ try:
329
+ cloudinary_uploader = CloudinaryUploader()
330
+ cloud_status.success("✅ Cloudinary ready")
331
+ except Exception as e:
332
+ cloud_status.warning(f"⚠️ Cloudinary unavailable: {str(e)[:30]}")
333
+ else:
334
+ cloud_status.info("☁️ Cloudinary disabled")
335
+
336
+ pipeline = ProcessingPipeline(
337
+ embedder=embedder, indexer=indexer, cloudinary_uploader=cloudinary_uploader,
338
+ metadata_mapping=metadata_mapping,
339
+ config={
340
+ "processing": {"dpi": dpi},
341
+ "batching": {
342
+ "embedding_batch_size": embed_batch_size,
343
+ "upload_batch_size": upload_batch_size,
344
+ },
345
+ },
346
+ crop_empty=crop_empty, crop_empty_percentage_to_remove=crop_pct,
347
+ **({
348
+ "crop_empty_uniform_rowcol_std_threshold": uniform_rowcol_std_threshold
349
+ } if "crop_empty_uniform_rowcol_std_threshold" in inspect.signature(ProcessingPipeline.__init__).parameters else {}),
350
+ )
351
+
352
+ with phase3:
353
+ st.markdown("##### 📄 Phase 3: Processing PDFs")
354
+
355
+ overall_progress = st.progress(0.0)
356
+ file_status = st.empty()
357
+ log_area = st.empty()
358
+ log_lines = []
359
+
360
+ total_uploaded, total_skipped, total_failed = 0, 0, 0
361
+ file_results = []
362
+
363
+ page_status = st.empty()
364
+
365
+ for i, f in enumerate(uploaded_files):
366
+ original_filename = f.name
367
+ file_status.info(f"📄 Processing `{original_filename}` ({i+1}/{len(uploaded_files)})")
368
+ t0 = time.perf_counter()
369
+
370
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp:
371
+ tmp.write(f.getvalue())
372
+ tmp_path = Path(tmp.name)
373
+
374
+ def progress_cb(stage, current, total, message):
375
+ if stage == "process" and total > 0:
376
+ page_status.caption(f" └─ Page {current}/{total}")
377
+ elif stage == "embed" and total > 0:
378
+ # Never show internal function names; keep this human-friendly.
379
+ page_status.caption(f" └─ Embedding pages… ({current+1}-{min(current + 1 + (pipeline.embedding_batch_size - 1), total)}/{total})")
380
+ elif stage == "convert" and total > 0:
381
+ page_status.caption(f" └─ {total} pages found")
382
+
383
+ try:
384
+ result = pipeline.process_pdf(
385
+ tmp_path,
386
+ original_filename=original_filename,
387
+ progress_callback=progress_cb,
388
+ )
389
+ elapsed_s = time.perf_counter() - t0
390
+ uploaded = result.get("uploaded", 0)
391
+ skipped = result.get("skipped", 0)
392
+ total_uploaded += uploaded
393
+ total_skipped += skipped
394
+ total_pages = int(result.get("total_pages") or 0)
395
+ sec_per_page = (elapsed_s / total_pages) if total_pages > 0 else None
396
+ file_results.append({
397
+ "file": original_filename,
398
+ "uploaded": uploaded,
399
+ "skipped": skipped,
400
+ "total_pages": total_pages,
401
+ "elapsed_s": float(elapsed_s),
402
+ "sec_per_page": float(sec_per_page) if sec_per_page is not None else None,
403
+ })
404
+ timing_str = f"{elapsed_s:.1f}s" + (f" ({sec_per_page:.2f}s/page)" if sec_per_page is not None else "")
405
+ log_lines.append(f"✓ {original_filename}: {uploaded} uploaded, {skipped} skipped | {timing_str}")
406
+ except Exception as e:
407
+ total_failed += 1
408
+ log_lines.append(f"✗ {original_filename}: {str(e)[:50]}")
409
+ finally:
410
+ os.unlink(tmp_path)
411
+
412
+ page_status.empty()
413
+ overall_progress.progress((i + 1) / len(uploaded_files))
414
+ log_area.code("\n".join(log_lines[-10:]), language="text")
415
+
416
+ overall_progress.progress(1.0)
417
+ file_status.success(f"✅ Processed {len(uploaded_files)} file(s)")
418
+
419
+ with results_container:
420
+ st.markdown("##### 📊 Results")
421
+
422
+ st.success(f"✅ **{total_uploaded} pages** uploaded to `{collection_name}`" +
423
+ (f" ({total_skipped} skipped)" if total_skipped else "") +
424
+ (f" ({total_failed} failed)" if total_failed else ""))
425
+
426
+ if file_results:
427
+ with st.expander("📋 File Details", expanded=False):
428
+ for fr in file_results:
429
+ timing = ""
430
+ if fr.get("elapsed_s") is not None:
431
+ timing = f" | {fr['elapsed_s']:.1f}s"
432
+ if fr.get("sec_per_page") is not None:
433
+ timing += f" ({fr['sec_per_page']:.2f}s/page)"
434
+ st.text(
435
+ f"• {fr['file']}: {fr['uploaded']} uploaded"
436
+ + (f", {fr['skipped']} skipped" if fr.get("skipped") else "")
437
+ + (f", {fr['total_pages']} pages" if fr.get("total_pages") else "")
438
+ + timing
439
+ )
440
+
441
+ st.session_state["last_upload_result"] = {
442
+ "total_uploaded": total_uploaded, "total_skipped": total_skipped, "total_failed": total_failed,
443
+ "file_results": file_results, "collection": collection_name,
444
+ }
445
+
446
+ get_collection_stats.clear()
447
+ sample_points_cached.clear()
448
+
449
+ if total_uploaded > 0:
450
+ st.session_state["upload_success"] = f"Uploaded {total_uploaded} pages to {collection_name}"
451
+ st.balloons()
452
+
453
+ except Exception as e:
454
+ st.error(f"❌ Processing error: {e}")
455
+ with st.expander("Traceback"):
456
+ st.code(traceback.format_exc())
457
+
458
+
459
+ def render_upload_results():
460
+ result = st.session_state.get("last_upload_result", {})
461
+ if not result:
462
+ return
463
+
464
+ uploaded = result.get("total_uploaded", 0)
465
+ skipped = result.get("total_skipped", 0)
466
+ failed = result.get("total_failed", 0)
467
+ collection = result.get("collection", "")
468
+ file_results = result.get("file_results", [])
469
+
470
+ st.success(f"✅ **{uploaded} pages** uploaded to `{collection}`" +
471
+ (f" ({skipped} skipped)" if skipped else "") +
472
+ (f" ({failed} failed)" if failed else ""))
473
+
474
+ if file_results:
475
+ with st.expander("📋 Details", expanded=False):
476
+ for fr in file_results:
477
+ timing = ""
478
+ if fr.get("elapsed_s") is not None:
479
+ timing = f" | {fr['elapsed_s']:.1f}s"
480
+ if fr.get("sec_per_page") is not None:
481
+ timing += f" ({fr['sec_per_page']:.2f}s/page)"
482
+ st.text(
483
+ f"• {fr['file']}: {fr['uploaded']} uploaded"
484
+ + (f", {fr['skipped']} skipped" if fr.get("skipped") else "")
485
+ + (f", {fr['total_pages']} pages" if fr.get("total_pages") else "")
486
+ + timing
487
+ )
visual_rag/__init__.py ADDED
@@ -0,0 +1,98 @@
1
+ """
2
+ Visual RAG Toolkit - End-to-end visual document retrieval with two-stage pooling.
3
+
4
+ A modular toolkit for building visual document retrieval systems:
5
+
6
+ Components:
7
+ -----------
8
+ - embedding: Visual and text embedding generation (ColPali, etc.)
9
+ - indexing: PDF processing, Qdrant indexing, Cloudinary uploads
10
+ - retrieval: Single-stage and two-stage retrieval with MaxSim
11
+ - visualization: Saliency maps and attention visualization
12
+ - cli: Command-line interface
13
+
14
+ Quick Start:
15
+ ------------
16
+ >>> from visual_rag import VisualEmbedder, PDFProcessor, TwoStageRetriever
17
+ >>>
18
+ >>> # Process PDFs
19
+ >>> processor = PDFProcessor(dpi=140)
20
+ >>> images, texts = processor.process_pdf("report.pdf")
21
+ >>>
22
+ >>> # Generate embeddings
23
+ >>> embedder = VisualEmbedder()
24
+ >>> embeddings = embedder.embed_images(images)
25
+ >>> query_emb = embedder.embed_query("What is the budget?")
26
+ >>>
27
+ >>> # Search with two-stage retrieval
28
+ >>> retriever = TwoStageRetriever(qdrant_client, "my_collection")
29
+ >>> results = retriever.search(query_emb, top_k=10)
30
+
31
+ Each component works independently - use only what you need.
32
+ """
33
+
34
+ __version__ = "0.1.0"
35
+
36
+ # Import main classes at package level for convenience
37
+ # These are optional - if dependencies aren't installed, we catch the error
38
+
39
+ try:
40
+ from visual_rag.embedding.visual_embedder import VisualEmbedder
41
+ except ImportError:
42
+ VisualEmbedder = None
43
+
44
+ try:
45
+ from visual_rag.indexing.pdf_processor import PDFProcessor
46
+ except ImportError:
47
+ PDFProcessor = None
48
+
49
+ try:
50
+ from visual_rag.indexing.qdrant_indexer import QdrantIndexer
51
+ except ImportError:
52
+ QdrantIndexer = None
53
+
54
+ try:
55
+ from visual_rag.indexing.cloudinary_uploader import CloudinaryUploader
56
+ except ImportError:
57
+ CloudinaryUploader = None
58
+
59
+ try:
60
+ from visual_rag.retrieval.two_stage import TwoStageRetriever
61
+ except ImportError:
62
+ TwoStageRetriever = None
63
+
64
+ try:
65
+ from visual_rag.retrieval.multi_vector import MultiVectorRetriever
66
+ except ImportError:
67
+ MultiVectorRetriever = None
68
+
69
+ try:
70
+ from visual_rag.qdrant_admin import QdrantAdmin
71
+ except ImportError:
72
+ QdrantAdmin = None
73
+
74
+ try:
75
+ from visual_rag.demo_runner import demo
76
+ except ImportError:
77
+ demo = None
78
+
79
+ # Config utilities (always available)
80
+ from visual_rag.config import get, get_section, load_config
81
+
82
+ __all__ = [
83
+ # Version
84
+ "__version__",
85
+ # Main classes
86
+ "VisualEmbedder",
87
+ "PDFProcessor",
88
+ "QdrantIndexer",
89
+ "CloudinaryUploader",
90
+ "TwoStageRetriever",
91
+ "MultiVectorRetriever",
92
+ "QdrantAdmin",
93
+ "demo",
94
+ # Config utilities
95
+ "load_config",
96
+ "get",
97
+ "get_section",
98
+ ]
@@ -0,0 +1 @@
1
+ """CLI entry point for visual-rag-toolkit."""