sae-lens 6.13.1__tar.gz → 6.14.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sae_lens-6.13.1 → sae_lens-6.14.1}/PKG-INFO +1 -1
- {sae_lens-6.13.1 → sae_lens-6.14.1}/pyproject.toml +1 -1
- {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/__init__.py +1 -1
- {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/saes/batchtopk_sae.py +29 -0
- {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/saes/topk_sae.py +76 -8
- {sae_lens-6.13.1 → sae_lens-6.14.1}/LICENSE +0 -0
- {sae_lens-6.13.1 → sae_lens-6.14.1}/README.md +0 -0
- {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/analysis/__init__.py +0 -0
- {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
- {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/analysis/neuronpedia_integration.py +0 -0
- {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/cache_activations_runner.py +0 -0
- {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/config.py +0 -0
- {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/constants.py +0 -0
- {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/evals.py +0 -0
- {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/llm_sae_training_runner.py +0 -0
- {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/load_model.py +0 -0
- {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/loading/__init__.py +0 -0
- {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/loading/pretrained_sae_loaders.py +0 -0
- {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/loading/pretrained_saes_directory.py +0 -0
- {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/pretokenize_runner.py +0 -0
- {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/pretrained_saes.yaml +0 -0
- {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/registry.py +0 -0
- {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/saes/__init__.py +0 -0
- {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/saes/gated_sae.py +0 -0
- {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/saes/jumprelu_sae.py +0 -0
- {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/saes/sae.py +0 -0
- {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/saes/standard_sae.py +0 -0
- {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/saes/transcoder.py +0 -0
- {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/tokenization_and_batching.py +0 -0
- {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/training/__init__.py +0 -0
- {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/training/activation_scaler.py +0 -0
- {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/training/activations_store.py +0 -0
- {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/training/mixing_buffer.py +0 -0
- {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/training/optim.py +0 -0
- {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/training/sae_trainer.py +0 -0
- {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/training/types.py +0 -0
- {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
- {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/tutorial/tsea.py +0 -0
- {sae_lens-6.13.1 → sae_lens-6.14.1}/sae_lens/util.py +0 -0
|
@@ -35,6 +35,35 @@ class BatchTopK(nn.Module):
|
|
|
35
35
|
class BatchTopKTrainingSAEConfig(TopKTrainingSAEConfig):
|
|
36
36
|
"""
|
|
37
37
|
Configuration class for training a BatchTopKTrainingSAE.
|
|
38
|
+
|
|
39
|
+
BatchTopK SAEs maintain k active features on average across the entire batch,
|
|
40
|
+
rather than enforcing k features per sample like standard TopK SAEs. During training,
|
|
41
|
+
the SAE learns a global threshold that is updated based on the minimum positive
|
|
42
|
+
activation value. After training, BatchTopK SAEs are saved as JumpReLU SAEs.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
k (float): Average number of features to keep active across the batch. Unlike
|
|
46
|
+
standard TopK SAEs where k is an integer per sample, this is a float
|
|
47
|
+
representing the average number of active features across all samples in
|
|
48
|
+
the batch. Defaults to 100.
|
|
49
|
+
topk_threshold_lr (float): Learning rate for updating the global topk threshold.
|
|
50
|
+
The threshold is updated using an exponential moving average of the minimum
|
|
51
|
+
positive activation value. Defaults to 0.01.
|
|
52
|
+
aux_loss_coefficient (float): Coefficient for the auxiliary loss that encourages
|
|
53
|
+
dead neurons to learn useful features. Inherited from TopKTrainingSAEConfig.
|
|
54
|
+
Defaults to 1.0.
|
|
55
|
+
rescale_acts_by_decoder_norm (bool): Treat the decoder as if it was already normalized.
|
|
56
|
+
Inherited from TopKTrainingSAEConfig. Defaults to True.
|
|
57
|
+
decoder_init_norm (float | None): Norm to initialize decoder weights to.
|
|
58
|
+
Inherited from TrainingSAEConfig. Defaults to 0.1.
|
|
59
|
+
d_in (int): Input dimension (dimensionality of the activations being encoded).
|
|
60
|
+
Inherited from SAEConfig.
|
|
61
|
+
d_sae (int): SAE latent dimension (number of features in the SAE).
|
|
62
|
+
Inherited from SAEConfig.
|
|
63
|
+
dtype (str): Data type for the SAE parameters. Inherited from SAEConfig.
|
|
64
|
+
Defaults to "float32".
|
|
65
|
+
device (str): Device to place the SAE on. Inherited from SAEConfig.
|
|
66
|
+
Defaults to "cpu".
|
|
38
67
|
"""
|
|
39
68
|
|
|
40
69
|
k: float = 100 # type: ignore[assignment]
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""Inference-only TopKSAE variant, similar in spirit to StandardSAE but using a TopK-based activation."""
|
|
2
2
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
|
-
from typing import Callable
|
|
4
|
+
from typing import Any, Callable
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
from jaxtyping import Float
|
|
@@ -118,10 +118,36 @@ class TopK(nn.Module):
|
|
|
118
118
|
@dataclass
|
|
119
119
|
class TopKSAEConfig(SAEConfig):
|
|
120
120
|
"""
|
|
121
|
-
Configuration class for
|
|
121
|
+
Configuration class for TopKSAE inference.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
k (int): Number of top features to keep active during inference. Only the top k
|
|
125
|
+
features with the highest pre-activations will be non-zero. Defaults to 100.
|
|
126
|
+
rescale_acts_by_decoder_norm (bool): Whether to treat the decoder as if it was
|
|
127
|
+
already normalized. This affects the topk selection by rescaling pre-activations
|
|
128
|
+
by decoder norms. Requires that the SAE was trained this way. Defaults to False.
|
|
129
|
+
d_in (int): Input dimension (dimensionality of the activations being encoded).
|
|
130
|
+
Inherited from SAEConfig.
|
|
131
|
+
d_sae (int): SAE latent dimension (number of features in the SAE).
|
|
132
|
+
Inherited from SAEConfig.
|
|
133
|
+
dtype (str): Data type for the SAE parameters. Inherited from SAEConfig.
|
|
134
|
+
Defaults to "float32".
|
|
135
|
+
device (str): Device to place the SAE on. Inherited from SAEConfig.
|
|
136
|
+
Defaults to "cpu".
|
|
137
|
+
apply_b_dec_to_input (bool): Whether to apply decoder bias to the input
|
|
138
|
+
before encoding. Inherited from SAEConfig. Defaults to True.
|
|
139
|
+
normalize_activations (Literal["none", "expected_average_only_in", "constant_norm_rescale", "layer_norm"]):
|
|
140
|
+
Normalization strategy for input activations. Inherited from SAEConfig.
|
|
141
|
+
Defaults to "none".
|
|
142
|
+
reshape_activations (Literal["none", "hook_z"]): How to reshape activations
|
|
143
|
+
(useful for attention head outputs). Inherited from SAEConfig.
|
|
144
|
+
Defaults to "none".
|
|
145
|
+
metadata (SAEMetadata): Metadata about the SAE (model name, hook name, etc.).
|
|
146
|
+
Inherited from SAEConfig.
|
|
122
147
|
"""
|
|
123
148
|
|
|
124
149
|
k: int = 100
|
|
150
|
+
rescale_acts_by_decoder_norm: bool = False
|
|
125
151
|
|
|
126
152
|
@override
|
|
127
153
|
@classmethod
|
|
@@ -218,6 +244,8 @@ class TopKSAE(SAE[TopKSAEConfig]):
|
|
|
218
244
|
"""
|
|
219
245
|
sae_in = self.process_sae_in(x)
|
|
220
246
|
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
|
|
247
|
+
if self.cfg.rescale_acts_by_decoder_norm:
|
|
248
|
+
hidden_pre = hidden_pre * self.W_dec.norm(dim=-1)
|
|
221
249
|
# The BaseSAE already sets self.activation_fn to TopK(...) if config requests topk.
|
|
222
250
|
return self.hook_sae_acts_post(self.activation_fn(hidden_pre))
|
|
223
251
|
|
|
@@ -231,6 +259,8 @@ class TopKSAE(SAE[TopKSAEConfig]):
|
|
|
231
259
|
and optional head reshaping.
|
|
232
260
|
"""
|
|
233
261
|
# Handle sparse tensors using efficient sparse matrix multiplication
|
|
262
|
+
if self.cfg.rescale_acts_by_decoder_norm:
|
|
263
|
+
feature_acts = feature_acts / self.W_dec.norm(dim=-1)
|
|
234
264
|
if feature_acts.is_sparse:
|
|
235
265
|
sae_out_pre = _sparse_matmul_nd(feature_acts, self.W_dec) + self.b_dec
|
|
236
266
|
else:
|
|
@@ -246,9 +276,11 @@ class TopKSAE(SAE[TopKSAEConfig]):
|
|
|
246
276
|
@override
|
|
247
277
|
@torch.no_grad()
|
|
248
278
|
def fold_W_dec_norm(self) -> None:
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
279
|
+
if not self.cfg.rescale_acts_by_decoder_norm:
|
|
280
|
+
raise NotImplementedError(
|
|
281
|
+
"Folding W_dec_norm is not safe for TopKSAEs when rescale_acts_by_decoder_norm is False, as this may change the topk activations"
|
|
282
|
+
)
|
|
283
|
+
_fold_norm_topk(W_dec=self.W_dec, b_enc=self.b_enc, W_enc=self.W_enc)
|
|
252
284
|
|
|
253
285
|
|
|
254
286
|
@dataclass
|
|
@@ -267,6 +299,9 @@ class TopKTrainingSAEConfig(TrainingSAEConfig):
|
|
|
267
299
|
dead neurons to learn useful features. This loss helps prevent neuron death
|
|
268
300
|
in TopK SAEs by having dead neurons reconstruct the residual error from
|
|
269
301
|
live neurons. Defaults to 1.0.
|
|
302
|
+
rescale_acts_by_decoder_norm (bool): Treat the decoder as if it was already normalized.
|
|
303
|
+
This is a good idea since decoder norm can randomly drift during training, and this
|
|
304
|
+
affects what the topk activations will be. Defaults to True.
|
|
270
305
|
decoder_init_norm (float | None): Norm to initialize decoder weights to.
|
|
271
306
|
0.1 corresponds to the "heuristic" initialization from Anthropic's April update.
|
|
272
307
|
Use None to disable. Inherited from TrainingSAEConfig. Defaults to 0.1.
|
|
@@ -293,6 +328,7 @@ class TopKTrainingSAEConfig(TrainingSAEConfig):
|
|
|
293
328
|
k: int = 100
|
|
294
329
|
use_sparse_activations: bool = False
|
|
295
330
|
aux_loss_coefficient: float = 1.0
|
|
331
|
+
rescale_acts_by_decoder_norm: bool = True
|
|
296
332
|
|
|
297
333
|
@override
|
|
298
334
|
@classmethod
|
|
@@ -326,6 +362,9 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
|
|
|
326
362
|
sae_in = self.process_sae_in(x)
|
|
327
363
|
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
|
|
328
364
|
|
|
365
|
+
if self.cfg.rescale_acts_by_decoder_norm:
|
|
366
|
+
hidden_pre = hidden_pre * self.W_dec.norm(dim=-1)
|
|
367
|
+
|
|
329
368
|
# Apply the TopK activation function (already set in self.activation_fn if config is "topk")
|
|
330
369
|
feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre))
|
|
331
370
|
return feature_acts, hidden_pre
|
|
@@ -340,6 +379,9 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
|
|
|
340
379
|
applying optional finetuning scale, hooking, out normalization, etc.
|
|
341
380
|
"""
|
|
342
381
|
# Handle sparse tensors using efficient sparse matrix multiplication
|
|
382
|
+
if self.cfg.rescale_acts_by_decoder_norm:
|
|
383
|
+
# need to multiply by the inverse of the norm because division is illegal with sparse tensors
|
|
384
|
+
feature_acts = feature_acts * (1 / self.W_dec.norm(dim=-1))
|
|
343
385
|
if feature_acts.is_sparse:
|
|
344
386
|
sae_out_pre = _sparse_matmul_nd(feature_acts, self.W_dec) + self.b_dec
|
|
345
387
|
else:
|
|
@@ -385,9 +427,11 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
|
|
|
385
427
|
@override
|
|
386
428
|
@torch.no_grad()
|
|
387
429
|
def fold_W_dec_norm(self) -> None:
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
430
|
+
if not self.cfg.rescale_acts_by_decoder_norm:
|
|
431
|
+
raise NotImplementedError(
|
|
432
|
+
"Folding W_dec_norm is not safe for TopKSAEs when rescale_acts_by_decoder_norm is False, as this may change the topk activations"
|
|
433
|
+
)
|
|
434
|
+
_fold_norm_topk(W_dec=self.W_dec, b_enc=self.b_enc, W_enc=self.W_enc)
|
|
391
435
|
|
|
392
436
|
@override
|
|
393
437
|
def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
|
|
@@ -436,6 +480,18 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
|
|
|
436
480
|
auxk_loss = (recons - residual).pow(2).sum(dim=-1).mean()
|
|
437
481
|
return self.cfg.aux_loss_coefficient * scale * auxk_loss
|
|
438
482
|
|
|
483
|
+
@override
|
|
484
|
+
def process_state_dict_for_saving_inference(
|
|
485
|
+
self, state_dict: dict[str, Any]
|
|
486
|
+
) -> None:
|
|
487
|
+
super().process_state_dict_for_saving_inference(state_dict)
|
|
488
|
+
if self.cfg.rescale_acts_by_decoder_norm:
|
|
489
|
+
_fold_norm_topk(
|
|
490
|
+
W_enc=state_dict["W_enc"],
|
|
491
|
+
b_enc=state_dict["b_enc"],
|
|
492
|
+
W_dec=state_dict["W_dec"],
|
|
493
|
+
)
|
|
494
|
+
|
|
439
495
|
|
|
440
496
|
def _calculate_topk_aux_acts(
|
|
441
497
|
k_aux: int,
|
|
@@ -471,3 +527,15 @@ def _init_weights_topk(
|
|
|
471
527
|
sae.b_enc = nn.Parameter(
|
|
472
528
|
torch.zeros(sae.cfg.d_sae, dtype=sae.dtype, device=sae.device)
|
|
473
529
|
)
|
|
530
|
+
|
|
531
|
+
|
|
532
|
+
def _fold_norm_topk(
|
|
533
|
+
W_enc: torch.Tensor,
|
|
534
|
+
b_enc: torch.Tensor,
|
|
535
|
+
W_dec: torch.Tensor,
|
|
536
|
+
) -> None:
|
|
537
|
+
W_dec_norm = W_dec.norm(dim=-1)
|
|
538
|
+
b_enc.data = b_enc.data * W_dec_norm
|
|
539
|
+
W_dec_norms = W_dec_norm.unsqueeze(1)
|
|
540
|
+
W_dec.data = W_dec.data / W_dec_norms
|
|
541
|
+
W_enc.data = W_enc.data * W_dec_norms.T
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|