visual-rag-toolkit 0.1.2__tar.gz → 0.1.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/PKG-INFO +24 -15
  2. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/README.md +23 -14
  3. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/demo/app.py +20 -8
  4. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/demo/evaluation.py +5 -45
  5. visual_rag_toolkit-0.1.3/demo/indexing.py +274 -0
  6. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/demo/qdrant_utils.py +12 -5
  7. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/demo/ui/playground.py +1 -1
  8. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/demo/ui/sidebar.py +4 -3
  9. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/demo/ui/upload.py +5 -4
  10. visual_rag_toolkit-0.1.3/examples/COMMANDS.md +83 -0
  11. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/examples/config.yaml +6 -0
  12. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/examples/process_pdfs.py +6 -0
  13. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/examples/search_demo.py +6 -0
  14. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/pyproject.toml +1 -1
  15. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/visual_rag/__init__.py +43 -1
  16. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/visual_rag/config.py +4 -7
  17. visual_rag_toolkit-0.1.3/visual_rag/indexing/__init__.py +38 -0
  18. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/visual_rag/indexing/qdrant_indexer.py +92 -42
  19. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/visual_rag/retrieval/multi_vector.py +63 -65
  20. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/visual_rag/retrieval/single_stage.py +7 -0
  21. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/visual_rag/retrieval/two_stage.py +8 -10
  22. visual_rag_toolkit-0.1.2/demo/indexing.py +0 -315
  23. visual_rag_toolkit-0.1.2/visual_rag/indexing/__init__.py +0 -21
  24. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/.github/workflows/ci.yaml +0 -0
  25. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/.github/workflows/publish_pypi.yaml +0 -0
  26. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/.gitignore +0 -0
  27. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/LICENSE +0 -0
  28. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/benchmarks/README.md +0 -0
  29. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/benchmarks/__init__.py +0 -0
  30. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/benchmarks/analyze_results.py +0 -0
  31. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/benchmarks/benchmark_datasets.txt +0 -0
  32. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/benchmarks/prepare_submission.py +0 -0
  33. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/benchmarks/quick_test.py +0 -0
  34. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/benchmarks/run_vidore.py +0 -0
  35. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/benchmarks/vidore_beir_qdrant/run_qdrant_beir.py +0 -0
  36. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/benchmarks/vidore_tatdqa_test/__init__.py +0 -0
  37. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/benchmarks/vidore_tatdqa_test/dataset_loader.py +0 -0
  38. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/benchmarks/vidore_tatdqa_test/metrics.py +0 -0
  39. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/benchmarks/vidore_tatdqa_test/run_qdrant.py +0 -0
  40. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/benchmarks/vidore_tatdqa_test/sweep_eval.py +0 -0
  41. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/demo/__init__.py +0 -0
  42. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/demo/commands.py +0 -0
  43. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/demo/config.py +0 -0
  44. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/demo/download_models.py +0 -0
  45. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/demo/example_metadata_mapping_sigir.json +0 -0
  46. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/demo/results.py +0 -0
  47. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/demo/test_qdrant_connection.py +0 -0
  48. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/demo/ui/__init__.py +0 -0
  49. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/demo/ui/benchmark.py +0 -0
  50. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/demo/ui/header.py +0 -0
  51. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/requirements.txt +0 -0
  52. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/tests/__init__.py +0 -0
  53. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/tests/test_config.py +0 -0
  54. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/tests/test_pdf_processor.py +0 -0
  55. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/tests/test_pooling.py +0 -0
  56. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/tests/test_strategies.py +0 -0
  57. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/visual_rag/cli/__init__.py +0 -0
  58. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/visual_rag/cli/main.py +0 -0
  59. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/visual_rag/demo_runner.py +0 -0
  60. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/visual_rag/embedding/__init__.py +0 -0
  61. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/visual_rag/embedding/pooling.py +0 -0
  62. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/visual_rag/embedding/visual_embedder.py +0 -0
  63. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/visual_rag/indexing/cloudinary_uploader.py +0 -0
  64. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/visual_rag/indexing/pdf_processor.py +0 -0
  65. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/visual_rag/indexing/pipeline.py +0 -0
  66. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/visual_rag/preprocessing/__init__.py +0 -0
  67. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/visual_rag/preprocessing/crop_empty.py +0 -0
  68. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/visual_rag/qdrant_admin.py +0 -0
  69. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/visual_rag/retrieval/__init__.py +0 -0
  70. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/visual_rag/retrieval/three_stage.py +0 -0
  71. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/visual_rag/visualization/__init__.py +0 -0
  72. {visual_rag_toolkit-0.1.2 → visual_rag_toolkit-0.1.3}/visual_rag/visualization/saliency.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: visual-rag-toolkit
3
- Version: 0.1.2
3
+ Version: 0.1.3
4
4
  Summary: End-to-end visual document retrieval with ColPali, featuring two-stage pooling for scalable search
5
5
  Project-URL: Homepage, https://github.com/Ara-Yeroyan/visual-rag-toolkit
6
6
  Project-URL: Documentation, https://github.com/Ara-Yeroyan/visual-rag-toolkit#readme
@@ -88,11 +88,6 @@ Description-Content-Type: text/markdown
88
88
  [![PyPI](https://img.shields.io/pypi/v/visual-rag-toolkit)](https://pypi.org/project/visual-rag-toolkit/)
89
89
  [![Python](https://img.shields.io/pypi/pyversions/visual-rag-toolkit)](https://pypi.org/project/visual-rag-toolkit/)
90
90
  [![License](https://img.shields.io/pypi/l/visual-rag-toolkit)](LICENSE)
91
- [![CI](https://img.shields.io/github/actions/workflow/status/Ara-Yeroyan/visual-rag-toolkit/ci.yaml?branch=main)](https://github.com/Ara-Yeroyan/visual-rag-toolkit/actions/workflows/ci.yaml)
92
-
93
- Note:
94
- - The **PyPI badge** shows “not found” until the first release is published.
95
- - The **CI badge** requires the GitHub repo to be **public** (GitHub does not serve Actions badges for private repos).
96
91
 
97
92
  End-to-end visual document retrieval toolkit featuring **fast multi-stage retrieval** (prefetch with pooled vectors + exact MaxSim reranking).
98
93
 
@@ -162,7 +157,7 @@ for r in results[:3]:
162
157
 
163
158
  ### End-to-end: ingest PDFs (with cropping) → index in Qdrant
164
159
 
165
- This is the SDK-style pipeline: PDF → images → optional crop → embed → store vectors + payload in Qdrant.
160
+ This is the "SDK-style" pipeline: PDF → images → optional crop → embed → store vectors + payload in Qdrant.
166
161
 
167
162
  ```python
168
163
  import os
@@ -174,8 +169,8 @@ import torch
174
169
  from visual_rag import VisualEmbedder
175
170
  from visual_rag.indexing import ProcessingPipeline, QdrantIndexer
176
171
 
177
- QDRANT_URL = os.environ["SIGIR_QDRANT_URL"] # or QDRANT_URL
178
- QDRANT_KEY = os.getenv("SIGIR_QDRANT_KEY", "") # or QDRANT_API_KEY
172
+ QDRANT_URL = os.environ["QDRANT_URL"]
173
+ QDRANT_KEY = os.getenv("QDRANT_API_KEY", "")
179
174
 
180
175
  collection = "my_visual_docs"
181
176
 
@@ -193,6 +188,8 @@ indexer = QdrantIndexer(
193
188
  prefer_grpc=True,
194
189
  vector_datatype="float16",
195
190
  )
191
+
192
+ # Creates collection + required payload indexes (e.g., "filename" for skip_existing)
196
193
  indexer.create_collection(force_recreate=False)
197
194
 
198
195
  pipeline = ProcessingPipeline(
@@ -208,19 +205,32 @@ pipeline = ProcessingPipeline(
208
205
 
209
206
  pdfs = [Path("docs/a.pdf"), Path("docs/b.pdf")]
210
207
  for pdf_path in pdfs:
211
- pipeline.process_pdf(
208
+ result = pipeline.process_pdf(
212
209
  pdf_path,
213
- skip_existing=True,
210
+ skip_existing=True, # Skip pages already in Qdrant (uses filename index)
214
211
  upload_to_cloudinary=False,
215
212
  upload_to_qdrant=True,
216
213
  )
214
+ # Logs automatically shown:
215
+ # [10:23:45] 📚 Processing PDF: a.pdf
216
+ # [10:23:45] 🖼️ Converting PDF to images...
217
+ # [10:23:46] ✅ Converted 12 pages
218
+ # [10:23:46] 📦 Processing pages 1-8/12
219
+ # [10:23:46] 🤖 Generating embeddings for 8 pages...
220
+ # [10:23:48] 📤 Uploading batch of 8 pages...
221
+ # [10:23:48] ✅ Uploaded 8 points to Qdrant
222
+ # [10:23:48] 📦 Processing pages 9-12/12
223
+ # [10:23:48] 🤖 Generating embeddings for 4 pages...
224
+ # [10:23:50] 📤 Uploading batch of 4 pages...
225
+ # [10:23:50] ✅ Uploaded 4 points to Qdrant
226
+ # [10:23:50] ✅ Completed a.pdf: 12 uploaded, 0 skipped, 0 failed
217
227
  ```
218
228
 
219
229
  CLI equivalent:
220
230
 
221
231
  ```bash
222
- export SIGIR_QDRANT_URL="https://YOUR_QDRANT"
223
- export SIGIR_QDRANT_KEY="YOUR_KEY"
232
+ export QDRANT_URL="https://YOUR_QDRANT"
233
+ export QDRANT_API_KEY="YOUR_KEY"
224
234
 
225
235
  visual-rag process \
226
236
  --reports-dir ./docs \
@@ -263,7 +273,7 @@ Stage 2: Exact MaxSim reranking on candidates
263
273
  └── Return top-k results (e.g., 10)
264
274
  ```
265
275
 
266
- Three-stage extends this with an additional cheap prefetch stage before stage 2.
276
+ Three-stage extends this with an additional "cheap prefetch" stage before stage 2.
267
277
 
268
278
  ## 📁 Package Structure
269
279
 
@@ -374,4 +384,3 @@ MIT License - see [LICENSE](LICENSE) for details.
374
384
  - [Qdrant](https://qdrant.tech/) - Vector database with multi-vector support
375
385
  - [ColPali](https://github.com/illuin-tech/colpali) - Visual document retrieval models
376
386
  - [ViDoRe](https://huggingface.co/spaces/vidore/vidore-leaderboard) - Benchmark dataset
377
-
@@ -3,11 +3,6 @@
3
3
  [![PyPI](https://img.shields.io/pypi/v/visual-rag-toolkit)](https://pypi.org/project/visual-rag-toolkit/)
4
4
  [![Python](https://img.shields.io/pypi/pyversions/visual-rag-toolkit)](https://pypi.org/project/visual-rag-toolkit/)
5
5
  [![License](https://img.shields.io/pypi/l/visual-rag-toolkit)](LICENSE)
6
- [![CI](https://img.shields.io/github/actions/workflow/status/Ara-Yeroyan/visual-rag-toolkit/ci.yaml?branch=main)](https://github.com/Ara-Yeroyan/visual-rag-toolkit/actions/workflows/ci.yaml)
7
-
8
- Note:
9
- - The **PyPI badge** shows “not found” until the first release is published.
10
- - The **CI badge** requires the GitHub repo to be **public** (GitHub does not serve Actions badges for private repos).
11
6
 
12
7
  End-to-end visual document retrieval toolkit featuring **fast multi-stage retrieval** (prefetch with pooled vectors + exact MaxSim reranking).
13
8
 
@@ -77,7 +72,7 @@ for r in results[:3]:
77
72
 
78
73
  ### End-to-end: ingest PDFs (with cropping) → index in Qdrant
79
74
 
80
- This is the SDK-style pipeline: PDF → images → optional crop → embed → store vectors + payload in Qdrant.
75
+ This is the "SDK-style" pipeline: PDF → images → optional crop → embed → store vectors + payload in Qdrant.
81
76
 
82
77
  ```python
83
78
  import os
@@ -89,8 +84,8 @@ import torch
89
84
  from visual_rag import VisualEmbedder
90
85
  from visual_rag.indexing import ProcessingPipeline, QdrantIndexer
91
86
 
92
- QDRANT_URL = os.environ["SIGIR_QDRANT_URL"] # or QDRANT_URL
93
- QDRANT_KEY = os.getenv("SIGIR_QDRANT_KEY", "") # or QDRANT_API_KEY
87
+ QDRANT_URL = os.environ["QDRANT_URL"]
88
+ QDRANT_KEY = os.getenv("QDRANT_API_KEY", "")
94
89
 
95
90
  collection = "my_visual_docs"
96
91
 
@@ -108,6 +103,8 @@ indexer = QdrantIndexer(
108
103
  prefer_grpc=True,
109
104
  vector_datatype="float16",
110
105
  )
106
+
107
+ # Creates collection + required payload indexes (e.g., "filename" for skip_existing)
111
108
  indexer.create_collection(force_recreate=False)
112
109
 
113
110
  pipeline = ProcessingPipeline(
@@ -123,19 +120,32 @@ pipeline = ProcessingPipeline(
123
120
 
124
121
  pdfs = [Path("docs/a.pdf"), Path("docs/b.pdf")]
125
122
  for pdf_path in pdfs:
126
- pipeline.process_pdf(
123
+ result = pipeline.process_pdf(
127
124
  pdf_path,
128
- skip_existing=True,
125
+ skip_existing=True, # Skip pages already in Qdrant (uses filename index)
129
126
  upload_to_cloudinary=False,
130
127
  upload_to_qdrant=True,
131
128
  )
129
+ # Logs automatically shown:
130
+ # [10:23:45] 📚 Processing PDF: a.pdf
131
+ # [10:23:45] 🖼️ Converting PDF to images...
132
+ # [10:23:46] ✅ Converted 12 pages
133
+ # [10:23:46] 📦 Processing pages 1-8/12
134
+ # [10:23:46] 🤖 Generating embeddings for 8 pages...
135
+ # [10:23:48] 📤 Uploading batch of 8 pages...
136
+ # [10:23:48] ✅ Uploaded 8 points to Qdrant
137
+ # [10:23:48] 📦 Processing pages 9-12/12
138
+ # [10:23:48] 🤖 Generating embeddings for 4 pages...
139
+ # [10:23:50] 📤 Uploading batch of 4 pages...
140
+ # [10:23:50] ✅ Uploaded 4 points to Qdrant
141
+ # [10:23:50] ✅ Completed a.pdf: 12 uploaded, 0 skipped, 0 failed
132
142
  ```
133
143
 
134
144
  CLI equivalent:
135
145
 
136
146
  ```bash
137
- export SIGIR_QDRANT_URL="https://YOUR_QDRANT"
138
- export SIGIR_QDRANT_KEY="YOUR_KEY"
147
+ export QDRANT_URL="https://YOUR_QDRANT"
148
+ export QDRANT_API_KEY="YOUR_KEY"
139
149
 
140
150
  visual-rag process \
141
151
  --reports-dir ./docs \
@@ -178,7 +188,7 @@ Stage 2: Exact MaxSim reranking on candidates
178
188
  └── Return top-k results (e.g., 10)
179
189
  ```
180
190
 
181
- Three-stage extends this with an additional cheap prefetch stage before stage 2.
191
+ Three-stage extends this with an additional "cheap prefetch" stage before stage 2.
182
192
 
183
193
  ## 📁 Package Structure
184
194
 
@@ -289,4 +299,3 @@ MIT License - see [LICENSE](LICENSE) for details.
289
299
  - [Qdrant](https://qdrant.tech/) - Vector database with multi-vector support
290
300
  - [ColPali](https://github.com/illuin-tech/colpali) - Visual document retrieval models
291
301
  - [ViDoRe](https://huggingface.co/spaces/vidore/vidore-leaderboard) - Benchmark dataset
292
-
@@ -1,13 +1,23 @@
1
1
  """Main entry point for the Visual RAG Toolkit demo application."""
2
2
 
3
+ import os
3
4
  import sys
4
5
  from pathlib import Path
5
6
 
6
- ROOT_DIR = Path(__file__).parent.parent
7
- sys.path.insert(0, str(ROOT_DIR))
7
+ # Ensure repo root is in sys.path for local development
8
+ # (In HF Space / Docker, PYTHONPATH is already set correctly)
9
+ _app_dir = Path(__file__).resolve().parent
10
+ _repo_root = _app_dir.parent
11
+ if str(_repo_root) not in sys.path:
12
+ sys.path.insert(0, str(_repo_root))
8
13
 
9
14
  from dotenv import load_dotenv
10
- load_dotenv(ROOT_DIR / ".env")
15
+
16
+ # Load .env from the repo root (works both locally and in Docker)
17
+ if (_repo_root / ".env").exists():
18
+ load_dotenv(_repo_root / ".env")
19
+ if (_app_dir / ".env").exists():
20
+ load_dotenv(_app_dir / ".env")
11
21
 
12
22
  import streamlit as st
13
23
 
@@ -28,15 +38,17 @@ from demo.ui.benchmark import render_benchmark_tab
28
38
  def main():
29
39
  render_header()
30
40
  render_sidebar()
31
-
32
- tab_upload, tab_playground, tab_benchmark = st.tabs(["📤 Upload", "🎮 Playground", "📊 Benchmarking"])
33
-
41
+
42
+ tab_upload, tab_playground, tab_benchmark = st.tabs(
43
+ ["📤 Upload", "🎮 Playground", "📊 Benchmarking"]
44
+ )
45
+
34
46
  with tab_upload:
35
47
  render_upload_tab()
36
-
48
+
37
49
  with tab_playground:
38
50
  render_playground_tab()
39
-
51
+
40
52
  with tab_benchmark:
41
53
  render_benchmark_tab()
42
54
 
@@ -1,20 +1,23 @@
1
1
  """Evaluation runner with UI updates."""
2
2
 
3
3
  import hashlib
4
- import importlib.util
5
4
  import json
6
5
  import logging
7
6
  import time
8
7
  import traceback
9
8
  from datetime import datetime
10
- from pathlib import Path
11
9
  from typing import Any, Dict, List, Optional
12
10
 
13
11
  import numpy as np
14
12
  import streamlit as st
15
13
  import torch
14
+ from qdrant_client.models import FieldCondition, Filter, MatchValue
16
15
 
17
16
  from visual_rag import VisualEmbedder
17
+ from visual_rag.retrieval import MultiVectorRetriever
18
+ from benchmarks.vidore_tatdqa_test.dataset_loader import load_vidore_beir_dataset
19
+ from benchmarks.vidore_tatdqa_test.metrics import ndcg_at_k, mrr_at_k, recall_at_k
20
+ from demo.qdrant_utils import get_qdrant_credentials
18
21
 
19
22
 
20
23
  TORCH_DTYPE_MAP = {
@@ -22,49 +25,6 @@ TORCH_DTYPE_MAP = {
22
25
  "float32": torch.float32,
23
26
  "bfloat16": torch.bfloat16,
24
27
  }
25
- from qdrant_client.models import Filter, FieldCondition, MatchValue
26
-
27
- from visual_rag.retrieval import MultiVectorRetriever
28
-
29
-
30
- def _load_local_benchmark_module(module_filename: str):
31
- """
32
- Load `benchmarks/vidore_tatdqa_test/<module_filename>` via file path.
33
-
34
- Motivation:
35
- - Some environments (notably containers / Spaces) can have a third-party
36
- `benchmarks` package installed, causing `import benchmarks...` to resolve
37
- to the wrong module.
38
- - This fallback guarantees we load the repo's benchmark utilities.
39
- """
40
- root = Path(__file__).resolve().parents[1] # demo/.. = repo root
41
- target = root / "benchmarks" / "vidore_tatdqa_test" / module_filename
42
- if not target.exists():
43
- raise ModuleNotFoundError(f"Missing local benchmark module file: {target}")
44
-
45
- name = f"_visual_rag_toolkit_local_{target.stem}"
46
- spec = importlib.util.spec_from_file_location(name, str(target))
47
- if spec is None or spec.loader is None:
48
- raise ModuleNotFoundError(f"Could not load module spec for: {target}")
49
- mod = importlib.util.module_from_spec(spec)
50
- spec.loader.exec_module(mod) # type: ignore[attr-defined]
51
- return mod
52
-
53
-
54
- try:
55
- # Preferred: normal import
56
- from benchmarks.vidore_tatdqa_test.dataset_loader import load_vidore_beir_dataset
57
- from benchmarks.vidore_tatdqa_test.metrics import ndcg_at_k, mrr_at_k, recall_at_k
58
- except ModuleNotFoundError:
59
- # Robust fallback: load from local file paths
60
- _dl = _load_local_benchmark_module("dataset_loader.py")
61
- _mx = _load_local_benchmark_module("metrics.py")
62
- load_vidore_beir_dataset = _dl.load_vidore_beir_dataset
63
- ndcg_at_k = _mx.ndcg_at_k
64
- mrr_at_k = _mx.mrr_at_k
65
- recall_at_k = _mx.recall_at_k
66
-
67
- from demo.qdrant_utils import get_qdrant_credentials
68
28
 
69
29
  logger = logging.getLogger(__name__)
70
30
  logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")
@@ -0,0 +1,274 @@
1
+ """Indexing runner with UI updates."""
2
+
3
+ import hashlib
4
+ import json
5
+ import time
6
+ import traceback
7
+ from datetime import datetime
8
+ from typing import Any, Dict, Optional
9
+
10
+ import numpy as np
11
+ import streamlit as st
12
+ import torch
13
+
14
+ from visual_rag import VisualEmbedder
15
+ from visual_rag.embedding.pooling import tile_level_mean_pooling
16
+ from visual_rag.indexing.qdrant_indexer import QdrantIndexer
17
+ from benchmarks.vidore_tatdqa_test.dataset_loader import load_vidore_beir_dataset
18
+ from demo.qdrant_utils import get_qdrant_credentials
19
+
20
+
21
+ TORCH_DTYPE_MAP = {
22
+ "float16": torch.float16,
23
+ "float32": torch.float32,
24
+ "bfloat16": torch.bfloat16,
25
+ }
26
+
27
+
28
+ def _stable_uuid(text: str) -> str:
29
+ """Generate a stable UUID from text (same as benchmark script)."""
30
+ hex_str = hashlib.sha256(text.encode("utf-8")).hexdigest()[:32]
31
+ return f"{hex_str[:8]}-{hex_str[8:12]}-{hex_str[12:16]}-{hex_str[16:20]}-{hex_str[20:32]}"
32
+
33
+
34
+ def _union_point_id(
35
+ *, dataset_name: str, source_doc_id: str, union_namespace: Optional[str]
36
+ ) -> str:
37
+ """Generate union point ID (same as benchmark script)."""
38
+ ns = f"{union_namespace}::{dataset_name}" if union_namespace else dataset_name
39
+ return _stable_uuid(f"{ns}::{source_doc_id}")
40
+
41
+
42
+ def run_indexing_with_ui(config: Dict[str, Any]):
43
+ st.divider()
44
+
45
+ print("=" * 60)
46
+ print("[INDEX] Starting indexing via UI")
47
+ print("=" * 60)
48
+
49
+ url, api_key = get_qdrant_credentials()
50
+ if not url:
51
+ st.error("QDRANT_URL not configured")
52
+ return
53
+
54
+ datasets = config.get("datasets", [])
55
+ collection = config["collection"]
56
+ model = config.get("model", "vidore/colpali-v1.3")
57
+ recreate = config.get("recreate", False)
58
+ torch_dtype = config.get("torch_dtype", "float16")
59
+ qdrant_vector_dtype = config.get("qdrant_vector_dtype", "float16")
60
+ prefer_grpc = config.get("prefer_grpc", True)
61
+ batch_size = config.get("batch_size", 4)
62
+ max_docs = config.get("max_docs")
63
+
64
+ print(f"[INDEX] Config: collection={collection}, model={model}")
65
+ print(f"[INDEX] Datasets: {datasets}")
66
+ print(
67
+ f"[INDEX] max_docs={max_docs}, batch_size={batch_size}, recreate={recreate}"
68
+ )
69
+ print(
70
+ f"[INDEX] torch_dtype={torch_dtype}, qdrant_dtype={qdrant_vector_dtype}, grpc={prefer_grpc}"
71
+ )
72
+
73
+ phase1_container = st.container()
74
+ phase2_container = st.container()
75
+ phase3_container = st.container()
76
+ results_container = st.container()
77
+
78
+ try:
79
+ with phase1_container:
80
+ st.markdown("##### 🤖 Phase 1: Loading Model")
81
+ model_status = st.empty()
82
+ model_status.info(f"Loading `{model.split('/')[-1]}`...")
83
+
84
+ print(f"[INDEX] Loading embedder: {model}")
85
+ torch_dtype_obj = TORCH_DTYPE_MAP.get(torch_dtype, torch.float16)
86
+ output_dtype_obj = (
87
+ np.float16 if qdrant_vector_dtype == "float16" else np.float32
88
+ )
89
+ embedder = VisualEmbedder(
90
+ model_name=model,
91
+ torch_dtype=torch_dtype_obj,
92
+ output_dtype=output_dtype_obj,
93
+ )
94
+ embedder._load_model()
95
+ print(
96
+ f"[INDEX] Embedder loaded (torch_dtype={torch_dtype}, output_dtype={qdrant_vector_dtype})"
97
+ )
98
+ model_status.success(f"✅ Model `{model.split('/')[-1]}` loaded")
99
+
100
+ with phase2_container:
101
+ st.markdown("##### 📦 Phase 2: Setting Up Collection")
102
+
103
+ indexer_status = st.empty()
104
+ indexer_status.info("Connecting to Qdrant...")
105
+
106
+ print("[INDEX] Connecting to Qdrant...")
107
+ indexer = QdrantIndexer(
108
+ url=url,
109
+ api_key=api_key,
110
+ collection_name=collection,
111
+ prefer_grpc=prefer_grpc,
112
+ vector_datatype=qdrant_vector_dtype,
113
+ )
114
+ print("[INDEX] Connected to Qdrant")
115
+ indexer_status.success("✅ Connected to Qdrant")
116
+
117
+ coll_status = st.empty()
118
+ action = "Recreating" if recreate else "Creating/verifying"
119
+ coll_status.info(f"{action} collection `{collection}`...")
120
+
121
+ print(f"[INDEX] {action} collection: {collection}")
122
+ indexer.create_collection(force_recreate=recreate)
123
+ indexer.create_payload_indexes(
124
+ fields=[
125
+ {"field": "dataset", "type": "keyword"},
126
+ {"field": "doc_id", "type": "keyword"},
127
+ {"field": "source_doc_id", "type": "keyword"},
128
+ ]
129
+ )
130
+ print("[INDEX] Collection ready")
131
+ coll_status.success(f"✅ Collection `{collection}` ready")
132
+
133
+ with phase3_container:
134
+ st.markdown("##### 📊 Phase 3: Processing Datasets")
135
+
136
+ all_results = []
137
+
138
+ for ds_idx, dataset_name in enumerate(datasets):
139
+ ds_short = dataset_name.split("/")[-1]
140
+ ds_container = st.container()
141
+
142
+ with ds_container:
143
+ st.markdown(
144
+ f"**Dataset {ds_idx + 1}/{len(datasets)}: `{ds_short}`**"
145
+ )
146
+
147
+ load_status = st.empty()
148
+ load_status.info(f"Loading dataset `{ds_short}`...")
149
+
150
+ print(f"[INDEX] Loading dataset: {dataset_name}")
151
+ corpus, queries, qrels = load_vidore_beir_dataset(dataset_name)
152
+ total_docs = len(corpus)
153
+ print(f"[INDEX] Dataset loaded: {total_docs} docs")
154
+ load_status.success(f"✅ Loaded {total_docs:,} documents")
155
+
156
+ if max_docs and max_docs < total_docs:
157
+ corpus = corpus[:max_docs]
158
+ print(f"[INDEX] Limiting to {max_docs} docs")
159
+
160
+ progress_bar = st.progress(0)
161
+ status_text = st.empty()
162
+
163
+ uploaded = 0
164
+ failed = 0
165
+ total = len(corpus)
166
+
167
+ for i, doc in enumerate(corpus):
168
+ try:
169
+ doc_id = str(doc.doc_id)
170
+ image = doc.image
171
+ if image is None:
172
+ failed += 1
173
+ continue
174
+
175
+ status_text.text(
176
+ f"Processing {i + 1}/{total}: {doc_id[:30]}..."
177
+ )
178
+
179
+ embeddings, token_infos = embedder.embed_images(
180
+ [image],
181
+ return_token_info=True,
182
+ show_progress=False,
183
+ )
184
+ emb = embeddings[0]
185
+ token_info = token_infos[0] if token_infos else {}
186
+
187
+ if hasattr(emb, "cpu"):
188
+ emb = emb.cpu()
189
+ emb_np = np.asarray(emb, dtype=output_dtype_obj)
190
+
191
+ initial = emb_np.tolist()
192
+ global_pool = emb_np.mean(axis=0).tolist()
193
+
194
+ num_tiles = token_info.get("num_tiles")
195
+ mean_pooling = None
196
+ experimental_pooling = None
197
+
198
+ if num_tiles and num_tiles > 0:
199
+ try:
200
+ mean_pooling = tile_level_mean_pooling(
201
+ emb_np, num_tiles=num_tiles, patches_per_tile=64
202
+ ).tolist()
203
+ except Exception:
204
+ pass
205
+
206
+ try:
207
+ exp_pool = embedder.experimental_pool_visual_embedding(
208
+ emb_np, num_tiles=num_tiles
209
+ )
210
+ if exp_pool is not None:
211
+ experimental_pooling = exp_pool.tolist()
212
+ except Exception:
213
+ pass
214
+
215
+ union_doc_id = _union_point_id(
216
+ dataset_name=dataset_name,
217
+ source_doc_id=doc_id,
218
+ union_namespace=collection,
219
+ )
220
+
221
+ payload = {
222
+ "dataset": dataset_name,
223
+ "doc_id": doc_id,
224
+ "source_doc_id": doc_id,
225
+ "union_doc_id": union_doc_id,
226
+ "num_tiles": num_tiles,
227
+ "num_visual_tokens": token_info.get("num_visual_tokens"),
228
+ }
229
+
230
+ vectors = {"initial": initial, "global_pooling": global_pool}
231
+ if mean_pooling:
232
+ vectors["mean_pooling"] = mean_pooling
233
+ if experimental_pooling:
234
+ vectors["experimental_pooling"] = experimental_pooling
235
+
236
+ indexer.upsert_point(
237
+ point_id=union_doc_id,
238
+ vectors=vectors,
239
+ payload=payload,
240
+ )
241
+
242
+ uploaded += 1
243
+
244
+ except Exception as e:
245
+ print(f"[INDEX] Error on doc {i}: {e}")
246
+ failed += 1
247
+
248
+ progress_bar.progress((i + 1) / total)
249
+
250
+ status_text.text(f"✅ Done: {uploaded} uploaded, {failed} failed")
251
+ all_results.append(
252
+ {
253
+ "dataset": dataset_name,
254
+ "total": total,
255
+ "uploaded": uploaded,
256
+ "failed": failed,
257
+ }
258
+ )
259
+
260
+ with results_container:
261
+ st.markdown("##### 📋 Results Summary")
262
+
263
+ for r in all_results:
264
+ st.write(
265
+ f"**{r['dataset'].split('/')[-1]}**: {r['uploaded']:,} uploaded, {r['failed']:,} failed"
266
+ )
267
+
268
+ st.success("✅ Indexing complete!")
269
+
270
+ except Exception as e:
271
+ st.error(f"Indexing error: {e}")
272
+ st.code(traceback.format_exc())
273
+ print(f"[INDEX] ERROR: {e}")
274
+ traceback.print_exc()
@@ -8,12 +8,19 @@ import streamlit as st
8
8
 
9
9
 
10
10
  def get_qdrant_credentials() -> Tuple[Optional[str], Optional[str]]:
11
- url = st.session_state.get("qdrant_url_input") or os.getenv("SIGIR_QDRANT_URL") or os.getenv("DEST_QDRANT_URL") or os.getenv("QDRANT_URL")
12
- api_key = st.session_state.get("qdrant_key_input") or (
13
- os.getenv("SIGIR_QDRANT_KEY")
14
- or os.getenv("SIGIR_QDRANT_API_KEY")
15
- or os.getenv("DEST_QDRANT_API_KEY")
11
+ """Get Qdrant credentials from session state or environment variables.
12
+
13
+ Priority: session_state > QDRANT_URL/QDRANT_API_KEY > legacy env vars
14
+ """
15
+ url = (
16
+ st.session_state.get("qdrant_url_input")
17
+ or os.getenv("QDRANT_URL")
18
+ or os.getenv("SIGIR_QDRANT_URL") # legacy
19
+ )
20
+ api_key = (
21
+ st.session_state.get("qdrant_key_input")
16
22
  or os.getenv("QDRANT_API_KEY")
23
+ or os.getenv("SIGIR_QDRANT_KEY") # legacy
17
24
  )
18
25
  return url, api_key
19
26
 
@@ -9,6 +9,7 @@ from demo.qdrant_utils import (
9
9
  sample_points_cached,
10
10
  search_collection,
11
11
  )
12
+ from visual_rag.retrieval import MultiVectorRetriever
12
13
 
13
14
 
14
15
  def render_playground_tab():
@@ -46,7 +47,6 @@ def render_playground_tab():
46
47
  if not st.session_state.get("model_loaded"):
47
48
  with st.spinner(f"Loading {model_short}..."):
48
49
  try:
49
- from visual_rag.retrieval import MultiVectorRetriever
50
50
  _ = MultiVectorRetriever(collection_name=active_collection, model_name=model_name)
51
51
  st.session_state["model_loaded"] = True
52
52
  st.session_state["loaded_model_key"] = cache_key
@@ -3,6 +3,8 @@
3
3
  import os
4
4
  import streamlit as st
5
5
 
6
+ from qdrant_client.models import VectorParamsDiff
7
+
6
8
  from demo.qdrant_utils import (
7
9
  get_qdrant_credentials,
8
10
  init_qdrant_client_with_creds,
@@ -17,8 +19,8 @@ def render_sidebar():
17
19
  with st.sidebar:
18
20
  st.subheader("🔑 Qdrant Credentials")
19
21
 
20
- env_url = os.getenv("SIGIR_QDRANT_URL") or os.getenv("DEST_QDRANT_URL") or os.getenv("QDRANT_URL") or ""
21
- env_key = os.getenv("SIGIR_QDRANT_KEY") or os.getenv("SIGIR_QDRANT_API_KEY") or os.getenv("DEST_QDRANT_API_KEY") or os.getenv("QDRANT_API_KEY") or ""
22
+ env_url = os.getenv("QDRANT_URL") or os.getenv("SIGIR_QDRANT_URL") or ""
23
+ env_key = os.getenv("QDRANT_API_KEY") or os.getenv("SIGIR_QDRANT_KEY") or ""
22
24
 
23
25
  if "qdrant_url_input" not in st.session_state:
24
26
  st.session_state["qdrant_url_input"] = env_url
@@ -136,7 +138,6 @@ def render_sidebar():
136
138
  if target_in_ram != current_in_ram:
137
139
  if st.button("💾 Apply Change", key="admin_apply"):
138
140
  try:
139
- from qdrant_client.models import VectorParamsDiff
140
141
  client.update_collection(
141
142
  collection_name=active,
142
143
  vectors_config={sel_vec: VectorParamsDiff(on_disk=not target_in_ram)}