visual-rag-toolkit 0.1.2__tar.gz → 0.1.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/PKG-INFO +28 -16
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/README.md +27 -15
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/demo/__init__.py +1 -1
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/demo/app.py +20 -8
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/demo/evaluation.py +5 -45
- visual_rag_toolkit-0.1.4/demo/indexing.py +274 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/demo/qdrant_utils.py +12 -5
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/demo/ui/playground.py +1 -1
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/demo/ui/sidebar.py +26 -3
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/demo/ui/upload.py +6 -5
- visual_rag_toolkit-0.1.4/examples/COMMANDS.md +83 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/examples/config.yaml +6 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/examples/process_pdfs.py +6 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/examples/search_demo.py +6 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/pyproject.toml +25 -1
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/visual_rag/__init__.py +63 -6
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/visual_rag/config.py +4 -7
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/visual_rag/demo_runner.py +3 -5
- visual_rag_toolkit-0.1.4/visual_rag/indexing/__init__.py +38 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/visual_rag/indexing/qdrant_indexer.py +94 -42
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/visual_rag/retrieval/multi_vector.py +62 -65
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/visual_rag/retrieval/single_stage.py +7 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/visual_rag/retrieval/two_stage.py +7 -10
- visual_rag_toolkit-0.1.2/demo/example_metadata_mapping_sigir.json +0 -37
- visual_rag_toolkit-0.1.2/demo/indexing.py +0 -315
- visual_rag_toolkit-0.1.2/visual_rag/indexing/__init__.py +0 -21
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/.github/workflows/ci.yaml +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/.github/workflows/publish_pypi.yaml +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/.gitignore +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/LICENSE +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/benchmarks/README.md +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/benchmarks/__init__.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/benchmarks/analyze_results.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/benchmarks/benchmark_datasets.txt +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/benchmarks/prepare_submission.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/benchmarks/quick_test.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/benchmarks/run_vidore.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/benchmarks/vidore_beir_qdrant/run_qdrant_beir.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/benchmarks/vidore_tatdqa_test/__init__.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/benchmarks/vidore_tatdqa_test/dataset_loader.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/benchmarks/vidore_tatdqa_test/metrics.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/benchmarks/vidore_tatdqa_test/run_qdrant.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/benchmarks/vidore_tatdqa_test/sweep_eval.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/demo/commands.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/demo/config.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/demo/download_models.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/demo/results.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/demo/test_qdrant_connection.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/demo/ui/__init__.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/demo/ui/benchmark.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/demo/ui/header.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/requirements.txt +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/tests/__init__.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/tests/test_config.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/tests/test_pdf_processor.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/tests/test_pooling.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/tests/test_strategies.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/visual_rag/cli/__init__.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/visual_rag/cli/main.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/visual_rag/embedding/__init__.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/visual_rag/embedding/pooling.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/visual_rag/embedding/visual_embedder.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/visual_rag/indexing/cloudinary_uploader.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/visual_rag/indexing/pdf_processor.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/visual_rag/indexing/pipeline.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/visual_rag/preprocessing/__init__.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/visual_rag/preprocessing/crop_empty.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/visual_rag/qdrant_admin.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/visual_rag/retrieval/__init__.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/visual_rag/retrieval/three_stage.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/visual_rag/visualization/__init__.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.4}/visual_rag/visualization/saliency.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: visual-rag-toolkit
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.4
|
|
4
4
|
Summary: End-to-end visual document retrieval with ColPali, featuring two-stage pooling for scalable search
|
|
5
5
|
Project-URL: Homepage, https://github.com/Ara-Yeroyan/visual-rag-toolkit
|
|
6
6
|
Project-URL: Documentation, https://github.com/Ara-Yeroyan/visual-rag-toolkit#readme
|
|
@@ -88,14 +88,12 @@ Description-Content-Type: text/markdown
|
|
|
88
88
|
[](https://pypi.org/project/visual-rag-toolkit/)
|
|
89
89
|
[](https://pypi.org/project/visual-rag-toolkit/)
|
|
90
90
|
[](LICENSE)
|
|
91
|
-
[](https://huggingface.co/spaces/Yeroyan/visual-rag-toolkit)
|
|
96
92
|
|
|
97
93
|
End-to-end visual document retrieval toolkit featuring **fast multi-stage retrieval** (prefetch with pooled vectors + exact MaxSim reranking).
|
|
98
94
|
|
|
95
|
+
**[Try the Live Demo](https://huggingface.co/spaces/Yeroyan/visual-rag-toolkit)** - Upload PDFs, index to Qdrant, and query with visual retrieval.
|
|
96
|
+
|
|
99
97
|
This repo contains:
|
|
100
98
|
- a **Python package** (`visual_rag`)
|
|
101
99
|
- a **Streamlit demo app** (`demo/`)
|
|
@@ -162,7 +160,7 @@ for r in results[:3]:
|
|
|
162
160
|
|
|
163
161
|
### End-to-end: ingest PDFs (with cropping) → index in Qdrant
|
|
164
162
|
|
|
165
|
-
This is the
|
|
163
|
+
This is the "SDK-style" pipeline: PDF → images → optional crop → embed → store vectors + payload in Qdrant.
|
|
166
164
|
|
|
167
165
|
```python
|
|
168
166
|
import os
|
|
@@ -174,8 +172,8 @@ import torch
|
|
|
174
172
|
from visual_rag import VisualEmbedder
|
|
175
173
|
from visual_rag.indexing import ProcessingPipeline, QdrantIndexer
|
|
176
174
|
|
|
177
|
-
QDRANT_URL = os.environ["
|
|
178
|
-
QDRANT_KEY = os.getenv("
|
|
175
|
+
QDRANT_URL = os.environ["QDRANT_URL"]
|
|
176
|
+
QDRANT_KEY = os.getenv("QDRANT_API_KEY", "")
|
|
179
177
|
|
|
180
178
|
collection = "my_visual_docs"
|
|
181
179
|
|
|
@@ -193,6 +191,8 @@ indexer = QdrantIndexer(
|
|
|
193
191
|
prefer_grpc=True,
|
|
194
192
|
vector_datatype="float16",
|
|
195
193
|
)
|
|
194
|
+
|
|
195
|
+
# Creates collection + required payload indexes (e.g., "filename" for skip_existing)
|
|
196
196
|
indexer.create_collection(force_recreate=False)
|
|
197
197
|
|
|
198
198
|
pipeline = ProcessingPipeline(
|
|
@@ -208,19 +208,32 @@ pipeline = ProcessingPipeline(
|
|
|
208
208
|
|
|
209
209
|
pdfs = [Path("docs/a.pdf"), Path("docs/b.pdf")]
|
|
210
210
|
for pdf_path in pdfs:
|
|
211
|
-
pipeline.process_pdf(
|
|
211
|
+
result = pipeline.process_pdf(
|
|
212
212
|
pdf_path,
|
|
213
|
-
skip_existing=True,
|
|
213
|
+
skip_existing=True, # Skip pages already in Qdrant (uses filename index)
|
|
214
214
|
upload_to_cloudinary=False,
|
|
215
215
|
upload_to_qdrant=True,
|
|
216
216
|
)
|
|
217
|
+
# Logs automatically shown:
|
|
218
|
+
# [10:23:45] 📚 Processing PDF: a.pdf
|
|
219
|
+
# [10:23:45] 🖼️ Converting PDF to images...
|
|
220
|
+
# [10:23:46] ✅ Converted 12 pages
|
|
221
|
+
# [10:23:46] 📦 Processing pages 1-8/12
|
|
222
|
+
# [10:23:46] 🤖 Generating embeddings for 8 pages...
|
|
223
|
+
# [10:23:48] 📤 Uploading batch of 8 pages...
|
|
224
|
+
# [10:23:48] ✅ Uploaded 8 points to Qdrant
|
|
225
|
+
# [10:23:48] 📦 Processing pages 9-12/12
|
|
226
|
+
# [10:23:48] 🤖 Generating embeddings for 4 pages...
|
|
227
|
+
# [10:23:50] 📤 Uploading batch of 4 pages...
|
|
228
|
+
# [10:23:50] ✅ Uploaded 4 points to Qdrant
|
|
229
|
+
# [10:23:50] ✅ Completed a.pdf: 12 uploaded, 0 skipped, 0 failed
|
|
217
230
|
```
|
|
218
231
|
|
|
219
232
|
CLI equivalent:
|
|
220
233
|
|
|
221
234
|
```bash
|
|
222
|
-
export
|
|
223
|
-
export
|
|
235
|
+
export QDRANT_URL="https://YOUR_QDRANT"
|
|
236
|
+
export QDRANT_API_KEY="YOUR_KEY"
|
|
224
237
|
|
|
225
238
|
visual-rag process \
|
|
226
239
|
--reports-dir ./docs \
|
|
@@ -263,7 +276,7 @@ Stage 2: Exact MaxSim reranking on candidates
|
|
|
263
276
|
└── Return top-k results (e.g., 10)
|
|
264
277
|
```
|
|
265
278
|
|
|
266
|
-
Three-stage extends this with an additional
|
|
279
|
+
Three-stage extends this with an additional "cheap prefetch" stage before stage 2.
|
|
267
280
|
|
|
268
281
|
## 📁 Package Structure
|
|
269
282
|
|
|
@@ -358,7 +371,7 @@ If you use this toolkit in your research, please cite:
|
|
|
358
371
|
|
|
359
372
|
```bibtex
|
|
360
373
|
@software{visual_rag_toolkit,
|
|
361
|
-
title = {Visual RAG Toolkit: Scalable Visual Document Retrieval with
|
|
374
|
+
title = {Visual RAG Toolkit: Scalable Visual Document Retrieval with 1D Convolutional Pooling},
|
|
362
375
|
author = {Ara Yeroyan},
|
|
363
376
|
year = {2026},
|
|
364
377
|
url = {https://github.com/Ara-Yeroyan/visual-rag-toolkit}
|
|
@@ -374,4 +387,3 @@ MIT License - see [LICENSE](LICENSE) for details.
|
|
|
374
387
|
- [Qdrant](https://qdrant.tech/) - Vector database with multi-vector support
|
|
375
388
|
- [ColPali](https://github.com/illuin-tech/colpali) - Visual document retrieval models
|
|
376
389
|
- [ViDoRe](https://huggingface.co/spaces/vidore/vidore-leaderboard) - Benchmark dataset
|
|
377
|
-
|
|
@@ -3,14 +3,12 @@
|
|
|
3
3
|
[](https://pypi.org/project/visual-rag-toolkit/)
|
|
4
4
|
[](https://pypi.org/project/visual-rag-toolkit/)
|
|
5
5
|
[](LICENSE)
|
|
6
|
-
[](https://huggingface.co/spaces/Yeroyan/visual-rag-toolkit)
|
|
11
7
|
|
|
12
8
|
End-to-end visual document retrieval toolkit featuring **fast multi-stage retrieval** (prefetch with pooled vectors + exact MaxSim reranking).
|
|
13
9
|
|
|
10
|
+
**[Try the Live Demo](https://huggingface.co/spaces/Yeroyan/visual-rag-toolkit)** - Upload PDFs, index to Qdrant, and query with visual retrieval.
|
|
11
|
+
|
|
14
12
|
This repo contains:
|
|
15
13
|
- a **Python package** (`visual_rag`)
|
|
16
14
|
- a **Streamlit demo app** (`demo/`)
|
|
@@ -77,7 +75,7 @@ for r in results[:3]:
|
|
|
77
75
|
|
|
78
76
|
### End-to-end: ingest PDFs (with cropping) → index in Qdrant
|
|
79
77
|
|
|
80
|
-
This is the
|
|
78
|
+
This is the "SDK-style" pipeline: PDF → images → optional crop → embed → store vectors + payload in Qdrant.
|
|
81
79
|
|
|
82
80
|
```python
|
|
83
81
|
import os
|
|
@@ -89,8 +87,8 @@ import torch
|
|
|
89
87
|
from visual_rag import VisualEmbedder
|
|
90
88
|
from visual_rag.indexing import ProcessingPipeline, QdrantIndexer
|
|
91
89
|
|
|
92
|
-
QDRANT_URL = os.environ["
|
|
93
|
-
QDRANT_KEY = os.getenv("
|
|
90
|
+
QDRANT_URL = os.environ["QDRANT_URL"]
|
|
91
|
+
QDRANT_KEY = os.getenv("QDRANT_API_KEY", "")
|
|
94
92
|
|
|
95
93
|
collection = "my_visual_docs"
|
|
96
94
|
|
|
@@ -108,6 +106,8 @@ indexer = QdrantIndexer(
|
|
|
108
106
|
prefer_grpc=True,
|
|
109
107
|
vector_datatype="float16",
|
|
110
108
|
)
|
|
109
|
+
|
|
110
|
+
# Creates collection + required payload indexes (e.g., "filename" for skip_existing)
|
|
111
111
|
indexer.create_collection(force_recreate=False)
|
|
112
112
|
|
|
113
113
|
pipeline = ProcessingPipeline(
|
|
@@ -123,19 +123,32 @@ pipeline = ProcessingPipeline(
|
|
|
123
123
|
|
|
124
124
|
pdfs = [Path("docs/a.pdf"), Path("docs/b.pdf")]
|
|
125
125
|
for pdf_path in pdfs:
|
|
126
|
-
pipeline.process_pdf(
|
|
126
|
+
result = pipeline.process_pdf(
|
|
127
127
|
pdf_path,
|
|
128
|
-
skip_existing=True,
|
|
128
|
+
skip_existing=True, # Skip pages already in Qdrant (uses filename index)
|
|
129
129
|
upload_to_cloudinary=False,
|
|
130
130
|
upload_to_qdrant=True,
|
|
131
131
|
)
|
|
132
|
+
# Logs automatically shown:
|
|
133
|
+
# [10:23:45] 📚 Processing PDF: a.pdf
|
|
134
|
+
# [10:23:45] 🖼️ Converting PDF to images...
|
|
135
|
+
# [10:23:46] ✅ Converted 12 pages
|
|
136
|
+
# [10:23:46] 📦 Processing pages 1-8/12
|
|
137
|
+
# [10:23:46] 🤖 Generating embeddings for 8 pages...
|
|
138
|
+
# [10:23:48] 📤 Uploading batch of 8 pages...
|
|
139
|
+
# [10:23:48] ✅ Uploaded 8 points to Qdrant
|
|
140
|
+
# [10:23:48] 📦 Processing pages 9-12/12
|
|
141
|
+
# [10:23:48] 🤖 Generating embeddings for 4 pages...
|
|
142
|
+
# [10:23:50] 📤 Uploading batch of 4 pages...
|
|
143
|
+
# [10:23:50] ✅ Uploaded 4 points to Qdrant
|
|
144
|
+
# [10:23:50] ✅ Completed a.pdf: 12 uploaded, 0 skipped, 0 failed
|
|
132
145
|
```
|
|
133
146
|
|
|
134
147
|
CLI equivalent:
|
|
135
148
|
|
|
136
149
|
```bash
|
|
137
|
-
export
|
|
138
|
-
export
|
|
150
|
+
export QDRANT_URL="https://YOUR_QDRANT"
|
|
151
|
+
export QDRANT_API_KEY="YOUR_KEY"
|
|
139
152
|
|
|
140
153
|
visual-rag process \
|
|
141
154
|
--reports-dir ./docs \
|
|
@@ -178,7 +191,7 @@ Stage 2: Exact MaxSim reranking on candidates
|
|
|
178
191
|
└── Return top-k results (e.g., 10)
|
|
179
192
|
```
|
|
180
193
|
|
|
181
|
-
Three-stage extends this with an additional
|
|
194
|
+
Three-stage extends this with an additional "cheap prefetch" stage before stage 2.
|
|
182
195
|
|
|
183
196
|
## 📁 Package Structure
|
|
184
197
|
|
|
@@ -273,7 +286,7 @@ If you use this toolkit in your research, please cite:
|
|
|
273
286
|
|
|
274
287
|
```bibtex
|
|
275
288
|
@software{visual_rag_toolkit,
|
|
276
|
-
title = {Visual RAG Toolkit: Scalable Visual Document Retrieval with
|
|
289
|
+
title = {Visual RAG Toolkit: Scalable Visual Document Retrieval with 1D Convolutional Pooling},
|
|
277
290
|
author = {Ara Yeroyan},
|
|
278
291
|
year = {2026},
|
|
279
292
|
url = {https://github.com/Ara-Yeroyan/visual-rag-toolkit}
|
|
@@ -289,4 +302,3 @@ MIT License - see [LICENSE](LICENSE) for details.
|
|
|
289
302
|
- [Qdrant](https://qdrant.tech/) - Vector database with multi-vector support
|
|
290
303
|
- [ColPali](https://github.com/illuin-tech/colpali) - Visual document retrieval models
|
|
291
304
|
- [ViDoRe](https://huggingface.co/spaces/vidore/vidore-leaderboard) - Benchmark dataset
|
|
292
|
-
|
|
@@ -1,13 +1,23 @@
|
|
|
1
1
|
"""Main entry point for the Visual RAG Toolkit demo application."""
|
|
2
2
|
|
|
3
|
+
import os
|
|
3
4
|
import sys
|
|
4
5
|
from pathlib import Path
|
|
5
6
|
|
|
6
|
-
|
|
7
|
-
|
|
7
|
+
# Ensure repo root is in sys.path for local development
|
|
8
|
+
# (In HF Space / Docker, PYTHONPATH is already set correctly)
|
|
9
|
+
_app_dir = Path(__file__).resolve().parent
|
|
10
|
+
_repo_root = _app_dir.parent
|
|
11
|
+
if str(_repo_root) not in sys.path:
|
|
12
|
+
sys.path.insert(0, str(_repo_root))
|
|
8
13
|
|
|
9
14
|
from dotenv import load_dotenv
|
|
10
|
-
|
|
15
|
+
|
|
16
|
+
# Load .env from the repo root (works both locally and in Docker)
|
|
17
|
+
if (_repo_root / ".env").exists():
|
|
18
|
+
load_dotenv(_repo_root / ".env")
|
|
19
|
+
if (_app_dir / ".env").exists():
|
|
20
|
+
load_dotenv(_app_dir / ".env")
|
|
11
21
|
|
|
12
22
|
import streamlit as st
|
|
13
23
|
|
|
@@ -28,15 +38,17 @@ from demo.ui.benchmark import render_benchmark_tab
|
|
|
28
38
|
def main():
|
|
29
39
|
render_header()
|
|
30
40
|
render_sidebar()
|
|
31
|
-
|
|
32
|
-
tab_upload, tab_playground, tab_benchmark = st.tabs(
|
|
33
|
-
|
|
41
|
+
|
|
42
|
+
tab_upload, tab_playground, tab_benchmark = st.tabs(
|
|
43
|
+
["📤 Upload", "🎮 Playground", "📊 Benchmarking"]
|
|
44
|
+
)
|
|
45
|
+
|
|
34
46
|
with tab_upload:
|
|
35
47
|
render_upload_tab()
|
|
36
|
-
|
|
48
|
+
|
|
37
49
|
with tab_playground:
|
|
38
50
|
render_playground_tab()
|
|
39
|
-
|
|
51
|
+
|
|
40
52
|
with tab_benchmark:
|
|
41
53
|
render_benchmark_tab()
|
|
42
54
|
|
|
@@ -1,20 +1,23 @@
|
|
|
1
1
|
"""Evaluation runner with UI updates."""
|
|
2
2
|
|
|
3
3
|
import hashlib
|
|
4
|
-
import importlib.util
|
|
5
4
|
import json
|
|
6
5
|
import logging
|
|
7
6
|
import time
|
|
8
7
|
import traceback
|
|
9
8
|
from datetime import datetime
|
|
10
|
-
from pathlib import Path
|
|
11
9
|
from typing import Any, Dict, List, Optional
|
|
12
10
|
|
|
13
11
|
import numpy as np
|
|
14
12
|
import streamlit as st
|
|
15
13
|
import torch
|
|
14
|
+
from qdrant_client.models import FieldCondition, Filter, MatchValue
|
|
16
15
|
|
|
17
16
|
from visual_rag import VisualEmbedder
|
|
17
|
+
from visual_rag.retrieval import MultiVectorRetriever
|
|
18
|
+
from benchmarks.vidore_tatdqa_test.dataset_loader import load_vidore_beir_dataset
|
|
19
|
+
from benchmarks.vidore_tatdqa_test.metrics import ndcg_at_k, mrr_at_k, recall_at_k
|
|
20
|
+
from demo.qdrant_utils import get_qdrant_credentials
|
|
18
21
|
|
|
19
22
|
|
|
20
23
|
TORCH_DTYPE_MAP = {
|
|
@@ -22,49 +25,6 @@ TORCH_DTYPE_MAP = {
|
|
|
22
25
|
"float32": torch.float32,
|
|
23
26
|
"bfloat16": torch.bfloat16,
|
|
24
27
|
}
|
|
25
|
-
from qdrant_client.models import Filter, FieldCondition, MatchValue
|
|
26
|
-
|
|
27
|
-
from visual_rag.retrieval import MultiVectorRetriever
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
def _load_local_benchmark_module(module_filename: str):
|
|
31
|
-
"""
|
|
32
|
-
Load `benchmarks/vidore_tatdqa_test/<module_filename>` via file path.
|
|
33
|
-
|
|
34
|
-
Motivation:
|
|
35
|
-
- Some environments (notably containers / Spaces) can have a third-party
|
|
36
|
-
`benchmarks` package installed, causing `import benchmarks...` to resolve
|
|
37
|
-
to the wrong module.
|
|
38
|
-
- This fallback guarantees we load the repo's benchmark utilities.
|
|
39
|
-
"""
|
|
40
|
-
root = Path(__file__).resolve().parents[1] # demo/.. = repo root
|
|
41
|
-
target = root / "benchmarks" / "vidore_tatdqa_test" / module_filename
|
|
42
|
-
if not target.exists():
|
|
43
|
-
raise ModuleNotFoundError(f"Missing local benchmark module file: {target}")
|
|
44
|
-
|
|
45
|
-
name = f"_visual_rag_toolkit_local_{target.stem}"
|
|
46
|
-
spec = importlib.util.spec_from_file_location(name, str(target))
|
|
47
|
-
if spec is None or spec.loader is None:
|
|
48
|
-
raise ModuleNotFoundError(f"Could not load module spec for: {target}")
|
|
49
|
-
mod = importlib.util.module_from_spec(spec)
|
|
50
|
-
spec.loader.exec_module(mod) # type: ignore[attr-defined]
|
|
51
|
-
return mod
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
try:
|
|
55
|
-
# Preferred: normal import
|
|
56
|
-
from benchmarks.vidore_tatdqa_test.dataset_loader import load_vidore_beir_dataset
|
|
57
|
-
from benchmarks.vidore_tatdqa_test.metrics import ndcg_at_k, mrr_at_k, recall_at_k
|
|
58
|
-
except ModuleNotFoundError:
|
|
59
|
-
# Robust fallback: load from local file paths
|
|
60
|
-
_dl = _load_local_benchmark_module("dataset_loader.py")
|
|
61
|
-
_mx = _load_local_benchmark_module("metrics.py")
|
|
62
|
-
load_vidore_beir_dataset = _dl.load_vidore_beir_dataset
|
|
63
|
-
ndcg_at_k = _mx.ndcg_at_k
|
|
64
|
-
mrr_at_k = _mx.mrr_at_k
|
|
65
|
-
recall_at_k = _mx.recall_at_k
|
|
66
|
-
|
|
67
|
-
from demo.qdrant_utils import get_qdrant_credentials
|
|
68
28
|
|
|
69
29
|
logger = logging.getLogger(__name__)
|
|
70
30
|
logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")
|
|
@@ -0,0 +1,274 @@
|
|
|
1
|
+
"""Indexing runner with UI updates."""
|
|
2
|
+
|
|
3
|
+
import hashlib
|
|
4
|
+
import json
|
|
5
|
+
import time
|
|
6
|
+
import traceback
|
|
7
|
+
from datetime import datetime
|
|
8
|
+
from typing import Any, Dict, Optional
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import streamlit as st
|
|
12
|
+
import torch
|
|
13
|
+
|
|
14
|
+
from visual_rag import VisualEmbedder
|
|
15
|
+
from visual_rag.embedding.pooling import tile_level_mean_pooling
|
|
16
|
+
from visual_rag.indexing.qdrant_indexer import QdrantIndexer
|
|
17
|
+
from benchmarks.vidore_tatdqa_test.dataset_loader import load_vidore_beir_dataset
|
|
18
|
+
from demo.qdrant_utils import get_qdrant_credentials
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
TORCH_DTYPE_MAP = {
|
|
22
|
+
"float16": torch.float16,
|
|
23
|
+
"float32": torch.float32,
|
|
24
|
+
"bfloat16": torch.bfloat16,
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _stable_uuid(text: str) -> str:
|
|
29
|
+
"""Generate a stable UUID from text (same as benchmark script)."""
|
|
30
|
+
hex_str = hashlib.sha256(text.encode("utf-8")).hexdigest()[:32]
|
|
31
|
+
return f"{hex_str[:8]}-{hex_str[8:12]}-{hex_str[12:16]}-{hex_str[16:20]}-{hex_str[20:32]}"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _union_point_id(
|
|
35
|
+
*, dataset_name: str, source_doc_id: str, union_namespace: Optional[str]
|
|
36
|
+
) -> str:
|
|
37
|
+
"""Generate union point ID (same as benchmark script)."""
|
|
38
|
+
ns = f"{union_namespace}::{dataset_name}" if union_namespace else dataset_name
|
|
39
|
+
return _stable_uuid(f"{ns}::{source_doc_id}")
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def run_indexing_with_ui(config: Dict[str, Any]):
|
|
43
|
+
st.divider()
|
|
44
|
+
|
|
45
|
+
print("=" * 60)
|
|
46
|
+
print("[INDEX] Starting indexing via UI")
|
|
47
|
+
print("=" * 60)
|
|
48
|
+
|
|
49
|
+
url, api_key = get_qdrant_credentials()
|
|
50
|
+
if not url:
|
|
51
|
+
st.error("QDRANT_URL not configured")
|
|
52
|
+
return
|
|
53
|
+
|
|
54
|
+
datasets = config.get("datasets", [])
|
|
55
|
+
collection = config["collection"]
|
|
56
|
+
model = config.get("model", "vidore/colpali-v1.3")
|
|
57
|
+
recreate = config.get("recreate", False)
|
|
58
|
+
torch_dtype = config.get("torch_dtype", "float16")
|
|
59
|
+
qdrant_vector_dtype = config.get("qdrant_vector_dtype", "float16")
|
|
60
|
+
prefer_grpc = config.get("prefer_grpc", True)
|
|
61
|
+
batch_size = config.get("batch_size", 4)
|
|
62
|
+
max_docs = config.get("max_docs")
|
|
63
|
+
|
|
64
|
+
print(f"[INDEX] Config: collection={collection}, model={model}")
|
|
65
|
+
print(f"[INDEX] Datasets: {datasets}")
|
|
66
|
+
print(
|
|
67
|
+
f"[INDEX] max_docs={max_docs}, batch_size={batch_size}, recreate={recreate}"
|
|
68
|
+
)
|
|
69
|
+
print(
|
|
70
|
+
f"[INDEX] torch_dtype={torch_dtype}, qdrant_dtype={qdrant_vector_dtype}, grpc={prefer_grpc}"
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
phase1_container = st.container()
|
|
74
|
+
phase2_container = st.container()
|
|
75
|
+
phase3_container = st.container()
|
|
76
|
+
results_container = st.container()
|
|
77
|
+
|
|
78
|
+
try:
|
|
79
|
+
with phase1_container:
|
|
80
|
+
st.markdown("##### 🤖 Phase 1: Loading Model")
|
|
81
|
+
model_status = st.empty()
|
|
82
|
+
model_status.info(f"Loading `{model.split('/')[-1]}`...")
|
|
83
|
+
|
|
84
|
+
print(f"[INDEX] Loading embedder: {model}")
|
|
85
|
+
torch_dtype_obj = TORCH_DTYPE_MAP.get(torch_dtype, torch.float16)
|
|
86
|
+
output_dtype_obj = (
|
|
87
|
+
np.float16 if qdrant_vector_dtype == "float16" else np.float32
|
|
88
|
+
)
|
|
89
|
+
embedder = VisualEmbedder(
|
|
90
|
+
model_name=model,
|
|
91
|
+
torch_dtype=torch_dtype_obj,
|
|
92
|
+
output_dtype=output_dtype_obj,
|
|
93
|
+
)
|
|
94
|
+
embedder._load_model()
|
|
95
|
+
print(
|
|
96
|
+
f"[INDEX] Embedder loaded (torch_dtype={torch_dtype}, output_dtype={qdrant_vector_dtype})"
|
|
97
|
+
)
|
|
98
|
+
model_status.success(f"✅ Model `{model.split('/')[-1]}` loaded")
|
|
99
|
+
|
|
100
|
+
with phase2_container:
|
|
101
|
+
st.markdown("##### 📦 Phase 2: Setting Up Collection")
|
|
102
|
+
|
|
103
|
+
indexer_status = st.empty()
|
|
104
|
+
indexer_status.info("Connecting to Qdrant...")
|
|
105
|
+
|
|
106
|
+
print("[INDEX] Connecting to Qdrant...")
|
|
107
|
+
indexer = QdrantIndexer(
|
|
108
|
+
url=url,
|
|
109
|
+
api_key=api_key,
|
|
110
|
+
collection_name=collection,
|
|
111
|
+
prefer_grpc=prefer_grpc,
|
|
112
|
+
vector_datatype=qdrant_vector_dtype,
|
|
113
|
+
)
|
|
114
|
+
print("[INDEX] Connected to Qdrant")
|
|
115
|
+
indexer_status.success("✅ Connected to Qdrant")
|
|
116
|
+
|
|
117
|
+
coll_status = st.empty()
|
|
118
|
+
action = "Recreating" if recreate else "Creating/verifying"
|
|
119
|
+
coll_status.info(f"{action} collection `{collection}`...")
|
|
120
|
+
|
|
121
|
+
print(f"[INDEX] {action} collection: {collection}")
|
|
122
|
+
indexer.create_collection(force_recreate=recreate)
|
|
123
|
+
indexer.create_payload_indexes(
|
|
124
|
+
fields=[
|
|
125
|
+
{"field": "dataset", "type": "keyword"},
|
|
126
|
+
{"field": "doc_id", "type": "keyword"},
|
|
127
|
+
{"field": "source_doc_id", "type": "keyword"},
|
|
128
|
+
]
|
|
129
|
+
)
|
|
130
|
+
print("[INDEX] Collection ready")
|
|
131
|
+
coll_status.success(f"✅ Collection `{collection}` ready")
|
|
132
|
+
|
|
133
|
+
with phase3_container:
|
|
134
|
+
st.markdown("##### 📊 Phase 3: Processing Datasets")
|
|
135
|
+
|
|
136
|
+
all_results = []
|
|
137
|
+
|
|
138
|
+
for ds_idx, dataset_name in enumerate(datasets):
|
|
139
|
+
ds_short = dataset_name.split("/")[-1]
|
|
140
|
+
ds_container = st.container()
|
|
141
|
+
|
|
142
|
+
with ds_container:
|
|
143
|
+
st.markdown(
|
|
144
|
+
f"**Dataset {ds_idx + 1}/{len(datasets)}: `{ds_short}`**"
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
load_status = st.empty()
|
|
148
|
+
load_status.info(f"Loading dataset `{ds_short}`...")
|
|
149
|
+
|
|
150
|
+
print(f"[INDEX] Loading dataset: {dataset_name}")
|
|
151
|
+
corpus, queries, qrels = load_vidore_beir_dataset(dataset_name)
|
|
152
|
+
total_docs = len(corpus)
|
|
153
|
+
print(f"[INDEX] Dataset loaded: {total_docs} docs")
|
|
154
|
+
load_status.success(f"✅ Loaded {total_docs:,} documents")
|
|
155
|
+
|
|
156
|
+
if max_docs and max_docs < total_docs:
|
|
157
|
+
corpus = corpus[:max_docs]
|
|
158
|
+
print(f"[INDEX] Limiting to {max_docs} docs")
|
|
159
|
+
|
|
160
|
+
progress_bar = st.progress(0)
|
|
161
|
+
status_text = st.empty()
|
|
162
|
+
|
|
163
|
+
uploaded = 0
|
|
164
|
+
failed = 0
|
|
165
|
+
total = len(corpus)
|
|
166
|
+
|
|
167
|
+
for i, doc in enumerate(corpus):
|
|
168
|
+
try:
|
|
169
|
+
doc_id = str(doc.doc_id)
|
|
170
|
+
image = doc.image
|
|
171
|
+
if image is None:
|
|
172
|
+
failed += 1
|
|
173
|
+
continue
|
|
174
|
+
|
|
175
|
+
status_text.text(
|
|
176
|
+
f"Processing {i + 1}/{total}: {doc_id[:30]}..."
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
embeddings, token_infos = embedder.embed_images(
|
|
180
|
+
[image],
|
|
181
|
+
return_token_info=True,
|
|
182
|
+
show_progress=False,
|
|
183
|
+
)
|
|
184
|
+
emb = embeddings[0]
|
|
185
|
+
token_info = token_infos[0] if token_infos else {}
|
|
186
|
+
|
|
187
|
+
if hasattr(emb, "cpu"):
|
|
188
|
+
emb = emb.cpu()
|
|
189
|
+
emb_np = np.asarray(emb, dtype=output_dtype_obj)
|
|
190
|
+
|
|
191
|
+
initial = emb_np.tolist()
|
|
192
|
+
global_pool = emb_np.mean(axis=0).tolist()
|
|
193
|
+
|
|
194
|
+
num_tiles = token_info.get("num_tiles")
|
|
195
|
+
mean_pooling = None
|
|
196
|
+
experimental_pooling = None
|
|
197
|
+
|
|
198
|
+
if num_tiles and num_tiles > 0:
|
|
199
|
+
try:
|
|
200
|
+
mean_pooling = tile_level_mean_pooling(
|
|
201
|
+
emb_np, num_tiles=num_tiles, patches_per_tile=64
|
|
202
|
+
).tolist()
|
|
203
|
+
except Exception:
|
|
204
|
+
pass
|
|
205
|
+
|
|
206
|
+
try:
|
|
207
|
+
exp_pool = embedder.experimental_pool_visual_embedding(
|
|
208
|
+
emb_np, num_tiles=num_tiles
|
|
209
|
+
)
|
|
210
|
+
if exp_pool is not None:
|
|
211
|
+
experimental_pooling = exp_pool.tolist()
|
|
212
|
+
except Exception:
|
|
213
|
+
pass
|
|
214
|
+
|
|
215
|
+
union_doc_id = _union_point_id(
|
|
216
|
+
dataset_name=dataset_name,
|
|
217
|
+
source_doc_id=doc_id,
|
|
218
|
+
union_namespace=collection,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
payload = {
|
|
222
|
+
"dataset": dataset_name,
|
|
223
|
+
"doc_id": doc_id,
|
|
224
|
+
"source_doc_id": doc_id,
|
|
225
|
+
"union_doc_id": union_doc_id,
|
|
226
|
+
"num_tiles": num_tiles,
|
|
227
|
+
"num_visual_tokens": token_info.get("num_visual_tokens"),
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
vectors = {"initial": initial, "global_pooling": global_pool}
|
|
231
|
+
if mean_pooling:
|
|
232
|
+
vectors["mean_pooling"] = mean_pooling
|
|
233
|
+
if experimental_pooling:
|
|
234
|
+
vectors["experimental_pooling"] = experimental_pooling
|
|
235
|
+
|
|
236
|
+
indexer.upsert_point(
|
|
237
|
+
point_id=union_doc_id,
|
|
238
|
+
vectors=vectors,
|
|
239
|
+
payload=payload,
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
uploaded += 1
|
|
243
|
+
|
|
244
|
+
except Exception as e:
|
|
245
|
+
print(f"[INDEX] Error on doc {i}: {e}")
|
|
246
|
+
failed += 1
|
|
247
|
+
|
|
248
|
+
progress_bar.progress((i + 1) / total)
|
|
249
|
+
|
|
250
|
+
status_text.text(f"✅ Done: {uploaded} uploaded, {failed} failed")
|
|
251
|
+
all_results.append(
|
|
252
|
+
{
|
|
253
|
+
"dataset": dataset_name,
|
|
254
|
+
"total": total,
|
|
255
|
+
"uploaded": uploaded,
|
|
256
|
+
"failed": failed,
|
|
257
|
+
}
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
with results_container:
|
|
261
|
+
st.markdown("##### 📋 Results Summary")
|
|
262
|
+
|
|
263
|
+
for r in all_results:
|
|
264
|
+
st.write(
|
|
265
|
+
f"**{r['dataset'].split('/')[-1]}**: {r['uploaded']:,} uploaded, {r['failed']:,} failed"
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
st.success("✅ Indexing complete!")
|
|
269
|
+
|
|
270
|
+
except Exception as e:
|
|
271
|
+
st.error(f"Indexing error: {e}")
|
|
272
|
+
st.code(traceback.format_exc())
|
|
273
|
+
print(f"[INDEX] ERROR: {e}")
|
|
274
|
+
traceback.print_exc()
|
|
@@ -8,12 +8,19 @@ import streamlit as st
|
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
def get_qdrant_credentials() -> Tuple[Optional[str], Optional[str]]:
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
11
|
+
"""Get Qdrant credentials from session state or environment variables.
|
|
12
|
+
|
|
13
|
+
Priority: session_state > QDRANT_URL/QDRANT_API_KEY > legacy env vars
|
|
14
|
+
"""
|
|
15
|
+
url = (
|
|
16
|
+
st.session_state.get("qdrant_url_input")
|
|
17
|
+
or os.getenv("QDRANT_URL")
|
|
18
|
+
or os.getenv("SIGIR_QDRANT_URL") # legacy
|
|
19
|
+
)
|
|
20
|
+
api_key = (
|
|
21
|
+
st.session_state.get("qdrant_key_input")
|
|
16
22
|
or os.getenv("QDRANT_API_KEY")
|
|
23
|
+
or os.getenv("SIGIR_QDRANT_KEY") # legacy
|
|
17
24
|
)
|
|
18
25
|
return url, api_key
|
|
19
26
|
|
|
@@ -9,6 +9,7 @@ from demo.qdrant_utils import (
|
|
|
9
9
|
sample_points_cached,
|
|
10
10
|
search_collection,
|
|
11
11
|
)
|
|
12
|
+
from visual_rag.retrieval import MultiVectorRetriever
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
def render_playground_tab():
|
|
@@ -46,7 +47,6 @@ def render_playground_tab():
|
|
|
46
47
|
if not st.session_state.get("model_loaded"):
|
|
47
48
|
with st.spinner(f"Loading {model_short}..."):
|
|
48
49
|
try:
|
|
49
|
-
from visual_rag.retrieval import MultiVectorRetriever
|
|
50
50
|
_ = MultiVectorRetriever(collection_name=active_collection, model_name=model_name)
|
|
51
51
|
st.session_state["model_loaded"] = True
|
|
52
52
|
st.session_state["loaded_model_key"] = cache_key
|