sae-lens 6.28.2__py3-none-any.whl → 6.32.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sae_lens/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.28.2"
2
+ __version__ = "6.32.1"
3
3
 
4
4
  import logging
5
5
 
@@ -125,6 +125,19 @@ __all__ = [
125
125
  "MatchingPursuitTrainingSAEConfig",
126
126
  ]
127
127
 
128
+ # Conditional export for SAETransformerBridge (requires transformer-lens v3+)
129
+ try:
130
+ from sae_lens.analysis.compat import has_transformer_bridge
131
+
132
+ if has_transformer_bridge():
133
+ from sae_lens.analysis.sae_transformer_bridge import ( # noqa: F401
134
+ SAETransformerBridge,
135
+ )
136
+
137
+ __all__.append("SAETransformerBridge")
138
+ except ImportError:
139
+ pass
140
+
128
141
 
129
142
  register_sae_class("standard", StandardSAE, StandardSAEConfig)
130
143
  register_sae_training_class("standard", StandardTrainingSAE, StandardTrainingSAEConfig)
@@ -0,0 +1,15 @@
1
+ from sae_lens.analysis.hooked_sae_transformer import HookedSAETransformer
2
+
3
+ __all__ = ["HookedSAETransformer"]
4
+
5
+ try:
6
+ from sae_lens.analysis.compat import has_transformer_bridge
7
+
8
+ if has_transformer_bridge():
9
+ from sae_lens.analysis.sae_transformer_bridge import ( # noqa: F401
10
+ SAETransformerBridge,
11
+ )
12
+
13
+ __all__.append("SAETransformerBridge")
14
+ except ImportError:
15
+ pass
@@ -0,0 +1,16 @@
1
+ import importlib.metadata
2
+
3
+ from packaging.version import parse as parse_version
4
+
5
+
6
+ def get_transformer_lens_version() -> tuple[int, int, int]:
7
+ """Get transformer-lens version as (major, minor, patch)."""
8
+ version_str = importlib.metadata.version("transformer-lens")
9
+ version = parse_version(version_str)
10
+ return (version.major, version.minor, version.micro)
11
+
12
+
13
+ def has_transformer_bridge() -> bool:
14
+ """Check if TransformerBridge is available (v3+)."""
15
+ major, _, _ = get_transformer_lens_version()
16
+ return major >= 3
@@ -126,7 +126,7 @@ class HookedSAETransformer(HookedTransformer):
126
126
  current_sae.use_error_term = current_sae._original_use_error_term # type: ignore
127
127
  delattr(current_sae, "_original_use_error_term")
128
128
 
129
- if prev_sae:
129
+ if prev_sae is not None:
130
130
  set_deep_attr(self, act_name, prev_sae)
131
131
  self.acts_to_saes[act_name] = prev_sae
132
132
  else:
@@ -0,0 +1,348 @@
1
+ from collections.abc import Callable
2
+ from contextlib import contextmanager
3
+ from typing import Any
4
+
5
+ import torch
6
+ from transformer_lens.ActivationCache import ActivationCache
7
+ from transformer_lens.hook_points import HookPoint
8
+ from transformer_lens.model_bridge import TransformerBridge
9
+
10
+ from sae_lens import logger
11
+ from sae_lens.analysis.hooked_sae_transformer import set_deep_attr
12
+ from sae_lens.saes.sae import SAE
13
+
14
+ SingleLoss = torch.Tensor # Type alias for a single element tensor
15
+ LossPerToken = torch.Tensor
16
+ Loss = SingleLoss | LossPerToken
17
+
18
+
19
+ class SAETransformerBridge(TransformerBridge): # type: ignore[misc,no-untyped-call]
20
+ """A TransformerBridge subclass that supports attaching SAEs.
21
+
22
+ .. warning::
23
+ This class is in **beta**. The API may change in future versions.
24
+
25
+ This class provides the same SAE attachment functionality as HookedSAETransformer,
26
+ but for transformer-lens v3's TransformerBridge instead of HookedTransformer.
27
+
28
+ TransformerBridge is a lightweight wrapper around HuggingFace models that provides
29
+ hook points without the overhead of HookedTransformer's weight processing. This is
30
+ useful for models not natively supported by HookedTransformer, such as Gemma 3.
31
+ """
32
+
33
+ acts_to_saes: dict[str, SAE[Any]]
34
+
35
+ def __init__(self, *args: Any, **kwargs: Any):
36
+ super().__init__(*args, **kwargs)
37
+ self.acts_to_saes = {}
38
+
39
+ @classmethod
40
+ def boot_transformers( # type: ignore[override]
41
+ cls,
42
+ model_name: str,
43
+ **kwargs: Any,
44
+ ) -> "SAETransformerBridge":
45
+ """Factory method to boot a model and return SAETransformerBridge instance.
46
+
47
+ Args:
48
+ model_name: The name of the model to load (e.g., "gpt2", "gemma-2-2b")
49
+ **kwargs: Additional arguments passed to TransformerBridge.boot_transformers
50
+
51
+ Returns:
52
+ SAETransformerBridge instance with the loaded model
53
+ """
54
+ # Boot parent TransformerBridge
55
+ bridge = TransformerBridge.boot_transformers(model_name, **kwargs)
56
+ # Convert to our class
57
+ # NOTE: this is super hacky and scary, but I don't know how else to achieve this given TLens' internal code
58
+ bridge.__class__ = cls
59
+ bridge.acts_to_saes = {} # type: ignore[attr-defined]
60
+ return bridge # type: ignore[return-value]
61
+
62
+ def _resolve_hook_name(self, hook_name: str) -> str:
63
+ """Resolve alias to actual hook name.
64
+
65
+ TransformerBridge supports hook aliases like 'blocks.0.hook_mlp_out'
66
+ that map to actual paths like 'blocks.0.mlp.hook_out'.
67
+ """
68
+ # Combine static and dynamic aliases
69
+ aliases: dict[str, Any] = {
70
+ **self.hook_aliases,
71
+ **self._collect_hook_aliases_from_registry(),
72
+ }
73
+ resolved = aliases.get(hook_name, hook_name)
74
+ # aliases values are always strings, but type checker doesn't know this
75
+ return resolved if isinstance(resolved, str) else hook_name
76
+
77
+ def add_sae(self, sae: SAE[Any], use_error_term: bool | None = None) -> None:
78
+ """Attaches an SAE to the model.
79
+
80
+ WARNING: This SAE will be permanently attached until you remove it with
81
+ reset_saes. This function will also overwrite any existing SAE attached
82
+ to the same hook point.
83
+
84
+ Args:
85
+ sae: The SAE to attach to the model
86
+ use_error_term: If provided, will set the use_error_term attribute of
87
+ the SAE to this value. Determines whether the SAE returns input
88
+ or reconstruction. Defaults to None.
89
+ """
90
+ alias_name = sae.cfg.metadata.hook_name
91
+ actual_name = self._resolve_hook_name(alias_name)
92
+
93
+ # Check if hook exists (either as alias or actual name)
94
+ if (alias_name not in self.acts_to_saes) and (
95
+ actual_name not in self._hook_registry
96
+ ):
97
+ logger.warning(
98
+ f"No hook found for {alias_name}. Skipping. "
99
+ f"Check model._hook_registry for available hooks."
100
+ )
101
+ return
102
+
103
+ if use_error_term is not None:
104
+ if not hasattr(sae, "_original_use_error_term"):
105
+ sae._original_use_error_term = sae.use_error_term # type: ignore[attr-defined]
106
+ sae.use_error_term = use_error_term
107
+
108
+ # Replace hook and update registry
109
+ set_deep_attr(self, actual_name, sae)
110
+ self._hook_registry[actual_name] = sae # type: ignore[assignment]
111
+ self.acts_to_saes[alias_name] = sae
112
+
113
+ def _reset_sae(self, act_name: str, prev_sae: SAE[Any] | None = None) -> None:
114
+ """Resets an SAE that was attached to the model.
115
+
116
+ By default will remove the SAE from that hook_point.
117
+ If prev_sae is provided, will replace the current SAE with the provided one.
118
+ This is mainly used to restore previously attached SAEs after temporarily
119
+ running with different SAEs (e.g., with run_with_saes).
120
+
121
+ Args:
122
+ act_name: The hook_name of the SAE to reset
123
+ prev_sae: The SAE to replace the current one with. If None, will just
124
+ remove the SAE from this hook point. Defaults to None.
125
+ """
126
+ if act_name not in self.acts_to_saes:
127
+ logger.warning(
128
+ f"No SAE is attached to {act_name}. There's nothing to reset."
129
+ )
130
+ return
131
+
132
+ actual_name = self._resolve_hook_name(act_name)
133
+ current_sae = self.acts_to_saes[act_name]
134
+
135
+ if hasattr(current_sae, "_original_use_error_term"):
136
+ current_sae.use_error_term = current_sae._original_use_error_term # type: ignore[attr-defined]
137
+ delattr(current_sae, "_original_use_error_term")
138
+
139
+ if prev_sae is not None:
140
+ set_deep_attr(self, actual_name, prev_sae)
141
+ self._hook_registry[actual_name] = prev_sae # type: ignore[assignment]
142
+ self.acts_to_saes[act_name] = prev_sae
143
+ else:
144
+ new_hook = HookPoint()
145
+ new_hook.name = actual_name
146
+ set_deep_attr(self, actual_name, new_hook)
147
+ self._hook_registry[actual_name] = new_hook
148
+ del self.acts_to_saes[act_name]
149
+
150
+ def reset_saes(
151
+ self,
152
+ act_names: str | list[str] | None = None,
153
+ prev_saes: list[SAE[Any] | None] | None = None,
154
+ ) -> None:
155
+ """Reset the SAEs attached to the model.
156
+
157
+ If act_names are provided will just reset SAEs attached to those hooks.
158
+ Otherwise will reset all SAEs attached to the model.
159
+ Optionally can provide a list of prev_saes to reset to. This is mainly
160
+ used to restore previously attached SAEs after temporarily running with
161
+ different SAEs (e.g., with run_with_saes).
162
+
163
+ Args:
164
+ act_names: The act_names of the SAEs to reset. If None, will reset all
165
+ SAEs attached to the model. Defaults to None.
166
+ prev_saes: List of SAEs to replace the current ones with. If None, will
167
+ just remove the SAEs. Defaults to None.
168
+ """
169
+ if isinstance(act_names, str):
170
+ act_names = [act_names]
171
+ elif act_names is None:
172
+ act_names = list(self.acts_to_saes.keys())
173
+
174
+ if prev_saes:
175
+ if len(act_names) != len(prev_saes):
176
+ raise ValueError("act_names and prev_saes must have the same length")
177
+ else:
178
+ prev_saes = [None] * len(act_names) # type: ignore[assignment]
179
+
180
+ for act_name, prev_sae in zip(act_names, prev_saes): # type: ignore[arg-type]
181
+ self._reset_sae(act_name, prev_sae)
182
+
183
+ def run_with_saes(
184
+ self,
185
+ *model_args: Any,
186
+ saes: SAE[Any] | list[SAE[Any]] = [],
187
+ reset_saes_end: bool = True,
188
+ use_error_term: bool | None = None,
189
+ **model_kwargs: Any,
190
+ ) -> torch.Tensor | Loss | tuple[torch.Tensor, Loss] | None:
191
+ """Wrapper around forward pass.
192
+
193
+ Runs the model with the given SAEs attached for one forward pass, then
194
+ removes them. By default, will reset all SAEs to original state after.
195
+
196
+ Args:
197
+ *model_args: Positional arguments for the model forward pass
198
+ saes: The SAEs to be attached for this forward pass
199
+ reset_saes_end: If True, all SAEs added during this run are removed
200
+ at the end, and previously attached SAEs are restored to their
201
+ original state. Default is True.
202
+ use_error_term: If provided, will set the use_error_term attribute
203
+ of all SAEs attached during this run to this value. Defaults to None.
204
+ **model_kwargs: Keyword arguments for the model forward pass
205
+ """
206
+ with self.saes(
207
+ saes=saes, reset_saes_end=reset_saes_end, use_error_term=use_error_term
208
+ ):
209
+ return self(*model_args, **model_kwargs)
210
+
211
+ def run_with_cache_with_saes(
212
+ self,
213
+ *model_args: Any,
214
+ saes: SAE[Any] | list[SAE[Any]] = [],
215
+ reset_saes_end: bool = True,
216
+ use_error_term: bool | None = None,
217
+ return_cache_object: bool = True,
218
+ remove_batch_dim: bool = False,
219
+ **kwargs: Any,
220
+ ) -> tuple[
221
+ torch.Tensor | Loss | tuple[torch.Tensor, Loss] | None,
222
+ ActivationCache | dict[str, torch.Tensor],
223
+ ]:
224
+ """Wrapper around 'run_with_cache'.
225
+
226
+ Attaches given SAEs before running the model with cache and then removes them.
227
+ By default, will reset all SAEs to original state after.
228
+
229
+ Args:
230
+ *model_args: Positional arguments for the model forward pass
231
+ saes: The SAEs to be attached for this forward pass
232
+ reset_saes_end: If True, all SAEs added during this run are removed
233
+ at the end, and previously attached SAEs are restored to their
234
+ original state. Default is True.
235
+ use_error_term: If provided, will set the use_error_term attribute
236
+ of all SAEs attached during this run to this value. Defaults to None.
237
+ return_cache_object: If True, returns an ActivationCache object with
238
+ useful methods, otherwise returns a dictionary of activations.
239
+ remove_batch_dim: Whether to remove the batch dimension
240
+ (only works for batch_size==1). Defaults to False.
241
+ **kwargs: Keyword arguments for the model forward pass
242
+ """
243
+ with self.saes(
244
+ saes=saes, reset_saes_end=reset_saes_end, use_error_term=use_error_term
245
+ ):
246
+ return self.run_with_cache(
247
+ *model_args,
248
+ return_cache_object=return_cache_object, # type: ignore[arg-type]
249
+ remove_batch_dim=remove_batch_dim,
250
+ **kwargs,
251
+ ) # type: ignore[return-value]
252
+
253
+ def run_with_hooks_with_saes(
254
+ self,
255
+ *model_args: Any,
256
+ saes: SAE[Any] | list[SAE[Any]] = [],
257
+ reset_saes_end: bool = True,
258
+ fwd_hooks: list[tuple[str | Callable[..., Any], Callable[..., Any]]] = [],
259
+ bwd_hooks: list[tuple[str | Callable[..., Any], Callable[..., Any]]] = [],
260
+ reset_hooks_end: bool = True,
261
+ clear_contexts: bool = False,
262
+ **model_kwargs: Any,
263
+ ) -> Any:
264
+ """Wrapper around 'run_with_hooks'.
265
+
266
+ Attaches the given SAEs to the model before running the model with hooks
267
+ and then removes them. By default, will reset all SAEs to original state after.
268
+
269
+ Args:
270
+ *model_args: Positional arguments for the model forward pass
271
+ saes: The SAEs to be attached for this forward pass
272
+ reset_saes_end: If True, all SAEs added during this run are removed
273
+ at the end, and previously attached SAEs are restored to their
274
+ original state. Default is True.
275
+ fwd_hooks: List of forward hooks to apply
276
+ bwd_hooks: List of backward hooks to apply
277
+ reset_hooks_end: Whether to reset the hooks at the end of the forward
278
+ pass. Default is True.
279
+ clear_contexts: Whether to clear the contexts at the end of the forward
280
+ pass. Default is False.
281
+ **model_kwargs: Keyword arguments for the model forward pass
282
+ """
283
+ with self.saes(saes=saes, reset_saes_end=reset_saes_end):
284
+ return self.run_with_hooks(
285
+ *model_args,
286
+ fwd_hooks=fwd_hooks,
287
+ bwd_hooks=bwd_hooks,
288
+ reset_hooks_end=reset_hooks_end,
289
+ clear_contexts=clear_contexts,
290
+ **model_kwargs,
291
+ )
292
+
293
+ @contextmanager
294
+ def saes(
295
+ self,
296
+ saes: SAE[Any] | list[SAE[Any]] = [],
297
+ reset_saes_end: bool = True,
298
+ use_error_term: bool | None = None,
299
+ ): # type: ignore[no-untyped-def]
300
+ """A context manager for adding temporary SAEs to the model.
301
+
302
+ By default will keep track of previously attached SAEs, and restore them
303
+ when the context manager exits.
304
+
305
+ Args:
306
+ saes: SAEs to be attached.
307
+ reset_saes_end: If True, removes all SAEs added by this context manager
308
+ when the context manager exits, returning previously attached SAEs
309
+ to their original state.
310
+ use_error_term: If provided, will set the use_error_term attribute of
311
+ all SAEs attached during this run to this value. Defaults to None.
312
+ """
313
+ act_names_to_reset: list[str] = []
314
+ prev_saes: list[SAE[Any] | None] = []
315
+ if isinstance(saes, SAE):
316
+ saes = [saes]
317
+ try:
318
+ for sae in saes:
319
+ act_names_to_reset.append(sae.cfg.metadata.hook_name)
320
+ prev_sae = self.acts_to_saes.get(sae.cfg.metadata.hook_name, None)
321
+ prev_saes.append(prev_sae)
322
+ self.add_sae(sae, use_error_term=use_error_term)
323
+ yield self
324
+ finally:
325
+ if reset_saes_end:
326
+ self.reset_saes(act_names_to_reset, prev_saes)
327
+
328
+ @property
329
+ def hook_dict(self) -> dict[str, HookPoint]:
330
+ """Return combined hook registry including SAE internal hooks.
331
+
332
+ When SAEs are attached, they replace HookPoint entries in the registry.
333
+ This property returns both the base hooks and any internal hooks from
334
+ attached SAEs (like hook_sae_acts_post, hook_sae_input, etc.) with
335
+ their full path names.
336
+ """
337
+ hooks: dict[str, HookPoint] = {}
338
+
339
+ for name, hook_or_sae in self._hook_registry.items():
340
+ if isinstance(hook_or_sae, SAE):
341
+ # Include SAE's internal hooks with full path names
342
+ for sae_hook_name, sae_hook in hook_or_sae.hook_dict.items():
343
+ full_name = f"{name}.{sae_hook_name}"
344
+ hooks[full_name] = sae_hook
345
+ else:
346
+ hooks[name] = hook_or_sae
347
+
348
+ return hooks
sae_lens/config.py CHANGED
@@ -82,6 +82,7 @@ class LoggingConfig:
82
82
  log_to_wandb: bool = True
83
83
  log_activations_store_to_wandb: bool = False
84
84
  log_optimizer_state_to_wandb: bool = False
85
+ log_weights_to_wandb: bool = True
85
86
  wandb_project: str = "sae_lens_training"
86
87
  wandb_id: str | None = None
87
88
  run_name: str | None = None
@@ -107,7 +108,8 @@ class LoggingConfig:
107
108
  type="model",
108
109
  metadata=dict(trainer.cfg.__dict__),
109
110
  )
110
- model_artifact.add_file(str(weights_path))
111
+ if self.log_weights_to_wandb:
112
+ model_artifact.add_file(str(weights_path))
111
113
  model_artifact.add_file(str(cfg_path))
112
114
  wandb.log_artifact(model_artifact, aliases=wandb_aliases)
113
115
 
@@ -557,6 +559,12 @@ class CacheActivationsRunnerConfig:
557
559
  context_size=self.context_size,
558
560
  )
559
561
 
562
+ if self.context_size > self.training_tokens:
563
+ raise ValueError(
564
+ f"context_size ({self.context_size}) is greater than training_tokens "
565
+ f"({self.training_tokens}). Please reduce context_size or increase training_tokens."
566
+ )
567
+
560
568
  if self.new_cached_activations_path is None:
561
569
  self.new_cached_activations_path = _default_cached_activations_path( # type: ignore
562
570
  self.dataset_path, self.model_name, self.hook_name, None
sae_lens/evals.py CHANGED
@@ -335,7 +335,7 @@ def get_downstream_reconstruction_metrics(
335
335
 
336
336
  batch_iter = range(n_batches)
337
337
  if verbose:
338
- batch_iter = tqdm(batch_iter, desc="Reconstruction Batches")
338
+ batch_iter = tqdm(batch_iter, desc="Reconstruction Batches", leave=False)
339
339
 
340
340
  for _ in batch_iter:
341
341
  batch_tokens = activation_store.get_batch_tokens(eval_batch_size_prompts)
@@ -430,7 +430,7 @@ def get_sparsity_and_variance_metrics(
430
430
 
431
431
  batch_iter = range(n_batches)
432
432
  if verbose:
433
- batch_iter = tqdm(batch_iter, desc="Sparsity and Variance Batches")
433
+ batch_iter = tqdm(batch_iter, desc="Sparsity and Variance Batches", leave=False)
434
434
 
435
435
  for _ in batch_iter:
436
436
  batch_tokens = activation_store.get_batch_tokens(eval_batch_size_prompts)
@@ -575,6 +575,8 @@ def _infer_gemma_3_raw_cfg_dict(repo_id: str, folder_name: str) -> dict[str, Any
575
575
  "model_name": model_name,
576
576
  "hf_hook_point_in": hf_hook_point_in,
577
577
  }
578
+ if "transcoder" in folder_name or "clt" in folder_name:
579
+ cfg["affine_connection"] = "affine" in folder_name
578
580
  if hf_hook_point_out is not None:
579
581
  cfg["hf_hook_point_out"] = hf_hook_point_out
580
582
 
@@ -614,11 +616,11 @@ def get_gemma_3_config_from_hf(
614
616
  if "resid_post" in folder_name:
615
617
  hook_name = f"blocks.{layer}.hook_resid_post"
616
618
  elif "attn_out" in folder_name:
617
- hook_name = f"blocks.{layer}.hook_attn_out"
619
+ hook_name = f"blocks.{layer}.attn.hook_z"
618
620
  elif "mlp_out" in folder_name:
619
621
  hook_name = f"blocks.{layer}.hook_mlp_out"
620
622
  elif "transcoder" in folder_name or "clt" in folder_name:
621
- hook_name = f"blocks.{layer}.ln2.hook_normalized"
623
+ hook_name = f"blocks.{layer}.hook_mlp_in"
622
624
  hook_name_out = f"blocks.{layer}.hook_mlp_out"
623
625
  else:
624
626
  raise ValueError("Hook name not found in folder_name.")
@@ -643,7 +645,11 @@ def get_gemma_3_config_from_hf(
643
645
 
644
646
  architecture = "jumprelu"
645
647
  if "transcoder" in folder_name or "clt" in folder_name:
646
- architecture = "jumprelu_skip_transcoder"
648
+ architecture = (
649
+ "jumprelu_skip_transcoder"
650
+ if raw_cfg_dict.get("affine_connection", False)
651
+ else "jumprelu_transcoder"
652
+ )
647
653
  d_out = shapes_dict["w_dec"][-1]
648
654
 
649
655
  cfg = {
@@ -660,7 +666,8 @@ def get_gemma_3_config_from_hf(
660
666
  "dataset_path": "monology/pile-uncopyrighted",
661
667
  "context_size": 1024,
662
668
  "apply_b_dec_to_input": False,
663
- "normalize_activations": None,
669
+ "normalize_activations": "none",
670
+ "reshape_activations": "none",
664
671
  "hf_hook_name": raw_cfg_dict.get("hf_hook_point_in"),
665
672
  }
666
673
  if hook_name_out is not None: