sae-lens 6.0.0rc4__tar.gz → 6.0.0rc5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sae_lens-6.0.0rc4 → sae_lens-6.0.0rc5}/PKG-INFO +2 -2
- {sae_lens-6.0.0rc4 → sae_lens-6.0.0rc5}/README.md +1 -1
- {sae_lens-6.0.0rc4 → sae_lens-6.0.0rc5}/pyproject.toml +1 -1
- {sae_lens-6.0.0rc4 → sae_lens-6.0.0rc5}/sae_lens/__init__.py +1 -1
- {sae_lens-6.0.0rc4 → sae_lens-6.0.0rc5}/sae_lens/constants.py +1 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.0.0rc5}/sae_lens/evals.py +0 -11
- {sae_lens-6.0.0rc4 → sae_lens-6.0.0rc5}/sae_lens/loading/pretrained_sae_loaders.py +154 -2
- {sae_lens-6.0.0rc4 → sae_lens-6.0.0rc5}/sae_lens/pretrained_saes.yaml +12 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.0.0rc5}/sae_lens/saes/sae.py +58 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.0.0rc5}/sae_lens/training/activations_store.py +1 -1
- {sae_lens-6.0.0rc4 → sae_lens-6.0.0rc5}/LICENSE +0 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.0.0rc5}/sae_lens/analysis/__init__.py +0 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.0.0rc5}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.0.0rc5}/sae_lens/analysis/neuronpedia_integration.py +0 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.0.0rc5}/sae_lens/cache_activations_runner.py +0 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.0.0rc5}/sae_lens/config.py +0 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.0.0rc5}/sae_lens/llm_sae_training_runner.py +0 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.0.0rc5}/sae_lens/load_model.py +0 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.0.0rc5}/sae_lens/loading/__init__.py +0 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.0.0rc5}/sae_lens/loading/pretrained_saes_directory.py +0 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.0.0rc5}/sae_lens/pretokenize_runner.py +0 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.0.0rc5}/sae_lens/registry.py +0 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.0.0rc5}/sae_lens/saes/__init__.py +0 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.0.0rc5}/sae_lens/saes/gated_sae.py +0 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.0.0rc5}/sae_lens/saes/jumprelu_sae.py +0 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.0.0rc5}/sae_lens/saes/standard_sae.py +0 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.0.0rc5}/sae_lens/saes/topk_sae.py +0 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.0.0rc5}/sae_lens/tokenization_and_batching.py +0 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.0.0rc5}/sae_lens/training/__init__.py +0 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.0.0rc5}/sae_lens/training/activation_scaler.py +0 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.0.0rc5}/sae_lens/training/mixing_buffer.py +0 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.0.0rc5}/sae_lens/training/optim.py +0 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.0.0rc5}/sae_lens/training/sae_trainer.py +0 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.0.0rc5}/sae_lens/training/types.py +0 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.0.0rc5}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.0.0rc5}/sae_lens/tutorial/tsea.py +0 -0
- {sae_lens-6.0.0rc4 → sae_lens-6.0.0rc5}/sae_lens/util.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: sae-lens
|
|
3
|
-
Version: 6.0.
|
|
3
|
+
Version: 6.0.0rc5
|
|
4
4
|
Summary: Training and Analyzing Sparse Autoencoders (SAEs)
|
|
5
5
|
License: MIT
|
|
6
6
|
Keywords: deep-learning,sparse-autoencoders,mechanistic-interpretability,PyTorch
|
|
@@ -80,7 +80,7 @@ Pre-trained SAEs for various models can be imported via SAE Lens. See this [page
|
|
|
80
80
|
|
|
81
81
|
## Join the Slack!
|
|
82
82
|
|
|
83
|
-
Feel free to join the [Open Source Mechanistic Interpretability Slack](https://join.slack.com/t/opensourcemechanistic/shared_invite/zt-
|
|
83
|
+
Feel free to join the [Open Source Mechanistic Interpretability Slack](https://join.slack.com/t/opensourcemechanistic/shared_invite/zt-375zalm04-GFd5tdBU1yLKlu_T_JSqZQ) for support!
|
|
84
84
|
|
|
85
85
|
## Citation
|
|
86
86
|
|
|
@@ -40,7 +40,7 @@ Pre-trained SAEs for various models can be imported via SAE Lens. See this [page
|
|
|
40
40
|
|
|
41
41
|
## Join the Slack!
|
|
42
42
|
|
|
43
|
-
Feel free to join the [Open Source Mechanistic Interpretability Slack](https://join.slack.com/t/opensourcemechanistic/shared_invite/zt-
|
|
43
|
+
Feel free to join the [Open Source Mechanistic Interpretability Slack](https://join.slack.com/t/opensourcemechanistic/shared_invite/zt-375zalm04-GFd5tdBU1yLKlu_T_JSqZQ) for support!
|
|
44
44
|
|
|
45
45
|
## Citation
|
|
46
46
|
|
|
@@ -16,5 +16,6 @@ SPARSITY_FILENAME = "sparsity.safetensors"
|
|
|
16
16
|
SAE_WEIGHTS_FILENAME = "sae_weights.safetensors"
|
|
17
17
|
SAE_CFG_FILENAME = "cfg.json"
|
|
18
18
|
RUNNER_CFG_FILENAME = "runner_cfg.json"
|
|
19
|
+
SPARSIFY_WEIGHTS_FILENAME = "sae.safetensors"
|
|
19
20
|
ACTIVATIONS_STORE_STATE_FILENAME = "activations_store_state.safetensors"
|
|
20
21
|
ACTIVATION_SCALER_CFG_FILENAME = "activation_scaler.json"
|
|
@@ -769,17 +769,6 @@ def nested_dict() -> defaultdict[Any, Any]:
|
|
|
769
769
|
return defaultdict(nested_dict)
|
|
770
770
|
|
|
771
771
|
|
|
772
|
-
def dict_to_nested(flat_dict: dict[str, Any]) -> defaultdict[Any, Any]:
|
|
773
|
-
nested = nested_dict()
|
|
774
|
-
for key, value in flat_dict.items():
|
|
775
|
-
parts = key.split("/")
|
|
776
|
-
d = nested
|
|
777
|
-
for part in parts[:-1]:
|
|
778
|
-
d = d[part]
|
|
779
|
-
d[parts[-1]] = value
|
|
780
|
-
return nested
|
|
781
|
-
|
|
782
|
-
|
|
783
772
|
def multiple_evals(
|
|
784
773
|
sae_regex_pattern: str,
|
|
785
774
|
sae_block_pattern: str,
|
|
@@ -16,6 +16,7 @@ from sae_lens.constants import (
|
|
|
16
16
|
DTYPE_MAP,
|
|
17
17
|
SAE_CFG_FILENAME,
|
|
18
18
|
SAE_WEIGHTS_FILENAME,
|
|
19
|
+
SPARSIFY_WEIGHTS_FILENAME,
|
|
19
20
|
SPARSITY_FILENAME,
|
|
20
21
|
)
|
|
21
22
|
from sae_lens.loading.pretrained_saes_directory import (
|
|
@@ -248,7 +249,7 @@ def handle_pre_6_0_config(cfg_dict: dict[str, Any]) -> dict[str, Any]:
|
|
|
248
249
|
config_class = get_sae_class(architecture)[1]
|
|
249
250
|
|
|
250
251
|
sae_cfg_dict = filter_valid_dataclass_fields(new_cfg, config_class)
|
|
251
|
-
if architecture == "topk":
|
|
252
|
+
if architecture == "topk" and "activation_fn_kwargs" in new_cfg:
|
|
252
253
|
sae_cfg_dict["k"] = new_cfg["activation_fn_kwargs"]["k"]
|
|
253
254
|
|
|
254
255
|
sae_cfg_dict["metadata"] = {
|
|
@@ -530,11 +531,20 @@ def get_llama_scope_config_from_hf(
|
|
|
530
531
|
# Model specific parameters
|
|
531
532
|
model_name, d_in = "meta-llama/Llama-3.1-8B", old_cfg_dict["d_model"]
|
|
532
533
|
|
|
534
|
+
# Get norm scaling factor to rescale jumprelu threshold.
|
|
535
|
+
# We need this because sae.fold_activation_norm_scaling_factor folds scaling norm into W_enc.
|
|
536
|
+
# This requires jumprelu threshold to be scaled in the same way
|
|
537
|
+
norm_scaling_factor = (
|
|
538
|
+
d_in**0.5 / old_cfg_dict["dataset_average_activation_norm"]["in"]
|
|
539
|
+
)
|
|
540
|
+
|
|
533
541
|
cfg_dict = {
|
|
534
542
|
"architecture": "jumprelu",
|
|
535
|
-
"jump_relu_threshold": old_cfg_dict["jump_relu_threshold"]
|
|
543
|
+
"jump_relu_threshold": old_cfg_dict["jump_relu_threshold"]
|
|
544
|
+
* norm_scaling_factor,
|
|
536
545
|
# We use a scalar jump_relu_threshold for all features
|
|
537
546
|
# This is different from Gemma Scope JumpReLU SAEs.
|
|
547
|
+
# Scaled with norm_scaling_factor to match sae.fold_activation_norm_scaling_factor
|
|
538
548
|
"d_in": d_in,
|
|
539
549
|
"d_sae": old_cfg_dict["d_sae"],
|
|
540
550
|
"dtype": "bfloat16",
|
|
@@ -942,6 +952,146 @@ def llama_scope_r1_distill_sae_huggingface_loader(
|
|
|
942
952
|
return cfg_dict, state_dict, log_sparsity
|
|
943
953
|
|
|
944
954
|
|
|
955
|
+
def get_sparsify_config_from_hf(
|
|
956
|
+
repo_id: str,
|
|
957
|
+
folder_name: str,
|
|
958
|
+
device: str,
|
|
959
|
+
force_download: bool = False,
|
|
960
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
961
|
+
) -> dict[str, Any]:
|
|
962
|
+
cfg_filename = f"{folder_name}/{SAE_CFG_FILENAME}"
|
|
963
|
+
cfg_path = hf_hub_download(
|
|
964
|
+
repo_id,
|
|
965
|
+
filename=cfg_filename,
|
|
966
|
+
force_download=force_download,
|
|
967
|
+
)
|
|
968
|
+
sae_path = Path(cfg_path).parent
|
|
969
|
+
return get_sparsify_config_from_disk(
|
|
970
|
+
sae_path, device=device, cfg_overrides=cfg_overrides
|
|
971
|
+
)
|
|
972
|
+
|
|
973
|
+
|
|
974
|
+
def get_sparsify_config_from_disk(
|
|
975
|
+
path: str | Path,
|
|
976
|
+
device: str | None = None,
|
|
977
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
978
|
+
) -> dict[str, Any]:
|
|
979
|
+
path = Path(path)
|
|
980
|
+
|
|
981
|
+
with open(path / SAE_CFG_FILENAME) as f:
|
|
982
|
+
old_cfg_dict = json.load(f)
|
|
983
|
+
|
|
984
|
+
config_path = path.parent / "config.json"
|
|
985
|
+
if config_path.exists():
|
|
986
|
+
with open(config_path) as f:
|
|
987
|
+
config_dict = json.load(f)
|
|
988
|
+
else:
|
|
989
|
+
config_dict = {}
|
|
990
|
+
|
|
991
|
+
folder_name = path.name
|
|
992
|
+
if folder_name == "embed_tokens":
|
|
993
|
+
hook_name, layer = "hook_embed", 0
|
|
994
|
+
else:
|
|
995
|
+
match = re.search(r"layers[._](\d+)", folder_name)
|
|
996
|
+
if match is None:
|
|
997
|
+
raise ValueError(f"Unrecognized Sparsify folder: {folder_name}")
|
|
998
|
+
layer = int(match.group(1))
|
|
999
|
+
hook_name = f"blocks.{layer}.hook_resid_post"
|
|
1000
|
+
|
|
1001
|
+
cfg_dict: dict[str, Any] = {
|
|
1002
|
+
"architecture": "standard",
|
|
1003
|
+
"d_in": old_cfg_dict["d_in"],
|
|
1004
|
+
"d_sae": old_cfg_dict["d_in"] * old_cfg_dict["expansion_factor"],
|
|
1005
|
+
"dtype": "bfloat16",
|
|
1006
|
+
"device": device or "cpu",
|
|
1007
|
+
"model_name": config_dict.get("model", path.parts[-2]),
|
|
1008
|
+
"hook_name": hook_name,
|
|
1009
|
+
"hook_layer": layer,
|
|
1010
|
+
"hook_head_index": None,
|
|
1011
|
+
"activation_fn_str": "topk",
|
|
1012
|
+
"activation_fn_kwargs": {
|
|
1013
|
+
"k": old_cfg_dict["k"],
|
|
1014
|
+
"signed": old_cfg_dict.get("signed", False),
|
|
1015
|
+
},
|
|
1016
|
+
"apply_b_dec_to_input": not old_cfg_dict.get("normalize_decoder", False),
|
|
1017
|
+
"dataset_path": config_dict.get(
|
|
1018
|
+
"dataset", "togethercomputer/RedPajama-Data-1T-Sample"
|
|
1019
|
+
),
|
|
1020
|
+
"context_size": config_dict.get("ctx_len", 2048),
|
|
1021
|
+
"finetuning_scaling_factor": False,
|
|
1022
|
+
"sae_lens_training_version": None,
|
|
1023
|
+
"prepend_bos": True,
|
|
1024
|
+
"dataset_trust_remote_code": True,
|
|
1025
|
+
"normalize_activations": "none",
|
|
1026
|
+
"neuronpedia_id": None,
|
|
1027
|
+
}
|
|
1028
|
+
|
|
1029
|
+
if cfg_overrides:
|
|
1030
|
+
cfg_dict.update(cfg_overrides)
|
|
1031
|
+
|
|
1032
|
+
return cfg_dict
|
|
1033
|
+
|
|
1034
|
+
|
|
1035
|
+
def sparsify_huggingface_loader(
|
|
1036
|
+
repo_id: str,
|
|
1037
|
+
folder_name: str,
|
|
1038
|
+
device: str = "cpu",
|
|
1039
|
+
force_download: bool = False,
|
|
1040
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
1041
|
+
) -> tuple[dict[str, Any], dict[str, torch.Tensor], None]:
|
|
1042
|
+
weights_filename = f"{folder_name}/{SPARSIFY_WEIGHTS_FILENAME}"
|
|
1043
|
+
sae_path = hf_hub_download(
|
|
1044
|
+
repo_id,
|
|
1045
|
+
filename=weights_filename,
|
|
1046
|
+
force_download=force_download,
|
|
1047
|
+
)
|
|
1048
|
+
cfg_dict, state_dict = sparsify_disk_loader(
|
|
1049
|
+
Path(sae_path).parent, device=device, cfg_overrides=cfg_overrides
|
|
1050
|
+
)
|
|
1051
|
+
return cfg_dict, state_dict, None
|
|
1052
|
+
|
|
1053
|
+
|
|
1054
|
+
def sparsify_disk_loader(
|
|
1055
|
+
path: str | Path,
|
|
1056
|
+
device: str = "cpu",
|
|
1057
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
1058
|
+
) -> tuple[dict[str, Any], dict[str, torch.Tensor]]:
|
|
1059
|
+
cfg_dict = get_sparsify_config_from_disk(path, device, cfg_overrides)
|
|
1060
|
+
|
|
1061
|
+
weight_path = Path(path) / SPARSIFY_WEIGHTS_FILENAME
|
|
1062
|
+
state_dict_loaded = load_file(weight_path, device=device)
|
|
1063
|
+
|
|
1064
|
+
dtype = DTYPE_MAP[cfg_dict["dtype"]]
|
|
1065
|
+
|
|
1066
|
+
W_enc = (
|
|
1067
|
+
state_dict_loaded["W_enc"]
|
|
1068
|
+
if "W_enc" in state_dict_loaded
|
|
1069
|
+
else state_dict_loaded["encoder.weight"].T
|
|
1070
|
+
).to(dtype)
|
|
1071
|
+
|
|
1072
|
+
if "W_dec" in state_dict_loaded:
|
|
1073
|
+
W_dec = state_dict_loaded["W_dec"].T.to(dtype)
|
|
1074
|
+
else:
|
|
1075
|
+
W_dec = state_dict_loaded["decoder.weight"].T.to(dtype)
|
|
1076
|
+
|
|
1077
|
+
if "b_enc" in state_dict_loaded:
|
|
1078
|
+
b_enc = state_dict_loaded["b_enc"].to(dtype)
|
|
1079
|
+
elif "encoder.bias" in state_dict_loaded:
|
|
1080
|
+
b_enc = state_dict_loaded["encoder.bias"].to(dtype)
|
|
1081
|
+
else:
|
|
1082
|
+
b_enc = torch.zeros(cfg_dict["d_sae"], dtype=dtype, device=device)
|
|
1083
|
+
|
|
1084
|
+
if "b_dec" in state_dict_loaded:
|
|
1085
|
+
b_dec = state_dict_loaded["b_dec"].to(dtype)
|
|
1086
|
+
elif "decoder.bias" in state_dict_loaded:
|
|
1087
|
+
b_dec = state_dict_loaded["decoder.bias"].to(dtype)
|
|
1088
|
+
else:
|
|
1089
|
+
b_dec = torch.zeros(cfg_dict["d_in"], dtype=dtype, device=device)
|
|
1090
|
+
|
|
1091
|
+
state_dict = {"W_enc": W_enc, "b_enc": b_enc, "W_dec": W_dec, "b_dec": b_dec}
|
|
1092
|
+
return cfg_dict, state_dict
|
|
1093
|
+
|
|
1094
|
+
|
|
945
1095
|
NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
|
|
946
1096
|
"sae_lens": sae_lens_huggingface_loader,
|
|
947
1097
|
"connor_rob_hook_z": connor_rob_hook_z_huggingface_loader,
|
|
@@ -950,6 +1100,7 @@ NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
|
|
|
950
1100
|
"llama_scope_r1_distill": llama_scope_r1_distill_sae_huggingface_loader,
|
|
951
1101
|
"dictionary_learning_1": dictionary_learning_sae_huggingface_loader_1,
|
|
952
1102
|
"deepseek_r1": deepseek_r1_sae_huggingface_loader,
|
|
1103
|
+
"sparsify": sparsify_huggingface_loader,
|
|
953
1104
|
}
|
|
954
1105
|
|
|
955
1106
|
|
|
@@ -961,4 +1112,5 @@ NAMED_PRETRAINED_SAE_CONFIG_GETTERS: dict[str, PretrainedSaeConfigHuggingfaceLoa
|
|
|
961
1112
|
"llama_scope_r1_distill": get_llama_scope_r1_distill_config_from_hf,
|
|
962
1113
|
"dictionary_learning_1": get_dictionary_learning_config_1_from_hf,
|
|
963
1114
|
"deepseek_r1": get_deepseek_r1_config_from_hf,
|
|
1115
|
+
"sparsify": get_sparsify_config_from_hf,
|
|
964
1116
|
}
|
|
@@ -13634,39 +13634,51 @@ gemma-2-2b-res-matryoshka-dc:
|
|
|
13634
13634
|
- id: blocks.13.hook_resid_post
|
|
13635
13635
|
path: standard/blocks.13.hook_resid_post
|
|
13636
13636
|
l0: 40.0
|
|
13637
|
+
neuronpedia: gemma-2-2b/13-res-matryoshka-dc
|
|
13637
13638
|
- id: blocks.14.hook_resid_post
|
|
13638
13639
|
path: standard/blocks.14.hook_resid_post
|
|
13639
13640
|
l0: 40.0
|
|
13641
|
+
neuronpedia: gemma-2-2b/14-res-matryoshka-dc
|
|
13640
13642
|
- id: blocks.15.hook_resid_post
|
|
13641
13643
|
path: standard/blocks.15.hook_resid_post
|
|
13642
13644
|
l0: 40.0
|
|
13645
|
+
neuronpedia: gemma-2-2b/15-res-matryoshka-dc
|
|
13643
13646
|
- id: blocks.16.hook_resid_post
|
|
13644
13647
|
path: standard/blocks.16.hook_resid_post
|
|
13645
13648
|
l0: 40.0
|
|
13649
|
+
neuronpedia: gemma-2-2b/16-res-matryoshka-dc
|
|
13646
13650
|
- id: blocks.17.hook_resid_post
|
|
13647
13651
|
path: standard/blocks.17.hook_resid_post
|
|
13648
13652
|
l0: 40.0
|
|
13653
|
+
neuronpedia: gemma-2-2b/17-res-matryoshka-dc
|
|
13649
13654
|
- id: blocks.18.hook_resid_post
|
|
13650
13655
|
path: standard/blocks.18.hook_resid_post
|
|
13651
13656
|
l0: 40.0
|
|
13657
|
+
neuronpedia: gemma-2-2b/18-res-matryoshka-dc
|
|
13652
13658
|
- id: blocks.19.hook_resid_post
|
|
13653
13659
|
path: standard/blocks.19.hook_resid_post
|
|
13654
13660
|
l0: 40.0
|
|
13661
|
+
neuronpedia: gemma-2-2b/19-res-matryoshka-dc
|
|
13655
13662
|
- id: blocks.20.hook_resid_post
|
|
13656
13663
|
path: standard/blocks.20.hook_resid_post
|
|
13657
13664
|
l0: 40.0
|
|
13665
|
+
neuronpedia: gemma-2-2b/20-res-matryoshka-dc
|
|
13658
13666
|
- id: blocks.21.hook_resid_post
|
|
13659
13667
|
path: standard/blocks.21.hook_resid_post
|
|
13660
13668
|
l0: 40.0
|
|
13669
|
+
neuronpedia: gemma-2-2b/21-res-matryoshka-dc
|
|
13661
13670
|
- id: blocks.22.hook_resid_post
|
|
13662
13671
|
path: standard/blocks.22.hook_resid_post
|
|
13663
13672
|
l0: 40.0
|
|
13673
|
+
neuronpedia: gemma-2-2b/22-res-matryoshka-dc
|
|
13664
13674
|
- id: blocks.23.hook_resid_post
|
|
13665
13675
|
path: standard/blocks.23.hook_resid_post
|
|
13666
13676
|
l0: 40.0
|
|
13677
|
+
neuronpedia: gemma-2-2b/23-res-matryoshka-dc
|
|
13667
13678
|
- id: blocks.24.hook_resid_post
|
|
13668
13679
|
path: standard/blocks.24.hook_resid_post
|
|
13669
13680
|
l0: 40.0
|
|
13681
|
+
neuronpedia: gemma-2-2b/24-res-matryoshka-dc
|
|
13670
13682
|
gemma-2-2b-res-snap-matryoshka-dc:
|
|
13671
13683
|
conversion_func: null
|
|
13672
13684
|
links:
|
|
@@ -732,6 +732,64 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
732
732
|
) -> type[SAEConfig]:
|
|
733
733
|
return SAEConfig
|
|
734
734
|
|
|
735
|
+
### Methods to support deprecated usage of SAE.from_pretrained() ###
|
|
736
|
+
|
|
737
|
+
def __getitem__(self, index: int) -> Any:
|
|
738
|
+
"""
|
|
739
|
+
Support indexing for backward compatibility with tuple unpacking.
|
|
740
|
+
DEPRECATED: SAE.from_pretrained() no longer returns a tuple.
|
|
741
|
+
Use SAE.from_pretrained_with_cfg_and_sparsity() instead.
|
|
742
|
+
"""
|
|
743
|
+
warnings.warn(
|
|
744
|
+
"Indexing SAE objects is deprecated. SAE.from_pretrained() now returns "
|
|
745
|
+
"only the SAE object. Use SAE.from_pretrained_with_cfg_and_sparsity() "
|
|
746
|
+
"to get the config dict and sparsity as well.",
|
|
747
|
+
DeprecationWarning,
|
|
748
|
+
stacklevel=2,
|
|
749
|
+
)
|
|
750
|
+
|
|
751
|
+
if index == 0:
|
|
752
|
+
return self
|
|
753
|
+
if index == 1:
|
|
754
|
+
return self.cfg.to_dict()
|
|
755
|
+
if index == 2:
|
|
756
|
+
return None
|
|
757
|
+
raise IndexError(f"SAE tuple index {index} out of range")
|
|
758
|
+
|
|
759
|
+
def __iter__(self):
|
|
760
|
+
"""
|
|
761
|
+
Support unpacking for backward compatibility with tuple unpacking.
|
|
762
|
+
DEPRECATED: SAE.from_pretrained() no longer returns a tuple.
|
|
763
|
+
Use SAE.from_pretrained_with_cfg_and_sparsity() instead.
|
|
764
|
+
"""
|
|
765
|
+
warnings.warn(
|
|
766
|
+
"Unpacking SAE objects is deprecated. SAE.from_pretrained() now returns "
|
|
767
|
+
"only the SAE object. Use SAE.from_pretrained_with_cfg_and_sparsity() "
|
|
768
|
+
"to get the config dict and sparsity as well.",
|
|
769
|
+
DeprecationWarning,
|
|
770
|
+
stacklevel=2,
|
|
771
|
+
)
|
|
772
|
+
|
|
773
|
+
yield self
|
|
774
|
+
yield self.cfg.to_dict()
|
|
775
|
+
yield None
|
|
776
|
+
|
|
777
|
+
def __len__(self) -> int:
|
|
778
|
+
"""
|
|
779
|
+
Support len() for backward compatibility with tuple unpacking.
|
|
780
|
+
DEPRECATED: SAE.from_pretrained() no longer returns a tuple.
|
|
781
|
+
Use SAE.from_pretrained_with_cfg_and_sparsity() instead.
|
|
782
|
+
"""
|
|
783
|
+
warnings.warn(
|
|
784
|
+
"Getting length of SAE objects is deprecated. SAE.from_pretrained() now returns "
|
|
785
|
+
"only the SAE object. Use SAE.from_pretrained_with_cfg_and_sparsity() "
|
|
786
|
+
"to get the config dict and sparsity as well.",
|
|
787
|
+
DeprecationWarning,
|
|
788
|
+
stacklevel=2,
|
|
789
|
+
)
|
|
790
|
+
|
|
791
|
+
return 3
|
|
792
|
+
|
|
735
793
|
|
|
736
794
|
@dataclass(kw_only=True)
|
|
737
795
|
class TrainingSAEConfig(SAEConfig, ABC):
|
|
@@ -428,7 +428,7 @@ class ActivationsStore:
|
|
|
428
428
|
):
|
|
429
429
|
# temporalily set estimated_norm_scaling_factor to 1.0 so the dataloader works
|
|
430
430
|
self.estimated_norm_scaling_factor = 1.0
|
|
431
|
-
acts = self.next_batch()[0]
|
|
431
|
+
acts = self.next_batch()[:, 0]
|
|
432
432
|
self.estimated_norm_scaling_factor = None
|
|
433
433
|
norms_per_batch.append(acts.norm(dim=-1).mean().item())
|
|
434
434
|
mean_norm = np.mean(norms_per_batch)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|