cotlab 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. cotlab/__init__.py +3 -0
  2. cotlab/analyse_experiments.py +392 -0
  3. cotlab/analysis/__init__.py +11 -0
  4. cotlab/analysis/cot_parser.py +243 -0
  5. cotlab/analysis/faithfulness_metrics.py +192 -0
  6. cotlab/backends/__init__.py +16 -0
  7. cotlab/backends/base.py +78 -0
  8. cotlab/backends/transformers_backend.py +335 -0
  9. cotlab/backends/vllm_backend.py +227 -0
  10. cotlab/cli.py +83 -0
  11. cotlab/core/__init__.py +34 -0
  12. cotlab/core/base.py +749 -0
  13. cotlab/core/config.py +90 -0
  14. cotlab/core/registry.py +68 -0
  15. cotlab/datasets/__init__.py +45 -0
  16. cotlab/datasets/loaders.py +1889 -0
  17. cotlab/experiment/__init__.py +315 -0
  18. cotlab/experiments/__init__.py +43 -0
  19. cotlab/experiments/activation_compare.py +290 -0
  20. cotlab/experiments/activation_patching.py +1050 -0
  21. cotlab/experiments/attention_analysis.py +885 -0
  22. cotlab/experiments/classification.py +235 -0
  23. cotlab/experiments/composite_shift_detector.py +524 -0
  24. cotlab/experiments/cot_ablation.py +277 -0
  25. cotlab/experiments/cot_faithfulness.py +187 -0
  26. cotlab/experiments/cot_heads.py +208 -0
  27. cotlab/experiments/full_layer_cot.py +232 -0
  28. cotlab/experiments/full_layer_patching.py +225 -0
  29. cotlab/experiments/h_neuron_analysis.py +712 -0
  30. cotlab/experiments/logit_lens.py +439 -0
  31. cotlab/experiments/multi_head_cot.py +220 -0
  32. cotlab/experiments/multi_head_patching.py +229 -0
  33. cotlab/experiments/probing_classifier.py +402 -0
  34. cotlab/experiments/residual_norm_ood.py +413 -0
  35. cotlab/experiments/sae_feature_analysis.py +673 -0
  36. cotlab/experiments/steering_vectors.py +223 -0
  37. cotlab/experiments/sycophancy_heads.py +224 -0
  38. cotlab/logging/__init__.py +5 -0
  39. cotlab/logging/json_logger.py +161 -0
  40. cotlab/main.py +317 -0
  41. cotlab/patching/__init__.py +24 -0
  42. cotlab/patching/cache.py +141 -0
  43. cotlab/patching/hooks.py +558 -0
  44. cotlab/patching/interventions.py +86 -0
  45. cotlab/patching/patcher.py +439 -0
  46. cotlab/patching/sae.py +181 -0
  47. cotlab/prompts/__init__.py +43 -0
  48. cotlab/prompts/cardiology.py +378 -0
  49. cotlab/prompts/histopathology.py +265 -0
  50. cotlab/prompts/length_matched_strategies.py +157 -0
  51. cotlab/prompts/mcq.py +193 -0
  52. cotlab/prompts/neurology.py +353 -0
  53. cotlab/prompts/oncology.py +367 -0
  54. cotlab/prompts/plab.py +162 -0
  55. cotlab/prompts/pubhealthbench.py +82 -0
  56. cotlab/prompts/pubmedqa.py +173 -0
  57. cotlab/prompts/radiology.py +414 -0
  58. cotlab/prompts/strategies.py +939 -0
  59. cotlab/prompts/tcga.py +168 -0
  60. cotlab/runner.py +204 -0
  61. cotlab-0.8.0.dist-info/METADATA +166 -0
  62. cotlab-0.8.0.dist-info/RECORD +65 -0
  63. cotlab-0.8.0.dist-info/WHEEL +4 -0
  64. cotlab-0.8.0.dist-info/entry_points.txt +3 -0
  65. cotlab-0.8.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,558 @@
1
+ """PyTorch forward hook utilities for activation extraction and patching."""
2
+
3
+ from typing import Any, Callable, Dict, List, Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+
9
+ class HookManager:
10
+ """
11
+ Manage PyTorch forward hooks for activation extraction and patching.
12
+
13
+ This provides a clean interface for:
14
+ - Registering hooks on specific layers
15
+ - Caching activations during forward pass
16
+ - Patching activations with custom values
17
+ - Cleanup of all registered hooks
18
+
19
+ Example:
20
+ >>> manager = HookManager(model)
21
+ >>> cache = ActivationCache()
22
+ >>> manager.register_cache_hooks(cache, layers=[0, 5, 10])
23
+ >>> output = model(input_ids) # Activations now in cache
24
+ >>> manager.remove_all_hooks()
25
+ """
26
+
27
+ def __init__(self, model: nn.Module):
28
+ """
29
+ Initialize hook manager for a model.
30
+
31
+ Args:
32
+ model: The transformer model to hook into
33
+ """
34
+ self.model = model
35
+ self.handles: List[torch.utils.hooks.RemovableHandle] = []
36
+ self._layer_modules = self._build_layer_mapping()
37
+
38
+ # Known layer paths by model_type from HF config
39
+ # Focused on Gemma 3, MedGemma, and reasoning models
40
+ LAYER_PATHS = {
41
+ # Gemma 3 / MedGemma family
42
+ "gemma": "model.layers",
43
+ "gemma2": "model.layers",
44
+ "gemma3": "model.layers",
45
+ "gemma3_text": "model.layers",
46
+ # Mistral / Ministral (reasoning models)
47
+ "mistral": "model.layers",
48
+ # Qwen (DeepSeek-R1 distilled)
49
+ "qwen2": "model.layers",
50
+ # Olmo (Think models)
51
+ "olmo": "model.layers",
52
+ "olmo2": "model.layers",
53
+ # Nemotron / Llama-based
54
+ "llama": "model.layers",
55
+ # GPT-2 for testing
56
+ "gpt2": "transformer.h",
57
+ }
58
+
59
+ # Safe residual stream hook points (post-layer normalization)
60
+ # These are the output of the final norm in each layer block
61
+ RESIDUAL_HOOK_POINTS = {
62
+ # Gemma 3 / MedGemma: hook post_feedforward_layernorm (after entire layer)
63
+ "gemma": "post_feedforward_layernorm",
64
+ "gemma2": "post_feedforward_layernorm",
65
+ "gemma3": "post_feedforward_layernorm",
66
+ "gemma3_text": "post_feedforward_layernorm",
67
+ # Mistral / Qwen / Olmo / Llama: hook post_attention_layernorm
68
+ "mistral": "post_attention_layernorm",
69
+ "qwen2": "post_attention_layernorm",
70
+ "olmo": "post_attention_layernorm",
71
+ "olmo2": "post_attention_layernorm",
72
+ "llama": "post_attention_layernorm",
73
+ # GPT-2 for testing
74
+ "gpt2": "ln_2",
75
+ }
76
+
77
+ # FFN down-projection modules — input is z_t (intermediate activations after SwiGLU gate)
78
+ # Used for CETT computation and H-Neuron causal interventions
79
+ MLP_DOWN_PROJ_POINTS = {
80
+ "gemma": "mlp.down_proj",
81
+ "gemma2": "mlp.down_proj",
82
+ "gemma3": "mlp.down_proj",
83
+ "gemma3_text": "mlp.down_proj",
84
+ "mistral": "mlp.down_proj",
85
+ "qwen2": "mlp.down_proj",
86
+ "olmo": "mlp.down_proj",
87
+ "olmo2": "mlp.down_proj",
88
+ "llama": "mlp.down_proj",
89
+ "gpt2": "mlp.c_proj",
90
+ }
91
+
92
+ # Attention output projection modules for head-level patching
93
+ # These modules take concatenated head outputs and project back to hidden dim
94
+ ATTENTION_OUTPUT_POINTS = {
95
+ # Gemma 3 / MedGemma: self_attn.o_proj
96
+ "gemma": "self_attn.o_proj",
97
+ "gemma2": "self_attn.o_proj",
98
+ "gemma3": "self_attn.o_proj",
99
+ "gemma3_text": "self_attn.o_proj",
100
+ # Mistral / Qwen / Olmo / Llama: self_attn.o_proj
101
+ "mistral": "self_attn.o_proj",
102
+ "qwen2": "self_attn.o_proj",
103
+ "olmo": "self_attn.o_proj",
104
+ "olmo2": "self_attn.o_proj",
105
+ "llama": "self_attn.o_proj",
106
+ # GPT-2 for testing
107
+ "gpt2": "attn.c_proj",
108
+ }
109
+
110
+ def _build_layer_mapping(self) -> Dict[int, nn.Module]:
111
+ """
112
+ Auto-detect transformer layers using HF config model_type.
113
+
114
+ Uses known layer paths for common architectures, with fallback
115
+ to regex-based auto-detection for unknown models.
116
+ """
117
+ # Try to get model_type from config
118
+ model_type = getattr(self.model.config, "model_type", None)
119
+ layer_path = self.LAYER_PATHS.get(model_type)
120
+
121
+ if layer_path:
122
+ # Try known path first
123
+ layers = self._get_layers_from_path(layer_path)
124
+ if layers:
125
+ return layers
126
+ # If known path fails (e.g., multimodal model), try auto-detect
127
+
128
+ # Fallback: auto-detect using regex
129
+ return self._auto_detect_layers()
130
+
131
+ def _get_layers_from_path(self, layer_path: str) -> Dict[int, nn.Module]:
132
+ """Get layers from a known path like 'transformer.h' or 'model.layers'."""
133
+ layers = {}
134
+
135
+ for name, module in self.model.named_modules():
136
+ # Match pattern: layer_path.{number}
137
+ if name.startswith(layer_path + "."):
138
+ suffix = name[len(layer_path) + 1 :]
139
+ # Only match direct children (no dots in suffix)
140
+ if "." not in suffix and suffix.isdigit():
141
+ layers[int(suffix)] = module
142
+
143
+ return layers
144
+
145
+ def _auto_detect_layers(self) -> Dict[int, nn.Module]:
146
+ """Fallback: auto-detect layers using regex pattern matching."""
147
+ import re
148
+
149
+ layers = {}
150
+ layer_priority = {}
151
+ layer_pattern = re.compile(r"^(.+?)\.(\d+)$")
152
+
153
+ for name, module in self.model.named_modules():
154
+ match = layer_pattern.match(name)
155
+ if not match:
156
+ continue
157
+
158
+ prefix = match.group(1)
159
+ layer_idx = int(match.group(2))
160
+
161
+ # Skip sublayers (modules with numbered children)
162
+ has_numbered_children = any(c.isdigit() for c, _ in module.named_children())
163
+ if has_numbered_children:
164
+ continue
165
+
166
+ # Prioritize language model layers
167
+ if "language_model" in prefix:
168
+ priority = 4
169
+ elif "layers" in prefix or "h" in prefix:
170
+ priority = 1 if ("vision" in prefix or "encoder" in prefix) else 3
171
+ else:
172
+ priority = 2
173
+
174
+ if layer_idx not in layer_priority or priority > layer_priority[layer_idx]:
175
+ layers[layer_idx] = module
176
+ layer_priority[layer_idx] = priority
177
+
178
+ return layers
179
+
180
+ @property
181
+ def num_layers(self) -> int:
182
+ """Number of hookable layers."""
183
+ return len(self._layer_modules)
184
+
185
+ @property
186
+ def available_layers(self) -> List[int]:
187
+ """List of available layer indices."""
188
+ return sorted(self._layer_modules.keys())
189
+
190
+ def get_layer_module(self, layer_idx: int) -> nn.Module:
191
+ """Get the module for a specific layer."""
192
+ if layer_idx not in self._layer_modules:
193
+ raise ValueError(f"Layer {layer_idx} not found. Available: {self.available_layers}")
194
+ return self._layer_modules[layer_idx]
195
+
196
+ def get_residual_module(self, layer_idx: int) -> nn.Module:
197
+ """
198
+ Get the residual stream hook point for a layer.
199
+
200
+ Returns the post-layer normalization module (e.g., ln_2 for GPT-2,
201
+ post_feedforward_layernorm for Gemma3) which is safer for patching
202
+ than the full layer block.
203
+ """
204
+ layer_module = self.get_layer_module(layer_idx)
205
+ model_type = getattr(self.model.config, "model_type", None)
206
+ residual_name = self.RESIDUAL_HOOK_POINTS.get(model_type)
207
+
208
+ if residual_name:
209
+ # Try to get the specific residual hook point
210
+ if hasattr(layer_module, residual_name):
211
+ return getattr(layer_module, residual_name)
212
+
213
+ # Fallback: try common names
214
+ for name in [
215
+ "post_feedforward_layernorm",
216
+ "post_attention_layernorm",
217
+ "ln_2",
218
+ "layer_norm",
219
+ ]:
220
+ if hasattr(layer_module, name):
221
+ return getattr(layer_module, name)
222
+
223
+ # Last resort: return the layer itself
224
+ return layer_module
225
+
226
+ def get_attention_output_module(self, layer_idx: int) -> nn.Module:
227
+ """
228
+ Get the attention output projection module for a layer.
229
+
230
+ This is where individual head outputs are concatenated and projected.
231
+ Used for head-level patching interventions.
232
+ """
233
+ layer_module = self.get_layer_module(layer_idx)
234
+ model_type = getattr(self.model.config, "model_type", None)
235
+ attn_path = self.ATTENTION_OUTPUT_POINTS.get(model_type)
236
+
237
+ if attn_path:
238
+ # Navigate nested path like "self_attn.o_proj"
239
+ parts = attn_path.split(".")
240
+ module = layer_module
241
+ for part in parts:
242
+ if hasattr(module, part):
243
+ module = getattr(module, part)
244
+ else:
245
+ break
246
+ else:
247
+ return module
248
+
249
+ # Fallback: try common attention output names
250
+ for attn_name in ["self_attn", "attn", "attention"]:
251
+ if hasattr(layer_module, attn_name):
252
+ attn = getattr(layer_module, attn_name)
253
+ for proj_name in ["o_proj", "c_proj", "out_proj", "dense"]:
254
+ if hasattr(attn, proj_name):
255
+ return getattr(attn, proj_name)
256
+
257
+ raise ValueError(f"Could not find attention output module for layer {layer_idx}")
258
+
259
+ def get_mlp_down_proj_module(self, layer_idx: int) -> nn.Module:
260
+ """
261
+ Get the FFN down-projection module for a layer.
262
+
263
+ This is the W_down linear layer whose input is z_t (post-SwiGLU intermediate
264
+ activations). Used for CETT computation and H-Neuron causal interventions.
265
+ """
266
+ layer_module = self.get_layer_module(layer_idx)
267
+ model_type = getattr(self.model.config, "model_type", None)
268
+ path = self.MLP_DOWN_PROJ_POINTS.get(model_type)
269
+
270
+ if path:
271
+ parts = path.split(".")
272
+ module = layer_module
273
+ for part in parts:
274
+ if hasattr(module, part):
275
+ module = getattr(module, part)
276
+ else:
277
+ break
278
+ else:
279
+ return module
280
+
281
+ # Fallback: common MLP down-projection names
282
+ for mlp_name in ["mlp", "feed_forward", "ffn"]:
283
+ if hasattr(layer_module, mlp_name):
284
+ mlp = getattr(layer_module, mlp_name)
285
+ for proj_name in ["down_proj", "c_proj", "w2", "fc2", "dense_4h_to_h"]:
286
+ if hasattr(mlp, proj_name):
287
+ return getattr(mlp, proj_name)
288
+
289
+ raise ValueError(f"Could not find MLP down-projection module for layer {layer_idx}")
290
+
291
+ def register_forward_hook(
292
+ self, layer_idx: int, hook_fn: Callable[[nn.Module, Any, Any], Optional[Any]]
293
+ ) -> torch.utils.hooks.RemovableHandle:
294
+ """
295
+ Register a forward hook on a specific layer.
296
+
297
+ Args:
298
+ layer_idx: Index of the transformer layer
299
+ hook_fn: Hook function with signature (module, input, output) -> output
300
+
301
+ Returns:
302
+ Handle that can be used to remove the hook
303
+ """
304
+ module = self.get_layer_module(layer_idx)
305
+ handle = module.register_forward_hook(hook_fn)
306
+ self.handles.append(handle)
307
+ return handle
308
+
309
+ def register_cache_hooks(
310
+ self, cache: "ActivationCache", layers: Optional[List[int]] = None, detach: bool = True
311
+ ) -> None:
312
+ """
313
+ Register hooks to cache activations from specified layers.
314
+
315
+ Args:
316
+ cache: ActivationCache to store activations in
317
+ layers: Which layers to cache (None = all)
318
+ detach: Whether to detach tensors from computation graph
319
+ """
320
+ target_layers = layers if layers is not None else self.available_layers
321
+
322
+ for layer_idx in target_layers:
323
+
324
+ def make_hook(idx: int):
325
+ def hook(module, input, output):
326
+ # output is typically a tuple (hidden_states, ...)
327
+ if isinstance(output, tuple):
328
+ activation = output[0]
329
+ else:
330
+ activation = output
331
+
332
+ if detach:
333
+ activation = activation.detach().clone()
334
+
335
+ cache.store(idx, activation)
336
+ return output
337
+
338
+ return hook
339
+
340
+ self.register_forward_hook(layer_idx, make_hook(layer_idx))
341
+
342
+ def register_patch_hook(
343
+ self,
344
+ layer_idx: int,
345
+ source_activation: torch.Tensor,
346
+ token_positions: Optional[List[int]] = None,
347
+ ) -> torch.utils.hooks.RemovableHandle:
348
+ """
349
+ Register a hook that patches activations with source values.
350
+
351
+ Args:
352
+ layer_idx: Layer to patch
353
+ source_activation: Activation tensor to patch in
354
+ token_positions: Which positions to patch (None = all)
355
+
356
+ Returns:
357
+ Hook handle
358
+ """
359
+
360
+ def patch_hook(module, input, output):
361
+ if isinstance(output, tuple):
362
+ hidden_states = output[0]
363
+ rest = output[1:]
364
+ else:
365
+ hidden_states = output
366
+ rest = ()
367
+
368
+ # Skip patching during autoregressive decoding (seq_len=1)
369
+ current_seq_len = hidden_states.shape[1]
370
+ if current_seq_len == 1:
371
+ if rest:
372
+ return (hidden_states,) + rest
373
+ return hidden_states
374
+
375
+ # Create new tensor for patching to avoid in-place operations
376
+ # that can break model internals
377
+ source_seq_len = source_activation.shape[1]
378
+ target_seq_len = hidden_states.shape[1]
379
+
380
+ if token_positions is None:
381
+ # Replace with source activations (truncated/padded as needed)
382
+ if target_seq_len <= source_seq_len:
383
+ # Use source directly (truncated if needed)
384
+ patched = source_activation[:, :target_seq_len, :].contiguous()
385
+ else:
386
+ # Need to pad: copy source then keep remaining from original
387
+ patched = torch.cat(
388
+ [source_activation, hidden_states[:, source_seq_len:, :]], dim=1
389
+ )
390
+ else:
391
+ # Patch specific positions by building new tensor
392
+ patched = hidden_states.clone()
393
+ for pos in token_positions:
394
+ if pos < target_seq_len and pos < source_seq_len:
395
+ patched[:, pos : pos + 1, :] = source_activation[:, pos : pos + 1, :]
396
+
397
+ if rest:
398
+ return (patched,) + rest
399
+ return patched
400
+
401
+ return self.register_forward_hook(layer_idx, patch_hook)
402
+
403
+ def register_residual_cache_hooks(
404
+ self, cache: "ActivationCache", layers: Optional[List[int]] = None, detach: bool = True
405
+ ) -> None:
406
+ """
407
+ Register hooks to cache activations from residual stream (post-layer norm).
408
+
409
+ This is safer than caching from the full layer block as it captures
410
+ the clean residual stream without internal layer state.
411
+ """
412
+ target_layers = layers if layers is not None else self.available_layers
413
+
414
+ for layer_idx in target_layers:
415
+ residual_module = self.get_residual_module(layer_idx)
416
+
417
+ def make_hook(idx: int):
418
+ def hook(module, input, output):
419
+ activation = output
420
+ if detach:
421
+ activation = activation.detach().clone()
422
+ cache.store(idx, activation)
423
+ return output
424
+
425
+ return hook
426
+
427
+ handle = residual_module.register_forward_hook(make_hook(layer_idx))
428
+ self.handles.append(handle)
429
+
430
+ def register_residual_patch_hook(
431
+ self,
432
+ layer_idx: int,
433
+ source_activation: torch.Tensor,
434
+ token_positions: Optional[List[int]] = None,
435
+ ) -> torch.utils.hooks.RemovableHandle:
436
+ """
437
+ Register a patch hook on the residual stream (post-layer norm).
438
+
439
+ This is safer than patching the full layer block as it only modifies
440
+ the output of the normalization layer without affecting internal state.
441
+ """
442
+ residual_module = self.get_residual_module(layer_idx)
443
+
444
+ def patch_hook(module, input, output):
445
+ # Residual modules typically output a tensor directly (not tuple)
446
+ hidden_states = output
447
+
448
+ # Handle both 2D [batch, hidden] (Mamba) and 3D [batch, seq, hidden] (Transformer)
449
+ if hidden_states.dim() == 2:
450
+ # Mamba: 2D tensor, no sequence dimension
451
+ # Simply replace with source activation if shapes match
452
+ if source_activation.dim() == 2:
453
+ return source_activation.clone()
454
+ elif source_activation.dim() == 3:
455
+ # Source is 3D, take last token
456
+ return source_activation[:, -1, :].clone()
457
+ return hidden_states
458
+
459
+ # 3D Transformer case
460
+ # Skip single-token decoding
461
+ if hidden_states.shape[1] == 1:
462
+ return hidden_states
463
+
464
+ # Match shapes
465
+ target_len = hidden_states.shape[1]
466
+ source_len = source_activation.shape[1] if source_activation.dim() == 3 else 1
467
+ min_len = min(target_len, source_len)
468
+
469
+ if token_positions is None:
470
+ # Patch overlapping positions
471
+ patched = hidden_states.clone()
472
+ if source_activation.dim() == 3:
473
+ patched[:, :min_len, :] = source_activation[:, :min_len, :]
474
+ else:
475
+ # Source is 2D, expand to match
476
+ patched[:, -1, :] = source_activation
477
+ else:
478
+ patched = hidden_states.clone()
479
+ for pos in token_positions:
480
+ if pos < target_len and source_activation.dim() == 3 and pos < source_len:
481
+ patched[:, pos : pos + 1, :] = source_activation[:, pos : pos + 1, :]
482
+
483
+ return patched
484
+
485
+ handle = residual_module.register_forward_hook(patch_hook)
486
+ self.handles.append(handle)
487
+ return handle
488
+
489
+ def register_attention_cache_hooks(
490
+ self, cache: "ActivationCache", layers: Optional[List[int]] = None, detach: bool = True
491
+ ) -> None:
492
+ """
493
+ Register hooks to cache attention output projections for specific layers.
494
+
495
+ This captures the output of the attention projection (o_proj/c_proj),
496
+ which is suitable for head-level patching.
497
+ """
498
+ target_layers = layers if layers is not None else self.available_layers
499
+
500
+ for layer_idx in target_layers:
501
+ attn_module = self.get_attention_output_module(layer_idx)
502
+
503
+ def make_hook(idx: int):
504
+ def hook(module, input, output):
505
+ activation = output[0] if isinstance(output, tuple) else output
506
+ if detach:
507
+ activation = activation.detach().clone()
508
+ cache.store(idx, activation)
509
+ return output
510
+
511
+ return hook
512
+
513
+ handle = attn_module.register_forward_hook(make_hook(layer_idx))
514
+ self.handles.append(handle)
515
+
516
+ def register_multi_head_patch_hook(
517
+ self,
518
+ layer_idx: int,
519
+ head_indices: List[int],
520
+ source_activation: torch.Tensor,
521
+ head_dim: int,
522
+ ) -> torch.utils.hooks.RemovableHandle:
523
+ """
524
+ Register a patch hook that patches multiple attention heads at once.
525
+
526
+ Args:
527
+ layer_idx: Layer to patch
528
+ head_indices: List of head indices to patch (e.g., [2, 5, 7])
529
+ source_activation: Activation tensor to patch from
530
+ head_dim: Dimension of each head (hidden_size // num_heads)
531
+
532
+ Returns:
533
+ Hook handle
534
+ """
535
+ attn_module = self.get_attention_output_module(layer_idx)
536
+
537
+ def patch_hook(module, input, output):
538
+ patched = output.clone()
539
+ for head_idx in head_indices:
540
+ h_start = head_idx * head_dim
541
+ h_end = (head_idx + 1) * head_dim
542
+ # Patch last token position only
543
+ patched[:, -1, h_start:h_end] = source_activation[:, -1, h_start:h_end]
544
+ return patched
545
+
546
+ handle = attn_module.register_forward_hook(patch_hook)
547
+ self.handles.append(handle)
548
+ return handle
549
+
550
+ def remove_all_hooks(self) -> None:
551
+ """Remove all registered hooks."""
552
+ for handle in self.handles:
553
+ handle.remove()
554
+ self.handles.clear()
555
+
556
+ def __del__(self):
557
+ """Cleanup hooks on deletion."""
558
+ self.remove_all_hooks()
@@ -0,0 +1,86 @@
1
+ """Intervention types and specifications."""
2
+
3
+ from dataclasses import dataclass, field
4
+ from enum import Enum, auto
5
+ from typing import Any, Dict, List, Optional
6
+
7
+
8
+ class InterventionType(Enum):
9
+ """Types of activation interventions."""
10
+
11
+ PATCH = auto() # Replace with activations from another run
12
+ ZERO = auto() # Zero out activations
13
+ NOISE = auto() # Add Gaussian noise
14
+ MEAN_ABLATE = auto() # Replace with mean activation
15
+ SCALE = auto() # Scale activations by factor
16
+
17
+
18
+ @dataclass
19
+ class Intervention:
20
+ """Specification for a single activation intervention."""
21
+
22
+ type: InterventionType
23
+ layers: List[int]
24
+ token_positions: Optional[List[int]] = None
25
+
26
+ # Parameters for specific intervention types
27
+ noise_scale: float = 0.1 # For NOISE type
28
+ scale_factor: float = 0.0 # For SCALE type
29
+ source_cache_key: Optional[str] = None # For PATCH type
30
+
31
+ def __repr__(self) -> str:
32
+ pos_str = f", positions={self.token_positions}" if self.token_positions else ""
33
+ return f"Intervention({self.type.name}, layers={self.layers}{pos_str})"
34
+
35
+
36
+ @dataclass
37
+ class PatchingExperimentSpec:
38
+ """Full specification for a patching experiment."""
39
+
40
+ clean_prompt: str
41
+ corrupted_prompt: str
42
+ interventions: List[Intervention] = field(default_factory=list)
43
+
44
+ expected_clean_answer: Optional[str] = None
45
+ expected_corrupted_answer: Optional[str] = None
46
+
47
+ name: Optional[str] = None
48
+ description: Optional[str] = None
49
+ metadata: Dict[str, Any] = field(default_factory=dict)
50
+
51
+ def add_intervention(
52
+ self, type: InterventionType, layers: List[int], **kwargs
53
+ ) -> "PatchingExperimentSpec":
54
+ """Add an intervention (builder pattern)."""
55
+ self.interventions.append(Intervention(type=type, layers=layers, **kwargs))
56
+ return self
57
+
58
+
59
+ @dataclass
60
+ class LayerImportance:
61
+ """Results from a layer importance sweep."""
62
+
63
+ layer_idx: int
64
+ effect_size: float # How much patching changed the output
65
+ original_output: str
66
+ patched_output: str
67
+ answer_recovered: bool # Did patching recover the expected answer?
68
+
69
+ @property
70
+ def is_important(self) -> bool:
71
+ """Whether this layer appears causally important."""
72
+ return abs(self.effect_size) > 0.1 or self.answer_recovered
73
+
74
+
75
+ @dataclass
76
+ class ThoughtAnchor:
77
+ """A token position identified as important for reasoning."""
78
+
79
+ position: int
80
+ token: str
81
+ layer_effects: Dict[int, float] # layer -> effect size
82
+
83
+ @property
84
+ def max_effect(self) -> float:
85
+ """Maximum effect across all layers."""
86
+ return max(abs(e) for e in self.layer_effects.values()) if self.layer_effects else 0.0