sae-lens 5.10.3__py3-none-any.whl → 6.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +56 -6
- sae_lens/analysis/hooked_sae_transformer.py +12 -12
- sae_lens/analysis/neuronpedia_integration.py +13 -11
- sae_lens/cache_activations_runner.py +2 -1
- sae_lens/config.py +121 -252
- sae_lens/constants.py +18 -0
- sae_lens/evals.py +32 -17
- sae_lens/{toolkit → loading}/pretrained_sae_loaders.py +68 -36
- sae_lens/pretrained_saes.yaml +0 -12
- sae_lens/registry.py +49 -0
- sae_lens/sae_training_runner.py +40 -54
- sae_lens/saes/__init__.py +48 -0
- sae_lens/saes/gated_sae.py +258 -0
- sae_lens/saes/jumprelu_sae.py +354 -0
- sae_lens/saes/sae.py +948 -0
- sae_lens/saes/standard_sae.py +185 -0
- sae_lens/saes/topk_sae.py +294 -0
- sae_lens/training/activations_store.py +32 -16
- sae_lens/training/optim.py +60 -36
- sae_lens/training/sae_trainer.py +55 -86
- sae_lens/training/upload_saes_to_huggingface.py +12 -6
- sae_lens/util.py +28 -0
- {sae_lens-5.10.3.dist-info → sae_lens-6.0.0rc2.dist-info}/METADATA +1 -1
- sae_lens-6.0.0rc2.dist-info/RECORD +35 -0
- sae_lens/sae.py +0 -747
- sae_lens/training/training_sae.py +0 -710
- sae_lens-5.10.3.dist-info/RECORD +0 -28
- /sae_lens/{toolkit → loading}/__init__.py +0 -0
- /sae_lens/{toolkit → loading}/pretrained_saes_directory.py +0 -0
- {sae_lens-5.10.3.dist-info → sae_lens-6.0.0rc2.dist-info}/LICENSE +0 -0
- {sae_lens-5.10.3.dist-info → sae_lens-6.0.0rc2.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
from jaxtyping import Float
|
|
7
|
+
from numpy.typing import NDArray
|
|
8
|
+
from torch import nn
|
|
9
|
+
from typing_extensions import override
|
|
10
|
+
|
|
11
|
+
from sae_lens.saes.sae import (
|
|
12
|
+
SAE,
|
|
13
|
+
SAEConfig,
|
|
14
|
+
TrainCoefficientConfig,
|
|
15
|
+
TrainingSAE,
|
|
16
|
+
TrainingSAEConfig,
|
|
17
|
+
TrainStepInput,
|
|
18
|
+
)
|
|
19
|
+
from sae_lens.util import filter_valid_dataclass_fields
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class StandardSAEConfig(SAEConfig):
|
|
24
|
+
"""
|
|
25
|
+
Configuration class for a StandardSAE.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
@override
|
|
29
|
+
@classmethod
|
|
30
|
+
def architecture(cls) -> str:
|
|
31
|
+
return "standard"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class StandardSAE(SAE[StandardSAEConfig]):
|
|
35
|
+
"""
|
|
36
|
+
StandardSAE is an inference-only implementation of a Sparse Autoencoder (SAE)
|
|
37
|
+
using a simple linear encoder and decoder.
|
|
38
|
+
|
|
39
|
+
It implements the required abstract methods from BaseSAE:
|
|
40
|
+
- initialize_weights: sets up simple parameter initializations for W_enc, b_enc, W_dec, and b_dec.
|
|
41
|
+
- encode: computes the feature activations from an input.
|
|
42
|
+
- decode: reconstructs the input from the feature activations.
|
|
43
|
+
|
|
44
|
+
The BaseSAE.forward() method automatically calls encode and decode,
|
|
45
|
+
including any error-term processing if configured.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
b_enc: nn.Parameter
|
|
49
|
+
|
|
50
|
+
def __init__(self, cfg: StandardSAEConfig, use_error_term: bool = False):
|
|
51
|
+
super().__init__(cfg, use_error_term)
|
|
52
|
+
|
|
53
|
+
@override
|
|
54
|
+
def initialize_weights(self) -> None:
|
|
55
|
+
# Initialize encoder weights and bias.
|
|
56
|
+
super().initialize_weights()
|
|
57
|
+
_init_weights_standard(self)
|
|
58
|
+
|
|
59
|
+
def encode(
|
|
60
|
+
self, x: Float[torch.Tensor, "... d_in"]
|
|
61
|
+
) -> Float[torch.Tensor, "... d_sae"]:
|
|
62
|
+
"""
|
|
63
|
+
Encode the input tensor into the feature space.
|
|
64
|
+
For inference, no noise is added.
|
|
65
|
+
"""
|
|
66
|
+
# Preprocess the SAE input (casting type, applying hooks, normalization)
|
|
67
|
+
sae_in = self.process_sae_in(x)
|
|
68
|
+
# Compute the pre-activation values
|
|
69
|
+
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
|
|
70
|
+
# Apply the activation function (e.g., ReLU, tanh-relu, depending on config)
|
|
71
|
+
return self.hook_sae_acts_post(self.activation_fn(hidden_pre))
|
|
72
|
+
|
|
73
|
+
def decode(
|
|
74
|
+
self, feature_acts: Float[torch.Tensor, "... d_sae"]
|
|
75
|
+
) -> Float[torch.Tensor, "... d_in"]:
|
|
76
|
+
"""
|
|
77
|
+
Decode the feature activations back to the input space.
|
|
78
|
+
Now, if hook_z reshaping is turned on, we reverse the flattening.
|
|
79
|
+
"""
|
|
80
|
+
# 1) linear transform
|
|
81
|
+
sae_out_pre = feature_acts @ self.W_dec + self.b_dec
|
|
82
|
+
# 2) hook reconstruction
|
|
83
|
+
sae_out_pre = self.hook_sae_recons(sae_out_pre)
|
|
84
|
+
# 4) optional out-normalization (e.g. constant_norm_rescale or layer_norm)
|
|
85
|
+
sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
|
|
86
|
+
# 5) if hook_z is enabled, rearrange back to (..., n_heads, d_head).
|
|
87
|
+
return self.reshape_fn_out(sae_out_pre, self.d_head)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@dataclass
|
|
91
|
+
class StandardTrainingSAEConfig(TrainingSAEConfig):
|
|
92
|
+
"""
|
|
93
|
+
Configuration class for training a StandardTrainingSAE.
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
l1_coefficient: float = 1.0
|
|
97
|
+
lp_norm: float = 1.0
|
|
98
|
+
l1_warm_up_steps: int = 0
|
|
99
|
+
|
|
100
|
+
@override
|
|
101
|
+
@classmethod
|
|
102
|
+
def architecture(cls) -> str:
|
|
103
|
+
return "standard"
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class StandardTrainingSAE(TrainingSAE[StandardTrainingSAEConfig]):
|
|
107
|
+
"""
|
|
108
|
+
StandardTrainingSAE is a concrete implementation of BaseTrainingSAE using the "standard" SAE architecture.
|
|
109
|
+
It implements:
|
|
110
|
+
- initialize_weights: basic weight initialization for encoder/decoder.
|
|
111
|
+
- encode: inference encoding (invokes encode_with_hidden_pre).
|
|
112
|
+
- decode: a simple linear decoder.
|
|
113
|
+
- encode_with_hidden_pre: computes pre-activations, adds noise when training, and then activates.
|
|
114
|
+
- calculate_aux_loss: computes a sparsity penalty based on the (optionally scaled) p-norm of feature activations.
|
|
115
|
+
"""
|
|
116
|
+
|
|
117
|
+
b_enc: nn.Parameter
|
|
118
|
+
|
|
119
|
+
def initialize_weights(self) -> None:
|
|
120
|
+
super().initialize_weights()
|
|
121
|
+
_init_weights_standard(self)
|
|
122
|
+
|
|
123
|
+
@override
|
|
124
|
+
def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]:
|
|
125
|
+
return {
|
|
126
|
+
"l1": TrainCoefficientConfig(
|
|
127
|
+
value=self.cfg.l1_coefficient,
|
|
128
|
+
warm_up_steps=self.cfg.l1_warm_up_steps,
|
|
129
|
+
),
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
def encode_with_hidden_pre(
|
|
133
|
+
self, x: Float[torch.Tensor, "... d_in"]
|
|
134
|
+
) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
|
|
135
|
+
# Process the input (including dtype conversion, hook call, and any activation normalization)
|
|
136
|
+
sae_in = self.process_sae_in(x)
|
|
137
|
+
# Compute the pre-activation (and allow for a hook if desired)
|
|
138
|
+
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc) # type: ignore
|
|
139
|
+
# Add noise during training for robustness (scaled by noise_scale from the configuration)
|
|
140
|
+
if self.training and self.cfg.noise_scale > 0:
|
|
141
|
+
hidden_pre_noised = (
|
|
142
|
+
hidden_pre + torch.randn_like(hidden_pre) * self.cfg.noise_scale
|
|
143
|
+
)
|
|
144
|
+
else:
|
|
145
|
+
hidden_pre_noised = hidden_pre
|
|
146
|
+
# Apply the activation function (and any post-activation hook)
|
|
147
|
+
feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre_noised))
|
|
148
|
+
return feature_acts, hidden_pre_noised
|
|
149
|
+
|
|
150
|
+
def calculate_aux_loss(
|
|
151
|
+
self,
|
|
152
|
+
step_input: TrainStepInput,
|
|
153
|
+
feature_acts: torch.Tensor,
|
|
154
|
+
hidden_pre: torch.Tensor,
|
|
155
|
+
sae_out: torch.Tensor,
|
|
156
|
+
) -> dict[str, torch.Tensor]:
|
|
157
|
+
# The "standard" auxiliary loss is a sparsity penalty on the feature activations
|
|
158
|
+
weighted_feature_acts = feature_acts * self.W_dec.norm(dim=1)
|
|
159
|
+
|
|
160
|
+
# Compute the p-norm (set by cfg.lp_norm) over the feature dimension
|
|
161
|
+
sparsity = weighted_feature_acts.norm(p=self.cfg.lp_norm, dim=-1)
|
|
162
|
+
l1_loss = (step_input.coefficients["l1"] * sparsity).mean()
|
|
163
|
+
|
|
164
|
+
return {"l1_loss": l1_loss}
|
|
165
|
+
|
|
166
|
+
def log_histograms(self) -> dict[str, NDArray[np.generic]]:
|
|
167
|
+
"""Log histograms of the weights and biases."""
|
|
168
|
+
b_e_dist = self.b_enc.detach().float().cpu().numpy()
|
|
169
|
+
return {
|
|
170
|
+
**super().log_histograms(),
|
|
171
|
+
"weights/b_e": b_e_dist,
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
def to_inference_config_dict(self) -> dict[str, Any]:
|
|
175
|
+
return filter_valid_dataclass_fields(
|
|
176
|
+
self.cfg.to_dict(), StandardSAEConfig, ["architecture"]
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def _init_weights_standard(
|
|
181
|
+
sae: SAE[StandardSAEConfig] | TrainingSAE[StandardTrainingSAEConfig],
|
|
182
|
+
) -> None:
|
|
183
|
+
sae.b_enc = nn.Parameter(
|
|
184
|
+
torch.zeros(sae.cfg.d_sae, dtype=sae.dtype, device=sae.device)
|
|
185
|
+
)
|
|
@@ -0,0 +1,294 @@
|
|
|
1
|
+
"""Inference-only TopKSAE variant, similar in spirit to StandardSAE but using a TopK-based activation."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, Callable
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from jaxtyping import Float
|
|
8
|
+
from torch import nn
|
|
9
|
+
from typing_extensions import override
|
|
10
|
+
|
|
11
|
+
from sae_lens.saes.sae import (
|
|
12
|
+
SAE,
|
|
13
|
+
SAEConfig,
|
|
14
|
+
TrainCoefficientConfig,
|
|
15
|
+
TrainingSAE,
|
|
16
|
+
TrainingSAEConfig,
|
|
17
|
+
TrainStepInput,
|
|
18
|
+
)
|
|
19
|
+
from sae_lens.util import filter_valid_dataclass_fields
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class TopK(nn.Module):
|
|
23
|
+
"""
|
|
24
|
+
A simple TopK activation that zeroes out all but the top K elements along the last dimension,
|
|
25
|
+
then optionally applies a post-activation function (e.g., ReLU).
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
b_enc: nn.Parameter
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
k: int,
|
|
33
|
+
postact_fn: Callable[[torch.Tensor], torch.Tensor] = nn.ReLU(),
|
|
34
|
+
):
|
|
35
|
+
super().__init__()
|
|
36
|
+
self.k = k
|
|
37
|
+
self.postact_fn = postact_fn
|
|
38
|
+
|
|
39
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
40
|
+
"""
|
|
41
|
+
1) Select top K elements along the last dimension.
|
|
42
|
+
2) Apply post-activation (often ReLU).
|
|
43
|
+
3) Zero out all other entries.
|
|
44
|
+
"""
|
|
45
|
+
topk = torch.topk(x, k=self.k, dim=-1)
|
|
46
|
+
values = self.postact_fn(topk.values)
|
|
47
|
+
result = torch.zeros_like(x)
|
|
48
|
+
result.scatter_(-1, topk.indices, values)
|
|
49
|
+
return result
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@dataclass
|
|
53
|
+
class TopKSAEConfig(SAEConfig):
|
|
54
|
+
"""
|
|
55
|
+
Configuration class for a TopKSAE.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
k: int = 100
|
|
59
|
+
|
|
60
|
+
@override
|
|
61
|
+
@classmethod
|
|
62
|
+
def architecture(cls) -> str:
|
|
63
|
+
return "topk"
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class TopKSAE(SAE[TopKSAEConfig]):
|
|
67
|
+
"""
|
|
68
|
+
An inference-only sparse autoencoder using a "topk" activation function.
|
|
69
|
+
It uses linear encoder and decoder layers, applying the TopK activation
|
|
70
|
+
to the hidden pre-activation in its encode step.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
b_enc: nn.Parameter
|
|
74
|
+
|
|
75
|
+
def __init__(self, cfg: TopKSAEConfig, use_error_term: bool = False):
|
|
76
|
+
"""
|
|
77
|
+
Args:
|
|
78
|
+
cfg: SAEConfig defining model size and behavior.
|
|
79
|
+
use_error_term: Whether to apply the error-term approach in the forward pass.
|
|
80
|
+
"""
|
|
81
|
+
super().__init__(cfg, use_error_term)
|
|
82
|
+
|
|
83
|
+
@override
|
|
84
|
+
def initialize_weights(self) -> None:
|
|
85
|
+
# Initialize encoder weights and bias.
|
|
86
|
+
super().initialize_weights()
|
|
87
|
+
_init_weights_topk(self)
|
|
88
|
+
|
|
89
|
+
def encode(
|
|
90
|
+
self, x: Float[torch.Tensor, "... d_in"]
|
|
91
|
+
) -> Float[torch.Tensor, "... d_sae"]:
|
|
92
|
+
"""
|
|
93
|
+
Converts input x into feature activations.
|
|
94
|
+
Uses topk activation from the config (cfg.activation_fn == "topk")
|
|
95
|
+
under the hood.
|
|
96
|
+
"""
|
|
97
|
+
sae_in = self.process_sae_in(x)
|
|
98
|
+
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
|
|
99
|
+
# The BaseSAE already sets self.activation_fn to TopK(...) if config requests topk.
|
|
100
|
+
return self.hook_sae_acts_post(self.activation_fn(hidden_pre))
|
|
101
|
+
|
|
102
|
+
def decode(
|
|
103
|
+
self, feature_acts: Float[torch.Tensor, "... d_sae"]
|
|
104
|
+
) -> Float[torch.Tensor, "... d_in"]:
|
|
105
|
+
"""
|
|
106
|
+
Reconstructs the input from topk feature activations.
|
|
107
|
+
Applies optional finetuning scaling, hooking to recons, out normalization,
|
|
108
|
+
and optional head reshaping.
|
|
109
|
+
"""
|
|
110
|
+
sae_out_pre = feature_acts @ self.W_dec + self.b_dec
|
|
111
|
+
sae_out_pre = self.hook_sae_recons(sae_out_pre)
|
|
112
|
+
sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
|
|
113
|
+
return self.reshape_fn_out(sae_out_pre, self.d_head)
|
|
114
|
+
|
|
115
|
+
@override
|
|
116
|
+
def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
|
|
117
|
+
return TopK(self.cfg.k)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@dataclass
|
|
121
|
+
class TopKTrainingSAEConfig(TrainingSAEConfig):
|
|
122
|
+
"""
|
|
123
|
+
Configuration class for training a TopKTrainingSAE.
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
k: int = 100
|
|
127
|
+
|
|
128
|
+
@override
|
|
129
|
+
@classmethod
|
|
130
|
+
def architecture(cls) -> str:
|
|
131
|
+
return "topk"
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
|
|
135
|
+
"""
|
|
136
|
+
TopK variant with training functionality. Injects noise during training, optionally
|
|
137
|
+
calculates a topk-related auxiliary loss, etc.
|
|
138
|
+
"""
|
|
139
|
+
|
|
140
|
+
b_enc: nn.Parameter
|
|
141
|
+
|
|
142
|
+
def __init__(self, cfg: TopKTrainingSAEConfig, use_error_term: bool = False):
|
|
143
|
+
super().__init__(cfg, use_error_term)
|
|
144
|
+
|
|
145
|
+
@override
|
|
146
|
+
def initialize_weights(self) -> None:
|
|
147
|
+
super().initialize_weights()
|
|
148
|
+
_init_weights_topk(self)
|
|
149
|
+
|
|
150
|
+
def encode_with_hidden_pre(
|
|
151
|
+
self, x: Float[torch.Tensor, "... d_in"]
|
|
152
|
+
) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
|
|
153
|
+
"""
|
|
154
|
+
Similar to the base training method: cast input, optionally add noise, then apply TopK.
|
|
155
|
+
"""
|
|
156
|
+
sae_in = self.process_sae_in(x)
|
|
157
|
+
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
|
|
158
|
+
|
|
159
|
+
# Inject noise if training
|
|
160
|
+
if self.training and self.cfg.noise_scale > 0:
|
|
161
|
+
hidden_pre_noised = (
|
|
162
|
+
hidden_pre + torch.randn_like(hidden_pre) * self.cfg.noise_scale
|
|
163
|
+
)
|
|
164
|
+
else:
|
|
165
|
+
hidden_pre_noised = hidden_pre
|
|
166
|
+
|
|
167
|
+
# Apply the TopK activation function (already set in self.activation_fn if config is "topk")
|
|
168
|
+
feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre_noised))
|
|
169
|
+
return feature_acts, hidden_pre_noised
|
|
170
|
+
|
|
171
|
+
def calculate_aux_loss(
|
|
172
|
+
self,
|
|
173
|
+
step_input: TrainStepInput,
|
|
174
|
+
feature_acts: torch.Tensor,
|
|
175
|
+
hidden_pre: torch.Tensor,
|
|
176
|
+
sae_out: torch.Tensor,
|
|
177
|
+
) -> dict[str, torch.Tensor]:
|
|
178
|
+
# Calculate the auxiliary loss for dead neurons
|
|
179
|
+
topk_loss = self.calculate_topk_aux_loss(
|
|
180
|
+
sae_in=step_input.sae_in,
|
|
181
|
+
sae_out=sae_out,
|
|
182
|
+
hidden_pre=hidden_pre,
|
|
183
|
+
dead_neuron_mask=step_input.dead_neuron_mask,
|
|
184
|
+
)
|
|
185
|
+
return {"auxiliary_reconstruction_loss": topk_loss}
|
|
186
|
+
|
|
187
|
+
@override
|
|
188
|
+
def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
|
|
189
|
+
return TopK(self.cfg.k)
|
|
190
|
+
|
|
191
|
+
@override
|
|
192
|
+
def get_coefficients(self) -> dict[str, TrainCoefficientConfig | float]:
|
|
193
|
+
return {}
|
|
194
|
+
|
|
195
|
+
def calculate_topk_aux_loss(
|
|
196
|
+
self,
|
|
197
|
+
sae_in: torch.Tensor,
|
|
198
|
+
sae_out: torch.Tensor,
|
|
199
|
+
hidden_pre: torch.Tensor,
|
|
200
|
+
dead_neuron_mask: torch.Tensor | None,
|
|
201
|
+
) -> torch.Tensor:
|
|
202
|
+
"""
|
|
203
|
+
Calculate TopK auxiliary loss.
|
|
204
|
+
|
|
205
|
+
This auxiliary loss encourages dead neurons to learn useful features by having
|
|
206
|
+
them reconstruct the residual error from the live neurons. It's a key part of
|
|
207
|
+
preventing neuron death in TopK SAEs.
|
|
208
|
+
"""
|
|
209
|
+
# Mostly taken from https://github.com/EleutherAI/sae/blob/main/sae/sae.py, except without variance normalization
|
|
210
|
+
# NOTE: checking the number of dead neurons will force a GPU sync, so performance can likely be improved here
|
|
211
|
+
if dead_neuron_mask is None or (num_dead := int(dead_neuron_mask.sum())) == 0:
|
|
212
|
+
return sae_out.new_tensor(0.0)
|
|
213
|
+
residual = (sae_in - sae_out).detach()
|
|
214
|
+
|
|
215
|
+
# Heuristic from Appendix B.1 in the paper
|
|
216
|
+
k_aux = sae_in.shape[-1] // 2
|
|
217
|
+
|
|
218
|
+
# Reduce the scale of the loss if there are a small number of dead latents
|
|
219
|
+
scale = min(num_dead / k_aux, 1.0)
|
|
220
|
+
k_aux = min(k_aux, num_dead)
|
|
221
|
+
|
|
222
|
+
auxk_acts = _calculate_topk_aux_acts(
|
|
223
|
+
k_aux=k_aux,
|
|
224
|
+
hidden_pre=hidden_pre,
|
|
225
|
+
dead_neuron_mask=dead_neuron_mask,
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
# Encourage the top ~50% of dead latents to predict the residual of the
|
|
229
|
+
# top k living latents
|
|
230
|
+
recons = self.decode(auxk_acts)
|
|
231
|
+
auxk_loss = (recons - residual).pow(2).sum(dim=-1).mean()
|
|
232
|
+
return scale * auxk_loss
|
|
233
|
+
|
|
234
|
+
def _calculate_topk_aux_acts(
|
|
235
|
+
self,
|
|
236
|
+
k_aux: int,
|
|
237
|
+
hidden_pre: torch.Tensor,
|
|
238
|
+
dead_neuron_mask: torch.Tensor,
|
|
239
|
+
) -> torch.Tensor:
|
|
240
|
+
"""
|
|
241
|
+
Helper method to calculate activations for the auxiliary loss.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
k_aux: Number of top dead neurons to select
|
|
245
|
+
hidden_pre: Pre-activation values from encoder
|
|
246
|
+
dead_neuron_mask: Boolean mask indicating which neurons are dead
|
|
247
|
+
|
|
248
|
+
Returns:
|
|
249
|
+
Tensor with activations for only the top-k dead neurons, zeros elsewhere
|
|
250
|
+
"""
|
|
251
|
+
# Don't include living latents in this loss (set them to -inf so they won't be selected)
|
|
252
|
+
auxk_latents = torch.where(
|
|
253
|
+
dead_neuron_mask[None],
|
|
254
|
+
hidden_pre,
|
|
255
|
+
torch.tensor(-float("inf"), device=hidden_pre.device),
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
# Find topk values among dead neurons
|
|
259
|
+
auxk_topk = auxk_latents.topk(k_aux, dim=-1, sorted=False)
|
|
260
|
+
|
|
261
|
+
# Create a tensor of zeros, then place the topk values at their proper indices
|
|
262
|
+
auxk_acts = torch.zeros_like(hidden_pre)
|
|
263
|
+
auxk_acts.scatter_(-1, auxk_topk.indices, auxk_topk.values)
|
|
264
|
+
|
|
265
|
+
return auxk_acts
|
|
266
|
+
|
|
267
|
+
def to_inference_config_dict(self) -> dict[str, Any]:
|
|
268
|
+
return filter_valid_dataclass_fields(
|
|
269
|
+
self.cfg.to_dict(), TopKSAEConfig, ["architecture"]
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def _calculate_topk_aux_acts(
|
|
274
|
+
k_aux: int,
|
|
275
|
+
hidden_pre: torch.Tensor,
|
|
276
|
+
dead_neuron_mask: torch.Tensor,
|
|
277
|
+
) -> torch.Tensor:
|
|
278
|
+
# Don't include living latents in this loss
|
|
279
|
+
auxk_latents = torch.where(dead_neuron_mask[None], hidden_pre, -torch.inf)
|
|
280
|
+
# Top-k dead latents
|
|
281
|
+
auxk_topk = auxk_latents.topk(k_aux, sorted=False)
|
|
282
|
+
# Set the activations to zero for all but the top k_aux dead latents
|
|
283
|
+
auxk_acts = torch.zeros_like(hidden_pre)
|
|
284
|
+
auxk_acts.scatter_(-1, auxk_topk.indices, auxk_topk.values)
|
|
285
|
+
# Set activations to zero for all but top k_aux dead latents
|
|
286
|
+
return auxk_acts
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
def _init_weights_topk(
|
|
290
|
+
sae: SAE[TopKSAEConfig] | TrainingSAE[TopKTrainingSAEConfig],
|
|
291
|
+
) -> None:
|
|
292
|
+
sae.b_enc = nn.Parameter(
|
|
293
|
+
torch.zeros(sae.cfg.d_sae, dtype=sae.dtype, device=sae.device)
|
|
294
|
+
)
|
|
@@ -23,12 +23,12 @@ from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
|
|
23
23
|
|
|
24
24
|
from sae_lens import logger
|
|
25
25
|
from sae_lens.config import (
|
|
26
|
-
DTYPE_MAP,
|
|
27
26
|
CacheActivationsRunnerConfig,
|
|
28
27
|
HfDataset,
|
|
29
28
|
LanguageModelSAERunnerConfig,
|
|
30
29
|
)
|
|
31
|
-
from sae_lens.
|
|
30
|
+
from sae_lens.constants import DTYPE_MAP
|
|
31
|
+
from sae_lens.saes.sae import SAE, T_SAE_CONFIG, T_TRAINING_SAE_CONFIG
|
|
32
32
|
from sae_lens.tokenization_and_batching import concat_and_batch_sequences
|
|
33
33
|
|
|
34
34
|
|
|
@@ -91,7 +91,8 @@ class ActivationsStore:
|
|
|
91
91
|
def from_config(
|
|
92
92
|
cls,
|
|
93
93
|
model: HookedRootModule,
|
|
94
|
-
cfg: LanguageModelSAERunnerConfig
|
|
94
|
+
cfg: LanguageModelSAERunnerConfig[T_TRAINING_SAE_CONFIG]
|
|
95
|
+
| CacheActivationsRunnerConfig,
|
|
95
96
|
override_dataset: HfDataset | None = None,
|
|
96
97
|
) -> ActivationsStore:
|
|
97
98
|
if isinstance(cfg, CacheActivationsRunnerConfig):
|
|
@@ -128,13 +129,15 @@ class ActivationsStore:
|
|
|
128
129
|
hook_layer=cfg.hook_layer,
|
|
129
130
|
hook_head_index=cfg.hook_head_index,
|
|
130
131
|
context_size=cfg.context_size,
|
|
131
|
-
d_in=cfg.d_in
|
|
132
|
+
d_in=cfg.d_in
|
|
133
|
+
if isinstance(cfg, CacheActivationsRunnerConfig)
|
|
134
|
+
else cfg.sae.d_in,
|
|
132
135
|
n_batches_in_buffer=cfg.n_batches_in_buffer,
|
|
133
136
|
total_training_tokens=cfg.training_tokens,
|
|
134
137
|
store_batch_size_prompts=cfg.store_batch_size_prompts,
|
|
135
138
|
train_batch_size_tokens=cfg.train_batch_size_tokens,
|
|
136
139
|
prepend_bos=cfg.prepend_bos,
|
|
137
|
-
normalize_activations=cfg.normalize_activations,
|
|
140
|
+
normalize_activations=cfg.sae.normalize_activations,
|
|
138
141
|
device=device,
|
|
139
142
|
dtype=cfg.dtype,
|
|
140
143
|
cached_activations_path=cached_activations_path,
|
|
@@ -149,9 +152,10 @@ class ActivationsStore:
|
|
|
149
152
|
def from_sae(
|
|
150
153
|
cls,
|
|
151
154
|
model: HookedRootModule,
|
|
152
|
-
sae: SAE,
|
|
155
|
+
sae: SAE[T_SAE_CONFIG],
|
|
156
|
+
dataset: HfDataset | str,
|
|
157
|
+
dataset_trust_remote_code: bool = False,
|
|
153
158
|
context_size: int | None = None,
|
|
154
|
-
dataset: HfDataset | str | None = None,
|
|
155
159
|
streaming: bool = True,
|
|
156
160
|
store_batch_size_prompts: int = 8,
|
|
157
161
|
n_batches_in_buffer: int = 8,
|
|
@@ -159,25 +163,37 @@ class ActivationsStore:
|
|
|
159
163
|
total_tokens: int = 10**9,
|
|
160
164
|
device: str = "cpu",
|
|
161
165
|
) -> ActivationsStore:
|
|
166
|
+
if sae.cfg.metadata.hook_name is None:
|
|
167
|
+
raise ValueError("hook_name is required")
|
|
168
|
+
if sae.cfg.metadata.hook_layer is None:
|
|
169
|
+
raise ValueError("hook_layer is required")
|
|
170
|
+
if sae.cfg.metadata.hook_head_index is None:
|
|
171
|
+
raise ValueError("hook_head_index is required")
|
|
172
|
+
if sae.cfg.metadata.context_size is None:
|
|
173
|
+
raise ValueError("context_size is required")
|
|
174
|
+
if sae.cfg.metadata.prepend_bos is None:
|
|
175
|
+
raise ValueError("prepend_bos is required")
|
|
162
176
|
return cls(
|
|
163
177
|
model=model,
|
|
164
|
-
dataset=
|
|
178
|
+
dataset=dataset,
|
|
165
179
|
d_in=sae.cfg.d_in,
|
|
166
|
-
hook_name=sae.cfg.hook_name,
|
|
167
|
-
hook_layer=sae.cfg.hook_layer,
|
|
168
|
-
hook_head_index=sae.cfg.hook_head_index,
|
|
169
|
-
context_size=sae.cfg.context_size
|
|
170
|
-
|
|
180
|
+
hook_name=sae.cfg.metadata.hook_name,
|
|
181
|
+
hook_layer=sae.cfg.metadata.hook_layer,
|
|
182
|
+
hook_head_index=sae.cfg.metadata.hook_head_index,
|
|
183
|
+
context_size=sae.cfg.metadata.context_size
|
|
184
|
+
if context_size is None
|
|
185
|
+
else context_size,
|
|
186
|
+
prepend_bos=sae.cfg.metadata.prepend_bos,
|
|
171
187
|
streaming=streaming,
|
|
172
188
|
store_batch_size_prompts=store_batch_size_prompts,
|
|
173
189
|
train_batch_size_tokens=train_batch_size_tokens,
|
|
174
190
|
n_batches_in_buffer=n_batches_in_buffer,
|
|
175
191
|
total_training_tokens=total_tokens,
|
|
176
192
|
normalize_activations=sae.cfg.normalize_activations,
|
|
177
|
-
dataset_trust_remote_code=
|
|
193
|
+
dataset_trust_remote_code=dataset_trust_remote_code,
|
|
178
194
|
dtype=sae.cfg.dtype,
|
|
179
195
|
device=torch.device(device),
|
|
180
|
-
seqpos_slice=sae.cfg.seqpos_slice,
|
|
196
|
+
seqpos_slice=sae.cfg.metadata.seqpos_slice or (None,),
|
|
181
197
|
)
|
|
182
198
|
|
|
183
199
|
def __init__(
|
|
@@ -448,7 +464,7 @@ class ActivationsStore:
|
|
|
448
464
|
):
|
|
449
465
|
# temporalily set estimated_norm_scaling_factor to 1.0 so the dataloader works
|
|
450
466
|
self.estimated_norm_scaling_factor = 1.0
|
|
451
|
-
acts = self.next_batch()[
|
|
467
|
+
acts = self.next_batch()[0]
|
|
452
468
|
self.estimated_norm_scaling_factor = None
|
|
453
469
|
norms_per_batch.append(acts.norm(dim=-1).mean().item())
|
|
454
470
|
mean_norm = np.mean(norms_per_batch)
|