sae-lens 5.11.0__py3-none-any.whl → 6.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. sae_lens/__init__.py +60 -7
  2. sae_lens/analysis/hooked_sae_transformer.py +12 -12
  3. sae_lens/analysis/neuronpedia_integration.py +16 -14
  4. sae_lens/cache_activations_runner.py +9 -7
  5. sae_lens/config.py +170 -258
  6. sae_lens/constants.py +21 -0
  7. sae_lens/evals.py +59 -44
  8. sae_lens/llm_sae_training_runner.py +377 -0
  9. sae_lens/load_model.py +52 -4
  10. sae_lens/{toolkit → loading}/pretrained_sae_loaders.py +85 -32
  11. sae_lens/registry.py +49 -0
  12. sae_lens/saes/__init__.py +48 -0
  13. sae_lens/saes/gated_sae.py +254 -0
  14. sae_lens/saes/jumprelu_sae.py +348 -0
  15. sae_lens/saes/sae.py +1076 -0
  16. sae_lens/saes/standard_sae.py +178 -0
  17. sae_lens/saes/topk_sae.py +300 -0
  18. sae_lens/training/activation_scaler.py +53 -0
  19. sae_lens/training/activations_store.py +103 -184
  20. sae_lens/training/mixing_buffer.py +56 -0
  21. sae_lens/training/optim.py +60 -36
  22. sae_lens/training/sae_trainer.py +155 -177
  23. sae_lens/training/types.py +5 -0
  24. sae_lens/training/upload_saes_to_huggingface.py +13 -7
  25. sae_lens/util.py +47 -0
  26. {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/METADATA +1 -1
  27. sae_lens-6.0.0.dist-info/RECORD +37 -0
  28. sae_lens/sae.py +0 -747
  29. sae_lens/sae_training_runner.py +0 -251
  30. sae_lens/training/geometric_median.py +0 -101
  31. sae_lens/training/training_sae.py +0 -710
  32. sae_lens-5.11.0.dist-info/RECORD +0 -28
  33. /sae_lens/{toolkit → loading}/__init__.py +0 -0
  34. /sae_lens/{toolkit → loading}/pretrained_saes_directory.py +0 -0
  35. {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/LICENSE +0 -0
  36. {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/WHEEL +0 -0
sae_lens/config.py CHANGED
@@ -1,8 +1,9 @@
1
1
  import json
2
2
  import math
3
3
  import os
4
- from dataclasses import dataclass, field
5
- from typing import Any, Literal, cast
4
+ from dataclasses import asdict, dataclass, field
5
+ from pathlib import Path
6
+ from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast
6
7
 
7
8
  import simple_parsing
8
9
  import torch
@@ -16,25 +17,17 @@ from datasets import (
16
17
  )
17
18
 
18
19
  from sae_lens import __version__, logger
20
+ from sae_lens.constants import DTYPE_MAP
21
+ from sae_lens.saes.sae import TrainingSAEConfig
19
22
 
20
- DTYPE_MAP = {
21
- "float32": torch.float32,
22
- "float64": torch.float64,
23
- "float16": torch.float16,
24
- "bfloat16": torch.bfloat16,
25
- "torch.float32": torch.float32,
26
- "torch.float64": torch.float64,
27
- "torch.float16": torch.float16,
28
- "torch.bfloat16": torch.bfloat16,
29
- }
30
-
31
- HfDataset = DatasetDict | Dataset | IterableDatasetDict | IterableDataset
23
+ if TYPE_CHECKING:
24
+ pass
32
25
 
26
+ T_TRAINING_SAE_CONFIG = TypeVar(
27
+ "T_TRAINING_SAE_CONFIG", bound=TrainingSAEConfig, covariant=True
28
+ )
33
29
 
34
- SPARSITY_FILENAME = "sparsity.safetensors"
35
- SAE_WEIGHTS_FILENAME = "sae_weights.safetensors"
36
- SPARSIFY_WEIGHTS_FILENAME = "sae.safetensors"
37
- SAE_CFG_FILENAME = "cfg.json"
30
+ HfDataset = DatasetDict | Dataset | IterableDatasetDict | IterableDataset
38
31
 
39
32
 
40
33
  # calling this "json_dict" so error messages will reference "json_dict" being invalid
@@ -55,101 +48,118 @@ def dict_field(default: dict[str, Any] | None, **kwargs: Any) -> Any: # type: i
55
48
 
56
49
 
57
50
  @dataclass
58
- class LanguageModelSAERunnerConfig:
51
+ class LoggingConfig:
52
+ # WANDB
53
+ log_to_wandb: bool = True
54
+ log_activations_store_to_wandb: bool = False
55
+ log_optimizer_state_to_wandb: bool = False
56
+ wandb_project: str = "sae_lens_training"
57
+ wandb_id: str | None = None
58
+ run_name: str | None = None
59
+ wandb_entity: str | None = None
60
+ wandb_log_frequency: int = 10
61
+ eval_every_n_wandb_logs: int = 100 # logs every 100 steps.
62
+
63
+ def log(
64
+ self,
65
+ trainer: Any, # avoid import cycle from importing SAETrainer
66
+ weights_path: Path | str,
67
+ cfg_path: Path | str,
68
+ sparsity_path: Path | str | None,
69
+ wandb_aliases: list[str] | None = None,
70
+ ) -> None:
71
+ # Avoid wandb saving errors such as:
72
+ # ValueError: Artifact name may only contain alphanumeric characters, dashes, underscores, and dots. Invalid name: sae_google/gemma-2b_etc
73
+ sae_name = trainer.sae.get_name().replace("/", "__")
74
+
75
+ # save model weights and cfg
76
+ model_artifact = wandb.Artifact(
77
+ sae_name,
78
+ type="model",
79
+ metadata=dict(trainer.cfg.__dict__),
80
+ )
81
+ model_artifact.add_file(str(weights_path))
82
+ model_artifact.add_file(str(cfg_path))
83
+ wandb.log_artifact(model_artifact, aliases=wandb_aliases)
84
+
85
+ # save log feature sparsity
86
+ sparsity_artifact = wandb.Artifact(
87
+ f"{sae_name}_log_feature_sparsity",
88
+ type="log_feature_sparsity",
89
+ metadata=dict(trainer.cfg.__dict__),
90
+ )
91
+ if sparsity_path is not None:
92
+ sparsity_artifact.add_file(str(sparsity_path))
93
+ wandb.log_artifact(sparsity_artifact)
94
+
95
+
96
+ @dataclass
97
+ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
59
98
  """
60
99
  Configuration for training a sparse autoencoder on a language model.
61
100
 
62
101
  Args:
63
- architecture (str): The architecture to use, either "standard", "gated", "topk", or "jumprelu".
102
+ sae (T_TRAINING_SAE_CONFIG): The configuration for the SAE itself (e.g. StandardSAEConfig, GatedSAEConfig).
64
103
  model_name (str): The name of the model to use. This should be the name of the model in the Hugging Face model hub.
65
104
  model_class_name (str): The name of the class of the model to use. This should be either `HookedTransformer` or `HookedMamba`.
66
105
  hook_name (str): The name of the hook to use. This should be a valid TransformerLens hook.
67
106
  hook_eval (str): NOT CURRENTLY IN USE. The name of the hook to use for evaluation.
68
- hook_layer (int): The index of the layer to hook. Used to stop forward passes early and speed up processing.
69
- hook_head_index (int, optional): When the hook if for an activatio with a head index, we can specify a specific head to use here.
107
+ hook_head_index (int, optional): When the hook is for an activation with a head index, we can specify a specific head to use here.
70
108
  dataset_path (str): A Hugging Face dataset path.
71
109
  dataset_trust_remote_code (bool): Whether to trust remote code when loading datasets from Huggingface.
72
110
  streaming (bool): Whether to stream the dataset. Streaming large datasets is usually practical.
73
- is_dataset_tokenized (bool): NOT IN USE. We used to use this but now automatically detect if the dataset is tokenized.
111
+ is_dataset_tokenized (bool): Whether the dataset is already tokenized.
74
112
  context_size (int): The context size to use when generating activations on which to train the SAE.
75
113
  use_cached_activations (bool): Whether to use cached activations. This is useful when doing sweeps over the same activations.
76
- cached_activations_path (str, optional): The path to the cached activations.
77
- d_in (int): The input dimension of the SAE.
78
- d_sae (int, optional): The output dimension of the SAE. If None, defaults to `d_in * expansion_factor`.
79
- b_dec_init_method (str): The method to use to initialize the decoder bias. Zeros is likely fine.
80
- expansion_factor (int): The expansion factor. Larger is better but more computationally expensive. Default is 4.
81
- activation_fn (str): The activation function to use. Relu is standard.
82
- normalize_sae_decoder (bool): Whether to normalize the SAE decoder. Unit normed decoder weights used to be preferred.
83
- noise_scale (float): Using noise to induce sparsity is supported but not recommended.
114
+ cached_activations_path (str, optional): The path to the cached activations. Defaults to "activations/{dataset_path}/{model_name}/{hook_name}_{hook_head_index}".
84
115
  from_pretrained_path (str, optional): The path to a pretrained SAE. We can finetune an existing SAE if needed.
85
- apply_b_dec_to_input (bool): Whether to apply the decoder bias to the input. Not currently advised.
86
- decoder_orthogonal_init (bool): Whether to use orthogonal initialization for the decoder. Not currently advised.
87
- decoder_heuristic_init (bool): Whether to use heuristic initialization for the decoder. See Anthropic April Update.
88
- init_encoder_as_decoder_transpose (bool): Whether to initialize the encoder as the transpose of the decoder. See Anthropic April Update.
89
- n_batches_in_buffer (int): The number of batches in the buffer. When not using cached activations, a buffer in ram is used. The larger it is, the better shuffled the activations will be.
116
+ n_batches_in_buffer (int): The number of batches in the buffer. When not using cached activations, a buffer in RAM is used. The larger it is, the better shuffled the activations will be.
90
117
  training_tokens (int): The number of training tokens.
91
- finetuning_tokens (int): The number of finetuning tokens. See [here](https://www.lesswrong.com/posts/3JuSjTZyMzaSeTxKk/addressing-feature-suppression-in-saes)
92
- store_batch_size_prompts (int): The batch size for storing activations. This controls how many prompts are in the batch of the language model when generating actiations.
93
- train_batch_size_tokens (int): The batch size for training. This controls the batch size of the SAE Training loop.
94
- normalize_activations (str): Activation Normalization Strategy. Either none, expected_average_only_in (estimate the average activation norm and divide activations by it following Antrhopic April update -> this can be folded post training and set to None), or constant_norm_rescale (at runtime set activation norm to sqrt(d_in) and then scale up the SAE output).
95
- seqpos_slice (tuple): Determines slicing of activations when constructing batches during training. The slice should be (start_pos, end_pos, optional[step_size]), e.g. for Othello we sometimes use (5, -5). Note, step_size > 0.
96
- device (str): The device to use. Usually cuda.
97
- act_store_device (str): The device to use for the activation store. CPU is advised in order to save vram.
118
+ store_batch_size_prompts (int): The batch size for storing activations. This controls how many prompts are in the batch of the language model when generating activations.
119
+ seqpos_slice (tuple[int | None, ...]): Determines slicing of activations when constructing batches during training. The slice should be (start_pos, end_pos, optional[step_size]), e.g. for Othello we sometimes use (5, -5). Note, step_size > 0.
120
+ device (str): The device to use. Usually "cuda".
121
+ act_store_device (str): The device to use for the activation store. "cpu" is advised in order to save VRAM. Defaults to "with_model" which uses the same device as the main model.
98
122
  seed (int): The seed to use.
99
- dtype (str): The data type to use.
123
+ dtype (str): The data type to use for the SAE and activations.
100
124
  prepend_bos (bool): Whether to prepend the beginning of sequence token. You should use whatever the model was trained with.
101
- jumprelu_init_threshold (float): The threshold to initialize for training JumpReLU SAEs.
102
- jumprelu_bandwidth (float): Bandwidth for training JumpReLU SAEs.
103
- autocast (bool): Whether to use autocast during training. Saves vram.
104
- autocast_lm (bool): Whether to use autocast during activation fetching.
105
- compile_llm (bool): Whether to compile the LLM.
106
- llm_compilation_mode (str): The compilation mode to use for the LLM.
107
- compile_sae (bool): Whether to compile the SAE.
108
- sae_compilation_mode (str): The compilation mode to use for the SAE.
109
- adam_beta1 (float): The beta1 parameter for Adam.
110
- adam_beta2 (float): The beta2 parameter for Adam.
111
- mse_loss_normalization (str): The normalization to use for the MSE loss.
112
- l1_coefficient (float): The L1 coefficient.
113
- lp_norm (float): The Lp norm.
114
- scale_sparsity_penalty_by_decoder_norm (bool): Whether to scale the sparsity penalty by the decoder norm.
115
- l1_warm_up_steps (int): The number of warm-up steps for the L1 loss.
125
+ autocast (bool): Whether to use autocast (mixed-precision) during SAE training. Saves VRAM.
126
+ autocast_lm (bool): Whether to use autocast (mixed-precision) during activation fetching. Saves VRAM.
127
+ compile_llm (bool): Whether to compile the LLM using `torch.compile`.
128
+ llm_compilation_mode (str, optional): The compilation mode to use for the LLM if `compile_llm` is True.
129
+ compile_sae (bool): Whether to compile the SAE using `torch.compile`.
130
+ sae_compilation_mode (str, optional): The compilation mode to use for the SAE if `compile_sae` is True.
131
+ train_batch_size_tokens (int): The batch size for training, in tokens. This controls the batch size of the SAE training loop.
132
+ adam_beta1 (float): The beta1 parameter for the Adam optimizer.
133
+ adam_beta2 (float): The beta2 parameter for the Adam optimizer.
116
134
  lr (float): The learning rate.
117
- lr_scheduler_name (str): The name of the learning rate scheduler to use.
135
+ lr_scheduler_name (str): The name of the learning rate scheduler to use (e.g., "constant", "cosineannealing", "cosineannealingwarmrestarts").
118
136
  lr_warm_up_steps (int): The number of warm-up steps for the learning rate.
119
- lr_end (float): The end learning rate if lr_decay_steps is set. Default is lr / 10.
120
- lr_decay_steps (int): The number of decay steps for the learning rate.
121
- n_restart_cycles (int): The number of restart cycles for the cosine annealing warm restarts scheduler.
122
- finetuning_method (str): The method to use for finetuning.
123
- use_ghost_grads (bool): Whether to use ghost gradients.
124
- feature_sampling_window (int): The feature sampling window.
125
- dead_feature_window (int): The dead feature window.
126
- dead_feature_threshold (float): The dead feature threshold.
127
- n_eval_batches (int): The number of evaluation batches.
128
- eval_batch_size_prompts (int): The batch size for evaluation.
129
- log_to_wandb (bool): Whether to log to Weights & Biases.
130
- log_activations_store_to_wandb (bool): NOT CURRENTLY USED. Whether to log the activations store to Weights & Biases.
131
- log_optimizer_state_to_wandb (bool): NOT CURRENTLY USED. Whether to log the optimizer state to Weights & Biases.
132
- wandb_project (str): The Weights & Biases project to log to.
133
- wandb_id (str): The Weights & Biases ID.
134
- run_name (str): The name of the run.
135
- wandb_entity (str): The Weights & Biases entity.
136
- wandb_log_frequency (int): The frequency to log to Weights & Biases.
137
- eval_every_n_wandb_logs (int): The frequency to evaluate.
138
- resume (bool): Whether to resume training.
139
- n_checkpoints (int): The number of checkpoints.
140
- checkpoint_path (str): The path to save checkpoints.
137
+ lr_end (float, optional): The end learning rate if using a scheduler like cosine annealing. Defaults to `lr / 10`.
138
+ lr_decay_steps (int): The number of decay steps for the learning rate if using a scheduler with decay.
139
+ n_restart_cycles (int): The number of restart cycles for the cosine annealing with warm restarts scheduler.
140
+ dead_feature_window (int): The window size (in training steps) for detecting dead features.
141
+ feature_sampling_window (int): The window size (in training steps) for resampling features (e.g. dead features).
142
+ dead_feature_threshold (float): The threshold below which a feature's activation frequency is considered dead.
143
+ n_eval_batches (int): The number of batches to use for evaluation.
144
+ eval_batch_size_prompts (int, optional): The batch size for evaluation, in prompts. Useful if evals cause OOM.
145
+ logger (LoggingConfig): Configuration for logging (e.g. W&B).
146
+ n_checkpoints (int): The number of checkpoints to save during training. 0 means no checkpoints.
147
+ checkpoint_path (str): The path to save checkpoints. A unique ID will be appended to this path.
141
148
  verbose (bool): Whether to print verbose output.
142
- model_kwargs (dict[str, Any]): Additional keyword arguments for the model.
143
- model_from_pretrained_kwargs (dict[str, Any]): Additional keyword arguments for the model from pretrained.
144
- exclude_special_tokens (bool | list[int]): Whether to exclude special tokens from the activations.
149
+ model_kwargs (dict[str, Any]): Keyword arguments for `model.run_with_cache`
150
+ model_from_pretrained_kwargs (dict[str, Any], optional): Additional keyword arguments to pass to the model's `from_pretrained` method.
151
+ sae_lens_version (str): The version of the sae_lens library.
152
+ sae_lens_training_version (str): The version of the sae_lens training library.
153
+ exclude_special_tokens (bool | list[int]): Whether to exclude special tokens from the activations. If True, excludes all special tokens. If a list of ints, excludes those token IDs.
145
154
  """
146
155
 
156
+ sae: T_TRAINING_SAE_CONFIG
157
+
147
158
  # Data Generating Function (Model + Training Distibuion)
148
159
  model_name: str = "gelu-2l"
149
160
  model_class_name: str = "HookedTransformer"
150
161
  hook_name: str = "blocks.0.hook_mlp_out"
151
162
  hook_eval: str = "NOT_IN_USE"
152
- hook_layer: int = 0
153
163
  hook_head_index: int | None = None
154
164
  dataset_path: str = ""
155
165
  dataset_trust_remote_code: bool = True
@@ -162,30 +172,12 @@ class LanguageModelSAERunnerConfig:
162
172
  )
163
173
 
164
174
  # SAE Parameters
165
- architecture: Literal["standard", "gated", "jumprelu", "topk"] = "standard"
166
- d_in: int = 512
167
- d_sae: int | None = None
168
- b_dec_init_method: str = "geometric_median"
169
- expansion_factor: int | None = (
170
- None # defaults to 4 if d_sae and expansion_factor is None
171
- )
172
- activation_fn: str = None # relu, tanh-relu, topk. Default is relu. # type: ignore
173
- activation_fn_kwargs: dict[str, int] = dict_field(default=None) # for topk
174
- normalize_sae_decoder: bool = True
175
- noise_scale: float = 0.0
176
175
  from_pretrained_path: str | None = None
177
- apply_b_dec_to_input: bool = True
178
- decoder_orthogonal_init: bool = False
179
- decoder_heuristic_init: bool = False
180
- decoder_heuristic_init_norm: float = 0.1
181
- init_encoder_as_decoder_transpose: bool = False
182
176
 
183
177
  # Activation Store Parameters
184
178
  n_batches_in_buffer: int = 20
185
179
  training_tokens: int = 2_000_000
186
- finetuning_tokens: int = 0
187
180
  store_batch_size_prompts: int = 32
188
- normalize_activations: str = "none" # none, expected_average_only_in (Anthropic April Update), constant_norm_rescale (Anthropic Feb Update)
189
181
  seqpos_slice: tuple[int | None, ...] = (None,)
190
182
 
191
183
  # Misc
@@ -195,10 +187,6 @@ class LanguageModelSAERunnerConfig:
195
187
  dtype: str = "float32" # type: ignore #
196
188
  prepend_bos: bool = True
197
189
 
198
- # JumpReLU Parameters
199
- jumprelu_init_threshold: float = 0.001
200
- jumprelu_bandwidth: float = 0.001
201
-
202
190
  # Performance - see compilation section of lm_runner.py for info
203
191
  autocast: bool = False # autocast to autocast_dtype during training
204
192
  autocast_lm: bool = False # autocast lm during activation fetching
@@ -213,16 +201,9 @@ class LanguageModelSAERunnerConfig:
213
201
  train_batch_size_tokens: int = 4096
214
202
 
215
203
  ## Adam
216
- adam_beta1: float = 0.0
204
+ adam_beta1: float = 0.9
217
205
  adam_beta2: float = 0.999
218
206
 
219
- ## Loss Function
220
- mse_loss_normalization: str | None = None
221
- l1_coefficient: float = 1e-3
222
- lp_norm: float = 1
223
- scale_sparsity_penalty_by_decoder_norm: bool = False
224
- l1_warm_up_steps: int = 0
225
-
226
207
  ## Learning Rate Schedule
227
208
  lr: float = 3e-4
228
209
  lr_scheduler_name: str = (
@@ -233,33 +214,18 @@ class LanguageModelSAERunnerConfig:
233
214
  lr_decay_steps: int = 0
234
215
  n_restart_cycles: int = 1 # used only for cosineannealingwarmrestarts
235
216
 
236
- ## FineTuning
237
- finetuning_method: str | None = None # scale, decoder or unrotated_decoder
238
-
239
217
  # Resampling protocol args
240
- use_ghost_grads: bool = False # want to change this to true on some timeline.
241
- feature_sampling_window: int = 2000
242
218
  dead_feature_window: int = 1000 # unless this window is larger feature sampling,
243
-
219
+ feature_sampling_window: int = 2000
244
220
  dead_feature_threshold: float = 1e-8
245
221
 
246
222
  # Evals
247
223
  n_eval_batches: int = 10
248
224
  eval_batch_size_prompts: int | None = None # useful if evals cause OOM
249
225
 
250
- # WANDB
251
- log_to_wandb: bool = True
252
- log_activations_store_to_wandb: bool = False
253
- log_optimizer_state_to_wandb: bool = False
254
- wandb_project: str = "mats_sae_training_language_model"
255
- wandb_id: str | None = None
256
- run_name: str | None = None
257
- wandb_entity: str | None = None
258
- wandb_log_frequency: int = 10
259
- eval_every_n_wandb_logs: int = 100 # logs every 1000 steps.
226
+ logger: LoggingConfig = field(default_factory=LoggingConfig)
260
227
 
261
228
  # Misc
262
- resume: bool = False
263
229
  n_checkpoints: int = 0
264
230
  checkpoint_path: str = "checkpoints"
265
231
  verbose: bool = True
@@ -270,12 +236,6 @@ class LanguageModelSAERunnerConfig:
270
236
  exclude_special_tokens: bool | list[int] = False
271
237
 
272
238
  def __post_init__(self):
273
- if self.resume:
274
- raise ValueError(
275
- "Resuming is no longer supported. You can finetune a trained SAE using cfg.from_pretrained path."
276
- + "If you want to load an SAE with resume=True in the config, please manually set resume=False in that config."
277
- )
278
-
279
239
  if self.use_cached_activations and self.cached_activations_path is None:
280
240
  self.cached_activations_path = _default_cached_activations_path(
281
241
  self.dataset_path,
@@ -283,37 +243,12 @@ class LanguageModelSAERunnerConfig:
283
243
  self.hook_name,
284
244
  self.hook_head_index,
285
245
  )
286
-
287
- if self.activation_fn is None:
288
- self.activation_fn = "topk" if self.architecture == "topk" else "relu"
289
-
290
- if self.architecture == "topk" and self.activation_fn != "topk":
291
- raise ValueError("If using topk architecture, activation_fn must be topk.")
292
-
293
- if self.activation_fn_kwargs is None:
294
- self.activation_fn_kwargs = (
295
- {"k": 100} if self.activation_fn == "topk" else {}
296
- )
297
-
298
- if self.architecture == "topk" and self.activation_fn_kwargs.get("k") is None:
299
- raise ValueError(
300
- "activation_fn_kwargs.k must be provided for topk architecture."
301
- )
302
-
303
- if self.d_sae is not None and self.expansion_factor is not None:
304
- raise ValueError("You can't set both d_sae and expansion_factor.")
305
-
306
- if self.d_sae is None and self.expansion_factor is None:
307
- self.expansion_factor = 4
308
-
309
- if self.d_sae is None and self.expansion_factor is not None:
310
- self.d_sae = self.d_in * self.expansion_factor
311
246
  self.tokens_per_buffer = (
312
247
  self.train_batch_size_tokens * self.context_size * self.n_batches_in_buffer
313
248
  )
314
249
 
315
- if self.run_name is None:
316
- self.run_name = f"{self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.training_tokens:3.3e}"
250
+ if self.logger.run_name is None:
251
+ self.logger.run_name = f"{self.sae.architecture()}-{self.sae.d_sae}-LR-{self.lr}-Tokens-{self.training_tokens:3.3e}"
317
252
 
318
253
  if self.model_from_pretrained_kwargs is None:
319
254
  if self.model_class_name == "HookedTransformer":
@@ -321,44 +256,13 @@ class LanguageModelSAERunnerConfig:
321
256
  else:
322
257
  self.model_from_pretrained_kwargs = {}
323
258
 
324
- if self.b_dec_init_method not in ["geometric_median", "mean", "zeros"]:
325
- raise ValueError(
326
- f"b_dec_init_method must be geometric_median, mean, or zeros. Got {self.b_dec_init_method}"
327
- )
328
-
329
- if self.normalize_sae_decoder and self.decoder_heuristic_init:
330
- raise ValueError(
331
- "You can't normalize the decoder and use heuristic initialization."
332
- )
333
-
334
- if self.normalize_sae_decoder and self.scale_sparsity_penalty_by_decoder_norm:
335
- raise ValueError(
336
- "Weighting loss by decoder norm makes no sense if you are normalizing the decoder weight norms to 1"
337
- )
338
-
339
- # if we use decoder fine tuning, we can't be applying b_dec to the input
340
- if (self.finetuning_method == "decoder") and (self.apply_b_dec_to_input):
341
- raise ValueError(
342
- "If we are fine tuning the decoder, we can't be applying b_dec to the input.\nSet apply_b_dec_to_input to False."
343
- )
344
-
345
- if self.normalize_activations not in [
346
- "none",
347
- "expected_average_only_in",
348
- "constant_norm_rescale",
349
- "layer_norm",
350
- ]:
351
- raise ValueError(
352
- f"normalize_activations must be none, layer_norm, expected_average_only_in, or constant_norm_rescale. Got {self.normalize_activations}"
353
- )
354
-
355
259
  if self.act_store_device == "with_model":
356
260
  self.act_store_device = self.device
357
261
 
358
262
  if self.lr_end is None:
359
263
  self.lr_end = self.lr / 10
360
264
 
361
- unique_id = self.wandb_id
265
+ unique_id = self.logger.wandb_id
362
266
  if unique_id is None:
363
267
  unique_id = cast(
364
268
  Any, wandb
@@ -367,7 +271,7 @@ class LanguageModelSAERunnerConfig:
367
271
 
368
272
  if self.verbose:
369
273
  logger.info(
370
- f"Run name: {self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.training_tokens:3.3e}"
274
+ f"Run name: {self.sae.architecture()}-{self.sae.d_sae}-LR-{self.lr}-Tokens-{self.training_tokens:3.3e}"
371
275
  )
372
276
  # Print out some useful info:
373
277
  n_tokens_per_buffer = (
@@ -386,11 +290,13 @@ class LanguageModelSAERunnerConfig:
386
290
  )
387
291
 
388
292
  total_training_steps = (
389
- self.training_tokens + self.finetuning_tokens
293
+ self.training_tokens
390
294
  ) // self.train_batch_size_tokens
391
295
  logger.info(f"Total training steps: {total_training_steps}")
392
296
 
393
- total_wandb_updates = total_training_steps // self.wandb_log_frequency
297
+ total_wandb_updates = (
298
+ total_training_steps // self.logger.wandb_log_frequency
299
+ )
394
300
  logger.info(f"Total wandb updates: {total_wandb_updates}")
395
301
 
396
302
  # how many times will we sample dead neurons?
@@ -412,9 +318,6 @@ class LanguageModelSAERunnerConfig:
412
318
  f"Number tokens in sparsity calculation window: {self.feature_sampling_window * self.train_batch_size_tokens:.2e}"
413
319
  )
414
320
 
415
- if self.use_ghost_grads:
416
- logger.info("Using Ghost Grads.")
417
-
418
321
  if self.context_size < 0:
419
322
  raise ValueError(
420
323
  f"The provided context_size is {self.context_size} is negative. Expecting positive context_size."
@@ -429,65 +332,26 @@ class LanguageModelSAERunnerConfig:
429
332
 
430
333
  @property
431
334
  def total_training_tokens(self) -> int:
432
- return self.training_tokens + self.finetuning_tokens
335
+ return self.training_tokens
433
336
 
434
337
  @property
435
338
  def total_training_steps(self) -> int:
436
339
  return self.total_training_tokens // self.train_batch_size_tokens
437
340
 
438
- def get_base_sae_cfg_dict(self) -> dict[str, Any]:
439
- return {
440
- # TEMP
441
- "architecture": self.architecture,
442
- "d_in": self.d_in,
443
- "d_sae": self.d_sae,
444
- "dtype": self.dtype,
445
- "device": self.device,
446
- "model_name": self.model_name,
447
- "hook_name": self.hook_name,
448
- "hook_layer": self.hook_layer,
449
- "hook_head_index": self.hook_head_index,
450
- "activation_fn_str": self.activation_fn,
451
- "apply_b_dec_to_input": self.apply_b_dec_to_input,
452
- "context_size": self.context_size,
453
- "prepend_bos": self.prepend_bos,
454
- "dataset_path": self.dataset_path,
455
- "dataset_trust_remote_code": self.dataset_trust_remote_code,
456
- "finetuning_scaling_factor": self.finetuning_method is not None,
457
- "sae_lens_training_version": self.sae_lens_training_version,
458
- "normalize_activations": self.normalize_activations,
459
- "activation_fn_kwargs": self.activation_fn_kwargs,
460
- "model_from_pretrained_kwargs": self.model_from_pretrained_kwargs,
461
- "seqpos_slice": self.seqpos_slice,
462
- }
463
-
464
341
  def get_training_sae_cfg_dict(self) -> dict[str, Any]:
465
- return {
466
- **self.get_base_sae_cfg_dict(),
467
- "l1_coefficient": self.l1_coefficient,
468
- "lp_norm": self.lp_norm,
469
- "use_ghost_grads": self.use_ghost_grads,
470
- "normalize_sae_decoder": self.normalize_sae_decoder,
471
- "noise_scale": self.noise_scale,
472
- "decoder_orthogonal_init": self.decoder_orthogonal_init,
473
- "mse_loss_normalization": self.mse_loss_normalization,
474
- "decoder_heuristic_init": self.decoder_heuristic_init,
475
- "decoder_heuristic_init_norm": self.decoder_heuristic_init_norm,
476
- "init_encoder_as_decoder_transpose": self.init_encoder_as_decoder_transpose,
477
- "normalize_activations": self.normalize_activations,
478
- "jumprelu_init_threshold": self.jumprelu_init_threshold,
479
- "jumprelu_bandwidth": self.jumprelu_bandwidth,
480
- "scale_sparsity_penalty_by_decoder_norm": self.scale_sparsity_penalty_by_decoder_norm,
481
- }
342
+ return self.sae.to_dict()
482
343
 
483
344
  def to_dict(self) -> dict[str, Any]:
484
- return {
485
- **self.__dict__,
486
- # some args may not be serializable by default
487
- "dtype": str(self.dtype),
488
- "device": str(self.device),
489
- "act_store_device": str(self.act_store_device),
490
- }
345
+ # Make a shallow copy of config's dictionary
346
+ d = dict(self.__dict__)
347
+
348
+ d["logger"] = asdict(self.logger)
349
+ d["sae"] = self.sae.to_dict()
350
+ # Overwrite fields that might not be JSON-serializable
351
+ d["dtype"] = str(self.dtype)
352
+ d["device"] = str(self.device)
353
+ d["act_store_device"] = str(self.act_store_device)
354
+ return d
491
355
 
492
356
  def to_json(self, path: str) -> None:
493
357
  if not os.path.exists(os.path.dirname(path)):
@@ -497,7 +361,7 @@ class LanguageModelSAERunnerConfig:
497
361
  json.dump(self.to_dict(), f, indent=2)
498
362
 
499
363
  @classmethod
500
- def from_json(cls, path: str) -> "LanguageModelSAERunnerConfig":
364
+ def from_json(cls, path: str) -> "LanguageModelSAERunnerConfig[Any]":
501
365
  with open(path + "cfg.json") as f:
502
366
  cfg = json.load(f)
503
367
 
@@ -511,6 +375,27 @@ class LanguageModelSAERunnerConfig:
511
375
 
512
376
  return cls(**cfg)
513
377
 
378
+ def to_sae_trainer_config(self) -> "SAETrainerConfig":
379
+ return SAETrainerConfig(
380
+ n_checkpoints=self.n_checkpoints,
381
+ checkpoint_path=self.checkpoint_path,
382
+ total_training_samples=self.total_training_tokens,
383
+ device=self.device,
384
+ autocast=self.autocast,
385
+ lr=self.lr,
386
+ lr_end=self.lr_end,
387
+ lr_scheduler_name=self.lr_scheduler_name,
388
+ lr_warm_up_steps=self.lr_warm_up_steps,
389
+ adam_beta1=self.adam_beta1,
390
+ adam_beta2=self.adam_beta2,
391
+ lr_decay_steps=self.lr_decay_steps,
392
+ n_restart_cycles=self.n_restart_cycles,
393
+ train_batch_size_samples=self.train_batch_size_tokens,
394
+ dead_feature_window=self.dead_feature_window,
395
+ feature_sampling_window=self.feature_sampling_window,
396
+ logger=self.logger,
397
+ )
398
+
514
399
 
515
400
  @dataclass
516
401
  class CacheActivationsRunnerConfig:
@@ -522,7 +407,6 @@ class CacheActivationsRunnerConfig:
522
407
  model_name (str): The name of the model to use.
523
408
  model_batch_size (int): How many prompts are in the batch of the language model when generating activations.
524
409
  hook_name (str): The name of the hook to use.
525
- hook_layer (int): The layer of the final hook. Currently only support a single hook, so this should be the same as hook_name.
526
410
  d_in (int): Dimension of the model.
527
411
  total_training_tokens (int): Total number of tokens to process.
528
412
  context_size (int): Context size to process. Can be left as -1 if the dataset is tokenized.
@@ -552,7 +436,6 @@ class CacheActivationsRunnerConfig:
552
436
  model_name: str
553
437
  model_batch_size: int
554
438
  hook_name: str
555
- hook_layer: int
556
439
  d_in: int
557
440
  training_tokens: int
558
441
 
@@ -680,6 +563,10 @@ def _validate_seqpos(seqpos: tuple[int | None, ...], context_size: int) -> None:
680
563
 
681
564
  @dataclass
682
565
  class PretokenizeRunnerConfig:
566
+ """
567
+ Configuration class for pretokenizing a dataset.
568
+ """
569
+
683
570
  tokenizer_name: str = "gpt2"
684
571
  dataset_path: str = ""
685
572
  dataset_name: str | None = None
@@ -708,3 +595,28 @@ class PretokenizeRunnerConfig:
708
595
  hf_num_shards: int = 64
709
596
  hf_revision: str = "main"
710
597
  hf_is_private_repo: bool = False
598
+
599
+
600
+ @dataclass
601
+ class SAETrainerConfig:
602
+ n_checkpoints: int
603
+ checkpoint_path: str
604
+ total_training_samples: int
605
+ device: str
606
+ autocast: bool
607
+ lr: float
608
+ lr_end: float | None
609
+ lr_scheduler_name: str
610
+ lr_warm_up_steps: int
611
+ adam_beta1: float
612
+ adam_beta2: float
613
+ lr_decay_steps: int
614
+ n_restart_cycles: int
615
+ train_batch_size_samples: int
616
+ dead_feature_window: int
617
+ feature_sampling_window: int
618
+ logger: LoggingConfig
619
+
620
+ @property
621
+ def total_training_steps(self) -> int:
622
+ return self.total_training_samples // self.train_batch_size_samples
sae_lens/constants.py ADDED
@@ -0,0 +1,21 @@
1
+ import torch
2
+
3
+ DTYPE_MAP = {
4
+ "float32": torch.float32,
5
+ "float64": torch.float64,
6
+ "float16": torch.float16,
7
+ "bfloat16": torch.bfloat16,
8
+ "torch.float32": torch.float32,
9
+ "torch.float64": torch.float64,
10
+ "torch.float16": torch.float16,
11
+ "torch.bfloat16": torch.bfloat16,
12
+ }
13
+
14
+
15
+ SPARSITY_FILENAME = "sparsity.safetensors"
16
+ SAE_WEIGHTS_FILENAME = "sae_weights.safetensors"
17
+ SAE_CFG_FILENAME = "cfg.json"
18
+ RUNNER_CFG_FILENAME = "runner_cfg.json"
19
+ SPARSIFY_WEIGHTS_FILENAME = "sae.safetensors"
20
+ ACTIVATIONS_STORE_STATE_FILENAME = "activations_store_state.safetensors"
21
+ ACTIVATION_SCALER_CFG_FILENAME = "activation_scaler.json"