sae-lens 6.0.0rc1__py3-none-any.whl → 6.0.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,11 +7,12 @@ import numpy as np
7
7
  import torch
8
8
  from huggingface_hub import hf_hub_download
9
9
  from huggingface_hub.utils import EntryNotFoundError
10
+ from packaging.version import Version
10
11
  from safetensors import safe_open
11
12
  from safetensors.torch import load_file
12
13
 
13
14
  from sae_lens import logger
14
- from sae_lens.config import (
15
+ from sae_lens.constants import (
15
16
  DTYPE_MAP,
16
17
  SAE_CFG_FILENAME,
17
18
  SAE_WEIGHTS_FILENAME,
@@ -22,6 +23,8 @@ from sae_lens.loading.pretrained_saes_directory import (
22
23
  get_pretrained_saes_directory,
23
24
  get_repo_id_and_folder_name,
24
25
  )
26
+ from sae_lens.registry import get_sae_class
27
+ from sae_lens.util import filter_valid_dataclass_fields
25
28
 
26
29
 
27
30
  # loaders take in a release, sae_id, device, and whether to force download, and returns a tuple of config, state_dict, and log sparsity
@@ -174,9 +177,22 @@ def get_sae_lens_config_from_disk(
174
177
 
175
178
 
176
179
  def handle_config_defaulting(cfg_dict: dict[str, Any]) -> dict[str, Any]:
180
+ sae_lens_version = cfg_dict.get("sae_lens_version")
181
+ if not sae_lens_version and "metadata" in cfg_dict:
182
+ sae_lens_version = cfg_dict["metadata"].get("sae_lens_version")
183
+
184
+ if not sae_lens_version or Version(sae_lens_version) < Version("6.0.0-rc.0"):
185
+ cfg_dict = handle_pre_6_0_config(cfg_dict)
186
+ return cfg_dict
187
+
188
+
189
+ def handle_pre_6_0_config(cfg_dict: dict[str, Any]) -> dict[str, Any]:
190
+ """
191
+ Format a config dictionary for a Sparse Autoencoder (SAE) to be compatible with the new 6.0 format.
192
+ """
193
+
177
194
  rename_keys_map = {
178
195
  "hook_point": "hook_name",
179
- "hook_point_layer": "hook_layer",
180
196
  "hook_point_head_index": "hook_head_index",
181
197
  "activation_fn_str": "activation_fn",
182
198
  }
@@ -202,10 +218,26 @@ def handle_config_defaulting(cfg_dict: dict[str, Any]) -> dict[str, Any]:
202
218
  else "expected_average_only_in"
203
219
  )
204
220
 
205
- new_cfg.setdefault("normalize_activations", "none")
221
+ if new_cfg.get("normalize_activations") is None:
222
+ new_cfg["normalize_activations"] = "none"
223
+
206
224
  new_cfg.setdefault("device", "cpu")
207
225
 
208
- return new_cfg
226
+ architecture = new_cfg.get("architecture", "standard")
227
+
228
+ config_class = get_sae_class(architecture)[1]
229
+
230
+ sae_cfg_dict = filter_valid_dataclass_fields(new_cfg, config_class)
231
+ if architecture == "topk":
232
+ sae_cfg_dict["k"] = new_cfg["activation_fn_kwargs"]["k"]
233
+
234
+ # import here to avoid circular import
235
+ from sae_lens.saes.sae import SAEMetadata
236
+
237
+ meta_dict = filter_valid_dataclass_fields(new_cfg, SAEMetadata)
238
+ sae_cfg_dict["metadata"] = meta_dict
239
+ sae_cfg_dict["architecture"] = architecture
240
+ return sae_cfg_dict
209
241
 
210
242
 
211
243
  def get_connor_rob_hook_z_config_from_hf(
@@ -229,7 +261,6 @@ def get_connor_rob_hook_z_config_from_hf(
229
261
  "device": device if device is not None else "cpu",
230
262
  "model_name": "gpt2-small",
231
263
  "hook_name": old_cfg_dict["act_name"],
232
- "hook_layer": old_cfg_dict["layer"],
233
264
  "hook_head_index": None,
234
265
  "activation_fn": "relu",
235
266
  "apply_b_dec_to_input": True,
@@ -378,7 +409,6 @@ def get_gemma_2_config_from_hf(
378
409
  "dtype": "float32",
379
410
  "model_name": model_name,
380
411
  "hook_name": hook_name,
381
- "hook_layer": layer,
382
412
  "hook_head_index": None,
383
413
  "activation_fn": "relu",
384
414
  "finetuning_scaling_factor": False,
@@ -491,7 +521,6 @@ def get_llama_scope_config_from_hf(
491
521
  "dtype": "bfloat16",
492
522
  "model_name": model_name,
493
523
  "hook_name": old_cfg_dict["hook_point_in"],
494
- "hook_layer": int(old_cfg_dict["hook_point_in"].split(".")[1]),
495
524
  "hook_head_index": None,
496
525
  "activation_fn": "relu",
497
526
  "finetuning_scaling_factor": False,
@@ -618,7 +647,6 @@ def get_dictionary_learning_config_1_from_hf(
618
647
  "device": device,
619
648
  "model_name": trainer["lm_name"].split("/")[-1],
620
649
  "hook_name": hook_point_name,
621
- "hook_layer": trainer["layer"],
622
650
  "hook_head_index": None,
623
651
  "activation_fn": activation_fn,
624
652
  "activation_fn_kwargs": activation_fn_kwargs,
@@ -657,7 +685,6 @@ def get_deepseek_r1_config_from_hf(
657
685
  "context_size": 1024,
658
686
  "model_name": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
659
687
  "hook_name": f"blocks.{layer}.hook_resid_post",
660
- "hook_layer": layer,
661
688
  "hook_head_index": None,
662
689
  "prepend_bos": True,
663
690
  "dataset_path": "lmsys/lmsys-chat-1m",
@@ -816,7 +843,6 @@ def get_llama_scope_r1_distill_config_from_hf(
816
843
  "device": device,
817
844
  "model_name": model_name,
818
845
  "hook_name": huggingface_cfg_dict["hook_point_in"],
819
- "hook_layer": int(huggingface_cfg_dict["hook_point_in"].split(".")[1]),
820
846
  "hook_head_index": None,
821
847
  "activation_fn": "relu",
822
848
  "finetuning_scaling_factor": False,
sae_lens/registry.py ADDED
@@ -0,0 +1,49 @@
1
+ from typing import TYPE_CHECKING, Any
2
+
3
+ # avoid circular imports
4
+ if TYPE_CHECKING:
5
+ from sae_lens.saes.sae import SAE, SAEConfig, TrainingSAE, TrainingSAEConfig
6
+
7
+ SAE_CLASS_REGISTRY: dict[str, tuple["type[SAE[Any]]", "type[SAEConfig]"]] = {}
8
+ SAE_TRAINING_CLASS_REGISTRY: dict[
9
+ str, tuple["type[TrainingSAE[Any]]", "type[TrainingSAEConfig]"]
10
+ ] = {}
11
+
12
+
13
+ def register_sae_class(
14
+ architecture: str,
15
+ sae_class: "type[SAE[Any]]",
16
+ sae_config_class: "type[SAEConfig]",
17
+ ) -> None:
18
+ if architecture in SAE_CLASS_REGISTRY:
19
+ raise ValueError(
20
+ f"SAE class for architecture {architecture} already registered."
21
+ )
22
+ SAE_CLASS_REGISTRY[architecture] = (sae_class, sae_config_class)
23
+
24
+
25
+ def register_sae_training_class(
26
+ architecture: str,
27
+ sae_training_class: "type[TrainingSAE[Any]]",
28
+ sae_training_config_class: "type[TrainingSAEConfig]",
29
+ ) -> None:
30
+ if architecture in SAE_TRAINING_CLASS_REGISTRY:
31
+ raise ValueError(
32
+ f"SAE training class for architecture {architecture} already registered."
33
+ )
34
+ SAE_TRAINING_CLASS_REGISTRY[architecture] = (
35
+ sae_training_class,
36
+ sae_training_config_class,
37
+ )
38
+
39
+
40
+ def get_sae_class(
41
+ architecture: str,
42
+ ) -> tuple["type[SAE[Any]]", "type[SAEConfig]"]:
43
+ return SAE_CLASS_REGISTRY[architecture]
44
+
45
+
46
+ def get_sae_training_class(
47
+ architecture: str,
48
+ ) -> tuple["type[TrainingSAE[Any]]", "type[TrainingSAEConfig]"]:
49
+ return SAE_TRAINING_CLASS_REGISTRY[architecture]
@@ -0,0 +1,48 @@
1
+ from .gated_sae import (
2
+ GatedSAE,
3
+ GatedSAEConfig,
4
+ GatedTrainingSAE,
5
+ GatedTrainingSAEConfig,
6
+ )
7
+ from .jumprelu_sae import (
8
+ JumpReLUSAE,
9
+ JumpReLUSAEConfig,
10
+ JumpReLUTrainingSAE,
11
+ JumpReLUTrainingSAEConfig,
12
+ )
13
+ from .sae import SAE, SAEConfig, TrainingSAE, TrainingSAEConfig
14
+ from .standard_sae import (
15
+ StandardSAE,
16
+ StandardSAEConfig,
17
+ StandardTrainingSAE,
18
+ StandardTrainingSAEConfig,
19
+ )
20
+ from .topk_sae import (
21
+ TopKSAE,
22
+ TopKSAEConfig,
23
+ TopKTrainingSAE,
24
+ TopKTrainingSAEConfig,
25
+ )
26
+
27
+ __all__ = [
28
+ "SAE",
29
+ "SAEConfig",
30
+ "TrainingSAE",
31
+ "TrainingSAEConfig",
32
+ "StandardSAE",
33
+ "StandardSAEConfig",
34
+ "StandardTrainingSAE",
35
+ "StandardTrainingSAEConfig",
36
+ "GatedSAE",
37
+ "GatedSAEConfig",
38
+ "GatedTrainingSAE",
39
+ "GatedTrainingSAEConfig",
40
+ "JumpReLUSAE",
41
+ "JumpReLUSAEConfig",
42
+ "JumpReLUTrainingSAE",
43
+ "JumpReLUTrainingSAEConfig",
44
+ "TopKSAE",
45
+ "TopKSAEConfig",
46
+ "TopKTrainingSAE",
47
+ "TopKTrainingSAEConfig",
48
+ ]
@@ -1,20 +1,36 @@
1
+ from dataclasses import dataclass
1
2
  from typing import Any
2
3
 
3
4
  import torch
4
5
  from jaxtyping import Float
5
6
  from numpy.typing import NDArray
6
7
  from torch import nn
8
+ from typing_extensions import override
7
9
 
8
10
  from sae_lens.saes.sae import (
9
11
  SAE,
10
12
  SAEConfig,
13
+ TrainCoefficientConfig,
11
14
  TrainingSAE,
12
15
  TrainingSAEConfig,
13
16
  TrainStepInput,
14
17
  )
18
+ from sae_lens.util import filter_valid_dataclass_fields
15
19
 
16
20
 
17
- class GatedSAE(SAE):
21
+ @dataclass
22
+ class GatedSAEConfig(SAEConfig):
23
+ """
24
+ Configuration class for a GatedSAE.
25
+ """
26
+
27
+ @override
28
+ @classmethod
29
+ def architecture(cls) -> str:
30
+ return "gated"
31
+
32
+
33
+ class GatedSAE(SAE[GatedSAEConfig]):
18
34
  """
19
35
  GatedSAE is an inference-only implementation of a Sparse Autoencoder (SAE)
20
36
  using a gated linear encoder and a standard linear decoder.
@@ -24,48 +40,15 @@ class GatedSAE(SAE):
24
40
  b_mag: nn.Parameter
25
41
  r_mag: nn.Parameter
26
42
 
27
- def __init__(self, cfg: SAEConfig, use_error_term: bool = False):
43
+ def __init__(self, cfg: GatedSAEConfig, use_error_term: bool = False):
28
44
  super().__init__(cfg, use_error_term)
29
45
  # Ensure b_enc does not exist for the gated architecture
30
46
  self.b_enc = None
31
47
 
48
+ @override
32
49
  def initialize_weights(self) -> None:
33
- """
34
- Initialize weights exactly as in the original SAE class for gated architecture.
35
- """
36
- # Use the same initialization methods and values as in original SAE
37
- self.W_enc = nn.Parameter(
38
- torch.nn.init.kaiming_uniform_(
39
- torch.empty(
40
- self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
41
- )
42
- )
43
- )
44
- self.b_gate = nn.Parameter(
45
- torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
46
- )
47
- # Ensure r_mag is initialized to zero as in original
48
- self.r_mag = nn.Parameter(
49
- torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
50
- )
51
- self.b_mag = nn.Parameter(
52
- torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
53
- )
54
-
55
- # Decoder parameters with same initialization as original
56
- self.W_dec = nn.Parameter(
57
- torch.nn.init.kaiming_uniform_(
58
- torch.empty(
59
- self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
60
- )
61
- )
62
- )
63
- self.b_dec = nn.Parameter(
64
- torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
65
- )
66
-
67
- # after defining b_gate, b_mag, etc.:
68
- self.b_enc = None
50
+ super().initialize_weights()
51
+ _init_weights_gated(self)
69
52
 
70
53
  def encode(
71
54
  self, x: Float[torch.Tensor, "... d_in"]
@@ -101,9 +84,8 @@ class GatedSAE(SAE):
101
84
  4) If the SAE was reshaping hook_z activations, reshape back.
102
85
  """
103
86
  # 1) optional finetuning scaling
104
- scaled_features = self.apply_finetuning_scaling_factor(feature_acts)
105
87
  # 2) linear transform
106
- sae_out_pre = scaled_features @ self.W_dec + self.b_dec
88
+ sae_out_pre = feature_acts @ self.W_dec + self.b_dec
107
89
  # 3) hooking and normalization
108
90
  sae_out_pre = self.hook_sae_recons(sae_out_pre)
109
91
  sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
@@ -129,7 +111,22 @@ class GatedSAE(SAE):
129
111
  self.W_dec.data *= norm
130
112
 
131
113
 
132
- class GatedTrainingSAE(TrainingSAE):
114
+ @dataclass
115
+ class GatedTrainingSAEConfig(TrainingSAEConfig):
116
+ """
117
+ Configuration class for training a GatedTrainingSAE.
118
+ """
119
+
120
+ l1_coefficient: float = 1.0
121
+ l1_warm_up_steps: int = 0
122
+
123
+ @override
124
+ @classmethod
125
+ def architecture(cls) -> str:
126
+ return "gated"
127
+
128
+
129
+ class GatedTrainingSAE(TrainingSAE[GatedTrainingSAEConfig]):
133
130
  """
134
131
  GatedTrainingSAE is a concrete implementation of BaseTrainingSAE for the "gated" SAE architecture.
135
132
  It implements:
@@ -145,7 +142,7 @@ class GatedTrainingSAE(TrainingSAE):
145
142
  b_mag: nn.Parameter # type: ignore
146
143
  r_mag: nn.Parameter # type: ignore
147
144
 
148
- def __init__(self, cfg: TrainingSAEConfig, use_error_term: bool = False):
145
+ def __init__(self, cfg: GatedTrainingSAEConfig, use_error_term: bool = False):
149
146
  if use_error_term:
150
147
  raise ValueError(
151
148
  "GatedSAE does not support `use_error_term`. Please set `use_error_term=False`."
@@ -153,22 +150,8 @@ class GatedTrainingSAE(TrainingSAE):
153
150
  super().__init__(cfg, use_error_term)
154
151
 
155
152
  def initialize_weights(self) -> None:
156
- # Reuse the gating parameter initialization from GatedSAE:
157
- GatedSAE.initialize_weights(self) # type: ignore
158
-
159
- # Additional training-specific logic, e.g. orthogonal init or heuristics:
160
- if self.cfg.decoder_orthogonal_init:
161
- self.W_dec.data = nn.init.orthogonal_(self.W_dec.data.T).T
162
- elif self.cfg.decoder_heuristic_init:
163
- self.W_dec.data = torch.rand(
164
- self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
165
- )
166
- self.initialize_decoder_norm_constant_norm()
167
- if self.cfg.init_encoder_as_decoder_transpose:
168
- self.W_enc.data = self.W_dec.data.T.clone().contiguous()
169
- if self.cfg.normalize_sae_decoder:
170
- with torch.no_grad():
171
- self.set_decoder_norm_to_unit_norm()
153
+ super().initialize_weights()
154
+ _init_weights_gated(self)
172
155
 
173
156
  def encode_with_hidden_pre(
174
157
  self, x: Float[torch.Tensor, "... d_in"]
@@ -217,7 +200,7 @@ class GatedTrainingSAE(TrainingSAE):
217
200
 
218
201
  # L1-like penalty scaled by W_dec norms
219
202
  l1_loss = (
220
- step_input.current_l1_coefficient
203
+ step_input.coefficients["l1"]
221
204
  * torch.sum(pi_gate_act * self.W_dec.norm(dim=1), dim=-1).mean()
222
205
  )
223
206
 
@@ -245,3 +228,31 @@ class GatedTrainingSAE(TrainingSAE):
245
228
  """Initialize decoder with constant norm"""
246
229
  self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
247
230
  self.W_dec.data *= norm
231
+
232
+ def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]:
233
+ return {
234
+ "l1": TrainCoefficientConfig(
235
+ value=self.cfg.l1_coefficient,
236
+ warm_up_steps=self.cfg.l1_warm_up_steps,
237
+ ),
238
+ }
239
+
240
+ def to_inference_config_dict(self) -> dict[str, Any]:
241
+ return filter_valid_dataclass_fields(
242
+ self.cfg.to_dict(), GatedSAEConfig, ["architecture"]
243
+ )
244
+
245
+
246
+ def _init_weights_gated(
247
+ sae: SAE[GatedSAEConfig] | TrainingSAE[GatedTrainingSAEConfig],
248
+ ) -> None:
249
+ sae.b_gate = nn.Parameter(
250
+ torch.zeros(sae.cfg.d_sae, dtype=sae.dtype, device=sae.device)
251
+ )
252
+ # Ensure r_mag is initialized to zero as in original
253
+ sae.r_mag = nn.Parameter(
254
+ torch.zeros(sae.cfg.d_sae, dtype=sae.dtype, device=sae.device)
255
+ )
256
+ sae.b_mag = nn.Parameter(
257
+ torch.zeros(sae.cfg.d_sae, dtype=sae.dtype, device=sae.device)
258
+ )
@@ -1,3 +1,4 @@
1
+ from dataclasses import dataclass
1
2
  from typing import Any
2
3
 
3
4
  import numpy as np
@@ -9,11 +10,13 @@ from typing_extensions import override
9
10
  from sae_lens.saes.sae import (
10
11
  SAE,
11
12
  SAEConfig,
13
+ TrainCoefficientConfig,
12
14
  TrainingSAE,
13
15
  TrainingSAEConfig,
14
16
  TrainStepInput,
15
17
  TrainStepOutput,
16
18
  )
19
+ from sae_lens.util import filter_valid_dataclass_fields
17
20
 
18
21
 
19
22
  def rectangle(x: torch.Tensor) -> torch.Tensor:
@@ -85,7 +88,19 @@ class JumpReLU(torch.autograd.Function):
85
88
  return x_grad, threshold_grad, None
86
89
 
87
90
 
88
- class JumpReLUSAE(SAE):
91
+ @dataclass
92
+ class JumpReLUSAEConfig(SAEConfig):
93
+ """
94
+ Configuration class for a JumpReLUSAE.
95
+ """
96
+
97
+ @override
98
+ @classmethod
99
+ def architecture(cls) -> str:
100
+ return "jumprelu"
101
+
102
+
103
+ class JumpReLUSAE(SAE[JumpReLUSAEConfig]):
89
104
  """
90
105
  JumpReLUSAE is an inference-only implementation of a Sparse Autoencoder (SAE)
91
106
  using a JumpReLU activation. For each unit, if its pre-activation is
@@ -104,42 +119,18 @@ class JumpReLUSAE(SAE):
104
119
  b_enc: nn.Parameter
105
120
  threshold: nn.Parameter
106
121
 
107
- def __init__(self, cfg: SAEConfig, use_error_term: bool = False):
122
+ def __init__(self, cfg: JumpReLUSAEConfig, use_error_term: bool = False):
108
123
  super().__init__(cfg, use_error_term)
109
124
 
125
+ @override
110
126
  def initialize_weights(self) -> None:
111
- """
112
- Initialize encoder and decoder weights, as well as biases.
113
- Additionally, include a learnable `threshold` parameter that
114
- determines when units "turn on" for the JumpReLU.
115
- """
116
- # Biases
117
- self.b_enc = nn.Parameter(
118
- torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
119
- )
120
- self.b_dec = nn.Parameter(
121
- torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
122
- )
123
-
124
- # Threshold for JumpReLU
125
- # You can pick a default initialization (e.g., zeros means unit is off unless hidden_pre > 0)
126
- # or see the training version for more advanced init with log_threshold, etc.
127
+ super().initialize_weights()
127
128
  self.threshold = nn.Parameter(
128
129
  torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
129
130
  )
130
-
131
- # Encoder and Decoder weights
132
- w_enc_data = torch.empty(
133
- self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
134
- )
135
- nn.init.kaiming_uniform_(w_enc_data)
136
- self.W_enc = nn.Parameter(w_enc_data)
137
-
138
- w_dec_data = torch.empty(
139
- self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
131
+ self.b_enc = nn.Parameter(
132
+ torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
140
133
  )
141
- nn.init.kaiming_uniform_(w_dec_data)
142
- self.W_dec = nn.Parameter(w_dec_data)
143
134
 
144
135
  def encode(
145
136
  self, x: Float[torch.Tensor, "... d_in"]
@@ -168,8 +159,7 @@ class JumpReLUSAE(SAE):
168
159
  Decode the feature activations back to the input space.
169
160
  Follows the same steps as StandardSAE: apply scaling, transform, hook, and optionally reshape.
170
161
  """
171
- scaled_features = self.apply_finetuning_scaling_factor(feature_acts)
172
- sae_out_pre = scaled_features @ self.W_dec + self.b_dec
162
+ sae_out_pre = feature_acts @ self.W_dec + self.b_dec
173
163
  sae_out_pre = self.hook_sae_recons(sae_out_pre)
174
164
  sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
175
165
  return self.reshape_fn_out(sae_out_pre, self.d_head)
@@ -195,7 +185,24 @@ class JumpReLUSAE(SAE):
195
185
  self.threshold.data = current_thresh * W_dec_norms
196
186
 
197
187
 
198
- class JumpReLUTrainingSAE(TrainingSAE):
188
+ @dataclass
189
+ class JumpReLUTrainingSAEConfig(TrainingSAEConfig):
190
+ """
191
+ Configuration class for training a JumpReLUTrainingSAE.
192
+ """
193
+
194
+ jumprelu_init_threshold: float = 0.001
195
+ jumprelu_bandwidth: float = 0.001
196
+ l0_coefficient: float = 1.0
197
+ l0_warm_up_steps: int = 0
198
+
199
+ @override
200
+ @classmethod
201
+ def architecture(cls) -> str:
202
+ return "jumprelu"
203
+
204
+
205
+ class JumpReLUTrainingSAE(TrainingSAE[JumpReLUTrainingSAEConfig]):
199
206
  """
200
207
  JumpReLUTrainingSAE is a training-focused implementation of a SAE using a JumpReLU activation.
201
208
 
@@ -213,7 +220,7 @@ class JumpReLUTrainingSAE(TrainingSAE):
213
220
  b_enc: nn.Parameter
214
221
  log_threshold: nn.Parameter
215
222
 
216
- def __init__(self, cfg: TrainingSAEConfig, use_error_term: bool = False):
223
+ def __init__(self, cfg: JumpReLUTrainingSAEConfig, use_error_term: bool = False):
217
224
  super().__init__(cfg, use_error_term)
218
225
 
219
226
  # We'll store a bandwidth for the training approach, if needed
@@ -225,51 +232,16 @@ class JumpReLUTrainingSAE(TrainingSAE):
225
232
  * np.log(cfg.jumprelu_init_threshold)
226
233
  )
227
234
 
235
+ @override
228
236
  def initialize_weights(self) -> None:
229
237
  """
230
238
  Initialize parameters like the base SAE, but also add log_threshold.
231
239
  """
240
+ super().initialize_weights()
232
241
  # Encoder Bias
233
242
  self.b_enc = nn.Parameter(
234
243
  torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
235
244
  )
236
- # Decoder Bias
237
- self.b_dec = nn.Parameter(
238
- torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
239
- )
240
- # W_enc
241
- w_enc_data = torch.nn.init.kaiming_uniform_(
242
- torch.empty(
243
- self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
244
- )
245
- )
246
- self.W_enc = nn.Parameter(w_enc_data)
247
-
248
- # W_dec
249
- w_dec_data = torch.nn.init.kaiming_uniform_(
250
- torch.empty(
251
- self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
252
- )
253
- )
254
- self.W_dec = nn.Parameter(w_dec_data)
255
-
256
- # Optionally apply orthogonal or heuristic init
257
- if self.cfg.decoder_orthogonal_init:
258
- self.W_dec.data = nn.init.orthogonal_(self.W_dec.data.T).T
259
- elif self.cfg.decoder_heuristic_init:
260
- self.W_dec.data = torch.rand(
261
- self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
262
- )
263
- self.initialize_decoder_norm_constant_norm()
264
-
265
- # Optionally transpose
266
- if self.cfg.init_encoder_as_decoder_transpose:
267
- self.W_enc.data = self.W_dec.data.T.clone().contiguous()
268
-
269
- # Optionally normalize columns of W_dec
270
- if self.cfg.normalize_sae_decoder:
271
- with torch.no_grad():
272
- self.set_decoder_norm_to_unit_norm()
273
245
 
274
246
  @property
275
247
  def threshold(self) -> torch.Tensor:
@@ -305,9 +277,18 @@ class JumpReLUTrainingSAE(TrainingSAE):
305
277
  ) -> dict[str, torch.Tensor]:
306
278
  """Calculate architecture-specific auxiliary loss terms."""
307
279
  l0 = torch.sum(Step.apply(hidden_pre, self.threshold, self.bandwidth), dim=-1) # type: ignore
308
- l0_loss = (step_input.current_l1_coefficient * l0).mean()
280
+ l0_loss = (step_input.coefficients["l0"] * l0).mean()
309
281
  return {"l0_loss": l0_loss}
310
282
 
283
+ @override
284
+ def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]:
285
+ return {
286
+ "l0": TrainCoefficientConfig(
287
+ value=self.cfg.l0_coefficient,
288
+ warm_up_steps=self.cfg.l0_warm_up_steps,
289
+ ),
290
+ }
291
+
311
292
  @torch.no_grad()
312
293
  def fold_W_dec_norm(self):
313
294
  """
@@ -366,3 +347,8 @@ class JumpReLUTrainingSAE(TrainingSAE):
366
347
  threshold = state_dict["threshold"]
367
348
  del state_dict["threshold"]
368
349
  state_dict["log_threshold"] = torch.log(threshold).detach().contiguous()
350
+
351
+ def to_inference_config_dict(self) -> dict[str, Any]:
352
+ return filter_valid_dataclass_fields(
353
+ self.cfg.to_dict(), JumpReLUSAEConfig, ["architecture"]
354
+ )