sae-lens 5.10.6__py3-none-any.whl → 5.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +1 -1
- sae_lens/config.py +1 -0
- sae_lens/load_model.py +1 -1
- sae_lens/toolkit/pretrained_sae_loaders.py +143 -0
- {sae_lens-5.10.6.dist-info → sae_lens-5.11.0.dist-info}/METADATA +2 -2
- {sae_lens-5.10.6.dist-info → sae_lens-5.11.0.dist-info}/RECORD +8 -8
- {sae_lens-5.10.6.dist-info → sae_lens-5.11.0.dist-info}/LICENSE +0 -0
- {sae_lens-5.10.6.dist-info → sae_lens-5.11.0.dist-info}/WHEEL +0 -0
sae_lens/__init__.py
CHANGED
sae_lens/config.py
CHANGED
|
@@ -33,6 +33,7 @@ HfDataset = DatasetDict | Dataset | IterableDatasetDict | IterableDataset
|
|
|
33
33
|
|
|
34
34
|
SPARSITY_FILENAME = "sparsity.safetensors"
|
|
35
35
|
SAE_WEIGHTS_FILENAME = "sae_weights.safetensors"
|
|
36
|
+
SPARSIFY_WEIGHTS_FILENAME = "sae.safetensors"
|
|
36
37
|
SAE_CFG_FILENAME = "cfg.json"
|
|
37
38
|
|
|
38
39
|
|
sae_lens/load_model.py
CHANGED
|
@@ -159,7 +159,7 @@ class HookedProxyLM(HookedRootModule):
|
|
|
159
159
|
|
|
160
160
|
# We don't want to prepend bos but the tokenizer does it automatically, so we remove it manually
|
|
161
161
|
if hasattr(self.tokenizer, "add_bos_token") and self.tokenizer.add_bos_token: # type: ignore
|
|
162
|
-
tokens = get_tokens_with_bos_removed(self.tokenizer, tokens)
|
|
162
|
+
tokens = get_tokens_with_bos_removed(self.tokenizer, tokens) # type: ignore
|
|
163
163
|
return tokens # type: ignore
|
|
164
164
|
|
|
165
165
|
|
|
@@ -15,6 +15,7 @@ from sae_lens.config import (
|
|
|
15
15
|
DTYPE_MAP,
|
|
16
16
|
SAE_CFG_FILENAME,
|
|
17
17
|
SAE_WEIGHTS_FILENAME,
|
|
18
|
+
SPARSIFY_WEIGHTS_FILENAME,
|
|
18
19
|
SPARSITY_FILENAME,
|
|
19
20
|
)
|
|
20
21
|
from sae_lens.toolkit.pretrained_saes_directory import (
|
|
@@ -898,6 +899,146 @@ def llama_scope_r1_distill_sae_huggingface_loader(
|
|
|
898
899
|
return cfg_dict, state_dict, log_sparsity
|
|
899
900
|
|
|
900
901
|
|
|
902
|
+
def get_sparsify_config_from_hf(
|
|
903
|
+
repo_id: str,
|
|
904
|
+
folder_name: str,
|
|
905
|
+
device: str,
|
|
906
|
+
force_download: bool = False,
|
|
907
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
908
|
+
) -> dict[str, Any]:
|
|
909
|
+
cfg_filename = f"{folder_name}/{SAE_CFG_FILENAME}"
|
|
910
|
+
cfg_path = hf_hub_download(
|
|
911
|
+
repo_id,
|
|
912
|
+
filename=cfg_filename,
|
|
913
|
+
force_download=force_download,
|
|
914
|
+
)
|
|
915
|
+
sae_path = Path(cfg_path).parent
|
|
916
|
+
return get_sparsify_config_from_disk(
|
|
917
|
+
sae_path, device=device, cfg_overrides=cfg_overrides
|
|
918
|
+
)
|
|
919
|
+
|
|
920
|
+
|
|
921
|
+
def get_sparsify_config_from_disk(
|
|
922
|
+
path: str | Path,
|
|
923
|
+
device: str | None = None,
|
|
924
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
925
|
+
) -> dict[str, Any]:
|
|
926
|
+
path = Path(path)
|
|
927
|
+
|
|
928
|
+
with open(path / SAE_CFG_FILENAME) as f:
|
|
929
|
+
old_cfg_dict = json.load(f)
|
|
930
|
+
|
|
931
|
+
config_path = path.parent / "config.json"
|
|
932
|
+
if config_path.exists():
|
|
933
|
+
with open(config_path) as f:
|
|
934
|
+
config_dict = json.load(f)
|
|
935
|
+
else:
|
|
936
|
+
config_dict = {}
|
|
937
|
+
|
|
938
|
+
folder_name = path.name
|
|
939
|
+
if folder_name == "embed_tokens":
|
|
940
|
+
hook_name, layer = "hook_embed", 0
|
|
941
|
+
else:
|
|
942
|
+
match = re.search(r"layers[._](\d+)", folder_name)
|
|
943
|
+
if match is None:
|
|
944
|
+
raise ValueError(f"Unrecognized Sparsify folder: {folder_name}")
|
|
945
|
+
layer = int(match.group(1))
|
|
946
|
+
hook_name = f"blocks.{layer}.hook_resid_post"
|
|
947
|
+
|
|
948
|
+
cfg_dict: dict[str, Any] = {
|
|
949
|
+
"architecture": "standard",
|
|
950
|
+
"d_in": old_cfg_dict["d_in"],
|
|
951
|
+
"d_sae": old_cfg_dict["d_in"] * old_cfg_dict["expansion_factor"],
|
|
952
|
+
"dtype": "bfloat16",
|
|
953
|
+
"device": device or "cpu",
|
|
954
|
+
"model_name": config_dict.get("model", path.parts[-2]),
|
|
955
|
+
"hook_name": hook_name,
|
|
956
|
+
"hook_layer": layer,
|
|
957
|
+
"hook_head_index": None,
|
|
958
|
+
"activation_fn_str": "topk",
|
|
959
|
+
"activation_fn_kwargs": {
|
|
960
|
+
"k": old_cfg_dict["k"],
|
|
961
|
+
"signed": old_cfg_dict.get("signed", False),
|
|
962
|
+
},
|
|
963
|
+
"apply_b_dec_to_input": not old_cfg_dict.get("normalize_decoder", False),
|
|
964
|
+
"dataset_path": config_dict.get(
|
|
965
|
+
"dataset", "togethercomputer/RedPajama-Data-1T-Sample"
|
|
966
|
+
),
|
|
967
|
+
"context_size": config_dict.get("ctx_len", 2048),
|
|
968
|
+
"finetuning_scaling_factor": False,
|
|
969
|
+
"sae_lens_training_version": None,
|
|
970
|
+
"prepend_bos": True,
|
|
971
|
+
"dataset_trust_remote_code": True,
|
|
972
|
+
"normalize_activations": "none",
|
|
973
|
+
"neuronpedia_id": None,
|
|
974
|
+
}
|
|
975
|
+
|
|
976
|
+
if cfg_overrides:
|
|
977
|
+
cfg_dict.update(cfg_overrides)
|
|
978
|
+
|
|
979
|
+
return cfg_dict
|
|
980
|
+
|
|
981
|
+
|
|
982
|
+
def sparsify_huggingface_loader(
|
|
983
|
+
repo_id: str,
|
|
984
|
+
folder_name: str,
|
|
985
|
+
device: str = "cpu",
|
|
986
|
+
force_download: bool = False,
|
|
987
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
988
|
+
) -> tuple[dict[str, Any], dict[str, torch.Tensor], None]:
|
|
989
|
+
weights_filename = f"{folder_name}/{SPARSIFY_WEIGHTS_FILENAME}"
|
|
990
|
+
sae_path = hf_hub_download(
|
|
991
|
+
repo_id,
|
|
992
|
+
filename=weights_filename,
|
|
993
|
+
force_download=force_download,
|
|
994
|
+
)
|
|
995
|
+
cfg_dict, state_dict = sparsify_disk_loader(
|
|
996
|
+
Path(sae_path).parent, device=device, cfg_overrides=cfg_overrides
|
|
997
|
+
)
|
|
998
|
+
return cfg_dict, state_dict, None
|
|
999
|
+
|
|
1000
|
+
|
|
1001
|
+
def sparsify_disk_loader(
|
|
1002
|
+
path: str | Path,
|
|
1003
|
+
device: str = "cpu",
|
|
1004
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
1005
|
+
) -> tuple[dict[str, Any], dict[str, torch.Tensor]]:
|
|
1006
|
+
cfg_dict = get_sparsify_config_from_disk(path, device, cfg_overrides)
|
|
1007
|
+
|
|
1008
|
+
weight_path = Path(path) / SPARSIFY_WEIGHTS_FILENAME
|
|
1009
|
+
state_dict_loaded = load_file(weight_path, device=device)
|
|
1010
|
+
|
|
1011
|
+
dtype = DTYPE_MAP[cfg_dict["dtype"]]
|
|
1012
|
+
|
|
1013
|
+
W_enc = (
|
|
1014
|
+
state_dict_loaded["W_enc"]
|
|
1015
|
+
if "W_enc" in state_dict_loaded
|
|
1016
|
+
else state_dict_loaded["encoder.weight"].T
|
|
1017
|
+
).to(dtype)
|
|
1018
|
+
|
|
1019
|
+
if "W_dec" in state_dict_loaded:
|
|
1020
|
+
W_dec = state_dict_loaded["W_dec"].T.to(dtype)
|
|
1021
|
+
else:
|
|
1022
|
+
W_dec = state_dict_loaded["decoder.weight"].T.to(dtype)
|
|
1023
|
+
|
|
1024
|
+
if "b_enc" in state_dict_loaded:
|
|
1025
|
+
b_enc = state_dict_loaded["b_enc"].to(dtype)
|
|
1026
|
+
elif "encoder.bias" in state_dict_loaded:
|
|
1027
|
+
b_enc = state_dict_loaded["encoder.bias"].to(dtype)
|
|
1028
|
+
else:
|
|
1029
|
+
b_enc = torch.zeros(cfg_dict["d_sae"], dtype=dtype, device=device)
|
|
1030
|
+
|
|
1031
|
+
if "b_dec" in state_dict_loaded:
|
|
1032
|
+
b_dec = state_dict_loaded["b_dec"].to(dtype)
|
|
1033
|
+
elif "decoder.bias" in state_dict_loaded:
|
|
1034
|
+
b_dec = state_dict_loaded["decoder.bias"].to(dtype)
|
|
1035
|
+
else:
|
|
1036
|
+
b_dec = torch.zeros(cfg_dict["d_in"], dtype=dtype, device=device)
|
|
1037
|
+
|
|
1038
|
+
state_dict = {"W_enc": W_enc, "b_enc": b_enc, "W_dec": W_dec, "b_dec": b_dec}
|
|
1039
|
+
return cfg_dict, state_dict
|
|
1040
|
+
|
|
1041
|
+
|
|
901
1042
|
NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
|
|
902
1043
|
"sae_lens": sae_lens_huggingface_loader,
|
|
903
1044
|
"connor_rob_hook_z": connor_rob_hook_z_huggingface_loader,
|
|
@@ -906,6 +1047,7 @@ NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
|
|
|
906
1047
|
"llama_scope_r1_distill": llama_scope_r1_distill_sae_huggingface_loader,
|
|
907
1048
|
"dictionary_learning_1": dictionary_learning_sae_huggingface_loader_1,
|
|
908
1049
|
"deepseek_r1": deepseek_r1_sae_huggingface_loader,
|
|
1050
|
+
"sparsify": sparsify_huggingface_loader,
|
|
909
1051
|
}
|
|
910
1052
|
|
|
911
1053
|
|
|
@@ -917,4 +1059,5 @@ NAMED_PRETRAINED_SAE_CONFIG_GETTERS: dict[str, PretrainedSaeConfigHuggingfaceLoa
|
|
|
917
1059
|
"llama_scope_r1_distill": get_llama_scope_r1_distill_config_from_hf,
|
|
918
1060
|
"dictionary_learning_1": get_dictionary_learning_config_1_from_hf,
|
|
919
1061
|
"deepseek_r1": get_deepseek_r1_config_from_hf,
|
|
1062
|
+
"sparsify": get_sparsify_config_from_hf,
|
|
920
1063
|
}
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: sae-lens
|
|
3
|
-
Version: 5.
|
|
3
|
+
Version: 5.11.0
|
|
4
4
|
Summary: Training and Analyzing Sparse Autoencoders (SAEs)
|
|
5
5
|
License: MIT
|
|
6
6
|
Keywords: deep-learning,sparse-autoencoders,mechanistic-interpretability,PyTorch
|
|
@@ -80,7 +80,7 @@ Pre-trained SAEs for various models can be imported via SAE Lens. See this [page
|
|
|
80
80
|
|
|
81
81
|
## Join the Slack!
|
|
82
82
|
|
|
83
|
-
Feel free to join the [Open Source Mechanistic Interpretability Slack](https://join.slack.com/t/opensourcemechanistic/shared_invite/zt-
|
|
83
|
+
Feel free to join the [Open Source Mechanistic Interpretability Slack](https://join.slack.com/t/opensourcemechanistic/shared_invite/zt-375zalm04-GFd5tdBU1yLKlu_T_JSqZQ) for support!
|
|
84
84
|
|
|
85
85
|
## Citation
|
|
86
86
|
|
|
@@ -1,18 +1,18 @@
|
|
|
1
|
-
sae_lens/__init__.py,sha256=
|
|
1
|
+
sae_lens/__init__.py,sha256=A2ttZHoobEQm6YKaCqWrztd6LIDGmXlOvuyfp1aGb_E,1307
|
|
2
2
|
sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
3
|
sae_lens/analysis/hooked_sae_transformer.py,sha256=fkSsvWcTM_d7M3rRM6N2oFpeSvGhj_ENZtqmfWOzZTQ,13717
|
|
4
4
|
sae_lens/analysis/neuronpedia_integration.py,sha256=dFiKRWfuT5iUfTPBPmZydSaNG3VwqZ1asuNbbQv_NCM,18488
|
|
5
5
|
sae_lens/cache_activations_runner.py,sha256=dGK5EHJMHAKDAFyr25fy1COSm-61q-q6kpWENHFMaKk,12561
|
|
6
|
-
sae_lens/config.py,sha256=
|
|
6
|
+
sae_lens/config.py,sha256=y3dgA_lNSpwi_n442dtrQ6RxfFKbnvUKjb7Qe1ZNoA4,33034
|
|
7
7
|
sae_lens/evals.py,sha256=7cuLlT0ZTAhZ7eQbsZEFT-M3oixmaXSCBJtjh9hGnVQ,38527
|
|
8
|
-
sae_lens/load_model.py,sha256=
|
|
8
|
+
sae_lens/load_model.py,sha256=TRxyUpudPCwGzSccQiHxww9OtLiwBBRurvi-HUnfdKg,6787
|
|
9
9
|
sae_lens/pretokenize_runner.py,sha256=0nHQq3s_d80VS8iVK4-e6y_orAYVO8c4RrLGtIDfK_E,6885
|
|
10
10
|
sae_lens/pretrained_saes.yaml,sha256=nhHW1auhyi4GHYrjUnHQqbNVhI5cMJv-HThzbzU1xG0,574145
|
|
11
11
|
sae_lens/sae.py,sha256=8DFVKG72Ml_hVm49YIHJ0zAS6Pbd7O_7wDkQV5kyhxk,27965
|
|
12
12
|
sae_lens/sae_training_runner.py,sha256=tduPN8BGtMatua0bNY-tXGGGxhedMu6F_O9ugDOdRmQ,9004
|
|
13
13
|
sae_lens/tokenization_and_batching.py,sha256=oUAscjy_LPOrOb8_Ty6eLAcZ0B3HB_wiWjWktgolhG0,4314
|
|
14
14
|
sae_lens/toolkit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
15
|
-
sae_lens/toolkit/pretrained_sae_loaders.py,sha256=
|
|
15
|
+
sae_lens/toolkit/pretrained_sae_loaders.py,sha256=vzDdDjy7EnNpmOcEwwXyd0AzzX4Up2Gdhs8wlogww8M,34840
|
|
16
16
|
sae_lens/toolkit/pretrained_saes_directory.py,sha256=4Vn-Jex6SveD7EbxcSOBv8cx1gkPfUMLU1QOP-ww1ZE,3752
|
|
17
17
|
sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
18
18
|
sae_lens/training/activations_store.py,sha256=1H9sb1mAVrlHU5WVci1FFcCVNj1BydezI4zu7_Usm3s,35985
|
|
@@ -22,7 +22,7 @@ sae_lens/training/sae_trainer.py,sha256=xenSV0xw06y1_qLhw82_966DmWOp2nydqlrVgJA6
|
|
|
22
22
|
sae_lens/training/training_sae.py,sha256=0A4x74qUfinLhwaK9RSoWZ7POrGc8kIU5EgBOp4UJtE,27998
|
|
23
23
|
sae_lens/training/upload_saes_to_huggingface.py,sha256=P1K3nxv-IM7JptfLHj5Agiis7A_adn-g_tiq1d8PdaU,4361
|
|
24
24
|
sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
|
|
25
|
-
sae_lens-5.
|
|
26
|
-
sae_lens-5.
|
|
27
|
-
sae_lens-5.
|
|
28
|
-
sae_lens-5.
|
|
25
|
+
sae_lens-5.11.0.dist-info/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
|
|
26
|
+
sae_lens-5.11.0.dist-info/METADATA,sha256=LPXNsAtjYMtDveFW2aBFrLCjWmmnN7hiFt3WIDOxpBU,5324
|
|
27
|
+
sae_lens-5.11.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
|
28
|
+
sae_lens-5.11.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|