sae-lens 6.14.2__py3-none-any.whl → 6.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sae_lens/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.14.2"
2
+ __version__ = "6.15.0"
3
3
 
4
4
  import logging
5
5
 
@@ -19,6 +19,8 @@ from sae_lens.saes import (
19
19
  JumpReLUTrainingSAEConfig,
20
20
  JumpReLUTranscoder,
21
21
  JumpReLUTranscoderConfig,
22
+ MatryoshkaBatchTopKTrainingSAE,
23
+ MatryoshkaBatchTopKTrainingSAEConfig,
22
24
  SAEConfig,
23
25
  SkipTranscoder,
24
26
  SkipTranscoderConfig,
@@ -101,6 +103,8 @@ __all__ = [
101
103
  "SkipTranscoderConfig",
102
104
  "JumpReLUTranscoder",
103
105
  "JumpReLUTranscoderConfig",
106
+ "MatryoshkaBatchTopKTrainingSAE",
107
+ "MatryoshkaBatchTopKTrainingSAEConfig",
104
108
  ]
105
109
 
106
110
 
@@ -115,6 +119,11 @@ register_sae_training_class("jumprelu", JumpReLUTrainingSAE, JumpReLUTrainingSAE
115
119
  register_sae_training_class(
116
120
  "batchtopk", BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig
117
121
  )
122
+ register_sae_training_class(
123
+ "matryoshka_batchtopk",
124
+ MatryoshkaBatchTopKTrainingSAE,
125
+ MatryoshkaBatchTopKTrainingSAEConfig,
126
+ )
118
127
  register_sae_class("transcoder", Transcoder, TranscoderConfig)
119
128
  register_sae_class("skip_transcoder", SkipTranscoder, SkipTranscoderConfig)
120
129
  register_sae_class("jumprelu_transcoder", JumpReLUTranscoder, JumpReLUTranscoderConfig)
@@ -22,17 +22,13 @@ from sae_lens.constants import (
22
22
  )
23
23
  from sae_lens.evals import EvalConfig, run_evals
24
24
  from sae_lens.load_model import load_model
25
- from sae_lens.saes.batchtopk_sae import BatchTopKTrainingSAEConfig
26
- from sae_lens.saes.gated_sae import GatedTrainingSAEConfig
27
- from sae_lens.saes.jumprelu_sae import JumpReLUTrainingSAEConfig
25
+ from sae_lens.registry import SAE_TRAINING_CLASS_REGISTRY
28
26
  from sae_lens.saes.sae import (
29
27
  T_TRAINING_SAE,
30
28
  T_TRAINING_SAE_CONFIG,
31
29
  TrainingSAE,
32
30
  TrainingSAEConfig,
33
31
  )
34
- from sae_lens.saes.standard_sae import StandardTrainingSAEConfig
35
- from sae_lens.saes.topk_sae import TopKTrainingSAEConfig
36
32
  from sae_lens.training.activation_scaler import ActivationScaler
37
33
  from sae_lens.training.activations_store import ActivationsStore
38
34
  from sae_lens.training.sae_trainer import SAETrainer
@@ -395,12 +391,8 @@ def _parse_cfg_args(
395
391
  )
396
392
 
397
393
  # Map architecture to concrete config class
398
- sae_config_map = {
399
- "standard": StandardTrainingSAEConfig,
400
- "gated": GatedTrainingSAEConfig,
401
- "jumprelu": JumpReLUTrainingSAEConfig,
402
- "topk": TopKTrainingSAEConfig,
403
- "batchtopk": BatchTopKTrainingSAEConfig,
394
+ sae_config_map: dict[str, type[TrainingSAEConfig]] = {
395
+ name: cfg for name, (_, cfg) in SAE_TRAINING_CLASS_REGISTRY.items()
404
396
  }
405
397
 
406
398
  sae_config_type = sae_config_map[architecture]
sae_lens/saes/__init__.py CHANGED
@@ -14,6 +14,10 @@ from .jumprelu_sae import (
14
14
  JumpReLUTrainingSAE,
15
15
  JumpReLUTrainingSAEConfig,
16
16
  )
17
+ from .matryoshka_batchtopk_sae import (
18
+ MatryoshkaBatchTopKTrainingSAE,
19
+ MatryoshkaBatchTopKTrainingSAEConfig,
20
+ )
17
21
  from .sae import SAE, SAEConfig, TrainingSAE, TrainingSAEConfig
18
22
  from .standard_sae import (
19
23
  StandardSAE,
@@ -65,4 +69,6 @@ __all__ = [
65
69
  "SkipTranscoderConfig",
66
70
  "JumpReLUTranscoder",
67
71
  "JumpReLUTranscoderConfig",
72
+ "MatryoshkaBatchTopKTrainingSAE",
73
+ "MatryoshkaBatchTopKTrainingSAEConfig",
68
74
  ]
@@ -0,0 +1,143 @@
1
+ import warnings
2
+ from dataclasses import dataclass, field
3
+
4
+ import torch
5
+ from jaxtyping import Float
6
+ from typing_extensions import override
7
+
8
+ from sae_lens.saes.batchtopk_sae import (
9
+ BatchTopKTrainingSAE,
10
+ BatchTopKTrainingSAEConfig,
11
+ )
12
+ from sae_lens.saes.sae import TrainStepInput, TrainStepOutput
13
+ from sae_lens.saes.topk_sae import _sparse_matmul_nd
14
+
15
+
16
+ @dataclass
17
+ class MatryoshkaBatchTopKTrainingSAEConfig(BatchTopKTrainingSAEConfig):
18
+ """
19
+ Configuration class for training a MatryoshkaBatchTopKTrainingSAE.
20
+
21
+ [Matryoshka SAEs](https://arxiv.org/pdf/2503.17547) use a series of nested reconstruction
22
+ losses of different widths during training to avoid feature absorption. This also has a
23
+ nice side-effect of encouraging higher-frequency features to be learned in earlier levels.
24
+ However, this SAE has more hyperparameters to tune than standard BatchTopK SAEs, and takes
25
+ longer to train due to requiring multiple forward passes per training step.
26
+
27
+ After training, MatryoshkaBatchTopK SAEs are saved as JumpReLU SAEs.
28
+
29
+ Args:
30
+ matryoshka_widths (list[int]): The widths of the matryoshka levels. Defaults to an empty list.
31
+ k (float): The number of features to keep active. Inherited from BatchTopKTrainingSAEConfig.
32
+ Defaults to 100.
33
+ topk_threshold_lr (float): Learning rate for updating the global topk threshold.
34
+ The threshold is updated using an exponential moving average of the minimum
35
+ positive activation value. Defaults to 0.01.
36
+ aux_loss_coefficient (float): Coefficient for the auxiliary loss that encourages
37
+ dead neurons to learn useful features. Inherited from TopKTrainingSAEConfig.
38
+ Defaults to 1.0.
39
+ rescale_acts_by_decoder_norm (bool): Treat the decoder as if it was already normalized.
40
+ Inherited from TopKTrainingSAEConfig. Defaults to True.
41
+ decoder_init_norm (float | None): Norm to initialize decoder weights to.
42
+ Inherited from TrainingSAEConfig. Defaults to 0.1.
43
+ d_in (int): Input dimension (dimensionality of the activations being encoded).
44
+ Inherited from SAEConfig.
45
+ d_sae (int): SAE latent dimension (number of features in the SAE).
46
+ Inherited from SAEConfig.
47
+ dtype (str): Data type for the SAE parameters. Inherited from SAEConfig.
48
+ Defaults to "float32".
49
+ device (str): Device to place the SAE on. Inherited from SAEConfig.
50
+ Defaults to "cpu".
51
+ """
52
+
53
+ matryoshka_widths: list[int] = field(default_factory=list)
54
+
55
+ @override
56
+ @classmethod
57
+ def architecture(cls) -> str:
58
+ return "matryoshka_batchtopk"
59
+
60
+
61
+ class MatryoshkaBatchTopKTrainingSAE(BatchTopKTrainingSAE):
62
+ """
63
+ Global Batch TopK Training SAE
64
+
65
+ This SAE will maintain the k on average across the batch, rather than enforcing the k per-sample as in standard TopK.
66
+
67
+ BatchTopK SAEs are saved as JumpReLU SAEs after training.
68
+ """
69
+
70
+ cfg: MatryoshkaBatchTopKTrainingSAEConfig # type: ignore[assignment]
71
+
72
+ def __init__(
73
+ self, cfg: MatryoshkaBatchTopKTrainingSAEConfig, use_error_term: bool = False
74
+ ):
75
+ super().__init__(cfg, use_error_term)
76
+ _validate_matryoshka_config(cfg)
77
+
78
+ @override
79
+ def training_forward_pass(self, step_input: TrainStepInput) -> TrainStepOutput:
80
+ base_output = super().training_forward_pass(step_input)
81
+ hidden_pre = base_output.hidden_pre
82
+ inv_W_dec_norm = 1 / self.W_dec.norm(dim=-1)
83
+ # the outer matryoshka level is the base SAE, so we don't need to add an extra loss for it
84
+ for width in self.cfg.matryoshka_widths[:-1]:
85
+ inner_hidden_pre = hidden_pre[:, :width]
86
+ inner_feat_acts = self.activation_fn(inner_hidden_pre)
87
+ inner_reconstruction = self._decode_matryoshka_level(
88
+ inner_feat_acts, width, inv_W_dec_norm
89
+ )
90
+ inner_mse_loss = (
91
+ self.mse_loss_fn(inner_reconstruction, step_input.sae_in)
92
+ .sum(dim=-1)
93
+ .mean()
94
+ )
95
+ base_output.losses[f"inner_mse_loss_{width}"] = inner_mse_loss
96
+ base_output.loss = base_output.loss + inner_mse_loss
97
+ return base_output
98
+
99
+ def _decode_matryoshka_level(
100
+ self,
101
+ feature_acts: Float[torch.Tensor, "... d_sae"],
102
+ width: int,
103
+ inv_W_dec_norm: torch.Tensor,
104
+ ) -> Float[torch.Tensor, "... d_in"]:
105
+ """
106
+ Decodes feature activations back into input space for a matryoshka level
107
+ """
108
+ # Handle sparse tensors using efficient sparse matrix multiplication
109
+ if self.cfg.rescale_acts_by_decoder_norm:
110
+ # need to multiply by the inverse of the norm because division is illegal with sparse tensors
111
+ feature_acts = feature_acts * inv_W_dec_norm[:width]
112
+ if feature_acts.is_sparse:
113
+ sae_out_pre = (
114
+ _sparse_matmul_nd(feature_acts, self.W_dec[:width]) + self.b_dec
115
+ )
116
+ else:
117
+ sae_out_pre = feature_acts @ self.W_dec[:width] + self.b_dec
118
+ sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
119
+ return self.reshape_fn_out(sae_out_pre, self.d_head)
120
+
121
+
122
+ def _validate_matryoshka_config(cfg: MatryoshkaBatchTopKTrainingSAEConfig) -> None:
123
+ if cfg.matryoshka_widths[-1] != cfg.d_sae:
124
+ # warn the users that we will add a final matryoshka level
125
+ warnings.warn(
126
+ "WARNING: The final matryoshka level width is not set to cfg.d_sae. "
127
+ "A final matryoshka level of width=cfg.d_sae will be added."
128
+ )
129
+ cfg.matryoshka_widths.append(cfg.d_sae)
130
+
131
+ for prev_width, curr_width in zip(
132
+ cfg.matryoshka_widths[:-1], cfg.matryoshka_widths[1:]
133
+ ):
134
+ if prev_width >= curr_width:
135
+ raise ValueError("cfg.matryoshka_widths must be strictly increasing.")
136
+ if len(cfg.matryoshka_widths) == 1:
137
+ warnings.warn(
138
+ "WARNING: You have only set one matryoshka level. This is equivalent to using a standard BatchTopK SAE and is likely not what you want."
139
+ )
140
+ if cfg.matryoshka_widths[0] < cfg.k:
141
+ raise ValueError(
142
+ "The smallest matryoshka level width cannot be smaller than cfg.k."
143
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sae-lens
3
- Version: 6.14.2
3
+ Version: 6.15.0
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  License-File: LICENSE
@@ -1,4 +1,4 @@
1
- sae_lens/__init__.py,sha256=U6PI8XxNzEqTakvBsTnn6i8EvoMbpcviRffWBke2frk,3589
1
+ sae_lens/__init__.py,sha256=ab8Lj2QJE3i1uOP_4B9LLh_vCgi__3XXx66_eO8rcrA,3886
2
2
  sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  sae_lens/analysis/hooked_sae_transformer.py,sha256=vRu6JseH1lZaEeILD5bEkQEQ1wYHHDcxD-f2olKmE9Y,14275
4
4
  sae_lens/analysis/neuronpedia_integration.py,sha256=Gx1W7hUBEuMoasNcnOnZ1wmqbXDd1pSZ1nqKEya1HQc,4962
@@ -6,7 +6,7 @@ sae_lens/cache_activations_runner.py,sha256=cNeAtp2JQ_vKbeddZVM-tcPLYyyfTWL8NDna
6
6
  sae_lens/config.py,sha256=IdRXSKPfYY3hwUovj-u83eep8z52gkJHII0mY0KseYY,28739
7
7
  sae_lens/constants.py,sha256=CSjmiZ-bhjQeVLyRvWxAjBokCgkfM8mnvd7-vxLIWTY,639
8
8
  sae_lens/evals.py,sha256=P0NUsJeGzYxFBiVKhbPzd72IFKY4gH40HHlEZ3jEAmg,39598
9
- sae_lens/llm_sae_training_runner.py,sha256=8Km519LH080RZnUBeaG2T1trq5UqxoAqokNmpX4xMTM,15200
9
+ sae_lens/llm_sae_training_runner.py,sha256=UHRcLqvtnORsZ7u7ymbrv-Ib2BD84czHBvu03jNbtcE,14834
10
10
  sae_lens/load_model.py,sha256=C8AMykctj6H7tz_xRwB06-EXj6TfW64PtSJZR5Jxn1Y,8649
11
11
  sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
12
  sae_lens/loading/pretrained_sae_loaders.py,sha256=SM4aT8NM6ezYix5c2u7p72Fz2RfvTtf7gw5RdOSKXhc,49846
@@ -14,10 +14,11 @@ sae_lens/loading/pretrained_saes_directory.py,sha256=4Vn-Jex6SveD7EbxcSOBv8cx1gk
14
14
  sae_lens/pretokenize_runner.py,sha256=x-reJzVPFDS9iRFbZtrFYSzNguJYki9gd0pbHjYJ3r4,7085
15
15
  sae_lens/pretrained_saes.yaml,sha256=6ca3geEB6NyhULUrmdtPDK8ea0YdpLp8_au78vIFC5w,602553
16
16
  sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
17
- sae_lens/saes/__init__.py,sha256=jVwazK8Q6dW5J6_zFXPoNAuBvSxgziQ8eMOjGM3t-X8,1475
17
+ sae_lens/saes/__init__.py,sha256=sIfZUxZ4m3HPtPymCJlpBEofiCrL8_QziE6ChS-v4lE,1677
18
18
  sae_lens/saes/batchtopk_sae.py,sha256=zxIke8lOBKkQEMVFk6sSW6q_s6F9RKhysLqfqG9ecwI,5300
19
19
  sae_lens/saes/gated_sae.py,sha256=qcmM9JwBA8aZR8z_IRHV1_gQX-q_63tKewWXRnhdXuo,8986
20
20
  sae_lens/saes/jumprelu_sae.py,sha256=HHBF1sJ95lZvxwP5vwLSQFKdnJN2KKYK0WAEaLTrta0,13399
21
+ sae_lens/saes/matryoshka_batchtopk_sae.py,sha256=zrS4MksbxdhhftmU3UWjRCWjR7iEBpAk6N00c6GrXks,6291
21
22
  sae_lens/saes/sae.py,sha256=nuII6ZmaVtJWhPjyhasHQyiv_Wj-zdAtRQqJRYbVBQs,38274
22
23
  sae_lens/saes/standard_sae.py,sha256=9UqYyYtQuThYxXKNaDjYcyowpOx2-7cShG-TeUP6JCQ,5940
23
24
  sae_lens/saes/topk_sae.py,sha256=tzQM5eQFifMe--8_8NUBYWY7hpjQa6A_olNe6U71FE8,21275
@@ -33,7 +34,7 @@ sae_lens/training/types.py,sha256=1FpLx_Doda9vZpmfm-x1e8wGBYpyhe9Kpb_JuM5nIFM,90
33
34
  sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
34
35
  sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
35
36
  sae_lens/util.py,sha256=tCovQ-eZa1L7thPpNDL6PGOJrIMML2yLI5e0EHCOpS8,3309
36
- sae_lens-6.14.2.dist-info/METADATA,sha256=WDlgsdDyQT4jmu5hxMU-pqm5PfBh0h65MTEbyuMuE3c,5318
37
- sae_lens-6.14.2.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
38
- sae_lens-6.14.2.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
39
- sae_lens-6.14.2.dist-info/RECORD,,
37
+ sae_lens-6.15.0.dist-info/METADATA,sha256=UmBQ8quUJBWyLclhbnDcXAkL-6jnOW4SbT8_X3rrcbE,5318
38
+ sae_lens-6.15.0.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
39
+ sae_lens-6.15.0.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
40
+ sae_lens-6.15.0.dist-info/RECORD,,