sae-lens 6.0.0rc1__py3-none-any.whl → 6.0.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sae_lens/config.py CHANGED
@@ -3,7 +3,7 @@ import math
3
3
  import os
4
4
  from dataclasses import asdict, dataclass, field
5
5
  from pathlib import Path
6
- from typing import Any, Literal, cast
6
+ from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast
7
7
 
8
8
  import simple_parsing
9
9
  import torch
@@ -17,24 +17,17 @@ from datasets import (
17
17
  )
18
18
 
19
19
  from sae_lens import __version__, logger
20
+ from sae_lens.constants import DTYPE_MAP
21
+ from sae_lens.saes.sae import TrainingSAEConfig
20
22
 
21
- DTYPE_MAP = {
22
- "float32": torch.float32,
23
- "float64": torch.float64,
24
- "float16": torch.float16,
25
- "bfloat16": torch.bfloat16,
26
- "torch.float32": torch.float32,
27
- "torch.float64": torch.float64,
28
- "torch.float16": torch.float16,
29
- "torch.bfloat16": torch.bfloat16,
30
- }
31
-
32
- HfDataset = DatasetDict | Dataset | IterableDatasetDict | IterableDataset
23
+ if TYPE_CHECKING:
24
+ pass
33
25
 
26
+ T_TRAINING_SAE_CONFIG = TypeVar(
27
+ "T_TRAINING_SAE_CONFIG", bound=TrainingSAEConfig, covariant=True
28
+ )
34
29
 
35
- SPARSITY_FILENAME = "sparsity.safetensors"
36
- SAE_WEIGHTS_FILENAME = "sae_weights.safetensors"
37
- SAE_CFG_FILENAME = "cfg.json"
30
+ HfDataset = DatasetDict | Dataset | IterableDatasetDict | IterableDataset
38
31
 
39
32
 
40
33
  # calling this "json_dict" so error messages will reference "json_dict" being invalid
@@ -101,101 +94,72 @@ class LoggingConfig:
101
94
 
102
95
 
103
96
  @dataclass
104
- class LanguageModelSAERunnerConfig:
97
+ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
105
98
  """
106
99
  Configuration for training a sparse autoencoder on a language model.
107
100
 
108
101
  Args:
109
- architecture (str): The architecture to use, either "standard", "gated", "topk", or "jumprelu".
102
+ sae (T_TRAINING_SAE_CONFIG): The configuration for the SAE itself (e.g. StandardSAEConfig, GatedSAEConfig).
110
103
  model_name (str): The name of the model to use. This should be the name of the model in the Hugging Face model hub.
111
104
  model_class_name (str): The name of the class of the model to use. This should be either `HookedTransformer` or `HookedMamba`.
112
105
  hook_name (str): The name of the hook to use. This should be a valid TransformerLens hook.
113
106
  hook_eval (str): NOT CURRENTLY IN USE. The name of the hook to use for evaluation.
114
- hook_layer (int): The index of the layer to hook. Used to stop forward passes early and speed up processing.
115
- hook_head_index (int, optional): When the hook if for an activatio with a head index, we can specify a specific head to use here.
107
+ hook_head_index (int, optional): When the hook is for an activation with a head index, we can specify a specific head to use here.
116
108
  dataset_path (str): A Hugging Face dataset path.
117
109
  dataset_trust_remote_code (bool): Whether to trust remote code when loading datasets from Huggingface.
118
110
  streaming (bool): Whether to stream the dataset. Streaming large datasets is usually practical.
119
- is_dataset_tokenized (bool): NOT IN USE. We used to use this but now automatically detect if the dataset is tokenized.
111
+ is_dataset_tokenized (bool): Whether the dataset is already tokenized.
120
112
  context_size (int): The context size to use when generating activations on which to train the SAE.
121
113
  use_cached_activations (bool): Whether to use cached activations. This is useful when doing sweeps over the same activations.
122
- cached_activations_path (str, optional): The path to the cached activations.
123
- d_in (int): The input dimension of the SAE.
124
- d_sae (int, optional): The output dimension of the SAE. If None, defaults to `d_in * expansion_factor`.
125
- b_dec_init_method (str): The method to use to initialize the decoder bias. Zeros is likely fine.
126
- expansion_factor (int): The expansion factor. Larger is better but more computationally expensive. Default is 4.
127
- activation_fn (str): The activation function to use. Relu is standard.
128
- normalize_sae_decoder (bool): Whether to normalize the SAE decoder. Unit normed decoder weights used to be preferred.
129
- noise_scale (float): Using noise to induce sparsity is supported but not recommended.
114
+ cached_activations_path (str, optional): The path to the cached activations. Defaults to "activations/{dataset_path}/{model_name}/{hook_name}_{hook_head_index}".
130
115
  from_pretrained_path (str, optional): The path to a pretrained SAE. We can finetune an existing SAE if needed.
131
- apply_b_dec_to_input (bool): Whether to apply the decoder bias to the input. Not currently advised.
132
- decoder_orthogonal_init (bool): Whether to use orthogonal initialization for the decoder. Not currently advised.
133
- decoder_heuristic_init (bool): Whether to use heuristic initialization for the decoder. See Anthropic April Update.
134
- init_encoder_as_decoder_transpose (bool): Whether to initialize the encoder as the transpose of the decoder. See Anthropic April Update.
135
- n_batches_in_buffer (int): The number of batches in the buffer. When not using cached activations, a buffer in ram is used. The larger it is, the better shuffled the activations will be.
116
+ n_batches_in_buffer (int): The number of batches in the buffer. When not using cached activations, a buffer in RAM is used. The larger it is, the better shuffled the activations will be.
136
117
  training_tokens (int): The number of training tokens.
137
- finetuning_tokens (int): The number of finetuning tokens. See [here](https://www.lesswrong.com/posts/3JuSjTZyMzaSeTxKk/addressing-feature-suppression-in-saes)
138
- store_batch_size_prompts (int): The batch size for storing activations. This controls how many prompts are in the batch of the language model when generating actiations.
139
- train_batch_size_tokens (int): The batch size for training. This controls the batch size of the SAE Training loop.
140
- normalize_activations (str): Activation Normalization Strategy. Either none, expected_average_only_in (estimate the average activation norm and divide activations by it following Antrhopic April update -> this can be folded post training and set to None), or constant_norm_rescale (at runtime set activation norm to sqrt(d_in) and then scale up the SAE output).
141
- seqpos_slice (tuple): Determines slicing of activations when constructing batches during training. The slice should be (start_pos, end_pos, optional[step_size]), e.g. for Othello we sometimes use (5, -5). Note, step_size > 0.
142
- device (str): The device to use. Usually cuda.
143
- act_store_device (str): The device to use for the activation store. CPU is advised in order to save vram.
118
+ store_batch_size_prompts (int): The batch size for storing activations. This controls how many prompts are in the batch of the language model when generating activations.
119
+ seqpos_slice (tuple[int | None, ...]): Determines slicing of activations when constructing batches during training. The slice should be (start_pos, end_pos, optional[step_size]), e.g. for Othello we sometimes use (5, -5). Note, step_size > 0.
120
+ device (str): The device to use. Usually "cuda".
121
+ act_store_device (str): The device to use for the activation store. "cpu" is advised in order to save VRAM. Defaults to "with_model" which uses the same device as the main model.
144
122
  seed (int): The seed to use.
145
- dtype (str): The data type to use.
123
+ dtype (str): The data type to use for the SAE and activations.
146
124
  prepend_bos (bool): Whether to prepend the beginning of sequence token. You should use whatever the model was trained with.
147
- jumprelu_init_threshold (float): The threshold to initialize for training JumpReLU SAEs.
148
- jumprelu_bandwidth (float): Bandwidth for training JumpReLU SAEs.
149
- autocast (bool): Whether to use autocast during training. Saves vram.
150
- autocast_lm (bool): Whether to use autocast during activation fetching.
151
- compile_llm (bool): Whether to compile the LLM.
152
- llm_compilation_mode (str): The compilation mode to use for the LLM.
153
- compile_sae (bool): Whether to compile the SAE.
154
- sae_compilation_mode (str): The compilation mode to use for the SAE.
155
- adam_beta1 (float): The beta1 parameter for Adam.
156
- adam_beta2 (float): The beta2 parameter for Adam.
157
- mse_loss_normalization (str): The normalization to use for the MSE loss.
158
- l1_coefficient (float): The L1 coefficient.
159
- lp_norm (float): The Lp norm.
160
- scale_sparsity_penalty_by_decoder_norm (bool): Whether to scale the sparsity penalty by the decoder norm.
161
- l1_warm_up_steps (int): The number of warm-up steps for the L1 loss.
125
+ autocast (bool): Whether to use autocast (mixed-precision) during SAE training. Saves VRAM.
126
+ autocast_lm (bool): Whether to use autocast (mixed-precision) during activation fetching. Saves VRAM.
127
+ compile_llm (bool): Whether to compile the LLM using `torch.compile`.
128
+ llm_compilation_mode (str, optional): The compilation mode to use for the LLM if `compile_llm` is True.
129
+ compile_sae (bool): Whether to compile the SAE using `torch.compile`.
130
+ sae_compilation_mode (str, optional): The compilation mode to use for the SAE if `compile_sae` is True.
131
+ train_batch_size_tokens (int): The batch size for training, in tokens. This controls the batch size of the SAE training loop.
132
+ adam_beta1 (float): The beta1 parameter for the Adam optimizer.
133
+ adam_beta2 (float): The beta2 parameter for the Adam optimizer.
162
134
  lr (float): The learning rate.
163
- lr_scheduler_name (str): The name of the learning rate scheduler to use.
135
+ lr_scheduler_name (str): The name of the learning rate scheduler to use (e.g., "constant", "cosineannealing", "cosineannealingwarmrestarts").
164
136
  lr_warm_up_steps (int): The number of warm-up steps for the learning rate.
165
- lr_end (float): The end learning rate if lr_decay_steps is set. Default is lr / 10.
166
- lr_decay_steps (int): The number of decay steps for the learning rate.
167
- n_restart_cycles (int): The number of restart cycles for the cosine annealing warm restarts scheduler.
168
- finetuning_method (str): The method to use for finetuning.
169
- use_ghost_grads (bool): Whether to use ghost gradients.
170
- feature_sampling_window (int): The feature sampling window.
171
- dead_feature_window (int): The dead feature window.
172
- dead_feature_threshold (float): The dead feature threshold.
173
- n_eval_batches (int): The number of evaluation batches.
174
- eval_batch_size_prompts (int): The batch size for evaluation.
175
- log_to_wandb (bool): Whether to log to Weights & Biases.
176
- log_activations_store_to_wandb (bool): NOT CURRENTLY USED. Whether to log the activations store to Weights & Biases.
177
- log_optimizer_state_to_wandb (bool): NOT CURRENTLY USED. Whether to log the optimizer state to Weights & Biases.
178
- wandb_project (str): The Weights & Biases project to log to.
179
- wandb_id (str): The Weights & Biases ID.
180
- run_name (str): The name of the run.
181
- wandb_entity (str): The Weights & Biases entity.
182
- wandb_log_frequency (int): The frequency to log to Weights & Biases.
183
- eval_every_n_wandb_logs (int): The frequency to evaluate.
184
- resume (bool): Whether to resume training.
185
- n_checkpoints (int): The number of checkpoints.
186
- checkpoint_path (str): The path to save checkpoints.
137
+ lr_end (float, optional): The end learning rate if using a scheduler like cosine annealing. Defaults to `lr / 10`.
138
+ lr_decay_steps (int): The number of decay steps for the learning rate if using a scheduler with decay.
139
+ n_restart_cycles (int): The number of restart cycles for the cosine annealing with warm restarts scheduler.
140
+ dead_feature_window (int): The window size (in training steps) for detecting dead features.
141
+ feature_sampling_window (int): The window size (in training steps) for resampling features (e.g. dead features).
142
+ dead_feature_threshold (float): The threshold below which a feature's activation frequency is considered dead.
143
+ n_eval_batches (int): The number of batches to use for evaluation.
144
+ eval_batch_size_prompts (int, optional): The batch size for evaluation, in prompts. Useful if evals cause OOM.
145
+ logger (LoggingConfig): Configuration for logging (e.g. W&B).
146
+ n_checkpoints (int): The number of checkpoints to save during training. 0 means no checkpoints.
147
+ checkpoint_path (str): The path to save checkpoints. A unique ID will be appended to this path.
187
148
  verbose (bool): Whether to print verbose output.
188
- model_kwargs (dict[str, Any]): Additional keyword arguments for the model.
189
- model_from_pretrained_kwargs (dict[str, Any]): Additional keyword arguments for the model from pretrained.
190
- exclude_special_tokens (bool | list[int]): Whether to exclude special tokens from the activations.
149
+ model_kwargs (dict[str, Any]): Keyword arguments for `model.run_with_cache`
150
+ model_from_pretrained_kwargs (dict[str, Any], optional): Additional keyword arguments to pass to the model's `from_pretrained` method.
151
+ sae_lens_version (str): The version of the sae_lens library.
152
+ sae_lens_training_version (str): The version of the sae_lens training library.
153
+ exclude_special_tokens (bool | list[int]): Whether to exclude special tokens from the activations. If True, excludes all special tokens. If a list of ints, excludes those token IDs.
191
154
  """
192
155
 
156
+ sae: T_TRAINING_SAE_CONFIG
157
+
193
158
  # Data Generating Function (Model + Training Distibuion)
194
159
  model_name: str = "gelu-2l"
195
160
  model_class_name: str = "HookedTransformer"
196
161
  hook_name: str = "blocks.0.hook_mlp_out"
197
162
  hook_eval: str = "NOT_IN_USE"
198
- hook_layer: int = 0
199
163
  hook_head_index: int | None = None
200
164
  dataset_path: str = ""
201
165
  dataset_trust_remote_code: bool = True
@@ -208,29 +172,12 @@ class LanguageModelSAERunnerConfig:
208
172
  )
209
173
 
210
174
  # SAE Parameters
211
- architecture: Literal["standard", "gated", "jumprelu", "topk"] = "standard"
212
- d_in: int = 512
213
- d_sae: int | None = None
214
- b_dec_init_method: str = "geometric_median"
215
- expansion_factor: int | None = (
216
- None # defaults to 4 if d_sae and expansion_factor is None
217
- )
218
- activation_fn: str = None # relu, tanh-relu, topk. Default is relu. # type: ignore
219
- activation_fn_kwargs: dict[str, int] = dict_field(default=None) # for topk
220
- normalize_sae_decoder: bool = True
221
- noise_scale: float = 0.0
222
175
  from_pretrained_path: str | None = None
223
- apply_b_dec_to_input: bool = True
224
- decoder_orthogonal_init: bool = False
225
- decoder_heuristic_init: bool = False
226
- init_encoder_as_decoder_transpose: bool = False
227
176
 
228
177
  # Activation Store Parameters
229
178
  n_batches_in_buffer: int = 20
230
179
  training_tokens: int = 2_000_000
231
- finetuning_tokens: int = 0
232
180
  store_batch_size_prompts: int = 32
233
- normalize_activations: str = "none" # none, expected_average_only_in (Anthropic April Update), constant_norm_rescale (Anthropic Feb Update)
234
181
  seqpos_slice: tuple[int | None, ...] = (None,)
235
182
 
236
183
  # Misc
@@ -240,10 +187,6 @@ class LanguageModelSAERunnerConfig:
240
187
  dtype: str = "float32" # type: ignore #
241
188
  prepend_bos: bool = True
242
189
 
243
- # JumpReLU Parameters
244
- jumprelu_init_threshold: float = 0.001
245
- jumprelu_bandwidth: float = 0.001
246
-
247
190
  # Performance - see compilation section of lm_runner.py for info
248
191
  autocast: bool = False # autocast to autocast_dtype during training
249
192
  autocast_lm: bool = False # autocast lm during activation fetching
@@ -261,13 +204,6 @@ class LanguageModelSAERunnerConfig:
261
204
  adam_beta1: float = 0.0
262
205
  adam_beta2: float = 0.999
263
206
 
264
- ## Loss Function
265
- mse_loss_normalization: str | None = None
266
- l1_coefficient: float = 1e-3
267
- lp_norm: float = 1
268
- scale_sparsity_penalty_by_decoder_norm: bool = False
269
- l1_warm_up_steps: int = 0
270
-
271
207
  ## Learning Rate Schedule
272
208
  lr: float = 3e-4
273
209
  lr_scheduler_name: str = (
@@ -278,14 +214,9 @@ class LanguageModelSAERunnerConfig:
278
214
  lr_decay_steps: int = 0
279
215
  n_restart_cycles: int = 1 # used only for cosineannealingwarmrestarts
280
216
 
281
- ## FineTuning
282
- finetuning_method: str | None = None # scale, decoder or unrotated_decoder
283
-
284
217
  # Resampling protocol args
285
- use_ghost_grads: bool = False # want to change this to true on some timeline.
286
- feature_sampling_window: int = 2000
287
218
  dead_feature_window: int = 1000 # unless this window is larger feature sampling,
288
-
219
+ feature_sampling_window: int = 2000
289
220
  dead_feature_threshold: float = 1e-8
290
221
 
291
222
  # Evals
@@ -295,7 +226,6 @@ class LanguageModelSAERunnerConfig:
295
226
  logger: LoggingConfig = field(default_factory=LoggingConfig)
296
227
 
297
228
  # Misc
298
- resume: bool = False
299
229
  n_checkpoints: int = 0
300
230
  checkpoint_path: str = "checkpoints"
301
231
  verbose: bool = True
@@ -306,12 +236,6 @@ class LanguageModelSAERunnerConfig:
306
236
  exclude_special_tokens: bool | list[int] = False
307
237
 
308
238
  def __post_init__(self):
309
- if self.resume:
310
- raise ValueError(
311
- "Resuming is no longer supported. You can finetune a trained SAE using cfg.from_pretrained path."
312
- + "If you want to load an SAE with resume=True in the config, please manually set resume=False in that config."
313
- )
314
-
315
239
  if self.use_cached_activations and self.cached_activations_path is None:
316
240
  self.cached_activations_path = _default_cached_activations_path(
317
241
  self.dataset_path,
@@ -319,37 +243,12 @@ class LanguageModelSAERunnerConfig:
319
243
  self.hook_name,
320
244
  self.hook_head_index,
321
245
  )
322
-
323
- if self.activation_fn is None:
324
- self.activation_fn = "topk" if self.architecture == "topk" else "relu"
325
-
326
- if self.architecture == "topk" and self.activation_fn != "topk":
327
- raise ValueError("If using topk architecture, activation_fn must be topk.")
328
-
329
- if self.activation_fn_kwargs is None:
330
- self.activation_fn_kwargs = (
331
- {"k": 100} if self.activation_fn == "topk" else {}
332
- )
333
-
334
- if self.architecture == "topk" and self.activation_fn_kwargs.get("k") is None:
335
- raise ValueError(
336
- "activation_fn_kwargs.k must be provided for topk architecture."
337
- )
338
-
339
- if self.d_sae is not None and self.expansion_factor is not None:
340
- raise ValueError("You can't set both d_sae and expansion_factor.")
341
-
342
- if self.d_sae is None and self.expansion_factor is None:
343
- self.expansion_factor = 4
344
-
345
- if self.d_sae is None and self.expansion_factor is not None:
346
- self.d_sae = self.d_in * self.expansion_factor
347
246
  self.tokens_per_buffer = (
348
247
  self.train_batch_size_tokens * self.context_size * self.n_batches_in_buffer
349
248
  )
350
249
 
351
250
  if self.logger.run_name is None:
352
- self.logger.run_name = f"{self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.training_tokens:3.3e}"
251
+ self.logger.run_name = f"{self.sae.architecture()}-{self.sae.d_sae}-LR-{self.lr}-Tokens-{self.training_tokens:3.3e}"
353
252
 
354
253
  if self.model_from_pretrained_kwargs is None:
355
254
  if self.model_class_name == "HookedTransformer":
@@ -357,37 +256,6 @@ class LanguageModelSAERunnerConfig:
357
256
  else:
358
257
  self.model_from_pretrained_kwargs = {}
359
258
 
360
- if self.b_dec_init_method not in ["geometric_median", "mean", "zeros"]:
361
- raise ValueError(
362
- f"b_dec_init_method must be geometric_median, mean, or zeros. Got {self.b_dec_init_method}"
363
- )
364
-
365
- if self.normalize_sae_decoder and self.decoder_heuristic_init:
366
- raise ValueError(
367
- "You can't normalize the decoder and use heuristic initialization."
368
- )
369
-
370
- if self.normalize_sae_decoder and self.scale_sparsity_penalty_by_decoder_norm:
371
- raise ValueError(
372
- "Weighting loss by decoder norm makes no sense if you are normalizing the decoder weight norms to 1"
373
- )
374
-
375
- # if we use decoder fine tuning, we can't be applying b_dec to the input
376
- if (self.finetuning_method == "decoder") and (self.apply_b_dec_to_input):
377
- raise ValueError(
378
- "If we are fine tuning the decoder, we can't be applying b_dec to the input.\nSet apply_b_dec_to_input to False."
379
- )
380
-
381
- if self.normalize_activations not in [
382
- "none",
383
- "expected_average_only_in",
384
- "constant_norm_rescale",
385
- "layer_norm",
386
- ]:
387
- raise ValueError(
388
- f"normalize_activations must be none, layer_norm, expected_average_only_in, or constant_norm_rescale. Got {self.normalize_activations}"
389
- )
390
-
391
259
  if self.act_store_device == "with_model":
392
260
  self.act_store_device = self.device
393
261
 
@@ -403,7 +271,7 @@ class LanguageModelSAERunnerConfig:
403
271
 
404
272
  if self.verbose:
405
273
  logger.info(
406
- f"Run name: {self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.training_tokens:3.3e}"
274
+ f"Run name: {self.sae.architecture()}-{self.sae.d_sae}-LR-{self.lr}-Tokens-{self.training_tokens:3.3e}"
407
275
  )
408
276
  # Print out some useful info:
409
277
  n_tokens_per_buffer = (
@@ -422,7 +290,7 @@ class LanguageModelSAERunnerConfig:
422
290
  )
423
291
 
424
292
  total_training_steps = (
425
- self.training_tokens + self.finetuning_tokens
293
+ self.training_tokens
426
294
  ) // self.train_batch_size_tokens
427
295
  logger.info(f"Total training steps: {total_training_steps}")
428
296
 
@@ -450,9 +318,6 @@ class LanguageModelSAERunnerConfig:
450
318
  f"Number tokens in sparsity calculation window: {self.feature_sampling_window * self.train_batch_size_tokens:.2e}"
451
319
  )
452
320
 
453
- if self.use_ghost_grads:
454
- logger.info("Using Ghost Grads.")
455
-
456
321
  if self.context_size < 0:
457
322
  raise ValueError(
458
323
  f"The provided context_size is {self.context_size} is negative. Expecting positive context_size."
@@ -467,62 +332,21 @@ class LanguageModelSAERunnerConfig:
467
332
 
468
333
  @property
469
334
  def total_training_tokens(self) -> int:
470
- return self.training_tokens + self.finetuning_tokens
335
+ return self.training_tokens
471
336
 
472
337
  @property
473
338
  def total_training_steps(self) -> int:
474
339
  return self.total_training_tokens // self.train_batch_size_tokens
475
340
 
476
- def get_base_sae_cfg_dict(self) -> dict[str, Any]:
477
- return {
478
- # TEMP
479
- "architecture": self.architecture,
480
- "d_in": self.d_in,
481
- "d_sae": self.d_sae,
482
- "dtype": self.dtype,
483
- "device": self.device,
484
- "model_name": self.model_name,
485
- "hook_name": self.hook_name,
486
- "hook_layer": self.hook_layer,
487
- "hook_head_index": self.hook_head_index,
488
- "activation_fn": self.activation_fn,
489
- "apply_b_dec_to_input": self.apply_b_dec_to_input,
490
- "context_size": self.context_size,
491
- "prepend_bos": self.prepend_bos,
492
- "dataset_path": self.dataset_path,
493
- "dataset_trust_remote_code": self.dataset_trust_remote_code,
494
- "finetuning_scaling_factor": self.finetuning_method is not None,
495
- "sae_lens_training_version": self.sae_lens_training_version,
496
- "normalize_activations": self.normalize_activations,
497
- "activation_fn_kwargs": self.activation_fn_kwargs,
498
- "model_from_pretrained_kwargs": self.model_from_pretrained_kwargs,
499
- "seqpos_slice": self.seqpos_slice,
500
- }
501
-
502
341
  def get_training_sae_cfg_dict(self) -> dict[str, Any]:
503
- return {
504
- **self.get_base_sae_cfg_dict(),
505
- "l1_coefficient": self.l1_coefficient,
506
- "lp_norm": self.lp_norm,
507
- "use_ghost_grads": self.use_ghost_grads,
508
- "normalize_sae_decoder": self.normalize_sae_decoder,
509
- "noise_scale": self.noise_scale,
510
- "decoder_orthogonal_init": self.decoder_orthogonal_init,
511
- "mse_loss_normalization": self.mse_loss_normalization,
512
- "decoder_heuristic_init": self.decoder_heuristic_init,
513
- "init_encoder_as_decoder_transpose": self.init_encoder_as_decoder_transpose,
514
- "normalize_activations": self.normalize_activations,
515
- "jumprelu_init_threshold": self.jumprelu_init_threshold,
516
- "jumprelu_bandwidth": self.jumprelu_bandwidth,
517
- "scale_sparsity_penalty_by_decoder_norm": self.scale_sparsity_penalty_by_decoder_norm,
518
- }
342
+ return self.sae.to_dict()
519
343
 
520
344
  def to_dict(self) -> dict[str, Any]:
521
- # Make a shallow copy of configs dictionary
345
+ # Make a shallow copy of config's dictionary
522
346
  d = dict(self.__dict__)
523
347
 
524
348
  d["logger"] = asdict(self.logger)
525
-
349
+ d["sae"] = self.sae.to_dict()
526
350
  # Overwrite fields that might not be JSON-serializable
527
351
  d["dtype"] = str(self.dtype)
528
352
  d["device"] = str(self.device)
@@ -537,7 +361,7 @@ class LanguageModelSAERunnerConfig:
537
361
  json.dump(self.to_dict(), f, indent=2)
538
362
 
539
363
  @classmethod
540
- def from_json(cls, path: str) -> "LanguageModelSAERunnerConfig":
364
+ def from_json(cls, path: str) -> "LanguageModelSAERunnerConfig[Any]":
541
365
  with open(path + "cfg.json") as f:
542
366
  cfg = json.load(f)
543
367
 
@@ -551,6 +375,28 @@ class LanguageModelSAERunnerConfig:
551
375
 
552
376
  return cls(**cfg)
553
377
 
378
+ def to_sae_trainer_config(self) -> "SAETrainerConfig":
379
+ return SAETrainerConfig(
380
+ n_checkpoints=self.n_checkpoints,
381
+ checkpoint_path=self.checkpoint_path,
382
+ total_training_samples=self.total_training_tokens,
383
+ device=self.device,
384
+ autocast=self.autocast,
385
+ lr=self.lr,
386
+ lr_end=self.lr_end,
387
+ lr_scheduler_name=self.lr_scheduler_name,
388
+ lr_warm_up_steps=self.lr_warm_up_steps,
389
+ adam_beta1=self.adam_beta1,
390
+ adam_beta2=self.adam_beta2,
391
+ lr_decay_steps=self.lr_decay_steps,
392
+ n_restart_cycles=self.n_restart_cycles,
393
+ total_training_steps=self.total_training_steps,
394
+ train_batch_size_samples=self.train_batch_size_tokens,
395
+ dead_feature_window=self.dead_feature_window,
396
+ feature_sampling_window=self.feature_sampling_window,
397
+ logger=self.logger,
398
+ )
399
+
554
400
 
555
401
  @dataclass
556
402
  class CacheActivationsRunnerConfig:
@@ -562,7 +408,6 @@ class CacheActivationsRunnerConfig:
562
408
  model_name (str): The name of the model to use.
563
409
  model_batch_size (int): How many prompts are in the batch of the language model when generating activations.
564
410
  hook_name (str): The name of the hook to use.
565
- hook_layer (int): The layer of the final hook. Currently only support a single hook, so this should be the same as hook_name.
566
411
  d_in (int): Dimension of the model.
567
412
  total_training_tokens (int): Total number of tokens to process.
568
413
  context_size (int): Context size to process. Can be left as -1 if the dataset is tokenized.
@@ -592,7 +437,6 @@ class CacheActivationsRunnerConfig:
592
437
  model_name: str
593
438
  model_batch_size: int
594
439
  hook_name: str
595
- hook_layer: int
596
440
  d_in: int
597
441
  training_tokens: int
598
442
 
@@ -720,6 +564,10 @@ def _validate_seqpos(seqpos: tuple[int | None, ...], context_size: int) -> None:
720
564
 
721
565
  @dataclass
722
566
  class PretokenizeRunnerConfig:
567
+ """
568
+ Configuration class for pretokenizing a dataset.
569
+ """
570
+
723
571
  tokenizer_name: str = "gpt2"
724
572
  dataset_path: str = ""
725
573
  dataset_name: str | None = None
@@ -748,3 +596,25 @@ class PretokenizeRunnerConfig:
748
596
  hf_num_shards: int = 64
749
597
  hf_revision: str = "main"
750
598
  hf_is_private_repo: bool = False
599
+
600
+
601
+ @dataclass
602
+ class SAETrainerConfig:
603
+ n_checkpoints: int
604
+ checkpoint_path: str
605
+ total_training_samples: int
606
+ device: str
607
+ autocast: bool
608
+ lr: float
609
+ lr_end: float | None
610
+ lr_scheduler_name: str
611
+ lr_warm_up_steps: int
612
+ adam_beta1: float
613
+ adam_beta2: float
614
+ lr_decay_steps: int
615
+ n_restart_cycles: int
616
+ total_training_steps: int
617
+ train_batch_size_samples: int
618
+ dead_feature_window: int
619
+ feature_sampling_window: int
620
+ logger: LoggingConfig
sae_lens/constants.py ADDED
@@ -0,0 +1,20 @@
1
+ import torch
2
+
3
+ DTYPE_MAP = {
4
+ "float32": torch.float32,
5
+ "float64": torch.float64,
6
+ "float16": torch.float16,
7
+ "bfloat16": torch.bfloat16,
8
+ "torch.float32": torch.float32,
9
+ "torch.float64": torch.float64,
10
+ "torch.float16": torch.float16,
11
+ "torch.bfloat16": torch.bfloat16,
12
+ }
13
+
14
+
15
+ SPARSITY_FILENAME = "sparsity.safetensors"
16
+ SAE_WEIGHTS_FILENAME = "sae_weights.safetensors"
17
+ SAE_CFG_FILENAME = "cfg.json"
18
+ RUNNER_CFG_FILENAME = "runner_cfg.json"
19
+ ACTIVATIONS_STORE_STATE_FILENAME = "activations_store_state.safetensors"
20
+ ACTIVATION_SCALER_CFG_FILENAME = "activation_scaler.json"