sae-lens 6.26.0__tar.gz → 6.28.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. {sae_lens-6.26.0 → sae_lens-6.28.0}/PKG-INFO +3 -1
  2. {sae_lens-6.26.0 → sae_lens-6.28.0}/README.md +2 -0
  3. {sae_lens-6.26.0 → sae_lens-6.28.0}/pyproject.toml +3 -1
  4. {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/__init__.py +3 -1
  5. {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/cache_activations_runner.py +12 -5
  6. {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/config.py +2 -0
  7. {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/loading/pretrained_sae_loaders.py +2 -1
  8. {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/pretrained_saes.yaml +144 -144
  9. {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/saes/gated_sae.py +1 -0
  10. {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/saes/jumprelu_sae.py +3 -0
  11. {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/saes/standard_sae.py +2 -0
  12. {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/saes/temporal_sae.py +1 -0
  13. sae_lens-6.28.0/sae_lens/synthetic/__init__.py +89 -0
  14. sae_lens-6.28.0/sae_lens/synthetic/activation_generator.py +215 -0
  15. sae_lens-6.28.0/sae_lens/synthetic/correlation.py +170 -0
  16. sae_lens-6.28.0/sae_lens/synthetic/evals.py +141 -0
  17. sae_lens-6.28.0/sae_lens/synthetic/feature_dictionary.py +138 -0
  18. sae_lens-6.28.0/sae_lens/synthetic/firing_probabilities.py +104 -0
  19. sae_lens-6.28.0/sae_lens/synthetic/hierarchy.py +335 -0
  20. sae_lens-6.28.0/sae_lens/synthetic/initialization.py +40 -0
  21. sae_lens-6.28.0/sae_lens/synthetic/plotting.py +230 -0
  22. sae_lens-6.28.0/sae_lens/synthetic/training.py +145 -0
  23. {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/training/activations_store.py +51 -91
  24. {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/training/mixing_buffer.py +14 -5
  25. {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/training/sae_trainer.py +1 -1
  26. {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/util.py +26 -1
  27. {sae_lens-6.26.0 → sae_lens-6.28.0}/LICENSE +0 -0
  28. {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/analysis/__init__.py +0 -0
  29. {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
  30. {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/analysis/neuronpedia_integration.py +0 -0
  31. {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/constants.py +0 -0
  32. {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/evals.py +0 -0
  33. {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/llm_sae_training_runner.py +0 -0
  34. {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/load_model.py +0 -0
  35. {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/loading/__init__.py +0 -0
  36. {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/loading/pretrained_saes_directory.py +0 -0
  37. {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/pretokenize_runner.py +0 -0
  38. {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/registry.py +0 -0
  39. {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/saes/__init__.py +0 -0
  40. {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/saes/batchtopk_sae.py +0 -0
  41. {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/saes/matching_pursuit_sae.py +0 -0
  42. {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/saes/matryoshka_batchtopk_sae.py +0 -0
  43. {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/saes/sae.py +0 -0
  44. {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/saes/topk_sae.py +0 -0
  45. {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/saes/transcoder.py +0 -0
  46. {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/tokenization_and_batching.py +1 -1
  47. {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/training/__init__.py +0 -0
  48. {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/training/activation_scaler.py +0 -0
  49. {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/training/optim.py +0 -0
  50. {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/training/types.py +0 -0
  51. {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
  52. {sae_lens-6.26.0 → sae_lens-6.28.0}/sae_lens/tutorial/tsea.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sae-lens
3
- Version: 6.26.0
3
+ Version: 6.28.0
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  License-File: LICENSE
@@ -77,6 +77,8 @@ The new v6 update is a major refactor to SAELens and changes the way training co
77
77
  [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/logits_lens_with_features.ipynb)
78
78
  - [Training a Sparse Autoencoder](tutorials/training_a_sparse_autoencoder.ipynb)
79
79
  [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)
80
+ - [Training SAEs on Synthetic Data](tutorials/training_saes_on_synthetic_data.ipynb)
81
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/training_saes_on_synthetic_data.ipynb)
80
82
 
81
83
  ## Join the Slack!
82
84
 
@@ -41,6 +41,8 @@ The new v6 update is a major refactor to SAELens and changes the way training co
41
41
  [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/logits_lens_with_features.ipynb)
42
42
  - [Training a Sparse Autoencoder](tutorials/training_a_sparse_autoencoder.ipynb)
43
43
  [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)
44
+ - [Training SAEs on Synthetic Data](tutorials/training_saes_on_synthetic_data.ipynb)
45
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/training_saes_on_synthetic_data.ipynb)
44
46
 
45
47
  ## Join the Slack!
46
48
 
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "sae-lens"
3
- version = "6.26.0"
3
+ version = "6.28.0"
4
4
  description = "Training and Analyzing Sparse Autoencoders (SAEs)"
5
5
  authors = ["Joseph Bloom"]
6
6
  readme = "README.md"
@@ -51,12 +51,14 @@ mkdocs-redirects = "^1.2.1"
51
51
  mkdocs-section-index = "^0.3.9"
52
52
  mkdocstrings = "^0.25.2"
53
53
  mkdocstrings-python = "^1.10.9"
54
+ beautifulsoup4 = "^4.12.0"
54
55
  tabulate = "^0.9.0"
55
56
  ruff = "^0.7.4"
56
57
  eai-sparsify = "^1.1.1"
57
58
  mike = "^2.0.0"
58
59
  trio = "^0.30.0"
59
60
  dictionary-learning = "^0.1.0"
61
+ kaleido = "^1.2.0"
60
62
 
61
63
  [tool.poetry.extras]
62
64
  mamba = ["mamba-lens"]
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.26.0"
2
+ __version__ = "6.28.0"
3
3
 
4
4
  import logging
5
5
 
@@ -63,6 +63,7 @@ from .loading.pretrained_sae_loaders import (
63
63
  from .pretokenize_runner import PretokenizeRunner, pretokenize_runner
64
64
  from .registry import register_sae_class, register_sae_training_class
65
65
  from .training.activations_store import ActivationsStore
66
+ from .training.sae_trainer import SAETrainer
66
67
  from .training.upload_saes_to_huggingface import upload_saes_to_huggingface
67
68
 
68
69
  __all__ = [
@@ -102,6 +103,7 @@ __all__ = [
102
103
  "JumpReLUTrainingSAE",
103
104
  "JumpReLUTrainingSAEConfig",
104
105
  "SAETrainingRunner",
106
+ "SAETrainer",
105
107
  "LoggingConfig",
106
108
  "BatchTopKTrainingSAE",
107
109
  "BatchTopKTrainingSAEConfig",
@@ -263,14 +263,21 @@ class CacheActivationsRunner:
263
263
 
264
264
  for i in tqdm(range(self.cfg.n_buffers), desc="Caching activations"):
265
265
  try:
266
- buffer = self.activations_store.get_raw_buffer(
267
- self.cfg.n_batches_in_buffer, shuffle=False
268
- )
269
- shard = self._create_shard(buffer)
266
+ # Accumulate n_batches_in_buffer batches into one shard
267
+ buffers: list[tuple[torch.Tensor, torch.Tensor | None]] = []
268
+ for _ in range(self.cfg.n_batches_in_buffer):
269
+ buffers.append(self.activations_store.get_raw_llm_batch())
270
+ # Concatenate all batches
271
+ acts = torch.cat([b[0] for b in buffers], dim=0)
272
+ token_ids: torch.Tensor | None = None
273
+ if buffers[0][1] is not None:
274
+ # All batches have token_ids if the first one does
275
+ token_ids = torch.cat([b[1] for b in buffers], dim=0) # type: ignore[arg-type]
276
+ shard = self._create_shard((acts, token_ids))
270
277
  shard.save_to_disk(
271
278
  f"{tmp_cached_activation_path}/shard_{i:05d}", num_shards=1
272
279
  )
273
- del buffer, shard
280
+ del buffers, acts, token_ids, shard
274
281
  except StopIteration:
275
282
  logger.warning(
276
283
  f"Warning: Ran out of samples while filling the buffer at batch {i} before reaching {self.cfg.n_buffers} batches."
@@ -148,6 +148,7 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
148
148
  seqpos_slice (tuple[int | None, ...]): Determines slicing of activations when constructing batches during training. The slice should be (start_pos, end_pos, optional[step_size]), e.g. for Othello we sometimes use (5, -5). Note, step_size > 0.
149
149
  disable_concat_sequences (bool): Whether to disable concatenating sequences and ignore sequences shorter than the context size. If True, disables concatenating and ignores short sequences.
150
150
  sequence_separator_token (int | Literal["bos", "eos", "sep"] | None): If not `None`, this token will be placed between sentences in a batch to act as a separator. By default, this is the `<bos>` token.
151
+ activations_mixing_fraction (float): Fraction of the activation buffer to keep for mixing with new activations (default 0.5). Higher values mean more temporal shuffling but slower throughput. If 0, activations are served in order without shuffling (no temporal mixing).
151
152
  device (str): The device to use. Usually "cuda".
152
153
  act_store_device (str): The device to use for the activation store. "cpu" is advised in order to save VRAM. Defaults to "with_model" which uses the same device as the main model.
153
154
  seed (int): The seed to use.
@@ -217,6 +218,7 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
217
218
  sequence_separator_token: int | Literal["bos", "eos", "sep"] | None = (
218
219
  special_token_field(default="bos")
219
220
  )
221
+ activations_mixing_fraction: float = 0.5
220
222
 
221
223
  # Misc
222
224
  device: str = "cpu"
@@ -959,7 +959,7 @@ def get_dictionary_learning_config_1_from_hf(
959
959
  architecture = "standard"
960
960
  if trainer["dict_class"] == "GatedAutoEncoder":
961
961
  architecture = "gated"
962
- elif trainer["dict_class"] == "MatryoshkaBatchTopKSAE":
962
+ elif trainer["dict_class"] in ["MatryoshkaBatchTopKSAE", "BatchTopKSAE"]:
963
963
  architecture = "jumprelu"
964
964
 
965
965
  return {
@@ -1831,6 +1831,7 @@ def temporal_sae_huggingface_loader(
1831
1831
  Load TemporalSAE from canrager/temporalSAEs format (safetensors version).
1832
1832
 
1833
1833
  Expects folder_name to contain:
1834
+
1834
1835
  - conf.yaml (configuration)
1835
1836
  - latest_ckpt.safetensors (model weights)
1836
1837
  """
@@ -9072,150 +9072,150 @@ gemma-scope-2-27b-it-transcoders-all:
9072
9072
  - id: layer_5_width_262k_l0_small_affine
9073
9073
  path: transcoder_all/layer_5_width_262k_l0_small_affine
9074
9074
  l0: 12
9075
- # - id: layer_60_width_16k_l0_big
9076
- # path: transcoder_all/layer_60_width_16k_l0_big
9077
- # l0: 120
9078
- # - id: layer_60_width_16k_l0_big_affine
9079
- # path: transcoder_all/layer_60_width_16k_l0_big_affine
9080
- # l0: 120
9081
- # - id: layer_60_width_16k_l0_small
9082
- # path: transcoder_all/layer_60_width_16k_l0_small
9083
- # l0: 20
9084
- # - id: layer_60_width_16k_l0_small_affine
9085
- # path: transcoder_all/layer_60_width_16k_l0_small_affine
9086
- # l0: 20
9087
- # - id: layer_60_width_262k_l0_big
9088
- # path: transcoder_all/layer_60_width_262k_l0_big
9089
- # l0: 120
9090
- # - id: layer_60_width_262k_l0_big_affine
9091
- # path: transcoder_all/layer_60_width_262k_l0_big_affine
9092
- # l0: 120
9093
- # - id: layer_60_width_262k_l0_small
9094
- # path: transcoder_all/layer_60_width_262k_l0_small
9095
- # l0: 20
9096
- # - id: layer_60_width_262k_l0_small_affine
9097
- # path: transcoder_all/layer_60_width_262k_l0_small_affine
9098
- # l0: 20
9099
- # - id: layer_61_width_16k_l0_big
9100
- # path: transcoder_all/layer_61_width_16k_l0_big
9101
- # l0: 120
9102
- # - id: layer_61_width_16k_l0_big_affine
9103
- # path: transcoder_all/layer_61_width_16k_l0_big_affine
9104
- # l0: 120
9105
- # - id: layer_61_width_16k_l0_small
9106
- # path: transcoder_all/layer_61_width_16k_l0_small
9107
- # l0: 20
9108
- # - id: layer_61_width_16k_l0_small_affine
9109
- # path: transcoder_all/layer_61_width_16k_l0_small_affine
9110
- # l0: 20
9111
- # - id: layer_61_width_262k_l0_big
9112
- # path: transcoder_all/layer_61_width_262k_l0_big
9113
- # l0: 120
9114
- # - id: layer_61_width_262k_l0_big_affine
9115
- # path: transcoder_all/layer_61_width_262k_l0_big_affine
9116
- # l0: 120
9117
- # - id: layer_61_width_262k_l0_small
9118
- # path: transcoder_all/layer_61_width_262k_l0_small
9119
- # l0: 20
9120
- # - id: layer_61_width_262k_l0_small_affine
9121
- # path: transcoder_all/layer_61_width_262k_l0_small_affine
9122
- # l0: 20
9123
- # - id: layer_6_width_16k_l0_big
9124
- # path: transcoder_all/layer_6_width_16k_l0_big
9125
- # l0: 77
9126
- # - id: layer_6_width_16k_l0_big_affine
9127
- # path: transcoder_all/layer_6_width_16k_l0_big_affine
9128
- # l0: 77
9129
- # - id: layer_6_width_16k_l0_small
9130
- # path: transcoder_all/layer_6_width_16k_l0_small
9131
- # l0: 12
9132
- # - id: layer_6_width_16k_l0_small_affine
9133
- # path: transcoder_all/layer_6_width_16k_l0_small_affine
9134
- # l0: 12
9135
- # - id: layer_6_width_262k_l0_big
9136
- # path: transcoder_all/layer_6_width_262k_l0_big
9137
- # l0: 77
9138
- # - id: layer_6_width_262k_l0_big_affine
9139
- # path: transcoder_all/layer_6_width_262k_l0_big_affine
9140
- # l0: 77
9141
- # - id: layer_6_width_262k_l0_small
9142
- # path: transcoder_all/layer_6_width_262k_l0_small
9143
- # l0: 12
9144
- # - id: layer_6_width_262k_l0_small_affine
9145
- # path: transcoder_all/layer_6_width_262k_l0_small_affine
9146
- # l0: 12
9147
- # - id: layer_7_width_16k_l0_big
9148
- # path: transcoder_all/layer_7_width_16k_l0_big
9149
- # l0: 80
9150
- # - id: layer_7_width_16k_l0_big_affine
9151
- # path: transcoder_all/layer_7_width_16k_l0_big_affine
9152
- # l0: 80
9153
- # - id: layer_7_width_16k_l0_small
9154
- # path: transcoder_all/layer_7_width_16k_l0_small
9155
- # l0: 13
9156
- # - id: layer_7_width_16k_l0_small_affine
9157
- # path: transcoder_all/layer_7_width_16k_l0_small_affine
9158
- # l0: 13
9159
- # - id: layer_7_width_262k_l0_big
9160
- # path: transcoder_all/layer_7_width_262k_l0_big
9161
- # l0: 80
9162
- # - id: layer_7_width_262k_l0_big_affine
9163
- # path: transcoder_all/layer_7_width_262k_l0_big_affine
9164
- # l0: 80
9165
- # - id: layer_7_width_262k_l0_small
9166
- # path: transcoder_all/layer_7_width_262k_l0_small
9167
- # l0: 13
9168
- # - id: layer_7_width_262k_l0_small_affine
9169
- # path: transcoder_all/layer_7_width_262k_l0_small_affine
9170
- # l0: 13
9171
- # - id: layer_8_width_16k_l0_big
9172
- # path: transcoder_all/layer_8_width_16k_l0_big
9173
- # l0: 83
9174
- # - id: layer_8_width_16k_l0_big_affine
9175
- # path: transcoder_all/layer_8_width_16k_l0_big_affine
9176
- # l0: 83
9177
- # - id: layer_8_width_16k_l0_small
9178
- # path: transcoder_all/layer_8_width_16k_l0_small
9179
- # l0: 13
9180
- # - id: layer_8_width_16k_l0_small_affine
9181
- # path: transcoder_all/layer_8_width_16k_l0_small_affine
9182
- # l0: 13
9183
- # - id: layer_8_width_262k_l0_big
9184
- # path: transcoder_all/layer_8_width_262k_l0_big
9185
- # l0: 83
9186
- # - id: layer_8_width_262k_l0_big_affine
9187
- # path: transcoder_all/layer_8_width_262k_l0_big_affine
9188
- # l0: 83
9189
- # - id: layer_8_width_262k_l0_small
9190
- # path: transcoder_all/layer_8_width_262k_l0_small
9191
- # l0: 13
9192
- # - id: layer_8_width_262k_l0_small_affine
9193
- # path: transcoder_all/layer_8_width_262k_l0_small_affine
9194
- # l0: 13
9195
- # - id: layer_9_width_16k_l0_big
9196
- # path: transcoder_all/layer_9_width_16k_l0_big
9197
- # l0: 86
9198
- # - id: layer_9_width_16k_l0_big_affine
9199
- # path: transcoder_all/layer_9_width_16k_l0_big_affine
9200
- # l0: 86
9201
- # - id: layer_9_width_16k_l0_small
9202
- # path: transcoder_all/layer_9_width_16k_l0_small
9203
- # l0: 14
9204
- # - id: layer_9_width_16k_l0_small_affine
9205
- # path: transcoder_all/layer_9_width_16k_l0_small_affine
9206
- # l0: 14
9207
- # - id: layer_9_width_262k_l0_big
9208
- # path: transcoder_all/layer_9_width_262k_l0_big
9209
- # l0: 86
9210
- # - id: layer_9_width_262k_l0_big_affine
9211
- # path: transcoder_all/layer_9_width_262k_l0_big_affine
9212
- # l0: 86
9213
- # - id: layer_9_width_262k_l0_small
9214
- # path: transcoder_all/layer_9_width_262k_l0_small
9215
- # l0: 14
9216
- # - id: layer_9_width_262k_l0_small_affine
9217
- # path: transcoder_all/layer_9_width_262k_l0_small_affine
9218
- # l0: 14
9075
+ - id: layer_60_width_16k_l0_big
9076
+ path: transcoder_all/layer_60_width_16k_l0_big
9077
+ l0: 120
9078
+ - id: layer_60_width_16k_l0_big_affine
9079
+ path: transcoder_all/layer_60_width_16k_l0_big_affine
9080
+ l0: 120
9081
+ - id: layer_60_width_16k_l0_small
9082
+ path: transcoder_all/layer_60_width_16k_l0_small
9083
+ l0: 20
9084
+ - id: layer_60_width_16k_l0_small_affine
9085
+ path: transcoder_all/layer_60_width_16k_l0_small_affine
9086
+ l0: 20
9087
+ - id: layer_60_width_262k_l0_big
9088
+ path: transcoder_all/layer_60_width_262k_l0_big
9089
+ l0: 120
9090
+ - id: layer_60_width_262k_l0_big_affine
9091
+ path: transcoder_all/layer_60_width_262k_l0_big_affine
9092
+ l0: 120
9093
+ - id: layer_60_width_262k_l0_small
9094
+ path: transcoder_all/layer_60_width_262k_l0_small
9095
+ l0: 20
9096
+ - id: layer_60_width_262k_l0_small_affine
9097
+ path: transcoder_all/layer_60_width_262k_l0_small_affine
9098
+ l0: 20
9099
+ - id: layer_61_width_16k_l0_big
9100
+ path: transcoder_all/layer_61_width_16k_l0_big
9101
+ l0: 120
9102
+ - id: layer_61_width_16k_l0_big_affine
9103
+ path: transcoder_all/layer_61_width_16k_l0_big_affine
9104
+ l0: 120
9105
+ - id: layer_61_width_16k_l0_small
9106
+ path: transcoder_all/layer_61_width_16k_l0_small
9107
+ l0: 20
9108
+ - id: layer_61_width_16k_l0_small_affine
9109
+ path: transcoder_all/layer_61_width_16k_l0_small_affine
9110
+ l0: 20
9111
+ - id: layer_61_width_262k_l0_big
9112
+ path: transcoder_all/layer_61_width_262k_l0_big
9113
+ l0: 120
9114
+ - id: layer_61_width_262k_l0_big_affine
9115
+ path: transcoder_all/layer_61_width_262k_l0_big_affine
9116
+ l0: 120
9117
+ - id: layer_61_width_262k_l0_small
9118
+ path: transcoder_all/layer_61_width_262k_l0_small
9119
+ l0: 20
9120
+ - id: layer_61_width_262k_l0_small_affine
9121
+ path: transcoder_all/layer_61_width_262k_l0_small_affine
9122
+ l0: 20
9123
+ - id: layer_6_width_16k_l0_big
9124
+ path: transcoder_all/layer_6_width_16k_l0_big
9125
+ l0: 77
9126
+ - id: layer_6_width_16k_l0_big_affine
9127
+ path: transcoder_all/layer_6_width_16k_l0_big_affine
9128
+ l0: 77
9129
+ - id: layer_6_width_16k_l0_small
9130
+ path: transcoder_all/layer_6_width_16k_l0_small
9131
+ l0: 12
9132
+ - id: layer_6_width_16k_l0_small_affine
9133
+ path: transcoder_all/layer_6_width_16k_l0_small_affine
9134
+ l0: 12
9135
+ - id: layer_6_width_262k_l0_big
9136
+ path: transcoder_all/layer_6_width_262k_l0_big
9137
+ l0: 77
9138
+ - id: layer_6_width_262k_l0_big_affine
9139
+ path: transcoder_all/layer_6_width_262k_l0_big_affine
9140
+ l0: 77
9141
+ - id: layer_6_width_262k_l0_small
9142
+ path: transcoder_all/layer_6_width_262k_l0_small
9143
+ l0: 12
9144
+ - id: layer_6_width_262k_l0_small_affine
9145
+ path: transcoder_all/layer_6_width_262k_l0_small_affine
9146
+ l0: 12
9147
+ - id: layer_7_width_16k_l0_big
9148
+ path: transcoder_all/layer_7_width_16k_l0_big
9149
+ l0: 80
9150
+ - id: layer_7_width_16k_l0_big_affine
9151
+ path: transcoder_all/layer_7_width_16k_l0_big_affine
9152
+ l0: 80
9153
+ - id: layer_7_width_16k_l0_small
9154
+ path: transcoder_all/layer_7_width_16k_l0_small
9155
+ l0: 13
9156
+ - id: layer_7_width_16k_l0_small_affine
9157
+ path: transcoder_all/layer_7_width_16k_l0_small_affine
9158
+ l0: 13
9159
+ - id: layer_7_width_262k_l0_big
9160
+ path: transcoder_all/layer_7_width_262k_l0_big
9161
+ l0: 80
9162
+ - id: layer_7_width_262k_l0_big_affine
9163
+ path: transcoder_all/layer_7_width_262k_l0_big_affine
9164
+ l0: 80
9165
+ - id: layer_7_width_262k_l0_small
9166
+ path: transcoder_all/layer_7_width_262k_l0_small
9167
+ l0: 13
9168
+ - id: layer_7_width_262k_l0_small_affine
9169
+ path: transcoder_all/layer_7_width_262k_l0_small_affine
9170
+ l0: 13
9171
+ - id: layer_8_width_16k_l0_big
9172
+ path: transcoder_all/layer_8_width_16k_l0_big
9173
+ l0: 83
9174
+ - id: layer_8_width_16k_l0_big_affine
9175
+ path: transcoder_all/layer_8_width_16k_l0_big_affine
9176
+ l0: 83
9177
+ - id: layer_8_width_16k_l0_small
9178
+ path: transcoder_all/layer_8_width_16k_l0_small
9179
+ l0: 13
9180
+ - id: layer_8_width_16k_l0_small_affine
9181
+ path: transcoder_all/layer_8_width_16k_l0_small_affine
9182
+ l0: 13
9183
+ - id: layer_8_width_262k_l0_big
9184
+ path: transcoder_all/layer_8_width_262k_l0_big
9185
+ l0: 83
9186
+ - id: layer_8_width_262k_l0_big_affine
9187
+ path: transcoder_all/layer_8_width_262k_l0_big_affine
9188
+ l0: 83
9189
+ - id: layer_8_width_262k_l0_small
9190
+ path: transcoder_all/layer_8_width_262k_l0_small
9191
+ l0: 13
9192
+ - id: layer_8_width_262k_l0_small_affine
9193
+ path: transcoder_all/layer_8_width_262k_l0_small_affine
9194
+ l0: 13
9195
+ - id: layer_9_width_16k_l0_big
9196
+ path: transcoder_all/layer_9_width_16k_l0_big
9197
+ l0: 86
9198
+ - id: layer_9_width_16k_l0_big_affine
9199
+ path: transcoder_all/layer_9_width_16k_l0_big_affine
9200
+ l0: 86
9201
+ - id: layer_9_width_16k_l0_small
9202
+ path: transcoder_all/layer_9_width_16k_l0_small
9203
+ l0: 14
9204
+ - id: layer_9_width_16k_l0_small_affine
9205
+ path: transcoder_all/layer_9_width_16k_l0_small_affine
9206
+ l0: 14
9207
+ - id: layer_9_width_262k_l0_big
9208
+ path: transcoder_all/layer_9_width_262k_l0_big
9209
+ l0: 86
9210
+ - id: layer_9_width_262k_l0_big_affine
9211
+ path: transcoder_all/layer_9_width_262k_l0_big_affine
9212
+ l0: 86
9213
+ - id: layer_9_width_262k_l0_small
9214
+ path: transcoder_all/layer_9_width_262k_l0_small
9215
+ l0: 14
9216
+ - id: layer_9_width_262k_l0_small_affine
9217
+ path: transcoder_all/layer_9_width_262k_l0_small_affine
9218
+ l0: 14
9219
9219
  gemma-scope-2-27b-it-transcoders:
9220
9220
  conversion_func: gemma_3
9221
9221
  model: google/gemma-3-27b-it
@@ -118,6 +118,7 @@ class GatedTrainingSAE(TrainingSAE[GatedTrainingSAEConfig]):
118
118
  """
119
119
  GatedTrainingSAE is a concrete implementation of BaseTrainingSAE for the "gated" SAE architecture.
120
120
  It implements:
121
+
121
122
  - initialize_weights: sets up gating parameters (as in GatedSAE) plus optional training-specific init.
122
123
  - encode: calls encode_with_hidden_pre (standard training approach).
123
124
  - decode: linear transformation + hooking, same as GatedSAE or StandardTrainingSAE.
@@ -105,6 +105,7 @@ class JumpReLUSAE(SAE[JumpReLUSAEConfig]):
105
105
  activation function (e.g., ReLU etc.).
106
106
 
107
107
  It implements:
108
+
108
109
  - initialize_weights: sets up parameters, including a threshold.
109
110
  - encode: computes the feature activations using JumpReLU.
110
111
  - decode: reconstructs the input from the feature activations.
@@ -216,10 +217,12 @@ class JumpReLUTrainingSAE(TrainingSAE[JumpReLUTrainingSAEConfig]):
216
217
  JumpReLUTrainingSAE is a training-focused implementation of a SAE using a JumpReLU activation.
217
218
 
218
219
  Similar to the inference-only JumpReLUSAE, but with:
220
+
219
221
  - A learnable log-threshold parameter (instead of a raw threshold).
220
222
  - A specialized auxiliary loss term for sparsity (L0 or similar).
221
223
 
222
224
  Methods of interest include:
225
+
223
226
  - initialize_weights: sets up W_enc, b_enc, W_dec, b_dec, and log_threshold.
224
227
  - encode_with_hidden_pre_jumprelu: runs a forward pass for training.
225
228
  - training_forward_pass: calculates MSE and auxiliary losses, returning a TrainStepOutput.
@@ -34,6 +34,7 @@ class StandardSAE(SAE[StandardSAEConfig]):
34
34
  using a simple linear encoder and decoder.
35
35
 
36
36
  It implements the required abstract methods from BaseSAE:
37
+
37
38
  - initialize_weights: sets up simple parameter initializations for W_enc, b_enc, W_dec, and b_dec.
38
39
  - encode: computes the feature activations from an input.
39
40
  - decode: reconstructs the input from the feature activations.
@@ -99,6 +100,7 @@ class StandardTrainingSAE(TrainingSAE[StandardTrainingSAEConfig]):
99
100
  """
100
101
  StandardTrainingSAE is a concrete implementation of BaseTrainingSAE using the "standard" SAE architecture.
101
102
  It implements:
103
+
102
104
  - initialize_weights: basic weight initialization for encoder/decoder.
103
105
  - encode: inference encoding (invokes encode_with_hidden_pre).
104
106
  - decode: a simple linear decoder.
@@ -167,6 +167,7 @@ class TemporalSAE(SAE[TemporalSAEConfig]):
167
167
  """TemporalSAE: Sparse Autoencoder with temporal attention.
168
168
 
169
169
  This SAE decomposes each activation x_t into:
170
+
170
171
  - x_pred: Information aggregated from context {x_0, ..., x_{t-1}}
171
172
  - x_novel: Novel information at position t (encoded sparsely)
172
173
 
@@ -0,0 +1,89 @@
1
+ """
2
+ Synthetic data utilities for SAE experiments.
3
+
4
+ This module provides tools for creating feature dictionaries and generating
5
+ synthetic activations for testing and experimenting with SAEs.
6
+
7
+ Main components:
8
+
9
+ - FeatureDictionary: Maps sparse feature activations to dense hidden activations
10
+ - ActivationGenerator: Generates batches of synthetic feature activations
11
+ - HierarchyNode: Enforces hierarchical structure on feature activations
12
+ - Training utilities: Helpers for training and evaluating SAEs on synthetic data
13
+ - Plotting utilities: Visualization helpers for understanding SAE behavior
14
+ """
15
+
16
+ from sae_lens.synthetic.activation_generator import (
17
+ ActivationGenerator,
18
+ ActivationsModifier,
19
+ ActivationsModifierInput,
20
+ )
21
+ from sae_lens.synthetic.correlation import (
22
+ create_correlation_matrix_from_correlations,
23
+ generate_random_correlation_matrix,
24
+ generate_random_correlations,
25
+ )
26
+ from sae_lens.synthetic.evals import (
27
+ SyntheticDataEvalResult,
28
+ eval_sae_on_synthetic_data,
29
+ mean_correlation_coefficient,
30
+ )
31
+ from sae_lens.synthetic.feature_dictionary import (
32
+ FeatureDictionary,
33
+ FeatureDictionaryInitializer,
34
+ orthogonal_initializer,
35
+ orthogonalize_embeddings,
36
+ )
37
+ from sae_lens.synthetic.firing_probabilities import (
38
+ linear_firing_probabilities,
39
+ random_firing_probabilities,
40
+ zipfian_firing_probabilities,
41
+ )
42
+ from sae_lens.synthetic.hierarchy import HierarchyNode, hierarchy_modifier
43
+ from sae_lens.synthetic.initialization import init_sae_to_match_feature_dict
44
+ from sae_lens.synthetic.plotting import (
45
+ find_best_feature_ordering,
46
+ find_best_feature_ordering_across_saes,
47
+ find_best_feature_ordering_from_sae,
48
+ plot_sae_feature_similarity,
49
+ )
50
+ from sae_lens.synthetic.training import (
51
+ SyntheticActivationIterator,
52
+ train_toy_sae,
53
+ )
54
+ from sae_lens.util import cosine_similarities
55
+
56
+ __all__ = [
57
+ # Main classes
58
+ "FeatureDictionary",
59
+ "HierarchyNode",
60
+ "hierarchy_modifier",
61
+ "ActivationGenerator",
62
+ # Activation generation
63
+ "zipfian_firing_probabilities",
64
+ "linear_firing_probabilities",
65
+ "random_firing_probabilities",
66
+ "create_correlation_matrix_from_correlations",
67
+ "generate_random_correlations",
68
+ "generate_random_correlation_matrix",
69
+ # Feature modifiers
70
+ "ActivationsModifier",
71
+ "ActivationsModifierInput",
72
+ # Utilities
73
+ "orthogonalize_embeddings",
74
+ "orthogonal_initializer",
75
+ "FeatureDictionaryInitializer",
76
+ "cosine_similarities",
77
+ # Training utilities
78
+ "SyntheticActivationIterator",
79
+ "SyntheticDataEvalResult",
80
+ "train_toy_sae",
81
+ "eval_sae_on_synthetic_data",
82
+ "mean_correlation_coefficient",
83
+ "init_sae_to_match_feature_dict",
84
+ # Plotting utilities
85
+ "find_best_feature_ordering",
86
+ "find_best_feature_ordering_from_sae",
87
+ "find_best_feature_ordering_across_saes",
88
+ "plot_sae_feature_similarity",
89
+ ]