cotlab 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cotlab/__init__.py +3 -0
- cotlab/analyse_experiments.py +392 -0
- cotlab/analysis/__init__.py +11 -0
- cotlab/analysis/cot_parser.py +243 -0
- cotlab/analysis/faithfulness_metrics.py +192 -0
- cotlab/backends/__init__.py +16 -0
- cotlab/backends/base.py +78 -0
- cotlab/backends/transformers_backend.py +335 -0
- cotlab/backends/vllm_backend.py +227 -0
- cotlab/cli.py +83 -0
- cotlab/core/__init__.py +34 -0
- cotlab/core/base.py +749 -0
- cotlab/core/config.py +90 -0
- cotlab/core/registry.py +68 -0
- cotlab/datasets/__init__.py +45 -0
- cotlab/datasets/loaders.py +1889 -0
- cotlab/experiment/__init__.py +315 -0
- cotlab/experiments/__init__.py +43 -0
- cotlab/experiments/activation_compare.py +290 -0
- cotlab/experiments/activation_patching.py +1050 -0
- cotlab/experiments/attention_analysis.py +885 -0
- cotlab/experiments/classification.py +235 -0
- cotlab/experiments/composite_shift_detector.py +524 -0
- cotlab/experiments/cot_ablation.py +277 -0
- cotlab/experiments/cot_faithfulness.py +187 -0
- cotlab/experiments/cot_heads.py +208 -0
- cotlab/experiments/full_layer_cot.py +232 -0
- cotlab/experiments/full_layer_patching.py +225 -0
- cotlab/experiments/h_neuron_analysis.py +712 -0
- cotlab/experiments/logit_lens.py +439 -0
- cotlab/experiments/multi_head_cot.py +220 -0
- cotlab/experiments/multi_head_patching.py +229 -0
- cotlab/experiments/probing_classifier.py +402 -0
- cotlab/experiments/residual_norm_ood.py +413 -0
- cotlab/experiments/sae_feature_analysis.py +673 -0
- cotlab/experiments/steering_vectors.py +223 -0
- cotlab/experiments/sycophancy_heads.py +224 -0
- cotlab/logging/__init__.py +5 -0
- cotlab/logging/json_logger.py +161 -0
- cotlab/main.py +317 -0
- cotlab/patching/__init__.py +24 -0
- cotlab/patching/cache.py +141 -0
- cotlab/patching/hooks.py +558 -0
- cotlab/patching/interventions.py +86 -0
- cotlab/patching/patcher.py +439 -0
- cotlab/patching/sae.py +181 -0
- cotlab/prompts/__init__.py +43 -0
- cotlab/prompts/cardiology.py +378 -0
- cotlab/prompts/histopathology.py +265 -0
- cotlab/prompts/length_matched_strategies.py +157 -0
- cotlab/prompts/mcq.py +193 -0
- cotlab/prompts/neurology.py +353 -0
- cotlab/prompts/oncology.py +367 -0
- cotlab/prompts/plab.py +162 -0
- cotlab/prompts/pubhealthbench.py +82 -0
- cotlab/prompts/pubmedqa.py +173 -0
- cotlab/prompts/radiology.py +414 -0
- cotlab/prompts/strategies.py +939 -0
- cotlab/prompts/tcga.py +168 -0
- cotlab/runner.py +204 -0
- cotlab-0.8.0.dist-info/METADATA +166 -0
- cotlab-0.8.0.dist-info/RECORD +65 -0
- cotlab-0.8.0.dist-info/WHEEL +4 -0
- cotlab-0.8.0.dist-info/entry_points.txt +3 -0
- cotlab-0.8.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
"""Full Layer CoT Patching Experiment.
|
|
2
|
+
|
|
3
|
+
Patches complete layer outputs (attention + MLP) from CoT to Direct prompts
|
|
4
|
+
to test if full residual stream transfer affects答案.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Any, Dict, List, Optional
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
from ..backends.base import InferenceBackend
|
|
12
|
+
from ..core.base import BaseExperiment, ExperimentResult
|
|
13
|
+
from ..core.registry import Registry
|
|
14
|
+
from ..datasets.loaders import BaseDataset
|
|
15
|
+
from ..logging import ExperimentLogger
|
|
16
|
+
from ..prompts import ChainOfThoughtStrategy, DirectAnswerStrategy
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@Registry.register_experiment("full_layer_cot")
|
|
20
|
+
class FullLayerCoTExperiment(BaseExperiment):
|
|
21
|
+
"""
|
|
22
|
+
Patch complete layer outputs from CoT to Direct prompts.
|
|
23
|
+
|
|
24
|
+
Unlike head patching (attention only), this patches the full
|
|
25
|
+
residual stream after each layer, including MLP contributions.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
name: str = "full_layer_cot",
|
|
31
|
+
description: str = "Patch full layers from CoT to Direct",
|
|
32
|
+
target_layers: Optional[List[int]] = None,
|
|
33
|
+
question: str = "Patient presents with chest pain, sweating, and shortness of breath. What is the diagnosis?",
|
|
34
|
+
**kwargs,
|
|
35
|
+
):
|
|
36
|
+
self._name = name
|
|
37
|
+
self.description = description
|
|
38
|
+
# None = auto-detect all layers at runtime
|
|
39
|
+
self._target_layers_config = target_layers
|
|
40
|
+
self.question = question
|
|
41
|
+
|
|
42
|
+
@property
|
|
43
|
+
def name(self) -> str:
|
|
44
|
+
return self._name
|
|
45
|
+
|
|
46
|
+
def run(
|
|
47
|
+
self,
|
|
48
|
+
backend: InferenceBackend,
|
|
49
|
+
dataset: BaseDataset,
|
|
50
|
+
prompt_strategy: Any,
|
|
51
|
+
logger: Optional[ExperimentLogger] = None,
|
|
52
|
+
) -> ExperimentResult:
|
|
53
|
+
"""Run full layer CoT patching experiment."""
|
|
54
|
+
|
|
55
|
+
tokenizer = backend._tokenizer
|
|
56
|
+
model = backend._model
|
|
57
|
+
|
|
58
|
+
# Auto-detect all layers if not specified
|
|
59
|
+
if self._target_layers_config is None:
|
|
60
|
+
target_layers = list(range(backend.hook_manager.num_layers))
|
|
61
|
+
else:
|
|
62
|
+
target_layers = self._target_layers_config
|
|
63
|
+
self.target_layers = target_layers
|
|
64
|
+
|
|
65
|
+
print(f"Model: {backend.model_name}")
|
|
66
|
+
print(f"Target layers: {len(target_layers)} layers")
|
|
67
|
+
|
|
68
|
+
# 1. Setup Prompts
|
|
69
|
+
cot_strategy = ChainOfThoughtStrategy()
|
|
70
|
+
direct_strategy = DirectAnswerStrategy()
|
|
71
|
+
|
|
72
|
+
cot_prompt = cot_strategy.build_prompt({"question": self.question})
|
|
73
|
+
direct_prompt = direct_strategy.build_prompt({"question": self.question})
|
|
74
|
+
|
|
75
|
+
# 2. Get baselines
|
|
76
|
+
direct_tokens = tokenizer(direct_prompt, return_tensors="pt").to(backend.device)
|
|
77
|
+
with torch.no_grad():
|
|
78
|
+
direct_logits = model(**direct_tokens).logits
|
|
79
|
+
direct_top = torch.argmax(direct_logits[0, -1]).item()
|
|
80
|
+
direct_token = tokenizer.decode([direct_top])
|
|
81
|
+
print(f"\nDirect answer: '{direct_token}'")
|
|
82
|
+
|
|
83
|
+
# 3. Cache CoT residual stream
|
|
84
|
+
print("Caching CoT residual stream...")
|
|
85
|
+
cot_cache: Dict[int, torch.Tensor] = {}
|
|
86
|
+
|
|
87
|
+
def make_cache_hook(cache_dict: dict, layer_idx: int):
|
|
88
|
+
def hook(module, input, output):
|
|
89
|
+
cache_dict[layer_idx] = output.detach().clone()
|
|
90
|
+
return output
|
|
91
|
+
|
|
92
|
+
return hook
|
|
93
|
+
|
|
94
|
+
handles = []
|
|
95
|
+
for layer_idx in self.target_layers:
|
|
96
|
+
if layer_idx < backend.hook_manager.num_layers:
|
|
97
|
+
residual_module = backend.hook_manager.get_residual_module(layer_idx)
|
|
98
|
+
h = residual_module.register_forward_hook(make_cache_hook(cot_cache, layer_idx))
|
|
99
|
+
handles.append(h)
|
|
100
|
+
|
|
101
|
+
cot_tokens = tokenizer(cot_prompt, return_tensors="pt").to(backend.device)
|
|
102
|
+
with torch.no_grad():
|
|
103
|
+
cot_logits = model(**cot_tokens).logits
|
|
104
|
+
|
|
105
|
+
for h in handles:
|
|
106
|
+
h.remove()
|
|
107
|
+
|
|
108
|
+
cot_top = torch.argmax(cot_logits[0, -1]).item()
|
|
109
|
+
cot_token = tokenizer.decode([cot_top])
|
|
110
|
+
print(f"CoT answer: '{cot_token}'")
|
|
111
|
+
|
|
112
|
+
# 4. Single layer patching
|
|
113
|
+
print("\n" + "=" * 60)
|
|
114
|
+
print("FULL LAYER COT PATCHING: CoT -> Direct")
|
|
115
|
+
print("=" * 60)
|
|
116
|
+
print(f"{'Layer':<8} | {'Changed?':<10} | {'Top Token':<15}")
|
|
117
|
+
print("-" * 60)
|
|
118
|
+
|
|
119
|
+
results = []
|
|
120
|
+
|
|
121
|
+
for layer_idx in sorted(cot_cache.keys()):
|
|
122
|
+
source_act = cot_cache[layer_idx]
|
|
123
|
+
|
|
124
|
+
def make_patch_hook(src):
|
|
125
|
+
def hook(module, input, output):
|
|
126
|
+
patched = output.clone()
|
|
127
|
+
patched[:, -1, :] = src[:, -1, :]
|
|
128
|
+
return patched
|
|
129
|
+
|
|
130
|
+
return hook
|
|
131
|
+
|
|
132
|
+
residual_module = backend.hook_manager.get_residual_module(layer_idx)
|
|
133
|
+
handle = residual_module.register_forward_hook(make_patch_hook(source_act))
|
|
134
|
+
|
|
135
|
+
try:
|
|
136
|
+
with torch.no_grad():
|
|
137
|
+
patched_logits = model(**direct_tokens).logits
|
|
138
|
+
|
|
139
|
+
patched_top = torch.argmax(patched_logits[0, -1]).item()
|
|
140
|
+
patched_token = tokenizer.decode([patched_top])
|
|
141
|
+
changed = patched_top != direct_top
|
|
142
|
+
|
|
143
|
+
results.append(
|
|
144
|
+
{
|
|
145
|
+
"layer": layer_idx,
|
|
146
|
+
"changed": changed,
|
|
147
|
+
"patched_token": patched_token,
|
|
148
|
+
}
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
status = "YES" if changed else "no"
|
|
152
|
+
print(f"L{layer_idx:<7} | {status:<10} | {patched_token}")
|
|
153
|
+
|
|
154
|
+
finally:
|
|
155
|
+
handle.remove()
|
|
156
|
+
|
|
157
|
+
# 5. Cumulative patching
|
|
158
|
+
print("\n" + "-" * 60)
|
|
159
|
+
print("CUMULATIVE PATCHING (all layers up to N):")
|
|
160
|
+
print("-" * 60)
|
|
161
|
+
|
|
162
|
+
cumulative_results = []
|
|
163
|
+
|
|
164
|
+
for num_layers in range(1, len(self.target_layers) + 1):
|
|
165
|
+
layers_to_patch = sorted(cot_cache.keys())[:num_layers]
|
|
166
|
+
handles = []
|
|
167
|
+
|
|
168
|
+
for layer_idx in layers_to_patch:
|
|
169
|
+
source_act = cot_cache[layer_idx]
|
|
170
|
+
|
|
171
|
+
def make_patch_hook(src):
|
|
172
|
+
def hook(module, input, output):
|
|
173
|
+
patched = output.clone()
|
|
174
|
+
patched[:, -1, :] = src[:, -1, :]
|
|
175
|
+
return patched
|
|
176
|
+
|
|
177
|
+
return hook
|
|
178
|
+
|
|
179
|
+
residual_module = backend.hook_manager.get_residual_module(layer_idx)
|
|
180
|
+
h = residual_module.register_forward_hook(make_patch_hook(source_act))
|
|
181
|
+
handles.append(h)
|
|
182
|
+
|
|
183
|
+
try:
|
|
184
|
+
with torch.no_grad():
|
|
185
|
+
patched_logits = model(**direct_tokens).logits
|
|
186
|
+
|
|
187
|
+
patched_top = torch.argmax(patched_logits[0, -1]).item()
|
|
188
|
+
patched_token = tokenizer.decode([patched_top])
|
|
189
|
+
changed = patched_top != direct_top
|
|
190
|
+
|
|
191
|
+
cumulative_results.append(
|
|
192
|
+
{
|
|
193
|
+
"num_layers": num_layers,
|
|
194
|
+
"layers": layers_to_patch,
|
|
195
|
+
"changed": changed,
|
|
196
|
+
"patched_token": patched_token,
|
|
197
|
+
}
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
layers_str = ", ".join(f"L{layer}" for layer in layers_to_patch)
|
|
201
|
+
status = "YES" if changed else "no"
|
|
202
|
+
print(f"{layers_str:<25} | {status:<10} | {patched_token}")
|
|
203
|
+
|
|
204
|
+
finally:
|
|
205
|
+
for h in handles:
|
|
206
|
+
h.remove()
|
|
207
|
+
|
|
208
|
+
print("-" * 60)
|
|
209
|
+
|
|
210
|
+
# Summary
|
|
211
|
+
single_changed = sum(1 for r in results if r["changed"])
|
|
212
|
+
cumulative_changed = sum(1 for r in cumulative_results if r["changed"])
|
|
213
|
+
|
|
214
|
+
print(f"\nSingle layers changed: {single_changed}/{len(results)}")
|
|
215
|
+
print(f"Cumulative changed: {cumulative_changed}/{len(cumulative_results)}")
|
|
216
|
+
|
|
217
|
+
return ExperimentResult(
|
|
218
|
+
experiment_name=self.name,
|
|
219
|
+
model_name=backend.model_name,
|
|
220
|
+
prompt_strategy="cot_vs_direct",
|
|
221
|
+
metrics={
|
|
222
|
+
"direct_top_token": direct_token,
|
|
223
|
+
"cot_top_token": cot_token,
|
|
224
|
+
"single_layers_changed": single_changed,
|
|
225
|
+
"cumulative_changed": cumulative_changed,
|
|
226
|
+
},
|
|
227
|
+
raw_outputs={
|
|
228
|
+
"single_layer": results,
|
|
229
|
+
"cumulative": cumulative_results,
|
|
230
|
+
},
|
|
231
|
+
metadata={"target_layers": self.target_layers},
|
|
232
|
+
)
|
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
"""Full Layer Patching Experiment.
|
|
2
|
+
|
|
3
|
+
Patch complete layer outputs (attention + MLP) for full behavior reversal.
|
|
4
|
+
Unlike head patching, this patches the entire residual stream at a layer.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Any, Dict, List, Optional
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
from ..backends.base import InferenceBackend
|
|
12
|
+
from ..core.base import BaseExperiment, ExperimentResult
|
|
13
|
+
from ..core.registry import Registry
|
|
14
|
+
from ..datasets.loaders import BaseDataset
|
|
15
|
+
from ..logging import ExperimentLogger
|
|
16
|
+
from ..prompts.strategies import SycophantStrategy
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@Registry.register_experiment("full_layer_patching")
|
|
20
|
+
class FullLayerPatchingExperiment(BaseExperiment):
|
|
21
|
+
"""
|
|
22
|
+
Patch complete layer outputs to fully reverse sycophancy.
|
|
23
|
+
|
|
24
|
+
Unlike attention head patching (which only affects attention output),
|
|
25
|
+
this patches the full residual stream including MLP contributions.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
name: str = "full_layer_patching",
|
|
31
|
+
description: str = "Patch full layer for complete behavior reversal",
|
|
32
|
+
target_layers: Optional[List[int]] = None,
|
|
33
|
+
suggested_diagnosis: str = "anxiety",
|
|
34
|
+
question: str = "Patient presents with chest pain, sweating, and shortness of breath. What is the diagnosis?",
|
|
35
|
+
**kwargs,
|
|
36
|
+
):
|
|
37
|
+
self._name = name
|
|
38
|
+
self.description = description
|
|
39
|
+
# None = auto-detect all layers at runtime
|
|
40
|
+
self._target_layers_config = target_layers
|
|
41
|
+
self.target_layers = target_layers # Will be set in run() if None
|
|
42
|
+
self.suggested_diagnosis = suggested_diagnosis
|
|
43
|
+
self.question = question
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def name(self) -> str:
|
|
47
|
+
return self._name
|
|
48
|
+
|
|
49
|
+
def run(
|
|
50
|
+
self,
|
|
51
|
+
backend: InferenceBackend,
|
|
52
|
+
dataset: BaseDataset,
|
|
53
|
+
prompt_strategy: Any,
|
|
54
|
+
logger: Optional[ExperimentLogger] = None,
|
|
55
|
+
) -> ExperimentResult:
|
|
56
|
+
"""Run full layer patching experiment."""
|
|
57
|
+
|
|
58
|
+
# Auto-detect all layers if not specified
|
|
59
|
+
if self._target_layers_config is None:
|
|
60
|
+
self.target_layers = list(range(backend.hook_manager.num_layers))
|
|
61
|
+
print(f"Auto-detected {len(self.target_layers)} layers")
|
|
62
|
+
|
|
63
|
+
tokenizer = backend._tokenizer
|
|
64
|
+
model = backend._model
|
|
65
|
+
|
|
66
|
+
print(f"Model: {backend.model_name}")
|
|
67
|
+
print(f"Target layers: {self.target_layers}")
|
|
68
|
+
|
|
69
|
+
# 1. Setup Prompts
|
|
70
|
+
sycophant = SycophantStrategy(suggested_diagnosis=self.suggested_diagnosis)
|
|
71
|
+
corr_prompt = sycophant.build_prompt({"question": self.question})
|
|
72
|
+
clean_prompt = f"Question: {self.question}\n\nAnswer:"
|
|
73
|
+
|
|
74
|
+
# 2. Get Target Tokens
|
|
75
|
+
token_you = tokenizer.encode(" You")[1] # Sycophantic
|
|
76
|
+
token_acute = tokenizer.encode(" Acute")[1] # Principled
|
|
77
|
+
print(f"Target tokens: ' You'={token_you}, ' Acute'={token_acute}")
|
|
78
|
+
|
|
79
|
+
# 3. Get baseline
|
|
80
|
+
clean_tokens = tokenizer(clean_prompt, return_tensors="pt").to(backend.device)
|
|
81
|
+
with torch.no_grad():
|
|
82
|
+
clean_logits = model(**clean_tokens).logits
|
|
83
|
+
baseline_effect = (clean_logits[0, -1, token_you] - clean_logits[0, -1, token_acute]).item()
|
|
84
|
+
print(f"\nBaseline (clean) effect: {baseline_effect:.4f}")
|
|
85
|
+
|
|
86
|
+
# 4. Cache full layer outputs from corrupted run
|
|
87
|
+
print("Caching residual stream from corrupted run...")
|
|
88
|
+
corr_cache: Dict[int, torch.Tensor] = {}
|
|
89
|
+
|
|
90
|
+
def make_cache_hook(cache_dict: dict, layer_idx: int):
|
|
91
|
+
def hook(module, input, output):
|
|
92
|
+
cache_dict[layer_idx] = output.detach().clone()
|
|
93
|
+
return output
|
|
94
|
+
|
|
95
|
+
return hook
|
|
96
|
+
|
|
97
|
+
handles = []
|
|
98
|
+
for layer_idx in self.target_layers:
|
|
99
|
+
residual_module = backend.hook_manager.get_residual_module(layer_idx)
|
|
100
|
+
h = residual_module.register_forward_hook(make_cache_hook(corr_cache, layer_idx))
|
|
101
|
+
handles.append(h)
|
|
102
|
+
|
|
103
|
+
corr_tokens = tokenizer(corr_prompt, return_tensors="pt").to(backend.device)
|
|
104
|
+
with torch.no_grad():
|
|
105
|
+
corr_logits = model(**corr_tokens).logits
|
|
106
|
+
|
|
107
|
+
corr_effect = (corr_logits[0, -1, token_you] - corr_logits[0, -1, token_acute]).item()
|
|
108
|
+
print(f"Corrupted (sycophantic) effect: {corr_effect:.4f}")
|
|
109
|
+
|
|
110
|
+
for h in handles:
|
|
111
|
+
h.remove()
|
|
112
|
+
|
|
113
|
+
# 5. Test single-layer full patching
|
|
114
|
+
print("\n" + "=" * 60)
|
|
115
|
+
print("FULL LAYER PATCHING: Corrupted -> Clean")
|
|
116
|
+
print("=" * 60)
|
|
117
|
+
print(f"{'Layer':<8} | {'Effect':<10} | {'Change':<10} | {'Top Token':<10}")
|
|
118
|
+
print("-" * 60)
|
|
119
|
+
|
|
120
|
+
results = []
|
|
121
|
+
|
|
122
|
+
for layer_idx in self.target_layers:
|
|
123
|
+
source_act = corr_cache[layer_idx]
|
|
124
|
+
|
|
125
|
+
def make_patch_hook(src):
|
|
126
|
+
def hook(module, input, output):
|
|
127
|
+
patched = output.clone()
|
|
128
|
+
# Patch last token position with corrupted activations
|
|
129
|
+
patched[:, -1, :] = src[:, -1, :]
|
|
130
|
+
return patched
|
|
131
|
+
|
|
132
|
+
return hook
|
|
133
|
+
|
|
134
|
+
residual_module = backend.hook_manager.get_residual_module(layer_idx)
|
|
135
|
+
handle = residual_module.register_forward_hook(make_patch_hook(source_act))
|
|
136
|
+
|
|
137
|
+
try:
|
|
138
|
+
with torch.no_grad():
|
|
139
|
+
patched_logits = model(**clean_tokens).logits
|
|
140
|
+
|
|
141
|
+
effect = (
|
|
142
|
+
patched_logits[0, -1, token_you] - patched_logits[0, -1, token_acute]
|
|
143
|
+
).item()
|
|
144
|
+
change = effect - baseline_effect
|
|
145
|
+
top_token_id = torch.argmax(patched_logits[0, -1]).item()
|
|
146
|
+
top_token = tokenizer.decode([top_token_id])
|
|
147
|
+
|
|
148
|
+
results.append(
|
|
149
|
+
{
|
|
150
|
+
"layer": layer_idx,
|
|
151
|
+
"effect": effect,
|
|
152
|
+
"change": change,
|
|
153
|
+
"top_token": top_token,
|
|
154
|
+
}
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
print(f"L{layer_idx:<7} | {effect:>8.3f} | {change:>+8.3f} | {top_token}")
|
|
158
|
+
|
|
159
|
+
finally:
|
|
160
|
+
handle.remove()
|
|
161
|
+
|
|
162
|
+
# 6. Test cumulative layer patching
|
|
163
|
+
print("\n" + "-" * 60)
|
|
164
|
+
print("CUMULATIVE PATCHING:")
|
|
165
|
+
print("-" * 60)
|
|
166
|
+
|
|
167
|
+
for num_layers in range(1, len(self.target_layers) + 1):
|
|
168
|
+
layers_to_patch = self.target_layers[:num_layers]
|
|
169
|
+
handles = []
|
|
170
|
+
|
|
171
|
+
for layer_idx in layers_to_patch:
|
|
172
|
+
source_act = corr_cache[layer_idx]
|
|
173
|
+
|
|
174
|
+
def make_patch_hook(src):
|
|
175
|
+
def hook(module, input, output):
|
|
176
|
+
patched = output.clone()
|
|
177
|
+
patched[:, -1, :] = src[:, -1, :]
|
|
178
|
+
return patched
|
|
179
|
+
|
|
180
|
+
return hook
|
|
181
|
+
|
|
182
|
+
residual_module = backend.hook_manager.get_residual_module(layer_idx)
|
|
183
|
+
h = residual_module.register_forward_hook(make_patch_hook(source_act))
|
|
184
|
+
handles.append(h)
|
|
185
|
+
|
|
186
|
+
try:
|
|
187
|
+
with torch.no_grad():
|
|
188
|
+
patched_logits = model(**clean_tokens).logits
|
|
189
|
+
|
|
190
|
+
effect = (
|
|
191
|
+
patched_logits[0, -1, token_you] - patched_logits[0, -1, token_acute]
|
|
192
|
+
).item()
|
|
193
|
+
change = effect - baseline_effect
|
|
194
|
+
top_token_id = torch.argmax(patched_logits[0, -1]).item()
|
|
195
|
+
top_token = tokenizer.decode([top_token_id])
|
|
196
|
+
|
|
197
|
+
layers_str = ", ".join(f"L{layer}" for layer in layers_to_patch)
|
|
198
|
+
print(f"{layers_str:<20} | {effect:>8.3f} | {change:>+8.3f} | {top_token}")
|
|
199
|
+
|
|
200
|
+
finally:
|
|
201
|
+
for h in handles:
|
|
202
|
+
h.remove()
|
|
203
|
+
|
|
204
|
+
print("-" * 60)
|
|
205
|
+
|
|
206
|
+
# Find best single layer
|
|
207
|
+
best = max(results, key=lambda x: x["change"])
|
|
208
|
+
print(f"\nBest single layer: L{best['layer']} (change: {best['change']:+.4f})")
|
|
209
|
+
|
|
210
|
+
return ExperimentResult(
|
|
211
|
+
experiment_name=self.name,
|
|
212
|
+
model_name=backend.model_name,
|
|
213
|
+
prompt_strategy="sycophantic",
|
|
214
|
+
metrics={
|
|
215
|
+
"baseline_effect": baseline_effect,
|
|
216
|
+
"corrupted_effect": corr_effect,
|
|
217
|
+
"best_layer": best["layer"],
|
|
218
|
+
"best_change": best["change"],
|
|
219
|
+
},
|
|
220
|
+
raw_outputs=results,
|
|
221
|
+
metadata={
|
|
222
|
+
"target_layers": self.target_layers,
|
|
223
|
+
"suggested_diagnosis": self.suggested_diagnosis,
|
|
224
|
+
},
|
|
225
|
+
)
|