sae-lens 6.15.0__py3-none-any.whl → 6.22.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,7 +2,6 @@ import warnings
2
2
  from dataclasses import dataclass, field
3
3
 
4
4
  import torch
5
- from jaxtyping import Float
6
5
  from typing_extensions import override
7
6
 
8
7
  from sae_lens.saes.batchtopk_sae import (
@@ -78,14 +77,11 @@ class MatryoshkaBatchTopKTrainingSAE(BatchTopKTrainingSAE):
78
77
  @override
79
78
  def training_forward_pass(self, step_input: TrainStepInput) -> TrainStepOutput:
80
79
  base_output = super().training_forward_pass(step_input)
81
- hidden_pre = base_output.hidden_pre
82
80
  inv_W_dec_norm = 1 / self.W_dec.norm(dim=-1)
83
81
  # the outer matryoshka level is the base SAE, so we don't need to add an extra loss for it
84
82
  for width in self.cfg.matryoshka_widths[:-1]:
85
- inner_hidden_pre = hidden_pre[:, :width]
86
- inner_feat_acts = self.activation_fn(inner_hidden_pre)
87
83
  inner_reconstruction = self._decode_matryoshka_level(
88
- inner_feat_acts, width, inv_W_dec_norm
84
+ base_output.feature_acts, width, inv_W_dec_norm
89
85
  )
90
86
  inner_mse_loss = (
91
87
  self.mse_loss_fn(inner_reconstruction, step_input.sae_in)
@@ -98,23 +94,24 @@ class MatryoshkaBatchTopKTrainingSAE(BatchTopKTrainingSAE):
98
94
 
99
95
  def _decode_matryoshka_level(
100
96
  self,
101
- feature_acts: Float[torch.Tensor, "... d_sae"],
97
+ feature_acts: torch.Tensor,
102
98
  width: int,
103
99
  inv_W_dec_norm: torch.Tensor,
104
- ) -> Float[torch.Tensor, "... d_in"]:
100
+ ) -> torch.Tensor:
105
101
  """
106
102
  Decodes feature activations back into input space for a matryoshka level
107
103
  """
104
+ inner_feature_acts = feature_acts[:, :width]
108
105
  # Handle sparse tensors using efficient sparse matrix multiplication
109
106
  if self.cfg.rescale_acts_by_decoder_norm:
110
107
  # need to multiply by the inverse of the norm because division is illegal with sparse tensors
111
- feature_acts = feature_acts * inv_W_dec_norm[:width]
112
- if feature_acts.is_sparse:
108
+ inner_feature_acts = inner_feature_acts * inv_W_dec_norm[:width]
109
+ if inner_feature_acts.is_sparse:
113
110
  sae_out_pre = (
114
- _sparse_matmul_nd(feature_acts, self.W_dec[:width]) + self.b_dec
111
+ _sparse_matmul_nd(inner_feature_acts, self.W_dec[:width]) + self.b_dec
115
112
  )
116
113
  else:
117
- sae_out_pre = feature_acts @ self.W_dec[:width] + self.b_dec
114
+ sae_out_pre = inner_feature_acts @ self.W_dec[:width] + self.b_dec
118
115
  sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
119
116
  return self.reshape_fn_out(sae_out_pre, self.d_head)
120
117
 
@@ -137,7 +134,3 @@ def _validate_matryoshka_config(cfg: MatryoshkaBatchTopKTrainingSAEConfig) -> No
137
134
  warnings.warn(
138
135
  "WARNING: You have only set one matryoshka level. This is equivalent to using a standard BatchTopK SAE and is likely not what you want."
139
136
  )
140
- if cfg.matryoshka_widths[0] < cfg.k:
141
- raise ValueError(
142
- "The smallest matryoshka level width cannot be smaller than cfg.k."
143
- )
sae_lens/saes/sae.py CHANGED
@@ -19,9 +19,8 @@ from typing import (
19
19
 
20
20
  import einops
21
21
  import torch
22
- from jaxtyping import Float
23
22
  from numpy.typing import NDArray
24
- from safetensors.torch import save_file
23
+ from safetensors.torch import load_file, save_file
25
24
  from torch import nn
26
25
  from transformer_lens.hook_points import HookedRootModule, HookPoint
27
26
  from typing_extensions import deprecated, overload, override
@@ -155,9 +154,9 @@ class SAEConfig(ABC):
155
154
  dtype: str = "float32"
156
155
  device: str = "cpu"
157
156
  apply_b_dec_to_input: bool = True
158
- normalize_activations: Literal[
159
- "none", "expected_average_only_in", "constant_norm_rescale", "layer_norm"
160
- ] = "none" # none, expected_average_only_in (Anthropic April Update), constant_norm_rescale (Anthropic Feb Update)
157
+ normalize_activations: Literal["none", "expected_average_only_in", "layer_norm"] = (
158
+ "none" # none, expected_average_only_in (Anthropic April Update)
159
+ )
161
160
  reshape_activations: Literal["none", "hook_z"] = "none"
162
161
  metadata: SAEMetadata = field(default_factory=SAEMetadata)
163
162
 
@@ -217,6 +216,7 @@ class TrainStepInput:
217
216
  sae_in: torch.Tensor
218
217
  coefficients: dict[str, float]
219
218
  dead_neuron_mask: torch.Tensor | None
219
+ n_training_steps: int
220
220
 
221
221
 
222
222
  class TrainCoefficientConfig(NamedTuple):
@@ -308,6 +308,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
308
308
 
309
309
  self.run_time_activation_norm_fn_in = run_time_activation_norm_fn_in
310
310
  self.run_time_activation_norm_fn_out = run_time_activation_norm_fn_out
311
+
311
312
  elif self.cfg.normalize_activations == "layer_norm":
312
313
  # we need to scale the norm of the input and store the scaling factor
313
314
  def run_time_activation_ln_in(
@@ -349,16 +350,12 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
349
350
  self.W_enc = nn.Parameter(w_enc_data)
350
351
 
351
352
  @abstractmethod
352
- def encode(
353
- self, x: Float[torch.Tensor, "... d_in"]
354
- ) -> Float[torch.Tensor, "... d_sae"]:
353
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
355
354
  """Encode input tensor to feature space."""
356
355
  pass
357
356
 
358
357
  @abstractmethod
359
- def decode(
360
- self, feature_acts: Float[torch.Tensor, "... d_sae"]
361
- ) -> Float[torch.Tensor, "... d_in"]:
358
+ def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
362
359
  """Decode feature activations back to input space."""
363
360
  pass
364
361
 
@@ -448,26 +445,15 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
448
445
 
449
446
  return super().to(*args, **kwargs)
450
447
 
451
- def process_sae_in(
452
- self, sae_in: Float[torch.Tensor, "... d_in"]
453
- ) -> Float[torch.Tensor, "... d_in"]:
454
- # print(f"Input shape to process_sae_in: {sae_in.shape}")
455
- # print(f"self.cfg.hook_name: {self.cfg.hook_name}")
456
- # print(f"self.b_dec shape: {self.b_dec.shape}")
457
- # print(f"Hook z reshaping mode: {getattr(self, 'hook_z_reshaping_mode', False)}")
458
-
448
+ def process_sae_in(self, sae_in: torch.Tensor) -> torch.Tensor:
459
449
  sae_in = sae_in.to(self.dtype)
460
-
461
- # print(f"Shape before reshape_fn_in: {sae_in.shape}")
462
450
  sae_in = self.reshape_fn_in(sae_in)
463
- # print(f"Shape after reshape_fn_in: {sae_in.shape}")
464
451
 
465
452
  sae_in = self.hook_sae_input(sae_in)
466
453
  sae_in = self.run_time_activation_norm_fn_in(sae_in)
467
454
 
468
455
  # Here's where the error happens
469
456
  bias_term = self.b_dec * self.cfg.apply_b_dec_to_input
470
- # print(f"Bias term shape: {bias_term.shape}")
471
457
 
472
458
  return sae_in - bias_term
473
459
 
@@ -866,14 +852,12 @@ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
866
852
 
867
853
  @abstractmethod
868
854
  def encode_with_hidden_pre(
869
- self, x: Float[torch.Tensor, "... d_in"]
870
- ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
855
+ self, x: torch.Tensor
856
+ ) -> tuple[torch.Tensor, torch.Tensor]:
871
857
  """Encode with access to pre-activation values for training."""
872
858
  ...
873
859
 
874
- def encode(
875
- self, x: Float[torch.Tensor, "... d_in"]
876
- ) -> Float[torch.Tensor, "... d_sae"]:
860
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
877
861
  """
878
862
  For inference, just encode without returning hidden_pre.
879
863
  (training_forward_pass calls encode_with_hidden_pre).
@@ -881,9 +865,7 @@ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
881
865
  feature_acts, _ = self.encode_with_hidden_pre(x)
882
866
  return feature_acts
883
867
 
884
- def decode(
885
- self, feature_acts: Float[torch.Tensor, "... d_sae"]
886
- ) -> Float[torch.Tensor, "... d_in"]:
868
+ def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
887
869
  """
888
870
  Decodes feature activations back into input space,
889
871
  applying optional finetuning scale, hooking, out normalization, etc.
@@ -1017,6 +999,12 @@ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
1017
999
  ) -> type[TrainingSAEConfig]:
1018
1000
  return get_sae_training_class(architecture)[1]
1019
1001
 
1002
+ def load_weights_from_checkpoint(self, checkpoint_path: Path | str) -> None:
1003
+ checkpoint_path = Path(checkpoint_path)
1004
+ state_dict = load_file(checkpoint_path / SAE_WEIGHTS_FILENAME)
1005
+ self.process_state_dict_for_loading(state_dict)
1006
+ self.load_state_dict(state_dict)
1007
+
1020
1008
 
1021
1009
  _blank_hook = nn.Identity()
1022
1010
 
@@ -2,7 +2,6 @@ from dataclasses import dataclass
2
2
 
3
3
  import numpy as np
4
4
  import torch
5
- from jaxtyping import Float
6
5
  from numpy.typing import NDArray
7
6
  from torch import nn
8
7
  from typing_extensions import override
@@ -54,9 +53,7 @@ class StandardSAE(SAE[StandardSAEConfig]):
54
53
  super().initialize_weights()
55
54
  _init_weights_standard(self)
56
55
 
57
- def encode(
58
- self, x: Float[torch.Tensor, "... d_in"]
59
- ) -> Float[torch.Tensor, "... d_sae"]:
56
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
60
57
  """
61
58
  Encode the input tensor into the feature space.
62
59
  """
@@ -67,9 +64,7 @@ class StandardSAE(SAE[StandardSAEConfig]):
67
64
  # Apply the activation function (e.g., ReLU, depending on config)
68
65
  return self.hook_sae_acts_post(self.activation_fn(hidden_pre))
69
66
 
70
- def decode(
71
- self, feature_acts: Float[torch.Tensor, "... d_sae"]
72
- ) -> Float[torch.Tensor, "... d_in"]:
67
+ def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
73
68
  """
74
69
  Decode the feature activations back to the input space.
75
70
  Now, if hook_z reshaping is turned on, we reverse the flattening.
@@ -127,8 +122,8 @@ class StandardTrainingSAE(TrainingSAE[StandardTrainingSAEConfig]):
127
122
  }
128
123
 
129
124
  def encode_with_hidden_pre(
130
- self, x: Float[torch.Tensor, "... d_in"]
131
- ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
125
+ self, x: torch.Tensor
126
+ ) -> tuple[torch.Tensor, torch.Tensor]:
132
127
  # Process the input (including dtype conversion, hook call, and any activation normalization)
133
128
  sae_in = self.process_sae_in(x)
134
129
  # Compute the pre-activation (and allow for a hook if desired)
@@ -0,0 +1,365 @@
1
+ """TemporalSAE: A Sparse Autoencoder with temporal attention mechanism.
2
+
3
+ TemporalSAE decomposes activations into:
4
+ 1. Predicted codes (from attention over context)
5
+ 2. Novel codes (sparse features of the residual)
6
+
7
+ See: https://arxiv.org/abs/2410.04185
8
+ """
9
+
10
+ import math
11
+ from dataclasses import dataclass
12
+ from typing import Literal
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+ from torch import nn
17
+ from typing_extensions import override
18
+
19
+ from sae_lens import logger
20
+ from sae_lens.saes.sae import SAE, SAEConfig
21
+
22
+
23
+ def get_attention(query: torch.Tensor, key: torch.Tensor) -> torch.Tensor:
24
+ """Compute causal attention weights."""
25
+ L, S = query.size(-2), key.size(-2)
26
+ scale_factor = 1 / math.sqrt(query.size(-1))
27
+ attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
28
+ temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0)
29
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
30
+ attn_bias.to(query.dtype)
31
+
32
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
33
+ attn_weight += attn_bias
34
+ return torch.softmax(attn_weight, dim=-1)
35
+
36
+
37
+ class ManualAttention(nn.Module):
38
+ """Manual attention implementation for TemporalSAE."""
39
+
40
+ def __init__(
41
+ self,
42
+ dimin: int,
43
+ n_heads: int = 4,
44
+ bottleneck_factor: int = 64,
45
+ bias_k: bool = True,
46
+ bias_q: bool = True,
47
+ bias_v: bool = True,
48
+ bias_o: bool = True,
49
+ ):
50
+ super().__init__()
51
+ assert dimin % (bottleneck_factor * n_heads) == 0
52
+
53
+ self.n_heads = n_heads
54
+ self.n_embds = dimin // bottleneck_factor
55
+ self.dimin = dimin
56
+
57
+ # Key, query, value projections
58
+ self.k_ctx = nn.Linear(dimin, self.n_embds, bias=bias_k)
59
+ self.q_target = nn.Linear(dimin, self.n_embds, bias=bias_q)
60
+ self.v_ctx = nn.Linear(dimin, dimin, bias=bias_v)
61
+ self.c_proj = nn.Linear(dimin, dimin, bias=bias_o)
62
+
63
+ # Normalize to match scale with representations
64
+ with torch.no_grad():
65
+ scaling = 1 / math.sqrt(self.n_embds // self.n_heads)
66
+ self.k_ctx.weight.copy_(
67
+ scaling
68
+ * self.k_ctx.weight
69
+ / (1e-6 + torch.linalg.norm(self.k_ctx.weight, dim=1, keepdim=True))
70
+ )
71
+ self.q_target.weight.copy_(
72
+ scaling
73
+ * self.q_target.weight
74
+ / (1e-6 + torch.linalg.norm(self.q_target.weight, dim=1, keepdim=True))
75
+ )
76
+
77
+ scaling = 1 / math.sqrt(self.dimin // self.n_heads)
78
+ self.v_ctx.weight.copy_(
79
+ scaling
80
+ * self.v_ctx.weight
81
+ / (1e-6 + torch.linalg.norm(self.v_ctx.weight, dim=1, keepdim=True))
82
+ )
83
+
84
+ scaling = 1 / math.sqrt(self.dimin)
85
+ self.c_proj.weight.copy_(
86
+ scaling
87
+ * self.c_proj.weight
88
+ / (1e-6 + torch.linalg.norm(self.c_proj.weight, dim=1, keepdim=True))
89
+ )
90
+
91
+ def forward(
92
+ self, x_ctx: torch.Tensor, x_target: torch.Tensor, get_attn_map: bool = False
93
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
94
+ """Compute projective attention output."""
95
+ k = self.k_ctx(x_ctx)
96
+ v = self.v_ctx(x_ctx)
97
+ q = self.q_target(x_target)
98
+
99
+ # Split into heads
100
+ B, T, _ = x_ctx.size()
101
+ k = k.view(B, T, self.n_heads, self.n_embds // self.n_heads).transpose(1, 2)
102
+ q = q.view(B, T, self.n_heads, self.n_embds // self.n_heads).transpose(1, 2)
103
+ v = v.view(B, T, self.n_heads, self.dimin // self.n_heads).transpose(1, 2)
104
+
105
+ # Attention map (optional)
106
+ attn_map = None
107
+ if get_attn_map:
108
+ attn_map = get_attention(query=q, key=k)
109
+
110
+ # Scaled dot-product attention
111
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
112
+ q, k, v, attn_mask=None, dropout_p=0, is_causal=True
113
+ )
114
+
115
+ # Reshape and project
116
+ d_target = self.c_proj(
117
+ attn_output.transpose(1, 2).contiguous().view(B, T, self.dimin)
118
+ )
119
+
120
+ return d_target, attn_map
121
+
122
+
123
+ @dataclass
124
+ class TemporalSAEConfig(SAEConfig):
125
+ """Configuration for TemporalSAE inference.
126
+
127
+ Args:
128
+ d_in: Input dimension (dimensionality of the activations being encoded)
129
+ d_sae: SAE latent dimension (number of features)
130
+ n_heads: Number of attention heads in temporal attention
131
+ n_attn_layers: Number of attention layers
132
+ bottleneck_factor: Bottleneck factor for attention dimension
133
+ sae_diff_type: Type of SAE for novel codes ('relu' or 'topk')
134
+ kval_topk: K value for top-k sparsity (if sae_diff_type='topk')
135
+ tied_weights: Whether to tie encoder and decoder weights
136
+ activation_normalization_factor: Scalar factor for rescaling activations (used with normalize_activations='constant_scalar_rescale')
137
+ """
138
+
139
+ n_heads: int = 8
140
+ n_attn_layers: int = 1
141
+ bottleneck_factor: int = 64
142
+ sae_diff_type: Literal["relu", "topk"] = "topk"
143
+ kval_topk: int | None = None
144
+ tied_weights: bool = True
145
+ activation_normalization_factor: float = 1.0
146
+
147
+ def __post_init__(self):
148
+ # Call parent's __post_init__ first, but allow constant_scalar_rescale
149
+ if self.normalize_activations not in [
150
+ "none",
151
+ "expected_average_only_in",
152
+ "constant_norm_rescale",
153
+ "constant_scalar_rescale", # Temporal SAEs support this
154
+ "layer_norm",
155
+ ]:
156
+ raise ValueError(
157
+ f"normalize_activations must be none, expected_average_only_in, layer_norm, constant_norm_rescale, or constant_scalar_rescale. Got {self.normalize_activations}"
158
+ )
159
+
160
+ @override
161
+ @classmethod
162
+ def architecture(cls) -> str:
163
+ return "temporal"
164
+
165
+
166
+ class TemporalSAE(SAE[TemporalSAEConfig]):
167
+ """TemporalSAE: Sparse Autoencoder with temporal attention.
168
+
169
+ This SAE decomposes each activation x_t into:
170
+ - x_pred: Information aggregated from context {x_0, ..., x_{t-1}}
171
+ - x_novel: Novel information at position t (encoded sparsely)
172
+
173
+ The forward pass:
174
+ 1. Uses attention layers to predict x_t from context
175
+ 2. Encodes the residual (novel part) with a sparse SAE
176
+ 3. Combines both for reconstruction
177
+ """
178
+
179
+ # Custom parameters (in addition to W_enc, W_dec, b_dec from base)
180
+ attn_layers: nn.ModuleList # Attention layers
181
+ eps: float
182
+ lam: float
183
+
184
+ def __init__(self, cfg: TemporalSAEConfig, use_error_term: bool = False):
185
+ # Call parent init first
186
+ super().__init__(cfg, use_error_term)
187
+
188
+ # Initialize attention layers after parent init and move to correct device
189
+ self.attn_layers = nn.ModuleList(
190
+ [
191
+ ManualAttention(
192
+ dimin=cfg.d_sae,
193
+ n_heads=cfg.n_heads,
194
+ bottleneck_factor=cfg.bottleneck_factor,
195
+ bias_k=True,
196
+ bias_q=True,
197
+ bias_v=True,
198
+ bias_o=True,
199
+ ).to(device=self.device, dtype=self.dtype)
200
+ for _ in range(cfg.n_attn_layers)
201
+ ]
202
+ )
203
+
204
+ self.eps = 1e-6
205
+ self.lam = 1 / (4 * self.cfg.d_in)
206
+
207
+ @override
208
+ def _setup_activation_normalization(self):
209
+ """Set up activation normalization functions for TemporalSAE.
210
+
211
+ Overrides the base implementation to handle constant_scalar_rescale
212
+ using the temporal-specific activation_normalization_factor.
213
+ """
214
+ if self.cfg.normalize_activations == "constant_scalar_rescale":
215
+ # Handle constant scalar rescaling for temporal SAEs
216
+ def run_time_activation_norm_fn_in(x: torch.Tensor) -> torch.Tensor:
217
+ return x * self.cfg.activation_normalization_factor
218
+
219
+ def run_time_activation_norm_fn_out(x: torch.Tensor) -> torch.Tensor:
220
+ return x / self.cfg.activation_normalization_factor
221
+
222
+ self.run_time_activation_norm_fn_in = run_time_activation_norm_fn_in
223
+ self.run_time_activation_norm_fn_out = run_time_activation_norm_fn_out
224
+ else:
225
+ # Delegate to parent for all other normalization types
226
+ super()._setup_activation_normalization()
227
+
228
+ @override
229
+ def initialize_weights(self) -> None:
230
+ """Initialize TemporalSAE weights."""
231
+ # Initialize D (decoder) and b (bias)
232
+ self.W_dec = nn.Parameter(
233
+ torch.randn(
234
+ (self.cfg.d_sae, self.cfg.d_in), dtype=self.dtype, device=self.device
235
+ )
236
+ )
237
+ self.b_dec = nn.Parameter(
238
+ torch.zeros((self.cfg.d_in), dtype=self.dtype, device=self.device)
239
+ )
240
+
241
+ # Initialize E (encoder) if not tied
242
+ if not self.cfg.tied_weights:
243
+ self.W_enc = nn.Parameter(
244
+ torch.randn(
245
+ (self.cfg.d_in, self.cfg.d_sae),
246
+ dtype=self.dtype,
247
+ device=self.device,
248
+ )
249
+ )
250
+
251
+ def encode_with_predictions(
252
+ self, x: torch.Tensor
253
+ ) -> tuple[torch.Tensor, torch.Tensor]:
254
+ """Encode input to novel codes only.
255
+
256
+ Returns only the sparse novel codes (not predicted codes).
257
+ This is the main feature representation for TemporalSAE.
258
+ """
259
+ # Process input through SAELens preprocessing
260
+ x = self.process_sae_in(x)
261
+
262
+ B, L, _ = x.shape
263
+
264
+ if self.cfg.tied_weights: # noqa: SIM108
265
+ W_enc = self.W_dec.T
266
+ else:
267
+ W_enc = self.W_enc
268
+
269
+ # Compute predicted codes using attention
270
+ x_residual = x
271
+ z_pred = torch.zeros((B, L, self.cfg.d_sae), device=x.device, dtype=x.dtype)
272
+
273
+ for attn_layer in self.attn_layers:
274
+ # Encode input to latent space
275
+ z_input = F.relu(torch.matmul(x_residual * self.lam, W_enc))
276
+
277
+ # Shift context (causal masking)
278
+ z_ctx = torch.cat(
279
+ (torch.zeros_like(z_input[:, :1, :]), z_input[:, :-1, :].clone()), dim=1
280
+ )
281
+
282
+ # Apply attention to get predicted codes
283
+ z_pred_, _ = attn_layer(z_ctx, z_input, get_attn_map=False)
284
+ z_pred_ = F.relu(z_pred_)
285
+
286
+ # Project predicted codes back to input space
287
+ Dz_pred_ = torch.matmul(z_pred_, self.W_dec)
288
+ Dz_norm_ = Dz_pred_.norm(dim=-1, keepdim=True) + self.eps
289
+
290
+ # Compute projection scale
291
+ proj_scale = (Dz_pred_ * x_residual).sum(
292
+ dim=-1, keepdim=True
293
+ ) / Dz_norm_.pow(2)
294
+
295
+ # Accumulate predicted codes
296
+ z_pred = z_pred + (z_pred_ * proj_scale)
297
+
298
+ # Remove prediction from residual
299
+ x_residual = x_residual - proj_scale * Dz_pred_
300
+
301
+ # Encode residual (novel part) with sparse SAE
302
+ z_novel = F.relu(torch.matmul(x_residual * self.lam, W_enc))
303
+ if self.cfg.sae_diff_type == "topk":
304
+ kval = self.cfg.kval_topk
305
+ if kval is not None:
306
+ _, topk_indices = torch.topk(z_novel, kval, dim=-1)
307
+ mask = torch.zeros_like(z_novel)
308
+ mask.scatter_(-1, topk_indices, 1)
309
+ z_novel = z_novel * mask
310
+
311
+ # Return only novel codes (these are the interpretable features)
312
+ return z_novel, z_pred
313
+
314
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
315
+ return self.encode_with_predictions(x)[0]
316
+
317
+ def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
318
+ """Decode novel codes to reconstruction.
319
+
320
+ Note: This only decodes the novel codes. For full reconstruction,
321
+ use forward() which includes predicted codes.
322
+ """
323
+ # Decode novel codes
324
+ sae_out = torch.matmul(feature_acts, self.W_dec)
325
+ sae_out = sae_out + self.b_dec
326
+
327
+ # Apply hook
328
+ sae_out = self.hook_sae_recons(sae_out)
329
+
330
+ # Apply output activation normalization (reverses input normalization)
331
+ sae_out = self.run_time_activation_norm_fn_out(sae_out)
332
+
333
+ # Add bias (already removed in process_sae_in)
334
+ logger.warning(
335
+ "NOTE this only decodes x_novel. The x_pred is missing, so we're not reconstructing the full x."
336
+ )
337
+ return sae_out
338
+
339
+ @override
340
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
341
+ """Full forward pass through TemporalSAE.
342
+
343
+ Returns complete reconstruction (predicted + novel).
344
+ """
345
+ # Encode
346
+ z_novel, z_pred = self.encode_with_predictions(x)
347
+
348
+ # Decode the sum of predicted and novel codes.
349
+ x_recons = torch.matmul(z_novel + z_pred, self.W_dec) + self.b_dec
350
+
351
+ # Apply output activation normalization (reverses input normalization)
352
+ x_recons = self.run_time_activation_norm_fn_out(x_recons)
353
+
354
+ return self.hook_sae_output(x_recons)
355
+
356
+ @override
357
+ def fold_W_dec_norm(self) -> None:
358
+ raise NotImplementedError("Folding W_dec_norm is not supported for TemporalSAE")
359
+
360
+ @override
361
+ @torch.no_grad()
362
+ def fold_activation_norm_scaling_factor(self, scaling_factor: float) -> None:
363
+ raise NotImplementedError(
364
+ "Folding activation norm scaling factor is not supported for TemporalSAE"
365
+ )
sae_lens/saes/topk_sae.py CHANGED
@@ -4,7 +4,6 @@ from dataclasses import dataclass
4
4
  from typing import Any, Callable
5
5
 
6
6
  import torch
7
- from jaxtyping import Float
8
7
  from torch import nn
9
8
  from transformer_lens.hook_points import HookPoint
10
9
  from typing_extensions import override
@@ -235,9 +234,7 @@ class TopKSAE(SAE[TopKSAEConfig]):
235
234
  super().initialize_weights()
236
235
  _init_weights_topk(self)
237
236
 
238
- def encode(
239
- self, x: Float[torch.Tensor, "... d_in"]
240
- ) -> Float[torch.Tensor, "... d_sae"]:
237
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
241
238
  """
242
239
  Converts input x into feature activations.
243
240
  Uses topk activation under the hood.
@@ -251,8 +248,8 @@ class TopKSAE(SAE[TopKSAEConfig]):
251
248
 
252
249
  def decode(
253
250
  self,
254
- feature_acts: Float[torch.Tensor, "... d_sae"],
255
- ) -> Float[torch.Tensor, "... d_in"]:
251
+ feature_acts: torch.Tensor,
252
+ ) -> torch.Tensor:
256
253
  """
257
254
  Reconstructs the input from topk feature activations.
258
255
  Applies optional finetuning scaling, hooking to recons, out normalization,
@@ -354,8 +351,8 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
354
351
  _init_weights_topk(self)
355
352
 
356
353
  def encode_with_hidden_pre(
357
- self, x: Float[torch.Tensor, "... d_in"]
358
- ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
354
+ self, x: torch.Tensor
355
+ ) -> tuple[torch.Tensor, torch.Tensor]:
359
356
  """
360
357
  Similar to the base training method: calculate pre-activations, then apply TopK.
361
358
  """
@@ -372,8 +369,8 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
372
369
  @override
373
370
  def decode(
374
371
  self,
375
- feature_acts: Float[torch.Tensor, "... d_sae"],
376
- ) -> Float[torch.Tensor, "... d_in"]:
372
+ feature_acts: torch.Tensor,
373
+ ) -> torch.Tensor:
377
374
  """
378
375
  Decodes feature activations back into input space,
379
376
  applying optional finetuning scale, hooking, out normalization, etc.
@@ -1,5 +1,6 @@
1
1
  import json
2
2
  from dataclasses import dataclass
3
+ from pathlib import Path
3
4
  from statistics import mean
4
5
 
5
6
  import torch
@@ -51,3 +52,9 @@ class ActivationScaler:
51
52
 
52
53
  with open(file_path, "w") as f:
53
54
  json.dump({"scaling_factor": self.scaling_factor}, f)
55
+
56
+ def load(self, file_path: str | Path):
57
+ """load the state dict from a file in json format"""
58
+ with open(file_path) as f:
59
+ data = json.load(f)
60
+ self.scaling_factor = data["scaling_factor"]