mirage-benchmark 1.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mirage-benchmark might be problematic. Click here for more details.
- mirage/__init__.py +83 -0
- mirage/cli.py +150 -0
- mirage/core/__init__.py +52 -0
- mirage/core/config.py +248 -0
- mirage/core/llm.py +1745 -0
- mirage/core/prompts.py +884 -0
- mirage/embeddings/__init__.py +31 -0
- mirage/embeddings/models.py +512 -0
- mirage/embeddings/rerankers_multimodal.py +766 -0
- mirage/embeddings/rerankers_text.py +149 -0
- mirage/evaluation/__init__.py +26 -0
- mirage/evaluation/metrics.py +2223 -0
- mirage/evaluation/metrics_optimized.py +2172 -0
- mirage/pipeline/__init__.py +45 -0
- mirage/pipeline/chunker.py +545 -0
- mirage/pipeline/context.py +1003 -0
- mirage/pipeline/deduplication.py +491 -0
- mirage/pipeline/domain.py +514 -0
- mirage/pipeline/pdf_processor.py +598 -0
- mirage/pipeline/qa_generator.py +798 -0
- mirage/utils/__init__.py +31 -0
- mirage/utils/ablation.py +360 -0
- mirage/utils/preflight.py +663 -0
- mirage/utils/stats.py +626 -0
- mirage_benchmark-1.0.4.dist-info/METADATA +490 -0
- mirage_benchmark-1.0.4.dist-info/RECORD +30 -0
- mirage_benchmark-1.0.4.dist-info/WHEEL +5 -0
- mirage_benchmark-1.0.4.dist-info/entry_points.txt +3 -0
- mirage_benchmark-1.0.4.dist-info/licenses/LICENSE +190 -0
- mirage_benchmark-1.0.4.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,491 @@
|
|
|
1
|
+
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import sys
|
|
5
|
+
import os
|
|
6
|
+
from tqdm import tqdm
|
|
7
|
+
from sentence_transformers import SentenceTransformer, util
|
|
8
|
+
import torch
|
|
9
|
+
import numpy as np
|
|
10
|
+
from typing import List, Dict, Set, Tuple
|
|
11
|
+
from collections import defaultdict
|
|
12
|
+
|
|
13
|
+
# Import helper modules
|
|
14
|
+
from mirage.embeddings.models import get_best_embedding_model
|
|
15
|
+
from mirage.embeddings.rerankers_text import LLMReranker
|
|
16
|
+
from mirage.core.llm import setup_logging, call_vlm_with_multiple_images
|
|
17
|
+
from mirage.core.prompts import PROMPTS
|
|
18
|
+
|
|
19
|
+
# Try to load configuration from config.yaml
|
|
20
|
+
try:
|
|
21
|
+
from mirage.core.config import load_config
|
|
22
|
+
_cfg = load_config()
|
|
23
|
+
_dedup = _cfg.get('deduplication', {})
|
|
24
|
+
|
|
25
|
+
QUESTION_SIMILARITY_THRESHOLD = _dedup.get('question_similarity_threshold', 0.75)
|
|
26
|
+
_ans_sim = _dedup.get('answer_similarity', {})
|
|
27
|
+
ANSWER_SIMILARITY_HIGH = _ans_sim.get('high', 0.95)
|
|
28
|
+
ANSWER_SIMILARITY_MEDIUM = _ans_sim.get('medium', 0.85)
|
|
29
|
+
ANSWER_SIMILARITY_LOW = _ans_sim.get('low', 0.70)
|
|
30
|
+
MIN_COMMUNITY_SIZE = _dedup.get('min_community_size', 2)
|
|
31
|
+
# α parameter: weight for semantic similarity vs chunk lineage (Eq. 10 in manuscript)
|
|
32
|
+
ALPHA = _dedup.get('alpha', 0.6)
|
|
33
|
+
print(f"✅ Deduplication config loaded: α={ALPHA}, question_threshold={QUESTION_SIMILARITY_THRESHOLD}")
|
|
34
|
+
except ImportError:
|
|
35
|
+
print("⚠️ config_loader not available, using default deduplication configuration")
|
|
36
|
+
QUESTION_SIMILARITY_THRESHOLD = 0.75
|
|
37
|
+
ANSWER_SIMILARITY_HIGH = 0.95
|
|
38
|
+
ANSWER_SIMILARITY_MEDIUM = 0.85
|
|
39
|
+
ANSWER_SIMILARITY_LOW = 0.70
|
|
40
|
+
MIN_COMMUNITY_SIZE = 2
|
|
41
|
+
ALPHA = 0.6 # Default α for Eq. 10: α * semantic_sim + (1-α) * jaccard
|
|
42
|
+
|
|
43
|
+
# Configuration
|
|
44
|
+
OUTPUT_DIR = "output"
|
|
45
|
+
INPUT_FILE = os.path.join(OUTPUT_DIR, "qa_multihop_pass.json")
|
|
46
|
+
OUTPUT_FILE = os.path.join(OUTPUT_DIR, "qa_dataset_deduplicated.json")
|
|
47
|
+
|
|
48
|
+
def load_dataset(filepath):
|
|
49
|
+
print(f"📂 Loading dataset from {filepath}...")
|
|
50
|
+
with open(filepath, 'r', encoding='utf-8') as f:
|
|
51
|
+
return json.load(f)
|
|
52
|
+
|
|
53
|
+
def save_dataset(data, filepath):
|
|
54
|
+
print(f"💾 Saving {len(data)} items to {filepath}...")
|
|
55
|
+
with open(filepath, 'w', encoding='utf-8') as f:
|
|
56
|
+
json.dump(data, f, indent=2, ensure_ascii=False)
|
|
57
|
+
|
|
58
|
+
def compute_chunk_overlap(qa1: Dict, qa2: Dict) -> float:
|
|
59
|
+
"""
|
|
60
|
+
Compute Jaccard similarity based on chunk lineage.
|
|
61
|
+
Returns overlap ratio [0, 1] of chunks used to generate each QA.
|
|
62
|
+
"""
|
|
63
|
+
# Extract chunk identifiers from chunks_added (list of dicts with file_name and chunk_id)
|
|
64
|
+
def extract_chunk_ids(qa):
|
|
65
|
+
ids = {qa.get('chunk_id', -1)}
|
|
66
|
+
for chunk in qa.get('chunks_added', []):
|
|
67
|
+
if isinstance(chunk, dict):
|
|
68
|
+
ids.add((chunk.get('file_name', ''), chunk.get('chunk_id', '')))
|
|
69
|
+
else:
|
|
70
|
+
ids.add(chunk)
|
|
71
|
+
return ids
|
|
72
|
+
|
|
73
|
+
chunks1 = extract_chunk_ids(qa1)
|
|
74
|
+
chunks2 = extract_chunk_ids(qa2)
|
|
75
|
+
|
|
76
|
+
# Remove invalid IDs
|
|
77
|
+
chunks1.discard(-1)
|
|
78
|
+
chunks2.discard(-1)
|
|
79
|
+
|
|
80
|
+
if not chunks1 or not chunks2:
|
|
81
|
+
return 0.0
|
|
82
|
+
|
|
83
|
+
intersection = len(chunks1 & chunks2)
|
|
84
|
+
union = len(chunks1 | chunks2)
|
|
85
|
+
|
|
86
|
+
return intersection / union if union > 0 else 0.0
|
|
87
|
+
|
|
88
|
+
def select_best_qa(cluster_items: List[Dict]) -> Dict:
|
|
89
|
+
"""
|
|
90
|
+
Select the best QA from exact duplicates based on quality metrics.
|
|
91
|
+
Prioritizes: relevance score > difficulty score > answer length
|
|
92
|
+
"""
|
|
93
|
+
best = cluster_items[0]
|
|
94
|
+
best_score = (
|
|
95
|
+
float(best.get('relevance_score', 0)),
|
|
96
|
+
float(best.get('difficulty_score', 0)),
|
|
97
|
+
len(best.get('answer', ''))
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
for item in cluster_items[1:]:
|
|
101
|
+
score = (
|
|
102
|
+
float(item.get('relevance_score', 0)),
|
|
103
|
+
float(item.get('difficulty_score', 0)),
|
|
104
|
+
len(item.get('answer', ''))
|
|
105
|
+
)
|
|
106
|
+
if score > best_score:
|
|
107
|
+
best = item
|
|
108
|
+
best_score = score
|
|
109
|
+
|
|
110
|
+
return best
|
|
111
|
+
|
|
112
|
+
def reorganize_qa_packs(merged_items: List[Dict], base_metadata: Dict,
|
|
113
|
+
expert_persona: str,
|
|
114
|
+
domain: str) -> List[Dict]:
|
|
115
|
+
"""
|
|
116
|
+
Reorganize merged QA pairs into balanced question-answer packs.
|
|
117
|
+
Groups related questions together while keeping packs balanced.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
merged_items: List of merged QA dicts with 'question' and 'answer' keys
|
|
121
|
+
base_metadata: Metadata to propagate to reorganized items
|
|
122
|
+
expert_persona: Expert role for domain-specific organization
|
|
123
|
+
domain: Domain context for organization
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
List of reorganized QA dicts
|
|
127
|
+
"""
|
|
128
|
+
if len(merged_items) <= 1:
|
|
129
|
+
return merged_items
|
|
130
|
+
|
|
131
|
+
# Prepare merged questions and answers for the prompt
|
|
132
|
+
merged_questions = "\n".join([
|
|
133
|
+
f"{i+1}. {item['question']}"
|
|
134
|
+
for i, item in enumerate(merged_items)
|
|
135
|
+
])
|
|
136
|
+
merged_answers = "\n".join([
|
|
137
|
+
f"{i+1}. {item['answer']}"
|
|
138
|
+
for i, item in enumerate(merged_items)
|
|
139
|
+
])
|
|
140
|
+
|
|
141
|
+
prompt_template = PROMPTS.get("deduplication_reorganize", "")
|
|
142
|
+
if not prompt_template:
|
|
143
|
+
logging.warning("deduplication_reorganize prompt not found, skipping reorganization.")
|
|
144
|
+
return merged_items
|
|
145
|
+
|
|
146
|
+
# Format prompt with domain and expert role
|
|
147
|
+
formatted_prompt = prompt_template.format(expert_persona=expert_persona, domain=domain)
|
|
148
|
+
prompt = f"{formatted_prompt}\n\nInput:\nMerged Questions:\n{merged_questions}\n\nMerged Answers:\n{merged_answers}"
|
|
149
|
+
|
|
150
|
+
try:
|
|
151
|
+
response = call_vlm_with_multiple_images(prompt, [])
|
|
152
|
+
reorganized = parse_reorganized_packs(response, base_metadata)
|
|
153
|
+
|
|
154
|
+
if not reorganized:
|
|
155
|
+
logging.warning("LLM returned empty reorganization, keeping merged items.")
|
|
156
|
+
return merged_items
|
|
157
|
+
|
|
158
|
+
return reorganized
|
|
159
|
+
|
|
160
|
+
except Exception as e:
|
|
161
|
+
logging.error(f"Error in LLM reorganization: {e}")
|
|
162
|
+
return merged_items
|
|
163
|
+
|
|
164
|
+
def parse_reorganized_packs(response_text: str, base_metadata: Dict) -> List[Dict]:
|
|
165
|
+
"""
|
|
166
|
+
Parse the LLM response containing reorganized QA packs.
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
List of reorganized QA dicts
|
|
170
|
+
"""
|
|
171
|
+
tuple_delimiter = PROMPTS.get("DEFAULT_TUPLE_DELIMITER", "<|#|>")
|
|
172
|
+
completion_delimiter = PROMPTS.get("DEFAULT_COMPLETION_DELIMITER", "<|#|>END<|#|>")
|
|
173
|
+
|
|
174
|
+
qa_packs = []
|
|
175
|
+
|
|
176
|
+
try:
|
|
177
|
+
# Remove completion delimiter if present
|
|
178
|
+
if completion_delimiter in response_text:
|
|
179
|
+
response_text = response_text.split(completion_delimiter)[0].strip()
|
|
180
|
+
|
|
181
|
+
# Remove START delimiter if present
|
|
182
|
+
start_delimiter = tuple_delimiter + "START" + tuple_delimiter
|
|
183
|
+
if start_delimiter in response_text:
|
|
184
|
+
response_text = response_text.split(start_delimiter, 1)[-1].strip()
|
|
185
|
+
|
|
186
|
+
# Split by NEXT delimiter to get individual packs
|
|
187
|
+
next_delimiter = tuple_delimiter + "NEXT" + tuple_delimiter
|
|
188
|
+
pack_texts = response_text.split(next_delimiter)
|
|
189
|
+
|
|
190
|
+
for pack_text in pack_texts:
|
|
191
|
+
pack_text = pack_text.strip()
|
|
192
|
+
if not pack_text:
|
|
193
|
+
continue
|
|
194
|
+
|
|
195
|
+
# Parse Question and Answer from the pack
|
|
196
|
+
# Expected format: Question<|#|><questions><|#|>Answer<|#|><answer>
|
|
197
|
+
if "Question" + tuple_delimiter in pack_text:
|
|
198
|
+
parts = pack_text.split(tuple_delimiter)
|
|
199
|
+
|
|
200
|
+
question = None
|
|
201
|
+
answer = None
|
|
202
|
+
|
|
203
|
+
for i, part in enumerate(parts):
|
|
204
|
+
if part.lower() == "question" and i + 1 < len(parts):
|
|
205
|
+
question = parts[i + 1].strip()
|
|
206
|
+
elif part.lower() == "answer" and i + 1 < len(parts):
|
|
207
|
+
answer = parts[i + 1].strip()
|
|
208
|
+
|
|
209
|
+
if question and answer:
|
|
210
|
+
new_item = base_metadata.copy()
|
|
211
|
+
new_item["question"] = question
|
|
212
|
+
new_item["answer"] = answer
|
|
213
|
+
new_item["reorganized"] = True
|
|
214
|
+
qa_packs.append(new_item)
|
|
215
|
+
|
|
216
|
+
return qa_packs
|
|
217
|
+
|
|
218
|
+
except Exception as e:
|
|
219
|
+
logging.error(f"Error parsing reorganized packs: {e}")
|
|
220
|
+
return []
|
|
221
|
+
|
|
222
|
+
def hierarchical_clustering(
|
|
223
|
+
data: List[Dict],
|
|
224
|
+
question_embeddings: torch.Tensor,
|
|
225
|
+
answer_embeddings: torch.Tensor
|
|
226
|
+
) -> List[List[int]]:
|
|
227
|
+
"""
|
|
228
|
+
Two-stage hierarchical clustering:
|
|
229
|
+
1. Cluster questions by semantic similarity (topic/intent)
|
|
230
|
+
2. Within each question cluster, sub-cluster answers by similarity
|
|
231
|
+
|
|
232
|
+
Returns: List of clusters (each cluster is a list of indices)
|
|
233
|
+
"""
|
|
234
|
+
print(f"\n🔍 Stage 1: Clustering questions by topic (threshold: {QUESTION_SIMILARITY_THRESHOLD})...")
|
|
235
|
+
|
|
236
|
+
# Stage 1: Cluster questions to group by topic/intent
|
|
237
|
+
question_clusters = util.community_detection(
|
|
238
|
+
question_embeddings,
|
|
239
|
+
threshold=QUESTION_SIMILARITY_THRESHOLD,
|
|
240
|
+
min_community_size=1
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
print(f"✅ Found {len(question_clusters)} question-based topic groups")
|
|
244
|
+
|
|
245
|
+
# Stage 2: Within each question cluster, sub-cluster by answer similarity
|
|
246
|
+
final_clusters = []
|
|
247
|
+
singleton_count = 0
|
|
248
|
+
|
|
249
|
+
print(f"\n🔍 Stage 2: Sub-clustering answers within each topic group...")
|
|
250
|
+
|
|
251
|
+
for q_cluster in tqdm(question_clusters, desc="Processing question clusters"):
|
|
252
|
+
if len(q_cluster) == 1:
|
|
253
|
+
singleton_count += 1
|
|
254
|
+
continue # No duplicates possible
|
|
255
|
+
|
|
256
|
+
# Extract answer embeddings for this question cluster
|
|
257
|
+
q_cluster_list = list(q_cluster)
|
|
258
|
+
q_cluster_answer_embs = answer_embeddings[q_cluster_list]
|
|
259
|
+
|
|
260
|
+
# Check for chunk overlap to prioritize merging
|
|
261
|
+
# Build chunk overlap matrix
|
|
262
|
+
chunk_overlap_matrix = np.zeros((len(q_cluster_list), len(q_cluster_list)))
|
|
263
|
+
for i, idx_i in enumerate(q_cluster_list):
|
|
264
|
+
for j, idx_j in enumerate(q_cluster_list):
|
|
265
|
+
if i != j:
|
|
266
|
+
chunk_overlap_matrix[i, j] = compute_chunk_overlap(data[idx_i], data[idx_j])
|
|
267
|
+
|
|
268
|
+
# Compute answer similarity matrix
|
|
269
|
+
answer_sim_matrix = util.pytorch_cos_sim(q_cluster_answer_embs, q_cluster_answer_embs)
|
|
270
|
+
|
|
271
|
+
# Combined similarity per Eq. 10: α * cos(e_ai, e_aj) + (1-α) * J(C^s_i, C^s_j)
|
|
272
|
+
# α weights semantic similarity; (1-α) weights chunk lineage Jaccard overlap
|
|
273
|
+
combined_sim = ALPHA * answer_sim_matrix.cpu().numpy() + (1 - ALPHA) * chunk_overlap_matrix
|
|
274
|
+
|
|
275
|
+
# Find high-similarity pairs and group them
|
|
276
|
+
visited = set()
|
|
277
|
+
for i in range(len(q_cluster_list)):
|
|
278
|
+
if i in visited:
|
|
279
|
+
continue
|
|
280
|
+
|
|
281
|
+
cluster = [q_cluster_list[i]]
|
|
282
|
+
visited.add(i)
|
|
283
|
+
|
|
284
|
+
for j in range(i + 1, len(q_cluster_list)):
|
|
285
|
+
if j not in visited and combined_sim[i, j] >= ANSWER_SIMILARITY_LOW:
|
|
286
|
+
cluster.append(q_cluster_list[j])
|
|
287
|
+
visited.add(j)
|
|
288
|
+
|
|
289
|
+
if len(cluster) >= MIN_COMMUNITY_SIZE:
|
|
290
|
+
final_clusters.append(cluster)
|
|
291
|
+
|
|
292
|
+
print(f"✅ Found {len(final_clusters)} answer clusters requiring merge")
|
|
293
|
+
print(f"ℹ️ {singleton_count} singletons (unique QAs)")
|
|
294
|
+
|
|
295
|
+
return final_clusters
|
|
296
|
+
|
|
297
|
+
def process_cluster_by_similarity(
|
|
298
|
+
cluster_items: List[Dict],
|
|
299
|
+
cluster_indices: List[int],
|
|
300
|
+
answer_embeddings: torch.Tensor,
|
|
301
|
+
llm_merger: LLMReranker,
|
|
302
|
+
expert_persona: str,
|
|
303
|
+
domain: str,
|
|
304
|
+
enable_reorganization: bool = True
|
|
305
|
+
) -> List[Dict]:
|
|
306
|
+
"""
|
|
307
|
+
Process a cluster with stratified handling based on answer similarity.
|
|
308
|
+
- High similarity (>0.95): Exact duplicates → select best
|
|
309
|
+
- Medium similarity (0.85-0.95): Partial overlap → LLM merge → reorganize
|
|
310
|
+
- Low similarity (0.70-0.85): Related → LLM evaluate → reorganize
|
|
311
|
+
|
|
312
|
+
After LLM merging, optionally reorganizes into balanced QA packs.
|
|
313
|
+
|
|
314
|
+
Args:
|
|
315
|
+
cluster_items: List of QA dicts in the cluster
|
|
316
|
+
cluster_indices: Original indices of items in the cluster
|
|
317
|
+
answer_embeddings: Tensor of answer embeddings
|
|
318
|
+
llm_merger: LLMReranker instance for merging
|
|
319
|
+
enable_reorganization: Whether to reorganize after merging
|
|
320
|
+
expert_persona: Expert role for domain-specific handling
|
|
321
|
+
domain: Domain context for handling
|
|
322
|
+
"""
|
|
323
|
+
if len(cluster_items) < 2:
|
|
324
|
+
return cluster_items
|
|
325
|
+
|
|
326
|
+
# Compute pairwise answer similarities within cluster
|
|
327
|
+
cluster_answer_embs = answer_embeddings[cluster_indices]
|
|
328
|
+
sim_matrix = util.pytorch_cos_sim(cluster_answer_embs, cluster_answer_embs).cpu().numpy()
|
|
329
|
+
|
|
330
|
+
# Get max similarity (excluding diagonal)
|
|
331
|
+
np.fill_diagonal(sim_matrix, 0)
|
|
332
|
+
max_similarity = np.max(sim_matrix)
|
|
333
|
+
avg_similarity = np.mean(sim_matrix[np.triu_indices_from(sim_matrix, k=1)])
|
|
334
|
+
|
|
335
|
+
# Base metadata for propagation
|
|
336
|
+
base_metadata = cluster_items[0].copy()
|
|
337
|
+
base_metadata.pop('question', None)
|
|
338
|
+
base_metadata.pop('answer', None)
|
|
339
|
+
base_metadata['merged_from_count'] = len(cluster_items)
|
|
340
|
+
|
|
341
|
+
# Stratified handling
|
|
342
|
+
if max_similarity >= ANSWER_SIMILARITY_HIGH:
|
|
343
|
+
# Tier 1: Exact duplicates - just pick the best one (no reorganization needed)
|
|
344
|
+
logging.info(f"Cluster of {len(cluster_items)} with max_sim={max_similarity:.3f}: Selecting best (exact duplicates)")
|
|
345
|
+
best = select_best_qa(cluster_items)
|
|
346
|
+
best['dedup_method'] = 'select_best'
|
|
347
|
+
best['merged_from_count'] = len(cluster_items)
|
|
348
|
+
return [best]
|
|
349
|
+
|
|
350
|
+
elif avg_similarity >= ANSWER_SIMILARITY_MEDIUM:
|
|
351
|
+
# Tier 2: High overlap - merge with LLM, then reorganize
|
|
352
|
+
logging.info(f"Cluster of {len(cluster_items)} with avg_sim={avg_similarity:.3f}: LLM merge (high overlap)")
|
|
353
|
+
merged = llm_merger.deduplicate_and_merge(cluster_items)
|
|
354
|
+
|
|
355
|
+
# Reorganize into balanced packs if enabled and multiple items
|
|
356
|
+
if enable_reorganization and len(merged) > 1:
|
|
357
|
+
logging.info(f"Reorganizing {len(merged)} merged items into balanced packs")
|
|
358
|
+
merged = reorganize_qa_packs(merged, base_metadata, expert_persona=expert_persona, domain=domain)
|
|
359
|
+
|
|
360
|
+
for item in merged:
|
|
361
|
+
item['dedup_method'] = 'llm_merge_high'
|
|
362
|
+
return merged
|
|
363
|
+
|
|
364
|
+
else:
|
|
365
|
+
# Tier 3: Related but potentially distinct - let LLM decide, then reorganize
|
|
366
|
+
logging.info(f"Cluster of {len(cluster_items)} with avg_sim={avg_similarity:.3f}: LLM evaluate (medium overlap)")
|
|
367
|
+
merged = llm_merger.deduplicate_and_merge(cluster_items)
|
|
368
|
+
|
|
369
|
+
# Reorganize into balanced packs if enabled and multiple items
|
|
370
|
+
if enable_reorganization and len(merged) > 1:
|
|
371
|
+
logging.info(f"Reorganizing {len(merged)} merged items into balanced packs")
|
|
372
|
+
merged = reorganize_qa_packs(merged, base_metadata, expert_persona=expert_persona, domain=domain)
|
|
373
|
+
|
|
374
|
+
for item in merged:
|
|
375
|
+
item['dedup_method'] = 'llm_merge_medium'
|
|
376
|
+
return merged
|
|
377
|
+
|
|
378
|
+
def deduplicate_dataset():
|
|
379
|
+
setup_logging()
|
|
380
|
+
|
|
381
|
+
# 1. Load Data
|
|
382
|
+
if not os.path.exists(INPUT_FILE):
|
|
383
|
+
print(f"❌ Input file {INPUT_FILE} not found.")
|
|
384
|
+
return
|
|
385
|
+
|
|
386
|
+
data = load_dataset(INPUT_FILE)
|
|
387
|
+
if not data:
|
|
388
|
+
print("⚠️ Dataset is empty.")
|
|
389
|
+
return
|
|
390
|
+
|
|
391
|
+
print(f"\n{'='*80}")
|
|
392
|
+
print("🎯 HIERARCHICAL DEDUPLICATION: Questions → Answers")
|
|
393
|
+
print(f"{'='*80}\n")
|
|
394
|
+
|
|
395
|
+
# 2. Prepare separate embeddings for questions and answers
|
|
396
|
+
print("⚙️ Preparing text for embedding...")
|
|
397
|
+
questions = [item['question'] for item in data]
|
|
398
|
+
answers = [item['answer'] for item in data]
|
|
399
|
+
|
|
400
|
+
# 3. Load Embedding Model & Embed
|
|
401
|
+
model_name = get_best_embedding_model()
|
|
402
|
+
print(f"🤖 Loading embedding model: {model_name}")
|
|
403
|
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
404
|
+
embedder = SentenceTransformer(model_name, device=device)
|
|
405
|
+
|
|
406
|
+
print(f"📊 Generating question embeddings for {len(questions)} QA pairs...")
|
|
407
|
+
question_embeddings = embedder.encode(questions, convert_to_tensor=True, show_progress_bar=True)
|
|
408
|
+
|
|
409
|
+
print(f"📊 Generating answer embeddings for {len(answers)} QA pairs...")
|
|
410
|
+
answer_embeddings = embedder.encode(answers, convert_to_tensor=True, show_progress_bar=True)
|
|
411
|
+
|
|
412
|
+
# 4. Hierarchical Clustering: Questions first, then Answers
|
|
413
|
+
clusters = hierarchical_clustering(data, question_embeddings, answer_embeddings)
|
|
414
|
+
|
|
415
|
+
# 5. Track processed items
|
|
416
|
+
clustered_indices = set()
|
|
417
|
+
for cluster in clusters:
|
|
418
|
+
clustered_indices.update(cluster)
|
|
419
|
+
|
|
420
|
+
# Initialize LLM merger
|
|
421
|
+
llm_merger = LLMReranker()
|
|
422
|
+
final_dataset = []
|
|
423
|
+
|
|
424
|
+
# Statistics
|
|
425
|
+
stats = {
|
|
426
|
+
'original': len(data),
|
|
427
|
+
'singletons': 0,
|
|
428
|
+
'clusters_processed': 0,
|
|
429
|
+
'items_in_clusters': 0,
|
|
430
|
+
'exact_duplicates': 0,
|
|
431
|
+
'llm_merges': 0,
|
|
432
|
+
'reorganized_packs': 0
|
|
433
|
+
}
|
|
434
|
+
|
|
435
|
+
# Add singletons (items not in any cluster)
|
|
436
|
+
for i in range(len(data)):
|
|
437
|
+
if i not in clustered_indices:
|
|
438
|
+
final_dataset.append(data[i])
|
|
439
|
+
stats['singletons'] += 1
|
|
440
|
+
|
|
441
|
+
print(f"\nℹ️ Added {stats['singletons']} unique (singleton) items.")
|
|
442
|
+
|
|
443
|
+
# 6. Process clusters with stratified handling
|
|
444
|
+
print(f"\n🔄 Processing {len(clusters)} clusters with stratified merge strategy...")
|
|
445
|
+
|
|
446
|
+
for cluster in tqdm(clusters, desc="Merging clusters"):
|
|
447
|
+
cluster_items = [data[idx] for idx in cluster]
|
|
448
|
+
stats['clusters_processed'] += 1
|
|
449
|
+
stats['items_in_clusters'] += len(cluster_items)
|
|
450
|
+
|
|
451
|
+
# Process with stratified approach
|
|
452
|
+
merged_items = process_cluster_by_similarity(
|
|
453
|
+
cluster_items,
|
|
454
|
+
cluster,
|
|
455
|
+
answer_embeddings,
|
|
456
|
+
llm_merger
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
# Track merge type
|
|
460
|
+
if merged_items and merged_items[0].get('dedup_method') == 'select_best':
|
|
461
|
+
stats['exact_duplicates'] += 1
|
|
462
|
+
else:
|
|
463
|
+
stats['llm_merges'] += 1
|
|
464
|
+
|
|
465
|
+
# Track reorganized packs
|
|
466
|
+
reorganized_count = sum(1 for item in merged_items if item.get('reorganized', False))
|
|
467
|
+
if reorganized_count > 0:
|
|
468
|
+
stats['reorganized_packs'] += reorganized_count
|
|
469
|
+
|
|
470
|
+
final_dataset.extend(merged_items)
|
|
471
|
+
|
|
472
|
+
# 7. Save Results
|
|
473
|
+
print("\n" + "="*80)
|
|
474
|
+
print("📊 HIERARCHICAL DEDUPLICATION SUMMARY")
|
|
475
|
+
print("="*80)
|
|
476
|
+
print(f"Original count: {stats['original']}")
|
|
477
|
+
print(f"Final count: {len(final_dataset)}")
|
|
478
|
+
print(f"Reduction: {stats['original'] - len(final_dataset)} ({100*(stats['original'] - len(final_dataset))/stats['original']:.1f}%)")
|
|
479
|
+
print(f"---")
|
|
480
|
+
print(f"Singleton items: {stats['singletons']}")
|
|
481
|
+
print(f"Clusters processed: {stats['clusters_processed']}")
|
|
482
|
+
print(f"Items in clusters: {stats['items_in_clusters']}")
|
|
483
|
+
print(f"Exact duplicates removed: {stats['exact_duplicates']}")
|
|
484
|
+
print(f"LLM merges performed: {stats['llm_merges']}")
|
|
485
|
+
print(f"Reorganized QA packs: {stats['reorganized_packs']}")
|
|
486
|
+
print("="*80)
|
|
487
|
+
|
|
488
|
+
save_dataset(final_dataset, OUTPUT_FILE)
|
|
489
|
+
|
|
490
|
+
if __name__ == "__main__":
|
|
491
|
+
deduplicate_dataset()
|