sae-lens 6.24.0__tar.gz → 6.26.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sae_lens-6.24.0 → sae_lens-6.26.0}/PKG-INFO +1 -1
- {sae_lens-6.24.0 → sae_lens-6.26.0}/pyproject.toml +1 -1
- {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/__init__.py +13 -1
- {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/cache_activations_runner.py +2 -2
- {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/config.py +7 -2
- {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/constants.py +8 -0
- {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/loading/pretrained_sae_loaders.py +66 -66
- {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/pretrained_saes.yaml +160 -144
- {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/saes/__init__.py +10 -0
- sae_lens-6.26.0/sae_lens/saes/matching_pursuit_sae.py +334 -0
- {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/saes/sae.py +52 -12
- {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/training/activations_store.py +3 -2
- {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/util.py +21 -0
- {sae_lens-6.24.0 → sae_lens-6.26.0}/LICENSE +0 -0
- {sae_lens-6.24.0 → sae_lens-6.26.0}/README.md +0 -0
- {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/analysis/__init__.py +0 -0
- {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
- {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/analysis/neuronpedia_integration.py +0 -0
- {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/evals.py +0 -0
- {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/llm_sae_training_runner.py +0 -0
- {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/load_model.py +0 -0
- {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/loading/__init__.py +0 -0
- {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/loading/pretrained_saes_directory.py +0 -0
- {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/pretokenize_runner.py +0 -0
- {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/registry.py +0 -0
- {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/saes/batchtopk_sae.py +0 -0
- {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/saes/gated_sae.py +0 -0
- {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/saes/jumprelu_sae.py +0 -0
- {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/saes/matryoshka_batchtopk_sae.py +0 -0
- {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/saes/standard_sae.py +0 -0
- {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/saes/temporal_sae.py +0 -0
- {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/saes/topk_sae.py +0 -0
- {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/saes/transcoder.py +0 -0
- {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/tokenization_and_batching.py +0 -0
- {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/training/__init__.py +0 -0
- {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/training/activation_scaler.py +0 -0
- {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/training/mixing_buffer.py +0 -0
- {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/training/optim.py +0 -0
- {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/training/sae_trainer.py +0 -0
- {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/training/types.py +0 -0
- {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
- {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/tutorial/tsea.py +0 -0
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
# ruff: noqa: E402
|
|
2
|
-
__version__ = "6.
|
|
2
|
+
__version__ = "6.26.0"
|
|
3
3
|
|
|
4
4
|
import logging
|
|
5
5
|
|
|
@@ -21,6 +21,10 @@ from sae_lens.saes import (
|
|
|
21
21
|
JumpReLUTrainingSAEConfig,
|
|
22
22
|
JumpReLUTranscoder,
|
|
23
23
|
JumpReLUTranscoderConfig,
|
|
24
|
+
MatchingPursuitSAE,
|
|
25
|
+
MatchingPursuitSAEConfig,
|
|
26
|
+
MatchingPursuitTrainingSAE,
|
|
27
|
+
MatchingPursuitTrainingSAEConfig,
|
|
24
28
|
MatryoshkaBatchTopKTrainingSAE,
|
|
25
29
|
MatryoshkaBatchTopKTrainingSAEConfig,
|
|
26
30
|
SAEConfig,
|
|
@@ -113,6 +117,10 @@ __all__ = [
|
|
|
113
117
|
"MatryoshkaBatchTopKTrainingSAEConfig",
|
|
114
118
|
"TemporalSAE",
|
|
115
119
|
"TemporalSAEConfig",
|
|
120
|
+
"MatchingPursuitSAE",
|
|
121
|
+
"MatchingPursuitTrainingSAE",
|
|
122
|
+
"MatchingPursuitSAEConfig",
|
|
123
|
+
"MatchingPursuitTrainingSAEConfig",
|
|
116
124
|
]
|
|
117
125
|
|
|
118
126
|
|
|
@@ -139,3 +147,7 @@ register_sae_class(
|
|
|
139
147
|
"jumprelu_skip_transcoder", JumpReLUSkipTranscoder, JumpReLUSkipTranscoderConfig
|
|
140
148
|
)
|
|
141
149
|
register_sae_class("temporal", TemporalSAE, TemporalSAEConfig)
|
|
150
|
+
register_sae_class("matching_pursuit", MatchingPursuitSAE, MatchingPursuitSAEConfig)
|
|
151
|
+
register_sae_training_class(
|
|
152
|
+
"matching_pursuit", MatchingPursuitTrainingSAE, MatchingPursuitTrainingSAEConfig
|
|
153
|
+
)
|
|
@@ -14,9 +14,9 @@ from transformer_lens.HookedTransformer import HookedRootModule
|
|
|
14
14
|
|
|
15
15
|
from sae_lens import logger
|
|
16
16
|
from sae_lens.config import CacheActivationsRunnerConfig
|
|
17
|
-
from sae_lens.constants import DTYPE_MAP
|
|
18
17
|
from sae_lens.load_model import load_model
|
|
19
18
|
from sae_lens.training.activations_store import ActivationsStore
|
|
19
|
+
from sae_lens.util import str_to_dtype
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
def _mk_activations_store(
|
|
@@ -97,7 +97,7 @@ class CacheActivationsRunner:
|
|
|
97
97
|
bytes_per_token = (
|
|
98
98
|
self.cfg.d_in * self.cfg.dtype.itemsize
|
|
99
99
|
if isinstance(self.cfg.dtype, torch.dtype)
|
|
100
|
-
else
|
|
100
|
+
else str_to_dtype(self.cfg.dtype).itemsize
|
|
101
101
|
)
|
|
102
102
|
total_training_tokens = self.cfg.n_seq_in_dataset * self.context_size
|
|
103
103
|
total_disk_space_gb = total_training_tokens * bytes_per_token / 10**9
|
|
@@ -17,9 +17,14 @@ from datasets import (
|
|
|
17
17
|
)
|
|
18
18
|
|
|
19
19
|
from sae_lens import __version__, logger
|
|
20
|
-
|
|
20
|
+
|
|
21
|
+
# keeping this unused import since some SAELens deps import DTYPE_MAP from config
|
|
22
|
+
from sae_lens.constants import (
|
|
23
|
+
DTYPE_MAP, # noqa: F401 # pyright: ignore[reportUnusedImport]
|
|
24
|
+
)
|
|
21
25
|
from sae_lens.registry import get_sae_training_class
|
|
22
26
|
from sae_lens.saes.sae import TrainingSAEConfig
|
|
27
|
+
from sae_lens.util import str_to_dtype
|
|
23
28
|
|
|
24
29
|
if TYPE_CHECKING:
|
|
25
30
|
pass
|
|
@@ -563,7 +568,7 @@ class CacheActivationsRunnerConfig:
|
|
|
563
568
|
|
|
564
569
|
@property
|
|
565
570
|
def bytes_per_token(self) -> int:
|
|
566
|
-
return self.d_in *
|
|
571
|
+
return self.d_in * str_to_dtype(self.dtype).itemsize
|
|
567
572
|
|
|
568
573
|
@property
|
|
569
574
|
def n_tokens_in_buffer(self) -> int:
|
|
@@ -11,6 +11,14 @@ DTYPE_MAP = {
|
|
|
11
11
|
"torch.bfloat16": torch.bfloat16,
|
|
12
12
|
}
|
|
13
13
|
|
|
14
|
+
# Reverse mapping from torch.dtype to canonical string format
|
|
15
|
+
DTYPE_TO_STR = {
|
|
16
|
+
torch.float32: "float32",
|
|
17
|
+
torch.float64: "float64",
|
|
18
|
+
torch.float16: "float16",
|
|
19
|
+
torch.bfloat16: "bfloat16",
|
|
20
|
+
}
|
|
21
|
+
|
|
14
22
|
|
|
15
23
|
SPARSITY_FILENAME = "sparsity.safetensors"
|
|
16
24
|
SAE_WEIGHTS_FILENAME = "sae_weights.safetensors"
|
|
@@ -12,11 +12,9 @@ from huggingface_hub import hf_hub_download, hf_hub_url
|
|
|
12
12
|
from huggingface_hub.utils import EntryNotFoundError, build_hf_headers
|
|
13
13
|
from packaging.version import Version
|
|
14
14
|
from safetensors import safe_open
|
|
15
|
-
from safetensors.torch import load_file
|
|
16
15
|
|
|
17
16
|
from sae_lens import logger
|
|
18
17
|
from sae_lens.constants import (
|
|
19
|
-
DTYPE_MAP,
|
|
20
18
|
SAE_CFG_FILENAME,
|
|
21
19
|
SAE_WEIGHTS_FILENAME,
|
|
22
20
|
SPARSIFY_WEIGHTS_FILENAME,
|
|
@@ -28,7 +26,7 @@ from sae_lens.loading.pretrained_saes_directory import (
|
|
|
28
26
|
get_repo_id_and_folder_name,
|
|
29
27
|
)
|
|
30
28
|
from sae_lens.registry import get_sae_class
|
|
31
|
-
from sae_lens.util import filter_valid_dataclass_fields
|
|
29
|
+
from sae_lens.util import filter_valid_dataclass_fields, str_to_dtype
|
|
32
30
|
|
|
33
31
|
LLM_METADATA_KEYS = {
|
|
34
32
|
"model_name",
|
|
@@ -51,6 +49,21 @@ LLM_METADATA_KEYS = {
|
|
|
51
49
|
}
|
|
52
50
|
|
|
53
51
|
|
|
52
|
+
def load_safetensors_weights(
|
|
53
|
+
path: str | Path, device: str = "cpu", dtype: torch.dtype | str | None = None
|
|
54
|
+
) -> dict[str, torch.Tensor]:
|
|
55
|
+
"""Load safetensors weights and optionally convert to a different dtype"""
|
|
56
|
+
loaded_weights = {}
|
|
57
|
+
dtype = str_to_dtype(dtype) if isinstance(dtype, str) else dtype
|
|
58
|
+
with safe_open(path, framework="pt", device=device) as f:
|
|
59
|
+
for k in f.keys(): # noqa: SIM118
|
|
60
|
+
weight = f.get_tensor(k)
|
|
61
|
+
if dtype is not None:
|
|
62
|
+
weight = weight.to(dtype=dtype)
|
|
63
|
+
loaded_weights[k] = weight
|
|
64
|
+
return loaded_weights
|
|
65
|
+
|
|
66
|
+
|
|
54
67
|
# loaders take in a release, sae_id, device, and whether to force download, and returns a tuple of config, state_dict, and log sparsity
|
|
55
68
|
class PretrainedSaeHuggingfaceLoader(Protocol):
|
|
56
69
|
def __call__(
|
|
@@ -341,7 +354,7 @@ def read_sae_components_from_disk(
|
|
|
341
354
|
Given a loaded dictionary and a path to a weight file, load the weights and return the state_dict.
|
|
342
355
|
"""
|
|
343
356
|
if dtype is None:
|
|
344
|
-
dtype =
|
|
357
|
+
dtype = str_to_dtype(cfg_dict["dtype"])
|
|
345
358
|
|
|
346
359
|
state_dict = {}
|
|
347
360
|
with safe_open(weight_path, framework="pt", device=device) as f: # type: ignore
|
|
@@ -682,15 +695,6 @@ def gemma_3_sae_huggingface_loader(
|
|
|
682
695
|
cfg_overrides,
|
|
683
696
|
)
|
|
684
697
|
|
|
685
|
-
# replace folder name of 65k with 64k
|
|
686
|
-
# TODO: remove this workaround once weights are fixed
|
|
687
|
-
if "270m-pt" in repo_id:
|
|
688
|
-
if "65k" in folder_name:
|
|
689
|
-
folder_name = folder_name.replace("65k", "64k")
|
|
690
|
-
# replace folder name of 262k with 250k
|
|
691
|
-
if "262k" in folder_name:
|
|
692
|
-
folder_name = folder_name.replace("262k", "250k")
|
|
693
|
-
|
|
694
698
|
params_file = "params.safetensors"
|
|
695
699
|
if "clt" in folder_name:
|
|
696
700
|
params_file = folder_name.split("/")[-1] + ".safetensors"
|
|
@@ -704,7 +708,9 @@ def gemma_3_sae_huggingface_loader(
|
|
|
704
708
|
force_download=force_download,
|
|
705
709
|
)
|
|
706
710
|
|
|
707
|
-
raw_state_dict =
|
|
711
|
+
raw_state_dict = load_safetensors_weights(
|
|
712
|
+
sae_path, device=device, dtype=cfg_dict.get("dtype")
|
|
713
|
+
)
|
|
708
714
|
|
|
709
715
|
with torch.no_grad():
|
|
710
716
|
w_dec = raw_state_dict["w_dec"]
|
|
@@ -791,11 +797,13 @@ def get_goodfire_huggingface_loader(
|
|
|
791
797
|
)
|
|
792
798
|
raw_state_dict = torch.load(sae_path, map_location=device)
|
|
793
799
|
|
|
800
|
+
target_dtype = str_to_dtype(cfg_dict.get("dtype", "float32"))
|
|
801
|
+
|
|
794
802
|
state_dict = {
|
|
795
|
-
"W_enc": raw_state_dict["encoder_linear.weight"].T,
|
|
796
|
-
"W_dec": raw_state_dict["decoder_linear.weight"].T,
|
|
797
|
-
"b_enc": raw_state_dict["encoder_linear.bias"],
|
|
798
|
-
"b_dec": raw_state_dict["decoder_linear.bias"],
|
|
803
|
+
"W_enc": raw_state_dict["encoder_linear.weight"].T.to(dtype=target_dtype),
|
|
804
|
+
"W_dec": raw_state_dict["decoder_linear.weight"].T.to(dtype=target_dtype),
|
|
805
|
+
"b_enc": raw_state_dict["encoder_linear.bias"].to(dtype=target_dtype),
|
|
806
|
+
"b_dec": raw_state_dict["decoder_linear.bias"].to(dtype=target_dtype),
|
|
799
807
|
}
|
|
800
808
|
|
|
801
809
|
return cfg_dict, state_dict, None
|
|
@@ -898,26 +906,19 @@ def llama_scope_sae_huggingface_loader(
|
|
|
898
906
|
force_download=force_download,
|
|
899
907
|
)
|
|
900
908
|
|
|
901
|
-
|
|
902
|
-
|
|
909
|
+
state_dict_loaded = load_safetensors_weights(
|
|
910
|
+
sae_path, device=device, dtype=cfg_dict.get("dtype")
|
|
911
|
+
)
|
|
903
912
|
|
|
904
913
|
# Convert and organize the weights
|
|
905
914
|
state_dict = {
|
|
906
|
-
"W_enc": state_dict_loaded["encoder.weight"]
|
|
907
|
-
|
|
908
|
-
.
|
|
909
|
-
"
|
|
910
|
-
.to(dtype=DTYPE_MAP[cfg_dict["dtype"]])
|
|
911
|
-
.T,
|
|
912
|
-
"b_enc": state_dict_loaded["encoder.bias"].to(
|
|
913
|
-
dtype=DTYPE_MAP[cfg_dict["dtype"]]
|
|
914
|
-
),
|
|
915
|
-
"b_dec": state_dict_loaded["decoder.bias"].to(
|
|
916
|
-
dtype=DTYPE_MAP[cfg_dict["dtype"]]
|
|
917
|
-
),
|
|
915
|
+
"W_enc": state_dict_loaded["encoder.weight"].T,
|
|
916
|
+
"W_dec": state_dict_loaded["decoder.weight"].T,
|
|
917
|
+
"b_enc": state_dict_loaded["encoder.bias"],
|
|
918
|
+
"b_dec": state_dict_loaded["decoder.bias"],
|
|
918
919
|
"threshold": torch.ones(
|
|
919
920
|
cfg_dict["d_sae"],
|
|
920
|
-
dtype=
|
|
921
|
+
dtype=str_to_dtype(cfg_dict["dtype"]),
|
|
921
922
|
device=cfg_dict["device"],
|
|
922
923
|
)
|
|
923
924
|
* cfg_dict["jump_relu_threshold"],
|
|
@@ -1228,26 +1229,17 @@ def llama_scope_r1_distill_sae_huggingface_loader(
|
|
|
1228
1229
|
force_download=force_download,
|
|
1229
1230
|
)
|
|
1230
1231
|
|
|
1231
|
-
|
|
1232
|
-
|
|
1232
|
+
state_dict_loaded = load_safetensors_weights(
|
|
1233
|
+
sae_path, device=device, dtype=cfg_dict.get("dtype")
|
|
1234
|
+
)
|
|
1233
1235
|
|
|
1234
1236
|
# Convert and organize the weights
|
|
1235
1237
|
state_dict = {
|
|
1236
|
-
"W_enc": state_dict_loaded["encoder.weight"]
|
|
1237
|
-
|
|
1238
|
-
.
|
|
1239
|
-
"
|
|
1240
|
-
|
|
1241
|
-
.T,
|
|
1242
|
-
"b_enc": state_dict_loaded["encoder.bias"].to(
|
|
1243
|
-
dtype=DTYPE_MAP[cfg_dict["dtype"]]
|
|
1244
|
-
),
|
|
1245
|
-
"b_dec": state_dict_loaded["decoder.bias"].to(
|
|
1246
|
-
dtype=DTYPE_MAP[cfg_dict["dtype"]]
|
|
1247
|
-
),
|
|
1248
|
-
"threshold": state_dict_loaded["log_jumprelu_threshold"]
|
|
1249
|
-
.to(dtype=DTYPE_MAP[cfg_dict["dtype"]])
|
|
1250
|
-
.exp(),
|
|
1238
|
+
"W_enc": state_dict_loaded["encoder.weight"].T,
|
|
1239
|
+
"W_dec": state_dict_loaded["decoder.weight"].T,
|
|
1240
|
+
"b_enc": state_dict_loaded["encoder.bias"],
|
|
1241
|
+
"b_dec": state_dict_loaded["decoder.bias"],
|
|
1242
|
+
"threshold": state_dict_loaded["log_jumprelu_threshold"].exp(),
|
|
1251
1243
|
}
|
|
1252
1244
|
|
|
1253
1245
|
# No sparsity tensor for Llama Scope SAEs
|
|
@@ -1367,34 +1359,34 @@ def sparsify_disk_loader(
|
|
|
1367
1359
|
cfg_dict = get_sparsify_config_from_disk(path, device, cfg_overrides)
|
|
1368
1360
|
|
|
1369
1361
|
weight_path = Path(path) / SPARSIFY_WEIGHTS_FILENAME
|
|
1370
|
-
state_dict_loaded =
|
|
1371
|
-
|
|
1372
|
-
|
|
1362
|
+
state_dict_loaded = load_safetensors_weights(
|
|
1363
|
+
weight_path, device=device, dtype=cfg_dict.get("dtype")
|
|
1364
|
+
)
|
|
1373
1365
|
|
|
1374
1366
|
W_enc = (
|
|
1375
1367
|
state_dict_loaded["W_enc"]
|
|
1376
1368
|
if "W_enc" in state_dict_loaded
|
|
1377
1369
|
else state_dict_loaded["encoder.weight"].T
|
|
1378
|
-
)
|
|
1370
|
+
)
|
|
1379
1371
|
|
|
1380
1372
|
if "W_dec" in state_dict_loaded:
|
|
1381
|
-
W_dec = state_dict_loaded["W_dec"].T
|
|
1373
|
+
W_dec = state_dict_loaded["W_dec"].T
|
|
1382
1374
|
else:
|
|
1383
|
-
W_dec = state_dict_loaded["decoder.weight"].T
|
|
1375
|
+
W_dec = state_dict_loaded["decoder.weight"].T
|
|
1384
1376
|
|
|
1385
1377
|
if "b_enc" in state_dict_loaded:
|
|
1386
|
-
b_enc = state_dict_loaded["b_enc"]
|
|
1378
|
+
b_enc = state_dict_loaded["b_enc"]
|
|
1387
1379
|
elif "encoder.bias" in state_dict_loaded:
|
|
1388
|
-
b_enc = state_dict_loaded["encoder.bias"]
|
|
1380
|
+
b_enc = state_dict_loaded["encoder.bias"]
|
|
1389
1381
|
else:
|
|
1390
|
-
b_enc = torch.zeros(cfg_dict["d_sae"], dtype=dtype, device=device)
|
|
1382
|
+
b_enc = torch.zeros(cfg_dict["d_sae"], dtype=W_dec.dtype, device=device)
|
|
1391
1383
|
|
|
1392
1384
|
if "b_dec" in state_dict_loaded:
|
|
1393
|
-
b_dec = state_dict_loaded["b_dec"]
|
|
1385
|
+
b_dec = state_dict_loaded["b_dec"]
|
|
1394
1386
|
elif "decoder.bias" in state_dict_loaded:
|
|
1395
|
-
b_dec = state_dict_loaded["decoder.bias"]
|
|
1387
|
+
b_dec = state_dict_loaded["decoder.bias"]
|
|
1396
1388
|
else:
|
|
1397
|
-
b_dec = torch.zeros(cfg_dict["d_in"], dtype=dtype, device=device)
|
|
1389
|
+
b_dec = torch.zeros(cfg_dict["d_in"], dtype=W_dec.dtype, device=device)
|
|
1398
1390
|
|
|
1399
1391
|
state_dict = {"W_enc": W_enc, "b_enc": b_enc, "W_dec": W_dec, "b_dec": b_dec}
|
|
1400
1392
|
return cfg_dict, state_dict
|
|
@@ -1625,7 +1617,9 @@ def mwhanna_transcoder_huggingface_loader(
|
|
|
1625
1617
|
)
|
|
1626
1618
|
|
|
1627
1619
|
# Load weights from safetensors
|
|
1628
|
-
state_dict =
|
|
1620
|
+
state_dict = load_safetensors_weights(
|
|
1621
|
+
file_path, device=device, dtype=cfg_dict.get("dtype")
|
|
1622
|
+
)
|
|
1629
1623
|
state_dict["W_enc"] = state_dict["W_enc"].T
|
|
1630
1624
|
|
|
1631
1625
|
return cfg_dict, state_dict, None
|
|
@@ -1709,8 +1703,12 @@ def mntss_clt_layer_huggingface_loader(
|
|
|
1709
1703
|
force_download=force_download,
|
|
1710
1704
|
)
|
|
1711
1705
|
|
|
1712
|
-
encoder_state_dict =
|
|
1713
|
-
|
|
1706
|
+
encoder_state_dict = load_safetensors_weights(
|
|
1707
|
+
encoder_path, device=device, dtype=cfg_dict.get("dtype")
|
|
1708
|
+
)
|
|
1709
|
+
decoder_state_dict = load_safetensors_weights(
|
|
1710
|
+
decoder_path, device=device, dtype=cfg_dict.get("dtype")
|
|
1711
|
+
)
|
|
1714
1712
|
|
|
1715
1713
|
with torch.no_grad():
|
|
1716
1714
|
state_dict = {
|
|
@@ -1853,7 +1851,9 @@ def temporal_sae_huggingface_loader(
|
|
|
1853
1851
|
)
|
|
1854
1852
|
|
|
1855
1853
|
# Load checkpoint from safetensors
|
|
1856
|
-
state_dict_raw =
|
|
1854
|
+
state_dict_raw = load_safetensors_weights(
|
|
1855
|
+
ckpt_path, device=device, dtype=cfg_dict.get("dtype")
|
|
1856
|
+
)
|
|
1857
1857
|
|
|
1858
1858
|
# Convert to SAELens naming convention
|
|
1859
1859
|
# TemporalSAE uses: D (decoder), E (encoder), b (bias), attn_layers.*
|