sae-lens 6.0.0rc4__tar.gz → 6.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sae_lens-6.0.0rc4 → sae_lens-6.1.0}/PKG-INFO +2 -2
- {sae_lens-6.0.0rc4 → sae_lens-6.1.0}/README.md +1 -1
- {sae_lens-6.0.0rc4 → sae_lens-6.1.0}/pyproject.toml +1 -1
- {sae_lens-6.0.0rc4 → sae_lens-6.1.0}/sae_lens/__init__.py +8 -1
- {sae_lens-6.0.0rc4 → sae_lens-6.1.0}/sae_lens/config.py +0 -23
- {sae_lens-6.0.0rc4 → sae_lens-6.1.0}/sae_lens/constants.py +1 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.1.0}/sae_lens/evals.py +0 -11
- {sae_lens-6.0.0rc4 → sae_lens-6.1.0}/sae_lens/loading/pretrained_sae_loaders.py +154 -2
- {sae_lens-6.0.0rc4 → sae_lens-6.1.0}/sae_lens/pretrained_saes.yaml +12 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.1.0}/sae_lens/saes/__init__.py +6 -0
- sae_lens-6.1.0/sae_lens/saes/batchtopk_sae.py +102 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.1.0}/sae_lens/saes/gated_sae.py +13 -20
- {sae_lens-6.0.0rc4 → sae_lens-6.1.0}/sae_lens/saes/jumprelu_sae.py +1 -37
- {sae_lens-6.0.0rc4 → sae_lens-6.1.0}/sae_lens/saes/sae.py +71 -49
- {sae_lens-6.0.0rc4 → sae_lens-6.1.0}/sae_lens/saes/standard_sae.py +1 -9
- {sae_lens-6.0.0rc4 → sae_lens-6.1.0}/sae_lens/saes/topk_sae.py +18 -48
- {sae_lens-6.0.0rc4 → sae_lens-6.1.0}/sae_lens/training/activations_store.py +0 -15
- {sae_lens-6.0.0rc4 → sae_lens-6.1.0}/sae_lens/training/optim.py +0 -33
- {sae_lens-6.0.0rc4 → sae_lens-6.1.0}/sae_lens/training/sae_trainer.py +4 -2
- {sae_lens-6.0.0rc4 → sae_lens-6.1.0}/LICENSE +0 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.1.0}/sae_lens/analysis/__init__.py +0 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.1.0}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.1.0}/sae_lens/analysis/neuronpedia_integration.py +0 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.1.0}/sae_lens/cache_activations_runner.py +0 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.1.0}/sae_lens/llm_sae_training_runner.py +0 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.1.0}/sae_lens/load_model.py +0 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.1.0}/sae_lens/loading/__init__.py +0 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.1.0}/sae_lens/loading/pretrained_saes_directory.py +0 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.1.0}/sae_lens/pretokenize_runner.py +0 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.1.0}/sae_lens/registry.py +0 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.1.0}/sae_lens/tokenization_and_batching.py +0 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.1.0}/sae_lens/training/__init__.py +0 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.1.0}/sae_lens/training/activation_scaler.py +0 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.1.0}/sae_lens/training/mixing_buffer.py +0 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.1.0}/sae_lens/training/types.py +0 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.1.0}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.1.0}/sae_lens/tutorial/tsea.py +0 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.1.0}/sae_lens/util.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: sae-lens
|
|
3
|
-
Version: 6.0
|
|
3
|
+
Version: 6.1.0
|
|
4
4
|
Summary: Training and Analyzing Sparse Autoencoders (SAEs)
|
|
5
5
|
License: MIT
|
|
6
6
|
Keywords: deep-learning,sparse-autoencoders,mechanistic-interpretability,PyTorch
|
|
@@ -80,7 +80,7 @@ Pre-trained SAEs for various models can be imported via SAE Lens. See this [page
|
|
|
80
80
|
|
|
81
81
|
## Join the Slack!
|
|
82
82
|
|
|
83
|
-
Feel free to join the [Open Source Mechanistic Interpretability Slack](https://join.slack.com/t/opensourcemechanistic/shared_invite/zt-
|
|
83
|
+
Feel free to join the [Open Source Mechanistic Interpretability Slack](https://join.slack.com/t/opensourcemechanistic/shared_invite/zt-375zalm04-GFd5tdBU1yLKlu_T_JSqZQ) for support!
|
|
84
84
|
|
|
85
85
|
## Citation
|
|
86
86
|
|
|
@@ -40,7 +40,7 @@ Pre-trained SAEs for various models can be imported via SAE Lens. See this [page
|
|
|
40
40
|
|
|
41
41
|
## Join the Slack!
|
|
42
42
|
|
|
43
|
-
Feel free to join the [Open Source Mechanistic Interpretability Slack](https://join.slack.com/t/opensourcemechanistic/shared_invite/zt-
|
|
43
|
+
Feel free to join the [Open Source Mechanistic Interpretability Slack](https://join.slack.com/t/opensourcemechanistic/shared_invite/zt-375zalm04-GFd5tdBU1yLKlu_T_JSqZQ) for support!
|
|
44
44
|
|
|
45
45
|
## Citation
|
|
46
46
|
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
# ruff: noqa: E402
|
|
2
|
-
__version__ = "6.
|
|
2
|
+
__version__ = "6.1.0"
|
|
3
3
|
|
|
4
4
|
import logging
|
|
5
5
|
|
|
@@ -7,6 +7,8 @@ logger = logging.getLogger(__name__)
|
|
|
7
7
|
|
|
8
8
|
from sae_lens.saes import (
|
|
9
9
|
SAE,
|
|
10
|
+
BatchTopKTrainingSAE,
|
|
11
|
+
BatchTopKTrainingSAEConfig,
|
|
10
12
|
GatedSAE,
|
|
11
13
|
GatedSAEConfig,
|
|
12
14
|
GatedTrainingSAE,
|
|
@@ -85,6 +87,8 @@ __all__ = [
|
|
|
85
87
|
"JumpReLUTrainingSAEConfig",
|
|
86
88
|
"SAETrainingRunner",
|
|
87
89
|
"LoggingConfig",
|
|
90
|
+
"BatchTopKTrainingSAE",
|
|
91
|
+
"BatchTopKTrainingSAEConfig",
|
|
88
92
|
]
|
|
89
93
|
|
|
90
94
|
|
|
@@ -96,3 +100,6 @@ register_sae_class("topk", TopKSAE, TopKSAEConfig)
|
|
|
96
100
|
register_sae_training_class("topk", TopKTrainingSAE, TopKTrainingSAEConfig)
|
|
97
101
|
register_sae_class("jumprelu", JumpReLUSAE, JumpReLUSAEConfig)
|
|
98
102
|
register_sae_training_class("jumprelu", JumpReLUTrainingSAE, JumpReLUTrainingSAEConfig)
|
|
103
|
+
register_sae_training_class(
|
|
104
|
+
"batchtopk", BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig
|
|
105
|
+
)
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import math
|
|
3
|
-
import os
|
|
4
3
|
from dataclasses import asdict, dataclass, field
|
|
5
4
|
from pathlib import Path
|
|
6
5
|
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast
|
|
@@ -353,28 +352,6 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
|
353
352
|
d["act_store_device"] = str(self.act_store_device)
|
|
354
353
|
return d
|
|
355
354
|
|
|
356
|
-
def to_json(self, path: str) -> None:
|
|
357
|
-
if not os.path.exists(os.path.dirname(path)):
|
|
358
|
-
os.makedirs(os.path.dirname(path))
|
|
359
|
-
|
|
360
|
-
with open(path + "cfg.json", "w") as f:
|
|
361
|
-
json.dump(self.to_dict(), f, indent=2)
|
|
362
|
-
|
|
363
|
-
@classmethod
|
|
364
|
-
def from_json(cls, path: str) -> "LanguageModelSAERunnerConfig[Any]":
|
|
365
|
-
with open(path + "cfg.json") as f:
|
|
366
|
-
cfg = json.load(f)
|
|
367
|
-
|
|
368
|
-
# ensure that seqpos slices is a tuple
|
|
369
|
-
# Ensure seqpos_slice is a tuple
|
|
370
|
-
if "seqpos_slice" in cfg:
|
|
371
|
-
if isinstance(cfg["seqpos_slice"], list):
|
|
372
|
-
cfg["seqpos_slice"] = tuple(cfg["seqpos_slice"])
|
|
373
|
-
elif not isinstance(cfg["seqpos_slice"], tuple):
|
|
374
|
-
cfg["seqpos_slice"] = (cfg["seqpos_slice"],)
|
|
375
|
-
|
|
376
|
-
return cls(**cfg)
|
|
377
|
-
|
|
378
355
|
def to_sae_trainer_config(self) -> "SAETrainerConfig":
|
|
379
356
|
return SAETrainerConfig(
|
|
380
357
|
n_checkpoints=self.n_checkpoints,
|
|
@@ -16,5 +16,6 @@ SPARSITY_FILENAME = "sparsity.safetensors"
|
|
|
16
16
|
SAE_WEIGHTS_FILENAME = "sae_weights.safetensors"
|
|
17
17
|
SAE_CFG_FILENAME = "cfg.json"
|
|
18
18
|
RUNNER_CFG_FILENAME = "runner_cfg.json"
|
|
19
|
+
SPARSIFY_WEIGHTS_FILENAME = "sae.safetensors"
|
|
19
20
|
ACTIVATIONS_STORE_STATE_FILENAME = "activations_store_state.safetensors"
|
|
20
21
|
ACTIVATION_SCALER_CFG_FILENAME = "activation_scaler.json"
|
|
@@ -769,17 +769,6 @@ def nested_dict() -> defaultdict[Any, Any]:
|
|
|
769
769
|
return defaultdict(nested_dict)
|
|
770
770
|
|
|
771
771
|
|
|
772
|
-
def dict_to_nested(flat_dict: dict[str, Any]) -> defaultdict[Any, Any]:
|
|
773
|
-
nested = nested_dict()
|
|
774
|
-
for key, value in flat_dict.items():
|
|
775
|
-
parts = key.split("/")
|
|
776
|
-
d = nested
|
|
777
|
-
for part in parts[:-1]:
|
|
778
|
-
d = d[part]
|
|
779
|
-
d[parts[-1]] = value
|
|
780
|
-
return nested
|
|
781
|
-
|
|
782
|
-
|
|
783
772
|
def multiple_evals(
|
|
784
773
|
sae_regex_pattern: str,
|
|
785
774
|
sae_block_pattern: str,
|
|
@@ -16,6 +16,7 @@ from sae_lens.constants import (
|
|
|
16
16
|
DTYPE_MAP,
|
|
17
17
|
SAE_CFG_FILENAME,
|
|
18
18
|
SAE_WEIGHTS_FILENAME,
|
|
19
|
+
SPARSIFY_WEIGHTS_FILENAME,
|
|
19
20
|
SPARSITY_FILENAME,
|
|
20
21
|
)
|
|
21
22
|
from sae_lens.loading.pretrained_saes_directory import (
|
|
@@ -248,7 +249,7 @@ def handle_pre_6_0_config(cfg_dict: dict[str, Any]) -> dict[str, Any]:
|
|
|
248
249
|
config_class = get_sae_class(architecture)[1]
|
|
249
250
|
|
|
250
251
|
sae_cfg_dict = filter_valid_dataclass_fields(new_cfg, config_class)
|
|
251
|
-
if architecture == "topk":
|
|
252
|
+
if architecture == "topk" and "activation_fn_kwargs" in new_cfg:
|
|
252
253
|
sae_cfg_dict["k"] = new_cfg["activation_fn_kwargs"]["k"]
|
|
253
254
|
|
|
254
255
|
sae_cfg_dict["metadata"] = {
|
|
@@ -530,11 +531,20 @@ def get_llama_scope_config_from_hf(
|
|
|
530
531
|
# Model specific parameters
|
|
531
532
|
model_name, d_in = "meta-llama/Llama-3.1-8B", old_cfg_dict["d_model"]
|
|
532
533
|
|
|
534
|
+
# Get norm scaling factor to rescale jumprelu threshold.
|
|
535
|
+
# We need this because sae.fold_activation_norm_scaling_factor folds scaling norm into W_enc.
|
|
536
|
+
# This requires jumprelu threshold to be scaled in the same way
|
|
537
|
+
norm_scaling_factor = (
|
|
538
|
+
d_in**0.5 / old_cfg_dict["dataset_average_activation_norm"]["in"]
|
|
539
|
+
)
|
|
540
|
+
|
|
533
541
|
cfg_dict = {
|
|
534
542
|
"architecture": "jumprelu",
|
|
535
|
-
"jump_relu_threshold": old_cfg_dict["jump_relu_threshold"]
|
|
543
|
+
"jump_relu_threshold": old_cfg_dict["jump_relu_threshold"]
|
|
544
|
+
* norm_scaling_factor,
|
|
536
545
|
# We use a scalar jump_relu_threshold for all features
|
|
537
546
|
# This is different from Gemma Scope JumpReLU SAEs.
|
|
547
|
+
# Scaled with norm_scaling_factor to match sae.fold_activation_norm_scaling_factor
|
|
538
548
|
"d_in": d_in,
|
|
539
549
|
"d_sae": old_cfg_dict["d_sae"],
|
|
540
550
|
"dtype": "bfloat16",
|
|
@@ -942,6 +952,146 @@ def llama_scope_r1_distill_sae_huggingface_loader(
|
|
|
942
952
|
return cfg_dict, state_dict, log_sparsity
|
|
943
953
|
|
|
944
954
|
|
|
955
|
+
def get_sparsify_config_from_hf(
|
|
956
|
+
repo_id: str,
|
|
957
|
+
folder_name: str,
|
|
958
|
+
device: str,
|
|
959
|
+
force_download: bool = False,
|
|
960
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
961
|
+
) -> dict[str, Any]:
|
|
962
|
+
cfg_filename = f"{folder_name}/{SAE_CFG_FILENAME}"
|
|
963
|
+
cfg_path = hf_hub_download(
|
|
964
|
+
repo_id,
|
|
965
|
+
filename=cfg_filename,
|
|
966
|
+
force_download=force_download,
|
|
967
|
+
)
|
|
968
|
+
sae_path = Path(cfg_path).parent
|
|
969
|
+
return get_sparsify_config_from_disk(
|
|
970
|
+
sae_path, device=device, cfg_overrides=cfg_overrides
|
|
971
|
+
)
|
|
972
|
+
|
|
973
|
+
|
|
974
|
+
def get_sparsify_config_from_disk(
|
|
975
|
+
path: str | Path,
|
|
976
|
+
device: str | None = None,
|
|
977
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
978
|
+
) -> dict[str, Any]:
|
|
979
|
+
path = Path(path)
|
|
980
|
+
|
|
981
|
+
with open(path / SAE_CFG_FILENAME) as f:
|
|
982
|
+
old_cfg_dict = json.load(f)
|
|
983
|
+
|
|
984
|
+
config_path = path.parent / "config.json"
|
|
985
|
+
if config_path.exists():
|
|
986
|
+
with open(config_path) as f:
|
|
987
|
+
config_dict = json.load(f)
|
|
988
|
+
else:
|
|
989
|
+
config_dict = {}
|
|
990
|
+
|
|
991
|
+
folder_name = path.name
|
|
992
|
+
if folder_name == "embed_tokens":
|
|
993
|
+
hook_name, layer = "hook_embed", 0
|
|
994
|
+
else:
|
|
995
|
+
match = re.search(r"layers[._](\d+)", folder_name)
|
|
996
|
+
if match is None:
|
|
997
|
+
raise ValueError(f"Unrecognized Sparsify folder: {folder_name}")
|
|
998
|
+
layer = int(match.group(1))
|
|
999
|
+
hook_name = f"blocks.{layer}.hook_resid_post"
|
|
1000
|
+
|
|
1001
|
+
cfg_dict: dict[str, Any] = {
|
|
1002
|
+
"architecture": "standard",
|
|
1003
|
+
"d_in": old_cfg_dict["d_in"],
|
|
1004
|
+
"d_sae": old_cfg_dict["d_in"] * old_cfg_dict["expansion_factor"],
|
|
1005
|
+
"dtype": "bfloat16",
|
|
1006
|
+
"device": device or "cpu",
|
|
1007
|
+
"model_name": config_dict.get("model", path.parts[-2]),
|
|
1008
|
+
"hook_name": hook_name,
|
|
1009
|
+
"hook_layer": layer,
|
|
1010
|
+
"hook_head_index": None,
|
|
1011
|
+
"activation_fn_str": "topk",
|
|
1012
|
+
"activation_fn_kwargs": {
|
|
1013
|
+
"k": old_cfg_dict["k"],
|
|
1014
|
+
"signed": old_cfg_dict.get("signed", False),
|
|
1015
|
+
},
|
|
1016
|
+
"apply_b_dec_to_input": not old_cfg_dict.get("normalize_decoder", False),
|
|
1017
|
+
"dataset_path": config_dict.get(
|
|
1018
|
+
"dataset", "togethercomputer/RedPajama-Data-1T-Sample"
|
|
1019
|
+
),
|
|
1020
|
+
"context_size": config_dict.get("ctx_len", 2048),
|
|
1021
|
+
"finetuning_scaling_factor": False,
|
|
1022
|
+
"sae_lens_training_version": None,
|
|
1023
|
+
"prepend_bos": True,
|
|
1024
|
+
"dataset_trust_remote_code": True,
|
|
1025
|
+
"normalize_activations": "none",
|
|
1026
|
+
"neuronpedia_id": None,
|
|
1027
|
+
}
|
|
1028
|
+
|
|
1029
|
+
if cfg_overrides:
|
|
1030
|
+
cfg_dict.update(cfg_overrides)
|
|
1031
|
+
|
|
1032
|
+
return cfg_dict
|
|
1033
|
+
|
|
1034
|
+
|
|
1035
|
+
def sparsify_huggingface_loader(
|
|
1036
|
+
repo_id: str,
|
|
1037
|
+
folder_name: str,
|
|
1038
|
+
device: str = "cpu",
|
|
1039
|
+
force_download: bool = False,
|
|
1040
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
1041
|
+
) -> tuple[dict[str, Any], dict[str, torch.Tensor], None]:
|
|
1042
|
+
weights_filename = f"{folder_name}/{SPARSIFY_WEIGHTS_FILENAME}"
|
|
1043
|
+
sae_path = hf_hub_download(
|
|
1044
|
+
repo_id,
|
|
1045
|
+
filename=weights_filename,
|
|
1046
|
+
force_download=force_download,
|
|
1047
|
+
)
|
|
1048
|
+
cfg_dict, state_dict = sparsify_disk_loader(
|
|
1049
|
+
Path(sae_path).parent, device=device, cfg_overrides=cfg_overrides
|
|
1050
|
+
)
|
|
1051
|
+
return cfg_dict, state_dict, None
|
|
1052
|
+
|
|
1053
|
+
|
|
1054
|
+
def sparsify_disk_loader(
|
|
1055
|
+
path: str | Path,
|
|
1056
|
+
device: str = "cpu",
|
|
1057
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
1058
|
+
) -> tuple[dict[str, Any], dict[str, torch.Tensor]]:
|
|
1059
|
+
cfg_dict = get_sparsify_config_from_disk(path, device, cfg_overrides)
|
|
1060
|
+
|
|
1061
|
+
weight_path = Path(path) / SPARSIFY_WEIGHTS_FILENAME
|
|
1062
|
+
state_dict_loaded = load_file(weight_path, device=device)
|
|
1063
|
+
|
|
1064
|
+
dtype = DTYPE_MAP[cfg_dict["dtype"]]
|
|
1065
|
+
|
|
1066
|
+
W_enc = (
|
|
1067
|
+
state_dict_loaded["W_enc"]
|
|
1068
|
+
if "W_enc" in state_dict_loaded
|
|
1069
|
+
else state_dict_loaded["encoder.weight"].T
|
|
1070
|
+
).to(dtype)
|
|
1071
|
+
|
|
1072
|
+
if "W_dec" in state_dict_loaded:
|
|
1073
|
+
W_dec = state_dict_loaded["W_dec"].T.to(dtype)
|
|
1074
|
+
else:
|
|
1075
|
+
W_dec = state_dict_loaded["decoder.weight"].T.to(dtype)
|
|
1076
|
+
|
|
1077
|
+
if "b_enc" in state_dict_loaded:
|
|
1078
|
+
b_enc = state_dict_loaded["b_enc"].to(dtype)
|
|
1079
|
+
elif "encoder.bias" in state_dict_loaded:
|
|
1080
|
+
b_enc = state_dict_loaded["encoder.bias"].to(dtype)
|
|
1081
|
+
else:
|
|
1082
|
+
b_enc = torch.zeros(cfg_dict["d_sae"], dtype=dtype, device=device)
|
|
1083
|
+
|
|
1084
|
+
if "b_dec" in state_dict_loaded:
|
|
1085
|
+
b_dec = state_dict_loaded["b_dec"].to(dtype)
|
|
1086
|
+
elif "decoder.bias" in state_dict_loaded:
|
|
1087
|
+
b_dec = state_dict_loaded["decoder.bias"].to(dtype)
|
|
1088
|
+
else:
|
|
1089
|
+
b_dec = torch.zeros(cfg_dict["d_in"], dtype=dtype, device=device)
|
|
1090
|
+
|
|
1091
|
+
state_dict = {"W_enc": W_enc, "b_enc": b_enc, "W_dec": W_dec, "b_dec": b_dec}
|
|
1092
|
+
return cfg_dict, state_dict
|
|
1093
|
+
|
|
1094
|
+
|
|
945
1095
|
NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
|
|
946
1096
|
"sae_lens": sae_lens_huggingface_loader,
|
|
947
1097
|
"connor_rob_hook_z": connor_rob_hook_z_huggingface_loader,
|
|
@@ -950,6 +1100,7 @@ NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
|
|
|
950
1100
|
"llama_scope_r1_distill": llama_scope_r1_distill_sae_huggingface_loader,
|
|
951
1101
|
"dictionary_learning_1": dictionary_learning_sae_huggingface_loader_1,
|
|
952
1102
|
"deepseek_r1": deepseek_r1_sae_huggingface_loader,
|
|
1103
|
+
"sparsify": sparsify_huggingface_loader,
|
|
953
1104
|
}
|
|
954
1105
|
|
|
955
1106
|
|
|
@@ -961,4 +1112,5 @@ NAMED_PRETRAINED_SAE_CONFIG_GETTERS: dict[str, PretrainedSaeConfigHuggingfaceLoa
|
|
|
961
1112
|
"llama_scope_r1_distill": get_llama_scope_r1_distill_config_from_hf,
|
|
962
1113
|
"dictionary_learning_1": get_dictionary_learning_config_1_from_hf,
|
|
963
1114
|
"deepseek_r1": get_deepseek_r1_config_from_hf,
|
|
1115
|
+
"sparsify": get_sparsify_config_from_hf,
|
|
964
1116
|
}
|
|
@@ -13634,39 +13634,51 @@ gemma-2-2b-res-matryoshka-dc:
|
|
|
13634
13634
|
- id: blocks.13.hook_resid_post
|
|
13635
13635
|
path: standard/blocks.13.hook_resid_post
|
|
13636
13636
|
l0: 40.0
|
|
13637
|
+
neuronpedia: gemma-2-2b/13-res-matryoshka-dc
|
|
13637
13638
|
- id: blocks.14.hook_resid_post
|
|
13638
13639
|
path: standard/blocks.14.hook_resid_post
|
|
13639
13640
|
l0: 40.0
|
|
13641
|
+
neuronpedia: gemma-2-2b/14-res-matryoshka-dc
|
|
13640
13642
|
- id: blocks.15.hook_resid_post
|
|
13641
13643
|
path: standard/blocks.15.hook_resid_post
|
|
13642
13644
|
l0: 40.0
|
|
13645
|
+
neuronpedia: gemma-2-2b/15-res-matryoshka-dc
|
|
13643
13646
|
- id: blocks.16.hook_resid_post
|
|
13644
13647
|
path: standard/blocks.16.hook_resid_post
|
|
13645
13648
|
l0: 40.0
|
|
13649
|
+
neuronpedia: gemma-2-2b/16-res-matryoshka-dc
|
|
13646
13650
|
- id: blocks.17.hook_resid_post
|
|
13647
13651
|
path: standard/blocks.17.hook_resid_post
|
|
13648
13652
|
l0: 40.0
|
|
13653
|
+
neuronpedia: gemma-2-2b/17-res-matryoshka-dc
|
|
13649
13654
|
- id: blocks.18.hook_resid_post
|
|
13650
13655
|
path: standard/blocks.18.hook_resid_post
|
|
13651
13656
|
l0: 40.0
|
|
13657
|
+
neuronpedia: gemma-2-2b/18-res-matryoshka-dc
|
|
13652
13658
|
- id: blocks.19.hook_resid_post
|
|
13653
13659
|
path: standard/blocks.19.hook_resid_post
|
|
13654
13660
|
l0: 40.0
|
|
13661
|
+
neuronpedia: gemma-2-2b/19-res-matryoshka-dc
|
|
13655
13662
|
- id: blocks.20.hook_resid_post
|
|
13656
13663
|
path: standard/blocks.20.hook_resid_post
|
|
13657
13664
|
l0: 40.0
|
|
13665
|
+
neuronpedia: gemma-2-2b/20-res-matryoshka-dc
|
|
13658
13666
|
- id: blocks.21.hook_resid_post
|
|
13659
13667
|
path: standard/blocks.21.hook_resid_post
|
|
13660
13668
|
l0: 40.0
|
|
13669
|
+
neuronpedia: gemma-2-2b/21-res-matryoshka-dc
|
|
13661
13670
|
- id: blocks.22.hook_resid_post
|
|
13662
13671
|
path: standard/blocks.22.hook_resid_post
|
|
13663
13672
|
l0: 40.0
|
|
13673
|
+
neuronpedia: gemma-2-2b/22-res-matryoshka-dc
|
|
13664
13674
|
- id: blocks.23.hook_resid_post
|
|
13665
13675
|
path: standard/blocks.23.hook_resid_post
|
|
13666
13676
|
l0: 40.0
|
|
13677
|
+
neuronpedia: gemma-2-2b/23-res-matryoshka-dc
|
|
13667
13678
|
- id: blocks.24.hook_resid_post
|
|
13668
13679
|
path: standard/blocks.24.hook_resid_post
|
|
13669
13680
|
l0: 40.0
|
|
13681
|
+
neuronpedia: gemma-2-2b/24-res-matryoshka-dc
|
|
13670
13682
|
gemma-2-2b-res-snap-matryoshka-dc:
|
|
13671
13683
|
conversion_func: null
|
|
13672
13684
|
links:
|
|
@@ -1,3 +1,7 @@
|
|
|
1
|
+
from .batchtopk_sae import (
|
|
2
|
+
BatchTopKTrainingSAE,
|
|
3
|
+
BatchTopKTrainingSAEConfig,
|
|
4
|
+
)
|
|
1
5
|
from .gated_sae import (
|
|
2
6
|
GatedSAE,
|
|
3
7
|
GatedSAEConfig,
|
|
@@ -45,4 +49,6 @@ __all__ = [
|
|
|
45
49
|
"TopKSAEConfig",
|
|
46
50
|
"TopKTrainingSAE",
|
|
47
51
|
"TopKTrainingSAEConfig",
|
|
52
|
+
"BatchTopKTrainingSAE",
|
|
53
|
+
"BatchTopKTrainingSAEConfig",
|
|
48
54
|
]
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Any, Callable
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
from typing_extensions import override
|
|
7
|
+
|
|
8
|
+
from sae_lens.saes.jumprelu_sae import JumpReLUSAEConfig
|
|
9
|
+
from sae_lens.saes.sae import SAEConfig, TrainStepInput, TrainStepOutput
|
|
10
|
+
from sae_lens.saes.topk_sae import TopKTrainingSAE, TopKTrainingSAEConfig
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class BatchTopK(nn.Module):
|
|
14
|
+
"""BatchTopK activation function"""
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
k: int,
|
|
19
|
+
):
|
|
20
|
+
super().__init__()
|
|
21
|
+
self.k = k
|
|
22
|
+
|
|
23
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
24
|
+
acts = x.relu()
|
|
25
|
+
flat_acts = acts.flatten()
|
|
26
|
+
acts_topk_flat = torch.topk(flat_acts, self.k * acts.shape[0], dim=-1)
|
|
27
|
+
return (
|
|
28
|
+
torch.zeros_like(flat_acts)
|
|
29
|
+
.scatter(-1, acts_topk_flat.indices, acts_topk_flat.values)
|
|
30
|
+
.reshape(acts.shape)
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass
|
|
35
|
+
class BatchTopKTrainingSAEConfig(TopKTrainingSAEConfig):
|
|
36
|
+
"""
|
|
37
|
+
Configuration class for training a BatchTopKTrainingSAE.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
topk_threshold_lr: float = 0.01
|
|
41
|
+
|
|
42
|
+
@override
|
|
43
|
+
@classmethod
|
|
44
|
+
def architecture(cls) -> str:
|
|
45
|
+
return "batchtopk"
|
|
46
|
+
|
|
47
|
+
@override
|
|
48
|
+
def get_inference_config_class(self) -> type[SAEConfig]:
|
|
49
|
+
return JumpReLUSAEConfig
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class BatchTopKTrainingSAE(TopKTrainingSAE):
|
|
53
|
+
"""
|
|
54
|
+
Global Batch TopK Training SAE
|
|
55
|
+
|
|
56
|
+
This SAE will maintain the k on average across the batch, rather than enforcing the k per-sample as in standard TopK.
|
|
57
|
+
|
|
58
|
+
BatchTopK SAEs are saved as JumpReLU SAEs after training.
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
topk_threshold: torch.Tensor
|
|
62
|
+
cfg: BatchTopKTrainingSAEConfig # type: ignore[assignment]
|
|
63
|
+
|
|
64
|
+
def __init__(self, cfg: BatchTopKTrainingSAEConfig, use_error_term: bool = False):
|
|
65
|
+
super().__init__(cfg, use_error_term)
|
|
66
|
+
|
|
67
|
+
self.register_buffer(
|
|
68
|
+
"topk_threshold",
|
|
69
|
+
# use double precision as otherwise we can run into numerical issues
|
|
70
|
+
torch.tensor(0.0, dtype=torch.double, device=self.W_dec.device),
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
|
|
74
|
+
return BatchTopK(self.cfg.k)
|
|
75
|
+
|
|
76
|
+
@override
|
|
77
|
+
def training_forward_pass(self, step_input: TrainStepInput) -> TrainStepOutput:
|
|
78
|
+
output = super().training_forward_pass(step_input)
|
|
79
|
+
self.update_topk_threshold(output.feature_acts)
|
|
80
|
+
output.metrics["topk_threshold"] = self.topk_threshold
|
|
81
|
+
return output
|
|
82
|
+
|
|
83
|
+
@torch.no_grad()
|
|
84
|
+
def update_topk_threshold(self, acts_topk: torch.Tensor) -> None:
|
|
85
|
+
positive_mask = acts_topk > 0
|
|
86
|
+
lr = self.cfg.topk_threshold_lr
|
|
87
|
+
# autocast can cause numerical issues with the threshold update
|
|
88
|
+
with torch.autocast(self.topk_threshold.device.type, enabled=False):
|
|
89
|
+
if positive_mask.any():
|
|
90
|
+
min_positive = (
|
|
91
|
+
acts_topk[positive_mask].min().to(self.topk_threshold.dtype)
|
|
92
|
+
)
|
|
93
|
+
self.topk_threshold = (1 - lr) * self.topk_threshold + lr * min_positive
|
|
94
|
+
|
|
95
|
+
@override
|
|
96
|
+
def process_state_dict_for_saving_inference(
|
|
97
|
+
self, state_dict: dict[str, Any]
|
|
98
|
+
) -> None:
|
|
99
|
+
super().process_state_dict_for_saving_inference(state_dict)
|
|
100
|
+
# turn the topk threshold into jumprelu threshold
|
|
101
|
+
topk_threshold = state_dict.pop("topk_threshold").item()
|
|
102
|
+
state_dict["threshold"] = torch.ones_like(self.b_enc) * topk_threshold
|
|
@@ -15,7 +15,6 @@ from sae_lens.saes.sae import (
|
|
|
15
15
|
TrainingSAEConfig,
|
|
16
16
|
TrainStepInput,
|
|
17
17
|
)
|
|
18
|
-
from sae_lens.util import filter_valid_dataclass_fields
|
|
19
18
|
|
|
20
19
|
|
|
21
20
|
@dataclass
|
|
@@ -100,16 +99,10 @@ class GatedSAE(SAE[GatedSAEConfig]):
|
|
|
100
99
|
self.W_enc.data = self.W_enc.data * W_dec_norms.T
|
|
101
100
|
|
|
102
101
|
# Gated-specific parameters need special handling
|
|
103
|
-
|
|
102
|
+
# r_mag doesn't need scaling since W_enc scaling is sufficient for magnitude path
|
|
104
103
|
self.b_gate.data = self.b_gate.data * W_dec_norms.squeeze()
|
|
105
104
|
self.b_mag.data = self.b_mag.data * W_dec_norms.squeeze()
|
|
106
105
|
|
|
107
|
-
@torch.no_grad()
|
|
108
|
-
def initialize_decoder_norm_constant_norm(self, norm: float = 0.1):
|
|
109
|
-
"""Initialize decoder with constant norm."""
|
|
110
|
-
self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
|
|
111
|
-
self.W_dec.data *= norm
|
|
112
|
-
|
|
113
106
|
|
|
114
107
|
@dataclass
|
|
115
108
|
class GatedTrainingSAEConfig(TrainingSAEConfig):
|
|
@@ -133,7 +126,7 @@ class GatedTrainingSAE(TrainingSAE[GatedTrainingSAEConfig]):
|
|
|
133
126
|
- initialize_weights: sets up gating parameters (as in GatedSAE) plus optional training-specific init.
|
|
134
127
|
- encode: calls encode_with_hidden_pre (standard training approach).
|
|
135
128
|
- decode: linear transformation + hooking, same as GatedSAE or StandardTrainingSAE.
|
|
136
|
-
- encode_with_hidden_pre: gating logic
|
|
129
|
+
- encode_with_hidden_pre: gating logic.
|
|
137
130
|
- calculate_aux_loss: includes an auxiliary reconstruction path and gating-based sparsity penalty.
|
|
138
131
|
- training_forward_pass: calls encode_with_hidden_pre, decode, and sums up MSE + gating losses.
|
|
139
132
|
"""
|
|
@@ -158,7 +151,6 @@ class GatedTrainingSAE(TrainingSAE[GatedTrainingSAEConfig]):
|
|
|
158
151
|
) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
|
|
159
152
|
"""
|
|
160
153
|
Gated forward pass with pre-activation (for training).
|
|
161
|
-
We also inject noise if self.training is True.
|
|
162
154
|
"""
|
|
163
155
|
sae_in = self.process_sae_in(x)
|
|
164
156
|
|
|
@@ -219,12 +211,6 @@ class GatedTrainingSAE(TrainingSAE[GatedTrainingSAEConfig]):
|
|
|
219
211
|
"weights/b_mag": b_mag_dist,
|
|
220
212
|
}
|
|
221
213
|
|
|
222
|
-
@torch.no_grad()
|
|
223
|
-
def initialize_decoder_norm_constant_norm(self, norm: float = 0.1):
|
|
224
|
-
"""Initialize decoder with constant norm"""
|
|
225
|
-
self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
|
|
226
|
-
self.W_dec.data *= norm
|
|
227
|
-
|
|
228
214
|
def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]:
|
|
229
215
|
return {
|
|
230
216
|
"l1": TrainCoefficientConfig(
|
|
@@ -233,10 +219,17 @@ class GatedTrainingSAE(TrainingSAE[GatedTrainingSAEConfig]):
|
|
|
233
219
|
),
|
|
234
220
|
}
|
|
235
221
|
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
)
|
|
222
|
+
@torch.no_grad()
|
|
223
|
+
def fold_W_dec_norm(self):
|
|
224
|
+
"""Override to handle gated-specific parameters."""
|
|
225
|
+
W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
|
|
226
|
+
self.W_dec.data = self.W_dec.data / W_dec_norms
|
|
227
|
+
self.W_enc.data = self.W_enc.data * W_dec_norms.T
|
|
228
|
+
|
|
229
|
+
# Gated-specific parameters need special handling
|
|
230
|
+
# r_mag doesn't need scaling since W_enc scaling is sufficient for magnitude path
|
|
231
|
+
self.b_gate.data = self.b_gate.data * W_dec_norms.squeeze()
|
|
232
|
+
self.b_mag.data = self.b_mag.data * W_dec_norms.squeeze()
|
|
240
233
|
|
|
241
234
|
|
|
242
235
|
def _init_weights_gated(
|
|
@@ -14,9 +14,7 @@ from sae_lens.saes.sae import (
|
|
|
14
14
|
TrainingSAE,
|
|
15
15
|
TrainingSAEConfig,
|
|
16
16
|
TrainStepInput,
|
|
17
|
-
TrainStepOutput,
|
|
18
17
|
)
|
|
19
|
-
from sae_lens.util import filter_valid_dataclass_fields
|
|
20
18
|
|
|
21
19
|
|
|
22
20
|
def rectangle(x: torch.Tensor) -> torch.Tensor:
|
|
@@ -208,12 +206,11 @@ class JumpReLUTrainingSAE(TrainingSAE[JumpReLUTrainingSAEConfig]):
|
|
|
208
206
|
|
|
209
207
|
Similar to the inference-only JumpReLUSAE, but with:
|
|
210
208
|
- A learnable log-threshold parameter (instead of a raw threshold).
|
|
211
|
-
- Forward passes that add noise during training, if configured.
|
|
212
209
|
- A specialized auxiliary loss term for sparsity (L0 or similar).
|
|
213
210
|
|
|
214
211
|
Methods of interest include:
|
|
215
212
|
- initialize_weights: sets up W_enc, b_enc, W_dec, b_dec, and log_threshold.
|
|
216
|
-
- encode_with_hidden_pre_jumprelu: runs a forward pass for training
|
|
213
|
+
- encode_with_hidden_pre_jumprelu: runs a forward pass for training.
|
|
217
214
|
- training_forward_pass: calculates MSE and auxiliary losses, returning a TrainStepOutput.
|
|
218
215
|
"""
|
|
219
216
|
|
|
@@ -300,34 +297,6 @@ class JumpReLUTrainingSAE(TrainingSAE[JumpReLUTrainingSAEConfig]):
|
|
|
300
297
|
# Fix: Use squeeze() instead of squeeze(-1) to match old behavior
|
|
301
298
|
self.log_threshold.data = torch.log(current_thresh * W_dec_norms.squeeze())
|
|
302
299
|
|
|
303
|
-
def _create_train_step_output(
|
|
304
|
-
self,
|
|
305
|
-
sae_in: torch.Tensor,
|
|
306
|
-
sae_out: torch.Tensor,
|
|
307
|
-
feature_acts: torch.Tensor,
|
|
308
|
-
hidden_pre: torch.Tensor,
|
|
309
|
-
loss: torch.Tensor,
|
|
310
|
-
losses: dict[str, torch.Tensor],
|
|
311
|
-
) -> TrainStepOutput:
|
|
312
|
-
"""
|
|
313
|
-
Helper to produce a TrainStepOutput from the trainer.
|
|
314
|
-
The old code expects a method named _create_train_step_output().
|
|
315
|
-
"""
|
|
316
|
-
return TrainStepOutput(
|
|
317
|
-
sae_in=sae_in,
|
|
318
|
-
sae_out=sae_out,
|
|
319
|
-
feature_acts=feature_acts,
|
|
320
|
-
hidden_pre=hidden_pre,
|
|
321
|
-
loss=loss,
|
|
322
|
-
losses=losses,
|
|
323
|
-
)
|
|
324
|
-
|
|
325
|
-
@torch.no_grad()
|
|
326
|
-
def initialize_decoder_norm_constant_norm(self, norm: float = 0.1):
|
|
327
|
-
"""Initialize decoder with constant norm"""
|
|
328
|
-
self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
|
|
329
|
-
self.W_dec.data *= norm
|
|
330
|
-
|
|
331
300
|
def process_state_dict_for_saving(self, state_dict: dict[str, Any]) -> None:
|
|
332
301
|
"""Convert log_threshold to threshold for saving"""
|
|
333
302
|
if "log_threshold" in state_dict:
|
|
@@ -341,8 +310,3 @@ class JumpReLUTrainingSAE(TrainingSAE[JumpReLUTrainingSAEConfig]):
|
|
|
341
310
|
threshold = state_dict["threshold"]
|
|
342
311
|
del state_dict["threshold"]
|
|
343
312
|
state_dict["log_threshold"] = torch.log(threshold).detach().contiguous()
|
|
344
|
-
|
|
345
|
-
def to_inference_config_dict(self) -> dict[str, Any]:
|
|
346
|
-
return filter_valid_dataclass_fields(
|
|
347
|
-
self.cfg.to_dict(), JumpReLUSAEConfig, ["architecture"]
|
|
348
|
-
)
|
|
@@ -27,7 +27,7 @@ from torch import nn
|
|
|
27
27
|
from transformer_lens.hook_points import HookedRootModule, HookPoint
|
|
28
28
|
from typing_extensions import deprecated, overload, override
|
|
29
29
|
|
|
30
|
-
from sae_lens import __version__
|
|
30
|
+
from sae_lens import __version__
|
|
31
31
|
from sae_lens.constants import (
|
|
32
32
|
DTYPE_MAP,
|
|
33
33
|
SAE_CFG_FILENAME,
|
|
@@ -207,6 +207,8 @@ class TrainStepOutput:
|
|
|
207
207
|
hidden_pre: torch.Tensor
|
|
208
208
|
loss: torch.Tensor # we need to call backwards on this
|
|
209
209
|
losses: dict[str, torch.Tensor]
|
|
210
|
+
# any extra metrics to log can be added here
|
|
211
|
+
metrics: dict[str, torch.Tensor | float | int] = field(default_factory=dict)
|
|
210
212
|
|
|
211
213
|
|
|
212
214
|
@dataclass
|
|
@@ -528,28 +530,6 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
528
530
|
|
|
529
531
|
return model_weights_path, cfg_path
|
|
530
532
|
|
|
531
|
-
## Initialization Methods
|
|
532
|
-
@torch.no_grad()
|
|
533
|
-
def initialize_b_dec_with_precalculated(self, origin: torch.Tensor):
|
|
534
|
-
out = torch.tensor(origin, dtype=self.dtype, device=self.device)
|
|
535
|
-
self.b_dec.data = out
|
|
536
|
-
|
|
537
|
-
@torch.no_grad()
|
|
538
|
-
def initialize_b_dec_with_mean(self, all_activations: torch.Tensor):
|
|
539
|
-
previous_b_dec = self.b_dec.clone().cpu()
|
|
540
|
-
out = all_activations.mean(dim=0)
|
|
541
|
-
|
|
542
|
-
previous_distances = torch.norm(all_activations - previous_b_dec, dim=-1)
|
|
543
|
-
distances = torch.norm(all_activations - out, dim=-1)
|
|
544
|
-
|
|
545
|
-
logger.info("Reinitializing b_dec with mean of activations")
|
|
546
|
-
logger.debug(
|
|
547
|
-
f"Previous distances: {previous_distances.median(0).values.mean().item()}"
|
|
548
|
-
)
|
|
549
|
-
logger.debug(f"New distances: {distances.median(0).values.mean().item()}")
|
|
550
|
-
|
|
551
|
-
self.b_dec.data = out.to(self.dtype).to(self.device)
|
|
552
|
-
|
|
553
533
|
# Class methods for loading models
|
|
554
534
|
@classmethod
|
|
555
535
|
@deprecated("Use load_from_disk instead")
|
|
@@ -732,6 +712,64 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
732
712
|
) -> type[SAEConfig]:
|
|
733
713
|
return SAEConfig
|
|
734
714
|
|
|
715
|
+
### Methods to support deprecated usage of SAE.from_pretrained() ###
|
|
716
|
+
|
|
717
|
+
def __getitem__(self, index: int) -> Any:
|
|
718
|
+
"""
|
|
719
|
+
Support indexing for backward compatibility with tuple unpacking.
|
|
720
|
+
DEPRECATED: SAE.from_pretrained() no longer returns a tuple.
|
|
721
|
+
Use SAE.from_pretrained_with_cfg_and_sparsity() instead.
|
|
722
|
+
"""
|
|
723
|
+
warnings.warn(
|
|
724
|
+
"Indexing SAE objects is deprecated. SAE.from_pretrained() now returns "
|
|
725
|
+
"only the SAE object. Use SAE.from_pretrained_with_cfg_and_sparsity() "
|
|
726
|
+
"to get the config dict and sparsity as well.",
|
|
727
|
+
DeprecationWarning,
|
|
728
|
+
stacklevel=2,
|
|
729
|
+
)
|
|
730
|
+
|
|
731
|
+
if index == 0:
|
|
732
|
+
return self
|
|
733
|
+
if index == 1:
|
|
734
|
+
return self.cfg.to_dict()
|
|
735
|
+
if index == 2:
|
|
736
|
+
return None
|
|
737
|
+
raise IndexError(f"SAE tuple index {index} out of range")
|
|
738
|
+
|
|
739
|
+
def __iter__(self):
|
|
740
|
+
"""
|
|
741
|
+
Support unpacking for backward compatibility with tuple unpacking.
|
|
742
|
+
DEPRECATED: SAE.from_pretrained() no longer returns a tuple.
|
|
743
|
+
Use SAE.from_pretrained_with_cfg_and_sparsity() instead.
|
|
744
|
+
"""
|
|
745
|
+
warnings.warn(
|
|
746
|
+
"Unpacking SAE objects is deprecated. SAE.from_pretrained() now returns "
|
|
747
|
+
"only the SAE object. Use SAE.from_pretrained_with_cfg_and_sparsity() "
|
|
748
|
+
"to get the config dict and sparsity as well.",
|
|
749
|
+
DeprecationWarning,
|
|
750
|
+
stacklevel=2,
|
|
751
|
+
)
|
|
752
|
+
|
|
753
|
+
yield self
|
|
754
|
+
yield self.cfg.to_dict()
|
|
755
|
+
yield None
|
|
756
|
+
|
|
757
|
+
def __len__(self) -> int:
|
|
758
|
+
"""
|
|
759
|
+
Support len() for backward compatibility with tuple unpacking.
|
|
760
|
+
DEPRECATED: SAE.from_pretrained() no longer returns a tuple.
|
|
761
|
+
Use SAE.from_pretrained_with_cfg_and_sparsity() instead.
|
|
762
|
+
"""
|
|
763
|
+
warnings.warn(
|
|
764
|
+
"Getting length of SAE objects is deprecated. SAE.from_pretrained() now returns "
|
|
765
|
+
"only the SAE object. Use SAE.from_pretrained_with_cfg_and_sparsity() "
|
|
766
|
+
"to get the config dict and sparsity as well.",
|
|
767
|
+
DeprecationWarning,
|
|
768
|
+
stacklevel=2,
|
|
769
|
+
)
|
|
770
|
+
|
|
771
|
+
return 3
|
|
772
|
+
|
|
735
773
|
|
|
736
774
|
@dataclass(kw_only=True)
|
|
737
775
|
class TrainingSAEConfig(SAEConfig, ABC):
|
|
@@ -789,20 +827,26 @@ class TrainingSAEConfig(SAEConfig, ABC):
|
|
|
789
827
|
"architecture": self.architecture(),
|
|
790
828
|
}
|
|
791
829
|
|
|
830
|
+
def get_inference_config_class(self) -> type[SAEConfig]:
|
|
831
|
+
"""
|
|
832
|
+
Get the architecture for inference.
|
|
833
|
+
"""
|
|
834
|
+
return get_sae_class(self.architecture())[1]
|
|
835
|
+
|
|
792
836
|
# this needs to exist so we can initialize the parent sae cfg without the training specific
|
|
793
837
|
# parameters. Maybe there's a cleaner way to do this
|
|
794
|
-
def
|
|
838
|
+
def get_inference_sae_cfg_dict(self) -> dict[str, Any]:
|
|
795
839
|
"""
|
|
796
840
|
Creates a dictionary containing attributes corresponding to the fields
|
|
797
841
|
defined in the base SAEConfig class.
|
|
798
842
|
"""
|
|
799
|
-
base_sae_cfg_class =
|
|
843
|
+
base_sae_cfg_class = self.get_inference_config_class()
|
|
800
844
|
base_config_field_names = {f.name for f in fields(base_sae_cfg_class)}
|
|
801
845
|
result_dict = {
|
|
802
846
|
field_name: getattr(self, field_name)
|
|
803
847
|
for field_name in base_config_field_names
|
|
804
848
|
}
|
|
805
|
-
result_dict["architecture"] =
|
|
849
|
+
result_dict["architecture"] = base_sae_cfg_class.architecture()
|
|
806
850
|
result_dict["metadata"] = self.metadata.to_dict()
|
|
807
851
|
return result_dict
|
|
808
852
|
|
|
@@ -930,18 +974,13 @@ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
|
|
|
930
974
|
save_file(state_dict, model_weights_path)
|
|
931
975
|
|
|
932
976
|
# Save the config
|
|
933
|
-
config = self.
|
|
977
|
+
config = self.cfg.get_inference_sae_cfg_dict()
|
|
934
978
|
cfg_path = path / SAE_CFG_FILENAME
|
|
935
979
|
with open(cfg_path, "w") as f:
|
|
936
980
|
json.dump(config, f)
|
|
937
981
|
|
|
938
982
|
return model_weights_path, cfg_path
|
|
939
983
|
|
|
940
|
-
@abstractmethod
|
|
941
|
-
def to_inference_config_dict(self) -> dict[str, Any]:
|
|
942
|
-
"""Convert the config into an inference SAE config dict."""
|
|
943
|
-
...
|
|
944
|
-
|
|
945
984
|
def process_state_dict_for_saving_inference(
|
|
946
985
|
self, state_dict: dict[str, Any]
|
|
947
986
|
) -> None:
|
|
@@ -951,23 +990,6 @@ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
|
|
|
951
990
|
"""
|
|
952
991
|
return self.process_state_dict_for_saving(state_dict)
|
|
953
992
|
|
|
954
|
-
@torch.no_grad()
|
|
955
|
-
def remove_gradient_parallel_to_decoder_directions(self) -> None:
|
|
956
|
-
"""Remove gradient components parallel to decoder directions."""
|
|
957
|
-
# Implement the original logic since this may not be in the base class
|
|
958
|
-
assert self.W_dec.grad is not None
|
|
959
|
-
|
|
960
|
-
parallel_component = einops.einsum(
|
|
961
|
-
self.W_dec.grad,
|
|
962
|
-
self.W_dec.data,
|
|
963
|
-
"d_sae d_in, d_sae d_in -> d_sae",
|
|
964
|
-
)
|
|
965
|
-
self.W_dec.grad -= einops.einsum(
|
|
966
|
-
parallel_component,
|
|
967
|
-
self.W_dec.data,
|
|
968
|
-
"d_sae, d_sae d_in -> d_sae d_in",
|
|
969
|
-
)
|
|
970
|
-
|
|
971
993
|
@torch.no_grad()
|
|
972
994
|
def log_histograms(self) -> dict[str, NDArray[Any]]:
|
|
973
995
|
"""Log histograms of the weights and biases."""
|
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
|
-
from typing import Any
|
|
3
2
|
|
|
4
3
|
import numpy as np
|
|
5
4
|
import torch
|
|
@@ -16,7 +15,6 @@ from sae_lens.saes.sae import (
|
|
|
16
15
|
TrainingSAEConfig,
|
|
17
16
|
TrainStepInput,
|
|
18
17
|
)
|
|
19
|
-
from sae_lens.util import filter_valid_dataclass_fields
|
|
20
18
|
|
|
21
19
|
|
|
22
20
|
@dataclass
|
|
@@ -61,7 +59,6 @@ class StandardSAE(SAE[StandardSAEConfig]):
|
|
|
61
59
|
) -> Float[torch.Tensor, "... d_sae"]:
|
|
62
60
|
"""
|
|
63
61
|
Encode the input tensor into the feature space.
|
|
64
|
-
For inference, no noise is added.
|
|
65
62
|
"""
|
|
66
63
|
# Preprocess the SAE input (casting type, applying hooks, normalization)
|
|
67
64
|
sae_in = self.process_sae_in(x)
|
|
@@ -110,7 +107,7 @@ class StandardTrainingSAE(TrainingSAE[StandardTrainingSAEConfig]):
|
|
|
110
107
|
- initialize_weights: basic weight initialization for encoder/decoder.
|
|
111
108
|
- encode: inference encoding (invokes encode_with_hidden_pre).
|
|
112
109
|
- decode: a simple linear decoder.
|
|
113
|
-
- encode_with_hidden_pre: computes
|
|
110
|
+
- encode_with_hidden_pre: computes activations and pre-activations.
|
|
114
111
|
- calculate_aux_loss: computes a sparsity penalty based on the (optionally scaled) p-norm of feature activations.
|
|
115
112
|
"""
|
|
116
113
|
|
|
@@ -164,11 +161,6 @@ class StandardTrainingSAE(TrainingSAE[StandardTrainingSAEConfig]):
|
|
|
164
161
|
"weights/b_e": b_e_dist,
|
|
165
162
|
}
|
|
166
163
|
|
|
167
|
-
def to_inference_config_dict(self) -> dict[str, Any]:
|
|
168
|
-
return filter_valid_dataclass_fields(
|
|
169
|
-
self.cfg.to_dict(), StandardSAEConfig, ["architecture"]
|
|
170
|
-
)
|
|
171
|
-
|
|
172
164
|
|
|
173
165
|
def _init_weights_standard(
|
|
174
166
|
sae: SAE[StandardSAEConfig] | TrainingSAE[StandardTrainingSAEConfig],
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""Inference-only TopKSAE variant, similar in spirit to StandardSAE but using a TopK-based activation."""
|
|
2
2
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
|
-
from typing import
|
|
4
|
+
from typing import Callable
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
from jaxtyping import Float
|
|
@@ -16,13 +16,12 @@ from sae_lens.saes.sae import (
|
|
|
16
16
|
TrainingSAEConfig,
|
|
17
17
|
TrainStepInput,
|
|
18
18
|
)
|
|
19
|
-
from sae_lens.util import filter_valid_dataclass_fields
|
|
20
19
|
|
|
21
20
|
|
|
22
21
|
class TopK(nn.Module):
|
|
23
22
|
"""
|
|
24
23
|
A simple TopK activation that zeroes out all but the top K elements along the last dimension,
|
|
25
|
-
|
|
24
|
+
and applies ReLU to the top K elements.
|
|
26
25
|
"""
|
|
27
26
|
|
|
28
27
|
b_enc: nn.Parameter
|
|
@@ -30,20 +29,18 @@ class TopK(nn.Module):
|
|
|
30
29
|
def __init__(
|
|
31
30
|
self,
|
|
32
31
|
k: int,
|
|
33
|
-
postact_fn: Callable[[torch.Tensor], torch.Tensor] = nn.ReLU(),
|
|
34
32
|
):
|
|
35
33
|
super().__init__()
|
|
36
34
|
self.k = k
|
|
37
|
-
self.postact_fn = postact_fn
|
|
38
35
|
|
|
39
36
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
40
37
|
"""
|
|
41
38
|
1) Select top K elements along the last dimension.
|
|
42
|
-
2) Apply
|
|
39
|
+
2) Apply ReLU.
|
|
43
40
|
3) Zero out all other entries.
|
|
44
41
|
"""
|
|
45
42
|
topk = torch.topk(x, k=self.k, dim=-1)
|
|
46
|
-
values =
|
|
43
|
+
values = topk.values.relu()
|
|
47
44
|
result = torch.zeros_like(x)
|
|
48
45
|
result.scatter_(-1, topk.indices, values)
|
|
49
46
|
return result
|
|
@@ -139,8 +136,7 @@ class TopKTrainingSAEConfig(TrainingSAEConfig):
|
|
|
139
136
|
|
|
140
137
|
class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
|
|
141
138
|
"""
|
|
142
|
-
TopK variant with training functionality.
|
|
143
|
-
calculates a topk-related auxiliary loss, etc.
|
|
139
|
+
TopK variant with training functionality. Calculates a topk-related auxiliary loss, etc.
|
|
144
140
|
"""
|
|
145
141
|
|
|
146
142
|
b_enc: nn.Parameter
|
|
@@ -157,7 +153,7 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
|
|
|
157
153
|
self, x: Float[torch.Tensor, "... d_in"]
|
|
158
154
|
) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
|
|
159
155
|
"""
|
|
160
|
-
Similar to the base training method:
|
|
156
|
+
Similar to the base training method: calculate pre-activations, then apply TopK.
|
|
161
157
|
"""
|
|
162
158
|
sae_in = self.process_sae_in(x)
|
|
163
159
|
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
|
|
@@ -237,50 +233,24 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
|
|
|
237
233
|
auxk_loss = (recons - residual).pow(2).sum(dim=-1).mean()
|
|
238
234
|
return scale * auxk_loss
|
|
239
235
|
|
|
240
|
-
def _calculate_topk_aux_acts(
|
|
241
|
-
self,
|
|
242
|
-
k_aux: int,
|
|
243
|
-
hidden_pre: torch.Tensor,
|
|
244
|
-
dead_neuron_mask: torch.Tensor,
|
|
245
|
-
) -> torch.Tensor:
|
|
246
|
-
"""
|
|
247
|
-
Helper method to calculate activations for the auxiliary loss.
|
|
248
|
-
|
|
249
|
-
Args:
|
|
250
|
-
k_aux: Number of top dead neurons to select
|
|
251
|
-
hidden_pre: Pre-activation values from encoder
|
|
252
|
-
dead_neuron_mask: Boolean mask indicating which neurons are dead
|
|
253
|
-
|
|
254
|
-
Returns:
|
|
255
|
-
Tensor with activations for only the top-k dead neurons, zeros elsewhere
|
|
256
|
-
"""
|
|
257
|
-
# Don't include living latents in this loss (set them to -inf so they won't be selected)
|
|
258
|
-
auxk_latents = torch.where(
|
|
259
|
-
dead_neuron_mask[None],
|
|
260
|
-
hidden_pre,
|
|
261
|
-
torch.tensor(-float("inf"), device=hidden_pre.device),
|
|
262
|
-
)
|
|
263
|
-
|
|
264
|
-
# Find topk values among dead neurons
|
|
265
|
-
auxk_topk = auxk_latents.topk(k_aux, dim=-1, sorted=False)
|
|
266
|
-
|
|
267
|
-
# Create a tensor of zeros, then place the topk values at their proper indices
|
|
268
|
-
auxk_acts = torch.zeros_like(hidden_pre)
|
|
269
|
-
auxk_acts.scatter_(-1, auxk_topk.indices, auxk_topk.values)
|
|
270
|
-
|
|
271
|
-
return auxk_acts
|
|
272
|
-
|
|
273
|
-
def to_inference_config_dict(self) -> dict[str, Any]:
|
|
274
|
-
return filter_valid_dataclass_fields(
|
|
275
|
-
self.cfg.to_dict(), TopKSAEConfig, ["architecture"]
|
|
276
|
-
)
|
|
277
|
-
|
|
278
236
|
|
|
279
237
|
def _calculate_topk_aux_acts(
|
|
280
238
|
k_aux: int,
|
|
281
239
|
hidden_pre: torch.Tensor,
|
|
282
240
|
dead_neuron_mask: torch.Tensor,
|
|
283
241
|
) -> torch.Tensor:
|
|
242
|
+
"""
|
|
243
|
+
Helper method to calculate activations for the auxiliary loss.
|
|
244
|
+
|
|
245
|
+
Args:
|
|
246
|
+
k_aux: Number of top dead neurons to select
|
|
247
|
+
hidden_pre: Pre-activation values from encoder
|
|
248
|
+
dead_neuron_mask: Boolean mask indicating which neurons are dead
|
|
249
|
+
|
|
250
|
+
Returns:
|
|
251
|
+
Tensor with activations for only the top-k dead neurons, zeros elsewhere
|
|
252
|
+
"""
|
|
253
|
+
|
|
284
254
|
# Don't include living latents in this loss
|
|
285
255
|
auxk_latents = torch.where(dead_neuron_mask[None], hidden_pre, -torch.inf)
|
|
286
256
|
# Top-k dead latents
|
|
@@ -7,7 +7,6 @@ from collections.abc import Generator, Iterator, Sequence
|
|
|
7
7
|
from typing import Any, Literal, cast
|
|
8
8
|
|
|
9
9
|
import datasets
|
|
10
|
-
import numpy as np
|
|
11
10
|
import torch
|
|
12
11
|
from datasets import Dataset, DatasetDict, IterableDataset, load_dataset
|
|
13
12
|
from huggingface_hub import hf_hub_download
|
|
@@ -420,20 +419,6 @@ class ActivationsStore:
|
|
|
420
419
|
|
|
421
420
|
return activations_dataset
|
|
422
421
|
|
|
423
|
-
@torch.no_grad()
|
|
424
|
-
def estimate_norm_scaling_factor(self, n_batches_for_norm_estimate: int = int(1e3)):
|
|
425
|
-
norms_per_batch = []
|
|
426
|
-
for _ in tqdm(
|
|
427
|
-
range(n_batches_for_norm_estimate), desc="Estimating norm scaling factor"
|
|
428
|
-
):
|
|
429
|
-
# temporalily set estimated_norm_scaling_factor to 1.0 so the dataloader works
|
|
430
|
-
self.estimated_norm_scaling_factor = 1.0
|
|
431
|
-
acts = self.next_batch()[0]
|
|
432
|
-
self.estimated_norm_scaling_factor = None
|
|
433
|
-
norms_per_batch.append(acts.norm(dim=-1).mean().item())
|
|
434
|
-
mean_norm = np.mean(norms_per_batch)
|
|
435
|
-
return np.sqrt(self.d_in) / mean_norm
|
|
436
|
-
|
|
437
422
|
def shuffle_input_dataset(self, seed: int, buffer_size: int = 1):
|
|
438
423
|
"""
|
|
439
424
|
This applies a shuffle to the huggingface dataset that is the input to the activations store. This
|
|
@@ -2,8 +2,6 @@
|
|
|
2
2
|
Took the LR scheduler from my previous work: https://github.com/jbloomAus/DecisionTransformerInterpretability/blob/ee55df35cdb92e81d689c72fb9dd5a7252893363/src/decision_transformer/utils.py#L425
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from typing import Any
|
|
6
|
-
|
|
7
5
|
import torch.optim as optim
|
|
8
6
|
import torch.optim.lr_scheduler as lr_scheduler
|
|
9
7
|
|
|
@@ -152,34 +150,3 @@ class CoefficientScheduler:
|
|
|
152
150
|
def value(self) -> float:
|
|
153
151
|
"""Returns the current scalar value."""
|
|
154
152
|
return self.current_value
|
|
155
|
-
|
|
156
|
-
def state_dict(self) -> dict[str, Any]:
|
|
157
|
-
"""State dict for serialization."""
|
|
158
|
-
return {
|
|
159
|
-
"warm_up_steps": self.warm_up_steps,
|
|
160
|
-
"final_value": self.final_value,
|
|
161
|
-
"current_step": self.current_step,
|
|
162
|
-
"current_value": self.current_value,
|
|
163
|
-
}
|
|
164
|
-
|
|
165
|
-
def load_state_dict(self, state_dict: dict[str, Any]):
|
|
166
|
-
"""Loads the scheduler state."""
|
|
167
|
-
self.warm_up_steps = state_dict["warm_up_steps"]
|
|
168
|
-
self.final_value = state_dict["final_value"]
|
|
169
|
-
self.current_step = state_dict["current_step"]
|
|
170
|
-
# Maintain consistency: re-calculate current_value based on loaded step
|
|
171
|
-
# This handles resuming correctly if stopped mid-warmup.
|
|
172
|
-
if self.current_step <= self.warm_up_steps and self.warm_up_steps > 0:
|
|
173
|
-
# Use max(0, ...) to handle case where current_step might be loaded as -1 or similar before first step
|
|
174
|
-
step_for_calc = max(0, self.current_step)
|
|
175
|
-
# Recalculate based on the step *before* the one about to be taken
|
|
176
|
-
# Or simply use the saved current_value if available and consistent
|
|
177
|
-
if "current_value" in state_dict:
|
|
178
|
-
self.current_value = state_dict["current_value"]
|
|
179
|
-
else: # Legacy state dicts might not have current_value
|
|
180
|
-
self.current_value = self.final_value * (
|
|
181
|
-
step_for_calc / self.warm_up_steps
|
|
182
|
-
)
|
|
183
|
-
|
|
184
|
-
else:
|
|
185
|
-
self.current_value = self.final_value
|
|
@@ -349,8 +349,10 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
|
349
349
|
},
|
|
350
350
|
}
|
|
351
351
|
for loss_name, loss_value in output.losses.items():
|
|
352
|
-
|
|
353
|
-
|
|
352
|
+
log_dict[f"losses/{loss_name}"] = _unwrap_item(loss_value)
|
|
353
|
+
|
|
354
|
+
for metric_name, metric_value in output.metrics.items():
|
|
355
|
+
log_dict[f"metrics/{metric_name}"] = _unwrap_item(metric_value)
|
|
354
356
|
|
|
355
357
|
return log_dict
|
|
356
358
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|