sae-lens 5.11.0__py3-none-any.whl → 6.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. sae_lens/__init__.py +60 -7
  2. sae_lens/analysis/hooked_sae_transformer.py +12 -12
  3. sae_lens/analysis/neuronpedia_integration.py +16 -14
  4. sae_lens/cache_activations_runner.py +9 -7
  5. sae_lens/config.py +170 -258
  6. sae_lens/constants.py +21 -0
  7. sae_lens/evals.py +59 -44
  8. sae_lens/llm_sae_training_runner.py +377 -0
  9. sae_lens/load_model.py +52 -4
  10. sae_lens/{toolkit → loading}/pretrained_sae_loaders.py +85 -32
  11. sae_lens/registry.py +49 -0
  12. sae_lens/saes/__init__.py +48 -0
  13. sae_lens/saes/gated_sae.py +254 -0
  14. sae_lens/saes/jumprelu_sae.py +348 -0
  15. sae_lens/saes/sae.py +1076 -0
  16. sae_lens/saes/standard_sae.py +178 -0
  17. sae_lens/saes/topk_sae.py +300 -0
  18. sae_lens/training/activation_scaler.py +53 -0
  19. sae_lens/training/activations_store.py +103 -184
  20. sae_lens/training/mixing_buffer.py +56 -0
  21. sae_lens/training/optim.py +60 -36
  22. sae_lens/training/sae_trainer.py +155 -177
  23. sae_lens/training/types.py +5 -0
  24. sae_lens/training/upload_saes_to_huggingface.py +13 -7
  25. sae_lens/util.py +47 -0
  26. {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/METADATA +1 -1
  27. sae_lens-6.0.0.dist-info/RECORD +37 -0
  28. sae_lens/sae.py +0 -747
  29. sae_lens/sae_training_runner.py +0 -251
  30. sae_lens/training/geometric_median.py +0 -101
  31. sae_lens/training/training_sae.py +0 -710
  32. sae_lens-5.11.0.dist-info/RECORD +0 -28
  33. /sae_lens/{toolkit → loading}/__init__.py +0 -0
  34. /sae_lens/{toolkit → loading}/pretrained_saes_directory.py +0 -0
  35. {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/LICENSE +0 -0
  36. {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/WHEEL +0 -0
sae_lens/load_model.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Any, Literal, cast
1
+ from typing import Any, Callable, Literal, cast
2
2
 
3
3
  import torch
4
4
  from transformer_lens import HookedTransformer
@@ -77,6 +77,7 @@ class HookedProxyLM(HookedRootModule):
77
77
  # copied and modified from base HookedRootModule
78
78
  def setup(self):
79
79
  self.mod_dict = {}
80
+ self.named_modules_dict = {}
80
81
  self.hook_dict: dict[str, HookPoint] = {}
81
82
  for name, module in self.model.named_modules():
82
83
  if name == "":
@@ -89,14 +90,21 @@ class HookedProxyLM(HookedRootModule):
89
90
 
90
91
  self.hook_dict[name] = hook_point
91
92
  self.mod_dict[name] = hook_point
93
+ self.named_modules_dict[name] = module
94
+
95
+ def run_with_cache(self, *args: Any, **kwargs: Any): # type: ignore
96
+ if "names_filter" in kwargs:
97
+ # hacky way to make sure that the names_filter is passed to our forward method
98
+ kwargs["_names_filter"] = kwargs["names_filter"]
99
+ return super().run_with_cache(*args, **kwargs)
92
100
 
93
101
  def forward(
94
102
  self,
95
103
  tokens: torch.Tensor,
96
104
  return_type: Literal["both", "logits"] = "logits",
97
105
  loss_per_token: bool = False,
98
- # TODO: implement real support for stop_at_layer
99
106
  stop_at_layer: int | None = None,
107
+ _names_filter: list[str] | None = None,
100
108
  **kwargs: Any,
101
109
  ) -> Output | Loss:
102
110
  # This is just what's needed for evals, not everything that HookedTransformer has
@@ -107,8 +115,28 @@ class HookedProxyLM(HookedRootModule):
107
115
  raise NotImplementedError(
108
116
  "Only return_type supported is 'both' or 'logits' to match what's in evals.py and ActivationsStore"
109
117
  )
110
- output = self.model(tokens)
111
- logits = _extract_logits_from_output(output)
118
+
119
+ stop_hooks = []
120
+ if stop_at_layer is not None and _names_filter is not None:
121
+ if return_type != "logits":
122
+ raise NotImplementedError(
123
+ "stop_at_layer is not supported for return_type='both'"
124
+ )
125
+ stop_manager = StopManager(_names_filter)
126
+
127
+ for hook_name in _names_filter:
128
+ module = self.named_modules_dict[hook_name]
129
+ stop_fn = stop_manager.get_stop_hook_fn(hook_name)
130
+ stop_hooks.append(module.register_forward_hook(stop_fn))
131
+ try:
132
+ output = self.model(tokens)
133
+ logits = _extract_logits_from_output(output)
134
+ except StopForward:
135
+ # If we stop early, we don't care about the return output
136
+ return None # type: ignore
137
+ finally:
138
+ for stop_hook in stop_hooks:
139
+ stop_hook.remove()
112
140
 
113
141
  if return_type == "logits":
114
142
  return logits
@@ -183,3 +211,23 @@ def get_hook_fn(hook_point: HookPoint):
183
211
  return output
184
212
 
185
213
  return hook_fn
214
+
215
+
216
+ class StopForward(Exception):
217
+ pass
218
+
219
+
220
+ class StopManager:
221
+ def __init__(self, hook_names: list[str]):
222
+ self.hook_names = hook_names
223
+ self.total_hook_names = len(set(hook_names))
224
+ self.called_hook_names = set()
225
+
226
+ def get_stop_hook_fn(self, hook_name: str) -> Callable[[Any, Any, Any], Any]:
227
+ def stop_hook_fn(module: Any, input: Any, output: Any) -> Any: # noqa: ARG001
228
+ self.called_hook_names.add(hook_name)
229
+ if len(self.called_hook_names) == self.total_hook_names:
230
+ raise StopForward()
231
+ return output
232
+
233
+ return stop_hook_fn
@@ -7,22 +7,41 @@ import numpy as np
7
7
  import torch
8
8
  from huggingface_hub import hf_hub_download
9
9
  from huggingface_hub.utils import EntryNotFoundError
10
+ from packaging.version import Version
10
11
  from safetensors import safe_open
11
12
  from safetensors.torch import load_file
12
13
 
13
14
  from sae_lens import logger
14
- from sae_lens.config import (
15
+ from sae_lens.constants import (
15
16
  DTYPE_MAP,
16
17
  SAE_CFG_FILENAME,
17
18
  SAE_WEIGHTS_FILENAME,
18
19
  SPARSIFY_WEIGHTS_FILENAME,
19
20
  SPARSITY_FILENAME,
20
21
  )
21
- from sae_lens.toolkit.pretrained_saes_directory import (
22
+ from sae_lens.loading.pretrained_saes_directory import (
22
23
  get_config_overrides,
23
24
  get_pretrained_saes_directory,
24
25
  get_repo_id_and_folder_name,
25
26
  )
27
+ from sae_lens.registry import get_sae_class
28
+ from sae_lens.util import filter_valid_dataclass_fields
29
+
30
+ LLM_METADATA_KEYS = {
31
+ "model_name",
32
+ "hook_name",
33
+ "model_class_name",
34
+ "hook_head_index",
35
+ "model_from_pretrained_kwargs",
36
+ "prepend_bos",
37
+ "exclude_special_tokens",
38
+ "neuronpedia_id",
39
+ "context_size",
40
+ "seqpos_slice",
41
+ "dataset_path",
42
+ "sae_lens_version",
43
+ "sae_lens_training_version",
44
+ }
26
45
 
27
46
 
28
47
  # loaders take in a release, sae_id, device, and whether to force download, and returns a tuple of config, state_dict, and log sparsity
@@ -175,30 +194,69 @@ def get_sae_lens_config_from_disk(
175
194
 
176
195
 
177
196
  def handle_config_defaulting(cfg_dict: dict[str, Any]) -> dict[str, Any]:
197
+ sae_lens_version = cfg_dict.get("sae_lens_version")
198
+ if not sae_lens_version and "metadata" in cfg_dict:
199
+ sae_lens_version = cfg_dict["metadata"].get("sae_lens_version")
200
+
201
+ if not sae_lens_version or Version(sae_lens_version) < Version("6.0.0-rc.0"):
202
+ cfg_dict = handle_pre_6_0_config(cfg_dict)
203
+ return cfg_dict
204
+
205
+
206
+ def handle_pre_6_0_config(cfg_dict: dict[str, Any]) -> dict[str, Any]:
207
+ """
208
+ Format a config dictionary for a Sparse Autoencoder (SAE) to be compatible with the new 6.0 format.
209
+ """
210
+
211
+ rename_keys_map = {
212
+ "hook_point": "hook_name",
213
+ "hook_point_head_index": "hook_head_index",
214
+ "activation_fn_str": "activation_fn",
215
+ }
216
+ new_cfg = {rename_keys_map.get(k, k): v for k, v in cfg_dict.items()}
217
+
178
218
  # Set default values for backwards compatibility
179
- cfg_dict.setdefault("prepend_bos", True)
180
- cfg_dict.setdefault("dataset_trust_remote_code", True)
181
- cfg_dict.setdefault("apply_b_dec_to_input", True)
182
- cfg_dict.setdefault("finetuning_scaling_factor", False)
183
- cfg_dict.setdefault("sae_lens_training_version", None)
184
- cfg_dict.setdefault("activation_fn_str", cfg_dict.get("activation_fn", "relu"))
185
- cfg_dict.setdefault("architecture", "standard")
186
- cfg_dict.setdefault("neuronpedia_id", None)
187
-
188
- if "normalize_activations" in cfg_dict and isinstance(
189
- cfg_dict["normalize_activations"], bool
219
+ new_cfg.setdefault("prepend_bos", True)
220
+ new_cfg.setdefault("dataset_trust_remote_code", True)
221
+ new_cfg.setdefault("apply_b_dec_to_input", True)
222
+ new_cfg.setdefault("finetuning_scaling_factor", False)
223
+ new_cfg.setdefault("sae_lens_training_version", None)
224
+ new_cfg.setdefault("activation_fn", new_cfg.get("activation_fn", "relu"))
225
+ new_cfg.setdefault("architecture", "standard")
226
+ new_cfg.setdefault("neuronpedia_id", None)
227
+ new_cfg.setdefault(
228
+ "reshape_activations",
229
+ "hook_z" if "hook_z" in new_cfg.get("hook_name", "") else "none",
230
+ )
231
+
232
+ if "normalize_activations" in new_cfg and isinstance(
233
+ new_cfg["normalize_activations"], bool
190
234
  ):
191
235
  # backwards compatibility
192
- cfg_dict["normalize_activations"] = (
236
+ new_cfg["normalize_activations"] = (
193
237
  "none"
194
- if not cfg_dict["normalize_activations"]
238
+ if not new_cfg["normalize_activations"]
195
239
  else "expected_average_only_in"
196
240
  )
197
241
 
198
- cfg_dict.setdefault("normalize_activations", "none")
199
- cfg_dict.setdefault("device", "cpu")
242
+ if new_cfg.get("normalize_activations") is None:
243
+ new_cfg["normalize_activations"] = "none"
200
244
 
201
- return cfg_dict
245
+ new_cfg.setdefault("device", "cpu")
246
+
247
+ architecture = new_cfg.get("architecture", "standard")
248
+
249
+ config_class = get_sae_class(architecture)[1]
250
+
251
+ sae_cfg_dict = filter_valid_dataclass_fields(new_cfg, config_class)
252
+ if architecture == "topk" and "activation_fn_kwargs" in new_cfg:
253
+ sae_cfg_dict["k"] = new_cfg["activation_fn_kwargs"]["k"]
254
+
255
+ sae_cfg_dict["metadata"] = {
256
+ k: v for k, v in new_cfg.items() if k in LLM_METADATA_KEYS
257
+ }
258
+ sae_cfg_dict["architecture"] = architecture
259
+ return sae_cfg_dict
202
260
 
203
261
 
204
262
  def get_connor_rob_hook_z_config_from_hf(
@@ -222,9 +280,8 @@ def get_connor_rob_hook_z_config_from_hf(
222
280
  "device": device if device is not None else "cpu",
223
281
  "model_name": "gpt2-small",
224
282
  "hook_name": old_cfg_dict["act_name"],
225
- "hook_layer": old_cfg_dict["layer"],
226
283
  "hook_head_index": None,
227
- "activation_fn_str": "relu",
284
+ "activation_fn": "relu",
228
285
  "apply_b_dec_to_input": True,
229
286
  "finetuning_scaling_factor": False,
230
287
  "sae_lens_training_version": None,
@@ -233,6 +290,7 @@ def get_connor_rob_hook_z_config_from_hf(
233
290
  "context_size": 128,
234
291
  "normalize_activations": "none",
235
292
  "dataset_trust_remote_code": True,
293
+ "reshape_activations": "hook_z",
236
294
  **(cfg_overrides or {}),
237
295
  }
238
296
 
@@ -371,9 +429,8 @@ def get_gemma_2_config_from_hf(
371
429
  "dtype": "float32",
372
430
  "model_name": model_name,
373
431
  "hook_name": hook_name,
374
- "hook_layer": layer,
375
432
  "hook_head_index": None,
376
- "activation_fn_str": "relu",
433
+ "activation_fn": "relu",
377
434
  "finetuning_scaling_factor": False,
378
435
  "sae_lens_training_version": None,
379
436
  "prepend_bos": True,
@@ -493,9 +550,8 @@ def get_llama_scope_config_from_hf(
493
550
  "dtype": "bfloat16",
494
551
  "model_name": model_name,
495
552
  "hook_name": old_cfg_dict["hook_point_in"],
496
- "hook_layer": int(old_cfg_dict["hook_point_in"].split(".")[1]),
497
553
  "hook_head_index": None,
498
- "activation_fn_str": "relu",
554
+ "activation_fn": "relu",
499
555
  "finetuning_scaling_factor": False,
500
556
  "sae_lens_training_version": None,
501
557
  "prepend_bos": True,
@@ -607,8 +663,8 @@ def get_dictionary_learning_config_1_from_hf(
607
663
 
608
664
  hook_point_name = f"blocks.{trainer['layer']}.hook_resid_post"
609
665
 
610
- activation_fn_str = "topk" if trainer["dict_class"] == "AutoEncoderTopK" else "relu"
611
- activation_fn_kwargs = {"k": trainer["k"]} if activation_fn_str == "topk" else {}
666
+ activation_fn = "topk" if trainer["dict_class"] == "AutoEncoderTopK" else "relu"
667
+ activation_fn_kwargs = {"k": trainer["k"]} if activation_fn == "topk" else {}
612
668
 
613
669
  return {
614
670
  "architecture": (
@@ -620,9 +676,8 @@ def get_dictionary_learning_config_1_from_hf(
620
676
  "device": device,
621
677
  "model_name": trainer["lm_name"].split("/")[-1],
622
678
  "hook_name": hook_point_name,
623
- "hook_layer": trainer["layer"],
624
679
  "hook_head_index": None,
625
- "activation_fn_str": activation_fn_str,
680
+ "activation_fn": activation_fn,
626
681
  "activation_fn_kwargs": activation_fn_kwargs,
627
682
  "apply_b_dec_to_input": True,
628
683
  "finetuning_scaling_factor": False,
@@ -659,13 +714,12 @@ def get_deepseek_r1_config_from_hf(
659
714
  "context_size": 1024,
660
715
  "model_name": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
661
716
  "hook_name": f"blocks.{layer}.hook_resid_post",
662
- "hook_layer": layer,
663
717
  "hook_head_index": None,
664
718
  "prepend_bos": True,
665
719
  "dataset_path": "lmsys/lmsys-chat-1m",
666
720
  "dataset_trust_remote_code": True,
667
721
  "sae_lens_training_version": None,
668
- "activation_fn_str": "relu",
722
+ "activation_fn": "relu",
669
723
  "normalize_activations": "none",
670
724
  "device": device,
671
725
  "apply_b_dec_to_input": False,
@@ -818,9 +872,8 @@ def get_llama_scope_r1_distill_config_from_hf(
818
872
  "device": device,
819
873
  "model_name": model_name,
820
874
  "hook_name": huggingface_cfg_dict["hook_point_in"],
821
- "hook_layer": int(huggingface_cfg_dict["hook_point_in"].split(".")[1]),
822
875
  "hook_head_index": None,
823
- "activation_fn_str": "relu",
876
+ "activation_fn": "relu",
824
877
  "finetuning_scaling_factor": False,
825
878
  "sae_lens_training_version": None,
826
879
  "prepend_bos": True,
sae_lens/registry.py ADDED
@@ -0,0 +1,49 @@
1
+ from typing import TYPE_CHECKING, Any
2
+
3
+ # avoid circular imports
4
+ if TYPE_CHECKING:
5
+ from sae_lens.saes.sae import SAE, SAEConfig, TrainingSAE, TrainingSAEConfig
6
+
7
+ SAE_CLASS_REGISTRY: dict[str, tuple["type[SAE[Any]]", "type[SAEConfig]"]] = {}
8
+ SAE_TRAINING_CLASS_REGISTRY: dict[
9
+ str, tuple["type[TrainingSAE[Any]]", "type[TrainingSAEConfig]"]
10
+ ] = {}
11
+
12
+
13
+ def register_sae_class(
14
+ architecture: str,
15
+ sae_class: "type[SAE[Any]]",
16
+ sae_config_class: "type[SAEConfig]",
17
+ ) -> None:
18
+ if architecture in SAE_CLASS_REGISTRY:
19
+ raise ValueError(
20
+ f"SAE class for architecture {architecture} already registered."
21
+ )
22
+ SAE_CLASS_REGISTRY[architecture] = (sae_class, sae_config_class)
23
+
24
+
25
+ def register_sae_training_class(
26
+ architecture: str,
27
+ sae_training_class: "type[TrainingSAE[Any]]",
28
+ sae_training_config_class: "type[TrainingSAEConfig]",
29
+ ) -> None:
30
+ if architecture in SAE_TRAINING_CLASS_REGISTRY:
31
+ raise ValueError(
32
+ f"SAE training class for architecture {architecture} already registered."
33
+ )
34
+ SAE_TRAINING_CLASS_REGISTRY[architecture] = (
35
+ sae_training_class,
36
+ sae_training_config_class,
37
+ )
38
+
39
+
40
+ def get_sae_class(
41
+ architecture: str,
42
+ ) -> tuple["type[SAE[Any]]", "type[SAEConfig]"]:
43
+ return SAE_CLASS_REGISTRY[architecture]
44
+
45
+
46
+ def get_sae_training_class(
47
+ architecture: str,
48
+ ) -> tuple["type[TrainingSAE[Any]]", "type[TrainingSAEConfig]"]:
49
+ return SAE_TRAINING_CLASS_REGISTRY[architecture]
@@ -0,0 +1,48 @@
1
+ from .gated_sae import (
2
+ GatedSAE,
3
+ GatedSAEConfig,
4
+ GatedTrainingSAE,
5
+ GatedTrainingSAEConfig,
6
+ )
7
+ from .jumprelu_sae import (
8
+ JumpReLUSAE,
9
+ JumpReLUSAEConfig,
10
+ JumpReLUTrainingSAE,
11
+ JumpReLUTrainingSAEConfig,
12
+ )
13
+ from .sae import SAE, SAEConfig, TrainingSAE, TrainingSAEConfig
14
+ from .standard_sae import (
15
+ StandardSAE,
16
+ StandardSAEConfig,
17
+ StandardTrainingSAE,
18
+ StandardTrainingSAEConfig,
19
+ )
20
+ from .topk_sae import (
21
+ TopKSAE,
22
+ TopKSAEConfig,
23
+ TopKTrainingSAE,
24
+ TopKTrainingSAEConfig,
25
+ )
26
+
27
+ __all__ = [
28
+ "SAE",
29
+ "SAEConfig",
30
+ "TrainingSAE",
31
+ "TrainingSAEConfig",
32
+ "StandardSAE",
33
+ "StandardSAEConfig",
34
+ "StandardTrainingSAE",
35
+ "StandardTrainingSAEConfig",
36
+ "GatedSAE",
37
+ "GatedSAEConfig",
38
+ "GatedTrainingSAE",
39
+ "GatedTrainingSAEConfig",
40
+ "JumpReLUSAE",
41
+ "JumpReLUSAEConfig",
42
+ "JumpReLUTrainingSAE",
43
+ "JumpReLUTrainingSAEConfig",
44
+ "TopKSAE",
45
+ "TopKSAEConfig",
46
+ "TopKTrainingSAE",
47
+ "TopKTrainingSAEConfig",
48
+ ]
@@ -0,0 +1,254 @@
1
+ from dataclasses import dataclass
2
+ from typing import Any
3
+
4
+ import torch
5
+ from jaxtyping import Float
6
+ from numpy.typing import NDArray
7
+ from torch import nn
8
+ from typing_extensions import override
9
+
10
+ from sae_lens.saes.sae import (
11
+ SAE,
12
+ SAEConfig,
13
+ TrainCoefficientConfig,
14
+ TrainingSAE,
15
+ TrainingSAEConfig,
16
+ TrainStepInput,
17
+ )
18
+ from sae_lens.util import filter_valid_dataclass_fields
19
+
20
+
21
+ @dataclass
22
+ class GatedSAEConfig(SAEConfig):
23
+ """
24
+ Configuration class for a GatedSAE.
25
+ """
26
+
27
+ @override
28
+ @classmethod
29
+ def architecture(cls) -> str:
30
+ return "gated"
31
+
32
+
33
+ class GatedSAE(SAE[GatedSAEConfig]):
34
+ """
35
+ GatedSAE is an inference-only implementation of a Sparse Autoencoder (SAE)
36
+ using a gated linear encoder and a standard linear decoder.
37
+ """
38
+
39
+ b_gate: nn.Parameter
40
+ b_mag: nn.Parameter
41
+ r_mag: nn.Parameter
42
+
43
+ def __init__(self, cfg: GatedSAEConfig, use_error_term: bool = False):
44
+ super().__init__(cfg, use_error_term)
45
+ # Ensure b_enc does not exist for the gated architecture
46
+ self.b_enc = None
47
+
48
+ @override
49
+ def initialize_weights(self) -> None:
50
+ super().initialize_weights()
51
+ _init_weights_gated(self)
52
+
53
+ def encode(
54
+ self, x: Float[torch.Tensor, "... d_in"]
55
+ ) -> Float[torch.Tensor, "... d_sae"]:
56
+ """
57
+ Encode the input tensor into the feature space using a gated encoder.
58
+ This must match the original encode_gated implementation from SAE class.
59
+ """
60
+ # Preprocess the SAE input (casting type, applying hooks, normalization)
61
+ sae_in = self.process_sae_in(x)
62
+
63
+ # Gating path exactly as in original SAE.encode_gated
64
+ gating_pre_activation = sae_in @ self.W_enc + self.b_gate
65
+ active_features = (gating_pre_activation > 0).to(self.dtype)
66
+
67
+ # Magnitude path (weight sharing with gated encoder)
68
+ magnitude_pre_activation = self.hook_sae_acts_pre(
69
+ sae_in @ (self.W_enc * self.r_mag.exp()) + self.b_mag
70
+ )
71
+ feature_magnitudes = self.activation_fn(magnitude_pre_activation)
72
+
73
+ # Combine gating and magnitudes
74
+ return self.hook_sae_acts_post(active_features * feature_magnitudes)
75
+
76
+ def decode(
77
+ self, feature_acts: Float[torch.Tensor, "... d_sae"]
78
+ ) -> Float[torch.Tensor, "... d_in"]:
79
+ """
80
+ Decode the feature activations back into the input space:
81
+ 1) Apply optional finetuning scaling.
82
+ 2) Linear transform plus bias.
83
+ 3) Run any reconstruction hooks and out-normalization if configured.
84
+ 4) If the SAE was reshaping hook_z activations, reshape back.
85
+ """
86
+ # 1) optional finetuning scaling
87
+ # 2) linear transform
88
+ sae_out_pre = feature_acts @ self.W_dec + self.b_dec
89
+ # 3) hooking and normalization
90
+ sae_out_pre = self.hook_sae_recons(sae_out_pre)
91
+ sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
92
+ # 4) reshape if needed (hook_z)
93
+ return self.reshape_fn_out(sae_out_pre, self.d_head)
94
+
95
+ @torch.no_grad()
96
+ def fold_W_dec_norm(self):
97
+ """Override to handle gated-specific parameters."""
98
+ W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
99
+ self.W_dec.data = self.W_dec.data / W_dec_norms
100
+ self.W_enc.data = self.W_enc.data * W_dec_norms.T
101
+
102
+ # Gated-specific parameters need special handling
103
+ self.r_mag.data = self.r_mag.data * W_dec_norms.squeeze()
104
+ self.b_gate.data = self.b_gate.data * W_dec_norms.squeeze()
105
+ self.b_mag.data = self.b_mag.data * W_dec_norms.squeeze()
106
+
107
+ @torch.no_grad()
108
+ def initialize_decoder_norm_constant_norm(self, norm: float = 0.1):
109
+ """Initialize decoder with constant norm."""
110
+ self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
111
+ self.W_dec.data *= norm
112
+
113
+
114
+ @dataclass
115
+ class GatedTrainingSAEConfig(TrainingSAEConfig):
116
+ """
117
+ Configuration class for training a GatedTrainingSAE.
118
+ """
119
+
120
+ l1_coefficient: float = 1.0
121
+ l1_warm_up_steps: int = 0
122
+
123
+ @override
124
+ @classmethod
125
+ def architecture(cls) -> str:
126
+ return "gated"
127
+
128
+
129
+ class GatedTrainingSAE(TrainingSAE[GatedTrainingSAEConfig]):
130
+ """
131
+ GatedTrainingSAE is a concrete implementation of BaseTrainingSAE for the "gated" SAE architecture.
132
+ It implements:
133
+ - initialize_weights: sets up gating parameters (as in GatedSAE) plus optional training-specific init.
134
+ - encode: calls encode_with_hidden_pre (standard training approach).
135
+ - decode: linear transformation + hooking, same as GatedSAE or StandardTrainingSAE.
136
+ - encode_with_hidden_pre: gating logic + optional noise injection for training.
137
+ - calculate_aux_loss: includes an auxiliary reconstruction path and gating-based sparsity penalty.
138
+ - training_forward_pass: calls encode_with_hidden_pre, decode, and sums up MSE + gating losses.
139
+ """
140
+
141
+ b_gate: nn.Parameter # type: ignore
142
+ b_mag: nn.Parameter # type: ignore
143
+ r_mag: nn.Parameter # type: ignore
144
+
145
+ def __init__(self, cfg: GatedTrainingSAEConfig, use_error_term: bool = False):
146
+ if use_error_term:
147
+ raise ValueError(
148
+ "GatedSAE does not support `use_error_term`. Please set `use_error_term=False`."
149
+ )
150
+ super().__init__(cfg, use_error_term)
151
+
152
+ def initialize_weights(self) -> None:
153
+ super().initialize_weights()
154
+ _init_weights_gated(self)
155
+
156
+ def encode_with_hidden_pre(
157
+ self, x: Float[torch.Tensor, "... d_in"]
158
+ ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
159
+ """
160
+ Gated forward pass with pre-activation (for training).
161
+ We also inject noise if self.training is True.
162
+ """
163
+ sae_in = self.process_sae_in(x)
164
+
165
+ # Gating path
166
+ gating_pre_activation = sae_in @ self.W_enc + self.b_gate
167
+ active_features = (gating_pre_activation > 0).to(self.dtype)
168
+
169
+ # Magnitude path
170
+ magnitude_pre_activation = sae_in @ (self.W_enc * self.r_mag.exp()) + self.b_mag
171
+ magnitude_pre_activation = self.hook_sae_acts_pre(magnitude_pre_activation)
172
+
173
+ feature_magnitudes = self.activation_fn(magnitude_pre_activation)
174
+
175
+ # Combine gating path and magnitude path
176
+ feature_acts = self.hook_sae_acts_post(active_features * feature_magnitudes)
177
+
178
+ # Return both the final feature activations and the pre-activation (for logging or penalty)
179
+ return feature_acts, magnitude_pre_activation
180
+
181
+ def calculate_aux_loss(
182
+ self,
183
+ step_input: TrainStepInput,
184
+ feature_acts: torch.Tensor,
185
+ hidden_pre: torch.Tensor,
186
+ sae_out: torch.Tensor,
187
+ ) -> dict[str, torch.Tensor]:
188
+ # Re-center the input if apply_b_dec_to_input is set
189
+ sae_in_centered = step_input.sae_in - (
190
+ self.b_dec * self.cfg.apply_b_dec_to_input
191
+ )
192
+
193
+ # The gating pre-activation (pi_gate) for the auxiliary path
194
+ pi_gate = sae_in_centered @ self.W_enc + self.b_gate
195
+ pi_gate_act = torch.relu(pi_gate)
196
+
197
+ # L1-like penalty scaled by W_dec norms
198
+ l1_loss = (
199
+ step_input.coefficients["l1"]
200
+ * torch.sum(pi_gate_act * self.W_dec.norm(dim=1), dim=-1).mean()
201
+ )
202
+
203
+ # Aux reconstruction: reconstruct x purely from gating path
204
+ via_gate_reconstruction = pi_gate_act @ self.W_dec + self.b_dec
205
+ aux_recon_loss = (
206
+ (via_gate_reconstruction - step_input.sae_in).pow(2).sum(dim=-1).mean()
207
+ )
208
+
209
+ # Return both losses separately
210
+ return {"l1_loss": l1_loss, "auxiliary_reconstruction_loss": aux_recon_loss}
211
+
212
+ def log_histograms(self) -> dict[str, NDArray[Any]]:
213
+ """Log histograms of the weights and biases."""
214
+ b_gate_dist = self.b_gate.detach().float().cpu().numpy()
215
+ b_mag_dist = self.b_mag.detach().float().cpu().numpy()
216
+ return {
217
+ **super().log_histograms(),
218
+ "weights/b_gate": b_gate_dist,
219
+ "weights/b_mag": b_mag_dist,
220
+ }
221
+
222
+ @torch.no_grad()
223
+ def initialize_decoder_norm_constant_norm(self, norm: float = 0.1):
224
+ """Initialize decoder with constant norm"""
225
+ self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
226
+ self.W_dec.data *= norm
227
+
228
+ def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]:
229
+ return {
230
+ "l1": TrainCoefficientConfig(
231
+ value=self.cfg.l1_coefficient,
232
+ warm_up_steps=self.cfg.l1_warm_up_steps,
233
+ ),
234
+ }
235
+
236
+ def to_inference_config_dict(self) -> dict[str, Any]:
237
+ return filter_valid_dataclass_fields(
238
+ self.cfg.to_dict(), GatedSAEConfig, ["architecture"]
239
+ )
240
+
241
+
242
+ def _init_weights_gated(
243
+ sae: SAE[GatedSAEConfig] | TrainingSAE[GatedTrainingSAEConfig],
244
+ ) -> None:
245
+ sae.b_gate = nn.Parameter(
246
+ torch.zeros(sae.cfg.d_sae, dtype=sae.dtype, device=sae.device)
247
+ )
248
+ # Ensure r_mag is initialized to zero as in original
249
+ sae.r_mag = nn.Parameter(
250
+ torch.zeros(sae.cfg.d_sae, dtype=sae.dtype, device=sae.device)
251
+ )
252
+ sae.b_mag = nn.Parameter(
253
+ torch.zeros(sae.cfg.d_sae, dtype=sae.dtype, device=sae.device)
254
+ )