sae-lens 6.16.0__tar.gz → 6.16.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. {sae_lens-6.16.0 → sae_lens-6.16.3}/PKG-INFO +1 -1
  2. {sae_lens-6.16.0 → sae_lens-6.16.3}/pyproject.toml +2 -1
  3. {sae_lens-6.16.0 → sae_lens-6.16.3}/sae_lens/__init__.py +1 -1
  4. {sae_lens-6.16.0 → sae_lens-6.16.3}/sae_lens/saes/batchtopk_sae.py +3 -1
  5. {sae_lens-6.16.0 → sae_lens-6.16.3}/sae_lens/saes/matryoshka_batchtopk_sae.py +6 -12
  6. {sae_lens-6.16.0 → sae_lens-6.16.3}/LICENSE +0 -0
  7. {sae_lens-6.16.0 → sae_lens-6.16.3}/README.md +0 -0
  8. {sae_lens-6.16.0 → sae_lens-6.16.3}/sae_lens/analysis/__init__.py +0 -0
  9. {sae_lens-6.16.0 → sae_lens-6.16.3}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
  10. {sae_lens-6.16.0 → sae_lens-6.16.3}/sae_lens/analysis/neuronpedia_integration.py +0 -0
  11. {sae_lens-6.16.0 → sae_lens-6.16.3}/sae_lens/cache_activations_runner.py +0 -0
  12. {sae_lens-6.16.0 → sae_lens-6.16.3}/sae_lens/config.py +0 -0
  13. {sae_lens-6.16.0 → sae_lens-6.16.3}/sae_lens/constants.py +0 -0
  14. {sae_lens-6.16.0 → sae_lens-6.16.3}/sae_lens/evals.py +0 -0
  15. {sae_lens-6.16.0 → sae_lens-6.16.3}/sae_lens/llm_sae_training_runner.py +0 -0
  16. {sae_lens-6.16.0 → sae_lens-6.16.3}/sae_lens/load_model.py +0 -0
  17. {sae_lens-6.16.0 → sae_lens-6.16.3}/sae_lens/loading/__init__.py +0 -0
  18. {sae_lens-6.16.0 → sae_lens-6.16.3}/sae_lens/loading/pretrained_sae_loaders.py +0 -0
  19. {sae_lens-6.16.0 → sae_lens-6.16.3}/sae_lens/loading/pretrained_saes_directory.py +0 -0
  20. {sae_lens-6.16.0 → sae_lens-6.16.3}/sae_lens/pretokenize_runner.py +0 -0
  21. {sae_lens-6.16.0 → sae_lens-6.16.3}/sae_lens/pretrained_saes.yaml +0 -0
  22. {sae_lens-6.16.0 → sae_lens-6.16.3}/sae_lens/registry.py +0 -0
  23. {sae_lens-6.16.0 → sae_lens-6.16.3}/sae_lens/saes/__init__.py +0 -0
  24. {sae_lens-6.16.0 → sae_lens-6.16.3}/sae_lens/saes/gated_sae.py +0 -0
  25. {sae_lens-6.16.0 → sae_lens-6.16.3}/sae_lens/saes/jumprelu_sae.py +0 -0
  26. {sae_lens-6.16.0 → sae_lens-6.16.3}/sae_lens/saes/sae.py +0 -0
  27. {sae_lens-6.16.0 → sae_lens-6.16.3}/sae_lens/saes/standard_sae.py +0 -0
  28. {sae_lens-6.16.0 → sae_lens-6.16.3}/sae_lens/saes/topk_sae.py +0 -0
  29. {sae_lens-6.16.0 → sae_lens-6.16.3}/sae_lens/saes/transcoder.py +0 -0
  30. {sae_lens-6.16.0 → sae_lens-6.16.3}/sae_lens/tokenization_and_batching.py +0 -0
  31. {sae_lens-6.16.0 → sae_lens-6.16.3}/sae_lens/training/__init__.py +0 -0
  32. {sae_lens-6.16.0 → sae_lens-6.16.3}/sae_lens/training/activation_scaler.py +0 -0
  33. {sae_lens-6.16.0 → sae_lens-6.16.3}/sae_lens/training/activations_store.py +0 -0
  34. {sae_lens-6.16.0 → sae_lens-6.16.3}/sae_lens/training/mixing_buffer.py +0 -0
  35. {sae_lens-6.16.0 → sae_lens-6.16.3}/sae_lens/training/optim.py +0 -0
  36. {sae_lens-6.16.0 → sae_lens-6.16.3}/sae_lens/training/sae_trainer.py +0 -0
  37. {sae_lens-6.16.0 → sae_lens-6.16.3}/sae_lens/training/types.py +0 -0
  38. {sae_lens-6.16.0 → sae_lens-6.16.3}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
  39. {sae_lens-6.16.0 → sae_lens-6.16.3}/sae_lens/tutorial/tsea.py +0 -0
  40. {sae_lens-6.16.0 → sae_lens-6.16.3}/sae_lens/util.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sae-lens
3
- Version: 6.16.0
3
+ Version: 6.16.3
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  License-File: LICENSE
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "sae-lens"
3
- version = "6.16.0"
3
+ version = "6.16.3"
4
4
  description = "Training and Analyzing Sparse Autoencoders (SAEs)"
5
5
  authors = ["Joseph Bloom"]
6
6
  readme = "README.md"
@@ -55,6 +55,7 @@ ruff = "^0.7.4"
55
55
  eai-sparsify = "^1.1.1"
56
56
  mike = "^2.0.0"
57
57
  trio = "^0.30.0"
58
+ dictionary-learning = "^0.1.0"
58
59
 
59
60
  [tool.poetry.extras]
60
61
  mamba = ["mamba-lens"]
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.16.0"
2
+ __version__ = "6.16.3"
3
3
 
4
4
  import logging
5
5
 
@@ -23,7 +23,9 @@ class BatchTopK(nn.Module):
23
23
  def forward(self, x: torch.Tensor) -> torch.Tensor:
24
24
  acts = x.relu()
25
25
  flat_acts = acts.flatten()
26
- acts_topk_flat = torch.topk(flat_acts, int(self.k * acts.shape[0]), dim=-1)
26
+ # Calculate total number of samples across all non-feature dimensions
27
+ num_samples = acts.shape[:-1].numel()
28
+ acts_topk_flat = torch.topk(flat_acts, int(self.k * num_samples), dim=-1)
27
29
  return (
28
30
  torch.zeros_like(flat_acts)
29
31
  .scatter(-1, acts_topk_flat.indices, acts_topk_flat.values)
@@ -78,14 +78,11 @@ class MatryoshkaBatchTopKTrainingSAE(BatchTopKTrainingSAE):
78
78
  @override
79
79
  def training_forward_pass(self, step_input: TrainStepInput) -> TrainStepOutput:
80
80
  base_output = super().training_forward_pass(step_input)
81
- hidden_pre = base_output.hidden_pre
82
81
  inv_W_dec_norm = 1 / self.W_dec.norm(dim=-1)
83
82
  # the outer matryoshka level is the base SAE, so we don't need to add an extra loss for it
84
83
  for width in self.cfg.matryoshka_widths[:-1]:
85
- inner_hidden_pre = hidden_pre[:, :width]
86
- inner_feat_acts = self.activation_fn(inner_hidden_pre)
87
84
  inner_reconstruction = self._decode_matryoshka_level(
88
- inner_feat_acts, width, inv_W_dec_norm
85
+ base_output.feature_acts, width, inv_W_dec_norm
89
86
  )
90
87
  inner_mse_loss = (
91
88
  self.mse_loss_fn(inner_reconstruction, step_input.sae_in)
@@ -105,16 +102,17 @@ class MatryoshkaBatchTopKTrainingSAE(BatchTopKTrainingSAE):
105
102
  """
106
103
  Decodes feature activations back into input space for a matryoshka level
107
104
  """
105
+ inner_feature_acts = feature_acts[:, :width]
108
106
  # Handle sparse tensors using efficient sparse matrix multiplication
109
107
  if self.cfg.rescale_acts_by_decoder_norm:
110
108
  # need to multiply by the inverse of the norm because division is illegal with sparse tensors
111
- feature_acts = feature_acts * inv_W_dec_norm[:width]
112
- if feature_acts.is_sparse:
109
+ inner_feature_acts = inner_feature_acts * inv_W_dec_norm[:width]
110
+ if inner_feature_acts.is_sparse:
113
111
  sae_out_pre = (
114
- _sparse_matmul_nd(feature_acts, self.W_dec[:width]) + self.b_dec
112
+ _sparse_matmul_nd(inner_feature_acts, self.W_dec[:width]) + self.b_dec
115
113
  )
116
114
  else:
117
- sae_out_pre = feature_acts @ self.W_dec[:width] + self.b_dec
115
+ sae_out_pre = inner_feature_acts @ self.W_dec[:width] + self.b_dec
118
116
  sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
119
117
  return self.reshape_fn_out(sae_out_pre, self.d_head)
120
118
 
@@ -137,7 +135,3 @@ def _validate_matryoshka_config(cfg: MatryoshkaBatchTopKTrainingSAEConfig) -> No
137
135
  warnings.warn(
138
136
  "WARNING: You have only set one matryoshka level. This is equivalent to using a standard BatchTopK SAE and is likely not what you want."
139
137
  )
140
- if cfg.matryoshka_widths[0] < cfg.k:
141
- raise ValueError(
142
- "The smallest matryoshka level width cannot be smaller than cfg.k."
143
- )
File without changes
File without changes
File without changes
File without changes
File without changes