sae-lens 5.10.7__py3-none-any.whl → 6.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. sae_lens/__init__.py +60 -7
  2. sae_lens/analysis/hooked_sae_transformer.py +12 -12
  3. sae_lens/analysis/neuronpedia_integration.py +16 -14
  4. sae_lens/cache_activations_runner.py +9 -7
  5. sae_lens/config.py +170 -257
  6. sae_lens/constants.py +21 -0
  7. sae_lens/evals.py +59 -44
  8. sae_lens/llm_sae_training_runner.py +377 -0
  9. sae_lens/load_model.py +53 -5
  10. sae_lens/{toolkit → loading}/pretrained_sae_loaders.py +228 -32
  11. sae_lens/registry.py +49 -0
  12. sae_lens/saes/__init__.py +48 -0
  13. sae_lens/saes/gated_sae.py +254 -0
  14. sae_lens/saes/jumprelu_sae.py +348 -0
  15. sae_lens/saes/sae.py +1076 -0
  16. sae_lens/saes/standard_sae.py +178 -0
  17. sae_lens/saes/topk_sae.py +300 -0
  18. sae_lens/training/activation_scaler.py +53 -0
  19. sae_lens/training/activations_store.py +103 -184
  20. sae_lens/training/mixing_buffer.py +56 -0
  21. sae_lens/training/optim.py +60 -36
  22. sae_lens/training/sae_trainer.py +155 -177
  23. sae_lens/training/types.py +5 -0
  24. sae_lens/training/upload_saes_to_huggingface.py +13 -7
  25. sae_lens/util.py +47 -0
  26. {sae_lens-5.10.7.dist-info → sae_lens-6.0.0.dist-info}/METADATA +1 -1
  27. sae_lens-6.0.0.dist-info/RECORD +37 -0
  28. sae_lens/sae.py +0 -747
  29. sae_lens/sae_training_runner.py +0 -251
  30. sae_lens/training/geometric_median.py +0 -101
  31. sae_lens/training/training_sae.py +0 -710
  32. sae_lens-5.10.7.dist-info/RECORD +0 -28
  33. /sae_lens/{toolkit → loading}/__init__.py +0 -0
  34. /sae_lens/{toolkit → loading}/pretrained_saes_directory.py +0 -0
  35. {sae_lens-5.10.7.dist-info → sae_lens-6.0.0.dist-info}/LICENSE +0 -0
  36. {sae_lens-5.10.7.dist-info → sae_lens-6.0.0.dist-info}/WHEEL +0 -0
sae_lens/load_model.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Any, Literal, cast
1
+ from typing import Any, Callable, Literal, cast
2
2
 
3
3
  import torch
4
4
  from transformer_lens import HookedTransformer
@@ -77,6 +77,7 @@ class HookedProxyLM(HookedRootModule):
77
77
  # copied and modified from base HookedRootModule
78
78
  def setup(self):
79
79
  self.mod_dict = {}
80
+ self.named_modules_dict = {}
80
81
  self.hook_dict: dict[str, HookPoint] = {}
81
82
  for name, module in self.model.named_modules():
82
83
  if name == "":
@@ -89,14 +90,21 @@ class HookedProxyLM(HookedRootModule):
89
90
 
90
91
  self.hook_dict[name] = hook_point
91
92
  self.mod_dict[name] = hook_point
93
+ self.named_modules_dict[name] = module
94
+
95
+ def run_with_cache(self, *args: Any, **kwargs: Any): # type: ignore
96
+ if "names_filter" in kwargs:
97
+ # hacky way to make sure that the names_filter is passed to our forward method
98
+ kwargs["_names_filter"] = kwargs["names_filter"]
99
+ return super().run_with_cache(*args, **kwargs)
92
100
 
93
101
  def forward(
94
102
  self,
95
103
  tokens: torch.Tensor,
96
104
  return_type: Literal["both", "logits"] = "logits",
97
105
  loss_per_token: bool = False,
98
- # TODO: implement real support for stop_at_layer
99
106
  stop_at_layer: int | None = None,
107
+ _names_filter: list[str] | None = None,
100
108
  **kwargs: Any,
101
109
  ) -> Output | Loss:
102
110
  # This is just what's needed for evals, not everything that HookedTransformer has
@@ -107,8 +115,28 @@ class HookedProxyLM(HookedRootModule):
107
115
  raise NotImplementedError(
108
116
  "Only return_type supported is 'both' or 'logits' to match what's in evals.py and ActivationsStore"
109
117
  )
110
- output = self.model(tokens)
111
- logits = _extract_logits_from_output(output)
118
+
119
+ stop_hooks = []
120
+ if stop_at_layer is not None and _names_filter is not None:
121
+ if return_type != "logits":
122
+ raise NotImplementedError(
123
+ "stop_at_layer is not supported for return_type='both'"
124
+ )
125
+ stop_manager = StopManager(_names_filter)
126
+
127
+ for hook_name in _names_filter:
128
+ module = self.named_modules_dict[hook_name]
129
+ stop_fn = stop_manager.get_stop_hook_fn(hook_name)
130
+ stop_hooks.append(module.register_forward_hook(stop_fn))
131
+ try:
132
+ output = self.model(tokens)
133
+ logits = _extract_logits_from_output(output)
134
+ except StopForward:
135
+ # If we stop early, we don't care about the return output
136
+ return None # type: ignore
137
+ finally:
138
+ for stop_hook in stop_hooks:
139
+ stop_hook.remove()
112
140
 
113
141
  if return_type == "logits":
114
142
  return logits
@@ -159,7 +187,7 @@ class HookedProxyLM(HookedRootModule):
159
187
 
160
188
  # We don't want to prepend bos but the tokenizer does it automatically, so we remove it manually
161
189
  if hasattr(self.tokenizer, "add_bos_token") and self.tokenizer.add_bos_token: # type: ignore
162
- tokens = get_tokens_with_bos_removed(self.tokenizer, tokens)
190
+ tokens = get_tokens_with_bos_removed(self.tokenizer, tokens) # type: ignore
163
191
  return tokens # type: ignore
164
192
 
165
193
 
@@ -183,3 +211,23 @@ def get_hook_fn(hook_point: HookPoint):
183
211
  return output
184
212
 
185
213
  return hook_fn
214
+
215
+
216
+ class StopForward(Exception):
217
+ pass
218
+
219
+
220
+ class StopManager:
221
+ def __init__(self, hook_names: list[str]):
222
+ self.hook_names = hook_names
223
+ self.total_hook_names = len(set(hook_names))
224
+ self.called_hook_names = set()
225
+
226
+ def get_stop_hook_fn(self, hook_name: str) -> Callable[[Any, Any, Any], Any]:
227
+ def stop_hook_fn(module: Any, input: Any, output: Any) -> Any: # noqa: ARG001
228
+ self.called_hook_names.add(hook_name)
229
+ if len(self.called_hook_names) == self.total_hook_names:
230
+ raise StopForward()
231
+ return output
232
+
233
+ return stop_hook_fn
@@ -7,21 +7,41 @@ import numpy as np
7
7
  import torch
8
8
  from huggingface_hub import hf_hub_download
9
9
  from huggingface_hub.utils import EntryNotFoundError
10
+ from packaging.version import Version
10
11
  from safetensors import safe_open
11
12
  from safetensors.torch import load_file
12
13
 
13
14
  from sae_lens import logger
14
- from sae_lens.config import (
15
+ from sae_lens.constants import (
15
16
  DTYPE_MAP,
16
17
  SAE_CFG_FILENAME,
17
18
  SAE_WEIGHTS_FILENAME,
19
+ SPARSIFY_WEIGHTS_FILENAME,
18
20
  SPARSITY_FILENAME,
19
21
  )
20
- from sae_lens.toolkit.pretrained_saes_directory import (
22
+ from sae_lens.loading.pretrained_saes_directory import (
21
23
  get_config_overrides,
22
24
  get_pretrained_saes_directory,
23
25
  get_repo_id_and_folder_name,
24
26
  )
27
+ from sae_lens.registry import get_sae_class
28
+ from sae_lens.util import filter_valid_dataclass_fields
29
+
30
+ LLM_METADATA_KEYS = {
31
+ "model_name",
32
+ "hook_name",
33
+ "model_class_name",
34
+ "hook_head_index",
35
+ "model_from_pretrained_kwargs",
36
+ "prepend_bos",
37
+ "exclude_special_tokens",
38
+ "neuronpedia_id",
39
+ "context_size",
40
+ "seqpos_slice",
41
+ "dataset_path",
42
+ "sae_lens_version",
43
+ "sae_lens_training_version",
44
+ }
25
45
 
26
46
 
27
47
  # loaders take in a release, sae_id, device, and whether to force download, and returns a tuple of config, state_dict, and log sparsity
@@ -174,30 +194,69 @@ def get_sae_lens_config_from_disk(
174
194
 
175
195
 
176
196
  def handle_config_defaulting(cfg_dict: dict[str, Any]) -> dict[str, Any]:
197
+ sae_lens_version = cfg_dict.get("sae_lens_version")
198
+ if not sae_lens_version and "metadata" in cfg_dict:
199
+ sae_lens_version = cfg_dict["metadata"].get("sae_lens_version")
200
+
201
+ if not sae_lens_version or Version(sae_lens_version) < Version("6.0.0-rc.0"):
202
+ cfg_dict = handle_pre_6_0_config(cfg_dict)
203
+ return cfg_dict
204
+
205
+
206
+ def handle_pre_6_0_config(cfg_dict: dict[str, Any]) -> dict[str, Any]:
207
+ """
208
+ Format a config dictionary for a Sparse Autoencoder (SAE) to be compatible with the new 6.0 format.
209
+ """
210
+
211
+ rename_keys_map = {
212
+ "hook_point": "hook_name",
213
+ "hook_point_head_index": "hook_head_index",
214
+ "activation_fn_str": "activation_fn",
215
+ }
216
+ new_cfg = {rename_keys_map.get(k, k): v for k, v in cfg_dict.items()}
217
+
177
218
  # Set default values for backwards compatibility
178
- cfg_dict.setdefault("prepend_bos", True)
179
- cfg_dict.setdefault("dataset_trust_remote_code", True)
180
- cfg_dict.setdefault("apply_b_dec_to_input", True)
181
- cfg_dict.setdefault("finetuning_scaling_factor", False)
182
- cfg_dict.setdefault("sae_lens_training_version", None)
183
- cfg_dict.setdefault("activation_fn_str", cfg_dict.get("activation_fn", "relu"))
184
- cfg_dict.setdefault("architecture", "standard")
185
- cfg_dict.setdefault("neuronpedia_id", None)
186
-
187
- if "normalize_activations" in cfg_dict and isinstance(
188
- cfg_dict["normalize_activations"], bool
219
+ new_cfg.setdefault("prepend_bos", True)
220
+ new_cfg.setdefault("dataset_trust_remote_code", True)
221
+ new_cfg.setdefault("apply_b_dec_to_input", True)
222
+ new_cfg.setdefault("finetuning_scaling_factor", False)
223
+ new_cfg.setdefault("sae_lens_training_version", None)
224
+ new_cfg.setdefault("activation_fn", new_cfg.get("activation_fn", "relu"))
225
+ new_cfg.setdefault("architecture", "standard")
226
+ new_cfg.setdefault("neuronpedia_id", None)
227
+ new_cfg.setdefault(
228
+ "reshape_activations",
229
+ "hook_z" if "hook_z" in new_cfg.get("hook_name", "") else "none",
230
+ )
231
+
232
+ if "normalize_activations" in new_cfg and isinstance(
233
+ new_cfg["normalize_activations"], bool
189
234
  ):
190
235
  # backwards compatibility
191
- cfg_dict["normalize_activations"] = (
236
+ new_cfg["normalize_activations"] = (
192
237
  "none"
193
- if not cfg_dict["normalize_activations"]
238
+ if not new_cfg["normalize_activations"]
194
239
  else "expected_average_only_in"
195
240
  )
196
241
 
197
- cfg_dict.setdefault("normalize_activations", "none")
198
- cfg_dict.setdefault("device", "cpu")
242
+ if new_cfg.get("normalize_activations") is None:
243
+ new_cfg["normalize_activations"] = "none"
199
244
 
200
- return cfg_dict
245
+ new_cfg.setdefault("device", "cpu")
246
+
247
+ architecture = new_cfg.get("architecture", "standard")
248
+
249
+ config_class = get_sae_class(architecture)[1]
250
+
251
+ sae_cfg_dict = filter_valid_dataclass_fields(new_cfg, config_class)
252
+ if architecture == "topk" and "activation_fn_kwargs" in new_cfg:
253
+ sae_cfg_dict["k"] = new_cfg["activation_fn_kwargs"]["k"]
254
+
255
+ sae_cfg_dict["metadata"] = {
256
+ k: v for k, v in new_cfg.items() if k in LLM_METADATA_KEYS
257
+ }
258
+ sae_cfg_dict["architecture"] = architecture
259
+ return sae_cfg_dict
201
260
 
202
261
 
203
262
  def get_connor_rob_hook_z_config_from_hf(
@@ -221,9 +280,8 @@ def get_connor_rob_hook_z_config_from_hf(
221
280
  "device": device if device is not None else "cpu",
222
281
  "model_name": "gpt2-small",
223
282
  "hook_name": old_cfg_dict["act_name"],
224
- "hook_layer": old_cfg_dict["layer"],
225
283
  "hook_head_index": None,
226
- "activation_fn_str": "relu",
284
+ "activation_fn": "relu",
227
285
  "apply_b_dec_to_input": True,
228
286
  "finetuning_scaling_factor": False,
229
287
  "sae_lens_training_version": None,
@@ -232,6 +290,7 @@ def get_connor_rob_hook_z_config_from_hf(
232
290
  "context_size": 128,
233
291
  "normalize_activations": "none",
234
292
  "dataset_trust_remote_code": True,
293
+ "reshape_activations": "hook_z",
235
294
  **(cfg_overrides or {}),
236
295
  }
237
296
 
@@ -370,9 +429,8 @@ def get_gemma_2_config_from_hf(
370
429
  "dtype": "float32",
371
430
  "model_name": model_name,
372
431
  "hook_name": hook_name,
373
- "hook_layer": layer,
374
432
  "hook_head_index": None,
375
- "activation_fn_str": "relu",
433
+ "activation_fn": "relu",
376
434
  "finetuning_scaling_factor": False,
377
435
  "sae_lens_training_version": None,
378
436
  "prepend_bos": True,
@@ -492,9 +550,8 @@ def get_llama_scope_config_from_hf(
492
550
  "dtype": "bfloat16",
493
551
  "model_name": model_name,
494
552
  "hook_name": old_cfg_dict["hook_point_in"],
495
- "hook_layer": int(old_cfg_dict["hook_point_in"].split(".")[1]),
496
553
  "hook_head_index": None,
497
- "activation_fn_str": "relu",
554
+ "activation_fn": "relu",
498
555
  "finetuning_scaling_factor": False,
499
556
  "sae_lens_training_version": None,
500
557
  "prepend_bos": True,
@@ -606,8 +663,8 @@ def get_dictionary_learning_config_1_from_hf(
606
663
 
607
664
  hook_point_name = f"blocks.{trainer['layer']}.hook_resid_post"
608
665
 
609
- activation_fn_str = "topk" if trainer["dict_class"] == "AutoEncoderTopK" else "relu"
610
- activation_fn_kwargs = {"k": trainer["k"]} if activation_fn_str == "topk" else {}
666
+ activation_fn = "topk" if trainer["dict_class"] == "AutoEncoderTopK" else "relu"
667
+ activation_fn_kwargs = {"k": trainer["k"]} if activation_fn == "topk" else {}
611
668
 
612
669
  return {
613
670
  "architecture": (
@@ -619,9 +676,8 @@ def get_dictionary_learning_config_1_from_hf(
619
676
  "device": device,
620
677
  "model_name": trainer["lm_name"].split("/")[-1],
621
678
  "hook_name": hook_point_name,
622
- "hook_layer": trainer["layer"],
623
679
  "hook_head_index": None,
624
- "activation_fn_str": activation_fn_str,
680
+ "activation_fn": activation_fn,
625
681
  "activation_fn_kwargs": activation_fn_kwargs,
626
682
  "apply_b_dec_to_input": True,
627
683
  "finetuning_scaling_factor": False,
@@ -658,13 +714,12 @@ def get_deepseek_r1_config_from_hf(
658
714
  "context_size": 1024,
659
715
  "model_name": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
660
716
  "hook_name": f"blocks.{layer}.hook_resid_post",
661
- "hook_layer": layer,
662
717
  "hook_head_index": None,
663
718
  "prepend_bos": True,
664
719
  "dataset_path": "lmsys/lmsys-chat-1m",
665
720
  "dataset_trust_remote_code": True,
666
721
  "sae_lens_training_version": None,
667
- "activation_fn_str": "relu",
722
+ "activation_fn": "relu",
668
723
  "normalize_activations": "none",
669
724
  "device": device,
670
725
  "apply_b_dec_to_input": False,
@@ -817,9 +872,8 @@ def get_llama_scope_r1_distill_config_from_hf(
817
872
  "device": device,
818
873
  "model_name": model_name,
819
874
  "hook_name": huggingface_cfg_dict["hook_point_in"],
820
- "hook_layer": int(huggingface_cfg_dict["hook_point_in"].split(".")[1]),
821
875
  "hook_head_index": None,
822
- "activation_fn_str": "relu",
876
+ "activation_fn": "relu",
823
877
  "finetuning_scaling_factor": False,
824
878
  "sae_lens_training_version": None,
825
879
  "prepend_bos": True,
@@ -898,6 +952,146 @@ def llama_scope_r1_distill_sae_huggingface_loader(
898
952
  return cfg_dict, state_dict, log_sparsity
899
953
 
900
954
 
955
+ def get_sparsify_config_from_hf(
956
+ repo_id: str,
957
+ folder_name: str,
958
+ device: str,
959
+ force_download: bool = False,
960
+ cfg_overrides: dict[str, Any] | None = None,
961
+ ) -> dict[str, Any]:
962
+ cfg_filename = f"{folder_name}/{SAE_CFG_FILENAME}"
963
+ cfg_path = hf_hub_download(
964
+ repo_id,
965
+ filename=cfg_filename,
966
+ force_download=force_download,
967
+ )
968
+ sae_path = Path(cfg_path).parent
969
+ return get_sparsify_config_from_disk(
970
+ sae_path, device=device, cfg_overrides=cfg_overrides
971
+ )
972
+
973
+
974
+ def get_sparsify_config_from_disk(
975
+ path: str | Path,
976
+ device: str | None = None,
977
+ cfg_overrides: dict[str, Any] | None = None,
978
+ ) -> dict[str, Any]:
979
+ path = Path(path)
980
+
981
+ with open(path / SAE_CFG_FILENAME) as f:
982
+ old_cfg_dict = json.load(f)
983
+
984
+ config_path = path.parent / "config.json"
985
+ if config_path.exists():
986
+ with open(config_path) as f:
987
+ config_dict = json.load(f)
988
+ else:
989
+ config_dict = {}
990
+
991
+ folder_name = path.name
992
+ if folder_name == "embed_tokens":
993
+ hook_name, layer = "hook_embed", 0
994
+ else:
995
+ match = re.search(r"layers[._](\d+)", folder_name)
996
+ if match is None:
997
+ raise ValueError(f"Unrecognized Sparsify folder: {folder_name}")
998
+ layer = int(match.group(1))
999
+ hook_name = f"blocks.{layer}.hook_resid_post"
1000
+
1001
+ cfg_dict: dict[str, Any] = {
1002
+ "architecture": "standard",
1003
+ "d_in": old_cfg_dict["d_in"],
1004
+ "d_sae": old_cfg_dict["d_in"] * old_cfg_dict["expansion_factor"],
1005
+ "dtype": "bfloat16",
1006
+ "device": device or "cpu",
1007
+ "model_name": config_dict.get("model", path.parts[-2]),
1008
+ "hook_name": hook_name,
1009
+ "hook_layer": layer,
1010
+ "hook_head_index": None,
1011
+ "activation_fn_str": "topk",
1012
+ "activation_fn_kwargs": {
1013
+ "k": old_cfg_dict["k"],
1014
+ "signed": old_cfg_dict.get("signed", False),
1015
+ },
1016
+ "apply_b_dec_to_input": not old_cfg_dict.get("normalize_decoder", False),
1017
+ "dataset_path": config_dict.get(
1018
+ "dataset", "togethercomputer/RedPajama-Data-1T-Sample"
1019
+ ),
1020
+ "context_size": config_dict.get("ctx_len", 2048),
1021
+ "finetuning_scaling_factor": False,
1022
+ "sae_lens_training_version": None,
1023
+ "prepend_bos": True,
1024
+ "dataset_trust_remote_code": True,
1025
+ "normalize_activations": "none",
1026
+ "neuronpedia_id": None,
1027
+ }
1028
+
1029
+ if cfg_overrides:
1030
+ cfg_dict.update(cfg_overrides)
1031
+
1032
+ return cfg_dict
1033
+
1034
+
1035
+ def sparsify_huggingface_loader(
1036
+ repo_id: str,
1037
+ folder_name: str,
1038
+ device: str = "cpu",
1039
+ force_download: bool = False,
1040
+ cfg_overrides: dict[str, Any] | None = None,
1041
+ ) -> tuple[dict[str, Any], dict[str, torch.Tensor], None]:
1042
+ weights_filename = f"{folder_name}/{SPARSIFY_WEIGHTS_FILENAME}"
1043
+ sae_path = hf_hub_download(
1044
+ repo_id,
1045
+ filename=weights_filename,
1046
+ force_download=force_download,
1047
+ )
1048
+ cfg_dict, state_dict = sparsify_disk_loader(
1049
+ Path(sae_path).parent, device=device, cfg_overrides=cfg_overrides
1050
+ )
1051
+ return cfg_dict, state_dict, None
1052
+
1053
+
1054
+ def sparsify_disk_loader(
1055
+ path: str | Path,
1056
+ device: str = "cpu",
1057
+ cfg_overrides: dict[str, Any] | None = None,
1058
+ ) -> tuple[dict[str, Any], dict[str, torch.Tensor]]:
1059
+ cfg_dict = get_sparsify_config_from_disk(path, device, cfg_overrides)
1060
+
1061
+ weight_path = Path(path) / SPARSIFY_WEIGHTS_FILENAME
1062
+ state_dict_loaded = load_file(weight_path, device=device)
1063
+
1064
+ dtype = DTYPE_MAP[cfg_dict["dtype"]]
1065
+
1066
+ W_enc = (
1067
+ state_dict_loaded["W_enc"]
1068
+ if "W_enc" in state_dict_loaded
1069
+ else state_dict_loaded["encoder.weight"].T
1070
+ ).to(dtype)
1071
+
1072
+ if "W_dec" in state_dict_loaded:
1073
+ W_dec = state_dict_loaded["W_dec"].T.to(dtype)
1074
+ else:
1075
+ W_dec = state_dict_loaded["decoder.weight"].T.to(dtype)
1076
+
1077
+ if "b_enc" in state_dict_loaded:
1078
+ b_enc = state_dict_loaded["b_enc"].to(dtype)
1079
+ elif "encoder.bias" in state_dict_loaded:
1080
+ b_enc = state_dict_loaded["encoder.bias"].to(dtype)
1081
+ else:
1082
+ b_enc = torch.zeros(cfg_dict["d_sae"], dtype=dtype, device=device)
1083
+
1084
+ if "b_dec" in state_dict_loaded:
1085
+ b_dec = state_dict_loaded["b_dec"].to(dtype)
1086
+ elif "decoder.bias" in state_dict_loaded:
1087
+ b_dec = state_dict_loaded["decoder.bias"].to(dtype)
1088
+ else:
1089
+ b_dec = torch.zeros(cfg_dict["d_in"], dtype=dtype, device=device)
1090
+
1091
+ state_dict = {"W_enc": W_enc, "b_enc": b_enc, "W_dec": W_dec, "b_dec": b_dec}
1092
+ return cfg_dict, state_dict
1093
+
1094
+
901
1095
  NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
902
1096
  "sae_lens": sae_lens_huggingface_loader,
903
1097
  "connor_rob_hook_z": connor_rob_hook_z_huggingface_loader,
@@ -906,6 +1100,7 @@ NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
906
1100
  "llama_scope_r1_distill": llama_scope_r1_distill_sae_huggingface_loader,
907
1101
  "dictionary_learning_1": dictionary_learning_sae_huggingface_loader_1,
908
1102
  "deepseek_r1": deepseek_r1_sae_huggingface_loader,
1103
+ "sparsify": sparsify_huggingface_loader,
909
1104
  }
910
1105
 
911
1106
 
@@ -917,4 +1112,5 @@ NAMED_PRETRAINED_SAE_CONFIG_GETTERS: dict[str, PretrainedSaeConfigHuggingfaceLoa
917
1112
  "llama_scope_r1_distill": get_llama_scope_r1_distill_config_from_hf,
918
1113
  "dictionary_learning_1": get_dictionary_learning_config_1_from_hf,
919
1114
  "deepseek_r1": get_deepseek_r1_config_from_hf,
1115
+ "sparsify": get_sparsify_config_from_hf,
920
1116
  }
sae_lens/registry.py ADDED
@@ -0,0 +1,49 @@
1
+ from typing import TYPE_CHECKING, Any
2
+
3
+ # avoid circular imports
4
+ if TYPE_CHECKING:
5
+ from sae_lens.saes.sae import SAE, SAEConfig, TrainingSAE, TrainingSAEConfig
6
+
7
+ SAE_CLASS_REGISTRY: dict[str, tuple["type[SAE[Any]]", "type[SAEConfig]"]] = {}
8
+ SAE_TRAINING_CLASS_REGISTRY: dict[
9
+ str, tuple["type[TrainingSAE[Any]]", "type[TrainingSAEConfig]"]
10
+ ] = {}
11
+
12
+
13
+ def register_sae_class(
14
+ architecture: str,
15
+ sae_class: "type[SAE[Any]]",
16
+ sae_config_class: "type[SAEConfig]",
17
+ ) -> None:
18
+ if architecture in SAE_CLASS_REGISTRY:
19
+ raise ValueError(
20
+ f"SAE class for architecture {architecture} already registered."
21
+ )
22
+ SAE_CLASS_REGISTRY[architecture] = (sae_class, sae_config_class)
23
+
24
+
25
+ def register_sae_training_class(
26
+ architecture: str,
27
+ sae_training_class: "type[TrainingSAE[Any]]",
28
+ sae_training_config_class: "type[TrainingSAEConfig]",
29
+ ) -> None:
30
+ if architecture in SAE_TRAINING_CLASS_REGISTRY:
31
+ raise ValueError(
32
+ f"SAE training class for architecture {architecture} already registered."
33
+ )
34
+ SAE_TRAINING_CLASS_REGISTRY[architecture] = (
35
+ sae_training_class,
36
+ sae_training_config_class,
37
+ )
38
+
39
+
40
+ def get_sae_class(
41
+ architecture: str,
42
+ ) -> tuple["type[SAE[Any]]", "type[SAEConfig]"]:
43
+ return SAE_CLASS_REGISTRY[architecture]
44
+
45
+
46
+ def get_sae_training_class(
47
+ architecture: str,
48
+ ) -> tuple["type[TrainingSAE[Any]]", "type[TrainingSAEConfig]"]:
49
+ return SAE_TRAINING_CLASS_REGISTRY[architecture]
@@ -0,0 +1,48 @@
1
+ from .gated_sae import (
2
+ GatedSAE,
3
+ GatedSAEConfig,
4
+ GatedTrainingSAE,
5
+ GatedTrainingSAEConfig,
6
+ )
7
+ from .jumprelu_sae import (
8
+ JumpReLUSAE,
9
+ JumpReLUSAEConfig,
10
+ JumpReLUTrainingSAE,
11
+ JumpReLUTrainingSAEConfig,
12
+ )
13
+ from .sae import SAE, SAEConfig, TrainingSAE, TrainingSAEConfig
14
+ from .standard_sae import (
15
+ StandardSAE,
16
+ StandardSAEConfig,
17
+ StandardTrainingSAE,
18
+ StandardTrainingSAEConfig,
19
+ )
20
+ from .topk_sae import (
21
+ TopKSAE,
22
+ TopKSAEConfig,
23
+ TopKTrainingSAE,
24
+ TopKTrainingSAEConfig,
25
+ )
26
+
27
+ __all__ = [
28
+ "SAE",
29
+ "SAEConfig",
30
+ "TrainingSAE",
31
+ "TrainingSAEConfig",
32
+ "StandardSAE",
33
+ "StandardSAEConfig",
34
+ "StandardTrainingSAE",
35
+ "StandardTrainingSAEConfig",
36
+ "GatedSAE",
37
+ "GatedSAEConfig",
38
+ "GatedTrainingSAE",
39
+ "GatedTrainingSAEConfig",
40
+ "JumpReLUSAE",
41
+ "JumpReLUSAEConfig",
42
+ "JumpReLUTrainingSAE",
43
+ "JumpReLUTrainingSAEConfig",
44
+ "TopKSAE",
45
+ "TopKSAEConfig",
46
+ "TopKTrainingSAE",
47
+ "TopKTrainingSAEConfig",
48
+ ]