cotlab 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cotlab/__init__.py +3 -0
- cotlab/analyse_experiments.py +392 -0
- cotlab/analysis/__init__.py +11 -0
- cotlab/analysis/cot_parser.py +243 -0
- cotlab/analysis/faithfulness_metrics.py +192 -0
- cotlab/backends/__init__.py +16 -0
- cotlab/backends/base.py +78 -0
- cotlab/backends/transformers_backend.py +335 -0
- cotlab/backends/vllm_backend.py +227 -0
- cotlab/cli.py +83 -0
- cotlab/core/__init__.py +34 -0
- cotlab/core/base.py +749 -0
- cotlab/core/config.py +90 -0
- cotlab/core/registry.py +68 -0
- cotlab/datasets/__init__.py +45 -0
- cotlab/datasets/loaders.py +1889 -0
- cotlab/experiment/__init__.py +315 -0
- cotlab/experiments/__init__.py +43 -0
- cotlab/experiments/activation_compare.py +290 -0
- cotlab/experiments/activation_patching.py +1050 -0
- cotlab/experiments/attention_analysis.py +885 -0
- cotlab/experiments/classification.py +235 -0
- cotlab/experiments/composite_shift_detector.py +524 -0
- cotlab/experiments/cot_ablation.py +277 -0
- cotlab/experiments/cot_faithfulness.py +187 -0
- cotlab/experiments/cot_heads.py +208 -0
- cotlab/experiments/full_layer_cot.py +232 -0
- cotlab/experiments/full_layer_patching.py +225 -0
- cotlab/experiments/h_neuron_analysis.py +712 -0
- cotlab/experiments/logit_lens.py +439 -0
- cotlab/experiments/multi_head_cot.py +220 -0
- cotlab/experiments/multi_head_patching.py +229 -0
- cotlab/experiments/probing_classifier.py +402 -0
- cotlab/experiments/residual_norm_ood.py +413 -0
- cotlab/experiments/sae_feature_analysis.py +673 -0
- cotlab/experiments/steering_vectors.py +223 -0
- cotlab/experiments/sycophancy_heads.py +224 -0
- cotlab/logging/__init__.py +5 -0
- cotlab/logging/json_logger.py +161 -0
- cotlab/main.py +317 -0
- cotlab/patching/__init__.py +24 -0
- cotlab/patching/cache.py +141 -0
- cotlab/patching/hooks.py +558 -0
- cotlab/patching/interventions.py +86 -0
- cotlab/patching/patcher.py +439 -0
- cotlab/patching/sae.py +181 -0
- cotlab/prompts/__init__.py +43 -0
- cotlab/prompts/cardiology.py +378 -0
- cotlab/prompts/histopathology.py +265 -0
- cotlab/prompts/length_matched_strategies.py +157 -0
- cotlab/prompts/mcq.py +193 -0
- cotlab/prompts/neurology.py +353 -0
- cotlab/prompts/oncology.py +367 -0
- cotlab/prompts/plab.py +162 -0
- cotlab/prompts/pubhealthbench.py +82 -0
- cotlab/prompts/pubmedqa.py +173 -0
- cotlab/prompts/radiology.py +414 -0
- cotlab/prompts/strategies.py +939 -0
- cotlab/prompts/tcga.py +168 -0
- cotlab/runner.py +204 -0
- cotlab-0.8.0.dist-info/METADATA +166 -0
- cotlab-0.8.0.dist-info/RECORD +65 -0
- cotlab-0.8.0.dist-info/WHEEL +4 -0
- cotlab-0.8.0.dist-info/entry_points.txt +3 -0
- cotlab-0.8.0.dist-info/licenses/LICENSE +21 -0
cotlab/patching/hooks.py
ADDED
|
@@ -0,0 +1,558 @@
|
|
|
1
|
+
"""PyTorch forward hook utilities for activation extraction and patching."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Callable, Dict, List, Optional
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class HookManager:
|
|
10
|
+
"""
|
|
11
|
+
Manage PyTorch forward hooks for activation extraction and patching.
|
|
12
|
+
|
|
13
|
+
This provides a clean interface for:
|
|
14
|
+
- Registering hooks on specific layers
|
|
15
|
+
- Caching activations during forward pass
|
|
16
|
+
- Patching activations with custom values
|
|
17
|
+
- Cleanup of all registered hooks
|
|
18
|
+
|
|
19
|
+
Example:
|
|
20
|
+
>>> manager = HookManager(model)
|
|
21
|
+
>>> cache = ActivationCache()
|
|
22
|
+
>>> manager.register_cache_hooks(cache, layers=[0, 5, 10])
|
|
23
|
+
>>> output = model(input_ids) # Activations now in cache
|
|
24
|
+
>>> manager.remove_all_hooks()
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(self, model: nn.Module):
|
|
28
|
+
"""
|
|
29
|
+
Initialize hook manager for a model.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
model: The transformer model to hook into
|
|
33
|
+
"""
|
|
34
|
+
self.model = model
|
|
35
|
+
self.handles: List[torch.utils.hooks.RemovableHandle] = []
|
|
36
|
+
self._layer_modules = self._build_layer_mapping()
|
|
37
|
+
|
|
38
|
+
# Known layer paths by model_type from HF config
|
|
39
|
+
# Focused on Gemma 3, MedGemma, and reasoning models
|
|
40
|
+
LAYER_PATHS = {
|
|
41
|
+
# Gemma 3 / MedGemma family
|
|
42
|
+
"gemma": "model.layers",
|
|
43
|
+
"gemma2": "model.layers",
|
|
44
|
+
"gemma3": "model.layers",
|
|
45
|
+
"gemma3_text": "model.layers",
|
|
46
|
+
# Mistral / Ministral (reasoning models)
|
|
47
|
+
"mistral": "model.layers",
|
|
48
|
+
# Qwen (DeepSeek-R1 distilled)
|
|
49
|
+
"qwen2": "model.layers",
|
|
50
|
+
# Olmo (Think models)
|
|
51
|
+
"olmo": "model.layers",
|
|
52
|
+
"olmo2": "model.layers",
|
|
53
|
+
# Nemotron / Llama-based
|
|
54
|
+
"llama": "model.layers",
|
|
55
|
+
# GPT-2 for testing
|
|
56
|
+
"gpt2": "transformer.h",
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
# Safe residual stream hook points (post-layer normalization)
|
|
60
|
+
# These are the output of the final norm in each layer block
|
|
61
|
+
RESIDUAL_HOOK_POINTS = {
|
|
62
|
+
# Gemma 3 / MedGemma: hook post_feedforward_layernorm (after entire layer)
|
|
63
|
+
"gemma": "post_feedforward_layernorm",
|
|
64
|
+
"gemma2": "post_feedforward_layernorm",
|
|
65
|
+
"gemma3": "post_feedforward_layernorm",
|
|
66
|
+
"gemma3_text": "post_feedforward_layernorm",
|
|
67
|
+
# Mistral / Qwen / Olmo / Llama: hook post_attention_layernorm
|
|
68
|
+
"mistral": "post_attention_layernorm",
|
|
69
|
+
"qwen2": "post_attention_layernorm",
|
|
70
|
+
"olmo": "post_attention_layernorm",
|
|
71
|
+
"olmo2": "post_attention_layernorm",
|
|
72
|
+
"llama": "post_attention_layernorm",
|
|
73
|
+
# GPT-2 for testing
|
|
74
|
+
"gpt2": "ln_2",
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
# FFN down-projection modules — input is z_t (intermediate activations after SwiGLU gate)
|
|
78
|
+
# Used for CETT computation and H-Neuron causal interventions
|
|
79
|
+
MLP_DOWN_PROJ_POINTS = {
|
|
80
|
+
"gemma": "mlp.down_proj",
|
|
81
|
+
"gemma2": "mlp.down_proj",
|
|
82
|
+
"gemma3": "mlp.down_proj",
|
|
83
|
+
"gemma3_text": "mlp.down_proj",
|
|
84
|
+
"mistral": "mlp.down_proj",
|
|
85
|
+
"qwen2": "mlp.down_proj",
|
|
86
|
+
"olmo": "mlp.down_proj",
|
|
87
|
+
"olmo2": "mlp.down_proj",
|
|
88
|
+
"llama": "mlp.down_proj",
|
|
89
|
+
"gpt2": "mlp.c_proj",
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
# Attention output projection modules for head-level patching
|
|
93
|
+
# These modules take concatenated head outputs and project back to hidden dim
|
|
94
|
+
ATTENTION_OUTPUT_POINTS = {
|
|
95
|
+
# Gemma 3 / MedGemma: self_attn.o_proj
|
|
96
|
+
"gemma": "self_attn.o_proj",
|
|
97
|
+
"gemma2": "self_attn.o_proj",
|
|
98
|
+
"gemma3": "self_attn.o_proj",
|
|
99
|
+
"gemma3_text": "self_attn.o_proj",
|
|
100
|
+
# Mistral / Qwen / Olmo / Llama: self_attn.o_proj
|
|
101
|
+
"mistral": "self_attn.o_proj",
|
|
102
|
+
"qwen2": "self_attn.o_proj",
|
|
103
|
+
"olmo": "self_attn.o_proj",
|
|
104
|
+
"olmo2": "self_attn.o_proj",
|
|
105
|
+
"llama": "self_attn.o_proj",
|
|
106
|
+
# GPT-2 for testing
|
|
107
|
+
"gpt2": "attn.c_proj",
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
def _build_layer_mapping(self) -> Dict[int, nn.Module]:
|
|
111
|
+
"""
|
|
112
|
+
Auto-detect transformer layers using HF config model_type.
|
|
113
|
+
|
|
114
|
+
Uses known layer paths for common architectures, with fallback
|
|
115
|
+
to regex-based auto-detection for unknown models.
|
|
116
|
+
"""
|
|
117
|
+
# Try to get model_type from config
|
|
118
|
+
model_type = getattr(self.model.config, "model_type", None)
|
|
119
|
+
layer_path = self.LAYER_PATHS.get(model_type)
|
|
120
|
+
|
|
121
|
+
if layer_path:
|
|
122
|
+
# Try known path first
|
|
123
|
+
layers = self._get_layers_from_path(layer_path)
|
|
124
|
+
if layers:
|
|
125
|
+
return layers
|
|
126
|
+
# If known path fails (e.g., multimodal model), try auto-detect
|
|
127
|
+
|
|
128
|
+
# Fallback: auto-detect using regex
|
|
129
|
+
return self._auto_detect_layers()
|
|
130
|
+
|
|
131
|
+
def _get_layers_from_path(self, layer_path: str) -> Dict[int, nn.Module]:
|
|
132
|
+
"""Get layers from a known path like 'transformer.h' or 'model.layers'."""
|
|
133
|
+
layers = {}
|
|
134
|
+
|
|
135
|
+
for name, module in self.model.named_modules():
|
|
136
|
+
# Match pattern: layer_path.{number}
|
|
137
|
+
if name.startswith(layer_path + "."):
|
|
138
|
+
suffix = name[len(layer_path) + 1 :]
|
|
139
|
+
# Only match direct children (no dots in suffix)
|
|
140
|
+
if "." not in suffix and suffix.isdigit():
|
|
141
|
+
layers[int(suffix)] = module
|
|
142
|
+
|
|
143
|
+
return layers
|
|
144
|
+
|
|
145
|
+
def _auto_detect_layers(self) -> Dict[int, nn.Module]:
|
|
146
|
+
"""Fallback: auto-detect layers using regex pattern matching."""
|
|
147
|
+
import re
|
|
148
|
+
|
|
149
|
+
layers = {}
|
|
150
|
+
layer_priority = {}
|
|
151
|
+
layer_pattern = re.compile(r"^(.+?)\.(\d+)$")
|
|
152
|
+
|
|
153
|
+
for name, module in self.model.named_modules():
|
|
154
|
+
match = layer_pattern.match(name)
|
|
155
|
+
if not match:
|
|
156
|
+
continue
|
|
157
|
+
|
|
158
|
+
prefix = match.group(1)
|
|
159
|
+
layer_idx = int(match.group(2))
|
|
160
|
+
|
|
161
|
+
# Skip sublayers (modules with numbered children)
|
|
162
|
+
has_numbered_children = any(c.isdigit() for c, _ in module.named_children())
|
|
163
|
+
if has_numbered_children:
|
|
164
|
+
continue
|
|
165
|
+
|
|
166
|
+
# Prioritize language model layers
|
|
167
|
+
if "language_model" in prefix:
|
|
168
|
+
priority = 4
|
|
169
|
+
elif "layers" in prefix or "h" in prefix:
|
|
170
|
+
priority = 1 if ("vision" in prefix or "encoder" in prefix) else 3
|
|
171
|
+
else:
|
|
172
|
+
priority = 2
|
|
173
|
+
|
|
174
|
+
if layer_idx not in layer_priority or priority > layer_priority[layer_idx]:
|
|
175
|
+
layers[layer_idx] = module
|
|
176
|
+
layer_priority[layer_idx] = priority
|
|
177
|
+
|
|
178
|
+
return layers
|
|
179
|
+
|
|
180
|
+
@property
|
|
181
|
+
def num_layers(self) -> int:
|
|
182
|
+
"""Number of hookable layers."""
|
|
183
|
+
return len(self._layer_modules)
|
|
184
|
+
|
|
185
|
+
@property
|
|
186
|
+
def available_layers(self) -> List[int]:
|
|
187
|
+
"""List of available layer indices."""
|
|
188
|
+
return sorted(self._layer_modules.keys())
|
|
189
|
+
|
|
190
|
+
def get_layer_module(self, layer_idx: int) -> nn.Module:
|
|
191
|
+
"""Get the module for a specific layer."""
|
|
192
|
+
if layer_idx not in self._layer_modules:
|
|
193
|
+
raise ValueError(f"Layer {layer_idx} not found. Available: {self.available_layers}")
|
|
194
|
+
return self._layer_modules[layer_idx]
|
|
195
|
+
|
|
196
|
+
def get_residual_module(self, layer_idx: int) -> nn.Module:
|
|
197
|
+
"""
|
|
198
|
+
Get the residual stream hook point for a layer.
|
|
199
|
+
|
|
200
|
+
Returns the post-layer normalization module (e.g., ln_2 for GPT-2,
|
|
201
|
+
post_feedforward_layernorm for Gemma3) which is safer for patching
|
|
202
|
+
than the full layer block.
|
|
203
|
+
"""
|
|
204
|
+
layer_module = self.get_layer_module(layer_idx)
|
|
205
|
+
model_type = getattr(self.model.config, "model_type", None)
|
|
206
|
+
residual_name = self.RESIDUAL_HOOK_POINTS.get(model_type)
|
|
207
|
+
|
|
208
|
+
if residual_name:
|
|
209
|
+
# Try to get the specific residual hook point
|
|
210
|
+
if hasattr(layer_module, residual_name):
|
|
211
|
+
return getattr(layer_module, residual_name)
|
|
212
|
+
|
|
213
|
+
# Fallback: try common names
|
|
214
|
+
for name in [
|
|
215
|
+
"post_feedforward_layernorm",
|
|
216
|
+
"post_attention_layernorm",
|
|
217
|
+
"ln_2",
|
|
218
|
+
"layer_norm",
|
|
219
|
+
]:
|
|
220
|
+
if hasattr(layer_module, name):
|
|
221
|
+
return getattr(layer_module, name)
|
|
222
|
+
|
|
223
|
+
# Last resort: return the layer itself
|
|
224
|
+
return layer_module
|
|
225
|
+
|
|
226
|
+
def get_attention_output_module(self, layer_idx: int) -> nn.Module:
|
|
227
|
+
"""
|
|
228
|
+
Get the attention output projection module for a layer.
|
|
229
|
+
|
|
230
|
+
This is where individual head outputs are concatenated and projected.
|
|
231
|
+
Used for head-level patching interventions.
|
|
232
|
+
"""
|
|
233
|
+
layer_module = self.get_layer_module(layer_idx)
|
|
234
|
+
model_type = getattr(self.model.config, "model_type", None)
|
|
235
|
+
attn_path = self.ATTENTION_OUTPUT_POINTS.get(model_type)
|
|
236
|
+
|
|
237
|
+
if attn_path:
|
|
238
|
+
# Navigate nested path like "self_attn.o_proj"
|
|
239
|
+
parts = attn_path.split(".")
|
|
240
|
+
module = layer_module
|
|
241
|
+
for part in parts:
|
|
242
|
+
if hasattr(module, part):
|
|
243
|
+
module = getattr(module, part)
|
|
244
|
+
else:
|
|
245
|
+
break
|
|
246
|
+
else:
|
|
247
|
+
return module
|
|
248
|
+
|
|
249
|
+
# Fallback: try common attention output names
|
|
250
|
+
for attn_name in ["self_attn", "attn", "attention"]:
|
|
251
|
+
if hasattr(layer_module, attn_name):
|
|
252
|
+
attn = getattr(layer_module, attn_name)
|
|
253
|
+
for proj_name in ["o_proj", "c_proj", "out_proj", "dense"]:
|
|
254
|
+
if hasattr(attn, proj_name):
|
|
255
|
+
return getattr(attn, proj_name)
|
|
256
|
+
|
|
257
|
+
raise ValueError(f"Could not find attention output module for layer {layer_idx}")
|
|
258
|
+
|
|
259
|
+
def get_mlp_down_proj_module(self, layer_idx: int) -> nn.Module:
|
|
260
|
+
"""
|
|
261
|
+
Get the FFN down-projection module for a layer.
|
|
262
|
+
|
|
263
|
+
This is the W_down linear layer whose input is z_t (post-SwiGLU intermediate
|
|
264
|
+
activations). Used for CETT computation and H-Neuron causal interventions.
|
|
265
|
+
"""
|
|
266
|
+
layer_module = self.get_layer_module(layer_idx)
|
|
267
|
+
model_type = getattr(self.model.config, "model_type", None)
|
|
268
|
+
path = self.MLP_DOWN_PROJ_POINTS.get(model_type)
|
|
269
|
+
|
|
270
|
+
if path:
|
|
271
|
+
parts = path.split(".")
|
|
272
|
+
module = layer_module
|
|
273
|
+
for part in parts:
|
|
274
|
+
if hasattr(module, part):
|
|
275
|
+
module = getattr(module, part)
|
|
276
|
+
else:
|
|
277
|
+
break
|
|
278
|
+
else:
|
|
279
|
+
return module
|
|
280
|
+
|
|
281
|
+
# Fallback: common MLP down-projection names
|
|
282
|
+
for mlp_name in ["mlp", "feed_forward", "ffn"]:
|
|
283
|
+
if hasattr(layer_module, mlp_name):
|
|
284
|
+
mlp = getattr(layer_module, mlp_name)
|
|
285
|
+
for proj_name in ["down_proj", "c_proj", "w2", "fc2", "dense_4h_to_h"]:
|
|
286
|
+
if hasattr(mlp, proj_name):
|
|
287
|
+
return getattr(mlp, proj_name)
|
|
288
|
+
|
|
289
|
+
raise ValueError(f"Could not find MLP down-projection module for layer {layer_idx}")
|
|
290
|
+
|
|
291
|
+
def register_forward_hook(
|
|
292
|
+
self, layer_idx: int, hook_fn: Callable[[nn.Module, Any, Any], Optional[Any]]
|
|
293
|
+
) -> torch.utils.hooks.RemovableHandle:
|
|
294
|
+
"""
|
|
295
|
+
Register a forward hook on a specific layer.
|
|
296
|
+
|
|
297
|
+
Args:
|
|
298
|
+
layer_idx: Index of the transformer layer
|
|
299
|
+
hook_fn: Hook function with signature (module, input, output) -> output
|
|
300
|
+
|
|
301
|
+
Returns:
|
|
302
|
+
Handle that can be used to remove the hook
|
|
303
|
+
"""
|
|
304
|
+
module = self.get_layer_module(layer_idx)
|
|
305
|
+
handle = module.register_forward_hook(hook_fn)
|
|
306
|
+
self.handles.append(handle)
|
|
307
|
+
return handle
|
|
308
|
+
|
|
309
|
+
def register_cache_hooks(
|
|
310
|
+
self, cache: "ActivationCache", layers: Optional[List[int]] = None, detach: bool = True
|
|
311
|
+
) -> None:
|
|
312
|
+
"""
|
|
313
|
+
Register hooks to cache activations from specified layers.
|
|
314
|
+
|
|
315
|
+
Args:
|
|
316
|
+
cache: ActivationCache to store activations in
|
|
317
|
+
layers: Which layers to cache (None = all)
|
|
318
|
+
detach: Whether to detach tensors from computation graph
|
|
319
|
+
"""
|
|
320
|
+
target_layers = layers if layers is not None else self.available_layers
|
|
321
|
+
|
|
322
|
+
for layer_idx in target_layers:
|
|
323
|
+
|
|
324
|
+
def make_hook(idx: int):
|
|
325
|
+
def hook(module, input, output):
|
|
326
|
+
# output is typically a tuple (hidden_states, ...)
|
|
327
|
+
if isinstance(output, tuple):
|
|
328
|
+
activation = output[0]
|
|
329
|
+
else:
|
|
330
|
+
activation = output
|
|
331
|
+
|
|
332
|
+
if detach:
|
|
333
|
+
activation = activation.detach().clone()
|
|
334
|
+
|
|
335
|
+
cache.store(idx, activation)
|
|
336
|
+
return output
|
|
337
|
+
|
|
338
|
+
return hook
|
|
339
|
+
|
|
340
|
+
self.register_forward_hook(layer_idx, make_hook(layer_idx))
|
|
341
|
+
|
|
342
|
+
def register_patch_hook(
|
|
343
|
+
self,
|
|
344
|
+
layer_idx: int,
|
|
345
|
+
source_activation: torch.Tensor,
|
|
346
|
+
token_positions: Optional[List[int]] = None,
|
|
347
|
+
) -> torch.utils.hooks.RemovableHandle:
|
|
348
|
+
"""
|
|
349
|
+
Register a hook that patches activations with source values.
|
|
350
|
+
|
|
351
|
+
Args:
|
|
352
|
+
layer_idx: Layer to patch
|
|
353
|
+
source_activation: Activation tensor to patch in
|
|
354
|
+
token_positions: Which positions to patch (None = all)
|
|
355
|
+
|
|
356
|
+
Returns:
|
|
357
|
+
Hook handle
|
|
358
|
+
"""
|
|
359
|
+
|
|
360
|
+
def patch_hook(module, input, output):
|
|
361
|
+
if isinstance(output, tuple):
|
|
362
|
+
hidden_states = output[0]
|
|
363
|
+
rest = output[1:]
|
|
364
|
+
else:
|
|
365
|
+
hidden_states = output
|
|
366
|
+
rest = ()
|
|
367
|
+
|
|
368
|
+
# Skip patching during autoregressive decoding (seq_len=1)
|
|
369
|
+
current_seq_len = hidden_states.shape[1]
|
|
370
|
+
if current_seq_len == 1:
|
|
371
|
+
if rest:
|
|
372
|
+
return (hidden_states,) + rest
|
|
373
|
+
return hidden_states
|
|
374
|
+
|
|
375
|
+
# Create new tensor for patching to avoid in-place operations
|
|
376
|
+
# that can break model internals
|
|
377
|
+
source_seq_len = source_activation.shape[1]
|
|
378
|
+
target_seq_len = hidden_states.shape[1]
|
|
379
|
+
|
|
380
|
+
if token_positions is None:
|
|
381
|
+
# Replace with source activations (truncated/padded as needed)
|
|
382
|
+
if target_seq_len <= source_seq_len:
|
|
383
|
+
# Use source directly (truncated if needed)
|
|
384
|
+
patched = source_activation[:, :target_seq_len, :].contiguous()
|
|
385
|
+
else:
|
|
386
|
+
# Need to pad: copy source then keep remaining from original
|
|
387
|
+
patched = torch.cat(
|
|
388
|
+
[source_activation, hidden_states[:, source_seq_len:, :]], dim=1
|
|
389
|
+
)
|
|
390
|
+
else:
|
|
391
|
+
# Patch specific positions by building new tensor
|
|
392
|
+
patched = hidden_states.clone()
|
|
393
|
+
for pos in token_positions:
|
|
394
|
+
if pos < target_seq_len and pos < source_seq_len:
|
|
395
|
+
patched[:, pos : pos + 1, :] = source_activation[:, pos : pos + 1, :]
|
|
396
|
+
|
|
397
|
+
if rest:
|
|
398
|
+
return (patched,) + rest
|
|
399
|
+
return patched
|
|
400
|
+
|
|
401
|
+
return self.register_forward_hook(layer_idx, patch_hook)
|
|
402
|
+
|
|
403
|
+
def register_residual_cache_hooks(
|
|
404
|
+
self, cache: "ActivationCache", layers: Optional[List[int]] = None, detach: bool = True
|
|
405
|
+
) -> None:
|
|
406
|
+
"""
|
|
407
|
+
Register hooks to cache activations from residual stream (post-layer norm).
|
|
408
|
+
|
|
409
|
+
This is safer than caching from the full layer block as it captures
|
|
410
|
+
the clean residual stream without internal layer state.
|
|
411
|
+
"""
|
|
412
|
+
target_layers = layers if layers is not None else self.available_layers
|
|
413
|
+
|
|
414
|
+
for layer_idx in target_layers:
|
|
415
|
+
residual_module = self.get_residual_module(layer_idx)
|
|
416
|
+
|
|
417
|
+
def make_hook(idx: int):
|
|
418
|
+
def hook(module, input, output):
|
|
419
|
+
activation = output
|
|
420
|
+
if detach:
|
|
421
|
+
activation = activation.detach().clone()
|
|
422
|
+
cache.store(idx, activation)
|
|
423
|
+
return output
|
|
424
|
+
|
|
425
|
+
return hook
|
|
426
|
+
|
|
427
|
+
handle = residual_module.register_forward_hook(make_hook(layer_idx))
|
|
428
|
+
self.handles.append(handle)
|
|
429
|
+
|
|
430
|
+
def register_residual_patch_hook(
|
|
431
|
+
self,
|
|
432
|
+
layer_idx: int,
|
|
433
|
+
source_activation: torch.Tensor,
|
|
434
|
+
token_positions: Optional[List[int]] = None,
|
|
435
|
+
) -> torch.utils.hooks.RemovableHandle:
|
|
436
|
+
"""
|
|
437
|
+
Register a patch hook on the residual stream (post-layer norm).
|
|
438
|
+
|
|
439
|
+
This is safer than patching the full layer block as it only modifies
|
|
440
|
+
the output of the normalization layer without affecting internal state.
|
|
441
|
+
"""
|
|
442
|
+
residual_module = self.get_residual_module(layer_idx)
|
|
443
|
+
|
|
444
|
+
def patch_hook(module, input, output):
|
|
445
|
+
# Residual modules typically output a tensor directly (not tuple)
|
|
446
|
+
hidden_states = output
|
|
447
|
+
|
|
448
|
+
# Handle both 2D [batch, hidden] (Mamba) and 3D [batch, seq, hidden] (Transformer)
|
|
449
|
+
if hidden_states.dim() == 2:
|
|
450
|
+
# Mamba: 2D tensor, no sequence dimension
|
|
451
|
+
# Simply replace with source activation if shapes match
|
|
452
|
+
if source_activation.dim() == 2:
|
|
453
|
+
return source_activation.clone()
|
|
454
|
+
elif source_activation.dim() == 3:
|
|
455
|
+
# Source is 3D, take last token
|
|
456
|
+
return source_activation[:, -1, :].clone()
|
|
457
|
+
return hidden_states
|
|
458
|
+
|
|
459
|
+
# 3D Transformer case
|
|
460
|
+
# Skip single-token decoding
|
|
461
|
+
if hidden_states.shape[1] == 1:
|
|
462
|
+
return hidden_states
|
|
463
|
+
|
|
464
|
+
# Match shapes
|
|
465
|
+
target_len = hidden_states.shape[1]
|
|
466
|
+
source_len = source_activation.shape[1] if source_activation.dim() == 3 else 1
|
|
467
|
+
min_len = min(target_len, source_len)
|
|
468
|
+
|
|
469
|
+
if token_positions is None:
|
|
470
|
+
# Patch overlapping positions
|
|
471
|
+
patched = hidden_states.clone()
|
|
472
|
+
if source_activation.dim() == 3:
|
|
473
|
+
patched[:, :min_len, :] = source_activation[:, :min_len, :]
|
|
474
|
+
else:
|
|
475
|
+
# Source is 2D, expand to match
|
|
476
|
+
patched[:, -1, :] = source_activation
|
|
477
|
+
else:
|
|
478
|
+
patched = hidden_states.clone()
|
|
479
|
+
for pos in token_positions:
|
|
480
|
+
if pos < target_len and source_activation.dim() == 3 and pos < source_len:
|
|
481
|
+
patched[:, pos : pos + 1, :] = source_activation[:, pos : pos + 1, :]
|
|
482
|
+
|
|
483
|
+
return patched
|
|
484
|
+
|
|
485
|
+
handle = residual_module.register_forward_hook(patch_hook)
|
|
486
|
+
self.handles.append(handle)
|
|
487
|
+
return handle
|
|
488
|
+
|
|
489
|
+
def register_attention_cache_hooks(
|
|
490
|
+
self, cache: "ActivationCache", layers: Optional[List[int]] = None, detach: bool = True
|
|
491
|
+
) -> None:
|
|
492
|
+
"""
|
|
493
|
+
Register hooks to cache attention output projections for specific layers.
|
|
494
|
+
|
|
495
|
+
This captures the output of the attention projection (o_proj/c_proj),
|
|
496
|
+
which is suitable for head-level patching.
|
|
497
|
+
"""
|
|
498
|
+
target_layers = layers if layers is not None else self.available_layers
|
|
499
|
+
|
|
500
|
+
for layer_idx in target_layers:
|
|
501
|
+
attn_module = self.get_attention_output_module(layer_idx)
|
|
502
|
+
|
|
503
|
+
def make_hook(idx: int):
|
|
504
|
+
def hook(module, input, output):
|
|
505
|
+
activation = output[0] if isinstance(output, tuple) else output
|
|
506
|
+
if detach:
|
|
507
|
+
activation = activation.detach().clone()
|
|
508
|
+
cache.store(idx, activation)
|
|
509
|
+
return output
|
|
510
|
+
|
|
511
|
+
return hook
|
|
512
|
+
|
|
513
|
+
handle = attn_module.register_forward_hook(make_hook(layer_idx))
|
|
514
|
+
self.handles.append(handle)
|
|
515
|
+
|
|
516
|
+
def register_multi_head_patch_hook(
|
|
517
|
+
self,
|
|
518
|
+
layer_idx: int,
|
|
519
|
+
head_indices: List[int],
|
|
520
|
+
source_activation: torch.Tensor,
|
|
521
|
+
head_dim: int,
|
|
522
|
+
) -> torch.utils.hooks.RemovableHandle:
|
|
523
|
+
"""
|
|
524
|
+
Register a patch hook that patches multiple attention heads at once.
|
|
525
|
+
|
|
526
|
+
Args:
|
|
527
|
+
layer_idx: Layer to patch
|
|
528
|
+
head_indices: List of head indices to patch (e.g., [2, 5, 7])
|
|
529
|
+
source_activation: Activation tensor to patch from
|
|
530
|
+
head_dim: Dimension of each head (hidden_size // num_heads)
|
|
531
|
+
|
|
532
|
+
Returns:
|
|
533
|
+
Hook handle
|
|
534
|
+
"""
|
|
535
|
+
attn_module = self.get_attention_output_module(layer_idx)
|
|
536
|
+
|
|
537
|
+
def patch_hook(module, input, output):
|
|
538
|
+
patched = output.clone()
|
|
539
|
+
for head_idx in head_indices:
|
|
540
|
+
h_start = head_idx * head_dim
|
|
541
|
+
h_end = (head_idx + 1) * head_dim
|
|
542
|
+
# Patch last token position only
|
|
543
|
+
patched[:, -1, h_start:h_end] = source_activation[:, -1, h_start:h_end]
|
|
544
|
+
return patched
|
|
545
|
+
|
|
546
|
+
handle = attn_module.register_forward_hook(patch_hook)
|
|
547
|
+
self.handles.append(handle)
|
|
548
|
+
return handle
|
|
549
|
+
|
|
550
|
+
def remove_all_hooks(self) -> None:
|
|
551
|
+
"""Remove all registered hooks."""
|
|
552
|
+
for handle in self.handles:
|
|
553
|
+
handle.remove()
|
|
554
|
+
self.handles.clear()
|
|
555
|
+
|
|
556
|
+
def __del__(self):
|
|
557
|
+
"""Cleanup hooks on deletion."""
|
|
558
|
+
self.remove_all_hooks()
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
"""Intervention types and specifications."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from enum import Enum, auto
|
|
5
|
+
from typing import Any, Dict, List, Optional
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class InterventionType(Enum):
|
|
9
|
+
"""Types of activation interventions."""
|
|
10
|
+
|
|
11
|
+
PATCH = auto() # Replace with activations from another run
|
|
12
|
+
ZERO = auto() # Zero out activations
|
|
13
|
+
NOISE = auto() # Add Gaussian noise
|
|
14
|
+
MEAN_ABLATE = auto() # Replace with mean activation
|
|
15
|
+
SCALE = auto() # Scale activations by factor
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class Intervention:
|
|
20
|
+
"""Specification for a single activation intervention."""
|
|
21
|
+
|
|
22
|
+
type: InterventionType
|
|
23
|
+
layers: List[int]
|
|
24
|
+
token_positions: Optional[List[int]] = None
|
|
25
|
+
|
|
26
|
+
# Parameters for specific intervention types
|
|
27
|
+
noise_scale: float = 0.1 # For NOISE type
|
|
28
|
+
scale_factor: float = 0.0 # For SCALE type
|
|
29
|
+
source_cache_key: Optional[str] = None # For PATCH type
|
|
30
|
+
|
|
31
|
+
def __repr__(self) -> str:
|
|
32
|
+
pos_str = f", positions={self.token_positions}" if self.token_positions else ""
|
|
33
|
+
return f"Intervention({self.type.name}, layers={self.layers}{pos_str})"
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass
|
|
37
|
+
class PatchingExperimentSpec:
|
|
38
|
+
"""Full specification for a patching experiment."""
|
|
39
|
+
|
|
40
|
+
clean_prompt: str
|
|
41
|
+
corrupted_prompt: str
|
|
42
|
+
interventions: List[Intervention] = field(default_factory=list)
|
|
43
|
+
|
|
44
|
+
expected_clean_answer: Optional[str] = None
|
|
45
|
+
expected_corrupted_answer: Optional[str] = None
|
|
46
|
+
|
|
47
|
+
name: Optional[str] = None
|
|
48
|
+
description: Optional[str] = None
|
|
49
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
50
|
+
|
|
51
|
+
def add_intervention(
|
|
52
|
+
self, type: InterventionType, layers: List[int], **kwargs
|
|
53
|
+
) -> "PatchingExperimentSpec":
|
|
54
|
+
"""Add an intervention (builder pattern)."""
|
|
55
|
+
self.interventions.append(Intervention(type=type, layers=layers, **kwargs))
|
|
56
|
+
return self
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@dataclass
|
|
60
|
+
class LayerImportance:
|
|
61
|
+
"""Results from a layer importance sweep."""
|
|
62
|
+
|
|
63
|
+
layer_idx: int
|
|
64
|
+
effect_size: float # How much patching changed the output
|
|
65
|
+
original_output: str
|
|
66
|
+
patched_output: str
|
|
67
|
+
answer_recovered: bool # Did patching recover the expected answer?
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def is_important(self) -> bool:
|
|
71
|
+
"""Whether this layer appears causally important."""
|
|
72
|
+
return abs(self.effect_size) > 0.1 or self.answer_recovered
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@dataclass
|
|
76
|
+
class ThoughtAnchor:
|
|
77
|
+
"""A token position identified as important for reasoning."""
|
|
78
|
+
|
|
79
|
+
position: int
|
|
80
|
+
token: str
|
|
81
|
+
layer_effects: Dict[int, float] # layer -> effect size
|
|
82
|
+
|
|
83
|
+
@property
|
|
84
|
+
def max_effect(self) -> float:
|
|
85
|
+
"""Maximum effect across all layers."""
|
|
86
|
+
return max(abs(e) for e in self.layer_effects.values()) if self.layer_effects else 0.0
|