sae-lens 5.9.1__py3-none-any.whl → 6.0.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +22 -6
- sae_lens/analysis/hooked_sae_transformer.py +2 -2
- sae_lens/config.py +66 -23
- sae_lens/evals.py +6 -5
- sae_lens/{toolkit → loading}/pretrained_sae_loaders.py +33 -25
- sae_lens/regsitry.py +34 -0
- sae_lens/sae_training_runner.py +18 -33
- sae_lens/saes/gated_sae.py +247 -0
- sae_lens/saes/jumprelu_sae.py +368 -0
- sae_lens/saes/sae.py +970 -0
- sae_lens/saes/standard_sae.py +167 -0
- sae_lens/saes/topk_sae.py +305 -0
- sae_lens/training/activations_store.py +2 -2
- sae_lens/training/sae_trainer.py +13 -19
- sae_lens/training/upload_saes_to_huggingface.py +1 -1
- {sae_lens-5.9.1.dist-info → sae_lens-6.0.0rc1.dist-info}/METADATA +2 -2
- sae_lens-6.0.0rc1.dist-info/RECORD +32 -0
- sae_lens/sae.py +0 -747
- sae_lens/training/training_sae.py +0 -705
- sae_lens-5.9.1.dist-info/RECORD +0 -28
- /sae_lens/{toolkit → loading}/__init__.py +0 -0
- /sae_lens/{toolkit → loading}/pretrained_saes_directory.py +0 -0
- {sae_lens-5.9.1.dist-info → sae_lens-6.0.0rc1.dist-info}/LICENSE +0 -0
- {sae_lens-5.9.1.dist-info → sae_lens-6.0.0rc1.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,247 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from jaxtyping import Float
|
|
5
|
+
from numpy.typing import NDArray
|
|
6
|
+
from torch import nn
|
|
7
|
+
|
|
8
|
+
from sae_lens.saes.sae import (
|
|
9
|
+
SAE,
|
|
10
|
+
SAEConfig,
|
|
11
|
+
TrainingSAE,
|
|
12
|
+
TrainingSAEConfig,
|
|
13
|
+
TrainStepInput,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class GatedSAE(SAE):
|
|
18
|
+
"""
|
|
19
|
+
GatedSAE is an inference-only implementation of a Sparse Autoencoder (SAE)
|
|
20
|
+
using a gated linear encoder and a standard linear decoder.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
b_gate: nn.Parameter
|
|
24
|
+
b_mag: nn.Parameter
|
|
25
|
+
r_mag: nn.Parameter
|
|
26
|
+
|
|
27
|
+
def __init__(self, cfg: SAEConfig, use_error_term: bool = False):
|
|
28
|
+
super().__init__(cfg, use_error_term)
|
|
29
|
+
# Ensure b_enc does not exist for the gated architecture
|
|
30
|
+
self.b_enc = None
|
|
31
|
+
|
|
32
|
+
def initialize_weights(self) -> None:
|
|
33
|
+
"""
|
|
34
|
+
Initialize weights exactly as in the original SAE class for gated architecture.
|
|
35
|
+
"""
|
|
36
|
+
# Use the same initialization methods and values as in original SAE
|
|
37
|
+
self.W_enc = nn.Parameter(
|
|
38
|
+
torch.nn.init.kaiming_uniform_(
|
|
39
|
+
torch.empty(
|
|
40
|
+
self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
|
|
41
|
+
)
|
|
42
|
+
)
|
|
43
|
+
)
|
|
44
|
+
self.b_gate = nn.Parameter(
|
|
45
|
+
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
|
|
46
|
+
)
|
|
47
|
+
# Ensure r_mag is initialized to zero as in original
|
|
48
|
+
self.r_mag = nn.Parameter(
|
|
49
|
+
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
|
|
50
|
+
)
|
|
51
|
+
self.b_mag = nn.Parameter(
|
|
52
|
+
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
# Decoder parameters with same initialization as original
|
|
56
|
+
self.W_dec = nn.Parameter(
|
|
57
|
+
torch.nn.init.kaiming_uniform_(
|
|
58
|
+
torch.empty(
|
|
59
|
+
self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
|
|
60
|
+
)
|
|
61
|
+
)
|
|
62
|
+
)
|
|
63
|
+
self.b_dec = nn.Parameter(
|
|
64
|
+
torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
# after defining b_gate, b_mag, etc.:
|
|
68
|
+
self.b_enc = None
|
|
69
|
+
|
|
70
|
+
def encode(
|
|
71
|
+
self, x: Float[torch.Tensor, "... d_in"]
|
|
72
|
+
) -> Float[torch.Tensor, "... d_sae"]:
|
|
73
|
+
"""
|
|
74
|
+
Encode the input tensor into the feature space using a gated encoder.
|
|
75
|
+
This must match the original encode_gated implementation from SAE class.
|
|
76
|
+
"""
|
|
77
|
+
# Preprocess the SAE input (casting type, applying hooks, normalization)
|
|
78
|
+
sae_in = self.process_sae_in(x)
|
|
79
|
+
|
|
80
|
+
# Gating path exactly as in original SAE.encode_gated
|
|
81
|
+
gating_pre_activation = sae_in @ self.W_enc + self.b_gate
|
|
82
|
+
active_features = (gating_pre_activation > 0).to(self.dtype)
|
|
83
|
+
|
|
84
|
+
# Magnitude path (weight sharing with gated encoder)
|
|
85
|
+
magnitude_pre_activation = self.hook_sae_acts_pre(
|
|
86
|
+
sae_in @ (self.W_enc * self.r_mag.exp()) + self.b_mag
|
|
87
|
+
)
|
|
88
|
+
feature_magnitudes = self.activation_fn(magnitude_pre_activation)
|
|
89
|
+
|
|
90
|
+
# Combine gating and magnitudes
|
|
91
|
+
return self.hook_sae_acts_post(active_features * feature_magnitudes)
|
|
92
|
+
|
|
93
|
+
def decode(
|
|
94
|
+
self, feature_acts: Float[torch.Tensor, "... d_sae"]
|
|
95
|
+
) -> Float[torch.Tensor, "... d_in"]:
|
|
96
|
+
"""
|
|
97
|
+
Decode the feature activations back into the input space:
|
|
98
|
+
1) Apply optional finetuning scaling.
|
|
99
|
+
2) Linear transform plus bias.
|
|
100
|
+
3) Run any reconstruction hooks and out-normalization if configured.
|
|
101
|
+
4) If the SAE was reshaping hook_z activations, reshape back.
|
|
102
|
+
"""
|
|
103
|
+
# 1) optional finetuning scaling
|
|
104
|
+
scaled_features = self.apply_finetuning_scaling_factor(feature_acts)
|
|
105
|
+
# 2) linear transform
|
|
106
|
+
sae_out_pre = scaled_features @ self.W_dec + self.b_dec
|
|
107
|
+
# 3) hooking and normalization
|
|
108
|
+
sae_out_pre = self.hook_sae_recons(sae_out_pre)
|
|
109
|
+
sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
|
|
110
|
+
# 4) reshape if needed (hook_z)
|
|
111
|
+
return self.reshape_fn_out(sae_out_pre, self.d_head)
|
|
112
|
+
|
|
113
|
+
@torch.no_grad()
|
|
114
|
+
def fold_W_dec_norm(self):
|
|
115
|
+
"""Override to handle gated-specific parameters."""
|
|
116
|
+
W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
|
|
117
|
+
self.W_dec.data = self.W_dec.data / W_dec_norms
|
|
118
|
+
self.W_enc.data = self.W_enc.data * W_dec_norms.T
|
|
119
|
+
|
|
120
|
+
# Gated-specific parameters need special handling
|
|
121
|
+
self.r_mag.data = self.r_mag.data * W_dec_norms.squeeze()
|
|
122
|
+
self.b_gate.data = self.b_gate.data * W_dec_norms.squeeze()
|
|
123
|
+
self.b_mag.data = self.b_mag.data * W_dec_norms.squeeze()
|
|
124
|
+
|
|
125
|
+
@torch.no_grad()
|
|
126
|
+
def initialize_decoder_norm_constant_norm(self, norm: float = 0.1):
|
|
127
|
+
"""Initialize decoder with constant norm."""
|
|
128
|
+
self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
|
|
129
|
+
self.W_dec.data *= norm
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
class GatedTrainingSAE(TrainingSAE):
|
|
133
|
+
"""
|
|
134
|
+
GatedTrainingSAE is a concrete implementation of BaseTrainingSAE for the "gated" SAE architecture.
|
|
135
|
+
It implements:
|
|
136
|
+
- initialize_weights: sets up gating parameters (as in GatedSAE) plus optional training-specific init.
|
|
137
|
+
- encode: calls encode_with_hidden_pre (standard training approach).
|
|
138
|
+
- decode: linear transformation + hooking, same as GatedSAE or StandardTrainingSAE.
|
|
139
|
+
- encode_with_hidden_pre: gating logic + optional noise injection for training.
|
|
140
|
+
- calculate_aux_loss: includes an auxiliary reconstruction path and gating-based sparsity penalty.
|
|
141
|
+
- training_forward_pass: calls encode_with_hidden_pre, decode, and sums up MSE + gating losses.
|
|
142
|
+
"""
|
|
143
|
+
|
|
144
|
+
b_gate: nn.Parameter # type: ignore
|
|
145
|
+
b_mag: nn.Parameter # type: ignore
|
|
146
|
+
r_mag: nn.Parameter # type: ignore
|
|
147
|
+
|
|
148
|
+
def __init__(self, cfg: TrainingSAEConfig, use_error_term: bool = False):
|
|
149
|
+
if use_error_term:
|
|
150
|
+
raise ValueError(
|
|
151
|
+
"GatedSAE does not support `use_error_term`. Please set `use_error_term=False`."
|
|
152
|
+
)
|
|
153
|
+
super().__init__(cfg, use_error_term)
|
|
154
|
+
|
|
155
|
+
def initialize_weights(self) -> None:
|
|
156
|
+
# Reuse the gating parameter initialization from GatedSAE:
|
|
157
|
+
GatedSAE.initialize_weights(self) # type: ignore
|
|
158
|
+
|
|
159
|
+
# Additional training-specific logic, e.g. orthogonal init or heuristics:
|
|
160
|
+
if self.cfg.decoder_orthogonal_init:
|
|
161
|
+
self.W_dec.data = nn.init.orthogonal_(self.W_dec.data.T).T
|
|
162
|
+
elif self.cfg.decoder_heuristic_init:
|
|
163
|
+
self.W_dec.data = torch.rand(
|
|
164
|
+
self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
|
|
165
|
+
)
|
|
166
|
+
self.initialize_decoder_norm_constant_norm()
|
|
167
|
+
if self.cfg.init_encoder_as_decoder_transpose:
|
|
168
|
+
self.W_enc.data = self.W_dec.data.T.clone().contiguous()
|
|
169
|
+
if self.cfg.normalize_sae_decoder:
|
|
170
|
+
with torch.no_grad():
|
|
171
|
+
self.set_decoder_norm_to_unit_norm()
|
|
172
|
+
|
|
173
|
+
def encode_with_hidden_pre(
|
|
174
|
+
self, x: Float[torch.Tensor, "... d_in"]
|
|
175
|
+
) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
|
|
176
|
+
"""
|
|
177
|
+
Gated forward pass with pre-activation (for training).
|
|
178
|
+
We also inject noise if self.training is True.
|
|
179
|
+
"""
|
|
180
|
+
sae_in = self.process_sae_in(x)
|
|
181
|
+
|
|
182
|
+
# Gating path
|
|
183
|
+
gating_pre_activation = sae_in @ self.W_enc + self.b_gate
|
|
184
|
+
active_features = (gating_pre_activation > 0).to(self.dtype)
|
|
185
|
+
|
|
186
|
+
# Magnitude path
|
|
187
|
+
magnitude_pre_activation = sae_in @ (self.W_enc * self.r_mag.exp()) + self.b_mag
|
|
188
|
+
if self.training and self.cfg.noise_scale > 0:
|
|
189
|
+
magnitude_pre_activation += (
|
|
190
|
+
torch.randn_like(magnitude_pre_activation) * self.cfg.noise_scale
|
|
191
|
+
)
|
|
192
|
+
magnitude_pre_activation = self.hook_sae_acts_pre(magnitude_pre_activation)
|
|
193
|
+
|
|
194
|
+
feature_magnitudes = self.activation_fn(magnitude_pre_activation)
|
|
195
|
+
|
|
196
|
+
# Combine gating path and magnitude path
|
|
197
|
+
feature_acts = self.hook_sae_acts_post(active_features * feature_magnitudes)
|
|
198
|
+
|
|
199
|
+
# Return both the final feature activations and the pre-activation (for logging or penalty)
|
|
200
|
+
return feature_acts, magnitude_pre_activation
|
|
201
|
+
|
|
202
|
+
def calculate_aux_loss(
|
|
203
|
+
self,
|
|
204
|
+
step_input: TrainStepInput,
|
|
205
|
+
feature_acts: torch.Tensor,
|
|
206
|
+
hidden_pre: torch.Tensor,
|
|
207
|
+
sae_out: torch.Tensor,
|
|
208
|
+
) -> dict[str, torch.Tensor]:
|
|
209
|
+
# Re-center the input if apply_b_dec_to_input is set
|
|
210
|
+
sae_in_centered = step_input.sae_in - (
|
|
211
|
+
self.b_dec * self.cfg.apply_b_dec_to_input
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
# The gating pre-activation (pi_gate) for the auxiliary path
|
|
215
|
+
pi_gate = sae_in_centered @ self.W_enc + self.b_gate
|
|
216
|
+
pi_gate_act = torch.relu(pi_gate)
|
|
217
|
+
|
|
218
|
+
# L1-like penalty scaled by W_dec norms
|
|
219
|
+
l1_loss = (
|
|
220
|
+
step_input.current_l1_coefficient
|
|
221
|
+
* torch.sum(pi_gate_act * self.W_dec.norm(dim=1), dim=-1).mean()
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
# Aux reconstruction: reconstruct x purely from gating path
|
|
225
|
+
via_gate_reconstruction = pi_gate_act @ self.W_dec + self.b_dec
|
|
226
|
+
aux_recon_loss = (
|
|
227
|
+
(via_gate_reconstruction - step_input.sae_in).pow(2).sum(dim=-1).mean()
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
# Return both losses separately
|
|
231
|
+
return {"l1_loss": l1_loss, "auxiliary_reconstruction_loss": aux_recon_loss}
|
|
232
|
+
|
|
233
|
+
def log_histograms(self) -> dict[str, NDArray[Any]]:
|
|
234
|
+
"""Log histograms of the weights and biases."""
|
|
235
|
+
b_gate_dist = self.b_gate.detach().float().cpu().numpy()
|
|
236
|
+
b_mag_dist = self.b_mag.detach().float().cpu().numpy()
|
|
237
|
+
return {
|
|
238
|
+
**super().log_histograms(),
|
|
239
|
+
"weights/b_gate": b_gate_dist,
|
|
240
|
+
"weights/b_mag": b_mag_dist,
|
|
241
|
+
}
|
|
242
|
+
|
|
243
|
+
@torch.no_grad()
|
|
244
|
+
def initialize_decoder_norm_constant_norm(self, norm: float = 0.1):
|
|
245
|
+
"""Initialize decoder with constant norm"""
|
|
246
|
+
self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
|
|
247
|
+
self.W_dec.data *= norm
|
|
@@ -0,0 +1,368 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
from jaxtyping import Float
|
|
6
|
+
from torch import nn
|
|
7
|
+
from typing_extensions import override
|
|
8
|
+
|
|
9
|
+
from sae_lens.saes.sae import (
|
|
10
|
+
SAE,
|
|
11
|
+
SAEConfig,
|
|
12
|
+
TrainingSAE,
|
|
13
|
+
TrainingSAEConfig,
|
|
14
|
+
TrainStepInput,
|
|
15
|
+
TrainStepOutput,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def rectangle(x: torch.Tensor) -> torch.Tensor:
|
|
20
|
+
return ((x > -0.5) & (x < 0.5)).to(x)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class Step(torch.autograd.Function):
|
|
24
|
+
@staticmethod
|
|
25
|
+
def forward(
|
|
26
|
+
x: torch.Tensor,
|
|
27
|
+
threshold: torch.Tensor,
|
|
28
|
+
bandwidth: float, # noqa: ARG004
|
|
29
|
+
) -> torch.Tensor:
|
|
30
|
+
return (x > threshold).to(x)
|
|
31
|
+
|
|
32
|
+
@staticmethod
|
|
33
|
+
def setup_context(
|
|
34
|
+
ctx: Any, inputs: tuple[torch.Tensor, torch.Tensor, float], output: torch.Tensor
|
|
35
|
+
) -> None:
|
|
36
|
+
x, threshold, bandwidth = inputs
|
|
37
|
+
del output
|
|
38
|
+
ctx.save_for_backward(x, threshold)
|
|
39
|
+
ctx.bandwidth = bandwidth
|
|
40
|
+
|
|
41
|
+
@staticmethod
|
|
42
|
+
def backward( # type: ignore[override]
|
|
43
|
+
ctx: Any, grad_output: torch.Tensor
|
|
44
|
+
) -> tuple[None, torch.Tensor, None]:
|
|
45
|
+
x, threshold = ctx.saved_tensors
|
|
46
|
+
bandwidth = ctx.bandwidth
|
|
47
|
+
threshold_grad = torch.sum(
|
|
48
|
+
-(1.0 / bandwidth) * rectangle((x - threshold) / bandwidth) * grad_output,
|
|
49
|
+
dim=0,
|
|
50
|
+
)
|
|
51
|
+
return None, threshold_grad, None
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class JumpReLU(torch.autograd.Function):
|
|
55
|
+
@staticmethod
|
|
56
|
+
def forward(
|
|
57
|
+
x: torch.Tensor,
|
|
58
|
+
threshold: torch.Tensor,
|
|
59
|
+
bandwidth: float, # noqa: ARG004
|
|
60
|
+
) -> torch.Tensor:
|
|
61
|
+
return (x * (x > threshold)).to(x)
|
|
62
|
+
|
|
63
|
+
@staticmethod
|
|
64
|
+
def setup_context(
|
|
65
|
+
ctx: Any, inputs: tuple[torch.Tensor, torch.Tensor, float], output: torch.Tensor
|
|
66
|
+
) -> None:
|
|
67
|
+
x, threshold, bandwidth = inputs
|
|
68
|
+
del output
|
|
69
|
+
ctx.save_for_backward(x, threshold)
|
|
70
|
+
ctx.bandwidth = bandwidth
|
|
71
|
+
|
|
72
|
+
@staticmethod
|
|
73
|
+
def backward( # type: ignore[override]
|
|
74
|
+
ctx: Any, grad_output: torch.Tensor
|
|
75
|
+
) -> tuple[torch.Tensor, torch.Tensor, None]:
|
|
76
|
+
x, threshold = ctx.saved_tensors
|
|
77
|
+
bandwidth = ctx.bandwidth
|
|
78
|
+
x_grad = (x > threshold) * grad_output # We don't apply STE to x input
|
|
79
|
+
threshold_grad = torch.sum(
|
|
80
|
+
-(threshold / bandwidth)
|
|
81
|
+
* rectangle((x - threshold) / bandwidth)
|
|
82
|
+
* grad_output,
|
|
83
|
+
dim=0,
|
|
84
|
+
)
|
|
85
|
+
return x_grad, threshold_grad, None
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class JumpReLUSAE(SAE):
|
|
89
|
+
"""
|
|
90
|
+
JumpReLUSAE is an inference-only implementation of a Sparse Autoencoder (SAE)
|
|
91
|
+
using a JumpReLU activation. For each unit, if its pre-activation is
|
|
92
|
+
<= threshold, that unit is zeroed out; otherwise, it follows a user-specified
|
|
93
|
+
activation function (e.g., ReLU, tanh-relu, etc.).
|
|
94
|
+
|
|
95
|
+
It implements:
|
|
96
|
+
- initialize_weights: sets up parameters, including a threshold.
|
|
97
|
+
- encode: computes the feature activations using JumpReLU.
|
|
98
|
+
- decode: reconstructs the input from the feature activations.
|
|
99
|
+
|
|
100
|
+
The BaseSAE.forward() method automatically calls encode and decode,
|
|
101
|
+
including any error-term processing if configured.
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
b_enc: nn.Parameter
|
|
105
|
+
threshold: nn.Parameter
|
|
106
|
+
|
|
107
|
+
def __init__(self, cfg: SAEConfig, use_error_term: bool = False):
|
|
108
|
+
super().__init__(cfg, use_error_term)
|
|
109
|
+
|
|
110
|
+
def initialize_weights(self) -> None:
|
|
111
|
+
"""
|
|
112
|
+
Initialize encoder and decoder weights, as well as biases.
|
|
113
|
+
Additionally, include a learnable `threshold` parameter that
|
|
114
|
+
determines when units "turn on" for the JumpReLU.
|
|
115
|
+
"""
|
|
116
|
+
# Biases
|
|
117
|
+
self.b_enc = nn.Parameter(
|
|
118
|
+
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
|
|
119
|
+
)
|
|
120
|
+
self.b_dec = nn.Parameter(
|
|
121
|
+
torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
# Threshold for JumpReLU
|
|
125
|
+
# You can pick a default initialization (e.g., zeros means unit is off unless hidden_pre > 0)
|
|
126
|
+
# or see the training version for more advanced init with log_threshold, etc.
|
|
127
|
+
self.threshold = nn.Parameter(
|
|
128
|
+
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
# Encoder and Decoder weights
|
|
132
|
+
w_enc_data = torch.empty(
|
|
133
|
+
self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
|
|
134
|
+
)
|
|
135
|
+
nn.init.kaiming_uniform_(w_enc_data)
|
|
136
|
+
self.W_enc = nn.Parameter(w_enc_data)
|
|
137
|
+
|
|
138
|
+
w_dec_data = torch.empty(
|
|
139
|
+
self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
|
|
140
|
+
)
|
|
141
|
+
nn.init.kaiming_uniform_(w_dec_data)
|
|
142
|
+
self.W_dec = nn.Parameter(w_dec_data)
|
|
143
|
+
|
|
144
|
+
def encode(
|
|
145
|
+
self, x: Float[torch.Tensor, "... d_in"]
|
|
146
|
+
) -> Float[torch.Tensor, "... d_sae"]:
|
|
147
|
+
"""
|
|
148
|
+
Encode the input tensor into the feature space using JumpReLU.
|
|
149
|
+
The threshold parameter determines which units remain active.
|
|
150
|
+
"""
|
|
151
|
+
sae_in = self.process_sae_in(x)
|
|
152
|
+
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
|
|
153
|
+
|
|
154
|
+
# 1) Apply the base "activation_fn" from config (e.g., ReLU, tanh-relu).
|
|
155
|
+
base_acts = self.activation_fn(hidden_pre)
|
|
156
|
+
|
|
157
|
+
# 2) Zero out any unit whose (hidden_pre <= threshold).
|
|
158
|
+
# We cast the boolean mask to the same dtype for safe multiplication.
|
|
159
|
+
jump_relu_mask = (hidden_pre > self.threshold).to(base_acts.dtype)
|
|
160
|
+
|
|
161
|
+
# 3) Multiply the normally activated units by that mask.
|
|
162
|
+
return self.hook_sae_acts_post(base_acts * jump_relu_mask)
|
|
163
|
+
|
|
164
|
+
def decode(
|
|
165
|
+
self, feature_acts: Float[torch.Tensor, "... d_sae"]
|
|
166
|
+
) -> Float[torch.Tensor, "... d_in"]:
|
|
167
|
+
"""
|
|
168
|
+
Decode the feature activations back to the input space.
|
|
169
|
+
Follows the same steps as StandardSAE: apply scaling, transform, hook, and optionally reshape.
|
|
170
|
+
"""
|
|
171
|
+
scaled_features = self.apply_finetuning_scaling_factor(feature_acts)
|
|
172
|
+
sae_out_pre = scaled_features @ self.W_dec + self.b_dec
|
|
173
|
+
sae_out_pre = self.hook_sae_recons(sae_out_pre)
|
|
174
|
+
sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
|
|
175
|
+
return self.reshape_fn_out(sae_out_pre, self.d_head)
|
|
176
|
+
|
|
177
|
+
@torch.no_grad()
|
|
178
|
+
def fold_W_dec_norm(self):
|
|
179
|
+
"""
|
|
180
|
+
Override to properly handle threshold adjustment with W_dec norms.
|
|
181
|
+
When we scale the encoder weights, we need to scale the threshold
|
|
182
|
+
by the same factor to maintain the same sparsity pattern.
|
|
183
|
+
"""
|
|
184
|
+
# Save the current threshold before calling parent method
|
|
185
|
+
current_thresh = self.threshold.clone()
|
|
186
|
+
|
|
187
|
+
# Get W_dec norms that will be used for scaling
|
|
188
|
+
W_dec_norms = self.W_dec.norm(dim=-1)
|
|
189
|
+
|
|
190
|
+
# Call parent implementation to handle W_enc, W_dec, and b_enc adjustment
|
|
191
|
+
super().fold_W_dec_norm()
|
|
192
|
+
|
|
193
|
+
# Scale the threshold by the same factor as we scaled b_enc
|
|
194
|
+
# This ensures the same features remain active/inactive after folding
|
|
195
|
+
self.threshold.data = current_thresh * W_dec_norms
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
class JumpReLUTrainingSAE(TrainingSAE):
|
|
199
|
+
"""
|
|
200
|
+
JumpReLUTrainingSAE is a training-focused implementation of a SAE using a JumpReLU activation.
|
|
201
|
+
|
|
202
|
+
Similar to the inference-only JumpReLUSAE, but with:
|
|
203
|
+
- A learnable log-threshold parameter (instead of a raw threshold).
|
|
204
|
+
- Forward passes that add noise during training, if configured.
|
|
205
|
+
- A specialized auxiliary loss term for sparsity (L0 or similar).
|
|
206
|
+
|
|
207
|
+
Methods of interest include:
|
|
208
|
+
- initialize_weights: sets up W_enc, b_enc, W_dec, b_dec, and log_threshold.
|
|
209
|
+
- encode_with_hidden_pre_jumprelu: runs a forward pass for training, optionally adding noise.
|
|
210
|
+
- training_forward_pass: calculates MSE and auxiliary losses, returning a TrainStepOutput.
|
|
211
|
+
"""
|
|
212
|
+
|
|
213
|
+
b_enc: nn.Parameter
|
|
214
|
+
log_threshold: nn.Parameter
|
|
215
|
+
|
|
216
|
+
def __init__(self, cfg: TrainingSAEConfig, use_error_term: bool = False):
|
|
217
|
+
super().__init__(cfg, use_error_term)
|
|
218
|
+
|
|
219
|
+
# We'll store a bandwidth for the training approach, if needed
|
|
220
|
+
self.bandwidth = cfg.jumprelu_bandwidth
|
|
221
|
+
|
|
222
|
+
# In typical JumpReLU training code, we may track a log_threshold:
|
|
223
|
+
self.log_threshold = nn.Parameter(
|
|
224
|
+
torch.ones(self.cfg.d_sae, dtype=self.dtype, device=self.device)
|
|
225
|
+
* np.log(cfg.jumprelu_init_threshold)
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
def initialize_weights(self) -> None:
|
|
229
|
+
"""
|
|
230
|
+
Initialize parameters like the base SAE, but also add log_threshold.
|
|
231
|
+
"""
|
|
232
|
+
# Encoder Bias
|
|
233
|
+
self.b_enc = nn.Parameter(
|
|
234
|
+
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
|
|
235
|
+
)
|
|
236
|
+
# Decoder Bias
|
|
237
|
+
self.b_dec = nn.Parameter(
|
|
238
|
+
torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
|
|
239
|
+
)
|
|
240
|
+
# W_enc
|
|
241
|
+
w_enc_data = torch.nn.init.kaiming_uniform_(
|
|
242
|
+
torch.empty(
|
|
243
|
+
self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
|
|
244
|
+
)
|
|
245
|
+
)
|
|
246
|
+
self.W_enc = nn.Parameter(w_enc_data)
|
|
247
|
+
|
|
248
|
+
# W_dec
|
|
249
|
+
w_dec_data = torch.nn.init.kaiming_uniform_(
|
|
250
|
+
torch.empty(
|
|
251
|
+
self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
|
|
252
|
+
)
|
|
253
|
+
)
|
|
254
|
+
self.W_dec = nn.Parameter(w_dec_data)
|
|
255
|
+
|
|
256
|
+
# Optionally apply orthogonal or heuristic init
|
|
257
|
+
if self.cfg.decoder_orthogonal_init:
|
|
258
|
+
self.W_dec.data = nn.init.orthogonal_(self.W_dec.data.T).T
|
|
259
|
+
elif self.cfg.decoder_heuristic_init:
|
|
260
|
+
self.W_dec.data = torch.rand(
|
|
261
|
+
self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
|
|
262
|
+
)
|
|
263
|
+
self.initialize_decoder_norm_constant_norm()
|
|
264
|
+
|
|
265
|
+
# Optionally transpose
|
|
266
|
+
if self.cfg.init_encoder_as_decoder_transpose:
|
|
267
|
+
self.W_enc.data = self.W_dec.data.T.clone().contiguous()
|
|
268
|
+
|
|
269
|
+
# Optionally normalize columns of W_dec
|
|
270
|
+
if self.cfg.normalize_sae_decoder:
|
|
271
|
+
with torch.no_grad():
|
|
272
|
+
self.set_decoder_norm_to_unit_norm()
|
|
273
|
+
|
|
274
|
+
@property
|
|
275
|
+
def threshold(self) -> torch.Tensor:
|
|
276
|
+
"""
|
|
277
|
+
Returns the parameterized threshold > 0 for each unit.
|
|
278
|
+
threshold = exp(log_threshold).
|
|
279
|
+
"""
|
|
280
|
+
return torch.exp(self.log_threshold)
|
|
281
|
+
|
|
282
|
+
def encode_with_hidden_pre(
|
|
283
|
+
self, x: Float[torch.Tensor, "... d_in"]
|
|
284
|
+
) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
|
|
285
|
+
sae_in = self.process_sae_in(x)
|
|
286
|
+
|
|
287
|
+
hidden_pre = sae_in @ self.W_enc + self.b_enc
|
|
288
|
+
|
|
289
|
+
if self.training and self.cfg.noise_scale > 0:
|
|
290
|
+
hidden_pre = (
|
|
291
|
+
hidden_pre + torch.randn_like(hidden_pre) * self.cfg.noise_scale
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
feature_acts = JumpReLU.apply(hidden_pre, self.threshold, self.bandwidth)
|
|
295
|
+
|
|
296
|
+
return feature_acts, hidden_pre # type: ignore
|
|
297
|
+
|
|
298
|
+
@override
|
|
299
|
+
def calculate_aux_loss(
|
|
300
|
+
self,
|
|
301
|
+
step_input: TrainStepInput,
|
|
302
|
+
feature_acts: torch.Tensor,
|
|
303
|
+
hidden_pre: torch.Tensor,
|
|
304
|
+
sae_out: torch.Tensor,
|
|
305
|
+
) -> dict[str, torch.Tensor]:
|
|
306
|
+
"""Calculate architecture-specific auxiliary loss terms."""
|
|
307
|
+
l0 = torch.sum(Step.apply(hidden_pre, self.threshold, self.bandwidth), dim=-1) # type: ignore
|
|
308
|
+
l0_loss = (step_input.current_l1_coefficient * l0).mean()
|
|
309
|
+
return {"l0_loss": l0_loss}
|
|
310
|
+
|
|
311
|
+
@torch.no_grad()
|
|
312
|
+
def fold_W_dec_norm(self):
|
|
313
|
+
"""
|
|
314
|
+
Override to properly handle threshold adjustment with W_dec norms.
|
|
315
|
+
"""
|
|
316
|
+
# Save the current threshold before we call the parent method
|
|
317
|
+
current_thresh = self.threshold.clone()
|
|
318
|
+
|
|
319
|
+
# Get W_dec norms
|
|
320
|
+
W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
|
|
321
|
+
|
|
322
|
+
# Call parent implementation to handle W_enc and W_dec adjustment
|
|
323
|
+
super().fold_W_dec_norm()
|
|
324
|
+
|
|
325
|
+
# Fix: Use squeeze() instead of squeeze(-1) to match old behavior
|
|
326
|
+
self.log_threshold.data = torch.log(current_thresh * W_dec_norms.squeeze())
|
|
327
|
+
|
|
328
|
+
def _create_train_step_output(
|
|
329
|
+
self,
|
|
330
|
+
sae_in: torch.Tensor,
|
|
331
|
+
sae_out: torch.Tensor,
|
|
332
|
+
feature_acts: torch.Tensor,
|
|
333
|
+
hidden_pre: torch.Tensor,
|
|
334
|
+
loss: torch.Tensor,
|
|
335
|
+
losses: dict[str, torch.Tensor],
|
|
336
|
+
) -> TrainStepOutput:
|
|
337
|
+
"""
|
|
338
|
+
Helper to produce a TrainStepOutput from the trainer.
|
|
339
|
+
The old code expects a method named _create_train_step_output().
|
|
340
|
+
"""
|
|
341
|
+
return TrainStepOutput(
|
|
342
|
+
sae_in=sae_in,
|
|
343
|
+
sae_out=sae_out,
|
|
344
|
+
feature_acts=feature_acts,
|
|
345
|
+
hidden_pre=hidden_pre,
|
|
346
|
+
loss=loss,
|
|
347
|
+
losses=losses,
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
@torch.no_grad()
|
|
351
|
+
def initialize_decoder_norm_constant_norm(self, norm: float = 0.1):
|
|
352
|
+
"""Initialize decoder with constant norm"""
|
|
353
|
+
self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
|
|
354
|
+
self.W_dec.data *= norm
|
|
355
|
+
|
|
356
|
+
def process_state_dict_for_saving(self, state_dict: dict[str, Any]) -> None:
|
|
357
|
+
"""Convert log_threshold to threshold for saving"""
|
|
358
|
+
if "log_threshold" in state_dict:
|
|
359
|
+
threshold = torch.exp(state_dict["log_threshold"]).detach().contiguous()
|
|
360
|
+
del state_dict["log_threshold"]
|
|
361
|
+
state_dict["threshold"] = threshold
|
|
362
|
+
|
|
363
|
+
def process_state_dict_for_loading(self, state_dict: dict[str, Any]) -> None:
|
|
364
|
+
"""Convert threshold to log_threshold for loading"""
|
|
365
|
+
if "threshold" in state_dict:
|
|
366
|
+
threshold = state_dict["threshold"]
|
|
367
|
+
del state_dict["threshold"]
|
|
368
|
+
state_dict["log_threshold"] = torch.log(threshold).detach().contiguous()
|