sae-lens 6.24.0__tar.gz → 6.26.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. {sae_lens-6.24.0 → sae_lens-6.26.0}/PKG-INFO +1 -1
  2. {sae_lens-6.24.0 → sae_lens-6.26.0}/pyproject.toml +1 -1
  3. {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/__init__.py +13 -1
  4. {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/cache_activations_runner.py +2 -2
  5. {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/config.py +7 -2
  6. {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/constants.py +8 -0
  7. {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/loading/pretrained_sae_loaders.py +66 -66
  8. {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/pretrained_saes.yaml +160 -144
  9. {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/saes/__init__.py +10 -0
  10. sae_lens-6.26.0/sae_lens/saes/matching_pursuit_sae.py +334 -0
  11. {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/saes/sae.py +52 -12
  12. {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/training/activations_store.py +3 -2
  13. {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/util.py +21 -0
  14. {sae_lens-6.24.0 → sae_lens-6.26.0}/LICENSE +0 -0
  15. {sae_lens-6.24.0 → sae_lens-6.26.0}/README.md +0 -0
  16. {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/analysis/__init__.py +0 -0
  17. {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
  18. {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/analysis/neuronpedia_integration.py +0 -0
  19. {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/evals.py +0 -0
  20. {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/llm_sae_training_runner.py +0 -0
  21. {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/load_model.py +0 -0
  22. {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/loading/__init__.py +0 -0
  23. {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/loading/pretrained_saes_directory.py +0 -0
  24. {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/pretokenize_runner.py +0 -0
  25. {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/registry.py +0 -0
  26. {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/saes/batchtopk_sae.py +0 -0
  27. {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/saes/gated_sae.py +0 -0
  28. {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/saes/jumprelu_sae.py +0 -0
  29. {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/saes/matryoshka_batchtopk_sae.py +0 -0
  30. {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/saes/standard_sae.py +0 -0
  31. {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/saes/temporal_sae.py +0 -0
  32. {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/saes/topk_sae.py +0 -0
  33. {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/saes/transcoder.py +0 -0
  34. {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/tokenization_and_batching.py +0 -0
  35. {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/training/__init__.py +0 -0
  36. {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/training/activation_scaler.py +0 -0
  37. {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/training/mixing_buffer.py +0 -0
  38. {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/training/optim.py +0 -0
  39. {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/training/sae_trainer.py +0 -0
  40. {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/training/types.py +0 -0
  41. {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
  42. {sae_lens-6.24.0 → sae_lens-6.26.0}/sae_lens/tutorial/tsea.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sae-lens
3
- Version: 6.24.0
3
+ Version: 6.26.0
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  License-File: LICENSE
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "sae-lens"
3
- version = "6.24.0"
3
+ version = "6.26.0"
4
4
  description = "Training and Analyzing Sparse Autoencoders (SAEs)"
5
5
  authors = ["Joseph Bloom"]
6
6
  readme = "README.md"
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.24.0"
2
+ __version__ = "6.26.0"
3
3
 
4
4
  import logging
5
5
 
@@ -21,6 +21,10 @@ from sae_lens.saes import (
21
21
  JumpReLUTrainingSAEConfig,
22
22
  JumpReLUTranscoder,
23
23
  JumpReLUTranscoderConfig,
24
+ MatchingPursuitSAE,
25
+ MatchingPursuitSAEConfig,
26
+ MatchingPursuitTrainingSAE,
27
+ MatchingPursuitTrainingSAEConfig,
24
28
  MatryoshkaBatchTopKTrainingSAE,
25
29
  MatryoshkaBatchTopKTrainingSAEConfig,
26
30
  SAEConfig,
@@ -113,6 +117,10 @@ __all__ = [
113
117
  "MatryoshkaBatchTopKTrainingSAEConfig",
114
118
  "TemporalSAE",
115
119
  "TemporalSAEConfig",
120
+ "MatchingPursuitSAE",
121
+ "MatchingPursuitTrainingSAE",
122
+ "MatchingPursuitSAEConfig",
123
+ "MatchingPursuitTrainingSAEConfig",
116
124
  ]
117
125
 
118
126
 
@@ -139,3 +147,7 @@ register_sae_class(
139
147
  "jumprelu_skip_transcoder", JumpReLUSkipTranscoder, JumpReLUSkipTranscoderConfig
140
148
  )
141
149
  register_sae_class("temporal", TemporalSAE, TemporalSAEConfig)
150
+ register_sae_class("matching_pursuit", MatchingPursuitSAE, MatchingPursuitSAEConfig)
151
+ register_sae_training_class(
152
+ "matching_pursuit", MatchingPursuitTrainingSAE, MatchingPursuitTrainingSAEConfig
153
+ )
@@ -14,9 +14,9 @@ from transformer_lens.HookedTransformer import HookedRootModule
14
14
 
15
15
  from sae_lens import logger
16
16
  from sae_lens.config import CacheActivationsRunnerConfig
17
- from sae_lens.constants import DTYPE_MAP
18
17
  from sae_lens.load_model import load_model
19
18
  from sae_lens.training.activations_store import ActivationsStore
19
+ from sae_lens.util import str_to_dtype
20
20
 
21
21
 
22
22
  def _mk_activations_store(
@@ -97,7 +97,7 @@ class CacheActivationsRunner:
97
97
  bytes_per_token = (
98
98
  self.cfg.d_in * self.cfg.dtype.itemsize
99
99
  if isinstance(self.cfg.dtype, torch.dtype)
100
- else DTYPE_MAP[self.cfg.dtype].itemsize
100
+ else str_to_dtype(self.cfg.dtype).itemsize
101
101
  )
102
102
  total_training_tokens = self.cfg.n_seq_in_dataset * self.context_size
103
103
  total_disk_space_gb = total_training_tokens * bytes_per_token / 10**9
@@ -17,9 +17,14 @@ from datasets import (
17
17
  )
18
18
 
19
19
  from sae_lens import __version__, logger
20
- from sae_lens.constants import DTYPE_MAP
20
+
21
+ # keeping this unused import since some SAELens deps import DTYPE_MAP from config
22
+ from sae_lens.constants import (
23
+ DTYPE_MAP, # noqa: F401 # pyright: ignore[reportUnusedImport]
24
+ )
21
25
  from sae_lens.registry import get_sae_training_class
22
26
  from sae_lens.saes.sae import TrainingSAEConfig
27
+ from sae_lens.util import str_to_dtype
23
28
 
24
29
  if TYPE_CHECKING:
25
30
  pass
@@ -563,7 +568,7 @@ class CacheActivationsRunnerConfig:
563
568
 
564
569
  @property
565
570
  def bytes_per_token(self) -> int:
566
- return self.d_in * DTYPE_MAP[self.dtype].itemsize
571
+ return self.d_in * str_to_dtype(self.dtype).itemsize
567
572
 
568
573
  @property
569
574
  def n_tokens_in_buffer(self) -> int:
@@ -11,6 +11,14 @@ DTYPE_MAP = {
11
11
  "torch.bfloat16": torch.bfloat16,
12
12
  }
13
13
 
14
+ # Reverse mapping from torch.dtype to canonical string format
15
+ DTYPE_TO_STR = {
16
+ torch.float32: "float32",
17
+ torch.float64: "float64",
18
+ torch.float16: "float16",
19
+ torch.bfloat16: "bfloat16",
20
+ }
21
+
14
22
 
15
23
  SPARSITY_FILENAME = "sparsity.safetensors"
16
24
  SAE_WEIGHTS_FILENAME = "sae_weights.safetensors"
@@ -12,11 +12,9 @@ from huggingface_hub import hf_hub_download, hf_hub_url
12
12
  from huggingface_hub.utils import EntryNotFoundError, build_hf_headers
13
13
  from packaging.version import Version
14
14
  from safetensors import safe_open
15
- from safetensors.torch import load_file
16
15
 
17
16
  from sae_lens import logger
18
17
  from sae_lens.constants import (
19
- DTYPE_MAP,
20
18
  SAE_CFG_FILENAME,
21
19
  SAE_WEIGHTS_FILENAME,
22
20
  SPARSIFY_WEIGHTS_FILENAME,
@@ -28,7 +26,7 @@ from sae_lens.loading.pretrained_saes_directory import (
28
26
  get_repo_id_and_folder_name,
29
27
  )
30
28
  from sae_lens.registry import get_sae_class
31
- from sae_lens.util import filter_valid_dataclass_fields
29
+ from sae_lens.util import filter_valid_dataclass_fields, str_to_dtype
32
30
 
33
31
  LLM_METADATA_KEYS = {
34
32
  "model_name",
@@ -51,6 +49,21 @@ LLM_METADATA_KEYS = {
51
49
  }
52
50
 
53
51
 
52
+ def load_safetensors_weights(
53
+ path: str | Path, device: str = "cpu", dtype: torch.dtype | str | None = None
54
+ ) -> dict[str, torch.Tensor]:
55
+ """Load safetensors weights and optionally convert to a different dtype"""
56
+ loaded_weights = {}
57
+ dtype = str_to_dtype(dtype) if isinstance(dtype, str) else dtype
58
+ with safe_open(path, framework="pt", device=device) as f:
59
+ for k in f.keys(): # noqa: SIM118
60
+ weight = f.get_tensor(k)
61
+ if dtype is not None:
62
+ weight = weight.to(dtype=dtype)
63
+ loaded_weights[k] = weight
64
+ return loaded_weights
65
+
66
+
54
67
  # loaders take in a release, sae_id, device, and whether to force download, and returns a tuple of config, state_dict, and log sparsity
55
68
  class PretrainedSaeHuggingfaceLoader(Protocol):
56
69
  def __call__(
@@ -341,7 +354,7 @@ def read_sae_components_from_disk(
341
354
  Given a loaded dictionary and a path to a weight file, load the weights and return the state_dict.
342
355
  """
343
356
  if dtype is None:
344
- dtype = DTYPE_MAP[cfg_dict["dtype"]]
357
+ dtype = str_to_dtype(cfg_dict["dtype"])
345
358
 
346
359
  state_dict = {}
347
360
  with safe_open(weight_path, framework="pt", device=device) as f: # type: ignore
@@ -682,15 +695,6 @@ def gemma_3_sae_huggingface_loader(
682
695
  cfg_overrides,
683
696
  )
684
697
 
685
- # replace folder name of 65k with 64k
686
- # TODO: remove this workaround once weights are fixed
687
- if "270m-pt" in repo_id:
688
- if "65k" in folder_name:
689
- folder_name = folder_name.replace("65k", "64k")
690
- # replace folder name of 262k with 250k
691
- if "262k" in folder_name:
692
- folder_name = folder_name.replace("262k", "250k")
693
-
694
698
  params_file = "params.safetensors"
695
699
  if "clt" in folder_name:
696
700
  params_file = folder_name.split("/")[-1] + ".safetensors"
@@ -704,7 +708,9 @@ def gemma_3_sae_huggingface_loader(
704
708
  force_download=force_download,
705
709
  )
706
710
 
707
- raw_state_dict = load_file(sae_path, device=device)
711
+ raw_state_dict = load_safetensors_weights(
712
+ sae_path, device=device, dtype=cfg_dict.get("dtype")
713
+ )
708
714
 
709
715
  with torch.no_grad():
710
716
  w_dec = raw_state_dict["w_dec"]
@@ -791,11 +797,13 @@ def get_goodfire_huggingface_loader(
791
797
  )
792
798
  raw_state_dict = torch.load(sae_path, map_location=device)
793
799
 
800
+ target_dtype = str_to_dtype(cfg_dict.get("dtype", "float32"))
801
+
794
802
  state_dict = {
795
- "W_enc": raw_state_dict["encoder_linear.weight"].T,
796
- "W_dec": raw_state_dict["decoder_linear.weight"].T,
797
- "b_enc": raw_state_dict["encoder_linear.bias"],
798
- "b_dec": raw_state_dict["decoder_linear.bias"],
803
+ "W_enc": raw_state_dict["encoder_linear.weight"].T.to(dtype=target_dtype),
804
+ "W_dec": raw_state_dict["decoder_linear.weight"].T.to(dtype=target_dtype),
805
+ "b_enc": raw_state_dict["encoder_linear.bias"].to(dtype=target_dtype),
806
+ "b_dec": raw_state_dict["decoder_linear.bias"].to(dtype=target_dtype),
799
807
  }
800
808
 
801
809
  return cfg_dict, state_dict, None
@@ -898,26 +906,19 @@ def llama_scope_sae_huggingface_loader(
898
906
  force_download=force_download,
899
907
  )
900
908
 
901
- # Load the weights using load_file instead of safe_open
902
- state_dict_loaded = load_file(sae_path, device=device)
909
+ state_dict_loaded = load_safetensors_weights(
910
+ sae_path, device=device, dtype=cfg_dict.get("dtype")
911
+ )
903
912
 
904
913
  # Convert and organize the weights
905
914
  state_dict = {
906
- "W_enc": state_dict_loaded["encoder.weight"]
907
- .to(dtype=DTYPE_MAP[cfg_dict["dtype"]])
908
- .T,
909
- "W_dec": state_dict_loaded["decoder.weight"]
910
- .to(dtype=DTYPE_MAP[cfg_dict["dtype"]])
911
- .T,
912
- "b_enc": state_dict_loaded["encoder.bias"].to(
913
- dtype=DTYPE_MAP[cfg_dict["dtype"]]
914
- ),
915
- "b_dec": state_dict_loaded["decoder.bias"].to(
916
- dtype=DTYPE_MAP[cfg_dict["dtype"]]
917
- ),
915
+ "W_enc": state_dict_loaded["encoder.weight"].T,
916
+ "W_dec": state_dict_loaded["decoder.weight"].T,
917
+ "b_enc": state_dict_loaded["encoder.bias"],
918
+ "b_dec": state_dict_loaded["decoder.bias"],
918
919
  "threshold": torch.ones(
919
920
  cfg_dict["d_sae"],
920
- dtype=DTYPE_MAP[cfg_dict["dtype"]],
921
+ dtype=str_to_dtype(cfg_dict["dtype"]),
921
922
  device=cfg_dict["device"],
922
923
  )
923
924
  * cfg_dict["jump_relu_threshold"],
@@ -1228,26 +1229,17 @@ def llama_scope_r1_distill_sae_huggingface_loader(
1228
1229
  force_download=force_download,
1229
1230
  )
1230
1231
 
1231
- # Load the weights using load_file instead of safe_open
1232
- state_dict_loaded = load_file(sae_path, device=device)
1232
+ state_dict_loaded = load_safetensors_weights(
1233
+ sae_path, device=device, dtype=cfg_dict.get("dtype")
1234
+ )
1233
1235
 
1234
1236
  # Convert and organize the weights
1235
1237
  state_dict = {
1236
- "W_enc": state_dict_loaded["encoder.weight"]
1237
- .to(dtype=DTYPE_MAP[cfg_dict["dtype"]])
1238
- .T,
1239
- "W_dec": state_dict_loaded["decoder.weight"]
1240
- .to(dtype=DTYPE_MAP[cfg_dict["dtype"]])
1241
- .T,
1242
- "b_enc": state_dict_loaded["encoder.bias"].to(
1243
- dtype=DTYPE_MAP[cfg_dict["dtype"]]
1244
- ),
1245
- "b_dec": state_dict_loaded["decoder.bias"].to(
1246
- dtype=DTYPE_MAP[cfg_dict["dtype"]]
1247
- ),
1248
- "threshold": state_dict_loaded["log_jumprelu_threshold"]
1249
- .to(dtype=DTYPE_MAP[cfg_dict["dtype"]])
1250
- .exp(),
1238
+ "W_enc": state_dict_loaded["encoder.weight"].T,
1239
+ "W_dec": state_dict_loaded["decoder.weight"].T,
1240
+ "b_enc": state_dict_loaded["encoder.bias"],
1241
+ "b_dec": state_dict_loaded["decoder.bias"],
1242
+ "threshold": state_dict_loaded["log_jumprelu_threshold"].exp(),
1251
1243
  }
1252
1244
 
1253
1245
  # No sparsity tensor for Llama Scope SAEs
@@ -1367,34 +1359,34 @@ def sparsify_disk_loader(
1367
1359
  cfg_dict = get_sparsify_config_from_disk(path, device, cfg_overrides)
1368
1360
 
1369
1361
  weight_path = Path(path) / SPARSIFY_WEIGHTS_FILENAME
1370
- state_dict_loaded = load_file(weight_path, device=device)
1371
-
1372
- dtype = DTYPE_MAP[cfg_dict["dtype"]]
1362
+ state_dict_loaded = load_safetensors_weights(
1363
+ weight_path, device=device, dtype=cfg_dict.get("dtype")
1364
+ )
1373
1365
 
1374
1366
  W_enc = (
1375
1367
  state_dict_loaded["W_enc"]
1376
1368
  if "W_enc" in state_dict_loaded
1377
1369
  else state_dict_loaded["encoder.weight"].T
1378
- ).to(dtype)
1370
+ )
1379
1371
 
1380
1372
  if "W_dec" in state_dict_loaded:
1381
- W_dec = state_dict_loaded["W_dec"].T.to(dtype)
1373
+ W_dec = state_dict_loaded["W_dec"].T
1382
1374
  else:
1383
- W_dec = state_dict_loaded["decoder.weight"].T.to(dtype)
1375
+ W_dec = state_dict_loaded["decoder.weight"].T
1384
1376
 
1385
1377
  if "b_enc" in state_dict_loaded:
1386
- b_enc = state_dict_loaded["b_enc"].to(dtype)
1378
+ b_enc = state_dict_loaded["b_enc"]
1387
1379
  elif "encoder.bias" in state_dict_loaded:
1388
- b_enc = state_dict_loaded["encoder.bias"].to(dtype)
1380
+ b_enc = state_dict_loaded["encoder.bias"]
1389
1381
  else:
1390
- b_enc = torch.zeros(cfg_dict["d_sae"], dtype=dtype, device=device)
1382
+ b_enc = torch.zeros(cfg_dict["d_sae"], dtype=W_dec.dtype, device=device)
1391
1383
 
1392
1384
  if "b_dec" in state_dict_loaded:
1393
- b_dec = state_dict_loaded["b_dec"].to(dtype)
1385
+ b_dec = state_dict_loaded["b_dec"]
1394
1386
  elif "decoder.bias" in state_dict_loaded:
1395
- b_dec = state_dict_loaded["decoder.bias"].to(dtype)
1387
+ b_dec = state_dict_loaded["decoder.bias"]
1396
1388
  else:
1397
- b_dec = torch.zeros(cfg_dict["d_in"], dtype=dtype, device=device)
1389
+ b_dec = torch.zeros(cfg_dict["d_in"], dtype=W_dec.dtype, device=device)
1398
1390
 
1399
1391
  state_dict = {"W_enc": W_enc, "b_enc": b_enc, "W_dec": W_dec, "b_dec": b_dec}
1400
1392
  return cfg_dict, state_dict
@@ -1625,7 +1617,9 @@ def mwhanna_transcoder_huggingface_loader(
1625
1617
  )
1626
1618
 
1627
1619
  # Load weights from safetensors
1628
- state_dict = load_file(file_path, device=device)
1620
+ state_dict = load_safetensors_weights(
1621
+ file_path, device=device, dtype=cfg_dict.get("dtype")
1622
+ )
1629
1623
  state_dict["W_enc"] = state_dict["W_enc"].T
1630
1624
 
1631
1625
  return cfg_dict, state_dict, None
@@ -1709,8 +1703,12 @@ def mntss_clt_layer_huggingface_loader(
1709
1703
  force_download=force_download,
1710
1704
  )
1711
1705
 
1712
- encoder_state_dict = load_file(encoder_path, device=device)
1713
- decoder_state_dict = load_file(decoder_path, device=device)
1706
+ encoder_state_dict = load_safetensors_weights(
1707
+ encoder_path, device=device, dtype=cfg_dict.get("dtype")
1708
+ )
1709
+ decoder_state_dict = load_safetensors_weights(
1710
+ decoder_path, device=device, dtype=cfg_dict.get("dtype")
1711
+ )
1714
1712
 
1715
1713
  with torch.no_grad():
1716
1714
  state_dict = {
@@ -1853,7 +1851,9 @@ def temporal_sae_huggingface_loader(
1853
1851
  )
1854
1852
 
1855
1853
  # Load checkpoint from safetensors
1856
- state_dict_raw = load_file(ckpt_path, device=device)
1854
+ state_dict_raw = load_safetensors_weights(
1855
+ ckpt_path, device=device, dtype=cfg_dict.get("dtype")
1856
+ )
1857
1857
 
1858
1858
  # Convert to SAELens naming convention
1859
1859
  # TemporalSAE uses: D (decoder), E (encoder), b (bias), attn_layers.*