sae-lens 6.14.1__py3-none-any.whl → 6.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +10 -1
- sae_lens/evals.py +18 -14
- sae_lens/llm_sae_training_runner.py +8 -14
- sae_lens/saes/__init__.py +6 -0
- sae_lens/saes/matryoshka_batchtopk_sae.py +143 -0
- sae_lens/training/activations_store.py +5 -27
- sae_lens/util.py +27 -0
- {sae_lens-6.14.1.dist-info → sae_lens-6.15.0.dist-info}/METADATA +1 -1
- {sae_lens-6.14.1.dist-info → sae_lens-6.15.0.dist-info}/RECORD +11 -10
- {sae_lens-6.14.1.dist-info → sae_lens-6.15.0.dist-info}/WHEEL +0 -0
- {sae_lens-6.14.1.dist-info → sae_lens-6.15.0.dist-info}/licenses/LICENSE +0 -0
sae_lens/__init__.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
# ruff: noqa: E402
|
|
2
|
-
__version__ = "6.
|
|
2
|
+
__version__ = "6.15.0"
|
|
3
3
|
|
|
4
4
|
import logging
|
|
5
5
|
|
|
@@ -19,6 +19,8 @@ from sae_lens.saes import (
|
|
|
19
19
|
JumpReLUTrainingSAEConfig,
|
|
20
20
|
JumpReLUTranscoder,
|
|
21
21
|
JumpReLUTranscoderConfig,
|
|
22
|
+
MatryoshkaBatchTopKTrainingSAE,
|
|
23
|
+
MatryoshkaBatchTopKTrainingSAEConfig,
|
|
22
24
|
SAEConfig,
|
|
23
25
|
SkipTranscoder,
|
|
24
26
|
SkipTranscoderConfig,
|
|
@@ -101,6 +103,8 @@ __all__ = [
|
|
|
101
103
|
"SkipTranscoderConfig",
|
|
102
104
|
"JumpReLUTranscoder",
|
|
103
105
|
"JumpReLUTranscoderConfig",
|
|
106
|
+
"MatryoshkaBatchTopKTrainingSAE",
|
|
107
|
+
"MatryoshkaBatchTopKTrainingSAEConfig",
|
|
104
108
|
]
|
|
105
109
|
|
|
106
110
|
|
|
@@ -115,6 +119,11 @@ register_sae_training_class("jumprelu", JumpReLUTrainingSAE, JumpReLUTrainingSAE
|
|
|
115
119
|
register_sae_training_class(
|
|
116
120
|
"batchtopk", BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig
|
|
117
121
|
)
|
|
122
|
+
register_sae_training_class(
|
|
123
|
+
"matryoshka_batchtopk",
|
|
124
|
+
MatryoshkaBatchTopKTrainingSAE,
|
|
125
|
+
MatryoshkaBatchTopKTrainingSAEConfig,
|
|
126
|
+
)
|
|
118
127
|
register_sae_class("transcoder", Transcoder, TranscoderConfig)
|
|
119
128
|
register_sae_class("skip_transcoder", SkipTranscoder, SkipTranscoderConfig)
|
|
120
129
|
register_sae_class("jumprelu_transcoder", JumpReLUTranscoder, JumpReLUTranscoderConfig)
|
sae_lens/evals.py
CHANGED
|
@@ -11,7 +11,7 @@ from dataclasses import dataclass, field
|
|
|
11
11
|
from functools import partial
|
|
12
12
|
from importlib.metadata import PackageNotFoundError, version
|
|
13
13
|
from pathlib import Path
|
|
14
|
-
from typing import Any
|
|
14
|
+
from typing import Any, Iterable
|
|
15
15
|
|
|
16
16
|
import einops
|
|
17
17
|
import pandas as pd
|
|
@@ -24,7 +24,10 @@ from sae_lens.loading.pretrained_saes_directory import get_pretrained_saes_direc
|
|
|
24
24
|
from sae_lens.saes.sae import SAE, SAEConfig
|
|
25
25
|
from sae_lens.training.activation_scaler import ActivationScaler
|
|
26
26
|
from sae_lens.training.activations_store import ActivationsStore
|
|
27
|
-
from sae_lens.util import
|
|
27
|
+
from sae_lens.util import (
|
|
28
|
+
extract_stop_at_layer_from_tlens_hook_name,
|
|
29
|
+
get_special_token_ids,
|
|
30
|
+
)
|
|
28
31
|
|
|
29
32
|
|
|
30
33
|
def get_library_version() -> str:
|
|
@@ -109,9 +112,15 @@ def run_evals(
|
|
|
109
112
|
activation_scaler: ActivationScaler,
|
|
110
113
|
eval_config: EvalConfig = EvalConfig(),
|
|
111
114
|
model_kwargs: Mapping[str, Any] = {},
|
|
112
|
-
|
|
115
|
+
exclude_special_tokens: Iterable[int] | bool = True,
|
|
113
116
|
verbose: bool = False,
|
|
114
117
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
118
|
+
ignore_tokens = None
|
|
119
|
+
if exclude_special_tokens is True:
|
|
120
|
+
ignore_tokens = list(get_special_token_ids(model.tokenizer)) # type: ignore
|
|
121
|
+
elif exclude_special_tokens:
|
|
122
|
+
ignore_tokens = list(exclude_special_tokens)
|
|
123
|
+
|
|
115
124
|
hook_name = sae.cfg.metadata.hook_name
|
|
116
125
|
actual_batch_size = (
|
|
117
126
|
eval_config.batch_size_prompts or activation_store.store_batch_size_prompts
|
|
@@ -312,7 +321,7 @@ def get_downstream_reconstruction_metrics(
|
|
|
312
321
|
compute_ce_loss: bool,
|
|
313
322
|
n_batches: int,
|
|
314
323
|
eval_batch_size_prompts: int,
|
|
315
|
-
ignore_tokens:
|
|
324
|
+
ignore_tokens: list[int] | None = None,
|
|
316
325
|
verbose: bool = False,
|
|
317
326
|
):
|
|
318
327
|
metrics_dict = {}
|
|
@@ -339,7 +348,7 @@ def get_downstream_reconstruction_metrics(
|
|
|
339
348
|
compute_ce_loss=compute_ce_loss,
|
|
340
349
|
ignore_tokens=ignore_tokens,
|
|
341
350
|
).items():
|
|
342
|
-
if
|
|
351
|
+
if ignore_tokens:
|
|
343
352
|
mask = torch.logical_not(
|
|
344
353
|
torch.any(
|
|
345
354
|
torch.stack(
|
|
@@ -384,7 +393,7 @@ def get_sparsity_and_variance_metrics(
|
|
|
384
393
|
compute_featurewise_density_statistics: bool,
|
|
385
394
|
eval_batch_size_prompts: int,
|
|
386
395
|
model_kwargs: Mapping[str, Any],
|
|
387
|
-
ignore_tokens:
|
|
396
|
+
ignore_tokens: list[int] | None = None,
|
|
388
397
|
verbose: bool = False,
|
|
389
398
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
390
399
|
hook_name = sae.cfg.metadata.hook_name
|
|
@@ -426,7 +435,7 @@ def get_sparsity_and_variance_metrics(
|
|
|
426
435
|
for _ in batch_iter:
|
|
427
436
|
batch_tokens = activation_store.get_batch_tokens(eval_batch_size_prompts)
|
|
428
437
|
|
|
429
|
-
if
|
|
438
|
+
if ignore_tokens:
|
|
430
439
|
mask = torch.logical_not(
|
|
431
440
|
torch.any(
|
|
432
441
|
torch.stack(
|
|
@@ -596,7 +605,7 @@ def get_recons_loss(
|
|
|
596
605
|
batch_tokens: torch.Tensor,
|
|
597
606
|
compute_kl: bool,
|
|
598
607
|
compute_ce_loss: bool,
|
|
599
|
-
ignore_tokens:
|
|
608
|
+
ignore_tokens: list[int] | None = None,
|
|
600
609
|
model_kwargs: Mapping[str, Any] = {},
|
|
601
610
|
hook_name: str | None = None,
|
|
602
611
|
) -> dict[str, Any]:
|
|
@@ -610,7 +619,7 @@ def get_recons_loss(
|
|
|
610
619
|
batch_tokens, return_type="both", loss_per_token=True, **model_kwargs
|
|
611
620
|
)
|
|
612
621
|
|
|
613
|
-
if
|
|
622
|
+
if ignore_tokens:
|
|
614
623
|
mask = torch.logical_not(
|
|
615
624
|
torch.any(
|
|
616
625
|
torch.stack([batch_tokens == token for token in ignore_tokens], dim=0),
|
|
@@ -856,11 +865,6 @@ def multiple_evals(
|
|
|
856
865
|
activation_scaler=ActivationScaler(),
|
|
857
866
|
model=current_model,
|
|
858
867
|
eval_config=eval_config,
|
|
859
|
-
ignore_tokens={
|
|
860
|
-
current_model.tokenizer.pad_token_id, # type: ignore
|
|
861
|
-
current_model.tokenizer.eos_token_id, # type: ignore
|
|
862
|
-
current_model.tokenizer.bos_token_id, # type: ignore
|
|
863
|
-
},
|
|
864
868
|
verbose=verbose,
|
|
865
869
|
)
|
|
866
870
|
eval_metrics["metrics"] = scalar_metrics
|
|
@@ -22,17 +22,13 @@ from sae_lens.constants import (
|
|
|
22
22
|
)
|
|
23
23
|
from sae_lens.evals import EvalConfig, run_evals
|
|
24
24
|
from sae_lens.load_model import load_model
|
|
25
|
-
from sae_lens.
|
|
26
|
-
from sae_lens.saes.gated_sae import GatedTrainingSAEConfig
|
|
27
|
-
from sae_lens.saes.jumprelu_sae import JumpReLUTrainingSAEConfig
|
|
25
|
+
from sae_lens.registry import SAE_TRAINING_CLASS_REGISTRY
|
|
28
26
|
from sae_lens.saes.sae import (
|
|
29
27
|
T_TRAINING_SAE,
|
|
30
28
|
T_TRAINING_SAE_CONFIG,
|
|
31
29
|
TrainingSAE,
|
|
32
30
|
TrainingSAEConfig,
|
|
33
31
|
)
|
|
34
|
-
from sae_lens.saes.standard_sae import StandardTrainingSAEConfig
|
|
35
|
-
from sae_lens.saes.topk_sae import TopKTrainingSAEConfig
|
|
36
32
|
from sae_lens.training.activation_scaler import ActivationScaler
|
|
37
33
|
from sae_lens.training.activations_store import ActivationsStore
|
|
38
34
|
from sae_lens.training.sae_trainer import SAETrainer
|
|
@@ -61,9 +57,11 @@ class LLMSaeEvaluator(Generic[T_TRAINING_SAE]):
|
|
|
61
57
|
data_provider: DataProvider,
|
|
62
58
|
activation_scaler: ActivationScaler,
|
|
63
59
|
) -> dict[str, Any]:
|
|
64
|
-
|
|
60
|
+
exclude_special_tokens = False
|
|
65
61
|
if self.activations_store.exclude_special_tokens is not None:
|
|
66
|
-
|
|
62
|
+
exclude_special_tokens = (
|
|
63
|
+
self.activations_store.exclude_special_tokens.tolist()
|
|
64
|
+
)
|
|
67
65
|
|
|
68
66
|
eval_config = EvalConfig(
|
|
69
67
|
batch_size_prompts=self.eval_batch_size_prompts,
|
|
@@ -81,7 +79,7 @@ class LLMSaeEvaluator(Generic[T_TRAINING_SAE]):
|
|
|
81
79
|
model=self.model,
|
|
82
80
|
activation_scaler=activation_scaler,
|
|
83
81
|
eval_config=eval_config,
|
|
84
|
-
|
|
82
|
+
exclude_special_tokens=exclude_special_tokens,
|
|
85
83
|
model_kwargs=self.model_kwargs,
|
|
86
84
|
) # not calculating featurwise metrics here.
|
|
87
85
|
|
|
@@ -393,12 +391,8 @@ def _parse_cfg_args(
|
|
|
393
391
|
)
|
|
394
392
|
|
|
395
393
|
# Map architecture to concrete config class
|
|
396
|
-
sae_config_map = {
|
|
397
|
-
|
|
398
|
-
"gated": GatedTrainingSAEConfig,
|
|
399
|
-
"jumprelu": JumpReLUTrainingSAEConfig,
|
|
400
|
-
"topk": TopKTrainingSAEConfig,
|
|
401
|
-
"batchtopk": BatchTopKTrainingSAEConfig,
|
|
394
|
+
sae_config_map: dict[str, type[TrainingSAEConfig]] = {
|
|
395
|
+
name: cfg for name, (_, cfg) in SAE_TRAINING_CLASS_REGISTRY.items()
|
|
402
396
|
}
|
|
403
397
|
|
|
404
398
|
sae_config_type = sae_config_map[architecture]
|
sae_lens/saes/__init__.py
CHANGED
|
@@ -14,6 +14,10 @@ from .jumprelu_sae import (
|
|
|
14
14
|
JumpReLUTrainingSAE,
|
|
15
15
|
JumpReLUTrainingSAEConfig,
|
|
16
16
|
)
|
|
17
|
+
from .matryoshka_batchtopk_sae import (
|
|
18
|
+
MatryoshkaBatchTopKTrainingSAE,
|
|
19
|
+
MatryoshkaBatchTopKTrainingSAEConfig,
|
|
20
|
+
)
|
|
17
21
|
from .sae import SAE, SAEConfig, TrainingSAE, TrainingSAEConfig
|
|
18
22
|
from .standard_sae import (
|
|
19
23
|
StandardSAE,
|
|
@@ -65,4 +69,6 @@ __all__ = [
|
|
|
65
69
|
"SkipTranscoderConfig",
|
|
66
70
|
"JumpReLUTranscoder",
|
|
67
71
|
"JumpReLUTranscoderConfig",
|
|
72
|
+
"MatryoshkaBatchTopKTrainingSAE",
|
|
73
|
+
"MatryoshkaBatchTopKTrainingSAEConfig",
|
|
68
74
|
]
|
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from jaxtyping import Float
|
|
6
|
+
from typing_extensions import override
|
|
7
|
+
|
|
8
|
+
from sae_lens.saes.batchtopk_sae import (
|
|
9
|
+
BatchTopKTrainingSAE,
|
|
10
|
+
BatchTopKTrainingSAEConfig,
|
|
11
|
+
)
|
|
12
|
+
from sae_lens.saes.sae import TrainStepInput, TrainStepOutput
|
|
13
|
+
from sae_lens.saes.topk_sae import _sparse_matmul_nd
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class MatryoshkaBatchTopKTrainingSAEConfig(BatchTopKTrainingSAEConfig):
|
|
18
|
+
"""
|
|
19
|
+
Configuration class for training a MatryoshkaBatchTopKTrainingSAE.
|
|
20
|
+
|
|
21
|
+
[Matryoshka SAEs](https://arxiv.org/pdf/2503.17547) use a series of nested reconstruction
|
|
22
|
+
losses of different widths during training to avoid feature absorption. This also has a
|
|
23
|
+
nice side-effect of encouraging higher-frequency features to be learned in earlier levels.
|
|
24
|
+
However, this SAE has more hyperparameters to tune than standard BatchTopK SAEs, and takes
|
|
25
|
+
longer to train due to requiring multiple forward passes per training step.
|
|
26
|
+
|
|
27
|
+
After training, MatryoshkaBatchTopK SAEs are saved as JumpReLU SAEs.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
matryoshka_widths (list[int]): The widths of the matryoshka levels. Defaults to an empty list.
|
|
31
|
+
k (float): The number of features to keep active. Inherited from BatchTopKTrainingSAEConfig.
|
|
32
|
+
Defaults to 100.
|
|
33
|
+
topk_threshold_lr (float): Learning rate for updating the global topk threshold.
|
|
34
|
+
The threshold is updated using an exponential moving average of the minimum
|
|
35
|
+
positive activation value. Defaults to 0.01.
|
|
36
|
+
aux_loss_coefficient (float): Coefficient for the auxiliary loss that encourages
|
|
37
|
+
dead neurons to learn useful features. Inherited from TopKTrainingSAEConfig.
|
|
38
|
+
Defaults to 1.0.
|
|
39
|
+
rescale_acts_by_decoder_norm (bool): Treat the decoder as if it was already normalized.
|
|
40
|
+
Inherited from TopKTrainingSAEConfig. Defaults to True.
|
|
41
|
+
decoder_init_norm (float | None): Norm to initialize decoder weights to.
|
|
42
|
+
Inherited from TrainingSAEConfig. Defaults to 0.1.
|
|
43
|
+
d_in (int): Input dimension (dimensionality of the activations being encoded).
|
|
44
|
+
Inherited from SAEConfig.
|
|
45
|
+
d_sae (int): SAE latent dimension (number of features in the SAE).
|
|
46
|
+
Inherited from SAEConfig.
|
|
47
|
+
dtype (str): Data type for the SAE parameters. Inherited from SAEConfig.
|
|
48
|
+
Defaults to "float32".
|
|
49
|
+
device (str): Device to place the SAE on. Inherited from SAEConfig.
|
|
50
|
+
Defaults to "cpu".
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
matryoshka_widths: list[int] = field(default_factory=list)
|
|
54
|
+
|
|
55
|
+
@override
|
|
56
|
+
@classmethod
|
|
57
|
+
def architecture(cls) -> str:
|
|
58
|
+
return "matryoshka_batchtopk"
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class MatryoshkaBatchTopKTrainingSAE(BatchTopKTrainingSAE):
|
|
62
|
+
"""
|
|
63
|
+
Global Batch TopK Training SAE
|
|
64
|
+
|
|
65
|
+
This SAE will maintain the k on average across the batch, rather than enforcing the k per-sample as in standard TopK.
|
|
66
|
+
|
|
67
|
+
BatchTopK SAEs are saved as JumpReLU SAEs after training.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
cfg: MatryoshkaBatchTopKTrainingSAEConfig # type: ignore[assignment]
|
|
71
|
+
|
|
72
|
+
def __init__(
|
|
73
|
+
self, cfg: MatryoshkaBatchTopKTrainingSAEConfig, use_error_term: bool = False
|
|
74
|
+
):
|
|
75
|
+
super().__init__(cfg, use_error_term)
|
|
76
|
+
_validate_matryoshka_config(cfg)
|
|
77
|
+
|
|
78
|
+
@override
|
|
79
|
+
def training_forward_pass(self, step_input: TrainStepInput) -> TrainStepOutput:
|
|
80
|
+
base_output = super().training_forward_pass(step_input)
|
|
81
|
+
hidden_pre = base_output.hidden_pre
|
|
82
|
+
inv_W_dec_norm = 1 / self.W_dec.norm(dim=-1)
|
|
83
|
+
# the outer matryoshka level is the base SAE, so we don't need to add an extra loss for it
|
|
84
|
+
for width in self.cfg.matryoshka_widths[:-1]:
|
|
85
|
+
inner_hidden_pre = hidden_pre[:, :width]
|
|
86
|
+
inner_feat_acts = self.activation_fn(inner_hidden_pre)
|
|
87
|
+
inner_reconstruction = self._decode_matryoshka_level(
|
|
88
|
+
inner_feat_acts, width, inv_W_dec_norm
|
|
89
|
+
)
|
|
90
|
+
inner_mse_loss = (
|
|
91
|
+
self.mse_loss_fn(inner_reconstruction, step_input.sae_in)
|
|
92
|
+
.sum(dim=-1)
|
|
93
|
+
.mean()
|
|
94
|
+
)
|
|
95
|
+
base_output.losses[f"inner_mse_loss_{width}"] = inner_mse_loss
|
|
96
|
+
base_output.loss = base_output.loss + inner_mse_loss
|
|
97
|
+
return base_output
|
|
98
|
+
|
|
99
|
+
def _decode_matryoshka_level(
|
|
100
|
+
self,
|
|
101
|
+
feature_acts: Float[torch.Tensor, "... d_sae"],
|
|
102
|
+
width: int,
|
|
103
|
+
inv_W_dec_norm: torch.Tensor,
|
|
104
|
+
) -> Float[torch.Tensor, "... d_in"]:
|
|
105
|
+
"""
|
|
106
|
+
Decodes feature activations back into input space for a matryoshka level
|
|
107
|
+
"""
|
|
108
|
+
# Handle sparse tensors using efficient sparse matrix multiplication
|
|
109
|
+
if self.cfg.rescale_acts_by_decoder_norm:
|
|
110
|
+
# need to multiply by the inverse of the norm because division is illegal with sparse tensors
|
|
111
|
+
feature_acts = feature_acts * inv_W_dec_norm[:width]
|
|
112
|
+
if feature_acts.is_sparse:
|
|
113
|
+
sae_out_pre = (
|
|
114
|
+
_sparse_matmul_nd(feature_acts, self.W_dec[:width]) + self.b_dec
|
|
115
|
+
)
|
|
116
|
+
else:
|
|
117
|
+
sae_out_pre = feature_acts @ self.W_dec[:width] + self.b_dec
|
|
118
|
+
sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
|
|
119
|
+
return self.reshape_fn_out(sae_out_pre, self.d_head)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def _validate_matryoshka_config(cfg: MatryoshkaBatchTopKTrainingSAEConfig) -> None:
|
|
123
|
+
if cfg.matryoshka_widths[-1] != cfg.d_sae:
|
|
124
|
+
# warn the users that we will add a final matryoshka level
|
|
125
|
+
warnings.warn(
|
|
126
|
+
"WARNING: The final matryoshka level width is not set to cfg.d_sae. "
|
|
127
|
+
"A final matryoshka level of width=cfg.d_sae will be added."
|
|
128
|
+
)
|
|
129
|
+
cfg.matryoshka_widths.append(cfg.d_sae)
|
|
130
|
+
|
|
131
|
+
for prev_width, curr_width in zip(
|
|
132
|
+
cfg.matryoshka_widths[:-1], cfg.matryoshka_widths[1:]
|
|
133
|
+
):
|
|
134
|
+
if prev_width >= curr_width:
|
|
135
|
+
raise ValueError("cfg.matryoshka_widths must be strictly increasing.")
|
|
136
|
+
if len(cfg.matryoshka_widths) == 1:
|
|
137
|
+
warnings.warn(
|
|
138
|
+
"WARNING: You have only set one matryoshka level. This is equivalent to using a standard BatchTopK SAE and is likely not what you want."
|
|
139
|
+
)
|
|
140
|
+
if cfg.matryoshka_widths[0] < cfg.k:
|
|
141
|
+
raise ValueError(
|
|
142
|
+
"The smallest matryoshka level width cannot be smaller than cfg.k."
|
|
143
|
+
)
|
|
@@ -29,7 +29,10 @@ from sae_lens.pretokenize_runner import get_special_token_from_cfg
|
|
|
29
29
|
from sae_lens.saes.sae import SAE, T_SAE_CONFIG, T_TRAINING_SAE_CONFIG
|
|
30
30
|
from sae_lens.tokenization_and_batching import concat_and_batch_sequences
|
|
31
31
|
from sae_lens.training.mixing_buffer import mixing_buffer
|
|
32
|
-
from sae_lens.util import
|
|
32
|
+
from sae_lens.util import (
|
|
33
|
+
extract_stop_at_layer_from_tlens_hook_name,
|
|
34
|
+
get_special_token_ids,
|
|
35
|
+
)
|
|
33
36
|
|
|
34
37
|
|
|
35
38
|
# TODO: Make an activation store config class to be consistent with the rest of the code.
|
|
@@ -113,7 +116,7 @@ class ActivationsStore:
|
|
|
113
116
|
if exclude_special_tokens is False:
|
|
114
117
|
exclude_special_tokens = None
|
|
115
118
|
if exclude_special_tokens is True:
|
|
116
|
-
exclude_special_tokens =
|
|
119
|
+
exclude_special_tokens = get_special_token_ids(model.tokenizer) # type: ignore
|
|
117
120
|
if exclude_special_tokens is not None:
|
|
118
121
|
exclude_special_tokens = torch.tensor(
|
|
119
122
|
exclude_special_tokens, dtype=torch.long, device=device
|
|
@@ -763,31 +766,6 @@ def _get_model_device(model: HookedRootModule) -> torch.device:
|
|
|
763
766
|
return next(model.parameters()).device # type: ignore
|
|
764
767
|
|
|
765
768
|
|
|
766
|
-
def _get_special_token_ids(tokenizer: PreTrainedTokenizerBase) -> list[int]:
|
|
767
|
-
"""Get all special token IDs from a tokenizer."""
|
|
768
|
-
special_tokens = set()
|
|
769
|
-
|
|
770
|
-
# Get special tokens from tokenizer attributes
|
|
771
|
-
for attr in dir(tokenizer):
|
|
772
|
-
if attr.endswith("_token_id"):
|
|
773
|
-
token_id = getattr(tokenizer, attr)
|
|
774
|
-
if token_id is not None:
|
|
775
|
-
special_tokens.add(token_id)
|
|
776
|
-
|
|
777
|
-
# Get any additional special tokens from the tokenizer's special tokens map
|
|
778
|
-
if hasattr(tokenizer, "special_tokens_map"):
|
|
779
|
-
for token in tokenizer.special_tokens_map.values():
|
|
780
|
-
if isinstance(token, str):
|
|
781
|
-
token_id = tokenizer.convert_tokens_to_ids(token) # type: ignore
|
|
782
|
-
special_tokens.add(token_id)
|
|
783
|
-
elif isinstance(token, list):
|
|
784
|
-
for t in token:
|
|
785
|
-
token_id = tokenizer.convert_tokens_to_ids(t) # type: ignore
|
|
786
|
-
special_tokens.add(token_id)
|
|
787
|
-
|
|
788
|
-
return list(special_tokens)
|
|
789
|
-
|
|
790
|
-
|
|
791
769
|
def _filter_buffer_acts(
|
|
792
770
|
buffer: tuple[torch.Tensor, torch.Tensor | None],
|
|
793
771
|
exclude_tokens: torch.Tensor | None,
|
sae_lens/util.py
CHANGED
|
@@ -5,6 +5,8 @@ from dataclasses import asdict, fields, is_dataclass
|
|
|
5
5
|
from pathlib import Path
|
|
6
6
|
from typing import Sequence, TypeVar
|
|
7
7
|
|
|
8
|
+
from transformers import PreTrainedTokenizerBase
|
|
9
|
+
|
|
8
10
|
K = TypeVar("K")
|
|
9
11
|
V = TypeVar("V")
|
|
10
12
|
|
|
@@ -63,3 +65,28 @@ def path_or_tmp_dir(path: str | Path | None):
|
|
|
63
65
|
yield Path(td)
|
|
64
66
|
else:
|
|
65
67
|
yield Path(path)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def get_special_token_ids(tokenizer: PreTrainedTokenizerBase) -> list[int]:
|
|
71
|
+
"""Get all special token IDs from a tokenizer."""
|
|
72
|
+
special_tokens = set()
|
|
73
|
+
|
|
74
|
+
# Get special tokens from tokenizer attributes
|
|
75
|
+
for attr in dir(tokenizer):
|
|
76
|
+
if attr.endswith("_token_id"):
|
|
77
|
+
token_id = getattr(tokenizer, attr)
|
|
78
|
+
if token_id is not None:
|
|
79
|
+
special_tokens.add(token_id)
|
|
80
|
+
|
|
81
|
+
# Get any additional special tokens from the tokenizer's special tokens map
|
|
82
|
+
if hasattr(tokenizer, "special_tokens_map"):
|
|
83
|
+
for token in tokenizer.special_tokens_map.values():
|
|
84
|
+
if isinstance(token, str):
|
|
85
|
+
token_id = tokenizer.convert_tokens_to_ids(token) # type: ignore
|
|
86
|
+
special_tokens.add(token_id)
|
|
87
|
+
elif isinstance(token, list):
|
|
88
|
+
for t in token:
|
|
89
|
+
token_id = tokenizer.convert_tokens_to_ids(t) # type: ignore
|
|
90
|
+
special_tokens.add(token_id)
|
|
91
|
+
|
|
92
|
+
return list(special_tokens)
|
|
@@ -1,12 +1,12 @@
|
|
|
1
|
-
sae_lens/__init__.py,sha256=
|
|
1
|
+
sae_lens/__init__.py,sha256=ab8Lj2QJE3i1uOP_4B9LLh_vCgi__3XXx66_eO8rcrA,3886
|
|
2
2
|
sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
3
|
sae_lens/analysis/hooked_sae_transformer.py,sha256=vRu6JseH1lZaEeILD5bEkQEQ1wYHHDcxD-f2olKmE9Y,14275
|
|
4
4
|
sae_lens/analysis/neuronpedia_integration.py,sha256=Gx1W7hUBEuMoasNcnOnZ1wmqbXDd1pSZ1nqKEya1HQc,4962
|
|
5
5
|
sae_lens/cache_activations_runner.py,sha256=cNeAtp2JQ_vKbeddZVM-tcPLYyyfTWL8NDna5KQpkLI,12583
|
|
6
6
|
sae_lens/config.py,sha256=IdRXSKPfYY3hwUovj-u83eep8z52gkJHII0mY0KseYY,28739
|
|
7
7
|
sae_lens/constants.py,sha256=CSjmiZ-bhjQeVLyRvWxAjBokCgkfM8mnvd7-vxLIWTY,639
|
|
8
|
-
sae_lens/evals.py,sha256=
|
|
9
|
-
sae_lens/llm_sae_training_runner.py,sha256=
|
|
8
|
+
sae_lens/evals.py,sha256=P0NUsJeGzYxFBiVKhbPzd72IFKY4gH40HHlEZ3jEAmg,39598
|
|
9
|
+
sae_lens/llm_sae_training_runner.py,sha256=UHRcLqvtnORsZ7u7ymbrv-Ib2BD84czHBvu03jNbtcE,14834
|
|
10
10
|
sae_lens/load_model.py,sha256=C8AMykctj6H7tz_xRwB06-EXj6TfW64PtSJZR5Jxn1Y,8649
|
|
11
11
|
sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
12
12
|
sae_lens/loading/pretrained_sae_loaders.py,sha256=SM4aT8NM6ezYix5c2u7p72Fz2RfvTtf7gw5RdOSKXhc,49846
|
|
@@ -14,10 +14,11 @@ sae_lens/loading/pretrained_saes_directory.py,sha256=4Vn-Jex6SveD7EbxcSOBv8cx1gk
|
|
|
14
14
|
sae_lens/pretokenize_runner.py,sha256=x-reJzVPFDS9iRFbZtrFYSzNguJYki9gd0pbHjYJ3r4,7085
|
|
15
15
|
sae_lens/pretrained_saes.yaml,sha256=6ca3geEB6NyhULUrmdtPDK8ea0YdpLp8_au78vIFC5w,602553
|
|
16
16
|
sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
|
|
17
|
-
sae_lens/saes/__init__.py,sha256=
|
|
17
|
+
sae_lens/saes/__init__.py,sha256=sIfZUxZ4m3HPtPymCJlpBEofiCrL8_QziE6ChS-v4lE,1677
|
|
18
18
|
sae_lens/saes/batchtopk_sae.py,sha256=zxIke8lOBKkQEMVFk6sSW6q_s6F9RKhysLqfqG9ecwI,5300
|
|
19
19
|
sae_lens/saes/gated_sae.py,sha256=qcmM9JwBA8aZR8z_IRHV1_gQX-q_63tKewWXRnhdXuo,8986
|
|
20
20
|
sae_lens/saes/jumprelu_sae.py,sha256=HHBF1sJ95lZvxwP5vwLSQFKdnJN2KKYK0WAEaLTrta0,13399
|
|
21
|
+
sae_lens/saes/matryoshka_batchtopk_sae.py,sha256=zrS4MksbxdhhftmU3UWjRCWjR7iEBpAk6N00c6GrXks,6291
|
|
21
22
|
sae_lens/saes/sae.py,sha256=nuII6ZmaVtJWhPjyhasHQyiv_Wj-zdAtRQqJRYbVBQs,38274
|
|
22
23
|
sae_lens/saes/standard_sae.py,sha256=9UqYyYtQuThYxXKNaDjYcyowpOx2-7cShG-TeUP6JCQ,5940
|
|
23
24
|
sae_lens/saes/topk_sae.py,sha256=tzQM5eQFifMe--8_8NUBYWY7hpjQa6A_olNe6U71FE8,21275
|
|
@@ -25,15 +26,15 @@ sae_lens/saes/transcoder.py,sha256=BfLSbTYVNZh-ruGxseZiZJ_acEL6_7QyTdfqUr0lDOg,1
|
|
|
25
26
|
sae_lens/tokenization_and_batching.py,sha256=D_o7cXvRqhT89H3wNzoRymNALNE6eHojBWLdXOUwUGE,5438
|
|
26
27
|
sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
27
28
|
sae_lens/training/activation_scaler.py,sha256=seEE-2Qd2JMHxqgnsNWPt-DGtYGZxWPnOwCGuVNSOtI,1719
|
|
28
|
-
sae_lens/training/activations_store.py,sha256=
|
|
29
|
+
sae_lens/training/activations_store.py,sha256=hHY6rW-T7sLq2a8JPEyWdm8leuIRm_MsObZs3jRTZmE,31931
|
|
29
30
|
sae_lens/training/mixing_buffer.py,sha256=vDpYG5ZE70szDvBsRKcNHEES3h_WTKJ16qDYk5jPOVA,2015
|
|
30
31
|
sae_lens/training/optim.py,sha256=TiI9nbffzXNsI8WjcIsqa2uheW6suxqL_KDDmWXobWI,5312
|
|
31
32
|
sae_lens/training/sae_trainer.py,sha256=il4Evf-c4F3Uf2n_v-AOItCasX-uPxYTzn_sZLvLkl0,15633
|
|
32
33
|
sae_lens/training/types.py,sha256=1FpLx_Doda9vZpmfm-x1e8wGBYpyhe9Kpb_JuM5nIFM,90
|
|
33
34
|
sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
|
|
34
35
|
sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
|
|
35
|
-
sae_lens/util.py,sha256=
|
|
36
|
-
sae_lens-6.
|
|
37
|
-
sae_lens-6.
|
|
38
|
-
sae_lens-6.
|
|
39
|
-
sae_lens-6.
|
|
36
|
+
sae_lens/util.py,sha256=tCovQ-eZa1L7thPpNDL6PGOJrIMML2yLI5e0EHCOpS8,3309
|
|
37
|
+
sae_lens-6.15.0.dist-info/METADATA,sha256=UmBQ8quUJBWyLclhbnDcXAkL-6jnOW4SbT8_X3rrcbE,5318
|
|
38
|
+
sae_lens-6.15.0.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
|
|
39
|
+
sae_lens-6.15.0.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
|
|
40
|
+
sae_lens-6.15.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|