sae-lens 6.0.0rc1__tar.gz → 6.0.0rc3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc3}/PKG-INFO +1 -1
  2. {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc3}/pyproject.toml +3 -2
  3. sae_lens-6.0.0rc3/sae_lens/__init__.py +98 -0
  4. {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc3}/sae_lens/analysis/hooked_sae_transformer.py +10 -10
  5. {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc3}/sae_lens/analysis/neuronpedia_integration.py +13 -11
  6. {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc3}/sae_lens/cache_activations_runner.py +9 -7
  7. {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc3}/sae_lens/config.py +105 -235
  8. sae_lens-6.0.0rc3/sae_lens/constants.py +20 -0
  9. {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc3}/sae_lens/evals.py +34 -31
  10. sae_lens-6.0.0rc1/sae_lens/sae_training_runner.py → sae_lens-6.0.0rc3/sae_lens/llm_sae_training_runner.py +103 -70
  11. {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc3}/sae_lens/load_model.py +53 -5
  12. {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc3}/sae_lens/loading/pretrained_sae_loaders.py +36 -10
  13. sae_lens-6.0.0rc3/sae_lens/registry.py +49 -0
  14. sae_lens-6.0.0rc3/sae_lens/saes/__init__.py +48 -0
  15. {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc3}/sae_lens/saes/gated_sae.py +70 -59
  16. {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc3}/sae_lens/saes/jumprelu_sae.py +58 -72
  17. {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc3}/sae_lens/saes/sae.py +248 -273
  18. {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc3}/sae_lens/saes/standard_sae.py +75 -57
  19. {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc3}/sae_lens/saes/topk_sae.py +72 -83
  20. sae_lens-6.0.0rc3/sae_lens/training/activation_scaler.py +53 -0
  21. {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc3}/sae_lens/training/activations_store.py +105 -184
  22. sae_lens-6.0.0rc3/sae_lens/training/mixing_buffer.py +56 -0
  23. {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc3}/sae_lens/training/optim.py +60 -36
  24. {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc3}/sae_lens/training/sae_trainer.py +134 -158
  25. sae_lens-6.0.0rc3/sae_lens/training/types.py +5 -0
  26. {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc3}/sae_lens/training/upload_saes_to_huggingface.py +11 -5
  27. sae_lens-6.0.0rc3/sae_lens/util.py +47 -0
  28. sae_lens-6.0.0rc1/sae_lens/__init__.py +0 -61
  29. sae_lens-6.0.0rc1/sae_lens/regsitry.py +0 -34
  30. {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc3}/LICENSE +0 -0
  31. {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc3}/README.md +0 -0
  32. {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc3}/sae_lens/analysis/__init__.py +0 -0
  33. {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc3}/sae_lens/loading/__init__.py +0 -0
  34. {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc3}/sae_lens/loading/pretrained_saes_directory.py +0 -0
  35. {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc3}/sae_lens/pretokenize_runner.py +0 -0
  36. {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc3}/sae_lens/pretrained_saes.yaml +0 -0
  37. {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc3}/sae_lens/tokenization_and_batching.py +0 -0
  38. {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc3}/sae_lens/training/__init__.py +0 -0
  39. {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc3}/sae_lens/training/geometric_median.py +0 -0
  40. {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc3}/sae_lens/tutorial/tsea.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: sae-lens
3
- Version: 6.0.0rc1
3
+ Version: 6.0.0rc3
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  Keywords: deep-learning,sparse-autoencoders,mechanistic-interpretability,PyTorch
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "sae-lens"
3
- version = "6.0.0-rc.1"
3
+ version = "6.0.0-rc.3"
4
4
  description = "Training and Analyzing Sparse Autoencoders (SAEs)"
5
5
  authors = ["Joseph Bloom"]
6
6
  readme = "README.md"
@@ -52,13 +52,14 @@ boto3 = "^1.34.101"
52
52
  docstr-coverage = "^2.3.2"
53
53
  mkdocs = "^1.6.1"
54
54
  mkdocs-material = "^9.5.34"
55
- mkdocs-autorefs = "^1.1.0"
55
+ mkdocs-autorefs = "^1.4.2"
56
56
  mkdocs-section-index = "^0.3.9"
57
57
  mkdocstrings = "^0.25.2"
58
58
  mkdocstrings-python = "^1.10.9"
59
59
  tabulate = "^0.9.0"
60
60
  ruff = "^0.7.4"
61
61
  eai-sparsify = "^1.1.1"
62
+ mike = "^2.0.0"
62
63
 
63
64
  [tool.poetry.extras]
64
65
  mamba = ["mamba-lens"]
@@ -0,0 +1,98 @@
1
+ # ruff: noqa: E402
2
+ __version__ = "6.0.0-rc.3"
3
+
4
+ import logging
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+ from sae_lens.saes import (
9
+ SAE,
10
+ GatedSAE,
11
+ GatedSAEConfig,
12
+ GatedTrainingSAE,
13
+ GatedTrainingSAEConfig,
14
+ JumpReLUSAE,
15
+ JumpReLUSAEConfig,
16
+ JumpReLUTrainingSAE,
17
+ JumpReLUTrainingSAEConfig,
18
+ SAEConfig,
19
+ StandardSAE,
20
+ StandardSAEConfig,
21
+ StandardTrainingSAE,
22
+ StandardTrainingSAEConfig,
23
+ TopKSAE,
24
+ TopKSAEConfig,
25
+ TopKTrainingSAE,
26
+ TopKTrainingSAEConfig,
27
+ TrainingSAE,
28
+ TrainingSAEConfig,
29
+ )
30
+
31
+ from .analysis.hooked_sae_transformer import HookedSAETransformer
32
+ from .cache_activations_runner import CacheActivationsRunner
33
+ from .config import (
34
+ CacheActivationsRunnerConfig,
35
+ LanguageModelSAERunnerConfig,
36
+ LoggingConfig,
37
+ PretokenizeRunnerConfig,
38
+ )
39
+ from .evals import run_evals
40
+ from .llm_sae_training_runner import LanguageModelSAETrainingRunner, SAETrainingRunner
41
+ from .loading.pretrained_sae_loaders import (
42
+ PretrainedSaeDiskLoader,
43
+ PretrainedSaeHuggingfaceLoader,
44
+ )
45
+ from .pretokenize_runner import PretokenizeRunner, pretokenize_runner
46
+ from .registry import register_sae_class, register_sae_training_class
47
+ from .training.activations_store import ActivationsStore
48
+ from .training.upload_saes_to_huggingface import upload_saes_to_huggingface
49
+
50
+ __all__ = [
51
+ "SAE",
52
+ "SAEConfig",
53
+ "TrainingSAE",
54
+ "TrainingSAEConfig",
55
+ "HookedSAETransformer",
56
+ "ActivationsStore",
57
+ "LanguageModelSAERunnerConfig",
58
+ "LanguageModelSAETrainingRunner",
59
+ "CacheActivationsRunnerConfig",
60
+ "CacheActivationsRunner",
61
+ "PretokenizeRunnerConfig",
62
+ "PretokenizeRunner",
63
+ "pretokenize_runner",
64
+ "run_evals",
65
+ "upload_saes_to_huggingface",
66
+ "PretrainedSaeHuggingfaceLoader",
67
+ "PretrainedSaeDiskLoader",
68
+ "register_sae_class",
69
+ "register_sae_training_class",
70
+ "StandardSAE",
71
+ "StandardSAEConfig",
72
+ "StandardTrainingSAE",
73
+ "StandardTrainingSAEConfig",
74
+ "GatedSAE",
75
+ "GatedSAEConfig",
76
+ "GatedTrainingSAE",
77
+ "GatedTrainingSAEConfig",
78
+ "TopKSAE",
79
+ "TopKSAEConfig",
80
+ "TopKTrainingSAE",
81
+ "TopKTrainingSAEConfig",
82
+ "JumpReLUSAE",
83
+ "JumpReLUSAEConfig",
84
+ "JumpReLUTrainingSAE",
85
+ "JumpReLUTrainingSAEConfig",
86
+ "SAETrainingRunner",
87
+ "LoggingConfig",
88
+ ]
89
+
90
+
91
+ register_sae_class("standard", StandardSAE, StandardSAEConfig)
92
+ register_sae_training_class("standard", StandardTrainingSAE, StandardTrainingSAEConfig)
93
+ register_sae_class("gated", GatedSAE, GatedSAEConfig)
94
+ register_sae_training_class("gated", GatedTrainingSAE, GatedTrainingSAEConfig)
95
+ register_sae_class("topk", TopKSAE, TopKSAEConfig)
96
+ register_sae_training_class("topk", TopKTrainingSAE, TopKTrainingSAEConfig)
97
+ register_sae_class("jumprelu", JumpReLUSAE, JumpReLUSAEConfig)
98
+ register_sae_training_class("jumprelu", JumpReLUTrainingSAE, JumpReLUTrainingSAEConfig)
@@ -68,7 +68,7 @@ class HookedSAETransformer(HookedTransformer):
68
68
  super().__init__(*model_args, **model_kwargs)
69
69
  self.acts_to_saes: dict[str, SAE] = {} # type: ignore
70
70
 
71
- def add_sae(self, sae: SAE, use_error_term: bool | None = None):
71
+ def add_sae(self, sae: SAE[Any], use_error_term: bool | None = None):
72
72
  """Attaches an SAE to the model
73
73
 
74
74
  WARNING: This sae will be permanantly attached until you remove it with reset_saes. This function will also overwrite any existing SAE attached to the same hook point.
@@ -77,7 +77,7 @@ class HookedSAETransformer(HookedTransformer):
77
77
  sae: SparseAutoencoderBase. The SAE to attach to the model
78
78
  use_error_term: (bool | None) If provided, will set the use_error_term attribute of the SAE to this value. Determines whether the SAE returns input or reconstruction. Defaults to None.
79
79
  """
80
- act_name = sae.cfg.hook_name
80
+ act_name = sae.cfg.metadata.hook_name
81
81
  if (act_name not in self.acts_to_saes) and (act_name not in self.hook_dict):
82
82
  logging.warning(
83
83
  f"No hook found for {act_name}. Skipping. Check model.hook_dict for available hooks."
@@ -92,7 +92,7 @@ class HookedSAETransformer(HookedTransformer):
92
92
  set_deep_attr(self, act_name, sae)
93
93
  self.setup()
94
94
 
95
- def _reset_sae(self, act_name: str, prev_sae: SAE | None = None):
95
+ def _reset_sae(self, act_name: str, prev_sae: SAE[Any] | None = None):
96
96
  """Resets an SAE that was attached to the model
97
97
 
98
98
  By default will remove the SAE from that hook_point.
@@ -124,7 +124,7 @@ class HookedSAETransformer(HookedTransformer):
124
124
  def reset_saes(
125
125
  self,
126
126
  act_names: str | list[str] | None = None,
127
- prev_saes: list[SAE | None] | None = None,
127
+ prev_saes: list[SAE[Any] | None] | None = None,
128
128
  ):
129
129
  """Reset the SAEs attached to the model
130
130
 
@@ -154,7 +154,7 @@ class HookedSAETransformer(HookedTransformer):
154
154
  def run_with_saes(
155
155
  self,
156
156
  *model_args: Any,
157
- saes: SAE | list[SAE] = [],
157
+ saes: SAE[Any] | list[SAE[Any]] = [],
158
158
  reset_saes_end: bool = True,
159
159
  use_error_term: bool | None = None,
160
160
  **model_kwargs: Any,
@@ -183,7 +183,7 @@ class HookedSAETransformer(HookedTransformer):
183
183
  def run_with_cache_with_saes(
184
184
  self,
185
185
  *model_args: Any,
186
- saes: SAE | list[SAE] = [],
186
+ saes: SAE[Any] | list[SAE[Any]] = [],
187
187
  reset_saes_end: bool = True,
188
188
  use_error_term: bool | None = None,
189
189
  return_cache_object: bool = True,
@@ -225,7 +225,7 @@ class HookedSAETransformer(HookedTransformer):
225
225
  def run_with_hooks_with_saes(
226
226
  self,
227
227
  *model_args: Any,
228
- saes: SAE | list[SAE] = [],
228
+ saes: SAE[Any] | list[SAE[Any]] = [],
229
229
  reset_saes_end: bool = True,
230
230
  fwd_hooks: list[tuple[str | Callable, Callable]] = [], # type: ignore
231
231
  bwd_hooks: list[tuple[str | Callable, Callable]] = [], # type: ignore
@@ -261,7 +261,7 @@ class HookedSAETransformer(HookedTransformer):
261
261
  @contextmanager
262
262
  def saes(
263
263
  self,
264
- saes: SAE | list[SAE] = [],
264
+ saes: SAE[Any] | list[SAE[Any]] = [],
265
265
  reset_saes_end: bool = True,
266
266
  use_error_term: bool | None = None,
267
267
  ):
@@ -295,8 +295,8 @@ class HookedSAETransformer(HookedTransformer):
295
295
  saes = [saes]
296
296
  try:
297
297
  for sae in saes:
298
- act_names_to_reset.append(sae.cfg.hook_name)
299
- prev_sae = self.acts_to_saes.get(sae.cfg.hook_name, None)
298
+ act_names_to_reset.append(sae.cfg.metadata.hook_name)
299
+ prev_sae = self.acts_to_saes.get(sae.cfg.metadata.hook_name, None)
300
300
  prev_saes.append(prev_sae)
301
301
  self.add_sae(sae, use_error_term=use_error_term)
302
302
  yield self
@@ -58,7 +58,7 @@ def NanAndInfReplacer(value: str):
58
58
  return NAN_REPLACEMENT
59
59
 
60
60
 
61
- def open_neuronpedia_feature_dashboard(sae: SAE, index: int):
61
+ def open_neuronpedia_feature_dashboard(sae: SAE[Any], index: int):
62
62
  sae_id = sae.cfg.neuronpedia_id
63
63
  if sae_id is None:
64
64
  logger.warning(
@@ -70,7 +70,7 @@ def open_neuronpedia_feature_dashboard(sae: SAE, index: int):
70
70
 
71
71
 
72
72
  def get_neuronpedia_quick_list(
73
- sae: SAE,
73
+ sae: SAE[Any],
74
74
  features: list[int],
75
75
  name: str = "temporary_list",
76
76
  ):
@@ -157,9 +157,10 @@ def sleep_identity(x: T) -> T:
157
157
 
158
158
 
159
159
  @retry(wait=wait_random_exponential(min=1, max=500), stop=stop_after_attempt(10))
160
- async def simulate_and_score(
161
- simulator: NeuronSimulator, activation_records: list[ActivationRecord]
162
- ) -> ScoredSimulation:
160
+ async def simulate_and_score( # type: ignore
161
+ simulator: NeuronSimulator,
162
+ activation_records: list[ActivationRecord], # type: ignore
163
+ ) -> ScoredSimulation: # type: ignore
163
164
  """Score an explanation of a neuron by how well it predicts activations on the given text sequences."""
164
165
  scored_sequence_simulations = await asyncio.gather(
165
166
  *[
@@ -330,8 +331,9 @@ async def autointerp_neuronpedia_features( # noqa: C901
330
331
  feature.activations = []
331
332
  activation_records = [
332
333
  ActivationRecord(
333
- tokens=activation.tokens, activations=activation.act_values
334
- )
334
+ tokens=activation.tokens, # type: ignore
335
+ activations=activation.act_values, # type: ignore
336
+ ) # type: ignore
335
337
  for activation in feature.activations
336
338
  ]
337
339
 
@@ -384,15 +386,15 @@ async def autointerp_neuronpedia_features( # noqa: C901
384
386
 
385
387
  temp_activation_records = [
386
388
  ActivationRecord(
387
- tokens=[
389
+ tokens=[ # type: ignore
388
390
  token.replace("<|endoftext|>", "<|not_endoftext|>")
389
391
  .replace(" 55", "_55")
390
392
  .encode("ascii", errors="backslashreplace")
391
393
  .decode("ascii")
392
- for token in activation_record.tokens
394
+ for token in activation_record.tokens # type: ignore
393
395
  ],
394
- activations=activation_record.activations,
395
- )
396
+ activations=activation_record.activations, # type: ignore
397
+ ) # type: ignore
396
398
  for activation_record in activation_records
397
399
  ]
398
400
 
@@ -14,7 +14,8 @@ from tqdm import tqdm
14
14
  from transformer_lens.HookedTransformer import HookedRootModule
15
15
 
16
16
  from sae_lens import logger
17
- from sae_lens.config import DTYPE_MAP, CacheActivationsRunnerConfig
17
+ from sae_lens.config import CacheActivationsRunnerConfig
18
+ from sae_lens.constants import DTYPE_MAP
18
19
  from sae_lens.load_model import load_model
19
20
  from sae_lens.training.activations_store import ActivationsStore
20
21
 
@@ -33,7 +34,6 @@ def _mk_activations_store(
33
34
  dataset=override_dataset or cfg.dataset_path,
34
35
  streaming=cfg.streaming,
35
36
  hook_name=cfg.hook_name,
36
- hook_layer=cfg.hook_layer,
37
37
  hook_head_index=None,
38
38
  context_size=cfg.context_size,
39
39
  d_in=cfg.d_in,
@@ -264,7 +264,7 @@ class CacheActivationsRunner:
264
264
 
265
265
  for i in tqdm(range(self.cfg.n_buffers), desc="Caching activations"):
266
266
  try:
267
- buffer = self.activations_store.get_buffer(
267
+ buffer = self.activations_store.get_raw_buffer(
268
268
  self.cfg.n_batches_in_buffer, shuffle=False
269
269
  )
270
270
  shard = self._create_shard(buffer)
@@ -318,7 +318,7 @@ class CacheActivationsRunner:
318
318
  def _create_shard(
319
319
  self,
320
320
  buffer: tuple[
321
- Float[torch.Tensor, "(bs context_size) num_layers d_in"],
321
+ Float[torch.Tensor, "(bs context_size) d_in"],
322
322
  Int[torch.Tensor, "(bs context_size)"] | None,
323
323
  ],
324
324
  ) -> Dataset:
@@ -326,13 +326,15 @@ class CacheActivationsRunner:
326
326
  acts, token_ids = buffer
327
327
  acts = einops.rearrange(
328
328
  acts,
329
- "(bs context_size) num_layers d_in -> num_layers bs context_size d_in",
329
+ "(bs context_size) d_in -> bs context_size d_in",
330
330
  bs=self.cfg.n_seq_in_buffer,
331
331
  context_size=self.context_size,
332
332
  d_in=self.cfg.d_in,
333
- num_layers=len(hook_names),
334
333
  )
335
- shard_dict = {hook_name: act for hook_name, act in zip(hook_names, acts)}
334
+ shard_dict: dict[str, object] = {
335
+ hook_name: act_batch
336
+ for hook_name, act_batch in zip(hook_names, [acts], strict=True)
337
+ }
336
338
 
337
339
  if token_ids is not None:
338
340
  token_ids = einops.rearrange(