sae-lens 6.12.1__py3-none-any.whl → 6.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,3 +1,35 @@
1
+ temporal-sae-gemma-2-2b:
2
+ conversion_func: temporal
3
+ model: gemma-2-2b
4
+ repo_id: canrager/temporalSAEs
5
+ config_overrides:
6
+ model_name: gemma-2-2b
7
+ hook_name: blocks.12.hook_resid_post
8
+ dataset_path: monology/pile-uncopyrighted
9
+ saes:
10
+ - id: blocks.12.hook_resid_post
11
+ l0: 192
12
+ norm_scaling_factor: 0.00666666667
13
+ path: gemma-2-2B/layer_12/temporal
14
+ neuronpedia: gemma-2-2b/12-temporal-res
15
+ temporal-sae-llama-3.1-8b:
16
+ conversion_func: temporal
17
+ model: meta-llama/Llama-3.1-8B
18
+ repo_id: canrager/temporalSAEs
19
+ config_overrides:
20
+ model_name: meta-llama/Llama-3.1-8B
21
+ dataset_path: monology/pile-uncopyrighted
22
+ saes:
23
+ - id: blocks.15.hook_resid_post
24
+ l0: 256
25
+ norm_scaling_factor: 0.029
26
+ path: llama-3.1-8B/layer_15/temporal
27
+ neuronpedia: llama3.1-8b/15-temporal-res
28
+ - id: blocks.26.hook_resid_post
29
+ l0: 256
30
+ norm_scaling_factor: 0.029
31
+ path: llama-3.1-8B/layer_26/temporal
32
+ neuronpedia: llama3.1-8b/26-temporal-res
1
33
  deepseek-r1-distill-llama-8b-qresearch:
2
34
  conversion_func: deepseek_r1
3
35
  model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
@@ -14882,4 +14914,46 @@ qwen2.5-7b-instruct-andyrdt:
14882
14914
  neuronpedia: qwen2.5-7b-it/23-resid-post-aa
14883
14915
  - id: resid_post_layer_27_trainer_1
14884
14916
  path: resid_post_layer_27/trainer_1
14885
- neuronpedia: qwen2.5-7b-it/27-resid-post-aa
14917
+ neuronpedia: qwen2.5-7b-it/27-resid-post-aa
14918
+
14919
+ gpt-oss-20b-andyrdt:
14920
+ conversion_func: dictionary_learning_1
14921
+ model: openai/gpt-oss-20b
14922
+ repo_id: andyrdt/saes-gpt-oss-20b
14923
+ saes:
14924
+ - id: resid_post_layer_3_trainer_0
14925
+ path: resid_post_layer_3/trainer_0
14926
+ neuronpedia: gpt-oss-20b/3-resid-post-aa
14927
+ - id: resid_post_layer_7_trainer_0
14928
+ path: resid_post_layer_7/trainer_0
14929
+ neuronpedia: gpt-oss-20b/7-resid-post-aa
14930
+ - id: resid_post_layer_11_trainer_0
14931
+ path: resid_post_layer_11/trainer_0
14932
+ neuronpedia: gpt-oss-20b/11-resid-post-aa
14933
+ - id: resid_post_layer_15_trainer_0
14934
+ path: resid_post_layer_15/trainer_0
14935
+ neuronpedia: gpt-oss-20b/15-resid-post-aa
14936
+ - id: resid_post_layer_19_trainer_0
14937
+ path: resid_post_layer_19/trainer_0
14938
+ neuronpedia: gpt-oss-20b/19-resid-post-aa
14939
+ - id: resid_post_layer_23_trainer_0
14940
+ path: resid_post_layer_23/trainer_0
14941
+ neuronpedia: gpt-oss-20b/23-resid-post-aa
14942
+
14943
+ goodfire-llama-3.3-70b-instruct:
14944
+ conversion_func: goodfire
14945
+ model: meta-llama/Llama-3.3-70B-Instruct
14946
+ repo_id: Goodfire/Llama-3.3-70B-Instruct-SAE-l50
14947
+ saes:
14948
+ - id: layer_50
14949
+ path: Llama-3.3-70B-Instruct-SAE-l50.pt
14950
+ l0: 121
14951
+
14952
+ goodfire-llama-3.1-8b-instruct:
14953
+ conversion_func: goodfire
14954
+ model: meta-llama/Llama-3.1-8B-Instruct
14955
+ repo_id: Goodfire/Llama-3.1-8B-Instruct-SAE-l19
14956
+ saes:
14957
+ - id: layer_19
14958
+ path: Llama-3.1-8B-Instruct-SAE-l19.pth
14959
+ l0: 91
sae_lens/saes/__init__.py CHANGED
@@ -14,6 +14,10 @@ from .jumprelu_sae import (
14
14
  JumpReLUTrainingSAE,
15
15
  JumpReLUTrainingSAEConfig,
16
16
  )
17
+ from .matryoshka_batchtopk_sae import (
18
+ MatryoshkaBatchTopKTrainingSAE,
19
+ MatryoshkaBatchTopKTrainingSAEConfig,
20
+ )
17
21
  from .sae import SAE, SAEConfig, TrainingSAE, TrainingSAEConfig
18
22
  from .standard_sae import (
19
23
  StandardSAE,
@@ -21,6 +25,7 @@ from .standard_sae import (
21
25
  StandardTrainingSAE,
22
26
  StandardTrainingSAEConfig,
23
27
  )
28
+ from .temporal_sae import TemporalSAE, TemporalSAEConfig
24
29
  from .topk_sae import (
25
30
  TopKSAE,
26
31
  TopKSAEConfig,
@@ -65,4 +70,8 @@ __all__ = [
65
70
  "SkipTranscoderConfig",
66
71
  "JumpReLUTranscoder",
67
72
  "JumpReLUTranscoderConfig",
73
+ "MatryoshkaBatchTopKTrainingSAE",
74
+ "MatryoshkaBatchTopKTrainingSAEConfig",
75
+ "TemporalSAE",
76
+ "TemporalSAEConfig",
68
77
  ]
@@ -23,7 +23,9 @@ class BatchTopK(nn.Module):
23
23
  def forward(self, x: torch.Tensor) -> torch.Tensor:
24
24
  acts = x.relu()
25
25
  flat_acts = acts.flatten()
26
- acts_topk_flat = torch.topk(flat_acts, int(self.k * acts.shape[0]), dim=-1)
26
+ # Calculate total number of samples across all non-feature dimensions
27
+ num_samples = acts.shape[:-1].numel()
28
+ acts_topk_flat = torch.topk(flat_acts, int(self.k * num_samples), dim=-1)
27
29
  return (
28
30
  torch.zeros_like(flat_acts)
29
31
  .scatter(-1, acts_topk_flat.indices, acts_topk_flat.values)
@@ -35,6 +37,35 @@ class BatchTopK(nn.Module):
35
37
  class BatchTopKTrainingSAEConfig(TopKTrainingSAEConfig):
36
38
  """
37
39
  Configuration class for training a BatchTopKTrainingSAE.
40
+
41
+ BatchTopK SAEs maintain k active features on average across the entire batch,
42
+ rather than enforcing k features per sample like standard TopK SAEs. During training,
43
+ the SAE learns a global threshold that is updated based on the minimum positive
44
+ activation value. After training, BatchTopK SAEs are saved as JumpReLU SAEs.
45
+
46
+ Args:
47
+ k (float): Average number of features to keep active across the batch. Unlike
48
+ standard TopK SAEs where k is an integer per sample, this is a float
49
+ representing the average number of active features across all samples in
50
+ the batch. Defaults to 100.
51
+ topk_threshold_lr (float): Learning rate for updating the global topk threshold.
52
+ The threshold is updated using an exponential moving average of the minimum
53
+ positive activation value. Defaults to 0.01.
54
+ aux_loss_coefficient (float): Coefficient for the auxiliary loss that encourages
55
+ dead neurons to learn useful features. Inherited from TopKTrainingSAEConfig.
56
+ Defaults to 1.0.
57
+ rescale_acts_by_decoder_norm (bool): Treat the decoder as if it was already normalized.
58
+ Inherited from TopKTrainingSAEConfig. Defaults to True.
59
+ decoder_init_norm (float | None): Norm to initialize decoder weights to.
60
+ Inherited from TrainingSAEConfig. Defaults to 0.1.
61
+ d_in (int): Input dimension (dimensionality of the activations being encoded).
62
+ Inherited from SAEConfig.
63
+ d_sae (int): SAE latent dimension (number of features in the SAE).
64
+ Inherited from SAEConfig.
65
+ dtype (str): Data type for the SAE parameters. Inherited from SAEConfig.
66
+ Defaults to "float32".
67
+ device (str): Device to place the SAE on. Inherited from SAEConfig.
68
+ Defaults to "cpu".
38
69
  """
39
70
 
40
71
  k: float = 100 # type: ignore[assignment]
@@ -0,0 +1,137 @@
1
+ import warnings
2
+ from dataclasses import dataclass, field
3
+
4
+ import torch
5
+ from jaxtyping import Float
6
+ from typing_extensions import override
7
+
8
+ from sae_lens.saes.batchtopk_sae import (
9
+ BatchTopKTrainingSAE,
10
+ BatchTopKTrainingSAEConfig,
11
+ )
12
+ from sae_lens.saes.sae import TrainStepInput, TrainStepOutput
13
+ from sae_lens.saes.topk_sae import _sparse_matmul_nd
14
+
15
+
16
+ @dataclass
17
+ class MatryoshkaBatchTopKTrainingSAEConfig(BatchTopKTrainingSAEConfig):
18
+ """
19
+ Configuration class for training a MatryoshkaBatchTopKTrainingSAE.
20
+
21
+ [Matryoshka SAEs](https://arxiv.org/pdf/2503.17547) use a series of nested reconstruction
22
+ losses of different widths during training to avoid feature absorption. This also has a
23
+ nice side-effect of encouraging higher-frequency features to be learned in earlier levels.
24
+ However, this SAE has more hyperparameters to tune than standard BatchTopK SAEs, and takes
25
+ longer to train due to requiring multiple forward passes per training step.
26
+
27
+ After training, MatryoshkaBatchTopK SAEs are saved as JumpReLU SAEs.
28
+
29
+ Args:
30
+ matryoshka_widths (list[int]): The widths of the matryoshka levels. Defaults to an empty list.
31
+ k (float): The number of features to keep active. Inherited from BatchTopKTrainingSAEConfig.
32
+ Defaults to 100.
33
+ topk_threshold_lr (float): Learning rate for updating the global topk threshold.
34
+ The threshold is updated using an exponential moving average of the minimum
35
+ positive activation value. Defaults to 0.01.
36
+ aux_loss_coefficient (float): Coefficient for the auxiliary loss that encourages
37
+ dead neurons to learn useful features. Inherited from TopKTrainingSAEConfig.
38
+ Defaults to 1.0.
39
+ rescale_acts_by_decoder_norm (bool): Treat the decoder as if it was already normalized.
40
+ Inherited from TopKTrainingSAEConfig. Defaults to True.
41
+ decoder_init_norm (float | None): Norm to initialize decoder weights to.
42
+ Inherited from TrainingSAEConfig. Defaults to 0.1.
43
+ d_in (int): Input dimension (dimensionality of the activations being encoded).
44
+ Inherited from SAEConfig.
45
+ d_sae (int): SAE latent dimension (number of features in the SAE).
46
+ Inherited from SAEConfig.
47
+ dtype (str): Data type for the SAE parameters. Inherited from SAEConfig.
48
+ Defaults to "float32".
49
+ device (str): Device to place the SAE on. Inherited from SAEConfig.
50
+ Defaults to "cpu".
51
+ """
52
+
53
+ matryoshka_widths: list[int] = field(default_factory=list)
54
+
55
+ @override
56
+ @classmethod
57
+ def architecture(cls) -> str:
58
+ return "matryoshka_batchtopk"
59
+
60
+
61
+ class MatryoshkaBatchTopKTrainingSAE(BatchTopKTrainingSAE):
62
+ """
63
+ Global Batch TopK Training SAE
64
+
65
+ This SAE will maintain the k on average across the batch, rather than enforcing the k per-sample as in standard TopK.
66
+
67
+ BatchTopK SAEs are saved as JumpReLU SAEs after training.
68
+ """
69
+
70
+ cfg: MatryoshkaBatchTopKTrainingSAEConfig # type: ignore[assignment]
71
+
72
+ def __init__(
73
+ self, cfg: MatryoshkaBatchTopKTrainingSAEConfig, use_error_term: bool = False
74
+ ):
75
+ super().__init__(cfg, use_error_term)
76
+ _validate_matryoshka_config(cfg)
77
+
78
+ @override
79
+ def training_forward_pass(self, step_input: TrainStepInput) -> TrainStepOutput:
80
+ base_output = super().training_forward_pass(step_input)
81
+ inv_W_dec_norm = 1 / self.W_dec.norm(dim=-1)
82
+ # the outer matryoshka level is the base SAE, so we don't need to add an extra loss for it
83
+ for width in self.cfg.matryoshka_widths[:-1]:
84
+ inner_reconstruction = self._decode_matryoshka_level(
85
+ base_output.feature_acts, width, inv_W_dec_norm
86
+ )
87
+ inner_mse_loss = (
88
+ self.mse_loss_fn(inner_reconstruction, step_input.sae_in)
89
+ .sum(dim=-1)
90
+ .mean()
91
+ )
92
+ base_output.losses[f"inner_mse_loss_{width}"] = inner_mse_loss
93
+ base_output.loss = base_output.loss + inner_mse_loss
94
+ return base_output
95
+
96
+ def _decode_matryoshka_level(
97
+ self,
98
+ feature_acts: Float[torch.Tensor, "... d_sae"],
99
+ width: int,
100
+ inv_W_dec_norm: torch.Tensor,
101
+ ) -> Float[torch.Tensor, "... d_in"]:
102
+ """
103
+ Decodes feature activations back into input space for a matryoshka level
104
+ """
105
+ inner_feature_acts = feature_acts[:, :width]
106
+ # Handle sparse tensors using efficient sparse matrix multiplication
107
+ if self.cfg.rescale_acts_by_decoder_norm:
108
+ # need to multiply by the inverse of the norm because division is illegal with sparse tensors
109
+ inner_feature_acts = inner_feature_acts * inv_W_dec_norm[:width]
110
+ if inner_feature_acts.is_sparse:
111
+ sae_out_pre = (
112
+ _sparse_matmul_nd(inner_feature_acts, self.W_dec[:width]) + self.b_dec
113
+ )
114
+ else:
115
+ sae_out_pre = inner_feature_acts @ self.W_dec[:width] + self.b_dec
116
+ sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
117
+ return self.reshape_fn_out(sae_out_pre, self.d_head)
118
+
119
+
120
+ def _validate_matryoshka_config(cfg: MatryoshkaBatchTopKTrainingSAEConfig) -> None:
121
+ if cfg.matryoshka_widths[-1] != cfg.d_sae:
122
+ # warn the users that we will add a final matryoshka level
123
+ warnings.warn(
124
+ "WARNING: The final matryoshka level width is not set to cfg.d_sae. "
125
+ "A final matryoshka level of width=cfg.d_sae will be added."
126
+ )
127
+ cfg.matryoshka_widths.append(cfg.d_sae)
128
+
129
+ for prev_width, curr_width in zip(
130
+ cfg.matryoshka_widths[:-1], cfg.matryoshka_widths[1:]
131
+ ):
132
+ if prev_width >= curr_width:
133
+ raise ValueError("cfg.matryoshka_widths must be strictly increasing.")
134
+ if len(cfg.matryoshka_widths) == 1:
135
+ warnings.warn(
136
+ "WARNING: You have only set one matryoshka level. This is equivalent to using a standard BatchTopK SAE and is likely not what you want."
137
+ )
sae_lens/saes/sae.py CHANGED
@@ -14,7 +14,6 @@ from typing import (
14
14
  Generic,
15
15
  Literal,
16
16
  NamedTuple,
17
- Type,
18
17
  TypeVar,
19
18
  )
20
19
 
@@ -22,7 +21,7 @@ import einops
22
21
  import torch
23
22
  from jaxtyping import Float
24
23
  from numpy.typing import NDArray
25
- from safetensors.torch import save_file
24
+ from safetensors.torch import load_file, save_file
26
25
  from torch import nn
27
26
  from transformer_lens.hook_points import HookedRootModule, HookPoint
28
27
  from typing_extensions import deprecated, overload, override
@@ -156,9 +155,9 @@ class SAEConfig(ABC):
156
155
  dtype: str = "float32"
157
156
  device: str = "cpu"
158
157
  apply_b_dec_to_input: bool = True
159
- normalize_activations: Literal[
160
- "none", "expected_average_only_in", "constant_norm_rescale", "layer_norm"
161
- ] = "none" # none, expected_average_only_in (Anthropic April Update), constant_norm_rescale (Anthropic Feb Update)
158
+ normalize_activations: Literal["none", "expected_average_only_in", "layer_norm"] = (
159
+ "none" # none, expected_average_only_in (Anthropic April Update)
160
+ )
162
161
  reshape_activations: Literal["none", "hook_z"] = "none"
163
162
  metadata: SAEMetadata = field(default_factory=SAEMetadata)
164
163
 
@@ -218,6 +217,7 @@ class TrainStepInput:
218
217
  sae_in: torch.Tensor
219
218
  coefficients: dict[str, float]
220
219
  dead_neuron_mask: torch.Tensor | None
220
+ n_training_steps: int
221
221
 
222
222
 
223
223
  class TrainCoefficientConfig(NamedTuple):
@@ -245,7 +245,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
245
245
 
246
246
  self.cfg = cfg
247
247
 
248
- if cfg.metadata and cfg.metadata:
248
+ if cfg.metadata and cfg.metadata.model_from_pretrained_kwargs:
249
249
  warnings.warn(
250
250
  "\nThis SAE has non-empty model_from_pretrained_kwargs. "
251
251
  "\nFor optimal performance, load the model like so:\n"
@@ -309,6 +309,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
309
309
 
310
310
  self.run_time_activation_norm_fn_in = run_time_activation_norm_fn_in
311
311
  self.run_time_activation_norm_fn_out = run_time_activation_norm_fn_out
312
+
312
313
  elif self.cfg.normalize_activations == "layer_norm":
313
314
  # we need to scale the norm of the input and store the scaling factor
314
315
  def run_time_activation_ln_in(
@@ -452,23 +453,14 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
452
453
  def process_sae_in(
453
454
  self, sae_in: Float[torch.Tensor, "... d_in"]
454
455
  ) -> Float[torch.Tensor, "... d_in"]:
455
- # print(f"Input shape to process_sae_in: {sae_in.shape}")
456
- # print(f"self.cfg.hook_name: {self.cfg.hook_name}")
457
- # print(f"self.b_dec shape: {self.b_dec.shape}")
458
- # print(f"Hook z reshaping mode: {getattr(self, 'hook_z_reshaping_mode', False)}")
459
-
460
456
  sae_in = sae_in.to(self.dtype)
461
-
462
- # print(f"Shape before reshape_fn_in: {sae_in.shape}")
463
457
  sae_in = self.reshape_fn_in(sae_in)
464
- # print(f"Shape after reshape_fn_in: {sae_in.shape}")
465
458
 
466
459
  sae_in = self.hook_sae_input(sae_in)
467
460
  sae_in = self.run_time_activation_norm_fn_in(sae_in)
468
461
 
469
462
  # Here's where the error happens
470
463
  bias_term = self.b_dec * self.cfg.apply_b_dec_to_input
471
- # print(f"Bias term shape: {bias_term.shape}")
472
464
 
473
465
  return sae_in - bias_term
474
466
 
@@ -534,7 +526,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
534
526
  @classmethod
535
527
  @deprecated("Use load_from_disk instead")
536
528
  def load_from_pretrained(
537
- cls: Type[T_SAE],
529
+ cls: type[T_SAE],
538
530
  path: str | Path,
539
531
  device: str = "cpu",
540
532
  dtype: str | None = None,
@@ -543,7 +535,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
543
535
 
544
536
  @classmethod
545
537
  def load_from_disk(
546
- cls: Type[T_SAE],
538
+ cls: type[T_SAE],
547
539
  path: str | Path,
548
540
  device: str = "cpu",
549
541
  dtype: str | None = None,
@@ -564,7 +556,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
564
556
 
565
557
  @classmethod
566
558
  def from_pretrained(
567
- cls: Type[T_SAE],
559
+ cls: type[T_SAE],
568
560
  release: str,
569
561
  sae_id: str,
570
562
  device: str = "cpu",
@@ -585,7 +577,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
585
577
 
586
578
  @classmethod
587
579
  def from_pretrained_with_cfg_and_sparsity(
588
- cls: Type[T_SAE],
580
+ cls: type[T_SAE],
589
581
  release: str,
590
582
  sae_id: str,
591
583
  device: str = "cpu",
@@ -684,7 +676,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
684
676
  return sae, cfg_dict, log_sparsities
685
677
 
686
678
  @classmethod
687
- def from_dict(cls: Type[T_SAE], config_dict: dict[str, Any]) -> T_SAE:
679
+ def from_dict(cls: type[T_SAE], config_dict: dict[str, Any]) -> T_SAE:
688
680
  """Create an SAE from a config dictionary."""
689
681
  sae_cls = cls.get_sae_class_for_architecture(config_dict["architecture"])
690
682
  sae_config_cls = cls.get_sae_config_class_for_architecture(
@@ -694,8 +686,8 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
694
686
 
695
687
  @classmethod
696
688
  def get_sae_class_for_architecture(
697
- cls: Type[T_SAE], architecture: str
698
- ) -> Type[T_SAE]:
689
+ cls: type[T_SAE], architecture: str
690
+ ) -> type[T_SAE]:
699
691
  """Get the SAE class for a given architecture."""
700
692
  sae_cls, _ = get_sae_class(architecture)
701
693
  if not issubclass(sae_cls, cls):
@@ -1000,8 +992,8 @@ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
1000
992
 
1001
993
  @classmethod
1002
994
  def get_sae_class_for_architecture(
1003
- cls: Type[T_TRAINING_SAE], architecture: str
1004
- ) -> Type[T_TRAINING_SAE]:
995
+ cls: type[T_TRAINING_SAE], architecture: str
996
+ ) -> type[T_TRAINING_SAE]:
1005
997
  """Get the SAE class for a given architecture."""
1006
998
  sae_cls, _ = get_sae_training_class(architecture)
1007
999
  if not issubclass(sae_cls, cls):
@@ -1018,6 +1010,12 @@ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
1018
1010
  ) -> type[TrainingSAEConfig]:
1019
1011
  return get_sae_training_class(architecture)[1]
1020
1012
 
1013
+ def load_weights_from_checkpoint(self, checkpoint_path: Path | str) -> None:
1014
+ checkpoint_path = Path(checkpoint_path)
1015
+ state_dict = load_file(checkpoint_path / SAE_WEIGHTS_FILENAME)
1016
+ self.process_state_dict_for_loading(state_dict)
1017
+ self.load_state_dict(state_dict)
1018
+
1021
1019
 
1022
1020
  _blank_hook = nn.Identity()
1023
1021