sae-lens 6.22.2__tar.gz → 6.23.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sae_lens-6.22.2 → sae_lens-6.23.0}/PKG-INFO +1 -1
- {sae_lens-6.22.2 → sae_lens-6.23.0}/pyproject.toml +1 -1
- {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/__init__.py +1 -1
- {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/loading/pretrained_sae_loaders.py +20 -5
- {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/pretrained_saes.yaml +78 -0
- {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/saes/gated_sae.py +2 -2
- {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/saes/jumprelu_sae.py +4 -4
- {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/saes/sae.py +1 -1
- {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/saes/topk_sae.py +1 -1
- {sae_lens-6.22.2 → sae_lens-6.23.0}/LICENSE +0 -0
- {sae_lens-6.22.2 → sae_lens-6.23.0}/README.md +0 -0
- {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/analysis/__init__.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/analysis/neuronpedia_integration.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/cache_activations_runner.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/config.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/constants.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/evals.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/llm_sae_training_runner.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/load_model.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/loading/__init__.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/loading/pretrained_saes_directory.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/pretokenize_runner.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/registry.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/saes/__init__.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/saes/batchtopk_sae.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/saes/matryoshka_batchtopk_sae.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/saes/standard_sae.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/saes/temporal_sae.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/saes/transcoder.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/tokenization_and_batching.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/training/__init__.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/training/activation_scaler.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/training/activations_store.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/training/mixing_buffer.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/training/optim.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/training/sae_trainer.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/training/types.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/tutorial/tsea.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.23.0}/sae_lens/util.py +0 -0
|
@@ -753,10 +753,14 @@ def get_dictionary_learning_config_1_from_hf(
|
|
|
753
753
|
activation_fn = "topk" if trainer["dict_class"] == "AutoEncoderTopK" else "relu"
|
|
754
754
|
activation_fn_kwargs = {"k": trainer["k"]} if activation_fn == "topk" else {}
|
|
755
755
|
|
|
756
|
+
architecture = "standard"
|
|
757
|
+
if trainer["dict_class"] == "GatedAutoEncoder":
|
|
758
|
+
architecture = "gated"
|
|
759
|
+
elif trainer["dict_class"] == "MatryoshkaBatchTopKSAE":
|
|
760
|
+
architecture = "jumprelu"
|
|
761
|
+
|
|
756
762
|
return {
|
|
757
|
-
"architecture":
|
|
758
|
-
"gated" if trainer["dict_class"] == "GatedAutoEncoder" else "standard"
|
|
759
|
-
),
|
|
763
|
+
"architecture": architecture,
|
|
760
764
|
"d_in": trainer["activation_dim"],
|
|
761
765
|
"d_sae": trainer["dict_size"],
|
|
762
766
|
"dtype": "float32",
|
|
@@ -905,9 +909,12 @@ def dictionary_learning_sae_huggingface_loader_1(
|
|
|
905
909
|
)
|
|
906
910
|
encoder = torch.load(encoder_path, map_location="cpu")
|
|
907
911
|
|
|
912
|
+
W_enc = encoder["W_enc"] if "W_enc" in encoder else encoder["encoder.weight"].T
|
|
913
|
+
W_dec = encoder["W_dec"] if "W_dec" in encoder else encoder["decoder.weight"].T
|
|
914
|
+
|
|
908
915
|
state_dict = {
|
|
909
|
-
"W_enc":
|
|
910
|
-
"W_dec":
|
|
916
|
+
"W_enc": W_enc,
|
|
917
|
+
"W_dec": W_dec,
|
|
911
918
|
"b_dec": encoder.get(
|
|
912
919
|
"b_dec", encoder.get("bias", encoder.get("decoder_bias", None))
|
|
913
920
|
),
|
|
@@ -915,6 +922,8 @@ def dictionary_learning_sae_huggingface_loader_1(
|
|
|
915
922
|
|
|
916
923
|
if "encoder.bias" in encoder:
|
|
917
924
|
state_dict["b_enc"] = encoder["encoder.bias"]
|
|
925
|
+
if "b_enc" in encoder:
|
|
926
|
+
state_dict["b_enc"] = encoder["b_enc"]
|
|
918
927
|
|
|
919
928
|
if "mag_bias" in encoder:
|
|
920
929
|
state_dict["b_mag"] = encoder["mag_bias"]
|
|
@@ -923,6 +932,12 @@ def dictionary_learning_sae_huggingface_loader_1(
|
|
|
923
932
|
if "r_mag" in encoder:
|
|
924
933
|
state_dict["r_mag"] = encoder["r_mag"]
|
|
925
934
|
|
|
935
|
+
if "threshold" in encoder:
|
|
936
|
+
threshold = encoder["threshold"]
|
|
937
|
+
if threshold.ndim == 0:
|
|
938
|
+
threshold = torch.full((W_enc.size(1),), threshold)
|
|
939
|
+
state_dict["threshold"] = threshold
|
|
940
|
+
|
|
926
941
|
return cfg_dict, state_dict, None
|
|
927
942
|
|
|
928
943
|
|
|
@@ -14959,3 +14959,81 @@ goodfire-llama-3.1-8b-instruct:
|
|
|
14959
14959
|
path: Llama-3.1-8B-Instruct-SAE-l19.pth
|
|
14960
14960
|
l0: 91
|
|
14961
14961
|
neuronpedia: llama3.1-8b-it/19-resid-post-gf
|
|
14962
|
+
|
|
14963
|
+
saebench_gemma-2-2b_width-2pow12_date-0108:
|
|
14964
|
+
conversion_func: dictionary_learning_1
|
|
14965
|
+
links:
|
|
14966
|
+
model: https://huggingface.co/google/gemma-2-2b
|
|
14967
|
+
model: gemma-2-2b
|
|
14968
|
+
repo_id: adamkarvonen/saebench_gemma-2-2b_width-2pow12_date-0108
|
|
14969
|
+
saes:
|
|
14970
|
+
- id: blocks.12.hook_resid_post__trainer_0
|
|
14971
|
+
neuronpedia: gemma-2-2b/12-sae_bench-matryoshka-res-4k__trainer_0_step_final
|
|
14972
|
+
path: MatryoshkaBatchTopK_gemma-2-2b__0108/resid_post_layer_12/trainer_0
|
|
14973
|
+
- id: blocks.12.hook_resid_post__trainer_1
|
|
14974
|
+
neuronpedia: gemma-2-2b/12-sae_bench-matryoshka-res-4k__trainer_1_step_final
|
|
14975
|
+
path: MatryoshkaBatchTopK_gemma-2-2b__0108/resid_post_layer_12/trainer_1
|
|
14976
|
+
- id: blocks.12.hook_resid_post__trainer_2
|
|
14977
|
+
neuronpedia: gemma-2-2b/12-sae_bench-matryoshka-res-4k__trainer_2_step_final
|
|
14978
|
+
path: MatryoshkaBatchTopK_gemma-2-2b__0108/resid_post_layer_12/trainer_2
|
|
14979
|
+
- id: blocks.12.hook_resid_post__trainer_3
|
|
14980
|
+
neuronpedia: gemma-2-2b/12-sae_bench-matryoshka-res-4k__trainer_3_step_final
|
|
14981
|
+
path: MatryoshkaBatchTopK_gemma-2-2b__0108/resid_post_layer_12/trainer_3
|
|
14982
|
+
- id: blocks.12.hook_resid_post__trainer_4
|
|
14983
|
+
neuronpedia: gemma-2-2b/12-sae_bench-matryoshka-res-4k__trainer_4_step_final
|
|
14984
|
+
path: MatryoshkaBatchTopK_gemma-2-2b__0108/resid_post_layer_12/trainer_4
|
|
14985
|
+
- id: blocks.12.hook_resid_post__trainer_5
|
|
14986
|
+
neuronpedia: gemma-2-2b/12-sae_bench-matryoshka-res-4k__trainer_5_step_final
|
|
14987
|
+
path: MatryoshkaBatchTopK_gemma-2-2b__0108/resid_post_layer_12/trainer_5
|
|
14988
|
+
|
|
14989
|
+
saebench_gemma-2-2b_width-2pow14_date-0107:
|
|
14990
|
+
conversion_func: dictionary_learning_1
|
|
14991
|
+
links:
|
|
14992
|
+
model: https://huggingface.co/google/gemma-2-2b
|
|
14993
|
+
model: gemma-2-2b
|
|
14994
|
+
repo_id: canrager/saebench_gemma-2-2b_width-2pow14_date-0107
|
|
14995
|
+
saes:
|
|
14996
|
+
- id: blocks.12.hook_resid_post__trainer_0
|
|
14997
|
+
neuronpedia: gemma-2-2b/12-sae_bench-matryoshka-res-16k__trainer_0_step_final
|
|
14998
|
+
path: gemma-2-2b_matryoshka_batch_top_k_width-2pow14_date-0107/resid_post_layer_12/trainer_0
|
|
14999
|
+
- id: blocks.12.hook_resid_post__trainer_1
|
|
15000
|
+
neuronpedia: gemma-2-2b/12-sae_bench-matryoshka-res-16k__trainer_1_step_final
|
|
15001
|
+
path: gemma-2-2b_matryoshka_batch_top_k_width-2pow14_date-0107/resid_post_layer_12/trainer_1
|
|
15002
|
+
- id: blocks.12.hook_resid_post__trainer_2
|
|
15003
|
+
neuronpedia: gemma-2-2b/12-sae_bench-matryoshka-res-16k__trainer_2_step_final
|
|
15004
|
+
path: gemma-2-2b_matryoshka_batch_top_k_width-2pow14_date-0107/resid_post_layer_12/trainer_2
|
|
15005
|
+
- id: blocks.12.hook_resid_post__trainer_3
|
|
15006
|
+
neuronpedia: gemma-2-2b/12-sae_bench-matryoshka-res-16k__trainer_3_step_final
|
|
15007
|
+
path: gemma-2-2b_matryoshka_batch_top_k_width-2pow14_date-0107/resid_post_layer_12/trainer_3
|
|
15008
|
+
- id: blocks.12.hook_resid_post__trainer_4
|
|
15009
|
+
neuronpedia: gemma-2-2b/12-sae_bench-matryoshka-res-16k__trainer_4_step_final
|
|
15010
|
+
path: gemma-2-2b_matryoshka_batch_top_k_width-2pow14_date-0107/resid_post_layer_12/trainer_4
|
|
15011
|
+
- id: blocks.12.hook_resid_post__trainer_5
|
|
15012
|
+
neuronpedia: gemma-2-2b/12-sae_bench-matryoshka-res-16k__trainer_5_step_final
|
|
15013
|
+
path: gemma-2-2b_matryoshka_batch_top_k_width-2pow14_date-0107/resid_post_layer_12/trainer_5
|
|
15014
|
+
|
|
15015
|
+
saebench_gemma-2-2b_width-2pow16_date-0107:
|
|
15016
|
+
conversion_func: dictionary_learning_1
|
|
15017
|
+
links:
|
|
15018
|
+
model: https://huggingface.co/google/gemma-2-2b
|
|
15019
|
+
model: gemma-2-2b
|
|
15020
|
+
repo_id: canrager/saebench_gemma-2-2b_width-2pow16_date-0107
|
|
15021
|
+
saes:
|
|
15022
|
+
- id: blocks.12.hook_resid_post__trainer_0
|
|
15023
|
+
neuronpedia: gemma-2-2b/12-sae_bench-matryoshka-res-65k__trainer_0_step_final
|
|
15024
|
+
path: gemma-2-2b_matryoshka_batch_top_k_width-2pow16_date-0107/resid_post_layer_12/trainer_0
|
|
15025
|
+
- id: blocks.12.hook_resid_post__trainer_1
|
|
15026
|
+
neuronpedia: gemma-2-2b/12-sae_bench-matryoshka-res-65k__trainer_1_step_final
|
|
15027
|
+
path: gemma-2-2b_matryoshka_batch_top_k_width-2pow16_date-0107/resid_post_layer_12/trainer_1
|
|
15028
|
+
- id: blocks.12.hook_resid_post__trainer_2
|
|
15029
|
+
neuronpedia: gemma-2-2b/12-sae_bench-matryoshka-res-65k__trainer_2_step_final
|
|
15030
|
+
path: gemma-2-2b_matryoshka_batch_top_k_width-2pow16_date-0107/resid_post_layer_12/trainer_2
|
|
15031
|
+
- id: blocks.12.hook_resid_post__trainer_3
|
|
15032
|
+
neuronpedia: gemma-2-2b/12-sae_bench-matryoshka-res-65k__trainer_3_step_final
|
|
15033
|
+
path: gemma-2-2b_matryoshka_batch_top_k_width-2pow16_date-0107/resid_post_layer_12/trainer_3
|
|
15034
|
+
- id: blocks.12.hook_resid_post__trainer_4
|
|
15035
|
+
neuronpedia: gemma-2-2b/12-sae_bench-matryoshka-res-65k__trainer_4_step_final
|
|
15036
|
+
path: gemma-2-2b_matryoshka_batch_top_k_width-2pow16_date-0107/resid_post_layer_12/trainer_4
|
|
15037
|
+
- id: blocks.12.hook_resid_post__trainer_5
|
|
15038
|
+
neuronpedia: gemma-2-2b/12-sae_bench-matryoshka-res-65k__trainer_5_step_final
|
|
15039
|
+
path: gemma-2-2b_matryoshka_batch_top_k_width-2pow16_date-0107/resid_post_layer_12/trainer_5
|
|
@@ -89,7 +89,7 @@ class GatedSAE(SAE[GatedSAEConfig]):
|
|
|
89
89
|
@torch.no_grad()
|
|
90
90
|
def fold_W_dec_norm(self):
|
|
91
91
|
"""Override to handle gated-specific parameters."""
|
|
92
|
-
W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
|
|
92
|
+
W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1)
|
|
93
93
|
self.W_dec.data = self.W_dec.data / W_dec_norms
|
|
94
94
|
self.W_enc.data = self.W_enc.data * W_dec_norms.T
|
|
95
95
|
|
|
@@ -217,7 +217,7 @@ class GatedTrainingSAE(TrainingSAE[GatedTrainingSAEConfig]):
|
|
|
217
217
|
@torch.no_grad()
|
|
218
218
|
def fold_W_dec_norm(self):
|
|
219
219
|
"""Override to handle gated-specific parameters."""
|
|
220
|
-
W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
|
|
220
|
+
W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1)
|
|
221
221
|
self.W_dec.data = self.W_dec.data / W_dec_norms
|
|
222
222
|
self.W_enc.data = self.W_enc.data * W_dec_norms.T
|
|
223
223
|
|
|
@@ -167,8 +167,8 @@ class JumpReLUSAE(SAE[JumpReLUSAEConfig]):
|
|
|
167
167
|
# Save the current threshold before calling parent method
|
|
168
168
|
current_thresh = self.threshold.clone()
|
|
169
169
|
|
|
170
|
-
# Get W_dec norms that will be used for scaling
|
|
171
|
-
W_dec_norms = self.W_dec.norm(dim=-1)
|
|
170
|
+
# Get W_dec norms that will be used for scaling (clamped to avoid division by zero)
|
|
171
|
+
W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8)
|
|
172
172
|
|
|
173
173
|
# Call parent implementation to handle W_enc, W_dec, and b_enc adjustment
|
|
174
174
|
super().fold_W_dec_norm()
|
|
@@ -325,8 +325,8 @@ class JumpReLUTrainingSAE(TrainingSAE[JumpReLUTrainingSAEConfig]):
|
|
|
325
325
|
# Save the current threshold before we call the parent method
|
|
326
326
|
current_thresh = self.threshold.clone()
|
|
327
327
|
|
|
328
|
-
# Get W_dec norms
|
|
329
|
-
W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
|
|
328
|
+
# Get W_dec norms (clamped to avoid division by zero)
|
|
329
|
+
W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1)
|
|
330
330
|
|
|
331
331
|
# Call parent implementation to handle W_enc and W_dec adjustment
|
|
332
332
|
super().fold_W_dec_norm()
|
|
@@ -484,7 +484,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
484
484
|
@torch.no_grad()
|
|
485
485
|
def fold_W_dec_norm(self):
|
|
486
486
|
"""Fold decoder norms into encoder."""
|
|
487
|
-
W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
|
|
487
|
+
W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1)
|
|
488
488
|
self.W_dec.data = self.W_dec.data / W_dec_norms
|
|
489
489
|
self.W_enc.data = self.W_enc.data * W_dec_norms.T
|
|
490
490
|
|
|
@@ -531,7 +531,7 @@ def _fold_norm_topk(
|
|
|
531
531
|
b_enc: torch.Tensor,
|
|
532
532
|
W_dec: torch.Tensor,
|
|
533
533
|
) -> None:
|
|
534
|
-
W_dec_norm = W_dec.norm(dim=-1)
|
|
534
|
+
W_dec_norm = W_dec.norm(dim=-1).clamp(min=1e-8)
|
|
535
535
|
b_enc.data = b_enc.data * W_dec_norm
|
|
536
536
|
W_dec_norms = W_dec_norm.unsqueeze(1)
|
|
537
537
|
W_dec.data = W_dec.data / W_dec_norms
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|