sae-lens 6.0.0rc1__py3-none-any.whl → 6.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +50 -16
- sae_lens/analysis/hooked_sae_transformer.py +10 -10
- sae_lens/analysis/neuronpedia_integration.py +13 -11
- sae_lens/cache_activations_runner.py +2 -1
- sae_lens/config.py +59 -231
- sae_lens/constants.py +18 -0
- sae_lens/evals.py +16 -13
- sae_lens/loading/pretrained_sae_loaders.py +36 -3
- sae_lens/registry.py +49 -0
- sae_lens/sae_training_runner.py +22 -21
- sae_lens/saes/__init__.py +48 -0
- sae_lens/saes/gated_sae.py +70 -59
- sae_lens/saes/jumprelu_sae.py +58 -72
- sae_lens/saes/sae.py +250 -272
- sae_lens/saes/standard_sae.py +75 -57
- sae_lens/saes/topk_sae.py +72 -83
- sae_lens/training/activations_store.py +31 -15
- sae_lens/training/optim.py +60 -36
- sae_lens/training/sae_trainer.py +44 -69
- sae_lens/training/upload_saes_to_huggingface.py +11 -5
- sae_lens/util.py +28 -0
- {sae_lens-6.0.0rc1.dist-info → sae_lens-6.0.0rc2.dist-info}/METADATA +1 -1
- sae_lens-6.0.0rc2.dist-info/RECORD +35 -0
- {sae_lens-6.0.0rc1.dist-info → sae_lens-6.0.0rc2.dist-info}/WHEEL +1 -1
- sae_lens/regsitry.py +0 -34
- sae_lens-6.0.0rc1.dist-info/RECORD +0 -32
- {sae_lens-6.0.0rc1.dist-info → sae_lens-6.0.0rc2.dist-info}/LICENSE +0 -0
sae_lens/registry.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, Any
|
|
2
|
+
|
|
3
|
+
# avoid circular imports
|
|
4
|
+
if TYPE_CHECKING:
|
|
5
|
+
from sae_lens.saes.sae import SAE, SAEConfig, TrainingSAE, TrainingSAEConfig
|
|
6
|
+
|
|
7
|
+
SAE_CLASS_REGISTRY: dict[str, tuple["type[SAE[Any]]", "type[SAEConfig]"]] = {}
|
|
8
|
+
SAE_TRAINING_CLASS_REGISTRY: dict[
|
|
9
|
+
str, tuple["type[TrainingSAE[Any]]", "type[TrainingSAEConfig]"]
|
|
10
|
+
] = {}
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def register_sae_class(
|
|
14
|
+
architecture: str,
|
|
15
|
+
sae_class: "type[SAE[Any]]",
|
|
16
|
+
sae_config_class: "type[SAEConfig]",
|
|
17
|
+
) -> None:
|
|
18
|
+
if architecture in SAE_CLASS_REGISTRY:
|
|
19
|
+
raise ValueError(
|
|
20
|
+
f"SAE class for architecture {architecture} already registered."
|
|
21
|
+
)
|
|
22
|
+
SAE_CLASS_REGISTRY[architecture] = (sae_class, sae_config_class)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def register_sae_training_class(
|
|
26
|
+
architecture: str,
|
|
27
|
+
sae_training_class: "type[TrainingSAE[Any]]",
|
|
28
|
+
sae_training_config_class: "type[TrainingSAEConfig]",
|
|
29
|
+
) -> None:
|
|
30
|
+
if architecture in SAE_TRAINING_CLASS_REGISTRY:
|
|
31
|
+
raise ValueError(
|
|
32
|
+
f"SAE training class for architecture {architecture} already registered."
|
|
33
|
+
)
|
|
34
|
+
SAE_TRAINING_CLASS_REGISTRY[architecture] = (
|
|
35
|
+
sae_training_class,
|
|
36
|
+
sae_training_config_class,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def get_sae_class(
|
|
41
|
+
architecture: str,
|
|
42
|
+
) -> tuple["type[SAE[Any]]", "type[SAEConfig]"]:
|
|
43
|
+
return SAE_CLASS_REGISTRY[architecture]
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def get_sae_training_class(
|
|
47
|
+
architecture: str,
|
|
48
|
+
) -> tuple["type[TrainingSAE[Any]]", "type[TrainingSAEConfig]"]:
|
|
49
|
+
return SAE_TRAINING_CLASS_REGISTRY[architecture]
|
sae_lens/sae_training_runner.py
CHANGED
|
@@ -7,13 +7,15 @@ from typing import Any, cast
|
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
9
|
import wandb
|
|
10
|
+
from safetensors.torch import save_file
|
|
10
11
|
from simple_parsing import ArgumentParser
|
|
11
12
|
from transformer_lens.hook_points import HookedRootModule
|
|
12
13
|
|
|
13
14
|
from sae_lens import logger
|
|
14
15
|
from sae_lens.config import HfDataset, LanguageModelSAERunnerConfig
|
|
16
|
+
from sae_lens.constants import RUNNER_CFG_FILENAME, SPARSITY_FILENAME
|
|
15
17
|
from sae_lens.load_model import load_model
|
|
16
|
-
from sae_lens.saes.sae import TrainingSAE, TrainingSAEConfig
|
|
18
|
+
from sae_lens.saes.sae import T_TRAINING_SAE_CONFIG, TrainingSAE, TrainingSAEConfig
|
|
17
19
|
from sae_lens.training.activations_store import ActivationsStore
|
|
18
20
|
from sae_lens.training.geometric_median import compute_geometric_median
|
|
19
21
|
from sae_lens.training.sae_trainer import SAETrainer
|
|
@@ -32,17 +34,17 @@ class SAETrainingRunner:
|
|
|
32
34
|
Class to run the training of a Sparse Autoencoder (SAE) on a TransformerLens model.
|
|
33
35
|
"""
|
|
34
36
|
|
|
35
|
-
cfg: LanguageModelSAERunnerConfig
|
|
37
|
+
cfg: LanguageModelSAERunnerConfig[Any]
|
|
36
38
|
model: HookedRootModule
|
|
37
|
-
sae: TrainingSAE
|
|
39
|
+
sae: TrainingSAE[Any]
|
|
38
40
|
activations_store: ActivationsStore
|
|
39
41
|
|
|
40
42
|
def __init__(
|
|
41
43
|
self,
|
|
42
|
-
cfg: LanguageModelSAERunnerConfig,
|
|
44
|
+
cfg: LanguageModelSAERunnerConfig[T_TRAINING_SAE_CONFIG],
|
|
43
45
|
override_dataset: HfDataset | None = None,
|
|
44
46
|
override_model: HookedRootModule | None = None,
|
|
45
|
-
override_sae: TrainingSAE | None = None,
|
|
47
|
+
override_sae: TrainingSAE[Any] | None = None,
|
|
46
48
|
):
|
|
47
49
|
if override_dataset is not None:
|
|
48
50
|
logger.warning(
|
|
@@ -141,7 +143,9 @@ class SAETrainingRunner:
|
|
|
141
143
|
backend=backend,
|
|
142
144
|
) # type: ignore
|
|
143
145
|
|
|
144
|
-
def run_trainer_with_interruption_handling(
|
|
146
|
+
def run_trainer_with_interruption_handling(
|
|
147
|
+
self, trainer: SAETrainer[TrainingSAE[TrainingSAEConfig], TrainingSAEConfig]
|
|
148
|
+
):
|
|
145
149
|
try:
|
|
146
150
|
# signal handlers (if preempted)
|
|
147
151
|
signal.signal(signal.SIGINT, interrupt_callback)
|
|
@@ -167,7 +171,7 @@ class SAETrainingRunner:
|
|
|
167
171
|
extract all activations at a certain layer and use for sae b_dec initialization
|
|
168
172
|
"""
|
|
169
173
|
|
|
170
|
-
if self.cfg.b_dec_init_method == "geometric_median":
|
|
174
|
+
if self.cfg.sae.b_dec_init_method == "geometric_median":
|
|
171
175
|
self.activations_store.set_norm_scaling_factor_if_needed()
|
|
172
176
|
layer_acts = self.activations_store.storage_buffer.detach()[:, 0, :]
|
|
173
177
|
# get geometric median of the activations if we're using those.
|
|
@@ -176,14 +180,14 @@ class SAETrainingRunner:
|
|
|
176
180
|
maxiter=100,
|
|
177
181
|
).median
|
|
178
182
|
self.sae.initialize_b_dec_with_precalculated(median)
|
|
179
|
-
elif self.cfg.b_dec_init_method == "mean":
|
|
183
|
+
elif self.cfg.sae.b_dec_init_method == "mean":
|
|
180
184
|
self.activations_store.set_norm_scaling_factor_if_needed()
|
|
181
185
|
layer_acts = self.activations_store.storage_buffer.detach().cpu()[:, 0, :]
|
|
182
186
|
self.sae.initialize_b_dec_with_mean(layer_acts) # type: ignore
|
|
183
187
|
|
|
184
188
|
@staticmethod
|
|
185
189
|
def save_checkpoint(
|
|
186
|
-
trainer: SAETrainer,
|
|
190
|
+
trainer: SAETrainer[TrainingSAE[Any], Any],
|
|
187
191
|
checkpoint_name: str,
|
|
188
192
|
wandb_aliases: list[str] | None = None,
|
|
189
193
|
) -> None:
|
|
@@ -194,19 +198,14 @@ class SAETrainingRunner:
|
|
|
194
198
|
str(base_path / "activations_store_state.safetensors")
|
|
195
199
|
)
|
|
196
200
|
|
|
197
|
-
|
|
198
|
-
trainer.sae.set_decoder_norm_to_unit_norm()
|
|
201
|
+
weights_path, cfg_path = trainer.sae.save_model(str(base_path))
|
|
199
202
|
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
trainer.log_feature_sparsity,
|
|
203
|
-
)
|
|
203
|
+
sparsity_path = base_path / SPARSITY_FILENAME
|
|
204
|
+
save_file({"sparsity": trainer.log_feature_sparsity}, sparsity_path)
|
|
204
205
|
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
with open(cfg_path, "w") as f:
|
|
209
|
-
json.dump(config, f)
|
|
206
|
+
runner_config = trainer.cfg.to_dict()
|
|
207
|
+
with open(base_path / RUNNER_CFG_FILENAME, "w") as f:
|
|
208
|
+
json.dump(runner_config, f)
|
|
210
209
|
|
|
211
210
|
if trainer.cfg.logger.log_to_wandb:
|
|
212
211
|
trainer.cfg.logger.log(
|
|
@@ -218,7 +217,9 @@ class SAETrainingRunner:
|
|
|
218
217
|
)
|
|
219
218
|
|
|
220
219
|
|
|
221
|
-
def _parse_cfg_args(
|
|
220
|
+
def _parse_cfg_args(
|
|
221
|
+
args: Sequence[str],
|
|
222
|
+
) -> LanguageModelSAERunnerConfig[TrainingSAEConfig]:
|
|
222
223
|
if len(args) == 0:
|
|
223
224
|
args = ["--help"]
|
|
224
225
|
parser = ArgumentParser(exit_on_error=False)
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
from .gated_sae import (
|
|
2
|
+
GatedSAE,
|
|
3
|
+
GatedSAEConfig,
|
|
4
|
+
GatedTrainingSAE,
|
|
5
|
+
GatedTrainingSAEConfig,
|
|
6
|
+
)
|
|
7
|
+
from .jumprelu_sae import (
|
|
8
|
+
JumpReLUSAE,
|
|
9
|
+
JumpReLUSAEConfig,
|
|
10
|
+
JumpReLUTrainingSAE,
|
|
11
|
+
JumpReLUTrainingSAEConfig,
|
|
12
|
+
)
|
|
13
|
+
from .sae import SAE, SAEConfig, TrainingSAE, TrainingSAEConfig
|
|
14
|
+
from .standard_sae import (
|
|
15
|
+
StandardSAE,
|
|
16
|
+
StandardSAEConfig,
|
|
17
|
+
StandardTrainingSAE,
|
|
18
|
+
StandardTrainingSAEConfig,
|
|
19
|
+
)
|
|
20
|
+
from .topk_sae import (
|
|
21
|
+
TopKSAE,
|
|
22
|
+
TopKSAEConfig,
|
|
23
|
+
TopKTrainingSAE,
|
|
24
|
+
TopKTrainingSAEConfig,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
__all__ = [
|
|
28
|
+
"SAE",
|
|
29
|
+
"SAEConfig",
|
|
30
|
+
"TrainingSAE",
|
|
31
|
+
"TrainingSAEConfig",
|
|
32
|
+
"StandardSAE",
|
|
33
|
+
"StandardSAEConfig",
|
|
34
|
+
"StandardTrainingSAE",
|
|
35
|
+
"StandardTrainingSAEConfig",
|
|
36
|
+
"GatedSAE",
|
|
37
|
+
"GatedSAEConfig",
|
|
38
|
+
"GatedTrainingSAE",
|
|
39
|
+
"GatedTrainingSAEConfig",
|
|
40
|
+
"JumpReLUSAE",
|
|
41
|
+
"JumpReLUSAEConfig",
|
|
42
|
+
"JumpReLUTrainingSAE",
|
|
43
|
+
"JumpReLUTrainingSAEConfig",
|
|
44
|
+
"TopKSAE",
|
|
45
|
+
"TopKSAEConfig",
|
|
46
|
+
"TopKTrainingSAE",
|
|
47
|
+
"TopKTrainingSAEConfig",
|
|
48
|
+
]
|
sae_lens/saes/gated_sae.py
CHANGED
|
@@ -1,20 +1,36 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
1
2
|
from typing import Any
|
|
2
3
|
|
|
3
4
|
import torch
|
|
4
5
|
from jaxtyping import Float
|
|
5
6
|
from numpy.typing import NDArray
|
|
6
7
|
from torch import nn
|
|
8
|
+
from typing_extensions import override
|
|
7
9
|
|
|
8
10
|
from sae_lens.saes.sae import (
|
|
9
11
|
SAE,
|
|
10
12
|
SAEConfig,
|
|
13
|
+
TrainCoefficientConfig,
|
|
11
14
|
TrainingSAE,
|
|
12
15
|
TrainingSAEConfig,
|
|
13
16
|
TrainStepInput,
|
|
14
17
|
)
|
|
18
|
+
from sae_lens.util import filter_valid_dataclass_fields
|
|
15
19
|
|
|
16
20
|
|
|
17
|
-
|
|
21
|
+
@dataclass
|
|
22
|
+
class GatedSAEConfig(SAEConfig):
|
|
23
|
+
"""
|
|
24
|
+
Configuration class for a GatedSAE.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
@override
|
|
28
|
+
@classmethod
|
|
29
|
+
def architecture(cls) -> str:
|
|
30
|
+
return "gated"
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class GatedSAE(SAE[GatedSAEConfig]):
|
|
18
34
|
"""
|
|
19
35
|
GatedSAE is an inference-only implementation of a Sparse Autoencoder (SAE)
|
|
20
36
|
using a gated linear encoder and a standard linear decoder.
|
|
@@ -24,48 +40,15 @@ class GatedSAE(SAE):
|
|
|
24
40
|
b_mag: nn.Parameter
|
|
25
41
|
r_mag: nn.Parameter
|
|
26
42
|
|
|
27
|
-
def __init__(self, cfg:
|
|
43
|
+
def __init__(self, cfg: GatedSAEConfig, use_error_term: bool = False):
|
|
28
44
|
super().__init__(cfg, use_error_term)
|
|
29
45
|
# Ensure b_enc does not exist for the gated architecture
|
|
30
46
|
self.b_enc = None
|
|
31
47
|
|
|
48
|
+
@override
|
|
32
49
|
def initialize_weights(self) -> None:
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
"""
|
|
36
|
-
# Use the same initialization methods and values as in original SAE
|
|
37
|
-
self.W_enc = nn.Parameter(
|
|
38
|
-
torch.nn.init.kaiming_uniform_(
|
|
39
|
-
torch.empty(
|
|
40
|
-
self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
|
|
41
|
-
)
|
|
42
|
-
)
|
|
43
|
-
)
|
|
44
|
-
self.b_gate = nn.Parameter(
|
|
45
|
-
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
|
|
46
|
-
)
|
|
47
|
-
# Ensure r_mag is initialized to zero as in original
|
|
48
|
-
self.r_mag = nn.Parameter(
|
|
49
|
-
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
|
|
50
|
-
)
|
|
51
|
-
self.b_mag = nn.Parameter(
|
|
52
|
-
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
|
|
53
|
-
)
|
|
54
|
-
|
|
55
|
-
# Decoder parameters with same initialization as original
|
|
56
|
-
self.W_dec = nn.Parameter(
|
|
57
|
-
torch.nn.init.kaiming_uniform_(
|
|
58
|
-
torch.empty(
|
|
59
|
-
self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
|
|
60
|
-
)
|
|
61
|
-
)
|
|
62
|
-
)
|
|
63
|
-
self.b_dec = nn.Parameter(
|
|
64
|
-
torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
|
|
65
|
-
)
|
|
66
|
-
|
|
67
|
-
# after defining b_gate, b_mag, etc.:
|
|
68
|
-
self.b_enc = None
|
|
50
|
+
super().initialize_weights()
|
|
51
|
+
_init_weights_gated(self)
|
|
69
52
|
|
|
70
53
|
def encode(
|
|
71
54
|
self, x: Float[torch.Tensor, "... d_in"]
|
|
@@ -101,9 +84,8 @@ class GatedSAE(SAE):
|
|
|
101
84
|
4) If the SAE was reshaping hook_z activations, reshape back.
|
|
102
85
|
"""
|
|
103
86
|
# 1) optional finetuning scaling
|
|
104
|
-
scaled_features = self.apply_finetuning_scaling_factor(feature_acts)
|
|
105
87
|
# 2) linear transform
|
|
106
|
-
sae_out_pre =
|
|
88
|
+
sae_out_pre = feature_acts @ self.W_dec + self.b_dec
|
|
107
89
|
# 3) hooking and normalization
|
|
108
90
|
sae_out_pre = self.hook_sae_recons(sae_out_pre)
|
|
109
91
|
sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
|
|
@@ -129,7 +111,22 @@ class GatedSAE(SAE):
|
|
|
129
111
|
self.W_dec.data *= norm
|
|
130
112
|
|
|
131
113
|
|
|
132
|
-
|
|
114
|
+
@dataclass
|
|
115
|
+
class GatedTrainingSAEConfig(TrainingSAEConfig):
|
|
116
|
+
"""
|
|
117
|
+
Configuration class for training a GatedTrainingSAE.
|
|
118
|
+
"""
|
|
119
|
+
|
|
120
|
+
l1_coefficient: float = 1.0
|
|
121
|
+
l1_warm_up_steps: int = 0
|
|
122
|
+
|
|
123
|
+
@override
|
|
124
|
+
@classmethod
|
|
125
|
+
def architecture(cls) -> str:
|
|
126
|
+
return "gated"
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class GatedTrainingSAE(TrainingSAE[GatedTrainingSAEConfig]):
|
|
133
130
|
"""
|
|
134
131
|
GatedTrainingSAE is a concrete implementation of BaseTrainingSAE for the "gated" SAE architecture.
|
|
135
132
|
It implements:
|
|
@@ -145,7 +142,7 @@ class GatedTrainingSAE(TrainingSAE):
|
|
|
145
142
|
b_mag: nn.Parameter # type: ignore
|
|
146
143
|
r_mag: nn.Parameter # type: ignore
|
|
147
144
|
|
|
148
|
-
def __init__(self, cfg:
|
|
145
|
+
def __init__(self, cfg: GatedTrainingSAEConfig, use_error_term: bool = False):
|
|
149
146
|
if use_error_term:
|
|
150
147
|
raise ValueError(
|
|
151
148
|
"GatedSAE does not support `use_error_term`. Please set `use_error_term=False`."
|
|
@@ -153,22 +150,8 @@ class GatedTrainingSAE(TrainingSAE):
|
|
|
153
150
|
super().__init__(cfg, use_error_term)
|
|
154
151
|
|
|
155
152
|
def initialize_weights(self) -> None:
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
# Additional training-specific logic, e.g. orthogonal init or heuristics:
|
|
160
|
-
if self.cfg.decoder_orthogonal_init:
|
|
161
|
-
self.W_dec.data = nn.init.orthogonal_(self.W_dec.data.T).T
|
|
162
|
-
elif self.cfg.decoder_heuristic_init:
|
|
163
|
-
self.W_dec.data = torch.rand(
|
|
164
|
-
self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
|
|
165
|
-
)
|
|
166
|
-
self.initialize_decoder_norm_constant_norm()
|
|
167
|
-
if self.cfg.init_encoder_as_decoder_transpose:
|
|
168
|
-
self.W_enc.data = self.W_dec.data.T.clone().contiguous()
|
|
169
|
-
if self.cfg.normalize_sae_decoder:
|
|
170
|
-
with torch.no_grad():
|
|
171
|
-
self.set_decoder_norm_to_unit_norm()
|
|
153
|
+
super().initialize_weights()
|
|
154
|
+
_init_weights_gated(self)
|
|
172
155
|
|
|
173
156
|
def encode_with_hidden_pre(
|
|
174
157
|
self, x: Float[torch.Tensor, "... d_in"]
|
|
@@ -217,7 +200,7 @@ class GatedTrainingSAE(TrainingSAE):
|
|
|
217
200
|
|
|
218
201
|
# L1-like penalty scaled by W_dec norms
|
|
219
202
|
l1_loss = (
|
|
220
|
-
step_input.
|
|
203
|
+
step_input.coefficients["l1"]
|
|
221
204
|
* torch.sum(pi_gate_act * self.W_dec.norm(dim=1), dim=-1).mean()
|
|
222
205
|
)
|
|
223
206
|
|
|
@@ -245,3 +228,31 @@ class GatedTrainingSAE(TrainingSAE):
|
|
|
245
228
|
"""Initialize decoder with constant norm"""
|
|
246
229
|
self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
|
|
247
230
|
self.W_dec.data *= norm
|
|
231
|
+
|
|
232
|
+
def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]:
|
|
233
|
+
return {
|
|
234
|
+
"l1": TrainCoefficientConfig(
|
|
235
|
+
value=self.cfg.l1_coefficient,
|
|
236
|
+
warm_up_steps=self.cfg.l1_warm_up_steps,
|
|
237
|
+
),
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
def to_inference_config_dict(self) -> dict[str, Any]:
|
|
241
|
+
return filter_valid_dataclass_fields(
|
|
242
|
+
self.cfg.to_dict(), GatedSAEConfig, ["architecture"]
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def _init_weights_gated(
|
|
247
|
+
sae: SAE[GatedSAEConfig] | TrainingSAE[GatedTrainingSAEConfig],
|
|
248
|
+
) -> None:
|
|
249
|
+
sae.b_gate = nn.Parameter(
|
|
250
|
+
torch.zeros(sae.cfg.d_sae, dtype=sae.dtype, device=sae.device)
|
|
251
|
+
)
|
|
252
|
+
# Ensure r_mag is initialized to zero as in original
|
|
253
|
+
sae.r_mag = nn.Parameter(
|
|
254
|
+
torch.zeros(sae.cfg.d_sae, dtype=sae.dtype, device=sae.device)
|
|
255
|
+
)
|
|
256
|
+
sae.b_mag = nn.Parameter(
|
|
257
|
+
torch.zeros(sae.cfg.d_sae, dtype=sae.dtype, device=sae.device)
|
|
258
|
+
)
|
sae_lens/saes/jumprelu_sae.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
1
2
|
from typing import Any
|
|
2
3
|
|
|
3
4
|
import numpy as np
|
|
@@ -9,11 +10,13 @@ from typing_extensions import override
|
|
|
9
10
|
from sae_lens.saes.sae import (
|
|
10
11
|
SAE,
|
|
11
12
|
SAEConfig,
|
|
13
|
+
TrainCoefficientConfig,
|
|
12
14
|
TrainingSAE,
|
|
13
15
|
TrainingSAEConfig,
|
|
14
16
|
TrainStepInput,
|
|
15
17
|
TrainStepOutput,
|
|
16
18
|
)
|
|
19
|
+
from sae_lens.util import filter_valid_dataclass_fields
|
|
17
20
|
|
|
18
21
|
|
|
19
22
|
def rectangle(x: torch.Tensor) -> torch.Tensor:
|
|
@@ -85,7 +88,19 @@ class JumpReLU(torch.autograd.Function):
|
|
|
85
88
|
return x_grad, threshold_grad, None
|
|
86
89
|
|
|
87
90
|
|
|
88
|
-
|
|
91
|
+
@dataclass
|
|
92
|
+
class JumpReLUSAEConfig(SAEConfig):
|
|
93
|
+
"""
|
|
94
|
+
Configuration class for a JumpReLUSAE.
|
|
95
|
+
"""
|
|
96
|
+
|
|
97
|
+
@override
|
|
98
|
+
@classmethod
|
|
99
|
+
def architecture(cls) -> str:
|
|
100
|
+
return "jumprelu"
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class JumpReLUSAE(SAE[JumpReLUSAEConfig]):
|
|
89
104
|
"""
|
|
90
105
|
JumpReLUSAE is an inference-only implementation of a Sparse Autoencoder (SAE)
|
|
91
106
|
using a JumpReLU activation. For each unit, if its pre-activation is
|
|
@@ -104,42 +119,18 @@ class JumpReLUSAE(SAE):
|
|
|
104
119
|
b_enc: nn.Parameter
|
|
105
120
|
threshold: nn.Parameter
|
|
106
121
|
|
|
107
|
-
def __init__(self, cfg:
|
|
122
|
+
def __init__(self, cfg: JumpReLUSAEConfig, use_error_term: bool = False):
|
|
108
123
|
super().__init__(cfg, use_error_term)
|
|
109
124
|
|
|
125
|
+
@override
|
|
110
126
|
def initialize_weights(self) -> None:
|
|
111
|
-
|
|
112
|
-
Initialize encoder and decoder weights, as well as biases.
|
|
113
|
-
Additionally, include a learnable `threshold` parameter that
|
|
114
|
-
determines when units "turn on" for the JumpReLU.
|
|
115
|
-
"""
|
|
116
|
-
# Biases
|
|
117
|
-
self.b_enc = nn.Parameter(
|
|
118
|
-
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
|
|
119
|
-
)
|
|
120
|
-
self.b_dec = nn.Parameter(
|
|
121
|
-
torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
|
|
122
|
-
)
|
|
123
|
-
|
|
124
|
-
# Threshold for JumpReLU
|
|
125
|
-
# You can pick a default initialization (e.g., zeros means unit is off unless hidden_pre > 0)
|
|
126
|
-
# or see the training version for more advanced init with log_threshold, etc.
|
|
127
|
+
super().initialize_weights()
|
|
127
128
|
self.threshold = nn.Parameter(
|
|
128
129
|
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
|
|
129
130
|
)
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
w_enc_data = torch.empty(
|
|
133
|
-
self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
|
|
134
|
-
)
|
|
135
|
-
nn.init.kaiming_uniform_(w_enc_data)
|
|
136
|
-
self.W_enc = nn.Parameter(w_enc_data)
|
|
137
|
-
|
|
138
|
-
w_dec_data = torch.empty(
|
|
139
|
-
self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
|
|
131
|
+
self.b_enc = nn.Parameter(
|
|
132
|
+
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
|
|
140
133
|
)
|
|
141
|
-
nn.init.kaiming_uniform_(w_dec_data)
|
|
142
|
-
self.W_dec = nn.Parameter(w_dec_data)
|
|
143
134
|
|
|
144
135
|
def encode(
|
|
145
136
|
self, x: Float[torch.Tensor, "... d_in"]
|
|
@@ -168,8 +159,7 @@ class JumpReLUSAE(SAE):
|
|
|
168
159
|
Decode the feature activations back to the input space.
|
|
169
160
|
Follows the same steps as StandardSAE: apply scaling, transform, hook, and optionally reshape.
|
|
170
161
|
"""
|
|
171
|
-
|
|
172
|
-
sae_out_pre = scaled_features @ self.W_dec + self.b_dec
|
|
162
|
+
sae_out_pre = feature_acts @ self.W_dec + self.b_dec
|
|
173
163
|
sae_out_pre = self.hook_sae_recons(sae_out_pre)
|
|
174
164
|
sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
|
|
175
165
|
return self.reshape_fn_out(sae_out_pre, self.d_head)
|
|
@@ -195,7 +185,24 @@ class JumpReLUSAE(SAE):
|
|
|
195
185
|
self.threshold.data = current_thresh * W_dec_norms
|
|
196
186
|
|
|
197
187
|
|
|
198
|
-
|
|
188
|
+
@dataclass
|
|
189
|
+
class JumpReLUTrainingSAEConfig(TrainingSAEConfig):
|
|
190
|
+
"""
|
|
191
|
+
Configuration class for training a JumpReLUTrainingSAE.
|
|
192
|
+
"""
|
|
193
|
+
|
|
194
|
+
jumprelu_init_threshold: float = 0.001
|
|
195
|
+
jumprelu_bandwidth: float = 0.001
|
|
196
|
+
l0_coefficient: float = 1.0
|
|
197
|
+
l0_warm_up_steps: int = 0
|
|
198
|
+
|
|
199
|
+
@override
|
|
200
|
+
@classmethod
|
|
201
|
+
def architecture(cls) -> str:
|
|
202
|
+
return "jumprelu"
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
class JumpReLUTrainingSAE(TrainingSAE[JumpReLUTrainingSAEConfig]):
|
|
199
206
|
"""
|
|
200
207
|
JumpReLUTrainingSAE is a training-focused implementation of a SAE using a JumpReLU activation.
|
|
201
208
|
|
|
@@ -213,7 +220,7 @@ class JumpReLUTrainingSAE(TrainingSAE):
|
|
|
213
220
|
b_enc: nn.Parameter
|
|
214
221
|
log_threshold: nn.Parameter
|
|
215
222
|
|
|
216
|
-
def __init__(self, cfg:
|
|
223
|
+
def __init__(self, cfg: JumpReLUTrainingSAEConfig, use_error_term: bool = False):
|
|
217
224
|
super().__init__(cfg, use_error_term)
|
|
218
225
|
|
|
219
226
|
# We'll store a bandwidth for the training approach, if needed
|
|
@@ -225,51 +232,16 @@ class JumpReLUTrainingSAE(TrainingSAE):
|
|
|
225
232
|
* np.log(cfg.jumprelu_init_threshold)
|
|
226
233
|
)
|
|
227
234
|
|
|
235
|
+
@override
|
|
228
236
|
def initialize_weights(self) -> None:
|
|
229
237
|
"""
|
|
230
238
|
Initialize parameters like the base SAE, but also add log_threshold.
|
|
231
239
|
"""
|
|
240
|
+
super().initialize_weights()
|
|
232
241
|
# Encoder Bias
|
|
233
242
|
self.b_enc = nn.Parameter(
|
|
234
243
|
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
|
|
235
244
|
)
|
|
236
|
-
# Decoder Bias
|
|
237
|
-
self.b_dec = nn.Parameter(
|
|
238
|
-
torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
|
|
239
|
-
)
|
|
240
|
-
# W_enc
|
|
241
|
-
w_enc_data = torch.nn.init.kaiming_uniform_(
|
|
242
|
-
torch.empty(
|
|
243
|
-
self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
|
|
244
|
-
)
|
|
245
|
-
)
|
|
246
|
-
self.W_enc = nn.Parameter(w_enc_data)
|
|
247
|
-
|
|
248
|
-
# W_dec
|
|
249
|
-
w_dec_data = torch.nn.init.kaiming_uniform_(
|
|
250
|
-
torch.empty(
|
|
251
|
-
self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
|
|
252
|
-
)
|
|
253
|
-
)
|
|
254
|
-
self.W_dec = nn.Parameter(w_dec_data)
|
|
255
|
-
|
|
256
|
-
# Optionally apply orthogonal or heuristic init
|
|
257
|
-
if self.cfg.decoder_orthogonal_init:
|
|
258
|
-
self.W_dec.data = nn.init.orthogonal_(self.W_dec.data.T).T
|
|
259
|
-
elif self.cfg.decoder_heuristic_init:
|
|
260
|
-
self.W_dec.data = torch.rand(
|
|
261
|
-
self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
|
|
262
|
-
)
|
|
263
|
-
self.initialize_decoder_norm_constant_norm()
|
|
264
|
-
|
|
265
|
-
# Optionally transpose
|
|
266
|
-
if self.cfg.init_encoder_as_decoder_transpose:
|
|
267
|
-
self.W_enc.data = self.W_dec.data.T.clone().contiguous()
|
|
268
|
-
|
|
269
|
-
# Optionally normalize columns of W_dec
|
|
270
|
-
if self.cfg.normalize_sae_decoder:
|
|
271
|
-
with torch.no_grad():
|
|
272
|
-
self.set_decoder_norm_to_unit_norm()
|
|
273
245
|
|
|
274
246
|
@property
|
|
275
247
|
def threshold(self) -> torch.Tensor:
|
|
@@ -305,9 +277,18 @@ class JumpReLUTrainingSAE(TrainingSAE):
|
|
|
305
277
|
) -> dict[str, torch.Tensor]:
|
|
306
278
|
"""Calculate architecture-specific auxiliary loss terms."""
|
|
307
279
|
l0 = torch.sum(Step.apply(hidden_pre, self.threshold, self.bandwidth), dim=-1) # type: ignore
|
|
308
|
-
l0_loss = (step_input.
|
|
280
|
+
l0_loss = (step_input.coefficients["l0"] * l0).mean()
|
|
309
281
|
return {"l0_loss": l0_loss}
|
|
310
282
|
|
|
283
|
+
@override
|
|
284
|
+
def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]:
|
|
285
|
+
return {
|
|
286
|
+
"l0": TrainCoefficientConfig(
|
|
287
|
+
value=self.cfg.l0_coefficient,
|
|
288
|
+
warm_up_steps=self.cfg.l0_warm_up_steps,
|
|
289
|
+
),
|
|
290
|
+
}
|
|
291
|
+
|
|
311
292
|
@torch.no_grad()
|
|
312
293
|
def fold_W_dec_norm(self):
|
|
313
294
|
"""
|
|
@@ -366,3 +347,8 @@ class JumpReLUTrainingSAE(TrainingSAE):
|
|
|
366
347
|
threshold = state_dict["threshold"]
|
|
367
348
|
del state_dict["threshold"]
|
|
368
349
|
state_dict["log_threshold"] = torch.log(threshold).detach().contiguous()
|
|
350
|
+
|
|
351
|
+
def to_inference_config_dict(self) -> dict[str, Any]:
|
|
352
|
+
return filter_valid_dataclass_fields(
|
|
353
|
+
self.cfg.to_dict(), JumpReLUSAEConfig, ["architecture"]
|
|
354
|
+
)
|