sae-lens 6.0.0rc2__py3-none-any.whl → 6.0.0rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +6 -3
- sae_lens/cache_activations_runner.py +7 -6
- sae_lens/config.py +47 -5
- sae_lens/constants.py +2 -0
- sae_lens/evals.py +19 -19
- sae_lens/{sae_training_runner.py → llm_sae_training_runner.py} +92 -60
- sae_lens/load_model.py +53 -5
- sae_lens/loading/pretrained_sae_loaders.py +0 -7
- sae_lens/saes/sae.py +0 -3
- sae_lens/training/activation_scaler.py +53 -0
- sae_lens/training/activations_store.py +77 -172
- sae_lens/training/mixing_buffer.py +56 -0
- sae_lens/training/sae_trainer.py +96 -95
- sae_lens/training/types.py +5 -0
- sae_lens/util.py +19 -0
- {sae_lens-6.0.0rc2.dist-info → sae_lens-6.0.0rc3.dist-info}/METADATA +1 -1
- {sae_lens-6.0.0rc2.dist-info → sae_lens-6.0.0rc3.dist-info}/RECORD +19 -16
- {sae_lens-6.0.0rc2.dist-info → sae_lens-6.0.0rc3.dist-info}/LICENSE +0 -0
- {sae_lens-6.0.0rc2.dist-info → sae_lens-6.0.0rc3.dist-info}/WHEEL +0 -0
sae_lens/__init__.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
# ruff: noqa: E402
|
|
2
|
-
__version__ = "6.0.0-rc.
|
|
2
|
+
__version__ = "6.0.0-rc.3"
|
|
3
3
|
|
|
4
4
|
import logging
|
|
5
5
|
|
|
@@ -33,16 +33,17 @@ from .cache_activations_runner import CacheActivationsRunner
|
|
|
33
33
|
from .config import (
|
|
34
34
|
CacheActivationsRunnerConfig,
|
|
35
35
|
LanguageModelSAERunnerConfig,
|
|
36
|
+
LoggingConfig,
|
|
36
37
|
PretokenizeRunnerConfig,
|
|
37
38
|
)
|
|
38
39
|
from .evals import run_evals
|
|
40
|
+
from .llm_sae_training_runner import LanguageModelSAETrainingRunner, SAETrainingRunner
|
|
39
41
|
from .loading.pretrained_sae_loaders import (
|
|
40
42
|
PretrainedSaeDiskLoader,
|
|
41
43
|
PretrainedSaeHuggingfaceLoader,
|
|
42
44
|
)
|
|
43
45
|
from .pretokenize_runner import PretokenizeRunner, pretokenize_runner
|
|
44
46
|
from .registry import register_sae_class, register_sae_training_class
|
|
45
|
-
from .sae_training_runner import SAETrainingRunner
|
|
46
47
|
from .training.activations_store import ActivationsStore
|
|
47
48
|
from .training.upload_saes_to_huggingface import upload_saes_to_huggingface
|
|
48
49
|
|
|
@@ -54,7 +55,7 @@ __all__ = [
|
|
|
54
55
|
"HookedSAETransformer",
|
|
55
56
|
"ActivationsStore",
|
|
56
57
|
"LanguageModelSAERunnerConfig",
|
|
57
|
-
"
|
|
58
|
+
"LanguageModelSAETrainingRunner",
|
|
58
59
|
"CacheActivationsRunnerConfig",
|
|
59
60
|
"CacheActivationsRunner",
|
|
60
61
|
"PretokenizeRunnerConfig",
|
|
@@ -82,6 +83,8 @@ __all__ = [
|
|
|
82
83
|
"JumpReLUSAEConfig",
|
|
83
84
|
"JumpReLUTrainingSAE",
|
|
84
85
|
"JumpReLUTrainingSAEConfig",
|
|
86
|
+
"SAETrainingRunner",
|
|
87
|
+
"LoggingConfig",
|
|
85
88
|
]
|
|
86
89
|
|
|
87
90
|
|
|
@@ -34,7 +34,6 @@ def _mk_activations_store(
|
|
|
34
34
|
dataset=override_dataset or cfg.dataset_path,
|
|
35
35
|
streaming=cfg.streaming,
|
|
36
36
|
hook_name=cfg.hook_name,
|
|
37
|
-
hook_layer=cfg.hook_layer,
|
|
38
37
|
hook_head_index=None,
|
|
39
38
|
context_size=cfg.context_size,
|
|
40
39
|
d_in=cfg.d_in,
|
|
@@ -265,7 +264,7 @@ class CacheActivationsRunner:
|
|
|
265
264
|
|
|
266
265
|
for i in tqdm(range(self.cfg.n_buffers), desc="Caching activations"):
|
|
267
266
|
try:
|
|
268
|
-
buffer = self.activations_store.
|
|
267
|
+
buffer = self.activations_store.get_raw_buffer(
|
|
269
268
|
self.cfg.n_batches_in_buffer, shuffle=False
|
|
270
269
|
)
|
|
271
270
|
shard = self._create_shard(buffer)
|
|
@@ -319,7 +318,7 @@ class CacheActivationsRunner:
|
|
|
319
318
|
def _create_shard(
|
|
320
319
|
self,
|
|
321
320
|
buffer: tuple[
|
|
322
|
-
Float[torch.Tensor, "(bs context_size)
|
|
321
|
+
Float[torch.Tensor, "(bs context_size) d_in"],
|
|
323
322
|
Int[torch.Tensor, "(bs context_size)"] | None,
|
|
324
323
|
],
|
|
325
324
|
) -> Dataset:
|
|
@@ -327,13 +326,15 @@ class CacheActivationsRunner:
|
|
|
327
326
|
acts, token_ids = buffer
|
|
328
327
|
acts = einops.rearrange(
|
|
329
328
|
acts,
|
|
330
|
-
"(bs context_size)
|
|
329
|
+
"(bs context_size) d_in -> bs context_size d_in",
|
|
331
330
|
bs=self.cfg.n_seq_in_buffer,
|
|
332
331
|
context_size=self.context_size,
|
|
333
332
|
d_in=self.cfg.d_in,
|
|
334
|
-
num_layers=len(hook_names),
|
|
335
333
|
)
|
|
336
|
-
shard_dict
|
|
334
|
+
shard_dict: dict[str, object] = {
|
|
335
|
+
hook_name: act_batch
|
|
336
|
+
for hook_name, act_batch in zip(hook_names, [acts], strict=True)
|
|
337
|
+
}
|
|
337
338
|
|
|
338
339
|
if token_ids is not None:
|
|
339
340
|
token_ids = einops.rearrange(
|
sae_lens/config.py
CHANGED
|
@@ -23,7 +23,9 @@ from sae_lens.saes.sae import TrainingSAEConfig
|
|
|
23
23
|
if TYPE_CHECKING:
|
|
24
24
|
pass
|
|
25
25
|
|
|
26
|
-
T_TRAINING_SAE_CONFIG = TypeVar(
|
|
26
|
+
T_TRAINING_SAE_CONFIG = TypeVar(
|
|
27
|
+
"T_TRAINING_SAE_CONFIG", bound=TrainingSAEConfig, covariant=True
|
|
28
|
+
)
|
|
27
29
|
|
|
28
30
|
HfDataset = DatasetDict | Dataset | IterableDatasetDict | IterableDataset
|
|
29
31
|
|
|
@@ -102,7 +104,6 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
|
102
104
|
model_class_name (str): The name of the class of the model to use. This should be either `HookedTransformer` or `HookedMamba`.
|
|
103
105
|
hook_name (str): The name of the hook to use. This should be a valid TransformerLens hook.
|
|
104
106
|
hook_eval (str): NOT CURRENTLY IN USE. The name of the hook to use for evaluation.
|
|
105
|
-
hook_layer (int): The index of the layer to hook. Used to stop forward passes early and speed up processing.
|
|
106
107
|
hook_head_index (int, optional): When the hook is for an activation with a head index, we can specify a specific head to use here.
|
|
107
108
|
dataset_path (str): A Hugging Face dataset path.
|
|
108
109
|
dataset_trust_remote_code (bool): Whether to trust remote code when loading datasets from Huggingface.
|
|
@@ -159,7 +160,6 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
|
159
160
|
model_class_name: str = "HookedTransformer"
|
|
160
161
|
hook_name: str = "blocks.0.hook_mlp_out"
|
|
161
162
|
hook_eval: str = "NOT_IN_USE"
|
|
162
|
-
hook_layer: int = 0
|
|
163
163
|
hook_head_index: int | None = None
|
|
164
164
|
dataset_path: str = ""
|
|
165
165
|
dataset_trust_remote_code: bool = True
|
|
@@ -375,6 +375,28 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
|
375
375
|
|
|
376
376
|
return cls(**cfg)
|
|
377
377
|
|
|
378
|
+
def to_sae_trainer_config(self) -> "SAETrainerConfig":
|
|
379
|
+
return SAETrainerConfig(
|
|
380
|
+
n_checkpoints=self.n_checkpoints,
|
|
381
|
+
checkpoint_path=self.checkpoint_path,
|
|
382
|
+
total_training_samples=self.total_training_tokens,
|
|
383
|
+
device=self.device,
|
|
384
|
+
autocast=self.autocast,
|
|
385
|
+
lr=self.lr,
|
|
386
|
+
lr_end=self.lr_end,
|
|
387
|
+
lr_scheduler_name=self.lr_scheduler_name,
|
|
388
|
+
lr_warm_up_steps=self.lr_warm_up_steps,
|
|
389
|
+
adam_beta1=self.adam_beta1,
|
|
390
|
+
adam_beta2=self.adam_beta2,
|
|
391
|
+
lr_decay_steps=self.lr_decay_steps,
|
|
392
|
+
n_restart_cycles=self.n_restart_cycles,
|
|
393
|
+
total_training_steps=self.total_training_steps,
|
|
394
|
+
train_batch_size_samples=self.train_batch_size_tokens,
|
|
395
|
+
dead_feature_window=self.dead_feature_window,
|
|
396
|
+
feature_sampling_window=self.feature_sampling_window,
|
|
397
|
+
logger=self.logger,
|
|
398
|
+
)
|
|
399
|
+
|
|
378
400
|
|
|
379
401
|
@dataclass
|
|
380
402
|
class CacheActivationsRunnerConfig:
|
|
@@ -386,7 +408,6 @@ class CacheActivationsRunnerConfig:
|
|
|
386
408
|
model_name (str): The name of the model to use.
|
|
387
409
|
model_batch_size (int): How many prompts are in the batch of the language model when generating activations.
|
|
388
410
|
hook_name (str): The name of the hook to use.
|
|
389
|
-
hook_layer (int): The layer of the final hook. Currently only support a single hook, so this should be the same as hook_name.
|
|
390
411
|
d_in (int): Dimension of the model.
|
|
391
412
|
total_training_tokens (int): Total number of tokens to process.
|
|
392
413
|
context_size (int): Context size to process. Can be left as -1 if the dataset is tokenized.
|
|
@@ -416,7 +437,6 @@ class CacheActivationsRunnerConfig:
|
|
|
416
437
|
model_name: str
|
|
417
438
|
model_batch_size: int
|
|
418
439
|
hook_name: str
|
|
419
|
-
hook_layer: int
|
|
420
440
|
d_in: int
|
|
421
441
|
training_tokens: int
|
|
422
442
|
|
|
@@ -576,3 +596,25 @@ class PretokenizeRunnerConfig:
|
|
|
576
596
|
hf_num_shards: int = 64
|
|
577
597
|
hf_revision: str = "main"
|
|
578
598
|
hf_is_private_repo: bool = False
|
|
599
|
+
|
|
600
|
+
|
|
601
|
+
@dataclass
|
|
602
|
+
class SAETrainerConfig:
|
|
603
|
+
n_checkpoints: int
|
|
604
|
+
checkpoint_path: str
|
|
605
|
+
total_training_samples: int
|
|
606
|
+
device: str
|
|
607
|
+
autocast: bool
|
|
608
|
+
lr: float
|
|
609
|
+
lr_end: float | None
|
|
610
|
+
lr_scheduler_name: str
|
|
611
|
+
lr_warm_up_steps: int
|
|
612
|
+
adam_beta1: float
|
|
613
|
+
adam_beta2: float
|
|
614
|
+
lr_decay_steps: int
|
|
615
|
+
n_restart_cycles: int
|
|
616
|
+
total_training_steps: int
|
|
617
|
+
train_batch_size_samples: int
|
|
618
|
+
dead_feature_window: int
|
|
619
|
+
feature_sampling_window: int
|
|
620
|
+
logger: LoggingConfig
|
sae_lens/constants.py
CHANGED
|
@@ -16,3 +16,5 @@ SPARSITY_FILENAME = "sparsity.safetensors"
|
|
|
16
16
|
SAE_WEIGHTS_FILENAME = "sae_weights.safetensors"
|
|
17
17
|
SAE_CFG_FILENAME = "cfg.json"
|
|
18
18
|
RUNNER_CFG_FILENAME = "runner_cfg.json"
|
|
19
|
+
ACTIVATIONS_STORE_STATE_FILENAME = "activations_store_state.safetensors"
|
|
20
|
+
ACTIVATION_SCALER_CFG_FILENAME = "activation_scaler.json"
|
sae_lens/evals.py
CHANGED
|
@@ -21,7 +21,9 @@ from transformer_lens.hook_points import HookedRootModule
|
|
|
21
21
|
|
|
22
22
|
from sae_lens.loading.pretrained_saes_directory import get_pretrained_saes_directory
|
|
23
23
|
from sae_lens.saes.sae import SAE, SAEConfig
|
|
24
|
+
from sae_lens.training.activation_scaler import ActivationScaler
|
|
24
25
|
from sae_lens.training.activations_store import ActivationsStore
|
|
26
|
+
from sae_lens.util import extract_stop_at_layer_from_tlens_hook_name
|
|
25
27
|
|
|
26
28
|
|
|
27
29
|
def get_library_version() -> str:
|
|
@@ -103,6 +105,7 @@ def run_evals(
|
|
|
103
105
|
sae: SAE[Any],
|
|
104
106
|
activation_store: ActivationsStore,
|
|
105
107
|
model: HookedRootModule,
|
|
108
|
+
activation_scaler: ActivationScaler,
|
|
106
109
|
eval_config: EvalConfig = EvalConfig(),
|
|
107
110
|
model_kwargs: Mapping[str, Any] = {},
|
|
108
111
|
ignore_tokens: set[int | None] = set(),
|
|
@@ -140,6 +143,7 @@ def run_evals(
|
|
|
140
143
|
sae,
|
|
141
144
|
model,
|
|
142
145
|
activation_store,
|
|
146
|
+
activation_scaler,
|
|
143
147
|
compute_kl=eval_config.compute_kl,
|
|
144
148
|
compute_ce_loss=eval_config.compute_ce_loss,
|
|
145
149
|
n_batches=eval_config.n_eval_reconstruction_batches,
|
|
@@ -189,6 +193,7 @@ def run_evals(
|
|
|
189
193
|
sae,
|
|
190
194
|
model,
|
|
191
195
|
activation_store,
|
|
196
|
+
activation_scaler,
|
|
192
197
|
compute_l2_norms=eval_config.compute_l2_norms,
|
|
193
198
|
compute_sparsity_metrics=eval_config.compute_sparsity_metrics,
|
|
194
199
|
compute_variance_metrics=eval_config.compute_variance_metrics,
|
|
@@ -301,6 +306,7 @@ def get_downstream_reconstruction_metrics(
|
|
|
301
306
|
sae: SAE[Any],
|
|
302
307
|
model: HookedRootModule,
|
|
303
308
|
activation_store: ActivationsStore,
|
|
309
|
+
activation_scaler: ActivationScaler,
|
|
304
310
|
compute_kl: bool,
|
|
305
311
|
compute_ce_loss: bool,
|
|
306
312
|
n_batches: int,
|
|
@@ -326,8 +332,8 @@ def get_downstream_reconstruction_metrics(
|
|
|
326
332
|
for metric_name, metric_value in get_recons_loss(
|
|
327
333
|
sae,
|
|
328
334
|
model,
|
|
335
|
+
activation_scaler,
|
|
329
336
|
batch_tokens,
|
|
330
|
-
activation_store,
|
|
331
337
|
compute_kl=compute_kl,
|
|
332
338
|
compute_ce_loss=compute_ce_loss,
|
|
333
339
|
ignore_tokens=ignore_tokens,
|
|
@@ -369,6 +375,7 @@ def get_sparsity_and_variance_metrics(
|
|
|
369
375
|
sae: SAE[Any],
|
|
370
376
|
model: HookedRootModule,
|
|
371
377
|
activation_store: ActivationsStore,
|
|
378
|
+
activation_scaler: ActivationScaler,
|
|
372
379
|
n_batches: int,
|
|
373
380
|
compute_l2_norms: bool,
|
|
374
381
|
compute_sparsity_metrics: bool,
|
|
@@ -436,7 +443,7 @@ def get_sparsity_and_variance_metrics(
|
|
|
436
443
|
batch_tokens,
|
|
437
444
|
prepend_bos=False,
|
|
438
445
|
names_filter=[hook_name],
|
|
439
|
-
stop_at_layer=
|
|
446
|
+
stop_at_layer=extract_stop_at_layer_from_tlens_hook_name(hook_name),
|
|
440
447
|
**model_kwargs,
|
|
441
448
|
)
|
|
442
449
|
|
|
@@ -451,16 +458,14 @@ def get_sparsity_and_variance_metrics(
|
|
|
451
458
|
original_act = cache[hook_name]
|
|
452
459
|
|
|
453
460
|
# normalise if necessary (necessary in training only, otherwise we should fold the scaling in)
|
|
454
|
-
|
|
455
|
-
original_act = activation_store.apply_norm_scaling_factor(original_act)
|
|
461
|
+
original_act = activation_scaler.scale(original_act)
|
|
456
462
|
|
|
457
463
|
# send the (maybe normalised) activations into the SAE
|
|
458
464
|
sae_feature_activations = sae.encode(original_act.to(sae.device))
|
|
459
465
|
sae_out = sae.decode(sae_feature_activations).to(original_act.device)
|
|
460
466
|
del cache
|
|
461
467
|
|
|
462
|
-
|
|
463
|
-
sae_out = activation_store.unscale(sae_out)
|
|
468
|
+
sae_out = activation_scaler.unscale(sae_out)
|
|
464
469
|
|
|
465
470
|
flattened_sae_input = einops.rearrange(original_act, "b ctx d -> (b ctx) d")
|
|
466
471
|
flattened_sae_feature_acts = einops.rearrange(
|
|
@@ -582,8 +587,8 @@ def get_sparsity_and_variance_metrics(
|
|
|
582
587
|
def get_recons_loss(
|
|
583
588
|
sae: SAE[SAEConfig],
|
|
584
589
|
model: HookedRootModule,
|
|
590
|
+
activation_scaler: ActivationScaler,
|
|
585
591
|
batch_tokens: torch.Tensor,
|
|
586
|
-
activation_store: ActivationsStore,
|
|
587
592
|
compute_kl: bool,
|
|
588
593
|
compute_ce_loss: bool,
|
|
589
594
|
ignore_tokens: set[int | None] = set(),
|
|
@@ -618,15 +623,13 @@ def get_recons_loss(
|
|
|
618
623
|
activations = activations.to(sae.device)
|
|
619
624
|
|
|
620
625
|
# Handle rescaling if SAE expects it
|
|
621
|
-
|
|
622
|
-
activations = activation_store.apply_norm_scaling_factor(activations)
|
|
626
|
+
activations = activation_scaler.scale(activations)
|
|
623
627
|
|
|
624
628
|
# SAE class agnost forward forward pass.
|
|
625
629
|
new_activations = sae.decode(sae.encode(activations)).to(activations.dtype)
|
|
626
630
|
|
|
627
631
|
# Unscale if activations were scaled prior to going into the SAE
|
|
628
|
-
|
|
629
|
-
new_activations = activation_store.unscale(new_activations)
|
|
632
|
+
new_activations = activation_scaler.unscale(new_activations)
|
|
630
633
|
|
|
631
634
|
new_activations = torch.where(mask[..., None], new_activations, activations)
|
|
632
635
|
|
|
@@ -637,8 +640,7 @@ def get_recons_loss(
|
|
|
637
640
|
activations = activations.to(sae.device)
|
|
638
641
|
|
|
639
642
|
# Handle rescaling if SAE expects it
|
|
640
|
-
|
|
641
|
-
activations = activation_store.apply_norm_scaling_factor(activations)
|
|
643
|
+
activations = activation_scaler.scale(activations)
|
|
642
644
|
|
|
643
645
|
# SAE class agnost forward forward pass.
|
|
644
646
|
new_activations = sae.decode(sae.encode(activations.flatten(-2, -1))).to(
|
|
@@ -650,8 +652,7 @@ def get_recons_loss(
|
|
|
650
652
|
) # reshape to match original shape
|
|
651
653
|
|
|
652
654
|
# Unscale if activations were scaled prior to going into the SAE
|
|
653
|
-
|
|
654
|
-
new_activations = activation_store.unscale(new_activations)
|
|
655
|
+
new_activations = activation_scaler.unscale(new_activations)
|
|
655
656
|
|
|
656
657
|
return new_activations.to(original_device)
|
|
657
658
|
|
|
@@ -660,8 +661,7 @@ def get_recons_loss(
|
|
|
660
661
|
activations = activations.to(sae.device)
|
|
661
662
|
|
|
662
663
|
# Handle rescaling if SAE expects it
|
|
663
|
-
|
|
664
|
-
activations = activation_store.apply_norm_scaling_factor(activations)
|
|
664
|
+
activations = activation_scaler.scale(activations)
|
|
665
665
|
|
|
666
666
|
new_activations = sae.decode(sae.encode(activations[:, :, head_index])).to(
|
|
667
667
|
activations.dtype
|
|
@@ -669,8 +669,7 @@ def get_recons_loss(
|
|
|
669
669
|
activations[:, :, head_index] = new_activations
|
|
670
670
|
|
|
671
671
|
# Unscale if activations were scaled prior to going into the SAE
|
|
672
|
-
|
|
673
|
-
activations = activation_store.unscale(activations)
|
|
672
|
+
activations = activation_scaler.unscale(activations)
|
|
674
673
|
|
|
675
674
|
return activations.to(original_device)
|
|
676
675
|
|
|
@@ -849,6 +848,7 @@ def multiple_evals(
|
|
|
849
848
|
scalar_metrics, feature_metrics = run_evals(
|
|
850
849
|
sae=sae,
|
|
851
850
|
activation_store=activation_store,
|
|
851
|
+
activation_scaler=ActivationScaler(),
|
|
852
852
|
model=current_model,
|
|
853
853
|
eval_config=eval_config,
|
|
854
854
|
ignore_tokens={
|
|
@@ -2,23 +2,31 @@ import json
|
|
|
2
2
|
import signal
|
|
3
3
|
import sys
|
|
4
4
|
from collections.abc import Sequence
|
|
5
|
+
from dataclasses import dataclass
|
|
5
6
|
from pathlib import Path
|
|
6
|
-
from typing import Any, cast
|
|
7
|
+
from typing import Any, Generic, cast
|
|
7
8
|
|
|
8
9
|
import torch
|
|
9
10
|
import wandb
|
|
10
|
-
from safetensors.torch import save_file
|
|
11
11
|
from simple_parsing import ArgumentParser
|
|
12
12
|
from transformer_lens.hook_points import HookedRootModule
|
|
13
|
+
from typing_extensions import deprecated
|
|
13
14
|
|
|
14
15
|
from sae_lens import logger
|
|
15
16
|
from sae_lens.config import HfDataset, LanguageModelSAERunnerConfig
|
|
16
|
-
from sae_lens.constants import
|
|
17
|
+
from sae_lens.constants import ACTIVATIONS_STORE_STATE_FILENAME, RUNNER_CFG_FILENAME
|
|
18
|
+
from sae_lens.evals import EvalConfig, run_evals
|
|
17
19
|
from sae_lens.load_model import load_model
|
|
18
|
-
from sae_lens.saes.sae import
|
|
20
|
+
from sae_lens.saes.sae import (
|
|
21
|
+
T_TRAINING_SAE,
|
|
22
|
+
T_TRAINING_SAE_CONFIG,
|
|
23
|
+
TrainingSAE,
|
|
24
|
+
TrainingSAEConfig,
|
|
25
|
+
)
|
|
26
|
+
from sae_lens.training.activation_scaler import ActivationScaler
|
|
19
27
|
from sae_lens.training.activations_store import ActivationsStore
|
|
20
|
-
from sae_lens.training.geometric_median import compute_geometric_median
|
|
21
28
|
from sae_lens.training.sae_trainer import SAETrainer
|
|
29
|
+
from sae_lens.training.types import DataProvider
|
|
22
30
|
|
|
23
31
|
|
|
24
32
|
class InterruptedException(Exception):
|
|
@@ -29,7 +37,58 @@ def interrupt_callback(sig_num: Any, stack_frame: Any): # noqa: ARG001
|
|
|
29
37
|
raise InterruptedException()
|
|
30
38
|
|
|
31
39
|
|
|
32
|
-
|
|
40
|
+
@dataclass
|
|
41
|
+
class LLMSaeEvaluator(Generic[T_TRAINING_SAE]):
|
|
42
|
+
model: HookedRootModule
|
|
43
|
+
activations_store: ActivationsStore
|
|
44
|
+
eval_batch_size_prompts: int | None
|
|
45
|
+
n_eval_batches: int
|
|
46
|
+
model_kwargs: dict[str, Any]
|
|
47
|
+
|
|
48
|
+
def __call__(
|
|
49
|
+
self,
|
|
50
|
+
sae: T_TRAINING_SAE,
|
|
51
|
+
data_provider: DataProvider,
|
|
52
|
+
activation_scaler: ActivationScaler,
|
|
53
|
+
) -> dict[str, Any]:
|
|
54
|
+
ignore_tokens = set()
|
|
55
|
+
if self.activations_store.exclude_special_tokens is not None:
|
|
56
|
+
ignore_tokens = set(self.activations_store.exclude_special_tokens.tolist())
|
|
57
|
+
|
|
58
|
+
eval_config = EvalConfig(
|
|
59
|
+
batch_size_prompts=self.eval_batch_size_prompts,
|
|
60
|
+
n_eval_reconstruction_batches=self.n_eval_batches,
|
|
61
|
+
n_eval_sparsity_variance_batches=self.n_eval_batches,
|
|
62
|
+
compute_ce_loss=True,
|
|
63
|
+
compute_l2_norms=True,
|
|
64
|
+
compute_sparsity_metrics=True,
|
|
65
|
+
compute_variance_metrics=True,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
eval_metrics, _ = run_evals(
|
|
69
|
+
sae=sae,
|
|
70
|
+
activation_store=self.activations_store,
|
|
71
|
+
model=self.model,
|
|
72
|
+
activation_scaler=activation_scaler,
|
|
73
|
+
eval_config=eval_config,
|
|
74
|
+
ignore_tokens=ignore_tokens,
|
|
75
|
+
model_kwargs=self.model_kwargs,
|
|
76
|
+
) # not calculating featurwise metrics here.
|
|
77
|
+
|
|
78
|
+
# Remove eval metrics that are already logged during training
|
|
79
|
+
eval_metrics.pop("metrics/explained_variance", None)
|
|
80
|
+
eval_metrics.pop("metrics/explained_variance_std", None)
|
|
81
|
+
eval_metrics.pop("metrics/l0", None)
|
|
82
|
+
eval_metrics.pop("metrics/l1", None)
|
|
83
|
+
eval_metrics.pop("metrics/mse", None)
|
|
84
|
+
|
|
85
|
+
# Remove metrics that are not useful for wandb logging
|
|
86
|
+
eval_metrics.pop("metrics/total_tokens_evaluated", None)
|
|
87
|
+
|
|
88
|
+
return eval_metrics
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class LanguageModelSAETrainingRunner:
|
|
33
92
|
"""
|
|
34
93
|
Class to run the training of a Sparse Autoencoder (SAE) on a TransformerLens model.
|
|
35
94
|
"""
|
|
@@ -84,7 +143,6 @@ class SAETrainingRunner:
|
|
|
84
143
|
self.cfg.get_training_sae_cfg_dict(),
|
|
85
144
|
).to_dict()
|
|
86
145
|
)
|
|
87
|
-
self._init_sae_group_b_decs()
|
|
88
146
|
else:
|
|
89
147
|
self.sae = override_sae
|
|
90
148
|
|
|
@@ -102,12 +160,20 @@ class SAETrainingRunner:
|
|
|
102
160
|
id=self.cfg.logger.wandb_id,
|
|
103
161
|
)
|
|
104
162
|
|
|
105
|
-
|
|
163
|
+
evaluator = LLMSaeEvaluator(
|
|
106
164
|
model=self.model,
|
|
165
|
+
activations_store=self.activations_store,
|
|
166
|
+
eval_batch_size_prompts=self.cfg.eval_batch_size_prompts,
|
|
167
|
+
n_eval_batches=self.cfg.n_eval_batches,
|
|
168
|
+
model_kwargs=self.cfg.model_kwargs,
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
trainer = SAETrainer(
|
|
107
172
|
sae=self.sae,
|
|
108
|
-
|
|
173
|
+
data_provider=self.activations_store,
|
|
174
|
+
evaluator=evaluator,
|
|
109
175
|
save_checkpoint_fn=self.save_checkpoint,
|
|
110
|
-
cfg=self.cfg,
|
|
176
|
+
cfg=self.cfg.to_sae_trainer_config(),
|
|
111
177
|
)
|
|
112
178
|
|
|
113
179
|
self._compile_if_needed()
|
|
@@ -156,66 +222,27 @@ class SAETrainingRunner:
|
|
|
156
222
|
|
|
157
223
|
except (KeyboardInterrupt, InterruptedException):
|
|
158
224
|
logger.warning("interrupted, saving progress")
|
|
159
|
-
|
|
160
|
-
|
|
225
|
+
checkpoint_path = Path(self.cfg.checkpoint_path) / str(
|
|
226
|
+
trainer.n_training_samples
|
|
227
|
+
)
|
|
228
|
+
self.save_checkpoint(checkpoint_path)
|
|
161
229
|
logger.info("done saving")
|
|
162
230
|
raise
|
|
163
231
|
|
|
164
232
|
return sae
|
|
165
233
|
|
|
166
|
-
# TODO: move this into the SAE trainer or Training SAE class
|
|
167
|
-
def _init_sae_group_b_decs(
|
|
168
|
-
self,
|
|
169
|
-
) -> None:
|
|
170
|
-
"""
|
|
171
|
-
extract all activations at a certain layer and use for sae b_dec initialization
|
|
172
|
-
"""
|
|
173
|
-
|
|
174
|
-
if self.cfg.sae.b_dec_init_method == "geometric_median":
|
|
175
|
-
self.activations_store.set_norm_scaling_factor_if_needed()
|
|
176
|
-
layer_acts = self.activations_store.storage_buffer.detach()[:, 0, :]
|
|
177
|
-
# get geometric median of the activations if we're using those.
|
|
178
|
-
median = compute_geometric_median(
|
|
179
|
-
layer_acts,
|
|
180
|
-
maxiter=100,
|
|
181
|
-
).median
|
|
182
|
-
self.sae.initialize_b_dec_with_precalculated(median)
|
|
183
|
-
elif self.cfg.sae.b_dec_init_method == "mean":
|
|
184
|
-
self.activations_store.set_norm_scaling_factor_if_needed()
|
|
185
|
-
layer_acts = self.activations_store.storage_buffer.detach().cpu()[:, 0, :]
|
|
186
|
-
self.sae.initialize_b_dec_with_mean(layer_acts) # type: ignore
|
|
187
|
-
|
|
188
|
-
@staticmethod
|
|
189
234
|
def save_checkpoint(
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
wandb_aliases: list[str] | None = None,
|
|
235
|
+
self,
|
|
236
|
+
checkpoint_path: Path,
|
|
193
237
|
) -> None:
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
trainer.activations_store.save(
|
|
198
|
-
str(base_path / "activations_store_state.safetensors")
|
|
238
|
+
self.activations_store.save(
|
|
239
|
+
str(checkpoint_path / ACTIVATIONS_STORE_STATE_FILENAME)
|
|
199
240
|
)
|
|
200
241
|
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
sparsity_path = base_path / SPARSITY_FILENAME
|
|
204
|
-
save_file({"sparsity": trainer.log_feature_sparsity}, sparsity_path)
|
|
205
|
-
|
|
206
|
-
runner_config = trainer.cfg.to_dict()
|
|
207
|
-
with open(base_path / RUNNER_CFG_FILENAME, "w") as f:
|
|
242
|
+
runner_config = self.cfg.to_dict()
|
|
243
|
+
with open(checkpoint_path / RUNNER_CFG_FILENAME, "w") as f:
|
|
208
244
|
json.dump(runner_config, f)
|
|
209
245
|
|
|
210
|
-
if trainer.cfg.logger.log_to_wandb:
|
|
211
|
-
trainer.cfg.logger.log(
|
|
212
|
-
trainer,
|
|
213
|
-
weights_path,
|
|
214
|
-
cfg_path,
|
|
215
|
-
sparsity_path=sparsity_path,
|
|
216
|
-
wandb_aliases=wandb_aliases,
|
|
217
|
-
)
|
|
218
|
-
|
|
219
246
|
|
|
220
247
|
def _parse_cfg_args(
|
|
221
248
|
args: Sequence[str],
|
|
@@ -230,8 +257,13 @@ def _parse_cfg_args(
|
|
|
230
257
|
# moved into its own function to make it easier to test
|
|
231
258
|
def _run_cli(args: Sequence[str]):
|
|
232
259
|
cfg = _parse_cfg_args(args)
|
|
233
|
-
|
|
260
|
+
LanguageModelSAETrainingRunner(cfg=cfg).run()
|
|
234
261
|
|
|
235
262
|
|
|
236
263
|
if __name__ == "__main__":
|
|
237
264
|
_run_cli(args=sys.argv[1:])
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
@deprecated("Use LanguageModelSAETrainingRunner instead")
|
|
268
|
+
class SAETrainingRunner(LanguageModelSAETrainingRunner):
|
|
269
|
+
pass
|
sae_lens/load_model.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Any, Literal, cast
|
|
1
|
+
from typing import Any, Callable, Literal, cast
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
from transformer_lens import HookedTransformer
|
|
@@ -77,6 +77,7 @@ class HookedProxyLM(HookedRootModule):
|
|
|
77
77
|
# copied and modified from base HookedRootModule
|
|
78
78
|
def setup(self):
|
|
79
79
|
self.mod_dict = {}
|
|
80
|
+
self.named_modules_dict = {}
|
|
80
81
|
self.hook_dict: dict[str, HookPoint] = {}
|
|
81
82
|
for name, module in self.model.named_modules():
|
|
82
83
|
if name == "":
|
|
@@ -89,14 +90,21 @@ class HookedProxyLM(HookedRootModule):
|
|
|
89
90
|
|
|
90
91
|
self.hook_dict[name] = hook_point
|
|
91
92
|
self.mod_dict[name] = hook_point
|
|
93
|
+
self.named_modules_dict[name] = module
|
|
94
|
+
|
|
95
|
+
def run_with_cache(self, *args: Any, **kwargs: Any): # type: ignore
|
|
96
|
+
if "names_filter" in kwargs:
|
|
97
|
+
# hacky way to make sure that the names_filter is passed to our forward method
|
|
98
|
+
kwargs["_names_filter"] = kwargs["names_filter"]
|
|
99
|
+
return super().run_with_cache(*args, **kwargs)
|
|
92
100
|
|
|
93
101
|
def forward(
|
|
94
102
|
self,
|
|
95
103
|
tokens: torch.Tensor,
|
|
96
104
|
return_type: Literal["both", "logits"] = "logits",
|
|
97
105
|
loss_per_token: bool = False,
|
|
98
|
-
# TODO: implement real support for stop_at_layer
|
|
99
106
|
stop_at_layer: int | None = None,
|
|
107
|
+
_names_filter: list[str] | None = None,
|
|
100
108
|
**kwargs: Any,
|
|
101
109
|
) -> Output | Loss:
|
|
102
110
|
# This is just what's needed for evals, not everything that HookedTransformer has
|
|
@@ -107,8 +115,28 @@ class HookedProxyLM(HookedRootModule):
|
|
|
107
115
|
raise NotImplementedError(
|
|
108
116
|
"Only return_type supported is 'both' or 'logits' to match what's in evals.py and ActivationsStore"
|
|
109
117
|
)
|
|
110
|
-
|
|
111
|
-
|
|
118
|
+
|
|
119
|
+
stop_hooks = []
|
|
120
|
+
if stop_at_layer is not None and _names_filter is not None:
|
|
121
|
+
if return_type != "logits":
|
|
122
|
+
raise NotImplementedError(
|
|
123
|
+
"stop_at_layer is not supported for return_type='both'"
|
|
124
|
+
)
|
|
125
|
+
stop_manager = StopManager(_names_filter)
|
|
126
|
+
|
|
127
|
+
for hook_name in _names_filter:
|
|
128
|
+
module = self.named_modules_dict[hook_name]
|
|
129
|
+
stop_fn = stop_manager.get_stop_hook_fn(hook_name)
|
|
130
|
+
stop_hooks.append(module.register_forward_hook(stop_fn))
|
|
131
|
+
try:
|
|
132
|
+
output = self.model(tokens)
|
|
133
|
+
logits = _extract_logits_from_output(output)
|
|
134
|
+
except StopForward:
|
|
135
|
+
# If we stop early, we don't care about the return output
|
|
136
|
+
return None # type: ignore
|
|
137
|
+
finally:
|
|
138
|
+
for stop_hook in stop_hooks:
|
|
139
|
+
stop_hook.remove()
|
|
112
140
|
|
|
113
141
|
if return_type == "logits":
|
|
114
142
|
return logits
|
|
@@ -159,7 +187,7 @@ class HookedProxyLM(HookedRootModule):
|
|
|
159
187
|
|
|
160
188
|
# We don't want to prepend bos but the tokenizer does it automatically, so we remove it manually
|
|
161
189
|
if hasattr(self.tokenizer, "add_bos_token") and self.tokenizer.add_bos_token: # type: ignore
|
|
162
|
-
tokens = get_tokens_with_bos_removed(self.tokenizer, tokens)
|
|
190
|
+
tokens = get_tokens_with_bos_removed(self.tokenizer, tokens) # type: ignore
|
|
163
191
|
return tokens # type: ignore
|
|
164
192
|
|
|
165
193
|
|
|
@@ -183,3 +211,23 @@ def get_hook_fn(hook_point: HookPoint):
|
|
|
183
211
|
return output
|
|
184
212
|
|
|
185
213
|
return hook_fn
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
class StopForward(Exception):
|
|
217
|
+
pass
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
class StopManager:
|
|
221
|
+
def __init__(self, hook_names: list[str]):
|
|
222
|
+
self.hook_names = hook_names
|
|
223
|
+
self.total_hook_names = len(set(hook_names))
|
|
224
|
+
self.called_hook_names = set()
|
|
225
|
+
|
|
226
|
+
def get_stop_hook_fn(self, hook_name: str) -> Callable[[Any, Any, Any], Any]:
|
|
227
|
+
def stop_hook_fn(module: Any, input: Any, output: Any) -> Any: # noqa: ARG001
|
|
228
|
+
self.called_hook_names.add(hook_name)
|
|
229
|
+
if len(self.called_hook_names) == self.total_hook_names:
|
|
230
|
+
raise StopForward()
|
|
231
|
+
return output
|
|
232
|
+
|
|
233
|
+
return stop_hook_fn
|