cotlab 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. cotlab/__init__.py +3 -0
  2. cotlab/analyse_experiments.py +392 -0
  3. cotlab/analysis/__init__.py +11 -0
  4. cotlab/analysis/cot_parser.py +243 -0
  5. cotlab/analysis/faithfulness_metrics.py +192 -0
  6. cotlab/backends/__init__.py +16 -0
  7. cotlab/backends/base.py +78 -0
  8. cotlab/backends/transformers_backend.py +335 -0
  9. cotlab/backends/vllm_backend.py +227 -0
  10. cotlab/cli.py +83 -0
  11. cotlab/core/__init__.py +34 -0
  12. cotlab/core/base.py +749 -0
  13. cotlab/core/config.py +90 -0
  14. cotlab/core/registry.py +68 -0
  15. cotlab/datasets/__init__.py +45 -0
  16. cotlab/datasets/loaders.py +1889 -0
  17. cotlab/experiment/__init__.py +315 -0
  18. cotlab/experiments/__init__.py +43 -0
  19. cotlab/experiments/activation_compare.py +290 -0
  20. cotlab/experiments/activation_patching.py +1050 -0
  21. cotlab/experiments/attention_analysis.py +885 -0
  22. cotlab/experiments/classification.py +235 -0
  23. cotlab/experiments/composite_shift_detector.py +524 -0
  24. cotlab/experiments/cot_ablation.py +277 -0
  25. cotlab/experiments/cot_faithfulness.py +187 -0
  26. cotlab/experiments/cot_heads.py +208 -0
  27. cotlab/experiments/full_layer_cot.py +232 -0
  28. cotlab/experiments/full_layer_patching.py +225 -0
  29. cotlab/experiments/h_neuron_analysis.py +712 -0
  30. cotlab/experiments/logit_lens.py +439 -0
  31. cotlab/experiments/multi_head_cot.py +220 -0
  32. cotlab/experiments/multi_head_patching.py +229 -0
  33. cotlab/experiments/probing_classifier.py +402 -0
  34. cotlab/experiments/residual_norm_ood.py +413 -0
  35. cotlab/experiments/sae_feature_analysis.py +673 -0
  36. cotlab/experiments/steering_vectors.py +223 -0
  37. cotlab/experiments/sycophancy_heads.py +224 -0
  38. cotlab/logging/__init__.py +5 -0
  39. cotlab/logging/json_logger.py +161 -0
  40. cotlab/main.py +317 -0
  41. cotlab/patching/__init__.py +24 -0
  42. cotlab/patching/cache.py +141 -0
  43. cotlab/patching/hooks.py +558 -0
  44. cotlab/patching/interventions.py +86 -0
  45. cotlab/patching/patcher.py +439 -0
  46. cotlab/patching/sae.py +181 -0
  47. cotlab/prompts/__init__.py +43 -0
  48. cotlab/prompts/cardiology.py +378 -0
  49. cotlab/prompts/histopathology.py +265 -0
  50. cotlab/prompts/length_matched_strategies.py +157 -0
  51. cotlab/prompts/mcq.py +193 -0
  52. cotlab/prompts/neurology.py +353 -0
  53. cotlab/prompts/oncology.py +367 -0
  54. cotlab/prompts/plab.py +162 -0
  55. cotlab/prompts/pubhealthbench.py +82 -0
  56. cotlab/prompts/pubmedqa.py +173 -0
  57. cotlab/prompts/radiology.py +414 -0
  58. cotlab/prompts/strategies.py +939 -0
  59. cotlab/prompts/tcga.py +168 -0
  60. cotlab/runner.py +204 -0
  61. cotlab-0.8.0.dist-info/METADATA +166 -0
  62. cotlab-0.8.0.dist-info/RECORD +65 -0
  63. cotlab-0.8.0.dist-info/WHEEL +4 -0
  64. cotlab-0.8.0.dist-info/entry_points.txt +3 -0
  65. cotlab-0.8.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,1050 @@
1
+ """Activation Patching experiment — causal intervention via residual-stream replacement.
2
+
3
+ Patching modes
4
+ --------------
5
+ ``pairs`` (default — requires PatchingPairsDataset)
6
+ clean = sample.text
7
+ corrupt = sample.metadata["corrupted_prompt"]
8
+ Answers Q: which layers encode the specific diagnosis/fact?
9
+
10
+ ``few_shot_contrast`` (works with ANY dataset)
11
+ clean = few-shot prompt of the sample (prompt_strategy with few_shot=True)
12
+ corrupt = zero-shot prompt of the sample (prompt_strategy with few_shot=False)
13
+ Answers Q: which layers causally drive few-shot's benefit on OOD / non-OOD?
14
+
15
+ ``introspect_contrast`` (works with ANY dataset)
16
+ clean = prompt + introspect instruction
17
+ corrupt = prompt only
18
+ Answers Q: which layers carry the "think deeply" reasoning signal?
19
+
20
+ ``cot_contrast`` (works with ANY dataset)
21
+ clean = full CoT prompt (cot_trigger active, e.g. "Let's think through this step by step:")
22
+ corrupt = zero-shot prompt (cot_trigger stripped — same structure, no reasoning nudge)
23
+ Answers Q: which layers carry the chain-of-thought reasoning signal vs plain answering?
24
+ Use as the default/baseline contrast alongside few_shot_contrast and introspect_contrast.
25
+
26
+ ``token_group_contrast`` (works with ANY dataset)
27
+ Hooks the attention weight matrix at a single target layer and zeros out
28
+ one token group at a time (delimiter / choice / content). Measures how
29
+ much each group's removal shifts the answer logit.
30
+ Answers Q: which token positions does <target_mask_layer> attend to causally?
31
+ (Q1: run at layer 3 — the universal attention bottleneck)
32
+
33
+ Algorithm (logit-recovery metric, one sample, residual patching modes):
34
+ 1. Forward clean → cache per-layer residuals (CPU).
35
+ 2. Forward corrupt → baseline logit at last token.
36
+ 3. For each layer L (strided):
37
+ Re-run corrupt with hook replacing layer L's output with cached clean.
38
+ effect(L) = (logit_patched[clean_tok] - logit_corrupt[clean_tok])
39
+ / (logit_clean[clean_tok] - logit_corrupt[clean_tok] + ε)
40
+ 1 = full recovery, 0 = no effect, negative = made things worse.
41
+
42
+ Algorithm (token_group_contrast):
43
+ 1. Single forward pass with no masking → logit_base.
44
+ 2. For each group G in {delimiter, choice, content}:
45
+ Forward with attention weights at target_mask_layer zeroed for group G.
46
+ importance(G) = |logit_base - logit_masked|
47
+ 3. dominant_group = argmax importance.
48
+
49
+ Memory safety: activations moved to CPU immediately inside each hook.
50
+ """
51
+
52
+ import math
53
+ import re
54
+ from typing import Any, Dict, List, Optional, Set
55
+
56
+ import torch
57
+ from tqdm import tqdm
58
+
59
+ from ..backends.base import InferenceBackend
60
+ from ..core.base import BaseExperiment, ExperimentResult
61
+ from ..core.registry import Registry
62
+ from ..datasets.loaders import BaseDataset
63
+ from ..logging import ExperimentLogger
64
+
65
+
66
+ @Registry.register_experiment("activation_patching")
67
+ class ActivationPatchingExperiment(BaseExperiment):
68
+ """
69
+ Layer-wise causal activation patching with logit-recovery scoring.
70
+
71
+ Supports two patching modes:
72
+ - ``pairs`` PatchingPairsDataset clean/corrupt pairs.
73
+ - ``few_shot_contrast`` Any dataset — few-shot (clean) vs zero-shot (corrupt).
74
+ - ``cot_contrast`` Any dataset — CoT prompt (clean) vs zero-shot (corrupt).
75
+ """
76
+
77
+ VALID_MODES = (
78
+ "pairs",
79
+ "few_shot_contrast",
80
+ "introspect_contrast",
81
+ "token_group_contrast",
82
+ "cot_contrast",
83
+ )
84
+
85
+ # Tokens that are purely structural / formatting — not medical content.
86
+ _DELIMITER_STRINGS: Set[str] = {
87
+ "\n",
88
+ ":",
89
+ "#",
90
+ "##",
91
+ "###",
92
+ "Options",
93
+ "Options:",
94
+ "Answer",
95
+ "Answer:",
96
+ "A.",
97
+ "B.",
98
+ "C.",
99
+ "D.",
100
+ "E.",
101
+ "F.",
102
+ "G.",
103
+ "(A)",
104
+ "(B)",
105
+ "(C)",
106
+ "(D)",
107
+ "(E)",
108
+ "(F)",
109
+ "(G)",
110
+ "A)",
111
+ "B)",
112
+ "C)",
113
+ "D)",
114
+ "E)",
115
+ "F)",
116
+ "G)",
117
+ }
118
+
119
+ def __init__(
120
+ self,
121
+ name: str = "activation_patching",
122
+ description: str = "Layer-wise causal activation patching (logit recovery)",
123
+ patching_mode: str = "pairs", # "pairs" | "few_shot_contrast" | "introspect_contrast"
124
+ layer_stride: int = 2,
125
+ num_samples: int = 50,
126
+ max_input_tokens: int = 1024,
127
+ seed: int = 42,
128
+ answer_cue: str = "\n\nAnswer:",
129
+ introspect_instruction: str = (
130
+ "Think deeply about this problem. "
131
+ "Carefully reason through the underlying mechanisms and consider "
132
+ "all relevant factors before committing to your answer."
133
+ ),
134
+ # Legacy fields kept so old YAML configs don't break
135
+ variants: Optional[List[Dict[str, Any]]] = None,
136
+ patching: Optional[Dict[str, Any]] = None,
137
+ # Token-group contrast params
138
+ token_group_contrast_layer: int = 3,
139
+ token_group_mode: str = "all", # "all" | "delimiter" | "choice" | "content"
140
+ **kwargs,
141
+ ):
142
+ if patching_mode not in self.VALID_MODES:
143
+ raise ValueError(
144
+ f"patching_mode must be one of {self.VALID_MODES}, got {patching_mode!r}"
145
+ )
146
+ self._name = name
147
+ self.description = description
148
+ self.patching_mode = patching_mode
149
+ self.layer_stride = layer_stride
150
+ self.num_samples = num_samples
151
+ self.max_input_tokens = max_input_tokens
152
+ self.seed = seed
153
+ self.answer_cue = answer_cue
154
+ self.introspect_instruction = introspect_instruction
155
+ self.patching = patching or {}
156
+ self.token_group_contrast_layer = int(token_group_contrast_layer)
157
+ self.token_group_mode = token_group_mode
158
+
159
+ @property
160
+ def name(self) -> str:
161
+ return self._name
162
+
163
+ # ------------------------------------------------------------------
164
+ # Helpers
165
+ # ------------------------------------------------------------------
166
+
167
+ def _resolve_layers(self, backend: InferenceBackend) -> List[int]:
168
+ all_layers = list(range(backend.hook_manager.num_layers))
169
+ return all_layers[:: self.layer_stride]
170
+
171
+ def _resolve_head_targets(self, layers: List[int]) -> Dict[int, List[int]]:
172
+ """Resolve optional head-target mapping from `patching` config.
173
+
174
+ Supported configs (mutually exclusive):
175
+ - `patching.head_indices`: list of heads to apply to all `layers`
176
+ - `patching.target_heads`: mapping layer -> list of heads
177
+ """
178
+ head_indices = self.patching.get("head_indices")
179
+ target_heads = self.patching.get("target_heads")
180
+
181
+ if head_indices is not None and target_heads is not None:
182
+ raise ValueError("Use either target_heads or head_indices, not both.")
183
+
184
+ if target_heads is not None:
185
+ resolved: Dict[int, List[int]] = {}
186
+ for layer_key, heads in dict(target_heads).items():
187
+ layer_idx = int(layer_key)
188
+ if layer_idx in layers:
189
+ resolved[layer_idx] = [int(h) for h in list(heads)]
190
+ return resolved
191
+
192
+ if head_indices is not None:
193
+ head_list = [int(h) for h in list(head_indices)]
194
+ return {layer_idx: head_list for layer_idx in layers}
195
+
196
+ return {}
197
+
198
+ def _answer_token_id(self, tokenizer, label) -> Optional[int]:
199
+ """Return the first token id of the label string (the logit we track)."""
200
+ if label is None:
201
+ return None
202
+ label_str = str(label).strip()
203
+ if not label_str:
204
+ return None
205
+ for prefix in (" ", ""):
206
+ ids = tokenizer.encode(prefix + label_str, add_special_tokens=False)
207
+ if ids:
208
+ return ids[0]
209
+ return None
210
+
211
+ def _answer_letter_token_ids(self, tokenizer) -> List[int]:
212
+ """Collect all plausible token ids for MCQ answer letters A-J."""
213
+ ids = set()
214
+ for letter in "ABCDEFGHIJ":
215
+ for prefix in (" ", "", "\n"):
216
+ encoded = tokenizer.encode(prefix + letter, add_special_tokens=False)
217
+ if encoded:
218
+ ids.add(encoded[-1])
219
+ return sorted(ids)
220
+
221
+ def _tokenize(self, tokenizer, text: str, device):
222
+ return tokenizer(
223
+ text,
224
+ return_tensors="pt",
225
+ truncation=True,
226
+ max_length=self.max_input_tokens,
227
+ ).to(device)
228
+
229
+ def _forward_with_cache(
230
+ self,
231
+ backend: InferenceBackend,
232
+ tokens,
233
+ target_layers: List[int],
234
+ ) -> tuple:
235
+ """Run a forward pass, caching residual activations (last token, CPU) per layer.
236
+
237
+ Returns:
238
+ logits_last – [vocab_size] float32 CPU tensor at last token position
239
+ act_cache – dict[layer_idx → [hidden] float32 CPU tensor]
240
+ """
241
+ act_cache: Dict[int, torch.Tensor] = {}
242
+
243
+ def make_cache_hook(layer_idx: int):
244
+ def hook(module, inp, output):
245
+ tensor = output[0] if isinstance(output, tuple) else output
246
+ with torch.no_grad():
247
+ # keep bfloat16 so patching is dtype-compatible with the model
248
+ act_cache[layer_idx] = tensor[0, -1].detach().cpu()
249
+ return output
250
+
251
+ return hook
252
+
253
+ handles = [
254
+ backend.hook_manager.get_residual_module(layer_idx).register_forward_hook(
255
+ make_cache_hook(layer_idx)
256
+ )
257
+ for layer_idx in target_layers
258
+ if layer_idx < backend.hook_manager.num_layers
259
+ ]
260
+ try:
261
+ with torch.no_grad():
262
+ out = backend._model(**tokens)
263
+ finally:
264
+ for h in handles:
265
+ h.remove()
266
+
267
+ logits_last = out.logits[0, -1].detach().float().cpu() # float32 for stable arithmetic
268
+ return logits_last, act_cache
269
+
270
+ def _forward_patched(
271
+ self,
272
+ backend: InferenceBackend,
273
+ tokens,
274
+ patch_layer: int,
275
+ patch_vec: torch.Tensor, # CPU [hidden]
276
+ ) -> torch.Tensor:
277
+ """Forward pass replacing layer `patch_layer` output with `patch_vec`.
278
+
279
+ Returns [vocab_size] float32 CPU logit vector at last token.
280
+ """
281
+ # cast to model dtype (bfloat16) before injection — avoids dtype mismatch
282
+ model_dtype = next(backend._model.parameters()).dtype
283
+ patch_gpu = patch_vec.to(dtype=model_dtype, device=backend.device)
284
+
285
+ def patch_hook(module, inp, output):
286
+ if isinstance(output, tuple):
287
+ patched = list(output)
288
+ patched[0] = patch_gpu.unsqueeze(0).unsqueeze(0).expand_as(output[0])
289
+ return tuple(patched)
290
+ return patch_gpu.unsqueeze(0).unsqueeze(0).expand_as(output)
291
+
292
+ mod = backend.hook_manager.get_residual_module(patch_layer)
293
+ handle = mod.register_forward_hook(patch_hook)
294
+ try:
295
+ with torch.no_grad():
296
+ out = backend._model(**tokens)
297
+ finally:
298
+ handle.remove()
299
+ del patch_gpu
300
+
301
+ return out.logits[0, -1].detach().float().cpu()
302
+
303
+ # ------------------------------------------------------------------
304
+ # Token-group tagger and attention masking helpers
305
+ # ------------------------------------------------------------------
306
+
307
+ def _tag_tokens(
308
+ self,
309
+ input_ids: torch.Tensor, # shape (seq_len,)
310
+ tokenizer,
311
+ metadata: dict,
312
+ ) -> Dict[str, List[int]]:
313
+ """Classify every token position into one of 3 groups.
314
+
315
+ Groups
316
+ ------
317
+ delimiter : structural tokens (\\n, A., Options:, …)
318
+ choice : answer-option text (the words after A. / B. / …)
319
+ content : question stem + clinical entities
320
+
321
+ For MedQA samples that carry ``metamap_phrases`` in metadata the
322
+ content group is further split into ``entity`` and ``stem``.
323
+
324
+ Returns
325
+ -------
326
+ dict mapping group name -> sorted list of 0-based token positions.
327
+ """
328
+ seq_len = input_ids.shape[0]
329
+ labels = ["content"] * seq_len # default everything to content
330
+
331
+ # ── Pass 1: mark delimiter tokens ─────────────────────────────
332
+ for i in range(seq_len):
333
+ tok_raw = tokenizer.decode([input_ids[i].item()])
334
+ tok_str = tok_raw.strip()
335
+ # Match against stripped form OR raw form (catches \n, spaces, etc.)
336
+ if tok_str in self._DELIMITER_STRINGS or tok_raw in self._DELIMITER_STRINGS:
337
+ labels[i] = "delimiter"
338
+
339
+ # ── Pass 2: mark answer-choice span ───────────────────────────
340
+ # Options boundary detection: scan the full decoded text for the
341
+ # first occurrence of a newline followed by an answer label pattern.
342
+ # This handles tokenizers that split 'A)' into ['A', ')'] etc.
343
+ ANSWER_LABEL_RE = re.compile(r"\n(?:Options\s*:?|(?:[A-G][.)\s]|\([A-G]\)))", re.IGNORECASE)
344
+ options_start: Optional[int] = None
345
+
346
+ # Build cumulative char offsets per token (same approach as entity split).
347
+ cum_chars_pass2: list = []
348
+ offset_p2 = 0
349
+ for tid in input_ids.tolist():
350
+ decoded = tokenizer.decode([tid])
351
+ cum_chars_pass2.append(offset_p2)
352
+ offset_p2 += len(decoded)
353
+
354
+ full_text_p2 = tokenizer.decode(input_ids.tolist())
355
+ match = ANSWER_LABEL_RE.search(full_text_p2)
356
+ if match:
357
+ boundary_char = match.start() # char index of the '\n'
358
+ # Find first token that starts at or after boundary_char.
359
+ for i, tok_char_start in enumerate(cum_chars_pass2):
360
+ if tok_char_start >= boundary_char:
361
+ options_start = i
362
+ break
363
+
364
+ if options_start is not None:
365
+ for i in range(options_start, seq_len):
366
+ if labels[i] != "delimiter":
367
+ labels[i] = "choice"
368
+
369
+ # ── Pass 3 (MedQA only): entity vs stem split ──────────────────
370
+ metamap = metadata.get("metamap_phrases") if metadata else None
371
+ if metamap:
372
+ # metamap_phrases is a list of entity strings in their raw form.
373
+ # We decode a window of tokens and look for substring matches.
374
+ full_text = tokenizer.decode(input_ids.tolist())
375
+ entity_spans: List[tuple] = [] # (char_start, char_end)
376
+ for phrase in metamap:
377
+ phrase_str = str(phrase).strip()
378
+ if not phrase_str:
379
+ continue
380
+ for m in re.finditer(re.escape(phrase_str), full_text, re.IGNORECASE):
381
+ entity_spans.append((m.start(), m.end()))
382
+
383
+ # Map character spans back to token positions (approximate).
384
+ if entity_spans:
385
+ # Build cumulative char lengths per token.
386
+ cum_chars = []
387
+ offset = 0
388
+ for tid in input_ids.tolist():
389
+ decoded = tokenizer.decode([tid])
390
+ cum_chars.append((offset, offset + len(decoded)))
391
+ offset += len(decoded)
392
+
393
+ for i, (tok_start, tok_end) in enumerate(cum_chars):
394
+ if labels[i] != "content":
395
+ continue
396
+ for es, ee in entity_spans:
397
+ if tok_start < ee and tok_end > es: # overlap
398
+ labels[i] = "entity"
399
+ break
400
+ # Remaining "content" tokens become "stem".
401
+ labels = ["stem" if label == "content" else label for label in labels]
402
+
403
+ # ── Collect positions per group ────────────────────────────────
404
+ groups: Dict[str, List[int]] = {}
405
+ for i, lbl in enumerate(labels):
406
+ groups.setdefault(lbl, []).append(i)
407
+
408
+ # Always expose the 3 primary groups (even if empty).
409
+ for g in ("delimiter", "choice", "content", "stem", "entity"):
410
+ groups.setdefault(g, [])
411
+
412
+ return groups
413
+
414
+ def _forward_attention_masked(
415
+ self,
416
+ backend: InferenceBackend,
417
+ tokens,
418
+ mask_layer: int,
419
+ zero_positions: List[int],
420
+ answer_tok_id: int,
421
+ ) -> float:
422
+ """Forward pass suppressing ``zero_positions`` at ``mask_layer``'s attention.
423
+
424
+ Strategy: register a pre-forward hook on the target layer's ``self_attn``
425
+ module. Inside the hook we add a large negative value (-1e4) to the
426
+ attention_mask at the key-columns we want to suppress. The additive causal
427
+ mask is applied inside both ``eager`` and ``sdpa`` kernels before softmax, so
428
+ the suppressed positions get ~zero weight after softmax, with no
429
+ ``output_attentions`` flag required.
430
+
431
+ Returns the logit (float32, CPU) for ``answer_tok_id`` at the last token.
432
+ """
433
+ if not zero_positions:
434
+ with torch.no_grad():
435
+ out = backend._model(**tokens)
436
+ return float(out.logits[0, -1, answer_tok_id].detach().cpu().item())
437
+
438
+ seq_len = tokens["input_ids"].shape[-1]
439
+ device = tokens["input_ids"].device
440
+ # Build a (1, 1, seq_len, seq_len) additive bias tensor.
441
+ # -1e4 at every key-column in zero_positions, 0.0 elsewhere.
442
+ bias = torch.zeros(1, 1, seq_len, seq_len, dtype=torch.float32, device=device)
443
+ valid_pos = [p for p in zero_positions if p < seq_len]
444
+ if valid_pos:
445
+ bias[:, :, :, valid_pos] = -1e4
446
+ # Gemma 3 SDPA kernel requires bias dtype == query dtype (e.g. bfloat16).
447
+ model_dtype = backend._model.dtype
448
+
449
+ def _pre_hook(module, args, kwargs):
450
+ # Gemma self_attn receives attention_mask as a keyword argument.
451
+ if "attention_mask" in kwargs and kwargs["attention_mask"] is not None:
452
+ existing = kwargs["attention_mask"]
453
+ # Add bias then cast to model dtype so SDPA dtype check passes.
454
+ kwargs["attention_mask"] = (
455
+ existing + bias.to(dtype=existing.dtype, device=existing.device)
456
+ ).to(dtype=model_dtype)
457
+ else:
458
+ kwargs["attention_mask"] = bias.to(dtype=model_dtype, device=device)
459
+ return args, kwargs
460
+
461
+ layer_mod = backend.hook_manager.get_layer_module(mask_layer)
462
+ attn_mod = getattr(layer_mod, "self_attn", None)
463
+ if attn_mod is None:
464
+ tqdm.write(
465
+ f" [warn] token_group_contrast: no self_attn on layer {mask_layer}, skipping mask"
466
+ )
467
+ with torch.no_grad():
468
+ out = backend._model(**tokens)
469
+ return float(out.logits[0, -1, answer_tok_id].detach().cpu().item())
470
+
471
+ handle = attn_mod.register_forward_pre_hook(_pre_hook, with_kwargs=True)
472
+ try:
473
+ with torch.no_grad():
474
+ out = backend._model(**tokens)
475
+ finally:
476
+ handle.remove()
477
+
478
+ return float(out.logits[0, -1, answer_tok_id].detach().float().cpu().item())
479
+
480
+ # ------------------------------------------------------------------
481
+ # Statistical correlation helpers
482
+ # ------------------------------------------------------------------
483
+
484
+ @staticmethod
485
+ def _compute_correlations(per_sample_results: List[Dict]) -> Dict[str, Any]:
486
+ """Point-biserial correlations between each group's importance score and is_correct.
487
+
488
+ Point-biserial r equals Pearson r when one variable is binary, so we
489
+ compute standard Pearson r between the continuous importance score and
490
+ the 0/1 correctness label. A two-tailed p-value is derived from the
491
+ t-distribution (df = n-2). scipy is used for the CDF if available;
492
+ otherwise a normal approximation is used as a fallback.
493
+
494
+ Returns a dict keyed by group name, each with:
495
+ r – point-biserial correlation coefficient
496
+ p_value – two-tailed p-value
497
+ n – number of samples used
498
+ mean_importance_correct – mean importance when sample is correct
499
+ mean_importance_incorrect – mean importance when sample is incorrect
500
+ """
501
+ valid = [s for s in per_sample_results if s.get("is_correct") is not None]
502
+ if len(valid) < 3:
503
+ return {}
504
+
505
+ labels = [int(s["is_correct"]) for s in valid]
506
+
507
+ # Collect all group names present across samples.
508
+ groups: set = set()
509
+ for s in valid:
510
+ groups.update(s.get("group_importances", {}).keys())
511
+ if any(s.get("entity_importance") is not None for s in valid):
512
+ groups.add("entity")
513
+ if any(s.get("stem_importance") is not None for s in valid):
514
+ groups.add("stem")
515
+
516
+ # Try to import scipy t-distribution CDF once.
517
+ try:
518
+ from scipy.stats import t as _t_dist # noqa: PLC0415
519
+
520
+ _t_cdf = _t_dist.cdf
521
+ except ImportError:
522
+ _t_cdf = None
523
+
524
+ def _p_value(t_stat: float, df: int) -> float:
525
+ if _t_cdf is not None:
526
+ return float(2 * (1 - _t_cdf(abs(t_stat), df=df)))
527
+ # Normal approximation fallback.
528
+ return float(2 * (1 - 0.5 * (1 + math.erf(abs(t_stat) / math.sqrt(2)))))
529
+
530
+ results: Dict[str, Any] = {}
531
+ for group in sorted(groups):
532
+ if group in ("entity", "stem"):
533
+ scores = [s.get(f"{group}_importance") for s in valid]
534
+ else:
535
+ scores = [s.get("group_importances", {}).get(group) for s in valid]
536
+
537
+ paired = [(y, x) for y, x in zip(labels, scores) if x is not None]
538
+ if len(paired) < 3:
539
+ continue
540
+
541
+ ys = [p[0] for p in paired]
542
+ xs = [p[1] for p in paired]
543
+ n_g = len(paired)
544
+
545
+ mean_x = sum(xs) / n_g
546
+ mean_y = sum(ys) / n_g
547
+ cov = sum((x - mean_x) * (y - mean_y) for x, y in zip(xs, ys))
548
+ std_x = math.sqrt(sum((x - mean_x) ** 2 for x in xs) + 1e-12)
549
+ std_y = math.sqrt(sum((y - mean_y) ** 2 for y in ys) + 1e-12)
550
+ r = cov / (std_x * std_y)
551
+ r = max(-1.0, min(1.0, r)) # clamp to [-1, 1]
552
+
553
+ if abs(r) >= 1.0 - 1e-9:
554
+ p_val = 0.0
555
+ else:
556
+ t_stat = r * math.sqrt((n_g - 2) / (1 - r**2 + 1e-12))
557
+ p_val = _p_value(t_stat, n_g - 2)
558
+
559
+ correct_scores = [x for x, y in zip(xs, ys) if y == 1]
560
+ incorrect_scores = [x for x, y in zip(xs, ys) if y == 0]
561
+
562
+ results[group] = {
563
+ "r": round(r, 4),
564
+ "p_value": round(p_val, 4),
565
+ "n": n_g,
566
+ "mean_importance_correct": (
567
+ round(sum(correct_scores) / len(correct_scores), 4) if correct_scores else None
568
+ ),
569
+ "mean_importance_incorrect": (
570
+ round(sum(incorrect_scores) / len(incorrect_scores), 4)
571
+ if incorrect_scores
572
+ else None
573
+ ),
574
+ }
575
+
576
+ return results
577
+
578
+ # ------------------------------------------------------------------
579
+ # Token-group contrast: sample loop
580
+ # ------------------------------------------------------------------
581
+
582
+ def _run_token_group_contrast(
583
+ self,
584
+ backend: InferenceBackend,
585
+ dataset: BaseDataset,
586
+ prompt_strategy: Any,
587
+ logger: Optional["ExperimentLogger"] = None,
588
+ ) -> "ExperimentResult":
589
+ """Token-group attention masking loop.
590
+
591
+ For each sample:
592
+ 1. Build a single prompt (standard, no clean/corrupt split).
593
+ 2. Tokenize and tag token positions into groups.
594
+ 3. Run baseline forward (no masking) → logit_base + is_correct.
595
+ 4. For each group, run _forward_attention_masked → logit_masked.
596
+ 5. importance(group) = |logit_base - logit_masked|.
597
+ 6. dominant_group = argmax importance.
598
+ """
599
+ tokenizer = backend._tokenizer
600
+ mask_layer = self.token_group_contrast_layer
601
+
602
+ print(f"Model : {backend.model_name}")
603
+ print("Patching mode: token_group_contrast")
604
+ print(f"Mask layer : L{mask_layer}")
605
+ print(f"max_input_tokens: {self.max_input_tokens}")
606
+
607
+ samples = dataset.sample(self.num_samples, seed=self.seed)
608
+ n = len(samples)
609
+ print(f"Samples: {n} (each requires 4 forward passes)\n")
610
+
611
+ # Primary groups to probe (entity/stem only appear for MedQA).
612
+ PRIMARY_GROUPS = ("delimiter", "choice", "content")
613
+
614
+ per_sample_results: List[Dict] = []
615
+ # Accumulator: group → list of importance scores across samples.
616
+ group_importances: Dict[str, List[float]] = {g: [] for g in PRIMARY_GROUPS}
617
+ # Per-group: track whether dominant_group == group AND sample correct.
618
+ accuracy_by_dominant: Dict[str, List[bool]] = {g: [] for g in PRIMARY_GROUPS}
619
+ processed = 0
620
+
621
+ for sample in tqdm(samples, desc="Token-group contrast"):
622
+ answer_tok_id = self._answer_token_id(tokenizer, sample.label)
623
+ if answer_tok_id is None:
624
+ tqdm.write(f" [skip] sample {sample.idx}: cannot resolve answer token")
625
+ continue
626
+
627
+ prompt_str = self._build_prompt(prompt_strategy, sample.text, sample.metadata or {})
628
+ tokens = self._tokenize(tokenizer, prompt_str, backend.device)
629
+ input_ids = tokens["input_ids"][0] # (seq_len,)
630
+
631
+ # Tag tokens into groups.
632
+ try:
633
+ groups = self._tag_tokens(input_ids, tokenizer, sample.metadata or {})
634
+ except Exception as exc:
635
+ tqdm.write(f" [skip] sample {sample.idx} (tagging): {exc}")
636
+ continue
637
+
638
+ # Baseline forward (no masking) — also derive is_correct in one pass.
639
+ try:
640
+ with torch.no_grad():
641
+ out_base = backend._model(**tokens)
642
+ last_logits = out_base.logits[0, -1].detach().float().cpu()
643
+ logit_base = float(last_logits[answer_tok_id].item())
644
+ letter_ids = self._answer_letter_token_ids(tokenizer)
645
+ if letter_ids:
646
+ best_letter_tok = max(letter_ids, key=lambda t: last_logits[t].item())
647
+ is_correct = best_letter_tok == answer_tok_id
648
+ else:
649
+ is_correct = False
650
+ del out_base, last_logits
651
+ except Exception as exc:
652
+ tqdm.write(f" [skip] sample {sample.idx} (baseline): {exc}")
653
+ torch.cuda.empty_cache()
654
+ continue
655
+
656
+ # Masked forward passes per group.
657
+ sample_importances: Dict[str, float] = {}
658
+ for group in PRIMARY_GROUPS:
659
+ zero_pos = groups.get(group, [])
660
+ try:
661
+ logit_masked = self._forward_attention_masked(
662
+ backend, tokens, mask_layer, zero_pos, answer_tok_id
663
+ )
664
+ importance = abs(logit_base - logit_masked)
665
+ except Exception as exc:
666
+ tqdm.write(f" [skip] sample {sample.idx} group '{group}': {exc}")
667
+ importance = 0.0
668
+ finally:
669
+ torch.cuda.empty_cache()
670
+
671
+ sample_importances[group] = round(importance, 4)
672
+ group_importances[group].append(importance)
673
+
674
+ # Dominant group for this sample.
675
+ dominant = max(sample_importances, key=lambda g: sample_importances[g])
676
+ if is_correct is not None:
677
+ accuracy_by_dominant[dominant].append(is_correct)
678
+
679
+ # MedQA entity/stem breakdown (bonus — logged but not aggregated).
680
+ entity_importance: Optional[float] = None
681
+ stem_importance: Optional[float] = None
682
+ if groups.get("entity"):
683
+ try:
684
+ lm_e = self._forward_attention_masked(
685
+ backend, tokens, mask_layer, groups["entity"], answer_tok_id
686
+ )
687
+ entity_importance = round(abs(logit_base - lm_e), 4)
688
+ except Exception:
689
+ pass
690
+ finally:
691
+ torch.cuda.empty_cache()
692
+ if groups.get("stem"):
693
+ try:
694
+ lm_s = self._forward_attention_masked(
695
+ backend, tokens, mask_layer, groups["stem"], answer_tok_id
696
+ )
697
+ stem_importance = round(abs(logit_base - lm_s), 4)
698
+ except Exception:
699
+ pass
700
+ finally:
701
+ torch.cuda.empty_cache()
702
+
703
+ per_sample_results.append(
704
+ {
705
+ "sample_idx": sample.idx,
706
+ "is_correct": is_correct,
707
+ "logit_base": round(logit_base, 4),
708
+ "dominant_group": dominant,
709
+ "group_importances": sample_importances,
710
+ "token_counts": {g: len(groups.get(g, [])) for g in PRIMARY_GROUPS},
711
+ "entity_importance": entity_importance,
712
+ "stem_importance": stem_importance,
713
+ }
714
+ )
715
+ processed += 1
716
+
717
+ # ── Aggregate ─────────────────────────────────────────────────
718
+ mean_importance: Dict[str, float] = {
719
+ g: round(sum(v) / len(v), 4) if v else 0.0 for g, v in group_importances.items()
720
+ }
721
+ dominant_group_overall = max(mean_importance, key=lambda g: mean_importance[g])
722
+
723
+ acc_by_dom: Dict[str, Optional[float]] = {}
724
+ for g, hits in accuracy_by_dominant.items():
725
+ acc_by_dom[g] = round(sum(hits) / len(hits), 4) if hits else None
726
+
727
+ correlations = self._compute_correlations(per_sample_results)
728
+
729
+ # ── Print summary ──────────────────────────────────────────────
730
+ print("\n" + "=" * 70)
731
+ print(f"TOKEN GROUP CONTRAST — L{mask_layer} attention masking")
732
+ print("=" * 70)
733
+ print(f"Processed samples : {processed} / {n}")
734
+ print(f"Dominant group (avg): {dominant_group_overall}")
735
+ print()
736
+ print(f"{'Group':<12} {'Mean Importance':>16} {'Acc when dominant':>18}")
737
+ print("-" * 52)
738
+ for g in PRIMARY_GROUPS:
739
+ acc_str = f"{acc_by_dom[g]:.4f}" if acc_by_dom[g] is not None else " n/a "
740
+ print(f"{g:<12} {mean_importance[g]:>16.4f} {acc_str:>18}")
741
+
742
+ if correlations:
743
+ print()
744
+ print("Point-biserial correlations (importance score → is_correct):")
745
+ print(
746
+ f" {'Group':<12} {'r':>7} {'p':>8} {'n':>5} {'mean(corr)':>11} {'mean(incorr)':>12}"
747
+ )
748
+ print(" " + "-" * 60)
749
+ for g, c in correlations.items():
750
+ sig = "*" if c["p_value"] < 0.05 else (" " if c["p_value"] < 0.10 else " ")
751
+ mc = (
752
+ f"{c['mean_importance_correct']:.4f}"
753
+ if c["mean_importance_correct"] is not None
754
+ else " n/a "
755
+ )
756
+ mi = (
757
+ f"{c['mean_importance_incorrect']:.4f}"
758
+ if c["mean_importance_incorrect"] is not None
759
+ else " n/a "
760
+ )
761
+ print(
762
+ f" {g:<12} {c['r']:>+7.4f} {c['p_value']:>8.4f}{sig} {c['n']:>5} {mc:>11} {mi:>12}"
763
+ )
764
+ print(" (* p<0.05)")
765
+
766
+ print("=" * 70)
767
+ print()
768
+ print("Interpretation:")
769
+ print(
770
+ " Higher importance = removing this group from L",
771
+ mask_layer,
772
+ "attention shifts the answer more.",
773
+ )
774
+ print(" The dominant group is what the layer causally relies on most.")
775
+ if correlations:
776
+ print(" Positive r = higher importance of this group → more likely correct.")
777
+ print(" Negative r = higher importance of this group → more likely incorrect.")
778
+
779
+ return ExperimentResult(
780
+ experiment_name=self.name,
781
+ model_name=backend.model_name,
782
+ prompt_strategy=(
783
+ prompt_strategy.name if hasattr(prompt_strategy, "name") else "custom"
784
+ ),
785
+ metrics={
786
+ "num_samples": processed,
787
+ "mask_layer": mask_layer,
788
+ "mean_importance_per_group": mean_importance,
789
+ "dominant_group": dominant_group_overall,
790
+ "accuracy_when_dominant": acc_by_dom,
791
+ "point_biserial_correlations": correlations,
792
+ },
793
+ raw_outputs={"per_sample": per_sample_results},
794
+ metadata={
795
+ "mask_layer": mask_layer,
796
+ "token_group_mode": self.token_group_mode,
797
+ "num_samples": processed,
798
+ "seed": self.seed,
799
+ "answer_cue": self.answer_cue,
800
+ },
801
+ )
802
+
803
+ # ------------------------------------------------------------------
804
+ # Main entry point
805
+ # ------------------------------------------------------------------
806
+
807
+ def _build_prompt(self, prompt_strategy: Any, text: str, metadata: dict) -> str:
808
+ return (
809
+ prompt_strategy.build_prompt(
810
+ {
811
+ "text": text,
812
+ "question": text,
813
+ "report": text,
814
+ "metadata": metadata,
815
+ }
816
+ )
817
+ + self.answer_cue
818
+ )
819
+
820
+ def _build_prompt_few_shot(
821
+ self, prompt_strategy: Any, text: str, metadata: dict, few_shot: bool
822
+ ) -> str:
823
+ """Build prompt with few_shot toggled — restores original value afterwards."""
824
+ orig = getattr(prompt_strategy, "few_shot", None)
825
+ try:
826
+ if hasattr(prompt_strategy, "few_shot"):
827
+ prompt_strategy.few_shot = few_shot
828
+ return self._build_prompt(prompt_strategy, text, metadata)
829
+ finally:
830
+ if orig is not None:
831
+ prompt_strategy.few_shot = orig
832
+
833
+ def _build_prompt_introspect(
834
+ self, prompt_strategy: Any, text: str, metadata: dict, introspect: bool
835
+ ) -> str:
836
+ """Build prompt with introspect instruction appended (clean) or omitted (corrupt).
837
+
838
+ clean (introspect=True) → standard prompt + introspect_instruction prepended
839
+ corrupt (introspect=False) → standard prompt only (no instruction)
840
+
841
+ few_shot is kept at whatever the prompt strategy has configured so that
842
+ the only variable between clean and corrupt is the introspect wording.
843
+ """
844
+ base = self._build_prompt(prompt_strategy, text, metadata)
845
+ if introspect:
846
+ # Prepend the instruction before the main prompt body so it sets
847
+ # the reasoning intent from the first token.
848
+ return self.introspect_instruction + "\n\n" + base
849
+ return base
850
+
851
+ def _build_prompt_cot(self, prompt_strategy: Any, text: str, metadata: dict, cot: bool) -> str:
852
+ """Build prompt with CoT trigger active (clean) or stripped (corrupt).
853
+
854
+ clean (cot=True) → full CoT prompt with cot_trigger intact
855
+ corrupt (cot=False) → same prompt with cot_trigger set to "" (zero-shot)
856
+
857
+ Only the cot_trigger attribute is toggled; few_shot and all other
858
+ strategy settings are preserved so CoT is the sole variable.
859
+ """
860
+ orig = getattr(prompt_strategy, "cot_trigger", None)
861
+ try:
862
+ if hasattr(prompt_strategy, "cot_trigger"):
863
+ prompt_strategy.cot_trigger = orig if cot else ""
864
+ return self._build_prompt(prompt_strategy, text, metadata)
865
+ finally:
866
+ if orig is not None:
867
+ prompt_strategy.cot_trigger = orig
868
+
869
+ def run(
870
+ self,
871
+ backend: InferenceBackend,
872
+ dataset: BaseDataset,
873
+ prompt_strategy: Any,
874
+ logger: Optional[ExperimentLogger] = None,
875
+ **kwargs,
876
+ ) -> ExperimentResult:
877
+ """Run activation patching experiment.
878
+
879
+ Dispatches to the token_group_contrast branch when
880
+ ``patching_mode == 'token_group_contrast'``, otherwise runs the
881
+ standard layer-sweep residual patching.
882
+ """
883
+
884
+ tokenizer = backend._tokenizer
885
+
886
+ # ── Dispatch to token_group_contrast mode ─────────────────────
887
+ if self.patching_mode == "token_group_contrast":
888
+ return self._run_token_group_contrast(backend, dataset, prompt_strategy, logger)
889
+
890
+ # ── Standard residual patching modes ──────────────────────────
891
+ target_layers = self._resolve_layers(backend)
892
+
893
+ print(f"Model : {backend.model_name}")
894
+ print(f"Patching mode: {self.patching_mode}")
895
+ print(f"Layers ({len(target_layers)}): {target_layers}")
896
+ print(f"Stride : {self.layer_stride} | max_input_tokens: {self.max_input_tokens}")
897
+
898
+ samples = dataset.sample(self.num_samples, seed=self.seed)
899
+ n = len(samples)
900
+ print(f"Samples: {n} (each requires {len(target_layers) + 2} forward passes)\n")
901
+
902
+ # Per-layer effect accumulators
903
+ layer_effects: Dict[int, List[float]] = {lid: [] for lid in target_layers}
904
+ per_sample_results: List[Dict] = []
905
+ processed = 0
906
+
907
+ for sample in tqdm(samples, desc="Activation patching"):
908
+ clean_tok_id = self._answer_token_id(tokenizer, sample.label)
909
+ if clean_tok_id is None:
910
+ tqdm.write(f" [skip] sample {sample.idx}: cannot resolve answer token")
911
+ continue
912
+
913
+ # ── Build clean / corrupted prompt strings based on mode ──────
914
+ if self.patching_mode == "pairs":
915
+ corrupted_prompt = sample.metadata.get("corrupted_prompt")
916
+ if not corrupted_prompt:
917
+ tqdm.write(f" [skip] sample {sample.idx}: no corrupted_prompt in metadata")
918
+ continue
919
+ clean_str = self._build_prompt(prompt_strategy, sample.text, sample.metadata or {})
920
+ corr_str = self._build_prompt(prompt_strategy, corrupted_prompt, {})
921
+ elif self.patching_mode == "few_shot_contrast":
922
+ # few-shot = clean (more context → better answer representation)
923
+ # zero-shot = corrupted
924
+ clean_str = self._build_prompt_few_shot(
925
+ prompt_strategy, sample.text, sample.metadata or {}, few_shot=True
926
+ )
927
+ corr_str = self._build_prompt_few_shot(
928
+ prompt_strategy, sample.text, sample.metadata or {}, few_shot=False
929
+ )
930
+ elif self.patching_mode == "introspect_contrast":
931
+ # introspect instruction prepended = clean
932
+ # no instruction = corrupted
933
+ clean_str = self._build_prompt_introspect(
934
+ prompt_strategy, sample.text, sample.metadata or {}, introspect=True
935
+ )
936
+ corr_str = self._build_prompt_introspect(
937
+ prompt_strategy, sample.text, sample.metadata or {}, introspect=False
938
+ )
939
+ else: # cot_contrast
940
+ # CoT trigger active = clean
941
+ # CoT trigger stripped (zero-shot) = corrupted
942
+ clean_str = self._build_prompt_cot(
943
+ prompt_strategy, sample.text, sample.metadata or {}, cot=True
944
+ )
945
+ corr_str = self._build_prompt_cot(
946
+ prompt_strategy, sample.text, sample.metadata or {}, cot=False
947
+ )
948
+
949
+ clean_tokens = self._tokenize(tokenizer, clean_str, backend.device)
950
+ corr_tokens = self._tokenize(tokenizer, corr_str, backend.device)
951
+
952
+ try:
953
+ # Step 1 — clean forward, cache activations
954
+ logits_clean, act_cache = self._forward_with_cache(
955
+ backend, clean_tokens, target_layers
956
+ )
957
+ # Step 2 — corrupted baseline (no patching needed, reuse cache run)
958
+ logits_corr, _ = self._forward_with_cache(backend, corr_tokens, [])
959
+ except Exception as exc:
960
+ tqdm.write(f" [skip] sample {sample.idx} (baseline): {type(exc).__name__}: {exc}")
961
+ torch.cuda.empty_cache()
962
+ continue
963
+
964
+ clean_logit = float(logits_clean[clean_tok_id].item())
965
+ corr_logit = float(logits_corr[clean_tok_id].item())
966
+ denom = clean_logit - corr_logit # may be 0 or negative
967
+
968
+ sample_layer_effects: Dict[int, float] = {}
969
+
970
+ # Step 3 — patching sweep over layers
971
+ for layer_idx in target_layers:
972
+ if layer_idx not in act_cache:
973
+ continue
974
+ try:
975
+ logits_patch = self._forward_patched(
976
+ backend, corr_tokens, layer_idx, act_cache[layer_idx]
977
+ )
978
+ except Exception as exc:
979
+ tqdm.write(f" [skip] sample {sample.idx} layer {layer_idx}: {exc}")
980
+ torch.cuda.empty_cache()
981
+ continue
982
+
983
+ patch_logit = float(logits_patch[clean_tok_id].item())
984
+ eps = 1e-6
985
+ if abs(denom) < eps:
986
+ effect = 0.0
987
+ else:
988
+ effect = (patch_logit - corr_logit) / denom
989
+ # Clip to [-1, 2] to handle outliers
990
+ effect = max(-1.0, min(2.0, effect))
991
+ layer_effects[layer_idx].append(effect)
992
+ sample_layer_effects[layer_idx] = round(effect, 4)
993
+ torch.cuda.empty_cache()
994
+
995
+ per_sample_results.append(
996
+ {
997
+ "sample_idx": sample.idx,
998
+ "clean_logit": round(clean_logit, 4),
999
+ "corrupt_logit": round(corr_logit, 4),
1000
+ "logit_gap": round(denom, 4),
1001
+ "layer_effects": sample_layer_effects,
1002
+ }
1003
+ )
1004
+ processed += 1
1005
+ torch.cuda.empty_cache()
1006
+
1007
+ # --- Aggregate --------------------------------------------------
1008
+ mean_effects: Dict[int, float] = {}
1009
+ for layer_idx in target_layers:
1010
+ vals = layer_effects[layer_idx]
1011
+ mean_effects[layer_idx] = round(sum(vals) / len(vals), 4) if vals else 0.0
1012
+
1013
+ sorted_by_effect = sorted(mean_effects.items(), key=lambda x: x[1], reverse=True)
1014
+ top_5_layers = [lid for lid, _ in sorted_by_effect[:5]]
1015
+
1016
+ # --- Print summary -----------------------------------------------
1017
+ print("\n" + "=" * 62)
1018
+ print("ACTIVATION PATCHING SUMMARY (logit-recovery effect)")
1019
+ print("=" * 62)
1020
+ print(f"Processed samples : {processed} / {n}")
1021
+ print(f"Top-5 causal layers: {top_5_layers}")
1022
+ print()
1023
+ print(f"{'Layer':>6} {'Mean Effect':>12} {'N samples':>10}")
1024
+ print("-" * 34)
1025
+ for layer_idx in target_layers:
1026
+ n_val = len(layer_effects[layer_idx])
1027
+ print(f"{layer_idx:>6} {mean_effects[layer_idx]:>12.4f} {n_val:>10}")
1028
+ print("=" * 62)
1029
+
1030
+ return ExperimentResult(
1031
+ experiment_name=self.name,
1032
+ model_name=backend.model_name,
1033
+ prompt_strategy=(
1034
+ prompt_strategy.name if hasattr(prompt_strategy, "name") else "custom"
1035
+ ),
1036
+ metrics={
1037
+ "num_samples": processed,
1038
+ "layer_stride": self.layer_stride,
1039
+ "mean_effect_per_layer": mean_effects,
1040
+ "top_5_causal_layers": top_5_layers,
1041
+ },
1042
+ raw_outputs={"per_sample": per_sample_results},
1043
+ metadata={
1044
+ "target_layers": target_layers,
1045
+ "layer_stride": self.layer_stride,
1046
+ "num_samples": processed,
1047
+ "seed": self.seed,
1048
+ "answer_cue": self.answer_cue,
1049
+ },
1050
+ )