sae-lens 6.22.2__tar.gz → 6.22.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. {sae_lens-6.22.2 → sae_lens-6.22.3}/PKG-INFO +1 -1
  2. {sae_lens-6.22.2 → sae_lens-6.22.3}/pyproject.toml +1 -1
  3. {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/__init__.py +1 -1
  4. {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/saes/gated_sae.py +2 -2
  5. {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/saes/jumprelu_sae.py +4 -4
  6. {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/saes/sae.py +1 -1
  7. {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/saes/topk_sae.py +1 -1
  8. {sae_lens-6.22.2 → sae_lens-6.22.3}/LICENSE +0 -0
  9. {sae_lens-6.22.2 → sae_lens-6.22.3}/README.md +0 -0
  10. {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/analysis/__init__.py +0 -0
  11. {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
  12. {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/analysis/neuronpedia_integration.py +0 -0
  13. {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/cache_activations_runner.py +0 -0
  14. {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/config.py +0 -0
  15. {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/constants.py +0 -0
  16. {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/evals.py +0 -0
  17. {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/llm_sae_training_runner.py +0 -0
  18. {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/load_model.py +0 -0
  19. {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/loading/__init__.py +0 -0
  20. {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/loading/pretrained_sae_loaders.py +0 -0
  21. {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/loading/pretrained_saes_directory.py +0 -0
  22. {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/pretokenize_runner.py +0 -0
  23. {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/pretrained_saes.yaml +0 -0
  24. {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/registry.py +0 -0
  25. {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/saes/__init__.py +0 -0
  26. {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/saes/batchtopk_sae.py +0 -0
  27. {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/saes/matryoshka_batchtopk_sae.py +0 -0
  28. {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/saes/standard_sae.py +0 -0
  29. {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/saes/temporal_sae.py +0 -0
  30. {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/saes/transcoder.py +0 -0
  31. {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/tokenization_and_batching.py +0 -0
  32. {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/training/__init__.py +0 -0
  33. {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/training/activation_scaler.py +0 -0
  34. {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/training/activations_store.py +0 -0
  35. {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/training/mixing_buffer.py +0 -0
  36. {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/training/optim.py +0 -0
  37. {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/training/sae_trainer.py +0 -0
  38. {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/training/types.py +0 -0
  39. {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
  40. {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/tutorial/tsea.py +0 -0
  41. {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/util.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sae-lens
3
- Version: 6.22.2
3
+ Version: 6.22.3
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  License-File: LICENSE
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "sae-lens"
3
- version = "6.22.2"
3
+ version = "6.22.3"
4
4
  description = "Training and Analyzing Sparse Autoencoders (SAEs)"
5
5
  authors = ["Joseph Bloom"]
6
6
  readme = "README.md"
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.22.2"
2
+ __version__ = "6.22.3"
3
3
 
4
4
  import logging
5
5
 
@@ -89,7 +89,7 @@ class GatedSAE(SAE[GatedSAEConfig]):
89
89
  @torch.no_grad()
90
90
  def fold_W_dec_norm(self):
91
91
  """Override to handle gated-specific parameters."""
92
- W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
92
+ W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1)
93
93
  self.W_dec.data = self.W_dec.data / W_dec_norms
94
94
  self.W_enc.data = self.W_enc.data * W_dec_norms.T
95
95
 
@@ -217,7 +217,7 @@ class GatedTrainingSAE(TrainingSAE[GatedTrainingSAEConfig]):
217
217
  @torch.no_grad()
218
218
  def fold_W_dec_norm(self):
219
219
  """Override to handle gated-specific parameters."""
220
- W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
220
+ W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1)
221
221
  self.W_dec.data = self.W_dec.data / W_dec_norms
222
222
  self.W_enc.data = self.W_enc.data * W_dec_norms.T
223
223
 
@@ -167,8 +167,8 @@ class JumpReLUSAE(SAE[JumpReLUSAEConfig]):
167
167
  # Save the current threshold before calling parent method
168
168
  current_thresh = self.threshold.clone()
169
169
 
170
- # Get W_dec norms that will be used for scaling
171
- W_dec_norms = self.W_dec.norm(dim=-1)
170
+ # Get W_dec norms that will be used for scaling (clamped to avoid division by zero)
171
+ W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8)
172
172
 
173
173
  # Call parent implementation to handle W_enc, W_dec, and b_enc adjustment
174
174
  super().fold_W_dec_norm()
@@ -325,8 +325,8 @@ class JumpReLUTrainingSAE(TrainingSAE[JumpReLUTrainingSAEConfig]):
325
325
  # Save the current threshold before we call the parent method
326
326
  current_thresh = self.threshold.clone()
327
327
 
328
- # Get W_dec norms
329
- W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
328
+ # Get W_dec norms (clamped to avoid division by zero)
329
+ W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1)
330
330
 
331
331
  # Call parent implementation to handle W_enc and W_dec adjustment
332
332
  super().fold_W_dec_norm()
@@ -484,7 +484,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
484
484
  @torch.no_grad()
485
485
  def fold_W_dec_norm(self):
486
486
  """Fold decoder norms into encoder."""
487
- W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
487
+ W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1)
488
488
  self.W_dec.data = self.W_dec.data / W_dec_norms
489
489
  self.W_enc.data = self.W_enc.data * W_dec_norms.T
490
490
 
@@ -531,7 +531,7 @@ def _fold_norm_topk(
531
531
  b_enc: torch.Tensor,
532
532
  W_dec: torch.Tensor,
533
533
  ) -> None:
534
- W_dec_norm = W_dec.norm(dim=-1)
534
+ W_dec_norm = W_dec.norm(dim=-1).clamp(min=1e-8)
535
535
  b_enc.data = b_enc.data * W_dec_norm
536
536
  W_dec_norms = W_dec_norm.unsqueeze(1)
537
537
  W_dec.data = W_dec.data / W_dec_norms
File without changes
File without changes
File without changes
File without changes
File without changes