sae-lens 5.10.7__py3-none-any.whl → 6.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +60 -7
- sae_lens/analysis/hooked_sae_transformer.py +12 -12
- sae_lens/analysis/neuronpedia_integration.py +16 -14
- sae_lens/cache_activations_runner.py +9 -7
- sae_lens/config.py +170 -257
- sae_lens/constants.py +21 -0
- sae_lens/evals.py +59 -44
- sae_lens/llm_sae_training_runner.py +377 -0
- sae_lens/load_model.py +53 -5
- sae_lens/{toolkit → loading}/pretrained_sae_loaders.py +228 -32
- sae_lens/registry.py +49 -0
- sae_lens/saes/__init__.py +48 -0
- sae_lens/saes/gated_sae.py +254 -0
- sae_lens/saes/jumprelu_sae.py +348 -0
- sae_lens/saes/sae.py +1076 -0
- sae_lens/saes/standard_sae.py +178 -0
- sae_lens/saes/topk_sae.py +300 -0
- sae_lens/training/activation_scaler.py +53 -0
- sae_lens/training/activations_store.py +103 -184
- sae_lens/training/mixing_buffer.py +56 -0
- sae_lens/training/optim.py +60 -36
- sae_lens/training/sae_trainer.py +155 -177
- sae_lens/training/types.py +5 -0
- sae_lens/training/upload_saes_to_huggingface.py +13 -7
- sae_lens/util.py +47 -0
- {sae_lens-5.10.7.dist-info → sae_lens-6.0.0.dist-info}/METADATA +1 -1
- sae_lens-6.0.0.dist-info/RECORD +37 -0
- sae_lens/sae.py +0 -747
- sae_lens/sae_training_runner.py +0 -251
- sae_lens/training/geometric_median.py +0 -101
- sae_lens/training/training_sae.py +0 -710
- sae_lens-5.10.7.dist-info/RECORD +0 -28
- /sae_lens/{toolkit → loading}/__init__.py +0 -0
- /sae_lens/{toolkit → loading}/pretrained_saes_directory.py +0 -0
- {sae_lens-5.10.7.dist-info → sae_lens-6.0.0.dist-info}/LICENSE +0 -0
- {sae_lens-5.10.7.dist-info → sae_lens-6.0.0.dist-info}/WHEEL +0 -0
sae_lens/load_model.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Any, Literal, cast
|
|
1
|
+
from typing import Any, Callable, Literal, cast
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
from transformer_lens import HookedTransformer
|
|
@@ -77,6 +77,7 @@ class HookedProxyLM(HookedRootModule):
|
|
|
77
77
|
# copied and modified from base HookedRootModule
|
|
78
78
|
def setup(self):
|
|
79
79
|
self.mod_dict = {}
|
|
80
|
+
self.named_modules_dict = {}
|
|
80
81
|
self.hook_dict: dict[str, HookPoint] = {}
|
|
81
82
|
for name, module in self.model.named_modules():
|
|
82
83
|
if name == "":
|
|
@@ -89,14 +90,21 @@ class HookedProxyLM(HookedRootModule):
|
|
|
89
90
|
|
|
90
91
|
self.hook_dict[name] = hook_point
|
|
91
92
|
self.mod_dict[name] = hook_point
|
|
93
|
+
self.named_modules_dict[name] = module
|
|
94
|
+
|
|
95
|
+
def run_with_cache(self, *args: Any, **kwargs: Any): # type: ignore
|
|
96
|
+
if "names_filter" in kwargs:
|
|
97
|
+
# hacky way to make sure that the names_filter is passed to our forward method
|
|
98
|
+
kwargs["_names_filter"] = kwargs["names_filter"]
|
|
99
|
+
return super().run_with_cache(*args, **kwargs)
|
|
92
100
|
|
|
93
101
|
def forward(
|
|
94
102
|
self,
|
|
95
103
|
tokens: torch.Tensor,
|
|
96
104
|
return_type: Literal["both", "logits"] = "logits",
|
|
97
105
|
loss_per_token: bool = False,
|
|
98
|
-
# TODO: implement real support for stop_at_layer
|
|
99
106
|
stop_at_layer: int | None = None,
|
|
107
|
+
_names_filter: list[str] | None = None,
|
|
100
108
|
**kwargs: Any,
|
|
101
109
|
) -> Output | Loss:
|
|
102
110
|
# This is just what's needed for evals, not everything that HookedTransformer has
|
|
@@ -107,8 +115,28 @@ class HookedProxyLM(HookedRootModule):
|
|
|
107
115
|
raise NotImplementedError(
|
|
108
116
|
"Only return_type supported is 'both' or 'logits' to match what's in evals.py and ActivationsStore"
|
|
109
117
|
)
|
|
110
|
-
|
|
111
|
-
|
|
118
|
+
|
|
119
|
+
stop_hooks = []
|
|
120
|
+
if stop_at_layer is not None and _names_filter is not None:
|
|
121
|
+
if return_type != "logits":
|
|
122
|
+
raise NotImplementedError(
|
|
123
|
+
"stop_at_layer is not supported for return_type='both'"
|
|
124
|
+
)
|
|
125
|
+
stop_manager = StopManager(_names_filter)
|
|
126
|
+
|
|
127
|
+
for hook_name in _names_filter:
|
|
128
|
+
module = self.named_modules_dict[hook_name]
|
|
129
|
+
stop_fn = stop_manager.get_stop_hook_fn(hook_name)
|
|
130
|
+
stop_hooks.append(module.register_forward_hook(stop_fn))
|
|
131
|
+
try:
|
|
132
|
+
output = self.model(tokens)
|
|
133
|
+
logits = _extract_logits_from_output(output)
|
|
134
|
+
except StopForward:
|
|
135
|
+
# If we stop early, we don't care about the return output
|
|
136
|
+
return None # type: ignore
|
|
137
|
+
finally:
|
|
138
|
+
for stop_hook in stop_hooks:
|
|
139
|
+
stop_hook.remove()
|
|
112
140
|
|
|
113
141
|
if return_type == "logits":
|
|
114
142
|
return logits
|
|
@@ -159,7 +187,7 @@ class HookedProxyLM(HookedRootModule):
|
|
|
159
187
|
|
|
160
188
|
# We don't want to prepend bos but the tokenizer does it automatically, so we remove it manually
|
|
161
189
|
if hasattr(self.tokenizer, "add_bos_token") and self.tokenizer.add_bos_token: # type: ignore
|
|
162
|
-
tokens = get_tokens_with_bos_removed(self.tokenizer, tokens)
|
|
190
|
+
tokens = get_tokens_with_bos_removed(self.tokenizer, tokens) # type: ignore
|
|
163
191
|
return tokens # type: ignore
|
|
164
192
|
|
|
165
193
|
|
|
@@ -183,3 +211,23 @@ def get_hook_fn(hook_point: HookPoint):
|
|
|
183
211
|
return output
|
|
184
212
|
|
|
185
213
|
return hook_fn
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
class StopForward(Exception):
|
|
217
|
+
pass
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
class StopManager:
|
|
221
|
+
def __init__(self, hook_names: list[str]):
|
|
222
|
+
self.hook_names = hook_names
|
|
223
|
+
self.total_hook_names = len(set(hook_names))
|
|
224
|
+
self.called_hook_names = set()
|
|
225
|
+
|
|
226
|
+
def get_stop_hook_fn(self, hook_name: str) -> Callable[[Any, Any, Any], Any]:
|
|
227
|
+
def stop_hook_fn(module: Any, input: Any, output: Any) -> Any: # noqa: ARG001
|
|
228
|
+
self.called_hook_names.add(hook_name)
|
|
229
|
+
if len(self.called_hook_names) == self.total_hook_names:
|
|
230
|
+
raise StopForward()
|
|
231
|
+
return output
|
|
232
|
+
|
|
233
|
+
return stop_hook_fn
|
|
@@ -7,21 +7,41 @@ import numpy as np
|
|
|
7
7
|
import torch
|
|
8
8
|
from huggingface_hub import hf_hub_download
|
|
9
9
|
from huggingface_hub.utils import EntryNotFoundError
|
|
10
|
+
from packaging.version import Version
|
|
10
11
|
from safetensors import safe_open
|
|
11
12
|
from safetensors.torch import load_file
|
|
12
13
|
|
|
13
14
|
from sae_lens import logger
|
|
14
|
-
from sae_lens.
|
|
15
|
+
from sae_lens.constants import (
|
|
15
16
|
DTYPE_MAP,
|
|
16
17
|
SAE_CFG_FILENAME,
|
|
17
18
|
SAE_WEIGHTS_FILENAME,
|
|
19
|
+
SPARSIFY_WEIGHTS_FILENAME,
|
|
18
20
|
SPARSITY_FILENAME,
|
|
19
21
|
)
|
|
20
|
-
from sae_lens.
|
|
22
|
+
from sae_lens.loading.pretrained_saes_directory import (
|
|
21
23
|
get_config_overrides,
|
|
22
24
|
get_pretrained_saes_directory,
|
|
23
25
|
get_repo_id_and_folder_name,
|
|
24
26
|
)
|
|
27
|
+
from sae_lens.registry import get_sae_class
|
|
28
|
+
from sae_lens.util import filter_valid_dataclass_fields
|
|
29
|
+
|
|
30
|
+
LLM_METADATA_KEYS = {
|
|
31
|
+
"model_name",
|
|
32
|
+
"hook_name",
|
|
33
|
+
"model_class_name",
|
|
34
|
+
"hook_head_index",
|
|
35
|
+
"model_from_pretrained_kwargs",
|
|
36
|
+
"prepend_bos",
|
|
37
|
+
"exclude_special_tokens",
|
|
38
|
+
"neuronpedia_id",
|
|
39
|
+
"context_size",
|
|
40
|
+
"seqpos_slice",
|
|
41
|
+
"dataset_path",
|
|
42
|
+
"sae_lens_version",
|
|
43
|
+
"sae_lens_training_version",
|
|
44
|
+
}
|
|
25
45
|
|
|
26
46
|
|
|
27
47
|
# loaders take in a release, sae_id, device, and whether to force download, and returns a tuple of config, state_dict, and log sparsity
|
|
@@ -174,30 +194,69 @@ def get_sae_lens_config_from_disk(
|
|
|
174
194
|
|
|
175
195
|
|
|
176
196
|
def handle_config_defaulting(cfg_dict: dict[str, Any]) -> dict[str, Any]:
|
|
197
|
+
sae_lens_version = cfg_dict.get("sae_lens_version")
|
|
198
|
+
if not sae_lens_version and "metadata" in cfg_dict:
|
|
199
|
+
sae_lens_version = cfg_dict["metadata"].get("sae_lens_version")
|
|
200
|
+
|
|
201
|
+
if not sae_lens_version or Version(sae_lens_version) < Version("6.0.0-rc.0"):
|
|
202
|
+
cfg_dict = handle_pre_6_0_config(cfg_dict)
|
|
203
|
+
return cfg_dict
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def handle_pre_6_0_config(cfg_dict: dict[str, Any]) -> dict[str, Any]:
|
|
207
|
+
"""
|
|
208
|
+
Format a config dictionary for a Sparse Autoencoder (SAE) to be compatible with the new 6.0 format.
|
|
209
|
+
"""
|
|
210
|
+
|
|
211
|
+
rename_keys_map = {
|
|
212
|
+
"hook_point": "hook_name",
|
|
213
|
+
"hook_point_head_index": "hook_head_index",
|
|
214
|
+
"activation_fn_str": "activation_fn",
|
|
215
|
+
}
|
|
216
|
+
new_cfg = {rename_keys_map.get(k, k): v for k, v in cfg_dict.items()}
|
|
217
|
+
|
|
177
218
|
# Set default values for backwards compatibility
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
219
|
+
new_cfg.setdefault("prepend_bos", True)
|
|
220
|
+
new_cfg.setdefault("dataset_trust_remote_code", True)
|
|
221
|
+
new_cfg.setdefault("apply_b_dec_to_input", True)
|
|
222
|
+
new_cfg.setdefault("finetuning_scaling_factor", False)
|
|
223
|
+
new_cfg.setdefault("sae_lens_training_version", None)
|
|
224
|
+
new_cfg.setdefault("activation_fn", new_cfg.get("activation_fn", "relu"))
|
|
225
|
+
new_cfg.setdefault("architecture", "standard")
|
|
226
|
+
new_cfg.setdefault("neuronpedia_id", None)
|
|
227
|
+
new_cfg.setdefault(
|
|
228
|
+
"reshape_activations",
|
|
229
|
+
"hook_z" if "hook_z" in new_cfg.get("hook_name", "") else "none",
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
if "normalize_activations" in new_cfg and isinstance(
|
|
233
|
+
new_cfg["normalize_activations"], bool
|
|
189
234
|
):
|
|
190
235
|
# backwards compatibility
|
|
191
|
-
|
|
236
|
+
new_cfg["normalize_activations"] = (
|
|
192
237
|
"none"
|
|
193
|
-
if not
|
|
238
|
+
if not new_cfg["normalize_activations"]
|
|
194
239
|
else "expected_average_only_in"
|
|
195
240
|
)
|
|
196
241
|
|
|
197
|
-
|
|
198
|
-
|
|
242
|
+
if new_cfg.get("normalize_activations") is None:
|
|
243
|
+
new_cfg["normalize_activations"] = "none"
|
|
199
244
|
|
|
200
|
-
|
|
245
|
+
new_cfg.setdefault("device", "cpu")
|
|
246
|
+
|
|
247
|
+
architecture = new_cfg.get("architecture", "standard")
|
|
248
|
+
|
|
249
|
+
config_class = get_sae_class(architecture)[1]
|
|
250
|
+
|
|
251
|
+
sae_cfg_dict = filter_valid_dataclass_fields(new_cfg, config_class)
|
|
252
|
+
if architecture == "topk" and "activation_fn_kwargs" in new_cfg:
|
|
253
|
+
sae_cfg_dict["k"] = new_cfg["activation_fn_kwargs"]["k"]
|
|
254
|
+
|
|
255
|
+
sae_cfg_dict["metadata"] = {
|
|
256
|
+
k: v for k, v in new_cfg.items() if k in LLM_METADATA_KEYS
|
|
257
|
+
}
|
|
258
|
+
sae_cfg_dict["architecture"] = architecture
|
|
259
|
+
return sae_cfg_dict
|
|
201
260
|
|
|
202
261
|
|
|
203
262
|
def get_connor_rob_hook_z_config_from_hf(
|
|
@@ -221,9 +280,8 @@ def get_connor_rob_hook_z_config_from_hf(
|
|
|
221
280
|
"device": device if device is not None else "cpu",
|
|
222
281
|
"model_name": "gpt2-small",
|
|
223
282
|
"hook_name": old_cfg_dict["act_name"],
|
|
224
|
-
"hook_layer": old_cfg_dict["layer"],
|
|
225
283
|
"hook_head_index": None,
|
|
226
|
-
"
|
|
284
|
+
"activation_fn": "relu",
|
|
227
285
|
"apply_b_dec_to_input": True,
|
|
228
286
|
"finetuning_scaling_factor": False,
|
|
229
287
|
"sae_lens_training_version": None,
|
|
@@ -232,6 +290,7 @@ def get_connor_rob_hook_z_config_from_hf(
|
|
|
232
290
|
"context_size": 128,
|
|
233
291
|
"normalize_activations": "none",
|
|
234
292
|
"dataset_trust_remote_code": True,
|
|
293
|
+
"reshape_activations": "hook_z",
|
|
235
294
|
**(cfg_overrides or {}),
|
|
236
295
|
}
|
|
237
296
|
|
|
@@ -370,9 +429,8 @@ def get_gemma_2_config_from_hf(
|
|
|
370
429
|
"dtype": "float32",
|
|
371
430
|
"model_name": model_name,
|
|
372
431
|
"hook_name": hook_name,
|
|
373
|
-
"hook_layer": layer,
|
|
374
432
|
"hook_head_index": None,
|
|
375
|
-
"
|
|
433
|
+
"activation_fn": "relu",
|
|
376
434
|
"finetuning_scaling_factor": False,
|
|
377
435
|
"sae_lens_training_version": None,
|
|
378
436
|
"prepend_bos": True,
|
|
@@ -492,9 +550,8 @@ def get_llama_scope_config_from_hf(
|
|
|
492
550
|
"dtype": "bfloat16",
|
|
493
551
|
"model_name": model_name,
|
|
494
552
|
"hook_name": old_cfg_dict["hook_point_in"],
|
|
495
|
-
"hook_layer": int(old_cfg_dict["hook_point_in"].split(".")[1]),
|
|
496
553
|
"hook_head_index": None,
|
|
497
|
-
"
|
|
554
|
+
"activation_fn": "relu",
|
|
498
555
|
"finetuning_scaling_factor": False,
|
|
499
556
|
"sae_lens_training_version": None,
|
|
500
557
|
"prepend_bos": True,
|
|
@@ -606,8 +663,8 @@ def get_dictionary_learning_config_1_from_hf(
|
|
|
606
663
|
|
|
607
664
|
hook_point_name = f"blocks.{trainer['layer']}.hook_resid_post"
|
|
608
665
|
|
|
609
|
-
|
|
610
|
-
activation_fn_kwargs = {"k": trainer["k"]} if
|
|
666
|
+
activation_fn = "topk" if trainer["dict_class"] == "AutoEncoderTopK" else "relu"
|
|
667
|
+
activation_fn_kwargs = {"k": trainer["k"]} if activation_fn == "topk" else {}
|
|
611
668
|
|
|
612
669
|
return {
|
|
613
670
|
"architecture": (
|
|
@@ -619,9 +676,8 @@ def get_dictionary_learning_config_1_from_hf(
|
|
|
619
676
|
"device": device,
|
|
620
677
|
"model_name": trainer["lm_name"].split("/")[-1],
|
|
621
678
|
"hook_name": hook_point_name,
|
|
622
|
-
"hook_layer": trainer["layer"],
|
|
623
679
|
"hook_head_index": None,
|
|
624
|
-
"
|
|
680
|
+
"activation_fn": activation_fn,
|
|
625
681
|
"activation_fn_kwargs": activation_fn_kwargs,
|
|
626
682
|
"apply_b_dec_to_input": True,
|
|
627
683
|
"finetuning_scaling_factor": False,
|
|
@@ -658,13 +714,12 @@ def get_deepseek_r1_config_from_hf(
|
|
|
658
714
|
"context_size": 1024,
|
|
659
715
|
"model_name": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
|
|
660
716
|
"hook_name": f"blocks.{layer}.hook_resid_post",
|
|
661
|
-
"hook_layer": layer,
|
|
662
717
|
"hook_head_index": None,
|
|
663
718
|
"prepend_bos": True,
|
|
664
719
|
"dataset_path": "lmsys/lmsys-chat-1m",
|
|
665
720
|
"dataset_trust_remote_code": True,
|
|
666
721
|
"sae_lens_training_version": None,
|
|
667
|
-
"
|
|
722
|
+
"activation_fn": "relu",
|
|
668
723
|
"normalize_activations": "none",
|
|
669
724
|
"device": device,
|
|
670
725
|
"apply_b_dec_to_input": False,
|
|
@@ -817,9 +872,8 @@ def get_llama_scope_r1_distill_config_from_hf(
|
|
|
817
872
|
"device": device,
|
|
818
873
|
"model_name": model_name,
|
|
819
874
|
"hook_name": huggingface_cfg_dict["hook_point_in"],
|
|
820
|
-
"hook_layer": int(huggingface_cfg_dict["hook_point_in"].split(".")[1]),
|
|
821
875
|
"hook_head_index": None,
|
|
822
|
-
"
|
|
876
|
+
"activation_fn": "relu",
|
|
823
877
|
"finetuning_scaling_factor": False,
|
|
824
878
|
"sae_lens_training_version": None,
|
|
825
879
|
"prepend_bos": True,
|
|
@@ -898,6 +952,146 @@ def llama_scope_r1_distill_sae_huggingface_loader(
|
|
|
898
952
|
return cfg_dict, state_dict, log_sparsity
|
|
899
953
|
|
|
900
954
|
|
|
955
|
+
def get_sparsify_config_from_hf(
|
|
956
|
+
repo_id: str,
|
|
957
|
+
folder_name: str,
|
|
958
|
+
device: str,
|
|
959
|
+
force_download: bool = False,
|
|
960
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
961
|
+
) -> dict[str, Any]:
|
|
962
|
+
cfg_filename = f"{folder_name}/{SAE_CFG_FILENAME}"
|
|
963
|
+
cfg_path = hf_hub_download(
|
|
964
|
+
repo_id,
|
|
965
|
+
filename=cfg_filename,
|
|
966
|
+
force_download=force_download,
|
|
967
|
+
)
|
|
968
|
+
sae_path = Path(cfg_path).parent
|
|
969
|
+
return get_sparsify_config_from_disk(
|
|
970
|
+
sae_path, device=device, cfg_overrides=cfg_overrides
|
|
971
|
+
)
|
|
972
|
+
|
|
973
|
+
|
|
974
|
+
def get_sparsify_config_from_disk(
|
|
975
|
+
path: str | Path,
|
|
976
|
+
device: str | None = None,
|
|
977
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
978
|
+
) -> dict[str, Any]:
|
|
979
|
+
path = Path(path)
|
|
980
|
+
|
|
981
|
+
with open(path / SAE_CFG_FILENAME) as f:
|
|
982
|
+
old_cfg_dict = json.load(f)
|
|
983
|
+
|
|
984
|
+
config_path = path.parent / "config.json"
|
|
985
|
+
if config_path.exists():
|
|
986
|
+
with open(config_path) as f:
|
|
987
|
+
config_dict = json.load(f)
|
|
988
|
+
else:
|
|
989
|
+
config_dict = {}
|
|
990
|
+
|
|
991
|
+
folder_name = path.name
|
|
992
|
+
if folder_name == "embed_tokens":
|
|
993
|
+
hook_name, layer = "hook_embed", 0
|
|
994
|
+
else:
|
|
995
|
+
match = re.search(r"layers[._](\d+)", folder_name)
|
|
996
|
+
if match is None:
|
|
997
|
+
raise ValueError(f"Unrecognized Sparsify folder: {folder_name}")
|
|
998
|
+
layer = int(match.group(1))
|
|
999
|
+
hook_name = f"blocks.{layer}.hook_resid_post"
|
|
1000
|
+
|
|
1001
|
+
cfg_dict: dict[str, Any] = {
|
|
1002
|
+
"architecture": "standard",
|
|
1003
|
+
"d_in": old_cfg_dict["d_in"],
|
|
1004
|
+
"d_sae": old_cfg_dict["d_in"] * old_cfg_dict["expansion_factor"],
|
|
1005
|
+
"dtype": "bfloat16",
|
|
1006
|
+
"device": device or "cpu",
|
|
1007
|
+
"model_name": config_dict.get("model", path.parts[-2]),
|
|
1008
|
+
"hook_name": hook_name,
|
|
1009
|
+
"hook_layer": layer,
|
|
1010
|
+
"hook_head_index": None,
|
|
1011
|
+
"activation_fn_str": "topk",
|
|
1012
|
+
"activation_fn_kwargs": {
|
|
1013
|
+
"k": old_cfg_dict["k"],
|
|
1014
|
+
"signed": old_cfg_dict.get("signed", False),
|
|
1015
|
+
},
|
|
1016
|
+
"apply_b_dec_to_input": not old_cfg_dict.get("normalize_decoder", False),
|
|
1017
|
+
"dataset_path": config_dict.get(
|
|
1018
|
+
"dataset", "togethercomputer/RedPajama-Data-1T-Sample"
|
|
1019
|
+
),
|
|
1020
|
+
"context_size": config_dict.get("ctx_len", 2048),
|
|
1021
|
+
"finetuning_scaling_factor": False,
|
|
1022
|
+
"sae_lens_training_version": None,
|
|
1023
|
+
"prepend_bos": True,
|
|
1024
|
+
"dataset_trust_remote_code": True,
|
|
1025
|
+
"normalize_activations": "none",
|
|
1026
|
+
"neuronpedia_id": None,
|
|
1027
|
+
}
|
|
1028
|
+
|
|
1029
|
+
if cfg_overrides:
|
|
1030
|
+
cfg_dict.update(cfg_overrides)
|
|
1031
|
+
|
|
1032
|
+
return cfg_dict
|
|
1033
|
+
|
|
1034
|
+
|
|
1035
|
+
def sparsify_huggingface_loader(
|
|
1036
|
+
repo_id: str,
|
|
1037
|
+
folder_name: str,
|
|
1038
|
+
device: str = "cpu",
|
|
1039
|
+
force_download: bool = False,
|
|
1040
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
1041
|
+
) -> tuple[dict[str, Any], dict[str, torch.Tensor], None]:
|
|
1042
|
+
weights_filename = f"{folder_name}/{SPARSIFY_WEIGHTS_FILENAME}"
|
|
1043
|
+
sae_path = hf_hub_download(
|
|
1044
|
+
repo_id,
|
|
1045
|
+
filename=weights_filename,
|
|
1046
|
+
force_download=force_download,
|
|
1047
|
+
)
|
|
1048
|
+
cfg_dict, state_dict = sparsify_disk_loader(
|
|
1049
|
+
Path(sae_path).parent, device=device, cfg_overrides=cfg_overrides
|
|
1050
|
+
)
|
|
1051
|
+
return cfg_dict, state_dict, None
|
|
1052
|
+
|
|
1053
|
+
|
|
1054
|
+
def sparsify_disk_loader(
|
|
1055
|
+
path: str | Path,
|
|
1056
|
+
device: str = "cpu",
|
|
1057
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
1058
|
+
) -> tuple[dict[str, Any], dict[str, torch.Tensor]]:
|
|
1059
|
+
cfg_dict = get_sparsify_config_from_disk(path, device, cfg_overrides)
|
|
1060
|
+
|
|
1061
|
+
weight_path = Path(path) / SPARSIFY_WEIGHTS_FILENAME
|
|
1062
|
+
state_dict_loaded = load_file(weight_path, device=device)
|
|
1063
|
+
|
|
1064
|
+
dtype = DTYPE_MAP[cfg_dict["dtype"]]
|
|
1065
|
+
|
|
1066
|
+
W_enc = (
|
|
1067
|
+
state_dict_loaded["W_enc"]
|
|
1068
|
+
if "W_enc" in state_dict_loaded
|
|
1069
|
+
else state_dict_loaded["encoder.weight"].T
|
|
1070
|
+
).to(dtype)
|
|
1071
|
+
|
|
1072
|
+
if "W_dec" in state_dict_loaded:
|
|
1073
|
+
W_dec = state_dict_loaded["W_dec"].T.to(dtype)
|
|
1074
|
+
else:
|
|
1075
|
+
W_dec = state_dict_loaded["decoder.weight"].T.to(dtype)
|
|
1076
|
+
|
|
1077
|
+
if "b_enc" in state_dict_loaded:
|
|
1078
|
+
b_enc = state_dict_loaded["b_enc"].to(dtype)
|
|
1079
|
+
elif "encoder.bias" in state_dict_loaded:
|
|
1080
|
+
b_enc = state_dict_loaded["encoder.bias"].to(dtype)
|
|
1081
|
+
else:
|
|
1082
|
+
b_enc = torch.zeros(cfg_dict["d_sae"], dtype=dtype, device=device)
|
|
1083
|
+
|
|
1084
|
+
if "b_dec" in state_dict_loaded:
|
|
1085
|
+
b_dec = state_dict_loaded["b_dec"].to(dtype)
|
|
1086
|
+
elif "decoder.bias" in state_dict_loaded:
|
|
1087
|
+
b_dec = state_dict_loaded["decoder.bias"].to(dtype)
|
|
1088
|
+
else:
|
|
1089
|
+
b_dec = torch.zeros(cfg_dict["d_in"], dtype=dtype, device=device)
|
|
1090
|
+
|
|
1091
|
+
state_dict = {"W_enc": W_enc, "b_enc": b_enc, "W_dec": W_dec, "b_dec": b_dec}
|
|
1092
|
+
return cfg_dict, state_dict
|
|
1093
|
+
|
|
1094
|
+
|
|
901
1095
|
NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
|
|
902
1096
|
"sae_lens": sae_lens_huggingface_loader,
|
|
903
1097
|
"connor_rob_hook_z": connor_rob_hook_z_huggingface_loader,
|
|
@@ -906,6 +1100,7 @@ NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
|
|
|
906
1100
|
"llama_scope_r1_distill": llama_scope_r1_distill_sae_huggingface_loader,
|
|
907
1101
|
"dictionary_learning_1": dictionary_learning_sae_huggingface_loader_1,
|
|
908
1102
|
"deepseek_r1": deepseek_r1_sae_huggingface_loader,
|
|
1103
|
+
"sparsify": sparsify_huggingface_loader,
|
|
909
1104
|
}
|
|
910
1105
|
|
|
911
1106
|
|
|
@@ -917,4 +1112,5 @@ NAMED_PRETRAINED_SAE_CONFIG_GETTERS: dict[str, PretrainedSaeConfigHuggingfaceLoa
|
|
|
917
1112
|
"llama_scope_r1_distill": get_llama_scope_r1_distill_config_from_hf,
|
|
918
1113
|
"dictionary_learning_1": get_dictionary_learning_config_1_from_hf,
|
|
919
1114
|
"deepseek_r1": get_deepseek_r1_config_from_hf,
|
|
1115
|
+
"sparsify": get_sparsify_config_from_hf,
|
|
920
1116
|
}
|
sae_lens/registry.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, Any
|
|
2
|
+
|
|
3
|
+
# avoid circular imports
|
|
4
|
+
if TYPE_CHECKING:
|
|
5
|
+
from sae_lens.saes.sae import SAE, SAEConfig, TrainingSAE, TrainingSAEConfig
|
|
6
|
+
|
|
7
|
+
SAE_CLASS_REGISTRY: dict[str, tuple["type[SAE[Any]]", "type[SAEConfig]"]] = {}
|
|
8
|
+
SAE_TRAINING_CLASS_REGISTRY: dict[
|
|
9
|
+
str, tuple["type[TrainingSAE[Any]]", "type[TrainingSAEConfig]"]
|
|
10
|
+
] = {}
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def register_sae_class(
|
|
14
|
+
architecture: str,
|
|
15
|
+
sae_class: "type[SAE[Any]]",
|
|
16
|
+
sae_config_class: "type[SAEConfig]",
|
|
17
|
+
) -> None:
|
|
18
|
+
if architecture in SAE_CLASS_REGISTRY:
|
|
19
|
+
raise ValueError(
|
|
20
|
+
f"SAE class for architecture {architecture} already registered."
|
|
21
|
+
)
|
|
22
|
+
SAE_CLASS_REGISTRY[architecture] = (sae_class, sae_config_class)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def register_sae_training_class(
|
|
26
|
+
architecture: str,
|
|
27
|
+
sae_training_class: "type[TrainingSAE[Any]]",
|
|
28
|
+
sae_training_config_class: "type[TrainingSAEConfig]",
|
|
29
|
+
) -> None:
|
|
30
|
+
if architecture in SAE_TRAINING_CLASS_REGISTRY:
|
|
31
|
+
raise ValueError(
|
|
32
|
+
f"SAE training class for architecture {architecture} already registered."
|
|
33
|
+
)
|
|
34
|
+
SAE_TRAINING_CLASS_REGISTRY[architecture] = (
|
|
35
|
+
sae_training_class,
|
|
36
|
+
sae_training_config_class,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def get_sae_class(
|
|
41
|
+
architecture: str,
|
|
42
|
+
) -> tuple["type[SAE[Any]]", "type[SAEConfig]"]:
|
|
43
|
+
return SAE_CLASS_REGISTRY[architecture]
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def get_sae_training_class(
|
|
47
|
+
architecture: str,
|
|
48
|
+
) -> tuple["type[TrainingSAE[Any]]", "type[TrainingSAEConfig]"]:
|
|
49
|
+
return SAE_TRAINING_CLASS_REGISTRY[architecture]
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
from .gated_sae import (
|
|
2
|
+
GatedSAE,
|
|
3
|
+
GatedSAEConfig,
|
|
4
|
+
GatedTrainingSAE,
|
|
5
|
+
GatedTrainingSAEConfig,
|
|
6
|
+
)
|
|
7
|
+
from .jumprelu_sae import (
|
|
8
|
+
JumpReLUSAE,
|
|
9
|
+
JumpReLUSAEConfig,
|
|
10
|
+
JumpReLUTrainingSAE,
|
|
11
|
+
JumpReLUTrainingSAEConfig,
|
|
12
|
+
)
|
|
13
|
+
from .sae import SAE, SAEConfig, TrainingSAE, TrainingSAEConfig
|
|
14
|
+
from .standard_sae import (
|
|
15
|
+
StandardSAE,
|
|
16
|
+
StandardSAEConfig,
|
|
17
|
+
StandardTrainingSAE,
|
|
18
|
+
StandardTrainingSAEConfig,
|
|
19
|
+
)
|
|
20
|
+
from .topk_sae import (
|
|
21
|
+
TopKSAE,
|
|
22
|
+
TopKSAEConfig,
|
|
23
|
+
TopKTrainingSAE,
|
|
24
|
+
TopKTrainingSAEConfig,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
__all__ = [
|
|
28
|
+
"SAE",
|
|
29
|
+
"SAEConfig",
|
|
30
|
+
"TrainingSAE",
|
|
31
|
+
"TrainingSAEConfig",
|
|
32
|
+
"StandardSAE",
|
|
33
|
+
"StandardSAEConfig",
|
|
34
|
+
"StandardTrainingSAE",
|
|
35
|
+
"StandardTrainingSAEConfig",
|
|
36
|
+
"GatedSAE",
|
|
37
|
+
"GatedSAEConfig",
|
|
38
|
+
"GatedTrainingSAE",
|
|
39
|
+
"GatedTrainingSAEConfig",
|
|
40
|
+
"JumpReLUSAE",
|
|
41
|
+
"JumpReLUSAEConfig",
|
|
42
|
+
"JumpReLUTrainingSAE",
|
|
43
|
+
"JumpReLUTrainingSAEConfig",
|
|
44
|
+
"TopKSAE",
|
|
45
|
+
"TopKSAEConfig",
|
|
46
|
+
"TopKTrainingSAE",
|
|
47
|
+
"TopKTrainingSAEConfig",
|
|
48
|
+
]
|