sae-lens 6.3.1__py3-none-any.whl → 6.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +16 -1
- sae_lens/config.py +10 -1
- sae_lens/loading/pretrained_sae_loaders.py +141 -0
- sae_lens/pretrained_saes.yaml +110 -0
- sae_lens/saes/__init__.py +14 -0
- sae_lens/saes/sae.py +1 -1
- sae_lens/saes/transcoder.py +365 -0
- {sae_lens-6.3.1.dist-info → sae_lens-6.5.0.dist-info}/METADATA +1 -1
- {sae_lens-6.3.1.dist-info → sae_lens-6.5.0.dist-info}/RECORD +11 -10
- {sae_lens-6.3.1.dist-info → sae_lens-6.5.0.dist-info}/LICENSE +0 -0
- {sae_lens-6.3.1.dist-info → sae_lens-6.5.0.dist-info}/WHEEL +0 -0
sae_lens/__init__.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
# ruff: noqa: E402
|
|
2
|
-
__version__ = "6.
|
|
2
|
+
__version__ = "6.5.0"
|
|
3
3
|
|
|
4
4
|
import logging
|
|
5
5
|
|
|
@@ -17,7 +17,11 @@ from sae_lens.saes import (
|
|
|
17
17
|
JumpReLUSAEConfig,
|
|
18
18
|
JumpReLUTrainingSAE,
|
|
19
19
|
JumpReLUTrainingSAEConfig,
|
|
20
|
+
JumpReLUTranscoder,
|
|
21
|
+
JumpReLUTranscoderConfig,
|
|
20
22
|
SAEConfig,
|
|
23
|
+
SkipTranscoder,
|
|
24
|
+
SkipTranscoderConfig,
|
|
21
25
|
StandardSAE,
|
|
22
26
|
StandardSAEConfig,
|
|
23
27
|
StandardTrainingSAE,
|
|
@@ -28,6 +32,8 @@ from sae_lens.saes import (
|
|
|
28
32
|
TopKTrainingSAEConfig,
|
|
29
33
|
TrainingSAE,
|
|
30
34
|
TrainingSAEConfig,
|
|
35
|
+
Transcoder,
|
|
36
|
+
TranscoderConfig,
|
|
31
37
|
)
|
|
32
38
|
|
|
33
39
|
from .analysis.hooked_sae_transformer import HookedSAETransformer
|
|
@@ -89,6 +95,12 @@ __all__ = [
|
|
|
89
95
|
"LoggingConfig",
|
|
90
96
|
"BatchTopKTrainingSAE",
|
|
91
97
|
"BatchTopKTrainingSAEConfig",
|
|
98
|
+
"Transcoder",
|
|
99
|
+
"TranscoderConfig",
|
|
100
|
+
"SkipTranscoder",
|
|
101
|
+
"SkipTranscoderConfig",
|
|
102
|
+
"JumpReLUTranscoder",
|
|
103
|
+
"JumpReLUTranscoderConfig",
|
|
92
104
|
]
|
|
93
105
|
|
|
94
106
|
|
|
@@ -103,3 +115,6 @@ register_sae_training_class("jumprelu", JumpReLUTrainingSAE, JumpReLUTrainingSAE
|
|
|
103
115
|
register_sae_training_class(
|
|
104
116
|
"batchtopk", BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig
|
|
105
117
|
)
|
|
118
|
+
register_sae_class("transcoder", Transcoder, TranscoderConfig)
|
|
119
|
+
register_sae_class("skip_transcoder", SkipTranscoder, SkipTranscoderConfig)
|
|
120
|
+
register_sae_class("jumprelu_transcoder", JumpReLUTranscoder, JumpReLUTranscoderConfig)
|
sae_lens/config.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import math
|
|
3
|
+
import warnings
|
|
3
4
|
from dataclasses import asdict, dataclass, field
|
|
4
5
|
from pathlib import Path
|
|
5
6
|
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast
|
|
@@ -125,7 +126,7 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
|
125
126
|
model_name (str): The name of the model to use. This should be the name of the model in the Hugging Face model hub.
|
|
126
127
|
model_class_name (str): The name of the class of the model to use. This should be either `HookedTransformer` or `HookedMamba`.
|
|
127
128
|
hook_name (str): The name of the hook to use. This should be a valid TransformerLens hook.
|
|
128
|
-
hook_eval (str): NOT CURRENTLY IN USE. The name of the hook to use for evaluation.
|
|
129
|
+
hook_eval (str): DEPRECATED: Will be removed in v7.0.0. NOT CURRENTLY IN USE. The name of the hook to use for evaluation.
|
|
129
130
|
hook_head_index (int, optional): When the hook is for an activation with a head index, we can specify a specific head to use here.
|
|
130
131
|
dataset_path (str): A Hugging Face dataset path.
|
|
131
132
|
dataset_trust_remote_code (bool): Whether to trust remote code when loading datasets from Huggingface.
|
|
@@ -264,6 +265,14 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
|
264
265
|
exclude_special_tokens: bool | list[int] = False
|
|
265
266
|
|
|
266
267
|
def __post_init__(self):
|
|
268
|
+
if self.hook_eval != "NOT_IN_USE":
|
|
269
|
+
warnings.warn(
|
|
270
|
+
"The 'hook_eval' field is deprecated and will be removed in v7.0.0. "
|
|
271
|
+
"It is not currently used and can be safely removed from your config.",
|
|
272
|
+
DeprecationWarning,
|
|
273
|
+
stacklevel=2,
|
|
274
|
+
)
|
|
275
|
+
|
|
267
276
|
if self.use_cached_activations and self.cached_activations_path is None:
|
|
268
277
|
self.cached_activations_path = _default_cached_activations_path(
|
|
269
278
|
self.dataset_path,
|
|
@@ -41,6 +41,8 @@ LLM_METADATA_KEYS = {
|
|
|
41
41
|
"dataset_path",
|
|
42
42
|
"sae_lens_version",
|
|
43
43
|
"sae_lens_training_version",
|
|
44
|
+
"hook_name_out",
|
|
45
|
+
"hook_head_index_out",
|
|
44
46
|
}
|
|
45
47
|
|
|
46
48
|
|
|
@@ -1092,6 +1094,143 @@ def sparsify_disk_loader(
|
|
|
1092
1094
|
return cfg_dict, state_dict
|
|
1093
1095
|
|
|
1094
1096
|
|
|
1097
|
+
def get_gemma_2_transcoder_config_from_hf(
|
|
1098
|
+
repo_id: str,
|
|
1099
|
+
folder_name: str,
|
|
1100
|
+
device: str | None = None,
|
|
1101
|
+
force_download: bool = False, # noqa: ARG001
|
|
1102
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
1103
|
+
) -> dict[str, Any]:
|
|
1104
|
+
"""Get config for Gemma-2 transcoders"""
|
|
1105
|
+
width_map = {
|
|
1106
|
+
"width_4k": 4096,
|
|
1107
|
+
"width_16k": 16384,
|
|
1108
|
+
"width_65k": 65536,
|
|
1109
|
+
"width_262k": 262144,
|
|
1110
|
+
"width_524k": 524288,
|
|
1111
|
+
"width_1m": 1048576,
|
|
1112
|
+
}
|
|
1113
|
+
|
|
1114
|
+
# Extract width from folder name
|
|
1115
|
+
d_sae = None
|
|
1116
|
+
for width_key, width_value in width_map.items():
|
|
1117
|
+
if width_key in folder_name:
|
|
1118
|
+
d_sae = width_value
|
|
1119
|
+
break
|
|
1120
|
+
|
|
1121
|
+
if d_sae is None:
|
|
1122
|
+
# Try to extract from pattern like "width_16k"
|
|
1123
|
+
match = re.search(r"width_(\d+)k", folder_name)
|
|
1124
|
+
if match:
|
|
1125
|
+
d_sae = int(match.group(1)) * 1024
|
|
1126
|
+
else:
|
|
1127
|
+
raise ValueError(
|
|
1128
|
+
f"Could not extract dictionary size from folder name: {folder_name}"
|
|
1129
|
+
)
|
|
1130
|
+
|
|
1131
|
+
# Extract layer
|
|
1132
|
+
layer_match = re.search(r"layer_(\d+)", folder_name)
|
|
1133
|
+
if layer_match:
|
|
1134
|
+
layer = int(layer_match.group(1))
|
|
1135
|
+
else:
|
|
1136
|
+
layer_match = re.search(r"layer_(\d+)", repo_id)
|
|
1137
|
+
if layer_match:
|
|
1138
|
+
layer = int(layer_match.group(1))
|
|
1139
|
+
else:
|
|
1140
|
+
raise ValueError("Could not extract layer index")
|
|
1141
|
+
|
|
1142
|
+
# Determine model and dimensions from repo_id
|
|
1143
|
+
model_configs = {
|
|
1144
|
+
"2b-it": ("gemma-2-2b-it", 2304),
|
|
1145
|
+
"2b": ("gemma-2-2b", 2304),
|
|
1146
|
+
"9b-it": ("gemma-2-9b-it", 3584),
|
|
1147
|
+
"9b": ("gemma-2-9b", 3584),
|
|
1148
|
+
"27b-it": ("gemma-2-27b-it", 4608),
|
|
1149
|
+
"27b": ("gemma-2-27b", 4608),
|
|
1150
|
+
}
|
|
1151
|
+
|
|
1152
|
+
model_name = None
|
|
1153
|
+
d_model = None
|
|
1154
|
+
for model_key, (name, dim) in model_configs.items():
|
|
1155
|
+
if model_key in repo_id:
|
|
1156
|
+
model_name = name
|
|
1157
|
+
d_model = dim
|
|
1158
|
+
break
|
|
1159
|
+
|
|
1160
|
+
if model_name is None:
|
|
1161
|
+
raise ValueError(f"Could not determine model from repo_id: {repo_id}")
|
|
1162
|
+
|
|
1163
|
+
return {
|
|
1164
|
+
"architecture": "jumprelu_transcoder",
|
|
1165
|
+
"d_in": d_model,
|
|
1166
|
+
"d_out": d_model,
|
|
1167
|
+
"d_sae": d_sae,
|
|
1168
|
+
"dtype": "float32",
|
|
1169
|
+
"device": device if device is not None else "cpu",
|
|
1170
|
+
"activation_fn": "relu",
|
|
1171
|
+
"normalize_activations": "none",
|
|
1172
|
+
"model_name": model_name,
|
|
1173
|
+
"hook_name": f"blocks.{layer}.ln2.hook_normalized",
|
|
1174
|
+
"hook_name_out": f"blocks.{layer}.hook_mlp_out",
|
|
1175
|
+
"hook_head_index": None,
|
|
1176
|
+
"hook_head_index_out": None,
|
|
1177
|
+
"prepend_bos": True,
|
|
1178
|
+
"dataset_path": "monology/pile-uncopyrighted",
|
|
1179
|
+
"context_size": 1024,
|
|
1180
|
+
**(cfg_overrides or {}),
|
|
1181
|
+
}
|
|
1182
|
+
|
|
1183
|
+
|
|
1184
|
+
def gemma_2_transcoder_huggingface_loader(
|
|
1185
|
+
repo_id: str,
|
|
1186
|
+
folder_name: str,
|
|
1187
|
+
device: str = "cpu",
|
|
1188
|
+
force_download: bool = False,
|
|
1189
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
1190
|
+
) -> tuple[dict[str, Any], dict[str, torch.Tensor], torch.Tensor | None]:
|
|
1191
|
+
"""Load Gemma-2 transcoders from HuggingFace"""
|
|
1192
|
+
cfg_dict = get_gemma_2_transcoder_config_from_hf(
|
|
1193
|
+
repo_id,
|
|
1194
|
+
folder_name,
|
|
1195
|
+
device,
|
|
1196
|
+
force_download,
|
|
1197
|
+
cfg_overrides,
|
|
1198
|
+
)
|
|
1199
|
+
|
|
1200
|
+
# Download the npz file
|
|
1201
|
+
revision = cfg_overrides.get("revision", None) if cfg_overrides else None
|
|
1202
|
+
params_filename = f"{folder_name}/params.npz"
|
|
1203
|
+
|
|
1204
|
+
file_path = hf_hub_download(
|
|
1205
|
+
repo_id=repo_id,
|
|
1206
|
+
filename=params_filename,
|
|
1207
|
+
force_download=force_download,
|
|
1208
|
+
revision=revision,
|
|
1209
|
+
)
|
|
1210
|
+
|
|
1211
|
+
# Load weights from npz file
|
|
1212
|
+
params = np.load(file_path)
|
|
1213
|
+
|
|
1214
|
+
# Convert to state dict with proper naming
|
|
1215
|
+
state_dict = {}
|
|
1216
|
+
for key in params.files:
|
|
1217
|
+
tensor = torch.tensor(params[key], dtype=torch.float32, device=device)
|
|
1218
|
+
# Handle various naming conventions
|
|
1219
|
+
key_lower = key.lower()
|
|
1220
|
+
if key_lower in ["w_enc", "wenc", "w_e"]:
|
|
1221
|
+
state_dict["W_enc"] = tensor
|
|
1222
|
+
elif key_lower in ["w_dec", "wdec", "w_d"]:
|
|
1223
|
+
state_dict["W_dec"] = tensor
|
|
1224
|
+
elif key_lower in ["b_enc", "benc", "b_e"]:
|
|
1225
|
+
state_dict["b_enc"] = tensor
|
|
1226
|
+
elif key_lower in ["b_dec", "bdec", "b_d"]:
|
|
1227
|
+
state_dict["b_dec"] = tensor
|
|
1228
|
+
if key_lower in ["threshold"]:
|
|
1229
|
+
state_dict["threshold"] = tensor
|
|
1230
|
+
|
|
1231
|
+
return cfg_dict, state_dict, None
|
|
1232
|
+
|
|
1233
|
+
|
|
1095
1234
|
NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
|
|
1096
1235
|
"sae_lens": sae_lens_huggingface_loader,
|
|
1097
1236
|
"connor_rob_hook_z": connor_rob_hook_z_huggingface_loader,
|
|
@@ -1101,6 +1240,7 @@ NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
|
|
|
1101
1240
|
"dictionary_learning_1": dictionary_learning_sae_huggingface_loader_1,
|
|
1102
1241
|
"deepseek_r1": deepseek_r1_sae_huggingface_loader,
|
|
1103
1242
|
"sparsify": sparsify_huggingface_loader,
|
|
1243
|
+
"gemma_2_transcoder": gemma_2_transcoder_huggingface_loader,
|
|
1104
1244
|
}
|
|
1105
1245
|
|
|
1106
1246
|
|
|
@@ -1113,4 +1253,5 @@ NAMED_PRETRAINED_SAE_CONFIG_GETTERS: dict[str, PretrainedSaeConfigHuggingfaceLoa
|
|
|
1113
1253
|
"dictionary_learning_1": get_dictionary_learning_config_1_from_hf,
|
|
1114
1254
|
"deepseek_r1": get_deepseek_r1_config_from_hf,
|
|
1115
1255
|
"sparsify": get_sparsify_config_from_hf,
|
|
1256
|
+
"gemma_2_transcoder": get_gemma_2_transcoder_config_from_hf,
|
|
1116
1257
|
}
|
sae_lens/pretrained_saes.yaml
CHANGED
|
@@ -13974,3 +13974,113 @@ gemma-3-1b-res-matryoshka-dc:
|
|
|
13974
13974
|
- id: blocks.24.hook_resid_post
|
|
13975
13975
|
path: blocks.24.hook_resid_post
|
|
13976
13976
|
l0: 40.0
|
|
13977
|
+
|
|
13978
|
+
gemma-scope-2b-pt-transcoders:
|
|
13979
|
+
conversion_func: gemma_2_transcoder
|
|
13980
|
+
model: gemma-2-2b
|
|
13981
|
+
repo_id: google/gemma-scope-2b-pt-transcoders
|
|
13982
|
+
saes:
|
|
13983
|
+
- id: layer_0/width_16k/average_l0_76
|
|
13984
|
+
neuronpedia: gemma-2-2b/0-gemmascope-transcoder-16k
|
|
13985
|
+
l0: 76
|
|
13986
|
+
path: layer_0/width_16k/average_l0_76
|
|
13987
|
+
- id: layer_1/width_16k/average_l0_65
|
|
13988
|
+
neuronpedia: gemma-2-2b/1-gemmascope-transcoder-16k
|
|
13989
|
+
l0: 65
|
|
13990
|
+
path: layer_1/width_16k/average_l0_65
|
|
13991
|
+
- id: layer_2/width_16k/average_l0_49
|
|
13992
|
+
neuronpedia: gemma-2-2b/2-gemmascope-transcoder-16k
|
|
13993
|
+
l0: 49
|
|
13994
|
+
path: layer_2/width_16k/average_l0_49
|
|
13995
|
+
- id: layer_3/width_16k/average_l0_54
|
|
13996
|
+
neuronpedia: gemma-2-2b/3-gemmascope-transcoder-16k
|
|
13997
|
+
l0: 54
|
|
13998
|
+
path: layer_3/width_16k/average_l0_54
|
|
13999
|
+
- id: layer_4/width_16k/average_l0_88
|
|
14000
|
+
neuronpedia: gemma-2-2b/4-gemmascope-transcoder-16k
|
|
14001
|
+
l0: 88
|
|
14002
|
+
path: layer_4/width_16k/average_l0_88
|
|
14003
|
+
- id: layer_5/width_16k/average_l0_87
|
|
14004
|
+
neuronpedia: gemma-2-2b/5-gemmascope-transcoder-16k
|
|
14005
|
+
l0: 87
|
|
14006
|
+
path: layer_5/width_16k/average_l0_87
|
|
14007
|
+
- id: layer_6/width_16k/average_l0_95
|
|
14008
|
+
neuronpedia: gemma-2-2b/6-gemmascope-transcoder-16k
|
|
14009
|
+
l0: 95
|
|
14010
|
+
path: layer_6/width_16k/average_l0_95
|
|
14011
|
+
- id: layer_7/width_16k/average_l0_70
|
|
14012
|
+
neuronpedia: gemma-2-2b/7-gemmascope-transcoder-16k
|
|
14013
|
+
l0: 70
|
|
14014
|
+
path: layer_7/width_16k/average_l0_70
|
|
14015
|
+
- id: layer_8/width_16k/average_l0_52
|
|
14016
|
+
neuronpedia: gemma-2-2b/8-gemmascope-transcoder-16k
|
|
14017
|
+
l0: 52
|
|
14018
|
+
path: layer_8/width_16k/average_l0_52
|
|
14019
|
+
- id: layer_9/width_16k/average_l0_72
|
|
14020
|
+
neuronpedia: gemma-2-2b/9-gemmascope-transcoder-16k
|
|
14021
|
+
l0: 72
|
|
14022
|
+
path: layer_9/width_16k/average_l0_72
|
|
14023
|
+
- id: layer_10/width_16k/average_l0_88
|
|
14024
|
+
neuronpedia: gemma-2-2b/10-gemmascope-transcoder-16k
|
|
14025
|
+
l0: 88
|
|
14026
|
+
path: layer_10/width_16k/average_l0_88
|
|
14027
|
+
- id: layer_11/width_16k/average_l0_5
|
|
14028
|
+
neuronpedia: gemma-2-2b/11-gemmascope-transcoder-16k
|
|
14029
|
+
l0: 5
|
|
14030
|
+
path: layer_11/width_16k/average_l0_5
|
|
14031
|
+
- id: layer_12/width_16k/average_l0_6
|
|
14032
|
+
neuronpedia: gemma-2-2b/12-gemmascope-transcoder-16k
|
|
14033
|
+
l0: 6
|
|
14034
|
+
path: layer_12/width_16k/average_l0_6
|
|
14035
|
+
- id: layer_13/width_16k/average_l0_8
|
|
14036
|
+
neuronpedia: gemma-2-2b/13-gemmascope-transcoder-16k
|
|
14037
|
+
l0: 8
|
|
14038
|
+
path: layer_13/width_16k/average_l0_8
|
|
14039
|
+
- id: layer_14/width_16k/average_l0_8
|
|
14040
|
+
neuronpedia: gemma-2-2b/14-gemmascope-transcoder-16k
|
|
14041
|
+
l0: 8
|
|
14042
|
+
path: layer_14/width_16k/average_l0_8
|
|
14043
|
+
- id: layer_15/width_16k/average_l0_8
|
|
14044
|
+
neuronpedia: gemma-2-2b/15-gemmascope-transcoder-16k
|
|
14045
|
+
l0: 8
|
|
14046
|
+
path: layer_15/width_16k/average_l0_8
|
|
14047
|
+
- id: layer_16/width_16k/average_l0_10
|
|
14048
|
+
neuronpedia: gemma-2-2b/16-gemmascope-transcoder-16k
|
|
14049
|
+
l0: 10
|
|
14050
|
+
path: layer_16/width_16k/average_l0_10
|
|
14051
|
+
- id: layer_17/width_16k/average_l0_12
|
|
14052
|
+
neuronpedia: gemma-2-2b/17-gemmascope-transcoder-16k
|
|
14053
|
+
l0: 12
|
|
14054
|
+
path: layer_17/width_16k/average_l0_12
|
|
14055
|
+
- id: layer_18/width_16k/average_l0_13
|
|
14056
|
+
neuronpedia: gemma-2-2b/18-gemmascope-transcoder-16k
|
|
14057
|
+
l0: 13
|
|
14058
|
+
path: layer_18/width_16k/average_l0_13
|
|
14059
|
+
- id: layer_19/width_16k/average_l0_12
|
|
14060
|
+
neuronpedia: gemma-2-2b/19-gemmascope-transcoder-16k
|
|
14061
|
+
l0: 12
|
|
14062
|
+
path: layer_19/width_16k/average_l0_12
|
|
14063
|
+
- id: layer_20/width_16k/average_l0_11
|
|
14064
|
+
neuronpedia: gemma-2-2b/20-gemmascope-transcoder-16k
|
|
14065
|
+
l0: 11
|
|
14066
|
+
path: layer_20/width_16k/average_l0_11
|
|
14067
|
+
- id: layer_21/width_16k/average_l0_13
|
|
14068
|
+
neuronpedia: gemma-2-2b/21-gemmascope-transcoder-16k
|
|
14069
|
+
l0: 13
|
|
14070
|
+
path: layer_21/width_16k/average_l0_13
|
|
14071
|
+
- id: layer_22/width_16k/average_l0_15
|
|
14072
|
+
neuronpedia: gemma-2-2b/22-gemmascope-transcoder-16k
|
|
14073
|
+
l0: 15
|
|
14074
|
+
path: layer_22/width_16k/average_l0_15
|
|
14075
|
+
- id: layer_23/width_16k/average_l0_25
|
|
14076
|
+
neuronpedia: gemma-2-2b/23-gemmascope-transcoder-16k
|
|
14077
|
+
l0: 25
|
|
14078
|
+
path: layer_23/width_16k/average_l0_25
|
|
14079
|
+
- id: layer_24/width_16k/average_l0_37
|
|
14080
|
+
neuronpedia: gemma-2-2b/24-gemmascope-transcoder-16k
|
|
14081
|
+
l0: 37
|
|
14082
|
+
path: layer_24/width_16k/average_l0_37
|
|
14083
|
+
- id: layer_25/width_16k/average_l0_41
|
|
14084
|
+
neuronpedia: gemma-2-2b/25-gemmascope-transcoder-16k
|
|
14085
|
+
l0: 41
|
|
14086
|
+
path: layer_25/width_16k/average_l0_41
|
sae_lens/saes/__init__.py
CHANGED
|
@@ -27,6 +27,14 @@ from .topk_sae import (
|
|
|
27
27
|
TopKTrainingSAE,
|
|
28
28
|
TopKTrainingSAEConfig,
|
|
29
29
|
)
|
|
30
|
+
from .transcoder import (
|
|
31
|
+
JumpReLUTranscoder,
|
|
32
|
+
JumpReLUTranscoderConfig,
|
|
33
|
+
SkipTranscoder,
|
|
34
|
+
SkipTranscoderConfig,
|
|
35
|
+
Transcoder,
|
|
36
|
+
TranscoderConfig,
|
|
37
|
+
)
|
|
30
38
|
|
|
31
39
|
__all__ = [
|
|
32
40
|
"SAE",
|
|
@@ -51,4 +59,10 @@ __all__ = [
|
|
|
51
59
|
"TopKTrainingSAEConfig",
|
|
52
60
|
"BatchTopKTrainingSAE",
|
|
53
61
|
"BatchTopKTrainingSAEConfig",
|
|
62
|
+
"Transcoder",
|
|
63
|
+
"TranscoderConfig",
|
|
64
|
+
"SkipTranscoder",
|
|
65
|
+
"SkipTranscoderConfig",
|
|
66
|
+
"JumpReLUTranscoder",
|
|
67
|
+
"JumpReLUTranscoderConfig",
|
|
54
68
|
]
|
sae_lens/saes/sae.py
CHANGED
|
@@ -245,7 +245,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
245
245
|
|
|
246
246
|
self.cfg = cfg
|
|
247
247
|
|
|
248
|
-
if cfg.metadata and cfg.metadata
|
|
248
|
+
if cfg.metadata and cfg.metadata:
|
|
249
249
|
warnings.warn(
|
|
250
250
|
"\nThis SAE has non-empty model_from_pretrained_kwargs. "
|
|
251
251
|
"\nFor optimal performance, load the model like so:\n"
|
|
@@ -0,0 +1,365 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch import nn
|
|
6
|
+
|
|
7
|
+
from sae_lens.saes.sae import (
|
|
8
|
+
SAE,
|
|
9
|
+
SAEConfig,
|
|
10
|
+
SAEMetadata,
|
|
11
|
+
)
|
|
12
|
+
from sae_lens.util import filter_valid_dataclass_fields
|
|
13
|
+
|
|
14
|
+
# pyright: reportIncompatibleVariableOverride=false
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class TranscoderConfig(SAEConfig):
|
|
19
|
+
# Output dimension fields
|
|
20
|
+
d_out: int = 768
|
|
21
|
+
# hook_name_out: str = ""
|
|
22
|
+
# hook_layer_out: int = 0
|
|
23
|
+
# hook_head_index_out: int | None = None
|
|
24
|
+
|
|
25
|
+
@classmethod
|
|
26
|
+
def architecture(cls) -> str:
|
|
27
|
+
"""Return the architecture name for this config."""
|
|
28
|
+
return "transcoder"
|
|
29
|
+
|
|
30
|
+
@classmethod
|
|
31
|
+
def from_dict(cls, config_dict: dict[str, Any]) -> "TranscoderConfig":
|
|
32
|
+
"""Create a TranscoderConfig from a dictionary."""
|
|
33
|
+
# Filter to only include valid dataclass fields
|
|
34
|
+
filtered_config_dict = filter_valid_dataclass_fields(config_dict, cls)
|
|
35
|
+
|
|
36
|
+
# Create the config instance
|
|
37
|
+
res = cls(**filtered_config_dict)
|
|
38
|
+
|
|
39
|
+
# Handle metadata if present
|
|
40
|
+
if "metadata" in config_dict:
|
|
41
|
+
res.metadata = SAEMetadata(**config_dict["metadata"])
|
|
42
|
+
|
|
43
|
+
return res
|
|
44
|
+
|
|
45
|
+
def to_dict(self) -> dict[str, Any]:
|
|
46
|
+
"""Convert to dictionary, including parent fields."""
|
|
47
|
+
# Get the base dictionary from parent
|
|
48
|
+
res = super().to_dict()
|
|
49
|
+
|
|
50
|
+
# Add transcoder-specific fields
|
|
51
|
+
res.update({"d_out": self.d_out})
|
|
52
|
+
|
|
53
|
+
return res
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class Transcoder(SAE[TranscoderConfig]):
|
|
57
|
+
"""
|
|
58
|
+
A transcoder maps activations from one hook point to another with
|
|
59
|
+
potentially different dimensions. It extends the standard SAE but with a
|
|
60
|
+
decoder that maps to a different output dimension.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
cfg: TranscoderConfig
|
|
64
|
+
W_enc: nn.Parameter
|
|
65
|
+
b_enc: nn.Parameter
|
|
66
|
+
W_dec: nn.Parameter
|
|
67
|
+
b_dec: nn.Parameter
|
|
68
|
+
|
|
69
|
+
def __init__(self, cfg: TranscoderConfig):
|
|
70
|
+
super().__init__(cfg)
|
|
71
|
+
self.cfg = cfg
|
|
72
|
+
|
|
73
|
+
def initialize_weights(self):
|
|
74
|
+
"""Initialize transcoder weights with proper dimensions."""
|
|
75
|
+
# Initialize b_dec with output dimension
|
|
76
|
+
self.b_dec = nn.Parameter(
|
|
77
|
+
torch.zeros(self.cfg.d_out, dtype=self.dtype, device=self.device)
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
# Initialize W_dec with shape [d_sae, d_out]
|
|
81
|
+
w_dec_data = torch.empty(
|
|
82
|
+
self.cfg.d_sae, self.cfg.d_out, dtype=self.dtype, device=self.device
|
|
83
|
+
)
|
|
84
|
+
nn.init.kaiming_uniform_(w_dec_data)
|
|
85
|
+
self.W_dec = nn.Parameter(w_dec_data)
|
|
86
|
+
|
|
87
|
+
# Initialize W_enc with shape [d_in, d_sae]
|
|
88
|
+
w_enc_data = torch.empty(
|
|
89
|
+
self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
|
|
90
|
+
)
|
|
91
|
+
nn.init.kaiming_uniform_(w_enc_data)
|
|
92
|
+
self.W_enc = nn.Parameter(w_enc_data)
|
|
93
|
+
|
|
94
|
+
# Initialize b_enc
|
|
95
|
+
self.b_enc = nn.Parameter(
|
|
96
|
+
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
def process_sae_in(self, sae_in: torch.Tensor) -> torch.Tensor:
|
|
100
|
+
"""
|
|
101
|
+
Process input without applying decoder bias (which has wrong dimension
|
|
102
|
+
for transcoder).
|
|
103
|
+
|
|
104
|
+
Overrides the parent method to skip the bias subtraction since b_dec
|
|
105
|
+
has dimension d_out which doesn't match the input dimension d_in.
|
|
106
|
+
"""
|
|
107
|
+
# Don't apply b_dec since it has different dimension
|
|
108
|
+
# Just handle dtype conversion and hooks
|
|
109
|
+
sae_in = sae_in.to(self.dtype)
|
|
110
|
+
sae_in = self.hook_sae_input(sae_in)
|
|
111
|
+
return self.run_time_activation_norm_fn_in(sae_in)
|
|
112
|
+
|
|
113
|
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
|
114
|
+
"""
|
|
115
|
+
Encode the input tensor into the feature space.
|
|
116
|
+
"""
|
|
117
|
+
# Preprocess the SAE input (casting type, applying hooks, normalization)
|
|
118
|
+
sae_in = self.process_sae_in(x)
|
|
119
|
+
# Compute the pre-activation values
|
|
120
|
+
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
|
|
121
|
+
# Apply the activation function (e.g., ReLU)
|
|
122
|
+
return self.hook_sae_acts_post(self.activation_fn(hidden_pre))
|
|
123
|
+
|
|
124
|
+
def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
|
|
125
|
+
"""Decode to output dimension."""
|
|
126
|
+
# W_dec has shape [d_sae, d_out], feature_acts has shape
|
|
127
|
+
# [batch, d_sae]
|
|
128
|
+
sae_out = feature_acts @ self.W_dec + self.b_dec
|
|
129
|
+
# Apply hooks
|
|
130
|
+
# Note: We don't apply run_time_activation_norm_fn_out since the
|
|
131
|
+
# output dimension is different from the input dimension
|
|
132
|
+
return self.hook_sae_recons(sae_out)
|
|
133
|
+
|
|
134
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
135
|
+
"""
|
|
136
|
+
Forward pass for transcoder.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
x: Input activations from the input hook point [batch, d_in]
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
sae_out: Reconstructed activations for the output hook point
|
|
143
|
+
[batch, d_out]
|
|
144
|
+
"""
|
|
145
|
+
feature_acts = self.encode(x)
|
|
146
|
+
return self.decode(feature_acts)
|
|
147
|
+
|
|
148
|
+
def forward_with_activations(
|
|
149
|
+
self,
|
|
150
|
+
x: torch.Tensor,
|
|
151
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
152
|
+
"""
|
|
153
|
+
Forward pass returning both output and feature activations.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
x: Input activations from the input hook point [batch, d_in]
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
sae_out: Reconstructed activations for the output hook point
|
|
160
|
+
[batch, d_out]
|
|
161
|
+
feature_acts: Hidden activations [batch, d_sae]
|
|
162
|
+
"""
|
|
163
|
+
feature_acts = self.encode(x)
|
|
164
|
+
sae_out = self.decode(feature_acts)
|
|
165
|
+
return sae_out, feature_acts
|
|
166
|
+
|
|
167
|
+
@property
|
|
168
|
+
def d_out(self) -> int:
|
|
169
|
+
"""Output dimension of the transcoder."""
|
|
170
|
+
return self.cfg.d_out
|
|
171
|
+
|
|
172
|
+
@classmethod
|
|
173
|
+
def from_dict(cls, config_dict: dict[str, Any]) -> "Transcoder":
|
|
174
|
+
cfg = TranscoderConfig.from_dict(config_dict)
|
|
175
|
+
return cls(cfg)
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
@dataclass
|
|
179
|
+
class SkipTranscoderConfig(TranscoderConfig):
|
|
180
|
+
@classmethod
|
|
181
|
+
def architecture(cls) -> str:
|
|
182
|
+
"""Return the architecture name for this config."""
|
|
183
|
+
return "skip_transcoder"
|
|
184
|
+
|
|
185
|
+
@classmethod
|
|
186
|
+
def from_dict(cls, config_dict: dict[str, Any]) -> "SkipTranscoderConfig":
|
|
187
|
+
"""Create a SkipTranscoderConfig from a dictionary."""
|
|
188
|
+
# Filter to only include valid dataclass fields
|
|
189
|
+
filtered_config_dict = filter_valid_dataclass_fields(config_dict, cls)
|
|
190
|
+
|
|
191
|
+
# Create the config instance
|
|
192
|
+
res = cls(**filtered_config_dict)
|
|
193
|
+
|
|
194
|
+
# Handle metadata if present
|
|
195
|
+
if "metadata" in config_dict:
|
|
196
|
+
res.metadata = SAEMetadata(**config_dict["metadata"])
|
|
197
|
+
|
|
198
|
+
return res
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
class SkipTranscoder(Transcoder):
|
|
202
|
+
"""
|
|
203
|
+
A transcoder with a learnable skip connection.
|
|
204
|
+
|
|
205
|
+
Implements: f(x) = W_dec @ relu(W_enc @ x + b_enc) + W_skip @ x + b_dec
|
|
206
|
+
where W_skip is initialized to zeros.
|
|
207
|
+
"""
|
|
208
|
+
|
|
209
|
+
cfg: SkipTranscoderConfig # type: ignore[assignment]
|
|
210
|
+
W_skip: nn.Parameter
|
|
211
|
+
|
|
212
|
+
def __init__(self, cfg: SkipTranscoderConfig):
|
|
213
|
+
super().__init__(cfg)
|
|
214
|
+
self.cfg = cfg
|
|
215
|
+
|
|
216
|
+
# Initialize skip connection matrix
|
|
217
|
+
# Shape: [d_out, d_in] to map from input to output dimension
|
|
218
|
+
self.W_skip = nn.Parameter(torch.zeros(self.cfg.d_out, self.cfg.d_in))
|
|
219
|
+
|
|
220
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
221
|
+
"""
|
|
222
|
+
Forward pass for skip transcoder.
|
|
223
|
+
|
|
224
|
+
Args:
|
|
225
|
+
x: Input activations from the input hook point [batch, d_in]
|
|
226
|
+
|
|
227
|
+
Returns:
|
|
228
|
+
sae_out: Reconstructed activations for the output hook point
|
|
229
|
+
[batch, d_out]
|
|
230
|
+
"""
|
|
231
|
+
feature_acts = self.encode(x)
|
|
232
|
+
sae_out = self.decode(feature_acts)
|
|
233
|
+
|
|
234
|
+
# Add skip connection: W_skip @ x
|
|
235
|
+
# x has shape [batch, d_in], W_skip has shape [d_out, d_in]
|
|
236
|
+
skip_out = x @ self.W_skip.T.to(x.device)
|
|
237
|
+
return sae_out + skip_out
|
|
238
|
+
|
|
239
|
+
def forward_with_activations(
|
|
240
|
+
self,
|
|
241
|
+
x: torch.Tensor,
|
|
242
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
243
|
+
"""
|
|
244
|
+
Forward pass returning both output and feature activations.
|
|
245
|
+
|
|
246
|
+
Args:
|
|
247
|
+
x: Input activations from the input hook point [batch, d_in]
|
|
248
|
+
|
|
249
|
+
Returns:
|
|
250
|
+
sae_out: Reconstructed activations for the output hook point
|
|
251
|
+
[batch, d_out]
|
|
252
|
+
feature_acts: Hidden activations [batch, d_sae]
|
|
253
|
+
"""
|
|
254
|
+
feature_acts = self.encode(x)
|
|
255
|
+
sae_out = self.decode(feature_acts)
|
|
256
|
+
|
|
257
|
+
# Add skip connection: W_skip @ x
|
|
258
|
+
# x has shape [batch, d_in], W_skip has shape [d_out, d_in]
|
|
259
|
+
skip_out = x @ self.W_skip.T.to(x.device)
|
|
260
|
+
sae_out = sae_out + skip_out
|
|
261
|
+
|
|
262
|
+
return sae_out, feature_acts
|
|
263
|
+
|
|
264
|
+
@classmethod
|
|
265
|
+
def from_dict(cls, config_dict: dict[str, Any]) -> "SkipTranscoder":
|
|
266
|
+
cfg = SkipTranscoderConfig.from_dict(config_dict)
|
|
267
|
+
return cls(cfg)
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
# JumpReLU Transcoder Classes
|
|
271
|
+
@dataclass
|
|
272
|
+
class JumpReLUTranscoderConfig(TranscoderConfig):
|
|
273
|
+
"""Configuration for JumpReLU transcoder."""
|
|
274
|
+
|
|
275
|
+
@classmethod
|
|
276
|
+
def architecture(cls) -> str:
|
|
277
|
+
"""Return the architecture name for this config."""
|
|
278
|
+
return "jumprelu_transcoder"
|
|
279
|
+
|
|
280
|
+
@classmethod
|
|
281
|
+
def from_dict(cls, config_dict: dict[str, Any]) -> "JumpReLUTranscoderConfig":
|
|
282
|
+
"""Create a JumpReLUTranscoderConfig from a dictionary."""
|
|
283
|
+
# Filter to only include valid dataclass fields
|
|
284
|
+
filtered_config_dict = filter_valid_dataclass_fields(config_dict, cls)
|
|
285
|
+
|
|
286
|
+
# Create the config instance
|
|
287
|
+
res = cls(**filtered_config_dict)
|
|
288
|
+
|
|
289
|
+
# Handle metadata if present
|
|
290
|
+
if "metadata" in config_dict:
|
|
291
|
+
res.metadata = SAEMetadata(**config_dict["metadata"])
|
|
292
|
+
|
|
293
|
+
return res
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
class JumpReLUTranscoder(Transcoder):
|
|
297
|
+
"""
|
|
298
|
+
A transcoder with JumpReLU activation function.
|
|
299
|
+
|
|
300
|
+
JumpReLU applies a threshold to activations: if pre-activation <=
|
|
301
|
+
threshold, the unit is zeroed out; otherwise, it follows the base
|
|
302
|
+
activation function.
|
|
303
|
+
"""
|
|
304
|
+
|
|
305
|
+
cfg: JumpReLUTranscoderConfig # type: ignore[assignment]
|
|
306
|
+
threshold: nn.Parameter
|
|
307
|
+
|
|
308
|
+
def __init__(self, cfg: JumpReLUTranscoderConfig):
|
|
309
|
+
super().__init__(cfg)
|
|
310
|
+
self.cfg = cfg
|
|
311
|
+
|
|
312
|
+
def initialize_weights(self):
|
|
313
|
+
"""Initialize transcoder weights including threshold parameter."""
|
|
314
|
+
super().initialize_weights()
|
|
315
|
+
|
|
316
|
+
# Initialize threshold parameter for JumpReLU
|
|
317
|
+
self.threshold = nn.Parameter(
|
|
318
|
+
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
|
322
|
+
"""
|
|
323
|
+
Encode using JumpReLU activation.
|
|
324
|
+
|
|
325
|
+
Applies base activation function (ReLU) then masks based on threshold.
|
|
326
|
+
"""
|
|
327
|
+
# Preprocess the SAE input
|
|
328
|
+
sae_in = self.process_sae_in(x)
|
|
329
|
+
|
|
330
|
+
# Compute pre-activation values
|
|
331
|
+
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
|
|
332
|
+
|
|
333
|
+
# Apply base activation function (ReLU)
|
|
334
|
+
feature_acts = self.activation_fn(hidden_pre)
|
|
335
|
+
|
|
336
|
+
# Apply JumpReLU threshold
|
|
337
|
+
# During training, use detached threshold to prevent gradient flow
|
|
338
|
+
threshold = self.threshold.detach() if self.training else self.threshold
|
|
339
|
+
jump_relu_mask = (hidden_pre > threshold).to(self.dtype)
|
|
340
|
+
|
|
341
|
+
# Apply mask and hook
|
|
342
|
+
return self.hook_sae_acts_post(feature_acts * jump_relu_mask)
|
|
343
|
+
|
|
344
|
+
def fold_W_dec_norm(self) -> None:
|
|
345
|
+
"""
|
|
346
|
+
Fold the decoder weight norm into the threshold parameter.
|
|
347
|
+
|
|
348
|
+
This is important for JumpReLU as the threshold needs to be scaled
|
|
349
|
+
along with the decoder weights.
|
|
350
|
+
"""
|
|
351
|
+
# Get the decoder weight norms before normalizing
|
|
352
|
+
with torch.no_grad():
|
|
353
|
+
W_dec_norms = self.W_dec.norm(dim=1)
|
|
354
|
+
|
|
355
|
+
# Fold the decoder norms as in the parent class
|
|
356
|
+
super().fold_W_dec_norm()
|
|
357
|
+
|
|
358
|
+
# Scale the threshold by the decoder weight norms
|
|
359
|
+
with torch.no_grad():
|
|
360
|
+
self.threshold.data = self.threshold.data * W_dec_norms
|
|
361
|
+
|
|
362
|
+
@classmethod
|
|
363
|
+
def from_dict(cls, config_dict: dict[str, Any]) -> "JumpReLUTranscoder":
|
|
364
|
+
cfg = JumpReLUTranscoderConfig.from_dict(config_dict)
|
|
365
|
+
return cls(cfg)
|
|
@@ -1,26 +1,27 @@
|
|
|
1
|
-
sae_lens/__init__.py,sha256=
|
|
1
|
+
sae_lens/__init__.py,sha256=bVByXIUMDNWvGYcmDJ9cY0Me0iDR5c5TzEdB7yNkv6I,3588
|
|
2
2
|
sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
3
|
sae_lens/analysis/hooked_sae_transformer.py,sha256=Eyg1Y2hVIHNuiiLOCTgzstOuW6iA-7hPHqaGR8y_vMs,13809
|
|
4
4
|
sae_lens/analysis/neuronpedia_integration.py,sha256=MrENqc81Mc2SMbxGjbwHzpkGUCAFKSf0i4EdaUF2Oj4,18707
|
|
5
5
|
sae_lens/cache_activations_runner.py,sha256=L5hhuU2-zPQr2S3L64GMKKLeMQfqXxwDl8NbuOtrybI,12567
|
|
6
|
-
sae_lens/config.py,sha256=
|
|
6
|
+
sae_lens/config.py,sha256=IrjbsKBbaZoFXYrsPJ5xBwIqi9uZJIIFXjV_uoErJaE,28176
|
|
7
7
|
sae_lens/constants.py,sha256=CSjmiZ-bhjQeVLyRvWxAjBokCgkfM8mnvd7-vxLIWTY,639
|
|
8
8
|
sae_lens/evals.py,sha256=kQyrzczKaVD9rHwfFa_DxL_gMXDxsoIVHmsFIPIU2bY,38696
|
|
9
9
|
sae_lens/llm_sae_training_runner.py,sha256=exxNX_OEhdiUrlgmBP9bjX9DOf0HUcNQGO4unKeDjKM,13713
|
|
10
10
|
sae_lens/load_model.py,sha256=dBB_9gO6kWyQ4sXHq7qB8T3YUlXm3PGwYcpR4UVW4QY,8633
|
|
11
11
|
sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
12
|
-
sae_lens/loading/pretrained_sae_loaders.py,sha256=
|
|
12
|
+
sae_lens/loading/pretrained_sae_loaders.py,sha256=RfZhE8l2y5_ZW5-fW-iN-SJGJfHRaZo4d6AiooK1Xuc,40890
|
|
13
13
|
sae_lens/loading/pretrained_saes_directory.py,sha256=4Vn-Jex6SveD7EbxcSOBv8cx1gkPfUMLU1QOP-ww1ZE,3752
|
|
14
14
|
sae_lens/pretokenize_runner.py,sha256=w0f6SfZLAxbp5eAAKnet8RqUB_DKofZ9RGsoJwFnYbA,7058
|
|
15
|
-
sae_lens/pretrained_saes.yaml,sha256=
|
|
15
|
+
sae_lens/pretrained_saes.yaml,sha256=BncZpkf-NhVOWV1FhrDPr_TPxAINP4-eiPZBw-HO4N4,578197
|
|
16
16
|
sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
|
|
17
|
-
sae_lens/saes/__init__.py,sha256=
|
|
17
|
+
sae_lens/saes/__init__.py,sha256=jVwazK8Q6dW5J6_zFXPoNAuBvSxgziQ8eMOjGM3t-X8,1475
|
|
18
18
|
sae_lens/saes/batchtopk_sae.py,sha256=CyaFG2hMyyDaEaXXrAMJC8wQDW1JoddTKF5mvxxBQKY,3395
|
|
19
19
|
sae_lens/saes/gated_sae.py,sha256=qcmM9JwBA8aZR8z_IRHV1_gQX-q_63tKewWXRnhdXuo,8986
|
|
20
20
|
sae_lens/saes/jumprelu_sae.py,sha256=3xkhBcCol2mEpIBLceymCpudocm2ypOjTeTXbpiXoA4,10794
|
|
21
|
-
sae_lens/saes/sae.py,sha256=
|
|
21
|
+
sae_lens/saes/sae.py,sha256=gdUZuLaOHQrPjbDj-nZI813B6-_mNAnV9i9z4qTnpHk,38255
|
|
22
22
|
sae_lens/saes/standard_sae.py,sha256=9UqYyYtQuThYxXKNaDjYcyowpOx2-7cShG-TeUP6JCQ,5940
|
|
23
23
|
sae_lens/saes/topk_sae.py,sha256=CXMBI6CFvI5829bOhoQ350VXR9d8uFHUDlULTIWHXoU,8686
|
|
24
|
+
sae_lens/saes/transcoder.py,sha256=qbsvIeRy7M1nuHapDut0URNg9fR-UtfBvwW4_hgsAac,11963
|
|
24
25
|
sae_lens/tokenization_and_batching.py,sha256=now7caLbU3p-iGokNwmqZDyIvxYoXgnG1uklhgiLZN4,4656
|
|
25
26
|
sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
26
27
|
sae_lens/training/activation_scaler.py,sha256=seEE-2Qd2JMHxqgnsNWPt-DGtYGZxWPnOwCGuVNSOtI,1719
|
|
@@ -32,7 +33,7 @@ sae_lens/training/types.py,sha256=qSjmGzXf3MLalygG0psnVjmhX_mpLmL47MQtZfe7qxg,81
|
|
|
32
33
|
sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
|
|
33
34
|
sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
|
|
34
35
|
sae_lens/util.py,sha256=mCwLAilGMVo8Scm7CIsCafU7GsfmBvCcjwmloI4Ly7Y,1718
|
|
35
|
-
sae_lens-6.
|
|
36
|
-
sae_lens-6.
|
|
37
|
-
sae_lens-6.
|
|
38
|
-
sae_lens-6.
|
|
36
|
+
sae_lens-6.5.0.dist-info/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
|
|
37
|
+
sae_lens-6.5.0.dist-info/METADATA,sha256=kzfWl_gUihYKkxaGbDBk3dLPNGAbjJ4NgnlcJRhUf0M,5555
|
|
38
|
+
sae_lens-6.5.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
|
39
|
+
sae_lens-6.5.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|