sae-lens 5.11.0__py3-none-any.whl → 6.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +60 -7
- sae_lens/analysis/hooked_sae_transformer.py +12 -12
- sae_lens/analysis/neuronpedia_integration.py +16 -14
- sae_lens/cache_activations_runner.py +9 -7
- sae_lens/config.py +170 -258
- sae_lens/constants.py +21 -0
- sae_lens/evals.py +59 -44
- sae_lens/llm_sae_training_runner.py +377 -0
- sae_lens/load_model.py +52 -4
- sae_lens/{toolkit → loading}/pretrained_sae_loaders.py +85 -32
- sae_lens/registry.py +49 -0
- sae_lens/saes/__init__.py +48 -0
- sae_lens/saes/gated_sae.py +254 -0
- sae_lens/saes/jumprelu_sae.py +348 -0
- sae_lens/saes/sae.py +1076 -0
- sae_lens/saes/standard_sae.py +178 -0
- sae_lens/saes/topk_sae.py +300 -0
- sae_lens/training/activation_scaler.py +53 -0
- sae_lens/training/activations_store.py +103 -184
- sae_lens/training/mixing_buffer.py +56 -0
- sae_lens/training/optim.py +60 -36
- sae_lens/training/sae_trainer.py +155 -177
- sae_lens/training/types.py +5 -0
- sae_lens/training/upload_saes_to_huggingface.py +13 -7
- sae_lens/util.py +47 -0
- {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/METADATA +1 -1
- sae_lens-6.0.0.dist-info/RECORD +37 -0
- sae_lens/sae.py +0 -747
- sae_lens/sae_training_runner.py +0 -251
- sae_lens/training/geometric_median.py +0 -101
- sae_lens/training/training_sae.py +0 -710
- sae_lens-5.11.0.dist-info/RECORD +0 -28
- /sae_lens/{toolkit → loading}/__init__.py +0 -0
- /sae_lens/{toolkit → loading}/pretrained_saes_directory.py +0 -0
- {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/LICENSE +0 -0
- {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/WHEEL +0 -0
sae_lens/evals.py
CHANGED
|
@@ -4,6 +4,7 @@ import json
|
|
|
4
4
|
import math
|
|
5
5
|
import re
|
|
6
6
|
import subprocess
|
|
7
|
+
import sys
|
|
7
8
|
from collections import defaultdict
|
|
8
9
|
from collections.abc import Mapping
|
|
9
10
|
from dataclasses import dataclass, field
|
|
@@ -15,13 +16,15 @@ from typing import Any
|
|
|
15
16
|
import einops
|
|
16
17
|
import pandas as pd
|
|
17
18
|
import torch
|
|
18
|
-
from tqdm import tqdm
|
|
19
|
+
from tqdm.auto import tqdm
|
|
19
20
|
from transformer_lens import HookedTransformer
|
|
20
21
|
from transformer_lens.hook_points import HookedRootModule
|
|
21
22
|
|
|
22
|
-
from sae_lens.
|
|
23
|
-
from sae_lens.
|
|
23
|
+
from sae_lens.loading.pretrained_saes_directory import get_pretrained_saes_directory
|
|
24
|
+
from sae_lens.saes.sae import SAE, SAEConfig
|
|
25
|
+
from sae_lens.training.activation_scaler import ActivationScaler
|
|
24
26
|
from sae_lens.training.activations_store import ActivationsStore
|
|
27
|
+
from sae_lens.util import extract_stop_at_layer_from_tlens_hook_name
|
|
25
28
|
|
|
26
29
|
|
|
27
30
|
def get_library_version() -> str:
|
|
@@ -100,15 +103,16 @@ def get_eval_everything_config(
|
|
|
100
103
|
|
|
101
104
|
@torch.no_grad()
|
|
102
105
|
def run_evals(
|
|
103
|
-
sae: SAE,
|
|
106
|
+
sae: SAE[Any],
|
|
104
107
|
activation_store: ActivationsStore,
|
|
105
108
|
model: HookedRootModule,
|
|
109
|
+
activation_scaler: ActivationScaler,
|
|
106
110
|
eval_config: EvalConfig = EvalConfig(),
|
|
107
111
|
model_kwargs: Mapping[str, Any] = {},
|
|
108
112
|
ignore_tokens: set[int | None] = set(),
|
|
109
113
|
verbose: bool = False,
|
|
110
114
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
111
|
-
hook_name = sae.cfg.hook_name
|
|
115
|
+
hook_name = sae.cfg.metadata.hook_name
|
|
112
116
|
actual_batch_size = (
|
|
113
117
|
eval_config.batch_size_prompts or activation_store.store_batch_size_prompts
|
|
114
118
|
)
|
|
@@ -140,6 +144,7 @@ def run_evals(
|
|
|
140
144
|
sae,
|
|
141
145
|
model,
|
|
142
146
|
activation_store,
|
|
147
|
+
activation_scaler,
|
|
143
148
|
compute_kl=eval_config.compute_kl,
|
|
144
149
|
compute_ce_loss=eval_config.compute_ce_loss,
|
|
145
150
|
n_batches=eval_config.n_eval_reconstruction_batches,
|
|
@@ -189,6 +194,7 @@ def run_evals(
|
|
|
189
194
|
sae,
|
|
190
195
|
model,
|
|
191
196
|
activation_store,
|
|
197
|
+
activation_scaler,
|
|
192
198
|
compute_l2_norms=eval_config.compute_l2_norms,
|
|
193
199
|
compute_sparsity_metrics=eval_config.compute_sparsity_metrics,
|
|
194
200
|
compute_variance_metrics=eval_config.compute_variance_metrics,
|
|
@@ -274,12 +280,11 @@ def run_evals(
|
|
|
274
280
|
return all_metrics, feature_metrics
|
|
275
281
|
|
|
276
282
|
|
|
277
|
-
def get_featurewise_weight_based_metrics(sae: SAE) -> dict[str, Any]:
|
|
283
|
+
def get_featurewise_weight_based_metrics(sae: SAE[Any]) -> dict[str, Any]:
|
|
278
284
|
unit_norm_encoders = (sae.W_enc / sae.W_enc.norm(dim=0, keepdim=True)).cpu()
|
|
279
285
|
unit_norm_decoder = (sae.W_dec.T / sae.W_dec.T.norm(dim=0, keepdim=True)).cpu()
|
|
280
286
|
|
|
281
287
|
encoder_norms = sae.W_enc.norm(dim=-2).cpu().tolist()
|
|
282
|
-
encoder_bias = sae.b_enc.cpu().tolist()
|
|
283
288
|
encoder_decoder_cosine_sim = (
|
|
284
289
|
torch.nn.functional.cosine_similarity(
|
|
285
290
|
unit_norm_decoder.T,
|
|
@@ -289,17 +294,20 @@ def get_featurewise_weight_based_metrics(sae: SAE) -> dict[str, Any]:
|
|
|
289
294
|
.tolist()
|
|
290
295
|
)
|
|
291
296
|
|
|
292
|
-
|
|
293
|
-
"encoder_bias": encoder_bias,
|
|
297
|
+
metrics = {
|
|
294
298
|
"encoder_norm": encoder_norms,
|
|
295
299
|
"encoder_decoder_cosine_sim": encoder_decoder_cosine_sim,
|
|
296
300
|
}
|
|
301
|
+
if hasattr(sae, "b_enc") and sae.b_enc is not None:
|
|
302
|
+
metrics["encoder_bias"] = sae.b_enc.cpu().tolist() # type: ignore
|
|
303
|
+
return metrics
|
|
297
304
|
|
|
298
305
|
|
|
299
306
|
def get_downstream_reconstruction_metrics(
|
|
300
|
-
sae: SAE,
|
|
307
|
+
sae: SAE[Any],
|
|
301
308
|
model: HookedRootModule,
|
|
302
309
|
activation_store: ActivationsStore,
|
|
310
|
+
activation_scaler: ActivationScaler,
|
|
303
311
|
compute_kl: bool,
|
|
304
312
|
compute_ce_loss: bool,
|
|
305
313
|
n_batches: int,
|
|
@@ -325,8 +333,8 @@ def get_downstream_reconstruction_metrics(
|
|
|
325
333
|
for metric_name, metric_value in get_recons_loss(
|
|
326
334
|
sae,
|
|
327
335
|
model,
|
|
336
|
+
activation_scaler,
|
|
328
337
|
batch_tokens,
|
|
329
|
-
activation_store,
|
|
330
338
|
compute_kl=compute_kl,
|
|
331
339
|
compute_ce_loss=compute_ce_loss,
|
|
332
340
|
ignore_tokens=ignore_tokens,
|
|
@@ -365,9 +373,10 @@ def get_downstream_reconstruction_metrics(
|
|
|
365
373
|
|
|
366
374
|
|
|
367
375
|
def get_sparsity_and_variance_metrics(
|
|
368
|
-
sae: SAE,
|
|
376
|
+
sae: SAE[Any],
|
|
369
377
|
model: HookedRootModule,
|
|
370
378
|
activation_store: ActivationsStore,
|
|
379
|
+
activation_scaler: ActivationScaler,
|
|
371
380
|
n_batches: int,
|
|
372
381
|
compute_l2_norms: bool,
|
|
373
382
|
compute_sparsity_metrics: bool,
|
|
@@ -378,8 +387,8 @@ def get_sparsity_and_variance_metrics(
|
|
|
378
387
|
ignore_tokens: set[int | None] = set(),
|
|
379
388
|
verbose: bool = False,
|
|
380
389
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
381
|
-
hook_name = sae.cfg.hook_name
|
|
382
|
-
hook_head_index = sae.cfg.hook_head_index
|
|
390
|
+
hook_name = sae.cfg.metadata.hook_name
|
|
391
|
+
hook_head_index = sae.cfg.metadata.hook_head_index
|
|
383
392
|
|
|
384
393
|
metric_dict = {}
|
|
385
394
|
feature_metric_dict = {}
|
|
@@ -435,7 +444,7 @@ def get_sparsity_and_variance_metrics(
|
|
|
435
444
|
batch_tokens,
|
|
436
445
|
prepend_bos=False,
|
|
437
446
|
names_filter=[hook_name],
|
|
438
|
-
stop_at_layer=
|
|
447
|
+
stop_at_layer=extract_stop_at_layer_from_tlens_hook_name(hook_name),
|
|
439
448
|
**model_kwargs,
|
|
440
449
|
)
|
|
441
450
|
|
|
@@ -450,16 +459,14 @@ def get_sparsity_and_variance_metrics(
|
|
|
450
459
|
original_act = cache[hook_name]
|
|
451
460
|
|
|
452
461
|
# normalise if necessary (necessary in training only, otherwise we should fold the scaling in)
|
|
453
|
-
|
|
454
|
-
original_act = activation_store.apply_norm_scaling_factor(original_act)
|
|
462
|
+
original_act = activation_scaler.scale(original_act)
|
|
455
463
|
|
|
456
464
|
# send the (maybe normalised) activations into the SAE
|
|
457
465
|
sae_feature_activations = sae.encode(original_act.to(sae.device))
|
|
458
466
|
sae_out = sae.decode(sae_feature_activations).to(original_act.device)
|
|
459
467
|
del cache
|
|
460
468
|
|
|
461
|
-
|
|
462
|
-
sae_out = activation_store.unscale(sae_out)
|
|
469
|
+
sae_out = activation_scaler.unscale(sae_out)
|
|
463
470
|
|
|
464
471
|
flattened_sae_input = einops.rearrange(original_act, "b ctx d -> (b ctx) d")
|
|
465
472
|
flattened_sae_feature_acts = einops.rearrange(
|
|
@@ -579,17 +586,21 @@ def get_sparsity_and_variance_metrics(
|
|
|
579
586
|
|
|
580
587
|
@torch.no_grad()
|
|
581
588
|
def get_recons_loss(
|
|
582
|
-
sae: SAE,
|
|
589
|
+
sae: SAE[SAEConfig],
|
|
583
590
|
model: HookedRootModule,
|
|
591
|
+
activation_scaler: ActivationScaler,
|
|
584
592
|
batch_tokens: torch.Tensor,
|
|
585
|
-
activation_store: ActivationsStore,
|
|
586
593
|
compute_kl: bool,
|
|
587
594
|
compute_ce_loss: bool,
|
|
588
595
|
ignore_tokens: set[int | None] = set(),
|
|
589
596
|
model_kwargs: Mapping[str, Any] = {},
|
|
597
|
+
hook_name: str | None = None,
|
|
590
598
|
) -> dict[str, Any]:
|
|
591
|
-
hook_name = sae.cfg.hook_name
|
|
592
|
-
head_index = sae.cfg.hook_head_index
|
|
599
|
+
hook_name = hook_name or sae.cfg.metadata.hook_name
|
|
600
|
+
head_index = sae.cfg.metadata.hook_head_index
|
|
601
|
+
|
|
602
|
+
if hook_name is None:
|
|
603
|
+
raise ValueError("hook_name must be provided")
|
|
593
604
|
|
|
594
605
|
original_logits, original_ce_loss = model(
|
|
595
606
|
batch_tokens, return_type="both", loss_per_token=True, **model_kwargs
|
|
@@ -613,15 +624,13 @@ def get_recons_loss(
|
|
|
613
624
|
activations = activations.to(sae.device)
|
|
614
625
|
|
|
615
626
|
# Handle rescaling if SAE expects it
|
|
616
|
-
|
|
617
|
-
activations = activation_store.apply_norm_scaling_factor(activations)
|
|
627
|
+
activations = activation_scaler.scale(activations)
|
|
618
628
|
|
|
619
629
|
# SAE class agnost forward forward pass.
|
|
620
630
|
new_activations = sae.decode(sae.encode(activations)).to(activations.dtype)
|
|
621
631
|
|
|
622
632
|
# Unscale if activations were scaled prior to going into the SAE
|
|
623
|
-
|
|
624
|
-
new_activations = activation_store.unscale(new_activations)
|
|
633
|
+
new_activations = activation_scaler.unscale(new_activations)
|
|
625
634
|
|
|
626
635
|
new_activations = torch.where(mask[..., None], new_activations, activations)
|
|
627
636
|
|
|
@@ -632,8 +641,7 @@ def get_recons_loss(
|
|
|
632
641
|
activations = activations.to(sae.device)
|
|
633
642
|
|
|
634
643
|
# Handle rescaling if SAE expects it
|
|
635
|
-
|
|
636
|
-
activations = activation_store.apply_norm_scaling_factor(activations)
|
|
644
|
+
activations = activation_scaler.scale(activations)
|
|
637
645
|
|
|
638
646
|
# SAE class agnost forward forward pass.
|
|
639
647
|
new_activations = sae.decode(sae.encode(activations.flatten(-2, -1))).to(
|
|
@@ -645,8 +653,7 @@ def get_recons_loss(
|
|
|
645
653
|
) # reshape to match original shape
|
|
646
654
|
|
|
647
655
|
# Unscale if activations were scaled prior to going into the SAE
|
|
648
|
-
|
|
649
|
-
new_activations = activation_store.unscale(new_activations)
|
|
656
|
+
new_activations = activation_scaler.unscale(new_activations)
|
|
650
657
|
|
|
651
658
|
return new_activations.to(original_device)
|
|
652
659
|
|
|
@@ -655,8 +662,7 @@ def get_recons_loss(
|
|
|
655
662
|
activations = activations.to(sae.device)
|
|
656
663
|
|
|
657
664
|
# Handle rescaling if SAE expects it
|
|
658
|
-
|
|
659
|
-
activations = activation_store.apply_norm_scaling_factor(activations)
|
|
665
|
+
activations = activation_scaler.scale(activations)
|
|
660
666
|
|
|
661
667
|
new_activations = sae.decode(sae.encode(activations[:, :, head_index])).to(
|
|
662
668
|
activations.dtype
|
|
@@ -664,8 +670,7 @@ def get_recons_loss(
|
|
|
664
670
|
activations[:, :, head_index] = new_activations
|
|
665
671
|
|
|
666
672
|
# Unscale if activations were scaled prior to going into the SAE
|
|
667
|
-
|
|
668
|
-
activations = activation_store.unscale(activations)
|
|
673
|
+
activations = activation_scaler.unscale(activations)
|
|
669
674
|
|
|
670
675
|
return activations.to(original_device)
|
|
671
676
|
|
|
@@ -794,22 +799,23 @@ def multiple_evals(
|
|
|
794
799
|
|
|
795
800
|
current_model = None
|
|
796
801
|
current_model_str = None
|
|
797
|
-
print(filtered_saes)
|
|
798
802
|
for sae_release_name, sae_id, _, _ in tqdm(filtered_saes):
|
|
799
803
|
sae = SAE.from_pretrained(
|
|
800
804
|
release=sae_release_name, # see other options in sae_lens/pretrained_saes.yaml
|
|
801
805
|
sae_id=sae_id, # won't always be a hook point
|
|
802
806
|
device=device,
|
|
803
|
-
)
|
|
807
|
+
)
|
|
804
808
|
|
|
805
809
|
# move SAE to device if not there already
|
|
806
810
|
sae.to(device)
|
|
807
811
|
|
|
808
|
-
if current_model_str != sae.cfg.model_name:
|
|
812
|
+
if current_model_str != sae.cfg.metadata.model_name:
|
|
809
813
|
del current_model # potentially saves GPU memory
|
|
810
|
-
current_model_str = sae.cfg.model_name
|
|
814
|
+
current_model_str = sae.cfg.metadata.model_name
|
|
811
815
|
current_model = HookedTransformer.from_pretrained_no_processing(
|
|
812
|
-
current_model_str,
|
|
816
|
+
current_model_str,
|
|
817
|
+
device=device,
|
|
818
|
+
**sae.cfg.metadata.model_from_pretrained_kwargs,
|
|
813
819
|
)
|
|
814
820
|
assert current_model is not None
|
|
815
821
|
|
|
@@ -834,6 +840,7 @@ def multiple_evals(
|
|
|
834
840
|
scalar_metrics, feature_metrics = run_evals(
|
|
835
841
|
sae=sae,
|
|
836
842
|
activation_store=activation_store,
|
|
843
|
+
activation_scaler=ActivationScaler(),
|
|
837
844
|
model=current_model,
|
|
838
845
|
eval_config=eval_config,
|
|
839
846
|
ignore_tokens={
|
|
@@ -926,7 +933,7 @@ def process_results(
|
|
|
926
933
|
}
|
|
927
934
|
|
|
928
935
|
|
|
929
|
-
|
|
936
|
+
def process_args(args: list[str]) -> argparse.Namespace:
|
|
930
937
|
arg_parser = argparse.ArgumentParser(description="Run evaluations on SAEs")
|
|
931
938
|
arg_parser.add_argument(
|
|
932
939
|
"sae_regex_pattern",
|
|
@@ -1016,11 +1023,19 @@ if __name__ == "__main__":
|
|
|
1016
1023
|
help="Enable verbose output with tqdm loaders.",
|
|
1017
1024
|
)
|
|
1018
1025
|
|
|
1019
|
-
|
|
1020
|
-
|
|
1021
|
-
|
|
1026
|
+
return arg_parser.parse_args(args)
|
|
1027
|
+
|
|
1028
|
+
|
|
1029
|
+
def run_evals_cli(args: list[str]) -> None:
|
|
1030
|
+
opts = process_args(args)
|
|
1031
|
+
eval_results = run_evaluations(opts)
|
|
1032
|
+
output_files = process_results(eval_results, opts.output_dir)
|
|
1022
1033
|
|
|
1023
1034
|
print("Evaluation complete. Output files:")
|
|
1024
1035
|
print(f"Individual JSONs: {len(output_files['individual_jsons'])}") # type: ignore
|
|
1025
1036
|
print(f"Combined JSON: {output_files['combined_json']}")
|
|
1026
1037
|
print(f"CSV: {output_files['csv']}")
|
|
1038
|
+
|
|
1039
|
+
|
|
1040
|
+
if __name__ == "__main__":
|
|
1041
|
+
run_evals_cli(sys.argv[1:])
|
|
@@ -0,0 +1,377 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import signal
|
|
3
|
+
import sys
|
|
4
|
+
from collections.abc import Sequence
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Generic
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
import wandb
|
|
11
|
+
from simple_parsing import ArgumentParser
|
|
12
|
+
from transformer_lens.hook_points import HookedRootModule
|
|
13
|
+
from typing_extensions import deprecated
|
|
14
|
+
|
|
15
|
+
from sae_lens import logger
|
|
16
|
+
from sae_lens.config import HfDataset, LanguageModelSAERunnerConfig
|
|
17
|
+
from sae_lens.constants import ACTIVATIONS_STORE_STATE_FILENAME, RUNNER_CFG_FILENAME
|
|
18
|
+
from sae_lens.evals import EvalConfig, run_evals
|
|
19
|
+
from sae_lens.load_model import load_model
|
|
20
|
+
from sae_lens.saes.gated_sae import GatedTrainingSAEConfig
|
|
21
|
+
from sae_lens.saes.jumprelu_sae import JumpReLUTrainingSAEConfig
|
|
22
|
+
from sae_lens.saes.sae import (
|
|
23
|
+
T_TRAINING_SAE,
|
|
24
|
+
T_TRAINING_SAE_CONFIG,
|
|
25
|
+
TrainingSAE,
|
|
26
|
+
TrainingSAEConfig,
|
|
27
|
+
)
|
|
28
|
+
from sae_lens.saes.standard_sae import StandardTrainingSAEConfig
|
|
29
|
+
from sae_lens.saes.topk_sae import TopKTrainingSAEConfig
|
|
30
|
+
from sae_lens.training.activation_scaler import ActivationScaler
|
|
31
|
+
from sae_lens.training.activations_store import ActivationsStore
|
|
32
|
+
from sae_lens.training.sae_trainer import SAETrainer
|
|
33
|
+
from sae_lens.training.types import DataProvider
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class InterruptedException(Exception):
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def interrupt_callback(sig_num: Any, stack_frame: Any): # noqa: ARG001
|
|
41
|
+
raise InterruptedException()
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass
|
|
45
|
+
class LLMSaeEvaluator(Generic[T_TRAINING_SAE]):
|
|
46
|
+
model: HookedRootModule
|
|
47
|
+
activations_store: ActivationsStore
|
|
48
|
+
eval_batch_size_prompts: int | None
|
|
49
|
+
n_eval_batches: int
|
|
50
|
+
model_kwargs: dict[str, Any]
|
|
51
|
+
|
|
52
|
+
def __call__(
|
|
53
|
+
self,
|
|
54
|
+
sae: T_TRAINING_SAE,
|
|
55
|
+
data_provider: DataProvider,
|
|
56
|
+
activation_scaler: ActivationScaler,
|
|
57
|
+
) -> dict[str, Any]:
|
|
58
|
+
ignore_tokens = set()
|
|
59
|
+
if self.activations_store.exclude_special_tokens is not None:
|
|
60
|
+
ignore_tokens = set(self.activations_store.exclude_special_tokens.tolist())
|
|
61
|
+
|
|
62
|
+
eval_config = EvalConfig(
|
|
63
|
+
batch_size_prompts=self.eval_batch_size_prompts,
|
|
64
|
+
n_eval_reconstruction_batches=self.n_eval_batches,
|
|
65
|
+
n_eval_sparsity_variance_batches=self.n_eval_batches,
|
|
66
|
+
compute_ce_loss=True,
|
|
67
|
+
compute_l2_norms=True,
|
|
68
|
+
compute_sparsity_metrics=True,
|
|
69
|
+
compute_variance_metrics=True,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
eval_metrics, _ = run_evals(
|
|
73
|
+
sae=sae,
|
|
74
|
+
activation_store=self.activations_store,
|
|
75
|
+
model=self.model,
|
|
76
|
+
activation_scaler=activation_scaler,
|
|
77
|
+
eval_config=eval_config,
|
|
78
|
+
ignore_tokens=ignore_tokens,
|
|
79
|
+
model_kwargs=self.model_kwargs,
|
|
80
|
+
) # not calculating featurwise metrics here.
|
|
81
|
+
|
|
82
|
+
# Remove eval metrics that are already logged during training
|
|
83
|
+
eval_metrics.pop("metrics/explained_variance", None)
|
|
84
|
+
eval_metrics.pop("metrics/explained_variance_std", None)
|
|
85
|
+
eval_metrics.pop("metrics/l0", None)
|
|
86
|
+
eval_metrics.pop("metrics/l1", None)
|
|
87
|
+
eval_metrics.pop("metrics/mse", None)
|
|
88
|
+
|
|
89
|
+
# Remove metrics that are not useful for wandb logging
|
|
90
|
+
eval_metrics.pop("metrics/total_tokens_evaluated", None)
|
|
91
|
+
|
|
92
|
+
return eval_metrics
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class LanguageModelSAETrainingRunner:
|
|
96
|
+
"""
|
|
97
|
+
Class to run the training of a Sparse Autoencoder (SAE) on a TransformerLens model.
|
|
98
|
+
"""
|
|
99
|
+
|
|
100
|
+
cfg: LanguageModelSAERunnerConfig[Any]
|
|
101
|
+
model: HookedRootModule
|
|
102
|
+
sae: TrainingSAE[Any]
|
|
103
|
+
activations_store: ActivationsStore
|
|
104
|
+
|
|
105
|
+
def __init__(
|
|
106
|
+
self,
|
|
107
|
+
cfg: LanguageModelSAERunnerConfig[T_TRAINING_SAE_CONFIG],
|
|
108
|
+
override_dataset: HfDataset | None = None,
|
|
109
|
+
override_model: HookedRootModule | None = None,
|
|
110
|
+
override_sae: TrainingSAE[Any] | None = None,
|
|
111
|
+
):
|
|
112
|
+
if override_dataset is not None:
|
|
113
|
+
logger.warning(
|
|
114
|
+
f"You just passed in a dataset which will override the one specified in your configuration: {cfg.dataset_path}. As a consequence this run will not be reproducible via configuration alone."
|
|
115
|
+
)
|
|
116
|
+
if override_model is not None:
|
|
117
|
+
logger.warning(
|
|
118
|
+
f"You just passed in a model which will override the one specified in your configuration: {cfg.model_name}. As a consequence this run will not be reproducible via configuration alone."
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
self.cfg = cfg
|
|
122
|
+
|
|
123
|
+
if override_model is None:
|
|
124
|
+
self.model = load_model(
|
|
125
|
+
self.cfg.model_class_name,
|
|
126
|
+
self.cfg.model_name,
|
|
127
|
+
device=self.cfg.device,
|
|
128
|
+
model_from_pretrained_kwargs=self.cfg.model_from_pretrained_kwargs,
|
|
129
|
+
)
|
|
130
|
+
else:
|
|
131
|
+
self.model = override_model
|
|
132
|
+
|
|
133
|
+
self.activations_store = ActivationsStore.from_config(
|
|
134
|
+
self.model,
|
|
135
|
+
self.cfg,
|
|
136
|
+
override_dataset=override_dataset,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
if override_sae is None:
|
|
140
|
+
if self.cfg.from_pretrained_path is not None:
|
|
141
|
+
self.sae = TrainingSAE.load_from_disk(
|
|
142
|
+
self.cfg.from_pretrained_path, self.cfg.device
|
|
143
|
+
)
|
|
144
|
+
else:
|
|
145
|
+
self.sae = TrainingSAE.from_dict(
|
|
146
|
+
TrainingSAEConfig.from_dict(
|
|
147
|
+
self.cfg.get_training_sae_cfg_dict(),
|
|
148
|
+
).to_dict()
|
|
149
|
+
)
|
|
150
|
+
else:
|
|
151
|
+
self.sae = override_sae
|
|
152
|
+
self.sae.to(self.cfg.device)
|
|
153
|
+
|
|
154
|
+
def run(self):
|
|
155
|
+
"""
|
|
156
|
+
Run the training of the SAE.
|
|
157
|
+
"""
|
|
158
|
+
self._set_sae_metadata()
|
|
159
|
+
if self.cfg.logger.log_to_wandb:
|
|
160
|
+
wandb.init(
|
|
161
|
+
project=self.cfg.logger.wandb_project,
|
|
162
|
+
entity=self.cfg.logger.wandb_entity,
|
|
163
|
+
config=self.cfg.to_dict(),
|
|
164
|
+
name=self.cfg.logger.run_name,
|
|
165
|
+
id=self.cfg.logger.wandb_id,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
evaluator = LLMSaeEvaluator(
|
|
169
|
+
model=self.model,
|
|
170
|
+
activations_store=self.activations_store,
|
|
171
|
+
eval_batch_size_prompts=self.cfg.eval_batch_size_prompts,
|
|
172
|
+
n_eval_batches=self.cfg.n_eval_batches,
|
|
173
|
+
model_kwargs=self.cfg.model_kwargs,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
trainer = SAETrainer(
|
|
177
|
+
sae=self.sae,
|
|
178
|
+
data_provider=self.activations_store,
|
|
179
|
+
evaluator=evaluator,
|
|
180
|
+
save_checkpoint_fn=self.save_checkpoint,
|
|
181
|
+
cfg=self.cfg.to_sae_trainer_config(),
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
self._compile_if_needed()
|
|
185
|
+
sae = self.run_trainer_with_interruption_handling(trainer)
|
|
186
|
+
|
|
187
|
+
if self.cfg.logger.log_to_wandb:
|
|
188
|
+
wandb.finish()
|
|
189
|
+
|
|
190
|
+
return sae
|
|
191
|
+
|
|
192
|
+
def _set_sae_metadata(self):
|
|
193
|
+
self.sae.cfg.metadata.dataset_path = self.cfg.dataset_path
|
|
194
|
+
self.sae.cfg.metadata.hook_name = self.cfg.hook_name
|
|
195
|
+
self.sae.cfg.metadata.model_name = self.cfg.model_name
|
|
196
|
+
self.sae.cfg.metadata.model_class_name = self.cfg.model_class_name
|
|
197
|
+
self.sae.cfg.metadata.hook_head_index = self.cfg.hook_head_index
|
|
198
|
+
self.sae.cfg.metadata.context_size = self.cfg.context_size
|
|
199
|
+
self.sae.cfg.metadata.seqpos_slice = self.cfg.seqpos_slice
|
|
200
|
+
self.sae.cfg.metadata.model_from_pretrained_kwargs = (
|
|
201
|
+
self.cfg.model_from_pretrained_kwargs
|
|
202
|
+
)
|
|
203
|
+
self.sae.cfg.metadata.prepend_bos = self.cfg.prepend_bos
|
|
204
|
+
self.sae.cfg.metadata.exclude_special_tokens = self.cfg.exclude_special_tokens
|
|
205
|
+
|
|
206
|
+
def _compile_if_needed(self):
|
|
207
|
+
# Compile model and SAE
|
|
208
|
+
# torch.compile can provide significant speedups (10-20% in testing)
|
|
209
|
+
# using max-autotune gives the best speedups but:
|
|
210
|
+
# (a) increases VRAM usage,
|
|
211
|
+
# (b) can't be used on both SAE and LM (some issue with cudagraphs), and
|
|
212
|
+
# (c) takes some time to compile
|
|
213
|
+
# optimal settings seem to be:
|
|
214
|
+
# use max-autotune on SAE and max-autotune-no-cudagraphs on LM
|
|
215
|
+
# (also pylance seems to really hate this)
|
|
216
|
+
if self.cfg.compile_llm:
|
|
217
|
+
self.model = torch.compile(
|
|
218
|
+
self.model,
|
|
219
|
+
mode=self.cfg.llm_compilation_mode,
|
|
220
|
+
) # type: ignore
|
|
221
|
+
|
|
222
|
+
if self.cfg.compile_sae:
|
|
223
|
+
backend = "aot_eager" if self.cfg.device == "mps" else "inductor"
|
|
224
|
+
|
|
225
|
+
self.sae.training_forward_pass = torch.compile( # type: ignore
|
|
226
|
+
self.sae.training_forward_pass,
|
|
227
|
+
mode=self.cfg.sae_compilation_mode,
|
|
228
|
+
backend=backend,
|
|
229
|
+
) # type: ignore
|
|
230
|
+
|
|
231
|
+
def run_trainer_with_interruption_handling(
|
|
232
|
+
self, trainer: SAETrainer[TrainingSAE[TrainingSAEConfig], TrainingSAEConfig]
|
|
233
|
+
):
|
|
234
|
+
try:
|
|
235
|
+
# signal handlers (if preempted)
|
|
236
|
+
signal.signal(signal.SIGINT, interrupt_callback)
|
|
237
|
+
signal.signal(signal.SIGTERM, interrupt_callback)
|
|
238
|
+
|
|
239
|
+
# train SAE
|
|
240
|
+
sae = trainer.fit()
|
|
241
|
+
|
|
242
|
+
except (KeyboardInterrupt, InterruptedException):
|
|
243
|
+
logger.warning("interrupted, saving progress")
|
|
244
|
+
checkpoint_path = Path(self.cfg.checkpoint_path) / str(
|
|
245
|
+
trainer.n_training_samples
|
|
246
|
+
)
|
|
247
|
+
self.save_checkpoint(checkpoint_path)
|
|
248
|
+
logger.info("done saving")
|
|
249
|
+
raise
|
|
250
|
+
|
|
251
|
+
return sae
|
|
252
|
+
|
|
253
|
+
def save_checkpoint(
|
|
254
|
+
self,
|
|
255
|
+
checkpoint_path: Path,
|
|
256
|
+
) -> None:
|
|
257
|
+
self.activations_store.save(
|
|
258
|
+
str(checkpoint_path / ACTIVATIONS_STORE_STATE_FILENAME)
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
runner_config = self.cfg.to_dict()
|
|
262
|
+
with open(checkpoint_path / RUNNER_CFG_FILENAME, "w") as f:
|
|
263
|
+
json.dump(runner_config, f)
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def _parse_cfg_args(
|
|
267
|
+
args: Sequence[str],
|
|
268
|
+
) -> LanguageModelSAERunnerConfig[TrainingSAEConfig]:
|
|
269
|
+
"""
|
|
270
|
+
Parse command line arguments into a LanguageModelSAERunnerConfig.
|
|
271
|
+
|
|
272
|
+
This function first parses the architecture argument to determine which
|
|
273
|
+
concrete SAE config class to use, then parses the full configuration
|
|
274
|
+
with that concrete type.
|
|
275
|
+
"""
|
|
276
|
+
if len(args) == 0:
|
|
277
|
+
args = ["--help"]
|
|
278
|
+
|
|
279
|
+
# First, parse only the architecture to determine which concrete class to use
|
|
280
|
+
architecture_parser = ArgumentParser(
|
|
281
|
+
description="Parse architecture to determine SAE config class",
|
|
282
|
+
exit_on_error=False,
|
|
283
|
+
add_help=False, # Don't add help to avoid conflicts
|
|
284
|
+
)
|
|
285
|
+
architecture_parser.add_argument(
|
|
286
|
+
"--architecture",
|
|
287
|
+
type=str,
|
|
288
|
+
choices=["standard", "gated", "jumprelu", "topk"],
|
|
289
|
+
default="standard",
|
|
290
|
+
help="SAE architecture to use",
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
# Parse known args to extract architecture, ignore unknown args for now
|
|
294
|
+
arch_args, remaining_args = architecture_parser.parse_known_args(args)
|
|
295
|
+
architecture = arch_args.architecture
|
|
296
|
+
|
|
297
|
+
# Remove architecture from remaining args if it exists
|
|
298
|
+
filtered_args = []
|
|
299
|
+
skip_next = False
|
|
300
|
+
for arg in remaining_args:
|
|
301
|
+
if skip_next:
|
|
302
|
+
skip_next = False
|
|
303
|
+
continue
|
|
304
|
+
if arg == "--architecture":
|
|
305
|
+
skip_next = True # Skip the next argument (the architecture value)
|
|
306
|
+
continue
|
|
307
|
+
filtered_args.append(arg)
|
|
308
|
+
|
|
309
|
+
# Create a custom wrapper class that simple_parsing can handle
|
|
310
|
+
def create_config_class(
|
|
311
|
+
sae_config_type: type[TrainingSAEConfig],
|
|
312
|
+
) -> type[LanguageModelSAERunnerConfig[TrainingSAEConfig]]:
|
|
313
|
+
"""Create a concrete config class for the given SAE config type."""
|
|
314
|
+
|
|
315
|
+
# Create the base config without the sae field
|
|
316
|
+
from dataclasses import field as dataclass_field
|
|
317
|
+
from dataclasses import fields, make_dataclass
|
|
318
|
+
|
|
319
|
+
# Get all fields from LanguageModelSAERunnerConfig except the generic sae field
|
|
320
|
+
base_fields = []
|
|
321
|
+
for field_obj in fields(LanguageModelSAERunnerConfig):
|
|
322
|
+
if field_obj.name != "sae":
|
|
323
|
+
base_fields.append((field_obj.name, field_obj.type, field_obj))
|
|
324
|
+
|
|
325
|
+
# Add the concrete sae field
|
|
326
|
+
base_fields.append(
|
|
327
|
+
(
|
|
328
|
+
"sae",
|
|
329
|
+
sae_config_type,
|
|
330
|
+
dataclass_field(
|
|
331
|
+
default_factory=lambda: sae_config_type(d_in=512, d_sae=1024)
|
|
332
|
+
),
|
|
333
|
+
)
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
# Create the concrete class
|
|
337
|
+
return make_dataclass(
|
|
338
|
+
f"{sae_config_type.__name__}RunnerConfig",
|
|
339
|
+
base_fields,
|
|
340
|
+
bases=(LanguageModelSAERunnerConfig,),
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
# Map architecture to concrete config class
|
|
344
|
+
sae_config_map = {
|
|
345
|
+
"standard": StandardTrainingSAEConfig,
|
|
346
|
+
"gated": GatedTrainingSAEConfig,
|
|
347
|
+
"jumprelu": JumpReLUTrainingSAEConfig,
|
|
348
|
+
"topk": TopKTrainingSAEConfig,
|
|
349
|
+
}
|
|
350
|
+
|
|
351
|
+
sae_config_type = sae_config_map[architecture]
|
|
352
|
+
concrete_config_class = create_config_class(sae_config_type)
|
|
353
|
+
|
|
354
|
+
# Now parse the full configuration with the concrete type
|
|
355
|
+
parser = ArgumentParser(exit_on_error=False)
|
|
356
|
+
parser.add_arguments(concrete_config_class, dest="cfg")
|
|
357
|
+
|
|
358
|
+
# Parse the filtered arguments (without --architecture)
|
|
359
|
+
parsed_args = parser.parse_args(filtered_args)
|
|
360
|
+
|
|
361
|
+
# Return the parsed configuration
|
|
362
|
+
return parsed_args.cfg
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
# moved into its own function to make it easier to test
|
|
366
|
+
def _run_cli(args: Sequence[str]):
|
|
367
|
+
cfg = _parse_cfg_args(args)
|
|
368
|
+
LanguageModelSAETrainingRunner(cfg=cfg).run()
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
if __name__ == "__main__":
|
|
372
|
+
_run_cli(args=sys.argv[1:])
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
@deprecated("Use LanguageModelSAETrainingRunner instead")
|
|
376
|
+
class SAETrainingRunner(LanguageModelSAETrainingRunner):
|
|
377
|
+
pass
|