sae-lens 6.12.1__py3-none-any.whl → 6.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sae_lens/saes/topk_sae.py CHANGED
@@ -1,11 +1,12 @@
1
1
  """Inference-only TopKSAE variant, similar in spirit to StandardSAE but using a TopK-based activation."""
2
2
 
3
3
  from dataclasses import dataclass
4
- from typing import Callable
4
+ from typing import Any, Callable
5
5
 
6
6
  import torch
7
7
  from jaxtyping import Float
8
8
  from torch import nn
9
+ from transformer_lens.hook_points import HookPoint
9
10
  from typing_extensions import override
10
11
 
11
12
  from sae_lens.saes.sae import (
@@ -15,44 +16,138 @@ from sae_lens.saes.sae import (
15
16
  TrainingSAE,
16
17
  TrainingSAEConfig,
17
18
  TrainStepInput,
19
+ _disable_hooks,
18
20
  )
19
21
 
20
22
 
23
+ class SparseHookPoint(HookPoint):
24
+ """
25
+ A HookPoint that takes in a sparse tensor.
26
+ Overrides TransformerLens's HookPoint.
27
+ """
28
+
29
+ def __init__(self, d_sae: int):
30
+ super().__init__()
31
+ self.d_sae = d_sae
32
+
33
+ @override
34
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
35
+ using_hooks = (
36
+ self._forward_hooks is not None and len(self._forward_hooks) > 0
37
+ ) or (self._backward_hooks is not None and len(self._backward_hooks) > 0)
38
+ if using_hooks and x.is_sparse:
39
+ return x.to_dense()
40
+ return x # if no hooks are being used, use passthrough
41
+
42
+
21
43
  class TopK(nn.Module):
22
44
  """
23
45
  A simple TopK activation that zeroes out all but the top K elements along the last dimension,
24
46
  and applies ReLU to the top K elements.
25
47
  """
26
48
 
27
- b_enc: nn.Parameter
49
+ use_sparse_activations: bool
28
50
 
29
51
  def __init__(
30
52
  self,
31
53
  k: int,
54
+ use_sparse_activations: bool = False,
32
55
  ):
33
56
  super().__init__()
34
57
  self.k = k
58
+ self.use_sparse_activations = use_sparse_activations
35
59
 
36
- def forward(self, x: torch.Tensor) -> torch.Tensor:
60
+ def forward(
61
+ self,
62
+ x: torch.Tensor,
63
+ ) -> torch.Tensor:
37
64
  """
38
65
  1) Select top K elements along the last dimension.
39
66
  2) Apply ReLU.
40
67
  3) Zero out all other entries.
41
68
  """
42
- topk = torch.topk(x, k=self.k, dim=-1)
43
- values = topk.values.relu()
69
+ topk_values, topk_indices = torch.topk(x, k=self.k, dim=-1, sorted=False)
70
+ values = topk_values.relu()
71
+ if self.use_sparse_activations:
72
+ # Produce a COO sparse tensor (use sparse matrix multiply in decode)
73
+ original_shape = x.shape
74
+
75
+ # Create indices for all dimensions
76
+ # For each element in topk_indices, we need to map it back to the original tensor coordinates
77
+ batch_dims = original_shape[:-1] # All dimensions except the last one
78
+ num_batch_elements = torch.prod(torch.tensor(batch_dims)).item()
79
+
80
+ # Create batch indices - each batch element repeated k times
81
+ batch_indices_flat = torch.arange(
82
+ num_batch_elements, device=x.device
83
+ ).repeat_interleave(self.k)
84
+
85
+ # Convert flat batch indices back to multi-dimensional indices
86
+ if len(batch_dims) == 1:
87
+ # 2D case: [batch, features]
88
+ sparse_indices = torch.stack(
89
+ [
90
+ batch_indices_flat,
91
+ topk_indices.flatten(),
92
+ ]
93
+ )
94
+ else:
95
+ # 3D+ case: need to unravel the batch indices
96
+ batch_indices_multi = []
97
+ remaining = batch_indices_flat
98
+ for dim_size in reversed(batch_dims):
99
+ batch_indices_multi.append(remaining % dim_size)
100
+ remaining = remaining // dim_size
101
+ batch_indices_multi.reverse()
102
+
103
+ sparse_indices = torch.stack(
104
+ [
105
+ *batch_indices_multi,
106
+ topk_indices.flatten(),
107
+ ]
108
+ )
109
+
110
+ return torch.sparse_coo_tensor(
111
+ sparse_indices, values.flatten(), original_shape
112
+ )
44
113
  result = torch.zeros_like(x)
45
- result.scatter_(-1, topk.indices, values)
114
+ result.scatter_(-1, topk_indices, values)
46
115
  return result
47
116
 
48
117
 
49
118
  @dataclass
50
119
  class TopKSAEConfig(SAEConfig):
51
120
  """
52
- Configuration class for a TopKSAE.
121
+ Configuration class for TopKSAE inference.
122
+
123
+ Args:
124
+ k (int): Number of top features to keep active during inference. Only the top k
125
+ features with the highest pre-activations will be non-zero. Defaults to 100.
126
+ rescale_acts_by_decoder_norm (bool): Whether to treat the decoder as if it was
127
+ already normalized. This affects the topk selection by rescaling pre-activations
128
+ by decoder norms. Requires that the SAE was trained this way. Defaults to False.
129
+ d_in (int): Input dimension (dimensionality of the activations being encoded).
130
+ Inherited from SAEConfig.
131
+ d_sae (int): SAE latent dimension (number of features in the SAE).
132
+ Inherited from SAEConfig.
133
+ dtype (str): Data type for the SAE parameters. Inherited from SAEConfig.
134
+ Defaults to "float32".
135
+ device (str): Device to place the SAE on. Inherited from SAEConfig.
136
+ Defaults to "cpu".
137
+ apply_b_dec_to_input (bool): Whether to apply decoder bias to the input
138
+ before encoding. Inherited from SAEConfig. Defaults to True.
139
+ normalize_activations (Literal["none", "expected_average_only_in", "constant_norm_rescale", "layer_norm"]):
140
+ Normalization strategy for input activations. Inherited from SAEConfig.
141
+ Defaults to "none".
142
+ reshape_activations (Literal["none", "hook_z"]): How to reshape activations
143
+ (useful for attention head outputs). Inherited from SAEConfig.
144
+ Defaults to "none".
145
+ metadata (SAEMetadata): Metadata about the SAE (model name, hook name, etc.).
146
+ Inherited from SAEConfig.
53
147
  """
54
148
 
55
149
  k: int = 100
150
+ rescale_acts_by_decoder_norm: bool = False
56
151
 
57
152
  @override
58
153
  @classmethod
@@ -60,6 +155,63 @@ class TopKSAEConfig(SAEConfig):
60
155
  return "topk"
61
156
 
62
157
 
158
+ def _sparse_matmul_nd(
159
+ sparse_tensor: torch.Tensor, dense_matrix: torch.Tensor
160
+ ) -> torch.Tensor:
161
+ """
162
+ Multiply a sparse tensor of shape [..., d_sae] with a dense matrix of shape [d_sae, d_out]
163
+ to get a result of shape [..., d_out].
164
+
165
+ This function handles sparse tensors with arbitrary batch dimensions by flattening
166
+ the batch dimensions, performing 2D sparse matrix multiplication, and reshaping back.
167
+ """
168
+ original_shape = sparse_tensor.shape
169
+ batch_dims = original_shape[:-1]
170
+ d_sae = original_shape[-1]
171
+ d_out = dense_matrix.shape[-1]
172
+
173
+ if sparse_tensor.ndim == 2:
174
+ # Simple 2D case - use torch.sparse.mm directly
175
+ # sparse.mm errors with bfloat16 :(
176
+ with torch.autocast(device_type=sparse_tensor.device.type, enabled=False):
177
+ return torch.sparse.mm(sparse_tensor, dense_matrix)
178
+
179
+ # For 3D+ case, reshape to 2D, multiply, then reshape back
180
+ batch_size = int(torch.prod(torch.tensor(batch_dims)).item())
181
+
182
+ # Ensure tensor is coalesced for efficient access to indices/values
183
+ if not sparse_tensor.is_coalesced():
184
+ sparse_tensor = sparse_tensor.coalesce()
185
+
186
+ # Get indices and values
187
+ indices = sparse_tensor.indices() # [ndim, nnz]
188
+ values = sparse_tensor.values() # [nnz]
189
+
190
+ # Convert multi-dimensional batch indices to flat indices
191
+ flat_batch_indices = torch.zeros_like(indices[0])
192
+ multiplier = 1
193
+ for i in reversed(range(len(batch_dims))):
194
+ flat_batch_indices += indices[i] * multiplier
195
+ multiplier *= batch_dims[i]
196
+
197
+ # Create 2D sparse tensor indices [batch_flat, feature]
198
+ sparse_2d_indices = torch.stack([flat_batch_indices, indices[-1]])
199
+
200
+ # Create 2D sparse tensor
201
+ sparse_2d = torch.sparse_coo_tensor(
202
+ sparse_2d_indices, values, (batch_size, d_sae)
203
+ ).coalesce()
204
+
205
+ # sparse.mm errors with bfloat16 :(
206
+ with torch.autocast(device_type=sparse_tensor.device.type, enabled=False):
207
+ # Do the matrix multiplication
208
+ result_2d = torch.sparse.mm(sparse_2d, dense_matrix) # [batch_size, d_out]
209
+
210
+ # Reshape back to original batch dimensions
211
+ result_shape = tuple(batch_dims) + (d_out,)
212
+ return result_2d.view(result_shape)
213
+
214
+
63
215
  class TopKSAE(SAE[TopKSAEConfig]):
64
216
  """
65
217
  An inference-only sparse autoencoder using a "topk" activation function.
@@ -92,42 +244,91 @@ class TopKSAE(SAE[TopKSAEConfig]):
92
244
  """
93
245
  sae_in = self.process_sae_in(x)
94
246
  hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
247
+ if self.cfg.rescale_acts_by_decoder_norm:
248
+ hidden_pre = hidden_pre * self.W_dec.norm(dim=-1)
95
249
  # The BaseSAE already sets self.activation_fn to TopK(...) if config requests topk.
96
250
  return self.hook_sae_acts_post(self.activation_fn(hidden_pre))
97
251
 
98
252
  def decode(
99
- self, feature_acts: Float[torch.Tensor, "... d_sae"]
253
+ self,
254
+ feature_acts: Float[torch.Tensor, "... d_sae"],
100
255
  ) -> Float[torch.Tensor, "... d_in"]:
101
256
  """
102
257
  Reconstructs the input from topk feature activations.
103
258
  Applies optional finetuning scaling, hooking to recons, out normalization,
104
259
  and optional head reshaping.
105
260
  """
106
- sae_out_pre = feature_acts @ self.W_dec + self.b_dec
261
+ # Handle sparse tensors using efficient sparse matrix multiplication
262
+ if self.cfg.rescale_acts_by_decoder_norm:
263
+ feature_acts = feature_acts / self.W_dec.norm(dim=-1)
264
+ if feature_acts.is_sparse:
265
+ sae_out_pre = _sparse_matmul_nd(feature_acts, self.W_dec) + self.b_dec
266
+ else:
267
+ sae_out_pre = feature_acts @ self.W_dec + self.b_dec
107
268
  sae_out_pre = self.hook_sae_recons(sae_out_pre)
108
269
  sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
109
270
  return self.reshape_fn_out(sae_out_pre, self.d_head)
110
271
 
111
272
  @override
112
273
  def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
113
- return TopK(self.cfg.k)
274
+ return TopK(self.cfg.k, use_sparse_activations=False)
114
275
 
115
276
  @override
116
277
  @torch.no_grad()
117
278
  def fold_W_dec_norm(self) -> None:
118
- raise NotImplementedError(
119
- "Folding W_dec_norm is not safe for TopKSAEs, as this may change the topk activations"
120
- )
279
+ if not self.cfg.rescale_acts_by_decoder_norm:
280
+ raise NotImplementedError(
281
+ "Folding W_dec_norm is not safe for TopKSAEs when rescale_acts_by_decoder_norm is False, as this may change the topk activations"
282
+ )
283
+ _fold_norm_topk(W_dec=self.W_dec, b_enc=self.b_enc, W_enc=self.W_enc)
121
284
 
122
285
 
123
286
  @dataclass
124
287
  class TopKTrainingSAEConfig(TrainingSAEConfig):
125
288
  """
126
289
  Configuration class for training a TopKTrainingSAE.
290
+
291
+ Args:
292
+ k (int): Number of top features to keep active. Only the top k features
293
+ with the highest pre-activations will be non-zero. Defaults to 100.
294
+ use_sparse_activations (bool): Whether to use sparse tensor representations
295
+ for activations during training. This can reduce memory usage and improve
296
+ performance when k is small relative to d_sae, but is only worthwhile if
297
+ using float32 and not using autocast. Defaults to False.
298
+ aux_loss_coefficient (float): Coefficient for the auxiliary loss that encourages
299
+ dead neurons to learn useful features. This loss helps prevent neuron death
300
+ in TopK SAEs by having dead neurons reconstruct the residual error from
301
+ live neurons. Defaults to 1.0.
302
+ rescale_acts_by_decoder_norm (bool): Treat the decoder as if it was already normalized.
303
+ This is a good idea since decoder norm can randomly drift during training, and this
304
+ affects what the topk activations will be. Defaults to True.
305
+ decoder_init_norm (float | None): Norm to initialize decoder weights to.
306
+ 0.1 corresponds to the "heuristic" initialization from Anthropic's April update.
307
+ Use None to disable. Inherited from TrainingSAEConfig. Defaults to 0.1.
308
+ d_in (int): Input dimension (dimensionality of the activations being encoded).
309
+ Inherited from SAEConfig.
310
+ d_sae (int): SAE latent dimension (number of features in the SAE).
311
+ Inherited from SAEConfig.
312
+ dtype (str): Data type for the SAE parameters. Inherited from SAEConfig.
313
+ Defaults to "float32".
314
+ device (str): Device to place the SAE on. Inherited from SAEConfig.
315
+ Defaults to "cpu".
316
+ apply_b_dec_to_input (bool): Whether to apply decoder bias to the input
317
+ before encoding. Inherited from SAEConfig. Defaults to True.
318
+ normalize_activations (Literal["none", "expected_average_only_in", "constant_norm_rescale", "layer_norm"]):
319
+ Normalization strategy for input activations. Inherited from SAEConfig.
320
+ Defaults to "none".
321
+ reshape_activations (Literal["none", "hook_z"]): How to reshape activations
322
+ (useful for attention head outputs). Inherited from SAEConfig.
323
+ Defaults to "none".
324
+ metadata (SAEMetadata): Metadata about the SAE training (model name, hook name, etc.).
325
+ Inherited from SAEConfig.
127
326
  """
128
327
 
129
328
  k: int = 100
329
+ use_sparse_activations: bool = False
130
330
  aux_loss_coefficient: float = 1.0
331
+ rescale_acts_by_decoder_norm: bool = True
131
332
 
132
333
  @override
133
334
  @classmethod
@@ -144,6 +345,8 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
144
345
 
145
346
  def __init__(self, cfg: TopKTrainingSAEConfig, use_error_term: bool = False):
146
347
  super().__init__(cfg, use_error_term)
348
+ self.hook_sae_acts_post = SparseHookPoint(self.cfg.d_sae)
349
+ self.setup()
147
350
 
148
351
  @override
149
352
  def initialize_weights(self) -> None:
@@ -159,10 +362,51 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
159
362
  sae_in = self.process_sae_in(x)
160
363
  hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
161
364
 
365
+ if self.cfg.rescale_acts_by_decoder_norm:
366
+ hidden_pre = hidden_pre * self.W_dec.norm(dim=-1)
367
+
162
368
  # Apply the TopK activation function (already set in self.activation_fn if config is "topk")
163
369
  feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre))
164
370
  return feature_acts, hidden_pre
165
371
 
372
+ @override
373
+ def decode(
374
+ self,
375
+ feature_acts: Float[torch.Tensor, "... d_sae"],
376
+ ) -> Float[torch.Tensor, "... d_in"]:
377
+ """
378
+ Decodes feature activations back into input space,
379
+ applying optional finetuning scale, hooking, out normalization, etc.
380
+ """
381
+ # Handle sparse tensors using efficient sparse matrix multiplication
382
+ if self.cfg.rescale_acts_by_decoder_norm:
383
+ # need to multiply by the inverse of the norm because division is illegal with sparse tensors
384
+ feature_acts = feature_acts * (1 / self.W_dec.norm(dim=-1))
385
+ if feature_acts.is_sparse:
386
+ sae_out_pre = _sparse_matmul_nd(feature_acts, self.W_dec) + self.b_dec
387
+ else:
388
+ sae_out_pre = feature_acts @ self.W_dec + self.b_dec
389
+ sae_out_pre = self.hook_sae_recons(sae_out_pre)
390
+ sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
391
+ return self.reshape_fn_out(sae_out_pre, self.d_head)
392
+
393
+ @override
394
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
395
+ """Forward pass through the SAE."""
396
+ feature_acts = self.encode(x)
397
+ sae_out = self.decode(feature_acts)
398
+
399
+ if self.use_error_term:
400
+ with torch.no_grad():
401
+ # Recompute without hooks for true error term
402
+ with _disable_hooks(self):
403
+ feature_acts_clean = self.encode(x)
404
+ x_reconstruct_clean = self.decode(feature_acts_clean)
405
+ sae_error = self.hook_sae_error(x - x_reconstruct_clean)
406
+ sae_out = sae_out + sae_error
407
+
408
+ return self.hook_sae_output(sae_out)
409
+
166
410
  @override
167
411
  def calculate_aux_loss(
168
412
  self,
@@ -183,13 +427,15 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
183
427
  @override
184
428
  @torch.no_grad()
185
429
  def fold_W_dec_norm(self) -> None:
186
- raise NotImplementedError(
187
- "Folding W_dec_norm is not safe for TopKSAEs, as this may change the topk activations"
188
- )
430
+ if not self.cfg.rescale_acts_by_decoder_norm:
431
+ raise NotImplementedError(
432
+ "Folding W_dec_norm is not safe for TopKSAEs when rescale_acts_by_decoder_norm is False, as this may change the topk activations"
433
+ )
434
+ _fold_norm_topk(W_dec=self.W_dec, b_enc=self.b_enc, W_enc=self.W_enc)
189
435
 
190
436
  @override
191
437
  def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
192
- return TopK(self.cfg.k)
438
+ return TopK(self.cfg.k, use_sparse_activations=self.cfg.use_sparse_activations)
193
439
 
194
440
  @override
195
441
  def get_coefficients(self) -> dict[str, TrainCoefficientConfig | float]:
@@ -234,6 +480,18 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
234
480
  auxk_loss = (recons - residual).pow(2).sum(dim=-1).mean()
235
481
  return self.cfg.aux_loss_coefficient * scale * auxk_loss
236
482
 
483
+ @override
484
+ def process_state_dict_for_saving_inference(
485
+ self, state_dict: dict[str, Any]
486
+ ) -> None:
487
+ super().process_state_dict_for_saving_inference(state_dict)
488
+ if self.cfg.rescale_acts_by_decoder_norm:
489
+ _fold_norm_topk(
490
+ W_enc=state_dict["W_enc"],
491
+ b_enc=state_dict["b_enc"],
492
+ W_dec=state_dict["W_dec"],
493
+ )
494
+
237
495
 
238
496
  def _calculate_topk_aux_acts(
239
497
  k_aux: int,
@@ -269,3 +527,15 @@ def _init_weights_topk(
269
527
  sae.b_enc = nn.Parameter(
270
528
  torch.zeros(sae.cfg.d_sae, dtype=sae.dtype, device=sae.device)
271
529
  )
530
+
531
+
532
+ def _fold_norm_topk(
533
+ W_enc: torch.Tensor,
534
+ b_enc: torch.Tensor,
535
+ W_dec: torch.Tensor,
536
+ ) -> None:
537
+ W_dec_norm = W_dec.norm(dim=-1)
538
+ b_enc.data = b_enc.data * W_dec_norm
539
+ W_dec_norms = W_dec_norm.unsqueeze(1)
540
+ W_dec.data = W_dec.data / W_dec_norms
541
+ W_enc.data = W_enc.data * W_dec_norms.T
@@ -1,4 +1,4 @@
1
- from typing import Generator, Iterator
1
+ from collections.abc import Generator, Iterator
2
2
 
3
3
  import torch
4
4
 
@@ -68,7 +68,7 @@ def concat_and_batch_sequences(
68
68
  ) -> Generator[torch.Tensor, None, None]:
69
69
  """
70
70
  Generator to concat token sequences together from the tokens_interator, yielding
71
- batches of size `context_size`.
71
+ sequences of size `context_size`. Batching across the batch dimension is handled by the caller.
72
72
 
73
73
  Args:
74
74
  tokens_iterator: An iterator which returns a 1D tensors of tokens
@@ -76,13 +76,28 @@ def concat_and_batch_sequences(
76
76
  begin_batch_token_id: If provided, this token will be at position 0 of each batch
77
77
  begin_sequence_token_id: If provided, this token will be the first token of each sequence
78
78
  sequence_separator_token_id: If provided, this token will be inserted between concatenated sequences
79
- disable_concat_sequences: If True, disable concatenating sequences and ignore sequences shorter than context_size
79
+ disable_concat_sequences: If True, disable concatenating sequences and ignore sequences shorter than context_size (including BOS token if present)
80
80
  max_batches: If not provided, the iterator will be run to completion.
81
81
  """
82
82
  if disable_concat_sequences:
83
- for tokens in tokens_iterator:
84
- if len(tokens) >= context_size:
85
- yield tokens[:context_size]
83
+ if begin_batch_token_id and not begin_sequence_token_id:
84
+ begin_sequence_token_id = begin_batch_token_id
85
+ for sequence in tokens_iterator:
86
+ if (
87
+ begin_sequence_token_id is not None
88
+ and sequence[0] != begin_sequence_token_id
89
+ and len(sequence) >= context_size - 1
90
+ ):
91
+ begin_sequence_token_id_tensor = torch.tensor(
92
+ [begin_sequence_token_id],
93
+ dtype=torch.long,
94
+ device=sequence.device,
95
+ )
96
+ sequence = torch.cat(
97
+ [begin_sequence_token_id_tensor, sequence[: context_size - 1]]
98
+ )
99
+ if len(sequence) >= context_size:
100
+ yield sequence[:context_size]
86
101
  return
87
102
 
88
103
  batch: torch.Tensor | None = None
@@ -1,5 +1,6 @@
1
1
  import json
2
2
  from dataclasses import dataclass
3
+ from pathlib import Path
3
4
  from statistics import mean
4
5
 
5
6
  import torch
@@ -51,3 +52,9 @@ class ActivationScaler:
51
52
 
52
53
  with open(file_path, "w") as f:
53
54
  json.dump({"scaling_factor": self.scaling_factor}, f)
55
+
56
+ def load(self, file_path: str | Path):
57
+ """load the state dict from a file in json format"""
58
+ with open(file_path) as f:
59
+ data = json.load(f)
60
+ self.scaling_factor = data["scaling_factor"]
@@ -4,6 +4,7 @@ import json
4
4
  import os
5
5
  import warnings
6
6
  from collections.abc import Generator, Iterator, Sequence
7
+ from pathlib import Path
7
8
  from typing import Any, Literal, cast
8
9
 
9
10
  import datasets
@@ -13,8 +14,8 @@ from huggingface_hub import hf_hub_download
13
14
  from huggingface_hub.utils import HfHubHTTPError
14
15
  from jaxtyping import Float, Int
15
16
  from requests import HTTPError
16
- from safetensors.torch import save_file
17
- from tqdm import tqdm
17
+ from safetensors.torch import load_file, save_file
18
+ from tqdm.auto import tqdm
18
19
  from transformer_lens.hook_points import HookedRootModule
19
20
  from transformers import AutoTokenizer, PreTrainedTokenizerBase
20
21
 
@@ -24,12 +25,15 @@ from sae_lens.config import (
24
25
  HfDataset,
25
26
  LanguageModelSAERunnerConfig,
26
27
  )
27
- from sae_lens.constants import DTYPE_MAP
28
+ from sae_lens.constants import ACTIVATIONS_STORE_STATE_FILENAME, DTYPE_MAP
28
29
  from sae_lens.pretokenize_runner import get_special_token_from_cfg
29
30
  from sae_lens.saes.sae import SAE, T_SAE_CONFIG, T_TRAINING_SAE_CONFIG
30
31
  from sae_lens.tokenization_and_batching import concat_and_batch_sequences
31
32
  from sae_lens.training.mixing_buffer import mixing_buffer
32
- from sae_lens.util import extract_stop_at_layer_from_tlens_hook_name
33
+ from sae_lens.util import (
34
+ extract_stop_at_layer_from_tlens_hook_name,
35
+ get_special_token_ids,
36
+ )
33
37
 
34
38
 
35
39
  # TODO: Make an activation store config class to be consistent with the rest of the code.
@@ -113,7 +117,7 @@ class ActivationsStore:
113
117
  if exclude_special_tokens is False:
114
118
  exclude_special_tokens = None
115
119
  if exclude_special_tokens is True:
116
- exclude_special_tokens = _get_special_token_ids(model.tokenizer) # type: ignore
120
+ exclude_special_tokens = get_special_token_ids(model.tokenizer) # type: ignore
117
121
  if exclude_special_tokens is not None:
118
122
  exclude_special_tokens = torch.tensor(
119
123
  exclude_special_tokens, dtype=torch.long, device=device
@@ -315,7 +319,7 @@ class ActivationsStore:
315
319
  )
316
320
  else:
317
321
  warnings.warn(
318
- "Dataset is not tokenized. Pre-tokenizing will improve performance and allows for more control over special tokens. See https://jbloomaus.github.io/SAELens/training_saes/#pretokenizing-datasets for more info."
322
+ "Dataset is not tokenized. Pre-tokenizing will improve performance and allows for more control over special tokens. See https://decoderesearch.github.io/SAELens/training_saes/#pretokenizing-datasets for more info."
319
323
  )
320
324
 
321
325
  self.iterable_sequences = self._iterate_tokenized_sequences()
@@ -726,6 +730,48 @@ class ActivationsStore:
726
730
  """save the state dict to a file in safetensors format"""
727
731
  save_file(self.state_dict(), file_path)
728
732
 
733
+ def save_to_checkpoint(self, checkpoint_path: str | Path):
734
+ """Save the state dict to a checkpoint path"""
735
+ self.save(str(Path(checkpoint_path) / ACTIVATIONS_STORE_STATE_FILENAME))
736
+
737
+ def load_from_checkpoint(self, checkpoint_path: str | Path):
738
+ """Load the state dict from a checkpoint path"""
739
+ self.load(str(Path(checkpoint_path) / ACTIVATIONS_STORE_STATE_FILENAME))
740
+
741
+ def load(self, file_path: str):
742
+ """Load the state dict from a file in safetensors format"""
743
+
744
+ state_dict = load_file(file_path)
745
+
746
+ if "n_dataset_processed" in state_dict:
747
+ target_n_dataset_processed = state_dict["n_dataset_processed"].item()
748
+
749
+ # Only fast-forward if needed
750
+
751
+ if target_n_dataset_processed > self.n_dataset_processed:
752
+ logger.info(
753
+ "Fast-forwarding through dataset samples to match checkpoint position"
754
+ )
755
+ samples_to_skip = target_n_dataset_processed - self.n_dataset_processed
756
+
757
+ pbar = tqdm(
758
+ total=samples_to_skip,
759
+ desc="Fast-forwarding through dataset",
760
+ leave=False,
761
+ )
762
+ while target_n_dataset_processed > self.n_dataset_processed:
763
+ start = self.n_dataset_processed
764
+ try:
765
+ # Just consume and ignore the values to fast-forward
766
+ next(self.iterable_sequences)
767
+ except StopIteration:
768
+ logger.warning(
769
+ "Dataset exhausted during fast-forward. Resetting dataset."
770
+ )
771
+ self.iterable_sequences = self._iterate_tokenized_sequences()
772
+ pbar.update(self.n_dataset_processed - start)
773
+ pbar.close()
774
+
729
775
 
730
776
  def validate_pretokenized_dataset_tokenizer(
731
777
  dataset_path: str, model_tokenizer: PreTrainedTokenizerBase
@@ -763,31 +809,6 @@ def _get_model_device(model: HookedRootModule) -> torch.device:
763
809
  return next(model.parameters()).device # type: ignore
764
810
 
765
811
 
766
- def _get_special_token_ids(tokenizer: PreTrainedTokenizerBase) -> list[int]:
767
- """Get all special token IDs from a tokenizer."""
768
- special_tokens = set()
769
-
770
- # Get special tokens from tokenizer attributes
771
- for attr in dir(tokenizer):
772
- if attr.endswith("_token_id"):
773
- token_id = getattr(tokenizer, attr)
774
- if token_id is not None:
775
- special_tokens.add(token_id)
776
-
777
- # Get any additional special tokens from the tokenizer's special tokens map
778
- if hasattr(tokenizer, "special_tokens_map"):
779
- for token in tokenizer.special_tokens_map.values():
780
- if isinstance(token, str):
781
- token_id = tokenizer.convert_tokens_to_ids(token) # type: ignore
782
- special_tokens.add(token_id)
783
- elif isinstance(token, list):
784
- for t in token:
785
- token_id = tokenizer.convert_tokens_to_ids(t) # type: ignore
786
- special_tokens.add(token_id)
787
-
788
- return list(special_tokens)
789
-
790
-
791
812
  def _filter_buffer_acts(
792
813
  buffer: tuple[torch.Tensor, torch.Tensor | None],
793
814
  exclude_tokens: torch.Tensor | None,
@@ -2,6 +2,8 @@
2
2
  Took the LR scheduler from my previous work: https://github.com/jbloomAus/DecisionTransformerInterpretability/blob/ee55df35cdb92e81d689c72fb9dd5a7252893363/src/decision_transformer/utils.py#L425
3
3
  """
4
4
 
5
+ from typing import Any
6
+
5
7
  import torch.optim as optim
6
8
  import torch.optim.lr_scheduler as lr_scheduler
7
9
 
@@ -150,3 +152,12 @@ class CoefficientScheduler:
150
152
  def value(self) -> float:
151
153
  """Returns the current scalar value."""
152
154
  return self.current_value
155
+
156
+ def state_dict(self) -> dict[str, Any]:
157
+ return {
158
+ "current_step": self.current_step,
159
+ }
160
+
161
+ def load_state_dict(self, state_dict: dict[str, Any]):
162
+ for k in state_dict:
163
+ setattr(self, k, state_dict[k])