mirage-benchmark 1.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mirage-benchmark might be problematic. Click here for more details.
- mirage/__init__.py +83 -0
- mirage/cli.py +150 -0
- mirage/core/__init__.py +52 -0
- mirage/core/config.py +248 -0
- mirage/core/llm.py +1745 -0
- mirage/core/prompts.py +884 -0
- mirage/embeddings/__init__.py +31 -0
- mirage/embeddings/models.py +512 -0
- mirage/embeddings/rerankers_multimodal.py +766 -0
- mirage/embeddings/rerankers_text.py +149 -0
- mirage/evaluation/__init__.py +26 -0
- mirage/evaluation/metrics.py +2223 -0
- mirage/evaluation/metrics_optimized.py +2172 -0
- mirage/pipeline/__init__.py +45 -0
- mirage/pipeline/chunker.py +545 -0
- mirage/pipeline/context.py +1003 -0
- mirage/pipeline/deduplication.py +491 -0
- mirage/pipeline/domain.py +514 -0
- mirage/pipeline/pdf_processor.py +598 -0
- mirage/pipeline/qa_generator.py +798 -0
- mirage/utils/__init__.py +31 -0
- mirage/utils/ablation.py +360 -0
- mirage/utils/preflight.py +663 -0
- mirage/utils/stats.py +626 -0
- mirage_benchmark-1.0.4.dist-info/METADATA +490 -0
- mirage_benchmark-1.0.4.dist-info/RECORD +30 -0
- mirage_benchmark-1.0.4.dist-info/WHEEL +5 -0
- mirage_benchmark-1.0.4.dist-info/entry_points.txt +3 -0
- mirage_benchmark-1.0.4.dist-info/licenses/LICENSE +190 -0
- mirage_benchmark-1.0.4.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,766 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import re
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import List, Optional, Dict, Tuple
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from PIL import Image
|
|
7
|
+
|
|
8
|
+
class BaseReranker(ABC):
|
|
9
|
+
"""Abstract base class for image-based rerankers"""
|
|
10
|
+
|
|
11
|
+
@abstractmethod
|
|
12
|
+
def rerank(self, query: str, image_paths: List[str], top_k: int = 10) -> List[int]:
|
|
13
|
+
pass
|
|
14
|
+
|
|
15
|
+
class ChunkReranker(ABC):
|
|
16
|
+
"""Abstract base class for chunk-based rerankers (text + optional images)"""
|
|
17
|
+
|
|
18
|
+
@abstractmethod
|
|
19
|
+
def rerank(self, query: str, chunks: List[Dict[str, str]], top_k: int = 1) -> List[Tuple[int, float, Dict[str, str]]]:
|
|
20
|
+
"""
|
|
21
|
+
Rerank chunks based on query relevance
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
query: User query string
|
|
25
|
+
chunks: List of dicts with 'text' and optional 'image_path' keys
|
|
26
|
+
top_k: Number of top chunks to return
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
List of tuples: (original_index, relevance_score, chunk_dict)
|
|
30
|
+
"""
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
class MMR5Reranker(BaseReranker):
|
|
34
|
+
"""MM-R5: MultiModal Reasoning-Enhanced ReRanker"""
|
|
35
|
+
|
|
36
|
+
SYSTEM_PROMPT = (
|
|
37
|
+
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
|
|
38
|
+
"first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
|
|
39
|
+
"process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
|
|
40
|
+
"<think> reasoning process here </think><answer> answer here </answer>"
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
QUESTION_TEMPLATE = (
|
|
44
|
+
"Please rank the following images according to their relevance to the question. "
|
|
45
|
+
"Provide your response in the format: <think>your reasoning process here</think><answer>[image_id_1, image_id_2, ...]</answer> "
|
|
46
|
+
"where the numbers in the list represent the ranking order of images'id from most to least relevant. "
|
|
47
|
+
"Before outputting the answer, you need to analyze each image and provide your analysis process."
|
|
48
|
+
"For example: <think>Image 1 shows the most relevant content because...</think><answer>[id_most_relevant, id_second_relevant, ...]</answer>"
|
|
49
|
+
"\nThe question is: {Question}"
|
|
50
|
+
"\n\nThere are {image_num} images, id from 1 to {image_num_end}, Image ID to image mapping:\n"
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
def __init__(self, model_name: str = "i2vec/MM-R5"):
|
|
54
|
+
print(f"Loading MM-R5: {model_name}")
|
|
55
|
+
|
|
56
|
+
# Try official MM-R5 package first
|
|
57
|
+
try:
|
|
58
|
+
from reranker import QueryReranker # type: ignore
|
|
59
|
+
self.reranker = QueryReranker(model_name)
|
|
60
|
+
self.use_official_package = True
|
|
61
|
+
print(f"✅ MM-R5 loaded via official package")
|
|
62
|
+
except ImportError:
|
|
63
|
+
# Fallback to direct implementation using official code
|
|
64
|
+
from transformers import Qwen2_5_VLProcessor, Qwen2_5_VLForConditionalGeneration
|
|
65
|
+
from qwen_vl_utils import process_vision_info
|
|
66
|
+
|
|
67
|
+
attn_kwargs = {}
|
|
68
|
+
if self._flash_attn_available():
|
|
69
|
+
attn_kwargs["attn_implementation"] = "flash_attention_2"
|
|
70
|
+
else:
|
|
71
|
+
print("⚠️ flash_attn not available. Falling back to default attention implementation.")
|
|
72
|
+
|
|
73
|
+
# Load without device_map="auto" to avoid meta tensor issues in parallel processing
|
|
74
|
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
75
|
+
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
|
76
|
+
model_name,
|
|
77
|
+
torch_dtype=torch.bfloat16,
|
|
78
|
+
low_cpu_mem_usage=True,
|
|
79
|
+
**attn_kwargs,
|
|
80
|
+
).to(self.device).eval()
|
|
81
|
+
|
|
82
|
+
self.processor = Qwen2_5_VLProcessor.from_pretrained(model_name)
|
|
83
|
+
self.use_official_package = False
|
|
84
|
+
print(f"✅ MM-R5 loaded via direct implementation")
|
|
85
|
+
|
|
86
|
+
def rerank(self, query: str, image_paths: List[str], top_k: int = 10) -> List[int]:
|
|
87
|
+
"""Rerank images based on query relevance"""
|
|
88
|
+
import re
|
|
89
|
+
|
|
90
|
+
if self.use_official_package:
|
|
91
|
+
# Use official package API
|
|
92
|
+
predicted_order = self.reranker.rerank(query, image_paths)
|
|
93
|
+
else:
|
|
94
|
+
# Use direct implementation following official code
|
|
95
|
+
from qwen_vl_utils import process_vision_info
|
|
96
|
+
|
|
97
|
+
device = self.model.device
|
|
98
|
+
|
|
99
|
+
messages = [
|
|
100
|
+
{
|
|
101
|
+
"role": "system",
|
|
102
|
+
"content": [
|
|
103
|
+
{
|
|
104
|
+
"type": "text",
|
|
105
|
+
"text": self.SYSTEM_PROMPT,
|
|
106
|
+
},
|
|
107
|
+
],
|
|
108
|
+
},
|
|
109
|
+
{
|
|
110
|
+
"role": "user",
|
|
111
|
+
"content": [
|
|
112
|
+
{
|
|
113
|
+
"type": "text",
|
|
114
|
+
"text": self.QUESTION_TEMPLATE.format(
|
|
115
|
+
Question=query,
|
|
116
|
+
image_num=len(image_paths),
|
|
117
|
+
image_num_end=len(image_paths)
|
|
118
|
+
),
|
|
119
|
+
},
|
|
120
|
+
],
|
|
121
|
+
},
|
|
122
|
+
]
|
|
123
|
+
|
|
124
|
+
# Add images to messages
|
|
125
|
+
for i, image_path in enumerate(image_paths):
|
|
126
|
+
messages[-1]["content"].extend(
|
|
127
|
+
[
|
|
128
|
+
{"type": "text", "text": f"\nImage {i+1}: "},
|
|
129
|
+
{"type": "image", "image": image_path},
|
|
130
|
+
]
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
text = self.processor.apply_chat_template(
|
|
134
|
+
messages, tokenize=False, add_generation_prompt=True
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
image_inputs, video_inputs = process_vision_info(messages)
|
|
138
|
+
|
|
139
|
+
inputs = self.processor(
|
|
140
|
+
text=[text],
|
|
141
|
+
images=image_inputs,
|
|
142
|
+
videos=video_inputs,
|
|
143
|
+
padding=True,
|
|
144
|
+
return_tensors="pt",
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
inputs = inputs.to(device)
|
|
148
|
+
|
|
149
|
+
generated_ids = self.model.generate(
|
|
150
|
+
**inputs,
|
|
151
|
+
do_sample=True,
|
|
152
|
+
temperature=0.3,
|
|
153
|
+
max_new_tokens=8192,
|
|
154
|
+
use_cache=True,
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
generated_ids_trimmed = [
|
|
158
|
+
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
|
159
|
+
]
|
|
160
|
+
|
|
161
|
+
output_text = self.processor.batch_decode(
|
|
162
|
+
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
|
163
|
+
)[0]
|
|
164
|
+
|
|
165
|
+
# Parse output results
|
|
166
|
+
match = re.search(r'<answer>\[(.*?)\]</answer>', output_text)
|
|
167
|
+
|
|
168
|
+
if match:
|
|
169
|
+
try:
|
|
170
|
+
tmp_predicted_order = []
|
|
171
|
+
predicted_order = [int(x) - 1 for x in match.group(1).strip().split(',') if x.strip()]
|
|
172
|
+
|
|
173
|
+
for idx in predicted_order:
|
|
174
|
+
if 0 <= idx < len(image_paths):
|
|
175
|
+
tmp_predicted_order.append(idx)
|
|
176
|
+
|
|
177
|
+
predicted_order = tmp_predicted_order
|
|
178
|
+
|
|
179
|
+
# Handle missing indices
|
|
180
|
+
if len(set(predicted_order)) < len(image_paths):
|
|
181
|
+
missing_ids = set(range(len(image_paths))) - set(predicted_order)
|
|
182
|
+
predicted_order.extend(sorted(list(missing_ids)))
|
|
183
|
+
|
|
184
|
+
except Exception as e:
|
|
185
|
+
predicted_order = [i for i in range(len(image_paths))]
|
|
186
|
+
print(f"⚠️ Parsing error: {str(e)}, output text: {output_text[:200]}...")
|
|
187
|
+
else:
|
|
188
|
+
predicted_order = [i for i in range(len(image_paths))]
|
|
189
|
+
print(f"⚠️ Could not parse ranking from output: {output_text[:200]}...")
|
|
190
|
+
|
|
191
|
+
return predicted_order[:top_k]
|
|
192
|
+
|
|
193
|
+
@staticmethod
|
|
194
|
+
def _flash_attn_available() -> bool:
|
|
195
|
+
try:
|
|
196
|
+
import flash_attn # noqa: F401
|
|
197
|
+
return True
|
|
198
|
+
except Exception:
|
|
199
|
+
return False
|
|
200
|
+
|
|
201
|
+
class Florence2Reranker(BaseReranker):
|
|
202
|
+
"""Florence-2-large for visual document reranking"""
|
|
203
|
+
|
|
204
|
+
def __init__(self, model_name: str = "microsoft/Florence-2-large"):
|
|
205
|
+
from transformers import AutoProcessor, AutoModelForCausalLM
|
|
206
|
+
try:
|
|
207
|
+
from transformers import BitsAndBytesConfig
|
|
208
|
+
except ImportError:
|
|
209
|
+
BitsAndBytesConfig = None
|
|
210
|
+
|
|
211
|
+
print(f"Loading Florence-2: {model_name}")
|
|
212
|
+
|
|
213
|
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
214
|
+
self.torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
|
215
|
+
|
|
216
|
+
# Note: When using quantization, device_map is needed for BitsAndBytes
|
|
217
|
+
# When not using quantization, load explicitly to avoid meta tensor issues
|
|
218
|
+
if self.device == "cuda":
|
|
219
|
+
try:
|
|
220
|
+
quantization_config = BitsAndBytesConfig(
|
|
221
|
+
load_in_4bit=True,
|
|
222
|
+
bnb_4bit_compute_dtype=torch.float16,
|
|
223
|
+
bnb_4bit_use_double_quant=True,
|
|
224
|
+
)
|
|
225
|
+
self.model = AutoModelForCausalLM.from_pretrained(
|
|
226
|
+
model_name,
|
|
227
|
+
torch_dtype=self.torch_dtype,
|
|
228
|
+
trust_remote_code=True,
|
|
229
|
+
quantization_config=quantization_config,
|
|
230
|
+
device_map="auto" # Required for BitsAndBytes quantization
|
|
231
|
+
).eval()
|
|
232
|
+
except Exception as e:
|
|
233
|
+
print(f"⚠️ Quantization failed ({e}), loading without quantization...")
|
|
234
|
+
self.model = AutoModelForCausalLM.from_pretrained(
|
|
235
|
+
model_name,
|
|
236
|
+
torch_dtype=self.torch_dtype,
|
|
237
|
+
trust_remote_code=True,
|
|
238
|
+
low_cpu_mem_usage=True,
|
|
239
|
+
).to(self.device).eval()
|
|
240
|
+
else:
|
|
241
|
+
self.model = AutoModelForCausalLM.from_pretrained(
|
|
242
|
+
model_name,
|
|
243
|
+
torch_dtype=self.torch_dtype,
|
|
244
|
+
trust_remote_code=True,
|
|
245
|
+
low_cpu_mem_usage=True,
|
|
246
|
+
).to(self.device).eval()
|
|
247
|
+
|
|
248
|
+
self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
|
249
|
+
print(f"✅ Florence-2 loaded on {self.device}")
|
|
250
|
+
|
|
251
|
+
def _score_image(self, query: str, image_path: str) -> float:
|
|
252
|
+
"""Score a single image based on query relevance"""
|
|
253
|
+
try:
|
|
254
|
+
if not Path(image_path).exists():
|
|
255
|
+
return 0.0
|
|
256
|
+
|
|
257
|
+
image = Image.open(image_path).convert('RGB')
|
|
258
|
+
|
|
259
|
+
# Use caption task to understand image
|
|
260
|
+
prompt = "<CAPTION>"
|
|
261
|
+
inputs = self.processor(text=prompt, images=image, return_tensors="pt").to(self.device)
|
|
262
|
+
|
|
263
|
+
with torch.no_grad():
|
|
264
|
+
generated_ids = self.model.generate(
|
|
265
|
+
input_ids=inputs["input_ids"],
|
|
266
|
+
pixel_values=inputs["pixel_values"],
|
|
267
|
+
max_new_tokens=1024,
|
|
268
|
+
num_beams=3,
|
|
269
|
+
do_sample=False,
|
|
270
|
+
use_cache=False,
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
|
|
274
|
+
parsed_answer = self.processor.post_process_generation(
|
|
275
|
+
generated_text, task="<CAPTION>", image_size=(image.width, image.height)
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
caption = parsed_answer.get('<CAPTION>', '')
|
|
279
|
+
|
|
280
|
+
# Simple relevance score based on query keyword overlap
|
|
281
|
+
query_words = set(query.lower().split())
|
|
282
|
+
caption_words = set(caption.lower().split())
|
|
283
|
+
overlap = len(query_words & caption_words)
|
|
284
|
+
score = overlap / max(len(query_words), 1)
|
|
285
|
+
|
|
286
|
+
return score
|
|
287
|
+
|
|
288
|
+
except Exception as e:
|
|
289
|
+
print(f"⚠️ Scoring failed for {image_path}: {e}")
|
|
290
|
+
return 0.0
|
|
291
|
+
|
|
292
|
+
def rerank(self, query: str, image_paths: List[str], top_k: int = 10) -> List[int]:
|
|
293
|
+
"""Rerank images based on query relevance using Florence-2"""
|
|
294
|
+
try:
|
|
295
|
+
scores = []
|
|
296
|
+
for img_path in image_paths:
|
|
297
|
+
score = self._score_image(query, img_path)
|
|
298
|
+
scores.append(score)
|
|
299
|
+
|
|
300
|
+
# Get top-k indices sorted by score
|
|
301
|
+
ranked_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)
|
|
302
|
+
return ranked_indices[:top_k]
|
|
303
|
+
|
|
304
|
+
except Exception as e:
|
|
305
|
+
print(f"⚠️ Reranking failed: {e}")
|
|
306
|
+
return list(range(min(top_k, len(image_paths))))
|
|
307
|
+
|
|
308
|
+
class VLMReranker(ChunkReranker):
|
|
309
|
+
"""VLM-based reranker using Motor Maven endpoint with multiple images"""
|
|
310
|
+
|
|
311
|
+
def __init__(self):
|
|
312
|
+
from call_llm import call_vlm_with_multiple_images
|
|
313
|
+
from prompt import PROMPTS
|
|
314
|
+
|
|
315
|
+
self.call_vlm_multi = call_vlm_with_multiple_images
|
|
316
|
+
self.rerank_prompt = PROMPTS["rerank_vlm"]
|
|
317
|
+
print("✅ VLM Reranker initialized")
|
|
318
|
+
|
|
319
|
+
def rerank(self, query: str, chunks: List[Dict[str, str]], top_k: int = 1) -> List[Tuple[int, float, Dict[str, str]]]:
|
|
320
|
+
"""
|
|
321
|
+
Rerank chunks based on query relevance with multiple images
|
|
322
|
+
|
|
323
|
+
Args:
|
|
324
|
+
query: User query string
|
|
325
|
+
chunks: List of dicts with 'text' and optional 'image_path' keys
|
|
326
|
+
top_k: Number of top chunks to return (default: 1)
|
|
327
|
+
|
|
328
|
+
Returns:
|
|
329
|
+
List of tuples: (original_index, relevance_score, chunk_dict)
|
|
330
|
+
"""
|
|
331
|
+
try:
|
|
332
|
+
# Collect chunk data and images
|
|
333
|
+
chunk_data = []
|
|
334
|
+
image_paths = []
|
|
335
|
+
chunk_to_image_map = {} # Maps chunk index to image index
|
|
336
|
+
|
|
337
|
+
for i, chunk in enumerate(chunks):
|
|
338
|
+
chunk_info = {
|
|
339
|
+
'index': i,
|
|
340
|
+
'text': chunk.get('text', ''),
|
|
341
|
+
'image_path': chunk.get('image_path', None),
|
|
342
|
+
'has_image': False
|
|
343
|
+
}
|
|
344
|
+
|
|
345
|
+
# Track images
|
|
346
|
+
if chunk_info['image_path'] and Path(chunk_info['image_path']).exists():
|
|
347
|
+
chunk_info['has_image'] = True
|
|
348
|
+
chunk_to_image_map[i] = len(image_paths)
|
|
349
|
+
image_paths.append(chunk_info['image_path'])
|
|
350
|
+
|
|
351
|
+
chunk_data.append(chunk_info)
|
|
352
|
+
|
|
353
|
+
# Build structured prompt with explicit chunk boundaries
|
|
354
|
+
formatted_chunks = []
|
|
355
|
+
for chunk_info in chunk_data:
|
|
356
|
+
i = chunk_info['index'] + 1 # 1-indexed for display
|
|
357
|
+
chunk_lines = [f"<CHUNK_START id={i}>"]
|
|
358
|
+
chunk_lines.append(chunk_info['text'])
|
|
359
|
+
|
|
360
|
+
if chunk_info['has_image']:
|
|
361
|
+
img_idx = chunk_to_image_map[chunk_info['index']] + 1
|
|
362
|
+
chunk_lines.append(f"<IMAGE_START id={img_idx} relates_to_chunk={i}>")
|
|
363
|
+
chunk_lines.append(f"[Image {img_idx} displayed here]")
|
|
364
|
+
chunk_lines.append(f"<IMAGE_END id={img_idx}>")
|
|
365
|
+
|
|
366
|
+
chunk_lines.append(f"<CHUNK_END id={i}>")
|
|
367
|
+
formatted_chunks.append("\n".join(chunk_lines))
|
|
368
|
+
|
|
369
|
+
# Build full prompt
|
|
370
|
+
full_prompt = f"""{self.rerank_prompt}
|
|
371
|
+
|
|
372
|
+
Query: {query}
|
|
373
|
+
|
|
374
|
+
Chunks to rank:
|
|
375
|
+
|
|
376
|
+
{chr(10).join(formatted_chunks)}"""
|
|
377
|
+
|
|
378
|
+
# Call VLM with all images
|
|
379
|
+
if not image_paths:
|
|
380
|
+
# No images - use LLM fallback instead of VLM
|
|
381
|
+
from call_llm import call_llm_simple
|
|
382
|
+
response = call_llm_simple(full_prompt)
|
|
383
|
+
else:
|
|
384
|
+
response = self.call_vlm_multi(full_prompt, image_paths)
|
|
385
|
+
|
|
386
|
+
# Parse response to extract rankings
|
|
387
|
+
rankings = self._parse_rankings(response, chunk_data)
|
|
388
|
+
|
|
389
|
+
# Return top-k with chunk data
|
|
390
|
+
result = []
|
|
391
|
+
for idx, score in rankings[:top_k]:
|
|
392
|
+
result.append((idx, score, chunks[idx]))
|
|
393
|
+
|
|
394
|
+
return result
|
|
395
|
+
|
|
396
|
+
except Exception as e:
|
|
397
|
+
print(f"⚠️ Reranking failed: {e}")
|
|
398
|
+
import traceback
|
|
399
|
+
traceback.print_exc()
|
|
400
|
+
# Return original order with default scores
|
|
401
|
+
return [(i, 1.0, chunks[i]) for i in range(min(top_k, len(chunks)))]
|
|
402
|
+
|
|
403
|
+
def _parse_rankings(self, response: str, chunk_data: List[Dict]) -> List[Tuple[int, float]]:
|
|
404
|
+
"""Parse VLM response to extract chunk rankings from structured format"""
|
|
405
|
+
rankings = []
|
|
406
|
+
num_chunks = len(chunk_data)
|
|
407
|
+
|
|
408
|
+
# Primary pattern: <Rank X>Chunk Y (simplified format)
|
|
409
|
+
rank_pattern = r'<Rank\s+(\d+)>\s*Chunk\s+(\d+)'
|
|
410
|
+
matches = re.findall(rank_pattern, response, re.IGNORECASE)
|
|
411
|
+
|
|
412
|
+
seen_indices = set()
|
|
413
|
+
for rank_num, chunk_num in matches:
|
|
414
|
+
idx = int(chunk_num) - 1 # Convert to 0-indexed
|
|
415
|
+
rank = int(rank_num)
|
|
416
|
+
|
|
417
|
+
# Calculate score based on rank (higher rank = lower score)
|
|
418
|
+
# Rank 1 gets highest score (1.0), decreasing linearly
|
|
419
|
+
relevance = 1.0 - ((rank - 1) / max(num_chunks, 1))
|
|
420
|
+
|
|
421
|
+
# Ensure valid index and no duplicates
|
|
422
|
+
if 0 <= idx < num_chunks and idx not in seen_indices:
|
|
423
|
+
rankings.append((idx, relevance))
|
|
424
|
+
seen_indices.add(idx)
|
|
425
|
+
|
|
426
|
+
# If parsing failed or incomplete, fill remaining chunks with low scores
|
|
427
|
+
if len(rankings) < num_chunks:
|
|
428
|
+
missing = set(range(num_chunks)) - seen_indices
|
|
429
|
+
for idx in missing:
|
|
430
|
+
rankings.append((idx, 0.0))
|
|
431
|
+
print(f"⚠️ Parsed {len(seen_indices)}/{num_chunks} chunks from response")
|
|
432
|
+
if len(seen_indices) == 0:
|
|
433
|
+
# Debug: print response when parsing completely fails
|
|
434
|
+
print(f"🔍 Debug - VLM Response (first 500 chars):\n{response[:500]}")
|
|
435
|
+
print(f"🔍 Debug - VLM Response (last 500 chars):\n{response[-500:]}")
|
|
436
|
+
|
|
437
|
+
# Sort by relevance score (highest first)
|
|
438
|
+
rankings.sort(key=lambda x: x[1], reverse=True)
|
|
439
|
+
|
|
440
|
+
return rankings
|
|
441
|
+
|
|
442
|
+
class MonoVLMReranker(ChunkReranker):
|
|
443
|
+
"""MonoVLM reranker using lightonai/MonoQwen2-VL-v0.1"""
|
|
444
|
+
|
|
445
|
+
def __init__(
|
|
446
|
+
self,
|
|
447
|
+
model_name: str = "lightonai/MonoQwen2-VL-v0.1",
|
|
448
|
+
processor_name: str = "Qwen/Qwen2-VL-2B-Instruct",
|
|
449
|
+
):
|
|
450
|
+
print(f"Loading MonoVLM: {model_name}")
|
|
451
|
+
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
|
|
452
|
+
|
|
453
|
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
454
|
+
self.torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
|
455
|
+
|
|
456
|
+
self.processor = AutoProcessor.from_pretrained(processor_name, trust_remote_code=True)
|
|
457
|
+
|
|
458
|
+
# Load model without device_map="auto" to avoid meta tensor issues
|
|
459
|
+
# First load to CPU, then move to GPU explicitly
|
|
460
|
+
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
|
|
461
|
+
model_name,
|
|
462
|
+
torch_dtype=self.torch_dtype,
|
|
463
|
+
trust_remote_code=True,
|
|
464
|
+
low_cpu_mem_usage=True, # Reduces CPU memory during loading
|
|
465
|
+
)
|
|
466
|
+
|
|
467
|
+
# Move to device and set to eval mode
|
|
468
|
+
self.model = self.model.to(self.device).eval()
|
|
469
|
+
|
|
470
|
+
# Cache token IDs for True/False classification
|
|
471
|
+
tokenizer = getattr(self.processor, "tokenizer", None)
|
|
472
|
+
if tokenizer is None:
|
|
473
|
+
raise ValueError("Processor does not expose a tokenizer needed for MonoVLM scoring.")
|
|
474
|
+
|
|
475
|
+
self.true_token_id = tokenizer.convert_tokens_to_ids("True")
|
|
476
|
+
self.false_token_id = tokenizer.convert_tokens_to_ids("False")
|
|
477
|
+
|
|
478
|
+
if self.true_token_id is None or self.false_token_id is None:
|
|
479
|
+
raise ValueError("Tokenizer missing True/False tokens required for MonoVLM scoring.")
|
|
480
|
+
|
|
481
|
+
print(f"✅ MonoVLM loaded")
|
|
482
|
+
|
|
483
|
+
def _build_prompt(self, query: str, chunk_text: str) -> str:
|
|
484
|
+
chunk_text = chunk_text.strip() if chunk_text else "[No text provided]"
|
|
485
|
+
return (
|
|
486
|
+
"Assert the relevance of the provided document (text and/or image) to the query.\n"
|
|
487
|
+
"Respond with a single word: True if relevant, otherwise False.\n\n"
|
|
488
|
+
f"Query:\n{query}\n\nDocument:\n{chunk_text}"
|
|
489
|
+
)
|
|
490
|
+
|
|
491
|
+
def _score_chunk(self, query: str, chunk: Dict[str, str]) -> float:
|
|
492
|
+
image = None
|
|
493
|
+
image_path = chunk.get('image_path')
|
|
494
|
+
|
|
495
|
+
try:
|
|
496
|
+
if image_path and Path(image_path).exists():
|
|
497
|
+
with Image.open(image_path) as img:
|
|
498
|
+
image = img.convert("RGB")
|
|
499
|
+
except Exception as img_err:
|
|
500
|
+
print(f"⚠️ Failed to load image {image_path}: {img_err}")
|
|
501
|
+
image = None
|
|
502
|
+
|
|
503
|
+
prompt = self._build_prompt(query, chunk.get('text', ''))
|
|
504
|
+
|
|
505
|
+
messages = [
|
|
506
|
+
{
|
|
507
|
+
"role": "user",
|
|
508
|
+
"content": (
|
|
509
|
+
[{"type": "image", "image": image}] if image else []
|
|
510
|
+
) + [
|
|
511
|
+
{
|
|
512
|
+
"type": "text",
|
|
513
|
+
"text": prompt,
|
|
514
|
+
}
|
|
515
|
+
],
|
|
516
|
+
}
|
|
517
|
+
]
|
|
518
|
+
|
|
519
|
+
try:
|
|
520
|
+
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
521
|
+
|
|
522
|
+
processor_kwargs = {
|
|
523
|
+
"text": text,
|
|
524
|
+
"return_tensors": "pt",
|
|
525
|
+
}
|
|
526
|
+
|
|
527
|
+
if image is not None:
|
|
528
|
+
processor_kwargs["images"] = image
|
|
529
|
+
|
|
530
|
+
inputs = self.processor(**processor_kwargs)
|
|
531
|
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
|
532
|
+
|
|
533
|
+
with torch.no_grad():
|
|
534
|
+
outputs = self.model(**inputs)
|
|
535
|
+
|
|
536
|
+
logits = outputs.logits[:, -1, :]
|
|
537
|
+
relevance = torch.softmax(
|
|
538
|
+
logits[:, [self.true_token_id, self.false_token_id]],
|
|
539
|
+
dim=-1
|
|
540
|
+
)
|
|
541
|
+
|
|
542
|
+
return relevance[0, 0].item()
|
|
543
|
+
|
|
544
|
+
except Exception as e:
|
|
545
|
+
print(f"⚠️ MonoVLM scoring failed: {e}")
|
|
546
|
+
return 0.0
|
|
547
|
+
|
|
548
|
+
def rerank(self, query: str, chunks: List[Dict[str, str]], top_k: int = 1) -> List[Tuple[int, float, Dict[str, str]]]:
|
|
549
|
+
"""
|
|
550
|
+
Rerank chunks based on query relevance using MonoVLM
|
|
551
|
+
|
|
552
|
+
Args:
|
|
553
|
+
query: User query string
|
|
554
|
+
chunks: List of dicts with 'text' and optional 'image_path' keys
|
|
555
|
+
top_k: Number of top chunks to return (default: 1)
|
|
556
|
+
|
|
557
|
+
Returns:
|
|
558
|
+
List of tuples: (original_index, relevance_score, chunk_dict)
|
|
559
|
+
"""
|
|
560
|
+
scores = []
|
|
561
|
+
for idx, chunk in enumerate(chunks):
|
|
562
|
+
score = self._score_chunk(query, chunk)
|
|
563
|
+
scores.append((idx, score))
|
|
564
|
+
|
|
565
|
+
scores.sort(key=lambda x: x[1], reverse=True)
|
|
566
|
+
|
|
567
|
+
result = []
|
|
568
|
+
for idx, score in scores[:top_k]:
|
|
569
|
+
if 0 <= idx < len(chunks):
|
|
570
|
+
result.append((idx, score, chunks[idx]))
|
|
571
|
+
|
|
572
|
+
return result
|
|
573
|
+
|
|
574
|
+
class TextEmbeddingReranker(ChunkReranker):
|
|
575
|
+
"""Text embedding reranker using BAAI/bge-large-en-v1.5 with image descriptions"""
|
|
576
|
+
|
|
577
|
+
def __init__(self, model_name: str = "BAAI/bge-large-en-v1.5"):
|
|
578
|
+
from sentence_transformers import SentenceTransformer
|
|
579
|
+
|
|
580
|
+
print(f"Loading text embedding model: {model_name}")
|
|
581
|
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
582
|
+
|
|
583
|
+
self.model = SentenceTransformer(model_name, device=self.device)
|
|
584
|
+
print(f"✅ Text embedding model loaded on {self.device}")
|
|
585
|
+
|
|
586
|
+
# For generating image descriptions (reuse VLM)
|
|
587
|
+
from call_llm import call_vlm_simple
|
|
588
|
+
from prompt import PROMPTS
|
|
589
|
+
self.call_vlm = call_vlm_simple
|
|
590
|
+
self.desc_prompt = PROMPTS.get("rerank_image_desc", "Generate a concise 100-word technical description of this image.")
|
|
591
|
+
|
|
592
|
+
def _generate_image_description(self, image_path: str) -> str:
|
|
593
|
+
"""Generate description for image using VLM"""
|
|
594
|
+
try:
|
|
595
|
+
return self.call_vlm(self.desc_prompt, image_path)
|
|
596
|
+
except Exception as e:
|
|
597
|
+
print(f"⚠️ Image description failed for {image_path}: {e}")
|
|
598
|
+
return "[Image description unavailable]"
|
|
599
|
+
|
|
600
|
+
def rerank(self, query: str, chunks: List[Dict[str, str]], top_k: int = 1) -> List[Tuple[int, float, Dict[str, str]]]:
|
|
601
|
+
"""
|
|
602
|
+
Rerank chunks based on query relevance using text embeddings
|
|
603
|
+
|
|
604
|
+
Args:
|
|
605
|
+
query: User query string
|
|
606
|
+
chunks: List of dicts with 'text' and optional 'image_path' keys
|
|
607
|
+
top_k: Number of top chunks to return (default: 1)
|
|
608
|
+
|
|
609
|
+
Returns:
|
|
610
|
+
List of tuples: (original_index, relevance_score, chunk_dict)
|
|
611
|
+
"""
|
|
612
|
+
try:
|
|
613
|
+
# Generate image descriptions and combine with text
|
|
614
|
+
chunk_texts = []
|
|
615
|
+
for chunk in chunks:
|
|
616
|
+
text = chunk.get('text', '')
|
|
617
|
+
|
|
618
|
+
# Add image description if image exists
|
|
619
|
+
if chunk.get('image_path') and Path(chunk['image_path']).exists():
|
|
620
|
+
img_desc = self._generate_image_description(chunk['image_path'])
|
|
621
|
+
combined_text = f"{text}\n[Image Description: {img_desc}]"
|
|
622
|
+
else:
|
|
623
|
+
combined_text = text
|
|
624
|
+
|
|
625
|
+
chunk_texts.append(combined_text)
|
|
626
|
+
|
|
627
|
+
# Compute embeddings
|
|
628
|
+
query_embedding = self.model.encode([query], convert_to_tensor=True, normalize_embeddings=True)
|
|
629
|
+
chunk_embeddings = self.model.encode(chunk_texts, convert_to_tensor=True, normalize_embeddings=True)
|
|
630
|
+
|
|
631
|
+
# Compute cosine similarities
|
|
632
|
+
similarities = torch.nn.functional.cosine_similarity(
|
|
633
|
+
query_embedding, chunk_embeddings, dim=1
|
|
634
|
+
)
|
|
635
|
+
|
|
636
|
+
# Get top-k indices sorted by similarity
|
|
637
|
+
top_indices = torch.argsort(similarities, descending=True)[:top_k].cpu().tolist()
|
|
638
|
+
|
|
639
|
+
# Build results with scores
|
|
640
|
+
results = []
|
|
641
|
+
for idx in top_indices:
|
|
642
|
+
if 0 <= idx < len(chunks):
|
|
643
|
+
score = float(similarities[idx].cpu())
|
|
644
|
+
results.append((idx, score, chunks[idx]))
|
|
645
|
+
|
|
646
|
+
return results
|
|
647
|
+
|
|
648
|
+
except Exception as e:
|
|
649
|
+
print(f"⚠️ Reranking failed: {e}")
|
|
650
|
+
return [(i, 1.0, chunks[i]) for i in range(min(top_k, len(chunks)))]
|
|
651
|
+
|
|
652
|
+
if __name__ == "__main__":
|
|
653
|
+
from pathlib import Path
|
|
654
|
+
|
|
655
|
+
print("=" * 60)
|
|
656
|
+
print("Testing Multimodal Rerankers")
|
|
657
|
+
print("=" * 60)
|
|
658
|
+
|
|
659
|
+
# Prepare test data (text-only chunks for testing without specific images)
|
|
660
|
+
test_chunks = [
|
|
661
|
+
{
|
|
662
|
+
"text": "Machine learning models require careful hyperparameter tuning to achieve optimal performance. Common parameters include learning rate, batch size, and regularization strength.",
|
|
663
|
+
"image_path": None
|
|
664
|
+
},
|
|
665
|
+
{
|
|
666
|
+
"text": "This figure shows the characteristic relationship between model accuracy and training epochs. The data demonstrates typical learning curve behavior with diminishing returns after initial rapid improvement.",
|
|
667
|
+
"image_path": None # Set to actual image path for multimodal testing
|
|
668
|
+
},
|
|
669
|
+
{
|
|
670
|
+
"text": "Neural network architectures vary widely in depth and complexity. Convolutional neural networks excel at image tasks, while transformers dominate natural language processing applications.",
|
|
671
|
+
"image_path": None
|
|
672
|
+
},
|
|
673
|
+
{
|
|
674
|
+
"text": "Data preprocessing is essential for model performance. Standard techniques include normalization, handling missing values, and feature engineering for tabular data.",
|
|
675
|
+
"image_path": None
|
|
676
|
+
},
|
|
677
|
+
{
|
|
678
|
+
"text": "This flowchart illustrates the machine learning pipeline from data collection through model deployment, showing how each stage contributes to the final system performance.",
|
|
679
|
+
"image_path": None # Set to actual image path for multimodal testing
|
|
680
|
+
}
|
|
681
|
+
]
|
|
682
|
+
|
|
683
|
+
# Extract valid image paths for image-based rerankers
|
|
684
|
+
image_paths = [chunk['image_path'] for chunk in test_chunks if chunk.get('image_path') and Path(chunk['image_path']).exists()]
|
|
685
|
+
valid_chunks = [chunk for chunk in test_chunks if chunk['image_path'] is None or (chunk.get('image_path') and Path(chunk['image_path']).exists())]
|
|
686
|
+
query = "How do machine learning models improve with training, and what are the key stages in the ML pipeline?"
|
|
687
|
+
|
|
688
|
+
# Test 1: VLM Reranker
|
|
689
|
+
print("\n1. Testing VLM Reranker...")
|
|
690
|
+
print("-" * 60)
|
|
691
|
+
try:
|
|
692
|
+
if valid_chunks:
|
|
693
|
+
reranker = VLMReranker()
|
|
694
|
+
print(f"Testing with {len(valid_chunks)} chunks")
|
|
695
|
+
print(f"Chunks with images: {sum(1 for c in valid_chunks if c.get('image_path'))}")
|
|
696
|
+
print(f"\nQuery: {query}\n")
|
|
697
|
+
results = reranker.rerank(query, valid_chunks, top_k=3)
|
|
698
|
+
print("\nTop 3 Reranked Chunks:")
|
|
699
|
+
for i, (orig_idx, score, chunk) in enumerate(results, 1):
|
|
700
|
+
print(f"\n{i}. Original Index: {orig_idx}, Relevance Score: {score:.3f}")
|
|
701
|
+
print(f" Text preview: {chunk['text'][:100]}...")
|
|
702
|
+
if chunk.get('image_path'):
|
|
703
|
+
print(f" Has image: {Path(chunk['image_path']).name}")
|
|
704
|
+
else:
|
|
705
|
+
print(f" Text-only chunk")
|
|
706
|
+
print("\n✅ VLM Reranker test completed!")
|
|
707
|
+
else:
|
|
708
|
+
print("⚠️ No valid test chunks found.")
|
|
709
|
+
except Exception as e:
|
|
710
|
+
print(f"❌ VLM Reranker test failed: {e}")
|
|
711
|
+
import traceback
|
|
712
|
+
traceback.print_exc()
|
|
713
|
+
|
|
714
|
+
# Test 2: MonoVLM Reranker
|
|
715
|
+
print("\n2. Testing MonoVLM Reranker (lightonai/MonoQwen2-VL-v0.1)...")
|
|
716
|
+
print("-" * 60)
|
|
717
|
+
try:
|
|
718
|
+
if valid_chunks:
|
|
719
|
+
reranker = MonoVLMReranker()
|
|
720
|
+
print(f"Testing with {len(valid_chunks)} chunks")
|
|
721
|
+
print(f"Chunks with images: {sum(1 for c in valid_chunks if c.get('image_path'))}")
|
|
722
|
+
print(f"\nQuery: {query}\n")
|
|
723
|
+
results = reranker.rerank(query, valid_chunks, top_k=3)
|
|
724
|
+
print("\nTop 3 Reranked Chunks:")
|
|
725
|
+
for i, (orig_idx, score, chunk) in enumerate(results, 1):
|
|
726
|
+
print(f"\n{i}. Original Index: {orig_idx}, Relevance Score: {score:.3f}")
|
|
727
|
+
print(f" Text preview: {chunk['text'][:100]}...")
|
|
728
|
+
if chunk.get('image_path'):
|
|
729
|
+
print(f" Has image: {Path(chunk['image_path']).name}")
|
|
730
|
+
else:
|
|
731
|
+
print(f" Text-only chunk")
|
|
732
|
+
print("\n✅ MonoVLM Reranker test completed!")
|
|
733
|
+
else:
|
|
734
|
+
print("⚠️ No valid test chunks found.")
|
|
735
|
+
except Exception as e:
|
|
736
|
+
print(f"⚠️ MonoVLM Reranker test skipped: {e}")
|
|
737
|
+
import traceback
|
|
738
|
+
traceback.print_exc()
|
|
739
|
+
|
|
740
|
+
# Test 3: Text Embedding Reranker (BAAI/bge-large-en-v1.5)
|
|
741
|
+
print("\n3. Testing Text Embedding Reranker (BAAI/bge-large-en-v1.5)...")
|
|
742
|
+
print("-" * 60)
|
|
743
|
+
try:
|
|
744
|
+
if valid_chunks:
|
|
745
|
+
reranker = TextEmbeddingReranker()
|
|
746
|
+
print(f"Testing with {len(valid_chunks)} chunks")
|
|
747
|
+
print(f"Chunks with images: {sum(1 for c in valid_chunks if c.get('image_path'))}")
|
|
748
|
+
print(f"\nQuery: {query}\n")
|
|
749
|
+
results = reranker.rerank(query, valid_chunks, top_k=3)
|
|
750
|
+
print("\nTop 3 Reranked Chunks:")
|
|
751
|
+
for i, (orig_idx, score, chunk) in enumerate(results, 1):
|
|
752
|
+
print(f"\n{i}. Original Index: {orig_idx}, Relevance Score: {score:.3f}")
|
|
753
|
+
print(f" Text preview: {chunk['text'][:100]}...")
|
|
754
|
+
if chunk.get('image_path'):
|
|
755
|
+
print(f" Has image: {Path(chunk['image_path']).name}")
|
|
756
|
+
else:
|
|
757
|
+
print(f" Text-only chunk")
|
|
758
|
+
print("\n✅ Text Embedding Reranker test completed!")
|
|
759
|
+
else:
|
|
760
|
+
print("⚠️ No valid test chunks found.")
|
|
761
|
+
except Exception as e:
|
|
762
|
+
print(f"⚠️ Text Embedding Reranker test skipped: {e}")
|
|
763
|
+
import traceback
|
|
764
|
+
traceback.print_exc()
|
|
765
|
+
|
|
766
|
+
print("\n" + "=" * 60)
|