sae-lens 6.25.1__py3-none-any.whl → 6.26.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sae_lens/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.25.1"
2
+ __version__ = "6.26.1"
3
3
 
4
4
  import logging
5
5
 
@@ -21,6 +21,10 @@ from sae_lens.saes import (
21
21
  JumpReLUTrainingSAEConfig,
22
22
  JumpReLUTranscoder,
23
23
  JumpReLUTranscoderConfig,
24
+ MatchingPursuitSAE,
25
+ MatchingPursuitSAEConfig,
26
+ MatchingPursuitTrainingSAE,
27
+ MatchingPursuitTrainingSAEConfig,
24
28
  MatryoshkaBatchTopKTrainingSAE,
25
29
  MatryoshkaBatchTopKTrainingSAEConfig,
26
30
  SAEConfig,
@@ -113,6 +117,10 @@ __all__ = [
113
117
  "MatryoshkaBatchTopKTrainingSAEConfig",
114
118
  "TemporalSAE",
115
119
  "TemporalSAEConfig",
120
+ "MatchingPursuitSAE",
121
+ "MatchingPursuitTrainingSAE",
122
+ "MatchingPursuitSAEConfig",
123
+ "MatchingPursuitTrainingSAEConfig",
116
124
  ]
117
125
 
118
126
 
@@ -139,3 +147,7 @@ register_sae_class(
139
147
  "jumprelu_skip_transcoder", JumpReLUSkipTranscoder, JumpReLUSkipTranscoderConfig
140
148
  )
141
149
  register_sae_class("temporal", TemporalSAE, TemporalSAEConfig)
150
+ register_sae_class("matching_pursuit", MatchingPursuitSAE, MatchingPursuitSAEConfig)
151
+ register_sae_training_class(
152
+ "matching_pursuit", MatchingPursuitTrainingSAE, MatchingPursuitTrainingSAEConfig
153
+ )
sae_lens/config.py CHANGED
@@ -17,6 +17,11 @@ from datasets import (
17
17
  )
18
18
 
19
19
  from sae_lens import __version__, logger
20
+
21
+ # keeping this unused import since some SAELens deps import DTYPE_MAP from config
22
+ from sae_lens.constants import (
23
+ DTYPE_MAP, # noqa: F401 # pyright: ignore[reportUnusedImport]
24
+ )
20
25
  from sae_lens.registry import get_sae_training_class
21
26
  from sae_lens.saes.sae import TrainingSAEConfig
22
27
  from sae_lens.util import str_to_dtype
@@ -9072,150 +9072,150 @@ gemma-scope-2-27b-it-transcoders-all:
9072
9072
  - id: layer_5_width_262k_l0_small_affine
9073
9073
  path: transcoder_all/layer_5_width_262k_l0_small_affine
9074
9074
  l0: 12
9075
- # - id: layer_60_width_16k_l0_big
9076
- # path: transcoder_all/layer_60_width_16k_l0_big
9077
- # l0: 120
9078
- # - id: layer_60_width_16k_l0_big_affine
9079
- # path: transcoder_all/layer_60_width_16k_l0_big_affine
9080
- # l0: 120
9081
- # - id: layer_60_width_16k_l0_small
9082
- # path: transcoder_all/layer_60_width_16k_l0_small
9083
- # l0: 20
9084
- # - id: layer_60_width_16k_l0_small_affine
9085
- # path: transcoder_all/layer_60_width_16k_l0_small_affine
9086
- # l0: 20
9087
- # - id: layer_60_width_262k_l0_big
9088
- # path: transcoder_all/layer_60_width_262k_l0_big
9089
- # l0: 120
9090
- # - id: layer_60_width_262k_l0_big_affine
9091
- # path: transcoder_all/layer_60_width_262k_l0_big_affine
9092
- # l0: 120
9093
- # - id: layer_60_width_262k_l0_small
9094
- # path: transcoder_all/layer_60_width_262k_l0_small
9095
- # l0: 20
9096
- # - id: layer_60_width_262k_l0_small_affine
9097
- # path: transcoder_all/layer_60_width_262k_l0_small_affine
9098
- # l0: 20
9099
- # - id: layer_61_width_16k_l0_big
9100
- # path: transcoder_all/layer_61_width_16k_l0_big
9101
- # l0: 120
9102
- # - id: layer_61_width_16k_l0_big_affine
9103
- # path: transcoder_all/layer_61_width_16k_l0_big_affine
9104
- # l0: 120
9105
- # - id: layer_61_width_16k_l0_small
9106
- # path: transcoder_all/layer_61_width_16k_l0_small
9107
- # l0: 20
9108
- # - id: layer_61_width_16k_l0_small_affine
9109
- # path: transcoder_all/layer_61_width_16k_l0_small_affine
9110
- # l0: 20
9111
- # - id: layer_61_width_262k_l0_big
9112
- # path: transcoder_all/layer_61_width_262k_l0_big
9113
- # l0: 120
9114
- # - id: layer_61_width_262k_l0_big_affine
9115
- # path: transcoder_all/layer_61_width_262k_l0_big_affine
9116
- # l0: 120
9117
- # - id: layer_61_width_262k_l0_small
9118
- # path: transcoder_all/layer_61_width_262k_l0_small
9119
- # l0: 20
9120
- # - id: layer_61_width_262k_l0_small_affine
9121
- # path: transcoder_all/layer_61_width_262k_l0_small_affine
9122
- # l0: 20
9123
- # - id: layer_6_width_16k_l0_big
9124
- # path: transcoder_all/layer_6_width_16k_l0_big
9125
- # l0: 77
9126
- # - id: layer_6_width_16k_l0_big_affine
9127
- # path: transcoder_all/layer_6_width_16k_l0_big_affine
9128
- # l0: 77
9129
- # - id: layer_6_width_16k_l0_small
9130
- # path: transcoder_all/layer_6_width_16k_l0_small
9131
- # l0: 12
9132
- # - id: layer_6_width_16k_l0_small_affine
9133
- # path: transcoder_all/layer_6_width_16k_l0_small_affine
9134
- # l0: 12
9135
- # - id: layer_6_width_262k_l0_big
9136
- # path: transcoder_all/layer_6_width_262k_l0_big
9137
- # l0: 77
9138
- # - id: layer_6_width_262k_l0_big_affine
9139
- # path: transcoder_all/layer_6_width_262k_l0_big_affine
9140
- # l0: 77
9141
- # - id: layer_6_width_262k_l0_small
9142
- # path: transcoder_all/layer_6_width_262k_l0_small
9143
- # l0: 12
9144
- # - id: layer_6_width_262k_l0_small_affine
9145
- # path: transcoder_all/layer_6_width_262k_l0_small_affine
9146
- # l0: 12
9147
- # - id: layer_7_width_16k_l0_big
9148
- # path: transcoder_all/layer_7_width_16k_l0_big
9149
- # l0: 80
9150
- # - id: layer_7_width_16k_l0_big_affine
9151
- # path: transcoder_all/layer_7_width_16k_l0_big_affine
9152
- # l0: 80
9153
- # - id: layer_7_width_16k_l0_small
9154
- # path: transcoder_all/layer_7_width_16k_l0_small
9155
- # l0: 13
9156
- # - id: layer_7_width_16k_l0_small_affine
9157
- # path: transcoder_all/layer_7_width_16k_l0_small_affine
9158
- # l0: 13
9159
- # - id: layer_7_width_262k_l0_big
9160
- # path: transcoder_all/layer_7_width_262k_l0_big
9161
- # l0: 80
9162
- # - id: layer_7_width_262k_l0_big_affine
9163
- # path: transcoder_all/layer_7_width_262k_l0_big_affine
9164
- # l0: 80
9165
- # - id: layer_7_width_262k_l0_small
9166
- # path: transcoder_all/layer_7_width_262k_l0_small
9167
- # l0: 13
9168
- # - id: layer_7_width_262k_l0_small_affine
9169
- # path: transcoder_all/layer_7_width_262k_l0_small_affine
9170
- # l0: 13
9171
- # - id: layer_8_width_16k_l0_big
9172
- # path: transcoder_all/layer_8_width_16k_l0_big
9173
- # l0: 83
9174
- # - id: layer_8_width_16k_l0_big_affine
9175
- # path: transcoder_all/layer_8_width_16k_l0_big_affine
9176
- # l0: 83
9177
- # - id: layer_8_width_16k_l0_small
9178
- # path: transcoder_all/layer_8_width_16k_l0_small
9179
- # l0: 13
9180
- # - id: layer_8_width_16k_l0_small_affine
9181
- # path: transcoder_all/layer_8_width_16k_l0_small_affine
9182
- # l0: 13
9183
- # - id: layer_8_width_262k_l0_big
9184
- # path: transcoder_all/layer_8_width_262k_l0_big
9185
- # l0: 83
9186
- # - id: layer_8_width_262k_l0_big_affine
9187
- # path: transcoder_all/layer_8_width_262k_l0_big_affine
9188
- # l0: 83
9189
- # - id: layer_8_width_262k_l0_small
9190
- # path: transcoder_all/layer_8_width_262k_l0_small
9191
- # l0: 13
9192
- # - id: layer_8_width_262k_l0_small_affine
9193
- # path: transcoder_all/layer_8_width_262k_l0_small_affine
9194
- # l0: 13
9195
- # - id: layer_9_width_16k_l0_big
9196
- # path: transcoder_all/layer_9_width_16k_l0_big
9197
- # l0: 86
9198
- # - id: layer_9_width_16k_l0_big_affine
9199
- # path: transcoder_all/layer_9_width_16k_l0_big_affine
9200
- # l0: 86
9201
- # - id: layer_9_width_16k_l0_small
9202
- # path: transcoder_all/layer_9_width_16k_l0_small
9203
- # l0: 14
9204
- # - id: layer_9_width_16k_l0_small_affine
9205
- # path: transcoder_all/layer_9_width_16k_l0_small_affine
9206
- # l0: 14
9207
- # - id: layer_9_width_262k_l0_big
9208
- # path: transcoder_all/layer_9_width_262k_l0_big
9209
- # l0: 86
9210
- # - id: layer_9_width_262k_l0_big_affine
9211
- # path: transcoder_all/layer_9_width_262k_l0_big_affine
9212
- # l0: 86
9213
- # - id: layer_9_width_262k_l0_small
9214
- # path: transcoder_all/layer_9_width_262k_l0_small
9215
- # l0: 14
9216
- # - id: layer_9_width_262k_l0_small_affine
9217
- # path: transcoder_all/layer_9_width_262k_l0_small_affine
9218
- # l0: 14
9075
+ - id: layer_60_width_16k_l0_big
9076
+ path: transcoder_all/layer_60_width_16k_l0_big
9077
+ l0: 120
9078
+ - id: layer_60_width_16k_l0_big_affine
9079
+ path: transcoder_all/layer_60_width_16k_l0_big_affine
9080
+ l0: 120
9081
+ - id: layer_60_width_16k_l0_small
9082
+ path: transcoder_all/layer_60_width_16k_l0_small
9083
+ l0: 20
9084
+ - id: layer_60_width_16k_l0_small_affine
9085
+ path: transcoder_all/layer_60_width_16k_l0_small_affine
9086
+ l0: 20
9087
+ - id: layer_60_width_262k_l0_big
9088
+ path: transcoder_all/layer_60_width_262k_l0_big
9089
+ l0: 120
9090
+ - id: layer_60_width_262k_l0_big_affine
9091
+ path: transcoder_all/layer_60_width_262k_l0_big_affine
9092
+ l0: 120
9093
+ - id: layer_60_width_262k_l0_small
9094
+ path: transcoder_all/layer_60_width_262k_l0_small
9095
+ l0: 20
9096
+ - id: layer_60_width_262k_l0_small_affine
9097
+ path: transcoder_all/layer_60_width_262k_l0_small_affine
9098
+ l0: 20
9099
+ - id: layer_61_width_16k_l0_big
9100
+ path: transcoder_all/layer_61_width_16k_l0_big
9101
+ l0: 120
9102
+ - id: layer_61_width_16k_l0_big_affine
9103
+ path: transcoder_all/layer_61_width_16k_l0_big_affine
9104
+ l0: 120
9105
+ - id: layer_61_width_16k_l0_small
9106
+ path: transcoder_all/layer_61_width_16k_l0_small
9107
+ l0: 20
9108
+ - id: layer_61_width_16k_l0_small_affine
9109
+ path: transcoder_all/layer_61_width_16k_l0_small_affine
9110
+ l0: 20
9111
+ - id: layer_61_width_262k_l0_big
9112
+ path: transcoder_all/layer_61_width_262k_l0_big
9113
+ l0: 120
9114
+ - id: layer_61_width_262k_l0_big_affine
9115
+ path: transcoder_all/layer_61_width_262k_l0_big_affine
9116
+ l0: 120
9117
+ - id: layer_61_width_262k_l0_small
9118
+ path: transcoder_all/layer_61_width_262k_l0_small
9119
+ l0: 20
9120
+ - id: layer_61_width_262k_l0_small_affine
9121
+ path: transcoder_all/layer_61_width_262k_l0_small_affine
9122
+ l0: 20
9123
+ - id: layer_6_width_16k_l0_big
9124
+ path: transcoder_all/layer_6_width_16k_l0_big
9125
+ l0: 77
9126
+ - id: layer_6_width_16k_l0_big_affine
9127
+ path: transcoder_all/layer_6_width_16k_l0_big_affine
9128
+ l0: 77
9129
+ - id: layer_6_width_16k_l0_small
9130
+ path: transcoder_all/layer_6_width_16k_l0_small
9131
+ l0: 12
9132
+ - id: layer_6_width_16k_l0_small_affine
9133
+ path: transcoder_all/layer_6_width_16k_l0_small_affine
9134
+ l0: 12
9135
+ - id: layer_6_width_262k_l0_big
9136
+ path: transcoder_all/layer_6_width_262k_l0_big
9137
+ l0: 77
9138
+ - id: layer_6_width_262k_l0_big_affine
9139
+ path: transcoder_all/layer_6_width_262k_l0_big_affine
9140
+ l0: 77
9141
+ - id: layer_6_width_262k_l0_small
9142
+ path: transcoder_all/layer_6_width_262k_l0_small
9143
+ l0: 12
9144
+ - id: layer_6_width_262k_l0_small_affine
9145
+ path: transcoder_all/layer_6_width_262k_l0_small_affine
9146
+ l0: 12
9147
+ - id: layer_7_width_16k_l0_big
9148
+ path: transcoder_all/layer_7_width_16k_l0_big
9149
+ l0: 80
9150
+ - id: layer_7_width_16k_l0_big_affine
9151
+ path: transcoder_all/layer_7_width_16k_l0_big_affine
9152
+ l0: 80
9153
+ - id: layer_7_width_16k_l0_small
9154
+ path: transcoder_all/layer_7_width_16k_l0_small
9155
+ l0: 13
9156
+ - id: layer_7_width_16k_l0_small_affine
9157
+ path: transcoder_all/layer_7_width_16k_l0_small_affine
9158
+ l0: 13
9159
+ - id: layer_7_width_262k_l0_big
9160
+ path: transcoder_all/layer_7_width_262k_l0_big
9161
+ l0: 80
9162
+ - id: layer_7_width_262k_l0_big_affine
9163
+ path: transcoder_all/layer_7_width_262k_l0_big_affine
9164
+ l0: 80
9165
+ - id: layer_7_width_262k_l0_small
9166
+ path: transcoder_all/layer_7_width_262k_l0_small
9167
+ l0: 13
9168
+ - id: layer_7_width_262k_l0_small_affine
9169
+ path: transcoder_all/layer_7_width_262k_l0_small_affine
9170
+ l0: 13
9171
+ - id: layer_8_width_16k_l0_big
9172
+ path: transcoder_all/layer_8_width_16k_l0_big
9173
+ l0: 83
9174
+ - id: layer_8_width_16k_l0_big_affine
9175
+ path: transcoder_all/layer_8_width_16k_l0_big_affine
9176
+ l0: 83
9177
+ - id: layer_8_width_16k_l0_small
9178
+ path: transcoder_all/layer_8_width_16k_l0_small
9179
+ l0: 13
9180
+ - id: layer_8_width_16k_l0_small_affine
9181
+ path: transcoder_all/layer_8_width_16k_l0_small_affine
9182
+ l0: 13
9183
+ - id: layer_8_width_262k_l0_big
9184
+ path: transcoder_all/layer_8_width_262k_l0_big
9185
+ l0: 83
9186
+ - id: layer_8_width_262k_l0_big_affine
9187
+ path: transcoder_all/layer_8_width_262k_l0_big_affine
9188
+ l0: 83
9189
+ - id: layer_8_width_262k_l0_small
9190
+ path: transcoder_all/layer_8_width_262k_l0_small
9191
+ l0: 13
9192
+ - id: layer_8_width_262k_l0_small_affine
9193
+ path: transcoder_all/layer_8_width_262k_l0_small_affine
9194
+ l0: 13
9195
+ - id: layer_9_width_16k_l0_big
9196
+ path: transcoder_all/layer_9_width_16k_l0_big
9197
+ l0: 86
9198
+ - id: layer_9_width_16k_l0_big_affine
9199
+ path: transcoder_all/layer_9_width_16k_l0_big_affine
9200
+ l0: 86
9201
+ - id: layer_9_width_16k_l0_small
9202
+ path: transcoder_all/layer_9_width_16k_l0_small
9203
+ l0: 14
9204
+ - id: layer_9_width_16k_l0_small_affine
9205
+ path: transcoder_all/layer_9_width_16k_l0_small_affine
9206
+ l0: 14
9207
+ - id: layer_9_width_262k_l0_big
9208
+ path: transcoder_all/layer_9_width_262k_l0_big
9209
+ l0: 86
9210
+ - id: layer_9_width_262k_l0_big_affine
9211
+ path: transcoder_all/layer_9_width_262k_l0_big_affine
9212
+ l0: 86
9213
+ - id: layer_9_width_262k_l0_small
9214
+ path: transcoder_all/layer_9_width_262k_l0_small
9215
+ l0: 14
9216
+ - id: layer_9_width_262k_l0_small_affine
9217
+ path: transcoder_all/layer_9_width_262k_l0_small_affine
9218
+ l0: 14
9219
9219
  gemma-scope-2-27b-it-transcoders:
9220
9220
  conversion_func: gemma_3
9221
9221
  model: google/gemma-3-27b-it
sae_lens/saes/__init__.py CHANGED
@@ -14,6 +14,12 @@ from .jumprelu_sae import (
14
14
  JumpReLUTrainingSAE,
15
15
  JumpReLUTrainingSAEConfig,
16
16
  )
17
+ from .matching_pursuit_sae import (
18
+ MatchingPursuitSAE,
19
+ MatchingPursuitSAEConfig,
20
+ MatchingPursuitTrainingSAE,
21
+ MatchingPursuitTrainingSAEConfig,
22
+ )
17
23
  from .matryoshka_batchtopk_sae import (
18
24
  MatryoshkaBatchTopKTrainingSAE,
19
25
  MatryoshkaBatchTopKTrainingSAEConfig,
@@ -78,4 +84,8 @@ __all__ = [
78
84
  "MatryoshkaBatchTopKTrainingSAEConfig",
79
85
  "TemporalSAE",
80
86
  "TemporalSAEConfig",
87
+ "MatchingPursuitSAE",
88
+ "MatchingPursuitTrainingSAE",
89
+ "MatchingPursuitSAEConfig",
90
+ "MatchingPursuitTrainingSAEConfig",
81
91
  ]
@@ -0,0 +1,334 @@
1
+ """Matching Pursuit SAE"""
2
+
3
+ import warnings
4
+ from dataclasses import dataclass
5
+ from typing import Any
6
+
7
+ import torch
8
+ from typing_extensions import override
9
+
10
+ from sae_lens.saes.sae import (
11
+ SAE,
12
+ SAEConfig,
13
+ TrainCoefficientConfig,
14
+ TrainingSAE,
15
+ TrainingSAEConfig,
16
+ TrainStepInput,
17
+ TrainStepOutput,
18
+ )
19
+
20
+ # --- inference ---
21
+
22
+
23
+ @dataclass
24
+ class MatchingPursuitSAEConfig(SAEConfig):
25
+ """
26
+ Configuration class for MatchingPursuitSAE inference.
27
+
28
+ Args:
29
+ residual_threshold (float): residual error at which to stop selecting latents. Default 1e-2.
30
+ max_iterations (int | None): Maximum iterations (default: d_in if set to None).
31
+ Defaults to None.
32
+ stop_on_duplicate_support (bool): Whether to stop selecting latents if the support set has not changed from the previous iteration. Defaults to True.
33
+ d_in (int): Input dimension (dimensionality of the activations being encoded).
34
+ Inherited from SAEConfig.
35
+ d_sae (int): SAE latent dimension (number of features in the SAE).
36
+ Inherited from SAEConfig.
37
+ dtype (str): Data type for the SAE parameters. Inherited from SAEConfig.
38
+ Defaults to "float32".
39
+ device (str): Device to place the SAE on. Inherited from SAEConfig.
40
+ Defaults to "cpu".
41
+ apply_b_dec_to_input (bool): Whether to apply decoder bias to the input
42
+ before encoding. Inherited from SAEConfig. Defaults to True.
43
+ normalize_activations (Literal["none", "expected_average_only_in", "constant_norm_rescale", "layer_norm"]):
44
+ Normalization strategy for input activations. Inherited from SAEConfig.
45
+ Defaults to "none".
46
+ reshape_activations (Literal["none", "hook_z"]): How to reshape activations
47
+ (useful for attention head outputs). Inherited from SAEConfig.
48
+ Defaults to "none".
49
+ metadata (SAEMetadata): Metadata about the SAE (model name, hook name, etc.).
50
+ Inherited from SAEConfig.
51
+ """
52
+
53
+ residual_threshold: float = 1e-2
54
+ max_iterations: int | None = None
55
+ stop_on_duplicate_support: bool = True
56
+
57
+ @override
58
+ @classmethod
59
+ def architecture(cls) -> str:
60
+ return "matching_pursuit"
61
+
62
+
63
+ class MatchingPursuitSAE(SAE[MatchingPursuitSAEConfig]):
64
+ """
65
+ An inference-only sparse autoencoder using a "matching pursuit" activation function.
66
+ """
67
+
68
+ # Matching pursuit is a tied SAE, so we use W_enc as the decoder transposed
69
+ @property
70
+ def W_enc(self) -> torch.Tensor: # pyright: ignore[reportIncompatibleVariableOverride]
71
+ return self.W_dec.T
72
+
73
+ # hacky way to get around the base class having W_enc.
74
+ # TODO: harmonize with the base class in next major release
75
+ @override
76
+ def __setattr__(self, name: str, value: Any):
77
+ if name == "W_enc":
78
+ return
79
+ super().__setattr__(name, value)
80
+
81
+ @override
82
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
83
+ """
84
+ Converts input x into feature activations.
85
+ """
86
+ sae_in = self.process_sae_in(x)
87
+ return _encode_matching_pursuit(
88
+ sae_in,
89
+ self.W_dec,
90
+ self.cfg.residual_threshold,
91
+ max_iterations=self.cfg.max_iterations,
92
+ stop_on_duplicate_support=self.cfg.stop_on_duplicate_support,
93
+ )
94
+
95
+ @override
96
+ @torch.no_grad()
97
+ def fold_W_dec_norm(self) -> None:
98
+ raise NotImplementedError(
99
+ "Folding W_dec_norm is not safe for MatchingPursuit SAEs, as this may change the resulting activations"
100
+ )
101
+
102
+ @override
103
+ def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
104
+ """
105
+ Decode the feature activations back to the input space.
106
+ Now, if hook_z reshaping is turned on, we reverse the flattening.
107
+ """
108
+ sae_out_pre = feature_acts @ self.W_dec
109
+ # since this is a tied SAE, we need to make sure b_dec is only applied if applied at input
110
+ if self.cfg.apply_b_dec_to_input:
111
+ sae_out_pre = sae_out_pre + self.b_dec
112
+ sae_out_pre = self.hook_sae_recons(sae_out_pre)
113
+ sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
114
+ return self.reshape_fn_out(sae_out_pre, self.d_head)
115
+
116
+
117
+ # --- training ---
118
+
119
+
120
+ @dataclass
121
+ class MatchingPursuitTrainingSAEConfig(TrainingSAEConfig):
122
+ """
123
+ Configuration class for training a MatchingPursuitTrainingSAE.
124
+
125
+ Args:
126
+ residual_threshold (float): residual error at which to stop selecting latents. Default 1e-2.
127
+ max_iterations (int | None): Maximum iterations (default: d_in if set to None).
128
+ Defaults to None.
129
+ stop_on_duplicate_support (bool): Whether to stop selecting latents if the support set has not changed from the previous iteration. Defaults to True.
130
+ decoder_init_norm (float | None): Norm to initialize decoder weights to.
131
+ 0.1 corresponds to the "heuristic" initialization from Anthropic's April update.
132
+ Use None to disable. Inherited from TrainingSAEConfig. Defaults to 0.1.
133
+ d_in (int): Input dimension (dimensionality of the activations being encoded).
134
+ Inherited from SAEConfig.
135
+ d_sae (int): SAE latent dimension (number of features in the SAE).
136
+ Inherited from SAEConfig.
137
+ dtype (str): Data type for the SAE parameters. Inherited from SAEConfig.
138
+ Defaults to "float32".
139
+ device (str): Device to place the SAE on. Inherited from SAEConfig.
140
+ Defaults to "cpu".
141
+ apply_b_dec_to_input (bool): Whether to apply decoder bias to the input
142
+ before encoding. Inherited from SAEConfig. Defaults to True.
143
+ normalize_activations (Literal["none", "expected_average_only_in", "constant_norm_rescale", "layer_norm"]):
144
+ Normalization strategy for input activations. Inherited from SAEConfig.
145
+ Defaults to "none".
146
+ reshape_activations (Literal["none", "hook_z"]): How to reshape activations
147
+ (useful for attention head outputs). Inherited from SAEConfig.
148
+ Defaults to "none".
149
+ metadata (SAEMetadata): Metadata about the SAE training (model name, hook name, etc.).
150
+ Inherited from SAEConfig.
151
+ """
152
+
153
+ residual_threshold: float = 1e-2
154
+ max_iterations: int | None = None
155
+ stop_on_duplicate_support: bool = True
156
+
157
+ @override
158
+ @classmethod
159
+ def architecture(cls) -> str:
160
+ return "matching_pursuit"
161
+
162
+ @override
163
+ def __post_init__(self):
164
+ super().__post_init__()
165
+ if self.decoder_init_norm != 1.0:
166
+ self.decoder_init_norm = 1.0
167
+ warnings.warn(
168
+ "decoder_init_norm must be set to 1.0 for MatchingPursuitTrainingSAE, setting to 1.0"
169
+ )
170
+
171
+
172
+ class MatchingPursuitTrainingSAE(TrainingSAE[MatchingPursuitTrainingSAEConfig]):
173
+ # Matching pursuit is a tied SAE, so we use W_enc as the decoder transposed
174
+ @property
175
+ def W_enc(self) -> torch.Tensor: # pyright: ignore[reportIncompatibleVariableOverride]
176
+ return self.W_dec.T
177
+
178
+ # hacky way to get around the base class having W_enc.
179
+ # TODO: harmonize with the base class in next major release
180
+ @override
181
+ def __setattr__(self, name: str, value: Any):
182
+ if name == "W_enc":
183
+ return
184
+ super().__setattr__(name, value)
185
+
186
+ @override
187
+ def encode_with_hidden_pre(
188
+ self, x: torch.Tensor
189
+ ) -> tuple[torch.Tensor, torch.Tensor]:
190
+ """
191
+ hidden_pre doesn't make sense for matching pursuit, since there is not a single pre-activation.
192
+ We just return zeros for the hidden_pre.
193
+ """
194
+
195
+ sae_in = self.process_sae_in(x)
196
+ acts = _encode_matching_pursuit(
197
+ sae_in,
198
+ self.W_dec,
199
+ self.cfg.residual_threshold,
200
+ max_iterations=self.cfg.max_iterations,
201
+ stop_on_duplicate_support=self.cfg.stop_on_duplicate_support,
202
+ )
203
+ return acts, torch.zeros_like(acts)
204
+
205
+ @override
206
+ @torch.no_grad()
207
+ def fold_W_dec_norm(self) -> None:
208
+ raise NotImplementedError(
209
+ "Folding W_dec_norm is not safe for MatchingPursuit SAEs, as this may change the resulting activations"
210
+ )
211
+
212
+ @override
213
+ def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]:
214
+ return {}
215
+
216
+ @override
217
+ def calculate_aux_loss(
218
+ self,
219
+ step_input: TrainStepInput,
220
+ feature_acts: torch.Tensor,
221
+ hidden_pre: torch.Tensor,
222
+ sae_out: torch.Tensor,
223
+ ) -> dict[str, torch.Tensor]:
224
+ return {}
225
+
226
+ @override
227
+ def training_forward_pass(self, step_input: TrainStepInput) -> TrainStepOutput:
228
+ output = super().training_forward_pass(step_input)
229
+ l0 = output.feature_acts.bool().float().sum(-1).to_dense()
230
+ residual_norm = (step_input.sae_in - output.sae_out).norm(dim=-1)
231
+ output.metrics["max_l0"] = l0.max()
232
+ output.metrics["min_l0"] = l0.min()
233
+ output.metrics["residual_norm"] = residual_norm.mean()
234
+ output.metrics["residual_threshold_converged_portion"] = (
235
+ (residual_norm < self.cfg.residual_threshold).float().mean()
236
+ )
237
+ return output
238
+
239
+ @override
240
+ def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
241
+ """
242
+ Decode the feature activations back to the input space.
243
+ Now, if hook_z reshaping is turned on, we reverse the flattening.
244
+ """
245
+ sae_out_pre = feature_acts @ self.W_dec
246
+ # since this is a tied SAE, we need to make sure b_dec is only applied if applied at input
247
+ if self.cfg.apply_b_dec_to_input:
248
+ sae_out_pre = sae_out_pre + self.b_dec
249
+ sae_out_pre = self.hook_sae_recons(sae_out_pre)
250
+ sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
251
+ return self.reshape_fn_out(sae_out_pre, self.d_head)
252
+
253
+
254
+ # --- shared ---
255
+
256
+
257
+ def _encode_matching_pursuit(
258
+ sae_in_centered: torch.Tensor,
259
+ W_dec: torch.Tensor,
260
+ residual_threshold: float,
261
+ max_iterations: int | None,
262
+ stop_on_duplicate_support: bool,
263
+ ) -> torch.Tensor:
264
+ """
265
+ Matching pursuit encoding.
266
+
267
+ Args:
268
+ sae_in_centered: Input activations, centered by b_dec. Shape [..., d_in].
269
+ W_dec: Decoder weight matrix. Shape [d_sae, d_in].
270
+ residual_threshold: Stop when residual norm falls below this.
271
+ max_iterations: Maximum iterations (default: d_in). Prevents infinite loops.
272
+ stop_on_duplicate_support: Whether to stop selecting latents if the support set has not changed from the previous iteration.
273
+ """
274
+ residual = sae_in_centered.clone()
275
+
276
+ stop_on_residual_threshold = residual_threshold > 0
277
+
278
+ # Handle multi-dimensional inputs by flattening all but the last dimension
279
+ original_shape = residual.shape
280
+ if residual.ndim > 2:
281
+ residual = residual.reshape(-1, residual.shape[-1])
282
+
283
+ batch_size = residual.shape[0]
284
+ d_sae, d_in = W_dec.shape
285
+
286
+ if max_iterations is None:
287
+ max_iterations = d_in # Sensible upper bound
288
+
289
+ acts = torch.zeros(batch_size, d_sae, device=W_dec.device, dtype=residual.dtype)
290
+ prev_support = torch.zeros(batch_size, d_sae, dtype=torch.bool, device=W_dec.device)
291
+ done = torch.zeros(batch_size, dtype=torch.bool, device=W_dec.device)
292
+
293
+ for _ in range(max_iterations):
294
+ # Find indices without gradients - the full [batch, d_sae] matmul result
295
+ # doesn't need to be saved for backward since max indices don't need gradients
296
+ with torch.no_grad():
297
+ indices = (residual @ W_dec.T).relu().max(dim=1, keepdim=True).indices
298
+ indices_flat = indices.squeeze(1) # [batch_size]
299
+
300
+ # Compute values with gradients using only the selected decoder rows.
301
+ # This stores [batch, d_in] for backward instead of [batch, d_sae].
302
+ selected_dec = W_dec[indices_flat] # [batch_size, d_in]
303
+ values = (residual * selected_dec).sum(dim=-1, keepdim=True).relu()
304
+
305
+ # Mask values for samples that are already done
306
+ active_mask = (~done).unsqueeze(1)
307
+ masked_values = (values * active_mask.to(values.dtype)).to(acts.dtype)
308
+
309
+ acts.scatter_add_(1, indices, masked_values)
310
+
311
+ # Update residual
312
+ residual = residual - masked_values * selected_dec
313
+
314
+ if stop_on_duplicate_support or stop_on_residual_threshold:
315
+ with torch.no_grad():
316
+ support = acts != 0
317
+
318
+ # A sample is considered converged if:
319
+ # (1) the support set hasn't changed from the previous iteration (stability), or
320
+ # (2) the residual norm is below a given threshold (good enough reconstruction)
321
+ if stop_on_duplicate_support:
322
+ done = done | (support == prev_support).all(dim=1)
323
+ prev_support = support
324
+ if stop_on_residual_threshold:
325
+ done = done | (residual.norm(dim=-1) < residual_threshold)
326
+
327
+ if done.all():
328
+ break
329
+
330
+ # Reshape acts back to original shape (replacing last dimension with d_sae)
331
+ if len(original_shape) > 2:
332
+ acts = acts.reshape(*original_shape[:-1], acts.shape[-1])
333
+
334
+ return acts
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sae-lens
3
- Version: 6.25.1
3
+ Version: 6.26.1
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  License-File: LICENSE
@@ -1,9 +1,9 @@
1
- sae_lens/__init__.py,sha256=vWuA8EbynIJadj666RoFNCTIvoH9-HFpUxuHwoYt8Ks,4268
1
+ sae_lens/__init__.py,sha256=zRp1nmb41W1Pt1rvlKvRWw73UxjGyz1iHAzH9_X6_WQ,4725
2
2
  sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  sae_lens/analysis/hooked_sae_transformer.py,sha256=dQRgGVwce8XwylL2AzJE7l9elhtMRFCs2hdUj-Qyy4g,14038
4
4
  sae_lens/analysis/neuronpedia_integration.py,sha256=Gx1W7hUBEuMoasNcnOnZ1wmqbXDd1pSZ1nqKEya1HQc,4962
5
5
  sae_lens/cache_activations_runner.py,sha256=Lvlz-k5-3XxVRtUdC4b1CiKyx5s0ckLa8GDGv9_kcxs,12566
6
- sae_lens/config.py,sha256=JmcrXT4orJV2OulbEZAciz8RQmYv7DrtUtRbOLsNQ2Y,30330
6
+ sae_lens/config.py,sha256=C982bUELhGHcfTwzeMTtXIf2hPtc946thYpUyctLiBo,30516
7
7
  sae_lens/constants.py,sha256=CM-h9AjZNAl2aP7hVpKk7YsFHpu-_Lfhhmq2d5qPEVc,887
8
8
  sae_lens/evals.py,sha256=P0NUsJeGzYxFBiVKhbPzd72IFKY4gH40HHlEZ3jEAmg,39598
9
9
  sae_lens/llm_sae_training_runner.py,sha256=M7BK55gSFYu2qFQKABHX3c8i46P1LfODCeyHFzGGuqU,15196
@@ -12,12 +12,13 @@ sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,
12
12
  sae_lens/loading/pretrained_sae_loaders.py,sha256=hq-dhxsEdUmlAnZEiZBqX7lNyQQwZ6KXmXZWpzAc5FY,63638
13
13
  sae_lens/loading/pretrained_saes_directory.py,sha256=hejNfLUepYCSGPalRfQwxxCEUqMMUPsn1tufwvwct5k,3820
14
14
  sae_lens/pretokenize_runner.py,sha256=amJwIz3CKi2s2wNQn-10E7eAV7VFhNqtFDNTeTkwEI8,7133
15
- sae_lens/pretrained_saes.yaml,sha256=Hy9mk4Liy50B0CIBD4ER1ETcho2drFFiIy-bPVCN_lc,1510210
15
+ sae_lens/pretrained_saes.yaml,sha256=Hn8jXwZ7V6QQxzgu41LFEP-LAzuDxwYL5vhoar-pPX8,1509922
16
16
  sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
17
- sae_lens/saes/__init__.py,sha256=fYVujOzNnUgpzLL0MBLBt_DNX2CPcTaheukzCd2bEPo,1906
17
+ sae_lens/saes/__init__.py,sha256=SBqPaP6Gl5uPFwHlumAZATC4Wd26xKIYLAAAo4MSa5Q,2200
18
18
  sae_lens/saes/batchtopk_sae.py,sha256=x4EbgZl0GUickRPcCmtKNGS2Ra3Uy1Z1OtF2FnrSabQ,5422
19
19
  sae_lens/saes/gated_sae.py,sha256=mHnmw-RD7hqIbP9_EBj3p2SK0OqQIkZivdOKRygeRgw,8825
20
20
  sae_lens/saes/jumprelu_sae.py,sha256=udjGHp3WTABQSL2Qq57j-bINWX61GCmo68EmdjMOXoo,13310
21
+ sae_lens/saes/matching_pursuit_sae.py,sha256=08_G9p1YMLnE5qZVCPp6gll-iG6nHRbMMASf4_bkFt8,13207
21
22
  sae_lens/saes/matryoshka_batchtopk_sae.py,sha256=Qr6htt1HHOuO9FXI9hyaPSnGFIiJG-v7y1t1CEmkFzM,5995
22
23
  sae_lens/saes/sae.py,sha256=fzXv8lwHskSxsf8hm_wlKPkpq50iafmBjBNQzwZ6a00,40050
23
24
  sae_lens/saes/standard_sae.py,sha256=nEVETwAmRD2tyX7ESIic1fij48gAq1Dh7s_GQ2fqCZ4,5747
@@ -35,7 +36,7 @@ sae_lens/training/types.py,sha256=1FpLx_Doda9vZpmfm-x1e8wGBYpyhe9Kpb_JuM5nIFM,90
35
36
  sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
36
37
  sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
37
38
  sae_lens/util.py,sha256=spkcmQUsjVYFn5H2032nQYr1CKGVnv3tAdfIpY59-Mg,3919
38
- sae_lens-6.25.1.dist-info/METADATA,sha256=gClFVWzEWNNjrXsGqvCY6ry6ehXIFwp8PB0jIOhmQvc,5361
39
- sae_lens-6.25.1.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
40
- sae_lens-6.25.1.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
41
- sae_lens-6.25.1.dist-info/RECORD,,
39
+ sae_lens-6.26.1.dist-info/METADATA,sha256=yoE6CFgQ9L5SLzI3Zgr8H8CfUBgSimihGyEIvKd8TW8,5361
40
+ sae_lens-6.26.1.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
41
+ sae_lens-6.26.1.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
42
+ sae_lens-6.26.1.dist-info/RECORD,,