visual-rag-toolkit 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. benchmarks/README.md +101 -0
  2. benchmarks/__init__.py +11 -0
  3. benchmarks/analyze_results.py +187 -0
  4. benchmarks/benchmark_datasets.txt +105 -0
  5. benchmarks/prepare_submission.py +205 -0
  6. benchmarks/quick_test.py +566 -0
  7. benchmarks/run_vidore.py +513 -0
  8. benchmarks/vidore_beir_qdrant/run_qdrant_beir.py +1365 -0
  9. benchmarks/vidore_tatdqa_test/COMMANDS.md +83 -0
  10. benchmarks/vidore_tatdqa_test/__init__.py +6 -0
  11. benchmarks/vidore_tatdqa_test/dataset_loader.py +363 -0
  12. benchmarks/vidore_tatdqa_test/metrics.py +44 -0
  13. benchmarks/vidore_tatdqa_test/run_qdrant.py +799 -0
  14. benchmarks/vidore_tatdqa_test/sweep_eval.py +372 -0
  15. demo/__init__.py +10 -0
  16. demo/app.py +45 -0
  17. demo/commands.py +334 -0
  18. demo/config.py +34 -0
  19. demo/download_models.py +75 -0
  20. demo/evaluation.py +602 -0
  21. demo/example_metadata_mapping_sigir.json +37 -0
  22. demo/indexing.py +286 -0
  23. demo/qdrant_utils.py +211 -0
  24. demo/results.py +35 -0
  25. demo/test_qdrant_connection.py +119 -0
  26. demo/ui/__init__.py +15 -0
  27. demo/ui/benchmark.py +355 -0
  28. demo/ui/header.py +30 -0
  29. demo/ui/playground.py +339 -0
  30. demo/ui/sidebar.py +162 -0
  31. demo/ui/upload.py +487 -0
  32. visual_rag/__init__.py +98 -0
  33. visual_rag/cli/__init__.py +1 -0
  34. visual_rag/cli/main.py +629 -0
  35. visual_rag/config.py +230 -0
  36. visual_rag/demo_runner.py +90 -0
  37. visual_rag/embedding/__init__.py +26 -0
  38. visual_rag/embedding/pooling.py +343 -0
  39. visual_rag/embedding/visual_embedder.py +622 -0
  40. visual_rag/indexing/__init__.py +21 -0
  41. visual_rag/indexing/cloudinary_uploader.py +274 -0
  42. visual_rag/indexing/pdf_processor.py +324 -0
  43. visual_rag/indexing/pipeline.py +628 -0
  44. visual_rag/indexing/qdrant_indexer.py +478 -0
  45. visual_rag/preprocessing/__init__.py +3 -0
  46. visual_rag/preprocessing/crop_empty.py +120 -0
  47. visual_rag/qdrant_admin.py +222 -0
  48. visual_rag/retrieval/__init__.py +19 -0
  49. visual_rag/retrieval/multi_vector.py +222 -0
  50. visual_rag/retrieval/single_stage.py +126 -0
  51. visual_rag/retrieval/three_stage.py +173 -0
  52. visual_rag/retrieval/two_stage.py +471 -0
  53. visual_rag/visualization/__init__.py +19 -0
  54. visual_rag/visualization/saliency.py +335 -0
  55. visual_rag_toolkit-0.1.1.dist-info/METADATA +305 -0
  56. visual_rag_toolkit-0.1.1.dist-info/RECORD +59 -0
  57. visual_rag_toolkit-0.1.1.dist-info/WHEEL +4 -0
  58. visual_rag_toolkit-0.1.1.dist-info/entry_points.txt +3 -0
  59. visual_rag_toolkit-0.1.1.dist-info/licenses/LICENSE +22 -0
@@ -0,0 +1,622 @@
1
+ """
2
+ Visual Embedder - Generate visual and text embeddings for document retrieval.
3
+
4
+ This module provides a flexible interface that supports:
5
+ - ColPali models (ColSmol, ColPali, ColQwen2)
6
+ - Other vision-language models (future)
7
+ - Image embedding with tile-aware processing
8
+ - Query embedding with special token filtering
9
+
10
+ The embedder is BACKEND-AGNOSTIC - configure which model to use via the
11
+ `backend` parameter or model_name.
12
+ """
13
+
14
+ import gc
15
+ import logging
16
+ import os
17
+ from typing import Any, Dict, List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import torch
21
+ from PIL import Image
22
+ from tqdm import tqdm
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class VisualEmbedder:
28
+ """
29
+ Visual document embedder supporting multiple backends.
30
+
31
+ Currently supports:
32
+ - ColPali family (ColSmol-500M, ColPali, ColQwen2)
33
+ - More backends can be added
34
+
35
+ Args:
36
+ model_name: HuggingFace model name (e.g., "vidore/colSmol-500M")
37
+ backend: Backend type ("colpali", "auto"). "auto" detects from model_name.
38
+ device: Device to use (auto, cuda, mps, cpu)
39
+ torch_dtype: Data type for model weights
40
+ batch_size: Batch size for image processing
41
+ filter_special_tokens: Filter special tokens from query embeddings
42
+
43
+ Example:
44
+ >>> # Auto-detect backend from model name
45
+ >>> embedder = VisualEmbedder(model_name="vidore/colSmol-500M")
46
+ >>>
47
+ >>> # Embed images
48
+ >>> image_embeddings = embedder.embed_images(images)
49
+ >>>
50
+ >>> # Embed query
51
+ >>> query_embedding = embedder.embed_query("What is the budget?")
52
+ >>>
53
+ >>> # Get token info for saliency maps
54
+ >>> embeddings, token_infos = embedder.embed_images(
55
+ ... images, return_token_info=True
56
+ ... )
57
+ """
58
+
59
+ # Known model families and their backends
60
+ MODEL_BACKENDS = {
61
+ "colsmol": "colpali",
62
+ "colpali": "colpali",
63
+ "colqwen": "colpali",
64
+ "colidefics": "colpali",
65
+ }
66
+
67
+ def __init__(
68
+ self,
69
+ model_name: str = "vidore/colSmol-500M",
70
+ backend: str = "auto",
71
+ device: Optional[str] = None,
72
+ torch_dtype: Optional[torch.dtype] = None,
73
+ output_dtype: Optional[np.dtype] = None,
74
+ batch_size: int = 4,
75
+ filter_special_tokens: bool = True,
76
+ processor_speed: str = "fast",
77
+ ):
78
+ self.model_name = model_name
79
+ self.batch_size = batch_size
80
+ self.filter_special_tokens = filter_special_tokens
81
+ if processor_speed not in ("fast", "slow", "auto"):
82
+ raise ValueError("processor_speed must be one of: fast, slow, auto")
83
+ self.processor_speed = processor_speed
84
+
85
+ if os.getenv("VISUALRAG_INCLUDE_SPECIAL_TOKENS"):
86
+ self.filter_special_tokens = False
87
+ logger.info("Special token filtering disabled via VISUALRAG_INCLUDE_SPECIAL_TOKENS")
88
+
89
+ if backend == "auto":
90
+ backend = self._detect_backend(model_name)
91
+ self.backend = backend
92
+
93
+ if device is None:
94
+ if torch.cuda.is_available():
95
+ device = "cuda"
96
+ elif torch.backends.mps.is_available():
97
+ device = "mps"
98
+ else:
99
+ device = "cpu"
100
+ self.device = device
101
+
102
+ if torch_dtype is None:
103
+ if device == "cuda":
104
+ torch_dtype = torch.bfloat16
105
+ else:
106
+ torch_dtype = torch.float32
107
+ self.torch_dtype = torch_dtype
108
+
109
+ if output_dtype is None:
110
+ if torch_dtype == torch.float16:
111
+ output_dtype = np.float16
112
+ else:
113
+ output_dtype = np.float32
114
+ self.output_dtype = output_dtype
115
+
116
+ self._model = None
117
+ self._processor = None
118
+ self._image_token_id = None
119
+
120
+ logger.info("🤖 VisualEmbedder initialized")
121
+ logger.info(f" Model: {model_name}")
122
+ logger.info(f" Backend: {backend}")
123
+ logger.info(
124
+ f" Device: {device}, torch_dtype: {torch_dtype}, output_dtype: {output_dtype}"
125
+ )
126
+
127
+ def _detect_backend(self, model_name: str) -> str:
128
+ """Auto-detect backend from model name."""
129
+ model_lower = model_name.lower()
130
+
131
+ for key, backend in self.MODEL_BACKENDS.items():
132
+ if key in model_lower:
133
+ logger.debug(f"Detected backend '{backend}' from model name")
134
+ return backend
135
+
136
+ # Default to colpali for unknown models
137
+ logger.warning(f"Unknown model '{model_name}', defaulting to 'colpali' backend")
138
+ return "colpali"
139
+
140
+ def _load_model(self):
141
+ """Lazy load the model when first needed."""
142
+ if self._model is not None:
143
+ return
144
+
145
+ if self.backend == "colpali":
146
+ self._load_colpali_model()
147
+ else:
148
+ raise ValueError(f"Unknown backend: {self.backend}")
149
+
150
+ def _load_colpali_model(self):
151
+ """Load ColPali-family model."""
152
+ try:
153
+ from colpali_engine.models import (
154
+ ColIdefics3,
155
+ ColIdefics3Processor,
156
+ ColPali,
157
+ ColPaliProcessor,
158
+ ColQwen2,
159
+ ColQwen2Processor,
160
+ )
161
+ except ImportError:
162
+ raise ImportError(
163
+ "colpali_engine not installed. Install with: "
164
+ "pip install visual-rag-toolkit[embedding] or "
165
+ "pip install colpali-engine"
166
+ )
167
+
168
+ logger.info(f"🤖 Loading ColPali model: {self.model_name}")
169
+ logger.info(f" Device: {self.device}, dtype: {self.torch_dtype}")
170
+
171
+ def _processor_kwargs():
172
+ if self.processor_speed == "auto":
173
+ return {}
174
+ return {"use_fast": self.processor_speed == "fast"}
175
+
176
+ from transformers import AutoConfig
177
+
178
+ cfg = AutoConfig.from_pretrained(self.model_name)
179
+ model_type = str(getattr(cfg, "model_type", "") or "").lower()
180
+
181
+ if model_type == "colpali" or "colpali" in (self.model_name or "").lower():
182
+ self._model = ColPali.from_pretrained(
183
+ self.model_name,
184
+ torch_dtype=self.torch_dtype,
185
+ device_map=self.device,
186
+ ).eval()
187
+ try:
188
+ self._processor = ColPaliProcessor.from_pretrained(
189
+ self.model_name, **_processor_kwargs()
190
+ )
191
+ except TypeError:
192
+ self._processor = ColPaliProcessor.from_pretrained(self.model_name)
193
+ except Exception:
194
+ if self.processor_speed == "fast":
195
+ self._processor = ColPaliProcessor.from_pretrained(
196
+ self.model_name, use_fast=False
197
+ )
198
+ else:
199
+ raise
200
+ self._image_token_id = self._processor.image_token_id
201
+ logger.info("✅ Loaded ColPali backend")
202
+ return
203
+
204
+ if model_type.startswith("qwen2") or "colqwen" in (self.model_name or "").lower():
205
+ self._model = ColQwen2.from_pretrained(
206
+ self.model_name,
207
+ dtype=self.torch_dtype,
208
+ device_map=self.device,
209
+ ).eval()
210
+ try:
211
+ self._processor = ColQwen2Processor.from_pretrained(
212
+ self.model_name, device_map=self.device, **_processor_kwargs()
213
+ )
214
+ except TypeError:
215
+ self._processor = ColQwen2Processor.from_pretrained(
216
+ self.model_name, device_map=self.device
217
+ )
218
+ except Exception:
219
+ if self.processor_speed == "fast":
220
+ self._processor = ColQwen2Processor.from_pretrained(
221
+ self.model_name, device_map=self.device, use_fast=False
222
+ )
223
+ else:
224
+ raise
225
+ self._image_token_id = self._processor.image_token_id
226
+ logger.info("✅ Loaded ColQwen2 backend")
227
+ return
228
+
229
+ attn_implementation = "eager"
230
+ if self.device != "cpu":
231
+ try:
232
+ import flash_attn # noqa
233
+
234
+ attn_implementation = "flash_attention_2"
235
+ logger.info(" Using FlashAttention2")
236
+ except ImportError:
237
+ pass
238
+
239
+ self._model = ColIdefics3.from_pretrained(
240
+ self.model_name,
241
+ dtype=self.torch_dtype,
242
+ device_map=self.device,
243
+ attn_implementation=attn_implementation,
244
+ ).eval()
245
+ try:
246
+ self._processor = ColIdefics3Processor.from_pretrained(
247
+ self.model_name, **_processor_kwargs()
248
+ )
249
+ except TypeError:
250
+ self._processor = ColIdefics3Processor.from_pretrained(self.model_name)
251
+ except Exception:
252
+ if self.processor_speed == "fast":
253
+ self._processor = ColIdefics3Processor.from_pretrained(
254
+ self.model_name, use_fast=False
255
+ )
256
+ else:
257
+ raise
258
+ self._image_token_id = self._processor.image_token_id
259
+
260
+ logger.info("✅ Model loaded successfully")
261
+
262
+ @property
263
+ def model(self):
264
+ self._load_model()
265
+ return self._model
266
+
267
+ @property
268
+ def processor(self):
269
+ self._load_model()
270
+ return self._processor
271
+
272
+ @property
273
+ def image_token_id(self):
274
+ self._load_model()
275
+ return self._image_token_id
276
+
277
+ def embed_query(
278
+ self,
279
+ query_text: str,
280
+ filter_special_tokens: Optional[bool] = None,
281
+ ) -> torch.Tensor:
282
+ """
283
+ Generate embedding for a text query.
284
+
285
+ By default, filters out special tokens (CLS, SEP, PAD) to keep only
286
+ meaningful text tokens for better MaxSim matching.
287
+
288
+ Args:
289
+ query_text: Natural language query string
290
+ filter_special_tokens: Override instance-level setting
291
+
292
+ Returns:
293
+ Query embedding tensor of shape [num_tokens, embedding_dim]
294
+ """
295
+ should_filter = (
296
+ filter_special_tokens
297
+ if filter_special_tokens is not None
298
+ else self.filter_special_tokens
299
+ )
300
+
301
+ with torch.no_grad():
302
+ processed = self.processor.process_queries([query_text]).to(self.model.device)
303
+ embedding = self.model(**processed)
304
+
305
+ # Remove batch dimension: [1, tokens, dim] -> [tokens, dim]
306
+ if embedding.dim() == 3:
307
+ embedding = embedding.squeeze(0)
308
+
309
+ if should_filter:
310
+ # Filter special tokens based on attention mask
311
+ attention_mask = processed.get("attention_mask")
312
+ if attention_mask is not None:
313
+ # Keep only tokens with attention_mask = 1
314
+ valid_mask = attention_mask.squeeze(0).bool()
315
+ embedding = embedding[valid_mask]
316
+
317
+ # Additionally filter padding tokens if present
318
+ input_ids = processed.get("input_ids")
319
+ if input_ids is not None:
320
+ input_ids = input_ids.squeeze(0)[valid_mask]
321
+ # Filter common special token IDs
322
+ # IDs >= 4 are usually real tokens for most tokenizers
323
+ non_special_mask = input_ids >= 4
324
+ if non_special_mask.any():
325
+ embedding = embedding[non_special_mask]
326
+
327
+ logger.debug(f"Query embedding: {embedding.shape[0]} tokens after filtering")
328
+ else:
329
+ logger.debug(f"Query embedding: {embedding.shape[0]} tokens (unfiltered)")
330
+
331
+ return embedding
332
+
333
+ def embed_queries(
334
+ self,
335
+ query_texts: List[str],
336
+ batch_size: Optional[int] = None,
337
+ filter_special_tokens: Optional[bool] = None,
338
+ show_progress: bool = True,
339
+ ) -> List[torch.Tensor]:
340
+ """
341
+ Generate embeddings for a list of text queries.
342
+
343
+ Returns a list of tensors, each of shape [num_tokens, embedding_dim].
344
+ """
345
+ should_filter = (
346
+ filter_special_tokens
347
+ if filter_special_tokens is not None
348
+ else self.filter_special_tokens
349
+ )
350
+ batch_size = batch_size or self.batch_size
351
+
352
+ outputs: List[torch.Tensor] = []
353
+ iterator = range(0, len(query_texts), batch_size)
354
+ if show_progress:
355
+ iterator = tqdm(iterator, desc="📝 Embedding queries", unit="batch")
356
+
357
+ for i in iterator:
358
+ batch = query_texts[i : i + batch_size]
359
+ with torch.no_grad():
360
+ processed = self.processor.process_queries(batch).to(self.model.device)
361
+ batch_embeddings = self.model(**processed)
362
+
363
+ if isinstance(batch_embeddings, torch.Tensor) and batch_embeddings.dim() == 3:
364
+ attn = processed.get("attention_mask") if should_filter else None
365
+ input_ids = processed.get("input_ids") if should_filter else None
366
+
367
+ for j in range(batch_embeddings.shape[0]):
368
+ emb = batch_embeddings[j]
369
+ if should_filter and attn is not None:
370
+ valid_mask = attn[j].bool()
371
+ emb = emb[valid_mask]
372
+ if input_ids is not None:
373
+ ids = input_ids[j][valid_mask]
374
+ non_special_mask = ids >= 4
375
+ if non_special_mask.any():
376
+ emb = emb[non_special_mask]
377
+ outputs.append(emb)
378
+ else:
379
+ outputs.extend(batch_embeddings)
380
+
381
+ del processed, batch_embeddings
382
+ gc.collect()
383
+ if torch.cuda.is_available():
384
+ torch.cuda.empty_cache()
385
+ elif torch.backends.mps.is_available():
386
+ torch.mps.empty_cache()
387
+
388
+ return outputs
389
+
390
+ def embed_images(
391
+ self,
392
+ images: List[Image.Image],
393
+ batch_size: Optional[int] = None,
394
+ return_token_info: bool = False,
395
+ show_progress: bool = True,
396
+ ) -> Union[List[torch.Tensor], Tuple[List[torch.Tensor], List[Dict[str, Any]]]]:
397
+ """
398
+ Generate embeddings for a list of images.
399
+
400
+ Args:
401
+ images: List of PIL Images
402
+ batch_size: Override instance batch size
403
+ return_token_info: Also return token metadata (for saliency maps)
404
+ show_progress: Show progress bar
405
+
406
+ Returns:
407
+ If return_token_info=False:
408
+ List of embedding tensors [num_patches, dim]
409
+ If return_token_info=True:
410
+ Tuple of (embeddings, token_infos)
411
+
412
+ Token info contains:
413
+ - visual_token_indices: Indices of visual tokens in embedding
414
+ - num_visual_tokens: Count of visual tokens
415
+ - n_rows, n_cols: Tile grid dimensions
416
+ - num_tiles: Total tiles (n_rows × n_cols + 1 global)
417
+ """
418
+ batch_size = batch_size or self.batch_size
419
+ if (
420
+ self.device == "mps"
421
+ and "colpali" in (self.model_name or "").lower()
422
+ and int(batch_size) > 1
423
+ ):
424
+ batch_size = 1
425
+
426
+ embeddings = []
427
+ token_infos = [] if return_token_info else None
428
+
429
+ iterator = range(0, len(images), batch_size)
430
+ if show_progress:
431
+ iterator = tqdm(iterator, desc="🎨 Embedding", unit="batch")
432
+
433
+ for i in iterator:
434
+ batch = images[i : i + batch_size]
435
+
436
+ with torch.no_grad():
437
+ processed = self.processor.process_images(batch).to(self.model.device)
438
+
439
+ # Extract token info before model forward
440
+ if return_token_info:
441
+ input_ids = processed["input_ids"]
442
+ batch_n_rows = processed.get("n_rows")
443
+ batch_n_cols = processed.get("n_cols")
444
+
445
+ for j in range(input_ids.shape[0]):
446
+ # Find visual token indices
447
+ image_token_mask = input_ids[j] == self.image_token_id
448
+ visual_indices = torch.where(image_token_mask)[0].cpu().numpy().tolist()
449
+
450
+ n_rows = batch_n_rows[j].item() if batch_n_rows is not None else None
451
+ n_cols = batch_n_cols[j].item() if batch_n_cols is not None else None
452
+
453
+ token_infos.append(
454
+ {
455
+ "visual_token_indices": visual_indices,
456
+ "num_visual_tokens": len(visual_indices),
457
+ "n_rows": n_rows,
458
+ "n_cols": n_cols,
459
+ "num_tiles": (n_rows * n_cols + 1) if n_rows and n_cols else None,
460
+ }
461
+ )
462
+
463
+ # Generate embeddings
464
+ batch_embeddings = self.model(**processed)
465
+
466
+ # Extract per-image embeddings
467
+ if isinstance(batch_embeddings, torch.Tensor) and batch_embeddings.dim() == 3:
468
+ for j in range(batch_embeddings.shape[0]):
469
+ embeddings.append(batch_embeddings[j].cpu())
470
+ else:
471
+ embeddings.extend([e.cpu() for e in batch_embeddings])
472
+
473
+ # Memory cleanup
474
+ del processed, batch_embeddings
475
+ gc.collect()
476
+ if torch.cuda.is_available():
477
+ torch.cuda.empty_cache()
478
+ elif torch.backends.mps.is_available():
479
+ torch.mps.empty_cache()
480
+
481
+ if return_token_info:
482
+ return embeddings, token_infos
483
+ return embeddings
484
+
485
+ def extract_visual_embedding(
486
+ self,
487
+ full_embedding: torch.Tensor,
488
+ token_info: Dict[str, Any],
489
+ ) -> np.ndarray:
490
+ """
491
+ Extract only visual token embeddings from full embedding.
492
+
493
+ Filters out special tokens, keeping only visual patches for MaxSim.
494
+
495
+ Args:
496
+ full_embedding: Full embedding [all_tokens, dim]
497
+ token_info: Token info dict from embed_images
498
+
499
+ Returns:
500
+ Visual embedding array [num_visual_tokens, dim]
501
+ """
502
+ visual_indices = token_info["visual_token_indices"]
503
+
504
+ if isinstance(full_embedding, torch.Tensor):
505
+ if full_embedding.dtype == torch.bfloat16:
506
+ visual_emb = full_embedding[visual_indices].cpu().float().numpy()
507
+ else:
508
+ visual_emb = full_embedding[visual_indices].cpu().numpy()
509
+ else:
510
+ visual_emb = np.array(full_embedding)[visual_indices]
511
+
512
+ return visual_emb.astype(self.output_dtype)
513
+
514
+ def mean_pool_visual_embedding(
515
+ self,
516
+ visual_embedding: Union[torch.Tensor, np.ndarray],
517
+ token_info: Optional[Dict[str, Any]] = None,
518
+ *,
519
+ target_vectors: int = 32,
520
+ ) -> np.ndarray:
521
+ from visual_rag.embedding.pooling import colpali_row_mean_pooling, tile_level_mean_pooling
522
+
523
+ model_lower = (self.model_name or "").lower()
524
+ is_colsmol = "colsmol" in model_lower
525
+
526
+ if isinstance(visual_embedding, torch.Tensor):
527
+ if visual_embedding.dtype == torch.bfloat16:
528
+ visual_np = visual_embedding.cpu().float().numpy()
529
+ else:
530
+ visual_np = visual_embedding.cpu().numpy().astype(np.float32)
531
+ else:
532
+ visual_np = np.array(visual_embedding, dtype=np.float32)
533
+
534
+ if is_colsmol:
535
+ n_rows = (token_info or {}).get("n_rows")
536
+ n_cols = (token_info or {}).get("n_cols")
537
+ num_tiles = int(n_rows) * int(n_cols) + 1 if n_rows and n_cols else 13
538
+ return tile_level_mean_pooling(
539
+ visual_np, num_tiles=num_tiles, patches_per_tile=64, output_dtype=self.output_dtype
540
+ )
541
+
542
+ num_tokens = int(visual_np.shape[0])
543
+ grid = int(round(float(num_tokens) ** 0.5))
544
+ if grid * grid != num_tokens:
545
+ raise ValueError(
546
+ f"Cannot infer square grid from num_visual_tokens={num_tokens} for model={self.model_name}"
547
+ )
548
+ if int(target_vectors) != int(grid):
549
+ raise ValueError(
550
+ f"target_vectors={target_vectors} does not match inferred grid_size={grid} for model={self.model_name}"
551
+ )
552
+ return colpali_row_mean_pooling(
553
+ visual_np, grid_size=int(target_vectors), output_dtype=self.output_dtype
554
+ )
555
+
556
+ def global_pool_from_mean_pool(self, mean_pool: np.ndarray) -> np.ndarray:
557
+ if mean_pool.size == 0:
558
+ return np.zeros((128,), dtype=self.output_dtype)
559
+ return mean_pool.mean(axis=0).astype(self.output_dtype)
560
+
561
+ def experimental_pool_visual_embedding(
562
+ self,
563
+ visual_embedding: Union[torch.Tensor, np.ndarray],
564
+ token_info: Optional[Dict[str, Any]] = None,
565
+ *,
566
+ target_vectors: int = 32,
567
+ mean_pool: Optional[np.ndarray] = None,
568
+ ) -> np.ndarray:
569
+ from visual_rag.embedding.pooling import (
570
+ colpali_experimental_pooling_from_rows,
571
+ colsmol_experimental_pooling,
572
+ )
573
+
574
+ model_lower = (self.model_name or "").lower()
575
+ is_colsmol = "colsmol" in model_lower
576
+
577
+ if isinstance(visual_embedding, torch.Tensor):
578
+ if visual_embedding.dtype == torch.bfloat16:
579
+ visual_np = visual_embedding.cpu().float().numpy()
580
+ else:
581
+ visual_np = visual_embedding.cpu().numpy().astype(np.float32)
582
+ else:
583
+ visual_np = np.array(visual_embedding, dtype=np.float32)
584
+
585
+ if is_colsmol:
586
+ if (
587
+ mean_pool is not None
588
+ and getattr(mean_pool, "shape", None) is not None
589
+ and int(mean_pool.shape[0]) > 0
590
+ ):
591
+ num_tiles = int(mean_pool.shape[0])
592
+ else:
593
+ num_tiles = (token_info or {}).get("num_tiles")
594
+ if num_tiles is None:
595
+ num_visual_tokens = (token_info or {}).get("num_visual_tokens")
596
+ if num_visual_tokens is None:
597
+ num_visual_tokens = int(visual_np.shape[0])
598
+ patches_per_tile = 64
599
+ num_tiles = int(num_visual_tokens) // patches_per_tile
600
+ if int(num_tiles) * patches_per_tile != int(num_visual_tokens):
601
+ num_tiles = int(num_tiles) + 1
602
+ num_tiles = int(num_tiles)
603
+ return colsmol_experimental_pooling(
604
+ visual_np, num_tiles=num_tiles, patches_per_tile=64, output_dtype=self.output_dtype
605
+ )
606
+
607
+ rows = (
608
+ mean_pool
609
+ if mean_pool is not None
610
+ else self.mean_pool_visual_embedding(
611
+ visual_np, token_info, target_vectors=target_vectors
612
+ )
613
+ )
614
+ if int(rows.shape[0]) != int(target_vectors):
615
+ raise ValueError(
616
+ f"experimental pooling expects mean_pool to have {target_vectors} rows, got {rows.shape[0]} for model={self.model_name}"
617
+ )
618
+ return colpali_experimental_pooling_from_rows(rows, output_dtype=self.output_dtype)
619
+
620
+
621
+ # Backward compatibility alias
622
+ ColPaliEmbedder = VisualEmbedder
@@ -0,0 +1,21 @@
1
+ """
2
+ Indexing module - PDF processing, embedding storage, and CDN uploads.
3
+
4
+ Components:
5
+ - PDFProcessor: Convert PDFs to images and extract text
6
+ - QdrantIndexer: Upload embeddings to Qdrant vector database
7
+ - CloudinaryUploader: Upload images to Cloudinary CDN
8
+ - ProcessingPipeline: End-to-end PDF → Qdrant pipeline
9
+ """
10
+
11
+ from visual_rag.indexing.cloudinary_uploader import CloudinaryUploader
12
+ from visual_rag.indexing.pdf_processor import PDFProcessor
13
+ from visual_rag.indexing.pipeline import ProcessingPipeline
14
+ from visual_rag.indexing.qdrant_indexer import QdrantIndexer
15
+
16
+ __all__ = [
17
+ "PDFProcessor",
18
+ "QdrantIndexer",
19
+ "CloudinaryUploader",
20
+ "ProcessingPipeline",
21
+ ]