sae-lens 5.10.7__py3-none-any.whl → 6.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. sae_lens/__init__.py +60 -7
  2. sae_lens/analysis/hooked_sae_transformer.py +12 -12
  3. sae_lens/analysis/neuronpedia_integration.py +16 -14
  4. sae_lens/cache_activations_runner.py +9 -7
  5. sae_lens/config.py +170 -257
  6. sae_lens/constants.py +21 -0
  7. sae_lens/evals.py +59 -44
  8. sae_lens/llm_sae_training_runner.py +377 -0
  9. sae_lens/load_model.py +53 -5
  10. sae_lens/{toolkit → loading}/pretrained_sae_loaders.py +228 -32
  11. sae_lens/registry.py +49 -0
  12. sae_lens/saes/__init__.py +48 -0
  13. sae_lens/saes/gated_sae.py +254 -0
  14. sae_lens/saes/jumprelu_sae.py +348 -0
  15. sae_lens/saes/sae.py +1076 -0
  16. sae_lens/saes/standard_sae.py +178 -0
  17. sae_lens/saes/topk_sae.py +300 -0
  18. sae_lens/training/activation_scaler.py +53 -0
  19. sae_lens/training/activations_store.py +103 -184
  20. sae_lens/training/mixing_buffer.py +56 -0
  21. sae_lens/training/optim.py +60 -36
  22. sae_lens/training/sae_trainer.py +155 -177
  23. sae_lens/training/types.py +5 -0
  24. sae_lens/training/upload_saes_to_huggingface.py +13 -7
  25. sae_lens/util.py +47 -0
  26. {sae_lens-5.10.7.dist-info → sae_lens-6.0.0.dist-info}/METADATA +1 -1
  27. sae_lens-6.0.0.dist-info/RECORD +37 -0
  28. sae_lens/sae.py +0 -747
  29. sae_lens/sae_training_runner.py +0 -251
  30. sae_lens/training/geometric_median.py +0 -101
  31. sae_lens/training/training_sae.py +0 -710
  32. sae_lens-5.10.7.dist-info/RECORD +0 -28
  33. /sae_lens/{toolkit → loading}/__init__.py +0 -0
  34. /sae_lens/{toolkit → loading}/pretrained_saes_directory.py +0 -0
  35. {sae_lens-5.10.7.dist-info → sae_lens-6.0.0.dist-info}/LICENSE +0 -0
  36. {sae_lens-5.10.7.dist-info → sae_lens-6.0.0.dist-info}/WHEEL +0 -0
sae_lens/config.py CHANGED
@@ -1,8 +1,9 @@
1
1
  import json
2
2
  import math
3
3
  import os
4
- from dataclasses import dataclass, field
5
- from typing import Any, Literal, cast
4
+ from dataclasses import asdict, dataclass, field
5
+ from pathlib import Path
6
+ from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast
6
7
 
7
8
  import simple_parsing
8
9
  import torch
@@ -16,24 +17,17 @@ from datasets import (
16
17
  )
17
18
 
18
19
  from sae_lens import __version__, logger
20
+ from sae_lens.constants import DTYPE_MAP
21
+ from sae_lens.saes.sae import TrainingSAEConfig
19
22
 
20
- DTYPE_MAP = {
21
- "float32": torch.float32,
22
- "float64": torch.float64,
23
- "float16": torch.float16,
24
- "bfloat16": torch.bfloat16,
25
- "torch.float32": torch.float32,
26
- "torch.float64": torch.float64,
27
- "torch.float16": torch.float16,
28
- "torch.bfloat16": torch.bfloat16,
29
- }
30
-
31
- HfDataset = DatasetDict | Dataset | IterableDatasetDict | IterableDataset
23
+ if TYPE_CHECKING:
24
+ pass
32
25
 
26
+ T_TRAINING_SAE_CONFIG = TypeVar(
27
+ "T_TRAINING_SAE_CONFIG", bound=TrainingSAEConfig, covariant=True
28
+ )
33
29
 
34
- SPARSITY_FILENAME = "sparsity.safetensors"
35
- SAE_WEIGHTS_FILENAME = "sae_weights.safetensors"
36
- SAE_CFG_FILENAME = "cfg.json"
30
+ HfDataset = DatasetDict | Dataset | IterableDatasetDict | IterableDataset
37
31
 
38
32
 
39
33
  # calling this "json_dict" so error messages will reference "json_dict" being invalid
@@ -54,101 +48,118 @@ def dict_field(default: dict[str, Any] | None, **kwargs: Any) -> Any: # type: i
54
48
 
55
49
 
56
50
  @dataclass
57
- class LanguageModelSAERunnerConfig:
51
+ class LoggingConfig:
52
+ # WANDB
53
+ log_to_wandb: bool = True
54
+ log_activations_store_to_wandb: bool = False
55
+ log_optimizer_state_to_wandb: bool = False
56
+ wandb_project: str = "sae_lens_training"
57
+ wandb_id: str | None = None
58
+ run_name: str | None = None
59
+ wandb_entity: str | None = None
60
+ wandb_log_frequency: int = 10
61
+ eval_every_n_wandb_logs: int = 100 # logs every 100 steps.
62
+
63
+ def log(
64
+ self,
65
+ trainer: Any, # avoid import cycle from importing SAETrainer
66
+ weights_path: Path | str,
67
+ cfg_path: Path | str,
68
+ sparsity_path: Path | str | None,
69
+ wandb_aliases: list[str] | None = None,
70
+ ) -> None:
71
+ # Avoid wandb saving errors such as:
72
+ # ValueError: Artifact name may only contain alphanumeric characters, dashes, underscores, and dots. Invalid name: sae_google/gemma-2b_etc
73
+ sae_name = trainer.sae.get_name().replace("/", "__")
74
+
75
+ # save model weights and cfg
76
+ model_artifact = wandb.Artifact(
77
+ sae_name,
78
+ type="model",
79
+ metadata=dict(trainer.cfg.__dict__),
80
+ )
81
+ model_artifact.add_file(str(weights_path))
82
+ model_artifact.add_file(str(cfg_path))
83
+ wandb.log_artifact(model_artifact, aliases=wandb_aliases)
84
+
85
+ # save log feature sparsity
86
+ sparsity_artifact = wandb.Artifact(
87
+ f"{sae_name}_log_feature_sparsity",
88
+ type="log_feature_sparsity",
89
+ metadata=dict(trainer.cfg.__dict__),
90
+ )
91
+ if sparsity_path is not None:
92
+ sparsity_artifact.add_file(str(sparsity_path))
93
+ wandb.log_artifact(sparsity_artifact)
94
+
95
+
96
+ @dataclass
97
+ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
58
98
  """
59
99
  Configuration for training a sparse autoencoder on a language model.
60
100
 
61
101
  Args:
62
- architecture (str): The architecture to use, either "standard", "gated", "topk", or "jumprelu".
102
+ sae (T_TRAINING_SAE_CONFIG): The configuration for the SAE itself (e.g. StandardSAEConfig, GatedSAEConfig).
63
103
  model_name (str): The name of the model to use. This should be the name of the model in the Hugging Face model hub.
64
104
  model_class_name (str): The name of the class of the model to use. This should be either `HookedTransformer` or `HookedMamba`.
65
105
  hook_name (str): The name of the hook to use. This should be a valid TransformerLens hook.
66
106
  hook_eval (str): NOT CURRENTLY IN USE. The name of the hook to use for evaluation.
67
- hook_layer (int): The index of the layer to hook. Used to stop forward passes early and speed up processing.
68
- hook_head_index (int, optional): When the hook if for an activatio with a head index, we can specify a specific head to use here.
107
+ hook_head_index (int, optional): When the hook is for an activation with a head index, we can specify a specific head to use here.
69
108
  dataset_path (str): A Hugging Face dataset path.
70
109
  dataset_trust_remote_code (bool): Whether to trust remote code when loading datasets from Huggingface.
71
110
  streaming (bool): Whether to stream the dataset. Streaming large datasets is usually practical.
72
- is_dataset_tokenized (bool): NOT IN USE. We used to use this but now automatically detect if the dataset is tokenized.
111
+ is_dataset_tokenized (bool): Whether the dataset is already tokenized.
73
112
  context_size (int): The context size to use when generating activations on which to train the SAE.
74
113
  use_cached_activations (bool): Whether to use cached activations. This is useful when doing sweeps over the same activations.
75
- cached_activations_path (str, optional): The path to the cached activations.
76
- d_in (int): The input dimension of the SAE.
77
- d_sae (int, optional): The output dimension of the SAE. If None, defaults to `d_in * expansion_factor`.
78
- b_dec_init_method (str): The method to use to initialize the decoder bias. Zeros is likely fine.
79
- expansion_factor (int): The expansion factor. Larger is better but more computationally expensive. Default is 4.
80
- activation_fn (str): The activation function to use. Relu is standard.
81
- normalize_sae_decoder (bool): Whether to normalize the SAE decoder. Unit normed decoder weights used to be preferred.
82
- noise_scale (float): Using noise to induce sparsity is supported but not recommended.
114
+ cached_activations_path (str, optional): The path to the cached activations. Defaults to "activations/{dataset_path}/{model_name}/{hook_name}_{hook_head_index}".
83
115
  from_pretrained_path (str, optional): The path to a pretrained SAE. We can finetune an existing SAE if needed.
84
- apply_b_dec_to_input (bool): Whether to apply the decoder bias to the input. Not currently advised.
85
- decoder_orthogonal_init (bool): Whether to use orthogonal initialization for the decoder. Not currently advised.
86
- decoder_heuristic_init (bool): Whether to use heuristic initialization for the decoder. See Anthropic April Update.
87
- init_encoder_as_decoder_transpose (bool): Whether to initialize the encoder as the transpose of the decoder. See Anthropic April Update.
88
- n_batches_in_buffer (int): The number of batches in the buffer. When not using cached activations, a buffer in ram is used. The larger it is, the better shuffled the activations will be.
116
+ n_batches_in_buffer (int): The number of batches in the buffer. When not using cached activations, a buffer in RAM is used. The larger it is, the better shuffled the activations will be.
89
117
  training_tokens (int): The number of training tokens.
90
- finetuning_tokens (int): The number of finetuning tokens. See [here](https://www.lesswrong.com/posts/3JuSjTZyMzaSeTxKk/addressing-feature-suppression-in-saes)
91
- store_batch_size_prompts (int): The batch size for storing activations. This controls how many prompts are in the batch of the language model when generating actiations.
92
- train_batch_size_tokens (int): The batch size for training. This controls the batch size of the SAE Training loop.
93
- normalize_activations (str): Activation Normalization Strategy. Either none, expected_average_only_in (estimate the average activation norm and divide activations by it following Antrhopic April update -> this can be folded post training and set to None), or constant_norm_rescale (at runtime set activation norm to sqrt(d_in) and then scale up the SAE output).
94
- seqpos_slice (tuple): Determines slicing of activations when constructing batches during training. The slice should be (start_pos, end_pos, optional[step_size]), e.g. for Othello we sometimes use (5, -5). Note, step_size > 0.
95
- device (str): The device to use. Usually cuda.
96
- act_store_device (str): The device to use for the activation store. CPU is advised in order to save vram.
118
+ store_batch_size_prompts (int): The batch size for storing activations. This controls how many prompts are in the batch of the language model when generating activations.
119
+ seqpos_slice (tuple[int | None, ...]): Determines slicing of activations when constructing batches during training. The slice should be (start_pos, end_pos, optional[step_size]), e.g. for Othello we sometimes use (5, -5). Note, step_size > 0.
120
+ device (str): The device to use. Usually "cuda".
121
+ act_store_device (str): The device to use for the activation store. "cpu" is advised in order to save VRAM. Defaults to "with_model" which uses the same device as the main model.
97
122
  seed (int): The seed to use.
98
- dtype (str): The data type to use.
123
+ dtype (str): The data type to use for the SAE and activations.
99
124
  prepend_bos (bool): Whether to prepend the beginning of sequence token. You should use whatever the model was trained with.
100
- jumprelu_init_threshold (float): The threshold to initialize for training JumpReLU SAEs.
101
- jumprelu_bandwidth (float): Bandwidth for training JumpReLU SAEs.
102
- autocast (bool): Whether to use autocast during training. Saves vram.
103
- autocast_lm (bool): Whether to use autocast during activation fetching.
104
- compile_llm (bool): Whether to compile the LLM.
105
- llm_compilation_mode (str): The compilation mode to use for the LLM.
106
- compile_sae (bool): Whether to compile the SAE.
107
- sae_compilation_mode (str): The compilation mode to use for the SAE.
108
- adam_beta1 (float): The beta1 parameter for Adam.
109
- adam_beta2 (float): The beta2 parameter for Adam.
110
- mse_loss_normalization (str): The normalization to use for the MSE loss.
111
- l1_coefficient (float): The L1 coefficient.
112
- lp_norm (float): The Lp norm.
113
- scale_sparsity_penalty_by_decoder_norm (bool): Whether to scale the sparsity penalty by the decoder norm.
114
- l1_warm_up_steps (int): The number of warm-up steps for the L1 loss.
125
+ autocast (bool): Whether to use autocast (mixed-precision) during SAE training. Saves VRAM.
126
+ autocast_lm (bool): Whether to use autocast (mixed-precision) during activation fetching. Saves VRAM.
127
+ compile_llm (bool): Whether to compile the LLM using `torch.compile`.
128
+ llm_compilation_mode (str, optional): The compilation mode to use for the LLM if `compile_llm` is True.
129
+ compile_sae (bool): Whether to compile the SAE using `torch.compile`.
130
+ sae_compilation_mode (str, optional): The compilation mode to use for the SAE if `compile_sae` is True.
131
+ train_batch_size_tokens (int): The batch size for training, in tokens. This controls the batch size of the SAE training loop.
132
+ adam_beta1 (float): The beta1 parameter for the Adam optimizer.
133
+ adam_beta2 (float): The beta2 parameter for the Adam optimizer.
115
134
  lr (float): The learning rate.
116
- lr_scheduler_name (str): The name of the learning rate scheduler to use.
135
+ lr_scheduler_name (str): The name of the learning rate scheduler to use (e.g., "constant", "cosineannealing", "cosineannealingwarmrestarts").
117
136
  lr_warm_up_steps (int): The number of warm-up steps for the learning rate.
118
- lr_end (float): The end learning rate if lr_decay_steps is set. Default is lr / 10.
119
- lr_decay_steps (int): The number of decay steps for the learning rate.
120
- n_restart_cycles (int): The number of restart cycles for the cosine annealing warm restarts scheduler.
121
- finetuning_method (str): The method to use for finetuning.
122
- use_ghost_grads (bool): Whether to use ghost gradients.
123
- feature_sampling_window (int): The feature sampling window.
124
- dead_feature_window (int): The dead feature window.
125
- dead_feature_threshold (float): The dead feature threshold.
126
- n_eval_batches (int): The number of evaluation batches.
127
- eval_batch_size_prompts (int): The batch size for evaluation.
128
- log_to_wandb (bool): Whether to log to Weights & Biases.
129
- log_activations_store_to_wandb (bool): NOT CURRENTLY USED. Whether to log the activations store to Weights & Biases.
130
- log_optimizer_state_to_wandb (bool): NOT CURRENTLY USED. Whether to log the optimizer state to Weights & Biases.
131
- wandb_project (str): The Weights & Biases project to log to.
132
- wandb_id (str): The Weights & Biases ID.
133
- run_name (str): The name of the run.
134
- wandb_entity (str): The Weights & Biases entity.
135
- wandb_log_frequency (int): The frequency to log to Weights & Biases.
136
- eval_every_n_wandb_logs (int): The frequency to evaluate.
137
- resume (bool): Whether to resume training.
138
- n_checkpoints (int): The number of checkpoints.
139
- checkpoint_path (str): The path to save checkpoints.
137
+ lr_end (float, optional): The end learning rate if using a scheduler like cosine annealing. Defaults to `lr / 10`.
138
+ lr_decay_steps (int): The number of decay steps for the learning rate if using a scheduler with decay.
139
+ n_restart_cycles (int): The number of restart cycles for the cosine annealing with warm restarts scheduler.
140
+ dead_feature_window (int): The window size (in training steps) for detecting dead features.
141
+ feature_sampling_window (int): The window size (in training steps) for resampling features (e.g. dead features).
142
+ dead_feature_threshold (float): The threshold below which a feature's activation frequency is considered dead.
143
+ n_eval_batches (int): The number of batches to use for evaluation.
144
+ eval_batch_size_prompts (int, optional): The batch size for evaluation, in prompts. Useful if evals cause OOM.
145
+ logger (LoggingConfig): Configuration for logging (e.g. W&B).
146
+ n_checkpoints (int): The number of checkpoints to save during training. 0 means no checkpoints.
147
+ checkpoint_path (str): The path to save checkpoints. A unique ID will be appended to this path.
140
148
  verbose (bool): Whether to print verbose output.
141
- model_kwargs (dict[str, Any]): Additional keyword arguments for the model.
142
- model_from_pretrained_kwargs (dict[str, Any]): Additional keyword arguments for the model from pretrained.
143
- exclude_special_tokens (bool | list[int]): Whether to exclude special tokens from the activations.
149
+ model_kwargs (dict[str, Any]): Keyword arguments for `model.run_with_cache`
150
+ model_from_pretrained_kwargs (dict[str, Any], optional): Additional keyword arguments to pass to the model's `from_pretrained` method.
151
+ sae_lens_version (str): The version of the sae_lens library.
152
+ sae_lens_training_version (str): The version of the sae_lens training library.
153
+ exclude_special_tokens (bool | list[int]): Whether to exclude special tokens from the activations. If True, excludes all special tokens. If a list of ints, excludes those token IDs.
144
154
  """
145
155
 
156
+ sae: T_TRAINING_SAE_CONFIG
157
+
146
158
  # Data Generating Function (Model + Training Distibuion)
147
159
  model_name: str = "gelu-2l"
148
160
  model_class_name: str = "HookedTransformer"
149
161
  hook_name: str = "blocks.0.hook_mlp_out"
150
162
  hook_eval: str = "NOT_IN_USE"
151
- hook_layer: int = 0
152
163
  hook_head_index: int | None = None
153
164
  dataset_path: str = ""
154
165
  dataset_trust_remote_code: bool = True
@@ -161,30 +172,12 @@ class LanguageModelSAERunnerConfig:
161
172
  )
162
173
 
163
174
  # SAE Parameters
164
- architecture: Literal["standard", "gated", "jumprelu", "topk"] = "standard"
165
- d_in: int = 512
166
- d_sae: int | None = None
167
- b_dec_init_method: str = "geometric_median"
168
- expansion_factor: int | None = (
169
- None # defaults to 4 if d_sae and expansion_factor is None
170
- )
171
- activation_fn: str = None # relu, tanh-relu, topk. Default is relu. # type: ignore
172
- activation_fn_kwargs: dict[str, int] = dict_field(default=None) # for topk
173
- normalize_sae_decoder: bool = True
174
- noise_scale: float = 0.0
175
175
  from_pretrained_path: str | None = None
176
- apply_b_dec_to_input: bool = True
177
- decoder_orthogonal_init: bool = False
178
- decoder_heuristic_init: bool = False
179
- decoder_heuristic_init_norm: float = 0.1
180
- init_encoder_as_decoder_transpose: bool = False
181
176
 
182
177
  # Activation Store Parameters
183
178
  n_batches_in_buffer: int = 20
184
179
  training_tokens: int = 2_000_000
185
- finetuning_tokens: int = 0
186
180
  store_batch_size_prompts: int = 32
187
- normalize_activations: str = "none" # none, expected_average_only_in (Anthropic April Update), constant_norm_rescale (Anthropic Feb Update)
188
181
  seqpos_slice: tuple[int | None, ...] = (None,)
189
182
 
190
183
  # Misc
@@ -194,10 +187,6 @@ class LanguageModelSAERunnerConfig:
194
187
  dtype: str = "float32" # type: ignore #
195
188
  prepend_bos: bool = True
196
189
 
197
- # JumpReLU Parameters
198
- jumprelu_init_threshold: float = 0.001
199
- jumprelu_bandwidth: float = 0.001
200
-
201
190
  # Performance - see compilation section of lm_runner.py for info
202
191
  autocast: bool = False # autocast to autocast_dtype during training
203
192
  autocast_lm: bool = False # autocast lm during activation fetching
@@ -212,16 +201,9 @@ class LanguageModelSAERunnerConfig:
212
201
  train_batch_size_tokens: int = 4096
213
202
 
214
203
  ## Adam
215
- adam_beta1: float = 0.0
204
+ adam_beta1: float = 0.9
216
205
  adam_beta2: float = 0.999
217
206
 
218
- ## Loss Function
219
- mse_loss_normalization: str | None = None
220
- l1_coefficient: float = 1e-3
221
- lp_norm: float = 1
222
- scale_sparsity_penalty_by_decoder_norm: bool = False
223
- l1_warm_up_steps: int = 0
224
-
225
207
  ## Learning Rate Schedule
226
208
  lr: float = 3e-4
227
209
  lr_scheduler_name: str = (
@@ -232,33 +214,18 @@ class LanguageModelSAERunnerConfig:
232
214
  lr_decay_steps: int = 0
233
215
  n_restart_cycles: int = 1 # used only for cosineannealingwarmrestarts
234
216
 
235
- ## FineTuning
236
- finetuning_method: str | None = None # scale, decoder or unrotated_decoder
237
-
238
217
  # Resampling protocol args
239
- use_ghost_grads: bool = False # want to change this to true on some timeline.
240
- feature_sampling_window: int = 2000
241
218
  dead_feature_window: int = 1000 # unless this window is larger feature sampling,
242
-
219
+ feature_sampling_window: int = 2000
243
220
  dead_feature_threshold: float = 1e-8
244
221
 
245
222
  # Evals
246
223
  n_eval_batches: int = 10
247
224
  eval_batch_size_prompts: int | None = None # useful if evals cause OOM
248
225
 
249
- # WANDB
250
- log_to_wandb: bool = True
251
- log_activations_store_to_wandb: bool = False
252
- log_optimizer_state_to_wandb: bool = False
253
- wandb_project: str = "mats_sae_training_language_model"
254
- wandb_id: str | None = None
255
- run_name: str | None = None
256
- wandb_entity: str | None = None
257
- wandb_log_frequency: int = 10
258
- eval_every_n_wandb_logs: int = 100 # logs every 1000 steps.
226
+ logger: LoggingConfig = field(default_factory=LoggingConfig)
259
227
 
260
228
  # Misc
261
- resume: bool = False
262
229
  n_checkpoints: int = 0
263
230
  checkpoint_path: str = "checkpoints"
264
231
  verbose: bool = True
@@ -269,12 +236,6 @@ class LanguageModelSAERunnerConfig:
269
236
  exclude_special_tokens: bool | list[int] = False
270
237
 
271
238
  def __post_init__(self):
272
- if self.resume:
273
- raise ValueError(
274
- "Resuming is no longer supported. You can finetune a trained SAE using cfg.from_pretrained path."
275
- + "If you want to load an SAE with resume=True in the config, please manually set resume=False in that config."
276
- )
277
-
278
239
  if self.use_cached_activations and self.cached_activations_path is None:
279
240
  self.cached_activations_path = _default_cached_activations_path(
280
241
  self.dataset_path,
@@ -282,37 +243,12 @@ class LanguageModelSAERunnerConfig:
282
243
  self.hook_name,
283
244
  self.hook_head_index,
284
245
  )
285
-
286
- if self.activation_fn is None:
287
- self.activation_fn = "topk" if self.architecture == "topk" else "relu"
288
-
289
- if self.architecture == "topk" and self.activation_fn != "topk":
290
- raise ValueError("If using topk architecture, activation_fn must be topk.")
291
-
292
- if self.activation_fn_kwargs is None:
293
- self.activation_fn_kwargs = (
294
- {"k": 100} if self.activation_fn == "topk" else {}
295
- )
296
-
297
- if self.architecture == "topk" and self.activation_fn_kwargs.get("k") is None:
298
- raise ValueError(
299
- "activation_fn_kwargs.k must be provided for topk architecture."
300
- )
301
-
302
- if self.d_sae is not None and self.expansion_factor is not None:
303
- raise ValueError("You can't set both d_sae and expansion_factor.")
304
-
305
- if self.d_sae is None and self.expansion_factor is None:
306
- self.expansion_factor = 4
307
-
308
- if self.d_sae is None and self.expansion_factor is not None:
309
- self.d_sae = self.d_in * self.expansion_factor
310
246
  self.tokens_per_buffer = (
311
247
  self.train_batch_size_tokens * self.context_size * self.n_batches_in_buffer
312
248
  )
313
249
 
314
- if self.run_name is None:
315
- self.run_name = f"{self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.training_tokens:3.3e}"
250
+ if self.logger.run_name is None:
251
+ self.logger.run_name = f"{self.sae.architecture()}-{self.sae.d_sae}-LR-{self.lr}-Tokens-{self.training_tokens:3.3e}"
316
252
 
317
253
  if self.model_from_pretrained_kwargs is None:
318
254
  if self.model_class_name == "HookedTransformer":
@@ -320,44 +256,13 @@ class LanguageModelSAERunnerConfig:
320
256
  else:
321
257
  self.model_from_pretrained_kwargs = {}
322
258
 
323
- if self.b_dec_init_method not in ["geometric_median", "mean", "zeros"]:
324
- raise ValueError(
325
- f"b_dec_init_method must be geometric_median, mean, or zeros. Got {self.b_dec_init_method}"
326
- )
327
-
328
- if self.normalize_sae_decoder and self.decoder_heuristic_init:
329
- raise ValueError(
330
- "You can't normalize the decoder and use heuristic initialization."
331
- )
332
-
333
- if self.normalize_sae_decoder and self.scale_sparsity_penalty_by_decoder_norm:
334
- raise ValueError(
335
- "Weighting loss by decoder norm makes no sense if you are normalizing the decoder weight norms to 1"
336
- )
337
-
338
- # if we use decoder fine tuning, we can't be applying b_dec to the input
339
- if (self.finetuning_method == "decoder") and (self.apply_b_dec_to_input):
340
- raise ValueError(
341
- "If we are fine tuning the decoder, we can't be applying b_dec to the input.\nSet apply_b_dec_to_input to False."
342
- )
343
-
344
- if self.normalize_activations not in [
345
- "none",
346
- "expected_average_only_in",
347
- "constant_norm_rescale",
348
- "layer_norm",
349
- ]:
350
- raise ValueError(
351
- f"normalize_activations must be none, layer_norm, expected_average_only_in, or constant_norm_rescale. Got {self.normalize_activations}"
352
- )
353
-
354
259
  if self.act_store_device == "with_model":
355
260
  self.act_store_device = self.device
356
261
 
357
262
  if self.lr_end is None:
358
263
  self.lr_end = self.lr / 10
359
264
 
360
- unique_id = self.wandb_id
265
+ unique_id = self.logger.wandb_id
361
266
  if unique_id is None:
362
267
  unique_id = cast(
363
268
  Any, wandb
@@ -366,7 +271,7 @@ class LanguageModelSAERunnerConfig:
366
271
 
367
272
  if self.verbose:
368
273
  logger.info(
369
- f"Run name: {self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.training_tokens:3.3e}"
274
+ f"Run name: {self.sae.architecture()}-{self.sae.d_sae}-LR-{self.lr}-Tokens-{self.training_tokens:3.3e}"
370
275
  )
371
276
  # Print out some useful info:
372
277
  n_tokens_per_buffer = (
@@ -385,11 +290,13 @@ class LanguageModelSAERunnerConfig:
385
290
  )
386
291
 
387
292
  total_training_steps = (
388
- self.training_tokens + self.finetuning_tokens
293
+ self.training_tokens
389
294
  ) // self.train_batch_size_tokens
390
295
  logger.info(f"Total training steps: {total_training_steps}")
391
296
 
392
- total_wandb_updates = total_training_steps // self.wandb_log_frequency
297
+ total_wandb_updates = (
298
+ total_training_steps // self.logger.wandb_log_frequency
299
+ )
393
300
  logger.info(f"Total wandb updates: {total_wandb_updates}")
394
301
 
395
302
  # how many times will we sample dead neurons?
@@ -411,9 +318,6 @@ class LanguageModelSAERunnerConfig:
411
318
  f"Number tokens in sparsity calculation window: {self.feature_sampling_window * self.train_batch_size_tokens:.2e}"
412
319
  )
413
320
 
414
- if self.use_ghost_grads:
415
- logger.info("Using Ghost Grads.")
416
-
417
321
  if self.context_size < 0:
418
322
  raise ValueError(
419
323
  f"The provided context_size is {self.context_size} is negative. Expecting positive context_size."
@@ -428,65 +332,26 @@ class LanguageModelSAERunnerConfig:
428
332
 
429
333
  @property
430
334
  def total_training_tokens(self) -> int:
431
- return self.training_tokens + self.finetuning_tokens
335
+ return self.training_tokens
432
336
 
433
337
  @property
434
338
  def total_training_steps(self) -> int:
435
339
  return self.total_training_tokens // self.train_batch_size_tokens
436
340
 
437
- def get_base_sae_cfg_dict(self) -> dict[str, Any]:
438
- return {
439
- # TEMP
440
- "architecture": self.architecture,
441
- "d_in": self.d_in,
442
- "d_sae": self.d_sae,
443
- "dtype": self.dtype,
444
- "device": self.device,
445
- "model_name": self.model_name,
446
- "hook_name": self.hook_name,
447
- "hook_layer": self.hook_layer,
448
- "hook_head_index": self.hook_head_index,
449
- "activation_fn_str": self.activation_fn,
450
- "apply_b_dec_to_input": self.apply_b_dec_to_input,
451
- "context_size": self.context_size,
452
- "prepend_bos": self.prepend_bos,
453
- "dataset_path": self.dataset_path,
454
- "dataset_trust_remote_code": self.dataset_trust_remote_code,
455
- "finetuning_scaling_factor": self.finetuning_method is not None,
456
- "sae_lens_training_version": self.sae_lens_training_version,
457
- "normalize_activations": self.normalize_activations,
458
- "activation_fn_kwargs": self.activation_fn_kwargs,
459
- "model_from_pretrained_kwargs": self.model_from_pretrained_kwargs,
460
- "seqpos_slice": self.seqpos_slice,
461
- }
462
-
463
341
  def get_training_sae_cfg_dict(self) -> dict[str, Any]:
464
- return {
465
- **self.get_base_sae_cfg_dict(),
466
- "l1_coefficient": self.l1_coefficient,
467
- "lp_norm": self.lp_norm,
468
- "use_ghost_grads": self.use_ghost_grads,
469
- "normalize_sae_decoder": self.normalize_sae_decoder,
470
- "noise_scale": self.noise_scale,
471
- "decoder_orthogonal_init": self.decoder_orthogonal_init,
472
- "mse_loss_normalization": self.mse_loss_normalization,
473
- "decoder_heuristic_init": self.decoder_heuristic_init,
474
- "decoder_heuristic_init_norm": self.decoder_heuristic_init_norm,
475
- "init_encoder_as_decoder_transpose": self.init_encoder_as_decoder_transpose,
476
- "normalize_activations": self.normalize_activations,
477
- "jumprelu_init_threshold": self.jumprelu_init_threshold,
478
- "jumprelu_bandwidth": self.jumprelu_bandwidth,
479
- "scale_sparsity_penalty_by_decoder_norm": self.scale_sparsity_penalty_by_decoder_norm,
480
- }
342
+ return self.sae.to_dict()
481
343
 
482
344
  def to_dict(self) -> dict[str, Any]:
483
- return {
484
- **self.__dict__,
485
- # some args may not be serializable by default
486
- "dtype": str(self.dtype),
487
- "device": str(self.device),
488
- "act_store_device": str(self.act_store_device),
489
- }
345
+ # Make a shallow copy of config's dictionary
346
+ d = dict(self.__dict__)
347
+
348
+ d["logger"] = asdict(self.logger)
349
+ d["sae"] = self.sae.to_dict()
350
+ # Overwrite fields that might not be JSON-serializable
351
+ d["dtype"] = str(self.dtype)
352
+ d["device"] = str(self.device)
353
+ d["act_store_device"] = str(self.act_store_device)
354
+ return d
490
355
 
491
356
  def to_json(self, path: str) -> None:
492
357
  if not os.path.exists(os.path.dirname(path)):
@@ -496,7 +361,7 @@ class LanguageModelSAERunnerConfig:
496
361
  json.dump(self.to_dict(), f, indent=2)
497
362
 
498
363
  @classmethod
499
- def from_json(cls, path: str) -> "LanguageModelSAERunnerConfig":
364
+ def from_json(cls, path: str) -> "LanguageModelSAERunnerConfig[Any]":
500
365
  with open(path + "cfg.json") as f:
501
366
  cfg = json.load(f)
502
367
 
@@ -510,6 +375,27 @@ class LanguageModelSAERunnerConfig:
510
375
 
511
376
  return cls(**cfg)
512
377
 
378
+ def to_sae_trainer_config(self) -> "SAETrainerConfig":
379
+ return SAETrainerConfig(
380
+ n_checkpoints=self.n_checkpoints,
381
+ checkpoint_path=self.checkpoint_path,
382
+ total_training_samples=self.total_training_tokens,
383
+ device=self.device,
384
+ autocast=self.autocast,
385
+ lr=self.lr,
386
+ lr_end=self.lr_end,
387
+ lr_scheduler_name=self.lr_scheduler_name,
388
+ lr_warm_up_steps=self.lr_warm_up_steps,
389
+ adam_beta1=self.adam_beta1,
390
+ adam_beta2=self.adam_beta2,
391
+ lr_decay_steps=self.lr_decay_steps,
392
+ n_restart_cycles=self.n_restart_cycles,
393
+ train_batch_size_samples=self.train_batch_size_tokens,
394
+ dead_feature_window=self.dead_feature_window,
395
+ feature_sampling_window=self.feature_sampling_window,
396
+ logger=self.logger,
397
+ )
398
+
513
399
 
514
400
  @dataclass
515
401
  class CacheActivationsRunnerConfig:
@@ -521,7 +407,6 @@ class CacheActivationsRunnerConfig:
521
407
  model_name (str): The name of the model to use.
522
408
  model_batch_size (int): How many prompts are in the batch of the language model when generating activations.
523
409
  hook_name (str): The name of the hook to use.
524
- hook_layer (int): The layer of the final hook. Currently only support a single hook, so this should be the same as hook_name.
525
410
  d_in (int): Dimension of the model.
526
411
  total_training_tokens (int): Total number of tokens to process.
527
412
  context_size (int): Context size to process. Can be left as -1 if the dataset is tokenized.
@@ -551,7 +436,6 @@ class CacheActivationsRunnerConfig:
551
436
  model_name: str
552
437
  model_batch_size: int
553
438
  hook_name: str
554
- hook_layer: int
555
439
  d_in: int
556
440
  training_tokens: int
557
441
 
@@ -679,6 +563,10 @@ def _validate_seqpos(seqpos: tuple[int | None, ...], context_size: int) -> None:
679
563
 
680
564
  @dataclass
681
565
  class PretokenizeRunnerConfig:
566
+ """
567
+ Configuration class for pretokenizing a dataset.
568
+ """
569
+
682
570
  tokenizer_name: str = "gpt2"
683
571
  dataset_path: str = ""
684
572
  dataset_name: str | None = None
@@ -707,3 +595,28 @@ class PretokenizeRunnerConfig:
707
595
  hf_num_shards: int = 64
708
596
  hf_revision: str = "main"
709
597
  hf_is_private_repo: bool = False
598
+
599
+
600
+ @dataclass
601
+ class SAETrainerConfig:
602
+ n_checkpoints: int
603
+ checkpoint_path: str
604
+ total_training_samples: int
605
+ device: str
606
+ autocast: bool
607
+ lr: float
608
+ lr_end: float | None
609
+ lr_scheduler_name: str
610
+ lr_warm_up_steps: int
611
+ adam_beta1: float
612
+ adam_beta2: float
613
+ lr_decay_steps: int
614
+ n_restart_cycles: int
615
+ train_batch_size_samples: int
616
+ dead_feature_window: int
617
+ feature_sampling_window: int
618
+ logger: LoggingConfig
619
+
620
+ @property
621
+ def total_training_steps(self) -> int:
622
+ return self.total_training_samples // self.train_batch_size_samples
sae_lens/constants.py ADDED
@@ -0,0 +1,21 @@
1
+ import torch
2
+
3
+ DTYPE_MAP = {
4
+ "float32": torch.float32,
5
+ "float64": torch.float64,
6
+ "float16": torch.float16,
7
+ "bfloat16": torch.bfloat16,
8
+ "torch.float32": torch.float32,
9
+ "torch.float64": torch.float64,
10
+ "torch.float16": torch.float16,
11
+ "torch.bfloat16": torch.bfloat16,
12
+ }
13
+
14
+
15
+ SPARSITY_FILENAME = "sparsity.safetensors"
16
+ SAE_WEIGHTS_FILENAME = "sae_weights.safetensors"
17
+ SAE_CFG_FILENAME = "cfg.json"
18
+ RUNNER_CFG_FILENAME = "runner_cfg.json"
19
+ SPARSIFY_WEIGHTS_FILENAME = "sae.safetensors"
20
+ ACTIVATIONS_STORE_STATE_FILENAME = "activations_store_state.safetensors"
21
+ ACTIVATION_SCALER_CFG_FILENAME = "activation_scaler.json"