sae-lens 6.14.2__tar.gz → 6.16.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sae_lens-6.14.2 → sae_lens-6.16.0}/PKG-INFO +1 -1
- {sae_lens-6.14.2 → sae_lens-6.16.0}/pyproject.toml +1 -1
- {sae_lens-6.14.2 → sae_lens-6.16.0}/sae_lens/__init__.py +10 -1
- {sae_lens-6.14.2 → sae_lens-6.16.0}/sae_lens/llm_sae_training_runner.py +3 -11
- {sae_lens-6.14.2 → sae_lens-6.16.0}/sae_lens/saes/__init__.py +6 -0
- sae_lens-6.16.0/sae_lens/saes/matryoshka_batchtopk_sae.py +143 -0
- {sae_lens-6.14.2 → sae_lens-6.16.0}/sae_lens/saes/sae.py +1 -0
- {sae_lens-6.14.2 → sae_lens-6.16.0}/sae_lens/training/sae_trainer.py +1 -0
- {sae_lens-6.14.2 → sae_lens-6.16.0}/LICENSE +0 -0
- {sae_lens-6.14.2 → sae_lens-6.16.0}/README.md +0 -0
- {sae_lens-6.14.2 → sae_lens-6.16.0}/sae_lens/analysis/__init__.py +0 -0
- {sae_lens-6.14.2 → sae_lens-6.16.0}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
- {sae_lens-6.14.2 → sae_lens-6.16.0}/sae_lens/analysis/neuronpedia_integration.py +0 -0
- {sae_lens-6.14.2 → sae_lens-6.16.0}/sae_lens/cache_activations_runner.py +0 -0
- {sae_lens-6.14.2 → sae_lens-6.16.0}/sae_lens/config.py +0 -0
- {sae_lens-6.14.2 → sae_lens-6.16.0}/sae_lens/constants.py +0 -0
- {sae_lens-6.14.2 → sae_lens-6.16.0}/sae_lens/evals.py +0 -0
- {sae_lens-6.14.2 → sae_lens-6.16.0}/sae_lens/load_model.py +0 -0
- {sae_lens-6.14.2 → sae_lens-6.16.0}/sae_lens/loading/__init__.py +0 -0
- {sae_lens-6.14.2 → sae_lens-6.16.0}/sae_lens/loading/pretrained_sae_loaders.py +0 -0
- {sae_lens-6.14.2 → sae_lens-6.16.0}/sae_lens/loading/pretrained_saes_directory.py +0 -0
- {sae_lens-6.14.2 → sae_lens-6.16.0}/sae_lens/pretokenize_runner.py +0 -0
- {sae_lens-6.14.2 → sae_lens-6.16.0}/sae_lens/pretrained_saes.yaml +0 -0
- {sae_lens-6.14.2 → sae_lens-6.16.0}/sae_lens/registry.py +0 -0
- {sae_lens-6.14.2 → sae_lens-6.16.0}/sae_lens/saes/batchtopk_sae.py +0 -0
- {sae_lens-6.14.2 → sae_lens-6.16.0}/sae_lens/saes/gated_sae.py +0 -0
- {sae_lens-6.14.2 → sae_lens-6.16.0}/sae_lens/saes/jumprelu_sae.py +0 -0
- {sae_lens-6.14.2 → sae_lens-6.16.0}/sae_lens/saes/standard_sae.py +0 -0
- {sae_lens-6.14.2 → sae_lens-6.16.0}/sae_lens/saes/topk_sae.py +0 -0
- {sae_lens-6.14.2 → sae_lens-6.16.0}/sae_lens/saes/transcoder.py +0 -0
- {sae_lens-6.14.2 → sae_lens-6.16.0}/sae_lens/tokenization_and_batching.py +0 -0
- {sae_lens-6.14.2 → sae_lens-6.16.0}/sae_lens/training/__init__.py +0 -0
- {sae_lens-6.14.2 → sae_lens-6.16.0}/sae_lens/training/activation_scaler.py +0 -0
- {sae_lens-6.14.2 → sae_lens-6.16.0}/sae_lens/training/activations_store.py +0 -0
- {sae_lens-6.14.2 → sae_lens-6.16.0}/sae_lens/training/mixing_buffer.py +0 -0
- {sae_lens-6.14.2 → sae_lens-6.16.0}/sae_lens/training/optim.py +0 -0
- {sae_lens-6.14.2 → sae_lens-6.16.0}/sae_lens/training/types.py +0 -0
- {sae_lens-6.14.2 → sae_lens-6.16.0}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
- {sae_lens-6.14.2 → sae_lens-6.16.0}/sae_lens/tutorial/tsea.py +0 -0
- {sae_lens-6.14.2 → sae_lens-6.16.0}/sae_lens/util.py +0 -0
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
# ruff: noqa: E402
|
|
2
|
-
__version__ = "6.
|
|
2
|
+
__version__ = "6.16.0"
|
|
3
3
|
|
|
4
4
|
import logging
|
|
5
5
|
|
|
@@ -19,6 +19,8 @@ from sae_lens.saes import (
|
|
|
19
19
|
JumpReLUTrainingSAEConfig,
|
|
20
20
|
JumpReLUTranscoder,
|
|
21
21
|
JumpReLUTranscoderConfig,
|
|
22
|
+
MatryoshkaBatchTopKTrainingSAE,
|
|
23
|
+
MatryoshkaBatchTopKTrainingSAEConfig,
|
|
22
24
|
SAEConfig,
|
|
23
25
|
SkipTranscoder,
|
|
24
26
|
SkipTranscoderConfig,
|
|
@@ -101,6 +103,8 @@ __all__ = [
|
|
|
101
103
|
"SkipTranscoderConfig",
|
|
102
104
|
"JumpReLUTranscoder",
|
|
103
105
|
"JumpReLUTranscoderConfig",
|
|
106
|
+
"MatryoshkaBatchTopKTrainingSAE",
|
|
107
|
+
"MatryoshkaBatchTopKTrainingSAEConfig",
|
|
104
108
|
]
|
|
105
109
|
|
|
106
110
|
|
|
@@ -115,6 +119,11 @@ register_sae_training_class("jumprelu", JumpReLUTrainingSAE, JumpReLUTrainingSAE
|
|
|
115
119
|
register_sae_training_class(
|
|
116
120
|
"batchtopk", BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig
|
|
117
121
|
)
|
|
122
|
+
register_sae_training_class(
|
|
123
|
+
"matryoshka_batchtopk",
|
|
124
|
+
MatryoshkaBatchTopKTrainingSAE,
|
|
125
|
+
MatryoshkaBatchTopKTrainingSAEConfig,
|
|
126
|
+
)
|
|
118
127
|
register_sae_class("transcoder", Transcoder, TranscoderConfig)
|
|
119
128
|
register_sae_class("skip_transcoder", SkipTranscoder, SkipTranscoderConfig)
|
|
120
129
|
register_sae_class("jumprelu_transcoder", JumpReLUTranscoder, JumpReLUTranscoderConfig)
|
|
@@ -22,17 +22,13 @@ from sae_lens.constants import (
|
|
|
22
22
|
)
|
|
23
23
|
from sae_lens.evals import EvalConfig, run_evals
|
|
24
24
|
from sae_lens.load_model import load_model
|
|
25
|
-
from sae_lens.
|
|
26
|
-
from sae_lens.saes.gated_sae import GatedTrainingSAEConfig
|
|
27
|
-
from sae_lens.saes.jumprelu_sae import JumpReLUTrainingSAEConfig
|
|
25
|
+
from sae_lens.registry import SAE_TRAINING_CLASS_REGISTRY
|
|
28
26
|
from sae_lens.saes.sae import (
|
|
29
27
|
T_TRAINING_SAE,
|
|
30
28
|
T_TRAINING_SAE_CONFIG,
|
|
31
29
|
TrainingSAE,
|
|
32
30
|
TrainingSAEConfig,
|
|
33
31
|
)
|
|
34
|
-
from sae_lens.saes.standard_sae import StandardTrainingSAEConfig
|
|
35
|
-
from sae_lens.saes.topk_sae import TopKTrainingSAEConfig
|
|
36
32
|
from sae_lens.training.activation_scaler import ActivationScaler
|
|
37
33
|
from sae_lens.training.activations_store import ActivationsStore
|
|
38
34
|
from sae_lens.training.sae_trainer import SAETrainer
|
|
@@ -395,12 +391,8 @@ def _parse_cfg_args(
|
|
|
395
391
|
)
|
|
396
392
|
|
|
397
393
|
# Map architecture to concrete config class
|
|
398
|
-
sae_config_map = {
|
|
399
|
-
|
|
400
|
-
"gated": GatedTrainingSAEConfig,
|
|
401
|
-
"jumprelu": JumpReLUTrainingSAEConfig,
|
|
402
|
-
"topk": TopKTrainingSAEConfig,
|
|
403
|
-
"batchtopk": BatchTopKTrainingSAEConfig,
|
|
394
|
+
sae_config_map: dict[str, type[TrainingSAEConfig]] = {
|
|
395
|
+
name: cfg for name, (_, cfg) in SAE_TRAINING_CLASS_REGISTRY.items()
|
|
404
396
|
}
|
|
405
397
|
|
|
406
398
|
sae_config_type = sae_config_map[architecture]
|
|
@@ -14,6 +14,10 @@ from .jumprelu_sae import (
|
|
|
14
14
|
JumpReLUTrainingSAE,
|
|
15
15
|
JumpReLUTrainingSAEConfig,
|
|
16
16
|
)
|
|
17
|
+
from .matryoshka_batchtopk_sae import (
|
|
18
|
+
MatryoshkaBatchTopKTrainingSAE,
|
|
19
|
+
MatryoshkaBatchTopKTrainingSAEConfig,
|
|
20
|
+
)
|
|
17
21
|
from .sae import SAE, SAEConfig, TrainingSAE, TrainingSAEConfig
|
|
18
22
|
from .standard_sae import (
|
|
19
23
|
StandardSAE,
|
|
@@ -65,4 +69,6 @@ __all__ = [
|
|
|
65
69
|
"SkipTranscoderConfig",
|
|
66
70
|
"JumpReLUTranscoder",
|
|
67
71
|
"JumpReLUTranscoderConfig",
|
|
72
|
+
"MatryoshkaBatchTopKTrainingSAE",
|
|
73
|
+
"MatryoshkaBatchTopKTrainingSAEConfig",
|
|
68
74
|
]
|
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from jaxtyping import Float
|
|
6
|
+
from typing_extensions import override
|
|
7
|
+
|
|
8
|
+
from sae_lens.saes.batchtopk_sae import (
|
|
9
|
+
BatchTopKTrainingSAE,
|
|
10
|
+
BatchTopKTrainingSAEConfig,
|
|
11
|
+
)
|
|
12
|
+
from sae_lens.saes.sae import TrainStepInput, TrainStepOutput
|
|
13
|
+
from sae_lens.saes.topk_sae import _sparse_matmul_nd
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class MatryoshkaBatchTopKTrainingSAEConfig(BatchTopKTrainingSAEConfig):
|
|
18
|
+
"""
|
|
19
|
+
Configuration class for training a MatryoshkaBatchTopKTrainingSAE.
|
|
20
|
+
|
|
21
|
+
[Matryoshka SAEs](https://arxiv.org/pdf/2503.17547) use a series of nested reconstruction
|
|
22
|
+
losses of different widths during training to avoid feature absorption. This also has a
|
|
23
|
+
nice side-effect of encouraging higher-frequency features to be learned in earlier levels.
|
|
24
|
+
However, this SAE has more hyperparameters to tune than standard BatchTopK SAEs, and takes
|
|
25
|
+
longer to train due to requiring multiple forward passes per training step.
|
|
26
|
+
|
|
27
|
+
After training, MatryoshkaBatchTopK SAEs are saved as JumpReLU SAEs.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
matryoshka_widths (list[int]): The widths of the matryoshka levels. Defaults to an empty list.
|
|
31
|
+
k (float): The number of features to keep active. Inherited from BatchTopKTrainingSAEConfig.
|
|
32
|
+
Defaults to 100.
|
|
33
|
+
topk_threshold_lr (float): Learning rate for updating the global topk threshold.
|
|
34
|
+
The threshold is updated using an exponential moving average of the minimum
|
|
35
|
+
positive activation value. Defaults to 0.01.
|
|
36
|
+
aux_loss_coefficient (float): Coefficient for the auxiliary loss that encourages
|
|
37
|
+
dead neurons to learn useful features. Inherited from TopKTrainingSAEConfig.
|
|
38
|
+
Defaults to 1.0.
|
|
39
|
+
rescale_acts_by_decoder_norm (bool): Treat the decoder as if it was already normalized.
|
|
40
|
+
Inherited from TopKTrainingSAEConfig. Defaults to True.
|
|
41
|
+
decoder_init_norm (float | None): Norm to initialize decoder weights to.
|
|
42
|
+
Inherited from TrainingSAEConfig. Defaults to 0.1.
|
|
43
|
+
d_in (int): Input dimension (dimensionality of the activations being encoded).
|
|
44
|
+
Inherited from SAEConfig.
|
|
45
|
+
d_sae (int): SAE latent dimension (number of features in the SAE).
|
|
46
|
+
Inherited from SAEConfig.
|
|
47
|
+
dtype (str): Data type for the SAE parameters. Inherited from SAEConfig.
|
|
48
|
+
Defaults to "float32".
|
|
49
|
+
device (str): Device to place the SAE on. Inherited from SAEConfig.
|
|
50
|
+
Defaults to "cpu".
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
matryoshka_widths: list[int] = field(default_factory=list)
|
|
54
|
+
|
|
55
|
+
@override
|
|
56
|
+
@classmethod
|
|
57
|
+
def architecture(cls) -> str:
|
|
58
|
+
return "matryoshka_batchtopk"
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class MatryoshkaBatchTopKTrainingSAE(BatchTopKTrainingSAE):
|
|
62
|
+
"""
|
|
63
|
+
Global Batch TopK Training SAE
|
|
64
|
+
|
|
65
|
+
This SAE will maintain the k on average across the batch, rather than enforcing the k per-sample as in standard TopK.
|
|
66
|
+
|
|
67
|
+
BatchTopK SAEs are saved as JumpReLU SAEs after training.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
cfg: MatryoshkaBatchTopKTrainingSAEConfig # type: ignore[assignment]
|
|
71
|
+
|
|
72
|
+
def __init__(
|
|
73
|
+
self, cfg: MatryoshkaBatchTopKTrainingSAEConfig, use_error_term: bool = False
|
|
74
|
+
):
|
|
75
|
+
super().__init__(cfg, use_error_term)
|
|
76
|
+
_validate_matryoshka_config(cfg)
|
|
77
|
+
|
|
78
|
+
@override
|
|
79
|
+
def training_forward_pass(self, step_input: TrainStepInput) -> TrainStepOutput:
|
|
80
|
+
base_output = super().training_forward_pass(step_input)
|
|
81
|
+
hidden_pre = base_output.hidden_pre
|
|
82
|
+
inv_W_dec_norm = 1 / self.W_dec.norm(dim=-1)
|
|
83
|
+
# the outer matryoshka level is the base SAE, so we don't need to add an extra loss for it
|
|
84
|
+
for width in self.cfg.matryoshka_widths[:-1]:
|
|
85
|
+
inner_hidden_pre = hidden_pre[:, :width]
|
|
86
|
+
inner_feat_acts = self.activation_fn(inner_hidden_pre)
|
|
87
|
+
inner_reconstruction = self._decode_matryoshka_level(
|
|
88
|
+
inner_feat_acts, width, inv_W_dec_norm
|
|
89
|
+
)
|
|
90
|
+
inner_mse_loss = (
|
|
91
|
+
self.mse_loss_fn(inner_reconstruction, step_input.sae_in)
|
|
92
|
+
.sum(dim=-1)
|
|
93
|
+
.mean()
|
|
94
|
+
)
|
|
95
|
+
base_output.losses[f"inner_mse_loss_{width}"] = inner_mse_loss
|
|
96
|
+
base_output.loss = base_output.loss + inner_mse_loss
|
|
97
|
+
return base_output
|
|
98
|
+
|
|
99
|
+
def _decode_matryoshka_level(
|
|
100
|
+
self,
|
|
101
|
+
feature_acts: Float[torch.Tensor, "... d_sae"],
|
|
102
|
+
width: int,
|
|
103
|
+
inv_W_dec_norm: torch.Tensor,
|
|
104
|
+
) -> Float[torch.Tensor, "... d_in"]:
|
|
105
|
+
"""
|
|
106
|
+
Decodes feature activations back into input space for a matryoshka level
|
|
107
|
+
"""
|
|
108
|
+
# Handle sparse tensors using efficient sparse matrix multiplication
|
|
109
|
+
if self.cfg.rescale_acts_by_decoder_norm:
|
|
110
|
+
# need to multiply by the inverse of the norm because division is illegal with sparse tensors
|
|
111
|
+
feature_acts = feature_acts * inv_W_dec_norm[:width]
|
|
112
|
+
if feature_acts.is_sparse:
|
|
113
|
+
sae_out_pre = (
|
|
114
|
+
_sparse_matmul_nd(feature_acts, self.W_dec[:width]) + self.b_dec
|
|
115
|
+
)
|
|
116
|
+
else:
|
|
117
|
+
sae_out_pre = feature_acts @ self.W_dec[:width] + self.b_dec
|
|
118
|
+
sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
|
|
119
|
+
return self.reshape_fn_out(sae_out_pre, self.d_head)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def _validate_matryoshka_config(cfg: MatryoshkaBatchTopKTrainingSAEConfig) -> None:
|
|
123
|
+
if cfg.matryoshka_widths[-1] != cfg.d_sae:
|
|
124
|
+
# warn the users that we will add a final matryoshka level
|
|
125
|
+
warnings.warn(
|
|
126
|
+
"WARNING: The final matryoshka level width is not set to cfg.d_sae. "
|
|
127
|
+
"A final matryoshka level of width=cfg.d_sae will be added."
|
|
128
|
+
)
|
|
129
|
+
cfg.matryoshka_widths.append(cfg.d_sae)
|
|
130
|
+
|
|
131
|
+
for prev_width, curr_width in zip(
|
|
132
|
+
cfg.matryoshka_widths[:-1], cfg.matryoshka_widths[1:]
|
|
133
|
+
):
|
|
134
|
+
if prev_width >= curr_width:
|
|
135
|
+
raise ValueError("cfg.matryoshka_widths must be strictly increasing.")
|
|
136
|
+
if len(cfg.matryoshka_widths) == 1:
|
|
137
|
+
warnings.warn(
|
|
138
|
+
"WARNING: You have only set one matryoshka level. This is equivalent to using a standard BatchTopK SAE and is likely not what you want."
|
|
139
|
+
)
|
|
140
|
+
if cfg.matryoshka_widths[0] < cfg.k:
|
|
141
|
+
raise ValueError(
|
|
142
|
+
"The smallest matryoshka level width cannot be smaller than cfg.k."
|
|
143
|
+
)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|