sae-lens 5.11.0__py3-none-any.whl → 6.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. sae_lens/__init__.py +60 -7
  2. sae_lens/analysis/hooked_sae_transformer.py +12 -12
  3. sae_lens/analysis/neuronpedia_integration.py +16 -14
  4. sae_lens/cache_activations_runner.py +9 -7
  5. sae_lens/config.py +170 -258
  6. sae_lens/constants.py +21 -0
  7. sae_lens/evals.py +59 -44
  8. sae_lens/llm_sae_training_runner.py +377 -0
  9. sae_lens/load_model.py +52 -4
  10. sae_lens/{toolkit → loading}/pretrained_sae_loaders.py +85 -32
  11. sae_lens/registry.py +49 -0
  12. sae_lens/saes/__init__.py +48 -0
  13. sae_lens/saes/gated_sae.py +254 -0
  14. sae_lens/saes/jumprelu_sae.py +348 -0
  15. sae_lens/saes/sae.py +1076 -0
  16. sae_lens/saes/standard_sae.py +178 -0
  17. sae_lens/saes/topk_sae.py +300 -0
  18. sae_lens/training/activation_scaler.py +53 -0
  19. sae_lens/training/activations_store.py +103 -184
  20. sae_lens/training/mixing_buffer.py +56 -0
  21. sae_lens/training/optim.py +60 -36
  22. sae_lens/training/sae_trainer.py +155 -177
  23. sae_lens/training/types.py +5 -0
  24. sae_lens/training/upload_saes_to_huggingface.py +13 -7
  25. sae_lens/util.py +47 -0
  26. {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/METADATA +1 -1
  27. sae_lens-6.0.0.dist-info/RECORD +37 -0
  28. sae_lens/sae.py +0 -747
  29. sae_lens/sae_training_runner.py +0 -251
  30. sae_lens/training/geometric_median.py +0 -101
  31. sae_lens/training/training_sae.py +0 -710
  32. sae_lens-5.11.0.dist-info/RECORD +0 -28
  33. /sae_lens/{toolkit → loading}/__init__.py +0 -0
  34. /sae_lens/{toolkit → loading}/pretrained_saes_directory.py +0 -0
  35. {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/LICENSE +0 -0
  36. {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/WHEEL +0 -0
sae_lens/evals.py CHANGED
@@ -4,6 +4,7 @@ import json
4
4
  import math
5
5
  import re
6
6
  import subprocess
7
+ import sys
7
8
  from collections import defaultdict
8
9
  from collections.abc import Mapping
9
10
  from dataclasses import dataclass, field
@@ -15,13 +16,15 @@ from typing import Any
15
16
  import einops
16
17
  import pandas as pd
17
18
  import torch
18
- from tqdm import tqdm
19
+ from tqdm.auto import tqdm
19
20
  from transformer_lens import HookedTransformer
20
21
  from transformer_lens.hook_points import HookedRootModule
21
22
 
22
- from sae_lens.sae import SAE
23
- from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory
23
+ from sae_lens.loading.pretrained_saes_directory import get_pretrained_saes_directory
24
+ from sae_lens.saes.sae import SAE, SAEConfig
25
+ from sae_lens.training.activation_scaler import ActivationScaler
24
26
  from sae_lens.training.activations_store import ActivationsStore
27
+ from sae_lens.util import extract_stop_at_layer_from_tlens_hook_name
25
28
 
26
29
 
27
30
  def get_library_version() -> str:
@@ -100,15 +103,16 @@ def get_eval_everything_config(
100
103
 
101
104
  @torch.no_grad()
102
105
  def run_evals(
103
- sae: SAE,
106
+ sae: SAE[Any],
104
107
  activation_store: ActivationsStore,
105
108
  model: HookedRootModule,
109
+ activation_scaler: ActivationScaler,
106
110
  eval_config: EvalConfig = EvalConfig(),
107
111
  model_kwargs: Mapping[str, Any] = {},
108
112
  ignore_tokens: set[int | None] = set(),
109
113
  verbose: bool = False,
110
114
  ) -> tuple[dict[str, Any], dict[str, Any]]:
111
- hook_name = sae.cfg.hook_name
115
+ hook_name = sae.cfg.metadata.hook_name
112
116
  actual_batch_size = (
113
117
  eval_config.batch_size_prompts or activation_store.store_batch_size_prompts
114
118
  )
@@ -140,6 +144,7 @@ def run_evals(
140
144
  sae,
141
145
  model,
142
146
  activation_store,
147
+ activation_scaler,
143
148
  compute_kl=eval_config.compute_kl,
144
149
  compute_ce_loss=eval_config.compute_ce_loss,
145
150
  n_batches=eval_config.n_eval_reconstruction_batches,
@@ -189,6 +194,7 @@ def run_evals(
189
194
  sae,
190
195
  model,
191
196
  activation_store,
197
+ activation_scaler,
192
198
  compute_l2_norms=eval_config.compute_l2_norms,
193
199
  compute_sparsity_metrics=eval_config.compute_sparsity_metrics,
194
200
  compute_variance_metrics=eval_config.compute_variance_metrics,
@@ -274,12 +280,11 @@ def run_evals(
274
280
  return all_metrics, feature_metrics
275
281
 
276
282
 
277
- def get_featurewise_weight_based_metrics(sae: SAE) -> dict[str, Any]:
283
+ def get_featurewise_weight_based_metrics(sae: SAE[Any]) -> dict[str, Any]:
278
284
  unit_norm_encoders = (sae.W_enc / sae.W_enc.norm(dim=0, keepdim=True)).cpu()
279
285
  unit_norm_decoder = (sae.W_dec.T / sae.W_dec.T.norm(dim=0, keepdim=True)).cpu()
280
286
 
281
287
  encoder_norms = sae.W_enc.norm(dim=-2).cpu().tolist()
282
- encoder_bias = sae.b_enc.cpu().tolist()
283
288
  encoder_decoder_cosine_sim = (
284
289
  torch.nn.functional.cosine_similarity(
285
290
  unit_norm_decoder.T,
@@ -289,17 +294,20 @@ def get_featurewise_weight_based_metrics(sae: SAE) -> dict[str, Any]:
289
294
  .tolist()
290
295
  )
291
296
 
292
- return {
293
- "encoder_bias": encoder_bias,
297
+ metrics = {
294
298
  "encoder_norm": encoder_norms,
295
299
  "encoder_decoder_cosine_sim": encoder_decoder_cosine_sim,
296
300
  }
301
+ if hasattr(sae, "b_enc") and sae.b_enc is not None:
302
+ metrics["encoder_bias"] = sae.b_enc.cpu().tolist() # type: ignore
303
+ return metrics
297
304
 
298
305
 
299
306
  def get_downstream_reconstruction_metrics(
300
- sae: SAE,
307
+ sae: SAE[Any],
301
308
  model: HookedRootModule,
302
309
  activation_store: ActivationsStore,
310
+ activation_scaler: ActivationScaler,
303
311
  compute_kl: bool,
304
312
  compute_ce_loss: bool,
305
313
  n_batches: int,
@@ -325,8 +333,8 @@ def get_downstream_reconstruction_metrics(
325
333
  for metric_name, metric_value in get_recons_loss(
326
334
  sae,
327
335
  model,
336
+ activation_scaler,
328
337
  batch_tokens,
329
- activation_store,
330
338
  compute_kl=compute_kl,
331
339
  compute_ce_loss=compute_ce_loss,
332
340
  ignore_tokens=ignore_tokens,
@@ -365,9 +373,10 @@ def get_downstream_reconstruction_metrics(
365
373
 
366
374
 
367
375
  def get_sparsity_and_variance_metrics(
368
- sae: SAE,
376
+ sae: SAE[Any],
369
377
  model: HookedRootModule,
370
378
  activation_store: ActivationsStore,
379
+ activation_scaler: ActivationScaler,
371
380
  n_batches: int,
372
381
  compute_l2_norms: bool,
373
382
  compute_sparsity_metrics: bool,
@@ -378,8 +387,8 @@ def get_sparsity_and_variance_metrics(
378
387
  ignore_tokens: set[int | None] = set(),
379
388
  verbose: bool = False,
380
389
  ) -> tuple[dict[str, Any], dict[str, Any]]:
381
- hook_name = sae.cfg.hook_name
382
- hook_head_index = sae.cfg.hook_head_index
390
+ hook_name = sae.cfg.metadata.hook_name
391
+ hook_head_index = sae.cfg.metadata.hook_head_index
383
392
 
384
393
  metric_dict = {}
385
394
  feature_metric_dict = {}
@@ -435,7 +444,7 @@ def get_sparsity_and_variance_metrics(
435
444
  batch_tokens,
436
445
  prepend_bos=False,
437
446
  names_filter=[hook_name],
438
- stop_at_layer=sae.cfg.hook_layer + 1,
447
+ stop_at_layer=extract_stop_at_layer_from_tlens_hook_name(hook_name),
439
448
  **model_kwargs,
440
449
  )
441
450
 
@@ -450,16 +459,14 @@ def get_sparsity_and_variance_metrics(
450
459
  original_act = cache[hook_name]
451
460
 
452
461
  # normalise if necessary (necessary in training only, otherwise we should fold the scaling in)
453
- if activation_store.normalize_activations == "expected_average_only_in":
454
- original_act = activation_store.apply_norm_scaling_factor(original_act)
462
+ original_act = activation_scaler.scale(original_act)
455
463
 
456
464
  # send the (maybe normalised) activations into the SAE
457
465
  sae_feature_activations = sae.encode(original_act.to(sae.device))
458
466
  sae_out = sae.decode(sae_feature_activations).to(original_act.device)
459
467
  del cache
460
468
 
461
- if activation_store.normalize_activations == "expected_average_only_in":
462
- sae_out = activation_store.unscale(sae_out)
469
+ sae_out = activation_scaler.unscale(sae_out)
463
470
 
464
471
  flattened_sae_input = einops.rearrange(original_act, "b ctx d -> (b ctx) d")
465
472
  flattened_sae_feature_acts = einops.rearrange(
@@ -579,17 +586,21 @@ def get_sparsity_and_variance_metrics(
579
586
 
580
587
  @torch.no_grad()
581
588
  def get_recons_loss(
582
- sae: SAE,
589
+ sae: SAE[SAEConfig],
583
590
  model: HookedRootModule,
591
+ activation_scaler: ActivationScaler,
584
592
  batch_tokens: torch.Tensor,
585
- activation_store: ActivationsStore,
586
593
  compute_kl: bool,
587
594
  compute_ce_loss: bool,
588
595
  ignore_tokens: set[int | None] = set(),
589
596
  model_kwargs: Mapping[str, Any] = {},
597
+ hook_name: str | None = None,
590
598
  ) -> dict[str, Any]:
591
- hook_name = sae.cfg.hook_name
592
- head_index = sae.cfg.hook_head_index
599
+ hook_name = hook_name or sae.cfg.metadata.hook_name
600
+ head_index = sae.cfg.metadata.hook_head_index
601
+
602
+ if hook_name is None:
603
+ raise ValueError("hook_name must be provided")
593
604
 
594
605
  original_logits, original_ce_loss = model(
595
606
  batch_tokens, return_type="both", loss_per_token=True, **model_kwargs
@@ -613,15 +624,13 @@ def get_recons_loss(
613
624
  activations = activations.to(sae.device)
614
625
 
615
626
  # Handle rescaling if SAE expects it
616
- if activation_store.normalize_activations == "expected_average_only_in":
617
- activations = activation_store.apply_norm_scaling_factor(activations)
627
+ activations = activation_scaler.scale(activations)
618
628
 
619
629
  # SAE class agnost forward forward pass.
620
630
  new_activations = sae.decode(sae.encode(activations)).to(activations.dtype)
621
631
 
622
632
  # Unscale if activations were scaled prior to going into the SAE
623
- if activation_store.normalize_activations == "expected_average_only_in":
624
- new_activations = activation_store.unscale(new_activations)
633
+ new_activations = activation_scaler.unscale(new_activations)
625
634
 
626
635
  new_activations = torch.where(mask[..., None], new_activations, activations)
627
636
 
@@ -632,8 +641,7 @@ def get_recons_loss(
632
641
  activations = activations.to(sae.device)
633
642
 
634
643
  # Handle rescaling if SAE expects it
635
- if activation_store.normalize_activations == "expected_average_only_in":
636
- activations = activation_store.apply_norm_scaling_factor(activations)
644
+ activations = activation_scaler.scale(activations)
637
645
 
638
646
  # SAE class agnost forward forward pass.
639
647
  new_activations = sae.decode(sae.encode(activations.flatten(-2, -1))).to(
@@ -645,8 +653,7 @@ def get_recons_loss(
645
653
  ) # reshape to match original shape
646
654
 
647
655
  # Unscale if activations were scaled prior to going into the SAE
648
- if activation_store.normalize_activations == "expected_average_only_in":
649
- new_activations = activation_store.unscale(new_activations)
656
+ new_activations = activation_scaler.unscale(new_activations)
650
657
 
651
658
  return new_activations.to(original_device)
652
659
 
@@ -655,8 +662,7 @@ def get_recons_loss(
655
662
  activations = activations.to(sae.device)
656
663
 
657
664
  # Handle rescaling if SAE expects it
658
- if activation_store.normalize_activations == "expected_average_only_in":
659
- activations = activation_store.apply_norm_scaling_factor(activations)
665
+ activations = activation_scaler.scale(activations)
660
666
 
661
667
  new_activations = sae.decode(sae.encode(activations[:, :, head_index])).to(
662
668
  activations.dtype
@@ -664,8 +670,7 @@ def get_recons_loss(
664
670
  activations[:, :, head_index] = new_activations
665
671
 
666
672
  # Unscale if activations were scaled prior to going into the SAE
667
- if activation_store.normalize_activations == "expected_average_only_in":
668
- activations = activation_store.unscale(activations)
673
+ activations = activation_scaler.unscale(activations)
669
674
 
670
675
  return activations.to(original_device)
671
676
 
@@ -794,22 +799,23 @@ def multiple_evals(
794
799
 
795
800
  current_model = None
796
801
  current_model_str = None
797
- print(filtered_saes)
798
802
  for sae_release_name, sae_id, _, _ in tqdm(filtered_saes):
799
803
  sae = SAE.from_pretrained(
800
804
  release=sae_release_name, # see other options in sae_lens/pretrained_saes.yaml
801
805
  sae_id=sae_id, # won't always be a hook point
802
806
  device=device,
803
- )[0]
807
+ )
804
808
 
805
809
  # move SAE to device if not there already
806
810
  sae.to(device)
807
811
 
808
- if current_model_str != sae.cfg.model_name:
812
+ if current_model_str != sae.cfg.metadata.model_name:
809
813
  del current_model # potentially saves GPU memory
810
- current_model_str = sae.cfg.model_name
814
+ current_model_str = sae.cfg.metadata.model_name
811
815
  current_model = HookedTransformer.from_pretrained_no_processing(
812
- current_model_str, device=device, **sae.cfg.model_from_pretrained_kwargs
816
+ current_model_str,
817
+ device=device,
818
+ **sae.cfg.metadata.model_from_pretrained_kwargs,
813
819
  )
814
820
  assert current_model is not None
815
821
 
@@ -834,6 +840,7 @@ def multiple_evals(
834
840
  scalar_metrics, feature_metrics = run_evals(
835
841
  sae=sae,
836
842
  activation_store=activation_store,
843
+ activation_scaler=ActivationScaler(),
837
844
  model=current_model,
838
845
  eval_config=eval_config,
839
846
  ignore_tokens={
@@ -926,7 +933,7 @@ def process_results(
926
933
  }
927
934
 
928
935
 
929
- if __name__ == "__main__":
936
+ def process_args(args: list[str]) -> argparse.Namespace:
930
937
  arg_parser = argparse.ArgumentParser(description="Run evaluations on SAEs")
931
938
  arg_parser.add_argument(
932
939
  "sae_regex_pattern",
@@ -1016,11 +1023,19 @@ if __name__ == "__main__":
1016
1023
  help="Enable verbose output with tqdm loaders.",
1017
1024
  )
1018
1025
 
1019
- args = arg_parser.parse_args()
1020
- eval_results = run_evaluations(args)
1021
- output_files = process_results(eval_results, args.output_dir)
1026
+ return arg_parser.parse_args(args)
1027
+
1028
+
1029
+ def run_evals_cli(args: list[str]) -> None:
1030
+ opts = process_args(args)
1031
+ eval_results = run_evaluations(opts)
1032
+ output_files = process_results(eval_results, opts.output_dir)
1022
1033
 
1023
1034
  print("Evaluation complete. Output files:")
1024
1035
  print(f"Individual JSONs: {len(output_files['individual_jsons'])}") # type: ignore
1025
1036
  print(f"Combined JSON: {output_files['combined_json']}")
1026
1037
  print(f"CSV: {output_files['csv']}")
1038
+
1039
+
1040
+ if __name__ == "__main__":
1041
+ run_evals_cli(sys.argv[1:])
@@ -0,0 +1,377 @@
1
+ import json
2
+ import signal
3
+ import sys
4
+ from collections.abc import Sequence
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from typing import Any, Generic
8
+
9
+ import torch
10
+ import wandb
11
+ from simple_parsing import ArgumentParser
12
+ from transformer_lens.hook_points import HookedRootModule
13
+ from typing_extensions import deprecated
14
+
15
+ from sae_lens import logger
16
+ from sae_lens.config import HfDataset, LanguageModelSAERunnerConfig
17
+ from sae_lens.constants import ACTIVATIONS_STORE_STATE_FILENAME, RUNNER_CFG_FILENAME
18
+ from sae_lens.evals import EvalConfig, run_evals
19
+ from sae_lens.load_model import load_model
20
+ from sae_lens.saes.gated_sae import GatedTrainingSAEConfig
21
+ from sae_lens.saes.jumprelu_sae import JumpReLUTrainingSAEConfig
22
+ from sae_lens.saes.sae import (
23
+ T_TRAINING_SAE,
24
+ T_TRAINING_SAE_CONFIG,
25
+ TrainingSAE,
26
+ TrainingSAEConfig,
27
+ )
28
+ from sae_lens.saes.standard_sae import StandardTrainingSAEConfig
29
+ from sae_lens.saes.topk_sae import TopKTrainingSAEConfig
30
+ from sae_lens.training.activation_scaler import ActivationScaler
31
+ from sae_lens.training.activations_store import ActivationsStore
32
+ from sae_lens.training.sae_trainer import SAETrainer
33
+ from sae_lens.training.types import DataProvider
34
+
35
+
36
+ class InterruptedException(Exception):
37
+ pass
38
+
39
+
40
+ def interrupt_callback(sig_num: Any, stack_frame: Any): # noqa: ARG001
41
+ raise InterruptedException()
42
+
43
+
44
+ @dataclass
45
+ class LLMSaeEvaluator(Generic[T_TRAINING_SAE]):
46
+ model: HookedRootModule
47
+ activations_store: ActivationsStore
48
+ eval_batch_size_prompts: int | None
49
+ n_eval_batches: int
50
+ model_kwargs: dict[str, Any]
51
+
52
+ def __call__(
53
+ self,
54
+ sae: T_TRAINING_SAE,
55
+ data_provider: DataProvider,
56
+ activation_scaler: ActivationScaler,
57
+ ) -> dict[str, Any]:
58
+ ignore_tokens = set()
59
+ if self.activations_store.exclude_special_tokens is not None:
60
+ ignore_tokens = set(self.activations_store.exclude_special_tokens.tolist())
61
+
62
+ eval_config = EvalConfig(
63
+ batch_size_prompts=self.eval_batch_size_prompts,
64
+ n_eval_reconstruction_batches=self.n_eval_batches,
65
+ n_eval_sparsity_variance_batches=self.n_eval_batches,
66
+ compute_ce_loss=True,
67
+ compute_l2_norms=True,
68
+ compute_sparsity_metrics=True,
69
+ compute_variance_metrics=True,
70
+ )
71
+
72
+ eval_metrics, _ = run_evals(
73
+ sae=sae,
74
+ activation_store=self.activations_store,
75
+ model=self.model,
76
+ activation_scaler=activation_scaler,
77
+ eval_config=eval_config,
78
+ ignore_tokens=ignore_tokens,
79
+ model_kwargs=self.model_kwargs,
80
+ ) # not calculating featurwise metrics here.
81
+
82
+ # Remove eval metrics that are already logged during training
83
+ eval_metrics.pop("metrics/explained_variance", None)
84
+ eval_metrics.pop("metrics/explained_variance_std", None)
85
+ eval_metrics.pop("metrics/l0", None)
86
+ eval_metrics.pop("metrics/l1", None)
87
+ eval_metrics.pop("metrics/mse", None)
88
+
89
+ # Remove metrics that are not useful for wandb logging
90
+ eval_metrics.pop("metrics/total_tokens_evaluated", None)
91
+
92
+ return eval_metrics
93
+
94
+
95
+ class LanguageModelSAETrainingRunner:
96
+ """
97
+ Class to run the training of a Sparse Autoencoder (SAE) on a TransformerLens model.
98
+ """
99
+
100
+ cfg: LanguageModelSAERunnerConfig[Any]
101
+ model: HookedRootModule
102
+ sae: TrainingSAE[Any]
103
+ activations_store: ActivationsStore
104
+
105
+ def __init__(
106
+ self,
107
+ cfg: LanguageModelSAERunnerConfig[T_TRAINING_SAE_CONFIG],
108
+ override_dataset: HfDataset | None = None,
109
+ override_model: HookedRootModule | None = None,
110
+ override_sae: TrainingSAE[Any] | None = None,
111
+ ):
112
+ if override_dataset is not None:
113
+ logger.warning(
114
+ f"You just passed in a dataset which will override the one specified in your configuration: {cfg.dataset_path}. As a consequence this run will not be reproducible via configuration alone."
115
+ )
116
+ if override_model is not None:
117
+ logger.warning(
118
+ f"You just passed in a model which will override the one specified in your configuration: {cfg.model_name}. As a consequence this run will not be reproducible via configuration alone."
119
+ )
120
+
121
+ self.cfg = cfg
122
+
123
+ if override_model is None:
124
+ self.model = load_model(
125
+ self.cfg.model_class_name,
126
+ self.cfg.model_name,
127
+ device=self.cfg.device,
128
+ model_from_pretrained_kwargs=self.cfg.model_from_pretrained_kwargs,
129
+ )
130
+ else:
131
+ self.model = override_model
132
+
133
+ self.activations_store = ActivationsStore.from_config(
134
+ self.model,
135
+ self.cfg,
136
+ override_dataset=override_dataset,
137
+ )
138
+
139
+ if override_sae is None:
140
+ if self.cfg.from_pretrained_path is not None:
141
+ self.sae = TrainingSAE.load_from_disk(
142
+ self.cfg.from_pretrained_path, self.cfg.device
143
+ )
144
+ else:
145
+ self.sae = TrainingSAE.from_dict(
146
+ TrainingSAEConfig.from_dict(
147
+ self.cfg.get_training_sae_cfg_dict(),
148
+ ).to_dict()
149
+ )
150
+ else:
151
+ self.sae = override_sae
152
+ self.sae.to(self.cfg.device)
153
+
154
+ def run(self):
155
+ """
156
+ Run the training of the SAE.
157
+ """
158
+ self._set_sae_metadata()
159
+ if self.cfg.logger.log_to_wandb:
160
+ wandb.init(
161
+ project=self.cfg.logger.wandb_project,
162
+ entity=self.cfg.logger.wandb_entity,
163
+ config=self.cfg.to_dict(),
164
+ name=self.cfg.logger.run_name,
165
+ id=self.cfg.logger.wandb_id,
166
+ )
167
+
168
+ evaluator = LLMSaeEvaluator(
169
+ model=self.model,
170
+ activations_store=self.activations_store,
171
+ eval_batch_size_prompts=self.cfg.eval_batch_size_prompts,
172
+ n_eval_batches=self.cfg.n_eval_batches,
173
+ model_kwargs=self.cfg.model_kwargs,
174
+ )
175
+
176
+ trainer = SAETrainer(
177
+ sae=self.sae,
178
+ data_provider=self.activations_store,
179
+ evaluator=evaluator,
180
+ save_checkpoint_fn=self.save_checkpoint,
181
+ cfg=self.cfg.to_sae_trainer_config(),
182
+ )
183
+
184
+ self._compile_if_needed()
185
+ sae = self.run_trainer_with_interruption_handling(trainer)
186
+
187
+ if self.cfg.logger.log_to_wandb:
188
+ wandb.finish()
189
+
190
+ return sae
191
+
192
+ def _set_sae_metadata(self):
193
+ self.sae.cfg.metadata.dataset_path = self.cfg.dataset_path
194
+ self.sae.cfg.metadata.hook_name = self.cfg.hook_name
195
+ self.sae.cfg.metadata.model_name = self.cfg.model_name
196
+ self.sae.cfg.metadata.model_class_name = self.cfg.model_class_name
197
+ self.sae.cfg.metadata.hook_head_index = self.cfg.hook_head_index
198
+ self.sae.cfg.metadata.context_size = self.cfg.context_size
199
+ self.sae.cfg.metadata.seqpos_slice = self.cfg.seqpos_slice
200
+ self.sae.cfg.metadata.model_from_pretrained_kwargs = (
201
+ self.cfg.model_from_pretrained_kwargs
202
+ )
203
+ self.sae.cfg.metadata.prepend_bos = self.cfg.prepend_bos
204
+ self.sae.cfg.metadata.exclude_special_tokens = self.cfg.exclude_special_tokens
205
+
206
+ def _compile_if_needed(self):
207
+ # Compile model and SAE
208
+ # torch.compile can provide significant speedups (10-20% in testing)
209
+ # using max-autotune gives the best speedups but:
210
+ # (a) increases VRAM usage,
211
+ # (b) can't be used on both SAE and LM (some issue with cudagraphs), and
212
+ # (c) takes some time to compile
213
+ # optimal settings seem to be:
214
+ # use max-autotune on SAE and max-autotune-no-cudagraphs on LM
215
+ # (also pylance seems to really hate this)
216
+ if self.cfg.compile_llm:
217
+ self.model = torch.compile(
218
+ self.model,
219
+ mode=self.cfg.llm_compilation_mode,
220
+ ) # type: ignore
221
+
222
+ if self.cfg.compile_sae:
223
+ backend = "aot_eager" if self.cfg.device == "mps" else "inductor"
224
+
225
+ self.sae.training_forward_pass = torch.compile( # type: ignore
226
+ self.sae.training_forward_pass,
227
+ mode=self.cfg.sae_compilation_mode,
228
+ backend=backend,
229
+ ) # type: ignore
230
+
231
+ def run_trainer_with_interruption_handling(
232
+ self, trainer: SAETrainer[TrainingSAE[TrainingSAEConfig], TrainingSAEConfig]
233
+ ):
234
+ try:
235
+ # signal handlers (if preempted)
236
+ signal.signal(signal.SIGINT, interrupt_callback)
237
+ signal.signal(signal.SIGTERM, interrupt_callback)
238
+
239
+ # train SAE
240
+ sae = trainer.fit()
241
+
242
+ except (KeyboardInterrupt, InterruptedException):
243
+ logger.warning("interrupted, saving progress")
244
+ checkpoint_path = Path(self.cfg.checkpoint_path) / str(
245
+ trainer.n_training_samples
246
+ )
247
+ self.save_checkpoint(checkpoint_path)
248
+ logger.info("done saving")
249
+ raise
250
+
251
+ return sae
252
+
253
+ def save_checkpoint(
254
+ self,
255
+ checkpoint_path: Path,
256
+ ) -> None:
257
+ self.activations_store.save(
258
+ str(checkpoint_path / ACTIVATIONS_STORE_STATE_FILENAME)
259
+ )
260
+
261
+ runner_config = self.cfg.to_dict()
262
+ with open(checkpoint_path / RUNNER_CFG_FILENAME, "w") as f:
263
+ json.dump(runner_config, f)
264
+
265
+
266
+ def _parse_cfg_args(
267
+ args: Sequence[str],
268
+ ) -> LanguageModelSAERunnerConfig[TrainingSAEConfig]:
269
+ """
270
+ Parse command line arguments into a LanguageModelSAERunnerConfig.
271
+
272
+ This function first parses the architecture argument to determine which
273
+ concrete SAE config class to use, then parses the full configuration
274
+ with that concrete type.
275
+ """
276
+ if len(args) == 0:
277
+ args = ["--help"]
278
+
279
+ # First, parse only the architecture to determine which concrete class to use
280
+ architecture_parser = ArgumentParser(
281
+ description="Parse architecture to determine SAE config class",
282
+ exit_on_error=False,
283
+ add_help=False, # Don't add help to avoid conflicts
284
+ )
285
+ architecture_parser.add_argument(
286
+ "--architecture",
287
+ type=str,
288
+ choices=["standard", "gated", "jumprelu", "topk"],
289
+ default="standard",
290
+ help="SAE architecture to use",
291
+ )
292
+
293
+ # Parse known args to extract architecture, ignore unknown args for now
294
+ arch_args, remaining_args = architecture_parser.parse_known_args(args)
295
+ architecture = arch_args.architecture
296
+
297
+ # Remove architecture from remaining args if it exists
298
+ filtered_args = []
299
+ skip_next = False
300
+ for arg in remaining_args:
301
+ if skip_next:
302
+ skip_next = False
303
+ continue
304
+ if arg == "--architecture":
305
+ skip_next = True # Skip the next argument (the architecture value)
306
+ continue
307
+ filtered_args.append(arg)
308
+
309
+ # Create a custom wrapper class that simple_parsing can handle
310
+ def create_config_class(
311
+ sae_config_type: type[TrainingSAEConfig],
312
+ ) -> type[LanguageModelSAERunnerConfig[TrainingSAEConfig]]:
313
+ """Create a concrete config class for the given SAE config type."""
314
+
315
+ # Create the base config without the sae field
316
+ from dataclasses import field as dataclass_field
317
+ from dataclasses import fields, make_dataclass
318
+
319
+ # Get all fields from LanguageModelSAERunnerConfig except the generic sae field
320
+ base_fields = []
321
+ for field_obj in fields(LanguageModelSAERunnerConfig):
322
+ if field_obj.name != "sae":
323
+ base_fields.append((field_obj.name, field_obj.type, field_obj))
324
+
325
+ # Add the concrete sae field
326
+ base_fields.append(
327
+ (
328
+ "sae",
329
+ sae_config_type,
330
+ dataclass_field(
331
+ default_factory=lambda: sae_config_type(d_in=512, d_sae=1024)
332
+ ),
333
+ )
334
+ )
335
+
336
+ # Create the concrete class
337
+ return make_dataclass(
338
+ f"{sae_config_type.__name__}RunnerConfig",
339
+ base_fields,
340
+ bases=(LanguageModelSAERunnerConfig,),
341
+ )
342
+
343
+ # Map architecture to concrete config class
344
+ sae_config_map = {
345
+ "standard": StandardTrainingSAEConfig,
346
+ "gated": GatedTrainingSAEConfig,
347
+ "jumprelu": JumpReLUTrainingSAEConfig,
348
+ "topk": TopKTrainingSAEConfig,
349
+ }
350
+
351
+ sae_config_type = sae_config_map[architecture]
352
+ concrete_config_class = create_config_class(sae_config_type)
353
+
354
+ # Now parse the full configuration with the concrete type
355
+ parser = ArgumentParser(exit_on_error=False)
356
+ parser.add_arguments(concrete_config_class, dest="cfg")
357
+
358
+ # Parse the filtered arguments (without --architecture)
359
+ parsed_args = parser.parse_args(filtered_args)
360
+
361
+ # Return the parsed configuration
362
+ return parsed_args.cfg
363
+
364
+
365
+ # moved into its own function to make it easier to test
366
+ def _run_cli(args: Sequence[str]):
367
+ cfg = _parse_cfg_args(args)
368
+ LanguageModelSAETrainingRunner(cfg=cfg).run()
369
+
370
+
371
+ if __name__ == "__main__":
372
+ _run_cli(args=sys.argv[1:])
373
+
374
+
375
+ @deprecated("Use LanguageModelSAETrainingRunner instead")
376
+ class SAETrainingRunner(LanguageModelSAETrainingRunner):
377
+ pass