sae-lens 6.3.1__tar.gz → 6.5.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {sae_lens-6.3.1 → sae_lens-6.5.0}/PKG-INFO +1 -1
  2. {sae_lens-6.3.1 → sae_lens-6.5.0}/pyproject.toml +1 -1
  3. {sae_lens-6.3.1 → sae_lens-6.5.0}/sae_lens/__init__.py +16 -1
  4. {sae_lens-6.3.1 → sae_lens-6.5.0}/sae_lens/config.py +10 -1
  5. {sae_lens-6.3.1 → sae_lens-6.5.0}/sae_lens/loading/pretrained_sae_loaders.py +141 -0
  6. {sae_lens-6.3.1 → sae_lens-6.5.0}/sae_lens/pretrained_saes.yaml +110 -0
  7. {sae_lens-6.3.1 → sae_lens-6.5.0}/sae_lens/saes/__init__.py +14 -0
  8. {sae_lens-6.3.1 → sae_lens-6.5.0}/sae_lens/saes/sae.py +1 -1
  9. sae_lens-6.5.0/sae_lens/saes/transcoder.py +365 -0
  10. {sae_lens-6.3.1 → sae_lens-6.5.0}/LICENSE +0 -0
  11. {sae_lens-6.3.1 → sae_lens-6.5.0}/README.md +0 -0
  12. {sae_lens-6.3.1 → sae_lens-6.5.0}/sae_lens/analysis/__init__.py +0 -0
  13. {sae_lens-6.3.1 → sae_lens-6.5.0}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
  14. {sae_lens-6.3.1 → sae_lens-6.5.0}/sae_lens/analysis/neuronpedia_integration.py +0 -0
  15. {sae_lens-6.3.1 → sae_lens-6.5.0}/sae_lens/cache_activations_runner.py +0 -0
  16. {sae_lens-6.3.1 → sae_lens-6.5.0}/sae_lens/constants.py +0 -0
  17. {sae_lens-6.3.1 → sae_lens-6.5.0}/sae_lens/evals.py +0 -0
  18. {sae_lens-6.3.1 → sae_lens-6.5.0}/sae_lens/llm_sae_training_runner.py +0 -0
  19. {sae_lens-6.3.1 → sae_lens-6.5.0}/sae_lens/load_model.py +0 -0
  20. {sae_lens-6.3.1 → sae_lens-6.5.0}/sae_lens/loading/__init__.py +0 -0
  21. {sae_lens-6.3.1 → sae_lens-6.5.0}/sae_lens/loading/pretrained_saes_directory.py +0 -0
  22. {sae_lens-6.3.1 → sae_lens-6.5.0}/sae_lens/pretokenize_runner.py +0 -0
  23. {sae_lens-6.3.1 → sae_lens-6.5.0}/sae_lens/registry.py +0 -0
  24. {sae_lens-6.3.1 → sae_lens-6.5.0}/sae_lens/saes/batchtopk_sae.py +0 -0
  25. {sae_lens-6.3.1 → sae_lens-6.5.0}/sae_lens/saes/gated_sae.py +0 -0
  26. {sae_lens-6.3.1 → sae_lens-6.5.0}/sae_lens/saes/jumprelu_sae.py +0 -0
  27. {sae_lens-6.3.1 → sae_lens-6.5.0}/sae_lens/saes/standard_sae.py +0 -0
  28. {sae_lens-6.3.1 → sae_lens-6.5.0}/sae_lens/saes/topk_sae.py +0 -0
  29. {sae_lens-6.3.1 → sae_lens-6.5.0}/sae_lens/tokenization_and_batching.py +0 -0
  30. {sae_lens-6.3.1 → sae_lens-6.5.0}/sae_lens/training/__init__.py +0 -0
  31. {sae_lens-6.3.1 → sae_lens-6.5.0}/sae_lens/training/activation_scaler.py +0 -0
  32. {sae_lens-6.3.1 → sae_lens-6.5.0}/sae_lens/training/activations_store.py +0 -0
  33. {sae_lens-6.3.1 → sae_lens-6.5.0}/sae_lens/training/mixing_buffer.py +0 -0
  34. {sae_lens-6.3.1 → sae_lens-6.5.0}/sae_lens/training/optim.py +0 -0
  35. {sae_lens-6.3.1 → sae_lens-6.5.0}/sae_lens/training/sae_trainer.py +0 -0
  36. {sae_lens-6.3.1 → sae_lens-6.5.0}/sae_lens/training/types.py +0 -0
  37. {sae_lens-6.3.1 → sae_lens-6.5.0}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
  38. {sae_lens-6.3.1 → sae_lens-6.5.0}/sae_lens/tutorial/tsea.py +0 -0
  39. {sae_lens-6.3.1 → sae_lens-6.5.0}/sae_lens/util.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: sae-lens
3
- Version: 6.3.1
3
+ Version: 6.5.0
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  Keywords: deep-learning,sparse-autoencoders,mechanistic-interpretability,PyTorch
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "sae-lens"
3
- version = "6.3.1"
3
+ version = "6.5.0"
4
4
  description = "Training and Analyzing Sparse Autoencoders (SAEs)"
5
5
  authors = ["Joseph Bloom"]
6
6
  readme = "README.md"
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.3.1"
2
+ __version__ = "6.5.0"
3
3
 
4
4
  import logging
5
5
 
@@ -17,7 +17,11 @@ from sae_lens.saes import (
17
17
  JumpReLUSAEConfig,
18
18
  JumpReLUTrainingSAE,
19
19
  JumpReLUTrainingSAEConfig,
20
+ JumpReLUTranscoder,
21
+ JumpReLUTranscoderConfig,
20
22
  SAEConfig,
23
+ SkipTranscoder,
24
+ SkipTranscoderConfig,
21
25
  StandardSAE,
22
26
  StandardSAEConfig,
23
27
  StandardTrainingSAE,
@@ -28,6 +32,8 @@ from sae_lens.saes import (
28
32
  TopKTrainingSAEConfig,
29
33
  TrainingSAE,
30
34
  TrainingSAEConfig,
35
+ Transcoder,
36
+ TranscoderConfig,
31
37
  )
32
38
 
33
39
  from .analysis.hooked_sae_transformer import HookedSAETransformer
@@ -89,6 +95,12 @@ __all__ = [
89
95
  "LoggingConfig",
90
96
  "BatchTopKTrainingSAE",
91
97
  "BatchTopKTrainingSAEConfig",
98
+ "Transcoder",
99
+ "TranscoderConfig",
100
+ "SkipTranscoder",
101
+ "SkipTranscoderConfig",
102
+ "JumpReLUTranscoder",
103
+ "JumpReLUTranscoderConfig",
92
104
  ]
93
105
 
94
106
 
@@ -103,3 +115,6 @@ register_sae_training_class("jumprelu", JumpReLUTrainingSAE, JumpReLUTrainingSAE
103
115
  register_sae_training_class(
104
116
  "batchtopk", BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig
105
117
  )
118
+ register_sae_class("transcoder", Transcoder, TranscoderConfig)
119
+ register_sae_class("skip_transcoder", SkipTranscoder, SkipTranscoderConfig)
120
+ register_sae_class("jumprelu_transcoder", JumpReLUTranscoder, JumpReLUTranscoderConfig)
@@ -1,5 +1,6 @@
1
1
  import json
2
2
  import math
3
+ import warnings
3
4
  from dataclasses import asdict, dataclass, field
4
5
  from pathlib import Path
5
6
  from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast
@@ -125,7 +126,7 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
125
126
  model_name (str): The name of the model to use. This should be the name of the model in the Hugging Face model hub.
126
127
  model_class_name (str): The name of the class of the model to use. This should be either `HookedTransformer` or `HookedMamba`.
127
128
  hook_name (str): The name of the hook to use. This should be a valid TransformerLens hook.
128
- hook_eval (str): NOT CURRENTLY IN USE. The name of the hook to use for evaluation.
129
+ hook_eval (str): DEPRECATED: Will be removed in v7.0.0. NOT CURRENTLY IN USE. The name of the hook to use for evaluation.
129
130
  hook_head_index (int, optional): When the hook is for an activation with a head index, we can specify a specific head to use here.
130
131
  dataset_path (str): A Hugging Face dataset path.
131
132
  dataset_trust_remote_code (bool): Whether to trust remote code when loading datasets from Huggingface.
@@ -264,6 +265,14 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
264
265
  exclude_special_tokens: bool | list[int] = False
265
266
 
266
267
  def __post_init__(self):
268
+ if self.hook_eval != "NOT_IN_USE":
269
+ warnings.warn(
270
+ "The 'hook_eval' field is deprecated and will be removed in v7.0.0. "
271
+ "It is not currently used and can be safely removed from your config.",
272
+ DeprecationWarning,
273
+ stacklevel=2,
274
+ )
275
+
267
276
  if self.use_cached_activations and self.cached_activations_path is None:
268
277
  self.cached_activations_path = _default_cached_activations_path(
269
278
  self.dataset_path,
@@ -41,6 +41,8 @@ LLM_METADATA_KEYS = {
41
41
  "dataset_path",
42
42
  "sae_lens_version",
43
43
  "sae_lens_training_version",
44
+ "hook_name_out",
45
+ "hook_head_index_out",
44
46
  }
45
47
 
46
48
 
@@ -1092,6 +1094,143 @@ def sparsify_disk_loader(
1092
1094
  return cfg_dict, state_dict
1093
1095
 
1094
1096
 
1097
+ def get_gemma_2_transcoder_config_from_hf(
1098
+ repo_id: str,
1099
+ folder_name: str,
1100
+ device: str | None = None,
1101
+ force_download: bool = False, # noqa: ARG001
1102
+ cfg_overrides: dict[str, Any] | None = None,
1103
+ ) -> dict[str, Any]:
1104
+ """Get config for Gemma-2 transcoders"""
1105
+ width_map = {
1106
+ "width_4k": 4096,
1107
+ "width_16k": 16384,
1108
+ "width_65k": 65536,
1109
+ "width_262k": 262144,
1110
+ "width_524k": 524288,
1111
+ "width_1m": 1048576,
1112
+ }
1113
+
1114
+ # Extract width from folder name
1115
+ d_sae = None
1116
+ for width_key, width_value in width_map.items():
1117
+ if width_key in folder_name:
1118
+ d_sae = width_value
1119
+ break
1120
+
1121
+ if d_sae is None:
1122
+ # Try to extract from pattern like "width_16k"
1123
+ match = re.search(r"width_(\d+)k", folder_name)
1124
+ if match:
1125
+ d_sae = int(match.group(1)) * 1024
1126
+ else:
1127
+ raise ValueError(
1128
+ f"Could not extract dictionary size from folder name: {folder_name}"
1129
+ )
1130
+
1131
+ # Extract layer
1132
+ layer_match = re.search(r"layer_(\d+)", folder_name)
1133
+ if layer_match:
1134
+ layer = int(layer_match.group(1))
1135
+ else:
1136
+ layer_match = re.search(r"layer_(\d+)", repo_id)
1137
+ if layer_match:
1138
+ layer = int(layer_match.group(1))
1139
+ else:
1140
+ raise ValueError("Could not extract layer index")
1141
+
1142
+ # Determine model and dimensions from repo_id
1143
+ model_configs = {
1144
+ "2b-it": ("gemma-2-2b-it", 2304),
1145
+ "2b": ("gemma-2-2b", 2304),
1146
+ "9b-it": ("gemma-2-9b-it", 3584),
1147
+ "9b": ("gemma-2-9b", 3584),
1148
+ "27b-it": ("gemma-2-27b-it", 4608),
1149
+ "27b": ("gemma-2-27b", 4608),
1150
+ }
1151
+
1152
+ model_name = None
1153
+ d_model = None
1154
+ for model_key, (name, dim) in model_configs.items():
1155
+ if model_key in repo_id:
1156
+ model_name = name
1157
+ d_model = dim
1158
+ break
1159
+
1160
+ if model_name is None:
1161
+ raise ValueError(f"Could not determine model from repo_id: {repo_id}")
1162
+
1163
+ return {
1164
+ "architecture": "jumprelu_transcoder",
1165
+ "d_in": d_model,
1166
+ "d_out": d_model,
1167
+ "d_sae": d_sae,
1168
+ "dtype": "float32",
1169
+ "device": device if device is not None else "cpu",
1170
+ "activation_fn": "relu",
1171
+ "normalize_activations": "none",
1172
+ "model_name": model_name,
1173
+ "hook_name": f"blocks.{layer}.ln2.hook_normalized",
1174
+ "hook_name_out": f"blocks.{layer}.hook_mlp_out",
1175
+ "hook_head_index": None,
1176
+ "hook_head_index_out": None,
1177
+ "prepend_bos": True,
1178
+ "dataset_path": "monology/pile-uncopyrighted",
1179
+ "context_size": 1024,
1180
+ **(cfg_overrides or {}),
1181
+ }
1182
+
1183
+
1184
+ def gemma_2_transcoder_huggingface_loader(
1185
+ repo_id: str,
1186
+ folder_name: str,
1187
+ device: str = "cpu",
1188
+ force_download: bool = False,
1189
+ cfg_overrides: dict[str, Any] | None = None,
1190
+ ) -> tuple[dict[str, Any], dict[str, torch.Tensor], torch.Tensor | None]:
1191
+ """Load Gemma-2 transcoders from HuggingFace"""
1192
+ cfg_dict = get_gemma_2_transcoder_config_from_hf(
1193
+ repo_id,
1194
+ folder_name,
1195
+ device,
1196
+ force_download,
1197
+ cfg_overrides,
1198
+ )
1199
+
1200
+ # Download the npz file
1201
+ revision = cfg_overrides.get("revision", None) if cfg_overrides else None
1202
+ params_filename = f"{folder_name}/params.npz"
1203
+
1204
+ file_path = hf_hub_download(
1205
+ repo_id=repo_id,
1206
+ filename=params_filename,
1207
+ force_download=force_download,
1208
+ revision=revision,
1209
+ )
1210
+
1211
+ # Load weights from npz file
1212
+ params = np.load(file_path)
1213
+
1214
+ # Convert to state dict with proper naming
1215
+ state_dict = {}
1216
+ for key in params.files:
1217
+ tensor = torch.tensor(params[key], dtype=torch.float32, device=device)
1218
+ # Handle various naming conventions
1219
+ key_lower = key.lower()
1220
+ if key_lower in ["w_enc", "wenc", "w_e"]:
1221
+ state_dict["W_enc"] = tensor
1222
+ elif key_lower in ["w_dec", "wdec", "w_d"]:
1223
+ state_dict["W_dec"] = tensor
1224
+ elif key_lower in ["b_enc", "benc", "b_e"]:
1225
+ state_dict["b_enc"] = tensor
1226
+ elif key_lower in ["b_dec", "bdec", "b_d"]:
1227
+ state_dict["b_dec"] = tensor
1228
+ if key_lower in ["threshold"]:
1229
+ state_dict["threshold"] = tensor
1230
+
1231
+ return cfg_dict, state_dict, None
1232
+
1233
+
1095
1234
  NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
1096
1235
  "sae_lens": sae_lens_huggingface_loader,
1097
1236
  "connor_rob_hook_z": connor_rob_hook_z_huggingface_loader,
@@ -1101,6 +1240,7 @@ NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
1101
1240
  "dictionary_learning_1": dictionary_learning_sae_huggingface_loader_1,
1102
1241
  "deepseek_r1": deepseek_r1_sae_huggingface_loader,
1103
1242
  "sparsify": sparsify_huggingface_loader,
1243
+ "gemma_2_transcoder": gemma_2_transcoder_huggingface_loader,
1104
1244
  }
1105
1245
 
1106
1246
 
@@ -1113,4 +1253,5 @@ NAMED_PRETRAINED_SAE_CONFIG_GETTERS: dict[str, PretrainedSaeConfigHuggingfaceLoa
1113
1253
  "dictionary_learning_1": get_dictionary_learning_config_1_from_hf,
1114
1254
  "deepseek_r1": get_deepseek_r1_config_from_hf,
1115
1255
  "sparsify": get_sparsify_config_from_hf,
1256
+ "gemma_2_transcoder": get_gemma_2_transcoder_config_from_hf,
1116
1257
  }
@@ -13974,3 +13974,113 @@ gemma-3-1b-res-matryoshka-dc:
13974
13974
  - id: blocks.24.hook_resid_post
13975
13975
  path: blocks.24.hook_resid_post
13976
13976
  l0: 40.0
13977
+
13978
+ gemma-scope-2b-pt-transcoders:
13979
+ conversion_func: gemma_2_transcoder
13980
+ model: gemma-2-2b
13981
+ repo_id: google/gemma-scope-2b-pt-transcoders
13982
+ saes:
13983
+ - id: layer_0/width_16k/average_l0_76
13984
+ neuronpedia: gemma-2-2b/0-gemmascope-transcoder-16k
13985
+ l0: 76
13986
+ path: layer_0/width_16k/average_l0_76
13987
+ - id: layer_1/width_16k/average_l0_65
13988
+ neuronpedia: gemma-2-2b/1-gemmascope-transcoder-16k
13989
+ l0: 65
13990
+ path: layer_1/width_16k/average_l0_65
13991
+ - id: layer_2/width_16k/average_l0_49
13992
+ neuronpedia: gemma-2-2b/2-gemmascope-transcoder-16k
13993
+ l0: 49
13994
+ path: layer_2/width_16k/average_l0_49
13995
+ - id: layer_3/width_16k/average_l0_54
13996
+ neuronpedia: gemma-2-2b/3-gemmascope-transcoder-16k
13997
+ l0: 54
13998
+ path: layer_3/width_16k/average_l0_54
13999
+ - id: layer_4/width_16k/average_l0_88
14000
+ neuronpedia: gemma-2-2b/4-gemmascope-transcoder-16k
14001
+ l0: 88
14002
+ path: layer_4/width_16k/average_l0_88
14003
+ - id: layer_5/width_16k/average_l0_87
14004
+ neuronpedia: gemma-2-2b/5-gemmascope-transcoder-16k
14005
+ l0: 87
14006
+ path: layer_5/width_16k/average_l0_87
14007
+ - id: layer_6/width_16k/average_l0_95
14008
+ neuronpedia: gemma-2-2b/6-gemmascope-transcoder-16k
14009
+ l0: 95
14010
+ path: layer_6/width_16k/average_l0_95
14011
+ - id: layer_7/width_16k/average_l0_70
14012
+ neuronpedia: gemma-2-2b/7-gemmascope-transcoder-16k
14013
+ l0: 70
14014
+ path: layer_7/width_16k/average_l0_70
14015
+ - id: layer_8/width_16k/average_l0_52
14016
+ neuronpedia: gemma-2-2b/8-gemmascope-transcoder-16k
14017
+ l0: 52
14018
+ path: layer_8/width_16k/average_l0_52
14019
+ - id: layer_9/width_16k/average_l0_72
14020
+ neuronpedia: gemma-2-2b/9-gemmascope-transcoder-16k
14021
+ l0: 72
14022
+ path: layer_9/width_16k/average_l0_72
14023
+ - id: layer_10/width_16k/average_l0_88
14024
+ neuronpedia: gemma-2-2b/10-gemmascope-transcoder-16k
14025
+ l0: 88
14026
+ path: layer_10/width_16k/average_l0_88
14027
+ - id: layer_11/width_16k/average_l0_5
14028
+ neuronpedia: gemma-2-2b/11-gemmascope-transcoder-16k
14029
+ l0: 5
14030
+ path: layer_11/width_16k/average_l0_5
14031
+ - id: layer_12/width_16k/average_l0_6
14032
+ neuronpedia: gemma-2-2b/12-gemmascope-transcoder-16k
14033
+ l0: 6
14034
+ path: layer_12/width_16k/average_l0_6
14035
+ - id: layer_13/width_16k/average_l0_8
14036
+ neuronpedia: gemma-2-2b/13-gemmascope-transcoder-16k
14037
+ l0: 8
14038
+ path: layer_13/width_16k/average_l0_8
14039
+ - id: layer_14/width_16k/average_l0_8
14040
+ neuronpedia: gemma-2-2b/14-gemmascope-transcoder-16k
14041
+ l0: 8
14042
+ path: layer_14/width_16k/average_l0_8
14043
+ - id: layer_15/width_16k/average_l0_8
14044
+ neuronpedia: gemma-2-2b/15-gemmascope-transcoder-16k
14045
+ l0: 8
14046
+ path: layer_15/width_16k/average_l0_8
14047
+ - id: layer_16/width_16k/average_l0_10
14048
+ neuronpedia: gemma-2-2b/16-gemmascope-transcoder-16k
14049
+ l0: 10
14050
+ path: layer_16/width_16k/average_l0_10
14051
+ - id: layer_17/width_16k/average_l0_12
14052
+ neuronpedia: gemma-2-2b/17-gemmascope-transcoder-16k
14053
+ l0: 12
14054
+ path: layer_17/width_16k/average_l0_12
14055
+ - id: layer_18/width_16k/average_l0_13
14056
+ neuronpedia: gemma-2-2b/18-gemmascope-transcoder-16k
14057
+ l0: 13
14058
+ path: layer_18/width_16k/average_l0_13
14059
+ - id: layer_19/width_16k/average_l0_12
14060
+ neuronpedia: gemma-2-2b/19-gemmascope-transcoder-16k
14061
+ l0: 12
14062
+ path: layer_19/width_16k/average_l0_12
14063
+ - id: layer_20/width_16k/average_l0_11
14064
+ neuronpedia: gemma-2-2b/20-gemmascope-transcoder-16k
14065
+ l0: 11
14066
+ path: layer_20/width_16k/average_l0_11
14067
+ - id: layer_21/width_16k/average_l0_13
14068
+ neuronpedia: gemma-2-2b/21-gemmascope-transcoder-16k
14069
+ l0: 13
14070
+ path: layer_21/width_16k/average_l0_13
14071
+ - id: layer_22/width_16k/average_l0_15
14072
+ neuronpedia: gemma-2-2b/22-gemmascope-transcoder-16k
14073
+ l0: 15
14074
+ path: layer_22/width_16k/average_l0_15
14075
+ - id: layer_23/width_16k/average_l0_25
14076
+ neuronpedia: gemma-2-2b/23-gemmascope-transcoder-16k
14077
+ l0: 25
14078
+ path: layer_23/width_16k/average_l0_25
14079
+ - id: layer_24/width_16k/average_l0_37
14080
+ neuronpedia: gemma-2-2b/24-gemmascope-transcoder-16k
14081
+ l0: 37
14082
+ path: layer_24/width_16k/average_l0_37
14083
+ - id: layer_25/width_16k/average_l0_41
14084
+ neuronpedia: gemma-2-2b/25-gemmascope-transcoder-16k
14085
+ l0: 41
14086
+ path: layer_25/width_16k/average_l0_41
@@ -27,6 +27,14 @@ from .topk_sae import (
27
27
  TopKTrainingSAE,
28
28
  TopKTrainingSAEConfig,
29
29
  )
30
+ from .transcoder import (
31
+ JumpReLUTranscoder,
32
+ JumpReLUTranscoderConfig,
33
+ SkipTranscoder,
34
+ SkipTranscoderConfig,
35
+ Transcoder,
36
+ TranscoderConfig,
37
+ )
30
38
 
31
39
  __all__ = [
32
40
  "SAE",
@@ -51,4 +59,10 @@ __all__ = [
51
59
  "TopKTrainingSAEConfig",
52
60
  "BatchTopKTrainingSAE",
53
61
  "BatchTopKTrainingSAEConfig",
62
+ "Transcoder",
63
+ "TranscoderConfig",
64
+ "SkipTranscoder",
65
+ "SkipTranscoderConfig",
66
+ "JumpReLUTranscoder",
67
+ "JumpReLUTranscoderConfig",
54
68
  ]
@@ -245,7 +245,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
245
245
 
246
246
  self.cfg = cfg
247
247
 
248
- if cfg.metadata and cfg.metadata.model_from_pretrained_kwargs:
248
+ if cfg.metadata and cfg.metadata:
249
249
  warnings.warn(
250
250
  "\nThis SAE has non-empty model_from_pretrained_kwargs. "
251
251
  "\nFor optimal performance, load the model like so:\n"
@@ -0,0 +1,365 @@
1
+ from dataclasses import dataclass
2
+ from typing import Any
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ from sae_lens.saes.sae import (
8
+ SAE,
9
+ SAEConfig,
10
+ SAEMetadata,
11
+ )
12
+ from sae_lens.util import filter_valid_dataclass_fields
13
+
14
+ # pyright: reportIncompatibleVariableOverride=false
15
+
16
+
17
+ @dataclass
18
+ class TranscoderConfig(SAEConfig):
19
+ # Output dimension fields
20
+ d_out: int = 768
21
+ # hook_name_out: str = ""
22
+ # hook_layer_out: int = 0
23
+ # hook_head_index_out: int | None = None
24
+
25
+ @classmethod
26
+ def architecture(cls) -> str:
27
+ """Return the architecture name for this config."""
28
+ return "transcoder"
29
+
30
+ @classmethod
31
+ def from_dict(cls, config_dict: dict[str, Any]) -> "TranscoderConfig":
32
+ """Create a TranscoderConfig from a dictionary."""
33
+ # Filter to only include valid dataclass fields
34
+ filtered_config_dict = filter_valid_dataclass_fields(config_dict, cls)
35
+
36
+ # Create the config instance
37
+ res = cls(**filtered_config_dict)
38
+
39
+ # Handle metadata if present
40
+ if "metadata" in config_dict:
41
+ res.metadata = SAEMetadata(**config_dict["metadata"])
42
+
43
+ return res
44
+
45
+ def to_dict(self) -> dict[str, Any]:
46
+ """Convert to dictionary, including parent fields."""
47
+ # Get the base dictionary from parent
48
+ res = super().to_dict()
49
+
50
+ # Add transcoder-specific fields
51
+ res.update({"d_out": self.d_out})
52
+
53
+ return res
54
+
55
+
56
+ class Transcoder(SAE[TranscoderConfig]):
57
+ """
58
+ A transcoder maps activations from one hook point to another with
59
+ potentially different dimensions. It extends the standard SAE but with a
60
+ decoder that maps to a different output dimension.
61
+ """
62
+
63
+ cfg: TranscoderConfig
64
+ W_enc: nn.Parameter
65
+ b_enc: nn.Parameter
66
+ W_dec: nn.Parameter
67
+ b_dec: nn.Parameter
68
+
69
+ def __init__(self, cfg: TranscoderConfig):
70
+ super().__init__(cfg)
71
+ self.cfg = cfg
72
+
73
+ def initialize_weights(self):
74
+ """Initialize transcoder weights with proper dimensions."""
75
+ # Initialize b_dec with output dimension
76
+ self.b_dec = nn.Parameter(
77
+ torch.zeros(self.cfg.d_out, dtype=self.dtype, device=self.device)
78
+ )
79
+
80
+ # Initialize W_dec with shape [d_sae, d_out]
81
+ w_dec_data = torch.empty(
82
+ self.cfg.d_sae, self.cfg.d_out, dtype=self.dtype, device=self.device
83
+ )
84
+ nn.init.kaiming_uniform_(w_dec_data)
85
+ self.W_dec = nn.Parameter(w_dec_data)
86
+
87
+ # Initialize W_enc with shape [d_in, d_sae]
88
+ w_enc_data = torch.empty(
89
+ self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
90
+ )
91
+ nn.init.kaiming_uniform_(w_enc_data)
92
+ self.W_enc = nn.Parameter(w_enc_data)
93
+
94
+ # Initialize b_enc
95
+ self.b_enc = nn.Parameter(
96
+ torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
97
+ )
98
+
99
+ def process_sae_in(self, sae_in: torch.Tensor) -> torch.Tensor:
100
+ """
101
+ Process input without applying decoder bias (which has wrong dimension
102
+ for transcoder).
103
+
104
+ Overrides the parent method to skip the bias subtraction since b_dec
105
+ has dimension d_out which doesn't match the input dimension d_in.
106
+ """
107
+ # Don't apply b_dec since it has different dimension
108
+ # Just handle dtype conversion and hooks
109
+ sae_in = sae_in.to(self.dtype)
110
+ sae_in = self.hook_sae_input(sae_in)
111
+ return self.run_time_activation_norm_fn_in(sae_in)
112
+
113
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
114
+ """
115
+ Encode the input tensor into the feature space.
116
+ """
117
+ # Preprocess the SAE input (casting type, applying hooks, normalization)
118
+ sae_in = self.process_sae_in(x)
119
+ # Compute the pre-activation values
120
+ hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
121
+ # Apply the activation function (e.g., ReLU)
122
+ return self.hook_sae_acts_post(self.activation_fn(hidden_pre))
123
+
124
+ def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
125
+ """Decode to output dimension."""
126
+ # W_dec has shape [d_sae, d_out], feature_acts has shape
127
+ # [batch, d_sae]
128
+ sae_out = feature_acts @ self.W_dec + self.b_dec
129
+ # Apply hooks
130
+ # Note: We don't apply run_time_activation_norm_fn_out since the
131
+ # output dimension is different from the input dimension
132
+ return self.hook_sae_recons(sae_out)
133
+
134
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
135
+ """
136
+ Forward pass for transcoder.
137
+
138
+ Args:
139
+ x: Input activations from the input hook point [batch, d_in]
140
+
141
+ Returns:
142
+ sae_out: Reconstructed activations for the output hook point
143
+ [batch, d_out]
144
+ """
145
+ feature_acts = self.encode(x)
146
+ return self.decode(feature_acts)
147
+
148
+ def forward_with_activations(
149
+ self,
150
+ x: torch.Tensor,
151
+ ) -> tuple[torch.Tensor, torch.Tensor]:
152
+ """
153
+ Forward pass returning both output and feature activations.
154
+
155
+ Args:
156
+ x: Input activations from the input hook point [batch, d_in]
157
+
158
+ Returns:
159
+ sae_out: Reconstructed activations for the output hook point
160
+ [batch, d_out]
161
+ feature_acts: Hidden activations [batch, d_sae]
162
+ """
163
+ feature_acts = self.encode(x)
164
+ sae_out = self.decode(feature_acts)
165
+ return sae_out, feature_acts
166
+
167
+ @property
168
+ def d_out(self) -> int:
169
+ """Output dimension of the transcoder."""
170
+ return self.cfg.d_out
171
+
172
+ @classmethod
173
+ def from_dict(cls, config_dict: dict[str, Any]) -> "Transcoder":
174
+ cfg = TranscoderConfig.from_dict(config_dict)
175
+ return cls(cfg)
176
+
177
+
178
+ @dataclass
179
+ class SkipTranscoderConfig(TranscoderConfig):
180
+ @classmethod
181
+ def architecture(cls) -> str:
182
+ """Return the architecture name for this config."""
183
+ return "skip_transcoder"
184
+
185
+ @classmethod
186
+ def from_dict(cls, config_dict: dict[str, Any]) -> "SkipTranscoderConfig":
187
+ """Create a SkipTranscoderConfig from a dictionary."""
188
+ # Filter to only include valid dataclass fields
189
+ filtered_config_dict = filter_valid_dataclass_fields(config_dict, cls)
190
+
191
+ # Create the config instance
192
+ res = cls(**filtered_config_dict)
193
+
194
+ # Handle metadata if present
195
+ if "metadata" in config_dict:
196
+ res.metadata = SAEMetadata(**config_dict["metadata"])
197
+
198
+ return res
199
+
200
+
201
+ class SkipTranscoder(Transcoder):
202
+ """
203
+ A transcoder with a learnable skip connection.
204
+
205
+ Implements: f(x) = W_dec @ relu(W_enc @ x + b_enc) + W_skip @ x + b_dec
206
+ where W_skip is initialized to zeros.
207
+ """
208
+
209
+ cfg: SkipTranscoderConfig # type: ignore[assignment]
210
+ W_skip: nn.Parameter
211
+
212
+ def __init__(self, cfg: SkipTranscoderConfig):
213
+ super().__init__(cfg)
214
+ self.cfg = cfg
215
+
216
+ # Initialize skip connection matrix
217
+ # Shape: [d_out, d_in] to map from input to output dimension
218
+ self.W_skip = nn.Parameter(torch.zeros(self.cfg.d_out, self.cfg.d_in))
219
+
220
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
221
+ """
222
+ Forward pass for skip transcoder.
223
+
224
+ Args:
225
+ x: Input activations from the input hook point [batch, d_in]
226
+
227
+ Returns:
228
+ sae_out: Reconstructed activations for the output hook point
229
+ [batch, d_out]
230
+ """
231
+ feature_acts = self.encode(x)
232
+ sae_out = self.decode(feature_acts)
233
+
234
+ # Add skip connection: W_skip @ x
235
+ # x has shape [batch, d_in], W_skip has shape [d_out, d_in]
236
+ skip_out = x @ self.W_skip.T.to(x.device)
237
+ return sae_out + skip_out
238
+
239
+ def forward_with_activations(
240
+ self,
241
+ x: torch.Tensor,
242
+ ) -> tuple[torch.Tensor, torch.Tensor]:
243
+ """
244
+ Forward pass returning both output and feature activations.
245
+
246
+ Args:
247
+ x: Input activations from the input hook point [batch, d_in]
248
+
249
+ Returns:
250
+ sae_out: Reconstructed activations for the output hook point
251
+ [batch, d_out]
252
+ feature_acts: Hidden activations [batch, d_sae]
253
+ """
254
+ feature_acts = self.encode(x)
255
+ sae_out = self.decode(feature_acts)
256
+
257
+ # Add skip connection: W_skip @ x
258
+ # x has shape [batch, d_in], W_skip has shape [d_out, d_in]
259
+ skip_out = x @ self.W_skip.T.to(x.device)
260
+ sae_out = sae_out + skip_out
261
+
262
+ return sae_out, feature_acts
263
+
264
+ @classmethod
265
+ def from_dict(cls, config_dict: dict[str, Any]) -> "SkipTranscoder":
266
+ cfg = SkipTranscoderConfig.from_dict(config_dict)
267
+ return cls(cfg)
268
+
269
+
270
+ # JumpReLU Transcoder Classes
271
+ @dataclass
272
+ class JumpReLUTranscoderConfig(TranscoderConfig):
273
+ """Configuration for JumpReLU transcoder."""
274
+
275
+ @classmethod
276
+ def architecture(cls) -> str:
277
+ """Return the architecture name for this config."""
278
+ return "jumprelu_transcoder"
279
+
280
+ @classmethod
281
+ def from_dict(cls, config_dict: dict[str, Any]) -> "JumpReLUTranscoderConfig":
282
+ """Create a JumpReLUTranscoderConfig from a dictionary."""
283
+ # Filter to only include valid dataclass fields
284
+ filtered_config_dict = filter_valid_dataclass_fields(config_dict, cls)
285
+
286
+ # Create the config instance
287
+ res = cls(**filtered_config_dict)
288
+
289
+ # Handle metadata if present
290
+ if "metadata" in config_dict:
291
+ res.metadata = SAEMetadata(**config_dict["metadata"])
292
+
293
+ return res
294
+
295
+
296
+ class JumpReLUTranscoder(Transcoder):
297
+ """
298
+ A transcoder with JumpReLU activation function.
299
+
300
+ JumpReLU applies a threshold to activations: if pre-activation <=
301
+ threshold, the unit is zeroed out; otherwise, it follows the base
302
+ activation function.
303
+ """
304
+
305
+ cfg: JumpReLUTranscoderConfig # type: ignore[assignment]
306
+ threshold: nn.Parameter
307
+
308
+ def __init__(self, cfg: JumpReLUTranscoderConfig):
309
+ super().__init__(cfg)
310
+ self.cfg = cfg
311
+
312
+ def initialize_weights(self):
313
+ """Initialize transcoder weights including threshold parameter."""
314
+ super().initialize_weights()
315
+
316
+ # Initialize threshold parameter for JumpReLU
317
+ self.threshold = nn.Parameter(
318
+ torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
319
+ )
320
+
321
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
322
+ """
323
+ Encode using JumpReLU activation.
324
+
325
+ Applies base activation function (ReLU) then masks based on threshold.
326
+ """
327
+ # Preprocess the SAE input
328
+ sae_in = self.process_sae_in(x)
329
+
330
+ # Compute pre-activation values
331
+ hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
332
+
333
+ # Apply base activation function (ReLU)
334
+ feature_acts = self.activation_fn(hidden_pre)
335
+
336
+ # Apply JumpReLU threshold
337
+ # During training, use detached threshold to prevent gradient flow
338
+ threshold = self.threshold.detach() if self.training else self.threshold
339
+ jump_relu_mask = (hidden_pre > threshold).to(self.dtype)
340
+
341
+ # Apply mask and hook
342
+ return self.hook_sae_acts_post(feature_acts * jump_relu_mask)
343
+
344
+ def fold_W_dec_norm(self) -> None:
345
+ """
346
+ Fold the decoder weight norm into the threshold parameter.
347
+
348
+ This is important for JumpReLU as the threshold needs to be scaled
349
+ along with the decoder weights.
350
+ """
351
+ # Get the decoder weight norms before normalizing
352
+ with torch.no_grad():
353
+ W_dec_norms = self.W_dec.norm(dim=1)
354
+
355
+ # Fold the decoder norms as in the parent class
356
+ super().fold_W_dec_norm()
357
+
358
+ # Scale the threshold by the decoder weight norms
359
+ with torch.no_grad():
360
+ self.threshold.data = self.threshold.data * W_dec_norms
361
+
362
+ @classmethod
363
+ def from_dict(cls, config_dict: dict[str, Any]) -> "JumpReLUTranscoder":
364
+ cfg = JumpReLUTranscoderConfig.from_dict(config_dict)
365
+ return cls(cfg)
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes