cotlab 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. cotlab/__init__.py +3 -0
  2. cotlab/analyse_experiments.py +392 -0
  3. cotlab/analysis/__init__.py +11 -0
  4. cotlab/analysis/cot_parser.py +243 -0
  5. cotlab/analysis/faithfulness_metrics.py +192 -0
  6. cotlab/backends/__init__.py +16 -0
  7. cotlab/backends/base.py +78 -0
  8. cotlab/backends/transformers_backend.py +335 -0
  9. cotlab/backends/vllm_backend.py +227 -0
  10. cotlab/cli.py +83 -0
  11. cotlab/core/__init__.py +34 -0
  12. cotlab/core/base.py +749 -0
  13. cotlab/core/config.py +90 -0
  14. cotlab/core/registry.py +68 -0
  15. cotlab/datasets/__init__.py +45 -0
  16. cotlab/datasets/loaders.py +1889 -0
  17. cotlab/experiment/__init__.py +315 -0
  18. cotlab/experiments/__init__.py +43 -0
  19. cotlab/experiments/activation_compare.py +290 -0
  20. cotlab/experiments/activation_patching.py +1050 -0
  21. cotlab/experiments/attention_analysis.py +885 -0
  22. cotlab/experiments/classification.py +235 -0
  23. cotlab/experiments/composite_shift_detector.py +524 -0
  24. cotlab/experiments/cot_ablation.py +277 -0
  25. cotlab/experiments/cot_faithfulness.py +187 -0
  26. cotlab/experiments/cot_heads.py +208 -0
  27. cotlab/experiments/full_layer_cot.py +232 -0
  28. cotlab/experiments/full_layer_patching.py +225 -0
  29. cotlab/experiments/h_neuron_analysis.py +712 -0
  30. cotlab/experiments/logit_lens.py +439 -0
  31. cotlab/experiments/multi_head_cot.py +220 -0
  32. cotlab/experiments/multi_head_patching.py +229 -0
  33. cotlab/experiments/probing_classifier.py +402 -0
  34. cotlab/experiments/residual_norm_ood.py +413 -0
  35. cotlab/experiments/sae_feature_analysis.py +673 -0
  36. cotlab/experiments/steering_vectors.py +223 -0
  37. cotlab/experiments/sycophancy_heads.py +224 -0
  38. cotlab/logging/__init__.py +5 -0
  39. cotlab/logging/json_logger.py +161 -0
  40. cotlab/main.py +317 -0
  41. cotlab/patching/__init__.py +24 -0
  42. cotlab/patching/cache.py +141 -0
  43. cotlab/patching/hooks.py +558 -0
  44. cotlab/patching/interventions.py +86 -0
  45. cotlab/patching/patcher.py +439 -0
  46. cotlab/patching/sae.py +181 -0
  47. cotlab/prompts/__init__.py +43 -0
  48. cotlab/prompts/cardiology.py +378 -0
  49. cotlab/prompts/histopathology.py +265 -0
  50. cotlab/prompts/length_matched_strategies.py +157 -0
  51. cotlab/prompts/mcq.py +193 -0
  52. cotlab/prompts/neurology.py +353 -0
  53. cotlab/prompts/oncology.py +367 -0
  54. cotlab/prompts/plab.py +162 -0
  55. cotlab/prompts/pubhealthbench.py +82 -0
  56. cotlab/prompts/pubmedqa.py +173 -0
  57. cotlab/prompts/radiology.py +414 -0
  58. cotlab/prompts/strategies.py +939 -0
  59. cotlab/prompts/tcga.py +168 -0
  60. cotlab/runner.py +204 -0
  61. cotlab-0.8.0.dist-info/METADATA +166 -0
  62. cotlab-0.8.0.dist-info/RECORD +65 -0
  63. cotlab-0.8.0.dist-info/WHEEL +4 -0
  64. cotlab-0.8.0.dist-info/entry_points.txt +3 -0
  65. cotlab-0.8.0.dist-info/licenses/LICENSE +21 -0
cotlab/__init__.py ADDED
@@ -0,0 +1,3 @@
1
+ """CoTLab - Chain of Thought Research Toolkit."""
2
+
3
+ __version__ = "0.8.0"
@@ -0,0 +1,392 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Analyze CoTLab experiment results with improved answer extraction.
4
+
5
+ Usage:
6
+ python -m cotlab.analyse_experiments <results_dir>
7
+ python -m cotlab.analyse_experiments /path/to/experiment/results
8
+ """
9
+
10
+ import json
11
+ import re
12
+ from collections import defaultdict
13
+ from pathlib import Path
14
+ from typing import Optional
15
+
16
+
17
+ def extract_answer(text: str) -> str:
18
+ """Extract the final answer/diagnosis from a response."""
19
+ if not text:
20
+ return ""
21
+
22
+ text = text.strip().lower()
23
+
24
+ # 1. Try to extract from \boxed{...}
25
+ boxed = re.findall(r"\$?\\boxed\{([^}]+)\}\$?", text)
26
+ if boxed:
27
+ return boxed[0].strip().lower()
28
+
29
+ # 2. Try to extract from "Final Answer: ..." or "**Final Answer:**"
30
+ final_answer = re.search(
31
+ r"(?:final answer|answer)[:\s]*(?:the final answer is\s*)?[:\s]*([^\n$]+)",
32
+ text,
33
+ re.IGNORECASE,
34
+ )
35
+ if final_answer:
36
+ answer = final_answer.group(1).strip()
37
+ answer = re.sub(r"\$.*$", "", answer).strip()
38
+ answer = re.sub(r"^[*\s]+|[*\s]+$", "", answer)
39
+ if answer:
40
+ return answer.lower()
41
+
42
+ # 3. Try to extract from "Diagnosis: ..."
43
+ diagnosis = re.search(r"diagnosis[:\s]+([^\n,]+)", text, re.IGNORECASE)
44
+ if diagnosis:
45
+ return diagnosis.group(1).strip().lower()
46
+
47
+ # 4. If response is very short (single word/phrase), use it directly
48
+ words = text.split()
49
+ if len(words) <= 5 and words:
50
+ return words[0].strip("*.,!?\"'").lower()
51
+
52
+ # 5. Look for bold text (**diagnosis**)
53
+ bold = re.findall(r"\*\*([^*]+)\*\*", text)
54
+ if bold:
55
+ return bold[-1].strip().lower()
56
+
57
+ return text[:50].lower()
58
+
59
+
60
+ def normalize_answer(answer) -> str:
61
+ """Normalize answer for comparison."""
62
+ if not isinstance(answer, str):
63
+ return str(answer).lower().strip() if answer else ""
64
+ answer = answer.lower().strip()
65
+ answer = re.sub(r"^(the\s+|a\s+|an\s+)", "", answer)
66
+ answer = re.sub(r"\s+(disease|syndrome|disorder)$", "", answer)
67
+ answer = re.sub(r"[^\w\s]", "", answer)
68
+ return answer.strip()
69
+
70
+
71
+ def answers_match(answer1: str, answer2: str) -> bool:
72
+ """Check if two answers match (fuzzy matching)."""
73
+ a1 = normalize_answer(answer1)
74
+ a2 = normalize_answer(answer2)
75
+
76
+ if not a1 or not a2:
77
+ return False
78
+
79
+ if a1 == a2:
80
+ return True
81
+
82
+ if a1 in a2 or a2 in a1:
83
+ return True
84
+
85
+ if a1.split()[0] == a2.split()[0]:
86
+ return True
87
+
88
+ return False
89
+
90
+
91
+ def analyse_experiment(results_path: Path) -> Optional[dict]:
92
+ """Analyse a single experiment's results.json file."""
93
+ with open(results_path) as f:
94
+ data = json.load(f)
95
+
96
+ samples = data.get("samples", [])
97
+ if not samples:
98
+ return None
99
+
100
+ # Check if this is a radiology/classification experiment
101
+ first_sample = samples[0]
102
+ is_classification = "predicted" in first_sample and "ground_truth" in first_sample
103
+
104
+ if is_classification:
105
+ return analyse_classification_experiment(data)
106
+ else:
107
+ return analyse_faithfulness_experiment(data)
108
+
109
+
110
+ def analyse_classification_experiment(data: dict) -> dict:
111
+ """Analyse a classification experiment (e.g., radiology)."""
112
+ samples = data.get("samples", [])
113
+ metrics = data.get("metrics", {})
114
+
115
+ correct = sum(1 for s in samples if s.get("correct", False))
116
+ n = len(samples)
117
+
118
+ return {
119
+ "num_samples": n,
120
+ "experiment_type": "classification",
121
+ # Use metrics from the experiment if available
122
+ "accuracy": metrics.get("accuracy", correct / n if n > 0 else 0),
123
+ "precision": metrics.get("precision", 0),
124
+ "recall": metrics.get("recall", 0),
125
+ "f1": metrics.get("f1", 0),
126
+ "true_positives": metrics.get("true_positives", 0),
127
+ "true_negatives": metrics.get("true_negatives", 0),
128
+ "false_positives": metrics.get("false_positives", 0),
129
+ "false_negatives": metrics.get("false_negatives", 0),
130
+ # For compatibility with CSV export
131
+ "agreement_rate": 0,
132
+ "cot_accuracy": metrics.get("accuracy", correct / n if n > 0 else 0),
133
+ "direct_accuracy": 0,
134
+ "agreements": 0,
135
+ "correct_cot": correct,
136
+ "correct_direct": 0,
137
+ }
138
+
139
+
140
+ def analyse_faithfulness_experiment(data: dict) -> dict:
141
+ """Analyse a CoT faithfulness experiment."""
142
+ samples = data.get("samples", [])
143
+
144
+ agreements = 0
145
+ correct_cot = 0
146
+ correct_direct = 0
147
+
148
+ for sample in samples:
149
+ cot_response = sample.get("cot_response", "") or sample.get("cot_answer", "")
150
+ direct_response = sample.get("direct_response", "") or sample.get("direct_answer", "")
151
+ expected = sample.get("expected_answer", "")
152
+
153
+ cot_answer = extract_answer(cot_response)
154
+ direct_answer = extract_answer(direct_response)
155
+ expected_answer = normalize_answer(expected)
156
+
157
+ if answers_match(cot_answer, direct_answer):
158
+ agreements += 1
159
+
160
+ if answers_match(cot_answer, expected_answer):
161
+ correct_cot += 1
162
+ if answers_match(direct_answer, expected_answer):
163
+ correct_direct += 1
164
+
165
+ n = len(samples)
166
+ return {
167
+ "num_samples": n,
168
+ "experiment_type": "faithfulness",
169
+ "agreement_rate": agreements / n if n > 0 else 0,
170
+ "cot_accuracy": correct_cot / n if n > 0 else 0,
171
+ "direct_accuracy": correct_direct / n if n > 0 else 0,
172
+ "agreements": agreements,
173
+ "correct_cot": correct_cot,
174
+ "correct_direct": correct_direct,
175
+ }
176
+
177
+
178
+ def analyse_experiments_dir(results_dir: Path) -> list:
179
+ """Analyse all experiments in a directory."""
180
+ all_results = []
181
+
182
+ for exp_dir in sorted(results_dir.iterdir()):
183
+ if not exp_dir.is_dir():
184
+ continue
185
+
186
+ results_file = exp_dir / "results.json"
187
+ if not results_file.exists():
188
+ continue
189
+
190
+ name = exp_dir.name
191
+
192
+ # First try to get prompt from results.json
193
+ try:
194
+ with open(results_file) as f:
195
+ data = json.load(f)
196
+ prompt = data.get("prompt_strategy", "")
197
+ # Try to get dataset from config if available
198
+ config = data.get("metadata", {}).get("config", {})
199
+ dataset = config.get("dataset", {}).get("name", "")
200
+ if not dataset:
201
+ # Fallback: extract from folder name prefix
202
+ parts = name.split("_")
203
+ dataset = parts[0] if parts else "unknown"
204
+ if not prompt:
205
+ prompt = name
206
+ except (json.JSONDecodeError, KeyError):
207
+ # Fallback to folder name parsing
208
+ parts = name.split("_")
209
+ if len(parts) >= 2:
210
+ dataset = parts[0]
211
+ prompt = "_".join(parts[1:])
212
+ else:
213
+ dataset = "unknown"
214
+ prompt = name
215
+
216
+ metrics = analyse_experiment(results_file)
217
+ if metrics:
218
+ metrics["experiment"] = name
219
+ metrics["dataset"] = dataset
220
+ metrics["prompt"] = prompt
221
+ all_results.append(metrics)
222
+
223
+ return all_results
224
+
225
+
226
+ def print_analysis_report(all_results: list, title: str = "Experiment Analysis"):
227
+ """Print a formatted analysis report."""
228
+ print("=" * 80)
229
+ print(title)
230
+ print("=" * 80)
231
+ print()
232
+
233
+ # Separate classification and faithfulness experiments
234
+ classification_results = [
235
+ r for r in all_results if r.get("experiment_type") == "classification"
236
+ ]
237
+ faithfulness_results = [r for r in all_results if r.get("experiment_type") != "classification"]
238
+
239
+ # Print classification experiments first
240
+ if classification_results:
241
+ print("CLASSIFICATION EXPERIMENTS")
242
+ print("-" * 40)
243
+ print(f"{'Experiment':<30} {'Acc':>8} {'Prec':>8} {'Recall':>8} {'F1':>8} {'N':>6}")
244
+ print("-" * 70)
245
+ for r in classification_results:
246
+ exp_name = f"{r['dataset']}_{r['prompt']}"
247
+ print(
248
+ f"{exp_name:<30} {100 * r.get('accuracy', 0):>7.1f}% "
249
+ f"{100 * r.get('precision', 0):>7.1f}% {100 * r.get('recall', 0):>7.1f}% "
250
+ f"{r.get('f1', 0):>7.2f} {r['num_samples']:>6}"
251
+ )
252
+ print()
253
+
254
+ if not faithfulness_results:
255
+ return
256
+
257
+ # Group by prompt
258
+ by_prompt = defaultdict(list)
259
+ for r in faithfulness_results:
260
+ by_prompt[r["prompt"]].append(r)
261
+
262
+ print("COT FAITHFULNESS EXPERIMENTS")
263
+ print("-" * 40)
264
+ print(f"{'Prompt':<25} {'Agree%':>8} {'CoT Acc':>8} {'Direct Acc':>10} {'Samples':>8}")
265
+ print("-" * 60)
266
+
267
+ for prompt in sorted(by_prompt.keys()):
268
+ results = by_prompt[prompt]
269
+ total_samples = sum(r["num_samples"] for r in results)
270
+ total_agree = sum(r["agreements"] for r in results)
271
+ total_cot = sum(r["correct_cot"] for r in results)
272
+ total_direct = sum(r["correct_direct"] for r in results)
273
+
274
+ agree_pct = 100 * total_agree / total_samples if total_samples > 0 else 0
275
+ cot_acc = 100 * total_cot / total_samples if total_samples > 0 else 0
276
+ direct_acc = 100 * total_direct / total_samples if total_samples > 0 else 0
277
+
278
+ print(
279
+ f"{prompt:<25} {agree_pct:>7.1f}% {cot_acc:>7.1f}% {direct_acc:>9.1f}% {total_samples:>8}"
280
+ )
281
+
282
+ print()
283
+ print("=" * 80)
284
+ print("SUMMARY BY DATASET")
285
+ print("=" * 80)
286
+
287
+ by_dataset = defaultdict(list)
288
+ for r in faithfulness_results:
289
+ by_dataset[r["dataset"]].append(r)
290
+
291
+ print(f"{'Dataset':<20} {'Agree%':>8} {'CoT Acc':>8} {'Direct Acc':>10} {'Samples':>8}")
292
+ print("-" * 55)
293
+
294
+ for dataset in sorted(by_dataset.keys()):
295
+ results = by_dataset[dataset]
296
+ total_samples = sum(r["num_samples"] for r in results)
297
+ total_agree = sum(r["agreements"] for r in results)
298
+ total_cot = sum(r["correct_cot"] for r in results)
299
+ total_direct = sum(r["correct_direct"] for r in results)
300
+
301
+ agree_pct = 100 * total_agree / total_samples if total_samples > 0 else 0
302
+ cot_acc = 100 * total_cot / total_samples if total_samples > 0 else 0
303
+ direct_acc = 100 * total_direct / total_samples if total_samples > 0 else 0
304
+
305
+ print(
306
+ f"{dataset:<20} {agree_pct:>7.1f}% {cot_acc:>7.1f}% {direct_acc:>9.1f}% {total_samples:>8}"
307
+ )
308
+
309
+ # Overall
310
+ print()
311
+ total_samples = sum(r["num_samples"] for r in faithfulness_results)
312
+ total_agree = sum(r["agreements"] for r in faithfulness_results)
313
+ total_cot = sum(r["correct_cot"] for r in faithfulness_results)
314
+ total_direct = sum(r["correct_direct"] for r in faithfulness_results)
315
+
316
+ print(f"OVERALL (Faithfulness): {total_samples} samples")
317
+ print(f" - Agreement: {100 * total_agree / total_samples:.1f}%")
318
+ print(f" - CoT Accuracy: {100 * total_cot / total_samples:.1f}%")
319
+ print(f" - Direct Accuracy: {100 * total_direct / total_samples:.1f}%")
320
+
321
+
322
+ def export_to_csv(all_results: list, output_path: Path):
323
+ """Export analysis results to CSV file."""
324
+ import csv
325
+
326
+ with open(output_path, "w", newline="") as f:
327
+ writer = csv.writer(f)
328
+
329
+ # Header
330
+ writer.writerow(
331
+ [
332
+ "experiment",
333
+ "dataset",
334
+ "prompt",
335
+ "num_samples",
336
+ "agreement_rate",
337
+ "cot_accuracy",
338
+ "direct_accuracy",
339
+ "cot_correct",
340
+ "direct_correct",
341
+ "agreements",
342
+ ]
343
+ )
344
+
345
+ # Data rows
346
+ for r in all_results:
347
+ writer.writerow(
348
+ [
349
+ r["experiment"],
350
+ r["dataset"],
351
+ r["prompt"],
352
+ r["num_samples"],
353
+ f"{r['agreement_rate']:.4f}",
354
+ f"{r['cot_accuracy']:.4f}",
355
+ f"{r['direct_accuracy']:.4f}",
356
+ r["correct_cot"],
357
+ r["correct_direct"],
358
+ r["agreements"],
359
+ ]
360
+ )
361
+
362
+ print(f"\nResults saved to: {output_path}")
363
+
364
+
365
+ def main():
366
+ import sys
367
+
368
+ if len(sys.argv) > 1:
369
+ results_dir = Path(sys.argv[1])
370
+ else:
371
+ # Default path for development
372
+ results_dir = Path("/Users/huseyin/Documents/CoT/18-41-38_medgemma27b-text-it-vLLM")
373
+
374
+ if not results_dir.exists():
375
+ print(f"Error: Directory not found: {results_dir}")
376
+ sys.exit(1)
377
+
378
+ all_results = analyse_experiments_dir(results_dir)
379
+
380
+ if not all_results:
381
+ print(f"No experiment results found in {results_dir}")
382
+ sys.exit(1)
383
+
384
+ print_analysis_report(all_results, f"Analysis: {results_dir.name}")
385
+
386
+ # Export to CSV
387
+ csv_path = results_dir / "analysis_results.csv"
388
+ export_to_csv(all_results, csv_path)
389
+
390
+
391
+ if __name__ == "__main__":
392
+ main()
@@ -0,0 +1,11 @@
1
+ """Analysis and metrics module."""
2
+
3
+ from .cot_parser import CoTParser, ReasoningStep
4
+ from .faithfulness_metrics import FaithfulnessMetrics, FaithfulnessScore
5
+
6
+ __all__ = [
7
+ "CoTParser",
8
+ "ReasoningStep",
9
+ "FaithfulnessMetrics",
10
+ "FaithfulnessScore",
11
+ ]
@@ -0,0 +1,243 @@
1
+ """CoT Parser for extracting and analyzing reasoning steps."""
2
+
3
+ import re
4
+ from dataclasses import dataclass
5
+ from typing import Any, Dict, List, Optional
6
+
7
+
8
+ @dataclass
9
+ class ReasoningStep:
10
+ """A single step in chain of thought reasoning."""
11
+
12
+ index: int
13
+ text: str
14
+ is_claim: bool = False
15
+ is_conclusion: bool = False
16
+
17
+
18
+ class CoTParser:
19
+ """
20
+ Extract structure from Chain of Thought outputs.
21
+
22
+ Parses model outputs to identify:
23
+ - Numbered reasoning steps
24
+ - Factual claims
25
+ - Hedging/uncertainty language
26
+ - Final conclusions
27
+ """
28
+
29
+ # Patterns for step extraction
30
+ STEP_PATTERNS = [
31
+ r"(?:^|\n)\s*(\d+)[.):]\s*(.+?)(?=\n\s*\d+[.):)]|\n\n|$)", # 1. Step
32
+ r"(?:^|\n)\s*[-•*]\s*(.+?)(?=\n\s*[-•*]|\n\n|$)", # Bullet points
33
+ r"(?:^|\n)\s*(First|Second|Third|Then|Next|Finally)[,:]?\s*(.+?)(?=\n|$)", # Word numbered
34
+ ]
35
+
36
+ # Hedging indicators
37
+ HEDGING_WORDS = [
38
+ "might",
39
+ "could",
40
+ "possibly",
41
+ "perhaps",
42
+ "maybe",
43
+ "uncertain",
44
+ "unsure",
45
+ "likely",
46
+ "probably",
47
+ "appears",
48
+ "seems",
49
+ "suggests",
50
+ "may",
51
+ "I think",
52
+ "I believe",
53
+ "not sure",
54
+ "unclear",
55
+ "would guess",
56
+ ]
57
+
58
+ # Confidence indicators
59
+ CONFIDENCE_WORDS = [
60
+ "definitely",
61
+ "certainly",
62
+ "clearly",
63
+ "obviously",
64
+ "must be",
65
+ "undoubtedly",
66
+ "without doubt",
67
+ "absolutely",
68
+ "100%",
69
+ "confident",
70
+ "sure",
71
+ ]
72
+
73
+ # Conclusion markers
74
+ CONCLUSION_MARKERS = [
75
+ "therefore",
76
+ "thus",
77
+ "so",
78
+ "hence",
79
+ "consequently",
80
+ "in conclusion",
81
+ "final answer",
82
+ "the answer is",
83
+ "this means",
84
+ "we can conclude",
85
+ ]
86
+
87
+ def extract_steps(self, cot_text: str) -> List[ReasoningStep]:
88
+ """
89
+ Parse numbered/bulleted reasoning steps from CoT.
90
+
91
+ Args:
92
+ cot_text: Raw CoT output
93
+
94
+ Returns:
95
+ List of ReasoningStep objects
96
+ """
97
+ steps = []
98
+
99
+ # Try numbered pattern first
100
+ numbered = re.findall(
101
+ r"(?:^|\n)\s*(\d+)[.):]\s*(.+?)(?=\n\s*\d+[.):)]|\n\n|$)", cot_text, re.DOTALL
102
+ )
103
+
104
+ if numbered:
105
+ for idx, (num, text) in enumerate(numbered):
106
+ step = ReasoningStep(
107
+ index=idx, text=text.strip(), is_conclusion=self._is_conclusion(text)
108
+ )
109
+ steps.append(step)
110
+ else:
111
+ # Fall back to sentence-based splitting
112
+ sentences = re.split(r"(?<=[.!?])\s+", cot_text)
113
+ for idx, sent in enumerate(sentences):
114
+ if sent.strip():
115
+ steps.append(
116
+ ReasoningStep(
117
+ index=idx, text=sent.strip(), is_conclusion=self._is_conclusion(sent)
118
+ )
119
+ )
120
+
121
+ return steps
122
+
123
+ def identify_claims(self, cot_text: str) -> List[Dict[str, Any]]:
124
+ """
125
+ Extract factual claims from reasoning.
126
+
127
+ A claim is a statement that asserts something as true.
128
+
129
+ Returns:
130
+ List of dicts with 'text' and 'confidence' keys
131
+ """
132
+ claims = []
133
+ sentences = re.split(r"(?<=[.!?])\s+", cot_text)
134
+
135
+ for sent in sentences:
136
+ sent = sent.strip()
137
+ if not sent:
138
+ continue
139
+
140
+ # Skip questions
141
+ if sent.endswith("?"):
142
+ continue
143
+
144
+ # Check if it's a claim vs procedural text
145
+ is_claim = any(
146
+ [
147
+ " is " in sent.lower(),
148
+ " are " in sent.lower(),
149
+ " has " in sent.lower(),
150
+ " have " in sent.lower(),
151
+ " indicates " in sent.lower(),
152
+ " suggests " in sent.lower(),
153
+ " shows " in sent.lower(),
154
+ ]
155
+ )
156
+
157
+ if is_claim:
158
+ confidence = self._estimate_confidence(sent)
159
+ claims.append(
160
+ {"text": sent, "confidence": confidence, "has_hedging": confidence < 0.5}
161
+ )
162
+
163
+ return claims
164
+
165
+ def detect_hedging(self, cot_text: str) -> float:
166
+ """
167
+ Measure uncertainty expressions in CoT.
168
+
169
+ Returns:
170
+ Score from 0 (no hedging) to 1 (heavy hedging)
171
+ """
172
+ text_lower = cot_text.lower()
173
+
174
+ hedging_count = sum(1 for word in self.HEDGING_WORDS if word.lower() in text_lower)
175
+ confidence_count = sum(1 for word in self.CONFIDENCE_WORDS if word.lower() in text_lower)
176
+
177
+ total = hedging_count + confidence_count
178
+ if total == 0:
179
+ return 0.3 # Neutral default
180
+
181
+ return hedging_count / total
182
+
183
+ def extract_conclusion(self, cot_text: str) -> Optional[str]:
184
+ """
185
+ Extract the final conclusion/answer from CoT.
186
+
187
+ Returns:
188
+ The conclusion text, or None if not found
189
+ """
190
+ text_lower = cot_text.lower()
191
+
192
+ for marker in self.CONCLUSION_MARKERS:
193
+ pattern = rf"{marker}\s*[,:]?\s*(.+?)(?:\.|$)"
194
+ match = re.search(pattern, text_lower, re.IGNORECASE)
195
+ if match:
196
+ # Find the actual text in original case
197
+ start = match.start(1)
198
+ end = match.end(1)
199
+ return cot_text[start:end].strip()
200
+
201
+ # Fall back to last sentence
202
+ sentences = re.split(r"(?<=[.!?])\s+", cot_text)
203
+ if sentences:
204
+ return sentences[-1].strip()
205
+
206
+ return None
207
+
208
+ def analyze(self, cot_text: str) -> Dict[str, Any]:
209
+ """
210
+ Full analysis of a CoT output.
211
+
212
+ Returns:
213
+ Dict with steps, claims, hedging score, and conclusion
214
+ """
215
+ return {
216
+ "steps": self.extract_steps(cot_text),
217
+ "claims": self.identify_claims(cot_text),
218
+ "hedging_score": self.detect_hedging(cot_text),
219
+ "conclusion": self.extract_conclusion(cot_text),
220
+ "num_steps": len(self.extract_steps(cot_text)),
221
+ "word_count": len(cot_text.split()),
222
+ }
223
+
224
+ def _is_conclusion(self, text: str) -> bool:
225
+ """Check if text is a conclusion."""
226
+ text_lower = text.lower()
227
+ return any(marker in text_lower for marker in self.CONCLUSION_MARKERS)
228
+
229
+ def _estimate_confidence(self, text: str) -> float:
230
+ """Estimate confidence level of a claim."""
231
+ text_lower = text.lower()
232
+
233
+ has_hedging = any(w in text_lower for w in self.HEDGING_WORDS)
234
+ has_confidence = any(w in text_lower for w in self.CONFIDENCE_WORDS)
235
+
236
+ if has_hedging and not has_confidence:
237
+ return 0.3
238
+ elif has_confidence and not has_hedging:
239
+ return 0.9
240
+ elif has_hedging and has_confidence:
241
+ return 0.5
242
+ else:
243
+ return 0.6 # Neutral