sae-lens 5.10.3__py3-none-any.whl → 6.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +56 -6
- sae_lens/analysis/hooked_sae_transformer.py +12 -12
- sae_lens/analysis/neuronpedia_integration.py +13 -11
- sae_lens/cache_activations_runner.py +2 -1
- sae_lens/config.py +121 -252
- sae_lens/constants.py +18 -0
- sae_lens/evals.py +32 -17
- sae_lens/{toolkit → loading}/pretrained_sae_loaders.py +68 -36
- sae_lens/pretrained_saes.yaml +0 -12
- sae_lens/registry.py +49 -0
- sae_lens/sae_training_runner.py +40 -54
- sae_lens/saes/__init__.py +48 -0
- sae_lens/saes/gated_sae.py +258 -0
- sae_lens/saes/jumprelu_sae.py +354 -0
- sae_lens/saes/sae.py +948 -0
- sae_lens/saes/standard_sae.py +185 -0
- sae_lens/saes/topk_sae.py +294 -0
- sae_lens/training/activations_store.py +32 -16
- sae_lens/training/optim.py +60 -36
- sae_lens/training/sae_trainer.py +55 -86
- sae_lens/training/upload_saes_to_huggingface.py +12 -6
- sae_lens/util.py +28 -0
- {sae_lens-5.10.3.dist-info → sae_lens-6.0.0rc2.dist-info}/METADATA +1 -1
- sae_lens-6.0.0rc2.dist-info/RECORD +35 -0
- sae_lens/sae.py +0 -747
- sae_lens/training/training_sae.py +0 -710
- sae_lens-5.10.3.dist-info/RECORD +0 -28
- /sae_lens/{toolkit → loading}/__init__.py +0 -0
- /sae_lens/{toolkit → loading}/pretrained_saes_directory.py +0 -0
- {sae_lens-5.10.3.dist-info → sae_lens-6.0.0rc2.dist-info}/LICENSE +0 -0
- {sae_lens-5.10.3.dist-info → sae_lens-6.0.0rc2.dist-info}/WHEEL +0 -0
sae_lens/evals.py
CHANGED
|
@@ -19,8 +19,8 @@ from tqdm import tqdm
|
|
|
19
19
|
from transformer_lens import HookedTransformer
|
|
20
20
|
from transformer_lens.hook_points import HookedRootModule
|
|
21
21
|
|
|
22
|
-
from sae_lens.
|
|
23
|
-
from sae_lens.
|
|
22
|
+
from sae_lens.loading.pretrained_saes_directory import get_pretrained_saes_directory
|
|
23
|
+
from sae_lens.saes.sae import SAE, SAEConfig
|
|
24
24
|
from sae_lens.training.activations_store import ActivationsStore
|
|
25
25
|
|
|
26
26
|
|
|
@@ -100,7 +100,7 @@ def get_eval_everything_config(
|
|
|
100
100
|
|
|
101
101
|
@torch.no_grad()
|
|
102
102
|
def run_evals(
|
|
103
|
-
sae: SAE,
|
|
103
|
+
sae: SAE[Any],
|
|
104
104
|
activation_store: ActivationsStore,
|
|
105
105
|
model: HookedRootModule,
|
|
106
106
|
eval_config: EvalConfig = EvalConfig(),
|
|
@@ -108,7 +108,7 @@ def run_evals(
|
|
|
108
108
|
ignore_tokens: set[int | None] = set(),
|
|
109
109
|
verbose: bool = False,
|
|
110
110
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
111
|
-
hook_name = sae.cfg.hook_name
|
|
111
|
+
hook_name = sae.cfg.metadata.hook_name
|
|
112
112
|
actual_batch_size = (
|
|
113
113
|
eval_config.batch_size_prompts or activation_store.store_batch_size_prompts
|
|
114
114
|
)
|
|
@@ -274,12 +274,11 @@ def run_evals(
|
|
|
274
274
|
return all_metrics, feature_metrics
|
|
275
275
|
|
|
276
276
|
|
|
277
|
-
def get_featurewise_weight_based_metrics(sae: SAE) -> dict[str, Any]:
|
|
277
|
+
def get_featurewise_weight_based_metrics(sae: SAE[Any]) -> dict[str, Any]:
|
|
278
278
|
unit_norm_encoders = (sae.W_enc / sae.W_enc.norm(dim=0, keepdim=True)).cpu()
|
|
279
279
|
unit_norm_decoder = (sae.W_dec.T / sae.W_dec.T.norm(dim=0, keepdim=True)).cpu()
|
|
280
280
|
|
|
281
281
|
encoder_norms = sae.W_enc.norm(dim=-2).cpu().tolist()
|
|
282
|
-
encoder_bias = sae.b_enc.cpu().tolist()
|
|
283
282
|
encoder_decoder_cosine_sim = (
|
|
284
283
|
torch.nn.functional.cosine_similarity(
|
|
285
284
|
unit_norm_decoder.T,
|
|
@@ -289,15 +288,17 @@ def get_featurewise_weight_based_metrics(sae: SAE) -> dict[str, Any]:
|
|
|
289
288
|
.tolist()
|
|
290
289
|
)
|
|
291
290
|
|
|
292
|
-
|
|
293
|
-
"encoder_bias": encoder_bias,
|
|
291
|
+
metrics = {
|
|
294
292
|
"encoder_norm": encoder_norms,
|
|
295
293
|
"encoder_decoder_cosine_sim": encoder_decoder_cosine_sim,
|
|
296
294
|
}
|
|
295
|
+
if hasattr(sae, "b_enc") and sae.b_enc is not None:
|
|
296
|
+
metrics["encoder_bias"] = sae.b_enc.cpu().tolist() # type: ignore
|
|
297
|
+
return metrics
|
|
297
298
|
|
|
298
299
|
|
|
299
300
|
def get_downstream_reconstruction_metrics(
|
|
300
|
-
sae: SAE,
|
|
301
|
+
sae: SAE[Any],
|
|
301
302
|
model: HookedRootModule,
|
|
302
303
|
activation_store: ActivationsStore,
|
|
303
304
|
compute_kl: bool,
|
|
@@ -365,7 +366,7 @@ def get_downstream_reconstruction_metrics(
|
|
|
365
366
|
|
|
366
367
|
|
|
367
368
|
def get_sparsity_and_variance_metrics(
|
|
368
|
-
sae: SAE,
|
|
369
|
+
sae: SAE[Any],
|
|
369
370
|
model: HookedRootModule,
|
|
370
371
|
activation_store: ActivationsStore,
|
|
371
372
|
n_batches: int,
|
|
@@ -378,8 +379,8 @@ def get_sparsity_and_variance_metrics(
|
|
|
378
379
|
ignore_tokens: set[int | None] = set(),
|
|
379
380
|
verbose: bool = False,
|
|
380
381
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
381
|
-
hook_name = sae.cfg.hook_name
|
|
382
|
-
hook_head_index = sae.cfg.hook_head_index
|
|
382
|
+
hook_name = sae.cfg.metadata.hook_name
|
|
383
|
+
hook_head_index = sae.cfg.metadata.hook_head_index
|
|
383
384
|
|
|
384
385
|
metric_dict = {}
|
|
385
386
|
feature_metric_dict = {}
|
|
@@ -435,7 +436,7 @@ def get_sparsity_and_variance_metrics(
|
|
|
435
436
|
batch_tokens,
|
|
436
437
|
prepend_bos=False,
|
|
437
438
|
names_filter=[hook_name],
|
|
438
|
-
stop_at_layer=sae.cfg.hook_layer + 1,
|
|
439
|
+
stop_at_layer=sae.cfg.metadata.hook_layer + 1,
|
|
439
440
|
**model_kwargs,
|
|
440
441
|
)
|
|
441
442
|
|
|
@@ -579,7 +580,7 @@ def get_sparsity_and_variance_metrics(
|
|
|
579
580
|
|
|
580
581
|
@torch.no_grad()
|
|
581
582
|
def get_recons_loss(
|
|
582
|
-
sae: SAE,
|
|
583
|
+
sae: SAE[SAEConfig],
|
|
583
584
|
model: HookedRootModule,
|
|
584
585
|
batch_tokens: torch.Tensor,
|
|
585
586
|
activation_store: ActivationsStore,
|
|
@@ -587,9 +588,13 @@ def get_recons_loss(
|
|
|
587
588
|
compute_ce_loss: bool,
|
|
588
589
|
ignore_tokens: set[int | None] = set(),
|
|
589
590
|
model_kwargs: Mapping[str, Any] = {},
|
|
591
|
+
hook_name: str | None = None,
|
|
590
592
|
) -> dict[str, Any]:
|
|
591
|
-
hook_name = sae.cfg.hook_name
|
|
592
|
-
head_index = sae.cfg.hook_head_index
|
|
593
|
+
hook_name = hook_name or sae.cfg.metadata.hook_name
|
|
594
|
+
head_index = sae.cfg.metadata.hook_head_index
|
|
595
|
+
|
|
596
|
+
if hook_name is None:
|
|
597
|
+
raise ValueError("hook_name must be provided")
|
|
593
598
|
|
|
594
599
|
original_logits, original_ce_loss = model(
|
|
595
600
|
batch_tokens, return_type="both", loss_per_token=True, **model_kwargs
|
|
@@ -764,6 +769,17 @@ def nested_dict() -> defaultdict[Any, Any]:
|
|
|
764
769
|
return defaultdict(nested_dict)
|
|
765
770
|
|
|
766
771
|
|
|
772
|
+
def dict_to_nested(flat_dict: dict[str, Any]) -> defaultdict[Any, Any]:
|
|
773
|
+
nested = nested_dict()
|
|
774
|
+
for key, value in flat_dict.items():
|
|
775
|
+
parts = key.split("/")
|
|
776
|
+
d = nested
|
|
777
|
+
for part in parts[:-1]:
|
|
778
|
+
d = d[part]
|
|
779
|
+
d[parts[-1]] = value
|
|
780
|
+
return nested
|
|
781
|
+
|
|
782
|
+
|
|
767
783
|
def multiple_evals(
|
|
768
784
|
sae_regex_pattern: str,
|
|
769
785
|
sae_block_pattern: str,
|
|
@@ -794,7 +810,6 @@ def multiple_evals(
|
|
|
794
810
|
|
|
795
811
|
current_model = None
|
|
796
812
|
current_model_str = None
|
|
797
|
-
print(filtered_saes)
|
|
798
813
|
for sae_release_name, sae_id, _, _ in tqdm(filtered_saes):
|
|
799
814
|
sae = SAE.from_pretrained(
|
|
800
815
|
release=sae_release_name, # see other options in sae_lens/pretrained_saes.yaml
|
|
@@ -7,21 +7,24 @@ import numpy as np
|
|
|
7
7
|
import torch
|
|
8
8
|
from huggingface_hub import hf_hub_download
|
|
9
9
|
from huggingface_hub.utils import EntryNotFoundError
|
|
10
|
+
from packaging.version import Version
|
|
10
11
|
from safetensors import safe_open
|
|
11
12
|
from safetensors.torch import load_file
|
|
12
13
|
|
|
13
14
|
from sae_lens import logger
|
|
14
|
-
from sae_lens.
|
|
15
|
+
from sae_lens.constants import (
|
|
15
16
|
DTYPE_MAP,
|
|
16
17
|
SAE_CFG_FILENAME,
|
|
17
18
|
SAE_WEIGHTS_FILENAME,
|
|
18
19
|
SPARSITY_FILENAME,
|
|
19
20
|
)
|
|
20
|
-
from sae_lens.
|
|
21
|
+
from sae_lens.loading.pretrained_saes_directory import (
|
|
21
22
|
get_config_overrides,
|
|
22
23
|
get_pretrained_saes_directory,
|
|
23
24
|
get_repo_id_and_folder_name,
|
|
24
25
|
)
|
|
26
|
+
from sae_lens.registry import get_sae_class
|
|
27
|
+
from sae_lens.util import filter_valid_dataclass_fields
|
|
25
28
|
|
|
26
29
|
|
|
27
30
|
# loaders take in a release, sae_id, device, and whether to force download, and returns a tuple of config, state_dict, and log sparsity
|
|
@@ -174,30 +177,68 @@ def get_sae_lens_config_from_disk(
|
|
|
174
177
|
|
|
175
178
|
|
|
176
179
|
def handle_config_defaulting(cfg_dict: dict[str, Any]) -> dict[str, Any]:
|
|
180
|
+
sae_lens_version = cfg_dict.get("sae_lens_version")
|
|
181
|
+
if not sae_lens_version and "metadata" in cfg_dict:
|
|
182
|
+
sae_lens_version = cfg_dict["metadata"].get("sae_lens_version")
|
|
183
|
+
|
|
184
|
+
if not sae_lens_version or Version(sae_lens_version) < Version("6.0.0-rc.0"):
|
|
185
|
+
cfg_dict = handle_pre_6_0_config(cfg_dict)
|
|
186
|
+
return cfg_dict
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def handle_pre_6_0_config(cfg_dict: dict[str, Any]) -> dict[str, Any]:
|
|
190
|
+
"""
|
|
191
|
+
Format a config dictionary for a Sparse Autoencoder (SAE) to be compatible with the new 6.0 format.
|
|
192
|
+
"""
|
|
193
|
+
|
|
194
|
+
rename_keys_map = {
|
|
195
|
+
"hook_point": "hook_name",
|
|
196
|
+
"hook_point_layer": "hook_layer",
|
|
197
|
+
"hook_point_head_index": "hook_head_index",
|
|
198
|
+
"activation_fn_str": "activation_fn",
|
|
199
|
+
}
|
|
200
|
+
new_cfg = {rename_keys_map.get(k, k): v for k, v in cfg_dict.items()}
|
|
201
|
+
|
|
177
202
|
# Set default values for backwards compatibility
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
if "normalize_activations" in
|
|
188
|
-
|
|
203
|
+
new_cfg.setdefault("prepend_bos", True)
|
|
204
|
+
new_cfg.setdefault("dataset_trust_remote_code", True)
|
|
205
|
+
new_cfg.setdefault("apply_b_dec_to_input", True)
|
|
206
|
+
new_cfg.setdefault("finetuning_scaling_factor", False)
|
|
207
|
+
new_cfg.setdefault("sae_lens_training_version", None)
|
|
208
|
+
new_cfg.setdefault("activation_fn", new_cfg.get("activation_fn", "relu"))
|
|
209
|
+
new_cfg.setdefault("architecture", "standard")
|
|
210
|
+
new_cfg.setdefault("neuronpedia_id", None)
|
|
211
|
+
|
|
212
|
+
if "normalize_activations" in new_cfg and isinstance(
|
|
213
|
+
new_cfg["normalize_activations"], bool
|
|
189
214
|
):
|
|
190
215
|
# backwards compatibility
|
|
191
|
-
|
|
216
|
+
new_cfg["normalize_activations"] = (
|
|
192
217
|
"none"
|
|
193
|
-
if not
|
|
218
|
+
if not new_cfg["normalize_activations"]
|
|
194
219
|
else "expected_average_only_in"
|
|
195
220
|
)
|
|
196
221
|
|
|
197
|
-
|
|
198
|
-
|
|
222
|
+
if new_cfg.get("normalize_activations") is None:
|
|
223
|
+
new_cfg["normalize_activations"] = "none"
|
|
199
224
|
|
|
200
|
-
|
|
225
|
+
new_cfg.setdefault("device", "cpu")
|
|
226
|
+
|
|
227
|
+
architecture = new_cfg.get("architecture", "standard")
|
|
228
|
+
|
|
229
|
+
config_class = get_sae_class(architecture)[1]
|
|
230
|
+
|
|
231
|
+
sae_cfg_dict = filter_valid_dataclass_fields(new_cfg, config_class)
|
|
232
|
+
if architecture == "topk":
|
|
233
|
+
sae_cfg_dict["k"] = new_cfg["activation_fn_kwargs"]["k"]
|
|
234
|
+
|
|
235
|
+
# import here to avoid circular import
|
|
236
|
+
from sae_lens.saes.sae import SAEMetadata
|
|
237
|
+
|
|
238
|
+
meta_dict = filter_valid_dataclass_fields(new_cfg, SAEMetadata)
|
|
239
|
+
sae_cfg_dict["metadata"] = meta_dict
|
|
240
|
+
sae_cfg_dict["architecture"] = architecture
|
|
241
|
+
return sae_cfg_dict
|
|
201
242
|
|
|
202
243
|
|
|
203
244
|
def get_connor_rob_hook_z_config_from_hf(
|
|
@@ -223,7 +264,7 @@ def get_connor_rob_hook_z_config_from_hf(
|
|
|
223
264
|
"hook_name": old_cfg_dict["act_name"],
|
|
224
265
|
"hook_layer": old_cfg_dict["layer"],
|
|
225
266
|
"hook_head_index": None,
|
|
226
|
-
"
|
|
267
|
+
"activation_fn": "relu",
|
|
227
268
|
"apply_b_dec_to_input": True,
|
|
228
269
|
"finetuning_scaling_factor": False,
|
|
229
270
|
"sae_lens_training_version": None,
|
|
@@ -372,7 +413,7 @@ def get_gemma_2_config_from_hf(
|
|
|
372
413
|
"hook_name": hook_name,
|
|
373
414
|
"hook_layer": layer,
|
|
374
415
|
"hook_head_index": None,
|
|
375
|
-
"
|
|
416
|
+
"activation_fn": "relu",
|
|
376
417
|
"finetuning_scaling_factor": False,
|
|
377
418
|
"sae_lens_training_version": None,
|
|
378
419
|
"prepend_bos": True,
|
|
@@ -473,20 +514,11 @@ def get_llama_scope_config_from_hf(
|
|
|
473
514
|
# Model specific parameters
|
|
474
515
|
model_name, d_in = "meta-llama/Llama-3.1-8B", old_cfg_dict["d_model"]
|
|
475
516
|
|
|
476
|
-
# Get norm scaling factor to rescale jumprelu threshold.
|
|
477
|
-
# We need this because sae.fold_activation_norm_scaling_factor folds scaling norm into W_enc.
|
|
478
|
-
# This requires jumprelu threshold to be scaled in the same way
|
|
479
|
-
norm_scaling_factor = (
|
|
480
|
-
d_in**0.5 / old_cfg_dict["dataset_average_activation_norm"]["in"]
|
|
481
|
-
)
|
|
482
|
-
|
|
483
517
|
cfg_dict = {
|
|
484
518
|
"architecture": "jumprelu",
|
|
485
|
-
"jump_relu_threshold": old_cfg_dict["jump_relu_threshold"]
|
|
486
|
-
* norm_scaling_factor,
|
|
519
|
+
"jump_relu_threshold": old_cfg_dict["jump_relu_threshold"],
|
|
487
520
|
# We use a scalar jump_relu_threshold for all features
|
|
488
521
|
# This is different from Gemma Scope JumpReLU SAEs.
|
|
489
|
-
# Scaled with norm_scaling_factor to match sae.fold_activation_norm_scaling_factor
|
|
490
522
|
"d_in": d_in,
|
|
491
523
|
"d_sae": old_cfg_dict["d_sae"],
|
|
492
524
|
"dtype": "bfloat16",
|
|
@@ -494,7 +526,7 @@ def get_llama_scope_config_from_hf(
|
|
|
494
526
|
"hook_name": old_cfg_dict["hook_point_in"],
|
|
495
527
|
"hook_layer": int(old_cfg_dict["hook_point_in"].split(".")[1]),
|
|
496
528
|
"hook_head_index": None,
|
|
497
|
-
"
|
|
529
|
+
"activation_fn": "relu",
|
|
498
530
|
"finetuning_scaling_factor": False,
|
|
499
531
|
"sae_lens_training_version": None,
|
|
500
532
|
"prepend_bos": True,
|
|
@@ -606,8 +638,8 @@ def get_dictionary_learning_config_1_from_hf(
|
|
|
606
638
|
|
|
607
639
|
hook_point_name = f"blocks.{trainer['layer']}.hook_resid_post"
|
|
608
640
|
|
|
609
|
-
|
|
610
|
-
activation_fn_kwargs = {"k": trainer["k"]} if
|
|
641
|
+
activation_fn = "topk" if trainer["dict_class"] == "AutoEncoderTopK" else "relu"
|
|
642
|
+
activation_fn_kwargs = {"k": trainer["k"]} if activation_fn == "topk" else {}
|
|
611
643
|
|
|
612
644
|
return {
|
|
613
645
|
"architecture": (
|
|
@@ -621,7 +653,7 @@ def get_dictionary_learning_config_1_from_hf(
|
|
|
621
653
|
"hook_name": hook_point_name,
|
|
622
654
|
"hook_layer": trainer["layer"],
|
|
623
655
|
"hook_head_index": None,
|
|
624
|
-
"
|
|
656
|
+
"activation_fn": activation_fn,
|
|
625
657
|
"activation_fn_kwargs": activation_fn_kwargs,
|
|
626
658
|
"apply_b_dec_to_input": True,
|
|
627
659
|
"finetuning_scaling_factor": False,
|
|
@@ -664,7 +696,7 @@ def get_deepseek_r1_config_from_hf(
|
|
|
664
696
|
"dataset_path": "lmsys/lmsys-chat-1m",
|
|
665
697
|
"dataset_trust_remote_code": True,
|
|
666
698
|
"sae_lens_training_version": None,
|
|
667
|
-
"
|
|
699
|
+
"activation_fn": "relu",
|
|
668
700
|
"normalize_activations": "none",
|
|
669
701
|
"device": device,
|
|
670
702
|
"apply_b_dec_to_input": False,
|
|
@@ -819,7 +851,7 @@ def get_llama_scope_r1_distill_config_from_hf(
|
|
|
819
851
|
"hook_name": huggingface_cfg_dict["hook_point_in"],
|
|
820
852
|
"hook_layer": int(huggingface_cfg_dict["hook_point_in"].split(".")[1]),
|
|
821
853
|
"hook_head_index": None,
|
|
822
|
-
"
|
|
854
|
+
"activation_fn": "relu",
|
|
823
855
|
"finetuning_scaling_factor": False,
|
|
824
856
|
"sae_lens_training_version": None,
|
|
825
857
|
"prepend_bos": True,
|
sae_lens/pretrained_saes.yaml
CHANGED
|
@@ -13634,51 +13634,39 @@ gemma-2-2b-res-matryoshka-dc:
|
|
|
13634
13634
|
- id: blocks.13.hook_resid_post
|
|
13635
13635
|
path: standard/blocks.13.hook_resid_post
|
|
13636
13636
|
l0: 40.0
|
|
13637
|
-
neuronpedia: gemma-2-2b/13-res-matryoshka-dc
|
|
13638
13637
|
- id: blocks.14.hook_resid_post
|
|
13639
13638
|
path: standard/blocks.14.hook_resid_post
|
|
13640
13639
|
l0: 40.0
|
|
13641
|
-
neuronpedia: gemma-2-2b/14-res-matryoshka-dc
|
|
13642
13640
|
- id: blocks.15.hook_resid_post
|
|
13643
13641
|
path: standard/blocks.15.hook_resid_post
|
|
13644
13642
|
l0: 40.0
|
|
13645
|
-
neuronpedia: gemma-2-2b/15-res-matryoshka-dc
|
|
13646
13643
|
- id: blocks.16.hook_resid_post
|
|
13647
13644
|
path: standard/blocks.16.hook_resid_post
|
|
13648
13645
|
l0: 40.0
|
|
13649
|
-
neuronpedia: gemma-2-2b/16-res-matryoshka-dc
|
|
13650
13646
|
- id: blocks.17.hook_resid_post
|
|
13651
13647
|
path: standard/blocks.17.hook_resid_post
|
|
13652
13648
|
l0: 40.0
|
|
13653
|
-
neuronpedia: gemma-2-2b/17-res-matryoshka-dc
|
|
13654
13649
|
- id: blocks.18.hook_resid_post
|
|
13655
13650
|
path: standard/blocks.18.hook_resid_post
|
|
13656
13651
|
l0: 40.0
|
|
13657
|
-
neuronpedia: gemma-2-2b/18-res-matryoshka-dc
|
|
13658
13652
|
- id: blocks.19.hook_resid_post
|
|
13659
13653
|
path: standard/blocks.19.hook_resid_post
|
|
13660
13654
|
l0: 40.0
|
|
13661
|
-
neuronpedia: gemma-2-2b/19-res-matryoshka-dc
|
|
13662
13655
|
- id: blocks.20.hook_resid_post
|
|
13663
13656
|
path: standard/blocks.20.hook_resid_post
|
|
13664
13657
|
l0: 40.0
|
|
13665
|
-
neuronpedia: gemma-2-2b/20-res-matryoshka-dc
|
|
13666
13658
|
- id: blocks.21.hook_resid_post
|
|
13667
13659
|
path: standard/blocks.21.hook_resid_post
|
|
13668
13660
|
l0: 40.0
|
|
13669
|
-
neuronpedia: gemma-2-2b/21-res-matryoshka-dc
|
|
13670
13661
|
- id: blocks.22.hook_resid_post
|
|
13671
13662
|
path: standard/blocks.22.hook_resid_post
|
|
13672
13663
|
l0: 40.0
|
|
13673
|
-
neuronpedia: gemma-2-2b/22-res-matryoshka-dc
|
|
13674
13664
|
- id: blocks.23.hook_resid_post
|
|
13675
13665
|
path: standard/blocks.23.hook_resid_post
|
|
13676
13666
|
l0: 40.0
|
|
13677
|
-
neuronpedia: gemma-2-2b/23-res-matryoshka-dc
|
|
13678
13667
|
- id: blocks.24.hook_resid_post
|
|
13679
13668
|
path: standard/blocks.24.hook_resid_post
|
|
13680
13669
|
l0: 40.0
|
|
13681
|
-
neuronpedia: gemma-2-2b/24-res-matryoshka-dc
|
|
13682
13670
|
gemma-2-2b-res-snap-matryoshka-dc:
|
|
13683
13671
|
conversion_func: null
|
|
13684
13672
|
links:
|
sae_lens/registry.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, Any
|
|
2
|
+
|
|
3
|
+
# avoid circular imports
|
|
4
|
+
if TYPE_CHECKING:
|
|
5
|
+
from sae_lens.saes.sae import SAE, SAEConfig, TrainingSAE, TrainingSAEConfig
|
|
6
|
+
|
|
7
|
+
SAE_CLASS_REGISTRY: dict[str, tuple["type[SAE[Any]]", "type[SAEConfig]"]] = {}
|
|
8
|
+
SAE_TRAINING_CLASS_REGISTRY: dict[
|
|
9
|
+
str, tuple["type[TrainingSAE[Any]]", "type[TrainingSAEConfig]"]
|
|
10
|
+
] = {}
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def register_sae_class(
|
|
14
|
+
architecture: str,
|
|
15
|
+
sae_class: "type[SAE[Any]]",
|
|
16
|
+
sae_config_class: "type[SAEConfig]",
|
|
17
|
+
) -> None:
|
|
18
|
+
if architecture in SAE_CLASS_REGISTRY:
|
|
19
|
+
raise ValueError(
|
|
20
|
+
f"SAE class for architecture {architecture} already registered."
|
|
21
|
+
)
|
|
22
|
+
SAE_CLASS_REGISTRY[architecture] = (sae_class, sae_config_class)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def register_sae_training_class(
|
|
26
|
+
architecture: str,
|
|
27
|
+
sae_training_class: "type[TrainingSAE[Any]]",
|
|
28
|
+
sae_training_config_class: "type[TrainingSAEConfig]",
|
|
29
|
+
) -> None:
|
|
30
|
+
if architecture in SAE_TRAINING_CLASS_REGISTRY:
|
|
31
|
+
raise ValueError(
|
|
32
|
+
f"SAE training class for architecture {architecture} already registered."
|
|
33
|
+
)
|
|
34
|
+
SAE_TRAINING_CLASS_REGISTRY[architecture] = (
|
|
35
|
+
sae_training_class,
|
|
36
|
+
sae_training_config_class,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def get_sae_class(
|
|
41
|
+
architecture: str,
|
|
42
|
+
) -> tuple["type[SAE[Any]]", "type[SAEConfig]"]:
|
|
43
|
+
return SAE_CLASS_REGISTRY[architecture]
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def get_sae_training_class(
|
|
47
|
+
architecture: str,
|
|
48
|
+
) -> tuple["type[TrainingSAE[Any]]", "type[TrainingSAEConfig]"]:
|
|
49
|
+
return SAE_TRAINING_CLASS_REGISTRY[architecture]
|
sae_lens/sae_training_runner.py
CHANGED
|
@@ -7,16 +7,18 @@ from typing import Any, cast
|
|
|
7
7
|
|
|
8
8
|
import torch
|
|
9
9
|
import wandb
|
|
10
|
+
from safetensors.torch import save_file
|
|
10
11
|
from simple_parsing import ArgumentParser
|
|
11
12
|
from transformer_lens.hook_points import HookedRootModule
|
|
12
13
|
|
|
13
14
|
from sae_lens import logger
|
|
14
15
|
from sae_lens.config import HfDataset, LanguageModelSAERunnerConfig
|
|
16
|
+
from sae_lens.constants import RUNNER_CFG_FILENAME, SPARSITY_FILENAME
|
|
15
17
|
from sae_lens.load_model import load_model
|
|
18
|
+
from sae_lens.saes.sae import T_TRAINING_SAE_CONFIG, TrainingSAE, TrainingSAEConfig
|
|
16
19
|
from sae_lens.training.activations_store import ActivationsStore
|
|
17
20
|
from sae_lens.training.geometric_median import compute_geometric_median
|
|
18
21
|
from sae_lens.training.sae_trainer import SAETrainer
|
|
19
|
-
from sae_lens.training.training_sae import TrainingSAE, TrainingSAEConfig
|
|
20
22
|
|
|
21
23
|
|
|
22
24
|
class InterruptedException(Exception):
|
|
@@ -32,17 +34,17 @@ class SAETrainingRunner:
|
|
|
32
34
|
Class to run the training of a Sparse Autoencoder (SAE) on a TransformerLens model.
|
|
33
35
|
"""
|
|
34
36
|
|
|
35
|
-
cfg: LanguageModelSAERunnerConfig
|
|
37
|
+
cfg: LanguageModelSAERunnerConfig[Any]
|
|
36
38
|
model: HookedRootModule
|
|
37
|
-
sae: TrainingSAE
|
|
39
|
+
sae: TrainingSAE[Any]
|
|
38
40
|
activations_store: ActivationsStore
|
|
39
41
|
|
|
40
42
|
def __init__(
|
|
41
43
|
self,
|
|
42
|
-
cfg: LanguageModelSAERunnerConfig,
|
|
44
|
+
cfg: LanguageModelSAERunnerConfig[T_TRAINING_SAE_CONFIG],
|
|
43
45
|
override_dataset: HfDataset | None = None,
|
|
44
46
|
override_model: HookedRootModule | None = None,
|
|
45
|
-
override_sae: TrainingSAE | None = None,
|
|
47
|
+
override_sae: TrainingSAE[Any] | None = None,
|
|
46
48
|
):
|
|
47
49
|
if override_dataset is not None:
|
|
48
50
|
logger.warning(
|
|
@@ -73,14 +75,14 @@ class SAETrainingRunner:
|
|
|
73
75
|
|
|
74
76
|
if override_sae is None:
|
|
75
77
|
if self.cfg.from_pretrained_path is not None:
|
|
76
|
-
self.sae = TrainingSAE.
|
|
78
|
+
self.sae = TrainingSAE.load_from_disk(
|
|
77
79
|
self.cfg.from_pretrained_path, self.cfg.device
|
|
78
80
|
)
|
|
79
81
|
else:
|
|
80
|
-
self.sae = TrainingSAE(
|
|
82
|
+
self.sae = TrainingSAE.from_dict(
|
|
81
83
|
TrainingSAEConfig.from_dict(
|
|
82
84
|
self.cfg.get_training_sae_cfg_dict(),
|
|
83
|
-
)
|
|
85
|
+
).to_dict()
|
|
84
86
|
)
|
|
85
87
|
self._init_sae_group_b_decs()
|
|
86
88
|
else:
|
|
@@ -91,13 +93,13 @@ class SAETrainingRunner:
|
|
|
91
93
|
Run the training of the SAE.
|
|
92
94
|
"""
|
|
93
95
|
|
|
94
|
-
if self.cfg.log_to_wandb:
|
|
96
|
+
if self.cfg.logger.log_to_wandb:
|
|
95
97
|
wandb.init(
|
|
96
|
-
project=self.cfg.wandb_project,
|
|
97
|
-
entity=self.cfg.wandb_entity,
|
|
98
|
+
project=self.cfg.logger.wandb_project,
|
|
99
|
+
entity=self.cfg.logger.wandb_entity,
|
|
98
100
|
config=cast(Any, self.cfg),
|
|
99
|
-
name=self.cfg.run_name,
|
|
100
|
-
id=self.cfg.wandb_id,
|
|
101
|
+
name=self.cfg.logger.run_name,
|
|
102
|
+
id=self.cfg.logger.wandb_id,
|
|
101
103
|
)
|
|
102
104
|
|
|
103
105
|
trainer = SAETrainer(
|
|
@@ -111,7 +113,7 @@ class SAETrainingRunner:
|
|
|
111
113
|
self._compile_if_needed()
|
|
112
114
|
sae = self.run_trainer_with_interruption_handling(trainer)
|
|
113
115
|
|
|
114
|
-
if self.cfg.log_to_wandb:
|
|
116
|
+
if self.cfg.logger.log_to_wandb:
|
|
115
117
|
wandb.finish()
|
|
116
118
|
|
|
117
119
|
return sae
|
|
@@ -141,7 +143,9 @@ class SAETrainingRunner:
|
|
|
141
143
|
backend=backend,
|
|
142
144
|
) # type: ignore
|
|
143
145
|
|
|
144
|
-
def run_trainer_with_interruption_handling(
|
|
146
|
+
def run_trainer_with_interruption_handling(
|
|
147
|
+
self, trainer: SAETrainer[TrainingSAE[TrainingSAEConfig], TrainingSAEConfig]
|
|
148
|
+
):
|
|
145
149
|
try:
|
|
146
150
|
# signal handlers (if preempted)
|
|
147
151
|
signal.signal(signal.SIGINT, interrupt_callback)
|
|
@@ -167,7 +171,7 @@ class SAETrainingRunner:
|
|
|
167
171
|
extract all activations at a certain layer and use for sae b_dec initialization
|
|
168
172
|
"""
|
|
169
173
|
|
|
170
|
-
if self.cfg.b_dec_init_method == "geometric_median":
|
|
174
|
+
if self.cfg.sae.b_dec_init_method == "geometric_median":
|
|
171
175
|
self.activations_store.set_norm_scaling_factor_if_needed()
|
|
172
176
|
layer_acts = self.activations_store.storage_buffer.detach()[:, 0, :]
|
|
173
177
|
# get geometric median of the activations if we're using those.
|
|
@@ -175,15 +179,15 @@ class SAETrainingRunner:
|
|
|
175
179
|
layer_acts,
|
|
176
180
|
maxiter=100,
|
|
177
181
|
).median
|
|
178
|
-
self.sae.initialize_b_dec_with_precalculated(median)
|
|
179
|
-
elif self.cfg.b_dec_init_method == "mean":
|
|
182
|
+
self.sae.initialize_b_dec_with_precalculated(median)
|
|
183
|
+
elif self.cfg.sae.b_dec_init_method == "mean":
|
|
180
184
|
self.activations_store.set_norm_scaling_factor_if_needed()
|
|
181
185
|
layer_acts = self.activations_store.storage_buffer.detach().cpu()[:, 0, :]
|
|
182
186
|
self.sae.initialize_b_dec_with_mean(layer_acts) # type: ignore
|
|
183
187
|
|
|
184
188
|
@staticmethod
|
|
185
189
|
def save_checkpoint(
|
|
186
|
-
trainer: SAETrainer,
|
|
190
|
+
trainer: SAETrainer[TrainingSAE[Any], Any],
|
|
187
191
|
checkpoint_name: str,
|
|
188
192
|
wandb_aliases: list[str] | None = None,
|
|
189
193
|
) -> None:
|
|
@@ -194,46 +198,28 @@ class SAETrainingRunner:
|
|
|
194
198
|
str(base_path / "activations_store_state.safetensors")
|
|
195
199
|
)
|
|
196
200
|
|
|
197
|
-
|
|
198
|
-
trainer.sae.set_decoder_norm_to_unit_norm()
|
|
201
|
+
weights_path, cfg_path = trainer.sae.save_model(str(base_path))
|
|
199
202
|
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
trainer.log_feature_sparsity,
|
|
203
|
-
)
|
|
203
|
+
sparsity_path = base_path / SPARSITY_FILENAME
|
|
204
|
+
save_file({"sparsity": trainer.log_feature_sparsity}, sparsity_path)
|
|
204
205
|
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
# save model weights and cfg
|
|
217
|
-
model_artifact = wandb.Artifact(
|
|
218
|
-
sae_name,
|
|
219
|
-
type="model",
|
|
220
|
-
metadata=dict(trainer.cfg.__dict__),
|
|
221
|
-
)
|
|
222
|
-
model_artifact.add_file(str(weights_path))
|
|
223
|
-
model_artifact.add_file(str(cfg_path))
|
|
224
|
-
wandb.log_artifact(model_artifact, aliases=wandb_aliases)
|
|
225
|
-
|
|
226
|
-
# save log feature sparsity
|
|
227
|
-
sparsity_artifact = wandb.Artifact(
|
|
228
|
-
f"{sae_name}_log_feature_sparsity",
|
|
229
|
-
type="log_feature_sparsity",
|
|
230
|
-
metadata=dict(trainer.cfg.__dict__),
|
|
206
|
+
runner_config = trainer.cfg.to_dict()
|
|
207
|
+
with open(base_path / RUNNER_CFG_FILENAME, "w") as f:
|
|
208
|
+
json.dump(runner_config, f)
|
|
209
|
+
|
|
210
|
+
if trainer.cfg.logger.log_to_wandb:
|
|
211
|
+
trainer.cfg.logger.log(
|
|
212
|
+
trainer,
|
|
213
|
+
weights_path,
|
|
214
|
+
cfg_path,
|
|
215
|
+
sparsity_path=sparsity_path,
|
|
216
|
+
wandb_aliases=wandb_aliases,
|
|
231
217
|
)
|
|
232
|
-
sparsity_artifact.add_file(str(sparsity_path))
|
|
233
|
-
wandb.log_artifact(sparsity_artifact)
|
|
234
218
|
|
|
235
219
|
|
|
236
|
-
def _parse_cfg_args(
|
|
220
|
+
def _parse_cfg_args(
|
|
221
|
+
args: Sequence[str],
|
|
222
|
+
) -> LanguageModelSAERunnerConfig[TrainingSAEConfig]:
|
|
237
223
|
if len(args) == 0:
|
|
238
224
|
args = ["--help"]
|
|
239
225
|
parser = ArgumentParser(exit_on_error=False)
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
from .gated_sae import (
|
|
2
|
+
GatedSAE,
|
|
3
|
+
GatedSAEConfig,
|
|
4
|
+
GatedTrainingSAE,
|
|
5
|
+
GatedTrainingSAEConfig,
|
|
6
|
+
)
|
|
7
|
+
from .jumprelu_sae import (
|
|
8
|
+
JumpReLUSAE,
|
|
9
|
+
JumpReLUSAEConfig,
|
|
10
|
+
JumpReLUTrainingSAE,
|
|
11
|
+
JumpReLUTrainingSAEConfig,
|
|
12
|
+
)
|
|
13
|
+
from .sae import SAE, SAEConfig, TrainingSAE, TrainingSAEConfig
|
|
14
|
+
from .standard_sae import (
|
|
15
|
+
StandardSAE,
|
|
16
|
+
StandardSAEConfig,
|
|
17
|
+
StandardTrainingSAE,
|
|
18
|
+
StandardTrainingSAEConfig,
|
|
19
|
+
)
|
|
20
|
+
from .topk_sae import (
|
|
21
|
+
TopKSAE,
|
|
22
|
+
TopKSAEConfig,
|
|
23
|
+
TopKTrainingSAE,
|
|
24
|
+
TopKTrainingSAEConfig,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
__all__ = [
|
|
28
|
+
"SAE",
|
|
29
|
+
"SAEConfig",
|
|
30
|
+
"TrainingSAE",
|
|
31
|
+
"TrainingSAEConfig",
|
|
32
|
+
"StandardSAE",
|
|
33
|
+
"StandardSAEConfig",
|
|
34
|
+
"StandardTrainingSAE",
|
|
35
|
+
"StandardTrainingSAEConfig",
|
|
36
|
+
"GatedSAE",
|
|
37
|
+
"GatedSAEConfig",
|
|
38
|
+
"GatedTrainingSAE",
|
|
39
|
+
"GatedTrainingSAEConfig",
|
|
40
|
+
"JumpReLUSAE",
|
|
41
|
+
"JumpReLUSAEConfig",
|
|
42
|
+
"JumpReLUTrainingSAE",
|
|
43
|
+
"JumpReLUTrainingSAEConfig",
|
|
44
|
+
"TopKSAE",
|
|
45
|
+
"TopKSAEConfig",
|
|
46
|
+
"TopKTrainingSAE",
|
|
47
|
+
"TopKTrainingSAEConfig",
|
|
48
|
+
]
|