sae-lens 6.13.1__tar.gz → 6.14.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {sae_lens-6.13.1 → sae_lens-6.14.1}/PKG-INFO +1 -1
  2. {sae_lens-6.13.1 → sae_lens-6.14.1}/pyproject.toml +1 -1
  3. {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/__init__.py +1 -1
  4. {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/saes/batchtopk_sae.py +29 -0
  5. {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/saes/topk_sae.py +76 -8
  6. {sae_lens-6.13.1 → sae_lens-6.14.1}/LICENSE +0 -0
  7. {sae_lens-6.13.1 → sae_lens-6.14.1}/README.md +0 -0
  8. {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/analysis/__init__.py +0 -0
  9. {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
  10. {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/analysis/neuronpedia_integration.py +0 -0
  11. {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/cache_activations_runner.py +0 -0
  12. {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/config.py +0 -0
  13. {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/constants.py +0 -0
  14. {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/evals.py +0 -0
  15. {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/llm_sae_training_runner.py +0 -0
  16. {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/load_model.py +0 -0
  17. {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/loading/__init__.py +0 -0
  18. {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/loading/pretrained_sae_loaders.py +0 -0
  19. {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/loading/pretrained_saes_directory.py +0 -0
  20. {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/pretokenize_runner.py +0 -0
  21. {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/pretrained_saes.yaml +0 -0
  22. {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/registry.py +0 -0
  23. {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/saes/__init__.py +0 -0
  24. {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/saes/gated_sae.py +0 -0
  25. {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/saes/jumprelu_sae.py +0 -0
  26. {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/saes/sae.py +0 -0
  27. {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/saes/standard_sae.py +0 -0
  28. {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/saes/transcoder.py +0 -0
  29. {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/tokenization_and_batching.py +0 -0
  30. {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/training/__init__.py +0 -0
  31. {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/training/activation_scaler.py +0 -0
  32. {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/training/activations_store.py +0 -0
  33. {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/training/mixing_buffer.py +0 -0
  34. {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/training/optim.py +0 -0
  35. {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/training/sae_trainer.py +0 -0
  36. {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/training/types.py +0 -0
  37. {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
  38. {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/tutorial/tsea.py +0 -0
  39. {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/util.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sae-lens
3
- Version: 6.13.1
3
+ Version: 6.14.1
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  License-File: LICENSE
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "sae-lens"
3
- version = "6.13.1"
3
+ version = "6.14.1"
4
4
  description = "Training and Analyzing Sparse Autoencoders (SAEs)"
5
5
  authors = ["Joseph Bloom"]
6
6
  readme = "README.md"
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.13.1"
2
+ __version__ = "6.14.1"
3
3
 
4
4
  import logging
5
5
 
@@ -35,6 +35,35 @@ class BatchTopK(nn.Module):
35
35
  class BatchTopKTrainingSAEConfig(TopKTrainingSAEConfig):
36
36
  """
37
37
  Configuration class for training a BatchTopKTrainingSAE.
38
+
39
+ BatchTopK SAEs maintain k active features on average across the entire batch,
40
+ rather than enforcing k features per sample like standard TopK SAEs. During training,
41
+ the SAE learns a global threshold that is updated based on the minimum positive
42
+ activation value. After training, BatchTopK SAEs are saved as JumpReLU SAEs.
43
+
44
+ Args:
45
+ k (float): Average number of features to keep active across the batch. Unlike
46
+ standard TopK SAEs where k is an integer per sample, this is a float
47
+ representing the average number of active features across all samples in
48
+ the batch. Defaults to 100.
49
+ topk_threshold_lr (float): Learning rate for updating the global topk threshold.
50
+ The threshold is updated using an exponential moving average of the minimum
51
+ positive activation value. Defaults to 0.01.
52
+ aux_loss_coefficient (float): Coefficient for the auxiliary loss that encourages
53
+ dead neurons to learn useful features. Inherited from TopKTrainingSAEConfig.
54
+ Defaults to 1.0.
55
+ rescale_acts_by_decoder_norm (bool): Treat the decoder as if it was already normalized.
56
+ Inherited from TopKTrainingSAEConfig. Defaults to True.
57
+ decoder_init_norm (float | None): Norm to initialize decoder weights to.
58
+ Inherited from TrainingSAEConfig. Defaults to 0.1.
59
+ d_in (int): Input dimension (dimensionality of the activations being encoded).
60
+ Inherited from SAEConfig.
61
+ d_sae (int): SAE latent dimension (number of features in the SAE).
62
+ Inherited from SAEConfig.
63
+ dtype (str): Data type for the SAE parameters. Inherited from SAEConfig.
64
+ Defaults to "float32".
65
+ device (str): Device to place the SAE on. Inherited from SAEConfig.
66
+ Defaults to "cpu".
38
67
  """
39
68
 
40
69
  k: float = 100 # type: ignore[assignment]
@@ -1,7 +1,7 @@
1
1
  """Inference-only TopKSAE variant, similar in spirit to StandardSAE but using a TopK-based activation."""
2
2
 
3
3
  from dataclasses import dataclass
4
- from typing import Callable
4
+ from typing import Any, Callable
5
5
 
6
6
  import torch
7
7
  from jaxtyping import Float
@@ -118,10 +118,36 @@ class TopK(nn.Module):
118
118
  @dataclass
119
119
  class TopKSAEConfig(SAEConfig):
120
120
  """
121
- Configuration class for a TopKSAE.
121
+ Configuration class for TopKSAE inference.
122
+
123
+ Args:
124
+ k (int): Number of top features to keep active during inference. Only the top k
125
+ features with the highest pre-activations will be non-zero. Defaults to 100.
126
+ rescale_acts_by_decoder_norm (bool): Whether to treat the decoder as if it was
127
+ already normalized. This affects the topk selection by rescaling pre-activations
128
+ by decoder norms. Requires that the SAE was trained this way. Defaults to False.
129
+ d_in (int): Input dimension (dimensionality of the activations being encoded).
130
+ Inherited from SAEConfig.
131
+ d_sae (int): SAE latent dimension (number of features in the SAE).
132
+ Inherited from SAEConfig.
133
+ dtype (str): Data type for the SAE parameters. Inherited from SAEConfig.
134
+ Defaults to "float32".
135
+ device (str): Device to place the SAE on. Inherited from SAEConfig.
136
+ Defaults to "cpu".
137
+ apply_b_dec_to_input (bool): Whether to apply decoder bias to the input
138
+ before encoding. Inherited from SAEConfig. Defaults to True.
139
+ normalize_activations (Literal["none", "expected_average_only_in", "constant_norm_rescale", "layer_norm"]):
140
+ Normalization strategy for input activations. Inherited from SAEConfig.
141
+ Defaults to "none".
142
+ reshape_activations (Literal["none", "hook_z"]): How to reshape activations
143
+ (useful for attention head outputs). Inherited from SAEConfig.
144
+ Defaults to "none".
145
+ metadata (SAEMetadata): Metadata about the SAE (model name, hook name, etc.).
146
+ Inherited from SAEConfig.
122
147
  """
123
148
 
124
149
  k: int = 100
150
+ rescale_acts_by_decoder_norm: bool = False
125
151
 
126
152
  @override
127
153
  @classmethod
@@ -218,6 +244,8 @@ class TopKSAE(SAE[TopKSAEConfig]):
218
244
  """
219
245
  sae_in = self.process_sae_in(x)
220
246
  hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
247
+ if self.cfg.rescale_acts_by_decoder_norm:
248
+ hidden_pre = hidden_pre * self.W_dec.norm(dim=-1)
221
249
  # The BaseSAE already sets self.activation_fn to TopK(...) if config requests topk.
222
250
  return self.hook_sae_acts_post(self.activation_fn(hidden_pre))
223
251
 
@@ -231,6 +259,8 @@ class TopKSAE(SAE[TopKSAEConfig]):
231
259
  and optional head reshaping.
232
260
  """
233
261
  # Handle sparse tensors using efficient sparse matrix multiplication
262
+ if self.cfg.rescale_acts_by_decoder_norm:
263
+ feature_acts = feature_acts / self.W_dec.norm(dim=-1)
234
264
  if feature_acts.is_sparse:
235
265
  sae_out_pre = _sparse_matmul_nd(feature_acts, self.W_dec) + self.b_dec
236
266
  else:
@@ -246,9 +276,11 @@ class TopKSAE(SAE[TopKSAEConfig]):
246
276
  @override
247
277
  @torch.no_grad()
248
278
  def fold_W_dec_norm(self) -> None:
249
- raise NotImplementedError(
250
- "Folding W_dec_norm is not safe for TopKSAEs, as this may change the topk activations"
251
- )
279
+ if not self.cfg.rescale_acts_by_decoder_norm:
280
+ raise NotImplementedError(
281
+ "Folding W_dec_norm is not safe for TopKSAEs when rescale_acts_by_decoder_norm is False, as this may change the topk activations"
282
+ )
283
+ _fold_norm_topk(W_dec=self.W_dec, b_enc=self.b_enc, W_enc=self.W_enc)
252
284
 
253
285
 
254
286
  @dataclass
@@ -267,6 +299,9 @@ class TopKTrainingSAEConfig(TrainingSAEConfig):
267
299
  dead neurons to learn useful features. This loss helps prevent neuron death
268
300
  in TopK SAEs by having dead neurons reconstruct the residual error from
269
301
  live neurons. Defaults to 1.0.
302
+ rescale_acts_by_decoder_norm (bool): Treat the decoder as if it was already normalized.
303
+ This is a good idea since decoder norm can randomly drift during training, and this
304
+ affects what the topk activations will be. Defaults to True.
270
305
  decoder_init_norm (float | None): Norm to initialize decoder weights to.
271
306
  0.1 corresponds to the "heuristic" initialization from Anthropic's April update.
272
307
  Use None to disable. Inherited from TrainingSAEConfig. Defaults to 0.1.
@@ -293,6 +328,7 @@ class TopKTrainingSAEConfig(TrainingSAEConfig):
293
328
  k: int = 100
294
329
  use_sparse_activations: bool = False
295
330
  aux_loss_coefficient: float = 1.0
331
+ rescale_acts_by_decoder_norm: bool = True
296
332
 
297
333
  @override
298
334
  @classmethod
@@ -326,6 +362,9 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
326
362
  sae_in = self.process_sae_in(x)
327
363
  hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
328
364
 
365
+ if self.cfg.rescale_acts_by_decoder_norm:
366
+ hidden_pre = hidden_pre * self.W_dec.norm(dim=-1)
367
+
329
368
  # Apply the TopK activation function (already set in self.activation_fn if config is "topk")
330
369
  feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre))
331
370
  return feature_acts, hidden_pre
@@ -340,6 +379,9 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
340
379
  applying optional finetuning scale, hooking, out normalization, etc.
341
380
  """
342
381
  # Handle sparse tensors using efficient sparse matrix multiplication
382
+ if self.cfg.rescale_acts_by_decoder_norm:
383
+ # need to multiply by the inverse of the norm because division is illegal with sparse tensors
384
+ feature_acts = feature_acts * (1 / self.W_dec.norm(dim=-1))
343
385
  if feature_acts.is_sparse:
344
386
  sae_out_pre = _sparse_matmul_nd(feature_acts, self.W_dec) + self.b_dec
345
387
  else:
@@ -385,9 +427,11 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
385
427
  @override
386
428
  @torch.no_grad()
387
429
  def fold_W_dec_norm(self) -> None:
388
- raise NotImplementedError(
389
- "Folding W_dec_norm is not safe for TopKSAEs, as this may change the topk activations"
390
- )
430
+ if not self.cfg.rescale_acts_by_decoder_norm:
431
+ raise NotImplementedError(
432
+ "Folding W_dec_norm is not safe for TopKSAEs when rescale_acts_by_decoder_norm is False, as this may change the topk activations"
433
+ )
434
+ _fold_norm_topk(W_dec=self.W_dec, b_enc=self.b_enc, W_enc=self.W_enc)
391
435
 
392
436
  @override
393
437
  def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
@@ -436,6 +480,18 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
436
480
  auxk_loss = (recons - residual).pow(2).sum(dim=-1).mean()
437
481
  return self.cfg.aux_loss_coefficient * scale * auxk_loss
438
482
 
483
+ @override
484
+ def process_state_dict_for_saving_inference(
485
+ self, state_dict: dict[str, Any]
486
+ ) -> None:
487
+ super().process_state_dict_for_saving_inference(state_dict)
488
+ if self.cfg.rescale_acts_by_decoder_norm:
489
+ _fold_norm_topk(
490
+ W_enc=state_dict["W_enc"],
491
+ b_enc=state_dict["b_enc"],
492
+ W_dec=state_dict["W_dec"],
493
+ )
494
+
439
495
 
440
496
  def _calculate_topk_aux_acts(
441
497
  k_aux: int,
@@ -471,3 +527,15 @@ def _init_weights_topk(
471
527
  sae.b_enc = nn.Parameter(
472
528
  torch.zeros(sae.cfg.d_sae, dtype=sae.dtype, device=sae.device)
473
529
  )
530
+
531
+
532
+ def _fold_norm_topk(
533
+ W_enc: torch.Tensor,
534
+ b_enc: torch.Tensor,
535
+ W_dec: torch.Tensor,
536
+ ) -> None:
537
+ W_dec_norm = W_dec.norm(dim=-1)
538
+ b_enc.data = b_enc.data * W_dec_norm
539
+ W_dec_norms = W_dec_norm.unsqueeze(1)
540
+ W_dec.data = W_dec.data / W_dec_norms
541
+ W_enc.data = W_enc.data * W_dec_norms.T
File without changes
File without changes
File without changes
File without changes
File without changes