sae-lens 5.11.0__py3-none-any.whl → 6.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +60 -7
- sae_lens/analysis/hooked_sae_transformer.py +12 -12
- sae_lens/analysis/neuronpedia_integration.py +16 -14
- sae_lens/cache_activations_runner.py +9 -7
- sae_lens/config.py +170 -258
- sae_lens/constants.py +21 -0
- sae_lens/evals.py +59 -44
- sae_lens/llm_sae_training_runner.py +377 -0
- sae_lens/load_model.py +52 -4
- sae_lens/{toolkit → loading}/pretrained_sae_loaders.py +85 -32
- sae_lens/registry.py +49 -0
- sae_lens/saes/__init__.py +48 -0
- sae_lens/saes/gated_sae.py +254 -0
- sae_lens/saes/jumprelu_sae.py +348 -0
- sae_lens/saes/sae.py +1076 -0
- sae_lens/saes/standard_sae.py +178 -0
- sae_lens/saes/topk_sae.py +300 -0
- sae_lens/training/activation_scaler.py +53 -0
- sae_lens/training/activations_store.py +103 -184
- sae_lens/training/mixing_buffer.py +56 -0
- sae_lens/training/optim.py +60 -36
- sae_lens/training/sae_trainer.py +155 -177
- sae_lens/training/types.py +5 -0
- sae_lens/training/upload_saes_to_huggingface.py +13 -7
- sae_lens/util.py +47 -0
- {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/METADATA +1 -1
- sae_lens-6.0.0.dist-info/RECORD +37 -0
- sae_lens/sae.py +0 -747
- sae_lens/sae_training_runner.py +0 -251
- sae_lens/training/geometric_median.py +0 -101
- sae_lens/training/training_sae.py +0 -710
- sae_lens-5.11.0.dist-info/RECORD +0 -28
- /sae_lens/{toolkit → loading}/__init__.py +0 -0
- /sae_lens/{toolkit → loading}/pretrained_saes_directory.py +0 -0
- {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/LICENSE +0 -0
- {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/WHEEL +0 -0
sae_lens/sae_training_runner.py
DELETED
|
@@ -1,251 +0,0 @@
|
|
|
1
|
-
import json
|
|
2
|
-
import signal
|
|
3
|
-
import sys
|
|
4
|
-
from collections.abc import Sequence
|
|
5
|
-
from pathlib import Path
|
|
6
|
-
from typing import Any, cast
|
|
7
|
-
|
|
8
|
-
import torch
|
|
9
|
-
import wandb
|
|
10
|
-
from simple_parsing import ArgumentParser
|
|
11
|
-
from transformer_lens.hook_points import HookedRootModule
|
|
12
|
-
|
|
13
|
-
from sae_lens import logger
|
|
14
|
-
from sae_lens.config import HfDataset, LanguageModelSAERunnerConfig
|
|
15
|
-
from sae_lens.load_model import load_model
|
|
16
|
-
from sae_lens.training.activations_store import ActivationsStore
|
|
17
|
-
from sae_lens.training.geometric_median import compute_geometric_median
|
|
18
|
-
from sae_lens.training.sae_trainer import SAETrainer
|
|
19
|
-
from sae_lens.training.training_sae import TrainingSAE, TrainingSAEConfig
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
class InterruptedException(Exception):
|
|
23
|
-
pass
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
def interrupt_callback(sig_num: Any, stack_frame: Any): # noqa: ARG001
|
|
27
|
-
raise InterruptedException()
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
class SAETrainingRunner:
|
|
31
|
-
"""
|
|
32
|
-
Class to run the training of a Sparse Autoencoder (SAE) on a TransformerLens model.
|
|
33
|
-
"""
|
|
34
|
-
|
|
35
|
-
cfg: LanguageModelSAERunnerConfig
|
|
36
|
-
model: HookedRootModule
|
|
37
|
-
sae: TrainingSAE
|
|
38
|
-
activations_store: ActivationsStore
|
|
39
|
-
|
|
40
|
-
def __init__(
|
|
41
|
-
self,
|
|
42
|
-
cfg: LanguageModelSAERunnerConfig,
|
|
43
|
-
override_dataset: HfDataset | None = None,
|
|
44
|
-
override_model: HookedRootModule | None = None,
|
|
45
|
-
override_sae: TrainingSAE | None = None,
|
|
46
|
-
):
|
|
47
|
-
if override_dataset is not None:
|
|
48
|
-
logger.warning(
|
|
49
|
-
f"You just passed in a dataset which will override the one specified in your configuration: {cfg.dataset_path}. As a consequence this run will not be reproducible via configuration alone."
|
|
50
|
-
)
|
|
51
|
-
if override_model is not None:
|
|
52
|
-
logger.warning(
|
|
53
|
-
f"You just passed in a model which will override the one specified in your configuration: {cfg.model_name}. As a consequence this run will not be reproducible via configuration alone."
|
|
54
|
-
)
|
|
55
|
-
|
|
56
|
-
self.cfg = cfg
|
|
57
|
-
|
|
58
|
-
if override_model is None:
|
|
59
|
-
self.model = load_model(
|
|
60
|
-
self.cfg.model_class_name,
|
|
61
|
-
self.cfg.model_name,
|
|
62
|
-
device=self.cfg.device,
|
|
63
|
-
model_from_pretrained_kwargs=self.cfg.model_from_pretrained_kwargs,
|
|
64
|
-
)
|
|
65
|
-
else:
|
|
66
|
-
self.model = override_model
|
|
67
|
-
|
|
68
|
-
self.activations_store = ActivationsStore.from_config(
|
|
69
|
-
self.model,
|
|
70
|
-
self.cfg,
|
|
71
|
-
override_dataset=override_dataset,
|
|
72
|
-
)
|
|
73
|
-
|
|
74
|
-
if override_sae is None:
|
|
75
|
-
if self.cfg.from_pretrained_path is not None:
|
|
76
|
-
self.sae = TrainingSAE.load_from_pretrained(
|
|
77
|
-
self.cfg.from_pretrained_path, self.cfg.device
|
|
78
|
-
)
|
|
79
|
-
else:
|
|
80
|
-
self.sae = TrainingSAE(
|
|
81
|
-
TrainingSAEConfig.from_dict(
|
|
82
|
-
self.cfg.get_training_sae_cfg_dict(),
|
|
83
|
-
)
|
|
84
|
-
)
|
|
85
|
-
self._init_sae_group_b_decs()
|
|
86
|
-
else:
|
|
87
|
-
self.sae = override_sae
|
|
88
|
-
|
|
89
|
-
def run(self):
|
|
90
|
-
"""
|
|
91
|
-
Run the training of the SAE.
|
|
92
|
-
"""
|
|
93
|
-
|
|
94
|
-
if self.cfg.log_to_wandb:
|
|
95
|
-
wandb.init(
|
|
96
|
-
project=self.cfg.wandb_project,
|
|
97
|
-
entity=self.cfg.wandb_entity,
|
|
98
|
-
config=cast(Any, self.cfg),
|
|
99
|
-
name=self.cfg.run_name,
|
|
100
|
-
id=self.cfg.wandb_id,
|
|
101
|
-
)
|
|
102
|
-
|
|
103
|
-
trainer = SAETrainer(
|
|
104
|
-
model=self.model,
|
|
105
|
-
sae=self.sae,
|
|
106
|
-
activation_store=self.activations_store,
|
|
107
|
-
save_checkpoint_fn=self.save_checkpoint,
|
|
108
|
-
cfg=self.cfg,
|
|
109
|
-
)
|
|
110
|
-
|
|
111
|
-
self._compile_if_needed()
|
|
112
|
-
sae = self.run_trainer_with_interruption_handling(trainer)
|
|
113
|
-
|
|
114
|
-
if self.cfg.log_to_wandb:
|
|
115
|
-
wandb.finish()
|
|
116
|
-
|
|
117
|
-
return sae
|
|
118
|
-
|
|
119
|
-
def _compile_if_needed(self):
|
|
120
|
-
# Compile model and SAE
|
|
121
|
-
# torch.compile can provide significant speedups (10-20% in testing)
|
|
122
|
-
# using max-autotune gives the best speedups but:
|
|
123
|
-
# (a) increases VRAM usage,
|
|
124
|
-
# (b) can't be used on both SAE and LM (some issue with cudagraphs), and
|
|
125
|
-
# (c) takes some time to compile
|
|
126
|
-
# optimal settings seem to be:
|
|
127
|
-
# use max-autotune on SAE and max-autotune-no-cudagraphs on LM
|
|
128
|
-
# (also pylance seems to really hate this)
|
|
129
|
-
if self.cfg.compile_llm:
|
|
130
|
-
self.model = torch.compile(
|
|
131
|
-
self.model,
|
|
132
|
-
mode=self.cfg.llm_compilation_mode,
|
|
133
|
-
) # type: ignore
|
|
134
|
-
|
|
135
|
-
if self.cfg.compile_sae:
|
|
136
|
-
backend = "aot_eager" if self.cfg.device == "mps" else "inductor"
|
|
137
|
-
|
|
138
|
-
self.sae.training_forward_pass = torch.compile( # type: ignore
|
|
139
|
-
self.sae.training_forward_pass,
|
|
140
|
-
mode=self.cfg.sae_compilation_mode,
|
|
141
|
-
backend=backend,
|
|
142
|
-
) # type: ignore
|
|
143
|
-
|
|
144
|
-
def run_trainer_with_interruption_handling(self, trainer: SAETrainer):
|
|
145
|
-
try:
|
|
146
|
-
# signal handlers (if preempted)
|
|
147
|
-
signal.signal(signal.SIGINT, interrupt_callback)
|
|
148
|
-
signal.signal(signal.SIGTERM, interrupt_callback)
|
|
149
|
-
|
|
150
|
-
# train SAE
|
|
151
|
-
sae = trainer.fit()
|
|
152
|
-
|
|
153
|
-
except (KeyboardInterrupt, InterruptedException):
|
|
154
|
-
logger.warning("interrupted, saving progress")
|
|
155
|
-
checkpoint_name = str(trainer.n_training_tokens)
|
|
156
|
-
self.save_checkpoint(trainer, checkpoint_name=checkpoint_name)
|
|
157
|
-
logger.info("done saving")
|
|
158
|
-
raise
|
|
159
|
-
|
|
160
|
-
return sae
|
|
161
|
-
|
|
162
|
-
# TODO: move this into the SAE trainer or Training SAE class
|
|
163
|
-
def _init_sae_group_b_decs(
|
|
164
|
-
self,
|
|
165
|
-
) -> None:
|
|
166
|
-
"""
|
|
167
|
-
extract all activations at a certain layer and use for sae b_dec initialization
|
|
168
|
-
"""
|
|
169
|
-
|
|
170
|
-
if self.cfg.b_dec_init_method == "geometric_median":
|
|
171
|
-
self.activations_store.set_norm_scaling_factor_if_needed()
|
|
172
|
-
layer_acts = self.activations_store.storage_buffer.detach()[:, 0, :]
|
|
173
|
-
# get geometric median of the activations if we're using those.
|
|
174
|
-
median = compute_geometric_median(
|
|
175
|
-
layer_acts,
|
|
176
|
-
maxiter=100,
|
|
177
|
-
).median
|
|
178
|
-
self.sae.initialize_b_dec_with_precalculated(median) # type: ignore
|
|
179
|
-
elif self.cfg.b_dec_init_method == "mean":
|
|
180
|
-
self.activations_store.set_norm_scaling_factor_if_needed()
|
|
181
|
-
layer_acts = self.activations_store.storage_buffer.detach().cpu()[:, 0, :]
|
|
182
|
-
self.sae.initialize_b_dec_with_mean(layer_acts) # type: ignore
|
|
183
|
-
|
|
184
|
-
@staticmethod
|
|
185
|
-
def save_checkpoint(
|
|
186
|
-
trainer: SAETrainer,
|
|
187
|
-
checkpoint_name: str,
|
|
188
|
-
wandb_aliases: list[str] | None = None,
|
|
189
|
-
) -> None:
|
|
190
|
-
base_path = Path(trainer.cfg.checkpoint_path) / checkpoint_name
|
|
191
|
-
base_path.mkdir(exist_ok=True, parents=True)
|
|
192
|
-
|
|
193
|
-
trainer.activations_store.save(
|
|
194
|
-
str(base_path / "activations_store_state.safetensors")
|
|
195
|
-
)
|
|
196
|
-
|
|
197
|
-
if trainer.sae.cfg.normalize_sae_decoder:
|
|
198
|
-
trainer.sae.set_decoder_norm_to_unit_norm()
|
|
199
|
-
|
|
200
|
-
weights_path, cfg_path, sparsity_path = trainer.sae.save_model(
|
|
201
|
-
str(base_path),
|
|
202
|
-
trainer.log_feature_sparsity,
|
|
203
|
-
)
|
|
204
|
-
|
|
205
|
-
# let's over write the cfg file with the trainer cfg, which is a super set of the original cfg.
|
|
206
|
-
# and should not cause issues but give us more info about SAEs we trained in SAE Lens.
|
|
207
|
-
config = trainer.cfg.to_dict()
|
|
208
|
-
with open(cfg_path, "w") as f:
|
|
209
|
-
json.dump(config, f)
|
|
210
|
-
|
|
211
|
-
if trainer.cfg.log_to_wandb:
|
|
212
|
-
# Avoid wandb saving errors such as:
|
|
213
|
-
# ValueError: Artifact name may only contain alphanumeric characters, dashes, underscores, and dots. Invalid name: sae_google/gemma-2b_etc
|
|
214
|
-
sae_name = trainer.sae.get_name().replace("/", "__")
|
|
215
|
-
|
|
216
|
-
# save model weights and cfg
|
|
217
|
-
model_artifact = wandb.Artifact(
|
|
218
|
-
sae_name,
|
|
219
|
-
type="model",
|
|
220
|
-
metadata=dict(trainer.cfg.__dict__),
|
|
221
|
-
)
|
|
222
|
-
model_artifact.add_file(str(weights_path))
|
|
223
|
-
model_artifact.add_file(str(cfg_path))
|
|
224
|
-
wandb.log_artifact(model_artifact, aliases=wandb_aliases)
|
|
225
|
-
|
|
226
|
-
# save log feature sparsity
|
|
227
|
-
sparsity_artifact = wandb.Artifact(
|
|
228
|
-
f"{sae_name}_log_feature_sparsity",
|
|
229
|
-
type="log_feature_sparsity",
|
|
230
|
-
metadata=dict(trainer.cfg.__dict__),
|
|
231
|
-
)
|
|
232
|
-
sparsity_artifact.add_file(str(sparsity_path))
|
|
233
|
-
wandb.log_artifact(sparsity_artifact)
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
def _parse_cfg_args(args: Sequence[str]) -> LanguageModelSAERunnerConfig:
|
|
237
|
-
if len(args) == 0:
|
|
238
|
-
args = ["--help"]
|
|
239
|
-
parser = ArgumentParser(exit_on_error=False)
|
|
240
|
-
parser.add_arguments(LanguageModelSAERunnerConfig, dest="cfg")
|
|
241
|
-
return parser.parse_args(args).cfg
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
# moved into its own function to make it easier to test
|
|
245
|
-
def _run_cli(args: Sequence[str]):
|
|
246
|
-
cfg = _parse_cfg_args(args)
|
|
247
|
-
SAETrainingRunner(cfg=cfg).run()
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
if __name__ == "__main__":
|
|
251
|
-
_run_cli(args=sys.argv[1:])
|
|
@@ -1,101 +0,0 @@
|
|
|
1
|
-
from types import SimpleNamespace
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
import tqdm
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
def weighted_average(points: torch.Tensor, weights: torch.Tensor):
|
|
8
|
-
weights = weights / weights.sum()
|
|
9
|
-
return (points * weights.view(-1, 1)).sum(dim=0)
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
@torch.no_grad()
|
|
13
|
-
def geometric_median_objective(
|
|
14
|
-
median: torch.Tensor, points: torch.Tensor, weights: torch.Tensor
|
|
15
|
-
) -> torch.Tensor:
|
|
16
|
-
norms = torch.linalg.norm(points - median.view(1, -1), dim=1) # type: ignore
|
|
17
|
-
|
|
18
|
-
return (norms * weights).sum()
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
def compute_geometric_median(
|
|
22
|
-
points: torch.Tensor,
|
|
23
|
-
weights: torch.Tensor | None = None,
|
|
24
|
-
eps: float = 1e-6,
|
|
25
|
-
maxiter: int = 100,
|
|
26
|
-
ftol: float = 1e-20,
|
|
27
|
-
do_log: bool = False,
|
|
28
|
-
):
|
|
29
|
-
"""
|
|
30
|
-
:param points: ``torch.Tensor`` of shape ``(n, d)``
|
|
31
|
-
:param weights: Optional ``torch.Tensor`` of shape :math:``(n,)``.
|
|
32
|
-
:param eps: Smallest allowed value of denominator, to avoid divide by zero.
|
|
33
|
-
Equivalently, this is a smoothing parameter. Default 1e-6.
|
|
34
|
-
:param maxiter: Maximum number of Weiszfeld iterations. Default 100
|
|
35
|
-
:param ftol: If objective value does not improve by at least this `ftol` fraction, terminate the algorithm. Default 1e-20.
|
|
36
|
-
:param do_log: If true will return a log of function values encountered through the course of the algorithm
|
|
37
|
-
:return: SimpleNamespace object with fields
|
|
38
|
-
- `median`: estimate of the geometric median, which is a ``torch.Tensor`` object of shape :math:``(d,)``
|
|
39
|
-
- `termination`: string explaining how the algorithm terminated.
|
|
40
|
-
- `logs`: function values encountered through the course of the algorithm in a list (None if do_log is false).
|
|
41
|
-
"""
|
|
42
|
-
with torch.no_grad():
|
|
43
|
-
if weights is None:
|
|
44
|
-
weights = torch.ones((points.shape[0],), device=points.device)
|
|
45
|
-
# initialize median estimate at mean
|
|
46
|
-
new_weights = weights
|
|
47
|
-
median = weighted_average(points, weights)
|
|
48
|
-
objective_value = geometric_median_objective(median, points, weights)
|
|
49
|
-
logs = [objective_value] if do_log else None
|
|
50
|
-
|
|
51
|
-
# Weiszfeld iterations
|
|
52
|
-
early_termination = False
|
|
53
|
-
pbar = tqdm.tqdm(range(maxiter))
|
|
54
|
-
for _ in pbar:
|
|
55
|
-
prev_obj_value = objective_value
|
|
56
|
-
|
|
57
|
-
norms = torch.linalg.norm(points - median.view(1, -1), dim=1) # type: ignore
|
|
58
|
-
new_weights = weights / torch.clamp(norms, min=eps)
|
|
59
|
-
median = weighted_average(points, new_weights)
|
|
60
|
-
objective_value = geometric_median_objective(median, points, weights)
|
|
61
|
-
|
|
62
|
-
if logs is not None:
|
|
63
|
-
logs.append(objective_value)
|
|
64
|
-
if abs(prev_obj_value - objective_value) <= ftol * objective_value:
|
|
65
|
-
early_termination = True
|
|
66
|
-
break
|
|
67
|
-
|
|
68
|
-
pbar.set_description(f"Objective value: {objective_value:.4f}")
|
|
69
|
-
|
|
70
|
-
median = weighted_average(points, new_weights) # allow autodiff to track it
|
|
71
|
-
return SimpleNamespace(
|
|
72
|
-
median=median,
|
|
73
|
-
new_weights=new_weights,
|
|
74
|
-
termination=(
|
|
75
|
-
"function value converged within tolerance"
|
|
76
|
-
if early_termination
|
|
77
|
-
else "maximum iterations reached"
|
|
78
|
-
),
|
|
79
|
-
logs=logs,
|
|
80
|
-
)
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
if __name__ == "__main__":
|
|
84
|
-
import time
|
|
85
|
-
|
|
86
|
-
TOLERANCE = 1e-2
|
|
87
|
-
|
|
88
|
-
dim1 = 10000
|
|
89
|
-
dim2 = 768
|
|
90
|
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
91
|
-
|
|
92
|
-
sample = (
|
|
93
|
-
torch.randn((dim1, dim2), device=device) * 100
|
|
94
|
-
) # seems to be the order of magnitude of the actual use case
|
|
95
|
-
weights = torch.randn((dim1,), device=device)
|
|
96
|
-
|
|
97
|
-
torch.tensor(weights, device=device)
|
|
98
|
-
|
|
99
|
-
tic = time.perf_counter()
|
|
100
|
-
new = compute_geometric_median(sample, weights=weights, maxiter=100)
|
|
101
|
-
print(f"new code takes {time.perf_counter()-tic} seconds!") # noqa: T201
|