visual-rag-toolkit 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/README.md +101 -0
- benchmarks/__init__.py +11 -0
- benchmarks/analyze_results.py +187 -0
- benchmarks/benchmark_datasets.txt +105 -0
- benchmarks/prepare_submission.py +205 -0
- benchmarks/quick_test.py +566 -0
- benchmarks/run_vidore.py +513 -0
- benchmarks/vidore_beir_qdrant/run_qdrant_beir.py +1365 -0
- benchmarks/vidore_tatdqa_test/COMMANDS.md +83 -0
- benchmarks/vidore_tatdqa_test/__init__.py +6 -0
- benchmarks/vidore_tatdqa_test/dataset_loader.py +363 -0
- benchmarks/vidore_tatdqa_test/metrics.py +44 -0
- benchmarks/vidore_tatdqa_test/run_qdrant.py +799 -0
- benchmarks/vidore_tatdqa_test/sweep_eval.py +372 -0
- demo/__init__.py +10 -0
- demo/app.py +45 -0
- demo/commands.py +334 -0
- demo/config.py +34 -0
- demo/download_models.py +75 -0
- demo/evaluation.py +602 -0
- demo/example_metadata_mapping_sigir.json +37 -0
- demo/indexing.py +286 -0
- demo/qdrant_utils.py +211 -0
- demo/results.py +35 -0
- demo/test_qdrant_connection.py +119 -0
- demo/ui/__init__.py +15 -0
- demo/ui/benchmark.py +355 -0
- demo/ui/header.py +30 -0
- demo/ui/playground.py +339 -0
- demo/ui/sidebar.py +162 -0
- demo/ui/upload.py +487 -0
- visual_rag/__init__.py +98 -0
- visual_rag/cli/__init__.py +1 -0
- visual_rag/cli/main.py +629 -0
- visual_rag/config.py +230 -0
- visual_rag/demo_runner.py +90 -0
- visual_rag/embedding/__init__.py +26 -0
- visual_rag/embedding/pooling.py +343 -0
- visual_rag/embedding/visual_embedder.py +622 -0
- visual_rag/indexing/__init__.py +21 -0
- visual_rag/indexing/cloudinary_uploader.py +274 -0
- visual_rag/indexing/pdf_processor.py +324 -0
- visual_rag/indexing/pipeline.py +628 -0
- visual_rag/indexing/qdrant_indexer.py +478 -0
- visual_rag/preprocessing/__init__.py +3 -0
- visual_rag/preprocessing/crop_empty.py +120 -0
- visual_rag/qdrant_admin.py +222 -0
- visual_rag/retrieval/__init__.py +19 -0
- visual_rag/retrieval/multi_vector.py +222 -0
- visual_rag/retrieval/single_stage.py +126 -0
- visual_rag/retrieval/three_stage.py +173 -0
- visual_rag/retrieval/two_stage.py +471 -0
- visual_rag/visualization/__init__.py +19 -0
- visual_rag/visualization/saliency.py +335 -0
- visual_rag_toolkit-0.1.1.dist-info/METADATA +305 -0
- visual_rag_toolkit-0.1.1.dist-info/RECORD +59 -0
- visual_rag_toolkit-0.1.1.dist-info/WHEEL +4 -0
- visual_rag_toolkit-0.1.1.dist-info/entry_points.txt +3 -0
- visual_rag_toolkit-0.1.1.dist-info/licenses/LICENSE +22 -0
visual_rag/config.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Configuration utilities for Visual RAG Toolkit.
|
|
3
|
+
|
|
4
|
+
Provides:
|
|
5
|
+
- YAML configuration loading with caching
|
|
6
|
+
- Environment variable overrides
|
|
7
|
+
- Convenience getters for common settings
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import copy
|
|
11
|
+
import logging
|
|
12
|
+
import os
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
from typing import Any, Dict, Optional
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
# Global config cache (raw YAML only; env overrides applied on demand)
|
|
19
|
+
_raw_config_cache: Optional[Dict[str, Any]] = None
|
|
20
|
+
_raw_config_cache_path: Optional[str] = None
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _env_qdrant_url() -> Optional[str]:
|
|
24
|
+
return os.getenv("SIGIR_QDRANT_URL") or os.getenv("DEST_QDRANT_URL") or os.getenv("QDRANT_URL")
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _env_qdrant_api_key() -> Optional[str]:
|
|
28
|
+
return (
|
|
29
|
+
os.getenv("SIGIR_QDRANT_KEY")
|
|
30
|
+
or os.getenv("SIGIR_QDRANT_API_KEY")
|
|
31
|
+
or os.getenv("DEST_QDRANT_API_KEY")
|
|
32
|
+
or os.getenv("QDRANT_API_KEY")
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def load_config(
|
|
37
|
+
config_path: Optional[str] = None,
|
|
38
|
+
force_reload: bool = False,
|
|
39
|
+
apply_env_overrides: bool = True,
|
|
40
|
+
) -> Dict[str, Any]:
|
|
41
|
+
"""
|
|
42
|
+
Load configuration from YAML file.
|
|
43
|
+
|
|
44
|
+
Uses caching to avoid repeated file I/O.
|
|
45
|
+
Environment variables can override config values.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
config_path: Path to config file (auto-detected if None)
|
|
49
|
+
force_reload: Bypass cache and reload from file
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
Configuration dictionary
|
|
53
|
+
"""
|
|
54
|
+
global _raw_config_cache, _raw_config_cache_path
|
|
55
|
+
|
|
56
|
+
# Determine the effective config path (used for caching)
|
|
57
|
+
effective_path: Optional[str] = None
|
|
58
|
+
|
|
59
|
+
# Find config file
|
|
60
|
+
if config_path is None:
|
|
61
|
+
config_path = os.getenv("VISUALRAG_CONFIG")
|
|
62
|
+
|
|
63
|
+
if config_path is None:
|
|
64
|
+
# Check common locations
|
|
65
|
+
search_paths = [
|
|
66
|
+
Path.cwd() / "config.yaml",
|
|
67
|
+
Path.cwd() / "visual_rag.yaml",
|
|
68
|
+
Path.home() / ".visual_rag" / "config.yaml",
|
|
69
|
+
]
|
|
70
|
+
|
|
71
|
+
for path in search_paths:
|
|
72
|
+
if path.exists():
|
|
73
|
+
config_path = str(path)
|
|
74
|
+
break
|
|
75
|
+
effective_path = str(config_path) if config_path else None
|
|
76
|
+
|
|
77
|
+
# Return cached raw config if available.
|
|
78
|
+
# - If caller doesn't specify a path (effective_path is None), use whatever was
|
|
79
|
+
# loaded most recently (common pattern in apps).
|
|
80
|
+
# - If a path is specified, only reuse cache when it matches.
|
|
81
|
+
if (
|
|
82
|
+
_raw_config_cache is not None
|
|
83
|
+
and not force_reload
|
|
84
|
+
and (effective_path is None or _raw_config_cache_path == effective_path)
|
|
85
|
+
):
|
|
86
|
+
cfg = copy.deepcopy(_raw_config_cache)
|
|
87
|
+
return _apply_env_overrides(cfg) if apply_env_overrides else cfg
|
|
88
|
+
|
|
89
|
+
# Load YAML if file exists
|
|
90
|
+
config = {}
|
|
91
|
+
if config_path and Path(config_path).exists():
|
|
92
|
+
try:
|
|
93
|
+
import yaml
|
|
94
|
+
|
|
95
|
+
with open(config_path, "r") as f:
|
|
96
|
+
config = yaml.safe_load(f) or {}
|
|
97
|
+
|
|
98
|
+
logger.info(f"Loaded config from: {config_path}")
|
|
99
|
+
except ImportError:
|
|
100
|
+
logger.warning("PyYAML not installed, using environment variables only")
|
|
101
|
+
except Exception as e:
|
|
102
|
+
logger.warning(f"Could not load config file: {e}")
|
|
103
|
+
|
|
104
|
+
# Cache RAW config (no env overrides)
|
|
105
|
+
_raw_config_cache = copy.deepcopy(config)
|
|
106
|
+
_raw_config_cache_path = effective_path
|
|
107
|
+
|
|
108
|
+
# Return resolved or raw depending on caller preference
|
|
109
|
+
cfg = copy.deepcopy(config)
|
|
110
|
+
return _apply_env_overrides(cfg) if apply_env_overrides else cfg
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def _apply_env_overrides(config: Dict[str, Any]) -> Dict[str, Any]:
|
|
114
|
+
"""Apply environment variable overrides."""
|
|
115
|
+
|
|
116
|
+
env_mappings = {
|
|
117
|
+
# Qdrant
|
|
118
|
+
"QDRANT_URL": ["qdrant", "url"],
|
|
119
|
+
"QDRANT_API_KEY": ["qdrant", "api_key"],
|
|
120
|
+
"QDRANT_COLLECTION": ["qdrant", "collection"],
|
|
121
|
+
# Model
|
|
122
|
+
"VISUALRAG_MODEL": ["model", "name"],
|
|
123
|
+
"COLPALI_MODEL_NAME": ["model", "name"], # Alias
|
|
124
|
+
"EMBEDDING_BATCH_SIZE": ["model", "batch_size"],
|
|
125
|
+
# Cloudinary
|
|
126
|
+
"CLOUDINARY_CLOUD_NAME": ["cloudinary", "cloud_name"],
|
|
127
|
+
"CLOUDINARY_API_KEY": ["cloudinary", "api_key"],
|
|
128
|
+
"CLOUDINARY_API_SECRET": ["cloudinary", "api_secret"],
|
|
129
|
+
# Processing
|
|
130
|
+
"PDF_DPI": ["processing", "dpi"],
|
|
131
|
+
"JPEG_QUALITY": ["processing", "jpeg_quality"],
|
|
132
|
+
# Search
|
|
133
|
+
"SEARCH_STRATEGY": ["search", "strategy"],
|
|
134
|
+
"PREFETCH_K": ["search", "prefetch_k"],
|
|
135
|
+
# Special token handling
|
|
136
|
+
"VISUALRAG_INCLUDE_SPECIAL_TOKENS": ["embedding", "include_special_tokens"],
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
for env_var, path in env_mappings.items():
|
|
140
|
+
value = os.getenv(env_var)
|
|
141
|
+
if value is not None:
|
|
142
|
+
# Navigate to the right place in config
|
|
143
|
+
current = config
|
|
144
|
+
for key in path[:-1]:
|
|
145
|
+
if key not in current:
|
|
146
|
+
current[key] = {}
|
|
147
|
+
current = current[key]
|
|
148
|
+
|
|
149
|
+
# Convert value to appropriate type
|
|
150
|
+
final_key = path[-1]
|
|
151
|
+
if final_key in current:
|
|
152
|
+
existing_type = type(current[final_key])
|
|
153
|
+
# Use `is` for type comparisons (Ruff E721).
|
|
154
|
+
if existing_type is bool:
|
|
155
|
+
value = value.lower() in ("true", "1", "yes", "on")
|
|
156
|
+
elif existing_type is int:
|
|
157
|
+
value = int(value)
|
|
158
|
+
elif existing_type is float:
|
|
159
|
+
value = float(value)
|
|
160
|
+
|
|
161
|
+
current[final_key] = value
|
|
162
|
+
logger.debug(f"Config override: {'.'.join(path)} = {value}")
|
|
163
|
+
|
|
164
|
+
return config
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def get(key: str, default: Any = None) -> Any:
|
|
168
|
+
"""
|
|
169
|
+
Get a configuration value by dot-notation path.
|
|
170
|
+
|
|
171
|
+
Examples:
|
|
172
|
+
>>> get("qdrant.url")
|
|
173
|
+
>>> get("model.name", "vidore/colSmol-500M")
|
|
174
|
+
>>> get("search.strategy", "multi_vector")
|
|
175
|
+
"""
|
|
176
|
+
config = load_config(apply_env_overrides=True)
|
|
177
|
+
|
|
178
|
+
keys = key.split(".")
|
|
179
|
+
current = config
|
|
180
|
+
|
|
181
|
+
for k in keys:
|
|
182
|
+
if isinstance(current, dict) and k in current:
|
|
183
|
+
current = current[k]
|
|
184
|
+
else:
|
|
185
|
+
return default
|
|
186
|
+
|
|
187
|
+
return current
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def get_section(section: str, *, apply_env_overrides: bool = True) -> Dict[str, Any]:
|
|
191
|
+
"""Get an entire configuration section."""
|
|
192
|
+
config = load_config(apply_env_overrides=apply_env_overrides)
|
|
193
|
+
return config.get(section, {})
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
# Convenience getters
|
|
197
|
+
def get_qdrant_config() -> Dict[str, Any]:
|
|
198
|
+
"""Get Qdrant configuration with defaults."""
|
|
199
|
+
return {
|
|
200
|
+
"url": get("qdrant.url", _env_qdrant_url()),
|
|
201
|
+
"api_key": get("qdrant.api_key", _env_qdrant_api_key()),
|
|
202
|
+
"collection": get("qdrant.collection", "visual_documents"),
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def get_model_config() -> Dict[str, Any]:
|
|
207
|
+
"""Get model configuration with defaults."""
|
|
208
|
+
return {
|
|
209
|
+
"name": get("model.name", "vidore/colSmol-500M"),
|
|
210
|
+
"batch_size": get("model.batch_size", 4),
|
|
211
|
+
"device": get("model.device", "auto"),
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def get_processing_config() -> Dict[str, Any]:
|
|
216
|
+
"""Get processing configuration with defaults."""
|
|
217
|
+
return {
|
|
218
|
+
"dpi": get("processing.dpi", 140),
|
|
219
|
+
"jpeg_quality": get("processing.jpeg_quality", 95),
|
|
220
|
+
"page_batch_size": get("processing.page_batch_size", 50),
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def get_search_config() -> Dict[str, Any]:
|
|
225
|
+
"""Get search configuration with defaults."""
|
|
226
|
+
return {
|
|
227
|
+
"strategy": get("search.strategy", "multi_vector"),
|
|
228
|
+
"prefetch_k": get("search.prefetch_k", 200),
|
|
229
|
+
"top_k": get("search.top_k", 10),
|
|
230
|
+
}
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Launch the Streamlit demo from an installed package.
|
|
3
|
+
|
|
4
|
+
Why:
|
|
5
|
+
- After `pip install visual-rag-toolkit`, the repo layout isn't present.
|
|
6
|
+
- We package the `demo/` module and expose `visual_rag.demo()` + `visual-rag-demo`.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import argparse
|
|
12
|
+
import importlib
|
|
13
|
+
import os
|
|
14
|
+
import subprocess
|
|
15
|
+
import sys
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
from typing import Optional
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def demo(
|
|
21
|
+
*,
|
|
22
|
+
host: str = "0.0.0.0",
|
|
23
|
+
port: int = 7860,
|
|
24
|
+
headless: bool = True,
|
|
25
|
+
open_browser: bool = False,
|
|
26
|
+
extra_args: Optional[list[str]] = None,
|
|
27
|
+
) -> int:
|
|
28
|
+
"""
|
|
29
|
+
Launch the Streamlit demo UI.
|
|
30
|
+
|
|
31
|
+
Requirements:
|
|
32
|
+
- `visual-rag-toolkit[ui,qdrant,embedding,pdf]` (or `visual-rag-toolkit[all]`)
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
Streamlit process exit code.
|
|
36
|
+
"""
|
|
37
|
+
try:
|
|
38
|
+
import streamlit # noqa: F401
|
|
39
|
+
except Exception as e: # pragma: no cover
|
|
40
|
+
raise RuntimeError(
|
|
41
|
+
"Streamlit is not installed. Install with:\n"
|
|
42
|
+
' pip install "visual-rag-toolkit[ui,qdrant,embedding,pdf]"'
|
|
43
|
+
) from e
|
|
44
|
+
|
|
45
|
+
# Resolve the installed demo entrypoint path.
|
|
46
|
+
mod = importlib.import_module("demo.app")
|
|
47
|
+
app_path = Path(getattr(mod, "__file__", "")).resolve()
|
|
48
|
+
if not app_path.exists(): # pragma: no cover
|
|
49
|
+
raise RuntimeError("Could not locate installed demo app (demo.app).")
|
|
50
|
+
|
|
51
|
+
# Build a stable Streamlit invocation.
|
|
52
|
+
cmd = [sys.executable, "-m", "streamlit", "run", str(app_path)]
|
|
53
|
+
cmd += ["--server.address", str(host)]
|
|
54
|
+
cmd += ["--server.port", str(int(port))]
|
|
55
|
+
cmd += ["--server.headless", "true" if headless else "false"]
|
|
56
|
+
cmd += ["--browser.gatherUsageStats", "false"]
|
|
57
|
+
cmd += ["--server.runOnSave", "false"]
|
|
58
|
+
cmd += ["--browser.serverAddress", str(host)]
|
|
59
|
+
if not open_browser:
|
|
60
|
+
cmd += ["--browser.serverPort", str(int(port))]
|
|
61
|
+
cmd += ["--browser.open", "false"]
|
|
62
|
+
|
|
63
|
+
if extra_args:
|
|
64
|
+
cmd += list(extra_args)
|
|
65
|
+
|
|
66
|
+
env = os.environ.copy()
|
|
67
|
+
# Make sure the demo doesn't spam internal Streamlit warnings in logs.
|
|
68
|
+
env.setdefault("STREAMLIT_BROWSER_GATHER_USAGE_STATS", "false")
|
|
69
|
+
|
|
70
|
+
return subprocess.call(cmd, env=env)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def main() -> None:
|
|
74
|
+
p = argparse.ArgumentParser(description="Launch the Visual RAG Toolkit Streamlit demo.")
|
|
75
|
+
p.add_argument("--host", default="0.0.0.0")
|
|
76
|
+
p.add_argument("--port", type=int, default=7860)
|
|
77
|
+
p.add_argument(
|
|
78
|
+
"--no-headless", action="store_true", help="Run with a browser window (not headless)."
|
|
79
|
+
)
|
|
80
|
+
p.add_argument("--open", action="store_true", help="Open browser automatically.")
|
|
81
|
+
args, unknown = p.parse_known_args()
|
|
82
|
+
|
|
83
|
+
rc = demo(
|
|
84
|
+
host=args.host,
|
|
85
|
+
port=args.port,
|
|
86
|
+
headless=(not args.no_headless),
|
|
87
|
+
open_browser=bool(args.open),
|
|
88
|
+
extra_args=unknown,
|
|
89
|
+
)
|
|
90
|
+
raise SystemExit(rc)
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Embedding module - Visual and text embedding generation.
|
|
3
|
+
|
|
4
|
+
Provides:
|
|
5
|
+
- VisualEmbedder: Backend-agnostic visual embedder (ColPali, etc.)
|
|
6
|
+
- Pooling utilities: tile-level, global, MaxSim scoring
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from visual_rag.embedding.pooling import (
|
|
10
|
+
compute_maxsim_batch,
|
|
11
|
+
compute_maxsim_score,
|
|
12
|
+
global_mean_pooling,
|
|
13
|
+
tile_level_mean_pooling,
|
|
14
|
+
)
|
|
15
|
+
from visual_rag.embedding.visual_embedder import ColPaliEmbedder, VisualEmbedder
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
# Main embedder
|
|
19
|
+
"VisualEmbedder",
|
|
20
|
+
"ColPaliEmbedder", # Backward compatibility alias
|
|
21
|
+
# Pooling functions
|
|
22
|
+
"tile_level_mean_pooling",
|
|
23
|
+
"global_mean_pooling",
|
|
24
|
+
"compute_maxsim_score",
|
|
25
|
+
"compute_maxsim_batch",
|
|
26
|
+
]
|
|
@@ -0,0 +1,343 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Pooling strategies for multi-vector embeddings.
|
|
3
|
+
|
|
4
|
+
Provides:
|
|
5
|
+
- Tile-level mean pooling: Preserves spatial structure (num_tiles × dim)
|
|
6
|
+
- Global mean pooling: Single vector (1 × dim)
|
|
7
|
+
- MaxSim scoring for ColBERT-style late interaction
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import logging
|
|
11
|
+
from typing import Optional, Union
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
import torch
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _infer_output_dtype(
|
|
20
|
+
embedding: Union[torch.Tensor, np.ndarray],
|
|
21
|
+
output_dtype: Optional[np.dtype] = None,
|
|
22
|
+
) -> np.dtype:
|
|
23
|
+
"""Infer output dtype: use provided, else match input (fp16→fp16, bf16→fp32, fp32→fp32)."""
|
|
24
|
+
if output_dtype is not None:
|
|
25
|
+
return output_dtype
|
|
26
|
+
if isinstance(embedding, torch.Tensor):
|
|
27
|
+
if embedding.dtype == torch.float16:
|
|
28
|
+
return np.float16
|
|
29
|
+
return np.float32
|
|
30
|
+
if isinstance(embedding, np.ndarray) and embedding.dtype == np.float16:
|
|
31
|
+
return np.float16
|
|
32
|
+
return np.float32
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def tile_level_mean_pooling(
|
|
36
|
+
embedding: Union[torch.Tensor, np.ndarray],
|
|
37
|
+
num_tiles: int,
|
|
38
|
+
patches_per_tile: int = 64,
|
|
39
|
+
output_dtype: Optional[np.dtype] = None,
|
|
40
|
+
) -> np.ndarray:
|
|
41
|
+
"""
|
|
42
|
+
Compute tile-level mean pooling for multi-vector embeddings.
|
|
43
|
+
|
|
44
|
+
Instead of collapsing to 1×dim (global pooling), this preserves spatial
|
|
45
|
+
structure by computing mean per tile → num_tiles × dim.
|
|
46
|
+
|
|
47
|
+
This is our NOVEL contribution for scalable visual retrieval:
|
|
48
|
+
- Faster than full MaxSim (fewer vectors to compare)
|
|
49
|
+
- More accurate than global pooling (preserves spatial info)
|
|
50
|
+
- Ideal for two-stage retrieval (prefetch with pooled, rerank with full)
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
embedding: Visual token embeddings [num_visual_tokens, dim]
|
|
54
|
+
num_tiles: Number of tiles (including global tile)
|
|
55
|
+
patches_per_tile: Patches per tile (64 for ColSmol)
|
|
56
|
+
output_dtype: Output dtype (default: infer from input, fp16→fp16, bf16→fp32)
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
Tile-level pooled embeddings [num_tiles, dim]
|
|
60
|
+
|
|
61
|
+
Example:
|
|
62
|
+
>>> # Image with 4×3 tiles + 1 global = 13 tiles
|
|
63
|
+
>>> # Each tile has 64 patches → 832 visual tokens
|
|
64
|
+
>>> pooled = tile_level_mean_pooling(embedding, num_tiles=13)
|
|
65
|
+
>>> print(pooled.shape) # (13, 128)
|
|
66
|
+
"""
|
|
67
|
+
out_dtype = _infer_output_dtype(embedding, output_dtype)
|
|
68
|
+
if isinstance(embedding, torch.Tensor):
|
|
69
|
+
if embedding.dtype == torch.bfloat16:
|
|
70
|
+
emb_np = embedding.cpu().float().numpy()
|
|
71
|
+
else:
|
|
72
|
+
emb_np = embedding.cpu().numpy().astype(np.float32)
|
|
73
|
+
else:
|
|
74
|
+
emb_np = np.array(embedding, dtype=np.float32)
|
|
75
|
+
|
|
76
|
+
num_visual_tokens = emb_np.shape[0]
|
|
77
|
+
expected_tokens = num_tiles * patches_per_tile
|
|
78
|
+
|
|
79
|
+
if num_visual_tokens != expected_tokens:
|
|
80
|
+
logger.debug(f"Token count mismatch: {num_visual_tokens} vs expected {expected_tokens}")
|
|
81
|
+
actual_tiles = num_visual_tokens // patches_per_tile
|
|
82
|
+
if actual_tiles * patches_per_tile != num_visual_tokens:
|
|
83
|
+
actual_tiles += 1
|
|
84
|
+
num_tiles = actual_tiles
|
|
85
|
+
|
|
86
|
+
tile_embeddings = []
|
|
87
|
+
for tile_idx in range(num_tiles):
|
|
88
|
+
start_idx = tile_idx * patches_per_tile
|
|
89
|
+
end_idx = min(start_idx + patches_per_tile, num_visual_tokens)
|
|
90
|
+
|
|
91
|
+
if start_idx >= num_visual_tokens:
|
|
92
|
+
break
|
|
93
|
+
|
|
94
|
+
tile_patches = emb_np[start_idx:end_idx]
|
|
95
|
+
tile_mean = tile_patches.mean(axis=0)
|
|
96
|
+
tile_embeddings.append(tile_mean)
|
|
97
|
+
|
|
98
|
+
return np.array(tile_embeddings, dtype=out_dtype)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def colpali_row_mean_pooling(
|
|
102
|
+
embedding: Union[torch.Tensor, np.ndarray],
|
|
103
|
+
grid_size: int = 32,
|
|
104
|
+
output_dtype: Optional[np.dtype] = None,
|
|
105
|
+
) -> np.ndarray:
|
|
106
|
+
out_dtype = _infer_output_dtype(embedding, output_dtype)
|
|
107
|
+
if isinstance(embedding, torch.Tensor):
|
|
108
|
+
if embedding.dtype == torch.bfloat16:
|
|
109
|
+
emb_np = embedding.cpu().float().numpy()
|
|
110
|
+
else:
|
|
111
|
+
emb_np = embedding.cpu().numpy().astype(np.float32)
|
|
112
|
+
else:
|
|
113
|
+
emb_np = np.array(embedding, dtype=np.float32)
|
|
114
|
+
|
|
115
|
+
num_tokens, dim = emb_np.shape
|
|
116
|
+
expected = int(grid_size) * int(grid_size)
|
|
117
|
+
if num_tokens != expected:
|
|
118
|
+
raise ValueError(
|
|
119
|
+
f"Expected {expected} visual tokens for grid_size={grid_size}, got {num_tokens}"
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
grid = emb_np.reshape(int(grid_size), int(grid_size), int(dim))
|
|
123
|
+
pooled = grid.mean(axis=1)
|
|
124
|
+
return pooled.astype(out_dtype)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def colsmol_experimental_pooling(
|
|
128
|
+
embedding: Union[torch.Tensor, np.ndarray],
|
|
129
|
+
num_tiles: int,
|
|
130
|
+
patches_per_tile: int = 64,
|
|
131
|
+
output_dtype: Optional[np.dtype] = None,
|
|
132
|
+
) -> np.ndarray:
|
|
133
|
+
out_dtype = _infer_output_dtype(embedding, output_dtype)
|
|
134
|
+
if isinstance(embedding, torch.Tensor):
|
|
135
|
+
if embedding.dtype == torch.bfloat16:
|
|
136
|
+
emb_np = embedding.cpu().float().numpy()
|
|
137
|
+
else:
|
|
138
|
+
emb_np = embedding.cpu().numpy().astype(np.float32)
|
|
139
|
+
else:
|
|
140
|
+
emb_np = np.array(embedding, dtype=np.float32)
|
|
141
|
+
|
|
142
|
+
num_visual_tokens, dim = emb_np.shape
|
|
143
|
+
if num_tiles <= 0:
|
|
144
|
+
raise ValueError("num_tiles must be > 0")
|
|
145
|
+
if patches_per_tile <= 0:
|
|
146
|
+
raise ValueError("patches_per_tile must be > 0")
|
|
147
|
+
|
|
148
|
+
last_tile_start = (int(num_tiles) - 1) * int(patches_per_tile)
|
|
149
|
+
if last_tile_start >= num_visual_tokens:
|
|
150
|
+
actual_tiles = int(num_visual_tokens) // int(patches_per_tile)
|
|
151
|
+
if actual_tiles * int(patches_per_tile) != int(num_visual_tokens):
|
|
152
|
+
actual_tiles += 1
|
|
153
|
+
if actual_tiles <= 0:
|
|
154
|
+
raise ValueError(
|
|
155
|
+
f"Not enough tokens for num_tiles={num_tiles}, patches_per_tile={patches_per_tile}: got {num_visual_tokens}"
|
|
156
|
+
)
|
|
157
|
+
num_tiles = actual_tiles
|
|
158
|
+
last_tile_start = (int(num_tiles) - 1) * int(patches_per_tile)
|
|
159
|
+
|
|
160
|
+
prefix = emb_np[:last_tile_start]
|
|
161
|
+
last_tile = emb_np[
|
|
162
|
+
last_tile_start : min(last_tile_start + int(patches_per_tile), num_visual_tokens)
|
|
163
|
+
]
|
|
164
|
+
|
|
165
|
+
if prefix.size:
|
|
166
|
+
prefix_tiles = prefix.reshape(-1, int(patches_per_tile), int(dim))
|
|
167
|
+
prefix_means = prefix_tiles.mean(axis=1)
|
|
168
|
+
else:
|
|
169
|
+
prefix_means = np.zeros((0, int(dim)), dtype=out_dtype)
|
|
170
|
+
|
|
171
|
+
return np.concatenate([prefix_means.astype(out_dtype), last_tile.astype(out_dtype)], axis=0)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def colpali_experimental_pooling_from_rows(
|
|
175
|
+
row_vectors: Union[torch.Tensor, np.ndarray],
|
|
176
|
+
output_dtype: Optional[np.dtype] = None,
|
|
177
|
+
) -> np.ndarray:
|
|
178
|
+
"""
|
|
179
|
+
Experimental "convolution-style" pooling with window size 3.
|
|
180
|
+
|
|
181
|
+
For N input rows, produces N + 2 output vectors:
|
|
182
|
+
- Position 0: row[0] alone (1 row)
|
|
183
|
+
- Position 1: mean(rows[0:2]) (2 rows)
|
|
184
|
+
- Position 2: mean(rows[0:3]) (3 rows)
|
|
185
|
+
- Positions 3 to N-1: sliding window of 3 (rows[i-2:i+1])
|
|
186
|
+
- Position N: mean(rows[N-2:N]) (last 2 rows)
|
|
187
|
+
- Position N+1: row[N-1] alone (last row)
|
|
188
|
+
|
|
189
|
+
For N=32 rows: produces 34 vectors.
|
|
190
|
+
"""
|
|
191
|
+
out_dtype = _infer_output_dtype(row_vectors, output_dtype)
|
|
192
|
+
if isinstance(row_vectors, torch.Tensor):
|
|
193
|
+
if row_vectors.dtype == torch.bfloat16:
|
|
194
|
+
rows = row_vectors.cpu().float().numpy()
|
|
195
|
+
else:
|
|
196
|
+
rows = row_vectors.cpu().numpy().astype(np.float32)
|
|
197
|
+
else:
|
|
198
|
+
rows = np.array(row_vectors, dtype=np.float32)
|
|
199
|
+
|
|
200
|
+
n, dim = rows.shape
|
|
201
|
+
if n < 1:
|
|
202
|
+
raise ValueError("row_vectors must be non-empty")
|
|
203
|
+
if n == 1:
|
|
204
|
+
return rows.astype(out_dtype)
|
|
205
|
+
if n == 2:
|
|
206
|
+
return np.stack([rows[0], rows[:2].mean(axis=0), rows[1]], axis=0).astype(out_dtype)
|
|
207
|
+
if n == 3:
|
|
208
|
+
return np.stack(
|
|
209
|
+
[
|
|
210
|
+
rows[0],
|
|
211
|
+
rows[:2].mean(axis=0),
|
|
212
|
+
rows[:3].mean(axis=0),
|
|
213
|
+
rows[1:3].mean(axis=0),
|
|
214
|
+
rows[2],
|
|
215
|
+
],
|
|
216
|
+
axis=0,
|
|
217
|
+
).astype(out_dtype)
|
|
218
|
+
|
|
219
|
+
out = np.zeros((n + 2, dim), dtype=np.float32)
|
|
220
|
+
out[0] = rows[0]
|
|
221
|
+
out[1] = rows[:2].mean(axis=0)
|
|
222
|
+
out[2] = rows[:3].mean(axis=0)
|
|
223
|
+
for i in range(3, n):
|
|
224
|
+
out[i] = rows[i - 2 : i + 1].mean(axis=0)
|
|
225
|
+
out[n] = rows[n - 2 : n].mean(axis=0)
|
|
226
|
+
out[n + 1] = rows[n - 1]
|
|
227
|
+
return out.astype(out_dtype)
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def global_mean_pooling(
|
|
231
|
+
embedding: Union[torch.Tensor, np.ndarray],
|
|
232
|
+
output_dtype: Optional[np.dtype] = None,
|
|
233
|
+
) -> np.ndarray:
|
|
234
|
+
"""
|
|
235
|
+
Compute global mean pooling → single vector.
|
|
236
|
+
|
|
237
|
+
This is the simplest pooling but loses all spatial information.
|
|
238
|
+
Use for fastest retrieval when accuracy can be sacrificed.
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
embedding: Multi-vector embeddings [num_tokens, dim]
|
|
242
|
+
output_dtype: Output dtype (default: infer from input, fp16→fp16, bf16→fp32)
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
Pooled vector [dim]
|
|
246
|
+
"""
|
|
247
|
+
out_dtype = _infer_output_dtype(embedding, output_dtype)
|
|
248
|
+
if isinstance(embedding, torch.Tensor):
|
|
249
|
+
if embedding.dtype == torch.bfloat16:
|
|
250
|
+
emb_np = embedding.cpu().float().numpy()
|
|
251
|
+
else:
|
|
252
|
+
emb_np = embedding.cpu().numpy()
|
|
253
|
+
else:
|
|
254
|
+
emb_np = np.array(embedding)
|
|
255
|
+
|
|
256
|
+
return emb_np.mean(axis=0).astype(out_dtype)
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def compute_maxsim_score(
|
|
260
|
+
query_embedding: np.ndarray,
|
|
261
|
+
doc_embedding: np.ndarray,
|
|
262
|
+
normalize: bool = True,
|
|
263
|
+
) -> float:
|
|
264
|
+
"""
|
|
265
|
+
Compute ColBERT-style MaxSim late interaction score.
|
|
266
|
+
|
|
267
|
+
For each query token, finds max similarity with any document token,
|
|
268
|
+
then sums across query tokens.
|
|
269
|
+
|
|
270
|
+
This is the standard scoring for ColBERT/ColPali:
|
|
271
|
+
score = Σ_q max_d (sim(q, d))
|
|
272
|
+
|
|
273
|
+
Args:
|
|
274
|
+
query_embedding: Query embeddings [num_query_tokens, dim]
|
|
275
|
+
doc_embedding: Document embeddings [num_doc_tokens, dim]
|
|
276
|
+
normalize: L2 normalize embeddings before scoring (recommended)
|
|
277
|
+
|
|
278
|
+
Returns:
|
|
279
|
+
MaxSim score (higher is better)
|
|
280
|
+
|
|
281
|
+
Example:
|
|
282
|
+
>>> query = embedder.embed_query("budget allocation")
|
|
283
|
+
>>> doc = embeddings[0] # From embed_images
|
|
284
|
+
>>> score = compute_maxsim_score(query, doc)
|
|
285
|
+
"""
|
|
286
|
+
if normalize:
|
|
287
|
+
# L2 normalize
|
|
288
|
+
query_norm = query_embedding / (
|
|
289
|
+
np.linalg.norm(query_embedding, axis=1, keepdims=True) + 1e-8
|
|
290
|
+
)
|
|
291
|
+
doc_norm = doc_embedding / (np.linalg.norm(doc_embedding, axis=1, keepdims=True) + 1e-8)
|
|
292
|
+
else:
|
|
293
|
+
query_norm = query_embedding
|
|
294
|
+
doc_norm = doc_embedding
|
|
295
|
+
|
|
296
|
+
# Compute similarity matrix: [num_query, num_doc]
|
|
297
|
+
similarity_matrix = np.dot(query_norm, doc_norm.T)
|
|
298
|
+
|
|
299
|
+
# MaxSim: For each query token, take max similarity with any doc token
|
|
300
|
+
max_similarities = similarity_matrix.max(axis=1)
|
|
301
|
+
|
|
302
|
+
# Sum across query tokens
|
|
303
|
+
score = float(max_similarities.sum())
|
|
304
|
+
|
|
305
|
+
return score
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
def compute_maxsim_batch(
|
|
309
|
+
query_embedding: np.ndarray,
|
|
310
|
+
doc_embeddings: list,
|
|
311
|
+
normalize: bool = True,
|
|
312
|
+
) -> list:
|
|
313
|
+
"""
|
|
314
|
+
Compute MaxSim scores for multiple documents efficiently.
|
|
315
|
+
|
|
316
|
+
Args:
|
|
317
|
+
query_embedding: Query embeddings [num_query_tokens, dim]
|
|
318
|
+
doc_embeddings: List of document embeddings
|
|
319
|
+
normalize: L2 normalize embeddings
|
|
320
|
+
|
|
321
|
+
Returns:
|
|
322
|
+
List of MaxSim scores
|
|
323
|
+
"""
|
|
324
|
+
# Pre-normalize query once
|
|
325
|
+
if normalize:
|
|
326
|
+
query_norm = query_embedding / (
|
|
327
|
+
np.linalg.norm(query_embedding, axis=1, keepdims=True) + 1e-8
|
|
328
|
+
)
|
|
329
|
+
else:
|
|
330
|
+
query_norm = query_embedding
|
|
331
|
+
|
|
332
|
+
scores = []
|
|
333
|
+
for doc_emb in doc_embeddings:
|
|
334
|
+
if normalize:
|
|
335
|
+
doc_norm = doc_emb / (np.linalg.norm(doc_emb, axis=1, keepdims=True) + 1e-8)
|
|
336
|
+
else:
|
|
337
|
+
doc_norm = doc_emb
|
|
338
|
+
|
|
339
|
+
sim_matrix = np.dot(query_norm, doc_norm.T)
|
|
340
|
+
max_sims = sim_matrix.max(axis=1)
|
|
341
|
+
scores.append(float(max_sims.sum()))
|
|
342
|
+
|
|
343
|
+
return scores
|