sae-lens 6.14.0__py3-none-any.whl → 6.14.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +1 -1
- sae_lens/evals.py +18 -14
- sae_lens/llm_sae_training_runner.py +5 -3
- sae_lens/training/activations_store.py +5 -27
- sae_lens/util.py +27 -0
- {sae_lens-6.14.0.dist-info → sae_lens-6.14.2.dist-info}/METADATA +1 -1
- {sae_lens-6.14.0.dist-info → sae_lens-6.14.2.dist-info}/RECORD +9 -9
- {sae_lens-6.14.0.dist-info → sae_lens-6.14.2.dist-info}/WHEEL +0 -0
- {sae_lens-6.14.0.dist-info → sae_lens-6.14.2.dist-info}/licenses/LICENSE +0 -0
sae_lens/__init__.py
CHANGED
sae_lens/evals.py
CHANGED
|
@@ -11,7 +11,7 @@ from dataclasses import dataclass, field
|
|
|
11
11
|
from functools import partial
|
|
12
12
|
from importlib.metadata import PackageNotFoundError, version
|
|
13
13
|
from pathlib import Path
|
|
14
|
-
from typing import Any
|
|
14
|
+
from typing import Any, Iterable
|
|
15
15
|
|
|
16
16
|
import einops
|
|
17
17
|
import pandas as pd
|
|
@@ -24,7 +24,10 @@ from sae_lens.loading.pretrained_saes_directory import get_pretrained_saes_direc
|
|
|
24
24
|
from sae_lens.saes.sae import SAE, SAEConfig
|
|
25
25
|
from sae_lens.training.activation_scaler import ActivationScaler
|
|
26
26
|
from sae_lens.training.activations_store import ActivationsStore
|
|
27
|
-
from sae_lens.util import
|
|
27
|
+
from sae_lens.util import (
|
|
28
|
+
extract_stop_at_layer_from_tlens_hook_name,
|
|
29
|
+
get_special_token_ids,
|
|
30
|
+
)
|
|
28
31
|
|
|
29
32
|
|
|
30
33
|
def get_library_version() -> str:
|
|
@@ -109,9 +112,15 @@ def run_evals(
|
|
|
109
112
|
activation_scaler: ActivationScaler,
|
|
110
113
|
eval_config: EvalConfig = EvalConfig(),
|
|
111
114
|
model_kwargs: Mapping[str, Any] = {},
|
|
112
|
-
|
|
115
|
+
exclude_special_tokens: Iterable[int] | bool = True,
|
|
113
116
|
verbose: bool = False,
|
|
114
117
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
118
|
+
ignore_tokens = None
|
|
119
|
+
if exclude_special_tokens is True:
|
|
120
|
+
ignore_tokens = list(get_special_token_ids(model.tokenizer)) # type: ignore
|
|
121
|
+
elif exclude_special_tokens:
|
|
122
|
+
ignore_tokens = list(exclude_special_tokens)
|
|
123
|
+
|
|
115
124
|
hook_name = sae.cfg.metadata.hook_name
|
|
116
125
|
actual_batch_size = (
|
|
117
126
|
eval_config.batch_size_prompts or activation_store.store_batch_size_prompts
|
|
@@ -312,7 +321,7 @@ def get_downstream_reconstruction_metrics(
|
|
|
312
321
|
compute_ce_loss: bool,
|
|
313
322
|
n_batches: int,
|
|
314
323
|
eval_batch_size_prompts: int,
|
|
315
|
-
ignore_tokens:
|
|
324
|
+
ignore_tokens: list[int] | None = None,
|
|
316
325
|
verbose: bool = False,
|
|
317
326
|
):
|
|
318
327
|
metrics_dict = {}
|
|
@@ -339,7 +348,7 @@ def get_downstream_reconstruction_metrics(
|
|
|
339
348
|
compute_ce_loss=compute_ce_loss,
|
|
340
349
|
ignore_tokens=ignore_tokens,
|
|
341
350
|
).items():
|
|
342
|
-
if
|
|
351
|
+
if ignore_tokens:
|
|
343
352
|
mask = torch.logical_not(
|
|
344
353
|
torch.any(
|
|
345
354
|
torch.stack(
|
|
@@ -384,7 +393,7 @@ def get_sparsity_and_variance_metrics(
|
|
|
384
393
|
compute_featurewise_density_statistics: bool,
|
|
385
394
|
eval_batch_size_prompts: int,
|
|
386
395
|
model_kwargs: Mapping[str, Any],
|
|
387
|
-
ignore_tokens:
|
|
396
|
+
ignore_tokens: list[int] | None = None,
|
|
388
397
|
verbose: bool = False,
|
|
389
398
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
390
399
|
hook_name = sae.cfg.metadata.hook_name
|
|
@@ -426,7 +435,7 @@ def get_sparsity_and_variance_metrics(
|
|
|
426
435
|
for _ in batch_iter:
|
|
427
436
|
batch_tokens = activation_store.get_batch_tokens(eval_batch_size_prompts)
|
|
428
437
|
|
|
429
|
-
if
|
|
438
|
+
if ignore_tokens:
|
|
430
439
|
mask = torch.logical_not(
|
|
431
440
|
torch.any(
|
|
432
441
|
torch.stack(
|
|
@@ -596,7 +605,7 @@ def get_recons_loss(
|
|
|
596
605
|
batch_tokens: torch.Tensor,
|
|
597
606
|
compute_kl: bool,
|
|
598
607
|
compute_ce_loss: bool,
|
|
599
|
-
ignore_tokens:
|
|
608
|
+
ignore_tokens: list[int] | None = None,
|
|
600
609
|
model_kwargs: Mapping[str, Any] = {},
|
|
601
610
|
hook_name: str | None = None,
|
|
602
611
|
) -> dict[str, Any]:
|
|
@@ -610,7 +619,7 @@ def get_recons_loss(
|
|
|
610
619
|
batch_tokens, return_type="both", loss_per_token=True, **model_kwargs
|
|
611
620
|
)
|
|
612
621
|
|
|
613
|
-
if
|
|
622
|
+
if ignore_tokens:
|
|
614
623
|
mask = torch.logical_not(
|
|
615
624
|
torch.any(
|
|
616
625
|
torch.stack([batch_tokens == token for token in ignore_tokens], dim=0),
|
|
@@ -856,11 +865,6 @@ def multiple_evals(
|
|
|
856
865
|
activation_scaler=ActivationScaler(),
|
|
857
866
|
model=current_model,
|
|
858
867
|
eval_config=eval_config,
|
|
859
|
-
ignore_tokens={
|
|
860
|
-
current_model.tokenizer.pad_token_id, # type: ignore
|
|
861
|
-
current_model.tokenizer.eos_token_id, # type: ignore
|
|
862
|
-
current_model.tokenizer.bos_token_id, # type: ignore
|
|
863
|
-
},
|
|
864
868
|
verbose=verbose,
|
|
865
869
|
)
|
|
866
870
|
eval_metrics["metrics"] = scalar_metrics
|
|
@@ -61,9 +61,11 @@ class LLMSaeEvaluator(Generic[T_TRAINING_SAE]):
|
|
|
61
61
|
data_provider: DataProvider,
|
|
62
62
|
activation_scaler: ActivationScaler,
|
|
63
63
|
) -> dict[str, Any]:
|
|
64
|
-
|
|
64
|
+
exclude_special_tokens = False
|
|
65
65
|
if self.activations_store.exclude_special_tokens is not None:
|
|
66
|
-
|
|
66
|
+
exclude_special_tokens = (
|
|
67
|
+
self.activations_store.exclude_special_tokens.tolist()
|
|
68
|
+
)
|
|
67
69
|
|
|
68
70
|
eval_config = EvalConfig(
|
|
69
71
|
batch_size_prompts=self.eval_batch_size_prompts,
|
|
@@ -81,7 +83,7 @@ class LLMSaeEvaluator(Generic[T_TRAINING_SAE]):
|
|
|
81
83
|
model=self.model,
|
|
82
84
|
activation_scaler=activation_scaler,
|
|
83
85
|
eval_config=eval_config,
|
|
84
|
-
|
|
86
|
+
exclude_special_tokens=exclude_special_tokens,
|
|
85
87
|
model_kwargs=self.model_kwargs,
|
|
86
88
|
) # not calculating featurwise metrics here.
|
|
87
89
|
|
|
@@ -29,7 +29,10 @@ from sae_lens.pretokenize_runner import get_special_token_from_cfg
|
|
|
29
29
|
from sae_lens.saes.sae import SAE, T_SAE_CONFIG, T_TRAINING_SAE_CONFIG
|
|
30
30
|
from sae_lens.tokenization_and_batching import concat_and_batch_sequences
|
|
31
31
|
from sae_lens.training.mixing_buffer import mixing_buffer
|
|
32
|
-
from sae_lens.util import
|
|
32
|
+
from sae_lens.util import (
|
|
33
|
+
extract_stop_at_layer_from_tlens_hook_name,
|
|
34
|
+
get_special_token_ids,
|
|
35
|
+
)
|
|
33
36
|
|
|
34
37
|
|
|
35
38
|
# TODO: Make an activation store config class to be consistent with the rest of the code.
|
|
@@ -113,7 +116,7 @@ class ActivationsStore:
|
|
|
113
116
|
if exclude_special_tokens is False:
|
|
114
117
|
exclude_special_tokens = None
|
|
115
118
|
if exclude_special_tokens is True:
|
|
116
|
-
exclude_special_tokens =
|
|
119
|
+
exclude_special_tokens = get_special_token_ids(model.tokenizer) # type: ignore
|
|
117
120
|
if exclude_special_tokens is not None:
|
|
118
121
|
exclude_special_tokens = torch.tensor(
|
|
119
122
|
exclude_special_tokens, dtype=torch.long, device=device
|
|
@@ -763,31 +766,6 @@ def _get_model_device(model: HookedRootModule) -> torch.device:
|
|
|
763
766
|
return next(model.parameters()).device # type: ignore
|
|
764
767
|
|
|
765
768
|
|
|
766
|
-
def _get_special_token_ids(tokenizer: PreTrainedTokenizerBase) -> list[int]:
|
|
767
|
-
"""Get all special token IDs from a tokenizer."""
|
|
768
|
-
special_tokens = set()
|
|
769
|
-
|
|
770
|
-
# Get special tokens from tokenizer attributes
|
|
771
|
-
for attr in dir(tokenizer):
|
|
772
|
-
if attr.endswith("_token_id"):
|
|
773
|
-
token_id = getattr(tokenizer, attr)
|
|
774
|
-
if token_id is not None:
|
|
775
|
-
special_tokens.add(token_id)
|
|
776
|
-
|
|
777
|
-
# Get any additional special tokens from the tokenizer's special tokens map
|
|
778
|
-
if hasattr(tokenizer, "special_tokens_map"):
|
|
779
|
-
for token in tokenizer.special_tokens_map.values():
|
|
780
|
-
if isinstance(token, str):
|
|
781
|
-
token_id = tokenizer.convert_tokens_to_ids(token) # type: ignore
|
|
782
|
-
special_tokens.add(token_id)
|
|
783
|
-
elif isinstance(token, list):
|
|
784
|
-
for t in token:
|
|
785
|
-
token_id = tokenizer.convert_tokens_to_ids(t) # type: ignore
|
|
786
|
-
special_tokens.add(token_id)
|
|
787
|
-
|
|
788
|
-
return list(special_tokens)
|
|
789
|
-
|
|
790
|
-
|
|
791
769
|
def _filter_buffer_acts(
|
|
792
770
|
buffer: tuple[torch.Tensor, torch.Tensor | None],
|
|
793
771
|
exclude_tokens: torch.Tensor | None,
|
sae_lens/util.py
CHANGED
|
@@ -5,6 +5,8 @@ from dataclasses import asdict, fields, is_dataclass
|
|
|
5
5
|
from pathlib import Path
|
|
6
6
|
from typing import Sequence, TypeVar
|
|
7
7
|
|
|
8
|
+
from transformers import PreTrainedTokenizerBase
|
|
9
|
+
|
|
8
10
|
K = TypeVar("K")
|
|
9
11
|
V = TypeVar("V")
|
|
10
12
|
|
|
@@ -63,3 +65,28 @@ def path_or_tmp_dir(path: str | Path | None):
|
|
|
63
65
|
yield Path(td)
|
|
64
66
|
else:
|
|
65
67
|
yield Path(path)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def get_special_token_ids(tokenizer: PreTrainedTokenizerBase) -> list[int]:
|
|
71
|
+
"""Get all special token IDs from a tokenizer."""
|
|
72
|
+
special_tokens = set()
|
|
73
|
+
|
|
74
|
+
# Get special tokens from tokenizer attributes
|
|
75
|
+
for attr in dir(tokenizer):
|
|
76
|
+
if attr.endswith("_token_id"):
|
|
77
|
+
token_id = getattr(tokenizer, attr)
|
|
78
|
+
if token_id is not None:
|
|
79
|
+
special_tokens.add(token_id)
|
|
80
|
+
|
|
81
|
+
# Get any additional special tokens from the tokenizer's special tokens map
|
|
82
|
+
if hasattr(tokenizer, "special_tokens_map"):
|
|
83
|
+
for token in tokenizer.special_tokens_map.values():
|
|
84
|
+
if isinstance(token, str):
|
|
85
|
+
token_id = tokenizer.convert_tokens_to_ids(token) # type: ignore
|
|
86
|
+
special_tokens.add(token_id)
|
|
87
|
+
elif isinstance(token, list):
|
|
88
|
+
for t in token:
|
|
89
|
+
token_id = tokenizer.convert_tokens_to_ids(t) # type: ignore
|
|
90
|
+
special_tokens.add(token_id)
|
|
91
|
+
|
|
92
|
+
return list(special_tokens)
|
|
@@ -1,12 +1,12 @@
|
|
|
1
|
-
sae_lens/__init__.py,sha256=
|
|
1
|
+
sae_lens/__init__.py,sha256=U6PI8XxNzEqTakvBsTnn6i8EvoMbpcviRffWBke2frk,3589
|
|
2
2
|
sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
3
|
sae_lens/analysis/hooked_sae_transformer.py,sha256=vRu6JseH1lZaEeILD5bEkQEQ1wYHHDcxD-f2olKmE9Y,14275
|
|
4
4
|
sae_lens/analysis/neuronpedia_integration.py,sha256=Gx1W7hUBEuMoasNcnOnZ1wmqbXDd1pSZ1nqKEya1HQc,4962
|
|
5
5
|
sae_lens/cache_activations_runner.py,sha256=cNeAtp2JQ_vKbeddZVM-tcPLYyyfTWL8NDna5KQpkLI,12583
|
|
6
6
|
sae_lens/config.py,sha256=IdRXSKPfYY3hwUovj-u83eep8z52gkJHII0mY0KseYY,28739
|
|
7
7
|
sae_lens/constants.py,sha256=CSjmiZ-bhjQeVLyRvWxAjBokCgkfM8mnvd7-vxLIWTY,639
|
|
8
|
-
sae_lens/evals.py,sha256=
|
|
9
|
-
sae_lens/llm_sae_training_runner.py,sha256=
|
|
8
|
+
sae_lens/evals.py,sha256=P0NUsJeGzYxFBiVKhbPzd72IFKY4gH40HHlEZ3jEAmg,39598
|
|
9
|
+
sae_lens/llm_sae_training_runner.py,sha256=8Km519LH080RZnUBeaG2T1trq5UqxoAqokNmpX4xMTM,15200
|
|
10
10
|
sae_lens/load_model.py,sha256=C8AMykctj6H7tz_xRwB06-EXj6TfW64PtSJZR5Jxn1Y,8649
|
|
11
11
|
sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
12
12
|
sae_lens/loading/pretrained_sae_loaders.py,sha256=SM4aT8NM6ezYix5c2u7p72Fz2RfvTtf7gw5RdOSKXhc,49846
|
|
@@ -25,15 +25,15 @@ sae_lens/saes/transcoder.py,sha256=BfLSbTYVNZh-ruGxseZiZJ_acEL6_7QyTdfqUr0lDOg,1
|
|
|
25
25
|
sae_lens/tokenization_and_batching.py,sha256=D_o7cXvRqhT89H3wNzoRymNALNE6eHojBWLdXOUwUGE,5438
|
|
26
26
|
sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
27
27
|
sae_lens/training/activation_scaler.py,sha256=seEE-2Qd2JMHxqgnsNWPt-DGtYGZxWPnOwCGuVNSOtI,1719
|
|
28
|
-
sae_lens/training/activations_store.py,sha256=
|
|
28
|
+
sae_lens/training/activations_store.py,sha256=hHY6rW-T7sLq2a8JPEyWdm8leuIRm_MsObZs3jRTZmE,31931
|
|
29
29
|
sae_lens/training/mixing_buffer.py,sha256=vDpYG5ZE70szDvBsRKcNHEES3h_WTKJ16qDYk5jPOVA,2015
|
|
30
30
|
sae_lens/training/optim.py,sha256=TiI9nbffzXNsI8WjcIsqa2uheW6suxqL_KDDmWXobWI,5312
|
|
31
31
|
sae_lens/training/sae_trainer.py,sha256=il4Evf-c4F3Uf2n_v-AOItCasX-uPxYTzn_sZLvLkl0,15633
|
|
32
32
|
sae_lens/training/types.py,sha256=1FpLx_Doda9vZpmfm-x1e8wGBYpyhe9Kpb_JuM5nIFM,90
|
|
33
33
|
sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
|
|
34
34
|
sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
|
|
35
|
-
sae_lens/util.py,sha256=
|
|
36
|
-
sae_lens-6.14.
|
|
37
|
-
sae_lens-6.14.
|
|
38
|
-
sae_lens-6.14.
|
|
39
|
-
sae_lens-6.14.
|
|
35
|
+
sae_lens/util.py,sha256=tCovQ-eZa1L7thPpNDL6PGOJrIMML2yLI5e0EHCOpS8,3309
|
|
36
|
+
sae_lens-6.14.2.dist-info/METADATA,sha256=WDlgsdDyQT4jmu5hxMU-pqm5PfBh0h65MTEbyuMuE3c,5318
|
|
37
|
+
sae_lens-6.14.2.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
|
|
38
|
+
sae_lens-6.14.2.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
|
|
39
|
+
sae_lens-6.14.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|