sae-lens 5.11.0__py3-none-any.whl → 6.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. sae_lens/__init__.py +60 -7
  2. sae_lens/analysis/hooked_sae_transformer.py +12 -12
  3. sae_lens/analysis/neuronpedia_integration.py +16 -14
  4. sae_lens/cache_activations_runner.py +9 -7
  5. sae_lens/config.py +170 -258
  6. sae_lens/constants.py +21 -0
  7. sae_lens/evals.py +59 -44
  8. sae_lens/llm_sae_training_runner.py +377 -0
  9. sae_lens/load_model.py +52 -4
  10. sae_lens/{toolkit → loading}/pretrained_sae_loaders.py +85 -32
  11. sae_lens/registry.py +49 -0
  12. sae_lens/saes/__init__.py +48 -0
  13. sae_lens/saes/gated_sae.py +254 -0
  14. sae_lens/saes/jumprelu_sae.py +348 -0
  15. sae_lens/saes/sae.py +1076 -0
  16. sae_lens/saes/standard_sae.py +178 -0
  17. sae_lens/saes/topk_sae.py +300 -0
  18. sae_lens/training/activation_scaler.py +53 -0
  19. sae_lens/training/activations_store.py +103 -184
  20. sae_lens/training/mixing_buffer.py +56 -0
  21. sae_lens/training/optim.py +60 -36
  22. sae_lens/training/sae_trainer.py +155 -177
  23. sae_lens/training/types.py +5 -0
  24. sae_lens/training/upload_saes_to_huggingface.py +13 -7
  25. sae_lens/util.py +47 -0
  26. {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/METADATA +1 -1
  27. sae_lens-6.0.0.dist-info/RECORD +37 -0
  28. sae_lens/sae.py +0 -747
  29. sae_lens/sae_training_runner.py +0 -251
  30. sae_lens/training/geometric_median.py +0 -101
  31. sae_lens/training/training_sae.py +0 -710
  32. sae_lens-5.11.0.dist-info/RECORD +0 -28
  33. /sae_lens/{toolkit → loading}/__init__.py +0 -0
  34. /sae_lens/{toolkit → loading}/pretrained_saes_directory.py +0 -0
  35. {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/LICENSE +0 -0
  36. {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/WHEEL +0 -0
@@ -1,251 +0,0 @@
1
- import json
2
- import signal
3
- import sys
4
- from collections.abc import Sequence
5
- from pathlib import Path
6
- from typing import Any, cast
7
-
8
- import torch
9
- import wandb
10
- from simple_parsing import ArgumentParser
11
- from transformer_lens.hook_points import HookedRootModule
12
-
13
- from sae_lens import logger
14
- from sae_lens.config import HfDataset, LanguageModelSAERunnerConfig
15
- from sae_lens.load_model import load_model
16
- from sae_lens.training.activations_store import ActivationsStore
17
- from sae_lens.training.geometric_median import compute_geometric_median
18
- from sae_lens.training.sae_trainer import SAETrainer
19
- from sae_lens.training.training_sae import TrainingSAE, TrainingSAEConfig
20
-
21
-
22
- class InterruptedException(Exception):
23
- pass
24
-
25
-
26
- def interrupt_callback(sig_num: Any, stack_frame: Any): # noqa: ARG001
27
- raise InterruptedException()
28
-
29
-
30
- class SAETrainingRunner:
31
- """
32
- Class to run the training of a Sparse Autoencoder (SAE) on a TransformerLens model.
33
- """
34
-
35
- cfg: LanguageModelSAERunnerConfig
36
- model: HookedRootModule
37
- sae: TrainingSAE
38
- activations_store: ActivationsStore
39
-
40
- def __init__(
41
- self,
42
- cfg: LanguageModelSAERunnerConfig,
43
- override_dataset: HfDataset | None = None,
44
- override_model: HookedRootModule | None = None,
45
- override_sae: TrainingSAE | None = None,
46
- ):
47
- if override_dataset is not None:
48
- logger.warning(
49
- f"You just passed in a dataset which will override the one specified in your configuration: {cfg.dataset_path}. As a consequence this run will not be reproducible via configuration alone."
50
- )
51
- if override_model is not None:
52
- logger.warning(
53
- f"You just passed in a model which will override the one specified in your configuration: {cfg.model_name}. As a consequence this run will not be reproducible via configuration alone."
54
- )
55
-
56
- self.cfg = cfg
57
-
58
- if override_model is None:
59
- self.model = load_model(
60
- self.cfg.model_class_name,
61
- self.cfg.model_name,
62
- device=self.cfg.device,
63
- model_from_pretrained_kwargs=self.cfg.model_from_pretrained_kwargs,
64
- )
65
- else:
66
- self.model = override_model
67
-
68
- self.activations_store = ActivationsStore.from_config(
69
- self.model,
70
- self.cfg,
71
- override_dataset=override_dataset,
72
- )
73
-
74
- if override_sae is None:
75
- if self.cfg.from_pretrained_path is not None:
76
- self.sae = TrainingSAE.load_from_pretrained(
77
- self.cfg.from_pretrained_path, self.cfg.device
78
- )
79
- else:
80
- self.sae = TrainingSAE(
81
- TrainingSAEConfig.from_dict(
82
- self.cfg.get_training_sae_cfg_dict(),
83
- )
84
- )
85
- self._init_sae_group_b_decs()
86
- else:
87
- self.sae = override_sae
88
-
89
- def run(self):
90
- """
91
- Run the training of the SAE.
92
- """
93
-
94
- if self.cfg.log_to_wandb:
95
- wandb.init(
96
- project=self.cfg.wandb_project,
97
- entity=self.cfg.wandb_entity,
98
- config=cast(Any, self.cfg),
99
- name=self.cfg.run_name,
100
- id=self.cfg.wandb_id,
101
- )
102
-
103
- trainer = SAETrainer(
104
- model=self.model,
105
- sae=self.sae,
106
- activation_store=self.activations_store,
107
- save_checkpoint_fn=self.save_checkpoint,
108
- cfg=self.cfg,
109
- )
110
-
111
- self._compile_if_needed()
112
- sae = self.run_trainer_with_interruption_handling(trainer)
113
-
114
- if self.cfg.log_to_wandb:
115
- wandb.finish()
116
-
117
- return sae
118
-
119
- def _compile_if_needed(self):
120
- # Compile model and SAE
121
- # torch.compile can provide significant speedups (10-20% in testing)
122
- # using max-autotune gives the best speedups but:
123
- # (a) increases VRAM usage,
124
- # (b) can't be used on both SAE and LM (some issue with cudagraphs), and
125
- # (c) takes some time to compile
126
- # optimal settings seem to be:
127
- # use max-autotune on SAE and max-autotune-no-cudagraphs on LM
128
- # (also pylance seems to really hate this)
129
- if self.cfg.compile_llm:
130
- self.model = torch.compile(
131
- self.model,
132
- mode=self.cfg.llm_compilation_mode,
133
- ) # type: ignore
134
-
135
- if self.cfg.compile_sae:
136
- backend = "aot_eager" if self.cfg.device == "mps" else "inductor"
137
-
138
- self.sae.training_forward_pass = torch.compile( # type: ignore
139
- self.sae.training_forward_pass,
140
- mode=self.cfg.sae_compilation_mode,
141
- backend=backend,
142
- ) # type: ignore
143
-
144
- def run_trainer_with_interruption_handling(self, trainer: SAETrainer):
145
- try:
146
- # signal handlers (if preempted)
147
- signal.signal(signal.SIGINT, interrupt_callback)
148
- signal.signal(signal.SIGTERM, interrupt_callback)
149
-
150
- # train SAE
151
- sae = trainer.fit()
152
-
153
- except (KeyboardInterrupt, InterruptedException):
154
- logger.warning("interrupted, saving progress")
155
- checkpoint_name = str(trainer.n_training_tokens)
156
- self.save_checkpoint(trainer, checkpoint_name=checkpoint_name)
157
- logger.info("done saving")
158
- raise
159
-
160
- return sae
161
-
162
- # TODO: move this into the SAE trainer or Training SAE class
163
- def _init_sae_group_b_decs(
164
- self,
165
- ) -> None:
166
- """
167
- extract all activations at a certain layer and use for sae b_dec initialization
168
- """
169
-
170
- if self.cfg.b_dec_init_method == "geometric_median":
171
- self.activations_store.set_norm_scaling_factor_if_needed()
172
- layer_acts = self.activations_store.storage_buffer.detach()[:, 0, :]
173
- # get geometric median of the activations if we're using those.
174
- median = compute_geometric_median(
175
- layer_acts,
176
- maxiter=100,
177
- ).median
178
- self.sae.initialize_b_dec_with_precalculated(median) # type: ignore
179
- elif self.cfg.b_dec_init_method == "mean":
180
- self.activations_store.set_norm_scaling_factor_if_needed()
181
- layer_acts = self.activations_store.storage_buffer.detach().cpu()[:, 0, :]
182
- self.sae.initialize_b_dec_with_mean(layer_acts) # type: ignore
183
-
184
- @staticmethod
185
- def save_checkpoint(
186
- trainer: SAETrainer,
187
- checkpoint_name: str,
188
- wandb_aliases: list[str] | None = None,
189
- ) -> None:
190
- base_path = Path(trainer.cfg.checkpoint_path) / checkpoint_name
191
- base_path.mkdir(exist_ok=True, parents=True)
192
-
193
- trainer.activations_store.save(
194
- str(base_path / "activations_store_state.safetensors")
195
- )
196
-
197
- if trainer.sae.cfg.normalize_sae_decoder:
198
- trainer.sae.set_decoder_norm_to_unit_norm()
199
-
200
- weights_path, cfg_path, sparsity_path = trainer.sae.save_model(
201
- str(base_path),
202
- trainer.log_feature_sparsity,
203
- )
204
-
205
- # let's over write the cfg file with the trainer cfg, which is a super set of the original cfg.
206
- # and should not cause issues but give us more info about SAEs we trained in SAE Lens.
207
- config = trainer.cfg.to_dict()
208
- with open(cfg_path, "w") as f:
209
- json.dump(config, f)
210
-
211
- if trainer.cfg.log_to_wandb:
212
- # Avoid wandb saving errors such as:
213
- # ValueError: Artifact name may only contain alphanumeric characters, dashes, underscores, and dots. Invalid name: sae_google/gemma-2b_etc
214
- sae_name = trainer.sae.get_name().replace("/", "__")
215
-
216
- # save model weights and cfg
217
- model_artifact = wandb.Artifact(
218
- sae_name,
219
- type="model",
220
- metadata=dict(trainer.cfg.__dict__),
221
- )
222
- model_artifact.add_file(str(weights_path))
223
- model_artifact.add_file(str(cfg_path))
224
- wandb.log_artifact(model_artifact, aliases=wandb_aliases)
225
-
226
- # save log feature sparsity
227
- sparsity_artifact = wandb.Artifact(
228
- f"{sae_name}_log_feature_sparsity",
229
- type="log_feature_sparsity",
230
- metadata=dict(trainer.cfg.__dict__),
231
- )
232
- sparsity_artifact.add_file(str(sparsity_path))
233
- wandb.log_artifact(sparsity_artifact)
234
-
235
-
236
- def _parse_cfg_args(args: Sequence[str]) -> LanguageModelSAERunnerConfig:
237
- if len(args) == 0:
238
- args = ["--help"]
239
- parser = ArgumentParser(exit_on_error=False)
240
- parser.add_arguments(LanguageModelSAERunnerConfig, dest="cfg")
241
- return parser.parse_args(args).cfg
242
-
243
-
244
- # moved into its own function to make it easier to test
245
- def _run_cli(args: Sequence[str]):
246
- cfg = _parse_cfg_args(args)
247
- SAETrainingRunner(cfg=cfg).run()
248
-
249
-
250
- if __name__ == "__main__":
251
- _run_cli(args=sys.argv[1:])
@@ -1,101 +0,0 @@
1
- from types import SimpleNamespace
2
-
3
- import torch
4
- import tqdm
5
-
6
-
7
- def weighted_average(points: torch.Tensor, weights: torch.Tensor):
8
- weights = weights / weights.sum()
9
- return (points * weights.view(-1, 1)).sum(dim=0)
10
-
11
-
12
- @torch.no_grad()
13
- def geometric_median_objective(
14
- median: torch.Tensor, points: torch.Tensor, weights: torch.Tensor
15
- ) -> torch.Tensor:
16
- norms = torch.linalg.norm(points - median.view(1, -1), dim=1) # type: ignore
17
-
18
- return (norms * weights).sum()
19
-
20
-
21
- def compute_geometric_median(
22
- points: torch.Tensor,
23
- weights: torch.Tensor | None = None,
24
- eps: float = 1e-6,
25
- maxiter: int = 100,
26
- ftol: float = 1e-20,
27
- do_log: bool = False,
28
- ):
29
- """
30
- :param points: ``torch.Tensor`` of shape ``(n, d)``
31
- :param weights: Optional ``torch.Tensor`` of shape :math:``(n,)``.
32
- :param eps: Smallest allowed value of denominator, to avoid divide by zero.
33
- Equivalently, this is a smoothing parameter. Default 1e-6.
34
- :param maxiter: Maximum number of Weiszfeld iterations. Default 100
35
- :param ftol: If objective value does not improve by at least this `ftol` fraction, terminate the algorithm. Default 1e-20.
36
- :param do_log: If true will return a log of function values encountered through the course of the algorithm
37
- :return: SimpleNamespace object with fields
38
- - `median`: estimate of the geometric median, which is a ``torch.Tensor`` object of shape :math:``(d,)``
39
- - `termination`: string explaining how the algorithm terminated.
40
- - `logs`: function values encountered through the course of the algorithm in a list (None if do_log is false).
41
- """
42
- with torch.no_grad():
43
- if weights is None:
44
- weights = torch.ones((points.shape[0],), device=points.device)
45
- # initialize median estimate at mean
46
- new_weights = weights
47
- median = weighted_average(points, weights)
48
- objective_value = geometric_median_objective(median, points, weights)
49
- logs = [objective_value] if do_log else None
50
-
51
- # Weiszfeld iterations
52
- early_termination = False
53
- pbar = tqdm.tqdm(range(maxiter))
54
- for _ in pbar:
55
- prev_obj_value = objective_value
56
-
57
- norms = torch.linalg.norm(points - median.view(1, -1), dim=1) # type: ignore
58
- new_weights = weights / torch.clamp(norms, min=eps)
59
- median = weighted_average(points, new_weights)
60
- objective_value = geometric_median_objective(median, points, weights)
61
-
62
- if logs is not None:
63
- logs.append(objective_value)
64
- if abs(prev_obj_value - objective_value) <= ftol * objective_value:
65
- early_termination = True
66
- break
67
-
68
- pbar.set_description(f"Objective value: {objective_value:.4f}")
69
-
70
- median = weighted_average(points, new_weights) # allow autodiff to track it
71
- return SimpleNamespace(
72
- median=median,
73
- new_weights=new_weights,
74
- termination=(
75
- "function value converged within tolerance"
76
- if early_termination
77
- else "maximum iterations reached"
78
- ),
79
- logs=logs,
80
- )
81
-
82
-
83
- if __name__ == "__main__":
84
- import time
85
-
86
- TOLERANCE = 1e-2
87
-
88
- dim1 = 10000
89
- dim2 = 768
90
- device = "cuda" if torch.cuda.is_available() else "cpu"
91
-
92
- sample = (
93
- torch.randn((dim1, dim2), device=device) * 100
94
- ) # seems to be the order of magnitude of the actual use case
95
- weights = torch.randn((dim1,), device=device)
96
-
97
- torch.tensor(weights, device=device)
98
-
99
- tic = time.perf_counter()
100
- new = compute_geometric_median(sample, weights=weights, maxiter=100)
101
- print(f"new code takes {time.perf_counter()-tic} seconds!") # noqa: T201