sae-lens 5.11.0__py3-none-any.whl → 6.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +60 -7
- sae_lens/analysis/hooked_sae_transformer.py +12 -12
- sae_lens/analysis/neuronpedia_integration.py +16 -14
- sae_lens/cache_activations_runner.py +9 -7
- sae_lens/config.py +170 -258
- sae_lens/constants.py +21 -0
- sae_lens/evals.py +59 -44
- sae_lens/llm_sae_training_runner.py +377 -0
- sae_lens/load_model.py +52 -4
- sae_lens/{toolkit → loading}/pretrained_sae_loaders.py +85 -32
- sae_lens/registry.py +49 -0
- sae_lens/saes/__init__.py +48 -0
- sae_lens/saes/gated_sae.py +254 -0
- sae_lens/saes/jumprelu_sae.py +348 -0
- sae_lens/saes/sae.py +1076 -0
- sae_lens/saes/standard_sae.py +178 -0
- sae_lens/saes/topk_sae.py +300 -0
- sae_lens/training/activation_scaler.py +53 -0
- sae_lens/training/activations_store.py +103 -184
- sae_lens/training/mixing_buffer.py +56 -0
- sae_lens/training/optim.py +60 -36
- sae_lens/training/sae_trainer.py +155 -177
- sae_lens/training/types.py +5 -0
- sae_lens/training/upload_saes_to_huggingface.py +13 -7
- sae_lens/util.py +47 -0
- {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/METADATA +1 -1
- sae_lens-6.0.0.dist-info/RECORD +37 -0
- sae_lens/sae.py +0 -747
- sae_lens/sae_training_runner.py +0 -251
- sae_lens/training/geometric_median.py +0 -101
- sae_lens/training/training_sae.py +0 -710
- sae_lens-5.11.0.dist-info/RECORD +0 -28
- /sae_lens/{toolkit → loading}/__init__.py +0 -0
- /sae_lens/{toolkit → loading}/pretrained_saes_directory.py +0 -0
- {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/LICENSE +0 -0
- {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/WHEEL +0 -0
sae_lens/load_model.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Any, Literal, cast
|
|
1
|
+
from typing import Any, Callable, Literal, cast
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
from transformer_lens import HookedTransformer
|
|
@@ -77,6 +77,7 @@ class HookedProxyLM(HookedRootModule):
|
|
|
77
77
|
# copied and modified from base HookedRootModule
|
|
78
78
|
def setup(self):
|
|
79
79
|
self.mod_dict = {}
|
|
80
|
+
self.named_modules_dict = {}
|
|
80
81
|
self.hook_dict: dict[str, HookPoint] = {}
|
|
81
82
|
for name, module in self.model.named_modules():
|
|
82
83
|
if name == "":
|
|
@@ -89,14 +90,21 @@ class HookedProxyLM(HookedRootModule):
|
|
|
89
90
|
|
|
90
91
|
self.hook_dict[name] = hook_point
|
|
91
92
|
self.mod_dict[name] = hook_point
|
|
93
|
+
self.named_modules_dict[name] = module
|
|
94
|
+
|
|
95
|
+
def run_with_cache(self, *args: Any, **kwargs: Any): # type: ignore
|
|
96
|
+
if "names_filter" in kwargs:
|
|
97
|
+
# hacky way to make sure that the names_filter is passed to our forward method
|
|
98
|
+
kwargs["_names_filter"] = kwargs["names_filter"]
|
|
99
|
+
return super().run_with_cache(*args, **kwargs)
|
|
92
100
|
|
|
93
101
|
def forward(
|
|
94
102
|
self,
|
|
95
103
|
tokens: torch.Tensor,
|
|
96
104
|
return_type: Literal["both", "logits"] = "logits",
|
|
97
105
|
loss_per_token: bool = False,
|
|
98
|
-
# TODO: implement real support for stop_at_layer
|
|
99
106
|
stop_at_layer: int | None = None,
|
|
107
|
+
_names_filter: list[str] | None = None,
|
|
100
108
|
**kwargs: Any,
|
|
101
109
|
) -> Output | Loss:
|
|
102
110
|
# This is just what's needed for evals, not everything that HookedTransformer has
|
|
@@ -107,8 +115,28 @@ class HookedProxyLM(HookedRootModule):
|
|
|
107
115
|
raise NotImplementedError(
|
|
108
116
|
"Only return_type supported is 'both' or 'logits' to match what's in evals.py and ActivationsStore"
|
|
109
117
|
)
|
|
110
|
-
|
|
111
|
-
|
|
118
|
+
|
|
119
|
+
stop_hooks = []
|
|
120
|
+
if stop_at_layer is not None and _names_filter is not None:
|
|
121
|
+
if return_type != "logits":
|
|
122
|
+
raise NotImplementedError(
|
|
123
|
+
"stop_at_layer is not supported for return_type='both'"
|
|
124
|
+
)
|
|
125
|
+
stop_manager = StopManager(_names_filter)
|
|
126
|
+
|
|
127
|
+
for hook_name in _names_filter:
|
|
128
|
+
module = self.named_modules_dict[hook_name]
|
|
129
|
+
stop_fn = stop_manager.get_stop_hook_fn(hook_name)
|
|
130
|
+
stop_hooks.append(module.register_forward_hook(stop_fn))
|
|
131
|
+
try:
|
|
132
|
+
output = self.model(tokens)
|
|
133
|
+
logits = _extract_logits_from_output(output)
|
|
134
|
+
except StopForward:
|
|
135
|
+
# If we stop early, we don't care about the return output
|
|
136
|
+
return None # type: ignore
|
|
137
|
+
finally:
|
|
138
|
+
for stop_hook in stop_hooks:
|
|
139
|
+
stop_hook.remove()
|
|
112
140
|
|
|
113
141
|
if return_type == "logits":
|
|
114
142
|
return logits
|
|
@@ -183,3 +211,23 @@ def get_hook_fn(hook_point: HookPoint):
|
|
|
183
211
|
return output
|
|
184
212
|
|
|
185
213
|
return hook_fn
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
class StopForward(Exception):
|
|
217
|
+
pass
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
class StopManager:
|
|
221
|
+
def __init__(self, hook_names: list[str]):
|
|
222
|
+
self.hook_names = hook_names
|
|
223
|
+
self.total_hook_names = len(set(hook_names))
|
|
224
|
+
self.called_hook_names = set()
|
|
225
|
+
|
|
226
|
+
def get_stop_hook_fn(self, hook_name: str) -> Callable[[Any, Any, Any], Any]:
|
|
227
|
+
def stop_hook_fn(module: Any, input: Any, output: Any) -> Any: # noqa: ARG001
|
|
228
|
+
self.called_hook_names.add(hook_name)
|
|
229
|
+
if len(self.called_hook_names) == self.total_hook_names:
|
|
230
|
+
raise StopForward()
|
|
231
|
+
return output
|
|
232
|
+
|
|
233
|
+
return stop_hook_fn
|
|
@@ -7,22 +7,41 @@ import numpy as np
|
|
|
7
7
|
import torch
|
|
8
8
|
from huggingface_hub import hf_hub_download
|
|
9
9
|
from huggingface_hub.utils import EntryNotFoundError
|
|
10
|
+
from packaging.version import Version
|
|
10
11
|
from safetensors import safe_open
|
|
11
12
|
from safetensors.torch import load_file
|
|
12
13
|
|
|
13
14
|
from sae_lens import logger
|
|
14
|
-
from sae_lens.
|
|
15
|
+
from sae_lens.constants import (
|
|
15
16
|
DTYPE_MAP,
|
|
16
17
|
SAE_CFG_FILENAME,
|
|
17
18
|
SAE_WEIGHTS_FILENAME,
|
|
18
19
|
SPARSIFY_WEIGHTS_FILENAME,
|
|
19
20
|
SPARSITY_FILENAME,
|
|
20
21
|
)
|
|
21
|
-
from sae_lens.
|
|
22
|
+
from sae_lens.loading.pretrained_saes_directory import (
|
|
22
23
|
get_config_overrides,
|
|
23
24
|
get_pretrained_saes_directory,
|
|
24
25
|
get_repo_id_and_folder_name,
|
|
25
26
|
)
|
|
27
|
+
from sae_lens.registry import get_sae_class
|
|
28
|
+
from sae_lens.util import filter_valid_dataclass_fields
|
|
29
|
+
|
|
30
|
+
LLM_METADATA_KEYS = {
|
|
31
|
+
"model_name",
|
|
32
|
+
"hook_name",
|
|
33
|
+
"model_class_name",
|
|
34
|
+
"hook_head_index",
|
|
35
|
+
"model_from_pretrained_kwargs",
|
|
36
|
+
"prepend_bos",
|
|
37
|
+
"exclude_special_tokens",
|
|
38
|
+
"neuronpedia_id",
|
|
39
|
+
"context_size",
|
|
40
|
+
"seqpos_slice",
|
|
41
|
+
"dataset_path",
|
|
42
|
+
"sae_lens_version",
|
|
43
|
+
"sae_lens_training_version",
|
|
44
|
+
}
|
|
26
45
|
|
|
27
46
|
|
|
28
47
|
# loaders take in a release, sae_id, device, and whether to force download, and returns a tuple of config, state_dict, and log sparsity
|
|
@@ -175,30 +194,69 @@ def get_sae_lens_config_from_disk(
|
|
|
175
194
|
|
|
176
195
|
|
|
177
196
|
def handle_config_defaulting(cfg_dict: dict[str, Any]) -> dict[str, Any]:
|
|
197
|
+
sae_lens_version = cfg_dict.get("sae_lens_version")
|
|
198
|
+
if not sae_lens_version and "metadata" in cfg_dict:
|
|
199
|
+
sae_lens_version = cfg_dict["metadata"].get("sae_lens_version")
|
|
200
|
+
|
|
201
|
+
if not sae_lens_version or Version(sae_lens_version) < Version("6.0.0-rc.0"):
|
|
202
|
+
cfg_dict = handle_pre_6_0_config(cfg_dict)
|
|
203
|
+
return cfg_dict
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def handle_pre_6_0_config(cfg_dict: dict[str, Any]) -> dict[str, Any]:
|
|
207
|
+
"""
|
|
208
|
+
Format a config dictionary for a Sparse Autoencoder (SAE) to be compatible with the new 6.0 format.
|
|
209
|
+
"""
|
|
210
|
+
|
|
211
|
+
rename_keys_map = {
|
|
212
|
+
"hook_point": "hook_name",
|
|
213
|
+
"hook_point_head_index": "hook_head_index",
|
|
214
|
+
"activation_fn_str": "activation_fn",
|
|
215
|
+
}
|
|
216
|
+
new_cfg = {rename_keys_map.get(k, k): v for k, v in cfg_dict.items()}
|
|
217
|
+
|
|
178
218
|
# Set default values for backwards compatibility
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
219
|
+
new_cfg.setdefault("prepend_bos", True)
|
|
220
|
+
new_cfg.setdefault("dataset_trust_remote_code", True)
|
|
221
|
+
new_cfg.setdefault("apply_b_dec_to_input", True)
|
|
222
|
+
new_cfg.setdefault("finetuning_scaling_factor", False)
|
|
223
|
+
new_cfg.setdefault("sae_lens_training_version", None)
|
|
224
|
+
new_cfg.setdefault("activation_fn", new_cfg.get("activation_fn", "relu"))
|
|
225
|
+
new_cfg.setdefault("architecture", "standard")
|
|
226
|
+
new_cfg.setdefault("neuronpedia_id", None)
|
|
227
|
+
new_cfg.setdefault(
|
|
228
|
+
"reshape_activations",
|
|
229
|
+
"hook_z" if "hook_z" in new_cfg.get("hook_name", "") else "none",
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
if "normalize_activations" in new_cfg and isinstance(
|
|
233
|
+
new_cfg["normalize_activations"], bool
|
|
190
234
|
):
|
|
191
235
|
# backwards compatibility
|
|
192
|
-
|
|
236
|
+
new_cfg["normalize_activations"] = (
|
|
193
237
|
"none"
|
|
194
|
-
if not
|
|
238
|
+
if not new_cfg["normalize_activations"]
|
|
195
239
|
else "expected_average_only_in"
|
|
196
240
|
)
|
|
197
241
|
|
|
198
|
-
|
|
199
|
-
|
|
242
|
+
if new_cfg.get("normalize_activations") is None:
|
|
243
|
+
new_cfg["normalize_activations"] = "none"
|
|
200
244
|
|
|
201
|
-
|
|
245
|
+
new_cfg.setdefault("device", "cpu")
|
|
246
|
+
|
|
247
|
+
architecture = new_cfg.get("architecture", "standard")
|
|
248
|
+
|
|
249
|
+
config_class = get_sae_class(architecture)[1]
|
|
250
|
+
|
|
251
|
+
sae_cfg_dict = filter_valid_dataclass_fields(new_cfg, config_class)
|
|
252
|
+
if architecture == "topk" and "activation_fn_kwargs" in new_cfg:
|
|
253
|
+
sae_cfg_dict["k"] = new_cfg["activation_fn_kwargs"]["k"]
|
|
254
|
+
|
|
255
|
+
sae_cfg_dict["metadata"] = {
|
|
256
|
+
k: v for k, v in new_cfg.items() if k in LLM_METADATA_KEYS
|
|
257
|
+
}
|
|
258
|
+
sae_cfg_dict["architecture"] = architecture
|
|
259
|
+
return sae_cfg_dict
|
|
202
260
|
|
|
203
261
|
|
|
204
262
|
def get_connor_rob_hook_z_config_from_hf(
|
|
@@ -222,9 +280,8 @@ def get_connor_rob_hook_z_config_from_hf(
|
|
|
222
280
|
"device": device if device is not None else "cpu",
|
|
223
281
|
"model_name": "gpt2-small",
|
|
224
282
|
"hook_name": old_cfg_dict["act_name"],
|
|
225
|
-
"hook_layer": old_cfg_dict["layer"],
|
|
226
283
|
"hook_head_index": None,
|
|
227
|
-
"
|
|
284
|
+
"activation_fn": "relu",
|
|
228
285
|
"apply_b_dec_to_input": True,
|
|
229
286
|
"finetuning_scaling_factor": False,
|
|
230
287
|
"sae_lens_training_version": None,
|
|
@@ -233,6 +290,7 @@ def get_connor_rob_hook_z_config_from_hf(
|
|
|
233
290
|
"context_size": 128,
|
|
234
291
|
"normalize_activations": "none",
|
|
235
292
|
"dataset_trust_remote_code": True,
|
|
293
|
+
"reshape_activations": "hook_z",
|
|
236
294
|
**(cfg_overrides or {}),
|
|
237
295
|
}
|
|
238
296
|
|
|
@@ -371,9 +429,8 @@ def get_gemma_2_config_from_hf(
|
|
|
371
429
|
"dtype": "float32",
|
|
372
430
|
"model_name": model_name,
|
|
373
431
|
"hook_name": hook_name,
|
|
374
|
-
"hook_layer": layer,
|
|
375
432
|
"hook_head_index": None,
|
|
376
|
-
"
|
|
433
|
+
"activation_fn": "relu",
|
|
377
434
|
"finetuning_scaling_factor": False,
|
|
378
435
|
"sae_lens_training_version": None,
|
|
379
436
|
"prepend_bos": True,
|
|
@@ -493,9 +550,8 @@ def get_llama_scope_config_from_hf(
|
|
|
493
550
|
"dtype": "bfloat16",
|
|
494
551
|
"model_name": model_name,
|
|
495
552
|
"hook_name": old_cfg_dict["hook_point_in"],
|
|
496
|
-
"hook_layer": int(old_cfg_dict["hook_point_in"].split(".")[1]),
|
|
497
553
|
"hook_head_index": None,
|
|
498
|
-
"
|
|
554
|
+
"activation_fn": "relu",
|
|
499
555
|
"finetuning_scaling_factor": False,
|
|
500
556
|
"sae_lens_training_version": None,
|
|
501
557
|
"prepend_bos": True,
|
|
@@ -607,8 +663,8 @@ def get_dictionary_learning_config_1_from_hf(
|
|
|
607
663
|
|
|
608
664
|
hook_point_name = f"blocks.{trainer['layer']}.hook_resid_post"
|
|
609
665
|
|
|
610
|
-
|
|
611
|
-
activation_fn_kwargs = {"k": trainer["k"]} if
|
|
666
|
+
activation_fn = "topk" if trainer["dict_class"] == "AutoEncoderTopK" else "relu"
|
|
667
|
+
activation_fn_kwargs = {"k": trainer["k"]} if activation_fn == "topk" else {}
|
|
612
668
|
|
|
613
669
|
return {
|
|
614
670
|
"architecture": (
|
|
@@ -620,9 +676,8 @@ def get_dictionary_learning_config_1_from_hf(
|
|
|
620
676
|
"device": device,
|
|
621
677
|
"model_name": trainer["lm_name"].split("/")[-1],
|
|
622
678
|
"hook_name": hook_point_name,
|
|
623
|
-
"hook_layer": trainer["layer"],
|
|
624
679
|
"hook_head_index": None,
|
|
625
|
-
"
|
|
680
|
+
"activation_fn": activation_fn,
|
|
626
681
|
"activation_fn_kwargs": activation_fn_kwargs,
|
|
627
682
|
"apply_b_dec_to_input": True,
|
|
628
683
|
"finetuning_scaling_factor": False,
|
|
@@ -659,13 +714,12 @@ def get_deepseek_r1_config_from_hf(
|
|
|
659
714
|
"context_size": 1024,
|
|
660
715
|
"model_name": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
|
|
661
716
|
"hook_name": f"blocks.{layer}.hook_resid_post",
|
|
662
|
-
"hook_layer": layer,
|
|
663
717
|
"hook_head_index": None,
|
|
664
718
|
"prepend_bos": True,
|
|
665
719
|
"dataset_path": "lmsys/lmsys-chat-1m",
|
|
666
720
|
"dataset_trust_remote_code": True,
|
|
667
721
|
"sae_lens_training_version": None,
|
|
668
|
-
"
|
|
722
|
+
"activation_fn": "relu",
|
|
669
723
|
"normalize_activations": "none",
|
|
670
724
|
"device": device,
|
|
671
725
|
"apply_b_dec_to_input": False,
|
|
@@ -818,9 +872,8 @@ def get_llama_scope_r1_distill_config_from_hf(
|
|
|
818
872
|
"device": device,
|
|
819
873
|
"model_name": model_name,
|
|
820
874
|
"hook_name": huggingface_cfg_dict["hook_point_in"],
|
|
821
|
-
"hook_layer": int(huggingface_cfg_dict["hook_point_in"].split(".")[1]),
|
|
822
875
|
"hook_head_index": None,
|
|
823
|
-
"
|
|
876
|
+
"activation_fn": "relu",
|
|
824
877
|
"finetuning_scaling_factor": False,
|
|
825
878
|
"sae_lens_training_version": None,
|
|
826
879
|
"prepend_bos": True,
|
sae_lens/registry.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, Any
|
|
2
|
+
|
|
3
|
+
# avoid circular imports
|
|
4
|
+
if TYPE_CHECKING:
|
|
5
|
+
from sae_lens.saes.sae import SAE, SAEConfig, TrainingSAE, TrainingSAEConfig
|
|
6
|
+
|
|
7
|
+
SAE_CLASS_REGISTRY: dict[str, tuple["type[SAE[Any]]", "type[SAEConfig]"]] = {}
|
|
8
|
+
SAE_TRAINING_CLASS_REGISTRY: dict[
|
|
9
|
+
str, tuple["type[TrainingSAE[Any]]", "type[TrainingSAEConfig]"]
|
|
10
|
+
] = {}
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def register_sae_class(
|
|
14
|
+
architecture: str,
|
|
15
|
+
sae_class: "type[SAE[Any]]",
|
|
16
|
+
sae_config_class: "type[SAEConfig]",
|
|
17
|
+
) -> None:
|
|
18
|
+
if architecture in SAE_CLASS_REGISTRY:
|
|
19
|
+
raise ValueError(
|
|
20
|
+
f"SAE class for architecture {architecture} already registered."
|
|
21
|
+
)
|
|
22
|
+
SAE_CLASS_REGISTRY[architecture] = (sae_class, sae_config_class)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def register_sae_training_class(
|
|
26
|
+
architecture: str,
|
|
27
|
+
sae_training_class: "type[TrainingSAE[Any]]",
|
|
28
|
+
sae_training_config_class: "type[TrainingSAEConfig]",
|
|
29
|
+
) -> None:
|
|
30
|
+
if architecture in SAE_TRAINING_CLASS_REGISTRY:
|
|
31
|
+
raise ValueError(
|
|
32
|
+
f"SAE training class for architecture {architecture} already registered."
|
|
33
|
+
)
|
|
34
|
+
SAE_TRAINING_CLASS_REGISTRY[architecture] = (
|
|
35
|
+
sae_training_class,
|
|
36
|
+
sae_training_config_class,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def get_sae_class(
|
|
41
|
+
architecture: str,
|
|
42
|
+
) -> tuple["type[SAE[Any]]", "type[SAEConfig]"]:
|
|
43
|
+
return SAE_CLASS_REGISTRY[architecture]
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def get_sae_training_class(
|
|
47
|
+
architecture: str,
|
|
48
|
+
) -> tuple["type[TrainingSAE[Any]]", "type[TrainingSAEConfig]"]:
|
|
49
|
+
return SAE_TRAINING_CLASS_REGISTRY[architecture]
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
from .gated_sae import (
|
|
2
|
+
GatedSAE,
|
|
3
|
+
GatedSAEConfig,
|
|
4
|
+
GatedTrainingSAE,
|
|
5
|
+
GatedTrainingSAEConfig,
|
|
6
|
+
)
|
|
7
|
+
from .jumprelu_sae import (
|
|
8
|
+
JumpReLUSAE,
|
|
9
|
+
JumpReLUSAEConfig,
|
|
10
|
+
JumpReLUTrainingSAE,
|
|
11
|
+
JumpReLUTrainingSAEConfig,
|
|
12
|
+
)
|
|
13
|
+
from .sae import SAE, SAEConfig, TrainingSAE, TrainingSAEConfig
|
|
14
|
+
from .standard_sae import (
|
|
15
|
+
StandardSAE,
|
|
16
|
+
StandardSAEConfig,
|
|
17
|
+
StandardTrainingSAE,
|
|
18
|
+
StandardTrainingSAEConfig,
|
|
19
|
+
)
|
|
20
|
+
from .topk_sae import (
|
|
21
|
+
TopKSAE,
|
|
22
|
+
TopKSAEConfig,
|
|
23
|
+
TopKTrainingSAE,
|
|
24
|
+
TopKTrainingSAEConfig,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
__all__ = [
|
|
28
|
+
"SAE",
|
|
29
|
+
"SAEConfig",
|
|
30
|
+
"TrainingSAE",
|
|
31
|
+
"TrainingSAEConfig",
|
|
32
|
+
"StandardSAE",
|
|
33
|
+
"StandardSAEConfig",
|
|
34
|
+
"StandardTrainingSAE",
|
|
35
|
+
"StandardTrainingSAEConfig",
|
|
36
|
+
"GatedSAE",
|
|
37
|
+
"GatedSAEConfig",
|
|
38
|
+
"GatedTrainingSAE",
|
|
39
|
+
"GatedTrainingSAEConfig",
|
|
40
|
+
"JumpReLUSAE",
|
|
41
|
+
"JumpReLUSAEConfig",
|
|
42
|
+
"JumpReLUTrainingSAE",
|
|
43
|
+
"JumpReLUTrainingSAEConfig",
|
|
44
|
+
"TopKSAE",
|
|
45
|
+
"TopKSAEConfig",
|
|
46
|
+
"TopKTrainingSAE",
|
|
47
|
+
"TopKTrainingSAEConfig",
|
|
48
|
+
]
|
|
@@ -0,0 +1,254 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from jaxtyping import Float
|
|
6
|
+
from numpy.typing import NDArray
|
|
7
|
+
from torch import nn
|
|
8
|
+
from typing_extensions import override
|
|
9
|
+
|
|
10
|
+
from sae_lens.saes.sae import (
|
|
11
|
+
SAE,
|
|
12
|
+
SAEConfig,
|
|
13
|
+
TrainCoefficientConfig,
|
|
14
|
+
TrainingSAE,
|
|
15
|
+
TrainingSAEConfig,
|
|
16
|
+
TrainStepInput,
|
|
17
|
+
)
|
|
18
|
+
from sae_lens.util import filter_valid_dataclass_fields
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class GatedSAEConfig(SAEConfig):
|
|
23
|
+
"""
|
|
24
|
+
Configuration class for a GatedSAE.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
@override
|
|
28
|
+
@classmethod
|
|
29
|
+
def architecture(cls) -> str:
|
|
30
|
+
return "gated"
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class GatedSAE(SAE[GatedSAEConfig]):
|
|
34
|
+
"""
|
|
35
|
+
GatedSAE is an inference-only implementation of a Sparse Autoencoder (SAE)
|
|
36
|
+
using a gated linear encoder and a standard linear decoder.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
b_gate: nn.Parameter
|
|
40
|
+
b_mag: nn.Parameter
|
|
41
|
+
r_mag: nn.Parameter
|
|
42
|
+
|
|
43
|
+
def __init__(self, cfg: GatedSAEConfig, use_error_term: bool = False):
|
|
44
|
+
super().__init__(cfg, use_error_term)
|
|
45
|
+
# Ensure b_enc does not exist for the gated architecture
|
|
46
|
+
self.b_enc = None
|
|
47
|
+
|
|
48
|
+
@override
|
|
49
|
+
def initialize_weights(self) -> None:
|
|
50
|
+
super().initialize_weights()
|
|
51
|
+
_init_weights_gated(self)
|
|
52
|
+
|
|
53
|
+
def encode(
|
|
54
|
+
self, x: Float[torch.Tensor, "... d_in"]
|
|
55
|
+
) -> Float[torch.Tensor, "... d_sae"]:
|
|
56
|
+
"""
|
|
57
|
+
Encode the input tensor into the feature space using a gated encoder.
|
|
58
|
+
This must match the original encode_gated implementation from SAE class.
|
|
59
|
+
"""
|
|
60
|
+
# Preprocess the SAE input (casting type, applying hooks, normalization)
|
|
61
|
+
sae_in = self.process_sae_in(x)
|
|
62
|
+
|
|
63
|
+
# Gating path exactly as in original SAE.encode_gated
|
|
64
|
+
gating_pre_activation = sae_in @ self.W_enc + self.b_gate
|
|
65
|
+
active_features = (gating_pre_activation > 0).to(self.dtype)
|
|
66
|
+
|
|
67
|
+
# Magnitude path (weight sharing with gated encoder)
|
|
68
|
+
magnitude_pre_activation = self.hook_sae_acts_pre(
|
|
69
|
+
sae_in @ (self.W_enc * self.r_mag.exp()) + self.b_mag
|
|
70
|
+
)
|
|
71
|
+
feature_magnitudes = self.activation_fn(magnitude_pre_activation)
|
|
72
|
+
|
|
73
|
+
# Combine gating and magnitudes
|
|
74
|
+
return self.hook_sae_acts_post(active_features * feature_magnitudes)
|
|
75
|
+
|
|
76
|
+
def decode(
|
|
77
|
+
self, feature_acts: Float[torch.Tensor, "... d_sae"]
|
|
78
|
+
) -> Float[torch.Tensor, "... d_in"]:
|
|
79
|
+
"""
|
|
80
|
+
Decode the feature activations back into the input space:
|
|
81
|
+
1) Apply optional finetuning scaling.
|
|
82
|
+
2) Linear transform plus bias.
|
|
83
|
+
3) Run any reconstruction hooks and out-normalization if configured.
|
|
84
|
+
4) If the SAE was reshaping hook_z activations, reshape back.
|
|
85
|
+
"""
|
|
86
|
+
# 1) optional finetuning scaling
|
|
87
|
+
# 2) linear transform
|
|
88
|
+
sae_out_pre = feature_acts @ self.W_dec + self.b_dec
|
|
89
|
+
# 3) hooking and normalization
|
|
90
|
+
sae_out_pre = self.hook_sae_recons(sae_out_pre)
|
|
91
|
+
sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
|
|
92
|
+
# 4) reshape if needed (hook_z)
|
|
93
|
+
return self.reshape_fn_out(sae_out_pre, self.d_head)
|
|
94
|
+
|
|
95
|
+
@torch.no_grad()
|
|
96
|
+
def fold_W_dec_norm(self):
|
|
97
|
+
"""Override to handle gated-specific parameters."""
|
|
98
|
+
W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
|
|
99
|
+
self.W_dec.data = self.W_dec.data / W_dec_norms
|
|
100
|
+
self.W_enc.data = self.W_enc.data * W_dec_norms.T
|
|
101
|
+
|
|
102
|
+
# Gated-specific parameters need special handling
|
|
103
|
+
self.r_mag.data = self.r_mag.data * W_dec_norms.squeeze()
|
|
104
|
+
self.b_gate.data = self.b_gate.data * W_dec_norms.squeeze()
|
|
105
|
+
self.b_mag.data = self.b_mag.data * W_dec_norms.squeeze()
|
|
106
|
+
|
|
107
|
+
@torch.no_grad()
|
|
108
|
+
def initialize_decoder_norm_constant_norm(self, norm: float = 0.1):
|
|
109
|
+
"""Initialize decoder with constant norm."""
|
|
110
|
+
self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
|
|
111
|
+
self.W_dec.data *= norm
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@dataclass
|
|
115
|
+
class GatedTrainingSAEConfig(TrainingSAEConfig):
|
|
116
|
+
"""
|
|
117
|
+
Configuration class for training a GatedTrainingSAE.
|
|
118
|
+
"""
|
|
119
|
+
|
|
120
|
+
l1_coefficient: float = 1.0
|
|
121
|
+
l1_warm_up_steps: int = 0
|
|
122
|
+
|
|
123
|
+
@override
|
|
124
|
+
@classmethod
|
|
125
|
+
def architecture(cls) -> str:
|
|
126
|
+
return "gated"
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class GatedTrainingSAE(TrainingSAE[GatedTrainingSAEConfig]):
|
|
130
|
+
"""
|
|
131
|
+
GatedTrainingSAE is a concrete implementation of BaseTrainingSAE for the "gated" SAE architecture.
|
|
132
|
+
It implements:
|
|
133
|
+
- initialize_weights: sets up gating parameters (as in GatedSAE) plus optional training-specific init.
|
|
134
|
+
- encode: calls encode_with_hidden_pre (standard training approach).
|
|
135
|
+
- decode: linear transformation + hooking, same as GatedSAE or StandardTrainingSAE.
|
|
136
|
+
- encode_with_hidden_pre: gating logic + optional noise injection for training.
|
|
137
|
+
- calculate_aux_loss: includes an auxiliary reconstruction path and gating-based sparsity penalty.
|
|
138
|
+
- training_forward_pass: calls encode_with_hidden_pre, decode, and sums up MSE + gating losses.
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
b_gate: nn.Parameter # type: ignore
|
|
142
|
+
b_mag: nn.Parameter # type: ignore
|
|
143
|
+
r_mag: nn.Parameter # type: ignore
|
|
144
|
+
|
|
145
|
+
def __init__(self, cfg: GatedTrainingSAEConfig, use_error_term: bool = False):
|
|
146
|
+
if use_error_term:
|
|
147
|
+
raise ValueError(
|
|
148
|
+
"GatedSAE does not support `use_error_term`. Please set `use_error_term=False`."
|
|
149
|
+
)
|
|
150
|
+
super().__init__(cfg, use_error_term)
|
|
151
|
+
|
|
152
|
+
def initialize_weights(self) -> None:
|
|
153
|
+
super().initialize_weights()
|
|
154
|
+
_init_weights_gated(self)
|
|
155
|
+
|
|
156
|
+
def encode_with_hidden_pre(
|
|
157
|
+
self, x: Float[torch.Tensor, "... d_in"]
|
|
158
|
+
) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
|
|
159
|
+
"""
|
|
160
|
+
Gated forward pass with pre-activation (for training).
|
|
161
|
+
We also inject noise if self.training is True.
|
|
162
|
+
"""
|
|
163
|
+
sae_in = self.process_sae_in(x)
|
|
164
|
+
|
|
165
|
+
# Gating path
|
|
166
|
+
gating_pre_activation = sae_in @ self.W_enc + self.b_gate
|
|
167
|
+
active_features = (gating_pre_activation > 0).to(self.dtype)
|
|
168
|
+
|
|
169
|
+
# Magnitude path
|
|
170
|
+
magnitude_pre_activation = sae_in @ (self.W_enc * self.r_mag.exp()) + self.b_mag
|
|
171
|
+
magnitude_pre_activation = self.hook_sae_acts_pre(magnitude_pre_activation)
|
|
172
|
+
|
|
173
|
+
feature_magnitudes = self.activation_fn(magnitude_pre_activation)
|
|
174
|
+
|
|
175
|
+
# Combine gating path and magnitude path
|
|
176
|
+
feature_acts = self.hook_sae_acts_post(active_features * feature_magnitudes)
|
|
177
|
+
|
|
178
|
+
# Return both the final feature activations and the pre-activation (for logging or penalty)
|
|
179
|
+
return feature_acts, magnitude_pre_activation
|
|
180
|
+
|
|
181
|
+
def calculate_aux_loss(
|
|
182
|
+
self,
|
|
183
|
+
step_input: TrainStepInput,
|
|
184
|
+
feature_acts: torch.Tensor,
|
|
185
|
+
hidden_pre: torch.Tensor,
|
|
186
|
+
sae_out: torch.Tensor,
|
|
187
|
+
) -> dict[str, torch.Tensor]:
|
|
188
|
+
# Re-center the input if apply_b_dec_to_input is set
|
|
189
|
+
sae_in_centered = step_input.sae_in - (
|
|
190
|
+
self.b_dec * self.cfg.apply_b_dec_to_input
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
# The gating pre-activation (pi_gate) for the auxiliary path
|
|
194
|
+
pi_gate = sae_in_centered @ self.W_enc + self.b_gate
|
|
195
|
+
pi_gate_act = torch.relu(pi_gate)
|
|
196
|
+
|
|
197
|
+
# L1-like penalty scaled by W_dec norms
|
|
198
|
+
l1_loss = (
|
|
199
|
+
step_input.coefficients["l1"]
|
|
200
|
+
* torch.sum(pi_gate_act * self.W_dec.norm(dim=1), dim=-1).mean()
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
# Aux reconstruction: reconstruct x purely from gating path
|
|
204
|
+
via_gate_reconstruction = pi_gate_act @ self.W_dec + self.b_dec
|
|
205
|
+
aux_recon_loss = (
|
|
206
|
+
(via_gate_reconstruction - step_input.sae_in).pow(2).sum(dim=-1).mean()
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
# Return both losses separately
|
|
210
|
+
return {"l1_loss": l1_loss, "auxiliary_reconstruction_loss": aux_recon_loss}
|
|
211
|
+
|
|
212
|
+
def log_histograms(self) -> dict[str, NDArray[Any]]:
|
|
213
|
+
"""Log histograms of the weights and biases."""
|
|
214
|
+
b_gate_dist = self.b_gate.detach().float().cpu().numpy()
|
|
215
|
+
b_mag_dist = self.b_mag.detach().float().cpu().numpy()
|
|
216
|
+
return {
|
|
217
|
+
**super().log_histograms(),
|
|
218
|
+
"weights/b_gate": b_gate_dist,
|
|
219
|
+
"weights/b_mag": b_mag_dist,
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
@torch.no_grad()
|
|
223
|
+
def initialize_decoder_norm_constant_norm(self, norm: float = 0.1):
|
|
224
|
+
"""Initialize decoder with constant norm"""
|
|
225
|
+
self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
|
|
226
|
+
self.W_dec.data *= norm
|
|
227
|
+
|
|
228
|
+
def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]:
|
|
229
|
+
return {
|
|
230
|
+
"l1": TrainCoefficientConfig(
|
|
231
|
+
value=self.cfg.l1_coefficient,
|
|
232
|
+
warm_up_steps=self.cfg.l1_warm_up_steps,
|
|
233
|
+
),
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
def to_inference_config_dict(self) -> dict[str, Any]:
|
|
237
|
+
return filter_valid_dataclass_fields(
|
|
238
|
+
self.cfg.to_dict(), GatedSAEConfig, ["architecture"]
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def _init_weights_gated(
|
|
243
|
+
sae: SAE[GatedSAEConfig] | TrainingSAE[GatedTrainingSAEConfig],
|
|
244
|
+
) -> None:
|
|
245
|
+
sae.b_gate = nn.Parameter(
|
|
246
|
+
torch.zeros(sae.cfg.d_sae, dtype=sae.dtype, device=sae.device)
|
|
247
|
+
)
|
|
248
|
+
# Ensure r_mag is initialized to zero as in original
|
|
249
|
+
sae.r_mag = nn.Parameter(
|
|
250
|
+
torch.zeros(sae.cfg.d_sae, dtype=sae.dtype, device=sae.device)
|
|
251
|
+
)
|
|
252
|
+
sae.b_mag = nn.Parameter(
|
|
253
|
+
torch.zeros(sae.cfg.d_sae, dtype=sae.dtype, device=sae.device)
|
|
254
|
+
)
|