mirage-benchmark 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mirage-benchmark might be problematic. Click here for more details.

@@ -0,0 +1,766 @@
1
+ import torch
2
+ import re
3
+ from abc import ABC, abstractmethod
4
+ from typing import List, Optional, Dict, Tuple
5
+ from pathlib import Path
6
+ from PIL import Image
7
+
8
+ class BaseReranker(ABC):
9
+ """Abstract base class for image-based rerankers"""
10
+
11
+ @abstractmethod
12
+ def rerank(self, query: str, image_paths: List[str], top_k: int = 10) -> List[int]:
13
+ pass
14
+
15
+ class ChunkReranker(ABC):
16
+ """Abstract base class for chunk-based rerankers (text + optional images)"""
17
+
18
+ @abstractmethod
19
+ def rerank(self, query: str, chunks: List[Dict[str, str]], top_k: int = 1) -> List[Tuple[int, float, Dict[str, str]]]:
20
+ """
21
+ Rerank chunks based on query relevance
22
+
23
+ Args:
24
+ query: User query string
25
+ chunks: List of dicts with 'text' and optional 'image_path' keys
26
+ top_k: Number of top chunks to return
27
+
28
+ Returns:
29
+ List of tuples: (original_index, relevance_score, chunk_dict)
30
+ """
31
+ pass
32
+
33
+ class MMR5Reranker(BaseReranker):
34
+ """MM-R5: MultiModal Reasoning-Enhanced ReRanker"""
35
+
36
+ SYSTEM_PROMPT = (
37
+ "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
38
+ "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
39
+ "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
40
+ "<think> reasoning process here </think><answer> answer here </answer>"
41
+ )
42
+
43
+ QUESTION_TEMPLATE = (
44
+ "Please rank the following images according to their relevance to the question. "
45
+ "Provide your response in the format: <think>your reasoning process here</think><answer>[image_id_1, image_id_2, ...]</answer> "
46
+ "where the numbers in the list represent the ranking order of images'id from most to least relevant. "
47
+ "Before outputting the answer, you need to analyze each image and provide your analysis process."
48
+ "For example: <think>Image 1 shows the most relevant content because...</think><answer>[id_most_relevant, id_second_relevant, ...]</answer>"
49
+ "\nThe question is: {Question}"
50
+ "\n\nThere are {image_num} images, id from 1 to {image_num_end}, Image ID to image mapping:\n"
51
+ )
52
+
53
+ def __init__(self, model_name: str = "i2vec/MM-R5"):
54
+ print(f"Loading MM-R5: {model_name}")
55
+
56
+ # Try official MM-R5 package first
57
+ try:
58
+ from reranker import QueryReranker # type: ignore
59
+ self.reranker = QueryReranker(model_name)
60
+ self.use_official_package = True
61
+ print(f"✅ MM-R5 loaded via official package")
62
+ except ImportError:
63
+ # Fallback to direct implementation using official code
64
+ from transformers import Qwen2_5_VLProcessor, Qwen2_5_VLForConditionalGeneration
65
+ from qwen_vl_utils import process_vision_info
66
+
67
+ attn_kwargs = {}
68
+ if self._flash_attn_available():
69
+ attn_kwargs["attn_implementation"] = "flash_attention_2"
70
+ else:
71
+ print("⚠️ flash_attn not available. Falling back to default attention implementation.")
72
+
73
+ # Load without device_map="auto" to avoid meta tensor issues in parallel processing
74
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
75
+ self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
76
+ model_name,
77
+ torch_dtype=torch.bfloat16,
78
+ low_cpu_mem_usage=True,
79
+ **attn_kwargs,
80
+ ).to(self.device).eval()
81
+
82
+ self.processor = Qwen2_5_VLProcessor.from_pretrained(model_name)
83
+ self.use_official_package = False
84
+ print(f"✅ MM-R5 loaded via direct implementation")
85
+
86
+ def rerank(self, query: str, image_paths: List[str], top_k: int = 10) -> List[int]:
87
+ """Rerank images based on query relevance"""
88
+ import re
89
+
90
+ if self.use_official_package:
91
+ # Use official package API
92
+ predicted_order = self.reranker.rerank(query, image_paths)
93
+ else:
94
+ # Use direct implementation following official code
95
+ from qwen_vl_utils import process_vision_info
96
+
97
+ device = self.model.device
98
+
99
+ messages = [
100
+ {
101
+ "role": "system",
102
+ "content": [
103
+ {
104
+ "type": "text",
105
+ "text": self.SYSTEM_PROMPT,
106
+ },
107
+ ],
108
+ },
109
+ {
110
+ "role": "user",
111
+ "content": [
112
+ {
113
+ "type": "text",
114
+ "text": self.QUESTION_TEMPLATE.format(
115
+ Question=query,
116
+ image_num=len(image_paths),
117
+ image_num_end=len(image_paths)
118
+ ),
119
+ },
120
+ ],
121
+ },
122
+ ]
123
+
124
+ # Add images to messages
125
+ for i, image_path in enumerate(image_paths):
126
+ messages[-1]["content"].extend(
127
+ [
128
+ {"type": "text", "text": f"\nImage {i+1}: "},
129
+ {"type": "image", "image": image_path},
130
+ ]
131
+ )
132
+
133
+ text = self.processor.apply_chat_template(
134
+ messages, tokenize=False, add_generation_prompt=True
135
+ )
136
+
137
+ image_inputs, video_inputs = process_vision_info(messages)
138
+
139
+ inputs = self.processor(
140
+ text=[text],
141
+ images=image_inputs,
142
+ videos=video_inputs,
143
+ padding=True,
144
+ return_tensors="pt",
145
+ )
146
+
147
+ inputs = inputs.to(device)
148
+
149
+ generated_ids = self.model.generate(
150
+ **inputs,
151
+ do_sample=True,
152
+ temperature=0.3,
153
+ max_new_tokens=8192,
154
+ use_cache=True,
155
+ )
156
+
157
+ generated_ids_trimmed = [
158
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
159
+ ]
160
+
161
+ output_text = self.processor.batch_decode(
162
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
163
+ )[0]
164
+
165
+ # Parse output results
166
+ match = re.search(r'<answer>\[(.*?)\]</answer>', output_text)
167
+
168
+ if match:
169
+ try:
170
+ tmp_predicted_order = []
171
+ predicted_order = [int(x) - 1 for x in match.group(1).strip().split(',') if x.strip()]
172
+
173
+ for idx in predicted_order:
174
+ if 0 <= idx < len(image_paths):
175
+ tmp_predicted_order.append(idx)
176
+
177
+ predicted_order = tmp_predicted_order
178
+
179
+ # Handle missing indices
180
+ if len(set(predicted_order)) < len(image_paths):
181
+ missing_ids = set(range(len(image_paths))) - set(predicted_order)
182
+ predicted_order.extend(sorted(list(missing_ids)))
183
+
184
+ except Exception as e:
185
+ predicted_order = [i for i in range(len(image_paths))]
186
+ print(f"⚠️ Parsing error: {str(e)}, output text: {output_text[:200]}...")
187
+ else:
188
+ predicted_order = [i for i in range(len(image_paths))]
189
+ print(f"⚠️ Could not parse ranking from output: {output_text[:200]}...")
190
+
191
+ return predicted_order[:top_k]
192
+
193
+ @staticmethod
194
+ def _flash_attn_available() -> bool:
195
+ try:
196
+ import flash_attn # noqa: F401
197
+ return True
198
+ except Exception:
199
+ return False
200
+
201
+ class Florence2Reranker(BaseReranker):
202
+ """Florence-2-large for visual document reranking"""
203
+
204
+ def __init__(self, model_name: str = "microsoft/Florence-2-large"):
205
+ from transformers import AutoProcessor, AutoModelForCausalLM
206
+ try:
207
+ from transformers import BitsAndBytesConfig
208
+ except ImportError:
209
+ BitsAndBytesConfig = None
210
+
211
+ print(f"Loading Florence-2: {model_name}")
212
+
213
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
214
+ self.torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
215
+
216
+ # Note: When using quantization, device_map is needed for BitsAndBytes
217
+ # When not using quantization, load explicitly to avoid meta tensor issues
218
+ if self.device == "cuda":
219
+ try:
220
+ quantization_config = BitsAndBytesConfig(
221
+ load_in_4bit=True,
222
+ bnb_4bit_compute_dtype=torch.float16,
223
+ bnb_4bit_use_double_quant=True,
224
+ )
225
+ self.model = AutoModelForCausalLM.from_pretrained(
226
+ model_name,
227
+ torch_dtype=self.torch_dtype,
228
+ trust_remote_code=True,
229
+ quantization_config=quantization_config,
230
+ device_map="auto" # Required for BitsAndBytes quantization
231
+ ).eval()
232
+ except Exception as e:
233
+ print(f"⚠️ Quantization failed ({e}), loading without quantization...")
234
+ self.model = AutoModelForCausalLM.from_pretrained(
235
+ model_name,
236
+ torch_dtype=self.torch_dtype,
237
+ trust_remote_code=True,
238
+ low_cpu_mem_usage=True,
239
+ ).to(self.device).eval()
240
+ else:
241
+ self.model = AutoModelForCausalLM.from_pretrained(
242
+ model_name,
243
+ torch_dtype=self.torch_dtype,
244
+ trust_remote_code=True,
245
+ low_cpu_mem_usage=True,
246
+ ).to(self.device).eval()
247
+
248
+ self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
249
+ print(f"✅ Florence-2 loaded on {self.device}")
250
+
251
+ def _score_image(self, query: str, image_path: str) -> float:
252
+ """Score a single image based on query relevance"""
253
+ try:
254
+ if not Path(image_path).exists():
255
+ return 0.0
256
+
257
+ image = Image.open(image_path).convert('RGB')
258
+
259
+ # Use caption task to understand image
260
+ prompt = "<CAPTION>"
261
+ inputs = self.processor(text=prompt, images=image, return_tensors="pt").to(self.device)
262
+
263
+ with torch.no_grad():
264
+ generated_ids = self.model.generate(
265
+ input_ids=inputs["input_ids"],
266
+ pixel_values=inputs["pixel_values"],
267
+ max_new_tokens=1024,
268
+ num_beams=3,
269
+ do_sample=False,
270
+ use_cache=False,
271
+ )
272
+
273
+ generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
274
+ parsed_answer = self.processor.post_process_generation(
275
+ generated_text, task="<CAPTION>", image_size=(image.width, image.height)
276
+ )
277
+
278
+ caption = parsed_answer.get('<CAPTION>', '')
279
+
280
+ # Simple relevance score based on query keyword overlap
281
+ query_words = set(query.lower().split())
282
+ caption_words = set(caption.lower().split())
283
+ overlap = len(query_words & caption_words)
284
+ score = overlap / max(len(query_words), 1)
285
+
286
+ return score
287
+
288
+ except Exception as e:
289
+ print(f"⚠️ Scoring failed for {image_path}: {e}")
290
+ return 0.0
291
+
292
+ def rerank(self, query: str, image_paths: List[str], top_k: int = 10) -> List[int]:
293
+ """Rerank images based on query relevance using Florence-2"""
294
+ try:
295
+ scores = []
296
+ for img_path in image_paths:
297
+ score = self._score_image(query, img_path)
298
+ scores.append(score)
299
+
300
+ # Get top-k indices sorted by score
301
+ ranked_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)
302
+ return ranked_indices[:top_k]
303
+
304
+ except Exception as e:
305
+ print(f"⚠️ Reranking failed: {e}")
306
+ return list(range(min(top_k, len(image_paths))))
307
+
308
+ class VLMReranker(ChunkReranker):
309
+ """VLM-based reranker using Motor Maven endpoint with multiple images"""
310
+
311
+ def __init__(self):
312
+ from call_llm import call_vlm_with_multiple_images
313
+ from prompt import PROMPTS
314
+
315
+ self.call_vlm_multi = call_vlm_with_multiple_images
316
+ self.rerank_prompt = PROMPTS["rerank_vlm"]
317
+ print("✅ VLM Reranker initialized")
318
+
319
+ def rerank(self, query: str, chunks: List[Dict[str, str]], top_k: int = 1) -> List[Tuple[int, float, Dict[str, str]]]:
320
+ """
321
+ Rerank chunks based on query relevance with multiple images
322
+
323
+ Args:
324
+ query: User query string
325
+ chunks: List of dicts with 'text' and optional 'image_path' keys
326
+ top_k: Number of top chunks to return (default: 1)
327
+
328
+ Returns:
329
+ List of tuples: (original_index, relevance_score, chunk_dict)
330
+ """
331
+ try:
332
+ # Collect chunk data and images
333
+ chunk_data = []
334
+ image_paths = []
335
+ chunk_to_image_map = {} # Maps chunk index to image index
336
+
337
+ for i, chunk in enumerate(chunks):
338
+ chunk_info = {
339
+ 'index': i,
340
+ 'text': chunk.get('text', ''),
341
+ 'image_path': chunk.get('image_path', None),
342
+ 'has_image': False
343
+ }
344
+
345
+ # Track images
346
+ if chunk_info['image_path'] and Path(chunk_info['image_path']).exists():
347
+ chunk_info['has_image'] = True
348
+ chunk_to_image_map[i] = len(image_paths)
349
+ image_paths.append(chunk_info['image_path'])
350
+
351
+ chunk_data.append(chunk_info)
352
+
353
+ # Build structured prompt with explicit chunk boundaries
354
+ formatted_chunks = []
355
+ for chunk_info in chunk_data:
356
+ i = chunk_info['index'] + 1 # 1-indexed for display
357
+ chunk_lines = [f"<CHUNK_START id={i}>"]
358
+ chunk_lines.append(chunk_info['text'])
359
+
360
+ if chunk_info['has_image']:
361
+ img_idx = chunk_to_image_map[chunk_info['index']] + 1
362
+ chunk_lines.append(f"<IMAGE_START id={img_idx} relates_to_chunk={i}>")
363
+ chunk_lines.append(f"[Image {img_idx} displayed here]")
364
+ chunk_lines.append(f"<IMAGE_END id={img_idx}>")
365
+
366
+ chunk_lines.append(f"<CHUNK_END id={i}>")
367
+ formatted_chunks.append("\n".join(chunk_lines))
368
+
369
+ # Build full prompt
370
+ full_prompt = f"""{self.rerank_prompt}
371
+
372
+ Query: {query}
373
+
374
+ Chunks to rank:
375
+
376
+ {chr(10).join(formatted_chunks)}"""
377
+
378
+ # Call VLM with all images
379
+ if not image_paths:
380
+ # No images - use LLM fallback instead of VLM
381
+ from call_llm import call_llm_simple
382
+ response = call_llm_simple(full_prompt)
383
+ else:
384
+ response = self.call_vlm_multi(full_prompt, image_paths)
385
+
386
+ # Parse response to extract rankings
387
+ rankings = self._parse_rankings(response, chunk_data)
388
+
389
+ # Return top-k with chunk data
390
+ result = []
391
+ for idx, score in rankings[:top_k]:
392
+ result.append((idx, score, chunks[idx]))
393
+
394
+ return result
395
+
396
+ except Exception as e:
397
+ print(f"⚠️ Reranking failed: {e}")
398
+ import traceback
399
+ traceback.print_exc()
400
+ # Return original order with default scores
401
+ return [(i, 1.0, chunks[i]) for i in range(min(top_k, len(chunks)))]
402
+
403
+ def _parse_rankings(self, response: str, chunk_data: List[Dict]) -> List[Tuple[int, float]]:
404
+ """Parse VLM response to extract chunk rankings from structured format"""
405
+ rankings = []
406
+ num_chunks = len(chunk_data)
407
+
408
+ # Primary pattern: <Rank X>Chunk Y (simplified format)
409
+ rank_pattern = r'<Rank\s+(\d+)>\s*Chunk\s+(\d+)'
410
+ matches = re.findall(rank_pattern, response, re.IGNORECASE)
411
+
412
+ seen_indices = set()
413
+ for rank_num, chunk_num in matches:
414
+ idx = int(chunk_num) - 1 # Convert to 0-indexed
415
+ rank = int(rank_num)
416
+
417
+ # Calculate score based on rank (higher rank = lower score)
418
+ # Rank 1 gets highest score (1.0), decreasing linearly
419
+ relevance = 1.0 - ((rank - 1) / max(num_chunks, 1))
420
+
421
+ # Ensure valid index and no duplicates
422
+ if 0 <= idx < num_chunks and idx not in seen_indices:
423
+ rankings.append((idx, relevance))
424
+ seen_indices.add(idx)
425
+
426
+ # If parsing failed or incomplete, fill remaining chunks with low scores
427
+ if len(rankings) < num_chunks:
428
+ missing = set(range(num_chunks)) - seen_indices
429
+ for idx in missing:
430
+ rankings.append((idx, 0.0))
431
+ print(f"⚠️ Parsed {len(seen_indices)}/{num_chunks} chunks from response")
432
+ if len(seen_indices) == 0:
433
+ # Debug: print response when parsing completely fails
434
+ print(f"🔍 Debug - VLM Response (first 500 chars):\n{response[:500]}")
435
+ print(f"🔍 Debug - VLM Response (last 500 chars):\n{response[-500:]}")
436
+
437
+ # Sort by relevance score (highest first)
438
+ rankings.sort(key=lambda x: x[1], reverse=True)
439
+
440
+ return rankings
441
+
442
+ class MonoVLMReranker(ChunkReranker):
443
+ """MonoVLM reranker using lightonai/MonoQwen2-VL-v0.1"""
444
+
445
+ def __init__(
446
+ self,
447
+ model_name: str = "lightonai/MonoQwen2-VL-v0.1",
448
+ processor_name: str = "Qwen/Qwen2-VL-2B-Instruct",
449
+ ):
450
+ print(f"Loading MonoVLM: {model_name}")
451
+ from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
452
+
453
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
454
+ self.torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
455
+
456
+ self.processor = AutoProcessor.from_pretrained(processor_name, trust_remote_code=True)
457
+
458
+ # Load model without device_map="auto" to avoid meta tensor issues
459
+ # First load to CPU, then move to GPU explicitly
460
+ self.model = Qwen2VLForConditionalGeneration.from_pretrained(
461
+ model_name,
462
+ torch_dtype=self.torch_dtype,
463
+ trust_remote_code=True,
464
+ low_cpu_mem_usage=True, # Reduces CPU memory during loading
465
+ )
466
+
467
+ # Move to device and set to eval mode
468
+ self.model = self.model.to(self.device).eval()
469
+
470
+ # Cache token IDs for True/False classification
471
+ tokenizer = getattr(self.processor, "tokenizer", None)
472
+ if tokenizer is None:
473
+ raise ValueError("Processor does not expose a tokenizer needed for MonoVLM scoring.")
474
+
475
+ self.true_token_id = tokenizer.convert_tokens_to_ids("True")
476
+ self.false_token_id = tokenizer.convert_tokens_to_ids("False")
477
+
478
+ if self.true_token_id is None or self.false_token_id is None:
479
+ raise ValueError("Tokenizer missing True/False tokens required for MonoVLM scoring.")
480
+
481
+ print(f"✅ MonoVLM loaded")
482
+
483
+ def _build_prompt(self, query: str, chunk_text: str) -> str:
484
+ chunk_text = chunk_text.strip() if chunk_text else "[No text provided]"
485
+ return (
486
+ "Assert the relevance of the provided document (text and/or image) to the query.\n"
487
+ "Respond with a single word: True if relevant, otherwise False.\n\n"
488
+ f"Query:\n{query}\n\nDocument:\n{chunk_text}"
489
+ )
490
+
491
+ def _score_chunk(self, query: str, chunk: Dict[str, str]) -> float:
492
+ image = None
493
+ image_path = chunk.get('image_path')
494
+
495
+ try:
496
+ if image_path and Path(image_path).exists():
497
+ with Image.open(image_path) as img:
498
+ image = img.convert("RGB")
499
+ except Exception as img_err:
500
+ print(f"⚠️ Failed to load image {image_path}: {img_err}")
501
+ image = None
502
+
503
+ prompt = self._build_prompt(query, chunk.get('text', ''))
504
+
505
+ messages = [
506
+ {
507
+ "role": "user",
508
+ "content": (
509
+ [{"type": "image", "image": image}] if image else []
510
+ ) + [
511
+ {
512
+ "type": "text",
513
+ "text": prompt,
514
+ }
515
+ ],
516
+ }
517
+ ]
518
+
519
+ try:
520
+ text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
521
+
522
+ processor_kwargs = {
523
+ "text": text,
524
+ "return_tensors": "pt",
525
+ }
526
+
527
+ if image is not None:
528
+ processor_kwargs["images"] = image
529
+
530
+ inputs = self.processor(**processor_kwargs)
531
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
532
+
533
+ with torch.no_grad():
534
+ outputs = self.model(**inputs)
535
+
536
+ logits = outputs.logits[:, -1, :]
537
+ relevance = torch.softmax(
538
+ logits[:, [self.true_token_id, self.false_token_id]],
539
+ dim=-1
540
+ )
541
+
542
+ return relevance[0, 0].item()
543
+
544
+ except Exception as e:
545
+ print(f"⚠️ MonoVLM scoring failed: {e}")
546
+ return 0.0
547
+
548
+ def rerank(self, query: str, chunks: List[Dict[str, str]], top_k: int = 1) -> List[Tuple[int, float, Dict[str, str]]]:
549
+ """
550
+ Rerank chunks based on query relevance using MonoVLM
551
+
552
+ Args:
553
+ query: User query string
554
+ chunks: List of dicts with 'text' and optional 'image_path' keys
555
+ top_k: Number of top chunks to return (default: 1)
556
+
557
+ Returns:
558
+ List of tuples: (original_index, relevance_score, chunk_dict)
559
+ """
560
+ scores = []
561
+ for idx, chunk in enumerate(chunks):
562
+ score = self._score_chunk(query, chunk)
563
+ scores.append((idx, score))
564
+
565
+ scores.sort(key=lambda x: x[1], reverse=True)
566
+
567
+ result = []
568
+ for idx, score in scores[:top_k]:
569
+ if 0 <= idx < len(chunks):
570
+ result.append((idx, score, chunks[idx]))
571
+
572
+ return result
573
+
574
+ class TextEmbeddingReranker(ChunkReranker):
575
+ """Text embedding reranker using BAAI/bge-large-en-v1.5 with image descriptions"""
576
+
577
+ def __init__(self, model_name: str = "BAAI/bge-large-en-v1.5"):
578
+ from sentence_transformers import SentenceTransformer
579
+
580
+ print(f"Loading text embedding model: {model_name}")
581
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
582
+
583
+ self.model = SentenceTransformer(model_name, device=self.device)
584
+ print(f"✅ Text embedding model loaded on {self.device}")
585
+
586
+ # For generating image descriptions (reuse VLM)
587
+ from call_llm import call_vlm_simple
588
+ from prompt import PROMPTS
589
+ self.call_vlm = call_vlm_simple
590
+ self.desc_prompt = PROMPTS.get("rerank_image_desc", "Generate a concise 100-word technical description of this image.")
591
+
592
+ def _generate_image_description(self, image_path: str) -> str:
593
+ """Generate description for image using VLM"""
594
+ try:
595
+ return self.call_vlm(self.desc_prompt, image_path)
596
+ except Exception as e:
597
+ print(f"⚠️ Image description failed for {image_path}: {e}")
598
+ return "[Image description unavailable]"
599
+
600
+ def rerank(self, query: str, chunks: List[Dict[str, str]], top_k: int = 1) -> List[Tuple[int, float, Dict[str, str]]]:
601
+ """
602
+ Rerank chunks based on query relevance using text embeddings
603
+
604
+ Args:
605
+ query: User query string
606
+ chunks: List of dicts with 'text' and optional 'image_path' keys
607
+ top_k: Number of top chunks to return (default: 1)
608
+
609
+ Returns:
610
+ List of tuples: (original_index, relevance_score, chunk_dict)
611
+ """
612
+ try:
613
+ # Generate image descriptions and combine with text
614
+ chunk_texts = []
615
+ for chunk in chunks:
616
+ text = chunk.get('text', '')
617
+
618
+ # Add image description if image exists
619
+ if chunk.get('image_path') and Path(chunk['image_path']).exists():
620
+ img_desc = self._generate_image_description(chunk['image_path'])
621
+ combined_text = f"{text}\n[Image Description: {img_desc}]"
622
+ else:
623
+ combined_text = text
624
+
625
+ chunk_texts.append(combined_text)
626
+
627
+ # Compute embeddings
628
+ query_embedding = self.model.encode([query], convert_to_tensor=True, normalize_embeddings=True)
629
+ chunk_embeddings = self.model.encode(chunk_texts, convert_to_tensor=True, normalize_embeddings=True)
630
+
631
+ # Compute cosine similarities
632
+ similarities = torch.nn.functional.cosine_similarity(
633
+ query_embedding, chunk_embeddings, dim=1
634
+ )
635
+
636
+ # Get top-k indices sorted by similarity
637
+ top_indices = torch.argsort(similarities, descending=True)[:top_k].cpu().tolist()
638
+
639
+ # Build results with scores
640
+ results = []
641
+ for idx in top_indices:
642
+ if 0 <= idx < len(chunks):
643
+ score = float(similarities[idx].cpu())
644
+ results.append((idx, score, chunks[idx]))
645
+
646
+ return results
647
+
648
+ except Exception as e:
649
+ print(f"⚠️ Reranking failed: {e}")
650
+ return [(i, 1.0, chunks[i]) for i in range(min(top_k, len(chunks)))]
651
+
652
+ if __name__ == "__main__":
653
+ from pathlib import Path
654
+
655
+ print("=" * 60)
656
+ print("Testing Multimodal Rerankers")
657
+ print("=" * 60)
658
+
659
+ # Prepare test data (text-only chunks for testing without specific images)
660
+ test_chunks = [
661
+ {
662
+ "text": "Machine learning models require careful hyperparameter tuning to achieve optimal performance. Common parameters include learning rate, batch size, and regularization strength.",
663
+ "image_path": None
664
+ },
665
+ {
666
+ "text": "This figure shows the characteristic relationship between model accuracy and training epochs. The data demonstrates typical learning curve behavior with diminishing returns after initial rapid improvement.",
667
+ "image_path": None # Set to actual image path for multimodal testing
668
+ },
669
+ {
670
+ "text": "Neural network architectures vary widely in depth and complexity. Convolutional neural networks excel at image tasks, while transformers dominate natural language processing applications.",
671
+ "image_path": None
672
+ },
673
+ {
674
+ "text": "Data preprocessing is essential for model performance. Standard techniques include normalization, handling missing values, and feature engineering for tabular data.",
675
+ "image_path": None
676
+ },
677
+ {
678
+ "text": "This flowchart illustrates the machine learning pipeline from data collection through model deployment, showing how each stage contributes to the final system performance.",
679
+ "image_path": None # Set to actual image path for multimodal testing
680
+ }
681
+ ]
682
+
683
+ # Extract valid image paths for image-based rerankers
684
+ image_paths = [chunk['image_path'] for chunk in test_chunks if chunk.get('image_path') and Path(chunk['image_path']).exists()]
685
+ valid_chunks = [chunk for chunk in test_chunks if chunk['image_path'] is None or (chunk.get('image_path') and Path(chunk['image_path']).exists())]
686
+ query = "How do machine learning models improve with training, and what are the key stages in the ML pipeline?"
687
+
688
+ # Test 1: VLM Reranker
689
+ print("\n1. Testing VLM Reranker...")
690
+ print("-" * 60)
691
+ try:
692
+ if valid_chunks:
693
+ reranker = VLMReranker()
694
+ print(f"Testing with {len(valid_chunks)} chunks")
695
+ print(f"Chunks with images: {sum(1 for c in valid_chunks if c.get('image_path'))}")
696
+ print(f"\nQuery: {query}\n")
697
+ results = reranker.rerank(query, valid_chunks, top_k=3)
698
+ print("\nTop 3 Reranked Chunks:")
699
+ for i, (orig_idx, score, chunk) in enumerate(results, 1):
700
+ print(f"\n{i}. Original Index: {orig_idx}, Relevance Score: {score:.3f}")
701
+ print(f" Text preview: {chunk['text'][:100]}...")
702
+ if chunk.get('image_path'):
703
+ print(f" Has image: {Path(chunk['image_path']).name}")
704
+ else:
705
+ print(f" Text-only chunk")
706
+ print("\n✅ VLM Reranker test completed!")
707
+ else:
708
+ print("⚠️ No valid test chunks found.")
709
+ except Exception as e:
710
+ print(f"❌ VLM Reranker test failed: {e}")
711
+ import traceback
712
+ traceback.print_exc()
713
+
714
+ # Test 2: MonoVLM Reranker
715
+ print("\n2. Testing MonoVLM Reranker (lightonai/MonoQwen2-VL-v0.1)...")
716
+ print("-" * 60)
717
+ try:
718
+ if valid_chunks:
719
+ reranker = MonoVLMReranker()
720
+ print(f"Testing with {len(valid_chunks)} chunks")
721
+ print(f"Chunks with images: {sum(1 for c in valid_chunks if c.get('image_path'))}")
722
+ print(f"\nQuery: {query}\n")
723
+ results = reranker.rerank(query, valid_chunks, top_k=3)
724
+ print("\nTop 3 Reranked Chunks:")
725
+ for i, (orig_idx, score, chunk) in enumerate(results, 1):
726
+ print(f"\n{i}. Original Index: {orig_idx}, Relevance Score: {score:.3f}")
727
+ print(f" Text preview: {chunk['text'][:100]}...")
728
+ if chunk.get('image_path'):
729
+ print(f" Has image: {Path(chunk['image_path']).name}")
730
+ else:
731
+ print(f" Text-only chunk")
732
+ print("\n✅ MonoVLM Reranker test completed!")
733
+ else:
734
+ print("⚠️ No valid test chunks found.")
735
+ except Exception as e:
736
+ print(f"⚠️ MonoVLM Reranker test skipped: {e}")
737
+ import traceback
738
+ traceback.print_exc()
739
+
740
+ # Test 3: Text Embedding Reranker (BAAI/bge-large-en-v1.5)
741
+ print("\n3. Testing Text Embedding Reranker (BAAI/bge-large-en-v1.5)...")
742
+ print("-" * 60)
743
+ try:
744
+ if valid_chunks:
745
+ reranker = TextEmbeddingReranker()
746
+ print(f"Testing with {len(valid_chunks)} chunks")
747
+ print(f"Chunks with images: {sum(1 for c in valid_chunks if c.get('image_path'))}")
748
+ print(f"\nQuery: {query}\n")
749
+ results = reranker.rerank(query, valid_chunks, top_k=3)
750
+ print("\nTop 3 Reranked Chunks:")
751
+ for i, (orig_idx, score, chunk) in enumerate(results, 1):
752
+ print(f"\n{i}. Original Index: {orig_idx}, Relevance Score: {score:.3f}")
753
+ print(f" Text preview: {chunk['text'][:100]}...")
754
+ if chunk.get('image_path'):
755
+ print(f" Has image: {Path(chunk['image_path']).name}")
756
+ else:
757
+ print(f" Text-only chunk")
758
+ print("\n✅ Text Embedding Reranker test completed!")
759
+ else:
760
+ print("⚠️ No valid test chunks found.")
761
+ except Exception as e:
762
+ print(f"⚠️ Text Embedding Reranker test skipped: {e}")
763
+ import traceback
764
+ traceback.print_exc()
765
+
766
+ print("\n" + "=" * 60)