sae-lens 6.15.0__py3-none-any.whl → 6.24.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +13 -1
- sae_lens/analysis/hooked_sae_transformer.py +4 -13
- sae_lens/cache_activations_runner.py +3 -4
- sae_lens/config.py +39 -2
- sae_lens/constants.py +1 -0
- sae_lens/llm_sae_training_runner.py +9 -4
- sae_lens/loading/pretrained_sae_loaders.py +430 -24
- sae_lens/loading/pretrained_saes_directory.py +5 -3
- sae_lens/pretokenize_runner.py +3 -3
- sae_lens/pretrained_saes.yaml +26977 -65
- sae_lens/saes/__init__.py +7 -0
- sae_lens/saes/batchtopk_sae.py +3 -1
- sae_lens/saes/gated_sae.py +6 -11
- sae_lens/saes/jumprelu_sae.py +8 -13
- sae_lens/saes/matryoshka_batchtopk_sae.py +8 -15
- sae_lens/saes/sae.py +20 -32
- sae_lens/saes/standard_sae.py +4 -9
- sae_lens/saes/temporal_sae.py +365 -0
- sae_lens/saes/topk_sae.py +8 -11
- sae_lens/saes/transcoder.py +41 -0
- sae_lens/training/activation_scaler.py +7 -0
- sae_lens/training/activations_store.py +54 -12
- sae_lens/training/optim.py +11 -0
- sae_lens/training/sae_trainer.py +50 -11
- {sae_lens-6.15.0.dist-info → sae_lens-6.24.1.dist-info}/METADATA +16 -16
- sae_lens-6.24.1.dist-info/RECORD +41 -0
- sae_lens-6.15.0.dist-info/RECORD +0 -40
- {sae_lens-6.15.0.dist-info → sae_lens-6.24.1.dist-info}/WHEEL +0 -0
- {sae_lens-6.15.0.dist-info → sae_lens-6.24.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -9,7 +9,7 @@ import requests
|
|
|
9
9
|
import torch
|
|
10
10
|
import yaml
|
|
11
11
|
from huggingface_hub import hf_hub_download, hf_hub_url
|
|
12
|
-
from huggingface_hub.utils import EntryNotFoundError
|
|
12
|
+
from huggingface_hub.utils import EntryNotFoundError, build_hf_headers
|
|
13
13
|
from packaging.version import Version
|
|
14
14
|
from safetensors import safe_open
|
|
15
15
|
from safetensors.torch import load_file
|
|
@@ -46,6 +46,8 @@ LLM_METADATA_KEYS = {
|
|
|
46
46
|
"sae_lens_training_version",
|
|
47
47
|
"hook_name_out",
|
|
48
48
|
"hook_head_index_out",
|
|
49
|
+
"hf_hook_name",
|
|
50
|
+
"hf_hook_name_out",
|
|
49
51
|
}
|
|
50
52
|
|
|
51
53
|
|
|
@@ -523,6 +525,282 @@ def gemma_2_sae_huggingface_loader(
|
|
|
523
525
|
return cfg_dict, state_dict, log_sparsity
|
|
524
526
|
|
|
525
527
|
|
|
528
|
+
def _infer_gemma_3_raw_cfg_dict(repo_id: str, folder_name: str) -> dict[str, Any]:
|
|
529
|
+
"""
|
|
530
|
+
Infer the raw config dict for Gemma 3 SAEs from the repo_id and folder_name.
|
|
531
|
+
This is used when config.json doesn't exist in the repo.
|
|
532
|
+
"""
|
|
533
|
+
# Extract layer number from folder name
|
|
534
|
+
layer_match = re.search(r"layer_(\d+)", folder_name)
|
|
535
|
+
if layer_match is None:
|
|
536
|
+
raise ValueError(
|
|
537
|
+
f"Could not extract layer number from folder_name: {folder_name}"
|
|
538
|
+
)
|
|
539
|
+
layer = int(layer_match.group(1))
|
|
540
|
+
|
|
541
|
+
# Convert repo_id to model_name: google/gemma-scope-2-{size}-{suffix} -> google/gemma-3-{size}-{suffix}
|
|
542
|
+
model_name = repo_id.replace("gemma-scope-2", "gemma-3")
|
|
543
|
+
|
|
544
|
+
# Determine hook type and HF hook points based on folder_name
|
|
545
|
+
if "transcoder" in folder_name or "clt" in folder_name:
|
|
546
|
+
hf_hook_point_in = f"model.layers.{layer}.pre_feedforward_layernorm.output"
|
|
547
|
+
hf_hook_point_out = f"model.layers.{layer}.post_feedforward_layernorm.output"
|
|
548
|
+
elif "resid_post" in folder_name:
|
|
549
|
+
hf_hook_point_in = f"model.layers.{layer}.output"
|
|
550
|
+
hf_hook_point_out = None
|
|
551
|
+
elif "attn_out" in folder_name:
|
|
552
|
+
hf_hook_point_in = f"model.layers.{layer}.self_attn.o_proj.input"
|
|
553
|
+
hf_hook_point_out = None
|
|
554
|
+
elif "mlp_out" in folder_name:
|
|
555
|
+
hf_hook_point_in = f"model.layers.{layer}.post_feedforward_layernorm.output"
|
|
556
|
+
hf_hook_point_out = None
|
|
557
|
+
else:
|
|
558
|
+
raise ValueError(f"Could not infer hook type from folder_name: {folder_name}")
|
|
559
|
+
|
|
560
|
+
cfg: dict[str, Any] = {
|
|
561
|
+
"architecture": "jump_relu",
|
|
562
|
+
"model_name": model_name,
|
|
563
|
+
"hf_hook_point_in": hf_hook_point_in,
|
|
564
|
+
}
|
|
565
|
+
if hf_hook_point_out is not None:
|
|
566
|
+
cfg["hf_hook_point_out"] = hf_hook_point_out
|
|
567
|
+
|
|
568
|
+
return cfg
|
|
569
|
+
|
|
570
|
+
|
|
571
|
+
def get_gemma_3_config_from_hf(
|
|
572
|
+
repo_id: str,
|
|
573
|
+
folder_name: str,
|
|
574
|
+
device: str,
|
|
575
|
+
force_download: bool = False,
|
|
576
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
577
|
+
) -> dict[str, Any]:
|
|
578
|
+
# Try to load config.json from the repo, fall back to inferring if it doesn't exist
|
|
579
|
+
try:
|
|
580
|
+
config_path = hf_hub_download(
|
|
581
|
+
repo_id, f"{folder_name}/config.json", force_download=force_download
|
|
582
|
+
)
|
|
583
|
+
with open(config_path) as config_file:
|
|
584
|
+
raw_cfg_dict = json.load(config_file)
|
|
585
|
+
except EntryNotFoundError:
|
|
586
|
+
raw_cfg_dict = _infer_gemma_3_raw_cfg_dict(repo_id, folder_name)
|
|
587
|
+
|
|
588
|
+
if raw_cfg_dict.get("architecture") != "jump_relu":
|
|
589
|
+
raise ValueError(
|
|
590
|
+
f"Unexpected architecture in Gemma 3 config: {raw_cfg_dict.get('architecture')}"
|
|
591
|
+
)
|
|
592
|
+
|
|
593
|
+
layer_match = re.search(r"layer_(\d+)", folder_name)
|
|
594
|
+
if layer_match is None:
|
|
595
|
+
raise ValueError(
|
|
596
|
+
f"Could not extract layer number from folder_name: {folder_name}"
|
|
597
|
+
)
|
|
598
|
+
layer = int(layer_match.group(1))
|
|
599
|
+
hook_name_out = None
|
|
600
|
+
d_out = None
|
|
601
|
+
if "resid_post" in folder_name:
|
|
602
|
+
hook_name = f"blocks.{layer}.hook_resid_post"
|
|
603
|
+
elif "attn_out" in folder_name:
|
|
604
|
+
hook_name = f"blocks.{layer}.hook_attn_out"
|
|
605
|
+
elif "mlp_out" in folder_name:
|
|
606
|
+
hook_name = f"blocks.{layer}.hook_mlp_out"
|
|
607
|
+
elif "transcoder" in folder_name or "clt" in folder_name:
|
|
608
|
+
hook_name = f"blocks.{layer}.ln2.hook_normalized"
|
|
609
|
+
hook_name_out = f"blocks.{layer}.hook_mlp_out"
|
|
610
|
+
else:
|
|
611
|
+
raise ValueError("Hook name not found in folder_name.")
|
|
612
|
+
|
|
613
|
+
# hackily deal with clt file names
|
|
614
|
+
params_file_part = "/params.safetensors"
|
|
615
|
+
if "clt" in folder_name:
|
|
616
|
+
params_file_part = ".safetensors"
|
|
617
|
+
|
|
618
|
+
shapes_dict = get_safetensors_tensor_shapes(
|
|
619
|
+
repo_id, f"{folder_name}{params_file_part}"
|
|
620
|
+
)
|
|
621
|
+
d_in, d_sae = shapes_dict["w_enc"]
|
|
622
|
+
# TODO: update this for real model info
|
|
623
|
+
model_name = raw_cfg_dict["model_name"]
|
|
624
|
+
if "google" not in model_name:
|
|
625
|
+
model_name = "google/" + model_name
|
|
626
|
+
model_name = model_name.replace("-v3", "-3")
|
|
627
|
+
if "270m" in model_name:
|
|
628
|
+
# for some reason the 270m model on huggingface doesn't have the -pt suffix
|
|
629
|
+
model_name = model_name.replace("-pt", "")
|
|
630
|
+
|
|
631
|
+
architecture = "jumprelu"
|
|
632
|
+
if "transcoder" in folder_name or "clt" in folder_name:
|
|
633
|
+
architecture = "jumprelu_skip_transcoder"
|
|
634
|
+
d_out = shapes_dict["w_dec"][-1]
|
|
635
|
+
|
|
636
|
+
cfg = {
|
|
637
|
+
"architecture": architecture,
|
|
638
|
+
"d_in": d_in,
|
|
639
|
+
"d_sae": d_sae,
|
|
640
|
+
"dtype": "float32",
|
|
641
|
+
"model_name": model_name,
|
|
642
|
+
"hook_name": hook_name,
|
|
643
|
+
"hook_head_index": None,
|
|
644
|
+
"finetuning_scaling_factor": False,
|
|
645
|
+
"sae_lens_training_version": None,
|
|
646
|
+
"prepend_bos": True,
|
|
647
|
+
"dataset_path": "monology/pile-uncopyrighted",
|
|
648
|
+
"context_size": 1024,
|
|
649
|
+
"apply_b_dec_to_input": False,
|
|
650
|
+
"normalize_activations": None,
|
|
651
|
+
"hf_hook_name": raw_cfg_dict.get("hf_hook_point_in"),
|
|
652
|
+
}
|
|
653
|
+
if hook_name_out is not None:
|
|
654
|
+
cfg["hook_name_out"] = hook_name_out
|
|
655
|
+
cfg["hf_hook_name_out"] = raw_cfg_dict.get("hf_hook_point_out")
|
|
656
|
+
if d_out is not None:
|
|
657
|
+
cfg["d_out"] = d_out
|
|
658
|
+
if device is not None:
|
|
659
|
+
cfg["device"] = device
|
|
660
|
+
|
|
661
|
+
if cfg_overrides is not None:
|
|
662
|
+
cfg.update(cfg_overrides)
|
|
663
|
+
|
|
664
|
+
return cfg
|
|
665
|
+
|
|
666
|
+
|
|
667
|
+
def gemma_3_sae_huggingface_loader(
|
|
668
|
+
repo_id: str,
|
|
669
|
+
folder_name: str,
|
|
670
|
+
device: str = "cpu",
|
|
671
|
+
force_download: bool = False,
|
|
672
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
673
|
+
) -> tuple[dict[str, Any], dict[str, torch.Tensor], torch.Tensor | None]:
|
|
674
|
+
"""
|
|
675
|
+
Custom loader for Gemma 3 SAEs.
|
|
676
|
+
"""
|
|
677
|
+
cfg_dict = get_gemma_3_config_from_hf(
|
|
678
|
+
repo_id,
|
|
679
|
+
folder_name,
|
|
680
|
+
device,
|
|
681
|
+
force_download,
|
|
682
|
+
cfg_overrides,
|
|
683
|
+
)
|
|
684
|
+
|
|
685
|
+
# replace folder name of 65k with 64k
|
|
686
|
+
# TODO: remove this workaround once weights are fixed
|
|
687
|
+
if "270m-pt" in repo_id:
|
|
688
|
+
if "65k" in folder_name:
|
|
689
|
+
folder_name = folder_name.replace("65k", "64k")
|
|
690
|
+
# replace folder name of 262k with 250k
|
|
691
|
+
if "262k" in folder_name:
|
|
692
|
+
folder_name = folder_name.replace("262k", "250k")
|
|
693
|
+
|
|
694
|
+
params_file = "params.safetensors"
|
|
695
|
+
if "clt" in folder_name:
|
|
696
|
+
params_file = folder_name.split("/")[-1] + ".safetensors"
|
|
697
|
+
folder_name = "/".join(folder_name.split("/")[:-1])
|
|
698
|
+
|
|
699
|
+
# Download the SAE weights
|
|
700
|
+
sae_path = hf_hub_download(
|
|
701
|
+
repo_id=repo_id,
|
|
702
|
+
filename=params_file,
|
|
703
|
+
subfolder=folder_name,
|
|
704
|
+
force_download=force_download,
|
|
705
|
+
)
|
|
706
|
+
|
|
707
|
+
raw_state_dict = load_file(sae_path, device=device)
|
|
708
|
+
|
|
709
|
+
with torch.no_grad():
|
|
710
|
+
w_dec = raw_state_dict["w_dec"]
|
|
711
|
+
if "clt" in folder_name:
|
|
712
|
+
w_dec = w_dec.sum(dim=1).contiguous()
|
|
713
|
+
|
|
714
|
+
state_dict = {
|
|
715
|
+
"W_enc": raw_state_dict["w_enc"],
|
|
716
|
+
"W_dec": w_dec,
|
|
717
|
+
"b_enc": raw_state_dict["b_enc"],
|
|
718
|
+
"b_dec": raw_state_dict["b_dec"],
|
|
719
|
+
"threshold": raw_state_dict["threshold"],
|
|
720
|
+
}
|
|
721
|
+
|
|
722
|
+
if "affine_skip_connection" in raw_state_dict:
|
|
723
|
+
state_dict["W_skip"] = raw_state_dict["affine_skip_connection"]
|
|
724
|
+
|
|
725
|
+
return cfg_dict, state_dict, None
|
|
726
|
+
|
|
727
|
+
|
|
728
|
+
def get_goodfire_config_from_hf(
|
|
729
|
+
repo_id: str,
|
|
730
|
+
folder_name: str, # noqa: ARG001
|
|
731
|
+
device: str,
|
|
732
|
+
force_download: bool = False, # noqa: ARG001
|
|
733
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
734
|
+
) -> dict[str, Any]:
|
|
735
|
+
cfg_dict = None
|
|
736
|
+
if repo_id == "Goodfire/Llama-3.3-70B-Instruct-SAE-l50":
|
|
737
|
+
if folder_name != "Llama-3.3-70B-Instruct-SAE-l50.pt":
|
|
738
|
+
raise ValueError(f"Unsupported Goodfire SAE: {repo_id}/{folder_name}")
|
|
739
|
+
cfg_dict = {
|
|
740
|
+
"architecture": "standard",
|
|
741
|
+
"d_in": 8192,
|
|
742
|
+
"d_sae": 65536,
|
|
743
|
+
"model_name": "meta-llama/Llama-3.3-70B-Instruct",
|
|
744
|
+
"hook_name": "blocks.50.hook_resid_post",
|
|
745
|
+
"hook_head_index": None,
|
|
746
|
+
"dataset_path": "lmsys/lmsys-chat-1m",
|
|
747
|
+
"apply_b_dec_to_input": False,
|
|
748
|
+
}
|
|
749
|
+
elif repo_id == "Goodfire/Llama-3.1-8B-Instruct-SAE-l19":
|
|
750
|
+
if folder_name != "Llama-3.1-8B-Instruct-SAE-l19.pth":
|
|
751
|
+
raise ValueError(f"Unsupported Goodfire SAE: {repo_id}/{folder_name}")
|
|
752
|
+
cfg_dict = {
|
|
753
|
+
"architecture": "standard",
|
|
754
|
+
"d_in": 4096,
|
|
755
|
+
"d_sae": 65536,
|
|
756
|
+
"model_name": "meta-llama/Llama-3.1-8B-Instruct",
|
|
757
|
+
"hook_name": "blocks.19.hook_resid_post",
|
|
758
|
+
"hook_head_index": None,
|
|
759
|
+
"dataset_path": "lmsys/lmsys-chat-1m",
|
|
760
|
+
"apply_b_dec_to_input": False,
|
|
761
|
+
}
|
|
762
|
+
if cfg_dict is None:
|
|
763
|
+
raise ValueError(f"Unsupported Goodfire SAE: {repo_id}/{folder_name}")
|
|
764
|
+
if device is not None:
|
|
765
|
+
cfg_dict["device"] = device
|
|
766
|
+
if cfg_overrides is not None:
|
|
767
|
+
cfg_dict.update(cfg_overrides)
|
|
768
|
+
return cfg_dict
|
|
769
|
+
|
|
770
|
+
|
|
771
|
+
def get_goodfire_huggingface_loader(
|
|
772
|
+
repo_id: str,
|
|
773
|
+
folder_name: str,
|
|
774
|
+
device: str = "cpu",
|
|
775
|
+
force_download: bool = False,
|
|
776
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
777
|
+
) -> tuple[dict[str, Any], dict[str, torch.Tensor], torch.Tensor | None]:
|
|
778
|
+
cfg_dict = get_goodfire_config_from_hf(
|
|
779
|
+
repo_id,
|
|
780
|
+
folder_name,
|
|
781
|
+
device,
|
|
782
|
+
force_download,
|
|
783
|
+
cfg_overrides,
|
|
784
|
+
)
|
|
785
|
+
|
|
786
|
+
# Download the SAE weights
|
|
787
|
+
sae_path = hf_hub_download(
|
|
788
|
+
repo_id=repo_id,
|
|
789
|
+
filename=folder_name,
|
|
790
|
+
force_download=force_download,
|
|
791
|
+
)
|
|
792
|
+
raw_state_dict = torch.load(sae_path, map_location=device)
|
|
793
|
+
|
|
794
|
+
state_dict = {
|
|
795
|
+
"W_enc": raw_state_dict["encoder_linear.weight"].T,
|
|
796
|
+
"W_dec": raw_state_dict["decoder_linear.weight"].T,
|
|
797
|
+
"b_enc": raw_state_dict["encoder_linear.bias"],
|
|
798
|
+
"b_dec": raw_state_dict["decoder_linear.bias"],
|
|
799
|
+
}
|
|
800
|
+
|
|
801
|
+
return cfg_dict, state_dict, None
|
|
802
|
+
|
|
803
|
+
|
|
526
804
|
def get_llama_scope_config_from_hf(
|
|
527
805
|
repo_id: str,
|
|
528
806
|
folder_name: str,
|
|
@@ -677,10 +955,14 @@ def get_dictionary_learning_config_1_from_hf(
|
|
|
677
955
|
activation_fn = "topk" if trainer["dict_class"] == "AutoEncoderTopK" else "relu"
|
|
678
956
|
activation_fn_kwargs = {"k": trainer["k"]} if activation_fn == "topk" else {}
|
|
679
957
|
|
|
958
|
+
architecture = "standard"
|
|
959
|
+
if trainer["dict_class"] == "GatedAutoEncoder":
|
|
960
|
+
architecture = "gated"
|
|
961
|
+
elif trainer["dict_class"] == "MatryoshkaBatchTopKSAE":
|
|
962
|
+
architecture = "jumprelu"
|
|
963
|
+
|
|
680
964
|
return {
|
|
681
|
-
"architecture":
|
|
682
|
-
"gated" if trainer["dict_class"] == "GatedAutoEncoder" else "standard"
|
|
683
|
-
),
|
|
965
|
+
"architecture": architecture,
|
|
684
966
|
"d_in": trainer["activation_dim"],
|
|
685
967
|
"d_sae": trainer["dict_size"],
|
|
686
968
|
"dtype": "float32",
|
|
@@ -829,9 +1111,12 @@ def dictionary_learning_sae_huggingface_loader_1(
|
|
|
829
1111
|
)
|
|
830
1112
|
encoder = torch.load(encoder_path, map_location="cpu")
|
|
831
1113
|
|
|
1114
|
+
W_enc = encoder["W_enc"] if "W_enc" in encoder else encoder["encoder.weight"].T
|
|
1115
|
+
W_dec = encoder["W_dec"] if "W_dec" in encoder else encoder["decoder.weight"].T
|
|
1116
|
+
|
|
832
1117
|
state_dict = {
|
|
833
|
-
"W_enc":
|
|
834
|
-
"W_dec":
|
|
1118
|
+
"W_enc": W_enc,
|
|
1119
|
+
"W_dec": W_dec,
|
|
835
1120
|
"b_dec": encoder.get(
|
|
836
1121
|
"b_dec", encoder.get("bias", encoder.get("decoder_bias", None))
|
|
837
1122
|
),
|
|
@@ -839,6 +1124,8 @@ def dictionary_learning_sae_huggingface_loader_1(
|
|
|
839
1124
|
|
|
840
1125
|
if "encoder.bias" in encoder:
|
|
841
1126
|
state_dict["b_enc"] = encoder["encoder.bias"]
|
|
1127
|
+
if "b_enc" in encoder:
|
|
1128
|
+
state_dict["b_enc"] = encoder["b_enc"]
|
|
842
1129
|
|
|
843
1130
|
if "mag_bias" in encoder:
|
|
844
1131
|
state_dict["b_mag"] = encoder["mag_bias"]
|
|
@@ -847,6 +1134,12 @@ def dictionary_learning_sae_huggingface_loader_1(
|
|
|
847
1134
|
if "r_mag" in encoder:
|
|
848
1135
|
state_dict["r_mag"] = encoder["r_mag"]
|
|
849
1136
|
|
|
1137
|
+
if "threshold" in encoder:
|
|
1138
|
+
threshold = encoder["threshold"]
|
|
1139
|
+
if threshold.ndim == 0:
|
|
1140
|
+
threshold = torch.full((W_enc.size(1),), threshold)
|
|
1141
|
+
state_dict["threshold"] = threshold
|
|
1142
|
+
|
|
850
1143
|
return cfg_dict, state_dict, None
|
|
851
1144
|
|
|
852
1145
|
|
|
@@ -1338,38 +1631,36 @@ def mwhanna_transcoder_huggingface_loader(
|
|
|
1338
1631
|
return cfg_dict, state_dict, None
|
|
1339
1632
|
|
|
1340
1633
|
|
|
1341
|
-
def get_safetensors_tensor_shapes(
|
|
1634
|
+
def get_safetensors_tensor_shapes(repo_id: str, filename: str) -> dict[str, list[int]]:
|
|
1342
1635
|
"""
|
|
1343
|
-
Get tensor shapes from a safetensors file
|
|
1636
|
+
Get tensor shapes from a safetensors file on HuggingFace Hub
|
|
1344
1637
|
without downloading the entire file.
|
|
1345
1638
|
|
|
1639
|
+
Uses HTTP range requests to fetch only the metadata header.
|
|
1640
|
+
|
|
1346
1641
|
Args:
|
|
1347
|
-
|
|
1642
|
+
repo_id: HuggingFace repo ID (e.g., "gg-gs/gemma-scope-2-1b-pt")
|
|
1643
|
+
filename: Path to the safetensors file within the repo
|
|
1348
1644
|
|
|
1349
1645
|
Returns:
|
|
1350
1646
|
Dictionary mapping tensor names to their shapes
|
|
1351
1647
|
"""
|
|
1352
|
-
|
|
1353
|
-
response = requests.head(url, timeout=10)
|
|
1354
|
-
response.raise_for_status()
|
|
1648
|
+
url = hf_hub_url(repo_id, filename)
|
|
1355
1649
|
|
|
1356
|
-
|
|
1357
|
-
|
|
1358
|
-
raise ValueError("Server does not support range requests")
|
|
1650
|
+
# Get HuggingFace headers (includes auth token if available)
|
|
1651
|
+
hf_headers = build_hf_headers()
|
|
1359
1652
|
|
|
1360
1653
|
# Fetch first 8 bytes to get metadata size
|
|
1361
|
-
headers = {"Range": "bytes=0-7"}
|
|
1654
|
+
headers = {**hf_headers, "Range": "bytes=0-7"}
|
|
1362
1655
|
response = requests.get(url, headers=headers, timeout=10)
|
|
1363
|
-
|
|
1364
|
-
raise ValueError("Failed to fetch initial bytes for metadata size")
|
|
1656
|
+
response.raise_for_status()
|
|
1365
1657
|
|
|
1366
1658
|
meta_size = int.from_bytes(response.content, byteorder="little")
|
|
1367
1659
|
|
|
1368
1660
|
# Fetch the metadata header
|
|
1369
|
-
headers = {"Range": f"bytes=8-{8 + meta_size - 1}"}
|
|
1661
|
+
headers = {**hf_headers, "Range": f"bytes=8-{8 + meta_size - 1}"}
|
|
1370
1662
|
response = requests.get(url, headers=headers, timeout=10)
|
|
1371
|
-
|
|
1372
|
-
raise ValueError("Failed to fetch metadata header")
|
|
1663
|
+
response.raise_for_status()
|
|
1373
1664
|
|
|
1374
1665
|
metadata_json = response.content.decode("utf-8").strip()
|
|
1375
1666
|
metadata = json.loads(metadata_json)
|
|
@@ -1449,9 +1740,10 @@ def get_mntss_clt_layer_config_from_hf(
|
|
|
1449
1740
|
with open(base_config_path) as f:
|
|
1450
1741
|
cfg_info: dict[str, Any] = yaml.safe_load(f)
|
|
1451
1742
|
|
|
1452
|
-
# Get tensor shapes without downloading full files
|
|
1453
|
-
|
|
1454
|
-
|
|
1743
|
+
# Get tensor shapes without downloading full files
|
|
1744
|
+
encoder_shapes = get_safetensors_tensor_shapes(
|
|
1745
|
+
repo_id, f"W_enc_{folder_name}.safetensors"
|
|
1746
|
+
)
|
|
1455
1747
|
|
|
1456
1748
|
# Extract shapes for the required tensors
|
|
1457
1749
|
b_dec_shape = encoder_shapes[f"b_dec_{folder_name}"]
|
|
@@ -1475,10 +1767,119 @@ def get_mntss_clt_layer_config_from_hf(
|
|
|
1475
1767
|
}
|
|
1476
1768
|
|
|
1477
1769
|
|
|
1770
|
+
def get_temporal_sae_config_from_hf(
|
|
1771
|
+
repo_id: str,
|
|
1772
|
+
folder_name: str,
|
|
1773
|
+
device: str,
|
|
1774
|
+
force_download: bool = False,
|
|
1775
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
1776
|
+
) -> dict[str, Any]:
|
|
1777
|
+
"""Get TemporalSAE config without loading weights."""
|
|
1778
|
+
# Download config file
|
|
1779
|
+
conf_path = hf_hub_download(
|
|
1780
|
+
repo_id=repo_id,
|
|
1781
|
+
filename=f"{folder_name}/conf.yaml",
|
|
1782
|
+
force_download=force_download,
|
|
1783
|
+
)
|
|
1784
|
+
|
|
1785
|
+
# Load and parse config
|
|
1786
|
+
with open(conf_path) as f:
|
|
1787
|
+
yaml_config = yaml.safe_load(f)
|
|
1788
|
+
|
|
1789
|
+
# Extract parameters
|
|
1790
|
+
d_in = yaml_config["llm"]["dimin"]
|
|
1791
|
+
exp_factor = yaml_config["sae"]["exp_factor"]
|
|
1792
|
+
d_sae = int(d_in * exp_factor)
|
|
1793
|
+
|
|
1794
|
+
# extract layer from folder_name eg : "layer_12/temporal"
|
|
1795
|
+
layer = re.search(r"layer_(\d+)", folder_name)
|
|
1796
|
+
if layer is None:
|
|
1797
|
+
raise ValueError(f"Could not find layer in folder_name: {folder_name}")
|
|
1798
|
+
layer = int(layer.group(1))
|
|
1799
|
+
|
|
1800
|
+
# Build config dict
|
|
1801
|
+
cfg_dict = {
|
|
1802
|
+
"architecture": "temporal",
|
|
1803
|
+
"hook_name": f"blocks.{layer}.hook_resid_post",
|
|
1804
|
+
"d_in": d_in,
|
|
1805
|
+
"d_sae": d_sae,
|
|
1806
|
+
"n_heads": yaml_config["sae"]["n_heads"],
|
|
1807
|
+
"n_attn_layers": yaml_config["sae"]["n_attn_layers"],
|
|
1808
|
+
"bottleneck_factor": yaml_config["sae"]["bottleneck_factor"],
|
|
1809
|
+
"sae_diff_type": yaml_config["sae"]["sae_diff_type"],
|
|
1810
|
+
"kval_topk": yaml_config["sae"]["kval_topk"],
|
|
1811
|
+
"tied_weights": yaml_config["sae"]["tied_weights"],
|
|
1812
|
+
"dtype": yaml_config["data"]["dtype"],
|
|
1813
|
+
"device": device,
|
|
1814
|
+
"normalize_activations": "constant_scalar_rescale",
|
|
1815
|
+
"activation_normalization_factor": yaml_config["sae"]["scaling_factor"],
|
|
1816
|
+
"apply_b_dec_to_input": True,
|
|
1817
|
+
}
|
|
1818
|
+
|
|
1819
|
+
if cfg_overrides:
|
|
1820
|
+
cfg_dict.update(cfg_overrides)
|
|
1821
|
+
|
|
1822
|
+
return cfg_dict
|
|
1823
|
+
|
|
1824
|
+
|
|
1825
|
+
def temporal_sae_huggingface_loader(
|
|
1826
|
+
repo_id: str,
|
|
1827
|
+
folder_name: str,
|
|
1828
|
+
device: str = "cpu",
|
|
1829
|
+
force_download: bool = False,
|
|
1830
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
1831
|
+
) -> tuple[dict[str, Any], dict[str, torch.Tensor], torch.Tensor | None]:
|
|
1832
|
+
"""
|
|
1833
|
+
Load TemporalSAE from canrager/temporalSAEs format (safetensors version).
|
|
1834
|
+
|
|
1835
|
+
Expects folder_name to contain:
|
|
1836
|
+
- conf.yaml (configuration)
|
|
1837
|
+
- latest_ckpt.safetensors (model weights)
|
|
1838
|
+
"""
|
|
1839
|
+
|
|
1840
|
+
cfg_dict = get_temporal_sae_config_from_hf(
|
|
1841
|
+
repo_id=repo_id,
|
|
1842
|
+
folder_name=folder_name,
|
|
1843
|
+
device=device,
|
|
1844
|
+
force_download=force_download,
|
|
1845
|
+
cfg_overrides=cfg_overrides,
|
|
1846
|
+
)
|
|
1847
|
+
|
|
1848
|
+
# Download checkpoint (safetensors format)
|
|
1849
|
+
ckpt_path = hf_hub_download(
|
|
1850
|
+
repo_id=repo_id,
|
|
1851
|
+
filename=f"{folder_name}/latest_ckpt.safetensors",
|
|
1852
|
+
force_download=force_download,
|
|
1853
|
+
)
|
|
1854
|
+
|
|
1855
|
+
# Load checkpoint from safetensors
|
|
1856
|
+
state_dict_raw = load_file(ckpt_path, device=device)
|
|
1857
|
+
|
|
1858
|
+
# Convert to SAELens naming convention
|
|
1859
|
+
# TemporalSAE uses: D (decoder), E (encoder), b (bias), attn_layers.*
|
|
1860
|
+
state_dict = {}
|
|
1861
|
+
|
|
1862
|
+
# Copy attention layers as-is
|
|
1863
|
+
for key, value in state_dict_raw.items():
|
|
1864
|
+
if key.startswith("attn_layers."):
|
|
1865
|
+
state_dict[key] = value.to(device)
|
|
1866
|
+
|
|
1867
|
+
# Main parameters
|
|
1868
|
+
state_dict["W_dec"] = state_dict_raw["D"].to(device)
|
|
1869
|
+
state_dict["b_dec"] = state_dict_raw["b"].to(device)
|
|
1870
|
+
|
|
1871
|
+
# Handle tied/untied weights
|
|
1872
|
+
if "E" in state_dict_raw:
|
|
1873
|
+
state_dict["W_enc"] = state_dict_raw["E"].to(device)
|
|
1874
|
+
|
|
1875
|
+
return cfg_dict, state_dict, None
|
|
1876
|
+
|
|
1877
|
+
|
|
1478
1878
|
NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
|
|
1479
1879
|
"sae_lens": sae_lens_huggingface_loader,
|
|
1480
1880
|
"connor_rob_hook_z": connor_rob_hook_z_huggingface_loader,
|
|
1481
1881
|
"gemma_2": gemma_2_sae_huggingface_loader,
|
|
1882
|
+
"gemma_3": gemma_3_sae_huggingface_loader,
|
|
1482
1883
|
"llama_scope": llama_scope_sae_huggingface_loader,
|
|
1483
1884
|
"llama_scope_r1_distill": llama_scope_r1_distill_sae_huggingface_loader,
|
|
1484
1885
|
"dictionary_learning_1": dictionary_learning_sae_huggingface_loader_1,
|
|
@@ -1487,6 +1888,8 @@ NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
|
|
|
1487
1888
|
"gemma_2_transcoder": gemma_2_transcoder_huggingface_loader,
|
|
1488
1889
|
"mwhanna_transcoder": mwhanna_transcoder_huggingface_loader,
|
|
1489
1890
|
"mntss_clt_layer_transcoder": mntss_clt_layer_huggingface_loader,
|
|
1891
|
+
"temporal": temporal_sae_huggingface_loader,
|
|
1892
|
+
"goodfire": get_goodfire_huggingface_loader,
|
|
1490
1893
|
}
|
|
1491
1894
|
|
|
1492
1895
|
|
|
@@ -1494,6 +1897,7 @@ NAMED_PRETRAINED_SAE_CONFIG_GETTERS: dict[str, PretrainedSaeConfigHuggingfaceLoa
|
|
|
1494
1897
|
"sae_lens": get_sae_lens_config_from_hf,
|
|
1495
1898
|
"connor_rob_hook_z": get_connor_rob_hook_z_config_from_hf,
|
|
1496
1899
|
"gemma_2": get_gemma_2_config_from_hf,
|
|
1900
|
+
"gemma_3": get_gemma_3_config_from_hf,
|
|
1497
1901
|
"llama_scope": get_llama_scope_config_from_hf,
|
|
1498
1902
|
"llama_scope_r1_distill": get_llama_scope_r1_distill_config_from_hf,
|
|
1499
1903
|
"dictionary_learning_1": get_dictionary_learning_config_1_from_hf,
|
|
@@ -1502,4 +1906,6 @@ NAMED_PRETRAINED_SAE_CONFIG_GETTERS: dict[str, PretrainedSaeConfigHuggingfaceLoa
|
|
|
1502
1906
|
"gemma_2_transcoder": get_gemma_2_transcoder_config_from_hf,
|
|
1503
1907
|
"mwhanna_transcoder": get_mwhanna_transcoder_config_from_hf,
|
|
1504
1908
|
"mntss_clt_layer_transcoder": get_mntss_clt_layer_config_from_hf,
|
|
1909
|
+
"temporal": get_temporal_sae_config_from_hf,
|
|
1910
|
+
"goodfire": get_goodfire_config_from_hf,
|
|
1505
1911
|
}
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
2
|
from functools import cache
|
|
3
|
-
from importlib import
|
|
3
|
+
from importlib.resources import files
|
|
4
4
|
from typing import Any
|
|
5
5
|
|
|
6
6
|
import yaml
|
|
@@ -24,7 +24,8 @@ def get_pretrained_saes_directory() -> dict[str, PretrainedSAELookup]:
|
|
|
24
24
|
package = "sae_lens"
|
|
25
25
|
# Access the file within the package using importlib.resources
|
|
26
26
|
directory: dict[str, PretrainedSAELookup] = {}
|
|
27
|
-
|
|
27
|
+
yaml_file = files(package).joinpath("pretrained_saes.yaml")
|
|
28
|
+
with yaml_file.open("r") as file:
|
|
28
29
|
# Load the YAML file content
|
|
29
30
|
data = yaml.safe_load(file)
|
|
30
31
|
for release, value in data.items():
|
|
@@ -68,7 +69,8 @@ def get_norm_scaling_factor(release: str, sae_id: str) -> float | None:
|
|
|
68
69
|
float | None: The norm_scaling_factor if it exists, None otherwise.
|
|
69
70
|
"""
|
|
70
71
|
package = "sae_lens"
|
|
71
|
-
|
|
72
|
+
yaml_file = files(package).joinpath("pretrained_saes.yaml")
|
|
73
|
+
with yaml_file.open("r") as file:
|
|
72
74
|
data = yaml.safe_load(file)
|
|
73
75
|
if release in data:
|
|
74
76
|
for sae_info in data[release]["saes"]:
|
sae_lens/pretokenize_runner.py
CHANGED
|
@@ -186,13 +186,13 @@ class PretokenizeRunner:
|
|
|
186
186
|
"""
|
|
187
187
|
Load the dataset, tokenize it, and save it to disk and/or upload to Huggingface.
|
|
188
188
|
"""
|
|
189
|
-
dataset = load_dataset(
|
|
189
|
+
dataset = load_dataset( # type: ignore
|
|
190
190
|
self.cfg.dataset_path,
|
|
191
191
|
name=self.cfg.dataset_name,
|
|
192
192
|
data_dir=self.cfg.data_dir,
|
|
193
193
|
data_files=self.cfg.data_files,
|
|
194
|
-
split=self.cfg.split,
|
|
195
|
-
streaming=self.cfg.streaming,
|
|
194
|
+
split=self.cfg.split, # type: ignore
|
|
195
|
+
streaming=self.cfg.streaming, # type: ignore
|
|
196
196
|
)
|
|
197
197
|
if isinstance(dataset, DatasetDict):
|
|
198
198
|
raise ValueError(
|