sae-lens 6.26.0__tar.gz → 6.28.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sae_lens-6.26.0 → sae_lens-6.28.0}/PKG-INFO +3 -1
- {sae_lens-6.26.0 → sae_lens-6.28.0}/README.md +2 -0
- {sae_lens-6.26.0 → sae_lens-6.28.0}/pyproject.toml +3 -1
- {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/__init__.py +3 -1
- {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/cache_activations_runner.py +12 -5
- {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/config.py +2 -0
- {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/loading/pretrained_sae_loaders.py +2 -1
- {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/pretrained_saes.yaml +144 -144
- {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/saes/gated_sae.py +1 -0
- {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/saes/jumprelu_sae.py +3 -0
- {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/saes/standard_sae.py +2 -0
- {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/saes/temporal_sae.py +1 -0
- sae_lens-6.28.0/sae_lens/synthetic/__init__.py +89 -0
- sae_lens-6.28.0/sae_lens/synthetic/activation_generator.py +215 -0
- sae_lens-6.28.0/sae_lens/synthetic/correlation.py +170 -0
- sae_lens-6.28.0/sae_lens/synthetic/evals.py +141 -0
- sae_lens-6.28.0/sae_lens/synthetic/feature_dictionary.py +138 -0
- sae_lens-6.28.0/sae_lens/synthetic/firing_probabilities.py +104 -0
- sae_lens-6.28.0/sae_lens/synthetic/hierarchy.py +335 -0
- sae_lens-6.28.0/sae_lens/synthetic/initialization.py +40 -0
- sae_lens-6.28.0/sae_lens/synthetic/plotting.py +230 -0
- sae_lens-6.28.0/sae_lens/synthetic/training.py +145 -0
- {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/training/activations_store.py +51 -91
- {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/training/mixing_buffer.py +14 -5
- {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/training/sae_trainer.py +1 -1
- {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/util.py +26 -1
- {sae_lens-6.26.0 → sae_lens-6.28.0}/LICENSE +0 -0
- {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/analysis/__init__.py +0 -0
- {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
- {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/analysis/neuronpedia_integration.py +0 -0
- {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/constants.py +0 -0
- {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/evals.py +0 -0
- {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/llm_sae_training_runner.py +0 -0
- {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/load_model.py +0 -0
- {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/loading/__init__.py +0 -0
- {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/loading/pretrained_saes_directory.py +0 -0
- {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/pretokenize_runner.py +0 -0
- {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/registry.py +0 -0
- {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/saes/__init__.py +0 -0
- {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/saes/batchtopk_sae.py +0 -0
- {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/saes/matching_pursuit_sae.py +0 -0
- {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/saes/matryoshka_batchtopk_sae.py +0 -0
- {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/saes/sae.py +0 -0
- {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/saes/topk_sae.py +0 -0
- {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/saes/transcoder.py +0 -0
- {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/tokenization_and_batching.py +1 -1
- {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/training/__init__.py +0 -0
- {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/training/activation_scaler.py +0 -0
- {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/training/optim.py +0 -0
- {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/training/types.py +0 -0
- {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
- {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/tutorial/tsea.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: sae-lens
|
|
3
|
-
Version: 6.
|
|
3
|
+
Version: 6.28.0
|
|
4
4
|
Summary: Training and Analyzing Sparse Autoencoders (SAEs)
|
|
5
5
|
License: MIT
|
|
6
6
|
License-File: LICENSE
|
|
@@ -77,6 +77,8 @@ The new v6 update is a major refactor to SAELens and changes the way training co
|
|
|
77
77
|
[](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/logits_lens_with_features.ipynb)
|
|
78
78
|
- [Training a Sparse Autoencoder](tutorials/training_a_sparse_autoencoder.ipynb)
|
|
79
79
|
[](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)
|
|
80
|
+
- [Training SAEs on Synthetic Data](tutorials/training_saes_on_synthetic_data.ipynb)
|
|
81
|
+
[](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/training_saes_on_synthetic_data.ipynb)
|
|
80
82
|
|
|
81
83
|
## Join the Slack!
|
|
82
84
|
|
|
@@ -41,6 +41,8 @@ The new v6 update is a major refactor to SAELens and changes the way training co
|
|
|
41
41
|
[](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/logits_lens_with_features.ipynb)
|
|
42
42
|
- [Training a Sparse Autoencoder](tutorials/training_a_sparse_autoencoder.ipynb)
|
|
43
43
|
[](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)
|
|
44
|
+
- [Training SAEs on Synthetic Data](tutorials/training_saes_on_synthetic_data.ipynb)
|
|
45
|
+
[](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/training_saes_on_synthetic_data.ipynb)
|
|
44
46
|
|
|
45
47
|
## Join the Slack!
|
|
46
48
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "sae-lens"
|
|
3
|
-
version = "6.
|
|
3
|
+
version = "6.28.0"
|
|
4
4
|
description = "Training and Analyzing Sparse Autoencoders (SAEs)"
|
|
5
5
|
authors = ["Joseph Bloom"]
|
|
6
6
|
readme = "README.md"
|
|
@@ -51,12 +51,14 @@ mkdocs-redirects = "^1.2.1"
|
|
|
51
51
|
mkdocs-section-index = "^0.3.9"
|
|
52
52
|
mkdocstrings = "^0.25.2"
|
|
53
53
|
mkdocstrings-python = "^1.10.9"
|
|
54
|
+
beautifulsoup4 = "^4.12.0"
|
|
54
55
|
tabulate = "^0.9.0"
|
|
55
56
|
ruff = "^0.7.4"
|
|
56
57
|
eai-sparsify = "^1.1.1"
|
|
57
58
|
mike = "^2.0.0"
|
|
58
59
|
trio = "^0.30.0"
|
|
59
60
|
dictionary-learning = "^0.1.0"
|
|
61
|
+
kaleido = "^1.2.0"
|
|
60
62
|
|
|
61
63
|
[tool.poetry.extras]
|
|
62
64
|
mamba = ["mamba-lens"]
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
# ruff: noqa: E402
|
|
2
|
-
__version__ = "6.
|
|
2
|
+
__version__ = "6.28.0"
|
|
3
3
|
|
|
4
4
|
import logging
|
|
5
5
|
|
|
@@ -63,6 +63,7 @@ from .loading.pretrained_sae_loaders import (
|
|
|
63
63
|
from .pretokenize_runner import PretokenizeRunner, pretokenize_runner
|
|
64
64
|
from .registry import register_sae_class, register_sae_training_class
|
|
65
65
|
from .training.activations_store import ActivationsStore
|
|
66
|
+
from .training.sae_trainer import SAETrainer
|
|
66
67
|
from .training.upload_saes_to_huggingface import upload_saes_to_huggingface
|
|
67
68
|
|
|
68
69
|
__all__ = [
|
|
@@ -102,6 +103,7 @@ __all__ = [
|
|
|
102
103
|
"JumpReLUTrainingSAE",
|
|
103
104
|
"JumpReLUTrainingSAEConfig",
|
|
104
105
|
"SAETrainingRunner",
|
|
106
|
+
"SAETrainer",
|
|
105
107
|
"LoggingConfig",
|
|
106
108
|
"BatchTopKTrainingSAE",
|
|
107
109
|
"BatchTopKTrainingSAEConfig",
|
|
@@ -263,14 +263,21 @@ class CacheActivationsRunner:
|
|
|
263
263
|
|
|
264
264
|
for i in tqdm(range(self.cfg.n_buffers), desc="Caching activations"):
|
|
265
265
|
try:
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
)
|
|
269
|
-
|
|
266
|
+
# Accumulate n_batches_in_buffer batches into one shard
|
|
267
|
+
buffers: list[tuple[torch.Tensor, torch.Tensor | None]] = []
|
|
268
|
+
for _ in range(self.cfg.n_batches_in_buffer):
|
|
269
|
+
buffers.append(self.activations_store.get_raw_llm_batch())
|
|
270
|
+
# Concatenate all batches
|
|
271
|
+
acts = torch.cat([b[0] for b in buffers], dim=0)
|
|
272
|
+
token_ids: torch.Tensor | None = None
|
|
273
|
+
if buffers[0][1] is not None:
|
|
274
|
+
# All batches have token_ids if the first one does
|
|
275
|
+
token_ids = torch.cat([b[1] for b in buffers], dim=0) # type: ignore[arg-type]
|
|
276
|
+
shard = self._create_shard((acts, token_ids))
|
|
270
277
|
shard.save_to_disk(
|
|
271
278
|
f"{tmp_cached_activation_path}/shard_{i:05d}", num_shards=1
|
|
272
279
|
)
|
|
273
|
-
del
|
|
280
|
+
del buffers, acts, token_ids, shard
|
|
274
281
|
except StopIteration:
|
|
275
282
|
logger.warning(
|
|
276
283
|
f"Warning: Ran out of samples while filling the buffer at batch {i} before reaching {self.cfg.n_buffers} batches."
|
|
@@ -148,6 +148,7 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
|
148
148
|
seqpos_slice (tuple[int | None, ...]): Determines slicing of activations when constructing batches during training. The slice should be (start_pos, end_pos, optional[step_size]), e.g. for Othello we sometimes use (5, -5). Note, step_size > 0.
|
|
149
149
|
disable_concat_sequences (bool): Whether to disable concatenating sequences and ignore sequences shorter than the context size. If True, disables concatenating and ignores short sequences.
|
|
150
150
|
sequence_separator_token (int | Literal["bos", "eos", "sep"] | None): If not `None`, this token will be placed between sentences in a batch to act as a separator. By default, this is the `<bos>` token.
|
|
151
|
+
activations_mixing_fraction (float): Fraction of the activation buffer to keep for mixing with new activations (default 0.5). Higher values mean more temporal shuffling but slower throughput. If 0, activations are served in order without shuffling (no temporal mixing).
|
|
151
152
|
device (str): The device to use. Usually "cuda".
|
|
152
153
|
act_store_device (str): The device to use for the activation store. "cpu" is advised in order to save VRAM. Defaults to "with_model" which uses the same device as the main model.
|
|
153
154
|
seed (int): The seed to use.
|
|
@@ -217,6 +218,7 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
|
217
218
|
sequence_separator_token: int | Literal["bos", "eos", "sep"] | None = (
|
|
218
219
|
special_token_field(default="bos")
|
|
219
220
|
)
|
|
221
|
+
activations_mixing_fraction: float = 0.5
|
|
220
222
|
|
|
221
223
|
# Misc
|
|
222
224
|
device: str = "cpu"
|
|
@@ -959,7 +959,7 @@ def get_dictionary_learning_config_1_from_hf(
|
|
|
959
959
|
architecture = "standard"
|
|
960
960
|
if trainer["dict_class"] == "GatedAutoEncoder":
|
|
961
961
|
architecture = "gated"
|
|
962
|
-
elif trainer["dict_class"]
|
|
962
|
+
elif trainer["dict_class"] in ["MatryoshkaBatchTopKSAE", "BatchTopKSAE"]:
|
|
963
963
|
architecture = "jumprelu"
|
|
964
964
|
|
|
965
965
|
return {
|
|
@@ -1831,6 +1831,7 @@ def temporal_sae_huggingface_loader(
|
|
|
1831
1831
|
Load TemporalSAE from canrager/temporalSAEs format (safetensors version).
|
|
1832
1832
|
|
|
1833
1833
|
Expects folder_name to contain:
|
|
1834
|
+
|
|
1834
1835
|
- conf.yaml (configuration)
|
|
1835
1836
|
- latest_ckpt.safetensors (model weights)
|
|
1836
1837
|
"""
|
|
@@ -9072,150 +9072,150 @@ gemma-scope-2-27b-it-transcoders-all:
|
|
|
9072
9072
|
- id: layer_5_width_262k_l0_small_affine
|
|
9073
9073
|
path: transcoder_all/layer_5_width_262k_l0_small_affine
|
|
9074
9074
|
l0: 12
|
|
9075
|
-
|
|
9076
|
-
|
|
9077
|
-
|
|
9078
|
-
|
|
9079
|
-
|
|
9080
|
-
|
|
9081
|
-
|
|
9082
|
-
|
|
9083
|
-
|
|
9084
|
-
|
|
9085
|
-
|
|
9086
|
-
|
|
9087
|
-
|
|
9088
|
-
|
|
9089
|
-
|
|
9090
|
-
|
|
9091
|
-
|
|
9092
|
-
|
|
9093
|
-
|
|
9094
|
-
|
|
9095
|
-
|
|
9096
|
-
|
|
9097
|
-
|
|
9098
|
-
|
|
9099
|
-
|
|
9100
|
-
|
|
9101
|
-
|
|
9102
|
-
|
|
9103
|
-
|
|
9104
|
-
|
|
9105
|
-
|
|
9106
|
-
|
|
9107
|
-
|
|
9108
|
-
|
|
9109
|
-
|
|
9110
|
-
|
|
9111
|
-
|
|
9112
|
-
|
|
9113
|
-
|
|
9114
|
-
|
|
9115
|
-
|
|
9116
|
-
|
|
9117
|
-
|
|
9118
|
-
|
|
9119
|
-
|
|
9120
|
-
|
|
9121
|
-
|
|
9122
|
-
|
|
9123
|
-
|
|
9124
|
-
|
|
9125
|
-
|
|
9126
|
-
|
|
9127
|
-
|
|
9128
|
-
|
|
9129
|
-
|
|
9130
|
-
|
|
9131
|
-
|
|
9132
|
-
|
|
9133
|
-
|
|
9134
|
-
|
|
9135
|
-
|
|
9136
|
-
|
|
9137
|
-
|
|
9138
|
-
|
|
9139
|
-
|
|
9140
|
-
|
|
9141
|
-
|
|
9142
|
-
|
|
9143
|
-
|
|
9144
|
-
|
|
9145
|
-
|
|
9146
|
-
|
|
9147
|
-
|
|
9148
|
-
|
|
9149
|
-
|
|
9150
|
-
|
|
9151
|
-
|
|
9152
|
-
|
|
9153
|
-
|
|
9154
|
-
|
|
9155
|
-
|
|
9156
|
-
|
|
9157
|
-
|
|
9158
|
-
|
|
9159
|
-
|
|
9160
|
-
|
|
9161
|
-
|
|
9162
|
-
|
|
9163
|
-
|
|
9164
|
-
|
|
9165
|
-
|
|
9166
|
-
|
|
9167
|
-
|
|
9168
|
-
|
|
9169
|
-
|
|
9170
|
-
|
|
9171
|
-
|
|
9172
|
-
|
|
9173
|
-
|
|
9174
|
-
|
|
9175
|
-
|
|
9176
|
-
|
|
9177
|
-
|
|
9178
|
-
|
|
9179
|
-
|
|
9180
|
-
|
|
9181
|
-
|
|
9182
|
-
|
|
9183
|
-
|
|
9184
|
-
|
|
9185
|
-
|
|
9186
|
-
|
|
9187
|
-
|
|
9188
|
-
|
|
9189
|
-
|
|
9190
|
-
|
|
9191
|
-
|
|
9192
|
-
|
|
9193
|
-
|
|
9194
|
-
|
|
9195
|
-
|
|
9196
|
-
|
|
9197
|
-
|
|
9198
|
-
|
|
9199
|
-
|
|
9200
|
-
|
|
9201
|
-
|
|
9202
|
-
|
|
9203
|
-
|
|
9204
|
-
|
|
9205
|
-
|
|
9206
|
-
|
|
9207
|
-
|
|
9208
|
-
|
|
9209
|
-
|
|
9210
|
-
|
|
9211
|
-
|
|
9212
|
-
|
|
9213
|
-
|
|
9214
|
-
|
|
9215
|
-
|
|
9216
|
-
|
|
9217
|
-
|
|
9218
|
-
|
|
9075
|
+
- id: layer_60_width_16k_l0_big
|
|
9076
|
+
path: transcoder_all/layer_60_width_16k_l0_big
|
|
9077
|
+
l0: 120
|
|
9078
|
+
- id: layer_60_width_16k_l0_big_affine
|
|
9079
|
+
path: transcoder_all/layer_60_width_16k_l0_big_affine
|
|
9080
|
+
l0: 120
|
|
9081
|
+
- id: layer_60_width_16k_l0_small
|
|
9082
|
+
path: transcoder_all/layer_60_width_16k_l0_small
|
|
9083
|
+
l0: 20
|
|
9084
|
+
- id: layer_60_width_16k_l0_small_affine
|
|
9085
|
+
path: transcoder_all/layer_60_width_16k_l0_small_affine
|
|
9086
|
+
l0: 20
|
|
9087
|
+
- id: layer_60_width_262k_l0_big
|
|
9088
|
+
path: transcoder_all/layer_60_width_262k_l0_big
|
|
9089
|
+
l0: 120
|
|
9090
|
+
- id: layer_60_width_262k_l0_big_affine
|
|
9091
|
+
path: transcoder_all/layer_60_width_262k_l0_big_affine
|
|
9092
|
+
l0: 120
|
|
9093
|
+
- id: layer_60_width_262k_l0_small
|
|
9094
|
+
path: transcoder_all/layer_60_width_262k_l0_small
|
|
9095
|
+
l0: 20
|
|
9096
|
+
- id: layer_60_width_262k_l0_small_affine
|
|
9097
|
+
path: transcoder_all/layer_60_width_262k_l0_small_affine
|
|
9098
|
+
l0: 20
|
|
9099
|
+
- id: layer_61_width_16k_l0_big
|
|
9100
|
+
path: transcoder_all/layer_61_width_16k_l0_big
|
|
9101
|
+
l0: 120
|
|
9102
|
+
- id: layer_61_width_16k_l0_big_affine
|
|
9103
|
+
path: transcoder_all/layer_61_width_16k_l0_big_affine
|
|
9104
|
+
l0: 120
|
|
9105
|
+
- id: layer_61_width_16k_l0_small
|
|
9106
|
+
path: transcoder_all/layer_61_width_16k_l0_small
|
|
9107
|
+
l0: 20
|
|
9108
|
+
- id: layer_61_width_16k_l0_small_affine
|
|
9109
|
+
path: transcoder_all/layer_61_width_16k_l0_small_affine
|
|
9110
|
+
l0: 20
|
|
9111
|
+
- id: layer_61_width_262k_l0_big
|
|
9112
|
+
path: transcoder_all/layer_61_width_262k_l0_big
|
|
9113
|
+
l0: 120
|
|
9114
|
+
- id: layer_61_width_262k_l0_big_affine
|
|
9115
|
+
path: transcoder_all/layer_61_width_262k_l0_big_affine
|
|
9116
|
+
l0: 120
|
|
9117
|
+
- id: layer_61_width_262k_l0_small
|
|
9118
|
+
path: transcoder_all/layer_61_width_262k_l0_small
|
|
9119
|
+
l0: 20
|
|
9120
|
+
- id: layer_61_width_262k_l0_small_affine
|
|
9121
|
+
path: transcoder_all/layer_61_width_262k_l0_small_affine
|
|
9122
|
+
l0: 20
|
|
9123
|
+
- id: layer_6_width_16k_l0_big
|
|
9124
|
+
path: transcoder_all/layer_6_width_16k_l0_big
|
|
9125
|
+
l0: 77
|
|
9126
|
+
- id: layer_6_width_16k_l0_big_affine
|
|
9127
|
+
path: transcoder_all/layer_6_width_16k_l0_big_affine
|
|
9128
|
+
l0: 77
|
|
9129
|
+
- id: layer_6_width_16k_l0_small
|
|
9130
|
+
path: transcoder_all/layer_6_width_16k_l0_small
|
|
9131
|
+
l0: 12
|
|
9132
|
+
- id: layer_6_width_16k_l0_small_affine
|
|
9133
|
+
path: transcoder_all/layer_6_width_16k_l0_small_affine
|
|
9134
|
+
l0: 12
|
|
9135
|
+
- id: layer_6_width_262k_l0_big
|
|
9136
|
+
path: transcoder_all/layer_6_width_262k_l0_big
|
|
9137
|
+
l0: 77
|
|
9138
|
+
- id: layer_6_width_262k_l0_big_affine
|
|
9139
|
+
path: transcoder_all/layer_6_width_262k_l0_big_affine
|
|
9140
|
+
l0: 77
|
|
9141
|
+
- id: layer_6_width_262k_l0_small
|
|
9142
|
+
path: transcoder_all/layer_6_width_262k_l0_small
|
|
9143
|
+
l0: 12
|
|
9144
|
+
- id: layer_6_width_262k_l0_small_affine
|
|
9145
|
+
path: transcoder_all/layer_6_width_262k_l0_small_affine
|
|
9146
|
+
l0: 12
|
|
9147
|
+
- id: layer_7_width_16k_l0_big
|
|
9148
|
+
path: transcoder_all/layer_7_width_16k_l0_big
|
|
9149
|
+
l0: 80
|
|
9150
|
+
- id: layer_7_width_16k_l0_big_affine
|
|
9151
|
+
path: transcoder_all/layer_7_width_16k_l0_big_affine
|
|
9152
|
+
l0: 80
|
|
9153
|
+
- id: layer_7_width_16k_l0_small
|
|
9154
|
+
path: transcoder_all/layer_7_width_16k_l0_small
|
|
9155
|
+
l0: 13
|
|
9156
|
+
- id: layer_7_width_16k_l0_small_affine
|
|
9157
|
+
path: transcoder_all/layer_7_width_16k_l0_small_affine
|
|
9158
|
+
l0: 13
|
|
9159
|
+
- id: layer_7_width_262k_l0_big
|
|
9160
|
+
path: transcoder_all/layer_7_width_262k_l0_big
|
|
9161
|
+
l0: 80
|
|
9162
|
+
- id: layer_7_width_262k_l0_big_affine
|
|
9163
|
+
path: transcoder_all/layer_7_width_262k_l0_big_affine
|
|
9164
|
+
l0: 80
|
|
9165
|
+
- id: layer_7_width_262k_l0_small
|
|
9166
|
+
path: transcoder_all/layer_7_width_262k_l0_small
|
|
9167
|
+
l0: 13
|
|
9168
|
+
- id: layer_7_width_262k_l0_small_affine
|
|
9169
|
+
path: transcoder_all/layer_7_width_262k_l0_small_affine
|
|
9170
|
+
l0: 13
|
|
9171
|
+
- id: layer_8_width_16k_l0_big
|
|
9172
|
+
path: transcoder_all/layer_8_width_16k_l0_big
|
|
9173
|
+
l0: 83
|
|
9174
|
+
- id: layer_8_width_16k_l0_big_affine
|
|
9175
|
+
path: transcoder_all/layer_8_width_16k_l0_big_affine
|
|
9176
|
+
l0: 83
|
|
9177
|
+
- id: layer_8_width_16k_l0_small
|
|
9178
|
+
path: transcoder_all/layer_8_width_16k_l0_small
|
|
9179
|
+
l0: 13
|
|
9180
|
+
- id: layer_8_width_16k_l0_small_affine
|
|
9181
|
+
path: transcoder_all/layer_8_width_16k_l0_small_affine
|
|
9182
|
+
l0: 13
|
|
9183
|
+
- id: layer_8_width_262k_l0_big
|
|
9184
|
+
path: transcoder_all/layer_8_width_262k_l0_big
|
|
9185
|
+
l0: 83
|
|
9186
|
+
- id: layer_8_width_262k_l0_big_affine
|
|
9187
|
+
path: transcoder_all/layer_8_width_262k_l0_big_affine
|
|
9188
|
+
l0: 83
|
|
9189
|
+
- id: layer_8_width_262k_l0_small
|
|
9190
|
+
path: transcoder_all/layer_8_width_262k_l0_small
|
|
9191
|
+
l0: 13
|
|
9192
|
+
- id: layer_8_width_262k_l0_small_affine
|
|
9193
|
+
path: transcoder_all/layer_8_width_262k_l0_small_affine
|
|
9194
|
+
l0: 13
|
|
9195
|
+
- id: layer_9_width_16k_l0_big
|
|
9196
|
+
path: transcoder_all/layer_9_width_16k_l0_big
|
|
9197
|
+
l0: 86
|
|
9198
|
+
- id: layer_9_width_16k_l0_big_affine
|
|
9199
|
+
path: transcoder_all/layer_9_width_16k_l0_big_affine
|
|
9200
|
+
l0: 86
|
|
9201
|
+
- id: layer_9_width_16k_l0_small
|
|
9202
|
+
path: transcoder_all/layer_9_width_16k_l0_small
|
|
9203
|
+
l0: 14
|
|
9204
|
+
- id: layer_9_width_16k_l0_small_affine
|
|
9205
|
+
path: transcoder_all/layer_9_width_16k_l0_small_affine
|
|
9206
|
+
l0: 14
|
|
9207
|
+
- id: layer_9_width_262k_l0_big
|
|
9208
|
+
path: transcoder_all/layer_9_width_262k_l0_big
|
|
9209
|
+
l0: 86
|
|
9210
|
+
- id: layer_9_width_262k_l0_big_affine
|
|
9211
|
+
path: transcoder_all/layer_9_width_262k_l0_big_affine
|
|
9212
|
+
l0: 86
|
|
9213
|
+
- id: layer_9_width_262k_l0_small
|
|
9214
|
+
path: transcoder_all/layer_9_width_262k_l0_small
|
|
9215
|
+
l0: 14
|
|
9216
|
+
- id: layer_9_width_262k_l0_small_affine
|
|
9217
|
+
path: transcoder_all/layer_9_width_262k_l0_small_affine
|
|
9218
|
+
l0: 14
|
|
9219
9219
|
gemma-scope-2-27b-it-transcoders:
|
|
9220
9220
|
conversion_func: gemma_3
|
|
9221
9221
|
model: google/gemma-3-27b-it
|
|
@@ -118,6 +118,7 @@ class GatedTrainingSAE(TrainingSAE[GatedTrainingSAEConfig]):
|
|
|
118
118
|
"""
|
|
119
119
|
GatedTrainingSAE is a concrete implementation of BaseTrainingSAE for the "gated" SAE architecture.
|
|
120
120
|
It implements:
|
|
121
|
+
|
|
121
122
|
- initialize_weights: sets up gating parameters (as in GatedSAE) plus optional training-specific init.
|
|
122
123
|
- encode: calls encode_with_hidden_pre (standard training approach).
|
|
123
124
|
- decode: linear transformation + hooking, same as GatedSAE or StandardTrainingSAE.
|
|
@@ -105,6 +105,7 @@ class JumpReLUSAE(SAE[JumpReLUSAEConfig]):
|
|
|
105
105
|
activation function (e.g., ReLU etc.).
|
|
106
106
|
|
|
107
107
|
It implements:
|
|
108
|
+
|
|
108
109
|
- initialize_weights: sets up parameters, including a threshold.
|
|
109
110
|
- encode: computes the feature activations using JumpReLU.
|
|
110
111
|
- decode: reconstructs the input from the feature activations.
|
|
@@ -216,10 +217,12 @@ class JumpReLUTrainingSAE(TrainingSAE[JumpReLUTrainingSAEConfig]):
|
|
|
216
217
|
JumpReLUTrainingSAE is a training-focused implementation of a SAE using a JumpReLU activation.
|
|
217
218
|
|
|
218
219
|
Similar to the inference-only JumpReLUSAE, but with:
|
|
220
|
+
|
|
219
221
|
- A learnable log-threshold parameter (instead of a raw threshold).
|
|
220
222
|
- A specialized auxiliary loss term for sparsity (L0 or similar).
|
|
221
223
|
|
|
222
224
|
Methods of interest include:
|
|
225
|
+
|
|
223
226
|
- initialize_weights: sets up W_enc, b_enc, W_dec, b_dec, and log_threshold.
|
|
224
227
|
- encode_with_hidden_pre_jumprelu: runs a forward pass for training.
|
|
225
228
|
- training_forward_pass: calculates MSE and auxiliary losses, returning a TrainStepOutput.
|
|
@@ -34,6 +34,7 @@ class StandardSAE(SAE[StandardSAEConfig]):
|
|
|
34
34
|
using a simple linear encoder and decoder.
|
|
35
35
|
|
|
36
36
|
It implements the required abstract methods from BaseSAE:
|
|
37
|
+
|
|
37
38
|
- initialize_weights: sets up simple parameter initializations for W_enc, b_enc, W_dec, and b_dec.
|
|
38
39
|
- encode: computes the feature activations from an input.
|
|
39
40
|
- decode: reconstructs the input from the feature activations.
|
|
@@ -99,6 +100,7 @@ class StandardTrainingSAE(TrainingSAE[StandardTrainingSAEConfig]):
|
|
|
99
100
|
"""
|
|
100
101
|
StandardTrainingSAE is a concrete implementation of BaseTrainingSAE using the "standard" SAE architecture.
|
|
101
102
|
It implements:
|
|
103
|
+
|
|
102
104
|
- initialize_weights: basic weight initialization for encoder/decoder.
|
|
103
105
|
- encode: inference encoding (invokes encode_with_hidden_pre).
|
|
104
106
|
- decode: a simple linear decoder.
|
|
@@ -167,6 +167,7 @@ class TemporalSAE(SAE[TemporalSAEConfig]):
|
|
|
167
167
|
"""TemporalSAE: Sparse Autoencoder with temporal attention.
|
|
168
168
|
|
|
169
169
|
This SAE decomposes each activation x_t into:
|
|
170
|
+
|
|
170
171
|
- x_pred: Information aggregated from context {x_0, ..., x_{t-1}}
|
|
171
172
|
- x_novel: Novel information at position t (encoded sparsely)
|
|
172
173
|
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Synthetic data utilities for SAE experiments.
|
|
3
|
+
|
|
4
|
+
This module provides tools for creating feature dictionaries and generating
|
|
5
|
+
synthetic activations for testing and experimenting with SAEs.
|
|
6
|
+
|
|
7
|
+
Main components:
|
|
8
|
+
|
|
9
|
+
- FeatureDictionary: Maps sparse feature activations to dense hidden activations
|
|
10
|
+
- ActivationGenerator: Generates batches of synthetic feature activations
|
|
11
|
+
- HierarchyNode: Enforces hierarchical structure on feature activations
|
|
12
|
+
- Training utilities: Helpers for training and evaluating SAEs on synthetic data
|
|
13
|
+
- Plotting utilities: Visualization helpers for understanding SAE behavior
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from sae_lens.synthetic.activation_generator import (
|
|
17
|
+
ActivationGenerator,
|
|
18
|
+
ActivationsModifier,
|
|
19
|
+
ActivationsModifierInput,
|
|
20
|
+
)
|
|
21
|
+
from sae_lens.synthetic.correlation import (
|
|
22
|
+
create_correlation_matrix_from_correlations,
|
|
23
|
+
generate_random_correlation_matrix,
|
|
24
|
+
generate_random_correlations,
|
|
25
|
+
)
|
|
26
|
+
from sae_lens.synthetic.evals import (
|
|
27
|
+
SyntheticDataEvalResult,
|
|
28
|
+
eval_sae_on_synthetic_data,
|
|
29
|
+
mean_correlation_coefficient,
|
|
30
|
+
)
|
|
31
|
+
from sae_lens.synthetic.feature_dictionary import (
|
|
32
|
+
FeatureDictionary,
|
|
33
|
+
FeatureDictionaryInitializer,
|
|
34
|
+
orthogonal_initializer,
|
|
35
|
+
orthogonalize_embeddings,
|
|
36
|
+
)
|
|
37
|
+
from sae_lens.synthetic.firing_probabilities import (
|
|
38
|
+
linear_firing_probabilities,
|
|
39
|
+
random_firing_probabilities,
|
|
40
|
+
zipfian_firing_probabilities,
|
|
41
|
+
)
|
|
42
|
+
from sae_lens.synthetic.hierarchy import HierarchyNode, hierarchy_modifier
|
|
43
|
+
from sae_lens.synthetic.initialization import init_sae_to_match_feature_dict
|
|
44
|
+
from sae_lens.synthetic.plotting import (
|
|
45
|
+
find_best_feature_ordering,
|
|
46
|
+
find_best_feature_ordering_across_saes,
|
|
47
|
+
find_best_feature_ordering_from_sae,
|
|
48
|
+
plot_sae_feature_similarity,
|
|
49
|
+
)
|
|
50
|
+
from sae_lens.synthetic.training import (
|
|
51
|
+
SyntheticActivationIterator,
|
|
52
|
+
train_toy_sae,
|
|
53
|
+
)
|
|
54
|
+
from sae_lens.util import cosine_similarities
|
|
55
|
+
|
|
56
|
+
__all__ = [
|
|
57
|
+
# Main classes
|
|
58
|
+
"FeatureDictionary",
|
|
59
|
+
"HierarchyNode",
|
|
60
|
+
"hierarchy_modifier",
|
|
61
|
+
"ActivationGenerator",
|
|
62
|
+
# Activation generation
|
|
63
|
+
"zipfian_firing_probabilities",
|
|
64
|
+
"linear_firing_probabilities",
|
|
65
|
+
"random_firing_probabilities",
|
|
66
|
+
"create_correlation_matrix_from_correlations",
|
|
67
|
+
"generate_random_correlations",
|
|
68
|
+
"generate_random_correlation_matrix",
|
|
69
|
+
# Feature modifiers
|
|
70
|
+
"ActivationsModifier",
|
|
71
|
+
"ActivationsModifierInput",
|
|
72
|
+
# Utilities
|
|
73
|
+
"orthogonalize_embeddings",
|
|
74
|
+
"orthogonal_initializer",
|
|
75
|
+
"FeatureDictionaryInitializer",
|
|
76
|
+
"cosine_similarities",
|
|
77
|
+
# Training utilities
|
|
78
|
+
"SyntheticActivationIterator",
|
|
79
|
+
"SyntheticDataEvalResult",
|
|
80
|
+
"train_toy_sae",
|
|
81
|
+
"eval_sae_on_synthetic_data",
|
|
82
|
+
"mean_correlation_coefficient",
|
|
83
|
+
"init_sae_to_match_feature_dict",
|
|
84
|
+
# Plotting utilities
|
|
85
|
+
"find_best_feature_ordering",
|
|
86
|
+
"find_best_feature_ordering_from_sae",
|
|
87
|
+
"find_best_feature_ordering_across_saes",
|
|
88
|
+
"plot_sae_feature_similarity",
|
|
89
|
+
]
|