sae-lens 6.22.1__py3-none-any.whl → 6.25.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sae_lens/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.22.1"
2
+ __version__ = "6.25.1"
3
3
 
4
4
  import logging
5
5
 
@@ -15,6 +15,8 @@ from sae_lens.saes import (
15
15
  GatedTrainingSAEConfig,
16
16
  JumpReLUSAE,
17
17
  JumpReLUSAEConfig,
18
+ JumpReLUSkipTranscoder,
19
+ JumpReLUSkipTranscoderConfig,
18
20
  JumpReLUTrainingSAE,
19
21
  JumpReLUTrainingSAEConfig,
20
22
  JumpReLUTranscoder,
@@ -105,6 +107,8 @@ __all__ = [
105
107
  "SkipTranscoderConfig",
106
108
  "JumpReLUTranscoder",
107
109
  "JumpReLUTranscoderConfig",
110
+ "JumpReLUSkipTranscoder",
111
+ "JumpReLUSkipTranscoderConfig",
108
112
  "MatryoshkaBatchTopKTrainingSAE",
109
113
  "MatryoshkaBatchTopKTrainingSAEConfig",
110
114
  "TemporalSAE",
@@ -131,4 +135,7 @@ register_sae_training_class(
131
135
  register_sae_class("transcoder", Transcoder, TranscoderConfig)
132
136
  register_sae_class("skip_transcoder", SkipTranscoder, SkipTranscoderConfig)
133
137
  register_sae_class("jumprelu_transcoder", JumpReLUTranscoder, JumpReLUTranscoderConfig)
138
+ register_sae_class(
139
+ "jumprelu_skip_transcoder", JumpReLUSkipTranscoder, JumpReLUSkipTranscoderConfig
140
+ )
134
141
  register_sae_class("temporal", TemporalSAE, TemporalSAEConfig)
@@ -14,9 +14,9 @@ from transformer_lens.HookedTransformer import HookedRootModule
14
14
 
15
15
  from sae_lens import logger
16
16
  from sae_lens.config import CacheActivationsRunnerConfig
17
- from sae_lens.constants import DTYPE_MAP
18
17
  from sae_lens.load_model import load_model
19
18
  from sae_lens.training.activations_store import ActivationsStore
19
+ from sae_lens.util import str_to_dtype
20
20
 
21
21
 
22
22
  def _mk_activations_store(
@@ -97,7 +97,7 @@ class CacheActivationsRunner:
97
97
  bytes_per_token = (
98
98
  self.cfg.d_in * self.cfg.dtype.itemsize
99
99
  if isinstance(self.cfg.dtype, torch.dtype)
100
- else DTYPE_MAP[self.cfg.dtype].itemsize
100
+ else str_to_dtype(self.cfg.dtype).itemsize
101
101
  )
102
102
  total_training_tokens = self.cfg.n_seq_in_dataset * self.context_size
103
103
  total_disk_space_gb = total_training_tokens * bytes_per_token / 10**9
sae_lens/config.py CHANGED
@@ -17,9 +17,9 @@ from datasets import (
17
17
  )
18
18
 
19
19
  from sae_lens import __version__, logger
20
- from sae_lens.constants import DTYPE_MAP
21
20
  from sae_lens.registry import get_sae_training_class
22
21
  from sae_lens.saes.sae import TrainingSAEConfig
22
+ from sae_lens.util import str_to_dtype
23
23
 
24
24
  if TYPE_CHECKING:
25
25
  pass
@@ -563,7 +563,7 @@ class CacheActivationsRunnerConfig:
563
563
 
564
564
  @property
565
565
  def bytes_per_token(self) -> int:
566
- return self.d_in * DTYPE_MAP[self.dtype].itemsize
566
+ return self.d_in * str_to_dtype(self.dtype).itemsize
567
567
 
568
568
  @property
569
569
  def n_tokens_in_buffer(self) -> int:
sae_lens/constants.py CHANGED
@@ -11,6 +11,14 @@ DTYPE_MAP = {
11
11
  "torch.bfloat16": torch.bfloat16,
12
12
  }
13
13
 
14
+ # Reverse mapping from torch.dtype to canonical string format
15
+ DTYPE_TO_STR = {
16
+ torch.float32: "float32",
17
+ torch.float64: "float64",
18
+ torch.float16: "float16",
19
+ torch.bfloat16: "bfloat16",
20
+ }
21
+
14
22
 
15
23
  SPARSITY_FILENAME = "sparsity.safetensors"
16
24
  SAE_WEIGHTS_FILENAME = "sae_weights.safetensors"
@@ -9,14 +9,12 @@ import requests
9
9
  import torch
10
10
  import yaml
11
11
  from huggingface_hub import hf_hub_download, hf_hub_url
12
- from huggingface_hub.utils import EntryNotFoundError
12
+ from huggingface_hub.utils import EntryNotFoundError, build_hf_headers
13
13
  from packaging.version import Version
14
14
  from safetensors import safe_open
15
- from safetensors.torch import load_file
16
15
 
17
16
  from sae_lens import logger
18
17
  from sae_lens.constants import (
19
- DTYPE_MAP,
20
18
  SAE_CFG_FILENAME,
21
19
  SAE_WEIGHTS_FILENAME,
22
20
  SPARSIFY_WEIGHTS_FILENAME,
@@ -28,7 +26,7 @@ from sae_lens.loading.pretrained_saes_directory import (
28
26
  get_repo_id_and_folder_name,
29
27
  )
30
28
  from sae_lens.registry import get_sae_class
31
- from sae_lens.util import filter_valid_dataclass_fields
29
+ from sae_lens.util import filter_valid_dataclass_fields, str_to_dtype
32
30
 
33
31
  LLM_METADATA_KEYS = {
34
32
  "model_name",
@@ -46,9 +44,26 @@ LLM_METADATA_KEYS = {
46
44
  "sae_lens_training_version",
47
45
  "hook_name_out",
48
46
  "hook_head_index_out",
47
+ "hf_hook_name",
48
+ "hf_hook_name_out",
49
49
  }
50
50
 
51
51
 
52
+ def load_safetensors_weights(
53
+ path: str | Path, device: str = "cpu", dtype: torch.dtype | str | None = None
54
+ ) -> dict[str, torch.Tensor]:
55
+ """Load safetensors weights and optionally convert to a different dtype"""
56
+ loaded_weights = {}
57
+ dtype = str_to_dtype(dtype) if isinstance(dtype, str) else dtype
58
+ with safe_open(path, framework="pt", device=device) as f:
59
+ for k in f.keys(): # noqa: SIM118
60
+ weight = f.get_tensor(k)
61
+ if dtype is not None:
62
+ weight = weight.to(dtype=dtype)
63
+ loaded_weights[k] = weight
64
+ return loaded_weights
65
+
66
+
52
67
  # loaders take in a release, sae_id, device, and whether to force download, and returns a tuple of config, state_dict, and log sparsity
53
68
  class PretrainedSaeHuggingfaceLoader(Protocol):
54
69
  def __call__(
@@ -339,7 +354,7 @@ def read_sae_components_from_disk(
339
354
  Given a loaded dictionary and a path to a weight file, load the weights and return the state_dict.
340
355
  """
341
356
  if dtype is None:
342
- dtype = DTYPE_MAP[cfg_dict["dtype"]]
357
+ dtype = str_to_dtype(cfg_dict["dtype"])
343
358
 
344
359
  state_dict = {}
345
360
  with safe_open(weight_path, framework="pt", device=device) as f: # type: ignore
@@ -523,6 +538,199 @@ def gemma_2_sae_huggingface_loader(
523
538
  return cfg_dict, state_dict, log_sparsity
524
539
 
525
540
 
541
+ def _infer_gemma_3_raw_cfg_dict(repo_id: str, folder_name: str) -> dict[str, Any]:
542
+ """
543
+ Infer the raw config dict for Gemma 3 SAEs from the repo_id and folder_name.
544
+ This is used when config.json doesn't exist in the repo.
545
+ """
546
+ # Extract layer number from folder name
547
+ layer_match = re.search(r"layer_(\d+)", folder_name)
548
+ if layer_match is None:
549
+ raise ValueError(
550
+ f"Could not extract layer number from folder_name: {folder_name}"
551
+ )
552
+ layer = int(layer_match.group(1))
553
+
554
+ # Convert repo_id to model_name: google/gemma-scope-2-{size}-{suffix} -> google/gemma-3-{size}-{suffix}
555
+ model_name = repo_id.replace("gemma-scope-2", "gemma-3")
556
+
557
+ # Determine hook type and HF hook points based on folder_name
558
+ if "transcoder" in folder_name or "clt" in folder_name:
559
+ hf_hook_point_in = f"model.layers.{layer}.pre_feedforward_layernorm.output"
560
+ hf_hook_point_out = f"model.layers.{layer}.post_feedforward_layernorm.output"
561
+ elif "resid_post" in folder_name:
562
+ hf_hook_point_in = f"model.layers.{layer}.output"
563
+ hf_hook_point_out = None
564
+ elif "attn_out" in folder_name:
565
+ hf_hook_point_in = f"model.layers.{layer}.self_attn.o_proj.input"
566
+ hf_hook_point_out = None
567
+ elif "mlp_out" in folder_name:
568
+ hf_hook_point_in = f"model.layers.{layer}.post_feedforward_layernorm.output"
569
+ hf_hook_point_out = None
570
+ else:
571
+ raise ValueError(f"Could not infer hook type from folder_name: {folder_name}")
572
+
573
+ cfg: dict[str, Any] = {
574
+ "architecture": "jump_relu",
575
+ "model_name": model_name,
576
+ "hf_hook_point_in": hf_hook_point_in,
577
+ }
578
+ if hf_hook_point_out is not None:
579
+ cfg["hf_hook_point_out"] = hf_hook_point_out
580
+
581
+ return cfg
582
+
583
+
584
+ def get_gemma_3_config_from_hf(
585
+ repo_id: str,
586
+ folder_name: str,
587
+ device: str,
588
+ force_download: bool = False,
589
+ cfg_overrides: dict[str, Any] | None = None,
590
+ ) -> dict[str, Any]:
591
+ # Try to load config.json from the repo, fall back to inferring if it doesn't exist
592
+ try:
593
+ config_path = hf_hub_download(
594
+ repo_id, f"{folder_name}/config.json", force_download=force_download
595
+ )
596
+ with open(config_path) as config_file:
597
+ raw_cfg_dict = json.load(config_file)
598
+ except EntryNotFoundError:
599
+ raw_cfg_dict = _infer_gemma_3_raw_cfg_dict(repo_id, folder_name)
600
+
601
+ if raw_cfg_dict.get("architecture") != "jump_relu":
602
+ raise ValueError(
603
+ f"Unexpected architecture in Gemma 3 config: {raw_cfg_dict.get('architecture')}"
604
+ )
605
+
606
+ layer_match = re.search(r"layer_(\d+)", folder_name)
607
+ if layer_match is None:
608
+ raise ValueError(
609
+ f"Could not extract layer number from folder_name: {folder_name}"
610
+ )
611
+ layer = int(layer_match.group(1))
612
+ hook_name_out = None
613
+ d_out = None
614
+ if "resid_post" in folder_name:
615
+ hook_name = f"blocks.{layer}.hook_resid_post"
616
+ elif "attn_out" in folder_name:
617
+ hook_name = f"blocks.{layer}.hook_attn_out"
618
+ elif "mlp_out" in folder_name:
619
+ hook_name = f"blocks.{layer}.hook_mlp_out"
620
+ elif "transcoder" in folder_name or "clt" in folder_name:
621
+ hook_name = f"blocks.{layer}.ln2.hook_normalized"
622
+ hook_name_out = f"blocks.{layer}.hook_mlp_out"
623
+ else:
624
+ raise ValueError("Hook name not found in folder_name.")
625
+
626
+ # hackily deal with clt file names
627
+ params_file_part = "/params.safetensors"
628
+ if "clt" in folder_name:
629
+ params_file_part = ".safetensors"
630
+
631
+ shapes_dict = get_safetensors_tensor_shapes(
632
+ repo_id, f"{folder_name}{params_file_part}"
633
+ )
634
+ d_in, d_sae = shapes_dict["w_enc"]
635
+ # TODO: update this for real model info
636
+ model_name = raw_cfg_dict["model_name"]
637
+ if "google" not in model_name:
638
+ model_name = "google/" + model_name
639
+ model_name = model_name.replace("-v3", "-3")
640
+ if "270m" in model_name:
641
+ # for some reason the 270m model on huggingface doesn't have the -pt suffix
642
+ model_name = model_name.replace("-pt", "")
643
+
644
+ architecture = "jumprelu"
645
+ if "transcoder" in folder_name or "clt" in folder_name:
646
+ architecture = "jumprelu_skip_transcoder"
647
+ d_out = shapes_dict["w_dec"][-1]
648
+
649
+ cfg = {
650
+ "architecture": architecture,
651
+ "d_in": d_in,
652
+ "d_sae": d_sae,
653
+ "dtype": "float32",
654
+ "model_name": model_name,
655
+ "hook_name": hook_name,
656
+ "hook_head_index": None,
657
+ "finetuning_scaling_factor": False,
658
+ "sae_lens_training_version": None,
659
+ "prepend_bos": True,
660
+ "dataset_path": "monology/pile-uncopyrighted",
661
+ "context_size": 1024,
662
+ "apply_b_dec_to_input": False,
663
+ "normalize_activations": None,
664
+ "hf_hook_name": raw_cfg_dict.get("hf_hook_point_in"),
665
+ }
666
+ if hook_name_out is not None:
667
+ cfg["hook_name_out"] = hook_name_out
668
+ cfg["hf_hook_name_out"] = raw_cfg_dict.get("hf_hook_point_out")
669
+ if d_out is not None:
670
+ cfg["d_out"] = d_out
671
+ if device is not None:
672
+ cfg["device"] = device
673
+
674
+ if cfg_overrides is not None:
675
+ cfg.update(cfg_overrides)
676
+
677
+ return cfg
678
+
679
+
680
+ def gemma_3_sae_huggingface_loader(
681
+ repo_id: str,
682
+ folder_name: str,
683
+ device: str = "cpu",
684
+ force_download: bool = False,
685
+ cfg_overrides: dict[str, Any] | None = None,
686
+ ) -> tuple[dict[str, Any], dict[str, torch.Tensor], torch.Tensor | None]:
687
+ """
688
+ Custom loader for Gemma 3 SAEs.
689
+ """
690
+ cfg_dict = get_gemma_3_config_from_hf(
691
+ repo_id,
692
+ folder_name,
693
+ device,
694
+ force_download,
695
+ cfg_overrides,
696
+ )
697
+
698
+ params_file = "params.safetensors"
699
+ if "clt" in folder_name:
700
+ params_file = folder_name.split("/")[-1] + ".safetensors"
701
+ folder_name = "/".join(folder_name.split("/")[:-1])
702
+
703
+ # Download the SAE weights
704
+ sae_path = hf_hub_download(
705
+ repo_id=repo_id,
706
+ filename=params_file,
707
+ subfolder=folder_name,
708
+ force_download=force_download,
709
+ )
710
+
711
+ raw_state_dict = load_safetensors_weights(
712
+ sae_path, device=device, dtype=cfg_dict.get("dtype")
713
+ )
714
+
715
+ with torch.no_grad():
716
+ w_dec = raw_state_dict["w_dec"]
717
+ if "clt" in folder_name:
718
+ w_dec = w_dec.sum(dim=1).contiguous()
719
+
720
+ state_dict = {
721
+ "W_enc": raw_state_dict["w_enc"],
722
+ "W_dec": w_dec,
723
+ "b_enc": raw_state_dict["b_enc"],
724
+ "b_dec": raw_state_dict["b_dec"],
725
+ "threshold": raw_state_dict["threshold"],
726
+ }
727
+
728
+ if "affine_skip_connection" in raw_state_dict:
729
+ state_dict["W_skip"] = raw_state_dict["affine_skip_connection"]
730
+
731
+ return cfg_dict, state_dict, None
732
+
733
+
526
734
  def get_goodfire_config_from_hf(
527
735
  repo_id: str,
528
736
  folder_name: str, # noqa: ARG001
@@ -589,11 +797,13 @@ def get_goodfire_huggingface_loader(
589
797
  )
590
798
  raw_state_dict = torch.load(sae_path, map_location=device)
591
799
 
800
+ target_dtype = str_to_dtype(cfg_dict.get("dtype", "float32"))
801
+
592
802
  state_dict = {
593
- "W_enc": raw_state_dict["encoder_linear.weight"].T,
594
- "W_dec": raw_state_dict["decoder_linear.weight"].T,
595
- "b_enc": raw_state_dict["encoder_linear.bias"],
596
- "b_dec": raw_state_dict["decoder_linear.bias"],
803
+ "W_enc": raw_state_dict["encoder_linear.weight"].T.to(dtype=target_dtype),
804
+ "W_dec": raw_state_dict["decoder_linear.weight"].T.to(dtype=target_dtype),
805
+ "b_enc": raw_state_dict["encoder_linear.bias"].to(dtype=target_dtype),
806
+ "b_dec": raw_state_dict["decoder_linear.bias"].to(dtype=target_dtype),
597
807
  }
598
808
 
599
809
  return cfg_dict, state_dict, None
@@ -696,26 +906,19 @@ def llama_scope_sae_huggingface_loader(
696
906
  force_download=force_download,
697
907
  )
698
908
 
699
- # Load the weights using load_file instead of safe_open
700
- state_dict_loaded = load_file(sae_path, device=device)
909
+ state_dict_loaded = load_safetensors_weights(
910
+ sae_path, device=device, dtype=cfg_dict.get("dtype")
911
+ )
701
912
 
702
913
  # Convert and organize the weights
703
914
  state_dict = {
704
- "W_enc": state_dict_loaded["encoder.weight"]
705
- .to(dtype=DTYPE_MAP[cfg_dict["dtype"]])
706
- .T,
707
- "W_dec": state_dict_loaded["decoder.weight"]
708
- .to(dtype=DTYPE_MAP[cfg_dict["dtype"]])
709
- .T,
710
- "b_enc": state_dict_loaded["encoder.bias"].to(
711
- dtype=DTYPE_MAP[cfg_dict["dtype"]]
712
- ),
713
- "b_dec": state_dict_loaded["decoder.bias"].to(
714
- dtype=DTYPE_MAP[cfg_dict["dtype"]]
715
- ),
915
+ "W_enc": state_dict_loaded["encoder.weight"].T,
916
+ "W_dec": state_dict_loaded["decoder.weight"].T,
917
+ "b_enc": state_dict_loaded["encoder.bias"],
918
+ "b_dec": state_dict_loaded["decoder.bias"],
716
919
  "threshold": torch.ones(
717
920
  cfg_dict["d_sae"],
718
- dtype=DTYPE_MAP[cfg_dict["dtype"]],
921
+ dtype=str_to_dtype(cfg_dict["dtype"]),
719
922
  device=cfg_dict["device"],
720
923
  )
721
924
  * cfg_dict["jump_relu_threshold"],
@@ -753,10 +956,14 @@ def get_dictionary_learning_config_1_from_hf(
753
956
  activation_fn = "topk" if trainer["dict_class"] == "AutoEncoderTopK" else "relu"
754
957
  activation_fn_kwargs = {"k": trainer["k"]} if activation_fn == "topk" else {}
755
958
 
959
+ architecture = "standard"
960
+ if trainer["dict_class"] == "GatedAutoEncoder":
961
+ architecture = "gated"
962
+ elif trainer["dict_class"] == "MatryoshkaBatchTopKSAE":
963
+ architecture = "jumprelu"
964
+
756
965
  return {
757
- "architecture": (
758
- "gated" if trainer["dict_class"] == "GatedAutoEncoder" else "standard"
759
- ),
966
+ "architecture": architecture,
760
967
  "d_in": trainer["activation_dim"],
761
968
  "d_sae": trainer["dict_size"],
762
969
  "dtype": "float32",
@@ -905,9 +1112,12 @@ def dictionary_learning_sae_huggingface_loader_1(
905
1112
  )
906
1113
  encoder = torch.load(encoder_path, map_location="cpu")
907
1114
 
1115
+ W_enc = encoder["W_enc"] if "W_enc" in encoder else encoder["encoder.weight"].T
1116
+ W_dec = encoder["W_dec"] if "W_dec" in encoder else encoder["decoder.weight"].T
1117
+
908
1118
  state_dict = {
909
- "W_enc": encoder["encoder.weight"].T,
910
- "W_dec": encoder["decoder.weight"].T,
1119
+ "W_enc": W_enc,
1120
+ "W_dec": W_dec,
911
1121
  "b_dec": encoder.get(
912
1122
  "b_dec", encoder.get("bias", encoder.get("decoder_bias", None))
913
1123
  ),
@@ -915,6 +1125,8 @@ def dictionary_learning_sae_huggingface_loader_1(
915
1125
 
916
1126
  if "encoder.bias" in encoder:
917
1127
  state_dict["b_enc"] = encoder["encoder.bias"]
1128
+ if "b_enc" in encoder:
1129
+ state_dict["b_enc"] = encoder["b_enc"]
918
1130
 
919
1131
  if "mag_bias" in encoder:
920
1132
  state_dict["b_mag"] = encoder["mag_bias"]
@@ -923,6 +1135,12 @@ def dictionary_learning_sae_huggingface_loader_1(
923
1135
  if "r_mag" in encoder:
924
1136
  state_dict["r_mag"] = encoder["r_mag"]
925
1137
 
1138
+ if "threshold" in encoder:
1139
+ threshold = encoder["threshold"]
1140
+ if threshold.ndim == 0:
1141
+ threshold = torch.full((W_enc.size(1),), threshold)
1142
+ state_dict["threshold"] = threshold
1143
+
926
1144
  return cfg_dict, state_dict, None
927
1145
 
928
1146
 
@@ -1011,26 +1229,17 @@ def llama_scope_r1_distill_sae_huggingface_loader(
1011
1229
  force_download=force_download,
1012
1230
  )
1013
1231
 
1014
- # Load the weights using load_file instead of safe_open
1015
- state_dict_loaded = load_file(sae_path, device=device)
1232
+ state_dict_loaded = load_safetensors_weights(
1233
+ sae_path, device=device, dtype=cfg_dict.get("dtype")
1234
+ )
1016
1235
 
1017
1236
  # Convert and organize the weights
1018
1237
  state_dict = {
1019
- "W_enc": state_dict_loaded["encoder.weight"]
1020
- .to(dtype=DTYPE_MAP[cfg_dict["dtype"]])
1021
- .T,
1022
- "W_dec": state_dict_loaded["decoder.weight"]
1023
- .to(dtype=DTYPE_MAP[cfg_dict["dtype"]])
1024
- .T,
1025
- "b_enc": state_dict_loaded["encoder.bias"].to(
1026
- dtype=DTYPE_MAP[cfg_dict["dtype"]]
1027
- ),
1028
- "b_dec": state_dict_loaded["decoder.bias"].to(
1029
- dtype=DTYPE_MAP[cfg_dict["dtype"]]
1030
- ),
1031
- "threshold": state_dict_loaded["log_jumprelu_threshold"]
1032
- .to(dtype=DTYPE_MAP[cfg_dict["dtype"]])
1033
- .exp(),
1238
+ "W_enc": state_dict_loaded["encoder.weight"].T,
1239
+ "W_dec": state_dict_loaded["decoder.weight"].T,
1240
+ "b_enc": state_dict_loaded["encoder.bias"],
1241
+ "b_dec": state_dict_loaded["decoder.bias"],
1242
+ "threshold": state_dict_loaded["log_jumprelu_threshold"].exp(),
1034
1243
  }
1035
1244
 
1036
1245
  # No sparsity tensor for Llama Scope SAEs
@@ -1150,34 +1359,34 @@ def sparsify_disk_loader(
1150
1359
  cfg_dict = get_sparsify_config_from_disk(path, device, cfg_overrides)
1151
1360
 
1152
1361
  weight_path = Path(path) / SPARSIFY_WEIGHTS_FILENAME
1153
- state_dict_loaded = load_file(weight_path, device=device)
1154
-
1155
- dtype = DTYPE_MAP[cfg_dict["dtype"]]
1362
+ state_dict_loaded = load_safetensors_weights(
1363
+ weight_path, device=device, dtype=cfg_dict.get("dtype")
1364
+ )
1156
1365
 
1157
1366
  W_enc = (
1158
1367
  state_dict_loaded["W_enc"]
1159
1368
  if "W_enc" in state_dict_loaded
1160
1369
  else state_dict_loaded["encoder.weight"].T
1161
- ).to(dtype)
1370
+ )
1162
1371
 
1163
1372
  if "W_dec" in state_dict_loaded:
1164
- W_dec = state_dict_loaded["W_dec"].T.to(dtype)
1373
+ W_dec = state_dict_loaded["W_dec"].T
1165
1374
  else:
1166
- W_dec = state_dict_loaded["decoder.weight"].T.to(dtype)
1375
+ W_dec = state_dict_loaded["decoder.weight"].T
1167
1376
 
1168
1377
  if "b_enc" in state_dict_loaded:
1169
- b_enc = state_dict_loaded["b_enc"].to(dtype)
1378
+ b_enc = state_dict_loaded["b_enc"]
1170
1379
  elif "encoder.bias" in state_dict_loaded:
1171
- b_enc = state_dict_loaded["encoder.bias"].to(dtype)
1380
+ b_enc = state_dict_loaded["encoder.bias"]
1172
1381
  else:
1173
- b_enc = torch.zeros(cfg_dict["d_sae"], dtype=dtype, device=device)
1382
+ b_enc = torch.zeros(cfg_dict["d_sae"], dtype=W_dec.dtype, device=device)
1174
1383
 
1175
1384
  if "b_dec" in state_dict_loaded:
1176
- b_dec = state_dict_loaded["b_dec"].to(dtype)
1385
+ b_dec = state_dict_loaded["b_dec"]
1177
1386
  elif "decoder.bias" in state_dict_loaded:
1178
- b_dec = state_dict_loaded["decoder.bias"].to(dtype)
1387
+ b_dec = state_dict_loaded["decoder.bias"]
1179
1388
  else:
1180
- b_dec = torch.zeros(cfg_dict["d_in"], dtype=dtype, device=device)
1389
+ b_dec = torch.zeros(cfg_dict["d_in"], dtype=W_dec.dtype, device=device)
1181
1390
 
1182
1391
  state_dict = {"W_enc": W_enc, "b_enc": b_enc, "W_dec": W_dec, "b_dec": b_dec}
1183
1392
  return cfg_dict, state_dict
@@ -1408,44 +1617,44 @@ def mwhanna_transcoder_huggingface_loader(
1408
1617
  )
1409
1618
 
1410
1619
  # Load weights from safetensors
1411
- state_dict = load_file(file_path, device=device)
1620
+ state_dict = load_safetensors_weights(
1621
+ file_path, device=device, dtype=cfg_dict.get("dtype")
1622
+ )
1412
1623
  state_dict["W_enc"] = state_dict["W_enc"].T
1413
1624
 
1414
1625
  return cfg_dict, state_dict, None
1415
1626
 
1416
1627
 
1417
- def get_safetensors_tensor_shapes(url: str) -> dict[str, list[int]]:
1628
+ def get_safetensors_tensor_shapes(repo_id: str, filename: str) -> dict[str, list[int]]:
1418
1629
  """
1419
- Get tensor shapes from a safetensors file using HTTP range requests
1630
+ Get tensor shapes from a safetensors file on HuggingFace Hub
1420
1631
  without downloading the entire file.
1421
1632
 
1633
+ Uses HTTP range requests to fetch only the metadata header.
1634
+
1422
1635
  Args:
1423
- url: Direct URL to the safetensors file
1636
+ repo_id: HuggingFace repo ID (e.g., "gg-gs/gemma-scope-2-1b-pt")
1637
+ filename: Path to the safetensors file within the repo
1424
1638
 
1425
1639
  Returns:
1426
1640
  Dictionary mapping tensor names to their shapes
1427
1641
  """
1428
- # Check if server supports range requests
1429
- response = requests.head(url, timeout=10)
1430
- response.raise_for_status()
1642
+ url = hf_hub_url(repo_id, filename)
1431
1643
 
1432
- accept_ranges = response.headers.get("Accept-Ranges", "")
1433
- if "bytes" not in accept_ranges:
1434
- raise ValueError("Server does not support range requests")
1644
+ # Get HuggingFace headers (includes auth token if available)
1645
+ hf_headers = build_hf_headers()
1435
1646
 
1436
1647
  # Fetch first 8 bytes to get metadata size
1437
- headers = {"Range": "bytes=0-7"}
1648
+ headers = {**hf_headers, "Range": "bytes=0-7"}
1438
1649
  response = requests.get(url, headers=headers, timeout=10)
1439
- if response.status_code != 206:
1440
- raise ValueError("Failed to fetch initial bytes for metadata size")
1650
+ response.raise_for_status()
1441
1651
 
1442
1652
  meta_size = int.from_bytes(response.content, byteorder="little")
1443
1653
 
1444
1654
  # Fetch the metadata header
1445
- headers = {"Range": f"bytes=8-{8 + meta_size - 1}"}
1655
+ headers = {**hf_headers, "Range": f"bytes=8-{8 + meta_size - 1}"}
1446
1656
  response = requests.get(url, headers=headers, timeout=10)
1447
- if response.status_code != 206:
1448
- raise ValueError("Failed to fetch metadata header")
1657
+ response.raise_for_status()
1449
1658
 
1450
1659
  metadata_json = response.content.decode("utf-8").strip()
1451
1660
  metadata = json.loads(metadata_json)
@@ -1494,8 +1703,12 @@ def mntss_clt_layer_huggingface_loader(
1494
1703
  force_download=force_download,
1495
1704
  )
1496
1705
 
1497
- encoder_state_dict = load_file(encoder_path, device=device)
1498
- decoder_state_dict = load_file(decoder_path, device=device)
1706
+ encoder_state_dict = load_safetensors_weights(
1707
+ encoder_path, device=device, dtype=cfg_dict.get("dtype")
1708
+ )
1709
+ decoder_state_dict = load_safetensors_weights(
1710
+ decoder_path, device=device, dtype=cfg_dict.get("dtype")
1711
+ )
1499
1712
 
1500
1713
  with torch.no_grad():
1501
1714
  state_dict = {
@@ -1525,9 +1738,10 @@ def get_mntss_clt_layer_config_from_hf(
1525
1738
  with open(base_config_path) as f:
1526
1739
  cfg_info: dict[str, Any] = yaml.safe_load(f)
1527
1740
 
1528
- # Get tensor shapes without downloading full files using HTTP range requests
1529
- encoder_url = hf_hub_url(repo_id, f"W_enc_{folder_name}.safetensors")
1530
- encoder_shapes = get_safetensors_tensor_shapes(encoder_url)
1741
+ # Get tensor shapes without downloading full files
1742
+ encoder_shapes = get_safetensors_tensor_shapes(
1743
+ repo_id, f"W_enc_{folder_name}.safetensors"
1744
+ )
1531
1745
 
1532
1746
  # Extract shapes for the required tensors
1533
1747
  b_dec_shape = encoder_shapes[f"b_dec_{folder_name}"]
@@ -1637,7 +1851,9 @@ def temporal_sae_huggingface_loader(
1637
1851
  )
1638
1852
 
1639
1853
  # Load checkpoint from safetensors
1640
- state_dict_raw = load_file(ckpt_path, device=device)
1854
+ state_dict_raw = load_safetensors_weights(
1855
+ ckpt_path, device=device, dtype=cfg_dict.get("dtype")
1856
+ )
1641
1857
 
1642
1858
  # Convert to SAELens naming convention
1643
1859
  # TemporalSAE uses: D (decoder), E (encoder), b (bias), attn_layers.*
@@ -1663,6 +1879,7 @@ NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
1663
1879
  "sae_lens": sae_lens_huggingface_loader,
1664
1880
  "connor_rob_hook_z": connor_rob_hook_z_huggingface_loader,
1665
1881
  "gemma_2": gemma_2_sae_huggingface_loader,
1882
+ "gemma_3": gemma_3_sae_huggingface_loader,
1666
1883
  "llama_scope": llama_scope_sae_huggingface_loader,
1667
1884
  "llama_scope_r1_distill": llama_scope_r1_distill_sae_huggingface_loader,
1668
1885
  "dictionary_learning_1": dictionary_learning_sae_huggingface_loader_1,
@@ -1680,6 +1897,7 @@ NAMED_PRETRAINED_SAE_CONFIG_GETTERS: dict[str, PretrainedSaeConfigHuggingfaceLoa
1680
1897
  "sae_lens": get_sae_lens_config_from_hf,
1681
1898
  "connor_rob_hook_z": get_connor_rob_hook_z_config_from_hf,
1682
1899
  "gemma_2": get_gemma_2_config_from_hf,
1900
+ "gemma_3": get_gemma_3_config_from_hf,
1683
1901
  "llama_scope": get_llama_scope_config_from_hf,
1684
1902
  "llama_scope_r1_distill": get_llama_scope_r1_distill_config_from_hf,
1685
1903
  "dictionary_learning_1": get_dictionary_learning_config_1_from_hf,
@@ -186,13 +186,13 @@ class PretokenizeRunner:
186
186
  """
187
187
  Load the dataset, tokenize it, and save it to disk and/or upload to Huggingface.
188
188
  """
189
- dataset = load_dataset(
189
+ dataset = load_dataset( # type: ignore
190
190
  self.cfg.dataset_path,
191
191
  name=self.cfg.dataset_name,
192
192
  data_dir=self.cfg.data_dir,
193
193
  data_files=self.cfg.data_files,
194
- split=self.cfg.split,
195
- streaming=self.cfg.streaming,
194
+ split=self.cfg.split, # type: ignore
195
+ streaming=self.cfg.streaming, # type: ignore
196
196
  )
197
197
  if isinstance(dataset, DatasetDict):
198
198
  raise ValueError(