sae-lens 6.0.0rc1__py3-none-any.whl → 6.0.0rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +55 -18
- sae_lens/analysis/hooked_sae_transformer.py +10 -10
- sae_lens/analysis/neuronpedia_integration.py +13 -11
- sae_lens/cache_activations_runner.py +9 -7
- sae_lens/config.py +105 -235
- sae_lens/constants.py +20 -0
- sae_lens/evals.py +34 -31
- sae_lens/{sae_training_runner.py → llm_sae_training_runner.py} +103 -70
- sae_lens/load_model.py +53 -5
- sae_lens/loading/pretrained_sae_loaders.py +36 -10
- sae_lens/registry.py +49 -0
- sae_lens/saes/__init__.py +48 -0
- sae_lens/saes/gated_sae.py +70 -59
- sae_lens/saes/jumprelu_sae.py +58 -72
- sae_lens/saes/sae.py +248 -273
- sae_lens/saes/standard_sae.py +75 -57
- sae_lens/saes/topk_sae.py +72 -83
- sae_lens/training/activation_scaler.py +53 -0
- sae_lens/training/activations_store.py +105 -184
- sae_lens/training/mixing_buffer.py +56 -0
- sae_lens/training/optim.py +60 -36
- sae_lens/training/sae_trainer.py +134 -158
- sae_lens/training/types.py +5 -0
- sae_lens/training/upload_saes_to_huggingface.py +11 -5
- sae_lens/util.py +47 -0
- {sae_lens-6.0.0rc1.dist-info → sae_lens-6.0.0rc3.dist-info}/METADATA +1 -1
- sae_lens-6.0.0rc3.dist-info/RECORD +38 -0
- {sae_lens-6.0.0rc1.dist-info → sae_lens-6.0.0rc3.dist-info}/WHEEL +1 -1
- sae_lens/regsitry.py +0 -34
- sae_lens-6.0.0rc1.dist-info/RECORD +0 -32
- {sae_lens-6.0.0rc1.dist-info → sae_lens-6.0.0rc3.dist-info}/LICENSE +0 -0
sae_lens/evals.py
CHANGED
|
@@ -20,8 +20,10 @@ from transformer_lens import HookedTransformer
|
|
|
20
20
|
from transformer_lens.hook_points import HookedRootModule
|
|
21
21
|
|
|
22
22
|
from sae_lens.loading.pretrained_saes_directory import get_pretrained_saes_directory
|
|
23
|
-
from sae_lens.saes.sae import SAE
|
|
23
|
+
from sae_lens.saes.sae import SAE, SAEConfig
|
|
24
|
+
from sae_lens.training.activation_scaler import ActivationScaler
|
|
24
25
|
from sae_lens.training.activations_store import ActivationsStore
|
|
26
|
+
from sae_lens.util import extract_stop_at_layer_from_tlens_hook_name
|
|
25
27
|
|
|
26
28
|
|
|
27
29
|
def get_library_version() -> str:
|
|
@@ -100,15 +102,16 @@ def get_eval_everything_config(
|
|
|
100
102
|
|
|
101
103
|
@torch.no_grad()
|
|
102
104
|
def run_evals(
|
|
103
|
-
sae: SAE,
|
|
105
|
+
sae: SAE[Any],
|
|
104
106
|
activation_store: ActivationsStore,
|
|
105
107
|
model: HookedRootModule,
|
|
108
|
+
activation_scaler: ActivationScaler,
|
|
106
109
|
eval_config: EvalConfig = EvalConfig(),
|
|
107
110
|
model_kwargs: Mapping[str, Any] = {},
|
|
108
111
|
ignore_tokens: set[int | None] = set(),
|
|
109
112
|
verbose: bool = False,
|
|
110
113
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
111
|
-
hook_name = sae.cfg.hook_name
|
|
114
|
+
hook_name = sae.cfg.metadata.hook_name
|
|
112
115
|
actual_batch_size = (
|
|
113
116
|
eval_config.batch_size_prompts or activation_store.store_batch_size_prompts
|
|
114
117
|
)
|
|
@@ -140,6 +143,7 @@ def run_evals(
|
|
|
140
143
|
sae,
|
|
141
144
|
model,
|
|
142
145
|
activation_store,
|
|
146
|
+
activation_scaler,
|
|
143
147
|
compute_kl=eval_config.compute_kl,
|
|
144
148
|
compute_ce_loss=eval_config.compute_ce_loss,
|
|
145
149
|
n_batches=eval_config.n_eval_reconstruction_batches,
|
|
@@ -189,6 +193,7 @@ def run_evals(
|
|
|
189
193
|
sae,
|
|
190
194
|
model,
|
|
191
195
|
activation_store,
|
|
196
|
+
activation_scaler,
|
|
192
197
|
compute_l2_norms=eval_config.compute_l2_norms,
|
|
193
198
|
compute_sparsity_metrics=eval_config.compute_sparsity_metrics,
|
|
194
199
|
compute_variance_metrics=eval_config.compute_variance_metrics,
|
|
@@ -274,7 +279,7 @@ def run_evals(
|
|
|
274
279
|
return all_metrics, feature_metrics
|
|
275
280
|
|
|
276
281
|
|
|
277
|
-
def get_featurewise_weight_based_metrics(sae: SAE) -> dict[str, Any]:
|
|
282
|
+
def get_featurewise_weight_based_metrics(sae: SAE[Any]) -> dict[str, Any]:
|
|
278
283
|
unit_norm_encoders = (sae.W_enc / sae.W_enc.norm(dim=0, keepdim=True)).cpu()
|
|
279
284
|
unit_norm_decoder = (sae.W_dec.T / sae.W_dec.T.norm(dim=0, keepdim=True)).cpu()
|
|
280
285
|
|
|
@@ -298,9 +303,10 @@ def get_featurewise_weight_based_metrics(sae: SAE) -> dict[str, Any]:
|
|
|
298
303
|
|
|
299
304
|
|
|
300
305
|
def get_downstream_reconstruction_metrics(
|
|
301
|
-
sae: SAE,
|
|
306
|
+
sae: SAE[Any],
|
|
302
307
|
model: HookedRootModule,
|
|
303
308
|
activation_store: ActivationsStore,
|
|
309
|
+
activation_scaler: ActivationScaler,
|
|
304
310
|
compute_kl: bool,
|
|
305
311
|
compute_ce_loss: bool,
|
|
306
312
|
n_batches: int,
|
|
@@ -326,8 +332,8 @@ def get_downstream_reconstruction_metrics(
|
|
|
326
332
|
for metric_name, metric_value in get_recons_loss(
|
|
327
333
|
sae,
|
|
328
334
|
model,
|
|
335
|
+
activation_scaler,
|
|
329
336
|
batch_tokens,
|
|
330
|
-
activation_store,
|
|
331
337
|
compute_kl=compute_kl,
|
|
332
338
|
compute_ce_loss=compute_ce_loss,
|
|
333
339
|
ignore_tokens=ignore_tokens,
|
|
@@ -366,9 +372,10 @@ def get_downstream_reconstruction_metrics(
|
|
|
366
372
|
|
|
367
373
|
|
|
368
374
|
def get_sparsity_and_variance_metrics(
|
|
369
|
-
sae: SAE,
|
|
375
|
+
sae: SAE[Any],
|
|
370
376
|
model: HookedRootModule,
|
|
371
377
|
activation_store: ActivationsStore,
|
|
378
|
+
activation_scaler: ActivationScaler,
|
|
372
379
|
n_batches: int,
|
|
373
380
|
compute_l2_norms: bool,
|
|
374
381
|
compute_sparsity_metrics: bool,
|
|
@@ -379,8 +386,8 @@ def get_sparsity_and_variance_metrics(
|
|
|
379
386
|
ignore_tokens: set[int | None] = set(),
|
|
380
387
|
verbose: bool = False,
|
|
381
388
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
382
|
-
hook_name = sae.cfg.hook_name
|
|
383
|
-
hook_head_index = sae.cfg.hook_head_index
|
|
389
|
+
hook_name = sae.cfg.metadata.hook_name
|
|
390
|
+
hook_head_index = sae.cfg.metadata.hook_head_index
|
|
384
391
|
|
|
385
392
|
metric_dict = {}
|
|
386
393
|
feature_metric_dict = {}
|
|
@@ -436,7 +443,7 @@ def get_sparsity_and_variance_metrics(
|
|
|
436
443
|
batch_tokens,
|
|
437
444
|
prepend_bos=False,
|
|
438
445
|
names_filter=[hook_name],
|
|
439
|
-
stop_at_layer=
|
|
446
|
+
stop_at_layer=extract_stop_at_layer_from_tlens_hook_name(hook_name),
|
|
440
447
|
**model_kwargs,
|
|
441
448
|
)
|
|
442
449
|
|
|
@@ -451,16 +458,14 @@ def get_sparsity_and_variance_metrics(
|
|
|
451
458
|
original_act = cache[hook_name]
|
|
452
459
|
|
|
453
460
|
# normalise if necessary (necessary in training only, otherwise we should fold the scaling in)
|
|
454
|
-
|
|
455
|
-
original_act = activation_store.apply_norm_scaling_factor(original_act)
|
|
461
|
+
original_act = activation_scaler.scale(original_act)
|
|
456
462
|
|
|
457
463
|
# send the (maybe normalised) activations into the SAE
|
|
458
464
|
sae_feature_activations = sae.encode(original_act.to(sae.device))
|
|
459
465
|
sae_out = sae.decode(sae_feature_activations).to(original_act.device)
|
|
460
466
|
del cache
|
|
461
467
|
|
|
462
|
-
|
|
463
|
-
sae_out = activation_store.unscale(sae_out)
|
|
468
|
+
sae_out = activation_scaler.unscale(sae_out)
|
|
464
469
|
|
|
465
470
|
flattened_sae_input = einops.rearrange(original_act, "b ctx d -> (b ctx) d")
|
|
466
471
|
flattened_sae_feature_acts = einops.rearrange(
|
|
@@ -580,17 +585,21 @@ def get_sparsity_and_variance_metrics(
|
|
|
580
585
|
|
|
581
586
|
@torch.no_grad()
|
|
582
587
|
def get_recons_loss(
|
|
583
|
-
sae: SAE,
|
|
588
|
+
sae: SAE[SAEConfig],
|
|
584
589
|
model: HookedRootModule,
|
|
590
|
+
activation_scaler: ActivationScaler,
|
|
585
591
|
batch_tokens: torch.Tensor,
|
|
586
|
-
activation_store: ActivationsStore,
|
|
587
592
|
compute_kl: bool,
|
|
588
593
|
compute_ce_loss: bool,
|
|
589
594
|
ignore_tokens: set[int | None] = set(),
|
|
590
595
|
model_kwargs: Mapping[str, Any] = {},
|
|
596
|
+
hook_name: str | None = None,
|
|
591
597
|
) -> dict[str, Any]:
|
|
592
|
-
hook_name = sae.cfg.hook_name
|
|
593
|
-
head_index = sae.cfg.hook_head_index
|
|
598
|
+
hook_name = hook_name or sae.cfg.metadata.hook_name
|
|
599
|
+
head_index = sae.cfg.metadata.hook_head_index
|
|
600
|
+
|
|
601
|
+
if hook_name is None:
|
|
602
|
+
raise ValueError("hook_name must be provided")
|
|
594
603
|
|
|
595
604
|
original_logits, original_ce_loss = model(
|
|
596
605
|
batch_tokens, return_type="both", loss_per_token=True, **model_kwargs
|
|
@@ -614,15 +623,13 @@ def get_recons_loss(
|
|
|
614
623
|
activations = activations.to(sae.device)
|
|
615
624
|
|
|
616
625
|
# Handle rescaling if SAE expects it
|
|
617
|
-
|
|
618
|
-
activations = activation_store.apply_norm_scaling_factor(activations)
|
|
626
|
+
activations = activation_scaler.scale(activations)
|
|
619
627
|
|
|
620
628
|
# SAE class agnost forward forward pass.
|
|
621
629
|
new_activations = sae.decode(sae.encode(activations)).to(activations.dtype)
|
|
622
630
|
|
|
623
631
|
# Unscale if activations were scaled prior to going into the SAE
|
|
624
|
-
|
|
625
|
-
new_activations = activation_store.unscale(new_activations)
|
|
632
|
+
new_activations = activation_scaler.unscale(new_activations)
|
|
626
633
|
|
|
627
634
|
new_activations = torch.where(mask[..., None], new_activations, activations)
|
|
628
635
|
|
|
@@ -633,8 +640,7 @@ def get_recons_loss(
|
|
|
633
640
|
activations = activations.to(sae.device)
|
|
634
641
|
|
|
635
642
|
# Handle rescaling if SAE expects it
|
|
636
|
-
|
|
637
|
-
activations = activation_store.apply_norm_scaling_factor(activations)
|
|
643
|
+
activations = activation_scaler.scale(activations)
|
|
638
644
|
|
|
639
645
|
# SAE class agnost forward forward pass.
|
|
640
646
|
new_activations = sae.decode(sae.encode(activations.flatten(-2, -1))).to(
|
|
@@ -646,8 +652,7 @@ def get_recons_loss(
|
|
|
646
652
|
) # reshape to match original shape
|
|
647
653
|
|
|
648
654
|
# Unscale if activations were scaled prior to going into the SAE
|
|
649
|
-
|
|
650
|
-
new_activations = activation_store.unscale(new_activations)
|
|
655
|
+
new_activations = activation_scaler.unscale(new_activations)
|
|
651
656
|
|
|
652
657
|
return new_activations.to(original_device)
|
|
653
658
|
|
|
@@ -656,8 +661,7 @@ def get_recons_loss(
|
|
|
656
661
|
activations = activations.to(sae.device)
|
|
657
662
|
|
|
658
663
|
# Handle rescaling if SAE expects it
|
|
659
|
-
|
|
660
|
-
activations = activation_store.apply_norm_scaling_factor(activations)
|
|
664
|
+
activations = activation_scaler.scale(activations)
|
|
661
665
|
|
|
662
666
|
new_activations = sae.decode(sae.encode(activations[:, :, head_index])).to(
|
|
663
667
|
activations.dtype
|
|
@@ -665,8 +669,7 @@ def get_recons_loss(
|
|
|
665
669
|
activations[:, :, head_index] = new_activations
|
|
666
670
|
|
|
667
671
|
# Unscale if activations were scaled prior to going into the SAE
|
|
668
|
-
|
|
669
|
-
activations = activation_store.unscale(activations)
|
|
672
|
+
activations = activation_scaler.unscale(activations)
|
|
670
673
|
|
|
671
674
|
return activations.to(original_device)
|
|
672
675
|
|
|
@@ -806,7 +809,6 @@ def multiple_evals(
|
|
|
806
809
|
|
|
807
810
|
current_model = None
|
|
808
811
|
current_model_str = None
|
|
809
|
-
print(filtered_saes)
|
|
810
812
|
for sae_release_name, sae_id, _, _ in tqdm(filtered_saes):
|
|
811
813
|
sae = SAE.from_pretrained(
|
|
812
814
|
release=sae_release_name, # see other options in sae_lens/pretrained_saes.yaml
|
|
@@ -846,6 +848,7 @@ def multiple_evals(
|
|
|
846
848
|
scalar_metrics, feature_metrics = run_evals(
|
|
847
849
|
sae=sae,
|
|
848
850
|
activation_store=activation_store,
|
|
851
|
+
activation_scaler=ActivationScaler(),
|
|
849
852
|
model=current_model,
|
|
850
853
|
eval_config=eval_config,
|
|
851
854
|
ignore_tokens={
|
|
@@ -2,21 +2,31 @@ import json
|
|
|
2
2
|
import signal
|
|
3
3
|
import sys
|
|
4
4
|
from collections.abc import Sequence
|
|
5
|
+
from dataclasses import dataclass
|
|
5
6
|
from pathlib import Path
|
|
6
|
-
from typing import Any, cast
|
|
7
|
+
from typing import Any, Generic, cast
|
|
7
8
|
|
|
8
9
|
import torch
|
|
9
10
|
import wandb
|
|
10
11
|
from simple_parsing import ArgumentParser
|
|
11
12
|
from transformer_lens.hook_points import HookedRootModule
|
|
13
|
+
from typing_extensions import deprecated
|
|
12
14
|
|
|
13
15
|
from sae_lens import logger
|
|
14
16
|
from sae_lens.config import HfDataset, LanguageModelSAERunnerConfig
|
|
17
|
+
from sae_lens.constants import ACTIVATIONS_STORE_STATE_FILENAME, RUNNER_CFG_FILENAME
|
|
18
|
+
from sae_lens.evals import EvalConfig, run_evals
|
|
15
19
|
from sae_lens.load_model import load_model
|
|
16
|
-
from sae_lens.saes.sae import
|
|
20
|
+
from sae_lens.saes.sae import (
|
|
21
|
+
T_TRAINING_SAE,
|
|
22
|
+
T_TRAINING_SAE_CONFIG,
|
|
23
|
+
TrainingSAE,
|
|
24
|
+
TrainingSAEConfig,
|
|
25
|
+
)
|
|
26
|
+
from sae_lens.training.activation_scaler import ActivationScaler
|
|
17
27
|
from sae_lens.training.activations_store import ActivationsStore
|
|
18
|
-
from sae_lens.training.geometric_median import compute_geometric_median
|
|
19
28
|
from sae_lens.training.sae_trainer import SAETrainer
|
|
29
|
+
from sae_lens.training.types import DataProvider
|
|
20
30
|
|
|
21
31
|
|
|
22
32
|
class InterruptedException(Exception):
|
|
@@ -27,22 +37,73 @@ def interrupt_callback(sig_num: Any, stack_frame: Any): # noqa: ARG001
|
|
|
27
37
|
raise InterruptedException()
|
|
28
38
|
|
|
29
39
|
|
|
30
|
-
|
|
40
|
+
@dataclass
|
|
41
|
+
class LLMSaeEvaluator(Generic[T_TRAINING_SAE]):
|
|
42
|
+
model: HookedRootModule
|
|
43
|
+
activations_store: ActivationsStore
|
|
44
|
+
eval_batch_size_prompts: int | None
|
|
45
|
+
n_eval_batches: int
|
|
46
|
+
model_kwargs: dict[str, Any]
|
|
47
|
+
|
|
48
|
+
def __call__(
|
|
49
|
+
self,
|
|
50
|
+
sae: T_TRAINING_SAE,
|
|
51
|
+
data_provider: DataProvider,
|
|
52
|
+
activation_scaler: ActivationScaler,
|
|
53
|
+
) -> dict[str, Any]:
|
|
54
|
+
ignore_tokens = set()
|
|
55
|
+
if self.activations_store.exclude_special_tokens is not None:
|
|
56
|
+
ignore_tokens = set(self.activations_store.exclude_special_tokens.tolist())
|
|
57
|
+
|
|
58
|
+
eval_config = EvalConfig(
|
|
59
|
+
batch_size_prompts=self.eval_batch_size_prompts,
|
|
60
|
+
n_eval_reconstruction_batches=self.n_eval_batches,
|
|
61
|
+
n_eval_sparsity_variance_batches=self.n_eval_batches,
|
|
62
|
+
compute_ce_loss=True,
|
|
63
|
+
compute_l2_norms=True,
|
|
64
|
+
compute_sparsity_metrics=True,
|
|
65
|
+
compute_variance_metrics=True,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
eval_metrics, _ = run_evals(
|
|
69
|
+
sae=sae,
|
|
70
|
+
activation_store=self.activations_store,
|
|
71
|
+
model=self.model,
|
|
72
|
+
activation_scaler=activation_scaler,
|
|
73
|
+
eval_config=eval_config,
|
|
74
|
+
ignore_tokens=ignore_tokens,
|
|
75
|
+
model_kwargs=self.model_kwargs,
|
|
76
|
+
) # not calculating featurwise metrics here.
|
|
77
|
+
|
|
78
|
+
# Remove eval metrics that are already logged during training
|
|
79
|
+
eval_metrics.pop("metrics/explained_variance", None)
|
|
80
|
+
eval_metrics.pop("metrics/explained_variance_std", None)
|
|
81
|
+
eval_metrics.pop("metrics/l0", None)
|
|
82
|
+
eval_metrics.pop("metrics/l1", None)
|
|
83
|
+
eval_metrics.pop("metrics/mse", None)
|
|
84
|
+
|
|
85
|
+
# Remove metrics that are not useful for wandb logging
|
|
86
|
+
eval_metrics.pop("metrics/total_tokens_evaluated", None)
|
|
87
|
+
|
|
88
|
+
return eval_metrics
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class LanguageModelSAETrainingRunner:
|
|
31
92
|
"""
|
|
32
93
|
Class to run the training of a Sparse Autoencoder (SAE) on a TransformerLens model.
|
|
33
94
|
"""
|
|
34
95
|
|
|
35
|
-
cfg: LanguageModelSAERunnerConfig
|
|
96
|
+
cfg: LanguageModelSAERunnerConfig[Any]
|
|
36
97
|
model: HookedRootModule
|
|
37
|
-
sae: TrainingSAE
|
|
98
|
+
sae: TrainingSAE[Any]
|
|
38
99
|
activations_store: ActivationsStore
|
|
39
100
|
|
|
40
101
|
def __init__(
|
|
41
102
|
self,
|
|
42
|
-
cfg: LanguageModelSAERunnerConfig,
|
|
103
|
+
cfg: LanguageModelSAERunnerConfig[T_TRAINING_SAE_CONFIG],
|
|
43
104
|
override_dataset: HfDataset | None = None,
|
|
44
105
|
override_model: HookedRootModule | None = None,
|
|
45
|
-
override_sae: TrainingSAE | None = None,
|
|
106
|
+
override_sae: TrainingSAE[Any] | None = None,
|
|
46
107
|
):
|
|
47
108
|
if override_dataset is not None:
|
|
48
109
|
logger.warning(
|
|
@@ -82,7 +143,6 @@ class SAETrainingRunner:
|
|
|
82
143
|
self.cfg.get_training_sae_cfg_dict(),
|
|
83
144
|
).to_dict()
|
|
84
145
|
)
|
|
85
|
-
self._init_sae_group_b_decs()
|
|
86
146
|
else:
|
|
87
147
|
self.sae = override_sae
|
|
88
148
|
|
|
@@ -100,12 +160,20 @@ class SAETrainingRunner:
|
|
|
100
160
|
id=self.cfg.logger.wandb_id,
|
|
101
161
|
)
|
|
102
162
|
|
|
103
|
-
|
|
163
|
+
evaluator = LLMSaeEvaluator(
|
|
104
164
|
model=self.model,
|
|
165
|
+
activations_store=self.activations_store,
|
|
166
|
+
eval_batch_size_prompts=self.cfg.eval_batch_size_prompts,
|
|
167
|
+
n_eval_batches=self.cfg.n_eval_batches,
|
|
168
|
+
model_kwargs=self.cfg.model_kwargs,
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
trainer = SAETrainer(
|
|
105
172
|
sae=self.sae,
|
|
106
|
-
|
|
173
|
+
data_provider=self.activations_store,
|
|
174
|
+
evaluator=evaluator,
|
|
107
175
|
save_checkpoint_fn=self.save_checkpoint,
|
|
108
|
-
cfg=self.cfg,
|
|
176
|
+
cfg=self.cfg.to_sae_trainer_config(),
|
|
109
177
|
)
|
|
110
178
|
|
|
111
179
|
self._compile_if_needed()
|
|
@@ -141,7 +209,9 @@ class SAETrainingRunner:
|
|
|
141
209
|
backend=backend,
|
|
142
210
|
) # type: ignore
|
|
143
211
|
|
|
144
|
-
def run_trainer_with_interruption_handling(
|
|
212
|
+
def run_trainer_with_interruption_handling(
|
|
213
|
+
self, trainer: SAETrainer[TrainingSAE[TrainingSAEConfig], TrainingSAEConfig]
|
|
214
|
+
):
|
|
145
215
|
try:
|
|
146
216
|
# signal handlers (if preempted)
|
|
147
217
|
signal.signal(signal.SIGINT, interrupt_callback)
|
|
@@ -152,73 +222,31 @@ class SAETrainingRunner:
|
|
|
152
222
|
|
|
153
223
|
except (KeyboardInterrupt, InterruptedException):
|
|
154
224
|
logger.warning("interrupted, saving progress")
|
|
155
|
-
|
|
156
|
-
|
|
225
|
+
checkpoint_path = Path(self.cfg.checkpoint_path) / str(
|
|
226
|
+
trainer.n_training_samples
|
|
227
|
+
)
|
|
228
|
+
self.save_checkpoint(checkpoint_path)
|
|
157
229
|
logger.info("done saving")
|
|
158
230
|
raise
|
|
159
231
|
|
|
160
232
|
return sae
|
|
161
233
|
|
|
162
|
-
# TODO: move this into the SAE trainer or Training SAE class
|
|
163
|
-
def _init_sae_group_b_decs(
|
|
164
|
-
self,
|
|
165
|
-
) -> None:
|
|
166
|
-
"""
|
|
167
|
-
extract all activations at a certain layer and use for sae b_dec initialization
|
|
168
|
-
"""
|
|
169
|
-
|
|
170
|
-
if self.cfg.b_dec_init_method == "geometric_median":
|
|
171
|
-
self.activations_store.set_norm_scaling_factor_if_needed()
|
|
172
|
-
layer_acts = self.activations_store.storage_buffer.detach()[:, 0, :]
|
|
173
|
-
# get geometric median of the activations if we're using those.
|
|
174
|
-
median = compute_geometric_median(
|
|
175
|
-
layer_acts,
|
|
176
|
-
maxiter=100,
|
|
177
|
-
).median
|
|
178
|
-
self.sae.initialize_b_dec_with_precalculated(median)
|
|
179
|
-
elif self.cfg.b_dec_init_method == "mean":
|
|
180
|
-
self.activations_store.set_norm_scaling_factor_if_needed()
|
|
181
|
-
layer_acts = self.activations_store.storage_buffer.detach().cpu()[:, 0, :]
|
|
182
|
-
self.sae.initialize_b_dec_with_mean(layer_acts) # type: ignore
|
|
183
|
-
|
|
184
|
-
@staticmethod
|
|
185
234
|
def save_checkpoint(
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
wandb_aliases: list[str] | None = None,
|
|
235
|
+
self,
|
|
236
|
+
checkpoint_path: Path,
|
|
189
237
|
) -> None:
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
trainer.activations_store.save(
|
|
194
|
-
str(base_path / "activations_store_state.safetensors")
|
|
195
|
-
)
|
|
196
|
-
|
|
197
|
-
if trainer.sae.cfg.normalize_sae_decoder:
|
|
198
|
-
trainer.sae.set_decoder_norm_to_unit_norm()
|
|
199
|
-
|
|
200
|
-
weights_path, cfg_path, sparsity_path = trainer.sae.save_model(
|
|
201
|
-
str(base_path),
|
|
202
|
-
trainer.log_feature_sparsity,
|
|
238
|
+
self.activations_store.save(
|
|
239
|
+
str(checkpoint_path / ACTIVATIONS_STORE_STATE_FILENAME)
|
|
203
240
|
)
|
|
204
241
|
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
with open(cfg_path, "w") as f:
|
|
209
|
-
json.dump(config, f)
|
|
210
|
-
|
|
211
|
-
if trainer.cfg.logger.log_to_wandb:
|
|
212
|
-
trainer.cfg.logger.log(
|
|
213
|
-
trainer,
|
|
214
|
-
weights_path,
|
|
215
|
-
cfg_path,
|
|
216
|
-
sparsity_path=sparsity_path,
|
|
217
|
-
wandb_aliases=wandb_aliases,
|
|
218
|
-
)
|
|
242
|
+
runner_config = self.cfg.to_dict()
|
|
243
|
+
with open(checkpoint_path / RUNNER_CFG_FILENAME, "w") as f:
|
|
244
|
+
json.dump(runner_config, f)
|
|
219
245
|
|
|
220
246
|
|
|
221
|
-
def _parse_cfg_args(
|
|
247
|
+
def _parse_cfg_args(
|
|
248
|
+
args: Sequence[str],
|
|
249
|
+
) -> LanguageModelSAERunnerConfig[TrainingSAEConfig]:
|
|
222
250
|
if len(args) == 0:
|
|
223
251
|
args = ["--help"]
|
|
224
252
|
parser = ArgumentParser(exit_on_error=False)
|
|
@@ -229,8 +257,13 @@ def _parse_cfg_args(args: Sequence[str]) -> LanguageModelSAERunnerConfig:
|
|
|
229
257
|
# moved into its own function to make it easier to test
|
|
230
258
|
def _run_cli(args: Sequence[str]):
|
|
231
259
|
cfg = _parse_cfg_args(args)
|
|
232
|
-
|
|
260
|
+
LanguageModelSAETrainingRunner(cfg=cfg).run()
|
|
233
261
|
|
|
234
262
|
|
|
235
263
|
if __name__ == "__main__":
|
|
236
264
|
_run_cli(args=sys.argv[1:])
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
@deprecated("Use LanguageModelSAETrainingRunner instead")
|
|
268
|
+
class SAETrainingRunner(LanguageModelSAETrainingRunner):
|
|
269
|
+
pass
|
sae_lens/load_model.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Any, Literal, cast
|
|
1
|
+
from typing import Any, Callable, Literal, cast
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
from transformer_lens import HookedTransformer
|
|
@@ -77,6 +77,7 @@ class HookedProxyLM(HookedRootModule):
|
|
|
77
77
|
# copied and modified from base HookedRootModule
|
|
78
78
|
def setup(self):
|
|
79
79
|
self.mod_dict = {}
|
|
80
|
+
self.named_modules_dict = {}
|
|
80
81
|
self.hook_dict: dict[str, HookPoint] = {}
|
|
81
82
|
for name, module in self.model.named_modules():
|
|
82
83
|
if name == "":
|
|
@@ -89,14 +90,21 @@ class HookedProxyLM(HookedRootModule):
|
|
|
89
90
|
|
|
90
91
|
self.hook_dict[name] = hook_point
|
|
91
92
|
self.mod_dict[name] = hook_point
|
|
93
|
+
self.named_modules_dict[name] = module
|
|
94
|
+
|
|
95
|
+
def run_with_cache(self, *args: Any, **kwargs: Any): # type: ignore
|
|
96
|
+
if "names_filter" in kwargs:
|
|
97
|
+
# hacky way to make sure that the names_filter is passed to our forward method
|
|
98
|
+
kwargs["_names_filter"] = kwargs["names_filter"]
|
|
99
|
+
return super().run_with_cache(*args, **kwargs)
|
|
92
100
|
|
|
93
101
|
def forward(
|
|
94
102
|
self,
|
|
95
103
|
tokens: torch.Tensor,
|
|
96
104
|
return_type: Literal["both", "logits"] = "logits",
|
|
97
105
|
loss_per_token: bool = False,
|
|
98
|
-
# TODO: implement real support for stop_at_layer
|
|
99
106
|
stop_at_layer: int | None = None,
|
|
107
|
+
_names_filter: list[str] | None = None,
|
|
100
108
|
**kwargs: Any,
|
|
101
109
|
) -> Output | Loss:
|
|
102
110
|
# This is just what's needed for evals, not everything that HookedTransformer has
|
|
@@ -107,8 +115,28 @@ class HookedProxyLM(HookedRootModule):
|
|
|
107
115
|
raise NotImplementedError(
|
|
108
116
|
"Only return_type supported is 'both' or 'logits' to match what's in evals.py and ActivationsStore"
|
|
109
117
|
)
|
|
110
|
-
|
|
111
|
-
|
|
118
|
+
|
|
119
|
+
stop_hooks = []
|
|
120
|
+
if stop_at_layer is not None and _names_filter is not None:
|
|
121
|
+
if return_type != "logits":
|
|
122
|
+
raise NotImplementedError(
|
|
123
|
+
"stop_at_layer is not supported for return_type='both'"
|
|
124
|
+
)
|
|
125
|
+
stop_manager = StopManager(_names_filter)
|
|
126
|
+
|
|
127
|
+
for hook_name in _names_filter:
|
|
128
|
+
module = self.named_modules_dict[hook_name]
|
|
129
|
+
stop_fn = stop_manager.get_stop_hook_fn(hook_name)
|
|
130
|
+
stop_hooks.append(module.register_forward_hook(stop_fn))
|
|
131
|
+
try:
|
|
132
|
+
output = self.model(tokens)
|
|
133
|
+
logits = _extract_logits_from_output(output)
|
|
134
|
+
except StopForward:
|
|
135
|
+
# If we stop early, we don't care about the return output
|
|
136
|
+
return None # type: ignore
|
|
137
|
+
finally:
|
|
138
|
+
for stop_hook in stop_hooks:
|
|
139
|
+
stop_hook.remove()
|
|
112
140
|
|
|
113
141
|
if return_type == "logits":
|
|
114
142
|
return logits
|
|
@@ -159,7 +187,7 @@ class HookedProxyLM(HookedRootModule):
|
|
|
159
187
|
|
|
160
188
|
# We don't want to prepend bos but the tokenizer does it automatically, so we remove it manually
|
|
161
189
|
if hasattr(self.tokenizer, "add_bos_token") and self.tokenizer.add_bos_token: # type: ignore
|
|
162
|
-
tokens = get_tokens_with_bos_removed(self.tokenizer, tokens)
|
|
190
|
+
tokens = get_tokens_with_bos_removed(self.tokenizer, tokens) # type: ignore
|
|
163
191
|
return tokens # type: ignore
|
|
164
192
|
|
|
165
193
|
|
|
@@ -183,3 +211,23 @@ def get_hook_fn(hook_point: HookPoint):
|
|
|
183
211
|
return output
|
|
184
212
|
|
|
185
213
|
return hook_fn
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
class StopForward(Exception):
|
|
217
|
+
pass
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
class StopManager:
|
|
221
|
+
def __init__(self, hook_names: list[str]):
|
|
222
|
+
self.hook_names = hook_names
|
|
223
|
+
self.total_hook_names = len(set(hook_names))
|
|
224
|
+
self.called_hook_names = set()
|
|
225
|
+
|
|
226
|
+
def get_stop_hook_fn(self, hook_name: str) -> Callable[[Any, Any, Any], Any]:
|
|
227
|
+
def stop_hook_fn(module: Any, input: Any, output: Any) -> Any: # noqa: ARG001
|
|
228
|
+
self.called_hook_names.add(hook_name)
|
|
229
|
+
if len(self.called_hook_names) == self.total_hook_names:
|
|
230
|
+
raise StopForward()
|
|
231
|
+
return output
|
|
232
|
+
|
|
233
|
+
return stop_hook_fn
|