sae-lens 5.10.7__py3-none-any.whl → 6.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. sae_lens/__init__.py +60 -7
  2. sae_lens/analysis/hooked_sae_transformer.py +12 -12
  3. sae_lens/analysis/neuronpedia_integration.py +16 -14
  4. sae_lens/cache_activations_runner.py +9 -7
  5. sae_lens/config.py +170 -257
  6. sae_lens/constants.py +21 -0
  7. sae_lens/evals.py +59 -44
  8. sae_lens/llm_sae_training_runner.py +377 -0
  9. sae_lens/load_model.py +53 -5
  10. sae_lens/{toolkit → loading}/pretrained_sae_loaders.py +228 -32
  11. sae_lens/registry.py +49 -0
  12. sae_lens/saes/__init__.py +48 -0
  13. sae_lens/saes/gated_sae.py +254 -0
  14. sae_lens/saes/jumprelu_sae.py +348 -0
  15. sae_lens/saes/sae.py +1076 -0
  16. sae_lens/saes/standard_sae.py +178 -0
  17. sae_lens/saes/topk_sae.py +300 -0
  18. sae_lens/training/activation_scaler.py +53 -0
  19. sae_lens/training/activations_store.py +103 -184
  20. sae_lens/training/mixing_buffer.py +56 -0
  21. sae_lens/training/optim.py +60 -36
  22. sae_lens/training/sae_trainer.py +155 -177
  23. sae_lens/training/types.py +5 -0
  24. sae_lens/training/upload_saes_to_huggingface.py +13 -7
  25. sae_lens/util.py +47 -0
  26. {sae_lens-5.10.7.dist-info → sae_lens-6.0.0.dist-info}/METADATA +1 -1
  27. sae_lens-6.0.0.dist-info/RECORD +37 -0
  28. sae_lens/sae.py +0 -747
  29. sae_lens/sae_training_runner.py +0 -251
  30. sae_lens/training/geometric_median.py +0 -101
  31. sae_lens/training/training_sae.py +0 -710
  32. sae_lens-5.10.7.dist-info/RECORD +0 -28
  33. /sae_lens/{toolkit → loading}/__init__.py +0 -0
  34. /sae_lens/{toolkit → loading}/pretrained_saes_directory.py +0 -0
  35. {sae_lens-5.10.7.dist-info → sae_lens-6.0.0.dist-info}/LICENSE +0 -0
  36. {sae_lens-5.10.7.dist-info → sae_lens-6.0.0.dist-info}/WHEEL +0 -0
@@ -101,61 +101,85 @@ def _get_main_lr_scheduler(
101
101
  raise ValueError(f"Unsupported scheduler: {scheduler_name}")
102
102
 
103
103
 
104
- class L1Scheduler:
104
+ class CoefficientScheduler:
105
+ """Linearly warms up a scalar value from 0.0 to a final value."""
106
+
105
107
  def __init__(
106
108
  self,
107
- l1_warm_up_steps: float,
108
- total_steps: int,
109
- final_l1_coefficient: float,
109
+ warm_up_steps: float,
110
+ final_value: float,
110
111
  ):
111
- self.l1_warmup_steps = l1_warm_up_steps
112
- # assume using warm-up
113
- if self.l1_warmup_steps != 0:
114
- self.current_l1_coefficient = 0.0
115
- else:
116
- self.current_l1_coefficient = final_l1_coefficient
117
-
118
- self.final_l1_coefficient = final_l1_coefficient
119
-
112
+ self.warm_up_steps = warm_up_steps
113
+ self.final_value = final_value
120
114
  self.current_step = 0
121
- self.total_steps = total_steps
122
- if not isinstance(self.final_l1_coefficient, (float, int)):
115
+
116
+ if not isinstance(self.final_value, (float, int)):
123
117
  raise TypeError(
124
- f"final_l1_coefficient must be float or int, got {type(self.final_l1_coefficient)}."
118
+ f"final_value must be float or int, got {type(self.final_value)}."
125
119
  )
126
120
 
121
+ # Initialize current_value based on whether warm-up is used
122
+ if self.warm_up_steps > 0:
123
+ self.current_value = 0.0
124
+ else:
125
+ self.current_value = self.final_value
126
+
127
127
  def __repr__(self) -> str:
128
128
  return (
129
- f"L1Scheduler(final_l1_value={self.final_l1_coefficient}, "
130
- f"l1_warmup_steps={self.l1_warmup_steps}, "
131
- f"total_steps={self.total_steps})"
129
+ f"{self.__class__.__name__}(final_value={self.final_value}, "
130
+ f"warm_up_steps={self.warm_up_steps})"
132
131
  )
133
132
 
134
- def step(self):
133
+ def step(self) -> float:
135
134
  """
136
- Updates the l1 coefficient of the sparse autoencoder.
135
+ Updates the scalar value based on the current step.
136
+
137
+ Returns:
138
+ The current scalar value after the step.
137
139
  """
138
- step = self.current_step
139
- if step < self.l1_warmup_steps:
140
- self.current_l1_coefficient = self.final_l1_coefficient * (
141
- (1 + step) / self.l1_warmup_steps
142
- ) # type: ignore
140
+ if self.current_step < self.warm_up_steps:
141
+ self.current_value = self.final_value * (
142
+ (self.current_step + 1) / self.warm_up_steps
143
+ )
143
144
  else:
144
- self.current_l1_coefficient = self.final_l1_coefficient # type: ignore
145
+ # Ensure the value stays at final_value after warm-up
146
+ self.current_value = self.final_value
145
147
 
146
148
  self.current_step += 1
149
+ return self.current_value
147
150
 
148
- def state_dict(self):
149
- """State dict for serializing as part of an SAETrainContext."""
151
+ @property
152
+ def value(self) -> float:
153
+ """Returns the current scalar value."""
154
+ return self.current_value
155
+
156
+ def state_dict(self) -> dict[str, Any]:
157
+ """State dict for serialization."""
150
158
  return {
151
- "l1_warmup_steps": self.l1_warmup_steps,
152
- "total_steps": self.total_steps,
153
- "current_l1_coefficient": self.current_l1_coefficient,
154
- "final_l1_coefficient": self.final_l1_coefficient,
159
+ "warm_up_steps": self.warm_up_steps,
160
+ "final_value": self.final_value,
155
161
  "current_step": self.current_step,
162
+ "current_value": self.current_value,
156
163
  }
157
164
 
158
165
  def load_state_dict(self, state_dict: dict[str, Any]):
159
- """Loads all state apart from attached SAE."""
160
- for k in state_dict:
161
- setattr(self, k, state_dict[k])
166
+ """Loads the scheduler state."""
167
+ self.warm_up_steps = state_dict["warm_up_steps"]
168
+ self.final_value = state_dict["final_value"]
169
+ self.current_step = state_dict["current_step"]
170
+ # Maintain consistency: re-calculate current_value based on loaded step
171
+ # This handles resuming correctly if stopped mid-warmup.
172
+ if self.current_step <= self.warm_up_steps and self.warm_up_steps > 0:
173
+ # Use max(0, ...) to handle case where current_step might be loaded as -1 or similar before first step
174
+ step_for_calc = max(0, self.current_step)
175
+ # Recalculate based on the step *before* the one about to be taken
176
+ # Or simply use the saved current_value if available and consistent
177
+ if "current_value" in state_dict:
178
+ self.current_value = state_dict["current_value"]
179
+ else: # Legacy state dicts might not have current_value
180
+ self.current_value = self.final_value * (
181
+ step_for_calc / self.warm_up_steps
182
+ )
183
+
184
+ else:
185
+ self.current_value = self.final_value
@@ -1,26 +1,28 @@
1
1
  import contextlib
2
2
  from dataclasses import dataclass
3
- from typing import Any, Protocol, cast
3
+ from pathlib import Path
4
+ from typing import Any, Callable, Generic, Protocol
4
5
 
5
6
  import torch
6
7
  import wandb
8
+ from safetensors.torch import save_file
7
9
  from torch.optim import Adam
8
- from tqdm import tqdm
9
- from transformer_lens.hook_points import HookedRootModule
10
+ from tqdm.auto import tqdm
10
11
 
11
12
  from sae_lens import __version__
12
- from sae_lens.config import LanguageModelSAERunnerConfig
13
- from sae_lens.evals import EvalConfig, run_evals
14
- from sae_lens.training.activations_store import ActivationsStore
15
- from sae_lens.training.optim import L1Scheduler, get_lr_scheduler
16
- from sae_lens.training.training_sae import TrainingSAE, TrainStepOutput
17
-
18
- # used to map between parameters which are updated during finetuning and the config str.
19
- FINETUNING_PARAMETERS = {
20
- "scale": ["scaling_factor"],
21
- "decoder": ["scaling_factor", "W_dec", "b_dec"],
22
- "unrotated_decoder": ["scaling_factor", "b_dec"],
23
- }
13
+ from sae_lens.config import SAETrainerConfig
14
+ from sae_lens.constants import ACTIVATION_SCALER_CFG_FILENAME, SPARSITY_FILENAME
15
+ from sae_lens.saes.sae import (
16
+ T_TRAINING_SAE,
17
+ T_TRAINING_SAE_CONFIG,
18
+ TrainCoefficientConfig,
19
+ TrainingSAE,
20
+ TrainStepInput,
21
+ TrainStepOutput,
22
+ )
23
+ from sae_lens.training.activation_scaler import ActivationScaler
24
+ from sae_lens.training.optim import CoefficientScheduler, get_lr_scheduler
25
+ from sae_lens.training.types import DataProvider
24
26
 
25
27
 
26
28
  def _log_feature_sparsity(
@@ -29,7 +31,7 @@ def _log_feature_sparsity(
29
31
  return torch.log10(feature_sparsity + eps).detach().cpu()
30
32
 
31
33
 
32
- def _update_sae_lens_training_version(sae: TrainingSAE) -> None:
34
+ def _update_sae_lens_training_version(sae: TrainingSAE[Any]) -> None:
33
35
  """
34
36
  Make sure we record the version of SAELens used for the training run
35
37
  """
@@ -38,7 +40,7 @@ def _update_sae_lens_training_version(sae: TrainingSAE) -> None:
38
40
 
39
41
  @dataclass
40
42
  class TrainSAEOutput:
41
- sae: TrainingSAE
43
+ sae: TrainingSAE[Any]
42
44
  checkpoint_path: str
43
45
  log_feature_sparsities: torch.Tensor
44
46
 
@@ -46,33 +48,39 @@ class TrainSAEOutput:
46
48
  class SaveCheckpointFn(Protocol):
47
49
  def __call__(
48
50
  self,
49
- trainer: "SAETrainer",
50
- checkpoint_name: str,
51
- wandb_aliases: list[str] | None = None,
51
+ checkpoint_path: Path,
52
52
  ) -> None: ...
53
53
 
54
54
 
55
- class SAETrainer:
55
+ Evaluator = Callable[[T_TRAINING_SAE, DataProvider, ActivationScaler], dict[str, Any]]
56
+
57
+
58
+ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
56
59
  """
57
60
  Core SAE class used for inference. For training, see TrainingSAE.
58
61
  """
59
62
 
63
+ data_provider: DataProvider
64
+ activation_scaler: ActivationScaler
65
+ evaluator: Evaluator[T_TRAINING_SAE] | None
66
+
60
67
  def __init__(
61
68
  self,
62
- model: HookedRootModule,
63
- sae: TrainingSAE,
64
- activation_store: ActivationsStore,
65
- save_checkpoint_fn: SaveCheckpointFn,
66
- cfg: LanguageModelSAERunnerConfig,
69
+ cfg: SAETrainerConfig,
70
+ sae: T_TRAINING_SAE,
71
+ data_provider: DataProvider,
72
+ evaluator: Evaluator[T_TRAINING_SAE] | None = None,
73
+ save_checkpoint_fn: SaveCheckpointFn | None = None,
67
74
  ) -> None:
68
- self.model = model
69
75
  self.sae = sae
70
- self.activations_store = activation_store
71
- self.save_checkpoint = save_checkpoint_fn
76
+ self.data_provider = data_provider
77
+ self.evaluator = evaluator
78
+ self.activation_scaler = ActivationScaler()
79
+ self.save_checkpoint_fn = save_checkpoint_fn
72
80
  self.cfg = cfg
73
81
 
74
82
  self.n_training_steps: int = 0
75
- self.n_training_tokens: int = 0
83
+ self.n_training_samples: int = 0
76
84
  self.started_fine_tuning: bool = False
77
85
 
78
86
  _update_sae_lens_training_version(self.sae)
@@ -82,20 +90,16 @@ class SAETrainer:
82
90
  self.checkpoint_thresholds = list(
83
91
  range(
84
92
  0,
85
- cfg.total_training_tokens,
86
- cfg.total_training_tokens // self.cfg.n_checkpoints,
93
+ cfg.total_training_samples,
94
+ cfg.total_training_samples // self.cfg.n_checkpoints,
87
95
  )
88
96
  )[1:]
89
97
 
90
- self.act_freq_scores = torch.zeros(
91
- cast(int, cfg.d_sae),
92
- device=cfg.device,
93
- )
98
+ self.act_freq_scores = torch.zeros(sae.cfg.d_sae, device=cfg.device)
94
99
  self.n_forward_passes_since_fired = torch.zeros(
95
- cast(int, cfg.d_sae),
96
- device=cfg.device,
100
+ sae.cfg.d_sae, device=cfg.device
97
101
  )
98
- self.n_frac_active_tokens = 0
102
+ self.n_frac_active_samples = 0
99
103
  # we don't train the scaling factor (initially)
100
104
  # set requires grad to false for the scaling factor
101
105
  for name, param in self.sae.named_parameters():
@@ -121,14 +125,17 @@ class SAETrainer:
121
125
  lr_end=cfg.lr_end,
122
126
  num_cycles=cfg.n_restart_cycles,
123
127
  )
124
- self.l1_scheduler = L1Scheduler(
125
- l1_warm_up_steps=cfg.l1_warm_up_steps,
126
- total_steps=cfg.total_training_steps,
127
- final_l1_coefficient=cfg.l1_coefficient,
128
- )
128
+ self.coefficient_schedulers = {}
129
+ for name, coeff_cfg in self.sae.get_coefficients().items():
130
+ if not isinstance(coeff_cfg, TrainCoefficientConfig):
131
+ coeff_cfg = TrainCoefficientConfig(value=coeff_cfg, warm_up_steps=0)
132
+ self.coefficient_schedulers[name] = CoefficientScheduler(
133
+ warm_up_steps=coeff_cfg.warm_up_steps,
134
+ final_value=coeff_cfg.value,
135
+ )
129
136
 
130
137
  # Setup autocast if using
131
- self.scaler = torch.amp.GradScaler(
138
+ self.grad_scaler = torch.amp.GradScaler(
132
139
  device=self.cfg.device, enabled=self.cfg.autocast
133
140
  )
134
141
 
@@ -141,52 +148,39 @@ class SAETrainer:
141
148
  else:
142
149
  self.autocast_if_enabled = contextlib.nullcontext()
143
150
 
144
- # Set up eval config
145
-
146
- self.trainer_eval_config = EvalConfig(
147
- batch_size_prompts=self.cfg.eval_batch_size_prompts,
148
- n_eval_reconstruction_batches=self.cfg.n_eval_batches,
149
- n_eval_sparsity_variance_batches=self.cfg.n_eval_batches,
150
- compute_ce_loss=True,
151
- compute_l2_norms=True,
152
- compute_sparsity_metrics=True,
153
- compute_variance_metrics=True,
154
- compute_kl=False,
155
- compute_featurewise_weight_based_metrics=False,
156
- )
157
-
158
151
  @property
159
152
  def feature_sparsity(self) -> torch.Tensor:
160
- return self.act_freq_scores / self.n_frac_active_tokens
153
+ return self.act_freq_scores / self.n_frac_active_samples
161
154
 
162
155
  @property
163
156
  def log_feature_sparsity(self) -> torch.Tensor:
164
157
  return _log_feature_sparsity(self.feature_sparsity)
165
158
 
166
- @property
167
- def current_l1_coefficient(self) -> float:
168
- return self.l1_scheduler.current_l1_coefficient
169
-
170
159
  @property
171
160
  def dead_neurons(self) -> torch.Tensor:
172
161
  return (self.n_forward_passes_since_fired > self.cfg.dead_feature_window).bool()
173
162
 
174
- def fit(self) -> TrainingSAE:
175
- pbar = tqdm(total=self.cfg.total_training_tokens, desc="Training SAE")
163
+ def fit(self) -> T_TRAINING_SAE:
164
+ self.sae.to(self.cfg.device)
165
+ pbar = tqdm(total=self.cfg.total_training_samples, desc="Training SAE")
176
166
 
177
- self.activations_store.set_norm_scaling_factor_if_needed()
167
+ if self.sae.cfg.normalize_activations == "expected_average_only_in":
168
+ self.activation_scaler.estimate_scaling_factor(
169
+ d_in=self.sae.cfg.d_in,
170
+ data_provider=self.data_provider,
171
+ n_batches_for_norm_estimate=int(1e3),
172
+ )
178
173
 
179
174
  # Train loop
180
- while self.n_training_tokens < self.cfg.total_training_tokens:
175
+ while self.n_training_samples < self.cfg.total_training_samples:
181
176
  # Do a training step.
182
- layer_acts = self.activations_store.next_batch()[:, 0, :].to(
183
- self.sae.device
184
- )
185
- self.n_training_tokens += self.cfg.train_batch_size_tokens
177
+ batch = next(self.data_provider).to(self.sae.device)
178
+ self.n_training_samples += batch.shape[0]
179
+ scaled_batch = self.activation_scaler(batch)
186
180
 
187
- step_output = self._train_step(sae=self.sae, sae_in=layer_acts)
181
+ step_output = self._train_step(sae=self.sae, sae_in=scaled_batch)
188
182
 
189
- if self.cfg.log_to_wandb:
183
+ if self.cfg.logger.log_to_wandb:
190
184
  self._log_train_step(step_output)
191
185
  self._run_and_log_evals()
192
186
 
@@ -194,39 +188,67 @@ class SAETrainer:
194
188
  self.n_training_steps += 1
195
189
  self._update_pbar(step_output, pbar)
196
190
 
197
- ### If n_training_tokens > sae_group.cfg.training_tokens, then we should switch to fine-tuning (if we haven't already)
198
- self._begin_finetuning_if_needed()
199
-
200
191
  # fold the estimated norm scaling factor into the sae weights
201
- if self.activations_store.estimated_norm_scaling_factor is not None:
192
+ if self.activation_scaler.scaling_factor is not None:
202
193
  self.sae.fold_activation_norm_scaling_factor(
203
- self.activations_store.estimated_norm_scaling_factor
194
+ self.activation_scaler.scaling_factor
204
195
  )
205
- self.activations_store.estimated_norm_scaling_factor = None
196
+ self.activation_scaler.scaling_factor = None
206
197
 
207
- # save final sae group to checkpoints folder
198
+ # save final inference sae group to checkpoints folder
208
199
  self.save_checkpoint(
209
- trainer=self,
210
- checkpoint_name=f"final_{self.n_training_tokens}",
200
+ checkpoint_name=f"final_{self.n_training_samples}",
211
201
  wandb_aliases=["final_model"],
202
+ save_inference_model=True,
212
203
  )
213
204
 
214
205
  pbar.close()
215
206
  return self.sae
216
207
 
208
+ def save_checkpoint(
209
+ self,
210
+ checkpoint_name: str,
211
+ wandb_aliases: list[str] | None = None,
212
+ save_inference_model: bool = False,
213
+ ) -> None:
214
+ checkpoint_path = Path(self.cfg.checkpoint_path) / checkpoint_name
215
+ checkpoint_path.mkdir(exist_ok=True, parents=True)
216
+
217
+ save_fn = (
218
+ self.sae.save_inference_model
219
+ if save_inference_model
220
+ else self.sae.save_model
221
+ )
222
+ weights_path, cfg_path = save_fn(str(checkpoint_path))
223
+
224
+ sparsity_path = checkpoint_path / SPARSITY_FILENAME
225
+ save_file({"sparsity": self.log_feature_sparsity}, sparsity_path)
226
+
227
+ activation_scaler_path = checkpoint_path / ACTIVATION_SCALER_CFG_FILENAME
228
+ self.activation_scaler.save(str(activation_scaler_path))
229
+
230
+ if self.cfg.logger.log_to_wandb:
231
+ self.cfg.logger.log(
232
+ self,
233
+ weights_path,
234
+ cfg_path,
235
+ sparsity_path=sparsity_path,
236
+ wandb_aliases=wandb_aliases,
237
+ )
238
+
239
+ if self.save_checkpoint_fn is not None:
240
+ self.save_checkpoint_fn(checkpoint_path=checkpoint_path)
241
+
217
242
  def _train_step(
218
243
  self,
219
- sae: TrainingSAE,
244
+ sae: T_TRAINING_SAE,
220
245
  sae_in: torch.Tensor,
221
246
  ) -> TrainStepOutput:
222
247
  sae.train()
223
- # Make sure the W_dec is still zero-norm
224
- if self.cfg.normalize_sae_decoder:
225
- sae.set_decoder_norm_to_unit_norm()
226
248
 
227
249
  # log and then reset the feature sparsity every feature_sampling_window steps
228
250
  if (self.n_training_steps + 1) % self.cfg.feature_sampling_window == 0:
229
- if self.cfg.log_to_wandb:
251
+ if self.cfg.logger.log_to_wandb:
230
252
  sparsity_log_dict = self._build_sparsity_log_dict()
231
253
  wandb.log(sparsity_log_dict, step=self.n_training_steps)
232
254
  self._reset_running_sparsity_stats()
@@ -235,9 +257,11 @@ class SAETrainer:
235
257
  # https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html
236
258
  with self.autocast_if_enabled:
237
259
  train_step_output = self.sae.training_forward_pass(
238
- sae_in=sae_in,
239
- dead_neuron_mask=self.dead_neurons,
240
- current_l1_coefficient=self.current_l1_coefficient,
260
+ step_input=TrainStepInput(
261
+ sae_in=sae_in,
262
+ dead_neuron_mask=self.dead_neurons,
263
+ coefficients=self.get_coefficients(),
264
+ ),
241
265
  )
242
266
 
243
267
  with torch.no_grad():
@@ -247,43 +271,50 @@ class SAETrainer:
247
271
  self.act_freq_scores += (
248
272
  (train_step_output.feature_acts.abs() > 0).float().sum(0)
249
273
  )
250
- self.n_frac_active_tokens += self.cfg.train_batch_size_tokens
274
+ self.n_frac_active_samples += self.cfg.train_batch_size_samples
251
275
 
252
- # Scaler will rescale gradients if autocast is enabled
253
- self.scaler.scale(
276
+ # Grad scaler will rescale gradients if autocast is enabled
277
+ self.grad_scaler.scale(
254
278
  train_step_output.loss
255
279
  ).backward() # loss.backward() if not autocasting
256
- self.scaler.unscale_(self.optimizer) # needed to clip correctly
280
+ self.grad_scaler.unscale_(self.optimizer) # needed to clip correctly
257
281
  # TODO: Work out if grad norm clipping should be in config / how to test it.
258
282
  torch.nn.utils.clip_grad_norm_(sae.parameters(), 1.0)
259
- self.scaler.step(self.optimizer) # just ctx.optimizer.step() if not autocasting
260
- self.scaler.update()
261
-
262
- if self.cfg.normalize_sae_decoder:
263
- sae.remove_gradient_parallel_to_decoder_directions()
283
+ self.grad_scaler.step(
284
+ self.optimizer
285
+ ) # just ctx.optimizer.step() if not autocasting
286
+ self.grad_scaler.update()
264
287
 
265
288
  self.optimizer.zero_grad()
266
289
  self.lr_scheduler.step()
267
- self.l1_scheduler.step()
290
+ for scheduler in self.coefficient_schedulers.values():
291
+ scheduler.step()
268
292
 
269
293
  return train_step_output
270
294
 
271
295
  @torch.no_grad()
272
296
  def _log_train_step(self, step_output: TrainStepOutput):
273
- if (self.n_training_steps + 1) % self.cfg.wandb_log_frequency == 0:
297
+ if (self.n_training_steps + 1) % self.cfg.logger.wandb_log_frequency == 0:
274
298
  wandb.log(
275
299
  self._build_train_step_log_dict(
276
300
  output=step_output,
277
- n_training_tokens=self.n_training_tokens,
301
+ n_training_samples=self.n_training_samples,
278
302
  ),
279
303
  step=self.n_training_steps,
280
304
  )
281
305
 
306
+ @torch.no_grad()
307
+ def get_coefficients(self) -> dict[str, float]:
308
+ return {
309
+ name: scheduler.value
310
+ for name, scheduler in self.coefficient_schedulers.items()
311
+ }
312
+
282
313
  @torch.no_grad()
283
314
  def _build_train_step_log_dict(
284
315
  self,
285
316
  output: TrainStepOutput,
286
- n_training_tokens: int,
317
+ n_training_samples: int,
287
318
  ) -> dict[str, Any]:
288
319
  sae_in = output.sae_in
289
320
  sae_out = output.sae_out
@@ -311,19 +342,15 @@ class SAETrainer:
311
342
  "sparsity/mean_passes_since_fired": self.n_forward_passes_since_fired.mean().item(),
312
343
  "sparsity/dead_features": self.dead_neurons.sum().item(),
313
344
  "details/current_learning_rate": current_learning_rate,
314
- "details/current_l1_coefficient": self.current_l1_coefficient,
315
- "details/n_training_tokens": n_training_tokens,
345
+ "details/n_training_samples": n_training_samples,
346
+ **{
347
+ f"details/{name}_coefficient": scheduler.value
348
+ for name, scheduler in self.coefficient_schedulers.items()
349
+ },
316
350
  }
317
351
  for loss_name, loss_value in output.losses.items():
318
352
  loss_item = _unwrap_item(loss_value)
319
- # special case for l1 loss, which we normalize by the l1 coefficient
320
- if loss_name == "l1_loss":
321
- log_dict[f"losses/{loss_name}"] = (
322
- loss_item / self.current_l1_coefficient
323
- )
324
- log_dict[f"losses/raw_{loss_name}"] = loss_item
325
- else:
326
- log_dict[f"losses/{loss_name}"] = loss_item
353
+ log_dict[f"losses/{loss_name}"] = loss_item
327
354
 
328
355
  return log_dict
329
356
 
@@ -331,44 +358,17 @@ class SAETrainer:
331
358
  def _run_and_log_evals(self):
332
359
  # record loss frequently, but not all the time.
333
360
  if (self.n_training_steps + 1) % (
334
- self.cfg.wandb_log_frequency * self.cfg.eval_every_n_wandb_logs
361
+ self.cfg.logger.wandb_log_frequency
362
+ * self.cfg.logger.eval_every_n_wandb_logs
335
363
  ) == 0:
336
364
  self.sae.eval()
337
- ignore_tokens = set()
338
- if self.activations_store.exclude_special_tokens is not None:
339
- ignore_tokens = set(
340
- self.activations_store.exclude_special_tokens.tolist()
341
- )
342
- eval_metrics, _ = run_evals(
343
- sae=self.sae,
344
- activation_store=self.activations_store,
345
- model=self.model,
346
- eval_config=self.trainer_eval_config,
347
- ignore_tokens=ignore_tokens,
348
- model_kwargs=self.cfg.model_kwargs,
349
- ) # not calculating featurwise metrics here.
350
-
351
- # Remove eval metrics that are already logged during training
352
- eval_metrics.pop("metrics/explained_variance", None)
353
- eval_metrics.pop("metrics/explained_variance_std", None)
354
- eval_metrics.pop("metrics/l0", None)
355
- eval_metrics.pop("metrics/l1", None)
356
- eval_metrics.pop("metrics/mse", None)
357
-
358
- # Remove metrics that are not useful for wandb logging
359
- eval_metrics.pop("metrics/total_tokens_evaluated", None)
360
-
361
- W_dec_norm_dist = self.sae.W_dec.detach().float().norm(dim=1).cpu().numpy()
362
- eval_metrics["weights/W_dec_norms"] = wandb.Histogram(W_dec_norm_dist) # type: ignore
363
-
364
- if self.sae.cfg.architecture == "standard":
365
- b_e_dist = self.sae.b_enc.detach().float().cpu().numpy()
366
- eval_metrics["weights/b_e"] = wandb.Histogram(b_e_dist) # type: ignore
367
- elif self.sae.cfg.architecture == "gated":
368
- b_gate_dist = self.sae.b_gate.detach().float().cpu().numpy()
369
- eval_metrics["weights/b_gate"] = wandb.Histogram(b_gate_dist) # type: ignore
370
- b_mag_dist = self.sae.b_mag.detach().float().cpu().numpy()
371
- eval_metrics["weights/b_mag"] = wandb.Histogram(b_mag_dist) # type: ignore
365
+ eval_metrics = (
366
+ self.evaluator(self.sae, self.data_provider, self.activation_scaler)
367
+ if self.evaluator is not None
368
+ else {}
369
+ )
370
+ for key, value in self.sae.log_histograms().items():
371
+ eval_metrics[key] = wandb.Histogram(value) # type: ignore
372
372
 
373
373
  wandb.log(
374
374
  eval_metrics,
@@ -390,21 +390,18 @@ class SAETrainer:
390
390
  @torch.no_grad()
391
391
  def _reset_running_sparsity_stats(self) -> None:
392
392
  self.act_freq_scores = torch.zeros(
393
- self.cfg.d_sae, # type: ignore
393
+ self.sae.cfg.d_sae, # type: ignore
394
394
  device=self.cfg.device,
395
395
  )
396
- self.n_frac_active_tokens = 0
396
+ self.n_frac_active_samples = 0
397
397
 
398
398
  @torch.no_grad()
399
399
  def _checkpoint_if_needed(self):
400
400
  if (
401
401
  self.checkpoint_thresholds
402
- and self.n_training_tokens > self.checkpoint_thresholds[0]
402
+ and self.n_training_samples > self.checkpoint_thresholds[0]
403
403
  ):
404
- self.save_checkpoint(
405
- trainer=self,
406
- checkpoint_name=str(self.n_training_tokens),
407
- )
404
+ self.save_checkpoint(checkpoint_name=str(self.n_training_samples))
408
405
  self.checkpoint_thresholds.pop(0)
409
406
 
410
407
  @torch.no_grad()
@@ -420,26 +417,7 @@ class SAETrainer:
420
417
  for loss_name, loss_value in step_output.losses.items()
421
418
  )
422
419
  pbar.set_description(f"{self.n_training_steps}| {loss_strs}")
423
- pbar.update(update_interval * self.cfg.train_batch_size_tokens)
424
-
425
- def _begin_finetuning_if_needed(self):
426
- if (not self.started_fine_tuning) and (
427
- self.n_training_tokens > self.cfg.training_tokens
428
- ):
429
- self.started_fine_tuning = True
430
-
431
- # finetuning method should be set in the config
432
- # if not, then we don't finetune
433
- if not isinstance(self.cfg.finetuning_method, str):
434
- return
435
-
436
- for name, param in self.sae.named_parameters():
437
- if name in FINETUNING_PARAMETERS[self.cfg.finetuning_method]:
438
- param.requires_grad = True
439
- else:
440
- param.requires_grad = False
441
-
442
- self.finetuning = True
420
+ pbar.update(update_interval * self.cfg.train_batch_size_samples)
443
421
 
444
422
 
445
423
  def _unwrap_item(item: float | torch.Tensor) -> float:
@@ -0,0 +1,5 @@
1
+ from typing import Iterator
2
+
3
+ import torch
4
+
5
+ DataProvider = Iterator[torch.Tensor]