sae-lens 6.0.0rc3__tar.gz → 6.0.0rc4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/PKG-INFO +1 -1
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/pyproject.toml +2 -1
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/__init__.py +1 -1
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/analysis/neuronpedia_integration.py +3 -3
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/config.py +5 -3
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/evals.py +20 -9
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/llm_sae_training_runner.py +113 -5
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/loading/pretrained_sae_loaders.py +24 -5
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/saes/gated_sae.py +0 -4
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/saes/jumprelu_sae.py +4 -10
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/saes/sae.py +121 -48
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/saes/standard_sae.py +4 -11
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/saes/topk_sae.py +18 -12
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/training/activation_scaler.py +1 -1
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/training/activations_store.py +0 -2
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/training/sae_trainer.py +11 -3
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/training/upload_saes_to_huggingface.py +1 -1
- sae_lens-6.0.0rc3/sae_lens/training/geometric_median.py +0 -101
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/LICENSE +0 -0
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/README.md +0 -0
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/analysis/__init__.py +0 -0
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/cache_activations_runner.py +0 -0
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/constants.py +0 -0
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/load_model.py +0 -0
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/loading/__init__.py +0 -0
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/loading/pretrained_saes_directory.py +0 -0
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/pretokenize_runner.py +0 -0
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/pretrained_saes.yaml +0 -0
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/registry.py +0 -0
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/saes/__init__.py +0 -0
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/tokenization_and_batching.py +0 -0
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/training/__init__.py +0 -0
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/training/mixing_buffer.py +0 -0
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/training/optim.py +0 -0
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/training/types.py +0 -0
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/tutorial/tsea.py +0 -0
- {sae_lens-6.0.0rc3 → sae_lens-6.0.0rc4}/sae_lens/util.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "sae-lens"
|
|
3
|
-
version = "6.0.0-rc.
|
|
3
|
+
version = "6.0.0-rc.4"
|
|
4
4
|
description = "Training and Analyzing Sparse Autoencoders (SAEs)"
|
|
5
5
|
authors = ["Joseph Bloom"]
|
|
6
6
|
readme = "README.md"
|
|
@@ -60,6 +60,7 @@ tabulate = "^0.9.0"
|
|
|
60
60
|
ruff = "^0.7.4"
|
|
61
61
|
eai-sparsify = "^1.1.1"
|
|
62
62
|
mike = "^2.0.0"
|
|
63
|
+
trio = "^0.30.0"
|
|
63
64
|
|
|
64
65
|
[tool.poetry.extras]
|
|
65
66
|
mamba = ["mamba-lens"]
|
|
@@ -59,7 +59,7 @@ def NanAndInfReplacer(value: str):
|
|
|
59
59
|
|
|
60
60
|
|
|
61
61
|
def open_neuronpedia_feature_dashboard(sae: SAE[Any], index: int):
|
|
62
|
-
sae_id = sae.cfg.neuronpedia_id
|
|
62
|
+
sae_id = sae.cfg.metadata.neuronpedia_id
|
|
63
63
|
if sae_id is None:
|
|
64
64
|
logger.warning(
|
|
65
65
|
"SAE does not have a Neuronpedia ID. Either dashboards for this SAE do not exist (yet) on Neuronpedia, or the SAE was not loaded via the from_pretrained method"
|
|
@@ -74,7 +74,7 @@ def get_neuronpedia_quick_list(
|
|
|
74
74
|
features: list[int],
|
|
75
75
|
name: str = "temporary_list",
|
|
76
76
|
):
|
|
77
|
-
sae_id = sae.cfg.neuronpedia_id
|
|
77
|
+
sae_id = sae.cfg.metadata.neuronpedia_id
|
|
78
78
|
if sae_id is None:
|
|
79
79
|
logger.warning(
|
|
80
80
|
"SAE does not have a Neuronpedia ID. Either dashboards for this SAE do not exist (yet) on Neuronpedia, or the SAE was not loaded via the from_pretrained method"
|
|
@@ -86,7 +86,7 @@ def get_neuronpedia_quick_list(
|
|
|
86
86
|
url = url + "?name=" + name
|
|
87
87
|
list_feature = [
|
|
88
88
|
{
|
|
89
|
-
"modelId": sae.cfg.model_name,
|
|
89
|
+
"modelId": sae.cfg.metadata.model_name,
|
|
90
90
|
"layer": sae_id.split("/")[1],
|
|
91
91
|
"index": str(feature),
|
|
92
92
|
}
|
|
@@ -201,7 +201,7 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
|
201
201
|
train_batch_size_tokens: int = 4096
|
|
202
202
|
|
|
203
203
|
## Adam
|
|
204
|
-
adam_beta1: float = 0.
|
|
204
|
+
adam_beta1: float = 0.9
|
|
205
205
|
adam_beta2: float = 0.999
|
|
206
206
|
|
|
207
207
|
## Learning Rate Schedule
|
|
@@ -390,7 +390,6 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
|
390
390
|
adam_beta2=self.adam_beta2,
|
|
391
391
|
lr_decay_steps=self.lr_decay_steps,
|
|
392
392
|
n_restart_cycles=self.n_restart_cycles,
|
|
393
|
-
total_training_steps=self.total_training_steps,
|
|
394
393
|
train_batch_size_samples=self.train_batch_size_tokens,
|
|
395
394
|
dead_feature_window=self.dead_feature_window,
|
|
396
395
|
feature_sampling_window=self.feature_sampling_window,
|
|
@@ -613,8 +612,11 @@ class SAETrainerConfig:
|
|
|
613
612
|
adam_beta2: float
|
|
614
613
|
lr_decay_steps: int
|
|
615
614
|
n_restart_cycles: int
|
|
616
|
-
total_training_steps: int
|
|
617
615
|
train_batch_size_samples: int
|
|
618
616
|
dead_feature_window: int
|
|
619
617
|
feature_sampling_window: int
|
|
620
618
|
logger: LoggingConfig
|
|
619
|
+
|
|
620
|
+
@property
|
|
621
|
+
def total_training_steps(self) -> int:
|
|
622
|
+
return self.total_training_samples // self.train_batch_size_samples
|
|
@@ -4,6 +4,7 @@ import json
|
|
|
4
4
|
import math
|
|
5
5
|
import re
|
|
6
6
|
import subprocess
|
|
7
|
+
import sys
|
|
7
8
|
from collections import defaultdict
|
|
8
9
|
from collections.abc import Mapping
|
|
9
10
|
from dataclasses import dataclass, field
|
|
@@ -15,7 +16,7 @@ from typing import Any
|
|
|
15
16
|
import einops
|
|
16
17
|
import pandas as pd
|
|
17
18
|
import torch
|
|
18
|
-
from tqdm import tqdm
|
|
19
|
+
from tqdm.auto import tqdm
|
|
19
20
|
from transformer_lens import HookedTransformer
|
|
20
21
|
from transformer_lens.hook_points import HookedRootModule
|
|
21
22
|
|
|
@@ -814,16 +815,18 @@ def multiple_evals(
|
|
|
814
815
|
release=sae_release_name, # see other options in sae_lens/pretrained_saes.yaml
|
|
815
816
|
sae_id=sae_id, # won't always be a hook point
|
|
816
817
|
device=device,
|
|
817
|
-
)
|
|
818
|
+
)
|
|
818
819
|
|
|
819
820
|
# move SAE to device if not there already
|
|
820
821
|
sae.to(device)
|
|
821
822
|
|
|
822
|
-
if current_model_str != sae.cfg.model_name:
|
|
823
|
+
if current_model_str != sae.cfg.metadata.model_name:
|
|
823
824
|
del current_model # potentially saves GPU memory
|
|
824
|
-
current_model_str = sae.cfg.model_name
|
|
825
|
+
current_model_str = sae.cfg.metadata.model_name
|
|
825
826
|
current_model = HookedTransformer.from_pretrained_no_processing(
|
|
826
|
-
current_model_str,
|
|
827
|
+
current_model_str,
|
|
828
|
+
device=device,
|
|
829
|
+
**sae.cfg.metadata.model_from_pretrained_kwargs,
|
|
827
830
|
)
|
|
828
831
|
assert current_model is not None
|
|
829
832
|
|
|
@@ -941,7 +944,7 @@ def process_results(
|
|
|
941
944
|
}
|
|
942
945
|
|
|
943
946
|
|
|
944
|
-
|
|
947
|
+
def process_args(args: list[str]) -> argparse.Namespace:
|
|
945
948
|
arg_parser = argparse.ArgumentParser(description="Run evaluations on SAEs")
|
|
946
949
|
arg_parser.add_argument(
|
|
947
950
|
"sae_regex_pattern",
|
|
@@ -1031,11 +1034,19 @@ if __name__ == "__main__":
|
|
|
1031
1034
|
help="Enable verbose output with tqdm loaders.",
|
|
1032
1035
|
)
|
|
1033
1036
|
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
+
return arg_parser.parse_args(args)
|
|
1038
|
+
|
|
1039
|
+
|
|
1040
|
+
def run_evals_cli(args: list[str]) -> None:
|
|
1041
|
+
opts = process_args(args)
|
|
1042
|
+
eval_results = run_evaluations(opts)
|
|
1043
|
+
output_files = process_results(eval_results, opts.output_dir)
|
|
1037
1044
|
|
|
1038
1045
|
print("Evaluation complete. Output files:")
|
|
1039
1046
|
print(f"Individual JSONs: {len(output_files['individual_jsons'])}") # type: ignore
|
|
1040
1047
|
print(f"Combined JSON: {output_files['combined_json']}")
|
|
1041
1048
|
print(f"CSV: {output_files['csv']}")
|
|
1049
|
+
|
|
1050
|
+
|
|
1051
|
+
if __name__ == "__main__":
|
|
1052
|
+
run_evals_cli(sys.argv[1:])
|
|
@@ -4,7 +4,7 @@ import sys
|
|
|
4
4
|
from collections.abc import Sequence
|
|
5
5
|
from dataclasses import dataclass
|
|
6
6
|
from pathlib import Path
|
|
7
|
-
from typing import Any, Generic
|
|
7
|
+
from typing import Any, Generic
|
|
8
8
|
|
|
9
9
|
import torch
|
|
10
10
|
import wandb
|
|
@@ -17,12 +17,16 @@ from sae_lens.config import HfDataset, LanguageModelSAERunnerConfig
|
|
|
17
17
|
from sae_lens.constants import ACTIVATIONS_STORE_STATE_FILENAME, RUNNER_CFG_FILENAME
|
|
18
18
|
from sae_lens.evals import EvalConfig, run_evals
|
|
19
19
|
from sae_lens.load_model import load_model
|
|
20
|
+
from sae_lens.saes.gated_sae import GatedTrainingSAEConfig
|
|
21
|
+
from sae_lens.saes.jumprelu_sae import JumpReLUTrainingSAEConfig
|
|
20
22
|
from sae_lens.saes.sae import (
|
|
21
23
|
T_TRAINING_SAE,
|
|
22
24
|
T_TRAINING_SAE_CONFIG,
|
|
23
25
|
TrainingSAE,
|
|
24
26
|
TrainingSAEConfig,
|
|
25
27
|
)
|
|
28
|
+
from sae_lens.saes.standard_sae import StandardTrainingSAEConfig
|
|
29
|
+
from sae_lens.saes.topk_sae import TopKTrainingSAEConfig
|
|
26
30
|
from sae_lens.training.activation_scaler import ActivationScaler
|
|
27
31
|
from sae_lens.training.activations_store import ActivationsStore
|
|
28
32
|
from sae_lens.training.sae_trainer import SAETrainer
|
|
@@ -145,17 +149,18 @@ class LanguageModelSAETrainingRunner:
|
|
|
145
149
|
)
|
|
146
150
|
else:
|
|
147
151
|
self.sae = override_sae
|
|
152
|
+
self.sae.to(self.cfg.device)
|
|
148
153
|
|
|
149
154
|
def run(self):
|
|
150
155
|
"""
|
|
151
156
|
Run the training of the SAE.
|
|
152
157
|
"""
|
|
153
|
-
|
|
158
|
+
self._set_sae_metadata()
|
|
154
159
|
if self.cfg.logger.log_to_wandb:
|
|
155
160
|
wandb.init(
|
|
156
161
|
project=self.cfg.logger.wandb_project,
|
|
157
162
|
entity=self.cfg.logger.wandb_entity,
|
|
158
|
-
config=
|
|
163
|
+
config=self.cfg.to_dict(),
|
|
159
164
|
name=self.cfg.logger.run_name,
|
|
160
165
|
id=self.cfg.logger.wandb_id,
|
|
161
166
|
)
|
|
@@ -184,6 +189,20 @@ class LanguageModelSAETrainingRunner:
|
|
|
184
189
|
|
|
185
190
|
return sae
|
|
186
191
|
|
|
192
|
+
def _set_sae_metadata(self):
|
|
193
|
+
self.sae.cfg.metadata.dataset_path = self.cfg.dataset_path
|
|
194
|
+
self.sae.cfg.metadata.hook_name = self.cfg.hook_name
|
|
195
|
+
self.sae.cfg.metadata.model_name = self.cfg.model_name
|
|
196
|
+
self.sae.cfg.metadata.model_class_name = self.cfg.model_class_name
|
|
197
|
+
self.sae.cfg.metadata.hook_head_index = self.cfg.hook_head_index
|
|
198
|
+
self.sae.cfg.metadata.context_size = self.cfg.context_size
|
|
199
|
+
self.sae.cfg.metadata.seqpos_slice = self.cfg.seqpos_slice
|
|
200
|
+
self.sae.cfg.metadata.model_from_pretrained_kwargs = (
|
|
201
|
+
self.cfg.model_from_pretrained_kwargs
|
|
202
|
+
)
|
|
203
|
+
self.sae.cfg.metadata.prepend_bos = self.cfg.prepend_bos
|
|
204
|
+
self.sae.cfg.metadata.exclude_special_tokens = self.cfg.exclude_special_tokens
|
|
205
|
+
|
|
187
206
|
def _compile_if_needed(self):
|
|
188
207
|
# Compile model and SAE
|
|
189
208
|
# torch.compile can provide significant speedups (10-20% in testing)
|
|
@@ -247,11 +266,100 @@ class LanguageModelSAETrainingRunner:
|
|
|
247
266
|
def _parse_cfg_args(
|
|
248
267
|
args: Sequence[str],
|
|
249
268
|
) -> LanguageModelSAERunnerConfig[TrainingSAEConfig]:
|
|
269
|
+
"""
|
|
270
|
+
Parse command line arguments into a LanguageModelSAERunnerConfig.
|
|
271
|
+
|
|
272
|
+
This function first parses the architecture argument to determine which
|
|
273
|
+
concrete SAE config class to use, then parses the full configuration
|
|
274
|
+
with that concrete type.
|
|
275
|
+
"""
|
|
250
276
|
if len(args) == 0:
|
|
251
277
|
args = ["--help"]
|
|
278
|
+
|
|
279
|
+
# First, parse only the architecture to determine which concrete class to use
|
|
280
|
+
architecture_parser = ArgumentParser(
|
|
281
|
+
description="Parse architecture to determine SAE config class",
|
|
282
|
+
exit_on_error=False,
|
|
283
|
+
add_help=False, # Don't add help to avoid conflicts
|
|
284
|
+
)
|
|
285
|
+
architecture_parser.add_argument(
|
|
286
|
+
"--architecture",
|
|
287
|
+
type=str,
|
|
288
|
+
choices=["standard", "gated", "jumprelu", "topk"],
|
|
289
|
+
default="standard",
|
|
290
|
+
help="SAE architecture to use",
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
# Parse known args to extract architecture, ignore unknown args for now
|
|
294
|
+
arch_args, remaining_args = architecture_parser.parse_known_args(args)
|
|
295
|
+
architecture = arch_args.architecture
|
|
296
|
+
|
|
297
|
+
# Remove architecture from remaining args if it exists
|
|
298
|
+
filtered_args = []
|
|
299
|
+
skip_next = False
|
|
300
|
+
for arg in remaining_args:
|
|
301
|
+
if skip_next:
|
|
302
|
+
skip_next = False
|
|
303
|
+
continue
|
|
304
|
+
if arg == "--architecture":
|
|
305
|
+
skip_next = True # Skip the next argument (the architecture value)
|
|
306
|
+
continue
|
|
307
|
+
filtered_args.append(arg)
|
|
308
|
+
|
|
309
|
+
# Create a custom wrapper class that simple_parsing can handle
|
|
310
|
+
def create_config_class(
|
|
311
|
+
sae_config_type: type[TrainingSAEConfig],
|
|
312
|
+
) -> type[LanguageModelSAERunnerConfig[TrainingSAEConfig]]:
|
|
313
|
+
"""Create a concrete config class for the given SAE config type."""
|
|
314
|
+
|
|
315
|
+
# Create the base config without the sae field
|
|
316
|
+
from dataclasses import field as dataclass_field
|
|
317
|
+
from dataclasses import fields, make_dataclass
|
|
318
|
+
|
|
319
|
+
# Get all fields from LanguageModelSAERunnerConfig except the generic sae field
|
|
320
|
+
base_fields = []
|
|
321
|
+
for field_obj in fields(LanguageModelSAERunnerConfig):
|
|
322
|
+
if field_obj.name != "sae":
|
|
323
|
+
base_fields.append((field_obj.name, field_obj.type, field_obj))
|
|
324
|
+
|
|
325
|
+
# Add the concrete sae field
|
|
326
|
+
base_fields.append(
|
|
327
|
+
(
|
|
328
|
+
"sae",
|
|
329
|
+
sae_config_type,
|
|
330
|
+
dataclass_field(
|
|
331
|
+
default_factory=lambda: sae_config_type(d_in=512, d_sae=1024)
|
|
332
|
+
),
|
|
333
|
+
)
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
# Create the concrete class
|
|
337
|
+
return make_dataclass(
|
|
338
|
+
f"{sae_config_type.__name__}RunnerConfig",
|
|
339
|
+
base_fields,
|
|
340
|
+
bases=(LanguageModelSAERunnerConfig,),
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
# Map architecture to concrete config class
|
|
344
|
+
sae_config_map = {
|
|
345
|
+
"standard": StandardTrainingSAEConfig,
|
|
346
|
+
"gated": GatedTrainingSAEConfig,
|
|
347
|
+
"jumprelu": JumpReLUTrainingSAEConfig,
|
|
348
|
+
"topk": TopKTrainingSAEConfig,
|
|
349
|
+
}
|
|
350
|
+
|
|
351
|
+
sae_config_type = sae_config_map[architecture]
|
|
352
|
+
concrete_config_class = create_config_class(sae_config_type)
|
|
353
|
+
|
|
354
|
+
# Now parse the full configuration with the concrete type
|
|
252
355
|
parser = ArgumentParser(exit_on_error=False)
|
|
253
|
-
parser.add_arguments(
|
|
254
|
-
|
|
356
|
+
parser.add_arguments(concrete_config_class, dest="cfg")
|
|
357
|
+
|
|
358
|
+
# Parse the filtered arguments (without --architecture)
|
|
359
|
+
parsed_args = parser.parse_args(filtered_args)
|
|
360
|
+
|
|
361
|
+
# Return the parsed configuration
|
|
362
|
+
return parsed_args.cfg
|
|
255
363
|
|
|
256
364
|
|
|
257
365
|
# moved into its own function to make it easier to test
|
|
@@ -26,6 +26,22 @@ from sae_lens.loading.pretrained_saes_directory import (
|
|
|
26
26
|
from sae_lens.registry import get_sae_class
|
|
27
27
|
from sae_lens.util import filter_valid_dataclass_fields
|
|
28
28
|
|
|
29
|
+
LLM_METADATA_KEYS = {
|
|
30
|
+
"model_name",
|
|
31
|
+
"hook_name",
|
|
32
|
+
"model_class_name",
|
|
33
|
+
"hook_head_index",
|
|
34
|
+
"model_from_pretrained_kwargs",
|
|
35
|
+
"prepend_bos",
|
|
36
|
+
"exclude_special_tokens",
|
|
37
|
+
"neuronpedia_id",
|
|
38
|
+
"context_size",
|
|
39
|
+
"seqpos_slice",
|
|
40
|
+
"dataset_path",
|
|
41
|
+
"sae_lens_version",
|
|
42
|
+
"sae_lens_training_version",
|
|
43
|
+
}
|
|
44
|
+
|
|
29
45
|
|
|
30
46
|
# loaders take in a release, sae_id, device, and whether to force download, and returns a tuple of config, state_dict, and log sparsity
|
|
31
47
|
class PretrainedSaeHuggingfaceLoader(Protocol):
|
|
@@ -207,6 +223,10 @@ def handle_pre_6_0_config(cfg_dict: dict[str, Any]) -> dict[str, Any]:
|
|
|
207
223
|
new_cfg.setdefault("activation_fn", new_cfg.get("activation_fn", "relu"))
|
|
208
224
|
new_cfg.setdefault("architecture", "standard")
|
|
209
225
|
new_cfg.setdefault("neuronpedia_id", None)
|
|
226
|
+
new_cfg.setdefault(
|
|
227
|
+
"reshape_activations",
|
|
228
|
+
"hook_z" if "hook_z" in new_cfg.get("hook_name", "") else "none",
|
|
229
|
+
)
|
|
210
230
|
|
|
211
231
|
if "normalize_activations" in new_cfg and isinstance(
|
|
212
232
|
new_cfg["normalize_activations"], bool
|
|
@@ -231,11 +251,9 @@ def handle_pre_6_0_config(cfg_dict: dict[str, Any]) -> dict[str, Any]:
|
|
|
231
251
|
if architecture == "topk":
|
|
232
252
|
sae_cfg_dict["k"] = new_cfg["activation_fn_kwargs"]["k"]
|
|
233
253
|
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
meta_dict = filter_valid_dataclass_fields(new_cfg, SAEMetadata)
|
|
238
|
-
sae_cfg_dict["metadata"] = meta_dict
|
|
254
|
+
sae_cfg_dict["metadata"] = {
|
|
255
|
+
k: v for k, v in new_cfg.items() if k in LLM_METADATA_KEYS
|
|
256
|
+
}
|
|
239
257
|
sae_cfg_dict["architecture"] = architecture
|
|
240
258
|
return sae_cfg_dict
|
|
241
259
|
|
|
@@ -271,6 +289,7 @@ def get_connor_rob_hook_z_config_from_hf(
|
|
|
271
289
|
"context_size": 128,
|
|
272
290
|
"normalize_activations": "none",
|
|
273
291
|
"dataset_trust_remote_code": True,
|
|
292
|
+
"reshape_activations": "hook_z",
|
|
274
293
|
**(cfg_overrides or {}),
|
|
275
294
|
}
|
|
276
295
|
|
|
@@ -168,10 +168,6 @@ class GatedTrainingSAE(TrainingSAE[GatedTrainingSAEConfig]):
|
|
|
168
168
|
|
|
169
169
|
# Magnitude path
|
|
170
170
|
magnitude_pre_activation = sae_in @ (self.W_enc * self.r_mag.exp()) + self.b_mag
|
|
171
|
-
if self.training and self.cfg.noise_scale > 0:
|
|
172
|
-
magnitude_pre_activation += (
|
|
173
|
-
torch.randn_like(magnitude_pre_activation) * self.cfg.noise_scale
|
|
174
|
-
)
|
|
175
171
|
magnitude_pre_activation = self.hook_sae_acts_pre(magnitude_pre_activation)
|
|
176
172
|
|
|
177
173
|
feature_magnitudes = self.activation_fn(magnitude_pre_activation)
|
|
@@ -105,7 +105,7 @@ class JumpReLUSAE(SAE[JumpReLUSAEConfig]):
|
|
|
105
105
|
JumpReLUSAE is an inference-only implementation of a Sparse Autoencoder (SAE)
|
|
106
106
|
using a JumpReLU activation. For each unit, if its pre-activation is
|
|
107
107
|
<= threshold, that unit is zeroed out; otherwise, it follows a user-specified
|
|
108
|
-
activation function (e.g., ReLU
|
|
108
|
+
activation function (e.g., ReLU etc.).
|
|
109
109
|
|
|
110
110
|
It implements:
|
|
111
111
|
- initialize_weights: sets up parameters, including a threshold.
|
|
@@ -142,7 +142,7 @@ class JumpReLUSAE(SAE[JumpReLUSAEConfig]):
|
|
|
142
142
|
sae_in = self.process_sae_in(x)
|
|
143
143
|
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
|
|
144
144
|
|
|
145
|
-
# 1) Apply the base "activation_fn" from config (e.g., ReLU
|
|
145
|
+
# 1) Apply the base "activation_fn" from config (e.g., ReLU).
|
|
146
146
|
base_acts = self.activation_fn(hidden_pre)
|
|
147
147
|
|
|
148
148
|
# 2) Zero out any unit whose (hidden_pre <= threshold).
|
|
@@ -191,8 +191,8 @@ class JumpReLUTrainingSAEConfig(TrainingSAEConfig):
|
|
|
191
191
|
Configuration class for training a JumpReLUTrainingSAE.
|
|
192
192
|
"""
|
|
193
193
|
|
|
194
|
-
jumprelu_init_threshold: float = 0.
|
|
195
|
-
jumprelu_bandwidth: float = 0.
|
|
194
|
+
jumprelu_init_threshold: float = 0.01
|
|
195
|
+
jumprelu_bandwidth: float = 0.05
|
|
196
196
|
l0_coefficient: float = 1.0
|
|
197
197
|
l0_warm_up_steps: int = 0
|
|
198
198
|
|
|
@@ -257,12 +257,6 @@ class JumpReLUTrainingSAE(TrainingSAE[JumpReLUTrainingSAEConfig]):
|
|
|
257
257
|
sae_in = self.process_sae_in(x)
|
|
258
258
|
|
|
259
259
|
hidden_pre = sae_in @ self.W_enc + self.b_enc
|
|
260
|
-
|
|
261
|
-
if self.training and self.cfg.noise_scale > 0:
|
|
262
|
-
hidden_pre = (
|
|
263
|
-
hidden_pre + torch.randn_like(hidden_pre) * self.cfg.noise_scale
|
|
264
|
-
)
|
|
265
|
-
|
|
266
260
|
feature_acts = JumpReLU.apply(hidden_pre, self.threshold, self.bandwidth)
|
|
267
261
|
|
|
268
262
|
return feature_acts, hidden_pre # type: ignore
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Base classes for Sparse Autoencoders (SAEs)."""
|
|
2
2
|
|
|
3
|
+
import copy
|
|
3
4
|
import json
|
|
4
5
|
import warnings
|
|
5
6
|
from abc import ABC, abstractmethod
|
|
@@ -59,23 +60,91 @@ T_SAE = TypeVar("T_SAE", bound="SAE") # type: ignore
|
|
|
59
60
|
T_TRAINING_SAE = TypeVar("T_TRAINING_SAE", bound="TrainingSAE") # type: ignore
|
|
60
61
|
|
|
61
62
|
|
|
62
|
-
@dataclass
|
|
63
63
|
class SAEMetadata:
|
|
64
64
|
"""Core metadata about how this SAE should be used, if known."""
|
|
65
65
|
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
66
|
+
def __init__(self, **kwargs: Any):
|
|
67
|
+
# Set default version fields with their current behavior
|
|
68
|
+
self.sae_lens_version = kwargs.pop("sae_lens_version", __version__)
|
|
69
|
+
self.sae_lens_training_version = kwargs.pop(
|
|
70
|
+
"sae_lens_training_version", __version__
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
# Set all other attributes dynamically
|
|
74
|
+
for key, value in kwargs.items():
|
|
75
|
+
setattr(self, key, value)
|
|
76
|
+
|
|
77
|
+
def __getattr__(self, name: str) -> None:
|
|
78
|
+
"""Return None for any missing attribute (like defaultdict)"""
|
|
79
|
+
return
|
|
80
|
+
|
|
81
|
+
def __setattr__(self, name: str, value: Any) -> None:
|
|
82
|
+
"""Allow setting any attribute"""
|
|
83
|
+
super().__setattr__(name, value)
|
|
84
|
+
|
|
85
|
+
def __getitem__(self, key: str) -> Any:
|
|
86
|
+
"""Allow dictionary-style access: metadata['key'] - returns None for missing keys"""
|
|
87
|
+
return getattr(self, key)
|
|
88
|
+
|
|
89
|
+
def __setitem__(self, key: str, value: Any) -> None:
|
|
90
|
+
"""Allow dictionary-style assignment: metadata['key'] = value"""
|
|
91
|
+
setattr(self, key, value)
|
|
92
|
+
|
|
93
|
+
def __contains__(self, key: str) -> bool:
|
|
94
|
+
"""Allow 'in' operator: 'key' in metadata"""
|
|
95
|
+
# Only return True if the attribute was explicitly set (not just defaulting to None)
|
|
96
|
+
return key in self.__dict__
|
|
97
|
+
|
|
98
|
+
def get(self, key: str, default: Any = None) -> Any:
|
|
99
|
+
"""Dictionary-style get with default"""
|
|
100
|
+
value = getattr(self, key)
|
|
101
|
+
# If the attribute wasn't explicitly set and we got None from __getattr__,
|
|
102
|
+
# use the provided default instead
|
|
103
|
+
if key not in self.__dict__ and value is None:
|
|
104
|
+
return default
|
|
105
|
+
return value
|
|
106
|
+
|
|
107
|
+
def keys(self):
|
|
108
|
+
"""Return all explicitly set attribute names"""
|
|
109
|
+
return self.__dict__.keys()
|
|
110
|
+
|
|
111
|
+
def values(self):
|
|
112
|
+
"""Return all explicitly set attribute values"""
|
|
113
|
+
return self.__dict__.values()
|
|
114
|
+
|
|
115
|
+
def items(self):
|
|
116
|
+
"""Return all explicitly set attribute name-value pairs"""
|
|
117
|
+
return self.__dict__.items()
|
|
118
|
+
|
|
119
|
+
def to_dict(self) -> dict[str, Any]:
|
|
120
|
+
"""Convert to dictionary for serialization"""
|
|
121
|
+
return self.__dict__.copy()
|
|
122
|
+
|
|
123
|
+
@classmethod
|
|
124
|
+
def from_dict(cls, data: dict[str, Any]) -> "SAEMetadata":
|
|
125
|
+
"""Create from dictionary"""
|
|
126
|
+
return cls(**data)
|
|
127
|
+
|
|
128
|
+
def __repr__(self) -> str:
|
|
129
|
+
return f"SAEMetadata({self.__dict__})"
|
|
130
|
+
|
|
131
|
+
def __eq__(self, other: object) -> bool:
|
|
132
|
+
if not isinstance(other, SAEMetadata):
|
|
133
|
+
return False
|
|
134
|
+
return self.__dict__ == other.__dict__
|
|
135
|
+
|
|
136
|
+
def __deepcopy__(self, memo: dict[int, Any]) -> "SAEMetadata":
|
|
137
|
+
"""Support for deep copying"""
|
|
138
|
+
|
|
139
|
+
return SAEMetadata(**copy.deepcopy(self.__dict__, memo))
|
|
140
|
+
|
|
141
|
+
def __getstate__(self) -> dict[str, Any]:
|
|
142
|
+
"""Support for pickling"""
|
|
143
|
+
return self.__dict__
|
|
144
|
+
|
|
145
|
+
def __setstate__(self, state: dict[str, Any]) -> None:
|
|
146
|
+
"""Support for unpickling"""
|
|
147
|
+
self.__dict__.update(state)
|
|
79
148
|
|
|
80
149
|
|
|
81
150
|
@dataclass
|
|
@@ -99,7 +168,7 @@ class SAEConfig(ABC):
|
|
|
99
168
|
|
|
100
169
|
def to_dict(self) -> dict[str, Any]:
|
|
101
170
|
res = {field.name: getattr(self, field.name) for field in fields(self)}
|
|
102
|
-
res["metadata"] =
|
|
171
|
+
res["metadata"] = self.metadata.to_dict()
|
|
103
172
|
res["architecture"] = self.architecture()
|
|
104
173
|
return res
|
|
105
174
|
|
|
@@ -124,7 +193,7 @@ class SAEConfig(ABC):
|
|
|
124
193
|
"layer_norm",
|
|
125
194
|
]:
|
|
126
195
|
raise ValueError(
|
|
127
|
-
f"normalize_activations must be none, expected_average_only_in,
|
|
196
|
+
f"normalize_activations must be none, expected_average_only_in, layer_norm, or constant_norm_rescale. Got {self.normalize_activations}"
|
|
128
197
|
)
|
|
129
198
|
|
|
130
199
|
|
|
@@ -238,9 +307,8 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
238
307
|
|
|
239
308
|
self.run_time_activation_norm_fn_in = run_time_activation_norm_fn_in
|
|
240
309
|
self.run_time_activation_norm_fn_out = run_time_activation_norm_fn_out
|
|
241
|
-
|
|
242
310
|
elif self.cfg.normalize_activations == "layer_norm":
|
|
243
|
-
|
|
311
|
+
# we need to scale the norm of the input and store the scaling factor
|
|
244
312
|
def run_time_activation_ln_in(
|
|
245
313
|
x: torch.Tensor, eps: float = 1e-5
|
|
246
314
|
) -> torch.Tensor:
|
|
@@ -522,7 +590,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
522
590
|
device: str = "cpu",
|
|
523
591
|
force_download: bool = False,
|
|
524
592
|
converter: PretrainedSaeHuggingfaceLoader | None = None,
|
|
525
|
-
) ->
|
|
593
|
+
) -> T_SAE:
|
|
526
594
|
"""
|
|
527
595
|
Load a pretrained SAE from the Hugging Face model hub.
|
|
528
596
|
|
|
@@ -530,7 +598,28 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
530
598
|
release: The release name. This will be mapped to a huggingface repo id based on the pretrained_saes.yaml file.
|
|
531
599
|
id: The id of the SAE to load. This will be mapped to a path in the huggingface repo.
|
|
532
600
|
device: The device to load the SAE on.
|
|
533
|
-
|
|
601
|
+
"""
|
|
602
|
+
return cls.from_pretrained_with_cfg_and_sparsity(
|
|
603
|
+
release, sae_id, device, force_download, converter=converter
|
|
604
|
+
)[0]
|
|
605
|
+
|
|
606
|
+
@classmethod
|
|
607
|
+
def from_pretrained_with_cfg_and_sparsity(
|
|
608
|
+
cls: Type[T_SAE],
|
|
609
|
+
release: str,
|
|
610
|
+
sae_id: str,
|
|
611
|
+
device: str = "cpu",
|
|
612
|
+
force_download: bool = False,
|
|
613
|
+
converter: PretrainedSaeHuggingfaceLoader | None = None,
|
|
614
|
+
) -> tuple[T_SAE, dict[str, Any], torch.Tensor | None]:
|
|
615
|
+
"""
|
|
616
|
+
Load a pretrained SAE from the Hugging Face model hub, along with its config dict and sparsity, if present.
|
|
617
|
+
In SAELens <= 5.x.x, this was called SAE.from_pretrained().
|
|
618
|
+
|
|
619
|
+
Args:
|
|
620
|
+
release: The release name. This will be mapped to a huggingface repo id based on the pretrained_saes.yaml file.
|
|
621
|
+
id: The id of the SAE to load. This will be mapped to a path in the huggingface repo.
|
|
622
|
+
device: The device to load the SAE on.
|
|
534
623
|
"""
|
|
535
624
|
|
|
536
625
|
# get sae directory
|
|
@@ -646,8 +735,6 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
646
735
|
|
|
647
736
|
@dataclass(kw_only=True)
|
|
648
737
|
class TrainingSAEConfig(SAEConfig, ABC):
|
|
649
|
-
noise_scale: float = 0.0
|
|
650
|
-
mse_loss_normalization: str | None = None
|
|
651
738
|
# https://transformer-circuits.pub/2024/april-update/index.html#training-saes
|
|
652
739
|
# 0.1 corresponds to the "heuristic" initialization, use None to disable
|
|
653
740
|
decoder_init_norm: float | None = 0.1
|
|
@@ -680,9 +767,6 @@ class TrainingSAEConfig(SAEConfig, ABC):
|
|
|
680
767
|
def from_dict(
|
|
681
768
|
cls: type[T_TRAINING_SAE_CONFIG], config_dict: dict[str, Any]
|
|
682
769
|
) -> T_TRAINING_SAE_CONFIG:
|
|
683
|
-
# remove any keys that are not in the dataclass
|
|
684
|
-
# since we sometimes enhance the config with the whole LM runner config
|
|
685
|
-
valid_config_dict = filter_valid_dataclass_fields(config_dict, cls)
|
|
686
770
|
cfg_class = cls
|
|
687
771
|
if "architecture" in config_dict:
|
|
688
772
|
cfg_class = get_sae_training_class(config_dict["architecture"])[1]
|
|
@@ -690,6 +774,9 @@ class TrainingSAEConfig(SAEConfig, ABC):
|
|
|
690
774
|
raise ValueError(
|
|
691
775
|
f"SAE config class {cls} does not match dict config class {type(cfg_class)}"
|
|
692
776
|
)
|
|
777
|
+
# remove any keys that are not in the dataclass
|
|
778
|
+
# since we sometimes enhance the config with the whole LM runner config
|
|
779
|
+
valid_config_dict = filter_valid_dataclass_fields(config_dict, cfg_class)
|
|
693
780
|
if "metadata" in config_dict:
|
|
694
781
|
valid_config_dict["metadata"] = SAEMetadata(**config_dict["metadata"])
|
|
695
782
|
return cfg_class(**valid_config_dict)
|
|
@@ -698,6 +785,7 @@ class TrainingSAEConfig(SAEConfig, ABC):
|
|
|
698
785
|
return {
|
|
699
786
|
**super().to_dict(),
|
|
700
787
|
**asdict(self),
|
|
788
|
+
"metadata": self.metadata.to_dict(),
|
|
701
789
|
"architecture": self.architecture(),
|
|
702
790
|
}
|
|
703
791
|
|
|
@@ -708,12 +796,14 @@ class TrainingSAEConfig(SAEConfig, ABC):
|
|
|
708
796
|
Creates a dictionary containing attributes corresponding to the fields
|
|
709
797
|
defined in the base SAEConfig class.
|
|
710
798
|
"""
|
|
711
|
-
|
|
799
|
+
base_sae_cfg_class = get_sae_class(self.architecture())[1]
|
|
800
|
+
base_config_field_names = {f.name for f in fields(base_sae_cfg_class)}
|
|
712
801
|
result_dict = {
|
|
713
802
|
field_name: getattr(self, field_name)
|
|
714
803
|
for field_name in base_config_field_names
|
|
715
804
|
}
|
|
716
805
|
result_dict["architecture"] = self.architecture()
|
|
806
|
+
result_dict["metadata"] = self.metadata.to_dict()
|
|
717
807
|
return result_dict
|
|
718
808
|
|
|
719
809
|
|
|
@@ -726,7 +816,7 @@ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
|
|
|
726
816
|
# Turn off hook_z reshaping for training mode - the activation store
|
|
727
817
|
# is expected to handle reshaping before passing data to the SAE
|
|
728
818
|
self.turn_off_forward_pass_hook_z_reshaping()
|
|
729
|
-
self.mse_loss_fn =
|
|
819
|
+
self.mse_loss_fn = mse_loss
|
|
730
820
|
|
|
731
821
|
@abstractmethod
|
|
732
822
|
def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]: ...
|
|
@@ -861,27 +951,6 @@ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
|
|
|
861
951
|
"""
|
|
862
952
|
return self.process_state_dict_for_saving(state_dict)
|
|
863
953
|
|
|
864
|
-
def _get_mse_loss_fn(self) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
|
|
865
|
-
"""Get the MSE loss function based on config."""
|
|
866
|
-
|
|
867
|
-
def standard_mse_loss_fn(
|
|
868
|
-
preds: torch.Tensor, target: torch.Tensor
|
|
869
|
-
) -> torch.Tensor:
|
|
870
|
-
return torch.nn.functional.mse_loss(preds, target, reduction="none")
|
|
871
|
-
|
|
872
|
-
def batch_norm_mse_loss_fn(
|
|
873
|
-
preds: torch.Tensor, target: torch.Tensor
|
|
874
|
-
) -> torch.Tensor:
|
|
875
|
-
target_centered = target - target.mean(dim=0, keepdim=True)
|
|
876
|
-
normalization = target_centered.norm(dim=-1, keepdim=True)
|
|
877
|
-
return torch.nn.functional.mse_loss(preds, target, reduction="none") / (
|
|
878
|
-
normalization + 1e-6
|
|
879
|
-
)
|
|
880
|
-
|
|
881
|
-
if self.cfg.mse_loss_normalization == "dense_batch":
|
|
882
|
-
return batch_norm_mse_loss_fn
|
|
883
|
-
return standard_mse_loss_fn
|
|
884
|
-
|
|
885
954
|
@torch.no_grad()
|
|
886
955
|
def remove_gradient_parallel_to_decoder_directions(self) -> None:
|
|
887
956
|
"""Remove gradient components parallel to decoder directions."""
|
|
@@ -943,3 +1012,7 @@ def _disable_hooks(sae: SAE[Any]):
|
|
|
943
1012
|
finally:
|
|
944
1013
|
for hook_name, hook in sae.hook_dict.items():
|
|
945
1014
|
setattr(sae, hook_name, hook)
|
|
1015
|
+
|
|
1016
|
+
|
|
1017
|
+
def mse_loss(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
1018
|
+
return torch.nn.functional.mse_loss(preds, target, reduction="none")
|
|
@@ -67,7 +67,7 @@ class StandardSAE(SAE[StandardSAEConfig]):
|
|
|
67
67
|
sae_in = self.process_sae_in(x)
|
|
68
68
|
# Compute the pre-activation values
|
|
69
69
|
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
|
|
70
|
-
# Apply the activation function (e.g., ReLU,
|
|
70
|
+
# Apply the activation function (e.g., ReLU, depending on config)
|
|
71
71
|
return self.hook_sae_acts_post(self.activation_fn(hidden_pre))
|
|
72
72
|
|
|
73
73
|
def decode(
|
|
@@ -81,7 +81,7 @@ class StandardSAE(SAE[StandardSAEConfig]):
|
|
|
81
81
|
sae_out_pre = feature_acts @ self.W_dec + self.b_dec
|
|
82
82
|
# 2) hook reconstruction
|
|
83
83
|
sae_out_pre = self.hook_sae_recons(sae_out_pre)
|
|
84
|
-
# 4) optional out-normalization (e.g. constant_norm_rescale
|
|
84
|
+
# 4) optional out-normalization (e.g. constant_norm_rescale)
|
|
85
85
|
sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
|
|
86
86
|
# 5) if hook_z is enabled, rearrange back to (..., n_heads, d_head).
|
|
87
87
|
return self.reshape_fn_out(sae_out_pre, self.d_head)
|
|
@@ -136,16 +136,9 @@ class StandardTrainingSAE(TrainingSAE[StandardTrainingSAEConfig]):
|
|
|
136
136
|
sae_in = self.process_sae_in(x)
|
|
137
137
|
# Compute the pre-activation (and allow for a hook if desired)
|
|
138
138
|
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc) # type: ignore
|
|
139
|
-
# Add noise during training for robustness (scaled by noise_scale from the configuration)
|
|
140
|
-
if self.training and self.cfg.noise_scale > 0:
|
|
141
|
-
hidden_pre_noised = (
|
|
142
|
-
hidden_pre + torch.randn_like(hidden_pre) * self.cfg.noise_scale
|
|
143
|
-
)
|
|
144
|
-
else:
|
|
145
|
-
hidden_pre_noised = hidden_pre
|
|
146
139
|
# Apply the activation function (and any post-activation hook)
|
|
147
|
-
feature_acts = self.hook_sae_acts_post(self.activation_fn(
|
|
148
|
-
return feature_acts,
|
|
140
|
+
feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre))
|
|
141
|
+
return feature_acts, hidden_pre
|
|
149
142
|
|
|
150
143
|
def calculate_aux_loss(
|
|
151
144
|
self,
|
|
@@ -91,8 +91,7 @@ class TopKSAE(SAE[TopKSAEConfig]):
|
|
|
91
91
|
) -> Float[torch.Tensor, "... d_sae"]:
|
|
92
92
|
"""
|
|
93
93
|
Converts input x into feature activations.
|
|
94
|
-
Uses topk activation
|
|
95
|
-
under the hood.
|
|
94
|
+
Uses topk activation under the hood.
|
|
96
95
|
"""
|
|
97
96
|
sae_in = self.process_sae_in(x)
|
|
98
97
|
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
|
|
@@ -116,6 +115,13 @@ class TopKSAE(SAE[TopKSAEConfig]):
|
|
|
116
115
|
def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
|
|
117
116
|
return TopK(self.cfg.k)
|
|
118
117
|
|
|
118
|
+
@override
|
|
119
|
+
@torch.no_grad()
|
|
120
|
+
def fold_W_dec_norm(self) -> None:
|
|
121
|
+
raise NotImplementedError(
|
|
122
|
+
"Folding W_dec_norm is not safe for TopKSAEs, as this may change the topk activations"
|
|
123
|
+
)
|
|
124
|
+
|
|
119
125
|
|
|
120
126
|
@dataclass
|
|
121
127
|
class TopKTrainingSAEConfig(TrainingSAEConfig):
|
|
@@ -156,18 +162,11 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
|
|
|
156
162
|
sae_in = self.process_sae_in(x)
|
|
157
163
|
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
|
|
158
164
|
|
|
159
|
-
# Inject noise if training
|
|
160
|
-
if self.training and self.cfg.noise_scale > 0:
|
|
161
|
-
hidden_pre_noised = (
|
|
162
|
-
hidden_pre + torch.randn_like(hidden_pre) * self.cfg.noise_scale
|
|
163
|
-
)
|
|
164
|
-
else:
|
|
165
|
-
hidden_pre_noised = hidden_pre
|
|
166
|
-
|
|
167
165
|
# Apply the TopK activation function (already set in self.activation_fn if config is "topk")
|
|
168
|
-
feature_acts = self.hook_sae_acts_post(self.activation_fn(
|
|
169
|
-
return feature_acts,
|
|
166
|
+
feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre))
|
|
167
|
+
return feature_acts, hidden_pre
|
|
170
168
|
|
|
169
|
+
@override
|
|
171
170
|
def calculate_aux_loss(
|
|
172
171
|
self,
|
|
173
172
|
step_input: TrainStepInput,
|
|
@@ -184,6 +183,13 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
|
|
|
184
183
|
)
|
|
185
184
|
return {"auxiliary_reconstruction_loss": topk_loss}
|
|
186
185
|
|
|
186
|
+
@override
|
|
187
|
+
@torch.no_grad()
|
|
188
|
+
def fold_W_dec_norm(self) -> None:
|
|
189
|
+
raise NotImplementedError(
|
|
190
|
+
"Folding W_dec_norm is not safe for TopKSAEs, as this may change the topk activations"
|
|
191
|
+
)
|
|
192
|
+
|
|
187
193
|
@override
|
|
188
194
|
def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
|
|
189
195
|
return TopK(self.cfg.k)
|
|
@@ -161,8 +161,6 @@ class ActivationsStore:
|
|
|
161
161
|
) -> ActivationsStore:
|
|
162
162
|
if sae.cfg.metadata.hook_name is None:
|
|
163
163
|
raise ValueError("hook_name is required")
|
|
164
|
-
if sae.cfg.metadata.hook_head_index is None:
|
|
165
|
-
raise ValueError("hook_head_index is required")
|
|
166
164
|
if sae.cfg.metadata.context_size is None:
|
|
167
165
|
raise ValueError("context_size is required")
|
|
168
166
|
if sae.cfg.metadata.prepend_bos is None:
|
|
@@ -7,7 +7,7 @@ import torch
|
|
|
7
7
|
import wandb
|
|
8
8
|
from safetensors.torch import save_file
|
|
9
9
|
from torch.optim import Adam
|
|
10
|
-
from tqdm import tqdm
|
|
10
|
+
from tqdm.auto import tqdm
|
|
11
11
|
|
|
12
12
|
from sae_lens import __version__
|
|
13
13
|
from sae_lens.config import SAETrainerConfig
|
|
@@ -161,6 +161,7 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
|
161
161
|
return (self.n_forward_passes_since_fired > self.cfg.dead_feature_window).bool()
|
|
162
162
|
|
|
163
163
|
def fit(self) -> T_TRAINING_SAE:
|
|
164
|
+
self.sae.to(self.cfg.device)
|
|
164
165
|
pbar = tqdm(total=self.cfg.total_training_samples, desc="Training SAE")
|
|
165
166
|
|
|
166
167
|
if self.sae.cfg.normalize_activations == "expected_average_only_in":
|
|
@@ -194,10 +195,11 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
|
194
195
|
)
|
|
195
196
|
self.activation_scaler.scaling_factor = None
|
|
196
197
|
|
|
197
|
-
# save final sae group to checkpoints folder
|
|
198
|
+
# save final inference sae group to checkpoints folder
|
|
198
199
|
self.save_checkpoint(
|
|
199
200
|
checkpoint_name=f"final_{self.n_training_samples}",
|
|
200
201
|
wandb_aliases=["final_model"],
|
|
202
|
+
save_inference_model=True,
|
|
201
203
|
)
|
|
202
204
|
|
|
203
205
|
pbar.close()
|
|
@@ -207,11 +209,17 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
|
207
209
|
self,
|
|
208
210
|
checkpoint_name: str,
|
|
209
211
|
wandb_aliases: list[str] | None = None,
|
|
212
|
+
save_inference_model: bool = False,
|
|
210
213
|
) -> None:
|
|
211
214
|
checkpoint_path = Path(self.cfg.checkpoint_path) / checkpoint_name
|
|
212
215
|
checkpoint_path.mkdir(exist_ok=True, parents=True)
|
|
213
216
|
|
|
214
|
-
|
|
217
|
+
save_fn = (
|
|
218
|
+
self.sae.save_inference_model
|
|
219
|
+
if save_inference_model
|
|
220
|
+
else self.sae.save_model
|
|
221
|
+
)
|
|
222
|
+
weights_path, cfg_path = save_fn(str(checkpoint_path))
|
|
215
223
|
|
|
216
224
|
sparsity_path = checkpoint_path / SPARSITY_FILENAME
|
|
217
225
|
save_file({"sparsity": self.log_feature_sparsity}, sparsity_path)
|
|
@@ -88,7 +88,7 @@ def _create_default_readme(repo_id: str, sae_ids: Iterable[str]) -> str:
|
|
|
88
88
|
```python
|
|
89
89
|
from sae_lens import SAE
|
|
90
90
|
|
|
91
|
-
sae
|
|
91
|
+
sae = SAE.from_pretrained("{repo_id}", "<sae_id>")
|
|
92
92
|
```
|
|
93
93
|
"""
|
|
94
94
|
)
|
|
@@ -1,101 +0,0 @@
|
|
|
1
|
-
from types import SimpleNamespace
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
import tqdm
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
def weighted_average(points: torch.Tensor, weights: torch.Tensor):
|
|
8
|
-
weights = weights / weights.sum()
|
|
9
|
-
return (points * weights.view(-1, 1)).sum(dim=0)
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
@torch.no_grad()
|
|
13
|
-
def geometric_median_objective(
|
|
14
|
-
median: torch.Tensor, points: torch.Tensor, weights: torch.Tensor
|
|
15
|
-
) -> torch.Tensor:
|
|
16
|
-
norms = torch.linalg.norm(points - median.view(1, -1), dim=1) # type: ignore
|
|
17
|
-
|
|
18
|
-
return (norms * weights).sum()
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
def compute_geometric_median(
|
|
22
|
-
points: torch.Tensor,
|
|
23
|
-
weights: torch.Tensor | None = None,
|
|
24
|
-
eps: float = 1e-6,
|
|
25
|
-
maxiter: int = 100,
|
|
26
|
-
ftol: float = 1e-20,
|
|
27
|
-
do_log: bool = False,
|
|
28
|
-
):
|
|
29
|
-
"""
|
|
30
|
-
:param points: ``torch.Tensor`` of shape ``(n, d)``
|
|
31
|
-
:param weights: Optional ``torch.Tensor`` of shape :math:``(n,)``.
|
|
32
|
-
:param eps: Smallest allowed value of denominator, to avoid divide by zero.
|
|
33
|
-
Equivalently, this is a smoothing parameter. Default 1e-6.
|
|
34
|
-
:param maxiter: Maximum number of Weiszfeld iterations. Default 100
|
|
35
|
-
:param ftol: If objective value does not improve by at least this `ftol` fraction, terminate the algorithm. Default 1e-20.
|
|
36
|
-
:param do_log: If true will return a log of function values encountered through the course of the algorithm
|
|
37
|
-
:return: SimpleNamespace object with fields
|
|
38
|
-
- `median`: estimate of the geometric median, which is a ``torch.Tensor`` object of shape :math:``(d,)``
|
|
39
|
-
- `termination`: string explaining how the algorithm terminated.
|
|
40
|
-
- `logs`: function values encountered through the course of the algorithm in a list (None if do_log is false).
|
|
41
|
-
"""
|
|
42
|
-
with torch.no_grad():
|
|
43
|
-
if weights is None:
|
|
44
|
-
weights = torch.ones((points.shape[0],), device=points.device)
|
|
45
|
-
# initialize median estimate at mean
|
|
46
|
-
new_weights = weights
|
|
47
|
-
median = weighted_average(points, weights)
|
|
48
|
-
objective_value = geometric_median_objective(median, points, weights)
|
|
49
|
-
logs = [objective_value] if do_log else None
|
|
50
|
-
|
|
51
|
-
# Weiszfeld iterations
|
|
52
|
-
early_termination = False
|
|
53
|
-
pbar = tqdm.tqdm(range(maxiter))
|
|
54
|
-
for _ in pbar:
|
|
55
|
-
prev_obj_value = objective_value
|
|
56
|
-
|
|
57
|
-
norms = torch.linalg.norm(points - median.view(1, -1), dim=1) # type: ignore
|
|
58
|
-
new_weights = weights / torch.clamp(norms, min=eps)
|
|
59
|
-
median = weighted_average(points, new_weights)
|
|
60
|
-
objective_value = geometric_median_objective(median, points, weights)
|
|
61
|
-
|
|
62
|
-
if logs is not None:
|
|
63
|
-
logs.append(objective_value)
|
|
64
|
-
if abs(prev_obj_value - objective_value) <= ftol * objective_value:
|
|
65
|
-
early_termination = True
|
|
66
|
-
break
|
|
67
|
-
|
|
68
|
-
pbar.set_description(f"Objective value: {objective_value:.4f}")
|
|
69
|
-
|
|
70
|
-
median = weighted_average(points, new_weights) # allow autodiff to track it
|
|
71
|
-
return SimpleNamespace(
|
|
72
|
-
median=median,
|
|
73
|
-
new_weights=new_weights,
|
|
74
|
-
termination=(
|
|
75
|
-
"function value converged within tolerance"
|
|
76
|
-
if early_termination
|
|
77
|
-
else "maximum iterations reached"
|
|
78
|
-
),
|
|
79
|
-
logs=logs,
|
|
80
|
-
)
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
if __name__ == "__main__":
|
|
84
|
-
import time
|
|
85
|
-
|
|
86
|
-
TOLERANCE = 1e-2
|
|
87
|
-
|
|
88
|
-
dim1 = 10000
|
|
89
|
-
dim2 = 768
|
|
90
|
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
91
|
-
|
|
92
|
-
sample = (
|
|
93
|
-
torch.randn((dim1, dim2), device=device) * 100
|
|
94
|
-
) # seems to be the order of magnitude of the actual use case
|
|
95
|
-
weights = torch.randn((dim1,), device=device)
|
|
96
|
-
|
|
97
|
-
torch.tensor(weights, device=device)
|
|
98
|
-
|
|
99
|
-
tic = time.perf_counter()
|
|
100
|
-
new = compute_geometric_median(sample, weights=weights, maxiter=100)
|
|
101
|
-
print(f"new code takes {time.perf_counter()-tic} seconds!") # noqa: T201
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|