sae-lens 5.10.3__py3-none-any.whl → 6.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,258 @@
1
+ from dataclasses import dataclass
2
+ from typing import Any
3
+
4
+ import torch
5
+ from jaxtyping import Float
6
+ from numpy.typing import NDArray
7
+ from torch import nn
8
+ from typing_extensions import override
9
+
10
+ from sae_lens.saes.sae import (
11
+ SAE,
12
+ SAEConfig,
13
+ TrainCoefficientConfig,
14
+ TrainingSAE,
15
+ TrainingSAEConfig,
16
+ TrainStepInput,
17
+ )
18
+ from sae_lens.util import filter_valid_dataclass_fields
19
+
20
+
21
+ @dataclass
22
+ class GatedSAEConfig(SAEConfig):
23
+ """
24
+ Configuration class for a GatedSAE.
25
+ """
26
+
27
+ @override
28
+ @classmethod
29
+ def architecture(cls) -> str:
30
+ return "gated"
31
+
32
+
33
+ class GatedSAE(SAE[GatedSAEConfig]):
34
+ """
35
+ GatedSAE is an inference-only implementation of a Sparse Autoencoder (SAE)
36
+ using a gated linear encoder and a standard linear decoder.
37
+ """
38
+
39
+ b_gate: nn.Parameter
40
+ b_mag: nn.Parameter
41
+ r_mag: nn.Parameter
42
+
43
+ def __init__(self, cfg: GatedSAEConfig, use_error_term: bool = False):
44
+ super().__init__(cfg, use_error_term)
45
+ # Ensure b_enc does not exist for the gated architecture
46
+ self.b_enc = None
47
+
48
+ @override
49
+ def initialize_weights(self) -> None:
50
+ super().initialize_weights()
51
+ _init_weights_gated(self)
52
+
53
+ def encode(
54
+ self, x: Float[torch.Tensor, "... d_in"]
55
+ ) -> Float[torch.Tensor, "... d_sae"]:
56
+ """
57
+ Encode the input tensor into the feature space using a gated encoder.
58
+ This must match the original encode_gated implementation from SAE class.
59
+ """
60
+ # Preprocess the SAE input (casting type, applying hooks, normalization)
61
+ sae_in = self.process_sae_in(x)
62
+
63
+ # Gating path exactly as in original SAE.encode_gated
64
+ gating_pre_activation = sae_in @ self.W_enc + self.b_gate
65
+ active_features = (gating_pre_activation > 0).to(self.dtype)
66
+
67
+ # Magnitude path (weight sharing with gated encoder)
68
+ magnitude_pre_activation = self.hook_sae_acts_pre(
69
+ sae_in @ (self.W_enc * self.r_mag.exp()) + self.b_mag
70
+ )
71
+ feature_magnitudes = self.activation_fn(magnitude_pre_activation)
72
+
73
+ # Combine gating and magnitudes
74
+ return self.hook_sae_acts_post(active_features * feature_magnitudes)
75
+
76
+ def decode(
77
+ self, feature_acts: Float[torch.Tensor, "... d_sae"]
78
+ ) -> Float[torch.Tensor, "... d_in"]:
79
+ """
80
+ Decode the feature activations back into the input space:
81
+ 1) Apply optional finetuning scaling.
82
+ 2) Linear transform plus bias.
83
+ 3) Run any reconstruction hooks and out-normalization if configured.
84
+ 4) If the SAE was reshaping hook_z activations, reshape back.
85
+ """
86
+ # 1) optional finetuning scaling
87
+ # 2) linear transform
88
+ sae_out_pre = feature_acts @ self.W_dec + self.b_dec
89
+ # 3) hooking and normalization
90
+ sae_out_pre = self.hook_sae_recons(sae_out_pre)
91
+ sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
92
+ # 4) reshape if needed (hook_z)
93
+ return self.reshape_fn_out(sae_out_pre, self.d_head)
94
+
95
+ @torch.no_grad()
96
+ def fold_W_dec_norm(self):
97
+ """Override to handle gated-specific parameters."""
98
+ W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
99
+ self.W_dec.data = self.W_dec.data / W_dec_norms
100
+ self.W_enc.data = self.W_enc.data * W_dec_norms.T
101
+
102
+ # Gated-specific parameters need special handling
103
+ self.r_mag.data = self.r_mag.data * W_dec_norms.squeeze()
104
+ self.b_gate.data = self.b_gate.data * W_dec_norms.squeeze()
105
+ self.b_mag.data = self.b_mag.data * W_dec_norms.squeeze()
106
+
107
+ @torch.no_grad()
108
+ def initialize_decoder_norm_constant_norm(self, norm: float = 0.1):
109
+ """Initialize decoder with constant norm."""
110
+ self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
111
+ self.W_dec.data *= norm
112
+
113
+
114
+ @dataclass
115
+ class GatedTrainingSAEConfig(TrainingSAEConfig):
116
+ """
117
+ Configuration class for training a GatedTrainingSAE.
118
+ """
119
+
120
+ l1_coefficient: float = 1.0
121
+ l1_warm_up_steps: int = 0
122
+
123
+ @override
124
+ @classmethod
125
+ def architecture(cls) -> str:
126
+ return "gated"
127
+
128
+
129
+ class GatedTrainingSAE(TrainingSAE[GatedTrainingSAEConfig]):
130
+ """
131
+ GatedTrainingSAE is a concrete implementation of BaseTrainingSAE for the "gated" SAE architecture.
132
+ It implements:
133
+ - initialize_weights: sets up gating parameters (as in GatedSAE) plus optional training-specific init.
134
+ - encode: calls encode_with_hidden_pre (standard training approach).
135
+ - decode: linear transformation + hooking, same as GatedSAE or StandardTrainingSAE.
136
+ - encode_with_hidden_pre: gating logic + optional noise injection for training.
137
+ - calculate_aux_loss: includes an auxiliary reconstruction path and gating-based sparsity penalty.
138
+ - training_forward_pass: calls encode_with_hidden_pre, decode, and sums up MSE + gating losses.
139
+ """
140
+
141
+ b_gate: nn.Parameter # type: ignore
142
+ b_mag: nn.Parameter # type: ignore
143
+ r_mag: nn.Parameter # type: ignore
144
+
145
+ def __init__(self, cfg: GatedTrainingSAEConfig, use_error_term: bool = False):
146
+ if use_error_term:
147
+ raise ValueError(
148
+ "GatedSAE does not support `use_error_term`. Please set `use_error_term=False`."
149
+ )
150
+ super().__init__(cfg, use_error_term)
151
+
152
+ def initialize_weights(self) -> None:
153
+ super().initialize_weights()
154
+ _init_weights_gated(self)
155
+
156
+ def encode_with_hidden_pre(
157
+ self, x: Float[torch.Tensor, "... d_in"]
158
+ ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
159
+ """
160
+ Gated forward pass with pre-activation (for training).
161
+ We also inject noise if self.training is True.
162
+ """
163
+ sae_in = self.process_sae_in(x)
164
+
165
+ # Gating path
166
+ gating_pre_activation = sae_in @ self.W_enc + self.b_gate
167
+ active_features = (gating_pre_activation > 0).to(self.dtype)
168
+
169
+ # Magnitude path
170
+ magnitude_pre_activation = sae_in @ (self.W_enc * self.r_mag.exp()) + self.b_mag
171
+ if self.training and self.cfg.noise_scale > 0:
172
+ magnitude_pre_activation += (
173
+ torch.randn_like(magnitude_pre_activation) * self.cfg.noise_scale
174
+ )
175
+ magnitude_pre_activation = self.hook_sae_acts_pre(magnitude_pre_activation)
176
+
177
+ feature_magnitudes = self.activation_fn(magnitude_pre_activation)
178
+
179
+ # Combine gating path and magnitude path
180
+ feature_acts = self.hook_sae_acts_post(active_features * feature_magnitudes)
181
+
182
+ # Return both the final feature activations and the pre-activation (for logging or penalty)
183
+ return feature_acts, magnitude_pre_activation
184
+
185
+ def calculate_aux_loss(
186
+ self,
187
+ step_input: TrainStepInput,
188
+ feature_acts: torch.Tensor,
189
+ hidden_pre: torch.Tensor,
190
+ sae_out: torch.Tensor,
191
+ ) -> dict[str, torch.Tensor]:
192
+ # Re-center the input if apply_b_dec_to_input is set
193
+ sae_in_centered = step_input.sae_in - (
194
+ self.b_dec * self.cfg.apply_b_dec_to_input
195
+ )
196
+
197
+ # The gating pre-activation (pi_gate) for the auxiliary path
198
+ pi_gate = sae_in_centered @ self.W_enc + self.b_gate
199
+ pi_gate_act = torch.relu(pi_gate)
200
+
201
+ # L1-like penalty scaled by W_dec norms
202
+ l1_loss = (
203
+ step_input.coefficients["l1"]
204
+ * torch.sum(pi_gate_act * self.W_dec.norm(dim=1), dim=-1).mean()
205
+ )
206
+
207
+ # Aux reconstruction: reconstruct x purely from gating path
208
+ via_gate_reconstruction = pi_gate_act @ self.W_dec + self.b_dec
209
+ aux_recon_loss = (
210
+ (via_gate_reconstruction - step_input.sae_in).pow(2).sum(dim=-1).mean()
211
+ )
212
+
213
+ # Return both losses separately
214
+ return {"l1_loss": l1_loss, "auxiliary_reconstruction_loss": aux_recon_loss}
215
+
216
+ def log_histograms(self) -> dict[str, NDArray[Any]]:
217
+ """Log histograms of the weights and biases."""
218
+ b_gate_dist = self.b_gate.detach().float().cpu().numpy()
219
+ b_mag_dist = self.b_mag.detach().float().cpu().numpy()
220
+ return {
221
+ **super().log_histograms(),
222
+ "weights/b_gate": b_gate_dist,
223
+ "weights/b_mag": b_mag_dist,
224
+ }
225
+
226
+ @torch.no_grad()
227
+ def initialize_decoder_norm_constant_norm(self, norm: float = 0.1):
228
+ """Initialize decoder with constant norm"""
229
+ self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
230
+ self.W_dec.data *= norm
231
+
232
+ def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]:
233
+ return {
234
+ "l1": TrainCoefficientConfig(
235
+ value=self.cfg.l1_coefficient,
236
+ warm_up_steps=self.cfg.l1_warm_up_steps,
237
+ ),
238
+ }
239
+
240
+ def to_inference_config_dict(self) -> dict[str, Any]:
241
+ return filter_valid_dataclass_fields(
242
+ self.cfg.to_dict(), GatedSAEConfig, ["architecture"]
243
+ )
244
+
245
+
246
+ def _init_weights_gated(
247
+ sae: SAE[GatedSAEConfig] | TrainingSAE[GatedTrainingSAEConfig],
248
+ ) -> None:
249
+ sae.b_gate = nn.Parameter(
250
+ torch.zeros(sae.cfg.d_sae, dtype=sae.dtype, device=sae.device)
251
+ )
252
+ # Ensure r_mag is initialized to zero as in original
253
+ sae.r_mag = nn.Parameter(
254
+ torch.zeros(sae.cfg.d_sae, dtype=sae.dtype, device=sae.device)
255
+ )
256
+ sae.b_mag = nn.Parameter(
257
+ torch.zeros(sae.cfg.d_sae, dtype=sae.dtype, device=sae.device)
258
+ )
@@ -0,0 +1,354 @@
1
+ from dataclasses import dataclass
2
+ from typing import Any
3
+
4
+ import numpy as np
5
+ import torch
6
+ from jaxtyping import Float
7
+ from torch import nn
8
+ from typing_extensions import override
9
+
10
+ from sae_lens.saes.sae import (
11
+ SAE,
12
+ SAEConfig,
13
+ TrainCoefficientConfig,
14
+ TrainingSAE,
15
+ TrainingSAEConfig,
16
+ TrainStepInput,
17
+ TrainStepOutput,
18
+ )
19
+ from sae_lens.util import filter_valid_dataclass_fields
20
+
21
+
22
+ def rectangle(x: torch.Tensor) -> torch.Tensor:
23
+ return ((x > -0.5) & (x < 0.5)).to(x)
24
+
25
+
26
+ class Step(torch.autograd.Function):
27
+ @staticmethod
28
+ def forward(
29
+ x: torch.Tensor,
30
+ threshold: torch.Tensor,
31
+ bandwidth: float, # noqa: ARG004
32
+ ) -> torch.Tensor:
33
+ return (x > threshold).to(x)
34
+
35
+ @staticmethod
36
+ def setup_context(
37
+ ctx: Any, inputs: tuple[torch.Tensor, torch.Tensor, float], output: torch.Tensor
38
+ ) -> None:
39
+ x, threshold, bandwidth = inputs
40
+ del output
41
+ ctx.save_for_backward(x, threshold)
42
+ ctx.bandwidth = bandwidth
43
+
44
+ @staticmethod
45
+ def backward( # type: ignore[override]
46
+ ctx: Any, grad_output: torch.Tensor
47
+ ) -> tuple[None, torch.Tensor, None]:
48
+ x, threshold = ctx.saved_tensors
49
+ bandwidth = ctx.bandwidth
50
+ threshold_grad = torch.sum(
51
+ -(1.0 / bandwidth) * rectangle((x - threshold) / bandwidth) * grad_output,
52
+ dim=0,
53
+ )
54
+ return None, threshold_grad, None
55
+
56
+
57
+ class JumpReLU(torch.autograd.Function):
58
+ @staticmethod
59
+ def forward(
60
+ x: torch.Tensor,
61
+ threshold: torch.Tensor,
62
+ bandwidth: float, # noqa: ARG004
63
+ ) -> torch.Tensor:
64
+ return (x * (x > threshold)).to(x)
65
+
66
+ @staticmethod
67
+ def setup_context(
68
+ ctx: Any, inputs: tuple[torch.Tensor, torch.Tensor, float], output: torch.Tensor
69
+ ) -> None:
70
+ x, threshold, bandwidth = inputs
71
+ del output
72
+ ctx.save_for_backward(x, threshold)
73
+ ctx.bandwidth = bandwidth
74
+
75
+ @staticmethod
76
+ def backward( # type: ignore[override]
77
+ ctx: Any, grad_output: torch.Tensor
78
+ ) -> tuple[torch.Tensor, torch.Tensor, None]:
79
+ x, threshold = ctx.saved_tensors
80
+ bandwidth = ctx.bandwidth
81
+ x_grad = (x > threshold) * grad_output # We don't apply STE to x input
82
+ threshold_grad = torch.sum(
83
+ -(threshold / bandwidth)
84
+ * rectangle((x - threshold) / bandwidth)
85
+ * grad_output,
86
+ dim=0,
87
+ )
88
+ return x_grad, threshold_grad, None
89
+
90
+
91
+ @dataclass
92
+ class JumpReLUSAEConfig(SAEConfig):
93
+ """
94
+ Configuration class for a JumpReLUSAE.
95
+ """
96
+
97
+ @override
98
+ @classmethod
99
+ def architecture(cls) -> str:
100
+ return "jumprelu"
101
+
102
+
103
+ class JumpReLUSAE(SAE[JumpReLUSAEConfig]):
104
+ """
105
+ JumpReLUSAE is an inference-only implementation of a Sparse Autoencoder (SAE)
106
+ using a JumpReLU activation. For each unit, if its pre-activation is
107
+ <= threshold, that unit is zeroed out; otherwise, it follows a user-specified
108
+ activation function (e.g., ReLU, tanh-relu, etc.).
109
+
110
+ It implements:
111
+ - initialize_weights: sets up parameters, including a threshold.
112
+ - encode: computes the feature activations using JumpReLU.
113
+ - decode: reconstructs the input from the feature activations.
114
+
115
+ The BaseSAE.forward() method automatically calls encode and decode,
116
+ including any error-term processing if configured.
117
+ """
118
+
119
+ b_enc: nn.Parameter
120
+ threshold: nn.Parameter
121
+
122
+ def __init__(self, cfg: JumpReLUSAEConfig, use_error_term: bool = False):
123
+ super().__init__(cfg, use_error_term)
124
+
125
+ @override
126
+ def initialize_weights(self) -> None:
127
+ super().initialize_weights()
128
+ self.threshold = nn.Parameter(
129
+ torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
130
+ )
131
+ self.b_enc = nn.Parameter(
132
+ torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
133
+ )
134
+
135
+ def encode(
136
+ self, x: Float[torch.Tensor, "... d_in"]
137
+ ) -> Float[torch.Tensor, "... d_sae"]:
138
+ """
139
+ Encode the input tensor into the feature space using JumpReLU.
140
+ The threshold parameter determines which units remain active.
141
+ """
142
+ sae_in = self.process_sae_in(x)
143
+ hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
144
+
145
+ # 1) Apply the base "activation_fn" from config (e.g., ReLU, tanh-relu).
146
+ base_acts = self.activation_fn(hidden_pre)
147
+
148
+ # 2) Zero out any unit whose (hidden_pre <= threshold).
149
+ # We cast the boolean mask to the same dtype for safe multiplication.
150
+ jump_relu_mask = (hidden_pre > self.threshold).to(base_acts.dtype)
151
+
152
+ # 3) Multiply the normally activated units by that mask.
153
+ return self.hook_sae_acts_post(base_acts * jump_relu_mask)
154
+
155
+ def decode(
156
+ self, feature_acts: Float[torch.Tensor, "... d_sae"]
157
+ ) -> Float[torch.Tensor, "... d_in"]:
158
+ """
159
+ Decode the feature activations back to the input space.
160
+ Follows the same steps as StandardSAE: apply scaling, transform, hook, and optionally reshape.
161
+ """
162
+ sae_out_pre = feature_acts @ self.W_dec + self.b_dec
163
+ sae_out_pre = self.hook_sae_recons(sae_out_pre)
164
+ sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
165
+ return self.reshape_fn_out(sae_out_pre, self.d_head)
166
+
167
+ @torch.no_grad()
168
+ def fold_W_dec_norm(self):
169
+ """
170
+ Override to properly handle threshold adjustment with W_dec norms.
171
+ When we scale the encoder weights, we need to scale the threshold
172
+ by the same factor to maintain the same sparsity pattern.
173
+ """
174
+ # Save the current threshold before calling parent method
175
+ current_thresh = self.threshold.clone()
176
+
177
+ # Get W_dec norms that will be used for scaling
178
+ W_dec_norms = self.W_dec.norm(dim=-1)
179
+
180
+ # Call parent implementation to handle W_enc, W_dec, and b_enc adjustment
181
+ super().fold_W_dec_norm()
182
+
183
+ # Scale the threshold by the same factor as we scaled b_enc
184
+ # This ensures the same features remain active/inactive after folding
185
+ self.threshold.data = current_thresh * W_dec_norms
186
+
187
+
188
+ @dataclass
189
+ class JumpReLUTrainingSAEConfig(TrainingSAEConfig):
190
+ """
191
+ Configuration class for training a JumpReLUTrainingSAE.
192
+ """
193
+
194
+ jumprelu_init_threshold: float = 0.001
195
+ jumprelu_bandwidth: float = 0.001
196
+ l0_coefficient: float = 1.0
197
+ l0_warm_up_steps: int = 0
198
+
199
+ @override
200
+ @classmethod
201
+ def architecture(cls) -> str:
202
+ return "jumprelu"
203
+
204
+
205
+ class JumpReLUTrainingSAE(TrainingSAE[JumpReLUTrainingSAEConfig]):
206
+ """
207
+ JumpReLUTrainingSAE is a training-focused implementation of a SAE using a JumpReLU activation.
208
+
209
+ Similar to the inference-only JumpReLUSAE, but with:
210
+ - A learnable log-threshold parameter (instead of a raw threshold).
211
+ - Forward passes that add noise during training, if configured.
212
+ - A specialized auxiliary loss term for sparsity (L0 or similar).
213
+
214
+ Methods of interest include:
215
+ - initialize_weights: sets up W_enc, b_enc, W_dec, b_dec, and log_threshold.
216
+ - encode_with_hidden_pre_jumprelu: runs a forward pass for training, optionally adding noise.
217
+ - training_forward_pass: calculates MSE and auxiliary losses, returning a TrainStepOutput.
218
+ """
219
+
220
+ b_enc: nn.Parameter
221
+ log_threshold: nn.Parameter
222
+
223
+ def __init__(self, cfg: JumpReLUTrainingSAEConfig, use_error_term: bool = False):
224
+ super().__init__(cfg, use_error_term)
225
+
226
+ # We'll store a bandwidth for the training approach, if needed
227
+ self.bandwidth = cfg.jumprelu_bandwidth
228
+
229
+ # In typical JumpReLU training code, we may track a log_threshold:
230
+ self.log_threshold = nn.Parameter(
231
+ torch.ones(self.cfg.d_sae, dtype=self.dtype, device=self.device)
232
+ * np.log(cfg.jumprelu_init_threshold)
233
+ )
234
+
235
+ @override
236
+ def initialize_weights(self) -> None:
237
+ """
238
+ Initialize parameters like the base SAE, but also add log_threshold.
239
+ """
240
+ super().initialize_weights()
241
+ # Encoder Bias
242
+ self.b_enc = nn.Parameter(
243
+ torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
244
+ )
245
+
246
+ @property
247
+ def threshold(self) -> torch.Tensor:
248
+ """
249
+ Returns the parameterized threshold > 0 for each unit.
250
+ threshold = exp(log_threshold).
251
+ """
252
+ return torch.exp(self.log_threshold)
253
+
254
+ def encode_with_hidden_pre(
255
+ self, x: Float[torch.Tensor, "... d_in"]
256
+ ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
257
+ sae_in = self.process_sae_in(x)
258
+
259
+ hidden_pre = sae_in @ self.W_enc + self.b_enc
260
+
261
+ if self.training and self.cfg.noise_scale > 0:
262
+ hidden_pre = (
263
+ hidden_pre + torch.randn_like(hidden_pre) * self.cfg.noise_scale
264
+ )
265
+
266
+ feature_acts = JumpReLU.apply(hidden_pre, self.threshold, self.bandwidth)
267
+
268
+ return feature_acts, hidden_pre # type: ignore
269
+
270
+ @override
271
+ def calculate_aux_loss(
272
+ self,
273
+ step_input: TrainStepInput,
274
+ feature_acts: torch.Tensor,
275
+ hidden_pre: torch.Tensor,
276
+ sae_out: torch.Tensor,
277
+ ) -> dict[str, torch.Tensor]:
278
+ """Calculate architecture-specific auxiliary loss terms."""
279
+ l0 = torch.sum(Step.apply(hidden_pre, self.threshold, self.bandwidth), dim=-1) # type: ignore
280
+ l0_loss = (step_input.coefficients["l0"] * l0).mean()
281
+ return {"l0_loss": l0_loss}
282
+
283
+ @override
284
+ def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]:
285
+ return {
286
+ "l0": TrainCoefficientConfig(
287
+ value=self.cfg.l0_coefficient,
288
+ warm_up_steps=self.cfg.l0_warm_up_steps,
289
+ ),
290
+ }
291
+
292
+ @torch.no_grad()
293
+ def fold_W_dec_norm(self):
294
+ """
295
+ Override to properly handle threshold adjustment with W_dec norms.
296
+ """
297
+ # Save the current threshold before we call the parent method
298
+ current_thresh = self.threshold.clone()
299
+
300
+ # Get W_dec norms
301
+ W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
302
+
303
+ # Call parent implementation to handle W_enc and W_dec adjustment
304
+ super().fold_W_dec_norm()
305
+
306
+ # Fix: Use squeeze() instead of squeeze(-1) to match old behavior
307
+ self.log_threshold.data = torch.log(current_thresh * W_dec_norms.squeeze())
308
+
309
+ def _create_train_step_output(
310
+ self,
311
+ sae_in: torch.Tensor,
312
+ sae_out: torch.Tensor,
313
+ feature_acts: torch.Tensor,
314
+ hidden_pre: torch.Tensor,
315
+ loss: torch.Tensor,
316
+ losses: dict[str, torch.Tensor],
317
+ ) -> TrainStepOutput:
318
+ """
319
+ Helper to produce a TrainStepOutput from the trainer.
320
+ The old code expects a method named _create_train_step_output().
321
+ """
322
+ return TrainStepOutput(
323
+ sae_in=sae_in,
324
+ sae_out=sae_out,
325
+ feature_acts=feature_acts,
326
+ hidden_pre=hidden_pre,
327
+ loss=loss,
328
+ losses=losses,
329
+ )
330
+
331
+ @torch.no_grad()
332
+ def initialize_decoder_norm_constant_norm(self, norm: float = 0.1):
333
+ """Initialize decoder with constant norm"""
334
+ self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
335
+ self.W_dec.data *= norm
336
+
337
+ def process_state_dict_for_saving(self, state_dict: dict[str, Any]) -> None:
338
+ """Convert log_threshold to threshold for saving"""
339
+ if "log_threshold" in state_dict:
340
+ threshold = torch.exp(state_dict["log_threshold"]).detach().contiguous()
341
+ del state_dict["log_threshold"]
342
+ state_dict["threshold"] = threshold
343
+
344
+ def process_state_dict_for_loading(self, state_dict: dict[str, Any]) -> None:
345
+ """Convert threshold to log_threshold for loading"""
346
+ if "threshold" in state_dict:
347
+ threshold = state_dict["threshold"]
348
+ del state_dict["threshold"]
349
+ state_dict["log_threshold"] = torch.log(threshold).detach().contiguous()
350
+
351
+ def to_inference_config_dict(self) -> dict[str, Any]:
352
+ return filter_valid_dataclass_fields(
353
+ self.cfg.to_dict(), JumpReLUSAEConfig, ["architecture"]
354
+ )