sae-lens 6.15.0__py3-none-any.whl → 6.24.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,365 @@
1
+ """TemporalSAE: A Sparse Autoencoder with temporal attention mechanism.
2
+
3
+ TemporalSAE decomposes activations into:
4
+ 1. Predicted codes (from attention over context)
5
+ 2. Novel codes (sparse features of the residual)
6
+
7
+ See: https://arxiv.org/abs/2410.04185
8
+ """
9
+
10
+ import math
11
+ from dataclasses import dataclass
12
+ from typing import Literal
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+ from torch import nn
17
+ from typing_extensions import override
18
+
19
+ from sae_lens import logger
20
+ from sae_lens.saes.sae import SAE, SAEConfig
21
+
22
+
23
+ def get_attention(query: torch.Tensor, key: torch.Tensor) -> torch.Tensor:
24
+ """Compute causal attention weights."""
25
+ L, S = query.size(-2), key.size(-2)
26
+ scale_factor = 1 / math.sqrt(query.size(-1))
27
+ attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
28
+ temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0)
29
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
30
+ attn_bias.to(query.dtype)
31
+
32
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
33
+ attn_weight += attn_bias
34
+ return torch.softmax(attn_weight, dim=-1)
35
+
36
+
37
+ class ManualAttention(nn.Module):
38
+ """Manual attention implementation for TemporalSAE."""
39
+
40
+ def __init__(
41
+ self,
42
+ dimin: int,
43
+ n_heads: int = 4,
44
+ bottleneck_factor: int = 64,
45
+ bias_k: bool = True,
46
+ bias_q: bool = True,
47
+ bias_v: bool = True,
48
+ bias_o: bool = True,
49
+ ):
50
+ super().__init__()
51
+ assert dimin % (bottleneck_factor * n_heads) == 0
52
+
53
+ self.n_heads = n_heads
54
+ self.n_embds = dimin // bottleneck_factor
55
+ self.dimin = dimin
56
+
57
+ # Key, query, value projections
58
+ self.k_ctx = nn.Linear(dimin, self.n_embds, bias=bias_k)
59
+ self.q_target = nn.Linear(dimin, self.n_embds, bias=bias_q)
60
+ self.v_ctx = nn.Linear(dimin, dimin, bias=bias_v)
61
+ self.c_proj = nn.Linear(dimin, dimin, bias=bias_o)
62
+
63
+ # Normalize to match scale with representations
64
+ with torch.no_grad():
65
+ scaling = 1 / math.sqrt(self.n_embds // self.n_heads)
66
+ self.k_ctx.weight.copy_(
67
+ scaling
68
+ * self.k_ctx.weight
69
+ / (1e-6 + torch.linalg.norm(self.k_ctx.weight, dim=1, keepdim=True))
70
+ )
71
+ self.q_target.weight.copy_(
72
+ scaling
73
+ * self.q_target.weight
74
+ / (1e-6 + torch.linalg.norm(self.q_target.weight, dim=1, keepdim=True))
75
+ )
76
+
77
+ scaling = 1 / math.sqrt(self.dimin // self.n_heads)
78
+ self.v_ctx.weight.copy_(
79
+ scaling
80
+ * self.v_ctx.weight
81
+ / (1e-6 + torch.linalg.norm(self.v_ctx.weight, dim=1, keepdim=True))
82
+ )
83
+
84
+ scaling = 1 / math.sqrt(self.dimin)
85
+ self.c_proj.weight.copy_(
86
+ scaling
87
+ * self.c_proj.weight
88
+ / (1e-6 + torch.linalg.norm(self.c_proj.weight, dim=1, keepdim=True))
89
+ )
90
+
91
+ def forward(
92
+ self, x_ctx: torch.Tensor, x_target: torch.Tensor, get_attn_map: bool = False
93
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
94
+ """Compute projective attention output."""
95
+ k = self.k_ctx(x_ctx)
96
+ v = self.v_ctx(x_ctx)
97
+ q = self.q_target(x_target)
98
+
99
+ # Split into heads
100
+ B, T, _ = x_ctx.size()
101
+ k = k.view(B, T, self.n_heads, self.n_embds // self.n_heads).transpose(1, 2)
102
+ q = q.view(B, T, self.n_heads, self.n_embds // self.n_heads).transpose(1, 2)
103
+ v = v.view(B, T, self.n_heads, self.dimin // self.n_heads).transpose(1, 2)
104
+
105
+ # Attention map (optional)
106
+ attn_map = None
107
+ if get_attn_map:
108
+ attn_map = get_attention(query=q, key=k)
109
+
110
+ # Scaled dot-product attention
111
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
112
+ q, k, v, attn_mask=None, dropout_p=0, is_causal=True
113
+ )
114
+
115
+ # Reshape and project
116
+ d_target = self.c_proj(
117
+ attn_output.transpose(1, 2).contiguous().view(B, T, self.dimin)
118
+ )
119
+
120
+ return d_target, attn_map
121
+
122
+
123
+ @dataclass
124
+ class TemporalSAEConfig(SAEConfig):
125
+ """Configuration for TemporalSAE inference.
126
+
127
+ Args:
128
+ d_in: Input dimension (dimensionality of the activations being encoded)
129
+ d_sae: SAE latent dimension (number of features)
130
+ n_heads: Number of attention heads in temporal attention
131
+ n_attn_layers: Number of attention layers
132
+ bottleneck_factor: Bottleneck factor for attention dimension
133
+ sae_diff_type: Type of SAE for novel codes ('relu' or 'topk')
134
+ kval_topk: K value for top-k sparsity (if sae_diff_type='topk')
135
+ tied_weights: Whether to tie encoder and decoder weights
136
+ activation_normalization_factor: Scalar factor for rescaling activations (used with normalize_activations='constant_scalar_rescale')
137
+ """
138
+
139
+ n_heads: int = 8
140
+ n_attn_layers: int = 1
141
+ bottleneck_factor: int = 64
142
+ sae_diff_type: Literal["relu", "topk"] = "topk"
143
+ kval_topk: int | None = None
144
+ tied_weights: bool = True
145
+ activation_normalization_factor: float = 1.0
146
+
147
+ def __post_init__(self):
148
+ # Call parent's __post_init__ first, but allow constant_scalar_rescale
149
+ if self.normalize_activations not in [
150
+ "none",
151
+ "expected_average_only_in",
152
+ "constant_norm_rescale",
153
+ "constant_scalar_rescale", # Temporal SAEs support this
154
+ "layer_norm",
155
+ ]:
156
+ raise ValueError(
157
+ f"normalize_activations must be none, expected_average_only_in, layer_norm, constant_norm_rescale, or constant_scalar_rescale. Got {self.normalize_activations}"
158
+ )
159
+
160
+ @override
161
+ @classmethod
162
+ def architecture(cls) -> str:
163
+ return "temporal"
164
+
165
+
166
+ class TemporalSAE(SAE[TemporalSAEConfig]):
167
+ """TemporalSAE: Sparse Autoencoder with temporal attention.
168
+
169
+ This SAE decomposes each activation x_t into:
170
+ - x_pred: Information aggregated from context {x_0, ..., x_{t-1}}
171
+ - x_novel: Novel information at position t (encoded sparsely)
172
+
173
+ The forward pass:
174
+ 1. Uses attention layers to predict x_t from context
175
+ 2. Encodes the residual (novel part) with a sparse SAE
176
+ 3. Combines both for reconstruction
177
+ """
178
+
179
+ # Custom parameters (in addition to W_enc, W_dec, b_dec from base)
180
+ attn_layers: nn.ModuleList # Attention layers
181
+ eps: float
182
+ lam: float
183
+
184
+ def __init__(self, cfg: TemporalSAEConfig, use_error_term: bool = False):
185
+ # Call parent init first
186
+ super().__init__(cfg, use_error_term)
187
+
188
+ # Initialize attention layers after parent init and move to correct device
189
+ self.attn_layers = nn.ModuleList(
190
+ [
191
+ ManualAttention(
192
+ dimin=cfg.d_sae,
193
+ n_heads=cfg.n_heads,
194
+ bottleneck_factor=cfg.bottleneck_factor,
195
+ bias_k=True,
196
+ bias_q=True,
197
+ bias_v=True,
198
+ bias_o=True,
199
+ ).to(device=self.device, dtype=self.dtype)
200
+ for _ in range(cfg.n_attn_layers)
201
+ ]
202
+ )
203
+
204
+ self.eps = 1e-6
205
+ self.lam = 1 / (4 * self.cfg.d_in)
206
+
207
+ @override
208
+ def _setup_activation_normalization(self):
209
+ """Set up activation normalization functions for TemporalSAE.
210
+
211
+ Overrides the base implementation to handle constant_scalar_rescale
212
+ using the temporal-specific activation_normalization_factor.
213
+ """
214
+ if self.cfg.normalize_activations == "constant_scalar_rescale":
215
+ # Handle constant scalar rescaling for temporal SAEs
216
+ def run_time_activation_norm_fn_in(x: torch.Tensor) -> torch.Tensor:
217
+ return x * self.cfg.activation_normalization_factor
218
+
219
+ def run_time_activation_norm_fn_out(x: torch.Tensor) -> torch.Tensor:
220
+ return x / self.cfg.activation_normalization_factor
221
+
222
+ self.run_time_activation_norm_fn_in = run_time_activation_norm_fn_in
223
+ self.run_time_activation_norm_fn_out = run_time_activation_norm_fn_out
224
+ else:
225
+ # Delegate to parent for all other normalization types
226
+ super()._setup_activation_normalization()
227
+
228
+ @override
229
+ def initialize_weights(self) -> None:
230
+ """Initialize TemporalSAE weights."""
231
+ # Initialize D (decoder) and b (bias)
232
+ self.W_dec = nn.Parameter(
233
+ torch.randn(
234
+ (self.cfg.d_sae, self.cfg.d_in), dtype=self.dtype, device=self.device
235
+ )
236
+ )
237
+ self.b_dec = nn.Parameter(
238
+ torch.zeros((self.cfg.d_in), dtype=self.dtype, device=self.device)
239
+ )
240
+
241
+ # Initialize E (encoder) if not tied
242
+ if not self.cfg.tied_weights:
243
+ self.W_enc = nn.Parameter(
244
+ torch.randn(
245
+ (self.cfg.d_in, self.cfg.d_sae),
246
+ dtype=self.dtype,
247
+ device=self.device,
248
+ )
249
+ )
250
+
251
+ def encode_with_predictions(
252
+ self, x: torch.Tensor
253
+ ) -> tuple[torch.Tensor, torch.Tensor]:
254
+ """Encode input to novel codes only.
255
+
256
+ Returns only the sparse novel codes (not predicted codes).
257
+ This is the main feature representation for TemporalSAE.
258
+ """
259
+ # Process input through SAELens preprocessing
260
+ x = self.process_sae_in(x)
261
+
262
+ B, L, _ = x.shape
263
+
264
+ if self.cfg.tied_weights: # noqa: SIM108
265
+ W_enc = self.W_dec.T
266
+ else:
267
+ W_enc = self.W_enc
268
+
269
+ # Compute predicted codes using attention
270
+ x_residual = x
271
+ z_pred = torch.zeros((B, L, self.cfg.d_sae), device=x.device, dtype=x.dtype)
272
+
273
+ for attn_layer in self.attn_layers:
274
+ # Encode input to latent space
275
+ z_input = F.relu(torch.matmul(x_residual * self.lam, W_enc))
276
+
277
+ # Shift context (causal masking)
278
+ z_ctx = torch.cat(
279
+ (torch.zeros_like(z_input[:, :1, :]), z_input[:, :-1, :].clone()), dim=1
280
+ )
281
+
282
+ # Apply attention to get predicted codes
283
+ z_pred_, _ = attn_layer(z_ctx, z_input, get_attn_map=False)
284
+ z_pred_ = F.relu(z_pred_)
285
+
286
+ # Project predicted codes back to input space
287
+ Dz_pred_ = torch.matmul(z_pred_, self.W_dec)
288
+ Dz_norm_ = Dz_pred_.norm(dim=-1, keepdim=True) + self.eps
289
+
290
+ # Compute projection scale
291
+ proj_scale = (Dz_pred_ * x_residual).sum(
292
+ dim=-1, keepdim=True
293
+ ) / Dz_norm_.pow(2)
294
+
295
+ # Accumulate predicted codes
296
+ z_pred = z_pred + (z_pred_ * proj_scale)
297
+
298
+ # Remove prediction from residual
299
+ x_residual = x_residual - proj_scale * Dz_pred_
300
+
301
+ # Encode residual (novel part) with sparse SAE
302
+ z_novel = F.relu(torch.matmul(x_residual * self.lam, W_enc))
303
+ if self.cfg.sae_diff_type == "topk":
304
+ kval = self.cfg.kval_topk
305
+ if kval is not None:
306
+ _, topk_indices = torch.topk(z_novel, kval, dim=-1)
307
+ mask = torch.zeros_like(z_novel)
308
+ mask.scatter_(-1, topk_indices, 1)
309
+ z_novel = z_novel * mask
310
+
311
+ # Return only novel codes (these are the interpretable features)
312
+ return z_novel, z_pred
313
+
314
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
315
+ return self.encode_with_predictions(x)[0]
316
+
317
+ def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
318
+ """Decode novel codes to reconstruction.
319
+
320
+ Note: This only decodes the novel codes. For full reconstruction,
321
+ use forward() which includes predicted codes.
322
+ """
323
+ # Decode novel codes
324
+ sae_out = torch.matmul(feature_acts, self.W_dec)
325
+ sae_out = sae_out + self.b_dec
326
+
327
+ # Apply hook
328
+ sae_out = self.hook_sae_recons(sae_out)
329
+
330
+ # Apply output activation normalization (reverses input normalization)
331
+ sae_out = self.run_time_activation_norm_fn_out(sae_out)
332
+
333
+ # Add bias (already removed in process_sae_in)
334
+ logger.warning(
335
+ "NOTE this only decodes x_novel. The x_pred is missing, so we're not reconstructing the full x."
336
+ )
337
+ return sae_out
338
+
339
+ @override
340
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
341
+ """Full forward pass through TemporalSAE.
342
+
343
+ Returns complete reconstruction (predicted + novel).
344
+ """
345
+ # Encode
346
+ z_novel, z_pred = self.encode_with_predictions(x)
347
+
348
+ # Decode the sum of predicted and novel codes.
349
+ x_recons = torch.matmul(z_novel + z_pred, self.W_dec) + self.b_dec
350
+
351
+ # Apply output activation normalization (reverses input normalization)
352
+ x_recons = self.run_time_activation_norm_fn_out(x_recons)
353
+
354
+ return self.hook_sae_output(x_recons)
355
+
356
+ @override
357
+ def fold_W_dec_norm(self) -> None:
358
+ raise NotImplementedError("Folding W_dec_norm is not supported for TemporalSAE")
359
+
360
+ @override
361
+ @torch.no_grad()
362
+ def fold_activation_norm_scaling_factor(self, scaling_factor: float) -> None:
363
+ raise NotImplementedError(
364
+ "Folding activation norm scaling factor is not supported for TemporalSAE"
365
+ )
sae_lens/saes/topk_sae.py CHANGED
@@ -4,7 +4,6 @@ from dataclasses import dataclass
4
4
  from typing import Any, Callable
5
5
 
6
6
  import torch
7
- from jaxtyping import Float
8
7
  from torch import nn
9
8
  from transformer_lens.hook_points import HookPoint
10
9
  from typing_extensions import override
@@ -235,9 +234,7 @@ class TopKSAE(SAE[TopKSAEConfig]):
235
234
  super().initialize_weights()
236
235
  _init_weights_topk(self)
237
236
 
238
- def encode(
239
- self, x: Float[torch.Tensor, "... d_in"]
240
- ) -> Float[torch.Tensor, "... d_sae"]:
237
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
241
238
  """
242
239
  Converts input x into feature activations.
243
240
  Uses topk activation under the hood.
@@ -251,8 +248,8 @@ class TopKSAE(SAE[TopKSAEConfig]):
251
248
 
252
249
  def decode(
253
250
  self,
254
- feature_acts: Float[torch.Tensor, "... d_sae"],
255
- ) -> Float[torch.Tensor, "... d_in"]:
251
+ feature_acts: torch.Tensor,
252
+ ) -> torch.Tensor:
256
253
  """
257
254
  Reconstructs the input from topk feature activations.
258
255
  Applies optional finetuning scaling, hooking to recons, out normalization,
@@ -354,8 +351,8 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
354
351
  _init_weights_topk(self)
355
352
 
356
353
  def encode_with_hidden_pre(
357
- self, x: Float[torch.Tensor, "... d_in"]
358
- ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
354
+ self, x: torch.Tensor
355
+ ) -> tuple[torch.Tensor, torch.Tensor]:
359
356
  """
360
357
  Similar to the base training method: calculate pre-activations, then apply TopK.
361
358
  """
@@ -372,8 +369,8 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
372
369
  @override
373
370
  def decode(
374
371
  self,
375
- feature_acts: Float[torch.Tensor, "... d_sae"],
376
- ) -> Float[torch.Tensor, "... d_in"]:
372
+ feature_acts: torch.Tensor,
373
+ ) -> torch.Tensor:
377
374
  """
378
375
  Decodes feature activations back into input space,
379
376
  applying optional finetuning scale, hooking, out normalization, etc.
@@ -534,7 +531,7 @@ def _fold_norm_topk(
534
531
  b_enc: torch.Tensor,
535
532
  W_dec: torch.Tensor,
536
533
  ) -> None:
537
- W_dec_norm = W_dec.norm(dim=-1)
534
+ W_dec_norm = W_dec.norm(dim=-1).clamp(min=1e-8)
538
535
  b_enc.data = b_enc.data * W_dec_norm
539
536
  W_dec_norms = W_dec_norm.unsqueeze(1)
540
537
  W_dec.data = W_dec.data / W_dec_norms
@@ -368,3 +368,44 @@ class JumpReLUTranscoder(Transcoder):
368
368
  def from_dict(cls, config_dict: dict[str, Any]) -> "JumpReLUTranscoder":
369
369
  cfg = JumpReLUTranscoderConfig.from_dict(config_dict)
370
370
  return cls(cfg)
371
+
372
+
373
+ @dataclass
374
+ class JumpReLUSkipTranscoderConfig(JumpReLUTranscoderConfig):
375
+ """Configuration for JumpReLU transcoder."""
376
+
377
+ @classmethod
378
+ def architecture(cls) -> str:
379
+ """Return the architecture name for this config."""
380
+ return "jumprelu_skip_transcoder"
381
+
382
+ @classmethod
383
+ def from_dict(cls, config_dict: dict[str, Any]) -> "JumpReLUSkipTranscoderConfig":
384
+ """Create a JumpReLUSkipTranscoderConfig from a dictionary."""
385
+ # Filter to only include valid dataclass fields
386
+ filtered_config_dict = filter_valid_dataclass_fields(config_dict, cls)
387
+
388
+ # Create the config instance
389
+ res = cls(**filtered_config_dict)
390
+
391
+ # Handle metadata if present
392
+ if "metadata" in config_dict:
393
+ res.metadata = SAEMetadata(**config_dict["metadata"])
394
+
395
+ return res
396
+
397
+
398
+ class JumpReLUSkipTranscoder(JumpReLUTranscoder, SkipTranscoder):
399
+ """
400
+ A transcoder with a learnable skip connection and JumpReLU activation function.
401
+ """
402
+
403
+ cfg: JumpReLUSkipTranscoderConfig # type: ignore[assignment]
404
+
405
+ def __init__(self, cfg: JumpReLUSkipTranscoderConfig):
406
+ super().__init__(cfg)
407
+
408
+ @classmethod
409
+ def from_dict(cls, config_dict: dict[str, Any]) -> "JumpReLUSkipTranscoder":
410
+ cfg = JumpReLUSkipTranscoderConfig.from_dict(config_dict)
411
+ return cls(cfg)
@@ -1,5 +1,6 @@
1
1
  import json
2
2
  from dataclasses import dataclass
3
+ from pathlib import Path
3
4
  from statistics import mean
4
5
 
5
6
  import torch
@@ -51,3 +52,9 @@ class ActivationScaler:
51
52
 
52
53
  with open(file_path, "w") as f:
53
54
  json.dump({"scaling_factor": self.scaling_factor}, f)
55
+
56
+ def load(self, file_path: str | Path):
57
+ """load the state dict from a file in json format"""
58
+ with open(file_path) as f:
59
+ data = json.load(f)
60
+ self.scaling_factor = data["scaling_factor"]
@@ -4,6 +4,7 @@ import json
4
4
  import os
5
5
  import warnings
6
6
  from collections.abc import Generator, Iterator, Sequence
7
+ from pathlib import Path
7
8
  from typing import Any, Literal, cast
8
9
 
9
10
  import datasets
@@ -11,10 +12,9 @@ import torch
11
12
  from datasets import Dataset, DatasetDict, IterableDataset, load_dataset
12
13
  from huggingface_hub import hf_hub_download
13
14
  from huggingface_hub.utils import HfHubHTTPError
14
- from jaxtyping import Float, Int
15
15
  from requests import HTTPError
16
- from safetensors.torch import save_file
17
- from tqdm import tqdm
16
+ from safetensors.torch import load_file, save_file
17
+ from tqdm.auto import tqdm
18
18
  from transformer_lens.hook_points import HookedRootModule
19
19
  from transformers import AutoTokenizer, PreTrainedTokenizerBase
20
20
 
@@ -24,7 +24,7 @@ from sae_lens.config import (
24
24
  HfDataset,
25
25
  LanguageModelSAERunnerConfig,
26
26
  )
27
- from sae_lens.constants import DTYPE_MAP
27
+ from sae_lens.constants import ACTIVATIONS_STORE_STATE_FILENAME, DTYPE_MAP
28
28
  from sae_lens.pretokenize_runner import get_special_token_from_cfg
29
29
  from sae_lens.saes.sae import SAE, T_SAE_CONFIG, T_TRAINING_SAE_CONFIG
30
30
  from sae_lens.tokenization_and_batching import concat_and_batch_sequences
@@ -166,9 +166,11 @@ class ActivationsStore:
166
166
  disable_concat_sequences: bool = False,
167
167
  sequence_separator_token: int | Literal["bos", "eos", "sep"] | None = "bos",
168
168
  ) -> ActivationsStore:
169
+ if context_size is None:
170
+ context_size = sae.cfg.metadata.context_size
169
171
  if sae.cfg.metadata.hook_name is None:
170
172
  raise ValueError("hook_name is required")
171
- if sae.cfg.metadata.context_size is None:
173
+ if context_size is None:
172
174
  raise ValueError("context_size is required")
173
175
  if sae.cfg.metadata.prepend_bos is None:
174
176
  raise ValueError("prepend_bos is required")
@@ -178,9 +180,7 @@ class ActivationsStore:
178
180
  d_in=sae.cfg.d_in,
179
181
  hook_name=sae.cfg.metadata.hook_name,
180
182
  hook_head_index=sae.cfg.metadata.hook_head_index,
181
- context_size=sae.cfg.metadata.context_size
182
- if context_size is None
183
- else context_size,
183
+ context_size=context_size,
184
184
  prepend_bos=sae.cfg.metadata.prepend_bos,
185
185
  streaming=streaming,
186
186
  store_batch_size_prompts=store_batch_size_prompts,
@@ -230,7 +230,7 @@ class ActivationsStore:
230
230
  load_dataset(
231
231
  dataset,
232
232
  split="train",
233
- streaming=streaming,
233
+ streaming=streaming, # type: ignore
234
234
  trust_remote_code=dataset_trust_remote_code, # type: ignore
235
235
  )
236
236
  if isinstance(dataset, str)
@@ -318,7 +318,7 @@ class ActivationsStore:
318
318
  )
319
319
  else:
320
320
  warnings.warn(
321
- "Dataset is not tokenized. Pre-tokenizing will improve performance and allows for more control over special tokens. See https://jbloomaus.github.io/SAELens/training_saes/#pretokenizing-datasets for more info."
321
+ "Dataset is not tokenized. Pre-tokenizing will improve performance and allows for more control over special tokens. See https://decoderesearch.github.io/SAELens/training_saes/#pretokenizing-datasets for more info."
322
322
  )
323
323
 
324
324
  self.iterable_sequences = self._iterate_tokenized_sequences()
@@ -541,8 +541,8 @@ class ActivationsStore:
541
541
  d_in: int,
542
542
  raise_on_epoch_end: bool,
543
543
  ) -> tuple[
544
- Float[torch.Tensor, "(total_size context_size) num_layers d_in"],
545
- Int[torch.Tensor, "(total_size context_size)"] | None,
544
+ torch.Tensor,
545
+ torch.Tensor | None,
546
546
  ]:
547
547
  """
548
548
  Loads `total_size` activations from `cached_activation_dataset`
@@ -729,6 +729,48 @@ class ActivationsStore:
729
729
  """save the state dict to a file in safetensors format"""
730
730
  save_file(self.state_dict(), file_path)
731
731
 
732
+ def save_to_checkpoint(self, checkpoint_path: str | Path):
733
+ """Save the state dict to a checkpoint path"""
734
+ self.save(str(Path(checkpoint_path) / ACTIVATIONS_STORE_STATE_FILENAME))
735
+
736
+ def load_from_checkpoint(self, checkpoint_path: str | Path):
737
+ """Load the state dict from a checkpoint path"""
738
+ self.load(str(Path(checkpoint_path) / ACTIVATIONS_STORE_STATE_FILENAME))
739
+
740
+ def load(self, file_path: str):
741
+ """Load the state dict from a file in safetensors format"""
742
+
743
+ state_dict = load_file(file_path)
744
+
745
+ if "n_dataset_processed" in state_dict:
746
+ target_n_dataset_processed = state_dict["n_dataset_processed"].item()
747
+
748
+ # Only fast-forward if needed
749
+
750
+ if target_n_dataset_processed > self.n_dataset_processed:
751
+ logger.info(
752
+ "Fast-forwarding through dataset samples to match checkpoint position"
753
+ )
754
+ samples_to_skip = target_n_dataset_processed - self.n_dataset_processed
755
+
756
+ pbar = tqdm(
757
+ total=samples_to_skip,
758
+ desc="Fast-forwarding through dataset",
759
+ leave=False,
760
+ )
761
+ while target_n_dataset_processed > self.n_dataset_processed:
762
+ start = self.n_dataset_processed
763
+ try:
764
+ # Just consume and ignore the values to fast-forward
765
+ next(self.iterable_sequences)
766
+ except StopIteration:
767
+ logger.warning(
768
+ "Dataset exhausted during fast-forward. Resetting dataset."
769
+ )
770
+ self.iterable_sequences = self._iterate_tokenized_sequences()
771
+ pbar.update(self.n_dataset_processed - start)
772
+ pbar.close()
773
+
732
774
 
733
775
  def validate_pretokenized_dataset_tokenizer(
734
776
  dataset_path: str, model_tokenizer: PreTrainedTokenizerBase
@@ -2,6 +2,8 @@
2
2
  Took the LR scheduler from my previous work: https://github.com/jbloomAus/DecisionTransformerInterpretability/blob/ee55df35cdb92e81d689c72fb9dd5a7252893363/src/decision_transformer/utils.py#L425
3
3
  """
4
4
 
5
+ from typing import Any
6
+
5
7
  import torch.optim as optim
6
8
  import torch.optim.lr_scheduler as lr_scheduler
7
9
 
@@ -150,3 +152,12 @@ class CoefficientScheduler:
150
152
  def value(self) -> float:
151
153
  """Returns the current scalar value."""
152
154
  return self.current_value
155
+
156
+ def state_dict(self) -> dict[str, Any]:
157
+ return {
158
+ "current_step": self.current_step,
159
+ }
160
+
161
+ def load_state_dict(self, state_dict: dict[str, Any]):
162
+ for k in state_dict:
163
+ setattr(self, k, state_dict[k])