sae-lens 6.12.1__py3-none-any.whl → 6.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +15 -1
- sae_lens/cache_activations_runner.py +1 -1
- sae_lens/config.py +39 -2
- sae_lens/constants.py +1 -0
- sae_lens/evals.py +20 -14
- sae_lens/llm_sae_training_runner.py +17 -18
- sae_lens/loading/pretrained_sae_loaders.py +194 -0
- sae_lens/loading/pretrained_saes_directory.py +5 -3
- sae_lens/pretokenize_runner.py +2 -1
- sae_lens/pretrained_saes.yaml +75 -1
- sae_lens/saes/__init__.py +9 -0
- sae_lens/saes/batchtopk_sae.py +32 -1
- sae_lens/saes/matryoshka_batchtopk_sae.py +137 -0
- sae_lens/saes/sae.py +22 -24
- sae_lens/saes/temporal_sae.py +372 -0
- sae_lens/saes/topk_sae.py +287 -17
- sae_lens/tokenization_and_batching.py +21 -6
- sae_lens/training/activation_scaler.py +7 -0
- sae_lens/training/activations_store.py +52 -31
- sae_lens/training/optim.py +11 -0
- sae_lens/training/sae_trainer.py +57 -16
- sae_lens/training/types.py +1 -1
- sae_lens/util.py +27 -0
- {sae_lens-6.12.1.dist-info → sae_lens-6.21.0.dist-info}/METADATA +19 -17
- sae_lens-6.21.0.dist-info/RECORD +41 -0
- {sae_lens-6.12.1.dist-info → sae_lens-6.21.0.dist-info}/WHEEL +1 -1
- sae_lens-6.12.1.dist-info/RECORD +0 -39
- {sae_lens-6.12.1.dist-info → sae_lens-6.21.0.dist-info/licenses}/LICENSE +0 -0
sae_lens/pretrained_saes.yaml
CHANGED
|
@@ -1,3 +1,35 @@
|
|
|
1
|
+
temporal-sae-gemma-2-2b:
|
|
2
|
+
conversion_func: temporal
|
|
3
|
+
model: gemma-2-2b
|
|
4
|
+
repo_id: canrager/temporalSAEs
|
|
5
|
+
config_overrides:
|
|
6
|
+
model_name: gemma-2-2b
|
|
7
|
+
hook_name: blocks.12.hook_resid_post
|
|
8
|
+
dataset_path: monology/pile-uncopyrighted
|
|
9
|
+
saes:
|
|
10
|
+
- id: blocks.12.hook_resid_post
|
|
11
|
+
l0: 192
|
|
12
|
+
norm_scaling_factor: 0.00666666667
|
|
13
|
+
path: gemma-2-2B/layer_12/temporal
|
|
14
|
+
neuronpedia: gemma-2-2b/12-temporal-res
|
|
15
|
+
temporal-sae-llama-3.1-8b:
|
|
16
|
+
conversion_func: temporal
|
|
17
|
+
model: meta-llama/Llama-3.1-8B
|
|
18
|
+
repo_id: canrager/temporalSAEs
|
|
19
|
+
config_overrides:
|
|
20
|
+
model_name: meta-llama/Llama-3.1-8B
|
|
21
|
+
dataset_path: monology/pile-uncopyrighted
|
|
22
|
+
saes:
|
|
23
|
+
- id: blocks.15.hook_resid_post
|
|
24
|
+
l0: 256
|
|
25
|
+
norm_scaling_factor: 0.029
|
|
26
|
+
path: llama-3.1-8B/layer_15/temporal
|
|
27
|
+
neuronpedia: llama3.1-8b/15-temporal-res
|
|
28
|
+
- id: blocks.26.hook_resid_post
|
|
29
|
+
l0: 256
|
|
30
|
+
norm_scaling_factor: 0.029
|
|
31
|
+
path: llama-3.1-8B/layer_26/temporal
|
|
32
|
+
neuronpedia: llama3.1-8b/26-temporal-res
|
|
1
33
|
deepseek-r1-distill-llama-8b-qresearch:
|
|
2
34
|
conversion_func: deepseek_r1
|
|
3
35
|
model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
|
|
@@ -14882,4 +14914,46 @@ qwen2.5-7b-instruct-andyrdt:
|
|
|
14882
14914
|
neuronpedia: qwen2.5-7b-it/23-resid-post-aa
|
|
14883
14915
|
- id: resid_post_layer_27_trainer_1
|
|
14884
14916
|
path: resid_post_layer_27/trainer_1
|
|
14885
|
-
neuronpedia: qwen2.5-7b-it/27-resid-post-aa
|
|
14917
|
+
neuronpedia: qwen2.5-7b-it/27-resid-post-aa
|
|
14918
|
+
|
|
14919
|
+
gpt-oss-20b-andyrdt:
|
|
14920
|
+
conversion_func: dictionary_learning_1
|
|
14921
|
+
model: openai/gpt-oss-20b
|
|
14922
|
+
repo_id: andyrdt/saes-gpt-oss-20b
|
|
14923
|
+
saes:
|
|
14924
|
+
- id: resid_post_layer_3_trainer_0
|
|
14925
|
+
path: resid_post_layer_3/trainer_0
|
|
14926
|
+
neuronpedia: gpt-oss-20b/3-resid-post-aa
|
|
14927
|
+
- id: resid_post_layer_7_trainer_0
|
|
14928
|
+
path: resid_post_layer_7/trainer_0
|
|
14929
|
+
neuronpedia: gpt-oss-20b/7-resid-post-aa
|
|
14930
|
+
- id: resid_post_layer_11_trainer_0
|
|
14931
|
+
path: resid_post_layer_11/trainer_0
|
|
14932
|
+
neuronpedia: gpt-oss-20b/11-resid-post-aa
|
|
14933
|
+
- id: resid_post_layer_15_trainer_0
|
|
14934
|
+
path: resid_post_layer_15/trainer_0
|
|
14935
|
+
neuronpedia: gpt-oss-20b/15-resid-post-aa
|
|
14936
|
+
- id: resid_post_layer_19_trainer_0
|
|
14937
|
+
path: resid_post_layer_19/trainer_0
|
|
14938
|
+
neuronpedia: gpt-oss-20b/19-resid-post-aa
|
|
14939
|
+
- id: resid_post_layer_23_trainer_0
|
|
14940
|
+
path: resid_post_layer_23/trainer_0
|
|
14941
|
+
neuronpedia: gpt-oss-20b/23-resid-post-aa
|
|
14942
|
+
|
|
14943
|
+
goodfire-llama-3.3-70b-instruct:
|
|
14944
|
+
conversion_func: goodfire
|
|
14945
|
+
model: meta-llama/Llama-3.3-70B-Instruct
|
|
14946
|
+
repo_id: Goodfire/Llama-3.3-70B-Instruct-SAE-l50
|
|
14947
|
+
saes:
|
|
14948
|
+
- id: layer_50
|
|
14949
|
+
path: Llama-3.3-70B-Instruct-SAE-l50.pt
|
|
14950
|
+
l0: 121
|
|
14951
|
+
|
|
14952
|
+
goodfire-llama-3.1-8b-instruct:
|
|
14953
|
+
conversion_func: goodfire
|
|
14954
|
+
model: meta-llama/Llama-3.1-8B-Instruct
|
|
14955
|
+
repo_id: Goodfire/Llama-3.1-8B-Instruct-SAE-l19
|
|
14956
|
+
saes:
|
|
14957
|
+
- id: layer_19
|
|
14958
|
+
path: Llama-3.1-8B-Instruct-SAE-l19.pth
|
|
14959
|
+
l0: 91
|
sae_lens/saes/__init__.py
CHANGED
|
@@ -14,6 +14,10 @@ from .jumprelu_sae import (
|
|
|
14
14
|
JumpReLUTrainingSAE,
|
|
15
15
|
JumpReLUTrainingSAEConfig,
|
|
16
16
|
)
|
|
17
|
+
from .matryoshka_batchtopk_sae import (
|
|
18
|
+
MatryoshkaBatchTopKTrainingSAE,
|
|
19
|
+
MatryoshkaBatchTopKTrainingSAEConfig,
|
|
20
|
+
)
|
|
17
21
|
from .sae import SAE, SAEConfig, TrainingSAE, TrainingSAEConfig
|
|
18
22
|
from .standard_sae import (
|
|
19
23
|
StandardSAE,
|
|
@@ -21,6 +25,7 @@ from .standard_sae import (
|
|
|
21
25
|
StandardTrainingSAE,
|
|
22
26
|
StandardTrainingSAEConfig,
|
|
23
27
|
)
|
|
28
|
+
from .temporal_sae import TemporalSAE, TemporalSAEConfig
|
|
24
29
|
from .topk_sae import (
|
|
25
30
|
TopKSAE,
|
|
26
31
|
TopKSAEConfig,
|
|
@@ -65,4 +70,8 @@ __all__ = [
|
|
|
65
70
|
"SkipTranscoderConfig",
|
|
66
71
|
"JumpReLUTranscoder",
|
|
67
72
|
"JumpReLUTranscoderConfig",
|
|
73
|
+
"MatryoshkaBatchTopKTrainingSAE",
|
|
74
|
+
"MatryoshkaBatchTopKTrainingSAEConfig",
|
|
75
|
+
"TemporalSAE",
|
|
76
|
+
"TemporalSAEConfig",
|
|
68
77
|
]
|
sae_lens/saes/batchtopk_sae.py
CHANGED
|
@@ -23,7 +23,9 @@ class BatchTopK(nn.Module):
|
|
|
23
23
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
24
24
|
acts = x.relu()
|
|
25
25
|
flat_acts = acts.flatten()
|
|
26
|
-
|
|
26
|
+
# Calculate total number of samples across all non-feature dimensions
|
|
27
|
+
num_samples = acts.shape[:-1].numel()
|
|
28
|
+
acts_topk_flat = torch.topk(flat_acts, int(self.k * num_samples), dim=-1)
|
|
27
29
|
return (
|
|
28
30
|
torch.zeros_like(flat_acts)
|
|
29
31
|
.scatter(-1, acts_topk_flat.indices, acts_topk_flat.values)
|
|
@@ -35,6 +37,35 @@ class BatchTopK(nn.Module):
|
|
|
35
37
|
class BatchTopKTrainingSAEConfig(TopKTrainingSAEConfig):
|
|
36
38
|
"""
|
|
37
39
|
Configuration class for training a BatchTopKTrainingSAE.
|
|
40
|
+
|
|
41
|
+
BatchTopK SAEs maintain k active features on average across the entire batch,
|
|
42
|
+
rather than enforcing k features per sample like standard TopK SAEs. During training,
|
|
43
|
+
the SAE learns a global threshold that is updated based on the minimum positive
|
|
44
|
+
activation value. After training, BatchTopK SAEs are saved as JumpReLU SAEs.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
k (float): Average number of features to keep active across the batch. Unlike
|
|
48
|
+
standard TopK SAEs where k is an integer per sample, this is a float
|
|
49
|
+
representing the average number of active features across all samples in
|
|
50
|
+
the batch. Defaults to 100.
|
|
51
|
+
topk_threshold_lr (float): Learning rate for updating the global topk threshold.
|
|
52
|
+
The threshold is updated using an exponential moving average of the minimum
|
|
53
|
+
positive activation value. Defaults to 0.01.
|
|
54
|
+
aux_loss_coefficient (float): Coefficient for the auxiliary loss that encourages
|
|
55
|
+
dead neurons to learn useful features. Inherited from TopKTrainingSAEConfig.
|
|
56
|
+
Defaults to 1.0.
|
|
57
|
+
rescale_acts_by_decoder_norm (bool): Treat the decoder as if it was already normalized.
|
|
58
|
+
Inherited from TopKTrainingSAEConfig. Defaults to True.
|
|
59
|
+
decoder_init_norm (float | None): Norm to initialize decoder weights to.
|
|
60
|
+
Inherited from TrainingSAEConfig. Defaults to 0.1.
|
|
61
|
+
d_in (int): Input dimension (dimensionality of the activations being encoded).
|
|
62
|
+
Inherited from SAEConfig.
|
|
63
|
+
d_sae (int): SAE latent dimension (number of features in the SAE).
|
|
64
|
+
Inherited from SAEConfig.
|
|
65
|
+
dtype (str): Data type for the SAE parameters. Inherited from SAEConfig.
|
|
66
|
+
Defaults to "float32".
|
|
67
|
+
device (str): Device to place the SAE on. Inherited from SAEConfig.
|
|
68
|
+
Defaults to "cpu".
|
|
38
69
|
"""
|
|
39
70
|
|
|
40
71
|
k: float = 100 # type: ignore[assignment]
|
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from jaxtyping import Float
|
|
6
|
+
from typing_extensions import override
|
|
7
|
+
|
|
8
|
+
from sae_lens.saes.batchtopk_sae import (
|
|
9
|
+
BatchTopKTrainingSAE,
|
|
10
|
+
BatchTopKTrainingSAEConfig,
|
|
11
|
+
)
|
|
12
|
+
from sae_lens.saes.sae import TrainStepInput, TrainStepOutput
|
|
13
|
+
from sae_lens.saes.topk_sae import _sparse_matmul_nd
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class MatryoshkaBatchTopKTrainingSAEConfig(BatchTopKTrainingSAEConfig):
|
|
18
|
+
"""
|
|
19
|
+
Configuration class for training a MatryoshkaBatchTopKTrainingSAE.
|
|
20
|
+
|
|
21
|
+
[Matryoshka SAEs](https://arxiv.org/pdf/2503.17547) use a series of nested reconstruction
|
|
22
|
+
losses of different widths during training to avoid feature absorption. This also has a
|
|
23
|
+
nice side-effect of encouraging higher-frequency features to be learned in earlier levels.
|
|
24
|
+
However, this SAE has more hyperparameters to tune than standard BatchTopK SAEs, and takes
|
|
25
|
+
longer to train due to requiring multiple forward passes per training step.
|
|
26
|
+
|
|
27
|
+
After training, MatryoshkaBatchTopK SAEs are saved as JumpReLU SAEs.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
matryoshka_widths (list[int]): The widths of the matryoshka levels. Defaults to an empty list.
|
|
31
|
+
k (float): The number of features to keep active. Inherited from BatchTopKTrainingSAEConfig.
|
|
32
|
+
Defaults to 100.
|
|
33
|
+
topk_threshold_lr (float): Learning rate for updating the global topk threshold.
|
|
34
|
+
The threshold is updated using an exponential moving average of the minimum
|
|
35
|
+
positive activation value. Defaults to 0.01.
|
|
36
|
+
aux_loss_coefficient (float): Coefficient for the auxiliary loss that encourages
|
|
37
|
+
dead neurons to learn useful features. Inherited from TopKTrainingSAEConfig.
|
|
38
|
+
Defaults to 1.0.
|
|
39
|
+
rescale_acts_by_decoder_norm (bool): Treat the decoder as if it was already normalized.
|
|
40
|
+
Inherited from TopKTrainingSAEConfig. Defaults to True.
|
|
41
|
+
decoder_init_norm (float | None): Norm to initialize decoder weights to.
|
|
42
|
+
Inherited from TrainingSAEConfig. Defaults to 0.1.
|
|
43
|
+
d_in (int): Input dimension (dimensionality of the activations being encoded).
|
|
44
|
+
Inherited from SAEConfig.
|
|
45
|
+
d_sae (int): SAE latent dimension (number of features in the SAE).
|
|
46
|
+
Inherited from SAEConfig.
|
|
47
|
+
dtype (str): Data type for the SAE parameters. Inherited from SAEConfig.
|
|
48
|
+
Defaults to "float32".
|
|
49
|
+
device (str): Device to place the SAE on. Inherited from SAEConfig.
|
|
50
|
+
Defaults to "cpu".
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
matryoshka_widths: list[int] = field(default_factory=list)
|
|
54
|
+
|
|
55
|
+
@override
|
|
56
|
+
@classmethod
|
|
57
|
+
def architecture(cls) -> str:
|
|
58
|
+
return "matryoshka_batchtopk"
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class MatryoshkaBatchTopKTrainingSAE(BatchTopKTrainingSAE):
|
|
62
|
+
"""
|
|
63
|
+
Global Batch TopK Training SAE
|
|
64
|
+
|
|
65
|
+
This SAE will maintain the k on average across the batch, rather than enforcing the k per-sample as in standard TopK.
|
|
66
|
+
|
|
67
|
+
BatchTopK SAEs are saved as JumpReLU SAEs after training.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
cfg: MatryoshkaBatchTopKTrainingSAEConfig # type: ignore[assignment]
|
|
71
|
+
|
|
72
|
+
def __init__(
|
|
73
|
+
self, cfg: MatryoshkaBatchTopKTrainingSAEConfig, use_error_term: bool = False
|
|
74
|
+
):
|
|
75
|
+
super().__init__(cfg, use_error_term)
|
|
76
|
+
_validate_matryoshka_config(cfg)
|
|
77
|
+
|
|
78
|
+
@override
|
|
79
|
+
def training_forward_pass(self, step_input: TrainStepInput) -> TrainStepOutput:
|
|
80
|
+
base_output = super().training_forward_pass(step_input)
|
|
81
|
+
inv_W_dec_norm = 1 / self.W_dec.norm(dim=-1)
|
|
82
|
+
# the outer matryoshka level is the base SAE, so we don't need to add an extra loss for it
|
|
83
|
+
for width in self.cfg.matryoshka_widths[:-1]:
|
|
84
|
+
inner_reconstruction = self._decode_matryoshka_level(
|
|
85
|
+
base_output.feature_acts, width, inv_W_dec_norm
|
|
86
|
+
)
|
|
87
|
+
inner_mse_loss = (
|
|
88
|
+
self.mse_loss_fn(inner_reconstruction, step_input.sae_in)
|
|
89
|
+
.sum(dim=-1)
|
|
90
|
+
.mean()
|
|
91
|
+
)
|
|
92
|
+
base_output.losses[f"inner_mse_loss_{width}"] = inner_mse_loss
|
|
93
|
+
base_output.loss = base_output.loss + inner_mse_loss
|
|
94
|
+
return base_output
|
|
95
|
+
|
|
96
|
+
def _decode_matryoshka_level(
|
|
97
|
+
self,
|
|
98
|
+
feature_acts: Float[torch.Tensor, "... d_sae"],
|
|
99
|
+
width: int,
|
|
100
|
+
inv_W_dec_norm: torch.Tensor,
|
|
101
|
+
) -> Float[torch.Tensor, "... d_in"]:
|
|
102
|
+
"""
|
|
103
|
+
Decodes feature activations back into input space for a matryoshka level
|
|
104
|
+
"""
|
|
105
|
+
inner_feature_acts = feature_acts[:, :width]
|
|
106
|
+
# Handle sparse tensors using efficient sparse matrix multiplication
|
|
107
|
+
if self.cfg.rescale_acts_by_decoder_norm:
|
|
108
|
+
# need to multiply by the inverse of the norm because division is illegal with sparse tensors
|
|
109
|
+
inner_feature_acts = inner_feature_acts * inv_W_dec_norm[:width]
|
|
110
|
+
if inner_feature_acts.is_sparse:
|
|
111
|
+
sae_out_pre = (
|
|
112
|
+
_sparse_matmul_nd(inner_feature_acts, self.W_dec[:width]) + self.b_dec
|
|
113
|
+
)
|
|
114
|
+
else:
|
|
115
|
+
sae_out_pre = inner_feature_acts @ self.W_dec[:width] + self.b_dec
|
|
116
|
+
sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
|
|
117
|
+
return self.reshape_fn_out(sae_out_pre, self.d_head)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def _validate_matryoshka_config(cfg: MatryoshkaBatchTopKTrainingSAEConfig) -> None:
|
|
121
|
+
if cfg.matryoshka_widths[-1] != cfg.d_sae:
|
|
122
|
+
# warn the users that we will add a final matryoshka level
|
|
123
|
+
warnings.warn(
|
|
124
|
+
"WARNING: The final matryoshka level width is not set to cfg.d_sae. "
|
|
125
|
+
"A final matryoshka level of width=cfg.d_sae will be added."
|
|
126
|
+
)
|
|
127
|
+
cfg.matryoshka_widths.append(cfg.d_sae)
|
|
128
|
+
|
|
129
|
+
for prev_width, curr_width in zip(
|
|
130
|
+
cfg.matryoshka_widths[:-1], cfg.matryoshka_widths[1:]
|
|
131
|
+
):
|
|
132
|
+
if prev_width >= curr_width:
|
|
133
|
+
raise ValueError("cfg.matryoshka_widths must be strictly increasing.")
|
|
134
|
+
if len(cfg.matryoshka_widths) == 1:
|
|
135
|
+
warnings.warn(
|
|
136
|
+
"WARNING: You have only set one matryoshka level. This is equivalent to using a standard BatchTopK SAE and is likely not what you want."
|
|
137
|
+
)
|
sae_lens/saes/sae.py
CHANGED
|
@@ -14,7 +14,6 @@ from typing import (
|
|
|
14
14
|
Generic,
|
|
15
15
|
Literal,
|
|
16
16
|
NamedTuple,
|
|
17
|
-
Type,
|
|
18
17
|
TypeVar,
|
|
19
18
|
)
|
|
20
19
|
|
|
@@ -22,7 +21,7 @@ import einops
|
|
|
22
21
|
import torch
|
|
23
22
|
from jaxtyping import Float
|
|
24
23
|
from numpy.typing import NDArray
|
|
25
|
-
from safetensors.torch import save_file
|
|
24
|
+
from safetensors.torch import load_file, save_file
|
|
26
25
|
from torch import nn
|
|
27
26
|
from transformer_lens.hook_points import HookedRootModule, HookPoint
|
|
28
27
|
from typing_extensions import deprecated, overload, override
|
|
@@ -156,9 +155,9 @@ class SAEConfig(ABC):
|
|
|
156
155
|
dtype: str = "float32"
|
|
157
156
|
device: str = "cpu"
|
|
158
157
|
apply_b_dec_to_input: bool = True
|
|
159
|
-
normalize_activations: Literal[
|
|
160
|
-
"none",
|
|
161
|
-
|
|
158
|
+
normalize_activations: Literal["none", "expected_average_only_in", "layer_norm"] = (
|
|
159
|
+
"none" # none, expected_average_only_in (Anthropic April Update)
|
|
160
|
+
)
|
|
162
161
|
reshape_activations: Literal["none", "hook_z"] = "none"
|
|
163
162
|
metadata: SAEMetadata = field(default_factory=SAEMetadata)
|
|
164
163
|
|
|
@@ -218,6 +217,7 @@ class TrainStepInput:
|
|
|
218
217
|
sae_in: torch.Tensor
|
|
219
218
|
coefficients: dict[str, float]
|
|
220
219
|
dead_neuron_mask: torch.Tensor | None
|
|
220
|
+
n_training_steps: int
|
|
221
221
|
|
|
222
222
|
|
|
223
223
|
class TrainCoefficientConfig(NamedTuple):
|
|
@@ -245,7 +245,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
245
245
|
|
|
246
246
|
self.cfg = cfg
|
|
247
247
|
|
|
248
|
-
if cfg.metadata and cfg.metadata:
|
|
248
|
+
if cfg.metadata and cfg.metadata.model_from_pretrained_kwargs:
|
|
249
249
|
warnings.warn(
|
|
250
250
|
"\nThis SAE has non-empty model_from_pretrained_kwargs. "
|
|
251
251
|
"\nFor optimal performance, load the model like so:\n"
|
|
@@ -309,6 +309,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
309
309
|
|
|
310
310
|
self.run_time_activation_norm_fn_in = run_time_activation_norm_fn_in
|
|
311
311
|
self.run_time_activation_norm_fn_out = run_time_activation_norm_fn_out
|
|
312
|
+
|
|
312
313
|
elif self.cfg.normalize_activations == "layer_norm":
|
|
313
314
|
# we need to scale the norm of the input and store the scaling factor
|
|
314
315
|
def run_time_activation_ln_in(
|
|
@@ -452,23 +453,14 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
452
453
|
def process_sae_in(
|
|
453
454
|
self, sae_in: Float[torch.Tensor, "... d_in"]
|
|
454
455
|
) -> Float[torch.Tensor, "... d_in"]:
|
|
455
|
-
# print(f"Input shape to process_sae_in: {sae_in.shape}")
|
|
456
|
-
# print(f"self.cfg.hook_name: {self.cfg.hook_name}")
|
|
457
|
-
# print(f"self.b_dec shape: {self.b_dec.shape}")
|
|
458
|
-
# print(f"Hook z reshaping mode: {getattr(self, 'hook_z_reshaping_mode', False)}")
|
|
459
|
-
|
|
460
456
|
sae_in = sae_in.to(self.dtype)
|
|
461
|
-
|
|
462
|
-
# print(f"Shape before reshape_fn_in: {sae_in.shape}")
|
|
463
457
|
sae_in = self.reshape_fn_in(sae_in)
|
|
464
|
-
# print(f"Shape after reshape_fn_in: {sae_in.shape}")
|
|
465
458
|
|
|
466
459
|
sae_in = self.hook_sae_input(sae_in)
|
|
467
460
|
sae_in = self.run_time_activation_norm_fn_in(sae_in)
|
|
468
461
|
|
|
469
462
|
# Here's where the error happens
|
|
470
463
|
bias_term = self.b_dec * self.cfg.apply_b_dec_to_input
|
|
471
|
-
# print(f"Bias term shape: {bias_term.shape}")
|
|
472
464
|
|
|
473
465
|
return sae_in - bias_term
|
|
474
466
|
|
|
@@ -534,7 +526,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
534
526
|
@classmethod
|
|
535
527
|
@deprecated("Use load_from_disk instead")
|
|
536
528
|
def load_from_pretrained(
|
|
537
|
-
cls:
|
|
529
|
+
cls: type[T_SAE],
|
|
538
530
|
path: str | Path,
|
|
539
531
|
device: str = "cpu",
|
|
540
532
|
dtype: str | None = None,
|
|
@@ -543,7 +535,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
543
535
|
|
|
544
536
|
@classmethod
|
|
545
537
|
def load_from_disk(
|
|
546
|
-
cls:
|
|
538
|
+
cls: type[T_SAE],
|
|
547
539
|
path: str | Path,
|
|
548
540
|
device: str = "cpu",
|
|
549
541
|
dtype: str | None = None,
|
|
@@ -564,7 +556,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
564
556
|
|
|
565
557
|
@classmethod
|
|
566
558
|
def from_pretrained(
|
|
567
|
-
cls:
|
|
559
|
+
cls: type[T_SAE],
|
|
568
560
|
release: str,
|
|
569
561
|
sae_id: str,
|
|
570
562
|
device: str = "cpu",
|
|
@@ -585,7 +577,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
585
577
|
|
|
586
578
|
@classmethod
|
|
587
579
|
def from_pretrained_with_cfg_and_sparsity(
|
|
588
|
-
cls:
|
|
580
|
+
cls: type[T_SAE],
|
|
589
581
|
release: str,
|
|
590
582
|
sae_id: str,
|
|
591
583
|
device: str = "cpu",
|
|
@@ -684,7 +676,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
684
676
|
return sae, cfg_dict, log_sparsities
|
|
685
677
|
|
|
686
678
|
@classmethod
|
|
687
|
-
def from_dict(cls:
|
|
679
|
+
def from_dict(cls: type[T_SAE], config_dict: dict[str, Any]) -> T_SAE:
|
|
688
680
|
"""Create an SAE from a config dictionary."""
|
|
689
681
|
sae_cls = cls.get_sae_class_for_architecture(config_dict["architecture"])
|
|
690
682
|
sae_config_cls = cls.get_sae_config_class_for_architecture(
|
|
@@ -694,8 +686,8 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
694
686
|
|
|
695
687
|
@classmethod
|
|
696
688
|
def get_sae_class_for_architecture(
|
|
697
|
-
cls:
|
|
698
|
-
) ->
|
|
689
|
+
cls: type[T_SAE], architecture: str
|
|
690
|
+
) -> type[T_SAE]:
|
|
699
691
|
"""Get the SAE class for a given architecture."""
|
|
700
692
|
sae_cls, _ = get_sae_class(architecture)
|
|
701
693
|
if not issubclass(sae_cls, cls):
|
|
@@ -1000,8 +992,8 @@ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
|
|
|
1000
992
|
|
|
1001
993
|
@classmethod
|
|
1002
994
|
def get_sae_class_for_architecture(
|
|
1003
|
-
cls:
|
|
1004
|
-
) ->
|
|
995
|
+
cls: type[T_TRAINING_SAE], architecture: str
|
|
996
|
+
) -> type[T_TRAINING_SAE]:
|
|
1005
997
|
"""Get the SAE class for a given architecture."""
|
|
1006
998
|
sae_cls, _ = get_sae_training_class(architecture)
|
|
1007
999
|
if not issubclass(sae_cls, cls):
|
|
@@ -1018,6 +1010,12 @@ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
|
|
|
1018
1010
|
) -> type[TrainingSAEConfig]:
|
|
1019
1011
|
return get_sae_training_class(architecture)[1]
|
|
1020
1012
|
|
|
1013
|
+
def load_weights_from_checkpoint(self, checkpoint_path: Path | str) -> None:
|
|
1014
|
+
checkpoint_path = Path(checkpoint_path)
|
|
1015
|
+
state_dict = load_file(checkpoint_path / SAE_WEIGHTS_FILENAME)
|
|
1016
|
+
self.process_state_dict_for_loading(state_dict)
|
|
1017
|
+
self.load_state_dict(state_dict)
|
|
1018
|
+
|
|
1021
1019
|
|
|
1022
1020
|
_blank_hook = nn.Identity()
|
|
1023
1021
|
|