sae-lens 5.11.0__py3-none-any.whl → 6.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +60 -7
- sae_lens/analysis/hooked_sae_transformer.py +12 -12
- sae_lens/analysis/neuronpedia_integration.py +16 -14
- sae_lens/cache_activations_runner.py +9 -7
- sae_lens/config.py +170 -258
- sae_lens/constants.py +21 -0
- sae_lens/evals.py +59 -44
- sae_lens/llm_sae_training_runner.py +377 -0
- sae_lens/load_model.py +52 -4
- sae_lens/{toolkit → loading}/pretrained_sae_loaders.py +85 -32
- sae_lens/registry.py +49 -0
- sae_lens/saes/__init__.py +48 -0
- sae_lens/saes/gated_sae.py +254 -0
- sae_lens/saes/jumprelu_sae.py +348 -0
- sae_lens/saes/sae.py +1076 -0
- sae_lens/saes/standard_sae.py +178 -0
- sae_lens/saes/topk_sae.py +300 -0
- sae_lens/training/activation_scaler.py +53 -0
- sae_lens/training/activations_store.py +103 -184
- sae_lens/training/mixing_buffer.py +56 -0
- sae_lens/training/optim.py +60 -36
- sae_lens/training/sae_trainer.py +155 -177
- sae_lens/training/types.py +5 -0
- sae_lens/training/upload_saes_to_huggingface.py +13 -7
- sae_lens/util.py +47 -0
- {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/METADATA +1 -1
- sae_lens-6.0.0.dist-info/RECORD +37 -0
- sae_lens/sae.py +0 -747
- sae_lens/sae_training_runner.py +0 -251
- sae_lens/training/geometric_median.py +0 -101
- sae_lens/training/training_sae.py +0 -710
- sae_lens-5.11.0.dist-info/RECORD +0 -28
- /sae_lens/{toolkit → loading}/__init__.py +0 -0
- /sae_lens/{toolkit → loading}/pretrained_saes_directory.py +0 -0
- {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/LICENSE +0 -0
- {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
from jaxtyping import Float
|
|
7
|
+
from numpy.typing import NDArray
|
|
8
|
+
from torch import nn
|
|
9
|
+
from typing_extensions import override
|
|
10
|
+
|
|
11
|
+
from sae_lens.saes.sae import (
|
|
12
|
+
SAE,
|
|
13
|
+
SAEConfig,
|
|
14
|
+
TrainCoefficientConfig,
|
|
15
|
+
TrainingSAE,
|
|
16
|
+
TrainingSAEConfig,
|
|
17
|
+
TrainStepInput,
|
|
18
|
+
)
|
|
19
|
+
from sae_lens.util import filter_valid_dataclass_fields
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class StandardSAEConfig(SAEConfig):
|
|
24
|
+
"""
|
|
25
|
+
Configuration class for a StandardSAE.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
@override
|
|
29
|
+
@classmethod
|
|
30
|
+
def architecture(cls) -> str:
|
|
31
|
+
return "standard"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class StandardSAE(SAE[StandardSAEConfig]):
|
|
35
|
+
"""
|
|
36
|
+
StandardSAE is an inference-only implementation of a Sparse Autoencoder (SAE)
|
|
37
|
+
using a simple linear encoder and decoder.
|
|
38
|
+
|
|
39
|
+
It implements the required abstract methods from BaseSAE:
|
|
40
|
+
- initialize_weights: sets up simple parameter initializations for W_enc, b_enc, W_dec, and b_dec.
|
|
41
|
+
- encode: computes the feature activations from an input.
|
|
42
|
+
- decode: reconstructs the input from the feature activations.
|
|
43
|
+
|
|
44
|
+
The BaseSAE.forward() method automatically calls encode and decode,
|
|
45
|
+
including any error-term processing if configured.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
b_enc: nn.Parameter
|
|
49
|
+
|
|
50
|
+
def __init__(self, cfg: StandardSAEConfig, use_error_term: bool = False):
|
|
51
|
+
super().__init__(cfg, use_error_term)
|
|
52
|
+
|
|
53
|
+
@override
|
|
54
|
+
def initialize_weights(self) -> None:
|
|
55
|
+
# Initialize encoder weights and bias.
|
|
56
|
+
super().initialize_weights()
|
|
57
|
+
_init_weights_standard(self)
|
|
58
|
+
|
|
59
|
+
def encode(
|
|
60
|
+
self, x: Float[torch.Tensor, "... d_in"]
|
|
61
|
+
) -> Float[torch.Tensor, "... d_sae"]:
|
|
62
|
+
"""
|
|
63
|
+
Encode the input tensor into the feature space.
|
|
64
|
+
For inference, no noise is added.
|
|
65
|
+
"""
|
|
66
|
+
# Preprocess the SAE input (casting type, applying hooks, normalization)
|
|
67
|
+
sae_in = self.process_sae_in(x)
|
|
68
|
+
# Compute the pre-activation values
|
|
69
|
+
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
|
|
70
|
+
# Apply the activation function (e.g., ReLU, depending on config)
|
|
71
|
+
return self.hook_sae_acts_post(self.activation_fn(hidden_pre))
|
|
72
|
+
|
|
73
|
+
def decode(
|
|
74
|
+
self, feature_acts: Float[torch.Tensor, "... d_sae"]
|
|
75
|
+
) -> Float[torch.Tensor, "... d_in"]:
|
|
76
|
+
"""
|
|
77
|
+
Decode the feature activations back to the input space.
|
|
78
|
+
Now, if hook_z reshaping is turned on, we reverse the flattening.
|
|
79
|
+
"""
|
|
80
|
+
# 1) linear transform
|
|
81
|
+
sae_out_pre = feature_acts @ self.W_dec + self.b_dec
|
|
82
|
+
# 2) hook reconstruction
|
|
83
|
+
sae_out_pre = self.hook_sae_recons(sae_out_pre)
|
|
84
|
+
# 4) optional out-normalization (e.g. constant_norm_rescale)
|
|
85
|
+
sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
|
|
86
|
+
# 5) if hook_z is enabled, rearrange back to (..., n_heads, d_head).
|
|
87
|
+
return self.reshape_fn_out(sae_out_pre, self.d_head)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@dataclass
|
|
91
|
+
class StandardTrainingSAEConfig(TrainingSAEConfig):
|
|
92
|
+
"""
|
|
93
|
+
Configuration class for training a StandardTrainingSAE.
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
l1_coefficient: float = 1.0
|
|
97
|
+
lp_norm: float = 1.0
|
|
98
|
+
l1_warm_up_steps: int = 0
|
|
99
|
+
|
|
100
|
+
@override
|
|
101
|
+
@classmethod
|
|
102
|
+
def architecture(cls) -> str:
|
|
103
|
+
return "standard"
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class StandardTrainingSAE(TrainingSAE[StandardTrainingSAEConfig]):
|
|
107
|
+
"""
|
|
108
|
+
StandardTrainingSAE is a concrete implementation of BaseTrainingSAE using the "standard" SAE architecture.
|
|
109
|
+
It implements:
|
|
110
|
+
- initialize_weights: basic weight initialization for encoder/decoder.
|
|
111
|
+
- encode: inference encoding (invokes encode_with_hidden_pre).
|
|
112
|
+
- decode: a simple linear decoder.
|
|
113
|
+
- encode_with_hidden_pre: computes pre-activations, adds noise when training, and then activates.
|
|
114
|
+
- calculate_aux_loss: computes a sparsity penalty based on the (optionally scaled) p-norm of feature activations.
|
|
115
|
+
"""
|
|
116
|
+
|
|
117
|
+
b_enc: nn.Parameter
|
|
118
|
+
|
|
119
|
+
def initialize_weights(self) -> None:
|
|
120
|
+
super().initialize_weights()
|
|
121
|
+
_init_weights_standard(self)
|
|
122
|
+
|
|
123
|
+
@override
|
|
124
|
+
def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]:
|
|
125
|
+
return {
|
|
126
|
+
"l1": TrainCoefficientConfig(
|
|
127
|
+
value=self.cfg.l1_coefficient,
|
|
128
|
+
warm_up_steps=self.cfg.l1_warm_up_steps,
|
|
129
|
+
),
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
def encode_with_hidden_pre(
|
|
133
|
+
self, x: Float[torch.Tensor, "... d_in"]
|
|
134
|
+
) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
|
|
135
|
+
# Process the input (including dtype conversion, hook call, and any activation normalization)
|
|
136
|
+
sae_in = self.process_sae_in(x)
|
|
137
|
+
# Compute the pre-activation (and allow for a hook if desired)
|
|
138
|
+
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc) # type: ignore
|
|
139
|
+
# Apply the activation function (and any post-activation hook)
|
|
140
|
+
feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre))
|
|
141
|
+
return feature_acts, hidden_pre
|
|
142
|
+
|
|
143
|
+
def calculate_aux_loss(
|
|
144
|
+
self,
|
|
145
|
+
step_input: TrainStepInput,
|
|
146
|
+
feature_acts: torch.Tensor,
|
|
147
|
+
hidden_pre: torch.Tensor,
|
|
148
|
+
sae_out: torch.Tensor,
|
|
149
|
+
) -> dict[str, torch.Tensor]:
|
|
150
|
+
# The "standard" auxiliary loss is a sparsity penalty on the feature activations
|
|
151
|
+
weighted_feature_acts = feature_acts * self.W_dec.norm(dim=1)
|
|
152
|
+
|
|
153
|
+
# Compute the p-norm (set by cfg.lp_norm) over the feature dimension
|
|
154
|
+
sparsity = weighted_feature_acts.norm(p=self.cfg.lp_norm, dim=-1)
|
|
155
|
+
l1_loss = (step_input.coefficients["l1"] * sparsity).mean()
|
|
156
|
+
|
|
157
|
+
return {"l1_loss": l1_loss}
|
|
158
|
+
|
|
159
|
+
def log_histograms(self) -> dict[str, NDArray[np.generic]]:
|
|
160
|
+
"""Log histograms of the weights and biases."""
|
|
161
|
+
b_e_dist = self.b_enc.detach().float().cpu().numpy()
|
|
162
|
+
return {
|
|
163
|
+
**super().log_histograms(),
|
|
164
|
+
"weights/b_e": b_e_dist,
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
def to_inference_config_dict(self) -> dict[str, Any]:
|
|
168
|
+
return filter_valid_dataclass_fields(
|
|
169
|
+
self.cfg.to_dict(), StandardSAEConfig, ["architecture"]
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def _init_weights_standard(
|
|
174
|
+
sae: SAE[StandardSAEConfig] | TrainingSAE[StandardTrainingSAEConfig],
|
|
175
|
+
) -> None:
|
|
176
|
+
sae.b_enc = nn.Parameter(
|
|
177
|
+
torch.zeros(sae.cfg.d_sae, dtype=sae.dtype, device=sae.device)
|
|
178
|
+
)
|
|
@@ -0,0 +1,300 @@
|
|
|
1
|
+
"""Inference-only TopKSAE variant, similar in spirit to StandardSAE but using a TopK-based activation."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, Callable
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from jaxtyping import Float
|
|
8
|
+
from torch import nn
|
|
9
|
+
from typing_extensions import override
|
|
10
|
+
|
|
11
|
+
from sae_lens.saes.sae import (
|
|
12
|
+
SAE,
|
|
13
|
+
SAEConfig,
|
|
14
|
+
TrainCoefficientConfig,
|
|
15
|
+
TrainingSAE,
|
|
16
|
+
TrainingSAEConfig,
|
|
17
|
+
TrainStepInput,
|
|
18
|
+
)
|
|
19
|
+
from sae_lens.util import filter_valid_dataclass_fields
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class TopK(nn.Module):
|
|
23
|
+
"""
|
|
24
|
+
A simple TopK activation that zeroes out all but the top K elements along the last dimension,
|
|
25
|
+
then optionally applies a post-activation function (e.g., ReLU).
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
b_enc: nn.Parameter
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
k: int,
|
|
33
|
+
postact_fn: Callable[[torch.Tensor], torch.Tensor] = nn.ReLU(),
|
|
34
|
+
):
|
|
35
|
+
super().__init__()
|
|
36
|
+
self.k = k
|
|
37
|
+
self.postact_fn = postact_fn
|
|
38
|
+
|
|
39
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
40
|
+
"""
|
|
41
|
+
1) Select top K elements along the last dimension.
|
|
42
|
+
2) Apply post-activation (often ReLU).
|
|
43
|
+
3) Zero out all other entries.
|
|
44
|
+
"""
|
|
45
|
+
topk = torch.topk(x, k=self.k, dim=-1)
|
|
46
|
+
values = self.postact_fn(topk.values)
|
|
47
|
+
result = torch.zeros_like(x)
|
|
48
|
+
result.scatter_(-1, topk.indices, values)
|
|
49
|
+
return result
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@dataclass
|
|
53
|
+
class TopKSAEConfig(SAEConfig):
|
|
54
|
+
"""
|
|
55
|
+
Configuration class for a TopKSAE.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
k: int = 100
|
|
59
|
+
|
|
60
|
+
@override
|
|
61
|
+
@classmethod
|
|
62
|
+
def architecture(cls) -> str:
|
|
63
|
+
return "topk"
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class TopKSAE(SAE[TopKSAEConfig]):
|
|
67
|
+
"""
|
|
68
|
+
An inference-only sparse autoencoder using a "topk" activation function.
|
|
69
|
+
It uses linear encoder and decoder layers, applying the TopK activation
|
|
70
|
+
to the hidden pre-activation in its encode step.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
b_enc: nn.Parameter
|
|
74
|
+
|
|
75
|
+
def __init__(self, cfg: TopKSAEConfig, use_error_term: bool = False):
|
|
76
|
+
"""
|
|
77
|
+
Args:
|
|
78
|
+
cfg: SAEConfig defining model size and behavior.
|
|
79
|
+
use_error_term: Whether to apply the error-term approach in the forward pass.
|
|
80
|
+
"""
|
|
81
|
+
super().__init__(cfg, use_error_term)
|
|
82
|
+
|
|
83
|
+
@override
|
|
84
|
+
def initialize_weights(self) -> None:
|
|
85
|
+
# Initialize encoder weights and bias.
|
|
86
|
+
super().initialize_weights()
|
|
87
|
+
_init_weights_topk(self)
|
|
88
|
+
|
|
89
|
+
def encode(
|
|
90
|
+
self, x: Float[torch.Tensor, "... d_in"]
|
|
91
|
+
) -> Float[torch.Tensor, "... d_sae"]:
|
|
92
|
+
"""
|
|
93
|
+
Converts input x into feature activations.
|
|
94
|
+
Uses topk activation under the hood.
|
|
95
|
+
"""
|
|
96
|
+
sae_in = self.process_sae_in(x)
|
|
97
|
+
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
|
|
98
|
+
# The BaseSAE already sets self.activation_fn to TopK(...) if config requests topk.
|
|
99
|
+
return self.hook_sae_acts_post(self.activation_fn(hidden_pre))
|
|
100
|
+
|
|
101
|
+
def decode(
|
|
102
|
+
self, feature_acts: Float[torch.Tensor, "... d_sae"]
|
|
103
|
+
) -> Float[torch.Tensor, "... d_in"]:
|
|
104
|
+
"""
|
|
105
|
+
Reconstructs the input from topk feature activations.
|
|
106
|
+
Applies optional finetuning scaling, hooking to recons, out normalization,
|
|
107
|
+
and optional head reshaping.
|
|
108
|
+
"""
|
|
109
|
+
sae_out_pre = feature_acts @ self.W_dec + self.b_dec
|
|
110
|
+
sae_out_pre = self.hook_sae_recons(sae_out_pre)
|
|
111
|
+
sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
|
|
112
|
+
return self.reshape_fn_out(sae_out_pre, self.d_head)
|
|
113
|
+
|
|
114
|
+
@override
|
|
115
|
+
def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
|
|
116
|
+
return TopK(self.cfg.k)
|
|
117
|
+
|
|
118
|
+
@override
|
|
119
|
+
@torch.no_grad()
|
|
120
|
+
def fold_W_dec_norm(self) -> None:
|
|
121
|
+
raise NotImplementedError(
|
|
122
|
+
"Folding W_dec_norm is not safe for TopKSAEs, as this may change the topk activations"
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
@dataclass
|
|
127
|
+
class TopKTrainingSAEConfig(TrainingSAEConfig):
|
|
128
|
+
"""
|
|
129
|
+
Configuration class for training a TopKTrainingSAE.
|
|
130
|
+
"""
|
|
131
|
+
|
|
132
|
+
k: int = 100
|
|
133
|
+
|
|
134
|
+
@override
|
|
135
|
+
@classmethod
|
|
136
|
+
def architecture(cls) -> str:
|
|
137
|
+
return "topk"
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
|
|
141
|
+
"""
|
|
142
|
+
TopK variant with training functionality. Injects noise during training, optionally
|
|
143
|
+
calculates a topk-related auxiliary loss, etc.
|
|
144
|
+
"""
|
|
145
|
+
|
|
146
|
+
b_enc: nn.Parameter
|
|
147
|
+
|
|
148
|
+
def __init__(self, cfg: TopKTrainingSAEConfig, use_error_term: bool = False):
|
|
149
|
+
super().__init__(cfg, use_error_term)
|
|
150
|
+
|
|
151
|
+
@override
|
|
152
|
+
def initialize_weights(self) -> None:
|
|
153
|
+
super().initialize_weights()
|
|
154
|
+
_init_weights_topk(self)
|
|
155
|
+
|
|
156
|
+
def encode_with_hidden_pre(
|
|
157
|
+
self, x: Float[torch.Tensor, "... d_in"]
|
|
158
|
+
) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
|
|
159
|
+
"""
|
|
160
|
+
Similar to the base training method: cast input, optionally add noise, then apply TopK.
|
|
161
|
+
"""
|
|
162
|
+
sae_in = self.process_sae_in(x)
|
|
163
|
+
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
|
|
164
|
+
|
|
165
|
+
# Apply the TopK activation function (already set in self.activation_fn if config is "topk")
|
|
166
|
+
feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre))
|
|
167
|
+
return feature_acts, hidden_pre
|
|
168
|
+
|
|
169
|
+
@override
|
|
170
|
+
def calculate_aux_loss(
|
|
171
|
+
self,
|
|
172
|
+
step_input: TrainStepInput,
|
|
173
|
+
feature_acts: torch.Tensor,
|
|
174
|
+
hidden_pre: torch.Tensor,
|
|
175
|
+
sae_out: torch.Tensor,
|
|
176
|
+
) -> dict[str, torch.Tensor]:
|
|
177
|
+
# Calculate the auxiliary loss for dead neurons
|
|
178
|
+
topk_loss = self.calculate_topk_aux_loss(
|
|
179
|
+
sae_in=step_input.sae_in,
|
|
180
|
+
sae_out=sae_out,
|
|
181
|
+
hidden_pre=hidden_pre,
|
|
182
|
+
dead_neuron_mask=step_input.dead_neuron_mask,
|
|
183
|
+
)
|
|
184
|
+
return {"auxiliary_reconstruction_loss": topk_loss}
|
|
185
|
+
|
|
186
|
+
@override
|
|
187
|
+
@torch.no_grad()
|
|
188
|
+
def fold_W_dec_norm(self) -> None:
|
|
189
|
+
raise NotImplementedError(
|
|
190
|
+
"Folding W_dec_norm is not safe for TopKSAEs, as this may change the topk activations"
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
@override
|
|
194
|
+
def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
|
|
195
|
+
return TopK(self.cfg.k)
|
|
196
|
+
|
|
197
|
+
@override
|
|
198
|
+
def get_coefficients(self) -> dict[str, TrainCoefficientConfig | float]:
|
|
199
|
+
return {}
|
|
200
|
+
|
|
201
|
+
def calculate_topk_aux_loss(
|
|
202
|
+
self,
|
|
203
|
+
sae_in: torch.Tensor,
|
|
204
|
+
sae_out: torch.Tensor,
|
|
205
|
+
hidden_pre: torch.Tensor,
|
|
206
|
+
dead_neuron_mask: torch.Tensor | None,
|
|
207
|
+
) -> torch.Tensor:
|
|
208
|
+
"""
|
|
209
|
+
Calculate TopK auxiliary loss.
|
|
210
|
+
|
|
211
|
+
This auxiliary loss encourages dead neurons to learn useful features by having
|
|
212
|
+
them reconstruct the residual error from the live neurons. It's a key part of
|
|
213
|
+
preventing neuron death in TopK SAEs.
|
|
214
|
+
"""
|
|
215
|
+
# Mostly taken from https://github.com/EleutherAI/sae/blob/main/sae/sae.py, except without variance normalization
|
|
216
|
+
# NOTE: checking the number of dead neurons will force a GPU sync, so performance can likely be improved here
|
|
217
|
+
if dead_neuron_mask is None or (num_dead := int(dead_neuron_mask.sum())) == 0:
|
|
218
|
+
return sae_out.new_tensor(0.0)
|
|
219
|
+
residual = (sae_in - sae_out).detach()
|
|
220
|
+
|
|
221
|
+
# Heuristic from Appendix B.1 in the paper
|
|
222
|
+
k_aux = sae_in.shape[-1] // 2
|
|
223
|
+
|
|
224
|
+
# Reduce the scale of the loss if there are a small number of dead latents
|
|
225
|
+
scale = min(num_dead / k_aux, 1.0)
|
|
226
|
+
k_aux = min(k_aux, num_dead)
|
|
227
|
+
|
|
228
|
+
auxk_acts = _calculate_topk_aux_acts(
|
|
229
|
+
k_aux=k_aux,
|
|
230
|
+
hidden_pre=hidden_pre,
|
|
231
|
+
dead_neuron_mask=dead_neuron_mask,
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
# Encourage the top ~50% of dead latents to predict the residual of the
|
|
235
|
+
# top k living latents
|
|
236
|
+
recons = self.decode(auxk_acts)
|
|
237
|
+
auxk_loss = (recons - residual).pow(2).sum(dim=-1).mean()
|
|
238
|
+
return scale * auxk_loss
|
|
239
|
+
|
|
240
|
+
def _calculate_topk_aux_acts(
|
|
241
|
+
self,
|
|
242
|
+
k_aux: int,
|
|
243
|
+
hidden_pre: torch.Tensor,
|
|
244
|
+
dead_neuron_mask: torch.Tensor,
|
|
245
|
+
) -> torch.Tensor:
|
|
246
|
+
"""
|
|
247
|
+
Helper method to calculate activations for the auxiliary loss.
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
k_aux: Number of top dead neurons to select
|
|
251
|
+
hidden_pre: Pre-activation values from encoder
|
|
252
|
+
dead_neuron_mask: Boolean mask indicating which neurons are dead
|
|
253
|
+
|
|
254
|
+
Returns:
|
|
255
|
+
Tensor with activations for only the top-k dead neurons, zeros elsewhere
|
|
256
|
+
"""
|
|
257
|
+
# Don't include living latents in this loss (set them to -inf so they won't be selected)
|
|
258
|
+
auxk_latents = torch.where(
|
|
259
|
+
dead_neuron_mask[None],
|
|
260
|
+
hidden_pre,
|
|
261
|
+
torch.tensor(-float("inf"), device=hidden_pre.device),
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
# Find topk values among dead neurons
|
|
265
|
+
auxk_topk = auxk_latents.topk(k_aux, dim=-1, sorted=False)
|
|
266
|
+
|
|
267
|
+
# Create a tensor of zeros, then place the topk values at their proper indices
|
|
268
|
+
auxk_acts = torch.zeros_like(hidden_pre)
|
|
269
|
+
auxk_acts.scatter_(-1, auxk_topk.indices, auxk_topk.values)
|
|
270
|
+
|
|
271
|
+
return auxk_acts
|
|
272
|
+
|
|
273
|
+
def to_inference_config_dict(self) -> dict[str, Any]:
|
|
274
|
+
return filter_valid_dataclass_fields(
|
|
275
|
+
self.cfg.to_dict(), TopKSAEConfig, ["architecture"]
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
def _calculate_topk_aux_acts(
|
|
280
|
+
k_aux: int,
|
|
281
|
+
hidden_pre: torch.Tensor,
|
|
282
|
+
dead_neuron_mask: torch.Tensor,
|
|
283
|
+
) -> torch.Tensor:
|
|
284
|
+
# Don't include living latents in this loss
|
|
285
|
+
auxk_latents = torch.where(dead_neuron_mask[None], hidden_pre, -torch.inf)
|
|
286
|
+
# Top-k dead latents
|
|
287
|
+
auxk_topk = auxk_latents.topk(k_aux, sorted=False)
|
|
288
|
+
# Set the activations to zero for all but the top k_aux dead latents
|
|
289
|
+
auxk_acts = torch.zeros_like(hidden_pre)
|
|
290
|
+
auxk_acts.scatter_(-1, auxk_topk.indices, auxk_topk.values)
|
|
291
|
+
# Set activations to zero for all but top k_aux dead latents
|
|
292
|
+
return auxk_acts
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def _init_weights_topk(
|
|
296
|
+
sae: SAE[TopKSAEConfig] | TrainingSAE[TopKTrainingSAEConfig],
|
|
297
|
+
) -> None:
|
|
298
|
+
sae.b_enc = nn.Parameter(
|
|
299
|
+
torch.zeros(sae.cfg.d_sae, dtype=sae.dtype, device=sae.device)
|
|
300
|
+
)
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from statistics import mean
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from tqdm.auto import tqdm
|
|
7
|
+
|
|
8
|
+
from sae_lens.training.types import DataProvider
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class ActivationScaler:
|
|
13
|
+
scaling_factor: float | None = None
|
|
14
|
+
|
|
15
|
+
def scale(self, acts: torch.Tensor) -> torch.Tensor:
|
|
16
|
+
return acts if self.scaling_factor is None else acts * self.scaling_factor
|
|
17
|
+
|
|
18
|
+
def unscale(self, acts: torch.Tensor) -> torch.Tensor:
|
|
19
|
+
return acts if self.scaling_factor is None else acts / self.scaling_factor
|
|
20
|
+
|
|
21
|
+
def __call__(self, acts: torch.Tensor) -> torch.Tensor:
|
|
22
|
+
return self.scale(acts)
|
|
23
|
+
|
|
24
|
+
@torch.no_grad()
|
|
25
|
+
def _calculate_mean_norm(
|
|
26
|
+
self, data_provider: DataProvider, n_batches_for_norm_estimate: int = int(1e3)
|
|
27
|
+
) -> float:
|
|
28
|
+
norms_per_batch: list[float] = []
|
|
29
|
+
for _ in tqdm(
|
|
30
|
+
range(n_batches_for_norm_estimate), desc="Estimating norm scaling factor"
|
|
31
|
+
):
|
|
32
|
+
acts = next(data_provider)
|
|
33
|
+
norms_per_batch.append(acts.norm(dim=-1).mean().item())
|
|
34
|
+
return mean(norms_per_batch)
|
|
35
|
+
|
|
36
|
+
def estimate_scaling_factor(
|
|
37
|
+
self,
|
|
38
|
+
d_in: int,
|
|
39
|
+
data_provider: DataProvider,
|
|
40
|
+
n_batches_for_norm_estimate: int = int(1e3),
|
|
41
|
+
):
|
|
42
|
+
mean_norm = self._calculate_mean_norm(
|
|
43
|
+
data_provider, n_batches_for_norm_estimate
|
|
44
|
+
)
|
|
45
|
+
self.scaling_factor = (d_in**0.5) / mean_norm
|
|
46
|
+
|
|
47
|
+
def save(self, file_path: str):
|
|
48
|
+
"""save the state dict to a file in json format"""
|
|
49
|
+
if not file_path.endswith(".json"):
|
|
50
|
+
raise ValueError("file_path must end with .json")
|
|
51
|
+
|
|
52
|
+
with open(file_path, "w") as f:
|
|
53
|
+
json.dump({"scaling_factor": self.scaling_factor}, f)
|