sae-lens 6.0.0rc3__tar.gz → 6.0.0rc5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc5}/PKG-INFO +2 -2
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc5}/README.md +1 -1
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc5}/pyproject.toml +2 -1
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc5}/sae_lens/__init__.py +1 -1
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc5}/sae_lens/analysis/neuronpedia_integration.py +3 -3
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc5}/sae_lens/config.py +5 -3
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc5}/sae_lens/constants.py +1 -0
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc5}/sae_lens/evals.py +20 -20
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc5}/sae_lens/llm_sae_training_runner.py +113 -5
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc5}/sae_lens/loading/pretrained_sae_loaders.py +178 -7
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc5}/sae_lens/pretrained_saes.yaml +12 -0
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc5}/sae_lens/saes/gated_sae.py +0 -4
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc5}/sae_lens/saes/jumprelu_sae.py +4 -10
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc5}/sae_lens/saes/sae.py +179 -48
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc5}/sae_lens/saes/standard_sae.py +4 -11
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc5}/sae_lens/saes/topk_sae.py +18 -12
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc5}/sae_lens/training/activation_scaler.py +1 -1
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc5}/sae_lens/training/activations_store.py +1 -3
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc5}/sae_lens/training/sae_trainer.py +11 -3
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc5}/sae_lens/training/upload_saes_to_huggingface.py +1 -1
- sae_lens-6.0.0rc3/sae_lens/training/geometric_median.py +0 -101
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc5}/LICENSE +0 -0
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc5}/sae_lens/analysis/__init__.py +0 -0
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc5}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc5}/sae_lens/cache_activations_runner.py +0 -0
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc5}/sae_lens/load_model.py +0 -0
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc5}/sae_lens/loading/__init__.py +0 -0
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc5}/sae_lens/loading/pretrained_saes_directory.py +0 -0
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc5}/sae_lens/pretokenize_runner.py +0 -0
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc5}/sae_lens/registry.py +0 -0
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc5}/sae_lens/saes/__init__.py +0 -0
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc5}/sae_lens/tokenization_and_batching.py +0 -0
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc5}/sae_lens/training/__init__.py +0 -0
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc5}/sae_lens/training/mixing_buffer.py +0 -0
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc5}/sae_lens/training/optim.py +0 -0
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc5}/sae_lens/training/types.py +0 -0
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc5}/sae_lens/tutorial/tsea.py +0 -0
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc5}/sae_lens/util.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: sae-lens
|
|
3
|
-
Version: 6.0.
|
|
3
|
+
Version: 6.0.0rc5
|
|
4
4
|
Summary: Training and Analyzing Sparse Autoencoders (SAEs)
|
|
5
5
|
License: MIT
|
|
6
6
|
Keywords: deep-learning,sparse-autoencoders,mechanistic-interpretability,PyTorch
|
|
@@ -80,7 +80,7 @@ Pre-trained SAEs for various models can be imported via SAE Lens. See this [page
|
|
|
80
80
|
|
|
81
81
|
## Join the Slack!
|
|
82
82
|
|
|
83
|
-
Feel free to join the [Open Source Mechanistic Interpretability Slack](https://join.slack.com/t/opensourcemechanistic/shared_invite/zt-
|
|
83
|
+
Feel free to join the [Open Source Mechanistic Interpretability Slack](https://join.slack.com/t/opensourcemechanistic/shared_invite/zt-375zalm04-GFd5tdBU1yLKlu_T_JSqZQ) for support!
|
|
84
84
|
|
|
85
85
|
## Citation
|
|
86
86
|
|
|
@@ -40,7 +40,7 @@ Pre-trained SAEs for various models can be imported via SAE Lens. See this [page
|
|
|
40
40
|
|
|
41
41
|
## Join the Slack!
|
|
42
42
|
|
|
43
|
-
Feel free to join the [Open Source Mechanistic Interpretability Slack](https://join.slack.com/t/opensourcemechanistic/shared_invite/zt-
|
|
43
|
+
Feel free to join the [Open Source Mechanistic Interpretability Slack](https://join.slack.com/t/opensourcemechanistic/shared_invite/zt-375zalm04-GFd5tdBU1yLKlu_T_JSqZQ) for support!
|
|
44
44
|
|
|
45
45
|
## Citation
|
|
46
46
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "sae-lens"
|
|
3
|
-
version = "6.0.0-rc.
|
|
3
|
+
version = "6.0.0-rc.5"
|
|
4
4
|
description = "Training and Analyzing Sparse Autoencoders (SAEs)"
|
|
5
5
|
authors = ["Joseph Bloom"]
|
|
6
6
|
readme = "README.md"
|
|
@@ -60,6 +60,7 @@ tabulate = "^0.9.0"
|
|
|
60
60
|
ruff = "^0.7.4"
|
|
61
61
|
eai-sparsify = "^1.1.1"
|
|
62
62
|
mike = "^2.0.0"
|
|
63
|
+
trio = "^0.30.0"
|
|
63
64
|
|
|
64
65
|
[tool.poetry.extras]
|
|
65
66
|
mamba = ["mamba-lens"]
|
|
@@ -59,7 +59,7 @@ def NanAndInfReplacer(value: str):
|
|
|
59
59
|
|
|
60
60
|
|
|
61
61
|
def open_neuronpedia_feature_dashboard(sae: SAE[Any], index: int):
|
|
62
|
-
sae_id = sae.cfg.neuronpedia_id
|
|
62
|
+
sae_id = sae.cfg.metadata.neuronpedia_id
|
|
63
63
|
if sae_id is None:
|
|
64
64
|
logger.warning(
|
|
65
65
|
"SAE does not have a Neuronpedia ID. Either dashboards for this SAE do not exist (yet) on Neuronpedia, or the SAE was not loaded via the from_pretrained method"
|
|
@@ -74,7 +74,7 @@ def get_neuronpedia_quick_list(
|
|
|
74
74
|
features: list[int],
|
|
75
75
|
name: str = "temporary_list",
|
|
76
76
|
):
|
|
77
|
-
sae_id = sae.cfg.neuronpedia_id
|
|
77
|
+
sae_id = sae.cfg.metadata.neuronpedia_id
|
|
78
78
|
if sae_id is None:
|
|
79
79
|
logger.warning(
|
|
80
80
|
"SAE does not have a Neuronpedia ID. Either dashboards for this SAE do not exist (yet) on Neuronpedia, or the SAE was not loaded via the from_pretrained method"
|
|
@@ -86,7 +86,7 @@ def get_neuronpedia_quick_list(
|
|
|
86
86
|
url = url + "?name=" + name
|
|
87
87
|
list_feature = [
|
|
88
88
|
{
|
|
89
|
-
"modelId": sae.cfg.model_name,
|
|
89
|
+
"modelId": sae.cfg.metadata.model_name,
|
|
90
90
|
"layer": sae_id.split("/")[1],
|
|
91
91
|
"index": str(feature),
|
|
92
92
|
}
|
|
@@ -201,7 +201,7 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
|
201
201
|
train_batch_size_tokens: int = 4096
|
|
202
202
|
|
|
203
203
|
## Adam
|
|
204
|
-
adam_beta1: float = 0.
|
|
204
|
+
adam_beta1: float = 0.9
|
|
205
205
|
adam_beta2: float = 0.999
|
|
206
206
|
|
|
207
207
|
## Learning Rate Schedule
|
|
@@ -390,7 +390,6 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
|
390
390
|
adam_beta2=self.adam_beta2,
|
|
391
391
|
lr_decay_steps=self.lr_decay_steps,
|
|
392
392
|
n_restart_cycles=self.n_restart_cycles,
|
|
393
|
-
total_training_steps=self.total_training_steps,
|
|
394
393
|
train_batch_size_samples=self.train_batch_size_tokens,
|
|
395
394
|
dead_feature_window=self.dead_feature_window,
|
|
396
395
|
feature_sampling_window=self.feature_sampling_window,
|
|
@@ -613,8 +612,11 @@ class SAETrainerConfig:
|
|
|
613
612
|
adam_beta2: float
|
|
614
613
|
lr_decay_steps: int
|
|
615
614
|
n_restart_cycles: int
|
|
616
|
-
total_training_steps: int
|
|
617
615
|
train_batch_size_samples: int
|
|
618
616
|
dead_feature_window: int
|
|
619
617
|
feature_sampling_window: int
|
|
620
618
|
logger: LoggingConfig
|
|
619
|
+
|
|
620
|
+
@property
|
|
621
|
+
def total_training_steps(self) -> int:
|
|
622
|
+
return self.total_training_samples // self.train_batch_size_samples
|
|
@@ -16,5 +16,6 @@ SPARSITY_FILENAME = "sparsity.safetensors"
|
|
|
16
16
|
SAE_WEIGHTS_FILENAME = "sae_weights.safetensors"
|
|
17
17
|
SAE_CFG_FILENAME = "cfg.json"
|
|
18
18
|
RUNNER_CFG_FILENAME = "runner_cfg.json"
|
|
19
|
+
SPARSIFY_WEIGHTS_FILENAME = "sae.safetensors"
|
|
19
20
|
ACTIVATIONS_STORE_STATE_FILENAME = "activations_store_state.safetensors"
|
|
20
21
|
ACTIVATION_SCALER_CFG_FILENAME = "activation_scaler.json"
|
|
@@ -4,6 +4,7 @@ import json
|
|
|
4
4
|
import math
|
|
5
5
|
import re
|
|
6
6
|
import subprocess
|
|
7
|
+
import sys
|
|
7
8
|
from collections import defaultdict
|
|
8
9
|
from collections.abc import Mapping
|
|
9
10
|
from dataclasses import dataclass, field
|
|
@@ -15,7 +16,7 @@ from typing import Any
|
|
|
15
16
|
import einops
|
|
16
17
|
import pandas as pd
|
|
17
18
|
import torch
|
|
18
|
-
from tqdm import tqdm
|
|
19
|
+
from tqdm.auto import tqdm
|
|
19
20
|
from transformer_lens import HookedTransformer
|
|
20
21
|
from transformer_lens.hook_points import HookedRootModule
|
|
21
22
|
|
|
@@ -768,17 +769,6 @@ def nested_dict() -> defaultdict[Any, Any]:
|
|
|
768
769
|
return defaultdict(nested_dict)
|
|
769
770
|
|
|
770
771
|
|
|
771
|
-
def dict_to_nested(flat_dict: dict[str, Any]) -> defaultdict[Any, Any]:
|
|
772
|
-
nested = nested_dict()
|
|
773
|
-
for key, value in flat_dict.items():
|
|
774
|
-
parts = key.split("/")
|
|
775
|
-
d = nested
|
|
776
|
-
for part in parts[:-1]:
|
|
777
|
-
d = d[part]
|
|
778
|
-
d[parts[-1]] = value
|
|
779
|
-
return nested
|
|
780
|
-
|
|
781
|
-
|
|
782
772
|
def multiple_evals(
|
|
783
773
|
sae_regex_pattern: str,
|
|
784
774
|
sae_block_pattern: str,
|
|
@@ -814,16 +804,18 @@ def multiple_evals(
|
|
|
814
804
|
release=sae_release_name, # see other options in sae_lens/pretrained_saes.yaml
|
|
815
805
|
sae_id=sae_id, # won't always be a hook point
|
|
816
806
|
device=device,
|
|
817
|
-
)
|
|
807
|
+
)
|
|
818
808
|
|
|
819
809
|
# move SAE to device if not there already
|
|
820
810
|
sae.to(device)
|
|
821
811
|
|
|
822
|
-
if current_model_str != sae.cfg.model_name:
|
|
812
|
+
if current_model_str != sae.cfg.metadata.model_name:
|
|
823
813
|
del current_model # potentially saves GPU memory
|
|
824
|
-
current_model_str = sae.cfg.model_name
|
|
814
|
+
current_model_str = sae.cfg.metadata.model_name
|
|
825
815
|
current_model = HookedTransformer.from_pretrained_no_processing(
|
|
826
|
-
current_model_str,
|
|
816
|
+
current_model_str,
|
|
817
|
+
device=device,
|
|
818
|
+
**sae.cfg.metadata.model_from_pretrained_kwargs,
|
|
827
819
|
)
|
|
828
820
|
assert current_model is not None
|
|
829
821
|
|
|
@@ -941,7 +933,7 @@ def process_results(
|
|
|
941
933
|
}
|
|
942
934
|
|
|
943
935
|
|
|
944
|
-
|
|
936
|
+
def process_args(args: list[str]) -> argparse.Namespace:
|
|
945
937
|
arg_parser = argparse.ArgumentParser(description="Run evaluations on SAEs")
|
|
946
938
|
arg_parser.add_argument(
|
|
947
939
|
"sae_regex_pattern",
|
|
@@ -1031,11 +1023,19 @@ if __name__ == "__main__":
|
|
|
1031
1023
|
help="Enable verbose output with tqdm loaders.",
|
|
1032
1024
|
)
|
|
1033
1025
|
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
|
|
1026
|
+
return arg_parser.parse_args(args)
|
|
1027
|
+
|
|
1028
|
+
|
|
1029
|
+
def run_evals_cli(args: list[str]) -> None:
|
|
1030
|
+
opts = process_args(args)
|
|
1031
|
+
eval_results = run_evaluations(opts)
|
|
1032
|
+
output_files = process_results(eval_results, opts.output_dir)
|
|
1037
1033
|
|
|
1038
1034
|
print("Evaluation complete. Output files:")
|
|
1039
1035
|
print(f"Individual JSONs: {len(output_files['individual_jsons'])}") # type: ignore
|
|
1040
1036
|
print(f"Combined JSON: {output_files['combined_json']}")
|
|
1041
1037
|
print(f"CSV: {output_files['csv']}")
|
|
1038
|
+
|
|
1039
|
+
|
|
1040
|
+
if __name__ == "__main__":
|
|
1041
|
+
run_evals_cli(sys.argv[1:])
|
|
@@ -4,7 +4,7 @@ import sys
|
|
|
4
4
|
from collections.abc import Sequence
|
|
5
5
|
from dataclasses import dataclass
|
|
6
6
|
from pathlib import Path
|
|
7
|
-
from typing import Any, Generic
|
|
7
|
+
from typing import Any, Generic
|
|
8
8
|
|
|
9
9
|
import torch
|
|
10
10
|
import wandb
|
|
@@ -17,12 +17,16 @@ from sae_lens.config import HfDataset, LanguageModelSAERunnerConfig
|
|
|
17
17
|
from sae_lens.constants import ACTIVATIONS_STORE_STATE_FILENAME, RUNNER_CFG_FILENAME
|
|
18
18
|
from sae_lens.evals import EvalConfig, run_evals
|
|
19
19
|
from sae_lens.load_model import load_model
|
|
20
|
+
from sae_lens.saes.gated_sae import GatedTrainingSAEConfig
|
|
21
|
+
from sae_lens.saes.jumprelu_sae import JumpReLUTrainingSAEConfig
|
|
20
22
|
from sae_lens.saes.sae import (
|
|
21
23
|
T_TRAINING_SAE,
|
|
22
24
|
T_TRAINING_SAE_CONFIG,
|
|
23
25
|
TrainingSAE,
|
|
24
26
|
TrainingSAEConfig,
|
|
25
27
|
)
|
|
28
|
+
from sae_lens.saes.standard_sae import StandardTrainingSAEConfig
|
|
29
|
+
from sae_lens.saes.topk_sae import TopKTrainingSAEConfig
|
|
26
30
|
from sae_lens.training.activation_scaler import ActivationScaler
|
|
27
31
|
from sae_lens.training.activations_store import ActivationsStore
|
|
28
32
|
from sae_lens.training.sae_trainer import SAETrainer
|
|
@@ -145,17 +149,18 @@ class LanguageModelSAETrainingRunner:
|
|
|
145
149
|
)
|
|
146
150
|
else:
|
|
147
151
|
self.sae = override_sae
|
|
152
|
+
self.sae.to(self.cfg.device)
|
|
148
153
|
|
|
149
154
|
def run(self):
|
|
150
155
|
"""
|
|
151
156
|
Run the training of the SAE.
|
|
152
157
|
"""
|
|
153
|
-
|
|
158
|
+
self._set_sae_metadata()
|
|
154
159
|
if self.cfg.logger.log_to_wandb:
|
|
155
160
|
wandb.init(
|
|
156
161
|
project=self.cfg.logger.wandb_project,
|
|
157
162
|
entity=self.cfg.logger.wandb_entity,
|
|
158
|
-
config=
|
|
163
|
+
config=self.cfg.to_dict(),
|
|
159
164
|
name=self.cfg.logger.run_name,
|
|
160
165
|
id=self.cfg.logger.wandb_id,
|
|
161
166
|
)
|
|
@@ -184,6 +189,20 @@ class LanguageModelSAETrainingRunner:
|
|
|
184
189
|
|
|
185
190
|
return sae
|
|
186
191
|
|
|
192
|
+
def _set_sae_metadata(self):
|
|
193
|
+
self.sae.cfg.metadata.dataset_path = self.cfg.dataset_path
|
|
194
|
+
self.sae.cfg.metadata.hook_name = self.cfg.hook_name
|
|
195
|
+
self.sae.cfg.metadata.model_name = self.cfg.model_name
|
|
196
|
+
self.sae.cfg.metadata.model_class_name = self.cfg.model_class_name
|
|
197
|
+
self.sae.cfg.metadata.hook_head_index = self.cfg.hook_head_index
|
|
198
|
+
self.sae.cfg.metadata.context_size = self.cfg.context_size
|
|
199
|
+
self.sae.cfg.metadata.seqpos_slice = self.cfg.seqpos_slice
|
|
200
|
+
self.sae.cfg.metadata.model_from_pretrained_kwargs = (
|
|
201
|
+
self.cfg.model_from_pretrained_kwargs
|
|
202
|
+
)
|
|
203
|
+
self.sae.cfg.metadata.prepend_bos = self.cfg.prepend_bos
|
|
204
|
+
self.sae.cfg.metadata.exclude_special_tokens = self.cfg.exclude_special_tokens
|
|
205
|
+
|
|
187
206
|
def _compile_if_needed(self):
|
|
188
207
|
# Compile model and SAE
|
|
189
208
|
# torch.compile can provide significant speedups (10-20% in testing)
|
|
@@ -247,11 +266,100 @@ class LanguageModelSAETrainingRunner:
|
|
|
247
266
|
def _parse_cfg_args(
|
|
248
267
|
args: Sequence[str],
|
|
249
268
|
) -> LanguageModelSAERunnerConfig[TrainingSAEConfig]:
|
|
269
|
+
"""
|
|
270
|
+
Parse command line arguments into a LanguageModelSAERunnerConfig.
|
|
271
|
+
|
|
272
|
+
This function first parses the architecture argument to determine which
|
|
273
|
+
concrete SAE config class to use, then parses the full configuration
|
|
274
|
+
with that concrete type.
|
|
275
|
+
"""
|
|
250
276
|
if len(args) == 0:
|
|
251
277
|
args = ["--help"]
|
|
278
|
+
|
|
279
|
+
# First, parse only the architecture to determine which concrete class to use
|
|
280
|
+
architecture_parser = ArgumentParser(
|
|
281
|
+
description="Parse architecture to determine SAE config class",
|
|
282
|
+
exit_on_error=False,
|
|
283
|
+
add_help=False, # Don't add help to avoid conflicts
|
|
284
|
+
)
|
|
285
|
+
architecture_parser.add_argument(
|
|
286
|
+
"--architecture",
|
|
287
|
+
type=str,
|
|
288
|
+
choices=["standard", "gated", "jumprelu", "topk"],
|
|
289
|
+
default="standard",
|
|
290
|
+
help="SAE architecture to use",
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
# Parse known args to extract architecture, ignore unknown args for now
|
|
294
|
+
arch_args, remaining_args = architecture_parser.parse_known_args(args)
|
|
295
|
+
architecture = arch_args.architecture
|
|
296
|
+
|
|
297
|
+
# Remove architecture from remaining args if it exists
|
|
298
|
+
filtered_args = []
|
|
299
|
+
skip_next = False
|
|
300
|
+
for arg in remaining_args:
|
|
301
|
+
if skip_next:
|
|
302
|
+
skip_next = False
|
|
303
|
+
continue
|
|
304
|
+
if arg == "--architecture":
|
|
305
|
+
skip_next = True # Skip the next argument (the architecture value)
|
|
306
|
+
continue
|
|
307
|
+
filtered_args.append(arg)
|
|
308
|
+
|
|
309
|
+
# Create a custom wrapper class that simple_parsing can handle
|
|
310
|
+
def create_config_class(
|
|
311
|
+
sae_config_type: type[TrainingSAEConfig],
|
|
312
|
+
) -> type[LanguageModelSAERunnerConfig[TrainingSAEConfig]]:
|
|
313
|
+
"""Create a concrete config class for the given SAE config type."""
|
|
314
|
+
|
|
315
|
+
# Create the base config without the sae field
|
|
316
|
+
from dataclasses import field as dataclass_field
|
|
317
|
+
from dataclasses import fields, make_dataclass
|
|
318
|
+
|
|
319
|
+
# Get all fields from LanguageModelSAERunnerConfig except the generic sae field
|
|
320
|
+
base_fields = []
|
|
321
|
+
for field_obj in fields(LanguageModelSAERunnerConfig):
|
|
322
|
+
if field_obj.name != "sae":
|
|
323
|
+
base_fields.append((field_obj.name, field_obj.type, field_obj))
|
|
324
|
+
|
|
325
|
+
# Add the concrete sae field
|
|
326
|
+
base_fields.append(
|
|
327
|
+
(
|
|
328
|
+
"sae",
|
|
329
|
+
sae_config_type,
|
|
330
|
+
dataclass_field(
|
|
331
|
+
default_factory=lambda: sae_config_type(d_in=512, d_sae=1024)
|
|
332
|
+
),
|
|
333
|
+
)
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
# Create the concrete class
|
|
337
|
+
return make_dataclass(
|
|
338
|
+
f"{sae_config_type.__name__}RunnerConfig",
|
|
339
|
+
base_fields,
|
|
340
|
+
bases=(LanguageModelSAERunnerConfig,),
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
# Map architecture to concrete config class
|
|
344
|
+
sae_config_map = {
|
|
345
|
+
"standard": StandardTrainingSAEConfig,
|
|
346
|
+
"gated": GatedTrainingSAEConfig,
|
|
347
|
+
"jumprelu": JumpReLUTrainingSAEConfig,
|
|
348
|
+
"topk": TopKTrainingSAEConfig,
|
|
349
|
+
}
|
|
350
|
+
|
|
351
|
+
sae_config_type = sae_config_map[architecture]
|
|
352
|
+
concrete_config_class = create_config_class(sae_config_type)
|
|
353
|
+
|
|
354
|
+
# Now parse the full configuration with the concrete type
|
|
252
355
|
parser = ArgumentParser(exit_on_error=False)
|
|
253
|
-
parser.add_arguments(
|
|
254
|
-
|
|
356
|
+
parser.add_arguments(concrete_config_class, dest="cfg")
|
|
357
|
+
|
|
358
|
+
# Parse the filtered arguments (without --architecture)
|
|
359
|
+
parsed_args = parser.parse_args(filtered_args)
|
|
360
|
+
|
|
361
|
+
# Return the parsed configuration
|
|
362
|
+
return parsed_args.cfg
|
|
255
363
|
|
|
256
364
|
|
|
257
365
|
# moved into its own function to make it easier to test
|
|
@@ -16,6 +16,7 @@ from sae_lens.constants import (
|
|
|
16
16
|
DTYPE_MAP,
|
|
17
17
|
SAE_CFG_FILENAME,
|
|
18
18
|
SAE_WEIGHTS_FILENAME,
|
|
19
|
+
SPARSIFY_WEIGHTS_FILENAME,
|
|
19
20
|
SPARSITY_FILENAME,
|
|
20
21
|
)
|
|
21
22
|
from sae_lens.loading.pretrained_saes_directory import (
|
|
@@ -26,6 +27,22 @@ from sae_lens.loading.pretrained_saes_directory import (
|
|
|
26
27
|
from sae_lens.registry import get_sae_class
|
|
27
28
|
from sae_lens.util import filter_valid_dataclass_fields
|
|
28
29
|
|
|
30
|
+
LLM_METADATA_KEYS = {
|
|
31
|
+
"model_name",
|
|
32
|
+
"hook_name",
|
|
33
|
+
"model_class_name",
|
|
34
|
+
"hook_head_index",
|
|
35
|
+
"model_from_pretrained_kwargs",
|
|
36
|
+
"prepend_bos",
|
|
37
|
+
"exclude_special_tokens",
|
|
38
|
+
"neuronpedia_id",
|
|
39
|
+
"context_size",
|
|
40
|
+
"seqpos_slice",
|
|
41
|
+
"dataset_path",
|
|
42
|
+
"sae_lens_version",
|
|
43
|
+
"sae_lens_training_version",
|
|
44
|
+
}
|
|
45
|
+
|
|
29
46
|
|
|
30
47
|
# loaders take in a release, sae_id, device, and whether to force download, and returns a tuple of config, state_dict, and log sparsity
|
|
31
48
|
class PretrainedSaeHuggingfaceLoader(Protocol):
|
|
@@ -207,6 +224,10 @@ def handle_pre_6_0_config(cfg_dict: dict[str, Any]) -> dict[str, Any]:
|
|
|
207
224
|
new_cfg.setdefault("activation_fn", new_cfg.get("activation_fn", "relu"))
|
|
208
225
|
new_cfg.setdefault("architecture", "standard")
|
|
209
226
|
new_cfg.setdefault("neuronpedia_id", None)
|
|
227
|
+
new_cfg.setdefault(
|
|
228
|
+
"reshape_activations",
|
|
229
|
+
"hook_z" if "hook_z" in new_cfg.get("hook_name", "") else "none",
|
|
230
|
+
)
|
|
210
231
|
|
|
211
232
|
if "normalize_activations" in new_cfg and isinstance(
|
|
212
233
|
new_cfg["normalize_activations"], bool
|
|
@@ -228,14 +249,12 @@ def handle_pre_6_0_config(cfg_dict: dict[str, Any]) -> dict[str, Any]:
|
|
|
228
249
|
config_class = get_sae_class(architecture)[1]
|
|
229
250
|
|
|
230
251
|
sae_cfg_dict = filter_valid_dataclass_fields(new_cfg, config_class)
|
|
231
|
-
if architecture == "topk":
|
|
252
|
+
if architecture == "topk" and "activation_fn_kwargs" in new_cfg:
|
|
232
253
|
sae_cfg_dict["k"] = new_cfg["activation_fn_kwargs"]["k"]
|
|
233
254
|
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
meta_dict = filter_valid_dataclass_fields(new_cfg, SAEMetadata)
|
|
238
|
-
sae_cfg_dict["metadata"] = meta_dict
|
|
255
|
+
sae_cfg_dict["metadata"] = {
|
|
256
|
+
k: v for k, v in new_cfg.items() if k in LLM_METADATA_KEYS
|
|
257
|
+
}
|
|
239
258
|
sae_cfg_dict["architecture"] = architecture
|
|
240
259
|
return sae_cfg_dict
|
|
241
260
|
|
|
@@ -271,6 +290,7 @@ def get_connor_rob_hook_z_config_from_hf(
|
|
|
271
290
|
"context_size": 128,
|
|
272
291
|
"normalize_activations": "none",
|
|
273
292
|
"dataset_trust_remote_code": True,
|
|
293
|
+
"reshape_activations": "hook_z",
|
|
274
294
|
**(cfg_overrides or {}),
|
|
275
295
|
}
|
|
276
296
|
|
|
@@ -511,11 +531,20 @@ def get_llama_scope_config_from_hf(
|
|
|
511
531
|
# Model specific parameters
|
|
512
532
|
model_name, d_in = "meta-llama/Llama-3.1-8B", old_cfg_dict["d_model"]
|
|
513
533
|
|
|
534
|
+
# Get norm scaling factor to rescale jumprelu threshold.
|
|
535
|
+
# We need this because sae.fold_activation_norm_scaling_factor folds scaling norm into W_enc.
|
|
536
|
+
# This requires jumprelu threshold to be scaled in the same way
|
|
537
|
+
norm_scaling_factor = (
|
|
538
|
+
d_in**0.5 / old_cfg_dict["dataset_average_activation_norm"]["in"]
|
|
539
|
+
)
|
|
540
|
+
|
|
514
541
|
cfg_dict = {
|
|
515
542
|
"architecture": "jumprelu",
|
|
516
|
-
"jump_relu_threshold": old_cfg_dict["jump_relu_threshold"]
|
|
543
|
+
"jump_relu_threshold": old_cfg_dict["jump_relu_threshold"]
|
|
544
|
+
* norm_scaling_factor,
|
|
517
545
|
# We use a scalar jump_relu_threshold for all features
|
|
518
546
|
# This is different from Gemma Scope JumpReLU SAEs.
|
|
547
|
+
# Scaled with norm_scaling_factor to match sae.fold_activation_norm_scaling_factor
|
|
519
548
|
"d_in": d_in,
|
|
520
549
|
"d_sae": old_cfg_dict["d_sae"],
|
|
521
550
|
"dtype": "bfloat16",
|
|
@@ -923,6 +952,146 @@ def llama_scope_r1_distill_sae_huggingface_loader(
|
|
|
923
952
|
return cfg_dict, state_dict, log_sparsity
|
|
924
953
|
|
|
925
954
|
|
|
955
|
+
def get_sparsify_config_from_hf(
|
|
956
|
+
repo_id: str,
|
|
957
|
+
folder_name: str,
|
|
958
|
+
device: str,
|
|
959
|
+
force_download: bool = False,
|
|
960
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
961
|
+
) -> dict[str, Any]:
|
|
962
|
+
cfg_filename = f"{folder_name}/{SAE_CFG_FILENAME}"
|
|
963
|
+
cfg_path = hf_hub_download(
|
|
964
|
+
repo_id,
|
|
965
|
+
filename=cfg_filename,
|
|
966
|
+
force_download=force_download,
|
|
967
|
+
)
|
|
968
|
+
sae_path = Path(cfg_path).parent
|
|
969
|
+
return get_sparsify_config_from_disk(
|
|
970
|
+
sae_path, device=device, cfg_overrides=cfg_overrides
|
|
971
|
+
)
|
|
972
|
+
|
|
973
|
+
|
|
974
|
+
def get_sparsify_config_from_disk(
|
|
975
|
+
path: str | Path,
|
|
976
|
+
device: str | None = None,
|
|
977
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
978
|
+
) -> dict[str, Any]:
|
|
979
|
+
path = Path(path)
|
|
980
|
+
|
|
981
|
+
with open(path / SAE_CFG_FILENAME) as f:
|
|
982
|
+
old_cfg_dict = json.load(f)
|
|
983
|
+
|
|
984
|
+
config_path = path.parent / "config.json"
|
|
985
|
+
if config_path.exists():
|
|
986
|
+
with open(config_path) as f:
|
|
987
|
+
config_dict = json.load(f)
|
|
988
|
+
else:
|
|
989
|
+
config_dict = {}
|
|
990
|
+
|
|
991
|
+
folder_name = path.name
|
|
992
|
+
if folder_name == "embed_tokens":
|
|
993
|
+
hook_name, layer = "hook_embed", 0
|
|
994
|
+
else:
|
|
995
|
+
match = re.search(r"layers[._](\d+)", folder_name)
|
|
996
|
+
if match is None:
|
|
997
|
+
raise ValueError(f"Unrecognized Sparsify folder: {folder_name}")
|
|
998
|
+
layer = int(match.group(1))
|
|
999
|
+
hook_name = f"blocks.{layer}.hook_resid_post"
|
|
1000
|
+
|
|
1001
|
+
cfg_dict: dict[str, Any] = {
|
|
1002
|
+
"architecture": "standard",
|
|
1003
|
+
"d_in": old_cfg_dict["d_in"],
|
|
1004
|
+
"d_sae": old_cfg_dict["d_in"] * old_cfg_dict["expansion_factor"],
|
|
1005
|
+
"dtype": "bfloat16",
|
|
1006
|
+
"device": device or "cpu",
|
|
1007
|
+
"model_name": config_dict.get("model", path.parts[-2]),
|
|
1008
|
+
"hook_name": hook_name,
|
|
1009
|
+
"hook_layer": layer,
|
|
1010
|
+
"hook_head_index": None,
|
|
1011
|
+
"activation_fn_str": "topk",
|
|
1012
|
+
"activation_fn_kwargs": {
|
|
1013
|
+
"k": old_cfg_dict["k"],
|
|
1014
|
+
"signed": old_cfg_dict.get("signed", False),
|
|
1015
|
+
},
|
|
1016
|
+
"apply_b_dec_to_input": not old_cfg_dict.get("normalize_decoder", False),
|
|
1017
|
+
"dataset_path": config_dict.get(
|
|
1018
|
+
"dataset", "togethercomputer/RedPajama-Data-1T-Sample"
|
|
1019
|
+
),
|
|
1020
|
+
"context_size": config_dict.get("ctx_len", 2048),
|
|
1021
|
+
"finetuning_scaling_factor": False,
|
|
1022
|
+
"sae_lens_training_version": None,
|
|
1023
|
+
"prepend_bos": True,
|
|
1024
|
+
"dataset_trust_remote_code": True,
|
|
1025
|
+
"normalize_activations": "none",
|
|
1026
|
+
"neuronpedia_id": None,
|
|
1027
|
+
}
|
|
1028
|
+
|
|
1029
|
+
if cfg_overrides:
|
|
1030
|
+
cfg_dict.update(cfg_overrides)
|
|
1031
|
+
|
|
1032
|
+
return cfg_dict
|
|
1033
|
+
|
|
1034
|
+
|
|
1035
|
+
def sparsify_huggingface_loader(
|
|
1036
|
+
repo_id: str,
|
|
1037
|
+
folder_name: str,
|
|
1038
|
+
device: str = "cpu",
|
|
1039
|
+
force_download: bool = False,
|
|
1040
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
1041
|
+
) -> tuple[dict[str, Any], dict[str, torch.Tensor], None]:
|
|
1042
|
+
weights_filename = f"{folder_name}/{SPARSIFY_WEIGHTS_FILENAME}"
|
|
1043
|
+
sae_path = hf_hub_download(
|
|
1044
|
+
repo_id,
|
|
1045
|
+
filename=weights_filename,
|
|
1046
|
+
force_download=force_download,
|
|
1047
|
+
)
|
|
1048
|
+
cfg_dict, state_dict = sparsify_disk_loader(
|
|
1049
|
+
Path(sae_path).parent, device=device, cfg_overrides=cfg_overrides
|
|
1050
|
+
)
|
|
1051
|
+
return cfg_dict, state_dict, None
|
|
1052
|
+
|
|
1053
|
+
|
|
1054
|
+
def sparsify_disk_loader(
|
|
1055
|
+
path: str | Path,
|
|
1056
|
+
device: str = "cpu",
|
|
1057
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
1058
|
+
) -> tuple[dict[str, Any], dict[str, torch.Tensor]]:
|
|
1059
|
+
cfg_dict = get_sparsify_config_from_disk(path, device, cfg_overrides)
|
|
1060
|
+
|
|
1061
|
+
weight_path = Path(path) / SPARSIFY_WEIGHTS_FILENAME
|
|
1062
|
+
state_dict_loaded = load_file(weight_path, device=device)
|
|
1063
|
+
|
|
1064
|
+
dtype = DTYPE_MAP[cfg_dict["dtype"]]
|
|
1065
|
+
|
|
1066
|
+
W_enc = (
|
|
1067
|
+
state_dict_loaded["W_enc"]
|
|
1068
|
+
if "W_enc" in state_dict_loaded
|
|
1069
|
+
else state_dict_loaded["encoder.weight"].T
|
|
1070
|
+
).to(dtype)
|
|
1071
|
+
|
|
1072
|
+
if "W_dec" in state_dict_loaded:
|
|
1073
|
+
W_dec = state_dict_loaded["W_dec"].T.to(dtype)
|
|
1074
|
+
else:
|
|
1075
|
+
W_dec = state_dict_loaded["decoder.weight"].T.to(dtype)
|
|
1076
|
+
|
|
1077
|
+
if "b_enc" in state_dict_loaded:
|
|
1078
|
+
b_enc = state_dict_loaded["b_enc"].to(dtype)
|
|
1079
|
+
elif "encoder.bias" in state_dict_loaded:
|
|
1080
|
+
b_enc = state_dict_loaded["encoder.bias"].to(dtype)
|
|
1081
|
+
else:
|
|
1082
|
+
b_enc = torch.zeros(cfg_dict["d_sae"], dtype=dtype, device=device)
|
|
1083
|
+
|
|
1084
|
+
if "b_dec" in state_dict_loaded:
|
|
1085
|
+
b_dec = state_dict_loaded["b_dec"].to(dtype)
|
|
1086
|
+
elif "decoder.bias" in state_dict_loaded:
|
|
1087
|
+
b_dec = state_dict_loaded["decoder.bias"].to(dtype)
|
|
1088
|
+
else:
|
|
1089
|
+
b_dec = torch.zeros(cfg_dict["d_in"], dtype=dtype, device=device)
|
|
1090
|
+
|
|
1091
|
+
state_dict = {"W_enc": W_enc, "b_enc": b_enc, "W_dec": W_dec, "b_dec": b_dec}
|
|
1092
|
+
return cfg_dict, state_dict
|
|
1093
|
+
|
|
1094
|
+
|
|
926
1095
|
NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
|
|
927
1096
|
"sae_lens": sae_lens_huggingface_loader,
|
|
928
1097
|
"connor_rob_hook_z": connor_rob_hook_z_huggingface_loader,
|
|
@@ -931,6 +1100,7 @@ NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
|
|
|
931
1100
|
"llama_scope_r1_distill": llama_scope_r1_distill_sae_huggingface_loader,
|
|
932
1101
|
"dictionary_learning_1": dictionary_learning_sae_huggingface_loader_1,
|
|
933
1102
|
"deepseek_r1": deepseek_r1_sae_huggingface_loader,
|
|
1103
|
+
"sparsify": sparsify_huggingface_loader,
|
|
934
1104
|
}
|
|
935
1105
|
|
|
936
1106
|
|
|
@@ -942,4 +1112,5 @@ NAMED_PRETRAINED_SAE_CONFIG_GETTERS: dict[str, PretrainedSaeConfigHuggingfaceLoa
|
|
|
942
1112
|
"llama_scope_r1_distill": get_llama_scope_r1_distill_config_from_hf,
|
|
943
1113
|
"dictionary_learning_1": get_dictionary_learning_config_1_from_hf,
|
|
944
1114
|
"deepseek_r1": get_deepseek_r1_config_from_hf,
|
|
1115
|
+
"sparsify": get_sparsify_config_from_hf,
|
|
945
1116
|
}
|
|
@@ -13634,39 +13634,51 @@ gemma-2-2b-res-matryoshka-dc:
|
|
|
13634
13634
|
- id: blocks.13.hook_resid_post
|
|
13635
13635
|
path: standard/blocks.13.hook_resid_post
|
|
13636
13636
|
l0: 40.0
|
|
13637
|
+
neuronpedia: gemma-2-2b/13-res-matryoshka-dc
|
|
13637
13638
|
- id: blocks.14.hook_resid_post
|
|
13638
13639
|
path: standard/blocks.14.hook_resid_post
|
|
13639
13640
|
l0: 40.0
|
|
13641
|
+
neuronpedia: gemma-2-2b/14-res-matryoshka-dc
|
|
13640
13642
|
- id: blocks.15.hook_resid_post
|
|
13641
13643
|
path: standard/blocks.15.hook_resid_post
|
|
13642
13644
|
l0: 40.0
|
|
13645
|
+
neuronpedia: gemma-2-2b/15-res-matryoshka-dc
|
|
13643
13646
|
- id: blocks.16.hook_resid_post
|
|
13644
13647
|
path: standard/blocks.16.hook_resid_post
|
|
13645
13648
|
l0: 40.0
|
|
13649
|
+
neuronpedia: gemma-2-2b/16-res-matryoshka-dc
|
|
13646
13650
|
- id: blocks.17.hook_resid_post
|
|
13647
13651
|
path: standard/blocks.17.hook_resid_post
|
|
13648
13652
|
l0: 40.0
|
|
13653
|
+
neuronpedia: gemma-2-2b/17-res-matryoshka-dc
|
|
13649
13654
|
- id: blocks.18.hook_resid_post
|
|
13650
13655
|
path: standard/blocks.18.hook_resid_post
|
|
13651
13656
|
l0: 40.0
|
|
13657
|
+
neuronpedia: gemma-2-2b/18-res-matryoshka-dc
|
|
13652
13658
|
- id: blocks.19.hook_resid_post
|
|
13653
13659
|
path: standard/blocks.19.hook_resid_post
|
|
13654
13660
|
l0: 40.0
|
|
13661
|
+
neuronpedia: gemma-2-2b/19-res-matryoshka-dc
|
|
13655
13662
|
- id: blocks.20.hook_resid_post
|
|
13656
13663
|
path: standard/blocks.20.hook_resid_post
|
|
13657
13664
|
l0: 40.0
|
|
13665
|
+
neuronpedia: gemma-2-2b/20-res-matryoshka-dc
|
|
13658
13666
|
- id: blocks.21.hook_resid_post
|
|
13659
13667
|
path: standard/blocks.21.hook_resid_post
|
|
13660
13668
|
l0: 40.0
|
|
13669
|
+
neuronpedia: gemma-2-2b/21-res-matryoshka-dc
|
|
13661
13670
|
- id: blocks.22.hook_resid_post
|
|
13662
13671
|
path: standard/blocks.22.hook_resid_post
|
|
13663
13672
|
l0: 40.0
|
|
13673
|
+
neuronpedia: gemma-2-2b/22-res-matryoshka-dc
|
|
13664
13674
|
- id: blocks.23.hook_resid_post
|
|
13665
13675
|
path: standard/blocks.23.hook_resid_post
|
|
13666
13676
|
l0: 40.0
|
|
13677
|
+
neuronpedia: gemma-2-2b/23-res-matryoshka-dc
|
|
13667
13678
|
- id: blocks.24.hook_resid_post
|
|
13668
13679
|
path: standard/blocks.24.hook_resid_post
|
|
13669
13680
|
l0: 40.0
|
|
13681
|
+
neuronpedia: gemma-2-2b/24-res-matryoshka-dc
|
|
13670
13682
|
gemma-2-2b-res-snap-matryoshka-dc:
|
|
13671
13683
|
conversion_func: null
|
|
13672
13684
|
links:
|