sae-lens 6.22.1__py3-none-any.whl → 6.25.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +8 -1
- sae_lens/cache_activations_runner.py +2 -2
- sae_lens/config.py +2 -2
- sae_lens/constants.py +8 -0
- sae_lens/loading/pretrained_sae_loaders.py +298 -80
- sae_lens/pretokenize_runner.py +3 -3
- sae_lens/pretrained_saes.yaml +26949 -97
- sae_lens/saes/__init__.py +4 -0
- sae_lens/saes/gated_sae.py +2 -2
- sae_lens/saes/jumprelu_sae.py +4 -4
- sae_lens/saes/sae.py +53 -13
- sae_lens/saes/topk_sae.py +1 -1
- sae_lens/saes/transcoder.py +41 -0
- sae_lens/training/activations_store.py +8 -7
- sae_lens/util.py +21 -0
- {sae_lens-6.22.1.dist-info → sae_lens-6.25.1.dist-info}/METADATA +2 -2
- {sae_lens-6.22.1.dist-info → sae_lens-6.25.1.dist-info}/RECORD +19 -19
- {sae_lens-6.22.1.dist-info → sae_lens-6.25.1.dist-info}/WHEEL +0 -0
- {sae_lens-6.22.1.dist-info → sae_lens-6.25.1.dist-info}/licenses/LICENSE +0 -0
sae_lens/__init__.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
# ruff: noqa: E402
|
|
2
|
-
__version__ = "6.
|
|
2
|
+
__version__ = "6.25.1"
|
|
3
3
|
|
|
4
4
|
import logging
|
|
5
5
|
|
|
@@ -15,6 +15,8 @@ from sae_lens.saes import (
|
|
|
15
15
|
GatedTrainingSAEConfig,
|
|
16
16
|
JumpReLUSAE,
|
|
17
17
|
JumpReLUSAEConfig,
|
|
18
|
+
JumpReLUSkipTranscoder,
|
|
19
|
+
JumpReLUSkipTranscoderConfig,
|
|
18
20
|
JumpReLUTrainingSAE,
|
|
19
21
|
JumpReLUTrainingSAEConfig,
|
|
20
22
|
JumpReLUTranscoder,
|
|
@@ -105,6 +107,8 @@ __all__ = [
|
|
|
105
107
|
"SkipTranscoderConfig",
|
|
106
108
|
"JumpReLUTranscoder",
|
|
107
109
|
"JumpReLUTranscoderConfig",
|
|
110
|
+
"JumpReLUSkipTranscoder",
|
|
111
|
+
"JumpReLUSkipTranscoderConfig",
|
|
108
112
|
"MatryoshkaBatchTopKTrainingSAE",
|
|
109
113
|
"MatryoshkaBatchTopKTrainingSAEConfig",
|
|
110
114
|
"TemporalSAE",
|
|
@@ -131,4 +135,7 @@ register_sae_training_class(
|
|
|
131
135
|
register_sae_class("transcoder", Transcoder, TranscoderConfig)
|
|
132
136
|
register_sae_class("skip_transcoder", SkipTranscoder, SkipTranscoderConfig)
|
|
133
137
|
register_sae_class("jumprelu_transcoder", JumpReLUTranscoder, JumpReLUTranscoderConfig)
|
|
138
|
+
register_sae_class(
|
|
139
|
+
"jumprelu_skip_transcoder", JumpReLUSkipTranscoder, JumpReLUSkipTranscoderConfig
|
|
140
|
+
)
|
|
134
141
|
register_sae_class("temporal", TemporalSAE, TemporalSAEConfig)
|
|
@@ -14,9 +14,9 @@ from transformer_lens.HookedTransformer import HookedRootModule
|
|
|
14
14
|
|
|
15
15
|
from sae_lens import logger
|
|
16
16
|
from sae_lens.config import CacheActivationsRunnerConfig
|
|
17
|
-
from sae_lens.constants import DTYPE_MAP
|
|
18
17
|
from sae_lens.load_model import load_model
|
|
19
18
|
from sae_lens.training.activations_store import ActivationsStore
|
|
19
|
+
from sae_lens.util import str_to_dtype
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
def _mk_activations_store(
|
|
@@ -97,7 +97,7 @@ class CacheActivationsRunner:
|
|
|
97
97
|
bytes_per_token = (
|
|
98
98
|
self.cfg.d_in * self.cfg.dtype.itemsize
|
|
99
99
|
if isinstance(self.cfg.dtype, torch.dtype)
|
|
100
|
-
else
|
|
100
|
+
else str_to_dtype(self.cfg.dtype).itemsize
|
|
101
101
|
)
|
|
102
102
|
total_training_tokens = self.cfg.n_seq_in_dataset * self.context_size
|
|
103
103
|
total_disk_space_gb = total_training_tokens * bytes_per_token / 10**9
|
sae_lens/config.py
CHANGED
|
@@ -17,9 +17,9 @@ from datasets import (
|
|
|
17
17
|
)
|
|
18
18
|
|
|
19
19
|
from sae_lens import __version__, logger
|
|
20
|
-
from sae_lens.constants import DTYPE_MAP
|
|
21
20
|
from sae_lens.registry import get_sae_training_class
|
|
22
21
|
from sae_lens.saes.sae import TrainingSAEConfig
|
|
22
|
+
from sae_lens.util import str_to_dtype
|
|
23
23
|
|
|
24
24
|
if TYPE_CHECKING:
|
|
25
25
|
pass
|
|
@@ -563,7 +563,7 @@ class CacheActivationsRunnerConfig:
|
|
|
563
563
|
|
|
564
564
|
@property
|
|
565
565
|
def bytes_per_token(self) -> int:
|
|
566
|
-
return self.d_in *
|
|
566
|
+
return self.d_in * str_to_dtype(self.dtype).itemsize
|
|
567
567
|
|
|
568
568
|
@property
|
|
569
569
|
def n_tokens_in_buffer(self) -> int:
|
sae_lens/constants.py
CHANGED
|
@@ -11,6 +11,14 @@ DTYPE_MAP = {
|
|
|
11
11
|
"torch.bfloat16": torch.bfloat16,
|
|
12
12
|
}
|
|
13
13
|
|
|
14
|
+
# Reverse mapping from torch.dtype to canonical string format
|
|
15
|
+
DTYPE_TO_STR = {
|
|
16
|
+
torch.float32: "float32",
|
|
17
|
+
torch.float64: "float64",
|
|
18
|
+
torch.float16: "float16",
|
|
19
|
+
torch.bfloat16: "bfloat16",
|
|
20
|
+
}
|
|
21
|
+
|
|
14
22
|
|
|
15
23
|
SPARSITY_FILENAME = "sparsity.safetensors"
|
|
16
24
|
SAE_WEIGHTS_FILENAME = "sae_weights.safetensors"
|
|
@@ -9,14 +9,12 @@ import requests
|
|
|
9
9
|
import torch
|
|
10
10
|
import yaml
|
|
11
11
|
from huggingface_hub import hf_hub_download, hf_hub_url
|
|
12
|
-
from huggingface_hub.utils import EntryNotFoundError
|
|
12
|
+
from huggingface_hub.utils import EntryNotFoundError, build_hf_headers
|
|
13
13
|
from packaging.version import Version
|
|
14
14
|
from safetensors import safe_open
|
|
15
|
-
from safetensors.torch import load_file
|
|
16
15
|
|
|
17
16
|
from sae_lens import logger
|
|
18
17
|
from sae_lens.constants import (
|
|
19
|
-
DTYPE_MAP,
|
|
20
18
|
SAE_CFG_FILENAME,
|
|
21
19
|
SAE_WEIGHTS_FILENAME,
|
|
22
20
|
SPARSIFY_WEIGHTS_FILENAME,
|
|
@@ -28,7 +26,7 @@ from sae_lens.loading.pretrained_saes_directory import (
|
|
|
28
26
|
get_repo_id_and_folder_name,
|
|
29
27
|
)
|
|
30
28
|
from sae_lens.registry import get_sae_class
|
|
31
|
-
from sae_lens.util import filter_valid_dataclass_fields
|
|
29
|
+
from sae_lens.util import filter_valid_dataclass_fields, str_to_dtype
|
|
32
30
|
|
|
33
31
|
LLM_METADATA_KEYS = {
|
|
34
32
|
"model_name",
|
|
@@ -46,9 +44,26 @@ LLM_METADATA_KEYS = {
|
|
|
46
44
|
"sae_lens_training_version",
|
|
47
45
|
"hook_name_out",
|
|
48
46
|
"hook_head_index_out",
|
|
47
|
+
"hf_hook_name",
|
|
48
|
+
"hf_hook_name_out",
|
|
49
49
|
}
|
|
50
50
|
|
|
51
51
|
|
|
52
|
+
def load_safetensors_weights(
|
|
53
|
+
path: str | Path, device: str = "cpu", dtype: torch.dtype | str | None = None
|
|
54
|
+
) -> dict[str, torch.Tensor]:
|
|
55
|
+
"""Load safetensors weights and optionally convert to a different dtype"""
|
|
56
|
+
loaded_weights = {}
|
|
57
|
+
dtype = str_to_dtype(dtype) if isinstance(dtype, str) else dtype
|
|
58
|
+
with safe_open(path, framework="pt", device=device) as f:
|
|
59
|
+
for k in f.keys(): # noqa: SIM118
|
|
60
|
+
weight = f.get_tensor(k)
|
|
61
|
+
if dtype is not None:
|
|
62
|
+
weight = weight.to(dtype=dtype)
|
|
63
|
+
loaded_weights[k] = weight
|
|
64
|
+
return loaded_weights
|
|
65
|
+
|
|
66
|
+
|
|
52
67
|
# loaders take in a release, sae_id, device, and whether to force download, and returns a tuple of config, state_dict, and log sparsity
|
|
53
68
|
class PretrainedSaeHuggingfaceLoader(Protocol):
|
|
54
69
|
def __call__(
|
|
@@ -339,7 +354,7 @@ def read_sae_components_from_disk(
|
|
|
339
354
|
Given a loaded dictionary and a path to a weight file, load the weights and return the state_dict.
|
|
340
355
|
"""
|
|
341
356
|
if dtype is None:
|
|
342
|
-
dtype =
|
|
357
|
+
dtype = str_to_dtype(cfg_dict["dtype"])
|
|
343
358
|
|
|
344
359
|
state_dict = {}
|
|
345
360
|
with safe_open(weight_path, framework="pt", device=device) as f: # type: ignore
|
|
@@ -523,6 +538,199 @@ def gemma_2_sae_huggingface_loader(
|
|
|
523
538
|
return cfg_dict, state_dict, log_sparsity
|
|
524
539
|
|
|
525
540
|
|
|
541
|
+
def _infer_gemma_3_raw_cfg_dict(repo_id: str, folder_name: str) -> dict[str, Any]:
|
|
542
|
+
"""
|
|
543
|
+
Infer the raw config dict for Gemma 3 SAEs from the repo_id and folder_name.
|
|
544
|
+
This is used when config.json doesn't exist in the repo.
|
|
545
|
+
"""
|
|
546
|
+
# Extract layer number from folder name
|
|
547
|
+
layer_match = re.search(r"layer_(\d+)", folder_name)
|
|
548
|
+
if layer_match is None:
|
|
549
|
+
raise ValueError(
|
|
550
|
+
f"Could not extract layer number from folder_name: {folder_name}"
|
|
551
|
+
)
|
|
552
|
+
layer = int(layer_match.group(1))
|
|
553
|
+
|
|
554
|
+
# Convert repo_id to model_name: google/gemma-scope-2-{size}-{suffix} -> google/gemma-3-{size}-{suffix}
|
|
555
|
+
model_name = repo_id.replace("gemma-scope-2", "gemma-3")
|
|
556
|
+
|
|
557
|
+
# Determine hook type and HF hook points based on folder_name
|
|
558
|
+
if "transcoder" in folder_name or "clt" in folder_name:
|
|
559
|
+
hf_hook_point_in = f"model.layers.{layer}.pre_feedforward_layernorm.output"
|
|
560
|
+
hf_hook_point_out = f"model.layers.{layer}.post_feedforward_layernorm.output"
|
|
561
|
+
elif "resid_post" in folder_name:
|
|
562
|
+
hf_hook_point_in = f"model.layers.{layer}.output"
|
|
563
|
+
hf_hook_point_out = None
|
|
564
|
+
elif "attn_out" in folder_name:
|
|
565
|
+
hf_hook_point_in = f"model.layers.{layer}.self_attn.o_proj.input"
|
|
566
|
+
hf_hook_point_out = None
|
|
567
|
+
elif "mlp_out" in folder_name:
|
|
568
|
+
hf_hook_point_in = f"model.layers.{layer}.post_feedforward_layernorm.output"
|
|
569
|
+
hf_hook_point_out = None
|
|
570
|
+
else:
|
|
571
|
+
raise ValueError(f"Could not infer hook type from folder_name: {folder_name}")
|
|
572
|
+
|
|
573
|
+
cfg: dict[str, Any] = {
|
|
574
|
+
"architecture": "jump_relu",
|
|
575
|
+
"model_name": model_name,
|
|
576
|
+
"hf_hook_point_in": hf_hook_point_in,
|
|
577
|
+
}
|
|
578
|
+
if hf_hook_point_out is not None:
|
|
579
|
+
cfg["hf_hook_point_out"] = hf_hook_point_out
|
|
580
|
+
|
|
581
|
+
return cfg
|
|
582
|
+
|
|
583
|
+
|
|
584
|
+
def get_gemma_3_config_from_hf(
|
|
585
|
+
repo_id: str,
|
|
586
|
+
folder_name: str,
|
|
587
|
+
device: str,
|
|
588
|
+
force_download: bool = False,
|
|
589
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
590
|
+
) -> dict[str, Any]:
|
|
591
|
+
# Try to load config.json from the repo, fall back to inferring if it doesn't exist
|
|
592
|
+
try:
|
|
593
|
+
config_path = hf_hub_download(
|
|
594
|
+
repo_id, f"{folder_name}/config.json", force_download=force_download
|
|
595
|
+
)
|
|
596
|
+
with open(config_path) as config_file:
|
|
597
|
+
raw_cfg_dict = json.load(config_file)
|
|
598
|
+
except EntryNotFoundError:
|
|
599
|
+
raw_cfg_dict = _infer_gemma_3_raw_cfg_dict(repo_id, folder_name)
|
|
600
|
+
|
|
601
|
+
if raw_cfg_dict.get("architecture") != "jump_relu":
|
|
602
|
+
raise ValueError(
|
|
603
|
+
f"Unexpected architecture in Gemma 3 config: {raw_cfg_dict.get('architecture')}"
|
|
604
|
+
)
|
|
605
|
+
|
|
606
|
+
layer_match = re.search(r"layer_(\d+)", folder_name)
|
|
607
|
+
if layer_match is None:
|
|
608
|
+
raise ValueError(
|
|
609
|
+
f"Could not extract layer number from folder_name: {folder_name}"
|
|
610
|
+
)
|
|
611
|
+
layer = int(layer_match.group(1))
|
|
612
|
+
hook_name_out = None
|
|
613
|
+
d_out = None
|
|
614
|
+
if "resid_post" in folder_name:
|
|
615
|
+
hook_name = f"blocks.{layer}.hook_resid_post"
|
|
616
|
+
elif "attn_out" in folder_name:
|
|
617
|
+
hook_name = f"blocks.{layer}.hook_attn_out"
|
|
618
|
+
elif "mlp_out" in folder_name:
|
|
619
|
+
hook_name = f"blocks.{layer}.hook_mlp_out"
|
|
620
|
+
elif "transcoder" in folder_name or "clt" in folder_name:
|
|
621
|
+
hook_name = f"blocks.{layer}.ln2.hook_normalized"
|
|
622
|
+
hook_name_out = f"blocks.{layer}.hook_mlp_out"
|
|
623
|
+
else:
|
|
624
|
+
raise ValueError("Hook name not found in folder_name.")
|
|
625
|
+
|
|
626
|
+
# hackily deal with clt file names
|
|
627
|
+
params_file_part = "/params.safetensors"
|
|
628
|
+
if "clt" in folder_name:
|
|
629
|
+
params_file_part = ".safetensors"
|
|
630
|
+
|
|
631
|
+
shapes_dict = get_safetensors_tensor_shapes(
|
|
632
|
+
repo_id, f"{folder_name}{params_file_part}"
|
|
633
|
+
)
|
|
634
|
+
d_in, d_sae = shapes_dict["w_enc"]
|
|
635
|
+
# TODO: update this for real model info
|
|
636
|
+
model_name = raw_cfg_dict["model_name"]
|
|
637
|
+
if "google" not in model_name:
|
|
638
|
+
model_name = "google/" + model_name
|
|
639
|
+
model_name = model_name.replace("-v3", "-3")
|
|
640
|
+
if "270m" in model_name:
|
|
641
|
+
# for some reason the 270m model on huggingface doesn't have the -pt suffix
|
|
642
|
+
model_name = model_name.replace("-pt", "")
|
|
643
|
+
|
|
644
|
+
architecture = "jumprelu"
|
|
645
|
+
if "transcoder" in folder_name or "clt" in folder_name:
|
|
646
|
+
architecture = "jumprelu_skip_transcoder"
|
|
647
|
+
d_out = shapes_dict["w_dec"][-1]
|
|
648
|
+
|
|
649
|
+
cfg = {
|
|
650
|
+
"architecture": architecture,
|
|
651
|
+
"d_in": d_in,
|
|
652
|
+
"d_sae": d_sae,
|
|
653
|
+
"dtype": "float32",
|
|
654
|
+
"model_name": model_name,
|
|
655
|
+
"hook_name": hook_name,
|
|
656
|
+
"hook_head_index": None,
|
|
657
|
+
"finetuning_scaling_factor": False,
|
|
658
|
+
"sae_lens_training_version": None,
|
|
659
|
+
"prepend_bos": True,
|
|
660
|
+
"dataset_path": "monology/pile-uncopyrighted",
|
|
661
|
+
"context_size": 1024,
|
|
662
|
+
"apply_b_dec_to_input": False,
|
|
663
|
+
"normalize_activations": None,
|
|
664
|
+
"hf_hook_name": raw_cfg_dict.get("hf_hook_point_in"),
|
|
665
|
+
}
|
|
666
|
+
if hook_name_out is not None:
|
|
667
|
+
cfg["hook_name_out"] = hook_name_out
|
|
668
|
+
cfg["hf_hook_name_out"] = raw_cfg_dict.get("hf_hook_point_out")
|
|
669
|
+
if d_out is not None:
|
|
670
|
+
cfg["d_out"] = d_out
|
|
671
|
+
if device is not None:
|
|
672
|
+
cfg["device"] = device
|
|
673
|
+
|
|
674
|
+
if cfg_overrides is not None:
|
|
675
|
+
cfg.update(cfg_overrides)
|
|
676
|
+
|
|
677
|
+
return cfg
|
|
678
|
+
|
|
679
|
+
|
|
680
|
+
def gemma_3_sae_huggingface_loader(
|
|
681
|
+
repo_id: str,
|
|
682
|
+
folder_name: str,
|
|
683
|
+
device: str = "cpu",
|
|
684
|
+
force_download: bool = False,
|
|
685
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
686
|
+
) -> tuple[dict[str, Any], dict[str, torch.Tensor], torch.Tensor | None]:
|
|
687
|
+
"""
|
|
688
|
+
Custom loader for Gemma 3 SAEs.
|
|
689
|
+
"""
|
|
690
|
+
cfg_dict = get_gemma_3_config_from_hf(
|
|
691
|
+
repo_id,
|
|
692
|
+
folder_name,
|
|
693
|
+
device,
|
|
694
|
+
force_download,
|
|
695
|
+
cfg_overrides,
|
|
696
|
+
)
|
|
697
|
+
|
|
698
|
+
params_file = "params.safetensors"
|
|
699
|
+
if "clt" in folder_name:
|
|
700
|
+
params_file = folder_name.split("/")[-1] + ".safetensors"
|
|
701
|
+
folder_name = "/".join(folder_name.split("/")[:-1])
|
|
702
|
+
|
|
703
|
+
# Download the SAE weights
|
|
704
|
+
sae_path = hf_hub_download(
|
|
705
|
+
repo_id=repo_id,
|
|
706
|
+
filename=params_file,
|
|
707
|
+
subfolder=folder_name,
|
|
708
|
+
force_download=force_download,
|
|
709
|
+
)
|
|
710
|
+
|
|
711
|
+
raw_state_dict = load_safetensors_weights(
|
|
712
|
+
sae_path, device=device, dtype=cfg_dict.get("dtype")
|
|
713
|
+
)
|
|
714
|
+
|
|
715
|
+
with torch.no_grad():
|
|
716
|
+
w_dec = raw_state_dict["w_dec"]
|
|
717
|
+
if "clt" in folder_name:
|
|
718
|
+
w_dec = w_dec.sum(dim=1).contiguous()
|
|
719
|
+
|
|
720
|
+
state_dict = {
|
|
721
|
+
"W_enc": raw_state_dict["w_enc"],
|
|
722
|
+
"W_dec": w_dec,
|
|
723
|
+
"b_enc": raw_state_dict["b_enc"],
|
|
724
|
+
"b_dec": raw_state_dict["b_dec"],
|
|
725
|
+
"threshold": raw_state_dict["threshold"],
|
|
726
|
+
}
|
|
727
|
+
|
|
728
|
+
if "affine_skip_connection" in raw_state_dict:
|
|
729
|
+
state_dict["W_skip"] = raw_state_dict["affine_skip_connection"]
|
|
730
|
+
|
|
731
|
+
return cfg_dict, state_dict, None
|
|
732
|
+
|
|
733
|
+
|
|
526
734
|
def get_goodfire_config_from_hf(
|
|
527
735
|
repo_id: str,
|
|
528
736
|
folder_name: str, # noqa: ARG001
|
|
@@ -589,11 +797,13 @@ def get_goodfire_huggingface_loader(
|
|
|
589
797
|
)
|
|
590
798
|
raw_state_dict = torch.load(sae_path, map_location=device)
|
|
591
799
|
|
|
800
|
+
target_dtype = str_to_dtype(cfg_dict.get("dtype", "float32"))
|
|
801
|
+
|
|
592
802
|
state_dict = {
|
|
593
|
-
"W_enc": raw_state_dict["encoder_linear.weight"].T,
|
|
594
|
-
"W_dec": raw_state_dict["decoder_linear.weight"].T,
|
|
595
|
-
"b_enc": raw_state_dict["encoder_linear.bias"],
|
|
596
|
-
"b_dec": raw_state_dict["decoder_linear.bias"],
|
|
803
|
+
"W_enc": raw_state_dict["encoder_linear.weight"].T.to(dtype=target_dtype),
|
|
804
|
+
"W_dec": raw_state_dict["decoder_linear.weight"].T.to(dtype=target_dtype),
|
|
805
|
+
"b_enc": raw_state_dict["encoder_linear.bias"].to(dtype=target_dtype),
|
|
806
|
+
"b_dec": raw_state_dict["decoder_linear.bias"].to(dtype=target_dtype),
|
|
597
807
|
}
|
|
598
808
|
|
|
599
809
|
return cfg_dict, state_dict, None
|
|
@@ -696,26 +906,19 @@ def llama_scope_sae_huggingface_loader(
|
|
|
696
906
|
force_download=force_download,
|
|
697
907
|
)
|
|
698
908
|
|
|
699
|
-
|
|
700
|
-
|
|
909
|
+
state_dict_loaded = load_safetensors_weights(
|
|
910
|
+
sae_path, device=device, dtype=cfg_dict.get("dtype")
|
|
911
|
+
)
|
|
701
912
|
|
|
702
913
|
# Convert and organize the weights
|
|
703
914
|
state_dict = {
|
|
704
|
-
"W_enc": state_dict_loaded["encoder.weight"]
|
|
705
|
-
|
|
706
|
-
.
|
|
707
|
-
"
|
|
708
|
-
.to(dtype=DTYPE_MAP[cfg_dict["dtype"]])
|
|
709
|
-
.T,
|
|
710
|
-
"b_enc": state_dict_loaded["encoder.bias"].to(
|
|
711
|
-
dtype=DTYPE_MAP[cfg_dict["dtype"]]
|
|
712
|
-
),
|
|
713
|
-
"b_dec": state_dict_loaded["decoder.bias"].to(
|
|
714
|
-
dtype=DTYPE_MAP[cfg_dict["dtype"]]
|
|
715
|
-
),
|
|
915
|
+
"W_enc": state_dict_loaded["encoder.weight"].T,
|
|
916
|
+
"W_dec": state_dict_loaded["decoder.weight"].T,
|
|
917
|
+
"b_enc": state_dict_loaded["encoder.bias"],
|
|
918
|
+
"b_dec": state_dict_loaded["decoder.bias"],
|
|
716
919
|
"threshold": torch.ones(
|
|
717
920
|
cfg_dict["d_sae"],
|
|
718
|
-
dtype=
|
|
921
|
+
dtype=str_to_dtype(cfg_dict["dtype"]),
|
|
719
922
|
device=cfg_dict["device"],
|
|
720
923
|
)
|
|
721
924
|
* cfg_dict["jump_relu_threshold"],
|
|
@@ -753,10 +956,14 @@ def get_dictionary_learning_config_1_from_hf(
|
|
|
753
956
|
activation_fn = "topk" if trainer["dict_class"] == "AutoEncoderTopK" else "relu"
|
|
754
957
|
activation_fn_kwargs = {"k": trainer["k"]} if activation_fn == "topk" else {}
|
|
755
958
|
|
|
959
|
+
architecture = "standard"
|
|
960
|
+
if trainer["dict_class"] == "GatedAutoEncoder":
|
|
961
|
+
architecture = "gated"
|
|
962
|
+
elif trainer["dict_class"] == "MatryoshkaBatchTopKSAE":
|
|
963
|
+
architecture = "jumprelu"
|
|
964
|
+
|
|
756
965
|
return {
|
|
757
|
-
"architecture":
|
|
758
|
-
"gated" if trainer["dict_class"] == "GatedAutoEncoder" else "standard"
|
|
759
|
-
),
|
|
966
|
+
"architecture": architecture,
|
|
760
967
|
"d_in": trainer["activation_dim"],
|
|
761
968
|
"d_sae": trainer["dict_size"],
|
|
762
969
|
"dtype": "float32",
|
|
@@ -905,9 +1112,12 @@ def dictionary_learning_sae_huggingface_loader_1(
|
|
|
905
1112
|
)
|
|
906
1113
|
encoder = torch.load(encoder_path, map_location="cpu")
|
|
907
1114
|
|
|
1115
|
+
W_enc = encoder["W_enc"] if "W_enc" in encoder else encoder["encoder.weight"].T
|
|
1116
|
+
W_dec = encoder["W_dec"] if "W_dec" in encoder else encoder["decoder.weight"].T
|
|
1117
|
+
|
|
908
1118
|
state_dict = {
|
|
909
|
-
"W_enc":
|
|
910
|
-
"W_dec":
|
|
1119
|
+
"W_enc": W_enc,
|
|
1120
|
+
"W_dec": W_dec,
|
|
911
1121
|
"b_dec": encoder.get(
|
|
912
1122
|
"b_dec", encoder.get("bias", encoder.get("decoder_bias", None))
|
|
913
1123
|
),
|
|
@@ -915,6 +1125,8 @@ def dictionary_learning_sae_huggingface_loader_1(
|
|
|
915
1125
|
|
|
916
1126
|
if "encoder.bias" in encoder:
|
|
917
1127
|
state_dict["b_enc"] = encoder["encoder.bias"]
|
|
1128
|
+
if "b_enc" in encoder:
|
|
1129
|
+
state_dict["b_enc"] = encoder["b_enc"]
|
|
918
1130
|
|
|
919
1131
|
if "mag_bias" in encoder:
|
|
920
1132
|
state_dict["b_mag"] = encoder["mag_bias"]
|
|
@@ -923,6 +1135,12 @@ def dictionary_learning_sae_huggingface_loader_1(
|
|
|
923
1135
|
if "r_mag" in encoder:
|
|
924
1136
|
state_dict["r_mag"] = encoder["r_mag"]
|
|
925
1137
|
|
|
1138
|
+
if "threshold" in encoder:
|
|
1139
|
+
threshold = encoder["threshold"]
|
|
1140
|
+
if threshold.ndim == 0:
|
|
1141
|
+
threshold = torch.full((W_enc.size(1),), threshold)
|
|
1142
|
+
state_dict["threshold"] = threshold
|
|
1143
|
+
|
|
926
1144
|
return cfg_dict, state_dict, None
|
|
927
1145
|
|
|
928
1146
|
|
|
@@ -1011,26 +1229,17 @@ def llama_scope_r1_distill_sae_huggingface_loader(
|
|
|
1011
1229
|
force_download=force_download,
|
|
1012
1230
|
)
|
|
1013
1231
|
|
|
1014
|
-
|
|
1015
|
-
|
|
1232
|
+
state_dict_loaded = load_safetensors_weights(
|
|
1233
|
+
sae_path, device=device, dtype=cfg_dict.get("dtype")
|
|
1234
|
+
)
|
|
1016
1235
|
|
|
1017
1236
|
# Convert and organize the weights
|
|
1018
1237
|
state_dict = {
|
|
1019
|
-
"W_enc": state_dict_loaded["encoder.weight"]
|
|
1020
|
-
|
|
1021
|
-
.
|
|
1022
|
-
"
|
|
1023
|
-
|
|
1024
|
-
.T,
|
|
1025
|
-
"b_enc": state_dict_loaded["encoder.bias"].to(
|
|
1026
|
-
dtype=DTYPE_MAP[cfg_dict["dtype"]]
|
|
1027
|
-
),
|
|
1028
|
-
"b_dec": state_dict_loaded["decoder.bias"].to(
|
|
1029
|
-
dtype=DTYPE_MAP[cfg_dict["dtype"]]
|
|
1030
|
-
),
|
|
1031
|
-
"threshold": state_dict_loaded["log_jumprelu_threshold"]
|
|
1032
|
-
.to(dtype=DTYPE_MAP[cfg_dict["dtype"]])
|
|
1033
|
-
.exp(),
|
|
1238
|
+
"W_enc": state_dict_loaded["encoder.weight"].T,
|
|
1239
|
+
"W_dec": state_dict_loaded["decoder.weight"].T,
|
|
1240
|
+
"b_enc": state_dict_loaded["encoder.bias"],
|
|
1241
|
+
"b_dec": state_dict_loaded["decoder.bias"],
|
|
1242
|
+
"threshold": state_dict_loaded["log_jumprelu_threshold"].exp(),
|
|
1034
1243
|
}
|
|
1035
1244
|
|
|
1036
1245
|
# No sparsity tensor for Llama Scope SAEs
|
|
@@ -1150,34 +1359,34 @@ def sparsify_disk_loader(
|
|
|
1150
1359
|
cfg_dict = get_sparsify_config_from_disk(path, device, cfg_overrides)
|
|
1151
1360
|
|
|
1152
1361
|
weight_path = Path(path) / SPARSIFY_WEIGHTS_FILENAME
|
|
1153
|
-
state_dict_loaded =
|
|
1154
|
-
|
|
1155
|
-
|
|
1362
|
+
state_dict_loaded = load_safetensors_weights(
|
|
1363
|
+
weight_path, device=device, dtype=cfg_dict.get("dtype")
|
|
1364
|
+
)
|
|
1156
1365
|
|
|
1157
1366
|
W_enc = (
|
|
1158
1367
|
state_dict_loaded["W_enc"]
|
|
1159
1368
|
if "W_enc" in state_dict_loaded
|
|
1160
1369
|
else state_dict_loaded["encoder.weight"].T
|
|
1161
|
-
)
|
|
1370
|
+
)
|
|
1162
1371
|
|
|
1163
1372
|
if "W_dec" in state_dict_loaded:
|
|
1164
|
-
W_dec = state_dict_loaded["W_dec"].T
|
|
1373
|
+
W_dec = state_dict_loaded["W_dec"].T
|
|
1165
1374
|
else:
|
|
1166
|
-
W_dec = state_dict_loaded["decoder.weight"].T
|
|
1375
|
+
W_dec = state_dict_loaded["decoder.weight"].T
|
|
1167
1376
|
|
|
1168
1377
|
if "b_enc" in state_dict_loaded:
|
|
1169
|
-
b_enc = state_dict_loaded["b_enc"]
|
|
1378
|
+
b_enc = state_dict_loaded["b_enc"]
|
|
1170
1379
|
elif "encoder.bias" in state_dict_loaded:
|
|
1171
|
-
b_enc = state_dict_loaded["encoder.bias"]
|
|
1380
|
+
b_enc = state_dict_loaded["encoder.bias"]
|
|
1172
1381
|
else:
|
|
1173
|
-
b_enc = torch.zeros(cfg_dict["d_sae"], dtype=dtype, device=device)
|
|
1382
|
+
b_enc = torch.zeros(cfg_dict["d_sae"], dtype=W_dec.dtype, device=device)
|
|
1174
1383
|
|
|
1175
1384
|
if "b_dec" in state_dict_loaded:
|
|
1176
|
-
b_dec = state_dict_loaded["b_dec"]
|
|
1385
|
+
b_dec = state_dict_loaded["b_dec"]
|
|
1177
1386
|
elif "decoder.bias" in state_dict_loaded:
|
|
1178
|
-
b_dec = state_dict_loaded["decoder.bias"]
|
|
1387
|
+
b_dec = state_dict_loaded["decoder.bias"]
|
|
1179
1388
|
else:
|
|
1180
|
-
b_dec = torch.zeros(cfg_dict["d_in"], dtype=dtype, device=device)
|
|
1389
|
+
b_dec = torch.zeros(cfg_dict["d_in"], dtype=W_dec.dtype, device=device)
|
|
1181
1390
|
|
|
1182
1391
|
state_dict = {"W_enc": W_enc, "b_enc": b_enc, "W_dec": W_dec, "b_dec": b_dec}
|
|
1183
1392
|
return cfg_dict, state_dict
|
|
@@ -1408,44 +1617,44 @@ def mwhanna_transcoder_huggingface_loader(
|
|
|
1408
1617
|
)
|
|
1409
1618
|
|
|
1410
1619
|
# Load weights from safetensors
|
|
1411
|
-
state_dict =
|
|
1620
|
+
state_dict = load_safetensors_weights(
|
|
1621
|
+
file_path, device=device, dtype=cfg_dict.get("dtype")
|
|
1622
|
+
)
|
|
1412
1623
|
state_dict["W_enc"] = state_dict["W_enc"].T
|
|
1413
1624
|
|
|
1414
1625
|
return cfg_dict, state_dict, None
|
|
1415
1626
|
|
|
1416
1627
|
|
|
1417
|
-
def get_safetensors_tensor_shapes(
|
|
1628
|
+
def get_safetensors_tensor_shapes(repo_id: str, filename: str) -> dict[str, list[int]]:
|
|
1418
1629
|
"""
|
|
1419
|
-
Get tensor shapes from a safetensors file
|
|
1630
|
+
Get tensor shapes from a safetensors file on HuggingFace Hub
|
|
1420
1631
|
without downloading the entire file.
|
|
1421
1632
|
|
|
1633
|
+
Uses HTTP range requests to fetch only the metadata header.
|
|
1634
|
+
|
|
1422
1635
|
Args:
|
|
1423
|
-
|
|
1636
|
+
repo_id: HuggingFace repo ID (e.g., "gg-gs/gemma-scope-2-1b-pt")
|
|
1637
|
+
filename: Path to the safetensors file within the repo
|
|
1424
1638
|
|
|
1425
1639
|
Returns:
|
|
1426
1640
|
Dictionary mapping tensor names to their shapes
|
|
1427
1641
|
"""
|
|
1428
|
-
|
|
1429
|
-
response = requests.head(url, timeout=10)
|
|
1430
|
-
response.raise_for_status()
|
|
1642
|
+
url = hf_hub_url(repo_id, filename)
|
|
1431
1643
|
|
|
1432
|
-
|
|
1433
|
-
|
|
1434
|
-
raise ValueError("Server does not support range requests")
|
|
1644
|
+
# Get HuggingFace headers (includes auth token if available)
|
|
1645
|
+
hf_headers = build_hf_headers()
|
|
1435
1646
|
|
|
1436
1647
|
# Fetch first 8 bytes to get metadata size
|
|
1437
|
-
headers = {"Range": "bytes=0-7"}
|
|
1648
|
+
headers = {**hf_headers, "Range": "bytes=0-7"}
|
|
1438
1649
|
response = requests.get(url, headers=headers, timeout=10)
|
|
1439
|
-
|
|
1440
|
-
raise ValueError("Failed to fetch initial bytes for metadata size")
|
|
1650
|
+
response.raise_for_status()
|
|
1441
1651
|
|
|
1442
1652
|
meta_size = int.from_bytes(response.content, byteorder="little")
|
|
1443
1653
|
|
|
1444
1654
|
# Fetch the metadata header
|
|
1445
|
-
headers = {"Range": f"bytes=8-{8 + meta_size - 1}"}
|
|
1655
|
+
headers = {**hf_headers, "Range": f"bytes=8-{8 + meta_size - 1}"}
|
|
1446
1656
|
response = requests.get(url, headers=headers, timeout=10)
|
|
1447
|
-
|
|
1448
|
-
raise ValueError("Failed to fetch metadata header")
|
|
1657
|
+
response.raise_for_status()
|
|
1449
1658
|
|
|
1450
1659
|
metadata_json = response.content.decode("utf-8").strip()
|
|
1451
1660
|
metadata = json.loads(metadata_json)
|
|
@@ -1494,8 +1703,12 @@ def mntss_clt_layer_huggingface_loader(
|
|
|
1494
1703
|
force_download=force_download,
|
|
1495
1704
|
)
|
|
1496
1705
|
|
|
1497
|
-
encoder_state_dict =
|
|
1498
|
-
|
|
1706
|
+
encoder_state_dict = load_safetensors_weights(
|
|
1707
|
+
encoder_path, device=device, dtype=cfg_dict.get("dtype")
|
|
1708
|
+
)
|
|
1709
|
+
decoder_state_dict = load_safetensors_weights(
|
|
1710
|
+
decoder_path, device=device, dtype=cfg_dict.get("dtype")
|
|
1711
|
+
)
|
|
1499
1712
|
|
|
1500
1713
|
with torch.no_grad():
|
|
1501
1714
|
state_dict = {
|
|
@@ -1525,9 +1738,10 @@ def get_mntss_clt_layer_config_from_hf(
|
|
|
1525
1738
|
with open(base_config_path) as f:
|
|
1526
1739
|
cfg_info: dict[str, Any] = yaml.safe_load(f)
|
|
1527
1740
|
|
|
1528
|
-
# Get tensor shapes without downloading full files
|
|
1529
|
-
|
|
1530
|
-
|
|
1741
|
+
# Get tensor shapes without downloading full files
|
|
1742
|
+
encoder_shapes = get_safetensors_tensor_shapes(
|
|
1743
|
+
repo_id, f"W_enc_{folder_name}.safetensors"
|
|
1744
|
+
)
|
|
1531
1745
|
|
|
1532
1746
|
# Extract shapes for the required tensors
|
|
1533
1747
|
b_dec_shape = encoder_shapes[f"b_dec_{folder_name}"]
|
|
@@ -1637,7 +1851,9 @@ def temporal_sae_huggingface_loader(
|
|
|
1637
1851
|
)
|
|
1638
1852
|
|
|
1639
1853
|
# Load checkpoint from safetensors
|
|
1640
|
-
state_dict_raw =
|
|
1854
|
+
state_dict_raw = load_safetensors_weights(
|
|
1855
|
+
ckpt_path, device=device, dtype=cfg_dict.get("dtype")
|
|
1856
|
+
)
|
|
1641
1857
|
|
|
1642
1858
|
# Convert to SAELens naming convention
|
|
1643
1859
|
# TemporalSAE uses: D (decoder), E (encoder), b (bias), attn_layers.*
|
|
@@ -1663,6 +1879,7 @@ NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
|
|
|
1663
1879
|
"sae_lens": sae_lens_huggingface_loader,
|
|
1664
1880
|
"connor_rob_hook_z": connor_rob_hook_z_huggingface_loader,
|
|
1665
1881
|
"gemma_2": gemma_2_sae_huggingface_loader,
|
|
1882
|
+
"gemma_3": gemma_3_sae_huggingface_loader,
|
|
1666
1883
|
"llama_scope": llama_scope_sae_huggingface_loader,
|
|
1667
1884
|
"llama_scope_r1_distill": llama_scope_r1_distill_sae_huggingface_loader,
|
|
1668
1885
|
"dictionary_learning_1": dictionary_learning_sae_huggingface_loader_1,
|
|
@@ -1680,6 +1897,7 @@ NAMED_PRETRAINED_SAE_CONFIG_GETTERS: dict[str, PretrainedSaeConfigHuggingfaceLoa
|
|
|
1680
1897
|
"sae_lens": get_sae_lens_config_from_hf,
|
|
1681
1898
|
"connor_rob_hook_z": get_connor_rob_hook_z_config_from_hf,
|
|
1682
1899
|
"gemma_2": get_gemma_2_config_from_hf,
|
|
1900
|
+
"gemma_3": get_gemma_3_config_from_hf,
|
|
1683
1901
|
"llama_scope": get_llama_scope_config_from_hf,
|
|
1684
1902
|
"llama_scope_r1_distill": get_llama_scope_r1_distill_config_from_hf,
|
|
1685
1903
|
"dictionary_learning_1": get_dictionary_learning_config_1_from_hf,
|
sae_lens/pretokenize_runner.py
CHANGED
|
@@ -186,13 +186,13 @@ class PretokenizeRunner:
|
|
|
186
186
|
"""
|
|
187
187
|
Load the dataset, tokenize it, and save it to disk and/or upload to Huggingface.
|
|
188
188
|
"""
|
|
189
|
-
dataset = load_dataset(
|
|
189
|
+
dataset = load_dataset( # type: ignore
|
|
190
190
|
self.cfg.dataset_path,
|
|
191
191
|
name=self.cfg.dataset_name,
|
|
192
192
|
data_dir=self.cfg.data_dir,
|
|
193
193
|
data_files=self.cfg.data_files,
|
|
194
|
-
split=self.cfg.split,
|
|
195
|
-
streaming=self.cfg.streaming,
|
|
194
|
+
split=self.cfg.split, # type: ignore
|
|
195
|
+
streaming=self.cfg.streaming, # type: ignore
|
|
196
196
|
)
|
|
197
197
|
if isinstance(dataset, DatasetDict):
|
|
198
198
|
raise ValueError(
|