sae-lens 6.0.0rc2__py3-none-any.whl → 6.0.0rc4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sae_lens/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.0.0-rc.2"
2
+ __version__ = "6.0.0-rc.4"
3
3
 
4
4
  import logging
5
5
 
@@ -33,16 +33,17 @@ from .cache_activations_runner import CacheActivationsRunner
33
33
  from .config import (
34
34
  CacheActivationsRunnerConfig,
35
35
  LanguageModelSAERunnerConfig,
36
+ LoggingConfig,
36
37
  PretokenizeRunnerConfig,
37
38
  )
38
39
  from .evals import run_evals
40
+ from .llm_sae_training_runner import LanguageModelSAETrainingRunner, SAETrainingRunner
39
41
  from .loading.pretrained_sae_loaders import (
40
42
  PretrainedSaeDiskLoader,
41
43
  PretrainedSaeHuggingfaceLoader,
42
44
  )
43
45
  from .pretokenize_runner import PretokenizeRunner, pretokenize_runner
44
46
  from .registry import register_sae_class, register_sae_training_class
45
- from .sae_training_runner import SAETrainingRunner
46
47
  from .training.activations_store import ActivationsStore
47
48
  from .training.upload_saes_to_huggingface import upload_saes_to_huggingface
48
49
 
@@ -54,7 +55,7 @@ __all__ = [
54
55
  "HookedSAETransformer",
55
56
  "ActivationsStore",
56
57
  "LanguageModelSAERunnerConfig",
57
- "SAETrainingRunner",
58
+ "LanguageModelSAETrainingRunner",
58
59
  "CacheActivationsRunnerConfig",
59
60
  "CacheActivationsRunner",
60
61
  "PretokenizeRunnerConfig",
@@ -82,6 +83,8 @@ __all__ = [
82
83
  "JumpReLUSAEConfig",
83
84
  "JumpReLUTrainingSAE",
84
85
  "JumpReLUTrainingSAEConfig",
86
+ "SAETrainingRunner",
87
+ "LoggingConfig",
85
88
  ]
86
89
 
87
90
 
@@ -59,7 +59,7 @@ def NanAndInfReplacer(value: str):
59
59
 
60
60
 
61
61
  def open_neuronpedia_feature_dashboard(sae: SAE[Any], index: int):
62
- sae_id = sae.cfg.neuronpedia_id
62
+ sae_id = sae.cfg.metadata.neuronpedia_id
63
63
  if sae_id is None:
64
64
  logger.warning(
65
65
  "SAE does not have a Neuronpedia ID. Either dashboards for this SAE do not exist (yet) on Neuronpedia, or the SAE was not loaded via the from_pretrained method"
@@ -74,7 +74,7 @@ def get_neuronpedia_quick_list(
74
74
  features: list[int],
75
75
  name: str = "temporary_list",
76
76
  ):
77
- sae_id = sae.cfg.neuronpedia_id
77
+ sae_id = sae.cfg.metadata.neuronpedia_id
78
78
  if sae_id is None:
79
79
  logger.warning(
80
80
  "SAE does not have a Neuronpedia ID. Either dashboards for this SAE do not exist (yet) on Neuronpedia, or the SAE was not loaded via the from_pretrained method"
@@ -86,7 +86,7 @@ def get_neuronpedia_quick_list(
86
86
  url = url + "?name=" + name
87
87
  list_feature = [
88
88
  {
89
- "modelId": sae.cfg.model_name,
89
+ "modelId": sae.cfg.metadata.model_name,
90
90
  "layer": sae_id.split("/")[1],
91
91
  "index": str(feature),
92
92
  }
@@ -34,7 +34,6 @@ def _mk_activations_store(
34
34
  dataset=override_dataset or cfg.dataset_path,
35
35
  streaming=cfg.streaming,
36
36
  hook_name=cfg.hook_name,
37
- hook_layer=cfg.hook_layer,
38
37
  hook_head_index=None,
39
38
  context_size=cfg.context_size,
40
39
  d_in=cfg.d_in,
@@ -265,7 +264,7 @@ class CacheActivationsRunner:
265
264
 
266
265
  for i in tqdm(range(self.cfg.n_buffers), desc="Caching activations"):
267
266
  try:
268
- buffer = self.activations_store.get_buffer(
267
+ buffer = self.activations_store.get_raw_buffer(
269
268
  self.cfg.n_batches_in_buffer, shuffle=False
270
269
  )
271
270
  shard = self._create_shard(buffer)
@@ -319,7 +318,7 @@ class CacheActivationsRunner:
319
318
  def _create_shard(
320
319
  self,
321
320
  buffer: tuple[
322
- Float[torch.Tensor, "(bs context_size) num_layers d_in"],
321
+ Float[torch.Tensor, "(bs context_size) d_in"],
323
322
  Int[torch.Tensor, "(bs context_size)"] | None,
324
323
  ],
325
324
  ) -> Dataset:
@@ -327,13 +326,15 @@ class CacheActivationsRunner:
327
326
  acts, token_ids = buffer
328
327
  acts = einops.rearrange(
329
328
  acts,
330
- "(bs context_size) num_layers d_in -> num_layers bs context_size d_in",
329
+ "(bs context_size) d_in -> bs context_size d_in",
331
330
  bs=self.cfg.n_seq_in_buffer,
332
331
  context_size=self.context_size,
333
332
  d_in=self.cfg.d_in,
334
- num_layers=len(hook_names),
335
333
  )
336
- shard_dict = {hook_name: act for hook_name, act in zip(hook_names, acts)}
334
+ shard_dict: dict[str, object] = {
335
+ hook_name: act_batch
336
+ for hook_name, act_batch in zip(hook_names, [acts], strict=True)
337
+ }
337
338
 
338
339
  if token_ids is not None:
339
340
  token_ids = einops.rearrange(
sae_lens/config.py CHANGED
@@ -23,7 +23,9 @@ from sae_lens.saes.sae import TrainingSAEConfig
23
23
  if TYPE_CHECKING:
24
24
  pass
25
25
 
26
- T_TRAINING_SAE_CONFIG = TypeVar("T_TRAINING_SAE_CONFIG", bound=TrainingSAEConfig)
26
+ T_TRAINING_SAE_CONFIG = TypeVar(
27
+ "T_TRAINING_SAE_CONFIG", bound=TrainingSAEConfig, covariant=True
28
+ )
27
29
 
28
30
  HfDataset = DatasetDict | Dataset | IterableDatasetDict | IterableDataset
29
31
 
@@ -102,7 +104,6 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
102
104
  model_class_name (str): The name of the class of the model to use. This should be either `HookedTransformer` or `HookedMamba`.
103
105
  hook_name (str): The name of the hook to use. This should be a valid TransformerLens hook.
104
106
  hook_eval (str): NOT CURRENTLY IN USE. The name of the hook to use for evaluation.
105
- hook_layer (int): The index of the layer to hook. Used to stop forward passes early and speed up processing.
106
107
  hook_head_index (int, optional): When the hook is for an activation with a head index, we can specify a specific head to use here.
107
108
  dataset_path (str): A Hugging Face dataset path.
108
109
  dataset_trust_remote_code (bool): Whether to trust remote code when loading datasets from Huggingface.
@@ -159,7 +160,6 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
159
160
  model_class_name: str = "HookedTransformer"
160
161
  hook_name: str = "blocks.0.hook_mlp_out"
161
162
  hook_eval: str = "NOT_IN_USE"
162
- hook_layer: int = 0
163
163
  hook_head_index: int | None = None
164
164
  dataset_path: str = ""
165
165
  dataset_trust_remote_code: bool = True
@@ -201,7 +201,7 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
201
201
  train_batch_size_tokens: int = 4096
202
202
 
203
203
  ## Adam
204
- adam_beta1: float = 0.0
204
+ adam_beta1: float = 0.9
205
205
  adam_beta2: float = 0.999
206
206
 
207
207
  ## Learning Rate Schedule
@@ -375,6 +375,27 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
375
375
 
376
376
  return cls(**cfg)
377
377
 
378
+ def to_sae_trainer_config(self) -> "SAETrainerConfig":
379
+ return SAETrainerConfig(
380
+ n_checkpoints=self.n_checkpoints,
381
+ checkpoint_path=self.checkpoint_path,
382
+ total_training_samples=self.total_training_tokens,
383
+ device=self.device,
384
+ autocast=self.autocast,
385
+ lr=self.lr,
386
+ lr_end=self.lr_end,
387
+ lr_scheduler_name=self.lr_scheduler_name,
388
+ lr_warm_up_steps=self.lr_warm_up_steps,
389
+ adam_beta1=self.adam_beta1,
390
+ adam_beta2=self.adam_beta2,
391
+ lr_decay_steps=self.lr_decay_steps,
392
+ n_restart_cycles=self.n_restart_cycles,
393
+ train_batch_size_samples=self.train_batch_size_tokens,
394
+ dead_feature_window=self.dead_feature_window,
395
+ feature_sampling_window=self.feature_sampling_window,
396
+ logger=self.logger,
397
+ )
398
+
378
399
 
379
400
  @dataclass
380
401
  class CacheActivationsRunnerConfig:
@@ -386,7 +407,6 @@ class CacheActivationsRunnerConfig:
386
407
  model_name (str): The name of the model to use.
387
408
  model_batch_size (int): How many prompts are in the batch of the language model when generating activations.
388
409
  hook_name (str): The name of the hook to use.
389
- hook_layer (int): The layer of the final hook. Currently only support a single hook, so this should be the same as hook_name.
390
410
  d_in (int): Dimension of the model.
391
411
  total_training_tokens (int): Total number of tokens to process.
392
412
  context_size (int): Context size to process. Can be left as -1 if the dataset is tokenized.
@@ -416,7 +436,6 @@ class CacheActivationsRunnerConfig:
416
436
  model_name: str
417
437
  model_batch_size: int
418
438
  hook_name: str
419
- hook_layer: int
420
439
  d_in: int
421
440
  training_tokens: int
422
441
 
@@ -576,3 +595,28 @@ class PretokenizeRunnerConfig:
576
595
  hf_num_shards: int = 64
577
596
  hf_revision: str = "main"
578
597
  hf_is_private_repo: bool = False
598
+
599
+
600
+ @dataclass
601
+ class SAETrainerConfig:
602
+ n_checkpoints: int
603
+ checkpoint_path: str
604
+ total_training_samples: int
605
+ device: str
606
+ autocast: bool
607
+ lr: float
608
+ lr_end: float | None
609
+ lr_scheduler_name: str
610
+ lr_warm_up_steps: int
611
+ adam_beta1: float
612
+ adam_beta2: float
613
+ lr_decay_steps: int
614
+ n_restart_cycles: int
615
+ train_batch_size_samples: int
616
+ dead_feature_window: int
617
+ feature_sampling_window: int
618
+ logger: LoggingConfig
619
+
620
+ @property
621
+ def total_training_steps(self) -> int:
622
+ return self.total_training_samples // self.train_batch_size_samples
sae_lens/constants.py CHANGED
@@ -16,3 +16,5 @@ SPARSITY_FILENAME = "sparsity.safetensors"
16
16
  SAE_WEIGHTS_FILENAME = "sae_weights.safetensors"
17
17
  SAE_CFG_FILENAME = "cfg.json"
18
18
  RUNNER_CFG_FILENAME = "runner_cfg.json"
19
+ ACTIVATIONS_STORE_STATE_FILENAME = "activations_store_state.safetensors"
20
+ ACTIVATION_SCALER_CFG_FILENAME = "activation_scaler.json"
sae_lens/evals.py CHANGED
@@ -4,6 +4,7 @@ import json
4
4
  import math
5
5
  import re
6
6
  import subprocess
7
+ import sys
7
8
  from collections import defaultdict
8
9
  from collections.abc import Mapping
9
10
  from dataclasses import dataclass, field
@@ -15,13 +16,15 @@ from typing import Any
15
16
  import einops
16
17
  import pandas as pd
17
18
  import torch
18
- from tqdm import tqdm
19
+ from tqdm.auto import tqdm
19
20
  from transformer_lens import HookedTransformer
20
21
  from transformer_lens.hook_points import HookedRootModule
21
22
 
22
23
  from sae_lens.loading.pretrained_saes_directory import get_pretrained_saes_directory
23
24
  from sae_lens.saes.sae import SAE, SAEConfig
25
+ from sae_lens.training.activation_scaler import ActivationScaler
24
26
  from sae_lens.training.activations_store import ActivationsStore
27
+ from sae_lens.util import extract_stop_at_layer_from_tlens_hook_name
25
28
 
26
29
 
27
30
  def get_library_version() -> str:
@@ -103,6 +106,7 @@ def run_evals(
103
106
  sae: SAE[Any],
104
107
  activation_store: ActivationsStore,
105
108
  model: HookedRootModule,
109
+ activation_scaler: ActivationScaler,
106
110
  eval_config: EvalConfig = EvalConfig(),
107
111
  model_kwargs: Mapping[str, Any] = {},
108
112
  ignore_tokens: set[int | None] = set(),
@@ -140,6 +144,7 @@ def run_evals(
140
144
  sae,
141
145
  model,
142
146
  activation_store,
147
+ activation_scaler,
143
148
  compute_kl=eval_config.compute_kl,
144
149
  compute_ce_loss=eval_config.compute_ce_loss,
145
150
  n_batches=eval_config.n_eval_reconstruction_batches,
@@ -189,6 +194,7 @@ def run_evals(
189
194
  sae,
190
195
  model,
191
196
  activation_store,
197
+ activation_scaler,
192
198
  compute_l2_norms=eval_config.compute_l2_norms,
193
199
  compute_sparsity_metrics=eval_config.compute_sparsity_metrics,
194
200
  compute_variance_metrics=eval_config.compute_variance_metrics,
@@ -301,6 +307,7 @@ def get_downstream_reconstruction_metrics(
301
307
  sae: SAE[Any],
302
308
  model: HookedRootModule,
303
309
  activation_store: ActivationsStore,
310
+ activation_scaler: ActivationScaler,
304
311
  compute_kl: bool,
305
312
  compute_ce_loss: bool,
306
313
  n_batches: int,
@@ -326,8 +333,8 @@ def get_downstream_reconstruction_metrics(
326
333
  for metric_name, metric_value in get_recons_loss(
327
334
  sae,
328
335
  model,
336
+ activation_scaler,
329
337
  batch_tokens,
330
- activation_store,
331
338
  compute_kl=compute_kl,
332
339
  compute_ce_loss=compute_ce_loss,
333
340
  ignore_tokens=ignore_tokens,
@@ -369,6 +376,7 @@ def get_sparsity_and_variance_metrics(
369
376
  sae: SAE[Any],
370
377
  model: HookedRootModule,
371
378
  activation_store: ActivationsStore,
379
+ activation_scaler: ActivationScaler,
372
380
  n_batches: int,
373
381
  compute_l2_norms: bool,
374
382
  compute_sparsity_metrics: bool,
@@ -436,7 +444,7 @@ def get_sparsity_and_variance_metrics(
436
444
  batch_tokens,
437
445
  prepend_bos=False,
438
446
  names_filter=[hook_name],
439
- stop_at_layer=sae.cfg.metadata.hook_layer + 1,
447
+ stop_at_layer=extract_stop_at_layer_from_tlens_hook_name(hook_name),
440
448
  **model_kwargs,
441
449
  )
442
450
 
@@ -451,16 +459,14 @@ def get_sparsity_and_variance_metrics(
451
459
  original_act = cache[hook_name]
452
460
 
453
461
  # normalise if necessary (necessary in training only, otherwise we should fold the scaling in)
454
- if activation_store.normalize_activations == "expected_average_only_in":
455
- original_act = activation_store.apply_norm_scaling_factor(original_act)
462
+ original_act = activation_scaler.scale(original_act)
456
463
 
457
464
  # send the (maybe normalised) activations into the SAE
458
465
  sae_feature_activations = sae.encode(original_act.to(sae.device))
459
466
  sae_out = sae.decode(sae_feature_activations).to(original_act.device)
460
467
  del cache
461
468
 
462
- if activation_store.normalize_activations == "expected_average_only_in":
463
- sae_out = activation_store.unscale(sae_out)
469
+ sae_out = activation_scaler.unscale(sae_out)
464
470
 
465
471
  flattened_sae_input = einops.rearrange(original_act, "b ctx d -> (b ctx) d")
466
472
  flattened_sae_feature_acts = einops.rearrange(
@@ -582,8 +588,8 @@ def get_sparsity_and_variance_metrics(
582
588
  def get_recons_loss(
583
589
  sae: SAE[SAEConfig],
584
590
  model: HookedRootModule,
591
+ activation_scaler: ActivationScaler,
585
592
  batch_tokens: torch.Tensor,
586
- activation_store: ActivationsStore,
587
593
  compute_kl: bool,
588
594
  compute_ce_loss: bool,
589
595
  ignore_tokens: set[int | None] = set(),
@@ -618,15 +624,13 @@ def get_recons_loss(
618
624
  activations = activations.to(sae.device)
619
625
 
620
626
  # Handle rescaling if SAE expects it
621
- if activation_store.normalize_activations == "expected_average_only_in":
622
- activations = activation_store.apply_norm_scaling_factor(activations)
627
+ activations = activation_scaler.scale(activations)
623
628
 
624
629
  # SAE class agnost forward forward pass.
625
630
  new_activations = sae.decode(sae.encode(activations)).to(activations.dtype)
626
631
 
627
632
  # Unscale if activations were scaled prior to going into the SAE
628
- if activation_store.normalize_activations == "expected_average_only_in":
629
- new_activations = activation_store.unscale(new_activations)
633
+ new_activations = activation_scaler.unscale(new_activations)
630
634
 
631
635
  new_activations = torch.where(mask[..., None], new_activations, activations)
632
636
 
@@ -637,8 +641,7 @@ def get_recons_loss(
637
641
  activations = activations.to(sae.device)
638
642
 
639
643
  # Handle rescaling if SAE expects it
640
- if activation_store.normalize_activations == "expected_average_only_in":
641
- activations = activation_store.apply_norm_scaling_factor(activations)
644
+ activations = activation_scaler.scale(activations)
642
645
 
643
646
  # SAE class agnost forward forward pass.
644
647
  new_activations = sae.decode(sae.encode(activations.flatten(-2, -1))).to(
@@ -650,8 +653,7 @@ def get_recons_loss(
650
653
  ) # reshape to match original shape
651
654
 
652
655
  # Unscale if activations were scaled prior to going into the SAE
653
- if activation_store.normalize_activations == "expected_average_only_in":
654
- new_activations = activation_store.unscale(new_activations)
656
+ new_activations = activation_scaler.unscale(new_activations)
655
657
 
656
658
  return new_activations.to(original_device)
657
659
 
@@ -660,8 +662,7 @@ def get_recons_loss(
660
662
  activations = activations.to(sae.device)
661
663
 
662
664
  # Handle rescaling if SAE expects it
663
- if activation_store.normalize_activations == "expected_average_only_in":
664
- activations = activation_store.apply_norm_scaling_factor(activations)
665
+ activations = activation_scaler.scale(activations)
665
666
 
666
667
  new_activations = sae.decode(sae.encode(activations[:, :, head_index])).to(
667
668
  activations.dtype
@@ -669,8 +670,7 @@ def get_recons_loss(
669
670
  activations[:, :, head_index] = new_activations
670
671
 
671
672
  # Unscale if activations were scaled prior to going into the SAE
672
- if activation_store.normalize_activations == "expected_average_only_in":
673
- activations = activation_store.unscale(activations)
673
+ activations = activation_scaler.unscale(activations)
674
674
 
675
675
  return activations.to(original_device)
676
676
 
@@ -815,16 +815,18 @@ def multiple_evals(
815
815
  release=sae_release_name, # see other options in sae_lens/pretrained_saes.yaml
816
816
  sae_id=sae_id, # won't always be a hook point
817
817
  device=device,
818
- )[0]
818
+ )
819
819
 
820
820
  # move SAE to device if not there already
821
821
  sae.to(device)
822
822
 
823
- if current_model_str != sae.cfg.model_name:
823
+ if current_model_str != sae.cfg.metadata.model_name:
824
824
  del current_model # potentially saves GPU memory
825
- current_model_str = sae.cfg.model_name
825
+ current_model_str = sae.cfg.metadata.model_name
826
826
  current_model = HookedTransformer.from_pretrained_no_processing(
827
- current_model_str, device=device, **sae.cfg.model_from_pretrained_kwargs
827
+ current_model_str,
828
+ device=device,
829
+ **sae.cfg.metadata.model_from_pretrained_kwargs,
828
830
  )
829
831
  assert current_model is not None
830
832
 
@@ -849,6 +851,7 @@ def multiple_evals(
849
851
  scalar_metrics, feature_metrics = run_evals(
850
852
  sae=sae,
851
853
  activation_store=activation_store,
854
+ activation_scaler=ActivationScaler(),
852
855
  model=current_model,
853
856
  eval_config=eval_config,
854
857
  ignore_tokens={
@@ -941,7 +944,7 @@ def process_results(
941
944
  }
942
945
 
943
946
 
944
- if __name__ == "__main__":
947
+ def process_args(args: list[str]) -> argparse.Namespace:
945
948
  arg_parser = argparse.ArgumentParser(description="Run evaluations on SAEs")
946
949
  arg_parser.add_argument(
947
950
  "sae_regex_pattern",
@@ -1031,11 +1034,19 @@ if __name__ == "__main__":
1031
1034
  help="Enable verbose output with tqdm loaders.",
1032
1035
  )
1033
1036
 
1034
- args = arg_parser.parse_args()
1035
- eval_results = run_evaluations(args)
1036
- output_files = process_results(eval_results, args.output_dir)
1037
+ return arg_parser.parse_args(args)
1038
+
1039
+
1040
+ def run_evals_cli(args: list[str]) -> None:
1041
+ opts = process_args(args)
1042
+ eval_results = run_evaluations(opts)
1043
+ output_files = process_results(eval_results, opts.output_dir)
1037
1044
 
1038
1045
  print("Evaluation complete. Output files:")
1039
1046
  print(f"Individual JSONs: {len(output_files['individual_jsons'])}") # type: ignore
1040
1047
  print(f"Combined JSON: {output_files['combined_json']}")
1041
1048
  print(f"CSV: {output_files['csv']}")
1049
+
1050
+
1051
+ if __name__ == "__main__":
1052
+ run_evals_cli(sys.argv[1:])