cotlab 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cotlab/__init__.py +3 -0
- cotlab/analyse_experiments.py +392 -0
- cotlab/analysis/__init__.py +11 -0
- cotlab/analysis/cot_parser.py +243 -0
- cotlab/analysis/faithfulness_metrics.py +192 -0
- cotlab/backends/__init__.py +16 -0
- cotlab/backends/base.py +78 -0
- cotlab/backends/transformers_backend.py +335 -0
- cotlab/backends/vllm_backend.py +227 -0
- cotlab/cli.py +83 -0
- cotlab/core/__init__.py +34 -0
- cotlab/core/base.py +749 -0
- cotlab/core/config.py +90 -0
- cotlab/core/registry.py +68 -0
- cotlab/datasets/__init__.py +45 -0
- cotlab/datasets/loaders.py +1889 -0
- cotlab/experiment/__init__.py +315 -0
- cotlab/experiments/__init__.py +43 -0
- cotlab/experiments/activation_compare.py +290 -0
- cotlab/experiments/activation_patching.py +1050 -0
- cotlab/experiments/attention_analysis.py +885 -0
- cotlab/experiments/classification.py +235 -0
- cotlab/experiments/composite_shift_detector.py +524 -0
- cotlab/experiments/cot_ablation.py +277 -0
- cotlab/experiments/cot_faithfulness.py +187 -0
- cotlab/experiments/cot_heads.py +208 -0
- cotlab/experiments/full_layer_cot.py +232 -0
- cotlab/experiments/full_layer_patching.py +225 -0
- cotlab/experiments/h_neuron_analysis.py +712 -0
- cotlab/experiments/logit_lens.py +439 -0
- cotlab/experiments/multi_head_cot.py +220 -0
- cotlab/experiments/multi_head_patching.py +229 -0
- cotlab/experiments/probing_classifier.py +402 -0
- cotlab/experiments/residual_norm_ood.py +413 -0
- cotlab/experiments/sae_feature_analysis.py +673 -0
- cotlab/experiments/steering_vectors.py +223 -0
- cotlab/experiments/sycophancy_heads.py +224 -0
- cotlab/logging/__init__.py +5 -0
- cotlab/logging/json_logger.py +161 -0
- cotlab/main.py +317 -0
- cotlab/patching/__init__.py +24 -0
- cotlab/patching/cache.py +141 -0
- cotlab/patching/hooks.py +558 -0
- cotlab/patching/interventions.py +86 -0
- cotlab/patching/patcher.py +439 -0
- cotlab/patching/sae.py +181 -0
- cotlab/prompts/__init__.py +43 -0
- cotlab/prompts/cardiology.py +378 -0
- cotlab/prompts/histopathology.py +265 -0
- cotlab/prompts/length_matched_strategies.py +157 -0
- cotlab/prompts/mcq.py +193 -0
- cotlab/prompts/neurology.py +353 -0
- cotlab/prompts/oncology.py +367 -0
- cotlab/prompts/plab.py +162 -0
- cotlab/prompts/pubhealthbench.py +82 -0
- cotlab/prompts/pubmedqa.py +173 -0
- cotlab/prompts/radiology.py +414 -0
- cotlab/prompts/strategies.py +939 -0
- cotlab/prompts/tcga.py +168 -0
- cotlab/runner.py +204 -0
- cotlab-0.8.0.dist-info/METADATA +166 -0
- cotlab-0.8.0.dist-info/RECORD +65 -0
- cotlab-0.8.0.dist-info/WHEEL +4 -0
- cotlab-0.8.0.dist-info/entry_points.txt +3 -0
- cotlab-0.8.0.dist-info/licenses/LICENSE +21 -0
cotlab/__init__.py
ADDED
|
@@ -0,0 +1,392 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Analyze CoTLab experiment results with improved answer extraction.
|
|
4
|
+
|
|
5
|
+
Usage:
|
|
6
|
+
python -m cotlab.analyse_experiments <results_dir>
|
|
7
|
+
python -m cotlab.analyse_experiments /path/to/experiment/results
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import json
|
|
11
|
+
import re
|
|
12
|
+
from collections import defaultdict
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
from typing import Optional
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def extract_answer(text: str) -> str:
|
|
18
|
+
"""Extract the final answer/diagnosis from a response."""
|
|
19
|
+
if not text:
|
|
20
|
+
return ""
|
|
21
|
+
|
|
22
|
+
text = text.strip().lower()
|
|
23
|
+
|
|
24
|
+
# 1. Try to extract from \boxed{...}
|
|
25
|
+
boxed = re.findall(r"\$?\\boxed\{([^}]+)\}\$?", text)
|
|
26
|
+
if boxed:
|
|
27
|
+
return boxed[0].strip().lower()
|
|
28
|
+
|
|
29
|
+
# 2. Try to extract from "Final Answer: ..." or "**Final Answer:**"
|
|
30
|
+
final_answer = re.search(
|
|
31
|
+
r"(?:final answer|answer)[:\s]*(?:the final answer is\s*)?[:\s]*([^\n$]+)",
|
|
32
|
+
text,
|
|
33
|
+
re.IGNORECASE,
|
|
34
|
+
)
|
|
35
|
+
if final_answer:
|
|
36
|
+
answer = final_answer.group(1).strip()
|
|
37
|
+
answer = re.sub(r"\$.*$", "", answer).strip()
|
|
38
|
+
answer = re.sub(r"^[*\s]+|[*\s]+$", "", answer)
|
|
39
|
+
if answer:
|
|
40
|
+
return answer.lower()
|
|
41
|
+
|
|
42
|
+
# 3. Try to extract from "Diagnosis: ..."
|
|
43
|
+
diagnosis = re.search(r"diagnosis[:\s]+([^\n,]+)", text, re.IGNORECASE)
|
|
44
|
+
if diagnosis:
|
|
45
|
+
return diagnosis.group(1).strip().lower()
|
|
46
|
+
|
|
47
|
+
# 4. If response is very short (single word/phrase), use it directly
|
|
48
|
+
words = text.split()
|
|
49
|
+
if len(words) <= 5 and words:
|
|
50
|
+
return words[0].strip("*.,!?\"'").lower()
|
|
51
|
+
|
|
52
|
+
# 5. Look for bold text (**diagnosis**)
|
|
53
|
+
bold = re.findall(r"\*\*([^*]+)\*\*", text)
|
|
54
|
+
if bold:
|
|
55
|
+
return bold[-1].strip().lower()
|
|
56
|
+
|
|
57
|
+
return text[:50].lower()
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def normalize_answer(answer) -> str:
|
|
61
|
+
"""Normalize answer for comparison."""
|
|
62
|
+
if not isinstance(answer, str):
|
|
63
|
+
return str(answer).lower().strip() if answer else ""
|
|
64
|
+
answer = answer.lower().strip()
|
|
65
|
+
answer = re.sub(r"^(the\s+|a\s+|an\s+)", "", answer)
|
|
66
|
+
answer = re.sub(r"\s+(disease|syndrome|disorder)$", "", answer)
|
|
67
|
+
answer = re.sub(r"[^\w\s]", "", answer)
|
|
68
|
+
return answer.strip()
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def answers_match(answer1: str, answer2: str) -> bool:
|
|
72
|
+
"""Check if two answers match (fuzzy matching)."""
|
|
73
|
+
a1 = normalize_answer(answer1)
|
|
74
|
+
a2 = normalize_answer(answer2)
|
|
75
|
+
|
|
76
|
+
if not a1 or not a2:
|
|
77
|
+
return False
|
|
78
|
+
|
|
79
|
+
if a1 == a2:
|
|
80
|
+
return True
|
|
81
|
+
|
|
82
|
+
if a1 in a2 or a2 in a1:
|
|
83
|
+
return True
|
|
84
|
+
|
|
85
|
+
if a1.split()[0] == a2.split()[0]:
|
|
86
|
+
return True
|
|
87
|
+
|
|
88
|
+
return False
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def analyse_experiment(results_path: Path) -> Optional[dict]:
|
|
92
|
+
"""Analyse a single experiment's results.json file."""
|
|
93
|
+
with open(results_path) as f:
|
|
94
|
+
data = json.load(f)
|
|
95
|
+
|
|
96
|
+
samples = data.get("samples", [])
|
|
97
|
+
if not samples:
|
|
98
|
+
return None
|
|
99
|
+
|
|
100
|
+
# Check if this is a radiology/classification experiment
|
|
101
|
+
first_sample = samples[0]
|
|
102
|
+
is_classification = "predicted" in first_sample and "ground_truth" in first_sample
|
|
103
|
+
|
|
104
|
+
if is_classification:
|
|
105
|
+
return analyse_classification_experiment(data)
|
|
106
|
+
else:
|
|
107
|
+
return analyse_faithfulness_experiment(data)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def analyse_classification_experiment(data: dict) -> dict:
|
|
111
|
+
"""Analyse a classification experiment (e.g., radiology)."""
|
|
112
|
+
samples = data.get("samples", [])
|
|
113
|
+
metrics = data.get("metrics", {})
|
|
114
|
+
|
|
115
|
+
correct = sum(1 for s in samples if s.get("correct", False))
|
|
116
|
+
n = len(samples)
|
|
117
|
+
|
|
118
|
+
return {
|
|
119
|
+
"num_samples": n,
|
|
120
|
+
"experiment_type": "classification",
|
|
121
|
+
# Use metrics from the experiment if available
|
|
122
|
+
"accuracy": metrics.get("accuracy", correct / n if n > 0 else 0),
|
|
123
|
+
"precision": metrics.get("precision", 0),
|
|
124
|
+
"recall": metrics.get("recall", 0),
|
|
125
|
+
"f1": metrics.get("f1", 0),
|
|
126
|
+
"true_positives": metrics.get("true_positives", 0),
|
|
127
|
+
"true_negatives": metrics.get("true_negatives", 0),
|
|
128
|
+
"false_positives": metrics.get("false_positives", 0),
|
|
129
|
+
"false_negatives": metrics.get("false_negatives", 0),
|
|
130
|
+
# For compatibility with CSV export
|
|
131
|
+
"agreement_rate": 0,
|
|
132
|
+
"cot_accuracy": metrics.get("accuracy", correct / n if n > 0 else 0),
|
|
133
|
+
"direct_accuracy": 0,
|
|
134
|
+
"agreements": 0,
|
|
135
|
+
"correct_cot": correct,
|
|
136
|
+
"correct_direct": 0,
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def analyse_faithfulness_experiment(data: dict) -> dict:
|
|
141
|
+
"""Analyse a CoT faithfulness experiment."""
|
|
142
|
+
samples = data.get("samples", [])
|
|
143
|
+
|
|
144
|
+
agreements = 0
|
|
145
|
+
correct_cot = 0
|
|
146
|
+
correct_direct = 0
|
|
147
|
+
|
|
148
|
+
for sample in samples:
|
|
149
|
+
cot_response = sample.get("cot_response", "") or sample.get("cot_answer", "")
|
|
150
|
+
direct_response = sample.get("direct_response", "") or sample.get("direct_answer", "")
|
|
151
|
+
expected = sample.get("expected_answer", "")
|
|
152
|
+
|
|
153
|
+
cot_answer = extract_answer(cot_response)
|
|
154
|
+
direct_answer = extract_answer(direct_response)
|
|
155
|
+
expected_answer = normalize_answer(expected)
|
|
156
|
+
|
|
157
|
+
if answers_match(cot_answer, direct_answer):
|
|
158
|
+
agreements += 1
|
|
159
|
+
|
|
160
|
+
if answers_match(cot_answer, expected_answer):
|
|
161
|
+
correct_cot += 1
|
|
162
|
+
if answers_match(direct_answer, expected_answer):
|
|
163
|
+
correct_direct += 1
|
|
164
|
+
|
|
165
|
+
n = len(samples)
|
|
166
|
+
return {
|
|
167
|
+
"num_samples": n,
|
|
168
|
+
"experiment_type": "faithfulness",
|
|
169
|
+
"agreement_rate": agreements / n if n > 0 else 0,
|
|
170
|
+
"cot_accuracy": correct_cot / n if n > 0 else 0,
|
|
171
|
+
"direct_accuracy": correct_direct / n if n > 0 else 0,
|
|
172
|
+
"agreements": agreements,
|
|
173
|
+
"correct_cot": correct_cot,
|
|
174
|
+
"correct_direct": correct_direct,
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def analyse_experiments_dir(results_dir: Path) -> list:
|
|
179
|
+
"""Analyse all experiments in a directory."""
|
|
180
|
+
all_results = []
|
|
181
|
+
|
|
182
|
+
for exp_dir in sorted(results_dir.iterdir()):
|
|
183
|
+
if not exp_dir.is_dir():
|
|
184
|
+
continue
|
|
185
|
+
|
|
186
|
+
results_file = exp_dir / "results.json"
|
|
187
|
+
if not results_file.exists():
|
|
188
|
+
continue
|
|
189
|
+
|
|
190
|
+
name = exp_dir.name
|
|
191
|
+
|
|
192
|
+
# First try to get prompt from results.json
|
|
193
|
+
try:
|
|
194
|
+
with open(results_file) as f:
|
|
195
|
+
data = json.load(f)
|
|
196
|
+
prompt = data.get("prompt_strategy", "")
|
|
197
|
+
# Try to get dataset from config if available
|
|
198
|
+
config = data.get("metadata", {}).get("config", {})
|
|
199
|
+
dataset = config.get("dataset", {}).get("name", "")
|
|
200
|
+
if not dataset:
|
|
201
|
+
# Fallback: extract from folder name prefix
|
|
202
|
+
parts = name.split("_")
|
|
203
|
+
dataset = parts[0] if parts else "unknown"
|
|
204
|
+
if not prompt:
|
|
205
|
+
prompt = name
|
|
206
|
+
except (json.JSONDecodeError, KeyError):
|
|
207
|
+
# Fallback to folder name parsing
|
|
208
|
+
parts = name.split("_")
|
|
209
|
+
if len(parts) >= 2:
|
|
210
|
+
dataset = parts[0]
|
|
211
|
+
prompt = "_".join(parts[1:])
|
|
212
|
+
else:
|
|
213
|
+
dataset = "unknown"
|
|
214
|
+
prompt = name
|
|
215
|
+
|
|
216
|
+
metrics = analyse_experiment(results_file)
|
|
217
|
+
if metrics:
|
|
218
|
+
metrics["experiment"] = name
|
|
219
|
+
metrics["dataset"] = dataset
|
|
220
|
+
metrics["prompt"] = prompt
|
|
221
|
+
all_results.append(metrics)
|
|
222
|
+
|
|
223
|
+
return all_results
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def print_analysis_report(all_results: list, title: str = "Experiment Analysis"):
|
|
227
|
+
"""Print a formatted analysis report."""
|
|
228
|
+
print("=" * 80)
|
|
229
|
+
print(title)
|
|
230
|
+
print("=" * 80)
|
|
231
|
+
print()
|
|
232
|
+
|
|
233
|
+
# Separate classification and faithfulness experiments
|
|
234
|
+
classification_results = [
|
|
235
|
+
r for r in all_results if r.get("experiment_type") == "classification"
|
|
236
|
+
]
|
|
237
|
+
faithfulness_results = [r for r in all_results if r.get("experiment_type") != "classification"]
|
|
238
|
+
|
|
239
|
+
# Print classification experiments first
|
|
240
|
+
if classification_results:
|
|
241
|
+
print("CLASSIFICATION EXPERIMENTS")
|
|
242
|
+
print("-" * 40)
|
|
243
|
+
print(f"{'Experiment':<30} {'Acc':>8} {'Prec':>8} {'Recall':>8} {'F1':>8} {'N':>6}")
|
|
244
|
+
print("-" * 70)
|
|
245
|
+
for r in classification_results:
|
|
246
|
+
exp_name = f"{r['dataset']}_{r['prompt']}"
|
|
247
|
+
print(
|
|
248
|
+
f"{exp_name:<30} {100 * r.get('accuracy', 0):>7.1f}% "
|
|
249
|
+
f"{100 * r.get('precision', 0):>7.1f}% {100 * r.get('recall', 0):>7.1f}% "
|
|
250
|
+
f"{r.get('f1', 0):>7.2f} {r['num_samples']:>6}"
|
|
251
|
+
)
|
|
252
|
+
print()
|
|
253
|
+
|
|
254
|
+
if not faithfulness_results:
|
|
255
|
+
return
|
|
256
|
+
|
|
257
|
+
# Group by prompt
|
|
258
|
+
by_prompt = defaultdict(list)
|
|
259
|
+
for r in faithfulness_results:
|
|
260
|
+
by_prompt[r["prompt"]].append(r)
|
|
261
|
+
|
|
262
|
+
print("COT FAITHFULNESS EXPERIMENTS")
|
|
263
|
+
print("-" * 40)
|
|
264
|
+
print(f"{'Prompt':<25} {'Agree%':>8} {'CoT Acc':>8} {'Direct Acc':>10} {'Samples':>8}")
|
|
265
|
+
print("-" * 60)
|
|
266
|
+
|
|
267
|
+
for prompt in sorted(by_prompt.keys()):
|
|
268
|
+
results = by_prompt[prompt]
|
|
269
|
+
total_samples = sum(r["num_samples"] for r in results)
|
|
270
|
+
total_agree = sum(r["agreements"] for r in results)
|
|
271
|
+
total_cot = sum(r["correct_cot"] for r in results)
|
|
272
|
+
total_direct = sum(r["correct_direct"] for r in results)
|
|
273
|
+
|
|
274
|
+
agree_pct = 100 * total_agree / total_samples if total_samples > 0 else 0
|
|
275
|
+
cot_acc = 100 * total_cot / total_samples if total_samples > 0 else 0
|
|
276
|
+
direct_acc = 100 * total_direct / total_samples if total_samples > 0 else 0
|
|
277
|
+
|
|
278
|
+
print(
|
|
279
|
+
f"{prompt:<25} {agree_pct:>7.1f}% {cot_acc:>7.1f}% {direct_acc:>9.1f}% {total_samples:>8}"
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
print()
|
|
283
|
+
print("=" * 80)
|
|
284
|
+
print("SUMMARY BY DATASET")
|
|
285
|
+
print("=" * 80)
|
|
286
|
+
|
|
287
|
+
by_dataset = defaultdict(list)
|
|
288
|
+
for r in faithfulness_results:
|
|
289
|
+
by_dataset[r["dataset"]].append(r)
|
|
290
|
+
|
|
291
|
+
print(f"{'Dataset':<20} {'Agree%':>8} {'CoT Acc':>8} {'Direct Acc':>10} {'Samples':>8}")
|
|
292
|
+
print("-" * 55)
|
|
293
|
+
|
|
294
|
+
for dataset in sorted(by_dataset.keys()):
|
|
295
|
+
results = by_dataset[dataset]
|
|
296
|
+
total_samples = sum(r["num_samples"] for r in results)
|
|
297
|
+
total_agree = sum(r["agreements"] for r in results)
|
|
298
|
+
total_cot = sum(r["correct_cot"] for r in results)
|
|
299
|
+
total_direct = sum(r["correct_direct"] for r in results)
|
|
300
|
+
|
|
301
|
+
agree_pct = 100 * total_agree / total_samples if total_samples > 0 else 0
|
|
302
|
+
cot_acc = 100 * total_cot / total_samples if total_samples > 0 else 0
|
|
303
|
+
direct_acc = 100 * total_direct / total_samples if total_samples > 0 else 0
|
|
304
|
+
|
|
305
|
+
print(
|
|
306
|
+
f"{dataset:<20} {agree_pct:>7.1f}% {cot_acc:>7.1f}% {direct_acc:>9.1f}% {total_samples:>8}"
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
# Overall
|
|
310
|
+
print()
|
|
311
|
+
total_samples = sum(r["num_samples"] for r in faithfulness_results)
|
|
312
|
+
total_agree = sum(r["agreements"] for r in faithfulness_results)
|
|
313
|
+
total_cot = sum(r["correct_cot"] for r in faithfulness_results)
|
|
314
|
+
total_direct = sum(r["correct_direct"] for r in faithfulness_results)
|
|
315
|
+
|
|
316
|
+
print(f"OVERALL (Faithfulness): {total_samples} samples")
|
|
317
|
+
print(f" - Agreement: {100 * total_agree / total_samples:.1f}%")
|
|
318
|
+
print(f" - CoT Accuracy: {100 * total_cot / total_samples:.1f}%")
|
|
319
|
+
print(f" - Direct Accuracy: {100 * total_direct / total_samples:.1f}%")
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
def export_to_csv(all_results: list, output_path: Path):
|
|
323
|
+
"""Export analysis results to CSV file."""
|
|
324
|
+
import csv
|
|
325
|
+
|
|
326
|
+
with open(output_path, "w", newline="") as f:
|
|
327
|
+
writer = csv.writer(f)
|
|
328
|
+
|
|
329
|
+
# Header
|
|
330
|
+
writer.writerow(
|
|
331
|
+
[
|
|
332
|
+
"experiment",
|
|
333
|
+
"dataset",
|
|
334
|
+
"prompt",
|
|
335
|
+
"num_samples",
|
|
336
|
+
"agreement_rate",
|
|
337
|
+
"cot_accuracy",
|
|
338
|
+
"direct_accuracy",
|
|
339
|
+
"cot_correct",
|
|
340
|
+
"direct_correct",
|
|
341
|
+
"agreements",
|
|
342
|
+
]
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
# Data rows
|
|
346
|
+
for r in all_results:
|
|
347
|
+
writer.writerow(
|
|
348
|
+
[
|
|
349
|
+
r["experiment"],
|
|
350
|
+
r["dataset"],
|
|
351
|
+
r["prompt"],
|
|
352
|
+
r["num_samples"],
|
|
353
|
+
f"{r['agreement_rate']:.4f}",
|
|
354
|
+
f"{r['cot_accuracy']:.4f}",
|
|
355
|
+
f"{r['direct_accuracy']:.4f}",
|
|
356
|
+
r["correct_cot"],
|
|
357
|
+
r["correct_direct"],
|
|
358
|
+
r["agreements"],
|
|
359
|
+
]
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
print(f"\nResults saved to: {output_path}")
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
def main():
|
|
366
|
+
import sys
|
|
367
|
+
|
|
368
|
+
if len(sys.argv) > 1:
|
|
369
|
+
results_dir = Path(sys.argv[1])
|
|
370
|
+
else:
|
|
371
|
+
# Default path for development
|
|
372
|
+
results_dir = Path("/Users/huseyin/Documents/CoT/18-41-38_medgemma27b-text-it-vLLM")
|
|
373
|
+
|
|
374
|
+
if not results_dir.exists():
|
|
375
|
+
print(f"Error: Directory not found: {results_dir}")
|
|
376
|
+
sys.exit(1)
|
|
377
|
+
|
|
378
|
+
all_results = analyse_experiments_dir(results_dir)
|
|
379
|
+
|
|
380
|
+
if not all_results:
|
|
381
|
+
print(f"No experiment results found in {results_dir}")
|
|
382
|
+
sys.exit(1)
|
|
383
|
+
|
|
384
|
+
print_analysis_report(all_results, f"Analysis: {results_dir.name}")
|
|
385
|
+
|
|
386
|
+
# Export to CSV
|
|
387
|
+
csv_path = results_dir / "analysis_results.csv"
|
|
388
|
+
export_to_csv(all_results, csv_path)
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
if __name__ == "__main__":
|
|
392
|
+
main()
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""Analysis and metrics module."""
|
|
2
|
+
|
|
3
|
+
from .cot_parser import CoTParser, ReasoningStep
|
|
4
|
+
from .faithfulness_metrics import FaithfulnessMetrics, FaithfulnessScore
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"CoTParser",
|
|
8
|
+
"ReasoningStep",
|
|
9
|
+
"FaithfulnessMetrics",
|
|
10
|
+
"FaithfulnessScore",
|
|
11
|
+
]
|
|
@@ -0,0 +1,243 @@
|
|
|
1
|
+
"""CoT Parser for extracting and analyzing reasoning steps."""
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Any, Dict, List, Optional
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class ReasoningStep:
|
|
10
|
+
"""A single step in chain of thought reasoning."""
|
|
11
|
+
|
|
12
|
+
index: int
|
|
13
|
+
text: str
|
|
14
|
+
is_claim: bool = False
|
|
15
|
+
is_conclusion: bool = False
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class CoTParser:
|
|
19
|
+
"""
|
|
20
|
+
Extract structure from Chain of Thought outputs.
|
|
21
|
+
|
|
22
|
+
Parses model outputs to identify:
|
|
23
|
+
- Numbered reasoning steps
|
|
24
|
+
- Factual claims
|
|
25
|
+
- Hedging/uncertainty language
|
|
26
|
+
- Final conclusions
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
# Patterns for step extraction
|
|
30
|
+
STEP_PATTERNS = [
|
|
31
|
+
r"(?:^|\n)\s*(\d+)[.):]\s*(.+?)(?=\n\s*\d+[.):)]|\n\n|$)", # 1. Step
|
|
32
|
+
r"(?:^|\n)\s*[-•*]\s*(.+?)(?=\n\s*[-•*]|\n\n|$)", # Bullet points
|
|
33
|
+
r"(?:^|\n)\s*(First|Second|Third|Then|Next|Finally)[,:]?\s*(.+?)(?=\n|$)", # Word numbered
|
|
34
|
+
]
|
|
35
|
+
|
|
36
|
+
# Hedging indicators
|
|
37
|
+
HEDGING_WORDS = [
|
|
38
|
+
"might",
|
|
39
|
+
"could",
|
|
40
|
+
"possibly",
|
|
41
|
+
"perhaps",
|
|
42
|
+
"maybe",
|
|
43
|
+
"uncertain",
|
|
44
|
+
"unsure",
|
|
45
|
+
"likely",
|
|
46
|
+
"probably",
|
|
47
|
+
"appears",
|
|
48
|
+
"seems",
|
|
49
|
+
"suggests",
|
|
50
|
+
"may",
|
|
51
|
+
"I think",
|
|
52
|
+
"I believe",
|
|
53
|
+
"not sure",
|
|
54
|
+
"unclear",
|
|
55
|
+
"would guess",
|
|
56
|
+
]
|
|
57
|
+
|
|
58
|
+
# Confidence indicators
|
|
59
|
+
CONFIDENCE_WORDS = [
|
|
60
|
+
"definitely",
|
|
61
|
+
"certainly",
|
|
62
|
+
"clearly",
|
|
63
|
+
"obviously",
|
|
64
|
+
"must be",
|
|
65
|
+
"undoubtedly",
|
|
66
|
+
"without doubt",
|
|
67
|
+
"absolutely",
|
|
68
|
+
"100%",
|
|
69
|
+
"confident",
|
|
70
|
+
"sure",
|
|
71
|
+
]
|
|
72
|
+
|
|
73
|
+
# Conclusion markers
|
|
74
|
+
CONCLUSION_MARKERS = [
|
|
75
|
+
"therefore",
|
|
76
|
+
"thus",
|
|
77
|
+
"so",
|
|
78
|
+
"hence",
|
|
79
|
+
"consequently",
|
|
80
|
+
"in conclusion",
|
|
81
|
+
"final answer",
|
|
82
|
+
"the answer is",
|
|
83
|
+
"this means",
|
|
84
|
+
"we can conclude",
|
|
85
|
+
]
|
|
86
|
+
|
|
87
|
+
def extract_steps(self, cot_text: str) -> List[ReasoningStep]:
|
|
88
|
+
"""
|
|
89
|
+
Parse numbered/bulleted reasoning steps from CoT.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
cot_text: Raw CoT output
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
List of ReasoningStep objects
|
|
96
|
+
"""
|
|
97
|
+
steps = []
|
|
98
|
+
|
|
99
|
+
# Try numbered pattern first
|
|
100
|
+
numbered = re.findall(
|
|
101
|
+
r"(?:^|\n)\s*(\d+)[.):]\s*(.+?)(?=\n\s*\d+[.):)]|\n\n|$)", cot_text, re.DOTALL
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
if numbered:
|
|
105
|
+
for idx, (num, text) in enumerate(numbered):
|
|
106
|
+
step = ReasoningStep(
|
|
107
|
+
index=idx, text=text.strip(), is_conclusion=self._is_conclusion(text)
|
|
108
|
+
)
|
|
109
|
+
steps.append(step)
|
|
110
|
+
else:
|
|
111
|
+
# Fall back to sentence-based splitting
|
|
112
|
+
sentences = re.split(r"(?<=[.!?])\s+", cot_text)
|
|
113
|
+
for idx, sent in enumerate(sentences):
|
|
114
|
+
if sent.strip():
|
|
115
|
+
steps.append(
|
|
116
|
+
ReasoningStep(
|
|
117
|
+
index=idx, text=sent.strip(), is_conclusion=self._is_conclusion(sent)
|
|
118
|
+
)
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
return steps
|
|
122
|
+
|
|
123
|
+
def identify_claims(self, cot_text: str) -> List[Dict[str, Any]]:
|
|
124
|
+
"""
|
|
125
|
+
Extract factual claims from reasoning.
|
|
126
|
+
|
|
127
|
+
A claim is a statement that asserts something as true.
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
List of dicts with 'text' and 'confidence' keys
|
|
131
|
+
"""
|
|
132
|
+
claims = []
|
|
133
|
+
sentences = re.split(r"(?<=[.!?])\s+", cot_text)
|
|
134
|
+
|
|
135
|
+
for sent in sentences:
|
|
136
|
+
sent = sent.strip()
|
|
137
|
+
if not sent:
|
|
138
|
+
continue
|
|
139
|
+
|
|
140
|
+
# Skip questions
|
|
141
|
+
if sent.endswith("?"):
|
|
142
|
+
continue
|
|
143
|
+
|
|
144
|
+
# Check if it's a claim vs procedural text
|
|
145
|
+
is_claim = any(
|
|
146
|
+
[
|
|
147
|
+
" is " in sent.lower(),
|
|
148
|
+
" are " in sent.lower(),
|
|
149
|
+
" has " in sent.lower(),
|
|
150
|
+
" have " in sent.lower(),
|
|
151
|
+
" indicates " in sent.lower(),
|
|
152
|
+
" suggests " in sent.lower(),
|
|
153
|
+
" shows " in sent.lower(),
|
|
154
|
+
]
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
if is_claim:
|
|
158
|
+
confidence = self._estimate_confidence(sent)
|
|
159
|
+
claims.append(
|
|
160
|
+
{"text": sent, "confidence": confidence, "has_hedging": confidence < 0.5}
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
return claims
|
|
164
|
+
|
|
165
|
+
def detect_hedging(self, cot_text: str) -> float:
|
|
166
|
+
"""
|
|
167
|
+
Measure uncertainty expressions in CoT.
|
|
168
|
+
|
|
169
|
+
Returns:
|
|
170
|
+
Score from 0 (no hedging) to 1 (heavy hedging)
|
|
171
|
+
"""
|
|
172
|
+
text_lower = cot_text.lower()
|
|
173
|
+
|
|
174
|
+
hedging_count = sum(1 for word in self.HEDGING_WORDS if word.lower() in text_lower)
|
|
175
|
+
confidence_count = sum(1 for word in self.CONFIDENCE_WORDS if word.lower() in text_lower)
|
|
176
|
+
|
|
177
|
+
total = hedging_count + confidence_count
|
|
178
|
+
if total == 0:
|
|
179
|
+
return 0.3 # Neutral default
|
|
180
|
+
|
|
181
|
+
return hedging_count / total
|
|
182
|
+
|
|
183
|
+
def extract_conclusion(self, cot_text: str) -> Optional[str]:
|
|
184
|
+
"""
|
|
185
|
+
Extract the final conclusion/answer from CoT.
|
|
186
|
+
|
|
187
|
+
Returns:
|
|
188
|
+
The conclusion text, or None if not found
|
|
189
|
+
"""
|
|
190
|
+
text_lower = cot_text.lower()
|
|
191
|
+
|
|
192
|
+
for marker in self.CONCLUSION_MARKERS:
|
|
193
|
+
pattern = rf"{marker}\s*[,:]?\s*(.+?)(?:\.|$)"
|
|
194
|
+
match = re.search(pattern, text_lower, re.IGNORECASE)
|
|
195
|
+
if match:
|
|
196
|
+
# Find the actual text in original case
|
|
197
|
+
start = match.start(1)
|
|
198
|
+
end = match.end(1)
|
|
199
|
+
return cot_text[start:end].strip()
|
|
200
|
+
|
|
201
|
+
# Fall back to last sentence
|
|
202
|
+
sentences = re.split(r"(?<=[.!?])\s+", cot_text)
|
|
203
|
+
if sentences:
|
|
204
|
+
return sentences[-1].strip()
|
|
205
|
+
|
|
206
|
+
return None
|
|
207
|
+
|
|
208
|
+
def analyze(self, cot_text: str) -> Dict[str, Any]:
|
|
209
|
+
"""
|
|
210
|
+
Full analysis of a CoT output.
|
|
211
|
+
|
|
212
|
+
Returns:
|
|
213
|
+
Dict with steps, claims, hedging score, and conclusion
|
|
214
|
+
"""
|
|
215
|
+
return {
|
|
216
|
+
"steps": self.extract_steps(cot_text),
|
|
217
|
+
"claims": self.identify_claims(cot_text),
|
|
218
|
+
"hedging_score": self.detect_hedging(cot_text),
|
|
219
|
+
"conclusion": self.extract_conclusion(cot_text),
|
|
220
|
+
"num_steps": len(self.extract_steps(cot_text)),
|
|
221
|
+
"word_count": len(cot_text.split()),
|
|
222
|
+
}
|
|
223
|
+
|
|
224
|
+
def _is_conclusion(self, text: str) -> bool:
|
|
225
|
+
"""Check if text is a conclusion."""
|
|
226
|
+
text_lower = text.lower()
|
|
227
|
+
return any(marker in text_lower for marker in self.CONCLUSION_MARKERS)
|
|
228
|
+
|
|
229
|
+
def _estimate_confidence(self, text: str) -> float:
|
|
230
|
+
"""Estimate confidence level of a claim."""
|
|
231
|
+
text_lower = text.lower()
|
|
232
|
+
|
|
233
|
+
has_hedging = any(w in text_lower for w in self.HEDGING_WORDS)
|
|
234
|
+
has_confidence = any(w in text_lower for w in self.CONFIDENCE_WORDS)
|
|
235
|
+
|
|
236
|
+
if has_hedging and not has_confidence:
|
|
237
|
+
return 0.3
|
|
238
|
+
elif has_confidence and not has_hedging:
|
|
239
|
+
return 0.9
|
|
240
|
+
elif has_hedging and has_confidence:
|
|
241
|
+
return 0.5
|
|
242
|
+
else:
|
|
243
|
+
return 0.6 # Neutral
|