cotlab 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cotlab/__init__.py +3 -0
- cotlab/analyse_experiments.py +392 -0
- cotlab/analysis/__init__.py +11 -0
- cotlab/analysis/cot_parser.py +243 -0
- cotlab/analysis/faithfulness_metrics.py +192 -0
- cotlab/backends/__init__.py +16 -0
- cotlab/backends/base.py +78 -0
- cotlab/backends/transformers_backend.py +335 -0
- cotlab/backends/vllm_backend.py +227 -0
- cotlab/cli.py +83 -0
- cotlab/core/__init__.py +34 -0
- cotlab/core/base.py +749 -0
- cotlab/core/config.py +90 -0
- cotlab/core/registry.py +68 -0
- cotlab/datasets/__init__.py +45 -0
- cotlab/datasets/loaders.py +1889 -0
- cotlab/experiment/__init__.py +315 -0
- cotlab/experiments/__init__.py +43 -0
- cotlab/experiments/activation_compare.py +290 -0
- cotlab/experiments/activation_patching.py +1050 -0
- cotlab/experiments/attention_analysis.py +885 -0
- cotlab/experiments/classification.py +235 -0
- cotlab/experiments/composite_shift_detector.py +524 -0
- cotlab/experiments/cot_ablation.py +277 -0
- cotlab/experiments/cot_faithfulness.py +187 -0
- cotlab/experiments/cot_heads.py +208 -0
- cotlab/experiments/full_layer_cot.py +232 -0
- cotlab/experiments/full_layer_patching.py +225 -0
- cotlab/experiments/h_neuron_analysis.py +712 -0
- cotlab/experiments/logit_lens.py +439 -0
- cotlab/experiments/multi_head_cot.py +220 -0
- cotlab/experiments/multi_head_patching.py +229 -0
- cotlab/experiments/probing_classifier.py +402 -0
- cotlab/experiments/residual_norm_ood.py +413 -0
- cotlab/experiments/sae_feature_analysis.py +673 -0
- cotlab/experiments/steering_vectors.py +223 -0
- cotlab/experiments/sycophancy_heads.py +224 -0
- cotlab/logging/__init__.py +5 -0
- cotlab/logging/json_logger.py +161 -0
- cotlab/main.py +317 -0
- cotlab/patching/__init__.py +24 -0
- cotlab/patching/cache.py +141 -0
- cotlab/patching/hooks.py +558 -0
- cotlab/patching/interventions.py +86 -0
- cotlab/patching/patcher.py +439 -0
- cotlab/patching/sae.py +181 -0
- cotlab/prompts/__init__.py +43 -0
- cotlab/prompts/cardiology.py +378 -0
- cotlab/prompts/histopathology.py +265 -0
- cotlab/prompts/length_matched_strategies.py +157 -0
- cotlab/prompts/mcq.py +193 -0
- cotlab/prompts/neurology.py +353 -0
- cotlab/prompts/oncology.py +367 -0
- cotlab/prompts/plab.py +162 -0
- cotlab/prompts/pubhealthbench.py +82 -0
- cotlab/prompts/pubmedqa.py +173 -0
- cotlab/prompts/radiology.py +414 -0
- cotlab/prompts/strategies.py +939 -0
- cotlab/prompts/tcga.py +168 -0
- cotlab/runner.py +204 -0
- cotlab-0.8.0.dist-info/METADATA +166 -0
- cotlab-0.8.0.dist-info/RECORD +65 -0
- cotlab-0.8.0.dist-info/WHEEL +4 -0
- cotlab-0.8.0.dist-info/entry_points.txt +3 -0
- cotlab-0.8.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,235 @@
|
|
|
1
|
+
"""Generic classification experiment for medical reports (binary and multi-class)."""
|
|
2
|
+
|
|
3
|
+
from collections import Counter
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
from tqdm import tqdm
|
|
7
|
+
|
|
8
|
+
from ..backends.base import InferenceBackend
|
|
9
|
+
from ..core.base import BaseExperiment, BasePromptStrategy, ExperimentResult
|
|
10
|
+
from ..core.registry import Registry
|
|
11
|
+
from ..datasets.loaders import BaseDataset
|
|
12
|
+
from ..logging import ExperimentLogger
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@Registry.register_experiment("classification")
|
|
16
|
+
class ClassificationExperiment(BaseExperiment):
|
|
17
|
+
"""
|
|
18
|
+
Generic classification experiment for medical reports.
|
|
19
|
+
|
|
20
|
+
Supports both binary classification (True/False labels) and
|
|
21
|
+
multi-class classification (string labels like cancer types).
|
|
22
|
+
Uses structured JSON output for reliable answer extraction.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
name: str = "classification",
|
|
28
|
+
description: str = "Classification from medical reports",
|
|
29
|
+
num_samples: int = -1, # Default to -1 (all samples)
|
|
30
|
+
**kwargs,
|
|
31
|
+
):
|
|
32
|
+
self._name = name
|
|
33
|
+
self.description = description
|
|
34
|
+
self.num_samples = num_samples
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
def name(self) -> str:
|
|
38
|
+
return self._name
|
|
39
|
+
|
|
40
|
+
def _compute_multiclass_metrics(self, y_true: list, y_pred: list, labels: list) -> dict:
|
|
41
|
+
"""Compute multi-class classification metrics."""
|
|
42
|
+
from sklearn.metrics import (
|
|
43
|
+
classification_report,
|
|
44
|
+
confusion_matrix,
|
|
45
|
+
f1_score,
|
|
46
|
+
precision_score,
|
|
47
|
+
recall_score,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
metrics = {}
|
|
51
|
+
|
|
52
|
+
# Per-class metrics via classification_report
|
|
53
|
+
report = classification_report(
|
|
54
|
+
y_true, y_pred, labels=labels, output_dict=True, zero_division=0
|
|
55
|
+
)
|
|
56
|
+
metrics["classification_report"] = report
|
|
57
|
+
|
|
58
|
+
# Macro and weighted averages
|
|
59
|
+
metrics["macro_precision"] = precision_score(
|
|
60
|
+
y_true, y_pred, labels=labels, average="macro", zero_division=0
|
|
61
|
+
)
|
|
62
|
+
metrics["macro_recall"] = recall_score(
|
|
63
|
+
y_true, y_pred, labels=labels, average="macro", zero_division=0
|
|
64
|
+
)
|
|
65
|
+
metrics["macro_f1"] = f1_score(
|
|
66
|
+
y_true, y_pred, labels=labels, average="macro", zero_division=0
|
|
67
|
+
)
|
|
68
|
+
metrics["weighted_f1"] = f1_score(
|
|
69
|
+
y_true, y_pred, labels=labels, average="weighted", zero_division=0
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
# Confusion matrix
|
|
73
|
+
cm = confusion_matrix(y_true, y_pred, labels=labels)
|
|
74
|
+
metrics["confusion_matrix"] = cm.tolist()
|
|
75
|
+
metrics["class_labels"] = labels
|
|
76
|
+
|
|
77
|
+
# Find top confused pairs
|
|
78
|
+
confused_pairs = []
|
|
79
|
+
for i, true_label in enumerate(labels):
|
|
80
|
+
for j, pred_label in enumerate(labels):
|
|
81
|
+
if i != j and cm[i][j] > 0:
|
|
82
|
+
confused_pairs.append((true_label, pred_label, int(cm[i][j])))
|
|
83
|
+
confused_pairs.sort(key=lambda x: x[2], reverse=True)
|
|
84
|
+
metrics["top_confused_pairs"] = confused_pairs[:10] # Top 10
|
|
85
|
+
|
|
86
|
+
# Class distribution
|
|
87
|
+
true_dist = Counter(y_true)
|
|
88
|
+
pred_dist = Counter(y_pred)
|
|
89
|
+
metrics["true_class_distribution"] = dict(true_dist)
|
|
90
|
+
metrics["pred_class_distribution"] = dict(pred_dist)
|
|
91
|
+
|
|
92
|
+
return metrics
|
|
93
|
+
|
|
94
|
+
def run(
|
|
95
|
+
self,
|
|
96
|
+
backend: InferenceBackend,
|
|
97
|
+
dataset: BaseDataset,
|
|
98
|
+
prompt_strategy: BasePromptStrategy,
|
|
99
|
+
num_samples: Optional[int] = None,
|
|
100
|
+
logger: Optional[ExperimentLogger] = None,
|
|
101
|
+
**kwargs,
|
|
102
|
+
) -> ExperimentResult:
|
|
103
|
+
"""Run the classification experiment."""
|
|
104
|
+
n_samples = num_samples if num_samples is not None else self.num_samples
|
|
105
|
+
|
|
106
|
+
if n_samples > 0 and n_samples < len(dataset):
|
|
107
|
+
samples = dataset.sample(n_samples)
|
|
108
|
+
else:
|
|
109
|
+
samples = list(dataset)
|
|
110
|
+
|
|
111
|
+
results = []
|
|
112
|
+
metrics = {
|
|
113
|
+
"correct": 0,
|
|
114
|
+
"incorrect": 0,
|
|
115
|
+
"parse_errors": 0,
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
# For multi-class metrics
|
|
119
|
+
y_true = []
|
|
120
|
+
y_pred = []
|
|
121
|
+
|
|
122
|
+
# Get prediction field from prompt strategy
|
|
123
|
+
prediction_field = getattr(
|
|
124
|
+
prompt_strategy, "get_prediction_field", lambda: "pathological_fracture"
|
|
125
|
+
)()
|
|
126
|
+
print(f"Running Classification Experiment on {len(samples)} samples...")
|
|
127
|
+
print(f" Prediction field: {prediction_field}")
|
|
128
|
+
|
|
129
|
+
# Prepare inputs
|
|
130
|
+
inputs = [{"text": s.text, "report": s.text, "metadata": s.metadata} for s in samples]
|
|
131
|
+
|
|
132
|
+
# Batch Generate
|
|
133
|
+
print("Generating responses...")
|
|
134
|
+
prompts = [prompt_strategy.build_prompt(i) for i in inputs]
|
|
135
|
+
system_prompt = None
|
|
136
|
+
get_system_message = getattr(prompt_strategy, "get_system_message", None)
|
|
137
|
+
if callable(get_system_message):
|
|
138
|
+
system_prompt = get_system_message()
|
|
139
|
+
if system_prompt is None:
|
|
140
|
+
get_system_prompt = getattr(prompt_strategy, "get_system_prompt", None)
|
|
141
|
+
if callable(get_system_prompt):
|
|
142
|
+
system_prompt = get_system_prompt()
|
|
143
|
+
|
|
144
|
+
outputs = backend.generate_batch(prompts, system_prompt=system_prompt, **kwargs)
|
|
145
|
+
|
|
146
|
+
# Process results
|
|
147
|
+
print("Analyzing results...")
|
|
148
|
+
for i, sample in enumerate(tqdm(samples, desc="Analyzing reports")):
|
|
149
|
+
output = outputs[i]
|
|
150
|
+
prompt = prompts[i]
|
|
151
|
+
|
|
152
|
+
# Parse response
|
|
153
|
+
parsed = prompt_strategy.parse_response(output.text)
|
|
154
|
+
|
|
155
|
+
# Extract prediction
|
|
156
|
+
if parsed.get("parse_error"):
|
|
157
|
+
metrics["parse_errors"] += 1
|
|
158
|
+
predicted = None
|
|
159
|
+
else:
|
|
160
|
+
predicted = parsed.get(prediction_field, None)
|
|
161
|
+
|
|
162
|
+
# Ground truth
|
|
163
|
+
ground_truth = sample.label
|
|
164
|
+
|
|
165
|
+
# Calculate metrics
|
|
166
|
+
if predicted is not None:
|
|
167
|
+
y_true.append(ground_truth)
|
|
168
|
+
y_pred.append(predicted)
|
|
169
|
+
|
|
170
|
+
if predicted == ground_truth:
|
|
171
|
+
metrics["correct"] += 1
|
|
172
|
+
else:
|
|
173
|
+
metrics["incorrect"] += 1
|
|
174
|
+
|
|
175
|
+
result = {
|
|
176
|
+
"sample_idx": sample.idx,
|
|
177
|
+
"input": sample.text[:500] + "..." if len(sample.text) > 500 else sample.text,
|
|
178
|
+
"prompt": prompt,
|
|
179
|
+
"system_prompt": system_prompt,
|
|
180
|
+
"response": output.text,
|
|
181
|
+
"predicted": predicted,
|
|
182
|
+
"ground_truth": ground_truth,
|
|
183
|
+
"correct": predicted == ground_truth if predicted is not None else None,
|
|
184
|
+
"reasoning": parsed.get("reasoning", ""),
|
|
185
|
+
}
|
|
186
|
+
results.append(result)
|
|
187
|
+
|
|
188
|
+
if logger:
|
|
189
|
+
logger.log_sample(sample.idx, result)
|
|
190
|
+
|
|
191
|
+
# Calculate final metrics
|
|
192
|
+
n = len(samples)
|
|
193
|
+
total_valid = metrics["correct"] + metrics["incorrect"]
|
|
194
|
+
|
|
195
|
+
metrics["accuracy"] = metrics["correct"] / total_valid if total_valid > 0 else 0
|
|
196
|
+
metrics["parse_error_rate"] = metrics["parse_errors"] / n if n > 0 else 0
|
|
197
|
+
|
|
198
|
+
# Determine if multi-class or binary
|
|
199
|
+
is_multiclass = len(y_true) > 0 and isinstance(y_true[0], str)
|
|
200
|
+
|
|
201
|
+
if is_multiclass and len(y_true) > 0:
|
|
202
|
+
# Multi-class metrics
|
|
203
|
+
all_labels = sorted(set(y_true + y_pred))
|
|
204
|
+
multiclass_metrics = self._compute_multiclass_metrics(y_true, y_pred, all_labels)
|
|
205
|
+
metrics.update(multiclass_metrics)
|
|
206
|
+
metrics["num_classes"] = len(all_labels)
|
|
207
|
+
else:
|
|
208
|
+
# Binary metrics (legacy support)
|
|
209
|
+
tp = sum(1 for yt, yp in zip(y_true, y_pred) if yt and yp)
|
|
210
|
+
tn = sum(1 for yt, yp in zip(y_true, y_pred) if not yt and not yp)
|
|
211
|
+
fp = sum(1 for yt, yp in zip(y_true, y_pred) if not yt and yp)
|
|
212
|
+
fn = sum(1 for yt, yp in zip(y_true, y_pred) if yt and not yp)
|
|
213
|
+
|
|
214
|
+
metrics["true_positives"] = tp
|
|
215
|
+
metrics["true_negatives"] = tn
|
|
216
|
+
metrics["false_positives"] = fp
|
|
217
|
+
metrics["false_negatives"] = fn
|
|
218
|
+
metrics["precision"] = tp / (tp + fp) if (tp + fp) > 0 else 0
|
|
219
|
+
metrics["recall"] = tp / (tp + fn) if (tp + fn) > 0 else 0
|
|
220
|
+
metrics["f1"] = (
|
|
221
|
+
2
|
|
222
|
+
* (metrics["precision"] * metrics["recall"])
|
|
223
|
+
/ (metrics["precision"] + metrics["recall"])
|
|
224
|
+
if (metrics["precision"] + metrics["recall"]) > 0
|
|
225
|
+
else 0
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
return ExperimentResult(
|
|
229
|
+
experiment_name=self.name,
|
|
230
|
+
model_name=backend.model_name or "unknown",
|
|
231
|
+
prompt_strategy=prompt_strategy.name,
|
|
232
|
+
metrics=metrics,
|
|
233
|
+
raw_outputs=results,
|
|
234
|
+
metadata={"num_samples": n, "description": self.description},
|
|
235
|
+
)
|