sae-lens 5.10.7__py3-none-any.whl → 5.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +1 -1
- sae_lens/config.py +1 -0
- sae_lens/load_model.py +1 -1
- sae_lens/toolkit/pretrained_sae_loaders.py +143 -0
- {sae_lens-5.10.7.dist-info → sae_lens-5.11.0.dist-info}/METADATA +1 -1
- {sae_lens-5.10.7.dist-info → sae_lens-5.11.0.dist-info}/RECORD +8 -8
- {sae_lens-5.10.7.dist-info → sae_lens-5.11.0.dist-info}/LICENSE +0 -0
- {sae_lens-5.10.7.dist-info → sae_lens-5.11.0.dist-info}/WHEEL +0 -0
sae_lens/__init__.py
CHANGED
sae_lens/config.py
CHANGED
|
@@ -33,6 +33,7 @@ HfDataset = DatasetDict | Dataset | IterableDatasetDict | IterableDataset
|
|
|
33
33
|
|
|
34
34
|
SPARSITY_FILENAME = "sparsity.safetensors"
|
|
35
35
|
SAE_WEIGHTS_FILENAME = "sae_weights.safetensors"
|
|
36
|
+
SPARSIFY_WEIGHTS_FILENAME = "sae.safetensors"
|
|
36
37
|
SAE_CFG_FILENAME = "cfg.json"
|
|
37
38
|
|
|
38
39
|
|
sae_lens/load_model.py
CHANGED
|
@@ -159,7 +159,7 @@ class HookedProxyLM(HookedRootModule):
|
|
|
159
159
|
|
|
160
160
|
# We don't want to prepend bos but the tokenizer does it automatically, so we remove it manually
|
|
161
161
|
if hasattr(self.tokenizer, "add_bos_token") and self.tokenizer.add_bos_token: # type: ignore
|
|
162
|
-
tokens = get_tokens_with_bos_removed(self.tokenizer, tokens)
|
|
162
|
+
tokens = get_tokens_with_bos_removed(self.tokenizer, tokens) # type: ignore
|
|
163
163
|
return tokens # type: ignore
|
|
164
164
|
|
|
165
165
|
|
|
@@ -15,6 +15,7 @@ from sae_lens.config import (
|
|
|
15
15
|
DTYPE_MAP,
|
|
16
16
|
SAE_CFG_FILENAME,
|
|
17
17
|
SAE_WEIGHTS_FILENAME,
|
|
18
|
+
SPARSIFY_WEIGHTS_FILENAME,
|
|
18
19
|
SPARSITY_FILENAME,
|
|
19
20
|
)
|
|
20
21
|
from sae_lens.toolkit.pretrained_saes_directory import (
|
|
@@ -898,6 +899,146 @@ def llama_scope_r1_distill_sae_huggingface_loader(
|
|
|
898
899
|
return cfg_dict, state_dict, log_sparsity
|
|
899
900
|
|
|
900
901
|
|
|
902
|
+
def get_sparsify_config_from_hf(
|
|
903
|
+
repo_id: str,
|
|
904
|
+
folder_name: str,
|
|
905
|
+
device: str,
|
|
906
|
+
force_download: bool = False,
|
|
907
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
908
|
+
) -> dict[str, Any]:
|
|
909
|
+
cfg_filename = f"{folder_name}/{SAE_CFG_FILENAME}"
|
|
910
|
+
cfg_path = hf_hub_download(
|
|
911
|
+
repo_id,
|
|
912
|
+
filename=cfg_filename,
|
|
913
|
+
force_download=force_download,
|
|
914
|
+
)
|
|
915
|
+
sae_path = Path(cfg_path).parent
|
|
916
|
+
return get_sparsify_config_from_disk(
|
|
917
|
+
sae_path, device=device, cfg_overrides=cfg_overrides
|
|
918
|
+
)
|
|
919
|
+
|
|
920
|
+
|
|
921
|
+
def get_sparsify_config_from_disk(
|
|
922
|
+
path: str | Path,
|
|
923
|
+
device: str | None = None,
|
|
924
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
925
|
+
) -> dict[str, Any]:
|
|
926
|
+
path = Path(path)
|
|
927
|
+
|
|
928
|
+
with open(path / SAE_CFG_FILENAME) as f:
|
|
929
|
+
old_cfg_dict = json.load(f)
|
|
930
|
+
|
|
931
|
+
config_path = path.parent / "config.json"
|
|
932
|
+
if config_path.exists():
|
|
933
|
+
with open(config_path) as f:
|
|
934
|
+
config_dict = json.load(f)
|
|
935
|
+
else:
|
|
936
|
+
config_dict = {}
|
|
937
|
+
|
|
938
|
+
folder_name = path.name
|
|
939
|
+
if folder_name == "embed_tokens":
|
|
940
|
+
hook_name, layer = "hook_embed", 0
|
|
941
|
+
else:
|
|
942
|
+
match = re.search(r"layers[._](\d+)", folder_name)
|
|
943
|
+
if match is None:
|
|
944
|
+
raise ValueError(f"Unrecognized Sparsify folder: {folder_name}")
|
|
945
|
+
layer = int(match.group(1))
|
|
946
|
+
hook_name = f"blocks.{layer}.hook_resid_post"
|
|
947
|
+
|
|
948
|
+
cfg_dict: dict[str, Any] = {
|
|
949
|
+
"architecture": "standard",
|
|
950
|
+
"d_in": old_cfg_dict["d_in"],
|
|
951
|
+
"d_sae": old_cfg_dict["d_in"] * old_cfg_dict["expansion_factor"],
|
|
952
|
+
"dtype": "bfloat16",
|
|
953
|
+
"device": device or "cpu",
|
|
954
|
+
"model_name": config_dict.get("model", path.parts[-2]),
|
|
955
|
+
"hook_name": hook_name,
|
|
956
|
+
"hook_layer": layer,
|
|
957
|
+
"hook_head_index": None,
|
|
958
|
+
"activation_fn_str": "topk",
|
|
959
|
+
"activation_fn_kwargs": {
|
|
960
|
+
"k": old_cfg_dict["k"],
|
|
961
|
+
"signed": old_cfg_dict.get("signed", False),
|
|
962
|
+
},
|
|
963
|
+
"apply_b_dec_to_input": not old_cfg_dict.get("normalize_decoder", False),
|
|
964
|
+
"dataset_path": config_dict.get(
|
|
965
|
+
"dataset", "togethercomputer/RedPajama-Data-1T-Sample"
|
|
966
|
+
),
|
|
967
|
+
"context_size": config_dict.get("ctx_len", 2048),
|
|
968
|
+
"finetuning_scaling_factor": False,
|
|
969
|
+
"sae_lens_training_version": None,
|
|
970
|
+
"prepend_bos": True,
|
|
971
|
+
"dataset_trust_remote_code": True,
|
|
972
|
+
"normalize_activations": "none",
|
|
973
|
+
"neuronpedia_id": None,
|
|
974
|
+
}
|
|
975
|
+
|
|
976
|
+
if cfg_overrides:
|
|
977
|
+
cfg_dict.update(cfg_overrides)
|
|
978
|
+
|
|
979
|
+
return cfg_dict
|
|
980
|
+
|
|
981
|
+
|
|
982
|
+
def sparsify_huggingface_loader(
|
|
983
|
+
repo_id: str,
|
|
984
|
+
folder_name: str,
|
|
985
|
+
device: str = "cpu",
|
|
986
|
+
force_download: bool = False,
|
|
987
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
988
|
+
) -> tuple[dict[str, Any], dict[str, torch.Tensor], None]:
|
|
989
|
+
weights_filename = f"{folder_name}/{SPARSIFY_WEIGHTS_FILENAME}"
|
|
990
|
+
sae_path = hf_hub_download(
|
|
991
|
+
repo_id,
|
|
992
|
+
filename=weights_filename,
|
|
993
|
+
force_download=force_download,
|
|
994
|
+
)
|
|
995
|
+
cfg_dict, state_dict = sparsify_disk_loader(
|
|
996
|
+
Path(sae_path).parent, device=device, cfg_overrides=cfg_overrides
|
|
997
|
+
)
|
|
998
|
+
return cfg_dict, state_dict, None
|
|
999
|
+
|
|
1000
|
+
|
|
1001
|
+
def sparsify_disk_loader(
|
|
1002
|
+
path: str | Path,
|
|
1003
|
+
device: str = "cpu",
|
|
1004
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
1005
|
+
) -> tuple[dict[str, Any], dict[str, torch.Tensor]]:
|
|
1006
|
+
cfg_dict = get_sparsify_config_from_disk(path, device, cfg_overrides)
|
|
1007
|
+
|
|
1008
|
+
weight_path = Path(path) / SPARSIFY_WEIGHTS_FILENAME
|
|
1009
|
+
state_dict_loaded = load_file(weight_path, device=device)
|
|
1010
|
+
|
|
1011
|
+
dtype = DTYPE_MAP[cfg_dict["dtype"]]
|
|
1012
|
+
|
|
1013
|
+
W_enc = (
|
|
1014
|
+
state_dict_loaded["W_enc"]
|
|
1015
|
+
if "W_enc" in state_dict_loaded
|
|
1016
|
+
else state_dict_loaded["encoder.weight"].T
|
|
1017
|
+
).to(dtype)
|
|
1018
|
+
|
|
1019
|
+
if "W_dec" in state_dict_loaded:
|
|
1020
|
+
W_dec = state_dict_loaded["W_dec"].T.to(dtype)
|
|
1021
|
+
else:
|
|
1022
|
+
W_dec = state_dict_loaded["decoder.weight"].T.to(dtype)
|
|
1023
|
+
|
|
1024
|
+
if "b_enc" in state_dict_loaded:
|
|
1025
|
+
b_enc = state_dict_loaded["b_enc"].to(dtype)
|
|
1026
|
+
elif "encoder.bias" in state_dict_loaded:
|
|
1027
|
+
b_enc = state_dict_loaded["encoder.bias"].to(dtype)
|
|
1028
|
+
else:
|
|
1029
|
+
b_enc = torch.zeros(cfg_dict["d_sae"], dtype=dtype, device=device)
|
|
1030
|
+
|
|
1031
|
+
if "b_dec" in state_dict_loaded:
|
|
1032
|
+
b_dec = state_dict_loaded["b_dec"].to(dtype)
|
|
1033
|
+
elif "decoder.bias" in state_dict_loaded:
|
|
1034
|
+
b_dec = state_dict_loaded["decoder.bias"].to(dtype)
|
|
1035
|
+
else:
|
|
1036
|
+
b_dec = torch.zeros(cfg_dict["d_in"], dtype=dtype, device=device)
|
|
1037
|
+
|
|
1038
|
+
state_dict = {"W_enc": W_enc, "b_enc": b_enc, "W_dec": W_dec, "b_dec": b_dec}
|
|
1039
|
+
return cfg_dict, state_dict
|
|
1040
|
+
|
|
1041
|
+
|
|
901
1042
|
NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
|
|
902
1043
|
"sae_lens": sae_lens_huggingface_loader,
|
|
903
1044
|
"connor_rob_hook_z": connor_rob_hook_z_huggingface_loader,
|
|
@@ -906,6 +1047,7 @@ NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
|
|
|
906
1047
|
"llama_scope_r1_distill": llama_scope_r1_distill_sae_huggingface_loader,
|
|
907
1048
|
"dictionary_learning_1": dictionary_learning_sae_huggingface_loader_1,
|
|
908
1049
|
"deepseek_r1": deepseek_r1_sae_huggingface_loader,
|
|
1050
|
+
"sparsify": sparsify_huggingface_loader,
|
|
909
1051
|
}
|
|
910
1052
|
|
|
911
1053
|
|
|
@@ -917,4 +1059,5 @@ NAMED_PRETRAINED_SAE_CONFIG_GETTERS: dict[str, PretrainedSaeConfigHuggingfaceLoa
|
|
|
917
1059
|
"llama_scope_r1_distill": get_llama_scope_r1_distill_config_from_hf,
|
|
918
1060
|
"dictionary_learning_1": get_dictionary_learning_config_1_from_hf,
|
|
919
1061
|
"deepseek_r1": get_deepseek_r1_config_from_hf,
|
|
1062
|
+
"sparsify": get_sparsify_config_from_hf,
|
|
920
1063
|
}
|
|
@@ -1,18 +1,18 @@
|
|
|
1
|
-
sae_lens/__init__.py,sha256=
|
|
1
|
+
sae_lens/__init__.py,sha256=A2ttZHoobEQm6YKaCqWrztd6LIDGmXlOvuyfp1aGb_E,1307
|
|
2
2
|
sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
3
|
sae_lens/analysis/hooked_sae_transformer.py,sha256=fkSsvWcTM_d7M3rRM6N2oFpeSvGhj_ENZtqmfWOzZTQ,13717
|
|
4
4
|
sae_lens/analysis/neuronpedia_integration.py,sha256=dFiKRWfuT5iUfTPBPmZydSaNG3VwqZ1asuNbbQv_NCM,18488
|
|
5
5
|
sae_lens/cache_activations_runner.py,sha256=dGK5EHJMHAKDAFyr25fy1COSm-61q-q6kpWENHFMaKk,12561
|
|
6
|
-
sae_lens/config.py,sha256=
|
|
6
|
+
sae_lens/config.py,sha256=y3dgA_lNSpwi_n442dtrQ6RxfFKbnvUKjb7Qe1ZNoA4,33034
|
|
7
7
|
sae_lens/evals.py,sha256=7cuLlT0ZTAhZ7eQbsZEFT-M3oixmaXSCBJtjh9hGnVQ,38527
|
|
8
|
-
sae_lens/load_model.py,sha256=
|
|
8
|
+
sae_lens/load_model.py,sha256=TRxyUpudPCwGzSccQiHxww9OtLiwBBRurvi-HUnfdKg,6787
|
|
9
9
|
sae_lens/pretokenize_runner.py,sha256=0nHQq3s_d80VS8iVK4-e6y_orAYVO8c4RrLGtIDfK_E,6885
|
|
10
10
|
sae_lens/pretrained_saes.yaml,sha256=nhHW1auhyi4GHYrjUnHQqbNVhI5cMJv-HThzbzU1xG0,574145
|
|
11
11
|
sae_lens/sae.py,sha256=8DFVKG72Ml_hVm49YIHJ0zAS6Pbd7O_7wDkQV5kyhxk,27965
|
|
12
12
|
sae_lens/sae_training_runner.py,sha256=tduPN8BGtMatua0bNY-tXGGGxhedMu6F_O9ugDOdRmQ,9004
|
|
13
13
|
sae_lens/tokenization_and_batching.py,sha256=oUAscjy_LPOrOb8_Ty6eLAcZ0B3HB_wiWjWktgolhG0,4314
|
|
14
14
|
sae_lens/toolkit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
15
|
-
sae_lens/toolkit/pretrained_sae_loaders.py,sha256=
|
|
15
|
+
sae_lens/toolkit/pretrained_sae_loaders.py,sha256=vzDdDjy7EnNpmOcEwwXyd0AzzX4Up2Gdhs8wlogww8M,34840
|
|
16
16
|
sae_lens/toolkit/pretrained_saes_directory.py,sha256=4Vn-Jex6SveD7EbxcSOBv8cx1gkPfUMLU1QOP-ww1ZE,3752
|
|
17
17
|
sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
18
18
|
sae_lens/training/activations_store.py,sha256=1H9sb1mAVrlHU5WVci1FFcCVNj1BydezI4zu7_Usm3s,35985
|
|
@@ -22,7 +22,7 @@ sae_lens/training/sae_trainer.py,sha256=xenSV0xw06y1_qLhw82_966DmWOp2nydqlrVgJA6
|
|
|
22
22
|
sae_lens/training/training_sae.py,sha256=0A4x74qUfinLhwaK9RSoWZ7POrGc8kIU5EgBOp4UJtE,27998
|
|
23
23
|
sae_lens/training/upload_saes_to_huggingface.py,sha256=P1K3nxv-IM7JptfLHj5Agiis7A_adn-g_tiq1d8PdaU,4361
|
|
24
24
|
sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
|
|
25
|
-
sae_lens-5.
|
|
26
|
-
sae_lens-5.
|
|
27
|
-
sae_lens-5.
|
|
28
|
-
sae_lens-5.
|
|
25
|
+
sae_lens-5.11.0.dist-info/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
|
|
26
|
+
sae_lens-5.11.0.dist-info/METADATA,sha256=LPXNsAtjYMtDveFW2aBFrLCjWmmnN7hiFt3WIDOxpBU,5324
|
|
27
|
+
sae_lens-5.11.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
|
28
|
+
sae_lens-5.11.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|