sae-lens 6.0.0rc1__py3-none-any.whl → 6.0.0rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +55 -18
- sae_lens/analysis/hooked_sae_transformer.py +10 -10
- sae_lens/analysis/neuronpedia_integration.py +13 -11
- sae_lens/cache_activations_runner.py +9 -7
- sae_lens/config.py +105 -235
- sae_lens/constants.py +20 -0
- sae_lens/evals.py +34 -31
- sae_lens/{sae_training_runner.py → llm_sae_training_runner.py} +103 -70
- sae_lens/load_model.py +53 -5
- sae_lens/loading/pretrained_sae_loaders.py +36 -10
- sae_lens/registry.py +49 -0
- sae_lens/saes/__init__.py +48 -0
- sae_lens/saes/gated_sae.py +70 -59
- sae_lens/saes/jumprelu_sae.py +58 -72
- sae_lens/saes/sae.py +248 -273
- sae_lens/saes/standard_sae.py +75 -57
- sae_lens/saes/topk_sae.py +72 -83
- sae_lens/training/activation_scaler.py +53 -0
- sae_lens/training/activations_store.py +105 -184
- sae_lens/training/mixing_buffer.py +56 -0
- sae_lens/training/optim.py +60 -36
- sae_lens/training/sae_trainer.py +134 -158
- sae_lens/training/types.py +5 -0
- sae_lens/training/upload_saes_to_huggingface.py +11 -5
- sae_lens/util.py +47 -0
- {sae_lens-6.0.0rc1.dist-info → sae_lens-6.0.0rc3.dist-info}/METADATA +1 -1
- sae_lens-6.0.0rc3.dist-info/RECORD +38 -0
- {sae_lens-6.0.0rc1.dist-info → sae_lens-6.0.0rc3.dist-info}/WHEEL +1 -1
- sae_lens/regsitry.py +0 -34
- sae_lens-6.0.0rc1.dist-info/RECORD +0 -32
- {sae_lens-6.0.0rc1.dist-info → sae_lens-6.0.0rc3.dist-info}/LICENSE +0 -0
|
@@ -7,11 +7,12 @@ import numpy as np
|
|
|
7
7
|
import torch
|
|
8
8
|
from huggingface_hub import hf_hub_download
|
|
9
9
|
from huggingface_hub.utils import EntryNotFoundError
|
|
10
|
+
from packaging.version import Version
|
|
10
11
|
from safetensors import safe_open
|
|
11
12
|
from safetensors.torch import load_file
|
|
12
13
|
|
|
13
14
|
from sae_lens import logger
|
|
14
|
-
from sae_lens.
|
|
15
|
+
from sae_lens.constants import (
|
|
15
16
|
DTYPE_MAP,
|
|
16
17
|
SAE_CFG_FILENAME,
|
|
17
18
|
SAE_WEIGHTS_FILENAME,
|
|
@@ -22,6 +23,8 @@ from sae_lens.loading.pretrained_saes_directory import (
|
|
|
22
23
|
get_pretrained_saes_directory,
|
|
23
24
|
get_repo_id_and_folder_name,
|
|
24
25
|
)
|
|
26
|
+
from sae_lens.registry import get_sae_class
|
|
27
|
+
from sae_lens.util import filter_valid_dataclass_fields
|
|
25
28
|
|
|
26
29
|
|
|
27
30
|
# loaders take in a release, sae_id, device, and whether to force download, and returns a tuple of config, state_dict, and log sparsity
|
|
@@ -174,9 +177,22 @@ def get_sae_lens_config_from_disk(
|
|
|
174
177
|
|
|
175
178
|
|
|
176
179
|
def handle_config_defaulting(cfg_dict: dict[str, Any]) -> dict[str, Any]:
|
|
180
|
+
sae_lens_version = cfg_dict.get("sae_lens_version")
|
|
181
|
+
if not sae_lens_version and "metadata" in cfg_dict:
|
|
182
|
+
sae_lens_version = cfg_dict["metadata"].get("sae_lens_version")
|
|
183
|
+
|
|
184
|
+
if not sae_lens_version or Version(sae_lens_version) < Version("6.0.0-rc.0"):
|
|
185
|
+
cfg_dict = handle_pre_6_0_config(cfg_dict)
|
|
186
|
+
return cfg_dict
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def handle_pre_6_0_config(cfg_dict: dict[str, Any]) -> dict[str, Any]:
|
|
190
|
+
"""
|
|
191
|
+
Format a config dictionary for a Sparse Autoencoder (SAE) to be compatible with the new 6.0 format.
|
|
192
|
+
"""
|
|
193
|
+
|
|
177
194
|
rename_keys_map = {
|
|
178
195
|
"hook_point": "hook_name",
|
|
179
|
-
"hook_point_layer": "hook_layer",
|
|
180
196
|
"hook_point_head_index": "hook_head_index",
|
|
181
197
|
"activation_fn_str": "activation_fn",
|
|
182
198
|
}
|
|
@@ -202,10 +218,26 @@ def handle_config_defaulting(cfg_dict: dict[str, Any]) -> dict[str, Any]:
|
|
|
202
218
|
else "expected_average_only_in"
|
|
203
219
|
)
|
|
204
220
|
|
|
205
|
-
new_cfg.
|
|
221
|
+
if new_cfg.get("normalize_activations") is None:
|
|
222
|
+
new_cfg["normalize_activations"] = "none"
|
|
223
|
+
|
|
206
224
|
new_cfg.setdefault("device", "cpu")
|
|
207
225
|
|
|
208
|
-
|
|
226
|
+
architecture = new_cfg.get("architecture", "standard")
|
|
227
|
+
|
|
228
|
+
config_class = get_sae_class(architecture)[1]
|
|
229
|
+
|
|
230
|
+
sae_cfg_dict = filter_valid_dataclass_fields(new_cfg, config_class)
|
|
231
|
+
if architecture == "topk":
|
|
232
|
+
sae_cfg_dict["k"] = new_cfg["activation_fn_kwargs"]["k"]
|
|
233
|
+
|
|
234
|
+
# import here to avoid circular import
|
|
235
|
+
from sae_lens.saes.sae import SAEMetadata
|
|
236
|
+
|
|
237
|
+
meta_dict = filter_valid_dataclass_fields(new_cfg, SAEMetadata)
|
|
238
|
+
sae_cfg_dict["metadata"] = meta_dict
|
|
239
|
+
sae_cfg_dict["architecture"] = architecture
|
|
240
|
+
return sae_cfg_dict
|
|
209
241
|
|
|
210
242
|
|
|
211
243
|
def get_connor_rob_hook_z_config_from_hf(
|
|
@@ -229,7 +261,6 @@ def get_connor_rob_hook_z_config_from_hf(
|
|
|
229
261
|
"device": device if device is not None else "cpu",
|
|
230
262
|
"model_name": "gpt2-small",
|
|
231
263
|
"hook_name": old_cfg_dict["act_name"],
|
|
232
|
-
"hook_layer": old_cfg_dict["layer"],
|
|
233
264
|
"hook_head_index": None,
|
|
234
265
|
"activation_fn": "relu",
|
|
235
266
|
"apply_b_dec_to_input": True,
|
|
@@ -378,7 +409,6 @@ def get_gemma_2_config_from_hf(
|
|
|
378
409
|
"dtype": "float32",
|
|
379
410
|
"model_name": model_name,
|
|
380
411
|
"hook_name": hook_name,
|
|
381
|
-
"hook_layer": layer,
|
|
382
412
|
"hook_head_index": None,
|
|
383
413
|
"activation_fn": "relu",
|
|
384
414
|
"finetuning_scaling_factor": False,
|
|
@@ -491,7 +521,6 @@ def get_llama_scope_config_from_hf(
|
|
|
491
521
|
"dtype": "bfloat16",
|
|
492
522
|
"model_name": model_name,
|
|
493
523
|
"hook_name": old_cfg_dict["hook_point_in"],
|
|
494
|
-
"hook_layer": int(old_cfg_dict["hook_point_in"].split(".")[1]),
|
|
495
524
|
"hook_head_index": None,
|
|
496
525
|
"activation_fn": "relu",
|
|
497
526
|
"finetuning_scaling_factor": False,
|
|
@@ -618,7 +647,6 @@ def get_dictionary_learning_config_1_from_hf(
|
|
|
618
647
|
"device": device,
|
|
619
648
|
"model_name": trainer["lm_name"].split("/")[-1],
|
|
620
649
|
"hook_name": hook_point_name,
|
|
621
|
-
"hook_layer": trainer["layer"],
|
|
622
650
|
"hook_head_index": None,
|
|
623
651
|
"activation_fn": activation_fn,
|
|
624
652
|
"activation_fn_kwargs": activation_fn_kwargs,
|
|
@@ -657,7 +685,6 @@ def get_deepseek_r1_config_from_hf(
|
|
|
657
685
|
"context_size": 1024,
|
|
658
686
|
"model_name": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
|
|
659
687
|
"hook_name": f"blocks.{layer}.hook_resid_post",
|
|
660
|
-
"hook_layer": layer,
|
|
661
688
|
"hook_head_index": None,
|
|
662
689
|
"prepend_bos": True,
|
|
663
690
|
"dataset_path": "lmsys/lmsys-chat-1m",
|
|
@@ -816,7 +843,6 @@ def get_llama_scope_r1_distill_config_from_hf(
|
|
|
816
843
|
"device": device,
|
|
817
844
|
"model_name": model_name,
|
|
818
845
|
"hook_name": huggingface_cfg_dict["hook_point_in"],
|
|
819
|
-
"hook_layer": int(huggingface_cfg_dict["hook_point_in"].split(".")[1]),
|
|
820
846
|
"hook_head_index": None,
|
|
821
847
|
"activation_fn": "relu",
|
|
822
848
|
"finetuning_scaling_factor": False,
|
sae_lens/registry.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, Any
|
|
2
|
+
|
|
3
|
+
# avoid circular imports
|
|
4
|
+
if TYPE_CHECKING:
|
|
5
|
+
from sae_lens.saes.sae import SAE, SAEConfig, TrainingSAE, TrainingSAEConfig
|
|
6
|
+
|
|
7
|
+
SAE_CLASS_REGISTRY: dict[str, tuple["type[SAE[Any]]", "type[SAEConfig]"]] = {}
|
|
8
|
+
SAE_TRAINING_CLASS_REGISTRY: dict[
|
|
9
|
+
str, tuple["type[TrainingSAE[Any]]", "type[TrainingSAEConfig]"]
|
|
10
|
+
] = {}
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def register_sae_class(
|
|
14
|
+
architecture: str,
|
|
15
|
+
sae_class: "type[SAE[Any]]",
|
|
16
|
+
sae_config_class: "type[SAEConfig]",
|
|
17
|
+
) -> None:
|
|
18
|
+
if architecture in SAE_CLASS_REGISTRY:
|
|
19
|
+
raise ValueError(
|
|
20
|
+
f"SAE class for architecture {architecture} already registered."
|
|
21
|
+
)
|
|
22
|
+
SAE_CLASS_REGISTRY[architecture] = (sae_class, sae_config_class)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def register_sae_training_class(
|
|
26
|
+
architecture: str,
|
|
27
|
+
sae_training_class: "type[TrainingSAE[Any]]",
|
|
28
|
+
sae_training_config_class: "type[TrainingSAEConfig]",
|
|
29
|
+
) -> None:
|
|
30
|
+
if architecture in SAE_TRAINING_CLASS_REGISTRY:
|
|
31
|
+
raise ValueError(
|
|
32
|
+
f"SAE training class for architecture {architecture} already registered."
|
|
33
|
+
)
|
|
34
|
+
SAE_TRAINING_CLASS_REGISTRY[architecture] = (
|
|
35
|
+
sae_training_class,
|
|
36
|
+
sae_training_config_class,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def get_sae_class(
|
|
41
|
+
architecture: str,
|
|
42
|
+
) -> tuple["type[SAE[Any]]", "type[SAEConfig]"]:
|
|
43
|
+
return SAE_CLASS_REGISTRY[architecture]
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def get_sae_training_class(
|
|
47
|
+
architecture: str,
|
|
48
|
+
) -> tuple["type[TrainingSAE[Any]]", "type[TrainingSAEConfig]"]:
|
|
49
|
+
return SAE_TRAINING_CLASS_REGISTRY[architecture]
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
from .gated_sae import (
|
|
2
|
+
GatedSAE,
|
|
3
|
+
GatedSAEConfig,
|
|
4
|
+
GatedTrainingSAE,
|
|
5
|
+
GatedTrainingSAEConfig,
|
|
6
|
+
)
|
|
7
|
+
from .jumprelu_sae import (
|
|
8
|
+
JumpReLUSAE,
|
|
9
|
+
JumpReLUSAEConfig,
|
|
10
|
+
JumpReLUTrainingSAE,
|
|
11
|
+
JumpReLUTrainingSAEConfig,
|
|
12
|
+
)
|
|
13
|
+
from .sae import SAE, SAEConfig, TrainingSAE, TrainingSAEConfig
|
|
14
|
+
from .standard_sae import (
|
|
15
|
+
StandardSAE,
|
|
16
|
+
StandardSAEConfig,
|
|
17
|
+
StandardTrainingSAE,
|
|
18
|
+
StandardTrainingSAEConfig,
|
|
19
|
+
)
|
|
20
|
+
from .topk_sae import (
|
|
21
|
+
TopKSAE,
|
|
22
|
+
TopKSAEConfig,
|
|
23
|
+
TopKTrainingSAE,
|
|
24
|
+
TopKTrainingSAEConfig,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
__all__ = [
|
|
28
|
+
"SAE",
|
|
29
|
+
"SAEConfig",
|
|
30
|
+
"TrainingSAE",
|
|
31
|
+
"TrainingSAEConfig",
|
|
32
|
+
"StandardSAE",
|
|
33
|
+
"StandardSAEConfig",
|
|
34
|
+
"StandardTrainingSAE",
|
|
35
|
+
"StandardTrainingSAEConfig",
|
|
36
|
+
"GatedSAE",
|
|
37
|
+
"GatedSAEConfig",
|
|
38
|
+
"GatedTrainingSAE",
|
|
39
|
+
"GatedTrainingSAEConfig",
|
|
40
|
+
"JumpReLUSAE",
|
|
41
|
+
"JumpReLUSAEConfig",
|
|
42
|
+
"JumpReLUTrainingSAE",
|
|
43
|
+
"JumpReLUTrainingSAEConfig",
|
|
44
|
+
"TopKSAE",
|
|
45
|
+
"TopKSAEConfig",
|
|
46
|
+
"TopKTrainingSAE",
|
|
47
|
+
"TopKTrainingSAEConfig",
|
|
48
|
+
]
|
sae_lens/saes/gated_sae.py
CHANGED
|
@@ -1,20 +1,36 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
1
2
|
from typing import Any
|
|
2
3
|
|
|
3
4
|
import torch
|
|
4
5
|
from jaxtyping import Float
|
|
5
6
|
from numpy.typing import NDArray
|
|
6
7
|
from torch import nn
|
|
8
|
+
from typing_extensions import override
|
|
7
9
|
|
|
8
10
|
from sae_lens.saes.sae import (
|
|
9
11
|
SAE,
|
|
10
12
|
SAEConfig,
|
|
13
|
+
TrainCoefficientConfig,
|
|
11
14
|
TrainingSAE,
|
|
12
15
|
TrainingSAEConfig,
|
|
13
16
|
TrainStepInput,
|
|
14
17
|
)
|
|
18
|
+
from sae_lens.util import filter_valid_dataclass_fields
|
|
15
19
|
|
|
16
20
|
|
|
17
|
-
|
|
21
|
+
@dataclass
|
|
22
|
+
class GatedSAEConfig(SAEConfig):
|
|
23
|
+
"""
|
|
24
|
+
Configuration class for a GatedSAE.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
@override
|
|
28
|
+
@classmethod
|
|
29
|
+
def architecture(cls) -> str:
|
|
30
|
+
return "gated"
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class GatedSAE(SAE[GatedSAEConfig]):
|
|
18
34
|
"""
|
|
19
35
|
GatedSAE is an inference-only implementation of a Sparse Autoencoder (SAE)
|
|
20
36
|
using a gated linear encoder and a standard linear decoder.
|
|
@@ -24,48 +40,15 @@ class GatedSAE(SAE):
|
|
|
24
40
|
b_mag: nn.Parameter
|
|
25
41
|
r_mag: nn.Parameter
|
|
26
42
|
|
|
27
|
-
def __init__(self, cfg:
|
|
43
|
+
def __init__(self, cfg: GatedSAEConfig, use_error_term: bool = False):
|
|
28
44
|
super().__init__(cfg, use_error_term)
|
|
29
45
|
# Ensure b_enc does not exist for the gated architecture
|
|
30
46
|
self.b_enc = None
|
|
31
47
|
|
|
48
|
+
@override
|
|
32
49
|
def initialize_weights(self) -> None:
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
"""
|
|
36
|
-
# Use the same initialization methods and values as in original SAE
|
|
37
|
-
self.W_enc = nn.Parameter(
|
|
38
|
-
torch.nn.init.kaiming_uniform_(
|
|
39
|
-
torch.empty(
|
|
40
|
-
self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
|
|
41
|
-
)
|
|
42
|
-
)
|
|
43
|
-
)
|
|
44
|
-
self.b_gate = nn.Parameter(
|
|
45
|
-
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
|
|
46
|
-
)
|
|
47
|
-
# Ensure r_mag is initialized to zero as in original
|
|
48
|
-
self.r_mag = nn.Parameter(
|
|
49
|
-
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
|
|
50
|
-
)
|
|
51
|
-
self.b_mag = nn.Parameter(
|
|
52
|
-
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
|
|
53
|
-
)
|
|
54
|
-
|
|
55
|
-
# Decoder parameters with same initialization as original
|
|
56
|
-
self.W_dec = nn.Parameter(
|
|
57
|
-
torch.nn.init.kaiming_uniform_(
|
|
58
|
-
torch.empty(
|
|
59
|
-
self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
|
|
60
|
-
)
|
|
61
|
-
)
|
|
62
|
-
)
|
|
63
|
-
self.b_dec = nn.Parameter(
|
|
64
|
-
torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
|
|
65
|
-
)
|
|
66
|
-
|
|
67
|
-
# after defining b_gate, b_mag, etc.:
|
|
68
|
-
self.b_enc = None
|
|
50
|
+
super().initialize_weights()
|
|
51
|
+
_init_weights_gated(self)
|
|
69
52
|
|
|
70
53
|
def encode(
|
|
71
54
|
self, x: Float[torch.Tensor, "... d_in"]
|
|
@@ -101,9 +84,8 @@ class GatedSAE(SAE):
|
|
|
101
84
|
4) If the SAE was reshaping hook_z activations, reshape back.
|
|
102
85
|
"""
|
|
103
86
|
# 1) optional finetuning scaling
|
|
104
|
-
scaled_features = self.apply_finetuning_scaling_factor(feature_acts)
|
|
105
87
|
# 2) linear transform
|
|
106
|
-
sae_out_pre =
|
|
88
|
+
sae_out_pre = feature_acts @ self.W_dec + self.b_dec
|
|
107
89
|
# 3) hooking and normalization
|
|
108
90
|
sae_out_pre = self.hook_sae_recons(sae_out_pre)
|
|
109
91
|
sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
|
|
@@ -129,7 +111,22 @@ class GatedSAE(SAE):
|
|
|
129
111
|
self.W_dec.data *= norm
|
|
130
112
|
|
|
131
113
|
|
|
132
|
-
|
|
114
|
+
@dataclass
|
|
115
|
+
class GatedTrainingSAEConfig(TrainingSAEConfig):
|
|
116
|
+
"""
|
|
117
|
+
Configuration class for training a GatedTrainingSAE.
|
|
118
|
+
"""
|
|
119
|
+
|
|
120
|
+
l1_coefficient: float = 1.0
|
|
121
|
+
l1_warm_up_steps: int = 0
|
|
122
|
+
|
|
123
|
+
@override
|
|
124
|
+
@classmethod
|
|
125
|
+
def architecture(cls) -> str:
|
|
126
|
+
return "gated"
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class GatedTrainingSAE(TrainingSAE[GatedTrainingSAEConfig]):
|
|
133
130
|
"""
|
|
134
131
|
GatedTrainingSAE is a concrete implementation of BaseTrainingSAE for the "gated" SAE architecture.
|
|
135
132
|
It implements:
|
|
@@ -145,7 +142,7 @@ class GatedTrainingSAE(TrainingSAE):
|
|
|
145
142
|
b_mag: nn.Parameter # type: ignore
|
|
146
143
|
r_mag: nn.Parameter # type: ignore
|
|
147
144
|
|
|
148
|
-
def __init__(self, cfg:
|
|
145
|
+
def __init__(self, cfg: GatedTrainingSAEConfig, use_error_term: bool = False):
|
|
149
146
|
if use_error_term:
|
|
150
147
|
raise ValueError(
|
|
151
148
|
"GatedSAE does not support `use_error_term`. Please set `use_error_term=False`."
|
|
@@ -153,22 +150,8 @@ class GatedTrainingSAE(TrainingSAE):
|
|
|
153
150
|
super().__init__(cfg, use_error_term)
|
|
154
151
|
|
|
155
152
|
def initialize_weights(self) -> None:
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
# Additional training-specific logic, e.g. orthogonal init or heuristics:
|
|
160
|
-
if self.cfg.decoder_orthogonal_init:
|
|
161
|
-
self.W_dec.data = nn.init.orthogonal_(self.W_dec.data.T).T
|
|
162
|
-
elif self.cfg.decoder_heuristic_init:
|
|
163
|
-
self.W_dec.data = torch.rand(
|
|
164
|
-
self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
|
|
165
|
-
)
|
|
166
|
-
self.initialize_decoder_norm_constant_norm()
|
|
167
|
-
if self.cfg.init_encoder_as_decoder_transpose:
|
|
168
|
-
self.W_enc.data = self.W_dec.data.T.clone().contiguous()
|
|
169
|
-
if self.cfg.normalize_sae_decoder:
|
|
170
|
-
with torch.no_grad():
|
|
171
|
-
self.set_decoder_norm_to_unit_norm()
|
|
153
|
+
super().initialize_weights()
|
|
154
|
+
_init_weights_gated(self)
|
|
172
155
|
|
|
173
156
|
def encode_with_hidden_pre(
|
|
174
157
|
self, x: Float[torch.Tensor, "... d_in"]
|
|
@@ -217,7 +200,7 @@ class GatedTrainingSAE(TrainingSAE):
|
|
|
217
200
|
|
|
218
201
|
# L1-like penalty scaled by W_dec norms
|
|
219
202
|
l1_loss = (
|
|
220
|
-
step_input.
|
|
203
|
+
step_input.coefficients["l1"]
|
|
221
204
|
* torch.sum(pi_gate_act * self.W_dec.norm(dim=1), dim=-1).mean()
|
|
222
205
|
)
|
|
223
206
|
|
|
@@ -245,3 +228,31 @@ class GatedTrainingSAE(TrainingSAE):
|
|
|
245
228
|
"""Initialize decoder with constant norm"""
|
|
246
229
|
self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
|
|
247
230
|
self.W_dec.data *= norm
|
|
231
|
+
|
|
232
|
+
def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]:
|
|
233
|
+
return {
|
|
234
|
+
"l1": TrainCoefficientConfig(
|
|
235
|
+
value=self.cfg.l1_coefficient,
|
|
236
|
+
warm_up_steps=self.cfg.l1_warm_up_steps,
|
|
237
|
+
),
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
def to_inference_config_dict(self) -> dict[str, Any]:
|
|
241
|
+
return filter_valid_dataclass_fields(
|
|
242
|
+
self.cfg.to_dict(), GatedSAEConfig, ["architecture"]
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def _init_weights_gated(
|
|
247
|
+
sae: SAE[GatedSAEConfig] | TrainingSAE[GatedTrainingSAEConfig],
|
|
248
|
+
) -> None:
|
|
249
|
+
sae.b_gate = nn.Parameter(
|
|
250
|
+
torch.zeros(sae.cfg.d_sae, dtype=sae.dtype, device=sae.device)
|
|
251
|
+
)
|
|
252
|
+
# Ensure r_mag is initialized to zero as in original
|
|
253
|
+
sae.r_mag = nn.Parameter(
|
|
254
|
+
torch.zeros(sae.cfg.d_sae, dtype=sae.dtype, device=sae.device)
|
|
255
|
+
)
|
|
256
|
+
sae.b_mag = nn.Parameter(
|
|
257
|
+
torch.zeros(sae.cfg.d_sae, dtype=sae.dtype, device=sae.device)
|
|
258
|
+
)
|
sae_lens/saes/jumprelu_sae.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
1
2
|
from typing import Any
|
|
2
3
|
|
|
3
4
|
import numpy as np
|
|
@@ -9,11 +10,13 @@ from typing_extensions import override
|
|
|
9
10
|
from sae_lens.saes.sae import (
|
|
10
11
|
SAE,
|
|
11
12
|
SAEConfig,
|
|
13
|
+
TrainCoefficientConfig,
|
|
12
14
|
TrainingSAE,
|
|
13
15
|
TrainingSAEConfig,
|
|
14
16
|
TrainStepInput,
|
|
15
17
|
TrainStepOutput,
|
|
16
18
|
)
|
|
19
|
+
from sae_lens.util import filter_valid_dataclass_fields
|
|
17
20
|
|
|
18
21
|
|
|
19
22
|
def rectangle(x: torch.Tensor) -> torch.Tensor:
|
|
@@ -85,7 +88,19 @@ class JumpReLU(torch.autograd.Function):
|
|
|
85
88
|
return x_grad, threshold_grad, None
|
|
86
89
|
|
|
87
90
|
|
|
88
|
-
|
|
91
|
+
@dataclass
|
|
92
|
+
class JumpReLUSAEConfig(SAEConfig):
|
|
93
|
+
"""
|
|
94
|
+
Configuration class for a JumpReLUSAE.
|
|
95
|
+
"""
|
|
96
|
+
|
|
97
|
+
@override
|
|
98
|
+
@classmethod
|
|
99
|
+
def architecture(cls) -> str:
|
|
100
|
+
return "jumprelu"
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class JumpReLUSAE(SAE[JumpReLUSAEConfig]):
|
|
89
104
|
"""
|
|
90
105
|
JumpReLUSAE is an inference-only implementation of a Sparse Autoencoder (SAE)
|
|
91
106
|
using a JumpReLU activation. For each unit, if its pre-activation is
|
|
@@ -104,42 +119,18 @@ class JumpReLUSAE(SAE):
|
|
|
104
119
|
b_enc: nn.Parameter
|
|
105
120
|
threshold: nn.Parameter
|
|
106
121
|
|
|
107
|
-
def __init__(self, cfg:
|
|
122
|
+
def __init__(self, cfg: JumpReLUSAEConfig, use_error_term: bool = False):
|
|
108
123
|
super().__init__(cfg, use_error_term)
|
|
109
124
|
|
|
125
|
+
@override
|
|
110
126
|
def initialize_weights(self) -> None:
|
|
111
|
-
|
|
112
|
-
Initialize encoder and decoder weights, as well as biases.
|
|
113
|
-
Additionally, include a learnable `threshold` parameter that
|
|
114
|
-
determines when units "turn on" for the JumpReLU.
|
|
115
|
-
"""
|
|
116
|
-
# Biases
|
|
117
|
-
self.b_enc = nn.Parameter(
|
|
118
|
-
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
|
|
119
|
-
)
|
|
120
|
-
self.b_dec = nn.Parameter(
|
|
121
|
-
torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
|
|
122
|
-
)
|
|
123
|
-
|
|
124
|
-
# Threshold for JumpReLU
|
|
125
|
-
# You can pick a default initialization (e.g., zeros means unit is off unless hidden_pre > 0)
|
|
126
|
-
# or see the training version for more advanced init with log_threshold, etc.
|
|
127
|
+
super().initialize_weights()
|
|
127
128
|
self.threshold = nn.Parameter(
|
|
128
129
|
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
|
|
129
130
|
)
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
w_enc_data = torch.empty(
|
|
133
|
-
self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
|
|
134
|
-
)
|
|
135
|
-
nn.init.kaiming_uniform_(w_enc_data)
|
|
136
|
-
self.W_enc = nn.Parameter(w_enc_data)
|
|
137
|
-
|
|
138
|
-
w_dec_data = torch.empty(
|
|
139
|
-
self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
|
|
131
|
+
self.b_enc = nn.Parameter(
|
|
132
|
+
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
|
|
140
133
|
)
|
|
141
|
-
nn.init.kaiming_uniform_(w_dec_data)
|
|
142
|
-
self.W_dec = nn.Parameter(w_dec_data)
|
|
143
134
|
|
|
144
135
|
def encode(
|
|
145
136
|
self, x: Float[torch.Tensor, "... d_in"]
|
|
@@ -168,8 +159,7 @@ class JumpReLUSAE(SAE):
|
|
|
168
159
|
Decode the feature activations back to the input space.
|
|
169
160
|
Follows the same steps as StandardSAE: apply scaling, transform, hook, and optionally reshape.
|
|
170
161
|
"""
|
|
171
|
-
|
|
172
|
-
sae_out_pre = scaled_features @ self.W_dec + self.b_dec
|
|
162
|
+
sae_out_pre = feature_acts @ self.W_dec + self.b_dec
|
|
173
163
|
sae_out_pre = self.hook_sae_recons(sae_out_pre)
|
|
174
164
|
sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
|
|
175
165
|
return self.reshape_fn_out(sae_out_pre, self.d_head)
|
|
@@ -195,7 +185,24 @@ class JumpReLUSAE(SAE):
|
|
|
195
185
|
self.threshold.data = current_thresh * W_dec_norms
|
|
196
186
|
|
|
197
187
|
|
|
198
|
-
|
|
188
|
+
@dataclass
|
|
189
|
+
class JumpReLUTrainingSAEConfig(TrainingSAEConfig):
|
|
190
|
+
"""
|
|
191
|
+
Configuration class for training a JumpReLUTrainingSAE.
|
|
192
|
+
"""
|
|
193
|
+
|
|
194
|
+
jumprelu_init_threshold: float = 0.001
|
|
195
|
+
jumprelu_bandwidth: float = 0.001
|
|
196
|
+
l0_coefficient: float = 1.0
|
|
197
|
+
l0_warm_up_steps: int = 0
|
|
198
|
+
|
|
199
|
+
@override
|
|
200
|
+
@classmethod
|
|
201
|
+
def architecture(cls) -> str:
|
|
202
|
+
return "jumprelu"
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
class JumpReLUTrainingSAE(TrainingSAE[JumpReLUTrainingSAEConfig]):
|
|
199
206
|
"""
|
|
200
207
|
JumpReLUTrainingSAE is a training-focused implementation of a SAE using a JumpReLU activation.
|
|
201
208
|
|
|
@@ -213,7 +220,7 @@ class JumpReLUTrainingSAE(TrainingSAE):
|
|
|
213
220
|
b_enc: nn.Parameter
|
|
214
221
|
log_threshold: nn.Parameter
|
|
215
222
|
|
|
216
|
-
def __init__(self, cfg:
|
|
223
|
+
def __init__(self, cfg: JumpReLUTrainingSAEConfig, use_error_term: bool = False):
|
|
217
224
|
super().__init__(cfg, use_error_term)
|
|
218
225
|
|
|
219
226
|
# We'll store a bandwidth for the training approach, if needed
|
|
@@ -225,51 +232,16 @@ class JumpReLUTrainingSAE(TrainingSAE):
|
|
|
225
232
|
* np.log(cfg.jumprelu_init_threshold)
|
|
226
233
|
)
|
|
227
234
|
|
|
235
|
+
@override
|
|
228
236
|
def initialize_weights(self) -> None:
|
|
229
237
|
"""
|
|
230
238
|
Initialize parameters like the base SAE, but also add log_threshold.
|
|
231
239
|
"""
|
|
240
|
+
super().initialize_weights()
|
|
232
241
|
# Encoder Bias
|
|
233
242
|
self.b_enc = nn.Parameter(
|
|
234
243
|
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
|
|
235
244
|
)
|
|
236
|
-
# Decoder Bias
|
|
237
|
-
self.b_dec = nn.Parameter(
|
|
238
|
-
torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
|
|
239
|
-
)
|
|
240
|
-
# W_enc
|
|
241
|
-
w_enc_data = torch.nn.init.kaiming_uniform_(
|
|
242
|
-
torch.empty(
|
|
243
|
-
self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
|
|
244
|
-
)
|
|
245
|
-
)
|
|
246
|
-
self.W_enc = nn.Parameter(w_enc_data)
|
|
247
|
-
|
|
248
|
-
# W_dec
|
|
249
|
-
w_dec_data = torch.nn.init.kaiming_uniform_(
|
|
250
|
-
torch.empty(
|
|
251
|
-
self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
|
|
252
|
-
)
|
|
253
|
-
)
|
|
254
|
-
self.W_dec = nn.Parameter(w_dec_data)
|
|
255
|
-
|
|
256
|
-
# Optionally apply orthogonal or heuristic init
|
|
257
|
-
if self.cfg.decoder_orthogonal_init:
|
|
258
|
-
self.W_dec.data = nn.init.orthogonal_(self.W_dec.data.T).T
|
|
259
|
-
elif self.cfg.decoder_heuristic_init:
|
|
260
|
-
self.W_dec.data = torch.rand(
|
|
261
|
-
self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
|
|
262
|
-
)
|
|
263
|
-
self.initialize_decoder_norm_constant_norm()
|
|
264
|
-
|
|
265
|
-
# Optionally transpose
|
|
266
|
-
if self.cfg.init_encoder_as_decoder_transpose:
|
|
267
|
-
self.W_enc.data = self.W_dec.data.T.clone().contiguous()
|
|
268
|
-
|
|
269
|
-
# Optionally normalize columns of W_dec
|
|
270
|
-
if self.cfg.normalize_sae_decoder:
|
|
271
|
-
with torch.no_grad():
|
|
272
|
-
self.set_decoder_norm_to_unit_norm()
|
|
273
245
|
|
|
274
246
|
@property
|
|
275
247
|
def threshold(self) -> torch.Tensor:
|
|
@@ -305,9 +277,18 @@ class JumpReLUTrainingSAE(TrainingSAE):
|
|
|
305
277
|
) -> dict[str, torch.Tensor]:
|
|
306
278
|
"""Calculate architecture-specific auxiliary loss terms."""
|
|
307
279
|
l0 = torch.sum(Step.apply(hidden_pre, self.threshold, self.bandwidth), dim=-1) # type: ignore
|
|
308
|
-
l0_loss = (step_input.
|
|
280
|
+
l0_loss = (step_input.coefficients["l0"] * l0).mean()
|
|
309
281
|
return {"l0_loss": l0_loss}
|
|
310
282
|
|
|
283
|
+
@override
|
|
284
|
+
def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]:
|
|
285
|
+
return {
|
|
286
|
+
"l0": TrainCoefficientConfig(
|
|
287
|
+
value=self.cfg.l0_coefficient,
|
|
288
|
+
warm_up_steps=self.cfg.l0_warm_up_steps,
|
|
289
|
+
),
|
|
290
|
+
}
|
|
291
|
+
|
|
311
292
|
@torch.no_grad()
|
|
312
293
|
def fold_W_dec_norm(self):
|
|
313
294
|
"""
|
|
@@ -366,3 +347,8 @@ class JumpReLUTrainingSAE(TrainingSAE):
|
|
|
366
347
|
threshold = state_dict["threshold"]
|
|
367
348
|
del state_dict["threshold"]
|
|
368
349
|
state_dict["log_threshold"] = torch.log(threshold).detach().contiguous()
|
|
350
|
+
|
|
351
|
+
def to_inference_config_dict(self) -> dict[str, Any]:
|
|
352
|
+
return filter_valid_dataclass_fields(
|
|
353
|
+
self.cfg.to_dict(), JumpReLUSAEConfig, ["architecture"]
|
|
354
|
+
)
|