sae-lens 5.10.7__py3-none-any.whl → 6.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +60 -7
- sae_lens/analysis/hooked_sae_transformer.py +12 -12
- sae_lens/analysis/neuronpedia_integration.py +16 -14
- sae_lens/cache_activations_runner.py +9 -7
- sae_lens/config.py +170 -257
- sae_lens/constants.py +21 -0
- sae_lens/evals.py +59 -44
- sae_lens/llm_sae_training_runner.py +377 -0
- sae_lens/load_model.py +53 -5
- sae_lens/{toolkit → loading}/pretrained_sae_loaders.py +228 -32
- sae_lens/registry.py +49 -0
- sae_lens/saes/__init__.py +48 -0
- sae_lens/saes/gated_sae.py +254 -0
- sae_lens/saes/jumprelu_sae.py +348 -0
- sae_lens/saes/sae.py +1076 -0
- sae_lens/saes/standard_sae.py +178 -0
- sae_lens/saes/topk_sae.py +300 -0
- sae_lens/training/activation_scaler.py +53 -0
- sae_lens/training/activations_store.py +103 -184
- sae_lens/training/mixing_buffer.py +56 -0
- sae_lens/training/optim.py +60 -36
- sae_lens/training/sae_trainer.py +155 -177
- sae_lens/training/types.py +5 -0
- sae_lens/training/upload_saes_to_huggingface.py +13 -7
- sae_lens/util.py +47 -0
- {sae_lens-5.10.7.dist-info → sae_lens-6.0.0.dist-info}/METADATA +1 -1
- sae_lens-6.0.0.dist-info/RECORD +37 -0
- sae_lens/sae.py +0 -747
- sae_lens/sae_training_runner.py +0 -251
- sae_lens/training/geometric_median.py +0 -101
- sae_lens/training/training_sae.py +0 -710
- sae_lens-5.10.7.dist-info/RECORD +0 -28
- /sae_lens/{toolkit → loading}/__init__.py +0 -0
- /sae_lens/{toolkit → loading}/pretrained_saes_directory.py +0 -0
- {sae_lens-5.10.7.dist-info → sae_lens-6.0.0.dist-info}/LICENSE +0 -0
- {sae_lens-5.10.7.dist-info → sae_lens-6.0.0.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,254 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from jaxtyping import Float
|
|
6
|
+
from numpy.typing import NDArray
|
|
7
|
+
from torch import nn
|
|
8
|
+
from typing_extensions import override
|
|
9
|
+
|
|
10
|
+
from sae_lens.saes.sae import (
|
|
11
|
+
SAE,
|
|
12
|
+
SAEConfig,
|
|
13
|
+
TrainCoefficientConfig,
|
|
14
|
+
TrainingSAE,
|
|
15
|
+
TrainingSAEConfig,
|
|
16
|
+
TrainStepInput,
|
|
17
|
+
)
|
|
18
|
+
from sae_lens.util import filter_valid_dataclass_fields
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class GatedSAEConfig(SAEConfig):
|
|
23
|
+
"""
|
|
24
|
+
Configuration class for a GatedSAE.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
@override
|
|
28
|
+
@classmethod
|
|
29
|
+
def architecture(cls) -> str:
|
|
30
|
+
return "gated"
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class GatedSAE(SAE[GatedSAEConfig]):
|
|
34
|
+
"""
|
|
35
|
+
GatedSAE is an inference-only implementation of a Sparse Autoencoder (SAE)
|
|
36
|
+
using a gated linear encoder and a standard linear decoder.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
b_gate: nn.Parameter
|
|
40
|
+
b_mag: nn.Parameter
|
|
41
|
+
r_mag: nn.Parameter
|
|
42
|
+
|
|
43
|
+
def __init__(self, cfg: GatedSAEConfig, use_error_term: bool = False):
|
|
44
|
+
super().__init__(cfg, use_error_term)
|
|
45
|
+
# Ensure b_enc does not exist for the gated architecture
|
|
46
|
+
self.b_enc = None
|
|
47
|
+
|
|
48
|
+
@override
|
|
49
|
+
def initialize_weights(self) -> None:
|
|
50
|
+
super().initialize_weights()
|
|
51
|
+
_init_weights_gated(self)
|
|
52
|
+
|
|
53
|
+
def encode(
|
|
54
|
+
self, x: Float[torch.Tensor, "... d_in"]
|
|
55
|
+
) -> Float[torch.Tensor, "... d_sae"]:
|
|
56
|
+
"""
|
|
57
|
+
Encode the input tensor into the feature space using a gated encoder.
|
|
58
|
+
This must match the original encode_gated implementation from SAE class.
|
|
59
|
+
"""
|
|
60
|
+
# Preprocess the SAE input (casting type, applying hooks, normalization)
|
|
61
|
+
sae_in = self.process_sae_in(x)
|
|
62
|
+
|
|
63
|
+
# Gating path exactly as in original SAE.encode_gated
|
|
64
|
+
gating_pre_activation = sae_in @ self.W_enc + self.b_gate
|
|
65
|
+
active_features = (gating_pre_activation > 0).to(self.dtype)
|
|
66
|
+
|
|
67
|
+
# Magnitude path (weight sharing with gated encoder)
|
|
68
|
+
magnitude_pre_activation = self.hook_sae_acts_pre(
|
|
69
|
+
sae_in @ (self.W_enc * self.r_mag.exp()) + self.b_mag
|
|
70
|
+
)
|
|
71
|
+
feature_magnitudes = self.activation_fn(magnitude_pre_activation)
|
|
72
|
+
|
|
73
|
+
# Combine gating and magnitudes
|
|
74
|
+
return self.hook_sae_acts_post(active_features * feature_magnitudes)
|
|
75
|
+
|
|
76
|
+
def decode(
|
|
77
|
+
self, feature_acts: Float[torch.Tensor, "... d_sae"]
|
|
78
|
+
) -> Float[torch.Tensor, "... d_in"]:
|
|
79
|
+
"""
|
|
80
|
+
Decode the feature activations back into the input space:
|
|
81
|
+
1) Apply optional finetuning scaling.
|
|
82
|
+
2) Linear transform plus bias.
|
|
83
|
+
3) Run any reconstruction hooks and out-normalization if configured.
|
|
84
|
+
4) If the SAE was reshaping hook_z activations, reshape back.
|
|
85
|
+
"""
|
|
86
|
+
# 1) optional finetuning scaling
|
|
87
|
+
# 2) linear transform
|
|
88
|
+
sae_out_pre = feature_acts @ self.W_dec + self.b_dec
|
|
89
|
+
# 3) hooking and normalization
|
|
90
|
+
sae_out_pre = self.hook_sae_recons(sae_out_pre)
|
|
91
|
+
sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
|
|
92
|
+
# 4) reshape if needed (hook_z)
|
|
93
|
+
return self.reshape_fn_out(sae_out_pre, self.d_head)
|
|
94
|
+
|
|
95
|
+
@torch.no_grad()
|
|
96
|
+
def fold_W_dec_norm(self):
|
|
97
|
+
"""Override to handle gated-specific parameters."""
|
|
98
|
+
W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
|
|
99
|
+
self.W_dec.data = self.W_dec.data / W_dec_norms
|
|
100
|
+
self.W_enc.data = self.W_enc.data * W_dec_norms.T
|
|
101
|
+
|
|
102
|
+
# Gated-specific parameters need special handling
|
|
103
|
+
self.r_mag.data = self.r_mag.data * W_dec_norms.squeeze()
|
|
104
|
+
self.b_gate.data = self.b_gate.data * W_dec_norms.squeeze()
|
|
105
|
+
self.b_mag.data = self.b_mag.data * W_dec_norms.squeeze()
|
|
106
|
+
|
|
107
|
+
@torch.no_grad()
|
|
108
|
+
def initialize_decoder_norm_constant_norm(self, norm: float = 0.1):
|
|
109
|
+
"""Initialize decoder with constant norm."""
|
|
110
|
+
self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
|
|
111
|
+
self.W_dec.data *= norm
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@dataclass
|
|
115
|
+
class GatedTrainingSAEConfig(TrainingSAEConfig):
|
|
116
|
+
"""
|
|
117
|
+
Configuration class for training a GatedTrainingSAE.
|
|
118
|
+
"""
|
|
119
|
+
|
|
120
|
+
l1_coefficient: float = 1.0
|
|
121
|
+
l1_warm_up_steps: int = 0
|
|
122
|
+
|
|
123
|
+
@override
|
|
124
|
+
@classmethod
|
|
125
|
+
def architecture(cls) -> str:
|
|
126
|
+
return "gated"
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class GatedTrainingSAE(TrainingSAE[GatedTrainingSAEConfig]):
|
|
130
|
+
"""
|
|
131
|
+
GatedTrainingSAE is a concrete implementation of BaseTrainingSAE for the "gated" SAE architecture.
|
|
132
|
+
It implements:
|
|
133
|
+
- initialize_weights: sets up gating parameters (as in GatedSAE) plus optional training-specific init.
|
|
134
|
+
- encode: calls encode_with_hidden_pre (standard training approach).
|
|
135
|
+
- decode: linear transformation + hooking, same as GatedSAE or StandardTrainingSAE.
|
|
136
|
+
- encode_with_hidden_pre: gating logic + optional noise injection for training.
|
|
137
|
+
- calculate_aux_loss: includes an auxiliary reconstruction path and gating-based sparsity penalty.
|
|
138
|
+
- training_forward_pass: calls encode_with_hidden_pre, decode, and sums up MSE + gating losses.
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
b_gate: nn.Parameter # type: ignore
|
|
142
|
+
b_mag: nn.Parameter # type: ignore
|
|
143
|
+
r_mag: nn.Parameter # type: ignore
|
|
144
|
+
|
|
145
|
+
def __init__(self, cfg: GatedTrainingSAEConfig, use_error_term: bool = False):
|
|
146
|
+
if use_error_term:
|
|
147
|
+
raise ValueError(
|
|
148
|
+
"GatedSAE does not support `use_error_term`. Please set `use_error_term=False`."
|
|
149
|
+
)
|
|
150
|
+
super().__init__(cfg, use_error_term)
|
|
151
|
+
|
|
152
|
+
def initialize_weights(self) -> None:
|
|
153
|
+
super().initialize_weights()
|
|
154
|
+
_init_weights_gated(self)
|
|
155
|
+
|
|
156
|
+
def encode_with_hidden_pre(
|
|
157
|
+
self, x: Float[torch.Tensor, "... d_in"]
|
|
158
|
+
) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
|
|
159
|
+
"""
|
|
160
|
+
Gated forward pass with pre-activation (for training).
|
|
161
|
+
We also inject noise if self.training is True.
|
|
162
|
+
"""
|
|
163
|
+
sae_in = self.process_sae_in(x)
|
|
164
|
+
|
|
165
|
+
# Gating path
|
|
166
|
+
gating_pre_activation = sae_in @ self.W_enc + self.b_gate
|
|
167
|
+
active_features = (gating_pre_activation > 0).to(self.dtype)
|
|
168
|
+
|
|
169
|
+
# Magnitude path
|
|
170
|
+
magnitude_pre_activation = sae_in @ (self.W_enc * self.r_mag.exp()) + self.b_mag
|
|
171
|
+
magnitude_pre_activation = self.hook_sae_acts_pre(magnitude_pre_activation)
|
|
172
|
+
|
|
173
|
+
feature_magnitudes = self.activation_fn(magnitude_pre_activation)
|
|
174
|
+
|
|
175
|
+
# Combine gating path and magnitude path
|
|
176
|
+
feature_acts = self.hook_sae_acts_post(active_features * feature_magnitudes)
|
|
177
|
+
|
|
178
|
+
# Return both the final feature activations and the pre-activation (for logging or penalty)
|
|
179
|
+
return feature_acts, magnitude_pre_activation
|
|
180
|
+
|
|
181
|
+
def calculate_aux_loss(
|
|
182
|
+
self,
|
|
183
|
+
step_input: TrainStepInput,
|
|
184
|
+
feature_acts: torch.Tensor,
|
|
185
|
+
hidden_pre: torch.Tensor,
|
|
186
|
+
sae_out: torch.Tensor,
|
|
187
|
+
) -> dict[str, torch.Tensor]:
|
|
188
|
+
# Re-center the input if apply_b_dec_to_input is set
|
|
189
|
+
sae_in_centered = step_input.sae_in - (
|
|
190
|
+
self.b_dec * self.cfg.apply_b_dec_to_input
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
# The gating pre-activation (pi_gate) for the auxiliary path
|
|
194
|
+
pi_gate = sae_in_centered @ self.W_enc + self.b_gate
|
|
195
|
+
pi_gate_act = torch.relu(pi_gate)
|
|
196
|
+
|
|
197
|
+
# L1-like penalty scaled by W_dec norms
|
|
198
|
+
l1_loss = (
|
|
199
|
+
step_input.coefficients["l1"]
|
|
200
|
+
* torch.sum(pi_gate_act * self.W_dec.norm(dim=1), dim=-1).mean()
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
# Aux reconstruction: reconstruct x purely from gating path
|
|
204
|
+
via_gate_reconstruction = pi_gate_act @ self.W_dec + self.b_dec
|
|
205
|
+
aux_recon_loss = (
|
|
206
|
+
(via_gate_reconstruction - step_input.sae_in).pow(2).sum(dim=-1).mean()
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
# Return both losses separately
|
|
210
|
+
return {"l1_loss": l1_loss, "auxiliary_reconstruction_loss": aux_recon_loss}
|
|
211
|
+
|
|
212
|
+
def log_histograms(self) -> dict[str, NDArray[Any]]:
|
|
213
|
+
"""Log histograms of the weights and biases."""
|
|
214
|
+
b_gate_dist = self.b_gate.detach().float().cpu().numpy()
|
|
215
|
+
b_mag_dist = self.b_mag.detach().float().cpu().numpy()
|
|
216
|
+
return {
|
|
217
|
+
**super().log_histograms(),
|
|
218
|
+
"weights/b_gate": b_gate_dist,
|
|
219
|
+
"weights/b_mag": b_mag_dist,
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
@torch.no_grad()
|
|
223
|
+
def initialize_decoder_norm_constant_norm(self, norm: float = 0.1):
|
|
224
|
+
"""Initialize decoder with constant norm"""
|
|
225
|
+
self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
|
|
226
|
+
self.W_dec.data *= norm
|
|
227
|
+
|
|
228
|
+
def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]:
|
|
229
|
+
return {
|
|
230
|
+
"l1": TrainCoefficientConfig(
|
|
231
|
+
value=self.cfg.l1_coefficient,
|
|
232
|
+
warm_up_steps=self.cfg.l1_warm_up_steps,
|
|
233
|
+
),
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
def to_inference_config_dict(self) -> dict[str, Any]:
|
|
237
|
+
return filter_valid_dataclass_fields(
|
|
238
|
+
self.cfg.to_dict(), GatedSAEConfig, ["architecture"]
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def _init_weights_gated(
|
|
243
|
+
sae: SAE[GatedSAEConfig] | TrainingSAE[GatedTrainingSAEConfig],
|
|
244
|
+
) -> None:
|
|
245
|
+
sae.b_gate = nn.Parameter(
|
|
246
|
+
torch.zeros(sae.cfg.d_sae, dtype=sae.dtype, device=sae.device)
|
|
247
|
+
)
|
|
248
|
+
# Ensure r_mag is initialized to zero as in original
|
|
249
|
+
sae.r_mag = nn.Parameter(
|
|
250
|
+
torch.zeros(sae.cfg.d_sae, dtype=sae.dtype, device=sae.device)
|
|
251
|
+
)
|
|
252
|
+
sae.b_mag = nn.Parameter(
|
|
253
|
+
torch.zeros(sae.cfg.d_sae, dtype=sae.dtype, device=sae.device)
|
|
254
|
+
)
|
|
@@ -0,0 +1,348 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
from jaxtyping import Float
|
|
7
|
+
from torch import nn
|
|
8
|
+
from typing_extensions import override
|
|
9
|
+
|
|
10
|
+
from sae_lens.saes.sae import (
|
|
11
|
+
SAE,
|
|
12
|
+
SAEConfig,
|
|
13
|
+
TrainCoefficientConfig,
|
|
14
|
+
TrainingSAE,
|
|
15
|
+
TrainingSAEConfig,
|
|
16
|
+
TrainStepInput,
|
|
17
|
+
TrainStepOutput,
|
|
18
|
+
)
|
|
19
|
+
from sae_lens.util import filter_valid_dataclass_fields
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def rectangle(x: torch.Tensor) -> torch.Tensor:
|
|
23
|
+
return ((x > -0.5) & (x < 0.5)).to(x)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class Step(torch.autograd.Function):
|
|
27
|
+
@staticmethod
|
|
28
|
+
def forward(
|
|
29
|
+
x: torch.Tensor,
|
|
30
|
+
threshold: torch.Tensor,
|
|
31
|
+
bandwidth: float, # noqa: ARG004
|
|
32
|
+
) -> torch.Tensor:
|
|
33
|
+
return (x > threshold).to(x)
|
|
34
|
+
|
|
35
|
+
@staticmethod
|
|
36
|
+
def setup_context(
|
|
37
|
+
ctx: Any, inputs: tuple[torch.Tensor, torch.Tensor, float], output: torch.Tensor
|
|
38
|
+
) -> None:
|
|
39
|
+
x, threshold, bandwidth = inputs
|
|
40
|
+
del output
|
|
41
|
+
ctx.save_for_backward(x, threshold)
|
|
42
|
+
ctx.bandwidth = bandwidth
|
|
43
|
+
|
|
44
|
+
@staticmethod
|
|
45
|
+
def backward( # type: ignore[override]
|
|
46
|
+
ctx: Any, grad_output: torch.Tensor
|
|
47
|
+
) -> tuple[None, torch.Tensor, None]:
|
|
48
|
+
x, threshold = ctx.saved_tensors
|
|
49
|
+
bandwidth = ctx.bandwidth
|
|
50
|
+
threshold_grad = torch.sum(
|
|
51
|
+
-(1.0 / bandwidth) * rectangle((x - threshold) / bandwidth) * grad_output,
|
|
52
|
+
dim=0,
|
|
53
|
+
)
|
|
54
|
+
return None, threshold_grad, None
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class JumpReLU(torch.autograd.Function):
|
|
58
|
+
@staticmethod
|
|
59
|
+
def forward(
|
|
60
|
+
x: torch.Tensor,
|
|
61
|
+
threshold: torch.Tensor,
|
|
62
|
+
bandwidth: float, # noqa: ARG004
|
|
63
|
+
) -> torch.Tensor:
|
|
64
|
+
return (x * (x > threshold)).to(x)
|
|
65
|
+
|
|
66
|
+
@staticmethod
|
|
67
|
+
def setup_context(
|
|
68
|
+
ctx: Any, inputs: tuple[torch.Tensor, torch.Tensor, float], output: torch.Tensor
|
|
69
|
+
) -> None:
|
|
70
|
+
x, threshold, bandwidth = inputs
|
|
71
|
+
del output
|
|
72
|
+
ctx.save_for_backward(x, threshold)
|
|
73
|
+
ctx.bandwidth = bandwidth
|
|
74
|
+
|
|
75
|
+
@staticmethod
|
|
76
|
+
def backward( # type: ignore[override]
|
|
77
|
+
ctx: Any, grad_output: torch.Tensor
|
|
78
|
+
) -> tuple[torch.Tensor, torch.Tensor, None]:
|
|
79
|
+
x, threshold = ctx.saved_tensors
|
|
80
|
+
bandwidth = ctx.bandwidth
|
|
81
|
+
x_grad = (x > threshold) * grad_output # We don't apply STE to x input
|
|
82
|
+
threshold_grad = torch.sum(
|
|
83
|
+
-(threshold / bandwidth)
|
|
84
|
+
* rectangle((x - threshold) / bandwidth)
|
|
85
|
+
* grad_output,
|
|
86
|
+
dim=0,
|
|
87
|
+
)
|
|
88
|
+
return x_grad, threshold_grad, None
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@dataclass
|
|
92
|
+
class JumpReLUSAEConfig(SAEConfig):
|
|
93
|
+
"""
|
|
94
|
+
Configuration class for a JumpReLUSAE.
|
|
95
|
+
"""
|
|
96
|
+
|
|
97
|
+
@override
|
|
98
|
+
@classmethod
|
|
99
|
+
def architecture(cls) -> str:
|
|
100
|
+
return "jumprelu"
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class JumpReLUSAE(SAE[JumpReLUSAEConfig]):
|
|
104
|
+
"""
|
|
105
|
+
JumpReLUSAE is an inference-only implementation of a Sparse Autoencoder (SAE)
|
|
106
|
+
using a JumpReLU activation. For each unit, if its pre-activation is
|
|
107
|
+
<= threshold, that unit is zeroed out; otherwise, it follows a user-specified
|
|
108
|
+
activation function (e.g., ReLU etc.).
|
|
109
|
+
|
|
110
|
+
It implements:
|
|
111
|
+
- initialize_weights: sets up parameters, including a threshold.
|
|
112
|
+
- encode: computes the feature activations using JumpReLU.
|
|
113
|
+
- decode: reconstructs the input from the feature activations.
|
|
114
|
+
|
|
115
|
+
The BaseSAE.forward() method automatically calls encode and decode,
|
|
116
|
+
including any error-term processing if configured.
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
b_enc: nn.Parameter
|
|
120
|
+
threshold: nn.Parameter
|
|
121
|
+
|
|
122
|
+
def __init__(self, cfg: JumpReLUSAEConfig, use_error_term: bool = False):
|
|
123
|
+
super().__init__(cfg, use_error_term)
|
|
124
|
+
|
|
125
|
+
@override
|
|
126
|
+
def initialize_weights(self) -> None:
|
|
127
|
+
super().initialize_weights()
|
|
128
|
+
self.threshold = nn.Parameter(
|
|
129
|
+
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
|
|
130
|
+
)
|
|
131
|
+
self.b_enc = nn.Parameter(
|
|
132
|
+
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
def encode(
|
|
136
|
+
self, x: Float[torch.Tensor, "... d_in"]
|
|
137
|
+
) -> Float[torch.Tensor, "... d_sae"]:
|
|
138
|
+
"""
|
|
139
|
+
Encode the input tensor into the feature space using JumpReLU.
|
|
140
|
+
The threshold parameter determines which units remain active.
|
|
141
|
+
"""
|
|
142
|
+
sae_in = self.process_sae_in(x)
|
|
143
|
+
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
|
|
144
|
+
|
|
145
|
+
# 1) Apply the base "activation_fn" from config (e.g., ReLU).
|
|
146
|
+
base_acts = self.activation_fn(hidden_pre)
|
|
147
|
+
|
|
148
|
+
# 2) Zero out any unit whose (hidden_pre <= threshold).
|
|
149
|
+
# We cast the boolean mask to the same dtype for safe multiplication.
|
|
150
|
+
jump_relu_mask = (hidden_pre > self.threshold).to(base_acts.dtype)
|
|
151
|
+
|
|
152
|
+
# 3) Multiply the normally activated units by that mask.
|
|
153
|
+
return self.hook_sae_acts_post(base_acts * jump_relu_mask)
|
|
154
|
+
|
|
155
|
+
def decode(
|
|
156
|
+
self, feature_acts: Float[torch.Tensor, "... d_sae"]
|
|
157
|
+
) -> Float[torch.Tensor, "... d_in"]:
|
|
158
|
+
"""
|
|
159
|
+
Decode the feature activations back to the input space.
|
|
160
|
+
Follows the same steps as StandardSAE: apply scaling, transform, hook, and optionally reshape.
|
|
161
|
+
"""
|
|
162
|
+
sae_out_pre = feature_acts @ self.W_dec + self.b_dec
|
|
163
|
+
sae_out_pre = self.hook_sae_recons(sae_out_pre)
|
|
164
|
+
sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
|
|
165
|
+
return self.reshape_fn_out(sae_out_pre, self.d_head)
|
|
166
|
+
|
|
167
|
+
@torch.no_grad()
|
|
168
|
+
def fold_W_dec_norm(self):
|
|
169
|
+
"""
|
|
170
|
+
Override to properly handle threshold adjustment with W_dec norms.
|
|
171
|
+
When we scale the encoder weights, we need to scale the threshold
|
|
172
|
+
by the same factor to maintain the same sparsity pattern.
|
|
173
|
+
"""
|
|
174
|
+
# Save the current threshold before calling parent method
|
|
175
|
+
current_thresh = self.threshold.clone()
|
|
176
|
+
|
|
177
|
+
# Get W_dec norms that will be used for scaling
|
|
178
|
+
W_dec_norms = self.W_dec.norm(dim=-1)
|
|
179
|
+
|
|
180
|
+
# Call parent implementation to handle W_enc, W_dec, and b_enc adjustment
|
|
181
|
+
super().fold_W_dec_norm()
|
|
182
|
+
|
|
183
|
+
# Scale the threshold by the same factor as we scaled b_enc
|
|
184
|
+
# This ensures the same features remain active/inactive after folding
|
|
185
|
+
self.threshold.data = current_thresh * W_dec_norms
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
@dataclass
|
|
189
|
+
class JumpReLUTrainingSAEConfig(TrainingSAEConfig):
|
|
190
|
+
"""
|
|
191
|
+
Configuration class for training a JumpReLUTrainingSAE.
|
|
192
|
+
"""
|
|
193
|
+
|
|
194
|
+
jumprelu_init_threshold: float = 0.01
|
|
195
|
+
jumprelu_bandwidth: float = 0.05
|
|
196
|
+
l0_coefficient: float = 1.0
|
|
197
|
+
l0_warm_up_steps: int = 0
|
|
198
|
+
|
|
199
|
+
@override
|
|
200
|
+
@classmethod
|
|
201
|
+
def architecture(cls) -> str:
|
|
202
|
+
return "jumprelu"
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
class JumpReLUTrainingSAE(TrainingSAE[JumpReLUTrainingSAEConfig]):
|
|
206
|
+
"""
|
|
207
|
+
JumpReLUTrainingSAE is a training-focused implementation of a SAE using a JumpReLU activation.
|
|
208
|
+
|
|
209
|
+
Similar to the inference-only JumpReLUSAE, but with:
|
|
210
|
+
- A learnable log-threshold parameter (instead of a raw threshold).
|
|
211
|
+
- Forward passes that add noise during training, if configured.
|
|
212
|
+
- A specialized auxiliary loss term for sparsity (L0 or similar).
|
|
213
|
+
|
|
214
|
+
Methods of interest include:
|
|
215
|
+
- initialize_weights: sets up W_enc, b_enc, W_dec, b_dec, and log_threshold.
|
|
216
|
+
- encode_with_hidden_pre_jumprelu: runs a forward pass for training, optionally adding noise.
|
|
217
|
+
- training_forward_pass: calculates MSE and auxiliary losses, returning a TrainStepOutput.
|
|
218
|
+
"""
|
|
219
|
+
|
|
220
|
+
b_enc: nn.Parameter
|
|
221
|
+
log_threshold: nn.Parameter
|
|
222
|
+
|
|
223
|
+
def __init__(self, cfg: JumpReLUTrainingSAEConfig, use_error_term: bool = False):
|
|
224
|
+
super().__init__(cfg, use_error_term)
|
|
225
|
+
|
|
226
|
+
# We'll store a bandwidth for the training approach, if needed
|
|
227
|
+
self.bandwidth = cfg.jumprelu_bandwidth
|
|
228
|
+
|
|
229
|
+
# In typical JumpReLU training code, we may track a log_threshold:
|
|
230
|
+
self.log_threshold = nn.Parameter(
|
|
231
|
+
torch.ones(self.cfg.d_sae, dtype=self.dtype, device=self.device)
|
|
232
|
+
* np.log(cfg.jumprelu_init_threshold)
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
@override
|
|
236
|
+
def initialize_weights(self) -> None:
|
|
237
|
+
"""
|
|
238
|
+
Initialize parameters like the base SAE, but also add log_threshold.
|
|
239
|
+
"""
|
|
240
|
+
super().initialize_weights()
|
|
241
|
+
# Encoder Bias
|
|
242
|
+
self.b_enc = nn.Parameter(
|
|
243
|
+
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
@property
|
|
247
|
+
def threshold(self) -> torch.Tensor:
|
|
248
|
+
"""
|
|
249
|
+
Returns the parameterized threshold > 0 for each unit.
|
|
250
|
+
threshold = exp(log_threshold).
|
|
251
|
+
"""
|
|
252
|
+
return torch.exp(self.log_threshold)
|
|
253
|
+
|
|
254
|
+
def encode_with_hidden_pre(
|
|
255
|
+
self, x: Float[torch.Tensor, "... d_in"]
|
|
256
|
+
) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
|
|
257
|
+
sae_in = self.process_sae_in(x)
|
|
258
|
+
|
|
259
|
+
hidden_pre = sae_in @ self.W_enc + self.b_enc
|
|
260
|
+
feature_acts = JumpReLU.apply(hidden_pre, self.threshold, self.bandwidth)
|
|
261
|
+
|
|
262
|
+
return feature_acts, hidden_pre # type: ignore
|
|
263
|
+
|
|
264
|
+
@override
|
|
265
|
+
def calculate_aux_loss(
|
|
266
|
+
self,
|
|
267
|
+
step_input: TrainStepInput,
|
|
268
|
+
feature_acts: torch.Tensor,
|
|
269
|
+
hidden_pre: torch.Tensor,
|
|
270
|
+
sae_out: torch.Tensor,
|
|
271
|
+
) -> dict[str, torch.Tensor]:
|
|
272
|
+
"""Calculate architecture-specific auxiliary loss terms."""
|
|
273
|
+
l0 = torch.sum(Step.apply(hidden_pre, self.threshold, self.bandwidth), dim=-1) # type: ignore
|
|
274
|
+
l0_loss = (step_input.coefficients["l0"] * l0).mean()
|
|
275
|
+
return {"l0_loss": l0_loss}
|
|
276
|
+
|
|
277
|
+
@override
|
|
278
|
+
def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]:
|
|
279
|
+
return {
|
|
280
|
+
"l0": TrainCoefficientConfig(
|
|
281
|
+
value=self.cfg.l0_coefficient,
|
|
282
|
+
warm_up_steps=self.cfg.l0_warm_up_steps,
|
|
283
|
+
),
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
@torch.no_grad()
|
|
287
|
+
def fold_W_dec_norm(self):
|
|
288
|
+
"""
|
|
289
|
+
Override to properly handle threshold adjustment with W_dec norms.
|
|
290
|
+
"""
|
|
291
|
+
# Save the current threshold before we call the parent method
|
|
292
|
+
current_thresh = self.threshold.clone()
|
|
293
|
+
|
|
294
|
+
# Get W_dec norms
|
|
295
|
+
W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
|
|
296
|
+
|
|
297
|
+
# Call parent implementation to handle W_enc and W_dec adjustment
|
|
298
|
+
super().fold_W_dec_norm()
|
|
299
|
+
|
|
300
|
+
# Fix: Use squeeze() instead of squeeze(-1) to match old behavior
|
|
301
|
+
self.log_threshold.data = torch.log(current_thresh * W_dec_norms.squeeze())
|
|
302
|
+
|
|
303
|
+
def _create_train_step_output(
|
|
304
|
+
self,
|
|
305
|
+
sae_in: torch.Tensor,
|
|
306
|
+
sae_out: torch.Tensor,
|
|
307
|
+
feature_acts: torch.Tensor,
|
|
308
|
+
hidden_pre: torch.Tensor,
|
|
309
|
+
loss: torch.Tensor,
|
|
310
|
+
losses: dict[str, torch.Tensor],
|
|
311
|
+
) -> TrainStepOutput:
|
|
312
|
+
"""
|
|
313
|
+
Helper to produce a TrainStepOutput from the trainer.
|
|
314
|
+
The old code expects a method named _create_train_step_output().
|
|
315
|
+
"""
|
|
316
|
+
return TrainStepOutput(
|
|
317
|
+
sae_in=sae_in,
|
|
318
|
+
sae_out=sae_out,
|
|
319
|
+
feature_acts=feature_acts,
|
|
320
|
+
hidden_pre=hidden_pre,
|
|
321
|
+
loss=loss,
|
|
322
|
+
losses=losses,
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
@torch.no_grad()
|
|
326
|
+
def initialize_decoder_norm_constant_norm(self, norm: float = 0.1):
|
|
327
|
+
"""Initialize decoder with constant norm"""
|
|
328
|
+
self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
|
|
329
|
+
self.W_dec.data *= norm
|
|
330
|
+
|
|
331
|
+
def process_state_dict_for_saving(self, state_dict: dict[str, Any]) -> None:
|
|
332
|
+
"""Convert log_threshold to threshold for saving"""
|
|
333
|
+
if "log_threshold" in state_dict:
|
|
334
|
+
threshold = torch.exp(state_dict["log_threshold"]).detach().contiguous()
|
|
335
|
+
del state_dict["log_threshold"]
|
|
336
|
+
state_dict["threshold"] = threshold
|
|
337
|
+
|
|
338
|
+
def process_state_dict_for_loading(self, state_dict: dict[str, Any]) -> None:
|
|
339
|
+
"""Convert threshold to log_threshold for loading"""
|
|
340
|
+
if "threshold" in state_dict:
|
|
341
|
+
threshold = state_dict["threshold"]
|
|
342
|
+
del state_dict["threshold"]
|
|
343
|
+
state_dict["log_threshold"] = torch.log(threshold).detach().contiguous()
|
|
344
|
+
|
|
345
|
+
def to_inference_config_dict(self) -> dict[str, Any]:
|
|
346
|
+
return filter_valid_dataclass_fields(
|
|
347
|
+
self.cfg.to_dict(), JumpReLUSAEConfig, ["architecture"]
|
|
348
|
+
)
|