sae-lens 6.24.1__py3-none-any.whl → 6.25.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sae_lens/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.24.1"
2
+ __version__ = "6.25.1"
3
3
 
4
4
  import logging
5
5
 
@@ -14,9 +14,9 @@ from transformer_lens.HookedTransformer import HookedRootModule
14
14
 
15
15
  from sae_lens import logger
16
16
  from sae_lens.config import CacheActivationsRunnerConfig
17
- from sae_lens.constants import DTYPE_MAP
18
17
  from sae_lens.load_model import load_model
19
18
  from sae_lens.training.activations_store import ActivationsStore
19
+ from sae_lens.util import str_to_dtype
20
20
 
21
21
 
22
22
  def _mk_activations_store(
@@ -97,7 +97,7 @@ class CacheActivationsRunner:
97
97
  bytes_per_token = (
98
98
  self.cfg.d_in * self.cfg.dtype.itemsize
99
99
  if isinstance(self.cfg.dtype, torch.dtype)
100
- else DTYPE_MAP[self.cfg.dtype].itemsize
100
+ else str_to_dtype(self.cfg.dtype).itemsize
101
101
  )
102
102
  total_training_tokens = self.cfg.n_seq_in_dataset * self.context_size
103
103
  total_disk_space_gb = total_training_tokens * bytes_per_token / 10**9
sae_lens/config.py CHANGED
@@ -17,9 +17,9 @@ from datasets import (
17
17
  )
18
18
 
19
19
  from sae_lens import __version__, logger
20
- from sae_lens.constants import DTYPE_MAP
21
20
  from sae_lens.registry import get_sae_training_class
22
21
  from sae_lens.saes.sae import TrainingSAEConfig
22
+ from sae_lens.util import str_to_dtype
23
23
 
24
24
  if TYPE_CHECKING:
25
25
  pass
@@ -563,7 +563,7 @@ class CacheActivationsRunnerConfig:
563
563
 
564
564
  @property
565
565
  def bytes_per_token(self) -> int:
566
- return self.d_in * DTYPE_MAP[self.dtype].itemsize
566
+ return self.d_in * str_to_dtype(self.dtype).itemsize
567
567
 
568
568
  @property
569
569
  def n_tokens_in_buffer(self) -> int:
sae_lens/constants.py CHANGED
@@ -11,6 +11,14 @@ DTYPE_MAP = {
11
11
  "torch.bfloat16": torch.bfloat16,
12
12
  }
13
13
 
14
+ # Reverse mapping from torch.dtype to canonical string format
15
+ DTYPE_TO_STR = {
16
+ torch.float32: "float32",
17
+ torch.float64: "float64",
18
+ torch.float16: "float16",
19
+ torch.bfloat16: "bfloat16",
20
+ }
21
+
14
22
 
15
23
  SPARSITY_FILENAME = "sparsity.safetensors"
16
24
  SAE_WEIGHTS_FILENAME = "sae_weights.safetensors"
@@ -12,11 +12,9 @@ from huggingface_hub import hf_hub_download, hf_hub_url
12
12
  from huggingface_hub.utils import EntryNotFoundError, build_hf_headers
13
13
  from packaging.version import Version
14
14
  from safetensors import safe_open
15
- from safetensors.torch import load_file
16
15
 
17
16
  from sae_lens import logger
18
17
  from sae_lens.constants import (
19
- DTYPE_MAP,
20
18
  SAE_CFG_FILENAME,
21
19
  SAE_WEIGHTS_FILENAME,
22
20
  SPARSIFY_WEIGHTS_FILENAME,
@@ -28,7 +26,7 @@ from sae_lens.loading.pretrained_saes_directory import (
28
26
  get_repo_id_and_folder_name,
29
27
  )
30
28
  from sae_lens.registry import get_sae_class
31
- from sae_lens.util import filter_valid_dataclass_fields
29
+ from sae_lens.util import filter_valid_dataclass_fields, str_to_dtype
32
30
 
33
31
  LLM_METADATA_KEYS = {
34
32
  "model_name",
@@ -51,6 +49,21 @@ LLM_METADATA_KEYS = {
51
49
  }
52
50
 
53
51
 
52
+ def load_safetensors_weights(
53
+ path: str | Path, device: str = "cpu", dtype: torch.dtype | str | None = None
54
+ ) -> dict[str, torch.Tensor]:
55
+ """Load safetensors weights and optionally convert to a different dtype"""
56
+ loaded_weights = {}
57
+ dtype = str_to_dtype(dtype) if isinstance(dtype, str) else dtype
58
+ with safe_open(path, framework="pt", device=device) as f:
59
+ for k in f.keys(): # noqa: SIM118
60
+ weight = f.get_tensor(k)
61
+ if dtype is not None:
62
+ weight = weight.to(dtype=dtype)
63
+ loaded_weights[k] = weight
64
+ return loaded_weights
65
+
66
+
54
67
  # loaders take in a release, sae_id, device, and whether to force download, and returns a tuple of config, state_dict, and log sparsity
55
68
  class PretrainedSaeHuggingfaceLoader(Protocol):
56
69
  def __call__(
@@ -341,7 +354,7 @@ def read_sae_components_from_disk(
341
354
  Given a loaded dictionary and a path to a weight file, load the weights and return the state_dict.
342
355
  """
343
356
  if dtype is None:
344
- dtype = DTYPE_MAP[cfg_dict["dtype"]]
357
+ dtype = str_to_dtype(cfg_dict["dtype"])
345
358
 
346
359
  state_dict = {}
347
360
  with safe_open(weight_path, framework="pt", device=device) as f: # type: ignore
@@ -682,15 +695,6 @@ def gemma_3_sae_huggingface_loader(
682
695
  cfg_overrides,
683
696
  )
684
697
 
685
- # replace folder name of 65k with 64k
686
- # TODO: remove this workaround once weights are fixed
687
- if "270m-pt" in repo_id:
688
- if "65k" in folder_name:
689
- folder_name = folder_name.replace("65k", "64k")
690
- # replace folder name of 262k with 250k
691
- if "262k" in folder_name:
692
- folder_name = folder_name.replace("262k", "250k")
693
-
694
698
  params_file = "params.safetensors"
695
699
  if "clt" in folder_name:
696
700
  params_file = folder_name.split("/")[-1] + ".safetensors"
@@ -704,7 +708,9 @@ def gemma_3_sae_huggingface_loader(
704
708
  force_download=force_download,
705
709
  )
706
710
 
707
- raw_state_dict = load_file(sae_path, device=device)
711
+ raw_state_dict = load_safetensors_weights(
712
+ sae_path, device=device, dtype=cfg_dict.get("dtype")
713
+ )
708
714
 
709
715
  with torch.no_grad():
710
716
  w_dec = raw_state_dict["w_dec"]
@@ -791,11 +797,13 @@ def get_goodfire_huggingface_loader(
791
797
  )
792
798
  raw_state_dict = torch.load(sae_path, map_location=device)
793
799
 
800
+ target_dtype = str_to_dtype(cfg_dict.get("dtype", "float32"))
801
+
794
802
  state_dict = {
795
- "W_enc": raw_state_dict["encoder_linear.weight"].T,
796
- "W_dec": raw_state_dict["decoder_linear.weight"].T,
797
- "b_enc": raw_state_dict["encoder_linear.bias"],
798
- "b_dec": raw_state_dict["decoder_linear.bias"],
803
+ "W_enc": raw_state_dict["encoder_linear.weight"].T.to(dtype=target_dtype),
804
+ "W_dec": raw_state_dict["decoder_linear.weight"].T.to(dtype=target_dtype),
805
+ "b_enc": raw_state_dict["encoder_linear.bias"].to(dtype=target_dtype),
806
+ "b_dec": raw_state_dict["decoder_linear.bias"].to(dtype=target_dtype),
799
807
  }
800
808
 
801
809
  return cfg_dict, state_dict, None
@@ -898,26 +906,19 @@ def llama_scope_sae_huggingface_loader(
898
906
  force_download=force_download,
899
907
  )
900
908
 
901
- # Load the weights using load_file instead of safe_open
902
- state_dict_loaded = load_file(sae_path, device=device)
909
+ state_dict_loaded = load_safetensors_weights(
910
+ sae_path, device=device, dtype=cfg_dict.get("dtype")
911
+ )
903
912
 
904
913
  # Convert and organize the weights
905
914
  state_dict = {
906
- "W_enc": state_dict_loaded["encoder.weight"]
907
- .to(dtype=DTYPE_MAP[cfg_dict["dtype"]])
908
- .T,
909
- "W_dec": state_dict_loaded["decoder.weight"]
910
- .to(dtype=DTYPE_MAP[cfg_dict["dtype"]])
911
- .T,
912
- "b_enc": state_dict_loaded["encoder.bias"].to(
913
- dtype=DTYPE_MAP[cfg_dict["dtype"]]
914
- ),
915
- "b_dec": state_dict_loaded["decoder.bias"].to(
916
- dtype=DTYPE_MAP[cfg_dict["dtype"]]
917
- ),
915
+ "W_enc": state_dict_loaded["encoder.weight"].T,
916
+ "W_dec": state_dict_loaded["decoder.weight"].T,
917
+ "b_enc": state_dict_loaded["encoder.bias"],
918
+ "b_dec": state_dict_loaded["decoder.bias"],
918
919
  "threshold": torch.ones(
919
920
  cfg_dict["d_sae"],
920
- dtype=DTYPE_MAP[cfg_dict["dtype"]],
921
+ dtype=str_to_dtype(cfg_dict["dtype"]),
921
922
  device=cfg_dict["device"],
922
923
  )
923
924
  * cfg_dict["jump_relu_threshold"],
@@ -1228,26 +1229,17 @@ def llama_scope_r1_distill_sae_huggingface_loader(
1228
1229
  force_download=force_download,
1229
1230
  )
1230
1231
 
1231
- # Load the weights using load_file instead of safe_open
1232
- state_dict_loaded = load_file(sae_path, device=device)
1232
+ state_dict_loaded = load_safetensors_weights(
1233
+ sae_path, device=device, dtype=cfg_dict.get("dtype")
1234
+ )
1233
1235
 
1234
1236
  # Convert and organize the weights
1235
1237
  state_dict = {
1236
- "W_enc": state_dict_loaded["encoder.weight"]
1237
- .to(dtype=DTYPE_MAP[cfg_dict["dtype"]])
1238
- .T,
1239
- "W_dec": state_dict_loaded["decoder.weight"]
1240
- .to(dtype=DTYPE_MAP[cfg_dict["dtype"]])
1241
- .T,
1242
- "b_enc": state_dict_loaded["encoder.bias"].to(
1243
- dtype=DTYPE_MAP[cfg_dict["dtype"]]
1244
- ),
1245
- "b_dec": state_dict_loaded["decoder.bias"].to(
1246
- dtype=DTYPE_MAP[cfg_dict["dtype"]]
1247
- ),
1248
- "threshold": state_dict_loaded["log_jumprelu_threshold"]
1249
- .to(dtype=DTYPE_MAP[cfg_dict["dtype"]])
1250
- .exp(),
1238
+ "W_enc": state_dict_loaded["encoder.weight"].T,
1239
+ "W_dec": state_dict_loaded["decoder.weight"].T,
1240
+ "b_enc": state_dict_loaded["encoder.bias"],
1241
+ "b_dec": state_dict_loaded["decoder.bias"],
1242
+ "threshold": state_dict_loaded["log_jumprelu_threshold"].exp(),
1251
1243
  }
1252
1244
 
1253
1245
  # No sparsity tensor for Llama Scope SAEs
@@ -1367,34 +1359,34 @@ def sparsify_disk_loader(
1367
1359
  cfg_dict = get_sparsify_config_from_disk(path, device, cfg_overrides)
1368
1360
 
1369
1361
  weight_path = Path(path) / SPARSIFY_WEIGHTS_FILENAME
1370
- state_dict_loaded = load_file(weight_path, device=device)
1371
-
1372
- dtype = DTYPE_MAP[cfg_dict["dtype"]]
1362
+ state_dict_loaded = load_safetensors_weights(
1363
+ weight_path, device=device, dtype=cfg_dict.get("dtype")
1364
+ )
1373
1365
 
1374
1366
  W_enc = (
1375
1367
  state_dict_loaded["W_enc"]
1376
1368
  if "W_enc" in state_dict_loaded
1377
1369
  else state_dict_loaded["encoder.weight"].T
1378
- ).to(dtype)
1370
+ )
1379
1371
 
1380
1372
  if "W_dec" in state_dict_loaded:
1381
- W_dec = state_dict_loaded["W_dec"].T.to(dtype)
1373
+ W_dec = state_dict_loaded["W_dec"].T
1382
1374
  else:
1383
- W_dec = state_dict_loaded["decoder.weight"].T.to(dtype)
1375
+ W_dec = state_dict_loaded["decoder.weight"].T
1384
1376
 
1385
1377
  if "b_enc" in state_dict_loaded:
1386
- b_enc = state_dict_loaded["b_enc"].to(dtype)
1378
+ b_enc = state_dict_loaded["b_enc"]
1387
1379
  elif "encoder.bias" in state_dict_loaded:
1388
- b_enc = state_dict_loaded["encoder.bias"].to(dtype)
1380
+ b_enc = state_dict_loaded["encoder.bias"]
1389
1381
  else:
1390
- b_enc = torch.zeros(cfg_dict["d_sae"], dtype=dtype, device=device)
1382
+ b_enc = torch.zeros(cfg_dict["d_sae"], dtype=W_dec.dtype, device=device)
1391
1383
 
1392
1384
  if "b_dec" in state_dict_loaded:
1393
- b_dec = state_dict_loaded["b_dec"].to(dtype)
1385
+ b_dec = state_dict_loaded["b_dec"]
1394
1386
  elif "decoder.bias" in state_dict_loaded:
1395
- b_dec = state_dict_loaded["decoder.bias"].to(dtype)
1387
+ b_dec = state_dict_loaded["decoder.bias"]
1396
1388
  else:
1397
- b_dec = torch.zeros(cfg_dict["d_in"], dtype=dtype, device=device)
1389
+ b_dec = torch.zeros(cfg_dict["d_in"], dtype=W_dec.dtype, device=device)
1398
1390
 
1399
1391
  state_dict = {"W_enc": W_enc, "b_enc": b_enc, "W_dec": W_dec, "b_dec": b_dec}
1400
1392
  return cfg_dict, state_dict
@@ -1625,7 +1617,9 @@ def mwhanna_transcoder_huggingface_loader(
1625
1617
  )
1626
1618
 
1627
1619
  # Load weights from safetensors
1628
- state_dict = load_file(file_path, device=device)
1620
+ state_dict = load_safetensors_weights(
1621
+ file_path, device=device, dtype=cfg_dict.get("dtype")
1622
+ )
1629
1623
  state_dict["W_enc"] = state_dict["W_enc"].T
1630
1624
 
1631
1625
  return cfg_dict, state_dict, None
@@ -1709,8 +1703,12 @@ def mntss_clt_layer_huggingface_loader(
1709
1703
  force_download=force_download,
1710
1704
  )
1711
1705
 
1712
- encoder_state_dict = load_file(encoder_path, device=device)
1713
- decoder_state_dict = load_file(decoder_path, device=device)
1706
+ encoder_state_dict = load_safetensors_weights(
1707
+ encoder_path, device=device, dtype=cfg_dict.get("dtype")
1708
+ )
1709
+ decoder_state_dict = load_safetensors_weights(
1710
+ decoder_path, device=device, dtype=cfg_dict.get("dtype")
1711
+ )
1714
1712
 
1715
1713
  with torch.no_grad():
1716
1714
  state_dict = {
@@ -1853,7 +1851,9 @@ def temporal_sae_huggingface_loader(
1853
1851
  )
1854
1852
 
1855
1853
  # Load checkpoint from safetensors
1856
- state_dict_raw = load_file(ckpt_path, device=device)
1854
+ state_dict_raw = load_safetensors_weights(
1855
+ ckpt_path, device=device, dtype=cfg_dict.get("dtype")
1856
+ )
1857
1857
 
1858
1858
  # Convert to SAELens naming convention
1859
1859
  # TemporalSAE uses: D (decoder), E (encoder), b (bias), attn_layers.*
@@ -10197,6 +10197,7 @@ gemma-scope-2-27b-it-res:
10197
10197
  - id: layer_16_width_16k_l0_medium
10198
10198
  path: resid_post/layer_16_width_16k_l0_medium
10199
10199
  l0: 53
10200
+ neuronpedia: gemma-3-27b-it/16-gemmascope-2-res-16k
10200
10201
  - id: layer_16_width_16k_l0_small
10201
10202
  path: resid_post/layer_16_width_16k_l0_small
10202
10203
  l0: 17
@@ -10206,6 +10207,7 @@ gemma-scope-2-27b-it-res:
10206
10207
  - id: layer_16_width_1m_l0_medium
10207
10208
  path: resid_post/layer_16_width_1m_l0_medium
10208
10209
  l0: 53
10210
+ neuronpedia: gemma-3-27b-it/16-gemmascope-2-res-1m
10209
10211
  - id: layer_16_width_1m_l0_small
10210
10212
  path: resid_post/layer_16_width_1m_l0_small
10211
10213
  l0: 17
@@ -10215,6 +10217,7 @@ gemma-scope-2-27b-it-res:
10215
10217
  - id: layer_16_width_262k_l0_medium
10216
10218
  path: resid_post/layer_16_width_262k_l0_medium
10217
10219
  l0: 53
10220
+ neuronpedia: gemma-3-27b-it/16-gemmascope-2-res-262k
10218
10221
  - id: layer_16_width_262k_l0_medium_seed_1
10219
10222
  path: resid_post/layer_16_width_262k_l0_medium_seed_1
10220
10223
  l0: 53
@@ -10227,6 +10230,7 @@ gemma-scope-2-27b-it-res:
10227
10230
  - id: layer_16_width_65k_l0_medium
10228
10231
  path: resid_post/layer_16_width_65k_l0_medium
10229
10232
  l0: 53
10233
+ neuronpedia: gemma-3-27b-it/16-gemmascope-2-res-65k
10230
10234
  - id: layer_16_width_65k_l0_small
10231
10235
  path: resid_post/layer_16_width_65k_l0_small
10232
10236
  l0: 17
@@ -10236,6 +10240,7 @@ gemma-scope-2-27b-it-res:
10236
10240
  - id: layer_31_width_16k_l0_medium
10237
10241
  path: resid_post/layer_31_width_16k_l0_medium
10238
10242
  l0: 60
10243
+ neuronpedia: gemma-3-27b-it/31-gemmascope-2-res-16k
10239
10244
  - id: layer_31_width_16k_l0_small
10240
10245
  path: resid_post/layer_31_width_16k_l0_small
10241
10246
  l0: 20
@@ -10245,6 +10250,7 @@ gemma-scope-2-27b-it-res:
10245
10250
  - id: layer_31_width_1m_l0_medium
10246
10251
  path: resid_post/layer_31_width_1m_l0_medium
10247
10252
  l0: 60
10253
+ neuronpedia: gemma-3-27b-it/31-gemmascope-2-res-1m
10248
10254
  - id: layer_31_width_1m_l0_small
10249
10255
  path: resid_post/layer_31_width_1m_l0_small
10250
10256
  l0: 20
@@ -10254,6 +10260,7 @@ gemma-scope-2-27b-it-res:
10254
10260
  - id: layer_31_width_262k_l0_medium
10255
10261
  path: resid_post/layer_31_width_262k_l0_medium
10256
10262
  l0: 60
10263
+ neuronpedia: gemma-3-27b-it/31-gemmascope-2-res-262k
10257
10264
  - id: layer_31_width_262k_l0_medium_seed_1
10258
10265
  path: resid_post/layer_31_width_262k_l0_medium_seed_1
10259
10266
  l0: 60
@@ -10266,6 +10273,7 @@ gemma-scope-2-27b-it-res:
10266
10273
  - id: layer_31_width_65k_l0_medium
10267
10274
  path: resid_post/layer_31_width_65k_l0_medium
10268
10275
  l0: 60
10276
+ neuronpedia: gemma-3-27b-it/31-gemmascope-2-res-65k
10269
10277
  - id: layer_31_width_65k_l0_small
10270
10278
  path: resid_post/layer_31_width_65k_l0_small
10271
10279
  l0: 20
@@ -10275,6 +10283,7 @@ gemma-scope-2-27b-it-res:
10275
10283
  - id: layer_40_width_16k_l0_medium
10276
10284
  path: resid_post/layer_40_width_16k_l0_medium
10277
10285
  l0: 60
10286
+ neuronpedia: gemma-3-27b-it/40-gemmascope-2-res-16k
10278
10287
  - id: layer_40_width_16k_l0_small
10279
10288
  path: resid_post/layer_40_width_16k_l0_small
10280
10289
  l0: 20
@@ -10284,6 +10293,7 @@ gemma-scope-2-27b-it-res:
10284
10293
  - id: layer_40_width_1m_l0_medium
10285
10294
  path: resid_post/layer_40_width_1m_l0_medium
10286
10295
  l0: 60
10296
+ neuronpedia: gemma-3-27b-it/40-gemmascope-2-res-1m
10287
10297
  - id: layer_40_width_1m_l0_small
10288
10298
  path: resid_post/layer_40_width_1m_l0_small
10289
10299
  l0: 20
@@ -10293,6 +10303,7 @@ gemma-scope-2-27b-it-res:
10293
10303
  - id: layer_40_width_262k_l0_medium
10294
10304
  path: resid_post/layer_40_width_262k_l0_medium
10295
10305
  l0: 60
10306
+ neuronpedia: gemma-3-27b-it/40-gemmascope-2-res-262k
10296
10307
  - id: layer_40_width_262k_l0_medium_seed_1
10297
10308
  path: resid_post/layer_40_width_262k_l0_medium_seed_1
10298
10309
  l0: 60
@@ -10305,6 +10316,7 @@ gemma-scope-2-27b-it-res:
10305
10316
  - id: layer_40_width_65k_l0_medium
10306
10317
  path: resid_post/layer_40_width_65k_l0_medium
10307
10318
  l0: 60
10319
+ neuronpedia: gemma-3-27b-it/40-gemmascope-2-res-65k
10308
10320
  - id: layer_40_width_65k_l0_small
10309
10321
  path: resid_post/layer_40_width_65k_l0_small
10310
10322
  l0: 20
@@ -10314,6 +10326,7 @@ gemma-scope-2-27b-it-res:
10314
10326
  - id: layer_53_width_16k_l0_medium
10315
10327
  path: resid_post/layer_53_width_16k_l0_medium
10316
10328
  l0: 60
10329
+ neuronpedia: gemma-3-27b-it/53-gemmascope-2-res-16k
10317
10330
  - id: layer_53_width_16k_l0_small
10318
10331
  path: resid_post/layer_53_width_16k_l0_small
10319
10332
  l0: 20
@@ -10323,6 +10336,7 @@ gemma-scope-2-27b-it-res:
10323
10336
  - id: layer_53_width_1m_l0_medium
10324
10337
  path: resid_post/layer_53_width_1m_l0_medium
10325
10338
  l0: 60
10339
+ neuronpedia: gemma-3-27b-it/53-gemmascope-2-res-1m
10326
10340
  - id: layer_53_width_1m_l0_small
10327
10341
  path: resid_post/layer_53_width_1m_l0_small
10328
10342
  l0: 20
@@ -10332,6 +10346,7 @@ gemma-scope-2-27b-it-res:
10332
10346
  - id: layer_53_width_262k_l0_medium
10333
10347
  path: resid_post/layer_53_width_262k_l0_medium
10334
10348
  l0: 60
10349
+ neuronpedia: gemma-3-27b-it/53-gemmascope-2-res-262k
10335
10350
  - id: layer_53_width_262k_l0_medium_seed_1
10336
10351
  path: resid_post/layer_53_width_262k_l0_medium_seed_1
10337
10352
  l0: 60
@@ -10344,6 +10359,7 @@ gemma-scope-2-27b-it-res:
10344
10359
  - id: layer_53_width_65k_l0_medium
10345
10360
  path: resid_post/layer_53_width_65k_l0_medium
10346
10361
  l0: 60
10362
+ neuronpedia: gemma-3-27b-it/53-gemmascope-2-res-65k
10347
10363
  - id: layer_53_width_65k_l0_small
10348
10364
  path: resid_post/layer_53_width_65k_l0_small
10349
10365
  l0: 20
sae_lens/saes/sae.py CHANGED
@@ -27,11 +27,10 @@ from typing_extensions import deprecated, overload, override
27
27
 
28
28
  from sae_lens import __version__
29
29
  from sae_lens.constants import (
30
- DTYPE_MAP,
31
30
  SAE_CFG_FILENAME,
32
31
  SAE_WEIGHTS_FILENAME,
33
32
  )
34
- from sae_lens.util import filter_valid_dataclass_fields
33
+ from sae_lens.util import dtype_to_str, filter_valid_dataclass_fields, str_to_dtype
35
34
 
36
35
  if TYPE_CHECKING:
37
36
  from sae_lens.config import LanguageModelSAERunnerConfig
@@ -253,7 +252,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
253
252
  stacklevel=1,
254
253
  )
255
254
 
256
- self.dtype = DTYPE_MAP[cfg.dtype]
255
+ self.dtype = str_to_dtype(cfg.dtype)
257
256
  self.device = torch.device(cfg.device)
258
257
  self.use_error_term = use_error_term
259
258
 
@@ -437,8 +436,8 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
437
436
 
438
437
  # Update dtype in config if provided
439
438
  if dtype_arg is not None:
440
- # Update the cfg.dtype
441
- self.cfg.dtype = str(dtype_arg)
439
+ # Update the cfg.dtype (use canonical short form like "float32")
440
+ self.cfg.dtype = dtype_to_str(dtype_arg)
442
441
 
443
442
  # Update the dtype property
444
443
  self.dtype = dtype_arg
@@ -534,6 +533,15 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
534
533
  dtype: str | None = None,
535
534
  converter: PretrainedSaeDiskLoader = sae_lens_disk_loader,
536
535
  ) -> T_SAE:
536
+ """
537
+ Load a SAE from disk.
538
+
539
+ Args:
540
+ path: The path to the SAE weights and config.
541
+ device: The device to load the SAE on, defaults to "cpu".
542
+ dtype: The dtype to load the SAE on, defaults to None. If None, the dtype will be inferred from the SAE config.
543
+ converter: The converter to use to load the SAE, defaults to sae_lens_disk_loader.
544
+ """
537
545
  overrides = {"dtype": dtype} if dtype is not None else None
538
546
  cfg_dict, state_dict = converter(path, device, cfg_overrides=overrides)
539
547
  cfg_dict = handle_config_defaulting(cfg_dict)
@@ -542,10 +550,17 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
542
550
  )
543
551
  sae_cfg = sae_config_cls.from_dict(cfg_dict)
544
552
  sae_cls = cls.get_sae_class_for_architecture(sae_cfg.architecture())
553
+ # hack to avoid using double memory when loading the SAE.
554
+ # first put the SAE on the meta device, then load the weights.
555
+ device = sae_cfg.device
556
+ sae_cfg.device = "meta"
545
557
  sae = sae_cls(sae_cfg)
558
+ sae.cfg.device = device
546
559
  sae.process_state_dict_for_loading(state_dict)
547
- sae.load_state_dict(state_dict)
548
- return sae
560
+ sae.load_state_dict(state_dict, assign=True)
561
+ # the loaders should already handle the dtype / device conversion
562
+ # but this is a fallback to guarantee the SAE is on the correct device and dtype
563
+ return sae.to(dtype=str_to_dtype(sae_cfg.dtype), device=device)
549
564
 
550
565
  @classmethod
551
566
  def from_pretrained(
@@ -553,6 +568,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
553
568
  release: str,
554
569
  sae_id: str,
555
570
  device: str = "cpu",
571
+ dtype: str = "float32",
556
572
  force_download: bool = False,
557
573
  converter: PretrainedSaeHuggingfaceLoader | None = None,
558
574
  ) -> T_SAE:
@@ -562,10 +578,18 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
562
578
  Args:
563
579
  release: The release name. This will be mapped to a huggingface repo id based on the pretrained_saes.yaml file.
564
580
  id: The id of the SAE to load. This will be mapped to a path in the huggingface repo.
565
- device: The device to load the SAE on.
581
+ device: The device to load the SAE on, defaults to "cpu".
582
+ dtype: The dtype to load the SAE on, defaults to "float32".
583
+ force_download: Whether to force download the SAE weights and config, defaults to False.
584
+ converter: The converter to use to load the SAE, defaults to None. If None, the converter will be inferred from the release.
566
585
  """
567
586
  return cls.from_pretrained_with_cfg_and_sparsity(
568
- release, sae_id, device, force_download, converter=converter
587
+ release,
588
+ sae_id,
589
+ device,
590
+ force_download=force_download,
591
+ dtype=dtype,
592
+ converter=converter,
569
593
  )[0]
570
594
 
571
595
  @classmethod
@@ -574,6 +598,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
574
598
  release: str,
575
599
  sae_id: str,
576
600
  device: str = "cpu",
601
+ dtype: str = "float32",
577
602
  force_download: bool = False,
578
603
  converter: PretrainedSaeHuggingfaceLoader | None = None,
579
604
  ) -> tuple[T_SAE, dict[str, Any], torch.Tensor | None]:
@@ -584,7 +609,10 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
584
609
  Args:
585
610
  release: The release name. This will be mapped to a huggingface repo id based on the pretrained_saes.yaml file.
586
611
  id: The id of the SAE to load. This will be mapped to a path in the huggingface repo.
587
- device: The device to load the SAE on.
612
+ device: The device to load the SAE on, defaults to "cpu".
613
+ dtype: The dtype to load the SAE on, defaults to "float32".
614
+ force_download: Whether to force download the SAE weights and config, defaults to False.
615
+ converter: The converter to use to load the SAE, defaults to None. If None, the converter will be inferred from the release.
588
616
  """
589
617
 
590
618
  # get sae directory
@@ -634,6 +662,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
634
662
  repo_id, folder_name = get_repo_id_and_folder_name(release, sae_id)
635
663
  config_overrides = get_config_overrides(release, sae_id)
636
664
  config_overrides["device"] = device
665
+ config_overrides["dtype"] = dtype
637
666
 
638
667
  # Load config and weights
639
668
  cfg_dict, state_dict, log_sparsities = conversion_loader(
@@ -651,9 +680,14 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
651
680
  )
652
681
  sae_cfg = sae_config_cls.from_dict(cfg_dict)
653
682
  sae_cls = cls.get_sae_class_for_architecture(sae_cfg.architecture())
683
+ # hack to avoid using double memory when loading the SAE.
684
+ # first put the SAE on the meta device, then load the weights.
685
+ device = sae_cfg.device
686
+ sae_cfg.device = "meta"
654
687
  sae = sae_cls(sae_cfg)
688
+ sae.cfg.device = device
655
689
  sae.process_state_dict_for_loading(state_dict)
656
- sae.load_state_dict(state_dict)
690
+ sae.load_state_dict(state_dict, assign=True)
657
691
 
658
692
  # Apply normalization if needed
659
693
  if cfg_dict.get("normalize_activations") == "expected_average_only_in":
@@ -666,7 +700,13 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
666
700
  f"norm_scaling_factor not found for {release} and {sae_id}, but normalize_activations is 'expected_average_only_in'. Skipping normalization folding."
667
701
  )
668
702
 
669
- return sae, cfg_dict, log_sparsities
703
+ # the loaders should already handle the dtype / device conversion
704
+ # but this is a fallback to guarantee the SAE is on the correct device and dtype
705
+ return (
706
+ sae.to(dtype=str_to_dtype(dtype), device=device),
707
+ cfg_dict,
708
+ log_sparsities,
709
+ )
670
710
 
671
711
  @classmethod
672
712
  def from_dict(cls: type[T_SAE], config_dict: dict[str, Any]) -> T_SAE:
@@ -24,7 +24,7 @@ from sae_lens.config import (
24
24
  HfDataset,
25
25
  LanguageModelSAERunnerConfig,
26
26
  )
27
- from sae_lens.constants import ACTIVATIONS_STORE_STATE_FILENAME, DTYPE_MAP
27
+ from sae_lens.constants import ACTIVATIONS_STORE_STATE_FILENAME
28
28
  from sae_lens.pretokenize_runner import get_special_token_from_cfg
29
29
  from sae_lens.saes.sae import SAE, T_SAE_CONFIG, T_TRAINING_SAE_CONFIG
30
30
  from sae_lens.tokenization_and_batching import concat_and_batch_sequences
@@ -32,6 +32,7 @@ from sae_lens.training.mixing_buffer import mixing_buffer
32
32
  from sae_lens.util import (
33
33
  extract_stop_at_layer_from_tlens_hook_name,
34
34
  get_special_token_ids,
35
+ str_to_dtype,
35
36
  )
36
37
 
37
38
 
@@ -258,7 +259,7 @@ class ActivationsStore:
258
259
  self.prepend_bos = prepend_bos
259
260
  self.normalize_activations = normalize_activations
260
261
  self.device = torch.device(device)
261
- self.dtype = DTYPE_MAP[dtype]
262
+ self.dtype = str_to_dtype(dtype)
262
263
  self.cached_activations_path = cached_activations_path
263
264
  self.autocast_lm = autocast_lm
264
265
  self.seqpos_slice = seqpos_slice
sae_lens/util.py CHANGED
@@ -5,8 +5,11 @@ from dataclasses import asdict, fields, is_dataclass
5
5
  from pathlib import Path
6
6
  from typing import Sequence, TypeVar
7
7
 
8
+ import torch
8
9
  from transformers import PreTrainedTokenizerBase
9
10
 
11
+ from sae_lens.constants import DTYPE_MAP, DTYPE_TO_STR
12
+
10
13
  K = TypeVar("K")
11
14
  V = TypeVar("V")
12
15
 
@@ -90,3 +93,21 @@ def get_special_token_ids(tokenizer: PreTrainedTokenizerBase) -> list[int]:
90
93
  special_tokens.add(token_id)
91
94
 
92
95
  return list(special_tokens)
96
+
97
+
98
+ def str_to_dtype(dtype: str) -> torch.dtype:
99
+ """Convert a string to a torch.dtype."""
100
+ if dtype not in DTYPE_MAP:
101
+ raise ValueError(
102
+ f"Invalid dtype: {dtype}. Must be one of {list(DTYPE_MAP.keys())}"
103
+ )
104
+ return DTYPE_MAP[dtype]
105
+
106
+
107
+ def dtype_to_str(dtype: torch.dtype) -> str:
108
+ """Convert a torch.dtype to a string."""
109
+ if dtype not in DTYPE_TO_STR:
110
+ raise ValueError(
111
+ f"Invalid dtype: {dtype}. Must be one of {list(DTYPE_TO_STR.keys())}"
112
+ )
113
+ return DTYPE_TO_STR[dtype]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sae-lens
3
- Version: 6.24.1
3
+ Version: 6.25.1
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  License-File: LICENSE
@@ -1,25 +1,25 @@
1
- sae_lens/__init__.py,sha256=spLEw4TR2BzzKc3R-ik8MbHlYOAR__wVmkSmJqOB4Tc,4268
1
+ sae_lens/__init__.py,sha256=vWuA8EbynIJadj666RoFNCTIvoH9-HFpUxuHwoYt8Ks,4268
2
2
  sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  sae_lens/analysis/hooked_sae_transformer.py,sha256=dQRgGVwce8XwylL2AzJE7l9elhtMRFCs2hdUj-Qyy4g,14038
4
4
  sae_lens/analysis/neuronpedia_integration.py,sha256=Gx1W7hUBEuMoasNcnOnZ1wmqbXDd1pSZ1nqKEya1HQc,4962
5
- sae_lens/cache_activations_runner.py,sha256=RNN_nDQkH0lqEIxTAIDx3g1cgAzRxQWBSBEXA6nbWh0,12565
6
- sae_lens/config.py,sha256=fxvpQxFfPOVUkryiHD19q9O1AJDSkIguWeYlbJuTxmY,30329
7
- sae_lens/constants.py,sha256=qX12uAE_xkha6hjss_0MGTbakI7gEkJzHABkZaHWQFU,683
5
+ sae_lens/cache_activations_runner.py,sha256=Lvlz-k5-3XxVRtUdC4b1CiKyx5s0ckLa8GDGv9_kcxs,12566
6
+ sae_lens/config.py,sha256=JmcrXT4orJV2OulbEZAciz8RQmYv7DrtUtRbOLsNQ2Y,30330
7
+ sae_lens/constants.py,sha256=CM-h9AjZNAl2aP7hVpKk7YsFHpu-_Lfhhmq2d5qPEVc,887
8
8
  sae_lens/evals.py,sha256=P0NUsJeGzYxFBiVKhbPzd72IFKY4gH40HHlEZ3jEAmg,39598
9
9
  sae_lens/llm_sae_training_runner.py,sha256=M7BK55gSFYu2qFQKABHX3c8i46P1LfODCeyHFzGGuqU,15196
10
10
  sae_lens/load_model.py,sha256=C8AMykctj6H7tz_xRwB06-EXj6TfW64PtSJZR5Jxn1Y,8649
11
11
  sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- sae_lens/loading/pretrained_sae_loaders.py,sha256=W2eIvUU1wAHrYxGiZs4s2D6DnGBQqqKjq0wvXzWbD5c,63561
12
+ sae_lens/loading/pretrained_sae_loaders.py,sha256=hq-dhxsEdUmlAnZEiZBqX7lNyQQwZ6KXmXZWpzAc5FY,63638
13
13
  sae_lens/loading/pretrained_saes_directory.py,sha256=hejNfLUepYCSGPalRfQwxxCEUqMMUPsn1tufwvwct5k,3820
14
14
  sae_lens/pretokenize_runner.py,sha256=amJwIz3CKi2s2wNQn-10E7eAV7VFhNqtFDNTeTkwEI8,7133
15
- sae_lens/pretrained_saes.yaml,sha256=Hd1GgaPL4TAXoS2gizG9e_9jc_9LpfI4w_hwGkEz9xQ,1509314
15
+ sae_lens/pretrained_saes.yaml,sha256=Hy9mk4Liy50B0CIBD4ER1ETcho2drFFiIy-bPVCN_lc,1510210
16
16
  sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
17
17
  sae_lens/saes/__init__.py,sha256=fYVujOzNnUgpzLL0MBLBt_DNX2CPcTaheukzCd2bEPo,1906
18
18
  sae_lens/saes/batchtopk_sae.py,sha256=x4EbgZl0GUickRPcCmtKNGS2Ra3Uy1Z1OtF2FnrSabQ,5422
19
19
  sae_lens/saes/gated_sae.py,sha256=mHnmw-RD7hqIbP9_EBj3p2SK0OqQIkZivdOKRygeRgw,8825
20
20
  sae_lens/saes/jumprelu_sae.py,sha256=udjGHp3WTABQSL2Qq57j-bINWX61GCmo68EmdjMOXoo,13310
21
21
  sae_lens/saes/matryoshka_batchtopk_sae.py,sha256=Qr6htt1HHOuO9FXI9hyaPSnGFIiJG-v7y1t1CEmkFzM,5995
22
- sae_lens/saes/sae.py,sha256=Vb1aGSDPRv_0J2aL8-EICRSkIxsO6Q4lJaJE9NNmfdA,37749
22
+ sae_lens/saes/sae.py,sha256=fzXv8lwHskSxsf8hm_wlKPkpq50iafmBjBNQzwZ6a00,40050
23
23
  sae_lens/saes/standard_sae.py,sha256=nEVETwAmRD2tyX7ESIic1fij48gAq1Dh7s_GQ2fqCZ4,5747
24
24
  sae_lens/saes/temporal_sae.py,sha256=DsecivcHWId-MTuJpQbz8OhqtmGhZACxJauYZGHo0Ok,13272
25
25
  sae_lens/saes/topk_sae.py,sha256=vrMRPrCQR1o8G_kXqY_EAoGZARupkQNFB2dNZVLsusE,21073
@@ -27,15 +27,15 @@ sae_lens/saes/transcoder.py,sha256=CTpJs8ASOK06npih7gZHygZuxqTR7HICWlOYfTiKjI4,1
27
27
  sae_lens/tokenization_and_batching.py,sha256=D_o7cXvRqhT89H3wNzoRymNALNE6eHojBWLdXOUwUGE,5438
28
28
  sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
29
29
  sae_lens/training/activation_scaler.py,sha256=FzNfgBplLWmyiSlZ6TUvE-nur3lOiGTrlvC97ys8S24,1973
30
- sae_lens/training/activations_store.py,sha256=yDWw7TZGPFM_O8_Oi78j8lLIHJJesxq9TKVP_TrMX-M,33768
30
+ sae_lens/training/activations_store.py,sha256=rQadexm2BiwK7_MZIPlRkcKSqabi3iuOTC-R8aJchS8,33778
31
31
  sae_lens/training/mixing_buffer.py,sha256=vDpYG5ZE70szDvBsRKcNHEES3h_WTKJ16qDYk5jPOVA,2015
32
32
  sae_lens/training/optim.py,sha256=bJpqqcK4enkcPvQAJkeH4Ci1LUOlfjIMTv6-IlaAbRA,5588
33
33
  sae_lens/training/sae_trainer.py,sha256=zhkabyIKxI_tZTV3_kwz6zMrHZ95Ecr97krmwc-9ffs,17600
34
34
  sae_lens/training/types.py,sha256=1FpLx_Doda9vZpmfm-x1e8wGBYpyhe9Kpb_JuM5nIFM,90
35
35
  sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
36
36
  sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
37
- sae_lens/util.py,sha256=tCovQ-eZa1L7thPpNDL6PGOJrIMML2yLI5e0EHCOpS8,3309
38
- sae_lens-6.24.1.dist-info/METADATA,sha256=5TlxCqEZoJV4S0F9IP6Ak_aitVkMkFfUhlFOl5NIJBc,5361
39
- sae_lens-6.24.1.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
40
- sae_lens-6.24.1.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
41
- sae_lens-6.24.1.dist-info/RECORD,,
37
+ sae_lens/util.py,sha256=spkcmQUsjVYFn5H2032nQYr1CKGVnv3tAdfIpY59-Mg,3919
38
+ sae_lens-6.25.1.dist-info/METADATA,sha256=gClFVWzEWNNjrXsGqvCY6ry6ehXIFwp8PB0jIOhmQvc,5361
39
+ sae_lens-6.25.1.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
40
+ sae_lens-6.25.1.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
41
+ sae_lens-6.25.1.dist-info/RECORD,,