sae-lens 6.23.0__tar.gz → 6.25.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sae_lens-6.23.0 → sae_lens-6.25.0}/PKG-INFO +2 -2
- {sae_lens-6.23.0 → sae_lens-6.25.0}/README.md +1 -1
- {sae_lens-6.23.0 → sae_lens-6.25.0}/pyproject.toml +2 -1
- {sae_lens-6.23.0 → sae_lens-6.25.0}/sae_lens/__init__.py +8 -1
- {sae_lens-6.23.0 → sae_lens-6.25.0}/sae_lens/loading/pretrained_sae_loaders.py +213 -19
- {sae_lens-6.23.0 → sae_lens-6.25.0}/sae_lens/pretokenize_runner.py +3 -3
- sae_lens-6.25.0/sae_lens/pretrained_saes.yaml +41813 -0
- {sae_lens-6.23.0 → sae_lens-6.25.0}/sae_lens/saes/__init__.py +4 -0
- {sae_lens-6.23.0 → sae_lens-6.25.0}/sae_lens/saes/transcoder.py +41 -0
- {sae_lens-6.23.0 → sae_lens-6.25.0}/sae_lens/training/activations_store.py +1 -1
- sae_lens-6.23.0/sae_lens/pretrained_saes.yaml +0 -15039
- {sae_lens-6.23.0 → sae_lens-6.25.0}/LICENSE +0 -0
- {sae_lens-6.23.0 → sae_lens-6.25.0}/sae_lens/analysis/__init__.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.25.0}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.25.0}/sae_lens/analysis/neuronpedia_integration.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.25.0}/sae_lens/cache_activations_runner.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.25.0}/sae_lens/config.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.25.0}/sae_lens/constants.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.25.0}/sae_lens/evals.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.25.0}/sae_lens/llm_sae_training_runner.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.25.0}/sae_lens/load_model.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.25.0}/sae_lens/loading/__init__.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.25.0}/sae_lens/loading/pretrained_saes_directory.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.25.0}/sae_lens/registry.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.25.0}/sae_lens/saes/batchtopk_sae.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.25.0}/sae_lens/saes/gated_sae.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.25.0}/sae_lens/saes/jumprelu_sae.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.25.0}/sae_lens/saes/matryoshka_batchtopk_sae.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.25.0}/sae_lens/saes/sae.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.25.0}/sae_lens/saes/standard_sae.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.25.0}/sae_lens/saes/temporal_sae.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.25.0}/sae_lens/saes/topk_sae.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.25.0}/sae_lens/tokenization_and_batching.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.25.0}/sae_lens/training/__init__.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.25.0}/sae_lens/training/activation_scaler.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.25.0}/sae_lens/training/mixing_buffer.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.25.0}/sae_lens/training/optim.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.25.0}/sae_lens/training/sae_trainer.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.25.0}/sae_lens/training/types.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.25.0}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.25.0}/sae_lens/tutorial/tsea.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.25.0}/sae_lens/util.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: sae-lens
|
|
3
|
-
Version: 6.
|
|
3
|
+
Version: 6.25.0
|
|
4
4
|
Summary: Training and Analyzing Sparse Autoencoders (SAEs)
|
|
5
5
|
License: MIT
|
|
6
6
|
License-File: LICENSE
|
|
@@ -62,7 +62,7 @@ This library is maintained by [Joseph Bloom](https://www.decoderesearch.com/), [
|
|
|
62
62
|
|
|
63
63
|
## Loading Pre-trained SAEs.
|
|
64
64
|
|
|
65
|
-
Pre-trained SAEs for various models can be imported via SAE Lens. See this [page](https://decoderesearch.github.io/SAELens/
|
|
65
|
+
Pre-trained SAEs for various models can be imported via SAE Lens. See this [page](https://decoderesearch.github.io/SAELens/pretrained_saes/) for a list of all SAEs.
|
|
66
66
|
|
|
67
67
|
## Migrating to SAELens v6
|
|
68
68
|
|
|
@@ -26,7 +26,7 @@ This library is maintained by [Joseph Bloom](https://www.decoderesearch.com/), [
|
|
|
26
26
|
|
|
27
27
|
## Loading Pre-trained SAEs.
|
|
28
28
|
|
|
29
|
-
Pre-trained SAEs for various models can be imported via SAE Lens. See this [page](https://decoderesearch.github.io/SAELens/
|
|
29
|
+
Pre-trained SAEs for various models can be imported via SAE Lens. See this [page](https://decoderesearch.github.io/SAELens/pretrained_saes/) for a list of all SAEs.
|
|
30
30
|
|
|
31
31
|
## Migrating to SAELens v6
|
|
32
32
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "sae-lens"
|
|
3
|
-
version = "6.
|
|
3
|
+
version = "6.25.0"
|
|
4
4
|
description = "Training and Analyzing Sparse Autoencoders (SAEs)"
|
|
5
5
|
authors = ["Joseph Bloom"]
|
|
6
6
|
readme = "README.md"
|
|
@@ -47,6 +47,7 @@ docstr-coverage = "^2.3.2"
|
|
|
47
47
|
mkdocs = "^1.6.1"
|
|
48
48
|
mkdocs-material = "^9.5.34"
|
|
49
49
|
mkdocs-autorefs = "^1.4.2"
|
|
50
|
+
mkdocs-redirects = "^1.2.1"
|
|
50
51
|
mkdocs-section-index = "^0.3.9"
|
|
51
52
|
mkdocstrings = "^0.25.2"
|
|
52
53
|
mkdocstrings-python = "^1.10.9"
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
# ruff: noqa: E402
|
|
2
|
-
__version__ = "6.
|
|
2
|
+
__version__ = "6.25.0"
|
|
3
3
|
|
|
4
4
|
import logging
|
|
5
5
|
|
|
@@ -15,6 +15,8 @@ from sae_lens.saes import (
|
|
|
15
15
|
GatedTrainingSAEConfig,
|
|
16
16
|
JumpReLUSAE,
|
|
17
17
|
JumpReLUSAEConfig,
|
|
18
|
+
JumpReLUSkipTranscoder,
|
|
19
|
+
JumpReLUSkipTranscoderConfig,
|
|
18
20
|
JumpReLUTrainingSAE,
|
|
19
21
|
JumpReLUTrainingSAEConfig,
|
|
20
22
|
JumpReLUTranscoder,
|
|
@@ -105,6 +107,8 @@ __all__ = [
|
|
|
105
107
|
"SkipTranscoderConfig",
|
|
106
108
|
"JumpReLUTranscoder",
|
|
107
109
|
"JumpReLUTranscoderConfig",
|
|
110
|
+
"JumpReLUSkipTranscoder",
|
|
111
|
+
"JumpReLUSkipTranscoderConfig",
|
|
108
112
|
"MatryoshkaBatchTopKTrainingSAE",
|
|
109
113
|
"MatryoshkaBatchTopKTrainingSAEConfig",
|
|
110
114
|
"TemporalSAE",
|
|
@@ -131,4 +135,7 @@ register_sae_training_class(
|
|
|
131
135
|
register_sae_class("transcoder", Transcoder, TranscoderConfig)
|
|
132
136
|
register_sae_class("skip_transcoder", SkipTranscoder, SkipTranscoderConfig)
|
|
133
137
|
register_sae_class("jumprelu_transcoder", JumpReLUTranscoder, JumpReLUTranscoderConfig)
|
|
138
|
+
register_sae_class(
|
|
139
|
+
"jumprelu_skip_transcoder", JumpReLUSkipTranscoder, JumpReLUSkipTranscoderConfig
|
|
140
|
+
)
|
|
134
141
|
register_sae_class("temporal", TemporalSAE, TemporalSAEConfig)
|
|
@@ -9,7 +9,7 @@ import requests
|
|
|
9
9
|
import torch
|
|
10
10
|
import yaml
|
|
11
11
|
from huggingface_hub import hf_hub_download, hf_hub_url
|
|
12
|
-
from huggingface_hub.utils import EntryNotFoundError
|
|
12
|
+
from huggingface_hub.utils import EntryNotFoundError, build_hf_headers
|
|
13
13
|
from packaging.version import Version
|
|
14
14
|
from safetensors import safe_open
|
|
15
15
|
from safetensors.torch import load_file
|
|
@@ -46,6 +46,8 @@ LLM_METADATA_KEYS = {
|
|
|
46
46
|
"sae_lens_training_version",
|
|
47
47
|
"hook_name_out",
|
|
48
48
|
"hook_head_index_out",
|
|
49
|
+
"hf_hook_name",
|
|
50
|
+
"hf_hook_name_out",
|
|
49
51
|
}
|
|
50
52
|
|
|
51
53
|
|
|
@@ -523,6 +525,197 @@ def gemma_2_sae_huggingface_loader(
|
|
|
523
525
|
return cfg_dict, state_dict, log_sparsity
|
|
524
526
|
|
|
525
527
|
|
|
528
|
+
def _infer_gemma_3_raw_cfg_dict(repo_id: str, folder_name: str) -> dict[str, Any]:
|
|
529
|
+
"""
|
|
530
|
+
Infer the raw config dict for Gemma 3 SAEs from the repo_id and folder_name.
|
|
531
|
+
This is used when config.json doesn't exist in the repo.
|
|
532
|
+
"""
|
|
533
|
+
# Extract layer number from folder name
|
|
534
|
+
layer_match = re.search(r"layer_(\d+)", folder_name)
|
|
535
|
+
if layer_match is None:
|
|
536
|
+
raise ValueError(
|
|
537
|
+
f"Could not extract layer number from folder_name: {folder_name}"
|
|
538
|
+
)
|
|
539
|
+
layer = int(layer_match.group(1))
|
|
540
|
+
|
|
541
|
+
# Convert repo_id to model_name: google/gemma-scope-2-{size}-{suffix} -> google/gemma-3-{size}-{suffix}
|
|
542
|
+
model_name = repo_id.replace("gemma-scope-2", "gemma-3")
|
|
543
|
+
|
|
544
|
+
# Determine hook type and HF hook points based on folder_name
|
|
545
|
+
if "transcoder" in folder_name or "clt" in folder_name:
|
|
546
|
+
hf_hook_point_in = f"model.layers.{layer}.pre_feedforward_layernorm.output"
|
|
547
|
+
hf_hook_point_out = f"model.layers.{layer}.post_feedforward_layernorm.output"
|
|
548
|
+
elif "resid_post" in folder_name:
|
|
549
|
+
hf_hook_point_in = f"model.layers.{layer}.output"
|
|
550
|
+
hf_hook_point_out = None
|
|
551
|
+
elif "attn_out" in folder_name:
|
|
552
|
+
hf_hook_point_in = f"model.layers.{layer}.self_attn.o_proj.input"
|
|
553
|
+
hf_hook_point_out = None
|
|
554
|
+
elif "mlp_out" in folder_name:
|
|
555
|
+
hf_hook_point_in = f"model.layers.{layer}.post_feedforward_layernorm.output"
|
|
556
|
+
hf_hook_point_out = None
|
|
557
|
+
else:
|
|
558
|
+
raise ValueError(f"Could not infer hook type from folder_name: {folder_name}")
|
|
559
|
+
|
|
560
|
+
cfg: dict[str, Any] = {
|
|
561
|
+
"architecture": "jump_relu",
|
|
562
|
+
"model_name": model_name,
|
|
563
|
+
"hf_hook_point_in": hf_hook_point_in,
|
|
564
|
+
}
|
|
565
|
+
if hf_hook_point_out is not None:
|
|
566
|
+
cfg["hf_hook_point_out"] = hf_hook_point_out
|
|
567
|
+
|
|
568
|
+
return cfg
|
|
569
|
+
|
|
570
|
+
|
|
571
|
+
def get_gemma_3_config_from_hf(
|
|
572
|
+
repo_id: str,
|
|
573
|
+
folder_name: str,
|
|
574
|
+
device: str,
|
|
575
|
+
force_download: bool = False,
|
|
576
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
577
|
+
) -> dict[str, Any]:
|
|
578
|
+
# Try to load config.json from the repo, fall back to inferring if it doesn't exist
|
|
579
|
+
try:
|
|
580
|
+
config_path = hf_hub_download(
|
|
581
|
+
repo_id, f"{folder_name}/config.json", force_download=force_download
|
|
582
|
+
)
|
|
583
|
+
with open(config_path) as config_file:
|
|
584
|
+
raw_cfg_dict = json.load(config_file)
|
|
585
|
+
except EntryNotFoundError:
|
|
586
|
+
raw_cfg_dict = _infer_gemma_3_raw_cfg_dict(repo_id, folder_name)
|
|
587
|
+
|
|
588
|
+
if raw_cfg_dict.get("architecture") != "jump_relu":
|
|
589
|
+
raise ValueError(
|
|
590
|
+
f"Unexpected architecture in Gemma 3 config: {raw_cfg_dict.get('architecture')}"
|
|
591
|
+
)
|
|
592
|
+
|
|
593
|
+
layer_match = re.search(r"layer_(\d+)", folder_name)
|
|
594
|
+
if layer_match is None:
|
|
595
|
+
raise ValueError(
|
|
596
|
+
f"Could not extract layer number from folder_name: {folder_name}"
|
|
597
|
+
)
|
|
598
|
+
layer = int(layer_match.group(1))
|
|
599
|
+
hook_name_out = None
|
|
600
|
+
d_out = None
|
|
601
|
+
if "resid_post" in folder_name:
|
|
602
|
+
hook_name = f"blocks.{layer}.hook_resid_post"
|
|
603
|
+
elif "attn_out" in folder_name:
|
|
604
|
+
hook_name = f"blocks.{layer}.hook_attn_out"
|
|
605
|
+
elif "mlp_out" in folder_name:
|
|
606
|
+
hook_name = f"blocks.{layer}.hook_mlp_out"
|
|
607
|
+
elif "transcoder" in folder_name or "clt" in folder_name:
|
|
608
|
+
hook_name = f"blocks.{layer}.ln2.hook_normalized"
|
|
609
|
+
hook_name_out = f"blocks.{layer}.hook_mlp_out"
|
|
610
|
+
else:
|
|
611
|
+
raise ValueError("Hook name not found in folder_name.")
|
|
612
|
+
|
|
613
|
+
# hackily deal with clt file names
|
|
614
|
+
params_file_part = "/params.safetensors"
|
|
615
|
+
if "clt" in folder_name:
|
|
616
|
+
params_file_part = ".safetensors"
|
|
617
|
+
|
|
618
|
+
shapes_dict = get_safetensors_tensor_shapes(
|
|
619
|
+
repo_id, f"{folder_name}{params_file_part}"
|
|
620
|
+
)
|
|
621
|
+
d_in, d_sae = shapes_dict["w_enc"]
|
|
622
|
+
# TODO: update this for real model info
|
|
623
|
+
model_name = raw_cfg_dict["model_name"]
|
|
624
|
+
if "google" not in model_name:
|
|
625
|
+
model_name = "google/" + model_name
|
|
626
|
+
model_name = model_name.replace("-v3", "-3")
|
|
627
|
+
if "270m" in model_name:
|
|
628
|
+
# for some reason the 270m model on huggingface doesn't have the -pt suffix
|
|
629
|
+
model_name = model_name.replace("-pt", "")
|
|
630
|
+
|
|
631
|
+
architecture = "jumprelu"
|
|
632
|
+
if "transcoder" in folder_name or "clt" in folder_name:
|
|
633
|
+
architecture = "jumprelu_skip_transcoder"
|
|
634
|
+
d_out = shapes_dict["w_dec"][-1]
|
|
635
|
+
|
|
636
|
+
cfg = {
|
|
637
|
+
"architecture": architecture,
|
|
638
|
+
"d_in": d_in,
|
|
639
|
+
"d_sae": d_sae,
|
|
640
|
+
"dtype": "float32",
|
|
641
|
+
"model_name": model_name,
|
|
642
|
+
"hook_name": hook_name,
|
|
643
|
+
"hook_head_index": None,
|
|
644
|
+
"finetuning_scaling_factor": False,
|
|
645
|
+
"sae_lens_training_version": None,
|
|
646
|
+
"prepend_bos": True,
|
|
647
|
+
"dataset_path": "monology/pile-uncopyrighted",
|
|
648
|
+
"context_size": 1024,
|
|
649
|
+
"apply_b_dec_to_input": False,
|
|
650
|
+
"normalize_activations": None,
|
|
651
|
+
"hf_hook_name": raw_cfg_dict.get("hf_hook_point_in"),
|
|
652
|
+
}
|
|
653
|
+
if hook_name_out is not None:
|
|
654
|
+
cfg["hook_name_out"] = hook_name_out
|
|
655
|
+
cfg["hf_hook_name_out"] = raw_cfg_dict.get("hf_hook_point_out")
|
|
656
|
+
if d_out is not None:
|
|
657
|
+
cfg["d_out"] = d_out
|
|
658
|
+
if device is not None:
|
|
659
|
+
cfg["device"] = device
|
|
660
|
+
|
|
661
|
+
if cfg_overrides is not None:
|
|
662
|
+
cfg.update(cfg_overrides)
|
|
663
|
+
|
|
664
|
+
return cfg
|
|
665
|
+
|
|
666
|
+
|
|
667
|
+
def gemma_3_sae_huggingface_loader(
|
|
668
|
+
repo_id: str,
|
|
669
|
+
folder_name: str,
|
|
670
|
+
device: str = "cpu",
|
|
671
|
+
force_download: bool = False,
|
|
672
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
673
|
+
) -> tuple[dict[str, Any], dict[str, torch.Tensor], torch.Tensor | None]:
|
|
674
|
+
"""
|
|
675
|
+
Custom loader for Gemma 3 SAEs.
|
|
676
|
+
"""
|
|
677
|
+
cfg_dict = get_gemma_3_config_from_hf(
|
|
678
|
+
repo_id,
|
|
679
|
+
folder_name,
|
|
680
|
+
device,
|
|
681
|
+
force_download,
|
|
682
|
+
cfg_overrides,
|
|
683
|
+
)
|
|
684
|
+
|
|
685
|
+
params_file = "params.safetensors"
|
|
686
|
+
if "clt" in folder_name:
|
|
687
|
+
params_file = folder_name.split("/")[-1] + ".safetensors"
|
|
688
|
+
folder_name = "/".join(folder_name.split("/")[:-1])
|
|
689
|
+
|
|
690
|
+
# Download the SAE weights
|
|
691
|
+
sae_path = hf_hub_download(
|
|
692
|
+
repo_id=repo_id,
|
|
693
|
+
filename=params_file,
|
|
694
|
+
subfolder=folder_name,
|
|
695
|
+
force_download=force_download,
|
|
696
|
+
)
|
|
697
|
+
|
|
698
|
+
raw_state_dict = load_file(sae_path, device=device)
|
|
699
|
+
|
|
700
|
+
with torch.no_grad():
|
|
701
|
+
w_dec = raw_state_dict["w_dec"]
|
|
702
|
+
if "clt" in folder_name:
|
|
703
|
+
w_dec = w_dec.sum(dim=1).contiguous()
|
|
704
|
+
|
|
705
|
+
state_dict = {
|
|
706
|
+
"W_enc": raw_state_dict["w_enc"],
|
|
707
|
+
"W_dec": w_dec,
|
|
708
|
+
"b_enc": raw_state_dict["b_enc"],
|
|
709
|
+
"b_dec": raw_state_dict["b_dec"],
|
|
710
|
+
"threshold": raw_state_dict["threshold"],
|
|
711
|
+
}
|
|
712
|
+
|
|
713
|
+
if "affine_skip_connection" in raw_state_dict:
|
|
714
|
+
state_dict["W_skip"] = raw_state_dict["affine_skip_connection"]
|
|
715
|
+
|
|
716
|
+
return cfg_dict, state_dict, None
|
|
717
|
+
|
|
718
|
+
|
|
526
719
|
def get_goodfire_config_from_hf(
|
|
527
720
|
repo_id: str,
|
|
528
721
|
folder_name: str, # noqa: ARG001
|
|
@@ -1429,38 +1622,36 @@ def mwhanna_transcoder_huggingface_loader(
|
|
|
1429
1622
|
return cfg_dict, state_dict, None
|
|
1430
1623
|
|
|
1431
1624
|
|
|
1432
|
-
def get_safetensors_tensor_shapes(
|
|
1625
|
+
def get_safetensors_tensor_shapes(repo_id: str, filename: str) -> dict[str, list[int]]:
|
|
1433
1626
|
"""
|
|
1434
|
-
Get tensor shapes from a safetensors file
|
|
1627
|
+
Get tensor shapes from a safetensors file on HuggingFace Hub
|
|
1435
1628
|
without downloading the entire file.
|
|
1436
1629
|
|
|
1630
|
+
Uses HTTP range requests to fetch only the metadata header.
|
|
1631
|
+
|
|
1437
1632
|
Args:
|
|
1438
|
-
|
|
1633
|
+
repo_id: HuggingFace repo ID (e.g., "gg-gs/gemma-scope-2-1b-pt")
|
|
1634
|
+
filename: Path to the safetensors file within the repo
|
|
1439
1635
|
|
|
1440
1636
|
Returns:
|
|
1441
1637
|
Dictionary mapping tensor names to their shapes
|
|
1442
1638
|
"""
|
|
1443
|
-
|
|
1444
|
-
response = requests.head(url, timeout=10)
|
|
1445
|
-
response.raise_for_status()
|
|
1639
|
+
url = hf_hub_url(repo_id, filename)
|
|
1446
1640
|
|
|
1447
|
-
|
|
1448
|
-
|
|
1449
|
-
raise ValueError("Server does not support range requests")
|
|
1641
|
+
# Get HuggingFace headers (includes auth token if available)
|
|
1642
|
+
hf_headers = build_hf_headers()
|
|
1450
1643
|
|
|
1451
1644
|
# Fetch first 8 bytes to get metadata size
|
|
1452
|
-
headers = {"Range": "bytes=0-7"}
|
|
1645
|
+
headers = {**hf_headers, "Range": "bytes=0-7"}
|
|
1453
1646
|
response = requests.get(url, headers=headers, timeout=10)
|
|
1454
|
-
|
|
1455
|
-
raise ValueError("Failed to fetch initial bytes for metadata size")
|
|
1647
|
+
response.raise_for_status()
|
|
1456
1648
|
|
|
1457
1649
|
meta_size = int.from_bytes(response.content, byteorder="little")
|
|
1458
1650
|
|
|
1459
1651
|
# Fetch the metadata header
|
|
1460
|
-
headers = {"Range": f"bytes=8-{8 + meta_size - 1}"}
|
|
1652
|
+
headers = {**hf_headers, "Range": f"bytes=8-{8 + meta_size - 1}"}
|
|
1461
1653
|
response = requests.get(url, headers=headers, timeout=10)
|
|
1462
|
-
|
|
1463
|
-
raise ValueError("Failed to fetch metadata header")
|
|
1654
|
+
response.raise_for_status()
|
|
1464
1655
|
|
|
1465
1656
|
metadata_json = response.content.decode("utf-8").strip()
|
|
1466
1657
|
metadata = json.loads(metadata_json)
|
|
@@ -1540,9 +1731,10 @@ def get_mntss_clt_layer_config_from_hf(
|
|
|
1540
1731
|
with open(base_config_path) as f:
|
|
1541
1732
|
cfg_info: dict[str, Any] = yaml.safe_load(f)
|
|
1542
1733
|
|
|
1543
|
-
# Get tensor shapes without downloading full files
|
|
1544
|
-
|
|
1545
|
-
|
|
1734
|
+
# Get tensor shapes without downloading full files
|
|
1735
|
+
encoder_shapes = get_safetensors_tensor_shapes(
|
|
1736
|
+
repo_id, f"W_enc_{folder_name}.safetensors"
|
|
1737
|
+
)
|
|
1546
1738
|
|
|
1547
1739
|
# Extract shapes for the required tensors
|
|
1548
1740
|
b_dec_shape = encoder_shapes[f"b_dec_{folder_name}"]
|
|
@@ -1678,6 +1870,7 @@ NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
|
|
|
1678
1870
|
"sae_lens": sae_lens_huggingface_loader,
|
|
1679
1871
|
"connor_rob_hook_z": connor_rob_hook_z_huggingface_loader,
|
|
1680
1872
|
"gemma_2": gemma_2_sae_huggingface_loader,
|
|
1873
|
+
"gemma_3": gemma_3_sae_huggingface_loader,
|
|
1681
1874
|
"llama_scope": llama_scope_sae_huggingface_loader,
|
|
1682
1875
|
"llama_scope_r1_distill": llama_scope_r1_distill_sae_huggingface_loader,
|
|
1683
1876
|
"dictionary_learning_1": dictionary_learning_sae_huggingface_loader_1,
|
|
@@ -1695,6 +1888,7 @@ NAMED_PRETRAINED_SAE_CONFIG_GETTERS: dict[str, PretrainedSaeConfigHuggingfaceLoa
|
|
|
1695
1888
|
"sae_lens": get_sae_lens_config_from_hf,
|
|
1696
1889
|
"connor_rob_hook_z": get_connor_rob_hook_z_config_from_hf,
|
|
1697
1890
|
"gemma_2": get_gemma_2_config_from_hf,
|
|
1891
|
+
"gemma_3": get_gemma_3_config_from_hf,
|
|
1698
1892
|
"llama_scope": get_llama_scope_config_from_hf,
|
|
1699
1893
|
"llama_scope_r1_distill": get_llama_scope_r1_distill_config_from_hf,
|
|
1700
1894
|
"dictionary_learning_1": get_dictionary_learning_config_1_from_hf,
|
|
@@ -186,13 +186,13 @@ class PretokenizeRunner:
|
|
|
186
186
|
"""
|
|
187
187
|
Load the dataset, tokenize it, and save it to disk and/or upload to Huggingface.
|
|
188
188
|
"""
|
|
189
|
-
dataset = load_dataset(
|
|
189
|
+
dataset = load_dataset( # type: ignore
|
|
190
190
|
self.cfg.dataset_path,
|
|
191
191
|
name=self.cfg.dataset_name,
|
|
192
192
|
data_dir=self.cfg.data_dir,
|
|
193
193
|
data_files=self.cfg.data_files,
|
|
194
|
-
split=self.cfg.split,
|
|
195
|
-
streaming=self.cfg.streaming,
|
|
194
|
+
split=self.cfg.split, # type: ignore
|
|
195
|
+
streaming=self.cfg.streaming, # type: ignore
|
|
196
196
|
)
|
|
197
197
|
if isinstance(dataset, DatasetDict):
|
|
198
198
|
raise ValueError(
|