cotlab 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. cotlab/__init__.py +3 -0
  2. cotlab/analyse_experiments.py +392 -0
  3. cotlab/analysis/__init__.py +11 -0
  4. cotlab/analysis/cot_parser.py +243 -0
  5. cotlab/analysis/faithfulness_metrics.py +192 -0
  6. cotlab/backends/__init__.py +16 -0
  7. cotlab/backends/base.py +78 -0
  8. cotlab/backends/transformers_backend.py +335 -0
  9. cotlab/backends/vllm_backend.py +227 -0
  10. cotlab/cli.py +83 -0
  11. cotlab/core/__init__.py +34 -0
  12. cotlab/core/base.py +749 -0
  13. cotlab/core/config.py +90 -0
  14. cotlab/core/registry.py +68 -0
  15. cotlab/datasets/__init__.py +45 -0
  16. cotlab/datasets/loaders.py +1889 -0
  17. cotlab/experiment/__init__.py +315 -0
  18. cotlab/experiments/__init__.py +43 -0
  19. cotlab/experiments/activation_compare.py +290 -0
  20. cotlab/experiments/activation_patching.py +1050 -0
  21. cotlab/experiments/attention_analysis.py +885 -0
  22. cotlab/experiments/classification.py +235 -0
  23. cotlab/experiments/composite_shift_detector.py +524 -0
  24. cotlab/experiments/cot_ablation.py +277 -0
  25. cotlab/experiments/cot_faithfulness.py +187 -0
  26. cotlab/experiments/cot_heads.py +208 -0
  27. cotlab/experiments/full_layer_cot.py +232 -0
  28. cotlab/experiments/full_layer_patching.py +225 -0
  29. cotlab/experiments/h_neuron_analysis.py +712 -0
  30. cotlab/experiments/logit_lens.py +439 -0
  31. cotlab/experiments/multi_head_cot.py +220 -0
  32. cotlab/experiments/multi_head_patching.py +229 -0
  33. cotlab/experiments/probing_classifier.py +402 -0
  34. cotlab/experiments/residual_norm_ood.py +413 -0
  35. cotlab/experiments/sae_feature_analysis.py +673 -0
  36. cotlab/experiments/steering_vectors.py +223 -0
  37. cotlab/experiments/sycophancy_heads.py +224 -0
  38. cotlab/logging/__init__.py +5 -0
  39. cotlab/logging/json_logger.py +161 -0
  40. cotlab/main.py +317 -0
  41. cotlab/patching/__init__.py +24 -0
  42. cotlab/patching/cache.py +141 -0
  43. cotlab/patching/hooks.py +558 -0
  44. cotlab/patching/interventions.py +86 -0
  45. cotlab/patching/patcher.py +439 -0
  46. cotlab/patching/sae.py +181 -0
  47. cotlab/prompts/__init__.py +43 -0
  48. cotlab/prompts/cardiology.py +378 -0
  49. cotlab/prompts/histopathology.py +265 -0
  50. cotlab/prompts/length_matched_strategies.py +157 -0
  51. cotlab/prompts/mcq.py +193 -0
  52. cotlab/prompts/neurology.py +353 -0
  53. cotlab/prompts/oncology.py +367 -0
  54. cotlab/prompts/plab.py +162 -0
  55. cotlab/prompts/pubhealthbench.py +82 -0
  56. cotlab/prompts/pubmedqa.py +173 -0
  57. cotlab/prompts/radiology.py +414 -0
  58. cotlab/prompts/strategies.py +939 -0
  59. cotlab/prompts/tcga.py +168 -0
  60. cotlab/runner.py +204 -0
  61. cotlab-0.8.0.dist-info/METADATA +166 -0
  62. cotlab-0.8.0.dist-info/RECORD +65 -0
  63. cotlab-0.8.0.dist-info/WHEEL +4 -0
  64. cotlab-0.8.0.dist-info/entry_points.txt +3 -0
  65. cotlab-0.8.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,235 @@
1
+ """Generic classification experiment for medical reports (binary and multi-class)."""
2
+
3
+ from collections import Counter
4
+ from typing import Optional
5
+
6
+ from tqdm import tqdm
7
+
8
+ from ..backends.base import InferenceBackend
9
+ from ..core.base import BaseExperiment, BasePromptStrategy, ExperimentResult
10
+ from ..core.registry import Registry
11
+ from ..datasets.loaders import BaseDataset
12
+ from ..logging import ExperimentLogger
13
+
14
+
15
+ @Registry.register_experiment("classification")
16
+ class ClassificationExperiment(BaseExperiment):
17
+ """
18
+ Generic classification experiment for medical reports.
19
+
20
+ Supports both binary classification (True/False labels) and
21
+ multi-class classification (string labels like cancer types).
22
+ Uses structured JSON output for reliable answer extraction.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ name: str = "classification",
28
+ description: str = "Classification from medical reports",
29
+ num_samples: int = -1, # Default to -1 (all samples)
30
+ **kwargs,
31
+ ):
32
+ self._name = name
33
+ self.description = description
34
+ self.num_samples = num_samples
35
+
36
+ @property
37
+ def name(self) -> str:
38
+ return self._name
39
+
40
+ def _compute_multiclass_metrics(self, y_true: list, y_pred: list, labels: list) -> dict:
41
+ """Compute multi-class classification metrics."""
42
+ from sklearn.metrics import (
43
+ classification_report,
44
+ confusion_matrix,
45
+ f1_score,
46
+ precision_score,
47
+ recall_score,
48
+ )
49
+
50
+ metrics = {}
51
+
52
+ # Per-class metrics via classification_report
53
+ report = classification_report(
54
+ y_true, y_pred, labels=labels, output_dict=True, zero_division=0
55
+ )
56
+ metrics["classification_report"] = report
57
+
58
+ # Macro and weighted averages
59
+ metrics["macro_precision"] = precision_score(
60
+ y_true, y_pred, labels=labels, average="macro", zero_division=0
61
+ )
62
+ metrics["macro_recall"] = recall_score(
63
+ y_true, y_pred, labels=labels, average="macro", zero_division=0
64
+ )
65
+ metrics["macro_f1"] = f1_score(
66
+ y_true, y_pred, labels=labels, average="macro", zero_division=0
67
+ )
68
+ metrics["weighted_f1"] = f1_score(
69
+ y_true, y_pred, labels=labels, average="weighted", zero_division=0
70
+ )
71
+
72
+ # Confusion matrix
73
+ cm = confusion_matrix(y_true, y_pred, labels=labels)
74
+ metrics["confusion_matrix"] = cm.tolist()
75
+ metrics["class_labels"] = labels
76
+
77
+ # Find top confused pairs
78
+ confused_pairs = []
79
+ for i, true_label in enumerate(labels):
80
+ for j, pred_label in enumerate(labels):
81
+ if i != j and cm[i][j] > 0:
82
+ confused_pairs.append((true_label, pred_label, int(cm[i][j])))
83
+ confused_pairs.sort(key=lambda x: x[2], reverse=True)
84
+ metrics["top_confused_pairs"] = confused_pairs[:10] # Top 10
85
+
86
+ # Class distribution
87
+ true_dist = Counter(y_true)
88
+ pred_dist = Counter(y_pred)
89
+ metrics["true_class_distribution"] = dict(true_dist)
90
+ metrics["pred_class_distribution"] = dict(pred_dist)
91
+
92
+ return metrics
93
+
94
+ def run(
95
+ self,
96
+ backend: InferenceBackend,
97
+ dataset: BaseDataset,
98
+ prompt_strategy: BasePromptStrategy,
99
+ num_samples: Optional[int] = None,
100
+ logger: Optional[ExperimentLogger] = None,
101
+ **kwargs,
102
+ ) -> ExperimentResult:
103
+ """Run the classification experiment."""
104
+ n_samples = num_samples if num_samples is not None else self.num_samples
105
+
106
+ if n_samples > 0 and n_samples < len(dataset):
107
+ samples = dataset.sample(n_samples)
108
+ else:
109
+ samples = list(dataset)
110
+
111
+ results = []
112
+ metrics = {
113
+ "correct": 0,
114
+ "incorrect": 0,
115
+ "parse_errors": 0,
116
+ }
117
+
118
+ # For multi-class metrics
119
+ y_true = []
120
+ y_pred = []
121
+
122
+ # Get prediction field from prompt strategy
123
+ prediction_field = getattr(
124
+ prompt_strategy, "get_prediction_field", lambda: "pathological_fracture"
125
+ )()
126
+ print(f"Running Classification Experiment on {len(samples)} samples...")
127
+ print(f" Prediction field: {prediction_field}")
128
+
129
+ # Prepare inputs
130
+ inputs = [{"text": s.text, "report": s.text, "metadata": s.metadata} for s in samples]
131
+
132
+ # Batch Generate
133
+ print("Generating responses...")
134
+ prompts = [prompt_strategy.build_prompt(i) for i in inputs]
135
+ system_prompt = None
136
+ get_system_message = getattr(prompt_strategy, "get_system_message", None)
137
+ if callable(get_system_message):
138
+ system_prompt = get_system_message()
139
+ if system_prompt is None:
140
+ get_system_prompt = getattr(prompt_strategy, "get_system_prompt", None)
141
+ if callable(get_system_prompt):
142
+ system_prompt = get_system_prompt()
143
+
144
+ outputs = backend.generate_batch(prompts, system_prompt=system_prompt, **kwargs)
145
+
146
+ # Process results
147
+ print("Analyzing results...")
148
+ for i, sample in enumerate(tqdm(samples, desc="Analyzing reports")):
149
+ output = outputs[i]
150
+ prompt = prompts[i]
151
+
152
+ # Parse response
153
+ parsed = prompt_strategy.parse_response(output.text)
154
+
155
+ # Extract prediction
156
+ if parsed.get("parse_error"):
157
+ metrics["parse_errors"] += 1
158
+ predicted = None
159
+ else:
160
+ predicted = parsed.get(prediction_field, None)
161
+
162
+ # Ground truth
163
+ ground_truth = sample.label
164
+
165
+ # Calculate metrics
166
+ if predicted is not None:
167
+ y_true.append(ground_truth)
168
+ y_pred.append(predicted)
169
+
170
+ if predicted == ground_truth:
171
+ metrics["correct"] += 1
172
+ else:
173
+ metrics["incorrect"] += 1
174
+
175
+ result = {
176
+ "sample_idx": sample.idx,
177
+ "input": sample.text[:500] + "..." if len(sample.text) > 500 else sample.text,
178
+ "prompt": prompt,
179
+ "system_prompt": system_prompt,
180
+ "response": output.text,
181
+ "predicted": predicted,
182
+ "ground_truth": ground_truth,
183
+ "correct": predicted == ground_truth if predicted is not None else None,
184
+ "reasoning": parsed.get("reasoning", ""),
185
+ }
186
+ results.append(result)
187
+
188
+ if logger:
189
+ logger.log_sample(sample.idx, result)
190
+
191
+ # Calculate final metrics
192
+ n = len(samples)
193
+ total_valid = metrics["correct"] + metrics["incorrect"]
194
+
195
+ metrics["accuracy"] = metrics["correct"] / total_valid if total_valid > 0 else 0
196
+ metrics["parse_error_rate"] = metrics["parse_errors"] / n if n > 0 else 0
197
+
198
+ # Determine if multi-class or binary
199
+ is_multiclass = len(y_true) > 0 and isinstance(y_true[0], str)
200
+
201
+ if is_multiclass and len(y_true) > 0:
202
+ # Multi-class metrics
203
+ all_labels = sorted(set(y_true + y_pred))
204
+ multiclass_metrics = self._compute_multiclass_metrics(y_true, y_pred, all_labels)
205
+ metrics.update(multiclass_metrics)
206
+ metrics["num_classes"] = len(all_labels)
207
+ else:
208
+ # Binary metrics (legacy support)
209
+ tp = sum(1 for yt, yp in zip(y_true, y_pred) if yt and yp)
210
+ tn = sum(1 for yt, yp in zip(y_true, y_pred) if not yt and not yp)
211
+ fp = sum(1 for yt, yp in zip(y_true, y_pred) if not yt and yp)
212
+ fn = sum(1 for yt, yp in zip(y_true, y_pred) if yt and not yp)
213
+
214
+ metrics["true_positives"] = tp
215
+ metrics["true_negatives"] = tn
216
+ metrics["false_positives"] = fp
217
+ metrics["false_negatives"] = fn
218
+ metrics["precision"] = tp / (tp + fp) if (tp + fp) > 0 else 0
219
+ metrics["recall"] = tp / (tp + fn) if (tp + fn) > 0 else 0
220
+ metrics["f1"] = (
221
+ 2
222
+ * (metrics["precision"] * metrics["recall"])
223
+ / (metrics["precision"] + metrics["recall"])
224
+ if (metrics["precision"] + metrics["recall"]) > 0
225
+ else 0
226
+ )
227
+
228
+ return ExperimentResult(
229
+ experiment_name=self.name,
230
+ model_name=backend.model_name or "unknown",
231
+ prompt_strategy=prompt_strategy.name,
232
+ metrics=metrics,
233
+ raw_outputs=results,
234
+ metadata={"num_samples": n, "description": self.description},
235
+ )