sae-lens 5.10.7__py3-none-any.whl → 6.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +60 -7
- sae_lens/analysis/hooked_sae_transformer.py +12 -12
- sae_lens/analysis/neuronpedia_integration.py +16 -14
- sae_lens/cache_activations_runner.py +9 -7
- sae_lens/config.py +170 -257
- sae_lens/constants.py +21 -0
- sae_lens/evals.py +59 -44
- sae_lens/llm_sae_training_runner.py +377 -0
- sae_lens/load_model.py +53 -5
- sae_lens/{toolkit → loading}/pretrained_sae_loaders.py +228 -32
- sae_lens/registry.py +49 -0
- sae_lens/saes/__init__.py +48 -0
- sae_lens/saes/gated_sae.py +254 -0
- sae_lens/saes/jumprelu_sae.py +348 -0
- sae_lens/saes/sae.py +1076 -0
- sae_lens/saes/standard_sae.py +178 -0
- sae_lens/saes/topk_sae.py +300 -0
- sae_lens/training/activation_scaler.py +53 -0
- sae_lens/training/activations_store.py +103 -184
- sae_lens/training/mixing_buffer.py +56 -0
- sae_lens/training/optim.py +60 -36
- sae_lens/training/sae_trainer.py +155 -177
- sae_lens/training/types.py +5 -0
- sae_lens/training/upload_saes_to_huggingface.py +13 -7
- sae_lens/util.py +47 -0
- {sae_lens-5.10.7.dist-info → sae_lens-6.0.0.dist-info}/METADATA +1 -1
- sae_lens-6.0.0.dist-info/RECORD +37 -0
- sae_lens/sae.py +0 -747
- sae_lens/sae_training_runner.py +0 -251
- sae_lens/training/geometric_median.py +0 -101
- sae_lens/training/training_sae.py +0 -710
- sae_lens-5.10.7.dist-info/RECORD +0 -28
- /sae_lens/{toolkit → loading}/__init__.py +0 -0
- /sae_lens/{toolkit → loading}/pretrained_saes_directory.py +0 -0
- {sae_lens-5.10.7.dist-info → sae_lens-6.0.0.dist-info}/LICENSE +0 -0
- {sae_lens-5.10.7.dist-info → sae_lens-6.0.0.dist-info}/WHEEL +0 -0
sae_lens/config.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import math
|
|
3
3
|
import os
|
|
4
|
-
from dataclasses import dataclass, field
|
|
5
|
-
from
|
|
4
|
+
from dataclasses import asdict, dataclass, field
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast
|
|
6
7
|
|
|
7
8
|
import simple_parsing
|
|
8
9
|
import torch
|
|
@@ -16,24 +17,17 @@ from datasets import (
|
|
|
16
17
|
)
|
|
17
18
|
|
|
18
19
|
from sae_lens import __version__, logger
|
|
20
|
+
from sae_lens.constants import DTYPE_MAP
|
|
21
|
+
from sae_lens.saes.sae import TrainingSAEConfig
|
|
19
22
|
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
"float64": torch.float64,
|
|
23
|
-
"float16": torch.float16,
|
|
24
|
-
"bfloat16": torch.bfloat16,
|
|
25
|
-
"torch.float32": torch.float32,
|
|
26
|
-
"torch.float64": torch.float64,
|
|
27
|
-
"torch.float16": torch.float16,
|
|
28
|
-
"torch.bfloat16": torch.bfloat16,
|
|
29
|
-
}
|
|
30
|
-
|
|
31
|
-
HfDataset = DatasetDict | Dataset | IterableDatasetDict | IterableDataset
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
pass
|
|
32
25
|
|
|
26
|
+
T_TRAINING_SAE_CONFIG = TypeVar(
|
|
27
|
+
"T_TRAINING_SAE_CONFIG", bound=TrainingSAEConfig, covariant=True
|
|
28
|
+
)
|
|
33
29
|
|
|
34
|
-
|
|
35
|
-
SAE_WEIGHTS_FILENAME = "sae_weights.safetensors"
|
|
36
|
-
SAE_CFG_FILENAME = "cfg.json"
|
|
30
|
+
HfDataset = DatasetDict | Dataset | IterableDatasetDict | IterableDataset
|
|
37
31
|
|
|
38
32
|
|
|
39
33
|
# calling this "json_dict" so error messages will reference "json_dict" being invalid
|
|
@@ -54,101 +48,118 @@ def dict_field(default: dict[str, Any] | None, **kwargs: Any) -> Any: # type: i
|
|
|
54
48
|
|
|
55
49
|
|
|
56
50
|
@dataclass
|
|
57
|
-
class
|
|
51
|
+
class LoggingConfig:
|
|
52
|
+
# WANDB
|
|
53
|
+
log_to_wandb: bool = True
|
|
54
|
+
log_activations_store_to_wandb: bool = False
|
|
55
|
+
log_optimizer_state_to_wandb: bool = False
|
|
56
|
+
wandb_project: str = "sae_lens_training"
|
|
57
|
+
wandb_id: str | None = None
|
|
58
|
+
run_name: str | None = None
|
|
59
|
+
wandb_entity: str | None = None
|
|
60
|
+
wandb_log_frequency: int = 10
|
|
61
|
+
eval_every_n_wandb_logs: int = 100 # logs every 100 steps.
|
|
62
|
+
|
|
63
|
+
def log(
|
|
64
|
+
self,
|
|
65
|
+
trainer: Any, # avoid import cycle from importing SAETrainer
|
|
66
|
+
weights_path: Path | str,
|
|
67
|
+
cfg_path: Path | str,
|
|
68
|
+
sparsity_path: Path | str | None,
|
|
69
|
+
wandb_aliases: list[str] | None = None,
|
|
70
|
+
) -> None:
|
|
71
|
+
# Avoid wandb saving errors such as:
|
|
72
|
+
# ValueError: Artifact name may only contain alphanumeric characters, dashes, underscores, and dots. Invalid name: sae_google/gemma-2b_etc
|
|
73
|
+
sae_name = trainer.sae.get_name().replace("/", "__")
|
|
74
|
+
|
|
75
|
+
# save model weights and cfg
|
|
76
|
+
model_artifact = wandb.Artifact(
|
|
77
|
+
sae_name,
|
|
78
|
+
type="model",
|
|
79
|
+
metadata=dict(trainer.cfg.__dict__),
|
|
80
|
+
)
|
|
81
|
+
model_artifact.add_file(str(weights_path))
|
|
82
|
+
model_artifact.add_file(str(cfg_path))
|
|
83
|
+
wandb.log_artifact(model_artifact, aliases=wandb_aliases)
|
|
84
|
+
|
|
85
|
+
# save log feature sparsity
|
|
86
|
+
sparsity_artifact = wandb.Artifact(
|
|
87
|
+
f"{sae_name}_log_feature_sparsity",
|
|
88
|
+
type="log_feature_sparsity",
|
|
89
|
+
metadata=dict(trainer.cfg.__dict__),
|
|
90
|
+
)
|
|
91
|
+
if sparsity_path is not None:
|
|
92
|
+
sparsity_artifact.add_file(str(sparsity_path))
|
|
93
|
+
wandb.log_artifact(sparsity_artifact)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@dataclass
|
|
97
|
+
class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
58
98
|
"""
|
|
59
99
|
Configuration for training a sparse autoencoder on a language model.
|
|
60
100
|
|
|
61
101
|
Args:
|
|
62
|
-
|
|
102
|
+
sae (T_TRAINING_SAE_CONFIG): The configuration for the SAE itself (e.g. StandardSAEConfig, GatedSAEConfig).
|
|
63
103
|
model_name (str): The name of the model to use. This should be the name of the model in the Hugging Face model hub.
|
|
64
104
|
model_class_name (str): The name of the class of the model to use. This should be either `HookedTransformer` or `HookedMamba`.
|
|
65
105
|
hook_name (str): The name of the hook to use. This should be a valid TransformerLens hook.
|
|
66
106
|
hook_eval (str): NOT CURRENTLY IN USE. The name of the hook to use for evaluation.
|
|
67
|
-
|
|
68
|
-
hook_head_index (int, optional): When the hook if for an activatio with a head index, we can specify a specific head to use here.
|
|
107
|
+
hook_head_index (int, optional): When the hook is for an activation with a head index, we can specify a specific head to use here.
|
|
69
108
|
dataset_path (str): A Hugging Face dataset path.
|
|
70
109
|
dataset_trust_remote_code (bool): Whether to trust remote code when loading datasets from Huggingface.
|
|
71
110
|
streaming (bool): Whether to stream the dataset. Streaming large datasets is usually practical.
|
|
72
|
-
is_dataset_tokenized (bool):
|
|
111
|
+
is_dataset_tokenized (bool): Whether the dataset is already tokenized.
|
|
73
112
|
context_size (int): The context size to use when generating activations on which to train the SAE.
|
|
74
113
|
use_cached_activations (bool): Whether to use cached activations. This is useful when doing sweeps over the same activations.
|
|
75
|
-
cached_activations_path (str, optional): The path to the cached activations.
|
|
76
|
-
d_in (int): The input dimension of the SAE.
|
|
77
|
-
d_sae (int, optional): The output dimension of the SAE. If None, defaults to `d_in * expansion_factor`.
|
|
78
|
-
b_dec_init_method (str): The method to use to initialize the decoder bias. Zeros is likely fine.
|
|
79
|
-
expansion_factor (int): The expansion factor. Larger is better but more computationally expensive. Default is 4.
|
|
80
|
-
activation_fn (str): The activation function to use. Relu is standard.
|
|
81
|
-
normalize_sae_decoder (bool): Whether to normalize the SAE decoder. Unit normed decoder weights used to be preferred.
|
|
82
|
-
noise_scale (float): Using noise to induce sparsity is supported but not recommended.
|
|
114
|
+
cached_activations_path (str, optional): The path to the cached activations. Defaults to "activations/{dataset_path}/{model_name}/{hook_name}_{hook_head_index}".
|
|
83
115
|
from_pretrained_path (str, optional): The path to a pretrained SAE. We can finetune an existing SAE if needed.
|
|
84
|
-
|
|
85
|
-
decoder_orthogonal_init (bool): Whether to use orthogonal initialization for the decoder. Not currently advised.
|
|
86
|
-
decoder_heuristic_init (bool): Whether to use heuristic initialization for the decoder. See Anthropic April Update.
|
|
87
|
-
init_encoder_as_decoder_transpose (bool): Whether to initialize the encoder as the transpose of the decoder. See Anthropic April Update.
|
|
88
|
-
n_batches_in_buffer (int): The number of batches in the buffer. When not using cached activations, a buffer in ram is used. The larger it is, the better shuffled the activations will be.
|
|
116
|
+
n_batches_in_buffer (int): The number of batches in the buffer. When not using cached activations, a buffer in RAM is used. The larger it is, the better shuffled the activations will be.
|
|
89
117
|
training_tokens (int): The number of training tokens.
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
seqpos_slice (tuple): Determines slicing of activations when constructing batches during training. The slice should be (start_pos, end_pos, optional[step_size]), e.g. for Othello we sometimes use (5, -5). Note, step_size > 0.
|
|
95
|
-
device (str): The device to use. Usually cuda.
|
|
96
|
-
act_store_device (str): The device to use for the activation store. CPU is advised in order to save vram.
|
|
118
|
+
store_batch_size_prompts (int): The batch size for storing activations. This controls how many prompts are in the batch of the language model when generating activations.
|
|
119
|
+
seqpos_slice (tuple[int | None, ...]): Determines slicing of activations when constructing batches during training. The slice should be (start_pos, end_pos, optional[step_size]), e.g. for Othello we sometimes use (5, -5). Note, step_size > 0.
|
|
120
|
+
device (str): The device to use. Usually "cuda".
|
|
121
|
+
act_store_device (str): The device to use for the activation store. "cpu" is advised in order to save VRAM. Defaults to "with_model" which uses the same device as the main model.
|
|
97
122
|
seed (int): The seed to use.
|
|
98
|
-
dtype (str): The data type to use.
|
|
123
|
+
dtype (str): The data type to use for the SAE and activations.
|
|
99
124
|
prepend_bos (bool): Whether to prepend the beginning of sequence token. You should use whatever the model was trained with.
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
adam_beta2 (float): The beta2 parameter for Adam.
|
|
110
|
-
mse_loss_normalization (str): The normalization to use for the MSE loss.
|
|
111
|
-
l1_coefficient (float): The L1 coefficient.
|
|
112
|
-
lp_norm (float): The Lp norm.
|
|
113
|
-
scale_sparsity_penalty_by_decoder_norm (bool): Whether to scale the sparsity penalty by the decoder norm.
|
|
114
|
-
l1_warm_up_steps (int): The number of warm-up steps for the L1 loss.
|
|
125
|
+
autocast (bool): Whether to use autocast (mixed-precision) during SAE training. Saves VRAM.
|
|
126
|
+
autocast_lm (bool): Whether to use autocast (mixed-precision) during activation fetching. Saves VRAM.
|
|
127
|
+
compile_llm (bool): Whether to compile the LLM using `torch.compile`.
|
|
128
|
+
llm_compilation_mode (str, optional): The compilation mode to use for the LLM if `compile_llm` is True.
|
|
129
|
+
compile_sae (bool): Whether to compile the SAE using `torch.compile`.
|
|
130
|
+
sae_compilation_mode (str, optional): The compilation mode to use for the SAE if `compile_sae` is True.
|
|
131
|
+
train_batch_size_tokens (int): The batch size for training, in tokens. This controls the batch size of the SAE training loop.
|
|
132
|
+
adam_beta1 (float): The beta1 parameter for the Adam optimizer.
|
|
133
|
+
adam_beta2 (float): The beta2 parameter for the Adam optimizer.
|
|
115
134
|
lr (float): The learning rate.
|
|
116
|
-
lr_scheduler_name (str): The name of the learning rate scheduler to use.
|
|
135
|
+
lr_scheduler_name (str): The name of the learning rate scheduler to use (e.g., "constant", "cosineannealing", "cosineannealingwarmrestarts").
|
|
117
136
|
lr_warm_up_steps (int): The number of warm-up steps for the learning rate.
|
|
118
|
-
lr_end (float): The end learning rate if
|
|
119
|
-
lr_decay_steps (int): The number of decay steps for the learning rate.
|
|
120
|
-
n_restart_cycles (int): The number of restart cycles for the cosine annealing warm restarts scheduler.
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
log_activations_store_to_wandb (bool): NOT CURRENTLY USED. Whether to log the activations store to Weights & Biases.
|
|
130
|
-
log_optimizer_state_to_wandb (bool): NOT CURRENTLY USED. Whether to log the optimizer state to Weights & Biases.
|
|
131
|
-
wandb_project (str): The Weights & Biases project to log to.
|
|
132
|
-
wandb_id (str): The Weights & Biases ID.
|
|
133
|
-
run_name (str): The name of the run.
|
|
134
|
-
wandb_entity (str): The Weights & Biases entity.
|
|
135
|
-
wandb_log_frequency (int): The frequency to log to Weights & Biases.
|
|
136
|
-
eval_every_n_wandb_logs (int): The frequency to evaluate.
|
|
137
|
-
resume (bool): Whether to resume training.
|
|
138
|
-
n_checkpoints (int): The number of checkpoints.
|
|
139
|
-
checkpoint_path (str): The path to save checkpoints.
|
|
137
|
+
lr_end (float, optional): The end learning rate if using a scheduler like cosine annealing. Defaults to `lr / 10`.
|
|
138
|
+
lr_decay_steps (int): The number of decay steps for the learning rate if using a scheduler with decay.
|
|
139
|
+
n_restart_cycles (int): The number of restart cycles for the cosine annealing with warm restarts scheduler.
|
|
140
|
+
dead_feature_window (int): The window size (in training steps) for detecting dead features.
|
|
141
|
+
feature_sampling_window (int): The window size (in training steps) for resampling features (e.g. dead features).
|
|
142
|
+
dead_feature_threshold (float): The threshold below which a feature's activation frequency is considered dead.
|
|
143
|
+
n_eval_batches (int): The number of batches to use for evaluation.
|
|
144
|
+
eval_batch_size_prompts (int, optional): The batch size for evaluation, in prompts. Useful if evals cause OOM.
|
|
145
|
+
logger (LoggingConfig): Configuration for logging (e.g. W&B).
|
|
146
|
+
n_checkpoints (int): The number of checkpoints to save during training. 0 means no checkpoints.
|
|
147
|
+
checkpoint_path (str): The path to save checkpoints. A unique ID will be appended to this path.
|
|
140
148
|
verbose (bool): Whether to print verbose output.
|
|
141
|
-
model_kwargs (dict[str, Any]):
|
|
142
|
-
model_from_pretrained_kwargs (dict[str, Any]): Additional keyword arguments
|
|
143
|
-
|
|
149
|
+
model_kwargs (dict[str, Any]): Keyword arguments for `model.run_with_cache`
|
|
150
|
+
model_from_pretrained_kwargs (dict[str, Any], optional): Additional keyword arguments to pass to the model's `from_pretrained` method.
|
|
151
|
+
sae_lens_version (str): The version of the sae_lens library.
|
|
152
|
+
sae_lens_training_version (str): The version of the sae_lens training library.
|
|
153
|
+
exclude_special_tokens (bool | list[int]): Whether to exclude special tokens from the activations. If True, excludes all special tokens. If a list of ints, excludes those token IDs.
|
|
144
154
|
"""
|
|
145
155
|
|
|
156
|
+
sae: T_TRAINING_SAE_CONFIG
|
|
157
|
+
|
|
146
158
|
# Data Generating Function (Model + Training Distibuion)
|
|
147
159
|
model_name: str = "gelu-2l"
|
|
148
160
|
model_class_name: str = "HookedTransformer"
|
|
149
161
|
hook_name: str = "blocks.0.hook_mlp_out"
|
|
150
162
|
hook_eval: str = "NOT_IN_USE"
|
|
151
|
-
hook_layer: int = 0
|
|
152
163
|
hook_head_index: int | None = None
|
|
153
164
|
dataset_path: str = ""
|
|
154
165
|
dataset_trust_remote_code: bool = True
|
|
@@ -161,30 +172,12 @@ class LanguageModelSAERunnerConfig:
|
|
|
161
172
|
)
|
|
162
173
|
|
|
163
174
|
# SAE Parameters
|
|
164
|
-
architecture: Literal["standard", "gated", "jumprelu", "topk"] = "standard"
|
|
165
|
-
d_in: int = 512
|
|
166
|
-
d_sae: int | None = None
|
|
167
|
-
b_dec_init_method: str = "geometric_median"
|
|
168
|
-
expansion_factor: int | None = (
|
|
169
|
-
None # defaults to 4 if d_sae and expansion_factor is None
|
|
170
|
-
)
|
|
171
|
-
activation_fn: str = None # relu, tanh-relu, topk. Default is relu. # type: ignore
|
|
172
|
-
activation_fn_kwargs: dict[str, int] = dict_field(default=None) # for topk
|
|
173
|
-
normalize_sae_decoder: bool = True
|
|
174
|
-
noise_scale: float = 0.0
|
|
175
175
|
from_pretrained_path: str | None = None
|
|
176
|
-
apply_b_dec_to_input: bool = True
|
|
177
|
-
decoder_orthogonal_init: bool = False
|
|
178
|
-
decoder_heuristic_init: bool = False
|
|
179
|
-
decoder_heuristic_init_norm: float = 0.1
|
|
180
|
-
init_encoder_as_decoder_transpose: bool = False
|
|
181
176
|
|
|
182
177
|
# Activation Store Parameters
|
|
183
178
|
n_batches_in_buffer: int = 20
|
|
184
179
|
training_tokens: int = 2_000_000
|
|
185
|
-
finetuning_tokens: int = 0
|
|
186
180
|
store_batch_size_prompts: int = 32
|
|
187
|
-
normalize_activations: str = "none" # none, expected_average_only_in (Anthropic April Update), constant_norm_rescale (Anthropic Feb Update)
|
|
188
181
|
seqpos_slice: tuple[int | None, ...] = (None,)
|
|
189
182
|
|
|
190
183
|
# Misc
|
|
@@ -194,10 +187,6 @@ class LanguageModelSAERunnerConfig:
|
|
|
194
187
|
dtype: str = "float32" # type: ignore #
|
|
195
188
|
prepend_bos: bool = True
|
|
196
189
|
|
|
197
|
-
# JumpReLU Parameters
|
|
198
|
-
jumprelu_init_threshold: float = 0.001
|
|
199
|
-
jumprelu_bandwidth: float = 0.001
|
|
200
|
-
|
|
201
190
|
# Performance - see compilation section of lm_runner.py for info
|
|
202
191
|
autocast: bool = False # autocast to autocast_dtype during training
|
|
203
192
|
autocast_lm: bool = False # autocast lm during activation fetching
|
|
@@ -212,16 +201,9 @@ class LanguageModelSAERunnerConfig:
|
|
|
212
201
|
train_batch_size_tokens: int = 4096
|
|
213
202
|
|
|
214
203
|
## Adam
|
|
215
|
-
adam_beta1: float = 0.
|
|
204
|
+
adam_beta1: float = 0.9
|
|
216
205
|
adam_beta2: float = 0.999
|
|
217
206
|
|
|
218
|
-
## Loss Function
|
|
219
|
-
mse_loss_normalization: str | None = None
|
|
220
|
-
l1_coefficient: float = 1e-3
|
|
221
|
-
lp_norm: float = 1
|
|
222
|
-
scale_sparsity_penalty_by_decoder_norm: bool = False
|
|
223
|
-
l1_warm_up_steps: int = 0
|
|
224
|
-
|
|
225
207
|
## Learning Rate Schedule
|
|
226
208
|
lr: float = 3e-4
|
|
227
209
|
lr_scheduler_name: str = (
|
|
@@ -232,33 +214,18 @@ class LanguageModelSAERunnerConfig:
|
|
|
232
214
|
lr_decay_steps: int = 0
|
|
233
215
|
n_restart_cycles: int = 1 # used only for cosineannealingwarmrestarts
|
|
234
216
|
|
|
235
|
-
## FineTuning
|
|
236
|
-
finetuning_method: str | None = None # scale, decoder or unrotated_decoder
|
|
237
|
-
|
|
238
217
|
# Resampling protocol args
|
|
239
|
-
use_ghost_grads: bool = False # want to change this to true on some timeline.
|
|
240
|
-
feature_sampling_window: int = 2000
|
|
241
218
|
dead_feature_window: int = 1000 # unless this window is larger feature sampling,
|
|
242
|
-
|
|
219
|
+
feature_sampling_window: int = 2000
|
|
243
220
|
dead_feature_threshold: float = 1e-8
|
|
244
221
|
|
|
245
222
|
# Evals
|
|
246
223
|
n_eval_batches: int = 10
|
|
247
224
|
eval_batch_size_prompts: int | None = None # useful if evals cause OOM
|
|
248
225
|
|
|
249
|
-
|
|
250
|
-
log_to_wandb: bool = True
|
|
251
|
-
log_activations_store_to_wandb: bool = False
|
|
252
|
-
log_optimizer_state_to_wandb: bool = False
|
|
253
|
-
wandb_project: str = "mats_sae_training_language_model"
|
|
254
|
-
wandb_id: str | None = None
|
|
255
|
-
run_name: str | None = None
|
|
256
|
-
wandb_entity: str | None = None
|
|
257
|
-
wandb_log_frequency: int = 10
|
|
258
|
-
eval_every_n_wandb_logs: int = 100 # logs every 1000 steps.
|
|
226
|
+
logger: LoggingConfig = field(default_factory=LoggingConfig)
|
|
259
227
|
|
|
260
228
|
# Misc
|
|
261
|
-
resume: bool = False
|
|
262
229
|
n_checkpoints: int = 0
|
|
263
230
|
checkpoint_path: str = "checkpoints"
|
|
264
231
|
verbose: bool = True
|
|
@@ -269,12 +236,6 @@ class LanguageModelSAERunnerConfig:
|
|
|
269
236
|
exclude_special_tokens: bool | list[int] = False
|
|
270
237
|
|
|
271
238
|
def __post_init__(self):
|
|
272
|
-
if self.resume:
|
|
273
|
-
raise ValueError(
|
|
274
|
-
"Resuming is no longer supported. You can finetune a trained SAE using cfg.from_pretrained path."
|
|
275
|
-
+ "If you want to load an SAE with resume=True in the config, please manually set resume=False in that config."
|
|
276
|
-
)
|
|
277
|
-
|
|
278
239
|
if self.use_cached_activations and self.cached_activations_path is None:
|
|
279
240
|
self.cached_activations_path = _default_cached_activations_path(
|
|
280
241
|
self.dataset_path,
|
|
@@ -282,37 +243,12 @@ class LanguageModelSAERunnerConfig:
|
|
|
282
243
|
self.hook_name,
|
|
283
244
|
self.hook_head_index,
|
|
284
245
|
)
|
|
285
|
-
|
|
286
|
-
if self.activation_fn is None:
|
|
287
|
-
self.activation_fn = "topk" if self.architecture == "topk" else "relu"
|
|
288
|
-
|
|
289
|
-
if self.architecture == "topk" and self.activation_fn != "topk":
|
|
290
|
-
raise ValueError("If using topk architecture, activation_fn must be topk.")
|
|
291
|
-
|
|
292
|
-
if self.activation_fn_kwargs is None:
|
|
293
|
-
self.activation_fn_kwargs = (
|
|
294
|
-
{"k": 100} if self.activation_fn == "topk" else {}
|
|
295
|
-
)
|
|
296
|
-
|
|
297
|
-
if self.architecture == "topk" and self.activation_fn_kwargs.get("k") is None:
|
|
298
|
-
raise ValueError(
|
|
299
|
-
"activation_fn_kwargs.k must be provided for topk architecture."
|
|
300
|
-
)
|
|
301
|
-
|
|
302
|
-
if self.d_sae is not None and self.expansion_factor is not None:
|
|
303
|
-
raise ValueError("You can't set both d_sae and expansion_factor.")
|
|
304
|
-
|
|
305
|
-
if self.d_sae is None and self.expansion_factor is None:
|
|
306
|
-
self.expansion_factor = 4
|
|
307
|
-
|
|
308
|
-
if self.d_sae is None and self.expansion_factor is not None:
|
|
309
|
-
self.d_sae = self.d_in * self.expansion_factor
|
|
310
246
|
self.tokens_per_buffer = (
|
|
311
247
|
self.train_batch_size_tokens * self.context_size * self.n_batches_in_buffer
|
|
312
248
|
)
|
|
313
249
|
|
|
314
|
-
if self.run_name is None:
|
|
315
|
-
self.run_name = f"{self.
|
|
250
|
+
if self.logger.run_name is None:
|
|
251
|
+
self.logger.run_name = f"{self.sae.architecture()}-{self.sae.d_sae}-LR-{self.lr}-Tokens-{self.training_tokens:3.3e}"
|
|
316
252
|
|
|
317
253
|
if self.model_from_pretrained_kwargs is None:
|
|
318
254
|
if self.model_class_name == "HookedTransformer":
|
|
@@ -320,44 +256,13 @@ class LanguageModelSAERunnerConfig:
|
|
|
320
256
|
else:
|
|
321
257
|
self.model_from_pretrained_kwargs = {}
|
|
322
258
|
|
|
323
|
-
if self.b_dec_init_method not in ["geometric_median", "mean", "zeros"]:
|
|
324
|
-
raise ValueError(
|
|
325
|
-
f"b_dec_init_method must be geometric_median, mean, or zeros. Got {self.b_dec_init_method}"
|
|
326
|
-
)
|
|
327
|
-
|
|
328
|
-
if self.normalize_sae_decoder and self.decoder_heuristic_init:
|
|
329
|
-
raise ValueError(
|
|
330
|
-
"You can't normalize the decoder and use heuristic initialization."
|
|
331
|
-
)
|
|
332
|
-
|
|
333
|
-
if self.normalize_sae_decoder and self.scale_sparsity_penalty_by_decoder_norm:
|
|
334
|
-
raise ValueError(
|
|
335
|
-
"Weighting loss by decoder norm makes no sense if you are normalizing the decoder weight norms to 1"
|
|
336
|
-
)
|
|
337
|
-
|
|
338
|
-
# if we use decoder fine tuning, we can't be applying b_dec to the input
|
|
339
|
-
if (self.finetuning_method == "decoder") and (self.apply_b_dec_to_input):
|
|
340
|
-
raise ValueError(
|
|
341
|
-
"If we are fine tuning the decoder, we can't be applying b_dec to the input.\nSet apply_b_dec_to_input to False."
|
|
342
|
-
)
|
|
343
|
-
|
|
344
|
-
if self.normalize_activations not in [
|
|
345
|
-
"none",
|
|
346
|
-
"expected_average_only_in",
|
|
347
|
-
"constant_norm_rescale",
|
|
348
|
-
"layer_norm",
|
|
349
|
-
]:
|
|
350
|
-
raise ValueError(
|
|
351
|
-
f"normalize_activations must be none, layer_norm, expected_average_only_in, or constant_norm_rescale. Got {self.normalize_activations}"
|
|
352
|
-
)
|
|
353
|
-
|
|
354
259
|
if self.act_store_device == "with_model":
|
|
355
260
|
self.act_store_device = self.device
|
|
356
261
|
|
|
357
262
|
if self.lr_end is None:
|
|
358
263
|
self.lr_end = self.lr / 10
|
|
359
264
|
|
|
360
|
-
unique_id = self.wandb_id
|
|
265
|
+
unique_id = self.logger.wandb_id
|
|
361
266
|
if unique_id is None:
|
|
362
267
|
unique_id = cast(
|
|
363
268
|
Any, wandb
|
|
@@ -366,7 +271,7 @@ class LanguageModelSAERunnerConfig:
|
|
|
366
271
|
|
|
367
272
|
if self.verbose:
|
|
368
273
|
logger.info(
|
|
369
|
-
f"Run name: {self.
|
|
274
|
+
f"Run name: {self.sae.architecture()}-{self.sae.d_sae}-LR-{self.lr}-Tokens-{self.training_tokens:3.3e}"
|
|
370
275
|
)
|
|
371
276
|
# Print out some useful info:
|
|
372
277
|
n_tokens_per_buffer = (
|
|
@@ -385,11 +290,13 @@ class LanguageModelSAERunnerConfig:
|
|
|
385
290
|
)
|
|
386
291
|
|
|
387
292
|
total_training_steps = (
|
|
388
|
-
self.training_tokens
|
|
293
|
+
self.training_tokens
|
|
389
294
|
) // self.train_batch_size_tokens
|
|
390
295
|
logger.info(f"Total training steps: {total_training_steps}")
|
|
391
296
|
|
|
392
|
-
total_wandb_updates =
|
|
297
|
+
total_wandb_updates = (
|
|
298
|
+
total_training_steps // self.logger.wandb_log_frequency
|
|
299
|
+
)
|
|
393
300
|
logger.info(f"Total wandb updates: {total_wandb_updates}")
|
|
394
301
|
|
|
395
302
|
# how many times will we sample dead neurons?
|
|
@@ -411,9 +318,6 @@ class LanguageModelSAERunnerConfig:
|
|
|
411
318
|
f"Number tokens in sparsity calculation window: {self.feature_sampling_window * self.train_batch_size_tokens:.2e}"
|
|
412
319
|
)
|
|
413
320
|
|
|
414
|
-
if self.use_ghost_grads:
|
|
415
|
-
logger.info("Using Ghost Grads.")
|
|
416
|
-
|
|
417
321
|
if self.context_size < 0:
|
|
418
322
|
raise ValueError(
|
|
419
323
|
f"The provided context_size is {self.context_size} is negative. Expecting positive context_size."
|
|
@@ -428,65 +332,26 @@ class LanguageModelSAERunnerConfig:
|
|
|
428
332
|
|
|
429
333
|
@property
|
|
430
334
|
def total_training_tokens(self) -> int:
|
|
431
|
-
return self.training_tokens
|
|
335
|
+
return self.training_tokens
|
|
432
336
|
|
|
433
337
|
@property
|
|
434
338
|
def total_training_steps(self) -> int:
|
|
435
339
|
return self.total_training_tokens // self.train_batch_size_tokens
|
|
436
340
|
|
|
437
|
-
def get_base_sae_cfg_dict(self) -> dict[str, Any]:
|
|
438
|
-
return {
|
|
439
|
-
# TEMP
|
|
440
|
-
"architecture": self.architecture,
|
|
441
|
-
"d_in": self.d_in,
|
|
442
|
-
"d_sae": self.d_sae,
|
|
443
|
-
"dtype": self.dtype,
|
|
444
|
-
"device": self.device,
|
|
445
|
-
"model_name": self.model_name,
|
|
446
|
-
"hook_name": self.hook_name,
|
|
447
|
-
"hook_layer": self.hook_layer,
|
|
448
|
-
"hook_head_index": self.hook_head_index,
|
|
449
|
-
"activation_fn_str": self.activation_fn,
|
|
450
|
-
"apply_b_dec_to_input": self.apply_b_dec_to_input,
|
|
451
|
-
"context_size": self.context_size,
|
|
452
|
-
"prepend_bos": self.prepend_bos,
|
|
453
|
-
"dataset_path": self.dataset_path,
|
|
454
|
-
"dataset_trust_remote_code": self.dataset_trust_remote_code,
|
|
455
|
-
"finetuning_scaling_factor": self.finetuning_method is not None,
|
|
456
|
-
"sae_lens_training_version": self.sae_lens_training_version,
|
|
457
|
-
"normalize_activations": self.normalize_activations,
|
|
458
|
-
"activation_fn_kwargs": self.activation_fn_kwargs,
|
|
459
|
-
"model_from_pretrained_kwargs": self.model_from_pretrained_kwargs,
|
|
460
|
-
"seqpos_slice": self.seqpos_slice,
|
|
461
|
-
}
|
|
462
|
-
|
|
463
341
|
def get_training_sae_cfg_dict(self) -> dict[str, Any]:
|
|
464
|
-
return
|
|
465
|
-
**self.get_base_sae_cfg_dict(),
|
|
466
|
-
"l1_coefficient": self.l1_coefficient,
|
|
467
|
-
"lp_norm": self.lp_norm,
|
|
468
|
-
"use_ghost_grads": self.use_ghost_grads,
|
|
469
|
-
"normalize_sae_decoder": self.normalize_sae_decoder,
|
|
470
|
-
"noise_scale": self.noise_scale,
|
|
471
|
-
"decoder_orthogonal_init": self.decoder_orthogonal_init,
|
|
472
|
-
"mse_loss_normalization": self.mse_loss_normalization,
|
|
473
|
-
"decoder_heuristic_init": self.decoder_heuristic_init,
|
|
474
|
-
"decoder_heuristic_init_norm": self.decoder_heuristic_init_norm,
|
|
475
|
-
"init_encoder_as_decoder_transpose": self.init_encoder_as_decoder_transpose,
|
|
476
|
-
"normalize_activations": self.normalize_activations,
|
|
477
|
-
"jumprelu_init_threshold": self.jumprelu_init_threshold,
|
|
478
|
-
"jumprelu_bandwidth": self.jumprelu_bandwidth,
|
|
479
|
-
"scale_sparsity_penalty_by_decoder_norm": self.scale_sparsity_penalty_by_decoder_norm,
|
|
480
|
-
}
|
|
342
|
+
return self.sae.to_dict()
|
|
481
343
|
|
|
482
344
|
def to_dict(self) -> dict[str, Any]:
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
345
|
+
# Make a shallow copy of config's dictionary
|
|
346
|
+
d = dict(self.__dict__)
|
|
347
|
+
|
|
348
|
+
d["logger"] = asdict(self.logger)
|
|
349
|
+
d["sae"] = self.sae.to_dict()
|
|
350
|
+
# Overwrite fields that might not be JSON-serializable
|
|
351
|
+
d["dtype"] = str(self.dtype)
|
|
352
|
+
d["device"] = str(self.device)
|
|
353
|
+
d["act_store_device"] = str(self.act_store_device)
|
|
354
|
+
return d
|
|
490
355
|
|
|
491
356
|
def to_json(self, path: str) -> None:
|
|
492
357
|
if not os.path.exists(os.path.dirname(path)):
|
|
@@ -496,7 +361,7 @@ class LanguageModelSAERunnerConfig:
|
|
|
496
361
|
json.dump(self.to_dict(), f, indent=2)
|
|
497
362
|
|
|
498
363
|
@classmethod
|
|
499
|
-
def from_json(cls, path: str) -> "LanguageModelSAERunnerConfig":
|
|
364
|
+
def from_json(cls, path: str) -> "LanguageModelSAERunnerConfig[Any]":
|
|
500
365
|
with open(path + "cfg.json") as f:
|
|
501
366
|
cfg = json.load(f)
|
|
502
367
|
|
|
@@ -510,6 +375,27 @@ class LanguageModelSAERunnerConfig:
|
|
|
510
375
|
|
|
511
376
|
return cls(**cfg)
|
|
512
377
|
|
|
378
|
+
def to_sae_trainer_config(self) -> "SAETrainerConfig":
|
|
379
|
+
return SAETrainerConfig(
|
|
380
|
+
n_checkpoints=self.n_checkpoints,
|
|
381
|
+
checkpoint_path=self.checkpoint_path,
|
|
382
|
+
total_training_samples=self.total_training_tokens,
|
|
383
|
+
device=self.device,
|
|
384
|
+
autocast=self.autocast,
|
|
385
|
+
lr=self.lr,
|
|
386
|
+
lr_end=self.lr_end,
|
|
387
|
+
lr_scheduler_name=self.lr_scheduler_name,
|
|
388
|
+
lr_warm_up_steps=self.lr_warm_up_steps,
|
|
389
|
+
adam_beta1=self.adam_beta1,
|
|
390
|
+
adam_beta2=self.adam_beta2,
|
|
391
|
+
lr_decay_steps=self.lr_decay_steps,
|
|
392
|
+
n_restart_cycles=self.n_restart_cycles,
|
|
393
|
+
train_batch_size_samples=self.train_batch_size_tokens,
|
|
394
|
+
dead_feature_window=self.dead_feature_window,
|
|
395
|
+
feature_sampling_window=self.feature_sampling_window,
|
|
396
|
+
logger=self.logger,
|
|
397
|
+
)
|
|
398
|
+
|
|
513
399
|
|
|
514
400
|
@dataclass
|
|
515
401
|
class CacheActivationsRunnerConfig:
|
|
@@ -521,7 +407,6 @@ class CacheActivationsRunnerConfig:
|
|
|
521
407
|
model_name (str): The name of the model to use.
|
|
522
408
|
model_batch_size (int): How many prompts are in the batch of the language model when generating activations.
|
|
523
409
|
hook_name (str): The name of the hook to use.
|
|
524
|
-
hook_layer (int): The layer of the final hook. Currently only support a single hook, so this should be the same as hook_name.
|
|
525
410
|
d_in (int): Dimension of the model.
|
|
526
411
|
total_training_tokens (int): Total number of tokens to process.
|
|
527
412
|
context_size (int): Context size to process. Can be left as -1 if the dataset is tokenized.
|
|
@@ -551,7 +436,6 @@ class CacheActivationsRunnerConfig:
|
|
|
551
436
|
model_name: str
|
|
552
437
|
model_batch_size: int
|
|
553
438
|
hook_name: str
|
|
554
|
-
hook_layer: int
|
|
555
439
|
d_in: int
|
|
556
440
|
training_tokens: int
|
|
557
441
|
|
|
@@ -679,6 +563,10 @@ def _validate_seqpos(seqpos: tuple[int | None, ...], context_size: int) -> None:
|
|
|
679
563
|
|
|
680
564
|
@dataclass
|
|
681
565
|
class PretokenizeRunnerConfig:
|
|
566
|
+
"""
|
|
567
|
+
Configuration class for pretokenizing a dataset.
|
|
568
|
+
"""
|
|
569
|
+
|
|
682
570
|
tokenizer_name: str = "gpt2"
|
|
683
571
|
dataset_path: str = ""
|
|
684
572
|
dataset_name: str | None = None
|
|
@@ -707,3 +595,28 @@ class PretokenizeRunnerConfig:
|
|
|
707
595
|
hf_num_shards: int = 64
|
|
708
596
|
hf_revision: str = "main"
|
|
709
597
|
hf_is_private_repo: bool = False
|
|
598
|
+
|
|
599
|
+
|
|
600
|
+
@dataclass
|
|
601
|
+
class SAETrainerConfig:
|
|
602
|
+
n_checkpoints: int
|
|
603
|
+
checkpoint_path: str
|
|
604
|
+
total_training_samples: int
|
|
605
|
+
device: str
|
|
606
|
+
autocast: bool
|
|
607
|
+
lr: float
|
|
608
|
+
lr_end: float | None
|
|
609
|
+
lr_scheduler_name: str
|
|
610
|
+
lr_warm_up_steps: int
|
|
611
|
+
adam_beta1: float
|
|
612
|
+
adam_beta2: float
|
|
613
|
+
lr_decay_steps: int
|
|
614
|
+
n_restart_cycles: int
|
|
615
|
+
train_batch_size_samples: int
|
|
616
|
+
dead_feature_window: int
|
|
617
|
+
feature_sampling_window: int
|
|
618
|
+
logger: LoggingConfig
|
|
619
|
+
|
|
620
|
+
@property
|
|
621
|
+
def total_training_steps(self) -> int:
|
|
622
|
+
return self.total_training_samples // self.train_batch_size_samples
|
sae_lens/constants.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
DTYPE_MAP = {
|
|
4
|
+
"float32": torch.float32,
|
|
5
|
+
"float64": torch.float64,
|
|
6
|
+
"float16": torch.float16,
|
|
7
|
+
"bfloat16": torch.bfloat16,
|
|
8
|
+
"torch.float32": torch.float32,
|
|
9
|
+
"torch.float64": torch.float64,
|
|
10
|
+
"torch.float16": torch.float16,
|
|
11
|
+
"torch.bfloat16": torch.bfloat16,
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
SPARSITY_FILENAME = "sparsity.safetensors"
|
|
16
|
+
SAE_WEIGHTS_FILENAME = "sae_weights.safetensors"
|
|
17
|
+
SAE_CFG_FILENAME = "cfg.json"
|
|
18
|
+
RUNNER_CFG_FILENAME = "runner_cfg.json"
|
|
19
|
+
SPARSIFY_WEIGHTS_FILENAME = "sae.safetensors"
|
|
20
|
+
ACTIVATIONS_STORE_STATE_FILENAME = "activations_store_state.safetensors"
|
|
21
|
+
ACTIVATION_SCALER_CFG_FILENAME = "activation_scaler.json"
|