sae-lens 6.23.0__tar.gz → 6.24.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sae_lens-6.23.0 → sae_lens-6.24.0}/PKG-INFO +2 -2
- {sae_lens-6.23.0 → sae_lens-6.24.0}/README.md +1 -1
- {sae_lens-6.23.0 → sae_lens-6.24.0}/pyproject.toml +2 -1
- {sae_lens-6.23.0 → sae_lens-6.24.0}/sae_lens/__init__.py +8 -1
- {sae_lens-6.23.0 → sae_lens-6.24.0}/sae_lens/loading/pretrained_sae_loaders.py +222 -19
- {sae_lens-6.23.0 → sae_lens-6.24.0}/sae_lens/pretokenize_runner.py +3 -3
- sae_lens-6.24.0/sae_lens/pretrained_saes.yaml +41797 -0
- {sae_lens-6.23.0 → sae_lens-6.24.0}/sae_lens/saes/__init__.py +4 -0
- {sae_lens-6.23.0 → sae_lens-6.24.0}/sae_lens/saes/transcoder.py +41 -0
- {sae_lens-6.23.0 → sae_lens-6.24.0}/sae_lens/training/activations_store.py +1 -1
- sae_lens-6.23.0/sae_lens/pretrained_saes.yaml +0 -15039
- {sae_lens-6.23.0 → sae_lens-6.24.0}/LICENSE +0 -0
- {sae_lens-6.23.0 → sae_lens-6.24.0}/sae_lens/analysis/__init__.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.24.0}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.24.0}/sae_lens/analysis/neuronpedia_integration.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.24.0}/sae_lens/cache_activations_runner.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.24.0}/sae_lens/config.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.24.0}/sae_lens/constants.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.24.0}/sae_lens/evals.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.24.0}/sae_lens/llm_sae_training_runner.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.24.0}/sae_lens/load_model.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.24.0}/sae_lens/loading/__init__.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.24.0}/sae_lens/loading/pretrained_saes_directory.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.24.0}/sae_lens/registry.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.24.0}/sae_lens/saes/batchtopk_sae.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.24.0}/sae_lens/saes/gated_sae.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.24.0}/sae_lens/saes/jumprelu_sae.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.24.0}/sae_lens/saes/matryoshka_batchtopk_sae.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.24.0}/sae_lens/saes/sae.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.24.0}/sae_lens/saes/standard_sae.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.24.0}/sae_lens/saes/temporal_sae.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.24.0}/sae_lens/saes/topk_sae.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.24.0}/sae_lens/tokenization_and_batching.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.24.0}/sae_lens/training/__init__.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.24.0}/sae_lens/training/activation_scaler.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.24.0}/sae_lens/training/mixing_buffer.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.24.0}/sae_lens/training/optim.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.24.0}/sae_lens/training/sae_trainer.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.24.0}/sae_lens/training/types.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.24.0}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.24.0}/sae_lens/tutorial/tsea.py +0 -0
- {sae_lens-6.23.0 → sae_lens-6.24.0}/sae_lens/util.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: sae-lens
|
|
3
|
-
Version: 6.
|
|
3
|
+
Version: 6.24.0
|
|
4
4
|
Summary: Training and Analyzing Sparse Autoencoders (SAEs)
|
|
5
5
|
License: MIT
|
|
6
6
|
License-File: LICENSE
|
|
@@ -62,7 +62,7 @@ This library is maintained by [Joseph Bloom](https://www.decoderesearch.com/), [
|
|
|
62
62
|
|
|
63
63
|
## Loading Pre-trained SAEs.
|
|
64
64
|
|
|
65
|
-
Pre-trained SAEs for various models can be imported via SAE Lens. See this [page](https://decoderesearch.github.io/SAELens/
|
|
65
|
+
Pre-trained SAEs for various models can be imported via SAE Lens. See this [page](https://decoderesearch.github.io/SAELens/pretrained_saes/) for a list of all SAEs.
|
|
66
66
|
|
|
67
67
|
## Migrating to SAELens v6
|
|
68
68
|
|
|
@@ -26,7 +26,7 @@ This library is maintained by [Joseph Bloom](https://www.decoderesearch.com/), [
|
|
|
26
26
|
|
|
27
27
|
## Loading Pre-trained SAEs.
|
|
28
28
|
|
|
29
|
-
Pre-trained SAEs for various models can be imported via SAE Lens. See this [page](https://decoderesearch.github.io/SAELens/
|
|
29
|
+
Pre-trained SAEs for various models can be imported via SAE Lens. See this [page](https://decoderesearch.github.io/SAELens/pretrained_saes/) for a list of all SAEs.
|
|
30
30
|
|
|
31
31
|
## Migrating to SAELens v6
|
|
32
32
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "sae-lens"
|
|
3
|
-
version = "6.
|
|
3
|
+
version = "6.24.0"
|
|
4
4
|
description = "Training and Analyzing Sparse Autoencoders (SAEs)"
|
|
5
5
|
authors = ["Joseph Bloom"]
|
|
6
6
|
readme = "README.md"
|
|
@@ -47,6 +47,7 @@ docstr-coverage = "^2.3.2"
|
|
|
47
47
|
mkdocs = "^1.6.1"
|
|
48
48
|
mkdocs-material = "^9.5.34"
|
|
49
49
|
mkdocs-autorefs = "^1.4.2"
|
|
50
|
+
mkdocs-redirects = "^1.2.1"
|
|
50
51
|
mkdocs-section-index = "^0.3.9"
|
|
51
52
|
mkdocstrings = "^0.25.2"
|
|
52
53
|
mkdocstrings-python = "^1.10.9"
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
# ruff: noqa: E402
|
|
2
|
-
__version__ = "6.
|
|
2
|
+
__version__ = "6.24.0"
|
|
3
3
|
|
|
4
4
|
import logging
|
|
5
5
|
|
|
@@ -15,6 +15,8 @@ from sae_lens.saes import (
|
|
|
15
15
|
GatedTrainingSAEConfig,
|
|
16
16
|
JumpReLUSAE,
|
|
17
17
|
JumpReLUSAEConfig,
|
|
18
|
+
JumpReLUSkipTranscoder,
|
|
19
|
+
JumpReLUSkipTranscoderConfig,
|
|
18
20
|
JumpReLUTrainingSAE,
|
|
19
21
|
JumpReLUTrainingSAEConfig,
|
|
20
22
|
JumpReLUTranscoder,
|
|
@@ -105,6 +107,8 @@ __all__ = [
|
|
|
105
107
|
"SkipTranscoderConfig",
|
|
106
108
|
"JumpReLUTranscoder",
|
|
107
109
|
"JumpReLUTranscoderConfig",
|
|
110
|
+
"JumpReLUSkipTranscoder",
|
|
111
|
+
"JumpReLUSkipTranscoderConfig",
|
|
108
112
|
"MatryoshkaBatchTopKTrainingSAE",
|
|
109
113
|
"MatryoshkaBatchTopKTrainingSAEConfig",
|
|
110
114
|
"TemporalSAE",
|
|
@@ -131,4 +135,7 @@ register_sae_training_class(
|
|
|
131
135
|
register_sae_class("transcoder", Transcoder, TranscoderConfig)
|
|
132
136
|
register_sae_class("skip_transcoder", SkipTranscoder, SkipTranscoderConfig)
|
|
133
137
|
register_sae_class("jumprelu_transcoder", JumpReLUTranscoder, JumpReLUTranscoderConfig)
|
|
138
|
+
register_sae_class(
|
|
139
|
+
"jumprelu_skip_transcoder", JumpReLUSkipTranscoder, JumpReLUSkipTranscoderConfig
|
|
140
|
+
)
|
|
134
141
|
register_sae_class("temporal", TemporalSAE, TemporalSAEConfig)
|
|
@@ -9,7 +9,7 @@ import requests
|
|
|
9
9
|
import torch
|
|
10
10
|
import yaml
|
|
11
11
|
from huggingface_hub import hf_hub_download, hf_hub_url
|
|
12
|
-
from huggingface_hub.utils import EntryNotFoundError
|
|
12
|
+
from huggingface_hub.utils import EntryNotFoundError, build_hf_headers
|
|
13
13
|
from packaging.version import Version
|
|
14
14
|
from safetensors import safe_open
|
|
15
15
|
from safetensors.torch import load_file
|
|
@@ -46,6 +46,8 @@ LLM_METADATA_KEYS = {
|
|
|
46
46
|
"sae_lens_training_version",
|
|
47
47
|
"hook_name_out",
|
|
48
48
|
"hook_head_index_out",
|
|
49
|
+
"hf_hook_name",
|
|
50
|
+
"hf_hook_name_out",
|
|
49
51
|
}
|
|
50
52
|
|
|
51
53
|
|
|
@@ -523,6 +525,206 @@ def gemma_2_sae_huggingface_loader(
|
|
|
523
525
|
return cfg_dict, state_dict, log_sparsity
|
|
524
526
|
|
|
525
527
|
|
|
528
|
+
def _infer_gemma_3_raw_cfg_dict(repo_id: str, folder_name: str) -> dict[str, Any]:
|
|
529
|
+
"""
|
|
530
|
+
Infer the raw config dict for Gemma 3 SAEs from the repo_id and folder_name.
|
|
531
|
+
This is used when config.json doesn't exist in the repo.
|
|
532
|
+
"""
|
|
533
|
+
# Extract layer number from folder name
|
|
534
|
+
layer_match = re.search(r"layer_(\d+)", folder_name)
|
|
535
|
+
if layer_match is None:
|
|
536
|
+
raise ValueError(
|
|
537
|
+
f"Could not extract layer number from folder_name: {folder_name}"
|
|
538
|
+
)
|
|
539
|
+
layer = int(layer_match.group(1))
|
|
540
|
+
|
|
541
|
+
# Convert repo_id to model_name: google/gemma-scope-2-{size}-{suffix} -> google/gemma-3-{size}-{suffix}
|
|
542
|
+
model_name = repo_id.replace("gemma-scope-2", "gemma-3")
|
|
543
|
+
|
|
544
|
+
# Determine hook type and HF hook points based on folder_name
|
|
545
|
+
if "transcoder" in folder_name or "clt" in folder_name:
|
|
546
|
+
hf_hook_point_in = f"model.layers.{layer}.pre_feedforward_layernorm.output"
|
|
547
|
+
hf_hook_point_out = f"model.layers.{layer}.post_feedforward_layernorm.output"
|
|
548
|
+
elif "resid_post" in folder_name:
|
|
549
|
+
hf_hook_point_in = f"model.layers.{layer}.output"
|
|
550
|
+
hf_hook_point_out = None
|
|
551
|
+
elif "attn_out" in folder_name:
|
|
552
|
+
hf_hook_point_in = f"model.layers.{layer}.self_attn.o_proj.input"
|
|
553
|
+
hf_hook_point_out = None
|
|
554
|
+
elif "mlp_out" in folder_name:
|
|
555
|
+
hf_hook_point_in = f"model.layers.{layer}.post_feedforward_layernorm.output"
|
|
556
|
+
hf_hook_point_out = None
|
|
557
|
+
else:
|
|
558
|
+
raise ValueError(f"Could not infer hook type from folder_name: {folder_name}")
|
|
559
|
+
|
|
560
|
+
cfg: dict[str, Any] = {
|
|
561
|
+
"architecture": "jump_relu",
|
|
562
|
+
"model_name": model_name,
|
|
563
|
+
"hf_hook_point_in": hf_hook_point_in,
|
|
564
|
+
}
|
|
565
|
+
if hf_hook_point_out is not None:
|
|
566
|
+
cfg["hf_hook_point_out"] = hf_hook_point_out
|
|
567
|
+
|
|
568
|
+
return cfg
|
|
569
|
+
|
|
570
|
+
|
|
571
|
+
def get_gemma_3_config_from_hf(
|
|
572
|
+
repo_id: str,
|
|
573
|
+
folder_name: str,
|
|
574
|
+
device: str,
|
|
575
|
+
force_download: bool = False,
|
|
576
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
577
|
+
) -> dict[str, Any]:
|
|
578
|
+
# Try to load config.json from the repo, fall back to inferring if it doesn't exist
|
|
579
|
+
try:
|
|
580
|
+
config_path = hf_hub_download(
|
|
581
|
+
repo_id, f"{folder_name}/config.json", force_download=force_download
|
|
582
|
+
)
|
|
583
|
+
with open(config_path) as config_file:
|
|
584
|
+
raw_cfg_dict = json.load(config_file)
|
|
585
|
+
except EntryNotFoundError:
|
|
586
|
+
raw_cfg_dict = _infer_gemma_3_raw_cfg_dict(repo_id, folder_name)
|
|
587
|
+
|
|
588
|
+
if raw_cfg_dict.get("architecture") != "jump_relu":
|
|
589
|
+
raise ValueError(
|
|
590
|
+
f"Unexpected architecture in Gemma 3 config: {raw_cfg_dict.get('architecture')}"
|
|
591
|
+
)
|
|
592
|
+
|
|
593
|
+
layer_match = re.search(r"layer_(\d+)", folder_name)
|
|
594
|
+
if layer_match is None:
|
|
595
|
+
raise ValueError(
|
|
596
|
+
f"Could not extract layer number from folder_name: {folder_name}"
|
|
597
|
+
)
|
|
598
|
+
layer = int(layer_match.group(1))
|
|
599
|
+
hook_name_out = None
|
|
600
|
+
d_out = None
|
|
601
|
+
if "resid_post" in folder_name:
|
|
602
|
+
hook_name = f"blocks.{layer}.hook_resid_post"
|
|
603
|
+
elif "attn_out" in folder_name:
|
|
604
|
+
hook_name = f"blocks.{layer}.hook_attn_out"
|
|
605
|
+
elif "mlp_out" in folder_name:
|
|
606
|
+
hook_name = f"blocks.{layer}.hook_mlp_out"
|
|
607
|
+
elif "transcoder" in folder_name or "clt" in folder_name:
|
|
608
|
+
hook_name = f"blocks.{layer}.ln2.hook_normalized"
|
|
609
|
+
hook_name_out = f"blocks.{layer}.hook_mlp_out"
|
|
610
|
+
else:
|
|
611
|
+
raise ValueError("Hook name not found in folder_name.")
|
|
612
|
+
|
|
613
|
+
# hackily deal with clt file names
|
|
614
|
+
params_file_part = "/params.safetensors"
|
|
615
|
+
if "clt" in folder_name:
|
|
616
|
+
params_file_part = ".safetensors"
|
|
617
|
+
|
|
618
|
+
shapes_dict = get_safetensors_tensor_shapes(
|
|
619
|
+
repo_id, f"{folder_name}{params_file_part}"
|
|
620
|
+
)
|
|
621
|
+
d_in, d_sae = shapes_dict["w_enc"]
|
|
622
|
+
# TODO: update this for real model info
|
|
623
|
+
model_name = raw_cfg_dict["model_name"]
|
|
624
|
+
if "google" not in model_name:
|
|
625
|
+
model_name = "google/" + model_name
|
|
626
|
+
model_name = model_name.replace("-v3", "-3")
|
|
627
|
+
if "270m" in model_name:
|
|
628
|
+
# for some reason the 270m model on huggingface doesn't have the -pt suffix
|
|
629
|
+
model_name = model_name.replace("-pt", "")
|
|
630
|
+
|
|
631
|
+
architecture = "jumprelu"
|
|
632
|
+
if "transcoder" in folder_name or "clt" in folder_name:
|
|
633
|
+
architecture = "jumprelu_skip_transcoder"
|
|
634
|
+
d_out = shapes_dict["w_dec"][-1]
|
|
635
|
+
|
|
636
|
+
cfg = {
|
|
637
|
+
"architecture": architecture,
|
|
638
|
+
"d_in": d_in,
|
|
639
|
+
"d_sae": d_sae,
|
|
640
|
+
"dtype": "float32",
|
|
641
|
+
"model_name": model_name,
|
|
642
|
+
"hook_name": hook_name,
|
|
643
|
+
"hook_head_index": None,
|
|
644
|
+
"finetuning_scaling_factor": False,
|
|
645
|
+
"sae_lens_training_version": None,
|
|
646
|
+
"prepend_bos": True,
|
|
647
|
+
"dataset_path": "monology/pile-uncopyrighted",
|
|
648
|
+
"context_size": 1024,
|
|
649
|
+
"apply_b_dec_to_input": False,
|
|
650
|
+
"normalize_activations": None,
|
|
651
|
+
"hf_hook_name": raw_cfg_dict.get("hf_hook_point_in"),
|
|
652
|
+
}
|
|
653
|
+
if hook_name_out is not None:
|
|
654
|
+
cfg["hook_name_out"] = hook_name_out
|
|
655
|
+
cfg["hf_hook_name_out"] = raw_cfg_dict.get("hf_hook_point_out")
|
|
656
|
+
if d_out is not None:
|
|
657
|
+
cfg["d_out"] = d_out
|
|
658
|
+
if device is not None:
|
|
659
|
+
cfg["device"] = device
|
|
660
|
+
|
|
661
|
+
if cfg_overrides is not None:
|
|
662
|
+
cfg.update(cfg_overrides)
|
|
663
|
+
|
|
664
|
+
return cfg
|
|
665
|
+
|
|
666
|
+
|
|
667
|
+
def gemma_3_sae_huggingface_loader(
|
|
668
|
+
repo_id: str,
|
|
669
|
+
folder_name: str,
|
|
670
|
+
device: str = "cpu",
|
|
671
|
+
force_download: bool = False,
|
|
672
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
673
|
+
) -> tuple[dict[str, Any], dict[str, torch.Tensor], torch.Tensor | None]:
|
|
674
|
+
"""
|
|
675
|
+
Custom loader for Gemma 3 SAEs.
|
|
676
|
+
"""
|
|
677
|
+
cfg_dict = get_gemma_3_config_from_hf(
|
|
678
|
+
repo_id,
|
|
679
|
+
folder_name,
|
|
680
|
+
device,
|
|
681
|
+
force_download,
|
|
682
|
+
cfg_overrides,
|
|
683
|
+
)
|
|
684
|
+
|
|
685
|
+
# replace folder name of 65k with 64k
|
|
686
|
+
# TODO: remove this workaround once weights are fixed
|
|
687
|
+
if "270m-pt" in repo_id:
|
|
688
|
+
if "65k" in folder_name:
|
|
689
|
+
folder_name = folder_name.replace("65k", "64k")
|
|
690
|
+
# replace folder name of 262k with 250k
|
|
691
|
+
if "262k" in folder_name:
|
|
692
|
+
folder_name = folder_name.replace("262k", "250k")
|
|
693
|
+
|
|
694
|
+
params_file = "params.safetensors"
|
|
695
|
+
if "clt" in folder_name:
|
|
696
|
+
params_file = folder_name.split("/")[-1] + ".safetensors"
|
|
697
|
+
folder_name = "/".join(folder_name.split("/")[:-1])
|
|
698
|
+
|
|
699
|
+
# Download the SAE weights
|
|
700
|
+
sae_path = hf_hub_download(
|
|
701
|
+
repo_id=repo_id,
|
|
702
|
+
filename=params_file,
|
|
703
|
+
subfolder=folder_name,
|
|
704
|
+
force_download=force_download,
|
|
705
|
+
)
|
|
706
|
+
|
|
707
|
+
raw_state_dict = load_file(sae_path, device=device)
|
|
708
|
+
|
|
709
|
+
with torch.no_grad():
|
|
710
|
+
w_dec = raw_state_dict["w_dec"]
|
|
711
|
+
if "clt" in folder_name:
|
|
712
|
+
w_dec = w_dec.sum(dim=1).contiguous()
|
|
713
|
+
|
|
714
|
+
state_dict = {
|
|
715
|
+
"W_enc": raw_state_dict["w_enc"],
|
|
716
|
+
"W_dec": w_dec,
|
|
717
|
+
"b_enc": raw_state_dict["b_enc"],
|
|
718
|
+
"b_dec": raw_state_dict["b_dec"],
|
|
719
|
+
"threshold": raw_state_dict["threshold"],
|
|
720
|
+
}
|
|
721
|
+
|
|
722
|
+
if "affine_skip_connection" in raw_state_dict:
|
|
723
|
+
state_dict["W_skip"] = raw_state_dict["affine_skip_connection"]
|
|
724
|
+
|
|
725
|
+
return cfg_dict, state_dict, None
|
|
726
|
+
|
|
727
|
+
|
|
526
728
|
def get_goodfire_config_from_hf(
|
|
527
729
|
repo_id: str,
|
|
528
730
|
folder_name: str, # noqa: ARG001
|
|
@@ -1429,38 +1631,36 @@ def mwhanna_transcoder_huggingface_loader(
|
|
|
1429
1631
|
return cfg_dict, state_dict, None
|
|
1430
1632
|
|
|
1431
1633
|
|
|
1432
|
-
def get_safetensors_tensor_shapes(
|
|
1634
|
+
def get_safetensors_tensor_shapes(repo_id: str, filename: str) -> dict[str, list[int]]:
|
|
1433
1635
|
"""
|
|
1434
|
-
Get tensor shapes from a safetensors file
|
|
1636
|
+
Get tensor shapes from a safetensors file on HuggingFace Hub
|
|
1435
1637
|
without downloading the entire file.
|
|
1436
1638
|
|
|
1639
|
+
Uses HTTP range requests to fetch only the metadata header.
|
|
1640
|
+
|
|
1437
1641
|
Args:
|
|
1438
|
-
|
|
1642
|
+
repo_id: HuggingFace repo ID (e.g., "gg-gs/gemma-scope-2-1b-pt")
|
|
1643
|
+
filename: Path to the safetensors file within the repo
|
|
1439
1644
|
|
|
1440
1645
|
Returns:
|
|
1441
1646
|
Dictionary mapping tensor names to their shapes
|
|
1442
1647
|
"""
|
|
1443
|
-
|
|
1444
|
-
response = requests.head(url, timeout=10)
|
|
1445
|
-
response.raise_for_status()
|
|
1648
|
+
url = hf_hub_url(repo_id, filename)
|
|
1446
1649
|
|
|
1447
|
-
|
|
1448
|
-
|
|
1449
|
-
raise ValueError("Server does not support range requests")
|
|
1650
|
+
# Get HuggingFace headers (includes auth token if available)
|
|
1651
|
+
hf_headers = build_hf_headers()
|
|
1450
1652
|
|
|
1451
1653
|
# Fetch first 8 bytes to get metadata size
|
|
1452
|
-
headers = {"Range": "bytes=0-7"}
|
|
1654
|
+
headers = {**hf_headers, "Range": "bytes=0-7"}
|
|
1453
1655
|
response = requests.get(url, headers=headers, timeout=10)
|
|
1454
|
-
|
|
1455
|
-
raise ValueError("Failed to fetch initial bytes for metadata size")
|
|
1656
|
+
response.raise_for_status()
|
|
1456
1657
|
|
|
1457
1658
|
meta_size = int.from_bytes(response.content, byteorder="little")
|
|
1458
1659
|
|
|
1459
1660
|
# Fetch the metadata header
|
|
1460
|
-
headers = {"Range": f"bytes=8-{8 + meta_size - 1}"}
|
|
1661
|
+
headers = {**hf_headers, "Range": f"bytes=8-{8 + meta_size - 1}"}
|
|
1461
1662
|
response = requests.get(url, headers=headers, timeout=10)
|
|
1462
|
-
|
|
1463
|
-
raise ValueError("Failed to fetch metadata header")
|
|
1663
|
+
response.raise_for_status()
|
|
1464
1664
|
|
|
1465
1665
|
metadata_json = response.content.decode("utf-8").strip()
|
|
1466
1666
|
metadata = json.loads(metadata_json)
|
|
@@ -1540,9 +1740,10 @@ def get_mntss_clt_layer_config_from_hf(
|
|
|
1540
1740
|
with open(base_config_path) as f:
|
|
1541
1741
|
cfg_info: dict[str, Any] = yaml.safe_load(f)
|
|
1542
1742
|
|
|
1543
|
-
# Get tensor shapes without downloading full files
|
|
1544
|
-
|
|
1545
|
-
|
|
1743
|
+
# Get tensor shapes without downloading full files
|
|
1744
|
+
encoder_shapes = get_safetensors_tensor_shapes(
|
|
1745
|
+
repo_id, f"W_enc_{folder_name}.safetensors"
|
|
1746
|
+
)
|
|
1546
1747
|
|
|
1547
1748
|
# Extract shapes for the required tensors
|
|
1548
1749
|
b_dec_shape = encoder_shapes[f"b_dec_{folder_name}"]
|
|
@@ -1678,6 +1879,7 @@ NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
|
|
|
1678
1879
|
"sae_lens": sae_lens_huggingface_loader,
|
|
1679
1880
|
"connor_rob_hook_z": connor_rob_hook_z_huggingface_loader,
|
|
1680
1881
|
"gemma_2": gemma_2_sae_huggingface_loader,
|
|
1882
|
+
"gemma_3": gemma_3_sae_huggingface_loader,
|
|
1681
1883
|
"llama_scope": llama_scope_sae_huggingface_loader,
|
|
1682
1884
|
"llama_scope_r1_distill": llama_scope_r1_distill_sae_huggingface_loader,
|
|
1683
1885
|
"dictionary_learning_1": dictionary_learning_sae_huggingface_loader_1,
|
|
@@ -1695,6 +1897,7 @@ NAMED_PRETRAINED_SAE_CONFIG_GETTERS: dict[str, PretrainedSaeConfigHuggingfaceLoa
|
|
|
1695
1897
|
"sae_lens": get_sae_lens_config_from_hf,
|
|
1696
1898
|
"connor_rob_hook_z": get_connor_rob_hook_z_config_from_hf,
|
|
1697
1899
|
"gemma_2": get_gemma_2_config_from_hf,
|
|
1900
|
+
"gemma_3": get_gemma_3_config_from_hf,
|
|
1698
1901
|
"llama_scope": get_llama_scope_config_from_hf,
|
|
1699
1902
|
"llama_scope_r1_distill": get_llama_scope_r1_distill_config_from_hf,
|
|
1700
1903
|
"dictionary_learning_1": get_dictionary_learning_config_1_from_hf,
|
|
@@ -186,13 +186,13 @@ class PretokenizeRunner:
|
|
|
186
186
|
"""
|
|
187
187
|
Load the dataset, tokenize it, and save it to disk and/or upload to Huggingface.
|
|
188
188
|
"""
|
|
189
|
-
dataset = load_dataset(
|
|
189
|
+
dataset = load_dataset( # type: ignore
|
|
190
190
|
self.cfg.dataset_path,
|
|
191
191
|
name=self.cfg.dataset_name,
|
|
192
192
|
data_dir=self.cfg.data_dir,
|
|
193
193
|
data_files=self.cfg.data_files,
|
|
194
|
-
split=self.cfg.split,
|
|
195
|
-
streaming=self.cfg.streaming,
|
|
194
|
+
split=self.cfg.split, # type: ignore
|
|
195
|
+
streaming=self.cfg.streaming, # type: ignore
|
|
196
196
|
)
|
|
197
197
|
if isinstance(dataset, DatasetDict):
|
|
198
198
|
raise ValueError(
|