sae-lens 6.29.1__py3-none-any.whl → 6.33.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sae_lens/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.29.1"
2
+ __version__ = "6.33.0"
3
3
 
4
4
  import logging
5
5
 
@@ -125,6 +125,19 @@ __all__ = [
125
125
  "MatchingPursuitTrainingSAEConfig",
126
126
  ]
127
127
 
128
+ # Conditional export for SAETransformerBridge (requires transformer-lens v3+)
129
+ try:
130
+ from sae_lens.analysis.compat import has_transformer_bridge
131
+
132
+ if has_transformer_bridge():
133
+ from sae_lens.analysis.sae_transformer_bridge import ( # noqa: F401
134
+ SAETransformerBridge,
135
+ )
136
+
137
+ __all__.append("SAETransformerBridge")
138
+ except ImportError:
139
+ pass
140
+
128
141
 
129
142
  register_sae_class("standard", StandardSAE, StandardSAEConfig)
130
143
  register_sae_training_class("standard", StandardTrainingSAE, StandardTrainingSAEConfig)
@@ -0,0 +1,15 @@
1
+ from sae_lens.analysis.hooked_sae_transformer import HookedSAETransformer
2
+
3
+ __all__ = ["HookedSAETransformer"]
4
+
5
+ try:
6
+ from sae_lens.analysis.compat import has_transformer_bridge
7
+
8
+ if has_transformer_bridge():
9
+ from sae_lens.analysis.sae_transformer_bridge import ( # noqa: F401
10
+ SAETransformerBridge,
11
+ )
12
+
13
+ __all__.append("SAETransformerBridge")
14
+ except ImportError:
15
+ pass
@@ -0,0 +1,16 @@
1
+ import importlib.metadata
2
+
3
+ from packaging.version import parse as parse_version
4
+
5
+
6
+ def get_transformer_lens_version() -> tuple[int, int, int]:
7
+ """Get transformer-lens version as (major, minor, patch)."""
8
+ version_str = importlib.metadata.version("transformer-lens")
9
+ version = parse_version(version_str)
10
+ return (version.major, version.minor, version.micro)
11
+
12
+
13
+ def has_transformer_bridge() -> bool:
14
+ """Check if TransformerBridge is available (v3+)."""
15
+ major, _, _ = get_transformer_lens_version()
16
+ return major >= 3
@@ -126,7 +126,7 @@ class HookedSAETransformer(HookedTransformer):
126
126
  current_sae.use_error_term = current_sae._original_use_error_term # type: ignore
127
127
  delattr(current_sae, "_original_use_error_term")
128
128
 
129
- if prev_sae:
129
+ if prev_sae is not None:
130
130
  set_deep_attr(self, act_name, prev_sae)
131
131
  self.acts_to_saes[act_name] = prev_sae
132
132
  else:
@@ -0,0 +1,348 @@
1
+ from collections.abc import Callable
2
+ from contextlib import contextmanager
3
+ from typing import Any
4
+
5
+ import torch
6
+ from transformer_lens.ActivationCache import ActivationCache
7
+ from transformer_lens.hook_points import HookPoint
8
+ from transformer_lens.model_bridge import TransformerBridge
9
+
10
+ from sae_lens import logger
11
+ from sae_lens.analysis.hooked_sae_transformer import set_deep_attr
12
+ from sae_lens.saes.sae import SAE
13
+
14
+ SingleLoss = torch.Tensor # Type alias for a single element tensor
15
+ LossPerToken = torch.Tensor
16
+ Loss = SingleLoss | LossPerToken
17
+
18
+
19
+ class SAETransformerBridge(TransformerBridge): # type: ignore[misc,no-untyped-call]
20
+ """A TransformerBridge subclass that supports attaching SAEs.
21
+
22
+ .. warning::
23
+ This class is in **beta**. The API may change in future versions.
24
+
25
+ This class provides the same SAE attachment functionality as HookedSAETransformer,
26
+ but for transformer-lens v3's TransformerBridge instead of HookedTransformer.
27
+
28
+ TransformerBridge is a lightweight wrapper around HuggingFace models that provides
29
+ hook points without the overhead of HookedTransformer's weight processing. This is
30
+ useful for models not natively supported by HookedTransformer, such as Gemma 3.
31
+ """
32
+
33
+ acts_to_saes: dict[str, SAE[Any]]
34
+
35
+ def __init__(self, *args: Any, **kwargs: Any):
36
+ super().__init__(*args, **kwargs)
37
+ self.acts_to_saes = {}
38
+
39
+ @classmethod
40
+ def boot_transformers( # type: ignore[override]
41
+ cls,
42
+ model_name: str,
43
+ **kwargs: Any,
44
+ ) -> "SAETransformerBridge":
45
+ """Factory method to boot a model and return SAETransformerBridge instance.
46
+
47
+ Args:
48
+ model_name: The name of the model to load (e.g., "gpt2", "gemma-2-2b")
49
+ **kwargs: Additional arguments passed to TransformerBridge.boot_transformers
50
+
51
+ Returns:
52
+ SAETransformerBridge instance with the loaded model
53
+ """
54
+ # Boot parent TransformerBridge
55
+ bridge = TransformerBridge.boot_transformers(model_name, **kwargs)
56
+ # Convert to our class
57
+ # NOTE: this is super hacky and scary, but I don't know how else to achieve this given TLens' internal code
58
+ bridge.__class__ = cls
59
+ bridge.acts_to_saes = {} # type: ignore[attr-defined]
60
+ return bridge # type: ignore[return-value]
61
+
62
+ def _resolve_hook_name(self, hook_name: str) -> str:
63
+ """Resolve alias to actual hook name.
64
+
65
+ TransformerBridge supports hook aliases like 'blocks.0.hook_mlp_out'
66
+ that map to actual paths like 'blocks.0.mlp.hook_out'.
67
+ """
68
+ # Combine static and dynamic aliases
69
+ aliases: dict[str, Any] = {
70
+ **self.hook_aliases,
71
+ **self._collect_hook_aliases_from_registry(),
72
+ }
73
+ resolved = aliases.get(hook_name, hook_name)
74
+ # aliases values are always strings, but type checker doesn't know this
75
+ return resolved if isinstance(resolved, str) else hook_name
76
+
77
+ def add_sae(self, sae: SAE[Any], use_error_term: bool | None = None) -> None:
78
+ """Attaches an SAE to the model.
79
+
80
+ WARNING: This SAE will be permanently attached until you remove it with
81
+ reset_saes. This function will also overwrite any existing SAE attached
82
+ to the same hook point.
83
+
84
+ Args:
85
+ sae: The SAE to attach to the model
86
+ use_error_term: If provided, will set the use_error_term attribute of
87
+ the SAE to this value. Determines whether the SAE returns input
88
+ or reconstruction. Defaults to None.
89
+ """
90
+ alias_name = sae.cfg.metadata.hook_name
91
+ actual_name = self._resolve_hook_name(alias_name)
92
+
93
+ # Check if hook exists (either as alias or actual name)
94
+ if (alias_name not in self.acts_to_saes) and (
95
+ actual_name not in self._hook_registry
96
+ ):
97
+ logger.warning(
98
+ f"No hook found for {alias_name}. Skipping. "
99
+ f"Check model._hook_registry for available hooks."
100
+ )
101
+ return
102
+
103
+ if use_error_term is not None:
104
+ if not hasattr(sae, "_original_use_error_term"):
105
+ sae._original_use_error_term = sae.use_error_term # type: ignore[attr-defined]
106
+ sae.use_error_term = use_error_term
107
+
108
+ # Replace hook and update registry
109
+ set_deep_attr(self, actual_name, sae)
110
+ self._hook_registry[actual_name] = sae # type: ignore[assignment]
111
+ self.acts_to_saes[alias_name] = sae
112
+
113
+ def _reset_sae(self, act_name: str, prev_sae: SAE[Any] | None = None) -> None:
114
+ """Resets an SAE that was attached to the model.
115
+
116
+ By default will remove the SAE from that hook_point.
117
+ If prev_sae is provided, will replace the current SAE with the provided one.
118
+ This is mainly used to restore previously attached SAEs after temporarily
119
+ running with different SAEs (e.g., with run_with_saes).
120
+
121
+ Args:
122
+ act_name: The hook_name of the SAE to reset
123
+ prev_sae: The SAE to replace the current one with. If None, will just
124
+ remove the SAE from this hook point. Defaults to None.
125
+ """
126
+ if act_name not in self.acts_to_saes:
127
+ logger.warning(
128
+ f"No SAE is attached to {act_name}. There's nothing to reset."
129
+ )
130
+ return
131
+
132
+ actual_name = self._resolve_hook_name(act_name)
133
+ current_sae = self.acts_to_saes[act_name]
134
+
135
+ if hasattr(current_sae, "_original_use_error_term"):
136
+ current_sae.use_error_term = current_sae._original_use_error_term # type: ignore[attr-defined]
137
+ delattr(current_sae, "_original_use_error_term")
138
+
139
+ if prev_sae is not None:
140
+ set_deep_attr(self, actual_name, prev_sae)
141
+ self._hook_registry[actual_name] = prev_sae # type: ignore[assignment]
142
+ self.acts_to_saes[act_name] = prev_sae
143
+ else:
144
+ new_hook = HookPoint()
145
+ new_hook.name = actual_name
146
+ set_deep_attr(self, actual_name, new_hook)
147
+ self._hook_registry[actual_name] = new_hook
148
+ del self.acts_to_saes[act_name]
149
+
150
+ def reset_saes(
151
+ self,
152
+ act_names: str | list[str] | None = None,
153
+ prev_saes: list[SAE[Any] | None] | None = None,
154
+ ) -> None:
155
+ """Reset the SAEs attached to the model.
156
+
157
+ If act_names are provided will just reset SAEs attached to those hooks.
158
+ Otherwise will reset all SAEs attached to the model.
159
+ Optionally can provide a list of prev_saes to reset to. This is mainly
160
+ used to restore previously attached SAEs after temporarily running with
161
+ different SAEs (e.g., with run_with_saes).
162
+
163
+ Args:
164
+ act_names: The act_names of the SAEs to reset. If None, will reset all
165
+ SAEs attached to the model. Defaults to None.
166
+ prev_saes: List of SAEs to replace the current ones with. If None, will
167
+ just remove the SAEs. Defaults to None.
168
+ """
169
+ if isinstance(act_names, str):
170
+ act_names = [act_names]
171
+ elif act_names is None:
172
+ act_names = list(self.acts_to_saes.keys())
173
+
174
+ if prev_saes:
175
+ if len(act_names) != len(prev_saes):
176
+ raise ValueError("act_names and prev_saes must have the same length")
177
+ else:
178
+ prev_saes = [None] * len(act_names) # type: ignore[assignment]
179
+
180
+ for act_name, prev_sae in zip(act_names, prev_saes): # type: ignore[arg-type]
181
+ self._reset_sae(act_name, prev_sae)
182
+
183
+ def run_with_saes(
184
+ self,
185
+ *model_args: Any,
186
+ saes: SAE[Any] | list[SAE[Any]] = [],
187
+ reset_saes_end: bool = True,
188
+ use_error_term: bool | None = None,
189
+ **model_kwargs: Any,
190
+ ) -> torch.Tensor | Loss | tuple[torch.Tensor, Loss] | None:
191
+ """Wrapper around forward pass.
192
+
193
+ Runs the model with the given SAEs attached for one forward pass, then
194
+ removes them. By default, will reset all SAEs to original state after.
195
+
196
+ Args:
197
+ *model_args: Positional arguments for the model forward pass
198
+ saes: The SAEs to be attached for this forward pass
199
+ reset_saes_end: If True, all SAEs added during this run are removed
200
+ at the end, and previously attached SAEs are restored to their
201
+ original state. Default is True.
202
+ use_error_term: If provided, will set the use_error_term attribute
203
+ of all SAEs attached during this run to this value. Defaults to None.
204
+ **model_kwargs: Keyword arguments for the model forward pass
205
+ """
206
+ with self.saes(
207
+ saes=saes, reset_saes_end=reset_saes_end, use_error_term=use_error_term
208
+ ):
209
+ return self(*model_args, **model_kwargs)
210
+
211
+ def run_with_cache_with_saes(
212
+ self,
213
+ *model_args: Any,
214
+ saes: SAE[Any] | list[SAE[Any]] = [],
215
+ reset_saes_end: bool = True,
216
+ use_error_term: bool | None = None,
217
+ return_cache_object: bool = True,
218
+ remove_batch_dim: bool = False,
219
+ **kwargs: Any,
220
+ ) -> tuple[
221
+ torch.Tensor | Loss | tuple[torch.Tensor, Loss] | None,
222
+ ActivationCache | dict[str, torch.Tensor],
223
+ ]:
224
+ """Wrapper around 'run_with_cache'.
225
+
226
+ Attaches given SAEs before running the model with cache and then removes them.
227
+ By default, will reset all SAEs to original state after.
228
+
229
+ Args:
230
+ *model_args: Positional arguments for the model forward pass
231
+ saes: The SAEs to be attached for this forward pass
232
+ reset_saes_end: If True, all SAEs added during this run are removed
233
+ at the end, and previously attached SAEs are restored to their
234
+ original state. Default is True.
235
+ use_error_term: If provided, will set the use_error_term attribute
236
+ of all SAEs attached during this run to this value. Defaults to None.
237
+ return_cache_object: If True, returns an ActivationCache object with
238
+ useful methods, otherwise returns a dictionary of activations.
239
+ remove_batch_dim: Whether to remove the batch dimension
240
+ (only works for batch_size==1). Defaults to False.
241
+ **kwargs: Keyword arguments for the model forward pass
242
+ """
243
+ with self.saes(
244
+ saes=saes, reset_saes_end=reset_saes_end, use_error_term=use_error_term
245
+ ):
246
+ return self.run_with_cache(
247
+ *model_args,
248
+ return_cache_object=return_cache_object, # type: ignore[arg-type]
249
+ remove_batch_dim=remove_batch_dim,
250
+ **kwargs,
251
+ ) # type: ignore[return-value]
252
+
253
+ def run_with_hooks_with_saes(
254
+ self,
255
+ *model_args: Any,
256
+ saes: SAE[Any] | list[SAE[Any]] = [],
257
+ reset_saes_end: bool = True,
258
+ fwd_hooks: list[tuple[str | Callable[..., Any], Callable[..., Any]]] = [],
259
+ bwd_hooks: list[tuple[str | Callable[..., Any], Callable[..., Any]]] = [],
260
+ reset_hooks_end: bool = True,
261
+ clear_contexts: bool = False,
262
+ **model_kwargs: Any,
263
+ ) -> Any:
264
+ """Wrapper around 'run_with_hooks'.
265
+
266
+ Attaches the given SAEs to the model before running the model with hooks
267
+ and then removes them. By default, will reset all SAEs to original state after.
268
+
269
+ Args:
270
+ *model_args: Positional arguments for the model forward pass
271
+ saes: The SAEs to be attached for this forward pass
272
+ reset_saes_end: If True, all SAEs added during this run are removed
273
+ at the end, and previously attached SAEs are restored to their
274
+ original state. Default is True.
275
+ fwd_hooks: List of forward hooks to apply
276
+ bwd_hooks: List of backward hooks to apply
277
+ reset_hooks_end: Whether to reset the hooks at the end of the forward
278
+ pass. Default is True.
279
+ clear_contexts: Whether to clear the contexts at the end of the forward
280
+ pass. Default is False.
281
+ **model_kwargs: Keyword arguments for the model forward pass
282
+ """
283
+ with self.saes(saes=saes, reset_saes_end=reset_saes_end):
284
+ return self.run_with_hooks(
285
+ *model_args,
286
+ fwd_hooks=fwd_hooks,
287
+ bwd_hooks=bwd_hooks,
288
+ reset_hooks_end=reset_hooks_end,
289
+ clear_contexts=clear_contexts,
290
+ **model_kwargs,
291
+ )
292
+
293
+ @contextmanager
294
+ def saes(
295
+ self,
296
+ saes: SAE[Any] | list[SAE[Any]] = [],
297
+ reset_saes_end: bool = True,
298
+ use_error_term: bool | None = None,
299
+ ): # type: ignore[no-untyped-def]
300
+ """A context manager for adding temporary SAEs to the model.
301
+
302
+ By default will keep track of previously attached SAEs, and restore them
303
+ when the context manager exits.
304
+
305
+ Args:
306
+ saes: SAEs to be attached.
307
+ reset_saes_end: If True, removes all SAEs added by this context manager
308
+ when the context manager exits, returning previously attached SAEs
309
+ to their original state.
310
+ use_error_term: If provided, will set the use_error_term attribute of
311
+ all SAEs attached during this run to this value. Defaults to None.
312
+ """
313
+ act_names_to_reset: list[str] = []
314
+ prev_saes: list[SAE[Any] | None] = []
315
+ if isinstance(saes, SAE):
316
+ saes = [saes]
317
+ try:
318
+ for sae in saes:
319
+ act_names_to_reset.append(sae.cfg.metadata.hook_name)
320
+ prev_sae = self.acts_to_saes.get(sae.cfg.metadata.hook_name, None)
321
+ prev_saes.append(prev_sae)
322
+ self.add_sae(sae, use_error_term=use_error_term)
323
+ yield self
324
+ finally:
325
+ if reset_saes_end:
326
+ self.reset_saes(act_names_to_reset, prev_saes)
327
+
328
+ @property
329
+ def hook_dict(self) -> dict[str, HookPoint]:
330
+ """Return combined hook registry including SAE internal hooks.
331
+
332
+ When SAEs are attached, they replace HookPoint entries in the registry.
333
+ This property returns both the base hooks and any internal hooks from
334
+ attached SAEs (like hook_sae_acts_post, hook_sae_input, etc.) with
335
+ their full path names.
336
+ """
337
+ hooks: dict[str, HookPoint] = {}
338
+
339
+ for name, hook_or_sae in self._hook_registry.items():
340
+ if isinstance(hook_or_sae, SAE):
341
+ # Include SAE's internal hooks with full path names
342
+ for sae_hook_name, sae_hook in hook_or_sae.hook_dict.items():
343
+ full_name = f"{name}.{sae_hook_name}"
344
+ hooks[full_name] = sae_hook
345
+ else:
346
+ hooks[name] = hook_or_sae
347
+
348
+ return hooks
sae_lens/config.py CHANGED
@@ -82,6 +82,7 @@ class LoggingConfig:
82
82
  log_to_wandb: bool = True
83
83
  log_activations_store_to_wandb: bool = False
84
84
  log_optimizer_state_to_wandb: bool = False
85
+ log_weights_to_wandb: bool = True
85
86
  wandb_project: str = "sae_lens_training"
86
87
  wandb_id: str | None = None
87
88
  run_name: str | None = None
@@ -107,7 +108,8 @@ class LoggingConfig:
107
108
  type="model",
108
109
  metadata=dict(trainer.cfg.__dict__),
109
110
  )
110
- model_artifact.add_file(str(weights_path))
111
+ if self.log_weights_to_wandb:
112
+ model_artifact.add_file(str(weights_path))
111
113
  model_artifact.add_file(str(cfg_path))
112
114
  wandb.log_artifact(model_artifact, aliases=wandb_aliases)
113
115
 
@@ -557,6 +559,12 @@ class CacheActivationsRunnerConfig:
557
559
  context_size=self.context_size,
558
560
  )
559
561
 
562
+ if self.context_size > self.training_tokens:
563
+ raise ValueError(
564
+ f"context_size ({self.context_size}) is greater than training_tokens "
565
+ f"({self.training_tokens}). Please reduce context_size or increase training_tokens."
566
+ )
567
+
560
568
  if self.new_cached_activations_path is None:
561
569
  self.new_cached_activations_path = _default_cached_activations_path( # type: ignore
562
570
  self.dataset_path, self.model_name, self.hook_name, None
sae_lens/evals.py CHANGED
@@ -335,7 +335,7 @@ def get_downstream_reconstruction_metrics(
335
335
 
336
336
  batch_iter = range(n_batches)
337
337
  if verbose:
338
- batch_iter = tqdm(batch_iter, desc="Reconstruction Batches")
338
+ batch_iter = tqdm(batch_iter, desc="Reconstruction Batches", leave=False)
339
339
 
340
340
  for _ in batch_iter:
341
341
  batch_tokens = activation_store.get_batch_tokens(eval_batch_size_prompts)
@@ -430,7 +430,7 @@ def get_sparsity_and_variance_metrics(
430
430
 
431
431
  batch_iter = range(n_batches)
432
432
  if verbose:
433
- batch_iter = tqdm(batch_iter, desc="Sparsity and Variance Batches")
433
+ batch_iter = tqdm(batch_iter, desc="Sparsity and Variance Batches", leave=False)
434
434
 
435
435
  for _ in batch_iter:
436
436
  batch_tokens = activation_store.get_batch_tokens(eval_batch_size_prompts)
@@ -575,6 +575,8 @@ def _infer_gemma_3_raw_cfg_dict(repo_id: str, folder_name: str) -> dict[str, Any
575
575
  "model_name": model_name,
576
576
  "hf_hook_point_in": hf_hook_point_in,
577
577
  }
578
+ if "transcoder" in folder_name or "clt" in folder_name:
579
+ cfg["affine_connection"] = "affine" in folder_name
578
580
  if hf_hook_point_out is not None:
579
581
  cfg["hf_hook_point_out"] = hf_hook_point_out
580
582
 
@@ -614,11 +616,11 @@ def get_gemma_3_config_from_hf(
614
616
  if "resid_post" in folder_name:
615
617
  hook_name = f"blocks.{layer}.hook_resid_post"
616
618
  elif "attn_out" in folder_name:
617
- hook_name = f"blocks.{layer}.hook_attn_out"
619
+ hook_name = f"blocks.{layer}.attn.hook_z"
618
620
  elif "mlp_out" in folder_name:
619
621
  hook_name = f"blocks.{layer}.hook_mlp_out"
620
622
  elif "transcoder" in folder_name or "clt" in folder_name:
621
- hook_name = f"blocks.{layer}.ln2.hook_normalized"
623
+ hook_name = f"blocks.{layer}.hook_mlp_in"
622
624
  hook_name_out = f"blocks.{layer}.hook_mlp_out"
623
625
  else:
624
626
  raise ValueError("Hook name not found in folder_name.")
@@ -643,7 +645,11 @@ def get_gemma_3_config_from_hf(
643
645
 
644
646
  architecture = "jumprelu"
645
647
  if "transcoder" in folder_name or "clt" in folder_name:
646
- architecture = "jumprelu_skip_transcoder"
648
+ architecture = (
649
+ "jumprelu_skip_transcoder"
650
+ if raw_cfg_dict.get("affine_connection", False)
651
+ else "jumprelu_transcoder"
652
+ )
647
653
  d_out = shapes_dict["w_dec"][-1]
648
654
 
649
655
  cfg = {
@@ -660,7 +666,8 @@ def get_gemma_3_config_from_hf(
660
666
  "dataset_path": "monology/pile-uncopyrighted",
661
667
  "context_size": 1024,
662
668
  "apply_b_dec_to_input": False,
663
- "normalize_activations": None,
669
+ "normalize_activations": "none",
670
+ "reshape_activations": "none",
664
671
  "hf_hook_name": raw_cfg_dict.get("hf_hook_point_in"),
665
672
  }
666
673
  if hook_name_out is not None:
@@ -57,28 +57,6 @@ def get_pretrained_saes_directory() -> dict[str, PretrainedSAELookup]:
57
57
  return directory
58
58
 
59
59
 
60
- def get_norm_scaling_factor(release: str, sae_id: str) -> float | None:
61
- """
62
- Retrieve the norm_scaling_factor for a specific SAE if it exists.
63
-
64
- Args:
65
- release (str): The release name of the SAE.
66
- sae_id (str): The ID of the specific SAE.
67
-
68
- Returns:
69
- float | None: The norm_scaling_factor if it exists, None otherwise.
70
- """
71
- package = "sae_lens"
72
- yaml_file = files(package).joinpath("pretrained_saes.yaml")
73
- with yaml_file.open("r") as file:
74
- data = yaml.safe_load(file)
75
- if release in data:
76
- for sae_info in data[release]["saes"]:
77
- if sae_info["id"] == sae_id:
78
- return sae_info.get("norm_scaling_factor")
79
- return None
80
-
81
-
82
60
  def get_repo_id_and_folder_name(release: str, sae_id: str) -> tuple[str, str]:
83
61
  saes_directory = get_pretrained_saes_directory()
84
62
  sae_info = saes_directory.get(release, None)
@@ -4148,6 +4148,7 @@ gemma-scope-2-4b-it-res:
4148
4148
  - id: layer_17_width_16k_l0_medium
4149
4149
  path: resid_post/layer_17_width_16k_l0_medium
4150
4150
  l0: 60
4151
+ neuronpedia: gemma-3-4b-it/17-gemmascope-2-res-16k
4151
4152
  - id: layer_17_width_16k_l0_small
4152
4153
  path: resid_post/layer_17_width_16k_l0_small
4153
4154
  l0: 20
@@ -4166,6 +4167,7 @@ gemma-scope-2-4b-it-res:
4166
4167
  - id: layer_17_width_262k_l0_medium
4167
4168
  path: resid_post/layer_17_width_262k_l0_medium
4168
4169
  l0: 60
4170
+ neuronpedia: gemma-3-4b-it/17-gemmascope-2-res-262k
4169
4171
  - id: layer_17_width_262k_l0_medium_seed_1
4170
4172
  path: resid_post/layer_17_width_262k_l0_medium_seed_1
4171
4173
  l0: 60
@@ -4178,6 +4180,7 @@ gemma-scope-2-4b-it-res:
4178
4180
  - id: layer_17_width_65k_l0_medium
4179
4181
  path: resid_post/layer_17_width_65k_l0_medium
4180
4182
  l0: 60
4183
+ neuronpedia: gemma-3-4b-it/17-gemmascope-2-res-65k
4181
4184
  - id: layer_17_width_65k_l0_small
4182
4185
  path: resid_post/layer_17_width_65k_l0_small
4183
4186
  l0: 20
@@ -4187,6 +4190,7 @@ gemma-scope-2-4b-it-res:
4187
4190
  - id: layer_22_width_16k_l0_medium
4188
4191
  path: resid_post/layer_22_width_16k_l0_medium
4189
4192
  l0: 60
4193
+ neuronpedia: gemma-3-4b-it/22-gemmascope-2-res-16k
4190
4194
  - id: layer_22_width_16k_l0_small
4191
4195
  path: resid_post/layer_22_width_16k_l0_small
4192
4196
  l0: 20
@@ -4205,6 +4209,7 @@ gemma-scope-2-4b-it-res:
4205
4209
  - id: layer_22_width_262k_l0_medium
4206
4210
  path: resid_post/layer_22_width_262k_l0_medium
4207
4211
  l0: 60
4212
+ neuronpedia: gemma-3-4b-it/22-gemmascope-2-res-262k
4208
4213
  - id: layer_22_width_262k_l0_medium_seed_1
4209
4214
  path: resid_post/layer_22_width_262k_l0_medium_seed_1
4210
4215
  l0: 60
@@ -4217,6 +4222,7 @@ gemma-scope-2-4b-it-res:
4217
4222
  - id: layer_22_width_65k_l0_medium
4218
4223
  path: resid_post/layer_22_width_65k_l0_medium
4219
4224
  l0: 60
4225
+ neuronpedia: gemma-3-4b-it/22-gemmascope-2-res-65k
4220
4226
  - id: layer_22_width_65k_l0_small
4221
4227
  path: resid_post/layer_22_width_65k_l0_small
4222
4228
  l0: 20
@@ -4226,6 +4232,7 @@ gemma-scope-2-4b-it-res:
4226
4232
  - id: layer_29_width_16k_l0_medium
4227
4233
  path: resid_post/layer_29_width_16k_l0_medium
4228
4234
  l0: 60
4235
+ neuronpedia: gemma-3-4b-it/29-gemmascope-2-res-16k
4229
4236
  - id: layer_29_width_16k_l0_small
4230
4237
  path: resid_post/layer_29_width_16k_l0_small
4231
4238
  l0: 20
@@ -4244,6 +4251,7 @@ gemma-scope-2-4b-it-res:
4244
4251
  - id: layer_29_width_262k_l0_medium
4245
4252
  path: resid_post/layer_29_width_262k_l0_medium
4246
4253
  l0: 60
4254
+ neuronpedia: gemma-3-4b-it/29-gemmascope-2-res-262k
4247
4255
  - id: layer_29_width_262k_l0_medium_seed_1
4248
4256
  path: resid_post/layer_29_width_262k_l0_medium_seed_1
4249
4257
  l0: 60
@@ -4256,6 +4264,7 @@ gemma-scope-2-4b-it-res:
4256
4264
  - id: layer_29_width_65k_l0_medium
4257
4265
  path: resid_post/layer_29_width_65k_l0_medium
4258
4266
  l0: 60
4267
+ neuronpedia: gemma-3-4b-it/29-gemmascope-2-res-65k
4259
4268
  - id: layer_29_width_65k_l0_small
4260
4269
  path: resid_post/layer_29_width_65k_l0_small
4261
4270
  l0: 20
@@ -4265,6 +4274,7 @@ gemma-scope-2-4b-it-res:
4265
4274
  - id: layer_9_width_16k_l0_medium
4266
4275
  path: resid_post/layer_9_width_16k_l0_medium
4267
4276
  l0: 53
4277
+ neuronpedia: gemma-3-4b-it/9-gemmascope-2-res-16k
4268
4278
  - id: layer_9_width_16k_l0_small
4269
4279
  path: resid_post/layer_9_width_16k_l0_small
4270
4280
  l0: 17
@@ -4283,6 +4293,7 @@ gemma-scope-2-4b-it-res:
4283
4293
  - id: layer_9_width_262k_l0_medium
4284
4294
  path: resid_post/layer_9_width_262k_l0_medium
4285
4295
  l0: 53
4296
+ neuronpedia: gemma-3-4b-it/9-gemmascope-2-res-262k
4286
4297
  - id: layer_9_width_262k_l0_medium_seed_1
4287
4298
  path: resid_post/layer_9_width_262k_l0_medium_seed_1
4288
4299
  l0: 53
@@ -4295,6 +4306,7 @@ gemma-scope-2-4b-it-res:
4295
4306
  - id: layer_9_width_65k_l0_medium
4296
4307
  path: resid_post/layer_9_width_65k_l0_medium
4297
4308
  l0: 53
4309
+ neuronpedia: gemma-3-4b-it/9-gemmascope-2-res-65k
4298
4310
  - id: layer_9_width_65k_l0_small
4299
4311
  path: resid_post/layer_9_width_65k_l0_small
4300
4312
  l0: 17
@@ -14491,6 +14503,7 @@ gemma-scope-2-270m-it-res:
14491
14503
  - id: layer_12_width_16k_l0_medium
14492
14504
  path: resid_post/layer_12_width_16k_l0_medium
14493
14505
  l0: 60
14506
+ neuronpedia: gemma-3-270m-it/12-gemmascope-2-res-16k
14494
14507
  - id: layer_12_width_16k_l0_small
14495
14508
  path: resid_post/layer_12_width_16k_l0_small
14496
14509
  l0: 20
@@ -14509,6 +14522,7 @@ gemma-scope-2-270m-it-res:
14509
14522
  - id: layer_12_width_262k_l0_medium
14510
14523
  path: resid_post/layer_12_width_262k_l0_medium
14511
14524
  l0: 60
14525
+ neuronpedia: gemma-3-270m-it/12-gemmascope-2-res-262k
14512
14526
  - id: layer_12_width_262k_l0_medium_seed_1
14513
14527
  path: resid_post/layer_12_width_262k_l0_medium_seed_1
14514
14528
  l0: 60
@@ -14521,6 +14535,7 @@ gemma-scope-2-270m-it-res:
14521
14535
  - id: layer_12_width_65k_l0_medium
14522
14536
  path: resid_post/layer_12_width_65k_l0_medium
14523
14537
  l0: 60
14538
+ neuronpedia: gemma-3-270m-it/12-gemmascope-2-res-65k
14524
14539
  - id: layer_12_width_65k_l0_small
14525
14540
  path: resid_post/layer_12_width_65k_l0_small
14526
14541
  l0: 20
@@ -14530,6 +14545,7 @@ gemma-scope-2-270m-it-res:
14530
14545
  - id: layer_15_width_16k_l0_medium
14531
14546
  path: resid_post/layer_15_width_16k_l0_medium
14532
14547
  l0: 60
14548
+ neuronpedia: gemma-3-270m-it/15-gemmascope-2-res-16k
14533
14549
  - id: layer_15_width_16k_l0_small
14534
14550
  path: resid_post/layer_15_width_16k_l0_small
14535
14551
  l0: 20
@@ -14548,6 +14564,7 @@ gemma-scope-2-270m-it-res:
14548
14564
  - id: layer_15_width_262k_l0_medium
14549
14565
  path: resid_post/layer_15_width_262k_l0_medium
14550
14566
  l0: 60
14567
+ neuronpedia: gemma-3-270m-it/15-gemmascope-2-res-262k
14551
14568
  - id: layer_15_width_262k_l0_medium_seed_1
14552
14569
  path: resid_post/layer_15_width_262k_l0_medium_seed_1
14553
14570
  l0: 60
@@ -14560,6 +14577,7 @@ gemma-scope-2-270m-it-res:
14560
14577
  - id: layer_15_width_65k_l0_medium
14561
14578
  path: resid_post/layer_15_width_65k_l0_medium
14562
14579
  l0: 60
14580
+ neuronpedia: gemma-3-270m-it/15-gemmascope-2-res-65k
14563
14581
  - id: layer_15_width_65k_l0_small
14564
14582
  path: resid_post/layer_15_width_65k_l0_small
14565
14583
  l0: 20
@@ -14569,6 +14587,7 @@ gemma-scope-2-270m-it-res:
14569
14587
  - id: layer_5_width_16k_l0_medium
14570
14588
  path: resid_post/layer_5_width_16k_l0_medium
14571
14589
  l0: 55
14590
+ neuronpedia: gemma-3-270m-it/5-gemmascope-2-res-16k
14572
14591
  - id: layer_5_width_16k_l0_small
14573
14592
  path: resid_post/layer_5_width_16k_l0_small
14574
14593
  l0: 18
@@ -14587,6 +14606,7 @@ gemma-scope-2-270m-it-res:
14587
14606
  - id: layer_5_width_262k_l0_medium
14588
14607
  path: resid_post/layer_5_width_262k_l0_medium
14589
14608
  l0: 55
14609
+ neuronpedia: gemma-3-270m-it/5-gemmascope-2-res-262k
14590
14610
  - id: layer_5_width_262k_l0_medium_seed_1
14591
14611
  path: resid_post/layer_5_width_262k_l0_medium_seed_1
14592
14612
  l0: 55
@@ -14599,6 +14619,7 @@ gemma-scope-2-270m-it-res:
14599
14619
  - id: layer_5_width_65k_l0_medium
14600
14620
  path: resid_post/layer_5_width_65k_l0_medium
14601
14621
  l0: 55
14622
+ neuronpedia: gemma-3-270m-it/5-gemmascope-2-res-65k
14602
14623
  - id: layer_5_width_65k_l0_small
14603
14624
  path: resid_post/layer_5_width_65k_l0_small
14604
14625
  l0: 18
@@ -14608,6 +14629,7 @@ gemma-scope-2-270m-it-res:
14608
14629
  - id: layer_9_width_16k_l0_medium
14609
14630
  path: resid_post/layer_9_width_16k_l0_medium
14610
14631
  l0: 60
14632
+ neuronpedia: gemma-3-270m-it/9-gemmascope-2-res-16k
14611
14633
  - id: layer_9_width_16k_l0_small
14612
14634
  path: resid_post/layer_9_width_16k_l0_small
14613
14635
  l0: 20
@@ -14626,6 +14648,7 @@ gemma-scope-2-270m-it-res:
14626
14648
  - id: layer_9_width_262k_l0_medium
14627
14649
  path: resid_post/layer_9_width_262k_l0_medium
14628
14650
  l0: 60
14651
+ neuronpedia: gemma-3-270m-it/9-gemmascope-2-res-262k
14629
14652
  - id: layer_9_width_262k_l0_medium_seed_1
14630
14653
  path: resid_post/layer_9_width_262k_l0_medium_seed_1
14631
14654
  l0: 60
@@ -14638,6 +14661,7 @@ gemma-scope-2-270m-it-res:
14638
14661
  - id: layer_9_width_65k_l0_medium
14639
14662
  path: resid_post/layer_9_width_65k_l0_medium
14640
14663
  l0: 60
14664
+ neuronpedia: gemma-3-270m-it/9-gemmascope-2-res-65k
14641
14665
  - id: layer_9_width_65k_l0_small
14642
14666
  path: resid_post/layer_9_width_65k_l0_small
14643
14667
  l0: 20
@@ -18727,6 +18751,7 @@ gemma-scope-2-1b-it-res:
18727
18751
  - id: layer_13_width_16k_l0_medium
18728
18752
  path: resid_post/layer_13_width_16k_l0_medium
18729
18753
  l0: 60
18754
+ neuronpedia: gemma-3-1b-it/13-gemmascope-2-res-16k
18730
18755
  - id: layer_13_width_16k_l0_small
18731
18756
  path: resid_post/layer_13_width_16k_l0_small
18732
18757
  l0: 20
@@ -18745,6 +18770,7 @@ gemma-scope-2-1b-it-res:
18745
18770
  - id: layer_13_width_262k_l0_medium
18746
18771
  path: resid_post/layer_13_width_262k_l0_medium
18747
18772
  l0: 60
18773
+ neuronpedia: gemma-3-1b-it/13-gemmascope-2-res-262k
18748
18774
  - id: layer_13_width_262k_l0_medium_seed_1
18749
18775
  path: resid_post/layer_13_width_262k_l0_medium_seed_1
18750
18776
  l0: 60
@@ -18757,6 +18783,7 @@ gemma-scope-2-1b-it-res:
18757
18783
  - id: layer_13_width_65k_l0_medium
18758
18784
  path: resid_post/layer_13_width_65k_l0_medium
18759
18785
  l0: 60
18786
+ neuronpedia: gemma-3-1b-it/13-gemmascope-2-res-65k
18760
18787
  - id: layer_13_width_65k_l0_small
18761
18788
  path: resid_post/layer_13_width_65k_l0_small
18762
18789
  l0: 20
@@ -18766,6 +18793,7 @@ gemma-scope-2-1b-it-res:
18766
18793
  - id: layer_17_width_16k_l0_medium
18767
18794
  path: resid_post/layer_17_width_16k_l0_medium
18768
18795
  l0: 60
18796
+ neuronpedia: gemma-3-1b-it/17-gemmascope-2-res-16k
18769
18797
  - id: layer_17_width_16k_l0_small
18770
18798
  path: resid_post/layer_17_width_16k_l0_small
18771
18799
  l0: 20
@@ -18784,6 +18812,7 @@ gemma-scope-2-1b-it-res:
18784
18812
  - id: layer_17_width_262k_l0_medium
18785
18813
  path: resid_post/layer_17_width_262k_l0_medium
18786
18814
  l0: 60
18815
+ neuronpedia: gemma-3-1b-it/17-gemmascope-2-res-262k
18787
18816
  - id: layer_17_width_262k_l0_medium_seed_1
18788
18817
  path: resid_post/layer_17_width_262k_l0_medium_seed_1
18789
18818
  l0: 60
@@ -18796,6 +18825,7 @@ gemma-scope-2-1b-it-res:
18796
18825
  - id: layer_17_width_65k_l0_medium
18797
18826
  path: resid_post/layer_17_width_65k_l0_medium
18798
18827
  l0: 60
18828
+ neuronpedia: gemma-3-1b-it/17-gemmascope-2-res-65k
18799
18829
  - id: layer_17_width_65k_l0_small
18800
18830
  path: resid_post/layer_17_width_65k_l0_small
18801
18831
  l0: 20
@@ -18805,6 +18835,7 @@ gemma-scope-2-1b-it-res:
18805
18835
  - id: layer_22_width_16k_l0_medium
18806
18836
  path: resid_post/layer_22_width_16k_l0_medium
18807
18837
  l0: 60
18838
+ neuronpedia: gemma-3-1b-it/22-gemmascope-2-res-16k
18808
18839
  - id: layer_22_width_16k_l0_small
18809
18840
  path: resid_post/layer_22_width_16k_l0_small
18810
18841
  l0: 20
@@ -18823,6 +18854,7 @@ gemma-scope-2-1b-it-res:
18823
18854
  - id: layer_22_width_262k_l0_medium
18824
18855
  path: resid_post/layer_22_width_262k_l0_medium
18825
18856
  l0: 60
18857
+ neuronpedia: gemma-3-1b-it/22-gemmascope-2-res-262k
18826
18858
  - id: layer_22_width_262k_l0_medium_seed_1
18827
18859
  path: resid_post/layer_22_width_262k_l0_medium_seed_1
18828
18860
  l0: 60
@@ -18835,6 +18867,7 @@ gemma-scope-2-1b-it-res:
18835
18867
  - id: layer_22_width_65k_l0_medium
18836
18868
  path: resid_post/layer_22_width_65k_l0_medium
18837
18869
  l0: 60
18870
+ neuronpedia: gemma-3-1b-it/22-gemmascope-2-res-65k
18838
18871
  - id: layer_22_width_65k_l0_small
18839
18872
  path: resid_post/layer_22_width_65k_l0_small
18840
18873
  l0: 20
@@ -18844,6 +18877,7 @@ gemma-scope-2-1b-it-res:
18844
18877
  - id: layer_7_width_16k_l0_medium
18845
18878
  path: resid_post/layer_7_width_16k_l0_medium
18846
18879
  l0: 54
18880
+ neuronpedia: gemma-3-1b-it/7-gemmascope-2-res-16k
18847
18881
  - id: layer_7_width_16k_l0_small
18848
18882
  path: resid_post/layer_7_width_16k_l0_small
18849
18883
  l0: 18
@@ -18862,6 +18896,7 @@ gemma-scope-2-1b-it-res:
18862
18896
  - id: layer_7_width_262k_l0_medium
18863
18897
  path: resid_post/layer_7_width_262k_l0_medium
18864
18898
  l0: 54
18899
+ neuronpedia: gemma-3-1b-it/7-gemmascope-2-res-262k
18865
18900
  - id: layer_7_width_262k_l0_medium_seed_1
18866
18901
  path: resid_post/layer_7_width_262k_l0_medium_seed_1
18867
18902
  l0: 54
@@ -18874,6 +18909,7 @@ gemma-scope-2-1b-it-res:
18874
18909
  - id: layer_7_width_65k_l0_medium
18875
18910
  path: resid_post/layer_7_width_65k_l0_medium
18876
18911
  l0: 54
18912
+ neuronpedia: gemma-3-1b-it/7-gemmascope-2-res-65k
18877
18913
  - id: layer_7_width_65k_l0_small
18878
18914
  path: resid_post/layer_7_width_65k_l0_small
18879
18915
  l0: 18
sae_lens/saes/sae.py CHANGED
@@ -45,7 +45,6 @@ from sae_lens.loading.pretrained_sae_loaders import (
45
45
  )
46
46
  from sae_lens.loading.pretrained_saes_directory import (
47
47
  get_config_overrides,
48
- get_norm_scaling_factor,
49
48
  get_pretrained_saes_directory,
50
49
  get_releases_for_repo_id,
51
50
  get_repo_id_and_folder_name,
@@ -638,24 +637,6 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
638
637
  stacklevel=2,
639
638
  )
640
639
  elif sae_id not in sae_directory[release].saes_map:
641
- # Handle special cases like Gemma Scope
642
- if (
643
- "gemma-scope" in release
644
- and "canonical" not in release
645
- and f"{release}-canonical" in sae_directory
646
- ):
647
- canonical_ids = list(
648
- sae_directory[release + "-canonical"].saes_map.keys()
649
- )
650
- # Shorten the lengthy string of valid IDs
651
- if len(canonical_ids) > 5:
652
- str_canonical_ids = str(canonical_ids[:5])[:-1] + ", ...]"
653
- else:
654
- str_canonical_ids = str(canonical_ids)
655
- value_suffix = f" If you don't want to specify an L0 value, consider using release {release}-canonical which has valid IDs {str_canonical_ids}"
656
- else:
657
- value_suffix = ""
658
-
659
640
  valid_ids = list(sae_directory[release].saes_map.keys())
660
641
  # Shorten the lengthy string of valid IDs
661
642
  if len(valid_ids) > 5:
@@ -665,7 +646,6 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
665
646
 
666
647
  raise ValueError(
667
648
  f"ID {sae_id} not found in release {release}. Valid IDs are {str_valid_ids}."
668
- + value_suffix
669
649
  )
670
650
 
671
651
  conversion_loader = (
@@ -702,17 +682,6 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
702
682
  sae.process_state_dict_for_loading(state_dict)
703
683
  sae.load_state_dict(state_dict, assign=True)
704
684
 
705
- # Apply normalization if needed
706
- if cfg_dict.get("normalize_activations") == "expected_average_only_in":
707
- norm_scaling_factor = get_norm_scaling_factor(release, sae_id)
708
- if norm_scaling_factor is not None:
709
- sae.fold_activation_norm_scaling_factor(norm_scaling_factor)
710
- cfg_dict["normalize_activations"] = "none"
711
- else:
712
- warnings.warn(
713
- f"norm_scaling_factor not found for {release} and {sae_id}, but normalize_activations is 'expected_average_only_in'. Skipping normalization folding."
714
- )
715
-
716
685
  # the loaders should already handle the dtype / device conversion
717
686
  # but this is a fallback to guarantee the SAE is on the correct device and dtype
718
687
  return (
@@ -4,7 +4,7 @@ TemporalSAE decomposes activations into:
4
4
  1. Predicted codes (from attention over context)
5
5
  2. Novel codes (sparse features of the residual)
6
6
 
7
- See: https://arxiv.org/abs/2410.04185
7
+ See: https://arxiv.org/pdf/2511.01836
8
8
  """
9
9
 
10
10
  import math
@@ -50,6 +50,13 @@ from sae_lens.synthetic.plotting import (
50
50
  find_best_feature_ordering_from_sae,
51
51
  plot_sae_feature_similarity,
52
52
  )
53
+ from sae_lens.synthetic.stats import (
54
+ CorrelationMatrixStats,
55
+ SuperpositionStats,
56
+ compute_correlation_matrix_stats,
57
+ compute_low_rank_correlation_matrix_stats,
58
+ compute_superposition_stats,
59
+ )
53
60
  from sae_lens.synthetic.training import (
54
61
  SyntheticActivationIterator,
55
62
  train_toy_sae,
@@ -80,6 +87,12 @@ __all__ = [
80
87
  "orthogonal_initializer",
81
88
  "FeatureDictionaryInitializer",
82
89
  "cosine_similarities",
90
+ # Statistics
91
+ "compute_correlation_matrix_stats",
92
+ "compute_low_rank_correlation_matrix_stats",
93
+ "compute_superposition_stats",
94
+ "CorrelationMatrixStats",
95
+ "SuperpositionStats",
83
96
  # Training utilities
84
97
  "SyntheticActivationIterator",
85
98
  "SyntheticDataEvalResult",
@@ -3,6 +3,7 @@ from typing import NamedTuple
3
3
 
4
4
  import torch
5
5
 
6
+ from sae_lens import logger
6
7
  from sae_lens.util import str_to_dtype
7
8
 
8
9
 
@@ -268,7 +269,7 @@ def generate_random_correlation_matrix(
268
269
  def generate_random_low_rank_correlation_matrix(
269
270
  num_features: int,
270
271
  rank: int,
271
- correlation_scale: float = 0.1,
272
+ correlation_scale: float = 0.075,
272
273
  seed: int | None = None,
273
274
  device: torch.device | str = "cpu",
274
275
  dtype: torch.dtype | str = torch.float32,
@@ -331,20 +332,17 @@ def generate_random_low_rank_correlation_matrix(
331
332
  factor_sq_sum = (factor**2).sum(dim=1)
332
333
  diag_term = 1 - factor_sq_sum
333
334
 
334
- # Ensure diagonal terms are at least _MIN_DIAG for numerical stability
335
- # If any diagonal term is too small, scale down the factor matrix
336
- if torch.any(diag_term < _MIN_DIAG):
337
- # Scale factor so max row norm squared is at most (1 - _MIN_DIAG)
338
- # This ensures all diagonal terms are >= _MIN_DIAG
339
- max_factor_contribution = 1 - _MIN_DIAG
340
- max_sq_sum = factor_sq_sum.max()
341
- scale = torch.sqrt(
342
- torch.tensor(max_factor_contribution, device=device, dtype=dtype)
343
- / max_sq_sum
335
+ # alternatively, we can rescale each row independently to ensure the diagonal is 1
336
+ mask = diag_term < _MIN_DIAG
337
+ factor[mask, :] *= torch.sqrt((1 - _MIN_DIAG) / factor_sq_sum[mask].unsqueeze(1))
338
+ factor_sq_sum = (factor**2).sum(dim=1)
339
+ diag_term = 1 - factor_sq_sum
340
+
341
+ total_rescaled = mask.sum().item()
342
+ if total_rescaled > 0:
343
+ logger.warning(
344
+ f"{total_rescaled} / {num_features} rows were capped. Either reduce the rank or reduce the correlation_scale to avoid rescaling."
344
345
  )
345
- factor = factor * scale
346
- factor_sq_sum = (factor**2).sum(dim=1)
347
- diag_term = 1 - factor_sq_sum
348
346
 
349
347
  return LowRankCorrelationMatrix(
350
348
  correlation_factor=factor, correlation_diag=diag_term
@@ -0,0 +1,205 @@
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+
5
+ from sae_lens.synthetic.correlation import LowRankCorrelationMatrix
6
+ from sae_lens.synthetic.feature_dictionary import FeatureDictionary
7
+
8
+
9
+ @dataclass
10
+ class CorrelationMatrixStats:
11
+ """Statistics computed from a correlation matrix."""
12
+
13
+ rms_correlation: float # Root mean square of off-diagonal correlations
14
+ mean_correlation: float # Mean of off-diagonal correlations (not absolute)
15
+ num_features: int
16
+
17
+
18
+ @torch.no_grad()
19
+ def compute_correlation_matrix_stats(
20
+ correlation_matrix: torch.Tensor,
21
+ ) -> CorrelationMatrixStats:
22
+ """Compute correlation statistics from a dense correlation matrix.
23
+
24
+ Args:
25
+ correlation_matrix: Dense correlation matrix of shape (n, n)
26
+
27
+ Returns:
28
+ CorrelationMatrixStats with correlation statistics
29
+ """
30
+ num_features = correlation_matrix.shape[0]
31
+
32
+ # Extract off-diagonal elements
33
+ mask = ~torch.eye(num_features, dtype=torch.bool, device=correlation_matrix.device)
34
+ off_diag = correlation_matrix[mask]
35
+
36
+ rms_correlation = (off_diag**2).mean().sqrt().item()
37
+ mean_correlation = off_diag.mean().item()
38
+
39
+ return CorrelationMatrixStats(
40
+ rms_correlation=rms_correlation,
41
+ mean_correlation=mean_correlation,
42
+ num_features=num_features,
43
+ )
44
+
45
+
46
+ @torch.no_grad()
47
+ def compute_low_rank_correlation_matrix_stats(
48
+ correlation_matrix: LowRankCorrelationMatrix,
49
+ ) -> CorrelationMatrixStats:
50
+ """Compute correlation statistics from a LowRankCorrelationMatrix.
51
+
52
+ The correlation matrix is represented as:
53
+ correlation = factor @ factor.T + diag(diag_term)
54
+
55
+ The off-diagonal elements are simply factor @ factor.T (the diagonal term
56
+ only affects the diagonal).
57
+
58
+ All statistics are computed efficiently in O(n*r²) time and O(r²) memory
59
+ without materializing the full n×n correlation matrix.
60
+
61
+ Args:
62
+ correlation_matrix: Low-rank correlation matrix
63
+
64
+ Returns:
65
+ CorrelationMatrixStats with correlation statistics
66
+ """
67
+
68
+ factor = correlation_matrix.correlation_factor
69
+ num_features = factor.shape[0]
70
+ num_off_diag = num_features * (num_features - 1)
71
+
72
+ # RMS correlation: uses ||F @ F.T||_F² = ||F.T @ F||_F²
73
+ # This avoids computing the (num_features, num_features) matrix
74
+ G = factor.T @ factor # (rank, rank) - small!
75
+ frobenius_sq = (G**2).sum()
76
+ row_norms_sq = (factor**2).sum(dim=1) # ||F[i]||² for each row
77
+ diag_sq_sum = (row_norms_sq**2).sum() # Σᵢ ||F[i]||⁴
78
+ off_diag_sq_sum = frobenius_sq - diag_sq_sum
79
+ rms_correlation = (off_diag_sq_sum / num_off_diag).sqrt().item()
80
+
81
+ # Mean correlation (not absolute): sum(C) = ||col_sums(F)||², trace(C) = Σ||F[i]||²
82
+ col_sums = factor.sum(dim=0) # (rank,)
83
+ sum_all = (col_sums**2).sum() # 1ᵀ C 1
84
+ trace_C = row_norms_sq.sum()
85
+ mean_correlation = ((sum_all - trace_C) / num_off_diag).item()
86
+
87
+ return CorrelationMatrixStats(
88
+ rms_correlation=rms_correlation,
89
+ mean_correlation=mean_correlation,
90
+ num_features=num_features,
91
+ )
92
+
93
+
94
+ @dataclass
95
+ class SuperpositionStats:
96
+ """Statistics measuring superposition in a feature dictionary."""
97
+
98
+ # Per-latent statistics: for each latent, max and percentile of |cos_sim| with others
99
+ max_abs_cos_sims: torch.Tensor # Shape: (num_features,)
100
+ percentile_abs_cos_sims: dict[int, torch.Tensor] # {percentile: (num_features,)}
101
+
102
+ # Summary statistics (means of the per-latent values)
103
+ mean_max_abs_cos_sim: float
104
+ mean_percentile_abs_cos_sim: dict[int, float]
105
+ mean_abs_cos_sim: float # Mean |cos_sim| across all pairs
106
+
107
+ # Metadata
108
+ num_features: int
109
+ hidden_dim: int
110
+
111
+
112
+ @torch.no_grad()
113
+ def compute_superposition_stats(
114
+ feature_dictionary: FeatureDictionary,
115
+ batch_size: int = 1024,
116
+ device: str | torch.device | None = None,
117
+ percentiles: list[int] | None = None,
118
+ ) -> SuperpositionStats:
119
+ """Compute superposition statistics for a feature dictionary.
120
+
121
+ Computes pairwise cosine similarities in batches to handle large dictionaries.
122
+
123
+ For each latent i, computes:
124
+
125
+ - max |cos_sim(i, j)| over all j != i
126
+ - kth percentile of |cos_sim(i, j)| over all j != i (for each k in percentiles)
127
+
128
+ Args:
129
+ feature_dictionary: FeatureDictionary containing the feature vectors
130
+ batch_size: Number of features to process per batch
131
+ device: Device for computation (defaults to feature dictionary's device)
132
+ percentiles: List of percentiles to compute per latent (default: [95, 99])
133
+
134
+ Returns:
135
+ SuperpositionStats with superposition metrics
136
+ """
137
+ if percentiles is None:
138
+ percentiles = [95, 99]
139
+
140
+ feature_vectors = feature_dictionary.feature_vectors
141
+ num_features, hidden_dim = feature_vectors.shape
142
+
143
+ if num_features < 2:
144
+ raise ValueError("Need at least 2 features to compute superposition stats")
145
+ if device is None:
146
+ device = feature_vectors.device
147
+
148
+ # Normalize features to unit norm for cosine similarity
149
+ features_normalized = feature_vectors.to(device).float()
150
+ norms = torch.linalg.norm(features_normalized, dim=1, keepdim=True)
151
+ features_normalized = features_normalized / norms.clamp(min=1e-8)
152
+
153
+ # Track per-latent statistics
154
+ max_abs_cos_sims = torch.zeros(num_features, device=device)
155
+ percentile_abs_cos_sims = {
156
+ p: torch.zeros(num_features, device=device) for p in percentiles
157
+ }
158
+ sum_abs_cos_sim = 0.0
159
+ n_pairs = 0
160
+
161
+ # Process in batches: for each batch of features, compute similarities with all others
162
+ for i in range(0, num_features, batch_size):
163
+ batch_end = min(i + batch_size, num_features)
164
+ batch = features_normalized[i:batch_end] # (batch_size, hidden_dim)
165
+
166
+ # Compute cosine similarities with all features: (batch_size, num_features)
167
+ cos_sims = batch @ features_normalized.T
168
+
169
+ # Absolute cosine similarities
170
+ abs_cos_sims = cos_sims.abs()
171
+
172
+ # Process each latent in the batch
173
+ for j, idx in enumerate(range(i, batch_end)):
174
+ # Get similarities with all other features (exclude self)
175
+ row = abs_cos_sims[j].clone()
176
+ row[idx] = 0.0 # Exclude self for max
177
+ max_abs_cos_sims[idx] = row.max()
178
+
179
+ # For percentiles, exclude self and compute
180
+ other_sims = torch.cat([abs_cos_sims[j, :idx], abs_cos_sims[j, idx + 1 :]])
181
+ for p in percentiles:
182
+ percentile_abs_cos_sims[p][idx] = torch.quantile(other_sims, p / 100.0)
183
+
184
+ # Sum for mean computation (only count pairs once - with features after this one)
185
+ sum_abs_cos_sim += abs_cos_sims[j, idx + 1 :].sum().item()
186
+ n_pairs += num_features - idx - 1
187
+
188
+ # Compute summary statistics
189
+ mean_max_abs_cos_sim = max_abs_cos_sims.mean().item()
190
+ mean_percentile_abs_cos_sim = {
191
+ p: percentile_abs_cos_sims[p].mean().item() for p in percentiles
192
+ }
193
+ mean_abs_cos_sim = sum_abs_cos_sim / n_pairs if n_pairs > 0 else 0.0
194
+
195
+ return SuperpositionStats(
196
+ max_abs_cos_sims=max_abs_cos_sims.cpu(),
197
+ percentile_abs_cos_sims={
198
+ p: v.cpu() for p, v in percentile_abs_cos_sims.items()
199
+ },
200
+ mean_max_abs_cos_sim=mean_max_abs_cos_sim,
201
+ mean_percentile_abs_cos_sim=mean_percentile_abs_cos_sim,
202
+ mean_abs_cos_sim=mean_abs_cos_sim,
203
+ num_features=num_features,
204
+ hidden_dim=hidden_dim,
205
+ )
@@ -28,7 +28,9 @@ class ActivationScaler:
28
28
  ) -> float:
29
29
  norms_per_batch: list[float] = []
30
30
  for _ in tqdm(
31
- range(n_batches_for_norm_estimate), desc="Estimating norm scaling factor"
31
+ range(n_batches_for_norm_estimate),
32
+ desc="Estimating norm scaling factor",
33
+ leave=False,
32
34
  ):
33
35
  acts = next(data_provider)
34
36
  norms_per_batch.append(acts.norm(dim=-1).mean().item())
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sae-lens
3
- Version: 6.29.1
3
+ Version: 6.33.0
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  License-File: LICENSE
@@ -27,7 +27,7 @@ Requires-Dist: pyyaml (>=6.0.1,<7.0.0)
27
27
  Requires-Dist: safetensors (>=0.4.2,<1.0.0)
28
28
  Requires-Dist: simple-parsing (>=0.1.6,<0.2.0)
29
29
  Requires-Dist: tenacity (>=9.0.0)
30
- Requires-Dist: transformer-lens (>=2.16.1,<3.0.0)
30
+ Requires-Dist: transformer-lens (>=2.16.1)
31
31
  Requires-Dist: transformers (>=4.38.1,<5.0.0)
32
32
  Requires-Dist: typing-extensions (>=4.10.0,<5.0.0)
33
33
  Project-URL: Homepage, https://decoderesearch.github.io/SAELens
@@ -1,18 +1,20 @@
1
- sae_lens/__init__.py,sha256=emqKVNiJwD8YtYhtgHJyAT8YSX1QmruQYuG-J4CStC4,4788
2
- sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- sae_lens/analysis/hooked_sae_transformer.py,sha256=dQRgGVwce8XwylL2AzJE7l9elhtMRFCs2hdUj-Qyy4g,14038
1
+ sae_lens/__init__.py,sha256=gHaxlySzLskrAUg2oUZ3aOpnI3U_AVIHce-agGJL9rI,5168
2
+ sae_lens/analysis/__init__.py,sha256=FZExlMviNwWR7OGUSGRbd0l-yUDGSp80gglI_ivILrY,412
3
+ sae_lens/analysis/compat.py,sha256=cgE3nhFcJTcuhppxbL71VanJS7YqVEOefuneB5eOaPw,538
4
+ sae_lens/analysis/hooked_sae_transformer.py,sha256=LpnjxSAcItqqXA4SJyZuxY4Ki0UOuWV683wg9laYAsY,14050
4
5
  sae_lens/analysis/neuronpedia_integration.py,sha256=Gx1W7hUBEuMoasNcnOnZ1wmqbXDd1pSZ1nqKEya1HQc,4962
6
+ sae_lens/analysis/sae_transformer_bridge.py,sha256=xpJRRcB0g47EOQcmNCwMyrJJsbqMsGxVViDrV6C3upU,14916
5
7
  sae_lens/cache_activations_runner.py,sha256=TjqNWIc46Nw09jHWFjzQzgzG5wdu_87Ahe-iFjI5_0Q,13117
6
- sae_lens/config.py,sha256=sseYcRMsAyopj8FICup1RGTXjFxzAithZ2OH7OpQV3Y,30839
8
+ sae_lens/config.py,sha256=V0BXV8rvpbm5YuVukow9FURPpdyE4HSflbdymAo0Ycg,31205
7
9
  sae_lens/constants.py,sha256=CM-h9AjZNAl2aP7hVpKk7YsFHpu-_Lfhhmq2d5qPEVc,887
8
- sae_lens/evals.py,sha256=P0NUsJeGzYxFBiVKhbPzd72IFKY4gH40HHlEZ3jEAmg,39598
10
+ sae_lens/evals.py,sha256=nEZpUfEUN-plw6Mj9GEqm-cU_tb1qrIF9km9ktQ0vVU,39624
9
11
  sae_lens/llm_sae_training_runner.py,sha256=M7BK55gSFYu2qFQKABHX3c8i46P1LfODCeyHFzGGuqU,15196
10
12
  sae_lens/load_model.py,sha256=C8AMykctj6H7tz_xRwB06-EXj6TfW64PtSJZR5Jxn1Y,8649
11
13
  sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- sae_lens/loading/pretrained_sae_loaders.py,sha256=hHMlew1u6zVlbzvS9S_SfUPnAG0_OAjjIcjoUTIUZrU,63657
13
- sae_lens/loading/pretrained_saes_directory.py,sha256=1at_aQbD8WFywchQCKuwfP-yvCq_Z2aUYrpKDnSN5Nc,4283
14
+ sae_lens/loading/pretrained_sae_loaders.py,sha256=kshvA0NivOc7B3sL19lHr_zrC_DDfW2T6YWb5j0hgAk,63930
15
+ sae_lens/loading/pretrained_saes_directory.py,sha256=lSnHl77IO5dd7iO21ynCzZNMrzuJAT8Za4W5THNq0qw,3554
14
16
  sae_lens/pretokenize_runner.py,sha256=amJwIz3CKi2s2wNQn-10E7eAV7VFhNqtFDNTeTkwEI8,7133
15
- sae_lens/pretrained_saes.yaml,sha256=Nq43dTcFvDDONTuJ9Me_HQ5nHqr9BdbP5-ZJGXj0TAQ,1509932
17
+ sae_lens/pretrained_saes.yaml,sha256=IVBLLR8_XNllJ1O-kVv9ED4u0u44Yn8UOL9R-f8Idp4,1511936
16
18
  sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
17
19
  sae_lens/saes/__init__.py,sha256=SBqPaP6Gl5uPFwHlumAZATC4Wd26xKIYLAAAo4MSa5Q,2200
18
20
  sae_lens/saes/batchtopk_sae.py,sha256=x4EbgZl0GUickRPcCmtKNGS2Ra3Uy1Z1OtF2FnrSabQ,5422
@@ -20,24 +22,25 @@ sae_lens/saes/gated_sae.py,sha256=V_2ZNlV4gRD-rX5JSx1xqY7idT8ChfdQ5yxWDdu_6hg,88
20
22
  sae_lens/saes/jumprelu_sae.py,sha256=miiF-xI_yXdV9EkKjwAbU9zSMsx9KtKCz5YdXEzkN8g,13313
21
23
  sae_lens/saes/matching_pursuit_sae.py,sha256=08_G9p1YMLnE5qZVCPp6gll-iG6nHRbMMASf4_bkFt8,13207
22
24
  sae_lens/saes/matryoshka_batchtopk_sae.py,sha256=Qr6htt1HHOuO9FXI9hyaPSnGFIiJG-v7y1t1CEmkFzM,5995
23
- sae_lens/saes/sae.py,sha256=xRmgiLuaFlDCv8SyLbL-5TwdrWHpNLqSGe8mC1L6WcI,40942
25
+ sae_lens/saes/sae.py,sha256=wkwqzNragj-1189cV52S3_XeRtEgBd2ZNwvL2EsKkWw,39429
24
26
  sae_lens/saes/standard_sae.py,sha256=_hldNZkFPAf9VGrxouR1-tN8T2OEk8IkWBcXoatrC1o,5749
25
- sae_lens/saes/temporal_sae.py,sha256=83Ap4mYGfdN3sKdPF8nKjhdXph3-7E2QuLobqJ_YuoM,13273
27
+ sae_lens/saes/temporal_sae.py,sha256=S44sPddVj2xujA02CC8gT1tG0in7c_CSAhspu9FHbaA,13273
26
28
  sae_lens/saes/topk_sae.py,sha256=vrMRPrCQR1o8G_kXqY_EAoGZARupkQNFB2dNZVLsusE,21073
27
29
  sae_lens/saes/transcoder.py,sha256=CTpJs8ASOK06npih7gZHygZuxqTR7HICWlOYfTiKjI4,13501
28
- sae_lens/synthetic/__init__.py,sha256=MtTnGkTfHV2WjkIgs7zZyx10EK9U5fjOHXy69Aq3uKw,3095
30
+ sae_lens/synthetic/__init__.py,sha256=hRRA3xhEQUacGyFbJXkLVYg_8A1bbSYYWlVovb0g4KU,3503
29
31
  sae_lens/synthetic/activation_generator.py,sha256=8L9nwC4jFRv_wg3QN-n1sFwX8w1NqwJMysWaJ41lLlY,15197
30
- sae_lens/synthetic/correlation.py,sha256=tMTLo9fBfDpeXwqhyUgFqnTipj9x2W0t4oEtNxB7AG0,13256
32
+ sae_lens/synthetic/correlation.py,sha256=tD8J9abWfuFtGZrEbbFn4P8FeTcNKF2V5JhBLwDUmkg,13146
31
33
  sae_lens/synthetic/evals.py,sha256=Nhi314ZnRgLfhBj-3tm_zzI-pGyFTcwllDXbIpPFXeU,4584
32
34
  sae_lens/synthetic/feature_dictionary.py,sha256=Nd4xjSTxKMnKilZ3uYi8Gv5SS5D4bv4wHiSL1uGB69E,6933
33
35
  sae_lens/synthetic/firing_probabilities.py,sha256=yclz1pWl5gE1r8LAxFvzQS88Lxwk5-3r8BCX9HLVejA,3370
34
36
  sae_lens/synthetic/hierarchy.py,sha256=nm7nwnTswktVJeKUsRZ0hLOdXcFWGbxnA1b6lefHm-4,33592
35
37
  sae_lens/synthetic/initialization.py,sha256=orMGW-786wRDHIS2W7bEH0HmlVFQ4g2z4bnnwdv5w4s,1386
36
38
  sae_lens/synthetic/plotting.py,sha256=5lFrej1QOkGAcImFNo5-o-8mI_rUVqvEI57KzUQPPtQ,8208
39
+ sae_lens/synthetic/stats.py,sha256=BoDPKDx8pgFF5Ko_IaBRZTczm7-ANUIRjjF5W5Qh3Lk,7441
37
40
  sae_lens/synthetic/training.py,sha256=fHcX2cZ6nDupr71GX0Gk17f1NvQ0SKIVXIA6IuAb2dw,5692
38
41
  sae_lens/tokenization_and_batching.py,sha256=uoHtAs9z3XqG0Fh-iQVYVlrbyB_E3kFFhrKU30BosCo,5438
39
42
  sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
40
- sae_lens/training/activation_scaler.py,sha256=FzNfgBplLWmyiSlZ6TUvE-nur3lOiGTrlvC97ys8S24,1973
43
+ sae_lens/training/activation_scaler.py,sha256=SJZzIMX1TGdeN_wT_wqgx2ij6f4p5Dm5lWH6DGNSt5g,2011
41
44
  sae_lens/training/activations_store.py,sha256=kp4-6R4rTJUSt-g-Ifg5B1h7iIe7jZj-XQSKDvDpQMI,32187
42
45
  sae_lens/training/mixing_buffer.py,sha256=1Z-S2CcQXMWGxRZJFnXeZFxbZcALkO_fP6VO37XdJQQ,2519
43
46
  sae_lens/training/optim.py,sha256=bJpqqcK4enkcPvQAJkeH4Ci1LUOlfjIMTv6-IlaAbRA,5588
@@ -46,7 +49,7 @@ sae_lens/training/types.py,sha256=1FpLx_Doda9vZpmfm-x1e8wGBYpyhe9Kpb_JuM5nIFM,90
46
49
  sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
47
50
  sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
48
51
  sae_lens/util.py,sha256=oIMoeyEP2IzcPFmRbKUzOAycgEyMcOasGeO_BGVZbc4,4846
49
- sae_lens-6.29.1.dist-info/METADATA,sha256=0Pp1L3vNiUGzkMox_BdQR6B064tTHFgwAPGJz8FY8UM,6573
50
- sae_lens-6.29.1.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
51
- sae_lens-6.29.1.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
52
- sae_lens-6.29.1.dist-info/RECORD,,
52
+ sae_lens-6.33.0.dist-info/METADATA,sha256=X6XqngWTNEsfdaPPWXxtF8Kvdp8fAk8i68sfRtDb2xo,6566
53
+ sae_lens-6.33.0.dist-info/WHEEL,sha256=3ny-bZhpXrU6vSQ1UPG34FoxZBp3lVcvK0LkgUz6VLk,88
54
+ sae_lens-6.33.0.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
55
+ sae_lens-6.33.0.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: poetry-core 2.2.1
2
+ Generator: poetry-core 2.3.0
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any