sae-lens 6.25.0__tar.gz → 6.26.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sae_lens-6.25.0 → sae_lens-6.26.0}/PKG-INFO +1 -1
- {sae_lens-6.25.0 → sae_lens-6.26.0}/pyproject.toml +1 -1
- {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/__init__.py +13 -1
- {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/cache_activations_runner.py +2 -2
- {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/config.py +7 -2
- {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/constants.py +8 -0
- {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/loading/pretrained_sae_loaders.py +66 -57
- {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/saes/__init__.py +10 -0
- sae_lens-6.26.0/sae_lens/saes/matching_pursuit_sae.py +334 -0
- {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/saes/sae.py +52 -12
- {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/training/activations_store.py +3 -2
- {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/util.py +21 -0
- {sae_lens-6.25.0 → sae_lens-6.26.0}/LICENSE +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.0}/README.md +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/analysis/__init__.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/analysis/neuronpedia_integration.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/evals.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/llm_sae_training_runner.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/load_model.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/loading/__init__.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/loading/pretrained_saes_directory.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/pretokenize_runner.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/pretrained_saes.yaml +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/registry.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/saes/batchtopk_sae.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/saes/gated_sae.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/saes/jumprelu_sae.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/saes/matryoshka_batchtopk_sae.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/saes/standard_sae.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/saes/temporal_sae.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/saes/topk_sae.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/saes/transcoder.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/tokenization_and_batching.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/training/__init__.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/training/activation_scaler.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/training/mixing_buffer.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/training/optim.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/training/sae_trainer.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/training/types.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
- {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/tutorial/tsea.py +0 -0
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
# ruff: noqa: E402
|
|
2
|
-
__version__ = "6.
|
|
2
|
+
__version__ = "6.26.0"
|
|
3
3
|
|
|
4
4
|
import logging
|
|
5
5
|
|
|
@@ -21,6 +21,10 @@ from sae_lens.saes import (
|
|
|
21
21
|
JumpReLUTrainingSAEConfig,
|
|
22
22
|
JumpReLUTranscoder,
|
|
23
23
|
JumpReLUTranscoderConfig,
|
|
24
|
+
MatchingPursuitSAE,
|
|
25
|
+
MatchingPursuitSAEConfig,
|
|
26
|
+
MatchingPursuitTrainingSAE,
|
|
27
|
+
MatchingPursuitTrainingSAEConfig,
|
|
24
28
|
MatryoshkaBatchTopKTrainingSAE,
|
|
25
29
|
MatryoshkaBatchTopKTrainingSAEConfig,
|
|
26
30
|
SAEConfig,
|
|
@@ -113,6 +117,10 @@ __all__ = [
|
|
|
113
117
|
"MatryoshkaBatchTopKTrainingSAEConfig",
|
|
114
118
|
"TemporalSAE",
|
|
115
119
|
"TemporalSAEConfig",
|
|
120
|
+
"MatchingPursuitSAE",
|
|
121
|
+
"MatchingPursuitTrainingSAE",
|
|
122
|
+
"MatchingPursuitSAEConfig",
|
|
123
|
+
"MatchingPursuitTrainingSAEConfig",
|
|
116
124
|
]
|
|
117
125
|
|
|
118
126
|
|
|
@@ -139,3 +147,7 @@ register_sae_class(
|
|
|
139
147
|
"jumprelu_skip_transcoder", JumpReLUSkipTranscoder, JumpReLUSkipTranscoderConfig
|
|
140
148
|
)
|
|
141
149
|
register_sae_class("temporal", TemporalSAE, TemporalSAEConfig)
|
|
150
|
+
register_sae_class("matching_pursuit", MatchingPursuitSAE, MatchingPursuitSAEConfig)
|
|
151
|
+
register_sae_training_class(
|
|
152
|
+
"matching_pursuit", MatchingPursuitTrainingSAE, MatchingPursuitTrainingSAEConfig
|
|
153
|
+
)
|
|
@@ -14,9 +14,9 @@ from transformer_lens.HookedTransformer import HookedRootModule
|
|
|
14
14
|
|
|
15
15
|
from sae_lens import logger
|
|
16
16
|
from sae_lens.config import CacheActivationsRunnerConfig
|
|
17
|
-
from sae_lens.constants import DTYPE_MAP
|
|
18
17
|
from sae_lens.load_model import load_model
|
|
19
18
|
from sae_lens.training.activations_store import ActivationsStore
|
|
19
|
+
from sae_lens.util import str_to_dtype
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
def _mk_activations_store(
|
|
@@ -97,7 +97,7 @@ class CacheActivationsRunner:
|
|
|
97
97
|
bytes_per_token = (
|
|
98
98
|
self.cfg.d_in * self.cfg.dtype.itemsize
|
|
99
99
|
if isinstance(self.cfg.dtype, torch.dtype)
|
|
100
|
-
else
|
|
100
|
+
else str_to_dtype(self.cfg.dtype).itemsize
|
|
101
101
|
)
|
|
102
102
|
total_training_tokens = self.cfg.n_seq_in_dataset * self.context_size
|
|
103
103
|
total_disk_space_gb = total_training_tokens * bytes_per_token / 10**9
|
|
@@ -17,9 +17,14 @@ from datasets import (
|
|
|
17
17
|
)
|
|
18
18
|
|
|
19
19
|
from sae_lens import __version__, logger
|
|
20
|
-
|
|
20
|
+
|
|
21
|
+
# keeping this unused import since some SAELens deps import DTYPE_MAP from config
|
|
22
|
+
from sae_lens.constants import (
|
|
23
|
+
DTYPE_MAP, # noqa: F401 # pyright: ignore[reportUnusedImport]
|
|
24
|
+
)
|
|
21
25
|
from sae_lens.registry import get_sae_training_class
|
|
22
26
|
from sae_lens.saes.sae import TrainingSAEConfig
|
|
27
|
+
from sae_lens.util import str_to_dtype
|
|
23
28
|
|
|
24
29
|
if TYPE_CHECKING:
|
|
25
30
|
pass
|
|
@@ -563,7 +568,7 @@ class CacheActivationsRunnerConfig:
|
|
|
563
568
|
|
|
564
569
|
@property
|
|
565
570
|
def bytes_per_token(self) -> int:
|
|
566
|
-
return self.d_in *
|
|
571
|
+
return self.d_in * str_to_dtype(self.dtype).itemsize
|
|
567
572
|
|
|
568
573
|
@property
|
|
569
574
|
def n_tokens_in_buffer(self) -> int:
|
|
@@ -11,6 +11,14 @@ DTYPE_MAP = {
|
|
|
11
11
|
"torch.bfloat16": torch.bfloat16,
|
|
12
12
|
}
|
|
13
13
|
|
|
14
|
+
# Reverse mapping from torch.dtype to canonical string format
|
|
15
|
+
DTYPE_TO_STR = {
|
|
16
|
+
torch.float32: "float32",
|
|
17
|
+
torch.float64: "float64",
|
|
18
|
+
torch.float16: "float16",
|
|
19
|
+
torch.bfloat16: "bfloat16",
|
|
20
|
+
}
|
|
21
|
+
|
|
14
22
|
|
|
15
23
|
SPARSITY_FILENAME = "sparsity.safetensors"
|
|
16
24
|
SAE_WEIGHTS_FILENAME = "sae_weights.safetensors"
|
|
@@ -12,11 +12,9 @@ from huggingface_hub import hf_hub_download, hf_hub_url
|
|
|
12
12
|
from huggingface_hub.utils import EntryNotFoundError, build_hf_headers
|
|
13
13
|
from packaging.version import Version
|
|
14
14
|
from safetensors import safe_open
|
|
15
|
-
from safetensors.torch import load_file
|
|
16
15
|
|
|
17
16
|
from sae_lens import logger
|
|
18
17
|
from sae_lens.constants import (
|
|
19
|
-
DTYPE_MAP,
|
|
20
18
|
SAE_CFG_FILENAME,
|
|
21
19
|
SAE_WEIGHTS_FILENAME,
|
|
22
20
|
SPARSIFY_WEIGHTS_FILENAME,
|
|
@@ -28,7 +26,7 @@ from sae_lens.loading.pretrained_saes_directory import (
|
|
|
28
26
|
get_repo_id_and_folder_name,
|
|
29
27
|
)
|
|
30
28
|
from sae_lens.registry import get_sae_class
|
|
31
|
-
from sae_lens.util import filter_valid_dataclass_fields
|
|
29
|
+
from sae_lens.util import filter_valid_dataclass_fields, str_to_dtype
|
|
32
30
|
|
|
33
31
|
LLM_METADATA_KEYS = {
|
|
34
32
|
"model_name",
|
|
@@ -51,6 +49,21 @@ LLM_METADATA_KEYS = {
|
|
|
51
49
|
}
|
|
52
50
|
|
|
53
51
|
|
|
52
|
+
def load_safetensors_weights(
|
|
53
|
+
path: str | Path, device: str = "cpu", dtype: torch.dtype | str | None = None
|
|
54
|
+
) -> dict[str, torch.Tensor]:
|
|
55
|
+
"""Load safetensors weights and optionally convert to a different dtype"""
|
|
56
|
+
loaded_weights = {}
|
|
57
|
+
dtype = str_to_dtype(dtype) if isinstance(dtype, str) else dtype
|
|
58
|
+
with safe_open(path, framework="pt", device=device) as f:
|
|
59
|
+
for k in f.keys(): # noqa: SIM118
|
|
60
|
+
weight = f.get_tensor(k)
|
|
61
|
+
if dtype is not None:
|
|
62
|
+
weight = weight.to(dtype=dtype)
|
|
63
|
+
loaded_weights[k] = weight
|
|
64
|
+
return loaded_weights
|
|
65
|
+
|
|
66
|
+
|
|
54
67
|
# loaders take in a release, sae_id, device, and whether to force download, and returns a tuple of config, state_dict, and log sparsity
|
|
55
68
|
class PretrainedSaeHuggingfaceLoader(Protocol):
|
|
56
69
|
def __call__(
|
|
@@ -341,7 +354,7 @@ def read_sae_components_from_disk(
|
|
|
341
354
|
Given a loaded dictionary and a path to a weight file, load the weights and return the state_dict.
|
|
342
355
|
"""
|
|
343
356
|
if dtype is None:
|
|
344
|
-
dtype =
|
|
357
|
+
dtype = str_to_dtype(cfg_dict["dtype"])
|
|
345
358
|
|
|
346
359
|
state_dict = {}
|
|
347
360
|
with safe_open(weight_path, framework="pt", device=device) as f: # type: ignore
|
|
@@ -695,7 +708,9 @@ def gemma_3_sae_huggingface_loader(
|
|
|
695
708
|
force_download=force_download,
|
|
696
709
|
)
|
|
697
710
|
|
|
698
|
-
raw_state_dict =
|
|
711
|
+
raw_state_dict = load_safetensors_weights(
|
|
712
|
+
sae_path, device=device, dtype=cfg_dict.get("dtype")
|
|
713
|
+
)
|
|
699
714
|
|
|
700
715
|
with torch.no_grad():
|
|
701
716
|
w_dec = raw_state_dict["w_dec"]
|
|
@@ -782,11 +797,13 @@ def get_goodfire_huggingface_loader(
|
|
|
782
797
|
)
|
|
783
798
|
raw_state_dict = torch.load(sae_path, map_location=device)
|
|
784
799
|
|
|
800
|
+
target_dtype = str_to_dtype(cfg_dict.get("dtype", "float32"))
|
|
801
|
+
|
|
785
802
|
state_dict = {
|
|
786
|
-
"W_enc": raw_state_dict["encoder_linear.weight"].T,
|
|
787
|
-
"W_dec": raw_state_dict["decoder_linear.weight"].T,
|
|
788
|
-
"b_enc": raw_state_dict["encoder_linear.bias"],
|
|
789
|
-
"b_dec": raw_state_dict["decoder_linear.bias"],
|
|
803
|
+
"W_enc": raw_state_dict["encoder_linear.weight"].T.to(dtype=target_dtype),
|
|
804
|
+
"W_dec": raw_state_dict["decoder_linear.weight"].T.to(dtype=target_dtype),
|
|
805
|
+
"b_enc": raw_state_dict["encoder_linear.bias"].to(dtype=target_dtype),
|
|
806
|
+
"b_dec": raw_state_dict["decoder_linear.bias"].to(dtype=target_dtype),
|
|
790
807
|
}
|
|
791
808
|
|
|
792
809
|
return cfg_dict, state_dict, None
|
|
@@ -889,26 +906,19 @@ def llama_scope_sae_huggingface_loader(
|
|
|
889
906
|
force_download=force_download,
|
|
890
907
|
)
|
|
891
908
|
|
|
892
|
-
|
|
893
|
-
|
|
909
|
+
state_dict_loaded = load_safetensors_weights(
|
|
910
|
+
sae_path, device=device, dtype=cfg_dict.get("dtype")
|
|
911
|
+
)
|
|
894
912
|
|
|
895
913
|
# Convert and organize the weights
|
|
896
914
|
state_dict = {
|
|
897
|
-
"W_enc": state_dict_loaded["encoder.weight"]
|
|
898
|
-
|
|
899
|
-
.
|
|
900
|
-
"
|
|
901
|
-
.to(dtype=DTYPE_MAP[cfg_dict["dtype"]])
|
|
902
|
-
.T,
|
|
903
|
-
"b_enc": state_dict_loaded["encoder.bias"].to(
|
|
904
|
-
dtype=DTYPE_MAP[cfg_dict["dtype"]]
|
|
905
|
-
),
|
|
906
|
-
"b_dec": state_dict_loaded["decoder.bias"].to(
|
|
907
|
-
dtype=DTYPE_MAP[cfg_dict["dtype"]]
|
|
908
|
-
),
|
|
915
|
+
"W_enc": state_dict_loaded["encoder.weight"].T,
|
|
916
|
+
"W_dec": state_dict_loaded["decoder.weight"].T,
|
|
917
|
+
"b_enc": state_dict_loaded["encoder.bias"],
|
|
918
|
+
"b_dec": state_dict_loaded["decoder.bias"],
|
|
909
919
|
"threshold": torch.ones(
|
|
910
920
|
cfg_dict["d_sae"],
|
|
911
|
-
dtype=
|
|
921
|
+
dtype=str_to_dtype(cfg_dict["dtype"]),
|
|
912
922
|
device=cfg_dict["device"],
|
|
913
923
|
)
|
|
914
924
|
* cfg_dict["jump_relu_threshold"],
|
|
@@ -1219,26 +1229,17 @@ def llama_scope_r1_distill_sae_huggingface_loader(
|
|
|
1219
1229
|
force_download=force_download,
|
|
1220
1230
|
)
|
|
1221
1231
|
|
|
1222
|
-
|
|
1223
|
-
|
|
1232
|
+
state_dict_loaded = load_safetensors_weights(
|
|
1233
|
+
sae_path, device=device, dtype=cfg_dict.get("dtype")
|
|
1234
|
+
)
|
|
1224
1235
|
|
|
1225
1236
|
# Convert and organize the weights
|
|
1226
1237
|
state_dict = {
|
|
1227
|
-
"W_enc": state_dict_loaded["encoder.weight"]
|
|
1228
|
-
|
|
1229
|
-
.
|
|
1230
|
-
"
|
|
1231
|
-
|
|
1232
|
-
.T,
|
|
1233
|
-
"b_enc": state_dict_loaded["encoder.bias"].to(
|
|
1234
|
-
dtype=DTYPE_MAP[cfg_dict["dtype"]]
|
|
1235
|
-
),
|
|
1236
|
-
"b_dec": state_dict_loaded["decoder.bias"].to(
|
|
1237
|
-
dtype=DTYPE_MAP[cfg_dict["dtype"]]
|
|
1238
|
-
),
|
|
1239
|
-
"threshold": state_dict_loaded["log_jumprelu_threshold"]
|
|
1240
|
-
.to(dtype=DTYPE_MAP[cfg_dict["dtype"]])
|
|
1241
|
-
.exp(),
|
|
1238
|
+
"W_enc": state_dict_loaded["encoder.weight"].T,
|
|
1239
|
+
"W_dec": state_dict_loaded["decoder.weight"].T,
|
|
1240
|
+
"b_enc": state_dict_loaded["encoder.bias"],
|
|
1241
|
+
"b_dec": state_dict_loaded["decoder.bias"],
|
|
1242
|
+
"threshold": state_dict_loaded["log_jumprelu_threshold"].exp(),
|
|
1242
1243
|
}
|
|
1243
1244
|
|
|
1244
1245
|
# No sparsity tensor for Llama Scope SAEs
|
|
@@ -1358,34 +1359,34 @@ def sparsify_disk_loader(
|
|
|
1358
1359
|
cfg_dict = get_sparsify_config_from_disk(path, device, cfg_overrides)
|
|
1359
1360
|
|
|
1360
1361
|
weight_path = Path(path) / SPARSIFY_WEIGHTS_FILENAME
|
|
1361
|
-
state_dict_loaded =
|
|
1362
|
-
|
|
1363
|
-
|
|
1362
|
+
state_dict_loaded = load_safetensors_weights(
|
|
1363
|
+
weight_path, device=device, dtype=cfg_dict.get("dtype")
|
|
1364
|
+
)
|
|
1364
1365
|
|
|
1365
1366
|
W_enc = (
|
|
1366
1367
|
state_dict_loaded["W_enc"]
|
|
1367
1368
|
if "W_enc" in state_dict_loaded
|
|
1368
1369
|
else state_dict_loaded["encoder.weight"].T
|
|
1369
|
-
)
|
|
1370
|
+
)
|
|
1370
1371
|
|
|
1371
1372
|
if "W_dec" in state_dict_loaded:
|
|
1372
|
-
W_dec = state_dict_loaded["W_dec"].T
|
|
1373
|
+
W_dec = state_dict_loaded["W_dec"].T
|
|
1373
1374
|
else:
|
|
1374
|
-
W_dec = state_dict_loaded["decoder.weight"].T
|
|
1375
|
+
W_dec = state_dict_loaded["decoder.weight"].T
|
|
1375
1376
|
|
|
1376
1377
|
if "b_enc" in state_dict_loaded:
|
|
1377
|
-
b_enc = state_dict_loaded["b_enc"]
|
|
1378
|
+
b_enc = state_dict_loaded["b_enc"]
|
|
1378
1379
|
elif "encoder.bias" in state_dict_loaded:
|
|
1379
|
-
b_enc = state_dict_loaded["encoder.bias"]
|
|
1380
|
+
b_enc = state_dict_loaded["encoder.bias"]
|
|
1380
1381
|
else:
|
|
1381
|
-
b_enc = torch.zeros(cfg_dict["d_sae"], dtype=dtype, device=device)
|
|
1382
|
+
b_enc = torch.zeros(cfg_dict["d_sae"], dtype=W_dec.dtype, device=device)
|
|
1382
1383
|
|
|
1383
1384
|
if "b_dec" in state_dict_loaded:
|
|
1384
|
-
b_dec = state_dict_loaded["b_dec"]
|
|
1385
|
+
b_dec = state_dict_loaded["b_dec"]
|
|
1385
1386
|
elif "decoder.bias" in state_dict_loaded:
|
|
1386
|
-
b_dec = state_dict_loaded["decoder.bias"]
|
|
1387
|
+
b_dec = state_dict_loaded["decoder.bias"]
|
|
1387
1388
|
else:
|
|
1388
|
-
b_dec = torch.zeros(cfg_dict["d_in"], dtype=dtype, device=device)
|
|
1389
|
+
b_dec = torch.zeros(cfg_dict["d_in"], dtype=W_dec.dtype, device=device)
|
|
1389
1390
|
|
|
1390
1391
|
state_dict = {"W_enc": W_enc, "b_enc": b_enc, "W_dec": W_dec, "b_dec": b_dec}
|
|
1391
1392
|
return cfg_dict, state_dict
|
|
@@ -1616,7 +1617,9 @@ def mwhanna_transcoder_huggingface_loader(
|
|
|
1616
1617
|
)
|
|
1617
1618
|
|
|
1618
1619
|
# Load weights from safetensors
|
|
1619
|
-
state_dict =
|
|
1620
|
+
state_dict = load_safetensors_weights(
|
|
1621
|
+
file_path, device=device, dtype=cfg_dict.get("dtype")
|
|
1622
|
+
)
|
|
1620
1623
|
state_dict["W_enc"] = state_dict["W_enc"].T
|
|
1621
1624
|
|
|
1622
1625
|
return cfg_dict, state_dict, None
|
|
@@ -1700,8 +1703,12 @@ def mntss_clt_layer_huggingface_loader(
|
|
|
1700
1703
|
force_download=force_download,
|
|
1701
1704
|
)
|
|
1702
1705
|
|
|
1703
|
-
encoder_state_dict =
|
|
1704
|
-
|
|
1706
|
+
encoder_state_dict = load_safetensors_weights(
|
|
1707
|
+
encoder_path, device=device, dtype=cfg_dict.get("dtype")
|
|
1708
|
+
)
|
|
1709
|
+
decoder_state_dict = load_safetensors_weights(
|
|
1710
|
+
decoder_path, device=device, dtype=cfg_dict.get("dtype")
|
|
1711
|
+
)
|
|
1705
1712
|
|
|
1706
1713
|
with torch.no_grad():
|
|
1707
1714
|
state_dict = {
|
|
@@ -1844,7 +1851,9 @@ def temporal_sae_huggingface_loader(
|
|
|
1844
1851
|
)
|
|
1845
1852
|
|
|
1846
1853
|
# Load checkpoint from safetensors
|
|
1847
|
-
state_dict_raw =
|
|
1854
|
+
state_dict_raw = load_safetensors_weights(
|
|
1855
|
+
ckpt_path, device=device, dtype=cfg_dict.get("dtype")
|
|
1856
|
+
)
|
|
1848
1857
|
|
|
1849
1858
|
# Convert to SAELens naming convention
|
|
1850
1859
|
# TemporalSAE uses: D (decoder), E (encoder), b (bias), attn_layers.*
|
|
@@ -14,6 +14,12 @@ from .jumprelu_sae import (
|
|
|
14
14
|
JumpReLUTrainingSAE,
|
|
15
15
|
JumpReLUTrainingSAEConfig,
|
|
16
16
|
)
|
|
17
|
+
from .matching_pursuit_sae import (
|
|
18
|
+
MatchingPursuitSAE,
|
|
19
|
+
MatchingPursuitSAEConfig,
|
|
20
|
+
MatchingPursuitTrainingSAE,
|
|
21
|
+
MatchingPursuitTrainingSAEConfig,
|
|
22
|
+
)
|
|
17
23
|
from .matryoshka_batchtopk_sae import (
|
|
18
24
|
MatryoshkaBatchTopKTrainingSAE,
|
|
19
25
|
MatryoshkaBatchTopKTrainingSAEConfig,
|
|
@@ -78,4 +84,8 @@ __all__ = [
|
|
|
78
84
|
"MatryoshkaBatchTopKTrainingSAEConfig",
|
|
79
85
|
"TemporalSAE",
|
|
80
86
|
"TemporalSAEConfig",
|
|
87
|
+
"MatchingPursuitSAE",
|
|
88
|
+
"MatchingPursuitTrainingSAE",
|
|
89
|
+
"MatchingPursuitSAEConfig",
|
|
90
|
+
"MatchingPursuitTrainingSAEConfig",
|
|
81
91
|
]
|
|
@@ -0,0 +1,334 @@
|
|
|
1
|
+
"""Matching Pursuit SAE"""
|
|
2
|
+
|
|
3
|
+
import warnings
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from typing_extensions import override
|
|
9
|
+
|
|
10
|
+
from sae_lens.saes.sae import (
|
|
11
|
+
SAE,
|
|
12
|
+
SAEConfig,
|
|
13
|
+
TrainCoefficientConfig,
|
|
14
|
+
TrainingSAE,
|
|
15
|
+
TrainingSAEConfig,
|
|
16
|
+
TrainStepInput,
|
|
17
|
+
TrainStepOutput,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
# --- inference ---
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class MatchingPursuitSAEConfig(SAEConfig):
|
|
25
|
+
"""
|
|
26
|
+
Configuration class for MatchingPursuitSAE inference.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
residual_threshold (float): residual error at which to stop selecting latents. Default 1e-2.
|
|
30
|
+
max_iterations (int | None): Maximum iterations (default: d_in if set to None).
|
|
31
|
+
Defaults to None.
|
|
32
|
+
stop_on_duplicate_support (bool): Whether to stop selecting latents if the support set has not changed from the previous iteration. Defaults to True.
|
|
33
|
+
d_in (int): Input dimension (dimensionality of the activations being encoded).
|
|
34
|
+
Inherited from SAEConfig.
|
|
35
|
+
d_sae (int): SAE latent dimension (number of features in the SAE).
|
|
36
|
+
Inherited from SAEConfig.
|
|
37
|
+
dtype (str): Data type for the SAE parameters. Inherited from SAEConfig.
|
|
38
|
+
Defaults to "float32".
|
|
39
|
+
device (str): Device to place the SAE on. Inherited from SAEConfig.
|
|
40
|
+
Defaults to "cpu".
|
|
41
|
+
apply_b_dec_to_input (bool): Whether to apply decoder bias to the input
|
|
42
|
+
before encoding. Inherited from SAEConfig. Defaults to True.
|
|
43
|
+
normalize_activations (Literal["none", "expected_average_only_in", "constant_norm_rescale", "layer_norm"]):
|
|
44
|
+
Normalization strategy for input activations. Inherited from SAEConfig.
|
|
45
|
+
Defaults to "none".
|
|
46
|
+
reshape_activations (Literal["none", "hook_z"]): How to reshape activations
|
|
47
|
+
(useful for attention head outputs). Inherited from SAEConfig.
|
|
48
|
+
Defaults to "none".
|
|
49
|
+
metadata (SAEMetadata): Metadata about the SAE (model name, hook name, etc.).
|
|
50
|
+
Inherited from SAEConfig.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
residual_threshold: float = 1e-2
|
|
54
|
+
max_iterations: int | None = None
|
|
55
|
+
stop_on_duplicate_support: bool = True
|
|
56
|
+
|
|
57
|
+
@override
|
|
58
|
+
@classmethod
|
|
59
|
+
def architecture(cls) -> str:
|
|
60
|
+
return "matching_pursuit"
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class MatchingPursuitSAE(SAE[MatchingPursuitSAEConfig]):
|
|
64
|
+
"""
|
|
65
|
+
An inference-only sparse autoencoder using a "matching pursuit" activation function.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
# Matching pursuit is a tied SAE, so we use W_enc as the decoder transposed
|
|
69
|
+
@property
|
|
70
|
+
def W_enc(self) -> torch.Tensor: # pyright: ignore[reportIncompatibleVariableOverride]
|
|
71
|
+
return self.W_dec.T
|
|
72
|
+
|
|
73
|
+
# hacky way to get around the base class having W_enc.
|
|
74
|
+
# TODO: harmonize with the base class in next major release
|
|
75
|
+
@override
|
|
76
|
+
def __setattr__(self, name: str, value: Any):
|
|
77
|
+
if name == "W_enc":
|
|
78
|
+
return
|
|
79
|
+
super().__setattr__(name, value)
|
|
80
|
+
|
|
81
|
+
@override
|
|
82
|
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
|
83
|
+
"""
|
|
84
|
+
Converts input x into feature activations.
|
|
85
|
+
"""
|
|
86
|
+
sae_in = self.process_sae_in(x)
|
|
87
|
+
return _encode_matching_pursuit(
|
|
88
|
+
sae_in,
|
|
89
|
+
self.W_dec,
|
|
90
|
+
self.cfg.residual_threshold,
|
|
91
|
+
max_iterations=self.cfg.max_iterations,
|
|
92
|
+
stop_on_duplicate_support=self.cfg.stop_on_duplicate_support,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
@override
|
|
96
|
+
@torch.no_grad()
|
|
97
|
+
def fold_W_dec_norm(self) -> None:
|
|
98
|
+
raise NotImplementedError(
|
|
99
|
+
"Folding W_dec_norm is not safe for MatchingPursuit SAEs, as this may change the resulting activations"
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
@override
|
|
103
|
+
def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
|
|
104
|
+
"""
|
|
105
|
+
Decode the feature activations back to the input space.
|
|
106
|
+
Now, if hook_z reshaping is turned on, we reverse the flattening.
|
|
107
|
+
"""
|
|
108
|
+
sae_out_pre = feature_acts @ self.W_dec
|
|
109
|
+
# since this is a tied SAE, we need to make sure b_dec is only applied if applied at input
|
|
110
|
+
if self.cfg.apply_b_dec_to_input:
|
|
111
|
+
sae_out_pre = sae_out_pre + self.b_dec
|
|
112
|
+
sae_out_pre = self.hook_sae_recons(sae_out_pre)
|
|
113
|
+
sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
|
|
114
|
+
return self.reshape_fn_out(sae_out_pre, self.d_head)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
# --- training ---
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@dataclass
|
|
121
|
+
class MatchingPursuitTrainingSAEConfig(TrainingSAEConfig):
|
|
122
|
+
"""
|
|
123
|
+
Configuration class for training a MatchingPursuitTrainingSAE.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
residual_threshold (float): residual error at which to stop selecting latents. Default 1e-2.
|
|
127
|
+
max_iterations (int | None): Maximum iterations (default: d_in if set to None).
|
|
128
|
+
Defaults to None.
|
|
129
|
+
stop_on_duplicate_support (bool): Whether to stop selecting latents if the support set has not changed from the previous iteration. Defaults to True.
|
|
130
|
+
decoder_init_norm (float | None): Norm to initialize decoder weights to.
|
|
131
|
+
0.1 corresponds to the "heuristic" initialization from Anthropic's April update.
|
|
132
|
+
Use None to disable. Inherited from TrainingSAEConfig. Defaults to 0.1.
|
|
133
|
+
d_in (int): Input dimension (dimensionality of the activations being encoded).
|
|
134
|
+
Inherited from SAEConfig.
|
|
135
|
+
d_sae (int): SAE latent dimension (number of features in the SAE).
|
|
136
|
+
Inherited from SAEConfig.
|
|
137
|
+
dtype (str): Data type for the SAE parameters. Inherited from SAEConfig.
|
|
138
|
+
Defaults to "float32".
|
|
139
|
+
device (str): Device to place the SAE on. Inherited from SAEConfig.
|
|
140
|
+
Defaults to "cpu".
|
|
141
|
+
apply_b_dec_to_input (bool): Whether to apply decoder bias to the input
|
|
142
|
+
before encoding. Inherited from SAEConfig. Defaults to True.
|
|
143
|
+
normalize_activations (Literal["none", "expected_average_only_in", "constant_norm_rescale", "layer_norm"]):
|
|
144
|
+
Normalization strategy for input activations. Inherited from SAEConfig.
|
|
145
|
+
Defaults to "none".
|
|
146
|
+
reshape_activations (Literal["none", "hook_z"]): How to reshape activations
|
|
147
|
+
(useful for attention head outputs). Inherited from SAEConfig.
|
|
148
|
+
Defaults to "none".
|
|
149
|
+
metadata (SAEMetadata): Metadata about the SAE training (model name, hook name, etc.).
|
|
150
|
+
Inherited from SAEConfig.
|
|
151
|
+
"""
|
|
152
|
+
|
|
153
|
+
residual_threshold: float = 1e-2
|
|
154
|
+
max_iterations: int | None = None
|
|
155
|
+
stop_on_duplicate_support: bool = True
|
|
156
|
+
|
|
157
|
+
@override
|
|
158
|
+
@classmethod
|
|
159
|
+
def architecture(cls) -> str:
|
|
160
|
+
return "matching_pursuit"
|
|
161
|
+
|
|
162
|
+
@override
|
|
163
|
+
def __post_init__(self):
|
|
164
|
+
super().__post_init__()
|
|
165
|
+
if self.decoder_init_norm != 1.0:
|
|
166
|
+
self.decoder_init_norm = 1.0
|
|
167
|
+
warnings.warn(
|
|
168
|
+
"decoder_init_norm must be set to 1.0 for MatchingPursuitTrainingSAE, setting to 1.0"
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
class MatchingPursuitTrainingSAE(TrainingSAE[MatchingPursuitTrainingSAEConfig]):
|
|
173
|
+
# Matching pursuit is a tied SAE, so we use W_enc as the decoder transposed
|
|
174
|
+
@property
|
|
175
|
+
def W_enc(self) -> torch.Tensor: # pyright: ignore[reportIncompatibleVariableOverride]
|
|
176
|
+
return self.W_dec.T
|
|
177
|
+
|
|
178
|
+
# hacky way to get around the base class having W_enc.
|
|
179
|
+
# TODO: harmonize with the base class in next major release
|
|
180
|
+
@override
|
|
181
|
+
def __setattr__(self, name: str, value: Any):
|
|
182
|
+
if name == "W_enc":
|
|
183
|
+
return
|
|
184
|
+
super().__setattr__(name, value)
|
|
185
|
+
|
|
186
|
+
@override
|
|
187
|
+
def encode_with_hidden_pre(
|
|
188
|
+
self, x: torch.Tensor
|
|
189
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
190
|
+
"""
|
|
191
|
+
hidden_pre doesn't make sense for matching pursuit, since there is not a single pre-activation.
|
|
192
|
+
We just return zeros for the hidden_pre.
|
|
193
|
+
"""
|
|
194
|
+
|
|
195
|
+
sae_in = self.process_sae_in(x)
|
|
196
|
+
acts = _encode_matching_pursuit(
|
|
197
|
+
sae_in,
|
|
198
|
+
self.W_dec,
|
|
199
|
+
self.cfg.residual_threshold,
|
|
200
|
+
max_iterations=self.cfg.max_iterations,
|
|
201
|
+
stop_on_duplicate_support=self.cfg.stop_on_duplicate_support,
|
|
202
|
+
)
|
|
203
|
+
return acts, torch.zeros_like(acts)
|
|
204
|
+
|
|
205
|
+
@override
|
|
206
|
+
@torch.no_grad()
|
|
207
|
+
def fold_W_dec_norm(self) -> None:
|
|
208
|
+
raise NotImplementedError(
|
|
209
|
+
"Folding W_dec_norm is not safe for MatchingPursuit SAEs, as this may change the resulting activations"
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
@override
|
|
213
|
+
def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]:
|
|
214
|
+
return {}
|
|
215
|
+
|
|
216
|
+
@override
|
|
217
|
+
def calculate_aux_loss(
|
|
218
|
+
self,
|
|
219
|
+
step_input: TrainStepInput,
|
|
220
|
+
feature_acts: torch.Tensor,
|
|
221
|
+
hidden_pre: torch.Tensor,
|
|
222
|
+
sae_out: torch.Tensor,
|
|
223
|
+
) -> dict[str, torch.Tensor]:
|
|
224
|
+
return {}
|
|
225
|
+
|
|
226
|
+
@override
|
|
227
|
+
def training_forward_pass(self, step_input: TrainStepInput) -> TrainStepOutput:
|
|
228
|
+
output = super().training_forward_pass(step_input)
|
|
229
|
+
l0 = output.feature_acts.bool().float().sum(-1).to_dense()
|
|
230
|
+
residual_norm = (step_input.sae_in - output.sae_out).norm(dim=-1)
|
|
231
|
+
output.metrics["max_l0"] = l0.max()
|
|
232
|
+
output.metrics["min_l0"] = l0.min()
|
|
233
|
+
output.metrics["residual_norm"] = residual_norm.mean()
|
|
234
|
+
output.metrics["residual_threshold_converged_portion"] = (
|
|
235
|
+
(residual_norm < self.cfg.residual_threshold).float().mean()
|
|
236
|
+
)
|
|
237
|
+
return output
|
|
238
|
+
|
|
239
|
+
@override
|
|
240
|
+
def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
|
|
241
|
+
"""
|
|
242
|
+
Decode the feature activations back to the input space.
|
|
243
|
+
Now, if hook_z reshaping is turned on, we reverse the flattening.
|
|
244
|
+
"""
|
|
245
|
+
sae_out_pre = feature_acts @ self.W_dec
|
|
246
|
+
# since this is a tied SAE, we need to make sure b_dec is only applied if applied at input
|
|
247
|
+
if self.cfg.apply_b_dec_to_input:
|
|
248
|
+
sae_out_pre = sae_out_pre + self.b_dec
|
|
249
|
+
sae_out_pre = self.hook_sae_recons(sae_out_pre)
|
|
250
|
+
sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
|
|
251
|
+
return self.reshape_fn_out(sae_out_pre, self.d_head)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
# --- shared ---
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def _encode_matching_pursuit(
|
|
258
|
+
sae_in_centered: torch.Tensor,
|
|
259
|
+
W_dec: torch.Tensor,
|
|
260
|
+
residual_threshold: float,
|
|
261
|
+
max_iterations: int | None,
|
|
262
|
+
stop_on_duplicate_support: bool,
|
|
263
|
+
) -> torch.Tensor:
|
|
264
|
+
"""
|
|
265
|
+
Matching pursuit encoding.
|
|
266
|
+
|
|
267
|
+
Args:
|
|
268
|
+
sae_in_centered: Input activations, centered by b_dec. Shape [..., d_in].
|
|
269
|
+
W_dec: Decoder weight matrix. Shape [d_sae, d_in].
|
|
270
|
+
residual_threshold: Stop when residual norm falls below this.
|
|
271
|
+
max_iterations: Maximum iterations (default: d_in). Prevents infinite loops.
|
|
272
|
+
stop_on_duplicate_support: Whether to stop selecting latents if the support set has not changed from the previous iteration.
|
|
273
|
+
"""
|
|
274
|
+
residual = sae_in_centered.clone()
|
|
275
|
+
|
|
276
|
+
stop_on_residual_threshold = residual_threshold > 0
|
|
277
|
+
|
|
278
|
+
# Handle multi-dimensional inputs by flattening all but the last dimension
|
|
279
|
+
original_shape = residual.shape
|
|
280
|
+
if residual.ndim > 2:
|
|
281
|
+
residual = residual.reshape(-1, residual.shape[-1])
|
|
282
|
+
|
|
283
|
+
batch_size = residual.shape[0]
|
|
284
|
+
d_sae, d_in = W_dec.shape
|
|
285
|
+
|
|
286
|
+
if max_iterations is None:
|
|
287
|
+
max_iterations = d_in # Sensible upper bound
|
|
288
|
+
|
|
289
|
+
acts = torch.zeros(batch_size, d_sae, device=W_dec.device, dtype=residual.dtype)
|
|
290
|
+
prev_support = torch.zeros(batch_size, d_sae, dtype=torch.bool, device=W_dec.device)
|
|
291
|
+
done = torch.zeros(batch_size, dtype=torch.bool, device=W_dec.device)
|
|
292
|
+
|
|
293
|
+
for _ in range(max_iterations):
|
|
294
|
+
# Find indices without gradients - the full [batch, d_sae] matmul result
|
|
295
|
+
# doesn't need to be saved for backward since max indices don't need gradients
|
|
296
|
+
with torch.no_grad():
|
|
297
|
+
indices = (residual @ W_dec.T).relu().max(dim=1, keepdim=True).indices
|
|
298
|
+
indices_flat = indices.squeeze(1) # [batch_size]
|
|
299
|
+
|
|
300
|
+
# Compute values with gradients using only the selected decoder rows.
|
|
301
|
+
# This stores [batch, d_in] for backward instead of [batch, d_sae].
|
|
302
|
+
selected_dec = W_dec[indices_flat] # [batch_size, d_in]
|
|
303
|
+
values = (residual * selected_dec).sum(dim=-1, keepdim=True).relu()
|
|
304
|
+
|
|
305
|
+
# Mask values for samples that are already done
|
|
306
|
+
active_mask = (~done).unsqueeze(1)
|
|
307
|
+
masked_values = (values * active_mask.to(values.dtype)).to(acts.dtype)
|
|
308
|
+
|
|
309
|
+
acts.scatter_add_(1, indices, masked_values)
|
|
310
|
+
|
|
311
|
+
# Update residual
|
|
312
|
+
residual = residual - masked_values * selected_dec
|
|
313
|
+
|
|
314
|
+
if stop_on_duplicate_support or stop_on_residual_threshold:
|
|
315
|
+
with torch.no_grad():
|
|
316
|
+
support = acts != 0
|
|
317
|
+
|
|
318
|
+
# A sample is considered converged if:
|
|
319
|
+
# (1) the support set hasn't changed from the previous iteration (stability), or
|
|
320
|
+
# (2) the residual norm is below a given threshold (good enough reconstruction)
|
|
321
|
+
if stop_on_duplicate_support:
|
|
322
|
+
done = done | (support == prev_support).all(dim=1)
|
|
323
|
+
prev_support = support
|
|
324
|
+
if stop_on_residual_threshold:
|
|
325
|
+
done = done | (residual.norm(dim=-1) < residual_threshold)
|
|
326
|
+
|
|
327
|
+
if done.all():
|
|
328
|
+
break
|
|
329
|
+
|
|
330
|
+
# Reshape acts back to original shape (replacing last dimension with d_sae)
|
|
331
|
+
if len(original_shape) > 2:
|
|
332
|
+
acts = acts.reshape(*original_shape[:-1], acts.shape[-1])
|
|
333
|
+
|
|
334
|
+
return acts
|
|
@@ -27,11 +27,10 @@ from typing_extensions import deprecated, overload, override
|
|
|
27
27
|
|
|
28
28
|
from sae_lens import __version__
|
|
29
29
|
from sae_lens.constants import (
|
|
30
|
-
DTYPE_MAP,
|
|
31
30
|
SAE_CFG_FILENAME,
|
|
32
31
|
SAE_WEIGHTS_FILENAME,
|
|
33
32
|
)
|
|
34
|
-
from sae_lens.util import filter_valid_dataclass_fields
|
|
33
|
+
from sae_lens.util import dtype_to_str, filter_valid_dataclass_fields, str_to_dtype
|
|
35
34
|
|
|
36
35
|
if TYPE_CHECKING:
|
|
37
36
|
from sae_lens.config import LanguageModelSAERunnerConfig
|
|
@@ -253,7 +252,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
253
252
|
stacklevel=1,
|
|
254
253
|
)
|
|
255
254
|
|
|
256
|
-
self.dtype =
|
|
255
|
+
self.dtype = str_to_dtype(cfg.dtype)
|
|
257
256
|
self.device = torch.device(cfg.device)
|
|
258
257
|
self.use_error_term = use_error_term
|
|
259
258
|
|
|
@@ -437,8 +436,8 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
437
436
|
|
|
438
437
|
# Update dtype in config if provided
|
|
439
438
|
if dtype_arg is not None:
|
|
440
|
-
# Update the cfg.dtype
|
|
441
|
-
self.cfg.dtype =
|
|
439
|
+
# Update the cfg.dtype (use canonical short form like "float32")
|
|
440
|
+
self.cfg.dtype = dtype_to_str(dtype_arg)
|
|
442
441
|
|
|
443
442
|
# Update the dtype property
|
|
444
443
|
self.dtype = dtype_arg
|
|
@@ -534,6 +533,15 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
534
533
|
dtype: str | None = None,
|
|
535
534
|
converter: PretrainedSaeDiskLoader = sae_lens_disk_loader,
|
|
536
535
|
) -> T_SAE:
|
|
536
|
+
"""
|
|
537
|
+
Load a SAE from disk.
|
|
538
|
+
|
|
539
|
+
Args:
|
|
540
|
+
path: The path to the SAE weights and config.
|
|
541
|
+
device: The device to load the SAE on, defaults to "cpu".
|
|
542
|
+
dtype: The dtype to load the SAE on, defaults to None. If None, the dtype will be inferred from the SAE config.
|
|
543
|
+
converter: The converter to use to load the SAE, defaults to sae_lens_disk_loader.
|
|
544
|
+
"""
|
|
537
545
|
overrides = {"dtype": dtype} if dtype is not None else None
|
|
538
546
|
cfg_dict, state_dict = converter(path, device, cfg_overrides=overrides)
|
|
539
547
|
cfg_dict = handle_config_defaulting(cfg_dict)
|
|
@@ -542,10 +550,17 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
542
550
|
)
|
|
543
551
|
sae_cfg = sae_config_cls.from_dict(cfg_dict)
|
|
544
552
|
sae_cls = cls.get_sae_class_for_architecture(sae_cfg.architecture())
|
|
553
|
+
# hack to avoid using double memory when loading the SAE.
|
|
554
|
+
# first put the SAE on the meta device, then load the weights.
|
|
555
|
+
device = sae_cfg.device
|
|
556
|
+
sae_cfg.device = "meta"
|
|
545
557
|
sae = sae_cls(sae_cfg)
|
|
558
|
+
sae.cfg.device = device
|
|
546
559
|
sae.process_state_dict_for_loading(state_dict)
|
|
547
|
-
sae.load_state_dict(state_dict)
|
|
548
|
-
|
|
560
|
+
sae.load_state_dict(state_dict, assign=True)
|
|
561
|
+
# the loaders should already handle the dtype / device conversion
|
|
562
|
+
# but this is a fallback to guarantee the SAE is on the correct device and dtype
|
|
563
|
+
return sae.to(dtype=str_to_dtype(sae_cfg.dtype), device=device)
|
|
549
564
|
|
|
550
565
|
@classmethod
|
|
551
566
|
def from_pretrained(
|
|
@@ -553,6 +568,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
553
568
|
release: str,
|
|
554
569
|
sae_id: str,
|
|
555
570
|
device: str = "cpu",
|
|
571
|
+
dtype: str = "float32",
|
|
556
572
|
force_download: bool = False,
|
|
557
573
|
converter: PretrainedSaeHuggingfaceLoader | None = None,
|
|
558
574
|
) -> T_SAE:
|
|
@@ -562,10 +578,18 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
562
578
|
Args:
|
|
563
579
|
release: The release name. This will be mapped to a huggingface repo id based on the pretrained_saes.yaml file.
|
|
564
580
|
id: The id of the SAE to load. This will be mapped to a path in the huggingface repo.
|
|
565
|
-
device: The device to load the SAE on.
|
|
581
|
+
device: The device to load the SAE on, defaults to "cpu".
|
|
582
|
+
dtype: The dtype to load the SAE on, defaults to "float32".
|
|
583
|
+
force_download: Whether to force download the SAE weights and config, defaults to False.
|
|
584
|
+
converter: The converter to use to load the SAE, defaults to None. If None, the converter will be inferred from the release.
|
|
566
585
|
"""
|
|
567
586
|
return cls.from_pretrained_with_cfg_and_sparsity(
|
|
568
|
-
release,
|
|
587
|
+
release,
|
|
588
|
+
sae_id,
|
|
589
|
+
device,
|
|
590
|
+
force_download=force_download,
|
|
591
|
+
dtype=dtype,
|
|
592
|
+
converter=converter,
|
|
569
593
|
)[0]
|
|
570
594
|
|
|
571
595
|
@classmethod
|
|
@@ -574,6 +598,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
574
598
|
release: str,
|
|
575
599
|
sae_id: str,
|
|
576
600
|
device: str = "cpu",
|
|
601
|
+
dtype: str = "float32",
|
|
577
602
|
force_download: bool = False,
|
|
578
603
|
converter: PretrainedSaeHuggingfaceLoader | None = None,
|
|
579
604
|
) -> tuple[T_SAE, dict[str, Any], torch.Tensor | None]:
|
|
@@ -584,7 +609,10 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
584
609
|
Args:
|
|
585
610
|
release: The release name. This will be mapped to a huggingface repo id based on the pretrained_saes.yaml file.
|
|
586
611
|
id: The id of the SAE to load. This will be mapped to a path in the huggingface repo.
|
|
587
|
-
device: The device to load the SAE on.
|
|
612
|
+
device: The device to load the SAE on, defaults to "cpu".
|
|
613
|
+
dtype: The dtype to load the SAE on, defaults to "float32".
|
|
614
|
+
force_download: Whether to force download the SAE weights and config, defaults to False.
|
|
615
|
+
converter: The converter to use to load the SAE, defaults to None. If None, the converter will be inferred from the release.
|
|
588
616
|
"""
|
|
589
617
|
|
|
590
618
|
# get sae directory
|
|
@@ -634,6 +662,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
634
662
|
repo_id, folder_name = get_repo_id_and_folder_name(release, sae_id)
|
|
635
663
|
config_overrides = get_config_overrides(release, sae_id)
|
|
636
664
|
config_overrides["device"] = device
|
|
665
|
+
config_overrides["dtype"] = dtype
|
|
637
666
|
|
|
638
667
|
# Load config and weights
|
|
639
668
|
cfg_dict, state_dict, log_sparsities = conversion_loader(
|
|
@@ -651,9 +680,14 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
651
680
|
)
|
|
652
681
|
sae_cfg = sae_config_cls.from_dict(cfg_dict)
|
|
653
682
|
sae_cls = cls.get_sae_class_for_architecture(sae_cfg.architecture())
|
|
683
|
+
# hack to avoid using double memory when loading the SAE.
|
|
684
|
+
# first put the SAE on the meta device, then load the weights.
|
|
685
|
+
device = sae_cfg.device
|
|
686
|
+
sae_cfg.device = "meta"
|
|
654
687
|
sae = sae_cls(sae_cfg)
|
|
688
|
+
sae.cfg.device = device
|
|
655
689
|
sae.process_state_dict_for_loading(state_dict)
|
|
656
|
-
sae.load_state_dict(state_dict)
|
|
690
|
+
sae.load_state_dict(state_dict, assign=True)
|
|
657
691
|
|
|
658
692
|
# Apply normalization if needed
|
|
659
693
|
if cfg_dict.get("normalize_activations") == "expected_average_only_in":
|
|
@@ -666,7 +700,13 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
666
700
|
f"norm_scaling_factor not found for {release} and {sae_id}, but normalize_activations is 'expected_average_only_in'. Skipping normalization folding."
|
|
667
701
|
)
|
|
668
702
|
|
|
669
|
-
|
|
703
|
+
# the loaders should already handle the dtype / device conversion
|
|
704
|
+
# but this is a fallback to guarantee the SAE is on the correct device and dtype
|
|
705
|
+
return (
|
|
706
|
+
sae.to(dtype=str_to_dtype(dtype), device=device),
|
|
707
|
+
cfg_dict,
|
|
708
|
+
log_sparsities,
|
|
709
|
+
)
|
|
670
710
|
|
|
671
711
|
@classmethod
|
|
672
712
|
def from_dict(cls: type[T_SAE], config_dict: dict[str, Any]) -> T_SAE:
|
|
@@ -24,7 +24,7 @@ from sae_lens.config import (
|
|
|
24
24
|
HfDataset,
|
|
25
25
|
LanguageModelSAERunnerConfig,
|
|
26
26
|
)
|
|
27
|
-
from sae_lens.constants import ACTIVATIONS_STORE_STATE_FILENAME
|
|
27
|
+
from sae_lens.constants import ACTIVATIONS_STORE_STATE_FILENAME
|
|
28
28
|
from sae_lens.pretokenize_runner import get_special_token_from_cfg
|
|
29
29
|
from sae_lens.saes.sae import SAE, T_SAE_CONFIG, T_TRAINING_SAE_CONFIG
|
|
30
30
|
from sae_lens.tokenization_and_batching import concat_and_batch_sequences
|
|
@@ -32,6 +32,7 @@ from sae_lens.training.mixing_buffer import mixing_buffer
|
|
|
32
32
|
from sae_lens.util import (
|
|
33
33
|
extract_stop_at_layer_from_tlens_hook_name,
|
|
34
34
|
get_special_token_ids,
|
|
35
|
+
str_to_dtype,
|
|
35
36
|
)
|
|
36
37
|
|
|
37
38
|
|
|
@@ -258,7 +259,7 @@ class ActivationsStore:
|
|
|
258
259
|
self.prepend_bos = prepend_bos
|
|
259
260
|
self.normalize_activations = normalize_activations
|
|
260
261
|
self.device = torch.device(device)
|
|
261
|
-
self.dtype =
|
|
262
|
+
self.dtype = str_to_dtype(dtype)
|
|
262
263
|
self.cached_activations_path = cached_activations_path
|
|
263
264
|
self.autocast_lm = autocast_lm
|
|
264
265
|
self.seqpos_slice = seqpos_slice
|
|
@@ -5,8 +5,11 @@ from dataclasses import asdict, fields, is_dataclass
|
|
|
5
5
|
from pathlib import Path
|
|
6
6
|
from typing import Sequence, TypeVar
|
|
7
7
|
|
|
8
|
+
import torch
|
|
8
9
|
from transformers import PreTrainedTokenizerBase
|
|
9
10
|
|
|
11
|
+
from sae_lens.constants import DTYPE_MAP, DTYPE_TO_STR
|
|
12
|
+
|
|
10
13
|
K = TypeVar("K")
|
|
11
14
|
V = TypeVar("V")
|
|
12
15
|
|
|
@@ -90,3 +93,21 @@ def get_special_token_ids(tokenizer: PreTrainedTokenizerBase) -> list[int]:
|
|
|
90
93
|
special_tokens.add(token_id)
|
|
91
94
|
|
|
92
95
|
return list(special_tokens)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def str_to_dtype(dtype: str) -> torch.dtype:
|
|
99
|
+
"""Convert a string to a torch.dtype."""
|
|
100
|
+
if dtype not in DTYPE_MAP:
|
|
101
|
+
raise ValueError(
|
|
102
|
+
f"Invalid dtype: {dtype}. Must be one of {list(DTYPE_MAP.keys())}"
|
|
103
|
+
)
|
|
104
|
+
return DTYPE_MAP[dtype]
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def dtype_to_str(dtype: torch.dtype) -> str:
|
|
108
|
+
"""Convert a torch.dtype to a string."""
|
|
109
|
+
if dtype not in DTYPE_TO_STR:
|
|
110
|
+
raise ValueError(
|
|
111
|
+
f"Invalid dtype: {dtype}. Must be one of {list(DTYPE_TO_STR.keys())}"
|
|
112
|
+
)
|
|
113
|
+
return DTYPE_TO_STR[dtype]
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|