visual-rag-toolkit 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/README.md +101 -0
- benchmarks/__init__.py +11 -0
- benchmarks/analyze_results.py +187 -0
- benchmarks/benchmark_datasets.txt +105 -0
- benchmarks/prepare_submission.py +205 -0
- benchmarks/quick_test.py +566 -0
- benchmarks/run_vidore.py +513 -0
- benchmarks/vidore_beir_qdrant/run_qdrant_beir.py +1365 -0
- benchmarks/vidore_tatdqa_test/COMMANDS.md +83 -0
- benchmarks/vidore_tatdqa_test/__init__.py +6 -0
- benchmarks/vidore_tatdqa_test/dataset_loader.py +363 -0
- benchmarks/vidore_tatdqa_test/metrics.py +44 -0
- benchmarks/vidore_tatdqa_test/run_qdrant.py +799 -0
- benchmarks/vidore_tatdqa_test/sweep_eval.py +372 -0
- demo/__init__.py +10 -0
- demo/app.py +45 -0
- demo/commands.py +334 -0
- demo/config.py +34 -0
- demo/download_models.py +75 -0
- demo/evaluation.py +602 -0
- demo/example_metadata_mapping_sigir.json +37 -0
- demo/indexing.py +286 -0
- demo/qdrant_utils.py +211 -0
- demo/results.py +35 -0
- demo/test_qdrant_connection.py +119 -0
- demo/ui/__init__.py +15 -0
- demo/ui/benchmark.py +355 -0
- demo/ui/header.py +30 -0
- demo/ui/playground.py +339 -0
- demo/ui/sidebar.py +162 -0
- demo/ui/upload.py +487 -0
- visual_rag/__init__.py +98 -0
- visual_rag/cli/__init__.py +1 -0
- visual_rag/cli/main.py +629 -0
- visual_rag/config.py +230 -0
- visual_rag/demo_runner.py +90 -0
- visual_rag/embedding/__init__.py +26 -0
- visual_rag/embedding/pooling.py +343 -0
- visual_rag/embedding/visual_embedder.py +622 -0
- visual_rag/indexing/__init__.py +21 -0
- visual_rag/indexing/cloudinary_uploader.py +274 -0
- visual_rag/indexing/pdf_processor.py +324 -0
- visual_rag/indexing/pipeline.py +628 -0
- visual_rag/indexing/qdrant_indexer.py +478 -0
- visual_rag/preprocessing/__init__.py +3 -0
- visual_rag/preprocessing/crop_empty.py +120 -0
- visual_rag/qdrant_admin.py +222 -0
- visual_rag/retrieval/__init__.py +19 -0
- visual_rag/retrieval/multi_vector.py +222 -0
- visual_rag/retrieval/single_stage.py +126 -0
- visual_rag/retrieval/three_stage.py +173 -0
- visual_rag/retrieval/two_stage.py +471 -0
- visual_rag/visualization/__init__.py +19 -0
- visual_rag/visualization/saliency.py +335 -0
- visual_rag_toolkit-0.1.1.dist-info/METADATA +305 -0
- visual_rag_toolkit-0.1.1.dist-info/RECORD +59 -0
- visual_rag_toolkit-0.1.1.dist-info/WHEEL +4 -0
- visual_rag_toolkit-0.1.1.dist-info/entry_points.txt +3 -0
- visual_rag_toolkit-0.1.1.dist-info/licenses/LICENSE +22 -0
|
@@ -0,0 +1,622 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Visual Embedder - Generate visual and text embeddings for document retrieval.
|
|
3
|
+
|
|
4
|
+
This module provides a flexible interface that supports:
|
|
5
|
+
- ColPali models (ColSmol, ColPali, ColQwen2)
|
|
6
|
+
- Other vision-language models (future)
|
|
7
|
+
- Image embedding with tile-aware processing
|
|
8
|
+
- Query embedding with special token filtering
|
|
9
|
+
|
|
10
|
+
The embedder is BACKEND-AGNOSTIC - configure which model to use via the
|
|
11
|
+
`backend` parameter or model_name.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import gc
|
|
15
|
+
import logging
|
|
16
|
+
import os
|
|
17
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
18
|
+
|
|
19
|
+
import numpy as np
|
|
20
|
+
import torch
|
|
21
|
+
from PIL import Image
|
|
22
|
+
from tqdm import tqdm
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class VisualEmbedder:
|
|
28
|
+
"""
|
|
29
|
+
Visual document embedder supporting multiple backends.
|
|
30
|
+
|
|
31
|
+
Currently supports:
|
|
32
|
+
- ColPali family (ColSmol-500M, ColPali, ColQwen2)
|
|
33
|
+
- More backends can be added
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
model_name: HuggingFace model name (e.g., "vidore/colSmol-500M")
|
|
37
|
+
backend: Backend type ("colpali", "auto"). "auto" detects from model_name.
|
|
38
|
+
device: Device to use (auto, cuda, mps, cpu)
|
|
39
|
+
torch_dtype: Data type for model weights
|
|
40
|
+
batch_size: Batch size for image processing
|
|
41
|
+
filter_special_tokens: Filter special tokens from query embeddings
|
|
42
|
+
|
|
43
|
+
Example:
|
|
44
|
+
>>> # Auto-detect backend from model name
|
|
45
|
+
>>> embedder = VisualEmbedder(model_name="vidore/colSmol-500M")
|
|
46
|
+
>>>
|
|
47
|
+
>>> # Embed images
|
|
48
|
+
>>> image_embeddings = embedder.embed_images(images)
|
|
49
|
+
>>>
|
|
50
|
+
>>> # Embed query
|
|
51
|
+
>>> query_embedding = embedder.embed_query("What is the budget?")
|
|
52
|
+
>>>
|
|
53
|
+
>>> # Get token info for saliency maps
|
|
54
|
+
>>> embeddings, token_infos = embedder.embed_images(
|
|
55
|
+
... images, return_token_info=True
|
|
56
|
+
... )
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
# Known model families and their backends
|
|
60
|
+
MODEL_BACKENDS = {
|
|
61
|
+
"colsmol": "colpali",
|
|
62
|
+
"colpali": "colpali",
|
|
63
|
+
"colqwen": "colpali",
|
|
64
|
+
"colidefics": "colpali",
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
def __init__(
|
|
68
|
+
self,
|
|
69
|
+
model_name: str = "vidore/colSmol-500M",
|
|
70
|
+
backend: str = "auto",
|
|
71
|
+
device: Optional[str] = None,
|
|
72
|
+
torch_dtype: Optional[torch.dtype] = None,
|
|
73
|
+
output_dtype: Optional[np.dtype] = None,
|
|
74
|
+
batch_size: int = 4,
|
|
75
|
+
filter_special_tokens: bool = True,
|
|
76
|
+
processor_speed: str = "fast",
|
|
77
|
+
):
|
|
78
|
+
self.model_name = model_name
|
|
79
|
+
self.batch_size = batch_size
|
|
80
|
+
self.filter_special_tokens = filter_special_tokens
|
|
81
|
+
if processor_speed not in ("fast", "slow", "auto"):
|
|
82
|
+
raise ValueError("processor_speed must be one of: fast, slow, auto")
|
|
83
|
+
self.processor_speed = processor_speed
|
|
84
|
+
|
|
85
|
+
if os.getenv("VISUALRAG_INCLUDE_SPECIAL_TOKENS"):
|
|
86
|
+
self.filter_special_tokens = False
|
|
87
|
+
logger.info("Special token filtering disabled via VISUALRAG_INCLUDE_SPECIAL_TOKENS")
|
|
88
|
+
|
|
89
|
+
if backend == "auto":
|
|
90
|
+
backend = self._detect_backend(model_name)
|
|
91
|
+
self.backend = backend
|
|
92
|
+
|
|
93
|
+
if device is None:
|
|
94
|
+
if torch.cuda.is_available():
|
|
95
|
+
device = "cuda"
|
|
96
|
+
elif torch.backends.mps.is_available():
|
|
97
|
+
device = "mps"
|
|
98
|
+
else:
|
|
99
|
+
device = "cpu"
|
|
100
|
+
self.device = device
|
|
101
|
+
|
|
102
|
+
if torch_dtype is None:
|
|
103
|
+
if device == "cuda":
|
|
104
|
+
torch_dtype = torch.bfloat16
|
|
105
|
+
else:
|
|
106
|
+
torch_dtype = torch.float32
|
|
107
|
+
self.torch_dtype = torch_dtype
|
|
108
|
+
|
|
109
|
+
if output_dtype is None:
|
|
110
|
+
if torch_dtype == torch.float16:
|
|
111
|
+
output_dtype = np.float16
|
|
112
|
+
else:
|
|
113
|
+
output_dtype = np.float32
|
|
114
|
+
self.output_dtype = output_dtype
|
|
115
|
+
|
|
116
|
+
self._model = None
|
|
117
|
+
self._processor = None
|
|
118
|
+
self._image_token_id = None
|
|
119
|
+
|
|
120
|
+
logger.info("🤖 VisualEmbedder initialized")
|
|
121
|
+
logger.info(f" Model: {model_name}")
|
|
122
|
+
logger.info(f" Backend: {backend}")
|
|
123
|
+
logger.info(
|
|
124
|
+
f" Device: {device}, torch_dtype: {torch_dtype}, output_dtype: {output_dtype}"
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
def _detect_backend(self, model_name: str) -> str:
|
|
128
|
+
"""Auto-detect backend from model name."""
|
|
129
|
+
model_lower = model_name.lower()
|
|
130
|
+
|
|
131
|
+
for key, backend in self.MODEL_BACKENDS.items():
|
|
132
|
+
if key in model_lower:
|
|
133
|
+
logger.debug(f"Detected backend '{backend}' from model name")
|
|
134
|
+
return backend
|
|
135
|
+
|
|
136
|
+
# Default to colpali for unknown models
|
|
137
|
+
logger.warning(f"Unknown model '{model_name}', defaulting to 'colpali' backend")
|
|
138
|
+
return "colpali"
|
|
139
|
+
|
|
140
|
+
def _load_model(self):
|
|
141
|
+
"""Lazy load the model when first needed."""
|
|
142
|
+
if self._model is not None:
|
|
143
|
+
return
|
|
144
|
+
|
|
145
|
+
if self.backend == "colpali":
|
|
146
|
+
self._load_colpali_model()
|
|
147
|
+
else:
|
|
148
|
+
raise ValueError(f"Unknown backend: {self.backend}")
|
|
149
|
+
|
|
150
|
+
def _load_colpali_model(self):
|
|
151
|
+
"""Load ColPali-family model."""
|
|
152
|
+
try:
|
|
153
|
+
from colpali_engine.models import (
|
|
154
|
+
ColIdefics3,
|
|
155
|
+
ColIdefics3Processor,
|
|
156
|
+
ColPali,
|
|
157
|
+
ColPaliProcessor,
|
|
158
|
+
ColQwen2,
|
|
159
|
+
ColQwen2Processor,
|
|
160
|
+
)
|
|
161
|
+
except ImportError:
|
|
162
|
+
raise ImportError(
|
|
163
|
+
"colpali_engine not installed. Install with: "
|
|
164
|
+
"pip install visual-rag-toolkit[embedding] or "
|
|
165
|
+
"pip install colpali-engine"
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
logger.info(f"🤖 Loading ColPali model: {self.model_name}")
|
|
169
|
+
logger.info(f" Device: {self.device}, dtype: {self.torch_dtype}")
|
|
170
|
+
|
|
171
|
+
def _processor_kwargs():
|
|
172
|
+
if self.processor_speed == "auto":
|
|
173
|
+
return {}
|
|
174
|
+
return {"use_fast": self.processor_speed == "fast"}
|
|
175
|
+
|
|
176
|
+
from transformers import AutoConfig
|
|
177
|
+
|
|
178
|
+
cfg = AutoConfig.from_pretrained(self.model_name)
|
|
179
|
+
model_type = str(getattr(cfg, "model_type", "") or "").lower()
|
|
180
|
+
|
|
181
|
+
if model_type == "colpali" or "colpali" in (self.model_name or "").lower():
|
|
182
|
+
self._model = ColPali.from_pretrained(
|
|
183
|
+
self.model_name,
|
|
184
|
+
torch_dtype=self.torch_dtype,
|
|
185
|
+
device_map=self.device,
|
|
186
|
+
).eval()
|
|
187
|
+
try:
|
|
188
|
+
self._processor = ColPaliProcessor.from_pretrained(
|
|
189
|
+
self.model_name, **_processor_kwargs()
|
|
190
|
+
)
|
|
191
|
+
except TypeError:
|
|
192
|
+
self._processor = ColPaliProcessor.from_pretrained(self.model_name)
|
|
193
|
+
except Exception:
|
|
194
|
+
if self.processor_speed == "fast":
|
|
195
|
+
self._processor = ColPaliProcessor.from_pretrained(
|
|
196
|
+
self.model_name, use_fast=False
|
|
197
|
+
)
|
|
198
|
+
else:
|
|
199
|
+
raise
|
|
200
|
+
self._image_token_id = self._processor.image_token_id
|
|
201
|
+
logger.info("✅ Loaded ColPali backend")
|
|
202
|
+
return
|
|
203
|
+
|
|
204
|
+
if model_type.startswith("qwen2") or "colqwen" in (self.model_name or "").lower():
|
|
205
|
+
self._model = ColQwen2.from_pretrained(
|
|
206
|
+
self.model_name,
|
|
207
|
+
dtype=self.torch_dtype,
|
|
208
|
+
device_map=self.device,
|
|
209
|
+
).eval()
|
|
210
|
+
try:
|
|
211
|
+
self._processor = ColQwen2Processor.from_pretrained(
|
|
212
|
+
self.model_name, device_map=self.device, **_processor_kwargs()
|
|
213
|
+
)
|
|
214
|
+
except TypeError:
|
|
215
|
+
self._processor = ColQwen2Processor.from_pretrained(
|
|
216
|
+
self.model_name, device_map=self.device
|
|
217
|
+
)
|
|
218
|
+
except Exception:
|
|
219
|
+
if self.processor_speed == "fast":
|
|
220
|
+
self._processor = ColQwen2Processor.from_pretrained(
|
|
221
|
+
self.model_name, device_map=self.device, use_fast=False
|
|
222
|
+
)
|
|
223
|
+
else:
|
|
224
|
+
raise
|
|
225
|
+
self._image_token_id = self._processor.image_token_id
|
|
226
|
+
logger.info("✅ Loaded ColQwen2 backend")
|
|
227
|
+
return
|
|
228
|
+
|
|
229
|
+
attn_implementation = "eager"
|
|
230
|
+
if self.device != "cpu":
|
|
231
|
+
try:
|
|
232
|
+
import flash_attn # noqa
|
|
233
|
+
|
|
234
|
+
attn_implementation = "flash_attention_2"
|
|
235
|
+
logger.info(" Using FlashAttention2")
|
|
236
|
+
except ImportError:
|
|
237
|
+
pass
|
|
238
|
+
|
|
239
|
+
self._model = ColIdefics3.from_pretrained(
|
|
240
|
+
self.model_name,
|
|
241
|
+
dtype=self.torch_dtype,
|
|
242
|
+
device_map=self.device,
|
|
243
|
+
attn_implementation=attn_implementation,
|
|
244
|
+
).eval()
|
|
245
|
+
try:
|
|
246
|
+
self._processor = ColIdefics3Processor.from_pretrained(
|
|
247
|
+
self.model_name, **_processor_kwargs()
|
|
248
|
+
)
|
|
249
|
+
except TypeError:
|
|
250
|
+
self._processor = ColIdefics3Processor.from_pretrained(self.model_name)
|
|
251
|
+
except Exception:
|
|
252
|
+
if self.processor_speed == "fast":
|
|
253
|
+
self._processor = ColIdefics3Processor.from_pretrained(
|
|
254
|
+
self.model_name, use_fast=False
|
|
255
|
+
)
|
|
256
|
+
else:
|
|
257
|
+
raise
|
|
258
|
+
self._image_token_id = self._processor.image_token_id
|
|
259
|
+
|
|
260
|
+
logger.info("✅ Model loaded successfully")
|
|
261
|
+
|
|
262
|
+
@property
|
|
263
|
+
def model(self):
|
|
264
|
+
self._load_model()
|
|
265
|
+
return self._model
|
|
266
|
+
|
|
267
|
+
@property
|
|
268
|
+
def processor(self):
|
|
269
|
+
self._load_model()
|
|
270
|
+
return self._processor
|
|
271
|
+
|
|
272
|
+
@property
|
|
273
|
+
def image_token_id(self):
|
|
274
|
+
self._load_model()
|
|
275
|
+
return self._image_token_id
|
|
276
|
+
|
|
277
|
+
def embed_query(
|
|
278
|
+
self,
|
|
279
|
+
query_text: str,
|
|
280
|
+
filter_special_tokens: Optional[bool] = None,
|
|
281
|
+
) -> torch.Tensor:
|
|
282
|
+
"""
|
|
283
|
+
Generate embedding for a text query.
|
|
284
|
+
|
|
285
|
+
By default, filters out special tokens (CLS, SEP, PAD) to keep only
|
|
286
|
+
meaningful text tokens for better MaxSim matching.
|
|
287
|
+
|
|
288
|
+
Args:
|
|
289
|
+
query_text: Natural language query string
|
|
290
|
+
filter_special_tokens: Override instance-level setting
|
|
291
|
+
|
|
292
|
+
Returns:
|
|
293
|
+
Query embedding tensor of shape [num_tokens, embedding_dim]
|
|
294
|
+
"""
|
|
295
|
+
should_filter = (
|
|
296
|
+
filter_special_tokens
|
|
297
|
+
if filter_special_tokens is not None
|
|
298
|
+
else self.filter_special_tokens
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
with torch.no_grad():
|
|
302
|
+
processed = self.processor.process_queries([query_text]).to(self.model.device)
|
|
303
|
+
embedding = self.model(**processed)
|
|
304
|
+
|
|
305
|
+
# Remove batch dimension: [1, tokens, dim] -> [tokens, dim]
|
|
306
|
+
if embedding.dim() == 3:
|
|
307
|
+
embedding = embedding.squeeze(0)
|
|
308
|
+
|
|
309
|
+
if should_filter:
|
|
310
|
+
# Filter special tokens based on attention mask
|
|
311
|
+
attention_mask = processed.get("attention_mask")
|
|
312
|
+
if attention_mask is not None:
|
|
313
|
+
# Keep only tokens with attention_mask = 1
|
|
314
|
+
valid_mask = attention_mask.squeeze(0).bool()
|
|
315
|
+
embedding = embedding[valid_mask]
|
|
316
|
+
|
|
317
|
+
# Additionally filter padding tokens if present
|
|
318
|
+
input_ids = processed.get("input_ids")
|
|
319
|
+
if input_ids is not None:
|
|
320
|
+
input_ids = input_ids.squeeze(0)[valid_mask]
|
|
321
|
+
# Filter common special token IDs
|
|
322
|
+
# IDs >= 4 are usually real tokens for most tokenizers
|
|
323
|
+
non_special_mask = input_ids >= 4
|
|
324
|
+
if non_special_mask.any():
|
|
325
|
+
embedding = embedding[non_special_mask]
|
|
326
|
+
|
|
327
|
+
logger.debug(f"Query embedding: {embedding.shape[0]} tokens after filtering")
|
|
328
|
+
else:
|
|
329
|
+
logger.debug(f"Query embedding: {embedding.shape[0]} tokens (unfiltered)")
|
|
330
|
+
|
|
331
|
+
return embedding
|
|
332
|
+
|
|
333
|
+
def embed_queries(
|
|
334
|
+
self,
|
|
335
|
+
query_texts: List[str],
|
|
336
|
+
batch_size: Optional[int] = None,
|
|
337
|
+
filter_special_tokens: Optional[bool] = None,
|
|
338
|
+
show_progress: bool = True,
|
|
339
|
+
) -> List[torch.Tensor]:
|
|
340
|
+
"""
|
|
341
|
+
Generate embeddings for a list of text queries.
|
|
342
|
+
|
|
343
|
+
Returns a list of tensors, each of shape [num_tokens, embedding_dim].
|
|
344
|
+
"""
|
|
345
|
+
should_filter = (
|
|
346
|
+
filter_special_tokens
|
|
347
|
+
if filter_special_tokens is not None
|
|
348
|
+
else self.filter_special_tokens
|
|
349
|
+
)
|
|
350
|
+
batch_size = batch_size or self.batch_size
|
|
351
|
+
|
|
352
|
+
outputs: List[torch.Tensor] = []
|
|
353
|
+
iterator = range(0, len(query_texts), batch_size)
|
|
354
|
+
if show_progress:
|
|
355
|
+
iterator = tqdm(iterator, desc="📝 Embedding queries", unit="batch")
|
|
356
|
+
|
|
357
|
+
for i in iterator:
|
|
358
|
+
batch = query_texts[i : i + batch_size]
|
|
359
|
+
with torch.no_grad():
|
|
360
|
+
processed = self.processor.process_queries(batch).to(self.model.device)
|
|
361
|
+
batch_embeddings = self.model(**processed)
|
|
362
|
+
|
|
363
|
+
if isinstance(batch_embeddings, torch.Tensor) and batch_embeddings.dim() == 3:
|
|
364
|
+
attn = processed.get("attention_mask") if should_filter else None
|
|
365
|
+
input_ids = processed.get("input_ids") if should_filter else None
|
|
366
|
+
|
|
367
|
+
for j in range(batch_embeddings.shape[0]):
|
|
368
|
+
emb = batch_embeddings[j]
|
|
369
|
+
if should_filter and attn is not None:
|
|
370
|
+
valid_mask = attn[j].bool()
|
|
371
|
+
emb = emb[valid_mask]
|
|
372
|
+
if input_ids is not None:
|
|
373
|
+
ids = input_ids[j][valid_mask]
|
|
374
|
+
non_special_mask = ids >= 4
|
|
375
|
+
if non_special_mask.any():
|
|
376
|
+
emb = emb[non_special_mask]
|
|
377
|
+
outputs.append(emb)
|
|
378
|
+
else:
|
|
379
|
+
outputs.extend(batch_embeddings)
|
|
380
|
+
|
|
381
|
+
del processed, batch_embeddings
|
|
382
|
+
gc.collect()
|
|
383
|
+
if torch.cuda.is_available():
|
|
384
|
+
torch.cuda.empty_cache()
|
|
385
|
+
elif torch.backends.mps.is_available():
|
|
386
|
+
torch.mps.empty_cache()
|
|
387
|
+
|
|
388
|
+
return outputs
|
|
389
|
+
|
|
390
|
+
def embed_images(
|
|
391
|
+
self,
|
|
392
|
+
images: List[Image.Image],
|
|
393
|
+
batch_size: Optional[int] = None,
|
|
394
|
+
return_token_info: bool = False,
|
|
395
|
+
show_progress: bool = True,
|
|
396
|
+
) -> Union[List[torch.Tensor], Tuple[List[torch.Tensor], List[Dict[str, Any]]]]:
|
|
397
|
+
"""
|
|
398
|
+
Generate embeddings for a list of images.
|
|
399
|
+
|
|
400
|
+
Args:
|
|
401
|
+
images: List of PIL Images
|
|
402
|
+
batch_size: Override instance batch size
|
|
403
|
+
return_token_info: Also return token metadata (for saliency maps)
|
|
404
|
+
show_progress: Show progress bar
|
|
405
|
+
|
|
406
|
+
Returns:
|
|
407
|
+
If return_token_info=False:
|
|
408
|
+
List of embedding tensors [num_patches, dim]
|
|
409
|
+
If return_token_info=True:
|
|
410
|
+
Tuple of (embeddings, token_infos)
|
|
411
|
+
|
|
412
|
+
Token info contains:
|
|
413
|
+
- visual_token_indices: Indices of visual tokens in embedding
|
|
414
|
+
- num_visual_tokens: Count of visual tokens
|
|
415
|
+
- n_rows, n_cols: Tile grid dimensions
|
|
416
|
+
- num_tiles: Total tiles (n_rows × n_cols + 1 global)
|
|
417
|
+
"""
|
|
418
|
+
batch_size = batch_size or self.batch_size
|
|
419
|
+
if (
|
|
420
|
+
self.device == "mps"
|
|
421
|
+
and "colpali" in (self.model_name or "").lower()
|
|
422
|
+
and int(batch_size) > 1
|
|
423
|
+
):
|
|
424
|
+
batch_size = 1
|
|
425
|
+
|
|
426
|
+
embeddings = []
|
|
427
|
+
token_infos = [] if return_token_info else None
|
|
428
|
+
|
|
429
|
+
iterator = range(0, len(images), batch_size)
|
|
430
|
+
if show_progress:
|
|
431
|
+
iterator = tqdm(iterator, desc="🎨 Embedding", unit="batch")
|
|
432
|
+
|
|
433
|
+
for i in iterator:
|
|
434
|
+
batch = images[i : i + batch_size]
|
|
435
|
+
|
|
436
|
+
with torch.no_grad():
|
|
437
|
+
processed = self.processor.process_images(batch).to(self.model.device)
|
|
438
|
+
|
|
439
|
+
# Extract token info before model forward
|
|
440
|
+
if return_token_info:
|
|
441
|
+
input_ids = processed["input_ids"]
|
|
442
|
+
batch_n_rows = processed.get("n_rows")
|
|
443
|
+
batch_n_cols = processed.get("n_cols")
|
|
444
|
+
|
|
445
|
+
for j in range(input_ids.shape[0]):
|
|
446
|
+
# Find visual token indices
|
|
447
|
+
image_token_mask = input_ids[j] == self.image_token_id
|
|
448
|
+
visual_indices = torch.where(image_token_mask)[0].cpu().numpy().tolist()
|
|
449
|
+
|
|
450
|
+
n_rows = batch_n_rows[j].item() if batch_n_rows is not None else None
|
|
451
|
+
n_cols = batch_n_cols[j].item() if batch_n_cols is not None else None
|
|
452
|
+
|
|
453
|
+
token_infos.append(
|
|
454
|
+
{
|
|
455
|
+
"visual_token_indices": visual_indices,
|
|
456
|
+
"num_visual_tokens": len(visual_indices),
|
|
457
|
+
"n_rows": n_rows,
|
|
458
|
+
"n_cols": n_cols,
|
|
459
|
+
"num_tiles": (n_rows * n_cols + 1) if n_rows and n_cols else None,
|
|
460
|
+
}
|
|
461
|
+
)
|
|
462
|
+
|
|
463
|
+
# Generate embeddings
|
|
464
|
+
batch_embeddings = self.model(**processed)
|
|
465
|
+
|
|
466
|
+
# Extract per-image embeddings
|
|
467
|
+
if isinstance(batch_embeddings, torch.Tensor) and batch_embeddings.dim() == 3:
|
|
468
|
+
for j in range(batch_embeddings.shape[0]):
|
|
469
|
+
embeddings.append(batch_embeddings[j].cpu())
|
|
470
|
+
else:
|
|
471
|
+
embeddings.extend([e.cpu() for e in batch_embeddings])
|
|
472
|
+
|
|
473
|
+
# Memory cleanup
|
|
474
|
+
del processed, batch_embeddings
|
|
475
|
+
gc.collect()
|
|
476
|
+
if torch.cuda.is_available():
|
|
477
|
+
torch.cuda.empty_cache()
|
|
478
|
+
elif torch.backends.mps.is_available():
|
|
479
|
+
torch.mps.empty_cache()
|
|
480
|
+
|
|
481
|
+
if return_token_info:
|
|
482
|
+
return embeddings, token_infos
|
|
483
|
+
return embeddings
|
|
484
|
+
|
|
485
|
+
def extract_visual_embedding(
|
|
486
|
+
self,
|
|
487
|
+
full_embedding: torch.Tensor,
|
|
488
|
+
token_info: Dict[str, Any],
|
|
489
|
+
) -> np.ndarray:
|
|
490
|
+
"""
|
|
491
|
+
Extract only visual token embeddings from full embedding.
|
|
492
|
+
|
|
493
|
+
Filters out special tokens, keeping only visual patches for MaxSim.
|
|
494
|
+
|
|
495
|
+
Args:
|
|
496
|
+
full_embedding: Full embedding [all_tokens, dim]
|
|
497
|
+
token_info: Token info dict from embed_images
|
|
498
|
+
|
|
499
|
+
Returns:
|
|
500
|
+
Visual embedding array [num_visual_tokens, dim]
|
|
501
|
+
"""
|
|
502
|
+
visual_indices = token_info["visual_token_indices"]
|
|
503
|
+
|
|
504
|
+
if isinstance(full_embedding, torch.Tensor):
|
|
505
|
+
if full_embedding.dtype == torch.bfloat16:
|
|
506
|
+
visual_emb = full_embedding[visual_indices].cpu().float().numpy()
|
|
507
|
+
else:
|
|
508
|
+
visual_emb = full_embedding[visual_indices].cpu().numpy()
|
|
509
|
+
else:
|
|
510
|
+
visual_emb = np.array(full_embedding)[visual_indices]
|
|
511
|
+
|
|
512
|
+
return visual_emb.astype(self.output_dtype)
|
|
513
|
+
|
|
514
|
+
def mean_pool_visual_embedding(
|
|
515
|
+
self,
|
|
516
|
+
visual_embedding: Union[torch.Tensor, np.ndarray],
|
|
517
|
+
token_info: Optional[Dict[str, Any]] = None,
|
|
518
|
+
*,
|
|
519
|
+
target_vectors: int = 32,
|
|
520
|
+
) -> np.ndarray:
|
|
521
|
+
from visual_rag.embedding.pooling import colpali_row_mean_pooling, tile_level_mean_pooling
|
|
522
|
+
|
|
523
|
+
model_lower = (self.model_name or "").lower()
|
|
524
|
+
is_colsmol = "colsmol" in model_lower
|
|
525
|
+
|
|
526
|
+
if isinstance(visual_embedding, torch.Tensor):
|
|
527
|
+
if visual_embedding.dtype == torch.bfloat16:
|
|
528
|
+
visual_np = visual_embedding.cpu().float().numpy()
|
|
529
|
+
else:
|
|
530
|
+
visual_np = visual_embedding.cpu().numpy().astype(np.float32)
|
|
531
|
+
else:
|
|
532
|
+
visual_np = np.array(visual_embedding, dtype=np.float32)
|
|
533
|
+
|
|
534
|
+
if is_colsmol:
|
|
535
|
+
n_rows = (token_info or {}).get("n_rows")
|
|
536
|
+
n_cols = (token_info or {}).get("n_cols")
|
|
537
|
+
num_tiles = int(n_rows) * int(n_cols) + 1 if n_rows and n_cols else 13
|
|
538
|
+
return tile_level_mean_pooling(
|
|
539
|
+
visual_np, num_tiles=num_tiles, patches_per_tile=64, output_dtype=self.output_dtype
|
|
540
|
+
)
|
|
541
|
+
|
|
542
|
+
num_tokens = int(visual_np.shape[0])
|
|
543
|
+
grid = int(round(float(num_tokens) ** 0.5))
|
|
544
|
+
if grid * grid != num_tokens:
|
|
545
|
+
raise ValueError(
|
|
546
|
+
f"Cannot infer square grid from num_visual_tokens={num_tokens} for model={self.model_name}"
|
|
547
|
+
)
|
|
548
|
+
if int(target_vectors) != int(grid):
|
|
549
|
+
raise ValueError(
|
|
550
|
+
f"target_vectors={target_vectors} does not match inferred grid_size={grid} for model={self.model_name}"
|
|
551
|
+
)
|
|
552
|
+
return colpali_row_mean_pooling(
|
|
553
|
+
visual_np, grid_size=int(target_vectors), output_dtype=self.output_dtype
|
|
554
|
+
)
|
|
555
|
+
|
|
556
|
+
def global_pool_from_mean_pool(self, mean_pool: np.ndarray) -> np.ndarray:
|
|
557
|
+
if mean_pool.size == 0:
|
|
558
|
+
return np.zeros((128,), dtype=self.output_dtype)
|
|
559
|
+
return mean_pool.mean(axis=0).astype(self.output_dtype)
|
|
560
|
+
|
|
561
|
+
def experimental_pool_visual_embedding(
|
|
562
|
+
self,
|
|
563
|
+
visual_embedding: Union[torch.Tensor, np.ndarray],
|
|
564
|
+
token_info: Optional[Dict[str, Any]] = None,
|
|
565
|
+
*,
|
|
566
|
+
target_vectors: int = 32,
|
|
567
|
+
mean_pool: Optional[np.ndarray] = None,
|
|
568
|
+
) -> np.ndarray:
|
|
569
|
+
from visual_rag.embedding.pooling import (
|
|
570
|
+
colpali_experimental_pooling_from_rows,
|
|
571
|
+
colsmol_experimental_pooling,
|
|
572
|
+
)
|
|
573
|
+
|
|
574
|
+
model_lower = (self.model_name or "").lower()
|
|
575
|
+
is_colsmol = "colsmol" in model_lower
|
|
576
|
+
|
|
577
|
+
if isinstance(visual_embedding, torch.Tensor):
|
|
578
|
+
if visual_embedding.dtype == torch.bfloat16:
|
|
579
|
+
visual_np = visual_embedding.cpu().float().numpy()
|
|
580
|
+
else:
|
|
581
|
+
visual_np = visual_embedding.cpu().numpy().astype(np.float32)
|
|
582
|
+
else:
|
|
583
|
+
visual_np = np.array(visual_embedding, dtype=np.float32)
|
|
584
|
+
|
|
585
|
+
if is_colsmol:
|
|
586
|
+
if (
|
|
587
|
+
mean_pool is not None
|
|
588
|
+
and getattr(mean_pool, "shape", None) is not None
|
|
589
|
+
and int(mean_pool.shape[0]) > 0
|
|
590
|
+
):
|
|
591
|
+
num_tiles = int(mean_pool.shape[0])
|
|
592
|
+
else:
|
|
593
|
+
num_tiles = (token_info or {}).get("num_tiles")
|
|
594
|
+
if num_tiles is None:
|
|
595
|
+
num_visual_tokens = (token_info or {}).get("num_visual_tokens")
|
|
596
|
+
if num_visual_tokens is None:
|
|
597
|
+
num_visual_tokens = int(visual_np.shape[0])
|
|
598
|
+
patches_per_tile = 64
|
|
599
|
+
num_tiles = int(num_visual_tokens) // patches_per_tile
|
|
600
|
+
if int(num_tiles) * patches_per_tile != int(num_visual_tokens):
|
|
601
|
+
num_tiles = int(num_tiles) + 1
|
|
602
|
+
num_tiles = int(num_tiles)
|
|
603
|
+
return colsmol_experimental_pooling(
|
|
604
|
+
visual_np, num_tiles=num_tiles, patches_per_tile=64, output_dtype=self.output_dtype
|
|
605
|
+
)
|
|
606
|
+
|
|
607
|
+
rows = (
|
|
608
|
+
mean_pool
|
|
609
|
+
if mean_pool is not None
|
|
610
|
+
else self.mean_pool_visual_embedding(
|
|
611
|
+
visual_np, token_info, target_vectors=target_vectors
|
|
612
|
+
)
|
|
613
|
+
)
|
|
614
|
+
if int(rows.shape[0]) != int(target_vectors):
|
|
615
|
+
raise ValueError(
|
|
616
|
+
f"experimental pooling expects mean_pool to have {target_vectors} rows, got {rows.shape[0]} for model={self.model_name}"
|
|
617
|
+
)
|
|
618
|
+
return colpali_experimental_pooling_from_rows(rows, output_dtype=self.output_dtype)
|
|
619
|
+
|
|
620
|
+
|
|
621
|
+
# Backward compatibility alias
|
|
622
|
+
ColPaliEmbedder = VisualEmbedder
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Indexing module - PDF processing, embedding storage, and CDN uploads.
|
|
3
|
+
|
|
4
|
+
Components:
|
|
5
|
+
- PDFProcessor: Convert PDFs to images and extract text
|
|
6
|
+
- QdrantIndexer: Upload embeddings to Qdrant vector database
|
|
7
|
+
- CloudinaryUploader: Upload images to Cloudinary CDN
|
|
8
|
+
- ProcessingPipeline: End-to-end PDF → Qdrant pipeline
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from visual_rag.indexing.cloudinary_uploader import CloudinaryUploader
|
|
12
|
+
from visual_rag.indexing.pdf_processor import PDFProcessor
|
|
13
|
+
from visual_rag.indexing.pipeline import ProcessingPipeline
|
|
14
|
+
from visual_rag.indexing.qdrant_indexer import QdrantIndexer
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"PDFProcessor",
|
|
18
|
+
"QdrantIndexer",
|
|
19
|
+
"CloudinaryUploader",
|
|
20
|
+
"ProcessingPipeline",
|
|
21
|
+
]
|