mirage-benchmark 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mirage-benchmark might be problematic. Click here for more details.

@@ -0,0 +1,491 @@
1
+
2
+ import json
3
+ import logging
4
+ import sys
5
+ import os
6
+ from tqdm import tqdm
7
+ from sentence_transformers import SentenceTransformer, util
8
+ import torch
9
+ import numpy as np
10
+ from typing import List, Dict, Set, Tuple
11
+ from collections import defaultdict
12
+
13
+ # Import helper modules
14
+ from mirage.embeddings.models import get_best_embedding_model
15
+ from mirage.embeddings.rerankers_text import LLMReranker
16
+ from mirage.core.llm import setup_logging, call_vlm_with_multiple_images
17
+ from mirage.core.prompts import PROMPTS
18
+
19
+ # Try to load configuration from config.yaml
20
+ try:
21
+ from mirage.core.config import load_config
22
+ _cfg = load_config()
23
+ _dedup = _cfg.get('deduplication', {})
24
+
25
+ QUESTION_SIMILARITY_THRESHOLD = _dedup.get('question_similarity_threshold', 0.75)
26
+ _ans_sim = _dedup.get('answer_similarity', {})
27
+ ANSWER_SIMILARITY_HIGH = _ans_sim.get('high', 0.95)
28
+ ANSWER_SIMILARITY_MEDIUM = _ans_sim.get('medium', 0.85)
29
+ ANSWER_SIMILARITY_LOW = _ans_sim.get('low', 0.70)
30
+ MIN_COMMUNITY_SIZE = _dedup.get('min_community_size', 2)
31
+ # α parameter: weight for semantic similarity vs chunk lineage (Eq. 10 in manuscript)
32
+ ALPHA = _dedup.get('alpha', 0.6)
33
+ print(f"✅ Deduplication config loaded: α={ALPHA}, question_threshold={QUESTION_SIMILARITY_THRESHOLD}")
34
+ except ImportError:
35
+ print("⚠️ config_loader not available, using default deduplication configuration")
36
+ QUESTION_SIMILARITY_THRESHOLD = 0.75
37
+ ANSWER_SIMILARITY_HIGH = 0.95
38
+ ANSWER_SIMILARITY_MEDIUM = 0.85
39
+ ANSWER_SIMILARITY_LOW = 0.70
40
+ MIN_COMMUNITY_SIZE = 2
41
+ ALPHA = 0.6 # Default α for Eq. 10: α * semantic_sim + (1-α) * jaccard
42
+
43
+ # Configuration
44
+ OUTPUT_DIR = "output"
45
+ INPUT_FILE = os.path.join(OUTPUT_DIR, "qa_multihop_pass.json")
46
+ OUTPUT_FILE = os.path.join(OUTPUT_DIR, "qa_dataset_deduplicated.json")
47
+
48
+ def load_dataset(filepath):
49
+ print(f"📂 Loading dataset from {filepath}...")
50
+ with open(filepath, 'r', encoding='utf-8') as f:
51
+ return json.load(f)
52
+
53
+ def save_dataset(data, filepath):
54
+ print(f"💾 Saving {len(data)} items to {filepath}...")
55
+ with open(filepath, 'w', encoding='utf-8') as f:
56
+ json.dump(data, f, indent=2, ensure_ascii=False)
57
+
58
+ def compute_chunk_overlap(qa1: Dict, qa2: Dict) -> float:
59
+ """
60
+ Compute Jaccard similarity based on chunk lineage.
61
+ Returns overlap ratio [0, 1] of chunks used to generate each QA.
62
+ """
63
+ # Extract chunk identifiers from chunks_added (list of dicts with file_name and chunk_id)
64
+ def extract_chunk_ids(qa):
65
+ ids = {qa.get('chunk_id', -1)}
66
+ for chunk in qa.get('chunks_added', []):
67
+ if isinstance(chunk, dict):
68
+ ids.add((chunk.get('file_name', ''), chunk.get('chunk_id', '')))
69
+ else:
70
+ ids.add(chunk)
71
+ return ids
72
+
73
+ chunks1 = extract_chunk_ids(qa1)
74
+ chunks2 = extract_chunk_ids(qa2)
75
+
76
+ # Remove invalid IDs
77
+ chunks1.discard(-1)
78
+ chunks2.discard(-1)
79
+
80
+ if not chunks1 or not chunks2:
81
+ return 0.0
82
+
83
+ intersection = len(chunks1 & chunks2)
84
+ union = len(chunks1 | chunks2)
85
+
86
+ return intersection / union if union > 0 else 0.0
87
+
88
+ def select_best_qa(cluster_items: List[Dict]) -> Dict:
89
+ """
90
+ Select the best QA from exact duplicates based on quality metrics.
91
+ Prioritizes: relevance score > difficulty score > answer length
92
+ """
93
+ best = cluster_items[0]
94
+ best_score = (
95
+ float(best.get('relevance_score', 0)),
96
+ float(best.get('difficulty_score', 0)),
97
+ len(best.get('answer', ''))
98
+ )
99
+
100
+ for item in cluster_items[1:]:
101
+ score = (
102
+ float(item.get('relevance_score', 0)),
103
+ float(item.get('difficulty_score', 0)),
104
+ len(item.get('answer', ''))
105
+ )
106
+ if score > best_score:
107
+ best = item
108
+ best_score = score
109
+
110
+ return best
111
+
112
+ def reorganize_qa_packs(merged_items: List[Dict], base_metadata: Dict,
113
+ expert_persona: str,
114
+ domain: str) -> List[Dict]:
115
+ """
116
+ Reorganize merged QA pairs into balanced question-answer packs.
117
+ Groups related questions together while keeping packs balanced.
118
+
119
+ Args:
120
+ merged_items: List of merged QA dicts with 'question' and 'answer' keys
121
+ base_metadata: Metadata to propagate to reorganized items
122
+ expert_persona: Expert role for domain-specific organization
123
+ domain: Domain context for organization
124
+
125
+ Returns:
126
+ List of reorganized QA dicts
127
+ """
128
+ if len(merged_items) <= 1:
129
+ return merged_items
130
+
131
+ # Prepare merged questions and answers for the prompt
132
+ merged_questions = "\n".join([
133
+ f"{i+1}. {item['question']}"
134
+ for i, item in enumerate(merged_items)
135
+ ])
136
+ merged_answers = "\n".join([
137
+ f"{i+1}. {item['answer']}"
138
+ for i, item in enumerate(merged_items)
139
+ ])
140
+
141
+ prompt_template = PROMPTS.get("deduplication_reorganize", "")
142
+ if not prompt_template:
143
+ logging.warning("deduplication_reorganize prompt not found, skipping reorganization.")
144
+ return merged_items
145
+
146
+ # Format prompt with domain and expert role
147
+ formatted_prompt = prompt_template.format(expert_persona=expert_persona, domain=domain)
148
+ prompt = f"{formatted_prompt}\n\nInput:\nMerged Questions:\n{merged_questions}\n\nMerged Answers:\n{merged_answers}"
149
+
150
+ try:
151
+ response = call_vlm_with_multiple_images(prompt, [])
152
+ reorganized = parse_reorganized_packs(response, base_metadata)
153
+
154
+ if not reorganized:
155
+ logging.warning("LLM returned empty reorganization, keeping merged items.")
156
+ return merged_items
157
+
158
+ return reorganized
159
+
160
+ except Exception as e:
161
+ logging.error(f"Error in LLM reorganization: {e}")
162
+ return merged_items
163
+
164
+ def parse_reorganized_packs(response_text: str, base_metadata: Dict) -> List[Dict]:
165
+ """
166
+ Parse the LLM response containing reorganized QA packs.
167
+
168
+ Returns:
169
+ List of reorganized QA dicts
170
+ """
171
+ tuple_delimiter = PROMPTS.get("DEFAULT_TUPLE_DELIMITER", "<|#|>")
172
+ completion_delimiter = PROMPTS.get("DEFAULT_COMPLETION_DELIMITER", "<|#|>END<|#|>")
173
+
174
+ qa_packs = []
175
+
176
+ try:
177
+ # Remove completion delimiter if present
178
+ if completion_delimiter in response_text:
179
+ response_text = response_text.split(completion_delimiter)[0].strip()
180
+
181
+ # Remove START delimiter if present
182
+ start_delimiter = tuple_delimiter + "START" + tuple_delimiter
183
+ if start_delimiter in response_text:
184
+ response_text = response_text.split(start_delimiter, 1)[-1].strip()
185
+
186
+ # Split by NEXT delimiter to get individual packs
187
+ next_delimiter = tuple_delimiter + "NEXT" + tuple_delimiter
188
+ pack_texts = response_text.split(next_delimiter)
189
+
190
+ for pack_text in pack_texts:
191
+ pack_text = pack_text.strip()
192
+ if not pack_text:
193
+ continue
194
+
195
+ # Parse Question and Answer from the pack
196
+ # Expected format: Question<|#|><questions><|#|>Answer<|#|><answer>
197
+ if "Question" + tuple_delimiter in pack_text:
198
+ parts = pack_text.split(tuple_delimiter)
199
+
200
+ question = None
201
+ answer = None
202
+
203
+ for i, part in enumerate(parts):
204
+ if part.lower() == "question" and i + 1 < len(parts):
205
+ question = parts[i + 1].strip()
206
+ elif part.lower() == "answer" and i + 1 < len(parts):
207
+ answer = parts[i + 1].strip()
208
+
209
+ if question and answer:
210
+ new_item = base_metadata.copy()
211
+ new_item["question"] = question
212
+ new_item["answer"] = answer
213
+ new_item["reorganized"] = True
214
+ qa_packs.append(new_item)
215
+
216
+ return qa_packs
217
+
218
+ except Exception as e:
219
+ logging.error(f"Error parsing reorganized packs: {e}")
220
+ return []
221
+
222
+ def hierarchical_clustering(
223
+ data: List[Dict],
224
+ question_embeddings: torch.Tensor,
225
+ answer_embeddings: torch.Tensor
226
+ ) -> List[List[int]]:
227
+ """
228
+ Two-stage hierarchical clustering:
229
+ 1. Cluster questions by semantic similarity (topic/intent)
230
+ 2. Within each question cluster, sub-cluster answers by similarity
231
+
232
+ Returns: List of clusters (each cluster is a list of indices)
233
+ """
234
+ print(f"\n🔍 Stage 1: Clustering questions by topic (threshold: {QUESTION_SIMILARITY_THRESHOLD})...")
235
+
236
+ # Stage 1: Cluster questions to group by topic/intent
237
+ question_clusters = util.community_detection(
238
+ question_embeddings,
239
+ threshold=QUESTION_SIMILARITY_THRESHOLD,
240
+ min_community_size=1
241
+ )
242
+
243
+ print(f"✅ Found {len(question_clusters)} question-based topic groups")
244
+
245
+ # Stage 2: Within each question cluster, sub-cluster by answer similarity
246
+ final_clusters = []
247
+ singleton_count = 0
248
+
249
+ print(f"\n🔍 Stage 2: Sub-clustering answers within each topic group...")
250
+
251
+ for q_cluster in tqdm(question_clusters, desc="Processing question clusters"):
252
+ if len(q_cluster) == 1:
253
+ singleton_count += 1
254
+ continue # No duplicates possible
255
+
256
+ # Extract answer embeddings for this question cluster
257
+ q_cluster_list = list(q_cluster)
258
+ q_cluster_answer_embs = answer_embeddings[q_cluster_list]
259
+
260
+ # Check for chunk overlap to prioritize merging
261
+ # Build chunk overlap matrix
262
+ chunk_overlap_matrix = np.zeros((len(q_cluster_list), len(q_cluster_list)))
263
+ for i, idx_i in enumerate(q_cluster_list):
264
+ for j, idx_j in enumerate(q_cluster_list):
265
+ if i != j:
266
+ chunk_overlap_matrix[i, j] = compute_chunk_overlap(data[idx_i], data[idx_j])
267
+
268
+ # Compute answer similarity matrix
269
+ answer_sim_matrix = util.pytorch_cos_sim(q_cluster_answer_embs, q_cluster_answer_embs)
270
+
271
+ # Combined similarity per Eq. 10: α * cos(e_ai, e_aj) + (1-α) * J(C^s_i, C^s_j)
272
+ # α weights semantic similarity; (1-α) weights chunk lineage Jaccard overlap
273
+ combined_sim = ALPHA * answer_sim_matrix.cpu().numpy() + (1 - ALPHA) * chunk_overlap_matrix
274
+
275
+ # Find high-similarity pairs and group them
276
+ visited = set()
277
+ for i in range(len(q_cluster_list)):
278
+ if i in visited:
279
+ continue
280
+
281
+ cluster = [q_cluster_list[i]]
282
+ visited.add(i)
283
+
284
+ for j in range(i + 1, len(q_cluster_list)):
285
+ if j not in visited and combined_sim[i, j] >= ANSWER_SIMILARITY_LOW:
286
+ cluster.append(q_cluster_list[j])
287
+ visited.add(j)
288
+
289
+ if len(cluster) >= MIN_COMMUNITY_SIZE:
290
+ final_clusters.append(cluster)
291
+
292
+ print(f"✅ Found {len(final_clusters)} answer clusters requiring merge")
293
+ print(f"ℹ️ {singleton_count} singletons (unique QAs)")
294
+
295
+ return final_clusters
296
+
297
+ def process_cluster_by_similarity(
298
+ cluster_items: List[Dict],
299
+ cluster_indices: List[int],
300
+ answer_embeddings: torch.Tensor,
301
+ llm_merger: LLMReranker,
302
+ expert_persona: str,
303
+ domain: str,
304
+ enable_reorganization: bool = True
305
+ ) -> List[Dict]:
306
+ """
307
+ Process a cluster with stratified handling based on answer similarity.
308
+ - High similarity (>0.95): Exact duplicates → select best
309
+ - Medium similarity (0.85-0.95): Partial overlap → LLM merge → reorganize
310
+ - Low similarity (0.70-0.85): Related → LLM evaluate → reorganize
311
+
312
+ After LLM merging, optionally reorganizes into balanced QA packs.
313
+
314
+ Args:
315
+ cluster_items: List of QA dicts in the cluster
316
+ cluster_indices: Original indices of items in the cluster
317
+ answer_embeddings: Tensor of answer embeddings
318
+ llm_merger: LLMReranker instance for merging
319
+ enable_reorganization: Whether to reorganize after merging
320
+ expert_persona: Expert role for domain-specific handling
321
+ domain: Domain context for handling
322
+ """
323
+ if len(cluster_items) < 2:
324
+ return cluster_items
325
+
326
+ # Compute pairwise answer similarities within cluster
327
+ cluster_answer_embs = answer_embeddings[cluster_indices]
328
+ sim_matrix = util.pytorch_cos_sim(cluster_answer_embs, cluster_answer_embs).cpu().numpy()
329
+
330
+ # Get max similarity (excluding diagonal)
331
+ np.fill_diagonal(sim_matrix, 0)
332
+ max_similarity = np.max(sim_matrix)
333
+ avg_similarity = np.mean(sim_matrix[np.triu_indices_from(sim_matrix, k=1)])
334
+
335
+ # Base metadata for propagation
336
+ base_metadata = cluster_items[0].copy()
337
+ base_metadata.pop('question', None)
338
+ base_metadata.pop('answer', None)
339
+ base_metadata['merged_from_count'] = len(cluster_items)
340
+
341
+ # Stratified handling
342
+ if max_similarity >= ANSWER_SIMILARITY_HIGH:
343
+ # Tier 1: Exact duplicates - just pick the best one (no reorganization needed)
344
+ logging.info(f"Cluster of {len(cluster_items)} with max_sim={max_similarity:.3f}: Selecting best (exact duplicates)")
345
+ best = select_best_qa(cluster_items)
346
+ best['dedup_method'] = 'select_best'
347
+ best['merged_from_count'] = len(cluster_items)
348
+ return [best]
349
+
350
+ elif avg_similarity >= ANSWER_SIMILARITY_MEDIUM:
351
+ # Tier 2: High overlap - merge with LLM, then reorganize
352
+ logging.info(f"Cluster of {len(cluster_items)} with avg_sim={avg_similarity:.3f}: LLM merge (high overlap)")
353
+ merged = llm_merger.deduplicate_and_merge(cluster_items)
354
+
355
+ # Reorganize into balanced packs if enabled and multiple items
356
+ if enable_reorganization and len(merged) > 1:
357
+ logging.info(f"Reorganizing {len(merged)} merged items into balanced packs")
358
+ merged = reorganize_qa_packs(merged, base_metadata, expert_persona=expert_persona, domain=domain)
359
+
360
+ for item in merged:
361
+ item['dedup_method'] = 'llm_merge_high'
362
+ return merged
363
+
364
+ else:
365
+ # Tier 3: Related but potentially distinct - let LLM decide, then reorganize
366
+ logging.info(f"Cluster of {len(cluster_items)} with avg_sim={avg_similarity:.3f}: LLM evaluate (medium overlap)")
367
+ merged = llm_merger.deduplicate_and_merge(cluster_items)
368
+
369
+ # Reorganize into balanced packs if enabled and multiple items
370
+ if enable_reorganization and len(merged) > 1:
371
+ logging.info(f"Reorganizing {len(merged)} merged items into balanced packs")
372
+ merged = reorganize_qa_packs(merged, base_metadata, expert_persona=expert_persona, domain=domain)
373
+
374
+ for item in merged:
375
+ item['dedup_method'] = 'llm_merge_medium'
376
+ return merged
377
+
378
+ def deduplicate_dataset():
379
+ setup_logging()
380
+
381
+ # 1. Load Data
382
+ if not os.path.exists(INPUT_FILE):
383
+ print(f"❌ Input file {INPUT_FILE} not found.")
384
+ return
385
+
386
+ data = load_dataset(INPUT_FILE)
387
+ if not data:
388
+ print("⚠️ Dataset is empty.")
389
+ return
390
+
391
+ print(f"\n{'='*80}")
392
+ print("🎯 HIERARCHICAL DEDUPLICATION: Questions → Answers")
393
+ print(f"{'='*80}\n")
394
+
395
+ # 2. Prepare separate embeddings for questions and answers
396
+ print("⚙️ Preparing text for embedding...")
397
+ questions = [item['question'] for item in data]
398
+ answers = [item['answer'] for item in data]
399
+
400
+ # 3. Load Embedding Model & Embed
401
+ model_name = get_best_embedding_model()
402
+ print(f"🤖 Loading embedding model: {model_name}")
403
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
404
+ embedder = SentenceTransformer(model_name, device=device)
405
+
406
+ print(f"📊 Generating question embeddings for {len(questions)} QA pairs...")
407
+ question_embeddings = embedder.encode(questions, convert_to_tensor=True, show_progress_bar=True)
408
+
409
+ print(f"📊 Generating answer embeddings for {len(answers)} QA pairs...")
410
+ answer_embeddings = embedder.encode(answers, convert_to_tensor=True, show_progress_bar=True)
411
+
412
+ # 4. Hierarchical Clustering: Questions first, then Answers
413
+ clusters = hierarchical_clustering(data, question_embeddings, answer_embeddings)
414
+
415
+ # 5. Track processed items
416
+ clustered_indices = set()
417
+ for cluster in clusters:
418
+ clustered_indices.update(cluster)
419
+
420
+ # Initialize LLM merger
421
+ llm_merger = LLMReranker()
422
+ final_dataset = []
423
+
424
+ # Statistics
425
+ stats = {
426
+ 'original': len(data),
427
+ 'singletons': 0,
428
+ 'clusters_processed': 0,
429
+ 'items_in_clusters': 0,
430
+ 'exact_duplicates': 0,
431
+ 'llm_merges': 0,
432
+ 'reorganized_packs': 0
433
+ }
434
+
435
+ # Add singletons (items not in any cluster)
436
+ for i in range(len(data)):
437
+ if i not in clustered_indices:
438
+ final_dataset.append(data[i])
439
+ stats['singletons'] += 1
440
+
441
+ print(f"\nℹ️ Added {stats['singletons']} unique (singleton) items.")
442
+
443
+ # 6. Process clusters with stratified handling
444
+ print(f"\n🔄 Processing {len(clusters)} clusters with stratified merge strategy...")
445
+
446
+ for cluster in tqdm(clusters, desc="Merging clusters"):
447
+ cluster_items = [data[idx] for idx in cluster]
448
+ stats['clusters_processed'] += 1
449
+ stats['items_in_clusters'] += len(cluster_items)
450
+
451
+ # Process with stratified approach
452
+ merged_items = process_cluster_by_similarity(
453
+ cluster_items,
454
+ cluster,
455
+ answer_embeddings,
456
+ llm_merger
457
+ )
458
+
459
+ # Track merge type
460
+ if merged_items and merged_items[0].get('dedup_method') == 'select_best':
461
+ stats['exact_duplicates'] += 1
462
+ else:
463
+ stats['llm_merges'] += 1
464
+
465
+ # Track reorganized packs
466
+ reorganized_count = sum(1 for item in merged_items if item.get('reorganized', False))
467
+ if reorganized_count > 0:
468
+ stats['reorganized_packs'] += reorganized_count
469
+
470
+ final_dataset.extend(merged_items)
471
+
472
+ # 7. Save Results
473
+ print("\n" + "="*80)
474
+ print("📊 HIERARCHICAL DEDUPLICATION SUMMARY")
475
+ print("="*80)
476
+ print(f"Original count: {stats['original']}")
477
+ print(f"Final count: {len(final_dataset)}")
478
+ print(f"Reduction: {stats['original'] - len(final_dataset)} ({100*(stats['original'] - len(final_dataset))/stats['original']:.1f}%)")
479
+ print(f"---")
480
+ print(f"Singleton items: {stats['singletons']}")
481
+ print(f"Clusters processed: {stats['clusters_processed']}")
482
+ print(f"Items in clusters: {stats['items_in_clusters']}")
483
+ print(f"Exact duplicates removed: {stats['exact_duplicates']}")
484
+ print(f"LLM merges performed: {stats['llm_merges']}")
485
+ print(f"Reorganized QA packs: {stats['reorganized_packs']}")
486
+ print("="*80)
487
+
488
+ save_dataset(final_dataset, OUTPUT_FILE)
489
+
490
+ if __name__ == "__main__":
491
+ deduplicate_dataset()