sae-lens 6.25.0__tar.gz → 6.26.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. {sae_lens-6.25.0 → sae_lens-6.26.0}/PKG-INFO +1 -1
  2. {sae_lens-6.25.0 → sae_lens-6.26.0}/pyproject.toml +1 -1
  3. {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/__init__.py +13 -1
  4. {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/cache_activations_runner.py +2 -2
  5. {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/config.py +7 -2
  6. {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/constants.py +8 -0
  7. {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/loading/pretrained_sae_loaders.py +66 -57
  8. {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/saes/__init__.py +10 -0
  9. sae_lens-6.26.0/sae_lens/saes/matching_pursuit_sae.py +334 -0
  10. {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/saes/sae.py +52 -12
  11. {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/training/activations_store.py +3 -2
  12. {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/util.py +21 -0
  13. {sae_lens-6.25.0 → sae_lens-6.26.0}/LICENSE +0 -0
  14. {sae_lens-6.25.0 → sae_lens-6.26.0}/README.md +0 -0
  15. {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/analysis/__init__.py +0 -0
  16. {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
  17. {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/analysis/neuronpedia_integration.py +0 -0
  18. {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/evals.py +0 -0
  19. {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/llm_sae_training_runner.py +0 -0
  20. {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/load_model.py +0 -0
  21. {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/loading/__init__.py +0 -0
  22. {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/loading/pretrained_saes_directory.py +0 -0
  23. {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/pretokenize_runner.py +0 -0
  24. {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/pretrained_saes.yaml +0 -0
  25. {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/registry.py +0 -0
  26. {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/saes/batchtopk_sae.py +0 -0
  27. {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/saes/gated_sae.py +0 -0
  28. {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/saes/jumprelu_sae.py +0 -0
  29. {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/saes/matryoshka_batchtopk_sae.py +0 -0
  30. {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/saes/standard_sae.py +0 -0
  31. {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/saes/temporal_sae.py +0 -0
  32. {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/saes/topk_sae.py +0 -0
  33. {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/saes/transcoder.py +0 -0
  34. {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/tokenization_and_batching.py +0 -0
  35. {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/training/__init__.py +0 -0
  36. {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/training/activation_scaler.py +0 -0
  37. {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/training/mixing_buffer.py +0 -0
  38. {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/training/optim.py +0 -0
  39. {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/training/sae_trainer.py +0 -0
  40. {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/training/types.py +0 -0
  41. {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
  42. {sae_lens-6.25.0 → sae_lens-6.26.0}/sae_lens/tutorial/tsea.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sae-lens
3
- Version: 6.25.0
3
+ Version: 6.26.0
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  License-File: LICENSE
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "sae-lens"
3
- version = "6.25.0"
3
+ version = "6.26.0"
4
4
  description = "Training and Analyzing Sparse Autoencoders (SAEs)"
5
5
  authors = ["Joseph Bloom"]
6
6
  readme = "README.md"
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.25.0"
2
+ __version__ = "6.26.0"
3
3
 
4
4
  import logging
5
5
 
@@ -21,6 +21,10 @@ from sae_lens.saes import (
21
21
  JumpReLUTrainingSAEConfig,
22
22
  JumpReLUTranscoder,
23
23
  JumpReLUTranscoderConfig,
24
+ MatchingPursuitSAE,
25
+ MatchingPursuitSAEConfig,
26
+ MatchingPursuitTrainingSAE,
27
+ MatchingPursuitTrainingSAEConfig,
24
28
  MatryoshkaBatchTopKTrainingSAE,
25
29
  MatryoshkaBatchTopKTrainingSAEConfig,
26
30
  SAEConfig,
@@ -113,6 +117,10 @@ __all__ = [
113
117
  "MatryoshkaBatchTopKTrainingSAEConfig",
114
118
  "TemporalSAE",
115
119
  "TemporalSAEConfig",
120
+ "MatchingPursuitSAE",
121
+ "MatchingPursuitTrainingSAE",
122
+ "MatchingPursuitSAEConfig",
123
+ "MatchingPursuitTrainingSAEConfig",
116
124
  ]
117
125
 
118
126
 
@@ -139,3 +147,7 @@ register_sae_class(
139
147
  "jumprelu_skip_transcoder", JumpReLUSkipTranscoder, JumpReLUSkipTranscoderConfig
140
148
  )
141
149
  register_sae_class("temporal", TemporalSAE, TemporalSAEConfig)
150
+ register_sae_class("matching_pursuit", MatchingPursuitSAE, MatchingPursuitSAEConfig)
151
+ register_sae_training_class(
152
+ "matching_pursuit", MatchingPursuitTrainingSAE, MatchingPursuitTrainingSAEConfig
153
+ )
@@ -14,9 +14,9 @@ from transformer_lens.HookedTransformer import HookedRootModule
14
14
 
15
15
  from sae_lens import logger
16
16
  from sae_lens.config import CacheActivationsRunnerConfig
17
- from sae_lens.constants import DTYPE_MAP
18
17
  from sae_lens.load_model import load_model
19
18
  from sae_lens.training.activations_store import ActivationsStore
19
+ from sae_lens.util import str_to_dtype
20
20
 
21
21
 
22
22
  def _mk_activations_store(
@@ -97,7 +97,7 @@ class CacheActivationsRunner:
97
97
  bytes_per_token = (
98
98
  self.cfg.d_in * self.cfg.dtype.itemsize
99
99
  if isinstance(self.cfg.dtype, torch.dtype)
100
- else DTYPE_MAP[self.cfg.dtype].itemsize
100
+ else str_to_dtype(self.cfg.dtype).itemsize
101
101
  )
102
102
  total_training_tokens = self.cfg.n_seq_in_dataset * self.context_size
103
103
  total_disk_space_gb = total_training_tokens * bytes_per_token / 10**9
@@ -17,9 +17,14 @@ from datasets import (
17
17
  )
18
18
 
19
19
  from sae_lens import __version__, logger
20
- from sae_lens.constants import DTYPE_MAP
20
+
21
+ # keeping this unused import since some SAELens deps import DTYPE_MAP from config
22
+ from sae_lens.constants import (
23
+ DTYPE_MAP, # noqa: F401 # pyright: ignore[reportUnusedImport]
24
+ )
21
25
  from sae_lens.registry import get_sae_training_class
22
26
  from sae_lens.saes.sae import TrainingSAEConfig
27
+ from sae_lens.util import str_to_dtype
23
28
 
24
29
  if TYPE_CHECKING:
25
30
  pass
@@ -563,7 +568,7 @@ class CacheActivationsRunnerConfig:
563
568
 
564
569
  @property
565
570
  def bytes_per_token(self) -> int:
566
- return self.d_in * DTYPE_MAP[self.dtype].itemsize
571
+ return self.d_in * str_to_dtype(self.dtype).itemsize
567
572
 
568
573
  @property
569
574
  def n_tokens_in_buffer(self) -> int:
@@ -11,6 +11,14 @@ DTYPE_MAP = {
11
11
  "torch.bfloat16": torch.bfloat16,
12
12
  }
13
13
 
14
+ # Reverse mapping from torch.dtype to canonical string format
15
+ DTYPE_TO_STR = {
16
+ torch.float32: "float32",
17
+ torch.float64: "float64",
18
+ torch.float16: "float16",
19
+ torch.bfloat16: "bfloat16",
20
+ }
21
+
14
22
 
15
23
  SPARSITY_FILENAME = "sparsity.safetensors"
16
24
  SAE_WEIGHTS_FILENAME = "sae_weights.safetensors"
@@ -12,11 +12,9 @@ from huggingface_hub import hf_hub_download, hf_hub_url
12
12
  from huggingface_hub.utils import EntryNotFoundError, build_hf_headers
13
13
  from packaging.version import Version
14
14
  from safetensors import safe_open
15
- from safetensors.torch import load_file
16
15
 
17
16
  from sae_lens import logger
18
17
  from sae_lens.constants import (
19
- DTYPE_MAP,
20
18
  SAE_CFG_FILENAME,
21
19
  SAE_WEIGHTS_FILENAME,
22
20
  SPARSIFY_WEIGHTS_FILENAME,
@@ -28,7 +26,7 @@ from sae_lens.loading.pretrained_saes_directory import (
28
26
  get_repo_id_and_folder_name,
29
27
  )
30
28
  from sae_lens.registry import get_sae_class
31
- from sae_lens.util import filter_valid_dataclass_fields
29
+ from sae_lens.util import filter_valid_dataclass_fields, str_to_dtype
32
30
 
33
31
  LLM_METADATA_KEYS = {
34
32
  "model_name",
@@ -51,6 +49,21 @@ LLM_METADATA_KEYS = {
51
49
  }
52
50
 
53
51
 
52
+ def load_safetensors_weights(
53
+ path: str | Path, device: str = "cpu", dtype: torch.dtype | str | None = None
54
+ ) -> dict[str, torch.Tensor]:
55
+ """Load safetensors weights and optionally convert to a different dtype"""
56
+ loaded_weights = {}
57
+ dtype = str_to_dtype(dtype) if isinstance(dtype, str) else dtype
58
+ with safe_open(path, framework="pt", device=device) as f:
59
+ for k in f.keys(): # noqa: SIM118
60
+ weight = f.get_tensor(k)
61
+ if dtype is not None:
62
+ weight = weight.to(dtype=dtype)
63
+ loaded_weights[k] = weight
64
+ return loaded_weights
65
+
66
+
54
67
  # loaders take in a release, sae_id, device, and whether to force download, and returns a tuple of config, state_dict, and log sparsity
55
68
  class PretrainedSaeHuggingfaceLoader(Protocol):
56
69
  def __call__(
@@ -341,7 +354,7 @@ def read_sae_components_from_disk(
341
354
  Given a loaded dictionary and a path to a weight file, load the weights and return the state_dict.
342
355
  """
343
356
  if dtype is None:
344
- dtype = DTYPE_MAP[cfg_dict["dtype"]]
357
+ dtype = str_to_dtype(cfg_dict["dtype"])
345
358
 
346
359
  state_dict = {}
347
360
  with safe_open(weight_path, framework="pt", device=device) as f: # type: ignore
@@ -695,7 +708,9 @@ def gemma_3_sae_huggingface_loader(
695
708
  force_download=force_download,
696
709
  )
697
710
 
698
- raw_state_dict = load_file(sae_path, device=device)
711
+ raw_state_dict = load_safetensors_weights(
712
+ sae_path, device=device, dtype=cfg_dict.get("dtype")
713
+ )
699
714
 
700
715
  with torch.no_grad():
701
716
  w_dec = raw_state_dict["w_dec"]
@@ -782,11 +797,13 @@ def get_goodfire_huggingface_loader(
782
797
  )
783
798
  raw_state_dict = torch.load(sae_path, map_location=device)
784
799
 
800
+ target_dtype = str_to_dtype(cfg_dict.get("dtype", "float32"))
801
+
785
802
  state_dict = {
786
- "W_enc": raw_state_dict["encoder_linear.weight"].T,
787
- "W_dec": raw_state_dict["decoder_linear.weight"].T,
788
- "b_enc": raw_state_dict["encoder_linear.bias"],
789
- "b_dec": raw_state_dict["decoder_linear.bias"],
803
+ "W_enc": raw_state_dict["encoder_linear.weight"].T.to(dtype=target_dtype),
804
+ "W_dec": raw_state_dict["decoder_linear.weight"].T.to(dtype=target_dtype),
805
+ "b_enc": raw_state_dict["encoder_linear.bias"].to(dtype=target_dtype),
806
+ "b_dec": raw_state_dict["decoder_linear.bias"].to(dtype=target_dtype),
790
807
  }
791
808
 
792
809
  return cfg_dict, state_dict, None
@@ -889,26 +906,19 @@ def llama_scope_sae_huggingface_loader(
889
906
  force_download=force_download,
890
907
  )
891
908
 
892
- # Load the weights using load_file instead of safe_open
893
- state_dict_loaded = load_file(sae_path, device=device)
909
+ state_dict_loaded = load_safetensors_weights(
910
+ sae_path, device=device, dtype=cfg_dict.get("dtype")
911
+ )
894
912
 
895
913
  # Convert and organize the weights
896
914
  state_dict = {
897
- "W_enc": state_dict_loaded["encoder.weight"]
898
- .to(dtype=DTYPE_MAP[cfg_dict["dtype"]])
899
- .T,
900
- "W_dec": state_dict_loaded["decoder.weight"]
901
- .to(dtype=DTYPE_MAP[cfg_dict["dtype"]])
902
- .T,
903
- "b_enc": state_dict_loaded["encoder.bias"].to(
904
- dtype=DTYPE_MAP[cfg_dict["dtype"]]
905
- ),
906
- "b_dec": state_dict_loaded["decoder.bias"].to(
907
- dtype=DTYPE_MAP[cfg_dict["dtype"]]
908
- ),
915
+ "W_enc": state_dict_loaded["encoder.weight"].T,
916
+ "W_dec": state_dict_loaded["decoder.weight"].T,
917
+ "b_enc": state_dict_loaded["encoder.bias"],
918
+ "b_dec": state_dict_loaded["decoder.bias"],
909
919
  "threshold": torch.ones(
910
920
  cfg_dict["d_sae"],
911
- dtype=DTYPE_MAP[cfg_dict["dtype"]],
921
+ dtype=str_to_dtype(cfg_dict["dtype"]),
912
922
  device=cfg_dict["device"],
913
923
  )
914
924
  * cfg_dict["jump_relu_threshold"],
@@ -1219,26 +1229,17 @@ def llama_scope_r1_distill_sae_huggingface_loader(
1219
1229
  force_download=force_download,
1220
1230
  )
1221
1231
 
1222
- # Load the weights using load_file instead of safe_open
1223
- state_dict_loaded = load_file(sae_path, device=device)
1232
+ state_dict_loaded = load_safetensors_weights(
1233
+ sae_path, device=device, dtype=cfg_dict.get("dtype")
1234
+ )
1224
1235
 
1225
1236
  # Convert and organize the weights
1226
1237
  state_dict = {
1227
- "W_enc": state_dict_loaded["encoder.weight"]
1228
- .to(dtype=DTYPE_MAP[cfg_dict["dtype"]])
1229
- .T,
1230
- "W_dec": state_dict_loaded["decoder.weight"]
1231
- .to(dtype=DTYPE_MAP[cfg_dict["dtype"]])
1232
- .T,
1233
- "b_enc": state_dict_loaded["encoder.bias"].to(
1234
- dtype=DTYPE_MAP[cfg_dict["dtype"]]
1235
- ),
1236
- "b_dec": state_dict_loaded["decoder.bias"].to(
1237
- dtype=DTYPE_MAP[cfg_dict["dtype"]]
1238
- ),
1239
- "threshold": state_dict_loaded["log_jumprelu_threshold"]
1240
- .to(dtype=DTYPE_MAP[cfg_dict["dtype"]])
1241
- .exp(),
1238
+ "W_enc": state_dict_loaded["encoder.weight"].T,
1239
+ "W_dec": state_dict_loaded["decoder.weight"].T,
1240
+ "b_enc": state_dict_loaded["encoder.bias"],
1241
+ "b_dec": state_dict_loaded["decoder.bias"],
1242
+ "threshold": state_dict_loaded["log_jumprelu_threshold"].exp(),
1242
1243
  }
1243
1244
 
1244
1245
  # No sparsity tensor for Llama Scope SAEs
@@ -1358,34 +1359,34 @@ def sparsify_disk_loader(
1358
1359
  cfg_dict = get_sparsify_config_from_disk(path, device, cfg_overrides)
1359
1360
 
1360
1361
  weight_path = Path(path) / SPARSIFY_WEIGHTS_FILENAME
1361
- state_dict_loaded = load_file(weight_path, device=device)
1362
-
1363
- dtype = DTYPE_MAP[cfg_dict["dtype"]]
1362
+ state_dict_loaded = load_safetensors_weights(
1363
+ weight_path, device=device, dtype=cfg_dict.get("dtype")
1364
+ )
1364
1365
 
1365
1366
  W_enc = (
1366
1367
  state_dict_loaded["W_enc"]
1367
1368
  if "W_enc" in state_dict_loaded
1368
1369
  else state_dict_loaded["encoder.weight"].T
1369
- ).to(dtype)
1370
+ )
1370
1371
 
1371
1372
  if "W_dec" in state_dict_loaded:
1372
- W_dec = state_dict_loaded["W_dec"].T.to(dtype)
1373
+ W_dec = state_dict_loaded["W_dec"].T
1373
1374
  else:
1374
- W_dec = state_dict_loaded["decoder.weight"].T.to(dtype)
1375
+ W_dec = state_dict_loaded["decoder.weight"].T
1375
1376
 
1376
1377
  if "b_enc" in state_dict_loaded:
1377
- b_enc = state_dict_loaded["b_enc"].to(dtype)
1378
+ b_enc = state_dict_loaded["b_enc"]
1378
1379
  elif "encoder.bias" in state_dict_loaded:
1379
- b_enc = state_dict_loaded["encoder.bias"].to(dtype)
1380
+ b_enc = state_dict_loaded["encoder.bias"]
1380
1381
  else:
1381
- b_enc = torch.zeros(cfg_dict["d_sae"], dtype=dtype, device=device)
1382
+ b_enc = torch.zeros(cfg_dict["d_sae"], dtype=W_dec.dtype, device=device)
1382
1383
 
1383
1384
  if "b_dec" in state_dict_loaded:
1384
- b_dec = state_dict_loaded["b_dec"].to(dtype)
1385
+ b_dec = state_dict_loaded["b_dec"]
1385
1386
  elif "decoder.bias" in state_dict_loaded:
1386
- b_dec = state_dict_loaded["decoder.bias"].to(dtype)
1387
+ b_dec = state_dict_loaded["decoder.bias"]
1387
1388
  else:
1388
- b_dec = torch.zeros(cfg_dict["d_in"], dtype=dtype, device=device)
1389
+ b_dec = torch.zeros(cfg_dict["d_in"], dtype=W_dec.dtype, device=device)
1389
1390
 
1390
1391
  state_dict = {"W_enc": W_enc, "b_enc": b_enc, "W_dec": W_dec, "b_dec": b_dec}
1391
1392
  return cfg_dict, state_dict
@@ -1616,7 +1617,9 @@ def mwhanna_transcoder_huggingface_loader(
1616
1617
  )
1617
1618
 
1618
1619
  # Load weights from safetensors
1619
- state_dict = load_file(file_path, device=device)
1620
+ state_dict = load_safetensors_weights(
1621
+ file_path, device=device, dtype=cfg_dict.get("dtype")
1622
+ )
1620
1623
  state_dict["W_enc"] = state_dict["W_enc"].T
1621
1624
 
1622
1625
  return cfg_dict, state_dict, None
@@ -1700,8 +1703,12 @@ def mntss_clt_layer_huggingface_loader(
1700
1703
  force_download=force_download,
1701
1704
  )
1702
1705
 
1703
- encoder_state_dict = load_file(encoder_path, device=device)
1704
- decoder_state_dict = load_file(decoder_path, device=device)
1706
+ encoder_state_dict = load_safetensors_weights(
1707
+ encoder_path, device=device, dtype=cfg_dict.get("dtype")
1708
+ )
1709
+ decoder_state_dict = load_safetensors_weights(
1710
+ decoder_path, device=device, dtype=cfg_dict.get("dtype")
1711
+ )
1705
1712
 
1706
1713
  with torch.no_grad():
1707
1714
  state_dict = {
@@ -1844,7 +1851,9 @@ def temporal_sae_huggingface_loader(
1844
1851
  )
1845
1852
 
1846
1853
  # Load checkpoint from safetensors
1847
- state_dict_raw = load_file(ckpt_path, device=device)
1854
+ state_dict_raw = load_safetensors_weights(
1855
+ ckpt_path, device=device, dtype=cfg_dict.get("dtype")
1856
+ )
1848
1857
 
1849
1858
  # Convert to SAELens naming convention
1850
1859
  # TemporalSAE uses: D (decoder), E (encoder), b (bias), attn_layers.*
@@ -14,6 +14,12 @@ from .jumprelu_sae import (
14
14
  JumpReLUTrainingSAE,
15
15
  JumpReLUTrainingSAEConfig,
16
16
  )
17
+ from .matching_pursuit_sae import (
18
+ MatchingPursuitSAE,
19
+ MatchingPursuitSAEConfig,
20
+ MatchingPursuitTrainingSAE,
21
+ MatchingPursuitTrainingSAEConfig,
22
+ )
17
23
  from .matryoshka_batchtopk_sae import (
18
24
  MatryoshkaBatchTopKTrainingSAE,
19
25
  MatryoshkaBatchTopKTrainingSAEConfig,
@@ -78,4 +84,8 @@ __all__ = [
78
84
  "MatryoshkaBatchTopKTrainingSAEConfig",
79
85
  "TemporalSAE",
80
86
  "TemporalSAEConfig",
87
+ "MatchingPursuitSAE",
88
+ "MatchingPursuitTrainingSAE",
89
+ "MatchingPursuitSAEConfig",
90
+ "MatchingPursuitTrainingSAEConfig",
81
91
  ]
@@ -0,0 +1,334 @@
1
+ """Matching Pursuit SAE"""
2
+
3
+ import warnings
4
+ from dataclasses import dataclass
5
+ from typing import Any
6
+
7
+ import torch
8
+ from typing_extensions import override
9
+
10
+ from sae_lens.saes.sae import (
11
+ SAE,
12
+ SAEConfig,
13
+ TrainCoefficientConfig,
14
+ TrainingSAE,
15
+ TrainingSAEConfig,
16
+ TrainStepInput,
17
+ TrainStepOutput,
18
+ )
19
+
20
+ # --- inference ---
21
+
22
+
23
+ @dataclass
24
+ class MatchingPursuitSAEConfig(SAEConfig):
25
+ """
26
+ Configuration class for MatchingPursuitSAE inference.
27
+
28
+ Args:
29
+ residual_threshold (float): residual error at which to stop selecting latents. Default 1e-2.
30
+ max_iterations (int | None): Maximum iterations (default: d_in if set to None).
31
+ Defaults to None.
32
+ stop_on_duplicate_support (bool): Whether to stop selecting latents if the support set has not changed from the previous iteration. Defaults to True.
33
+ d_in (int): Input dimension (dimensionality of the activations being encoded).
34
+ Inherited from SAEConfig.
35
+ d_sae (int): SAE latent dimension (number of features in the SAE).
36
+ Inherited from SAEConfig.
37
+ dtype (str): Data type for the SAE parameters. Inherited from SAEConfig.
38
+ Defaults to "float32".
39
+ device (str): Device to place the SAE on. Inherited from SAEConfig.
40
+ Defaults to "cpu".
41
+ apply_b_dec_to_input (bool): Whether to apply decoder bias to the input
42
+ before encoding. Inherited from SAEConfig. Defaults to True.
43
+ normalize_activations (Literal["none", "expected_average_only_in", "constant_norm_rescale", "layer_norm"]):
44
+ Normalization strategy for input activations. Inherited from SAEConfig.
45
+ Defaults to "none".
46
+ reshape_activations (Literal["none", "hook_z"]): How to reshape activations
47
+ (useful for attention head outputs). Inherited from SAEConfig.
48
+ Defaults to "none".
49
+ metadata (SAEMetadata): Metadata about the SAE (model name, hook name, etc.).
50
+ Inherited from SAEConfig.
51
+ """
52
+
53
+ residual_threshold: float = 1e-2
54
+ max_iterations: int | None = None
55
+ stop_on_duplicate_support: bool = True
56
+
57
+ @override
58
+ @classmethod
59
+ def architecture(cls) -> str:
60
+ return "matching_pursuit"
61
+
62
+
63
+ class MatchingPursuitSAE(SAE[MatchingPursuitSAEConfig]):
64
+ """
65
+ An inference-only sparse autoencoder using a "matching pursuit" activation function.
66
+ """
67
+
68
+ # Matching pursuit is a tied SAE, so we use W_enc as the decoder transposed
69
+ @property
70
+ def W_enc(self) -> torch.Tensor: # pyright: ignore[reportIncompatibleVariableOverride]
71
+ return self.W_dec.T
72
+
73
+ # hacky way to get around the base class having W_enc.
74
+ # TODO: harmonize with the base class in next major release
75
+ @override
76
+ def __setattr__(self, name: str, value: Any):
77
+ if name == "W_enc":
78
+ return
79
+ super().__setattr__(name, value)
80
+
81
+ @override
82
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
83
+ """
84
+ Converts input x into feature activations.
85
+ """
86
+ sae_in = self.process_sae_in(x)
87
+ return _encode_matching_pursuit(
88
+ sae_in,
89
+ self.W_dec,
90
+ self.cfg.residual_threshold,
91
+ max_iterations=self.cfg.max_iterations,
92
+ stop_on_duplicate_support=self.cfg.stop_on_duplicate_support,
93
+ )
94
+
95
+ @override
96
+ @torch.no_grad()
97
+ def fold_W_dec_norm(self) -> None:
98
+ raise NotImplementedError(
99
+ "Folding W_dec_norm is not safe for MatchingPursuit SAEs, as this may change the resulting activations"
100
+ )
101
+
102
+ @override
103
+ def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
104
+ """
105
+ Decode the feature activations back to the input space.
106
+ Now, if hook_z reshaping is turned on, we reverse the flattening.
107
+ """
108
+ sae_out_pre = feature_acts @ self.W_dec
109
+ # since this is a tied SAE, we need to make sure b_dec is only applied if applied at input
110
+ if self.cfg.apply_b_dec_to_input:
111
+ sae_out_pre = sae_out_pre + self.b_dec
112
+ sae_out_pre = self.hook_sae_recons(sae_out_pre)
113
+ sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
114
+ return self.reshape_fn_out(sae_out_pre, self.d_head)
115
+
116
+
117
+ # --- training ---
118
+
119
+
120
+ @dataclass
121
+ class MatchingPursuitTrainingSAEConfig(TrainingSAEConfig):
122
+ """
123
+ Configuration class for training a MatchingPursuitTrainingSAE.
124
+
125
+ Args:
126
+ residual_threshold (float): residual error at which to stop selecting latents. Default 1e-2.
127
+ max_iterations (int | None): Maximum iterations (default: d_in if set to None).
128
+ Defaults to None.
129
+ stop_on_duplicate_support (bool): Whether to stop selecting latents if the support set has not changed from the previous iteration. Defaults to True.
130
+ decoder_init_norm (float | None): Norm to initialize decoder weights to.
131
+ 0.1 corresponds to the "heuristic" initialization from Anthropic's April update.
132
+ Use None to disable. Inherited from TrainingSAEConfig. Defaults to 0.1.
133
+ d_in (int): Input dimension (dimensionality of the activations being encoded).
134
+ Inherited from SAEConfig.
135
+ d_sae (int): SAE latent dimension (number of features in the SAE).
136
+ Inherited from SAEConfig.
137
+ dtype (str): Data type for the SAE parameters. Inherited from SAEConfig.
138
+ Defaults to "float32".
139
+ device (str): Device to place the SAE on. Inherited from SAEConfig.
140
+ Defaults to "cpu".
141
+ apply_b_dec_to_input (bool): Whether to apply decoder bias to the input
142
+ before encoding. Inherited from SAEConfig. Defaults to True.
143
+ normalize_activations (Literal["none", "expected_average_only_in", "constant_norm_rescale", "layer_norm"]):
144
+ Normalization strategy for input activations. Inherited from SAEConfig.
145
+ Defaults to "none".
146
+ reshape_activations (Literal["none", "hook_z"]): How to reshape activations
147
+ (useful for attention head outputs). Inherited from SAEConfig.
148
+ Defaults to "none".
149
+ metadata (SAEMetadata): Metadata about the SAE training (model name, hook name, etc.).
150
+ Inherited from SAEConfig.
151
+ """
152
+
153
+ residual_threshold: float = 1e-2
154
+ max_iterations: int | None = None
155
+ stop_on_duplicate_support: bool = True
156
+
157
+ @override
158
+ @classmethod
159
+ def architecture(cls) -> str:
160
+ return "matching_pursuit"
161
+
162
+ @override
163
+ def __post_init__(self):
164
+ super().__post_init__()
165
+ if self.decoder_init_norm != 1.0:
166
+ self.decoder_init_norm = 1.0
167
+ warnings.warn(
168
+ "decoder_init_norm must be set to 1.0 for MatchingPursuitTrainingSAE, setting to 1.0"
169
+ )
170
+
171
+
172
+ class MatchingPursuitTrainingSAE(TrainingSAE[MatchingPursuitTrainingSAEConfig]):
173
+ # Matching pursuit is a tied SAE, so we use W_enc as the decoder transposed
174
+ @property
175
+ def W_enc(self) -> torch.Tensor: # pyright: ignore[reportIncompatibleVariableOverride]
176
+ return self.W_dec.T
177
+
178
+ # hacky way to get around the base class having W_enc.
179
+ # TODO: harmonize with the base class in next major release
180
+ @override
181
+ def __setattr__(self, name: str, value: Any):
182
+ if name == "W_enc":
183
+ return
184
+ super().__setattr__(name, value)
185
+
186
+ @override
187
+ def encode_with_hidden_pre(
188
+ self, x: torch.Tensor
189
+ ) -> tuple[torch.Tensor, torch.Tensor]:
190
+ """
191
+ hidden_pre doesn't make sense for matching pursuit, since there is not a single pre-activation.
192
+ We just return zeros for the hidden_pre.
193
+ """
194
+
195
+ sae_in = self.process_sae_in(x)
196
+ acts = _encode_matching_pursuit(
197
+ sae_in,
198
+ self.W_dec,
199
+ self.cfg.residual_threshold,
200
+ max_iterations=self.cfg.max_iterations,
201
+ stop_on_duplicate_support=self.cfg.stop_on_duplicate_support,
202
+ )
203
+ return acts, torch.zeros_like(acts)
204
+
205
+ @override
206
+ @torch.no_grad()
207
+ def fold_W_dec_norm(self) -> None:
208
+ raise NotImplementedError(
209
+ "Folding W_dec_norm is not safe for MatchingPursuit SAEs, as this may change the resulting activations"
210
+ )
211
+
212
+ @override
213
+ def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]:
214
+ return {}
215
+
216
+ @override
217
+ def calculate_aux_loss(
218
+ self,
219
+ step_input: TrainStepInput,
220
+ feature_acts: torch.Tensor,
221
+ hidden_pre: torch.Tensor,
222
+ sae_out: torch.Tensor,
223
+ ) -> dict[str, torch.Tensor]:
224
+ return {}
225
+
226
+ @override
227
+ def training_forward_pass(self, step_input: TrainStepInput) -> TrainStepOutput:
228
+ output = super().training_forward_pass(step_input)
229
+ l0 = output.feature_acts.bool().float().sum(-1).to_dense()
230
+ residual_norm = (step_input.sae_in - output.sae_out).norm(dim=-1)
231
+ output.metrics["max_l0"] = l0.max()
232
+ output.metrics["min_l0"] = l0.min()
233
+ output.metrics["residual_norm"] = residual_norm.mean()
234
+ output.metrics["residual_threshold_converged_portion"] = (
235
+ (residual_norm < self.cfg.residual_threshold).float().mean()
236
+ )
237
+ return output
238
+
239
+ @override
240
+ def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
241
+ """
242
+ Decode the feature activations back to the input space.
243
+ Now, if hook_z reshaping is turned on, we reverse the flattening.
244
+ """
245
+ sae_out_pre = feature_acts @ self.W_dec
246
+ # since this is a tied SAE, we need to make sure b_dec is only applied if applied at input
247
+ if self.cfg.apply_b_dec_to_input:
248
+ sae_out_pre = sae_out_pre + self.b_dec
249
+ sae_out_pre = self.hook_sae_recons(sae_out_pre)
250
+ sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
251
+ return self.reshape_fn_out(sae_out_pre, self.d_head)
252
+
253
+
254
+ # --- shared ---
255
+
256
+
257
+ def _encode_matching_pursuit(
258
+ sae_in_centered: torch.Tensor,
259
+ W_dec: torch.Tensor,
260
+ residual_threshold: float,
261
+ max_iterations: int | None,
262
+ stop_on_duplicate_support: bool,
263
+ ) -> torch.Tensor:
264
+ """
265
+ Matching pursuit encoding.
266
+
267
+ Args:
268
+ sae_in_centered: Input activations, centered by b_dec. Shape [..., d_in].
269
+ W_dec: Decoder weight matrix. Shape [d_sae, d_in].
270
+ residual_threshold: Stop when residual norm falls below this.
271
+ max_iterations: Maximum iterations (default: d_in). Prevents infinite loops.
272
+ stop_on_duplicate_support: Whether to stop selecting latents if the support set has not changed from the previous iteration.
273
+ """
274
+ residual = sae_in_centered.clone()
275
+
276
+ stop_on_residual_threshold = residual_threshold > 0
277
+
278
+ # Handle multi-dimensional inputs by flattening all but the last dimension
279
+ original_shape = residual.shape
280
+ if residual.ndim > 2:
281
+ residual = residual.reshape(-1, residual.shape[-1])
282
+
283
+ batch_size = residual.shape[0]
284
+ d_sae, d_in = W_dec.shape
285
+
286
+ if max_iterations is None:
287
+ max_iterations = d_in # Sensible upper bound
288
+
289
+ acts = torch.zeros(batch_size, d_sae, device=W_dec.device, dtype=residual.dtype)
290
+ prev_support = torch.zeros(batch_size, d_sae, dtype=torch.bool, device=W_dec.device)
291
+ done = torch.zeros(batch_size, dtype=torch.bool, device=W_dec.device)
292
+
293
+ for _ in range(max_iterations):
294
+ # Find indices without gradients - the full [batch, d_sae] matmul result
295
+ # doesn't need to be saved for backward since max indices don't need gradients
296
+ with torch.no_grad():
297
+ indices = (residual @ W_dec.T).relu().max(dim=1, keepdim=True).indices
298
+ indices_flat = indices.squeeze(1) # [batch_size]
299
+
300
+ # Compute values with gradients using only the selected decoder rows.
301
+ # This stores [batch, d_in] for backward instead of [batch, d_sae].
302
+ selected_dec = W_dec[indices_flat] # [batch_size, d_in]
303
+ values = (residual * selected_dec).sum(dim=-1, keepdim=True).relu()
304
+
305
+ # Mask values for samples that are already done
306
+ active_mask = (~done).unsqueeze(1)
307
+ masked_values = (values * active_mask.to(values.dtype)).to(acts.dtype)
308
+
309
+ acts.scatter_add_(1, indices, masked_values)
310
+
311
+ # Update residual
312
+ residual = residual - masked_values * selected_dec
313
+
314
+ if stop_on_duplicate_support or stop_on_residual_threshold:
315
+ with torch.no_grad():
316
+ support = acts != 0
317
+
318
+ # A sample is considered converged if:
319
+ # (1) the support set hasn't changed from the previous iteration (stability), or
320
+ # (2) the residual norm is below a given threshold (good enough reconstruction)
321
+ if stop_on_duplicate_support:
322
+ done = done | (support == prev_support).all(dim=1)
323
+ prev_support = support
324
+ if stop_on_residual_threshold:
325
+ done = done | (residual.norm(dim=-1) < residual_threshold)
326
+
327
+ if done.all():
328
+ break
329
+
330
+ # Reshape acts back to original shape (replacing last dimension with d_sae)
331
+ if len(original_shape) > 2:
332
+ acts = acts.reshape(*original_shape[:-1], acts.shape[-1])
333
+
334
+ return acts
@@ -27,11 +27,10 @@ from typing_extensions import deprecated, overload, override
27
27
 
28
28
  from sae_lens import __version__
29
29
  from sae_lens.constants import (
30
- DTYPE_MAP,
31
30
  SAE_CFG_FILENAME,
32
31
  SAE_WEIGHTS_FILENAME,
33
32
  )
34
- from sae_lens.util import filter_valid_dataclass_fields
33
+ from sae_lens.util import dtype_to_str, filter_valid_dataclass_fields, str_to_dtype
35
34
 
36
35
  if TYPE_CHECKING:
37
36
  from sae_lens.config import LanguageModelSAERunnerConfig
@@ -253,7 +252,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
253
252
  stacklevel=1,
254
253
  )
255
254
 
256
- self.dtype = DTYPE_MAP[cfg.dtype]
255
+ self.dtype = str_to_dtype(cfg.dtype)
257
256
  self.device = torch.device(cfg.device)
258
257
  self.use_error_term = use_error_term
259
258
 
@@ -437,8 +436,8 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
437
436
 
438
437
  # Update dtype in config if provided
439
438
  if dtype_arg is not None:
440
- # Update the cfg.dtype
441
- self.cfg.dtype = str(dtype_arg)
439
+ # Update the cfg.dtype (use canonical short form like "float32")
440
+ self.cfg.dtype = dtype_to_str(dtype_arg)
442
441
 
443
442
  # Update the dtype property
444
443
  self.dtype = dtype_arg
@@ -534,6 +533,15 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
534
533
  dtype: str | None = None,
535
534
  converter: PretrainedSaeDiskLoader = sae_lens_disk_loader,
536
535
  ) -> T_SAE:
536
+ """
537
+ Load a SAE from disk.
538
+
539
+ Args:
540
+ path: The path to the SAE weights and config.
541
+ device: The device to load the SAE on, defaults to "cpu".
542
+ dtype: The dtype to load the SAE on, defaults to None. If None, the dtype will be inferred from the SAE config.
543
+ converter: The converter to use to load the SAE, defaults to sae_lens_disk_loader.
544
+ """
537
545
  overrides = {"dtype": dtype} if dtype is not None else None
538
546
  cfg_dict, state_dict = converter(path, device, cfg_overrides=overrides)
539
547
  cfg_dict = handle_config_defaulting(cfg_dict)
@@ -542,10 +550,17 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
542
550
  )
543
551
  sae_cfg = sae_config_cls.from_dict(cfg_dict)
544
552
  sae_cls = cls.get_sae_class_for_architecture(sae_cfg.architecture())
553
+ # hack to avoid using double memory when loading the SAE.
554
+ # first put the SAE on the meta device, then load the weights.
555
+ device = sae_cfg.device
556
+ sae_cfg.device = "meta"
545
557
  sae = sae_cls(sae_cfg)
558
+ sae.cfg.device = device
546
559
  sae.process_state_dict_for_loading(state_dict)
547
- sae.load_state_dict(state_dict)
548
- return sae
560
+ sae.load_state_dict(state_dict, assign=True)
561
+ # the loaders should already handle the dtype / device conversion
562
+ # but this is a fallback to guarantee the SAE is on the correct device and dtype
563
+ return sae.to(dtype=str_to_dtype(sae_cfg.dtype), device=device)
549
564
 
550
565
  @classmethod
551
566
  def from_pretrained(
@@ -553,6 +568,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
553
568
  release: str,
554
569
  sae_id: str,
555
570
  device: str = "cpu",
571
+ dtype: str = "float32",
556
572
  force_download: bool = False,
557
573
  converter: PretrainedSaeHuggingfaceLoader | None = None,
558
574
  ) -> T_SAE:
@@ -562,10 +578,18 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
562
578
  Args:
563
579
  release: The release name. This will be mapped to a huggingface repo id based on the pretrained_saes.yaml file.
564
580
  id: The id of the SAE to load. This will be mapped to a path in the huggingface repo.
565
- device: The device to load the SAE on.
581
+ device: The device to load the SAE on, defaults to "cpu".
582
+ dtype: The dtype to load the SAE on, defaults to "float32".
583
+ force_download: Whether to force download the SAE weights and config, defaults to False.
584
+ converter: The converter to use to load the SAE, defaults to None. If None, the converter will be inferred from the release.
566
585
  """
567
586
  return cls.from_pretrained_with_cfg_and_sparsity(
568
- release, sae_id, device, force_download, converter=converter
587
+ release,
588
+ sae_id,
589
+ device,
590
+ force_download=force_download,
591
+ dtype=dtype,
592
+ converter=converter,
569
593
  )[0]
570
594
 
571
595
  @classmethod
@@ -574,6 +598,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
574
598
  release: str,
575
599
  sae_id: str,
576
600
  device: str = "cpu",
601
+ dtype: str = "float32",
577
602
  force_download: bool = False,
578
603
  converter: PretrainedSaeHuggingfaceLoader | None = None,
579
604
  ) -> tuple[T_SAE, dict[str, Any], torch.Tensor | None]:
@@ -584,7 +609,10 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
584
609
  Args:
585
610
  release: The release name. This will be mapped to a huggingface repo id based on the pretrained_saes.yaml file.
586
611
  id: The id of the SAE to load. This will be mapped to a path in the huggingface repo.
587
- device: The device to load the SAE on.
612
+ device: The device to load the SAE on, defaults to "cpu".
613
+ dtype: The dtype to load the SAE on, defaults to "float32".
614
+ force_download: Whether to force download the SAE weights and config, defaults to False.
615
+ converter: The converter to use to load the SAE, defaults to None. If None, the converter will be inferred from the release.
588
616
  """
589
617
 
590
618
  # get sae directory
@@ -634,6 +662,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
634
662
  repo_id, folder_name = get_repo_id_and_folder_name(release, sae_id)
635
663
  config_overrides = get_config_overrides(release, sae_id)
636
664
  config_overrides["device"] = device
665
+ config_overrides["dtype"] = dtype
637
666
 
638
667
  # Load config and weights
639
668
  cfg_dict, state_dict, log_sparsities = conversion_loader(
@@ -651,9 +680,14 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
651
680
  )
652
681
  sae_cfg = sae_config_cls.from_dict(cfg_dict)
653
682
  sae_cls = cls.get_sae_class_for_architecture(sae_cfg.architecture())
683
+ # hack to avoid using double memory when loading the SAE.
684
+ # first put the SAE on the meta device, then load the weights.
685
+ device = sae_cfg.device
686
+ sae_cfg.device = "meta"
654
687
  sae = sae_cls(sae_cfg)
688
+ sae.cfg.device = device
655
689
  sae.process_state_dict_for_loading(state_dict)
656
- sae.load_state_dict(state_dict)
690
+ sae.load_state_dict(state_dict, assign=True)
657
691
 
658
692
  # Apply normalization if needed
659
693
  if cfg_dict.get("normalize_activations") == "expected_average_only_in":
@@ -666,7 +700,13 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
666
700
  f"norm_scaling_factor not found for {release} and {sae_id}, but normalize_activations is 'expected_average_only_in'. Skipping normalization folding."
667
701
  )
668
702
 
669
- return sae, cfg_dict, log_sparsities
703
+ # the loaders should already handle the dtype / device conversion
704
+ # but this is a fallback to guarantee the SAE is on the correct device and dtype
705
+ return (
706
+ sae.to(dtype=str_to_dtype(dtype), device=device),
707
+ cfg_dict,
708
+ log_sparsities,
709
+ )
670
710
 
671
711
  @classmethod
672
712
  def from_dict(cls: type[T_SAE], config_dict: dict[str, Any]) -> T_SAE:
@@ -24,7 +24,7 @@ from sae_lens.config import (
24
24
  HfDataset,
25
25
  LanguageModelSAERunnerConfig,
26
26
  )
27
- from sae_lens.constants import ACTIVATIONS_STORE_STATE_FILENAME, DTYPE_MAP
27
+ from sae_lens.constants import ACTIVATIONS_STORE_STATE_FILENAME
28
28
  from sae_lens.pretokenize_runner import get_special_token_from_cfg
29
29
  from sae_lens.saes.sae import SAE, T_SAE_CONFIG, T_TRAINING_SAE_CONFIG
30
30
  from sae_lens.tokenization_and_batching import concat_and_batch_sequences
@@ -32,6 +32,7 @@ from sae_lens.training.mixing_buffer import mixing_buffer
32
32
  from sae_lens.util import (
33
33
  extract_stop_at_layer_from_tlens_hook_name,
34
34
  get_special_token_ids,
35
+ str_to_dtype,
35
36
  )
36
37
 
37
38
 
@@ -258,7 +259,7 @@ class ActivationsStore:
258
259
  self.prepend_bos = prepend_bos
259
260
  self.normalize_activations = normalize_activations
260
261
  self.device = torch.device(device)
261
- self.dtype = DTYPE_MAP[dtype]
262
+ self.dtype = str_to_dtype(dtype)
262
263
  self.cached_activations_path = cached_activations_path
263
264
  self.autocast_lm = autocast_lm
264
265
  self.seqpos_slice = seqpos_slice
@@ -5,8 +5,11 @@ from dataclasses import asdict, fields, is_dataclass
5
5
  from pathlib import Path
6
6
  from typing import Sequence, TypeVar
7
7
 
8
+ import torch
8
9
  from transformers import PreTrainedTokenizerBase
9
10
 
11
+ from sae_lens.constants import DTYPE_MAP, DTYPE_TO_STR
12
+
10
13
  K = TypeVar("K")
11
14
  V = TypeVar("V")
12
15
 
@@ -90,3 +93,21 @@ def get_special_token_ids(tokenizer: PreTrainedTokenizerBase) -> list[int]:
90
93
  special_tokens.add(token_id)
91
94
 
92
95
  return list(special_tokens)
96
+
97
+
98
+ def str_to_dtype(dtype: str) -> torch.dtype:
99
+ """Convert a string to a torch.dtype."""
100
+ if dtype not in DTYPE_MAP:
101
+ raise ValueError(
102
+ f"Invalid dtype: {dtype}. Must be one of {list(DTYPE_MAP.keys())}"
103
+ )
104
+ return DTYPE_MAP[dtype]
105
+
106
+
107
+ def dtype_to_str(dtype: torch.dtype) -> str:
108
+ """Convert a torch.dtype to a string."""
109
+ if dtype not in DTYPE_TO_STR:
110
+ raise ValueError(
111
+ f"Invalid dtype: {dtype}. Must be one of {list(DTYPE_TO_STR.keys())}"
112
+ )
113
+ return DTYPE_TO_STR[dtype]
File without changes
File without changes
File without changes