sae-lens 6.14.1__py3-none-any.whl → 6.22.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +15 -1
- sae_lens/analysis/hooked_sae_transformer.py +4 -13
- sae_lens/cache_activations_runner.py +3 -4
- sae_lens/config.py +39 -2
- sae_lens/constants.py +1 -0
- sae_lens/evals.py +18 -14
- sae_lens/llm_sae_training_runner.py +17 -18
- sae_lens/loading/pretrained_sae_loaders.py +188 -0
- sae_lens/loading/pretrained_saes_directory.py +5 -3
- sae_lens/pretrained_saes.yaml +77 -1
- sae_lens/saes/__init__.py +9 -0
- sae_lens/saes/batchtopk_sae.py +3 -1
- sae_lens/saes/gated_sae.py +4 -9
- sae_lens/saes/jumprelu_sae.py +4 -9
- sae_lens/saes/matryoshka_batchtopk_sae.py +136 -0
- sae_lens/saes/sae.py +19 -31
- sae_lens/saes/standard_sae.py +4 -9
- sae_lens/saes/temporal_sae.py +365 -0
- sae_lens/saes/topk_sae.py +7 -10
- sae_lens/training/activation_scaler.py +7 -0
- sae_lens/training/activations_store.py +54 -34
- sae_lens/training/optim.py +11 -0
- sae_lens/training/sae_trainer.py +50 -11
- sae_lens/util.py +27 -0
- {sae_lens-6.14.1.dist-info → sae_lens-6.22.1.dist-info}/METADATA +16 -16
- sae_lens-6.22.1.dist-info/RECORD +41 -0
- sae_lens-6.14.1.dist-info/RECORD +0 -39
- {sae_lens-6.14.1.dist-info → sae_lens-6.22.1.dist-info}/WHEEL +0 -0
- {sae_lens-6.14.1.dist-info → sae_lens-6.22.1.dist-info}/licenses/LICENSE +0 -0
sae_lens/pretrained_saes.yaml
CHANGED
|
@@ -1,3 +1,35 @@
|
|
|
1
|
+
temporal-sae-gemma-2-2b:
|
|
2
|
+
conversion_func: temporal
|
|
3
|
+
model: gemma-2-2b
|
|
4
|
+
repo_id: canrager/temporalSAEs
|
|
5
|
+
config_overrides:
|
|
6
|
+
model_name: gemma-2-2b
|
|
7
|
+
hook_name: blocks.12.hook_resid_post
|
|
8
|
+
dataset_path: monology/pile-uncopyrighted
|
|
9
|
+
saes:
|
|
10
|
+
- id: blocks.12.hook_resid_post
|
|
11
|
+
l0: 192
|
|
12
|
+
norm_scaling_factor: 0.00666666667
|
|
13
|
+
path: gemma-2-2B/layer_12/temporal
|
|
14
|
+
neuronpedia: gemma-2-2b/12-temporal-res
|
|
15
|
+
temporal-sae-llama-3.1-8b:
|
|
16
|
+
conversion_func: temporal
|
|
17
|
+
model: meta-llama/Llama-3.1-8B
|
|
18
|
+
repo_id: canrager/temporalSAEs
|
|
19
|
+
config_overrides:
|
|
20
|
+
model_name: meta-llama/Llama-3.1-8B
|
|
21
|
+
dataset_path: monology/pile-uncopyrighted
|
|
22
|
+
saes:
|
|
23
|
+
- id: blocks.15.hook_resid_post
|
|
24
|
+
l0: 256
|
|
25
|
+
norm_scaling_factor: 0.029
|
|
26
|
+
path: llama-3.1-8B/layer_15/temporal
|
|
27
|
+
neuronpedia: llama3.1-8b/15-temporal-res
|
|
28
|
+
- id: blocks.26.hook_resid_post
|
|
29
|
+
l0: 256
|
|
30
|
+
norm_scaling_factor: 0.029
|
|
31
|
+
path: llama-3.1-8B/layer_26/temporal
|
|
32
|
+
neuronpedia: llama3.1-8b/26-temporal-res
|
|
1
33
|
deepseek-r1-distill-llama-8b-qresearch:
|
|
2
34
|
conversion_func: deepseek_r1
|
|
3
35
|
model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
|
|
@@ -14882,4 +14914,48 @@ qwen2.5-7b-instruct-andyrdt:
|
|
|
14882
14914
|
neuronpedia: qwen2.5-7b-it/23-resid-post-aa
|
|
14883
14915
|
- id: resid_post_layer_27_trainer_1
|
|
14884
14916
|
path: resid_post_layer_27/trainer_1
|
|
14885
|
-
neuronpedia: qwen2.5-7b-it/27-resid-post-aa
|
|
14917
|
+
neuronpedia: qwen2.5-7b-it/27-resid-post-aa
|
|
14918
|
+
|
|
14919
|
+
gpt-oss-20b-andyrdt:
|
|
14920
|
+
conversion_func: dictionary_learning_1
|
|
14921
|
+
model: openai/gpt-oss-20b
|
|
14922
|
+
repo_id: andyrdt/saes-gpt-oss-20b
|
|
14923
|
+
saes:
|
|
14924
|
+
- id: resid_post_layer_3_trainer_0
|
|
14925
|
+
path: resid_post_layer_3/trainer_0
|
|
14926
|
+
neuronpedia: gpt-oss-20b/3-resid-post-aa
|
|
14927
|
+
- id: resid_post_layer_7_trainer_0
|
|
14928
|
+
path: resid_post_layer_7/trainer_0
|
|
14929
|
+
neuronpedia: gpt-oss-20b/7-resid-post-aa
|
|
14930
|
+
- id: resid_post_layer_11_trainer_0
|
|
14931
|
+
path: resid_post_layer_11/trainer_0
|
|
14932
|
+
neuronpedia: gpt-oss-20b/11-resid-post-aa
|
|
14933
|
+
- id: resid_post_layer_15_trainer_0
|
|
14934
|
+
path: resid_post_layer_15/trainer_0
|
|
14935
|
+
neuronpedia: gpt-oss-20b/15-resid-post-aa
|
|
14936
|
+
- id: resid_post_layer_19_trainer_0
|
|
14937
|
+
path: resid_post_layer_19/trainer_0
|
|
14938
|
+
neuronpedia: gpt-oss-20b/19-resid-post-aa
|
|
14939
|
+
- id: resid_post_layer_23_trainer_0
|
|
14940
|
+
path: resid_post_layer_23/trainer_0
|
|
14941
|
+
neuronpedia: gpt-oss-20b/23-resid-post-aa
|
|
14942
|
+
|
|
14943
|
+
goodfire-llama-3.3-70b-instruct:
|
|
14944
|
+
conversion_func: goodfire
|
|
14945
|
+
model: meta-llama/Llama-3.3-70B-Instruct
|
|
14946
|
+
repo_id: Goodfire/Llama-3.3-70B-Instruct-SAE-l50
|
|
14947
|
+
saes:
|
|
14948
|
+
- id: layer_50
|
|
14949
|
+
path: Llama-3.3-70B-Instruct-SAE-l50.pt
|
|
14950
|
+
l0: 121
|
|
14951
|
+
neuronpedia: llama3.3-70b-it/50-resid-post-gf
|
|
14952
|
+
|
|
14953
|
+
goodfire-llama-3.1-8b-instruct:
|
|
14954
|
+
conversion_func: goodfire
|
|
14955
|
+
model: meta-llama/Llama-3.1-8B-Instruct
|
|
14956
|
+
repo_id: Goodfire/Llama-3.1-8B-Instruct-SAE-l19
|
|
14957
|
+
saes:
|
|
14958
|
+
- id: layer_19
|
|
14959
|
+
path: Llama-3.1-8B-Instruct-SAE-l19.pth
|
|
14960
|
+
l0: 91
|
|
14961
|
+
neuronpedia: llama3.1-8b-it/19-resid-post-gf
|
sae_lens/saes/__init__.py
CHANGED
|
@@ -14,6 +14,10 @@ from .jumprelu_sae import (
|
|
|
14
14
|
JumpReLUTrainingSAE,
|
|
15
15
|
JumpReLUTrainingSAEConfig,
|
|
16
16
|
)
|
|
17
|
+
from .matryoshka_batchtopk_sae import (
|
|
18
|
+
MatryoshkaBatchTopKTrainingSAE,
|
|
19
|
+
MatryoshkaBatchTopKTrainingSAEConfig,
|
|
20
|
+
)
|
|
17
21
|
from .sae import SAE, SAEConfig, TrainingSAE, TrainingSAEConfig
|
|
18
22
|
from .standard_sae import (
|
|
19
23
|
StandardSAE,
|
|
@@ -21,6 +25,7 @@ from .standard_sae import (
|
|
|
21
25
|
StandardTrainingSAE,
|
|
22
26
|
StandardTrainingSAEConfig,
|
|
23
27
|
)
|
|
28
|
+
from .temporal_sae import TemporalSAE, TemporalSAEConfig
|
|
24
29
|
from .topk_sae import (
|
|
25
30
|
TopKSAE,
|
|
26
31
|
TopKSAEConfig,
|
|
@@ -65,4 +70,8 @@ __all__ = [
|
|
|
65
70
|
"SkipTranscoderConfig",
|
|
66
71
|
"JumpReLUTranscoder",
|
|
67
72
|
"JumpReLUTranscoderConfig",
|
|
73
|
+
"MatryoshkaBatchTopKTrainingSAE",
|
|
74
|
+
"MatryoshkaBatchTopKTrainingSAEConfig",
|
|
75
|
+
"TemporalSAE",
|
|
76
|
+
"TemporalSAEConfig",
|
|
68
77
|
]
|
sae_lens/saes/batchtopk_sae.py
CHANGED
|
@@ -23,7 +23,9 @@ class BatchTopK(nn.Module):
|
|
|
23
23
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
24
24
|
acts = x.relu()
|
|
25
25
|
flat_acts = acts.flatten()
|
|
26
|
-
|
|
26
|
+
# Calculate total number of samples across all non-feature dimensions
|
|
27
|
+
num_samples = acts.shape[:-1].numel()
|
|
28
|
+
acts_topk_flat = torch.topk(flat_acts, int(self.k * num_samples), dim=-1)
|
|
27
29
|
return (
|
|
28
30
|
torch.zeros_like(flat_acts)
|
|
29
31
|
.scatter(-1, acts_topk_flat.indices, acts_topk_flat.values)
|
sae_lens/saes/gated_sae.py
CHANGED
|
@@ -2,7 +2,6 @@ from dataclasses import dataclass
|
|
|
2
2
|
from typing import Any
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
|
-
from jaxtyping import Float
|
|
6
5
|
from numpy.typing import NDArray
|
|
7
6
|
from torch import nn
|
|
8
7
|
from typing_extensions import override
|
|
@@ -49,9 +48,7 @@ class GatedSAE(SAE[GatedSAEConfig]):
|
|
|
49
48
|
super().initialize_weights()
|
|
50
49
|
_init_weights_gated(self)
|
|
51
50
|
|
|
52
|
-
def encode(
|
|
53
|
-
self, x: Float[torch.Tensor, "... d_in"]
|
|
54
|
-
) -> Float[torch.Tensor, "... d_sae"]:
|
|
51
|
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
|
55
52
|
"""
|
|
56
53
|
Encode the input tensor into the feature space using a gated encoder.
|
|
57
54
|
This must match the original encode_gated implementation from SAE class.
|
|
@@ -72,9 +69,7 @@ class GatedSAE(SAE[GatedSAEConfig]):
|
|
|
72
69
|
# Combine gating and magnitudes
|
|
73
70
|
return self.hook_sae_acts_post(active_features * feature_magnitudes)
|
|
74
71
|
|
|
75
|
-
def decode(
|
|
76
|
-
self, feature_acts: Float[torch.Tensor, "... d_sae"]
|
|
77
|
-
) -> Float[torch.Tensor, "... d_in"]:
|
|
72
|
+
def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
|
|
78
73
|
"""
|
|
79
74
|
Decode the feature activations back into the input space:
|
|
80
75
|
1) Apply optional finetuning scaling.
|
|
@@ -147,8 +142,8 @@ class GatedTrainingSAE(TrainingSAE[GatedTrainingSAEConfig]):
|
|
|
147
142
|
_init_weights_gated(self)
|
|
148
143
|
|
|
149
144
|
def encode_with_hidden_pre(
|
|
150
|
-
self, x:
|
|
151
|
-
) -> tuple[
|
|
145
|
+
self, x: torch.Tensor
|
|
146
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
152
147
|
"""
|
|
153
148
|
Gated forward pass with pre-activation (for training).
|
|
154
149
|
"""
|
sae_lens/saes/jumprelu_sae.py
CHANGED
|
@@ -3,7 +3,6 @@ from typing import Any, Literal
|
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
import torch
|
|
6
|
-
from jaxtyping import Float
|
|
7
6
|
from torch import nn
|
|
8
7
|
from typing_extensions import override
|
|
9
8
|
|
|
@@ -130,9 +129,7 @@ class JumpReLUSAE(SAE[JumpReLUSAEConfig]):
|
|
|
130
129
|
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
|
|
131
130
|
)
|
|
132
131
|
|
|
133
|
-
def encode(
|
|
134
|
-
self, x: Float[torch.Tensor, "... d_in"]
|
|
135
|
-
) -> Float[torch.Tensor, "... d_sae"]:
|
|
132
|
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
|
136
133
|
"""
|
|
137
134
|
Encode the input tensor into the feature space using JumpReLU.
|
|
138
135
|
The threshold parameter determines which units remain active.
|
|
@@ -150,9 +147,7 @@ class JumpReLUSAE(SAE[JumpReLUSAEConfig]):
|
|
|
150
147
|
# 3) Multiply the normally activated units by that mask.
|
|
151
148
|
return self.hook_sae_acts_post(base_acts * jump_relu_mask)
|
|
152
149
|
|
|
153
|
-
def decode(
|
|
154
|
-
self, feature_acts: Float[torch.Tensor, "... d_sae"]
|
|
155
|
-
) -> Float[torch.Tensor, "... d_in"]:
|
|
150
|
+
def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
|
|
156
151
|
"""
|
|
157
152
|
Decode the feature activations back to the input space.
|
|
158
153
|
Follows the same steps as StandardSAE: apply scaling, transform, hook, and optionally reshape.
|
|
@@ -265,8 +260,8 @@ class JumpReLUTrainingSAE(TrainingSAE[JumpReLUTrainingSAEConfig]):
|
|
|
265
260
|
return torch.exp(self.log_threshold)
|
|
266
261
|
|
|
267
262
|
def encode_with_hidden_pre(
|
|
268
|
-
self, x:
|
|
269
|
-
) -> tuple[
|
|
263
|
+
self, x: torch.Tensor
|
|
264
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
270
265
|
sae_in = self.process_sae_in(x)
|
|
271
266
|
|
|
272
267
|
hidden_pre = sae_in @ self.W_enc + self.b_enc
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from typing_extensions import override
|
|
6
|
+
|
|
7
|
+
from sae_lens.saes.batchtopk_sae import (
|
|
8
|
+
BatchTopKTrainingSAE,
|
|
9
|
+
BatchTopKTrainingSAEConfig,
|
|
10
|
+
)
|
|
11
|
+
from sae_lens.saes.sae import TrainStepInput, TrainStepOutput
|
|
12
|
+
from sae_lens.saes.topk_sae import _sparse_matmul_nd
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class MatryoshkaBatchTopKTrainingSAEConfig(BatchTopKTrainingSAEConfig):
|
|
17
|
+
"""
|
|
18
|
+
Configuration class for training a MatryoshkaBatchTopKTrainingSAE.
|
|
19
|
+
|
|
20
|
+
[Matryoshka SAEs](https://arxiv.org/pdf/2503.17547) use a series of nested reconstruction
|
|
21
|
+
losses of different widths during training to avoid feature absorption. This also has a
|
|
22
|
+
nice side-effect of encouraging higher-frequency features to be learned in earlier levels.
|
|
23
|
+
However, this SAE has more hyperparameters to tune than standard BatchTopK SAEs, and takes
|
|
24
|
+
longer to train due to requiring multiple forward passes per training step.
|
|
25
|
+
|
|
26
|
+
After training, MatryoshkaBatchTopK SAEs are saved as JumpReLU SAEs.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
matryoshka_widths (list[int]): The widths of the matryoshka levels. Defaults to an empty list.
|
|
30
|
+
k (float): The number of features to keep active. Inherited from BatchTopKTrainingSAEConfig.
|
|
31
|
+
Defaults to 100.
|
|
32
|
+
topk_threshold_lr (float): Learning rate for updating the global topk threshold.
|
|
33
|
+
The threshold is updated using an exponential moving average of the minimum
|
|
34
|
+
positive activation value. Defaults to 0.01.
|
|
35
|
+
aux_loss_coefficient (float): Coefficient for the auxiliary loss that encourages
|
|
36
|
+
dead neurons to learn useful features. Inherited from TopKTrainingSAEConfig.
|
|
37
|
+
Defaults to 1.0.
|
|
38
|
+
rescale_acts_by_decoder_norm (bool): Treat the decoder as if it was already normalized.
|
|
39
|
+
Inherited from TopKTrainingSAEConfig. Defaults to True.
|
|
40
|
+
decoder_init_norm (float | None): Norm to initialize decoder weights to.
|
|
41
|
+
Inherited from TrainingSAEConfig. Defaults to 0.1.
|
|
42
|
+
d_in (int): Input dimension (dimensionality of the activations being encoded).
|
|
43
|
+
Inherited from SAEConfig.
|
|
44
|
+
d_sae (int): SAE latent dimension (number of features in the SAE).
|
|
45
|
+
Inherited from SAEConfig.
|
|
46
|
+
dtype (str): Data type for the SAE parameters. Inherited from SAEConfig.
|
|
47
|
+
Defaults to "float32".
|
|
48
|
+
device (str): Device to place the SAE on. Inherited from SAEConfig.
|
|
49
|
+
Defaults to "cpu".
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
matryoshka_widths: list[int] = field(default_factory=list)
|
|
53
|
+
|
|
54
|
+
@override
|
|
55
|
+
@classmethod
|
|
56
|
+
def architecture(cls) -> str:
|
|
57
|
+
return "matryoshka_batchtopk"
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class MatryoshkaBatchTopKTrainingSAE(BatchTopKTrainingSAE):
|
|
61
|
+
"""
|
|
62
|
+
Global Batch TopK Training SAE
|
|
63
|
+
|
|
64
|
+
This SAE will maintain the k on average across the batch, rather than enforcing the k per-sample as in standard TopK.
|
|
65
|
+
|
|
66
|
+
BatchTopK SAEs are saved as JumpReLU SAEs after training.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
cfg: MatryoshkaBatchTopKTrainingSAEConfig # type: ignore[assignment]
|
|
70
|
+
|
|
71
|
+
def __init__(
|
|
72
|
+
self, cfg: MatryoshkaBatchTopKTrainingSAEConfig, use_error_term: bool = False
|
|
73
|
+
):
|
|
74
|
+
super().__init__(cfg, use_error_term)
|
|
75
|
+
_validate_matryoshka_config(cfg)
|
|
76
|
+
|
|
77
|
+
@override
|
|
78
|
+
def training_forward_pass(self, step_input: TrainStepInput) -> TrainStepOutput:
|
|
79
|
+
base_output = super().training_forward_pass(step_input)
|
|
80
|
+
inv_W_dec_norm = 1 / self.W_dec.norm(dim=-1)
|
|
81
|
+
# the outer matryoshka level is the base SAE, so we don't need to add an extra loss for it
|
|
82
|
+
for width in self.cfg.matryoshka_widths[:-1]:
|
|
83
|
+
inner_reconstruction = self._decode_matryoshka_level(
|
|
84
|
+
base_output.feature_acts, width, inv_W_dec_norm
|
|
85
|
+
)
|
|
86
|
+
inner_mse_loss = (
|
|
87
|
+
self.mse_loss_fn(inner_reconstruction, step_input.sae_in)
|
|
88
|
+
.sum(dim=-1)
|
|
89
|
+
.mean()
|
|
90
|
+
)
|
|
91
|
+
base_output.losses[f"inner_mse_loss_{width}"] = inner_mse_loss
|
|
92
|
+
base_output.loss = base_output.loss + inner_mse_loss
|
|
93
|
+
return base_output
|
|
94
|
+
|
|
95
|
+
def _decode_matryoshka_level(
|
|
96
|
+
self,
|
|
97
|
+
feature_acts: torch.Tensor,
|
|
98
|
+
width: int,
|
|
99
|
+
inv_W_dec_norm: torch.Tensor,
|
|
100
|
+
) -> torch.Tensor:
|
|
101
|
+
"""
|
|
102
|
+
Decodes feature activations back into input space for a matryoshka level
|
|
103
|
+
"""
|
|
104
|
+
inner_feature_acts = feature_acts[:, :width]
|
|
105
|
+
# Handle sparse tensors using efficient sparse matrix multiplication
|
|
106
|
+
if self.cfg.rescale_acts_by_decoder_norm:
|
|
107
|
+
# need to multiply by the inverse of the norm because division is illegal with sparse tensors
|
|
108
|
+
inner_feature_acts = inner_feature_acts * inv_W_dec_norm[:width]
|
|
109
|
+
if inner_feature_acts.is_sparse:
|
|
110
|
+
sae_out_pre = (
|
|
111
|
+
_sparse_matmul_nd(inner_feature_acts, self.W_dec[:width]) + self.b_dec
|
|
112
|
+
)
|
|
113
|
+
else:
|
|
114
|
+
sae_out_pre = inner_feature_acts @ self.W_dec[:width] + self.b_dec
|
|
115
|
+
sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
|
|
116
|
+
return self.reshape_fn_out(sae_out_pre, self.d_head)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def _validate_matryoshka_config(cfg: MatryoshkaBatchTopKTrainingSAEConfig) -> None:
|
|
120
|
+
if cfg.matryoshka_widths[-1] != cfg.d_sae:
|
|
121
|
+
# warn the users that we will add a final matryoshka level
|
|
122
|
+
warnings.warn(
|
|
123
|
+
"WARNING: The final matryoshka level width is not set to cfg.d_sae. "
|
|
124
|
+
"A final matryoshka level of width=cfg.d_sae will be added."
|
|
125
|
+
)
|
|
126
|
+
cfg.matryoshka_widths.append(cfg.d_sae)
|
|
127
|
+
|
|
128
|
+
for prev_width, curr_width in zip(
|
|
129
|
+
cfg.matryoshka_widths[:-1], cfg.matryoshka_widths[1:]
|
|
130
|
+
):
|
|
131
|
+
if prev_width >= curr_width:
|
|
132
|
+
raise ValueError("cfg.matryoshka_widths must be strictly increasing.")
|
|
133
|
+
if len(cfg.matryoshka_widths) == 1:
|
|
134
|
+
warnings.warn(
|
|
135
|
+
"WARNING: You have only set one matryoshka level. This is equivalent to using a standard BatchTopK SAE and is likely not what you want."
|
|
136
|
+
)
|
sae_lens/saes/sae.py
CHANGED
|
@@ -19,9 +19,8 @@ from typing import (
|
|
|
19
19
|
|
|
20
20
|
import einops
|
|
21
21
|
import torch
|
|
22
|
-
from jaxtyping import Float
|
|
23
22
|
from numpy.typing import NDArray
|
|
24
|
-
from safetensors.torch import save_file
|
|
23
|
+
from safetensors.torch import load_file, save_file
|
|
25
24
|
from torch import nn
|
|
26
25
|
from transformer_lens.hook_points import HookedRootModule, HookPoint
|
|
27
26
|
from typing_extensions import deprecated, overload, override
|
|
@@ -155,9 +154,9 @@ class SAEConfig(ABC):
|
|
|
155
154
|
dtype: str = "float32"
|
|
156
155
|
device: str = "cpu"
|
|
157
156
|
apply_b_dec_to_input: bool = True
|
|
158
|
-
normalize_activations: Literal[
|
|
159
|
-
"none",
|
|
160
|
-
|
|
157
|
+
normalize_activations: Literal["none", "expected_average_only_in", "layer_norm"] = (
|
|
158
|
+
"none" # none, expected_average_only_in (Anthropic April Update)
|
|
159
|
+
)
|
|
161
160
|
reshape_activations: Literal["none", "hook_z"] = "none"
|
|
162
161
|
metadata: SAEMetadata = field(default_factory=SAEMetadata)
|
|
163
162
|
|
|
@@ -217,6 +216,7 @@ class TrainStepInput:
|
|
|
217
216
|
sae_in: torch.Tensor
|
|
218
217
|
coefficients: dict[str, float]
|
|
219
218
|
dead_neuron_mask: torch.Tensor | None
|
|
219
|
+
n_training_steps: int
|
|
220
220
|
|
|
221
221
|
|
|
222
222
|
class TrainCoefficientConfig(NamedTuple):
|
|
@@ -308,6 +308,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
308
308
|
|
|
309
309
|
self.run_time_activation_norm_fn_in = run_time_activation_norm_fn_in
|
|
310
310
|
self.run_time_activation_norm_fn_out = run_time_activation_norm_fn_out
|
|
311
|
+
|
|
311
312
|
elif self.cfg.normalize_activations == "layer_norm":
|
|
312
313
|
# we need to scale the norm of the input and store the scaling factor
|
|
313
314
|
def run_time_activation_ln_in(
|
|
@@ -349,16 +350,12 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
349
350
|
self.W_enc = nn.Parameter(w_enc_data)
|
|
350
351
|
|
|
351
352
|
@abstractmethod
|
|
352
|
-
def encode(
|
|
353
|
-
self, x: Float[torch.Tensor, "... d_in"]
|
|
354
|
-
) -> Float[torch.Tensor, "... d_sae"]:
|
|
353
|
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
|
355
354
|
"""Encode input tensor to feature space."""
|
|
356
355
|
pass
|
|
357
356
|
|
|
358
357
|
@abstractmethod
|
|
359
|
-
def decode(
|
|
360
|
-
self, feature_acts: Float[torch.Tensor, "... d_sae"]
|
|
361
|
-
) -> Float[torch.Tensor, "... d_in"]:
|
|
358
|
+
def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
|
|
362
359
|
"""Decode feature activations back to input space."""
|
|
363
360
|
pass
|
|
364
361
|
|
|
@@ -448,26 +445,15 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
448
445
|
|
|
449
446
|
return super().to(*args, **kwargs)
|
|
450
447
|
|
|
451
|
-
def process_sae_in(
|
|
452
|
-
self, sae_in: Float[torch.Tensor, "... d_in"]
|
|
453
|
-
) -> Float[torch.Tensor, "... d_in"]:
|
|
454
|
-
# print(f"Input shape to process_sae_in: {sae_in.shape}")
|
|
455
|
-
# print(f"self.cfg.hook_name: {self.cfg.hook_name}")
|
|
456
|
-
# print(f"self.b_dec shape: {self.b_dec.shape}")
|
|
457
|
-
# print(f"Hook z reshaping mode: {getattr(self, 'hook_z_reshaping_mode', False)}")
|
|
458
|
-
|
|
448
|
+
def process_sae_in(self, sae_in: torch.Tensor) -> torch.Tensor:
|
|
459
449
|
sae_in = sae_in.to(self.dtype)
|
|
460
|
-
|
|
461
|
-
# print(f"Shape before reshape_fn_in: {sae_in.shape}")
|
|
462
450
|
sae_in = self.reshape_fn_in(sae_in)
|
|
463
|
-
# print(f"Shape after reshape_fn_in: {sae_in.shape}")
|
|
464
451
|
|
|
465
452
|
sae_in = self.hook_sae_input(sae_in)
|
|
466
453
|
sae_in = self.run_time_activation_norm_fn_in(sae_in)
|
|
467
454
|
|
|
468
455
|
# Here's where the error happens
|
|
469
456
|
bias_term = self.b_dec * self.cfg.apply_b_dec_to_input
|
|
470
|
-
# print(f"Bias term shape: {bias_term.shape}")
|
|
471
457
|
|
|
472
458
|
return sae_in - bias_term
|
|
473
459
|
|
|
@@ -866,14 +852,12 @@ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
|
|
|
866
852
|
|
|
867
853
|
@abstractmethod
|
|
868
854
|
def encode_with_hidden_pre(
|
|
869
|
-
self, x:
|
|
870
|
-
) -> tuple[
|
|
855
|
+
self, x: torch.Tensor
|
|
856
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
871
857
|
"""Encode with access to pre-activation values for training."""
|
|
872
858
|
...
|
|
873
859
|
|
|
874
|
-
def encode(
|
|
875
|
-
self, x: Float[torch.Tensor, "... d_in"]
|
|
876
|
-
) -> Float[torch.Tensor, "... d_sae"]:
|
|
860
|
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
|
877
861
|
"""
|
|
878
862
|
For inference, just encode without returning hidden_pre.
|
|
879
863
|
(training_forward_pass calls encode_with_hidden_pre).
|
|
@@ -881,9 +865,7 @@ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
|
|
|
881
865
|
feature_acts, _ = self.encode_with_hidden_pre(x)
|
|
882
866
|
return feature_acts
|
|
883
867
|
|
|
884
|
-
def decode(
|
|
885
|
-
self, feature_acts: Float[torch.Tensor, "... d_sae"]
|
|
886
|
-
) -> Float[torch.Tensor, "... d_in"]:
|
|
868
|
+
def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
|
|
887
869
|
"""
|
|
888
870
|
Decodes feature activations back into input space,
|
|
889
871
|
applying optional finetuning scale, hooking, out normalization, etc.
|
|
@@ -1017,6 +999,12 @@ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
|
|
|
1017
999
|
) -> type[TrainingSAEConfig]:
|
|
1018
1000
|
return get_sae_training_class(architecture)[1]
|
|
1019
1001
|
|
|
1002
|
+
def load_weights_from_checkpoint(self, checkpoint_path: Path | str) -> None:
|
|
1003
|
+
checkpoint_path = Path(checkpoint_path)
|
|
1004
|
+
state_dict = load_file(checkpoint_path / SAE_WEIGHTS_FILENAME)
|
|
1005
|
+
self.process_state_dict_for_loading(state_dict)
|
|
1006
|
+
self.load_state_dict(state_dict)
|
|
1007
|
+
|
|
1020
1008
|
|
|
1021
1009
|
_blank_hook = nn.Identity()
|
|
1022
1010
|
|
sae_lens/saes/standard_sae.py
CHANGED
|
@@ -2,7 +2,6 @@ from dataclasses import dataclass
|
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
import torch
|
|
5
|
-
from jaxtyping import Float
|
|
6
5
|
from numpy.typing import NDArray
|
|
7
6
|
from torch import nn
|
|
8
7
|
from typing_extensions import override
|
|
@@ -54,9 +53,7 @@ class StandardSAE(SAE[StandardSAEConfig]):
|
|
|
54
53
|
super().initialize_weights()
|
|
55
54
|
_init_weights_standard(self)
|
|
56
55
|
|
|
57
|
-
def encode(
|
|
58
|
-
self, x: Float[torch.Tensor, "... d_in"]
|
|
59
|
-
) -> Float[torch.Tensor, "... d_sae"]:
|
|
56
|
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
|
60
57
|
"""
|
|
61
58
|
Encode the input tensor into the feature space.
|
|
62
59
|
"""
|
|
@@ -67,9 +64,7 @@ class StandardSAE(SAE[StandardSAEConfig]):
|
|
|
67
64
|
# Apply the activation function (e.g., ReLU, depending on config)
|
|
68
65
|
return self.hook_sae_acts_post(self.activation_fn(hidden_pre))
|
|
69
66
|
|
|
70
|
-
def decode(
|
|
71
|
-
self, feature_acts: Float[torch.Tensor, "... d_sae"]
|
|
72
|
-
) -> Float[torch.Tensor, "... d_in"]:
|
|
67
|
+
def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
|
|
73
68
|
"""
|
|
74
69
|
Decode the feature activations back to the input space.
|
|
75
70
|
Now, if hook_z reshaping is turned on, we reverse the flattening.
|
|
@@ -127,8 +122,8 @@ class StandardTrainingSAE(TrainingSAE[StandardTrainingSAEConfig]):
|
|
|
127
122
|
}
|
|
128
123
|
|
|
129
124
|
def encode_with_hidden_pre(
|
|
130
|
-
self, x:
|
|
131
|
-
) -> tuple[
|
|
125
|
+
self, x: torch.Tensor
|
|
126
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
132
127
|
# Process the input (including dtype conversion, hook call, and any activation normalization)
|
|
133
128
|
sae_in = self.process_sae_in(x)
|
|
134
129
|
# Compute the pre-activation (and allow for a hook if desired)
|