sae-lens 5.9.1__py3-none-any.whl → 6.0.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +22 -6
- sae_lens/analysis/hooked_sae_transformer.py +2 -2
- sae_lens/config.py +66 -23
- sae_lens/evals.py +6 -5
- sae_lens/{toolkit → loading}/pretrained_sae_loaders.py +33 -25
- sae_lens/regsitry.py +34 -0
- sae_lens/sae_training_runner.py +18 -33
- sae_lens/saes/gated_sae.py +247 -0
- sae_lens/saes/jumprelu_sae.py +368 -0
- sae_lens/saes/sae.py +970 -0
- sae_lens/saes/standard_sae.py +167 -0
- sae_lens/saes/topk_sae.py +305 -0
- sae_lens/training/activations_store.py +2 -2
- sae_lens/training/sae_trainer.py +13 -19
- sae_lens/training/upload_saes_to_huggingface.py +1 -1
- {sae_lens-5.9.1.dist-info → sae_lens-6.0.0rc1.dist-info}/METADATA +2 -2
- sae_lens-6.0.0rc1.dist-info/RECORD +32 -0
- sae_lens/sae.py +0 -747
- sae_lens/training/training_sae.py +0 -705
- sae_lens-5.9.1.dist-info/RECORD +0 -28
- /sae_lens/{toolkit → loading}/__init__.py +0 -0
- /sae_lens/{toolkit → loading}/pretrained_saes_directory.py +0 -0
- {sae_lens-5.9.1.dist-info → sae_lens-6.0.0rc1.dist-info}/LICENSE +0 -0
- {sae_lens-5.9.1.dist-info → sae_lens-6.0.0rc1.dist-info}/WHEEL +0 -0
|
@@ -1,705 +0,0 @@
|
|
|
1
|
-
"""Most of this is just copied over from Arthur's code and slightly simplified:
|
|
2
|
-
https://github.com/ArthurConmy/sae/blob/main/sae/model.py
|
|
3
|
-
"""
|
|
4
|
-
|
|
5
|
-
from dataclasses import dataclass, fields
|
|
6
|
-
from typing import Any
|
|
7
|
-
|
|
8
|
-
import einops
|
|
9
|
-
import numpy as np
|
|
10
|
-
import torch
|
|
11
|
-
from jaxtyping import Float
|
|
12
|
-
from torch import nn
|
|
13
|
-
from typing_extensions import deprecated
|
|
14
|
-
|
|
15
|
-
from sae_lens import logger
|
|
16
|
-
from sae_lens.config import LanguageModelSAERunnerConfig
|
|
17
|
-
from sae_lens.sae import SAE, SAEConfig
|
|
18
|
-
from sae_lens.toolkit.pretrained_sae_loaders import (
|
|
19
|
-
PretrainedSaeDiskLoader,
|
|
20
|
-
handle_config_defaulting,
|
|
21
|
-
sae_lens_disk_loader,
|
|
22
|
-
)
|
|
23
|
-
|
|
24
|
-
SPARSITY_PATH = "sparsity.safetensors"
|
|
25
|
-
SAE_WEIGHTS_PATH = "sae_weights.safetensors"
|
|
26
|
-
SAE_CFG_PATH = "cfg.json"
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
def rectangle(x: torch.Tensor) -> torch.Tensor:
|
|
30
|
-
return ((x > -0.5) & (x < 0.5)).to(x)
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
class Step(torch.autograd.Function):
|
|
34
|
-
@staticmethod
|
|
35
|
-
def forward(
|
|
36
|
-
x: torch.Tensor,
|
|
37
|
-
threshold: torch.Tensor,
|
|
38
|
-
bandwidth: float, # noqa: ARG004
|
|
39
|
-
) -> torch.Tensor:
|
|
40
|
-
return (x > threshold).to(x)
|
|
41
|
-
|
|
42
|
-
@staticmethod
|
|
43
|
-
def setup_context(
|
|
44
|
-
ctx: Any, inputs: tuple[torch.Tensor, torch.Tensor, float], output: torch.Tensor
|
|
45
|
-
) -> None:
|
|
46
|
-
x, threshold, bandwidth = inputs
|
|
47
|
-
del output
|
|
48
|
-
ctx.save_for_backward(x, threshold)
|
|
49
|
-
ctx.bandwidth = bandwidth
|
|
50
|
-
|
|
51
|
-
@staticmethod
|
|
52
|
-
def backward( # type: ignore[override]
|
|
53
|
-
ctx: Any, grad_output: torch.Tensor
|
|
54
|
-
) -> tuple[None, torch.Tensor, None]:
|
|
55
|
-
x, threshold = ctx.saved_tensors
|
|
56
|
-
bandwidth = ctx.bandwidth
|
|
57
|
-
threshold_grad = torch.sum(
|
|
58
|
-
-(1.0 / bandwidth) * rectangle((x - threshold) / bandwidth) * grad_output,
|
|
59
|
-
dim=0,
|
|
60
|
-
)
|
|
61
|
-
return None, threshold_grad, None
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
class JumpReLU(torch.autograd.Function):
|
|
65
|
-
@staticmethod
|
|
66
|
-
def forward(
|
|
67
|
-
x: torch.Tensor,
|
|
68
|
-
threshold: torch.Tensor,
|
|
69
|
-
bandwidth: float, # noqa: ARG004
|
|
70
|
-
) -> torch.Tensor:
|
|
71
|
-
return (x * (x > threshold)).to(x)
|
|
72
|
-
|
|
73
|
-
@staticmethod
|
|
74
|
-
def setup_context(
|
|
75
|
-
ctx: Any, inputs: tuple[torch.Tensor, torch.Tensor, float], output: torch.Tensor
|
|
76
|
-
) -> None:
|
|
77
|
-
x, threshold, bandwidth = inputs
|
|
78
|
-
del output
|
|
79
|
-
ctx.save_for_backward(x, threshold)
|
|
80
|
-
ctx.bandwidth = bandwidth
|
|
81
|
-
|
|
82
|
-
@staticmethod
|
|
83
|
-
def backward( # type: ignore[override]
|
|
84
|
-
ctx: Any, grad_output: torch.Tensor
|
|
85
|
-
) -> tuple[torch.Tensor, torch.Tensor, None]:
|
|
86
|
-
x, threshold = ctx.saved_tensors
|
|
87
|
-
bandwidth = ctx.bandwidth
|
|
88
|
-
x_grad = (x > threshold) * grad_output # We don't apply STE to x input
|
|
89
|
-
threshold_grad = torch.sum(
|
|
90
|
-
-(threshold / bandwidth)
|
|
91
|
-
* rectangle((x - threshold) / bandwidth)
|
|
92
|
-
* grad_output,
|
|
93
|
-
dim=0,
|
|
94
|
-
)
|
|
95
|
-
return x_grad, threshold_grad, None
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
@dataclass
|
|
99
|
-
class TrainStepOutput:
|
|
100
|
-
sae_in: torch.Tensor
|
|
101
|
-
sae_out: torch.Tensor
|
|
102
|
-
feature_acts: torch.Tensor
|
|
103
|
-
hidden_pre: torch.Tensor
|
|
104
|
-
loss: torch.Tensor # we need to call backwards on this
|
|
105
|
-
losses: dict[str, float | torch.Tensor]
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
@dataclass(kw_only=True)
|
|
109
|
-
class TrainingSAEConfig(SAEConfig):
|
|
110
|
-
# Sparsity Loss Calculations
|
|
111
|
-
l1_coefficient: float
|
|
112
|
-
lp_norm: float
|
|
113
|
-
use_ghost_grads: bool
|
|
114
|
-
normalize_sae_decoder: bool
|
|
115
|
-
noise_scale: float
|
|
116
|
-
decoder_orthogonal_init: bool
|
|
117
|
-
mse_loss_normalization: str | None
|
|
118
|
-
jumprelu_init_threshold: float
|
|
119
|
-
jumprelu_bandwidth: float
|
|
120
|
-
decoder_heuristic_init: bool
|
|
121
|
-
init_encoder_as_decoder_transpose: bool
|
|
122
|
-
scale_sparsity_penalty_by_decoder_norm: bool
|
|
123
|
-
|
|
124
|
-
@classmethod
|
|
125
|
-
def from_sae_runner_config(
|
|
126
|
-
cls, cfg: LanguageModelSAERunnerConfig
|
|
127
|
-
) -> "TrainingSAEConfig":
|
|
128
|
-
return cls(
|
|
129
|
-
# base config
|
|
130
|
-
architecture=cfg.architecture,
|
|
131
|
-
d_in=cfg.d_in,
|
|
132
|
-
d_sae=cfg.d_sae, # type: ignore
|
|
133
|
-
dtype=cfg.dtype,
|
|
134
|
-
device=cfg.device,
|
|
135
|
-
model_name=cfg.model_name,
|
|
136
|
-
hook_name=cfg.hook_name,
|
|
137
|
-
hook_layer=cfg.hook_layer,
|
|
138
|
-
hook_head_index=cfg.hook_head_index,
|
|
139
|
-
activation_fn_str=cfg.activation_fn,
|
|
140
|
-
activation_fn_kwargs=cfg.activation_fn_kwargs,
|
|
141
|
-
apply_b_dec_to_input=cfg.apply_b_dec_to_input,
|
|
142
|
-
finetuning_scaling_factor=cfg.finetuning_method is not None,
|
|
143
|
-
sae_lens_training_version=cfg.sae_lens_training_version,
|
|
144
|
-
context_size=cfg.context_size,
|
|
145
|
-
dataset_path=cfg.dataset_path,
|
|
146
|
-
prepend_bos=cfg.prepend_bos,
|
|
147
|
-
seqpos_slice=cfg.seqpos_slice,
|
|
148
|
-
# Training cfg
|
|
149
|
-
l1_coefficient=cfg.l1_coefficient,
|
|
150
|
-
lp_norm=cfg.lp_norm,
|
|
151
|
-
use_ghost_grads=cfg.use_ghost_grads,
|
|
152
|
-
normalize_sae_decoder=cfg.normalize_sae_decoder,
|
|
153
|
-
noise_scale=cfg.noise_scale,
|
|
154
|
-
decoder_orthogonal_init=cfg.decoder_orthogonal_init,
|
|
155
|
-
mse_loss_normalization=cfg.mse_loss_normalization,
|
|
156
|
-
decoder_heuristic_init=cfg.decoder_heuristic_init,
|
|
157
|
-
init_encoder_as_decoder_transpose=cfg.init_encoder_as_decoder_transpose,
|
|
158
|
-
scale_sparsity_penalty_by_decoder_norm=cfg.scale_sparsity_penalty_by_decoder_norm,
|
|
159
|
-
normalize_activations=cfg.normalize_activations,
|
|
160
|
-
dataset_trust_remote_code=cfg.dataset_trust_remote_code,
|
|
161
|
-
model_from_pretrained_kwargs=cfg.model_from_pretrained_kwargs or {},
|
|
162
|
-
jumprelu_init_threshold=cfg.jumprelu_init_threshold,
|
|
163
|
-
jumprelu_bandwidth=cfg.jumprelu_bandwidth,
|
|
164
|
-
)
|
|
165
|
-
|
|
166
|
-
@classmethod
|
|
167
|
-
def from_dict(cls, config_dict: dict[str, Any]) -> "TrainingSAEConfig":
|
|
168
|
-
# remove any keys that are not in the dataclass
|
|
169
|
-
# since we sometimes enhance the config with the whole LM runner config
|
|
170
|
-
valid_field_names = {field.name for field in fields(cls)}
|
|
171
|
-
valid_config_dict = {
|
|
172
|
-
key: val for key, val in config_dict.items() if key in valid_field_names
|
|
173
|
-
}
|
|
174
|
-
|
|
175
|
-
# ensure seqpos slice is tuple
|
|
176
|
-
# ensure that seqpos slices is a tuple
|
|
177
|
-
# Ensure seqpos_slice is a tuple
|
|
178
|
-
if "seqpos_slice" in valid_config_dict:
|
|
179
|
-
if isinstance(valid_config_dict["seqpos_slice"], list):
|
|
180
|
-
valid_config_dict["seqpos_slice"] = tuple(
|
|
181
|
-
valid_config_dict["seqpos_slice"]
|
|
182
|
-
)
|
|
183
|
-
elif not isinstance(valid_config_dict["seqpos_slice"], tuple):
|
|
184
|
-
valid_config_dict["seqpos_slice"] = (valid_config_dict["seqpos_slice"],)
|
|
185
|
-
|
|
186
|
-
return TrainingSAEConfig(**valid_config_dict)
|
|
187
|
-
|
|
188
|
-
def to_dict(self) -> dict[str, Any]:
|
|
189
|
-
return {
|
|
190
|
-
**super().to_dict(),
|
|
191
|
-
"l1_coefficient": self.l1_coefficient,
|
|
192
|
-
"lp_norm": self.lp_norm,
|
|
193
|
-
"use_ghost_grads": self.use_ghost_grads,
|
|
194
|
-
"normalize_sae_decoder": self.normalize_sae_decoder,
|
|
195
|
-
"noise_scale": self.noise_scale,
|
|
196
|
-
"decoder_orthogonal_init": self.decoder_orthogonal_init,
|
|
197
|
-
"init_encoder_as_decoder_transpose": self.init_encoder_as_decoder_transpose,
|
|
198
|
-
"mse_loss_normalization": self.mse_loss_normalization,
|
|
199
|
-
"decoder_heuristic_init": self.decoder_heuristic_init,
|
|
200
|
-
"scale_sparsity_penalty_by_decoder_norm": self.scale_sparsity_penalty_by_decoder_norm,
|
|
201
|
-
"normalize_activations": self.normalize_activations,
|
|
202
|
-
"jumprelu_init_threshold": self.jumprelu_init_threshold,
|
|
203
|
-
"jumprelu_bandwidth": self.jumprelu_bandwidth,
|
|
204
|
-
}
|
|
205
|
-
|
|
206
|
-
# this needs to exist so we can initialize the parent sae cfg without the training specific
|
|
207
|
-
# parameters. Maybe there's a cleaner way to do this
|
|
208
|
-
def get_base_sae_cfg_dict(self) -> dict[str, Any]:
|
|
209
|
-
return {
|
|
210
|
-
"architecture": self.architecture,
|
|
211
|
-
"d_in": self.d_in,
|
|
212
|
-
"d_sae": self.d_sae,
|
|
213
|
-
"activation_fn_str": self.activation_fn_str,
|
|
214
|
-
"activation_fn_kwargs": self.activation_fn_kwargs,
|
|
215
|
-
"apply_b_dec_to_input": self.apply_b_dec_to_input,
|
|
216
|
-
"dtype": self.dtype,
|
|
217
|
-
"model_name": self.model_name,
|
|
218
|
-
"hook_name": self.hook_name,
|
|
219
|
-
"hook_layer": self.hook_layer,
|
|
220
|
-
"hook_head_index": self.hook_head_index,
|
|
221
|
-
"device": self.device,
|
|
222
|
-
"context_size": self.context_size,
|
|
223
|
-
"prepend_bos": self.prepend_bos,
|
|
224
|
-
"finetuning_scaling_factor": self.finetuning_scaling_factor,
|
|
225
|
-
"normalize_activations": self.normalize_activations,
|
|
226
|
-
"dataset_path": self.dataset_path,
|
|
227
|
-
"dataset_trust_remote_code": self.dataset_trust_remote_code,
|
|
228
|
-
"sae_lens_training_version": self.sae_lens_training_version,
|
|
229
|
-
}
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
class TrainingSAE(SAE):
|
|
233
|
-
"""
|
|
234
|
-
A SAE used for training. This class provides a `training_forward_pass` method which calculates
|
|
235
|
-
losses used for training.
|
|
236
|
-
"""
|
|
237
|
-
|
|
238
|
-
cfg: TrainingSAEConfig
|
|
239
|
-
use_error_term: bool
|
|
240
|
-
dtype: torch.dtype
|
|
241
|
-
device: torch.device
|
|
242
|
-
|
|
243
|
-
def __init__(self, cfg: TrainingSAEConfig, use_error_term: bool = False):
|
|
244
|
-
base_sae_cfg = SAEConfig.from_dict(cfg.get_base_sae_cfg_dict())
|
|
245
|
-
super().__init__(base_sae_cfg)
|
|
246
|
-
self.cfg = cfg # type: ignore
|
|
247
|
-
|
|
248
|
-
if cfg.architecture == "standard" or cfg.architecture == "topk":
|
|
249
|
-
self.encode_with_hidden_pre_fn = self.encode_with_hidden_pre
|
|
250
|
-
elif cfg.architecture == "gated":
|
|
251
|
-
self.encode_with_hidden_pre_fn = self.encode_with_hidden_pre_gated
|
|
252
|
-
elif cfg.architecture == "jumprelu":
|
|
253
|
-
self.encode_with_hidden_pre_fn = self.encode_with_hidden_pre_jumprelu
|
|
254
|
-
self.bandwidth = cfg.jumprelu_bandwidth
|
|
255
|
-
self.log_threshold.data = torch.ones(
|
|
256
|
-
self.cfg.d_sae, dtype=self.dtype, device=self.device
|
|
257
|
-
) * np.log(cfg.jumprelu_init_threshold)
|
|
258
|
-
|
|
259
|
-
else:
|
|
260
|
-
raise ValueError(f"Unknown architecture: {cfg.architecture}")
|
|
261
|
-
|
|
262
|
-
self.check_cfg_compatibility()
|
|
263
|
-
|
|
264
|
-
self.use_error_term = use_error_term
|
|
265
|
-
|
|
266
|
-
self.initialize_weights_complex()
|
|
267
|
-
|
|
268
|
-
# The training SAE will assume that the activation store handles
|
|
269
|
-
# reshaping.
|
|
270
|
-
self.turn_off_forward_pass_hook_z_reshaping()
|
|
271
|
-
|
|
272
|
-
self.mse_loss_fn = self._get_mse_loss_fn()
|
|
273
|
-
|
|
274
|
-
def initialize_weights_jumprelu(self):
|
|
275
|
-
# same as the superclass, except we use a log_threshold parameter instead of threshold
|
|
276
|
-
self.log_threshold = nn.Parameter(
|
|
277
|
-
torch.empty(self.cfg.d_sae, dtype=self.dtype, device=self.device)
|
|
278
|
-
)
|
|
279
|
-
self.initialize_weights_basic()
|
|
280
|
-
|
|
281
|
-
@property
|
|
282
|
-
def threshold(self) -> torch.Tensor:
|
|
283
|
-
if self.cfg.architecture != "jumprelu":
|
|
284
|
-
raise ValueError("Threshold is only defined for Jumprelu SAEs")
|
|
285
|
-
return torch.exp(self.log_threshold)
|
|
286
|
-
|
|
287
|
-
@classmethod
|
|
288
|
-
def from_dict(cls, config_dict: dict[str, Any]) -> "TrainingSAE":
|
|
289
|
-
return cls(TrainingSAEConfig.from_dict(config_dict))
|
|
290
|
-
|
|
291
|
-
def check_cfg_compatibility(self):
|
|
292
|
-
if self.cfg.architecture != "standard" and self.cfg.use_ghost_grads:
|
|
293
|
-
raise ValueError(f"{self.cfg.architecture} SAEs do not support ghost grads")
|
|
294
|
-
if self.cfg.architecture == "gated" and self.use_error_term:
|
|
295
|
-
raise ValueError("Gated SAEs do not support error terms")
|
|
296
|
-
|
|
297
|
-
def encode_standard(
|
|
298
|
-
self, x: Float[torch.Tensor, "... d_in"]
|
|
299
|
-
) -> Float[torch.Tensor, "... d_sae"]:
|
|
300
|
-
"""
|
|
301
|
-
Calcuate SAE features from inputs
|
|
302
|
-
"""
|
|
303
|
-
feature_acts, _ = self.encode_with_hidden_pre_fn(x)
|
|
304
|
-
return feature_acts
|
|
305
|
-
|
|
306
|
-
def encode_with_hidden_pre_jumprelu(
|
|
307
|
-
self, x: Float[torch.Tensor, "... d_in"]
|
|
308
|
-
) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
|
|
309
|
-
sae_in = self.process_sae_in(x)
|
|
310
|
-
|
|
311
|
-
hidden_pre = sae_in @ self.W_enc + self.b_enc
|
|
312
|
-
|
|
313
|
-
if self.training:
|
|
314
|
-
hidden_pre = (
|
|
315
|
-
hidden_pre + torch.randn_like(hidden_pre) * self.cfg.noise_scale
|
|
316
|
-
)
|
|
317
|
-
|
|
318
|
-
threshold = torch.exp(self.log_threshold)
|
|
319
|
-
|
|
320
|
-
feature_acts = JumpReLU.apply(hidden_pre, threshold, self.bandwidth)
|
|
321
|
-
|
|
322
|
-
return feature_acts, hidden_pre # type: ignore
|
|
323
|
-
|
|
324
|
-
def encode_with_hidden_pre(
|
|
325
|
-
self, x: Float[torch.Tensor, "... d_in"]
|
|
326
|
-
) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
|
|
327
|
-
sae_in = self.process_sae_in(x)
|
|
328
|
-
|
|
329
|
-
# "... d_in, d_in d_sae -> ... d_sae",
|
|
330
|
-
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
|
|
331
|
-
hidden_pre_noised = hidden_pre + (
|
|
332
|
-
torch.randn_like(hidden_pre) * self.cfg.noise_scale * self.training
|
|
333
|
-
)
|
|
334
|
-
feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre_noised))
|
|
335
|
-
|
|
336
|
-
return feature_acts, hidden_pre_noised
|
|
337
|
-
|
|
338
|
-
def encode_with_hidden_pre_gated(
|
|
339
|
-
self, x: Float[torch.Tensor, "... d_in"]
|
|
340
|
-
) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
|
|
341
|
-
sae_in = self.process_sae_in(x)
|
|
342
|
-
|
|
343
|
-
# Gating path with Heaviside step function
|
|
344
|
-
gating_pre_activation = sae_in @ self.W_enc + self.b_gate
|
|
345
|
-
active_features = (gating_pre_activation > 0).to(self.dtype)
|
|
346
|
-
|
|
347
|
-
# Magnitude path with weight sharing
|
|
348
|
-
magnitude_pre_activation = sae_in @ (self.W_enc * self.r_mag.exp()) + self.b_mag
|
|
349
|
-
# magnitude_pre_activation_noised = magnitude_pre_activation + (
|
|
350
|
-
# torch.randn_like(magnitude_pre_activation) * self.cfg.noise_scale * self.training
|
|
351
|
-
# )
|
|
352
|
-
feature_magnitudes = self.activation_fn(
|
|
353
|
-
magnitude_pre_activation
|
|
354
|
-
) # magnitude_pre_activation_noised)
|
|
355
|
-
|
|
356
|
-
# Return both the gated feature activations and the magnitude pre-activations
|
|
357
|
-
return (
|
|
358
|
-
active_features * feature_magnitudes,
|
|
359
|
-
magnitude_pre_activation,
|
|
360
|
-
) # magnitude_pre_activation_noised
|
|
361
|
-
|
|
362
|
-
def forward(
|
|
363
|
-
self,
|
|
364
|
-
x: Float[torch.Tensor, "... d_in"],
|
|
365
|
-
) -> Float[torch.Tensor, "... d_in"]:
|
|
366
|
-
feature_acts, _ = self.encode_with_hidden_pre_fn(x)
|
|
367
|
-
return self.decode(feature_acts)
|
|
368
|
-
|
|
369
|
-
def training_forward_pass(
|
|
370
|
-
self,
|
|
371
|
-
sae_in: torch.Tensor,
|
|
372
|
-
current_l1_coefficient: float,
|
|
373
|
-
dead_neuron_mask: torch.Tensor | None = None,
|
|
374
|
-
) -> TrainStepOutput:
|
|
375
|
-
# do a forward pass to get SAE out, but we also need the
|
|
376
|
-
# hidden pre.
|
|
377
|
-
feature_acts, hidden_pre = self.encode_with_hidden_pre_fn(sae_in)
|
|
378
|
-
sae_out = self.decode(feature_acts)
|
|
379
|
-
|
|
380
|
-
# MSE LOSS
|
|
381
|
-
per_item_mse_loss = self.mse_loss_fn(sae_out, sae_in)
|
|
382
|
-
mse_loss = per_item_mse_loss.sum(dim=-1).mean()
|
|
383
|
-
|
|
384
|
-
losses: dict[str, float | torch.Tensor] = {}
|
|
385
|
-
|
|
386
|
-
if self.cfg.architecture == "gated":
|
|
387
|
-
# Gated SAE Loss Calculation
|
|
388
|
-
|
|
389
|
-
# Shared variables
|
|
390
|
-
sae_in_centered = (
|
|
391
|
-
self.reshape_fn_in(sae_in) - self.b_dec * self.cfg.apply_b_dec_to_input
|
|
392
|
-
)
|
|
393
|
-
pi_gate = sae_in_centered @ self.W_enc + self.b_gate
|
|
394
|
-
pi_gate_act = torch.relu(pi_gate)
|
|
395
|
-
|
|
396
|
-
# SFN sparsity loss - summed over the feature dimension and averaged over the batch
|
|
397
|
-
l1_loss = (
|
|
398
|
-
current_l1_coefficient
|
|
399
|
-
* torch.sum(pi_gate_act * self.W_dec.norm(dim=1), dim=-1).mean()
|
|
400
|
-
)
|
|
401
|
-
|
|
402
|
-
# Auxiliary reconstruction loss - summed over the feature dimension and averaged over the batch
|
|
403
|
-
via_gate_reconstruction = pi_gate_act @ self.W_dec + self.b_dec
|
|
404
|
-
aux_reconstruction_loss = torch.sum(
|
|
405
|
-
(via_gate_reconstruction - sae_in) ** 2, dim=-1
|
|
406
|
-
).mean()
|
|
407
|
-
loss = mse_loss + l1_loss + aux_reconstruction_loss
|
|
408
|
-
losses["auxiliary_reconstruction_loss"] = aux_reconstruction_loss
|
|
409
|
-
losses["l1_loss"] = l1_loss
|
|
410
|
-
elif self.cfg.architecture == "jumprelu":
|
|
411
|
-
threshold = torch.exp(self.log_threshold)
|
|
412
|
-
l0 = torch.sum(Step.apply(hidden_pre, threshold, self.bandwidth), dim=-1) # type: ignore
|
|
413
|
-
l0_loss = (current_l1_coefficient * l0).mean()
|
|
414
|
-
loss = mse_loss + l0_loss
|
|
415
|
-
losses["l0_loss"] = l0_loss
|
|
416
|
-
elif self.cfg.architecture == "topk":
|
|
417
|
-
topk_loss = self.calculate_topk_aux_loss(
|
|
418
|
-
sae_in=sae_in,
|
|
419
|
-
sae_out=sae_out,
|
|
420
|
-
hidden_pre=hidden_pre,
|
|
421
|
-
dead_neuron_mask=dead_neuron_mask,
|
|
422
|
-
)
|
|
423
|
-
losses["auxiliary_reconstruction_loss"] = topk_loss
|
|
424
|
-
loss = mse_loss + topk_loss
|
|
425
|
-
else:
|
|
426
|
-
# default SAE sparsity loss
|
|
427
|
-
weighted_feature_acts = feature_acts
|
|
428
|
-
if self.cfg.scale_sparsity_penalty_by_decoder_norm:
|
|
429
|
-
weighted_feature_acts = feature_acts * self.W_dec.norm(dim=1)
|
|
430
|
-
sparsity = weighted_feature_acts.norm(
|
|
431
|
-
p=self.cfg.lp_norm, dim=-1
|
|
432
|
-
) # sum over the feature dimension
|
|
433
|
-
|
|
434
|
-
l1_loss = (current_l1_coefficient * sparsity).mean()
|
|
435
|
-
loss = mse_loss + l1_loss
|
|
436
|
-
if (
|
|
437
|
-
self.cfg.use_ghost_grads
|
|
438
|
-
and self.training
|
|
439
|
-
and dead_neuron_mask is not None
|
|
440
|
-
):
|
|
441
|
-
ghost_grad_loss = self.calculate_ghost_grad_loss(
|
|
442
|
-
x=sae_in,
|
|
443
|
-
sae_out=sae_out,
|
|
444
|
-
per_item_mse_loss=per_item_mse_loss,
|
|
445
|
-
hidden_pre=hidden_pre,
|
|
446
|
-
dead_neuron_mask=dead_neuron_mask,
|
|
447
|
-
)
|
|
448
|
-
losses["ghost_grad_loss"] = ghost_grad_loss
|
|
449
|
-
loss = loss + ghost_grad_loss
|
|
450
|
-
losses["l1_loss"] = l1_loss
|
|
451
|
-
|
|
452
|
-
losses["mse_loss"] = mse_loss
|
|
453
|
-
|
|
454
|
-
return TrainStepOutput(
|
|
455
|
-
sae_in=sae_in,
|
|
456
|
-
sae_out=sae_out,
|
|
457
|
-
feature_acts=feature_acts,
|
|
458
|
-
hidden_pre=hidden_pre,
|
|
459
|
-
loss=loss,
|
|
460
|
-
losses=losses,
|
|
461
|
-
)
|
|
462
|
-
|
|
463
|
-
def calculate_topk_aux_loss(
|
|
464
|
-
self,
|
|
465
|
-
sae_in: torch.Tensor,
|
|
466
|
-
sae_out: torch.Tensor,
|
|
467
|
-
hidden_pre: torch.Tensor,
|
|
468
|
-
dead_neuron_mask: torch.Tensor | None,
|
|
469
|
-
) -> torch.Tensor:
|
|
470
|
-
# Mostly taken from https://github.com/EleutherAI/sae/blob/main/sae/sae.py, except without variance normalization
|
|
471
|
-
# NOTE: checking the number of dead neurons will force a GPU sync, so performance can likely be improved here
|
|
472
|
-
if dead_neuron_mask is None or (num_dead := int(dead_neuron_mask.sum())) == 0:
|
|
473
|
-
return sae_out.new_tensor(0.0)
|
|
474
|
-
residual = (sae_in - sae_out).detach()
|
|
475
|
-
|
|
476
|
-
# Heuristic from Appendix B.1 in the paper
|
|
477
|
-
k_aux = sae_in.shape[-1] // 2
|
|
478
|
-
|
|
479
|
-
# Reduce the scale of the loss if there are a small number of dead latents
|
|
480
|
-
scale = min(num_dead / k_aux, 1.0)
|
|
481
|
-
k_aux = min(k_aux, num_dead)
|
|
482
|
-
|
|
483
|
-
auxk_acts = _calculate_topk_aux_acts(
|
|
484
|
-
k_aux=k_aux,
|
|
485
|
-
hidden_pre=hidden_pre,
|
|
486
|
-
dead_neuron_mask=dead_neuron_mask,
|
|
487
|
-
)
|
|
488
|
-
|
|
489
|
-
# Encourage the top ~50% of dead latents to predict the residual of the
|
|
490
|
-
# top k living latents
|
|
491
|
-
recons = self.decode(auxk_acts)
|
|
492
|
-
auxk_loss = (recons - residual).pow(2).sum(dim=-1).mean()
|
|
493
|
-
return scale * auxk_loss
|
|
494
|
-
|
|
495
|
-
def calculate_ghost_grad_loss(
|
|
496
|
-
self,
|
|
497
|
-
x: torch.Tensor,
|
|
498
|
-
sae_out: torch.Tensor,
|
|
499
|
-
per_item_mse_loss: torch.Tensor,
|
|
500
|
-
hidden_pre: torch.Tensor,
|
|
501
|
-
dead_neuron_mask: torch.Tensor,
|
|
502
|
-
) -> torch.Tensor:
|
|
503
|
-
# 1.
|
|
504
|
-
residual = x - sae_out
|
|
505
|
-
l2_norm_residual = torch.norm(residual, dim=-1)
|
|
506
|
-
|
|
507
|
-
# 2.
|
|
508
|
-
# ghost grads use an exponentional activation function, ignoring whatever
|
|
509
|
-
# the activation function is in the SAE. The forward pass uses the dead neurons only.
|
|
510
|
-
feature_acts_dead_neurons_only = torch.exp(hidden_pre[:, dead_neuron_mask])
|
|
511
|
-
ghost_out = feature_acts_dead_neurons_only @ self.W_dec[dead_neuron_mask, :]
|
|
512
|
-
l2_norm_ghost_out = torch.norm(ghost_out, dim=-1)
|
|
513
|
-
norm_scaling_factor = l2_norm_residual / (1e-6 + l2_norm_ghost_out * 2)
|
|
514
|
-
ghost_out = ghost_out * norm_scaling_factor[:, None].detach()
|
|
515
|
-
|
|
516
|
-
# 3. There is some fairly complex rescaling here to make sure that the loss
|
|
517
|
-
# is comparable to the original loss. This is because the ghost grads are
|
|
518
|
-
# only calculated for the dead neurons, so we need to rescale the loss to
|
|
519
|
-
# make sure that the loss is comparable to the original loss.
|
|
520
|
-
# There have been methodological improvements that are not implemented here yet
|
|
521
|
-
# see here: https://www.lesswrong.com/posts/C5KAZQib3bzzpeyrg/full-post-progress-update-1-from-the-gdm-mech-interp-team#Improving_ghost_grads
|
|
522
|
-
per_item_mse_loss_ghost_resid = self.mse_loss_fn(ghost_out, residual.detach())
|
|
523
|
-
mse_rescaling_factor = (
|
|
524
|
-
per_item_mse_loss / (per_item_mse_loss_ghost_resid + 1e-6)
|
|
525
|
-
).detach()
|
|
526
|
-
per_item_mse_loss_ghost_resid = (
|
|
527
|
-
mse_rescaling_factor * per_item_mse_loss_ghost_resid
|
|
528
|
-
)
|
|
529
|
-
|
|
530
|
-
return per_item_mse_loss_ghost_resid.mean()
|
|
531
|
-
|
|
532
|
-
@torch.no_grad()
|
|
533
|
-
def _get_mse_loss_fn(self) -> Any:
|
|
534
|
-
def standard_mse_loss_fn(
|
|
535
|
-
preds: torch.Tensor, target: torch.Tensor
|
|
536
|
-
) -> torch.Tensor:
|
|
537
|
-
return torch.nn.functional.mse_loss(preds, target, reduction="none")
|
|
538
|
-
|
|
539
|
-
def batch_norm_mse_loss_fn(
|
|
540
|
-
preds: torch.Tensor, target: torch.Tensor
|
|
541
|
-
) -> torch.Tensor:
|
|
542
|
-
target_centered = target - target.mean(dim=0, keepdim=True)
|
|
543
|
-
normalization = target_centered.norm(dim=-1, keepdim=True)
|
|
544
|
-
return torch.nn.functional.mse_loss(preds, target, reduction="none") / (
|
|
545
|
-
normalization + 1e-6
|
|
546
|
-
)
|
|
547
|
-
|
|
548
|
-
if self.cfg.mse_loss_normalization == "dense_batch":
|
|
549
|
-
return batch_norm_mse_loss_fn
|
|
550
|
-
return standard_mse_loss_fn
|
|
551
|
-
|
|
552
|
-
def process_state_dict_for_saving(self, state_dict: dict[str, Any]) -> None:
|
|
553
|
-
if self.cfg.architecture == "jumprelu" and "log_threshold" in state_dict:
|
|
554
|
-
threshold = torch.exp(state_dict["log_threshold"]).detach().contiguous()
|
|
555
|
-
del state_dict["log_threshold"]
|
|
556
|
-
state_dict["threshold"] = threshold
|
|
557
|
-
|
|
558
|
-
def process_state_dict_for_loading(self, state_dict: dict[str, Any]) -> None:
|
|
559
|
-
if self.cfg.architecture == "jumprelu" and "threshold" in state_dict:
|
|
560
|
-
threshold = state_dict["threshold"]
|
|
561
|
-
del state_dict["threshold"]
|
|
562
|
-
state_dict["log_threshold"] = torch.log(threshold).detach().contiguous()
|
|
563
|
-
|
|
564
|
-
@classmethod
|
|
565
|
-
@deprecated("Use load_from_disk instead")
|
|
566
|
-
def load_from_pretrained(
|
|
567
|
-
cls, path: str, device: str = "cpu", dtype: str | None = None
|
|
568
|
-
) -> "TrainingSAE":
|
|
569
|
-
return cls.load_from_disk(path, device, dtype)
|
|
570
|
-
|
|
571
|
-
@classmethod
|
|
572
|
-
def load_from_disk(
|
|
573
|
-
cls,
|
|
574
|
-
path: str,
|
|
575
|
-
device: str = "cpu",
|
|
576
|
-
dtype: str | None = None,
|
|
577
|
-
converter: PretrainedSaeDiskLoader = sae_lens_disk_loader,
|
|
578
|
-
) -> "TrainingSAE":
|
|
579
|
-
overrides = {"dtype": dtype} if dtype is not None else None
|
|
580
|
-
cfg_dict, state_dict = converter(path, device, cfg_overrides=overrides)
|
|
581
|
-
cfg_dict = handle_config_defaulting(cfg_dict)
|
|
582
|
-
sae_cfg = TrainingSAEConfig.from_dict(cfg_dict)
|
|
583
|
-
sae = cls(sae_cfg)
|
|
584
|
-
sae.process_state_dict_for_loading(state_dict)
|
|
585
|
-
sae.load_state_dict(state_dict)
|
|
586
|
-
return sae
|
|
587
|
-
|
|
588
|
-
def initialize_weights_complex(self):
|
|
589
|
-
""" """
|
|
590
|
-
|
|
591
|
-
if self.cfg.decoder_orthogonal_init:
|
|
592
|
-
self.W_dec.data = nn.init.orthogonal_(self.W_dec.data.T).T
|
|
593
|
-
|
|
594
|
-
elif self.cfg.decoder_heuristic_init:
|
|
595
|
-
self.W_dec = nn.Parameter(
|
|
596
|
-
torch.rand(
|
|
597
|
-
self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
|
|
598
|
-
)
|
|
599
|
-
)
|
|
600
|
-
self.initialize_decoder_norm_constant_norm()
|
|
601
|
-
|
|
602
|
-
# Then we initialize the encoder weights (either as the transpose of decoder or not)
|
|
603
|
-
if self.cfg.init_encoder_as_decoder_transpose:
|
|
604
|
-
self.W_enc.data = self.W_dec.data.T.clone().contiguous()
|
|
605
|
-
else:
|
|
606
|
-
self.W_enc = nn.Parameter(
|
|
607
|
-
torch.nn.init.kaiming_uniform_(
|
|
608
|
-
torch.empty(
|
|
609
|
-
self.cfg.d_in,
|
|
610
|
-
self.cfg.d_sae,
|
|
611
|
-
dtype=self.dtype,
|
|
612
|
-
device=self.device,
|
|
613
|
-
)
|
|
614
|
-
)
|
|
615
|
-
)
|
|
616
|
-
|
|
617
|
-
if self.cfg.normalize_sae_decoder:
|
|
618
|
-
with torch.no_grad():
|
|
619
|
-
# Anthropic normalize this to have unit columns
|
|
620
|
-
self.set_decoder_norm_to_unit_norm()
|
|
621
|
-
|
|
622
|
-
@torch.no_grad()
|
|
623
|
-
def fold_W_dec_norm(self):
|
|
624
|
-
# need to deal with the jumprelu having a log_threshold in training
|
|
625
|
-
if self.cfg.architecture == "jumprelu":
|
|
626
|
-
cur_threshold = self.threshold.clone()
|
|
627
|
-
W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
|
|
628
|
-
super().fold_W_dec_norm()
|
|
629
|
-
self.log_threshold.data = torch.log(cur_threshold * W_dec_norms.squeeze())
|
|
630
|
-
else:
|
|
631
|
-
super().fold_W_dec_norm()
|
|
632
|
-
|
|
633
|
-
## Initialization Methods
|
|
634
|
-
@torch.no_grad()
|
|
635
|
-
def initialize_b_dec_with_precalculated(self, origin: torch.Tensor):
|
|
636
|
-
out = torch.tensor(origin, dtype=self.dtype, device=self.device)
|
|
637
|
-
self.b_dec.data = out
|
|
638
|
-
|
|
639
|
-
@torch.no_grad()
|
|
640
|
-
def initialize_b_dec_with_mean(self, all_activations: torch.Tensor):
|
|
641
|
-
previous_b_dec = self.b_dec.clone().cpu()
|
|
642
|
-
out = all_activations.mean(dim=0)
|
|
643
|
-
|
|
644
|
-
previous_distances = torch.norm(all_activations - previous_b_dec, dim=-1)
|
|
645
|
-
distances = torch.norm(all_activations - out, dim=-1)
|
|
646
|
-
|
|
647
|
-
logger.info("Reinitializing b_dec with mean of activations")
|
|
648
|
-
logger.debug(
|
|
649
|
-
f"Previous distances: {previous_distances.median(0).values.mean().item()}"
|
|
650
|
-
)
|
|
651
|
-
logger.debug(f"New distances: {distances.median(0).values.mean().item()}")
|
|
652
|
-
|
|
653
|
-
self.b_dec.data = out.to(self.dtype).to(self.device)
|
|
654
|
-
|
|
655
|
-
## Training Utils
|
|
656
|
-
@torch.no_grad()
|
|
657
|
-
def set_decoder_norm_to_unit_norm(self):
|
|
658
|
-
self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
|
|
659
|
-
|
|
660
|
-
@torch.no_grad()
|
|
661
|
-
def initialize_decoder_norm_constant_norm(self, norm: float = 0.1):
|
|
662
|
-
"""
|
|
663
|
-
A heuristic proceedure inspired by:
|
|
664
|
-
https://transformer-circuits.pub/2024/april-update/index.html#training-saes
|
|
665
|
-
"""
|
|
666
|
-
# TODO: Parameterise this as a function of m and n
|
|
667
|
-
|
|
668
|
-
# ensure W_dec norms at unit norm
|
|
669
|
-
self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
|
|
670
|
-
self.W_dec.data *= norm # will break tests but do this for now.
|
|
671
|
-
|
|
672
|
-
@torch.no_grad()
|
|
673
|
-
def remove_gradient_parallel_to_decoder_directions(self):
|
|
674
|
-
"""
|
|
675
|
-
Update grads so that they remove the parallel component
|
|
676
|
-
(d_sae, d_in) shape
|
|
677
|
-
"""
|
|
678
|
-
assert self.W_dec.grad is not None # keep pyright happy
|
|
679
|
-
|
|
680
|
-
parallel_component = einops.einsum(
|
|
681
|
-
self.W_dec.grad,
|
|
682
|
-
self.W_dec.data,
|
|
683
|
-
"d_sae d_in, d_sae d_in -> d_sae",
|
|
684
|
-
)
|
|
685
|
-
self.W_dec.grad -= einops.einsum(
|
|
686
|
-
parallel_component,
|
|
687
|
-
self.W_dec.data,
|
|
688
|
-
"d_sae, d_sae d_in -> d_sae d_in",
|
|
689
|
-
)
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
def _calculate_topk_aux_acts(
|
|
693
|
-
k_aux: int,
|
|
694
|
-
hidden_pre: torch.Tensor,
|
|
695
|
-
dead_neuron_mask: torch.Tensor,
|
|
696
|
-
) -> torch.Tensor:
|
|
697
|
-
# Don't include living latents in this loss
|
|
698
|
-
auxk_latents = torch.where(dead_neuron_mask[None], hidden_pre, -torch.inf)
|
|
699
|
-
# Top-k dead latents
|
|
700
|
-
auxk_topk = auxk_latents.topk(k_aux, sorted=False)
|
|
701
|
-
# Set the activations to zero for all but the top k_aux dead latents
|
|
702
|
-
auxk_acts = torch.zeros_like(hidden_pre)
|
|
703
|
-
auxk_acts.scatter_(-1, auxk_topk.indices, auxk_topk.values)
|
|
704
|
-
# Set activations to zero for all but top k_aux dead latents
|
|
705
|
-
return auxk_acts
|