sae-lens 6.12.3__tar.gz → 6.13.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sae_lens-6.12.3 → sae_lens-6.13.1}/PKG-INFO +1 -1
- {sae_lens-6.12.3 → sae_lens-6.13.1}/pyproject.toml +1 -1
- {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/__init__.py +1 -1
- {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/evals.py +2 -0
- {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/pretokenize_runner.py +2 -1
- {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/saes/sae.py +9 -10
- sae_lens-6.13.1/sae_lens/saes/topk_sae.py +473 -0
- {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/tokenization_and_batching.py +21 -6
- {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/training/sae_trainer.py +7 -5
- {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/training/types.py +1 -1
- sae_lens-6.12.3/sae_lens/saes/topk_sae.py +0 -271
- {sae_lens-6.12.3 → sae_lens-6.13.1}/LICENSE +0 -0
- {sae_lens-6.12.3 → sae_lens-6.13.1}/README.md +0 -0
- {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/analysis/__init__.py +0 -0
- {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
- {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/analysis/neuronpedia_integration.py +0 -0
- {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/cache_activations_runner.py +0 -0
- {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/config.py +0 -0
- {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/constants.py +0 -0
- {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/llm_sae_training_runner.py +0 -0
- {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/load_model.py +0 -0
- {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/loading/__init__.py +0 -0
- {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/loading/pretrained_sae_loaders.py +0 -0
- {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/loading/pretrained_saes_directory.py +0 -0
- {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/pretrained_saes.yaml +0 -0
- {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/registry.py +0 -0
- {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/saes/__init__.py +0 -0
- {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/saes/batchtopk_sae.py +0 -0
- {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/saes/gated_sae.py +0 -0
- {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/saes/jumprelu_sae.py +0 -0
- {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/saes/standard_sae.py +0 -0
- {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/saes/transcoder.py +0 -0
- {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/training/__init__.py +0 -0
- {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/training/activation_scaler.py +0 -0
- {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/training/activations_store.py +0 -0
- {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/training/mixing_buffer.py +0 -0
- {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/training/optim.py +0 -0
- {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
- {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/tutorial/tsea.py +0 -0
- {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/util.py +0 -0
|
@@ -466,6 +466,8 @@ def get_sparsity_and_variance_metrics(
|
|
|
466
466
|
sae_out_scaled = sae.decode(sae_feature_activations).to(
|
|
467
467
|
original_act_scaled.device
|
|
468
468
|
)
|
|
469
|
+
if sae_feature_activations.is_sparse:
|
|
470
|
+
sae_feature_activations = sae_feature_activations.to_dense()
|
|
469
471
|
del cache
|
|
470
472
|
|
|
471
473
|
sae_out = activation_scaler.unscale(sae_out_scaled)
|
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
import io
|
|
2
2
|
import json
|
|
3
3
|
import sys
|
|
4
|
+
from collections.abc import Iterator
|
|
4
5
|
from dataclasses import dataclass
|
|
5
6
|
from pathlib import Path
|
|
6
|
-
from typing import
|
|
7
|
+
from typing import Literal, cast
|
|
7
8
|
|
|
8
9
|
import torch
|
|
9
10
|
from datasets import Dataset, DatasetDict, load_dataset
|
|
@@ -14,7 +14,6 @@ from typing import (
|
|
|
14
14
|
Generic,
|
|
15
15
|
Literal,
|
|
16
16
|
NamedTuple,
|
|
17
|
-
Type,
|
|
18
17
|
TypeVar,
|
|
19
18
|
)
|
|
20
19
|
|
|
@@ -534,7 +533,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
534
533
|
@classmethod
|
|
535
534
|
@deprecated("Use load_from_disk instead")
|
|
536
535
|
def load_from_pretrained(
|
|
537
|
-
cls:
|
|
536
|
+
cls: type[T_SAE],
|
|
538
537
|
path: str | Path,
|
|
539
538
|
device: str = "cpu",
|
|
540
539
|
dtype: str | None = None,
|
|
@@ -543,7 +542,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
543
542
|
|
|
544
543
|
@classmethod
|
|
545
544
|
def load_from_disk(
|
|
546
|
-
cls:
|
|
545
|
+
cls: type[T_SAE],
|
|
547
546
|
path: str | Path,
|
|
548
547
|
device: str = "cpu",
|
|
549
548
|
dtype: str | None = None,
|
|
@@ -564,7 +563,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
564
563
|
|
|
565
564
|
@classmethod
|
|
566
565
|
def from_pretrained(
|
|
567
|
-
cls:
|
|
566
|
+
cls: type[T_SAE],
|
|
568
567
|
release: str,
|
|
569
568
|
sae_id: str,
|
|
570
569
|
device: str = "cpu",
|
|
@@ -585,7 +584,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
585
584
|
|
|
586
585
|
@classmethod
|
|
587
586
|
def from_pretrained_with_cfg_and_sparsity(
|
|
588
|
-
cls:
|
|
587
|
+
cls: type[T_SAE],
|
|
589
588
|
release: str,
|
|
590
589
|
sae_id: str,
|
|
591
590
|
device: str = "cpu",
|
|
@@ -684,7 +683,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
684
683
|
return sae, cfg_dict, log_sparsities
|
|
685
684
|
|
|
686
685
|
@classmethod
|
|
687
|
-
def from_dict(cls:
|
|
686
|
+
def from_dict(cls: type[T_SAE], config_dict: dict[str, Any]) -> T_SAE:
|
|
688
687
|
"""Create an SAE from a config dictionary."""
|
|
689
688
|
sae_cls = cls.get_sae_class_for_architecture(config_dict["architecture"])
|
|
690
689
|
sae_config_cls = cls.get_sae_config_class_for_architecture(
|
|
@@ -694,8 +693,8 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
694
693
|
|
|
695
694
|
@classmethod
|
|
696
695
|
def get_sae_class_for_architecture(
|
|
697
|
-
cls:
|
|
698
|
-
) ->
|
|
696
|
+
cls: type[T_SAE], architecture: str
|
|
697
|
+
) -> type[T_SAE]:
|
|
699
698
|
"""Get the SAE class for a given architecture."""
|
|
700
699
|
sae_cls, _ = get_sae_class(architecture)
|
|
701
700
|
if not issubclass(sae_cls, cls):
|
|
@@ -1000,8 +999,8 @@ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
|
|
|
1000
999
|
|
|
1001
1000
|
@classmethod
|
|
1002
1001
|
def get_sae_class_for_architecture(
|
|
1003
|
-
cls:
|
|
1004
|
-
) ->
|
|
1002
|
+
cls: type[T_TRAINING_SAE], architecture: str
|
|
1003
|
+
) -> type[T_TRAINING_SAE]:
|
|
1005
1004
|
"""Get the SAE class for a given architecture."""
|
|
1006
1005
|
sae_cls, _ = get_sae_training_class(architecture)
|
|
1007
1006
|
if not issubclass(sae_cls, cls):
|
|
@@ -0,0 +1,473 @@
|
|
|
1
|
+
"""Inference-only TopKSAE variant, similar in spirit to StandardSAE but using a TopK-based activation."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Callable
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from jaxtyping import Float
|
|
8
|
+
from torch import nn
|
|
9
|
+
from transformer_lens.hook_points import HookPoint
|
|
10
|
+
from typing_extensions import override
|
|
11
|
+
|
|
12
|
+
from sae_lens.saes.sae import (
|
|
13
|
+
SAE,
|
|
14
|
+
SAEConfig,
|
|
15
|
+
TrainCoefficientConfig,
|
|
16
|
+
TrainingSAE,
|
|
17
|
+
TrainingSAEConfig,
|
|
18
|
+
TrainStepInput,
|
|
19
|
+
_disable_hooks,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class SparseHookPoint(HookPoint):
|
|
24
|
+
"""
|
|
25
|
+
A HookPoint that takes in a sparse tensor.
|
|
26
|
+
Overrides TransformerLens's HookPoint.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(self, d_sae: int):
|
|
30
|
+
super().__init__()
|
|
31
|
+
self.d_sae = d_sae
|
|
32
|
+
|
|
33
|
+
@override
|
|
34
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
35
|
+
using_hooks = (
|
|
36
|
+
self._forward_hooks is not None and len(self._forward_hooks) > 0
|
|
37
|
+
) or (self._backward_hooks is not None and len(self._backward_hooks) > 0)
|
|
38
|
+
if using_hooks and x.is_sparse:
|
|
39
|
+
return x.to_dense()
|
|
40
|
+
return x # if no hooks are being used, use passthrough
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class TopK(nn.Module):
|
|
44
|
+
"""
|
|
45
|
+
A simple TopK activation that zeroes out all but the top K elements along the last dimension,
|
|
46
|
+
and applies ReLU to the top K elements.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
use_sparse_activations: bool
|
|
50
|
+
|
|
51
|
+
def __init__(
|
|
52
|
+
self,
|
|
53
|
+
k: int,
|
|
54
|
+
use_sparse_activations: bool = False,
|
|
55
|
+
):
|
|
56
|
+
super().__init__()
|
|
57
|
+
self.k = k
|
|
58
|
+
self.use_sparse_activations = use_sparse_activations
|
|
59
|
+
|
|
60
|
+
def forward(
|
|
61
|
+
self,
|
|
62
|
+
x: torch.Tensor,
|
|
63
|
+
) -> torch.Tensor:
|
|
64
|
+
"""
|
|
65
|
+
1) Select top K elements along the last dimension.
|
|
66
|
+
2) Apply ReLU.
|
|
67
|
+
3) Zero out all other entries.
|
|
68
|
+
"""
|
|
69
|
+
topk_values, topk_indices = torch.topk(x, k=self.k, dim=-1, sorted=False)
|
|
70
|
+
values = topk_values.relu()
|
|
71
|
+
if self.use_sparse_activations:
|
|
72
|
+
# Produce a COO sparse tensor (use sparse matrix multiply in decode)
|
|
73
|
+
original_shape = x.shape
|
|
74
|
+
|
|
75
|
+
# Create indices for all dimensions
|
|
76
|
+
# For each element in topk_indices, we need to map it back to the original tensor coordinates
|
|
77
|
+
batch_dims = original_shape[:-1] # All dimensions except the last one
|
|
78
|
+
num_batch_elements = torch.prod(torch.tensor(batch_dims)).item()
|
|
79
|
+
|
|
80
|
+
# Create batch indices - each batch element repeated k times
|
|
81
|
+
batch_indices_flat = torch.arange(
|
|
82
|
+
num_batch_elements, device=x.device
|
|
83
|
+
).repeat_interleave(self.k)
|
|
84
|
+
|
|
85
|
+
# Convert flat batch indices back to multi-dimensional indices
|
|
86
|
+
if len(batch_dims) == 1:
|
|
87
|
+
# 2D case: [batch, features]
|
|
88
|
+
sparse_indices = torch.stack(
|
|
89
|
+
[
|
|
90
|
+
batch_indices_flat,
|
|
91
|
+
topk_indices.flatten(),
|
|
92
|
+
]
|
|
93
|
+
)
|
|
94
|
+
else:
|
|
95
|
+
# 3D+ case: need to unravel the batch indices
|
|
96
|
+
batch_indices_multi = []
|
|
97
|
+
remaining = batch_indices_flat
|
|
98
|
+
for dim_size in reversed(batch_dims):
|
|
99
|
+
batch_indices_multi.append(remaining % dim_size)
|
|
100
|
+
remaining = remaining // dim_size
|
|
101
|
+
batch_indices_multi.reverse()
|
|
102
|
+
|
|
103
|
+
sparse_indices = torch.stack(
|
|
104
|
+
[
|
|
105
|
+
*batch_indices_multi,
|
|
106
|
+
topk_indices.flatten(),
|
|
107
|
+
]
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
return torch.sparse_coo_tensor(
|
|
111
|
+
sparse_indices, values.flatten(), original_shape
|
|
112
|
+
)
|
|
113
|
+
result = torch.zeros_like(x)
|
|
114
|
+
result.scatter_(-1, topk_indices, values)
|
|
115
|
+
return result
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
@dataclass
|
|
119
|
+
class TopKSAEConfig(SAEConfig):
|
|
120
|
+
"""
|
|
121
|
+
Configuration class for a TopKSAE.
|
|
122
|
+
"""
|
|
123
|
+
|
|
124
|
+
k: int = 100
|
|
125
|
+
|
|
126
|
+
@override
|
|
127
|
+
@classmethod
|
|
128
|
+
def architecture(cls) -> str:
|
|
129
|
+
return "topk"
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def _sparse_matmul_nd(
|
|
133
|
+
sparse_tensor: torch.Tensor, dense_matrix: torch.Tensor
|
|
134
|
+
) -> torch.Tensor:
|
|
135
|
+
"""
|
|
136
|
+
Multiply a sparse tensor of shape [..., d_sae] with a dense matrix of shape [d_sae, d_out]
|
|
137
|
+
to get a result of shape [..., d_out].
|
|
138
|
+
|
|
139
|
+
This function handles sparse tensors with arbitrary batch dimensions by flattening
|
|
140
|
+
the batch dimensions, performing 2D sparse matrix multiplication, and reshaping back.
|
|
141
|
+
"""
|
|
142
|
+
original_shape = sparse_tensor.shape
|
|
143
|
+
batch_dims = original_shape[:-1]
|
|
144
|
+
d_sae = original_shape[-1]
|
|
145
|
+
d_out = dense_matrix.shape[-1]
|
|
146
|
+
|
|
147
|
+
if sparse_tensor.ndim == 2:
|
|
148
|
+
# Simple 2D case - use torch.sparse.mm directly
|
|
149
|
+
# sparse.mm errors with bfloat16 :(
|
|
150
|
+
with torch.autocast(device_type=sparse_tensor.device.type, enabled=False):
|
|
151
|
+
return torch.sparse.mm(sparse_tensor, dense_matrix)
|
|
152
|
+
|
|
153
|
+
# For 3D+ case, reshape to 2D, multiply, then reshape back
|
|
154
|
+
batch_size = int(torch.prod(torch.tensor(batch_dims)).item())
|
|
155
|
+
|
|
156
|
+
# Ensure tensor is coalesced for efficient access to indices/values
|
|
157
|
+
if not sparse_tensor.is_coalesced():
|
|
158
|
+
sparse_tensor = sparse_tensor.coalesce()
|
|
159
|
+
|
|
160
|
+
# Get indices and values
|
|
161
|
+
indices = sparse_tensor.indices() # [ndim, nnz]
|
|
162
|
+
values = sparse_tensor.values() # [nnz]
|
|
163
|
+
|
|
164
|
+
# Convert multi-dimensional batch indices to flat indices
|
|
165
|
+
flat_batch_indices = torch.zeros_like(indices[0])
|
|
166
|
+
multiplier = 1
|
|
167
|
+
for i in reversed(range(len(batch_dims))):
|
|
168
|
+
flat_batch_indices += indices[i] * multiplier
|
|
169
|
+
multiplier *= batch_dims[i]
|
|
170
|
+
|
|
171
|
+
# Create 2D sparse tensor indices [batch_flat, feature]
|
|
172
|
+
sparse_2d_indices = torch.stack([flat_batch_indices, indices[-1]])
|
|
173
|
+
|
|
174
|
+
# Create 2D sparse tensor
|
|
175
|
+
sparse_2d = torch.sparse_coo_tensor(
|
|
176
|
+
sparse_2d_indices, values, (batch_size, d_sae)
|
|
177
|
+
).coalesce()
|
|
178
|
+
|
|
179
|
+
# sparse.mm errors with bfloat16 :(
|
|
180
|
+
with torch.autocast(device_type=sparse_tensor.device.type, enabled=False):
|
|
181
|
+
# Do the matrix multiplication
|
|
182
|
+
result_2d = torch.sparse.mm(sparse_2d, dense_matrix) # [batch_size, d_out]
|
|
183
|
+
|
|
184
|
+
# Reshape back to original batch dimensions
|
|
185
|
+
result_shape = tuple(batch_dims) + (d_out,)
|
|
186
|
+
return result_2d.view(result_shape)
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
class TopKSAE(SAE[TopKSAEConfig]):
|
|
190
|
+
"""
|
|
191
|
+
An inference-only sparse autoencoder using a "topk" activation function.
|
|
192
|
+
It uses linear encoder and decoder layers, applying the TopK activation
|
|
193
|
+
to the hidden pre-activation in its encode step.
|
|
194
|
+
"""
|
|
195
|
+
|
|
196
|
+
b_enc: nn.Parameter
|
|
197
|
+
|
|
198
|
+
def __init__(self, cfg: TopKSAEConfig, use_error_term: bool = False):
|
|
199
|
+
"""
|
|
200
|
+
Args:
|
|
201
|
+
cfg: SAEConfig defining model size and behavior.
|
|
202
|
+
use_error_term: Whether to apply the error-term approach in the forward pass.
|
|
203
|
+
"""
|
|
204
|
+
super().__init__(cfg, use_error_term)
|
|
205
|
+
|
|
206
|
+
@override
|
|
207
|
+
def initialize_weights(self) -> None:
|
|
208
|
+
# Initialize encoder weights and bias.
|
|
209
|
+
super().initialize_weights()
|
|
210
|
+
_init_weights_topk(self)
|
|
211
|
+
|
|
212
|
+
def encode(
|
|
213
|
+
self, x: Float[torch.Tensor, "... d_in"]
|
|
214
|
+
) -> Float[torch.Tensor, "... d_sae"]:
|
|
215
|
+
"""
|
|
216
|
+
Converts input x into feature activations.
|
|
217
|
+
Uses topk activation under the hood.
|
|
218
|
+
"""
|
|
219
|
+
sae_in = self.process_sae_in(x)
|
|
220
|
+
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
|
|
221
|
+
# The BaseSAE already sets self.activation_fn to TopK(...) if config requests topk.
|
|
222
|
+
return self.hook_sae_acts_post(self.activation_fn(hidden_pre))
|
|
223
|
+
|
|
224
|
+
def decode(
|
|
225
|
+
self,
|
|
226
|
+
feature_acts: Float[torch.Tensor, "... d_sae"],
|
|
227
|
+
) -> Float[torch.Tensor, "... d_in"]:
|
|
228
|
+
"""
|
|
229
|
+
Reconstructs the input from topk feature activations.
|
|
230
|
+
Applies optional finetuning scaling, hooking to recons, out normalization,
|
|
231
|
+
and optional head reshaping.
|
|
232
|
+
"""
|
|
233
|
+
# Handle sparse tensors using efficient sparse matrix multiplication
|
|
234
|
+
if feature_acts.is_sparse:
|
|
235
|
+
sae_out_pre = _sparse_matmul_nd(feature_acts, self.W_dec) + self.b_dec
|
|
236
|
+
else:
|
|
237
|
+
sae_out_pre = feature_acts @ self.W_dec + self.b_dec
|
|
238
|
+
sae_out_pre = self.hook_sae_recons(sae_out_pre)
|
|
239
|
+
sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
|
|
240
|
+
return self.reshape_fn_out(sae_out_pre, self.d_head)
|
|
241
|
+
|
|
242
|
+
@override
|
|
243
|
+
def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
|
|
244
|
+
return TopK(self.cfg.k, use_sparse_activations=False)
|
|
245
|
+
|
|
246
|
+
@override
|
|
247
|
+
@torch.no_grad()
|
|
248
|
+
def fold_W_dec_norm(self) -> None:
|
|
249
|
+
raise NotImplementedError(
|
|
250
|
+
"Folding W_dec_norm is not safe for TopKSAEs, as this may change the topk activations"
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
@dataclass
|
|
255
|
+
class TopKTrainingSAEConfig(TrainingSAEConfig):
|
|
256
|
+
"""
|
|
257
|
+
Configuration class for training a TopKTrainingSAE.
|
|
258
|
+
|
|
259
|
+
Args:
|
|
260
|
+
k (int): Number of top features to keep active. Only the top k features
|
|
261
|
+
with the highest pre-activations will be non-zero. Defaults to 100.
|
|
262
|
+
use_sparse_activations (bool): Whether to use sparse tensor representations
|
|
263
|
+
for activations during training. This can reduce memory usage and improve
|
|
264
|
+
performance when k is small relative to d_sae, but is only worthwhile if
|
|
265
|
+
using float32 and not using autocast. Defaults to False.
|
|
266
|
+
aux_loss_coefficient (float): Coefficient for the auxiliary loss that encourages
|
|
267
|
+
dead neurons to learn useful features. This loss helps prevent neuron death
|
|
268
|
+
in TopK SAEs by having dead neurons reconstruct the residual error from
|
|
269
|
+
live neurons. Defaults to 1.0.
|
|
270
|
+
decoder_init_norm (float | None): Norm to initialize decoder weights to.
|
|
271
|
+
0.1 corresponds to the "heuristic" initialization from Anthropic's April update.
|
|
272
|
+
Use None to disable. Inherited from TrainingSAEConfig. Defaults to 0.1.
|
|
273
|
+
d_in (int): Input dimension (dimensionality of the activations being encoded).
|
|
274
|
+
Inherited from SAEConfig.
|
|
275
|
+
d_sae (int): SAE latent dimension (number of features in the SAE).
|
|
276
|
+
Inherited from SAEConfig.
|
|
277
|
+
dtype (str): Data type for the SAE parameters. Inherited from SAEConfig.
|
|
278
|
+
Defaults to "float32".
|
|
279
|
+
device (str): Device to place the SAE on. Inherited from SAEConfig.
|
|
280
|
+
Defaults to "cpu".
|
|
281
|
+
apply_b_dec_to_input (bool): Whether to apply decoder bias to the input
|
|
282
|
+
before encoding. Inherited from SAEConfig. Defaults to True.
|
|
283
|
+
normalize_activations (Literal["none", "expected_average_only_in", "constant_norm_rescale", "layer_norm"]):
|
|
284
|
+
Normalization strategy for input activations. Inherited from SAEConfig.
|
|
285
|
+
Defaults to "none".
|
|
286
|
+
reshape_activations (Literal["none", "hook_z"]): How to reshape activations
|
|
287
|
+
(useful for attention head outputs). Inherited from SAEConfig.
|
|
288
|
+
Defaults to "none".
|
|
289
|
+
metadata (SAEMetadata): Metadata about the SAE training (model name, hook name, etc.).
|
|
290
|
+
Inherited from SAEConfig.
|
|
291
|
+
"""
|
|
292
|
+
|
|
293
|
+
k: int = 100
|
|
294
|
+
use_sparse_activations: bool = False
|
|
295
|
+
aux_loss_coefficient: float = 1.0
|
|
296
|
+
|
|
297
|
+
@override
|
|
298
|
+
@classmethod
|
|
299
|
+
def architecture(cls) -> str:
|
|
300
|
+
return "topk"
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
|
|
304
|
+
"""
|
|
305
|
+
TopK variant with training functionality. Calculates a topk-related auxiliary loss, etc.
|
|
306
|
+
"""
|
|
307
|
+
|
|
308
|
+
b_enc: nn.Parameter
|
|
309
|
+
|
|
310
|
+
def __init__(self, cfg: TopKTrainingSAEConfig, use_error_term: bool = False):
|
|
311
|
+
super().__init__(cfg, use_error_term)
|
|
312
|
+
self.hook_sae_acts_post = SparseHookPoint(self.cfg.d_sae)
|
|
313
|
+
self.setup()
|
|
314
|
+
|
|
315
|
+
@override
|
|
316
|
+
def initialize_weights(self) -> None:
|
|
317
|
+
super().initialize_weights()
|
|
318
|
+
_init_weights_topk(self)
|
|
319
|
+
|
|
320
|
+
def encode_with_hidden_pre(
|
|
321
|
+
self, x: Float[torch.Tensor, "... d_in"]
|
|
322
|
+
) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
|
|
323
|
+
"""
|
|
324
|
+
Similar to the base training method: calculate pre-activations, then apply TopK.
|
|
325
|
+
"""
|
|
326
|
+
sae_in = self.process_sae_in(x)
|
|
327
|
+
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
|
|
328
|
+
|
|
329
|
+
# Apply the TopK activation function (already set in self.activation_fn if config is "topk")
|
|
330
|
+
feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre))
|
|
331
|
+
return feature_acts, hidden_pre
|
|
332
|
+
|
|
333
|
+
@override
|
|
334
|
+
def decode(
|
|
335
|
+
self,
|
|
336
|
+
feature_acts: Float[torch.Tensor, "... d_sae"],
|
|
337
|
+
) -> Float[torch.Tensor, "... d_in"]:
|
|
338
|
+
"""
|
|
339
|
+
Decodes feature activations back into input space,
|
|
340
|
+
applying optional finetuning scale, hooking, out normalization, etc.
|
|
341
|
+
"""
|
|
342
|
+
# Handle sparse tensors using efficient sparse matrix multiplication
|
|
343
|
+
if feature_acts.is_sparse:
|
|
344
|
+
sae_out_pre = _sparse_matmul_nd(feature_acts, self.W_dec) + self.b_dec
|
|
345
|
+
else:
|
|
346
|
+
sae_out_pre = feature_acts @ self.W_dec + self.b_dec
|
|
347
|
+
sae_out_pre = self.hook_sae_recons(sae_out_pre)
|
|
348
|
+
sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
|
|
349
|
+
return self.reshape_fn_out(sae_out_pre, self.d_head)
|
|
350
|
+
|
|
351
|
+
@override
|
|
352
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
353
|
+
"""Forward pass through the SAE."""
|
|
354
|
+
feature_acts = self.encode(x)
|
|
355
|
+
sae_out = self.decode(feature_acts)
|
|
356
|
+
|
|
357
|
+
if self.use_error_term:
|
|
358
|
+
with torch.no_grad():
|
|
359
|
+
# Recompute without hooks for true error term
|
|
360
|
+
with _disable_hooks(self):
|
|
361
|
+
feature_acts_clean = self.encode(x)
|
|
362
|
+
x_reconstruct_clean = self.decode(feature_acts_clean)
|
|
363
|
+
sae_error = self.hook_sae_error(x - x_reconstruct_clean)
|
|
364
|
+
sae_out = sae_out + sae_error
|
|
365
|
+
|
|
366
|
+
return self.hook_sae_output(sae_out)
|
|
367
|
+
|
|
368
|
+
@override
|
|
369
|
+
def calculate_aux_loss(
|
|
370
|
+
self,
|
|
371
|
+
step_input: TrainStepInput,
|
|
372
|
+
feature_acts: torch.Tensor,
|
|
373
|
+
hidden_pre: torch.Tensor,
|
|
374
|
+
sae_out: torch.Tensor,
|
|
375
|
+
) -> dict[str, torch.Tensor]:
|
|
376
|
+
# Calculate the auxiliary loss for dead neurons
|
|
377
|
+
topk_loss = self.calculate_topk_aux_loss(
|
|
378
|
+
sae_in=step_input.sae_in,
|
|
379
|
+
sae_out=sae_out,
|
|
380
|
+
hidden_pre=hidden_pre,
|
|
381
|
+
dead_neuron_mask=step_input.dead_neuron_mask,
|
|
382
|
+
)
|
|
383
|
+
return {"auxiliary_reconstruction_loss": topk_loss}
|
|
384
|
+
|
|
385
|
+
@override
|
|
386
|
+
@torch.no_grad()
|
|
387
|
+
def fold_W_dec_norm(self) -> None:
|
|
388
|
+
raise NotImplementedError(
|
|
389
|
+
"Folding W_dec_norm is not safe for TopKSAEs, as this may change the topk activations"
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
@override
|
|
393
|
+
def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
|
|
394
|
+
return TopK(self.cfg.k, use_sparse_activations=self.cfg.use_sparse_activations)
|
|
395
|
+
|
|
396
|
+
@override
|
|
397
|
+
def get_coefficients(self) -> dict[str, TrainCoefficientConfig | float]:
|
|
398
|
+
return {}
|
|
399
|
+
|
|
400
|
+
def calculate_topk_aux_loss(
|
|
401
|
+
self,
|
|
402
|
+
sae_in: torch.Tensor,
|
|
403
|
+
sae_out: torch.Tensor,
|
|
404
|
+
hidden_pre: torch.Tensor,
|
|
405
|
+
dead_neuron_mask: torch.Tensor | None,
|
|
406
|
+
) -> torch.Tensor:
|
|
407
|
+
"""
|
|
408
|
+
Calculate TopK auxiliary loss.
|
|
409
|
+
|
|
410
|
+
This auxiliary loss encourages dead neurons to learn useful features by having
|
|
411
|
+
them reconstruct the residual error from the live neurons. It's a key part of
|
|
412
|
+
preventing neuron death in TopK SAEs.
|
|
413
|
+
"""
|
|
414
|
+
# Mostly taken from https://github.com/EleutherAI/sae/blob/main/sae/sae.py, except without variance normalization
|
|
415
|
+
# NOTE: checking the number of dead neurons will force a GPU sync, so performance can likely be improved here
|
|
416
|
+
if dead_neuron_mask is None or (num_dead := int(dead_neuron_mask.sum())) == 0:
|
|
417
|
+
return sae_out.new_tensor(0.0)
|
|
418
|
+
residual = (sae_in - sae_out).detach()
|
|
419
|
+
|
|
420
|
+
# Heuristic from Appendix B.1 in the paper
|
|
421
|
+
k_aux = sae_in.shape[-1] // 2
|
|
422
|
+
|
|
423
|
+
# Reduce the scale of the loss if there are a small number of dead latents
|
|
424
|
+
scale = min(num_dead / k_aux, 1.0)
|
|
425
|
+
k_aux = min(k_aux, num_dead)
|
|
426
|
+
|
|
427
|
+
auxk_acts = _calculate_topk_aux_acts(
|
|
428
|
+
k_aux=k_aux,
|
|
429
|
+
hidden_pre=hidden_pre,
|
|
430
|
+
dead_neuron_mask=dead_neuron_mask,
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
# Encourage the top ~50% of dead latents to predict the residual of the
|
|
434
|
+
# top k living latents
|
|
435
|
+
recons = self.decode(auxk_acts)
|
|
436
|
+
auxk_loss = (recons - residual).pow(2).sum(dim=-1).mean()
|
|
437
|
+
return self.cfg.aux_loss_coefficient * scale * auxk_loss
|
|
438
|
+
|
|
439
|
+
|
|
440
|
+
def _calculate_topk_aux_acts(
|
|
441
|
+
k_aux: int,
|
|
442
|
+
hidden_pre: torch.Tensor,
|
|
443
|
+
dead_neuron_mask: torch.Tensor,
|
|
444
|
+
) -> torch.Tensor:
|
|
445
|
+
"""
|
|
446
|
+
Helper method to calculate activations for the auxiliary loss.
|
|
447
|
+
|
|
448
|
+
Args:
|
|
449
|
+
k_aux: Number of top dead neurons to select
|
|
450
|
+
hidden_pre: Pre-activation values from encoder
|
|
451
|
+
dead_neuron_mask: Boolean mask indicating which neurons are dead
|
|
452
|
+
|
|
453
|
+
Returns:
|
|
454
|
+
Tensor with activations for only the top-k dead neurons, zeros elsewhere
|
|
455
|
+
"""
|
|
456
|
+
|
|
457
|
+
# Don't include living latents in this loss
|
|
458
|
+
auxk_latents = torch.where(dead_neuron_mask[None], hidden_pre, -torch.inf)
|
|
459
|
+
# Top-k dead latents
|
|
460
|
+
auxk_topk = auxk_latents.topk(k_aux, sorted=False)
|
|
461
|
+
# Set the activations to zero for all but the top k_aux dead latents
|
|
462
|
+
auxk_acts = torch.zeros_like(hidden_pre)
|
|
463
|
+
auxk_acts.scatter_(-1, auxk_topk.indices, auxk_topk.values)
|
|
464
|
+
# Set activations to zero for all but top k_aux dead latents
|
|
465
|
+
return auxk_acts
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
def _init_weights_topk(
|
|
469
|
+
sae: SAE[TopKSAEConfig] | TrainingSAE[TopKTrainingSAEConfig],
|
|
470
|
+
) -> None:
|
|
471
|
+
sae.b_enc = nn.Parameter(
|
|
472
|
+
torch.zeros(sae.cfg.d_sae, dtype=sae.dtype, device=sae.device)
|
|
473
|
+
)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from
|
|
1
|
+
from collections.abc import Generator, Iterator
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
@@ -68,7 +68,7 @@ def concat_and_batch_sequences(
|
|
|
68
68
|
) -> Generator[torch.Tensor, None, None]:
|
|
69
69
|
"""
|
|
70
70
|
Generator to concat token sequences together from the tokens_interator, yielding
|
|
71
|
-
|
|
71
|
+
sequences of size `context_size`. Batching across the batch dimension is handled by the caller.
|
|
72
72
|
|
|
73
73
|
Args:
|
|
74
74
|
tokens_iterator: An iterator which returns a 1D tensors of tokens
|
|
@@ -76,13 +76,28 @@ def concat_and_batch_sequences(
|
|
|
76
76
|
begin_batch_token_id: If provided, this token will be at position 0 of each batch
|
|
77
77
|
begin_sequence_token_id: If provided, this token will be the first token of each sequence
|
|
78
78
|
sequence_separator_token_id: If provided, this token will be inserted between concatenated sequences
|
|
79
|
-
disable_concat_sequences: If True, disable concatenating sequences and ignore sequences shorter than context_size
|
|
79
|
+
disable_concat_sequences: If True, disable concatenating sequences and ignore sequences shorter than context_size (including BOS token if present)
|
|
80
80
|
max_batches: If not provided, the iterator will be run to completion.
|
|
81
81
|
"""
|
|
82
82
|
if disable_concat_sequences:
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
83
|
+
if begin_batch_token_id and not begin_sequence_token_id:
|
|
84
|
+
begin_sequence_token_id = begin_batch_token_id
|
|
85
|
+
for sequence in tokens_iterator:
|
|
86
|
+
if (
|
|
87
|
+
begin_sequence_token_id is not None
|
|
88
|
+
and sequence[0] != begin_sequence_token_id
|
|
89
|
+
and len(sequence) >= context_size - 1
|
|
90
|
+
):
|
|
91
|
+
begin_sequence_token_id_tensor = torch.tensor(
|
|
92
|
+
[begin_sequence_token_id],
|
|
93
|
+
dtype=torch.long,
|
|
94
|
+
device=sequence.device,
|
|
95
|
+
)
|
|
96
|
+
sequence = torch.cat(
|
|
97
|
+
[begin_sequence_token_id_tensor, sequence[: context_size - 1]]
|
|
98
|
+
)
|
|
99
|
+
if len(sequence) >= context_size:
|
|
100
|
+
yield sequence[:context_size]
|
|
86
101
|
return
|
|
87
102
|
|
|
88
103
|
batch: torch.Tensor | None = None
|
|
@@ -253,12 +253,14 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
|
253
253
|
)
|
|
254
254
|
|
|
255
255
|
with torch.no_grad():
|
|
256
|
-
|
|
256
|
+
# calling .bool() should be equivalent to .abs() > 0, and work with coo tensors
|
|
257
|
+
firing_feats = train_step_output.feature_acts.bool().float()
|
|
258
|
+
did_fire = firing_feats.sum(-2).bool()
|
|
259
|
+
if did_fire.is_sparse:
|
|
260
|
+
did_fire = did_fire.to_dense()
|
|
257
261
|
self.n_forward_passes_since_fired += 1
|
|
258
262
|
self.n_forward_passes_since_fired[did_fire] = 0
|
|
259
|
-
self.act_freq_scores += (
|
|
260
|
-
(train_step_output.feature_acts.abs() > 0).float().sum(0)
|
|
261
|
-
)
|
|
263
|
+
self.act_freq_scores += firing_feats.sum(0)
|
|
262
264
|
self.n_frac_active_samples += self.cfg.train_batch_size_samples
|
|
263
265
|
|
|
264
266
|
# Grad scaler will rescale gradients if autocast is enabled
|
|
@@ -310,7 +312,7 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
|
310
312
|
loss = output.loss.item()
|
|
311
313
|
|
|
312
314
|
# metrics for currents acts
|
|
313
|
-
l0 = (
|
|
315
|
+
l0 = feature_acts.bool().float().sum(-1).to_dense().mean()
|
|
314
316
|
current_learning_rate = self.optimizer.param_groups[0]["lr"]
|
|
315
317
|
|
|
316
318
|
per_token_l2_loss = (sae_out - sae_in).pow(2).sum(dim=-1).squeeze()
|
|
@@ -1,271 +0,0 @@
|
|
|
1
|
-
"""Inference-only TopKSAE variant, similar in spirit to StandardSAE but using a TopK-based activation."""
|
|
2
|
-
|
|
3
|
-
from dataclasses import dataclass
|
|
4
|
-
from typing import Callable
|
|
5
|
-
|
|
6
|
-
import torch
|
|
7
|
-
from jaxtyping import Float
|
|
8
|
-
from torch import nn
|
|
9
|
-
from typing_extensions import override
|
|
10
|
-
|
|
11
|
-
from sae_lens.saes.sae import (
|
|
12
|
-
SAE,
|
|
13
|
-
SAEConfig,
|
|
14
|
-
TrainCoefficientConfig,
|
|
15
|
-
TrainingSAE,
|
|
16
|
-
TrainingSAEConfig,
|
|
17
|
-
TrainStepInput,
|
|
18
|
-
)
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
class TopK(nn.Module):
|
|
22
|
-
"""
|
|
23
|
-
A simple TopK activation that zeroes out all but the top K elements along the last dimension,
|
|
24
|
-
and applies ReLU to the top K elements.
|
|
25
|
-
"""
|
|
26
|
-
|
|
27
|
-
b_enc: nn.Parameter
|
|
28
|
-
|
|
29
|
-
def __init__(
|
|
30
|
-
self,
|
|
31
|
-
k: int,
|
|
32
|
-
):
|
|
33
|
-
super().__init__()
|
|
34
|
-
self.k = k
|
|
35
|
-
|
|
36
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
37
|
-
"""
|
|
38
|
-
1) Select top K elements along the last dimension.
|
|
39
|
-
2) Apply ReLU.
|
|
40
|
-
3) Zero out all other entries.
|
|
41
|
-
"""
|
|
42
|
-
topk = torch.topk(x, k=self.k, dim=-1)
|
|
43
|
-
values = topk.values.relu()
|
|
44
|
-
result = torch.zeros_like(x)
|
|
45
|
-
result.scatter_(-1, topk.indices, values)
|
|
46
|
-
return result
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
@dataclass
|
|
50
|
-
class TopKSAEConfig(SAEConfig):
|
|
51
|
-
"""
|
|
52
|
-
Configuration class for a TopKSAE.
|
|
53
|
-
"""
|
|
54
|
-
|
|
55
|
-
k: int = 100
|
|
56
|
-
|
|
57
|
-
@override
|
|
58
|
-
@classmethod
|
|
59
|
-
def architecture(cls) -> str:
|
|
60
|
-
return "topk"
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
class TopKSAE(SAE[TopKSAEConfig]):
|
|
64
|
-
"""
|
|
65
|
-
An inference-only sparse autoencoder using a "topk" activation function.
|
|
66
|
-
It uses linear encoder and decoder layers, applying the TopK activation
|
|
67
|
-
to the hidden pre-activation in its encode step.
|
|
68
|
-
"""
|
|
69
|
-
|
|
70
|
-
b_enc: nn.Parameter
|
|
71
|
-
|
|
72
|
-
def __init__(self, cfg: TopKSAEConfig, use_error_term: bool = False):
|
|
73
|
-
"""
|
|
74
|
-
Args:
|
|
75
|
-
cfg: SAEConfig defining model size and behavior.
|
|
76
|
-
use_error_term: Whether to apply the error-term approach in the forward pass.
|
|
77
|
-
"""
|
|
78
|
-
super().__init__(cfg, use_error_term)
|
|
79
|
-
|
|
80
|
-
@override
|
|
81
|
-
def initialize_weights(self) -> None:
|
|
82
|
-
# Initialize encoder weights and bias.
|
|
83
|
-
super().initialize_weights()
|
|
84
|
-
_init_weights_topk(self)
|
|
85
|
-
|
|
86
|
-
def encode(
|
|
87
|
-
self, x: Float[torch.Tensor, "... d_in"]
|
|
88
|
-
) -> Float[torch.Tensor, "... d_sae"]:
|
|
89
|
-
"""
|
|
90
|
-
Converts input x into feature activations.
|
|
91
|
-
Uses topk activation under the hood.
|
|
92
|
-
"""
|
|
93
|
-
sae_in = self.process_sae_in(x)
|
|
94
|
-
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
|
|
95
|
-
# The BaseSAE already sets self.activation_fn to TopK(...) if config requests topk.
|
|
96
|
-
return self.hook_sae_acts_post(self.activation_fn(hidden_pre))
|
|
97
|
-
|
|
98
|
-
def decode(
|
|
99
|
-
self, feature_acts: Float[torch.Tensor, "... d_sae"]
|
|
100
|
-
) -> Float[torch.Tensor, "... d_in"]:
|
|
101
|
-
"""
|
|
102
|
-
Reconstructs the input from topk feature activations.
|
|
103
|
-
Applies optional finetuning scaling, hooking to recons, out normalization,
|
|
104
|
-
and optional head reshaping.
|
|
105
|
-
"""
|
|
106
|
-
sae_out_pre = feature_acts @ self.W_dec + self.b_dec
|
|
107
|
-
sae_out_pre = self.hook_sae_recons(sae_out_pre)
|
|
108
|
-
sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
|
|
109
|
-
return self.reshape_fn_out(sae_out_pre, self.d_head)
|
|
110
|
-
|
|
111
|
-
@override
|
|
112
|
-
def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
|
|
113
|
-
return TopK(self.cfg.k)
|
|
114
|
-
|
|
115
|
-
@override
|
|
116
|
-
@torch.no_grad()
|
|
117
|
-
def fold_W_dec_norm(self) -> None:
|
|
118
|
-
raise NotImplementedError(
|
|
119
|
-
"Folding W_dec_norm is not safe for TopKSAEs, as this may change the topk activations"
|
|
120
|
-
)
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
@dataclass
|
|
124
|
-
class TopKTrainingSAEConfig(TrainingSAEConfig):
|
|
125
|
-
"""
|
|
126
|
-
Configuration class for training a TopKTrainingSAE.
|
|
127
|
-
"""
|
|
128
|
-
|
|
129
|
-
k: int = 100
|
|
130
|
-
aux_loss_coefficient: float = 1.0
|
|
131
|
-
|
|
132
|
-
@override
|
|
133
|
-
@classmethod
|
|
134
|
-
def architecture(cls) -> str:
|
|
135
|
-
return "topk"
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
|
|
139
|
-
"""
|
|
140
|
-
TopK variant with training functionality. Calculates a topk-related auxiliary loss, etc.
|
|
141
|
-
"""
|
|
142
|
-
|
|
143
|
-
b_enc: nn.Parameter
|
|
144
|
-
|
|
145
|
-
def __init__(self, cfg: TopKTrainingSAEConfig, use_error_term: bool = False):
|
|
146
|
-
super().__init__(cfg, use_error_term)
|
|
147
|
-
|
|
148
|
-
@override
|
|
149
|
-
def initialize_weights(self) -> None:
|
|
150
|
-
super().initialize_weights()
|
|
151
|
-
_init_weights_topk(self)
|
|
152
|
-
|
|
153
|
-
def encode_with_hidden_pre(
|
|
154
|
-
self, x: Float[torch.Tensor, "... d_in"]
|
|
155
|
-
) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
|
|
156
|
-
"""
|
|
157
|
-
Similar to the base training method: calculate pre-activations, then apply TopK.
|
|
158
|
-
"""
|
|
159
|
-
sae_in = self.process_sae_in(x)
|
|
160
|
-
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
|
|
161
|
-
|
|
162
|
-
# Apply the TopK activation function (already set in self.activation_fn if config is "topk")
|
|
163
|
-
feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre))
|
|
164
|
-
return feature_acts, hidden_pre
|
|
165
|
-
|
|
166
|
-
@override
|
|
167
|
-
def calculate_aux_loss(
|
|
168
|
-
self,
|
|
169
|
-
step_input: TrainStepInput,
|
|
170
|
-
feature_acts: torch.Tensor,
|
|
171
|
-
hidden_pre: torch.Tensor,
|
|
172
|
-
sae_out: torch.Tensor,
|
|
173
|
-
) -> dict[str, torch.Tensor]:
|
|
174
|
-
# Calculate the auxiliary loss for dead neurons
|
|
175
|
-
topk_loss = self.calculate_topk_aux_loss(
|
|
176
|
-
sae_in=step_input.sae_in,
|
|
177
|
-
sae_out=sae_out,
|
|
178
|
-
hidden_pre=hidden_pre,
|
|
179
|
-
dead_neuron_mask=step_input.dead_neuron_mask,
|
|
180
|
-
)
|
|
181
|
-
return {"auxiliary_reconstruction_loss": topk_loss}
|
|
182
|
-
|
|
183
|
-
@override
|
|
184
|
-
@torch.no_grad()
|
|
185
|
-
def fold_W_dec_norm(self) -> None:
|
|
186
|
-
raise NotImplementedError(
|
|
187
|
-
"Folding W_dec_norm is not safe for TopKSAEs, as this may change the topk activations"
|
|
188
|
-
)
|
|
189
|
-
|
|
190
|
-
@override
|
|
191
|
-
def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
|
|
192
|
-
return TopK(self.cfg.k)
|
|
193
|
-
|
|
194
|
-
@override
|
|
195
|
-
def get_coefficients(self) -> dict[str, TrainCoefficientConfig | float]:
|
|
196
|
-
return {}
|
|
197
|
-
|
|
198
|
-
def calculate_topk_aux_loss(
|
|
199
|
-
self,
|
|
200
|
-
sae_in: torch.Tensor,
|
|
201
|
-
sae_out: torch.Tensor,
|
|
202
|
-
hidden_pre: torch.Tensor,
|
|
203
|
-
dead_neuron_mask: torch.Tensor | None,
|
|
204
|
-
) -> torch.Tensor:
|
|
205
|
-
"""
|
|
206
|
-
Calculate TopK auxiliary loss.
|
|
207
|
-
|
|
208
|
-
This auxiliary loss encourages dead neurons to learn useful features by having
|
|
209
|
-
them reconstruct the residual error from the live neurons. It's a key part of
|
|
210
|
-
preventing neuron death in TopK SAEs.
|
|
211
|
-
"""
|
|
212
|
-
# Mostly taken from https://github.com/EleutherAI/sae/blob/main/sae/sae.py, except without variance normalization
|
|
213
|
-
# NOTE: checking the number of dead neurons will force a GPU sync, so performance can likely be improved here
|
|
214
|
-
if dead_neuron_mask is None or (num_dead := int(dead_neuron_mask.sum())) == 0:
|
|
215
|
-
return sae_out.new_tensor(0.0)
|
|
216
|
-
residual = (sae_in - sae_out).detach()
|
|
217
|
-
|
|
218
|
-
# Heuristic from Appendix B.1 in the paper
|
|
219
|
-
k_aux = sae_in.shape[-1] // 2
|
|
220
|
-
|
|
221
|
-
# Reduce the scale of the loss if there are a small number of dead latents
|
|
222
|
-
scale = min(num_dead / k_aux, 1.0)
|
|
223
|
-
k_aux = min(k_aux, num_dead)
|
|
224
|
-
|
|
225
|
-
auxk_acts = _calculate_topk_aux_acts(
|
|
226
|
-
k_aux=k_aux,
|
|
227
|
-
hidden_pre=hidden_pre,
|
|
228
|
-
dead_neuron_mask=dead_neuron_mask,
|
|
229
|
-
)
|
|
230
|
-
|
|
231
|
-
# Encourage the top ~50% of dead latents to predict the residual of the
|
|
232
|
-
# top k living latents
|
|
233
|
-
recons = self.decode(auxk_acts)
|
|
234
|
-
auxk_loss = (recons - residual).pow(2).sum(dim=-1).mean()
|
|
235
|
-
return self.cfg.aux_loss_coefficient * scale * auxk_loss
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
def _calculate_topk_aux_acts(
|
|
239
|
-
k_aux: int,
|
|
240
|
-
hidden_pre: torch.Tensor,
|
|
241
|
-
dead_neuron_mask: torch.Tensor,
|
|
242
|
-
) -> torch.Tensor:
|
|
243
|
-
"""
|
|
244
|
-
Helper method to calculate activations for the auxiliary loss.
|
|
245
|
-
|
|
246
|
-
Args:
|
|
247
|
-
k_aux: Number of top dead neurons to select
|
|
248
|
-
hidden_pre: Pre-activation values from encoder
|
|
249
|
-
dead_neuron_mask: Boolean mask indicating which neurons are dead
|
|
250
|
-
|
|
251
|
-
Returns:
|
|
252
|
-
Tensor with activations for only the top-k dead neurons, zeros elsewhere
|
|
253
|
-
"""
|
|
254
|
-
|
|
255
|
-
# Don't include living latents in this loss
|
|
256
|
-
auxk_latents = torch.where(dead_neuron_mask[None], hidden_pre, -torch.inf)
|
|
257
|
-
# Top-k dead latents
|
|
258
|
-
auxk_topk = auxk_latents.topk(k_aux, sorted=False)
|
|
259
|
-
# Set the activations to zero for all but the top k_aux dead latents
|
|
260
|
-
auxk_acts = torch.zeros_like(hidden_pre)
|
|
261
|
-
auxk_acts.scatter_(-1, auxk_topk.indices, auxk_topk.values)
|
|
262
|
-
# Set activations to zero for all but top k_aux dead latents
|
|
263
|
-
return auxk_acts
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
def _init_weights_topk(
|
|
267
|
-
sae: SAE[TopKSAEConfig] | TrainingSAE[TopKTrainingSAEConfig],
|
|
268
|
-
) -> None:
|
|
269
|
-
sae.b_enc = nn.Parameter(
|
|
270
|
-
torch.zeros(sae.cfg.d_sae, dtype=sae.dtype, device=sae.device)
|
|
271
|
-
)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|