sae-lens 6.0.0rc2__py3-none-any.whl → 6.0.0rc4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,237 +0,0 @@
1
- import json
2
- import signal
3
- import sys
4
- from collections.abc import Sequence
5
- from pathlib import Path
6
- from typing import Any, cast
7
-
8
- import torch
9
- import wandb
10
- from safetensors.torch import save_file
11
- from simple_parsing import ArgumentParser
12
- from transformer_lens.hook_points import HookedRootModule
13
-
14
- from sae_lens import logger
15
- from sae_lens.config import HfDataset, LanguageModelSAERunnerConfig
16
- from sae_lens.constants import RUNNER_CFG_FILENAME, SPARSITY_FILENAME
17
- from sae_lens.load_model import load_model
18
- from sae_lens.saes.sae import T_TRAINING_SAE_CONFIG, TrainingSAE, TrainingSAEConfig
19
- from sae_lens.training.activations_store import ActivationsStore
20
- from sae_lens.training.geometric_median import compute_geometric_median
21
- from sae_lens.training.sae_trainer import SAETrainer
22
-
23
-
24
- class InterruptedException(Exception):
25
- pass
26
-
27
-
28
- def interrupt_callback(sig_num: Any, stack_frame: Any): # noqa: ARG001
29
- raise InterruptedException()
30
-
31
-
32
- class SAETrainingRunner:
33
- """
34
- Class to run the training of a Sparse Autoencoder (SAE) on a TransformerLens model.
35
- """
36
-
37
- cfg: LanguageModelSAERunnerConfig[Any]
38
- model: HookedRootModule
39
- sae: TrainingSAE[Any]
40
- activations_store: ActivationsStore
41
-
42
- def __init__(
43
- self,
44
- cfg: LanguageModelSAERunnerConfig[T_TRAINING_SAE_CONFIG],
45
- override_dataset: HfDataset | None = None,
46
- override_model: HookedRootModule | None = None,
47
- override_sae: TrainingSAE[Any] | None = None,
48
- ):
49
- if override_dataset is not None:
50
- logger.warning(
51
- f"You just passed in a dataset which will override the one specified in your configuration: {cfg.dataset_path}. As a consequence this run will not be reproducible via configuration alone."
52
- )
53
- if override_model is not None:
54
- logger.warning(
55
- f"You just passed in a model which will override the one specified in your configuration: {cfg.model_name}. As a consequence this run will not be reproducible via configuration alone."
56
- )
57
-
58
- self.cfg = cfg
59
-
60
- if override_model is None:
61
- self.model = load_model(
62
- self.cfg.model_class_name,
63
- self.cfg.model_name,
64
- device=self.cfg.device,
65
- model_from_pretrained_kwargs=self.cfg.model_from_pretrained_kwargs,
66
- )
67
- else:
68
- self.model = override_model
69
-
70
- self.activations_store = ActivationsStore.from_config(
71
- self.model,
72
- self.cfg,
73
- override_dataset=override_dataset,
74
- )
75
-
76
- if override_sae is None:
77
- if self.cfg.from_pretrained_path is not None:
78
- self.sae = TrainingSAE.load_from_disk(
79
- self.cfg.from_pretrained_path, self.cfg.device
80
- )
81
- else:
82
- self.sae = TrainingSAE.from_dict(
83
- TrainingSAEConfig.from_dict(
84
- self.cfg.get_training_sae_cfg_dict(),
85
- ).to_dict()
86
- )
87
- self._init_sae_group_b_decs()
88
- else:
89
- self.sae = override_sae
90
-
91
- def run(self):
92
- """
93
- Run the training of the SAE.
94
- """
95
-
96
- if self.cfg.logger.log_to_wandb:
97
- wandb.init(
98
- project=self.cfg.logger.wandb_project,
99
- entity=self.cfg.logger.wandb_entity,
100
- config=cast(Any, self.cfg),
101
- name=self.cfg.logger.run_name,
102
- id=self.cfg.logger.wandb_id,
103
- )
104
-
105
- trainer = SAETrainer(
106
- model=self.model,
107
- sae=self.sae,
108
- activation_store=self.activations_store,
109
- save_checkpoint_fn=self.save_checkpoint,
110
- cfg=self.cfg,
111
- )
112
-
113
- self._compile_if_needed()
114
- sae = self.run_trainer_with_interruption_handling(trainer)
115
-
116
- if self.cfg.logger.log_to_wandb:
117
- wandb.finish()
118
-
119
- return sae
120
-
121
- def _compile_if_needed(self):
122
- # Compile model and SAE
123
- # torch.compile can provide significant speedups (10-20% in testing)
124
- # using max-autotune gives the best speedups but:
125
- # (a) increases VRAM usage,
126
- # (b) can't be used on both SAE and LM (some issue with cudagraphs), and
127
- # (c) takes some time to compile
128
- # optimal settings seem to be:
129
- # use max-autotune on SAE and max-autotune-no-cudagraphs on LM
130
- # (also pylance seems to really hate this)
131
- if self.cfg.compile_llm:
132
- self.model = torch.compile(
133
- self.model,
134
- mode=self.cfg.llm_compilation_mode,
135
- ) # type: ignore
136
-
137
- if self.cfg.compile_sae:
138
- backend = "aot_eager" if self.cfg.device == "mps" else "inductor"
139
-
140
- self.sae.training_forward_pass = torch.compile( # type: ignore
141
- self.sae.training_forward_pass,
142
- mode=self.cfg.sae_compilation_mode,
143
- backend=backend,
144
- ) # type: ignore
145
-
146
- def run_trainer_with_interruption_handling(
147
- self, trainer: SAETrainer[TrainingSAE[TrainingSAEConfig], TrainingSAEConfig]
148
- ):
149
- try:
150
- # signal handlers (if preempted)
151
- signal.signal(signal.SIGINT, interrupt_callback)
152
- signal.signal(signal.SIGTERM, interrupt_callback)
153
-
154
- # train SAE
155
- sae = trainer.fit()
156
-
157
- except (KeyboardInterrupt, InterruptedException):
158
- logger.warning("interrupted, saving progress")
159
- checkpoint_name = str(trainer.n_training_tokens)
160
- self.save_checkpoint(trainer, checkpoint_name=checkpoint_name)
161
- logger.info("done saving")
162
- raise
163
-
164
- return sae
165
-
166
- # TODO: move this into the SAE trainer or Training SAE class
167
- def _init_sae_group_b_decs(
168
- self,
169
- ) -> None:
170
- """
171
- extract all activations at a certain layer and use for sae b_dec initialization
172
- """
173
-
174
- if self.cfg.sae.b_dec_init_method == "geometric_median":
175
- self.activations_store.set_norm_scaling_factor_if_needed()
176
- layer_acts = self.activations_store.storage_buffer.detach()[:, 0, :]
177
- # get geometric median of the activations if we're using those.
178
- median = compute_geometric_median(
179
- layer_acts,
180
- maxiter=100,
181
- ).median
182
- self.sae.initialize_b_dec_with_precalculated(median)
183
- elif self.cfg.sae.b_dec_init_method == "mean":
184
- self.activations_store.set_norm_scaling_factor_if_needed()
185
- layer_acts = self.activations_store.storage_buffer.detach().cpu()[:, 0, :]
186
- self.sae.initialize_b_dec_with_mean(layer_acts) # type: ignore
187
-
188
- @staticmethod
189
- def save_checkpoint(
190
- trainer: SAETrainer[TrainingSAE[Any], Any],
191
- checkpoint_name: str,
192
- wandb_aliases: list[str] | None = None,
193
- ) -> None:
194
- base_path = Path(trainer.cfg.checkpoint_path) / checkpoint_name
195
- base_path.mkdir(exist_ok=True, parents=True)
196
-
197
- trainer.activations_store.save(
198
- str(base_path / "activations_store_state.safetensors")
199
- )
200
-
201
- weights_path, cfg_path = trainer.sae.save_model(str(base_path))
202
-
203
- sparsity_path = base_path / SPARSITY_FILENAME
204
- save_file({"sparsity": trainer.log_feature_sparsity}, sparsity_path)
205
-
206
- runner_config = trainer.cfg.to_dict()
207
- with open(base_path / RUNNER_CFG_FILENAME, "w") as f:
208
- json.dump(runner_config, f)
209
-
210
- if trainer.cfg.logger.log_to_wandb:
211
- trainer.cfg.logger.log(
212
- trainer,
213
- weights_path,
214
- cfg_path,
215
- sparsity_path=sparsity_path,
216
- wandb_aliases=wandb_aliases,
217
- )
218
-
219
-
220
- def _parse_cfg_args(
221
- args: Sequence[str],
222
- ) -> LanguageModelSAERunnerConfig[TrainingSAEConfig]:
223
- if len(args) == 0:
224
- args = ["--help"]
225
- parser = ArgumentParser(exit_on_error=False)
226
- parser.add_arguments(LanguageModelSAERunnerConfig, dest="cfg")
227
- return parser.parse_args(args).cfg
228
-
229
-
230
- # moved into its own function to make it easier to test
231
- def _run_cli(args: Sequence[str]):
232
- cfg = _parse_cfg_args(args)
233
- SAETrainingRunner(cfg=cfg).run()
234
-
235
-
236
- if __name__ == "__main__":
237
- _run_cli(args=sys.argv[1:])
@@ -1,101 +0,0 @@
1
- from types import SimpleNamespace
2
-
3
- import torch
4
- import tqdm
5
-
6
-
7
- def weighted_average(points: torch.Tensor, weights: torch.Tensor):
8
- weights = weights / weights.sum()
9
- return (points * weights.view(-1, 1)).sum(dim=0)
10
-
11
-
12
- @torch.no_grad()
13
- def geometric_median_objective(
14
- median: torch.Tensor, points: torch.Tensor, weights: torch.Tensor
15
- ) -> torch.Tensor:
16
- norms = torch.linalg.norm(points - median.view(1, -1), dim=1) # type: ignore
17
-
18
- return (norms * weights).sum()
19
-
20
-
21
- def compute_geometric_median(
22
- points: torch.Tensor,
23
- weights: torch.Tensor | None = None,
24
- eps: float = 1e-6,
25
- maxiter: int = 100,
26
- ftol: float = 1e-20,
27
- do_log: bool = False,
28
- ):
29
- """
30
- :param points: ``torch.Tensor`` of shape ``(n, d)``
31
- :param weights: Optional ``torch.Tensor`` of shape :math:``(n,)``.
32
- :param eps: Smallest allowed value of denominator, to avoid divide by zero.
33
- Equivalently, this is a smoothing parameter. Default 1e-6.
34
- :param maxiter: Maximum number of Weiszfeld iterations. Default 100
35
- :param ftol: If objective value does not improve by at least this `ftol` fraction, terminate the algorithm. Default 1e-20.
36
- :param do_log: If true will return a log of function values encountered through the course of the algorithm
37
- :return: SimpleNamespace object with fields
38
- - `median`: estimate of the geometric median, which is a ``torch.Tensor`` object of shape :math:``(d,)``
39
- - `termination`: string explaining how the algorithm terminated.
40
- - `logs`: function values encountered through the course of the algorithm in a list (None if do_log is false).
41
- """
42
- with torch.no_grad():
43
- if weights is None:
44
- weights = torch.ones((points.shape[0],), device=points.device)
45
- # initialize median estimate at mean
46
- new_weights = weights
47
- median = weighted_average(points, weights)
48
- objective_value = geometric_median_objective(median, points, weights)
49
- logs = [objective_value] if do_log else None
50
-
51
- # Weiszfeld iterations
52
- early_termination = False
53
- pbar = tqdm.tqdm(range(maxiter))
54
- for _ in pbar:
55
- prev_obj_value = objective_value
56
-
57
- norms = torch.linalg.norm(points - median.view(1, -1), dim=1) # type: ignore
58
- new_weights = weights / torch.clamp(norms, min=eps)
59
- median = weighted_average(points, new_weights)
60
- objective_value = geometric_median_objective(median, points, weights)
61
-
62
- if logs is not None:
63
- logs.append(objective_value)
64
- if abs(prev_obj_value - objective_value) <= ftol * objective_value:
65
- early_termination = True
66
- break
67
-
68
- pbar.set_description(f"Objective value: {objective_value:.4f}")
69
-
70
- median = weighted_average(points, new_weights) # allow autodiff to track it
71
- return SimpleNamespace(
72
- median=median,
73
- new_weights=new_weights,
74
- termination=(
75
- "function value converged within tolerance"
76
- if early_termination
77
- else "maximum iterations reached"
78
- ),
79
- logs=logs,
80
- )
81
-
82
-
83
- if __name__ == "__main__":
84
- import time
85
-
86
- TOLERANCE = 1e-2
87
-
88
- dim1 = 10000
89
- dim2 = 768
90
- device = "cuda" if torch.cuda.is_available() else "cpu"
91
-
92
- sample = (
93
- torch.randn((dim1, dim2), device=device) * 100
94
- ) # seems to be the order of magnitude of the actual use case
95
- weights = torch.randn((dim1,), device=device)
96
-
97
- torch.tensor(weights, device=device)
98
-
99
- tic = time.perf_counter()
100
- new = compute_geometric_median(sample, weights=weights, maxiter=100)
101
- print(f"new code takes {time.perf_counter()-tic} seconds!") # noqa: T201
@@ -1,35 +0,0 @@
1
- sae_lens/__init__.py,sha256=JZATcdlWGVOXYTHb41hn7dPp7pR2tWgpLAz2ztQOE-A,2747
2
- sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- sae_lens/analysis/hooked_sae_transformer.py,sha256=Eyg1Y2hVIHNuiiLOCTgzstOuW6iA-7hPHqaGR8y_vMs,13809
4
- sae_lens/analysis/neuronpedia_integration.py,sha256=DlI08ThI0zwMrBthICt1OFCMyqmaCUDeZxhOk7b7teY,18680
5
- sae_lens/cache_activations_runner.py,sha256=27jp2hFxZj4foWCRCJJd2VCwYJtMgkvPx6MuIhQBofc,12591
6
- sae_lens/config.py,sha256=Ff6MRzRlVk8xtgkvHdJEmuPh9Owc10XIWBaUwdypzkU,26062
7
- sae_lens/constants.py,sha256=HSiSp0j2Umak2buT30seFhkmj7KNuPmB3u4yLXrgfOg,462
8
- sae_lens/evals.py,sha256=aR0pJMBWBUdZElXPcxUyNnNYWbM2LC5UeaESKAwdOMY,39098
9
- sae_lens/load_model.py,sha256=tE70sXsyyyGYW7o506O3eiw1MXyyW6DCQojLG49hWYI,6771
10
- sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
- sae_lens/loading/pretrained_sae_loaders.py,sha256=IgQ-XSJ5VTLCzmJavPmk1vExBVB-36wW7w-ZNo7tzPY,31214
12
- sae_lens/loading/pretrained_saes_directory.py,sha256=4Vn-Jex6SveD7EbxcSOBv8cx1gkPfUMLU1QOP-ww1ZE,3752
13
- sae_lens/pretokenize_runner.py,sha256=0nHQq3s_d80VS8iVK4-e6y_orAYVO8c4RrLGtIDfK_E,6885
14
- sae_lens/pretrained_saes.yaml,sha256=C_z-7Lxz6ZIy2V-c-4Xw45eAQ926O9aGjocSNuki0xs,573557
15
- sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
16
- sae_lens/sae_training_runner.py,sha256=lI_d3ywS312dIz0wctm_Sgt3W9ffBOS7ahnDXBljX1s,8320
17
- sae_lens/saes/__init__.py,sha256=v6mfeDzyGYtT6x5SszAQtkldTXwPE-V_iwOlrT_pDwQ,1008
18
- sae_lens/saes/gated_sae.py,sha256=IgWvZxeJpdiu7VqeUnJLC-VWVhz6o8OXvmwCS-LJ-WQ,9426
19
- sae_lens/saes/jumprelu_sae.py,sha256=lkhafpoYYn4-62tBlmmufmUomoo3CmFFQQ3NNylBNSM,12264
20
- sae_lens/saes/sae.py,sha256=edJK3VFzOVBPXUX6QJ5fhhoY0wcfEisDmVXiqFRA7Xg,35089
21
- sae_lens/saes/standard_sae.py,sha256=tMs6Z6Cv44PWa7pLo53xhXFnHMvO5BM6eVYHtRPLpos,6652
22
- sae_lens/saes/topk_sae.py,sha256=CfF59K4J2XwUvztwg4fBbvFO3PyucLkg4Elkxdk0ozs,9786
23
- sae_lens/tokenization_and_batching.py,sha256=oUAscjy_LPOrOb8_Ty6eLAcZ0B3HB_wiWjWktgolhG0,4314
24
- sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
25
- sae_lens/training/activations_store.py,sha256=5V5dExeXWoE0dw-ePOZVnQIbBJwrepRMdsQrRam9Lg8,36790
26
- sae_lens/training/geometric_median.py,sha256=3kH8ZJAgKStlnZgs6s1uYGDYh004Bl0r4RLhuwT3lBY,3719
27
- sae_lens/training/optim.py,sha256=KXdOym-Ly3f2aFbndRc0JEH0Wa7u1BE5ljxGN3YtouQ,6836
28
- sae_lens/training/sae_trainer.py,sha256=zYAk_9QJ8AJi2TjDZ1qW_lyoovSBqrJvBHzyYgb89ZY,15251
29
- sae_lens/training/upload_saes_to_huggingface.py,sha256=tXvR4j25IgMjJ8R9oczwSdy00Tg-P_jAtnPHRt8yF64,4489
30
- sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
31
- sae_lens/util.py,sha256=4lqtl7HT9OiyRK8fe8nXtkcn2lOR1uX7ANrAClf6Bv8,1026
32
- sae_lens-6.0.0rc2.dist-info/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
33
- sae_lens-6.0.0rc2.dist-info/METADATA,sha256=Z8Zwb6EknAPB5dOvfduYZewr4nldot-1dQoqz50Co3k,5326
34
- sae_lens-6.0.0rc2.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
35
- sae_lens-6.0.0rc2.dist-info/RECORD,,