sae-lens 6.22.2__tar.gz → 6.23.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. {sae_lens-6.22.2 → sae_lens-6.23.0}/PKG-INFO +1 -1
  2. {sae_lens-6.22.2 → sae_lens-6.23.0}/pyproject.toml +1 -1
  3. {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/__init__.py +1 -1
  4. {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/loading/pretrained_sae_loaders.py +20 -5
  5. {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/pretrained_saes.yaml +78 -0
  6. {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/saes/gated_sae.py +2 -2
  7. {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/saes/jumprelu_sae.py +4 -4
  8. {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/saes/sae.py +1 -1
  9. {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/saes/topk_sae.py +1 -1
  10. {sae_lens-6.22.2 → sae_lens-6.23.0}/LICENSE +0 -0
  11. {sae_lens-6.22.2 → sae_lens-6.23.0}/README.md +0 -0
  12. {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/analysis/__init__.py +0 -0
  13. {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
  14. {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/analysis/neuronpedia_integration.py +0 -0
  15. {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/cache_activations_runner.py +0 -0
  16. {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/config.py +0 -0
  17. {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/constants.py +0 -0
  18. {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/evals.py +0 -0
  19. {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/llm_sae_training_runner.py +0 -0
  20. {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/load_model.py +0 -0
  21. {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/loading/__init__.py +0 -0
  22. {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/loading/pretrained_saes_directory.py +0 -0
  23. {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/pretokenize_runner.py +0 -0
  24. {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/registry.py +0 -0
  25. {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/saes/__init__.py +0 -0
  26. {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/saes/batchtopk_sae.py +0 -0
  27. {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/saes/matryoshka_batchtopk_sae.py +0 -0
  28. {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/saes/standard_sae.py +0 -0
  29. {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/saes/temporal_sae.py +0 -0
  30. {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/saes/transcoder.py +0 -0
  31. {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/tokenization_and_batching.py +0 -0
  32. {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/training/__init__.py +0 -0
  33. {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/training/activation_scaler.py +0 -0
  34. {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/training/activations_store.py +0 -0
  35. {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/training/mixing_buffer.py +0 -0
  36. {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/training/optim.py +0 -0
  37. {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/training/sae_trainer.py +0 -0
  38. {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/training/types.py +0 -0
  39. {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
  40. {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/tutorial/tsea.py +0 -0
  41. {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/util.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sae-lens
3
- Version: 6.22.2
3
+ Version: 6.23.0
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  License-File: LICENSE
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "sae-lens"
3
- version = "6.22.2"
3
+ version = "6.23.0"
4
4
  description = "Training and Analyzing Sparse Autoencoders (SAEs)"
5
5
  authors = ["Joseph Bloom"]
6
6
  readme = "README.md"
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.22.2"
2
+ __version__ = "6.23.0"
3
3
 
4
4
  import logging
5
5
 
@@ -753,10 +753,14 @@ def get_dictionary_learning_config_1_from_hf(
753
753
  activation_fn = "topk" if trainer["dict_class"] == "AutoEncoderTopK" else "relu"
754
754
  activation_fn_kwargs = {"k": trainer["k"]} if activation_fn == "topk" else {}
755
755
 
756
+ architecture = "standard"
757
+ if trainer["dict_class"] == "GatedAutoEncoder":
758
+ architecture = "gated"
759
+ elif trainer["dict_class"] == "MatryoshkaBatchTopKSAE":
760
+ architecture = "jumprelu"
761
+
756
762
  return {
757
- "architecture": (
758
- "gated" if trainer["dict_class"] == "GatedAutoEncoder" else "standard"
759
- ),
763
+ "architecture": architecture,
760
764
  "d_in": trainer["activation_dim"],
761
765
  "d_sae": trainer["dict_size"],
762
766
  "dtype": "float32",
@@ -905,9 +909,12 @@ def dictionary_learning_sae_huggingface_loader_1(
905
909
  )
906
910
  encoder = torch.load(encoder_path, map_location="cpu")
907
911
 
912
+ W_enc = encoder["W_enc"] if "W_enc" in encoder else encoder["encoder.weight"].T
913
+ W_dec = encoder["W_dec"] if "W_dec" in encoder else encoder["decoder.weight"].T
914
+
908
915
  state_dict = {
909
- "W_enc": encoder["encoder.weight"].T,
910
- "W_dec": encoder["decoder.weight"].T,
916
+ "W_enc": W_enc,
917
+ "W_dec": W_dec,
911
918
  "b_dec": encoder.get(
912
919
  "b_dec", encoder.get("bias", encoder.get("decoder_bias", None))
913
920
  ),
@@ -915,6 +922,8 @@ def dictionary_learning_sae_huggingface_loader_1(
915
922
 
916
923
  if "encoder.bias" in encoder:
917
924
  state_dict["b_enc"] = encoder["encoder.bias"]
925
+ if "b_enc" in encoder:
926
+ state_dict["b_enc"] = encoder["b_enc"]
918
927
 
919
928
  if "mag_bias" in encoder:
920
929
  state_dict["b_mag"] = encoder["mag_bias"]
@@ -923,6 +932,12 @@ def dictionary_learning_sae_huggingface_loader_1(
923
932
  if "r_mag" in encoder:
924
933
  state_dict["r_mag"] = encoder["r_mag"]
925
934
 
935
+ if "threshold" in encoder:
936
+ threshold = encoder["threshold"]
937
+ if threshold.ndim == 0:
938
+ threshold = torch.full((W_enc.size(1),), threshold)
939
+ state_dict["threshold"] = threshold
940
+
926
941
  return cfg_dict, state_dict, None
927
942
 
928
943
 
@@ -14959,3 +14959,81 @@ goodfire-llama-3.1-8b-instruct:
14959
14959
  path: Llama-3.1-8B-Instruct-SAE-l19.pth
14960
14960
  l0: 91
14961
14961
  neuronpedia: llama3.1-8b-it/19-resid-post-gf
14962
+
14963
+ saebench_gemma-2-2b_width-2pow12_date-0108:
14964
+ conversion_func: dictionary_learning_1
14965
+ links:
14966
+ model: https://huggingface.co/google/gemma-2-2b
14967
+ model: gemma-2-2b
14968
+ repo_id: adamkarvonen/saebench_gemma-2-2b_width-2pow12_date-0108
14969
+ saes:
14970
+ - id: blocks.12.hook_resid_post__trainer_0
14971
+ neuronpedia: gemma-2-2b/12-sae_bench-matryoshka-res-4k__trainer_0_step_final
14972
+ path: MatryoshkaBatchTopK_gemma-2-2b__0108/resid_post_layer_12/trainer_0
14973
+ - id: blocks.12.hook_resid_post__trainer_1
14974
+ neuronpedia: gemma-2-2b/12-sae_bench-matryoshka-res-4k__trainer_1_step_final
14975
+ path: MatryoshkaBatchTopK_gemma-2-2b__0108/resid_post_layer_12/trainer_1
14976
+ - id: blocks.12.hook_resid_post__trainer_2
14977
+ neuronpedia: gemma-2-2b/12-sae_bench-matryoshka-res-4k__trainer_2_step_final
14978
+ path: MatryoshkaBatchTopK_gemma-2-2b__0108/resid_post_layer_12/trainer_2
14979
+ - id: blocks.12.hook_resid_post__trainer_3
14980
+ neuronpedia: gemma-2-2b/12-sae_bench-matryoshka-res-4k__trainer_3_step_final
14981
+ path: MatryoshkaBatchTopK_gemma-2-2b__0108/resid_post_layer_12/trainer_3
14982
+ - id: blocks.12.hook_resid_post__trainer_4
14983
+ neuronpedia: gemma-2-2b/12-sae_bench-matryoshka-res-4k__trainer_4_step_final
14984
+ path: MatryoshkaBatchTopK_gemma-2-2b__0108/resid_post_layer_12/trainer_4
14985
+ - id: blocks.12.hook_resid_post__trainer_5
14986
+ neuronpedia: gemma-2-2b/12-sae_bench-matryoshka-res-4k__trainer_5_step_final
14987
+ path: MatryoshkaBatchTopK_gemma-2-2b__0108/resid_post_layer_12/trainer_5
14988
+
14989
+ saebench_gemma-2-2b_width-2pow14_date-0107:
14990
+ conversion_func: dictionary_learning_1
14991
+ links:
14992
+ model: https://huggingface.co/google/gemma-2-2b
14993
+ model: gemma-2-2b
14994
+ repo_id: canrager/saebench_gemma-2-2b_width-2pow14_date-0107
14995
+ saes:
14996
+ - id: blocks.12.hook_resid_post__trainer_0
14997
+ neuronpedia: gemma-2-2b/12-sae_bench-matryoshka-res-16k__trainer_0_step_final
14998
+ path: gemma-2-2b_matryoshka_batch_top_k_width-2pow14_date-0107/resid_post_layer_12/trainer_0
14999
+ - id: blocks.12.hook_resid_post__trainer_1
15000
+ neuronpedia: gemma-2-2b/12-sae_bench-matryoshka-res-16k__trainer_1_step_final
15001
+ path: gemma-2-2b_matryoshka_batch_top_k_width-2pow14_date-0107/resid_post_layer_12/trainer_1
15002
+ - id: blocks.12.hook_resid_post__trainer_2
15003
+ neuronpedia: gemma-2-2b/12-sae_bench-matryoshka-res-16k__trainer_2_step_final
15004
+ path: gemma-2-2b_matryoshka_batch_top_k_width-2pow14_date-0107/resid_post_layer_12/trainer_2
15005
+ - id: blocks.12.hook_resid_post__trainer_3
15006
+ neuronpedia: gemma-2-2b/12-sae_bench-matryoshka-res-16k__trainer_3_step_final
15007
+ path: gemma-2-2b_matryoshka_batch_top_k_width-2pow14_date-0107/resid_post_layer_12/trainer_3
15008
+ - id: blocks.12.hook_resid_post__trainer_4
15009
+ neuronpedia: gemma-2-2b/12-sae_bench-matryoshka-res-16k__trainer_4_step_final
15010
+ path: gemma-2-2b_matryoshka_batch_top_k_width-2pow14_date-0107/resid_post_layer_12/trainer_4
15011
+ - id: blocks.12.hook_resid_post__trainer_5
15012
+ neuronpedia: gemma-2-2b/12-sae_bench-matryoshka-res-16k__trainer_5_step_final
15013
+ path: gemma-2-2b_matryoshka_batch_top_k_width-2pow14_date-0107/resid_post_layer_12/trainer_5
15014
+
15015
+ saebench_gemma-2-2b_width-2pow16_date-0107:
15016
+ conversion_func: dictionary_learning_1
15017
+ links:
15018
+ model: https://huggingface.co/google/gemma-2-2b
15019
+ model: gemma-2-2b
15020
+ repo_id: canrager/saebench_gemma-2-2b_width-2pow16_date-0107
15021
+ saes:
15022
+ - id: blocks.12.hook_resid_post__trainer_0
15023
+ neuronpedia: gemma-2-2b/12-sae_bench-matryoshka-res-65k__trainer_0_step_final
15024
+ path: gemma-2-2b_matryoshka_batch_top_k_width-2pow16_date-0107/resid_post_layer_12/trainer_0
15025
+ - id: blocks.12.hook_resid_post__trainer_1
15026
+ neuronpedia: gemma-2-2b/12-sae_bench-matryoshka-res-65k__trainer_1_step_final
15027
+ path: gemma-2-2b_matryoshka_batch_top_k_width-2pow16_date-0107/resid_post_layer_12/trainer_1
15028
+ - id: blocks.12.hook_resid_post__trainer_2
15029
+ neuronpedia: gemma-2-2b/12-sae_bench-matryoshka-res-65k__trainer_2_step_final
15030
+ path: gemma-2-2b_matryoshka_batch_top_k_width-2pow16_date-0107/resid_post_layer_12/trainer_2
15031
+ - id: blocks.12.hook_resid_post__trainer_3
15032
+ neuronpedia: gemma-2-2b/12-sae_bench-matryoshka-res-65k__trainer_3_step_final
15033
+ path: gemma-2-2b_matryoshka_batch_top_k_width-2pow16_date-0107/resid_post_layer_12/trainer_3
15034
+ - id: blocks.12.hook_resid_post__trainer_4
15035
+ neuronpedia: gemma-2-2b/12-sae_bench-matryoshka-res-65k__trainer_4_step_final
15036
+ path: gemma-2-2b_matryoshka_batch_top_k_width-2pow16_date-0107/resid_post_layer_12/trainer_4
15037
+ - id: blocks.12.hook_resid_post__trainer_5
15038
+ neuronpedia: gemma-2-2b/12-sae_bench-matryoshka-res-65k__trainer_5_step_final
15039
+ path: gemma-2-2b_matryoshka_batch_top_k_width-2pow16_date-0107/resid_post_layer_12/trainer_5
@@ -89,7 +89,7 @@ class GatedSAE(SAE[GatedSAEConfig]):
89
89
  @torch.no_grad()
90
90
  def fold_W_dec_norm(self):
91
91
  """Override to handle gated-specific parameters."""
92
- W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
92
+ W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1)
93
93
  self.W_dec.data = self.W_dec.data / W_dec_norms
94
94
  self.W_enc.data = self.W_enc.data * W_dec_norms.T
95
95
 
@@ -217,7 +217,7 @@ class GatedTrainingSAE(TrainingSAE[GatedTrainingSAEConfig]):
217
217
  @torch.no_grad()
218
218
  def fold_W_dec_norm(self):
219
219
  """Override to handle gated-specific parameters."""
220
- W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
220
+ W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1)
221
221
  self.W_dec.data = self.W_dec.data / W_dec_norms
222
222
  self.W_enc.data = self.W_enc.data * W_dec_norms.T
223
223
 
@@ -167,8 +167,8 @@ class JumpReLUSAE(SAE[JumpReLUSAEConfig]):
167
167
  # Save the current threshold before calling parent method
168
168
  current_thresh = self.threshold.clone()
169
169
 
170
- # Get W_dec norms that will be used for scaling
171
- W_dec_norms = self.W_dec.norm(dim=-1)
170
+ # Get W_dec norms that will be used for scaling (clamped to avoid division by zero)
171
+ W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8)
172
172
 
173
173
  # Call parent implementation to handle W_enc, W_dec, and b_enc adjustment
174
174
  super().fold_W_dec_norm()
@@ -325,8 +325,8 @@ class JumpReLUTrainingSAE(TrainingSAE[JumpReLUTrainingSAEConfig]):
325
325
  # Save the current threshold before we call the parent method
326
326
  current_thresh = self.threshold.clone()
327
327
 
328
- # Get W_dec norms
329
- W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
328
+ # Get W_dec norms (clamped to avoid division by zero)
329
+ W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1)
330
330
 
331
331
  # Call parent implementation to handle W_enc and W_dec adjustment
332
332
  super().fold_W_dec_norm()
@@ -484,7 +484,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
484
484
  @torch.no_grad()
485
485
  def fold_W_dec_norm(self):
486
486
  """Fold decoder norms into encoder."""
487
- W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
487
+ W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1)
488
488
  self.W_dec.data = self.W_dec.data / W_dec_norms
489
489
  self.W_enc.data = self.W_enc.data * W_dec_norms.T
490
490
 
@@ -531,7 +531,7 @@ def _fold_norm_topk(
531
531
  b_enc: torch.Tensor,
532
532
  W_dec: torch.Tensor,
533
533
  ) -> None:
534
- W_dec_norm = W_dec.norm(dim=-1)
534
+ W_dec_norm = W_dec.norm(dim=-1).clamp(min=1e-8)
535
535
  b_enc.data = b_enc.data * W_dec_norm
536
536
  W_dec_norms = W_dec_norm.unsqueeze(1)
537
537
  W_dec.data = W_dec.data / W_dec_norms
File without changes
File without changes
File without changes
File without changes
File without changes