sae-lens 5.11.0__py3-none-any.whl → 6.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +60 -7
- sae_lens/analysis/hooked_sae_transformer.py +12 -12
- sae_lens/analysis/neuronpedia_integration.py +16 -14
- sae_lens/cache_activations_runner.py +9 -7
- sae_lens/config.py +170 -258
- sae_lens/constants.py +21 -0
- sae_lens/evals.py +59 -44
- sae_lens/llm_sae_training_runner.py +377 -0
- sae_lens/load_model.py +52 -4
- sae_lens/{toolkit → loading}/pretrained_sae_loaders.py +85 -32
- sae_lens/registry.py +49 -0
- sae_lens/saes/__init__.py +48 -0
- sae_lens/saes/gated_sae.py +254 -0
- sae_lens/saes/jumprelu_sae.py +348 -0
- sae_lens/saes/sae.py +1076 -0
- sae_lens/saes/standard_sae.py +178 -0
- sae_lens/saes/topk_sae.py +300 -0
- sae_lens/training/activation_scaler.py +53 -0
- sae_lens/training/activations_store.py +103 -184
- sae_lens/training/mixing_buffer.py +56 -0
- sae_lens/training/optim.py +60 -36
- sae_lens/training/sae_trainer.py +155 -177
- sae_lens/training/types.py +5 -0
- sae_lens/training/upload_saes_to_huggingface.py +13 -7
- sae_lens/util.py +47 -0
- {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/METADATA +1 -1
- sae_lens-6.0.0.dist-info/RECORD +37 -0
- sae_lens/sae.py +0 -747
- sae_lens/sae_training_runner.py +0 -251
- sae_lens/training/geometric_median.py +0 -101
- sae_lens/training/training_sae.py +0 -710
- sae_lens-5.11.0.dist-info/RECORD +0 -28
- /sae_lens/{toolkit → loading}/__init__.py +0 -0
- /sae_lens/{toolkit → loading}/pretrained_saes_directory.py +0 -0
- {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/LICENSE +0 -0
- {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,348 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
from jaxtyping import Float
|
|
7
|
+
from torch import nn
|
|
8
|
+
from typing_extensions import override
|
|
9
|
+
|
|
10
|
+
from sae_lens.saes.sae import (
|
|
11
|
+
SAE,
|
|
12
|
+
SAEConfig,
|
|
13
|
+
TrainCoefficientConfig,
|
|
14
|
+
TrainingSAE,
|
|
15
|
+
TrainingSAEConfig,
|
|
16
|
+
TrainStepInput,
|
|
17
|
+
TrainStepOutput,
|
|
18
|
+
)
|
|
19
|
+
from sae_lens.util import filter_valid_dataclass_fields
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def rectangle(x: torch.Tensor) -> torch.Tensor:
|
|
23
|
+
return ((x > -0.5) & (x < 0.5)).to(x)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class Step(torch.autograd.Function):
|
|
27
|
+
@staticmethod
|
|
28
|
+
def forward(
|
|
29
|
+
x: torch.Tensor,
|
|
30
|
+
threshold: torch.Tensor,
|
|
31
|
+
bandwidth: float, # noqa: ARG004
|
|
32
|
+
) -> torch.Tensor:
|
|
33
|
+
return (x > threshold).to(x)
|
|
34
|
+
|
|
35
|
+
@staticmethod
|
|
36
|
+
def setup_context(
|
|
37
|
+
ctx: Any, inputs: tuple[torch.Tensor, torch.Tensor, float], output: torch.Tensor
|
|
38
|
+
) -> None:
|
|
39
|
+
x, threshold, bandwidth = inputs
|
|
40
|
+
del output
|
|
41
|
+
ctx.save_for_backward(x, threshold)
|
|
42
|
+
ctx.bandwidth = bandwidth
|
|
43
|
+
|
|
44
|
+
@staticmethod
|
|
45
|
+
def backward( # type: ignore[override]
|
|
46
|
+
ctx: Any, grad_output: torch.Tensor
|
|
47
|
+
) -> tuple[None, torch.Tensor, None]:
|
|
48
|
+
x, threshold = ctx.saved_tensors
|
|
49
|
+
bandwidth = ctx.bandwidth
|
|
50
|
+
threshold_grad = torch.sum(
|
|
51
|
+
-(1.0 / bandwidth) * rectangle((x - threshold) / bandwidth) * grad_output,
|
|
52
|
+
dim=0,
|
|
53
|
+
)
|
|
54
|
+
return None, threshold_grad, None
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class JumpReLU(torch.autograd.Function):
|
|
58
|
+
@staticmethod
|
|
59
|
+
def forward(
|
|
60
|
+
x: torch.Tensor,
|
|
61
|
+
threshold: torch.Tensor,
|
|
62
|
+
bandwidth: float, # noqa: ARG004
|
|
63
|
+
) -> torch.Tensor:
|
|
64
|
+
return (x * (x > threshold)).to(x)
|
|
65
|
+
|
|
66
|
+
@staticmethod
|
|
67
|
+
def setup_context(
|
|
68
|
+
ctx: Any, inputs: tuple[torch.Tensor, torch.Tensor, float], output: torch.Tensor
|
|
69
|
+
) -> None:
|
|
70
|
+
x, threshold, bandwidth = inputs
|
|
71
|
+
del output
|
|
72
|
+
ctx.save_for_backward(x, threshold)
|
|
73
|
+
ctx.bandwidth = bandwidth
|
|
74
|
+
|
|
75
|
+
@staticmethod
|
|
76
|
+
def backward( # type: ignore[override]
|
|
77
|
+
ctx: Any, grad_output: torch.Tensor
|
|
78
|
+
) -> tuple[torch.Tensor, torch.Tensor, None]:
|
|
79
|
+
x, threshold = ctx.saved_tensors
|
|
80
|
+
bandwidth = ctx.bandwidth
|
|
81
|
+
x_grad = (x > threshold) * grad_output # We don't apply STE to x input
|
|
82
|
+
threshold_grad = torch.sum(
|
|
83
|
+
-(threshold / bandwidth)
|
|
84
|
+
* rectangle((x - threshold) / bandwidth)
|
|
85
|
+
* grad_output,
|
|
86
|
+
dim=0,
|
|
87
|
+
)
|
|
88
|
+
return x_grad, threshold_grad, None
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@dataclass
|
|
92
|
+
class JumpReLUSAEConfig(SAEConfig):
|
|
93
|
+
"""
|
|
94
|
+
Configuration class for a JumpReLUSAE.
|
|
95
|
+
"""
|
|
96
|
+
|
|
97
|
+
@override
|
|
98
|
+
@classmethod
|
|
99
|
+
def architecture(cls) -> str:
|
|
100
|
+
return "jumprelu"
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class JumpReLUSAE(SAE[JumpReLUSAEConfig]):
|
|
104
|
+
"""
|
|
105
|
+
JumpReLUSAE is an inference-only implementation of a Sparse Autoencoder (SAE)
|
|
106
|
+
using a JumpReLU activation. For each unit, if its pre-activation is
|
|
107
|
+
<= threshold, that unit is zeroed out; otherwise, it follows a user-specified
|
|
108
|
+
activation function (e.g., ReLU etc.).
|
|
109
|
+
|
|
110
|
+
It implements:
|
|
111
|
+
- initialize_weights: sets up parameters, including a threshold.
|
|
112
|
+
- encode: computes the feature activations using JumpReLU.
|
|
113
|
+
- decode: reconstructs the input from the feature activations.
|
|
114
|
+
|
|
115
|
+
The BaseSAE.forward() method automatically calls encode and decode,
|
|
116
|
+
including any error-term processing if configured.
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
b_enc: nn.Parameter
|
|
120
|
+
threshold: nn.Parameter
|
|
121
|
+
|
|
122
|
+
def __init__(self, cfg: JumpReLUSAEConfig, use_error_term: bool = False):
|
|
123
|
+
super().__init__(cfg, use_error_term)
|
|
124
|
+
|
|
125
|
+
@override
|
|
126
|
+
def initialize_weights(self) -> None:
|
|
127
|
+
super().initialize_weights()
|
|
128
|
+
self.threshold = nn.Parameter(
|
|
129
|
+
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
|
|
130
|
+
)
|
|
131
|
+
self.b_enc = nn.Parameter(
|
|
132
|
+
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
def encode(
|
|
136
|
+
self, x: Float[torch.Tensor, "... d_in"]
|
|
137
|
+
) -> Float[torch.Tensor, "... d_sae"]:
|
|
138
|
+
"""
|
|
139
|
+
Encode the input tensor into the feature space using JumpReLU.
|
|
140
|
+
The threshold parameter determines which units remain active.
|
|
141
|
+
"""
|
|
142
|
+
sae_in = self.process_sae_in(x)
|
|
143
|
+
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
|
|
144
|
+
|
|
145
|
+
# 1) Apply the base "activation_fn" from config (e.g., ReLU).
|
|
146
|
+
base_acts = self.activation_fn(hidden_pre)
|
|
147
|
+
|
|
148
|
+
# 2) Zero out any unit whose (hidden_pre <= threshold).
|
|
149
|
+
# We cast the boolean mask to the same dtype for safe multiplication.
|
|
150
|
+
jump_relu_mask = (hidden_pre > self.threshold).to(base_acts.dtype)
|
|
151
|
+
|
|
152
|
+
# 3) Multiply the normally activated units by that mask.
|
|
153
|
+
return self.hook_sae_acts_post(base_acts * jump_relu_mask)
|
|
154
|
+
|
|
155
|
+
def decode(
|
|
156
|
+
self, feature_acts: Float[torch.Tensor, "... d_sae"]
|
|
157
|
+
) -> Float[torch.Tensor, "... d_in"]:
|
|
158
|
+
"""
|
|
159
|
+
Decode the feature activations back to the input space.
|
|
160
|
+
Follows the same steps as StandardSAE: apply scaling, transform, hook, and optionally reshape.
|
|
161
|
+
"""
|
|
162
|
+
sae_out_pre = feature_acts @ self.W_dec + self.b_dec
|
|
163
|
+
sae_out_pre = self.hook_sae_recons(sae_out_pre)
|
|
164
|
+
sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
|
|
165
|
+
return self.reshape_fn_out(sae_out_pre, self.d_head)
|
|
166
|
+
|
|
167
|
+
@torch.no_grad()
|
|
168
|
+
def fold_W_dec_norm(self):
|
|
169
|
+
"""
|
|
170
|
+
Override to properly handle threshold adjustment with W_dec norms.
|
|
171
|
+
When we scale the encoder weights, we need to scale the threshold
|
|
172
|
+
by the same factor to maintain the same sparsity pattern.
|
|
173
|
+
"""
|
|
174
|
+
# Save the current threshold before calling parent method
|
|
175
|
+
current_thresh = self.threshold.clone()
|
|
176
|
+
|
|
177
|
+
# Get W_dec norms that will be used for scaling
|
|
178
|
+
W_dec_norms = self.W_dec.norm(dim=-1)
|
|
179
|
+
|
|
180
|
+
# Call parent implementation to handle W_enc, W_dec, and b_enc adjustment
|
|
181
|
+
super().fold_W_dec_norm()
|
|
182
|
+
|
|
183
|
+
# Scale the threshold by the same factor as we scaled b_enc
|
|
184
|
+
# This ensures the same features remain active/inactive after folding
|
|
185
|
+
self.threshold.data = current_thresh * W_dec_norms
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
@dataclass
|
|
189
|
+
class JumpReLUTrainingSAEConfig(TrainingSAEConfig):
|
|
190
|
+
"""
|
|
191
|
+
Configuration class for training a JumpReLUTrainingSAE.
|
|
192
|
+
"""
|
|
193
|
+
|
|
194
|
+
jumprelu_init_threshold: float = 0.01
|
|
195
|
+
jumprelu_bandwidth: float = 0.05
|
|
196
|
+
l0_coefficient: float = 1.0
|
|
197
|
+
l0_warm_up_steps: int = 0
|
|
198
|
+
|
|
199
|
+
@override
|
|
200
|
+
@classmethod
|
|
201
|
+
def architecture(cls) -> str:
|
|
202
|
+
return "jumprelu"
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
class JumpReLUTrainingSAE(TrainingSAE[JumpReLUTrainingSAEConfig]):
|
|
206
|
+
"""
|
|
207
|
+
JumpReLUTrainingSAE is a training-focused implementation of a SAE using a JumpReLU activation.
|
|
208
|
+
|
|
209
|
+
Similar to the inference-only JumpReLUSAE, but with:
|
|
210
|
+
- A learnable log-threshold parameter (instead of a raw threshold).
|
|
211
|
+
- Forward passes that add noise during training, if configured.
|
|
212
|
+
- A specialized auxiliary loss term for sparsity (L0 or similar).
|
|
213
|
+
|
|
214
|
+
Methods of interest include:
|
|
215
|
+
- initialize_weights: sets up W_enc, b_enc, W_dec, b_dec, and log_threshold.
|
|
216
|
+
- encode_with_hidden_pre_jumprelu: runs a forward pass for training, optionally adding noise.
|
|
217
|
+
- training_forward_pass: calculates MSE and auxiliary losses, returning a TrainStepOutput.
|
|
218
|
+
"""
|
|
219
|
+
|
|
220
|
+
b_enc: nn.Parameter
|
|
221
|
+
log_threshold: nn.Parameter
|
|
222
|
+
|
|
223
|
+
def __init__(self, cfg: JumpReLUTrainingSAEConfig, use_error_term: bool = False):
|
|
224
|
+
super().__init__(cfg, use_error_term)
|
|
225
|
+
|
|
226
|
+
# We'll store a bandwidth for the training approach, if needed
|
|
227
|
+
self.bandwidth = cfg.jumprelu_bandwidth
|
|
228
|
+
|
|
229
|
+
# In typical JumpReLU training code, we may track a log_threshold:
|
|
230
|
+
self.log_threshold = nn.Parameter(
|
|
231
|
+
torch.ones(self.cfg.d_sae, dtype=self.dtype, device=self.device)
|
|
232
|
+
* np.log(cfg.jumprelu_init_threshold)
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
@override
|
|
236
|
+
def initialize_weights(self) -> None:
|
|
237
|
+
"""
|
|
238
|
+
Initialize parameters like the base SAE, but also add log_threshold.
|
|
239
|
+
"""
|
|
240
|
+
super().initialize_weights()
|
|
241
|
+
# Encoder Bias
|
|
242
|
+
self.b_enc = nn.Parameter(
|
|
243
|
+
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
@property
|
|
247
|
+
def threshold(self) -> torch.Tensor:
|
|
248
|
+
"""
|
|
249
|
+
Returns the parameterized threshold > 0 for each unit.
|
|
250
|
+
threshold = exp(log_threshold).
|
|
251
|
+
"""
|
|
252
|
+
return torch.exp(self.log_threshold)
|
|
253
|
+
|
|
254
|
+
def encode_with_hidden_pre(
|
|
255
|
+
self, x: Float[torch.Tensor, "... d_in"]
|
|
256
|
+
) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
|
|
257
|
+
sae_in = self.process_sae_in(x)
|
|
258
|
+
|
|
259
|
+
hidden_pre = sae_in @ self.W_enc + self.b_enc
|
|
260
|
+
feature_acts = JumpReLU.apply(hidden_pre, self.threshold, self.bandwidth)
|
|
261
|
+
|
|
262
|
+
return feature_acts, hidden_pre # type: ignore
|
|
263
|
+
|
|
264
|
+
@override
|
|
265
|
+
def calculate_aux_loss(
|
|
266
|
+
self,
|
|
267
|
+
step_input: TrainStepInput,
|
|
268
|
+
feature_acts: torch.Tensor,
|
|
269
|
+
hidden_pre: torch.Tensor,
|
|
270
|
+
sae_out: torch.Tensor,
|
|
271
|
+
) -> dict[str, torch.Tensor]:
|
|
272
|
+
"""Calculate architecture-specific auxiliary loss terms."""
|
|
273
|
+
l0 = torch.sum(Step.apply(hidden_pre, self.threshold, self.bandwidth), dim=-1) # type: ignore
|
|
274
|
+
l0_loss = (step_input.coefficients["l0"] * l0).mean()
|
|
275
|
+
return {"l0_loss": l0_loss}
|
|
276
|
+
|
|
277
|
+
@override
|
|
278
|
+
def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]:
|
|
279
|
+
return {
|
|
280
|
+
"l0": TrainCoefficientConfig(
|
|
281
|
+
value=self.cfg.l0_coefficient,
|
|
282
|
+
warm_up_steps=self.cfg.l0_warm_up_steps,
|
|
283
|
+
),
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
@torch.no_grad()
|
|
287
|
+
def fold_W_dec_norm(self):
|
|
288
|
+
"""
|
|
289
|
+
Override to properly handle threshold adjustment with W_dec norms.
|
|
290
|
+
"""
|
|
291
|
+
# Save the current threshold before we call the parent method
|
|
292
|
+
current_thresh = self.threshold.clone()
|
|
293
|
+
|
|
294
|
+
# Get W_dec norms
|
|
295
|
+
W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
|
|
296
|
+
|
|
297
|
+
# Call parent implementation to handle W_enc and W_dec adjustment
|
|
298
|
+
super().fold_W_dec_norm()
|
|
299
|
+
|
|
300
|
+
# Fix: Use squeeze() instead of squeeze(-1) to match old behavior
|
|
301
|
+
self.log_threshold.data = torch.log(current_thresh * W_dec_norms.squeeze())
|
|
302
|
+
|
|
303
|
+
def _create_train_step_output(
|
|
304
|
+
self,
|
|
305
|
+
sae_in: torch.Tensor,
|
|
306
|
+
sae_out: torch.Tensor,
|
|
307
|
+
feature_acts: torch.Tensor,
|
|
308
|
+
hidden_pre: torch.Tensor,
|
|
309
|
+
loss: torch.Tensor,
|
|
310
|
+
losses: dict[str, torch.Tensor],
|
|
311
|
+
) -> TrainStepOutput:
|
|
312
|
+
"""
|
|
313
|
+
Helper to produce a TrainStepOutput from the trainer.
|
|
314
|
+
The old code expects a method named _create_train_step_output().
|
|
315
|
+
"""
|
|
316
|
+
return TrainStepOutput(
|
|
317
|
+
sae_in=sae_in,
|
|
318
|
+
sae_out=sae_out,
|
|
319
|
+
feature_acts=feature_acts,
|
|
320
|
+
hidden_pre=hidden_pre,
|
|
321
|
+
loss=loss,
|
|
322
|
+
losses=losses,
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
@torch.no_grad()
|
|
326
|
+
def initialize_decoder_norm_constant_norm(self, norm: float = 0.1):
|
|
327
|
+
"""Initialize decoder with constant norm"""
|
|
328
|
+
self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
|
|
329
|
+
self.W_dec.data *= norm
|
|
330
|
+
|
|
331
|
+
def process_state_dict_for_saving(self, state_dict: dict[str, Any]) -> None:
|
|
332
|
+
"""Convert log_threshold to threshold for saving"""
|
|
333
|
+
if "log_threshold" in state_dict:
|
|
334
|
+
threshold = torch.exp(state_dict["log_threshold"]).detach().contiguous()
|
|
335
|
+
del state_dict["log_threshold"]
|
|
336
|
+
state_dict["threshold"] = threshold
|
|
337
|
+
|
|
338
|
+
def process_state_dict_for_loading(self, state_dict: dict[str, Any]) -> None:
|
|
339
|
+
"""Convert threshold to log_threshold for loading"""
|
|
340
|
+
if "threshold" in state_dict:
|
|
341
|
+
threshold = state_dict["threshold"]
|
|
342
|
+
del state_dict["threshold"]
|
|
343
|
+
state_dict["log_threshold"] = torch.log(threshold).detach().contiguous()
|
|
344
|
+
|
|
345
|
+
def to_inference_config_dict(self) -> dict[str, Any]:
|
|
346
|
+
return filter_valid_dataclass_fields(
|
|
347
|
+
self.cfg.to_dict(), JumpReLUSAEConfig, ["architecture"]
|
|
348
|
+
)
|