sae-lens 6.14.1__py3-none-any.whl → 6.14.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sae_lens/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.14.1"
2
+ __version__ = "6.14.2"
3
3
 
4
4
  import logging
5
5
 
sae_lens/evals.py CHANGED
@@ -11,7 +11,7 @@ from dataclasses import dataclass, field
11
11
  from functools import partial
12
12
  from importlib.metadata import PackageNotFoundError, version
13
13
  from pathlib import Path
14
- from typing import Any
14
+ from typing import Any, Iterable
15
15
 
16
16
  import einops
17
17
  import pandas as pd
@@ -24,7 +24,10 @@ from sae_lens.loading.pretrained_saes_directory import get_pretrained_saes_direc
24
24
  from sae_lens.saes.sae import SAE, SAEConfig
25
25
  from sae_lens.training.activation_scaler import ActivationScaler
26
26
  from sae_lens.training.activations_store import ActivationsStore
27
- from sae_lens.util import extract_stop_at_layer_from_tlens_hook_name
27
+ from sae_lens.util import (
28
+ extract_stop_at_layer_from_tlens_hook_name,
29
+ get_special_token_ids,
30
+ )
28
31
 
29
32
 
30
33
  def get_library_version() -> str:
@@ -109,9 +112,15 @@ def run_evals(
109
112
  activation_scaler: ActivationScaler,
110
113
  eval_config: EvalConfig = EvalConfig(),
111
114
  model_kwargs: Mapping[str, Any] = {},
112
- ignore_tokens: set[int | None] = set(),
115
+ exclude_special_tokens: Iterable[int] | bool = True,
113
116
  verbose: bool = False,
114
117
  ) -> tuple[dict[str, Any], dict[str, Any]]:
118
+ ignore_tokens = None
119
+ if exclude_special_tokens is True:
120
+ ignore_tokens = list(get_special_token_ids(model.tokenizer)) # type: ignore
121
+ elif exclude_special_tokens:
122
+ ignore_tokens = list(exclude_special_tokens)
123
+
115
124
  hook_name = sae.cfg.metadata.hook_name
116
125
  actual_batch_size = (
117
126
  eval_config.batch_size_prompts or activation_store.store_batch_size_prompts
@@ -312,7 +321,7 @@ def get_downstream_reconstruction_metrics(
312
321
  compute_ce_loss: bool,
313
322
  n_batches: int,
314
323
  eval_batch_size_prompts: int,
315
- ignore_tokens: set[int | None] = set(),
324
+ ignore_tokens: list[int] | None = None,
316
325
  verbose: bool = False,
317
326
  ):
318
327
  metrics_dict = {}
@@ -339,7 +348,7 @@ def get_downstream_reconstruction_metrics(
339
348
  compute_ce_loss=compute_ce_loss,
340
349
  ignore_tokens=ignore_tokens,
341
350
  ).items():
342
- if len(ignore_tokens) > 0:
351
+ if ignore_tokens:
343
352
  mask = torch.logical_not(
344
353
  torch.any(
345
354
  torch.stack(
@@ -384,7 +393,7 @@ def get_sparsity_and_variance_metrics(
384
393
  compute_featurewise_density_statistics: bool,
385
394
  eval_batch_size_prompts: int,
386
395
  model_kwargs: Mapping[str, Any],
387
- ignore_tokens: set[int | None] = set(),
396
+ ignore_tokens: list[int] | None = None,
388
397
  verbose: bool = False,
389
398
  ) -> tuple[dict[str, Any], dict[str, Any]]:
390
399
  hook_name = sae.cfg.metadata.hook_name
@@ -426,7 +435,7 @@ def get_sparsity_and_variance_metrics(
426
435
  for _ in batch_iter:
427
436
  batch_tokens = activation_store.get_batch_tokens(eval_batch_size_prompts)
428
437
 
429
- if len(ignore_tokens) > 0:
438
+ if ignore_tokens:
430
439
  mask = torch.logical_not(
431
440
  torch.any(
432
441
  torch.stack(
@@ -596,7 +605,7 @@ def get_recons_loss(
596
605
  batch_tokens: torch.Tensor,
597
606
  compute_kl: bool,
598
607
  compute_ce_loss: bool,
599
- ignore_tokens: set[int | None] = set(),
608
+ ignore_tokens: list[int] | None = None,
600
609
  model_kwargs: Mapping[str, Any] = {},
601
610
  hook_name: str | None = None,
602
611
  ) -> dict[str, Any]:
@@ -610,7 +619,7 @@ def get_recons_loss(
610
619
  batch_tokens, return_type="both", loss_per_token=True, **model_kwargs
611
620
  )
612
621
 
613
- if len(ignore_tokens) > 0:
622
+ if ignore_tokens:
614
623
  mask = torch.logical_not(
615
624
  torch.any(
616
625
  torch.stack([batch_tokens == token for token in ignore_tokens], dim=0),
@@ -856,11 +865,6 @@ def multiple_evals(
856
865
  activation_scaler=ActivationScaler(),
857
866
  model=current_model,
858
867
  eval_config=eval_config,
859
- ignore_tokens={
860
- current_model.tokenizer.pad_token_id, # type: ignore
861
- current_model.tokenizer.eos_token_id, # type: ignore
862
- current_model.tokenizer.bos_token_id, # type: ignore
863
- },
864
868
  verbose=verbose,
865
869
  )
866
870
  eval_metrics["metrics"] = scalar_metrics
@@ -61,9 +61,11 @@ class LLMSaeEvaluator(Generic[T_TRAINING_SAE]):
61
61
  data_provider: DataProvider,
62
62
  activation_scaler: ActivationScaler,
63
63
  ) -> dict[str, Any]:
64
- ignore_tokens = set()
64
+ exclude_special_tokens = False
65
65
  if self.activations_store.exclude_special_tokens is not None:
66
- ignore_tokens = set(self.activations_store.exclude_special_tokens.tolist())
66
+ exclude_special_tokens = (
67
+ self.activations_store.exclude_special_tokens.tolist()
68
+ )
67
69
 
68
70
  eval_config = EvalConfig(
69
71
  batch_size_prompts=self.eval_batch_size_prompts,
@@ -81,7 +83,7 @@ class LLMSaeEvaluator(Generic[T_TRAINING_SAE]):
81
83
  model=self.model,
82
84
  activation_scaler=activation_scaler,
83
85
  eval_config=eval_config,
84
- ignore_tokens=ignore_tokens,
86
+ exclude_special_tokens=exclude_special_tokens,
85
87
  model_kwargs=self.model_kwargs,
86
88
  ) # not calculating featurwise metrics here.
87
89
 
@@ -29,7 +29,10 @@ from sae_lens.pretokenize_runner import get_special_token_from_cfg
29
29
  from sae_lens.saes.sae import SAE, T_SAE_CONFIG, T_TRAINING_SAE_CONFIG
30
30
  from sae_lens.tokenization_and_batching import concat_and_batch_sequences
31
31
  from sae_lens.training.mixing_buffer import mixing_buffer
32
- from sae_lens.util import extract_stop_at_layer_from_tlens_hook_name
32
+ from sae_lens.util import (
33
+ extract_stop_at_layer_from_tlens_hook_name,
34
+ get_special_token_ids,
35
+ )
33
36
 
34
37
 
35
38
  # TODO: Make an activation store config class to be consistent with the rest of the code.
@@ -113,7 +116,7 @@ class ActivationsStore:
113
116
  if exclude_special_tokens is False:
114
117
  exclude_special_tokens = None
115
118
  if exclude_special_tokens is True:
116
- exclude_special_tokens = _get_special_token_ids(model.tokenizer) # type: ignore
119
+ exclude_special_tokens = get_special_token_ids(model.tokenizer) # type: ignore
117
120
  if exclude_special_tokens is not None:
118
121
  exclude_special_tokens = torch.tensor(
119
122
  exclude_special_tokens, dtype=torch.long, device=device
@@ -763,31 +766,6 @@ def _get_model_device(model: HookedRootModule) -> torch.device:
763
766
  return next(model.parameters()).device # type: ignore
764
767
 
765
768
 
766
- def _get_special_token_ids(tokenizer: PreTrainedTokenizerBase) -> list[int]:
767
- """Get all special token IDs from a tokenizer."""
768
- special_tokens = set()
769
-
770
- # Get special tokens from tokenizer attributes
771
- for attr in dir(tokenizer):
772
- if attr.endswith("_token_id"):
773
- token_id = getattr(tokenizer, attr)
774
- if token_id is not None:
775
- special_tokens.add(token_id)
776
-
777
- # Get any additional special tokens from the tokenizer's special tokens map
778
- if hasattr(tokenizer, "special_tokens_map"):
779
- for token in tokenizer.special_tokens_map.values():
780
- if isinstance(token, str):
781
- token_id = tokenizer.convert_tokens_to_ids(token) # type: ignore
782
- special_tokens.add(token_id)
783
- elif isinstance(token, list):
784
- for t in token:
785
- token_id = tokenizer.convert_tokens_to_ids(t) # type: ignore
786
- special_tokens.add(token_id)
787
-
788
- return list(special_tokens)
789
-
790
-
791
769
  def _filter_buffer_acts(
792
770
  buffer: tuple[torch.Tensor, torch.Tensor | None],
793
771
  exclude_tokens: torch.Tensor | None,
sae_lens/util.py CHANGED
@@ -5,6 +5,8 @@ from dataclasses import asdict, fields, is_dataclass
5
5
  from pathlib import Path
6
6
  from typing import Sequence, TypeVar
7
7
 
8
+ from transformers import PreTrainedTokenizerBase
9
+
8
10
  K = TypeVar("K")
9
11
  V = TypeVar("V")
10
12
 
@@ -63,3 +65,28 @@ def path_or_tmp_dir(path: str | Path | None):
63
65
  yield Path(td)
64
66
  else:
65
67
  yield Path(path)
68
+
69
+
70
+ def get_special_token_ids(tokenizer: PreTrainedTokenizerBase) -> list[int]:
71
+ """Get all special token IDs from a tokenizer."""
72
+ special_tokens = set()
73
+
74
+ # Get special tokens from tokenizer attributes
75
+ for attr in dir(tokenizer):
76
+ if attr.endswith("_token_id"):
77
+ token_id = getattr(tokenizer, attr)
78
+ if token_id is not None:
79
+ special_tokens.add(token_id)
80
+
81
+ # Get any additional special tokens from the tokenizer's special tokens map
82
+ if hasattr(tokenizer, "special_tokens_map"):
83
+ for token in tokenizer.special_tokens_map.values():
84
+ if isinstance(token, str):
85
+ token_id = tokenizer.convert_tokens_to_ids(token) # type: ignore
86
+ special_tokens.add(token_id)
87
+ elif isinstance(token, list):
88
+ for t in token:
89
+ token_id = tokenizer.convert_tokens_to_ids(t) # type: ignore
90
+ special_tokens.add(token_id)
91
+
92
+ return list(special_tokens)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sae-lens
3
- Version: 6.14.1
3
+ Version: 6.14.2
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  License-File: LICENSE
@@ -1,12 +1,12 @@
1
- sae_lens/__init__.py,sha256=bh_CgiUTwniwjnBsHPO170zHd10hLM5fCeAgMZc-8n4,3589
1
+ sae_lens/__init__.py,sha256=U6PI8XxNzEqTakvBsTnn6i8EvoMbpcviRffWBke2frk,3589
2
2
  sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  sae_lens/analysis/hooked_sae_transformer.py,sha256=vRu6JseH1lZaEeILD5bEkQEQ1wYHHDcxD-f2olKmE9Y,14275
4
4
  sae_lens/analysis/neuronpedia_integration.py,sha256=Gx1W7hUBEuMoasNcnOnZ1wmqbXDd1pSZ1nqKEya1HQc,4962
5
5
  sae_lens/cache_activations_runner.py,sha256=cNeAtp2JQ_vKbeddZVM-tcPLYyyfTWL8NDna5KQpkLI,12583
6
6
  sae_lens/config.py,sha256=IdRXSKPfYY3hwUovj-u83eep8z52gkJHII0mY0KseYY,28739
7
7
  sae_lens/constants.py,sha256=CSjmiZ-bhjQeVLyRvWxAjBokCgkfM8mnvd7-vxLIWTY,639
8
- sae_lens/evals.py,sha256=p4AOueeemhJXyfLx2TxOva8LXxXj63JSKe9Lnib3mHs,39623
9
- sae_lens/llm_sae_training_runner.py,sha256=sJTcDX1bUJJ_jZLUT88-8KUYIAPeUGoXktX68PsBqw0,15137
8
+ sae_lens/evals.py,sha256=P0NUsJeGzYxFBiVKhbPzd72IFKY4gH40HHlEZ3jEAmg,39598
9
+ sae_lens/llm_sae_training_runner.py,sha256=8Km519LH080RZnUBeaG2T1trq5UqxoAqokNmpX4xMTM,15200
10
10
  sae_lens/load_model.py,sha256=C8AMykctj6H7tz_xRwB06-EXj6TfW64PtSJZR5Jxn1Y,8649
11
11
  sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
12
  sae_lens/loading/pretrained_sae_loaders.py,sha256=SM4aT8NM6ezYix5c2u7p72Fz2RfvTtf7gw5RdOSKXhc,49846
@@ -25,15 +25,15 @@ sae_lens/saes/transcoder.py,sha256=BfLSbTYVNZh-ruGxseZiZJ_acEL6_7QyTdfqUr0lDOg,1
25
25
  sae_lens/tokenization_and_batching.py,sha256=D_o7cXvRqhT89H3wNzoRymNALNE6eHojBWLdXOUwUGE,5438
26
26
  sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
27
27
  sae_lens/training/activation_scaler.py,sha256=seEE-2Qd2JMHxqgnsNWPt-DGtYGZxWPnOwCGuVNSOtI,1719
28
- sae_lens/training/activations_store.py,sha256=2EUY2abqpT5El3T95sypM_JRDgiKL3VeT73U9SQIFGY,32903
28
+ sae_lens/training/activations_store.py,sha256=hHY6rW-T7sLq2a8JPEyWdm8leuIRm_MsObZs3jRTZmE,31931
29
29
  sae_lens/training/mixing_buffer.py,sha256=vDpYG5ZE70szDvBsRKcNHEES3h_WTKJ16qDYk5jPOVA,2015
30
30
  sae_lens/training/optim.py,sha256=TiI9nbffzXNsI8WjcIsqa2uheW6suxqL_KDDmWXobWI,5312
31
31
  sae_lens/training/sae_trainer.py,sha256=il4Evf-c4F3Uf2n_v-AOItCasX-uPxYTzn_sZLvLkl0,15633
32
32
  sae_lens/training/types.py,sha256=1FpLx_Doda9vZpmfm-x1e8wGBYpyhe9Kpb_JuM5nIFM,90
33
33
  sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
34
34
  sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
35
- sae_lens/util.py,sha256=lW7fBn_b8quvRYlen9PUmB7km60YhKyjmuelB1f6KzQ,2253
36
- sae_lens-6.14.1.dist-info/METADATA,sha256=ZE2ppvNRrI_CAr7jQ2TdcPQmfEdhLoo-UMW83KVbtvU,5318
37
- sae_lens-6.14.1.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
38
- sae_lens-6.14.1.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
39
- sae_lens-6.14.1.dist-info/RECORD,,
35
+ sae_lens/util.py,sha256=tCovQ-eZa1L7thPpNDL6PGOJrIMML2yLI5e0EHCOpS8,3309
36
+ sae_lens-6.14.2.dist-info/METADATA,sha256=WDlgsdDyQT4jmu5hxMU-pqm5PfBh0h65MTEbyuMuE3c,5318
37
+ sae_lens-6.14.2.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
38
+ sae_lens-6.14.2.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
39
+ sae_lens-6.14.2.dist-info/RECORD,,