sae-lens 5.11.0__py3-none-any.whl → 6.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +60 -7
- sae_lens/analysis/hooked_sae_transformer.py +12 -12
- sae_lens/analysis/neuronpedia_integration.py +16 -14
- sae_lens/cache_activations_runner.py +9 -7
- sae_lens/config.py +170 -258
- sae_lens/constants.py +21 -0
- sae_lens/evals.py +59 -44
- sae_lens/llm_sae_training_runner.py +377 -0
- sae_lens/load_model.py +52 -4
- sae_lens/{toolkit → loading}/pretrained_sae_loaders.py +85 -32
- sae_lens/registry.py +49 -0
- sae_lens/saes/__init__.py +48 -0
- sae_lens/saes/gated_sae.py +254 -0
- sae_lens/saes/jumprelu_sae.py +348 -0
- sae_lens/saes/sae.py +1076 -0
- sae_lens/saes/standard_sae.py +178 -0
- sae_lens/saes/topk_sae.py +300 -0
- sae_lens/training/activation_scaler.py +53 -0
- sae_lens/training/activations_store.py +103 -184
- sae_lens/training/mixing_buffer.py +56 -0
- sae_lens/training/optim.py +60 -36
- sae_lens/training/sae_trainer.py +155 -177
- sae_lens/training/types.py +5 -0
- sae_lens/training/upload_saes_to_huggingface.py +13 -7
- sae_lens/util.py +47 -0
- {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/METADATA +1 -1
- sae_lens-6.0.0.dist-info/RECORD +37 -0
- sae_lens/sae.py +0 -747
- sae_lens/sae_training_runner.py +0 -251
- sae_lens/training/geometric_median.py +0 -101
- sae_lens/training/training_sae.py +0 -710
- sae_lens-5.11.0.dist-info/RECORD +0 -28
- /sae_lens/{toolkit → loading}/__init__.py +0 -0
- /sae_lens/{toolkit → loading}/pretrained_saes_directory.py +0 -0
- {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/LICENSE +0 -0
- {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/WHEEL +0 -0
sae_lens/config.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import math
|
|
3
3
|
import os
|
|
4
|
-
from dataclasses import dataclass, field
|
|
5
|
-
from
|
|
4
|
+
from dataclasses import asdict, dataclass, field
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast
|
|
6
7
|
|
|
7
8
|
import simple_parsing
|
|
8
9
|
import torch
|
|
@@ -16,25 +17,17 @@ from datasets import (
|
|
|
16
17
|
)
|
|
17
18
|
|
|
18
19
|
from sae_lens import __version__, logger
|
|
20
|
+
from sae_lens.constants import DTYPE_MAP
|
|
21
|
+
from sae_lens.saes.sae import TrainingSAEConfig
|
|
19
22
|
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
"float64": torch.float64,
|
|
23
|
-
"float16": torch.float16,
|
|
24
|
-
"bfloat16": torch.bfloat16,
|
|
25
|
-
"torch.float32": torch.float32,
|
|
26
|
-
"torch.float64": torch.float64,
|
|
27
|
-
"torch.float16": torch.float16,
|
|
28
|
-
"torch.bfloat16": torch.bfloat16,
|
|
29
|
-
}
|
|
30
|
-
|
|
31
|
-
HfDataset = DatasetDict | Dataset | IterableDatasetDict | IterableDataset
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
pass
|
|
32
25
|
|
|
26
|
+
T_TRAINING_SAE_CONFIG = TypeVar(
|
|
27
|
+
"T_TRAINING_SAE_CONFIG", bound=TrainingSAEConfig, covariant=True
|
|
28
|
+
)
|
|
33
29
|
|
|
34
|
-
|
|
35
|
-
SAE_WEIGHTS_FILENAME = "sae_weights.safetensors"
|
|
36
|
-
SPARSIFY_WEIGHTS_FILENAME = "sae.safetensors"
|
|
37
|
-
SAE_CFG_FILENAME = "cfg.json"
|
|
30
|
+
HfDataset = DatasetDict | Dataset | IterableDatasetDict | IterableDataset
|
|
38
31
|
|
|
39
32
|
|
|
40
33
|
# calling this "json_dict" so error messages will reference "json_dict" being invalid
|
|
@@ -55,101 +48,118 @@ def dict_field(default: dict[str, Any] | None, **kwargs: Any) -> Any: # type: i
|
|
|
55
48
|
|
|
56
49
|
|
|
57
50
|
@dataclass
|
|
58
|
-
class
|
|
51
|
+
class LoggingConfig:
|
|
52
|
+
# WANDB
|
|
53
|
+
log_to_wandb: bool = True
|
|
54
|
+
log_activations_store_to_wandb: bool = False
|
|
55
|
+
log_optimizer_state_to_wandb: bool = False
|
|
56
|
+
wandb_project: str = "sae_lens_training"
|
|
57
|
+
wandb_id: str | None = None
|
|
58
|
+
run_name: str | None = None
|
|
59
|
+
wandb_entity: str | None = None
|
|
60
|
+
wandb_log_frequency: int = 10
|
|
61
|
+
eval_every_n_wandb_logs: int = 100 # logs every 100 steps.
|
|
62
|
+
|
|
63
|
+
def log(
|
|
64
|
+
self,
|
|
65
|
+
trainer: Any, # avoid import cycle from importing SAETrainer
|
|
66
|
+
weights_path: Path | str,
|
|
67
|
+
cfg_path: Path | str,
|
|
68
|
+
sparsity_path: Path | str | None,
|
|
69
|
+
wandb_aliases: list[str] | None = None,
|
|
70
|
+
) -> None:
|
|
71
|
+
# Avoid wandb saving errors such as:
|
|
72
|
+
# ValueError: Artifact name may only contain alphanumeric characters, dashes, underscores, and dots. Invalid name: sae_google/gemma-2b_etc
|
|
73
|
+
sae_name = trainer.sae.get_name().replace("/", "__")
|
|
74
|
+
|
|
75
|
+
# save model weights and cfg
|
|
76
|
+
model_artifact = wandb.Artifact(
|
|
77
|
+
sae_name,
|
|
78
|
+
type="model",
|
|
79
|
+
metadata=dict(trainer.cfg.__dict__),
|
|
80
|
+
)
|
|
81
|
+
model_artifact.add_file(str(weights_path))
|
|
82
|
+
model_artifact.add_file(str(cfg_path))
|
|
83
|
+
wandb.log_artifact(model_artifact, aliases=wandb_aliases)
|
|
84
|
+
|
|
85
|
+
# save log feature sparsity
|
|
86
|
+
sparsity_artifact = wandb.Artifact(
|
|
87
|
+
f"{sae_name}_log_feature_sparsity",
|
|
88
|
+
type="log_feature_sparsity",
|
|
89
|
+
metadata=dict(trainer.cfg.__dict__),
|
|
90
|
+
)
|
|
91
|
+
if sparsity_path is not None:
|
|
92
|
+
sparsity_artifact.add_file(str(sparsity_path))
|
|
93
|
+
wandb.log_artifact(sparsity_artifact)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@dataclass
|
|
97
|
+
class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
59
98
|
"""
|
|
60
99
|
Configuration for training a sparse autoencoder on a language model.
|
|
61
100
|
|
|
62
101
|
Args:
|
|
63
|
-
|
|
102
|
+
sae (T_TRAINING_SAE_CONFIG): The configuration for the SAE itself (e.g. StandardSAEConfig, GatedSAEConfig).
|
|
64
103
|
model_name (str): The name of the model to use. This should be the name of the model in the Hugging Face model hub.
|
|
65
104
|
model_class_name (str): The name of the class of the model to use. This should be either `HookedTransformer` or `HookedMamba`.
|
|
66
105
|
hook_name (str): The name of the hook to use. This should be a valid TransformerLens hook.
|
|
67
106
|
hook_eval (str): NOT CURRENTLY IN USE. The name of the hook to use for evaluation.
|
|
68
|
-
|
|
69
|
-
hook_head_index (int, optional): When the hook if for an activatio with a head index, we can specify a specific head to use here.
|
|
107
|
+
hook_head_index (int, optional): When the hook is for an activation with a head index, we can specify a specific head to use here.
|
|
70
108
|
dataset_path (str): A Hugging Face dataset path.
|
|
71
109
|
dataset_trust_remote_code (bool): Whether to trust remote code when loading datasets from Huggingface.
|
|
72
110
|
streaming (bool): Whether to stream the dataset. Streaming large datasets is usually practical.
|
|
73
|
-
is_dataset_tokenized (bool):
|
|
111
|
+
is_dataset_tokenized (bool): Whether the dataset is already tokenized.
|
|
74
112
|
context_size (int): The context size to use when generating activations on which to train the SAE.
|
|
75
113
|
use_cached_activations (bool): Whether to use cached activations. This is useful when doing sweeps over the same activations.
|
|
76
|
-
cached_activations_path (str, optional): The path to the cached activations.
|
|
77
|
-
d_in (int): The input dimension of the SAE.
|
|
78
|
-
d_sae (int, optional): The output dimension of the SAE. If None, defaults to `d_in * expansion_factor`.
|
|
79
|
-
b_dec_init_method (str): The method to use to initialize the decoder bias. Zeros is likely fine.
|
|
80
|
-
expansion_factor (int): The expansion factor. Larger is better but more computationally expensive. Default is 4.
|
|
81
|
-
activation_fn (str): The activation function to use. Relu is standard.
|
|
82
|
-
normalize_sae_decoder (bool): Whether to normalize the SAE decoder. Unit normed decoder weights used to be preferred.
|
|
83
|
-
noise_scale (float): Using noise to induce sparsity is supported but not recommended.
|
|
114
|
+
cached_activations_path (str, optional): The path to the cached activations. Defaults to "activations/{dataset_path}/{model_name}/{hook_name}_{hook_head_index}".
|
|
84
115
|
from_pretrained_path (str, optional): The path to a pretrained SAE. We can finetune an existing SAE if needed.
|
|
85
|
-
|
|
86
|
-
decoder_orthogonal_init (bool): Whether to use orthogonal initialization for the decoder. Not currently advised.
|
|
87
|
-
decoder_heuristic_init (bool): Whether to use heuristic initialization for the decoder. See Anthropic April Update.
|
|
88
|
-
init_encoder_as_decoder_transpose (bool): Whether to initialize the encoder as the transpose of the decoder. See Anthropic April Update.
|
|
89
|
-
n_batches_in_buffer (int): The number of batches in the buffer. When not using cached activations, a buffer in ram is used. The larger it is, the better shuffled the activations will be.
|
|
116
|
+
n_batches_in_buffer (int): The number of batches in the buffer. When not using cached activations, a buffer in RAM is used. The larger it is, the better shuffled the activations will be.
|
|
90
117
|
training_tokens (int): The number of training tokens.
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
seqpos_slice (tuple): Determines slicing of activations when constructing batches during training. The slice should be (start_pos, end_pos, optional[step_size]), e.g. for Othello we sometimes use (5, -5). Note, step_size > 0.
|
|
96
|
-
device (str): The device to use. Usually cuda.
|
|
97
|
-
act_store_device (str): The device to use for the activation store. CPU is advised in order to save vram.
|
|
118
|
+
store_batch_size_prompts (int): The batch size for storing activations. This controls how many prompts are in the batch of the language model when generating activations.
|
|
119
|
+
seqpos_slice (tuple[int | None, ...]): Determines slicing of activations when constructing batches during training. The slice should be (start_pos, end_pos, optional[step_size]), e.g. for Othello we sometimes use (5, -5). Note, step_size > 0.
|
|
120
|
+
device (str): The device to use. Usually "cuda".
|
|
121
|
+
act_store_device (str): The device to use for the activation store. "cpu" is advised in order to save VRAM. Defaults to "with_model" which uses the same device as the main model.
|
|
98
122
|
seed (int): The seed to use.
|
|
99
|
-
dtype (str): The data type to use.
|
|
123
|
+
dtype (str): The data type to use for the SAE and activations.
|
|
100
124
|
prepend_bos (bool): Whether to prepend the beginning of sequence token. You should use whatever the model was trained with.
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
adam_beta2 (float): The beta2 parameter for Adam.
|
|
111
|
-
mse_loss_normalization (str): The normalization to use for the MSE loss.
|
|
112
|
-
l1_coefficient (float): The L1 coefficient.
|
|
113
|
-
lp_norm (float): The Lp norm.
|
|
114
|
-
scale_sparsity_penalty_by_decoder_norm (bool): Whether to scale the sparsity penalty by the decoder norm.
|
|
115
|
-
l1_warm_up_steps (int): The number of warm-up steps for the L1 loss.
|
|
125
|
+
autocast (bool): Whether to use autocast (mixed-precision) during SAE training. Saves VRAM.
|
|
126
|
+
autocast_lm (bool): Whether to use autocast (mixed-precision) during activation fetching. Saves VRAM.
|
|
127
|
+
compile_llm (bool): Whether to compile the LLM using `torch.compile`.
|
|
128
|
+
llm_compilation_mode (str, optional): The compilation mode to use for the LLM if `compile_llm` is True.
|
|
129
|
+
compile_sae (bool): Whether to compile the SAE using `torch.compile`.
|
|
130
|
+
sae_compilation_mode (str, optional): The compilation mode to use for the SAE if `compile_sae` is True.
|
|
131
|
+
train_batch_size_tokens (int): The batch size for training, in tokens. This controls the batch size of the SAE training loop.
|
|
132
|
+
adam_beta1 (float): The beta1 parameter for the Adam optimizer.
|
|
133
|
+
adam_beta2 (float): The beta2 parameter for the Adam optimizer.
|
|
116
134
|
lr (float): The learning rate.
|
|
117
|
-
lr_scheduler_name (str): The name of the learning rate scheduler to use.
|
|
135
|
+
lr_scheduler_name (str): The name of the learning rate scheduler to use (e.g., "constant", "cosineannealing", "cosineannealingwarmrestarts").
|
|
118
136
|
lr_warm_up_steps (int): The number of warm-up steps for the learning rate.
|
|
119
|
-
lr_end (float): The end learning rate if
|
|
120
|
-
lr_decay_steps (int): The number of decay steps for the learning rate.
|
|
121
|
-
n_restart_cycles (int): The number of restart cycles for the cosine annealing warm restarts scheduler.
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
log_activations_store_to_wandb (bool): NOT CURRENTLY USED. Whether to log the activations store to Weights & Biases.
|
|
131
|
-
log_optimizer_state_to_wandb (bool): NOT CURRENTLY USED. Whether to log the optimizer state to Weights & Biases.
|
|
132
|
-
wandb_project (str): The Weights & Biases project to log to.
|
|
133
|
-
wandb_id (str): The Weights & Biases ID.
|
|
134
|
-
run_name (str): The name of the run.
|
|
135
|
-
wandb_entity (str): The Weights & Biases entity.
|
|
136
|
-
wandb_log_frequency (int): The frequency to log to Weights & Biases.
|
|
137
|
-
eval_every_n_wandb_logs (int): The frequency to evaluate.
|
|
138
|
-
resume (bool): Whether to resume training.
|
|
139
|
-
n_checkpoints (int): The number of checkpoints.
|
|
140
|
-
checkpoint_path (str): The path to save checkpoints.
|
|
137
|
+
lr_end (float, optional): The end learning rate if using a scheduler like cosine annealing. Defaults to `lr / 10`.
|
|
138
|
+
lr_decay_steps (int): The number of decay steps for the learning rate if using a scheduler with decay.
|
|
139
|
+
n_restart_cycles (int): The number of restart cycles for the cosine annealing with warm restarts scheduler.
|
|
140
|
+
dead_feature_window (int): The window size (in training steps) for detecting dead features.
|
|
141
|
+
feature_sampling_window (int): The window size (in training steps) for resampling features (e.g. dead features).
|
|
142
|
+
dead_feature_threshold (float): The threshold below which a feature's activation frequency is considered dead.
|
|
143
|
+
n_eval_batches (int): The number of batches to use for evaluation.
|
|
144
|
+
eval_batch_size_prompts (int, optional): The batch size for evaluation, in prompts. Useful if evals cause OOM.
|
|
145
|
+
logger (LoggingConfig): Configuration for logging (e.g. W&B).
|
|
146
|
+
n_checkpoints (int): The number of checkpoints to save during training. 0 means no checkpoints.
|
|
147
|
+
checkpoint_path (str): The path to save checkpoints. A unique ID will be appended to this path.
|
|
141
148
|
verbose (bool): Whether to print verbose output.
|
|
142
|
-
model_kwargs (dict[str, Any]):
|
|
143
|
-
model_from_pretrained_kwargs (dict[str, Any]): Additional keyword arguments
|
|
144
|
-
|
|
149
|
+
model_kwargs (dict[str, Any]): Keyword arguments for `model.run_with_cache`
|
|
150
|
+
model_from_pretrained_kwargs (dict[str, Any], optional): Additional keyword arguments to pass to the model's `from_pretrained` method.
|
|
151
|
+
sae_lens_version (str): The version of the sae_lens library.
|
|
152
|
+
sae_lens_training_version (str): The version of the sae_lens training library.
|
|
153
|
+
exclude_special_tokens (bool | list[int]): Whether to exclude special tokens from the activations. If True, excludes all special tokens. If a list of ints, excludes those token IDs.
|
|
145
154
|
"""
|
|
146
155
|
|
|
156
|
+
sae: T_TRAINING_SAE_CONFIG
|
|
157
|
+
|
|
147
158
|
# Data Generating Function (Model + Training Distibuion)
|
|
148
159
|
model_name: str = "gelu-2l"
|
|
149
160
|
model_class_name: str = "HookedTransformer"
|
|
150
161
|
hook_name: str = "blocks.0.hook_mlp_out"
|
|
151
162
|
hook_eval: str = "NOT_IN_USE"
|
|
152
|
-
hook_layer: int = 0
|
|
153
163
|
hook_head_index: int | None = None
|
|
154
164
|
dataset_path: str = ""
|
|
155
165
|
dataset_trust_remote_code: bool = True
|
|
@@ -162,30 +172,12 @@ class LanguageModelSAERunnerConfig:
|
|
|
162
172
|
)
|
|
163
173
|
|
|
164
174
|
# SAE Parameters
|
|
165
|
-
architecture: Literal["standard", "gated", "jumprelu", "topk"] = "standard"
|
|
166
|
-
d_in: int = 512
|
|
167
|
-
d_sae: int | None = None
|
|
168
|
-
b_dec_init_method: str = "geometric_median"
|
|
169
|
-
expansion_factor: int | None = (
|
|
170
|
-
None # defaults to 4 if d_sae and expansion_factor is None
|
|
171
|
-
)
|
|
172
|
-
activation_fn: str = None # relu, tanh-relu, topk. Default is relu. # type: ignore
|
|
173
|
-
activation_fn_kwargs: dict[str, int] = dict_field(default=None) # for topk
|
|
174
|
-
normalize_sae_decoder: bool = True
|
|
175
|
-
noise_scale: float = 0.0
|
|
176
175
|
from_pretrained_path: str | None = None
|
|
177
|
-
apply_b_dec_to_input: bool = True
|
|
178
|
-
decoder_orthogonal_init: bool = False
|
|
179
|
-
decoder_heuristic_init: bool = False
|
|
180
|
-
decoder_heuristic_init_norm: float = 0.1
|
|
181
|
-
init_encoder_as_decoder_transpose: bool = False
|
|
182
176
|
|
|
183
177
|
# Activation Store Parameters
|
|
184
178
|
n_batches_in_buffer: int = 20
|
|
185
179
|
training_tokens: int = 2_000_000
|
|
186
|
-
finetuning_tokens: int = 0
|
|
187
180
|
store_batch_size_prompts: int = 32
|
|
188
|
-
normalize_activations: str = "none" # none, expected_average_only_in (Anthropic April Update), constant_norm_rescale (Anthropic Feb Update)
|
|
189
181
|
seqpos_slice: tuple[int | None, ...] = (None,)
|
|
190
182
|
|
|
191
183
|
# Misc
|
|
@@ -195,10 +187,6 @@ class LanguageModelSAERunnerConfig:
|
|
|
195
187
|
dtype: str = "float32" # type: ignore #
|
|
196
188
|
prepend_bos: bool = True
|
|
197
189
|
|
|
198
|
-
# JumpReLU Parameters
|
|
199
|
-
jumprelu_init_threshold: float = 0.001
|
|
200
|
-
jumprelu_bandwidth: float = 0.001
|
|
201
|
-
|
|
202
190
|
# Performance - see compilation section of lm_runner.py for info
|
|
203
191
|
autocast: bool = False # autocast to autocast_dtype during training
|
|
204
192
|
autocast_lm: bool = False # autocast lm during activation fetching
|
|
@@ -213,16 +201,9 @@ class LanguageModelSAERunnerConfig:
|
|
|
213
201
|
train_batch_size_tokens: int = 4096
|
|
214
202
|
|
|
215
203
|
## Adam
|
|
216
|
-
adam_beta1: float = 0.
|
|
204
|
+
adam_beta1: float = 0.9
|
|
217
205
|
adam_beta2: float = 0.999
|
|
218
206
|
|
|
219
|
-
## Loss Function
|
|
220
|
-
mse_loss_normalization: str | None = None
|
|
221
|
-
l1_coefficient: float = 1e-3
|
|
222
|
-
lp_norm: float = 1
|
|
223
|
-
scale_sparsity_penalty_by_decoder_norm: bool = False
|
|
224
|
-
l1_warm_up_steps: int = 0
|
|
225
|
-
|
|
226
207
|
## Learning Rate Schedule
|
|
227
208
|
lr: float = 3e-4
|
|
228
209
|
lr_scheduler_name: str = (
|
|
@@ -233,33 +214,18 @@ class LanguageModelSAERunnerConfig:
|
|
|
233
214
|
lr_decay_steps: int = 0
|
|
234
215
|
n_restart_cycles: int = 1 # used only for cosineannealingwarmrestarts
|
|
235
216
|
|
|
236
|
-
## FineTuning
|
|
237
|
-
finetuning_method: str | None = None # scale, decoder or unrotated_decoder
|
|
238
|
-
|
|
239
217
|
# Resampling protocol args
|
|
240
|
-
use_ghost_grads: bool = False # want to change this to true on some timeline.
|
|
241
|
-
feature_sampling_window: int = 2000
|
|
242
218
|
dead_feature_window: int = 1000 # unless this window is larger feature sampling,
|
|
243
|
-
|
|
219
|
+
feature_sampling_window: int = 2000
|
|
244
220
|
dead_feature_threshold: float = 1e-8
|
|
245
221
|
|
|
246
222
|
# Evals
|
|
247
223
|
n_eval_batches: int = 10
|
|
248
224
|
eval_batch_size_prompts: int | None = None # useful if evals cause OOM
|
|
249
225
|
|
|
250
|
-
|
|
251
|
-
log_to_wandb: bool = True
|
|
252
|
-
log_activations_store_to_wandb: bool = False
|
|
253
|
-
log_optimizer_state_to_wandb: bool = False
|
|
254
|
-
wandb_project: str = "mats_sae_training_language_model"
|
|
255
|
-
wandb_id: str | None = None
|
|
256
|
-
run_name: str | None = None
|
|
257
|
-
wandb_entity: str | None = None
|
|
258
|
-
wandb_log_frequency: int = 10
|
|
259
|
-
eval_every_n_wandb_logs: int = 100 # logs every 1000 steps.
|
|
226
|
+
logger: LoggingConfig = field(default_factory=LoggingConfig)
|
|
260
227
|
|
|
261
228
|
# Misc
|
|
262
|
-
resume: bool = False
|
|
263
229
|
n_checkpoints: int = 0
|
|
264
230
|
checkpoint_path: str = "checkpoints"
|
|
265
231
|
verbose: bool = True
|
|
@@ -270,12 +236,6 @@ class LanguageModelSAERunnerConfig:
|
|
|
270
236
|
exclude_special_tokens: bool | list[int] = False
|
|
271
237
|
|
|
272
238
|
def __post_init__(self):
|
|
273
|
-
if self.resume:
|
|
274
|
-
raise ValueError(
|
|
275
|
-
"Resuming is no longer supported. You can finetune a trained SAE using cfg.from_pretrained path."
|
|
276
|
-
+ "If you want to load an SAE with resume=True in the config, please manually set resume=False in that config."
|
|
277
|
-
)
|
|
278
|
-
|
|
279
239
|
if self.use_cached_activations and self.cached_activations_path is None:
|
|
280
240
|
self.cached_activations_path = _default_cached_activations_path(
|
|
281
241
|
self.dataset_path,
|
|
@@ -283,37 +243,12 @@ class LanguageModelSAERunnerConfig:
|
|
|
283
243
|
self.hook_name,
|
|
284
244
|
self.hook_head_index,
|
|
285
245
|
)
|
|
286
|
-
|
|
287
|
-
if self.activation_fn is None:
|
|
288
|
-
self.activation_fn = "topk" if self.architecture == "topk" else "relu"
|
|
289
|
-
|
|
290
|
-
if self.architecture == "topk" and self.activation_fn != "topk":
|
|
291
|
-
raise ValueError("If using topk architecture, activation_fn must be topk.")
|
|
292
|
-
|
|
293
|
-
if self.activation_fn_kwargs is None:
|
|
294
|
-
self.activation_fn_kwargs = (
|
|
295
|
-
{"k": 100} if self.activation_fn == "topk" else {}
|
|
296
|
-
)
|
|
297
|
-
|
|
298
|
-
if self.architecture == "topk" and self.activation_fn_kwargs.get("k") is None:
|
|
299
|
-
raise ValueError(
|
|
300
|
-
"activation_fn_kwargs.k must be provided for topk architecture."
|
|
301
|
-
)
|
|
302
|
-
|
|
303
|
-
if self.d_sae is not None and self.expansion_factor is not None:
|
|
304
|
-
raise ValueError("You can't set both d_sae and expansion_factor.")
|
|
305
|
-
|
|
306
|
-
if self.d_sae is None and self.expansion_factor is None:
|
|
307
|
-
self.expansion_factor = 4
|
|
308
|
-
|
|
309
|
-
if self.d_sae is None and self.expansion_factor is not None:
|
|
310
|
-
self.d_sae = self.d_in * self.expansion_factor
|
|
311
246
|
self.tokens_per_buffer = (
|
|
312
247
|
self.train_batch_size_tokens * self.context_size * self.n_batches_in_buffer
|
|
313
248
|
)
|
|
314
249
|
|
|
315
|
-
if self.run_name is None:
|
|
316
|
-
self.run_name = f"{self.
|
|
250
|
+
if self.logger.run_name is None:
|
|
251
|
+
self.logger.run_name = f"{self.sae.architecture()}-{self.sae.d_sae}-LR-{self.lr}-Tokens-{self.training_tokens:3.3e}"
|
|
317
252
|
|
|
318
253
|
if self.model_from_pretrained_kwargs is None:
|
|
319
254
|
if self.model_class_name == "HookedTransformer":
|
|
@@ -321,44 +256,13 @@ class LanguageModelSAERunnerConfig:
|
|
|
321
256
|
else:
|
|
322
257
|
self.model_from_pretrained_kwargs = {}
|
|
323
258
|
|
|
324
|
-
if self.b_dec_init_method not in ["geometric_median", "mean", "zeros"]:
|
|
325
|
-
raise ValueError(
|
|
326
|
-
f"b_dec_init_method must be geometric_median, mean, or zeros. Got {self.b_dec_init_method}"
|
|
327
|
-
)
|
|
328
|
-
|
|
329
|
-
if self.normalize_sae_decoder and self.decoder_heuristic_init:
|
|
330
|
-
raise ValueError(
|
|
331
|
-
"You can't normalize the decoder and use heuristic initialization."
|
|
332
|
-
)
|
|
333
|
-
|
|
334
|
-
if self.normalize_sae_decoder and self.scale_sparsity_penalty_by_decoder_norm:
|
|
335
|
-
raise ValueError(
|
|
336
|
-
"Weighting loss by decoder norm makes no sense if you are normalizing the decoder weight norms to 1"
|
|
337
|
-
)
|
|
338
|
-
|
|
339
|
-
# if we use decoder fine tuning, we can't be applying b_dec to the input
|
|
340
|
-
if (self.finetuning_method == "decoder") and (self.apply_b_dec_to_input):
|
|
341
|
-
raise ValueError(
|
|
342
|
-
"If we are fine tuning the decoder, we can't be applying b_dec to the input.\nSet apply_b_dec_to_input to False."
|
|
343
|
-
)
|
|
344
|
-
|
|
345
|
-
if self.normalize_activations not in [
|
|
346
|
-
"none",
|
|
347
|
-
"expected_average_only_in",
|
|
348
|
-
"constant_norm_rescale",
|
|
349
|
-
"layer_norm",
|
|
350
|
-
]:
|
|
351
|
-
raise ValueError(
|
|
352
|
-
f"normalize_activations must be none, layer_norm, expected_average_only_in, or constant_norm_rescale. Got {self.normalize_activations}"
|
|
353
|
-
)
|
|
354
|
-
|
|
355
259
|
if self.act_store_device == "with_model":
|
|
356
260
|
self.act_store_device = self.device
|
|
357
261
|
|
|
358
262
|
if self.lr_end is None:
|
|
359
263
|
self.lr_end = self.lr / 10
|
|
360
264
|
|
|
361
|
-
unique_id = self.wandb_id
|
|
265
|
+
unique_id = self.logger.wandb_id
|
|
362
266
|
if unique_id is None:
|
|
363
267
|
unique_id = cast(
|
|
364
268
|
Any, wandb
|
|
@@ -367,7 +271,7 @@ class LanguageModelSAERunnerConfig:
|
|
|
367
271
|
|
|
368
272
|
if self.verbose:
|
|
369
273
|
logger.info(
|
|
370
|
-
f"Run name: {self.
|
|
274
|
+
f"Run name: {self.sae.architecture()}-{self.sae.d_sae}-LR-{self.lr}-Tokens-{self.training_tokens:3.3e}"
|
|
371
275
|
)
|
|
372
276
|
# Print out some useful info:
|
|
373
277
|
n_tokens_per_buffer = (
|
|
@@ -386,11 +290,13 @@ class LanguageModelSAERunnerConfig:
|
|
|
386
290
|
)
|
|
387
291
|
|
|
388
292
|
total_training_steps = (
|
|
389
|
-
self.training_tokens
|
|
293
|
+
self.training_tokens
|
|
390
294
|
) // self.train_batch_size_tokens
|
|
391
295
|
logger.info(f"Total training steps: {total_training_steps}")
|
|
392
296
|
|
|
393
|
-
total_wandb_updates =
|
|
297
|
+
total_wandb_updates = (
|
|
298
|
+
total_training_steps // self.logger.wandb_log_frequency
|
|
299
|
+
)
|
|
394
300
|
logger.info(f"Total wandb updates: {total_wandb_updates}")
|
|
395
301
|
|
|
396
302
|
# how many times will we sample dead neurons?
|
|
@@ -412,9 +318,6 @@ class LanguageModelSAERunnerConfig:
|
|
|
412
318
|
f"Number tokens in sparsity calculation window: {self.feature_sampling_window * self.train_batch_size_tokens:.2e}"
|
|
413
319
|
)
|
|
414
320
|
|
|
415
|
-
if self.use_ghost_grads:
|
|
416
|
-
logger.info("Using Ghost Grads.")
|
|
417
|
-
|
|
418
321
|
if self.context_size < 0:
|
|
419
322
|
raise ValueError(
|
|
420
323
|
f"The provided context_size is {self.context_size} is negative. Expecting positive context_size."
|
|
@@ -429,65 +332,26 @@ class LanguageModelSAERunnerConfig:
|
|
|
429
332
|
|
|
430
333
|
@property
|
|
431
334
|
def total_training_tokens(self) -> int:
|
|
432
|
-
return self.training_tokens
|
|
335
|
+
return self.training_tokens
|
|
433
336
|
|
|
434
337
|
@property
|
|
435
338
|
def total_training_steps(self) -> int:
|
|
436
339
|
return self.total_training_tokens // self.train_batch_size_tokens
|
|
437
340
|
|
|
438
|
-
def get_base_sae_cfg_dict(self) -> dict[str, Any]:
|
|
439
|
-
return {
|
|
440
|
-
# TEMP
|
|
441
|
-
"architecture": self.architecture,
|
|
442
|
-
"d_in": self.d_in,
|
|
443
|
-
"d_sae": self.d_sae,
|
|
444
|
-
"dtype": self.dtype,
|
|
445
|
-
"device": self.device,
|
|
446
|
-
"model_name": self.model_name,
|
|
447
|
-
"hook_name": self.hook_name,
|
|
448
|
-
"hook_layer": self.hook_layer,
|
|
449
|
-
"hook_head_index": self.hook_head_index,
|
|
450
|
-
"activation_fn_str": self.activation_fn,
|
|
451
|
-
"apply_b_dec_to_input": self.apply_b_dec_to_input,
|
|
452
|
-
"context_size": self.context_size,
|
|
453
|
-
"prepend_bos": self.prepend_bos,
|
|
454
|
-
"dataset_path": self.dataset_path,
|
|
455
|
-
"dataset_trust_remote_code": self.dataset_trust_remote_code,
|
|
456
|
-
"finetuning_scaling_factor": self.finetuning_method is not None,
|
|
457
|
-
"sae_lens_training_version": self.sae_lens_training_version,
|
|
458
|
-
"normalize_activations": self.normalize_activations,
|
|
459
|
-
"activation_fn_kwargs": self.activation_fn_kwargs,
|
|
460
|
-
"model_from_pretrained_kwargs": self.model_from_pretrained_kwargs,
|
|
461
|
-
"seqpos_slice": self.seqpos_slice,
|
|
462
|
-
}
|
|
463
|
-
|
|
464
341
|
def get_training_sae_cfg_dict(self) -> dict[str, Any]:
|
|
465
|
-
return
|
|
466
|
-
**self.get_base_sae_cfg_dict(),
|
|
467
|
-
"l1_coefficient": self.l1_coefficient,
|
|
468
|
-
"lp_norm": self.lp_norm,
|
|
469
|
-
"use_ghost_grads": self.use_ghost_grads,
|
|
470
|
-
"normalize_sae_decoder": self.normalize_sae_decoder,
|
|
471
|
-
"noise_scale": self.noise_scale,
|
|
472
|
-
"decoder_orthogonal_init": self.decoder_orthogonal_init,
|
|
473
|
-
"mse_loss_normalization": self.mse_loss_normalization,
|
|
474
|
-
"decoder_heuristic_init": self.decoder_heuristic_init,
|
|
475
|
-
"decoder_heuristic_init_norm": self.decoder_heuristic_init_norm,
|
|
476
|
-
"init_encoder_as_decoder_transpose": self.init_encoder_as_decoder_transpose,
|
|
477
|
-
"normalize_activations": self.normalize_activations,
|
|
478
|
-
"jumprelu_init_threshold": self.jumprelu_init_threshold,
|
|
479
|
-
"jumprelu_bandwidth": self.jumprelu_bandwidth,
|
|
480
|
-
"scale_sparsity_penalty_by_decoder_norm": self.scale_sparsity_penalty_by_decoder_norm,
|
|
481
|
-
}
|
|
342
|
+
return self.sae.to_dict()
|
|
482
343
|
|
|
483
344
|
def to_dict(self) -> dict[str, Any]:
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
345
|
+
# Make a shallow copy of config's dictionary
|
|
346
|
+
d = dict(self.__dict__)
|
|
347
|
+
|
|
348
|
+
d["logger"] = asdict(self.logger)
|
|
349
|
+
d["sae"] = self.sae.to_dict()
|
|
350
|
+
# Overwrite fields that might not be JSON-serializable
|
|
351
|
+
d["dtype"] = str(self.dtype)
|
|
352
|
+
d["device"] = str(self.device)
|
|
353
|
+
d["act_store_device"] = str(self.act_store_device)
|
|
354
|
+
return d
|
|
491
355
|
|
|
492
356
|
def to_json(self, path: str) -> None:
|
|
493
357
|
if not os.path.exists(os.path.dirname(path)):
|
|
@@ -497,7 +361,7 @@ class LanguageModelSAERunnerConfig:
|
|
|
497
361
|
json.dump(self.to_dict(), f, indent=2)
|
|
498
362
|
|
|
499
363
|
@classmethod
|
|
500
|
-
def from_json(cls, path: str) -> "LanguageModelSAERunnerConfig":
|
|
364
|
+
def from_json(cls, path: str) -> "LanguageModelSAERunnerConfig[Any]":
|
|
501
365
|
with open(path + "cfg.json") as f:
|
|
502
366
|
cfg = json.load(f)
|
|
503
367
|
|
|
@@ -511,6 +375,27 @@ class LanguageModelSAERunnerConfig:
|
|
|
511
375
|
|
|
512
376
|
return cls(**cfg)
|
|
513
377
|
|
|
378
|
+
def to_sae_trainer_config(self) -> "SAETrainerConfig":
|
|
379
|
+
return SAETrainerConfig(
|
|
380
|
+
n_checkpoints=self.n_checkpoints,
|
|
381
|
+
checkpoint_path=self.checkpoint_path,
|
|
382
|
+
total_training_samples=self.total_training_tokens,
|
|
383
|
+
device=self.device,
|
|
384
|
+
autocast=self.autocast,
|
|
385
|
+
lr=self.lr,
|
|
386
|
+
lr_end=self.lr_end,
|
|
387
|
+
lr_scheduler_name=self.lr_scheduler_name,
|
|
388
|
+
lr_warm_up_steps=self.lr_warm_up_steps,
|
|
389
|
+
adam_beta1=self.adam_beta1,
|
|
390
|
+
adam_beta2=self.adam_beta2,
|
|
391
|
+
lr_decay_steps=self.lr_decay_steps,
|
|
392
|
+
n_restart_cycles=self.n_restart_cycles,
|
|
393
|
+
train_batch_size_samples=self.train_batch_size_tokens,
|
|
394
|
+
dead_feature_window=self.dead_feature_window,
|
|
395
|
+
feature_sampling_window=self.feature_sampling_window,
|
|
396
|
+
logger=self.logger,
|
|
397
|
+
)
|
|
398
|
+
|
|
514
399
|
|
|
515
400
|
@dataclass
|
|
516
401
|
class CacheActivationsRunnerConfig:
|
|
@@ -522,7 +407,6 @@ class CacheActivationsRunnerConfig:
|
|
|
522
407
|
model_name (str): The name of the model to use.
|
|
523
408
|
model_batch_size (int): How many prompts are in the batch of the language model when generating activations.
|
|
524
409
|
hook_name (str): The name of the hook to use.
|
|
525
|
-
hook_layer (int): The layer of the final hook. Currently only support a single hook, so this should be the same as hook_name.
|
|
526
410
|
d_in (int): Dimension of the model.
|
|
527
411
|
total_training_tokens (int): Total number of tokens to process.
|
|
528
412
|
context_size (int): Context size to process. Can be left as -1 if the dataset is tokenized.
|
|
@@ -552,7 +436,6 @@ class CacheActivationsRunnerConfig:
|
|
|
552
436
|
model_name: str
|
|
553
437
|
model_batch_size: int
|
|
554
438
|
hook_name: str
|
|
555
|
-
hook_layer: int
|
|
556
439
|
d_in: int
|
|
557
440
|
training_tokens: int
|
|
558
441
|
|
|
@@ -680,6 +563,10 @@ def _validate_seqpos(seqpos: tuple[int | None, ...], context_size: int) -> None:
|
|
|
680
563
|
|
|
681
564
|
@dataclass
|
|
682
565
|
class PretokenizeRunnerConfig:
|
|
566
|
+
"""
|
|
567
|
+
Configuration class for pretokenizing a dataset.
|
|
568
|
+
"""
|
|
569
|
+
|
|
683
570
|
tokenizer_name: str = "gpt2"
|
|
684
571
|
dataset_path: str = ""
|
|
685
572
|
dataset_name: str | None = None
|
|
@@ -708,3 +595,28 @@ class PretokenizeRunnerConfig:
|
|
|
708
595
|
hf_num_shards: int = 64
|
|
709
596
|
hf_revision: str = "main"
|
|
710
597
|
hf_is_private_repo: bool = False
|
|
598
|
+
|
|
599
|
+
|
|
600
|
+
@dataclass
|
|
601
|
+
class SAETrainerConfig:
|
|
602
|
+
n_checkpoints: int
|
|
603
|
+
checkpoint_path: str
|
|
604
|
+
total_training_samples: int
|
|
605
|
+
device: str
|
|
606
|
+
autocast: bool
|
|
607
|
+
lr: float
|
|
608
|
+
lr_end: float | None
|
|
609
|
+
lr_scheduler_name: str
|
|
610
|
+
lr_warm_up_steps: int
|
|
611
|
+
adam_beta1: float
|
|
612
|
+
adam_beta2: float
|
|
613
|
+
lr_decay_steps: int
|
|
614
|
+
n_restart_cycles: int
|
|
615
|
+
train_batch_size_samples: int
|
|
616
|
+
dead_feature_window: int
|
|
617
|
+
feature_sampling_window: int
|
|
618
|
+
logger: LoggingConfig
|
|
619
|
+
|
|
620
|
+
@property
|
|
621
|
+
def total_training_steps(self) -> int:
|
|
622
|
+
return self.total_training_samples // self.train_batch_size_samples
|
sae_lens/constants.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
DTYPE_MAP = {
|
|
4
|
+
"float32": torch.float32,
|
|
5
|
+
"float64": torch.float64,
|
|
6
|
+
"float16": torch.float16,
|
|
7
|
+
"bfloat16": torch.bfloat16,
|
|
8
|
+
"torch.float32": torch.float32,
|
|
9
|
+
"torch.float64": torch.float64,
|
|
10
|
+
"torch.float16": torch.float16,
|
|
11
|
+
"torch.bfloat16": torch.bfloat16,
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
SPARSITY_FILENAME = "sparsity.safetensors"
|
|
16
|
+
SAE_WEIGHTS_FILENAME = "sae_weights.safetensors"
|
|
17
|
+
SAE_CFG_FILENAME = "cfg.json"
|
|
18
|
+
RUNNER_CFG_FILENAME = "runner_cfg.json"
|
|
19
|
+
SPARSIFY_WEIGHTS_FILENAME = "sae.safetensors"
|
|
20
|
+
ACTIVATIONS_STORE_STATE_FILENAME = "activations_store_state.safetensors"
|
|
21
|
+
ACTIVATION_SCALER_CFG_FILENAME = "activation_scaler.json"
|