sae-lens 6.0.0rc2__py3-none-any.whl → 6.0.0rc4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +6 -3
- sae_lens/analysis/neuronpedia_integration.py +3 -3
- sae_lens/cache_activations_runner.py +7 -6
- sae_lens/config.py +50 -6
- sae_lens/constants.py +2 -0
- sae_lens/evals.py +39 -28
- sae_lens/llm_sae_training_runner.py +377 -0
- sae_lens/load_model.py +53 -5
- sae_lens/loading/pretrained_sae_loaders.py +24 -12
- sae_lens/saes/gated_sae.py +0 -4
- sae_lens/saes/jumprelu_sae.py +4 -10
- sae_lens/saes/sae.py +121 -51
- sae_lens/saes/standard_sae.py +4 -11
- sae_lens/saes/topk_sae.py +18 -12
- sae_lens/training/activation_scaler.py +53 -0
- sae_lens/training/activations_store.py +77 -174
- sae_lens/training/mixing_buffer.py +56 -0
- sae_lens/training/sae_trainer.py +107 -98
- sae_lens/training/types.py +5 -0
- sae_lens/training/upload_saes_to_huggingface.py +1 -1
- sae_lens/util.py +19 -0
- {sae_lens-6.0.0rc2.dist-info → sae_lens-6.0.0rc4.dist-info}/METADATA +1 -1
- sae_lens-6.0.0rc4.dist-info/RECORD +37 -0
- sae_lens/sae_training_runner.py +0 -237
- sae_lens/training/geometric_median.py +0 -101
- sae_lens-6.0.0rc2.dist-info/RECORD +0 -35
- {sae_lens-6.0.0rc2.dist-info → sae_lens-6.0.0rc4.dist-info}/LICENSE +0 -0
- {sae_lens-6.0.0rc2.dist-info → sae_lens-6.0.0rc4.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,377 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import signal
|
|
3
|
+
import sys
|
|
4
|
+
from collections.abc import Sequence
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Generic
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
import wandb
|
|
11
|
+
from simple_parsing import ArgumentParser
|
|
12
|
+
from transformer_lens.hook_points import HookedRootModule
|
|
13
|
+
from typing_extensions import deprecated
|
|
14
|
+
|
|
15
|
+
from sae_lens import logger
|
|
16
|
+
from sae_lens.config import HfDataset, LanguageModelSAERunnerConfig
|
|
17
|
+
from sae_lens.constants import ACTIVATIONS_STORE_STATE_FILENAME, RUNNER_CFG_FILENAME
|
|
18
|
+
from sae_lens.evals import EvalConfig, run_evals
|
|
19
|
+
from sae_lens.load_model import load_model
|
|
20
|
+
from sae_lens.saes.gated_sae import GatedTrainingSAEConfig
|
|
21
|
+
from sae_lens.saes.jumprelu_sae import JumpReLUTrainingSAEConfig
|
|
22
|
+
from sae_lens.saes.sae import (
|
|
23
|
+
T_TRAINING_SAE,
|
|
24
|
+
T_TRAINING_SAE_CONFIG,
|
|
25
|
+
TrainingSAE,
|
|
26
|
+
TrainingSAEConfig,
|
|
27
|
+
)
|
|
28
|
+
from sae_lens.saes.standard_sae import StandardTrainingSAEConfig
|
|
29
|
+
from sae_lens.saes.topk_sae import TopKTrainingSAEConfig
|
|
30
|
+
from sae_lens.training.activation_scaler import ActivationScaler
|
|
31
|
+
from sae_lens.training.activations_store import ActivationsStore
|
|
32
|
+
from sae_lens.training.sae_trainer import SAETrainer
|
|
33
|
+
from sae_lens.training.types import DataProvider
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class InterruptedException(Exception):
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def interrupt_callback(sig_num: Any, stack_frame: Any): # noqa: ARG001
|
|
41
|
+
raise InterruptedException()
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass
|
|
45
|
+
class LLMSaeEvaluator(Generic[T_TRAINING_SAE]):
|
|
46
|
+
model: HookedRootModule
|
|
47
|
+
activations_store: ActivationsStore
|
|
48
|
+
eval_batch_size_prompts: int | None
|
|
49
|
+
n_eval_batches: int
|
|
50
|
+
model_kwargs: dict[str, Any]
|
|
51
|
+
|
|
52
|
+
def __call__(
|
|
53
|
+
self,
|
|
54
|
+
sae: T_TRAINING_SAE,
|
|
55
|
+
data_provider: DataProvider,
|
|
56
|
+
activation_scaler: ActivationScaler,
|
|
57
|
+
) -> dict[str, Any]:
|
|
58
|
+
ignore_tokens = set()
|
|
59
|
+
if self.activations_store.exclude_special_tokens is not None:
|
|
60
|
+
ignore_tokens = set(self.activations_store.exclude_special_tokens.tolist())
|
|
61
|
+
|
|
62
|
+
eval_config = EvalConfig(
|
|
63
|
+
batch_size_prompts=self.eval_batch_size_prompts,
|
|
64
|
+
n_eval_reconstruction_batches=self.n_eval_batches,
|
|
65
|
+
n_eval_sparsity_variance_batches=self.n_eval_batches,
|
|
66
|
+
compute_ce_loss=True,
|
|
67
|
+
compute_l2_norms=True,
|
|
68
|
+
compute_sparsity_metrics=True,
|
|
69
|
+
compute_variance_metrics=True,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
eval_metrics, _ = run_evals(
|
|
73
|
+
sae=sae,
|
|
74
|
+
activation_store=self.activations_store,
|
|
75
|
+
model=self.model,
|
|
76
|
+
activation_scaler=activation_scaler,
|
|
77
|
+
eval_config=eval_config,
|
|
78
|
+
ignore_tokens=ignore_tokens,
|
|
79
|
+
model_kwargs=self.model_kwargs,
|
|
80
|
+
) # not calculating featurwise metrics here.
|
|
81
|
+
|
|
82
|
+
# Remove eval metrics that are already logged during training
|
|
83
|
+
eval_metrics.pop("metrics/explained_variance", None)
|
|
84
|
+
eval_metrics.pop("metrics/explained_variance_std", None)
|
|
85
|
+
eval_metrics.pop("metrics/l0", None)
|
|
86
|
+
eval_metrics.pop("metrics/l1", None)
|
|
87
|
+
eval_metrics.pop("metrics/mse", None)
|
|
88
|
+
|
|
89
|
+
# Remove metrics that are not useful for wandb logging
|
|
90
|
+
eval_metrics.pop("metrics/total_tokens_evaluated", None)
|
|
91
|
+
|
|
92
|
+
return eval_metrics
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class LanguageModelSAETrainingRunner:
|
|
96
|
+
"""
|
|
97
|
+
Class to run the training of a Sparse Autoencoder (SAE) on a TransformerLens model.
|
|
98
|
+
"""
|
|
99
|
+
|
|
100
|
+
cfg: LanguageModelSAERunnerConfig[Any]
|
|
101
|
+
model: HookedRootModule
|
|
102
|
+
sae: TrainingSAE[Any]
|
|
103
|
+
activations_store: ActivationsStore
|
|
104
|
+
|
|
105
|
+
def __init__(
|
|
106
|
+
self,
|
|
107
|
+
cfg: LanguageModelSAERunnerConfig[T_TRAINING_SAE_CONFIG],
|
|
108
|
+
override_dataset: HfDataset | None = None,
|
|
109
|
+
override_model: HookedRootModule | None = None,
|
|
110
|
+
override_sae: TrainingSAE[Any] | None = None,
|
|
111
|
+
):
|
|
112
|
+
if override_dataset is not None:
|
|
113
|
+
logger.warning(
|
|
114
|
+
f"You just passed in a dataset which will override the one specified in your configuration: {cfg.dataset_path}. As a consequence this run will not be reproducible via configuration alone."
|
|
115
|
+
)
|
|
116
|
+
if override_model is not None:
|
|
117
|
+
logger.warning(
|
|
118
|
+
f"You just passed in a model which will override the one specified in your configuration: {cfg.model_name}. As a consequence this run will not be reproducible via configuration alone."
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
self.cfg = cfg
|
|
122
|
+
|
|
123
|
+
if override_model is None:
|
|
124
|
+
self.model = load_model(
|
|
125
|
+
self.cfg.model_class_name,
|
|
126
|
+
self.cfg.model_name,
|
|
127
|
+
device=self.cfg.device,
|
|
128
|
+
model_from_pretrained_kwargs=self.cfg.model_from_pretrained_kwargs,
|
|
129
|
+
)
|
|
130
|
+
else:
|
|
131
|
+
self.model = override_model
|
|
132
|
+
|
|
133
|
+
self.activations_store = ActivationsStore.from_config(
|
|
134
|
+
self.model,
|
|
135
|
+
self.cfg,
|
|
136
|
+
override_dataset=override_dataset,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
if override_sae is None:
|
|
140
|
+
if self.cfg.from_pretrained_path is not None:
|
|
141
|
+
self.sae = TrainingSAE.load_from_disk(
|
|
142
|
+
self.cfg.from_pretrained_path, self.cfg.device
|
|
143
|
+
)
|
|
144
|
+
else:
|
|
145
|
+
self.sae = TrainingSAE.from_dict(
|
|
146
|
+
TrainingSAEConfig.from_dict(
|
|
147
|
+
self.cfg.get_training_sae_cfg_dict(),
|
|
148
|
+
).to_dict()
|
|
149
|
+
)
|
|
150
|
+
else:
|
|
151
|
+
self.sae = override_sae
|
|
152
|
+
self.sae.to(self.cfg.device)
|
|
153
|
+
|
|
154
|
+
def run(self):
|
|
155
|
+
"""
|
|
156
|
+
Run the training of the SAE.
|
|
157
|
+
"""
|
|
158
|
+
self._set_sae_metadata()
|
|
159
|
+
if self.cfg.logger.log_to_wandb:
|
|
160
|
+
wandb.init(
|
|
161
|
+
project=self.cfg.logger.wandb_project,
|
|
162
|
+
entity=self.cfg.logger.wandb_entity,
|
|
163
|
+
config=self.cfg.to_dict(),
|
|
164
|
+
name=self.cfg.logger.run_name,
|
|
165
|
+
id=self.cfg.logger.wandb_id,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
evaluator = LLMSaeEvaluator(
|
|
169
|
+
model=self.model,
|
|
170
|
+
activations_store=self.activations_store,
|
|
171
|
+
eval_batch_size_prompts=self.cfg.eval_batch_size_prompts,
|
|
172
|
+
n_eval_batches=self.cfg.n_eval_batches,
|
|
173
|
+
model_kwargs=self.cfg.model_kwargs,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
trainer = SAETrainer(
|
|
177
|
+
sae=self.sae,
|
|
178
|
+
data_provider=self.activations_store,
|
|
179
|
+
evaluator=evaluator,
|
|
180
|
+
save_checkpoint_fn=self.save_checkpoint,
|
|
181
|
+
cfg=self.cfg.to_sae_trainer_config(),
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
self._compile_if_needed()
|
|
185
|
+
sae = self.run_trainer_with_interruption_handling(trainer)
|
|
186
|
+
|
|
187
|
+
if self.cfg.logger.log_to_wandb:
|
|
188
|
+
wandb.finish()
|
|
189
|
+
|
|
190
|
+
return sae
|
|
191
|
+
|
|
192
|
+
def _set_sae_metadata(self):
|
|
193
|
+
self.sae.cfg.metadata.dataset_path = self.cfg.dataset_path
|
|
194
|
+
self.sae.cfg.metadata.hook_name = self.cfg.hook_name
|
|
195
|
+
self.sae.cfg.metadata.model_name = self.cfg.model_name
|
|
196
|
+
self.sae.cfg.metadata.model_class_name = self.cfg.model_class_name
|
|
197
|
+
self.sae.cfg.metadata.hook_head_index = self.cfg.hook_head_index
|
|
198
|
+
self.sae.cfg.metadata.context_size = self.cfg.context_size
|
|
199
|
+
self.sae.cfg.metadata.seqpos_slice = self.cfg.seqpos_slice
|
|
200
|
+
self.sae.cfg.metadata.model_from_pretrained_kwargs = (
|
|
201
|
+
self.cfg.model_from_pretrained_kwargs
|
|
202
|
+
)
|
|
203
|
+
self.sae.cfg.metadata.prepend_bos = self.cfg.prepend_bos
|
|
204
|
+
self.sae.cfg.metadata.exclude_special_tokens = self.cfg.exclude_special_tokens
|
|
205
|
+
|
|
206
|
+
def _compile_if_needed(self):
|
|
207
|
+
# Compile model and SAE
|
|
208
|
+
# torch.compile can provide significant speedups (10-20% in testing)
|
|
209
|
+
# using max-autotune gives the best speedups but:
|
|
210
|
+
# (a) increases VRAM usage,
|
|
211
|
+
# (b) can't be used on both SAE and LM (some issue with cudagraphs), and
|
|
212
|
+
# (c) takes some time to compile
|
|
213
|
+
# optimal settings seem to be:
|
|
214
|
+
# use max-autotune on SAE and max-autotune-no-cudagraphs on LM
|
|
215
|
+
# (also pylance seems to really hate this)
|
|
216
|
+
if self.cfg.compile_llm:
|
|
217
|
+
self.model = torch.compile(
|
|
218
|
+
self.model,
|
|
219
|
+
mode=self.cfg.llm_compilation_mode,
|
|
220
|
+
) # type: ignore
|
|
221
|
+
|
|
222
|
+
if self.cfg.compile_sae:
|
|
223
|
+
backend = "aot_eager" if self.cfg.device == "mps" else "inductor"
|
|
224
|
+
|
|
225
|
+
self.sae.training_forward_pass = torch.compile( # type: ignore
|
|
226
|
+
self.sae.training_forward_pass,
|
|
227
|
+
mode=self.cfg.sae_compilation_mode,
|
|
228
|
+
backend=backend,
|
|
229
|
+
) # type: ignore
|
|
230
|
+
|
|
231
|
+
def run_trainer_with_interruption_handling(
|
|
232
|
+
self, trainer: SAETrainer[TrainingSAE[TrainingSAEConfig], TrainingSAEConfig]
|
|
233
|
+
):
|
|
234
|
+
try:
|
|
235
|
+
# signal handlers (if preempted)
|
|
236
|
+
signal.signal(signal.SIGINT, interrupt_callback)
|
|
237
|
+
signal.signal(signal.SIGTERM, interrupt_callback)
|
|
238
|
+
|
|
239
|
+
# train SAE
|
|
240
|
+
sae = trainer.fit()
|
|
241
|
+
|
|
242
|
+
except (KeyboardInterrupt, InterruptedException):
|
|
243
|
+
logger.warning("interrupted, saving progress")
|
|
244
|
+
checkpoint_path = Path(self.cfg.checkpoint_path) / str(
|
|
245
|
+
trainer.n_training_samples
|
|
246
|
+
)
|
|
247
|
+
self.save_checkpoint(checkpoint_path)
|
|
248
|
+
logger.info("done saving")
|
|
249
|
+
raise
|
|
250
|
+
|
|
251
|
+
return sae
|
|
252
|
+
|
|
253
|
+
def save_checkpoint(
|
|
254
|
+
self,
|
|
255
|
+
checkpoint_path: Path,
|
|
256
|
+
) -> None:
|
|
257
|
+
self.activations_store.save(
|
|
258
|
+
str(checkpoint_path / ACTIVATIONS_STORE_STATE_FILENAME)
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
runner_config = self.cfg.to_dict()
|
|
262
|
+
with open(checkpoint_path / RUNNER_CFG_FILENAME, "w") as f:
|
|
263
|
+
json.dump(runner_config, f)
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def _parse_cfg_args(
|
|
267
|
+
args: Sequence[str],
|
|
268
|
+
) -> LanguageModelSAERunnerConfig[TrainingSAEConfig]:
|
|
269
|
+
"""
|
|
270
|
+
Parse command line arguments into a LanguageModelSAERunnerConfig.
|
|
271
|
+
|
|
272
|
+
This function first parses the architecture argument to determine which
|
|
273
|
+
concrete SAE config class to use, then parses the full configuration
|
|
274
|
+
with that concrete type.
|
|
275
|
+
"""
|
|
276
|
+
if len(args) == 0:
|
|
277
|
+
args = ["--help"]
|
|
278
|
+
|
|
279
|
+
# First, parse only the architecture to determine which concrete class to use
|
|
280
|
+
architecture_parser = ArgumentParser(
|
|
281
|
+
description="Parse architecture to determine SAE config class",
|
|
282
|
+
exit_on_error=False,
|
|
283
|
+
add_help=False, # Don't add help to avoid conflicts
|
|
284
|
+
)
|
|
285
|
+
architecture_parser.add_argument(
|
|
286
|
+
"--architecture",
|
|
287
|
+
type=str,
|
|
288
|
+
choices=["standard", "gated", "jumprelu", "topk"],
|
|
289
|
+
default="standard",
|
|
290
|
+
help="SAE architecture to use",
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
# Parse known args to extract architecture, ignore unknown args for now
|
|
294
|
+
arch_args, remaining_args = architecture_parser.parse_known_args(args)
|
|
295
|
+
architecture = arch_args.architecture
|
|
296
|
+
|
|
297
|
+
# Remove architecture from remaining args if it exists
|
|
298
|
+
filtered_args = []
|
|
299
|
+
skip_next = False
|
|
300
|
+
for arg in remaining_args:
|
|
301
|
+
if skip_next:
|
|
302
|
+
skip_next = False
|
|
303
|
+
continue
|
|
304
|
+
if arg == "--architecture":
|
|
305
|
+
skip_next = True # Skip the next argument (the architecture value)
|
|
306
|
+
continue
|
|
307
|
+
filtered_args.append(arg)
|
|
308
|
+
|
|
309
|
+
# Create a custom wrapper class that simple_parsing can handle
|
|
310
|
+
def create_config_class(
|
|
311
|
+
sae_config_type: type[TrainingSAEConfig],
|
|
312
|
+
) -> type[LanguageModelSAERunnerConfig[TrainingSAEConfig]]:
|
|
313
|
+
"""Create a concrete config class for the given SAE config type."""
|
|
314
|
+
|
|
315
|
+
# Create the base config without the sae field
|
|
316
|
+
from dataclasses import field as dataclass_field
|
|
317
|
+
from dataclasses import fields, make_dataclass
|
|
318
|
+
|
|
319
|
+
# Get all fields from LanguageModelSAERunnerConfig except the generic sae field
|
|
320
|
+
base_fields = []
|
|
321
|
+
for field_obj in fields(LanguageModelSAERunnerConfig):
|
|
322
|
+
if field_obj.name != "sae":
|
|
323
|
+
base_fields.append((field_obj.name, field_obj.type, field_obj))
|
|
324
|
+
|
|
325
|
+
# Add the concrete sae field
|
|
326
|
+
base_fields.append(
|
|
327
|
+
(
|
|
328
|
+
"sae",
|
|
329
|
+
sae_config_type,
|
|
330
|
+
dataclass_field(
|
|
331
|
+
default_factory=lambda: sae_config_type(d_in=512, d_sae=1024)
|
|
332
|
+
),
|
|
333
|
+
)
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
# Create the concrete class
|
|
337
|
+
return make_dataclass(
|
|
338
|
+
f"{sae_config_type.__name__}RunnerConfig",
|
|
339
|
+
base_fields,
|
|
340
|
+
bases=(LanguageModelSAERunnerConfig,),
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
# Map architecture to concrete config class
|
|
344
|
+
sae_config_map = {
|
|
345
|
+
"standard": StandardTrainingSAEConfig,
|
|
346
|
+
"gated": GatedTrainingSAEConfig,
|
|
347
|
+
"jumprelu": JumpReLUTrainingSAEConfig,
|
|
348
|
+
"topk": TopKTrainingSAEConfig,
|
|
349
|
+
}
|
|
350
|
+
|
|
351
|
+
sae_config_type = sae_config_map[architecture]
|
|
352
|
+
concrete_config_class = create_config_class(sae_config_type)
|
|
353
|
+
|
|
354
|
+
# Now parse the full configuration with the concrete type
|
|
355
|
+
parser = ArgumentParser(exit_on_error=False)
|
|
356
|
+
parser.add_arguments(concrete_config_class, dest="cfg")
|
|
357
|
+
|
|
358
|
+
# Parse the filtered arguments (without --architecture)
|
|
359
|
+
parsed_args = parser.parse_args(filtered_args)
|
|
360
|
+
|
|
361
|
+
# Return the parsed configuration
|
|
362
|
+
return parsed_args.cfg
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
# moved into its own function to make it easier to test
|
|
366
|
+
def _run_cli(args: Sequence[str]):
|
|
367
|
+
cfg = _parse_cfg_args(args)
|
|
368
|
+
LanguageModelSAETrainingRunner(cfg=cfg).run()
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
if __name__ == "__main__":
|
|
372
|
+
_run_cli(args=sys.argv[1:])
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
@deprecated("Use LanguageModelSAETrainingRunner instead")
|
|
376
|
+
class SAETrainingRunner(LanguageModelSAETrainingRunner):
|
|
377
|
+
pass
|
sae_lens/load_model.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Any, Literal, cast
|
|
1
|
+
from typing import Any, Callable, Literal, cast
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
from transformer_lens import HookedTransformer
|
|
@@ -77,6 +77,7 @@ class HookedProxyLM(HookedRootModule):
|
|
|
77
77
|
# copied and modified from base HookedRootModule
|
|
78
78
|
def setup(self):
|
|
79
79
|
self.mod_dict = {}
|
|
80
|
+
self.named_modules_dict = {}
|
|
80
81
|
self.hook_dict: dict[str, HookPoint] = {}
|
|
81
82
|
for name, module in self.model.named_modules():
|
|
82
83
|
if name == "":
|
|
@@ -89,14 +90,21 @@ class HookedProxyLM(HookedRootModule):
|
|
|
89
90
|
|
|
90
91
|
self.hook_dict[name] = hook_point
|
|
91
92
|
self.mod_dict[name] = hook_point
|
|
93
|
+
self.named_modules_dict[name] = module
|
|
94
|
+
|
|
95
|
+
def run_with_cache(self, *args: Any, **kwargs: Any): # type: ignore
|
|
96
|
+
if "names_filter" in kwargs:
|
|
97
|
+
# hacky way to make sure that the names_filter is passed to our forward method
|
|
98
|
+
kwargs["_names_filter"] = kwargs["names_filter"]
|
|
99
|
+
return super().run_with_cache(*args, **kwargs)
|
|
92
100
|
|
|
93
101
|
def forward(
|
|
94
102
|
self,
|
|
95
103
|
tokens: torch.Tensor,
|
|
96
104
|
return_type: Literal["both", "logits"] = "logits",
|
|
97
105
|
loss_per_token: bool = False,
|
|
98
|
-
# TODO: implement real support for stop_at_layer
|
|
99
106
|
stop_at_layer: int | None = None,
|
|
107
|
+
_names_filter: list[str] | None = None,
|
|
100
108
|
**kwargs: Any,
|
|
101
109
|
) -> Output | Loss:
|
|
102
110
|
# This is just what's needed for evals, not everything that HookedTransformer has
|
|
@@ -107,8 +115,28 @@ class HookedProxyLM(HookedRootModule):
|
|
|
107
115
|
raise NotImplementedError(
|
|
108
116
|
"Only return_type supported is 'both' or 'logits' to match what's in evals.py and ActivationsStore"
|
|
109
117
|
)
|
|
110
|
-
|
|
111
|
-
|
|
118
|
+
|
|
119
|
+
stop_hooks = []
|
|
120
|
+
if stop_at_layer is not None and _names_filter is not None:
|
|
121
|
+
if return_type != "logits":
|
|
122
|
+
raise NotImplementedError(
|
|
123
|
+
"stop_at_layer is not supported for return_type='both'"
|
|
124
|
+
)
|
|
125
|
+
stop_manager = StopManager(_names_filter)
|
|
126
|
+
|
|
127
|
+
for hook_name in _names_filter:
|
|
128
|
+
module = self.named_modules_dict[hook_name]
|
|
129
|
+
stop_fn = stop_manager.get_stop_hook_fn(hook_name)
|
|
130
|
+
stop_hooks.append(module.register_forward_hook(stop_fn))
|
|
131
|
+
try:
|
|
132
|
+
output = self.model(tokens)
|
|
133
|
+
logits = _extract_logits_from_output(output)
|
|
134
|
+
except StopForward:
|
|
135
|
+
# If we stop early, we don't care about the return output
|
|
136
|
+
return None # type: ignore
|
|
137
|
+
finally:
|
|
138
|
+
for stop_hook in stop_hooks:
|
|
139
|
+
stop_hook.remove()
|
|
112
140
|
|
|
113
141
|
if return_type == "logits":
|
|
114
142
|
return logits
|
|
@@ -159,7 +187,7 @@ class HookedProxyLM(HookedRootModule):
|
|
|
159
187
|
|
|
160
188
|
# We don't want to prepend bos but the tokenizer does it automatically, so we remove it manually
|
|
161
189
|
if hasattr(self.tokenizer, "add_bos_token") and self.tokenizer.add_bos_token: # type: ignore
|
|
162
|
-
tokens = get_tokens_with_bos_removed(self.tokenizer, tokens)
|
|
190
|
+
tokens = get_tokens_with_bos_removed(self.tokenizer, tokens) # type: ignore
|
|
163
191
|
return tokens # type: ignore
|
|
164
192
|
|
|
165
193
|
|
|
@@ -183,3 +211,23 @@ def get_hook_fn(hook_point: HookPoint):
|
|
|
183
211
|
return output
|
|
184
212
|
|
|
185
213
|
return hook_fn
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
class StopForward(Exception):
|
|
217
|
+
pass
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
class StopManager:
|
|
221
|
+
def __init__(self, hook_names: list[str]):
|
|
222
|
+
self.hook_names = hook_names
|
|
223
|
+
self.total_hook_names = len(set(hook_names))
|
|
224
|
+
self.called_hook_names = set()
|
|
225
|
+
|
|
226
|
+
def get_stop_hook_fn(self, hook_name: str) -> Callable[[Any, Any, Any], Any]:
|
|
227
|
+
def stop_hook_fn(module: Any, input: Any, output: Any) -> Any: # noqa: ARG001
|
|
228
|
+
self.called_hook_names.add(hook_name)
|
|
229
|
+
if len(self.called_hook_names) == self.total_hook_names:
|
|
230
|
+
raise StopForward()
|
|
231
|
+
return output
|
|
232
|
+
|
|
233
|
+
return stop_hook_fn
|
|
@@ -26,6 +26,22 @@ from sae_lens.loading.pretrained_saes_directory import (
|
|
|
26
26
|
from sae_lens.registry import get_sae_class
|
|
27
27
|
from sae_lens.util import filter_valid_dataclass_fields
|
|
28
28
|
|
|
29
|
+
LLM_METADATA_KEYS = {
|
|
30
|
+
"model_name",
|
|
31
|
+
"hook_name",
|
|
32
|
+
"model_class_name",
|
|
33
|
+
"hook_head_index",
|
|
34
|
+
"model_from_pretrained_kwargs",
|
|
35
|
+
"prepend_bos",
|
|
36
|
+
"exclude_special_tokens",
|
|
37
|
+
"neuronpedia_id",
|
|
38
|
+
"context_size",
|
|
39
|
+
"seqpos_slice",
|
|
40
|
+
"dataset_path",
|
|
41
|
+
"sae_lens_version",
|
|
42
|
+
"sae_lens_training_version",
|
|
43
|
+
}
|
|
44
|
+
|
|
29
45
|
|
|
30
46
|
# loaders take in a release, sae_id, device, and whether to force download, and returns a tuple of config, state_dict, and log sparsity
|
|
31
47
|
class PretrainedSaeHuggingfaceLoader(Protocol):
|
|
@@ -193,7 +209,6 @@ def handle_pre_6_0_config(cfg_dict: dict[str, Any]) -> dict[str, Any]:
|
|
|
193
209
|
|
|
194
210
|
rename_keys_map = {
|
|
195
211
|
"hook_point": "hook_name",
|
|
196
|
-
"hook_point_layer": "hook_layer",
|
|
197
212
|
"hook_point_head_index": "hook_head_index",
|
|
198
213
|
"activation_fn_str": "activation_fn",
|
|
199
214
|
}
|
|
@@ -208,6 +223,10 @@ def handle_pre_6_0_config(cfg_dict: dict[str, Any]) -> dict[str, Any]:
|
|
|
208
223
|
new_cfg.setdefault("activation_fn", new_cfg.get("activation_fn", "relu"))
|
|
209
224
|
new_cfg.setdefault("architecture", "standard")
|
|
210
225
|
new_cfg.setdefault("neuronpedia_id", None)
|
|
226
|
+
new_cfg.setdefault(
|
|
227
|
+
"reshape_activations",
|
|
228
|
+
"hook_z" if "hook_z" in new_cfg.get("hook_name", "") else "none",
|
|
229
|
+
)
|
|
211
230
|
|
|
212
231
|
if "normalize_activations" in new_cfg and isinstance(
|
|
213
232
|
new_cfg["normalize_activations"], bool
|
|
@@ -232,11 +251,9 @@ def handle_pre_6_0_config(cfg_dict: dict[str, Any]) -> dict[str, Any]:
|
|
|
232
251
|
if architecture == "topk":
|
|
233
252
|
sae_cfg_dict["k"] = new_cfg["activation_fn_kwargs"]["k"]
|
|
234
253
|
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
meta_dict = filter_valid_dataclass_fields(new_cfg, SAEMetadata)
|
|
239
|
-
sae_cfg_dict["metadata"] = meta_dict
|
|
254
|
+
sae_cfg_dict["metadata"] = {
|
|
255
|
+
k: v for k, v in new_cfg.items() if k in LLM_METADATA_KEYS
|
|
256
|
+
}
|
|
240
257
|
sae_cfg_dict["architecture"] = architecture
|
|
241
258
|
return sae_cfg_dict
|
|
242
259
|
|
|
@@ -262,7 +279,6 @@ def get_connor_rob_hook_z_config_from_hf(
|
|
|
262
279
|
"device": device if device is not None else "cpu",
|
|
263
280
|
"model_name": "gpt2-small",
|
|
264
281
|
"hook_name": old_cfg_dict["act_name"],
|
|
265
|
-
"hook_layer": old_cfg_dict["layer"],
|
|
266
282
|
"hook_head_index": None,
|
|
267
283
|
"activation_fn": "relu",
|
|
268
284
|
"apply_b_dec_to_input": True,
|
|
@@ -273,6 +289,7 @@ def get_connor_rob_hook_z_config_from_hf(
|
|
|
273
289
|
"context_size": 128,
|
|
274
290
|
"normalize_activations": "none",
|
|
275
291
|
"dataset_trust_remote_code": True,
|
|
292
|
+
"reshape_activations": "hook_z",
|
|
276
293
|
**(cfg_overrides or {}),
|
|
277
294
|
}
|
|
278
295
|
|
|
@@ -411,7 +428,6 @@ def get_gemma_2_config_from_hf(
|
|
|
411
428
|
"dtype": "float32",
|
|
412
429
|
"model_name": model_name,
|
|
413
430
|
"hook_name": hook_name,
|
|
414
|
-
"hook_layer": layer,
|
|
415
431
|
"hook_head_index": None,
|
|
416
432
|
"activation_fn": "relu",
|
|
417
433
|
"finetuning_scaling_factor": False,
|
|
@@ -524,7 +540,6 @@ def get_llama_scope_config_from_hf(
|
|
|
524
540
|
"dtype": "bfloat16",
|
|
525
541
|
"model_name": model_name,
|
|
526
542
|
"hook_name": old_cfg_dict["hook_point_in"],
|
|
527
|
-
"hook_layer": int(old_cfg_dict["hook_point_in"].split(".")[1]),
|
|
528
543
|
"hook_head_index": None,
|
|
529
544
|
"activation_fn": "relu",
|
|
530
545
|
"finetuning_scaling_factor": False,
|
|
@@ -651,7 +666,6 @@ def get_dictionary_learning_config_1_from_hf(
|
|
|
651
666
|
"device": device,
|
|
652
667
|
"model_name": trainer["lm_name"].split("/")[-1],
|
|
653
668
|
"hook_name": hook_point_name,
|
|
654
|
-
"hook_layer": trainer["layer"],
|
|
655
669
|
"hook_head_index": None,
|
|
656
670
|
"activation_fn": activation_fn,
|
|
657
671
|
"activation_fn_kwargs": activation_fn_kwargs,
|
|
@@ -690,7 +704,6 @@ def get_deepseek_r1_config_from_hf(
|
|
|
690
704
|
"context_size": 1024,
|
|
691
705
|
"model_name": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
|
|
692
706
|
"hook_name": f"blocks.{layer}.hook_resid_post",
|
|
693
|
-
"hook_layer": layer,
|
|
694
707
|
"hook_head_index": None,
|
|
695
708
|
"prepend_bos": True,
|
|
696
709
|
"dataset_path": "lmsys/lmsys-chat-1m",
|
|
@@ -849,7 +862,6 @@ def get_llama_scope_r1_distill_config_from_hf(
|
|
|
849
862
|
"device": device,
|
|
850
863
|
"model_name": model_name,
|
|
851
864
|
"hook_name": huggingface_cfg_dict["hook_point_in"],
|
|
852
|
-
"hook_layer": int(huggingface_cfg_dict["hook_point_in"].split(".")[1]),
|
|
853
865
|
"hook_head_index": None,
|
|
854
866
|
"activation_fn": "relu",
|
|
855
867
|
"finetuning_scaling_factor": False,
|
sae_lens/saes/gated_sae.py
CHANGED
|
@@ -168,10 +168,6 @@ class GatedTrainingSAE(TrainingSAE[GatedTrainingSAEConfig]):
|
|
|
168
168
|
|
|
169
169
|
# Magnitude path
|
|
170
170
|
magnitude_pre_activation = sae_in @ (self.W_enc * self.r_mag.exp()) + self.b_mag
|
|
171
|
-
if self.training and self.cfg.noise_scale > 0:
|
|
172
|
-
magnitude_pre_activation += (
|
|
173
|
-
torch.randn_like(magnitude_pre_activation) * self.cfg.noise_scale
|
|
174
|
-
)
|
|
175
171
|
magnitude_pre_activation = self.hook_sae_acts_pre(magnitude_pre_activation)
|
|
176
172
|
|
|
177
173
|
feature_magnitudes = self.activation_fn(magnitude_pre_activation)
|
sae_lens/saes/jumprelu_sae.py
CHANGED
|
@@ -105,7 +105,7 @@ class JumpReLUSAE(SAE[JumpReLUSAEConfig]):
|
|
|
105
105
|
JumpReLUSAE is an inference-only implementation of a Sparse Autoencoder (SAE)
|
|
106
106
|
using a JumpReLU activation. For each unit, if its pre-activation is
|
|
107
107
|
<= threshold, that unit is zeroed out; otherwise, it follows a user-specified
|
|
108
|
-
activation function (e.g., ReLU
|
|
108
|
+
activation function (e.g., ReLU etc.).
|
|
109
109
|
|
|
110
110
|
It implements:
|
|
111
111
|
- initialize_weights: sets up parameters, including a threshold.
|
|
@@ -142,7 +142,7 @@ class JumpReLUSAE(SAE[JumpReLUSAEConfig]):
|
|
|
142
142
|
sae_in = self.process_sae_in(x)
|
|
143
143
|
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
|
|
144
144
|
|
|
145
|
-
# 1) Apply the base "activation_fn" from config (e.g., ReLU
|
|
145
|
+
# 1) Apply the base "activation_fn" from config (e.g., ReLU).
|
|
146
146
|
base_acts = self.activation_fn(hidden_pre)
|
|
147
147
|
|
|
148
148
|
# 2) Zero out any unit whose (hidden_pre <= threshold).
|
|
@@ -191,8 +191,8 @@ class JumpReLUTrainingSAEConfig(TrainingSAEConfig):
|
|
|
191
191
|
Configuration class for training a JumpReLUTrainingSAE.
|
|
192
192
|
"""
|
|
193
193
|
|
|
194
|
-
jumprelu_init_threshold: float = 0.
|
|
195
|
-
jumprelu_bandwidth: float = 0.
|
|
194
|
+
jumprelu_init_threshold: float = 0.01
|
|
195
|
+
jumprelu_bandwidth: float = 0.05
|
|
196
196
|
l0_coefficient: float = 1.0
|
|
197
197
|
l0_warm_up_steps: int = 0
|
|
198
198
|
|
|
@@ -257,12 +257,6 @@ class JumpReLUTrainingSAE(TrainingSAE[JumpReLUTrainingSAEConfig]):
|
|
|
257
257
|
sae_in = self.process_sae_in(x)
|
|
258
258
|
|
|
259
259
|
hidden_pre = sae_in @ self.W_enc + self.b_enc
|
|
260
|
-
|
|
261
|
-
if self.training and self.cfg.noise_scale > 0:
|
|
262
|
-
hidden_pre = (
|
|
263
|
-
hidden_pre + torch.randn_like(hidden_pre) * self.cfg.noise_scale
|
|
264
|
-
)
|
|
265
|
-
|
|
266
260
|
feature_acts = JumpReLU.apply(hidden_pre, self.threshold, self.bandwidth)
|
|
267
261
|
|
|
268
262
|
return feature_acts, hidden_pre # type: ignore
|