sae-lens 5.9.1__py3-none-any.whl → 6.0.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,167 @@
1
+ import numpy as np
2
+ import torch
3
+ from jaxtyping import Float
4
+ from numpy.typing import NDArray
5
+ from torch import nn
6
+
7
+ from sae_lens.saes.sae import SAE, SAEConfig, TrainingSAE, TrainStepInput
8
+
9
+
10
+ class StandardSAE(SAE):
11
+ """
12
+ StandardSAE is an inference-only implementation of a Sparse Autoencoder (SAE)
13
+ using a simple linear encoder and decoder.
14
+
15
+ It implements the required abstract methods from BaseSAE:
16
+ - initialize_weights: sets up simple parameter initializations for W_enc, b_enc, W_dec, and b_dec.
17
+ - encode: computes the feature activations from an input.
18
+ - decode: reconstructs the input from the feature activations.
19
+
20
+ The BaseSAE.forward() method automatically calls encode and decode,
21
+ including any error-term processing if configured.
22
+ """
23
+
24
+ b_enc: nn.Parameter
25
+
26
+ def __init__(self, cfg: SAEConfig, use_error_term: bool = False):
27
+ super().__init__(cfg, use_error_term)
28
+
29
+ def initialize_weights(self) -> None:
30
+ # Initialize encoder weights and bias.
31
+ self.b_enc = nn.Parameter(
32
+ torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
33
+ )
34
+ self.b_dec = nn.Parameter(
35
+ torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
36
+ )
37
+
38
+ # Use Kaiming Uniform for W_enc
39
+ w_enc_data = torch.empty(
40
+ self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
41
+ )
42
+ nn.init.kaiming_uniform_(w_enc_data)
43
+ self.W_enc = nn.Parameter(w_enc_data)
44
+
45
+ # Use Kaiming Uniform for W_dec
46
+ w_dec_data = torch.empty(
47
+ self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
48
+ )
49
+ nn.init.kaiming_uniform_(w_dec_data)
50
+ self.W_dec = nn.Parameter(w_dec_data)
51
+
52
+ def encode(
53
+ self, x: Float[torch.Tensor, "... d_in"]
54
+ ) -> Float[torch.Tensor, "... d_sae"]:
55
+ """
56
+ Encode the input tensor into the feature space.
57
+ For inference, no noise is added.
58
+ """
59
+ # Preprocess the SAE input (casting type, applying hooks, normalization)
60
+ sae_in = self.process_sae_in(x)
61
+ # Compute the pre-activation values
62
+ hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
63
+ # Apply the activation function (e.g., ReLU, tanh-relu, depending on config)
64
+ return self.hook_sae_acts_post(self.activation_fn(hidden_pre))
65
+
66
+ def decode(
67
+ self, feature_acts: Float[torch.Tensor, "... d_sae"]
68
+ ) -> Float[torch.Tensor, "... d_in"]:
69
+ """
70
+ Decode the feature activations back to the input space.
71
+ Now, if hook_z reshaping is turned on, we reverse the flattening.
72
+ """
73
+ # 1) apply finetuning scaling if configured.
74
+ scaled_features = self.apply_finetuning_scaling_factor(feature_acts)
75
+ # 2) linear transform
76
+ sae_out_pre = scaled_features @ self.W_dec + self.b_dec
77
+ # 3) hook reconstruction
78
+ sae_out_pre = self.hook_sae_recons(sae_out_pre)
79
+ # 4) optional out-normalization (e.g. constant_norm_rescale or layer_norm)
80
+ sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
81
+ # 5) if hook_z is enabled, rearrange back to (..., n_heads, d_head).
82
+ return self.reshape_fn_out(sae_out_pre, self.d_head)
83
+
84
+
85
+ class StandardTrainingSAE(TrainingSAE):
86
+ """
87
+ StandardTrainingSAE is a concrete implementation of BaseTrainingSAE using the "standard" SAE architecture.
88
+ It implements:
89
+ - initialize_weights: basic weight initialization for encoder/decoder.
90
+ - encode: inference encoding (invokes encode_with_hidden_pre).
91
+ - decode: a simple linear decoder.
92
+ - encode_with_hidden_pre: computes pre-activations, adds noise when training, and then activates.
93
+ - calculate_aux_loss: computes a sparsity penalty based on the (optionally scaled) p-norm of feature activations.
94
+ """
95
+
96
+ b_enc: nn.Parameter
97
+
98
+ def initialize_weights(self) -> None:
99
+ # Basic init
100
+ # In Python MRO, this calls StandardSAE.initialize_weights()
101
+ StandardSAE.initialize_weights(self) # type: ignore
102
+
103
+ # Complex init logic from original TrainingSAE
104
+ if self.cfg.decoder_orthogonal_init:
105
+ self.W_dec.data = nn.init.orthogonal_(self.W_dec.data.T).T
106
+
107
+ elif self.cfg.decoder_heuristic_init:
108
+ self.W_dec.data = torch.rand( # Changed from Parameter to data assignment
109
+ self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
110
+ )
111
+ self.initialize_decoder_norm_constant_norm()
112
+
113
+ if self.cfg.init_encoder_as_decoder_transpose:
114
+ self.W_enc.data = self.W_dec.data.T.clone().contiguous() # type: ignore
115
+
116
+ if self.cfg.normalize_sae_decoder:
117
+ with torch.no_grad():
118
+ self.set_decoder_norm_to_unit_norm()
119
+
120
+ @torch.no_grad()
121
+ def initialize_decoder_norm_constant_norm(self, norm: float = 0.1):
122
+ self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True) # type: ignore
123
+ self.W_dec.data *= norm
124
+
125
+ def encode_with_hidden_pre(
126
+ self, x: Float[torch.Tensor, "... d_in"]
127
+ ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
128
+ # Process the input (including dtype conversion, hook call, and any activation normalization)
129
+ sae_in = self.process_sae_in(x)
130
+ # Compute the pre-activation (and allow for a hook if desired)
131
+ hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc) # type: ignore
132
+ # Add noise during training for robustness (scaled by noise_scale from the configuration)
133
+ if self.training and self.cfg.noise_scale > 0:
134
+ hidden_pre_noised = (
135
+ hidden_pre + torch.randn_like(hidden_pre) * self.cfg.noise_scale
136
+ )
137
+ else:
138
+ hidden_pre_noised = hidden_pre
139
+ # Apply the activation function (and any post-activation hook)
140
+ feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre_noised))
141
+ return feature_acts, hidden_pre_noised
142
+
143
+ def calculate_aux_loss(
144
+ self,
145
+ step_input: TrainStepInput,
146
+ feature_acts: torch.Tensor,
147
+ hidden_pre: torch.Tensor,
148
+ sae_out: torch.Tensor,
149
+ ) -> dict[str, torch.Tensor]:
150
+ # The "standard" auxiliary loss is a sparsity penalty on the feature activations
151
+ weighted_feature_acts = feature_acts
152
+ if self.cfg.scale_sparsity_penalty_by_decoder_norm:
153
+ weighted_feature_acts = feature_acts * self.W_dec.norm(dim=1)
154
+
155
+ # Compute the p-norm (set by cfg.lp_norm) over the feature dimension
156
+ sparsity = weighted_feature_acts.norm(p=self.cfg.lp_norm, dim=-1)
157
+ l1_loss = (step_input.current_l1_coefficient * sparsity).mean()
158
+
159
+ return {"l1_loss": l1_loss}
160
+
161
+ def log_histograms(self) -> dict[str, NDArray[np.generic]]:
162
+ """Log histograms of the weights and biases."""
163
+ b_e_dist = self.b_enc.detach().float().cpu().numpy()
164
+ return {
165
+ **super().log_histograms(),
166
+ "weights/b_e": b_e_dist,
167
+ }
@@ -0,0 +1,305 @@
1
+ """Inference-only TopKSAE variant, similar in spirit to StandardSAE but using a TopK-based activation."""
2
+
3
+ from typing import Callable
4
+
5
+ import torch
6
+ from jaxtyping import Float
7
+ from torch import nn
8
+
9
+ from sae_lens.saes.sae import (
10
+ SAE,
11
+ SAEConfig,
12
+ TrainingSAE,
13
+ TrainingSAEConfig,
14
+ TrainStepInput,
15
+ )
16
+
17
+
18
+ class TopK(nn.Module):
19
+ """
20
+ A simple TopK activation that zeroes out all but the top K elements along the last dimension,
21
+ then optionally applies a post-activation function (e.g., ReLU).
22
+ """
23
+
24
+ b_enc: nn.Parameter
25
+
26
+ def __init__(
27
+ self,
28
+ k: int,
29
+ postact_fn: Callable[[torch.Tensor], torch.Tensor] = nn.ReLU(),
30
+ ):
31
+ super().__init__()
32
+ self.k = k
33
+ self.postact_fn = postact_fn
34
+
35
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
36
+ """
37
+ 1) Select top K elements along the last dimension.
38
+ 2) Apply post-activation (often ReLU).
39
+ 3) Zero out all other entries.
40
+ """
41
+ topk = torch.topk(x, k=self.k, dim=-1)
42
+ values = self.postact_fn(topk.values)
43
+ result = torch.zeros_like(x)
44
+ result.scatter_(-1, topk.indices, values)
45
+ return result
46
+
47
+
48
+ class TopKSAE(SAE):
49
+ """
50
+ An inference-only sparse autoencoder using a "topk" activation function.
51
+ It uses linear encoder and decoder layers, applying the TopK activation
52
+ to the hidden pre-activation in its encode step.
53
+ """
54
+
55
+ def __init__(self, cfg: SAEConfig, use_error_term: bool = False):
56
+ """
57
+ Args:
58
+ cfg: SAEConfig defining model size and behavior.
59
+ use_error_term: Whether to apply the error-term approach in the forward pass.
60
+ """
61
+ super().__init__(cfg, use_error_term)
62
+
63
+ if self.cfg.activation_fn != "topk":
64
+ raise ValueError("TopKSAE must use a TopK activation function.")
65
+
66
+ def initialize_weights(self) -> None:
67
+ """
68
+ Initializes weights and biases for encoder/decoder similarly to the standard SAE,
69
+ that is:
70
+ - b_enc, b_dec are zero-initialized
71
+ - W_enc, W_dec are Kaiming Uniform
72
+ """
73
+ # encoder bias
74
+ self.b_enc = nn.Parameter(
75
+ torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
76
+ )
77
+ # decoder bias
78
+ self.b_dec = nn.Parameter(
79
+ torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
80
+ )
81
+
82
+ # encoder weight
83
+ w_enc_data = torch.empty(
84
+ self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
85
+ )
86
+ nn.init.kaiming_uniform_(w_enc_data)
87
+ self.W_enc = nn.Parameter(w_enc_data)
88
+
89
+ # decoder weight
90
+ w_dec_data = torch.empty(
91
+ self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
92
+ )
93
+ nn.init.kaiming_uniform_(w_dec_data)
94
+ self.W_dec = nn.Parameter(w_dec_data)
95
+
96
+ def encode(
97
+ self, x: Float[torch.Tensor, "... d_in"]
98
+ ) -> Float[torch.Tensor, "... d_sae"]:
99
+ """
100
+ Converts input x into feature activations.
101
+ Uses topk activation from the config (cfg.activation_fn == "topk")
102
+ under the hood.
103
+ """
104
+ sae_in = self.process_sae_in(x)
105
+ hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
106
+ # The BaseSAE already sets self.activation_fn to TopK(...) if config requests topk.
107
+ return self.hook_sae_acts_post(self.activation_fn(hidden_pre))
108
+
109
+ def decode(
110
+ self, feature_acts: Float[torch.Tensor, "... d_sae"]
111
+ ) -> Float[torch.Tensor, "... d_in"]:
112
+ """
113
+ Reconstructs the input from topk feature activations.
114
+ Applies optional finetuning scaling, hooking to recons, out normalization,
115
+ and optional head reshaping.
116
+ """
117
+ scaled_features = self.apply_finetuning_scaling_factor(feature_acts)
118
+ sae_out_pre = scaled_features @ self.W_dec + self.b_dec
119
+ sae_out_pre = self.hook_sae_recons(sae_out_pre)
120
+ sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
121
+ return self.reshape_fn_out(sae_out_pre, self.d_head)
122
+
123
+ def _get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
124
+ if self.cfg.activation_fn == "topk":
125
+ if "k" not in self.cfg.activation_fn_kwargs:
126
+ raise ValueError("TopK activation function requires a k value.")
127
+ k = self.cfg.activation_fn_kwargs.get(
128
+ "k", 1
129
+ ) # Default k to 1 if not provided
130
+ postact_fn = self.cfg.activation_fn_kwargs.get(
131
+ "postact_fn", nn.ReLU()
132
+ ) # Default post-activation to ReLU if not provided
133
+ return TopK(k, postact_fn)
134
+ # Otherwise, return the "standard" handling from BaseSAE
135
+ return super()._get_activation_fn()
136
+
137
+
138
+ class TopKTrainingSAE(TrainingSAE):
139
+ """
140
+ TopK variant with training functionality. Injects noise during training, optionally
141
+ calculates a topk-related auxiliary loss, etc.
142
+ """
143
+
144
+ b_enc: nn.Parameter
145
+
146
+ def __init__(self, cfg: TrainingSAEConfig, use_error_term: bool = False):
147
+ super().__init__(cfg, use_error_term)
148
+
149
+ if self.cfg.activation_fn != "topk":
150
+ raise ValueError("TopKSAE must use a TopK activation function.")
151
+
152
+ def initialize_weights(self) -> None:
153
+ """Very similar to TopKSAE, using zero biases + Kaiming Uniform weights."""
154
+ self.b_enc = nn.Parameter(
155
+ torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
156
+ )
157
+ self.b_dec = nn.Parameter(
158
+ torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
159
+ )
160
+
161
+ w_enc_data = torch.empty(
162
+ self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
163
+ )
164
+ nn.init.kaiming_uniform_(w_enc_data)
165
+ self.W_enc = nn.Parameter(w_enc_data)
166
+
167
+ w_dec_data = torch.empty(
168
+ self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
169
+ )
170
+ nn.init.kaiming_uniform_(w_dec_data)
171
+ self.W_dec = nn.Parameter(w_dec_data)
172
+
173
+ def encode_with_hidden_pre(
174
+ self, x: Float[torch.Tensor, "... d_in"]
175
+ ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
176
+ """
177
+ Similar to the base training method: cast input, optionally add noise, then apply TopK.
178
+ """
179
+ sae_in = self.process_sae_in(x)
180
+ hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
181
+
182
+ # Inject noise if training
183
+ if self.training and self.cfg.noise_scale > 0:
184
+ hidden_pre_noised = (
185
+ hidden_pre + torch.randn_like(hidden_pre) * self.cfg.noise_scale
186
+ )
187
+ else:
188
+ hidden_pre_noised = hidden_pre
189
+
190
+ # Apply the TopK activation function (already set in self.activation_fn if config is "topk")
191
+ feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre_noised))
192
+ return feature_acts, hidden_pre_noised
193
+
194
+ def calculate_aux_loss(
195
+ self,
196
+ step_input: TrainStepInput,
197
+ feature_acts: torch.Tensor,
198
+ hidden_pre: torch.Tensor,
199
+ sae_out: torch.Tensor,
200
+ ) -> dict[str, torch.Tensor]:
201
+ # Calculate the auxiliary loss for dead neurons
202
+ topk_loss = self.calculate_topk_aux_loss(
203
+ sae_in=step_input.sae_in,
204
+ sae_out=sae_out,
205
+ hidden_pre=hidden_pre,
206
+ dead_neuron_mask=step_input.dead_neuron_mask,
207
+ )
208
+ return {"auxiliary_reconstruction_loss": topk_loss}
209
+
210
+ def _get_activation_fn(self):
211
+ if self.cfg.activation_fn == "topk":
212
+ if "k" not in self.cfg.activation_fn_kwargs:
213
+ raise ValueError("TopK activation function requires a k value.")
214
+ k = self.cfg.activation_fn_kwargs.get("k", 1)
215
+ postact_fn = self.cfg.activation_fn_kwargs.get("postact_fn", nn.ReLU())
216
+ return TopK(k, postact_fn)
217
+ return super()._get_activation_fn()
218
+
219
+ def calculate_topk_aux_loss(
220
+ self,
221
+ sae_in: torch.Tensor,
222
+ sae_out: torch.Tensor,
223
+ hidden_pre: torch.Tensor,
224
+ dead_neuron_mask: torch.Tensor | None,
225
+ ) -> torch.Tensor:
226
+ """
227
+ Calculate TopK auxiliary loss.
228
+
229
+ This auxiliary loss encourages dead neurons to learn useful features by having
230
+ them reconstruct the residual error from the live neurons. It's a key part of
231
+ preventing neuron death in TopK SAEs.
232
+ """
233
+ # Mostly taken from https://github.com/EleutherAI/sae/blob/main/sae/sae.py, except without variance normalization
234
+ # NOTE: checking the number of dead neurons will force a GPU sync, so performance can likely be improved here
235
+ if dead_neuron_mask is None or (num_dead := int(dead_neuron_mask.sum())) == 0:
236
+ return sae_out.new_tensor(0.0)
237
+ residual = (sae_in - sae_out).detach()
238
+
239
+ # Heuristic from Appendix B.1 in the paper
240
+ k_aux = sae_in.shape[-1] // 2
241
+
242
+ # Reduce the scale of the loss if there are a small number of dead latents
243
+ scale = min(num_dead / k_aux, 1.0)
244
+ k_aux = min(k_aux, num_dead)
245
+
246
+ auxk_acts = _calculate_topk_aux_acts(
247
+ k_aux=k_aux,
248
+ hidden_pre=hidden_pre,
249
+ dead_neuron_mask=dead_neuron_mask,
250
+ )
251
+
252
+ # Encourage the top ~50% of dead latents to predict the residual of the
253
+ # top k living latents
254
+ recons = self.decode(auxk_acts)
255
+ auxk_loss = (recons - residual).pow(2).sum(dim=-1).mean()
256
+ return scale * auxk_loss
257
+
258
+ def _calculate_topk_aux_acts(
259
+ self,
260
+ k_aux: int,
261
+ hidden_pre: torch.Tensor,
262
+ dead_neuron_mask: torch.Tensor,
263
+ ) -> torch.Tensor:
264
+ """
265
+ Helper method to calculate activations for the auxiliary loss.
266
+
267
+ Args:
268
+ k_aux: Number of top dead neurons to select
269
+ hidden_pre: Pre-activation values from encoder
270
+ dead_neuron_mask: Boolean mask indicating which neurons are dead
271
+
272
+ Returns:
273
+ Tensor with activations for only the top-k dead neurons, zeros elsewhere
274
+ """
275
+ # Don't include living latents in this loss (set them to -inf so they won't be selected)
276
+ auxk_latents = torch.where(
277
+ dead_neuron_mask[None],
278
+ hidden_pre,
279
+ torch.tensor(-float("inf"), device=hidden_pre.device),
280
+ )
281
+
282
+ # Find topk values among dead neurons
283
+ auxk_topk = auxk_latents.topk(k_aux, dim=-1, sorted=False)
284
+
285
+ # Create a tensor of zeros, then place the topk values at their proper indices
286
+ auxk_acts = torch.zeros_like(hidden_pre)
287
+ auxk_acts.scatter_(-1, auxk_topk.indices, auxk_topk.values)
288
+
289
+ return auxk_acts
290
+
291
+
292
+ def _calculate_topk_aux_acts(
293
+ k_aux: int,
294
+ hidden_pre: torch.Tensor,
295
+ dead_neuron_mask: torch.Tensor,
296
+ ) -> torch.Tensor:
297
+ # Don't include living latents in this loss
298
+ auxk_latents = torch.where(dead_neuron_mask[None], hidden_pre, -torch.inf)
299
+ # Top-k dead latents
300
+ auxk_topk = auxk_latents.topk(k_aux, sorted=False)
301
+ # Set the activations to zero for all but the top k_aux dead latents
302
+ auxk_acts = torch.zeros_like(hidden_pre)
303
+ auxk_acts.scatter_(-1, auxk_topk.indices, auxk_topk.values)
304
+ # Set activations to zero for all but top k_aux dead latents
305
+ return auxk_acts
@@ -28,7 +28,7 @@ from sae_lens.config import (
28
28
  HfDataset,
29
29
  LanguageModelSAERunnerConfig,
30
30
  )
31
- from sae_lens.sae import SAE
31
+ from sae_lens.saes.sae import SAE
32
32
  from sae_lens.tokenization_and_batching import concat_and_batch_sequences
33
33
 
34
34
 
@@ -177,7 +177,7 @@ class ActivationsStore:
177
177
  dataset_trust_remote_code=sae.cfg.dataset_trust_remote_code,
178
178
  dtype=sae.cfg.dtype,
179
179
  device=torch.device(device),
180
- seqpos_slice=sae.cfg.seqpos_slice,
180
+ seqpos_slice=sae.cfg.seqpos_slice or (None,),
181
181
  )
182
182
 
183
183
  def __init__(
@@ -11,9 +11,9 @@ from transformer_lens.hook_points import HookedRootModule
11
11
  from sae_lens import __version__
12
12
  from sae_lens.config import LanguageModelSAERunnerConfig
13
13
  from sae_lens.evals import EvalConfig, run_evals
14
+ from sae_lens.saes.sae import TrainingSAE, TrainStepInput, TrainStepOutput
14
15
  from sae_lens.training.activations_store import ActivationsStore
15
16
  from sae_lens.training.optim import L1Scheduler, get_lr_scheduler
16
- from sae_lens.training.training_sae import TrainingSAE, TrainStepOutput
17
17
 
18
18
  # used to map between parameters which are updated during finetuning and the config str.
19
19
  FINETUNING_PARAMETERS = {
@@ -186,7 +186,7 @@ class SAETrainer:
186
186
 
187
187
  step_output = self._train_step(sae=self.sae, sae_in=layer_acts)
188
188
 
189
- if self.cfg.log_to_wandb:
189
+ if self.cfg.logger.log_to_wandb:
190
190
  self._log_train_step(step_output)
191
191
  self._run_and_log_evals()
192
192
 
@@ -226,7 +226,7 @@ class SAETrainer:
226
226
 
227
227
  # log and then reset the feature sparsity every feature_sampling_window steps
228
228
  if (self.n_training_steps + 1) % self.cfg.feature_sampling_window == 0:
229
- if self.cfg.log_to_wandb:
229
+ if self.cfg.logger.log_to_wandb:
230
230
  sparsity_log_dict = self._build_sparsity_log_dict()
231
231
  wandb.log(sparsity_log_dict, step=self.n_training_steps)
232
232
  self._reset_running_sparsity_stats()
@@ -235,9 +235,11 @@ class SAETrainer:
235
235
  # https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html
236
236
  with self.autocast_if_enabled:
237
237
  train_step_output = self.sae.training_forward_pass(
238
- sae_in=sae_in,
239
- dead_neuron_mask=self.dead_neurons,
240
- current_l1_coefficient=self.current_l1_coefficient,
238
+ step_input=TrainStepInput(
239
+ sae_in=sae_in,
240
+ dead_neuron_mask=self.dead_neurons,
241
+ current_l1_coefficient=self.current_l1_coefficient,
242
+ ),
241
243
  )
242
244
 
243
245
  with torch.no_grad():
@@ -270,7 +272,7 @@ class SAETrainer:
270
272
 
271
273
  @torch.no_grad()
272
274
  def _log_train_step(self, step_output: TrainStepOutput):
273
- if (self.n_training_steps + 1) % self.cfg.wandb_log_frequency == 0:
275
+ if (self.n_training_steps + 1) % self.cfg.logger.wandb_log_frequency == 0:
274
276
  wandb.log(
275
277
  self._build_train_step_log_dict(
276
278
  output=step_output,
@@ -331,7 +333,8 @@ class SAETrainer:
331
333
  def _run_and_log_evals(self):
332
334
  # record loss frequently, but not all the time.
333
335
  if (self.n_training_steps + 1) % (
334
- self.cfg.wandb_log_frequency * self.cfg.eval_every_n_wandb_logs
336
+ self.cfg.logger.wandb_log_frequency
337
+ * self.cfg.logger.eval_every_n_wandb_logs
335
338
  ) == 0:
336
339
  self.sae.eval()
337
340
  ignore_tokens = set()
@@ -358,17 +361,8 @@ class SAETrainer:
358
361
  # Remove metrics that are not useful for wandb logging
359
362
  eval_metrics.pop("metrics/total_tokens_evaluated", None)
360
363
 
361
- W_dec_norm_dist = self.sae.W_dec.detach().float().norm(dim=1).cpu().numpy()
362
- eval_metrics["weights/W_dec_norms"] = wandb.Histogram(W_dec_norm_dist) # type: ignore
363
-
364
- if self.sae.cfg.architecture == "standard":
365
- b_e_dist = self.sae.b_enc.detach().float().cpu().numpy()
366
- eval_metrics["weights/b_e"] = wandb.Histogram(b_e_dist) # type: ignore
367
- elif self.sae.cfg.architecture == "gated":
368
- b_gate_dist = self.sae.b_gate.detach().float().cpu().numpy()
369
- eval_metrics["weights/b_gate"] = wandb.Histogram(b_gate_dist) # type: ignore
370
- b_mag_dist = self.sae.b_mag.detach().float().cpu().numpy()
371
- eval_metrics["weights/b_mag"] = wandb.Histogram(b_mag_dist) # type: ignore
364
+ for key, value in self.sae.log_histograms().items():
365
+ eval_metrics[key] = wandb.Histogram(value) # type: ignore
372
366
 
373
367
  wandb.log(
374
368
  eval_metrics,
@@ -14,7 +14,7 @@ from sae_lens.config import (
14
14
  SAE_WEIGHTS_FILENAME,
15
15
  SPARSITY_FILENAME,
16
16
  )
17
- from sae_lens.sae import SAE
17
+ from sae_lens.saes.sae import SAE
18
18
 
19
19
 
20
20
  def upload_saes_to_huggingface(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: sae-lens
3
- Version: 5.9.1
3
+ Version: 6.0.0rc1
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  Keywords: deep-learning,sparse-autoencoders,mechanistic-interpretability,PyTorch
@@ -89,7 +89,7 @@ Please cite the package as follows:
89
89
  ```
90
90
  @misc{bloom2024saetrainingcodebase,
91
91
  title = {SAELens},
92
- author = {Joseph Bloom, Curt Tigges, Anthony Duong and David Chanin},
92
+ author = {Bloom, Joseph and Tigges, Curt and Duong, Anthony and Chanin, David},
93
93
  year = {2024},
94
94
  howpublished = {\url{https://github.com/jbloomAus/SAELens}},
95
95
  }
@@ -0,0 +1,32 @@
1
+ sae_lens/__init__.py,sha256=ofQyurU7LtxIsg89QFCZe13QsdYpxErRI0x0tiCpB04,2074
2
+ sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
+ sae_lens/analysis/hooked_sae_transformer.py,sha256=RK0mcLhymXdJInXHcagQggxW9Qf4ptePnH7sKXvGGaU,13727
4
+ sae_lens/analysis/neuronpedia_integration.py,sha256=dFiKRWfuT5iUfTPBPmZydSaNG3VwqZ1asuNbbQv_NCM,18488
5
+ sae_lens/cache_activations_runner.py,sha256=dGK5EHJMHAKDAFyr25fy1COSm-61q-q6kpWENHFMaKk,12561
6
+ sae_lens/config.py,sha256=SPjziXrTyOBjObSi-3s0_mza3Z7WH8gd9NT9pVUfosg,34375
7
+ sae_lens/evals.py,sha256=tjDKmkUM4fBbP9LHZuBLCx37ux8Px9CliTMme3Wjt1A,38898
8
+ sae_lens/load_model.py,sha256=tE70sXsyyyGYW7o506O3eiw1MXyyW6DCQojLG49hWYI,6771
9
+ sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
+ sae_lens/loading/pretrained_sae_loaders.py,sha256=NcqyH2KDL8Dg66-hjXsBAq1-IwdLEpYfKwbkHxSQbrg,29961
11
+ sae_lens/loading/pretrained_saes_directory.py,sha256=4Vn-Jex6SveD7EbxcSOBv8cx1gkPfUMLU1QOP-ww1ZE,3752
12
+ sae_lens/pretokenize_runner.py,sha256=0nHQq3s_d80VS8iVK4-e6y_orAYVO8c4RrLGtIDfK_E,6885
13
+ sae_lens/pretrained_saes.yaml,sha256=C_z-7Lxz6ZIy2V-c-4Xw45eAQ926O9aGjocSNuki0xs,573557
14
+ sae_lens/regsitry.py,sha256=yCse5NmVH-ZaPET3jW8r7C_py2DL3yoox40GxGzJ0TI,1098
15
+ sae_lens/sae_training_runner.py,sha256=VRNSAIsZLfcQMfZB8qdnK45PUXwoNvJ-rKt9BVYjMMY,8244
16
+ sae_lens/saes/gated_sae.py,sha256=l5ucq7AZHya6ZClWNNE7CionGSf1ms5m1Ah3IoN6SH4,9916
17
+ sae_lens/saes/jumprelu_sae.py,sha256=DRWgY58894cNh_sYAlefObI4rr0Eb6KHu1WuhTCcvB4,13468
18
+ sae_lens/saes/sae.py,sha256=fd7OEsSXbmVii6QoYI_TRti6dwaxAQyrBcKyX7PxERw,36779
19
+ sae_lens/saes/standard_sae.py,sha256=m2eNL_w6ave-_g7F1eQiwI4qbjMwwjzvxp96RN_WVAw,7110
20
+ sae_lens/saes/topk_sae.py,sha256=aBET4F55A4xMIvZ8AazPtyl3oL-9S7krKx78li0uKGk,11370
21
+ sae_lens/tokenization_and_batching.py,sha256=oUAscjy_LPOrOb8_Ty6eLAcZ0B3HB_wiWjWktgolhG0,4314
22
+ sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
23
+ sae_lens/training/activations_store.py,sha256=ilJdcnZWfTDus1bdoqIb1wF_7H8_HWLmf8OCGrybmlA,35998
24
+ sae_lens/training/geometric_median.py,sha256=3kH8ZJAgKStlnZgs6s1uYGDYh004Bl0r4RLhuwT3lBY,3719
25
+ sae_lens/training/optim.py,sha256=AImcc-MAaGDLOBP2hJ4alDFCtaqqgm4cc2eBxIxiQAo,5784
26
+ sae_lens/training/sae_trainer.py,sha256=6TkqbzA0fYluRM8ouI_nU9sz-FaP63axxcnDrVfw37E,16279
27
+ sae_lens/training/upload_saes_to_huggingface.py,sha256=tVC-2Txw7-9XttGlKzM0OSqU8CK7HDO9vIzDMqEwAYU,4366
28
+ sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
29
+ sae_lens-6.0.0rc1.dist-info/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
30
+ sae_lens-6.0.0rc1.dist-info/METADATA,sha256=wHH-VRtquu-FjZEOHdPJi3zYW3ns7MCT1fVerbPEylc,5326
31
+ sae_lens-6.0.0rc1.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
32
+ sae_lens-6.0.0rc1.dist-info/RECORD,,