cotlab 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cotlab/__init__.py +3 -0
- cotlab/analyse_experiments.py +392 -0
- cotlab/analysis/__init__.py +11 -0
- cotlab/analysis/cot_parser.py +243 -0
- cotlab/analysis/faithfulness_metrics.py +192 -0
- cotlab/backends/__init__.py +16 -0
- cotlab/backends/base.py +78 -0
- cotlab/backends/transformers_backend.py +335 -0
- cotlab/backends/vllm_backend.py +227 -0
- cotlab/cli.py +83 -0
- cotlab/core/__init__.py +34 -0
- cotlab/core/base.py +749 -0
- cotlab/core/config.py +90 -0
- cotlab/core/registry.py +68 -0
- cotlab/datasets/__init__.py +45 -0
- cotlab/datasets/loaders.py +1889 -0
- cotlab/experiment/__init__.py +315 -0
- cotlab/experiments/__init__.py +43 -0
- cotlab/experiments/activation_compare.py +290 -0
- cotlab/experiments/activation_patching.py +1050 -0
- cotlab/experiments/attention_analysis.py +885 -0
- cotlab/experiments/classification.py +235 -0
- cotlab/experiments/composite_shift_detector.py +524 -0
- cotlab/experiments/cot_ablation.py +277 -0
- cotlab/experiments/cot_faithfulness.py +187 -0
- cotlab/experiments/cot_heads.py +208 -0
- cotlab/experiments/full_layer_cot.py +232 -0
- cotlab/experiments/full_layer_patching.py +225 -0
- cotlab/experiments/h_neuron_analysis.py +712 -0
- cotlab/experiments/logit_lens.py +439 -0
- cotlab/experiments/multi_head_cot.py +220 -0
- cotlab/experiments/multi_head_patching.py +229 -0
- cotlab/experiments/probing_classifier.py +402 -0
- cotlab/experiments/residual_norm_ood.py +413 -0
- cotlab/experiments/sae_feature_analysis.py +673 -0
- cotlab/experiments/steering_vectors.py +223 -0
- cotlab/experiments/sycophancy_heads.py +224 -0
- cotlab/logging/__init__.py +5 -0
- cotlab/logging/json_logger.py +161 -0
- cotlab/main.py +317 -0
- cotlab/patching/__init__.py +24 -0
- cotlab/patching/cache.py +141 -0
- cotlab/patching/hooks.py +558 -0
- cotlab/patching/interventions.py +86 -0
- cotlab/patching/patcher.py +439 -0
- cotlab/patching/sae.py +181 -0
- cotlab/prompts/__init__.py +43 -0
- cotlab/prompts/cardiology.py +378 -0
- cotlab/prompts/histopathology.py +265 -0
- cotlab/prompts/length_matched_strategies.py +157 -0
- cotlab/prompts/mcq.py +193 -0
- cotlab/prompts/neurology.py +353 -0
- cotlab/prompts/oncology.py +367 -0
- cotlab/prompts/plab.py +162 -0
- cotlab/prompts/pubhealthbench.py +82 -0
- cotlab/prompts/pubmedqa.py +173 -0
- cotlab/prompts/radiology.py +414 -0
- cotlab/prompts/strategies.py +939 -0
- cotlab/prompts/tcga.py +168 -0
- cotlab/runner.py +204 -0
- cotlab-0.8.0.dist-info/METADATA +166 -0
- cotlab-0.8.0.dist-info/RECORD +65 -0
- cotlab-0.8.0.dist-info/WHEEL +4 -0
- cotlab-0.8.0.dist-info/entry_points.txt +3 -0
- cotlab-0.8.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,439 @@
|
|
|
1
|
+
"""Activation patcher for causal intervention experiments."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Dict, Iterable, List, Optional, Tuple
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from .cache import ActivationCache
|
|
9
|
+
from .hooks import HookManager
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class PatchingResult:
|
|
14
|
+
"""Result from a patching operation."""
|
|
15
|
+
|
|
16
|
+
output_text: str
|
|
17
|
+
patched_layers: List[int]
|
|
18
|
+
patched_positions: Optional[List[int]]
|
|
19
|
+
original_answer: Optional[str] = None
|
|
20
|
+
patched_answer: Optional[str] = None
|
|
21
|
+
|
|
22
|
+
@property
|
|
23
|
+
def answer_changed(self) -> bool:
|
|
24
|
+
"""Check if patching changed the answer."""
|
|
25
|
+
if self.original_answer is None or self.patched_answer is None:
|
|
26
|
+
return False
|
|
27
|
+
return self.original_answer.strip() != self.patched_answer.strip()
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class ForwardPatchResult:
|
|
32
|
+
"""Result from forward-only patching (no generation)."""
|
|
33
|
+
|
|
34
|
+
layer_idx: int
|
|
35
|
+
clean_logits: torch.Tensor
|
|
36
|
+
corrupted_logits: torch.Tensor
|
|
37
|
+
patched_logits: torch.Tensor
|
|
38
|
+
effect_score: float # How much patching recovered clean behavior
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def kl_from_clean(self) -> float:
|
|
42
|
+
"""KL divergence of patched from clean distribution."""
|
|
43
|
+
clean_probs = torch.softmax(self.clean_logits[:, -1, :], dim=-1)
|
|
44
|
+
patched_probs = torch.softmax(self.patched_logits[:, -1, :], dim=-1)
|
|
45
|
+
return torch.nn.functional.kl_div(
|
|
46
|
+
patched_probs.log(), clean_probs, reduction="batchmean"
|
|
47
|
+
).item()
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
def kl_from_corrupted(self) -> float:
|
|
51
|
+
"""KL divergence of patched from corrupted distribution."""
|
|
52
|
+
corrupted_probs = torch.softmax(self.corrupted_logits[:, -1, :], dim=-1)
|
|
53
|
+
patched_probs = torch.softmax(self.patched_logits[:, -1, :], dim=-1)
|
|
54
|
+
return torch.nn.functional.kl_div(
|
|
55
|
+
patched_probs.log(), corrupted_probs, reduction="batchmean"
|
|
56
|
+
).item()
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class ActivationPatcher:
|
|
60
|
+
"""
|
|
61
|
+
Perform activation patching interventions for causal analysis.
|
|
62
|
+
|
|
63
|
+
Activation patching is a technique to test causal importance of
|
|
64
|
+
specific model components by replacing activations from one run
|
|
65
|
+
with activations from another run.
|
|
66
|
+
|
|
67
|
+
Example:
|
|
68
|
+
>>> patcher = ActivationPatcher(backend)
|
|
69
|
+
>>> clean_out, clean_cache = backend.generate_with_cache(clean_prompt)
|
|
70
|
+
>>> results = patcher.patch_run(
|
|
71
|
+
... corrupted_prompt,
|
|
72
|
+
... source_cache=clean_cache,
|
|
73
|
+
... target_layers=[5, 10, 15]
|
|
74
|
+
... )
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
def __init__(self, backend: "TransformersBackend"):
|
|
78
|
+
"""
|
|
79
|
+
Initialize patcher with a backend.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
backend: TransformersBackend with hook support
|
|
83
|
+
"""
|
|
84
|
+
if not backend.supports_activations:
|
|
85
|
+
raise ValueError("Backend does not support activation access")
|
|
86
|
+
|
|
87
|
+
self.backend = backend
|
|
88
|
+
|
|
89
|
+
@property
|
|
90
|
+
def hook_manager(self) -> HookManager:
|
|
91
|
+
return self.backend.hook_manager
|
|
92
|
+
|
|
93
|
+
def patch_run(
|
|
94
|
+
self,
|
|
95
|
+
prompt: str,
|
|
96
|
+
source_cache: ActivationCache,
|
|
97
|
+
target_layers: List[int],
|
|
98
|
+
token_positions: Optional[List[int]] = None,
|
|
99
|
+
**gen_kwargs,
|
|
100
|
+
) -> PatchingResult:
|
|
101
|
+
"""
|
|
102
|
+
Run generation while patching activations from source_cache.
|
|
103
|
+
|
|
104
|
+
This is the core patching operation: run the model on `prompt`
|
|
105
|
+
but replace activations at `target_layers` with values from
|
|
106
|
+
`source_cache` (typically from a "clean" run).
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
prompt: Input prompt (typically the "corrupted" version)
|
|
110
|
+
source_cache: Activations from another run to patch in
|
|
111
|
+
target_layers: Which layers to patch
|
|
112
|
+
token_positions: Which token positions to patch (None = all)
|
|
113
|
+
**gen_kwargs: Additional generation arguments
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
PatchingResult with output and metadata
|
|
117
|
+
"""
|
|
118
|
+
# Register patching hooks
|
|
119
|
+
for layer_idx in target_layers:
|
|
120
|
+
source_activation = source_cache.get(layer_idx)
|
|
121
|
+
if source_activation is None:
|
|
122
|
+
raise ValueError(f"Layer {layer_idx} not in source cache")
|
|
123
|
+
|
|
124
|
+
self.hook_manager.register_patch_hook(layer_idx, source_activation, token_positions)
|
|
125
|
+
|
|
126
|
+
try:
|
|
127
|
+
output = self.backend.generate(prompt, **gen_kwargs)
|
|
128
|
+
finally:
|
|
129
|
+
self.hook_manager.remove_all_hooks()
|
|
130
|
+
|
|
131
|
+
return PatchingResult(
|
|
132
|
+
output_text=output.text, patched_layers=target_layers, patched_positions=token_positions
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
def sweep_layers(
|
|
136
|
+
self,
|
|
137
|
+
clean_prompt: str,
|
|
138
|
+
corrupted_prompt: str,
|
|
139
|
+
layers: Optional[List[int]] = None,
|
|
140
|
+
**gen_kwargs,
|
|
141
|
+
) -> Dict[int, PatchingResult]:
|
|
142
|
+
"""
|
|
143
|
+
Sweep patching across layers to find causal importance.
|
|
144
|
+
|
|
145
|
+
For each layer, patches that layer's activations from the clean
|
|
146
|
+
run into the corrupted run and measures the effect.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
clean_prompt: Prompt that elicits desired behavior
|
|
150
|
+
corrupted_prompt: Prompt that elicits different behavior
|
|
151
|
+
layers: Which layers to sweep (None = all)
|
|
152
|
+
**gen_kwargs: Generation arguments
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
Dict mapping layer_idx -> ForwardPatchResult
|
|
156
|
+
"""
|
|
157
|
+
# Get clean logits and activations
|
|
158
|
+
clean_logits, clean_cache = self.backend.forward_with_cache(clean_prompt, layers=layers)
|
|
159
|
+
|
|
160
|
+
# Get corrupted logits (no patching)
|
|
161
|
+
corrupted_logits, _ = self.backend.forward_with_cache(corrupted_prompt, layers=layers)
|
|
162
|
+
|
|
163
|
+
target_layers = layers if layers is not None else clean_cache.layers
|
|
164
|
+
results = {}
|
|
165
|
+
|
|
166
|
+
for layer_idx in target_layers:
|
|
167
|
+
# Run forward pass with patching on this layer
|
|
168
|
+
patched_logits = self._forward_with_patch(corrupted_prompt, clean_cache, layer_idx)
|
|
169
|
+
|
|
170
|
+
# Compute effect score (how much patching moves toward clean)
|
|
171
|
+
effect = self._compute_logit_effect(clean_logits, corrupted_logits, patched_logits)
|
|
172
|
+
|
|
173
|
+
results[layer_idx] = ForwardPatchResult(
|
|
174
|
+
layer_idx=layer_idx,
|
|
175
|
+
clean_logits=clean_logits.cpu(),
|
|
176
|
+
corrupted_logits=corrupted_logits.cpu(),
|
|
177
|
+
patched_logits=patched_logits.cpu(),
|
|
178
|
+
effect_score=effect,
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
return results
|
|
182
|
+
|
|
183
|
+
def sweep_heads(
|
|
184
|
+
self,
|
|
185
|
+
clean_prompt: str,
|
|
186
|
+
corrupted_prompt: str,
|
|
187
|
+
layers: Optional[List[int]] = None,
|
|
188
|
+
head_indices: Optional[Iterable[int]] = None,
|
|
189
|
+
target_heads: Optional[Dict[int, Iterable[int]]] = None,
|
|
190
|
+
) -> Dict[Tuple[int, int], ForwardPatchResult]:
|
|
191
|
+
"""
|
|
192
|
+
Sweep patching across attention heads to find causal importance.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
clean_prompt: Prompt that elicits desired behavior
|
|
196
|
+
corrupted_prompt: Prompt that elicits different behavior
|
|
197
|
+
layers: Which layers to sweep (None = all)
|
|
198
|
+
head_indices: Heads to patch for all layers (e.g., [0,1,2])
|
|
199
|
+
target_heads: Dict mapping layer -> list of head indices
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
Dict mapping (layer_idx, head_idx) -> ForwardPatchResult
|
|
203
|
+
"""
|
|
204
|
+
if head_indices is None and not target_heads:
|
|
205
|
+
raise ValueError("Provide head_indices or target_heads for head patching.")
|
|
206
|
+
if head_indices is not None and target_heads:
|
|
207
|
+
raise ValueError("Use either head_indices or target_heads, not both.")
|
|
208
|
+
|
|
209
|
+
layers_to_sweep = layers
|
|
210
|
+
if target_heads:
|
|
211
|
+
layers_to_sweep = list(target_heads.keys())
|
|
212
|
+
|
|
213
|
+
# Get clean logits and attention outputs
|
|
214
|
+
clean_logits, clean_cache = self.backend.forward_with_attention_cache(
|
|
215
|
+
clean_prompt, layers=layers_to_sweep
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
# Get corrupted logits (no patching)
|
|
219
|
+
corrupted_logits, _ = self.backend.forward_with_cache(corrupted_prompt, layers=[])
|
|
220
|
+
|
|
221
|
+
num_heads, head_dim = self._get_head_info()
|
|
222
|
+
|
|
223
|
+
results: Dict[Tuple[int, int], ForwardPatchResult] = {}
|
|
224
|
+
layers_to_sweep = layers_to_sweep if layers_to_sweep is not None else clean_cache.layers
|
|
225
|
+
|
|
226
|
+
for layer_idx in layers_to_sweep:
|
|
227
|
+
heads = list(target_heads.get(layer_idx, [])) if target_heads else list(head_indices)
|
|
228
|
+
for head_idx in heads:
|
|
229
|
+
if head_idx < 0 or head_idx >= num_heads:
|
|
230
|
+
raise ValueError(f"Head index {head_idx} out of range (0..{num_heads - 1})")
|
|
231
|
+
|
|
232
|
+
patched_logits = self._forward_with_head_patch(
|
|
233
|
+
corrupted_prompt, clean_cache, layer_idx, [head_idx], head_dim
|
|
234
|
+
)
|
|
235
|
+
effect = self._compute_logit_effect(clean_logits, corrupted_logits, patched_logits)
|
|
236
|
+
|
|
237
|
+
results[(layer_idx, head_idx)] = ForwardPatchResult(
|
|
238
|
+
layer_idx=layer_idx,
|
|
239
|
+
clean_logits=clean_logits.cpu(),
|
|
240
|
+
corrupted_logits=corrupted_logits.cpu(),
|
|
241
|
+
patched_logits=patched_logits.cpu(),
|
|
242
|
+
effect_score=effect,
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
return results
|
|
246
|
+
|
|
247
|
+
def _forward_with_patch(
|
|
248
|
+
self,
|
|
249
|
+
prompt: str,
|
|
250
|
+
source_cache: ActivationCache,
|
|
251
|
+
layer_idx: int,
|
|
252
|
+
) -> torch.Tensor:
|
|
253
|
+
"""Run a single forward pass with patching at specified layer's residual stream."""
|
|
254
|
+
source_activation = source_cache.get(layer_idx)
|
|
255
|
+
if source_activation is None:
|
|
256
|
+
raise ValueError(f"Layer {layer_idx} not in source cache")
|
|
257
|
+
|
|
258
|
+
# Use residual patch hook for safer patching
|
|
259
|
+
self.hook_manager.register_residual_patch_hook(layer_idx, source_activation, None)
|
|
260
|
+
|
|
261
|
+
try:
|
|
262
|
+
logits, _ = self.backend.forward_with_cache(prompt, layers=[])
|
|
263
|
+
finally:
|
|
264
|
+
self.hook_manager.remove_all_hooks()
|
|
265
|
+
|
|
266
|
+
return logits
|
|
267
|
+
|
|
268
|
+
def _forward_with_head_patch(
|
|
269
|
+
self,
|
|
270
|
+
prompt: str,
|
|
271
|
+
source_cache: ActivationCache,
|
|
272
|
+
layer_idx: int,
|
|
273
|
+
head_indices: List[int],
|
|
274
|
+
head_dim: int,
|
|
275
|
+
) -> torch.Tensor:
|
|
276
|
+
"""Run forward pass with attention head patching at a specific layer."""
|
|
277
|
+
source_activation = source_cache.get(layer_idx)
|
|
278
|
+
if source_activation is None:
|
|
279
|
+
raise ValueError(f"Layer {layer_idx} not in source cache")
|
|
280
|
+
|
|
281
|
+
self.hook_manager.register_multi_head_patch_hook(
|
|
282
|
+
layer_idx, head_indices, source_activation, head_dim
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
try:
|
|
286
|
+
logits, _ = self.backend.forward_with_cache(prompt, layers=[])
|
|
287
|
+
finally:
|
|
288
|
+
self.hook_manager.remove_all_hooks()
|
|
289
|
+
|
|
290
|
+
return logits
|
|
291
|
+
|
|
292
|
+
def _get_head_info(self) -> Tuple[int, int]:
|
|
293
|
+
"""Return (num_heads, head_dim) from the model config."""
|
|
294
|
+
cfg = getattr(self.backend, "model", None).config
|
|
295
|
+
# Handle multimodal models with nested text_config
|
|
296
|
+
if hasattr(cfg, "text_config"):
|
|
297
|
+
cfg = cfg.text_config
|
|
298
|
+
num_heads = (
|
|
299
|
+
getattr(cfg, "num_attention_heads", None)
|
|
300
|
+
or getattr(cfg, "n_head", None)
|
|
301
|
+
or getattr(cfg, "num_heads", None)
|
|
302
|
+
)
|
|
303
|
+
head_dim = getattr(cfg, "head_dim", None)
|
|
304
|
+
hidden_size = (
|
|
305
|
+
getattr(cfg, "hidden_size", None)
|
|
306
|
+
or getattr(cfg, "hidden_dim", None)
|
|
307
|
+
or getattr(cfg, "d_model", None)
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
if head_dim is None and num_heads and hidden_size:
|
|
311
|
+
head_dim = hidden_size // num_heads
|
|
312
|
+
|
|
313
|
+
if not num_heads or not head_dim:
|
|
314
|
+
raise ValueError(
|
|
315
|
+
"Could not determine head information from model config "
|
|
316
|
+
f"(num_heads={num_heads}, head_dim={head_dim})."
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
return int(num_heads), int(head_dim)
|
|
320
|
+
|
|
321
|
+
def _compute_logit_effect(
|
|
322
|
+
self,
|
|
323
|
+
clean_logits: torch.Tensor,
|
|
324
|
+
corrupted_logits: torch.Tensor,
|
|
325
|
+
patched_logits: torch.Tensor,
|
|
326
|
+
) -> float:
|
|
327
|
+
"""
|
|
328
|
+
Compute how much patching moves logits from corrupted toward clean.
|
|
329
|
+
|
|
330
|
+
Returns value from 0 (no effect) to 1 (full recovery).
|
|
331
|
+
"""
|
|
332
|
+
# Use last token logits
|
|
333
|
+
clean_last = clean_logits[:, -1, :]
|
|
334
|
+
corrupted_last = corrupted_logits[:, -1, :]
|
|
335
|
+
patched_last = patched_logits[:, -1, :]
|
|
336
|
+
|
|
337
|
+
# Compute distances
|
|
338
|
+
clean_to_corrupted = torch.norm(clean_last - corrupted_last).item()
|
|
339
|
+
patched_to_corrupted = torch.norm(patched_last - corrupted_last).item()
|
|
340
|
+
|
|
341
|
+
if clean_to_corrupted == 0:
|
|
342
|
+
return 0.0
|
|
343
|
+
|
|
344
|
+
# Effect = how much we moved toward clean (normalized)
|
|
345
|
+
effect = 1.0 - (patched_to_corrupted / clean_to_corrupted)
|
|
346
|
+
return max(0.0, min(1.0, effect))
|
|
347
|
+
|
|
348
|
+
def sweep_positions(
|
|
349
|
+
self,
|
|
350
|
+
clean_prompt: str,
|
|
351
|
+
corrupted_prompt: str,
|
|
352
|
+
layer_idx: int,
|
|
353
|
+
positions: Optional[List[int]] = None,
|
|
354
|
+
**gen_kwargs,
|
|
355
|
+
) -> Dict[int, PatchingResult]:
|
|
356
|
+
"""
|
|
357
|
+
Sweep patching across token positions at a fixed layer.
|
|
358
|
+
|
|
359
|
+
Args:
|
|
360
|
+
clean_prompt: Prompt that elicits desired behavior
|
|
361
|
+
corrupted_prompt: Prompt that elicits different behavior
|
|
362
|
+
layer_idx: Which layer to patch
|
|
363
|
+
positions: Which positions to sweep (None = all)
|
|
364
|
+
**gen_kwargs: Generation arguments
|
|
365
|
+
|
|
366
|
+
Returns:
|
|
367
|
+
Dict mapping position -> PatchingResult
|
|
368
|
+
"""
|
|
369
|
+
# Get clean activations
|
|
370
|
+
_, clean_cache = self.backend.generate_with_cache(
|
|
371
|
+
clean_prompt, layers=[layer_idx], **gen_kwargs
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
# Get corrupted baseline
|
|
375
|
+
corrupted_output = self.backend.generate(corrupted_prompt, **gen_kwargs)
|
|
376
|
+
|
|
377
|
+
# Determine positions to sweep
|
|
378
|
+
clean_act = clean_cache.get(layer_idx)
|
|
379
|
+
target_positions = positions if positions is not None else list(range(clean_act.shape[1]))
|
|
380
|
+
|
|
381
|
+
results = {}
|
|
382
|
+
|
|
383
|
+
for pos in target_positions:
|
|
384
|
+
result = self.patch_run(
|
|
385
|
+
corrupted_prompt,
|
|
386
|
+
clean_cache,
|
|
387
|
+
target_layers=[layer_idx],
|
|
388
|
+
token_positions=[pos],
|
|
389
|
+
**gen_kwargs,
|
|
390
|
+
)
|
|
391
|
+
result.original_answer = corrupted_output.text
|
|
392
|
+
result.patched_answer = result.output_text
|
|
393
|
+
results[pos] = result
|
|
394
|
+
|
|
395
|
+
return results
|
|
396
|
+
|
|
397
|
+
def ablate_layer(
|
|
398
|
+
self, prompt: str, layer_idx: int, ablation_type: str = "zero", **gen_kwargs
|
|
399
|
+
) -> PatchingResult:
|
|
400
|
+
"""
|
|
401
|
+
Ablate a layer by zeroing or replacing with mean.
|
|
402
|
+
|
|
403
|
+
Args:
|
|
404
|
+
prompt: Input prompt
|
|
405
|
+
layer_idx: Which layer to ablate
|
|
406
|
+
ablation_type: "zero" or "mean"
|
|
407
|
+
**gen_kwargs: Generation arguments
|
|
408
|
+
|
|
409
|
+
Returns:
|
|
410
|
+
PatchingResult from ablated run
|
|
411
|
+
"""
|
|
412
|
+
|
|
413
|
+
def zero_hook(module, input, output):
|
|
414
|
+
if isinstance(output, tuple):
|
|
415
|
+
hidden_states = output[0]
|
|
416
|
+
rest = output[1:]
|
|
417
|
+
return (torch.zeros_like(hidden_states),) + rest
|
|
418
|
+
return torch.zeros_like(output)
|
|
419
|
+
|
|
420
|
+
def mean_hook(module, input, output):
|
|
421
|
+
if isinstance(output, tuple):
|
|
422
|
+
hidden_states = output[0]
|
|
423
|
+
rest = output[1:]
|
|
424
|
+
mean_val = hidden_states.mean(dim=(0, 1), keepdim=True)
|
|
425
|
+
return (mean_val.expand_as(hidden_states),) + rest
|
|
426
|
+
mean_val = output.mean(dim=(0, 1), keepdim=True)
|
|
427
|
+
return mean_val.expand_as(output)
|
|
428
|
+
|
|
429
|
+
hook = zero_hook if ablation_type == "zero" else mean_hook
|
|
430
|
+
self.hook_manager.register_forward_hook(layer_idx, hook)
|
|
431
|
+
|
|
432
|
+
try:
|
|
433
|
+
output = self.backend.generate(prompt, **gen_kwargs)
|
|
434
|
+
finally:
|
|
435
|
+
self.hook_manager.remove_all_hooks()
|
|
436
|
+
|
|
437
|
+
return PatchingResult(
|
|
438
|
+
output_text=output.text, patched_layers=[layer_idx], patched_positions=None
|
|
439
|
+
)
|
cotlab/patching/sae.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
"""GemmaScope-2 JumpReLU Sparse Autoencoder loader.
|
|
2
|
+
|
|
3
|
+
Loads SAE weights directly from HuggingFace (no SAELens / TransformerLens
|
|
4
|
+
dependency) and exposes a minimal encode() forward pass compatible with the
|
|
5
|
+
CoTLab TransformersBackend hook infrastructure.
|
|
6
|
+
|
|
7
|
+
Architecture (JumpReLU SAE, Lieberum et al. 2024)
|
|
8
|
+
--------------------------------------------------
|
|
9
|
+
h = x - b_dec # centre around decoder bias
|
|
10
|
+
pre_act = h @ w_enc + b_enc # linear projection [..., d_sae]
|
|
11
|
+
features = pre_act * (pre_act > θ) # JumpReLU gate
|
|
12
|
+
|
|
13
|
+
Weight layout in params.safetensors (GemmaScope-2 convention)
|
|
14
|
+
--------------------------------------------------------------
|
|
15
|
+
w_enc float32 [d_model, d_sae]
|
|
16
|
+
b_enc float32 [d_sae]
|
|
17
|
+
w_dec float32 [d_sae, d_model]
|
|
18
|
+
b_dec float32 [d_model]
|
|
19
|
+
threshold float32 [d_sae]
|
|
20
|
+
|
|
21
|
+
HF path pattern
|
|
22
|
+
---------------
|
|
23
|
+
{site}/layer_{N}_width_{width}_l0_{l0_label}/params.safetensors
|
|
24
|
+
e.g. resid_post_all/layer_9_width_16k_l0_small/params.safetensors
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
import re
|
|
28
|
+
from typing import Optional
|
|
29
|
+
|
|
30
|
+
import torch
|
|
31
|
+
import torch.nn as nn
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class GemmaScopeLayer(nn.Module):
|
|
35
|
+
"""JumpReLU SAE for a single residual-stream layer."""
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
w_enc: torch.Tensor,
|
|
40
|
+
b_enc: torch.Tensor,
|
|
41
|
+
w_dec: torch.Tensor,
|
|
42
|
+
b_dec: torch.Tensor,
|
|
43
|
+
threshold: torch.Tensor,
|
|
44
|
+
layer: int,
|
|
45
|
+
repo_id: str,
|
|
46
|
+
source_path: str,
|
|
47
|
+
):
|
|
48
|
+
super().__init__()
|
|
49
|
+
self.register_buffer("w_enc", w_enc) # [d_model, d_sae]
|
|
50
|
+
self.register_buffer("b_enc", b_enc) # [d_sae]
|
|
51
|
+
self.register_buffer("w_dec", w_dec) # [d_sae, d_model]
|
|
52
|
+
self.register_buffer("b_dec", b_dec) # [d_model]
|
|
53
|
+
self.register_buffer("threshold", threshold) # [d_sae]
|
|
54
|
+
self.layer = layer
|
|
55
|
+
self.repo_id = repo_id
|
|
56
|
+
self.source_path = source_path
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def d_model(self) -> int:
|
|
60
|
+
return self.w_enc.shape[0]
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
def d_sae(self) -> int:
|
|
64
|
+
return self.w_enc.shape[1]
|
|
65
|
+
|
|
66
|
+
# ------------------------------------------------------------------
|
|
67
|
+
# Factory
|
|
68
|
+
# ------------------------------------------------------------------
|
|
69
|
+
|
|
70
|
+
@classmethod
|
|
71
|
+
def from_pretrained(
|
|
72
|
+
cls,
|
|
73
|
+
repo_id: str,
|
|
74
|
+
layer: int,
|
|
75
|
+
site: str = "resid_post_all",
|
|
76
|
+
width: str = "16k",
|
|
77
|
+
l0_label: str = "small",
|
|
78
|
+
token: Optional[str] = None,
|
|
79
|
+
) -> "GemmaScopeLayer":
|
|
80
|
+
"""Download and load a GemmaScope-2 SAE from HuggingFace.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
repo_id: HF repo, e.g. ``"google/gemma-scope-2-270m-it"``
|
|
84
|
+
layer: Residual stream layer index (0-based).
|
|
85
|
+
site: SAE training site. ``"resid_post_all"`` covers every
|
|
86
|
+
layer; other options are ``"resid_post"`` (4 depths only),
|
|
87
|
+
``"attn_out_all"``, ``"mlp_out_all"``.
|
|
88
|
+
width: Feature dictionary width: ``"16k"`` or ``"262k"``.
|
|
89
|
+
l0_label: Sparsity label used in the directory name: ``"small"``,
|
|
90
|
+
``"medium"``, or ``"big"``. If not found, falls back to
|
|
91
|
+
any available file for that layer/width combination.
|
|
92
|
+
token: HuggingFace API token (optional; reads HF_TOKEN env var
|
|
93
|
+
automatically via huggingface_hub).
|
|
94
|
+
"""
|
|
95
|
+
from huggingface_hub import hf_hub_download, list_repo_files # noqa: PLC0415
|
|
96
|
+
from safetensors import safe_open # noqa: PLC0415
|
|
97
|
+
|
|
98
|
+
# Try the canonical direct path first (avoids full repo listing).
|
|
99
|
+
direct = f"{site}/layer_{layer}_width_{width}_l0_{l0_label}/params.safetensors"
|
|
100
|
+
print(f" [SAE] Fetching layer={layer} ({direct}) …")
|
|
101
|
+
|
|
102
|
+
try:
|
|
103
|
+
local_path = hf_hub_download(repo_id=repo_id, filename=direct, token=token)
|
|
104
|
+
chosen = direct
|
|
105
|
+
except Exception:
|
|
106
|
+
print(" [SAE] Direct path not found — scanning repo …")
|
|
107
|
+
all_files = list(list_repo_files(repo_id, token=token))
|
|
108
|
+
layer_tag = f"layer_{layer}_"
|
|
109
|
+
width_tag = f"width_{width}_"
|
|
110
|
+
candidates = [
|
|
111
|
+
f
|
|
112
|
+
for f in all_files
|
|
113
|
+
if site in f
|
|
114
|
+
and layer_tag in f
|
|
115
|
+
and width_tag in f
|
|
116
|
+
and f.endswith("params.safetensors")
|
|
117
|
+
]
|
|
118
|
+
if not candidates:
|
|
119
|
+
available = sorted(
|
|
120
|
+
{
|
|
121
|
+
re.search(r"layer_(\d+)_", f).group(1)
|
|
122
|
+
for f in all_files
|
|
123
|
+
if "params.safetensors" in f and re.search(r"layer_(\d+)_", f)
|
|
124
|
+
},
|
|
125
|
+
key=int,
|
|
126
|
+
)
|
|
127
|
+
raise FileNotFoundError(
|
|
128
|
+
f"No SAE found for site={site!r}, layer={layer}, width={width!r} "
|
|
129
|
+
f"in {repo_id}.\nAvailable layers (any site/width): {available}"
|
|
130
|
+
)
|
|
131
|
+
preferred = [f for f in candidates if f"l0_{l0_label}" in f]
|
|
132
|
+
chosen = preferred[0] if preferred else candidates[0]
|
|
133
|
+
local_path = hf_hub_download(repo_id=repo_id, filename=chosen, token=token)
|
|
134
|
+
|
|
135
|
+
with safe_open(local_path, framework="pt") as f:
|
|
136
|
+
w_enc = f.get_tensor("w_enc").float() # [d_model, d_sae]
|
|
137
|
+
b_enc = f.get_tensor("b_enc").float() # [d_sae]
|
|
138
|
+
w_dec = f.get_tensor("w_dec").float() # [d_sae, d_model]
|
|
139
|
+
b_dec = f.get_tensor("b_dec").float() # [d_model]
|
|
140
|
+
threshold = f.get_tensor("threshold").float() # [d_sae]
|
|
141
|
+
|
|
142
|
+
print(f" [SAE] Loaded layer={layer} d_model={w_enc.shape[0]} d_sae={w_enc.shape[1]}")
|
|
143
|
+
return cls(
|
|
144
|
+
w_enc=w_enc,
|
|
145
|
+
b_enc=b_enc,
|
|
146
|
+
w_dec=w_dec,
|
|
147
|
+
b_dec=b_dec,
|
|
148
|
+
threshold=threshold,
|
|
149
|
+
layer=layer,
|
|
150
|
+
repo_id=repo_id,
|
|
151
|
+
source_path=chosen,
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
# ------------------------------------------------------------------
|
|
155
|
+
# Forward
|
|
156
|
+
# ------------------------------------------------------------------
|
|
157
|
+
|
|
158
|
+
@torch.no_grad()
|
|
159
|
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
|
160
|
+
"""JumpReLU encode.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
x: Residual stream activations, shape ``[..., d_model]``.
|
|
164
|
+
Must be on the same device as the SAE buffers.
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
Sparse feature activations, shape ``[..., d_sae]``.
|
|
168
|
+
"""
|
|
169
|
+
orig_dtype = x.dtype
|
|
170
|
+
x = x.float()
|
|
171
|
+
h = x - self.b_dec
|
|
172
|
+
pre = h @ self.w_enc + self.b_enc
|
|
173
|
+
features = pre * (pre > self.threshold).float()
|
|
174
|
+
return features.to(orig_dtype)
|
|
175
|
+
|
|
176
|
+
def __repr__(self) -> str:
|
|
177
|
+
return (
|
|
178
|
+
f"GemmaScopeLayer(layer={self.layer}, "
|
|
179
|
+
f"d_model={self.d_model}, d_sae={self.d_sae}, "
|
|
180
|
+
f"repo={self.repo_id})"
|
|
181
|
+
)
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
"""Prompt strategies module."""
|
|
2
|
+
|
|
3
|
+
from .cardiology import CardiologyPromptStrategy
|
|
4
|
+
from .histopathology import HistopathologyPromptStrategy
|
|
5
|
+
from .length_matched_strategies import (
|
|
6
|
+
ChainOfThoughtMatchedStrategy,
|
|
7
|
+
ContrarianMatchedStrategy,
|
|
8
|
+
DirectAnswerMatchedStrategy,
|
|
9
|
+
)
|
|
10
|
+
from .mcq import MCQPromptStrategy
|
|
11
|
+
from .neurology import NeurologyPromptStrategy
|
|
12
|
+
from .oncology import OncologyPromptStrategy
|
|
13
|
+
from .plab import PLABPromptStrategy
|
|
14
|
+
from .pubhealthbench import PubHealthBenchMCQPromptStrategy
|
|
15
|
+
from .radiology import RadiologyPromptStrategy
|
|
16
|
+
from .strategies import (
|
|
17
|
+
ArroganceStrategy,
|
|
18
|
+
ChainOfThoughtStrategy,
|
|
19
|
+
DirectAnswerStrategy,
|
|
20
|
+
NoInstructionStrategy,
|
|
21
|
+
SimplePromptStrategy,
|
|
22
|
+
create_prompt_strategy,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
__all__ = [
|
|
26
|
+
"SimplePromptStrategy",
|
|
27
|
+
"ChainOfThoughtStrategy",
|
|
28
|
+
"DirectAnswerStrategy",
|
|
29
|
+
"ArroganceStrategy",
|
|
30
|
+
"NoInstructionStrategy",
|
|
31
|
+
"CardiologyPromptStrategy",
|
|
32
|
+
"HistopathologyPromptStrategy",
|
|
33
|
+
"MCQPromptStrategy",
|
|
34
|
+
"NeurologyPromptStrategy",
|
|
35
|
+
"OncologyPromptStrategy",
|
|
36
|
+
"PLABPromptStrategy",
|
|
37
|
+
"PubHealthBenchMCQPromptStrategy",
|
|
38
|
+
"RadiologyPromptStrategy",
|
|
39
|
+
"create_prompt_strategy",
|
|
40
|
+
"ContrarianMatchedStrategy",
|
|
41
|
+
"ChainOfThoughtMatchedStrategy",
|
|
42
|
+
"DirectAnswerMatchedStrategy",
|
|
43
|
+
]
|