sae-lens 6.0.0rc1__py3-none-any.whl → 6.0.0rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +55 -18
- sae_lens/analysis/hooked_sae_transformer.py +10 -10
- sae_lens/analysis/neuronpedia_integration.py +13 -11
- sae_lens/cache_activations_runner.py +9 -7
- sae_lens/config.py +105 -235
- sae_lens/constants.py +20 -0
- sae_lens/evals.py +34 -31
- sae_lens/{sae_training_runner.py → llm_sae_training_runner.py} +103 -70
- sae_lens/load_model.py +53 -5
- sae_lens/loading/pretrained_sae_loaders.py +36 -10
- sae_lens/registry.py +49 -0
- sae_lens/saes/__init__.py +48 -0
- sae_lens/saes/gated_sae.py +70 -59
- sae_lens/saes/jumprelu_sae.py +58 -72
- sae_lens/saes/sae.py +248 -273
- sae_lens/saes/standard_sae.py +75 -57
- sae_lens/saes/topk_sae.py +72 -83
- sae_lens/training/activation_scaler.py +53 -0
- sae_lens/training/activations_store.py +105 -184
- sae_lens/training/mixing_buffer.py +56 -0
- sae_lens/training/optim.py +60 -36
- sae_lens/training/sae_trainer.py +134 -158
- sae_lens/training/types.py +5 -0
- sae_lens/training/upload_saes_to_huggingface.py +11 -5
- sae_lens/util.py +47 -0
- {sae_lens-6.0.0rc1.dist-info → sae_lens-6.0.0rc3.dist-info}/METADATA +1 -1
- sae_lens-6.0.0rc3.dist-info/RECORD +38 -0
- {sae_lens-6.0.0rc1.dist-info → sae_lens-6.0.0rc3.dist-info}/WHEEL +1 -1
- sae_lens/regsitry.py +0 -34
- sae_lens-6.0.0rc1.dist-info/RECORD +0 -32
- {sae_lens-6.0.0rc1.dist-info → sae_lens-6.0.0rc3.dist-info}/LICENSE +0 -0
sae_lens/config.py
CHANGED
|
@@ -3,7 +3,7 @@ import math
|
|
|
3
3
|
import os
|
|
4
4
|
from dataclasses import asdict, dataclass, field
|
|
5
5
|
from pathlib import Path
|
|
6
|
-
from typing import Any, Literal, cast
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast
|
|
7
7
|
|
|
8
8
|
import simple_parsing
|
|
9
9
|
import torch
|
|
@@ -17,24 +17,17 @@ from datasets import (
|
|
|
17
17
|
)
|
|
18
18
|
|
|
19
19
|
from sae_lens import __version__, logger
|
|
20
|
+
from sae_lens.constants import DTYPE_MAP
|
|
21
|
+
from sae_lens.saes.sae import TrainingSAEConfig
|
|
20
22
|
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
"float64": torch.float64,
|
|
24
|
-
"float16": torch.float16,
|
|
25
|
-
"bfloat16": torch.bfloat16,
|
|
26
|
-
"torch.float32": torch.float32,
|
|
27
|
-
"torch.float64": torch.float64,
|
|
28
|
-
"torch.float16": torch.float16,
|
|
29
|
-
"torch.bfloat16": torch.bfloat16,
|
|
30
|
-
}
|
|
31
|
-
|
|
32
|
-
HfDataset = DatasetDict | Dataset | IterableDatasetDict | IterableDataset
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
pass
|
|
33
25
|
|
|
26
|
+
T_TRAINING_SAE_CONFIG = TypeVar(
|
|
27
|
+
"T_TRAINING_SAE_CONFIG", bound=TrainingSAEConfig, covariant=True
|
|
28
|
+
)
|
|
34
29
|
|
|
35
|
-
|
|
36
|
-
SAE_WEIGHTS_FILENAME = "sae_weights.safetensors"
|
|
37
|
-
SAE_CFG_FILENAME = "cfg.json"
|
|
30
|
+
HfDataset = DatasetDict | Dataset | IterableDatasetDict | IterableDataset
|
|
38
31
|
|
|
39
32
|
|
|
40
33
|
# calling this "json_dict" so error messages will reference "json_dict" being invalid
|
|
@@ -101,101 +94,72 @@ class LoggingConfig:
|
|
|
101
94
|
|
|
102
95
|
|
|
103
96
|
@dataclass
|
|
104
|
-
class LanguageModelSAERunnerConfig:
|
|
97
|
+
class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
105
98
|
"""
|
|
106
99
|
Configuration for training a sparse autoencoder on a language model.
|
|
107
100
|
|
|
108
101
|
Args:
|
|
109
|
-
|
|
102
|
+
sae (T_TRAINING_SAE_CONFIG): The configuration for the SAE itself (e.g. StandardSAEConfig, GatedSAEConfig).
|
|
110
103
|
model_name (str): The name of the model to use. This should be the name of the model in the Hugging Face model hub.
|
|
111
104
|
model_class_name (str): The name of the class of the model to use. This should be either `HookedTransformer` or `HookedMamba`.
|
|
112
105
|
hook_name (str): The name of the hook to use. This should be a valid TransformerLens hook.
|
|
113
106
|
hook_eval (str): NOT CURRENTLY IN USE. The name of the hook to use for evaluation.
|
|
114
|
-
|
|
115
|
-
hook_head_index (int, optional): When the hook if for an activatio with a head index, we can specify a specific head to use here.
|
|
107
|
+
hook_head_index (int, optional): When the hook is for an activation with a head index, we can specify a specific head to use here.
|
|
116
108
|
dataset_path (str): A Hugging Face dataset path.
|
|
117
109
|
dataset_trust_remote_code (bool): Whether to trust remote code when loading datasets from Huggingface.
|
|
118
110
|
streaming (bool): Whether to stream the dataset. Streaming large datasets is usually practical.
|
|
119
|
-
is_dataset_tokenized (bool):
|
|
111
|
+
is_dataset_tokenized (bool): Whether the dataset is already tokenized.
|
|
120
112
|
context_size (int): The context size to use when generating activations on which to train the SAE.
|
|
121
113
|
use_cached_activations (bool): Whether to use cached activations. This is useful when doing sweeps over the same activations.
|
|
122
|
-
cached_activations_path (str, optional): The path to the cached activations.
|
|
123
|
-
d_in (int): The input dimension of the SAE.
|
|
124
|
-
d_sae (int, optional): The output dimension of the SAE. If None, defaults to `d_in * expansion_factor`.
|
|
125
|
-
b_dec_init_method (str): The method to use to initialize the decoder bias. Zeros is likely fine.
|
|
126
|
-
expansion_factor (int): The expansion factor. Larger is better but more computationally expensive. Default is 4.
|
|
127
|
-
activation_fn (str): The activation function to use. Relu is standard.
|
|
128
|
-
normalize_sae_decoder (bool): Whether to normalize the SAE decoder. Unit normed decoder weights used to be preferred.
|
|
129
|
-
noise_scale (float): Using noise to induce sparsity is supported but not recommended.
|
|
114
|
+
cached_activations_path (str, optional): The path to the cached activations. Defaults to "activations/{dataset_path}/{model_name}/{hook_name}_{hook_head_index}".
|
|
130
115
|
from_pretrained_path (str, optional): The path to a pretrained SAE. We can finetune an existing SAE if needed.
|
|
131
|
-
|
|
132
|
-
decoder_orthogonal_init (bool): Whether to use orthogonal initialization for the decoder. Not currently advised.
|
|
133
|
-
decoder_heuristic_init (bool): Whether to use heuristic initialization for the decoder. See Anthropic April Update.
|
|
134
|
-
init_encoder_as_decoder_transpose (bool): Whether to initialize the encoder as the transpose of the decoder. See Anthropic April Update.
|
|
135
|
-
n_batches_in_buffer (int): The number of batches in the buffer. When not using cached activations, a buffer in ram is used. The larger it is, the better shuffled the activations will be.
|
|
116
|
+
n_batches_in_buffer (int): The number of batches in the buffer. When not using cached activations, a buffer in RAM is used. The larger it is, the better shuffled the activations will be.
|
|
136
117
|
training_tokens (int): The number of training tokens.
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
seqpos_slice (tuple): Determines slicing of activations when constructing batches during training. The slice should be (start_pos, end_pos, optional[step_size]), e.g. for Othello we sometimes use (5, -5). Note, step_size > 0.
|
|
142
|
-
device (str): The device to use. Usually cuda.
|
|
143
|
-
act_store_device (str): The device to use for the activation store. CPU is advised in order to save vram.
|
|
118
|
+
store_batch_size_prompts (int): The batch size for storing activations. This controls how many prompts are in the batch of the language model when generating activations.
|
|
119
|
+
seqpos_slice (tuple[int | None, ...]): Determines slicing of activations when constructing batches during training. The slice should be (start_pos, end_pos, optional[step_size]), e.g. for Othello we sometimes use (5, -5). Note, step_size > 0.
|
|
120
|
+
device (str): The device to use. Usually "cuda".
|
|
121
|
+
act_store_device (str): The device to use for the activation store. "cpu" is advised in order to save VRAM. Defaults to "with_model" which uses the same device as the main model.
|
|
144
122
|
seed (int): The seed to use.
|
|
145
|
-
dtype (str): The data type to use.
|
|
123
|
+
dtype (str): The data type to use for the SAE and activations.
|
|
146
124
|
prepend_bos (bool): Whether to prepend the beginning of sequence token. You should use whatever the model was trained with.
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
adam_beta2 (float): The beta2 parameter for Adam.
|
|
157
|
-
mse_loss_normalization (str): The normalization to use for the MSE loss.
|
|
158
|
-
l1_coefficient (float): The L1 coefficient.
|
|
159
|
-
lp_norm (float): The Lp norm.
|
|
160
|
-
scale_sparsity_penalty_by_decoder_norm (bool): Whether to scale the sparsity penalty by the decoder norm.
|
|
161
|
-
l1_warm_up_steps (int): The number of warm-up steps for the L1 loss.
|
|
125
|
+
autocast (bool): Whether to use autocast (mixed-precision) during SAE training. Saves VRAM.
|
|
126
|
+
autocast_lm (bool): Whether to use autocast (mixed-precision) during activation fetching. Saves VRAM.
|
|
127
|
+
compile_llm (bool): Whether to compile the LLM using `torch.compile`.
|
|
128
|
+
llm_compilation_mode (str, optional): The compilation mode to use for the LLM if `compile_llm` is True.
|
|
129
|
+
compile_sae (bool): Whether to compile the SAE using `torch.compile`.
|
|
130
|
+
sae_compilation_mode (str, optional): The compilation mode to use for the SAE if `compile_sae` is True.
|
|
131
|
+
train_batch_size_tokens (int): The batch size for training, in tokens. This controls the batch size of the SAE training loop.
|
|
132
|
+
adam_beta1 (float): The beta1 parameter for the Adam optimizer.
|
|
133
|
+
adam_beta2 (float): The beta2 parameter for the Adam optimizer.
|
|
162
134
|
lr (float): The learning rate.
|
|
163
|
-
lr_scheduler_name (str): The name of the learning rate scheduler to use.
|
|
135
|
+
lr_scheduler_name (str): The name of the learning rate scheduler to use (e.g., "constant", "cosineannealing", "cosineannealingwarmrestarts").
|
|
164
136
|
lr_warm_up_steps (int): The number of warm-up steps for the learning rate.
|
|
165
|
-
lr_end (float): The end learning rate if
|
|
166
|
-
lr_decay_steps (int): The number of decay steps for the learning rate.
|
|
167
|
-
n_restart_cycles (int): The number of restart cycles for the cosine annealing warm restarts scheduler.
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
log_activations_store_to_wandb (bool): NOT CURRENTLY USED. Whether to log the activations store to Weights & Biases.
|
|
177
|
-
log_optimizer_state_to_wandb (bool): NOT CURRENTLY USED. Whether to log the optimizer state to Weights & Biases.
|
|
178
|
-
wandb_project (str): The Weights & Biases project to log to.
|
|
179
|
-
wandb_id (str): The Weights & Biases ID.
|
|
180
|
-
run_name (str): The name of the run.
|
|
181
|
-
wandb_entity (str): The Weights & Biases entity.
|
|
182
|
-
wandb_log_frequency (int): The frequency to log to Weights & Biases.
|
|
183
|
-
eval_every_n_wandb_logs (int): The frequency to evaluate.
|
|
184
|
-
resume (bool): Whether to resume training.
|
|
185
|
-
n_checkpoints (int): The number of checkpoints.
|
|
186
|
-
checkpoint_path (str): The path to save checkpoints.
|
|
137
|
+
lr_end (float, optional): The end learning rate if using a scheduler like cosine annealing. Defaults to `lr / 10`.
|
|
138
|
+
lr_decay_steps (int): The number of decay steps for the learning rate if using a scheduler with decay.
|
|
139
|
+
n_restart_cycles (int): The number of restart cycles for the cosine annealing with warm restarts scheduler.
|
|
140
|
+
dead_feature_window (int): The window size (in training steps) for detecting dead features.
|
|
141
|
+
feature_sampling_window (int): The window size (in training steps) for resampling features (e.g. dead features).
|
|
142
|
+
dead_feature_threshold (float): The threshold below which a feature's activation frequency is considered dead.
|
|
143
|
+
n_eval_batches (int): The number of batches to use for evaluation.
|
|
144
|
+
eval_batch_size_prompts (int, optional): The batch size for evaluation, in prompts. Useful if evals cause OOM.
|
|
145
|
+
logger (LoggingConfig): Configuration for logging (e.g. W&B).
|
|
146
|
+
n_checkpoints (int): The number of checkpoints to save during training. 0 means no checkpoints.
|
|
147
|
+
checkpoint_path (str): The path to save checkpoints. A unique ID will be appended to this path.
|
|
187
148
|
verbose (bool): Whether to print verbose output.
|
|
188
|
-
model_kwargs (dict[str, Any]):
|
|
189
|
-
model_from_pretrained_kwargs (dict[str, Any]): Additional keyword arguments
|
|
190
|
-
|
|
149
|
+
model_kwargs (dict[str, Any]): Keyword arguments for `model.run_with_cache`
|
|
150
|
+
model_from_pretrained_kwargs (dict[str, Any], optional): Additional keyword arguments to pass to the model's `from_pretrained` method.
|
|
151
|
+
sae_lens_version (str): The version of the sae_lens library.
|
|
152
|
+
sae_lens_training_version (str): The version of the sae_lens training library.
|
|
153
|
+
exclude_special_tokens (bool | list[int]): Whether to exclude special tokens from the activations. If True, excludes all special tokens. If a list of ints, excludes those token IDs.
|
|
191
154
|
"""
|
|
192
155
|
|
|
156
|
+
sae: T_TRAINING_SAE_CONFIG
|
|
157
|
+
|
|
193
158
|
# Data Generating Function (Model + Training Distibuion)
|
|
194
159
|
model_name: str = "gelu-2l"
|
|
195
160
|
model_class_name: str = "HookedTransformer"
|
|
196
161
|
hook_name: str = "blocks.0.hook_mlp_out"
|
|
197
162
|
hook_eval: str = "NOT_IN_USE"
|
|
198
|
-
hook_layer: int = 0
|
|
199
163
|
hook_head_index: int | None = None
|
|
200
164
|
dataset_path: str = ""
|
|
201
165
|
dataset_trust_remote_code: bool = True
|
|
@@ -208,29 +172,12 @@ class LanguageModelSAERunnerConfig:
|
|
|
208
172
|
)
|
|
209
173
|
|
|
210
174
|
# SAE Parameters
|
|
211
|
-
architecture: Literal["standard", "gated", "jumprelu", "topk"] = "standard"
|
|
212
|
-
d_in: int = 512
|
|
213
|
-
d_sae: int | None = None
|
|
214
|
-
b_dec_init_method: str = "geometric_median"
|
|
215
|
-
expansion_factor: int | None = (
|
|
216
|
-
None # defaults to 4 if d_sae and expansion_factor is None
|
|
217
|
-
)
|
|
218
|
-
activation_fn: str = None # relu, tanh-relu, topk. Default is relu. # type: ignore
|
|
219
|
-
activation_fn_kwargs: dict[str, int] = dict_field(default=None) # for topk
|
|
220
|
-
normalize_sae_decoder: bool = True
|
|
221
|
-
noise_scale: float = 0.0
|
|
222
175
|
from_pretrained_path: str | None = None
|
|
223
|
-
apply_b_dec_to_input: bool = True
|
|
224
|
-
decoder_orthogonal_init: bool = False
|
|
225
|
-
decoder_heuristic_init: bool = False
|
|
226
|
-
init_encoder_as_decoder_transpose: bool = False
|
|
227
176
|
|
|
228
177
|
# Activation Store Parameters
|
|
229
178
|
n_batches_in_buffer: int = 20
|
|
230
179
|
training_tokens: int = 2_000_000
|
|
231
|
-
finetuning_tokens: int = 0
|
|
232
180
|
store_batch_size_prompts: int = 32
|
|
233
|
-
normalize_activations: str = "none" # none, expected_average_only_in (Anthropic April Update), constant_norm_rescale (Anthropic Feb Update)
|
|
234
181
|
seqpos_slice: tuple[int | None, ...] = (None,)
|
|
235
182
|
|
|
236
183
|
# Misc
|
|
@@ -240,10 +187,6 @@ class LanguageModelSAERunnerConfig:
|
|
|
240
187
|
dtype: str = "float32" # type: ignore #
|
|
241
188
|
prepend_bos: bool = True
|
|
242
189
|
|
|
243
|
-
# JumpReLU Parameters
|
|
244
|
-
jumprelu_init_threshold: float = 0.001
|
|
245
|
-
jumprelu_bandwidth: float = 0.001
|
|
246
|
-
|
|
247
190
|
# Performance - see compilation section of lm_runner.py for info
|
|
248
191
|
autocast: bool = False # autocast to autocast_dtype during training
|
|
249
192
|
autocast_lm: bool = False # autocast lm during activation fetching
|
|
@@ -261,13 +204,6 @@ class LanguageModelSAERunnerConfig:
|
|
|
261
204
|
adam_beta1: float = 0.0
|
|
262
205
|
adam_beta2: float = 0.999
|
|
263
206
|
|
|
264
|
-
## Loss Function
|
|
265
|
-
mse_loss_normalization: str | None = None
|
|
266
|
-
l1_coefficient: float = 1e-3
|
|
267
|
-
lp_norm: float = 1
|
|
268
|
-
scale_sparsity_penalty_by_decoder_norm: bool = False
|
|
269
|
-
l1_warm_up_steps: int = 0
|
|
270
|
-
|
|
271
207
|
## Learning Rate Schedule
|
|
272
208
|
lr: float = 3e-4
|
|
273
209
|
lr_scheduler_name: str = (
|
|
@@ -278,14 +214,9 @@ class LanguageModelSAERunnerConfig:
|
|
|
278
214
|
lr_decay_steps: int = 0
|
|
279
215
|
n_restart_cycles: int = 1 # used only for cosineannealingwarmrestarts
|
|
280
216
|
|
|
281
|
-
## FineTuning
|
|
282
|
-
finetuning_method: str | None = None # scale, decoder or unrotated_decoder
|
|
283
|
-
|
|
284
217
|
# Resampling protocol args
|
|
285
|
-
use_ghost_grads: bool = False # want to change this to true on some timeline.
|
|
286
|
-
feature_sampling_window: int = 2000
|
|
287
218
|
dead_feature_window: int = 1000 # unless this window is larger feature sampling,
|
|
288
|
-
|
|
219
|
+
feature_sampling_window: int = 2000
|
|
289
220
|
dead_feature_threshold: float = 1e-8
|
|
290
221
|
|
|
291
222
|
# Evals
|
|
@@ -295,7 +226,6 @@ class LanguageModelSAERunnerConfig:
|
|
|
295
226
|
logger: LoggingConfig = field(default_factory=LoggingConfig)
|
|
296
227
|
|
|
297
228
|
# Misc
|
|
298
|
-
resume: bool = False
|
|
299
229
|
n_checkpoints: int = 0
|
|
300
230
|
checkpoint_path: str = "checkpoints"
|
|
301
231
|
verbose: bool = True
|
|
@@ -306,12 +236,6 @@ class LanguageModelSAERunnerConfig:
|
|
|
306
236
|
exclude_special_tokens: bool | list[int] = False
|
|
307
237
|
|
|
308
238
|
def __post_init__(self):
|
|
309
|
-
if self.resume:
|
|
310
|
-
raise ValueError(
|
|
311
|
-
"Resuming is no longer supported. You can finetune a trained SAE using cfg.from_pretrained path."
|
|
312
|
-
+ "If you want to load an SAE with resume=True in the config, please manually set resume=False in that config."
|
|
313
|
-
)
|
|
314
|
-
|
|
315
239
|
if self.use_cached_activations and self.cached_activations_path is None:
|
|
316
240
|
self.cached_activations_path = _default_cached_activations_path(
|
|
317
241
|
self.dataset_path,
|
|
@@ -319,37 +243,12 @@ class LanguageModelSAERunnerConfig:
|
|
|
319
243
|
self.hook_name,
|
|
320
244
|
self.hook_head_index,
|
|
321
245
|
)
|
|
322
|
-
|
|
323
|
-
if self.activation_fn is None:
|
|
324
|
-
self.activation_fn = "topk" if self.architecture == "topk" else "relu"
|
|
325
|
-
|
|
326
|
-
if self.architecture == "topk" and self.activation_fn != "topk":
|
|
327
|
-
raise ValueError("If using topk architecture, activation_fn must be topk.")
|
|
328
|
-
|
|
329
|
-
if self.activation_fn_kwargs is None:
|
|
330
|
-
self.activation_fn_kwargs = (
|
|
331
|
-
{"k": 100} if self.activation_fn == "topk" else {}
|
|
332
|
-
)
|
|
333
|
-
|
|
334
|
-
if self.architecture == "topk" and self.activation_fn_kwargs.get("k") is None:
|
|
335
|
-
raise ValueError(
|
|
336
|
-
"activation_fn_kwargs.k must be provided for topk architecture."
|
|
337
|
-
)
|
|
338
|
-
|
|
339
|
-
if self.d_sae is not None and self.expansion_factor is not None:
|
|
340
|
-
raise ValueError("You can't set both d_sae and expansion_factor.")
|
|
341
|
-
|
|
342
|
-
if self.d_sae is None and self.expansion_factor is None:
|
|
343
|
-
self.expansion_factor = 4
|
|
344
|
-
|
|
345
|
-
if self.d_sae is None and self.expansion_factor is not None:
|
|
346
|
-
self.d_sae = self.d_in * self.expansion_factor
|
|
347
246
|
self.tokens_per_buffer = (
|
|
348
247
|
self.train_batch_size_tokens * self.context_size * self.n_batches_in_buffer
|
|
349
248
|
)
|
|
350
249
|
|
|
351
250
|
if self.logger.run_name is None:
|
|
352
|
-
self.logger.run_name = f"{self.
|
|
251
|
+
self.logger.run_name = f"{self.sae.architecture()}-{self.sae.d_sae}-LR-{self.lr}-Tokens-{self.training_tokens:3.3e}"
|
|
353
252
|
|
|
354
253
|
if self.model_from_pretrained_kwargs is None:
|
|
355
254
|
if self.model_class_name == "HookedTransformer":
|
|
@@ -357,37 +256,6 @@ class LanguageModelSAERunnerConfig:
|
|
|
357
256
|
else:
|
|
358
257
|
self.model_from_pretrained_kwargs = {}
|
|
359
258
|
|
|
360
|
-
if self.b_dec_init_method not in ["geometric_median", "mean", "zeros"]:
|
|
361
|
-
raise ValueError(
|
|
362
|
-
f"b_dec_init_method must be geometric_median, mean, or zeros. Got {self.b_dec_init_method}"
|
|
363
|
-
)
|
|
364
|
-
|
|
365
|
-
if self.normalize_sae_decoder and self.decoder_heuristic_init:
|
|
366
|
-
raise ValueError(
|
|
367
|
-
"You can't normalize the decoder and use heuristic initialization."
|
|
368
|
-
)
|
|
369
|
-
|
|
370
|
-
if self.normalize_sae_decoder and self.scale_sparsity_penalty_by_decoder_norm:
|
|
371
|
-
raise ValueError(
|
|
372
|
-
"Weighting loss by decoder norm makes no sense if you are normalizing the decoder weight norms to 1"
|
|
373
|
-
)
|
|
374
|
-
|
|
375
|
-
# if we use decoder fine tuning, we can't be applying b_dec to the input
|
|
376
|
-
if (self.finetuning_method == "decoder") and (self.apply_b_dec_to_input):
|
|
377
|
-
raise ValueError(
|
|
378
|
-
"If we are fine tuning the decoder, we can't be applying b_dec to the input.\nSet apply_b_dec_to_input to False."
|
|
379
|
-
)
|
|
380
|
-
|
|
381
|
-
if self.normalize_activations not in [
|
|
382
|
-
"none",
|
|
383
|
-
"expected_average_only_in",
|
|
384
|
-
"constant_norm_rescale",
|
|
385
|
-
"layer_norm",
|
|
386
|
-
]:
|
|
387
|
-
raise ValueError(
|
|
388
|
-
f"normalize_activations must be none, layer_norm, expected_average_only_in, or constant_norm_rescale. Got {self.normalize_activations}"
|
|
389
|
-
)
|
|
390
|
-
|
|
391
259
|
if self.act_store_device == "with_model":
|
|
392
260
|
self.act_store_device = self.device
|
|
393
261
|
|
|
@@ -403,7 +271,7 @@ class LanguageModelSAERunnerConfig:
|
|
|
403
271
|
|
|
404
272
|
if self.verbose:
|
|
405
273
|
logger.info(
|
|
406
|
-
f"Run name: {self.
|
|
274
|
+
f"Run name: {self.sae.architecture()}-{self.sae.d_sae}-LR-{self.lr}-Tokens-{self.training_tokens:3.3e}"
|
|
407
275
|
)
|
|
408
276
|
# Print out some useful info:
|
|
409
277
|
n_tokens_per_buffer = (
|
|
@@ -422,7 +290,7 @@ class LanguageModelSAERunnerConfig:
|
|
|
422
290
|
)
|
|
423
291
|
|
|
424
292
|
total_training_steps = (
|
|
425
|
-
self.training_tokens
|
|
293
|
+
self.training_tokens
|
|
426
294
|
) // self.train_batch_size_tokens
|
|
427
295
|
logger.info(f"Total training steps: {total_training_steps}")
|
|
428
296
|
|
|
@@ -450,9 +318,6 @@ class LanguageModelSAERunnerConfig:
|
|
|
450
318
|
f"Number tokens in sparsity calculation window: {self.feature_sampling_window * self.train_batch_size_tokens:.2e}"
|
|
451
319
|
)
|
|
452
320
|
|
|
453
|
-
if self.use_ghost_grads:
|
|
454
|
-
logger.info("Using Ghost Grads.")
|
|
455
|
-
|
|
456
321
|
if self.context_size < 0:
|
|
457
322
|
raise ValueError(
|
|
458
323
|
f"The provided context_size is {self.context_size} is negative. Expecting positive context_size."
|
|
@@ -467,62 +332,21 @@ class LanguageModelSAERunnerConfig:
|
|
|
467
332
|
|
|
468
333
|
@property
|
|
469
334
|
def total_training_tokens(self) -> int:
|
|
470
|
-
return self.training_tokens
|
|
335
|
+
return self.training_tokens
|
|
471
336
|
|
|
472
337
|
@property
|
|
473
338
|
def total_training_steps(self) -> int:
|
|
474
339
|
return self.total_training_tokens // self.train_batch_size_tokens
|
|
475
340
|
|
|
476
|
-
def get_base_sae_cfg_dict(self) -> dict[str, Any]:
|
|
477
|
-
return {
|
|
478
|
-
# TEMP
|
|
479
|
-
"architecture": self.architecture,
|
|
480
|
-
"d_in": self.d_in,
|
|
481
|
-
"d_sae": self.d_sae,
|
|
482
|
-
"dtype": self.dtype,
|
|
483
|
-
"device": self.device,
|
|
484
|
-
"model_name": self.model_name,
|
|
485
|
-
"hook_name": self.hook_name,
|
|
486
|
-
"hook_layer": self.hook_layer,
|
|
487
|
-
"hook_head_index": self.hook_head_index,
|
|
488
|
-
"activation_fn": self.activation_fn,
|
|
489
|
-
"apply_b_dec_to_input": self.apply_b_dec_to_input,
|
|
490
|
-
"context_size": self.context_size,
|
|
491
|
-
"prepend_bos": self.prepend_bos,
|
|
492
|
-
"dataset_path": self.dataset_path,
|
|
493
|
-
"dataset_trust_remote_code": self.dataset_trust_remote_code,
|
|
494
|
-
"finetuning_scaling_factor": self.finetuning_method is not None,
|
|
495
|
-
"sae_lens_training_version": self.sae_lens_training_version,
|
|
496
|
-
"normalize_activations": self.normalize_activations,
|
|
497
|
-
"activation_fn_kwargs": self.activation_fn_kwargs,
|
|
498
|
-
"model_from_pretrained_kwargs": self.model_from_pretrained_kwargs,
|
|
499
|
-
"seqpos_slice": self.seqpos_slice,
|
|
500
|
-
}
|
|
501
|
-
|
|
502
341
|
def get_training_sae_cfg_dict(self) -> dict[str, Any]:
|
|
503
|
-
return
|
|
504
|
-
**self.get_base_sae_cfg_dict(),
|
|
505
|
-
"l1_coefficient": self.l1_coefficient,
|
|
506
|
-
"lp_norm": self.lp_norm,
|
|
507
|
-
"use_ghost_grads": self.use_ghost_grads,
|
|
508
|
-
"normalize_sae_decoder": self.normalize_sae_decoder,
|
|
509
|
-
"noise_scale": self.noise_scale,
|
|
510
|
-
"decoder_orthogonal_init": self.decoder_orthogonal_init,
|
|
511
|
-
"mse_loss_normalization": self.mse_loss_normalization,
|
|
512
|
-
"decoder_heuristic_init": self.decoder_heuristic_init,
|
|
513
|
-
"init_encoder_as_decoder_transpose": self.init_encoder_as_decoder_transpose,
|
|
514
|
-
"normalize_activations": self.normalize_activations,
|
|
515
|
-
"jumprelu_init_threshold": self.jumprelu_init_threshold,
|
|
516
|
-
"jumprelu_bandwidth": self.jumprelu_bandwidth,
|
|
517
|
-
"scale_sparsity_penalty_by_decoder_norm": self.scale_sparsity_penalty_by_decoder_norm,
|
|
518
|
-
}
|
|
342
|
+
return self.sae.to_dict()
|
|
519
343
|
|
|
520
344
|
def to_dict(self) -> dict[str, Any]:
|
|
521
|
-
# Make a shallow copy of config
|
|
345
|
+
# Make a shallow copy of config's dictionary
|
|
522
346
|
d = dict(self.__dict__)
|
|
523
347
|
|
|
524
348
|
d["logger"] = asdict(self.logger)
|
|
525
|
-
|
|
349
|
+
d["sae"] = self.sae.to_dict()
|
|
526
350
|
# Overwrite fields that might not be JSON-serializable
|
|
527
351
|
d["dtype"] = str(self.dtype)
|
|
528
352
|
d["device"] = str(self.device)
|
|
@@ -537,7 +361,7 @@ class LanguageModelSAERunnerConfig:
|
|
|
537
361
|
json.dump(self.to_dict(), f, indent=2)
|
|
538
362
|
|
|
539
363
|
@classmethod
|
|
540
|
-
def from_json(cls, path: str) -> "LanguageModelSAERunnerConfig":
|
|
364
|
+
def from_json(cls, path: str) -> "LanguageModelSAERunnerConfig[Any]":
|
|
541
365
|
with open(path + "cfg.json") as f:
|
|
542
366
|
cfg = json.load(f)
|
|
543
367
|
|
|
@@ -551,6 +375,28 @@ class LanguageModelSAERunnerConfig:
|
|
|
551
375
|
|
|
552
376
|
return cls(**cfg)
|
|
553
377
|
|
|
378
|
+
def to_sae_trainer_config(self) -> "SAETrainerConfig":
|
|
379
|
+
return SAETrainerConfig(
|
|
380
|
+
n_checkpoints=self.n_checkpoints,
|
|
381
|
+
checkpoint_path=self.checkpoint_path,
|
|
382
|
+
total_training_samples=self.total_training_tokens,
|
|
383
|
+
device=self.device,
|
|
384
|
+
autocast=self.autocast,
|
|
385
|
+
lr=self.lr,
|
|
386
|
+
lr_end=self.lr_end,
|
|
387
|
+
lr_scheduler_name=self.lr_scheduler_name,
|
|
388
|
+
lr_warm_up_steps=self.lr_warm_up_steps,
|
|
389
|
+
adam_beta1=self.adam_beta1,
|
|
390
|
+
adam_beta2=self.adam_beta2,
|
|
391
|
+
lr_decay_steps=self.lr_decay_steps,
|
|
392
|
+
n_restart_cycles=self.n_restart_cycles,
|
|
393
|
+
total_training_steps=self.total_training_steps,
|
|
394
|
+
train_batch_size_samples=self.train_batch_size_tokens,
|
|
395
|
+
dead_feature_window=self.dead_feature_window,
|
|
396
|
+
feature_sampling_window=self.feature_sampling_window,
|
|
397
|
+
logger=self.logger,
|
|
398
|
+
)
|
|
399
|
+
|
|
554
400
|
|
|
555
401
|
@dataclass
|
|
556
402
|
class CacheActivationsRunnerConfig:
|
|
@@ -562,7 +408,6 @@ class CacheActivationsRunnerConfig:
|
|
|
562
408
|
model_name (str): The name of the model to use.
|
|
563
409
|
model_batch_size (int): How many prompts are in the batch of the language model when generating activations.
|
|
564
410
|
hook_name (str): The name of the hook to use.
|
|
565
|
-
hook_layer (int): The layer of the final hook. Currently only support a single hook, so this should be the same as hook_name.
|
|
566
411
|
d_in (int): Dimension of the model.
|
|
567
412
|
total_training_tokens (int): Total number of tokens to process.
|
|
568
413
|
context_size (int): Context size to process. Can be left as -1 if the dataset is tokenized.
|
|
@@ -592,7 +437,6 @@ class CacheActivationsRunnerConfig:
|
|
|
592
437
|
model_name: str
|
|
593
438
|
model_batch_size: int
|
|
594
439
|
hook_name: str
|
|
595
|
-
hook_layer: int
|
|
596
440
|
d_in: int
|
|
597
441
|
training_tokens: int
|
|
598
442
|
|
|
@@ -720,6 +564,10 @@ def _validate_seqpos(seqpos: tuple[int | None, ...], context_size: int) -> None:
|
|
|
720
564
|
|
|
721
565
|
@dataclass
|
|
722
566
|
class PretokenizeRunnerConfig:
|
|
567
|
+
"""
|
|
568
|
+
Configuration class for pretokenizing a dataset.
|
|
569
|
+
"""
|
|
570
|
+
|
|
723
571
|
tokenizer_name: str = "gpt2"
|
|
724
572
|
dataset_path: str = ""
|
|
725
573
|
dataset_name: str | None = None
|
|
@@ -748,3 +596,25 @@ class PretokenizeRunnerConfig:
|
|
|
748
596
|
hf_num_shards: int = 64
|
|
749
597
|
hf_revision: str = "main"
|
|
750
598
|
hf_is_private_repo: bool = False
|
|
599
|
+
|
|
600
|
+
|
|
601
|
+
@dataclass
|
|
602
|
+
class SAETrainerConfig:
|
|
603
|
+
n_checkpoints: int
|
|
604
|
+
checkpoint_path: str
|
|
605
|
+
total_training_samples: int
|
|
606
|
+
device: str
|
|
607
|
+
autocast: bool
|
|
608
|
+
lr: float
|
|
609
|
+
lr_end: float | None
|
|
610
|
+
lr_scheduler_name: str
|
|
611
|
+
lr_warm_up_steps: int
|
|
612
|
+
adam_beta1: float
|
|
613
|
+
adam_beta2: float
|
|
614
|
+
lr_decay_steps: int
|
|
615
|
+
n_restart_cycles: int
|
|
616
|
+
total_training_steps: int
|
|
617
|
+
train_batch_size_samples: int
|
|
618
|
+
dead_feature_window: int
|
|
619
|
+
feature_sampling_window: int
|
|
620
|
+
logger: LoggingConfig
|
sae_lens/constants.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
DTYPE_MAP = {
|
|
4
|
+
"float32": torch.float32,
|
|
5
|
+
"float64": torch.float64,
|
|
6
|
+
"float16": torch.float16,
|
|
7
|
+
"bfloat16": torch.bfloat16,
|
|
8
|
+
"torch.float32": torch.float32,
|
|
9
|
+
"torch.float64": torch.float64,
|
|
10
|
+
"torch.float16": torch.float16,
|
|
11
|
+
"torch.bfloat16": torch.bfloat16,
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
SPARSITY_FILENAME = "sparsity.safetensors"
|
|
16
|
+
SAE_WEIGHTS_FILENAME = "sae_weights.safetensors"
|
|
17
|
+
SAE_CFG_FILENAME = "cfg.json"
|
|
18
|
+
RUNNER_CFG_FILENAME = "runner_cfg.json"
|
|
19
|
+
ACTIVATIONS_STORE_STATE_FILENAME = "activations_store_state.safetensors"
|
|
20
|
+
ACTIVATION_SCALER_CFG_FILENAME = "activation_scaler.json"
|