visual-rag-toolkit 0.1.2__tar.gz → 0.1.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/PKG-INFO +24 -15
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/README.md +23 -14
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/demo/app.py +20 -8
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/demo/evaluation.py +5 -45
- visual_rag_toolkit-0.1.3/demo/indexing.py +274 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/demo/qdrant_utils.py +12 -5
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/demo/ui/playground.py +1 -1
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/demo/ui/sidebar.py +4 -3
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/demo/ui/upload.py +5 -4
- visual_rag_toolkit-0.1.3/examples/COMMANDS.md +83 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/examples/config.yaml +6 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/examples/process_pdfs.py +6 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/examples/search_demo.py +6 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/pyproject.toml +1 -1
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/visual_rag/__init__.py +43 -1
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/visual_rag/config.py +4 -7
- visual_rag_toolkit-0.1.3/visual_rag/indexing/__init__.py +38 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/visual_rag/indexing/qdrant_indexer.py +92 -42
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/visual_rag/retrieval/multi_vector.py +63 -65
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/visual_rag/retrieval/single_stage.py +7 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/visual_rag/retrieval/two_stage.py +8 -10
- visual_rag_toolkit-0.1.2/demo/indexing.py +0 -315
- visual_rag_toolkit-0.1.2/visual_rag/indexing/__init__.py +0 -21
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/.github/workflows/ci.yaml +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/.github/workflows/publish_pypi.yaml +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/.gitignore +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/LICENSE +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/benchmarks/README.md +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/benchmarks/__init__.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/benchmarks/analyze_results.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/benchmarks/benchmark_datasets.txt +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/benchmarks/prepare_submission.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/benchmarks/quick_test.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/benchmarks/run_vidore.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/benchmarks/vidore_beir_qdrant/run_qdrant_beir.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/benchmarks/vidore_tatdqa_test/__init__.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/benchmarks/vidore_tatdqa_test/dataset_loader.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/benchmarks/vidore_tatdqa_test/metrics.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/benchmarks/vidore_tatdqa_test/run_qdrant.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/benchmarks/vidore_tatdqa_test/sweep_eval.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/demo/__init__.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/demo/commands.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/demo/config.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/demo/download_models.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/demo/example_metadata_mapping_sigir.json +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/demo/results.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/demo/test_qdrant_connection.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/demo/ui/__init__.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/demo/ui/benchmark.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/demo/ui/header.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/requirements.txt +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/tests/__init__.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/tests/test_config.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/tests/test_pdf_processor.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/tests/test_pooling.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/tests/test_strategies.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/visual_rag/cli/__init__.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/visual_rag/cli/main.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/visual_rag/demo_runner.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/visual_rag/embedding/__init__.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/visual_rag/embedding/pooling.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/visual_rag/embedding/visual_embedder.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/visual_rag/indexing/cloudinary_uploader.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/visual_rag/indexing/pdf_processor.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/visual_rag/indexing/pipeline.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/visual_rag/preprocessing/__init__.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/visual_rag/preprocessing/crop_empty.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/visual_rag/qdrant_admin.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/visual_rag/retrieval/__init__.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/visual_rag/retrieval/three_stage.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/visual_rag/visualization/__init__.py +0 -0
- {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/visual_rag/visualization/saliency.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: visual-rag-toolkit
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.3
|
|
4
4
|
Summary: End-to-end visual document retrieval with ColPali, featuring two-stage pooling for scalable search
|
|
5
5
|
Project-URL: Homepage, https://github.com/Ara-Yeroyan/visual-rag-toolkit
|
|
6
6
|
Project-URL: Documentation, https://github.com/Ara-Yeroyan/visual-rag-toolkit#readme
|
|
@@ -88,11 +88,6 @@ Description-Content-Type: text/markdown
|
|
|
88
88
|
[](https://pypi.org/project/visual-rag-toolkit/)
|
|
89
89
|
[](https://pypi.org/project/visual-rag-toolkit/)
|
|
90
90
|
[](LICENSE)
|
|
91
|
-
[](https://github.com/Ara-Yeroyan/visual-rag-toolkit/actions/workflows/ci.yaml)
|
|
92
|
-
|
|
93
|
-
Note:
|
|
94
|
-
- The **PyPI badge** shows “not found” until the first release is published.
|
|
95
|
-
- The **CI badge** requires the GitHub repo to be **public** (GitHub does not serve Actions badges for private repos).
|
|
96
91
|
|
|
97
92
|
End-to-end visual document retrieval toolkit featuring **fast multi-stage retrieval** (prefetch with pooled vectors + exact MaxSim reranking).
|
|
98
93
|
|
|
@@ -162,7 +157,7 @@ for r in results[:3]:
|
|
|
162
157
|
|
|
163
158
|
### End-to-end: ingest PDFs (with cropping) → index in Qdrant
|
|
164
159
|
|
|
165
|
-
This is the
|
|
160
|
+
This is the "SDK-style" pipeline: PDF → images → optional crop → embed → store vectors + payload in Qdrant.
|
|
166
161
|
|
|
167
162
|
```python
|
|
168
163
|
import os
|
|
@@ -174,8 +169,8 @@ import torch
|
|
|
174
169
|
from visual_rag import VisualEmbedder
|
|
175
170
|
from visual_rag.indexing import ProcessingPipeline, QdrantIndexer
|
|
176
171
|
|
|
177
|
-
QDRANT_URL = os.environ["
|
|
178
|
-
QDRANT_KEY = os.getenv("
|
|
172
|
+
QDRANT_URL = os.environ["QDRANT_URL"]
|
|
173
|
+
QDRANT_KEY = os.getenv("QDRANT_API_KEY", "")
|
|
179
174
|
|
|
180
175
|
collection = "my_visual_docs"
|
|
181
176
|
|
|
@@ -193,6 +188,8 @@ indexer = QdrantIndexer(
|
|
|
193
188
|
prefer_grpc=True,
|
|
194
189
|
vector_datatype="float16",
|
|
195
190
|
)
|
|
191
|
+
|
|
192
|
+
# Creates collection + required payload indexes (e.g., "filename" for skip_existing)
|
|
196
193
|
indexer.create_collection(force_recreate=False)
|
|
197
194
|
|
|
198
195
|
pipeline = ProcessingPipeline(
|
|
@@ -208,19 +205,32 @@ pipeline = ProcessingPipeline(
|
|
|
208
205
|
|
|
209
206
|
pdfs = [Path("docs/a.pdf"), Path("docs/b.pdf")]
|
|
210
207
|
for pdf_path in pdfs:
|
|
211
|
-
pipeline.process_pdf(
|
|
208
|
+
result = pipeline.process_pdf(
|
|
212
209
|
pdf_path,
|
|
213
|
-
skip_existing=True,
|
|
210
|
+
skip_existing=True, # Skip pages already in Qdrant (uses filename index)
|
|
214
211
|
upload_to_cloudinary=False,
|
|
215
212
|
upload_to_qdrant=True,
|
|
216
213
|
)
|
|
214
|
+
# Logs automatically shown:
|
|
215
|
+
# [10:23:45] 📚 Processing PDF: a.pdf
|
|
216
|
+
# [10:23:45] 🖼️ Converting PDF to images...
|
|
217
|
+
# [10:23:46] ✅ Converted 12 pages
|
|
218
|
+
# [10:23:46] 📦 Processing pages 1-8/12
|
|
219
|
+
# [10:23:46] 🤖 Generating embeddings for 8 pages...
|
|
220
|
+
# [10:23:48] 📤 Uploading batch of 8 pages...
|
|
221
|
+
# [10:23:48] ✅ Uploaded 8 points to Qdrant
|
|
222
|
+
# [10:23:48] 📦 Processing pages 9-12/12
|
|
223
|
+
# [10:23:48] 🤖 Generating embeddings for 4 pages...
|
|
224
|
+
# [10:23:50] 📤 Uploading batch of 4 pages...
|
|
225
|
+
# [10:23:50] ✅ Uploaded 4 points to Qdrant
|
|
226
|
+
# [10:23:50] ✅ Completed a.pdf: 12 uploaded, 0 skipped, 0 failed
|
|
217
227
|
```
|
|
218
228
|
|
|
219
229
|
CLI equivalent:
|
|
220
230
|
|
|
221
231
|
```bash
|
|
222
|
-
export
|
|
223
|
-
export
|
|
232
|
+
export QDRANT_URL="https://YOUR_QDRANT"
|
|
233
|
+
export QDRANT_API_KEY="YOUR_KEY"
|
|
224
234
|
|
|
225
235
|
visual-rag process \
|
|
226
236
|
--reports-dir ./docs \
|
|
@@ -263,7 +273,7 @@ Stage 2: Exact MaxSim reranking on candidates
|
|
|
263
273
|
└── Return top-k results (e.g., 10)
|
|
264
274
|
```
|
|
265
275
|
|
|
266
|
-
Three-stage extends this with an additional
|
|
276
|
+
Three-stage extends this with an additional "cheap prefetch" stage before stage 2.
|
|
267
277
|
|
|
268
278
|
## 📁 Package Structure
|
|
269
279
|
|
|
@@ -374,4 +384,3 @@ MIT License - see [LICENSE](LICENSE) for details.
|
|
|
374
384
|
- [Qdrant](https://qdrant.tech/) - Vector database with multi-vector support
|
|
375
385
|
- [ColPali](https://github.com/illuin-tech/colpali) - Visual document retrieval models
|
|
376
386
|
- [ViDoRe](https://huggingface.co/spaces/vidore/vidore-leaderboard) - Benchmark dataset
|
|
377
|
-
|
|
@@ -3,11 +3,6 @@
|
|
|
3
3
|
[](https://pypi.org/project/visual-rag-toolkit/)
|
|
4
4
|
[](https://pypi.org/project/visual-rag-toolkit/)
|
|
5
5
|
[](LICENSE)
|
|
6
|
-
[](https://github.com/Ara-Yeroyan/visual-rag-toolkit/actions/workflows/ci.yaml)
|
|
7
|
-
|
|
8
|
-
Note:
|
|
9
|
-
- The **PyPI badge** shows “not found” until the first release is published.
|
|
10
|
-
- The **CI badge** requires the GitHub repo to be **public** (GitHub does not serve Actions badges for private repos).
|
|
11
6
|
|
|
12
7
|
End-to-end visual document retrieval toolkit featuring **fast multi-stage retrieval** (prefetch with pooled vectors + exact MaxSim reranking).
|
|
13
8
|
|
|
@@ -77,7 +72,7 @@ for r in results[:3]:
|
|
|
77
72
|
|
|
78
73
|
### End-to-end: ingest PDFs (with cropping) → index in Qdrant
|
|
79
74
|
|
|
80
|
-
This is the
|
|
75
|
+
This is the "SDK-style" pipeline: PDF → images → optional crop → embed → store vectors + payload in Qdrant.
|
|
81
76
|
|
|
82
77
|
```python
|
|
83
78
|
import os
|
|
@@ -89,8 +84,8 @@ import torch
|
|
|
89
84
|
from visual_rag import VisualEmbedder
|
|
90
85
|
from visual_rag.indexing import ProcessingPipeline, QdrantIndexer
|
|
91
86
|
|
|
92
|
-
QDRANT_URL = os.environ["
|
|
93
|
-
QDRANT_KEY = os.getenv("
|
|
87
|
+
QDRANT_URL = os.environ["QDRANT_URL"]
|
|
88
|
+
QDRANT_KEY = os.getenv("QDRANT_API_KEY", "")
|
|
94
89
|
|
|
95
90
|
collection = "my_visual_docs"
|
|
96
91
|
|
|
@@ -108,6 +103,8 @@ indexer = QdrantIndexer(
|
|
|
108
103
|
prefer_grpc=True,
|
|
109
104
|
vector_datatype="float16",
|
|
110
105
|
)
|
|
106
|
+
|
|
107
|
+
# Creates collection + required payload indexes (e.g., "filename" for skip_existing)
|
|
111
108
|
indexer.create_collection(force_recreate=False)
|
|
112
109
|
|
|
113
110
|
pipeline = ProcessingPipeline(
|
|
@@ -123,19 +120,32 @@ pipeline = ProcessingPipeline(
|
|
|
123
120
|
|
|
124
121
|
pdfs = [Path("docs/a.pdf"), Path("docs/b.pdf")]
|
|
125
122
|
for pdf_path in pdfs:
|
|
126
|
-
pipeline.process_pdf(
|
|
123
|
+
result = pipeline.process_pdf(
|
|
127
124
|
pdf_path,
|
|
128
|
-
skip_existing=True,
|
|
125
|
+
skip_existing=True, # Skip pages already in Qdrant (uses filename index)
|
|
129
126
|
upload_to_cloudinary=False,
|
|
130
127
|
upload_to_qdrant=True,
|
|
131
128
|
)
|
|
129
|
+
# Logs automatically shown:
|
|
130
|
+
# [10:23:45] 📚 Processing PDF: a.pdf
|
|
131
|
+
# [10:23:45] 🖼️ Converting PDF to images...
|
|
132
|
+
# [10:23:46] ✅ Converted 12 pages
|
|
133
|
+
# [10:23:46] 📦 Processing pages 1-8/12
|
|
134
|
+
# [10:23:46] 🤖 Generating embeddings for 8 pages...
|
|
135
|
+
# [10:23:48] 📤 Uploading batch of 8 pages...
|
|
136
|
+
# [10:23:48] ✅ Uploaded 8 points to Qdrant
|
|
137
|
+
# [10:23:48] 📦 Processing pages 9-12/12
|
|
138
|
+
# [10:23:48] 🤖 Generating embeddings for 4 pages...
|
|
139
|
+
# [10:23:50] 📤 Uploading batch of 4 pages...
|
|
140
|
+
# [10:23:50] ✅ Uploaded 4 points to Qdrant
|
|
141
|
+
# [10:23:50] ✅ Completed a.pdf: 12 uploaded, 0 skipped, 0 failed
|
|
132
142
|
```
|
|
133
143
|
|
|
134
144
|
CLI equivalent:
|
|
135
145
|
|
|
136
146
|
```bash
|
|
137
|
-
export
|
|
138
|
-
export
|
|
147
|
+
export QDRANT_URL="https://YOUR_QDRANT"
|
|
148
|
+
export QDRANT_API_KEY="YOUR_KEY"
|
|
139
149
|
|
|
140
150
|
visual-rag process \
|
|
141
151
|
--reports-dir ./docs \
|
|
@@ -178,7 +188,7 @@ Stage 2: Exact MaxSim reranking on candidates
|
|
|
178
188
|
└── Return top-k results (e.g., 10)
|
|
179
189
|
```
|
|
180
190
|
|
|
181
|
-
Three-stage extends this with an additional
|
|
191
|
+
Three-stage extends this with an additional "cheap prefetch" stage before stage 2.
|
|
182
192
|
|
|
183
193
|
## 📁 Package Structure
|
|
184
194
|
|
|
@@ -289,4 +299,3 @@ MIT License - see [LICENSE](LICENSE) for details.
|
|
|
289
299
|
- [Qdrant](https://qdrant.tech/) - Vector database with multi-vector support
|
|
290
300
|
- [ColPali](https://github.com/illuin-tech/colpali) - Visual document retrieval models
|
|
291
301
|
- [ViDoRe](https://huggingface.co/spaces/vidore/vidore-leaderboard) - Benchmark dataset
|
|
292
|
-
|
|
@@ -1,13 +1,23 @@
|
|
|
1
1
|
"""Main entry point for the Visual RAG Toolkit demo application."""
|
|
2
2
|
|
|
3
|
+
import os
|
|
3
4
|
import sys
|
|
4
5
|
from pathlib import Path
|
|
5
6
|
|
|
6
|
-
|
|
7
|
-
|
|
7
|
+
# Ensure repo root is in sys.path for local development
|
|
8
|
+
# (In HF Space / Docker, PYTHONPATH is already set correctly)
|
|
9
|
+
_app_dir = Path(__file__).resolve().parent
|
|
10
|
+
_repo_root = _app_dir.parent
|
|
11
|
+
if str(_repo_root) not in sys.path:
|
|
12
|
+
sys.path.insert(0, str(_repo_root))
|
|
8
13
|
|
|
9
14
|
from dotenv import load_dotenv
|
|
10
|
-
|
|
15
|
+
|
|
16
|
+
# Load .env from the repo root (works both locally and in Docker)
|
|
17
|
+
if (_repo_root / ".env").exists():
|
|
18
|
+
load_dotenv(_repo_root / ".env")
|
|
19
|
+
if (_app_dir / ".env").exists():
|
|
20
|
+
load_dotenv(_app_dir / ".env")
|
|
11
21
|
|
|
12
22
|
import streamlit as st
|
|
13
23
|
|
|
@@ -28,15 +38,17 @@ from demo.ui.benchmark import render_benchmark_tab
|
|
|
28
38
|
def main():
|
|
29
39
|
render_header()
|
|
30
40
|
render_sidebar()
|
|
31
|
-
|
|
32
|
-
tab_upload, tab_playground, tab_benchmark = st.tabs(
|
|
33
|
-
|
|
41
|
+
|
|
42
|
+
tab_upload, tab_playground, tab_benchmark = st.tabs(
|
|
43
|
+
["📤 Upload", "🎮 Playground", "📊 Benchmarking"]
|
|
44
|
+
)
|
|
45
|
+
|
|
34
46
|
with tab_upload:
|
|
35
47
|
render_upload_tab()
|
|
36
|
-
|
|
48
|
+
|
|
37
49
|
with tab_playground:
|
|
38
50
|
render_playground_tab()
|
|
39
|
-
|
|
51
|
+
|
|
40
52
|
with tab_benchmark:
|
|
41
53
|
render_benchmark_tab()
|
|
42
54
|
|
|
@@ -1,20 +1,23 @@
|
|
|
1
1
|
"""Evaluation runner with UI updates."""
|
|
2
2
|
|
|
3
3
|
import hashlib
|
|
4
|
-
import importlib.util
|
|
5
4
|
import json
|
|
6
5
|
import logging
|
|
7
6
|
import time
|
|
8
7
|
import traceback
|
|
9
8
|
from datetime import datetime
|
|
10
|
-
from pathlib import Path
|
|
11
9
|
from typing import Any, Dict, List, Optional
|
|
12
10
|
|
|
13
11
|
import numpy as np
|
|
14
12
|
import streamlit as st
|
|
15
13
|
import torch
|
|
14
|
+
from qdrant_client.models import FieldCondition, Filter, MatchValue
|
|
16
15
|
|
|
17
16
|
from visual_rag import VisualEmbedder
|
|
17
|
+
from visual_rag.retrieval import MultiVectorRetriever
|
|
18
|
+
from benchmarks.vidore_tatdqa_test.dataset_loader import load_vidore_beir_dataset
|
|
19
|
+
from benchmarks.vidore_tatdqa_test.metrics import ndcg_at_k, mrr_at_k, recall_at_k
|
|
20
|
+
from demo.qdrant_utils import get_qdrant_credentials
|
|
18
21
|
|
|
19
22
|
|
|
20
23
|
TORCH_DTYPE_MAP = {
|
|
@@ -22,49 +25,6 @@ TORCH_DTYPE_MAP = {
|
|
|
22
25
|
"float32": torch.float32,
|
|
23
26
|
"bfloat16": torch.bfloat16,
|
|
24
27
|
}
|
|
25
|
-
from qdrant_client.models import Filter, FieldCondition, MatchValue
|
|
26
|
-
|
|
27
|
-
from visual_rag.retrieval import MultiVectorRetriever
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
def _load_local_benchmark_module(module_filename: str):
|
|
31
|
-
"""
|
|
32
|
-
Load `benchmarks/vidore_tatdqa_test/<module_filename>` via file path.
|
|
33
|
-
|
|
34
|
-
Motivation:
|
|
35
|
-
- Some environments (notably containers / Spaces) can have a third-party
|
|
36
|
-
`benchmarks` package installed, causing `import benchmarks...` to resolve
|
|
37
|
-
to the wrong module.
|
|
38
|
-
- This fallback guarantees we load the repo's benchmark utilities.
|
|
39
|
-
"""
|
|
40
|
-
root = Path(__file__).resolve().parents[1] # demo/.. = repo root
|
|
41
|
-
target = root / "benchmarks" / "vidore_tatdqa_test" / module_filename
|
|
42
|
-
if not target.exists():
|
|
43
|
-
raise ModuleNotFoundError(f"Missing local benchmark module file: {target}")
|
|
44
|
-
|
|
45
|
-
name = f"_visual_rag_toolkit_local_{target.stem}"
|
|
46
|
-
spec = importlib.util.spec_from_file_location(name, str(target))
|
|
47
|
-
if spec is None or spec.loader is None:
|
|
48
|
-
raise ModuleNotFoundError(f"Could not load module spec for: {target}")
|
|
49
|
-
mod = importlib.util.module_from_spec(spec)
|
|
50
|
-
spec.loader.exec_module(mod) # type: ignore[attr-defined]
|
|
51
|
-
return mod
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
try:
|
|
55
|
-
# Preferred: normal import
|
|
56
|
-
from benchmarks.vidore_tatdqa_test.dataset_loader import load_vidore_beir_dataset
|
|
57
|
-
from benchmarks.vidore_tatdqa_test.metrics import ndcg_at_k, mrr_at_k, recall_at_k
|
|
58
|
-
except ModuleNotFoundError:
|
|
59
|
-
# Robust fallback: load from local file paths
|
|
60
|
-
_dl = _load_local_benchmark_module("dataset_loader.py")
|
|
61
|
-
_mx = _load_local_benchmark_module("metrics.py")
|
|
62
|
-
load_vidore_beir_dataset = _dl.load_vidore_beir_dataset
|
|
63
|
-
ndcg_at_k = _mx.ndcg_at_k
|
|
64
|
-
mrr_at_k = _mx.mrr_at_k
|
|
65
|
-
recall_at_k = _mx.recall_at_k
|
|
66
|
-
|
|
67
|
-
from demo.qdrant_utils import get_qdrant_credentials
|
|
68
28
|
|
|
69
29
|
logger = logging.getLogger(__name__)
|
|
70
30
|
logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")
|
|
@@ -0,0 +1,274 @@
|
|
|
1
|
+
"""Indexing runner with UI updates."""
|
|
2
|
+
|
|
3
|
+
import hashlib
|
|
4
|
+
import json
|
|
5
|
+
import time
|
|
6
|
+
import traceback
|
|
7
|
+
from datetime import datetime
|
|
8
|
+
from typing import Any, Dict, Optional
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import streamlit as st
|
|
12
|
+
import torch
|
|
13
|
+
|
|
14
|
+
from visual_rag import VisualEmbedder
|
|
15
|
+
from visual_rag.embedding.pooling import tile_level_mean_pooling
|
|
16
|
+
from visual_rag.indexing.qdrant_indexer import QdrantIndexer
|
|
17
|
+
from benchmarks.vidore_tatdqa_test.dataset_loader import load_vidore_beir_dataset
|
|
18
|
+
from demo.qdrant_utils import get_qdrant_credentials
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
TORCH_DTYPE_MAP = {
|
|
22
|
+
"float16": torch.float16,
|
|
23
|
+
"float32": torch.float32,
|
|
24
|
+
"bfloat16": torch.bfloat16,
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _stable_uuid(text: str) -> str:
|
|
29
|
+
"""Generate a stable UUID from text (same as benchmark script)."""
|
|
30
|
+
hex_str = hashlib.sha256(text.encode("utf-8")).hexdigest()[:32]
|
|
31
|
+
return f"{hex_str[:8]}-{hex_str[8:12]}-{hex_str[12:16]}-{hex_str[16:20]}-{hex_str[20:32]}"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _union_point_id(
|
|
35
|
+
*, dataset_name: str, source_doc_id: str, union_namespace: Optional[str]
|
|
36
|
+
) -> str:
|
|
37
|
+
"""Generate union point ID (same as benchmark script)."""
|
|
38
|
+
ns = f"{union_namespace}::{dataset_name}" if union_namespace else dataset_name
|
|
39
|
+
return _stable_uuid(f"{ns}::{source_doc_id}")
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def run_indexing_with_ui(config: Dict[str, Any]):
|
|
43
|
+
st.divider()
|
|
44
|
+
|
|
45
|
+
print("=" * 60)
|
|
46
|
+
print("[INDEX] Starting indexing via UI")
|
|
47
|
+
print("=" * 60)
|
|
48
|
+
|
|
49
|
+
url, api_key = get_qdrant_credentials()
|
|
50
|
+
if not url:
|
|
51
|
+
st.error("QDRANT_URL not configured")
|
|
52
|
+
return
|
|
53
|
+
|
|
54
|
+
datasets = config.get("datasets", [])
|
|
55
|
+
collection = config["collection"]
|
|
56
|
+
model = config.get("model", "vidore/colpali-v1.3")
|
|
57
|
+
recreate = config.get("recreate", False)
|
|
58
|
+
torch_dtype = config.get("torch_dtype", "float16")
|
|
59
|
+
qdrant_vector_dtype = config.get("qdrant_vector_dtype", "float16")
|
|
60
|
+
prefer_grpc = config.get("prefer_grpc", True)
|
|
61
|
+
batch_size = config.get("batch_size", 4)
|
|
62
|
+
max_docs = config.get("max_docs")
|
|
63
|
+
|
|
64
|
+
print(f"[INDEX] Config: collection={collection}, model={model}")
|
|
65
|
+
print(f"[INDEX] Datasets: {datasets}")
|
|
66
|
+
print(
|
|
67
|
+
f"[INDEX] max_docs={max_docs}, batch_size={batch_size}, recreate={recreate}"
|
|
68
|
+
)
|
|
69
|
+
print(
|
|
70
|
+
f"[INDEX] torch_dtype={torch_dtype}, qdrant_dtype={qdrant_vector_dtype}, grpc={prefer_grpc}"
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
phase1_container = st.container()
|
|
74
|
+
phase2_container = st.container()
|
|
75
|
+
phase3_container = st.container()
|
|
76
|
+
results_container = st.container()
|
|
77
|
+
|
|
78
|
+
try:
|
|
79
|
+
with phase1_container:
|
|
80
|
+
st.markdown("##### 🤖 Phase 1: Loading Model")
|
|
81
|
+
model_status = st.empty()
|
|
82
|
+
model_status.info(f"Loading `{model.split('/')[-1]}`...")
|
|
83
|
+
|
|
84
|
+
print(f"[INDEX] Loading embedder: {model}")
|
|
85
|
+
torch_dtype_obj = TORCH_DTYPE_MAP.get(torch_dtype, torch.float16)
|
|
86
|
+
output_dtype_obj = (
|
|
87
|
+
np.float16 if qdrant_vector_dtype == "float16" else np.float32
|
|
88
|
+
)
|
|
89
|
+
embedder = VisualEmbedder(
|
|
90
|
+
model_name=model,
|
|
91
|
+
torch_dtype=torch_dtype_obj,
|
|
92
|
+
output_dtype=output_dtype_obj,
|
|
93
|
+
)
|
|
94
|
+
embedder._load_model()
|
|
95
|
+
print(
|
|
96
|
+
f"[INDEX] Embedder loaded (torch_dtype={torch_dtype}, output_dtype={qdrant_vector_dtype})"
|
|
97
|
+
)
|
|
98
|
+
model_status.success(f"✅ Model `{model.split('/')[-1]}` loaded")
|
|
99
|
+
|
|
100
|
+
with phase2_container:
|
|
101
|
+
st.markdown("##### 📦 Phase 2: Setting Up Collection")
|
|
102
|
+
|
|
103
|
+
indexer_status = st.empty()
|
|
104
|
+
indexer_status.info("Connecting to Qdrant...")
|
|
105
|
+
|
|
106
|
+
print("[INDEX] Connecting to Qdrant...")
|
|
107
|
+
indexer = QdrantIndexer(
|
|
108
|
+
url=url,
|
|
109
|
+
api_key=api_key,
|
|
110
|
+
collection_name=collection,
|
|
111
|
+
prefer_grpc=prefer_grpc,
|
|
112
|
+
vector_datatype=qdrant_vector_dtype,
|
|
113
|
+
)
|
|
114
|
+
print("[INDEX] Connected to Qdrant")
|
|
115
|
+
indexer_status.success("✅ Connected to Qdrant")
|
|
116
|
+
|
|
117
|
+
coll_status = st.empty()
|
|
118
|
+
action = "Recreating" if recreate else "Creating/verifying"
|
|
119
|
+
coll_status.info(f"{action} collection `{collection}`...")
|
|
120
|
+
|
|
121
|
+
print(f"[INDEX] {action} collection: {collection}")
|
|
122
|
+
indexer.create_collection(force_recreate=recreate)
|
|
123
|
+
indexer.create_payload_indexes(
|
|
124
|
+
fields=[
|
|
125
|
+
{"field": "dataset", "type": "keyword"},
|
|
126
|
+
{"field": "doc_id", "type": "keyword"},
|
|
127
|
+
{"field": "source_doc_id", "type": "keyword"},
|
|
128
|
+
]
|
|
129
|
+
)
|
|
130
|
+
print("[INDEX] Collection ready")
|
|
131
|
+
coll_status.success(f"✅ Collection `{collection}` ready")
|
|
132
|
+
|
|
133
|
+
with phase3_container:
|
|
134
|
+
st.markdown("##### 📊 Phase 3: Processing Datasets")
|
|
135
|
+
|
|
136
|
+
all_results = []
|
|
137
|
+
|
|
138
|
+
for ds_idx, dataset_name in enumerate(datasets):
|
|
139
|
+
ds_short = dataset_name.split("/")[-1]
|
|
140
|
+
ds_container = st.container()
|
|
141
|
+
|
|
142
|
+
with ds_container:
|
|
143
|
+
st.markdown(
|
|
144
|
+
f"**Dataset {ds_idx + 1}/{len(datasets)}: `{ds_short}`**"
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
load_status = st.empty()
|
|
148
|
+
load_status.info(f"Loading dataset `{ds_short}`...")
|
|
149
|
+
|
|
150
|
+
print(f"[INDEX] Loading dataset: {dataset_name}")
|
|
151
|
+
corpus, queries, qrels = load_vidore_beir_dataset(dataset_name)
|
|
152
|
+
total_docs = len(corpus)
|
|
153
|
+
print(f"[INDEX] Dataset loaded: {total_docs} docs")
|
|
154
|
+
load_status.success(f"✅ Loaded {total_docs:,} documents")
|
|
155
|
+
|
|
156
|
+
if max_docs and max_docs < total_docs:
|
|
157
|
+
corpus = corpus[:max_docs]
|
|
158
|
+
print(f"[INDEX] Limiting to {max_docs} docs")
|
|
159
|
+
|
|
160
|
+
progress_bar = st.progress(0)
|
|
161
|
+
status_text = st.empty()
|
|
162
|
+
|
|
163
|
+
uploaded = 0
|
|
164
|
+
failed = 0
|
|
165
|
+
total = len(corpus)
|
|
166
|
+
|
|
167
|
+
for i, doc in enumerate(corpus):
|
|
168
|
+
try:
|
|
169
|
+
doc_id = str(doc.doc_id)
|
|
170
|
+
image = doc.image
|
|
171
|
+
if image is None:
|
|
172
|
+
failed += 1
|
|
173
|
+
continue
|
|
174
|
+
|
|
175
|
+
status_text.text(
|
|
176
|
+
f"Processing {i + 1}/{total}: {doc_id[:30]}..."
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
embeddings, token_infos = embedder.embed_images(
|
|
180
|
+
[image],
|
|
181
|
+
return_token_info=True,
|
|
182
|
+
show_progress=False,
|
|
183
|
+
)
|
|
184
|
+
emb = embeddings[0]
|
|
185
|
+
token_info = token_infos[0] if token_infos else {}
|
|
186
|
+
|
|
187
|
+
if hasattr(emb, "cpu"):
|
|
188
|
+
emb = emb.cpu()
|
|
189
|
+
emb_np = np.asarray(emb, dtype=output_dtype_obj)
|
|
190
|
+
|
|
191
|
+
initial = emb_np.tolist()
|
|
192
|
+
global_pool = emb_np.mean(axis=0).tolist()
|
|
193
|
+
|
|
194
|
+
num_tiles = token_info.get("num_tiles")
|
|
195
|
+
mean_pooling = None
|
|
196
|
+
experimental_pooling = None
|
|
197
|
+
|
|
198
|
+
if num_tiles and num_tiles > 0:
|
|
199
|
+
try:
|
|
200
|
+
mean_pooling = tile_level_mean_pooling(
|
|
201
|
+
emb_np, num_tiles=num_tiles, patches_per_tile=64
|
|
202
|
+
).tolist()
|
|
203
|
+
except Exception:
|
|
204
|
+
pass
|
|
205
|
+
|
|
206
|
+
try:
|
|
207
|
+
exp_pool = embedder.experimental_pool_visual_embedding(
|
|
208
|
+
emb_np, num_tiles=num_tiles
|
|
209
|
+
)
|
|
210
|
+
if exp_pool is not None:
|
|
211
|
+
experimental_pooling = exp_pool.tolist()
|
|
212
|
+
except Exception:
|
|
213
|
+
pass
|
|
214
|
+
|
|
215
|
+
union_doc_id = _union_point_id(
|
|
216
|
+
dataset_name=dataset_name,
|
|
217
|
+
source_doc_id=doc_id,
|
|
218
|
+
union_namespace=collection,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
payload = {
|
|
222
|
+
"dataset": dataset_name,
|
|
223
|
+
"doc_id": doc_id,
|
|
224
|
+
"source_doc_id": doc_id,
|
|
225
|
+
"union_doc_id": union_doc_id,
|
|
226
|
+
"num_tiles": num_tiles,
|
|
227
|
+
"num_visual_tokens": token_info.get("num_visual_tokens"),
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
vectors = {"initial": initial, "global_pooling": global_pool}
|
|
231
|
+
if mean_pooling:
|
|
232
|
+
vectors["mean_pooling"] = mean_pooling
|
|
233
|
+
if experimental_pooling:
|
|
234
|
+
vectors["experimental_pooling"] = experimental_pooling
|
|
235
|
+
|
|
236
|
+
indexer.upsert_point(
|
|
237
|
+
point_id=union_doc_id,
|
|
238
|
+
vectors=vectors,
|
|
239
|
+
payload=payload,
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
uploaded += 1
|
|
243
|
+
|
|
244
|
+
except Exception as e:
|
|
245
|
+
print(f"[INDEX] Error on doc {i}: {e}")
|
|
246
|
+
failed += 1
|
|
247
|
+
|
|
248
|
+
progress_bar.progress((i + 1) / total)
|
|
249
|
+
|
|
250
|
+
status_text.text(f"✅ Done: {uploaded} uploaded, {failed} failed")
|
|
251
|
+
all_results.append(
|
|
252
|
+
{
|
|
253
|
+
"dataset": dataset_name,
|
|
254
|
+
"total": total,
|
|
255
|
+
"uploaded": uploaded,
|
|
256
|
+
"failed": failed,
|
|
257
|
+
}
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
with results_container:
|
|
261
|
+
st.markdown("##### 📋 Results Summary")
|
|
262
|
+
|
|
263
|
+
for r in all_results:
|
|
264
|
+
st.write(
|
|
265
|
+
f"**{r['dataset'].split('/')[-1]}**: {r['uploaded']:,} uploaded, {r['failed']:,} failed"
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
st.success("✅ Indexing complete!")
|
|
269
|
+
|
|
270
|
+
except Exception as e:
|
|
271
|
+
st.error(f"Indexing error: {e}")
|
|
272
|
+
st.code(traceback.format_exc())
|
|
273
|
+
print(f"[INDEX] ERROR: {e}")
|
|
274
|
+
traceback.print_exc()
|
|
@@ -8,12 +8,19 @@ import streamlit as st
|
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
def get_qdrant_credentials() -> Tuple[Optional[str], Optional[str]]:
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
11
|
+
"""Get Qdrant credentials from session state or environment variables.
|
|
12
|
+
|
|
13
|
+
Priority: session_state > QDRANT_URL/QDRANT_API_KEY > legacy env vars
|
|
14
|
+
"""
|
|
15
|
+
url = (
|
|
16
|
+
st.session_state.get("qdrant_url_input")
|
|
17
|
+
or os.getenv("QDRANT_URL")
|
|
18
|
+
or os.getenv("SIGIR_QDRANT_URL") # legacy
|
|
19
|
+
)
|
|
20
|
+
api_key = (
|
|
21
|
+
st.session_state.get("qdrant_key_input")
|
|
16
22
|
or os.getenv("QDRANT_API_KEY")
|
|
23
|
+
or os.getenv("SIGIR_QDRANT_KEY") # legacy
|
|
17
24
|
)
|
|
18
25
|
return url, api_key
|
|
19
26
|
|
|
@@ -9,6 +9,7 @@ from demo.qdrant_utils import (
|
|
|
9
9
|
sample_points_cached,
|
|
10
10
|
search_collection,
|
|
11
11
|
)
|
|
12
|
+
from visual_rag.retrieval import MultiVectorRetriever
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
def render_playground_tab():
|
|
@@ -46,7 +47,6 @@ def render_playground_tab():
|
|
|
46
47
|
if not st.session_state.get("model_loaded"):
|
|
47
48
|
with st.spinner(f"Loading {model_short}..."):
|
|
48
49
|
try:
|
|
49
|
-
from visual_rag.retrieval import MultiVectorRetriever
|
|
50
50
|
_ = MultiVectorRetriever(collection_name=active_collection, model_name=model_name)
|
|
51
51
|
st.session_state["model_loaded"] = True
|
|
52
52
|
st.session_state["loaded_model_key"] = cache_key
|
|
@@ -3,6 +3,8 @@
|
|
|
3
3
|
import os
|
|
4
4
|
import streamlit as st
|
|
5
5
|
|
|
6
|
+
from qdrant_client.models import VectorParamsDiff
|
|
7
|
+
|
|
6
8
|
from demo.qdrant_utils import (
|
|
7
9
|
get_qdrant_credentials,
|
|
8
10
|
init_qdrant_client_with_creds,
|
|
@@ -17,8 +19,8 @@ def render_sidebar():
|
|
|
17
19
|
with st.sidebar:
|
|
18
20
|
st.subheader("🔑 Qdrant Credentials")
|
|
19
21
|
|
|
20
|
-
env_url = os.getenv("
|
|
21
|
-
env_key = os.getenv("
|
|
22
|
+
env_url = os.getenv("QDRANT_URL") or os.getenv("SIGIR_QDRANT_URL") or ""
|
|
23
|
+
env_key = os.getenv("QDRANT_API_KEY") or os.getenv("SIGIR_QDRANT_KEY") or ""
|
|
22
24
|
|
|
23
25
|
if "qdrant_url_input" not in st.session_state:
|
|
24
26
|
st.session_state["qdrant_url_input"] = env_url
|
|
@@ -136,7 +138,6 @@ def render_sidebar():
|
|
|
136
138
|
if target_in_ram != current_in_ram:
|
|
137
139
|
if st.button("💾 Apply Change", key="admin_apply"):
|
|
138
140
|
try:
|
|
139
|
-
from qdrant_client.models import VectorParamsDiff
|
|
140
141
|
client.update_collection(
|
|
141
142
|
collection_name=active,
|
|
142
143
|
vectors_config={sel_vec: VectorParamsDiff(on_disk=not target_in_ram)}
|