sae-lens 5.9.0__py3-none-any.whl → 6.0.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,247 @@
1
+ from typing import Any
2
+
3
+ import torch
4
+ from jaxtyping import Float
5
+ from numpy.typing import NDArray
6
+ from torch import nn
7
+
8
+ from sae_lens.saes.sae import (
9
+ SAE,
10
+ SAEConfig,
11
+ TrainingSAE,
12
+ TrainingSAEConfig,
13
+ TrainStepInput,
14
+ )
15
+
16
+
17
+ class GatedSAE(SAE):
18
+ """
19
+ GatedSAE is an inference-only implementation of a Sparse Autoencoder (SAE)
20
+ using a gated linear encoder and a standard linear decoder.
21
+ """
22
+
23
+ b_gate: nn.Parameter
24
+ b_mag: nn.Parameter
25
+ r_mag: nn.Parameter
26
+
27
+ def __init__(self, cfg: SAEConfig, use_error_term: bool = False):
28
+ super().__init__(cfg, use_error_term)
29
+ # Ensure b_enc does not exist for the gated architecture
30
+ self.b_enc = None
31
+
32
+ def initialize_weights(self) -> None:
33
+ """
34
+ Initialize weights exactly as in the original SAE class for gated architecture.
35
+ """
36
+ # Use the same initialization methods and values as in original SAE
37
+ self.W_enc = nn.Parameter(
38
+ torch.nn.init.kaiming_uniform_(
39
+ torch.empty(
40
+ self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
41
+ )
42
+ )
43
+ )
44
+ self.b_gate = nn.Parameter(
45
+ torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
46
+ )
47
+ # Ensure r_mag is initialized to zero as in original
48
+ self.r_mag = nn.Parameter(
49
+ torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
50
+ )
51
+ self.b_mag = nn.Parameter(
52
+ torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
53
+ )
54
+
55
+ # Decoder parameters with same initialization as original
56
+ self.W_dec = nn.Parameter(
57
+ torch.nn.init.kaiming_uniform_(
58
+ torch.empty(
59
+ self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
60
+ )
61
+ )
62
+ )
63
+ self.b_dec = nn.Parameter(
64
+ torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
65
+ )
66
+
67
+ # after defining b_gate, b_mag, etc.:
68
+ self.b_enc = None
69
+
70
+ def encode(
71
+ self, x: Float[torch.Tensor, "... d_in"]
72
+ ) -> Float[torch.Tensor, "... d_sae"]:
73
+ """
74
+ Encode the input tensor into the feature space using a gated encoder.
75
+ This must match the original encode_gated implementation from SAE class.
76
+ """
77
+ # Preprocess the SAE input (casting type, applying hooks, normalization)
78
+ sae_in = self.process_sae_in(x)
79
+
80
+ # Gating path exactly as in original SAE.encode_gated
81
+ gating_pre_activation = sae_in @ self.W_enc + self.b_gate
82
+ active_features = (gating_pre_activation > 0).to(self.dtype)
83
+
84
+ # Magnitude path (weight sharing with gated encoder)
85
+ magnitude_pre_activation = self.hook_sae_acts_pre(
86
+ sae_in @ (self.W_enc * self.r_mag.exp()) + self.b_mag
87
+ )
88
+ feature_magnitudes = self.activation_fn(magnitude_pre_activation)
89
+
90
+ # Combine gating and magnitudes
91
+ return self.hook_sae_acts_post(active_features * feature_magnitudes)
92
+
93
+ def decode(
94
+ self, feature_acts: Float[torch.Tensor, "... d_sae"]
95
+ ) -> Float[torch.Tensor, "... d_in"]:
96
+ """
97
+ Decode the feature activations back into the input space:
98
+ 1) Apply optional finetuning scaling.
99
+ 2) Linear transform plus bias.
100
+ 3) Run any reconstruction hooks and out-normalization if configured.
101
+ 4) If the SAE was reshaping hook_z activations, reshape back.
102
+ """
103
+ # 1) optional finetuning scaling
104
+ scaled_features = self.apply_finetuning_scaling_factor(feature_acts)
105
+ # 2) linear transform
106
+ sae_out_pre = scaled_features @ self.W_dec + self.b_dec
107
+ # 3) hooking and normalization
108
+ sae_out_pre = self.hook_sae_recons(sae_out_pre)
109
+ sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
110
+ # 4) reshape if needed (hook_z)
111
+ return self.reshape_fn_out(sae_out_pre, self.d_head)
112
+
113
+ @torch.no_grad()
114
+ def fold_W_dec_norm(self):
115
+ """Override to handle gated-specific parameters."""
116
+ W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
117
+ self.W_dec.data = self.W_dec.data / W_dec_norms
118
+ self.W_enc.data = self.W_enc.data * W_dec_norms.T
119
+
120
+ # Gated-specific parameters need special handling
121
+ self.r_mag.data = self.r_mag.data * W_dec_norms.squeeze()
122
+ self.b_gate.data = self.b_gate.data * W_dec_norms.squeeze()
123
+ self.b_mag.data = self.b_mag.data * W_dec_norms.squeeze()
124
+
125
+ @torch.no_grad()
126
+ def initialize_decoder_norm_constant_norm(self, norm: float = 0.1):
127
+ """Initialize decoder with constant norm."""
128
+ self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
129
+ self.W_dec.data *= norm
130
+
131
+
132
+ class GatedTrainingSAE(TrainingSAE):
133
+ """
134
+ GatedTrainingSAE is a concrete implementation of BaseTrainingSAE for the "gated" SAE architecture.
135
+ It implements:
136
+ - initialize_weights: sets up gating parameters (as in GatedSAE) plus optional training-specific init.
137
+ - encode: calls encode_with_hidden_pre (standard training approach).
138
+ - decode: linear transformation + hooking, same as GatedSAE or StandardTrainingSAE.
139
+ - encode_with_hidden_pre: gating logic + optional noise injection for training.
140
+ - calculate_aux_loss: includes an auxiliary reconstruction path and gating-based sparsity penalty.
141
+ - training_forward_pass: calls encode_with_hidden_pre, decode, and sums up MSE + gating losses.
142
+ """
143
+
144
+ b_gate: nn.Parameter # type: ignore
145
+ b_mag: nn.Parameter # type: ignore
146
+ r_mag: nn.Parameter # type: ignore
147
+
148
+ def __init__(self, cfg: TrainingSAEConfig, use_error_term: bool = False):
149
+ if use_error_term:
150
+ raise ValueError(
151
+ "GatedSAE does not support `use_error_term`. Please set `use_error_term=False`."
152
+ )
153
+ super().__init__(cfg, use_error_term)
154
+
155
+ def initialize_weights(self) -> None:
156
+ # Reuse the gating parameter initialization from GatedSAE:
157
+ GatedSAE.initialize_weights(self) # type: ignore
158
+
159
+ # Additional training-specific logic, e.g. orthogonal init or heuristics:
160
+ if self.cfg.decoder_orthogonal_init:
161
+ self.W_dec.data = nn.init.orthogonal_(self.W_dec.data.T).T
162
+ elif self.cfg.decoder_heuristic_init:
163
+ self.W_dec.data = torch.rand(
164
+ self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
165
+ )
166
+ self.initialize_decoder_norm_constant_norm()
167
+ if self.cfg.init_encoder_as_decoder_transpose:
168
+ self.W_enc.data = self.W_dec.data.T.clone().contiguous()
169
+ if self.cfg.normalize_sae_decoder:
170
+ with torch.no_grad():
171
+ self.set_decoder_norm_to_unit_norm()
172
+
173
+ def encode_with_hidden_pre(
174
+ self, x: Float[torch.Tensor, "... d_in"]
175
+ ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
176
+ """
177
+ Gated forward pass with pre-activation (for training).
178
+ We also inject noise if self.training is True.
179
+ """
180
+ sae_in = self.process_sae_in(x)
181
+
182
+ # Gating path
183
+ gating_pre_activation = sae_in @ self.W_enc + self.b_gate
184
+ active_features = (gating_pre_activation > 0).to(self.dtype)
185
+
186
+ # Magnitude path
187
+ magnitude_pre_activation = sae_in @ (self.W_enc * self.r_mag.exp()) + self.b_mag
188
+ if self.training and self.cfg.noise_scale > 0:
189
+ magnitude_pre_activation += (
190
+ torch.randn_like(magnitude_pre_activation) * self.cfg.noise_scale
191
+ )
192
+ magnitude_pre_activation = self.hook_sae_acts_pre(magnitude_pre_activation)
193
+
194
+ feature_magnitudes = self.activation_fn(magnitude_pre_activation)
195
+
196
+ # Combine gating path and magnitude path
197
+ feature_acts = self.hook_sae_acts_post(active_features * feature_magnitudes)
198
+
199
+ # Return both the final feature activations and the pre-activation (for logging or penalty)
200
+ return feature_acts, magnitude_pre_activation
201
+
202
+ def calculate_aux_loss(
203
+ self,
204
+ step_input: TrainStepInput,
205
+ feature_acts: torch.Tensor,
206
+ hidden_pre: torch.Tensor,
207
+ sae_out: torch.Tensor,
208
+ ) -> dict[str, torch.Tensor]:
209
+ # Re-center the input if apply_b_dec_to_input is set
210
+ sae_in_centered = step_input.sae_in - (
211
+ self.b_dec * self.cfg.apply_b_dec_to_input
212
+ )
213
+
214
+ # The gating pre-activation (pi_gate) for the auxiliary path
215
+ pi_gate = sae_in_centered @ self.W_enc + self.b_gate
216
+ pi_gate_act = torch.relu(pi_gate)
217
+
218
+ # L1-like penalty scaled by W_dec norms
219
+ l1_loss = (
220
+ step_input.current_l1_coefficient
221
+ * torch.sum(pi_gate_act * self.W_dec.norm(dim=1), dim=-1).mean()
222
+ )
223
+
224
+ # Aux reconstruction: reconstruct x purely from gating path
225
+ via_gate_reconstruction = pi_gate_act @ self.W_dec + self.b_dec
226
+ aux_recon_loss = (
227
+ (via_gate_reconstruction - step_input.sae_in).pow(2).sum(dim=-1).mean()
228
+ )
229
+
230
+ # Return both losses separately
231
+ return {"l1_loss": l1_loss, "auxiliary_reconstruction_loss": aux_recon_loss}
232
+
233
+ def log_histograms(self) -> dict[str, NDArray[Any]]:
234
+ """Log histograms of the weights and biases."""
235
+ b_gate_dist = self.b_gate.detach().float().cpu().numpy()
236
+ b_mag_dist = self.b_mag.detach().float().cpu().numpy()
237
+ return {
238
+ **super().log_histograms(),
239
+ "weights/b_gate": b_gate_dist,
240
+ "weights/b_mag": b_mag_dist,
241
+ }
242
+
243
+ @torch.no_grad()
244
+ def initialize_decoder_norm_constant_norm(self, norm: float = 0.1):
245
+ """Initialize decoder with constant norm"""
246
+ self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
247
+ self.W_dec.data *= norm
@@ -0,0 +1,368 @@
1
+ from typing import Any
2
+
3
+ import numpy as np
4
+ import torch
5
+ from jaxtyping import Float
6
+ from torch import nn
7
+ from typing_extensions import override
8
+
9
+ from sae_lens.saes.sae import (
10
+ SAE,
11
+ SAEConfig,
12
+ TrainingSAE,
13
+ TrainingSAEConfig,
14
+ TrainStepInput,
15
+ TrainStepOutput,
16
+ )
17
+
18
+
19
+ def rectangle(x: torch.Tensor) -> torch.Tensor:
20
+ return ((x > -0.5) & (x < 0.5)).to(x)
21
+
22
+
23
+ class Step(torch.autograd.Function):
24
+ @staticmethod
25
+ def forward(
26
+ x: torch.Tensor,
27
+ threshold: torch.Tensor,
28
+ bandwidth: float, # noqa: ARG004
29
+ ) -> torch.Tensor:
30
+ return (x > threshold).to(x)
31
+
32
+ @staticmethod
33
+ def setup_context(
34
+ ctx: Any, inputs: tuple[torch.Tensor, torch.Tensor, float], output: torch.Tensor
35
+ ) -> None:
36
+ x, threshold, bandwidth = inputs
37
+ del output
38
+ ctx.save_for_backward(x, threshold)
39
+ ctx.bandwidth = bandwidth
40
+
41
+ @staticmethod
42
+ def backward( # type: ignore[override]
43
+ ctx: Any, grad_output: torch.Tensor
44
+ ) -> tuple[None, torch.Tensor, None]:
45
+ x, threshold = ctx.saved_tensors
46
+ bandwidth = ctx.bandwidth
47
+ threshold_grad = torch.sum(
48
+ -(1.0 / bandwidth) * rectangle((x - threshold) / bandwidth) * grad_output,
49
+ dim=0,
50
+ )
51
+ return None, threshold_grad, None
52
+
53
+
54
+ class JumpReLU(torch.autograd.Function):
55
+ @staticmethod
56
+ def forward(
57
+ x: torch.Tensor,
58
+ threshold: torch.Tensor,
59
+ bandwidth: float, # noqa: ARG004
60
+ ) -> torch.Tensor:
61
+ return (x * (x > threshold)).to(x)
62
+
63
+ @staticmethod
64
+ def setup_context(
65
+ ctx: Any, inputs: tuple[torch.Tensor, torch.Tensor, float], output: torch.Tensor
66
+ ) -> None:
67
+ x, threshold, bandwidth = inputs
68
+ del output
69
+ ctx.save_for_backward(x, threshold)
70
+ ctx.bandwidth = bandwidth
71
+
72
+ @staticmethod
73
+ def backward( # type: ignore[override]
74
+ ctx: Any, grad_output: torch.Tensor
75
+ ) -> tuple[torch.Tensor, torch.Tensor, None]:
76
+ x, threshold = ctx.saved_tensors
77
+ bandwidth = ctx.bandwidth
78
+ x_grad = (x > threshold) * grad_output # We don't apply STE to x input
79
+ threshold_grad = torch.sum(
80
+ -(threshold / bandwidth)
81
+ * rectangle((x - threshold) / bandwidth)
82
+ * grad_output,
83
+ dim=0,
84
+ )
85
+ return x_grad, threshold_grad, None
86
+
87
+
88
+ class JumpReLUSAE(SAE):
89
+ """
90
+ JumpReLUSAE is an inference-only implementation of a Sparse Autoencoder (SAE)
91
+ using a JumpReLU activation. For each unit, if its pre-activation is
92
+ <= threshold, that unit is zeroed out; otherwise, it follows a user-specified
93
+ activation function (e.g., ReLU, tanh-relu, etc.).
94
+
95
+ It implements:
96
+ - initialize_weights: sets up parameters, including a threshold.
97
+ - encode: computes the feature activations using JumpReLU.
98
+ - decode: reconstructs the input from the feature activations.
99
+
100
+ The BaseSAE.forward() method automatically calls encode and decode,
101
+ including any error-term processing if configured.
102
+ """
103
+
104
+ b_enc: nn.Parameter
105
+ threshold: nn.Parameter
106
+
107
+ def __init__(self, cfg: SAEConfig, use_error_term: bool = False):
108
+ super().__init__(cfg, use_error_term)
109
+
110
+ def initialize_weights(self) -> None:
111
+ """
112
+ Initialize encoder and decoder weights, as well as biases.
113
+ Additionally, include a learnable `threshold` parameter that
114
+ determines when units "turn on" for the JumpReLU.
115
+ """
116
+ # Biases
117
+ self.b_enc = nn.Parameter(
118
+ torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
119
+ )
120
+ self.b_dec = nn.Parameter(
121
+ torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
122
+ )
123
+
124
+ # Threshold for JumpReLU
125
+ # You can pick a default initialization (e.g., zeros means unit is off unless hidden_pre > 0)
126
+ # or see the training version for more advanced init with log_threshold, etc.
127
+ self.threshold = nn.Parameter(
128
+ torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
129
+ )
130
+
131
+ # Encoder and Decoder weights
132
+ w_enc_data = torch.empty(
133
+ self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
134
+ )
135
+ nn.init.kaiming_uniform_(w_enc_data)
136
+ self.W_enc = nn.Parameter(w_enc_data)
137
+
138
+ w_dec_data = torch.empty(
139
+ self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
140
+ )
141
+ nn.init.kaiming_uniform_(w_dec_data)
142
+ self.W_dec = nn.Parameter(w_dec_data)
143
+
144
+ def encode(
145
+ self, x: Float[torch.Tensor, "... d_in"]
146
+ ) -> Float[torch.Tensor, "... d_sae"]:
147
+ """
148
+ Encode the input tensor into the feature space using JumpReLU.
149
+ The threshold parameter determines which units remain active.
150
+ """
151
+ sae_in = self.process_sae_in(x)
152
+ hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
153
+
154
+ # 1) Apply the base "activation_fn" from config (e.g., ReLU, tanh-relu).
155
+ base_acts = self.activation_fn(hidden_pre)
156
+
157
+ # 2) Zero out any unit whose (hidden_pre <= threshold).
158
+ # We cast the boolean mask to the same dtype for safe multiplication.
159
+ jump_relu_mask = (hidden_pre > self.threshold).to(base_acts.dtype)
160
+
161
+ # 3) Multiply the normally activated units by that mask.
162
+ return self.hook_sae_acts_post(base_acts * jump_relu_mask)
163
+
164
+ def decode(
165
+ self, feature_acts: Float[torch.Tensor, "... d_sae"]
166
+ ) -> Float[torch.Tensor, "... d_in"]:
167
+ """
168
+ Decode the feature activations back to the input space.
169
+ Follows the same steps as StandardSAE: apply scaling, transform, hook, and optionally reshape.
170
+ """
171
+ scaled_features = self.apply_finetuning_scaling_factor(feature_acts)
172
+ sae_out_pre = scaled_features @ self.W_dec + self.b_dec
173
+ sae_out_pre = self.hook_sae_recons(sae_out_pre)
174
+ sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
175
+ return self.reshape_fn_out(sae_out_pre, self.d_head)
176
+
177
+ @torch.no_grad()
178
+ def fold_W_dec_norm(self):
179
+ """
180
+ Override to properly handle threshold adjustment with W_dec norms.
181
+ When we scale the encoder weights, we need to scale the threshold
182
+ by the same factor to maintain the same sparsity pattern.
183
+ """
184
+ # Save the current threshold before calling parent method
185
+ current_thresh = self.threshold.clone()
186
+
187
+ # Get W_dec norms that will be used for scaling
188
+ W_dec_norms = self.W_dec.norm(dim=-1)
189
+
190
+ # Call parent implementation to handle W_enc, W_dec, and b_enc adjustment
191
+ super().fold_W_dec_norm()
192
+
193
+ # Scale the threshold by the same factor as we scaled b_enc
194
+ # This ensures the same features remain active/inactive after folding
195
+ self.threshold.data = current_thresh * W_dec_norms
196
+
197
+
198
+ class JumpReLUTrainingSAE(TrainingSAE):
199
+ """
200
+ JumpReLUTrainingSAE is a training-focused implementation of a SAE using a JumpReLU activation.
201
+
202
+ Similar to the inference-only JumpReLUSAE, but with:
203
+ - A learnable log-threshold parameter (instead of a raw threshold).
204
+ - Forward passes that add noise during training, if configured.
205
+ - A specialized auxiliary loss term for sparsity (L0 or similar).
206
+
207
+ Methods of interest include:
208
+ - initialize_weights: sets up W_enc, b_enc, W_dec, b_dec, and log_threshold.
209
+ - encode_with_hidden_pre_jumprelu: runs a forward pass for training, optionally adding noise.
210
+ - training_forward_pass: calculates MSE and auxiliary losses, returning a TrainStepOutput.
211
+ """
212
+
213
+ b_enc: nn.Parameter
214
+ log_threshold: nn.Parameter
215
+
216
+ def __init__(self, cfg: TrainingSAEConfig, use_error_term: bool = False):
217
+ super().__init__(cfg, use_error_term)
218
+
219
+ # We'll store a bandwidth for the training approach, if needed
220
+ self.bandwidth = cfg.jumprelu_bandwidth
221
+
222
+ # In typical JumpReLU training code, we may track a log_threshold:
223
+ self.log_threshold = nn.Parameter(
224
+ torch.ones(self.cfg.d_sae, dtype=self.dtype, device=self.device)
225
+ * np.log(cfg.jumprelu_init_threshold)
226
+ )
227
+
228
+ def initialize_weights(self) -> None:
229
+ """
230
+ Initialize parameters like the base SAE, but also add log_threshold.
231
+ """
232
+ # Encoder Bias
233
+ self.b_enc = nn.Parameter(
234
+ torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
235
+ )
236
+ # Decoder Bias
237
+ self.b_dec = nn.Parameter(
238
+ torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
239
+ )
240
+ # W_enc
241
+ w_enc_data = torch.nn.init.kaiming_uniform_(
242
+ torch.empty(
243
+ self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
244
+ )
245
+ )
246
+ self.W_enc = nn.Parameter(w_enc_data)
247
+
248
+ # W_dec
249
+ w_dec_data = torch.nn.init.kaiming_uniform_(
250
+ torch.empty(
251
+ self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
252
+ )
253
+ )
254
+ self.W_dec = nn.Parameter(w_dec_data)
255
+
256
+ # Optionally apply orthogonal or heuristic init
257
+ if self.cfg.decoder_orthogonal_init:
258
+ self.W_dec.data = nn.init.orthogonal_(self.W_dec.data.T).T
259
+ elif self.cfg.decoder_heuristic_init:
260
+ self.W_dec.data = torch.rand(
261
+ self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
262
+ )
263
+ self.initialize_decoder_norm_constant_norm()
264
+
265
+ # Optionally transpose
266
+ if self.cfg.init_encoder_as_decoder_transpose:
267
+ self.W_enc.data = self.W_dec.data.T.clone().contiguous()
268
+
269
+ # Optionally normalize columns of W_dec
270
+ if self.cfg.normalize_sae_decoder:
271
+ with torch.no_grad():
272
+ self.set_decoder_norm_to_unit_norm()
273
+
274
+ @property
275
+ def threshold(self) -> torch.Tensor:
276
+ """
277
+ Returns the parameterized threshold > 0 for each unit.
278
+ threshold = exp(log_threshold).
279
+ """
280
+ return torch.exp(self.log_threshold)
281
+
282
+ def encode_with_hidden_pre(
283
+ self, x: Float[torch.Tensor, "... d_in"]
284
+ ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
285
+ sae_in = self.process_sae_in(x)
286
+
287
+ hidden_pre = sae_in @ self.W_enc + self.b_enc
288
+
289
+ if self.training and self.cfg.noise_scale > 0:
290
+ hidden_pre = (
291
+ hidden_pre + torch.randn_like(hidden_pre) * self.cfg.noise_scale
292
+ )
293
+
294
+ feature_acts = JumpReLU.apply(hidden_pre, self.threshold, self.bandwidth)
295
+
296
+ return feature_acts, hidden_pre # type: ignore
297
+
298
+ @override
299
+ def calculate_aux_loss(
300
+ self,
301
+ step_input: TrainStepInput,
302
+ feature_acts: torch.Tensor,
303
+ hidden_pre: torch.Tensor,
304
+ sae_out: torch.Tensor,
305
+ ) -> dict[str, torch.Tensor]:
306
+ """Calculate architecture-specific auxiliary loss terms."""
307
+ l0 = torch.sum(Step.apply(hidden_pre, self.threshold, self.bandwidth), dim=-1) # type: ignore
308
+ l0_loss = (step_input.current_l1_coefficient * l0).mean()
309
+ return {"l0_loss": l0_loss}
310
+
311
+ @torch.no_grad()
312
+ def fold_W_dec_norm(self):
313
+ """
314
+ Override to properly handle threshold adjustment with W_dec norms.
315
+ """
316
+ # Save the current threshold before we call the parent method
317
+ current_thresh = self.threshold.clone()
318
+
319
+ # Get W_dec norms
320
+ W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
321
+
322
+ # Call parent implementation to handle W_enc and W_dec adjustment
323
+ super().fold_W_dec_norm()
324
+
325
+ # Fix: Use squeeze() instead of squeeze(-1) to match old behavior
326
+ self.log_threshold.data = torch.log(current_thresh * W_dec_norms.squeeze())
327
+
328
+ def _create_train_step_output(
329
+ self,
330
+ sae_in: torch.Tensor,
331
+ sae_out: torch.Tensor,
332
+ feature_acts: torch.Tensor,
333
+ hidden_pre: torch.Tensor,
334
+ loss: torch.Tensor,
335
+ losses: dict[str, torch.Tensor],
336
+ ) -> TrainStepOutput:
337
+ """
338
+ Helper to produce a TrainStepOutput from the trainer.
339
+ The old code expects a method named _create_train_step_output().
340
+ """
341
+ return TrainStepOutput(
342
+ sae_in=sae_in,
343
+ sae_out=sae_out,
344
+ feature_acts=feature_acts,
345
+ hidden_pre=hidden_pre,
346
+ loss=loss,
347
+ losses=losses,
348
+ )
349
+
350
+ @torch.no_grad()
351
+ def initialize_decoder_norm_constant_norm(self, norm: float = 0.1):
352
+ """Initialize decoder with constant norm"""
353
+ self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
354
+ self.W_dec.data *= norm
355
+
356
+ def process_state_dict_for_saving(self, state_dict: dict[str, Any]) -> None:
357
+ """Convert log_threshold to threshold for saving"""
358
+ if "log_threshold" in state_dict:
359
+ threshold = torch.exp(state_dict["log_threshold"]).detach().contiguous()
360
+ del state_dict["log_threshold"]
361
+ state_dict["threshold"] = threshold
362
+
363
+ def process_state_dict_for_loading(self, state_dict: dict[str, Any]) -> None:
364
+ """Convert threshold to log_threshold for loading"""
365
+ if "threshold" in state_dict:
366
+ threshold = state_dict["threshold"]
367
+ del state_dict["threshold"]
368
+ state_dict["log_threshold"] = torch.log(threshold).detach().contiguous()