sae-lens 6.15.0__py3-none-any.whl → 6.24.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,7 +9,7 @@ import requests
9
9
  import torch
10
10
  import yaml
11
11
  from huggingface_hub import hf_hub_download, hf_hub_url
12
- from huggingface_hub.utils import EntryNotFoundError
12
+ from huggingface_hub.utils import EntryNotFoundError, build_hf_headers
13
13
  from packaging.version import Version
14
14
  from safetensors import safe_open
15
15
  from safetensors.torch import load_file
@@ -46,6 +46,8 @@ LLM_METADATA_KEYS = {
46
46
  "sae_lens_training_version",
47
47
  "hook_name_out",
48
48
  "hook_head_index_out",
49
+ "hf_hook_name",
50
+ "hf_hook_name_out",
49
51
  }
50
52
 
51
53
 
@@ -523,6 +525,282 @@ def gemma_2_sae_huggingface_loader(
523
525
  return cfg_dict, state_dict, log_sparsity
524
526
 
525
527
 
528
+ def _infer_gemma_3_raw_cfg_dict(repo_id: str, folder_name: str) -> dict[str, Any]:
529
+ """
530
+ Infer the raw config dict for Gemma 3 SAEs from the repo_id and folder_name.
531
+ This is used when config.json doesn't exist in the repo.
532
+ """
533
+ # Extract layer number from folder name
534
+ layer_match = re.search(r"layer_(\d+)", folder_name)
535
+ if layer_match is None:
536
+ raise ValueError(
537
+ f"Could not extract layer number from folder_name: {folder_name}"
538
+ )
539
+ layer = int(layer_match.group(1))
540
+
541
+ # Convert repo_id to model_name: google/gemma-scope-2-{size}-{suffix} -> google/gemma-3-{size}-{suffix}
542
+ model_name = repo_id.replace("gemma-scope-2", "gemma-3")
543
+
544
+ # Determine hook type and HF hook points based on folder_name
545
+ if "transcoder" in folder_name or "clt" in folder_name:
546
+ hf_hook_point_in = f"model.layers.{layer}.pre_feedforward_layernorm.output"
547
+ hf_hook_point_out = f"model.layers.{layer}.post_feedforward_layernorm.output"
548
+ elif "resid_post" in folder_name:
549
+ hf_hook_point_in = f"model.layers.{layer}.output"
550
+ hf_hook_point_out = None
551
+ elif "attn_out" in folder_name:
552
+ hf_hook_point_in = f"model.layers.{layer}.self_attn.o_proj.input"
553
+ hf_hook_point_out = None
554
+ elif "mlp_out" in folder_name:
555
+ hf_hook_point_in = f"model.layers.{layer}.post_feedforward_layernorm.output"
556
+ hf_hook_point_out = None
557
+ else:
558
+ raise ValueError(f"Could not infer hook type from folder_name: {folder_name}")
559
+
560
+ cfg: dict[str, Any] = {
561
+ "architecture": "jump_relu",
562
+ "model_name": model_name,
563
+ "hf_hook_point_in": hf_hook_point_in,
564
+ }
565
+ if hf_hook_point_out is not None:
566
+ cfg["hf_hook_point_out"] = hf_hook_point_out
567
+
568
+ return cfg
569
+
570
+
571
+ def get_gemma_3_config_from_hf(
572
+ repo_id: str,
573
+ folder_name: str,
574
+ device: str,
575
+ force_download: bool = False,
576
+ cfg_overrides: dict[str, Any] | None = None,
577
+ ) -> dict[str, Any]:
578
+ # Try to load config.json from the repo, fall back to inferring if it doesn't exist
579
+ try:
580
+ config_path = hf_hub_download(
581
+ repo_id, f"{folder_name}/config.json", force_download=force_download
582
+ )
583
+ with open(config_path) as config_file:
584
+ raw_cfg_dict = json.load(config_file)
585
+ except EntryNotFoundError:
586
+ raw_cfg_dict = _infer_gemma_3_raw_cfg_dict(repo_id, folder_name)
587
+
588
+ if raw_cfg_dict.get("architecture") != "jump_relu":
589
+ raise ValueError(
590
+ f"Unexpected architecture in Gemma 3 config: {raw_cfg_dict.get('architecture')}"
591
+ )
592
+
593
+ layer_match = re.search(r"layer_(\d+)", folder_name)
594
+ if layer_match is None:
595
+ raise ValueError(
596
+ f"Could not extract layer number from folder_name: {folder_name}"
597
+ )
598
+ layer = int(layer_match.group(1))
599
+ hook_name_out = None
600
+ d_out = None
601
+ if "resid_post" in folder_name:
602
+ hook_name = f"blocks.{layer}.hook_resid_post"
603
+ elif "attn_out" in folder_name:
604
+ hook_name = f"blocks.{layer}.hook_attn_out"
605
+ elif "mlp_out" in folder_name:
606
+ hook_name = f"blocks.{layer}.hook_mlp_out"
607
+ elif "transcoder" in folder_name or "clt" in folder_name:
608
+ hook_name = f"blocks.{layer}.ln2.hook_normalized"
609
+ hook_name_out = f"blocks.{layer}.hook_mlp_out"
610
+ else:
611
+ raise ValueError("Hook name not found in folder_name.")
612
+
613
+ # hackily deal with clt file names
614
+ params_file_part = "/params.safetensors"
615
+ if "clt" in folder_name:
616
+ params_file_part = ".safetensors"
617
+
618
+ shapes_dict = get_safetensors_tensor_shapes(
619
+ repo_id, f"{folder_name}{params_file_part}"
620
+ )
621
+ d_in, d_sae = shapes_dict["w_enc"]
622
+ # TODO: update this for real model info
623
+ model_name = raw_cfg_dict["model_name"]
624
+ if "google" not in model_name:
625
+ model_name = "google/" + model_name
626
+ model_name = model_name.replace("-v3", "-3")
627
+ if "270m" in model_name:
628
+ # for some reason the 270m model on huggingface doesn't have the -pt suffix
629
+ model_name = model_name.replace("-pt", "")
630
+
631
+ architecture = "jumprelu"
632
+ if "transcoder" in folder_name or "clt" in folder_name:
633
+ architecture = "jumprelu_skip_transcoder"
634
+ d_out = shapes_dict["w_dec"][-1]
635
+
636
+ cfg = {
637
+ "architecture": architecture,
638
+ "d_in": d_in,
639
+ "d_sae": d_sae,
640
+ "dtype": "float32",
641
+ "model_name": model_name,
642
+ "hook_name": hook_name,
643
+ "hook_head_index": None,
644
+ "finetuning_scaling_factor": False,
645
+ "sae_lens_training_version": None,
646
+ "prepend_bos": True,
647
+ "dataset_path": "monology/pile-uncopyrighted",
648
+ "context_size": 1024,
649
+ "apply_b_dec_to_input": False,
650
+ "normalize_activations": None,
651
+ "hf_hook_name": raw_cfg_dict.get("hf_hook_point_in"),
652
+ }
653
+ if hook_name_out is not None:
654
+ cfg["hook_name_out"] = hook_name_out
655
+ cfg["hf_hook_name_out"] = raw_cfg_dict.get("hf_hook_point_out")
656
+ if d_out is not None:
657
+ cfg["d_out"] = d_out
658
+ if device is not None:
659
+ cfg["device"] = device
660
+
661
+ if cfg_overrides is not None:
662
+ cfg.update(cfg_overrides)
663
+
664
+ return cfg
665
+
666
+
667
+ def gemma_3_sae_huggingface_loader(
668
+ repo_id: str,
669
+ folder_name: str,
670
+ device: str = "cpu",
671
+ force_download: bool = False,
672
+ cfg_overrides: dict[str, Any] | None = None,
673
+ ) -> tuple[dict[str, Any], dict[str, torch.Tensor], torch.Tensor | None]:
674
+ """
675
+ Custom loader for Gemma 3 SAEs.
676
+ """
677
+ cfg_dict = get_gemma_3_config_from_hf(
678
+ repo_id,
679
+ folder_name,
680
+ device,
681
+ force_download,
682
+ cfg_overrides,
683
+ )
684
+
685
+ # replace folder name of 65k with 64k
686
+ # TODO: remove this workaround once weights are fixed
687
+ if "270m-pt" in repo_id:
688
+ if "65k" in folder_name:
689
+ folder_name = folder_name.replace("65k", "64k")
690
+ # replace folder name of 262k with 250k
691
+ if "262k" in folder_name:
692
+ folder_name = folder_name.replace("262k", "250k")
693
+
694
+ params_file = "params.safetensors"
695
+ if "clt" in folder_name:
696
+ params_file = folder_name.split("/")[-1] + ".safetensors"
697
+ folder_name = "/".join(folder_name.split("/")[:-1])
698
+
699
+ # Download the SAE weights
700
+ sae_path = hf_hub_download(
701
+ repo_id=repo_id,
702
+ filename=params_file,
703
+ subfolder=folder_name,
704
+ force_download=force_download,
705
+ )
706
+
707
+ raw_state_dict = load_file(sae_path, device=device)
708
+
709
+ with torch.no_grad():
710
+ w_dec = raw_state_dict["w_dec"]
711
+ if "clt" in folder_name:
712
+ w_dec = w_dec.sum(dim=1).contiguous()
713
+
714
+ state_dict = {
715
+ "W_enc": raw_state_dict["w_enc"],
716
+ "W_dec": w_dec,
717
+ "b_enc": raw_state_dict["b_enc"],
718
+ "b_dec": raw_state_dict["b_dec"],
719
+ "threshold": raw_state_dict["threshold"],
720
+ }
721
+
722
+ if "affine_skip_connection" in raw_state_dict:
723
+ state_dict["W_skip"] = raw_state_dict["affine_skip_connection"]
724
+
725
+ return cfg_dict, state_dict, None
726
+
727
+
728
+ def get_goodfire_config_from_hf(
729
+ repo_id: str,
730
+ folder_name: str, # noqa: ARG001
731
+ device: str,
732
+ force_download: bool = False, # noqa: ARG001
733
+ cfg_overrides: dict[str, Any] | None = None,
734
+ ) -> dict[str, Any]:
735
+ cfg_dict = None
736
+ if repo_id == "Goodfire/Llama-3.3-70B-Instruct-SAE-l50":
737
+ if folder_name != "Llama-3.3-70B-Instruct-SAE-l50.pt":
738
+ raise ValueError(f"Unsupported Goodfire SAE: {repo_id}/{folder_name}")
739
+ cfg_dict = {
740
+ "architecture": "standard",
741
+ "d_in": 8192,
742
+ "d_sae": 65536,
743
+ "model_name": "meta-llama/Llama-3.3-70B-Instruct",
744
+ "hook_name": "blocks.50.hook_resid_post",
745
+ "hook_head_index": None,
746
+ "dataset_path": "lmsys/lmsys-chat-1m",
747
+ "apply_b_dec_to_input": False,
748
+ }
749
+ elif repo_id == "Goodfire/Llama-3.1-8B-Instruct-SAE-l19":
750
+ if folder_name != "Llama-3.1-8B-Instruct-SAE-l19.pth":
751
+ raise ValueError(f"Unsupported Goodfire SAE: {repo_id}/{folder_name}")
752
+ cfg_dict = {
753
+ "architecture": "standard",
754
+ "d_in": 4096,
755
+ "d_sae": 65536,
756
+ "model_name": "meta-llama/Llama-3.1-8B-Instruct",
757
+ "hook_name": "blocks.19.hook_resid_post",
758
+ "hook_head_index": None,
759
+ "dataset_path": "lmsys/lmsys-chat-1m",
760
+ "apply_b_dec_to_input": False,
761
+ }
762
+ if cfg_dict is None:
763
+ raise ValueError(f"Unsupported Goodfire SAE: {repo_id}/{folder_name}")
764
+ if device is not None:
765
+ cfg_dict["device"] = device
766
+ if cfg_overrides is not None:
767
+ cfg_dict.update(cfg_overrides)
768
+ return cfg_dict
769
+
770
+
771
+ def get_goodfire_huggingface_loader(
772
+ repo_id: str,
773
+ folder_name: str,
774
+ device: str = "cpu",
775
+ force_download: bool = False,
776
+ cfg_overrides: dict[str, Any] | None = None,
777
+ ) -> tuple[dict[str, Any], dict[str, torch.Tensor], torch.Tensor | None]:
778
+ cfg_dict = get_goodfire_config_from_hf(
779
+ repo_id,
780
+ folder_name,
781
+ device,
782
+ force_download,
783
+ cfg_overrides,
784
+ )
785
+
786
+ # Download the SAE weights
787
+ sae_path = hf_hub_download(
788
+ repo_id=repo_id,
789
+ filename=folder_name,
790
+ force_download=force_download,
791
+ )
792
+ raw_state_dict = torch.load(sae_path, map_location=device)
793
+
794
+ state_dict = {
795
+ "W_enc": raw_state_dict["encoder_linear.weight"].T,
796
+ "W_dec": raw_state_dict["decoder_linear.weight"].T,
797
+ "b_enc": raw_state_dict["encoder_linear.bias"],
798
+ "b_dec": raw_state_dict["decoder_linear.bias"],
799
+ }
800
+
801
+ return cfg_dict, state_dict, None
802
+
803
+
526
804
  def get_llama_scope_config_from_hf(
527
805
  repo_id: str,
528
806
  folder_name: str,
@@ -677,10 +955,14 @@ def get_dictionary_learning_config_1_from_hf(
677
955
  activation_fn = "topk" if trainer["dict_class"] == "AutoEncoderTopK" else "relu"
678
956
  activation_fn_kwargs = {"k": trainer["k"]} if activation_fn == "topk" else {}
679
957
 
958
+ architecture = "standard"
959
+ if trainer["dict_class"] == "GatedAutoEncoder":
960
+ architecture = "gated"
961
+ elif trainer["dict_class"] == "MatryoshkaBatchTopKSAE":
962
+ architecture = "jumprelu"
963
+
680
964
  return {
681
- "architecture": (
682
- "gated" if trainer["dict_class"] == "GatedAutoEncoder" else "standard"
683
- ),
965
+ "architecture": architecture,
684
966
  "d_in": trainer["activation_dim"],
685
967
  "d_sae": trainer["dict_size"],
686
968
  "dtype": "float32",
@@ -829,9 +1111,12 @@ def dictionary_learning_sae_huggingface_loader_1(
829
1111
  )
830
1112
  encoder = torch.load(encoder_path, map_location="cpu")
831
1113
 
1114
+ W_enc = encoder["W_enc"] if "W_enc" in encoder else encoder["encoder.weight"].T
1115
+ W_dec = encoder["W_dec"] if "W_dec" in encoder else encoder["decoder.weight"].T
1116
+
832
1117
  state_dict = {
833
- "W_enc": encoder["encoder.weight"].T,
834
- "W_dec": encoder["decoder.weight"].T,
1118
+ "W_enc": W_enc,
1119
+ "W_dec": W_dec,
835
1120
  "b_dec": encoder.get(
836
1121
  "b_dec", encoder.get("bias", encoder.get("decoder_bias", None))
837
1122
  ),
@@ -839,6 +1124,8 @@ def dictionary_learning_sae_huggingface_loader_1(
839
1124
 
840
1125
  if "encoder.bias" in encoder:
841
1126
  state_dict["b_enc"] = encoder["encoder.bias"]
1127
+ if "b_enc" in encoder:
1128
+ state_dict["b_enc"] = encoder["b_enc"]
842
1129
 
843
1130
  if "mag_bias" in encoder:
844
1131
  state_dict["b_mag"] = encoder["mag_bias"]
@@ -847,6 +1134,12 @@ def dictionary_learning_sae_huggingface_loader_1(
847
1134
  if "r_mag" in encoder:
848
1135
  state_dict["r_mag"] = encoder["r_mag"]
849
1136
 
1137
+ if "threshold" in encoder:
1138
+ threshold = encoder["threshold"]
1139
+ if threshold.ndim == 0:
1140
+ threshold = torch.full((W_enc.size(1),), threshold)
1141
+ state_dict["threshold"] = threshold
1142
+
850
1143
  return cfg_dict, state_dict, None
851
1144
 
852
1145
 
@@ -1338,38 +1631,36 @@ def mwhanna_transcoder_huggingface_loader(
1338
1631
  return cfg_dict, state_dict, None
1339
1632
 
1340
1633
 
1341
- def get_safetensors_tensor_shapes(url: str) -> dict[str, list[int]]:
1634
+ def get_safetensors_tensor_shapes(repo_id: str, filename: str) -> dict[str, list[int]]:
1342
1635
  """
1343
- Get tensor shapes from a safetensors file using HTTP range requests
1636
+ Get tensor shapes from a safetensors file on HuggingFace Hub
1344
1637
  without downloading the entire file.
1345
1638
 
1639
+ Uses HTTP range requests to fetch only the metadata header.
1640
+
1346
1641
  Args:
1347
- url: Direct URL to the safetensors file
1642
+ repo_id: HuggingFace repo ID (e.g., "gg-gs/gemma-scope-2-1b-pt")
1643
+ filename: Path to the safetensors file within the repo
1348
1644
 
1349
1645
  Returns:
1350
1646
  Dictionary mapping tensor names to their shapes
1351
1647
  """
1352
- # Check if server supports range requests
1353
- response = requests.head(url, timeout=10)
1354
- response.raise_for_status()
1648
+ url = hf_hub_url(repo_id, filename)
1355
1649
 
1356
- accept_ranges = response.headers.get("Accept-Ranges", "")
1357
- if "bytes" not in accept_ranges:
1358
- raise ValueError("Server does not support range requests")
1650
+ # Get HuggingFace headers (includes auth token if available)
1651
+ hf_headers = build_hf_headers()
1359
1652
 
1360
1653
  # Fetch first 8 bytes to get metadata size
1361
- headers = {"Range": "bytes=0-7"}
1654
+ headers = {**hf_headers, "Range": "bytes=0-7"}
1362
1655
  response = requests.get(url, headers=headers, timeout=10)
1363
- if response.status_code != 206:
1364
- raise ValueError("Failed to fetch initial bytes for metadata size")
1656
+ response.raise_for_status()
1365
1657
 
1366
1658
  meta_size = int.from_bytes(response.content, byteorder="little")
1367
1659
 
1368
1660
  # Fetch the metadata header
1369
- headers = {"Range": f"bytes=8-{8 + meta_size - 1}"}
1661
+ headers = {**hf_headers, "Range": f"bytes=8-{8 + meta_size - 1}"}
1370
1662
  response = requests.get(url, headers=headers, timeout=10)
1371
- if response.status_code != 206:
1372
- raise ValueError("Failed to fetch metadata header")
1663
+ response.raise_for_status()
1373
1664
 
1374
1665
  metadata_json = response.content.decode("utf-8").strip()
1375
1666
  metadata = json.loads(metadata_json)
@@ -1449,9 +1740,10 @@ def get_mntss_clt_layer_config_from_hf(
1449
1740
  with open(base_config_path) as f:
1450
1741
  cfg_info: dict[str, Any] = yaml.safe_load(f)
1451
1742
 
1452
- # Get tensor shapes without downloading full files using HTTP range requests
1453
- encoder_url = hf_hub_url(repo_id, f"W_enc_{folder_name}.safetensors")
1454
- encoder_shapes = get_safetensors_tensor_shapes(encoder_url)
1743
+ # Get tensor shapes without downloading full files
1744
+ encoder_shapes = get_safetensors_tensor_shapes(
1745
+ repo_id, f"W_enc_{folder_name}.safetensors"
1746
+ )
1455
1747
 
1456
1748
  # Extract shapes for the required tensors
1457
1749
  b_dec_shape = encoder_shapes[f"b_dec_{folder_name}"]
@@ -1475,10 +1767,119 @@ def get_mntss_clt_layer_config_from_hf(
1475
1767
  }
1476
1768
 
1477
1769
 
1770
+ def get_temporal_sae_config_from_hf(
1771
+ repo_id: str,
1772
+ folder_name: str,
1773
+ device: str,
1774
+ force_download: bool = False,
1775
+ cfg_overrides: dict[str, Any] | None = None,
1776
+ ) -> dict[str, Any]:
1777
+ """Get TemporalSAE config without loading weights."""
1778
+ # Download config file
1779
+ conf_path = hf_hub_download(
1780
+ repo_id=repo_id,
1781
+ filename=f"{folder_name}/conf.yaml",
1782
+ force_download=force_download,
1783
+ )
1784
+
1785
+ # Load and parse config
1786
+ with open(conf_path) as f:
1787
+ yaml_config = yaml.safe_load(f)
1788
+
1789
+ # Extract parameters
1790
+ d_in = yaml_config["llm"]["dimin"]
1791
+ exp_factor = yaml_config["sae"]["exp_factor"]
1792
+ d_sae = int(d_in * exp_factor)
1793
+
1794
+ # extract layer from folder_name eg : "layer_12/temporal"
1795
+ layer = re.search(r"layer_(\d+)", folder_name)
1796
+ if layer is None:
1797
+ raise ValueError(f"Could not find layer in folder_name: {folder_name}")
1798
+ layer = int(layer.group(1))
1799
+
1800
+ # Build config dict
1801
+ cfg_dict = {
1802
+ "architecture": "temporal",
1803
+ "hook_name": f"blocks.{layer}.hook_resid_post",
1804
+ "d_in": d_in,
1805
+ "d_sae": d_sae,
1806
+ "n_heads": yaml_config["sae"]["n_heads"],
1807
+ "n_attn_layers": yaml_config["sae"]["n_attn_layers"],
1808
+ "bottleneck_factor": yaml_config["sae"]["bottleneck_factor"],
1809
+ "sae_diff_type": yaml_config["sae"]["sae_diff_type"],
1810
+ "kval_topk": yaml_config["sae"]["kval_topk"],
1811
+ "tied_weights": yaml_config["sae"]["tied_weights"],
1812
+ "dtype": yaml_config["data"]["dtype"],
1813
+ "device": device,
1814
+ "normalize_activations": "constant_scalar_rescale",
1815
+ "activation_normalization_factor": yaml_config["sae"]["scaling_factor"],
1816
+ "apply_b_dec_to_input": True,
1817
+ }
1818
+
1819
+ if cfg_overrides:
1820
+ cfg_dict.update(cfg_overrides)
1821
+
1822
+ return cfg_dict
1823
+
1824
+
1825
+ def temporal_sae_huggingface_loader(
1826
+ repo_id: str,
1827
+ folder_name: str,
1828
+ device: str = "cpu",
1829
+ force_download: bool = False,
1830
+ cfg_overrides: dict[str, Any] | None = None,
1831
+ ) -> tuple[dict[str, Any], dict[str, torch.Tensor], torch.Tensor | None]:
1832
+ """
1833
+ Load TemporalSAE from canrager/temporalSAEs format (safetensors version).
1834
+
1835
+ Expects folder_name to contain:
1836
+ - conf.yaml (configuration)
1837
+ - latest_ckpt.safetensors (model weights)
1838
+ """
1839
+
1840
+ cfg_dict = get_temporal_sae_config_from_hf(
1841
+ repo_id=repo_id,
1842
+ folder_name=folder_name,
1843
+ device=device,
1844
+ force_download=force_download,
1845
+ cfg_overrides=cfg_overrides,
1846
+ )
1847
+
1848
+ # Download checkpoint (safetensors format)
1849
+ ckpt_path = hf_hub_download(
1850
+ repo_id=repo_id,
1851
+ filename=f"{folder_name}/latest_ckpt.safetensors",
1852
+ force_download=force_download,
1853
+ )
1854
+
1855
+ # Load checkpoint from safetensors
1856
+ state_dict_raw = load_file(ckpt_path, device=device)
1857
+
1858
+ # Convert to SAELens naming convention
1859
+ # TemporalSAE uses: D (decoder), E (encoder), b (bias), attn_layers.*
1860
+ state_dict = {}
1861
+
1862
+ # Copy attention layers as-is
1863
+ for key, value in state_dict_raw.items():
1864
+ if key.startswith("attn_layers."):
1865
+ state_dict[key] = value.to(device)
1866
+
1867
+ # Main parameters
1868
+ state_dict["W_dec"] = state_dict_raw["D"].to(device)
1869
+ state_dict["b_dec"] = state_dict_raw["b"].to(device)
1870
+
1871
+ # Handle tied/untied weights
1872
+ if "E" in state_dict_raw:
1873
+ state_dict["W_enc"] = state_dict_raw["E"].to(device)
1874
+
1875
+ return cfg_dict, state_dict, None
1876
+
1877
+
1478
1878
  NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
1479
1879
  "sae_lens": sae_lens_huggingface_loader,
1480
1880
  "connor_rob_hook_z": connor_rob_hook_z_huggingface_loader,
1481
1881
  "gemma_2": gemma_2_sae_huggingface_loader,
1882
+ "gemma_3": gemma_3_sae_huggingface_loader,
1482
1883
  "llama_scope": llama_scope_sae_huggingface_loader,
1483
1884
  "llama_scope_r1_distill": llama_scope_r1_distill_sae_huggingface_loader,
1484
1885
  "dictionary_learning_1": dictionary_learning_sae_huggingface_loader_1,
@@ -1487,6 +1888,8 @@ NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
1487
1888
  "gemma_2_transcoder": gemma_2_transcoder_huggingface_loader,
1488
1889
  "mwhanna_transcoder": mwhanna_transcoder_huggingface_loader,
1489
1890
  "mntss_clt_layer_transcoder": mntss_clt_layer_huggingface_loader,
1891
+ "temporal": temporal_sae_huggingface_loader,
1892
+ "goodfire": get_goodfire_huggingface_loader,
1490
1893
  }
1491
1894
 
1492
1895
 
@@ -1494,6 +1897,7 @@ NAMED_PRETRAINED_SAE_CONFIG_GETTERS: dict[str, PretrainedSaeConfigHuggingfaceLoa
1494
1897
  "sae_lens": get_sae_lens_config_from_hf,
1495
1898
  "connor_rob_hook_z": get_connor_rob_hook_z_config_from_hf,
1496
1899
  "gemma_2": get_gemma_2_config_from_hf,
1900
+ "gemma_3": get_gemma_3_config_from_hf,
1497
1901
  "llama_scope": get_llama_scope_config_from_hf,
1498
1902
  "llama_scope_r1_distill": get_llama_scope_r1_distill_config_from_hf,
1499
1903
  "dictionary_learning_1": get_dictionary_learning_config_1_from_hf,
@@ -1502,4 +1906,6 @@ NAMED_PRETRAINED_SAE_CONFIG_GETTERS: dict[str, PretrainedSaeConfigHuggingfaceLoa
1502
1906
  "gemma_2_transcoder": get_gemma_2_transcoder_config_from_hf,
1503
1907
  "mwhanna_transcoder": get_mwhanna_transcoder_config_from_hf,
1504
1908
  "mntss_clt_layer_transcoder": get_mntss_clt_layer_config_from_hf,
1909
+ "temporal": get_temporal_sae_config_from_hf,
1910
+ "goodfire": get_goodfire_config_from_hf,
1505
1911
  }
@@ -1,6 +1,6 @@
1
1
  from dataclasses import dataclass
2
2
  from functools import cache
3
- from importlib import resources
3
+ from importlib.resources import files
4
4
  from typing import Any
5
5
 
6
6
  import yaml
@@ -24,7 +24,8 @@ def get_pretrained_saes_directory() -> dict[str, PretrainedSAELookup]:
24
24
  package = "sae_lens"
25
25
  # Access the file within the package using importlib.resources
26
26
  directory: dict[str, PretrainedSAELookup] = {}
27
- with resources.open_text(package, "pretrained_saes.yaml") as file:
27
+ yaml_file = files(package).joinpath("pretrained_saes.yaml")
28
+ with yaml_file.open("r") as file:
28
29
  # Load the YAML file content
29
30
  data = yaml.safe_load(file)
30
31
  for release, value in data.items():
@@ -68,7 +69,8 @@ def get_norm_scaling_factor(release: str, sae_id: str) -> float | None:
68
69
  float | None: The norm_scaling_factor if it exists, None otherwise.
69
70
  """
70
71
  package = "sae_lens"
71
- with resources.open_text(package, "pretrained_saes.yaml") as file:
72
+ yaml_file = files(package).joinpath("pretrained_saes.yaml")
73
+ with yaml_file.open("r") as file:
72
74
  data = yaml.safe_load(file)
73
75
  if release in data:
74
76
  for sae_info in data[release]["saes"]:
@@ -186,13 +186,13 @@ class PretokenizeRunner:
186
186
  """
187
187
  Load the dataset, tokenize it, and save it to disk and/or upload to Huggingface.
188
188
  """
189
- dataset = load_dataset(
189
+ dataset = load_dataset( # type: ignore
190
190
  self.cfg.dataset_path,
191
191
  name=self.cfg.dataset_name,
192
192
  data_dir=self.cfg.data_dir,
193
193
  data_files=self.cfg.data_files,
194
- split=self.cfg.split,
195
- streaming=self.cfg.streaming,
194
+ split=self.cfg.split, # type: ignore
195
+ streaming=self.cfg.streaming, # type: ignore
196
196
  )
197
197
  if isinstance(dataset, DatasetDict):
198
198
  raise ValueError(