visual-rag-toolkit 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. benchmarks/README.md +101 -0
  2. benchmarks/__init__.py +11 -0
  3. benchmarks/analyze_results.py +187 -0
  4. benchmarks/benchmark_datasets.txt +105 -0
  5. benchmarks/prepare_submission.py +205 -0
  6. benchmarks/quick_test.py +566 -0
  7. benchmarks/run_vidore.py +513 -0
  8. benchmarks/vidore_beir_qdrant/run_qdrant_beir.py +1365 -0
  9. benchmarks/vidore_tatdqa_test/COMMANDS.md +83 -0
  10. benchmarks/vidore_tatdqa_test/__init__.py +6 -0
  11. benchmarks/vidore_tatdqa_test/dataset_loader.py +363 -0
  12. benchmarks/vidore_tatdqa_test/metrics.py +44 -0
  13. benchmarks/vidore_tatdqa_test/run_qdrant.py +799 -0
  14. benchmarks/vidore_tatdqa_test/sweep_eval.py +372 -0
  15. demo/__init__.py +10 -0
  16. demo/app.py +45 -0
  17. demo/commands.py +334 -0
  18. demo/config.py +34 -0
  19. demo/download_models.py +75 -0
  20. demo/evaluation.py +602 -0
  21. demo/example_metadata_mapping_sigir.json +37 -0
  22. demo/indexing.py +286 -0
  23. demo/qdrant_utils.py +211 -0
  24. demo/results.py +35 -0
  25. demo/test_qdrant_connection.py +119 -0
  26. demo/ui/__init__.py +15 -0
  27. demo/ui/benchmark.py +355 -0
  28. demo/ui/header.py +30 -0
  29. demo/ui/playground.py +339 -0
  30. demo/ui/sidebar.py +162 -0
  31. demo/ui/upload.py +487 -0
  32. visual_rag/__init__.py +98 -0
  33. visual_rag/cli/__init__.py +1 -0
  34. visual_rag/cli/main.py +629 -0
  35. visual_rag/config.py +230 -0
  36. visual_rag/demo_runner.py +90 -0
  37. visual_rag/embedding/__init__.py +26 -0
  38. visual_rag/embedding/pooling.py +343 -0
  39. visual_rag/embedding/visual_embedder.py +622 -0
  40. visual_rag/indexing/__init__.py +21 -0
  41. visual_rag/indexing/cloudinary_uploader.py +274 -0
  42. visual_rag/indexing/pdf_processor.py +324 -0
  43. visual_rag/indexing/pipeline.py +628 -0
  44. visual_rag/indexing/qdrant_indexer.py +478 -0
  45. visual_rag/preprocessing/__init__.py +3 -0
  46. visual_rag/preprocessing/crop_empty.py +120 -0
  47. visual_rag/qdrant_admin.py +222 -0
  48. visual_rag/retrieval/__init__.py +19 -0
  49. visual_rag/retrieval/multi_vector.py +222 -0
  50. visual_rag/retrieval/single_stage.py +126 -0
  51. visual_rag/retrieval/three_stage.py +173 -0
  52. visual_rag/retrieval/two_stage.py +471 -0
  53. visual_rag/visualization/__init__.py +19 -0
  54. visual_rag/visualization/saliency.py +335 -0
  55. visual_rag_toolkit-0.1.1.dist-info/METADATA +305 -0
  56. visual_rag_toolkit-0.1.1.dist-info/RECORD +59 -0
  57. visual_rag_toolkit-0.1.1.dist-info/WHEEL +4 -0
  58. visual_rag_toolkit-0.1.1.dist-info/entry_points.txt +3 -0
  59. visual_rag_toolkit-0.1.1.dist-info/licenses/LICENSE +22 -0
visual_rag/config.py ADDED
@@ -0,0 +1,230 @@
1
+ """
2
+ Configuration utilities for Visual RAG Toolkit.
3
+
4
+ Provides:
5
+ - YAML configuration loading with caching
6
+ - Environment variable overrides
7
+ - Convenience getters for common settings
8
+ """
9
+
10
+ import copy
11
+ import logging
12
+ import os
13
+ from pathlib import Path
14
+ from typing import Any, Dict, Optional
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ # Global config cache (raw YAML only; env overrides applied on demand)
19
+ _raw_config_cache: Optional[Dict[str, Any]] = None
20
+ _raw_config_cache_path: Optional[str] = None
21
+
22
+
23
+ def _env_qdrant_url() -> Optional[str]:
24
+ return os.getenv("SIGIR_QDRANT_URL") or os.getenv("DEST_QDRANT_URL") or os.getenv("QDRANT_URL")
25
+
26
+
27
+ def _env_qdrant_api_key() -> Optional[str]:
28
+ return (
29
+ os.getenv("SIGIR_QDRANT_KEY")
30
+ or os.getenv("SIGIR_QDRANT_API_KEY")
31
+ or os.getenv("DEST_QDRANT_API_KEY")
32
+ or os.getenv("QDRANT_API_KEY")
33
+ )
34
+
35
+
36
+ def load_config(
37
+ config_path: Optional[str] = None,
38
+ force_reload: bool = False,
39
+ apply_env_overrides: bool = True,
40
+ ) -> Dict[str, Any]:
41
+ """
42
+ Load configuration from YAML file.
43
+
44
+ Uses caching to avoid repeated file I/O.
45
+ Environment variables can override config values.
46
+
47
+ Args:
48
+ config_path: Path to config file (auto-detected if None)
49
+ force_reload: Bypass cache and reload from file
50
+
51
+ Returns:
52
+ Configuration dictionary
53
+ """
54
+ global _raw_config_cache, _raw_config_cache_path
55
+
56
+ # Determine the effective config path (used for caching)
57
+ effective_path: Optional[str] = None
58
+
59
+ # Find config file
60
+ if config_path is None:
61
+ config_path = os.getenv("VISUALRAG_CONFIG")
62
+
63
+ if config_path is None:
64
+ # Check common locations
65
+ search_paths = [
66
+ Path.cwd() / "config.yaml",
67
+ Path.cwd() / "visual_rag.yaml",
68
+ Path.home() / ".visual_rag" / "config.yaml",
69
+ ]
70
+
71
+ for path in search_paths:
72
+ if path.exists():
73
+ config_path = str(path)
74
+ break
75
+ effective_path = str(config_path) if config_path else None
76
+
77
+ # Return cached raw config if available.
78
+ # - If caller doesn't specify a path (effective_path is None), use whatever was
79
+ # loaded most recently (common pattern in apps).
80
+ # - If a path is specified, only reuse cache when it matches.
81
+ if (
82
+ _raw_config_cache is not None
83
+ and not force_reload
84
+ and (effective_path is None or _raw_config_cache_path == effective_path)
85
+ ):
86
+ cfg = copy.deepcopy(_raw_config_cache)
87
+ return _apply_env_overrides(cfg) if apply_env_overrides else cfg
88
+
89
+ # Load YAML if file exists
90
+ config = {}
91
+ if config_path and Path(config_path).exists():
92
+ try:
93
+ import yaml
94
+
95
+ with open(config_path, "r") as f:
96
+ config = yaml.safe_load(f) or {}
97
+
98
+ logger.info(f"Loaded config from: {config_path}")
99
+ except ImportError:
100
+ logger.warning("PyYAML not installed, using environment variables only")
101
+ except Exception as e:
102
+ logger.warning(f"Could not load config file: {e}")
103
+
104
+ # Cache RAW config (no env overrides)
105
+ _raw_config_cache = copy.deepcopy(config)
106
+ _raw_config_cache_path = effective_path
107
+
108
+ # Return resolved or raw depending on caller preference
109
+ cfg = copy.deepcopy(config)
110
+ return _apply_env_overrides(cfg) if apply_env_overrides else cfg
111
+
112
+
113
+ def _apply_env_overrides(config: Dict[str, Any]) -> Dict[str, Any]:
114
+ """Apply environment variable overrides."""
115
+
116
+ env_mappings = {
117
+ # Qdrant
118
+ "QDRANT_URL": ["qdrant", "url"],
119
+ "QDRANT_API_KEY": ["qdrant", "api_key"],
120
+ "QDRANT_COLLECTION": ["qdrant", "collection"],
121
+ # Model
122
+ "VISUALRAG_MODEL": ["model", "name"],
123
+ "COLPALI_MODEL_NAME": ["model", "name"], # Alias
124
+ "EMBEDDING_BATCH_SIZE": ["model", "batch_size"],
125
+ # Cloudinary
126
+ "CLOUDINARY_CLOUD_NAME": ["cloudinary", "cloud_name"],
127
+ "CLOUDINARY_API_KEY": ["cloudinary", "api_key"],
128
+ "CLOUDINARY_API_SECRET": ["cloudinary", "api_secret"],
129
+ # Processing
130
+ "PDF_DPI": ["processing", "dpi"],
131
+ "JPEG_QUALITY": ["processing", "jpeg_quality"],
132
+ # Search
133
+ "SEARCH_STRATEGY": ["search", "strategy"],
134
+ "PREFETCH_K": ["search", "prefetch_k"],
135
+ # Special token handling
136
+ "VISUALRAG_INCLUDE_SPECIAL_TOKENS": ["embedding", "include_special_tokens"],
137
+ }
138
+
139
+ for env_var, path in env_mappings.items():
140
+ value = os.getenv(env_var)
141
+ if value is not None:
142
+ # Navigate to the right place in config
143
+ current = config
144
+ for key in path[:-1]:
145
+ if key not in current:
146
+ current[key] = {}
147
+ current = current[key]
148
+
149
+ # Convert value to appropriate type
150
+ final_key = path[-1]
151
+ if final_key in current:
152
+ existing_type = type(current[final_key])
153
+ # Use `is` for type comparisons (Ruff E721).
154
+ if existing_type is bool:
155
+ value = value.lower() in ("true", "1", "yes", "on")
156
+ elif existing_type is int:
157
+ value = int(value)
158
+ elif existing_type is float:
159
+ value = float(value)
160
+
161
+ current[final_key] = value
162
+ logger.debug(f"Config override: {'.'.join(path)} = {value}")
163
+
164
+ return config
165
+
166
+
167
+ def get(key: str, default: Any = None) -> Any:
168
+ """
169
+ Get a configuration value by dot-notation path.
170
+
171
+ Examples:
172
+ >>> get("qdrant.url")
173
+ >>> get("model.name", "vidore/colSmol-500M")
174
+ >>> get("search.strategy", "multi_vector")
175
+ """
176
+ config = load_config(apply_env_overrides=True)
177
+
178
+ keys = key.split(".")
179
+ current = config
180
+
181
+ for k in keys:
182
+ if isinstance(current, dict) and k in current:
183
+ current = current[k]
184
+ else:
185
+ return default
186
+
187
+ return current
188
+
189
+
190
+ def get_section(section: str, *, apply_env_overrides: bool = True) -> Dict[str, Any]:
191
+ """Get an entire configuration section."""
192
+ config = load_config(apply_env_overrides=apply_env_overrides)
193
+ return config.get(section, {})
194
+
195
+
196
+ # Convenience getters
197
+ def get_qdrant_config() -> Dict[str, Any]:
198
+ """Get Qdrant configuration with defaults."""
199
+ return {
200
+ "url": get("qdrant.url", _env_qdrant_url()),
201
+ "api_key": get("qdrant.api_key", _env_qdrant_api_key()),
202
+ "collection": get("qdrant.collection", "visual_documents"),
203
+ }
204
+
205
+
206
+ def get_model_config() -> Dict[str, Any]:
207
+ """Get model configuration with defaults."""
208
+ return {
209
+ "name": get("model.name", "vidore/colSmol-500M"),
210
+ "batch_size": get("model.batch_size", 4),
211
+ "device": get("model.device", "auto"),
212
+ }
213
+
214
+
215
+ def get_processing_config() -> Dict[str, Any]:
216
+ """Get processing configuration with defaults."""
217
+ return {
218
+ "dpi": get("processing.dpi", 140),
219
+ "jpeg_quality": get("processing.jpeg_quality", 95),
220
+ "page_batch_size": get("processing.page_batch_size", 50),
221
+ }
222
+
223
+
224
+ def get_search_config() -> Dict[str, Any]:
225
+ """Get search configuration with defaults."""
226
+ return {
227
+ "strategy": get("search.strategy", "multi_vector"),
228
+ "prefetch_k": get("search.prefetch_k", 200),
229
+ "top_k": get("search.top_k", 10),
230
+ }
@@ -0,0 +1,90 @@
1
+ """
2
+ Launch the Streamlit demo from an installed package.
3
+
4
+ Why:
5
+ - After `pip install visual-rag-toolkit`, the repo layout isn't present.
6
+ - We package the `demo/` module and expose `visual_rag.demo()` + `visual-rag-demo`.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import argparse
12
+ import importlib
13
+ import os
14
+ import subprocess
15
+ import sys
16
+ from pathlib import Path
17
+ from typing import Optional
18
+
19
+
20
+ def demo(
21
+ *,
22
+ host: str = "0.0.0.0",
23
+ port: int = 7860,
24
+ headless: bool = True,
25
+ open_browser: bool = False,
26
+ extra_args: Optional[list[str]] = None,
27
+ ) -> int:
28
+ """
29
+ Launch the Streamlit demo UI.
30
+
31
+ Requirements:
32
+ - `visual-rag-toolkit[ui,qdrant,embedding,pdf]` (or `visual-rag-toolkit[all]`)
33
+
34
+ Returns:
35
+ Streamlit process exit code.
36
+ """
37
+ try:
38
+ import streamlit # noqa: F401
39
+ except Exception as e: # pragma: no cover
40
+ raise RuntimeError(
41
+ "Streamlit is not installed. Install with:\n"
42
+ ' pip install "visual-rag-toolkit[ui,qdrant,embedding,pdf]"'
43
+ ) from e
44
+
45
+ # Resolve the installed demo entrypoint path.
46
+ mod = importlib.import_module("demo.app")
47
+ app_path = Path(getattr(mod, "__file__", "")).resolve()
48
+ if not app_path.exists(): # pragma: no cover
49
+ raise RuntimeError("Could not locate installed demo app (demo.app).")
50
+
51
+ # Build a stable Streamlit invocation.
52
+ cmd = [sys.executable, "-m", "streamlit", "run", str(app_path)]
53
+ cmd += ["--server.address", str(host)]
54
+ cmd += ["--server.port", str(int(port))]
55
+ cmd += ["--server.headless", "true" if headless else "false"]
56
+ cmd += ["--browser.gatherUsageStats", "false"]
57
+ cmd += ["--server.runOnSave", "false"]
58
+ cmd += ["--browser.serverAddress", str(host)]
59
+ if not open_browser:
60
+ cmd += ["--browser.serverPort", str(int(port))]
61
+ cmd += ["--browser.open", "false"]
62
+
63
+ if extra_args:
64
+ cmd += list(extra_args)
65
+
66
+ env = os.environ.copy()
67
+ # Make sure the demo doesn't spam internal Streamlit warnings in logs.
68
+ env.setdefault("STREAMLIT_BROWSER_GATHER_USAGE_STATS", "false")
69
+
70
+ return subprocess.call(cmd, env=env)
71
+
72
+
73
+ def main() -> None:
74
+ p = argparse.ArgumentParser(description="Launch the Visual RAG Toolkit Streamlit demo.")
75
+ p.add_argument("--host", default="0.0.0.0")
76
+ p.add_argument("--port", type=int, default=7860)
77
+ p.add_argument(
78
+ "--no-headless", action="store_true", help="Run with a browser window (not headless)."
79
+ )
80
+ p.add_argument("--open", action="store_true", help="Open browser automatically.")
81
+ args, unknown = p.parse_known_args()
82
+
83
+ rc = demo(
84
+ host=args.host,
85
+ port=args.port,
86
+ headless=(not args.no_headless),
87
+ open_browser=bool(args.open),
88
+ extra_args=unknown,
89
+ )
90
+ raise SystemExit(rc)
@@ -0,0 +1,26 @@
1
+ """
2
+ Embedding module - Visual and text embedding generation.
3
+
4
+ Provides:
5
+ - VisualEmbedder: Backend-agnostic visual embedder (ColPali, etc.)
6
+ - Pooling utilities: tile-level, global, MaxSim scoring
7
+ """
8
+
9
+ from visual_rag.embedding.pooling import (
10
+ compute_maxsim_batch,
11
+ compute_maxsim_score,
12
+ global_mean_pooling,
13
+ tile_level_mean_pooling,
14
+ )
15
+ from visual_rag.embedding.visual_embedder import ColPaliEmbedder, VisualEmbedder
16
+
17
+ __all__ = [
18
+ # Main embedder
19
+ "VisualEmbedder",
20
+ "ColPaliEmbedder", # Backward compatibility alias
21
+ # Pooling functions
22
+ "tile_level_mean_pooling",
23
+ "global_mean_pooling",
24
+ "compute_maxsim_score",
25
+ "compute_maxsim_batch",
26
+ ]
@@ -0,0 +1,343 @@
1
+ """
2
+ Pooling strategies for multi-vector embeddings.
3
+
4
+ Provides:
5
+ - Tile-level mean pooling: Preserves spatial structure (num_tiles × dim)
6
+ - Global mean pooling: Single vector (1 × dim)
7
+ - MaxSim scoring for ColBERT-style late interaction
8
+ """
9
+
10
+ import logging
11
+ from typing import Optional, Union
12
+
13
+ import numpy as np
14
+ import torch
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def _infer_output_dtype(
20
+ embedding: Union[torch.Tensor, np.ndarray],
21
+ output_dtype: Optional[np.dtype] = None,
22
+ ) -> np.dtype:
23
+ """Infer output dtype: use provided, else match input (fp16→fp16, bf16→fp32, fp32→fp32)."""
24
+ if output_dtype is not None:
25
+ return output_dtype
26
+ if isinstance(embedding, torch.Tensor):
27
+ if embedding.dtype == torch.float16:
28
+ return np.float16
29
+ return np.float32
30
+ if isinstance(embedding, np.ndarray) and embedding.dtype == np.float16:
31
+ return np.float16
32
+ return np.float32
33
+
34
+
35
+ def tile_level_mean_pooling(
36
+ embedding: Union[torch.Tensor, np.ndarray],
37
+ num_tiles: int,
38
+ patches_per_tile: int = 64,
39
+ output_dtype: Optional[np.dtype] = None,
40
+ ) -> np.ndarray:
41
+ """
42
+ Compute tile-level mean pooling for multi-vector embeddings.
43
+
44
+ Instead of collapsing to 1×dim (global pooling), this preserves spatial
45
+ structure by computing mean per tile → num_tiles × dim.
46
+
47
+ This is our NOVEL contribution for scalable visual retrieval:
48
+ - Faster than full MaxSim (fewer vectors to compare)
49
+ - More accurate than global pooling (preserves spatial info)
50
+ - Ideal for two-stage retrieval (prefetch with pooled, rerank with full)
51
+
52
+ Args:
53
+ embedding: Visual token embeddings [num_visual_tokens, dim]
54
+ num_tiles: Number of tiles (including global tile)
55
+ patches_per_tile: Patches per tile (64 for ColSmol)
56
+ output_dtype: Output dtype (default: infer from input, fp16→fp16, bf16→fp32)
57
+
58
+ Returns:
59
+ Tile-level pooled embeddings [num_tiles, dim]
60
+
61
+ Example:
62
+ >>> # Image with 4×3 tiles + 1 global = 13 tiles
63
+ >>> # Each tile has 64 patches → 832 visual tokens
64
+ >>> pooled = tile_level_mean_pooling(embedding, num_tiles=13)
65
+ >>> print(pooled.shape) # (13, 128)
66
+ """
67
+ out_dtype = _infer_output_dtype(embedding, output_dtype)
68
+ if isinstance(embedding, torch.Tensor):
69
+ if embedding.dtype == torch.bfloat16:
70
+ emb_np = embedding.cpu().float().numpy()
71
+ else:
72
+ emb_np = embedding.cpu().numpy().astype(np.float32)
73
+ else:
74
+ emb_np = np.array(embedding, dtype=np.float32)
75
+
76
+ num_visual_tokens = emb_np.shape[0]
77
+ expected_tokens = num_tiles * patches_per_tile
78
+
79
+ if num_visual_tokens != expected_tokens:
80
+ logger.debug(f"Token count mismatch: {num_visual_tokens} vs expected {expected_tokens}")
81
+ actual_tiles = num_visual_tokens // patches_per_tile
82
+ if actual_tiles * patches_per_tile != num_visual_tokens:
83
+ actual_tiles += 1
84
+ num_tiles = actual_tiles
85
+
86
+ tile_embeddings = []
87
+ for tile_idx in range(num_tiles):
88
+ start_idx = tile_idx * patches_per_tile
89
+ end_idx = min(start_idx + patches_per_tile, num_visual_tokens)
90
+
91
+ if start_idx >= num_visual_tokens:
92
+ break
93
+
94
+ tile_patches = emb_np[start_idx:end_idx]
95
+ tile_mean = tile_patches.mean(axis=0)
96
+ tile_embeddings.append(tile_mean)
97
+
98
+ return np.array(tile_embeddings, dtype=out_dtype)
99
+
100
+
101
+ def colpali_row_mean_pooling(
102
+ embedding: Union[torch.Tensor, np.ndarray],
103
+ grid_size: int = 32,
104
+ output_dtype: Optional[np.dtype] = None,
105
+ ) -> np.ndarray:
106
+ out_dtype = _infer_output_dtype(embedding, output_dtype)
107
+ if isinstance(embedding, torch.Tensor):
108
+ if embedding.dtype == torch.bfloat16:
109
+ emb_np = embedding.cpu().float().numpy()
110
+ else:
111
+ emb_np = embedding.cpu().numpy().astype(np.float32)
112
+ else:
113
+ emb_np = np.array(embedding, dtype=np.float32)
114
+
115
+ num_tokens, dim = emb_np.shape
116
+ expected = int(grid_size) * int(grid_size)
117
+ if num_tokens != expected:
118
+ raise ValueError(
119
+ f"Expected {expected} visual tokens for grid_size={grid_size}, got {num_tokens}"
120
+ )
121
+
122
+ grid = emb_np.reshape(int(grid_size), int(grid_size), int(dim))
123
+ pooled = grid.mean(axis=1)
124
+ return pooled.astype(out_dtype)
125
+
126
+
127
+ def colsmol_experimental_pooling(
128
+ embedding: Union[torch.Tensor, np.ndarray],
129
+ num_tiles: int,
130
+ patches_per_tile: int = 64,
131
+ output_dtype: Optional[np.dtype] = None,
132
+ ) -> np.ndarray:
133
+ out_dtype = _infer_output_dtype(embedding, output_dtype)
134
+ if isinstance(embedding, torch.Tensor):
135
+ if embedding.dtype == torch.bfloat16:
136
+ emb_np = embedding.cpu().float().numpy()
137
+ else:
138
+ emb_np = embedding.cpu().numpy().astype(np.float32)
139
+ else:
140
+ emb_np = np.array(embedding, dtype=np.float32)
141
+
142
+ num_visual_tokens, dim = emb_np.shape
143
+ if num_tiles <= 0:
144
+ raise ValueError("num_tiles must be > 0")
145
+ if patches_per_tile <= 0:
146
+ raise ValueError("patches_per_tile must be > 0")
147
+
148
+ last_tile_start = (int(num_tiles) - 1) * int(patches_per_tile)
149
+ if last_tile_start >= num_visual_tokens:
150
+ actual_tiles = int(num_visual_tokens) // int(patches_per_tile)
151
+ if actual_tiles * int(patches_per_tile) != int(num_visual_tokens):
152
+ actual_tiles += 1
153
+ if actual_tiles <= 0:
154
+ raise ValueError(
155
+ f"Not enough tokens for num_tiles={num_tiles}, patches_per_tile={patches_per_tile}: got {num_visual_tokens}"
156
+ )
157
+ num_tiles = actual_tiles
158
+ last_tile_start = (int(num_tiles) - 1) * int(patches_per_tile)
159
+
160
+ prefix = emb_np[:last_tile_start]
161
+ last_tile = emb_np[
162
+ last_tile_start : min(last_tile_start + int(patches_per_tile), num_visual_tokens)
163
+ ]
164
+
165
+ if prefix.size:
166
+ prefix_tiles = prefix.reshape(-1, int(patches_per_tile), int(dim))
167
+ prefix_means = prefix_tiles.mean(axis=1)
168
+ else:
169
+ prefix_means = np.zeros((0, int(dim)), dtype=out_dtype)
170
+
171
+ return np.concatenate([prefix_means.astype(out_dtype), last_tile.astype(out_dtype)], axis=0)
172
+
173
+
174
+ def colpali_experimental_pooling_from_rows(
175
+ row_vectors: Union[torch.Tensor, np.ndarray],
176
+ output_dtype: Optional[np.dtype] = None,
177
+ ) -> np.ndarray:
178
+ """
179
+ Experimental "convolution-style" pooling with window size 3.
180
+
181
+ For N input rows, produces N + 2 output vectors:
182
+ - Position 0: row[0] alone (1 row)
183
+ - Position 1: mean(rows[0:2]) (2 rows)
184
+ - Position 2: mean(rows[0:3]) (3 rows)
185
+ - Positions 3 to N-1: sliding window of 3 (rows[i-2:i+1])
186
+ - Position N: mean(rows[N-2:N]) (last 2 rows)
187
+ - Position N+1: row[N-1] alone (last row)
188
+
189
+ For N=32 rows: produces 34 vectors.
190
+ """
191
+ out_dtype = _infer_output_dtype(row_vectors, output_dtype)
192
+ if isinstance(row_vectors, torch.Tensor):
193
+ if row_vectors.dtype == torch.bfloat16:
194
+ rows = row_vectors.cpu().float().numpy()
195
+ else:
196
+ rows = row_vectors.cpu().numpy().astype(np.float32)
197
+ else:
198
+ rows = np.array(row_vectors, dtype=np.float32)
199
+
200
+ n, dim = rows.shape
201
+ if n < 1:
202
+ raise ValueError("row_vectors must be non-empty")
203
+ if n == 1:
204
+ return rows.astype(out_dtype)
205
+ if n == 2:
206
+ return np.stack([rows[0], rows[:2].mean(axis=0), rows[1]], axis=0).astype(out_dtype)
207
+ if n == 3:
208
+ return np.stack(
209
+ [
210
+ rows[0],
211
+ rows[:2].mean(axis=0),
212
+ rows[:3].mean(axis=0),
213
+ rows[1:3].mean(axis=0),
214
+ rows[2],
215
+ ],
216
+ axis=0,
217
+ ).astype(out_dtype)
218
+
219
+ out = np.zeros((n + 2, dim), dtype=np.float32)
220
+ out[0] = rows[0]
221
+ out[1] = rows[:2].mean(axis=0)
222
+ out[2] = rows[:3].mean(axis=0)
223
+ for i in range(3, n):
224
+ out[i] = rows[i - 2 : i + 1].mean(axis=0)
225
+ out[n] = rows[n - 2 : n].mean(axis=0)
226
+ out[n + 1] = rows[n - 1]
227
+ return out.astype(out_dtype)
228
+
229
+
230
+ def global_mean_pooling(
231
+ embedding: Union[torch.Tensor, np.ndarray],
232
+ output_dtype: Optional[np.dtype] = None,
233
+ ) -> np.ndarray:
234
+ """
235
+ Compute global mean pooling → single vector.
236
+
237
+ This is the simplest pooling but loses all spatial information.
238
+ Use for fastest retrieval when accuracy can be sacrificed.
239
+
240
+ Args:
241
+ embedding: Multi-vector embeddings [num_tokens, dim]
242
+ output_dtype: Output dtype (default: infer from input, fp16→fp16, bf16→fp32)
243
+
244
+ Returns:
245
+ Pooled vector [dim]
246
+ """
247
+ out_dtype = _infer_output_dtype(embedding, output_dtype)
248
+ if isinstance(embedding, torch.Tensor):
249
+ if embedding.dtype == torch.bfloat16:
250
+ emb_np = embedding.cpu().float().numpy()
251
+ else:
252
+ emb_np = embedding.cpu().numpy()
253
+ else:
254
+ emb_np = np.array(embedding)
255
+
256
+ return emb_np.mean(axis=0).astype(out_dtype)
257
+
258
+
259
+ def compute_maxsim_score(
260
+ query_embedding: np.ndarray,
261
+ doc_embedding: np.ndarray,
262
+ normalize: bool = True,
263
+ ) -> float:
264
+ """
265
+ Compute ColBERT-style MaxSim late interaction score.
266
+
267
+ For each query token, finds max similarity with any document token,
268
+ then sums across query tokens.
269
+
270
+ This is the standard scoring for ColBERT/ColPali:
271
+ score = Σ_q max_d (sim(q, d))
272
+
273
+ Args:
274
+ query_embedding: Query embeddings [num_query_tokens, dim]
275
+ doc_embedding: Document embeddings [num_doc_tokens, dim]
276
+ normalize: L2 normalize embeddings before scoring (recommended)
277
+
278
+ Returns:
279
+ MaxSim score (higher is better)
280
+
281
+ Example:
282
+ >>> query = embedder.embed_query("budget allocation")
283
+ >>> doc = embeddings[0] # From embed_images
284
+ >>> score = compute_maxsim_score(query, doc)
285
+ """
286
+ if normalize:
287
+ # L2 normalize
288
+ query_norm = query_embedding / (
289
+ np.linalg.norm(query_embedding, axis=1, keepdims=True) + 1e-8
290
+ )
291
+ doc_norm = doc_embedding / (np.linalg.norm(doc_embedding, axis=1, keepdims=True) + 1e-8)
292
+ else:
293
+ query_norm = query_embedding
294
+ doc_norm = doc_embedding
295
+
296
+ # Compute similarity matrix: [num_query, num_doc]
297
+ similarity_matrix = np.dot(query_norm, doc_norm.T)
298
+
299
+ # MaxSim: For each query token, take max similarity with any doc token
300
+ max_similarities = similarity_matrix.max(axis=1)
301
+
302
+ # Sum across query tokens
303
+ score = float(max_similarities.sum())
304
+
305
+ return score
306
+
307
+
308
+ def compute_maxsim_batch(
309
+ query_embedding: np.ndarray,
310
+ doc_embeddings: list,
311
+ normalize: bool = True,
312
+ ) -> list:
313
+ """
314
+ Compute MaxSim scores for multiple documents efficiently.
315
+
316
+ Args:
317
+ query_embedding: Query embeddings [num_query_tokens, dim]
318
+ doc_embeddings: List of document embeddings
319
+ normalize: L2 normalize embeddings
320
+
321
+ Returns:
322
+ List of MaxSim scores
323
+ """
324
+ # Pre-normalize query once
325
+ if normalize:
326
+ query_norm = query_embedding / (
327
+ np.linalg.norm(query_embedding, axis=1, keepdims=True) + 1e-8
328
+ )
329
+ else:
330
+ query_norm = query_embedding
331
+
332
+ scores = []
333
+ for doc_emb in doc_embeddings:
334
+ if normalize:
335
+ doc_norm = doc_emb / (np.linalg.norm(doc_emb, axis=1, keepdims=True) + 1e-8)
336
+ else:
337
+ doc_norm = doc_emb
338
+
339
+ sim_matrix = np.dot(query_norm, doc_norm.T)
340
+ max_sims = sim_matrix.max(axis=1)
341
+ scores.append(float(max_sims.sum()))
342
+
343
+ return scores