sae-lens 6.12.1__py3-none-any.whl → 6.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +15 -1
- sae_lens/cache_activations_runner.py +1 -1
- sae_lens/config.py +39 -2
- sae_lens/constants.py +1 -0
- sae_lens/evals.py +20 -14
- sae_lens/llm_sae_training_runner.py +17 -18
- sae_lens/loading/pretrained_sae_loaders.py +194 -0
- sae_lens/loading/pretrained_saes_directory.py +5 -3
- sae_lens/pretokenize_runner.py +2 -1
- sae_lens/pretrained_saes.yaml +75 -1
- sae_lens/saes/__init__.py +9 -0
- sae_lens/saes/batchtopk_sae.py +32 -1
- sae_lens/saes/matryoshka_batchtopk_sae.py +137 -0
- sae_lens/saes/sae.py +22 -24
- sae_lens/saes/temporal_sae.py +372 -0
- sae_lens/saes/topk_sae.py +287 -17
- sae_lens/tokenization_and_batching.py +21 -6
- sae_lens/training/activation_scaler.py +7 -0
- sae_lens/training/activations_store.py +52 -31
- sae_lens/training/optim.py +11 -0
- sae_lens/training/sae_trainer.py +57 -16
- sae_lens/training/types.py +1 -1
- sae_lens/util.py +27 -0
- {sae_lens-6.12.1.dist-info → sae_lens-6.21.0.dist-info}/METADATA +19 -17
- sae_lens-6.21.0.dist-info/RECORD +41 -0
- {sae_lens-6.12.1.dist-info → sae_lens-6.21.0.dist-info}/WHEEL +1 -1
- sae_lens-6.12.1.dist-info/RECORD +0 -39
- {sae_lens-6.12.1.dist-info → sae_lens-6.21.0.dist-info/licenses}/LICENSE +0 -0
sae_lens/saes/topk_sae.py
CHANGED
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
"""Inference-only TopKSAE variant, similar in spirit to StandardSAE but using a TopK-based activation."""
|
|
2
2
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
|
-
from typing import Callable
|
|
4
|
+
from typing import Any, Callable
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
from jaxtyping import Float
|
|
8
8
|
from torch import nn
|
|
9
|
+
from transformer_lens.hook_points import HookPoint
|
|
9
10
|
from typing_extensions import override
|
|
10
11
|
|
|
11
12
|
from sae_lens.saes.sae import (
|
|
@@ -15,44 +16,138 @@ from sae_lens.saes.sae import (
|
|
|
15
16
|
TrainingSAE,
|
|
16
17
|
TrainingSAEConfig,
|
|
17
18
|
TrainStepInput,
|
|
19
|
+
_disable_hooks,
|
|
18
20
|
)
|
|
19
21
|
|
|
20
22
|
|
|
23
|
+
class SparseHookPoint(HookPoint):
|
|
24
|
+
"""
|
|
25
|
+
A HookPoint that takes in a sparse tensor.
|
|
26
|
+
Overrides TransformerLens's HookPoint.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(self, d_sae: int):
|
|
30
|
+
super().__init__()
|
|
31
|
+
self.d_sae = d_sae
|
|
32
|
+
|
|
33
|
+
@override
|
|
34
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
35
|
+
using_hooks = (
|
|
36
|
+
self._forward_hooks is not None and len(self._forward_hooks) > 0
|
|
37
|
+
) or (self._backward_hooks is not None and len(self._backward_hooks) > 0)
|
|
38
|
+
if using_hooks and x.is_sparse:
|
|
39
|
+
return x.to_dense()
|
|
40
|
+
return x # if no hooks are being used, use passthrough
|
|
41
|
+
|
|
42
|
+
|
|
21
43
|
class TopK(nn.Module):
|
|
22
44
|
"""
|
|
23
45
|
A simple TopK activation that zeroes out all but the top K elements along the last dimension,
|
|
24
46
|
and applies ReLU to the top K elements.
|
|
25
47
|
"""
|
|
26
48
|
|
|
27
|
-
|
|
49
|
+
use_sparse_activations: bool
|
|
28
50
|
|
|
29
51
|
def __init__(
|
|
30
52
|
self,
|
|
31
53
|
k: int,
|
|
54
|
+
use_sparse_activations: bool = False,
|
|
32
55
|
):
|
|
33
56
|
super().__init__()
|
|
34
57
|
self.k = k
|
|
58
|
+
self.use_sparse_activations = use_sparse_activations
|
|
35
59
|
|
|
36
|
-
def forward(
|
|
60
|
+
def forward(
|
|
61
|
+
self,
|
|
62
|
+
x: torch.Tensor,
|
|
63
|
+
) -> torch.Tensor:
|
|
37
64
|
"""
|
|
38
65
|
1) Select top K elements along the last dimension.
|
|
39
66
|
2) Apply ReLU.
|
|
40
67
|
3) Zero out all other entries.
|
|
41
68
|
"""
|
|
42
|
-
|
|
43
|
-
values =
|
|
69
|
+
topk_values, topk_indices = torch.topk(x, k=self.k, dim=-1, sorted=False)
|
|
70
|
+
values = topk_values.relu()
|
|
71
|
+
if self.use_sparse_activations:
|
|
72
|
+
# Produce a COO sparse tensor (use sparse matrix multiply in decode)
|
|
73
|
+
original_shape = x.shape
|
|
74
|
+
|
|
75
|
+
# Create indices for all dimensions
|
|
76
|
+
# For each element in topk_indices, we need to map it back to the original tensor coordinates
|
|
77
|
+
batch_dims = original_shape[:-1] # All dimensions except the last one
|
|
78
|
+
num_batch_elements = torch.prod(torch.tensor(batch_dims)).item()
|
|
79
|
+
|
|
80
|
+
# Create batch indices - each batch element repeated k times
|
|
81
|
+
batch_indices_flat = torch.arange(
|
|
82
|
+
num_batch_elements, device=x.device
|
|
83
|
+
).repeat_interleave(self.k)
|
|
84
|
+
|
|
85
|
+
# Convert flat batch indices back to multi-dimensional indices
|
|
86
|
+
if len(batch_dims) == 1:
|
|
87
|
+
# 2D case: [batch, features]
|
|
88
|
+
sparse_indices = torch.stack(
|
|
89
|
+
[
|
|
90
|
+
batch_indices_flat,
|
|
91
|
+
topk_indices.flatten(),
|
|
92
|
+
]
|
|
93
|
+
)
|
|
94
|
+
else:
|
|
95
|
+
# 3D+ case: need to unravel the batch indices
|
|
96
|
+
batch_indices_multi = []
|
|
97
|
+
remaining = batch_indices_flat
|
|
98
|
+
for dim_size in reversed(batch_dims):
|
|
99
|
+
batch_indices_multi.append(remaining % dim_size)
|
|
100
|
+
remaining = remaining // dim_size
|
|
101
|
+
batch_indices_multi.reverse()
|
|
102
|
+
|
|
103
|
+
sparse_indices = torch.stack(
|
|
104
|
+
[
|
|
105
|
+
*batch_indices_multi,
|
|
106
|
+
topk_indices.flatten(),
|
|
107
|
+
]
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
return torch.sparse_coo_tensor(
|
|
111
|
+
sparse_indices, values.flatten(), original_shape
|
|
112
|
+
)
|
|
44
113
|
result = torch.zeros_like(x)
|
|
45
|
-
result.scatter_(-1,
|
|
114
|
+
result.scatter_(-1, topk_indices, values)
|
|
46
115
|
return result
|
|
47
116
|
|
|
48
117
|
|
|
49
118
|
@dataclass
|
|
50
119
|
class TopKSAEConfig(SAEConfig):
|
|
51
120
|
"""
|
|
52
|
-
Configuration class for
|
|
121
|
+
Configuration class for TopKSAE inference.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
k (int): Number of top features to keep active during inference. Only the top k
|
|
125
|
+
features with the highest pre-activations will be non-zero. Defaults to 100.
|
|
126
|
+
rescale_acts_by_decoder_norm (bool): Whether to treat the decoder as if it was
|
|
127
|
+
already normalized. This affects the topk selection by rescaling pre-activations
|
|
128
|
+
by decoder norms. Requires that the SAE was trained this way. Defaults to False.
|
|
129
|
+
d_in (int): Input dimension (dimensionality of the activations being encoded).
|
|
130
|
+
Inherited from SAEConfig.
|
|
131
|
+
d_sae (int): SAE latent dimension (number of features in the SAE).
|
|
132
|
+
Inherited from SAEConfig.
|
|
133
|
+
dtype (str): Data type for the SAE parameters. Inherited from SAEConfig.
|
|
134
|
+
Defaults to "float32".
|
|
135
|
+
device (str): Device to place the SAE on. Inherited from SAEConfig.
|
|
136
|
+
Defaults to "cpu".
|
|
137
|
+
apply_b_dec_to_input (bool): Whether to apply decoder bias to the input
|
|
138
|
+
before encoding. Inherited from SAEConfig. Defaults to True.
|
|
139
|
+
normalize_activations (Literal["none", "expected_average_only_in", "constant_norm_rescale", "layer_norm"]):
|
|
140
|
+
Normalization strategy for input activations. Inherited from SAEConfig.
|
|
141
|
+
Defaults to "none".
|
|
142
|
+
reshape_activations (Literal["none", "hook_z"]): How to reshape activations
|
|
143
|
+
(useful for attention head outputs). Inherited from SAEConfig.
|
|
144
|
+
Defaults to "none".
|
|
145
|
+
metadata (SAEMetadata): Metadata about the SAE (model name, hook name, etc.).
|
|
146
|
+
Inherited from SAEConfig.
|
|
53
147
|
"""
|
|
54
148
|
|
|
55
149
|
k: int = 100
|
|
150
|
+
rescale_acts_by_decoder_norm: bool = False
|
|
56
151
|
|
|
57
152
|
@override
|
|
58
153
|
@classmethod
|
|
@@ -60,6 +155,63 @@ class TopKSAEConfig(SAEConfig):
|
|
|
60
155
|
return "topk"
|
|
61
156
|
|
|
62
157
|
|
|
158
|
+
def _sparse_matmul_nd(
|
|
159
|
+
sparse_tensor: torch.Tensor, dense_matrix: torch.Tensor
|
|
160
|
+
) -> torch.Tensor:
|
|
161
|
+
"""
|
|
162
|
+
Multiply a sparse tensor of shape [..., d_sae] with a dense matrix of shape [d_sae, d_out]
|
|
163
|
+
to get a result of shape [..., d_out].
|
|
164
|
+
|
|
165
|
+
This function handles sparse tensors with arbitrary batch dimensions by flattening
|
|
166
|
+
the batch dimensions, performing 2D sparse matrix multiplication, and reshaping back.
|
|
167
|
+
"""
|
|
168
|
+
original_shape = sparse_tensor.shape
|
|
169
|
+
batch_dims = original_shape[:-1]
|
|
170
|
+
d_sae = original_shape[-1]
|
|
171
|
+
d_out = dense_matrix.shape[-1]
|
|
172
|
+
|
|
173
|
+
if sparse_tensor.ndim == 2:
|
|
174
|
+
# Simple 2D case - use torch.sparse.mm directly
|
|
175
|
+
# sparse.mm errors with bfloat16 :(
|
|
176
|
+
with torch.autocast(device_type=sparse_tensor.device.type, enabled=False):
|
|
177
|
+
return torch.sparse.mm(sparse_tensor, dense_matrix)
|
|
178
|
+
|
|
179
|
+
# For 3D+ case, reshape to 2D, multiply, then reshape back
|
|
180
|
+
batch_size = int(torch.prod(torch.tensor(batch_dims)).item())
|
|
181
|
+
|
|
182
|
+
# Ensure tensor is coalesced for efficient access to indices/values
|
|
183
|
+
if not sparse_tensor.is_coalesced():
|
|
184
|
+
sparse_tensor = sparse_tensor.coalesce()
|
|
185
|
+
|
|
186
|
+
# Get indices and values
|
|
187
|
+
indices = sparse_tensor.indices() # [ndim, nnz]
|
|
188
|
+
values = sparse_tensor.values() # [nnz]
|
|
189
|
+
|
|
190
|
+
# Convert multi-dimensional batch indices to flat indices
|
|
191
|
+
flat_batch_indices = torch.zeros_like(indices[0])
|
|
192
|
+
multiplier = 1
|
|
193
|
+
for i in reversed(range(len(batch_dims))):
|
|
194
|
+
flat_batch_indices += indices[i] * multiplier
|
|
195
|
+
multiplier *= batch_dims[i]
|
|
196
|
+
|
|
197
|
+
# Create 2D sparse tensor indices [batch_flat, feature]
|
|
198
|
+
sparse_2d_indices = torch.stack([flat_batch_indices, indices[-1]])
|
|
199
|
+
|
|
200
|
+
# Create 2D sparse tensor
|
|
201
|
+
sparse_2d = torch.sparse_coo_tensor(
|
|
202
|
+
sparse_2d_indices, values, (batch_size, d_sae)
|
|
203
|
+
).coalesce()
|
|
204
|
+
|
|
205
|
+
# sparse.mm errors with bfloat16 :(
|
|
206
|
+
with torch.autocast(device_type=sparse_tensor.device.type, enabled=False):
|
|
207
|
+
# Do the matrix multiplication
|
|
208
|
+
result_2d = torch.sparse.mm(sparse_2d, dense_matrix) # [batch_size, d_out]
|
|
209
|
+
|
|
210
|
+
# Reshape back to original batch dimensions
|
|
211
|
+
result_shape = tuple(batch_dims) + (d_out,)
|
|
212
|
+
return result_2d.view(result_shape)
|
|
213
|
+
|
|
214
|
+
|
|
63
215
|
class TopKSAE(SAE[TopKSAEConfig]):
|
|
64
216
|
"""
|
|
65
217
|
An inference-only sparse autoencoder using a "topk" activation function.
|
|
@@ -92,42 +244,91 @@ class TopKSAE(SAE[TopKSAEConfig]):
|
|
|
92
244
|
"""
|
|
93
245
|
sae_in = self.process_sae_in(x)
|
|
94
246
|
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
|
|
247
|
+
if self.cfg.rescale_acts_by_decoder_norm:
|
|
248
|
+
hidden_pre = hidden_pre * self.W_dec.norm(dim=-1)
|
|
95
249
|
# The BaseSAE already sets self.activation_fn to TopK(...) if config requests topk.
|
|
96
250
|
return self.hook_sae_acts_post(self.activation_fn(hidden_pre))
|
|
97
251
|
|
|
98
252
|
def decode(
|
|
99
|
-
self,
|
|
253
|
+
self,
|
|
254
|
+
feature_acts: Float[torch.Tensor, "... d_sae"],
|
|
100
255
|
) -> Float[torch.Tensor, "... d_in"]:
|
|
101
256
|
"""
|
|
102
257
|
Reconstructs the input from topk feature activations.
|
|
103
258
|
Applies optional finetuning scaling, hooking to recons, out normalization,
|
|
104
259
|
and optional head reshaping.
|
|
105
260
|
"""
|
|
106
|
-
|
|
261
|
+
# Handle sparse tensors using efficient sparse matrix multiplication
|
|
262
|
+
if self.cfg.rescale_acts_by_decoder_norm:
|
|
263
|
+
feature_acts = feature_acts / self.W_dec.norm(dim=-1)
|
|
264
|
+
if feature_acts.is_sparse:
|
|
265
|
+
sae_out_pre = _sparse_matmul_nd(feature_acts, self.W_dec) + self.b_dec
|
|
266
|
+
else:
|
|
267
|
+
sae_out_pre = feature_acts @ self.W_dec + self.b_dec
|
|
107
268
|
sae_out_pre = self.hook_sae_recons(sae_out_pre)
|
|
108
269
|
sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
|
|
109
270
|
return self.reshape_fn_out(sae_out_pre, self.d_head)
|
|
110
271
|
|
|
111
272
|
@override
|
|
112
273
|
def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
|
|
113
|
-
return TopK(self.cfg.k)
|
|
274
|
+
return TopK(self.cfg.k, use_sparse_activations=False)
|
|
114
275
|
|
|
115
276
|
@override
|
|
116
277
|
@torch.no_grad()
|
|
117
278
|
def fold_W_dec_norm(self) -> None:
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
279
|
+
if not self.cfg.rescale_acts_by_decoder_norm:
|
|
280
|
+
raise NotImplementedError(
|
|
281
|
+
"Folding W_dec_norm is not safe for TopKSAEs when rescale_acts_by_decoder_norm is False, as this may change the topk activations"
|
|
282
|
+
)
|
|
283
|
+
_fold_norm_topk(W_dec=self.W_dec, b_enc=self.b_enc, W_enc=self.W_enc)
|
|
121
284
|
|
|
122
285
|
|
|
123
286
|
@dataclass
|
|
124
287
|
class TopKTrainingSAEConfig(TrainingSAEConfig):
|
|
125
288
|
"""
|
|
126
289
|
Configuration class for training a TopKTrainingSAE.
|
|
290
|
+
|
|
291
|
+
Args:
|
|
292
|
+
k (int): Number of top features to keep active. Only the top k features
|
|
293
|
+
with the highest pre-activations will be non-zero. Defaults to 100.
|
|
294
|
+
use_sparse_activations (bool): Whether to use sparse tensor representations
|
|
295
|
+
for activations during training. This can reduce memory usage and improve
|
|
296
|
+
performance when k is small relative to d_sae, but is only worthwhile if
|
|
297
|
+
using float32 and not using autocast. Defaults to False.
|
|
298
|
+
aux_loss_coefficient (float): Coefficient for the auxiliary loss that encourages
|
|
299
|
+
dead neurons to learn useful features. This loss helps prevent neuron death
|
|
300
|
+
in TopK SAEs by having dead neurons reconstruct the residual error from
|
|
301
|
+
live neurons. Defaults to 1.0.
|
|
302
|
+
rescale_acts_by_decoder_norm (bool): Treat the decoder as if it was already normalized.
|
|
303
|
+
This is a good idea since decoder norm can randomly drift during training, and this
|
|
304
|
+
affects what the topk activations will be. Defaults to True.
|
|
305
|
+
decoder_init_norm (float | None): Norm to initialize decoder weights to.
|
|
306
|
+
0.1 corresponds to the "heuristic" initialization from Anthropic's April update.
|
|
307
|
+
Use None to disable. Inherited from TrainingSAEConfig. Defaults to 0.1.
|
|
308
|
+
d_in (int): Input dimension (dimensionality of the activations being encoded).
|
|
309
|
+
Inherited from SAEConfig.
|
|
310
|
+
d_sae (int): SAE latent dimension (number of features in the SAE).
|
|
311
|
+
Inherited from SAEConfig.
|
|
312
|
+
dtype (str): Data type for the SAE parameters. Inherited from SAEConfig.
|
|
313
|
+
Defaults to "float32".
|
|
314
|
+
device (str): Device to place the SAE on. Inherited from SAEConfig.
|
|
315
|
+
Defaults to "cpu".
|
|
316
|
+
apply_b_dec_to_input (bool): Whether to apply decoder bias to the input
|
|
317
|
+
before encoding. Inherited from SAEConfig. Defaults to True.
|
|
318
|
+
normalize_activations (Literal["none", "expected_average_only_in", "constant_norm_rescale", "layer_norm"]):
|
|
319
|
+
Normalization strategy for input activations. Inherited from SAEConfig.
|
|
320
|
+
Defaults to "none".
|
|
321
|
+
reshape_activations (Literal["none", "hook_z"]): How to reshape activations
|
|
322
|
+
(useful for attention head outputs). Inherited from SAEConfig.
|
|
323
|
+
Defaults to "none".
|
|
324
|
+
metadata (SAEMetadata): Metadata about the SAE training (model name, hook name, etc.).
|
|
325
|
+
Inherited from SAEConfig.
|
|
127
326
|
"""
|
|
128
327
|
|
|
129
328
|
k: int = 100
|
|
329
|
+
use_sparse_activations: bool = False
|
|
130
330
|
aux_loss_coefficient: float = 1.0
|
|
331
|
+
rescale_acts_by_decoder_norm: bool = True
|
|
131
332
|
|
|
132
333
|
@override
|
|
133
334
|
@classmethod
|
|
@@ -144,6 +345,8 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
|
|
|
144
345
|
|
|
145
346
|
def __init__(self, cfg: TopKTrainingSAEConfig, use_error_term: bool = False):
|
|
146
347
|
super().__init__(cfg, use_error_term)
|
|
348
|
+
self.hook_sae_acts_post = SparseHookPoint(self.cfg.d_sae)
|
|
349
|
+
self.setup()
|
|
147
350
|
|
|
148
351
|
@override
|
|
149
352
|
def initialize_weights(self) -> None:
|
|
@@ -159,10 +362,51 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
|
|
|
159
362
|
sae_in = self.process_sae_in(x)
|
|
160
363
|
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
|
|
161
364
|
|
|
365
|
+
if self.cfg.rescale_acts_by_decoder_norm:
|
|
366
|
+
hidden_pre = hidden_pre * self.W_dec.norm(dim=-1)
|
|
367
|
+
|
|
162
368
|
# Apply the TopK activation function (already set in self.activation_fn if config is "topk")
|
|
163
369
|
feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre))
|
|
164
370
|
return feature_acts, hidden_pre
|
|
165
371
|
|
|
372
|
+
@override
|
|
373
|
+
def decode(
|
|
374
|
+
self,
|
|
375
|
+
feature_acts: Float[torch.Tensor, "... d_sae"],
|
|
376
|
+
) -> Float[torch.Tensor, "... d_in"]:
|
|
377
|
+
"""
|
|
378
|
+
Decodes feature activations back into input space,
|
|
379
|
+
applying optional finetuning scale, hooking, out normalization, etc.
|
|
380
|
+
"""
|
|
381
|
+
# Handle sparse tensors using efficient sparse matrix multiplication
|
|
382
|
+
if self.cfg.rescale_acts_by_decoder_norm:
|
|
383
|
+
# need to multiply by the inverse of the norm because division is illegal with sparse tensors
|
|
384
|
+
feature_acts = feature_acts * (1 / self.W_dec.norm(dim=-1))
|
|
385
|
+
if feature_acts.is_sparse:
|
|
386
|
+
sae_out_pre = _sparse_matmul_nd(feature_acts, self.W_dec) + self.b_dec
|
|
387
|
+
else:
|
|
388
|
+
sae_out_pre = feature_acts @ self.W_dec + self.b_dec
|
|
389
|
+
sae_out_pre = self.hook_sae_recons(sae_out_pre)
|
|
390
|
+
sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
|
|
391
|
+
return self.reshape_fn_out(sae_out_pre, self.d_head)
|
|
392
|
+
|
|
393
|
+
@override
|
|
394
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
395
|
+
"""Forward pass through the SAE."""
|
|
396
|
+
feature_acts = self.encode(x)
|
|
397
|
+
sae_out = self.decode(feature_acts)
|
|
398
|
+
|
|
399
|
+
if self.use_error_term:
|
|
400
|
+
with torch.no_grad():
|
|
401
|
+
# Recompute without hooks for true error term
|
|
402
|
+
with _disable_hooks(self):
|
|
403
|
+
feature_acts_clean = self.encode(x)
|
|
404
|
+
x_reconstruct_clean = self.decode(feature_acts_clean)
|
|
405
|
+
sae_error = self.hook_sae_error(x - x_reconstruct_clean)
|
|
406
|
+
sae_out = sae_out + sae_error
|
|
407
|
+
|
|
408
|
+
return self.hook_sae_output(sae_out)
|
|
409
|
+
|
|
166
410
|
@override
|
|
167
411
|
def calculate_aux_loss(
|
|
168
412
|
self,
|
|
@@ -183,13 +427,15 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
|
|
|
183
427
|
@override
|
|
184
428
|
@torch.no_grad()
|
|
185
429
|
def fold_W_dec_norm(self) -> None:
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
430
|
+
if not self.cfg.rescale_acts_by_decoder_norm:
|
|
431
|
+
raise NotImplementedError(
|
|
432
|
+
"Folding W_dec_norm is not safe for TopKSAEs when rescale_acts_by_decoder_norm is False, as this may change the topk activations"
|
|
433
|
+
)
|
|
434
|
+
_fold_norm_topk(W_dec=self.W_dec, b_enc=self.b_enc, W_enc=self.W_enc)
|
|
189
435
|
|
|
190
436
|
@override
|
|
191
437
|
def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
|
|
192
|
-
return TopK(self.cfg.k)
|
|
438
|
+
return TopK(self.cfg.k, use_sparse_activations=self.cfg.use_sparse_activations)
|
|
193
439
|
|
|
194
440
|
@override
|
|
195
441
|
def get_coefficients(self) -> dict[str, TrainCoefficientConfig | float]:
|
|
@@ -234,6 +480,18 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
|
|
|
234
480
|
auxk_loss = (recons - residual).pow(2).sum(dim=-1).mean()
|
|
235
481
|
return self.cfg.aux_loss_coefficient * scale * auxk_loss
|
|
236
482
|
|
|
483
|
+
@override
|
|
484
|
+
def process_state_dict_for_saving_inference(
|
|
485
|
+
self, state_dict: dict[str, Any]
|
|
486
|
+
) -> None:
|
|
487
|
+
super().process_state_dict_for_saving_inference(state_dict)
|
|
488
|
+
if self.cfg.rescale_acts_by_decoder_norm:
|
|
489
|
+
_fold_norm_topk(
|
|
490
|
+
W_enc=state_dict["W_enc"],
|
|
491
|
+
b_enc=state_dict["b_enc"],
|
|
492
|
+
W_dec=state_dict["W_dec"],
|
|
493
|
+
)
|
|
494
|
+
|
|
237
495
|
|
|
238
496
|
def _calculate_topk_aux_acts(
|
|
239
497
|
k_aux: int,
|
|
@@ -269,3 +527,15 @@ def _init_weights_topk(
|
|
|
269
527
|
sae.b_enc = nn.Parameter(
|
|
270
528
|
torch.zeros(sae.cfg.d_sae, dtype=sae.dtype, device=sae.device)
|
|
271
529
|
)
|
|
530
|
+
|
|
531
|
+
|
|
532
|
+
def _fold_norm_topk(
|
|
533
|
+
W_enc: torch.Tensor,
|
|
534
|
+
b_enc: torch.Tensor,
|
|
535
|
+
W_dec: torch.Tensor,
|
|
536
|
+
) -> None:
|
|
537
|
+
W_dec_norm = W_dec.norm(dim=-1)
|
|
538
|
+
b_enc.data = b_enc.data * W_dec_norm
|
|
539
|
+
W_dec_norms = W_dec_norm.unsqueeze(1)
|
|
540
|
+
W_dec.data = W_dec.data / W_dec_norms
|
|
541
|
+
W_enc.data = W_enc.data * W_dec_norms.T
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from
|
|
1
|
+
from collections.abc import Generator, Iterator
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
|
|
@@ -68,7 +68,7 @@ def concat_and_batch_sequences(
|
|
|
68
68
|
) -> Generator[torch.Tensor, None, None]:
|
|
69
69
|
"""
|
|
70
70
|
Generator to concat token sequences together from the tokens_interator, yielding
|
|
71
|
-
|
|
71
|
+
sequences of size `context_size`. Batching across the batch dimension is handled by the caller.
|
|
72
72
|
|
|
73
73
|
Args:
|
|
74
74
|
tokens_iterator: An iterator which returns a 1D tensors of tokens
|
|
@@ -76,13 +76,28 @@ def concat_and_batch_sequences(
|
|
|
76
76
|
begin_batch_token_id: If provided, this token will be at position 0 of each batch
|
|
77
77
|
begin_sequence_token_id: If provided, this token will be the first token of each sequence
|
|
78
78
|
sequence_separator_token_id: If provided, this token will be inserted between concatenated sequences
|
|
79
|
-
disable_concat_sequences: If True, disable concatenating sequences and ignore sequences shorter than context_size
|
|
79
|
+
disable_concat_sequences: If True, disable concatenating sequences and ignore sequences shorter than context_size (including BOS token if present)
|
|
80
80
|
max_batches: If not provided, the iterator will be run to completion.
|
|
81
81
|
"""
|
|
82
82
|
if disable_concat_sequences:
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
83
|
+
if begin_batch_token_id and not begin_sequence_token_id:
|
|
84
|
+
begin_sequence_token_id = begin_batch_token_id
|
|
85
|
+
for sequence in tokens_iterator:
|
|
86
|
+
if (
|
|
87
|
+
begin_sequence_token_id is not None
|
|
88
|
+
and sequence[0] != begin_sequence_token_id
|
|
89
|
+
and len(sequence) >= context_size - 1
|
|
90
|
+
):
|
|
91
|
+
begin_sequence_token_id_tensor = torch.tensor(
|
|
92
|
+
[begin_sequence_token_id],
|
|
93
|
+
dtype=torch.long,
|
|
94
|
+
device=sequence.device,
|
|
95
|
+
)
|
|
96
|
+
sequence = torch.cat(
|
|
97
|
+
[begin_sequence_token_id_tensor, sequence[: context_size - 1]]
|
|
98
|
+
)
|
|
99
|
+
if len(sequence) >= context_size:
|
|
100
|
+
yield sequence[:context_size]
|
|
86
101
|
return
|
|
87
102
|
|
|
88
103
|
batch: torch.Tensor | None = None
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import json
|
|
2
2
|
from dataclasses import dataclass
|
|
3
|
+
from pathlib import Path
|
|
3
4
|
from statistics import mean
|
|
4
5
|
|
|
5
6
|
import torch
|
|
@@ -51,3 +52,9 @@ class ActivationScaler:
|
|
|
51
52
|
|
|
52
53
|
with open(file_path, "w") as f:
|
|
53
54
|
json.dump({"scaling_factor": self.scaling_factor}, f)
|
|
55
|
+
|
|
56
|
+
def load(self, file_path: str | Path):
|
|
57
|
+
"""load the state dict from a file in json format"""
|
|
58
|
+
with open(file_path) as f:
|
|
59
|
+
data = json.load(f)
|
|
60
|
+
self.scaling_factor = data["scaling_factor"]
|
|
@@ -4,6 +4,7 @@ import json
|
|
|
4
4
|
import os
|
|
5
5
|
import warnings
|
|
6
6
|
from collections.abc import Generator, Iterator, Sequence
|
|
7
|
+
from pathlib import Path
|
|
7
8
|
from typing import Any, Literal, cast
|
|
8
9
|
|
|
9
10
|
import datasets
|
|
@@ -13,8 +14,8 @@ from huggingface_hub import hf_hub_download
|
|
|
13
14
|
from huggingface_hub.utils import HfHubHTTPError
|
|
14
15
|
from jaxtyping import Float, Int
|
|
15
16
|
from requests import HTTPError
|
|
16
|
-
from safetensors.torch import save_file
|
|
17
|
-
from tqdm import tqdm
|
|
17
|
+
from safetensors.torch import load_file, save_file
|
|
18
|
+
from tqdm.auto import tqdm
|
|
18
19
|
from transformer_lens.hook_points import HookedRootModule
|
|
19
20
|
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
|
20
21
|
|
|
@@ -24,12 +25,15 @@ from sae_lens.config import (
|
|
|
24
25
|
HfDataset,
|
|
25
26
|
LanguageModelSAERunnerConfig,
|
|
26
27
|
)
|
|
27
|
-
from sae_lens.constants import DTYPE_MAP
|
|
28
|
+
from sae_lens.constants import ACTIVATIONS_STORE_STATE_FILENAME, DTYPE_MAP
|
|
28
29
|
from sae_lens.pretokenize_runner import get_special_token_from_cfg
|
|
29
30
|
from sae_lens.saes.sae import SAE, T_SAE_CONFIG, T_TRAINING_SAE_CONFIG
|
|
30
31
|
from sae_lens.tokenization_and_batching import concat_and_batch_sequences
|
|
31
32
|
from sae_lens.training.mixing_buffer import mixing_buffer
|
|
32
|
-
from sae_lens.util import
|
|
33
|
+
from sae_lens.util import (
|
|
34
|
+
extract_stop_at_layer_from_tlens_hook_name,
|
|
35
|
+
get_special_token_ids,
|
|
36
|
+
)
|
|
33
37
|
|
|
34
38
|
|
|
35
39
|
# TODO: Make an activation store config class to be consistent with the rest of the code.
|
|
@@ -113,7 +117,7 @@ class ActivationsStore:
|
|
|
113
117
|
if exclude_special_tokens is False:
|
|
114
118
|
exclude_special_tokens = None
|
|
115
119
|
if exclude_special_tokens is True:
|
|
116
|
-
exclude_special_tokens =
|
|
120
|
+
exclude_special_tokens = get_special_token_ids(model.tokenizer) # type: ignore
|
|
117
121
|
if exclude_special_tokens is not None:
|
|
118
122
|
exclude_special_tokens = torch.tensor(
|
|
119
123
|
exclude_special_tokens, dtype=torch.long, device=device
|
|
@@ -315,7 +319,7 @@ class ActivationsStore:
|
|
|
315
319
|
)
|
|
316
320
|
else:
|
|
317
321
|
warnings.warn(
|
|
318
|
-
"Dataset is not tokenized. Pre-tokenizing will improve performance and allows for more control over special tokens. See https://
|
|
322
|
+
"Dataset is not tokenized. Pre-tokenizing will improve performance and allows for more control over special tokens. See https://decoderesearch.github.io/SAELens/training_saes/#pretokenizing-datasets for more info."
|
|
319
323
|
)
|
|
320
324
|
|
|
321
325
|
self.iterable_sequences = self._iterate_tokenized_sequences()
|
|
@@ -726,6 +730,48 @@ class ActivationsStore:
|
|
|
726
730
|
"""save the state dict to a file in safetensors format"""
|
|
727
731
|
save_file(self.state_dict(), file_path)
|
|
728
732
|
|
|
733
|
+
def save_to_checkpoint(self, checkpoint_path: str | Path):
|
|
734
|
+
"""Save the state dict to a checkpoint path"""
|
|
735
|
+
self.save(str(Path(checkpoint_path) / ACTIVATIONS_STORE_STATE_FILENAME))
|
|
736
|
+
|
|
737
|
+
def load_from_checkpoint(self, checkpoint_path: str | Path):
|
|
738
|
+
"""Load the state dict from a checkpoint path"""
|
|
739
|
+
self.load(str(Path(checkpoint_path) / ACTIVATIONS_STORE_STATE_FILENAME))
|
|
740
|
+
|
|
741
|
+
def load(self, file_path: str):
|
|
742
|
+
"""Load the state dict from a file in safetensors format"""
|
|
743
|
+
|
|
744
|
+
state_dict = load_file(file_path)
|
|
745
|
+
|
|
746
|
+
if "n_dataset_processed" in state_dict:
|
|
747
|
+
target_n_dataset_processed = state_dict["n_dataset_processed"].item()
|
|
748
|
+
|
|
749
|
+
# Only fast-forward if needed
|
|
750
|
+
|
|
751
|
+
if target_n_dataset_processed > self.n_dataset_processed:
|
|
752
|
+
logger.info(
|
|
753
|
+
"Fast-forwarding through dataset samples to match checkpoint position"
|
|
754
|
+
)
|
|
755
|
+
samples_to_skip = target_n_dataset_processed - self.n_dataset_processed
|
|
756
|
+
|
|
757
|
+
pbar = tqdm(
|
|
758
|
+
total=samples_to_skip,
|
|
759
|
+
desc="Fast-forwarding through dataset",
|
|
760
|
+
leave=False,
|
|
761
|
+
)
|
|
762
|
+
while target_n_dataset_processed > self.n_dataset_processed:
|
|
763
|
+
start = self.n_dataset_processed
|
|
764
|
+
try:
|
|
765
|
+
# Just consume and ignore the values to fast-forward
|
|
766
|
+
next(self.iterable_sequences)
|
|
767
|
+
except StopIteration:
|
|
768
|
+
logger.warning(
|
|
769
|
+
"Dataset exhausted during fast-forward. Resetting dataset."
|
|
770
|
+
)
|
|
771
|
+
self.iterable_sequences = self._iterate_tokenized_sequences()
|
|
772
|
+
pbar.update(self.n_dataset_processed - start)
|
|
773
|
+
pbar.close()
|
|
774
|
+
|
|
729
775
|
|
|
730
776
|
def validate_pretokenized_dataset_tokenizer(
|
|
731
777
|
dataset_path: str, model_tokenizer: PreTrainedTokenizerBase
|
|
@@ -763,31 +809,6 @@ def _get_model_device(model: HookedRootModule) -> torch.device:
|
|
|
763
809
|
return next(model.parameters()).device # type: ignore
|
|
764
810
|
|
|
765
811
|
|
|
766
|
-
def _get_special_token_ids(tokenizer: PreTrainedTokenizerBase) -> list[int]:
|
|
767
|
-
"""Get all special token IDs from a tokenizer."""
|
|
768
|
-
special_tokens = set()
|
|
769
|
-
|
|
770
|
-
# Get special tokens from tokenizer attributes
|
|
771
|
-
for attr in dir(tokenizer):
|
|
772
|
-
if attr.endswith("_token_id"):
|
|
773
|
-
token_id = getattr(tokenizer, attr)
|
|
774
|
-
if token_id is not None:
|
|
775
|
-
special_tokens.add(token_id)
|
|
776
|
-
|
|
777
|
-
# Get any additional special tokens from the tokenizer's special tokens map
|
|
778
|
-
if hasattr(tokenizer, "special_tokens_map"):
|
|
779
|
-
for token in tokenizer.special_tokens_map.values():
|
|
780
|
-
if isinstance(token, str):
|
|
781
|
-
token_id = tokenizer.convert_tokens_to_ids(token) # type: ignore
|
|
782
|
-
special_tokens.add(token_id)
|
|
783
|
-
elif isinstance(token, list):
|
|
784
|
-
for t in token:
|
|
785
|
-
token_id = tokenizer.convert_tokens_to_ids(t) # type: ignore
|
|
786
|
-
special_tokens.add(token_id)
|
|
787
|
-
|
|
788
|
-
return list(special_tokens)
|
|
789
|
-
|
|
790
|
-
|
|
791
812
|
def _filter_buffer_acts(
|
|
792
813
|
buffer: tuple[torch.Tensor, torch.Tensor | None],
|
|
793
814
|
exclude_tokens: torch.Tensor | None,
|
sae_lens/training/optim.py
CHANGED
|
@@ -2,6 +2,8 @@
|
|
|
2
2
|
Took the LR scheduler from my previous work: https://github.com/jbloomAus/DecisionTransformerInterpretability/blob/ee55df35cdb92e81d689c72fb9dd5a7252893363/src/decision_transformer/utils.py#L425
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
5
7
|
import torch.optim as optim
|
|
6
8
|
import torch.optim.lr_scheduler as lr_scheduler
|
|
7
9
|
|
|
@@ -150,3 +152,12 @@ class CoefficientScheduler:
|
|
|
150
152
|
def value(self) -> float:
|
|
151
153
|
"""Returns the current scalar value."""
|
|
152
154
|
return self.current_value
|
|
155
|
+
|
|
156
|
+
def state_dict(self) -> dict[str, Any]:
|
|
157
|
+
return {
|
|
158
|
+
"current_step": self.current_step,
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
def load_state_dict(self, state_dict: dict[str, Any]):
|
|
162
|
+
for k in state_dict:
|
|
163
|
+
setattr(self, k, state_dict[k])
|