sae-lens 6.15.0__py3-none-any.whl → 6.22.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +6 -1
- sae_lens/analysis/hooked_sae_transformer.py +4 -13
- sae_lens/cache_activations_runner.py +3 -4
- sae_lens/config.py +39 -2
- sae_lens/constants.py +1 -0
- sae_lens/llm_sae_training_runner.py +9 -4
- sae_lens/loading/pretrained_sae_loaders.py +188 -0
- sae_lens/loading/pretrained_saes_directory.py +5 -3
- sae_lens/pretrained_saes.yaml +77 -1
- sae_lens/saes/__init__.py +3 -0
- sae_lens/saes/batchtopk_sae.py +3 -1
- sae_lens/saes/gated_sae.py +4 -9
- sae_lens/saes/jumprelu_sae.py +4 -9
- sae_lens/saes/matryoshka_batchtopk_sae.py +8 -15
- sae_lens/saes/sae.py +19 -31
- sae_lens/saes/standard_sae.py +4 -9
- sae_lens/saes/temporal_sae.py +365 -0
- sae_lens/saes/topk_sae.py +7 -10
- sae_lens/training/activation_scaler.py +7 -0
- sae_lens/training/activations_store.py +49 -7
- sae_lens/training/optim.py +11 -0
- sae_lens/training/sae_trainer.py +50 -11
- {sae_lens-6.15.0.dist-info → sae_lens-6.22.1.dist-info}/METADATA +16 -16
- sae_lens-6.22.1.dist-info/RECORD +41 -0
- sae_lens-6.15.0.dist-info/RECORD +0 -40
- {sae_lens-6.15.0.dist-info → sae_lens-6.22.1.dist-info}/WHEEL +0 -0
- {sae_lens-6.15.0.dist-info → sae_lens-6.22.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -2,7 +2,6 @@ import warnings
|
|
|
2
2
|
from dataclasses import dataclass, field
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
|
-
from jaxtyping import Float
|
|
6
5
|
from typing_extensions import override
|
|
7
6
|
|
|
8
7
|
from sae_lens.saes.batchtopk_sae import (
|
|
@@ -78,14 +77,11 @@ class MatryoshkaBatchTopKTrainingSAE(BatchTopKTrainingSAE):
|
|
|
78
77
|
@override
|
|
79
78
|
def training_forward_pass(self, step_input: TrainStepInput) -> TrainStepOutput:
|
|
80
79
|
base_output = super().training_forward_pass(step_input)
|
|
81
|
-
hidden_pre = base_output.hidden_pre
|
|
82
80
|
inv_W_dec_norm = 1 / self.W_dec.norm(dim=-1)
|
|
83
81
|
# the outer matryoshka level is the base SAE, so we don't need to add an extra loss for it
|
|
84
82
|
for width in self.cfg.matryoshka_widths[:-1]:
|
|
85
|
-
inner_hidden_pre = hidden_pre[:, :width]
|
|
86
|
-
inner_feat_acts = self.activation_fn(inner_hidden_pre)
|
|
87
83
|
inner_reconstruction = self._decode_matryoshka_level(
|
|
88
|
-
|
|
84
|
+
base_output.feature_acts, width, inv_W_dec_norm
|
|
89
85
|
)
|
|
90
86
|
inner_mse_loss = (
|
|
91
87
|
self.mse_loss_fn(inner_reconstruction, step_input.sae_in)
|
|
@@ -98,23 +94,24 @@ class MatryoshkaBatchTopKTrainingSAE(BatchTopKTrainingSAE):
|
|
|
98
94
|
|
|
99
95
|
def _decode_matryoshka_level(
|
|
100
96
|
self,
|
|
101
|
-
feature_acts:
|
|
97
|
+
feature_acts: torch.Tensor,
|
|
102
98
|
width: int,
|
|
103
99
|
inv_W_dec_norm: torch.Tensor,
|
|
104
|
-
) ->
|
|
100
|
+
) -> torch.Tensor:
|
|
105
101
|
"""
|
|
106
102
|
Decodes feature activations back into input space for a matryoshka level
|
|
107
103
|
"""
|
|
104
|
+
inner_feature_acts = feature_acts[:, :width]
|
|
108
105
|
# Handle sparse tensors using efficient sparse matrix multiplication
|
|
109
106
|
if self.cfg.rescale_acts_by_decoder_norm:
|
|
110
107
|
# need to multiply by the inverse of the norm because division is illegal with sparse tensors
|
|
111
|
-
|
|
112
|
-
if
|
|
108
|
+
inner_feature_acts = inner_feature_acts * inv_W_dec_norm[:width]
|
|
109
|
+
if inner_feature_acts.is_sparse:
|
|
113
110
|
sae_out_pre = (
|
|
114
|
-
_sparse_matmul_nd(
|
|
111
|
+
_sparse_matmul_nd(inner_feature_acts, self.W_dec[:width]) + self.b_dec
|
|
115
112
|
)
|
|
116
113
|
else:
|
|
117
|
-
sae_out_pre =
|
|
114
|
+
sae_out_pre = inner_feature_acts @ self.W_dec[:width] + self.b_dec
|
|
118
115
|
sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
|
|
119
116
|
return self.reshape_fn_out(sae_out_pre, self.d_head)
|
|
120
117
|
|
|
@@ -137,7 +134,3 @@ def _validate_matryoshka_config(cfg: MatryoshkaBatchTopKTrainingSAEConfig) -> No
|
|
|
137
134
|
warnings.warn(
|
|
138
135
|
"WARNING: You have only set one matryoshka level. This is equivalent to using a standard BatchTopK SAE and is likely not what you want."
|
|
139
136
|
)
|
|
140
|
-
if cfg.matryoshka_widths[0] < cfg.k:
|
|
141
|
-
raise ValueError(
|
|
142
|
-
"The smallest matryoshka level width cannot be smaller than cfg.k."
|
|
143
|
-
)
|
sae_lens/saes/sae.py
CHANGED
|
@@ -19,9 +19,8 @@ from typing import (
|
|
|
19
19
|
|
|
20
20
|
import einops
|
|
21
21
|
import torch
|
|
22
|
-
from jaxtyping import Float
|
|
23
22
|
from numpy.typing import NDArray
|
|
24
|
-
from safetensors.torch import save_file
|
|
23
|
+
from safetensors.torch import load_file, save_file
|
|
25
24
|
from torch import nn
|
|
26
25
|
from transformer_lens.hook_points import HookedRootModule, HookPoint
|
|
27
26
|
from typing_extensions import deprecated, overload, override
|
|
@@ -155,9 +154,9 @@ class SAEConfig(ABC):
|
|
|
155
154
|
dtype: str = "float32"
|
|
156
155
|
device: str = "cpu"
|
|
157
156
|
apply_b_dec_to_input: bool = True
|
|
158
|
-
normalize_activations: Literal[
|
|
159
|
-
"none",
|
|
160
|
-
|
|
157
|
+
normalize_activations: Literal["none", "expected_average_only_in", "layer_norm"] = (
|
|
158
|
+
"none" # none, expected_average_only_in (Anthropic April Update)
|
|
159
|
+
)
|
|
161
160
|
reshape_activations: Literal["none", "hook_z"] = "none"
|
|
162
161
|
metadata: SAEMetadata = field(default_factory=SAEMetadata)
|
|
163
162
|
|
|
@@ -217,6 +216,7 @@ class TrainStepInput:
|
|
|
217
216
|
sae_in: torch.Tensor
|
|
218
217
|
coefficients: dict[str, float]
|
|
219
218
|
dead_neuron_mask: torch.Tensor | None
|
|
219
|
+
n_training_steps: int
|
|
220
220
|
|
|
221
221
|
|
|
222
222
|
class TrainCoefficientConfig(NamedTuple):
|
|
@@ -308,6 +308,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
308
308
|
|
|
309
309
|
self.run_time_activation_norm_fn_in = run_time_activation_norm_fn_in
|
|
310
310
|
self.run_time_activation_norm_fn_out = run_time_activation_norm_fn_out
|
|
311
|
+
|
|
311
312
|
elif self.cfg.normalize_activations == "layer_norm":
|
|
312
313
|
# we need to scale the norm of the input and store the scaling factor
|
|
313
314
|
def run_time_activation_ln_in(
|
|
@@ -349,16 +350,12 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
349
350
|
self.W_enc = nn.Parameter(w_enc_data)
|
|
350
351
|
|
|
351
352
|
@abstractmethod
|
|
352
|
-
def encode(
|
|
353
|
-
self, x: Float[torch.Tensor, "... d_in"]
|
|
354
|
-
) -> Float[torch.Tensor, "... d_sae"]:
|
|
353
|
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
|
355
354
|
"""Encode input tensor to feature space."""
|
|
356
355
|
pass
|
|
357
356
|
|
|
358
357
|
@abstractmethod
|
|
359
|
-
def decode(
|
|
360
|
-
self, feature_acts: Float[torch.Tensor, "... d_sae"]
|
|
361
|
-
) -> Float[torch.Tensor, "... d_in"]:
|
|
358
|
+
def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
|
|
362
359
|
"""Decode feature activations back to input space."""
|
|
363
360
|
pass
|
|
364
361
|
|
|
@@ -448,26 +445,15 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
448
445
|
|
|
449
446
|
return super().to(*args, **kwargs)
|
|
450
447
|
|
|
451
|
-
def process_sae_in(
|
|
452
|
-
self, sae_in: Float[torch.Tensor, "... d_in"]
|
|
453
|
-
) -> Float[torch.Tensor, "... d_in"]:
|
|
454
|
-
# print(f"Input shape to process_sae_in: {sae_in.shape}")
|
|
455
|
-
# print(f"self.cfg.hook_name: {self.cfg.hook_name}")
|
|
456
|
-
# print(f"self.b_dec shape: {self.b_dec.shape}")
|
|
457
|
-
# print(f"Hook z reshaping mode: {getattr(self, 'hook_z_reshaping_mode', False)}")
|
|
458
|
-
|
|
448
|
+
def process_sae_in(self, sae_in: torch.Tensor) -> torch.Tensor:
|
|
459
449
|
sae_in = sae_in.to(self.dtype)
|
|
460
|
-
|
|
461
|
-
# print(f"Shape before reshape_fn_in: {sae_in.shape}")
|
|
462
450
|
sae_in = self.reshape_fn_in(sae_in)
|
|
463
|
-
# print(f"Shape after reshape_fn_in: {sae_in.shape}")
|
|
464
451
|
|
|
465
452
|
sae_in = self.hook_sae_input(sae_in)
|
|
466
453
|
sae_in = self.run_time_activation_norm_fn_in(sae_in)
|
|
467
454
|
|
|
468
455
|
# Here's where the error happens
|
|
469
456
|
bias_term = self.b_dec * self.cfg.apply_b_dec_to_input
|
|
470
|
-
# print(f"Bias term shape: {bias_term.shape}")
|
|
471
457
|
|
|
472
458
|
return sae_in - bias_term
|
|
473
459
|
|
|
@@ -866,14 +852,12 @@ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
|
|
|
866
852
|
|
|
867
853
|
@abstractmethod
|
|
868
854
|
def encode_with_hidden_pre(
|
|
869
|
-
self, x:
|
|
870
|
-
) -> tuple[
|
|
855
|
+
self, x: torch.Tensor
|
|
856
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
871
857
|
"""Encode with access to pre-activation values for training."""
|
|
872
858
|
...
|
|
873
859
|
|
|
874
|
-
def encode(
|
|
875
|
-
self, x: Float[torch.Tensor, "... d_in"]
|
|
876
|
-
) -> Float[torch.Tensor, "... d_sae"]:
|
|
860
|
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
|
877
861
|
"""
|
|
878
862
|
For inference, just encode without returning hidden_pre.
|
|
879
863
|
(training_forward_pass calls encode_with_hidden_pre).
|
|
@@ -881,9 +865,7 @@ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
|
|
|
881
865
|
feature_acts, _ = self.encode_with_hidden_pre(x)
|
|
882
866
|
return feature_acts
|
|
883
867
|
|
|
884
|
-
def decode(
|
|
885
|
-
self, feature_acts: Float[torch.Tensor, "... d_sae"]
|
|
886
|
-
) -> Float[torch.Tensor, "... d_in"]:
|
|
868
|
+
def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
|
|
887
869
|
"""
|
|
888
870
|
Decodes feature activations back into input space,
|
|
889
871
|
applying optional finetuning scale, hooking, out normalization, etc.
|
|
@@ -1017,6 +999,12 @@ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
|
|
|
1017
999
|
) -> type[TrainingSAEConfig]:
|
|
1018
1000
|
return get_sae_training_class(architecture)[1]
|
|
1019
1001
|
|
|
1002
|
+
def load_weights_from_checkpoint(self, checkpoint_path: Path | str) -> None:
|
|
1003
|
+
checkpoint_path = Path(checkpoint_path)
|
|
1004
|
+
state_dict = load_file(checkpoint_path / SAE_WEIGHTS_FILENAME)
|
|
1005
|
+
self.process_state_dict_for_loading(state_dict)
|
|
1006
|
+
self.load_state_dict(state_dict)
|
|
1007
|
+
|
|
1020
1008
|
|
|
1021
1009
|
_blank_hook = nn.Identity()
|
|
1022
1010
|
|
sae_lens/saes/standard_sae.py
CHANGED
|
@@ -2,7 +2,6 @@ from dataclasses import dataclass
|
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
import torch
|
|
5
|
-
from jaxtyping import Float
|
|
6
5
|
from numpy.typing import NDArray
|
|
7
6
|
from torch import nn
|
|
8
7
|
from typing_extensions import override
|
|
@@ -54,9 +53,7 @@ class StandardSAE(SAE[StandardSAEConfig]):
|
|
|
54
53
|
super().initialize_weights()
|
|
55
54
|
_init_weights_standard(self)
|
|
56
55
|
|
|
57
|
-
def encode(
|
|
58
|
-
self, x: Float[torch.Tensor, "... d_in"]
|
|
59
|
-
) -> Float[torch.Tensor, "... d_sae"]:
|
|
56
|
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
|
60
57
|
"""
|
|
61
58
|
Encode the input tensor into the feature space.
|
|
62
59
|
"""
|
|
@@ -67,9 +64,7 @@ class StandardSAE(SAE[StandardSAEConfig]):
|
|
|
67
64
|
# Apply the activation function (e.g., ReLU, depending on config)
|
|
68
65
|
return self.hook_sae_acts_post(self.activation_fn(hidden_pre))
|
|
69
66
|
|
|
70
|
-
def decode(
|
|
71
|
-
self, feature_acts: Float[torch.Tensor, "... d_sae"]
|
|
72
|
-
) -> Float[torch.Tensor, "... d_in"]:
|
|
67
|
+
def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
|
|
73
68
|
"""
|
|
74
69
|
Decode the feature activations back to the input space.
|
|
75
70
|
Now, if hook_z reshaping is turned on, we reverse the flattening.
|
|
@@ -127,8 +122,8 @@ class StandardTrainingSAE(TrainingSAE[StandardTrainingSAEConfig]):
|
|
|
127
122
|
}
|
|
128
123
|
|
|
129
124
|
def encode_with_hidden_pre(
|
|
130
|
-
self, x:
|
|
131
|
-
) -> tuple[
|
|
125
|
+
self, x: torch.Tensor
|
|
126
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
132
127
|
# Process the input (including dtype conversion, hook call, and any activation normalization)
|
|
133
128
|
sae_in = self.process_sae_in(x)
|
|
134
129
|
# Compute the pre-activation (and allow for a hook if desired)
|
|
@@ -0,0 +1,365 @@
|
|
|
1
|
+
"""TemporalSAE: A Sparse Autoencoder with temporal attention mechanism.
|
|
2
|
+
|
|
3
|
+
TemporalSAE decomposes activations into:
|
|
4
|
+
1. Predicted codes (from attention over context)
|
|
5
|
+
2. Novel codes (sparse features of the residual)
|
|
6
|
+
|
|
7
|
+
See: https://arxiv.org/abs/2410.04185
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import math
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from typing import Literal
|
|
13
|
+
|
|
14
|
+
import torch
|
|
15
|
+
import torch.nn.functional as F
|
|
16
|
+
from torch import nn
|
|
17
|
+
from typing_extensions import override
|
|
18
|
+
|
|
19
|
+
from sae_lens import logger
|
|
20
|
+
from sae_lens.saes.sae import SAE, SAEConfig
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def get_attention(query: torch.Tensor, key: torch.Tensor) -> torch.Tensor:
|
|
24
|
+
"""Compute causal attention weights."""
|
|
25
|
+
L, S = query.size(-2), key.size(-2)
|
|
26
|
+
scale_factor = 1 / math.sqrt(query.size(-1))
|
|
27
|
+
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
|
|
28
|
+
temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0)
|
|
29
|
+
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
|
30
|
+
attn_bias.to(query.dtype)
|
|
31
|
+
|
|
32
|
+
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
|
33
|
+
attn_weight += attn_bias
|
|
34
|
+
return torch.softmax(attn_weight, dim=-1)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class ManualAttention(nn.Module):
|
|
38
|
+
"""Manual attention implementation for TemporalSAE."""
|
|
39
|
+
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
dimin: int,
|
|
43
|
+
n_heads: int = 4,
|
|
44
|
+
bottleneck_factor: int = 64,
|
|
45
|
+
bias_k: bool = True,
|
|
46
|
+
bias_q: bool = True,
|
|
47
|
+
bias_v: bool = True,
|
|
48
|
+
bias_o: bool = True,
|
|
49
|
+
):
|
|
50
|
+
super().__init__()
|
|
51
|
+
assert dimin % (bottleneck_factor * n_heads) == 0
|
|
52
|
+
|
|
53
|
+
self.n_heads = n_heads
|
|
54
|
+
self.n_embds = dimin // bottleneck_factor
|
|
55
|
+
self.dimin = dimin
|
|
56
|
+
|
|
57
|
+
# Key, query, value projections
|
|
58
|
+
self.k_ctx = nn.Linear(dimin, self.n_embds, bias=bias_k)
|
|
59
|
+
self.q_target = nn.Linear(dimin, self.n_embds, bias=bias_q)
|
|
60
|
+
self.v_ctx = nn.Linear(dimin, dimin, bias=bias_v)
|
|
61
|
+
self.c_proj = nn.Linear(dimin, dimin, bias=bias_o)
|
|
62
|
+
|
|
63
|
+
# Normalize to match scale with representations
|
|
64
|
+
with torch.no_grad():
|
|
65
|
+
scaling = 1 / math.sqrt(self.n_embds // self.n_heads)
|
|
66
|
+
self.k_ctx.weight.copy_(
|
|
67
|
+
scaling
|
|
68
|
+
* self.k_ctx.weight
|
|
69
|
+
/ (1e-6 + torch.linalg.norm(self.k_ctx.weight, dim=1, keepdim=True))
|
|
70
|
+
)
|
|
71
|
+
self.q_target.weight.copy_(
|
|
72
|
+
scaling
|
|
73
|
+
* self.q_target.weight
|
|
74
|
+
/ (1e-6 + torch.linalg.norm(self.q_target.weight, dim=1, keepdim=True))
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
scaling = 1 / math.sqrt(self.dimin // self.n_heads)
|
|
78
|
+
self.v_ctx.weight.copy_(
|
|
79
|
+
scaling
|
|
80
|
+
* self.v_ctx.weight
|
|
81
|
+
/ (1e-6 + torch.linalg.norm(self.v_ctx.weight, dim=1, keepdim=True))
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
scaling = 1 / math.sqrt(self.dimin)
|
|
85
|
+
self.c_proj.weight.copy_(
|
|
86
|
+
scaling
|
|
87
|
+
* self.c_proj.weight
|
|
88
|
+
/ (1e-6 + torch.linalg.norm(self.c_proj.weight, dim=1, keepdim=True))
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
def forward(
|
|
92
|
+
self, x_ctx: torch.Tensor, x_target: torch.Tensor, get_attn_map: bool = False
|
|
93
|
+
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
94
|
+
"""Compute projective attention output."""
|
|
95
|
+
k = self.k_ctx(x_ctx)
|
|
96
|
+
v = self.v_ctx(x_ctx)
|
|
97
|
+
q = self.q_target(x_target)
|
|
98
|
+
|
|
99
|
+
# Split into heads
|
|
100
|
+
B, T, _ = x_ctx.size()
|
|
101
|
+
k = k.view(B, T, self.n_heads, self.n_embds // self.n_heads).transpose(1, 2)
|
|
102
|
+
q = q.view(B, T, self.n_heads, self.n_embds // self.n_heads).transpose(1, 2)
|
|
103
|
+
v = v.view(B, T, self.n_heads, self.dimin // self.n_heads).transpose(1, 2)
|
|
104
|
+
|
|
105
|
+
# Attention map (optional)
|
|
106
|
+
attn_map = None
|
|
107
|
+
if get_attn_map:
|
|
108
|
+
attn_map = get_attention(query=q, key=k)
|
|
109
|
+
|
|
110
|
+
# Scaled dot-product attention
|
|
111
|
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
|
112
|
+
q, k, v, attn_mask=None, dropout_p=0, is_causal=True
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
# Reshape and project
|
|
116
|
+
d_target = self.c_proj(
|
|
117
|
+
attn_output.transpose(1, 2).contiguous().view(B, T, self.dimin)
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
return d_target, attn_map
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
@dataclass
|
|
124
|
+
class TemporalSAEConfig(SAEConfig):
|
|
125
|
+
"""Configuration for TemporalSAE inference.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
d_in: Input dimension (dimensionality of the activations being encoded)
|
|
129
|
+
d_sae: SAE latent dimension (number of features)
|
|
130
|
+
n_heads: Number of attention heads in temporal attention
|
|
131
|
+
n_attn_layers: Number of attention layers
|
|
132
|
+
bottleneck_factor: Bottleneck factor for attention dimension
|
|
133
|
+
sae_diff_type: Type of SAE for novel codes ('relu' or 'topk')
|
|
134
|
+
kval_topk: K value for top-k sparsity (if sae_diff_type='topk')
|
|
135
|
+
tied_weights: Whether to tie encoder and decoder weights
|
|
136
|
+
activation_normalization_factor: Scalar factor for rescaling activations (used with normalize_activations='constant_scalar_rescale')
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
n_heads: int = 8
|
|
140
|
+
n_attn_layers: int = 1
|
|
141
|
+
bottleneck_factor: int = 64
|
|
142
|
+
sae_diff_type: Literal["relu", "topk"] = "topk"
|
|
143
|
+
kval_topk: int | None = None
|
|
144
|
+
tied_weights: bool = True
|
|
145
|
+
activation_normalization_factor: float = 1.0
|
|
146
|
+
|
|
147
|
+
def __post_init__(self):
|
|
148
|
+
# Call parent's __post_init__ first, but allow constant_scalar_rescale
|
|
149
|
+
if self.normalize_activations not in [
|
|
150
|
+
"none",
|
|
151
|
+
"expected_average_only_in",
|
|
152
|
+
"constant_norm_rescale",
|
|
153
|
+
"constant_scalar_rescale", # Temporal SAEs support this
|
|
154
|
+
"layer_norm",
|
|
155
|
+
]:
|
|
156
|
+
raise ValueError(
|
|
157
|
+
f"normalize_activations must be none, expected_average_only_in, layer_norm, constant_norm_rescale, or constant_scalar_rescale. Got {self.normalize_activations}"
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
@override
|
|
161
|
+
@classmethod
|
|
162
|
+
def architecture(cls) -> str:
|
|
163
|
+
return "temporal"
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
class TemporalSAE(SAE[TemporalSAEConfig]):
|
|
167
|
+
"""TemporalSAE: Sparse Autoencoder with temporal attention.
|
|
168
|
+
|
|
169
|
+
This SAE decomposes each activation x_t into:
|
|
170
|
+
- x_pred: Information aggregated from context {x_0, ..., x_{t-1}}
|
|
171
|
+
- x_novel: Novel information at position t (encoded sparsely)
|
|
172
|
+
|
|
173
|
+
The forward pass:
|
|
174
|
+
1. Uses attention layers to predict x_t from context
|
|
175
|
+
2. Encodes the residual (novel part) with a sparse SAE
|
|
176
|
+
3. Combines both for reconstruction
|
|
177
|
+
"""
|
|
178
|
+
|
|
179
|
+
# Custom parameters (in addition to W_enc, W_dec, b_dec from base)
|
|
180
|
+
attn_layers: nn.ModuleList # Attention layers
|
|
181
|
+
eps: float
|
|
182
|
+
lam: float
|
|
183
|
+
|
|
184
|
+
def __init__(self, cfg: TemporalSAEConfig, use_error_term: bool = False):
|
|
185
|
+
# Call parent init first
|
|
186
|
+
super().__init__(cfg, use_error_term)
|
|
187
|
+
|
|
188
|
+
# Initialize attention layers after parent init and move to correct device
|
|
189
|
+
self.attn_layers = nn.ModuleList(
|
|
190
|
+
[
|
|
191
|
+
ManualAttention(
|
|
192
|
+
dimin=cfg.d_sae,
|
|
193
|
+
n_heads=cfg.n_heads,
|
|
194
|
+
bottleneck_factor=cfg.bottleneck_factor,
|
|
195
|
+
bias_k=True,
|
|
196
|
+
bias_q=True,
|
|
197
|
+
bias_v=True,
|
|
198
|
+
bias_o=True,
|
|
199
|
+
).to(device=self.device, dtype=self.dtype)
|
|
200
|
+
for _ in range(cfg.n_attn_layers)
|
|
201
|
+
]
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
self.eps = 1e-6
|
|
205
|
+
self.lam = 1 / (4 * self.cfg.d_in)
|
|
206
|
+
|
|
207
|
+
@override
|
|
208
|
+
def _setup_activation_normalization(self):
|
|
209
|
+
"""Set up activation normalization functions for TemporalSAE.
|
|
210
|
+
|
|
211
|
+
Overrides the base implementation to handle constant_scalar_rescale
|
|
212
|
+
using the temporal-specific activation_normalization_factor.
|
|
213
|
+
"""
|
|
214
|
+
if self.cfg.normalize_activations == "constant_scalar_rescale":
|
|
215
|
+
# Handle constant scalar rescaling for temporal SAEs
|
|
216
|
+
def run_time_activation_norm_fn_in(x: torch.Tensor) -> torch.Tensor:
|
|
217
|
+
return x * self.cfg.activation_normalization_factor
|
|
218
|
+
|
|
219
|
+
def run_time_activation_norm_fn_out(x: torch.Tensor) -> torch.Tensor:
|
|
220
|
+
return x / self.cfg.activation_normalization_factor
|
|
221
|
+
|
|
222
|
+
self.run_time_activation_norm_fn_in = run_time_activation_norm_fn_in
|
|
223
|
+
self.run_time_activation_norm_fn_out = run_time_activation_norm_fn_out
|
|
224
|
+
else:
|
|
225
|
+
# Delegate to parent for all other normalization types
|
|
226
|
+
super()._setup_activation_normalization()
|
|
227
|
+
|
|
228
|
+
@override
|
|
229
|
+
def initialize_weights(self) -> None:
|
|
230
|
+
"""Initialize TemporalSAE weights."""
|
|
231
|
+
# Initialize D (decoder) and b (bias)
|
|
232
|
+
self.W_dec = nn.Parameter(
|
|
233
|
+
torch.randn(
|
|
234
|
+
(self.cfg.d_sae, self.cfg.d_in), dtype=self.dtype, device=self.device
|
|
235
|
+
)
|
|
236
|
+
)
|
|
237
|
+
self.b_dec = nn.Parameter(
|
|
238
|
+
torch.zeros((self.cfg.d_in), dtype=self.dtype, device=self.device)
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
# Initialize E (encoder) if not tied
|
|
242
|
+
if not self.cfg.tied_weights:
|
|
243
|
+
self.W_enc = nn.Parameter(
|
|
244
|
+
torch.randn(
|
|
245
|
+
(self.cfg.d_in, self.cfg.d_sae),
|
|
246
|
+
dtype=self.dtype,
|
|
247
|
+
device=self.device,
|
|
248
|
+
)
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
def encode_with_predictions(
|
|
252
|
+
self, x: torch.Tensor
|
|
253
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
254
|
+
"""Encode input to novel codes only.
|
|
255
|
+
|
|
256
|
+
Returns only the sparse novel codes (not predicted codes).
|
|
257
|
+
This is the main feature representation for TemporalSAE.
|
|
258
|
+
"""
|
|
259
|
+
# Process input through SAELens preprocessing
|
|
260
|
+
x = self.process_sae_in(x)
|
|
261
|
+
|
|
262
|
+
B, L, _ = x.shape
|
|
263
|
+
|
|
264
|
+
if self.cfg.tied_weights: # noqa: SIM108
|
|
265
|
+
W_enc = self.W_dec.T
|
|
266
|
+
else:
|
|
267
|
+
W_enc = self.W_enc
|
|
268
|
+
|
|
269
|
+
# Compute predicted codes using attention
|
|
270
|
+
x_residual = x
|
|
271
|
+
z_pred = torch.zeros((B, L, self.cfg.d_sae), device=x.device, dtype=x.dtype)
|
|
272
|
+
|
|
273
|
+
for attn_layer in self.attn_layers:
|
|
274
|
+
# Encode input to latent space
|
|
275
|
+
z_input = F.relu(torch.matmul(x_residual * self.lam, W_enc))
|
|
276
|
+
|
|
277
|
+
# Shift context (causal masking)
|
|
278
|
+
z_ctx = torch.cat(
|
|
279
|
+
(torch.zeros_like(z_input[:, :1, :]), z_input[:, :-1, :].clone()), dim=1
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
# Apply attention to get predicted codes
|
|
283
|
+
z_pred_, _ = attn_layer(z_ctx, z_input, get_attn_map=False)
|
|
284
|
+
z_pred_ = F.relu(z_pred_)
|
|
285
|
+
|
|
286
|
+
# Project predicted codes back to input space
|
|
287
|
+
Dz_pred_ = torch.matmul(z_pred_, self.W_dec)
|
|
288
|
+
Dz_norm_ = Dz_pred_.norm(dim=-1, keepdim=True) + self.eps
|
|
289
|
+
|
|
290
|
+
# Compute projection scale
|
|
291
|
+
proj_scale = (Dz_pred_ * x_residual).sum(
|
|
292
|
+
dim=-1, keepdim=True
|
|
293
|
+
) / Dz_norm_.pow(2)
|
|
294
|
+
|
|
295
|
+
# Accumulate predicted codes
|
|
296
|
+
z_pred = z_pred + (z_pred_ * proj_scale)
|
|
297
|
+
|
|
298
|
+
# Remove prediction from residual
|
|
299
|
+
x_residual = x_residual - proj_scale * Dz_pred_
|
|
300
|
+
|
|
301
|
+
# Encode residual (novel part) with sparse SAE
|
|
302
|
+
z_novel = F.relu(torch.matmul(x_residual * self.lam, W_enc))
|
|
303
|
+
if self.cfg.sae_diff_type == "topk":
|
|
304
|
+
kval = self.cfg.kval_topk
|
|
305
|
+
if kval is not None:
|
|
306
|
+
_, topk_indices = torch.topk(z_novel, kval, dim=-1)
|
|
307
|
+
mask = torch.zeros_like(z_novel)
|
|
308
|
+
mask.scatter_(-1, topk_indices, 1)
|
|
309
|
+
z_novel = z_novel * mask
|
|
310
|
+
|
|
311
|
+
# Return only novel codes (these are the interpretable features)
|
|
312
|
+
return z_novel, z_pred
|
|
313
|
+
|
|
314
|
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
|
315
|
+
return self.encode_with_predictions(x)[0]
|
|
316
|
+
|
|
317
|
+
def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
|
|
318
|
+
"""Decode novel codes to reconstruction.
|
|
319
|
+
|
|
320
|
+
Note: This only decodes the novel codes. For full reconstruction,
|
|
321
|
+
use forward() which includes predicted codes.
|
|
322
|
+
"""
|
|
323
|
+
# Decode novel codes
|
|
324
|
+
sae_out = torch.matmul(feature_acts, self.W_dec)
|
|
325
|
+
sae_out = sae_out + self.b_dec
|
|
326
|
+
|
|
327
|
+
# Apply hook
|
|
328
|
+
sae_out = self.hook_sae_recons(sae_out)
|
|
329
|
+
|
|
330
|
+
# Apply output activation normalization (reverses input normalization)
|
|
331
|
+
sae_out = self.run_time_activation_norm_fn_out(sae_out)
|
|
332
|
+
|
|
333
|
+
# Add bias (already removed in process_sae_in)
|
|
334
|
+
logger.warning(
|
|
335
|
+
"NOTE this only decodes x_novel. The x_pred is missing, so we're not reconstructing the full x."
|
|
336
|
+
)
|
|
337
|
+
return sae_out
|
|
338
|
+
|
|
339
|
+
@override
|
|
340
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
341
|
+
"""Full forward pass through TemporalSAE.
|
|
342
|
+
|
|
343
|
+
Returns complete reconstruction (predicted + novel).
|
|
344
|
+
"""
|
|
345
|
+
# Encode
|
|
346
|
+
z_novel, z_pred = self.encode_with_predictions(x)
|
|
347
|
+
|
|
348
|
+
# Decode the sum of predicted and novel codes.
|
|
349
|
+
x_recons = torch.matmul(z_novel + z_pred, self.W_dec) + self.b_dec
|
|
350
|
+
|
|
351
|
+
# Apply output activation normalization (reverses input normalization)
|
|
352
|
+
x_recons = self.run_time_activation_norm_fn_out(x_recons)
|
|
353
|
+
|
|
354
|
+
return self.hook_sae_output(x_recons)
|
|
355
|
+
|
|
356
|
+
@override
|
|
357
|
+
def fold_W_dec_norm(self) -> None:
|
|
358
|
+
raise NotImplementedError("Folding W_dec_norm is not supported for TemporalSAE")
|
|
359
|
+
|
|
360
|
+
@override
|
|
361
|
+
@torch.no_grad()
|
|
362
|
+
def fold_activation_norm_scaling_factor(self, scaling_factor: float) -> None:
|
|
363
|
+
raise NotImplementedError(
|
|
364
|
+
"Folding activation norm scaling factor is not supported for TemporalSAE"
|
|
365
|
+
)
|
sae_lens/saes/topk_sae.py
CHANGED
|
@@ -4,7 +4,6 @@ from dataclasses import dataclass
|
|
|
4
4
|
from typing import Any, Callable
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
|
-
from jaxtyping import Float
|
|
8
7
|
from torch import nn
|
|
9
8
|
from transformer_lens.hook_points import HookPoint
|
|
10
9
|
from typing_extensions import override
|
|
@@ -235,9 +234,7 @@ class TopKSAE(SAE[TopKSAEConfig]):
|
|
|
235
234
|
super().initialize_weights()
|
|
236
235
|
_init_weights_topk(self)
|
|
237
236
|
|
|
238
|
-
def encode(
|
|
239
|
-
self, x: Float[torch.Tensor, "... d_in"]
|
|
240
|
-
) -> Float[torch.Tensor, "... d_sae"]:
|
|
237
|
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
|
241
238
|
"""
|
|
242
239
|
Converts input x into feature activations.
|
|
243
240
|
Uses topk activation under the hood.
|
|
@@ -251,8 +248,8 @@ class TopKSAE(SAE[TopKSAEConfig]):
|
|
|
251
248
|
|
|
252
249
|
def decode(
|
|
253
250
|
self,
|
|
254
|
-
feature_acts:
|
|
255
|
-
) ->
|
|
251
|
+
feature_acts: torch.Tensor,
|
|
252
|
+
) -> torch.Tensor:
|
|
256
253
|
"""
|
|
257
254
|
Reconstructs the input from topk feature activations.
|
|
258
255
|
Applies optional finetuning scaling, hooking to recons, out normalization,
|
|
@@ -354,8 +351,8 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
|
|
|
354
351
|
_init_weights_topk(self)
|
|
355
352
|
|
|
356
353
|
def encode_with_hidden_pre(
|
|
357
|
-
self, x:
|
|
358
|
-
) -> tuple[
|
|
354
|
+
self, x: torch.Tensor
|
|
355
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
359
356
|
"""
|
|
360
357
|
Similar to the base training method: calculate pre-activations, then apply TopK.
|
|
361
358
|
"""
|
|
@@ -372,8 +369,8 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
|
|
|
372
369
|
@override
|
|
373
370
|
def decode(
|
|
374
371
|
self,
|
|
375
|
-
feature_acts:
|
|
376
|
-
) ->
|
|
372
|
+
feature_acts: torch.Tensor,
|
|
373
|
+
) -> torch.Tensor:
|
|
377
374
|
"""
|
|
378
375
|
Decodes feature activations back into input space,
|
|
379
376
|
applying optional finetuning scale, hooking, out normalization, etc.
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import json
|
|
2
2
|
from dataclasses import dataclass
|
|
3
|
+
from pathlib import Path
|
|
3
4
|
from statistics import mean
|
|
4
5
|
|
|
5
6
|
import torch
|
|
@@ -51,3 +52,9 @@ class ActivationScaler:
|
|
|
51
52
|
|
|
52
53
|
with open(file_path, "w") as f:
|
|
53
54
|
json.dump({"scaling_factor": self.scaling_factor}, f)
|
|
55
|
+
|
|
56
|
+
def load(self, file_path: str | Path):
|
|
57
|
+
"""load the state dict from a file in json format"""
|
|
58
|
+
with open(file_path) as f:
|
|
59
|
+
data = json.load(f)
|
|
60
|
+
self.scaling_factor = data["scaling_factor"]
|