visual-rag-toolkit 0.1.2__tar.gz → 0.1.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/PKG-INFO +28 -16
  2. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/README.md +27 -15
  3. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/demo/__init__.py +1 -1
  4. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/demo/app.py +20 -8
  5. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/demo/evaluation.py +5 -45
  6. visual_rag_toolkit-0.1.4/demo/indexing.py +274 -0
  7. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/demo/qdrant_utils.py +12 -5
  8. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/demo/ui/playground.py +1 -1
  9. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/demo/ui/sidebar.py +26 -3
  10. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/demo/ui/upload.py +6 -5
  11. visual_rag_toolkit-0.1.4/examples/COMMANDS.md +83 -0
  12. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/examples/config.yaml +6 -0
  13. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/examples/process_pdfs.py +6 -0
  14. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/examples/search_demo.py +6 -0
  15. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/pyproject.toml +25 -1
  16. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/visual_rag/__init__.py +63 -6
  17. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/visual_rag/config.py +4 -7
  18. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/visual_rag/demo_runner.py +3 -5
  19. visual_rag_toolkit-0.1.4/visual_rag/indexing/__init__.py +38 -0
  20. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/visual_rag/indexing/qdrant_indexer.py +94 -42
  21. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/visual_rag/retrieval/multi_vector.py +62 -65
  22. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/visual_rag/retrieval/single_stage.py +7 -0
  23. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/visual_rag/retrieval/two_stage.py +7 -10
  24. visual_rag_toolkit-0.1.2/demo/example_metadata_mapping_sigir.json +0 -37
  25. visual_rag_toolkit-0.1.2/demo/indexing.py +0 -315
  26. visual_rag_toolkit-0.1.2/visual_rag/indexing/__init__.py +0 -21
  27. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/.github/workflows/ci.yaml +0 -0
  28. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/.github/workflows/publish_pypi.yaml +0 -0
  29. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/.gitignore +0 -0
  30. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/LICENSE +0 -0
  31. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/benchmarks/README.md +0 -0
  32. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/benchmarks/__init__.py +0 -0
  33. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/benchmarks/analyze_results.py +0 -0
  34. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/benchmarks/benchmark_datasets.txt +0 -0
  35. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/benchmarks/prepare_submission.py +0 -0
  36. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/benchmarks/quick_test.py +0 -0
  37. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/benchmarks/run_vidore.py +0 -0
  38. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/benchmarks/vidore_beir_qdrant/run_qdrant_beir.py +0 -0
  39. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/benchmarks/vidore_tatdqa_test/__init__.py +0 -0
  40. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/benchmarks/vidore_tatdqa_test/dataset_loader.py +0 -0
  41. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/benchmarks/vidore_tatdqa_test/metrics.py +0 -0
  42. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/benchmarks/vidore_tatdqa_test/run_qdrant.py +0 -0
  43. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/benchmarks/vidore_tatdqa_test/sweep_eval.py +0 -0
  44. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/demo/commands.py +0 -0
  45. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/demo/config.py +0 -0
  46. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/demo/download_models.py +0 -0
  47. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/demo/results.py +0 -0
  48. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/demo/test_qdrant_connection.py +0 -0
  49. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/demo/ui/__init__.py +0 -0
  50. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/demo/ui/benchmark.py +0 -0
  51. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/demo/ui/header.py +0 -0
  52. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/requirements.txt +0 -0
  53. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/tests/__init__.py +0 -0
  54. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/tests/test_config.py +0 -0
  55. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/tests/test_pdf_processor.py +0 -0
  56. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/tests/test_pooling.py +0 -0
  57. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/tests/test_strategies.py +0 -0
  58. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/visual_rag/cli/__init__.py +0 -0
  59. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/visual_rag/cli/main.py +0 -0
  60. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/visual_rag/embedding/__init__.py +0 -0
  61. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/visual_rag/embedding/pooling.py +0 -0
  62. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/visual_rag/embedding/visual_embedder.py +0 -0
  63. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/visual_rag/indexing/cloudinary_uploader.py +0 -0
  64. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/visual_rag/indexing/pdf_processor.py +0 -0
  65. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/visual_rag/indexing/pipeline.py +0 -0
  66. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/visual_rag/preprocessing/__init__.py +0 -0
  67. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/visual_rag/preprocessing/crop_empty.py +0 -0
  68. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/visual_rag/qdrant_admin.py +0 -0
  69. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/visual_rag/retrieval/__init__.py +0 -0
  70. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/visual_rag/retrieval/three_stage.py +0 -0
  71. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/visual_rag/visualization/__init__.py +0 -0
  72. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/visual_rag/visualization/saliency.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: visual-rag-toolkit
3
- Version: 0.1.2
3
+ Version: 0.1.4
4
4
  Summary: End-to-end visual document retrieval with ColPali, featuring two-stage pooling for scalable search
5
5
  Project-URL: Homepage, https://github.com/Ara-Yeroyan/visual-rag-toolkit
6
6
  Project-URL: Documentation, https://github.com/Ara-Yeroyan/visual-rag-toolkit#readme
@@ -88,14 +88,12 @@ Description-Content-Type: text/markdown
88
88
  [![PyPI](https://img.shields.io/pypi/v/visual-rag-toolkit)](https://pypi.org/project/visual-rag-toolkit/)
89
89
  [![Python](https://img.shields.io/pypi/pyversions/visual-rag-toolkit)](https://pypi.org/project/visual-rag-toolkit/)
90
90
  [![License](https://img.shields.io/pypi/l/visual-rag-toolkit)](LICENSE)
91
- [![CI](https://img.shields.io/github/actions/workflow/status/Ara-Yeroyan/visual-rag-toolkit/ci.yaml?branch=main)](https://github.com/Ara-Yeroyan/visual-rag-toolkit/actions/workflows/ci.yaml)
92
-
93
- Note:
94
- - The **PyPI badge** shows “not found” until the first release is published.
95
- - The **CI badge** requires the GitHub repo to be **public** (GitHub does not serve Actions badges for private repos).
91
+ [![Demo](https://img.shields.io/badge/Demo-Hugging%20Face-yellow)](https://huggingface.co/spaces/Yeroyan/visual-rag-toolkit)
96
92
 
97
93
  End-to-end visual document retrieval toolkit featuring **fast multi-stage retrieval** (prefetch with pooled vectors + exact MaxSim reranking).
98
94
 
95
+ **[Try the Live Demo](https://huggingface.co/spaces/Yeroyan/visual-rag-toolkit)** - Upload PDFs, index to Qdrant, and query with visual retrieval.
96
+
99
97
  This repo contains:
100
98
  - a **Python package** (`visual_rag`)
101
99
  - a **Streamlit demo app** (`demo/`)
@@ -162,7 +160,7 @@ for r in results[:3]:
162
160
 
163
161
  ### End-to-end: ingest PDFs (with cropping) → index in Qdrant
164
162
 
165
- This is the SDK-style pipeline: PDF → images → optional crop → embed → store vectors + payload in Qdrant.
163
+ This is the "SDK-style" pipeline: PDF → images → optional crop → embed → store vectors + payload in Qdrant.
166
164
 
167
165
  ```python
168
166
  import os
@@ -174,8 +172,8 @@ import torch
174
172
  from visual_rag import VisualEmbedder
175
173
  from visual_rag.indexing import ProcessingPipeline, QdrantIndexer
176
174
 
177
- QDRANT_URL = os.environ["SIGIR_QDRANT_URL"] # or QDRANT_URL
178
- QDRANT_KEY = os.getenv("SIGIR_QDRANT_KEY", "") # or QDRANT_API_KEY
175
+ QDRANT_URL = os.environ["QDRANT_URL"]
176
+ QDRANT_KEY = os.getenv("QDRANT_API_KEY", "")
179
177
 
180
178
  collection = "my_visual_docs"
181
179
 
@@ -193,6 +191,8 @@ indexer = QdrantIndexer(
193
191
  prefer_grpc=True,
194
192
  vector_datatype="float16",
195
193
  )
194
+
195
+ # Creates collection + required payload indexes (e.g., "filename" for skip_existing)
196
196
  indexer.create_collection(force_recreate=False)
197
197
 
198
198
  pipeline = ProcessingPipeline(
@@ -208,19 +208,32 @@ pipeline = ProcessingPipeline(
208
208
 
209
209
  pdfs = [Path("docs/a.pdf"), Path("docs/b.pdf")]
210
210
  for pdf_path in pdfs:
211
- pipeline.process_pdf(
211
+ result = pipeline.process_pdf(
212
212
  pdf_path,
213
- skip_existing=True,
213
+ skip_existing=True, # Skip pages already in Qdrant (uses filename index)
214
214
  upload_to_cloudinary=False,
215
215
  upload_to_qdrant=True,
216
216
  )
217
+ # Logs automatically shown:
218
+ # [10:23:45] 📚 Processing PDF: a.pdf
219
+ # [10:23:45] 🖼️ Converting PDF to images...
220
+ # [10:23:46] ✅ Converted 12 pages
221
+ # [10:23:46] 📦 Processing pages 1-8/12
222
+ # [10:23:46] 🤖 Generating embeddings for 8 pages...
223
+ # [10:23:48] 📤 Uploading batch of 8 pages...
224
+ # [10:23:48] ✅ Uploaded 8 points to Qdrant
225
+ # [10:23:48] 📦 Processing pages 9-12/12
226
+ # [10:23:48] 🤖 Generating embeddings for 4 pages...
227
+ # [10:23:50] 📤 Uploading batch of 4 pages...
228
+ # [10:23:50] ✅ Uploaded 4 points to Qdrant
229
+ # [10:23:50] ✅ Completed a.pdf: 12 uploaded, 0 skipped, 0 failed
217
230
  ```
218
231
 
219
232
  CLI equivalent:
220
233
 
221
234
  ```bash
222
- export SIGIR_QDRANT_URL="https://YOUR_QDRANT"
223
- export SIGIR_QDRANT_KEY="YOUR_KEY"
235
+ export QDRANT_URL="https://YOUR_QDRANT"
236
+ export QDRANT_API_KEY="YOUR_KEY"
224
237
 
225
238
  visual-rag process \
226
239
  --reports-dir ./docs \
@@ -263,7 +276,7 @@ Stage 2: Exact MaxSim reranking on candidates
263
276
  └── Return top-k results (e.g., 10)
264
277
  ```
265
278
 
266
- Three-stage extends this with an additional cheap prefetch stage before stage 2.
279
+ Three-stage extends this with an additional "cheap prefetch" stage before stage 2.
267
280
 
268
281
  ## 📁 Package Structure
269
282
 
@@ -358,7 +371,7 @@ If you use this toolkit in your research, please cite:
358
371
 
359
372
  ```bibtex
360
373
  @software{visual_rag_toolkit,
361
- title = {Visual RAG Toolkit: Scalable Visual Document Retrieval with Two-Stage Pooling},
374
+ title = {Visual RAG Toolkit: Scalable Visual Document Retrieval with 1D Convolutional Pooling},
362
375
  author = {Ara Yeroyan},
363
376
  year = {2026},
364
377
  url = {https://github.com/Ara-Yeroyan/visual-rag-toolkit}
@@ -374,4 +387,3 @@ MIT License - see [LICENSE](LICENSE) for details.
374
387
  - [Qdrant](https://qdrant.tech/) - Vector database with multi-vector support
375
388
  - [ColPali](https://github.com/illuin-tech/colpali) - Visual document retrieval models
376
389
  - [ViDoRe](https://huggingface.co/spaces/vidore/vidore-leaderboard) - Benchmark dataset
377
-
@@ -3,14 +3,12 @@
3
3
  [![PyPI](https://img.shields.io/pypi/v/visual-rag-toolkit)](https://pypi.org/project/visual-rag-toolkit/)
4
4
  [![Python](https://img.shields.io/pypi/pyversions/visual-rag-toolkit)](https://pypi.org/project/visual-rag-toolkit/)
5
5
  [![License](https://img.shields.io/pypi/l/visual-rag-toolkit)](LICENSE)
6
- [![CI](https://img.shields.io/github/actions/workflow/status/Ara-Yeroyan/visual-rag-toolkit/ci.yaml?branch=main)](https://github.com/Ara-Yeroyan/visual-rag-toolkit/actions/workflows/ci.yaml)
7
-
8
- Note:
9
- - The **PyPI badge** shows “not found” until the first release is published.
10
- - The **CI badge** requires the GitHub repo to be **public** (GitHub does not serve Actions badges for private repos).
6
+ [![Demo](https://img.shields.io/badge/Demo-Hugging%20Face-yellow)](https://huggingface.co/spaces/Yeroyan/visual-rag-toolkit)
11
7
 
12
8
  End-to-end visual document retrieval toolkit featuring **fast multi-stage retrieval** (prefetch with pooled vectors + exact MaxSim reranking).
13
9
 
10
+ **[Try the Live Demo](https://huggingface.co/spaces/Yeroyan/visual-rag-toolkit)** - Upload PDFs, index to Qdrant, and query with visual retrieval.
11
+
14
12
  This repo contains:
15
13
  - a **Python package** (`visual_rag`)
16
14
  - a **Streamlit demo app** (`demo/`)
@@ -77,7 +75,7 @@ for r in results[:3]:
77
75
 
78
76
  ### End-to-end: ingest PDFs (with cropping) → index in Qdrant
79
77
 
80
- This is the SDK-style pipeline: PDF → images → optional crop → embed → store vectors + payload in Qdrant.
78
+ This is the "SDK-style" pipeline: PDF → images → optional crop → embed → store vectors + payload in Qdrant.
81
79
 
82
80
  ```python
83
81
  import os
@@ -89,8 +87,8 @@ import torch
89
87
  from visual_rag import VisualEmbedder
90
88
  from visual_rag.indexing import ProcessingPipeline, QdrantIndexer
91
89
 
92
- QDRANT_URL = os.environ["SIGIR_QDRANT_URL"] # or QDRANT_URL
93
- QDRANT_KEY = os.getenv("SIGIR_QDRANT_KEY", "") # or QDRANT_API_KEY
90
+ QDRANT_URL = os.environ["QDRANT_URL"]
91
+ QDRANT_KEY = os.getenv("QDRANT_API_KEY", "")
94
92
 
95
93
  collection = "my_visual_docs"
96
94
 
@@ -108,6 +106,8 @@ indexer = QdrantIndexer(
108
106
  prefer_grpc=True,
109
107
  vector_datatype="float16",
110
108
  )
109
+
110
+ # Creates collection + required payload indexes (e.g., "filename" for skip_existing)
111
111
  indexer.create_collection(force_recreate=False)
112
112
 
113
113
  pipeline = ProcessingPipeline(
@@ -123,19 +123,32 @@ pipeline = ProcessingPipeline(
123
123
 
124
124
  pdfs = [Path("docs/a.pdf"), Path("docs/b.pdf")]
125
125
  for pdf_path in pdfs:
126
- pipeline.process_pdf(
126
+ result = pipeline.process_pdf(
127
127
  pdf_path,
128
- skip_existing=True,
128
+ skip_existing=True, # Skip pages already in Qdrant (uses filename index)
129
129
  upload_to_cloudinary=False,
130
130
  upload_to_qdrant=True,
131
131
  )
132
+ # Logs automatically shown:
133
+ # [10:23:45] 📚 Processing PDF: a.pdf
134
+ # [10:23:45] 🖼️ Converting PDF to images...
135
+ # [10:23:46] ✅ Converted 12 pages
136
+ # [10:23:46] 📦 Processing pages 1-8/12
137
+ # [10:23:46] 🤖 Generating embeddings for 8 pages...
138
+ # [10:23:48] 📤 Uploading batch of 8 pages...
139
+ # [10:23:48] ✅ Uploaded 8 points to Qdrant
140
+ # [10:23:48] 📦 Processing pages 9-12/12
141
+ # [10:23:48] 🤖 Generating embeddings for 4 pages...
142
+ # [10:23:50] 📤 Uploading batch of 4 pages...
143
+ # [10:23:50] ✅ Uploaded 4 points to Qdrant
144
+ # [10:23:50] ✅ Completed a.pdf: 12 uploaded, 0 skipped, 0 failed
132
145
  ```
133
146
 
134
147
  CLI equivalent:
135
148
 
136
149
  ```bash
137
- export SIGIR_QDRANT_URL="https://YOUR_QDRANT"
138
- export SIGIR_QDRANT_KEY="YOUR_KEY"
150
+ export QDRANT_URL="https://YOUR_QDRANT"
151
+ export QDRANT_API_KEY="YOUR_KEY"
139
152
 
140
153
  visual-rag process \
141
154
  --reports-dir ./docs \
@@ -178,7 +191,7 @@ Stage 2: Exact MaxSim reranking on candidates
178
191
  └── Return top-k results (e.g., 10)
179
192
  ```
180
193
 
181
- Three-stage extends this with an additional cheap prefetch stage before stage 2.
194
+ Three-stage extends this with an additional "cheap prefetch" stage before stage 2.
182
195
 
183
196
  ## 📁 Package Structure
184
197
 
@@ -273,7 +286,7 @@ If you use this toolkit in your research, please cite:
273
286
 
274
287
  ```bibtex
275
288
  @software{visual_rag_toolkit,
276
- title = {Visual RAG Toolkit: Scalable Visual Document Retrieval with Two-Stage Pooling},
289
+ title = {Visual RAG Toolkit: Scalable Visual Document Retrieval with 1D Convolutional Pooling},
277
290
  author = {Ara Yeroyan},
278
291
  year = {2026},
279
292
  url = {https://github.com/Ara-Yeroyan/visual-rag-toolkit}
@@ -289,4 +302,3 @@ MIT License - see [LICENSE](LICENSE) for details.
289
302
  - [Qdrant](https://qdrant.tech/) - Vector database with multi-vector support
290
303
  - [ColPali](https://github.com/illuin-tech/colpali) - Visual document retrieval models
291
304
  - [ViDoRe](https://huggingface.co/spaces/vidore/vidore-leaderboard) - Benchmark dataset
292
-
@@ -7,4 +7,4 @@ A Streamlit-based UI for:
7
7
  - Interactive playground for visual search
8
8
  """
9
9
 
10
- __version__ = "0.1.0"
10
+ __version__ = "0.1.4"
@@ -1,13 +1,23 @@
1
1
  """Main entry point for the Visual RAG Toolkit demo application."""
2
2
 
3
+ import os
3
4
  import sys
4
5
  from pathlib import Path
5
6
 
6
- ROOT_DIR = Path(__file__).parent.parent
7
- sys.path.insert(0, str(ROOT_DIR))
7
+ # Ensure repo root is in sys.path for local development
8
+ # (In HF Space / Docker, PYTHONPATH is already set correctly)
9
+ _app_dir = Path(__file__).resolve().parent
10
+ _repo_root = _app_dir.parent
11
+ if str(_repo_root) not in sys.path:
12
+ sys.path.insert(0, str(_repo_root))
8
13
 
9
14
  from dotenv import load_dotenv
10
- load_dotenv(ROOT_DIR / ".env")
15
+
16
+ # Load .env from the repo root (works both locally and in Docker)
17
+ if (_repo_root / ".env").exists():
18
+ load_dotenv(_repo_root / ".env")
19
+ if (_app_dir / ".env").exists():
20
+ load_dotenv(_app_dir / ".env")
11
21
 
12
22
  import streamlit as st
13
23
 
@@ -28,15 +38,17 @@ from demo.ui.benchmark import render_benchmark_tab
28
38
  def main():
29
39
  render_header()
30
40
  render_sidebar()
31
-
32
- tab_upload, tab_playground, tab_benchmark = st.tabs(["📤 Upload", "🎮 Playground", "📊 Benchmarking"])
33
-
41
+
42
+ tab_upload, tab_playground, tab_benchmark = st.tabs(
43
+ ["📤 Upload", "🎮 Playground", "📊 Benchmarking"]
44
+ )
45
+
34
46
  with tab_upload:
35
47
  render_upload_tab()
36
-
48
+
37
49
  with tab_playground:
38
50
  render_playground_tab()
39
-
51
+
40
52
  with tab_benchmark:
41
53
  render_benchmark_tab()
42
54
 
@@ -1,20 +1,23 @@
1
1
  """Evaluation runner with UI updates."""
2
2
 
3
3
  import hashlib
4
- import importlib.util
5
4
  import json
6
5
  import logging
7
6
  import time
8
7
  import traceback
9
8
  from datetime import datetime
10
- from pathlib import Path
11
9
  from typing import Any, Dict, List, Optional
12
10
 
13
11
  import numpy as np
14
12
  import streamlit as st
15
13
  import torch
14
+ from qdrant_client.models import FieldCondition, Filter, MatchValue
16
15
 
17
16
  from visual_rag import VisualEmbedder
17
+ from visual_rag.retrieval import MultiVectorRetriever
18
+ from benchmarks.vidore_tatdqa_test.dataset_loader import load_vidore_beir_dataset
19
+ from benchmarks.vidore_tatdqa_test.metrics import ndcg_at_k, mrr_at_k, recall_at_k
20
+ from demo.qdrant_utils import get_qdrant_credentials
18
21
 
19
22
 
20
23
  TORCH_DTYPE_MAP = {
@@ -22,49 +25,6 @@ TORCH_DTYPE_MAP = {
22
25
  "float32": torch.float32,
23
26
  "bfloat16": torch.bfloat16,
24
27
  }
25
- from qdrant_client.models import Filter, FieldCondition, MatchValue
26
-
27
- from visual_rag.retrieval import MultiVectorRetriever
28
-
29
-
30
- def _load_local_benchmark_module(module_filename: str):
31
- """
32
- Load `benchmarks/vidore_tatdqa_test/<module_filename>` via file path.
33
-
34
- Motivation:
35
- - Some environments (notably containers / Spaces) can have a third-party
36
- `benchmarks` package installed, causing `import benchmarks...` to resolve
37
- to the wrong module.
38
- - This fallback guarantees we load the repo's benchmark utilities.
39
- """
40
- root = Path(__file__).resolve().parents[1] # demo/.. = repo root
41
- target = root / "benchmarks" / "vidore_tatdqa_test" / module_filename
42
- if not target.exists():
43
- raise ModuleNotFoundError(f"Missing local benchmark module file: {target}")
44
-
45
- name = f"_visual_rag_toolkit_local_{target.stem}"
46
- spec = importlib.util.spec_from_file_location(name, str(target))
47
- if spec is None or spec.loader is None:
48
- raise ModuleNotFoundError(f"Could not load module spec for: {target}")
49
- mod = importlib.util.module_from_spec(spec)
50
- spec.loader.exec_module(mod) # type: ignore[attr-defined]
51
- return mod
52
-
53
-
54
- try:
55
- # Preferred: normal import
56
- from benchmarks.vidore_tatdqa_test.dataset_loader import load_vidore_beir_dataset
57
- from benchmarks.vidore_tatdqa_test.metrics import ndcg_at_k, mrr_at_k, recall_at_k
58
- except ModuleNotFoundError:
59
- # Robust fallback: load from local file paths
60
- _dl = _load_local_benchmark_module("dataset_loader.py")
61
- _mx = _load_local_benchmark_module("metrics.py")
62
- load_vidore_beir_dataset = _dl.load_vidore_beir_dataset
63
- ndcg_at_k = _mx.ndcg_at_k
64
- mrr_at_k = _mx.mrr_at_k
65
- recall_at_k = _mx.recall_at_k
66
-
67
- from demo.qdrant_utils import get_qdrant_credentials
68
28
 
69
29
  logger = logging.getLogger(__name__)
70
30
  logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")
@@ -0,0 +1,274 @@
1
+ """Indexing runner with UI updates."""
2
+
3
+ import hashlib
4
+ import json
5
+ import time
6
+ import traceback
7
+ from datetime import datetime
8
+ from typing import Any, Dict, Optional
9
+
10
+ import numpy as np
11
+ import streamlit as st
12
+ import torch
13
+
14
+ from visual_rag import VisualEmbedder
15
+ from visual_rag.embedding.pooling import tile_level_mean_pooling
16
+ from visual_rag.indexing.qdrant_indexer import QdrantIndexer
17
+ from benchmarks.vidore_tatdqa_test.dataset_loader import load_vidore_beir_dataset
18
+ from demo.qdrant_utils import get_qdrant_credentials
19
+
20
+
21
+ TORCH_DTYPE_MAP = {
22
+ "float16": torch.float16,
23
+ "float32": torch.float32,
24
+ "bfloat16": torch.bfloat16,
25
+ }
26
+
27
+
28
+ def _stable_uuid(text: str) -> str:
29
+ """Generate a stable UUID from text (same as benchmark script)."""
30
+ hex_str = hashlib.sha256(text.encode("utf-8")).hexdigest()[:32]
31
+ return f"{hex_str[:8]}-{hex_str[8:12]}-{hex_str[12:16]}-{hex_str[16:20]}-{hex_str[20:32]}"
32
+
33
+
34
+ def _union_point_id(
35
+ *, dataset_name: str, source_doc_id: str, union_namespace: Optional[str]
36
+ ) -> str:
37
+ """Generate union point ID (same as benchmark script)."""
38
+ ns = f"{union_namespace}::{dataset_name}" if union_namespace else dataset_name
39
+ return _stable_uuid(f"{ns}::{source_doc_id}")
40
+
41
+
42
+ def run_indexing_with_ui(config: Dict[str, Any]):
43
+ st.divider()
44
+
45
+ print("=" * 60)
46
+ print("[INDEX] Starting indexing via UI")
47
+ print("=" * 60)
48
+
49
+ url, api_key = get_qdrant_credentials()
50
+ if not url:
51
+ st.error("QDRANT_URL not configured")
52
+ return
53
+
54
+ datasets = config.get("datasets", [])
55
+ collection = config["collection"]
56
+ model = config.get("model", "vidore/colpali-v1.3")
57
+ recreate = config.get("recreate", False)
58
+ torch_dtype = config.get("torch_dtype", "float16")
59
+ qdrant_vector_dtype = config.get("qdrant_vector_dtype", "float16")
60
+ prefer_grpc = config.get("prefer_grpc", True)
61
+ batch_size = config.get("batch_size", 4)
62
+ max_docs = config.get("max_docs")
63
+
64
+ print(f"[INDEX] Config: collection={collection}, model={model}")
65
+ print(f"[INDEX] Datasets: {datasets}")
66
+ print(
67
+ f"[INDEX] max_docs={max_docs}, batch_size={batch_size}, recreate={recreate}"
68
+ )
69
+ print(
70
+ f"[INDEX] torch_dtype={torch_dtype}, qdrant_dtype={qdrant_vector_dtype}, grpc={prefer_grpc}"
71
+ )
72
+
73
+ phase1_container = st.container()
74
+ phase2_container = st.container()
75
+ phase3_container = st.container()
76
+ results_container = st.container()
77
+
78
+ try:
79
+ with phase1_container:
80
+ st.markdown("##### 🤖 Phase 1: Loading Model")
81
+ model_status = st.empty()
82
+ model_status.info(f"Loading `{model.split('/')[-1]}`...")
83
+
84
+ print(f"[INDEX] Loading embedder: {model}")
85
+ torch_dtype_obj = TORCH_DTYPE_MAP.get(torch_dtype, torch.float16)
86
+ output_dtype_obj = (
87
+ np.float16 if qdrant_vector_dtype == "float16" else np.float32
88
+ )
89
+ embedder = VisualEmbedder(
90
+ model_name=model,
91
+ torch_dtype=torch_dtype_obj,
92
+ output_dtype=output_dtype_obj,
93
+ )
94
+ embedder._load_model()
95
+ print(
96
+ f"[INDEX] Embedder loaded (torch_dtype={torch_dtype}, output_dtype={qdrant_vector_dtype})"
97
+ )
98
+ model_status.success(f"✅ Model `{model.split('/')[-1]}` loaded")
99
+
100
+ with phase2_container:
101
+ st.markdown("##### 📦 Phase 2: Setting Up Collection")
102
+
103
+ indexer_status = st.empty()
104
+ indexer_status.info("Connecting to Qdrant...")
105
+
106
+ print("[INDEX] Connecting to Qdrant...")
107
+ indexer = QdrantIndexer(
108
+ url=url,
109
+ api_key=api_key,
110
+ collection_name=collection,
111
+ prefer_grpc=prefer_grpc,
112
+ vector_datatype=qdrant_vector_dtype,
113
+ )
114
+ print("[INDEX] Connected to Qdrant")
115
+ indexer_status.success("✅ Connected to Qdrant")
116
+
117
+ coll_status = st.empty()
118
+ action = "Recreating" if recreate else "Creating/verifying"
119
+ coll_status.info(f"{action} collection `{collection}`...")
120
+
121
+ print(f"[INDEX] {action} collection: {collection}")
122
+ indexer.create_collection(force_recreate=recreate)
123
+ indexer.create_payload_indexes(
124
+ fields=[
125
+ {"field": "dataset", "type": "keyword"},
126
+ {"field": "doc_id", "type": "keyword"},
127
+ {"field": "source_doc_id", "type": "keyword"},
128
+ ]
129
+ )
130
+ print("[INDEX] Collection ready")
131
+ coll_status.success(f"✅ Collection `{collection}` ready")
132
+
133
+ with phase3_container:
134
+ st.markdown("##### 📊 Phase 3: Processing Datasets")
135
+
136
+ all_results = []
137
+
138
+ for ds_idx, dataset_name in enumerate(datasets):
139
+ ds_short = dataset_name.split("/")[-1]
140
+ ds_container = st.container()
141
+
142
+ with ds_container:
143
+ st.markdown(
144
+ f"**Dataset {ds_idx + 1}/{len(datasets)}: `{ds_short}`**"
145
+ )
146
+
147
+ load_status = st.empty()
148
+ load_status.info(f"Loading dataset `{ds_short}`...")
149
+
150
+ print(f"[INDEX] Loading dataset: {dataset_name}")
151
+ corpus, queries, qrels = load_vidore_beir_dataset(dataset_name)
152
+ total_docs = len(corpus)
153
+ print(f"[INDEX] Dataset loaded: {total_docs} docs")
154
+ load_status.success(f"✅ Loaded {total_docs:,} documents")
155
+
156
+ if max_docs and max_docs < total_docs:
157
+ corpus = corpus[:max_docs]
158
+ print(f"[INDEX] Limiting to {max_docs} docs")
159
+
160
+ progress_bar = st.progress(0)
161
+ status_text = st.empty()
162
+
163
+ uploaded = 0
164
+ failed = 0
165
+ total = len(corpus)
166
+
167
+ for i, doc in enumerate(corpus):
168
+ try:
169
+ doc_id = str(doc.doc_id)
170
+ image = doc.image
171
+ if image is None:
172
+ failed += 1
173
+ continue
174
+
175
+ status_text.text(
176
+ f"Processing {i + 1}/{total}: {doc_id[:30]}..."
177
+ )
178
+
179
+ embeddings, token_infos = embedder.embed_images(
180
+ [image],
181
+ return_token_info=True,
182
+ show_progress=False,
183
+ )
184
+ emb = embeddings[0]
185
+ token_info = token_infos[0] if token_infos else {}
186
+
187
+ if hasattr(emb, "cpu"):
188
+ emb = emb.cpu()
189
+ emb_np = np.asarray(emb, dtype=output_dtype_obj)
190
+
191
+ initial = emb_np.tolist()
192
+ global_pool = emb_np.mean(axis=0).tolist()
193
+
194
+ num_tiles = token_info.get("num_tiles")
195
+ mean_pooling = None
196
+ experimental_pooling = None
197
+
198
+ if num_tiles and num_tiles > 0:
199
+ try:
200
+ mean_pooling = tile_level_mean_pooling(
201
+ emb_np, num_tiles=num_tiles, patches_per_tile=64
202
+ ).tolist()
203
+ except Exception:
204
+ pass
205
+
206
+ try:
207
+ exp_pool = embedder.experimental_pool_visual_embedding(
208
+ emb_np, num_tiles=num_tiles
209
+ )
210
+ if exp_pool is not None:
211
+ experimental_pooling = exp_pool.tolist()
212
+ except Exception:
213
+ pass
214
+
215
+ union_doc_id = _union_point_id(
216
+ dataset_name=dataset_name,
217
+ source_doc_id=doc_id,
218
+ union_namespace=collection,
219
+ )
220
+
221
+ payload = {
222
+ "dataset": dataset_name,
223
+ "doc_id": doc_id,
224
+ "source_doc_id": doc_id,
225
+ "union_doc_id": union_doc_id,
226
+ "num_tiles": num_tiles,
227
+ "num_visual_tokens": token_info.get("num_visual_tokens"),
228
+ }
229
+
230
+ vectors = {"initial": initial, "global_pooling": global_pool}
231
+ if mean_pooling:
232
+ vectors["mean_pooling"] = mean_pooling
233
+ if experimental_pooling:
234
+ vectors["experimental_pooling"] = experimental_pooling
235
+
236
+ indexer.upsert_point(
237
+ point_id=union_doc_id,
238
+ vectors=vectors,
239
+ payload=payload,
240
+ )
241
+
242
+ uploaded += 1
243
+
244
+ except Exception as e:
245
+ print(f"[INDEX] Error on doc {i}: {e}")
246
+ failed += 1
247
+
248
+ progress_bar.progress((i + 1) / total)
249
+
250
+ status_text.text(f"✅ Done: {uploaded} uploaded, {failed} failed")
251
+ all_results.append(
252
+ {
253
+ "dataset": dataset_name,
254
+ "total": total,
255
+ "uploaded": uploaded,
256
+ "failed": failed,
257
+ }
258
+ )
259
+
260
+ with results_container:
261
+ st.markdown("##### 📋 Results Summary")
262
+
263
+ for r in all_results:
264
+ st.write(
265
+ f"**{r['dataset'].split('/')[-1]}**: {r['uploaded']:,} uploaded, {r['failed']:,} failed"
266
+ )
267
+
268
+ st.success("✅ Indexing complete!")
269
+
270
+ except Exception as e:
271
+ st.error(f"Indexing error: {e}")
272
+ st.code(traceback.format_exc())
273
+ print(f"[INDEX] ERROR: {e}")
274
+ traceback.print_exc()
@@ -8,12 +8,19 @@ import streamlit as st
8
8
 
9
9
 
10
10
  def get_qdrant_credentials() -> Tuple[Optional[str], Optional[str]]:
11
- url = st.session_state.get("qdrant_url_input") or os.getenv("SIGIR_QDRANT_URL") or os.getenv("DEST_QDRANT_URL") or os.getenv("QDRANT_URL")
12
- api_key = st.session_state.get("qdrant_key_input") or (
13
- os.getenv("SIGIR_QDRANT_KEY")
14
- or os.getenv("SIGIR_QDRANT_API_KEY")
15
- or os.getenv("DEST_QDRANT_API_KEY")
11
+ """Get Qdrant credentials from session state or environment variables.
12
+
13
+ Priority: session_state > QDRANT_URL/QDRANT_API_KEY > legacy env vars
14
+ """
15
+ url = (
16
+ st.session_state.get("qdrant_url_input")
17
+ or os.getenv("QDRANT_URL")
18
+ or os.getenv("SIGIR_QDRANT_URL") # legacy
19
+ )
20
+ api_key = (
21
+ st.session_state.get("qdrant_key_input")
16
22
  or os.getenv("QDRANT_API_KEY")
23
+ or os.getenv("SIGIR_QDRANT_KEY") # legacy
17
24
  )
18
25
  return url, api_key
19
26
 
@@ -9,6 +9,7 @@ from demo.qdrant_utils import (
9
9
  sample_points_cached,
10
10
  search_collection,
11
11
  )
12
+ from visual_rag.retrieval import MultiVectorRetriever
12
13
 
13
14
 
14
15
  def render_playground_tab():
@@ -46,7 +47,6 @@ def render_playground_tab():
46
47
  if not st.session_state.get("model_loaded"):
47
48
  with st.spinner(f"Loading {model_short}..."):
48
49
  try:
49
- from visual_rag.retrieval import MultiVectorRetriever
50
50
  _ = MultiVectorRetriever(collection_name=active_collection, model_name=model_name)
51
51
  st.session_state["model_loaded"] = True
52
52
  st.session_state["loaded_model_key"] = cache_key