sae-lens 6.15.0__py3-none-any.whl → 6.24.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +13 -1
- sae_lens/analysis/hooked_sae_transformer.py +4 -13
- sae_lens/cache_activations_runner.py +3 -4
- sae_lens/config.py +39 -2
- sae_lens/constants.py +1 -0
- sae_lens/llm_sae_training_runner.py +9 -4
- sae_lens/loading/pretrained_sae_loaders.py +430 -24
- sae_lens/loading/pretrained_saes_directory.py +5 -3
- sae_lens/pretokenize_runner.py +3 -3
- sae_lens/pretrained_saes.yaml +26977 -65
- sae_lens/saes/__init__.py +7 -0
- sae_lens/saes/batchtopk_sae.py +3 -1
- sae_lens/saes/gated_sae.py +6 -11
- sae_lens/saes/jumprelu_sae.py +8 -13
- sae_lens/saes/matryoshka_batchtopk_sae.py +8 -15
- sae_lens/saes/sae.py +20 -32
- sae_lens/saes/standard_sae.py +4 -9
- sae_lens/saes/temporal_sae.py +365 -0
- sae_lens/saes/topk_sae.py +8 -11
- sae_lens/saes/transcoder.py +41 -0
- sae_lens/training/activation_scaler.py +7 -0
- sae_lens/training/activations_store.py +54 -12
- sae_lens/training/optim.py +11 -0
- sae_lens/training/sae_trainer.py +50 -11
- {sae_lens-6.15.0.dist-info → sae_lens-6.24.1.dist-info}/METADATA +16 -16
- sae_lens-6.24.1.dist-info/RECORD +41 -0
- sae_lens-6.15.0.dist-info/RECORD +0 -40
- {sae_lens-6.15.0.dist-info → sae_lens-6.24.1.dist-info}/WHEEL +0 -0
- {sae_lens-6.15.0.dist-info → sae_lens-6.24.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,365 @@
|
|
|
1
|
+
"""TemporalSAE: A Sparse Autoencoder with temporal attention mechanism.
|
|
2
|
+
|
|
3
|
+
TemporalSAE decomposes activations into:
|
|
4
|
+
1. Predicted codes (from attention over context)
|
|
5
|
+
2. Novel codes (sparse features of the residual)
|
|
6
|
+
|
|
7
|
+
See: https://arxiv.org/abs/2410.04185
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import math
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from typing import Literal
|
|
13
|
+
|
|
14
|
+
import torch
|
|
15
|
+
import torch.nn.functional as F
|
|
16
|
+
from torch import nn
|
|
17
|
+
from typing_extensions import override
|
|
18
|
+
|
|
19
|
+
from sae_lens import logger
|
|
20
|
+
from sae_lens.saes.sae import SAE, SAEConfig
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def get_attention(query: torch.Tensor, key: torch.Tensor) -> torch.Tensor:
|
|
24
|
+
"""Compute causal attention weights."""
|
|
25
|
+
L, S = query.size(-2), key.size(-2)
|
|
26
|
+
scale_factor = 1 / math.sqrt(query.size(-1))
|
|
27
|
+
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
|
|
28
|
+
temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0)
|
|
29
|
+
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
|
30
|
+
attn_bias.to(query.dtype)
|
|
31
|
+
|
|
32
|
+
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
|
33
|
+
attn_weight += attn_bias
|
|
34
|
+
return torch.softmax(attn_weight, dim=-1)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class ManualAttention(nn.Module):
|
|
38
|
+
"""Manual attention implementation for TemporalSAE."""
|
|
39
|
+
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
dimin: int,
|
|
43
|
+
n_heads: int = 4,
|
|
44
|
+
bottleneck_factor: int = 64,
|
|
45
|
+
bias_k: bool = True,
|
|
46
|
+
bias_q: bool = True,
|
|
47
|
+
bias_v: bool = True,
|
|
48
|
+
bias_o: bool = True,
|
|
49
|
+
):
|
|
50
|
+
super().__init__()
|
|
51
|
+
assert dimin % (bottleneck_factor * n_heads) == 0
|
|
52
|
+
|
|
53
|
+
self.n_heads = n_heads
|
|
54
|
+
self.n_embds = dimin // bottleneck_factor
|
|
55
|
+
self.dimin = dimin
|
|
56
|
+
|
|
57
|
+
# Key, query, value projections
|
|
58
|
+
self.k_ctx = nn.Linear(dimin, self.n_embds, bias=bias_k)
|
|
59
|
+
self.q_target = nn.Linear(dimin, self.n_embds, bias=bias_q)
|
|
60
|
+
self.v_ctx = nn.Linear(dimin, dimin, bias=bias_v)
|
|
61
|
+
self.c_proj = nn.Linear(dimin, dimin, bias=bias_o)
|
|
62
|
+
|
|
63
|
+
# Normalize to match scale with representations
|
|
64
|
+
with torch.no_grad():
|
|
65
|
+
scaling = 1 / math.sqrt(self.n_embds // self.n_heads)
|
|
66
|
+
self.k_ctx.weight.copy_(
|
|
67
|
+
scaling
|
|
68
|
+
* self.k_ctx.weight
|
|
69
|
+
/ (1e-6 + torch.linalg.norm(self.k_ctx.weight, dim=1, keepdim=True))
|
|
70
|
+
)
|
|
71
|
+
self.q_target.weight.copy_(
|
|
72
|
+
scaling
|
|
73
|
+
* self.q_target.weight
|
|
74
|
+
/ (1e-6 + torch.linalg.norm(self.q_target.weight, dim=1, keepdim=True))
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
scaling = 1 / math.sqrt(self.dimin // self.n_heads)
|
|
78
|
+
self.v_ctx.weight.copy_(
|
|
79
|
+
scaling
|
|
80
|
+
* self.v_ctx.weight
|
|
81
|
+
/ (1e-6 + torch.linalg.norm(self.v_ctx.weight, dim=1, keepdim=True))
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
scaling = 1 / math.sqrt(self.dimin)
|
|
85
|
+
self.c_proj.weight.copy_(
|
|
86
|
+
scaling
|
|
87
|
+
* self.c_proj.weight
|
|
88
|
+
/ (1e-6 + torch.linalg.norm(self.c_proj.weight, dim=1, keepdim=True))
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
def forward(
|
|
92
|
+
self, x_ctx: torch.Tensor, x_target: torch.Tensor, get_attn_map: bool = False
|
|
93
|
+
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
94
|
+
"""Compute projective attention output."""
|
|
95
|
+
k = self.k_ctx(x_ctx)
|
|
96
|
+
v = self.v_ctx(x_ctx)
|
|
97
|
+
q = self.q_target(x_target)
|
|
98
|
+
|
|
99
|
+
# Split into heads
|
|
100
|
+
B, T, _ = x_ctx.size()
|
|
101
|
+
k = k.view(B, T, self.n_heads, self.n_embds // self.n_heads).transpose(1, 2)
|
|
102
|
+
q = q.view(B, T, self.n_heads, self.n_embds // self.n_heads).transpose(1, 2)
|
|
103
|
+
v = v.view(B, T, self.n_heads, self.dimin // self.n_heads).transpose(1, 2)
|
|
104
|
+
|
|
105
|
+
# Attention map (optional)
|
|
106
|
+
attn_map = None
|
|
107
|
+
if get_attn_map:
|
|
108
|
+
attn_map = get_attention(query=q, key=k)
|
|
109
|
+
|
|
110
|
+
# Scaled dot-product attention
|
|
111
|
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
|
112
|
+
q, k, v, attn_mask=None, dropout_p=0, is_causal=True
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
# Reshape and project
|
|
116
|
+
d_target = self.c_proj(
|
|
117
|
+
attn_output.transpose(1, 2).contiguous().view(B, T, self.dimin)
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
return d_target, attn_map
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
@dataclass
|
|
124
|
+
class TemporalSAEConfig(SAEConfig):
|
|
125
|
+
"""Configuration for TemporalSAE inference.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
d_in: Input dimension (dimensionality of the activations being encoded)
|
|
129
|
+
d_sae: SAE latent dimension (number of features)
|
|
130
|
+
n_heads: Number of attention heads in temporal attention
|
|
131
|
+
n_attn_layers: Number of attention layers
|
|
132
|
+
bottleneck_factor: Bottleneck factor for attention dimension
|
|
133
|
+
sae_diff_type: Type of SAE for novel codes ('relu' or 'topk')
|
|
134
|
+
kval_topk: K value for top-k sparsity (if sae_diff_type='topk')
|
|
135
|
+
tied_weights: Whether to tie encoder and decoder weights
|
|
136
|
+
activation_normalization_factor: Scalar factor for rescaling activations (used with normalize_activations='constant_scalar_rescale')
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
n_heads: int = 8
|
|
140
|
+
n_attn_layers: int = 1
|
|
141
|
+
bottleneck_factor: int = 64
|
|
142
|
+
sae_diff_type: Literal["relu", "topk"] = "topk"
|
|
143
|
+
kval_topk: int | None = None
|
|
144
|
+
tied_weights: bool = True
|
|
145
|
+
activation_normalization_factor: float = 1.0
|
|
146
|
+
|
|
147
|
+
def __post_init__(self):
|
|
148
|
+
# Call parent's __post_init__ first, but allow constant_scalar_rescale
|
|
149
|
+
if self.normalize_activations not in [
|
|
150
|
+
"none",
|
|
151
|
+
"expected_average_only_in",
|
|
152
|
+
"constant_norm_rescale",
|
|
153
|
+
"constant_scalar_rescale", # Temporal SAEs support this
|
|
154
|
+
"layer_norm",
|
|
155
|
+
]:
|
|
156
|
+
raise ValueError(
|
|
157
|
+
f"normalize_activations must be none, expected_average_only_in, layer_norm, constant_norm_rescale, or constant_scalar_rescale. Got {self.normalize_activations}"
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
@override
|
|
161
|
+
@classmethod
|
|
162
|
+
def architecture(cls) -> str:
|
|
163
|
+
return "temporal"
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
class TemporalSAE(SAE[TemporalSAEConfig]):
|
|
167
|
+
"""TemporalSAE: Sparse Autoencoder with temporal attention.
|
|
168
|
+
|
|
169
|
+
This SAE decomposes each activation x_t into:
|
|
170
|
+
- x_pred: Information aggregated from context {x_0, ..., x_{t-1}}
|
|
171
|
+
- x_novel: Novel information at position t (encoded sparsely)
|
|
172
|
+
|
|
173
|
+
The forward pass:
|
|
174
|
+
1. Uses attention layers to predict x_t from context
|
|
175
|
+
2. Encodes the residual (novel part) with a sparse SAE
|
|
176
|
+
3. Combines both for reconstruction
|
|
177
|
+
"""
|
|
178
|
+
|
|
179
|
+
# Custom parameters (in addition to W_enc, W_dec, b_dec from base)
|
|
180
|
+
attn_layers: nn.ModuleList # Attention layers
|
|
181
|
+
eps: float
|
|
182
|
+
lam: float
|
|
183
|
+
|
|
184
|
+
def __init__(self, cfg: TemporalSAEConfig, use_error_term: bool = False):
|
|
185
|
+
# Call parent init first
|
|
186
|
+
super().__init__(cfg, use_error_term)
|
|
187
|
+
|
|
188
|
+
# Initialize attention layers after parent init and move to correct device
|
|
189
|
+
self.attn_layers = nn.ModuleList(
|
|
190
|
+
[
|
|
191
|
+
ManualAttention(
|
|
192
|
+
dimin=cfg.d_sae,
|
|
193
|
+
n_heads=cfg.n_heads,
|
|
194
|
+
bottleneck_factor=cfg.bottleneck_factor,
|
|
195
|
+
bias_k=True,
|
|
196
|
+
bias_q=True,
|
|
197
|
+
bias_v=True,
|
|
198
|
+
bias_o=True,
|
|
199
|
+
).to(device=self.device, dtype=self.dtype)
|
|
200
|
+
for _ in range(cfg.n_attn_layers)
|
|
201
|
+
]
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
self.eps = 1e-6
|
|
205
|
+
self.lam = 1 / (4 * self.cfg.d_in)
|
|
206
|
+
|
|
207
|
+
@override
|
|
208
|
+
def _setup_activation_normalization(self):
|
|
209
|
+
"""Set up activation normalization functions for TemporalSAE.
|
|
210
|
+
|
|
211
|
+
Overrides the base implementation to handle constant_scalar_rescale
|
|
212
|
+
using the temporal-specific activation_normalization_factor.
|
|
213
|
+
"""
|
|
214
|
+
if self.cfg.normalize_activations == "constant_scalar_rescale":
|
|
215
|
+
# Handle constant scalar rescaling for temporal SAEs
|
|
216
|
+
def run_time_activation_norm_fn_in(x: torch.Tensor) -> torch.Tensor:
|
|
217
|
+
return x * self.cfg.activation_normalization_factor
|
|
218
|
+
|
|
219
|
+
def run_time_activation_norm_fn_out(x: torch.Tensor) -> torch.Tensor:
|
|
220
|
+
return x / self.cfg.activation_normalization_factor
|
|
221
|
+
|
|
222
|
+
self.run_time_activation_norm_fn_in = run_time_activation_norm_fn_in
|
|
223
|
+
self.run_time_activation_norm_fn_out = run_time_activation_norm_fn_out
|
|
224
|
+
else:
|
|
225
|
+
# Delegate to parent for all other normalization types
|
|
226
|
+
super()._setup_activation_normalization()
|
|
227
|
+
|
|
228
|
+
@override
|
|
229
|
+
def initialize_weights(self) -> None:
|
|
230
|
+
"""Initialize TemporalSAE weights."""
|
|
231
|
+
# Initialize D (decoder) and b (bias)
|
|
232
|
+
self.W_dec = nn.Parameter(
|
|
233
|
+
torch.randn(
|
|
234
|
+
(self.cfg.d_sae, self.cfg.d_in), dtype=self.dtype, device=self.device
|
|
235
|
+
)
|
|
236
|
+
)
|
|
237
|
+
self.b_dec = nn.Parameter(
|
|
238
|
+
torch.zeros((self.cfg.d_in), dtype=self.dtype, device=self.device)
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
# Initialize E (encoder) if not tied
|
|
242
|
+
if not self.cfg.tied_weights:
|
|
243
|
+
self.W_enc = nn.Parameter(
|
|
244
|
+
torch.randn(
|
|
245
|
+
(self.cfg.d_in, self.cfg.d_sae),
|
|
246
|
+
dtype=self.dtype,
|
|
247
|
+
device=self.device,
|
|
248
|
+
)
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
def encode_with_predictions(
|
|
252
|
+
self, x: torch.Tensor
|
|
253
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
254
|
+
"""Encode input to novel codes only.
|
|
255
|
+
|
|
256
|
+
Returns only the sparse novel codes (not predicted codes).
|
|
257
|
+
This is the main feature representation for TemporalSAE.
|
|
258
|
+
"""
|
|
259
|
+
# Process input through SAELens preprocessing
|
|
260
|
+
x = self.process_sae_in(x)
|
|
261
|
+
|
|
262
|
+
B, L, _ = x.shape
|
|
263
|
+
|
|
264
|
+
if self.cfg.tied_weights: # noqa: SIM108
|
|
265
|
+
W_enc = self.W_dec.T
|
|
266
|
+
else:
|
|
267
|
+
W_enc = self.W_enc
|
|
268
|
+
|
|
269
|
+
# Compute predicted codes using attention
|
|
270
|
+
x_residual = x
|
|
271
|
+
z_pred = torch.zeros((B, L, self.cfg.d_sae), device=x.device, dtype=x.dtype)
|
|
272
|
+
|
|
273
|
+
for attn_layer in self.attn_layers:
|
|
274
|
+
# Encode input to latent space
|
|
275
|
+
z_input = F.relu(torch.matmul(x_residual * self.lam, W_enc))
|
|
276
|
+
|
|
277
|
+
# Shift context (causal masking)
|
|
278
|
+
z_ctx = torch.cat(
|
|
279
|
+
(torch.zeros_like(z_input[:, :1, :]), z_input[:, :-1, :].clone()), dim=1
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
# Apply attention to get predicted codes
|
|
283
|
+
z_pred_, _ = attn_layer(z_ctx, z_input, get_attn_map=False)
|
|
284
|
+
z_pred_ = F.relu(z_pred_)
|
|
285
|
+
|
|
286
|
+
# Project predicted codes back to input space
|
|
287
|
+
Dz_pred_ = torch.matmul(z_pred_, self.W_dec)
|
|
288
|
+
Dz_norm_ = Dz_pred_.norm(dim=-1, keepdim=True) + self.eps
|
|
289
|
+
|
|
290
|
+
# Compute projection scale
|
|
291
|
+
proj_scale = (Dz_pred_ * x_residual).sum(
|
|
292
|
+
dim=-1, keepdim=True
|
|
293
|
+
) / Dz_norm_.pow(2)
|
|
294
|
+
|
|
295
|
+
# Accumulate predicted codes
|
|
296
|
+
z_pred = z_pred + (z_pred_ * proj_scale)
|
|
297
|
+
|
|
298
|
+
# Remove prediction from residual
|
|
299
|
+
x_residual = x_residual - proj_scale * Dz_pred_
|
|
300
|
+
|
|
301
|
+
# Encode residual (novel part) with sparse SAE
|
|
302
|
+
z_novel = F.relu(torch.matmul(x_residual * self.lam, W_enc))
|
|
303
|
+
if self.cfg.sae_diff_type == "topk":
|
|
304
|
+
kval = self.cfg.kval_topk
|
|
305
|
+
if kval is not None:
|
|
306
|
+
_, topk_indices = torch.topk(z_novel, kval, dim=-1)
|
|
307
|
+
mask = torch.zeros_like(z_novel)
|
|
308
|
+
mask.scatter_(-1, topk_indices, 1)
|
|
309
|
+
z_novel = z_novel * mask
|
|
310
|
+
|
|
311
|
+
# Return only novel codes (these are the interpretable features)
|
|
312
|
+
return z_novel, z_pred
|
|
313
|
+
|
|
314
|
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
|
315
|
+
return self.encode_with_predictions(x)[0]
|
|
316
|
+
|
|
317
|
+
def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
|
|
318
|
+
"""Decode novel codes to reconstruction.
|
|
319
|
+
|
|
320
|
+
Note: This only decodes the novel codes. For full reconstruction,
|
|
321
|
+
use forward() which includes predicted codes.
|
|
322
|
+
"""
|
|
323
|
+
# Decode novel codes
|
|
324
|
+
sae_out = torch.matmul(feature_acts, self.W_dec)
|
|
325
|
+
sae_out = sae_out + self.b_dec
|
|
326
|
+
|
|
327
|
+
# Apply hook
|
|
328
|
+
sae_out = self.hook_sae_recons(sae_out)
|
|
329
|
+
|
|
330
|
+
# Apply output activation normalization (reverses input normalization)
|
|
331
|
+
sae_out = self.run_time_activation_norm_fn_out(sae_out)
|
|
332
|
+
|
|
333
|
+
# Add bias (already removed in process_sae_in)
|
|
334
|
+
logger.warning(
|
|
335
|
+
"NOTE this only decodes x_novel. The x_pred is missing, so we're not reconstructing the full x."
|
|
336
|
+
)
|
|
337
|
+
return sae_out
|
|
338
|
+
|
|
339
|
+
@override
|
|
340
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
341
|
+
"""Full forward pass through TemporalSAE.
|
|
342
|
+
|
|
343
|
+
Returns complete reconstruction (predicted + novel).
|
|
344
|
+
"""
|
|
345
|
+
# Encode
|
|
346
|
+
z_novel, z_pred = self.encode_with_predictions(x)
|
|
347
|
+
|
|
348
|
+
# Decode the sum of predicted and novel codes.
|
|
349
|
+
x_recons = torch.matmul(z_novel + z_pred, self.W_dec) + self.b_dec
|
|
350
|
+
|
|
351
|
+
# Apply output activation normalization (reverses input normalization)
|
|
352
|
+
x_recons = self.run_time_activation_norm_fn_out(x_recons)
|
|
353
|
+
|
|
354
|
+
return self.hook_sae_output(x_recons)
|
|
355
|
+
|
|
356
|
+
@override
|
|
357
|
+
def fold_W_dec_norm(self) -> None:
|
|
358
|
+
raise NotImplementedError("Folding W_dec_norm is not supported for TemporalSAE")
|
|
359
|
+
|
|
360
|
+
@override
|
|
361
|
+
@torch.no_grad()
|
|
362
|
+
def fold_activation_norm_scaling_factor(self, scaling_factor: float) -> None:
|
|
363
|
+
raise NotImplementedError(
|
|
364
|
+
"Folding activation norm scaling factor is not supported for TemporalSAE"
|
|
365
|
+
)
|
sae_lens/saes/topk_sae.py
CHANGED
|
@@ -4,7 +4,6 @@ from dataclasses import dataclass
|
|
|
4
4
|
from typing import Any, Callable
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
|
-
from jaxtyping import Float
|
|
8
7
|
from torch import nn
|
|
9
8
|
from transformer_lens.hook_points import HookPoint
|
|
10
9
|
from typing_extensions import override
|
|
@@ -235,9 +234,7 @@ class TopKSAE(SAE[TopKSAEConfig]):
|
|
|
235
234
|
super().initialize_weights()
|
|
236
235
|
_init_weights_topk(self)
|
|
237
236
|
|
|
238
|
-
def encode(
|
|
239
|
-
self, x: Float[torch.Tensor, "... d_in"]
|
|
240
|
-
) -> Float[torch.Tensor, "... d_sae"]:
|
|
237
|
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
|
241
238
|
"""
|
|
242
239
|
Converts input x into feature activations.
|
|
243
240
|
Uses topk activation under the hood.
|
|
@@ -251,8 +248,8 @@ class TopKSAE(SAE[TopKSAEConfig]):
|
|
|
251
248
|
|
|
252
249
|
def decode(
|
|
253
250
|
self,
|
|
254
|
-
feature_acts:
|
|
255
|
-
) ->
|
|
251
|
+
feature_acts: torch.Tensor,
|
|
252
|
+
) -> torch.Tensor:
|
|
256
253
|
"""
|
|
257
254
|
Reconstructs the input from topk feature activations.
|
|
258
255
|
Applies optional finetuning scaling, hooking to recons, out normalization,
|
|
@@ -354,8 +351,8 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
|
|
|
354
351
|
_init_weights_topk(self)
|
|
355
352
|
|
|
356
353
|
def encode_with_hidden_pre(
|
|
357
|
-
self, x:
|
|
358
|
-
) -> tuple[
|
|
354
|
+
self, x: torch.Tensor
|
|
355
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
359
356
|
"""
|
|
360
357
|
Similar to the base training method: calculate pre-activations, then apply TopK.
|
|
361
358
|
"""
|
|
@@ -372,8 +369,8 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
|
|
|
372
369
|
@override
|
|
373
370
|
def decode(
|
|
374
371
|
self,
|
|
375
|
-
feature_acts:
|
|
376
|
-
) ->
|
|
372
|
+
feature_acts: torch.Tensor,
|
|
373
|
+
) -> torch.Tensor:
|
|
377
374
|
"""
|
|
378
375
|
Decodes feature activations back into input space,
|
|
379
376
|
applying optional finetuning scale, hooking, out normalization, etc.
|
|
@@ -534,7 +531,7 @@ def _fold_norm_topk(
|
|
|
534
531
|
b_enc: torch.Tensor,
|
|
535
532
|
W_dec: torch.Tensor,
|
|
536
533
|
) -> None:
|
|
537
|
-
W_dec_norm = W_dec.norm(dim=-1)
|
|
534
|
+
W_dec_norm = W_dec.norm(dim=-1).clamp(min=1e-8)
|
|
538
535
|
b_enc.data = b_enc.data * W_dec_norm
|
|
539
536
|
W_dec_norms = W_dec_norm.unsqueeze(1)
|
|
540
537
|
W_dec.data = W_dec.data / W_dec_norms
|
sae_lens/saes/transcoder.py
CHANGED
|
@@ -368,3 +368,44 @@ class JumpReLUTranscoder(Transcoder):
|
|
|
368
368
|
def from_dict(cls, config_dict: dict[str, Any]) -> "JumpReLUTranscoder":
|
|
369
369
|
cfg = JumpReLUTranscoderConfig.from_dict(config_dict)
|
|
370
370
|
return cls(cfg)
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
@dataclass
|
|
374
|
+
class JumpReLUSkipTranscoderConfig(JumpReLUTranscoderConfig):
|
|
375
|
+
"""Configuration for JumpReLU transcoder."""
|
|
376
|
+
|
|
377
|
+
@classmethod
|
|
378
|
+
def architecture(cls) -> str:
|
|
379
|
+
"""Return the architecture name for this config."""
|
|
380
|
+
return "jumprelu_skip_transcoder"
|
|
381
|
+
|
|
382
|
+
@classmethod
|
|
383
|
+
def from_dict(cls, config_dict: dict[str, Any]) -> "JumpReLUSkipTranscoderConfig":
|
|
384
|
+
"""Create a JumpReLUSkipTranscoderConfig from a dictionary."""
|
|
385
|
+
# Filter to only include valid dataclass fields
|
|
386
|
+
filtered_config_dict = filter_valid_dataclass_fields(config_dict, cls)
|
|
387
|
+
|
|
388
|
+
# Create the config instance
|
|
389
|
+
res = cls(**filtered_config_dict)
|
|
390
|
+
|
|
391
|
+
# Handle metadata if present
|
|
392
|
+
if "metadata" in config_dict:
|
|
393
|
+
res.metadata = SAEMetadata(**config_dict["metadata"])
|
|
394
|
+
|
|
395
|
+
return res
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
class JumpReLUSkipTranscoder(JumpReLUTranscoder, SkipTranscoder):
|
|
399
|
+
"""
|
|
400
|
+
A transcoder with a learnable skip connection and JumpReLU activation function.
|
|
401
|
+
"""
|
|
402
|
+
|
|
403
|
+
cfg: JumpReLUSkipTranscoderConfig # type: ignore[assignment]
|
|
404
|
+
|
|
405
|
+
def __init__(self, cfg: JumpReLUSkipTranscoderConfig):
|
|
406
|
+
super().__init__(cfg)
|
|
407
|
+
|
|
408
|
+
@classmethod
|
|
409
|
+
def from_dict(cls, config_dict: dict[str, Any]) -> "JumpReLUSkipTranscoder":
|
|
410
|
+
cfg = JumpReLUSkipTranscoderConfig.from_dict(config_dict)
|
|
411
|
+
return cls(cfg)
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import json
|
|
2
2
|
from dataclasses import dataclass
|
|
3
|
+
from pathlib import Path
|
|
3
4
|
from statistics import mean
|
|
4
5
|
|
|
5
6
|
import torch
|
|
@@ -51,3 +52,9 @@ class ActivationScaler:
|
|
|
51
52
|
|
|
52
53
|
with open(file_path, "w") as f:
|
|
53
54
|
json.dump({"scaling_factor": self.scaling_factor}, f)
|
|
55
|
+
|
|
56
|
+
def load(self, file_path: str | Path):
|
|
57
|
+
"""load the state dict from a file in json format"""
|
|
58
|
+
with open(file_path) as f:
|
|
59
|
+
data = json.load(f)
|
|
60
|
+
self.scaling_factor = data["scaling_factor"]
|
|
@@ -4,6 +4,7 @@ import json
|
|
|
4
4
|
import os
|
|
5
5
|
import warnings
|
|
6
6
|
from collections.abc import Generator, Iterator, Sequence
|
|
7
|
+
from pathlib import Path
|
|
7
8
|
from typing import Any, Literal, cast
|
|
8
9
|
|
|
9
10
|
import datasets
|
|
@@ -11,10 +12,9 @@ import torch
|
|
|
11
12
|
from datasets import Dataset, DatasetDict, IterableDataset, load_dataset
|
|
12
13
|
from huggingface_hub import hf_hub_download
|
|
13
14
|
from huggingface_hub.utils import HfHubHTTPError
|
|
14
|
-
from jaxtyping import Float, Int
|
|
15
15
|
from requests import HTTPError
|
|
16
|
-
from safetensors.torch import save_file
|
|
17
|
-
from tqdm import tqdm
|
|
16
|
+
from safetensors.torch import load_file, save_file
|
|
17
|
+
from tqdm.auto import tqdm
|
|
18
18
|
from transformer_lens.hook_points import HookedRootModule
|
|
19
19
|
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
|
20
20
|
|
|
@@ -24,7 +24,7 @@ from sae_lens.config import (
|
|
|
24
24
|
HfDataset,
|
|
25
25
|
LanguageModelSAERunnerConfig,
|
|
26
26
|
)
|
|
27
|
-
from sae_lens.constants import DTYPE_MAP
|
|
27
|
+
from sae_lens.constants import ACTIVATIONS_STORE_STATE_FILENAME, DTYPE_MAP
|
|
28
28
|
from sae_lens.pretokenize_runner import get_special_token_from_cfg
|
|
29
29
|
from sae_lens.saes.sae import SAE, T_SAE_CONFIG, T_TRAINING_SAE_CONFIG
|
|
30
30
|
from sae_lens.tokenization_and_batching import concat_and_batch_sequences
|
|
@@ -166,9 +166,11 @@ class ActivationsStore:
|
|
|
166
166
|
disable_concat_sequences: bool = False,
|
|
167
167
|
sequence_separator_token: int | Literal["bos", "eos", "sep"] | None = "bos",
|
|
168
168
|
) -> ActivationsStore:
|
|
169
|
+
if context_size is None:
|
|
170
|
+
context_size = sae.cfg.metadata.context_size
|
|
169
171
|
if sae.cfg.metadata.hook_name is None:
|
|
170
172
|
raise ValueError("hook_name is required")
|
|
171
|
-
if
|
|
173
|
+
if context_size is None:
|
|
172
174
|
raise ValueError("context_size is required")
|
|
173
175
|
if sae.cfg.metadata.prepend_bos is None:
|
|
174
176
|
raise ValueError("prepend_bos is required")
|
|
@@ -178,9 +180,7 @@ class ActivationsStore:
|
|
|
178
180
|
d_in=sae.cfg.d_in,
|
|
179
181
|
hook_name=sae.cfg.metadata.hook_name,
|
|
180
182
|
hook_head_index=sae.cfg.metadata.hook_head_index,
|
|
181
|
-
context_size=
|
|
182
|
-
if context_size is None
|
|
183
|
-
else context_size,
|
|
183
|
+
context_size=context_size,
|
|
184
184
|
prepend_bos=sae.cfg.metadata.prepend_bos,
|
|
185
185
|
streaming=streaming,
|
|
186
186
|
store_batch_size_prompts=store_batch_size_prompts,
|
|
@@ -230,7 +230,7 @@ class ActivationsStore:
|
|
|
230
230
|
load_dataset(
|
|
231
231
|
dataset,
|
|
232
232
|
split="train",
|
|
233
|
-
streaming=streaming,
|
|
233
|
+
streaming=streaming, # type: ignore
|
|
234
234
|
trust_remote_code=dataset_trust_remote_code, # type: ignore
|
|
235
235
|
)
|
|
236
236
|
if isinstance(dataset, str)
|
|
@@ -318,7 +318,7 @@ class ActivationsStore:
|
|
|
318
318
|
)
|
|
319
319
|
else:
|
|
320
320
|
warnings.warn(
|
|
321
|
-
"Dataset is not tokenized. Pre-tokenizing will improve performance and allows for more control over special tokens. See https://
|
|
321
|
+
"Dataset is not tokenized. Pre-tokenizing will improve performance and allows for more control over special tokens. See https://decoderesearch.github.io/SAELens/training_saes/#pretokenizing-datasets for more info."
|
|
322
322
|
)
|
|
323
323
|
|
|
324
324
|
self.iterable_sequences = self._iterate_tokenized_sequences()
|
|
@@ -541,8 +541,8 @@ class ActivationsStore:
|
|
|
541
541
|
d_in: int,
|
|
542
542
|
raise_on_epoch_end: bool,
|
|
543
543
|
) -> tuple[
|
|
544
|
-
|
|
545
|
-
|
|
544
|
+
torch.Tensor,
|
|
545
|
+
torch.Tensor | None,
|
|
546
546
|
]:
|
|
547
547
|
"""
|
|
548
548
|
Loads `total_size` activations from `cached_activation_dataset`
|
|
@@ -729,6 +729,48 @@ class ActivationsStore:
|
|
|
729
729
|
"""save the state dict to a file in safetensors format"""
|
|
730
730
|
save_file(self.state_dict(), file_path)
|
|
731
731
|
|
|
732
|
+
def save_to_checkpoint(self, checkpoint_path: str | Path):
|
|
733
|
+
"""Save the state dict to a checkpoint path"""
|
|
734
|
+
self.save(str(Path(checkpoint_path) / ACTIVATIONS_STORE_STATE_FILENAME))
|
|
735
|
+
|
|
736
|
+
def load_from_checkpoint(self, checkpoint_path: str | Path):
|
|
737
|
+
"""Load the state dict from a checkpoint path"""
|
|
738
|
+
self.load(str(Path(checkpoint_path) / ACTIVATIONS_STORE_STATE_FILENAME))
|
|
739
|
+
|
|
740
|
+
def load(self, file_path: str):
|
|
741
|
+
"""Load the state dict from a file in safetensors format"""
|
|
742
|
+
|
|
743
|
+
state_dict = load_file(file_path)
|
|
744
|
+
|
|
745
|
+
if "n_dataset_processed" in state_dict:
|
|
746
|
+
target_n_dataset_processed = state_dict["n_dataset_processed"].item()
|
|
747
|
+
|
|
748
|
+
# Only fast-forward if needed
|
|
749
|
+
|
|
750
|
+
if target_n_dataset_processed > self.n_dataset_processed:
|
|
751
|
+
logger.info(
|
|
752
|
+
"Fast-forwarding through dataset samples to match checkpoint position"
|
|
753
|
+
)
|
|
754
|
+
samples_to_skip = target_n_dataset_processed - self.n_dataset_processed
|
|
755
|
+
|
|
756
|
+
pbar = tqdm(
|
|
757
|
+
total=samples_to_skip,
|
|
758
|
+
desc="Fast-forwarding through dataset",
|
|
759
|
+
leave=False,
|
|
760
|
+
)
|
|
761
|
+
while target_n_dataset_processed > self.n_dataset_processed:
|
|
762
|
+
start = self.n_dataset_processed
|
|
763
|
+
try:
|
|
764
|
+
# Just consume and ignore the values to fast-forward
|
|
765
|
+
next(self.iterable_sequences)
|
|
766
|
+
except StopIteration:
|
|
767
|
+
logger.warning(
|
|
768
|
+
"Dataset exhausted during fast-forward. Resetting dataset."
|
|
769
|
+
)
|
|
770
|
+
self.iterable_sequences = self._iterate_tokenized_sequences()
|
|
771
|
+
pbar.update(self.n_dataset_processed - start)
|
|
772
|
+
pbar.close()
|
|
773
|
+
|
|
732
774
|
|
|
733
775
|
def validate_pretokenized_dataset_tokenizer(
|
|
734
776
|
dataset_path: str, model_tokenizer: PreTrainedTokenizerBase
|
sae_lens/training/optim.py
CHANGED
|
@@ -2,6 +2,8 @@
|
|
|
2
2
|
Took the LR scheduler from my previous work: https://github.com/jbloomAus/DecisionTransformerInterpretability/blob/ee55df35cdb92e81d689c72fb9dd5a7252893363/src/decision_transformer/utils.py#L425
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
5
7
|
import torch.optim as optim
|
|
6
8
|
import torch.optim.lr_scheduler as lr_scheduler
|
|
7
9
|
|
|
@@ -150,3 +152,12 @@ class CoefficientScheduler:
|
|
|
150
152
|
def value(self) -> float:
|
|
151
153
|
"""Returns the current scalar value."""
|
|
152
154
|
return self.current_value
|
|
155
|
+
|
|
156
|
+
def state_dict(self) -> dict[str, Any]:
|
|
157
|
+
return {
|
|
158
|
+
"current_step": self.current_step,
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
def load_state_dict(self, state_dict: dict[str, Any]):
|
|
162
|
+
for k in state_dict:
|
|
163
|
+
setattr(self, k, state_dict[k])
|