sae-lens 6.0.0rc1__py3-none-any.whl → 6.0.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -101,61 +101,85 @@ def _get_main_lr_scheduler(
101
101
  raise ValueError(f"Unsupported scheduler: {scheduler_name}")
102
102
 
103
103
 
104
- class L1Scheduler:
104
+ class CoefficientScheduler:
105
+ """Linearly warms up a scalar value from 0.0 to a final value."""
106
+
105
107
  def __init__(
106
108
  self,
107
- l1_warm_up_steps: float,
108
- total_steps: int,
109
- final_l1_coefficient: float,
109
+ warm_up_steps: float,
110
+ final_value: float,
110
111
  ):
111
- self.l1_warmup_steps = l1_warm_up_steps
112
- # assume using warm-up
113
- if self.l1_warmup_steps != 0:
114
- self.current_l1_coefficient = 0.0
115
- else:
116
- self.current_l1_coefficient = final_l1_coefficient
117
-
118
- self.final_l1_coefficient = final_l1_coefficient
119
-
112
+ self.warm_up_steps = warm_up_steps
113
+ self.final_value = final_value
120
114
  self.current_step = 0
121
- self.total_steps = total_steps
122
- if not isinstance(self.final_l1_coefficient, (float, int)):
115
+
116
+ if not isinstance(self.final_value, (float, int)):
123
117
  raise TypeError(
124
- f"final_l1_coefficient must be float or int, got {type(self.final_l1_coefficient)}."
118
+ f"final_value must be float or int, got {type(self.final_value)}."
125
119
  )
126
120
 
121
+ # Initialize current_value based on whether warm-up is used
122
+ if self.warm_up_steps > 0:
123
+ self.current_value = 0.0
124
+ else:
125
+ self.current_value = self.final_value
126
+
127
127
  def __repr__(self) -> str:
128
128
  return (
129
- f"L1Scheduler(final_l1_value={self.final_l1_coefficient}, "
130
- f"l1_warmup_steps={self.l1_warmup_steps}, "
131
- f"total_steps={self.total_steps})"
129
+ f"{self.__class__.__name__}(final_value={self.final_value}, "
130
+ f"warm_up_steps={self.warm_up_steps})"
132
131
  )
133
132
 
134
- def step(self):
133
+ def step(self) -> float:
135
134
  """
136
- Updates the l1 coefficient of the sparse autoencoder.
135
+ Updates the scalar value based on the current step.
136
+
137
+ Returns:
138
+ The current scalar value after the step.
137
139
  """
138
- step = self.current_step
139
- if step < self.l1_warmup_steps:
140
- self.current_l1_coefficient = self.final_l1_coefficient * (
141
- (1 + step) / self.l1_warmup_steps
142
- ) # type: ignore
140
+ if self.current_step < self.warm_up_steps:
141
+ self.current_value = self.final_value * (
142
+ (self.current_step + 1) / self.warm_up_steps
143
+ )
143
144
  else:
144
- self.current_l1_coefficient = self.final_l1_coefficient # type: ignore
145
+ # Ensure the value stays at final_value after warm-up
146
+ self.current_value = self.final_value
145
147
 
146
148
  self.current_step += 1
149
+ return self.current_value
147
150
 
148
- def state_dict(self):
149
- """State dict for serializing as part of an SAETrainContext."""
151
+ @property
152
+ def value(self) -> float:
153
+ """Returns the current scalar value."""
154
+ return self.current_value
155
+
156
+ def state_dict(self) -> dict[str, Any]:
157
+ """State dict for serialization."""
150
158
  return {
151
- "l1_warmup_steps": self.l1_warmup_steps,
152
- "total_steps": self.total_steps,
153
- "current_l1_coefficient": self.current_l1_coefficient,
154
- "final_l1_coefficient": self.final_l1_coefficient,
159
+ "warm_up_steps": self.warm_up_steps,
160
+ "final_value": self.final_value,
155
161
  "current_step": self.current_step,
162
+ "current_value": self.current_value,
156
163
  }
157
164
 
158
165
  def load_state_dict(self, state_dict: dict[str, Any]):
159
- """Loads all state apart from attached SAE."""
160
- for k in state_dict:
161
- setattr(self, k, state_dict[k])
166
+ """Loads the scheduler state."""
167
+ self.warm_up_steps = state_dict["warm_up_steps"]
168
+ self.final_value = state_dict["final_value"]
169
+ self.current_step = state_dict["current_step"]
170
+ # Maintain consistency: re-calculate current_value based on loaded step
171
+ # This handles resuming correctly if stopped mid-warmup.
172
+ if self.current_step <= self.warm_up_steps and self.warm_up_steps > 0:
173
+ # Use max(0, ...) to handle case where current_step might be loaded as -1 or similar before first step
174
+ step_for_calc = max(0, self.current_step)
175
+ # Recalculate based on the step *before* the one about to be taken
176
+ # Or simply use the saved current_value if available and consistent
177
+ if "current_value" in state_dict:
178
+ self.current_value = state_dict["current_value"]
179
+ else: # Legacy state dicts might not have current_value
180
+ self.current_value = self.final_value * (
181
+ step_for_calc / self.warm_up_steps
182
+ )
183
+
184
+ else:
185
+ self.current_value = self.final_value
@@ -1,26 +1,28 @@
1
1
  import contextlib
2
2
  from dataclasses import dataclass
3
- from typing import Any, Protocol, cast
3
+ from pathlib import Path
4
+ from typing import Any, Callable, Generic, Protocol
4
5
 
5
6
  import torch
6
7
  import wandb
8
+ from safetensors.torch import save_file
7
9
  from torch.optim import Adam
8
10
  from tqdm import tqdm
9
- from transformer_lens.hook_points import HookedRootModule
10
11
 
11
12
  from sae_lens import __version__
12
- from sae_lens.config import LanguageModelSAERunnerConfig
13
- from sae_lens.evals import EvalConfig, run_evals
14
- from sae_lens.saes.sae import TrainingSAE, TrainStepInput, TrainStepOutput
15
- from sae_lens.training.activations_store import ActivationsStore
16
- from sae_lens.training.optim import L1Scheduler, get_lr_scheduler
17
-
18
- # used to map between parameters which are updated during finetuning and the config str.
19
- FINETUNING_PARAMETERS = {
20
- "scale": ["scaling_factor"],
21
- "decoder": ["scaling_factor", "W_dec", "b_dec"],
22
- "unrotated_decoder": ["scaling_factor", "b_dec"],
23
- }
13
+ from sae_lens.config import SAETrainerConfig
14
+ from sae_lens.constants import ACTIVATION_SCALER_CFG_FILENAME, SPARSITY_FILENAME
15
+ from sae_lens.saes.sae import (
16
+ T_TRAINING_SAE,
17
+ T_TRAINING_SAE_CONFIG,
18
+ TrainCoefficientConfig,
19
+ TrainingSAE,
20
+ TrainStepInput,
21
+ TrainStepOutput,
22
+ )
23
+ from sae_lens.training.activation_scaler import ActivationScaler
24
+ from sae_lens.training.optim import CoefficientScheduler, get_lr_scheduler
25
+ from sae_lens.training.types import DataProvider
24
26
 
25
27
 
26
28
  def _log_feature_sparsity(
@@ -29,7 +31,7 @@ def _log_feature_sparsity(
29
31
  return torch.log10(feature_sparsity + eps).detach().cpu()
30
32
 
31
33
 
32
- def _update_sae_lens_training_version(sae: TrainingSAE) -> None:
34
+ def _update_sae_lens_training_version(sae: TrainingSAE[Any]) -> None:
33
35
  """
34
36
  Make sure we record the version of SAELens used for the training run
35
37
  """
@@ -38,7 +40,7 @@ def _update_sae_lens_training_version(sae: TrainingSAE) -> None:
38
40
 
39
41
  @dataclass
40
42
  class TrainSAEOutput:
41
- sae: TrainingSAE
43
+ sae: TrainingSAE[Any]
42
44
  checkpoint_path: str
43
45
  log_feature_sparsities: torch.Tensor
44
46
 
@@ -46,33 +48,39 @@ class TrainSAEOutput:
46
48
  class SaveCheckpointFn(Protocol):
47
49
  def __call__(
48
50
  self,
49
- trainer: "SAETrainer",
50
- checkpoint_name: str,
51
- wandb_aliases: list[str] | None = None,
51
+ checkpoint_path: Path,
52
52
  ) -> None: ...
53
53
 
54
54
 
55
- class SAETrainer:
55
+ Evaluator = Callable[[T_TRAINING_SAE, DataProvider, ActivationScaler], dict[str, Any]]
56
+
57
+
58
+ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
56
59
  """
57
60
  Core SAE class used for inference. For training, see TrainingSAE.
58
61
  """
59
62
 
63
+ data_provider: DataProvider
64
+ activation_scaler: ActivationScaler
65
+ evaluator: Evaluator[T_TRAINING_SAE] | None
66
+
60
67
  def __init__(
61
68
  self,
62
- model: HookedRootModule,
63
- sae: TrainingSAE,
64
- activation_store: ActivationsStore,
65
- save_checkpoint_fn: SaveCheckpointFn,
66
- cfg: LanguageModelSAERunnerConfig,
69
+ cfg: SAETrainerConfig,
70
+ sae: T_TRAINING_SAE,
71
+ data_provider: DataProvider,
72
+ evaluator: Evaluator[T_TRAINING_SAE] | None = None,
73
+ save_checkpoint_fn: SaveCheckpointFn | None = None,
67
74
  ) -> None:
68
- self.model = model
69
75
  self.sae = sae
70
- self.activations_store = activation_store
71
- self.save_checkpoint = save_checkpoint_fn
76
+ self.data_provider = data_provider
77
+ self.evaluator = evaluator
78
+ self.activation_scaler = ActivationScaler()
79
+ self.save_checkpoint_fn = save_checkpoint_fn
72
80
  self.cfg = cfg
73
81
 
74
82
  self.n_training_steps: int = 0
75
- self.n_training_tokens: int = 0
83
+ self.n_training_samples: int = 0
76
84
  self.started_fine_tuning: bool = False
77
85
 
78
86
  _update_sae_lens_training_version(self.sae)
@@ -82,20 +90,16 @@ class SAETrainer:
82
90
  self.checkpoint_thresholds = list(
83
91
  range(
84
92
  0,
85
- cfg.total_training_tokens,
86
- cfg.total_training_tokens // self.cfg.n_checkpoints,
93
+ cfg.total_training_samples,
94
+ cfg.total_training_samples // self.cfg.n_checkpoints,
87
95
  )
88
96
  )[1:]
89
97
 
90
- self.act_freq_scores = torch.zeros(
91
- cast(int, cfg.d_sae),
92
- device=cfg.device,
93
- )
98
+ self.act_freq_scores = torch.zeros(sae.cfg.d_sae, device=cfg.device)
94
99
  self.n_forward_passes_since_fired = torch.zeros(
95
- cast(int, cfg.d_sae),
96
- device=cfg.device,
100
+ sae.cfg.d_sae, device=cfg.device
97
101
  )
98
- self.n_frac_active_tokens = 0
102
+ self.n_frac_active_samples = 0
99
103
  # we don't train the scaling factor (initially)
100
104
  # set requires grad to false for the scaling factor
101
105
  for name, param in self.sae.named_parameters():
@@ -121,14 +125,17 @@ class SAETrainer:
121
125
  lr_end=cfg.lr_end,
122
126
  num_cycles=cfg.n_restart_cycles,
123
127
  )
124
- self.l1_scheduler = L1Scheduler(
125
- l1_warm_up_steps=cfg.l1_warm_up_steps,
126
- total_steps=cfg.total_training_steps,
127
- final_l1_coefficient=cfg.l1_coefficient,
128
- )
128
+ self.coefficient_schedulers = {}
129
+ for name, coeff_cfg in self.sae.get_coefficients().items():
130
+ if not isinstance(coeff_cfg, TrainCoefficientConfig):
131
+ coeff_cfg = TrainCoefficientConfig(value=coeff_cfg, warm_up_steps=0)
132
+ self.coefficient_schedulers[name] = CoefficientScheduler(
133
+ warm_up_steps=coeff_cfg.warm_up_steps,
134
+ final_value=coeff_cfg.value,
135
+ )
129
136
 
130
137
  # Setup autocast if using
131
- self.scaler = torch.amp.GradScaler(
138
+ self.grad_scaler = torch.amp.GradScaler(
132
139
  device=self.cfg.device, enabled=self.cfg.autocast
133
140
  )
134
141
 
@@ -141,50 +148,36 @@ class SAETrainer:
141
148
  else:
142
149
  self.autocast_if_enabled = contextlib.nullcontext()
143
150
 
144
- # Set up eval config
145
-
146
- self.trainer_eval_config = EvalConfig(
147
- batch_size_prompts=self.cfg.eval_batch_size_prompts,
148
- n_eval_reconstruction_batches=self.cfg.n_eval_batches,
149
- n_eval_sparsity_variance_batches=self.cfg.n_eval_batches,
150
- compute_ce_loss=True,
151
- compute_l2_norms=True,
152
- compute_sparsity_metrics=True,
153
- compute_variance_metrics=True,
154
- compute_kl=False,
155
- compute_featurewise_weight_based_metrics=False,
156
- )
157
-
158
151
  @property
159
152
  def feature_sparsity(self) -> torch.Tensor:
160
- return self.act_freq_scores / self.n_frac_active_tokens
153
+ return self.act_freq_scores / self.n_frac_active_samples
161
154
 
162
155
  @property
163
156
  def log_feature_sparsity(self) -> torch.Tensor:
164
157
  return _log_feature_sparsity(self.feature_sparsity)
165
158
 
166
- @property
167
- def current_l1_coefficient(self) -> float:
168
- return self.l1_scheduler.current_l1_coefficient
169
-
170
159
  @property
171
160
  def dead_neurons(self) -> torch.Tensor:
172
161
  return (self.n_forward_passes_since_fired > self.cfg.dead_feature_window).bool()
173
162
 
174
- def fit(self) -> TrainingSAE:
175
- pbar = tqdm(total=self.cfg.total_training_tokens, desc="Training SAE")
163
+ def fit(self) -> T_TRAINING_SAE:
164
+ pbar = tqdm(total=self.cfg.total_training_samples, desc="Training SAE")
176
165
 
177
- self.activations_store.set_norm_scaling_factor_if_needed()
166
+ if self.sae.cfg.normalize_activations == "expected_average_only_in":
167
+ self.activation_scaler.estimate_scaling_factor(
168
+ d_in=self.sae.cfg.d_in,
169
+ data_provider=self.data_provider,
170
+ n_batches_for_norm_estimate=int(1e3),
171
+ )
178
172
 
179
173
  # Train loop
180
- while self.n_training_tokens < self.cfg.total_training_tokens:
174
+ while self.n_training_samples < self.cfg.total_training_samples:
181
175
  # Do a training step.
182
- layer_acts = self.activations_store.next_batch()[:, 0, :].to(
183
- self.sae.device
184
- )
185
- self.n_training_tokens += self.cfg.train_batch_size_tokens
176
+ batch = next(self.data_provider).to(self.sae.device)
177
+ self.n_training_samples += batch.shape[0]
178
+ scaled_batch = self.activation_scaler(batch)
186
179
 
187
- step_output = self._train_step(sae=self.sae, sae_in=layer_acts)
180
+ step_output = self._train_step(sae=self.sae, sae_in=scaled_batch)
188
181
 
189
182
  if self.cfg.logger.log_to_wandb:
190
183
  self._log_train_step(step_output)
@@ -194,35 +187,56 @@ class SAETrainer:
194
187
  self.n_training_steps += 1
195
188
  self._update_pbar(step_output, pbar)
196
189
 
197
- ### If n_training_tokens > sae_group.cfg.training_tokens, then we should switch to fine-tuning (if we haven't already)
198
- self._begin_finetuning_if_needed()
199
-
200
190
  # fold the estimated norm scaling factor into the sae weights
201
- if self.activations_store.estimated_norm_scaling_factor is not None:
191
+ if self.activation_scaler.scaling_factor is not None:
202
192
  self.sae.fold_activation_norm_scaling_factor(
203
- self.activations_store.estimated_norm_scaling_factor
193
+ self.activation_scaler.scaling_factor
204
194
  )
205
- self.activations_store.estimated_norm_scaling_factor = None
195
+ self.activation_scaler.scaling_factor = None
206
196
 
207
197
  # save final sae group to checkpoints folder
208
198
  self.save_checkpoint(
209
- trainer=self,
210
- checkpoint_name=f"final_{self.n_training_tokens}",
199
+ checkpoint_name=f"final_{self.n_training_samples}",
211
200
  wandb_aliases=["final_model"],
212
201
  )
213
202
 
214
203
  pbar.close()
215
204
  return self.sae
216
205
 
206
+ def save_checkpoint(
207
+ self,
208
+ checkpoint_name: str,
209
+ wandb_aliases: list[str] | None = None,
210
+ ) -> None:
211
+ checkpoint_path = Path(self.cfg.checkpoint_path) / checkpoint_name
212
+ checkpoint_path.mkdir(exist_ok=True, parents=True)
213
+
214
+ weights_path, cfg_path = self.sae.save_model(str(checkpoint_path))
215
+
216
+ sparsity_path = checkpoint_path / SPARSITY_FILENAME
217
+ save_file({"sparsity": self.log_feature_sparsity}, sparsity_path)
218
+
219
+ activation_scaler_path = checkpoint_path / ACTIVATION_SCALER_CFG_FILENAME
220
+ self.activation_scaler.save(str(activation_scaler_path))
221
+
222
+ if self.cfg.logger.log_to_wandb:
223
+ self.cfg.logger.log(
224
+ self,
225
+ weights_path,
226
+ cfg_path,
227
+ sparsity_path=sparsity_path,
228
+ wandb_aliases=wandb_aliases,
229
+ )
230
+
231
+ if self.save_checkpoint_fn is not None:
232
+ self.save_checkpoint_fn(checkpoint_path=checkpoint_path)
233
+
217
234
  def _train_step(
218
235
  self,
219
- sae: TrainingSAE,
236
+ sae: T_TRAINING_SAE,
220
237
  sae_in: torch.Tensor,
221
238
  ) -> TrainStepOutput:
222
239
  sae.train()
223
- # Make sure the W_dec is still zero-norm
224
- if self.cfg.normalize_sae_decoder:
225
- sae.set_decoder_norm_to_unit_norm()
226
240
 
227
241
  # log and then reset the feature sparsity every feature_sampling_window steps
228
242
  if (self.n_training_steps + 1) % self.cfg.feature_sampling_window == 0:
@@ -238,7 +252,7 @@ class SAETrainer:
238
252
  step_input=TrainStepInput(
239
253
  sae_in=sae_in,
240
254
  dead_neuron_mask=self.dead_neurons,
241
- current_l1_coefficient=self.current_l1_coefficient,
255
+ coefficients=self.get_coefficients(),
242
256
  ),
243
257
  )
244
258
 
@@ -249,24 +263,24 @@ class SAETrainer:
249
263
  self.act_freq_scores += (
250
264
  (train_step_output.feature_acts.abs() > 0).float().sum(0)
251
265
  )
252
- self.n_frac_active_tokens += self.cfg.train_batch_size_tokens
266
+ self.n_frac_active_samples += self.cfg.train_batch_size_samples
253
267
 
254
- # Scaler will rescale gradients if autocast is enabled
255
- self.scaler.scale(
268
+ # Grad scaler will rescale gradients if autocast is enabled
269
+ self.grad_scaler.scale(
256
270
  train_step_output.loss
257
271
  ).backward() # loss.backward() if not autocasting
258
- self.scaler.unscale_(self.optimizer) # needed to clip correctly
272
+ self.grad_scaler.unscale_(self.optimizer) # needed to clip correctly
259
273
  # TODO: Work out if grad norm clipping should be in config / how to test it.
260
274
  torch.nn.utils.clip_grad_norm_(sae.parameters(), 1.0)
261
- self.scaler.step(self.optimizer) # just ctx.optimizer.step() if not autocasting
262
- self.scaler.update()
263
-
264
- if self.cfg.normalize_sae_decoder:
265
- sae.remove_gradient_parallel_to_decoder_directions()
275
+ self.grad_scaler.step(
276
+ self.optimizer
277
+ ) # just ctx.optimizer.step() if not autocasting
278
+ self.grad_scaler.update()
266
279
 
267
280
  self.optimizer.zero_grad()
268
281
  self.lr_scheduler.step()
269
- self.l1_scheduler.step()
282
+ for scheduler in self.coefficient_schedulers.values():
283
+ scheduler.step()
270
284
 
271
285
  return train_step_output
272
286
 
@@ -276,16 +290,23 @@ class SAETrainer:
276
290
  wandb.log(
277
291
  self._build_train_step_log_dict(
278
292
  output=step_output,
279
- n_training_tokens=self.n_training_tokens,
293
+ n_training_samples=self.n_training_samples,
280
294
  ),
281
295
  step=self.n_training_steps,
282
296
  )
283
297
 
298
+ @torch.no_grad()
299
+ def get_coefficients(self) -> dict[str, float]:
300
+ return {
301
+ name: scheduler.value
302
+ for name, scheduler in self.coefficient_schedulers.items()
303
+ }
304
+
284
305
  @torch.no_grad()
285
306
  def _build_train_step_log_dict(
286
307
  self,
287
308
  output: TrainStepOutput,
288
- n_training_tokens: int,
309
+ n_training_samples: int,
289
310
  ) -> dict[str, Any]:
290
311
  sae_in = output.sae_in
291
312
  sae_out = output.sae_out
@@ -313,19 +334,15 @@ class SAETrainer:
313
334
  "sparsity/mean_passes_since_fired": self.n_forward_passes_since_fired.mean().item(),
314
335
  "sparsity/dead_features": self.dead_neurons.sum().item(),
315
336
  "details/current_learning_rate": current_learning_rate,
316
- "details/current_l1_coefficient": self.current_l1_coefficient,
317
- "details/n_training_tokens": n_training_tokens,
337
+ "details/n_training_samples": n_training_samples,
338
+ **{
339
+ f"details/{name}_coefficient": scheduler.value
340
+ for name, scheduler in self.coefficient_schedulers.items()
341
+ },
318
342
  }
319
343
  for loss_name, loss_value in output.losses.items():
320
344
  loss_item = _unwrap_item(loss_value)
321
- # special case for l1 loss, which we normalize by the l1 coefficient
322
- if loss_name == "l1_loss":
323
- log_dict[f"losses/{loss_name}"] = (
324
- loss_item / self.current_l1_coefficient
325
- )
326
- log_dict[f"losses/raw_{loss_name}"] = loss_item
327
- else:
328
- log_dict[f"losses/{loss_name}"] = loss_item
345
+ log_dict[f"losses/{loss_name}"] = loss_item
329
346
 
330
347
  return log_dict
331
348
 
@@ -337,30 +354,11 @@ class SAETrainer:
337
354
  * self.cfg.logger.eval_every_n_wandb_logs
338
355
  ) == 0:
339
356
  self.sae.eval()
340
- ignore_tokens = set()
341
- if self.activations_store.exclude_special_tokens is not None:
342
- ignore_tokens = set(
343
- self.activations_store.exclude_special_tokens.tolist()
344
- )
345
- eval_metrics, _ = run_evals(
346
- sae=self.sae,
347
- activation_store=self.activations_store,
348
- model=self.model,
349
- eval_config=self.trainer_eval_config,
350
- ignore_tokens=ignore_tokens,
351
- model_kwargs=self.cfg.model_kwargs,
352
- ) # not calculating featurwise metrics here.
353
-
354
- # Remove eval metrics that are already logged during training
355
- eval_metrics.pop("metrics/explained_variance", None)
356
- eval_metrics.pop("metrics/explained_variance_std", None)
357
- eval_metrics.pop("metrics/l0", None)
358
- eval_metrics.pop("metrics/l1", None)
359
- eval_metrics.pop("metrics/mse", None)
360
-
361
- # Remove metrics that are not useful for wandb logging
362
- eval_metrics.pop("metrics/total_tokens_evaluated", None)
363
-
357
+ eval_metrics = (
358
+ self.evaluator(self.sae, self.data_provider, self.activation_scaler)
359
+ if self.evaluator is not None
360
+ else {}
361
+ )
364
362
  for key, value in self.sae.log_histograms().items():
365
363
  eval_metrics[key] = wandb.Histogram(value) # type: ignore
366
364
 
@@ -384,21 +382,18 @@ class SAETrainer:
384
382
  @torch.no_grad()
385
383
  def _reset_running_sparsity_stats(self) -> None:
386
384
  self.act_freq_scores = torch.zeros(
387
- self.cfg.d_sae, # type: ignore
385
+ self.sae.cfg.d_sae, # type: ignore
388
386
  device=self.cfg.device,
389
387
  )
390
- self.n_frac_active_tokens = 0
388
+ self.n_frac_active_samples = 0
391
389
 
392
390
  @torch.no_grad()
393
391
  def _checkpoint_if_needed(self):
394
392
  if (
395
393
  self.checkpoint_thresholds
396
- and self.n_training_tokens > self.checkpoint_thresholds[0]
394
+ and self.n_training_samples > self.checkpoint_thresholds[0]
397
395
  ):
398
- self.save_checkpoint(
399
- trainer=self,
400
- checkpoint_name=str(self.n_training_tokens),
401
- )
396
+ self.save_checkpoint(checkpoint_name=str(self.n_training_samples))
402
397
  self.checkpoint_thresholds.pop(0)
403
398
 
404
399
  @torch.no_grad()
@@ -414,26 +409,7 @@ class SAETrainer:
414
409
  for loss_name, loss_value in step_output.losses.items()
415
410
  )
416
411
  pbar.set_description(f"{self.n_training_steps}| {loss_strs}")
417
- pbar.update(update_interval * self.cfg.train_batch_size_tokens)
418
-
419
- def _begin_finetuning_if_needed(self):
420
- if (not self.started_fine_tuning) and (
421
- self.n_training_tokens > self.cfg.training_tokens
422
- ):
423
- self.started_fine_tuning = True
424
-
425
- # finetuning method should be set in the config
426
- # if not, then we don't finetune
427
- if not isinstance(self.cfg.finetuning_method, str):
428
- return
429
-
430
- for name, param in self.sae.named_parameters():
431
- if name in FINETUNING_PARAMETERS[self.cfg.finetuning_method]:
432
- param.requires_grad = True
433
- else:
434
- param.requires_grad = False
435
-
436
- self.finetuning = True
412
+ pbar.update(update_interval * self.cfg.train_batch_size_samples)
437
413
 
438
414
 
439
415
  def _unwrap_item(item: float | torch.Tensor) -> float:
@@ -0,0 +1,5 @@
1
+ from typing import Iterator
2
+
3
+ import torch
4
+
5
+ DataProvider = Iterator[torch.Tensor]
@@ -2,14 +2,15 @@ import io
2
2
  from pathlib import Path
3
3
  from tempfile import TemporaryDirectory
4
4
  from textwrap import dedent
5
- from typing import Iterable
5
+ from typing import Any, Iterable
6
6
 
7
7
  from huggingface_hub import HfApi, create_repo, get_hf_file_metadata, hf_hub_url
8
8
  from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
9
9
  from tqdm.autonotebook import tqdm
10
10
 
11
11
  from sae_lens import logger
12
- from sae_lens.config import (
12
+ from sae_lens.constants import (
13
+ RUNNER_CFG_FILENAME,
13
14
  SAE_CFG_FILENAME,
14
15
  SAE_WEIGHTS_FILENAME,
15
16
  SPARSITY_FILENAME,
@@ -18,7 +19,7 @@ from sae_lens.saes.sae import SAE
18
19
 
19
20
 
20
21
  def upload_saes_to_huggingface(
21
- saes_dict: dict[str, SAE | Path | str],
22
+ saes_dict: dict[str, SAE[Any] | Path | str],
22
23
  hf_repo_id: str,
23
24
  hf_revision: str = "main",
24
25
  show_progress: bool = True,
@@ -119,11 +120,16 @@ def _upload_sae(api: HfApi, sae_path: Path, repo_id: str, sae_id: str, revision:
119
120
  revision=revision,
120
121
  repo_type="model",
121
122
  commit_message=f"Upload SAE {sae_id}",
122
- allow_patterns=[SAE_CFG_FILENAME, SAE_WEIGHTS_FILENAME, SPARSITY_FILENAME],
123
+ allow_patterns=[
124
+ SAE_CFG_FILENAME,
125
+ SAE_WEIGHTS_FILENAME,
126
+ SPARSITY_FILENAME,
127
+ RUNNER_CFG_FILENAME,
128
+ ],
123
129
  )
124
130
 
125
131
 
126
- def _build_sae_path(sae_ref: SAE | Path | str, tmp_dir: str) -> Path:
132
+ def _build_sae_path(sae_ref: SAE[Any] | Path | str, tmp_dir: str) -> Path:
127
133
  if isinstance(sae_ref, SAE):
128
134
  sae_ref.save_model(tmp_dir)
129
135
  return Path(tmp_dir)