cotlab 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. cotlab/__init__.py +3 -0
  2. cotlab/analyse_experiments.py +392 -0
  3. cotlab/analysis/__init__.py +11 -0
  4. cotlab/analysis/cot_parser.py +243 -0
  5. cotlab/analysis/faithfulness_metrics.py +192 -0
  6. cotlab/backends/__init__.py +16 -0
  7. cotlab/backends/base.py +78 -0
  8. cotlab/backends/transformers_backend.py +335 -0
  9. cotlab/backends/vllm_backend.py +227 -0
  10. cotlab/cli.py +83 -0
  11. cotlab/core/__init__.py +34 -0
  12. cotlab/core/base.py +749 -0
  13. cotlab/core/config.py +90 -0
  14. cotlab/core/registry.py +68 -0
  15. cotlab/datasets/__init__.py +45 -0
  16. cotlab/datasets/loaders.py +1889 -0
  17. cotlab/experiment/__init__.py +315 -0
  18. cotlab/experiments/__init__.py +43 -0
  19. cotlab/experiments/activation_compare.py +290 -0
  20. cotlab/experiments/activation_patching.py +1050 -0
  21. cotlab/experiments/attention_analysis.py +885 -0
  22. cotlab/experiments/classification.py +235 -0
  23. cotlab/experiments/composite_shift_detector.py +524 -0
  24. cotlab/experiments/cot_ablation.py +277 -0
  25. cotlab/experiments/cot_faithfulness.py +187 -0
  26. cotlab/experiments/cot_heads.py +208 -0
  27. cotlab/experiments/full_layer_cot.py +232 -0
  28. cotlab/experiments/full_layer_patching.py +225 -0
  29. cotlab/experiments/h_neuron_analysis.py +712 -0
  30. cotlab/experiments/logit_lens.py +439 -0
  31. cotlab/experiments/multi_head_cot.py +220 -0
  32. cotlab/experiments/multi_head_patching.py +229 -0
  33. cotlab/experiments/probing_classifier.py +402 -0
  34. cotlab/experiments/residual_norm_ood.py +413 -0
  35. cotlab/experiments/sae_feature_analysis.py +673 -0
  36. cotlab/experiments/steering_vectors.py +223 -0
  37. cotlab/experiments/sycophancy_heads.py +224 -0
  38. cotlab/logging/__init__.py +5 -0
  39. cotlab/logging/json_logger.py +161 -0
  40. cotlab/main.py +317 -0
  41. cotlab/patching/__init__.py +24 -0
  42. cotlab/patching/cache.py +141 -0
  43. cotlab/patching/hooks.py +558 -0
  44. cotlab/patching/interventions.py +86 -0
  45. cotlab/patching/patcher.py +439 -0
  46. cotlab/patching/sae.py +181 -0
  47. cotlab/prompts/__init__.py +43 -0
  48. cotlab/prompts/cardiology.py +378 -0
  49. cotlab/prompts/histopathology.py +265 -0
  50. cotlab/prompts/length_matched_strategies.py +157 -0
  51. cotlab/prompts/mcq.py +193 -0
  52. cotlab/prompts/neurology.py +353 -0
  53. cotlab/prompts/oncology.py +367 -0
  54. cotlab/prompts/plab.py +162 -0
  55. cotlab/prompts/pubhealthbench.py +82 -0
  56. cotlab/prompts/pubmedqa.py +173 -0
  57. cotlab/prompts/radiology.py +414 -0
  58. cotlab/prompts/strategies.py +939 -0
  59. cotlab/prompts/tcga.py +168 -0
  60. cotlab/runner.py +204 -0
  61. cotlab-0.8.0.dist-info/METADATA +166 -0
  62. cotlab-0.8.0.dist-info/RECORD +65 -0
  63. cotlab-0.8.0.dist-info/WHEEL +4 -0
  64. cotlab-0.8.0.dist-info/entry_points.txt +3 -0
  65. cotlab-0.8.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,439 @@
1
+ """Activation patcher for causal intervention experiments."""
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Dict, Iterable, List, Optional, Tuple
5
+
6
+ import torch
7
+
8
+ from .cache import ActivationCache
9
+ from .hooks import HookManager
10
+
11
+
12
+ @dataclass
13
+ class PatchingResult:
14
+ """Result from a patching operation."""
15
+
16
+ output_text: str
17
+ patched_layers: List[int]
18
+ patched_positions: Optional[List[int]]
19
+ original_answer: Optional[str] = None
20
+ patched_answer: Optional[str] = None
21
+
22
+ @property
23
+ def answer_changed(self) -> bool:
24
+ """Check if patching changed the answer."""
25
+ if self.original_answer is None or self.patched_answer is None:
26
+ return False
27
+ return self.original_answer.strip() != self.patched_answer.strip()
28
+
29
+
30
+ @dataclass
31
+ class ForwardPatchResult:
32
+ """Result from forward-only patching (no generation)."""
33
+
34
+ layer_idx: int
35
+ clean_logits: torch.Tensor
36
+ corrupted_logits: torch.Tensor
37
+ patched_logits: torch.Tensor
38
+ effect_score: float # How much patching recovered clean behavior
39
+
40
+ @property
41
+ def kl_from_clean(self) -> float:
42
+ """KL divergence of patched from clean distribution."""
43
+ clean_probs = torch.softmax(self.clean_logits[:, -1, :], dim=-1)
44
+ patched_probs = torch.softmax(self.patched_logits[:, -1, :], dim=-1)
45
+ return torch.nn.functional.kl_div(
46
+ patched_probs.log(), clean_probs, reduction="batchmean"
47
+ ).item()
48
+
49
+ @property
50
+ def kl_from_corrupted(self) -> float:
51
+ """KL divergence of patched from corrupted distribution."""
52
+ corrupted_probs = torch.softmax(self.corrupted_logits[:, -1, :], dim=-1)
53
+ patched_probs = torch.softmax(self.patched_logits[:, -1, :], dim=-1)
54
+ return torch.nn.functional.kl_div(
55
+ patched_probs.log(), corrupted_probs, reduction="batchmean"
56
+ ).item()
57
+
58
+
59
+ class ActivationPatcher:
60
+ """
61
+ Perform activation patching interventions for causal analysis.
62
+
63
+ Activation patching is a technique to test causal importance of
64
+ specific model components by replacing activations from one run
65
+ with activations from another run.
66
+
67
+ Example:
68
+ >>> patcher = ActivationPatcher(backend)
69
+ >>> clean_out, clean_cache = backend.generate_with_cache(clean_prompt)
70
+ >>> results = patcher.patch_run(
71
+ ... corrupted_prompt,
72
+ ... source_cache=clean_cache,
73
+ ... target_layers=[5, 10, 15]
74
+ ... )
75
+ """
76
+
77
+ def __init__(self, backend: "TransformersBackend"):
78
+ """
79
+ Initialize patcher with a backend.
80
+
81
+ Args:
82
+ backend: TransformersBackend with hook support
83
+ """
84
+ if not backend.supports_activations:
85
+ raise ValueError("Backend does not support activation access")
86
+
87
+ self.backend = backend
88
+
89
+ @property
90
+ def hook_manager(self) -> HookManager:
91
+ return self.backend.hook_manager
92
+
93
+ def patch_run(
94
+ self,
95
+ prompt: str,
96
+ source_cache: ActivationCache,
97
+ target_layers: List[int],
98
+ token_positions: Optional[List[int]] = None,
99
+ **gen_kwargs,
100
+ ) -> PatchingResult:
101
+ """
102
+ Run generation while patching activations from source_cache.
103
+
104
+ This is the core patching operation: run the model on `prompt`
105
+ but replace activations at `target_layers` with values from
106
+ `source_cache` (typically from a "clean" run).
107
+
108
+ Args:
109
+ prompt: Input prompt (typically the "corrupted" version)
110
+ source_cache: Activations from another run to patch in
111
+ target_layers: Which layers to patch
112
+ token_positions: Which token positions to patch (None = all)
113
+ **gen_kwargs: Additional generation arguments
114
+
115
+ Returns:
116
+ PatchingResult with output and metadata
117
+ """
118
+ # Register patching hooks
119
+ for layer_idx in target_layers:
120
+ source_activation = source_cache.get(layer_idx)
121
+ if source_activation is None:
122
+ raise ValueError(f"Layer {layer_idx} not in source cache")
123
+
124
+ self.hook_manager.register_patch_hook(layer_idx, source_activation, token_positions)
125
+
126
+ try:
127
+ output = self.backend.generate(prompt, **gen_kwargs)
128
+ finally:
129
+ self.hook_manager.remove_all_hooks()
130
+
131
+ return PatchingResult(
132
+ output_text=output.text, patched_layers=target_layers, patched_positions=token_positions
133
+ )
134
+
135
+ def sweep_layers(
136
+ self,
137
+ clean_prompt: str,
138
+ corrupted_prompt: str,
139
+ layers: Optional[List[int]] = None,
140
+ **gen_kwargs,
141
+ ) -> Dict[int, PatchingResult]:
142
+ """
143
+ Sweep patching across layers to find causal importance.
144
+
145
+ For each layer, patches that layer's activations from the clean
146
+ run into the corrupted run and measures the effect.
147
+
148
+ Args:
149
+ clean_prompt: Prompt that elicits desired behavior
150
+ corrupted_prompt: Prompt that elicits different behavior
151
+ layers: Which layers to sweep (None = all)
152
+ **gen_kwargs: Generation arguments
153
+
154
+ Returns:
155
+ Dict mapping layer_idx -> ForwardPatchResult
156
+ """
157
+ # Get clean logits and activations
158
+ clean_logits, clean_cache = self.backend.forward_with_cache(clean_prompt, layers=layers)
159
+
160
+ # Get corrupted logits (no patching)
161
+ corrupted_logits, _ = self.backend.forward_with_cache(corrupted_prompt, layers=layers)
162
+
163
+ target_layers = layers if layers is not None else clean_cache.layers
164
+ results = {}
165
+
166
+ for layer_idx in target_layers:
167
+ # Run forward pass with patching on this layer
168
+ patched_logits = self._forward_with_patch(corrupted_prompt, clean_cache, layer_idx)
169
+
170
+ # Compute effect score (how much patching moves toward clean)
171
+ effect = self._compute_logit_effect(clean_logits, corrupted_logits, patched_logits)
172
+
173
+ results[layer_idx] = ForwardPatchResult(
174
+ layer_idx=layer_idx,
175
+ clean_logits=clean_logits.cpu(),
176
+ corrupted_logits=corrupted_logits.cpu(),
177
+ patched_logits=patched_logits.cpu(),
178
+ effect_score=effect,
179
+ )
180
+
181
+ return results
182
+
183
+ def sweep_heads(
184
+ self,
185
+ clean_prompt: str,
186
+ corrupted_prompt: str,
187
+ layers: Optional[List[int]] = None,
188
+ head_indices: Optional[Iterable[int]] = None,
189
+ target_heads: Optional[Dict[int, Iterable[int]]] = None,
190
+ ) -> Dict[Tuple[int, int], ForwardPatchResult]:
191
+ """
192
+ Sweep patching across attention heads to find causal importance.
193
+
194
+ Args:
195
+ clean_prompt: Prompt that elicits desired behavior
196
+ corrupted_prompt: Prompt that elicits different behavior
197
+ layers: Which layers to sweep (None = all)
198
+ head_indices: Heads to patch for all layers (e.g., [0,1,2])
199
+ target_heads: Dict mapping layer -> list of head indices
200
+
201
+ Returns:
202
+ Dict mapping (layer_idx, head_idx) -> ForwardPatchResult
203
+ """
204
+ if head_indices is None and not target_heads:
205
+ raise ValueError("Provide head_indices or target_heads for head patching.")
206
+ if head_indices is not None and target_heads:
207
+ raise ValueError("Use either head_indices or target_heads, not both.")
208
+
209
+ layers_to_sweep = layers
210
+ if target_heads:
211
+ layers_to_sweep = list(target_heads.keys())
212
+
213
+ # Get clean logits and attention outputs
214
+ clean_logits, clean_cache = self.backend.forward_with_attention_cache(
215
+ clean_prompt, layers=layers_to_sweep
216
+ )
217
+
218
+ # Get corrupted logits (no patching)
219
+ corrupted_logits, _ = self.backend.forward_with_cache(corrupted_prompt, layers=[])
220
+
221
+ num_heads, head_dim = self._get_head_info()
222
+
223
+ results: Dict[Tuple[int, int], ForwardPatchResult] = {}
224
+ layers_to_sweep = layers_to_sweep if layers_to_sweep is not None else clean_cache.layers
225
+
226
+ for layer_idx in layers_to_sweep:
227
+ heads = list(target_heads.get(layer_idx, [])) if target_heads else list(head_indices)
228
+ for head_idx in heads:
229
+ if head_idx < 0 or head_idx >= num_heads:
230
+ raise ValueError(f"Head index {head_idx} out of range (0..{num_heads - 1})")
231
+
232
+ patched_logits = self._forward_with_head_patch(
233
+ corrupted_prompt, clean_cache, layer_idx, [head_idx], head_dim
234
+ )
235
+ effect = self._compute_logit_effect(clean_logits, corrupted_logits, patched_logits)
236
+
237
+ results[(layer_idx, head_idx)] = ForwardPatchResult(
238
+ layer_idx=layer_idx,
239
+ clean_logits=clean_logits.cpu(),
240
+ corrupted_logits=corrupted_logits.cpu(),
241
+ patched_logits=patched_logits.cpu(),
242
+ effect_score=effect,
243
+ )
244
+
245
+ return results
246
+
247
+ def _forward_with_patch(
248
+ self,
249
+ prompt: str,
250
+ source_cache: ActivationCache,
251
+ layer_idx: int,
252
+ ) -> torch.Tensor:
253
+ """Run a single forward pass with patching at specified layer's residual stream."""
254
+ source_activation = source_cache.get(layer_idx)
255
+ if source_activation is None:
256
+ raise ValueError(f"Layer {layer_idx} not in source cache")
257
+
258
+ # Use residual patch hook for safer patching
259
+ self.hook_manager.register_residual_patch_hook(layer_idx, source_activation, None)
260
+
261
+ try:
262
+ logits, _ = self.backend.forward_with_cache(prompt, layers=[])
263
+ finally:
264
+ self.hook_manager.remove_all_hooks()
265
+
266
+ return logits
267
+
268
+ def _forward_with_head_patch(
269
+ self,
270
+ prompt: str,
271
+ source_cache: ActivationCache,
272
+ layer_idx: int,
273
+ head_indices: List[int],
274
+ head_dim: int,
275
+ ) -> torch.Tensor:
276
+ """Run forward pass with attention head patching at a specific layer."""
277
+ source_activation = source_cache.get(layer_idx)
278
+ if source_activation is None:
279
+ raise ValueError(f"Layer {layer_idx} not in source cache")
280
+
281
+ self.hook_manager.register_multi_head_patch_hook(
282
+ layer_idx, head_indices, source_activation, head_dim
283
+ )
284
+
285
+ try:
286
+ logits, _ = self.backend.forward_with_cache(prompt, layers=[])
287
+ finally:
288
+ self.hook_manager.remove_all_hooks()
289
+
290
+ return logits
291
+
292
+ def _get_head_info(self) -> Tuple[int, int]:
293
+ """Return (num_heads, head_dim) from the model config."""
294
+ cfg = getattr(self.backend, "model", None).config
295
+ # Handle multimodal models with nested text_config
296
+ if hasattr(cfg, "text_config"):
297
+ cfg = cfg.text_config
298
+ num_heads = (
299
+ getattr(cfg, "num_attention_heads", None)
300
+ or getattr(cfg, "n_head", None)
301
+ or getattr(cfg, "num_heads", None)
302
+ )
303
+ head_dim = getattr(cfg, "head_dim", None)
304
+ hidden_size = (
305
+ getattr(cfg, "hidden_size", None)
306
+ or getattr(cfg, "hidden_dim", None)
307
+ or getattr(cfg, "d_model", None)
308
+ )
309
+
310
+ if head_dim is None and num_heads and hidden_size:
311
+ head_dim = hidden_size // num_heads
312
+
313
+ if not num_heads or not head_dim:
314
+ raise ValueError(
315
+ "Could not determine head information from model config "
316
+ f"(num_heads={num_heads}, head_dim={head_dim})."
317
+ )
318
+
319
+ return int(num_heads), int(head_dim)
320
+
321
+ def _compute_logit_effect(
322
+ self,
323
+ clean_logits: torch.Tensor,
324
+ corrupted_logits: torch.Tensor,
325
+ patched_logits: torch.Tensor,
326
+ ) -> float:
327
+ """
328
+ Compute how much patching moves logits from corrupted toward clean.
329
+
330
+ Returns value from 0 (no effect) to 1 (full recovery).
331
+ """
332
+ # Use last token logits
333
+ clean_last = clean_logits[:, -1, :]
334
+ corrupted_last = corrupted_logits[:, -1, :]
335
+ patched_last = patched_logits[:, -1, :]
336
+
337
+ # Compute distances
338
+ clean_to_corrupted = torch.norm(clean_last - corrupted_last).item()
339
+ patched_to_corrupted = torch.norm(patched_last - corrupted_last).item()
340
+
341
+ if clean_to_corrupted == 0:
342
+ return 0.0
343
+
344
+ # Effect = how much we moved toward clean (normalized)
345
+ effect = 1.0 - (patched_to_corrupted / clean_to_corrupted)
346
+ return max(0.0, min(1.0, effect))
347
+
348
+ def sweep_positions(
349
+ self,
350
+ clean_prompt: str,
351
+ corrupted_prompt: str,
352
+ layer_idx: int,
353
+ positions: Optional[List[int]] = None,
354
+ **gen_kwargs,
355
+ ) -> Dict[int, PatchingResult]:
356
+ """
357
+ Sweep patching across token positions at a fixed layer.
358
+
359
+ Args:
360
+ clean_prompt: Prompt that elicits desired behavior
361
+ corrupted_prompt: Prompt that elicits different behavior
362
+ layer_idx: Which layer to patch
363
+ positions: Which positions to sweep (None = all)
364
+ **gen_kwargs: Generation arguments
365
+
366
+ Returns:
367
+ Dict mapping position -> PatchingResult
368
+ """
369
+ # Get clean activations
370
+ _, clean_cache = self.backend.generate_with_cache(
371
+ clean_prompt, layers=[layer_idx], **gen_kwargs
372
+ )
373
+
374
+ # Get corrupted baseline
375
+ corrupted_output = self.backend.generate(corrupted_prompt, **gen_kwargs)
376
+
377
+ # Determine positions to sweep
378
+ clean_act = clean_cache.get(layer_idx)
379
+ target_positions = positions if positions is not None else list(range(clean_act.shape[1]))
380
+
381
+ results = {}
382
+
383
+ for pos in target_positions:
384
+ result = self.patch_run(
385
+ corrupted_prompt,
386
+ clean_cache,
387
+ target_layers=[layer_idx],
388
+ token_positions=[pos],
389
+ **gen_kwargs,
390
+ )
391
+ result.original_answer = corrupted_output.text
392
+ result.patched_answer = result.output_text
393
+ results[pos] = result
394
+
395
+ return results
396
+
397
+ def ablate_layer(
398
+ self, prompt: str, layer_idx: int, ablation_type: str = "zero", **gen_kwargs
399
+ ) -> PatchingResult:
400
+ """
401
+ Ablate a layer by zeroing or replacing with mean.
402
+
403
+ Args:
404
+ prompt: Input prompt
405
+ layer_idx: Which layer to ablate
406
+ ablation_type: "zero" or "mean"
407
+ **gen_kwargs: Generation arguments
408
+
409
+ Returns:
410
+ PatchingResult from ablated run
411
+ """
412
+
413
+ def zero_hook(module, input, output):
414
+ if isinstance(output, tuple):
415
+ hidden_states = output[0]
416
+ rest = output[1:]
417
+ return (torch.zeros_like(hidden_states),) + rest
418
+ return torch.zeros_like(output)
419
+
420
+ def mean_hook(module, input, output):
421
+ if isinstance(output, tuple):
422
+ hidden_states = output[0]
423
+ rest = output[1:]
424
+ mean_val = hidden_states.mean(dim=(0, 1), keepdim=True)
425
+ return (mean_val.expand_as(hidden_states),) + rest
426
+ mean_val = output.mean(dim=(0, 1), keepdim=True)
427
+ return mean_val.expand_as(output)
428
+
429
+ hook = zero_hook if ablation_type == "zero" else mean_hook
430
+ self.hook_manager.register_forward_hook(layer_idx, hook)
431
+
432
+ try:
433
+ output = self.backend.generate(prompt, **gen_kwargs)
434
+ finally:
435
+ self.hook_manager.remove_all_hooks()
436
+
437
+ return PatchingResult(
438
+ output_text=output.text, patched_layers=[layer_idx], patched_positions=None
439
+ )
cotlab/patching/sae.py ADDED
@@ -0,0 +1,181 @@
1
+ """GemmaScope-2 JumpReLU Sparse Autoencoder loader.
2
+
3
+ Loads SAE weights directly from HuggingFace (no SAELens / TransformerLens
4
+ dependency) and exposes a minimal encode() forward pass compatible with the
5
+ CoTLab TransformersBackend hook infrastructure.
6
+
7
+ Architecture (JumpReLU SAE, Lieberum et al. 2024)
8
+ --------------------------------------------------
9
+ h = x - b_dec # centre around decoder bias
10
+ pre_act = h @ w_enc + b_enc # linear projection [..., d_sae]
11
+ features = pre_act * (pre_act > θ) # JumpReLU gate
12
+
13
+ Weight layout in params.safetensors (GemmaScope-2 convention)
14
+ --------------------------------------------------------------
15
+ w_enc float32 [d_model, d_sae]
16
+ b_enc float32 [d_sae]
17
+ w_dec float32 [d_sae, d_model]
18
+ b_dec float32 [d_model]
19
+ threshold float32 [d_sae]
20
+
21
+ HF path pattern
22
+ ---------------
23
+ {site}/layer_{N}_width_{width}_l0_{l0_label}/params.safetensors
24
+ e.g. resid_post_all/layer_9_width_16k_l0_small/params.safetensors
25
+ """
26
+
27
+ import re
28
+ from typing import Optional
29
+
30
+ import torch
31
+ import torch.nn as nn
32
+
33
+
34
+ class GemmaScopeLayer(nn.Module):
35
+ """JumpReLU SAE for a single residual-stream layer."""
36
+
37
+ def __init__(
38
+ self,
39
+ w_enc: torch.Tensor,
40
+ b_enc: torch.Tensor,
41
+ w_dec: torch.Tensor,
42
+ b_dec: torch.Tensor,
43
+ threshold: torch.Tensor,
44
+ layer: int,
45
+ repo_id: str,
46
+ source_path: str,
47
+ ):
48
+ super().__init__()
49
+ self.register_buffer("w_enc", w_enc) # [d_model, d_sae]
50
+ self.register_buffer("b_enc", b_enc) # [d_sae]
51
+ self.register_buffer("w_dec", w_dec) # [d_sae, d_model]
52
+ self.register_buffer("b_dec", b_dec) # [d_model]
53
+ self.register_buffer("threshold", threshold) # [d_sae]
54
+ self.layer = layer
55
+ self.repo_id = repo_id
56
+ self.source_path = source_path
57
+
58
+ @property
59
+ def d_model(self) -> int:
60
+ return self.w_enc.shape[0]
61
+
62
+ @property
63
+ def d_sae(self) -> int:
64
+ return self.w_enc.shape[1]
65
+
66
+ # ------------------------------------------------------------------
67
+ # Factory
68
+ # ------------------------------------------------------------------
69
+
70
+ @classmethod
71
+ def from_pretrained(
72
+ cls,
73
+ repo_id: str,
74
+ layer: int,
75
+ site: str = "resid_post_all",
76
+ width: str = "16k",
77
+ l0_label: str = "small",
78
+ token: Optional[str] = None,
79
+ ) -> "GemmaScopeLayer":
80
+ """Download and load a GemmaScope-2 SAE from HuggingFace.
81
+
82
+ Args:
83
+ repo_id: HF repo, e.g. ``"google/gemma-scope-2-270m-it"``
84
+ layer: Residual stream layer index (0-based).
85
+ site: SAE training site. ``"resid_post_all"`` covers every
86
+ layer; other options are ``"resid_post"`` (4 depths only),
87
+ ``"attn_out_all"``, ``"mlp_out_all"``.
88
+ width: Feature dictionary width: ``"16k"`` or ``"262k"``.
89
+ l0_label: Sparsity label used in the directory name: ``"small"``,
90
+ ``"medium"``, or ``"big"``. If not found, falls back to
91
+ any available file for that layer/width combination.
92
+ token: HuggingFace API token (optional; reads HF_TOKEN env var
93
+ automatically via huggingface_hub).
94
+ """
95
+ from huggingface_hub import hf_hub_download, list_repo_files # noqa: PLC0415
96
+ from safetensors import safe_open # noqa: PLC0415
97
+
98
+ # Try the canonical direct path first (avoids full repo listing).
99
+ direct = f"{site}/layer_{layer}_width_{width}_l0_{l0_label}/params.safetensors"
100
+ print(f" [SAE] Fetching layer={layer} ({direct}) …")
101
+
102
+ try:
103
+ local_path = hf_hub_download(repo_id=repo_id, filename=direct, token=token)
104
+ chosen = direct
105
+ except Exception:
106
+ print(" [SAE] Direct path not found — scanning repo …")
107
+ all_files = list(list_repo_files(repo_id, token=token))
108
+ layer_tag = f"layer_{layer}_"
109
+ width_tag = f"width_{width}_"
110
+ candidates = [
111
+ f
112
+ for f in all_files
113
+ if site in f
114
+ and layer_tag in f
115
+ and width_tag in f
116
+ and f.endswith("params.safetensors")
117
+ ]
118
+ if not candidates:
119
+ available = sorted(
120
+ {
121
+ re.search(r"layer_(\d+)_", f).group(1)
122
+ for f in all_files
123
+ if "params.safetensors" in f and re.search(r"layer_(\d+)_", f)
124
+ },
125
+ key=int,
126
+ )
127
+ raise FileNotFoundError(
128
+ f"No SAE found for site={site!r}, layer={layer}, width={width!r} "
129
+ f"in {repo_id}.\nAvailable layers (any site/width): {available}"
130
+ )
131
+ preferred = [f for f in candidates if f"l0_{l0_label}" in f]
132
+ chosen = preferred[0] if preferred else candidates[0]
133
+ local_path = hf_hub_download(repo_id=repo_id, filename=chosen, token=token)
134
+
135
+ with safe_open(local_path, framework="pt") as f:
136
+ w_enc = f.get_tensor("w_enc").float() # [d_model, d_sae]
137
+ b_enc = f.get_tensor("b_enc").float() # [d_sae]
138
+ w_dec = f.get_tensor("w_dec").float() # [d_sae, d_model]
139
+ b_dec = f.get_tensor("b_dec").float() # [d_model]
140
+ threshold = f.get_tensor("threshold").float() # [d_sae]
141
+
142
+ print(f" [SAE] Loaded layer={layer} d_model={w_enc.shape[0]} d_sae={w_enc.shape[1]}")
143
+ return cls(
144
+ w_enc=w_enc,
145
+ b_enc=b_enc,
146
+ w_dec=w_dec,
147
+ b_dec=b_dec,
148
+ threshold=threshold,
149
+ layer=layer,
150
+ repo_id=repo_id,
151
+ source_path=chosen,
152
+ )
153
+
154
+ # ------------------------------------------------------------------
155
+ # Forward
156
+ # ------------------------------------------------------------------
157
+
158
+ @torch.no_grad()
159
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
160
+ """JumpReLU encode.
161
+
162
+ Args:
163
+ x: Residual stream activations, shape ``[..., d_model]``.
164
+ Must be on the same device as the SAE buffers.
165
+
166
+ Returns:
167
+ Sparse feature activations, shape ``[..., d_sae]``.
168
+ """
169
+ orig_dtype = x.dtype
170
+ x = x.float()
171
+ h = x - self.b_dec
172
+ pre = h @ self.w_enc + self.b_enc
173
+ features = pre * (pre > self.threshold).float()
174
+ return features.to(orig_dtype)
175
+
176
+ def __repr__(self) -> str:
177
+ return (
178
+ f"GemmaScopeLayer(layer={self.layer}, "
179
+ f"d_model={self.d_model}, d_sae={self.d_sae}, "
180
+ f"repo={self.repo_id})"
181
+ )
@@ -0,0 +1,43 @@
1
+ """Prompt strategies module."""
2
+
3
+ from .cardiology import CardiologyPromptStrategy
4
+ from .histopathology import HistopathologyPromptStrategy
5
+ from .length_matched_strategies import (
6
+ ChainOfThoughtMatchedStrategy,
7
+ ContrarianMatchedStrategy,
8
+ DirectAnswerMatchedStrategy,
9
+ )
10
+ from .mcq import MCQPromptStrategy
11
+ from .neurology import NeurologyPromptStrategy
12
+ from .oncology import OncologyPromptStrategy
13
+ from .plab import PLABPromptStrategy
14
+ from .pubhealthbench import PubHealthBenchMCQPromptStrategy
15
+ from .radiology import RadiologyPromptStrategy
16
+ from .strategies import (
17
+ ArroganceStrategy,
18
+ ChainOfThoughtStrategy,
19
+ DirectAnswerStrategy,
20
+ NoInstructionStrategy,
21
+ SimplePromptStrategy,
22
+ create_prompt_strategy,
23
+ )
24
+
25
+ __all__ = [
26
+ "SimplePromptStrategy",
27
+ "ChainOfThoughtStrategy",
28
+ "DirectAnswerStrategy",
29
+ "ArroganceStrategy",
30
+ "NoInstructionStrategy",
31
+ "CardiologyPromptStrategy",
32
+ "HistopathologyPromptStrategy",
33
+ "MCQPromptStrategy",
34
+ "NeurologyPromptStrategy",
35
+ "OncologyPromptStrategy",
36
+ "PLABPromptStrategy",
37
+ "PubHealthBenchMCQPromptStrategy",
38
+ "RadiologyPromptStrategy",
39
+ "create_prompt_strategy",
40
+ "ContrarianMatchedStrategy",
41
+ "ChainOfThoughtMatchedStrategy",
42
+ "DirectAnswerMatchedStrategy",
43
+ ]