sae-lens 6.16.3__py3-none-any.whl → 6.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,372 @@
1
+ """TemporalSAE: A Sparse Autoencoder with temporal attention mechanism.
2
+
3
+ TemporalSAE decomposes activations into:
4
+ 1. Predicted codes (from attention over context)
5
+ 2. Novel codes (sparse features of the residual)
6
+
7
+ See: https://arxiv.org/abs/2410.04185
8
+ """
9
+
10
+ import math
11
+ from dataclasses import dataclass
12
+ from typing import Literal
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+ from jaxtyping import Float
17
+ from torch import nn
18
+ from typing_extensions import override
19
+
20
+ from sae_lens import logger
21
+ from sae_lens.saes.sae import SAE, SAEConfig
22
+
23
+
24
+ def get_attention(query: torch.Tensor, key: torch.Tensor) -> torch.Tensor:
25
+ """Compute causal attention weights."""
26
+ L, S = query.size(-2), key.size(-2)
27
+ scale_factor = 1 / math.sqrt(query.size(-1))
28
+ attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
29
+ temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0)
30
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
31
+ attn_bias.to(query.dtype)
32
+
33
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
34
+ attn_weight += attn_bias
35
+ return torch.softmax(attn_weight, dim=-1)
36
+
37
+
38
+ class ManualAttention(nn.Module):
39
+ """Manual attention implementation for TemporalSAE."""
40
+
41
+ def __init__(
42
+ self,
43
+ dimin: int,
44
+ n_heads: int = 4,
45
+ bottleneck_factor: int = 64,
46
+ bias_k: bool = True,
47
+ bias_q: bool = True,
48
+ bias_v: bool = True,
49
+ bias_o: bool = True,
50
+ ):
51
+ super().__init__()
52
+ assert dimin % (bottleneck_factor * n_heads) == 0
53
+
54
+ self.n_heads = n_heads
55
+ self.n_embds = dimin // bottleneck_factor
56
+ self.dimin = dimin
57
+
58
+ # Key, query, value projections
59
+ self.k_ctx = nn.Linear(dimin, self.n_embds, bias=bias_k)
60
+ self.q_target = nn.Linear(dimin, self.n_embds, bias=bias_q)
61
+ self.v_ctx = nn.Linear(dimin, dimin, bias=bias_v)
62
+ self.c_proj = nn.Linear(dimin, dimin, bias=bias_o)
63
+
64
+ # Normalize to match scale with representations
65
+ with torch.no_grad():
66
+ scaling = 1 / math.sqrt(self.n_embds // self.n_heads)
67
+ self.k_ctx.weight.copy_(
68
+ scaling
69
+ * self.k_ctx.weight
70
+ / (1e-6 + torch.linalg.norm(self.k_ctx.weight, dim=1, keepdim=True))
71
+ )
72
+ self.q_target.weight.copy_(
73
+ scaling
74
+ * self.q_target.weight
75
+ / (1e-6 + torch.linalg.norm(self.q_target.weight, dim=1, keepdim=True))
76
+ )
77
+
78
+ scaling = 1 / math.sqrt(self.dimin // self.n_heads)
79
+ self.v_ctx.weight.copy_(
80
+ scaling
81
+ * self.v_ctx.weight
82
+ / (1e-6 + torch.linalg.norm(self.v_ctx.weight, dim=1, keepdim=True))
83
+ )
84
+
85
+ scaling = 1 / math.sqrt(self.dimin)
86
+ self.c_proj.weight.copy_(
87
+ scaling
88
+ * self.c_proj.weight
89
+ / (1e-6 + torch.linalg.norm(self.c_proj.weight, dim=1, keepdim=True))
90
+ )
91
+
92
+ def forward(
93
+ self, x_ctx: torch.Tensor, x_target: torch.Tensor, get_attn_map: bool = False
94
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
95
+ """Compute projective attention output."""
96
+ k = self.k_ctx(x_ctx)
97
+ v = self.v_ctx(x_ctx)
98
+ q = self.q_target(x_target)
99
+
100
+ # Split into heads
101
+ B, T, _ = x_ctx.size()
102
+ k = k.view(B, T, self.n_heads, self.n_embds // self.n_heads).transpose(1, 2)
103
+ q = q.view(B, T, self.n_heads, self.n_embds // self.n_heads).transpose(1, 2)
104
+ v = v.view(B, T, self.n_heads, self.dimin // self.n_heads).transpose(1, 2)
105
+
106
+ # Attention map (optional)
107
+ attn_map = None
108
+ if get_attn_map:
109
+ attn_map = get_attention(query=q, key=k)
110
+
111
+ # Scaled dot-product attention
112
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
113
+ q, k, v, attn_mask=None, dropout_p=0, is_causal=True
114
+ )
115
+
116
+ # Reshape and project
117
+ d_target = self.c_proj(
118
+ attn_output.transpose(1, 2).contiguous().view(B, T, self.dimin)
119
+ )
120
+
121
+ return d_target, attn_map
122
+
123
+
124
+ @dataclass
125
+ class TemporalSAEConfig(SAEConfig):
126
+ """Configuration for TemporalSAE inference.
127
+
128
+ Args:
129
+ d_in: Input dimension (dimensionality of the activations being encoded)
130
+ d_sae: SAE latent dimension (number of features)
131
+ n_heads: Number of attention heads in temporal attention
132
+ n_attn_layers: Number of attention layers
133
+ bottleneck_factor: Bottleneck factor for attention dimension
134
+ sae_diff_type: Type of SAE for novel codes ('relu' or 'topk')
135
+ kval_topk: K value for top-k sparsity (if sae_diff_type='topk')
136
+ tied_weights: Whether to tie encoder and decoder weights
137
+ activation_normalization_factor: Scalar factor for rescaling activations (used with normalize_activations='constant_scalar_rescale')
138
+ """
139
+
140
+ n_heads: int = 8
141
+ n_attn_layers: int = 1
142
+ bottleneck_factor: int = 64
143
+ sae_diff_type: Literal["relu", "topk"] = "topk"
144
+ kval_topk: int | None = None
145
+ tied_weights: bool = True
146
+ activation_normalization_factor: float = 1.0
147
+
148
+ def __post_init__(self):
149
+ # Call parent's __post_init__ first, but allow constant_scalar_rescale
150
+ if self.normalize_activations not in [
151
+ "none",
152
+ "expected_average_only_in",
153
+ "constant_norm_rescale",
154
+ "constant_scalar_rescale", # Temporal SAEs support this
155
+ "layer_norm",
156
+ ]:
157
+ raise ValueError(
158
+ f"normalize_activations must be none, expected_average_only_in, layer_norm, constant_norm_rescale, or constant_scalar_rescale. Got {self.normalize_activations}"
159
+ )
160
+
161
+ @override
162
+ @classmethod
163
+ def architecture(cls) -> str:
164
+ return "temporal"
165
+
166
+
167
+ class TemporalSAE(SAE[TemporalSAEConfig]):
168
+ """TemporalSAE: Sparse Autoencoder with temporal attention.
169
+
170
+ This SAE decomposes each activation x_t into:
171
+ - x_pred: Information aggregated from context {x_0, ..., x_{t-1}}
172
+ - x_novel: Novel information at position t (encoded sparsely)
173
+
174
+ The forward pass:
175
+ 1. Uses attention layers to predict x_t from context
176
+ 2. Encodes the residual (novel part) with a sparse SAE
177
+ 3. Combines both for reconstruction
178
+ """
179
+
180
+ # Custom parameters (in addition to W_enc, W_dec, b_dec from base)
181
+ attn_layers: nn.ModuleList # Attention layers
182
+ eps: float
183
+ lam: float
184
+
185
+ def __init__(self, cfg: TemporalSAEConfig, use_error_term: bool = False):
186
+ # Call parent init first
187
+ super().__init__(cfg, use_error_term)
188
+
189
+ # Initialize attention layers after parent init and move to correct device
190
+ self.attn_layers = nn.ModuleList(
191
+ [
192
+ ManualAttention(
193
+ dimin=cfg.d_sae,
194
+ n_heads=cfg.n_heads,
195
+ bottleneck_factor=cfg.bottleneck_factor,
196
+ bias_k=True,
197
+ bias_q=True,
198
+ bias_v=True,
199
+ bias_o=True,
200
+ ).to(device=self.device, dtype=self.dtype)
201
+ for _ in range(cfg.n_attn_layers)
202
+ ]
203
+ )
204
+
205
+ self.eps = 1e-6
206
+ self.lam = 1 / (4 * self.cfg.d_in)
207
+
208
+ @override
209
+ def _setup_activation_normalization(self):
210
+ """Set up activation normalization functions for TemporalSAE.
211
+
212
+ Overrides the base implementation to handle constant_scalar_rescale
213
+ using the temporal-specific activation_normalization_factor.
214
+ """
215
+ if self.cfg.normalize_activations == "constant_scalar_rescale":
216
+ # Handle constant scalar rescaling for temporal SAEs
217
+ def run_time_activation_norm_fn_in(x: torch.Tensor) -> torch.Tensor:
218
+ return x * self.cfg.activation_normalization_factor
219
+
220
+ def run_time_activation_norm_fn_out(x: torch.Tensor) -> torch.Tensor:
221
+ return x / self.cfg.activation_normalization_factor
222
+
223
+ self.run_time_activation_norm_fn_in = run_time_activation_norm_fn_in
224
+ self.run_time_activation_norm_fn_out = run_time_activation_norm_fn_out
225
+ else:
226
+ # Delegate to parent for all other normalization types
227
+ super()._setup_activation_normalization()
228
+
229
+ @override
230
+ def initialize_weights(self) -> None:
231
+ """Initialize TemporalSAE weights."""
232
+ # Initialize D (decoder) and b (bias)
233
+ self.W_dec = nn.Parameter(
234
+ torch.randn(
235
+ (self.cfg.d_sae, self.cfg.d_in), dtype=self.dtype, device=self.device
236
+ )
237
+ )
238
+ self.b_dec = nn.Parameter(
239
+ torch.zeros((self.cfg.d_in), dtype=self.dtype, device=self.device)
240
+ )
241
+
242
+ # Initialize E (encoder) if not tied
243
+ if not self.cfg.tied_weights:
244
+ self.W_enc = nn.Parameter(
245
+ torch.randn(
246
+ (self.cfg.d_in, self.cfg.d_sae),
247
+ dtype=self.dtype,
248
+ device=self.device,
249
+ )
250
+ )
251
+
252
+ def encode_with_predictions(
253
+ self, x: Float[torch.Tensor, "... d_in"]
254
+ ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
255
+ """Encode input to novel codes only.
256
+
257
+ Returns only the sparse novel codes (not predicted codes).
258
+ This is the main feature representation for TemporalSAE.
259
+ """
260
+ # Process input through SAELens preprocessing
261
+ x = self.process_sae_in(x)
262
+
263
+ B, L, _ = x.shape
264
+
265
+ if self.cfg.tied_weights: # noqa: SIM108
266
+ W_enc = self.W_dec.T
267
+ else:
268
+ W_enc = self.W_enc
269
+
270
+ # Compute predicted codes using attention
271
+ x_residual = x
272
+ z_pred = torch.zeros((B, L, self.cfg.d_sae), device=x.device, dtype=x.dtype)
273
+
274
+ for attn_layer in self.attn_layers:
275
+ # Encode input to latent space
276
+ z_input = F.relu(torch.matmul(x_residual * self.lam, W_enc))
277
+
278
+ # Shift context (causal masking)
279
+ z_ctx = torch.cat(
280
+ (torch.zeros_like(z_input[:, :1, :]), z_input[:, :-1, :].clone()), dim=1
281
+ )
282
+
283
+ # Apply attention to get predicted codes
284
+ z_pred_, _ = attn_layer(z_ctx, z_input, get_attn_map=False)
285
+ z_pred_ = F.relu(z_pred_)
286
+
287
+ # Project predicted codes back to input space
288
+ Dz_pred_ = torch.matmul(z_pred_, self.W_dec)
289
+ Dz_norm_ = Dz_pred_.norm(dim=-1, keepdim=True) + self.eps
290
+
291
+ # Compute projection scale
292
+ proj_scale = (Dz_pred_ * x_residual).sum(
293
+ dim=-1, keepdim=True
294
+ ) / Dz_norm_.pow(2)
295
+
296
+ # Accumulate predicted codes
297
+ z_pred = z_pred + (z_pred_ * proj_scale)
298
+
299
+ # Remove prediction from residual
300
+ x_residual = x_residual - proj_scale * Dz_pred_
301
+
302
+ # Encode residual (novel part) with sparse SAE
303
+ z_novel = F.relu(torch.matmul(x_residual * self.lam, W_enc))
304
+ if self.cfg.sae_diff_type == "topk":
305
+ kval = self.cfg.kval_topk
306
+ if kval is not None:
307
+ _, topk_indices = torch.topk(z_novel, kval, dim=-1)
308
+ mask = torch.zeros_like(z_novel)
309
+ mask.scatter_(-1, topk_indices, 1)
310
+ z_novel = z_novel * mask
311
+
312
+ # Return only novel codes (these are the interpretable features)
313
+ return z_novel, z_pred
314
+
315
+ def encode(
316
+ self, x: Float[torch.Tensor, "... d_in"]
317
+ ) -> Float[torch.Tensor, "... d_sae"]:
318
+ return self.encode_with_predictions(x)[0]
319
+
320
+ def decode(
321
+ self, feature_acts: Float[torch.Tensor, "... d_sae"]
322
+ ) -> Float[torch.Tensor, "... d_in"]:
323
+ """Decode novel codes to reconstruction.
324
+
325
+ Note: This only decodes the novel codes. For full reconstruction,
326
+ use forward() which includes predicted codes.
327
+ """
328
+ # Decode novel codes
329
+ sae_out = torch.matmul(feature_acts, self.W_dec)
330
+ sae_out = sae_out + self.b_dec
331
+
332
+ # Apply hook
333
+ sae_out = self.hook_sae_recons(sae_out)
334
+
335
+ # Apply output activation normalization (reverses input normalization)
336
+ sae_out = self.run_time_activation_norm_fn_out(sae_out)
337
+
338
+ # Add bias (already removed in process_sae_in)
339
+ logger.warning(
340
+ "NOTE this only decodes x_novel. The x_pred is missing, so we're not reconstructing the full x."
341
+ )
342
+ return sae_out
343
+
344
+ @override
345
+ def forward(
346
+ self, x: Float[torch.Tensor, "... d_in"]
347
+ ) -> Float[torch.Tensor, "... d_in"]:
348
+ """Full forward pass through TemporalSAE.
349
+
350
+ Returns complete reconstruction (predicted + novel).
351
+ """
352
+ # Encode
353
+ z_novel, z_pred = self.encode_with_predictions(x)
354
+
355
+ # Decode the sum of predicted and novel codes.
356
+ x_recons = torch.matmul(z_novel + z_pred, self.W_dec) + self.b_dec
357
+
358
+ # Apply output activation normalization (reverses input normalization)
359
+ x_recons = self.run_time_activation_norm_fn_out(x_recons)
360
+
361
+ return self.hook_sae_output(x_recons)
362
+
363
+ @override
364
+ def fold_W_dec_norm(self) -> None:
365
+ raise NotImplementedError("Folding W_dec_norm is not supported for TemporalSAE")
366
+
367
+ @override
368
+ @torch.no_grad()
369
+ def fold_activation_norm_scaling_factor(self, scaling_factor: float) -> None:
370
+ raise NotImplementedError(
371
+ "Folding activation norm scaling factor is not supported for TemporalSAE"
372
+ )
@@ -1,5 +1,6 @@
1
1
  import json
2
2
  from dataclasses import dataclass
3
+ from pathlib import Path
3
4
  from statistics import mean
4
5
 
5
6
  import torch
@@ -51,3 +52,9 @@ class ActivationScaler:
51
52
 
52
53
  with open(file_path, "w") as f:
53
54
  json.dump({"scaling_factor": self.scaling_factor}, f)
55
+
56
+ def load(self, file_path: str | Path):
57
+ """load the state dict from a file in json format"""
58
+ with open(file_path) as f:
59
+ data = json.load(f)
60
+ self.scaling_factor = data["scaling_factor"]
@@ -4,6 +4,7 @@ import json
4
4
  import os
5
5
  import warnings
6
6
  from collections.abc import Generator, Iterator, Sequence
7
+ from pathlib import Path
7
8
  from typing import Any, Literal, cast
8
9
 
9
10
  import datasets
@@ -13,8 +14,8 @@ from huggingface_hub import hf_hub_download
13
14
  from huggingface_hub.utils import HfHubHTTPError
14
15
  from jaxtyping import Float, Int
15
16
  from requests import HTTPError
16
- from safetensors.torch import save_file
17
- from tqdm import tqdm
17
+ from safetensors.torch import load_file, save_file
18
+ from tqdm.auto import tqdm
18
19
  from transformer_lens.hook_points import HookedRootModule
19
20
  from transformers import AutoTokenizer, PreTrainedTokenizerBase
20
21
 
@@ -24,7 +25,7 @@ from sae_lens.config import (
24
25
  HfDataset,
25
26
  LanguageModelSAERunnerConfig,
26
27
  )
27
- from sae_lens.constants import DTYPE_MAP
28
+ from sae_lens.constants import ACTIVATIONS_STORE_STATE_FILENAME, DTYPE_MAP
28
29
  from sae_lens.pretokenize_runner import get_special_token_from_cfg
29
30
  from sae_lens.saes.sae import SAE, T_SAE_CONFIG, T_TRAINING_SAE_CONFIG
30
31
  from sae_lens.tokenization_and_batching import concat_and_batch_sequences
@@ -318,7 +319,7 @@ class ActivationsStore:
318
319
  )
319
320
  else:
320
321
  warnings.warn(
321
- "Dataset is not tokenized. Pre-tokenizing will improve performance and allows for more control over special tokens. See https://jbloomaus.github.io/SAELens/training_saes/#pretokenizing-datasets for more info."
322
+ "Dataset is not tokenized. Pre-tokenizing will improve performance and allows for more control over special tokens. See https://decoderesearch.github.io/SAELens/training_saes/#pretokenizing-datasets for more info."
322
323
  )
323
324
 
324
325
  self.iterable_sequences = self._iterate_tokenized_sequences()
@@ -729,6 +730,48 @@ class ActivationsStore:
729
730
  """save the state dict to a file in safetensors format"""
730
731
  save_file(self.state_dict(), file_path)
731
732
 
733
+ def save_to_checkpoint(self, checkpoint_path: str | Path):
734
+ """Save the state dict to a checkpoint path"""
735
+ self.save(str(Path(checkpoint_path) / ACTIVATIONS_STORE_STATE_FILENAME))
736
+
737
+ def load_from_checkpoint(self, checkpoint_path: str | Path):
738
+ """Load the state dict from a checkpoint path"""
739
+ self.load(str(Path(checkpoint_path) / ACTIVATIONS_STORE_STATE_FILENAME))
740
+
741
+ def load(self, file_path: str):
742
+ """Load the state dict from a file in safetensors format"""
743
+
744
+ state_dict = load_file(file_path)
745
+
746
+ if "n_dataset_processed" in state_dict:
747
+ target_n_dataset_processed = state_dict["n_dataset_processed"].item()
748
+
749
+ # Only fast-forward if needed
750
+
751
+ if target_n_dataset_processed > self.n_dataset_processed:
752
+ logger.info(
753
+ "Fast-forwarding through dataset samples to match checkpoint position"
754
+ )
755
+ samples_to_skip = target_n_dataset_processed - self.n_dataset_processed
756
+
757
+ pbar = tqdm(
758
+ total=samples_to_skip,
759
+ desc="Fast-forwarding through dataset",
760
+ leave=False,
761
+ )
762
+ while target_n_dataset_processed > self.n_dataset_processed:
763
+ start = self.n_dataset_processed
764
+ try:
765
+ # Just consume and ignore the values to fast-forward
766
+ next(self.iterable_sequences)
767
+ except StopIteration:
768
+ logger.warning(
769
+ "Dataset exhausted during fast-forward. Resetting dataset."
770
+ )
771
+ self.iterable_sequences = self._iterate_tokenized_sequences()
772
+ pbar.update(self.n_dataset_processed - start)
773
+ pbar.close()
774
+
732
775
 
733
776
  def validate_pretokenized_dataset_tokenizer(
734
777
  dataset_path: str, model_tokenizer: PreTrainedTokenizerBase
@@ -2,6 +2,8 @@
2
2
  Took the LR scheduler from my previous work: https://github.com/jbloomAus/DecisionTransformerInterpretability/blob/ee55df35cdb92e81d689c72fb9dd5a7252893363/src/decision_transformer/utils.py#L425
3
3
  """
4
4
 
5
+ from typing import Any
6
+
5
7
  import torch.optim as optim
6
8
  import torch.optim.lr_scheduler as lr_scheduler
7
9
 
@@ -150,3 +152,12 @@ class CoefficientScheduler:
150
152
  def value(self) -> float:
151
153
  """Returns the current scalar value."""
152
154
  return self.current_value
155
+
156
+ def state_dict(self) -> dict[str, Any]:
157
+ return {
158
+ "current_step": self.current_step,
159
+ }
160
+
161
+ def load_state_dict(self, state_dict: dict[str, Any]):
162
+ for k in state_dict:
163
+ setattr(self, k, state_dict[k])
@@ -1,4 +1,5 @@
1
1
  import contextlib
2
+ import math
2
3
  from pathlib import Path
3
4
  from typing import Any, Callable, Generic, Protocol
4
5
 
@@ -10,7 +11,11 @@ from tqdm.auto import tqdm
10
11
 
11
12
  from sae_lens import __version__
12
13
  from sae_lens.config import SAETrainerConfig
13
- from sae_lens.constants import ACTIVATION_SCALER_CFG_FILENAME, SPARSITY_FILENAME
14
+ from sae_lens.constants import (
15
+ ACTIVATION_SCALER_CFG_FILENAME,
16
+ SPARSITY_FILENAME,
17
+ TRAINER_STATE_FILENAME,
18
+ )
14
19
  from sae_lens.saes.sae import (
15
20
  T_TRAINING_SAE,
16
21
  T_TRAINING_SAE_CONFIG,
@@ -56,6 +61,7 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
56
61
  data_provider: DataProvider
57
62
  activation_scaler: ActivationScaler
58
63
  evaluator: Evaluator[T_TRAINING_SAE] | None
64
+ coefficient_schedulers: dict[str, CoefficientScheduler]
59
65
 
60
66
  def __init__(
61
67
  self,
@@ -84,7 +90,9 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
84
90
  range(
85
91
  0,
86
92
  cfg.total_training_samples,
87
- cfg.total_training_samples // self.cfg.n_checkpoints,
93
+ math.ceil(
94
+ cfg.total_training_samples / (self.cfg.n_checkpoints + 1)
95
+ ),
88
96
  )
89
97
  )[1:]
90
98
 
@@ -93,11 +101,6 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
93
101
  sae.cfg.d_sae, device=cfg.device
94
102
  )
95
103
  self.n_frac_active_samples = 0
96
- # we don't train the scaling factor (initially)
97
- # set requires grad to false for the scaling factor
98
- for name, param in self.sae.named_parameters():
99
- if "scaling_factor" in name:
100
- param.requires_grad = False
101
104
 
102
105
  self.optimizer = Adam(
103
106
  sae.parameters(),
@@ -210,10 +213,7 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
210
213
  sparsity_path = checkpoint_path / SPARSITY_FILENAME
211
214
  save_file({"sparsity": self.log_feature_sparsity}, sparsity_path)
212
215
 
213
- activation_scaler_path = (
214
- checkpoint_path / ACTIVATION_SCALER_CFG_FILENAME
215
- )
216
- self.activation_scaler.save(str(activation_scaler_path))
216
+ self.save_trainer_state(checkpoint_path)
217
217
 
218
218
  if self.cfg.logger.log_to_wandb:
219
219
  self.cfg.logger.log(
@@ -227,6 +227,44 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
227
227
  if self.save_checkpoint_fn is not None:
228
228
  self.save_checkpoint_fn(checkpoint_path=checkpoint_path)
229
229
 
230
+ def save_trainer_state(self, checkpoint_path: Path) -> None:
231
+ checkpoint_path.mkdir(exist_ok=True, parents=True)
232
+ scheduler_state_dicts = {
233
+ name: scheduler.state_dict()
234
+ for name, scheduler in self.coefficient_schedulers.items()
235
+ }
236
+ torch.save(
237
+ {
238
+ "optimizer": self.optimizer.state_dict(),
239
+ "lr_scheduler": self.lr_scheduler.state_dict(),
240
+ "n_training_samples": self.n_training_samples,
241
+ "n_training_steps": self.n_training_steps,
242
+ "act_freq_scores": self.act_freq_scores,
243
+ "n_forward_passes_since_fired": self.n_forward_passes_since_fired,
244
+ "n_frac_active_samples": self.n_frac_active_samples,
245
+ "started_fine_tuning": self.started_fine_tuning,
246
+ "coefficient_schedulers": scheduler_state_dicts,
247
+ },
248
+ str(checkpoint_path / TRAINER_STATE_FILENAME),
249
+ )
250
+ activation_scaler_path = checkpoint_path / ACTIVATION_SCALER_CFG_FILENAME
251
+ self.activation_scaler.save(str(activation_scaler_path))
252
+
253
+ def load_trainer_state(self, checkpoint_path: Path | str) -> None:
254
+ checkpoint_path = Path(checkpoint_path)
255
+ self.activation_scaler.load(checkpoint_path / ACTIVATION_SCALER_CFG_FILENAME)
256
+ state_dict = torch.load(checkpoint_path / TRAINER_STATE_FILENAME)
257
+ self.optimizer.load_state_dict(state_dict["optimizer"])
258
+ self.lr_scheduler.load_state_dict(state_dict["lr_scheduler"])
259
+ self.n_training_samples = state_dict["n_training_samples"]
260
+ self.n_training_steps = state_dict["n_training_steps"]
261
+ self.act_freq_scores = state_dict["act_freq_scores"]
262
+ self.n_forward_passes_since_fired = state_dict["n_forward_passes_since_fired"]
263
+ self.n_frac_active_samples = state_dict["n_frac_active_samples"]
264
+ self.started_fine_tuning = state_dict["started_fine_tuning"]
265
+ for name, scheduler_state_dict in state_dict["coefficient_schedulers"].items():
266
+ self.coefficient_schedulers[name].load_state_dict(scheduler_state_dict)
267
+
230
268
  def _train_step(
231
269
  self,
232
270
  sae: T_TRAINING_SAE,