sae-lens 6.0.0rc1__py3-none-any.whl → 6.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sae_lens/config.py CHANGED
@@ -3,7 +3,7 @@ import math
3
3
  import os
4
4
  from dataclasses import asdict, dataclass, field
5
5
  from pathlib import Path
6
- from typing import Any, Literal, cast
6
+ from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast
7
7
 
8
8
  import simple_parsing
9
9
  import torch
@@ -17,24 +17,15 @@ from datasets import (
17
17
  )
18
18
 
19
19
  from sae_lens import __version__, logger
20
+ from sae_lens.constants import DTYPE_MAP
21
+ from sae_lens.saes.sae import TrainingSAEConfig
20
22
 
21
- DTYPE_MAP = {
22
- "float32": torch.float32,
23
- "float64": torch.float64,
24
- "float16": torch.float16,
25
- "bfloat16": torch.bfloat16,
26
- "torch.float32": torch.float32,
27
- "torch.float64": torch.float64,
28
- "torch.float16": torch.float16,
29
- "torch.bfloat16": torch.bfloat16,
30
- }
31
-
32
- HfDataset = DatasetDict | Dataset | IterableDatasetDict | IterableDataset
23
+ if TYPE_CHECKING:
24
+ pass
33
25
 
26
+ T_TRAINING_SAE_CONFIG = TypeVar("T_TRAINING_SAE_CONFIG", bound=TrainingSAEConfig)
34
27
 
35
- SPARSITY_FILENAME = "sparsity.safetensors"
36
- SAE_WEIGHTS_FILENAME = "sae_weights.safetensors"
37
- SAE_CFG_FILENAME = "cfg.json"
28
+ HfDataset = DatasetDict | Dataset | IterableDatasetDict | IterableDataset
38
29
 
39
30
 
40
31
  # calling this "json_dict" so error messages will reference "json_dict" being invalid
@@ -101,95 +92,68 @@ class LoggingConfig:
101
92
 
102
93
 
103
94
  @dataclass
104
- class LanguageModelSAERunnerConfig:
95
+ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
105
96
  """
106
97
  Configuration for training a sparse autoencoder on a language model.
107
98
 
108
99
  Args:
109
- architecture (str): The architecture to use, either "standard", "gated", "topk", or "jumprelu".
100
+ sae (T_TRAINING_SAE_CONFIG): The configuration for the SAE itself (e.g. StandardSAEConfig, GatedSAEConfig).
110
101
  model_name (str): The name of the model to use. This should be the name of the model in the Hugging Face model hub.
111
102
  model_class_name (str): The name of the class of the model to use. This should be either `HookedTransformer` or `HookedMamba`.
112
103
  hook_name (str): The name of the hook to use. This should be a valid TransformerLens hook.
113
104
  hook_eval (str): NOT CURRENTLY IN USE. The name of the hook to use for evaluation.
114
105
  hook_layer (int): The index of the layer to hook. Used to stop forward passes early and speed up processing.
115
- hook_head_index (int, optional): When the hook if for an activatio with a head index, we can specify a specific head to use here.
106
+ hook_head_index (int, optional): When the hook is for an activation with a head index, we can specify a specific head to use here.
116
107
  dataset_path (str): A Hugging Face dataset path.
117
108
  dataset_trust_remote_code (bool): Whether to trust remote code when loading datasets from Huggingface.
118
109
  streaming (bool): Whether to stream the dataset. Streaming large datasets is usually practical.
119
- is_dataset_tokenized (bool): NOT IN USE. We used to use this but now automatically detect if the dataset is tokenized.
110
+ is_dataset_tokenized (bool): Whether the dataset is already tokenized.
120
111
  context_size (int): The context size to use when generating activations on which to train the SAE.
121
112
  use_cached_activations (bool): Whether to use cached activations. This is useful when doing sweeps over the same activations.
122
- cached_activations_path (str, optional): The path to the cached activations.
123
- d_in (int): The input dimension of the SAE.
124
- d_sae (int, optional): The output dimension of the SAE. If None, defaults to `d_in * expansion_factor`.
125
- b_dec_init_method (str): The method to use to initialize the decoder bias. Zeros is likely fine.
126
- expansion_factor (int): The expansion factor. Larger is better but more computationally expensive. Default is 4.
127
- activation_fn (str): The activation function to use. Relu is standard.
128
- normalize_sae_decoder (bool): Whether to normalize the SAE decoder. Unit normed decoder weights used to be preferred.
129
- noise_scale (float): Using noise to induce sparsity is supported but not recommended.
113
+ cached_activations_path (str, optional): The path to the cached activations. Defaults to "activations/{dataset_path}/{model_name}/{hook_name}_{hook_head_index}".
130
114
  from_pretrained_path (str, optional): The path to a pretrained SAE. We can finetune an existing SAE if needed.
131
- apply_b_dec_to_input (bool): Whether to apply the decoder bias to the input. Not currently advised.
132
- decoder_orthogonal_init (bool): Whether to use orthogonal initialization for the decoder. Not currently advised.
133
- decoder_heuristic_init (bool): Whether to use heuristic initialization for the decoder. See Anthropic April Update.
134
- init_encoder_as_decoder_transpose (bool): Whether to initialize the encoder as the transpose of the decoder. See Anthropic April Update.
135
- n_batches_in_buffer (int): The number of batches in the buffer. When not using cached activations, a buffer in ram is used. The larger it is, the better shuffled the activations will be.
115
+ n_batches_in_buffer (int): The number of batches in the buffer. When not using cached activations, a buffer in RAM is used. The larger it is, the better shuffled the activations will be.
136
116
  training_tokens (int): The number of training tokens.
137
- finetuning_tokens (int): The number of finetuning tokens. See [here](https://www.lesswrong.com/posts/3JuSjTZyMzaSeTxKk/addressing-feature-suppression-in-saes)
138
- store_batch_size_prompts (int): The batch size for storing activations. This controls how many prompts are in the batch of the language model when generating actiations.
139
- train_batch_size_tokens (int): The batch size for training. This controls the batch size of the SAE Training loop.
140
- normalize_activations (str): Activation Normalization Strategy. Either none, expected_average_only_in (estimate the average activation norm and divide activations by it following Antrhopic April update -> this can be folded post training and set to None), or constant_norm_rescale (at runtime set activation norm to sqrt(d_in) and then scale up the SAE output).
141
- seqpos_slice (tuple): Determines slicing of activations when constructing batches during training. The slice should be (start_pos, end_pos, optional[step_size]), e.g. for Othello we sometimes use (5, -5). Note, step_size > 0.
142
- device (str): The device to use. Usually cuda.
143
- act_store_device (str): The device to use for the activation store. CPU is advised in order to save vram.
117
+ store_batch_size_prompts (int): The batch size for storing activations. This controls how many prompts are in the batch of the language model when generating activations.
118
+ seqpos_slice (tuple[int | None, ...]): Determines slicing of activations when constructing batches during training. The slice should be (start_pos, end_pos, optional[step_size]), e.g. for Othello we sometimes use (5, -5). Note, step_size > 0.
119
+ device (str): The device to use. Usually "cuda".
120
+ act_store_device (str): The device to use for the activation store. "cpu" is advised in order to save VRAM. Defaults to "with_model" which uses the same device as the main model.
144
121
  seed (int): The seed to use.
145
- dtype (str): The data type to use.
122
+ dtype (str): The data type to use for the SAE and activations.
146
123
  prepend_bos (bool): Whether to prepend the beginning of sequence token. You should use whatever the model was trained with.
147
- jumprelu_init_threshold (float): The threshold to initialize for training JumpReLU SAEs.
148
- jumprelu_bandwidth (float): Bandwidth for training JumpReLU SAEs.
149
- autocast (bool): Whether to use autocast during training. Saves vram.
150
- autocast_lm (bool): Whether to use autocast during activation fetching.
151
- compile_llm (bool): Whether to compile the LLM.
152
- llm_compilation_mode (str): The compilation mode to use for the LLM.
153
- compile_sae (bool): Whether to compile the SAE.
154
- sae_compilation_mode (str): The compilation mode to use for the SAE.
155
- adam_beta1 (float): The beta1 parameter for Adam.
156
- adam_beta2 (float): The beta2 parameter for Adam.
157
- mse_loss_normalization (str): The normalization to use for the MSE loss.
158
- l1_coefficient (float): The L1 coefficient.
159
- lp_norm (float): The Lp norm.
160
- scale_sparsity_penalty_by_decoder_norm (bool): Whether to scale the sparsity penalty by the decoder norm.
161
- l1_warm_up_steps (int): The number of warm-up steps for the L1 loss.
124
+ autocast (bool): Whether to use autocast (mixed-precision) during SAE training. Saves VRAM.
125
+ autocast_lm (bool): Whether to use autocast (mixed-precision) during activation fetching. Saves VRAM.
126
+ compile_llm (bool): Whether to compile the LLM using `torch.compile`.
127
+ llm_compilation_mode (str, optional): The compilation mode to use for the LLM if `compile_llm` is True.
128
+ compile_sae (bool): Whether to compile the SAE using `torch.compile`.
129
+ sae_compilation_mode (str, optional): The compilation mode to use for the SAE if `compile_sae` is True.
130
+ train_batch_size_tokens (int): The batch size for training, in tokens. This controls the batch size of the SAE training loop.
131
+ adam_beta1 (float): The beta1 parameter for the Adam optimizer.
132
+ adam_beta2 (float): The beta2 parameter for the Adam optimizer.
162
133
  lr (float): The learning rate.
163
- lr_scheduler_name (str): The name of the learning rate scheduler to use.
134
+ lr_scheduler_name (str): The name of the learning rate scheduler to use (e.g., "constant", "cosineannealing", "cosineannealingwarmrestarts").
164
135
  lr_warm_up_steps (int): The number of warm-up steps for the learning rate.
165
- lr_end (float): The end learning rate if lr_decay_steps is set. Default is lr / 10.
166
- lr_decay_steps (int): The number of decay steps for the learning rate.
167
- n_restart_cycles (int): The number of restart cycles for the cosine annealing warm restarts scheduler.
168
- finetuning_method (str): The method to use for finetuning.
169
- use_ghost_grads (bool): Whether to use ghost gradients.
170
- feature_sampling_window (int): The feature sampling window.
171
- dead_feature_window (int): The dead feature window.
172
- dead_feature_threshold (float): The dead feature threshold.
173
- n_eval_batches (int): The number of evaluation batches.
174
- eval_batch_size_prompts (int): The batch size for evaluation.
175
- log_to_wandb (bool): Whether to log to Weights & Biases.
176
- log_activations_store_to_wandb (bool): NOT CURRENTLY USED. Whether to log the activations store to Weights & Biases.
177
- log_optimizer_state_to_wandb (bool): NOT CURRENTLY USED. Whether to log the optimizer state to Weights & Biases.
178
- wandb_project (str): The Weights & Biases project to log to.
179
- wandb_id (str): The Weights & Biases ID.
180
- run_name (str): The name of the run.
181
- wandb_entity (str): The Weights & Biases entity.
182
- wandb_log_frequency (int): The frequency to log to Weights & Biases.
183
- eval_every_n_wandb_logs (int): The frequency to evaluate.
184
- resume (bool): Whether to resume training.
185
- n_checkpoints (int): The number of checkpoints.
186
- checkpoint_path (str): The path to save checkpoints.
136
+ lr_end (float, optional): The end learning rate if using a scheduler like cosine annealing. Defaults to `lr / 10`.
137
+ lr_decay_steps (int): The number of decay steps for the learning rate if using a scheduler with decay.
138
+ n_restart_cycles (int): The number of restart cycles for the cosine annealing with warm restarts scheduler.
139
+ dead_feature_window (int): The window size (in training steps) for detecting dead features.
140
+ feature_sampling_window (int): The window size (in training steps) for resampling features (e.g. dead features).
141
+ dead_feature_threshold (float): The threshold below which a feature's activation frequency is considered dead.
142
+ n_eval_batches (int): The number of batches to use for evaluation.
143
+ eval_batch_size_prompts (int, optional): The batch size for evaluation, in prompts. Useful if evals cause OOM.
144
+ logger (LoggingConfig): Configuration for logging (e.g. W&B).
145
+ n_checkpoints (int): The number of checkpoints to save during training. 0 means no checkpoints.
146
+ checkpoint_path (str): The path to save checkpoints. A unique ID will be appended to this path.
187
147
  verbose (bool): Whether to print verbose output.
188
- model_kwargs (dict[str, Any]): Additional keyword arguments for the model.
189
- model_from_pretrained_kwargs (dict[str, Any]): Additional keyword arguments for the model from pretrained.
190
- exclude_special_tokens (bool | list[int]): Whether to exclude special tokens from the activations.
148
+ model_kwargs (dict[str, Any]): Keyword arguments for `model.run_with_cache`
149
+ model_from_pretrained_kwargs (dict[str, Any], optional): Additional keyword arguments to pass to the model's `from_pretrained` method.
150
+ sae_lens_version (str): The version of the sae_lens library.
151
+ sae_lens_training_version (str): The version of the sae_lens training library.
152
+ exclude_special_tokens (bool | list[int]): Whether to exclude special tokens from the activations. If True, excludes all special tokens. If a list of ints, excludes those token IDs.
191
153
  """
192
154
 
155
+ sae: T_TRAINING_SAE_CONFIG
156
+
193
157
  # Data Generating Function (Model + Training Distibuion)
194
158
  model_name: str = "gelu-2l"
195
159
  model_class_name: str = "HookedTransformer"
@@ -208,29 +172,12 @@ class LanguageModelSAERunnerConfig:
208
172
  )
209
173
 
210
174
  # SAE Parameters
211
- architecture: Literal["standard", "gated", "jumprelu", "topk"] = "standard"
212
- d_in: int = 512
213
- d_sae: int | None = None
214
- b_dec_init_method: str = "geometric_median"
215
- expansion_factor: int | None = (
216
- None # defaults to 4 if d_sae and expansion_factor is None
217
- )
218
- activation_fn: str = None # relu, tanh-relu, topk. Default is relu. # type: ignore
219
- activation_fn_kwargs: dict[str, int] = dict_field(default=None) # for topk
220
- normalize_sae_decoder: bool = True
221
- noise_scale: float = 0.0
222
175
  from_pretrained_path: str | None = None
223
- apply_b_dec_to_input: bool = True
224
- decoder_orthogonal_init: bool = False
225
- decoder_heuristic_init: bool = False
226
- init_encoder_as_decoder_transpose: bool = False
227
176
 
228
177
  # Activation Store Parameters
229
178
  n_batches_in_buffer: int = 20
230
179
  training_tokens: int = 2_000_000
231
- finetuning_tokens: int = 0
232
180
  store_batch_size_prompts: int = 32
233
- normalize_activations: str = "none" # none, expected_average_only_in (Anthropic April Update), constant_norm_rescale (Anthropic Feb Update)
234
181
  seqpos_slice: tuple[int | None, ...] = (None,)
235
182
 
236
183
  # Misc
@@ -240,10 +187,6 @@ class LanguageModelSAERunnerConfig:
240
187
  dtype: str = "float32" # type: ignore #
241
188
  prepend_bos: bool = True
242
189
 
243
- # JumpReLU Parameters
244
- jumprelu_init_threshold: float = 0.001
245
- jumprelu_bandwidth: float = 0.001
246
-
247
190
  # Performance - see compilation section of lm_runner.py for info
248
191
  autocast: bool = False # autocast to autocast_dtype during training
249
192
  autocast_lm: bool = False # autocast lm during activation fetching
@@ -261,13 +204,6 @@ class LanguageModelSAERunnerConfig:
261
204
  adam_beta1: float = 0.0
262
205
  adam_beta2: float = 0.999
263
206
 
264
- ## Loss Function
265
- mse_loss_normalization: str | None = None
266
- l1_coefficient: float = 1e-3
267
- lp_norm: float = 1
268
- scale_sparsity_penalty_by_decoder_norm: bool = False
269
- l1_warm_up_steps: int = 0
270
-
271
207
  ## Learning Rate Schedule
272
208
  lr: float = 3e-4
273
209
  lr_scheduler_name: str = (
@@ -278,14 +214,9 @@ class LanguageModelSAERunnerConfig:
278
214
  lr_decay_steps: int = 0
279
215
  n_restart_cycles: int = 1 # used only for cosineannealingwarmrestarts
280
216
 
281
- ## FineTuning
282
- finetuning_method: str | None = None # scale, decoder or unrotated_decoder
283
-
284
217
  # Resampling protocol args
285
- use_ghost_grads: bool = False # want to change this to true on some timeline.
286
- feature_sampling_window: int = 2000
287
218
  dead_feature_window: int = 1000 # unless this window is larger feature sampling,
288
-
219
+ feature_sampling_window: int = 2000
289
220
  dead_feature_threshold: float = 1e-8
290
221
 
291
222
  # Evals
@@ -295,7 +226,6 @@ class LanguageModelSAERunnerConfig:
295
226
  logger: LoggingConfig = field(default_factory=LoggingConfig)
296
227
 
297
228
  # Misc
298
- resume: bool = False
299
229
  n_checkpoints: int = 0
300
230
  checkpoint_path: str = "checkpoints"
301
231
  verbose: bool = True
@@ -306,12 +236,6 @@ class LanguageModelSAERunnerConfig:
306
236
  exclude_special_tokens: bool | list[int] = False
307
237
 
308
238
  def __post_init__(self):
309
- if self.resume:
310
- raise ValueError(
311
- "Resuming is no longer supported. You can finetune a trained SAE using cfg.from_pretrained path."
312
- + "If you want to load an SAE with resume=True in the config, please manually set resume=False in that config."
313
- )
314
-
315
239
  if self.use_cached_activations and self.cached_activations_path is None:
316
240
  self.cached_activations_path = _default_cached_activations_path(
317
241
  self.dataset_path,
@@ -319,37 +243,12 @@ class LanguageModelSAERunnerConfig:
319
243
  self.hook_name,
320
244
  self.hook_head_index,
321
245
  )
322
-
323
- if self.activation_fn is None:
324
- self.activation_fn = "topk" if self.architecture == "topk" else "relu"
325
-
326
- if self.architecture == "topk" and self.activation_fn != "topk":
327
- raise ValueError("If using topk architecture, activation_fn must be topk.")
328
-
329
- if self.activation_fn_kwargs is None:
330
- self.activation_fn_kwargs = (
331
- {"k": 100} if self.activation_fn == "topk" else {}
332
- )
333
-
334
- if self.architecture == "topk" and self.activation_fn_kwargs.get("k") is None:
335
- raise ValueError(
336
- "activation_fn_kwargs.k must be provided for topk architecture."
337
- )
338
-
339
- if self.d_sae is not None and self.expansion_factor is not None:
340
- raise ValueError("You can't set both d_sae and expansion_factor.")
341
-
342
- if self.d_sae is None and self.expansion_factor is None:
343
- self.expansion_factor = 4
344
-
345
- if self.d_sae is None and self.expansion_factor is not None:
346
- self.d_sae = self.d_in * self.expansion_factor
347
246
  self.tokens_per_buffer = (
348
247
  self.train_batch_size_tokens * self.context_size * self.n_batches_in_buffer
349
248
  )
350
249
 
351
250
  if self.logger.run_name is None:
352
- self.logger.run_name = f"{self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.training_tokens:3.3e}"
251
+ self.logger.run_name = f"{self.sae.architecture()}-{self.sae.d_sae}-LR-{self.lr}-Tokens-{self.training_tokens:3.3e}"
353
252
 
354
253
  if self.model_from_pretrained_kwargs is None:
355
254
  if self.model_class_name == "HookedTransformer":
@@ -357,37 +256,6 @@ class LanguageModelSAERunnerConfig:
357
256
  else:
358
257
  self.model_from_pretrained_kwargs = {}
359
258
 
360
- if self.b_dec_init_method not in ["geometric_median", "mean", "zeros"]:
361
- raise ValueError(
362
- f"b_dec_init_method must be geometric_median, mean, or zeros. Got {self.b_dec_init_method}"
363
- )
364
-
365
- if self.normalize_sae_decoder and self.decoder_heuristic_init:
366
- raise ValueError(
367
- "You can't normalize the decoder and use heuristic initialization."
368
- )
369
-
370
- if self.normalize_sae_decoder and self.scale_sparsity_penalty_by_decoder_norm:
371
- raise ValueError(
372
- "Weighting loss by decoder norm makes no sense if you are normalizing the decoder weight norms to 1"
373
- )
374
-
375
- # if we use decoder fine tuning, we can't be applying b_dec to the input
376
- if (self.finetuning_method == "decoder") and (self.apply_b_dec_to_input):
377
- raise ValueError(
378
- "If we are fine tuning the decoder, we can't be applying b_dec to the input.\nSet apply_b_dec_to_input to False."
379
- )
380
-
381
- if self.normalize_activations not in [
382
- "none",
383
- "expected_average_only_in",
384
- "constant_norm_rescale",
385
- "layer_norm",
386
- ]:
387
- raise ValueError(
388
- f"normalize_activations must be none, layer_norm, expected_average_only_in, or constant_norm_rescale. Got {self.normalize_activations}"
389
- )
390
-
391
259
  if self.act_store_device == "with_model":
392
260
  self.act_store_device = self.device
393
261
 
@@ -403,7 +271,7 @@ class LanguageModelSAERunnerConfig:
403
271
 
404
272
  if self.verbose:
405
273
  logger.info(
406
- f"Run name: {self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.training_tokens:3.3e}"
274
+ f"Run name: {self.sae.architecture()}-{self.sae.d_sae}-LR-{self.lr}-Tokens-{self.training_tokens:3.3e}"
407
275
  )
408
276
  # Print out some useful info:
409
277
  n_tokens_per_buffer = (
@@ -422,7 +290,7 @@ class LanguageModelSAERunnerConfig:
422
290
  )
423
291
 
424
292
  total_training_steps = (
425
- self.training_tokens + self.finetuning_tokens
293
+ self.training_tokens
426
294
  ) // self.train_batch_size_tokens
427
295
  logger.info(f"Total training steps: {total_training_steps}")
428
296
 
@@ -450,9 +318,6 @@ class LanguageModelSAERunnerConfig:
450
318
  f"Number tokens in sparsity calculation window: {self.feature_sampling_window * self.train_batch_size_tokens:.2e}"
451
319
  )
452
320
 
453
- if self.use_ghost_grads:
454
- logger.info("Using Ghost Grads.")
455
-
456
321
  if self.context_size < 0:
457
322
  raise ValueError(
458
323
  f"The provided context_size is {self.context_size} is negative. Expecting positive context_size."
@@ -467,62 +332,21 @@ class LanguageModelSAERunnerConfig:
467
332
 
468
333
  @property
469
334
  def total_training_tokens(self) -> int:
470
- return self.training_tokens + self.finetuning_tokens
335
+ return self.training_tokens
471
336
 
472
337
  @property
473
338
  def total_training_steps(self) -> int:
474
339
  return self.total_training_tokens // self.train_batch_size_tokens
475
340
 
476
- def get_base_sae_cfg_dict(self) -> dict[str, Any]:
477
- return {
478
- # TEMP
479
- "architecture": self.architecture,
480
- "d_in": self.d_in,
481
- "d_sae": self.d_sae,
482
- "dtype": self.dtype,
483
- "device": self.device,
484
- "model_name": self.model_name,
485
- "hook_name": self.hook_name,
486
- "hook_layer": self.hook_layer,
487
- "hook_head_index": self.hook_head_index,
488
- "activation_fn": self.activation_fn,
489
- "apply_b_dec_to_input": self.apply_b_dec_to_input,
490
- "context_size": self.context_size,
491
- "prepend_bos": self.prepend_bos,
492
- "dataset_path": self.dataset_path,
493
- "dataset_trust_remote_code": self.dataset_trust_remote_code,
494
- "finetuning_scaling_factor": self.finetuning_method is not None,
495
- "sae_lens_training_version": self.sae_lens_training_version,
496
- "normalize_activations": self.normalize_activations,
497
- "activation_fn_kwargs": self.activation_fn_kwargs,
498
- "model_from_pretrained_kwargs": self.model_from_pretrained_kwargs,
499
- "seqpos_slice": self.seqpos_slice,
500
- }
501
-
502
341
  def get_training_sae_cfg_dict(self) -> dict[str, Any]:
503
- return {
504
- **self.get_base_sae_cfg_dict(),
505
- "l1_coefficient": self.l1_coefficient,
506
- "lp_norm": self.lp_norm,
507
- "use_ghost_grads": self.use_ghost_grads,
508
- "normalize_sae_decoder": self.normalize_sae_decoder,
509
- "noise_scale": self.noise_scale,
510
- "decoder_orthogonal_init": self.decoder_orthogonal_init,
511
- "mse_loss_normalization": self.mse_loss_normalization,
512
- "decoder_heuristic_init": self.decoder_heuristic_init,
513
- "init_encoder_as_decoder_transpose": self.init_encoder_as_decoder_transpose,
514
- "normalize_activations": self.normalize_activations,
515
- "jumprelu_init_threshold": self.jumprelu_init_threshold,
516
- "jumprelu_bandwidth": self.jumprelu_bandwidth,
517
- "scale_sparsity_penalty_by_decoder_norm": self.scale_sparsity_penalty_by_decoder_norm,
518
- }
342
+ return self.sae.to_dict()
519
343
 
520
344
  def to_dict(self) -> dict[str, Any]:
521
- # Make a shallow copy of configs dictionary
345
+ # Make a shallow copy of config's dictionary
522
346
  d = dict(self.__dict__)
523
347
 
524
348
  d["logger"] = asdict(self.logger)
525
-
349
+ d["sae"] = self.sae.to_dict()
526
350
  # Overwrite fields that might not be JSON-serializable
527
351
  d["dtype"] = str(self.dtype)
528
352
  d["device"] = str(self.device)
@@ -537,7 +361,7 @@ class LanguageModelSAERunnerConfig:
537
361
  json.dump(self.to_dict(), f, indent=2)
538
362
 
539
363
  @classmethod
540
- def from_json(cls, path: str) -> "LanguageModelSAERunnerConfig":
364
+ def from_json(cls, path: str) -> "LanguageModelSAERunnerConfig[Any]":
541
365
  with open(path + "cfg.json") as f:
542
366
  cfg = json.load(f)
543
367
 
@@ -720,6 +544,10 @@ def _validate_seqpos(seqpos: tuple[int | None, ...], context_size: int) -> None:
720
544
 
721
545
  @dataclass
722
546
  class PretokenizeRunnerConfig:
547
+ """
548
+ Configuration class for pretokenizing a dataset.
549
+ """
550
+
723
551
  tokenizer_name: str = "gpt2"
724
552
  dataset_path: str = ""
725
553
  dataset_name: str | None = None
sae_lens/constants.py ADDED
@@ -0,0 +1,18 @@
1
+ import torch
2
+
3
+ DTYPE_MAP = {
4
+ "float32": torch.float32,
5
+ "float64": torch.float64,
6
+ "float16": torch.float16,
7
+ "bfloat16": torch.bfloat16,
8
+ "torch.float32": torch.float32,
9
+ "torch.float64": torch.float64,
10
+ "torch.float16": torch.float16,
11
+ "torch.bfloat16": torch.bfloat16,
12
+ }
13
+
14
+
15
+ SPARSITY_FILENAME = "sparsity.safetensors"
16
+ SAE_WEIGHTS_FILENAME = "sae_weights.safetensors"
17
+ SAE_CFG_FILENAME = "cfg.json"
18
+ RUNNER_CFG_FILENAME = "runner_cfg.json"
sae_lens/evals.py CHANGED
@@ -20,7 +20,7 @@ from transformer_lens import HookedTransformer
20
20
  from transformer_lens.hook_points import HookedRootModule
21
21
 
22
22
  from sae_lens.loading.pretrained_saes_directory import get_pretrained_saes_directory
23
- from sae_lens.saes.sae import SAE
23
+ from sae_lens.saes.sae import SAE, SAEConfig
24
24
  from sae_lens.training.activations_store import ActivationsStore
25
25
 
26
26
 
@@ -100,7 +100,7 @@ def get_eval_everything_config(
100
100
 
101
101
  @torch.no_grad()
102
102
  def run_evals(
103
- sae: SAE,
103
+ sae: SAE[Any],
104
104
  activation_store: ActivationsStore,
105
105
  model: HookedRootModule,
106
106
  eval_config: EvalConfig = EvalConfig(),
@@ -108,7 +108,7 @@ def run_evals(
108
108
  ignore_tokens: set[int | None] = set(),
109
109
  verbose: bool = False,
110
110
  ) -> tuple[dict[str, Any], dict[str, Any]]:
111
- hook_name = sae.cfg.hook_name
111
+ hook_name = sae.cfg.metadata.hook_name
112
112
  actual_batch_size = (
113
113
  eval_config.batch_size_prompts or activation_store.store_batch_size_prompts
114
114
  )
@@ -274,7 +274,7 @@ def run_evals(
274
274
  return all_metrics, feature_metrics
275
275
 
276
276
 
277
- def get_featurewise_weight_based_metrics(sae: SAE) -> dict[str, Any]:
277
+ def get_featurewise_weight_based_metrics(sae: SAE[Any]) -> dict[str, Any]:
278
278
  unit_norm_encoders = (sae.W_enc / sae.W_enc.norm(dim=0, keepdim=True)).cpu()
279
279
  unit_norm_decoder = (sae.W_dec.T / sae.W_dec.T.norm(dim=0, keepdim=True)).cpu()
280
280
 
@@ -298,7 +298,7 @@ def get_featurewise_weight_based_metrics(sae: SAE) -> dict[str, Any]:
298
298
 
299
299
 
300
300
  def get_downstream_reconstruction_metrics(
301
- sae: SAE,
301
+ sae: SAE[Any],
302
302
  model: HookedRootModule,
303
303
  activation_store: ActivationsStore,
304
304
  compute_kl: bool,
@@ -366,7 +366,7 @@ def get_downstream_reconstruction_metrics(
366
366
 
367
367
 
368
368
  def get_sparsity_and_variance_metrics(
369
- sae: SAE,
369
+ sae: SAE[Any],
370
370
  model: HookedRootModule,
371
371
  activation_store: ActivationsStore,
372
372
  n_batches: int,
@@ -379,8 +379,8 @@ def get_sparsity_and_variance_metrics(
379
379
  ignore_tokens: set[int | None] = set(),
380
380
  verbose: bool = False,
381
381
  ) -> tuple[dict[str, Any], dict[str, Any]]:
382
- hook_name = sae.cfg.hook_name
383
- hook_head_index = sae.cfg.hook_head_index
382
+ hook_name = sae.cfg.metadata.hook_name
383
+ hook_head_index = sae.cfg.metadata.hook_head_index
384
384
 
385
385
  metric_dict = {}
386
386
  feature_metric_dict = {}
@@ -436,7 +436,7 @@ def get_sparsity_and_variance_metrics(
436
436
  batch_tokens,
437
437
  prepend_bos=False,
438
438
  names_filter=[hook_name],
439
- stop_at_layer=sae.cfg.hook_layer + 1,
439
+ stop_at_layer=sae.cfg.metadata.hook_layer + 1,
440
440
  **model_kwargs,
441
441
  )
442
442
 
@@ -580,7 +580,7 @@ def get_sparsity_and_variance_metrics(
580
580
 
581
581
  @torch.no_grad()
582
582
  def get_recons_loss(
583
- sae: SAE,
583
+ sae: SAE[SAEConfig],
584
584
  model: HookedRootModule,
585
585
  batch_tokens: torch.Tensor,
586
586
  activation_store: ActivationsStore,
@@ -588,9 +588,13 @@ def get_recons_loss(
588
588
  compute_ce_loss: bool,
589
589
  ignore_tokens: set[int | None] = set(),
590
590
  model_kwargs: Mapping[str, Any] = {},
591
+ hook_name: str | None = None,
591
592
  ) -> dict[str, Any]:
592
- hook_name = sae.cfg.hook_name
593
- head_index = sae.cfg.hook_head_index
593
+ hook_name = hook_name or sae.cfg.metadata.hook_name
594
+ head_index = sae.cfg.metadata.hook_head_index
595
+
596
+ if hook_name is None:
597
+ raise ValueError("hook_name must be provided")
594
598
 
595
599
  original_logits, original_ce_loss = model(
596
600
  batch_tokens, return_type="both", loss_per_token=True, **model_kwargs
@@ -806,7 +810,6 @@ def multiple_evals(
806
810
 
807
811
  current_model = None
808
812
  current_model_str = None
809
- print(filtered_saes)
810
813
  for sae_release_name, sae_id, _, _ in tqdm(filtered_saes):
811
814
  sae = SAE.from_pretrained(
812
815
  release=sae_release_name, # see other options in sae_lens/pretrained_saes.yaml
@@ -7,11 +7,12 @@ import numpy as np
7
7
  import torch
8
8
  from huggingface_hub import hf_hub_download
9
9
  from huggingface_hub.utils import EntryNotFoundError
10
+ from packaging.version import Version
10
11
  from safetensors import safe_open
11
12
  from safetensors.torch import load_file
12
13
 
13
14
  from sae_lens import logger
14
- from sae_lens.config import (
15
+ from sae_lens.constants import (
15
16
  DTYPE_MAP,
16
17
  SAE_CFG_FILENAME,
17
18
  SAE_WEIGHTS_FILENAME,
@@ -22,6 +23,8 @@ from sae_lens.loading.pretrained_saes_directory import (
22
23
  get_pretrained_saes_directory,
23
24
  get_repo_id_and_folder_name,
24
25
  )
26
+ from sae_lens.registry import get_sae_class
27
+ from sae_lens.util import filter_valid_dataclass_fields
25
28
 
26
29
 
27
30
  # loaders take in a release, sae_id, device, and whether to force download, and returns a tuple of config, state_dict, and log sparsity
@@ -174,6 +177,20 @@ def get_sae_lens_config_from_disk(
174
177
 
175
178
 
176
179
  def handle_config_defaulting(cfg_dict: dict[str, Any]) -> dict[str, Any]:
180
+ sae_lens_version = cfg_dict.get("sae_lens_version")
181
+ if not sae_lens_version and "metadata" in cfg_dict:
182
+ sae_lens_version = cfg_dict["metadata"].get("sae_lens_version")
183
+
184
+ if not sae_lens_version or Version(sae_lens_version) < Version("6.0.0-rc.0"):
185
+ cfg_dict = handle_pre_6_0_config(cfg_dict)
186
+ return cfg_dict
187
+
188
+
189
+ def handle_pre_6_0_config(cfg_dict: dict[str, Any]) -> dict[str, Any]:
190
+ """
191
+ Format a config dictionary for a Sparse Autoencoder (SAE) to be compatible with the new 6.0 format.
192
+ """
193
+
177
194
  rename_keys_map = {
178
195
  "hook_point": "hook_name",
179
196
  "hook_point_layer": "hook_layer",
@@ -202,10 +219,26 @@ def handle_config_defaulting(cfg_dict: dict[str, Any]) -> dict[str, Any]:
202
219
  else "expected_average_only_in"
203
220
  )
204
221
 
205
- new_cfg.setdefault("normalize_activations", "none")
222
+ if new_cfg.get("normalize_activations") is None:
223
+ new_cfg["normalize_activations"] = "none"
224
+
206
225
  new_cfg.setdefault("device", "cpu")
207
226
 
208
- return new_cfg
227
+ architecture = new_cfg.get("architecture", "standard")
228
+
229
+ config_class = get_sae_class(architecture)[1]
230
+
231
+ sae_cfg_dict = filter_valid_dataclass_fields(new_cfg, config_class)
232
+ if architecture == "topk":
233
+ sae_cfg_dict["k"] = new_cfg["activation_fn_kwargs"]["k"]
234
+
235
+ # import here to avoid circular import
236
+ from sae_lens.saes.sae import SAEMetadata
237
+
238
+ meta_dict = filter_valid_dataclass_fields(new_cfg, SAEMetadata)
239
+ sae_cfg_dict["metadata"] = meta_dict
240
+ sae_cfg_dict["architecture"] = architecture
241
+ return sae_cfg_dict
209
242
 
210
243
 
211
244
  def get_connor_rob_hook_z_config_from_hf(