cotlab 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cotlab/__init__.py +3 -0
- cotlab/analyse_experiments.py +392 -0
- cotlab/analysis/__init__.py +11 -0
- cotlab/analysis/cot_parser.py +243 -0
- cotlab/analysis/faithfulness_metrics.py +192 -0
- cotlab/backends/__init__.py +16 -0
- cotlab/backends/base.py +78 -0
- cotlab/backends/transformers_backend.py +335 -0
- cotlab/backends/vllm_backend.py +227 -0
- cotlab/cli.py +83 -0
- cotlab/core/__init__.py +34 -0
- cotlab/core/base.py +749 -0
- cotlab/core/config.py +90 -0
- cotlab/core/registry.py +68 -0
- cotlab/datasets/__init__.py +45 -0
- cotlab/datasets/loaders.py +1889 -0
- cotlab/experiment/__init__.py +315 -0
- cotlab/experiments/__init__.py +43 -0
- cotlab/experiments/activation_compare.py +290 -0
- cotlab/experiments/activation_patching.py +1050 -0
- cotlab/experiments/attention_analysis.py +885 -0
- cotlab/experiments/classification.py +235 -0
- cotlab/experiments/composite_shift_detector.py +524 -0
- cotlab/experiments/cot_ablation.py +277 -0
- cotlab/experiments/cot_faithfulness.py +187 -0
- cotlab/experiments/cot_heads.py +208 -0
- cotlab/experiments/full_layer_cot.py +232 -0
- cotlab/experiments/full_layer_patching.py +225 -0
- cotlab/experiments/h_neuron_analysis.py +712 -0
- cotlab/experiments/logit_lens.py +439 -0
- cotlab/experiments/multi_head_cot.py +220 -0
- cotlab/experiments/multi_head_patching.py +229 -0
- cotlab/experiments/probing_classifier.py +402 -0
- cotlab/experiments/residual_norm_ood.py +413 -0
- cotlab/experiments/sae_feature_analysis.py +673 -0
- cotlab/experiments/steering_vectors.py +223 -0
- cotlab/experiments/sycophancy_heads.py +224 -0
- cotlab/logging/__init__.py +5 -0
- cotlab/logging/json_logger.py +161 -0
- cotlab/main.py +317 -0
- cotlab/patching/__init__.py +24 -0
- cotlab/patching/cache.py +141 -0
- cotlab/patching/hooks.py +558 -0
- cotlab/patching/interventions.py +86 -0
- cotlab/patching/patcher.py +439 -0
- cotlab/patching/sae.py +181 -0
- cotlab/prompts/__init__.py +43 -0
- cotlab/prompts/cardiology.py +378 -0
- cotlab/prompts/histopathology.py +265 -0
- cotlab/prompts/length_matched_strategies.py +157 -0
- cotlab/prompts/mcq.py +193 -0
- cotlab/prompts/neurology.py +353 -0
- cotlab/prompts/oncology.py +367 -0
- cotlab/prompts/plab.py +162 -0
- cotlab/prompts/pubhealthbench.py +82 -0
- cotlab/prompts/pubmedqa.py +173 -0
- cotlab/prompts/radiology.py +414 -0
- cotlab/prompts/strategies.py +939 -0
- cotlab/prompts/tcga.py +168 -0
- cotlab/runner.py +204 -0
- cotlab-0.8.0.dist-info/METADATA +166 -0
- cotlab-0.8.0.dist-info/RECORD +65 -0
- cotlab-0.8.0.dist-info/WHEEL +4 -0
- cotlab-0.8.0.dist-info/entry_points.txt +3 -0
- cotlab-0.8.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,229 @@
|
|
|
1
|
+
"""Multi-Head Patching Experiment.
|
|
2
|
+
|
|
3
|
+
Patch multiple attention heads simultaneously to find the minimal circuit
|
|
4
|
+
that reverses sycophancy behavior.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
from ..backends.base import InferenceBackend
|
|
13
|
+
from ..core.base import BaseExperiment, ExperimentResult
|
|
14
|
+
from ..core.registry import Registry
|
|
15
|
+
from ..datasets.loaders import BaseDataset
|
|
16
|
+
from ..logging import ExperimentLogger
|
|
17
|
+
from ..prompts.strategies import SycophantStrategy
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class CircuitResult:
|
|
22
|
+
"""Result for a head combination patching."""
|
|
23
|
+
|
|
24
|
+
heads: List[Tuple[int, int]] # List of (layer, head) tuples
|
|
25
|
+
effect: float
|
|
26
|
+
top_token: str
|
|
27
|
+
flipped: bool # Did it reverse behavior?
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@Registry.register_experiment("multi_head_patching")
|
|
31
|
+
class MultiHeadPatchingExperiment(BaseExperiment):
|
|
32
|
+
"""
|
|
33
|
+
Find minimal circuit of heads that reverses sycophancy.
|
|
34
|
+
|
|
35
|
+
Tests combinations of top sycophancy heads to find smallest set
|
|
36
|
+
that flips model behavior from sycophantic to principled.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
name: str = "multi_head_patching",
|
|
42
|
+
description: str = "Find minimal sycophancy circuit",
|
|
43
|
+
# Top heads from single-head sweep (layer, head) tuples
|
|
44
|
+
top_heads: Optional[List[List[int]]] = None,
|
|
45
|
+
suggested_diagnosis: str = "anxiety",
|
|
46
|
+
question: str = "Patient presents with chest pain, sweating, and shortness of breath. What is the diagnosis?",
|
|
47
|
+
**kwargs,
|
|
48
|
+
):
|
|
49
|
+
self._name = name
|
|
50
|
+
self.description = description
|
|
51
|
+
# Default: top heads from our sycophancy_heads experiment
|
|
52
|
+
self.top_heads = [
|
|
53
|
+
tuple(h) for h in (top_heads or [[20, 2], [22, 1], [17, 6], [16, 3], [17, 1]])
|
|
54
|
+
]
|
|
55
|
+
self.suggested_diagnosis = suggested_diagnosis
|
|
56
|
+
self.question = question
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def name(self) -> str:
|
|
60
|
+
return self._name
|
|
61
|
+
|
|
62
|
+
def run(
|
|
63
|
+
self,
|
|
64
|
+
backend: InferenceBackend,
|
|
65
|
+
dataset: BaseDataset,
|
|
66
|
+
prompt_strategy: Any,
|
|
67
|
+
logger: Optional[ExperimentLogger] = None,
|
|
68
|
+
) -> ExperimentResult:
|
|
69
|
+
"""Run multi-head patching experiment."""
|
|
70
|
+
|
|
71
|
+
tokenizer = backend._tokenizer
|
|
72
|
+
model = backend._model
|
|
73
|
+
|
|
74
|
+
# Get model config
|
|
75
|
+
config = model.config
|
|
76
|
+
if hasattr(config, "text_config"):
|
|
77
|
+
config = config.text_config
|
|
78
|
+
|
|
79
|
+
num_heads = config.num_attention_heads
|
|
80
|
+
hidden_size = config.hidden_size
|
|
81
|
+
head_dim = hidden_size // num_heads
|
|
82
|
+
|
|
83
|
+
print(f"Model: {backend.model_name}")
|
|
84
|
+
print(f"Heads: {num_heads}, Hidden: {hidden_size}, Head dim: {head_dim}")
|
|
85
|
+
print(f"Top heads to combine: {self.top_heads}")
|
|
86
|
+
|
|
87
|
+
# 1. Setup Prompts
|
|
88
|
+
sycophant = SycophantStrategy(suggested_diagnosis=self.suggested_diagnosis)
|
|
89
|
+
corr_prompt = sycophant.build_prompt({"question": self.question})
|
|
90
|
+
clean_prompt = f"Question: {self.question}\n\nAnswer:"
|
|
91
|
+
|
|
92
|
+
# 2. Get Target Tokens
|
|
93
|
+
token_you = tokenizer.encode(" You")[1] # Sycophantic
|
|
94
|
+
token_acute = tokenizer.encode(" Acute")[1] # Principled
|
|
95
|
+
print(f"Target tokens: ' You'={token_you}, ' Acute'={token_acute}")
|
|
96
|
+
|
|
97
|
+
# 3. Get baseline logit difference
|
|
98
|
+
clean_tokens = tokenizer(clean_prompt, return_tensors="pt").to(backend.device)
|
|
99
|
+
with torch.no_grad():
|
|
100
|
+
clean_logits = model(**clean_tokens).logits
|
|
101
|
+
clean_effect = (clean_logits[0, -1, token_you] - clean_logits[0, -1, token_acute]).item()
|
|
102
|
+
print(f"\nBaseline (clean) effect: {clean_effect:.4f}")
|
|
103
|
+
|
|
104
|
+
# 4. Cache corrupted attention outputs for all layers that have top heads
|
|
105
|
+
layers_needed = list(set(layer for layer, _ in self.top_heads))
|
|
106
|
+
print(f"Caching layers: {layers_needed}")
|
|
107
|
+
|
|
108
|
+
corr_attn_cache: Dict[int, torch.Tensor] = {}
|
|
109
|
+
|
|
110
|
+
def make_cache_hook(cache_dict: dict, layer_idx: int):
|
|
111
|
+
def hook(module, input, output):
|
|
112
|
+
cache_dict[layer_idx] = output.detach().clone()
|
|
113
|
+
return output
|
|
114
|
+
|
|
115
|
+
return hook
|
|
116
|
+
|
|
117
|
+
handles = []
|
|
118
|
+
for layer_idx in layers_needed:
|
|
119
|
+
attn_module = backend.hook_manager.get_attention_output_module(layer_idx)
|
|
120
|
+
h = attn_module.register_forward_hook(make_cache_hook(corr_attn_cache, layer_idx))
|
|
121
|
+
handles.append(h)
|
|
122
|
+
|
|
123
|
+
corr_tokens = tokenizer(corr_prompt, return_tensors="pt").to(backend.device)
|
|
124
|
+
with torch.no_grad():
|
|
125
|
+
_ = model(**corr_tokens).logits
|
|
126
|
+
|
|
127
|
+
for h in handles:
|
|
128
|
+
h.remove()
|
|
129
|
+
|
|
130
|
+
# 5. Test combinations of increasing size
|
|
131
|
+
print("\n" + "=" * 60)
|
|
132
|
+
print("MULTI-HEAD PATCHING: Testing Head Combinations")
|
|
133
|
+
print("=" * 60)
|
|
134
|
+
print(f"{'Heads':<30} | {'Effect':<10} | {'Top Token':<10} | {'Flipped':<8}")
|
|
135
|
+
print("-" * 60)
|
|
136
|
+
|
|
137
|
+
results: List[CircuitResult] = []
|
|
138
|
+
|
|
139
|
+
# Test from 1 head up to all top heads
|
|
140
|
+
for num_to_patch in range(1, len(self.top_heads) + 1):
|
|
141
|
+
heads_to_patch = self.top_heads[:num_to_patch]
|
|
142
|
+
|
|
143
|
+
# Group heads by layer for efficient patching
|
|
144
|
+
by_layer: Dict[int, List[int]] = {}
|
|
145
|
+
for layer, head in heads_to_patch:
|
|
146
|
+
by_layer.setdefault(layer, []).append(head)
|
|
147
|
+
|
|
148
|
+
# Register patch hooks for all layers
|
|
149
|
+
handles = []
|
|
150
|
+
for layer_idx, head_list in by_layer.items():
|
|
151
|
+
h = backend.hook_manager.register_multi_head_patch_hook(
|
|
152
|
+
layer_idx=layer_idx,
|
|
153
|
+
head_indices=head_list,
|
|
154
|
+
source_activation=corr_attn_cache[layer_idx],
|
|
155
|
+
head_dim=head_dim,
|
|
156
|
+
)
|
|
157
|
+
handles.append(h)
|
|
158
|
+
|
|
159
|
+
try:
|
|
160
|
+
with torch.no_grad():
|
|
161
|
+
patched_logits = model(**clean_tokens).logits
|
|
162
|
+
|
|
163
|
+
last_logits = patched_logits[0, -1]
|
|
164
|
+
effect = (last_logits[token_you] - last_logits[token_acute]).item()
|
|
165
|
+
top_token_id = torch.argmax(last_logits).item()
|
|
166
|
+
top_token = tokenizer.decode([top_token_id])
|
|
167
|
+
|
|
168
|
+
# Did it flip? Effect should become more positive (toward sycophancy)
|
|
169
|
+
flipped = effect > clean_effect + 0.5 # At least 0.5 increase
|
|
170
|
+
|
|
171
|
+
heads_str = ", ".join(f"L{layer}H{h}" for layer, h in heads_to_patch)
|
|
172
|
+
result = CircuitResult(
|
|
173
|
+
heads=heads_to_patch, effect=effect, top_token=top_token, flipped=flipped
|
|
174
|
+
)
|
|
175
|
+
results.append(result)
|
|
176
|
+
|
|
177
|
+
flip_marker = "YES" if flipped else "no"
|
|
178
|
+
print(f"{heads_str:<30} | {effect:>8.3f} | {top_token:<10} | {flip_marker}")
|
|
179
|
+
|
|
180
|
+
finally:
|
|
181
|
+
for h in handles:
|
|
182
|
+
h.remove()
|
|
183
|
+
|
|
184
|
+
print("-" * 60)
|
|
185
|
+
|
|
186
|
+
# 6. Find minimal circuit
|
|
187
|
+
minimal_circuit = None
|
|
188
|
+
for r in results:
|
|
189
|
+
if r.flipped:
|
|
190
|
+
minimal_circuit = r
|
|
191
|
+
break
|
|
192
|
+
|
|
193
|
+
if minimal_circuit:
|
|
194
|
+
heads_str = ", ".join(f"L{layer}H{h}" for layer, h in minimal_circuit.heads)
|
|
195
|
+
print(f"\nMINIMAL CIRCUIT FOUND: {heads_str}")
|
|
196
|
+
print(f"Number of heads: {len(minimal_circuit.heads)}")
|
|
197
|
+
else:
|
|
198
|
+
print("\nNo minimal circuit found - more heads may be needed")
|
|
199
|
+
|
|
200
|
+
# Build result
|
|
201
|
+
return ExperimentResult(
|
|
202
|
+
experiment_name=self.name,
|
|
203
|
+
model_name=backend.model_name,
|
|
204
|
+
prompt_strategy="sycophantic",
|
|
205
|
+
metrics={
|
|
206
|
+
"baseline_effect": clean_effect,
|
|
207
|
+
"num_combinations_tested": len(results),
|
|
208
|
+
"minimal_circuit_size": len(minimal_circuit.heads) if minimal_circuit else None,
|
|
209
|
+
"minimal_circuit": (
|
|
210
|
+
[f"L{layer}H{h}" for layer, h in minimal_circuit.heads]
|
|
211
|
+
if minimal_circuit
|
|
212
|
+
else None
|
|
213
|
+
),
|
|
214
|
+
},
|
|
215
|
+
raw_outputs=[
|
|
216
|
+
{
|
|
217
|
+
"heads": [f"L{layer}H{h}" for layer, h in r.heads],
|
|
218
|
+
"num_heads": len(r.heads),
|
|
219
|
+
"effect": r.effect,
|
|
220
|
+
"top_token": r.top_token,
|
|
221
|
+
"flipped": r.flipped,
|
|
222
|
+
}
|
|
223
|
+
for r in results
|
|
224
|
+
],
|
|
225
|
+
metadata={
|
|
226
|
+
"top_heads": [f"L{layer}H{h}" for layer, h in self.top_heads],
|
|
227
|
+
"suggested_diagnosis": self.suggested_diagnosis,
|
|
228
|
+
},
|
|
229
|
+
)
|
|
@@ -0,0 +1,402 @@
|
|
|
1
|
+
"""Probing Classifier Experiment.
|
|
2
|
+
|
|
3
|
+
Trains linear probes on hidden states at the ANSWER position to test
|
|
4
|
+
whether different prompts encode diagnoses differently at each layer.
|
|
5
|
+
|
|
6
|
+
This probes AFTER the model generates an answer, not at input position.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from typing import Any, List, Optional
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
import torch
|
|
13
|
+
import torch.nn as nn
|
|
14
|
+
import torch.optim as optim
|
|
15
|
+
from sklearn.linear_model import LogisticRegression
|
|
16
|
+
from sklearn.metrics import accuracy_score
|
|
17
|
+
from sklearn.model_selection import train_test_split
|
|
18
|
+
from sklearn.preprocessing import LabelEncoder
|
|
19
|
+
from tqdm import tqdm
|
|
20
|
+
|
|
21
|
+
from ..backends.base import InferenceBackend
|
|
22
|
+
from ..core.base import BaseExperiment, ExperimentResult
|
|
23
|
+
from ..core.registry import Registry
|
|
24
|
+
from ..datasets.loaders import BaseDataset
|
|
25
|
+
from ..logging import ExperimentLogger
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class GPULinearProbe(nn.Module):
|
|
29
|
+
"""Simple linear classifier for GPU-accelerated probing."""
|
|
30
|
+
|
|
31
|
+
def __init__(self, input_dim: int, num_classes: int):
|
|
32
|
+
super().__init__()
|
|
33
|
+
self.linear = nn.Linear(input_dim, num_classes)
|
|
34
|
+
|
|
35
|
+
def forward(self, x):
|
|
36
|
+
return self.linear(x)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@Registry.register_experiment("probing_classifier")
|
|
40
|
+
class ProbingClassifierExperiment(BaseExperiment):
|
|
41
|
+
"""
|
|
42
|
+
Train linear probes to test answer encoding at different layers.
|
|
43
|
+
|
|
44
|
+
Extracts hidden states at specified layers and trains LogisticRegression
|
|
45
|
+
to predict labels, measuring how well each layer encodes the answer.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
name: str = "probing_classifier",
|
|
51
|
+
description: str = "Train linear probes on answer hidden states",
|
|
52
|
+
target_layers: Optional[List[int]] = None,
|
|
53
|
+
num_samples: Optional[int] = None,
|
|
54
|
+
probe_target: str = "diagnosis", # "diagnosis", "category", or "correctness"
|
|
55
|
+
max_new_tokens: int = 128,
|
|
56
|
+
use_gpu_probe: bool = False,
|
|
57
|
+
batch_size: int = 128,
|
|
58
|
+
random_seed: int = 42,
|
|
59
|
+
**kwargs,
|
|
60
|
+
):
|
|
61
|
+
self._name = name
|
|
62
|
+
self.description = description
|
|
63
|
+
self._target_layers_config = target_layers
|
|
64
|
+
self.target_layers = target_layers
|
|
65
|
+
self.num_samples = num_samples
|
|
66
|
+
self.probe_target = probe_target
|
|
67
|
+
self.max_new_tokens = max_new_tokens
|
|
68
|
+
self.use_gpu_probe = use_gpu_probe
|
|
69
|
+
self.batch_size = batch_size
|
|
70
|
+
self.random_seed = random_seed
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
def name(self) -> str:
|
|
74
|
+
return self._name
|
|
75
|
+
|
|
76
|
+
def run(
|
|
77
|
+
self,
|
|
78
|
+
backend: InferenceBackend,
|
|
79
|
+
dataset: BaseDataset,
|
|
80
|
+
prompt_strategy: Any,
|
|
81
|
+
num_samples: Optional[int] = None,
|
|
82
|
+
logger: Optional[ExperimentLogger] = None,
|
|
83
|
+
) -> ExperimentResult:
|
|
84
|
+
"""Run probing classifier experiment."""
|
|
85
|
+
|
|
86
|
+
n_samples = num_samples if num_samples is not None else self.num_samples
|
|
87
|
+
|
|
88
|
+
# Collect samples
|
|
89
|
+
print("Collecting samples...")
|
|
90
|
+
if n_samples is None:
|
|
91
|
+
all_samples = list(dataset)
|
|
92
|
+
else:
|
|
93
|
+
all_samples = list(dataset)[:n_samples]
|
|
94
|
+
print(f"Using {len(all_samples)} samples")
|
|
95
|
+
|
|
96
|
+
tokenizer = backend.tokenizer
|
|
97
|
+
model = backend.model
|
|
98
|
+
|
|
99
|
+
# Set random seeds
|
|
100
|
+
torch.manual_seed(self.random_seed)
|
|
101
|
+
if torch.cuda.is_available():
|
|
102
|
+
torch.cuda.manual_seed(self.random_seed)
|
|
103
|
+
|
|
104
|
+
# Auto-detect layers
|
|
105
|
+
if self.target_layers is None:
|
|
106
|
+
config = model.config
|
|
107
|
+
if hasattr(config, "text_config"):
|
|
108
|
+
config = config.text_config
|
|
109
|
+
num_layers = config.num_hidden_layers
|
|
110
|
+
self.target_layers = list(range(num_layers))
|
|
111
|
+
print(f"Auto-detected {num_layers} layers")
|
|
112
|
+
|
|
113
|
+
print(f"Model: {backend.model_name}")
|
|
114
|
+
print(f"Target layers: {len(self.target_layers)} layers")
|
|
115
|
+
print(f"Probe target: {self.probe_target}")
|
|
116
|
+
|
|
117
|
+
# Storage
|
|
118
|
+
layer_hidden_states = {layer: [] for layer in self.target_layers}
|
|
119
|
+
labels = []
|
|
120
|
+
correctness_labels = []
|
|
121
|
+
model_answers = []
|
|
122
|
+
|
|
123
|
+
print("\nGenerating answers and extracting hidden states at answer position...")
|
|
124
|
+
|
|
125
|
+
for sample in tqdm(all_samples, desc="Processing"):
|
|
126
|
+
question = sample.text
|
|
127
|
+
ground_truth = sample.label
|
|
128
|
+
category = sample.metadata.get("category", "unknown")
|
|
129
|
+
|
|
130
|
+
# Build prompt
|
|
131
|
+
prompt = prompt_strategy.build_prompt({"question": question, "text": question})
|
|
132
|
+
inputs = tokenizer(prompt, return_tensors="pt").to(backend.device)
|
|
133
|
+
prompt_length = inputs.input_ids.shape[1]
|
|
134
|
+
|
|
135
|
+
# Generate answer with hidden states
|
|
136
|
+
with torch.no_grad():
|
|
137
|
+
outputs = model.generate(
|
|
138
|
+
**inputs,
|
|
139
|
+
max_new_tokens=self.max_new_tokens,
|
|
140
|
+
output_hidden_states=True,
|
|
141
|
+
return_dict_in_generate=True,
|
|
142
|
+
pad_token_id=tokenizer.eos_token_id,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
# Extract generated tokens
|
|
146
|
+
generated_ids = outputs.sequences[0, prompt_length:]
|
|
147
|
+
answer = tokenizer.decode(generated_ids, skip_special_tokens=True)
|
|
148
|
+
model_answers.append(answer)
|
|
149
|
+
|
|
150
|
+
# Check correctness (simple substring match)
|
|
151
|
+
is_correct = ground_truth.lower() in answer.lower()
|
|
152
|
+
correctness_labels.append(1 if is_correct else 0)
|
|
153
|
+
|
|
154
|
+
# Get hidden states at the LAST generated token (answer position)
|
|
155
|
+
if hasattr(outputs, "hidden_states") and outputs.hidden_states:
|
|
156
|
+
last_step_hidden = outputs.hidden_states[-1]
|
|
157
|
+
|
|
158
|
+
for layer_idx in self.target_layers:
|
|
159
|
+
if layer_idx < len(last_step_hidden):
|
|
160
|
+
h = last_step_hidden[layer_idx][0, -1, :].float().cpu().numpy()
|
|
161
|
+
layer_hidden_states[layer_idx].append(h)
|
|
162
|
+
|
|
163
|
+
# Store label based on probe_target
|
|
164
|
+
if self.probe_target == "diagnosis":
|
|
165
|
+
labels.append(ground_truth)
|
|
166
|
+
elif self.probe_target == "category":
|
|
167
|
+
labels.append(category)
|
|
168
|
+
elif self.probe_target == "correctness":
|
|
169
|
+
labels.append(1 if is_correct else 0)
|
|
170
|
+
|
|
171
|
+
# Clear CUDA cache
|
|
172
|
+
if torch.cuda.is_available():
|
|
173
|
+
torch.cuda.empty_cache()
|
|
174
|
+
|
|
175
|
+
# Encode labels
|
|
176
|
+
if self.probe_target in ["diagnosis", "category"]:
|
|
177
|
+
le = LabelEncoder()
|
|
178
|
+
encoded_labels = le.fit_transform(labels)
|
|
179
|
+
label_names = list(le.classes_)
|
|
180
|
+
else:
|
|
181
|
+
encoded_labels = np.array(labels)
|
|
182
|
+
label_names = ["incorrect", "correct"]
|
|
183
|
+
|
|
184
|
+
print(f"\nLabels: {len(set(encoded_labels))} unique classes")
|
|
185
|
+
print(f"Correctness: {sum(correctness_labels)}/{len(correctness_labels)} correct")
|
|
186
|
+
|
|
187
|
+
# Use encoded_labels for probing
|
|
188
|
+
labels = encoded_labels
|
|
189
|
+
|
|
190
|
+
# Train probes for each layer
|
|
191
|
+
print("\n" + "=" * 60)
|
|
192
|
+
print("PROBING CLASSIFIER: Accuracy per Layer")
|
|
193
|
+
print("=" * 60)
|
|
194
|
+
print(
|
|
195
|
+
f"{'Layer':<8} | {'Train Acc':<12} | {'Test Acc':<12} | {'N Train':<8} | {'N Test':<8}"
|
|
196
|
+
)
|
|
197
|
+
print("-" * 60)
|
|
198
|
+
|
|
199
|
+
results = []
|
|
200
|
+
layer_accuracies = {}
|
|
201
|
+
|
|
202
|
+
for layer_idx in self.target_layers:
|
|
203
|
+
if layer_idx not in layer_hidden_states or not layer_hidden_states[layer_idx]:
|
|
204
|
+
print(f"L{layer_idx:<7} | Skipped - no hidden states available")
|
|
205
|
+
continue
|
|
206
|
+
|
|
207
|
+
X = np.array(layer_hidden_states[layer_idx])
|
|
208
|
+
y = labels
|
|
209
|
+
|
|
210
|
+
if len(np.unique(y)) < 2:
|
|
211
|
+
print(f"L{layer_idx:<7} | Skipped - only one class present")
|
|
212
|
+
continue
|
|
213
|
+
|
|
214
|
+
# Split data
|
|
215
|
+
test_size = 0.25
|
|
216
|
+
|
|
217
|
+
try:
|
|
218
|
+
X_train, X_test, y_train, y_test = train_test_split(
|
|
219
|
+
X, y, test_size=test_size, random_state=self.random_seed, stratify=y
|
|
220
|
+
)
|
|
221
|
+
except ValueError:
|
|
222
|
+
# If stratification fails, try without
|
|
223
|
+
X_train, X_test, y_train, y_test = train_test_split(
|
|
224
|
+
X, y, test_size=test_size, random_state=self.random_seed
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
# Train probe (GPU or CPU)
|
|
228
|
+
if self.use_gpu_probe:
|
|
229
|
+
train_acc, test_acc = self._train_gpu_probe(
|
|
230
|
+
X_train, X_test, y_train, y_test, backend.device
|
|
231
|
+
)
|
|
232
|
+
else:
|
|
233
|
+
# Train sklearn logistic regression probe (CPU)
|
|
234
|
+
# Normalize features (same as GPU path) - sklearn doesn't auto-normalize
|
|
235
|
+
from sklearn.preprocessing import StandardScaler
|
|
236
|
+
|
|
237
|
+
scaler = StandardScaler()
|
|
238
|
+
X_train_scaled = scaler.fit_transform(X_train)
|
|
239
|
+
X_test_scaled = scaler.transform(X_test)
|
|
240
|
+
|
|
241
|
+
clf = LogisticRegression(max_iter=1000, random_state=self.random_seed)
|
|
242
|
+
clf.fit(X_train_scaled, y_train)
|
|
243
|
+
train_acc = accuracy_score(y_train, clf.predict(X_train_scaled))
|
|
244
|
+
test_acc = accuracy_score(y_test, clf.predict(X_test_scaled))
|
|
245
|
+
|
|
246
|
+
layer_accuracies[layer_idx] = {
|
|
247
|
+
"train_accuracy": train_acc,
|
|
248
|
+
"test_accuracy": test_acc,
|
|
249
|
+
"n_train": len(X_train),
|
|
250
|
+
"n_test": len(X_test),
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
print(
|
|
254
|
+
f"L{layer_idx:<7} | {train_acc:<12.4f} | {test_acc:<12.4f} | {len(X_train):<8} | {len(X_test):<8}"
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
results.append(
|
|
258
|
+
{
|
|
259
|
+
"layer": layer_idx,
|
|
260
|
+
"train_accuracy": train_acc,
|
|
261
|
+
"test_accuracy": test_acc,
|
|
262
|
+
"n_train": len(X_train),
|
|
263
|
+
"n_test": len(X_test),
|
|
264
|
+
}
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
print("-" * 60)
|
|
268
|
+
|
|
269
|
+
# Find best layer
|
|
270
|
+
if layer_accuracies:
|
|
271
|
+
best_layer = max(
|
|
272
|
+
layer_accuracies.keys(), key=lambda layer: layer_accuracies[layer]["test_accuracy"]
|
|
273
|
+
)
|
|
274
|
+
best_acc = layer_accuracies[best_layer]["test_accuracy"]
|
|
275
|
+
print(f"\nBest probing layer: L{best_layer} (test accuracy: {best_acc:.4f})")
|
|
276
|
+
else:
|
|
277
|
+
best_layer = None
|
|
278
|
+
best_acc = 0
|
|
279
|
+
|
|
280
|
+
# Count class distribution
|
|
281
|
+
unique_final, counts_final = np.unique(labels, return_counts=True)
|
|
282
|
+
|
|
283
|
+
metrics = {
|
|
284
|
+
"num_samples": len(all_samples),
|
|
285
|
+
"num_correct": sum(correctness_labels),
|
|
286
|
+
"accuracy_rate": sum(correctness_labels) / len(correctness_labels)
|
|
287
|
+
if correctness_labels
|
|
288
|
+
else 0,
|
|
289
|
+
"num_layers_probed": len(results),
|
|
290
|
+
"num_classes": len(unique_final),
|
|
291
|
+
"best_layer": best_layer,
|
|
292
|
+
"best_test_accuracy": best_acc,
|
|
293
|
+
"probe_target": self.probe_target,
|
|
294
|
+
}
|
|
295
|
+
|
|
296
|
+
return ExperimentResult(
|
|
297
|
+
experiment_name=self.name,
|
|
298
|
+
model_name=backend.model_name,
|
|
299
|
+
prompt_strategy=prompt_strategy.name if hasattr(prompt_strategy, "name") else "custom",
|
|
300
|
+
metrics=metrics,
|
|
301
|
+
raw_outputs=results,
|
|
302
|
+
metadata={
|
|
303
|
+
"target_layers": self.target_layers,
|
|
304
|
+
"class_distribution": dict(zip(unique_final.tolist(), counts_final.tolist())),
|
|
305
|
+
"label_names": label_names,
|
|
306
|
+
"model_answers": model_answers[:5], # Save first 5 for inspection
|
|
307
|
+
},
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
def _train_gpu_probe(
|
|
311
|
+
self,
|
|
312
|
+
X_train: np.ndarray,
|
|
313
|
+
X_test: np.ndarray,
|
|
314
|
+
y_train: np.ndarray,
|
|
315
|
+
y_test: np.ndarray,
|
|
316
|
+
device: str,
|
|
317
|
+
) -> tuple[float, float]:
|
|
318
|
+
"""Train a linear probe using PyTorch on GPU matching sklearn's LogisticRegression."""
|
|
319
|
+
|
|
320
|
+
# Set random seed for reproducibility
|
|
321
|
+
torch.manual_seed(self.random_seed)
|
|
322
|
+
if torch.cuda.is_available():
|
|
323
|
+
torch.cuda.manual_seed(self.random_seed)
|
|
324
|
+
|
|
325
|
+
# 1. NORMALIZE features (sklearn does this internally with lbfgs solver)
|
|
326
|
+
mean = X_train.mean(axis=0)
|
|
327
|
+
std = X_train.std(axis=0)
|
|
328
|
+
std[std == 0] = 1.0 # Avoid division by zero
|
|
329
|
+
X_train_normalized = (X_train - mean) / std
|
|
330
|
+
X_test_normalized = (X_test - mean) / std # Use train mean/std
|
|
331
|
+
|
|
332
|
+
# Convert to tensors
|
|
333
|
+
X_train_t = torch.from_numpy(X_train_normalized).float().to(device)
|
|
334
|
+
X_test_t = torch.from_numpy(X_test_normalized).float().to(device)
|
|
335
|
+
y_train_t = torch.from_numpy(y_train).long().to(device)
|
|
336
|
+
y_test_t = torch.from_numpy(y_test).long().to(device)
|
|
337
|
+
|
|
338
|
+
# Create model
|
|
339
|
+
input_dim = X_train.shape[1]
|
|
340
|
+
num_classes = len(np.unique(y_train))
|
|
341
|
+
model = GPULinearProbe(input_dim, num_classes).to(device)
|
|
342
|
+
|
|
343
|
+
# Match sklearn's LogisticRegression settings:
|
|
344
|
+
# - C=1.0 (regularization strength = 1/C = 1.0)
|
|
345
|
+
# - max_iter=1000
|
|
346
|
+
# - tolerance=1e-4
|
|
347
|
+
# - penalty='l2' (applied to weights only, NOT bias)
|
|
348
|
+
criterion = nn.CrossEntropyLoss()
|
|
349
|
+
l2_lambda = 1.0 # sklearn's C=1.0 means regularization strength = 1.0
|
|
350
|
+
|
|
351
|
+
# LBFGS optimizer with sklearn-like settings
|
|
352
|
+
optimizer = optim.LBFGS(
|
|
353
|
+
model.parameters(),
|
|
354
|
+
lr=1.0,
|
|
355
|
+
max_iter=20, # iterations per step
|
|
356
|
+
max_eval=None,
|
|
357
|
+
tolerance_grad=1e-7,
|
|
358
|
+
tolerance_change=1e-9,
|
|
359
|
+
history_size=100,
|
|
360
|
+
line_search_fn="strong_wolfe",
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
# Training loop - call optimizer.step() multiple times until convergence
|
|
364
|
+
max_steps = 50 # max outer steps (total iterations = 20 * 50 = 1000)
|
|
365
|
+
prev_loss = float("inf")
|
|
366
|
+
tolerance = 1e-4
|
|
367
|
+
|
|
368
|
+
for step in range(max_steps):
|
|
369
|
+
|
|
370
|
+
def closure():
|
|
371
|
+
optimizer.zero_grad()
|
|
372
|
+
outputs = model(X_train_t)
|
|
373
|
+
loss = criterion(outputs, y_train_t)
|
|
374
|
+
|
|
375
|
+
# 2. L2 regularization on WEIGHTS ONLY (not bias) - matches sklearn
|
|
376
|
+
# sklearn's LogisticRegression does NOT regularize the intercept
|
|
377
|
+
l2_reg = torch.norm(model.linear.weight, 2) ** 2
|
|
378
|
+
loss = loss + (l2_lambda / (2.0 * len(y_train))) * l2_reg
|
|
379
|
+
|
|
380
|
+
loss.backward()
|
|
381
|
+
return loss
|
|
382
|
+
|
|
383
|
+
loss = optimizer.step(closure)
|
|
384
|
+
|
|
385
|
+
# Check convergence
|
|
386
|
+
if abs(prev_loss - loss.item()) < tolerance:
|
|
387
|
+
break
|
|
388
|
+
prev_loss = loss.item()
|
|
389
|
+
|
|
390
|
+
# Evaluate
|
|
391
|
+
with torch.no_grad():
|
|
392
|
+
# Train accuracy
|
|
393
|
+
train_outputs = model(X_train_t)
|
|
394
|
+
train_preds = torch.argmax(train_outputs, dim=1)
|
|
395
|
+
train_acc = (train_preds == y_train_t).float().mean().item()
|
|
396
|
+
|
|
397
|
+
# Test accuracy
|
|
398
|
+
test_outputs = model(X_test_t)
|
|
399
|
+
test_preds = torch.argmax(test_outputs, dim=1)
|
|
400
|
+
test_acc = (test_preds == y_test_t).float().mean().item()
|
|
401
|
+
|
|
402
|
+
return train_acc, test_acc
|