sae-lens 6.12.3__tar.gz → 6.13.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. {sae_lens-6.12.3 → sae_lens-6.13.1}/PKG-INFO +1 -1
  2. {sae_lens-6.12.3 → sae_lens-6.13.1}/pyproject.toml +1 -1
  3. {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/__init__.py +1 -1
  4. {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/evals.py +2 -0
  5. {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/pretokenize_runner.py +2 -1
  6. {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/saes/sae.py +9 -10
  7. sae_lens-6.13.1/sae_lens/saes/topk_sae.py +473 -0
  8. {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/tokenization_and_batching.py +21 -6
  9. {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/training/sae_trainer.py +7 -5
  10. {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/training/types.py +1 -1
  11. sae_lens-6.12.3/sae_lens/saes/topk_sae.py +0 -271
  12. {sae_lens-6.12.3 → sae_lens-6.13.1}/LICENSE +0 -0
  13. {sae_lens-6.12.3 → sae_lens-6.13.1}/README.md +0 -0
  14. {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/analysis/__init__.py +0 -0
  15. {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
  16. {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/analysis/neuronpedia_integration.py +0 -0
  17. {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/cache_activations_runner.py +0 -0
  18. {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/config.py +0 -0
  19. {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/constants.py +0 -0
  20. {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/llm_sae_training_runner.py +0 -0
  21. {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/load_model.py +0 -0
  22. {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/loading/__init__.py +0 -0
  23. {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/loading/pretrained_sae_loaders.py +0 -0
  24. {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/loading/pretrained_saes_directory.py +0 -0
  25. {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/pretrained_saes.yaml +0 -0
  26. {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/registry.py +0 -0
  27. {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/saes/__init__.py +0 -0
  28. {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/saes/batchtopk_sae.py +0 -0
  29. {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/saes/gated_sae.py +0 -0
  30. {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/saes/jumprelu_sae.py +0 -0
  31. {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/saes/standard_sae.py +0 -0
  32. {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/saes/transcoder.py +0 -0
  33. {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/training/__init__.py +0 -0
  34. {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/training/activation_scaler.py +0 -0
  35. {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/training/activations_store.py +0 -0
  36. {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/training/mixing_buffer.py +0 -0
  37. {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/training/optim.py +0 -0
  38. {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
  39. {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/tutorial/tsea.py +0 -0
  40. {sae_lens-6.12.3 → sae_lens-6.13.1}/sae_lens/util.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sae-lens
3
- Version: 6.12.3
3
+ Version: 6.13.1
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  License-File: LICENSE
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "sae-lens"
3
- version = "6.12.3"
3
+ version = "6.13.1"
4
4
  description = "Training and Analyzing Sparse Autoencoders (SAEs)"
5
5
  authors = ["Joseph Bloom"]
6
6
  readme = "README.md"
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.12.3"
2
+ __version__ = "6.13.1"
3
3
 
4
4
  import logging
5
5
 
@@ -466,6 +466,8 @@ def get_sparsity_and_variance_metrics(
466
466
  sae_out_scaled = sae.decode(sae_feature_activations).to(
467
467
  original_act_scaled.device
468
468
  )
469
+ if sae_feature_activations.is_sparse:
470
+ sae_feature_activations = sae_feature_activations.to_dense()
469
471
  del cache
470
472
 
471
473
  sae_out = activation_scaler.unscale(sae_out_scaled)
@@ -1,9 +1,10 @@
1
1
  import io
2
2
  import json
3
3
  import sys
4
+ from collections.abc import Iterator
4
5
  from dataclasses import dataclass
5
6
  from pathlib import Path
6
- from typing import Iterator, Literal, cast
7
+ from typing import Literal, cast
7
8
 
8
9
  import torch
9
10
  from datasets import Dataset, DatasetDict, load_dataset
@@ -14,7 +14,6 @@ from typing import (
14
14
  Generic,
15
15
  Literal,
16
16
  NamedTuple,
17
- Type,
18
17
  TypeVar,
19
18
  )
20
19
 
@@ -534,7 +533,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
534
533
  @classmethod
535
534
  @deprecated("Use load_from_disk instead")
536
535
  def load_from_pretrained(
537
- cls: Type[T_SAE],
536
+ cls: type[T_SAE],
538
537
  path: str | Path,
539
538
  device: str = "cpu",
540
539
  dtype: str | None = None,
@@ -543,7 +542,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
543
542
 
544
543
  @classmethod
545
544
  def load_from_disk(
546
- cls: Type[T_SAE],
545
+ cls: type[T_SAE],
547
546
  path: str | Path,
548
547
  device: str = "cpu",
549
548
  dtype: str | None = None,
@@ -564,7 +563,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
564
563
 
565
564
  @classmethod
566
565
  def from_pretrained(
567
- cls: Type[T_SAE],
566
+ cls: type[T_SAE],
568
567
  release: str,
569
568
  sae_id: str,
570
569
  device: str = "cpu",
@@ -585,7 +584,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
585
584
 
586
585
  @classmethod
587
586
  def from_pretrained_with_cfg_and_sparsity(
588
- cls: Type[T_SAE],
587
+ cls: type[T_SAE],
589
588
  release: str,
590
589
  sae_id: str,
591
590
  device: str = "cpu",
@@ -684,7 +683,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
684
683
  return sae, cfg_dict, log_sparsities
685
684
 
686
685
  @classmethod
687
- def from_dict(cls: Type[T_SAE], config_dict: dict[str, Any]) -> T_SAE:
686
+ def from_dict(cls: type[T_SAE], config_dict: dict[str, Any]) -> T_SAE:
688
687
  """Create an SAE from a config dictionary."""
689
688
  sae_cls = cls.get_sae_class_for_architecture(config_dict["architecture"])
690
689
  sae_config_cls = cls.get_sae_config_class_for_architecture(
@@ -694,8 +693,8 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
694
693
 
695
694
  @classmethod
696
695
  def get_sae_class_for_architecture(
697
- cls: Type[T_SAE], architecture: str
698
- ) -> Type[T_SAE]:
696
+ cls: type[T_SAE], architecture: str
697
+ ) -> type[T_SAE]:
699
698
  """Get the SAE class for a given architecture."""
700
699
  sae_cls, _ = get_sae_class(architecture)
701
700
  if not issubclass(sae_cls, cls):
@@ -1000,8 +999,8 @@ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
1000
999
 
1001
1000
  @classmethod
1002
1001
  def get_sae_class_for_architecture(
1003
- cls: Type[T_TRAINING_SAE], architecture: str
1004
- ) -> Type[T_TRAINING_SAE]:
1002
+ cls: type[T_TRAINING_SAE], architecture: str
1003
+ ) -> type[T_TRAINING_SAE]:
1005
1004
  """Get the SAE class for a given architecture."""
1006
1005
  sae_cls, _ = get_sae_training_class(architecture)
1007
1006
  if not issubclass(sae_cls, cls):
@@ -0,0 +1,473 @@
1
+ """Inference-only TopKSAE variant, similar in spirit to StandardSAE but using a TopK-based activation."""
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Callable
5
+
6
+ import torch
7
+ from jaxtyping import Float
8
+ from torch import nn
9
+ from transformer_lens.hook_points import HookPoint
10
+ from typing_extensions import override
11
+
12
+ from sae_lens.saes.sae import (
13
+ SAE,
14
+ SAEConfig,
15
+ TrainCoefficientConfig,
16
+ TrainingSAE,
17
+ TrainingSAEConfig,
18
+ TrainStepInput,
19
+ _disable_hooks,
20
+ )
21
+
22
+
23
+ class SparseHookPoint(HookPoint):
24
+ """
25
+ A HookPoint that takes in a sparse tensor.
26
+ Overrides TransformerLens's HookPoint.
27
+ """
28
+
29
+ def __init__(self, d_sae: int):
30
+ super().__init__()
31
+ self.d_sae = d_sae
32
+
33
+ @override
34
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
35
+ using_hooks = (
36
+ self._forward_hooks is not None and len(self._forward_hooks) > 0
37
+ ) or (self._backward_hooks is not None and len(self._backward_hooks) > 0)
38
+ if using_hooks and x.is_sparse:
39
+ return x.to_dense()
40
+ return x # if no hooks are being used, use passthrough
41
+
42
+
43
+ class TopK(nn.Module):
44
+ """
45
+ A simple TopK activation that zeroes out all but the top K elements along the last dimension,
46
+ and applies ReLU to the top K elements.
47
+ """
48
+
49
+ use_sparse_activations: bool
50
+
51
+ def __init__(
52
+ self,
53
+ k: int,
54
+ use_sparse_activations: bool = False,
55
+ ):
56
+ super().__init__()
57
+ self.k = k
58
+ self.use_sparse_activations = use_sparse_activations
59
+
60
+ def forward(
61
+ self,
62
+ x: torch.Tensor,
63
+ ) -> torch.Tensor:
64
+ """
65
+ 1) Select top K elements along the last dimension.
66
+ 2) Apply ReLU.
67
+ 3) Zero out all other entries.
68
+ """
69
+ topk_values, topk_indices = torch.topk(x, k=self.k, dim=-1, sorted=False)
70
+ values = topk_values.relu()
71
+ if self.use_sparse_activations:
72
+ # Produce a COO sparse tensor (use sparse matrix multiply in decode)
73
+ original_shape = x.shape
74
+
75
+ # Create indices for all dimensions
76
+ # For each element in topk_indices, we need to map it back to the original tensor coordinates
77
+ batch_dims = original_shape[:-1] # All dimensions except the last one
78
+ num_batch_elements = torch.prod(torch.tensor(batch_dims)).item()
79
+
80
+ # Create batch indices - each batch element repeated k times
81
+ batch_indices_flat = torch.arange(
82
+ num_batch_elements, device=x.device
83
+ ).repeat_interleave(self.k)
84
+
85
+ # Convert flat batch indices back to multi-dimensional indices
86
+ if len(batch_dims) == 1:
87
+ # 2D case: [batch, features]
88
+ sparse_indices = torch.stack(
89
+ [
90
+ batch_indices_flat,
91
+ topk_indices.flatten(),
92
+ ]
93
+ )
94
+ else:
95
+ # 3D+ case: need to unravel the batch indices
96
+ batch_indices_multi = []
97
+ remaining = batch_indices_flat
98
+ for dim_size in reversed(batch_dims):
99
+ batch_indices_multi.append(remaining % dim_size)
100
+ remaining = remaining // dim_size
101
+ batch_indices_multi.reverse()
102
+
103
+ sparse_indices = torch.stack(
104
+ [
105
+ *batch_indices_multi,
106
+ topk_indices.flatten(),
107
+ ]
108
+ )
109
+
110
+ return torch.sparse_coo_tensor(
111
+ sparse_indices, values.flatten(), original_shape
112
+ )
113
+ result = torch.zeros_like(x)
114
+ result.scatter_(-1, topk_indices, values)
115
+ return result
116
+
117
+
118
+ @dataclass
119
+ class TopKSAEConfig(SAEConfig):
120
+ """
121
+ Configuration class for a TopKSAE.
122
+ """
123
+
124
+ k: int = 100
125
+
126
+ @override
127
+ @classmethod
128
+ def architecture(cls) -> str:
129
+ return "topk"
130
+
131
+
132
+ def _sparse_matmul_nd(
133
+ sparse_tensor: torch.Tensor, dense_matrix: torch.Tensor
134
+ ) -> torch.Tensor:
135
+ """
136
+ Multiply a sparse tensor of shape [..., d_sae] with a dense matrix of shape [d_sae, d_out]
137
+ to get a result of shape [..., d_out].
138
+
139
+ This function handles sparse tensors with arbitrary batch dimensions by flattening
140
+ the batch dimensions, performing 2D sparse matrix multiplication, and reshaping back.
141
+ """
142
+ original_shape = sparse_tensor.shape
143
+ batch_dims = original_shape[:-1]
144
+ d_sae = original_shape[-1]
145
+ d_out = dense_matrix.shape[-1]
146
+
147
+ if sparse_tensor.ndim == 2:
148
+ # Simple 2D case - use torch.sparse.mm directly
149
+ # sparse.mm errors with bfloat16 :(
150
+ with torch.autocast(device_type=sparse_tensor.device.type, enabled=False):
151
+ return torch.sparse.mm(sparse_tensor, dense_matrix)
152
+
153
+ # For 3D+ case, reshape to 2D, multiply, then reshape back
154
+ batch_size = int(torch.prod(torch.tensor(batch_dims)).item())
155
+
156
+ # Ensure tensor is coalesced for efficient access to indices/values
157
+ if not sparse_tensor.is_coalesced():
158
+ sparse_tensor = sparse_tensor.coalesce()
159
+
160
+ # Get indices and values
161
+ indices = sparse_tensor.indices() # [ndim, nnz]
162
+ values = sparse_tensor.values() # [nnz]
163
+
164
+ # Convert multi-dimensional batch indices to flat indices
165
+ flat_batch_indices = torch.zeros_like(indices[0])
166
+ multiplier = 1
167
+ for i in reversed(range(len(batch_dims))):
168
+ flat_batch_indices += indices[i] * multiplier
169
+ multiplier *= batch_dims[i]
170
+
171
+ # Create 2D sparse tensor indices [batch_flat, feature]
172
+ sparse_2d_indices = torch.stack([flat_batch_indices, indices[-1]])
173
+
174
+ # Create 2D sparse tensor
175
+ sparse_2d = torch.sparse_coo_tensor(
176
+ sparse_2d_indices, values, (batch_size, d_sae)
177
+ ).coalesce()
178
+
179
+ # sparse.mm errors with bfloat16 :(
180
+ with torch.autocast(device_type=sparse_tensor.device.type, enabled=False):
181
+ # Do the matrix multiplication
182
+ result_2d = torch.sparse.mm(sparse_2d, dense_matrix) # [batch_size, d_out]
183
+
184
+ # Reshape back to original batch dimensions
185
+ result_shape = tuple(batch_dims) + (d_out,)
186
+ return result_2d.view(result_shape)
187
+
188
+
189
+ class TopKSAE(SAE[TopKSAEConfig]):
190
+ """
191
+ An inference-only sparse autoencoder using a "topk" activation function.
192
+ It uses linear encoder and decoder layers, applying the TopK activation
193
+ to the hidden pre-activation in its encode step.
194
+ """
195
+
196
+ b_enc: nn.Parameter
197
+
198
+ def __init__(self, cfg: TopKSAEConfig, use_error_term: bool = False):
199
+ """
200
+ Args:
201
+ cfg: SAEConfig defining model size and behavior.
202
+ use_error_term: Whether to apply the error-term approach in the forward pass.
203
+ """
204
+ super().__init__(cfg, use_error_term)
205
+
206
+ @override
207
+ def initialize_weights(self) -> None:
208
+ # Initialize encoder weights and bias.
209
+ super().initialize_weights()
210
+ _init_weights_topk(self)
211
+
212
+ def encode(
213
+ self, x: Float[torch.Tensor, "... d_in"]
214
+ ) -> Float[torch.Tensor, "... d_sae"]:
215
+ """
216
+ Converts input x into feature activations.
217
+ Uses topk activation under the hood.
218
+ """
219
+ sae_in = self.process_sae_in(x)
220
+ hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
221
+ # The BaseSAE already sets self.activation_fn to TopK(...) if config requests topk.
222
+ return self.hook_sae_acts_post(self.activation_fn(hidden_pre))
223
+
224
+ def decode(
225
+ self,
226
+ feature_acts: Float[torch.Tensor, "... d_sae"],
227
+ ) -> Float[torch.Tensor, "... d_in"]:
228
+ """
229
+ Reconstructs the input from topk feature activations.
230
+ Applies optional finetuning scaling, hooking to recons, out normalization,
231
+ and optional head reshaping.
232
+ """
233
+ # Handle sparse tensors using efficient sparse matrix multiplication
234
+ if feature_acts.is_sparse:
235
+ sae_out_pre = _sparse_matmul_nd(feature_acts, self.W_dec) + self.b_dec
236
+ else:
237
+ sae_out_pre = feature_acts @ self.W_dec + self.b_dec
238
+ sae_out_pre = self.hook_sae_recons(sae_out_pre)
239
+ sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
240
+ return self.reshape_fn_out(sae_out_pre, self.d_head)
241
+
242
+ @override
243
+ def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
244
+ return TopK(self.cfg.k, use_sparse_activations=False)
245
+
246
+ @override
247
+ @torch.no_grad()
248
+ def fold_W_dec_norm(self) -> None:
249
+ raise NotImplementedError(
250
+ "Folding W_dec_norm is not safe for TopKSAEs, as this may change the topk activations"
251
+ )
252
+
253
+
254
+ @dataclass
255
+ class TopKTrainingSAEConfig(TrainingSAEConfig):
256
+ """
257
+ Configuration class for training a TopKTrainingSAE.
258
+
259
+ Args:
260
+ k (int): Number of top features to keep active. Only the top k features
261
+ with the highest pre-activations will be non-zero. Defaults to 100.
262
+ use_sparse_activations (bool): Whether to use sparse tensor representations
263
+ for activations during training. This can reduce memory usage and improve
264
+ performance when k is small relative to d_sae, but is only worthwhile if
265
+ using float32 and not using autocast. Defaults to False.
266
+ aux_loss_coefficient (float): Coefficient for the auxiliary loss that encourages
267
+ dead neurons to learn useful features. This loss helps prevent neuron death
268
+ in TopK SAEs by having dead neurons reconstruct the residual error from
269
+ live neurons. Defaults to 1.0.
270
+ decoder_init_norm (float | None): Norm to initialize decoder weights to.
271
+ 0.1 corresponds to the "heuristic" initialization from Anthropic's April update.
272
+ Use None to disable. Inherited from TrainingSAEConfig. Defaults to 0.1.
273
+ d_in (int): Input dimension (dimensionality of the activations being encoded).
274
+ Inherited from SAEConfig.
275
+ d_sae (int): SAE latent dimension (number of features in the SAE).
276
+ Inherited from SAEConfig.
277
+ dtype (str): Data type for the SAE parameters. Inherited from SAEConfig.
278
+ Defaults to "float32".
279
+ device (str): Device to place the SAE on. Inherited from SAEConfig.
280
+ Defaults to "cpu".
281
+ apply_b_dec_to_input (bool): Whether to apply decoder bias to the input
282
+ before encoding. Inherited from SAEConfig. Defaults to True.
283
+ normalize_activations (Literal["none", "expected_average_only_in", "constant_norm_rescale", "layer_norm"]):
284
+ Normalization strategy for input activations. Inherited from SAEConfig.
285
+ Defaults to "none".
286
+ reshape_activations (Literal["none", "hook_z"]): How to reshape activations
287
+ (useful for attention head outputs). Inherited from SAEConfig.
288
+ Defaults to "none".
289
+ metadata (SAEMetadata): Metadata about the SAE training (model name, hook name, etc.).
290
+ Inherited from SAEConfig.
291
+ """
292
+
293
+ k: int = 100
294
+ use_sparse_activations: bool = False
295
+ aux_loss_coefficient: float = 1.0
296
+
297
+ @override
298
+ @classmethod
299
+ def architecture(cls) -> str:
300
+ return "topk"
301
+
302
+
303
+ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
304
+ """
305
+ TopK variant with training functionality. Calculates a topk-related auxiliary loss, etc.
306
+ """
307
+
308
+ b_enc: nn.Parameter
309
+
310
+ def __init__(self, cfg: TopKTrainingSAEConfig, use_error_term: bool = False):
311
+ super().__init__(cfg, use_error_term)
312
+ self.hook_sae_acts_post = SparseHookPoint(self.cfg.d_sae)
313
+ self.setup()
314
+
315
+ @override
316
+ def initialize_weights(self) -> None:
317
+ super().initialize_weights()
318
+ _init_weights_topk(self)
319
+
320
+ def encode_with_hidden_pre(
321
+ self, x: Float[torch.Tensor, "... d_in"]
322
+ ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
323
+ """
324
+ Similar to the base training method: calculate pre-activations, then apply TopK.
325
+ """
326
+ sae_in = self.process_sae_in(x)
327
+ hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
328
+
329
+ # Apply the TopK activation function (already set in self.activation_fn if config is "topk")
330
+ feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre))
331
+ return feature_acts, hidden_pre
332
+
333
+ @override
334
+ def decode(
335
+ self,
336
+ feature_acts: Float[torch.Tensor, "... d_sae"],
337
+ ) -> Float[torch.Tensor, "... d_in"]:
338
+ """
339
+ Decodes feature activations back into input space,
340
+ applying optional finetuning scale, hooking, out normalization, etc.
341
+ """
342
+ # Handle sparse tensors using efficient sparse matrix multiplication
343
+ if feature_acts.is_sparse:
344
+ sae_out_pre = _sparse_matmul_nd(feature_acts, self.W_dec) + self.b_dec
345
+ else:
346
+ sae_out_pre = feature_acts @ self.W_dec + self.b_dec
347
+ sae_out_pre = self.hook_sae_recons(sae_out_pre)
348
+ sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
349
+ return self.reshape_fn_out(sae_out_pre, self.d_head)
350
+
351
+ @override
352
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
353
+ """Forward pass through the SAE."""
354
+ feature_acts = self.encode(x)
355
+ sae_out = self.decode(feature_acts)
356
+
357
+ if self.use_error_term:
358
+ with torch.no_grad():
359
+ # Recompute without hooks for true error term
360
+ with _disable_hooks(self):
361
+ feature_acts_clean = self.encode(x)
362
+ x_reconstruct_clean = self.decode(feature_acts_clean)
363
+ sae_error = self.hook_sae_error(x - x_reconstruct_clean)
364
+ sae_out = sae_out + sae_error
365
+
366
+ return self.hook_sae_output(sae_out)
367
+
368
+ @override
369
+ def calculate_aux_loss(
370
+ self,
371
+ step_input: TrainStepInput,
372
+ feature_acts: torch.Tensor,
373
+ hidden_pre: torch.Tensor,
374
+ sae_out: torch.Tensor,
375
+ ) -> dict[str, torch.Tensor]:
376
+ # Calculate the auxiliary loss for dead neurons
377
+ topk_loss = self.calculate_topk_aux_loss(
378
+ sae_in=step_input.sae_in,
379
+ sae_out=sae_out,
380
+ hidden_pre=hidden_pre,
381
+ dead_neuron_mask=step_input.dead_neuron_mask,
382
+ )
383
+ return {"auxiliary_reconstruction_loss": topk_loss}
384
+
385
+ @override
386
+ @torch.no_grad()
387
+ def fold_W_dec_norm(self) -> None:
388
+ raise NotImplementedError(
389
+ "Folding W_dec_norm is not safe for TopKSAEs, as this may change the topk activations"
390
+ )
391
+
392
+ @override
393
+ def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
394
+ return TopK(self.cfg.k, use_sparse_activations=self.cfg.use_sparse_activations)
395
+
396
+ @override
397
+ def get_coefficients(self) -> dict[str, TrainCoefficientConfig | float]:
398
+ return {}
399
+
400
+ def calculate_topk_aux_loss(
401
+ self,
402
+ sae_in: torch.Tensor,
403
+ sae_out: torch.Tensor,
404
+ hidden_pre: torch.Tensor,
405
+ dead_neuron_mask: torch.Tensor | None,
406
+ ) -> torch.Tensor:
407
+ """
408
+ Calculate TopK auxiliary loss.
409
+
410
+ This auxiliary loss encourages dead neurons to learn useful features by having
411
+ them reconstruct the residual error from the live neurons. It's a key part of
412
+ preventing neuron death in TopK SAEs.
413
+ """
414
+ # Mostly taken from https://github.com/EleutherAI/sae/blob/main/sae/sae.py, except without variance normalization
415
+ # NOTE: checking the number of dead neurons will force a GPU sync, so performance can likely be improved here
416
+ if dead_neuron_mask is None or (num_dead := int(dead_neuron_mask.sum())) == 0:
417
+ return sae_out.new_tensor(0.0)
418
+ residual = (sae_in - sae_out).detach()
419
+
420
+ # Heuristic from Appendix B.1 in the paper
421
+ k_aux = sae_in.shape[-1] // 2
422
+
423
+ # Reduce the scale of the loss if there are a small number of dead latents
424
+ scale = min(num_dead / k_aux, 1.0)
425
+ k_aux = min(k_aux, num_dead)
426
+
427
+ auxk_acts = _calculate_topk_aux_acts(
428
+ k_aux=k_aux,
429
+ hidden_pre=hidden_pre,
430
+ dead_neuron_mask=dead_neuron_mask,
431
+ )
432
+
433
+ # Encourage the top ~50% of dead latents to predict the residual of the
434
+ # top k living latents
435
+ recons = self.decode(auxk_acts)
436
+ auxk_loss = (recons - residual).pow(2).sum(dim=-1).mean()
437
+ return self.cfg.aux_loss_coefficient * scale * auxk_loss
438
+
439
+
440
+ def _calculate_topk_aux_acts(
441
+ k_aux: int,
442
+ hidden_pre: torch.Tensor,
443
+ dead_neuron_mask: torch.Tensor,
444
+ ) -> torch.Tensor:
445
+ """
446
+ Helper method to calculate activations for the auxiliary loss.
447
+
448
+ Args:
449
+ k_aux: Number of top dead neurons to select
450
+ hidden_pre: Pre-activation values from encoder
451
+ dead_neuron_mask: Boolean mask indicating which neurons are dead
452
+
453
+ Returns:
454
+ Tensor with activations for only the top-k dead neurons, zeros elsewhere
455
+ """
456
+
457
+ # Don't include living latents in this loss
458
+ auxk_latents = torch.where(dead_neuron_mask[None], hidden_pre, -torch.inf)
459
+ # Top-k dead latents
460
+ auxk_topk = auxk_latents.topk(k_aux, sorted=False)
461
+ # Set the activations to zero for all but the top k_aux dead latents
462
+ auxk_acts = torch.zeros_like(hidden_pre)
463
+ auxk_acts.scatter_(-1, auxk_topk.indices, auxk_topk.values)
464
+ # Set activations to zero for all but top k_aux dead latents
465
+ return auxk_acts
466
+
467
+
468
+ def _init_weights_topk(
469
+ sae: SAE[TopKSAEConfig] | TrainingSAE[TopKTrainingSAEConfig],
470
+ ) -> None:
471
+ sae.b_enc = nn.Parameter(
472
+ torch.zeros(sae.cfg.d_sae, dtype=sae.dtype, device=sae.device)
473
+ )
@@ -1,4 +1,4 @@
1
- from typing import Generator, Iterator
1
+ from collections.abc import Generator, Iterator
2
2
 
3
3
  import torch
4
4
 
@@ -68,7 +68,7 @@ def concat_and_batch_sequences(
68
68
  ) -> Generator[torch.Tensor, None, None]:
69
69
  """
70
70
  Generator to concat token sequences together from the tokens_interator, yielding
71
- batches of size `context_size`.
71
+ sequences of size `context_size`. Batching across the batch dimension is handled by the caller.
72
72
 
73
73
  Args:
74
74
  tokens_iterator: An iterator which returns a 1D tensors of tokens
@@ -76,13 +76,28 @@ def concat_and_batch_sequences(
76
76
  begin_batch_token_id: If provided, this token will be at position 0 of each batch
77
77
  begin_sequence_token_id: If provided, this token will be the first token of each sequence
78
78
  sequence_separator_token_id: If provided, this token will be inserted between concatenated sequences
79
- disable_concat_sequences: If True, disable concatenating sequences and ignore sequences shorter than context_size
79
+ disable_concat_sequences: If True, disable concatenating sequences and ignore sequences shorter than context_size (including BOS token if present)
80
80
  max_batches: If not provided, the iterator will be run to completion.
81
81
  """
82
82
  if disable_concat_sequences:
83
- for tokens in tokens_iterator:
84
- if len(tokens) >= context_size:
85
- yield tokens[:context_size]
83
+ if begin_batch_token_id and not begin_sequence_token_id:
84
+ begin_sequence_token_id = begin_batch_token_id
85
+ for sequence in tokens_iterator:
86
+ if (
87
+ begin_sequence_token_id is not None
88
+ and sequence[0] != begin_sequence_token_id
89
+ and len(sequence) >= context_size - 1
90
+ ):
91
+ begin_sequence_token_id_tensor = torch.tensor(
92
+ [begin_sequence_token_id],
93
+ dtype=torch.long,
94
+ device=sequence.device,
95
+ )
96
+ sequence = torch.cat(
97
+ [begin_sequence_token_id_tensor, sequence[: context_size - 1]]
98
+ )
99
+ if len(sequence) >= context_size:
100
+ yield sequence[:context_size]
86
101
  return
87
102
 
88
103
  batch: torch.Tensor | None = None
@@ -253,12 +253,14 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
253
253
  )
254
254
 
255
255
  with torch.no_grad():
256
- did_fire = (train_step_output.feature_acts > 0).float().sum(-2) > 0
256
+ # calling .bool() should be equivalent to .abs() > 0, and work with coo tensors
257
+ firing_feats = train_step_output.feature_acts.bool().float()
258
+ did_fire = firing_feats.sum(-2).bool()
259
+ if did_fire.is_sparse:
260
+ did_fire = did_fire.to_dense()
257
261
  self.n_forward_passes_since_fired += 1
258
262
  self.n_forward_passes_since_fired[did_fire] = 0
259
- self.act_freq_scores += (
260
- (train_step_output.feature_acts.abs() > 0).float().sum(0)
261
- )
263
+ self.act_freq_scores += firing_feats.sum(0)
262
264
  self.n_frac_active_samples += self.cfg.train_batch_size_samples
263
265
 
264
266
  # Grad scaler will rescale gradients if autocast is enabled
@@ -310,7 +312,7 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
310
312
  loss = output.loss.item()
311
313
 
312
314
  # metrics for currents acts
313
- l0 = (feature_acts > 0).float().sum(-1).mean()
315
+ l0 = feature_acts.bool().float().sum(-1).to_dense().mean()
314
316
  current_learning_rate = self.optimizer.param_groups[0]["lr"]
315
317
 
316
318
  per_token_l2_loss = (sae_out - sae_in).pow(2).sum(dim=-1).squeeze()
@@ -1,4 +1,4 @@
1
- from typing import Iterator
1
+ from collections.abc import Iterator
2
2
 
3
3
  import torch
4
4
 
@@ -1,271 +0,0 @@
1
- """Inference-only TopKSAE variant, similar in spirit to StandardSAE but using a TopK-based activation."""
2
-
3
- from dataclasses import dataclass
4
- from typing import Callable
5
-
6
- import torch
7
- from jaxtyping import Float
8
- from torch import nn
9
- from typing_extensions import override
10
-
11
- from sae_lens.saes.sae import (
12
- SAE,
13
- SAEConfig,
14
- TrainCoefficientConfig,
15
- TrainingSAE,
16
- TrainingSAEConfig,
17
- TrainStepInput,
18
- )
19
-
20
-
21
- class TopK(nn.Module):
22
- """
23
- A simple TopK activation that zeroes out all but the top K elements along the last dimension,
24
- and applies ReLU to the top K elements.
25
- """
26
-
27
- b_enc: nn.Parameter
28
-
29
- def __init__(
30
- self,
31
- k: int,
32
- ):
33
- super().__init__()
34
- self.k = k
35
-
36
- def forward(self, x: torch.Tensor) -> torch.Tensor:
37
- """
38
- 1) Select top K elements along the last dimension.
39
- 2) Apply ReLU.
40
- 3) Zero out all other entries.
41
- """
42
- topk = torch.topk(x, k=self.k, dim=-1)
43
- values = topk.values.relu()
44
- result = torch.zeros_like(x)
45
- result.scatter_(-1, topk.indices, values)
46
- return result
47
-
48
-
49
- @dataclass
50
- class TopKSAEConfig(SAEConfig):
51
- """
52
- Configuration class for a TopKSAE.
53
- """
54
-
55
- k: int = 100
56
-
57
- @override
58
- @classmethod
59
- def architecture(cls) -> str:
60
- return "topk"
61
-
62
-
63
- class TopKSAE(SAE[TopKSAEConfig]):
64
- """
65
- An inference-only sparse autoencoder using a "topk" activation function.
66
- It uses linear encoder and decoder layers, applying the TopK activation
67
- to the hidden pre-activation in its encode step.
68
- """
69
-
70
- b_enc: nn.Parameter
71
-
72
- def __init__(self, cfg: TopKSAEConfig, use_error_term: bool = False):
73
- """
74
- Args:
75
- cfg: SAEConfig defining model size and behavior.
76
- use_error_term: Whether to apply the error-term approach in the forward pass.
77
- """
78
- super().__init__(cfg, use_error_term)
79
-
80
- @override
81
- def initialize_weights(self) -> None:
82
- # Initialize encoder weights and bias.
83
- super().initialize_weights()
84
- _init_weights_topk(self)
85
-
86
- def encode(
87
- self, x: Float[torch.Tensor, "... d_in"]
88
- ) -> Float[torch.Tensor, "... d_sae"]:
89
- """
90
- Converts input x into feature activations.
91
- Uses topk activation under the hood.
92
- """
93
- sae_in = self.process_sae_in(x)
94
- hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
95
- # The BaseSAE already sets self.activation_fn to TopK(...) if config requests topk.
96
- return self.hook_sae_acts_post(self.activation_fn(hidden_pre))
97
-
98
- def decode(
99
- self, feature_acts: Float[torch.Tensor, "... d_sae"]
100
- ) -> Float[torch.Tensor, "... d_in"]:
101
- """
102
- Reconstructs the input from topk feature activations.
103
- Applies optional finetuning scaling, hooking to recons, out normalization,
104
- and optional head reshaping.
105
- """
106
- sae_out_pre = feature_acts @ self.W_dec + self.b_dec
107
- sae_out_pre = self.hook_sae_recons(sae_out_pre)
108
- sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
109
- return self.reshape_fn_out(sae_out_pre, self.d_head)
110
-
111
- @override
112
- def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
113
- return TopK(self.cfg.k)
114
-
115
- @override
116
- @torch.no_grad()
117
- def fold_W_dec_norm(self) -> None:
118
- raise NotImplementedError(
119
- "Folding W_dec_norm is not safe for TopKSAEs, as this may change the topk activations"
120
- )
121
-
122
-
123
- @dataclass
124
- class TopKTrainingSAEConfig(TrainingSAEConfig):
125
- """
126
- Configuration class for training a TopKTrainingSAE.
127
- """
128
-
129
- k: int = 100
130
- aux_loss_coefficient: float = 1.0
131
-
132
- @override
133
- @classmethod
134
- def architecture(cls) -> str:
135
- return "topk"
136
-
137
-
138
- class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
139
- """
140
- TopK variant with training functionality. Calculates a topk-related auxiliary loss, etc.
141
- """
142
-
143
- b_enc: nn.Parameter
144
-
145
- def __init__(self, cfg: TopKTrainingSAEConfig, use_error_term: bool = False):
146
- super().__init__(cfg, use_error_term)
147
-
148
- @override
149
- def initialize_weights(self) -> None:
150
- super().initialize_weights()
151
- _init_weights_topk(self)
152
-
153
- def encode_with_hidden_pre(
154
- self, x: Float[torch.Tensor, "... d_in"]
155
- ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
156
- """
157
- Similar to the base training method: calculate pre-activations, then apply TopK.
158
- """
159
- sae_in = self.process_sae_in(x)
160
- hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
161
-
162
- # Apply the TopK activation function (already set in self.activation_fn if config is "topk")
163
- feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre))
164
- return feature_acts, hidden_pre
165
-
166
- @override
167
- def calculate_aux_loss(
168
- self,
169
- step_input: TrainStepInput,
170
- feature_acts: torch.Tensor,
171
- hidden_pre: torch.Tensor,
172
- sae_out: torch.Tensor,
173
- ) -> dict[str, torch.Tensor]:
174
- # Calculate the auxiliary loss for dead neurons
175
- topk_loss = self.calculate_topk_aux_loss(
176
- sae_in=step_input.sae_in,
177
- sae_out=sae_out,
178
- hidden_pre=hidden_pre,
179
- dead_neuron_mask=step_input.dead_neuron_mask,
180
- )
181
- return {"auxiliary_reconstruction_loss": topk_loss}
182
-
183
- @override
184
- @torch.no_grad()
185
- def fold_W_dec_norm(self) -> None:
186
- raise NotImplementedError(
187
- "Folding W_dec_norm is not safe for TopKSAEs, as this may change the topk activations"
188
- )
189
-
190
- @override
191
- def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
192
- return TopK(self.cfg.k)
193
-
194
- @override
195
- def get_coefficients(self) -> dict[str, TrainCoefficientConfig | float]:
196
- return {}
197
-
198
- def calculate_topk_aux_loss(
199
- self,
200
- sae_in: torch.Tensor,
201
- sae_out: torch.Tensor,
202
- hidden_pre: torch.Tensor,
203
- dead_neuron_mask: torch.Tensor | None,
204
- ) -> torch.Tensor:
205
- """
206
- Calculate TopK auxiliary loss.
207
-
208
- This auxiliary loss encourages dead neurons to learn useful features by having
209
- them reconstruct the residual error from the live neurons. It's a key part of
210
- preventing neuron death in TopK SAEs.
211
- """
212
- # Mostly taken from https://github.com/EleutherAI/sae/blob/main/sae/sae.py, except without variance normalization
213
- # NOTE: checking the number of dead neurons will force a GPU sync, so performance can likely be improved here
214
- if dead_neuron_mask is None or (num_dead := int(dead_neuron_mask.sum())) == 0:
215
- return sae_out.new_tensor(0.0)
216
- residual = (sae_in - sae_out).detach()
217
-
218
- # Heuristic from Appendix B.1 in the paper
219
- k_aux = sae_in.shape[-1] // 2
220
-
221
- # Reduce the scale of the loss if there are a small number of dead latents
222
- scale = min(num_dead / k_aux, 1.0)
223
- k_aux = min(k_aux, num_dead)
224
-
225
- auxk_acts = _calculate_topk_aux_acts(
226
- k_aux=k_aux,
227
- hidden_pre=hidden_pre,
228
- dead_neuron_mask=dead_neuron_mask,
229
- )
230
-
231
- # Encourage the top ~50% of dead latents to predict the residual of the
232
- # top k living latents
233
- recons = self.decode(auxk_acts)
234
- auxk_loss = (recons - residual).pow(2).sum(dim=-1).mean()
235
- return self.cfg.aux_loss_coefficient * scale * auxk_loss
236
-
237
-
238
- def _calculate_topk_aux_acts(
239
- k_aux: int,
240
- hidden_pre: torch.Tensor,
241
- dead_neuron_mask: torch.Tensor,
242
- ) -> torch.Tensor:
243
- """
244
- Helper method to calculate activations for the auxiliary loss.
245
-
246
- Args:
247
- k_aux: Number of top dead neurons to select
248
- hidden_pre: Pre-activation values from encoder
249
- dead_neuron_mask: Boolean mask indicating which neurons are dead
250
-
251
- Returns:
252
- Tensor with activations for only the top-k dead neurons, zeros elsewhere
253
- """
254
-
255
- # Don't include living latents in this loss
256
- auxk_latents = torch.where(dead_neuron_mask[None], hidden_pre, -torch.inf)
257
- # Top-k dead latents
258
- auxk_topk = auxk_latents.topk(k_aux, sorted=False)
259
- # Set the activations to zero for all but the top k_aux dead latents
260
- auxk_acts = torch.zeros_like(hidden_pre)
261
- auxk_acts.scatter_(-1, auxk_topk.indices, auxk_topk.values)
262
- # Set activations to zero for all but top k_aux dead latents
263
- return auxk_acts
264
-
265
-
266
- def _init_weights_topk(
267
- sae: SAE[TopKSAEConfig] | TrainingSAE[TopKTrainingSAEConfig],
268
- ) -> None:
269
- sae.b_enc = nn.Parameter(
270
- torch.zeros(sae.cfg.d_sae, dtype=sae.dtype, device=sae.device)
271
- )
File without changes
File without changes
File without changes
File without changes