sae-lens 6.22.3__tar.gz → 6.24.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. {sae_lens-6.22.3 → sae_lens-6.24.0}/PKG-INFO +2 -2
  2. {sae_lens-6.22.3 → sae_lens-6.24.0}/README.md +1 -1
  3. {sae_lens-6.22.3 → sae_lens-6.24.0}/pyproject.toml +2 -1
  4. {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/__init__.py +8 -1
  5. {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/loading/pretrained_sae_loaders.py +242 -24
  6. {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/pretokenize_runner.py +3 -3
  7. sae_lens-6.24.0/sae_lens/pretrained_saes.yaml +41797 -0
  8. {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/saes/__init__.py +4 -0
  9. {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/saes/transcoder.py +41 -0
  10. {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/training/activations_store.py +1 -1
  11. sae_lens-6.22.3/sae_lens/pretrained_saes.yaml +0 -14961
  12. {sae_lens-6.22.3 → sae_lens-6.24.0}/LICENSE +0 -0
  13. {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/analysis/__init__.py +0 -0
  14. {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
  15. {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/analysis/neuronpedia_integration.py +0 -0
  16. {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/cache_activations_runner.py +0 -0
  17. {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/config.py +0 -0
  18. {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/constants.py +0 -0
  19. {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/evals.py +0 -0
  20. {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/llm_sae_training_runner.py +0 -0
  21. {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/load_model.py +0 -0
  22. {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/loading/__init__.py +0 -0
  23. {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/loading/pretrained_saes_directory.py +0 -0
  24. {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/registry.py +0 -0
  25. {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/saes/batchtopk_sae.py +0 -0
  26. {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/saes/gated_sae.py +0 -0
  27. {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/saes/jumprelu_sae.py +0 -0
  28. {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/saes/matryoshka_batchtopk_sae.py +0 -0
  29. {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/saes/sae.py +0 -0
  30. {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/saes/standard_sae.py +0 -0
  31. {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/saes/temporal_sae.py +0 -0
  32. {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/saes/topk_sae.py +0 -0
  33. {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/tokenization_and_batching.py +0 -0
  34. {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/training/__init__.py +0 -0
  35. {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/training/activation_scaler.py +0 -0
  36. {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/training/mixing_buffer.py +0 -0
  37. {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/training/optim.py +0 -0
  38. {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/training/sae_trainer.py +0 -0
  39. {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/training/types.py +0 -0
  40. {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
  41. {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/tutorial/tsea.py +0 -0
  42. {sae_lens-6.22.3 → sae_lens-6.24.0}/sae_lens/util.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sae-lens
3
- Version: 6.22.3
3
+ Version: 6.24.0
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  License-File: LICENSE
@@ -62,7 +62,7 @@ This library is maintained by [Joseph Bloom](https://www.decoderesearch.com/), [
62
62
 
63
63
  ## Loading Pre-trained SAEs.
64
64
 
65
- Pre-trained SAEs for various models can be imported via SAE Lens. See this [page](https://decoderesearch.github.io/SAELens/sae_table/) in the readme for a list of all SAEs.
65
+ Pre-trained SAEs for various models can be imported via SAE Lens. See this [page](https://decoderesearch.github.io/SAELens/pretrained_saes/) for a list of all SAEs.
66
66
 
67
67
  ## Migrating to SAELens v6
68
68
 
@@ -26,7 +26,7 @@ This library is maintained by [Joseph Bloom](https://www.decoderesearch.com/), [
26
26
 
27
27
  ## Loading Pre-trained SAEs.
28
28
 
29
- Pre-trained SAEs for various models can be imported via SAE Lens. See this [page](https://decoderesearch.github.io/SAELens/sae_table/) in the readme for a list of all SAEs.
29
+ Pre-trained SAEs for various models can be imported via SAE Lens. See this [page](https://decoderesearch.github.io/SAELens/pretrained_saes/) for a list of all SAEs.
30
30
 
31
31
  ## Migrating to SAELens v6
32
32
 
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "sae-lens"
3
- version = "6.22.3"
3
+ version = "6.24.0"
4
4
  description = "Training and Analyzing Sparse Autoencoders (SAEs)"
5
5
  authors = ["Joseph Bloom"]
6
6
  readme = "README.md"
@@ -47,6 +47,7 @@ docstr-coverage = "^2.3.2"
47
47
  mkdocs = "^1.6.1"
48
48
  mkdocs-material = "^9.5.34"
49
49
  mkdocs-autorefs = "^1.4.2"
50
+ mkdocs-redirects = "^1.2.1"
50
51
  mkdocs-section-index = "^0.3.9"
51
52
  mkdocstrings = "^0.25.2"
52
53
  mkdocstrings-python = "^1.10.9"
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.22.3"
2
+ __version__ = "6.24.0"
3
3
 
4
4
  import logging
5
5
 
@@ -15,6 +15,8 @@ from sae_lens.saes import (
15
15
  GatedTrainingSAEConfig,
16
16
  JumpReLUSAE,
17
17
  JumpReLUSAEConfig,
18
+ JumpReLUSkipTranscoder,
19
+ JumpReLUSkipTranscoderConfig,
18
20
  JumpReLUTrainingSAE,
19
21
  JumpReLUTrainingSAEConfig,
20
22
  JumpReLUTranscoder,
@@ -105,6 +107,8 @@ __all__ = [
105
107
  "SkipTranscoderConfig",
106
108
  "JumpReLUTranscoder",
107
109
  "JumpReLUTranscoderConfig",
110
+ "JumpReLUSkipTranscoder",
111
+ "JumpReLUSkipTranscoderConfig",
108
112
  "MatryoshkaBatchTopKTrainingSAE",
109
113
  "MatryoshkaBatchTopKTrainingSAEConfig",
110
114
  "TemporalSAE",
@@ -131,4 +135,7 @@ register_sae_training_class(
131
135
  register_sae_class("transcoder", Transcoder, TranscoderConfig)
132
136
  register_sae_class("skip_transcoder", SkipTranscoder, SkipTranscoderConfig)
133
137
  register_sae_class("jumprelu_transcoder", JumpReLUTranscoder, JumpReLUTranscoderConfig)
138
+ register_sae_class(
139
+ "jumprelu_skip_transcoder", JumpReLUSkipTranscoder, JumpReLUSkipTranscoderConfig
140
+ )
134
141
  register_sae_class("temporal", TemporalSAE, TemporalSAEConfig)
@@ -9,7 +9,7 @@ import requests
9
9
  import torch
10
10
  import yaml
11
11
  from huggingface_hub import hf_hub_download, hf_hub_url
12
- from huggingface_hub.utils import EntryNotFoundError
12
+ from huggingface_hub.utils import EntryNotFoundError, build_hf_headers
13
13
  from packaging.version import Version
14
14
  from safetensors import safe_open
15
15
  from safetensors.torch import load_file
@@ -46,6 +46,8 @@ LLM_METADATA_KEYS = {
46
46
  "sae_lens_training_version",
47
47
  "hook_name_out",
48
48
  "hook_head_index_out",
49
+ "hf_hook_name",
50
+ "hf_hook_name_out",
49
51
  }
50
52
 
51
53
 
@@ -523,6 +525,206 @@ def gemma_2_sae_huggingface_loader(
523
525
  return cfg_dict, state_dict, log_sparsity
524
526
 
525
527
 
528
+ def _infer_gemma_3_raw_cfg_dict(repo_id: str, folder_name: str) -> dict[str, Any]:
529
+ """
530
+ Infer the raw config dict for Gemma 3 SAEs from the repo_id and folder_name.
531
+ This is used when config.json doesn't exist in the repo.
532
+ """
533
+ # Extract layer number from folder name
534
+ layer_match = re.search(r"layer_(\d+)", folder_name)
535
+ if layer_match is None:
536
+ raise ValueError(
537
+ f"Could not extract layer number from folder_name: {folder_name}"
538
+ )
539
+ layer = int(layer_match.group(1))
540
+
541
+ # Convert repo_id to model_name: google/gemma-scope-2-{size}-{suffix} -> google/gemma-3-{size}-{suffix}
542
+ model_name = repo_id.replace("gemma-scope-2", "gemma-3")
543
+
544
+ # Determine hook type and HF hook points based on folder_name
545
+ if "transcoder" in folder_name or "clt" in folder_name:
546
+ hf_hook_point_in = f"model.layers.{layer}.pre_feedforward_layernorm.output"
547
+ hf_hook_point_out = f"model.layers.{layer}.post_feedforward_layernorm.output"
548
+ elif "resid_post" in folder_name:
549
+ hf_hook_point_in = f"model.layers.{layer}.output"
550
+ hf_hook_point_out = None
551
+ elif "attn_out" in folder_name:
552
+ hf_hook_point_in = f"model.layers.{layer}.self_attn.o_proj.input"
553
+ hf_hook_point_out = None
554
+ elif "mlp_out" in folder_name:
555
+ hf_hook_point_in = f"model.layers.{layer}.post_feedforward_layernorm.output"
556
+ hf_hook_point_out = None
557
+ else:
558
+ raise ValueError(f"Could not infer hook type from folder_name: {folder_name}")
559
+
560
+ cfg: dict[str, Any] = {
561
+ "architecture": "jump_relu",
562
+ "model_name": model_name,
563
+ "hf_hook_point_in": hf_hook_point_in,
564
+ }
565
+ if hf_hook_point_out is not None:
566
+ cfg["hf_hook_point_out"] = hf_hook_point_out
567
+
568
+ return cfg
569
+
570
+
571
+ def get_gemma_3_config_from_hf(
572
+ repo_id: str,
573
+ folder_name: str,
574
+ device: str,
575
+ force_download: bool = False,
576
+ cfg_overrides: dict[str, Any] | None = None,
577
+ ) -> dict[str, Any]:
578
+ # Try to load config.json from the repo, fall back to inferring if it doesn't exist
579
+ try:
580
+ config_path = hf_hub_download(
581
+ repo_id, f"{folder_name}/config.json", force_download=force_download
582
+ )
583
+ with open(config_path) as config_file:
584
+ raw_cfg_dict = json.load(config_file)
585
+ except EntryNotFoundError:
586
+ raw_cfg_dict = _infer_gemma_3_raw_cfg_dict(repo_id, folder_name)
587
+
588
+ if raw_cfg_dict.get("architecture") != "jump_relu":
589
+ raise ValueError(
590
+ f"Unexpected architecture in Gemma 3 config: {raw_cfg_dict.get('architecture')}"
591
+ )
592
+
593
+ layer_match = re.search(r"layer_(\d+)", folder_name)
594
+ if layer_match is None:
595
+ raise ValueError(
596
+ f"Could not extract layer number from folder_name: {folder_name}"
597
+ )
598
+ layer = int(layer_match.group(1))
599
+ hook_name_out = None
600
+ d_out = None
601
+ if "resid_post" in folder_name:
602
+ hook_name = f"blocks.{layer}.hook_resid_post"
603
+ elif "attn_out" in folder_name:
604
+ hook_name = f"blocks.{layer}.hook_attn_out"
605
+ elif "mlp_out" in folder_name:
606
+ hook_name = f"blocks.{layer}.hook_mlp_out"
607
+ elif "transcoder" in folder_name or "clt" in folder_name:
608
+ hook_name = f"blocks.{layer}.ln2.hook_normalized"
609
+ hook_name_out = f"blocks.{layer}.hook_mlp_out"
610
+ else:
611
+ raise ValueError("Hook name not found in folder_name.")
612
+
613
+ # hackily deal with clt file names
614
+ params_file_part = "/params.safetensors"
615
+ if "clt" in folder_name:
616
+ params_file_part = ".safetensors"
617
+
618
+ shapes_dict = get_safetensors_tensor_shapes(
619
+ repo_id, f"{folder_name}{params_file_part}"
620
+ )
621
+ d_in, d_sae = shapes_dict["w_enc"]
622
+ # TODO: update this for real model info
623
+ model_name = raw_cfg_dict["model_name"]
624
+ if "google" not in model_name:
625
+ model_name = "google/" + model_name
626
+ model_name = model_name.replace("-v3", "-3")
627
+ if "270m" in model_name:
628
+ # for some reason the 270m model on huggingface doesn't have the -pt suffix
629
+ model_name = model_name.replace("-pt", "")
630
+
631
+ architecture = "jumprelu"
632
+ if "transcoder" in folder_name or "clt" in folder_name:
633
+ architecture = "jumprelu_skip_transcoder"
634
+ d_out = shapes_dict["w_dec"][-1]
635
+
636
+ cfg = {
637
+ "architecture": architecture,
638
+ "d_in": d_in,
639
+ "d_sae": d_sae,
640
+ "dtype": "float32",
641
+ "model_name": model_name,
642
+ "hook_name": hook_name,
643
+ "hook_head_index": None,
644
+ "finetuning_scaling_factor": False,
645
+ "sae_lens_training_version": None,
646
+ "prepend_bos": True,
647
+ "dataset_path": "monology/pile-uncopyrighted",
648
+ "context_size": 1024,
649
+ "apply_b_dec_to_input": False,
650
+ "normalize_activations": None,
651
+ "hf_hook_name": raw_cfg_dict.get("hf_hook_point_in"),
652
+ }
653
+ if hook_name_out is not None:
654
+ cfg["hook_name_out"] = hook_name_out
655
+ cfg["hf_hook_name_out"] = raw_cfg_dict.get("hf_hook_point_out")
656
+ if d_out is not None:
657
+ cfg["d_out"] = d_out
658
+ if device is not None:
659
+ cfg["device"] = device
660
+
661
+ if cfg_overrides is not None:
662
+ cfg.update(cfg_overrides)
663
+
664
+ return cfg
665
+
666
+
667
+ def gemma_3_sae_huggingface_loader(
668
+ repo_id: str,
669
+ folder_name: str,
670
+ device: str = "cpu",
671
+ force_download: bool = False,
672
+ cfg_overrides: dict[str, Any] | None = None,
673
+ ) -> tuple[dict[str, Any], dict[str, torch.Tensor], torch.Tensor | None]:
674
+ """
675
+ Custom loader for Gemma 3 SAEs.
676
+ """
677
+ cfg_dict = get_gemma_3_config_from_hf(
678
+ repo_id,
679
+ folder_name,
680
+ device,
681
+ force_download,
682
+ cfg_overrides,
683
+ )
684
+
685
+ # replace folder name of 65k with 64k
686
+ # TODO: remove this workaround once weights are fixed
687
+ if "270m-pt" in repo_id:
688
+ if "65k" in folder_name:
689
+ folder_name = folder_name.replace("65k", "64k")
690
+ # replace folder name of 262k with 250k
691
+ if "262k" in folder_name:
692
+ folder_name = folder_name.replace("262k", "250k")
693
+
694
+ params_file = "params.safetensors"
695
+ if "clt" in folder_name:
696
+ params_file = folder_name.split("/")[-1] + ".safetensors"
697
+ folder_name = "/".join(folder_name.split("/")[:-1])
698
+
699
+ # Download the SAE weights
700
+ sae_path = hf_hub_download(
701
+ repo_id=repo_id,
702
+ filename=params_file,
703
+ subfolder=folder_name,
704
+ force_download=force_download,
705
+ )
706
+
707
+ raw_state_dict = load_file(sae_path, device=device)
708
+
709
+ with torch.no_grad():
710
+ w_dec = raw_state_dict["w_dec"]
711
+ if "clt" in folder_name:
712
+ w_dec = w_dec.sum(dim=1).contiguous()
713
+
714
+ state_dict = {
715
+ "W_enc": raw_state_dict["w_enc"],
716
+ "W_dec": w_dec,
717
+ "b_enc": raw_state_dict["b_enc"],
718
+ "b_dec": raw_state_dict["b_dec"],
719
+ "threshold": raw_state_dict["threshold"],
720
+ }
721
+
722
+ if "affine_skip_connection" in raw_state_dict:
723
+ state_dict["W_skip"] = raw_state_dict["affine_skip_connection"]
724
+
725
+ return cfg_dict, state_dict, None
726
+
727
+
526
728
  def get_goodfire_config_from_hf(
527
729
  repo_id: str,
528
730
  folder_name: str, # noqa: ARG001
@@ -753,10 +955,14 @@ def get_dictionary_learning_config_1_from_hf(
753
955
  activation_fn = "topk" if trainer["dict_class"] == "AutoEncoderTopK" else "relu"
754
956
  activation_fn_kwargs = {"k": trainer["k"]} if activation_fn == "topk" else {}
755
957
 
958
+ architecture = "standard"
959
+ if trainer["dict_class"] == "GatedAutoEncoder":
960
+ architecture = "gated"
961
+ elif trainer["dict_class"] == "MatryoshkaBatchTopKSAE":
962
+ architecture = "jumprelu"
963
+
756
964
  return {
757
- "architecture": (
758
- "gated" if trainer["dict_class"] == "GatedAutoEncoder" else "standard"
759
- ),
965
+ "architecture": architecture,
760
966
  "d_in": trainer["activation_dim"],
761
967
  "d_sae": trainer["dict_size"],
762
968
  "dtype": "float32",
@@ -905,9 +1111,12 @@ def dictionary_learning_sae_huggingface_loader_1(
905
1111
  )
906
1112
  encoder = torch.load(encoder_path, map_location="cpu")
907
1113
 
1114
+ W_enc = encoder["W_enc"] if "W_enc" in encoder else encoder["encoder.weight"].T
1115
+ W_dec = encoder["W_dec"] if "W_dec" in encoder else encoder["decoder.weight"].T
1116
+
908
1117
  state_dict = {
909
- "W_enc": encoder["encoder.weight"].T,
910
- "W_dec": encoder["decoder.weight"].T,
1118
+ "W_enc": W_enc,
1119
+ "W_dec": W_dec,
911
1120
  "b_dec": encoder.get(
912
1121
  "b_dec", encoder.get("bias", encoder.get("decoder_bias", None))
913
1122
  ),
@@ -915,6 +1124,8 @@ def dictionary_learning_sae_huggingface_loader_1(
915
1124
 
916
1125
  if "encoder.bias" in encoder:
917
1126
  state_dict["b_enc"] = encoder["encoder.bias"]
1127
+ if "b_enc" in encoder:
1128
+ state_dict["b_enc"] = encoder["b_enc"]
918
1129
 
919
1130
  if "mag_bias" in encoder:
920
1131
  state_dict["b_mag"] = encoder["mag_bias"]
@@ -923,6 +1134,12 @@ def dictionary_learning_sae_huggingface_loader_1(
923
1134
  if "r_mag" in encoder:
924
1135
  state_dict["r_mag"] = encoder["r_mag"]
925
1136
 
1137
+ if "threshold" in encoder:
1138
+ threshold = encoder["threshold"]
1139
+ if threshold.ndim == 0:
1140
+ threshold = torch.full((W_enc.size(1),), threshold)
1141
+ state_dict["threshold"] = threshold
1142
+
926
1143
  return cfg_dict, state_dict, None
927
1144
 
928
1145
 
@@ -1414,38 +1631,36 @@ def mwhanna_transcoder_huggingface_loader(
1414
1631
  return cfg_dict, state_dict, None
1415
1632
 
1416
1633
 
1417
- def get_safetensors_tensor_shapes(url: str) -> dict[str, list[int]]:
1634
+ def get_safetensors_tensor_shapes(repo_id: str, filename: str) -> dict[str, list[int]]:
1418
1635
  """
1419
- Get tensor shapes from a safetensors file using HTTP range requests
1636
+ Get tensor shapes from a safetensors file on HuggingFace Hub
1420
1637
  without downloading the entire file.
1421
1638
 
1639
+ Uses HTTP range requests to fetch only the metadata header.
1640
+
1422
1641
  Args:
1423
- url: Direct URL to the safetensors file
1642
+ repo_id: HuggingFace repo ID (e.g., "gg-gs/gemma-scope-2-1b-pt")
1643
+ filename: Path to the safetensors file within the repo
1424
1644
 
1425
1645
  Returns:
1426
1646
  Dictionary mapping tensor names to their shapes
1427
1647
  """
1428
- # Check if server supports range requests
1429
- response = requests.head(url, timeout=10)
1430
- response.raise_for_status()
1648
+ url = hf_hub_url(repo_id, filename)
1431
1649
 
1432
- accept_ranges = response.headers.get("Accept-Ranges", "")
1433
- if "bytes" not in accept_ranges:
1434
- raise ValueError("Server does not support range requests")
1650
+ # Get HuggingFace headers (includes auth token if available)
1651
+ hf_headers = build_hf_headers()
1435
1652
 
1436
1653
  # Fetch first 8 bytes to get metadata size
1437
- headers = {"Range": "bytes=0-7"}
1654
+ headers = {**hf_headers, "Range": "bytes=0-7"}
1438
1655
  response = requests.get(url, headers=headers, timeout=10)
1439
- if response.status_code != 206:
1440
- raise ValueError("Failed to fetch initial bytes for metadata size")
1656
+ response.raise_for_status()
1441
1657
 
1442
1658
  meta_size = int.from_bytes(response.content, byteorder="little")
1443
1659
 
1444
1660
  # Fetch the metadata header
1445
- headers = {"Range": f"bytes=8-{8 + meta_size - 1}"}
1661
+ headers = {**hf_headers, "Range": f"bytes=8-{8 + meta_size - 1}"}
1446
1662
  response = requests.get(url, headers=headers, timeout=10)
1447
- if response.status_code != 206:
1448
- raise ValueError("Failed to fetch metadata header")
1663
+ response.raise_for_status()
1449
1664
 
1450
1665
  metadata_json = response.content.decode("utf-8").strip()
1451
1666
  metadata = json.loads(metadata_json)
@@ -1525,9 +1740,10 @@ def get_mntss_clt_layer_config_from_hf(
1525
1740
  with open(base_config_path) as f:
1526
1741
  cfg_info: dict[str, Any] = yaml.safe_load(f)
1527
1742
 
1528
- # Get tensor shapes without downloading full files using HTTP range requests
1529
- encoder_url = hf_hub_url(repo_id, f"W_enc_{folder_name}.safetensors")
1530
- encoder_shapes = get_safetensors_tensor_shapes(encoder_url)
1743
+ # Get tensor shapes without downloading full files
1744
+ encoder_shapes = get_safetensors_tensor_shapes(
1745
+ repo_id, f"W_enc_{folder_name}.safetensors"
1746
+ )
1531
1747
 
1532
1748
  # Extract shapes for the required tensors
1533
1749
  b_dec_shape = encoder_shapes[f"b_dec_{folder_name}"]
@@ -1663,6 +1879,7 @@ NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
1663
1879
  "sae_lens": sae_lens_huggingface_loader,
1664
1880
  "connor_rob_hook_z": connor_rob_hook_z_huggingface_loader,
1665
1881
  "gemma_2": gemma_2_sae_huggingface_loader,
1882
+ "gemma_3": gemma_3_sae_huggingface_loader,
1666
1883
  "llama_scope": llama_scope_sae_huggingface_loader,
1667
1884
  "llama_scope_r1_distill": llama_scope_r1_distill_sae_huggingface_loader,
1668
1885
  "dictionary_learning_1": dictionary_learning_sae_huggingface_loader_1,
@@ -1680,6 +1897,7 @@ NAMED_PRETRAINED_SAE_CONFIG_GETTERS: dict[str, PretrainedSaeConfigHuggingfaceLoa
1680
1897
  "sae_lens": get_sae_lens_config_from_hf,
1681
1898
  "connor_rob_hook_z": get_connor_rob_hook_z_config_from_hf,
1682
1899
  "gemma_2": get_gemma_2_config_from_hf,
1900
+ "gemma_3": get_gemma_3_config_from_hf,
1683
1901
  "llama_scope": get_llama_scope_config_from_hf,
1684
1902
  "llama_scope_r1_distill": get_llama_scope_r1_distill_config_from_hf,
1685
1903
  "dictionary_learning_1": get_dictionary_learning_config_1_from_hf,
@@ -186,13 +186,13 @@ class PretokenizeRunner:
186
186
  """
187
187
  Load the dataset, tokenize it, and save it to disk and/or upload to Huggingface.
188
188
  """
189
- dataset = load_dataset(
189
+ dataset = load_dataset( # type: ignore
190
190
  self.cfg.dataset_path,
191
191
  name=self.cfg.dataset_name,
192
192
  data_dir=self.cfg.data_dir,
193
193
  data_files=self.cfg.data_files,
194
- split=self.cfg.split,
195
- streaming=self.cfg.streaming,
194
+ split=self.cfg.split, # type: ignore
195
+ streaming=self.cfg.streaming, # type: ignore
196
196
  )
197
197
  if isinstance(dataset, DatasetDict):
198
198
  raise ValueError(