mirage-benchmark 1.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mirage-benchmark might be problematic. Click here for more details.
- mirage/__init__.py +83 -0
- mirage/cli.py +150 -0
- mirage/core/__init__.py +52 -0
- mirage/core/config.py +248 -0
- mirage/core/llm.py +1745 -0
- mirage/core/prompts.py +884 -0
- mirage/embeddings/__init__.py +31 -0
- mirage/embeddings/models.py +512 -0
- mirage/embeddings/rerankers_multimodal.py +766 -0
- mirage/embeddings/rerankers_text.py +149 -0
- mirage/evaluation/__init__.py +26 -0
- mirage/evaluation/metrics.py +2223 -0
- mirage/evaluation/metrics_optimized.py +2172 -0
- mirage/pipeline/__init__.py +45 -0
- mirage/pipeline/chunker.py +545 -0
- mirage/pipeline/context.py +1003 -0
- mirage/pipeline/deduplication.py +491 -0
- mirage/pipeline/domain.py +514 -0
- mirage/pipeline/pdf_processor.py +598 -0
- mirage/pipeline/qa_generator.py +798 -0
- mirage/utils/__init__.py +31 -0
- mirage/utils/ablation.py +360 -0
- mirage/utils/preflight.py +663 -0
- mirage/utils/stats.py +626 -0
- mirage_benchmark-1.0.4.dist-info/METADATA +490 -0
- mirage_benchmark-1.0.4.dist-info/RECORD +30 -0
- mirage_benchmark-1.0.4.dist-info/WHEEL +5 -0
- mirage_benchmark-1.0.4.dist-info/entry_points.txt +3 -0
- mirage_benchmark-1.0.4.dist-info/licenses/LICENSE +190 -0
- mirage_benchmark-1.0.4.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,798 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
QA Generation from Multi-hop Context Completion
|
|
4
|
+
|
|
5
|
+
Pipeline for each chunk:
|
|
6
|
+
1. Load chunk from INPUT_CHUNKS_FILE
|
|
7
|
+
2. Build multihop context using retrieval (adds related chunks)
|
|
8
|
+
3. Generate Q&A pairs from the complete multihop context
|
|
9
|
+
4. Select best Q&A pairs using selection agent
|
|
10
|
+
5. Verify Q&A pairs require the context to answer
|
|
11
|
+
6. Save successful and failed Q&A pairs
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import json
|
|
15
|
+
import re
|
|
16
|
+
import logging
|
|
17
|
+
import os
|
|
18
|
+
from typing import Dict, List
|
|
19
|
+
from tqdm import tqdm
|
|
20
|
+
from mirage.core.llm import call_vlm_interweaved, setup_logging, batch_call_vlm_interweaved
|
|
21
|
+
from mirage.pipeline.context import build_complete_context
|
|
22
|
+
from mirage.pipeline.domain import fetch_domain_and_role, load_domain_expert_from_env, save_domain_expert_to_env
|
|
23
|
+
from mirage.core.prompts import PROMPTS, PROMPTS_CHUNK
|
|
24
|
+
|
|
25
|
+
# Configuration (override via config.yaml or command line)
|
|
26
|
+
INPUT_CHUNKS_FILE = "output/results/chunks.json"
|
|
27
|
+
OUTPUT_SUCCESSFUL = "qa_multihop_pass.json"
|
|
28
|
+
OUTPUT_FAILED = "qa_multihop_fail.json"
|
|
29
|
+
OUTPUT_IRRELEVANT = "irrelevant_chunk.json"
|
|
30
|
+
MAX_CHUNKS = None # Process all chunks (set to integer for testing, e.g., 100)
|
|
31
|
+
CHUNK_ADDITION_MODE = "RELATED" # EXPLANATORY (only direct answers) or RELATED (both)
|
|
32
|
+
|
|
33
|
+
def call_ai_service(prompt: str, chunks: List[Dict]) -> str:
|
|
34
|
+
"""Unified call to VLM using interleaved chunks"""
|
|
35
|
+
return call_vlm_interweaved(prompt, chunks)
|
|
36
|
+
|
|
37
|
+
def check_chunk_relevance(chunk_content: str, expert_persona: str, domain: str) -> bool:
|
|
38
|
+
"""Check if a chunk is relevant to the expert role and domain. Returns True if relevant, False otherwise."""
|
|
39
|
+
print(f"š Checking chunk relevance for {expert_persona} in {domain}...")
|
|
40
|
+
|
|
41
|
+
# Format prompt - domain should never be None at this point
|
|
42
|
+
domain_context = domain if domain else "unspecified domain"
|
|
43
|
+
|
|
44
|
+
prompt = PROMPTS_CHUNK["relevance_check"].format(
|
|
45
|
+
expert_persona=expert_persona,
|
|
46
|
+
domain=domain_context,
|
|
47
|
+
content=chunk_content[:2000] # Limit content length to avoid token limits
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
# Use a simple text-only LLM call for relevance check (no images needed)
|
|
51
|
+
from call_llm import call_llm
|
|
52
|
+
response = call_llm(prompt)
|
|
53
|
+
|
|
54
|
+
# Parse response - should be "RELEVANT" or "NOT_RELEVANT"
|
|
55
|
+
response_upper = response.strip().upper()
|
|
56
|
+
is_relevant = "RELEVANT" in response_upper and "NOT_RELEVANT" not in response_upper
|
|
57
|
+
|
|
58
|
+
if is_relevant:
|
|
59
|
+
print(f" ā
Chunk is RELEVANT")
|
|
60
|
+
else:
|
|
61
|
+
print(f" ā Chunk is NOT_RELEVANT")
|
|
62
|
+
|
|
63
|
+
return is_relevant
|
|
64
|
+
|
|
65
|
+
def generate_qa(chunks: List[Dict], expert_persona: str, domain: str) -> list:
|
|
66
|
+
"""Generate one or more Q&A pairs from multihop context chunks using consolidated prompt"""
|
|
67
|
+
print(f"ā Generating Q&A pairs from {len(chunks)} context chunks...")
|
|
68
|
+
|
|
69
|
+
# Format prompt with or without domain
|
|
70
|
+
if domain:
|
|
71
|
+
domain_context = f" in the field of {domain}"
|
|
72
|
+
domain_relevance = f" ({domain})"
|
|
73
|
+
else:
|
|
74
|
+
domain_context = ""
|
|
75
|
+
domain_relevance = ""
|
|
76
|
+
|
|
77
|
+
prompt = PROMPTS["question_answer_generation"].format(
|
|
78
|
+
content="[Refer to the chunks provided below]",
|
|
79
|
+
expert_persona=expert_persona,
|
|
80
|
+
domain_context=domain_context,
|
|
81
|
+
domain_relevance=domain_relevance
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
response = call_ai_service(prompt, chunks)
|
|
85
|
+
|
|
86
|
+
# Parse multiple Q&A pairs from response using delimiter format
|
|
87
|
+
qa_pairs = []
|
|
88
|
+
tuple_delimiter = PROMPTS.get("DEFAULT_TUPLE_DELIMITER", "<|#|>")
|
|
89
|
+
completion_delimiter = PROMPTS.get("DEFAULT_COMPLETION_DELIMITER", "<|#|>END<|#|>")
|
|
90
|
+
|
|
91
|
+
try:
|
|
92
|
+
# Remove completion delimiter if present
|
|
93
|
+
if completion_delimiter in response:
|
|
94
|
+
response = response.split(completion_delimiter)[0].strip()
|
|
95
|
+
|
|
96
|
+
# Remove START delimiter if present at the beginning
|
|
97
|
+
start_delimiter = tuple_delimiter + "START" + tuple_delimiter
|
|
98
|
+
if response.startswith(start_delimiter):
|
|
99
|
+
response = response[len(start_delimiter):].strip()
|
|
100
|
+
|
|
101
|
+
# Split by lines and process each Q&A pair
|
|
102
|
+
lines = [line.strip() for line in response.split('\n') if line.strip()]
|
|
103
|
+
|
|
104
|
+
for line in lines:
|
|
105
|
+
# Skip NEXT delimiter lines
|
|
106
|
+
next_delimiter = tuple_delimiter + "NEXT" + tuple_delimiter
|
|
107
|
+
if line == next_delimiter:
|
|
108
|
+
continue
|
|
109
|
+
|
|
110
|
+
# Check if line starts with "Question" delimiter pattern (case-insensitive)
|
|
111
|
+
if line.startswith("Question" + tuple_delimiter) or line.startswith("question" + tuple_delimiter):
|
|
112
|
+
# Split by tuple delimiter
|
|
113
|
+
parts = line.split(tuple_delimiter)
|
|
114
|
+
|
|
115
|
+
# Expected format: Question<|#|><question_text><|#|>Answer<|#|><answer_text><|#|>Relevance<|#|><score><|#|>Difficulty<|#|><score>
|
|
116
|
+
# Handle both capitalized and lowercase versions
|
|
117
|
+
if len(parts) >= 4 and parts[0].lower() == "question" and parts[2].lower() == "answer":
|
|
118
|
+
question = parts[1].strip()
|
|
119
|
+
answer = parts[3].strip()
|
|
120
|
+
relevance = "0"
|
|
121
|
+
difficulty = "0"
|
|
122
|
+
|
|
123
|
+
# Try to extract scores if present
|
|
124
|
+
if len(parts) >= 8:
|
|
125
|
+
if parts[4].lower() == "relevance":
|
|
126
|
+
relevance = parts[5].strip()
|
|
127
|
+
if parts[6].lower() == "difficulty":
|
|
128
|
+
difficulty = parts[7].strip()
|
|
129
|
+
|
|
130
|
+
if question and answer:
|
|
131
|
+
qa_pairs.append({
|
|
132
|
+
"question": question,
|
|
133
|
+
"answer": answer,
|
|
134
|
+
"relevance_score": relevance,
|
|
135
|
+
"difficulty_score": difficulty
|
|
136
|
+
})
|
|
137
|
+
|
|
138
|
+
if not qa_pairs:
|
|
139
|
+
# Fallback: try old format patterns for backward compatibility
|
|
140
|
+
# Try delimiter format with case-insensitive matching
|
|
141
|
+
for line in lines:
|
|
142
|
+
if re.match(r'(?i)^question', line):
|
|
143
|
+
parts = re.split(r'(?i)question' + re.escape(tuple_delimiter), line, maxsplit=1)
|
|
144
|
+
if len(parts) >= 2:
|
|
145
|
+
qa_content = parts[1]
|
|
146
|
+
qa_parts = qa_content.split(tuple_delimiter)
|
|
147
|
+
if len(qa_parts) >= 3 and qa_parts[1].lower() == "answer":
|
|
148
|
+
question = qa_parts[0].strip()
|
|
149
|
+
answer = qa_parts[2].strip()
|
|
150
|
+
if question and answer:
|
|
151
|
+
qa_pairs.append({
|
|
152
|
+
"question": question,
|
|
153
|
+
"answer": answer,
|
|
154
|
+
"relevance_score": "0",
|
|
155
|
+
"difficulty_score": "0"
|
|
156
|
+
})
|
|
157
|
+
|
|
158
|
+
# Final fallback: try old Question:/Answer: format
|
|
159
|
+
if not qa_pairs:
|
|
160
|
+
question_matches = re.finditer(r'(?i)Question:\s*(.*?)(?=\nAnswer:|\n\n|$)', response, re.DOTALL)
|
|
161
|
+
answer_matches = re.finditer(r'(?i)Answer:\s*(.*?)(?=\nQuestion:|\n\n|$)', response, re.DOTALL)
|
|
162
|
+
|
|
163
|
+
questions = [m.group(1).strip() for m in question_matches]
|
|
164
|
+
answers = [m.group(1).strip() for m in answer_matches]
|
|
165
|
+
|
|
166
|
+
if questions and answers and len(questions) == len(answers):
|
|
167
|
+
for q, a in zip(questions, answers):
|
|
168
|
+
qa_pairs.append({
|
|
169
|
+
"question": q,
|
|
170
|
+
"answer": a,
|
|
171
|
+
"relevance_score": "0",
|
|
172
|
+
"difficulty_score": "0"
|
|
173
|
+
})
|
|
174
|
+
elif questions and answers:
|
|
175
|
+
# Mismatched counts, try to pair them
|
|
176
|
+
for i, q in enumerate(questions):
|
|
177
|
+
if i < len(answers):
|
|
178
|
+
qa_pairs.append({
|
|
179
|
+
"question": q,
|
|
180
|
+
"answer": answers[i],
|
|
181
|
+
"relevance_score": "0",
|
|
182
|
+
"difficulty_score": "0"
|
|
183
|
+
})
|
|
184
|
+
|
|
185
|
+
if not qa_pairs:
|
|
186
|
+
print("ā ļø Could not parse Q&A from response")
|
|
187
|
+
qa_pairs.append({
|
|
188
|
+
"question": response,
|
|
189
|
+
"answer": "",
|
|
190
|
+
"relevance_score": "0",
|
|
191
|
+
"difficulty_score": "0"
|
|
192
|
+
})
|
|
193
|
+
|
|
194
|
+
print(f"ā
Generated {len(qa_pairs)} Q&A pair(s)")
|
|
195
|
+
return qa_pairs
|
|
196
|
+
|
|
197
|
+
except Exception as e:
|
|
198
|
+
print(f"ā ļø Error parsing Q&A: {e}")
|
|
199
|
+
import traceback
|
|
200
|
+
print(traceback.format_exc())
|
|
201
|
+
return [{
|
|
202
|
+
"question": response,
|
|
203
|
+
"answer": "",
|
|
204
|
+
"relevance_score": "0",
|
|
205
|
+
"difficulty_score": "0"
|
|
206
|
+
}]
|
|
207
|
+
|
|
208
|
+
def select_qa_pairs(qa_pairs: list, chunks: List[Dict], expert_persona: str, domain: str) -> tuple[list, list]:
|
|
209
|
+
"""Select/filter Q&A pairs using the selection agent. Returns (selected, rejected).
|
|
210
|
+
Uses batch processing when multiple QA pairs need to be evaluated.
|
|
211
|
+
"""
|
|
212
|
+
print(f"š Selecting Q&A pairs ({len(qa_pairs)} candidates)...")
|
|
213
|
+
|
|
214
|
+
if not qa_pairs:
|
|
215
|
+
return [], []
|
|
216
|
+
|
|
217
|
+
tuple_delimiter = PROMPTS.get("DEFAULT_TUPLE_DELIMITER", "<|#|>")
|
|
218
|
+
|
|
219
|
+
# Format domain context
|
|
220
|
+
if domain:
|
|
221
|
+
domain_context = f" in the field of {domain}"
|
|
222
|
+
domain_relevance = f" ({domain})"
|
|
223
|
+
else:
|
|
224
|
+
domain_context = ""
|
|
225
|
+
domain_relevance = ""
|
|
226
|
+
|
|
227
|
+
# Prepare batch requests
|
|
228
|
+
requests = []
|
|
229
|
+
for qa_pair in qa_pairs:
|
|
230
|
+
prompt = PROMPTS["question_answer_selection"].format(
|
|
231
|
+
content="[Refer to the chunks provided below]",
|
|
232
|
+
question=qa_pair["question"],
|
|
233
|
+
answer=qa_pair["answer"],
|
|
234
|
+
expert_persona=expert_persona,
|
|
235
|
+
domain_context=domain_context,
|
|
236
|
+
domain_relevance=domain_relevance
|
|
237
|
+
)
|
|
238
|
+
requests.append((prompt, chunks))
|
|
239
|
+
|
|
240
|
+
# Execute batch or sequential based on count
|
|
241
|
+
if len(requests) > 1:
|
|
242
|
+
print(f" ā” Batch evaluating {len(requests)} QA pairs...")
|
|
243
|
+
responses = batch_call_vlm_interweaved(requests, show_progress=False)
|
|
244
|
+
else:
|
|
245
|
+
responses = [call_ai_service(requests[0][0], chunks)]
|
|
246
|
+
|
|
247
|
+
# Process responses
|
|
248
|
+
selected_pairs = []
|
|
249
|
+
rejected_pairs = []
|
|
250
|
+
|
|
251
|
+
for idx, (qa_pair, response) in enumerate(zip(qa_pairs, responses), 1):
|
|
252
|
+
try:
|
|
253
|
+
if response and not response.startswith("ERROR:"):
|
|
254
|
+
parts = response.split(tuple_delimiter)
|
|
255
|
+
status = "REJECTED"
|
|
256
|
+
relevance = "0"
|
|
257
|
+
difficulty = "0"
|
|
258
|
+
reason = "No reason provided"
|
|
259
|
+
|
|
260
|
+
for i in range(0, len(parts), 2):
|
|
261
|
+
if i+1 < len(parts):
|
|
262
|
+
key = parts[i].strip().lower()
|
|
263
|
+
value = parts[i+1].strip()
|
|
264
|
+
|
|
265
|
+
if key == "status":
|
|
266
|
+
status = value.upper()
|
|
267
|
+
elif key == "relevance":
|
|
268
|
+
relevance = value
|
|
269
|
+
elif key == "difficulty":
|
|
270
|
+
difficulty = value
|
|
271
|
+
elif key == "reason":
|
|
272
|
+
reason = value
|
|
273
|
+
|
|
274
|
+
qa_pair["relevance_score"] = relevance
|
|
275
|
+
qa_pair["difficulty_score"] = difficulty
|
|
276
|
+
qa_pair["selection_reason"] = reason
|
|
277
|
+
qa_pair["selection_status"] = status
|
|
278
|
+
|
|
279
|
+
if status == "SELECTED":
|
|
280
|
+
selected_pairs.append(qa_pair)
|
|
281
|
+
print(f" ā
Q{idx} SELECTED (R:{relevance}/D:{difficulty})")
|
|
282
|
+
else:
|
|
283
|
+
rejected_pairs.append(qa_pair)
|
|
284
|
+
print(f" ā Q{idx} REJECTED: {reason[:60]}...")
|
|
285
|
+
else:
|
|
286
|
+
qa_pair["selection_status"] = "ERROR"
|
|
287
|
+
qa_pair["selection_reason"] = f"API Error: {response}"
|
|
288
|
+
rejected_pairs.append(qa_pair)
|
|
289
|
+
|
|
290
|
+
except Exception as e:
|
|
291
|
+
print(f" ā ļø Error parsing selection response for Q{idx}: {e}")
|
|
292
|
+
qa_pair["selection_status"] = "ERROR"
|
|
293
|
+
qa_pair["selection_reason"] = f"Parsing error: {str(e)}"
|
|
294
|
+
rejected_pairs.append(qa_pair)
|
|
295
|
+
|
|
296
|
+
print(f"ā
Selection complete: {len(selected_pairs)} selected, {len(rejected_pairs)} rejected")
|
|
297
|
+
return selected_pairs, rejected_pairs
|
|
298
|
+
|
|
299
|
+
def verify_qa(chunks: List[Dict], question: str, answer: str, expert_persona: str, domain: str) -> str:
|
|
300
|
+
"""Verify if the question requires the content to be answered"""
|
|
301
|
+
print("š Verifying Q&A pair...")
|
|
302
|
+
|
|
303
|
+
# Format prompt with or without domain
|
|
304
|
+
if domain:
|
|
305
|
+
domain_context = f" in the field of {domain}"
|
|
306
|
+
else:
|
|
307
|
+
domain_context = ""
|
|
308
|
+
|
|
309
|
+
prompt = PROMPTS["question_answer_verification"].format(
|
|
310
|
+
content="[Refer to the chunks provided below]",
|
|
311
|
+
question=question,
|
|
312
|
+
answer=answer,
|
|
313
|
+
expert_persona=expert_persona,
|
|
314
|
+
domain_context=domain_context
|
|
315
|
+
)
|
|
316
|
+
response = call_ai_service(prompt, chunks)
|
|
317
|
+
return response
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
def batch_verify_qa(chunks: List[Dict], qa_pairs: List[Dict], expert_persona: str, domain: str) -> List[str]:
|
|
321
|
+
"""Batch verify multiple Q&A pairs using concurrent API calls
|
|
322
|
+
|
|
323
|
+
Args:
|
|
324
|
+
chunks: Context chunks for verification
|
|
325
|
+
qa_pairs: List of dicts with 'question' and 'answer' keys
|
|
326
|
+
expert_persona: Expert role string
|
|
327
|
+
domain: Domain string
|
|
328
|
+
|
|
329
|
+
Returns:
|
|
330
|
+
List of verification result strings in same order as qa_pairs
|
|
331
|
+
"""
|
|
332
|
+
if not qa_pairs:
|
|
333
|
+
return []
|
|
334
|
+
|
|
335
|
+
# Format domain context
|
|
336
|
+
if domain:
|
|
337
|
+
domain_context = f" in the field of {domain}"
|
|
338
|
+
else:
|
|
339
|
+
domain_context = ""
|
|
340
|
+
|
|
341
|
+
# Prepare batch requests
|
|
342
|
+
requests = []
|
|
343
|
+
for qa in qa_pairs:
|
|
344
|
+
prompt = PROMPTS["question_answer_verification"].format(
|
|
345
|
+
content="[Refer to the chunks provided below]",
|
|
346
|
+
question=qa['question'],
|
|
347
|
+
answer=qa['answer'],
|
|
348
|
+
expert_persona=expert_persona,
|
|
349
|
+
domain_context=domain_context
|
|
350
|
+
)
|
|
351
|
+
requests.append((prompt, chunks))
|
|
352
|
+
|
|
353
|
+
# Execute batch
|
|
354
|
+
if len(requests) > 1:
|
|
355
|
+
print(f" ā” Batch verifying {len(requests)} QA pairs...")
|
|
356
|
+
responses = batch_call_vlm_interweaved(requests, show_progress=False)
|
|
357
|
+
else:
|
|
358
|
+
responses = [call_ai_service(requests[0][0], chunks)]
|
|
359
|
+
|
|
360
|
+
return responses
|
|
361
|
+
|
|
362
|
+
def process_chunk_for_qa(chunk_data: Dict, expert_persona: str, domain: str) -> Dict:
|
|
363
|
+
"""Complete pipeline: build context, generate Q&A using pre-extracted role, verify"""
|
|
364
|
+
chunk_id = chunk_data.get('chunk_id', 'unknown')
|
|
365
|
+
chunk_content = chunk_data.get('content', '')
|
|
366
|
+
|
|
367
|
+
print(f"\n{'='*80}")
|
|
368
|
+
print(f"Processing chunk {chunk_id}")
|
|
369
|
+
print(f"{'='*80}")
|
|
370
|
+
|
|
371
|
+
# Stage 1: Build complete context using multi-hop retrieval
|
|
372
|
+
print("\nš Stage 1: Building complete context...")
|
|
373
|
+
context_result = build_complete_context(
|
|
374
|
+
initial_chunk=chunk_data,
|
|
375
|
+
max_depth=3,
|
|
376
|
+
max_breadth=3,
|
|
377
|
+
chunks_per_search=2,
|
|
378
|
+
expert_persona=expert_persona,
|
|
379
|
+
domain=domain,
|
|
380
|
+
chunk_addition_mode=CHUNK_ADDITION_MODE
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
final_context = context_result['context']
|
|
384
|
+
context_chunks = context_result.get('chunks', [])
|
|
385
|
+
completion_status = context_result['status']
|
|
386
|
+
|
|
387
|
+
print(f"Context status: {completion_status}")
|
|
388
|
+
print(f"Final context length: {len(final_context)} chars")
|
|
389
|
+
print(f"Multihop context: {len(context_chunks)} chunks (original + retrieved)")
|
|
390
|
+
|
|
391
|
+
# Stage 2: Use pre-extracted expert role and domain (from BERTopic analysis)
|
|
392
|
+
print(f"\nā
Using expert role: {expert_persona}")
|
|
393
|
+
print(f"ā
Using domain: {domain}")
|
|
394
|
+
|
|
395
|
+
# Stage 3: Generate Q&A pairs from multihop context
|
|
396
|
+
print(f"\nš Stage 3: Generating Q&A pairs from multihop context ({len(context_chunks)} chunks)...")
|
|
397
|
+
qa_pairs = generate_qa(context_chunks, expert_persona, domain)
|
|
398
|
+
|
|
399
|
+
# Stage 3.5: Select Q&A pairs
|
|
400
|
+
print("\nš Stage 3.5: Selecting Q&A pairs...")
|
|
401
|
+
selected_pairs, rejected_pairs = select_qa_pairs(qa_pairs, context_chunks, expert_persona, domain)
|
|
402
|
+
|
|
403
|
+
# Stage 4: Verify selected Q&A pairs - BATCH PROCESSING
|
|
404
|
+
verified_qa_pairs = []
|
|
405
|
+
print(f"\nš Stage 4: Verifying {len(selected_pairs)} selected Q&A pair(s)...")
|
|
406
|
+
|
|
407
|
+
if selected_pairs:
|
|
408
|
+
# Use batch verification
|
|
409
|
+
verification_results = batch_verify_qa(
|
|
410
|
+
context_chunks, selected_pairs, expert_persona, domain
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
for i, (qa_pair, verification_result) in enumerate(zip(selected_pairs, verification_results), 1):
|
|
414
|
+
question = qa_pair["question"]
|
|
415
|
+
answer = qa_pair["answer"]
|
|
416
|
+
print(f"\n Q&A {i}: {question[:80]}...")
|
|
417
|
+
print(f" Verification: {verification_result[:100] if verification_result else 'ERROR'}...")
|
|
418
|
+
|
|
419
|
+
verified_qa_pairs.append({
|
|
420
|
+
"question": question,
|
|
421
|
+
"answer": answer,
|
|
422
|
+
"relevance_score": qa_pair.get("relevance_score", "0"),
|
|
423
|
+
"difficulty_score": qa_pair.get("difficulty_score", "0"),
|
|
424
|
+
"selection_status": qa_pair.get("selection_status", "SELECTED"),
|
|
425
|
+
"selection_reason": qa_pair.get("selection_reason", ""),
|
|
426
|
+
"verification_result": verification_result or "ERROR"
|
|
427
|
+
})
|
|
428
|
+
|
|
429
|
+
return {
|
|
430
|
+
"chunk_id": chunk_id,
|
|
431
|
+
"original_chunk": chunk_content,
|
|
432
|
+
"final_context": final_context,
|
|
433
|
+
"context_chunks": context_chunks, # Full chunks with image_path for multimodal eval
|
|
434
|
+
"context_status": completion_status,
|
|
435
|
+
"depth_reached": context_result['depth'],
|
|
436
|
+
"chunks_added": context_result['chunks_added'],
|
|
437
|
+
"expert_persona": expert_persona,
|
|
438
|
+
"domain": domain,
|
|
439
|
+
"selected_qa_pairs": verified_qa_pairs,
|
|
440
|
+
"rejected_qa_pairs": rejected_pairs
|
|
441
|
+
}
|
|
442
|
+
|
|
443
|
+
def is_verification_successful(verification_result: str) -> bool:
|
|
444
|
+
"""Check if verification indicates success"""
|
|
445
|
+
required_good = ["QUESTION_CORRECT", "ANSWER_CORRECT", "REQUIRES_CONTENT"]
|
|
446
|
+
bad_values = ["QUESTION_INCORRECT", "ANSWER_INCORRECT", "CAN_ANSWER_WITHOUT_CONTENT"]
|
|
447
|
+
|
|
448
|
+
has_bad = any(bad in verification_result for bad in bad_values)
|
|
449
|
+
has_all_good = all(good in verification_result for good in required_good)
|
|
450
|
+
|
|
451
|
+
return has_all_good and not has_bad
|
|
452
|
+
|
|
453
|
+
|
|
454
|
+
def correct_failed_qa(chunks: List[Dict], failed_qa_pairs: List[Dict],
|
|
455
|
+
expert_persona: str, domain: str) -> List[Dict]:
|
|
456
|
+
"""Generate corrected QA pairs from failed QA pairs using verification feedback.
|
|
457
|
+
|
|
458
|
+
Args:
|
|
459
|
+
chunks: Context chunks for QA generation
|
|
460
|
+
failed_qa_pairs: List of dicts with 'question', 'answer', 'verification_result'
|
|
461
|
+
expert_persona: Expert role string
|
|
462
|
+
domain: Domain string
|
|
463
|
+
|
|
464
|
+
Returns:
|
|
465
|
+
List of corrected QA pair dicts with 'question', 'answer', 'relevance_score', 'difficulty_score'
|
|
466
|
+
"""
|
|
467
|
+
if not failed_qa_pairs:
|
|
468
|
+
return []
|
|
469
|
+
|
|
470
|
+
print(f"š§ Correcting {len(failed_qa_pairs)} failed Q&A pair(s)...")
|
|
471
|
+
|
|
472
|
+
# Format the failed QA feedback section
|
|
473
|
+
failed_qa_feedback_parts = []
|
|
474
|
+
for i, qa in enumerate(failed_qa_pairs, 1):
|
|
475
|
+
question = qa.get('question', '')
|
|
476
|
+
answer = qa.get('answer', '')
|
|
477
|
+
verification = qa.get('verification_result', 'No verification feedback available')
|
|
478
|
+
|
|
479
|
+
failed_qa_feedback_parts.append(
|
|
480
|
+
f"--- Failed QA #{i} ---\n"
|
|
481
|
+
f"Question: {question}\n"
|
|
482
|
+
f"Answer: {answer}\n"
|
|
483
|
+
f"Verification Feedback: {verification}\n"
|
|
484
|
+
)
|
|
485
|
+
|
|
486
|
+
failed_qa_feedback = "\n".join(failed_qa_feedback_parts)
|
|
487
|
+
|
|
488
|
+
# Format domain context
|
|
489
|
+
if domain:
|
|
490
|
+
domain_context = f" in the field of {domain}"
|
|
491
|
+
else:
|
|
492
|
+
domain_context = ""
|
|
493
|
+
|
|
494
|
+
prompt = PROMPTS["question_answer_generation_corrected"].format(
|
|
495
|
+
content="[Refer to the chunks provided below]",
|
|
496
|
+
expert_persona=expert_persona,
|
|
497
|
+
domain_context=domain_context,
|
|
498
|
+
failed_qa_feedback=failed_qa_feedback
|
|
499
|
+
)
|
|
500
|
+
|
|
501
|
+
response = call_ai_service(prompt, chunks)
|
|
502
|
+
|
|
503
|
+
# Parse corrected QA pairs (same format as generate_qa)
|
|
504
|
+
corrected_pairs = []
|
|
505
|
+
tuple_delimiter = PROMPTS.get("DEFAULT_TUPLE_DELIMITER", "<|#|>")
|
|
506
|
+
completion_delimiter = PROMPTS.get("DEFAULT_COMPLETION_DELIMITER", "<|#|>END<|#|>")
|
|
507
|
+
|
|
508
|
+
try:
|
|
509
|
+
# Remove completion delimiter if present
|
|
510
|
+
if completion_delimiter in response:
|
|
511
|
+
response = response.split(completion_delimiter)[0].strip()
|
|
512
|
+
|
|
513
|
+
# Remove START delimiter if present
|
|
514
|
+
start_delimiter = tuple_delimiter + "START" + tuple_delimiter
|
|
515
|
+
if response.startswith(start_delimiter):
|
|
516
|
+
response = response[len(start_delimiter):].strip()
|
|
517
|
+
|
|
518
|
+
# Check for empty response (no valid correction possible)
|
|
519
|
+
if not response.strip() or response.strip() == tuple_delimiter + "START" + tuple_delimiter:
|
|
520
|
+
print(" ā ļø No correction possible - content doesn't support the original topic")
|
|
521
|
+
return []
|
|
522
|
+
|
|
523
|
+
# Split by NEXT delimiter for multiple QA pairs
|
|
524
|
+
next_delimiter = tuple_delimiter + "NEXT" + tuple_delimiter
|
|
525
|
+
qa_sections = response.split(next_delimiter)
|
|
526
|
+
|
|
527
|
+
for section in qa_sections:
|
|
528
|
+
section = section.strip()
|
|
529
|
+
if not section:
|
|
530
|
+
continue
|
|
531
|
+
|
|
532
|
+
# Parse Question, Answer, Relevance, Difficulty
|
|
533
|
+
parts = section.split(tuple_delimiter)
|
|
534
|
+
qa_dict = {}
|
|
535
|
+
|
|
536
|
+
for j in range(0, len(parts) - 1, 2):
|
|
537
|
+
key = parts[j].strip()
|
|
538
|
+
value = parts[j + 1].strip() if j + 1 < len(parts) else ""
|
|
539
|
+
|
|
540
|
+
if key == "Question":
|
|
541
|
+
qa_dict["question"] = value
|
|
542
|
+
elif key == "Answer":
|
|
543
|
+
qa_dict["answer"] = value
|
|
544
|
+
elif key == "Relevance":
|
|
545
|
+
qa_dict["relevance_score"] = value
|
|
546
|
+
elif key == "Difficulty":
|
|
547
|
+
qa_dict["difficulty_score"] = value
|
|
548
|
+
|
|
549
|
+
if qa_dict.get("question") and qa_dict.get("answer"):
|
|
550
|
+
qa_dict["correction_status"] = "CORRECTED"
|
|
551
|
+
corrected_pairs.append(qa_dict)
|
|
552
|
+
print(f" ā
Corrected: {qa_dict['question'][:60]}...")
|
|
553
|
+
|
|
554
|
+
except Exception as e:
|
|
555
|
+
logging.error(f"Error parsing corrected QA response: {e}")
|
|
556
|
+
print(f" ā Error parsing correction response: {e}")
|
|
557
|
+
|
|
558
|
+
print(f" š Generated {len(corrected_pairs)} corrected Q&A pair(s)")
|
|
559
|
+
return corrected_pairs
|
|
560
|
+
|
|
561
|
+
|
|
562
|
+
if __name__ == "__main__":
|
|
563
|
+
setup_logging()
|
|
564
|
+
|
|
565
|
+
# Load chunks
|
|
566
|
+
print(f"š Loading chunks from {INPUT_CHUNKS_FILE}...")
|
|
567
|
+
with open(INPUT_CHUNKS_FILE, 'r') as f:
|
|
568
|
+
chunks = json.load(f)
|
|
569
|
+
|
|
570
|
+
# Limit to first N chunks if MAX_CHUNKS is set
|
|
571
|
+
if MAX_CHUNKS is not None:
|
|
572
|
+
chunks = chunks[:MAX_CHUNKS]
|
|
573
|
+
print(f"š Processing {len(chunks)} chunks (limited to {MAX_CHUNKS} for testing)...")
|
|
574
|
+
else:
|
|
575
|
+
print(f"š Processing all {len(chunks)} chunks...")
|
|
576
|
+
|
|
577
|
+
# Extract domain and expert role once for all chunks using BERTopic
|
|
578
|
+
print(f"\n{'='*80}")
|
|
579
|
+
|
|
580
|
+
# Priority 1: Check config.yaml
|
|
581
|
+
domain, expert_persona = None, None
|
|
582
|
+
try:
|
|
583
|
+
from config_loader import get_domain_expert_config
|
|
584
|
+
domain_config = get_domain_expert_config()
|
|
585
|
+
config_domain = domain_config.get('domain')
|
|
586
|
+
config_persona = domain_config.get('expert_persona')
|
|
587
|
+
|
|
588
|
+
if config_domain and config_persona:
|
|
589
|
+
print(f"ā
Using domain and expert persona from config.yaml")
|
|
590
|
+
domain, expert_persona = config_domain, config_persona
|
|
591
|
+
save_domain_expert_to_env(domain, expert_persona)
|
|
592
|
+
except ImportError:
|
|
593
|
+
pass
|
|
594
|
+
|
|
595
|
+
# Priority 2: Check environment variables
|
|
596
|
+
if not domain or not expert_persona:
|
|
597
|
+
env_domain, env_persona = load_domain_expert_from_env()
|
|
598
|
+
if env_domain and env_persona:
|
|
599
|
+
domain, expert_persona = env_domain, env_persona
|
|
600
|
+
|
|
601
|
+
# Priority 3: Auto-detect using BERTopic
|
|
602
|
+
if not domain or not expert_persona:
|
|
603
|
+
print("š Auto-detecting domain and expert persona from corpus...")
|
|
604
|
+
domain, expert_persona = fetch_domain_and_role(INPUT_CHUNKS_FILE)
|
|
605
|
+
# Note: fetch_domain_and_role already saves to environment variables
|
|
606
|
+
|
|
607
|
+
print(f"ā
Domain: {domain}")
|
|
608
|
+
print(f"ā
Expert Role: {expert_persona}")
|
|
609
|
+
print(f"{'='*80}\n")
|
|
610
|
+
|
|
611
|
+
# Initialize result containers
|
|
612
|
+
successful_qa_pairs = []
|
|
613
|
+
failed_qa_pairs = []
|
|
614
|
+
irrelevant_chunks = []
|
|
615
|
+
|
|
616
|
+
# Process each chunk: relevance check -> build multihop context -> generate QA -> select -> verify
|
|
617
|
+
for i, chunk in tqdm(enumerate(chunks, 1), total=len(chunks), desc="Processing chunks"):
|
|
618
|
+
tqdm.write(f"\n{'='*80}")
|
|
619
|
+
tqdm.write(f"Processing Chunk {i}/{len(chunks)}")
|
|
620
|
+
tqdm.write(f"Pipeline: Multihop Context ā QA Generation ā Selection ā Verification")
|
|
621
|
+
tqdm.write(f"{'='*80}")
|
|
622
|
+
|
|
623
|
+
try:
|
|
624
|
+
# Extract chunk content and metadata
|
|
625
|
+
if isinstance(chunk, dict):
|
|
626
|
+
chunk_content = chunk.get("content", str(chunk))
|
|
627
|
+
source_document = chunk.get("file_name", "unknown")
|
|
628
|
+
chunk_id = chunk.get("chunk_id", str(i))
|
|
629
|
+
else:
|
|
630
|
+
chunk_content = str(chunk)
|
|
631
|
+
source_document = "unknown"
|
|
632
|
+
chunk_id = str(i)
|
|
633
|
+
|
|
634
|
+
# Stage 0: Check chunk relevance
|
|
635
|
+
print(f"\nš Stage 0: Checking chunk relevance...")
|
|
636
|
+
is_relevant = check_chunk_relevance(chunk_content, expert_persona, domain)
|
|
637
|
+
|
|
638
|
+
if not is_relevant:
|
|
639
|
+
# Store irrelevant chunk and skip processing
|
|
640
|
+
irrelevant_chunks.append({
|
|
641
|
+
"chunk_id": chunk_id,
|
|
642
|
+
"source_document": source_document
|
|
643
|
+
})
|
|
644
|
+
tqdm.write(f"āļø Chunk {i} is NOT_RELEVANT - skipping processing")
|
|
645
|
+
continue
|
|
646
|
+
|
|
647
|
+
# Prepare chunk data
|
|
648
|
+
if isinstance(chunk, dict):
|
|
649
|
+
chunk_data = chunk
|
|
650
|
+
else:
|
|
651
|
+
chunk_data = {
|
|
652
|
+
"content": str(chunk),
|
|
653
|
+
"chunk_id": str(i),
|
|
654
|
+
"file_name": "unknown",
|
|
655
|
+
"artifact": "None"
|
|
656
|
+
}
|
|
657
|
+
|
|
658
|
+
# Process chunk with pre-extracted domain and role
|
|
659
|
+
result = process_chunk_for_qa(chunk_data, expert_persona, domain)
|
|
660
|
+
|
|
661
|
+
# Process selected Q&A pairs
|
|
662
|
+
successful_count = 0
|
|
663
|
+
for qa_pair in result.get("selected_qa_pairs", []):
|
|
664
|
+
if is_verification_successful(qa_pair.get("verification_result", "")):
|
|
665
|
+
successful_count += 1
|
|
666
|
+
# Create individual entry for each successful Q&A pair
|
|
667
|
+
successful_qa_pairs.append({
|
|
668
|
+
"chunk_id": result["chunk_id"],
|
|
669
|
+
"original_chunk": result["original_chunk"],
|
|
670
|
+
"final_context": result["final_context"],
|
|
671
|
+
"context_chunks": result.get("context_chunks", []), # Full chunks with image_path
|
|
672
|
+
"context_status": result["context_status"],
|
|
673
|
+
"depth_reached": result["depth_reached"],
|
|
674
|
+
"chunks_added": result["chunks_added"],
|
|
675
|
+
"expert_persona": result["expert_persona"],
|
|
676
|
+
"domain": result.get("domain", ""),
|
|
677
|
+
"question": qa_pair["question"],
|
|
678
|
+
"answer": qa_pair["answer"],
|
|
679
|
+
"relevance_score": qa_pair.get("relevance_score", "0"),
|
|
680
|
+
"difficulty_score": qa_pair.get("difficulty_score", "0"),
|
|
681
|
+
"selection_status": qa_pair.get("selection_status", "SELECTED"),
|
|
682
|
+
"selection_reason": qa_pair.get("selection_reason", ""),
|
|
683
|
+
"verification_result": qa_pair["verification_result"]
|
|
684
|
+
})
|
|
685
|
+
else:
|
|
686
|
+
# Create individual entry for each failed verification
|
|
687
|
+
failed_qa_pairs.append({
|
|
688
|
+
"chunk_id": result["chunk_id"],
|
|
689
|
+
"original_chunk": result["original_chunk"],
|
|
690
|
+
"final_context": result["final_context"],
|
|
691
|
+
"context_chunks": result.get("context_chunks", []), # Full chunks with image_path
|
|
692
|
+
"context_status": result["context_status"],
|
|
693
|
+
"depth_reached": result["depth_reached"],
|
|
694
|
+
"chunks_added": result["chunks_added"],
|
|
695
|
+
"expert_persona": result["expert_persona"],
|
|
696
|
+
"domain": result.get("domain", ""),
|
|
697
|
+
"question": qa_pair["question"],
|
|
698
|
+
"answer": qa_pair["answer"],
|
|
699
|
+
"relevance_score": qa_pair.get("relevance_score", "0"),
|
|
700
|
+
"difficulty_score": qa_pair.get("difficulty_score", "0"),
|
|
701
|
+
"selection_status": qa_pair.get("selection_status", "SELECTED"),
|
|
702
|
+
"selection_reason": qa_pair.get("selection_reason", ""),
|
|
703
|
+
"verification_result": qa_pair["verification_result"],
|
|
704
|
+
"failure_reason": "Failed verification"
|
|
705
|
+
})
|
|
706
|
+
|
|
707
|
+
# Add rejected Q&A pairs to failed list
|
|
708
|
+
for qa_pair in result.get("rejected_qa_pairs", []):
|
|
709
|
+
failed_qa_pairs.append({
|
|
710
|
+
"chunk_id": result["chunk_id"],
|
|
711
|
+
"original_chunk": result["original_chunk"],
|
|
712
|
+
"final_context": result["final_context"],
|
|
713
|
+
"context_chunks": result.get("context_chunks", []), # Full chunks with image_path
|
|
714
|
+
"context_status": result["context_status"],
|
|
715
|
+
"depth_reached": result["depth_reached"],
|
|
716
|
+
"chunks_added": result["chunks_added"],
|
|
717
|
+
"expert_persona": result["expert_persona"],
|
|
718
|
+
"domain": result.get("domain", ""),
|
|
719
|
+
"question": qa_pair["question"],
|
|
720
|
+
"answer": qa_pair["answer"],
|
|
721
|
+
"relevance_score": qa_pair.get("relevance_score", "0"),
|
|
722
|
+
"difficulty_score": qa_pair.get("difficulty_score", "0"),
|
|
723
|
+
"selection_status": qa_pair.get("selection_status", "REJECTED"),
|
|
724
|
+
"selection_reason": qa_pair.get("selection_reason", ""),
|
|
725
|
+
"verification_result": "N/A - rejected by selection agent",
|
|
726
|
+
"failure_reason": "Rejected by selection agent"
|
|
727
|
+
})
|
|
728
|
+
|
|
729
|
+
total_qa = len(result.get("selected_qa_pairs", [])) + len(result.get("rejected_qa_pairs", []))
|
|
730
|
+
if successful_count > 0:
|
|
731
|
+
tqdm.write(f"ā
{successful_count}/{total_qa} Q&A pair(s) passed all stages")
|
|
732
|
+
else:
|
|
733
|
+
tqdm.write(f"ā ļø 0/{total_qa} Q&A pairs passed all stages")
|
|
734
|
+
|
|
735
|
+
except Exception as e:
|
|
736
|
+
error_msg = f"ā Error processing chunk {i}: {e}"
|
|
737
|
+
tqdm.write(error_msg)
|
|
738
|
+
logging.error(error_msg)
|
|
739
|
+
|
|
740
|
+
import traceback
|
|
741
|
+
logging.error(traceback.format_exc())
|
|
742
|
+
|
|
743
|
+
failed_qa_pairs.append({
|
|
744
|
+
"chunk_id": i,
|
|
745
|
+
"chunk_content": chunk_content if 'chunk_content' in locals() else str(chunk),
|
|
746
|
+
"error": str(e),
|
|
747
|
+
"traceback": traceback.format_exc()
|
|
748
|
+
})
|
|
749
|
+
continue
|
|
750
|
+
|
|
751
|
+
# Save results
|
|
752
|
+
print(f"\n{'='*80}")
|
|
753
|
+
print("š Saving results...")
|
|
754
|
+
print(f"{'='*80}")
|
|
755
|
+
|
|
756
|
+
if successful_qa_pairs:
|
|
757
|
+
with open(OUTPUT_SUCCESSFUL, 'w', encoding='utf-8') as f:
|
|
758
|
+
json.dump(successful_qa_pairs, f, indent=2, ensure_ascii=False)
|
|
759
|
+
print(f"ā
Successful QA pairs: {len(successful_qa_pairs)} saved to {OUTPUT_SUCCESSFUL}")
|
|
760
|
+
|
|
761
|
+
if failed_qa_pairs:
|
|
762
|
+
with open(OUTPUT_FAILED, 'w', encoding='utf-8') as f:
|
|
763
|
+
json.dump(failed_qa_pairs, f, indent=2, ensure_ascii=False)
|
|
764
|
+
print(f"ā ļø Failed QA pairs: {len(failed_qa_pairs)} saved to {OUTPUT_FAILED}")
|
|
765
|
+
|
|
766
|
+
if irrelevant_chunks:
|
|
767
|
+
with open(OUTPUT_IRRELEVANT, 'w', encoding='utf-8') as f:
|
|
768
|
+
json.dump(irrelevant_chunks, f, indent=2, ensure_ascii=False)
|
|
769
|
+
print(f"āļø Irrelevant chunks: {len(irrelevant_chunks)} saved to {OUTPUT_IRRELEVANT}")
|
|
770
|
+
|
|
771
|
+
# Summary
|
|
772
|
+
print(f"\n{'='*80}")
|
|
773
|
+
print("š SUMMARY")
|
|
774
|
+
print(f"{'='*80}")
|
|
775
|
+
print(f"Total chunks processed: {len(chunks)}")
|
|
776
|
+
print(f"Relevant chunks processed: {len(chunks) - len(irrelevant_chunks)}")
|
|
777
|
+
print(f"Irrelevant chunks skipped: {len(irrelevant_chunks)}")
|
|
778
|
+
print(f"Successful QA pairs: {len(successful_qa_pairs)}")
|
|
779
|
+
print(f"Failed QA pairs: {len(failed_qa_pairs)}")
|
|
780
|
+
if len(chunks) - len(irrelevant_chunks) > 0:
|
|
781
|
+
print(f"Success rate: {len(successful_qa_pairs)/(len(chunks) - len(irrelevant_chunks))*100:.1f}%")
|
|
782
|
+
|
|
783
|
+
# Stats: Count single vs multiple chunks in successful QA pairs
|
|
784
|
+
single_chunk_count = 0
|
|
785
|
+
multiple_chunk_count = 0
|
|
786
|
+
for qa_pair in successful_qa_pairs:
|
|
787
|
+
chunks_added = qa_pair.get("chunks_added", [])
|
|
788
|
+
if isinstance(chunks_added, list):
|
|
789
|
+
if len(chunks_added) == 1:
|
|
790
|
+
single_chunk_count += 1
|
|
791
|
+
elif len(chunks_added) > 1:
|
|
792
|
+
multiple_chunk_count += 1
|
|
793
|
+
|
|
794
|
+
print(f"\nš QA STATS")
|
|
795
|
+
print(f"{'='*80}")
|
|
796
|
+
print(f"Number of QA pairs in qa_multihop_fail: {len(failed_qa_pairs)}")
|
|
797
|
+
print(f"Number of QA pairs with single chunk in qa_multihop_pass: {single_chunk_count}")
|
|
798
|
+
print(f"Number of QA pairs with multiple chunks in qa_multihop_pass: {multiple_chunk_count}")
|