cotlab 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cotlab/__init__.py +3 -0
- cotlab/analyse_experiments.py +392 -0
- cotlab/analysis/__init__.py +11 -0
- cotlab/analysis/cot_parser.py +243 -0
- cotlab/analysis/faithfulness_metrics.py +192 -0
- cotlab/backends/__init__.py +16 -0
- cotlab/backends/base.py +78 -0
- cotlab/backends/transformers_backend.py +335 -0
- cotlab/backends/vllm_backend.py +227 -0
- cotlab/cli.py +83 -0
- cotlab/core/__init__.py +34 -0
- cotlab/core/base.py +749 -0
- cotlab/core/config.py +90 -0
- cotlab/core/registry.py +68 -0
- cotlab/datasets/__init__.py +45 -0
- cotlab/datasets/loaders.py +1889 -0
- cotlab/experiment/__init__.py +315 -0
- cotlab/experiments/__init__.py +43 -0
- cotlab/experiments/activation_compare.py +290 -0
- cotlab/experiments/activation_patching.py +1050 -0
- cotlab/experiments/attention_analysis.py +885 -0
- cotlab/experiments/classification.py +235 -0
- cotlab/experiments/composite_shift_detector.py +524 -0
- cotlab/experiments/cot_ablation.py +277 -0
- cotlab/experiments/cot_faithfulness.py +187 -0
- cotlab/experiments/cot_heads.py +208 -0
- cotlab/experiments/full_layer_cot.py +232 -0
- cotlab/experiments/full_layer_patching.py +225 -0
- cotlab/experiments/h_neuron_analysis.py +712 -0
- cotlab/experiments/logit_lens.py +439 -0
- cotlab/experiments/multi_head_cot.py +220 -0
- cotlab/experiments/multi_head_patching.py +229 -0
- cotlab/experiments/probing_classifier.py +402 -0
- cotlab/experiments/residual_norm_ood.py +413 -0
- cotlab/experiments/sae_feature_analysis.py +673 -0
- cotlab/experiments/steering_vectors.py +223 -0
- cotlab/experiments/sycophancy_heads.py +224 -0
- cotlab/logging/__init__.py +5 -0
- cotlab/logging/json_logger.py +161 -0
- cotlab/main.py +317 -0
- cotlab/patching/__init__.py +24 -0
- cotlab/patching/cache.py +141 -0
- cotlab/patching/hooks.py +558 -0
- cotlab/patching/interventions.py +86 -0
- cotlab/patching/patcher.py +439 -0
- cotlab/patching/sae.py +181 -0
- cotlab/prompts/__init__.py +43 -0
- cotlab/prompts/cardiology.py +378 -0
- cotlab/prompts/histopathology.py +265 -0
- cotlab/prompts/length_matched_strategies.py +157 -0
- cotlab/prompts/mcq.py +193 -0
- cotlab/prompts/neurology.py +353 -0
- cotlab/prompts/oncology.py +367 -0
- cotlab/prompts/plab.py +162 -0
- cotlab/prompts/pubhealthbench.py +82 -0
- cotlab/prompts/pubmedqa.py +173 -0
- cotlab/prompts/radiology.py +414 -0
- cotlab/prompts/strategies.py +939 -0
- cotlab/prompts/tcga.py +168 -0
- cotlab/runner.py +204 -0
- cotlab-0.8.0.dist-info/METADATA +166 -0
- cotlab-0.8.0.dist-info/RECORD +65 -0
- cotlab-0.8.0.dist-info/WHEEL +4 -0
- cotlab-0.8.0.dist-info/entry_points.txt +3 -0
- cotlab-0.8.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,1050 @@
|
|
|
1
|
+
"""Activation Patching experiment — causal intervention via residual-stream replacement.
|
|
2
|
+
|
|
3
|
+
Patching modes
|
|
4
|
+
--------------
|
|
5
|
+
``pairs`` (default — requires PatchingPairsDataset)
|
|
6
|
+
clean = sample.text
|
|
7
|
+
corrupt = sample.metadata["corrupted_prompt"]
|
|
8
|
+
Answers Q: which layers encode the specific diagnosis/fact?
|
|
9
|
+
|
|
10
|
+
``few_shot_contrast`` (works with ANY dataset)
|
|
11
|
+
clean = few-shot prompt of the sample (prompt_strategy with few_shot=True)
|
|
12
|
+
corrupt = zero-shot prompt of the sample (prompt_strategy with few_shot=False)
|
|
13
|
+
Answers Q: which layers causally drive few-shot's benefit on OOD / non-OOD?
|
|
14
|
+
|
|
15
|
+
``introspect_contrast`` (works with ANY dataset)
|
|
16
|
+
clean = prompt + introspect instruction
|
|
17
|
+
corrupt = prompt only
|
|
18
|
+
Answers Q: which layers carry the "think deeply" reasoning signal?
|
|
19
|
+
|
|
20
|
+
``cot_contrast`` (works with ANY dataset)
|
|
21
|
+
clean = full CoT prompt (cot_trigger active, e.g. "Let's think through this step by step:")
|
|
22
|
+
corrupt = zero-shot prompt (cot_trigger stripped — same structure, no reasoning nudge)
|
|
23
|
+
Answers Q: which layers carry the chain-of-thought reasoning signal vs plain answering?
|
|
24
|
+
Use as the default/baseline contrast alongside few_shot_contrast and introspect_contrast.
|
|
25
|
+
|
|
26
|
+
``token_group_contrast`` (works with ANY dataset)
|
|
27
|
+
Hooks the attention weight matrix at a single target layer and zeros out
|
|
28
|
+
one token group at a time (delimiter / choice / content). Measures how
|
|
29
|
+
much each group's removal shifts the answer logit.
|
|
30
|
+
Answers Q: which token positions does <target_mask_layer> attend to causally?
|
|
31
|
+
(Q1: run at layer 3 — the universal attention bottleneck)
|
|
32
|
+
|
|
33
|
+
Algorithm (logit-recovery metric, one sample, residual patching modes):
|
|
34
|
+
1. Forward clean → cache per-layer residuals (CPU).
|
|
35
|
+
2. Forward corrupt → baseline logit at last token.
|
|
36
|
+
3. For each layer L (strided):
|
|
37
|
+
Re-run corrupt with hook replacing layer L's output with cached clean.
|
|
38
|
+
effect(L) = (logit_patched[clean_tok] - logit_corrupt[clean_tok])
|
|
39
|
+
/ (logit_clean[clean_tok] - logit_corrupt[clean_tok] + ε)
|
|
40
|
+
1 = full recovery, 0 = no effect, negative = made things worse.
|
|
41
|
+
|
|
42
|
+
Algorithm (token_group_contrast):
|
|
43
|
+
1. Single forward pass with no masking → logit_base.
|
|
44
|
+
2. For each group G in {delimiter, choice, content}:
|
|
45
|
+
Forward with attention weights at target_mask_layer zeroed for group G.
|
|
46
|
+
importance(G) = |logit_base - logit_masked|
|
|
47
|
+
3. dominant_group = argmax importance.
|
|
48
|
+
|
|
49
|
+
Memory safety: activations moved to CPU immediately inside each hook.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
import math
|
|
53
|
+
import re
|
|
54
|
+
from typing import Any, Dict, List, Optional, Set
|
|
55
|
+
|
|
56
|
+
import torch
|
|
57
|
+
from tqdm import tqdm
|
|
58
|
+
|
|
59
|
+
from ..backends.base import InferenceBackend
|
|
60
|
+
from ..core.base import BaseExperiment, ExperimentResult
|
|
61
|
+
from ..core.registry import Registry
|
|
62
|
+
from ..datasets.loaders import BaseDataset
|
|
63
|
+
from ..logging import ExperimentLogger
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@Registry.register_experiment("activation_patching")
|
|
67
|
+
class ActivationPatchingExperiment(BaseExperiment):
|
|
68
|
+
"""
|
|
69
|
+
Layer-wise causal activation patching with logit-recovery scoring.
|
|
70
|
+
|
|
71
|
+
Supports two patching modes:
|
|
72
|
+
- ``pairs`` PatchingPairsDataset clean/corrupt pairs.
|
|
73
|
+
- ``few_shot_contrast`` Any dataset — few-shot (clean) vs zero-shot (corrupt).
|
|
74
|
+
- ``cot_contrast`` Any dataset — CoT prompt (clean) vs zero-shot (corrupt).
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
VALID_MODES = (
|
|
78
|
+
"pairs",
|
|
79
|
+
"few_shot_contrast",
|
|
80
|
+
"introspect_contrast",
|
|
81
|
+
"token_group_contrast",
|
|
82
|
+
"cot_contrast",
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
# Tokens that are purely structural / formatting — not medical content.
|
|
86
|
+
_DELIMITER_STRINGS: Set[str] = {
|
|
87
|
+
"\n",
|
|
88
|
+
":",
|
|
89
|
+
"#",
|
|
90
|
+
"##",
|
|
91
|
+
"###",
|
|
92
|
+
"Options",
|
|
93
|
+
"Options:",
|
|
94
|
+
"Answer",
|
|
95
|
+
"Answer:",
|
|
96
|
+
"A.",
|
|
97
|
+
"B.",
|
|
98
|
+
"C.",
|
|
99
|
+
"D.",
|
|
100
|
+
"E.",
|
|
101
|
+
"F.",
|
|
102
|
+
"G.",
|
|
103
|
+
"(A)",
|
|
104
|
+
"(B)",
|
|
105
|
+
"(C)",
|
|
106
|
+
"(D)",
|
|
107
|
+
"(E)",
|
|
108
|
+
"(F)",
|
|
109
|
+
"(G)",
|
|
110
|
+
"A)",
|
|
111
|
+
"B)",
|
|
112
|
+
"C)",
|
|
113
|
+
"D)",
|
|
114
|
+
"E)",
|
|
115
|
+
"F)",
|
|
116
|
+
"G)",
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
def __init__(
|
|
120
|
+
self,
|
|
121
|
+
name: str = "activation_patching",
|
|
122
|
+
description: str = "Layer-wise causal activation patching (logit recovery)",
|
|
123
|
+
patching_mode: str = "pairs", # "pairs" | "few_shot_contrast" | "introspect_contrast"
|
|
124
|
+
layer_stride: int = 2,
|
|
125
|
+
num_samples: int = 50,
|
|
126
|
+
max_input_tokens: int = 1024,
|
|
127
|
+
seed: int = 42,
|
|
128
|
+
answer_cue: str = "\n\nAnswer:",
|
|
129
|
+
introspect_instruction: str = (
|
|
130
|
+
"Think deeply about this problem. "
|
|
131
|
+
"Carefully reason through the underlying mechanisms and consider "
|
|
132
|
+
"all relevant factors before committing to your answer."
|
|
133
|
+
),
|
|
134
|
+
# Legacy fields kept so old YAML configs don't break
|
|
135
|
+
variants: Optional[List[Dict[str, Any]]] = None,
|
|
136
|
+
patching: Optional[Dict[str, Any]] = None,
|
|
137
|
+
# Token-group contrast params
|
|
138
|
+
token_group_contrast_layer: int = 3,
|
|
139
|
+
token_group_mode: str = "all", # "all" | "delimiter" | "choice" | "content"
|
|
140
|
+
**kwargs,
|
|
141
|
+
):
|
|
142
|
+
if patching_mode not in self.VALID_MODES:
|
|
143
|
+
raise ValueError(
|
|
144
|
+
f"patching_mode must be one of {self.VALID_MODES}, got {patching_mode!r}"
|
|
145
|
+
)
|
|
146
|
+
self._name = name
|
|
147
|
+
self.description = description
|
|
148
|
+
self.patching_mode = patching_mode
|
|
149
|
+
self.layer_stride = layer_stride
|
|
150
|
+
self.num_samples = num_samples
|
|
151
|
+
self.max_input_tokens = max_input_tokens
|
|
152
|
+
self.seed = seed
|
|
153
|
+
self.answer_cue = answer_cue
|
|
154
|
+
self.introspect_instruction = introspect_instruction
|
|
155
|
+
self.patching = patching or {}
|
|
156
|
+
self.token_group_contrast_layer = int(token_group_contrast_layer)
|
|
157
|
+
self.token_group_mode = token_group_mode
|
|
158
|
+
|
|
159
|
+
@property
|
|
160
|
+
def name(self) -> str:
|
|
161
|
+
return self._name
|
|
162
|
+
|
|
163
|
+
# ------------------------------------------------------------------
|
|
164
|
+
# Helpers
|
|
165
|
+
# ------------------------------------------------------------------
|
|
166
|
+
|
|
167
|
+
def _resolve_layers(self, backend: InferenceBackend) -> List[int]:
|
|
168
|
+
all_layers = list(range(backend.hook_manager.num_layers))
|
|
169
|
+
return all_layers[:: self.layer_stride]
|
|
170
|
+
|
|
171
|
+
def _resolve_head_targets(self, layers: List[int]) -> Dict[int, List[int]]:
|
|
172
|
+
"""Resolve optional head-target mapping from `patching` config.
|
|
173
|
+
|
|
174
|
+
Supported configs (mutually exclusive):
|
|
175
|
+
- `patching.head_indices`: list of heads to apply to all `layers`
|
|
176
|
+
- `patching.target_heads`: mapping layer -> list of heads
|
|
177
|
+
"""
|
|
178
|
+
head_indices = self.patching.get("head_indices")
|
|
179
|
+
target_heads = self.patching.get("target_heads")
|
|
180
|
+
|
|
181
|
+
if head_indices is not None and target_heads is not None:
|
|
182
|
+
raise ValueError("Use either target_heads or head_indices, not both.")
|
|
183
|
+
|
|
184
|
+
if target_heads is not None:
|
|
185
|
+
resolved: Dict[int, List[int]] = {}
|
|
186
|
+
for layer_key, heads in dict(target_heads).items():
|
|
187
|
+
layer_idx = int(layer_key)
|
|
188
|
+
if layer_idx in layers:
|
|
189
|
+
resolved[layer_idx] = [int(h) for h in list(heads)]
|
|
190
|
+
return resolved
|
|
191
|
+
|
|
192
|
+
if head_indices is not None:
|
|
193
|
+
head_list = [int(h) for h in list(head_indices)]
|
|
194
|
+
return {layer_idx: head_list for layer_idx in layers}
|
|
195
|
+
|
|
196
|
+
return {}
|
|
197
|
+
|
|
198
|
+
def _answer_token_id(self, tokenizer, label) -> Optional[int]:
|
|
199
|
+
"""Return the first token id of the label string (the logit we track)."""
|
|
200
|
+
if label is None:
|
|
201
|
+
return None
|
|
202
|
+
label_str = str(label).strip()
|
|
203
|
+
if not label_str:
|
|
204
|
+
return None
|
|
205
|
+
for prefix in (" ", ""):
|
|
206
|
+
ids = tokenizer.encode(prefix + label_str, add_special_tokens=False)
|
|
207
|
+
if ids:
|
|
208
|
+
return ids[0]
|
|
209
|
+
return None
|
|
210
|
+
|
|
211
|
+
def _answer_letter_token_ids(self, tokenizer) -> List[int]:
|
|
212
|
+
"""Collect all plausible token ids for MCQ answer letters A-J."""
|
|
213
|
+
ids = set()
|
|
214
|
+
for letter in "ABCDEFGHIJ":
|
|
215
|
+
for prefix in (" ", "", "\n"):
|
|
216
|
+
encoded = tokenizer.encode(prefix + letter, add_special_tokens=False)
|
|
217
|
+
if encoded:
|
|
218
|
+
ids.add(encoded[-1])
|
|
219
|
+
return sorted(ids)
|
|
220
|
+
|
|
221
|
+
def _tokenize(self, tokenizer, text: str, device):
|
|
222
|
+
return tokenizer(
|
|
223
|
+
text,
|
|
224
|
+
return_tensors="pt",
|
|
225
|
+
truncation=True,
|
|
226
|
+
max_length=self.max_input_tokens,
|
|
227
|
+
).to(device)
|
|
228
|
+
|
|
229
|
+
def _forward_with_cache(
|
|
230
|
+
self,
|
|
231
|
+
backend: InferenceBackend,
|
|
232
|
+
tokens,
|
|
233
|
+
target_layers: List[int],
|
|
234
|
+
) -> tuple:
|
|
235
|
+
"""Run a forward pass, caching residual activations (last token, CPU) per layer.
|
|
236
|
+
|
|
237
|
+
Returns:
|
|
238
|
+
logits_last – [vocab_size] float32 CPU tensor at last token position
|
|
239
|
+
act_cache – dict[layer_idx → [hidden] float32 CPU tensor]
|
|
240
|
+
"""
|
|
241
|
+
act_cache: Dict[int, torch.Tensor] = {}
|
|
242
|
+
|
|
243
|
+
def make_cache_hook(layer_idx: int):
|
|
244
|
+
def hook(module, inp, output):
|
|
245
|
+
tensor = output[0] if isinstance(output, tuple) else output
|
|
246
|
+
with torch.no_grad():
|
|
247
|
+
# keep bfloat16 so patching is dtype-compatible with the model
|
|
248
|
+
act_cache[layer_idx] = tensor[0, -1].detach().cpu()
|
|
249
|
+
return output
|
|
250
|
+
|
|
251
|
+
return hook
|
|
252
|
+
|
|
253
|
+
handles = [
|
|
254
|
+
backend.hook_manager.get_residual_module(layer_idx).register_forward_hook(
|
|
255
|
+
make_cache_hook(layer_idx)
|
|
256
|
+
)
|
|
257
|
+
for layer_idx in target_layers
|
|
258
|
+
if layer_idx < backend.hook_manager.num_layers
|
|
259
|
+
]
|
|
260
|
+
try:
|
|
261
|
+
with torch.no_grad():
|
|
262
|
+
out = backend._model(**tokens)
|
|
263
|
+
finally:
|
|
264
|
+
for h in handles:
|
|
265
|
+
h.remove()
|
|
266
|
+
|
|
267
|
+
logits_last = out.logits[0, -1].detach().float().cpu() # float32 for stable arithmetic
|
|
268
|
+
return logits_last, act_cache
|
|
269
|
+
|
|
270
|
+
def _forward_patched(
|
|
271
|
+
self,
|
|
272
|
+
backend: InferenceBackend,
|
|
273
|
+
tokens,
|
|
274
|
+
patch_layer: int,
|
|
275
|
+
patch_vec: torch.Tensor, # CPU [hidden]
|
|
276
|
+
) -> torch.Tensor:
|
|
277
|
+
"""Forward pass replacing layer `patch_layer` output with `patch_vec`.
|
|
278
|
+
|
|
279
|
+
Returns [vocab_size] float32 CPU logit vector at last token.
|
|
280
|
+
"""
|
|
281
|
+
# cast to model dtype (bfloat16) before injection — avoids dtype mismatch
|
|
282
|
+
model_dtype = next(backend._model.parameters()).dtype
|
|
283
|
+
patch_gpu = patch_vec.to(dtype=model_dtype, device=backend.device)
|
|
284
|
+
|
|
285
|
+
def patch_hook(module, inp, output):
|
|
286
|
+
if isinstance(output, tuple):
|
|
287
|
+
patched = list(output)
|
|
288
|
+
patched[0] = patch_gpu.unsqueeze(0).unsqueeze(0).expand_as(output[0])
|
|
289
|
+
return tuple(patched)
|
|
290
|
+
return patch_gpu.unsqueeze(0).unsqueeze(0).expand_as(output)
|
|
291
|
+
|
|
292
|
+
mod = backend.hook_manager.get_residual_module(patch_layer)
|
|
293
|
+
handle = mod.register_forward_hook(patch_hook)
|
|
294
|
+
try:
|
|
295
|
+
with torch.no_grad():
|
|
296
|
+
out = backend._model(**tokens)
|
|
297
|
+
finally:
|
|
298
|
+
handle.remove()
|
|
299
|
+
del patch_gpu
|
|
300
|
+
|
|
301
|
+
return out.logits[0, -1].detach().float().cpu()
|
|
302
|
+
|
|
303
|
+
# ------------------------------------------------------------------
|
|
304
|
+
# Token-group tagger and attention masking helpers
|
|
305
|
+
# ------------------------------------------------------------------
|
|
306
|
+
|
|
307
|
+
def _tag_tokens(
|
|
308
|
+
self,
|
|
309
|
+
input_ids: torch.Tensor, # shape (seq_len,)
|
|
310
|
+
tokenizer,
|
|
311
|
+
metadata: dict,
|
|
312
|
+
) -> Dict[str, List[int]]:
|
|
313
|
+
"""Classify every token position into one of 3 groups.
|
|
314
|
+
|
|
315
|
+
Groups
|
|
316
|
+
------
|
|
317
|
+
delimiter : structural tokens (\\n, A., Options:, …)
|
|
318
|
+
choice : answer-option text (the words after A. / B. / …)
|
|
319
|
+
content : question stem + clinical entities
|
|
320
|
+
|
|
321
|
+
For MedQA samples that carry ``metamap_phrases`` in metadata the
|
|
322
|
+
content group is further split into ``entity`` and ``stem``.
|
|
323
|
+
|
|
324
|
+
Returns
|
|
325
|
+
-------
|
|
326
|
+
dict mapping group name -> sorted list of 0-based token positions.
|
|
327
|
+
"""
|
|
328
|
+
seq_len = input_ids.shape[0]
|
|
329
|
+
labels = ["content"] * seq_len # default everything to content
|
|
330
|
+
|
|
331
|
+
# ── Pass 1: mark delimiter tokens ─────────────────────────────
|
|
332
|
+
for i in range(seq_len):
|
|
333
|
+
tok_raw = tokenizer.decode([input_ids[i].item()])
|
|
334
|
+
tok_str = tok_raw.strip()
|
|
335
|
+
# Match against stripped form OR raw form (catches \n, spaces, etc.)
|
|
336
|
+
if tok_str in self._DELIMITER_STRINGS or tok_raw in self._DELIMITER_STRINGS:
|
|
337
|
+
labels[i] = "delimiter"
|
|
338
|
+
|
|
339
|
+
# ── Pass 2: mark answer-choice span ───────────────────────────
|
|
340
|
+
# Options boundary detection: scan the full decoded text for the
|
|
341
|
+
# first occurrence of a newline followed by an answer label pattern.
|
|
342
|
+
# This handles tokenizers that split 'A)' into ['A', ')'] etc.
|
|
343
|
+
ANSWER_LABEL_RE = re.compile(r"\n(?:Options\s*:?|(?:[A-G][.)\s]|\([A-G]\)))", re.IGNORECASE)
|
|
344
|
+
options_start: Optional[int] = None
|
|
345
|
+
|
|
346
|
+
# Build cumulative char offsets per token (same approach as entity split).
|
|
347
|
+
cum_chars_pass2: list = []
|
|
348
|
+
offset_p2 = 0
|
|
349
|
+
for tid in input_ids.tolist():
|
|
350
|
+
decoded = tokenizer.decode([tid])
|
|
351
|
+
cum_chars_pass2.append(offset_p2)
|
|
352
|
+
offset_p2 += len(decoded)
|
|
353
|
+
|
|
354
|
+
full_text_p2 = tokenizer.decode(input_ids.tolist())
|
|
355
|
+
match = ANSWER_LABEL_RE.search(full_text_p2)
|
|
356
|
+
if match:
|
|
357
|
+
boundary_char = match.start() # char index of the '\n'
|
|
358
|
+
# Find first token that starts at or after boundary_char.
|
|
359
|
+
for i, tok_char_start in enumerate(cum_chars_pass2):
|
|
360
|
+
if tok_char_start >= boundary_char:
|
|
361
|
+
options_start = i
|
|
362
|
+
break
|
|
363
|
+
|
|
364
|
+
if options_start is not None:
|
|
365
|
+
for i in range(options_start, seq_len):
|
|
366
|
+
if labels[i] != "delimiter":
|
|
367
|
+
labels[i] = "choice"
|
|
368
|
+
|
|
369
|
+
# ── Pass 3 (MedQA only): entity vs stem split ──────────────────
|
|
370
|
+
metamap = metadata.get("metamap_phrases") if metadata else None
|
|
371
|
+
if metamap:
|
|
372
|
+
# metamap_phrases is a list of entity strings in their raw form.
|
|
373
|
+
# We decode a window of tokens and look for substring matches.
|
|
374
|
+
full_text = tokenizer.decode(input_ids.tolist())
|
|
375
|
+
entity_spans: List[tuple] = [] # (char_start, char_end)
|
|
376
|
+
for phrase in metamap:
|
|
377
|
+
phrase_str = str(phrase).strip()
|
|
378
|
+
if not phrase_str:
|
|
379
|
+
continue
|
|
380
|
+
for m in re.finditer(re.escape(phrase_str), full_text, re.IGNORECASE):
|
|
381
|
+
entity_spans.append((m.start(), m.end()))
|
|
382
|
+
|
|
383
|
+
# Map character spans back to token positions (approximate).
|
|
384
|
+
if entity_spans:
|
|
385
|
+
# Build cumulative char lengths per token.
|
|
386
|
+
cum_chars = []
|
|
387
|
+
offset = 0
|
|
388
|
+
for tid in input_ids.tolist():
|
|
389
|
+
decoded = tokenizer.decode([tid])
|
|
390
|
+
cum_chars.append((offset, offset + len(decoded)))
|
|
391
|
+
offset += len(decoded)
|
|
392
|
+
|
|
393
|
+
for i, (tok_start, tok_end) in enumerate(cum_chars):
|
|
394
|
+
if labels[i] != "content":
|
|
395
|
+
continue
|
|
396
|
+
for es, ee in entity_spans:
|
|
397
|
+
if tok_start < ee and tok_end > es: # overlap
|
|
398
|
+
labels[i] = "entity"
|
|
399
|
+
break
|
|
400
|
+
# Remaining "content" tokens become "stem".
|
|
401
|
+
labels = ["stem" if label == "content" else label for label in labels]
|
|
402
|
+
|
|
403
|
+
# ── Collect positions per group ────────────────────────────────
|
|
404
|
+
groups: Dict[str, List[int]] = {}
|
|
405
|
+
for i, lbl in enumerate(labels):
|
|
406
|
+
groups.setdefault(lbl, []).append(i)
|
|
407
|
+
|
|
408
|
+
# Always expose the 3 primary groups (even if empty).
|
|
409
|
+
for g in ("delimiter", "choice", "content", "stem", "entity"):
|
|
410
|
+
groups.setdefault(g, [])
|
|
411
|
+
|
|
412
|
+
return groups
|
|
413
|
+
|
|
414
|
+
def _forward_attention_masked(
|
|
415
|
+
self,
|
|
416
|
+
backend: InferenceBackend,
|
|
417
|
+
tokens,
|
|
418
|
+
mask_layer: int,
|
|
419
|
+
zero_positions: List[int],
|
|
420
|
+
answer_tok_id: int,
|
|
421
|
+
) -> float:
|
|
422
|
+
"""Forward pass suppressing ``zero_positions`` at ``mask_layer``'s attention.
|
|
423
|
+
|
|
424
|
+
Strategy: register a pre-forward hook on the target layer's ``self_attn``
|
|
425
|
+
module. Inside the hook we add a large negative value (-1e4) to the
|
|
426
|
+
attention_mask at the key-columns we want to suppress. The additive causal
|
|
427
|
+
mask is applied inside both ``eager`` and ``sdpa`` kernels before softmax, so
|
|
428
|
+
the suppressed positions get ~zero weight after softmax, with no
|
|
429
|
+
``output_attentions`` flag required.
|
|
430
|
+
|
|
431
|
+
Returns the logit (float32, CPU) for ``answer_tok_id`` at the last token.
|
|
432
|
+
"""
|
|
433
|
+
if not zero_positions:
|
|
434
|
+
with torch.no_grad():
|
|
435
|
+
out = backend._model(**tokens)
|
|
436
|
+
return float(out.logits[0, -1, answer_tok_id].detach().cpu().item())
|
|
437
|
+
|
|
438
|
+
seq_len = tokens["input_ids"].shape[-1]
|
|
439
|
+
device = tokens["input_ids"].device
|
|
440
|
+
# Build a (1, 1, seq_len, seq_len) additive bias tensor.
|
|
441
|
+
# -1e4 at every key-column in zero_positions, 0.0 elsewhere.
|
|
442
|
+
bias = torch.zeros(1, 1, seq_len, seq_len, dtype=torch.float32, device=device)
|
|
443
|
+
valid_pos = [p for p in zero_positions if p < seq_len]
|
|
444
|
+
if valid_pos:
|
|
445
|
+
bias[:, :, :, valid_pos] = -1e4
|
|
446
|
+
# Gemma 3 SDPA kernel requires bias dtype == query dtype (e.g. bfloat16).
|
|
447
|
+
model_dtype = backend._model.dtype
|
|
448
|
+
|
|
449
|
+
def _pre_hook(module, args, kwargs):
|
|
450
|
+
# Gemma self_attn receives attention_mask as a keyword argument.
|
|
451
|
+
if "attention_mask" in kwargs and kwargs["attention_mask"] is not None:
|
|
452
|
+
existing = kwargs["attention_mask"]
|
|
453
|
+
# Add bias then cast to model dtype so SDPA dtype check passes.
|
|
454
|
+
kwargs["attention_mask"] = (
|
|
455
|
+
existing + bias.to(dtype=existing.dtype, device=existing.device)
|
|
456
|
+
).to(dtype=model_dtype)
|
|
457
|
+
else:
|
|
458
|
+
kwargs["attention_mask"] = bias.to(dtype=model_dtype, device=device)
|
|
459
|
+
return args, kwargs
|
|
460
|
+
|
|
461
|
+
layer_mod = backend.hook_manager.get_layer_module(mask_layer)
|
|
462
|
+
attn_mod = getattr(layer_mod, "self_attn", None)
|
|
463
|
+
if attn_mod is None:
|
|
464
|
+
tqdm.write(
|
|
465
|
+
f" [warn] token_group_contrast: no self_attn on layer {mask_layer}, skipping mask"
|
|
466
|
+
)
|
|
467
|
+
with torch.no_grad():
|
|
468
|
+
out = backend._model(**tokens)
|
|
469
|
+
return float(out.logits[0, -1, answer_tok_id].detach().cpu().item())
|
|
470
|
+
|
|
471
|
+
handle = attn_mod.register_forward_pre_hook(_pre_hook, with_kwargs=True)
|
|
472
|
+
try:
|
|
473
|
+
with torch.no_grad():
|
|
474
|
+
out = backend._model(**tokens)
|
|
475
|
+
finally:
|
|
476
|
+
handle.remove()
|
|
477
|
+
|
|
478
|
+
return float(out.logits[0, -1, answer_tok_id].detach().float().cpu().item())
|
|
479
|
+
|
|
480
|
+
# ------------------------------------------------------------------
|
|
481
|
+
# Statistical correlation helpers
|
|
482
|
+
# ------------------------------------------------------------------
|
|
483
|
+
|
|
484
|
+
@staticmethod
|
|
485
|
+
def _compute_correlations(per_sample_results: List[Dict]) -> Dict[str, Any]:
|
|
486
|
+
"""Point-biserial correlations between each group's importance score and is_correct.
|
|
487
|
+
|
|
488
|
+
Point-biserial r equals Pearson r when one variable is binary, so we
|
|
489
|
+
compute standard Pearson r between the continuous importance score and
|
|
490
|
+
the 0/1 correctness label. A two-tailed p-value is derived from the
|
|
491
|
+
t-distribution (df = n-2). scipy is used for the CDF if available;
|
|
492
|
+
otherwise a normal approximation is used as a fallback.
|
|
493
|
+
|
|
494
|
+
Returns a dict keyed by group name, each with:
|
|
495
|
+
r – point-biserial correlation coefficient
|
|
496
|
+
p_value – two-tailed p-value
|
|
497
|
+
n – number of samples used
|
|
498
|
+
mean_importance_correct – mean importance when sample is correct
|
|
499
|
+
mean_importance_incorrect – mean importance when sample is incorrect
|
|
500
|
+
"""
|
|
501
|
+
valid = [s for s in per_sample_results if s.get("is_correct") is not None]
|
|
502
|
+
if len(valid) < 3:
|
|
503
|
+
return {}
|
|
504
|
+
|
|
505
|
+
labels = [int(s["is_correct"]) for s in valid]
|
|
506
|
+
|
|
507
|
+
# Collect all group names present across samples.
|
|
508
|
+
groups: set = set()
|
|
509
|
+
for s in valid:
|
|
510
|
+
groups.update(s.get("group_importances", {}).keys())
|
|
511
|
+
if any(s.get("entity_importance") is not None for s in valid):
|
|
512
|
+
groups.add("entity")
|
|
513
|
+
if any(s.get("stem_importance") is not None for s in valid):
|
|
514
|
+
groups.add("stem")
|
|
515
|
+
|
|
516
|
+
# Try to import scipy t-distribution CDF once.
|
|
517
|
+
try:
|
|
518
|
+
from scipy.stats import t as _t_dist # noqa: PLC0415
|
|
519
|
+
|
|
520
|
+
_t_cdf = _t_dist.cdf
|
|
521
|
+
except ImportError:
|
|
522
|
+
_t_cdf = None
|
|
523
|
+
|
|
524
|
+
def _p_value(t_stat: float, df: int) -> float:
|
|
525
|
+
if _t_cdf is not None:
|
|
526
|
+
return float(2 * (1 - _t_cdf(abs(t_stat), df=df)))
|
|
527
|
+
# Normal approximation fallback.
|
|
528
|
+
return float(2 * (1 - 0.5 * (1 + math.erf(abs(t_stat) / math.sqrt(2)))))
|
|
529
|
+
|
|
530
|
+
results: Dict[str, Any] = {}
|
|
531
|
+
for group in sorted(groups):
|
|
532
|
+
if group in ("entity", "stem"):
|
|
533
|
+
scores = [s.get(f"{group}_importance") for s in valid]
|
|
534
|
+
else:
|
|
535
|
+
scores = [s.get("group_importances", {}).get(group) for s in valid]
|
|
536
|
+
|
|
537
|
+
paired = [(y, x) for y, x in zip(labels, scores) if x is not None]
|
|
538
|
+
if len(paired) < 3:
|
|
539
|
+
continue
|
|
540
|
+
|
|
541
|
+
ys = [p[0] for p in paired]
|
|
542
|
+
xs = [p[1] for p in paired]
|
|
543
|
+
n_g = len(paired)
|
|
544
|
+
|
|
545
|
+
mean_x = sum(xs) / n_g
|
|
546
|
+
mean_y = sum(ys) / n_g
|
|
547
|
+
cov = sum((x - mean_x) * (y - mean_y) for x, y in zip(xs, ys))
|
|
548
|
+
std_x = math.sqrt(sum((x - mean_x) ** 2 for x in xs) + 1e-12)
|
|
549
|
+
std_y = math.sqrt(sum((y - mean_y) ** 2 for y in ys) + 1e-12)
|
|
550
|
+
r = cov / (std_x * std_y)
|
|
551
|
+
r = max(-1.0, min(1.0, r)) # clamp to [-1, 1]
|
|
552
|
+
|
|
553
|
+
if abs(r) >= 1.0 - 1e-9:
|
|
554
|
+
p_val = 0.0
|
|
555
|
+
else:
|
|
556
|
+
t_stat = r * math.sqrt((n_g - 2) / (1 - r**2 + 1e-12))
|
|
557
|
+
p_val = _p_value(t_stat, n_g - 2)
|
|
558
|
+
|
|
559
|
+
correct_scores = [x for x, y in zip(xs, ys) if y == 1]
|
|
560
|
+
incorrect_scores = [x for x, y in zip(xs, ys) if y == 0]
|
|
561
|
+
|
|
562
|
+
results[group] = {
|
|
563
|
+
"r": round(r, 4),
|
|
564
|
+
"p_value": round(p_val, 4),
|
|
565
|
+
"n": n_g,
|
|
566
|
+
"mean_importance_correct": (
|
|
567
|
+
round(sum(correct_scores) / len(correct_scores), 4) if correct_scores else None
|
|
568
|
+
),
|
|
569
|
+
"mean_importance_incorrect": (
|
|
570
|
+
round(sum(incorrect_scores) / len(incorrect_scores), 4)
|
|
571
|
+
if incorrect_scores
|
|
572
|
+
else None
|
|
573
|
+
),
|
|
574
|
+
}
|
|
575
|
+
|
|
576
|
+
return results
|
|
577
|
+
|
|
578
|
+
# ------------------------------------------------------------------
|
|
579
|
+
# Token-group contrast: sample loop
|
|
580
|
+
# ------------------------------------------------------------------
|
|
581
|
+
|
|
582
|
+
def _run_token_group_contrast(
|
|
583
|
+
self,
|
|
584
|
+
backend: InferenceBackend,
|
|
585
|
+
dataset: BaseDataset,
|
|
586
|
+
prompt_strategy: Any,
|
|
587
|
+
logger: Optional["ExperimentLogger"] = None,
|
|
588
|
+
) -> "ExperimentResult":
|
|
589
|
+
"""Token-group attention masking loop.
|
|
590
|
+
|
|
591
|
+
For each sample:
|
|
592
|
+
1. Build a single prompt (standard, no clean/corrupt split).
|
|
593
|
+
2. Tokenize and tag token positions into groups.
|
|
594
|
+
3. Run baseline forward (no masking) → logit_base + is_correct.
|
|
595
|
+
4. For each group, run _forward_attention_masked → logit_masked.
|
|
596
|
+
5. importance(group) = |logit_base - logit_masked|.
|
|
597
|
+
6. dominant_group = argmax importance.
|
|
598
|
+
"""
|
|
599
|
+
tokenizer = backend._tokenizer
|
|
600
|
+
mask_layer = self.token_group_contrast_layer
|
|
601
|
+
|
|
602
|
+
print(f"Model : {backend.model_name}")
|
|
603
|
+
print("Patching mode: token_group_contrast")
|
|
604
|
+
print(f"Mask layer : L{mask_layer}")
|
|
605
|
+
print(f"max_input_tokens: {self.max_input_tokens}")
|
|
606
|
+
|
|
607
|
+
samples = dataset.sample(self.num_samples, seed=self.seed)
|
|
608
|
+
n = len(samples)
|
|
609
|
+
print(f"Samples: {n} (each requires 4 forward passes)\n")
|
|
610
|
+
|
|
611
|
+
# Primary groups to probe (entity/stem only appear for MedQA).
|
|
612
|
+
PRIMARY_GROUPS = ("delimiter", "choice", "content")
|
|
613
|
+
|
|
614
|
+
per_sample_results: List[Dict] = []
|
|
615
|
+
# Accumulator: group → list of importance scores across samples.
|
|
616
|
+
group_importances: Dict[str, List[float]] = {g: [] for g in PRIMARY_GROUPS}
|
|
617
|
+
# Per-group: track whether dominant_group == group AND sample correct.
|
|
618
|
+
accuracy_by_dominant: Dict[str, List[bool]] = {g: [] for g in PRIMARY_GROUPS}
|
|
619
|
+
processed = 0
|
|
620
|
+
|
|
621
|
+
for sample in tqdm(samples, desc="Token-group contrast"):
|
|
622
|
+
answer_tok_id = self._answer_token_id(tokenizer, sample.label)
|
|
623
|
+
if answer_tok_id is None:
|
|
624
|
+
tqdm.write(f" [skip] sample {sample.idx}: cannot resolve answer token")
|
|
625
|
+
continue
|
|
626
|
+
|
|
627
|
+
prompt_str = self._build_prompt(prompt_strategy, sample.text, sample.metadata or {})
|
|
628
|
+
tokens = self._tokenize(tokenizer, prompt_str, backend.device)
|
|
629
|
+
input_ids = tokens["input_ids"][0] # (seq_len,)
|
|
630
|
+
|
|
631
|
+
# Tag tokens into groups.
|
|
632
|
+
try:
|
|
633
|
+
groups = self._tag_tokens(input_ids, tokenizer, sample.metadata or {})
|
|
634
|
+
except Exception as exc:
|
|
635
|
+
tqdm.write(f" [skip] sample {sample.idx} (tagging): {exc}")
|
|
636
|
+
continue
|
|
637
|
+
|
|
638
|
+
# Baseline forward (no masking) — also derive is_correct in one pass.
|
|
639
|
+
try:
|
|
640
|
+
with torch.no_grad():
|
|
641
|
+
out_base = backend._model(**tokens)
|
|
642
|
+
last_logits = out_base.logits[0, -1].detach().float().cpu()
|
|
643
|
+
logit_base = float(last_logits[answer_tok_id].item())
|
|
644
|
+
letter_ids = self._answer_letter_token_ids(tokenizer)
|
|
645
|
+
if letter_ids:
|
|
646
|
+
best_letter_tok = max(letter_ids, key=lambda t: last_logits[t].item())
|
|
647
|
+
is_correct = best_letter_tok == answer_tok_id
|
|
648
|
+
else:
|
|
649
|
+
is_correct = False
|
|
650
|
+
del out_base, last_logits
|
|
651
|
+
except Exception as exc:
|
|
652
|
+
tqdm.write(f" [skip] sample {sample.idx} (baseline): {exc}")
|
|
653
|
+
torch.cuda.empty_cache()
|
|
654
|
+
continue
|
|
655
|
+
|
|
656
|
+
# Masked forward passes per group.
|
|
657
|
+
sample_importances: Dict[str, float] = {}
|
|
658
|
+
for group in PRIMARY_GROUPS:
|
|
659
|
+
zero_pos = groups.get(group, [])
|
|
660
|
+
try:
|
|
661
|
+
logit_masked = self._forward_attention_masked(
|
|
662
|
+
backend, tokens, mask_layer, zero_pos, answer_tok_id
|
|
663
|
+
)
|
|
664
|
+
importance = abs(logit_base - logit_masked)
|
|
665
|
+
except Exception as exc:
|
|
666
|
+
tqdm.write(f" [skip] sample {sample.idx} group '{group}': {exc}")
|
|
667
|
+
importance = 0.0
|
|
668
|
+
finally:
|
|
669
|
+
torch.cuda.empty_cache()
|
|
670
|
+
|
|
671
|
+
sample_importances[group] = round(importance, 4)
|
|
672
|
+
group_importances[group].append(importance)
|
|
673
|
+
|
|
674
|
+
# Dominant group for this sample.
|
|
675
|
+
dominant = max(sample_importances, key=lambda g: sample_importances[g])
|
|
676
|
+
if is_correct is not None:
|
|
677
|
+
accuracy_by_dominant[dominant].append(is_correct)
|
|
678
|
+
|
|
679
|
+
# MedQA entity/stem breakdown (bonus — logged but not aggregated).
|
|
680
|
+
entity_importance: Optional[float] = None
|
|
681
|
+
stem_importance: Optional[float] = None
|
|
682
|
+
if groups.get("entity"):
|
|
683
|
+
try:
|
|
684
|
+
lm_e = self._forward_attention_masked(
|
|
685
|
+
backend, tokens, mask_layer, groups["entity"], answer_tok_id
|
|
686
|
+
)
|
|
687
|
+
entity_importance = round(abs(logit_base - lm_e), 4)
|
|
688
|
+
except Exception:
|
|
689
|
+
pass
|
|
690
|
+
finally:
|
|
691
|
+
torch.cuda.empty_cache()
|
|
692
|
+
if groups.get("stem"):
|
|
693
|
+
try:
|
|
694
|
+
lm_s = self._forward_attention_masked(
|
|
695
|
+
backend, tokens, mask_layer, groups["stem"], answer_tok_id
|
|
696
|
+
)
|
|
697
|
+
stem_importance = round(abs(logit_base - lm_s), 4)
|
|
698
|
+
except Exception:
|
|
699
|
+
pass
|
|
700
|
+
finally:
|
|
701
|
+
torch.cuda.empty_cache()
|
|
702
|
+
|
|
703
|
+
per_sample_results.append(
|
|
704
|
+
{
|
|
705
|
+
"sample_idx": sample.idx,
|
|
706
|
+
"is_correct": is_correct,
|
|
707
|
+
"logit_base": round(logit_base, 4),
|
|
708
|
+
"dominant_group": dominant,
|
|
709
|
+
"group_importances": sample_importances,
|
|
710
|
+
"token_counts": {g: len(groups.get(g, [])) for g in PRIMARY_GROUPS},
|
|
711
|
+
"entity_importance": entity_importance,
|
|
712
|
+
"stem_importance": stem_importance,
|
|
713
|
+
}
|
|
714
|
+
)
|
|
715
|
+
processed += 1
|
|
716
|
+
|
|
717
|
+
# ── Aggregate ─────────────────────────────────────────────────
|
|
718
|
+
mean_importance: Dict[str, float] = {
|
|
719
|
+
g: round(sum(v) / len(v), 4) if v else 0.0 for g, v in group_importances.items()
|
|
720
|
+
}
|
|
721
|
+
dominant_group_overall = max(mean_importance, key=lambda g: mean_importance[g])
|
|
722
|
+
|
|
723
|
+
acc_by_dom: Dict[str, Optional[float]] = {}
|
|
724
|
+
for g, hits in accuracy_by_dominant.items():
|
|
725
|
+
acc_by_dom[g] = round(sum(hits) / len(hits), 4) if hits else None
|
|
726
|
+
|
|
727
|
+
correlations = self._compute_correlations(per_sample_results)
|
|
728
|
+
|
|
729
|
+
# ── Print summary ──────────────────────────────────────────────
|
|
730
|
+
print("\n" + "=" * 70)
|
|
731
|
+
print(f"TOKEN GROUP CONTRAST — L{mask_layer} attention masking")
|
|
732
|
+
print("=" * 70)
|
|
733
|
+
print(f"Processed samples : {processed} / {n}")
|
|
734
|
+
print(f"Dominant group (avg): {dominant_group_overall}")
|
|
735
|
+
print()
|
|
736
|
+
print(f"{'Group':<12} {'Mean Importance':>16} {'Acc when dominant':>18}")
|
|
737
|
+
print("-" * 52)
|
|
738
|
+
for g in PRIMARY_GROUPS:
|
|
739
|
+
acc_str = f"{acc_by_dom[g]:.4f}" if acc_by_dom[g] is not None else " n/a "
|
|
740
|
+
print(f"{g:<12} {mean_importance[g]:>16.4f} {acc_str:>18}")
|
|
741
|
+
|
|
742
|
+
if correlations:
|
|
743
|
+
print()
|
|
744
|
+
print("Point-biserial correlations (importance score → is_correct):")
|
|
745
|
+
print(
|
|
746
|
+
f" {'Group':<12} {'r':>7} {'p':>8} {'n':>5} {'mean(corr)':>11} {'mean(incorr)':>12}"
|
|
747
|
+
)
|
|
748
|
+
print(" " + "-" * 60)
|
|
749
|
+
for g, c in correlations.items():
|
|
750
|
+
sig = "*" if c["p_value"] < 0.05 else (" " if c["p_value"] < 0.10 else " ")
|
|
751
|
+
mc = (
|
|
752
|
+
f"{c['mean_importance_correct']:.4f}"
|
|
753
|
+
if c["mean_importance_correct"] is not None
|
|
754
|
+
else " n/a "
|
|
755
|
+
)
|
|
756
|
+
mi = (
|
|
757
|
+
f"{c['mean_importance_incorrect']:.4f}"
|
|
758
|
+
if c["mean_importance_incorrect"] is not None
|
|
759
|
+
else " n/a "
|
|
760
|
+
)
|
|
761
|
+
print(
|
|
762
|
+
f" {g:<12} {c['r']:>+7.4f} {c['p_value']:>8.4f}{sig} {c['n']:>5} {mc:>11} {mi:>12}"
|
|
763
|
+
)
|
|
764
|
+
print(" (* p<0.05)")
|
|
765
|
+
|
|
766
|
+
print("=" * 70)
|
|
767
|
+
print()
|
|
768
|
+
print("Interpretation:")
|
|
769
|
+
print(
|
|
770
|
+
" Higher importance = removing this group from L",
|
|
771
|
+
mask_layer,
|
|
772
|
+
"attention shifts the answer more.",
|
|
773
|
+
)
|
|
774
|
+
print(" The dominant group is what the layer causally relies on most.")
|
|
775
|
+
if correlations:
|
|
776
|
+
print(" Positive r = higher importance of this group → more likely correct.")
|
|
777
|
+
print(" Negative r = higher importance of this group → more likely incorrect.")
|
|
778
|
+
|
|
779
|
+
return ExperimentResult(
|
|
780
|
+
experiment_name=self.name,
|
|
781
|
+
model_name=backend.model_name,
|
|
782
|
+
prompt_strategy=(
|
|
783
|
+
prompt_strategy.name if hasattr(prompt_strategy, "name") else "custom"
|
|
784
|
+
),
|
|
785
|
+
metrics={
|
|
786
|
+
"num_samples": processed,
|
|
787
|
+
"mask_layer": mask_layer,
|
|
788
|
+
"mean_importance_per_group": mean_importance,
|
|
789
|
+
"dominant_group": dominant_group_overall,
|
|
790
|
+
"accuracy_when_dominant": acc_by_dom,
|
|
791
|
+
"point_biserial_correlations": correlations,
|
|
792
|
+
},
|
|
793
|
+
raw_outputs={"per_sample": per_sample_results},
|
|
794
|
+
metadata={
|
|
795
|
+
"mask_layer": mask_layer,
|
|
796
|
+
"token_group_mode": self.token_group_mode,
|
|
797
|
+
"num_samples": processed,
|
|
798
|
+
"seed": self.seed,
|
|
799
|
+
"answer_cue": self.answer_cue,
|
|
800
|
+
},
|
|
801
|
+
)
|
|
802
|
+
|
|
803
|
+
# ------------------------------------------------------------------
|
|
804
|
+
# Main entry point
|
|
805
|
+
# ------------------------------------------------------------------
|
|
806
|
+
|
|
807
|
+
def _build_prompt(self, prompt_strategy: Any, text: str, metadata: dict) -> str:
|
|
808
|
+
return (
|
|
809
|
+
prompt_strategy.build_prompt(
|
|
810
|
+
{
|
|
811
|
+
"text": text,
|
|
812
|
+
"question": text,
|
|
813
|
+
"report": text,
|
|
814
|
+
"metadata": metadata,
|
|
815
|
+
}
|
|
816
|
+
)
|
|
817
|
+
+ self.answer_cue
|
|
818
|
+
)
|
|
819
|
+
|
|
820
|
+
def _build_prompt_few_shot(
|
|
821
|
+
self, prompt_strategy: Any, text: str, metadata: dict, few_shot: bool
|
|
822
|
+
) -> str:
|
|
823
|
+
"""Build prompt with few_shot toggled — restores original value afterwards."""
|
|
824
|
+
orig = getattr(prompt_strategy, "few_shot", None)
|
|
825
|
+
try:
|
|
826
|
+
if hasattr(prompt_strategy, "few_shot"):
|
|
827
|
+
prompt_strategy.few_shot = few_shot
|
|
828
|
+
return self._build_prompt(prompt_strategy, text, metadata)
|
|
829
|
+
finally:
|
|
830
|
+
if orig is not None:
|
|
831
|
+
prompt_strategy.few_shot = orig
|
|
832
|
+
|
|
833
|
+
def _build_prompt_introspect(
|
|
834
|
+
self, prompt_strategy: Any, text: str, metadata: dict, introspect: bool
|
|
835
|
+
) -> str:
|
|
836
|
+
"""Build prompt with introspect instruction appended (clean) or omitted (corrupt).
|
|
837
|
+
|
|
838
|
+
clean (introspect=True) → standard prompt + introspect_instruction prepended
|
|
839
|
+
corrupt (introspect=False) → standard prompt only (no instruction)
|
|
840
|
+
|
|
841
|
+
few_shot is kept at whatever the prompt strategy has configured so that
|
|
842
|
+
the only variable between clean and corrupt is the introspect wording.
|
|
843
|
+
"""
|
|
844
|
+
base = self._build_prompt(prompt_strategy, text, metadata)
|
|
845
|
+
if introspect:
|
|
846
|
+
# Prepend the instruction before the main prompt body so it sets
|
|
847
|
+
# the reasoning intent from the first token.
|
|
848
|
+
return self.introspect_instruction + "\n\n" + base
|
|
849
|
+
return base
|
|
850
|
+
|
|
851
|
+
def _build_prompt_cot(self, prompt_strategy: Any, text: str, metadata: dict, cot: bool) -> str:
|
|
852
|
+
"""Build prompt with CoT trigger active (clean) or stripped (corrupt).
|
|
853
|
+
|
|
854
|
+
clean (cot=True) → full CoT prompt with cot_trigger intact
|
|
855
|
+
corrupt (cot=False) → same prompt with cot_trigger set to "" (zero-shot)
|
|
856
|
+
|
|
857
|
+
Only the cot_trigger attribute is toggled; few_shot and all other
|
|
858
|
+
strategy settings are preserved so CoT is the sole variable.
|
|
859
|
+
"""
|
|
860
|
+
orig = getattr(prompt_strategy, "cot_trigger", None)
|
|
861
|
+
try:
|
|
862
|
+
if hasattr(prompt_strategy, "cot_trigger"):
|
|
863
|
+
prompt_strategy.cot_trigger = orig if cot else ""
|
|
864
|
+
return self._build_prompt(prompt_strategy, text, metadata)
|
|
865
|
+
finally:
|
|
866
|
+
if orig is not None:
|
|
867
|
+
prompt_strategy.cot_trigger = orig
|
|
868
|
+
|
|
869
|
+
def run(
|
|
870
|
+
self,
|
|
871
|
+
backend: InferenceBackend,
|
|
872
|
+
dataset: BaseDataset,
|
|
873
|
+
prompt_strategy: Any,
|
|
874
|
+
logger: Optional[ExperimentLogger] = None,
|
|
875
|
+
**kwargs,
|
|
876
|
+
) -> ExperimentResult:
|
|
877
|
+
"""Run activation patching experiment.
|
|
878
|
+
|
|
879
|
+
Dispatches to the token_group_contrast branch when
|
|
880
|
+
``patching_mode == 'token_group_contrast'``, otherwise runs the
|
|
881
|
+
standard layer-sweep residual patching.
|
|
882
|
+
"""
|
|
883
|
+
|
|
884
|
+
tokenizer = backend._tokenizer
|
|
885
|
+
|
|
886
|
+
# ── Dispatch to token_group_contrast mode ─────────────────────
|
|
887
|
+
if self.patching_mode == "token_group_contrast":
|
|
888
|
+
return self._run_token_group_contrast(backend, dataset, prompt_strategy, logger)
|
|
889
|
+
|
|
890
|
+
# ── Standard residual patching modes ──────────────────────────
|
|
891
|
+
target_layers = self._resolve_layers(backend)
|
|
892
|
+
|
|
893
|
+
print(f"Model : {backend.model_name}")
|
|
894
|
+
print(f"Patching mode: {self.patching_mode}")
|
|
895
|
+
print(f"Layers ({len(target_layers)}): {target_layers}")
|
|
896
|
+
print(f"Stride : {self.layer_stride} | max_input_tokens: {self.max_input_tokens}")
|
|
897
|
+
|
|
898
|
+
samples = dataset.sample(self.num_samples, seed=self.seed)
|
|
899
|
+
n = len(samples)
|
|
900
|
+
print(f"Samples: {n} (each requires {len(target_layers) + 2} forward passes)\n")
|
|
901
|
+
|
|
902
|
+
# Per-layer effect accumulators
|
|
903
|
+
layer_effects: Dict[int, List[float]] = {lid: [] for lid in target_layers}
|
|
904
|
+
per_sample_results: List[Dict] = []
|
|
905
|
+
processed = 0
|
|
906
|
+
|
|
907
|
+
for sample in tqdm(samples, desc="Activation patching"):
|
|
908
|
+
clean_tok_id = self._answer_token_id(tokenizer, sample.label)
|
|
909
|
+
if clean_tok_id is None:
|
|
910
|
+
tqdm.write(f" [skip] sample {sample.idx}: cannot resolve answer token")
|
|
911
|
+
continue
|
|
912
|
+
|
|
913
|
+
# ── Build clean / corrupted prompt strings based on mode ──────
|
|
914
|
+
if self.patching_mode == "pairs":
|
|
915
|
+
corrupted_prompt = sample.metadata.get("corrupted_prompt")
|
|
916
|
+
if not corrupted_prompt:
|
|
917
|
+
tqdm.write(f" [skip] sample {sample.idx}: no corrupted_prompt in metadata")
|
|
918
|
+
continue
|
|
919
|
+
clean_str = self._build_prompt(prompt_strategy, sample.text, sample.metadata or {})
|
|
920
|
+
corr_str = self._build_prompt(prompt_strategy, corrupted_prompt, {})
|
|
921
|
+
elif self.patching_mode == "few_shot_contrast":
|
|
922
|
+
# few-shot = clean (more context → better answer representation)
|
|
923
|
+
# zero-shot = corrupted
|
|
924
|
+
clean_str = self._build_prompt_few_shot(
|
|
925
|
+
prompt_strategy, sample.text, sample.metadata or {}, few_shot=True
|
|
926
|
+
)
|
|
927
|
+
corr_str = self._build_prompt_few_shot(
|
|
928
|
+
prompt_strategy, sample.text, sample.metadata or {}, few_shot=False
|
|
929
|
+
)
|
|
930
|
+
elif self.patching_mode == "introspect_contrast":
|
|
931
|
+
# introspect instruction prepended = clean
|
|
932
|
+
# no instruction = corrupted
|
|
933
|
+
clean_str = self._build_prompt_introspect(
|
|
934
|
+
prompt_strategy, sample.text, sample.metadata or {}, introspect=True
|
|
935
|
+
)
|
|
936
|
+
corr_str = self._build_prompt_introspect(
|
|
937
|
+
prompt_strategy, sample.text, sample.metadata or {}, introspect=False
|
|
938
|
+
)
|
|
939
|
+
else: # cot_contrast
|
|
940
|
+
# CoT trigger active = clean
|
|
941
|
+
# CoT trigger stripped (zero-shot) = corrupted
|
|
942
|
+
clean_str = self._build_prompt_cot(
|
|
943
|
+
prompt_strategy, sample.text, sample.metadata or {}, cot=True
|
|
944
|
+
)
|
|
945
|
+
corr_str = self._build_prompt_cot(
|
|
946
|
+
prompt_strategy, sample.text, sample.metadata or {}, cot=False
|
|
947
|
+
)
|
|
948
|
+
|
|
949
|
+
clean_tokens = self._tokenize(tokenizer, clean_str, backend.device)
|
|
950
|
+
corr_tokens = self._tokenize(tokenizer, corr_str, backend.device)
|
|
951
|
+
|
|
952
|
+
try:
|
|
953
|
+
# Step 1 — clean forward, cache activations
|
|
954
|
+
logits_clean, act_cache = self._forward_with_cache(
|
|
955
|
+
backend, clean_tokens, target_layers
|
|
956
|
+
)
|
|
957
|
+
# Step 2 — corrupted baseline (no patching needed, reuse cache run)
|
|
958
|
+
logits_corr, _ = self._forward_with_cache(backend, corr_tokens, [])
|
|
959
|
+
except Exception as exc:
|
|
960
|
+
tqdm.write(f" [skip] sample {sample.idx} (baseline): {type(exc).__name__}: {exc}")
|
|
961
|
+
torch.cuda.empty_cache()
|
|
962
|
+
continue
|
|
963
|
+
|
|
964
|
+
clean_logit = float(logits_clean[clean_tok_id].item())
|
|
965
|
+
corr_logit = float(logits_corr[clean_tok_id].item())
|
|
966
|
+
denom = clean_logit - corr_logit # may be 0 or negative
|
|
967
|
+
|
|
968
|
+
sample_layer_effects: Dict[int, float] = {}
|
|
969
|
+
|
|
970
|
+
# Step 3 — patching sweep over layers
|
|
971
|
+
for layer_idx in target_layers:
|
|
972
|
+
if layer_idx not in act_cache:
|
|
973
|
+
continue
|
|
974
|
+
try:
|
|
975
|
+
logits_patch = self._forward_patched(
|
|
976
|
+
backend, corr_tokens, layer_idx, act_cache[layer_idx]
|
|
977
|
+
)
|
|
978
|
+
except Exception as exc:
|
|
979
|
+
tqdm.write(f" [skip] sample {sample.idx} layer {layer_idx}: {exc}")
|
|
980
|
+
torch.cuda.empty_cache()
|
|
981
|
+
continue
|
|
982
|
+
|
|
983
|
+
patch_logit = float(logits_patch[clean_tok_id].item())
|
|
984
|
+
eps = 1e-6
|
|
985
|
+
if abs(denom) < eps:
|
|
986
|
+
effect = 0.0
|
|
987
|
+
else:
|
|
988
|
+
effect = (patch_logit - corr_logit) / denom
|
|
989
|
+
# Clip to [-1, 2] to handle outliers
|
|
990
|
+
effect = max(-1.0, min(2.0, effect))
|
|
991
|
+
layer_effects[layer_idx].append(effect)
|
|
992
|
+
sample_layer_effects[layer_idx] = round(effect, 4)
|
|
993
|
+
torch.cuda.empty_cache()
|
|
994
|
+
|
|
995
|
+
per_sample_results.append(
|
|
996
|
+
{
|
|
997
|
+
"sample_idx": sample.idx,
|
|
998
|
+
"clean_logit": round(clean_logit, 4),
|
|
999
|
+
"corrupt_logit": round(corr_logit, 4),
|
|
1000
|
+
"logit_gap": round(denom, 4),
|
|
1001
|
+
"layer_effects": sample_layer_effects,
|
|
1002
|
+
}
|
|
1003
|
+
)
|
|
1004
|
+
processed += 1
|
|
1005
|
+
torch.cuda.empty_cache()
|
|
1006
|
+
|
|
1007
|
+
# --- Aggregate --------------------------------------------------
|
|
1008
|
+
mean_effects: Dict[int, float] = {}
|
|
1009
|
+
for layer_idx in target_layers:
|
|
1010
|
+
vals = layer_effects[layer_idx]
|
|
1011
|
+
mean_effects[layer_idx] = round(sum(vals) / len(vals), 4) if vals else 0.0
|
|
1012
|
+
|
|
1013
|
+
sorted_by_effect = sorted(mean_effects.items(), key=lambda x: x[1], reverse=True)
|
|
1014
|
+
top_5_layers = [lid for lid, _ in sorted_by_effect[:5]]
|
|
1015
|
+
|
|
1016
|
+
# --- Print summary -----------------------------------------------
|
|
1017
|
+
print("\n" + "=" * 62)
|
|
1018
|
+
print("ACTIVATION PATCHING SUMMARY (logit-recovery effect)")
|
|
1019
|
+
print("=" * 62)
|
|
1020
|
+
print(f"Processed samples : {processed} / {n}")
|
|
1021
|
+
print(f"Top-5 causal layers: {top_5_layers}")
|
|
1022
|
+
print()
|
|
1023
|
+
print(f"{'Layer':>6} {'Mean Effect':>12} {'N samples':>10}")
|
|
1024
|
+
print("-" * 34)
|
|
1025
|
+
for layer_idx in target_layers:
|
|
1026
|
+
n_val = len(layer_effects[layer_idx])
|
|
1027
|
+
print(f"{layer_idx:>6} {mean_effects[layer_idx]:>12.4f} {n_val:>10}")
|
|
1028
|
+
print("=" * 62)
|
|
1029
|
+
|
|
1030
|
+
return ExperimentResult(
|
|
1031
|
+
experiment_name=self.name,
|
|
1032
|
+
model_name=backend.model_name,
|
|
1033
|
+
prompt_strategy=(
|
|
1034
|
+
prompt_strategy.name if hasattr(prompt_strategy, "name") else "custom"
|
|
1035
|
+
),
|
|
1036
|
+
metrics={
|
|
1037
|
+
"num_samples": processed,
|
|
1038
|
+
"layer_stride": self.layer_stride,
|
|
1039
|
+
"mean_effect_per_layer": mean_effects,
|
|
1040
|
+
"top_5_causal_layers": top_5_layers,
|
|
1041
|
+
},
|
|
1042
|
+
raw_outputs={"per_sample": per_sample_results},
|
|
1043
|
+
metadata={
|
|
1044
|
+
"target_layers": target_layers,
|
|
1045
|
+
"layer_stride": self.layer_stride,
|
|
1046
|
+
"num_samples": processed,
|
|
1047
|
+
"seed": self.seed,
|
|
1048
|
+
"answer_cue": self.answer_cue,
|
|
1049
|
+
},
|
|
1050
|
+
)
|