sae-lens 5.11.0__py3-none-any.whl → 6.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. sae_lens/__init__.py +60 -7
  2. sae_lens/analysis/hooked_sae_transformer.py +12 -12
  3. sae_lens/analysis/neuronpedia_integration.py +16 -14
  4. sae_lens/cache_activations_runner.py +9 -7
  5. sae_lens/config.py +170 -258
  6. sae_lens/constants.py +21 -0
  7. sae_lens/evals.py +59 -44
  8. sae_lens/llm_sae_training_runner.py +377 -0
  9. sae_lens/load_model.py +52 -4
  10. sae_lens/{toolkit → loading}/pretrained_sae_loaders.py +85 -32
  11. sae_lens/registry.py +49 -0
  12. sae_lens/saes/__init__.py +48 -0
  13. sae_lens/saes/gated_sae.py +254 -0
  14. sae_lens/saes/jumprelu_sae.py +348 -0
  15. sae_lens/saes/sae.py +1076 -0
  16. sae_lens/saes/standard_sae.py +178 -0
  17. sae_lens/saes/topk_sae.py +300 -0
  18. sae_lens/training/activation_scaler.py +53 -0
  19. sae_lens/training/activations_store.py +103 -184
  20. sae_lens/training/mixing_buffer.py +56 -0
  21. sae_lens/training/optim.py +60 -36
  22. sae_lens/training/sae_trainer.py +155 -177
  23. sae_lens/training/types.py +5 -0
  24. sae_lens/training/upload_saes_to_huggingface.py +13 -7
  25. sae_lens/util.py +47 -0
  26. {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/METADATA +1 -1
  27. sae_lens-6.0.0.dist-info/RECORD +37 -0
  28. sae_lens/sae.py +0 -747
  29. sae_lens/sae_training_runner.py +0 -251
  30. sae_lens/training/geometric_median.py +0 -101
  31. sae_lens/training/training_sae.py +0 -710
  32. sae_lens-5.11.0.dist-info/RECORD +0 -28
  33. /sae_lens/{toolkit → loading}/__init__.py +0 -0
  34. /sae_lens/{toolkit → loading}/pretrained_saes_directory.py +0 -0
  35. {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/LICENSE +0 -0
  36. {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,178 @@
1
+ from dataclasses import dataclass
2
+ from typing import Any
3
+
4
+ import numpy as np
5
+ import torch
6
+ from jaxtyping import Float
7
+ from numpy.typing import NDArray
8
+ from torch import nn
9
+ from typing_extensions import override
10
+
11
+ from sae_lens.saes.sae import (
12
+ SAE,
13
+ SAEConfig,
14
+ TrainCoefficientConfig,
15
+ TrainingSAE,
16
+ TrainingSAEConfig,
17
+ TrainStepInput,
18
+ )
19
+ from sae_lens.util import filter_valid_dataclass_fields
20
+
21
+
22
+ @dataclass
23
+ class StandardSAEConfig(SAEConfig):
24
+ """
25
+ Configuration class for a StandardSAE.
26
+ """
27
+
28
+ @override
29
+ @classmethod
30
+ def architecture(cls) -> str:
31
+ return "standard"
32
+
33
+
34
+ class StandardSAE(SAE[StandardSAEConfig]):
35
+ """
36
+ StandardSAE is an inference-only implementation of a Sparse Autoencoder (SAE)
37
+ using a simple linear encoder and decoder.
38
+
39
+ It implements the required abstract methods from BaseSAE:
40
+ - initialize_weights: sets up simple parameter initializations for W_enc, b_enc, W_dec, and b_dec.
41
+ - encode: computes the feature activations from an input.
42
+ - decode: reconstructs the input from the feature activations.
43
+
44
+ The BaseSAE.forward() method automatically calls encode and decode,
45
+ including any error-term processing if configured.
46
+ """
47
+
48
+ b_enc: nn.Parameter
49
+
50
+ def __init__(self, cfg: StandardSAEConfig, use_error_term: bool = False):
51
+ super().__init__(cfg, use_error_term)
52
+
53
+ @override
54
+ def initialize_weights(self) -> None:
55
+ # Initialize encoder weights and bias.
56
+ super().initialize_weights()
57
+ _init_weights_standard(self)
58
+
59
+ def encode(
60
+ self, x: Float[torch.Tensor, "... d_in"]
61
+ ) -> Float[torch.Tensor, "... d_sae"]:
62
+ """
63
+ Encode the input tensor into the feature space.
64
+ For inference, no noise is added.
65
+ """
66
+ # Preprocess the SAE input (casting type, applying hooks, normalization)
67
+ sae_in = self.process_sae_in(x)
68
+ # Compute the pre-activation values
69
+ hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
70
+ # Apply the activation function (e.g., ReLU, depending on config)
71
+ return self.hook_sae_acts_post(self.activation_fn(hidden_pre))
72
+
73
+ def decode(
74
+ self, feature_acts: Float[torch.Tensor, "... d_sae"]
75
+ ) -> Float[torch.Tensor, "... d_in"]:
76
+ """
77
+ Decode the feature activations back to the input space.
78
+ Now, if hook_z reshaping is turned on, we reverse the flattening.
79
+ """
80
+ # 1) linear transform
81
+ sae_out_pre = feature_acts @ self.W_dec + self.b_dec
82
+ # 2) hook reconstruction
83
+ sae_out_pre = self.hook_sae_recons(sae_out_pre)
84
+ # 4) optional out-normalization (e.g. constant_norm_rescale)
85
+ sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
86
+ # 5) if hook_z is enabled, rearrange back to (..., n_heads, d_head).
87
+ return self.reshape_fn_out(sae_out_pre, self.d_head)
88
+
89
+
90
+ @dataclass
91
+ class StandardTrainingSAEConfig(TrainingSAEConfig):
92
+ """
93
+ Configuration class for training a StandardTrainingSAE.
94
+ """
95
+
96
+ l1_coefficient: float = 1.0
97
+ lp_norm: float = 1.0
98
+ l1_warm_up_steps: int = 0
99
+
100
+ @override
101
+ @classmethod
102
+ def architecture(cls) -> str:
103
+ return "standard"
104
+
105
+
106
+ class StandardTrainingSAE(TrainingSAE[StandardTrainingSAEConfig]):
107
+ """
108
+ StandardTrainingSAE is a concrete implementation of BaseTrainingSAE using the "standard" SAE architecture.
109
+ It implements:
110
+ - initialize_weights: basic weight initialization for encoder/decoder.
111
+ - encode: inference encoding (invokes encode_with_hidden_pre).
112
+ - decode: a simple linear decoder.
113
+ - encode_with_hidden_pre: computes pre-activations, adds noise when training, and then activates.
114
+ - calculate_aux_loss: computes a sparsity penalty based on the (optionally scaled) p-norm of feature activations.
115
+ """
116
+
117
+ b_enc: nn.Parameter
118
+
119
+ def initialize_weights(self) -> None:
120
+ super().initialize_weights()
121
+ _init_weights_standard(self)
122
+
123
+ @override
124
+ def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]:
125
+ return {
126
+ "l1": TrainCoefficientConfig(
127
+ value=self.cfg.l1_coefficient,
128
+ warm_up_steps=self.cfg.l1_warm_up_steps,
129
+ ),
130
+ }
131
+
132
+ def encode_with_hidden_pre(
133
+ self, x: Float[torch.Tensor, "... d_in"]
134
+ ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
135
+ # Process the input (including dtype conversion, hook call, and any activation normalization)
136
+ sae_in = self.process_sae_in(x)
137
+ # Compute the pre-activation (and allow for a hook if desired)
138
+ hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc) # type: ignore
139
+ # Apply the activation function (and any post-activation hook)
140
+ feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre))
141
+ return feature_acts, hidden_pre
142
+
143
+ def calculate_aux_loss(
144
+ self,
145
+ step_input: TrainStepInput,
146
+ feature_acts: torch.Tensor,
147
+ hidden_pre: torch.Tensor,
148
+ sae_out: torch.Tensor,
149
+ ) -> dict[str, torch.Tensor]:
150
+ # The "standard" auxiliary loss is a sparsity penalty on the feature activations
151
+ weighted_feature_acts = feature_acts * self.W_dec.norm(dim=1)
152
+
153
+ # Compute the p-norm (set by cfg.lp_norm) over the feature dimension
154
+ sparsity = weighted_feature_acts.norm(p=self.cfg.lp_norm, dim=-1)
155
+ l1_loss = (step_input.coefficients["l1"] * sparsity).mean()
156
+
157
+ return {"l1_loss": l1_loss}
158
+
159
+ def log_histograms(self) -> dict[str, NDArray[np.generic]]:
160
+ """Log histograms of the weights and biases."""
161
+ b_e_dist = self.b_enc.detach().float().cpu().numpy()
162
+ return {
163
+ **super().log_histograms(),
164
+ "weights/b_e": b_e_dist,
165
+ }
166
+
167
+ def to_inference_config_dict(self) -> dict[str, Any]:
168
+ return filter_valid_dataclass_fields(
169
+ self.cfg.to_dict(), StandardSAEConfig, ["architecture"]
170
+ )
171
+
172
+
173
+ def _init_weights_standard(
174
+ sae: SAE[StandardSAEConfig] | TrainingSAE[StandardTrainingSAEConfig],
175
+ ) -> None:
176
+ sae.b_enc = nn.Parameter(
177
+ torch.zeros(sae.cfg.d_sae, dtype=sae.dtype, device=sae.device)
178
+ )
@@ -0,0 +1,300 @@
1
+ """Inference-only TopKSAE variant, similar in spirit to StandardSAE but using a TopK-based activation."""
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Callable
5
+
6
+ import torch
7
+ from jaxtyping import Float
8
+ from torch import nn
9
+ from typing_extensions import override
10
+
11
+ from sae_lens.saes.sae import (
12
+ SAE,
13
+ SAEConfig,
14
+ TrainCoefficientConfig,
15
+ TrainingSAE,
16
+ TrainingSAEConfig,
17
+ TrainStepInput,
18
+ )
19
+ from sae_lens.util import filter_valid_dataclass_fields
20
+
21
+
22
+ class TopK(nn.Module):
23
+ """
24
+ A simple TopK activation that zeroes out all but the top K elements along the last dimension,
25
+ then optionally applies a post-activation function (e.g., ReLU).
26
+ """
27
+
28
+ b_enc: nn.Parameter
29
+
30
+ def __init__(
31
+ self,
32
+ k: int,
33
+ postact_fn: Callable[[torch.Tensor], torch.Tensor] = nn.ReLU(),
34
+ ):
35
+ super().__init__()
36
+ self.k = k
37
+ self.postact_fn = postact_fn
38
+
39
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
40
+ """
41
+ 1) Select top K elements along the last dimension.
42
+ 2) Apply post-activation (often ReLU).
43
+ 3) Zero out all other entries.
44
+ """
45
+ topk = torch.topk(x, k=self.k, dim=-1)
46
+ values = self.postact_fn(topk.values)
47
+ result = torch.zeros_like(x)
48
+ result.scatter_(-1, topk.indices, values)
49
+ return result
50
+
51
+
52
+ @dataclass
53
+ class TopKSAEConfig(SAEConfig):
54
+ """
55
+ Configuration class for a TopKSAE.
56
+ """
57
+
58
+ k: int = 100
59
+
60
+ @override
61
+ @classmethod
62
+ def architecture(cls) -> str:
63
+ return "topk"
64
+
65
+
66
+ class TopKSAE(SAE[TopKSAEConfig]):
67
+ """
68
+ An inference-only sparse autoencoder using a "topk" activation function.
69
+ It uses linear encoder and decoder layers, applying the TopK activation
70
+ to the hidden pre-activation in its encode step.
71
+ """
72
+
73
+ b_enc: nn.Parameter
74
+
75
+ def __init__(self, cfg: TopKSAEConfig, use_error_term: bool = False):
76
+ """
77
+ Args:
78
+ cfg: SAEConfig defining model size and behavior.
79
+ use_error_term: Whether to apply the error-term approach in the forward pass.
80
+ """
81
+ super().__init__(cfg, use_error_term)
82
+
83
+ @override
84
+ def initialize_weights(self) -> None:
85
+ # Initialize encoder weights and bias.
86
+ super().initialize_weights()
87
+ _init_weights_topk(self)
88
+
89
+ def encode(
90
+ self, x: Float[torch.Tensor, "... d_in"]
91
+ ) -> Float[torch.Tensor, "... d_sae"]:
92
+ """
93
+ Converts input x into feature activations.
94
+ Uses topk activation under the hood.
95
+ """
96
+ sae_in = self.process_sae_in(x)
97
+ hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
98
+ # The BaseSAE already sets self.activation_fn to TopK(...) if config requests topk.
99
+ return self.hook_sae_acts_post(self.activation_fn(hidden_pre))
100
+
101
+ def decode(
102
+ self, feature_acts: Float[torch.Tensor, "... d_sae"]
103
+ ) -> Float[torch.Tensor, "... d_in"]:
104
+ """
105
+ Reconstructs the input from topk feature activations.
106
+ Applies optional finetuning scaling, hooking to recons, out normalization,
107
+ and optional head reshaping.
108
+ """
109
+ sae_out_pre = feature_acts @ self.W_dec + self.b_dec
110
+ sae_out_pre = self.hook_sae_recons(sae_out_pre)
111
+ sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
112
+ return self.reshape_fn_out(sae_out_pre, self.d_head)
113
+
114
+ @override
115
+ def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
116
+ return TopK(self.cfg.k)
117
+
118
+ @override
119
+ @torch.no_grad()
120
+ def fold_W_dec_norm(self) -> None:
121
+ raise NotImplementedError(
122
+ "Folding W_dec_norm is not safe for TopKSAEs, as this may change the topk activations"
123
+ )
124
+
125
+
126
+ @dataclass
127
+ class TopKTrainingSAEConfig(TrainingSAEConfig):
128
+ """
129
+ Configuration class for training a TopKTrainingSAE.
130
+ """
131
+
132
+ k: int = 100
133
+
134
+ @override
135
+ @classmethod
136
+ def architecture(cls) -> str:
137
+ return "topk"
138
+
139
+
140
+ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
141
+ """
142
+ TopK variant with training functionality. Injects noise during training, optionally
143
+ calculates a topk-related auxiliary loss, etc.
144
+ """
145
+
146
+ b_enc: nn.Parameter
147
+
148
+ def __init__(self, cfg: TopKTrainingSAEConfig, use_error_term: bool = False):
149
+ super().__init__(cfg, use_error_term)
150
+
151
+ @override
152
+ def initialize_weights(self) -> None:
153
+ super().initialize_weights()
154
+ _init_weights_topk(self)
155
+
156
+ def encode_with_hidden_pre(
157
+ self, x: Float[torch.Tensor, "... d_in"]
158
+ ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
159
+ """
160
+ Similar to the base training method: cast input, optionally add noise, then apply TopK.
161
+ """
162
+ sae_in = self.process_sae_in(x)
163
+ hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
164
+
165
+ # Apply the TopK activation function (already set in self.activation_fn if config is "topk")
166
+ feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre))
167
+ return feature_acts, hidden_pre
168
+
169
+ @override
170
+ def calculate_aux_loss(
171
+ self,
172
+ step_input: TrainStepInput,
173
+ feature_acts: torch.Tensor,
174
+ hidden_pre: torch.Tensor,
175
+ sae_out: torch.Tensor,
176
+ ) -> dict[str, torch.Tensor]:
177
+ # Calculate the auxiliary loss for dead neurons
178
+ topk_loss = self.calculate_topk_aux_loss(
179
+ sae_in=step_input.sae_in,
180
+ sae_out=sae_out,
181
+ hidden_pre=hidden_pre,
182
+ dead_neuron_mask=step_input.dead_neuron_mask,
183
+ )
184
+ return {"auxiliary_reconstruction_loss": topk_loss}
185
+
186
+ @override
187
+ @torch.no_grad()
188
+ def fold_W_dec_norm(self) -> None:
189
+ raise NotImplementedError(
190
+ "Folding W_dec_norm is not safe for TopKSAEs, as this may change the topk activations"
191
+ )
192
+
193
+ @override
194
+ def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
195
+ return TopK(self.cfg.k)
196
+
197
+ @override
198
+ def get_coefficients(self) -> dict[str, TrainCoefficientConfig | float]:
199
+ return {}
200
+
201
+ def calculate_topk_aux_loss(
202
+ self,
203
+ sae_in: torch.Tensor,
204
+ sae_out: torch.Tensor,
205
+ hidden_pre: torch.Tensor,
206
+ dead_neuron_mask: torch.Tensor | None,
207
+ ) -> torch.Tensor:
208
+ """
209
+ Calculate TopK auxiliary loss.
210
+
211
+ This auxiliary loss encourages dead neurons to learn useful features by having
212
+ them reconstruct the residual error from the live neurons. It's a key part of
213
+ preventing neuron death in TopK SAEs.
214
+ """
215
+ # Mostly taken from https://github.com/EleutherAI/sae/blob/main/sae/sae.py, except without variance normalization
216
+ # NOTE: checking the number of dead neurons will force a GPU sync, so performance can likely be improved here
217
+ if dead_neuron_mask is None or (num_dead := int(dead_neuron_mask.sum())) == 0:
218
+ return sae_out.new_tensor(0.0)
219
+ residual = (sae_in - sae_out).detach()
220
+
221
+ # Heuristic from Appendix B.1 in the paper
222
+ k_aux = sae_in.shape[-1] // 2
223
+
224
+ # Reduce the scale of the loss if there are a small number of dead latents
225
+ scale = min(num_dead / k_aux, 1.0)
226
+ k_aux = min(k_aux, num_dead)
227
+
228
+ auxk_acts = _calculate_topk_aux_acts(
229
+ k_aux=k_aux,
230
+ hidden_pre=hidden_pre,
231
+ dead_neuron_mask=dead_neuron_mask,
232
+ )
233
+
234
+ # Encourage the top ~50% of dead latents to predict the residual of the
235
+ # top k living latents
236
+ recons = self.decode(auxk_acts)
237
+ auxk_loss = (recons - residual).pow(2).sum(dim=-1).mean()
238
+ return scale * auxk_loss
239
+
240
+ def _calculate_topk_aux_acts(
241
+ self,
242
+ k_aux: int,
243
+ hidden_pre: torch.Tensor,
244
+ dead_neuron_mask: torch.Tensor,
245
+ ) -> torch.Tensor:
246
+ """
247
+ Helper method to calculate activations for the auxiliary loss.
248
+
249
+ Args:
250
+ k_aux: Number of top dead neurons to select
251
+ hidden_pre: Pre-activation values from encoder
252
+ dead_neuron_mask: Boolean mask indicating which neurons are dead
253
+
254
+ Returns:
255
+ Tensor with activations for only the top-k dead neurons, zeros elsewhere
256
+ """
257
+ # Don't include living latents in this loss (set them to -inf so they won't be selected)
258
+ auxk_latents = torch.where(
259
+ dead_neuron_mask[None],
260
+ hidden_pre,
261
+ torch.tensor(-float("inf"), device=hidden_pre.device),
262
+ )
263
+
264
+ # Find topk values among dead neurons
265
+ auxk_topk = auxk_latents.topk(k_aux, dim=-1, sorted=False)
266
+
267
+ # Create a tensor of zeros, then place the topk values at their proper indices
268
+ auxk_acts = torch.zeros_like(hidden_pre)
269
+ auxk_acts.scatter_(-1, auxk_topk.indices, auxk_topk.values)
270
+
271
+ return auxk_acts
272
+
273
+ def to_inference_config_dict(self) -> dict[str, Any]:
274
+ return filter_valid_dataclass_fields(
275
+ self.cfg.to_dict(), TopKSAEConfig, ["architecture"]
276
+ )
277
+
278
+
279
+ def _calculate_topk_aux_acts(
280
+ k_aux: int,
281
+ hidden_pre: torch.Tensor,
282
+ dead_neuron_mask: torch.Tensor,
283
+ ) -> torch.Tensor:
284
+ # Don't include living latents in this loss
285
+ auxk_latents = torch.where(dead_neuron_mask[None], hidden_pre, -torch.inf)
286
+ # Top-k dead latents
287
+ auxk_topk = auxk_latents.topk(k_aux, sorted=False)
288
+ # Set the activations to zero for all but the top k_aux dead latents
289
+ auxk_acts = torch.zeros_like(hidden_pre)
290
+ auxk_acts.scatter_(-1, auxk_topk.indices, auxk_topk.values)
291
+ # Set activations to zero for all but top k_aux dead latents
292
+ return auxk_acts
293
+
294
+
295
+ def _init_weights_topk(
296
+ sae: SAE[TopKSAEConfig] | TrainingSAE[TopKTrainingSAEConfig],
297
+ ) -> None:
298
+ sae.b_enc = nn.Parameter(
299
+ torch.zeros(sae.cfg.d_sae, dtype=sae.dtype, device=sae.device)
300
+ )
@@ -0,0 +1,53 @@
1
+ import json
2
+ from dataclasses import dataclass
3
+ from statistics import mean
4
+
5
+ import torch
6
+ from tqdm.auto import tqdm
7
+
8
+ from sae_lens.training.types import DataProvider
9
+
10
+
11
+ @dataclass
12
+ class ActivationScaler:
13
+ scaling_factor: float | None = None
14
+
15
+ def scale(self, acts: torch.Tensor) -> torch.Tensor:
16
+ return acts if self.scaling_factor is None else acts * self.scaling_factor
17
+
18
+ def unscale(self, acts: torch.Tensor) -> torch.Tensor:
19
+ return acts if self.scaling_factor is None else acts / self.scaling_factor
20
+
21
+ def __call__(self, acts: torch.Tensor) -> torch.Tensor:
22
+ return self.scale(acts)
23
+
24
+ @torch.no_grad()
25
+ def _calculate_mean_norm(
26
+ self, data_provider: DataProvider, n_batches_for_norm_estimate: int = int(1e3)
27
+ ) -> float:
28
+ norms_per_batch: list[float] = []
29
+ for _ in tqdm(
30
+ range(n_batches_for_norm_estimate), desc="Estimating norm scaling factor"
31
+ ):
32
+ acts = next(data_provider)
33
+ norms_per_batch.append(acts.norm(dim=-1).mean().item())
34
+ return mean(norms_per_batch)
35
+
36
+ def estimate_scaling_factor(
37
+ self,
38
+ d_in: int,
39
+ data_provider: DataProvider,
40
+ n_batches_for_norm_estimate: int = int(1e3),
41
+ ):
42
+ mean_norm = self._calculate_mean_norm(
43
+ data_provider, n_batches_for_norm_estimate
44
+ )
45
+ self.scaling_factor = (d_in**0.5) / mean_norm
46
+
47
+ def save(self, file_path: str):
48
+ """save the state dict to a file in json format"""
49
+ if not file_path.endswith(".json"):
50
+ raise ValueError("file_path must end with .json")
51
+
52
+ with open(file_path, "w") as f:
53
+ json.dump({"scaling_factor": self.scaling_factor}, f)