sae-lens 5.9.0__py3-none-any.whl → 6.0.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +22 -6
- sae_lens/analysis/hooked_sae_transformer.py +2 -2
- sae_lens/config.py +66 -23
- sae_lens/evals.py +6 -5
- sae_lens/{toolkit → loading}/pretrained_sae_loaders.py +33 -25
- sae_lens/regsitry.py +34 -0
- sae_lens/sae_training_runner.py +18 -33
- sae_lens/saes/gated_sae.py +247 -0
- sae_lens/saes/jumprelu_sae.py +368 -0
- sae_lens/saes/sae.py +970 -0
- sae_lens/saes/standard_sae.py +167 -0
- sae_lens/saes/topk_sae.py +305 -0
- sae_lens/training/activations_store.py +2 -2
- sae_lens/training/sae_trainer.py +13 -19
- sae_lens/training/upload_saes_to_huggingface.py +1 -1
- {sae_lens-5.9.0.dist-info → sae_lens-6.0.0rc1.dist-info}/METADATA +3 -3
- sae_lens-6.0.0rc1.dist-info/RECORD +32 -0
- sae_lens/sae.py +0 -747
- sae_lens/training/training_sae.py +0 -705
- sae_lens-5.9.0.dist-info/RECORD +0 -28
- /sae_lens/{toolkit → loading}/__init__.py +0 -0
- /sae_lens/{toolkit → loading}/pretrained_saes_directory.py +0 -0
- {sae_lens-5.9.0.dist-info → sae_lens-6.0.0rc1.dist-info}/LICENSE +0 -0
- {sae_lens-5.9.0.dist-info → sae_lens-6.0.0rc1.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
from jaxtyping import Float
|
|
4
|
+
from numpy.typing import NDArray
|
|
5
|
+
from torch import nn
|
|
6
|
+
|
|
7
|
+
from sae_lens.saes.sae import SAE, SAEConfig, TrainingSAE, TrainStepInput
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class StandardSAE(SAE):
|
|
11
|
+
"""
|
|
12
|
+
StandardSAE is an inference-only implementation of a Sparse Autoencoder (SAE)
|
|
13
|
+
using a simple linear encoder and decoder.
|
|
14
|
+
|
|
15
|
+
It implements the required abstract methods from BaseSAE:
|
|
16
|
+
- initialize_weights: sets up simple parameter initializations for W_enc, b_enc, W_dec, and b_dec.
|
|
17
|
+
- encode: computes the feature activations from an input.
|
|
18
|
+
- decode: reconstructs the input from the feature activations.
|
|
19
|
+
|
|
20
|
+
The BaseSAE.forward() method automatically calls encode and decode,
|
|
21
|
+
including any error-term processing if configured.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
b_enc: nn.Parameter
|
|
25
|
+
|
|
26
|
+
def __init__(self, cfg: SAEConfig, use_error_term: bool = False):
|
|
27
|
+
super().__init__(cfg, use_error_term)
|
|
28
|
+
|
|
29
|
+
def initialize_weights(self) -> None:
|
|
30
|
+
# Initialize encoder weights and bias.
|
|
31
|
+
self.b_enc = nn.Parameter(
|
|
32
|
+
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
|
|
33
|
+
)
|
|
34
|
+
self.b_dec = nn.Parameter(
|
|
35
|
+
torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
# Use Kaiming Uniform for W_enc
|
|
39
|
+
w_enc_data = torch.empty(
|
|
40
|
+
self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
|
|
41
|
+
)
|
|
42
|
+
nn.init.kaiming_uniform_(w_enc_data)
|
|
43
|
+
self.W_enc = nn.Parameter(w_enc_data)
|
|
44
|
+
|
|
45
|
+
# Use Kaiming Uniform for W_dec
|
|
46
|
+
w_dec_data = torch.empty(
|
|
47
|
+
self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
|
|
48
|
+
)
|
|
49
|
+
nn.init.kaiming_uniform_(w_dec_data)
|
|
50
|
+
self.W_dec = nn.Parameter(w_dec_data)
|
|
51
|
+
|
|
52
|
+
def encode(
|
|
53
|
+
self, x: Float[torch.Tensor, "... d_in"]
|
|
54
|
+
) -> Float[torch.Tensor, "... d_sae"]:
|
|
55
|
+
"""
|
|
56
|
+
Encode the input tensor into the feature space.
|
|
57
|
+
For inference, no noise is added.
|
|
58
|
+
"""
|
|
59
|
+
# Preprocess the SAE input (casting type, applying hooks, normalization)
|
|
60
|
+
sae_in = self.process_sae_in(x)
|
|
61
|
+
# Compute the pre-activation values
|
|
62
|
+
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
|
|
63
|
+
# Apply the activation function (e.g., ReLU, tanh-relu, depending on config)
|
|
64
|
+
return self.hook_sae_acts_post(self.activation_fn(hidden_pre))
|
|
65
|
+
|
|
66
|
+
def decode(
|
|
67
|
+
self, feature_acts: Float[torch.Tensor, "... d_sae"]
|
|
68
|
+
) -> Float[torch.Tensor, "... d_in"]:
|
|
69
|
+
"""
|
|
70
|
+
Decode the feature activations back to the input space.
|
|
71
|
+
Now, if hook_z reshaping is turned on, we reverse the flattening.
|
|
72
|
+
"""
|
|
73
|
+
# 1) apply finetuning scaling if configured.
|
|
74
|
+
scaled_features = self.apply_finetuning_scaling_factor(feature_acts)
|
|
75
|
+
# 2) linear transform
|
|
76
|
+
sae_out_pre = scaled_features @ self.W_dec + self.b_dec
|
|
77
|
+
# 3) hook reconstruction
|
|
78
|
+
sae_out_pre = self.hook_sae_recons(sae_out_pre)
|
|
79
|
+
# 4) optional out-normalization (e.g. constant_norm_rescale or layer_norm)
|
|
80
|
+
sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
|
|
81
|
+
# 5) if hook_z is enabled, rearrange back to (..., n_heads, d_head).
|
|
82
|
+
return self.reshape_fn_out(sae_out_pre, self.d_head)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class StandardTrainingSAE(TrainingSAE):
|
|
86
|
+
"""
|
|
87
|
+
StandardTrainingSAE is a concrete implementation of BaseTrainingSAE using the "standard" SAE architecture.
|
|
88
|
+
It implements:
|
|
89
|
+
- initialize_weights: basic weight initialization for encoder/decoder.
|
|
90
|
+
- encode: inference encoding (invokes encode_with_hidden_pre).
|
|
91
|
+
- decode: a simple linear decoder.
|
|
92
|
+
- encode_with_hidden_pre: computes pre-activations, adds noise when training, and then activates.
|
|
93
|
+
- calculate_aux_loss: computes a sparsity penalty based on the (optionally scaled) p-norm of feature activations.
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
b_enc: nn.Parameter
|
|
97
|
+
|
|
98
|
+
def initialize_weights(self) -> None:
|
|
99
|
+
# Basic init
|
|
100
|
+
# In Python MRO, this calls StandardSAE.initialize_weights()
|
|
101
|
+
StandardSAE.initialize_weights(self) # type: ignore
|
|
102
|
+
|
|
103
|
+
# Complex init logic from original TrainingSAE
|
|
104
|
+
if self.cfg.decoder_orthogonal_init:
|
|
105
|
+
self.W_dec.data = nn.init.orthogonal_(self.W_dec.data.T).T
|
|
106
|
+
|
|
107
|
+
elif self.cfg.decoder_heuristic_init:
|
|
108
|
+
self.W_dec.data = torch.rand( # Changed from Parameter to data assignment
|
|
109
|
+
self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
|
|
110
|
+
)
|
|
111
|
+
self.initialize_decoder_norm_constant_norm()
|
|
112
|
+
|
|
113
|
+
if self.cfg.init_encoder_as_decoder_transpose:
|
|
114
|
+
self.W_enc.data = self.W_dec.data.T.clone().contiguous() # type: ignore
|
|
115
|
+
|
|
116
|
+
if self.cfg.normalize_sae_decoder:
|
|
117
|
+
with torch.no_grad():
|
|
118
|
+
self.set_decoder_norm_to_unit_norm()
|
|
119
|
+
|
|
120
|
+
@torch.no_grad()
|
|
121
|
+
def initialize_decoder_norm_constant_norm(self, norm: float = 0.1):
|
|
122
|
+
self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True) # type: ignore
|
|
123
|
+
self.W_dec.data *= norm
|
|
124
|
+
|
|
125
|
+
def encode_with_hidden_pre(
|
|
126
|
+
self, x: Float[torch.Tensor, "... d_in"]
|
|
127
|
+
) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
|
|
128
|
+
# Process the input (including dtype conversion, hook call, and any activation normalization)
|
|
129
|
+
sae_in = self.process_sae_in(x)
|
|
130
|
+
# Compute the pre-activation (and allow for a hook if desired)
|
|
131
|
+
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc) # type: ignore
|
|
132
|
+
# Add noise during training for robustness (scaled by noise_scale from the configuration)
|
|
133
|
+
if self.training and self.cfg.noise_scale > 0:
|
|
134
|
+
hidden_pre_noised = (
|
|
135
|
+
hidden_pre + torch.randn_like(hidden_pre) * self.cfg.noise_scale
|
|
136
|
+
)
|
|
137
|
+
else:
|
|
138
|
+
hidden_pre_noised = hidden_pre
|
|
139
|
+
# Apply the activation function (and any post-activation hook)
|
|
140
|
+
feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre_noised))
|
|
141
|
+
return feature_acts, hidden_pre_noised
|
|
142
|
+
|
|
143
|
+
def calculate_aux_loss(
|
|
144
|
+
self,
|
|
145
|
+
step_input: TrainStepInput,
|
|
146
|
+
feature_acts: torch.Tensor,
|
|
147
|
+
hidden_pre: torch.Tensor,
|
|
148
|
+
sae_out: torch.Tensor,
|
|
149
|
+
) -> dict[str, torch.Tensor]:
|
|
150
|
+
# The "standard" auxiliary loss is a sparsity penalty on the feature activations
|
|
151
|
+
weighted_feature_acts = feature_acts
|
|
152
|
+
if self.cfg.scale_sparsity_penalty_by_decoder_norm:
|
|
153
|
+
weighted_feature_acts = feature_acts * self.W_dec.norm(dim=1)
|
|
154
|
+
|
|
155
|
+
# Compute the p-norm (set by cfg.lp_norm) over the feature dimension
|
|
156
|
+
sparsity = weighted_feature_acts.norm(p=self.cfg.lp_norm, dim=-1)
|
|
157
|
+
l1_loss = (step_input.current_l1_coefficient * sparsity).mean()
|
|
158
|
+
|
|
159
|
+
return {"l1_loss": l1_loss}
|
|
160
|
+
|
|
161
|
+
def log_histograms(self) -> dict[str, NDArray[np.generic]]:
|
|
162
|
+
"""Log histograms of the weights and biases."""
|
|
163
|
+
b_e_dist = self.b_enc.detach().float().cpu().numpy()
|
|
164
|
+
return {
|
|
165
|
+
**super().log_histograms(),
|
|
166
|
+
"weights/b_e": b_e_dist,
|
|
167
|
+
}
|
|
@@ -0,0 +1,305 @@
|
|
|
1
|
+
"""Inference-only TopKSAE variant, similar in spirit to StandardSAE but using a TopK-based activation."""
|
|
2
|
+
|
|
3
|
+
from typing import Callable
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from jaxtyping import Float
|
|
7
|
+
from torch import nn
|
|
8
|
+
|
|
9
|
+
from sae_lens.saes.sae import (
|
|
10
|
+
SAE,
|
|
11
|
+
SAEConfig,
|
|
12
|
+
TrainingSAE,
|
|
13
|
+
TrainingSAEConfig,
|
|
14
|
+
TrainStepInput,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class TopK(nn.Module):
|
|
19
|
+
"""
|
|
20
|
+
A simple TopK activation that zeroes out all but the top K elements along the last dimension,
|
|
21
|
+
then optionally applies a post-activation function (e.g., ReLU).
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
b_enc: nn.Parameter
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
k: int,
|
|
29
|
+
postact_fn: Callable[[torch.Tensor], torch.Tensor] = nn.ReLU(),
|
|
30
|
+
):
|
|
31
|
+
super().__init__()
|
|
32
|
+
self.k = k
|
|
33
|
+
self.postact_fn = postact_fn
|
|
34
|
+
|
|
35
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
36
|
+
"""
|
|
37
|
+
1) Select top K elements along the last dimension.
|
|
38
|
+
2) Apply post-activation (often ReLU).
|
|
39
|
+
3) Zero out all other entries.
|
|
40
|
+
"""
|
|
41
|
+
topk = torch.topk(x, k=self.k, dim=-1)
|
|
42
|
+
values = self.postact_fn(topk.values)
|
|
43
|
+
result = torch.zeros_like(x)
|
|
44
|
+
result.scatter_(-1, topk.indices, values)
|
|
45
|
+
return result
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class TopKSAE(SAE):
|
|
49
|
+
"""
|
|
50
|
+
An inference-only sparse autoencoder using a "topk" activation function.
|
|
51
|
+
It uses linear encoder and decoder layers, applying the TopK activation
|
|
52
|
+
to the hidden pre-activation in its encode step.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def __init__(self, cfg: SAEConfig, use_error_term: bool = False):
|
|
56
|
+
"""
|
|
57
|
+
Args:
|
|
58
|
+
cfg: SAEConfig defining model size and behavior.
|
|
59
|
+
use_error_term: Whether to apply the error-term approach in the forward pass.
|
|
60
|
+
"""
|
|
61
|
+
super().__init__(cfg, use_error_term)
|
|
62
|
+
|
|
63
|
+
if self.cfg.activation_fn != "topk":
|
|
64
|
+
raise ValueError("TopKSAE must use a TopK activation function.")
|
|
65
|
+
|
|
66
|
+
def initialize_weights(self) -> None:
|
|
67
|
+
"""
|
|
68
|
+
Initializes weights and biases for encoder/decoder similarly to the standard SAE,
|
|
69
|
+
that is:
|
|
70
|
+
- b_enc, b_dec are zero-initialized
|
|
71
|
+
- W_enc, W_dec are Kaiming Uniform
|
|
72
|
+
"""
|
|
73
|
+
# encoder bias
|
|
74
|
+
self.b_enc = nn.Parameter(
|
|
75
|
+
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
|
|
76
|
+
)
|
|
77
|
+
# decoder bias
|
|
78
|
+
self.b_dec = nn.Parameter(
|
|
79
|
+
torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
# encoder weight
|
|
83
|
+
w_enc_data = torch.empty(
|
|
84
|
+
self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
|
|
85
|
+
)
|
|
86
|
+
nn.init.kaiming_uniform_(w_enc_data)
|
|
87
|
+
self.W_enc = nn.Parameter(w_enc_data)
|
|
88
|
+
|
|
89
|
+
# decoder weight
|
|
90
|
+
w_dec_data = torch.empty(
|
|
91
|
+
self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
|
|
92
|
+
)
|
|
93
|
+
nn.init.kaiming_uniform_(w_dec_data)
|
|
94
|
+
self.W_dec = nn.Parameter(w_dec_data)
|
|
95
|
+
|
|
96
|
+
def encode(
|
|
97
|
+
self, x: Float[torch.Tensor, "... d_in"]
|
|
98
|
+
) -> Float[torch.Tensor, "... d_sae"]:
|
|
99
|
+
"""
|
|
100
|
+
Converts input x into feature activations.
|
|
101
|
+
Uses topk activation from the config (cfg.activation_fn == "topk")
|
|
102
|
+
under the hood.
|
|
103
|
+
"""
|
|
104
|
+
sae_in = self.process_sae_in(x)
|
|
105
|
+
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
|
|
106
|
+
# The BaseSAE already sets self.activation_fn to TopK(...) if config requests topk.
|
|
107
|
+
return self.hook_sae_acts_post(self.activation_fn(hidden_pre))
|
|
108
|
+
|
|
109
|
+
def decode(
|
|
110
|
+
self, feature_acts: Float[torch.Tensor, "... d_sae"]
|
|
111
|
+
) -> Float[torch.Tensor, "... d_in"]:
|
|
112
|
+
"""
|
|
113
|
+
Reconstructs the input from topk feature activations.
|
|
114
|
+
Applies optional finetuning scaling, hooking to recons, out normalization,
|
|
115
|
+
and optional head reshaping.
|
|
116
|
+
"""
|
|
117
|
+
scaled_features = self.apply_finetuning_scaling_factor(feature_acts)
|
|
118
|
+
sae_out_pre = scaled_features @ self.W_dec + self.b_dec
|
|
119
|
+
sae_out_pre = self.hook_sae_recons(sae_out_pre)
|
|
120
|
+
sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
|
|
121
|
+
return self.reshape_fn_out(sae_out_pre, self.d_head)
|
|
122
|
+
|
|
123
|
+
def _get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
|
|
124
|
+
if self.cfg.activation_fn == "topk":
|
|
125
|
+
if "k" not in self.cfg.activation_fn_kwargs:
|
|
126
|
+
raise ValueError("TopK activation function requires a k value.")
|
|
127
|
+
k = self.cfg.activation_fn_kwargs.get(
|
|
128
|
+
"k", 1
|
|
129
|
+
) # Default k to 1 if not provided
|
|
130
|
+
postact_fn = self.cfg.activation_fn_kwargs.get(
|
|
131
|
+
"postact_fn", nn.ReLU()
|
|
132
|
+
) # Default post-activation to ReLU if not provided
|
|
133
|
+
return TopK(k, postact_fn)
|
|
134
|
+
# Otherwise, return the "standard" handling from BaseSAE
|
|
135
|
+
return super()._get_activation_fn()
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class TopKTrainingSAE(TrainingSAE):
|
|
139
|
+
"""
|
|
140
|
+
TopK variant with training functionality. Injects noise during training, optionally
|
|
141
|
+
calculates a topk-related auxiliary loss, etc.
|
|
142
|
+
"""
|
|
143
|
+
|
|
144
|
+
b_enc: nn.Parameter
|
|
145
|
+
|
|
146
|
+
def __init__(self, cfg: TrainingSAEConfig, use_error_term: bool = False):
|
|
147
|
+
super().__init__(cfg, use_error_term)
|
|
148
|
+
|
|
149
|
+
if self.cfg.activation_fn != "topk":
|
|
150
|
+
raise ValueError("TopKSAE must use a TopK activation function.")
|
|
151
|
+
|
|
152
|
+
def initialize_weights(self) -> None:
|
|
153
|
+
"""Very similar to TopKSAE, using zero biases + Kaiming Uniform weights."""
|
|
154
|
+
self.b_enc = nn.Parameter(
|
|
155
|
+
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
|
|
156
|
+
)
|
|
157
|
+
self.b_dec = nn.Parameter(
|
|
158
|
+
torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
w_enc_data = torch.empty(
|
|
162
|
+
self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
|
|
163
|
+
)
|
|
164
|
+
nn.init.kaiming_uniform_(w_enc_data)
|
|
165
|
+
self.W_enc = nn.Parameter(w_enc_data)
|
|
166
|
+
|
|
167
|
+
w_dec_data = torch.empty(
|
|
168
|
+
self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
|
|
169
|
+
)
|
|
170
|
+
nn.init.kaiming_uniform_(w_dec_data)
|
|
171
|
+
self.W_dec = nn.Parameter(w_dec_data)
|
|
172
|
+
|
|
173
|
+
def encode_with_hidden_pre(
|
|
174
|
+
self, x: Float[torch.Tensor, "... d_in"]
|
|
175
|
+
) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
|
|
176
|
+
"""
|
|
177
|
+
Similar to the base training method: cast input, optionally add noise, then apply TopK.
|
|
178
|
+
"""
|
|
179
|
+
sae_in = self.process_sae_in(x)
|
|
180
|
+
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
|
|
181
|
+
|
|
182
|
+
# Inject noise if training
|
|
183
|
+
if self.training and self.cfg.noise_scale > 0:
|
|
184
|
+
hidden_pre_noised = (
|
|
185
|
+
hidden_pre + torch.randn_like(hidden_pre) * self.cfg.noise_scale
|
|
186
|
+
)
|
|
187
|
+
else:
|
|
188
|
+
hidden_pre_noised = hidden_pre
|
|
189
|
+
|
|
190
|
+
# Apply the TopK activation function (already set in self.activation_fn if config is "topk")
|
|
191
|
+
feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre_noised))
|
|
192
|
+
return feature_acts, hidden_pre_noised
|
|
193
|
+
|
|
194
|
+
def calculate_aux_loss(
|
|
195
|
+
self,
|
|
196
|
+
step_input: TrainStepInput,
|
|
197
|
+
feature_acts: torch.Tensor,
|
|
198
|
+
hidden_pre: torch.Tensor,
|
|
199
|
+
sae_out: torch.Tensor,
|
|
200
|
+
) -> dict[str, torch.Tensor]:
|
|
201
|
+
# Calculate the auxiliary loss for dead neurons
|
|
202
|
+
topk_loss = self.calculate_topk_aux_loss(
|
|
203
|
+
sae_in=step_input.sae_in,
|
|
204
|
+
sae_out=sae_out,
|
|
205
|
+
hidden_pre=hidden_pre,
|
|
206
|
+
dead_neuron_mask=step_input.dead_neuron_mask,
|
|
207
|
+
)
|
|
208
|
+
return {"auxiliary_reconstruction_loss": topk_loss}
|
|
209
|
+
|
|
210
|
+
def _get_activation_fn(self):
|
|
211
|
+
if self.cfg.activation_fn == "topk":
|
|
212
|
+
if "k" not in self.cfg.activation_fn_kwargs:
|
|
213
|
+
raise ValueError("TopK activation function requires a k value.")
|
|
214
|
+
k = self.cfg.activation_fn_kwargs.get("k", 1)
|
|
215
|
+
postact_fn = self.cfg.activation_fn_kwargs.get("postact_fn", nn.ReLU())
|
|
216
|
+
return TopK(k, postact_fn)
|
|
217
|
+
return super()._get_activation_fn()
|
|
218
|
+
|
|
219
|
+
def calculate_topk_aux_loss(
|
|
220
|
+
self,
|
|
221
|
+
sae_in: torch.Tensor,
|
|
222
|
+
sae_out: torch.Tensor,
|
|
223
|
+
hidden_pre: torch.Tensor,
|
|
224
|
+
dead_neuron_mask: torch.Tensor | None,
|
|
225
|
+
) -> torch.Tensor:
|
|
226
|
+
"""
|
|
227
|
+
Calculate TopK auxiliary loss.
|
|
228
|
+
|
|
229
|
+
This auxiliary loss encourages dead neurons to learn useful features by having
|
|
230
|
+
them reconstruct the residual error from the live neurons. It's a key part of
|
|
231
|
+
preventing neuron death in TopK SAEs.
|
|
232
|
+
"""
|
|
233
|
+
# Mostly taken from https://github.com/EleutherAI/sae/blob/main/sae/sae.py, except without variance normalization
|
|
234
|
+
# NOTE: checking the number of dead neurons will force a GPU sync, so performance can likely be improved here
|
|
235
|
+
if dead_neuron_mask is None or (num_dead := int(dead_neuron_mask.sum())) == 0:
|
|
236
|
+
return sae_out.new_tensor(0.0)
|
|
237
|
+
residual = (sae_in - sae_out).detach()
|
|
238
|
+
|
|
239
|
+
# Heuristic from Appendix B.1 in the paper
|
|
240
|
+
k_aux = sae_in.shape[-1] // 2
|
|
241
|
+
|
|
242
|
+
# Reduce the scale of the loss if there are a small number of dead latents
|
|
243
|
+
scale = min(num_dead / k_aux, 1.0)
|
|
244
|
+
k_aux = min(k_aux, num_dead)
|
|
245
|
+
|
|
246
|
+
auxk_acts = _calculate_topk_aux_acts(
|
|
247
|
+
k_aux=k_aux,
|
|
248
|
+
hidden_pre=hidden_pre,
|
|
249
|
+
dead_neuron_mask=dead_neuron_mask,
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
# Encourage the top ~50% of dead latents to predict the residual of the
|
|
253
|
+
# top k living latents
|
|
254
|
+
recons = self.decode(auxk_acts)
|
|
255
|
+
auxk_loss = (recons - residual).pow(2).sum(dim=-1).mean()
|
|
256
|
+
return scale * auxk_loss
|
|
257
|
+
|
|
258
|
+
def _calculate_topk_aux_acts(
|
|
259
|
+
self,
|
|
260
|
+
k_aux: int,
|
|
261
|
+
hidden_pre: torch.Tensor,
|
|
262
|
+
dead_neuron_mask: torch.Tensor,
|
|
263
|
+
) -> torch.Tensor:
|
|
264
|
+
"""
|
|
265
|
+
Helper method to calculate activations for the auxiliary loss.
|
|
266
|
+
|
|
267
|
+
Args:
|
|
268
|
+
k_aux: Number of top dead neurons to select
|
|
269
|
+
hidden_pre: Pre-activation values from encoder
|
|
270
|
+
dead_neuron_mask: Boolean mask indicating which neurons are dead
|
|
271
|
+
|
|
272
|
+
Returns:
|
|
273
|
+
Tensor with activations for only the top-k dead neurons, zeros elsewhere
|
|
274
|
+
"""
|
|
275
|
+
# Don't include living latents in this loss (set them to -inf so they won't be selected)
|
|
276
|
+
auxk_latents = torch.where(
|
|
277
|
+
dead_neuron_mask[None],
|
|
278
|
+
hidden_pre,
|
|
279
|
+
torch.tensor(-float("inf"), device=hidden_pre.device),
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
# Find topk values among dead neurons
|
|
283
|
+
auxk_topk = auxk_latents.topk(k_aux, dim=-1, sorted=False)
|
|
284
|
+
|
|
285
|
+
# Create a tensor of zeros, then place the topk values at their proper indices
|
|
286
|
+
auxk_acts = torch.zeros_like(hidden_pre)
|
|
287
|
+
auxk_acts.scatter_(-1, auxk_topk.indices, auxk_topk.values)
|
|
288
|
+
|
|
289
|
+
return auxk_acts
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def _calculate_topk_aux_acts(
|
|
293
|
+
k_aux: int,
|
|
294
|
+
hidden_pre: torch.Tensor,
|
|
295
|
+
dead_neuron_mask: torch.Tensor,
|
|
296
|
+
) -> torch.Tensor:
|
|
297
|
+
# Don't include living latents in this loss
|
|
298
|
+
auxk_latents = torch.where(dead_neuron_mask[None], hidden_pre, -torch.inf)
|
|
299
|
+
# Top-k dead latents
|
|
300
|
+
auxk_topk = auxk_latents.topk(k_aux, sorted=False)
|
|
301
|
+
# Set the activations to zero for all but the top k_aux dead latents
|
|
302
|
+
auxk_acts = torch.zeros_like(hidden_pre)
|
|
303
|
+
auxk_acts.scatter_(-1, auxk_topk.indices, auxk_topk.values)
|
|
304
|
+
# Set activations to zero for all but top k_aux dead latents
|
|
305
|
+
return auxk_acts
|
|
@@ -28,7 +28,7 @@ from sae_lens.config import (
|
|
|
28
28
|
HfDataset,
|
|
29
29
|
LanguageModelSAERunnerConfig,
|
|
30
30
|
)
|
|
31
|
-
from sae_lens.sae import SAE
|
|
31
|
+
from sae_lens.saes.sae import SAE
|
|
32
32
|
from sae_lens.tokenization_and_batching import concat_and_batch_sequences
|
|
33
33
|
|
|
34
34
|
|
|
@@ -177,7 +177,7 @@ class ActivationsStore:
|
|
|
177
177
|
dataset_trust_remote_code=sae.cfg.dataset_trust_remote_code,
|
|
178
178
|
dtype=sae.cfg.dtype,
|
|
179
179
|
device=torch.device(device),
|
|
180
|
-
seqpos_slice=sae.cfg.seqpos_slice,
|
|
180
|
+
seqpos_slice=sae.cfg.seqpos_slice or (None,),
|
|
181
181
|
)
|
|
182
182
|
|
|
183
183
|
def __init__(
|
sae_lens/training/sae_trainer.py
CHANGED
|
@@ -11,9 +11,9 @@ from transformer_lens.hook_points import HookedRootModule
|
|
|
11
11
|
from sae_lens import __version__
|
|
12
12
|
from sae_lens.config import LanguageModelSAERunnerConfig
|
|
13
13
|
from sae_lens.evals import EvalConfig, run_evals
|
|
14
|
+
from sae_lens.saes.sae import TrainingSAE, TrainStepInput, TrainStepOutput
|
|
14
15
|
from sae_lens.training.activations_store import ActivationsStore
|
|
15
16
|
from sae_lens.training.optim import L1Scheduler, get_lr_scheduler
|
|
16
|
-
from sae_lens.training.training_sae import TrainingSAE, TrainStepOutput
|
|
17
17
|
|
|
18
18
|
# used to map between parameters which are updated during finetuning and the config str.
|
|
19
19
|
FINETUNING_PARAMETERS = {
|
|
@@ -186,7 +186,7 @@ class SAETrainer:
|
|
|
186
186
|
|
|
187
187
|
step_output = self._train_step(sae=self.sae, sae_in=layer_acts)
|
|
188
188
|
|
|
189
|
-
if self.cfg.log_to_wandb:
|
|
189
|
+
if self.cfg.logger.log_to_wandb:
|
|
190
190
|
self._log_train_step(step_output)
|
|
191
191
|
self._run_and_log_evals()
|
|
192
192
|
|
|
@@ -226,7 +226,7 @@ class SAETrainer:
|
|
|
226
226
|
|
|
227
227
|
# log and then reset the feature sparsity every feature_sampling_window steps
|
|
228
228
|
if (self.n_training_steps + 1) % self.cfg.feature_sampling_window == 0:
|
|
229
|
-
if self.cfg.log_to_wandb:
|
|
229
|
+
if self.cfg.logger.log_to_wandb:
|
|
230
230
|
sparsity_log_dict = self._build_sparsity_log_dict()
|
|
231
231
|
wandb.log(sparsity_log_dict, step=self.n_training_steps)
|
|
232
232
|
self._reset_running_sparsity_stats()
|
|
@@ -235,9 +235,11 @@ class SAETrainer:
|
|
|
235
235
|
# https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html
|
|
236
236
|
with self.autocast_if_enabled:
|
|
237
237
|
train_step_output = self.sae.training_forward_pass(
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
238
|
+
step_input=TrainStepInput(
|
|
239
|
+
sae_in=sae_in,
|
|
240
|
+
dead_neuron_mask=self.dead_neurons,
|
|
241
|
+
current_l1_coefficient=self.current_l1_coefficient,
|
|
242
|
+
),
|
|
241
243
|
)
|
|
242
244
|
|
|
243
245
|
with torch.no_grad():
|
|
@@ -270,7 +272,7 @@ class SAETrainer:
|
|
|
270
272
|
|
|
271
273
|
@torch.no_grad()
|
|
272
274
|
def _log_train_step(self, step_output: TrainStepOutput):
|
|
273
|
-
if (self.n_training_steps + 1) % self.cfg.wandb_log_frequency == 0:
|
|
275
|
+
if (self.n_training_steps + 1) % self.cfg.logger.wandb_log_frequency == 0:
|
|
274
276
|
wandb.log(
|
|
275
277
|
self._build_train_step_log_dict(
|
|
276
278
|
output=step_output,
|
|
@@ -331,7 +333,8 @@ class SAETrainer:
|
|
|
331
333
|
def _run_and_log_evals(self):
|
|
332
334
|
# record loss frequently, but not all the time.
|
|
333
335
|
if (self.n_training_steps + 1) % (
|
|
334
|
-
self.cfg.wandb_log_frequency
|
|
336
|
+
self.cfg.logger.wandb_log_frequency
|
|
337
|
+
* self.cfg.logger.eval_every_n_wandb_logs
|
|
335
338
|
) == 0:
|
|
336
339
|
self.sae.eval()
|
|
337
340
|
ignore_tokens = set()
|
|
@@ -358,17 +361,8 @@ class SAETrainer:
|
|
|
358
361
|
# Remove metrics that are not useful for wandb logging
|
|
359
362
|
eval_metrics.pop("metrics/total_tokens_evaluated", None)
|
|
360
363
|
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
if self.sae.cfg.architecture == "standard":
|
|
365
|
-
b_e_dist = self.sae.b_enc.detach().float().cpu().numpy()
|
|
366
|
-
eval_metrics["weights/b_e"] = wandb.Histogram(b_e_dist) # type: ignore
|
|
367
|
-
elif self.sae.cfg.architecture == "gated":
|
|
368
|
-
b_gate_dist = self.sae.b_gate.detach().float().cpu().numpy()
|
|
369
|
-
eval_metrics["weights/b_gate"] = wandb.Histogram(b_gate_dist) # type: ignore
|
|
370
|
-
b_mag_dist = self.sae.b_mag.detach().float().cpu().numpy()
|
|
371
|
-
eval_metrics["weights/b_mag"] = wandb.Histogram(b_mag_dist) # type: ignore
|
|
364
|
+
for key, value in self.sae.log_histograms().items():
|
|
365
|
+
eval_metrics[key] = wandb.Histogram(value) # type: ignore
|
|
372
366
|
|
|
373
367
|
wandb.log(
|
|
374
368
|
eval_metrics,
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: sae-lens
|
|
3
|
-
Version:
|
|
3
|
+
Version: 6.0.0rc1
|
|
4
4
|
Summary: Training and Analyzing Sparse Autoencoders (SAEs)
|
|
5
5
|
License: MIT
|
|
6
6
|
Keywords: deep-learning,sparse-autoencoders,mechanistic-interpretability,PyTorch
|
|
@@ -62,7 +62,7 @@ Please refer to the [documentation](https://jbloomaus.github.io/SAELens/) for in
|
|
|
62
62
|
|
|
63
63
|
SAE Lens is the result of many contributors working collectively to improve humanity's understanding of neural networks, many of whom are motivated by a desire to [safeguard humanity from risks posed by artificial intelligence](https://80000hours.org/problem-profiles/artificial-intelligence/).
|
|
64
64
|
|
|
65
|
-
This library is maintained by [Joseph Bloom](https://www.jbloomaus.com/) and [David Chanin](https://github.com/chanind).
|
|
65
|
+
This library is maintained by [Joseph Bloom](https://www.jbloomaus.com/), [Curt Tigges](https://curttigges.com/), [Anthony Duong](https://github.com/anthonyduong9) and [David Chanin](https://github.com/chanind).
|
|
66
66
|
|
|
67
67
|
## Loading Pre-trained SAEs.
|
|
68
68
|
|
|
@@ -89,7 +89,7 @@ Please cite the package as follows:
|
|
|
89
89
|
```
|
|
90
90
|
@misc{bloom2024saetrainingcodebase,
|
|
91
91
|
title = {SAELens},
|
|
92
|
-
author = {Joseph
|
|
92
|
+
author = {Bloom, Joseph and Tigges, Curt and Duong, Anthony and Chanin, David},
|
|
93
93
|
year = {2024},
|
|
94
94
|
howpublished = {\url{https://github.com/jbloomAus/SAELens}},
|
|
95
95
|
}
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
sae_lens/__init__.py,sha256=ofQyurU7LtxIsg89QFCZe13QsdYpxErRI0x0tiCpB04,2074
|
|
2
|
+
sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
|
+
sae_lens/analysis/hooked_sae_transformer.py,sha256=RK0mcLhymXdJInXHcagQggxW9Qf4ptePnH7sKXvGGaU,13727
|
|
4
|
+
sae_lens/analysis/neuronpedia_integration.py,sha256=dFiKRWfuT5iUfTPBPmZydSaNG3VwqZ1asuNbbQv_NCM,18488
|
|
5
|
+
sae_lens/cache_activations_runner.py,sha256=dGK5EHJMHAKDAFyr25fy1COSm-61q-q6kpWENHFMaKk,12561
|
|
6
|
+
sae_lens/config.py,sha256=SPjziXrTyOBjObSi-3s0_mza3Z7WH8gd9NT9pVUfosg,34375
|
|
7
|
+
sae_lens/evals.py,sha256=tjDKmkUM4fBbP9LHZuBLCx37ux8Px9CliTMme3Wjt1A,38898
|
|
8
|
+
sae_lens/load_model.py,sha256=tE70sXsyyyGYW7o506O3eiw1MXyyW6DCQojLG49hWYI,6771
|
|
9
|
+
sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
10
|
+
sae_lens/loading/pretrained_sae_loaders.py,sha256=NcqyH2KDL8Dg66-hjXsBAq1-IwdLEpYfKwbkHxSQbrg,29961
|
|
11
|
+
sae_lens/loading/pretrained_saes_directory.py,sha256=4Vn-Jex6SveD7EbxcSOBv8cx1gkPfUMLU1QOP-ww1ZE,3752
|
|
12
|
+
sae_lens/pretokenize_runner.py,sha256=0nHQq3s_d80VS8iVK4-e6y_orAYVO8c4RrLGtIDfK_E,6885
|
|
13
|
+
sae_lens/pretrained_saes.yaml,sha256=C_z-7Lxz6ZIy2V-c-4Xw45eAQ926O9aGjocSNuki0xs,573557
|
|
14
|
+
sae_lens/regsitry.py,sha256=yCse5NmVH-ZaPET3jW8r7C_py2DL3yoox40GxGzJ0TI,1098
|
|
15
|
+
sae_lens/sae_training_runner.py,sha256=VRNSAIsZLfcQMfZB8qdnK45PUXwoNvJ-rKt9BVYjMMY,8244
|
|
16
|
+
sae_lens/saes/gated_sae.py,sha256=l5ucq7AZHya6ZClWNNE7CionGSf1ms5m1Ah3IoN6SH4,9916
|
|
17
|
+
sae_lens/saes/jumprelu_sae.py,sha256=DRWgY58894cNh_sYAlefObI4rr0Eb6KHu1WuhTCcvB4,13468
|
|
18
|
+
sae_lens/saes/sae.py,sha256=fd7OEsSXbmVii6QoYI_TRti6dwaxAQyrBcKyX7PxERw,36779
|
|
19
|
+
sae_lens/saes/standard_sae.py,sha256=m2eNL_w6ave-_g7F1eQiwI4qbjMwwjzvxp96RN_WVAw,7110
|
|
20
|
+
sae_lens/saes/topk_sae.py,sha256=aBET4F55A4xMIvZ8AazPtyl3oL-9S7krKx78li0uKGk,11370
|
|
21
|
+
sae_lens/tokenization_and_batching.py,sha256=oUAscjy_LPOrOb8_Ty6eLAcZ0B3HB_wiWjWktgolhG0,4314
|
|
22
|
+
sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
23
|
+
sae_lens/training/activations_store.py,sha256=ilJdcnZWfTDus1bdoqIb1wF_7H8_HWLmf8OCGrybmlA,35998
|
|
24
|
+
sae_lens/training/geometric_median.py,sha256=3kH8ZJAgKStlnZgs6s1uYGDYh004Bl0r4RLhuwT3lBY,3719
|
|
25
|
+
sae_lens/training/optim.py,sha256=AImcc-MAaGDLOBP2hJ4alDFCtaqqgm4cc2eBxIxiQAo,5784
|
|
26
|
+
sae_lens/training/sae_trainer.py,sha256=6TkqbzA0fYluRM8ouI_nU9sz-FaP63axxcnDrVfw37E,16279
|
|
27
|
+
sae_lens/training/upload_saes_to_huggingface.py,sha256=tVC-2Txw7-9XttGlKzM0OSqU8CK7HDO9vIzDMqEwAYU,4366
|
|
28
|
+
sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
|
|
29
|
+
sae_lens-6.0.0rc1.dist-info/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
|
|
30
|
+
sae_lens-6.0.0rc1.dist-info/METADATA,sha256=wHH-VRtquu-FjZEOHdPJi3zYW3ns7MCT1fVerbPEylc,5326
|
|
31
|
+
sae_lens-6.0.0rc1.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
|
32
|
+
sae_lens-6.0.0rc1.dist-info/RECORD,,
|