sae-lens 6.14.1__tar.gz → 6.15.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. {sae_lens-6.14.1 → sae_lens-6.15.0}/PKG-INFO +1 -1
  2. {sae_lens-6.14.1 → sae_lens-6.15.0}/pyproject.toml +1 -1
  3. {sae_lens-6.14.1 → sae_lens-6.15.0}/sae_lens/__init__.py +10 -1
  4. {sae_lens-6.14.1 → sae_lens-6.15.0}/sae_lens/evals.py +18 -14
  5. {sae_lens-6.14.1 → sae_lens-6.15.0}/sae_lens/llm_sae_training_runner.py +8 -14
  6. {sae_lens-6.14.1 → sae_lens-6.15.0}/sae_lens/saes/__init__.py +6 -0
  7. sae_lens-6.15.0/sae_lens/saes/matryoshka_batchtopk_sae.py +143 -0
  8. {sae_lens-6.14.1 → sae_lens-6.15.0}/sae_lens/training/activations_store.py +5 -27
  9. {sae_lens-6.14.1 → sae_lens-6.15.0}/sae_lens/util.py +27 -0
  10. {sae_lens-6.14.1 → sae_lens-6.15.0}/LICENSE +0 -0
  11. {sae_lens-6.14.1 → sae_lens-6.15.0}/README.md +0 -0
  12. {sae_lens-6.14.1 → sae_lens-6.15.0}/sae_lens/analysis/__init__.py +0 -0
  13. {sae_lens-6.14.1 → sae_lens-6.15.0}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
  14. {sae_lens-6.14.1 → sae_lens-6.15.0}/sae_lens/analysis/neuronpedia_integration.py +0 -0
  15. {sae_lens-6.14.1 → sae_lens-6.15.0}/sae_lens/cache_activations_runner.py +0 -0
  16. {sae_lens-6.14.1 → sae_lens-6.15.0}/sae_lens/config.py +0 -0
  17. {sae_lens-6.14.1 → sae_lens-6.15.0}/sae_lens/constants.py +0 -0
  18. {sae_lens-6.14.1 → sae_lens-6.15.0}/sae_lens/load_model.py +0 -0
  19. {sae_lens-6.14.1 → sae_lens-6.15.0}/sae_lens/loading/__init__.py +0 -0
  20. {sae_lens-6.14.1 → sae_lens-6.15.0}/sae_lens/loading/pretrained_sae_loaders.py +0 -0
  21. {sae_lens-6.14.1 → sae_lens-6.15.0}/sae_lens/loading/pretrained_saes_directory.py +0 -0
  22. {sae_lens-6.14.1 → sae_lens-6.15.0}/sae_lens/pretokenize_runner.py +0 -0
  23. {sae_lens-6.14.1 → sae_lens-6.15.0}/sae_lens/pretrained_saes.yaml +0 -0
  24. {sae_lens-6.14.1 → sae_lens-6.15.0}/sae_lens/registry.py +0 -0
  25. {sae_lens-6.14.1 → sae_lens-6.15.0}/sae_lens/saes/batchtopk_sae.py +0 -0
  26. {sae_lens-6.14.1 → sae_lens-6.15.0}/sae_lens/saes/gated_sae.py +0 -0
  27. {sae_lens-6.14.1 → sae_lens-6.15.0}/sae_lens/saes/jumprelu_sae.py +0 -0
  28. {sae_lens-6.14.1 → sae_lens-6.15.0}/sae_lens/saes/sae.py +0 -0
  29. {sae_lens-6.14.1 → sae_lens-6.15.0}/sae_lens/saes/standard_sae.py +0 -0
  30. {sae_lens-6.14.1 → sae_lens-6.15.0}/sae_lens/saes/topk_sae.py +0 -0
  31. {sae_lens-6.14.1 → sae_lens-6.15.0}/sae_lens/saes/transcoder.py +0 -0
  32. {sae_lens-6.14.1 → sae_lens-6.15.0}/sae_lens/tokenization_and_batching.py +0 -0
  33. {sae_lens-6.14.1 → sae_lens-6.15.0}/sae_lens/training/__init__.py +0 -0
  34. {sae_lens-6.14.1 → sae_lens-6.15.0}/sae_lens/training/activation_scaler.py +0 -0
  35. {sae_lens-6.14.1 → sae_lens-6.15.0}/sae_lens/training/mixing_buffer.py +0 -0
  36. {sae_lens-6.14.1 → sae_lens-6.15.0}/sae_lens/training/optim.py +0 -0
  37. {sae_lens-6.14.1 → sae_lens-6.15.0}/sae_lens/training/sae_trainer.py +0 -0
  38. {sae_lens-6.14.1 → sae_lens-6.15.0}/sae_lens/training/types.py +0 -0
  39. {sae_lens-6.14.1 → sae_lens-6.15.0}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
  40. {sae_lens-6.14.1 → sae_lens-6.15.0}/sae_lens/tutorial/tsea.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sae-lens
3
- Version: 6.14.1
3
+ Version: 6.15.0
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  License-File: LICENSE
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "sae-lens"
3
- version = "6.14.1"
3
+ version = "6.15.0"
4
4
  description = "Training and Analyzing Sparse Autoencoders (SAEs)"
5
5
  authors = ["Joseph Bloom"]
6
6
  readme = "README.md"
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.14.1"
2
+ __version__ = "6.15.0"
3
3
 
4
4
  import logging
5
5
 
@@ -19,6 +19,8 @@ from sae_lens.saes import (
19
19
  JumpReLUTrainingSAEConfig,
20
20
  JumpReLUTranscoder,
21
21
  JumpReLUTranscoderConfig,
22
+ MatryoshkaBatchTopKTrainingSAE,
23
+ MatryoshkaBatchTopKTrainingSAEConfig,
22
24
  SAEConfig,
23
25
  SkipTranscoder,
24
26
  SkipTranscoderConfig,
@@ -101,6 +103,8 @@ __all__ = [
101
103
  "SkipTranscoderConfig",
102
104
  "JumpReLUTranscoder",
103
105
  "JumpReLUTranscoderConfig",
106
+ "MatryoshkaBatchTopKTrainingSAE",
107
+ "MatryoshkaBatchTopKTrainingSAEConfig",
104
108
  ]
105
109
 
106
110
 
@@ -115,6 +119,11 @@ register_sae_training_class("jumprelu", JumpReLUTrainingSAE, JumpReLUTrainingSAE
115
119
  register_sae_training_class(
116
120
  "batchtopk", BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig
117
121
  )
122
+ register_sae_training_class(
123
+ "matryoshka_batchtopk",
124
+ MatryoshkaBatchTopKTrainingSAE,
125
+ MatryoshkaBatchTopKTrainingSAEConfig,
126
+ )
118
127
  register_sae_class("transcoder", Transcoder, TranscoderConfig)
119
128
  register_sae_class("skip_transcoder", SkipTranscoder, SkipTranscoderConfig)
120
129
  register_sae_class("jumprelu_transcoder", JumpReLUTranscoder, JumpReLUTranscoderConfig)
@@ -11,7 +11,7 @@ from dataclasses import dataclass, field
11
11
  from functools import partial
12
12
  from importlib.metadata import PackageNotFoundError, version
13
13
  from pathlib import Path
14
- from typing import Any
14
+ from typing import Any, Iterable
15
15
 
16
16
  import einops
17
17
  import pandas as pd
@@ -24,7 +24,10 @@ from sae_lens.loading.pretrained_saes_directory import get_pretrained_saes_direc
24
24
  from sae_lens.saes.sae import SAE, SAEConfig
25
25
  from sae_lens.training.activation_scaler import ActivationScaler
26
26
  from sae_lens.training.activations_store import ActivationsStore
27
- from sae_lens.util import extract_stop_at_layer_from_tlens_hook_name
27
+ from sae_lens.util import (
28
+ extract_stop_at_layer_from_tlens_hook_name,
29
+ get_special_token_ids,
30
+ )
28
31
 
29
32
 
30
33
  def get_library_version() -> str:
@@ -109,9 +112,15 @@ def run_evals(
109
112
  activation_scaler: ActivationScaler,
110
113
  eval_config: EvalConfig = EvalConfig(),
111
114
  model_kwargs: Mapping[str, Any] = {},
112
- ignore_tokens: set[int | None] = set(),
115
+ exclude_special_tokens: Iterable[int] | bool = True,
113
116
  verbose: bool = False,
114
117
  ) -> tuple[dict[str, Any], dict[str, Any]]:
118
+ ignore_tokens = None
119
+ if exclude_special_tokens is True:
120
+ ignore_tokens = list(get_special_token_ids(model.tokenizer)) # type: ignore
121
+ elif exclude_special_tokens:
122
+ ignore_tokens = list(exclude_special_tokens)
123
+
115
124
  hook_name = sae.cfg.metadata.hook_name
116
125
  actual_batch_size = (
117
126
  eval_config.batch_size_prompts or activation_store.store_batch_size_prompts
@@ -312,7 +321,7 @@ def get_downstream_reconstruction_metrics(
312
321
  compute_ce_loss: bool,
313
322
  n_batches: int,
314
323
  eval_batch_size_prompts: int,
315
- ignore_tokens: set[int | None] = set(),
324
+ ignore_tokens: list[int] | None = None,
316
325
  verbose: bool = False,
317
326
  ):
318
327
  metrics_dict = {}
@@ -339,7 +348,7 @@ def get_downstream_reconstruction_metrics(
339
348
  compute_ce_loss=compute_ce_loss,
340
349
  ignore_tokens=ignore_tokens,
341
350
  ).items():
342
- if len(ignore_tokens) > 0:
351
+ if ignore_tokens:
343
352
  mask = torch.logical_not(
344
353
  torch.any(
345
354
  torch.stack(
@@ -384,7 +393,7 @@ def get_sparsity_and_variance_metrics(
384
393
  compute_featurewise_density_statistics: bool,
385
394
  eval_batch_size_prompts: int,
386
395
  model_kwargs: Mapping[str, Any],
387
- ignore_tokens: set[int | None] = set(),
396
+ ignore_tokens: list[int] | None = None,
388
397
  verbose: bool = False,
389
398
  ) -> tuple[dict[str, Any], dict[str, Any]]:
390
399
  hook_name = sae.cfg.metadata.hook_name
@@ -426,7 +435,7 @@ def get_sparsity_and_variance_metrics(
426
435
  for _ in batch_iter:
427
436
  batch_tokens = activation_store.get_batch_tokens(eval_batch_size_prompts)
428
437
 
429
- if len(ignore_tokens) > 0:
438
+ if ignore_tokens:
430
439
  mask = torch.logical_not(
431
440
  torch.any(
432
441
  torch.stack(
@@ -596,7 +605,7 @@ def get_recons_loss(
596
605
  batch_tokens: torch.Tensor,
597
606
  compute_kl: bool,
598
607
  compute_ce_loss: bool,
599
- ignore_tokens: set[int | None] = set(),
608
+ ignore_tokens: list[int] | None = None,
600
609
  model_kwargs: Mapping[str, Any] = {},
601
610
  hook_name: str | None = None,
602
611
  ) -> dict[str, Any]:
@@ -610,7 +619,7 @@ def get_recons_loss(
610
619
  batch_tokens, return_type="both", loss_per_token=True, **model_kwargs
611
620
  )
612
621
 
613
- if len(ignore_tokens) > 0:
622
+ if ignore_tokens:
614
623
  mask = torch.logical_not(
615
624
  torch.any(
616
625
  torch.stack([batch_tokens == token for token in ignore_tokens], dim=0),
@@ -856,11 +865,6 @@ def multiple_evals(
856
865
  activation_scaler=ActivationScaler(),
857
866
  model=current_model,
858
867
  eval_config=eval_config,
859
- ignore_tokens={
860
- current_model.tokenizer.pad_token_id, # type: ignore
861
- current_model.tokenizer.eos_token_id, # type: ignore
862
- current_model.tokenizer.bos_token_id, # type: ignore
863
- },
864
868
  verbose=verbose,
865
869
  )
866
870
  eval_metrics["metrics"] = scalar_metrics
@@ -22,17 +22,13 @@ from sae_lens.constants import (
22
22
  )
23
23
  from sae_lens.evals import EvalConfig, run_evals
24
24
  from sae_lens.load_model import load_model
25
- from sae_lens.saes.batchtopk_sae import BatchTopKTrainingSAEConfig
26
- from sae_lens.saes.gated_sae import GatedTrainingSAEConfig
27
- from sae_lens.saes.jumprelu_sae import JumpReLUTrainingSAEConfig
25
+ from sae_lens.registry import SAE_TRAINING_CLASS_REGISTRY
28
26
  from sae_lens.saes.sae import (
29
27
  T_TRAINING_SAE,
30
28
  T_TRAINING_SAE_CONFIG,
31
29
  TrainingSAE,
32
30
  TrainingSAEConfig,
33
31
  )
34
- from sae_lens.saes.standard_sae import StandardTrainingSAEConfig
35
- from sae_lens.saes.topk_sae import TopKTrainingSAEConfig
36
32
  from sae_lens.training.activation_scaler import ActivationScaler
37
33
  from sae_lens.training.activations_store import ActivationsStore
38
34
  from sae_lens.training.sae_trainer import SAETrainer
@@ -61,9 +57,11 @@ class LLMSaeEvaluator(Generic[T_TRAINING_SAE]):
61
57
  data_provider: DataProvider,
62
58
  activation_scaler: ActivationScaler,
63
59
  ) -> dict[str, Any]:
64
- ignore_tokens = set()
60
+ exclude_special_tokens = False
65
61
  if self.activations_store.exclude_special_tokens is not None:
66
- ignore_tokens = set(self.activations_store.exclude_special_tokens.tolist())
62
+ exclude_special_tokens = (
63
+ self.activations_store.exclude_special_tokens.tolist()
64
+ )
67
65
 
68
66
  eval_config = EvalConfig(
69
67
  batch_size_prompts=self.eval_batch_size_prompts,
@@ -81,7 +79,7 @@ class LLMSaeEvaluator(Generic[T_TRAINING_SAE]):
81
79
  model=self.model,
82
80
  activation_scaler=activation_scaler,
83
81
  eval_config=eval_config,
84
- ignore_tokens=ignore_tokens,
82
+ exclude_special_tokens=exclude_special_tokens,
85
83
  model_kwargs=self.model_kwargs,
86
84
  ) # not calculating featurwise metrics here.
87
85
 
@@ -393,12 +391,8 @@ def _parse_cfg_args(
393
391
  )
394
392
 
395
393
  # Map architecture to concrete config class
396
- sae_config_map = {
397
- "standard": StandardTrainingSAEConfig,
398
- "gated": GatedTrainingSAEConfig,
399
- "jumprelu": JumpReLUTrainingSAEConfig,
400
- "topk": TopKTrainingSAEConfig,
401
- "batchtopk": BatchTopKTrainingSAEConfig,
394
+ sae_config_map: dict[str, type[TrainingSAEConfig]] = {
395
+ name: cfg for name, (_, cfg) in SAE_TRAINING_CLASS_REGISTRY.items()
402
396
  }
403
397
 
404
398
  sae_config_type = sae_config_map[architecture]
@@ -14,6 +14,10 @@ from .jumprelu_sae import (
14
14
  JumpReLUTrainingSAE,
15
15
  JumpReLUTrainingSAEConfig,
16
16
  )
17
+ from .matryoshka_batchtopk_sae import (
18
+ MatryoshkaBatchTopKTrainingSAE,
19
+ MatryoshkaBatchTopKTrainingSAEConfig,
20
+ )
17
21
  from .sae import SAE, SAEConfig, TrainingSAE, TrainingSAEConfig
18
22
  from .standard_sae import (
19
23
  StandardSAE,
@@ -65,4 +69,6 @@ __all__ = [
65
69
  "SkipTranscoderConfig",
66
70
  "JumpReLUTranscoder",
67
71
  "JumpReLUTranscoderConfig",
72
+ "MatryoshkaBatchTopKTrainingSAE",
73
+ "MatryoshkaBatchTopKTrainingSAEConfig",
68
74
  ]
@@ -0,0 +1,143 @@
1
+ import warnings
2
+ from dataclasses import dataclass, field
3
+
4
+ import torch
5
+ from jaxtyping import Float
6
+ from typing_extensions import override
7
+
8
+ from sae_lens.saes.batchtopk_sae import (
9
+ BatchTopKTrainingSAE,
10
+ BatchTopKTrainingSAEConfig,
11
+ )
12
+ from sae_lens.saes.sae import TrainStepInput, TrainStepOutput
13
+ from sae_lens.saes.topk_sae import _sparse_matmul_nd
14
+
15
+
16
+ @dataclass
17
+ class MatryoshkaBatchTopKTrainingSAEConfig(BatchTopKTrainingSAEConfig):
18
+ """
19
+ Configuration class for training a MatryoshkaBatchTopKTrainingSAE.
20
+
21
+ [Matryoshka SAEs](https://arxiv.org/pdf/2503.17547) use a series of nested reconstruction
22
+ losses of different widths during training to avoid feature absorption. This also has a
23
+ nice side-effect of encouraging higher-frequency features to be learned in earlier levels.
24
+ However, this SAE has more hyperparameters to tune than standard BatchTopK SAEs, and takes
25
+ longer to train due to requiring multiple forward passes per training step.
26
+
27
+ After training, MatryoshkaBatchTopK SAEs are saved as JumpReLU SAEs.
28
+
29
+ Args:
30
+ matryoshka_widths (list[int]): The widths of the matryoshka levels. Defaults to an empty list.
31
+ k (float): The number of features to keep active. Inherited from BatchTopKTrainingSAEConfig.
32
+ Defaults to 100.
33
+ topk_threshold_lr (float): Learning rate for updating the global topk threshold.
34
+ The threshold is updated using an exponential moving average of the minimum
35
+ positive activation value. Defaults to 0.01.
36
+ aux_loss_coefficient (float): Coefficient for the auxiliary loss that encourages
37
+ dead neurons to learn useful features. Inherited from TopKTrainingSAEConfig.
38
+ Defaults to 1.0.
39
+ rescale_acts_by_decoder_norm (bool): Treat the decoder as if it was already normalized.
40
+ Inherited from TopKTrainingSAEConfig. Defaults to True.
41
+ decoder_init_norm (float | None): Norm to initialize decoder weights to.
42
+ Inherited from TrainingSAEConfig. Defaults to 0.1.
43
+ d_in (int): Input dimension (dimensionality of the activations being encoded).
44
+ Inherited from SAEConfig.
45
+ d_sae (int): SAE latent dimension (number of features in the SAE).
46
+ Inherited from SAEConfig.
47
+ dtype (str): Data type for the SAE parameters. Inherited from SAEConfig.
48
+ Defaults to "float32".
49
+ device (str): Device to place the SAE on. Inherited from SAEConfig.
50
+ Defaults to "cpu".
51
+ """
52
+
53
+ matryoshka_widths: list[int] = field(default_factory=list)
54
+
55
+ @override
56
+ @classmethod
57
+ def architecture(cls) -> str:
58
+ return "matryoshka_batchtopk"
59
+
60
+
61
+ class MatryoshkaBatchTopKTrainingSAE(BatchTopKTrainingSAE):
62
+ """
63
+ Global Batch TopK Training SAE
64
+
65
+ This SAE will maintain the k on average across the batch, rather than enforcing the k per-sample as in standard TopK.
66
+
67
+ BatchTopK SAEs are saved as JumpReLU SAEs after training.
68
+ """
69
+
70
+ cfg: MatryoshkaBatchTopKTrainingSAEConfig # type: ignore[assignment]
71
+
72
+ def __init__(
73
+ self, cfg: MatryoshkaBatchTopKTrainingSAEConfig, use_error_term: bool = False
74
+ ):
75
+ super().__init__(cfg, use_error_term)
76
+ _validate_matryoshka_config(cfg)
77
+
78
+ @override
79
+ def training_forward_pass(self, step_input: TrainStepInput) -> TrainStepOutput:
80
+ base_output = super().training_forward_pass(step_input)
81
+ hidden_pre = base_output.hidden_pre
82
+ inv_W_dec_norm = 1 / self.W_dec.norm(dim=-1)
83
+ # the outer matryoshka level is the base SAE, so we don't need to add an extra loss for it
84
+ for width in self.cfg.matryoshka_widths[:-1]:
85
+ inner_hidden_pre = hidden_pre[:, :width]
86
+ inner_feat_acts = self.activation_fn(inner_hidden_pre)
87
+ inner_reconstruction = self._decode_matryoshka_level(
88
+ inner_feat_acts, width, inv_W_dec_norm
89
+ )
90
+ inner_mse_loss = (
91
+ self.mse_loss_fn(inner_reconstruction, step_input.sae_in)
92
+ .sum(dim=-1)
93
+ .mean()
94
+ )
95
+ base_output.losses[f"inner_mse_loss_{width}"] = inner_mse_loss
96
+ base_output.loss = base_output.loss + inner_mse_loss
97
+ return base_output
98
+
99
+ def _decode_matryoshka_level(
100
+ self,
101
+ feature_acts: Float[torch.Tensor, "... d_sae"],
102
+ width: int,
103
+ inv_W_dec_norm: torch.Tensor,
104
+ ) -> Float[torch.Tensor, "... d_in"]:
105
+ """
106
+ Decodes feature activations back into input space for a matryoshka level
107
+ """
108
+ # Handle sparse tensors using efficient sparse matrix multiplication
109
+ if self.cfg.rescale_acts_by_decoder_norm:
110
+ # need to multiply by the inverse of the norm because division is illegal with sparse tensors
111
+ feature_acts = feature_acts * inv_W_dec_norm[:width]
112
+ if feature_acts.is_sparse:
113
+ sae_out_pre = (
114
+ _sparse_matmul_nd(feature_acts, self.W_dec[:width]) + self.b_dec
115
+ )
116
+ else:
117
+ sae_out_pre = feature_acts @ self.W_dec[:width] + self.b_dec
118
+ sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
119
+ return self.reshape_fn_out(sae_out_pre, self.d_head)
120
+
121
+
122
+ def _validate_matryoshka_config(cfg: MatryoshkaBatchTopKTrainingSAEConfig) -> None:
123
+ if cfg.matryoshka_widths[-1] != cfg.d_sae:
124
+ # warn the users that we will add a final matryoshka level
125
+ warnings.warn(
126
+ "WARNING: The final matryoshka level width is not set to cfg.d_sae. "
127
+ "A final matryoshka level of width=cfg.d_sae will be added."
128
+ )
129
+ cfg.matryoshka_widths.append(cfg.d_sae)
130
+
131
+ for prev_width, curr_width in zip(
132
+ cfg.matryoshka_widths[:-1], cfg.matryoshka_widths[1:]
133
+ ):
134
+ if prev_width >= curr_width:
135
+ raise ValueError("cfg.matryoshka_widths must be strictly increasing.")
136
+ if len(cfg.matryoshka_widths) == 1:
137
+ warnings.warn(
138
+ "WARNING: You have only set one matryoshka level. This is equivalent to using a standard BatchTopK SAE and is likely not what you want."
139
+ )
140
+ if cfg.matryoshka_widths[0] < cfg.k:
141
+ raise ValueError(
142
+ "The smallest matryoshka level width cannot be smaller than cfg.k."
143
+ )
@@ -29,7 +29,10 @@ from sae_lens.pretokenize_runner import get_special_token_from_cfg
29
29
  from sae_lens.saes.sae import SAE, T_SAE_CONFIG, T_TRAINING_SAE_CONFIG
30
30
  from sae_lens.tokenization_and_batching import concat_and_batch_sequences
31
31
  from sae_lens.training.mixing_buffer import mixing_buffer
32
- from sae_lens.util import extract_stop_at_layer_from_tlens_hook_name
32
+ from sae_lens.util import (
33
+ extract_stop_at_layer_from_tlens_hook_name,
34
+ get_special_token_ids,
35
+ )
33
36
 
34
37
 
35
38
  # TODO: Make an activation store config class to be consistent with the rest of the code.
@@ -113,7 +116,7 @@ class ActivationsStore:
113
116
  if exclude_special_tokens is False:
114
117
  exclude_special_tokens = None
115
118
  if exclude_special_tokens is True:
116
- exclude_special_tokens = _get_special_token_ids(model.tokenizer) # type: ignore
119
+ exclude_special_tokens = get_special_token_ids(model.tokenizer) # type: ignore
117
120
  if exclude_special_tokens is not None:
118
121
  exclude_special_tokens = torch.tensor(
119
122
  exclude_special_tokens, dtype=torch.long, device=device
@@ -763,31 +766,6 @@ def _get_model_device(model: HookedRootModule) -> torch.device:
763
766
  return next(model.parameters()).device # type: ignore
764
767
 
765
768
 
766
- def _get_special_token_ids(tokenizer: PreTrainedTokenizerBase) -> list[int]:
767
- """Get all special token IDs from a tokenizer."""
768
- special_tokens = set()
769
-
770
- # Get special tokens from tokenizer attributes
771
- for attr in dir(tokenizer):
772
- if attr.endswith("_token_id"):
773
- token_id = getattr(tokenizer, attr)
774
- if token_id is not None:
775
- special_tokens.add(token_id)
776
-
777
- # Get any additional special tokens from the tokenizer's special tokens map
778
- if hasattr(tokenizer, "special_tokens_map"):
779
- for token in tokenizer.special_tokens_map.values():
780
- if isinstance(token, str):
781
- token_id = tokenizer.convert_tokens_to_ids(token) # type: ignore
782
- special_tokens.add(token_id)
783
- elif isinstance(token, list):
784
- for t in token:
785
- token_id = tokenizer.convert_tokens_to_ids(t) # type: ignore
786
- special_tokens.add(token_id)
787
-
788
- return list(special_tokens)
789
-
790
-
791
769
  def _filter_buffer_acts(
792
770
  buffer: tuple[torch.Tensor, torch.Tensor | None],
793
771
  exclude_tokens: torch.Tensor | None,
@@ -5,6 +5,8 @@ from dataclasses import asdict, fields, is_dataclass
5
5
  from pathlib import Path
6
6
  from typing import Sequence, TypeVar
7
7
 
8
+ from transformers import PreTrainedTokenizerBase
9
+
8
10
  K = TypeVar("K")
9
11
  V = TypeVar("V")
10
12
 
@@ -63,3 +65,28 @@ def path_or_tmp_dir(path: str | Path | None):
63
65
  yield Path(td)
64
66
  else:
65
67
  yield Path(path)
68
+
69
+
70
+ def get_special_token_ids(tokenizer: PreTrainedTokenizerBase) -> list[int]:
71
+ """Get all special token IDs from a tokenizer."""
72
+ special_tokens = set()
73
+
74
+ # Get special tokens from tokenizer attributes
75
+ for attr in dir(tokenizer):
76
+ if attr.endswith("_token_id"):
77
+ token_id = getattr(tokenizer, attr)
78
+ if token_id is not None:
79
+ special_tokens.add(token_id)
80
+
81
+ # Get any additional special tokens from the tokenizer's special tokens map
82
+ if hasattr(tokenizer, "special_tokens_map"):
83
+ for token in tokenizer.special_tokens_map.values():
84
+ if isinstance(token, str):
85
+ token_id = tokenizer.convert_tokens_to_ids(token) # type: ignore
86
+ special_tokens.add(token_id)
87
+ elif isinstance(token, list):
88
+ for t in token:
89
+ token_id = tokenizer.convert_tokens_to_ids(t) # type: ignore
90
+ special_tokens.add(token_id)
91
+
92
+ return list(special_tokens)
File without changes
File without changes
File without changes