mirage-benchmark 1.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mirage-benchmark might be problematic. Click here for more details.
- mirage/__init__.py +83 -0
- mirage/cli.py +150 -0
- mirage/core/__init__.py +52 -0
- mirage/core/config.py +248 -0
- mirage/core/llm.py +1745 -0
- mirage/core/prompts.py +884 -0
- mirage/embeddings/__init__.py +31 -0
- mirage/embeddings/models.py +512 -0
- mirage/embeddings/rerankers_multimodal.py +766 -0
- mirage/embeddings/rerankers_text.py +149 -0
- mirage/evaluation/__init__.py +26 -0
- mirage/evaluation/metrics.py +2223 -0
- mirage/evaluation/metrics_optimized.py +2172 -0
- mirage/pipeline/__init__.py +45 -0
- mirage/pipeline/chunker.py +545 -0
- mirage/pipeline/context.py +1003 -0
- mirage/pipeline/deduplication.py +491 -0
- mirage/pipeline/domain.py +514 -0
- mirage/pipeline/pdf_processor.py +598 -0
- mirage/pipeline/qa_generator.py +798 -0
- mirage/utils/__init__.py +31 -0
- mirage/utils/ablation.py +360 -0
- mirage/utils/preflight.py +663 -0
- mirage/utils/stats.py +626 -0
- mirage_benchmark-1.0.4.dist-info/METADATA +490 -0
- mirage_benchmark-1.0.4.dist-info/RECORD +30 -0
- mirage_benchmark-1.0.4.dist-info/WHEEL +5 -0
- mirage_benchmark-1.0.4.dist-info/entry_points.txt +3 -0
- mirage_benchmark-1.0.4.dist-info/licenses/LICENSE +190 -0
- mirage_benchmark-1.0.4.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Embeddings module for MiRAGE - Embedding models and rerankers.
|
|
3
|
+
|
|
4
|
+
Imports are lazy to avoid loading heavy dependencies at import time.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
_LAZY_IMPORTS = {
|
|
8
|
+
# Embedding models
|
|
9
|
+
"BaseEmbeddingModel": ("models", "BaseEmbeddingModel"),
|
|
10
|
+
"NomicVLEmbed": ("models", "NomicVLEmbed"),
|
|
11
|
+
"SentenceTransformerEmbedder": ("models", "SentenceTransformerEmbedder"),
|
|
12
|
+
"HuggingFaceAPIEmbedder": ("models", "HuggingFaceAPIEmbedder"),
|
|
13
|
+
"get_best_embedding_model": ("models", "get_best_embedding_model"),
|
|
14
|
+
# Multimodal rerankers
|
|
15
|
+
"BaseReranker": ("rerankers_multimodal", "BaseReranker"),
|
|
16
|
+
"MonoVLMReranker": ("rerankers_multimodal", "MonoVLMReranker"),
|
|
17
|
+
"VLMReranker": ("rerankers_multimodal", "VLMReranker"),
|
|
18
|
+
"TextEmbeddingReranker": ("rerankers_multimodal", "TextEmbeddingReranker"),
|
|
19
|
+
# Text rerankers
|
|
20
|
+
"LLMReranker": ("rerankers_text", "LLMReranker"),
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def __getattr__(name):
|
|
25
|
+
"""Lazy import to avoid loading heavy dependencies at import time."""
|
|
26
|
+
if name in _LAZY_IMPORTS:
|
|
27
|
+
module_name, attr_name = _LAZY_IMPORTS[name]
|
|
28
|
+
import importlib
|
|
29
|
+
module = importlib.import_module(f"mirage.embeddings.{module_name}")
|
|
30
|
+
return getattr(module, attr_name)
|
|
31
|
+
raise AttributeError(f"module 'mirage.embeddings' has no attribute '{name}'")
|
|
@@ -0,0 +1,512 @@
|
|
|
1
|
+
|
|
2
|
+
import torch
|
|
3
|
+
import numpy as np
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from typing import Optional, List, Union
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from PIL import Image
|
|
8
|
+
import os
|
|
9
|
+
import requests
|
|
10
|
+
import base64
|
|
11
|
+
|
|
12
|
+
# Text Embedding Configuration
|
|
13
|
+
EMBEDDING_MODELS_TEXT = {
|
|
14
|
+
"bge_m3": "BAAI/bge-m3"
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
def get_best_embedding_model():
|
|
18
|
+
"""Returns the best text embedding model (BGE-M3)"""
|
|
19
|
+
return EMBEDDING_MODELS_TEXT["bge_m3"]
|
|
20
|
+
|
|
21
|
+
def get_device_map_for_gpus(gpus: Optional[List[int]] = None) -> str:
|
|
22
|
+
"""Returns device_map string for specified GPUs"""
|
|
23
|
+
if gpus and len(gpus) > 0:
|
|
24
|
+
# Use first specified GPU as primary
|
|
25
|
+
return f"cuda:{gpus[0]}"
|
|
26
|
+
return "cuda" if torch.cuda.is_available() else "cpu"
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
# Multimodal Embedding Classes
|
|
30
|
+
class BaseMultimodalEmbedder(ABC):
|
|
31
|
+
"""Abstract base class for multimodal embedders"""
|
|
32
|
+
|
|
33
|
+
@abstractmethod
|
|
34
|
+
def embed_text(self, text: str) -> torch.Tensor:
|
|
35
|
+
"""Embed a single text string. Internal method."""
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
@abstractmethod
|
|
39
|
+
def embed_image(self, image_path: str) -> torch.Tensor:
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
@abstractmethod
|
|
43
|
+
def embed_multimodal(self, text: str, image_path: Optional[str] = None) -> torch.Tensor:
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
def encode(
|
|
47
|
+
self,
|
|
48
|
+
sentences: Union[str, List[str]],
|
|
49
|
+
convert_to_tensor: bool = False,
|
|
50
|
+
convert_to_numpy: bool = False,
|
|
51
|
+
show_progress_bar: bool = False,
|
|
52
|
+
**kwargs
|
|
53
|
+
) -> Union[torch.Tensor, np.ndarray, List]:
|
|
54
|
+
"""
|
|
55
|
+
Encode sentences to embeddings. Matches SentenceTransformer API.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
sentences: Single string or list of strings to encode
|
|
59
|
+
convert_to_tensor: If True, return torch.Tensor
|
|
60
|
+
convert_to_numpy: If True, return numpy array
|
|
61
|
+
show_progress_bar: Ignored (for API compatibility)
|
|
62
|
+
**kwargs: Additional arguments (ignored for compatibility)
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
Embeddings as tensor, numpy array, or list depending on flags
|
|
66
|
+
"""
|
|
67
|
+
# Handle single string
|
|
68
|
+
if isinstance(sentences, str):
|
|
69
|
+
embedding = self.embed_text(sentences)
|
|
70
|
+
if convert_to_numpy:
|
|
71
|
+
return embedding.cpu().float().numpy() if isinstance(embedding, torch.Tensor) else np.array(embedding)
|
|
72
|
+
if convert_to_tensor:
|
|
73
|
+
return embedding if isinstance(embedding, torch.Tensor) else torch.tensor(embedding)
|
|
74
|
+
return embedding.cpu().float().numpy() if isinstance(embedding, torch.Tensor) else embedding
|
|
75
|
+
|
|
76
|
+
# Handle list of strings
|
|
77
|
+
embeddings = []
|
|
78
|
+
for text in sentences:
|
|
79
|
+
emb = self.embed_text(text)
|
|
80
|
+
embeddings.append(emb)
|
|
81
|
+
|
|
82
|
+
# Stack embeddings
|
|
83
|
+
if embeddings:
|
|
84
|
+
stacked = torch.stack(embeddings) if isinstance(embeddings[0], torch.Tensor) else torch.tensor(embeddings)
|
|
85
|
+
if convert_to_numpy:
|
|
86
|
+
return stacked.cpu().float().numpy()
|
|
87
|
+
if convert_to_tensor:
|
|
88
|
+
return stacked
|
|
89
|
+
return stacked.cpu().float().numpy()
|
|
90
|
+
|
|
91
|
+
return np.array([]) if convert_to_numpy else torch.tensor([])
|
|
92
|
+
|
|
93
|
+
class NomicVLEmbed(BaseMultimodalEmbedder):
|
|
94
|
+
"""
|
|
95
|
+
Nomic Embed Multimodal 7B
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
def __init__(self, model_name: str = "nomic-ai/nomic-embed-multimodal-7b", gpus: Optional[List[int]] = None):
|
|
99
|
+
from transformers import BitsAndBytesConfig
|
|
100
|
+
from colpali_engine.models import BiQwen2_5, BiQwen2_5_Processor
|
|
101
|
+
|
|
102
|
+
print(f"Loading Nomic: {model_name}")
|
|
103
|
+
self._setup_hf_auth()
|
|
104
|
+
|
|
105
|
+
# Use specified GPUs or default to cuda
|
|
106
|
+
self.device = get_device_map_for_gpus(gpus)
|
|
107
|
+
self.gpus = gpus
|
|
108
|
+
|
|
109
|
+
if torch.cuda.is_available():
|
|
110
|
+
torch.cuda.empty_cache()
|
|
111
|
+
|
|
112
|
+
# Read attention implementation from config (default to sdpa for stability)
|
|
113
|
+
attn_impl = "sdpa" # Default to PyTorch native attention
|
|
114
|
+
try:
|
|
115
|
+
from config_loader import get_embedding_config
|
|
116
|
+
embed_config = get_embedding_config()
|
|
117
|
+
nomic_config = embed_config.get('models', {}).get('nomic', {})
|
|
118
|
+
attn_impl = nomic_config.get('attn_implementation', 'sdpa')
|
|
119
|
+
except Exception:
|
|
120
|
+
pass
|
|
121
|
+
|
|
122
|
+
is_cuda = self.device.startswith("cuda")
|
|
123
|
+
quantization_config = BitsAndBytesConfig(
|
|
124
|
+
load_in_4bit=True,
|
|
125
|
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
|
126
|
+
bnb_4bit_use_double_quant=True,
|
|
127
|
+
) if is_cuda else None
|
|
128
|
+
|
|
129
|
+
self.model = BiQwen2_5.from_pretrained(
|
|
130
|
+
model_name,
|
|
131
|
+
torch_dtype=torch.bfloat16,
|
|
132
|
+
device_map=self.device,
|
|
133
|
+
attn_implementation=attn_impl,
|
|
134
|
+
quantization_config=quantization_config,
|
|
135
|
+
).eval()
|
|
136
|
+
|
|
137
|
+
self.processor = BiQwen2_5_Processor.from_pretrained(model_name)
|
|
138
|
+
print(f"✅ Nomic loaded on {self.device}")
|
|
139
|
+
|
|
140
|
+
def _setup_hf_auth(self):
|
|
141
|
+
api_key_path = os.environ.get("HF_TOKEN_PATH", os.path.expanduser("~/.config/huggingface/token"))
|
|
142
|
+
if os.path.exists(api_key_path):
|
|
143
|
+
with open(api_key_path, 'r') as f:
|
|
144
|
+
os.environ["HUGGING_FACE_HUB_TOKEN"] = f.read().strip()
|
|
145
|
+
|
|
146
|
+
def embed_text(self, text: str) -> torch.Tensor:
|
|
147
|
+
inputs = self.processor.process_queries([text]).to(self.device)
|
|
148
|
+
with torch.no_grad():
|
|
149
|
+
embeddings = self.model(**inputs)
|
|
150
|
+
return embeddings.flatten()
|
|
151
|
+
|
|
152
|
+
def embed_image(self, image_path: str) -> torch.Tensor:
|
|
153
|
+
if not Path(image_path).exists():
|
|
154
|
+
return torch.zeros(128, device=self.device)
|
|
155
|
+
|
|
156
|
+
image = Image.open(image_path).convert('RGB')
|
|
157
|
+
inputs = self.processor.process_images([image]).to(self.device)
|
|
158
|
+
with torch.no_grad():
|
|
159
|
+
embeddings = self.model(**inputs)
|
|
160
|
+
return embeddings.flatten()
|
|
161
|
+
|
|
162
|
+
def embed_multimodal(self, text: str, image_path: Optional[str] = None) -> torch.Tensor:
|
|
163
|
+
if image_path and Path(image_path).exists():
|
|
164
|
+
image = Image.open(image_path).convert('RGB')
|
|
165
|
+
batch_images = self.processor.process_images([image]).to(self.device)
|
|
166
|
+
batch_queries = self.processor.process_queries([text]).to(self.device)
|
|
167
|
+
|
|
168
|
+
with torch.no_grad():
|
|
169
|
+
query_emb = self.model(**batch_queries)
|
|
170
|
+
image_emb = self.model(**batch_images)
|
|
171
|
+
# Normalize and combine text and image embeddings
|
|
172
|
+
query_emb = torch.nn.functional.normalize(query_emb, dim=-1)
|
|
173
|
+
image_emb = torch.nn.functional.normalize(image_emb, dim=-1)
|
|
174
|
+
combined = (query_emb + image_emb) / 2
|
|
175
|
+
combined = torch.nn.functional.normalize(combined, dim=-1)
|
|
176
|
+
|
|
177
|
+
return combined.flatten()
|
|
178
|
+
return self.embed_text(text)
|
|
179
|
+
|
|
180
|
+
class Qwen2VLEmbed(BaseMultimodalEmbedder):
|
|
181
|
+
"""
|
|
182
|
+
Qwen2-VL for multimodal embeddings
|
|
183
|
+
"""
|
|
184
|
+
|
|
185
|
+
def __init__(self, model_name: str = "Qwen/Qwen2-VL-7B-Instruct"):
|
|
186
|
+
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, BitsAndBytesConfig
|
|
187
|
+
|
|
188
|
+
print(f"Loading Qwen2-VL: {model_name}")
|
|
189
|
+
if torch.cuda.is_available():
|
|
190
|
+
torch.cuda.empty_cache()
|
|
191
|
+
|
|
192
|
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
193
|
+
|
|
194
|
+
quantization_config = BitsAndBytesConfig(
|
|
195
|
+
load_in_4bit=True,
|
|
196
|
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
|
197
|
+
bnb_4bit_use_double_quant=True,
|
|
198
|
+
) if self.device == "cuda" else None
|
|
199
|
+
|
|
200
|
+
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
|
|
201
|
+
model_name,
|
|
202
|
+
torch_dtype=torch.bfloat16,
|
|
203
|
+
device_map="auto",
|
|
204
|
+
quantization_config=quantization_config,
|
|
205
|
+
).eval()
|
|
206
|
+
|
|
207
|
+
self.processor = AutoProcessor.from_pretrained(model_name)
|
|
208
|
+
print(f"✅ Qwen2-VL loaded on {self.device}")
|
|
209
|
+
|
|
210
|
+
def embed_text(self, text: str) -> torch.Tensor:
|
|
211
|
+
messages = [{"role": "user", "content": [{"type": "text", "text": text}]}]
|
|
212
|
+
inputs = self.processor.apply_chat_template(
|
|
213
|
+
messages, tokenize=True, return_dict=True, return_tensors="pt"
|
|
214
|
+
).to(self.device)
|
|
215
|
+
|
|
216
|
+
with torch.no_grad():
|
|
217
|
+
outputs = self.model(**inputs, output_hidden_states=True)
|
|
218
|
+
embedding = outputs.hidden_states[-1].mean(dim=1).squeeze()
|
|
219
|
+
|
|
220
|
+
return embedding.flatten()
|
|
221
|
+
|
|
222
|
+
def embed_image(self, image_path: str) -> torch.Tensor:
|
|
223
|
+
if not Path(image_path).exists():
|
|
224
|
+
return torch.zeros(1536, device=self.device)
|
|
225
|
+
|
|
226
|
+
image = Image.open(image_path).convert('RGB')
|
|
227
|
+
messages = [{"role": "user", "content": [{"type": "image", "image": image}]}]
|
|
228
|
+
inputs = self.processor.apply_chat_template(
|
|
229
|
+
messages, tokenize=True, return_dict=True, return_tensors="pt"
|
|
230
|
+
).to(self.device)
|
|
231
|
+
|
|
232
|
+
with torch.no_grad():
|
|
233
|
+
outputs = self.model(**inputs, output_hidden_states=True)
|
|
234
|
+
embedding = outputs.hidden_states[-1].mean(dim=1).squeeze()
|
|
235
|
+
|
|
236
|
+
return embedding.flatten()
|
|
237
|
+
|
|
238
|
+
def embed_multimodal(self, text: str, image_path: Optional[str] = None) -> torch.Tensor:
|
|
239
|
+
if image_path and Path(image_path).exists():
|
|
240
|
+
image = Image.open(image_path).convert('RGB')
|
|
241
|
+
messages = [{
|
|
242
|
+
"role": "user",
|
|
243
|
+
"content": [
|
|
244
|
+
{"type": "image", "image": image},
|
|
245
|
+
{"type": "text", "text": text}
|
|
246
|
+
]
|
|
247
|
+
}]
|
|
248
|
+
inputs = self.processor.apply_chat_template(
|
|
249
|
+
messages, tokenize=True, return_dict=True, return_tensors="pt"
|
|
250
|
+
).to(self.device)
|
|
251
|
+
|
|
252
|
+
with torch.no_grad():
|
|
253
|
+
outputs = self.model(**inputs, output_hidden_states=True)
|
|
254
|
+
embedding = outputs.hidden_states[-1].mean(dim=1).squeeze()
|
|
255
|
+
|
|
256
|
+
return embedding.flatten()
|
|
257
|
+
return self.embed_text(text)
|
|
258
|
+
|
|
259
|
+
class VLMDescriptionEmbed(BaseMultimodalEmbedder):
|
|
260
|
+
"""
|
|
261
|
+
VLM Description-based Embedder
|
|
262
|
+
"""
|
|
263
|
+
|
|
264
|
+
def __init__(self,
|
|
265
|
+
text_model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
|
|
266
|
+
vlm_api_url: str = "https://api.openai.com/v1/chat/completions",
|
|
267
|
+
vlm_model_name: str = "gpt-4o"):
|
|
268
|
+
from sentence_transformers import SentenceTransformer
|
|
269
|
+
|
|
270
|
+
print(f"Loading VLM Description Embedder with text model: {text_model_name}")
|
|
271
|
+
|
|
272
|
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
273
|
+
|
|
274
|
+
# Load text embedding model
|
|
275
|
+
self.text_model = SentenceTransformer(text_model_name, device=self.device)
|
|
276
|
+
|
|
277
|
+
# Setup VLM API
|
|
278
|
+
self.vlm_api_url = vlm_api_url
|
|
279
|
+
self.vlm_model_name = vlm_model_name
|
|
280
|
+
|
|
281
|
+
# Load API key (use environment or config file)
|
|
282
|
+
api_key_path = os.environ.get("OPENAI_API_KEY_PATH", os.path.expanduser("~/.config/openai/api_key.txt"))
|
|
283
|
+
with open(api_key_path, 'r') as f:
|
|
284
|
+
self.api_key = f.read().strip()
|
|
285
|
+
|
|
286
|
+
self.headers = {
|
|
287
|
+
"Authorization": f"Bearer {self.api_key}",
|
|
288
|
+
"Content-Type": "application/json"
|
|
289
|
+
}
|
|
290
|
+
|
|
291
|
+
print(f"✅ VLM Description Embedder loaded on {self.device}")
|
|
292
|
+
|
|
293
|
+
def _describe_image(self, image_path: str) -> str:
|
|
294
|
+
"""Use VLM API to generate textual description of image"""
|
|
295
|
+
|
|
296
|
+
# Read and encode image
|
|
297
|
+
with open(image_path, 'rb') as f:
|
|
298
|
+
image_data = base64.b64encode(f.read()).decode('utf-8')
|
|
299
|
+
|
|
300
|
+
# Prepare API request
|
|
301
|
+
payload = {
|
|
302
|
+
"model": self.vlm_model_name,
|
|
303
|
+
"messages": [
|
|
304
|
+
{
|
|
305
|
+
"role": "user",
|
|
306
|
+
"content": [
|
|
307
|
+
{
|
|
308
|
+
"type": "image_url",
|
|
309
|
+
"image_url": {
|
|
310
|
+
"url": f"data:image/jpeg;base64,{image_data}"
|
|
311
|
+
}
|
|
312
|
+
},
|
|
313
|
+
{
|
|
314
|
+
"type": "text",
|
|
315
|
+
"text": "Describe this image with concise details, focusing on technical content, diagrams, charts, tables, and text visible in the image."
|
|
316
|
+
}
|
|
317
|
+
]
|
|
318
|
+
}
|
|
319
|
+
],
|
|
320
|
+
"max_tokens": 500
|
|
321
|
+
}
|
|
322
|
+
|
|
323
|
+
# Call API with timeout (increased to 180 seconds for large/complex images)
|
|
324
|
+
response = requests.post(self.vlm_api_url, headers=self.headers, json=payload, timeout=180)
|
|
325
|
+
response.raise_for_status()
|
|
326
|
+
|
|
327
|
+
result = response.json()
|
|
328
|
+
description = result['choices'][0]['message']['content']
|
|
329
|
+
|
|
330
|
+
return description
|
|
331
|
+
|
|
332
|
+
def embed_text(self, text: str) -> torch.Tensor:
|
|
333
|
+
embedding = self.text_model.encode(text, convert_to_tensor=True, device=self.device)
|
|
334
|
+
return embedding.flatten()
|
|
335
|
+
|
|
336
|
+
def embed_image(self, image_path: str) -> torch.Tensor:
|
|
337
|
+
if not Path(image_path).exists():
|
|
338
|
+
return torch.zeros(384, device=self.device)
|
|
339
|
+
|
|
340
|
+
description = self._describe_image(image_path)
|
|
341
|
+
return self.embed_text(description)
|
|
342
|
+
|
|
343
|
+
def embed_multimodal(self, text: str, image_path: Optional[str] = None) -> torch.Tensor:
|
|
344
|
+
if image_path and Path(image_path).exists():
|
|
345
|
+
description = self._describe_image(image_path)
|
|
346
|
+
combined_text = f"{text}\n\nImage description: {description}"
|
|
347
|
+
return self.embed_text(combined_text)
|
|
348
|
+
return self.embed_text(text)
|
|
349
|
+
|
|
350
|
+
class BGEVLEmbed(BaseMultimodalEmbedder):
|
|
351
|
+
"""
|
|
352
|
+
BGE-VL-v1.5-mmeb (MLLM variant)
|
|
353
|
+
"""
|
|
354
|
+
|
|
355
|
+
def __init__(self, model_name: str = "BAAI/BGE-VL-v1.5-zs"):
|
|
356
|
+
from transformers import AutoModel
|
|
357
|
+
|
|
358
|
+
print(f"Loading BGE-VL-v1.5: {model_name}")
|
|
359
|
+
if torch.cuda.is_available():
|
|
360
|
+
torch.cuda.empty_cache()
|
|
361
|
+
|
|
362
|
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
363
|
+
|
|
364
|
+
self.model = AutoModel.from_pretrained(
|
|
365
|
+
model_name,
|
|
366
|
+
trust_remote_code=True,
|
|
367
|
+
torch_dtype=torch.float16,
|
|
368
|
+
device_map="auto"
|
|
369
|
+
)
|
|
370
|
+
self.model.eval()
|
|
371
|
+
|
|
372
|
+
# Add missing image_newline attribute if needed
|
|
373
|
+
if not hasattr(self.model, 'image_newline'):
|
|
374
|
+
try:
|
|
375
|
+
hidden_size = self.model.config.text_config.hidden_size
|
|
376
|
+
except:
|
|
377
|
+
try:
|
|
378
|
+
hidden_size = self.model.config.hidden_size
|
|
379
|
+
except:
|
|
380
|
+
hidden_size = 4096
|
|
381
|
+
import torch.nn as nn
|
|
382
|
+
self.model.image_newline = nn.Parameter(
|
|
383
|
+
torch.zeros(hidden_size, dtype=torch.float16)
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
# Set processor
|
|
387
|
+
with torch.no_grad():
|
|
388
|
+
self.model.set_processor(model_name)
|
|
389
|
+
|
|
390
|
+
# Fix processor patch_size issue
|
|
391
|
+
if hasattr(self.model, 'processor'):
|
|
392
|
+
proc = self.model.processor
|
|
393
|
+
if hasattr(proc, 'patch_size') and proc.patch_size is None:
|
|
394
|
+
proc.patch_size = 14
|
|
395
|
+
if hasattr(proc, 'image_processor'):
|
|
396
|
+
img_proc = proc.image_processor
|
|
397
|
+
if hasattr(img_proc, 'patch_size') and img_proc.patch_size is None:
|
|
398
|
+
img_proc.patch_size = 14
|
|
399
|
+
|
|
400
|
+
self._patch_forward_method()
|
|
401
|
+
|
|
402
|
+
# Get embedding dimension
|
|
403
|
+
with torch.no_grad():
|
|
404
|
+
try:
|
|
405
|
+
test_inputs = self.model.data_process(text="test", q_or_c="c")
|
|
406
|
+
test_outputs = self.model(**test_inputs, output_hidden_states=True)
|
|
407
|
+
if hasattr(test_outputs, 'hidden_states'):
|
|
408
|
+
test_emb = test_outputs.hidden_states[-1][:, -1, :]
|
|
409
|
+
else:
|
|
410
|
+
test_emb = test_outputs[:, -1, :]
|
|
411
|
+
self.embedding_dim = test_emb.shape[-1]
|
|
412
|
+
except Exception:
|
|
413
|
+
self.embedding_dim = 4096
|
|
414
|
+
|
|
415
|
+
print(f"✅ BGE-VL-v1.5 loaded on {self.device}, dim: {self.embedding_dim}")
|
|
416
|
+
|
|
417
|
+
def _patch_forward_method(self):
|
|
418
|
+
import types
|
|
419
|
+
if hasattr(self.model, 'pack_image_features'):
|
|
420
|
+
original_pack = self.model.pack_image_features
|
|
421
|
+
def fixed_pack_image_features(self_model, image_features, image_sizes, **kwargs):
|
|
422
|
+
result, feature_lens = original_pack(image_features, image_sizes, **kwargs)
|
|
423
|
+
if isinstance(result, list):
|
|
424
|
+
if len(result) > 0 and isinstance(result[0], torch.Tensor):
|
|
425
|
+
try:
|
|
426
|
+
result = torch.stack(result, dim=0) if len(result) > 1 else result[0]
|
|
427
|
+
except:
|
|
428
|
+
try:
|
|
429
|
+
result = torch.cat(result, dim=0)
|
|
430
|
+
except:
|
|
431
|
+
result = result[0]
|
|
432
|
+
return result, feature_lens
|
|
433
|
+
self.model.pack_image_features = types.MethodType(fixed_pack_image_features, self.model)
|
|
434
|
+
|
|
435
|
+
model_class = self.model.__class__
|
|
436
|
+
if not hasattr(model_class, '_bgevl_original_forward'):
|
|
437
|
+
model_class._bgevl_original_forward = model_class.forward
|
|
438
|
+
def patched_forward(self, *args, **kwargs):
|
|
439
|
+
if hasattr(self, 'pack_image_features'):
|
|
440
|
+
original_pack = self.pack_image_features
|
|
441
|
+
def fixed_pack_image_features(image_features, image_sizes, **kwargs):
|
|
442
|
+
result, feature_lens = original_pack(image_features, image_sizes, **kwargs)
|
|
443
|
+
if isinstance(result, list):
|
|
444
|
+
if len(result) > 0 and isinstance(result[0], torch.Tensor):
|
|
445
|
+
try:
|
|
446
|
+
result = torch.stack(result, dim=0) if len(result) > 1 else result[0]
|
|
447
|
+
except:
|
|
448
|
+
try:
|
|
449
|
+
result = torch.cat(result, dim=0)
|
|
450
|
+
except:
|
|
451
|
+
result = result[0]
|
|
452
|
+
return result, feature_lens
|
|
453
|
+
self.pack_image_features = fixed_pack_image_features
|
|
454
|
+
try:
|
|
455
|
+
return model_class._bgevl_original_forward(self, *args, **kwargs)
|
|
456
|
+
finally:
|
|
457
|
+
self.pack_image_features = original_pack
|
|
458
|
+
else:
|
|
459
|
+
return model_class._bgevl_original_forward(self, *args, **kwargs)
|
|
460
|
+
model_class.forward = patched_forward
|
|
461
|
+
|
|
462
|
+
if hasattr(self.model, 'vision_tower'):
|
|
463
|
+
vt_class = self.model.vision_tower.__class__
|
|
464
|
+
if not hasattr(vt_class, '_original_vt_forward'):
|
|
465
|
+
vt_class._original_vt_forward = vt_class.forward
|
|
466
|
+
def fixed_vt_forward(vt_self, pixel_values, *args, **kwargs):
|
|
467
|
+
if isinstance(pixel_values, torch.Tensor) and pixel_values.dim() == 5:
|
|
468
|
+
b, n, c, h, w = pixel_values.shape
|
|
469
|
+
pixel_values = pixel_values.reshape(b * n, c, h, w)
|
|
470
|
+
return vt_class._original_vt_forward(vt_self, pixel_values, *args, **kwargs)
|
|
471
|
+
vt_class.forward = fixed_vt_forward
|
|
472
|
+
|
|
473
|
+
def embed_text(self, text: str) -> torch.Tensor:
|
|
474
|
+
with torch.no_grad():
|
|
475
|
+
inputs = self.model.data_process(text=text, q_or_c="c")
|
|
476
|
+
outputs = self.model(**inputs, output_hidden_states=True)
|
|
477
|
+
if hasattr(outputs, 'hidden_states'):
|
|
478
|
+
embedding = outputs.hidden_states[-1][:, -1, :]
|
|
479
|
+
else:
|
|
480
|
+
embedding = outputs[:, -1, :]
|
|
481
|
+
embedding = torch.nn.functional.normalize(embedding, dim=-1)
|
|
482
|
+
return embedding.to(device=self.device, dtype=torch.float32).flatten()
|
|
483
|
+
|
|
484
|
+
def embed_image(self, image_path: str) -> torch.Tensor:
|
|
485
|
+
if not Path(image_path).exists():
|
|
486
|
+
return torch.zeros(self.embedding_dim, device=self.device, dtype=torch.float32)
|
|
487
|
+
with torch.no_grad():
|
|
488
|
+
inputs = self.model.data_process(images=str(image_path), q_or_c="c")
|
|
489
|
+
outputs = self.model(**inputs, output_hidden_states=True)
|
|
490
|
+
if hasattr(outputs, 'hidden_states'):
|
|
491
|
+
embedding = outputs.hidden_states[-1][:, -1, :]
|
|
492
|
+
else:
|
|
493
|
+
embedding = outputs[:, -1, :]
|
|
494
|
+
embedding = torch.nn.functional.normalize(embedding, dim=-1)
|
|
495
|
+
return embedding.to(device=self.device, dtype=torch.float32).flatten()
|
|
496
|
+
|
|
497
|
+
def embed_multimodal(self, text: str, image_path: Optional[str] = None) -> torch.Tensor:
|
|
498
|
+
if image_path and Path(image_path).exists():
|
|
499
|
+
with torch.no_grad():
|
|
500
|
+
inputs = self.model.data_process(
|
|
501
|
+
text=text,
|
|
502
|
+
images=str(image_path),
|
|
503
|
+
q_or_c="c"
|
|
504
|
+
)
|
|
505
|
+
outputs = self.model(**inputs, output_hidden_states=True)
|
|
506
|
+
if hasattr(outputs, 'hidden_states'):
|
|
507
|
+
embedding = outputs.hidden_states[-1][:, -1, :]
|
|
508
|
+
else:
|
|
509
|
+
embedding = outputs[:, -1, :]
|
|
510
|
+
embedding = torch.nn.functional.normalize(embedding, dim=-1)
|
|
511
|
+
return embedding.to(device=self.device, dtype=torch.float32).flatten()
|
|
512
|
+
return self.embed_text(text)
|