sae-lens 6.25.0__tar.gz → 6.26.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sae_lens-6.25.0 → sae_lens-6.26.1}/PKG-INFO +1 -1
- {sae_lens-6.25.0 → sae_lens-6.26.1}/pyproject.toml +2 -1
- {sae_lens-6.25.0 → sae_lens-6.26.1}/sae_lens/__init__.py +13 -1
- {sae_lens-6.25.0 → sae_lens-6.26.1}/sae_lens/cache_activations_runner.py +2 -2
- {sae_lens-6.25.0 → sae_lens-6.26.1}/sae_lens/config.py +7 -2
- {sae_lens-6.25.0 → sae_lens-6.26.1}/sae_lens/constants.py +8 -0
- {sae_lens-6.25.0 → sae_lens-6.26.1}/sae_lens/loading/pretrained_sae_loaders.py +66 -57
- {sae_lens-6.25.0 → sae_lens-6.26.1}/sae_lens/pretrained_saes.yaml +144 -144
- {sae_lens-6.25.0 → sae_lens-6.26.1}/sae_lens/saes/__init__.py +10 -0
- sae_lens-6.26.1/sae_lens/saes/matching_pursuit_sae.py +334 -0
- {sae_lens-6.25.0 → sae_lens-6.26.1}/sae_lens/saes/sae.py +52 -12
- {sae_lens-6.25.0 → sae_lens-6.26.1}/sae_lens/training/activations_store.py +3 -2
- {sae_lens-6.25.0 → sae_lens-6.26.1}/sae_lens/util.py +21 -0
- {sae_lens-6.25.0 → sae_lens-6.26.1}/LICENSE +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.1}/README.md +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.1}/sae_lens/analysis/__init__.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.1}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.1}/sae_lens/analysis/neuronpedia_integration.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.1}/sae_lens/evals.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.1}/sae_lens/llm_sae_training_runner.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.1}/sae_lens/load_model.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.1}/sae_lens/loading/__init__.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.1}/sae_lens/loading/pretrained_saes_directory.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.1}/sae_lens/pretokenize_runner.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.1}/sae_lens/registry.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.1}/sae_lens/saes/batchtopk_sae.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.1}/sae_lens/saes/gated_sae.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.1}/sae_lens/saes/jumprelu_sae.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.1}/sae_lens/saes/matryoshka_batchtopk_sae.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.1}/sae_lens/saes/standard_sae.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.1}/sae_lens/saes/temporal_sae.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.1}/sae_lens/saes/topk_sae.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.1}/sae_lens/saes/transcoder.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.1}/sae_lens/tokenization_and_batching.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.1}/sae_lens/training/__init__.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.1}/sae_lens/training/activation_scaler.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.1}/sae_lens/training/mixing_buffer.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.1}/sae_lens/training/optim.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.1}/sae_lens/training/sae_trainer.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.1}/sae_lens/training/types.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.1}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.1}/sae_lens/tutorial/tsea.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "sae-lens"
|
|
3
|
-
version = "6.
|
|
3
|
+
version = "6.26.1"
|
|
4
4
|
description = "Training and Analyzing Sparse Autoencoders (SAEs)"
|
|
5
5
|
authors = ["Joseph Bloom"]
|
|
6
6
|
readme = "README.md"
|
|
@@ -51,6 +51,7 @@ mkdocs-redirects = "^1.2.1"
|
|
|
51
51
|
mkdocs-section-index = "^0.3.9"
|
|
52
52
|
mkdocstrings = "^0.25.2"
|
|
53
53
|
mkdocstrings-python = "^1.10.9"
|
|
54
|
+
beautifulsoup4 = "^4.12.0"
|
|
54
55
|
tabulate = "^0.9.0"
|
|
55
56
|
ruff = "^0.7.4"
|
|
56
57
|
eai-sparsify = "^1.1.1"
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
# ruff: noqa: E402
|
|
2
|
-
__version__ = "6.
|
|
2
|
+
__version__ = "6.26.1"
|
|
3
3
|
|
|
4
4
|
import logging
|
|
5
5
|
|
|
@@ -21,6 +21,10 @@ from sae_lens.saes import (
|
|
|
21
21
|
JumpReLUTrainingSAEConfig,
|
|
22
22
|
JumpReLUTranscoder,
|
|
23
23
|
JumpReLUTranscoderConfig,
|
|
24
|
+
MatchingPursuitSAE,
|
|
25
|
+
MatchingPursuitSAEConfig,
|
|
26
|
+
MatchingPursuitTrainingSAE,
|
|
27
|
+
MatchingPursuitTrainingSAEConfig,
|
|
24
28
|
MatryoshkaBatchTopKTrainingSAE,
|
|
25
29
|
MatryoshkaBatchTopKTrainingSAEConfig,
|
|
26
30
|
SAEConfig,
|
|
@@ -113,6 +117,10 @@ __all__ = [
|
|
|
113
117
|
"MatryoshkaBatchTopKTrainingSAEConfig",
|
|
114
118
|
"TemporalSAE",
|
|
115
119
|
"TemporalSAEConfig",
|
|
120
|
+
"MatchingPursuitSAE",
|
|
121
|
+
"MatchingPursuitTrainingSAE",
|
|
122
|
+
"MatchingPursuitSAEConfig",
|
|
123
|
+
"MatchingPursuitTrainingSAEConfig",
|
|
116
124
|
]
|
|
117
125
|
|
|
118
126
|
|
|
@@ -139,3 +147,7 @@ register_sae_class(
|
|
|
139
147
|
"jumprelu_skip_transcoder", JumpReLUSkipTranscoder, JumpReLUSkipTranscoderConfig
|
|
140
148
|
)
|
|
141
149
|
register_sae_class("temporal", TemporalSAE, TemporalSAEConfig)
|
|
150
|
+
register_sae_class("matching_pursuit", MatchingPursuitSAE, MatchingPursuitSAEConfig)
|
|
151
|
+
register_sae_training_class(
|
|
152
|
+
"matching_pursuit", MatchingPursuitTrainingSAE, MatchingPursuitTrainingSAEConfig
|
|
153
|
+
)
|
|
@@ -14,9 +14,9 @@ from transformer_lens.HookedTransformer import HookedRootModule
|
|
|
14
14
|
|
|
15
15
|
from sae_lens import logger
|
|
16
16
|
from sae_lens.config import CacheActivationsRunnerConfig
|
|
17
|
-
from sae_lens.constants import DTYPE_MAP
|
|
18
17
|
from sae_lens.load_model import load_model
|
|
19
18
|
from sae_lens.training.activations_store import ActivationsStore
|
|
19
|
+
from sae_lens.util import str_to_dtype
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
def _mk_activations_store(
|
|
@@ -97,7 +97,7 @@ class CacheActivationsRunner:
|
|
|
97
97
|
bytes_per_token = (
|
|
98
98
|
self.cfg.d_in * self.cfg.dtype.itemsize
|
|
99
99
|
if isinstance(self.cfg.dtype, torch.dtype)
|
|
100
|
-
else
|
|
100
|
+
else str_to_dtype(self.cfg.dtype).itemsize
|
|
101
101
|
)
|
|
102
102
|
total_training_tokens = self.cfg.n_seq_in_dataset * self.context_size
|
|
103
103
|
total_disk_space_gb = total_training_tokens * bytes_per_token / 10**9
|
|
@@ -17,9 +17,14 @@ from datasets import (
|
|
|
17
17
|
)
|
|
18
18
|
|
|
19
19
|
from sae_lens import __version__, logger
|
|
20
|
-
|
|
20
|
+
|
|
21
|
+
# keeping this unused import since some SAELens deps import DTYPE_MAP from config
|
|
22
|
+
from sae_lens.constants import (
|
|
23
|
+
DTYPE_MAP, # noqa: F401 # pyright: ignore[reportUnusedImport]
|
|
24
|
+
)
|
|
21
25
|
from sae_lens.registry import get_sae_training_class
|
|
22
26
|
from sae_lens.saes.sae import TrainingSAEConfig
|
|
27
|
+
from sae_lens.util import str_to_dtype
|
|
23
28
|
|
|
24
29
|
if TYPE_CHECKING:
|
|
25
30
|
pass
|
|
@@ -563,7 +568,7 @@ class CacheActivationsRunnerConfig:
|
|
|
563
568
|
|
|
564
569
|
@property
|
|
565
570
|
def bytes_per_token(self) -> int:
|
|
566
|
-
return self.d_in *
|
|
571
|
+
return self.d_in * str_to_dtype(self.dtype).itemsize
|
|
567
572
|
|
|
568
573
|
@property
|
|
569
574
|
def n_tokens_in_buffer(self) -> int:
|
|
@@ -11,6 +11,14 @@ DTYPE_MAP = {
|
|
|
11
11
|
"torch.bfloat16": torch.bfloat16,
|
|
12
12
|
}
|
|
13
13
|
|
|
14
|
+
# Reverse mapping from torch.dtype to canonical string format
|
|
15
|
+
DTYPE_TO_STR = {
|
|
16
|
+
torch.float32: "float32",
|
|
17
|
+
torch.float64: "float64",
|
|
18
|
+
torch.float16: "float16",
|
|
19
|
+
torch.bfloat16: "bfloat16",
|
|
20
|
+
}
|
|
21
|
+
|
|
14
22
|
|
|
15
23
|
SPARSITY_FILENAME = "sparsity.safetensors"
|
|
16
24
|
SAE_WEIGHTS_FILENAME = "sae_weights.safetensors"
|
|
@@ -12,11 +12,9 @@ from huggingface_hub import hf_hub_download, hf_hub_url
|
|
|
12
12
|
from huggingface_hub.utils import EntryNotFoundError, build_hf_headers
|
|
13
13
|
from packaging.version import Version
|
|
14
14
|
from safetensors import safe_open
|
|
15
|
-
from safetensors.torch import load_file
|
|
16
15
|
|
|
17
16
|
from sae_lens import logger
|
|
18
17
|
from sae_lens.constants import (
|
|
19
|
-
DTYPE_MAP,
|
|
20
18
|
SAE_CFG_FILENAME,
|
|
21
19
|
SAE_WEIGHTS_FILENAME,
|
|
22
20
|
SPARSIFY_WEIGHTS_FILENAME,
|
|
@@ -28,7 +26,7 @@ from sae_lens.loading.pretrained_saes_directory import (
|
|
|
28
26
|
get_repo_id_and_folder_name,
|
|
29
27
|
)
|
|
30
28
|
from sae_lens.registry import get_sae_class
|
|
31
|
-
from sae_lens.util import filter_valid_dataclass_fields
|
|
29
|
+
from sae_lens.util import filter_valid_dataclass_fields, str_to_dtype
|
|
32
30
|
|
|
33
31
|
LLM_METADATA_KEYS = {
|
|
34
32
|
"model_name",
|
|
@@ -51,6 +49,21 @@ LLM_METADATA_KEYS = {
|
|
|
51
49
|
}
|
|
52
50
|
|
|
53
51
|
|
|
52
|
+
def load_safetensors_weights(
|
|
53
|
+
path: str | Path, device: str = "cpu", dtype: torch.dtype | str | None = None
|
|
54
|
+
) -> dict[str, torch.Tensor]:
|
|
55
|
+
"""Load safetensors weights and optionally convert to a different dtype"""
|
|
56
|
+
loaded_weights = {}
|
|
57
|
+
dtype = str_to_dtype(dtype) if isinstance(dtype, str) else dtype
|
|
58
|
+
with safe_open(path, framework="pt", device=device) as f:
|
|
59
|
+
for k in f.keys(): # noqa: SIM118
|
|
60
|
+
weight = f.get_tensor(k)
|
|
61
|
+
if dtype is not None:
|
|
62
|
+
weight = weight.to(dtype=dtype)
|
|
63
|
+
loaded_weights[k] = weight
|
|
64
|
+
return loaded_weights
|
|
65
|
+
|
|
66
|
+
|
|
54
67
|
# loaders take in a release, sae_id, device, and whether to force download, and returns a tuple of config, state_dict, and log sparsity
|
|
55
68
|
class PretrainedSaeHuggingfaceLoader(Protocol):
|
|
56
69
|
def __call__(
|
|
@@ -341,7 +354,7 @@ def read_sae_components_from_disk(
|
|
|
341
354
|
Given a loaded dictionary and a path to a weight file, load the weights and return the state_dict.
|
|
342
355
|
"""
|
|
343
356
|
if dtype is None:
|
|
344
|
-
dtype =
|
|
357
|
+
dtype = str_to_dtype(cfg_dict["dtype"])
|
|
345
358
|
|
|
346
359
|
state_dict = {}
|
|
347
360
|
with safe_open(weight_path, framework="pt", device=device) as f: # type: ignore
|
|
@@ -695,7 +708,9 @@ def gemma_3_sae_huggingface_loader(
|
|
|
695
708
|
force_download=force_download,
|
|
696
709
|
)
|
|
697
710
|
|
|
698
|
-
raw_state_dict =
|
|
711
|
+
raw_state_dict = load_safetensors_weights(
|
|
712
|
+
sae_path, device=device, dtype=cfg_dict.get("dtype")
|
|
713
|
+
)
|
|
699
714
|
|
|
700
715
|
with torch.no_grad():
|
|
701
716
|
w_dec = raw_state_dict["w_dec"]
|
|
@@ -782,11 +797,13 @@ def get_goodfire_huggingface_loader(
|
|
|
782
797
|
)
|
|
783
798
|
raw_state_dict = torch.load(sae_path, map_location=device)
|
|
784
799
|
|
|
800
|
+
target_dtype = str_to_dtype(cfg_dict.get("dtype", "float32"))
|
|
801
|
+
|
|
785
802
|
state_dict = {
|
|
786
|
-
"W_enc": raw_state_dict["encoder_linear.weight"].T,
|
|
787
|
-
"W_dec": raw_state_dict["decoder_linear.weight"].T,
|
|
788
|
-
"b_enc": raw_state_dict["encoder_linear.bias"],
|
|
789
|
-
"b_dec": raw_state_dict["decoder_linear.bias"],
|
|
803
|
+
"W_enc": raw_state_dict["encoder_linear.weight"].T.to(dtype=target_dtype),
|
|
804
|
+
"W_dec": raw_state_dict["decoder_linear.weight"].T.to(dtype=target_dtype),
|
|
805
|
+
"b_enc": raw_state_dict["encoder_linear.bias"].to(dtype=target_dtype),
|
|
806
|
+
"b_dec": raw_state_dict["decoder_linear.bias"].to(dtype=target_dtype),
|
|
790
807
|
}
|
|
791
808
|
|
|
792
809
|
return cfg_dict, state_dict, None
|
|
@@ -889,26 +906,19 @@ def llama_scope_sae_huggingface_loader(
|
|
|
889
906
|
force_download=force_download,
|
|
890
907
|
)
|
|
891
908
|
|
|
892
|
-
|
|
893
|
-
|
|
909
|
+
state_dict_loaded = load_safetensors_weights(
|
|
910
|
+
sae_path, device=device, dtype=cfg_dict.get("dtype")
|
|
911
|
+
)
|
|
894
912
|
|
|
895
913
|
# Convert and organize the weights
|
|
896
914
|
state_dict = {
|
|
897
|
-
"W_enc": state_dict_loaded["encoder.weight"]
|
|
898
|
-
|
|
899
|
-
.
|
|
900
|
-
"
|
|
901
|
-
.to(dtype=DTYPE_MAP[cfg_dict["dtype"]])
|
|
902
|
-
.T,
|
|
903
|
-
"b_enc": state_dict_loaded["encoder.bias"].to(
|
|
904
|
-
dtype=DTYPE_MAP[cfg_dict["dtype"]]
|
|
905
|
-
),
|
|
906
|
-
"b_dec": state_dict_loaded["decoder.bias"].to(
|
|
907
|
-
dtype=DTYPE_MAP[cfg_dict["dtype"]]
|
|
908
|
-
),
|
|
915
|
+
"W_enc": state_dict_loaded["encoder.weight"].T,
|
|
916
|
+
"W_dec": state_dict_loaded["decoder.weight"].T,
|
|
917
|
+
"b_enc": state_dict_loaded["encoder.bias"],
|
|
918
|
+
"b_dec": state_dict_loaded["decoder.bias"],
|
|
909
919
|
"threshold": torch.ones(
|
|
910
920
|
cfg_dict["d_sae"],
|
|
911
|
-
dtype=
|
|
921
|
+
dtype=str_to_dtype(cfg_dict["dtype"]),
|
|
912
922
|
device=cfg_dict["device"],
|
|
913
923
|
)
|
|
914
924
|
* cfg_dict["jump_relu_threshold"],
|
|
@@ -1219,26 +1229,17 @@ def llama_scope_r1_distill_sae_huggingface_loader(
|
|
|
1219
1229
|
force_download=force_download,
|
|
1220
1230
|
)
|
|
1221
1231
|
|
|
1222
|
-
|
|
1223
|
-
|
|
1232
|
+
state_dict_loaded = load_safetensors_weights(
|
|
1233
|
+
sae_path, device=device, dtype=cfg_dict.get("dtype")
|
|
1234
|
+
)
|
|
1224
1235
|
|
|
1225
1236
|
# Convert and organize the weights
|
|
1226
1237
|
state_dict = {
|
|
1227
|
-
"W_enc": state_dict_loaded["encoder.weight"]
|
|
1228
|
-
|
|
1229
|
-
.
|
|
1230
|
-
"
|
|
1231
|
-
|
|
1232
|
-
.T,
|
|
1233
|
-
"b_enc": state_dict_loaded["encoder.bias"].to(
|
|
1234
|
-
dtype=DTYPE_MAP[cfg_dict["dtype"]]
|
|
1235
|
-
),
|
|
1236
|
-
"b_dec": state_dict_loaded["decoder.bias"].to(
|
|
1237
|
-
dtype=DTYPE_MAP[cfg_dict["dtype"]]
|
|
1238
|
-
),
|
|
1239
|
-
"threshold": state_dict_loaded["log_jumprelu_threshold"]
|
|
1240
|
-
.to(dtype=DTYPE_MAP[cfg_dict["dtype"]])
|
|
1241
|
-
.exp(),
|
|
1238
|
+
"W_enc": state_dict_loaded["encoder.weight"].T,
|
|
1239
|
+
"W_dec": state_dict_loaded["decoder.weight"].T,
|
|
1240
|
+
"b_enc": state_dict_loaded["encoder.bias"],
|
|
1241
|
+
"b_dec": state_dict_loaded["decoder.bias"],
|
|
1242
|
+
"threshold": state_dict_loaded["log_jumprelu_threshold"].exp(),
|
|
1242
1243
|
}
|
|
1243
1244
|
|
|
1244
1245
|
# No sparsity tensor for Llama Scope SAEs
|
|
@@ -1358,34 +1359,34 @@ def sparsify_disk_loader(
|
|
|
1358
1359
|
cfg_dict = get_sparsify_config_from_disk(path, device, cfg_overrides)
|
|
1359
1360
|
|
|
1360
1361
|
weight_path = Path(path) / SPARSIFY_WEIGHTS_FILENAME
|
|
1361
|
-
state_dict_loaded =
|
|
1362
|
-
|
|
1363
|
-
|
|
1362
|
+
state_dict_loaded = load_safetensors_weights(
|
|
1363
|
+
weight_path, device=device, dtype=cfg_dict.get("dtype")
|
|
1364
|
+
)
|
|
1364
1365
|
|
|
1365
1366
|
W_enc = (
|
|
1366
1367
|
state_dict_loaded["W_enc"]
|
|
1367
1368
|
if "W_enc" in state_dict_loaded
|
|
1368
1369
|
else state_dict_loaded["encoder.weight"].T
|
|
1369
|
-
)
|
|
1370
|
+
)
|
|
1370
1371
|
|
|
1371
1372
|
if "W_dec" in state_dict_loaded:
|
|
1372
|
-
W_dec = state_dict_loaded["W_dec"].T
|
|
1373
|
+
W_dec = state_dict_loaded["W_dec"].T
|
|
1373
1374
|
else:
|
|
1374
|
-
W_dec = state_dict_loaded["decoder.weight"].T
|
|
1375
|
+
W_dec = state_dict_loaded["decoder.weight"].T
|
|
1375
1376
|
|
|
1376
1377
|
if "b_enc" in state_dict_loaded:
|
|
1377
|
-
b_enc = state_dict_loaded["b_enc"]
|
|
1378
|
+
b_enc = state_dict_loaded["b_enc"]
|
|
1378
1379
|
elif "encoder.bias" in state_dict_loaded:
|
|
1379
|
-
b_enc = state_dict_loaded["encoder.bias"]
|
|
1380
|
+
b_enc = state_dict_loaded["encoder.bias"]
|
|
1380
1381
|
else:
|
|
1381
|
-
b_enc = torch.zeros(cfg_dict["d_sae"], dtype=dtype, device=device)
|
|
1382
|
+
b_enc = torch.zeros(cfg_dict["d_sae"], dtype=W_dec.dtype, device=device)
|
|
1382
1383
|
|
|
1383
1384
|
if "b_dec" in state_dict_loaded:
|
|
1384
|
-
b_dec = state_dict_loaded["b_dec"]
|
|
1385
|
+
b_dec = state_dict_loaded["b_dec"]
|
|
1385
1386
|
elif "decoder.bias" in state_dict_loaded:
|
|
1386
|
-
b_dec = state_dict_loaded["decoder.bias"]
|
|
1387
|
+
b_dec = state_dict_loaded["decoder.bias"]
|
|
1387
1388
|
else:
|
|
1388
|
-
b_dec = torch.zeros(cfg_dict["d_in"], dtype=dtype, device=device)
|
|
1389
|
+
b_dec = torch.zeros(cfg_dict["d_in"], dtype=W_dec.dtype, device=device)
|
|
1389
1390
|
|
|
1390
1391
|
state_dict = {"W_enc": W_enc, "b_enc": b_enc, "W_dec": W_dec, "b_dec": b_dec}
|
|
1391
1392
|
return cfg_dict, state_dict
|
|
@@ -1616,7 +1617,9 @@ def mwhanna_transcoder_huggingface_loader(
|
|
|
1616
1617
|
)
|
|
1617
1618
|
|
|
1618
1619
|
# Load weights from safetensors
|
|
1619
|
-
state_dict =
|
|
1620
|
+
state_dict = load_safetensors_weights(
|
|
1621
|
+
file_path, device=device, dtype=cfg_dict.get("dtype")
|
|
1622
|
+
)
|
|
1620
1623
|
state_dict["W_enc"] = state_dict["W_enc"].T
|
|
1621
1624
|
|
|
1622
1625
|
return cfg_dict, state_dict, None
|
|
@@ -1700,8 +1703,12 @@ def mntss_clt_layer_huggingface_loader(
|
|
|
1700
1703
|
force_download=force_download,
|
|
1701
1704
|
)
|
|
1702
1705
|
|
|
1703
|
-
encoder_state_dict =
|
|
1704
|
-
|
|
1706
|
+
encoder_state_dict = load_safetensors_weights(
|
|
1707
|
+
encoder_path, device=device, dtype=cfg_dict.get("dtype")
|
|
1708
|
+
)
|
|
1709
|
+
decoder_state_dict = load_safetensors_weights(
|
|
1710
|
+
decoder_path, device=device, dtype=cfg_dict.get("dtype")
|
|
1711
|
+
)
|
|
1705
1712
|
|
|
1706
1713
|
with torch.no_grad():
|
|
1707
1714
|
state_dict = {
|
|
@@ -1844,7 +1851,9 @@ def temporal_sae_huggingface_loader(
|
|
|
1844
1851
|
)
|
|
1845
1852
|
|
|
1846
1853
|
# Load checkpoint from safetensors
|
|
1847
|
-
state_dict_raw =
|
|
1854
|
+
state_dict_raw = load_safetensors_weights(
|
|
1855
|
+
ckpt_path, device=device, dtype=cfg_dict.get("dtype")
|
|
1856
|
+
)
|
|
1848
1857
|
|
|
1849
1858
|
# Convert to SAELens naming convention
|
|
1850
1859
|
# TemporalSAE uses: D (decoder), E (encoder), b (bias), attn_layers.*
|
|
@@ -9072,150 +9072,150 @@ gemma-scope-2-27b-it-transcoders-all:
|
|
|
9072
9072
|
- id: layer_5_width_262k_l0_small_affine
|
|
9073
9073
|
path: transcoder_all/layer_5_width_262k_l0_small_affine
|
|
9074
9074
|
l0: 12
|
|
9075
|
-
|
|
9076
|
-
|
|
9077
|
-
|
|
9078
|
-
|
|
9079
|
-
|
|
9080
|
-
|
|
9081
|
-
|
|
9082
|
-
|
|
9083
|
-
|
|
9084
|
-
|
|
9085
|
-
|
|
9086
|
-
|
|
9087
|
-
|
|
9088
|
-
|
|
9089
|
-
|
|
9090
|
-
|
|
9091
|
-
|
|
9092
|
-
|
|
9093
|
-
|
|
9094
|
-
|
|
9095
|
-
|
|
9096
|
-
|
|
9097
|
-
|
|
9098
|
-
|
|
9099
|
-
|
|
9100
|
-
|
|
9101
|
-
|
|
9102
|
-
|
|
9103
|
-
|
|
9104
|
-
|
|
9105
|
-
|
|
9106
|
-
|
|
9107
|
-
|
|
9108
|
-
|
|
9109
|
-
|
|
9110
|
-
|
|
9111
|
-
|
|
9112
|
-
|
|
9113
|
-
|
|
9114
|
-
|
|
9115
|
-
|
|
9116
|
-
|
|
9117
|
-
|
|
9118
|
-
|
|
9119
|
-
|
|
9120
|
-
|
|
9121
|
-
|
|
9122
|
-
|
|
9123
|
-
|
|
9124
|
-
|
|
9125
|
-
|
|
9126
|
-
|
|
9127
|
-
|
|
9128
|
-
|
|
9129
|
-
|
|
9130
|
-
|
|
9131
|
-
|
|
9132
|
-
|
|
9133
|
-
|
|
9134
|
-
|
|
9135
|
-
|
|
9136
|
-
|
|
9137
|
-
|
|
9138
|
-
|
|
9139
|
-
|
|
9140
|
-
|
|
9141
|
-
|
|
9142
|
-
|
|
9143
|
-
|
|
9144
|
-
|
|
9145
|
-
|
|
9146
|
-
|
|
9147
|
-
|
|
9148
|
-
|
|
9149
|
-
|
|
9150
|
-
|
|
9151
|
-
|
|
9152
|
-
|
|
9153
|
-
|
|
9154
|
-
|
|
9155
|
-
|
|
9156
|
-
|
|
9157
|
-
|
|
9158
|
-
|
|
9159
|
-
|
|
9160
|
-
|
|
9161
|
-
|
|
9162
|
-
|
|
9163
|
-
|
|
9164
|
-
|
|
9165
|
-
|
|
9166
|
-
|
|
9167
|
-
|
|
9168
|
-
|
|
9169
|
-
|
|
9170
|
-
|
|
9171
|
-
|
|
9172
|
-
|
|
9173
|
-
|
|
9174
|
-
|
|
9175
|
-
|
|
9176
|
-
|
|
9177
|
-
|
|
9178
|
-
|
|
9179
|
-
|
|
9180
|
-
|
|
9181
|
-
|
|
9182
|
-
|
|
9183
|
-
|
|
9184
|
-
|
|
9185
|
-
|
|
9186
|
-
|
|
9187
|
-
|
|
9188
|
-
|
|
9189
|
-
|
|
9190
|
-
|
|
9191
|
-
|
|
9192
|
-
|
|
9193
|
-
|
|
9194
|
-
|
|
9195
|
-
|
|
9196
|
-
|
|
9197
|
-
|
|
9198
|
-
|
|
9199
|
-
|
|
9200
|
-
|
|
9201
|
-
|
|
9202
|
-
|
|
9203
|
-
|
|
9204
|
-
|
|
9205
|
-
|
|
9206
|
-
|
|
9207
|
-
|
|
9208
|
-
|
|
9209
|
-
|
|
9210
|
-
|
|
9211
|
-
|
|
9212
|
-
|
|
9213
|
-
|
|
9214
|
-
|
|
9215
|
-
|
|
9216
|
-
|
|
9217
|
-
|
|
9218
|
-
|
|
9075
|
+
- id: layer_60_width_16k_l0_big
|
|
9076
|
+
path: transcoder_all/layer_60_width_16k_l0_big
|
|
9077
|
+
l0: 120
|
|
9078
|
+
- id: layer_60_width_16k_l0_big_affine
|
|
9079
|
+
path: transcoder_all/layer_60_width_16k_l0_big_affine
|
|
9080
|
+
l0: 120
|
|
9081
|
+
- id: layer_60_width_16k_l0_small
|
|
9082
|
+
path: transcoder_all/layer_60_width_16k_l0_small
|
|
9083
|
+
l0: 20
|
|
9084
|
+
- id: layer_60_width_16k_l0_small_affine
|
|
9085
|
+
path: transcoder_all/layer_60_width_16k_l0_small_affine
|
|
9086
|
+
l0: 20
|
|
9087
|
+
- id: layer_60_width_262k_l0_big
|
|
9088
|
+
path: transcoder_all/layer_60_width_262k_l0_big
|
|
9089
|
+
l0: 120
|
|
9090
|
+
- id: layer_60_width_262k_l0_big_affine
|
|
9091
|
+
path: transcoder_all/layer_60_width_262k_l0_big_affine
|
|
9092
|
+
l0: 120
|
|
9093
|
+
- id: layer_60_width_262k_l0_small
|
|
9094
|
+
path: transcoder_all/layer_60_width_262k_l0_small
|
|
9095
|
+
l0: 20
|
|
9096
|
+
- id: layer_60_width_262k_l0_small_affine
|
|
9097
|
+
path: transcoder_all/layer_60_width_262k_l0_small_affine
|
|
9098
|
+
l0: 20
|
|
9099
|
+
- id: layer_61_width_16k_l0_big
|
|
9100
|
+
path: transcoder_all/layer_61_width_16k_l0_big
|
|
9101
|
+
l0: 120
|
|
9102
|
+
- id: layer_61_width_16k_l0_big_affine
|
|
9103
|
+
path: transcoder_all/layer_61_width_16k_l0_big_affine
|
|
9104
|
+
l0: 120
|
|
9105
|
+
- id: layer_61_width_16k_l0_small
|
|
9106
|
+
path: transcoder_all/layer_61_width_16k_l0_small
|
|
9107
|
+
l0: 20
|
|
9108
|
+
- id: layer_61_width_16k_l0_small_affine
|
|
9109
|
+
path: transcoder_all/layer_61_width_16k_l0_small_affine
|
|
9110
|
+
l0: 20
|
|
9111
|
+
- id: layer_61_width_262k_l0_big
|
|
9112
|
+
path: transcoder_all/layer_61_width_262k_l0_big
|
|
9113
|
+
l0: 120
|
|
9114
|
+
- id: layer_61_width_262k_l0_big_affine
|
|
9115
|
+
path: transcoder_all/layer_61_width_262k_l0_big_affine
|
|
9116
|
+
l0: 120
|
|
9117
|
+
- id: layer_61_width_262k_l0_small
|
|
9118
|
+
path: transcoder_all/layer_61_width_262k_l0_small
|
|
9119
|
+
l0: 20
|
|
9120
|
+
- id: layer_61_width_262k_l0_small_affine
|
|
9121
|
+
path: transcoder_all/layer_61_width_262k_l0_small_affine
|
|
9122
|
+
l0: 20
|
|
9123
|
+
- id: layer_6_width_16k_l0_big
|
|
9124
|
+
path: transcoder_all/layer_6_width_16k_l0_big
|
|
9125
|
+
l0: 77
|
|
9126
|
+
- id: layer_6_width_16k_l0_big_affine
|
|
9127
|
+
path: transcoder_all/layer_6_width_16k_l0_big_affine
|
|
9128
|
+
l0: 77
|
|
9129
|
+
- id: layer_6_width_16k_l0_small
|
|
9130
|
+
path: transcoder_all/layer_6_width_16k_l0_small
|
|
9131
|
+
l0: 12
|
|
9132
|
+
- id: layer_6_width_16k_l0_small_affine
|
|
9133
|
+
path: transcoder_all/layer_6_width_16k_l0_small_affine
|
|
9134
|
+
l0: 12
|
|
9135
|
+
- id: layer_6_width_262k_l0_big
|
|
9136
|
+
path: transcoder_all/layer_6_width_262k_l0_big
|
|
9137
|
+
l0: 77
|
|
9138
|
+
- id: layer_6_width_262k_l0_big_affine
|
|
9139
|
+
path: transcoder_all/layer_6_width_262k_l0_big_affine
|
|
9140
|
+
l0: 77
|
|
9141
|
+
- id: layer_6_width_262k_l0_small
|
|
9142
|
+
path: transcoder_all/layer_6_width_262k_l0_small
|
|
9143
|
+
l0: 12
|
|
9144
|
+
- id: layer_6_width_262k_l0_small_affine
|
|
9145
|
+
path: transcoder_all/layer_6_width_262k_l0_small_affine
|
|
9146
|
+
l0: 12
|
|
9147
|
+
- id: layer_7_width_16k_l0_big
|
|
9148
|
+
path: transcoder_all/layer_7_width_16k_l0_big
|
|
9149
|
+
l0: 80
|
|
9150
|
+
- id: layer_7_width_16k_l0_big_affine
|
|
9151
|
+
path: transcoder_all/layer_7_width_16k_l0_big_affine
|
|
9152
|
+
l0: 80
|
|
9153
|
+
- id: layer_7_width_16k_l0_small
|
|
9154
|
+
path: transcoder_all/layer_7_width_16k_l0_small
|
|
9155
|
+
l0: 13
|
|
9156
|
+
- id: layer_7_width_16k_l0_small_affine
|
|
9157
|
+
path: transcoder_all/layer_7_width_16k_l0_small_affine
|
|
9158
|
+
l0: 13
|
|
9159
|
+
- id: layer_7_width_262k_l0_big
|
|
9160
|
+
path: transcoder_all/layer_7_width_262k_l0_big
|
|
9161
|
+
l0: 80
|
|
9162
|
+
- id: layer_7_width_262k_l0_big_affine
|
|
9163
|
+
path: transcoder_all/layer_7_width_262k_l0_big_affine
|
|
9164
|
+
l0: 80
|
|
9165
|
+
- id: layer_7_width_262k_l0_small
|
|
9166
|
+
path: transcoder_all/layer_7_width_262k_l0_small
|
|
9167
|
+
l0: 13
|
|
9168
|
+
- id: layer_7_width_262k_l0_small_affine
|
|
9169
|
+
path: transcoder_all/layer_7_width_262k_l0_small_affine
|
|
9170
|
+
l0: 13
|
|
9171
|
+
- id: layer_8_width_16k_l0_big
|
|
9172
|
+
path: transcoder_all/layer_8_width_16k_l0_big
|
|
9173
|
+
l0: 83
|
|
9174
|
+
- id: layer_8_width_16k_l0_big_affine
|
|
9175
|
+
path: transcoder_all/layer_8_width_16k_l0_big_affine
|
|
9176
|
+
l0: 83
|
|
9177
|
+
- id: layer_8_width_16k_l0_small
|
|
9178
|
+
path: transcoder_all/layer_8_width_16k_l0_small
|
|
9179
|
+
l0: 13
|
|
9180
|
+
- id: layer_8_width_16k_l0_small_affine
|
|
9181
|
+
path: transcoder_all/layer_8_width_16k_l0_small_affine
|
|
9182
|
+
l0: 13
|
|
9183
|
+
- id: layer_8_width_262k_l0_big
|
|
9184
|
+
path: transcoder_all/layer_8_width_262k_l0_big
|
|
9185
|
+
l0: 83
|
|
9186
|
+
- id: layer_8_width_262k_l0_big_affine
|
|
9187
|
+
path: transcoder_all/layer_8_width_262k_l0_big_affine
|
|
9188
|
+
l0: 83
|
|
9189
|
+
- id: layer_8_width_262k_l0_small
|
|
9190
|
+
path: transcoder_all/layer_8_width_262k_l0_small
|
|
9191
|
+
l0: 13
|
|
9192
|
+
- id: layer_8_width_262k_l0_small_affine
|
|
9193
|
+
path: transcoder_all/layer_8_width_262k_l0_small_affine
|
|
9194
|
+
l0: 13
|
|
9195
|
+
- id: layer_9_width_16k_l0_big
|
|
9196
|
+
path: transcoder_all/layer_9_width_16k_l0_big
|
|
9197
|
+
l0: 86
|
|
9198
|
+
- id: layer_9_width_16k_l0_big_affine
|
|
9199
|
+
path: transcoder_all/layer_9_width_16k_l0_big_affine
|
|
9200
|
+
l0: 86
|
|
9201
|
+
- id: layer_9_width_16k_l0_small
|
|
9202
|
+
path: transcoder_all/layer_9_width_16k_l0_small
|
|
9203
|
+
l0: 14
|
|
9204
|
+
- id: layer_9_width_16k_l0_small_affine
|
|
9205
|
+
path: transcoder_all/layer_9_width_16k_l0_small_affine
|
|
9206
|
+
l0: 14
|
|
9207
|
+
- id: layer_9_width_262k_l0_big
|
|
9208
|
+
path: transcoder_all/layer_9_width_262k_l0_big
|
|
9209
|
+
l0: 86
|
|
9210
|
+
- id: layer_9_width_262k_l0_big_affine
|
|
9211
|
+
path: transcoder_all/layer_9_width_262k_l0_big_affine
|
|
9212
|
+
l0: 86
|
|
9213
|
+
- id: layer_9_width_262k_l0_small
|
|
9214
|
+
path: transcoder_all/layer_9_width_262k_l0_small
|
|
9215
|
+
l0: 14
|
|
9216
|
+
- id: layer_9_width_262k_l0_small_affine
|
|
9217
|
+
path: transcoder_all/layer_9_width_262k_l0_small_affine
|
|
9218
|
+
l0: 14
|
|
9219
9219
|
gemma-scope-2-27b-it-transcoders:
|
|
9220
9220
|
conversion_func: gemma_3
|
|
9221
9221
|
model: google/gemma-3-27b-it
|