sae-lens 5.10.6__py3-none-any.whl → 5.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sae_lens/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "5.10.6"
2
+ __version__ = "5.11.0"
3
3
 
4
4
  import logging
5
5
 
sae_lens/config.py CHANGED
@@ -33,6 +33,7 @@ HfDataset = DatasetDict | Dataset | IterableDatasetDict | IterableDataset
33
33
 
34
34
  SPARSITY_FILENAME = "sparsity.safetensors"
35
35
  SAE_WEIGHTS_FILENAME = "sae_weights.safetensors"
36
+ SPARSIFY_WEIGHTS_FILENAME = "sae.safetensors"
36
37
  SAE_CFG_FILENAME = "cfg.json"
37
38
 
38
39
 
sae_lens/load_model.py CHANGED
@@ -159,7 +159,7 @@ class HookedProxyLM(HookedRootModule):
159
159
 
160
160
  # We don't want to prepend bos but the tokenizer does it automatically, so we remove it manually
161
161
  if hasattr(self.tokenizer, "add_bos_token") and self.tokenizer.add_bos_token: # type: ignore
162
- tokens = get_tokens_with_bos_removed(self.tokenizer, tokens)
162
+ tokens = get_tokens_with_bos_removed(self.tokenizer, tokens) # type: ignore
163
163
  return tokens # type: ignore
164
164
 
165
165
 
@@ -15,6 +15,7 @@ from sae_lens.config import (
15
15
  DTYPE_MAP,
16
16
  SAE_CFG_FILENAME,
17
17
  SAE_WEIGHTS_FILENAME,
18
+ SPARSIFY_WEIGHTS_FILENAME,
18
19
  SPARSITY_FILENAME,
19
20
  )
20
21
  from sae_lens.toolkit.pretrained_saes_directory import (
@@ -898,6 +899,146 @@ def llama_scope_r1_distill_sae_huggingface_loader(
898
899
  return cfg_dict, state_dict, log_sparsity
899
900
 
900
901
 
902
+ def get_sparsify_config_from_hf(
903
+ repo_id: str,
904
+ folder_name: str,
905
+ device: str,
906
+ force_download: bool = False,
907
+ cfg_overrides: dict[str, Any] | None = None,
908
+ ) -> dict[str, Any]:
909
+ cfg_filename = f"{folder_name}/{SAE_CFG_FILENAME}"
910
+ cfg_path = hf_hub_download(
911
+ repo_id,
912
+ filename=cfg_filename,
913
+ force_download=force_download,
914
+ )
915
+ sae_path = Path(cfg_path).parent
916
+ return get_sparsify_config_from_disk(
917
+ sae_path, device=device, cfg_overrides=cfg_overrides
918
+ )
919
+
920
+
921
+ def get_sparsify_config_from_disk(
922
+ path: str | Path,
923
+ device: str | None = None,
924
+ cfg_overrides: dict[str, Any] | None = None,
925
+ ) -> dict[str, Any]:
926
+ path = Path(path)
927
+
928
+ with open(path / SAE_CFG_FILENAME) as f:
929
+ old_cfg_dict = json.load(f)
930
+
931
+ config_path = path.parent / "config.json"
932
+ if config_path.exists():
933
+ with open(config_path) as f:
934
+ config_dict = json.load(f)
935
+ else:
936
+ config_dict = {}
937
+
938
+ folder_name = path.name
939
+ if folder_name == "embed_tokens":
940
+ hook_name, layer = "hook_embed", 0
941
+ else:
942
+ match = re.search(r"layers[._](\d+)", folder_name)
943
+ if match is None:
944
+ raise ValueError(f"Unrecognized Sparsify folder: {folder_name}")
945
+ layer = int(match.group(1))
946
+ hook_name = f"blocks.{layer}.hook_resid_post"
947
+
948
+ cfg_dict: dict[str, Any] = {
949
+ "architecture": "standard",
950
+ "d_in": old_cfg_dict["d_in"],
951
+ "d_sae": old_cfg_dict["d_in"] * old_cfg_dict["expansion_factor"],
952
+ "dtype": "bfloat16",
953
+ "device": device or "cpu",
954
+ "model_name": config_dict.get("model", path.parts[-2]),
955
+ "hook_name": hook_name,
956
+ "hook_layer": layer,
957
+ "hook_head_index": None,
958
+ "activation_fn_str": "topk",
959
+ "activation_fn_kwargs": {
960
+ "k": old_cfg_dict["k"],
961
+ "signed": old_cfg_dict.get("signed", False),
962
+ },
963
+ "apply_b_dec_to_input": not old_cfg_dict.get("normalize_decoder", False),
964
+ "dataset_path": config_dict.get(
965
+ "dataset", "togethercomputer/RedPajama-Data-1T-Sample"
966
+ ),
967
+ "context_size": config_dict.get("ctx_len", 2048),
968
+ "finetuning_scaling_factor": False,
969
+ "sae_lens_training_version": None,
970
+ "prepend_bos": True,
971
+ "dataset_trust_remote_code": True,
972
+ "normalize_activations": "none",
973
+ "neuronpedia_id": None,
974
+ }
975
+
976
+ if cfg_overrides:
977
+ cfg_dict.update(cfg_overrides)
978
+
979
+ return cfg_dict
980
+
981
+
982
+ def sparsify_huggingface_loader(
983
+ repo_id: str,
984
+ folder_name: str,
985
+ device: str = "cpu",
986
+ force_download: bool = False,
987
+ cfg_overrides: dict[str, Any] | None = None,
988
+ ) -> tuple[dict[str, Any], dict[str, torch.Tensor], None]:
989
+ weights_filename = f"{folder_name}/{SPARSIFY_WEIGHTS_FILENAME}"
990
+ sae_path = hf_hub_download(
991
+ repo_id,
992
+ filename=weights_filename,
993
+ force_download=force_download,
994
+ )
995
+ cfg_dict, state_dict = sparsify_disk_loader(
996
+ Path(sae_path).parent, device=device, cfg_overrides=cfg_overrides
997
+ )
998
+ return cfg_dict, state_dict, None
999
+
1000
+
1001
+ def sparsify_disk_loader(
1002
+ path: str | Path,
1003
+ device: str = "cpu",
1004
+ cfg_overrides: dict[str, Any] | None = None,
1005
+ ) -> tuple[dict[str, Any], dict[str, torch.Tensor]]:
1006
+ cfg_dict = get_sparsify_config_from_disk(path, device, cfg_overrides)
1007
+
1008
+ weight_path = Path(path) / SPARSIFY_WEIGHTS_FILENAME
1009
+ state_dict_loaded = load_file(weight_path, device=device)
1010
+
1011
+ dtype = DTYPE_MAP[cfg_dict["dtype"]]
1012
+
1013
+ W_enc = (
1014
+ state_dict_loaded["W_enc"]
1015
+ if "W_enc" in state_dict_loaded
1016
+ else state_dict_loaded["encoder.weight"].T
1017
+ ).to(dtype)
1018
+
1019
+ if "W_dec" in state_dict_loaded:
1020
+ W_dec = state_dict_loaded["W_dec"].T.to(dtype)
1021
+ else:
1022
+ W_dec = state_dict_loaded["decoder.weight"].T.to(dtype)
1023
+
1024
+ if "b_enc" in state_dict_loaded:
1025
+ b_enc = state_dict_loaded["b_enc"].to(dtype)
1026
+ elif "encoder.bias" in state_dict_loaded:
1027
+ b_enc = state_dict_loaded["encoder.bias"].to(dtype)
1028
+ else:
1029
+ b_enc = torch.zeros(cfg_dict["d_sae"], dtype=dtype, device=device)
1030
+
1031
+ if "b_dec" in state_dict_loaded:
1032
+ b_dec = state_dict_loaded["b_dec"].to(dtype)
1033
+ elif "decoder.bias" in state_dict_loaded:
1034
+ b_dec = state_dict_loaded["decoder.bias"].to(dtype)
1035
+ else:
1036
+ b_dec = torch.zeros(cfg_dict["d_in"], dtype=dtype, device=device)
1037
+
1038
+ state_dict = {"W_enc": W_enc, "b_enc": b_enc, "W_dec": W_dec, "b_dec": b_dec}
1039
+ return cfg_dict, state_dict
1040
+
1041
+
901
1042
  NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
902
1043
  "sae_lens": sae_lens_huggingface_loader,
903
1044
  "connor_rob_hook_z": connor_rob_hook_z_huggingface_loader,
@@ -906,6 +1047,7 @@ NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
906
1047
  "llama_scope_r1_distill": llama_scope_r1_distill_sae_huggingface_loader,
907
1048
  "dictionary_learning_1": dictionary_learning_sae_huggingface_loader_1,
908
1049
  "deepseek_r1": deepseek_r1_sae_huggingface_loader,
1050
+ "sparsify": sparsify_huggingface_loader,
909
1051
  }
910
1052
 
911
1053
 
@@ -917,4 +1059,5 @@ NAMED_PRETRAINED_SAE_CONFIG_GETTERS: dict[str, PretrainedSaeConfigHuggingfaceLoa
917
1059
  "llama_scope_r1_distill": get_llama_scope_r1_distill_config_from_hf,
918
1060
  "dictionary_learning_1": get_dictionary_learning_config_1_from_hf,
919
1061
  "deepseek_r1": get_deepseek_r1_config_from_hf,
1062
+ "sparsify": get_sparsify_config_from_hf,
920
1063
  }
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: sae-lens
3
- Version: 5.10.6
3
+ Version: 5.11.0
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  Keywords: deep-learning,sparse-autoencoders,mechanistic-interpretability,PyTorch
@@ -80,7 +80,7 @@ Pre-trained SAEs for various models can be imported via SAE Lens. See this [page
80
80
 
81
81
  ## Join the Slack!
82
82
 
83
- Feel free to join the [Open Source Mechanistic Interpretability Slack](https://join.slack.com/t/opensourcemechanistic/shared_invite/zt-35oqtxb2t-yKBlqTL570ycNJisIFX2gw) for support!
83
+ Feel free to join the [Open Source Mechanistic Interpretability Slack](https://join.slack.com/t/opensourcemechanistic/shared_invite/zt-375zalm04-GFd5tdBU1yLKlu_T_JSqZQ) for support!
84
84
 
85
85
  ## Citation
86
86
 
@@ -1,18 +1,18 @@
1
- sae_lens/__init__.py,sha256=cjKprSd1OuHNJWbXnKEnoV_v1GPUkEud2uLAc_Xjn04,1307
1
+ sae_lens/__init__.py,sha256=A2ttZHoobEQm6YKaCqWrztd6LIDGmXlOvuyfp1aGb_E,1307
2
2
  sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  sae_lens/analysis/hooked_sae_transformer.py,sha256=fkSsvWcTM_d7M3rRM6N2oFpeSvGhj_ENZtqmfWOzZTQ,13717
4
4
  sae_lens/analysis/neuronpedia_integration.py,sha256=dFiKRWfuT5iUfTPBPmZydSaNG3VwqZ1asuNbbQv_NCM,18488
5
5
  sae_lens/cache_activations_runner.py,sha256=dGK5EHJMHAKDAFyr25fy1COSm-61q-q6kpWENHFMaKk,12561
6
- sae_lens/config.py,sha256=6SP10H4U91u6UDeN1F9Lb4p1lpTv7ZKKL29-WPpyRr8,32988
6
+ sae_lens/config.py,sha256=y3dgA_lNSpwi_n442dtrQ6RxfFKbnvUKjb7Qe1ZNoA4,33034
7
7
  sae_lens/evals.py,sha256=7cuLlT0ZTAhZ7eQbsZEFT-M3oixmaXSCBJtjh9hGnVQ,38527
8
- sae_lens/load_model.py,sha256=tE70sXsyyyGYW7o506O3eiw1MXyyW6DCQojLG49hWYI,6771
8
+ sae_lens/load_model.py,sha256=TRxyUpudPCwGzSccQiHxww9OtLiwBBRurvi-HUnfdKg,6787
9
9
  sae_lens/pretokenize_runner.py,sha256=0nHQq3s_d80VS8iVK4-e6y_orAYVO8c4RrLGtIDfK_E,6885
10
10
  sae_lens/pretrained_saes.yaml,sha256=nhHW1auhyi4GHYrjUnHQqbNVhI5cMJv-HThzbzU1xG0,574145
11
11
  sae_lens/sae.py,sha256=8DFVKG72Ml_hVm49YIHJ0zAS6Pbd7O_7wDkQV5kyhxk,27965
12
12
  sae_lens/sae_training_runner.py,sha256=tduPN8BGtMatua0bNY-tXGGGxhedMu6F_O9ugDOdRmQ,9004
13
13
  sae_lens/tokenization_and_batching.py,sha256=oUAscjy_LPOrOb8_Ty6eLAcZ0B3HB_wiWjWktgolhG0,4314
14
14
  sae_lens/toolkit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
- sae_lens/toolkit/pretrained_sae_loaders.py,sha256=ANMLti9n8ipf6cBJiPFs6Cln0ug41EkoA0EhyH4AtVY,30194
15
+ sae_lens/toolkit/pretrained_sae_loaders.py,sha256=vzDdDjy7EnNpmOcEwwXyd0AzzX4Up2Gdhs8wlogww8M,34840
16
16
  sae_lens/toolkit/pretrained_saes_directory.py,sha256=4Vn-Jex6SveD7EbxcSOBv8cx1gkPfUMLU1QOP-ww1ZE,3752
17
17
  sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
18
  sae_lens/training/activations_store.py,sha256=1H9sb1mAVrlHU5WVci1FFcCVNj1BydezI4zu7_Usm3s,35985
@@ -22,7 +22,7 @@ sae_lens/training/sae_trainer.py,sha256=xenSV0xw06y1_qLhw82_966DmWOp2nydqlrVgJA6
22
22
  sae_lens/training/training_sae.py,sha256=0A4x74qUfinLhwaK9RSoWZ7POrGc8kIU5EgBOp4UJtE,27998
23
23
  sae_lens/training/upload_saes_to_huggingface.py,sha256=P1K3nxv-IM7JptfLHj5Agiis7A_adn-g_tiq1d8PdaU,4361
24
24
  sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
25
- sae_lens-5.10.6.dist-info/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
26
- sae_lens-5.10.6.dist-info/METADATA,sha256=3qzaaX6pfLggxkWUcf2_5EbviynWxXVIB3I_qik3tNI,5324
27
- sae_lens-5.10.6.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
28
- sae_lens-5.10.6.dist-info/RECORD,,
25
+ sae_lens-5.11.0.dist-info/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
26
+ sae_lens-5.11.0.dist-info/METADATA,sha256=LPXNsAtjYMtDveFW2aBFrLCjWmmnN7hiFt3WIDOxpBU,5324
27
+ sae_lens-5.11.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
28
+ sae_lens-5.11.0.dist-info/RECORD,,