sae-lens 6.0.0rc1__py3-none-any.whl → 6.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +50 -16
- sae_lens/analysis/hooked_sae_transformer.py +10 -10
- sae_lens/analysis/neuronpedia_integration.py +13 -11
- sae_lens/cache_activations_runner.py +2 -1
- sae_lens/config.py +59 -231
- sae_lens/constants.py +18 -0
- sae_lens/evals.py +16 -13
- sae_lens/loading/pretrained_sae_loaders.py +36 -3
- sae_lens/registry.py +49 -0
- sae_lens/sae_training_runner.py +22 -21
- sae_lens/saes/__init__.py +48 -0
- sae_lens/saes/gated_sae.py +70 -59
- sae_lens/saes/jumprelu_sae.py +58 -72
- sae_lens/saes/sae.py +250 -272
- sae_lens/saes/standard_sae.py +75 -57
- sae_lens/saes/topk_sae.py +72 -83
- sae_lens/training/activations_store.py +31 -15
- sae_lens/training/optim.py +60 -36
- sae_lens/training/sae_trainer.py +44 -69
- sae_lens/training/upload_saes_to_huggingface.py +11 -5
- sae_lens/util.py +28 -0
- {sae_lens-6.0.0rc1.dist-info → sae_lens-6.0.0rc2.dist-info}/METADATA +1 -1
- sae_lens-6.0.0rc2.dist-info/RECORD +35 -0
- {sae_lens-6.0.0rc1.dist-info → sae_lens-6.0.0rc2.dist-info}/WHEEL +1 -1
- sae_lens/regsitry.py +0 -34
- sae_lens-6.0.0rc1.dist-info/RECORD +0 -32
- {sae_lens-6.0.0rc1.dist-info → sae_lens-6.0.0rc2.dist-info}/LICENSE +0 -0
sae_lens/config.py
CHANGED
|
@@ -3,7 +3,7 @@ import math
|
|
|
3
3
|
import os
|
|
4
4
|
from dataclasses import asdict, dataclass, field
|
|
5
5
|
from pathlib import Path
|
|
6
|
-
from typing import Any, Literal, cast
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast
|
|
7
7
|
|
|
8
8
|
import simple_parsing
|
|
9
9
|
import torch
|
|
@@ -17,24 +17,15 @@ from datasets import (
|
|
|
17
17
|
)
|
|
18
18
|
|
|
19
19
|
from sae_lens import __version__, logger
|
|
20
|
+
from sae_lens.constants import DTYPE_MAP
|
|
21
|
+
from sae_lens.saes.sae import TrainingSAEConfig
|
|
20
22
|
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
"float64": torch.float64,
|
|
24
|
-
"float16": torch.float16,
|
|
25
|
-
"bfloat16": torch.bfloat16,
|
|
26
|
-
"torch.float32": torch.float32,
|
|
27
|
-
"torch.float64": torch.float64,
|
|
28
|
-
"torch.float16": torch.float16,
|
|
29
|
-
"torch.bfloat16": torch.bfloat16,
|
|
30
|
-
}
|
|
31
|
-
|
|
32
|
-
HfDataset = DatasetDict | Dataset | IterableDatasetDict | IterableDataset
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
pass
|
|
33
25
|
|
|
26
|
+
T_TRAINING_SAE_CONFIG = TypeVar("T_TRAINING_SAE_CONFIG", bound=TrainingSAEConfig)
|
|
34
27
|
|
|
35
|
-
|
|
36
|
-
SAE_WEIGHTS_FILENAME = "sae_weights.safetensors"
|
|
37
|
-
SAE_CFG_FILENAME = "cfg.json"
|
|
28
|
+
HfDataset = DatasetDict | Dataset | IterableDatasetDict | IterableDataset
|
|
38
29
|
|
|
39
30
|
|
|
40
31
|
# calling this "json_dict" so error messages will reference "json_dict" being invalid
|
|
@@ -101,95 +92,68 @@ class LoggingConfig:
|
|
|
101
92
|
|
|
102
93
|
|
|
103
94
|
@dataclass
|
|
104
|
-
class LanguageModelSAERunnerConfig:
|
|
95
|
+
class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
105
96
|
"""
|
|
106
97
|
Configuration for training a sparse autoencoder on a language model.
|
|
107
98
|
|
|
108
99
|
Args:
|
|
109
|
-
|
|
100
|
+
sae (T_TRAINING_SAE_CONFIG): The configuration for the SAE itself (e.g. StandardSAEConfig, GatedSAEConfig).
|
|
110
101
|
model_name (str): The name of the model to use. This should be the name of the model in the Hugging Face model hub.
|
|
111
102
|
model_class_name (str): The name of the class of the model to use. This should be either `HookedTransformer` or `HookedMamba`.
|
|
112
103
|
hook_name (str): The name of the hook to use. This should be a valid TransformerLens hook.
|
|
113
104
|
hook_eval (str): NOT CURRENTLY IN USE. The name of the hook to use for evaluation.
|
|
114
105
|
hook_layer (int): The index of the layer to hook. Used to stop forward passes early and speed up processing.
|
|
115
|
-
hook_head_index (int, optional): When the hook
|
|
106
|
+
hook_head_index (int, optional): When the hook is for an activation with a head index, we can specify a specific head to use here.
|
|
116
107
|
dataset_path (str): A Hugging Face dataset path.
|
|
117
108
|
dataset_trust_remote_code (bool): Whether to trust remote code when loading datasets from Huggingface.
|
|
118
109
|
streaming (bool): Whether to stream the dataset. Streaming large datasets is usually practical.
|
|
119
|
-
is_dataset_tokenized (bool):
|
|
110
|
+
is_dataset_tokenized (bool): Whether the dataset is already tokenized.
|
|
120
111
|
context_size (int): The context size to use when generating activations on which to train the SAE.
|
|
121
112
|
use_cached_activations (bool): Whether to use cached activations. This is useful when doing sweeps over the same activations.
|
|
122
|
-
cached_activations_path (str, optional): The path to the cached activations.
|
|
123
|
-
d_in (int): The input dimension of the SAE.
|
|
124
|
-
d_sae (int, optional): The output dimension of the SAE. If None, defaults to `d_in * expansion_factor`.
|
|
125
|
-
b_dec_init_method (str): The method to use to initialize the decoder bias. Zeros is likely fine.
|
|
126
|
-
expansion_factor (int): The expansion factor. Larger is better but more computationally expensive. Default is 4.
|
|
127
|
-
activation_fn (str): The activation function to use. Relu is standard.
|
|
128
|
-
normalize_sae_decoder (bool): Whether to normalize the SAE decoder. Unit normed decoder weights used to be preferred.
|
|
129
|
-
noise_scale (float): Using noise to induce sparsity is supported but not recommended.
|
|
113
|
+
cached_activations_path (str, optional): The path to the cached activations. Defaults to "activations/{dataset_path}/{model_name}/{hook_name}_{hook_head_index}".
|
|
130
114
|
from_pretrained_path (str, optional): The path to a pretrained SAE. We can finetune an existing SAE if needed.
|
|
131
|
-
|
|
132
|
-
decoder_orthogonal_init (bool): Whether to use orthogonal initialization for the decoder. Not currently advised.
|
|
133
|
-
decoder_heuristic_init (bool): Whether to use heuristic initialization for the decoder. See Anthropic April Update.
|
|
134
|
-
init_encoder_as_decoder_transpose (bool): Whether to initialize the encoder as the transpose of the decoder. See Anthropic April Update.
|
|
135
|
-
n_batches_in_buffer (int): The number of batches in the buffer. When not using cached activations, a buffer in ram is used. The larger it is, the better shuffled the activations will be.
|
|
115
|
+
n_batches_in_buffer (int): The number of batches in the buffer. When not using cached activations, a buffer in RAM is used. The larger it is, the better shuffled the activations will be.
|
|
136
116
|
training_tokens (int): The number of training tokens.
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
seqpos_slice (tuple): Determines slicing of activations when constructing batches during training. The slice should be (start_pos, end_pos, optional[step_size]), e.g. for Othello we sometimes use (5, -5). Note, step_size > 0.
|
|
142
|
-
device (str): The device to use. Usually cuda.
|
|
143
|
-
act_store_device (str): The device to use for the activation store. CPU is advised in order to save vram.
|
|
117
|
+
store_batch_size_prompts (int): The batch size for storing activations. This controls how many prompts are in the batch of the language model when generating activations.
|
|
118
|
+
seqpos_slice (tuple[int | None, ...]): Determines slicing of activations when constructing batches during training. The slice should be (start_pos, end_pos, optional[step_size]), e.g. for Othello we sometimes use (5, -5). Note, step_size > 0.
|
|
119
|
+
device (str): The device to use. Usually "cuda".
|
|
120
|
+
act_store_device (str): The device to use for the activation store. "cpu" is advised in order to save VRAM. Defaults to "with_model" which uses the same device as the main model.
|
|
144
121
|
seed (int): The seed to use.
|
|
145
|
-
dtype (str): The data type to use.
|
|
122
|
+
dtype (str): The data type to use for the SAE and activations.
|
|
146
123
|
prepend_bos (bool): Whether to prepend the beginning of sequence token. You should use whatever the model was trained with.
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
adam_beta2 (float): The beta2 parameter for Adam.
|
|
157
|
-
mse_loss_normalization (str): The normalization to use for the MSE loss.
|
|
158
|
-
l1_coefficient (float): The L1 coefficient.
|
|
159
|
-
lp_norm (float): The Lp norm.
|
|
160
|
-
scale_sparsity_penalty_by_decoder_norm (bool): Whether to scale the sparsity penalty by the decoder norm.
|
|
161
|
-
l1_warm_up_steps (int): The number of warm-up steps for the L1 loss.
|
|
124
|
+
autocast (bool): Whether to use autocast (mixed-precision) during SAE training. Saves VRAM.
|
|
125
|
+
autocast_lm (bool): Whether to use autocast (mixed-precision) during activation fetching. Saves VRAM.
|
|
126
|
+
compile_llm (bool): Whether to compile the LLM using `torch.compile`.
|
|
127
|
+
llm_compilation_mode (str, optional): The compilation mode to use for the LLM if `compile_llm` is True.
|
|
128
|
+
compile_sae (bool): Whether to compile the SAE using `torch.compile`.
|
|
129
|
+
sae_compilation_mode (str, optional): The compilation mode to use for the SAE if `compile_sae` is True.
|
|
130
|
+
train_batch_size_tokens (int): The batch size for training, in tokens. This controls the batch size of the SAE training loop.
|
|
131
|
+
adam_beta1 (float): The beta1 parameter for the Adam optimizer.
|
|
132
|
+
adam_beta2 (float): The beta2 parameter for the Adam optimizer.
|
|
162
133
|
lr (float): The learning rate.
|
|
163
|
-
lr_scheduler_name (str): The name of the learning rate scheduler to use.
|
|
134
|
+
lr_scheduler_name (str): The name of the learning rate scheduler to use (e.g., "constant", "cosineannealing", "cosineannealingwarmrestarts").
|
|
164
135
|
lr_warm_up_steps (int): The number of warm-up steps for the learning rate.
|
|
165
|
-
lr_end (float): The end learning rate if
|
|
166
|
-
lr_decay_steps (int): The number of decay steps for the learning rate.
|
|
167
|
-
n_restart_cycles (int): The number of restart cycles for the cosine annealing warm restarts scheduler.
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
log_activations_store_to_wandb (bool): NOT CURRENTLY USED. Whether to log the activations store to Weights & Biases.
|
|
177
|
-
log_optimizer_state_to_wandb (bool): NOT CURRENTLY USED. Whether to log the optimizer state to Weights & Biases.
|
|
178
|
-
wandb_project (str): The Weights & Biases project to log to.
|
|
179
|
-
wandb_id (str): The Weights & Biases ID.
|
|
180
|
-
run_name (str): The name of the run.
|
|
181
|
-
wandb_entity (str): The Weights & Biases entity.
|
|
182
|
-
wandb_log_frequency (int): The frequency to log to Weights & Biases.
|
|
183
|
-
eval_every_n_wandb_logs (int): The frequency to evaluate.
|
|
184
|
-
resume (bool): Whether to resume training.
|
|
185
|
-
n_checkpoints (int): The number of checkpoints.
|
|
186
|
-
checkpoint_path (str): The path to save checkpoints.
|
|
136
|
+
lr_end (float, optional): The end learning rate if using a scheduler like cosine annealing. Defaults to `lr / 10`.
|
|
137
|
+
lr_decay_steps (int): The number of decay steps for the learning rate if using a scheduler with decay.
|
|
138
|
+
n_restart_cycles (int): The number of restart cycles for the cosine annealing with warm restarts scheduler.
|
|
139
|
+
dead_feature_window (int): The window size (in training steps) for detecting dead features.
|
|
140
|
+
feature_sampling_window (int): The window size (in training steps) for resampling features (e.g. dead features).
|
|
141
|
+
dead_feature_threshold (float): The threshold below which a feature's activation frequency is considered dead.
|
|
142
|
+
n_eval_batches (int): The number of batches to use for evaluation.
|
|
143
|
+
eval_batch_size_prompts (int, optional): The batch size for evaluation, in prompts. Useful if evals cause OOM.
|
|
144
|
+
logger (LoggingConfig): Configuration for logging (e.g. W&B).
|
|
145
|
+
n_checkpoints (int): The number of checkpoints to save during training. 0 means no checkpoints.
|
|
146
|
+
checkpoint_path (str): The path to save checkpoints. A unique ID will be appended to this path.
|
|
187
147
|
verbose (bool): Whether to print verbose output.
|
|
188
|
-
model_kwargs (dict[str, Any]):
|
|
189
|
-
model_from_pretrained_kwargs (dict[str, Any]): Additional keyword arguments
|
|
190
|
-
|
|
148
|
+
model_kwargs (dict[str, Any]): Keyword arguments for `model.run_with_cache`
|
|
149
|
+
model_from_pretrained_kwargs (dict[str, Any], optional): Additional keyword arguments to pass to the model's `from_pretrained` method.
|
|
150
|
+
sae_lens_version (str): The version of the sae_lens library.
|
|
151
|
+
sae_lens_training_version (str): The version of the sae_lens training library.
|
|
152
|
+
exclude_special_tokens (bool | list[int]): Whether to exclude special tokens from the activations. If True, excludes all special tokens. If a list of ints, excludes those token IDs.
|
|
191
153
|
"""
|
|
192
154
|
|
|
155
|
+
sae: T_TRAINING_SAE_CONFIG
|
|
156
|
+
|
|
193
157
|
# Data Generating Function (Model + Training Distibuion)
|
|
194
158
|
model_name: str = "gelu-2l"
|
|
195
159
|
model_class_name: str = "HookedTransformer"
|
|
@@ -208,29 +172,12 @@ class LanguageModelSAERunnerConfig:
|
|
|
208
172
|
)
|
|
209
173
|
|
|
210
174
|
# SAE Parameters
|
|
211
|
-
architecture: Literal["standard", "gated", "jumprelu", "topk"] = "standard"
|
|
212
|
-
d_in: int = 512
|
|
213
|
-
d_sae: int | None = None
|
|
214
|
-
b_dec_init_method: str = "geometric_median"
|
|
215
|
-
expansion_factor: int | None = (
|
|
216
|
-
None # defaults to 4 if d_sae and expansion_factor is None
|
|
217
|
-
)
|
|
218
|
-
activation_fn: str = None # relu, tanh-relu, topk. Default is relu. # type: ignore
|
|
219
|
-
activation_fn_kwargs: dict[str, int] = dict_field(default=None) # for topk
|
|
220
|
-
normalize_sae_decoder: bool = True
|
|
221
|
-
noise_scale: float = 0.0
|
|
222
175
|
from_pretrained_path: str | None = None
|
|
223
|
-
apply_b_dec_to_input: bool = True
|
|
224
|
-
decoder_orthogonal_init: bool = False
|
|
225
|
-
decoder_heuristic_init: bool = False
|
|
226
|
-
init_encoder_as_decoder_transpose: bool = False
|
|
227
176
|
|
|
228
177
|
# Activation Store Parameters
|
|
229
178
|
n_batches_in_buffer: int = 20
|
|
230
179
|
training_tokens: int = 2_000_000
|
|
231
|
-
finetuning_tokens: int = 0
|
|
232
180
|
store_batch_size_prompts: int = 32
|
|
233
|
-
normalize_activations: str = "none" # none, expected_average_only_in (Anthropic April Update), constant_norm_rescale (Anthropic Feb Update)
|
|
234
181
|
seqpos_slice: tuple[int | None, ...] = (None,)
|
|
235
182
|
|
|
236
183
|
# Misc
|
|
@@ -240,10 +187,6 @@ class LanguageModelSAERunnerConfig:
|
|
|
240
187
|
dtype: str = "float32" # type: ignore #
|
|
241
188
|
prepend_bos: bool = True
|
|
242
189
|
|
|
243
|
-
# JumpReLU Parameters
|
|
244
|
-
jumprelu_init_threshold: float = 0.001
|
|
245
|
-
jumprelu_bandwidth: float = 0.001
|
|
246
|
-
|
|
247
190
|
# Performance - see compilation section of lm_runner.py for info
|
|
248
191
|
autocast: bool = False # autocast to autocast_dtype during training
|
|
249
192
|
autocast_lm: bool = False # autocast lm during activation fetching
|
|
@@ -261,13 +204,6 @@ class LanguageModelSAERunnerConfig:
|
|
|
261
204
|
adam_beta1: float = 0.0
|
|
262
205
|
adam_beta2: float = 0.999
|
|
263
206
|
|
|
264
|
-
## Loss Function
|
|
265
|
-
mse_loss_normalization: str | None = None
|
|
266
|
-
l1_coefficient: float = 1e-3
|
|
267
|
-
lp_norm: float = 1
|
|
268
|
-
scale_sparsity_penalty_by_decoder_norm: bool = False
|
|
269
|
-
l1_warm_up_steps: int = 0
|
|
270
|
-
|
|
271
207
|
## Learning Rate Schedule
|
|
272
208
|
lr: float = 3e-4
|
|
273
209
|
lr_scheduler_name: str = (
|
|
@@ -278,14 +214,9 @@ class LanguageModelSAERunnerConfig:
|
|
|
278
214
|
lr_decay_steps: int = 0
|
|
279
215
|
n_restart_cycles: int = 1 # used only for cosineannealingwarmrestarts
|
|
280
216
|
|
|
281
|
-
## FineTuning
|
|
282
|
-
finetuning_method: str | None = None # scale, decoder or unrotated_decoder
|
|
283
|
-
|
|
284
217
|
# Resampling protocol args
|
|
285
|
-
use_ghost_grads: bool = False # want to change this to true on some timeline.
|
|
286
|
-
feature_sampling_window: int = 2000
|
|
287
218
|
dead_feature_window: int = 1000 # unless this window is larger feature sampling,
|
|
288
|
-
|
|
219
|
+
feature_sampling_window: int = 2000
|
|
289
220
|
dead_feature_threshold: float = 1e-8
|
|
290
221
|
|
|
291
222
|
# Evals
|
|
@@ -295,7 +226,6 @@ class LanguageModelSAERunnerConfig:
|
|
|
295
226
|
logger: LoggingConfig = field(default_factory=LoggingConfig)
|
|
296
227
|
|
|
297
228
|
# Misc
|
|
298
|
-
resume: bool = False
|
|
299
229
|
n_checkpoints: int = 0
|
|
300
230
|
checkpoint_path: str = "checkpoints"
|
|
301
231
|
verbose: bool = True
|
|
@@ -306,12 +236,6 @@ class LanguageModelSAERunnerConfig:
|
|
|
306
236
|
exclude_special_tokens: bool | list[int] = False
|
|
307
237
|
|
|
308
238
|
def __post_init__(self):
|
|
309
|
-
if self.resume:
|
|
310
|
-
raise ValueError(
|
|
311
|
-
"Resuming is no longer supported. You can finetune a trained SAE using cfg.from_pretrained path."
|
|
312
|
-
+ "If you want to load an SAE with resume=True in the config, please manually set resume=False in that config."
|
|
313
|
-
)
|
|
314
|
-
|
|
315
239
|
if self.use_cached_activations and self.cached_activations_path is None:
|
|
316
240
|
self.cached_activations_path = _default_cached_activations_path(
|
|
317
241
|
self.dataset_path,
|
|
@@ -319,37 +243,12 @@ class LanguageModelSAERunnerConfig:
|
|
|
319
243
|
self.hook_name,
|
|
320
244
|
self.hook_head_index,
|
|
321
245
|
)
|
|
322
|
-
|
|
323
|
-
if self.activation_fn is None:
|
|
324
|
-
self.activation_fn = "topk" if self.architecture == "topk" else "relu"
|
|
325
|
-
|
|
326
|
-
if self.architecture == "topk" and self.activation_fn != "topk":
|
|
327
|
-
raise ValueError("If using topk architecture, activation_fn must be topk.")
|
|
328
|
-
|
|
329
|
-
if self.activation_fn_kwargs is None:
|
|
330
|
-
self.activation_fn_kwargs = (
|
|
331
|
-
{"k": 100} if self.activation_fn == "topk" else {}
|
|
332
|
-
)
|
|
333
|
-
|
|
334
|
-
if self.architecture == "topk" and self.activation_fn_kwargs.get("k") is None:
|
|
335
|
-
raise ValueError(
|
|
336
|
-
"activation_fn_kwargs.k must be provided for topk architecture."
|
|
337
|
-
)
|
|
338
|
-
|
|
339
|
-
if self.d_sae is not None and self.expansion_factor is not None:
|
|
340
|
-
raise ValueError("You can't set both d_sae and expansion_factor.")
|
|
341
|
-
|
|
342
|
-
if self.d_sae is None and self.expansion_factor is None:
|
|
343
|
-
self.expansion_factor = 4
|
|
344
|
-
|
|
345
|
-
if self.d_sae is None and self.expansion_factor is not None:
|
|
346
|
-
self.d_sae = self.d_in * self.expansion_factor
|
|
347
246
|
self.tokens_per_buffer = (
|
|
348
247
|
self.train_batch_size_tokens * self.context_size * self.n_batches_in_buffer
|
|
349
248
|
)
|
|
350
249
|
|
|
351
250
|
if self.logger.run_name is None:
|
|
352
|
-
self.logger.run_name = f"{self.
|
|
251
|
+
self.logger.run_name = f"{self.sae.architecture()}-{self.sae.d_sae}-LR-{self.lr}-Tokens-{self.training_tokens:3.3e}"
|
|
353
252
|
|
|
354
253
|
if self.model_from_pretrained_kwargs is None:
|
|
355
254
|
if self.model_class_name == "HookedTransformer":
|
|
@@ -357,37 +256,6 @@ class LanguageModelSAERunnerConfig:
|
|
|
357
256
|
else:
|
|
358
257
|
self.model_from_pretrained_kwargs = {}
|
|
359
258
|
|
|
360
|
-
if self.b_dec_init_method not in ["geometric_median", "mean", "zeros"]:
|
|
361
|
-
raise ValueError(
|
|
362
|
-
f"b_dec_init_method must be geometric_median, mean, or zeros. Got {self.b_dec_init_method}"
|
|
363
|
-
)
|
|
364
|
-
|
|
365
|
-
if self.normalize_sae_decoder and self.decoder_heuristic_init:
|
|
366
|
-
raise ValueError(
|
|
367
|
-
"You can't normalize the decoder and use heuristic initialization."
|
|
368
|
-
)
|
|
369
|
-
|
|
370
|
-
if self.normalize_sae_decoder and self.scale_sparsity_penalty_by_decoder_norm:
|
|
371
|
-
raise ValueError(
|
|
372
|
-
"Weighting loss by decoder norm makes no sense if you are normalizing the decoder weight norms to 1"
|
|
373
|
-
)
|
|
374
|
-
|
|
375
|
-
# if we use decoder fine tuning, we can't be applying b_dec to the input
|
|
376
|
-
if (self.finetuning_method == "decoder") and (self.apply_b_dec_to_input):
|
|
377
|
-
raise ValueError(
|
|
378
|
-
"If we are fine tuning the decoder, we can't be applying b_dec to the input.\nSet apply_b_dec_to_input to False."
|
|
379
|
-
)
|
|
380
|
-
|
|
381
|
-
if self.normalize_activations not in [
|
|
382
|
-
"none",
|
|
383
|
-
"expected_average_only_in",
|
|
384
|
-
"constant_norm_rescale",
|
|
385
|
-
"layer_norm",
|
|
386
|
-
]:
|
|
387
|
-
raise ValueError(
|
|
388
|
-
f"normalize_activations must be none, layer_norm, expected_average_only_in, or constant_norm_rescale. Got {self.normalize_activations}"
|
|
389
|
-
)
|
|
390
|
-
|
|
391
259
|
if self.act_store_device == "with_model":
|
|
392
260
|
self.act_store_device = self.device
|
|
393
261
|
|
|
@@ -403,7 +271,7 @@ class LanguageModelSAERunnerConfig:
|
|
|
403
271
|
|
|
404
272
|
if self.verbose:
|
|
405
273
|
logger.info(
|
|
406
|
-
f"Run name: {self.
|
|
274
|
+
f"Run name: {self.sae.architecture()}-{self.sae.d_sae}-LR-{self.lr}-Tokens-{self.training_tokens:3.3e}"
|
|
407
275
|
)
|
|
408
276
|
# Print out some useful info:
|
|
409
277
|
n_tokens_per_buffer = (
|
|
@@ -422,7 +290,7 @@ class LanguageModelSAERunnerConfig:
|
|
|
422
290
|
)
|
|
423
291
|
|
|
424
292
|
total_training_steps = (
|
|
425
|
-
self.training_tokens
|
|
293
|
+
self.training_tokens
|
|
426
294
|
) // self.train_batch_size_tokens
|
|
427
295
|
logger.info(f"Total training steps: {total_training_steps}")
|
|
428
296
|
|
|
@@ -450,9 +318,6 @@ class LanguageModelSAERunnerConfig:
|
|
|
450
318
|
f"Number tokens in sparsity calculation window: {self.feature_sampling_window * self.train_batch_size_tokens:.2e}"
|
|
451
319
|
)
|
|
452
320
|
|
|
453
|
-
if self.use_ghost_grads:
|
|
454
|
-
logger.info("Using Ghost Grads.")
|
|
455
|
-
|
|
456
321
|
if self.context_size < 0:
|
|
457
322
|
raise ValueError(
|
|
458
323
|
f"The provided context_size is {self.context_size} is negative. Expecting positive context_size."
|
|
@@ -467,62 +332,21 @@ class LanguageModelSAERunnerConfig:
|
|
|
467
332
|
|
|
468
333
|
@property
|
|
469
334
|
def total_training_tokens(self) -> int:
|
|
470
|
-
return self.training_tokens
|
|
335
|
+
return self.training_tokens
|
|
471
336
|
|
|
472
337
|
@property
|
|
473
338
|
def total_training_steps(self) -> int:
|
|
474
339
|
return self.total_training_tokens // self.train_batch_size_tokens
|
|
475
340
|
|
|
476
|
-
def get_base_sae_cfg_dict(self) -> dict[str, Any]:
|
|
477
|
-
return {
|
|
478
|
-
# TEMP
|
|
479
|
-
"architecture": self.architecture,
|
|
480
|
-
"d_in": self.d_in,
|
|
481
|
-
"d_sae": self.d_sae,
|
|
482
|
-
"dtype": self.dtype,
|
|
483
|
-
"device": self.device,
|
|
484
|
-
"model_name": self.model_name,
|
|
485
|
-
"hook_name": self.hook_name,
|
|
486
|
-
"hook_layer": self.hook_layer,
|
|
487
|
-
"hook_head_index": self.hook_head_index,
|
|
488
|
-
"activation_fn": self.activation_fn,
|
|
489
|
-
"apply_b_dec_to_input": self.apply_b_dec_to_input,
|
|
490
|
-
"context_size": self.context_size,
|
|
491
|
-
"prepend_bos": self.prepend_bos,
|
|
492
|
-
"dataset_path": self.dataset_path,
|
|
493
|
-
"dataset_trust_remote_code": self.dataset_trust_remote_code,
|
|
494
|
-
"finetuning_scaling_factor": self.finetuning_method is not None,
|
|
495
|
-
"sae_lens_training_version": self.sae_lens_training_version,
|
|
496
|
-
"normalize_activations": self.normalize_activations,
|
|
497
|
-
"activation_fn_kwargs": self.activation_fn_kwargs,
|
|
498
|
-
"model_from_pretrained_kwargs": self.model_from_pretrained_kwargs,
|
|
499
|
-
"seqpos_slice": self.seqpos_slice,
|
|
500
|
-
}
|
|
501
|
-
|
|
502
341
|
def get_training_sae_cfg_dict(self) -> dict[str, Any]:
|
|
503
|
-
return
|
|
504
|
-
**self.get_base_sae_cfg_dict(),
|
|
505
|
-
"l1_coefficient": self.l1_coefficient,
|
|
506
|
-
"lp_norm": self.lp_norm,
|
|
507
|
-
"use_ghost_grads": self.use_ghost_grads,
|
|
508
|
-
"normalize_sae_decoder": self.normalize_sae_decoder,
|
|
509
|
-
"noise_scale": self.noise_scale,
|
|
510
|
-
"decoder_orthogonal_init": self.decoder_orthogonal_init,
|
|
511
|
-
"mse_loss_normalization": self.mse_loss_normalization,
|
|
512
|
-
"decoder_heuristic_init": self.decoder_heuristic_init,
|
|
513
|
-
"init_encoder_as_decoder_transpose": self.init_encoder_as_decoder_transpose,
|
|
514
|
-
"normalize_activations": self.normalize_activations,
|
|
515
|
-
"jumprelu_init_threshold": self.jumprelu_init_threshold,
|
|
516
|
-
"jumprelu_bandwidth": self.jumprelu_bandwidth,
|
|
517
|
-
"scale_sparsity_penalty_by_decoder_norm": self.scale_sparsity_penalty_by_decoder_norm,
|
|
518
|
-
}
|
|
342
|
+
return self.sae.to_dict()
|
|
519
343
|
|
|
520
344
|
def to_dict(self) -> dict[str, Any]:
|
|
521
|
-
# Make a shallow copy of config
|
|
345
|
+
# Make a shallow copy of config's dictionary
|
|
522
346
|
d = dict(self.__dict__)
|
|
523
347
|
|
|
524
348
|
d["logger"] = asdict(self.logger)
|
|
525
|
-
|
|
349
|
+
d["sae"] = self.sae.to_dict()
|
|
526
350
|
# Overwrite fields that might not be JSON-serializable
|
|
527
351
|
d["dtype"] = str(self.dtype)
|
|
528
352
|
d["device"] = str(self.device)
|
|
@@ -537,7 +361,7 @@ class LanguageModelSAERunnerConfig:
|
|
|
537
361
|
json.dump(self.to_dict(), f, indent=2)
|
|
538
362
|
|
|
539
363
|
@classmethod
|
|
540
|
-
def from_json(cls, path: str) -> "LanguageModelSAERunnerConfig":
|
|
364
|
+
def from_json(cls, path: str) -> "LanguageModelSAERunnerConfig[Any]":
|
|
541
365
|
with open(path + "cfg.json") as f:
|
|
542
366
|
cfg = json.load(f)
|
|
543
367
|
|
|
@@ -720,6 +544,10 @@ def _validate_seqpos(seqpos: tuple[int | None, ...], context_size: int) -> None:
|
|
|
720
544
|
|
|
721
545
|
@dataclass
|
|
722
546
|
class PretokenizeRunnerConfig:
|
|
547
|
+
"""
|
|
548
|
+
Configuration class for pretokenizing a dataset.
|
|
549
|
+
"""
|
|
550
|
+
|
|
723
551
|
tokenizer_name: str = "gpt2"
|
|
724
552
|
dataset_path: str = ""
|
|
725
553
|
dataset_name: str | None = None
|
sae_lens/constants.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
DTYPE_MAP = {
|
|
4
|
+
"float32": torch.float32,
|
|
5
|
+
"float64": torch.float64,
|
|
6
|
+
"float16": torch.float16,
|
|
7
|
+
"bfloat16": torch.bfloat16,
|
|
8
|
+
"torch.float32": torch.float32,
|
|
9
|
+
"torch.float64": torch.float64,
|
|
10
|
+
"torch.float16": torch.float16,
|
|
11
|
+
"torch.bfloat16": torch.bfloat16,
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
SPARSITY_FILENAME = "sparsity.safetensors"
|
|
16
|
+
SAE_WEIGHTS_FILENAME = "sae_weights.safetensors"
|
|
17
|
+
SAE_CFG_FILENAME = "cfg.json"
|
|
18
|
+
RUNNER_CFG_FILENAME = "runner_cfg.json"
|
sae_lens/evals.py
CHANGED
|
@@ -20,7 +20,7 @@ from transformer_lens import HookedTransformer
|
|
|
20
20
|
from transformer_lens.hook_points import HookedRootModule
|
|
21
21
|
|
|
22
22
|
from sae_lens.loading.pretrained_saes_directory import get_pretrained_saes_directory
|
|
23
|
-
from sae_lens.saes.sae import SAE
|
|
23
|
+
from sae_lens.saes.sae import SAE, SAEConfig
|
|
24
24
|
from sae_lens.training.activations_store import ActivationsStore
|
|
25
25
|
|
|
26
26
|
|
|
@@ -100,7 +100,7 @@ def get_eval_everything_config(
|
|
|
100
100
|
|
|
101
101
|
@torch.no_grad()
|
|
102
102
|
def run_evals(
|
|
103
|
-
sae: SAE,
|
|
103
|
+
sae: SAE[Any],
|
|
104
104
|
activation_store: ActivationsStore,
|
|
105
105
|
model: HookedRootModule,
|
|
106
106
|
eval_config: EvalConfig = EvalConfig(),
|
|
@@ -108,7 +108,7 @@ def run_evals(
|
|
|
108
108
|
ignore_tokens: set[int | None] = set(),
|
|
109
109
|
verbose: bool = False,
|
|
110
110
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
111
|
-
hook_name = sae.cfg.hook_name
|
|
111
|
+
hook_name = sae.cfg.metadata.hook_name
|
|
112
112
|
actual_batch_size = (
|
|
113
113
|
eval_config.batch_size_prompts or activation_store.store_batch_size_prompts
|
|
114
114
|
)
|
|
@@ -274,7 +274,7 @@ def run_evals(
|
|
|
274
274
|
return all_metrics, feature_metrics
|
|
275
275
|
|
|
276
276
|
|
|
277
|
-
def get_featurewise_weight_based_metrics(sae: SAE) -> dict[str, Any]:
|
|
277
|
+
def get_featurewise_weight_based_metrics(sae: SAE[Any]) -> dict[str, Any]:
|
|
278
278
|
unit_norm_encoders = (sae.W_enc / sae.W_enc.norm(dim=0, keepdim=True)).cpu()
|
|
279
279
|
unit_norm_decoder = (sae.W_dec.T / sae.W_dec.T.norm(dim=0, keepdim=True)).cpu()
|
|
280
280
|
|
|
@@ -298,7 +298,7 @@ def get_featurewise_weight_based_metrics(sae: SAE) -> dict[str, Any]:
|
|
|
298
298
|
|
|
299
299
|
|
|
300
300
|
def get_downstream_reconstruction_metrics(
|
|
301
|
-
sae: SAE,
|
|
301
|
+
sae: SAE[Any],
|
|
302
302
|
model: HookedRootModule,
|
|
303
303
|
activation_store: ActivationsStore,
|
|
304
304
|
compute_kl: bool,
|
|
@@ -366,7 +366,7 @@ def get_downstream_reconstruction_metrics(
|
|
|
366
366
|
|
|
367
367
|
|
|
368
368
|
def get_sparsity_and_variance_metrics(
|
|
369
|
-
sae: SAE,
|
|
369
|
+
sae: SAE[Any],
|
|
370
370
|
model: HookedRootModule,
|
|
371
371
|
activation_store: ActivationsStore,
|
|
372
372
|
n_batches: int,
|
|
@@ -379,8 +379,8 @@ def get_sparsity_and_variance_metrics(
|
|
|
379
379
|
ignore_tokens: set[int | None] = set(),
|
|
380
380
|
verbose: bool = False,
|
|
381
381
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
382
|
-
hook_name = sae.cfg.hook_name
|
|
383
|
-
hook_head_index = sae.cfg.hook_head_index
|
|
382
|
+
hook_name = sae.cfg.metadata.hook_name
|
|
383
|
+
hook_head_index = sae.cfg.metadata.hook_head_index
|
|
384
384
|
|
|
385
385
|
metric_dict = {}
|
|
386
386
|
feature_metric_dict = {}
|
|
@@ -436,7 +436,7 @@ def get_sparsity_and_variance_metrics(
|
|
|
436
436
|
batch_tokens,
|
|
437
437
|
prepend_bos=False,
|
|
438
438
|
names_filter=[hook_name],
|
|
439
|
-
stop_at_layer=sae.cfg.hook_layer + 1,
|
|
439
|
+
stop_at_layer=sae.cfg.metadata.hook_layer + 1,
|
|
440
440
|
**model_kwargs,
|
|
441
441
|
)
|
|
442
442
|
|
|
@@ -580,7 +580,7 @@ def get_sparsity_and_variance_metrics(
|
|
|
580
580
|
|
|
581
581
|
@torch.no_grad()
|
|
582
582
|
def get_recons_loss(
|
|
583
|
-
sae: SAE,
|
|
583
|
+
sae: SAE[SAEConfig],
|
|
584
584
|
model: HookedRootModule,
|
|
585
585
|
batch_tokens: torch.Tensor,
|
|
586
586
|
activation_store: ActivationsStore,
|
|
@@ -588,9 +588,13 @@ def get_recons_loss(
|
|
|
588
588
|
compute_ce_loss: bool,
|
|
589
589
|
ignore_tokens: set[int | None] = set(),
|
|
590
590
|
model_kwargs: Mapping[str, Any] = {},
|
|
591
|
+
hook_name: str | None = None,
|
|
591
592
|
) -> dict[str, Any]:
|
|
592
|
-
hook_name = sae.cfg.hook_name
|
|
593
|
-
head_index = sae.cfg.hook_head_index
|
|
593
|
+
hook_name = hook_name or sae.cfg.metadata.hook_name
|
|
594
|
+
head_index = sae.cfg.metadata.hook_head_index
|
|
595
|
+
|
|
596
|
+
if hook_name is None:
|
|
597
|
+
raise ValueError("hook_name must be provided")
|
|
594
598
|
|
|
595
599
|
original_logits, original_ce_loss = model(
|
|
596
600
|
batch_tokens, return_type="both", loss_per_token=True, **model_kwargs
|
|
@@ -806,7 +810,6 @@ def multiple_evals(
|
|
|
806
810
|
|
|
807
811
|
current_model = None
|
|
808
812
|
current_model_str = None
|
|
809
|
-
print(filtered_saes)
|
|
810
813
|
for sae_release_name, sae_id, _, _ in tqdm(filtered_saes):
|
|
811
814
|
sae = SAE.from_pretrained(
|
|
812
815
|
release=sae_release_name, # see other options in sae_lens/pretrained_saes.yaml
|
|
@@ -7,11 +7,12 @@ import numpy as np
|
|
|
7
7
|
import torch
|
|
8
8
|
from huggingface_hub import hf_hub_download
|
|
9
9
|
from huggingface_hub.utils import EntryNotFoundError
|
|
10
|
+
from packaging.version import Version
|
|
10
11
|
from safetensors import safe_open
|
|
11
12
|
from safetensors.torch import load_file
|
|
12
13
|
|
|
13
14
|
from sae_lens import logger
|
|
14
|
-
from sae_lens.
|
|
15
|
+
from sae_lens.constants import (
|
|
15
16
|
DTYPE_MAP,
|
|
16
17
|
SAE_CFG_FILENAME,
|
|
17
18
|
SAE_WEIGHTS_FILENAME,
|
|
@@ -22,6 +23,8 @@ from sae_lens.loading.pretrained_saes_directory import (
|
|
|
22
23
|
get_pretrained_saes_directory,
|
|
23
24
|
get_repo_id_and_folder_name,
|
|
24
25
|
)
|
|
26
|
+
from sae_lens.registry import get_sae_class
|
|
27
|
+
from sae_lens.util import filter_valid_dataclass_fields
|
|
25
28
|
|
|
26
29
|
|
|
27
30
|
# loaders take in a release, sae_id, device, and whether to force download, and returns a tuple of config, state_dict, and log sparsity
|
|
@@ -174,6 +177,20 @@ def get_sae_lens_config_from_disk(
|
|
|
174
177
|
|
|
175
178
|
|
|
176
179
|
def handle_config_defaulting(cfg_dict: dict[str, Any]) -> dict[str, Any]:
|
|
180
|
+
sae_lens_version = cfg_dict.get("sae_lens_version")
|
|
181
|
+
if not sae_lens_version and "metadata" in cfg_dict:
|
|
182
|
+
sae_lens_version = cfg_dict["metadata"].get("sae_lens_version")
|
|
183
|
+
|
|
184
|
+
if not sae_lens_version or Version(sae_lens_version) < Version("6.0.0-rc.0"):
|
|
185
|
+
cfg_dict = handle_pre_6_0_config(cfg_dict)
|
|
186
|
+
return cfg_dict
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def handle_pre_6_0_config(cfg_dict: dict[str, Any]) -> dict[str, Any]:
|
|
190
|
+
"""
|
|
191
|
+
Format a config dictionary for a Sparse Autoencoder (SAE) to be compatible with the new 6.0 format.
|
|
192
|
+
"""
|
|
193
|
+
|
|
177
194
|
rename_keys_map = {
|
|
178
195
|
"hook_point": "hook_name",
|
|
179
196
|
"hook_point_layer": "hook_layer",
|
|
@@ -202,10 +219,26 @@ def handle_config_defaulting(cfg_dict: dict[str, Any]) -> dict[str, Any]:
|
|
|
202
219
|
else "expected_average_only_in"
|
|
203
220
|
)
|
|
204
221
|
|
|
205
|
-
new_cfg.
|
|
222
|
+
if new_cfg.get("normalize_activations") is None:
|
|
223
|
+
new_cfg["normalize_activations"] = "none"
|
|
224
|
+
|
|
206
225
|
new_cfg.setdefault("device", "cpu")
|
|
207
226
|
|
|
208
|
-
|
|
227
|
+
architecture = new_cfg.get("architecture", "standard")
|
|
228
|
+
|
|
229
|
+
config_class = get_sae_class(architecture)[1]
|
|
230
|
+
|
|
231
|
+
sae_cfg_dict = filter_valid_dataclass_fields(new_cfg, config_class)
|
|
232
|
+
if architecture == "topk":
|
|
233
|
+
sae_cfg_dict["k"] = new_cfg["activation_fn_kwargs"]["k"]
|
|
234
|
+
|
|
235
|
+
# import here to avoid circular import
|
|
236
|
+
from sae_lens.saes.sae import SAEMetadata
|
|
237
|
+
|
|
238
|
+
meta_dict = filter_valid_dataclass_fields(new_cfg, SAEMetadata)
|
|
239
|
+
sae_cfg_dict["metadata"] = meta_dict
|
|
240
|
+
sae_cfg_dict["architecture"] = architecture
|
|
241
|
+
return sae_cfg_dict
|
|
209
242
|
|
|
210
243
|
|
|
211
244
|
def get_connor_rob_hook_z_config_from_hf(
|