sae-lens 6.22.3__tar.gz → 6.24.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sae_lens-6.22.3 → sae_lens-6.24.0}/PKG-INFO +2 -2
- {sae_lens-6.22.3 → sae_lens-6.24.0}/README.md +1 -1
- {sae_lens-6.22.3 → sae_lens-6.24.0}/pyproject.toml +2 -1
- {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/__init__.py +8 -1
- {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/loading/pretrained_sae_loaders.py +242 -24
- {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/pretokenize_runner.py +3 -3
- sae_lens-6.24.0/sae_lens/pretrained_saes.yaml +41797 -0
- {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/saes/__init__.py +4 -0
- {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/saes/transcoder.py +41 -0
- {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/training/activations_store.py +1 -1
- sae_lens-6.22.3/sae_lens/pretrained_saes.yaml +0 -14961
- {sae_lens-6.22.3 → sae_lens-6.24.0}/LICENSE +0 -0
- {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/analysis/__init__.py +0 -0
- {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
- {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/analysis/neuronpedia_integration.py +0 -0
- {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/cache_activations_runner.py +0 -0
- {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/config.py +0 -0
- {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/constants.py +0 -0
- {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/evals.py +0 -0
- {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/llm_sae_training_runner.py +0 -0
- {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/load_model.py +0 -0
- {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/loading/__init__.py +0 -0
- {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/loading/pretrained_saes_directory.py +0 -0
- {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/registry.py +0 -0
- {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/saes/batchtopk_sae.py +0 -0
- {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/saes/gated_sae.py +0 -0
- {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/saes/jumprelu_sae.py +0 -0
- {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/saes/matryoshka_batchtopk_sae.py +0 -0
- {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/saes/sae.py +0 -0
- {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/saes/standard_sae.py +0 -0
- {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/saes/temporal_sae.py +0 -0
- {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/saes/topk_sae.py +0 -0
- {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/tokenization_and_batching.py +0 -0
- {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/training/__init__.py +0 -0
- {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/training/activation_scaler.py +0 -0
- {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/training/mixing_buffer.py +0 -0
- {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/training/optim.py +0 -0
- {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/training/sae_trainer.py +0 -0
- {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/training/types.py +0 -0
- {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
- {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/tutorial/tsea.py +0 -0
- {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/util.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: sae-lens
|
|
3
|
-
Version: 6.
|
|
3
|
+
Version: 6.24.0
|
|
4
4
|
Summary: Training and Analyzing Sparse Autoencoders (SAEs)
|
|
5
5
|
License: MIT
|
|
6
6
|
License-File: LICENSE
|
|
@@ -62,7 +62,7 @@ This library is maintained by [Joseph Bloom](https://www.decoderesearch.com/), [
|
|
|
62
62
|
|
|
63
63
|
## Loading Pre-trained SAEs.
|
|
64
64
|
|
|
65
|
-
Pre-trained SAEs for various models can be imported via SAE Lens. See this [page](https://decoderesearch.github.io/SAELens/
|
|
65
|
+
Pre-trained SAEs for various models can be imported via SAE Lens. See this [page](https://decoderesearch.github.io/SAELens/pretrained_saes/) for a list of all SAEs.
|
|
66
66
|
|
|
67
67
|
## Migrating to SAELens v6
|
|
68
68
|
|
|
@@ -26,7 +26,7 @@ This library is maintained by [Joseph Bloom](https://www.decoderesearch.com/), [
|
|
|
26
26
|
|
|
27
27
|
## Loading Pre-trained SAEs.
|
|
28
28
|
|
|
29
|
-
Pre-trained SAEs for various models can be imported via SAE Lens. See this [page](https://decoderesearch.github.io/SAELens/
|
|
29
|
+
Pre-trained SAEs for various models can be imported via SAE Lens. See this [page](https://decoderesearch.github.io/SAELens/pretrained_saes/) for a list of all SAEs.
|
|
30
30
|
|
|
31
31
|
## Migrating to SAELens v6
|
|
32
32
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "sae-lens"
|
|
3
|
-
version = "6.
|
|
3
|
+
version = "6.24.0"
|
|
4
4
|
description = "Training and Analyzing Sparse Autoencoders (SAEs)"
|
|
5
5
|
authors = ["Joseph Bloom"]
|
|
6
6
|
readme = "README.md"
|
|
@@ -47,6 +47,7 @@ docstr-coverage = "^2.3.2"
|
|
|
47
47
|
mkdocs = "^1.6.1"
|
|
48
48
|
mkdocs-material = "^9.5.34"
|
|
49
49
|
mkdocs-autorefs = "^1.4.2"
|
|
50
|
+
mkdocs-redirects = "^1.2.1"
|
|
50
51
|
mkdocs-section-index = "^0.3.9"
|
|
51
52
|
mkdocstrings = "^0.25.2"
|
|
52
53
|
mkdocstrings-python = "^1.10.9"
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
# ruff: noqa: E402
|
|
2
|
-
__version__ = "6.
|
|
2
|
+
__version__ = "6.24.0"
|
|
3
3
|
|
|
4
4
|
import logging
|
|
5
5
|
|
|
@@ -15,6 +15,8 @@ from sae_lens.saes import (
|
|
|
15
15
|
GatedTrainingSAEConfig,
|
|
16
16
|
JumpReLUSAE,
|
|
17
17
|
JumpReLUSAEConfig,
|
|
18
|
+
JumpReLUSkipTranscoder,
|
|
19
|
+
JumpReLUSkipTranscoderConfig,
|
|
18
20
|
JumpReLUTrainingSAE,
|
|
19
21
|
JumpReLUTrainingSAEConfig,
|
|
20
22
|
JumpReLUTranscoder,
|
|
@@ -105,6 +107,8 @@ __all__ = [
|
|
|
105
107
|
"SkipTranscoderConfig",
|
|
106
108
|
"JumpReLUTranscoder",
|
|
107
109
|
"JumpReLUTranscoderConfig",
|
|
110
|
+
"JumpReLUSkipTranscoder",
|
|
111
|
+
"JumpReLUSkipTranscoderConfig",
|
|
108
112
|
"MatryoshkaBatchTopKTrainingSAE",
|
|
109
113
|
"MatryoshkaBatchTopKTrainingSAEConfig",
|
|
110
114
|
"TemporalSAE",
|
|
@@ -131,4 +135,7 @@ register_sae_training_class(
|
|
|
131
135
|
register_sae_class("transcoder", Transcoder, TranscoderConfig)
|
|
132
136
|
register_sae_class("skip_transcoder", SkipTranscoder, SkipTranscoderConfig)
|
|
133
137
|
register_sae_class("jumprelu_transcoder", JumpReLUTranscoder, JumpReLUTranscoderConfig)
|
|
138
|
+
register_sae_class(
|
|
139
|
+
"jumprelu_skip_transcoder", JumpReLUSkipTranscoder, JumpReLUSkipTranscoderConfig
|
|
140
|
+
)
|
|
134
141
|
register_sae_class("temporal", TemporalSAE, TemporalSAEConfig)
|
|
@@ -9,7 +9,7 @@ import requests
|
|
|
9
9
|
import torch
|
|
10
10
|
import yaml
|
|
11
11
|
from huggingface_hub import hf_hub_download, hf_hub_url
|
|
12
|
-
from huggingface_hub.utils import EntryNotFoundError
|
|
12
|
+
from huggingface_hub.utils import EntryNotFoundError, build_hf_headers
|
|
13
13
|
from packaging.version import Version
|
|
14
14
|
from safetensors import safe_open
|
|
15
15
|
from safetensors.torch import load_file
|
|
@@ -46,6 +46,8 @@ LLM_METADATA_KEYS = {
|
|
|
46
46
|
"sae_lens_training_version",
|
|
47
47
|
"hook_name_out",
|
|
48
48
|
"hook_head_index_out",
|
|
49
|
+
"hf_hook_name",
|
|
50
|
+
"hf_hook_name_out",
|
|
49
51
|
}
|
|
50
52
|
|
|
51
53
|
|
|
@@ -523,6 +525,206 @@ def gemma_2_sae_huggingface_loader(
|
|
|
523
525
|
return cfg_dict, state_dict, log_sparsity
|
|
524
526
|
|
|
525
527
|
|
|
528
|
+
def _infer_gemma_3_raw_cfg_dict(repo_id: str, folder_name: str) -> dict[str, Any]:
|
|
529
|
+
"""
|
|
530
|
+
Infer the raw config dict for Gemma 3 SAEs from the repo_id and folder_name.
|
|
531
|
+
This is used when config.json doesn't exist in the repo.
|
|
532
|
+
"""
|
|
533
|
+
# Extract layer number from folder name
|
|
534
|
+
layer_match = re.search(r"layer_(\d+)", folder_name)
|
|
535
|
+
if layer_match is None:
|
|
536
|
+
raise ValueError(
|
|
537
|
+
f"Could not extract layer number from folder_name: {folder_name}"
|
|
538
|
+
)
|
|
539
|
+
layer = int(layer_match.group(1))
|
|
540
|
+
|
|
541
|
+
# Convert repo_id to model_name: google/gemma-scope-2-{size}-{suffix} -> google/gemma-3-{size}-{suffix}
|
|
542
|
+
model_name = repo_id.replace("gemma-scope-2", "gemma-3")
|
|
543
|
+
|
|
544
|
+
# Determine hook type and HF hook points based on folder_name
|
|
545
|
+
if "transcoder" in folder_name or "clt" in folder_name:
|
|
546
|
+
hf_hook_point_in = f"model.layers.{layer}.pre_feedforward_layernorm.output"
|
|
547
|
+
hf_hook_point_out = f"model.layers.{layer}.post_feedforward_layernorm.output"
|
|
548
|
+
elif "resid_post" in folder_name:
|
|
549
|
+
hf_hook_point_in = f"model.layers.{layer}.output"
|
|
550
|
+
hf_hook_point_out = None
|
|
551
|
+
elif "attn_out" in folder_name:
|
|
552
|
+
hf_hook_point_in = f"model.layers.{layer}.self_attn.o_proj.input"
|
|
553
|
+
hf_hook_point_out = None
|
|
554
|
+
elif "mlp_out" in folder_name:
|
|
555
|
+
hf_hook_point_in = f"model.layers.{layer}.post_feedforward_layernorm.output"
|
|
556
|
+
hf_hook_point_out = None
|
|
557
|
+
else:
|
|
558
|
+
raise ValueError(f"Could not infer hook type from folder_name: {folder_name}")
|
|
559
|
+
|
|
560
|
+
cfg: dict[str, Any] = {
|
|
561
|
+
"architecture": "jump_relu",
|
|
562
|
+
"model_name": model_name,
|
|
563
|
+
"hf_hook_point_in": hf_hook_point_in,
|
|
564
|
+
}
|
|
565
|
+
if hf_hook_point_out is not None:
|
|
566
|
+
cfg["hf_hook_point_out"] = hf_hook_point_out
|
|
567
|
+
|
|
568
|
+
return cfg
|
|
569
|
+
|
|
570
|
+
|
|
571
|
+
def get_gemma_3_config_from_hf(
|
|
572
|
+
repo_id: str,
|
|
573
|
+
folder_name: str,
|
|
574
|
+
device: str,
|
|
575
|
+
force_download: bool = False,
|
|
576
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
577
|
+
) -> dict[str, Any]:
|
|
578
|
+
# Try to load config.json from the repo, fall back to inferring if it doesn't exist
|
|
579
|
+
try:
|
|
580
|
+
config_path = hf_hub_download(
|
|
581
|
+
repo_id, f"{folder_name}/config.json", force_download=force_download
|
|
582
|
+
)
|
|
583
|
+
with open(config_path) as config_file:
|
|
584
|
+
raw_cfg_dict = json.load(config_file)
|
|
585
|
+
except EntryNotFoundError:
|
|
586
|
+
raw_cfg_dict = _infer_gemma_3_raw_cfg_dict(repo_id, folder_name)
|
|
587
|
+
|
|
588
|
+
if raw_cfg_dict.get("architecture") != "jump_relu":
|
|
589
|
+
raise ValueError(
|
|
590
|
+
f"Unexpected architecture in Gemma 3 config: {raw_cfg_dict.get('architecture')}"
|
|
591
|
+
)
|
|
592
|
+
|
|
593
|
+
layer_match = re.search(r"layer_(\d+)", folder_name)
|
|
594
|
+
if layer_match is None:
|
|
595
|
+
raise ValueError(
|
|
596
|
+
f"Could not extract layer number from folder_name: {folder_name}"
|
|
597
|
+
)
|
|
598
|
+
layer = int(layer_match.group(1))
|
|
599
|
+
hook_name_out = None
|
|
600
|
+
d_out = None
|
|
601
|
+
if "resid_post" in folder_name:
|
|
602
|
+
hook_name = f"blocks.{layer}.hook_resid_post"
|
|
603
|
+
elif "attn_out" in folder_name:
|
|
604
|
+
hook_name = f"blocks.{layer}.hook_attn_out"
|
|
605
|
+
elif "mlp_out" in folder_name:
|
|
606
|
+
hook_name = f"blocks.{layer}.hook_mlp_out"
|
|
607
|
+
elif "transcoder" in folder_name or "clt" in folder_name:
|
|
608
|
+
hook_name = f"blocks.{layer}.ln2.hook_normalized"
|
|
609
|
+
hook_name_out = f"blocks.{layer}.hook_mlp_out"
|
|
610
|
+
else:
|
|
611
|
+
raise ValueError("Hook name not found in folder_name.")
|
|
612
|
+
|
|
613
|
+
# hackily deal with clt file names
|
|
614
|
+
params_file_part = "/params.safetensors"
|
|
615
|
+
if "clt" in folder_name:
|
|
616
|
+
params_file_part = ".safetensors"
|
|
617
|
+
|
|
618
|
+
shapes_dict = get_safetensors_tensor_shapes(
|
|
619
|
+
repo_id, f"{folder_name}{params_file_part}"
|
|
620
|
+
)
|
|
621
|
+
d_in, d_sae = shapes_dict["w_enc"]
|
|
622
|
+
# TODO: update this for real model info
|
|
623
|
+
model_name = raw_cfg_dict["model_name"]
|
|
624
|
+
if "google" not in model_name:
|
|
625
|
+
model_name = "google/" + model_name
|
|
626
|
+
model_name = model_name.replace("-v3", "-3")
|
|
627
|
+
if "270m" in model_name:
|
|
628
|
+
# for some reason the 270m model on huggingface doesn't have the -pt suffix
|
|
629
|
+
model_name = model_name.replace("-pt", "")
|
|
630
|
+
|
|
631
|
+
architecture = "jumprelu"
|
|
632
|
+
if "transcoder" in folder_name or "clt" in folder_name:
|
|
633
|
+
architecture = "jumprelu_skip_transcoder"
|
|
634
|
+
d_out = shapes_dict["w_dec"][-1]
|
|
635
|
+
|
|
636
|
+
cfg = {
|
|
637
|
+
"architecture": architecture,
|
|
638
|
+
"d_in": d_in,
|
|
639
|
+
"d_sae": d_sae,
|
|
640
|
+
"dtype": "float32",
|
|
641
|
+
"model_name": model_name,
|
|
642
|
+
"hook_name": hook_name,
|
|
643
|
+
"hook_head_index": None,
|
|
644
|
+
"finetuning_scaling_factor": False,
|
|
645
|
+
"sae_lens_training_version": None,
|
|
646
|
+
"prepend_bos": True,
|
|
647
|
+
"dataset_path": "monology/pile-uncopyrighted",
|
|
648
|
+
"context_size": 1024,
|
|
649
|
+
"apply_b_dec_to_input": False,
|
|
650
|
+
"normalize_activations": None,
|
|
651
|
+
"hf_hook_name": raw_cfg_dict.get("hf_hook_point_in"),
|
|
652
|
+
}
|
|
653
|
+
if hook_name_out is not None:
|
|
654
|
+
cfg["hook_name_out"] = hook_name_out
|
|
655
|
+
cfg["hf_hook_name_out"] = raw_cfg_dict.get("hf_hook_point_out")
|
|
656
|
+
if d_out is not None:
|
|
657
|
+
cfg["d_out"] = d_out
|
|
658
|
+
if device is not None:
|
|
659
|
+
cfg["device"] = device
|
|
660
|
+
|
|
661
|
+
if cfg_overrides is not None:
|
|
662
|
+
cfg.update(cfg_overrides)
|
|
663
|
+
|
|
664
|
+
return cfg
|
|
665
|
+
|
|
666
|
+
|
|
667
|
+
def gemma_3_sae_huggingface_loader(
|
|
668
|
+
repo_id: str,
|
|
669
|
+
folder_name: str,
|
|
670
|
+
device: str = "cpu",
|
|
671
|
+
force_download: bool = False,
|
|
672
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
673
|
+
) -> tuple[dict[str, Any], dict[str, torch.Tensor], torch.Tensor | None]:
|
|
674
|
+
"""
|
|
675
|
+
Custom loader for Gemma 3 SAEs.
|
|
676
|
+
"""
|
|
677
|
+
cfg_dict = get_gemma_3_config_from_hf(
|
|
678
|
+
repo_id,
|
|
679
|
+
folder_name,
|
|
680
|
+
device,
|
|
681
|
+
force_download,
|
|
682
|
+
cfg_overrides,
|
|
683
|
+
)
|
|
684
|
+
|
|
685
|
+
# replace folder name of 65k with 64k
|
|
686
|
+
# TODO: remove this workaround once weights are fixed
|
|
687
|
+
if "270m-pt" in repo_id:
|
|
688
|
+
if "65k" in folder_name:
|
|
689
|
+
folder_name = folder_name.replace("65k", "64k")
|
|
690
|
+
# replace folder name of 262k with 250k
|
|
691
|
+
if "262k" in folder_name:
|
|
692
|
+
folder_name = folder_name.replace("262k", "250k")
|
|
693
|
+
|
|
694
|
+
params_file = "params.safetensors"
|
|
695
|
+
if "clt" in folder_name:
|
|
696
|
+
params_file = folder_name.split("/")[-1] + ".safetensors"
|
|
697
|
+
folder_name = "/".join(folder_name.split("/")[:-1])
|
|
698
|
+
|
|
699
|
+
# Download the SAE weights
|
|
700
|
+
sae_path = hf_hub_download(
|
|
701
|
+
repo_id=repo_id,
|
|
702
|
+
filename=params_file,
|
|
703
|
+
subfolder=folder_name,
|
|
704
|
+
force_download=force_download,
|
|
705
|
+
)
|
|
706
|
+
|
|
707
|
+
raw_state_dict = load_file(sae_path, device=device)
|
|
708
|
+
|
|
709
|
+
with torch.no_grad():
|
|
710
|
+
w_dec = raw_state_dict["w_dec"]
|
|
711
|
+
if "clt" in folder_name:
|
|
712
|
+
w_dec = w_dec.sum(dim=1).contiguous()
|
|
713
|
+
|
|
714
|
+
state_dict = {
|
|
715
|
+
"W_enc": raw_state_dict["w_enc"],
|
|
716
|
+
"W_dec": w_dec,
|
|
717
|
+
"b_enc": raw_state_dict["b_enc"],
|
|
718
|
+
"b_dec": raw_state_dict["b_dec"],
|
|
719
|
+
"threshold": raw_state_dict["threshold"],
|
|
720
|
+
}
|
|
721
|
+
|
|
722
|
+
if "affine_skip_connection" in raw_state_dict:
|
|
723
|
+
state_dict["W_skip"] = raw_state_dict["affine_skip_connection"]
|
|
724
|
+
|
|
725
|
+
return cfg_dict, state_dict, None
|
|
726
|
+
|
|
727
|
+
|
|
526
728
|
def get_goodfire_config_from_hf(
|
|
527
729
|
repo_id: str,
|
|
528
730
|
folder_name: str, # noqa: ARG001
|
|
@@ -753,10 +955,14 @@ def get_dictionary_learning_config_1_from_hf(
|
|
|
753
955
|
activation_fn = "topk" if trainer["dict_class"] == "AutoEncoderTopK" else "relu"
|
|
754
956
|
activation_fn_kwargs = {"k": trainer["k"]} if activation_fn == "topk" else {}
|
|
755
957
|
|
|
958
|
+
architecture = "standard"
|
|
959
|
+
if trainer["dict_class"] == "GatedAutoEncoder":
|
|
960
|
+
architecture = "gated"
|
|
961
|
+
elif trainer["dict_class"] == "MatryoshkaBatchTopKSAE":
|
|
962
|
+
architecture = "jumprelu"
|
|
963
|
+
|
|
756
964
|
return {
|
|
757
|
-
"architecture":
|
|
758
|
-
"gated" if trainer["dict_class"] == "GatedAutoEncoder" else "standard"
|
|
759
|
-
),
|
|
965
|
+
"architecture": architecture,
|
|
760
966
|
"d_in": trainer["activation_dim"],
|
|
761
967
|
"d_sae": trainer["dict_size"],
|
|
762
968
|
"dtype": "float32",
|
|
@@ -905,9 +1111,12 @@ def dictionary_learning_sae_huggingface_loader_1(
|
|
|
905
1111
|
)
|
|
906
1112
|
encoder = torch.load(encoder_path, map_location="cpu")
|
|
907
1113
|
|
|
1114
|
+
W_enc = encoder["W_enc"] if "W_enc" in encoder else encoder["encoder.weight"].T
|
|
1115
|
+
W_dec = encoder["W_dec"] if "W_dec" in encoder else encoder["decoder.weight"].T
|
|
1116
|
+
|
|
908
1117
|
state_dict = {
|
|
909
|
-
"W_enc":
|
|
910
|
-
"W_dec":
|
|
1118
|
+
"W_enc": W_enc,
|
|
1119
|
+
"W_dec": W_dec,
|
|
911
1120
|
"b_dec": encoder.get(
|
|
912
1121
|
"b_dec", encoder.get("bias", encoder.get("decoder_bias", None))
|
|
913
1122
|
),
|
|
@@ -915,6 +1124,8 @@ def dictionary_learning_sae_huggingface_loader_1(
|
|
|
915
1124
|
|
|
916
1125
|
if "encoder.bias" in encoder:
|
|
917
1126
|
state_dict["b_enc"] = encoder["encoder.bias"]
|
|
1127
|
+
if "b_enc" in encoder:
|
|
1128
|
+
state_dict["b_enc"] = encoder["b_enc"]
|
|
918
1129
|
|
|
919
1130
|
if "mag_bias" in encoder:
|
|
920
1131
|
state_dict["b_mag"] = encoder["mag_bias"]
|
|
@@ -923,6 +1134,12 @@ def dictionary_learning_sae_huggingface_loader_1(
|
|
|
923
1134
|
if "r_mag" in encoder:
|
|
924
1135
|
state_dict["r_mag"] = encoder["r_mag"]
|
|
925
1136
|
|
|
1137
|
+
if "threshold" in encoder:
|
|
1138
|
+
threshold = encoder["threshold"]
|
|
1139
|
+
if threshold.ndim == 0:
|
|
1140
|
+
threshold = torch.full((W_enc.size(1),), threshold)
|
|
1141
|
+
state_dict["threshold"] = threshold
|
|
1142
|
+
|
|
926
1143
|
return cfg_dict, state_dict, None
|
|
927
1144
|
|
|
928
1145
|
|
|
@@ -1414,38 +1631,36 @@ def mwhanna_transcoder_huggingface_loader(
|
|
|
1414
1631
|
return cfg_dict, state_dict, None
|
|
1415
1632
|
|
|
1416
1633
|
|
|
1417
|
-
def get_safetensors_tensor_shapes(
|
|
1634
|
+
def get_safetensors_tensor_shapes(repo_id: str, filename: str) -> dict[str, list[int]]:
|
|
1418
1635
|
"""
|
|
1419
|
-
Get tensor shapes from a safetensors file
|
|
1636
|
+
Get tensor shapes from a safetensors file on HuggingFace Hub
|
|
1420
1637
|
without downloading the entire file.
|
|
1421
1638
|
|
|
1639
|
+
Uses HTTP range requests to fetch only the metadata header.
|
|
1640
|
+
|
|
1422
1641
|
Args:
|
|
1423
|
-
|
|
1642
|
+
repo_id: HuggingFace repo ID (e.g., "gg-gs/gemma-scope-2-1b-pt")
|
|
1643
|
+
filename: Path to the safetensors file within the repo
|
|
1424
1644
|
|
|
1425
1645
|
Returns:
|
|
1426
1646
|
Dictionary mapping tensor names to their shapes
|
|
1427
1647
|
"""
|
|
1428
|
-
|
|
1429
|
-
response = requests.head(url, timeout=10)
|
|
1430
|
-
response.raise_for_status()
|
|
1648
|
+
url = hf_hub_url(repo_id, filename)
|
|
1431
1649
|
|
|
1432
|
-
|
|
1433
|
-
|
|
1434
|
-
raise ValueError("Server does not support range requests")
|
|
1650
|
+
# Get HuggingFace headers (includes auth token if available)
|
|
1651
|
+
hf_headers = build_hf_headers()
|
|
1435
1652
|
|
|
1436
1653
|
# Fetch first 8 bytes to get metadata size
|
|
1437
|
-
headers = {"Range": "bytes=0-7"}
|
|
1654
|
+
headers = {**hf_headers, "Range": "bytes=0-7"}
|
|
1438
1655
|
response = requests.get(url, headers=headers, timeout=10)
|
|
1439
|
-
|
|
1440
|
-
raise ValueError("Failed to fetch initial bytes for metadata size")
|
|
1656
|
+
response.raise_for_status()
|
|
1441
1657
|
|
|
1442
1658
|
meta_size = int.from_bytes(response.content, byteorder="little")
|
|
1443
1659
|
|
|
1444
1660
|
# Fetch the metadata header
|
|
1445
|
-
headers = {"Range": f"bytes=8-{8 + meta_size - 1}"}
|
|
1661
|
+
headers = {**hf_headers, "Range": f"bytes=8-{8 + meta_size - 1}"}
|
|
1446
1662
|
response = requests.get(url, headers=headers, timeout=10)
|
|
1447
|
-
|
|
1448
|
-
raise ValueError("Failed to fetch metadata header")
|
|
1663
|
+
response.raise_for_status()
|
|
1449
1664
|
|
|
1450
1665
|
metadata_json = response.content.decode("utf-8").strip()
|
|
1451
1666
|
metadata = json.loads(metadata_json)
|
|
@@ -1525,9 +1740,10 @@ def get_mntss_clt_layer_config_from_hf(
|
|
|
1525
1740
|
with open(base_config_path) as f:
|
|
1526
1741
|
cfg_info: dict[str, Any] = yaml.safe_load(f)
|
|
1527
1742
|
|
|
1528
|
-
# Get tensor shapes without downloading full files
|
|
1529
|
-
|
|
1530
|
-
|
|
1743
|
+
# Get tensor shapes without downloading full files
|
|
1744
|
+
encoder_shapes = get_safetensors_tensor_shapes(
|
|
1745
|
+
repo_id, f"W_enc_{folder_name}.safetensors"
|
|
1746
|
+
)
|
|
1531
1747
|
|
|
1532
1748
|
# Extract shapes for the required tensors
|
|
1533
1749
|
b_dec_shape = encoder_shapes[f"b_dec_{folder_name}"]
|
|
@@ -1663,6 +1879,7 @@ NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
|
|
|
1663
1879
|
"sae_lens": sae_lens_huggingface_loader,
|
|
1664
1880
|
"connor_rob_hook_z": connor_rob_hook_z_huggingface_loader,
|
|
1665
1881
|
"gemma_2": gemma_2_sae_huggingface_loader,
|
|
1882
|
+
"gemma_3": gemma_3_sae_huggingface_loader,
|
|
1666
1883
|
"llama_scope": llama_scope_sae_huggingface_loader,
|
|
1667
1884
|
"llama_scope_r1_distill": llama_scope_r1_distill_sae_huggingface_loader,
|
|
1668
1885
|
"dictionary_learning_1": dictionary_learning_sae_huggingface_loader_1,
|
|
@@ -1680,6 +1897,7 @@ NAMED_PRETRAINED_SAE_CONFIG_GETTERS: dict[str, PretrainedSaeConfigHuggingfaceLoa
|
|
|
1680
1897
|
"sae_lens": get_sae_lens_config_from_hf,
|
|
1681
1898
|
"connor_rob_hook_z": get_connor_rob_hook_z_config_from_hf,
|
|
1682
1899
|
"gemma_2": get_gemma_2_config_from_hf,
|
|
1900
|
+
"gemma_3": get_gemma_3_config_from_hf,
|
|
1683
1901
|
"llama_scope": get_llama_scope_config_from_hf,
|
|
1684
1902
|
"llama_scope_r1_distill": get_llama_scope_r1_distill_config_from_hf,
|
|
1685
1903
|
"dictionary_learning_1": get_dictionary_learning_config_1_from_hf,
|
|
@@ -186,13 +186,13 @@ class PretokenizeRunner:
|
|
|
186
186
|
"""
|
|
187
187
|
Load the dataset, tokenize it, and save it to disk and/or upload to Huggingface.
|
|
188
188
|
"""
|
|
189
|
-
dataset = load_dataset(
|
|
189
|
+
dataset = load_dataset( # type: ignore
|
|
190
190
|
self.cfg.dataset_path,
|
|
191
191
|
name=self.cfg.dataset_name,
|
|
192
192
|
data_dir=self.cfg.data_dir,
|
|
193
193
|
data_files=self.cfg.data_files,
|
|
194
|
-
split=self.cfg.split,
|
|
195
|
-
streaming=self.cfg.streaming,
|
|
194
|
+
split=self.cfg.split, # type: ignore
|
|
195
|
+
streaming=self.cfg.streaming, # type: ignore
|
|
196
196
|
)
|
|
197
197
|
if isinstance(dataset, DatasetDict):
|
|
198
198
|
raise ValueError(
|