sae-lens 6.25.1__py3-none-any.whl → 6.26.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +13 -1
- sae_lens/config.py +5 -0
- sae_lens/pretrained_saes.yaml +144 -144
- sae_lens/saes/__init__.py +10 -0
- sae_lens/saes/matching_pursuit_sae.py +334 -0
- {sae_lens-6.25.1.dist-info → sae_lens-6.26.1.dist-info}/METADATA +1 -1
- {sae_lens-6.25.1.dist-info → sae_lens-6.26.1.dist-info}/RECORD +9 -8
- {sae_lens-6.25.1.dist-info → sae_lens-6.26.1.dist-info}/WHEEL +0 -0
- {sae_lens-6.25.1.dist-info → sae_lens-6.26.1.dist-info}/licenses/LICENSE +0 -0
sae_lens/__init__.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
# ruff: noqa: E402
|
|
2
|
-
__version__ = "6.
|
|
2
|
+
__version__ = "6.26.1"
|
|
3
3
|
|
|
4
4
|
import logging
|
|
5
5
|
|
|
@@ -21,6 +21,10 @@ from sae_lens.saes import (
|
|
|
21
21
|
JumpReLUTrainingSAEConfig,
|
|
22
22
|
JumpReLUTranscoder,
|
|
23
23
|
JumpReLUTranscoderConfig,
|
|
24
|
+
MatchingPursuitSAE,
|
|
25
|
+
MatchingPursuitSAEConfig,
|
|
26
|
+
MatchingPursuitTrainingSAE,
|
|
27
|
+
MatchingPursuitTrainingSAEConfig,
|
|
24
28
|
MatryoshkaBatchTopKTrainingSAE,
|
|
25
29
|
MatryoshkaBatchTopKTrainingSAEConfig,
|
|
26
30
|
SAEConfig,
|
|
@@ -113,6 +117,10 @@ __all__ = [
|
|
|
113
117
|
"MatryoshkaBatchTopKTrainingSAEConfig",
|
|
114
118
|
"TemporalSAE",
|
|
115
119
|
"TemporalSAEConfig",
|
|
120
|
+
"MatchingPursuitSAE",
|
|
121
|
+
"MatchingPursuitTrainingSAE",
|
|
122
|
+
"MatchingPursuitSAEConfig",
|
|
123
|
+
"MatchingPursuitTrainingSAEConfig",
|
|
116
124
|
]
|
|
117
125
|
|
|
118
126
|
|
|
@@ -139,3 +147,7 @@ register_sae_class(
|
|
|
139
147
|
"jumprelu_skip_transcoder", JumpReLUSkipTranscoder, JumpReLUSkipTranscoderConfig
|
|
140
148
|
)
|
|
141
149
|
register_sae_class("temporal", TemporalSAE, TemporalSAEConfig)
|
|
150
|
+
register_sae_class("matching_pursuit", MatchingPursuitSAE, MatchingPursuitSAEConfig)
|
|
151
|
+
register_sae_training_class(
|
|
152
|
+
"matching_pursuit", MatchingPursuitTrainingSAE, MatchingPursuitTrainingSAEConfig
|
|
153
|
+
)
|
sae_lens/config.py
CHANGED
|
@@ -17,6 +17,11 @@ from datasets import (
|
|
|
17
17
|
)
|
|
18
18
|
|
|
19
19
|
from sae_lens import __version__, logger
|
|
20
|
+
|
|
21
|
+
# keeping this unused import since some SAELens deps import DTYPE_MAP from config
|
|
22
|
+
from sae_lens.constants import (
|
|
23
|
+
DTYPE_MAP, # noqa: F401 # pyright: ignore[reportUnusedImport]
|
|
24
|
+
)
|
|
20
25
|
from sae_lens.registry import get_sae_training_class
|
|
21
26
|
from sae_lens.saes.sae import TrainingSAEConfig
|
|
22
27
|
from sae_lens.util import str_to_dtype
|
sae_lens/pretrained_saes.yaml
CHANGED
|
@@ -9072,150 +9072,150 @@ gemma-scope-2-27b-it-transcoders-all:
|
|
|
9072
9072
|
- id: layer_5_width_262k_l0_small_affine
|
|
9073
9073
|
path: transcoder_all/layer_5_width_262k_l0_small_affine
|
|
9074
9074
|
l0: 12
|
|
9075
|
-
|
|
9076
|
-
|
|
9077
|
-
|
|
9078
|
-
|
|
9079
|
-
|
|
9080
|
-
|
|
9081
|
-
|
|
9082
|
-
|
|
9083
|
-
|
|
9084
|
-
|
|
9085
|
-
|
|
9086
|
-
|
|
9087
|
-
|
|
9088
|
-
|
|
9089
|
-
|
|
9090
|
-
|
|
9091
|
-
|
|
9092
|
-
|
|
9093
|
-
|
|
9094
|
-
|
|
9095
|
-
|
|
9096
|
-
|
|
9097
|
-
|
|
9098
|
-
|
|
9099
|
-
|
|
9100
|
-
|
|
9101
|
-
|
|
9102
|
-
|
|
9103
|
-
|
|
9104
|
-
|
|
9105
|
-
|
|
9106
|
-
|
|
9107
|
-
|
|
9108
|
-
|
|
9109
|
-
|
|
9110
|
-
|
|
9111
|
-
|
|
9112
|
-
|
|
9113
|
-
|
|
9114
|
-
|
|
9115
|
-
|
|
9116
|
-
|
|
9117
|
-
|
|
9118
|
-
|
|
9119
|
-
|
|
9120
|
-
|
|
9121
|
-
|
|
9122
|
-
|
|
9123
|
-
|
|
9124
|
-
|
|
9125
|
-
|
|
9126
|
-
|
|
9127
|
-
|
|
9128
|
-
|
|
9129
|
-
|
|
9130
|
-
|
|
9131
|
-
|
|
9132
|
-
|
|
9133
|
-
|
|
9134
|
-
|
|
9135
|
-
|
|
9136
|
-
|
|
9137
|
-
|
|
9138
|
-
|
|
9139
|
-
|
|
9140
|
-
|
|
9141
|
-
|
|
9142
|
-
|
|
9143
|
-
|
|
9144
|
-
|
|
9145
|
-
|
|
9146
|
-
|
|
9147
|
-
|
|
9148
|
-
|
|
9149
|
-
|
|
9150
|
-
|
|
9151
|
-
|
|
9152
|
-
|
|
9153
|
-
|
|
9154
|
-
|
|
9155
|
-
|
|
9156
|
-
|
|
9157
|
-
|
|
9158
|
-
|
|
9159
|
-
|
|
9160
|
-
|
|
9161
|
-
|
|
9162
|
-
|
|
9163
|
-
|
|
9164
|
-
|
|
9165
|
-
|
|
9166
|
-
|
|
9167
|
-
|
|
9168
|
-
|
|
9169
|
-
|
|
9170
|
-
|
|
9171
|
-
|
|
9172
|
-
|
|
9173
|
-
|
|
9174
|
-
|
|
9175
|
-
|
|
9176
|
-
|
|
9177
|
-
|
|
9178
|
-
|
|
9179
|
-
|
|
9180
|
-
|
|
9181
|
-
|
|
9182
|
-
|
|
9183
|
-
|
|
9184
|
-
|
|
9185
|
-
|
|
9186
|
-
|
|
9187
|
-
|
|
9188
|
-
|
|
9189
|
-
|
|
9190
|
-
|
|
9191
|
-
|
|
9192
|
-
|
|
9193
|
-
|
|
9194
|
-
|
|
9195
|
-
|
|
9196
|
-
|
|
9197
|
-
|
|
9198
|
-
|
|
9199
|
-
|
|
9200
|
-
|
|
9201
|
-
|
|
9202
|
-
|
|
9203
|
-
|
|
9204
|
-
|
|
9205
|
-
|
|
9206
|
-
|
|
9207
|
-
|
|
9208
|
-
|
|
9209
|
-
|
|
9210
|
-
|
|
9211
|
-
|
|
9212
|
-
|
|
9213
|
-
|
|
9214
|
-
|
|
9215
|
-
|
|
9216
|
-
|
|
9217
|
-
|
|
9218
|
-
|
|
9075
|
+
- id: layer_60_width_16k_l0_big
|
|
9076
|
+
path: transcoder_all/layer_60_width_16k_l0_big
|
|
9077
|
+
l0: 120
|
|
9078
|
+
- id: layer_60_width_16k_l0_big_affine
|
|
9079
|
+
path: transcoder_all/layer_60_width_16k_l0_big_affine
|
|
9080
|
+
l0: 120
|
|
9081
|
+
- id: layer_60_width_16k_l0_small
|
|
9082
|
+
path: transcoder_all/layer_60_width_16k_l0_small
|
|
9083
|
+
l0: 20
|
|
9084
|
+
- id: layer_60_width_16k_l0_small_affine
|
|
9085
|
+
path: transcoder_all/layer_60_width_16k_l0_small_affine
|
|
9086
|
+
l0: 20
|
|
9087
|
+
- id: layer_60_width_262k_l0_big
|
|
9088
|
+
path: transcoder_all/layer_60_width_262k_l0_big
|
|
9089
|
+
l0: 120
|
|
9090
|
+
- id: layer_60_width_262k_l0_big_affine
|
|
9091
|
+
path: transcoder_all/layer_60_width_262k_l0_big_affine
|
|
9092
|
+
l0: 120
|
|
9093
|
+
- id: layer_60_width_262k_l0_small
|
|
9094
|
+
path: transcoder_all/layer_60_width_262k_l0_small
|
|
9095
|
+
l0: 20
|
|
9096
|
+
- id: layer_60_width_262k_l0_small_affine
|
|
9097
|
+
path: transcoder_all/layer_60_width_262k_l0_small_affine
|
|
9098
|
+
l0: 20
|
|
9099
|
+
- id: layer_61_width_16k_l0_big
|
|
9100
|
+
path: transcoder_all/layer_61_width_16k_l0_big
|
|
9101
|
+
l0: 120
|
|
9102
|
+
- id: layer_61_width_16k_l0_big_affine
|
|
9103
|
+
path: transcoder_all/layer_61_width_16k_l0_big_affine
|
|
9104
|
+
l0: 120
|
|
9105
|
+
- id: layer_61_width_16k_l0_small
|
|
9106
|
+
path: transcoder_all/layer_61_width_16k_l0_small
|
|
9107
|
+
l0: 20
|
|
9108
|
+
- id: layer_61_width_16k_l0_small_affine
|
|
9109
|
+
path: transcoder_all/layer_61_width_16k_l0_small_affine
|
|
9110
|
+
l0: 20
|
|
9111
|
+
- id: layer_61_width_262k_l0_big
|
|
9112
|
+
path: transcoder_all/layer_61_width_262k_l0_big
|
|
9113
|
+
l0: 120
|
|
9114
|
+
- id: layer_61_width_262k_l0_big_affine
|
|
9115
|
+
path: transcoder_all/layer_61_width_262k_l0_big_affine
|
|
9116
|
+
l0: 120
|
|
9117
|
+
- id: layer_61_width_262k_l0_small
|
|
9118
|
+
path: transcoder_all/layer_61_width_262k_l0_small
|
|
9119
|
+
l0: 20
|
|
9120
|
+
- id: layer_61_width_262k_l0_small_affine
|
|
9121
|
+
path: transcoder_all/layer_61_width_262k_l0_small_affine
|
|
9122
|
+
l0: 20
|
|
9123
|
+
- id: layer_6_width_16k_l0_big
|
|
9124
|
+
path: transcoder_all/layer_6_width_16k_l0_big
|
|
9125
|
+
l0: 77
|
|
9126
|
+
- id: layer_6_width_16k_l0_big_affine
|
|
9127
|
+
path: transcoder_all/layer_6_width_16k_l0_big_affine
|
|
9128
|
+
l0: 77
|
|
9129
|
+
- id: layer_6_width_16k_l0_small
|
|
9130
|
+
path: transcoder_all/layer_6_width_16k_l0_small
|
|
9131
|
+
l0: 12
|
|
9132
|
+
- id: layer_6_width_16k_l0_small_affine
|
|
9133
|
+
path: transcoder_all/layer_6_width_16k_l0_small_affine
|
|
9134
|
+
l0: 12
|
|
9135
|
+
- id: layer_6_width_262k_l0_big
|
|
9136
|
+
path: transcoder_all/layer_6_width_262k_l0_big
|
|
9137
|
+
l0: 77
|
|
9138
|
+
- id: layer_6_width_262k_l0_big_affine
|
|
9139
|
+
path: transcoder_all/layer_6_width_262k_l0_big_affine
|
|
9140
|
+
l0: 77
|
|
9141
|
+
- id: layer_6_width_262k_l0_small
|
|
9142
|
+
path: transcoder_all/layer_6_width_262k_l0_small
|
|
9143
|
+
l0: 12
|
|
9144
|
+
- id: layer_6_width_262k_l0_small_affine
|
|
9145
|
+
path: transcoder_all/layer_6_width_262k_l0_small_affine
|
|
9146
|
+
l0: 12
|
|
9147
|
+
- id: layer_7_width_16k_l0_big
|
|
9148
|
+
path: transcoder_all/layer_7_width_16k_l0_big
|
|
9149
|
+
l0: 80
|
|
9150
|
+
- id: layer_7_width_16k_l0_big_affine
|
|
9151
|
+
path: transcoder_all/layer_7_width_16k_l0_big_affine
|
|
9152
|
+
l0: 80
|
|
9153
|
+
- id: layer_7_width_16k_l0_small
|
|
9154
|
+
path: transcoder_all/layer_7_width_16k_l0_small
|
|
9155
|
+
l0: 13
|
|
9156
|
+
- id: layer_7_width_16k_l0_small_affine
|
|
9157
|
+
path: transcoder_all/layer_7_width_16k_l0_small_affine
|
|
9158
|
+
l0: 13
|
|
9159
|
+
- id: layer_7_width_262k_l0_big
|
|
9160
|
+
path: transcoder_all/layer_7_width_262k_l0_big
|
|
9161
|
+
l0: 80
|
|
9162
|
+
- id: layer_7_width_262k_l0_big_affine
|
|
9163
|
+
path: transcoder_all/layer_7_width_262k_l0_big_affine
|
|
9164
|
+
l0: 80
|
|
9165
|
+
- id: layer_7_width_262k_l0_small
|
|
9166
|
+
path: transcoder_all/layer_7_width_262k_l0_small
|
|
9167
|
+
l0: 13
|
|
9168
|
+
- id: layer_7_width_262k_l0_small_affine
|
|
9169
|
+
path: transcoder_all/layer_7_width_262k_l0_small_affine
|
|
9170
|
+
l0: 13
|
|
9171
|
+
- id: layer_8_width_16k_l0_big
|
|
9172
|
+
path: transcoder_all/layer_8_width_16k_l0_big
|
|
9173
|
+
l0: 83
|
|
9174
|
+
- id: layer_8_width_16k_l0_big_affine
|
|
9175
|
+
path: transcoder_all/layer_8_width_16k_l0_big_affine
|
|
9176
|
+
l0: 83
|
|
9177
|
+
- id: layer_8_width_16k_l0_small
|
|
9178
|
+
path: transcoder_all/layer_8_width_16k_l0_small
|
|
9179
|
+
l0: 13
|
|
9180
|
+
- id: layer_8_width_16k_l0_small_affine
|
|
9181
|
+
path: transcoder_all/layer_8_width_16k_l0_small_affine
|
|
9182
|
+
l0: 13
|
|
9183
|
+
- id: layer_8_width_262k_l0_big
|
|
9184
|
+
path: transcoder_all/layer_8_width_262k_l0_big
|
|
9185
|
+
l0: 83
|
|
9186
|
+
- id: layer_8_width_262k_l0_big_affine
|
|
9187
|
+
path: transcoder_all/layer_8_width_262k_l0_big_affine
|
|
9188
|
+
l0: 83
|
|
9189
|
+
- id: layer_8_width_262k_l0_small
|
|
9190
|
+
path: transcoder_all/layer_8_width_262k_l0_small
|
|
9191
|
+
l0: 13
|
|
9192
|
+
- id: layer_8_width_262k_l0_small_affine
|
|
9193
|
+
path: transcoder_all/layer_8_width_262k_l0_small_affine
|
|
9194
|
+
l0: 13
|
|
9195
|
+
- id: layer_9_width_16k_l0_big
|
|
9196
|
+
path: transcoder_all/layer_9_width_16k_l0_big
|
|
9197
|
+
l0: 86
|
|
9198
|
+
- id: layer_9_width_16k_l0_big_affine
|
|
9199
|
+
path: transcoder_all/layer_9_width_16k_l0_big_affine
|
|
9200
|
+
l0: 86
|
|
9201
|
+
- id: layer_9_width_16k_l0_small
|
|
9202
|
+
path: transcoder_all/layer_9_width_16k_l0_small
|
|
9203
|
+
l0: 14
|
|
9204
|
+
- id: layer_9_width_16k_l0_small_affine
|
|
9205
|
+
path: transcoder_all/layer_9_width_16k_l0_small_affine
|
|
9206
|
+
l0: 14
|
|
9207
|
+
- id: layer_9_width_262k_l0_big
|
|
9208
|
+
path: transcoder_all/layer_9_width_262k_l0_big
|
|
9209
|
+
l0: 86
|
|
9210
|
+
- id: layer_9_width_262k_l0_big_affine
|
|
9211
|
+
path: transcoder_all/layer_9_width_262k_l0_big_affine
|
|
9212
|
+
l0: 86
|
|
9213
|
+
- id: layer_9_width_262k_l0_small
|
|
9214
|
+
path: transcoder_all/layer_9_width_262k_l0_small
|
|
9215
|
+
l0: 14
|
|
9216
|
+
- id: layer_9_width_262k_l0_small_affine
|
|
9217
|
+
path: transcoder_all/layer_9_width_262k_l0_small_affine
|
|
9218
|
+
l0: 14
|
|
9219
9219
|
gemma-scope-2-27b-it-transcoders:
|
|
9220
9220
|
conversion_func: gemma_3
|
|
9221
9221
|
model: google/gemma-3-27b-it
|
sae_lens/saes/__init__.py
CHANGED
|
@@ -14,6 +14,12 @@ from .jumprelu_sae import (
|
|
|
14
14
|
JumpReLUTrainingSAE,
|
|
15
15
|
JumpReLUTrainingSAEConfig,
|
|
16
16
|
)
|
|
17
|
+
from .matching_pursuit_sae import (
|
|
18
|
+
MatchingPursuitSAE,
|
|
19
|
+
MatchingPursuitSAEConfig,
|
|
20
|
+
MatchingPursuitTrainingSAE,
|
|
21
|
+
MatchingPursuitTrainingSAEConfig,
|
|
22
|
+
)
|
|
17
23
|
from .matryoshka_batchtopk_sae import (
|
|
18
24
|
MatryoshkaBatchTopKTrainingSAE,
|
|
19
25
|
MatryoshkaBatchTopKTrainingSAEConfig,
|
|
@@ -78,4 +84,8 @@ __all__ = [
|
|
|
78
84
|
"MatryoshkaBatchTopKTrainingSAEConfig",
|
|
79
85
|
"TemporalSAE",
|
|
80
86
|
"TemporalSAEConfig",
|
|
87
|
+
"MatchingPursuitSAE",
|
|
88
|
+
"MatchingPursuitTrainingSAE",
|
|
89
|
+
"MatchingPursuitSAEConfig",
|
|
90
|
+
"MatchingPursuitTrainingSAEConfig",
|
|
81
91
|
]
|
|
@@ -0,0 +1,334 @@
|
|
|
1
|
+
"""Matching Pursuit SAE"""
|
|
2
|
+
|
|
3
|
+
import warnings
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from typing_extensions import override
|
|
9
|
+
|
|
10
|
+
from sae_lens.saes.sae import (
|
|
11
|
+
SAE,
|
|
12
|
+
SAEConfig,
|
|
13
|
+
TrainCoefficientConfig,
|
|
14
|
+
TrainingSAE,
|
|
15
|
+
TrainingSAEConfig,
|
|
16
|
+
TrainStepInput,
|
|
17
|
+
TrainStepOutput,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
# --- inference ---
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class MatchingPursuitSAEConfig(SAEConfig):
|
|
25
|
+
"""
|
|
26
|
+
Configuration class for MatchingPursuitSAE inference.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
residual_threshold (float): residual error at which to stop selecting latents. Default 1e-2.
|
|
30
|
+
max_iterations (int | None): Maximum iterations (default: d_in if set to None).
|
|
31
|
+
Defaults to None.
|
|
32
|
+
stop_on_duplicate_support (bool): Whether to stop selecting latents if the support set has not changed from the previous iteration. Defaults to True.
|
|
33
|
+
d_in (int): Input dimension (dimensionality of the activations being encoded).
|
|
34
|
+
Inherited from SAEConfig.
|
|
35
|
+
d_sae (int): SAE latent dimension (number of features in the SAE).
|
|
36
|
+
Inherited from SAEConfig.
|
|
37
|
+
dtype (str): Data type for the SAE parameters. Inherited from SAEConfig.
|
|
38
|
+
Defaults to "float32".
|
|
39
|
+
device (str): Device to place the SAE on. Inherited from SAEConfig.
|
|
40
|
+
Defaults to "cpu".
|
|
41
|
+
apply_b_dec_to_input (bool): Whether to apply decoder bias to the input
|
|
42
|
+
before encoding. Inherited from SAEConfig. Defaults to True.
|
|
43
|
+
normalize_activations (Literal["none", "expected_average_only_in", "constant_norm_rescale", "layer_norm"]):
|
|
44
|
+
Normalization strategy for input activations. Inherited from SAEConfig.
|
|
45
|
+
Defaults to "none".
|
|
46
|
+
reshape_activations (Literal["none", "hook_z"]): How to reshape activations
|
|
47
|
+
(useful for attention head outputs). Inherited from SAEConfig.
|
|
48
|
+
Defaults to "none".
|
|
49
|
+
metadata (SAEMetadata): Metadata about the SAE (model name, hook name, etc.).
|
|
50
|
+
Inherited from SAEConfig.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
residual_threshold: float = 1e-2
|
|
54
|
+
max_iterations: int | None = None
|
|
55
|
+
stop_on_duplicate_support: bool = True
|
|
56
|
+
|
|
57
|
+
@override
|
|
58
|
+
@classmethod
|
|
59
|
+
def architecture(cls) -> str:
|
|
60
|
+
return "matching_pursuit"
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class MatchingPursuitSAE(SAE[MatchingPursuitSAEConfig]):
|
|
64
|
+
"""
|
|
65
|
+
An inference-only sparse autoencoder using a "matching pursuit" activation function.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
# Matching pursuit is a tied SAE, so we use W_enc as the decoder transposed
|
|
69
|
+
@property
|
|
70
|
+
def W_enc(self) -> torch.Tensor: # pyright: ignore[reportIncompatibleVariableOverride]
|
|
71
|
+
return self.W_dec.T
|
|
72
|
+
|
|
73
|
+
# hacky way to get around the base class having W_enc.
|
|
74
|
+
# TODO: harmonize with the base class in next major release
|
|
75
|
+
@override
|
|
76
|
+
def __setattr__(self, name: str, value: Any):
|
|
77
|
+
if name == "W_enc":
|
|
78
|
+
return
|
|
79
|
+
super().__setattr__(name, value)
|
|
80
|
+
|
|
81
|
+
@override
|
|
82
|
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
|
83
|
+
"""
|
|
84
|
+
Converts input x into feature activations.
|
|
85
|
+
"""
|
|
86
|
+
sae_in = self.process_sae_in(x)
|
|
87
|
+
return _encode_matching_pursuit(
|
|
88
|
+
sae_in,
|
|
89
|
+
self.W_dec,
|
|
90
|
+
self.cfg.residual_threshold,
|
|
91
|
+
max_iterations=self.cfg.max_iterations,
|
|
92
|
+
stop_on_duplicate_support=self.cfg.stop_on_duplicate_support,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
@override
|
|
96
|
+
@torch.no_grad()
|
|
97
|
+
def fold_W_dec_norm(self) -> None:
|
|
98
|
+
raise NotImplementedError(
|
|
99
|
+
"Folding W_dec_norm is not safe for MatchingPursuit SAEs, as this may change the resulting activations"
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
@override
|
|
103
|
+
def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
|
|
104
|
+
"""
|
|
105
|
+
Decode the feature activations back to the input space.
|
|
106
|
+
Now, if hook_z reshaping is turned on, we reverse the flattening.
|
|
107
|
+
"""
|
|
108
|
+
sae_out_pre = feature_acts @ self.W_dec
|
|
109
|
+
# since this is a tied SAE, we need to make sure b_dec is only applied if applied at input
|
|
110
|
+
if self.cfg.apply_b_dec_to_input:
|
|
111
|
+
sae_out_pre = sae_out_pre + self.b_dec
|
|
112
|
+
sae_out_pre = self.hook_sae_recons(sae_out_pre)
|
|
113
|
+
sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
|
|
114
|
+
return self.reshape_fn_out(sae_out_pre, self.d_head)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
# --- training ---
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@dataclass
|
|
121
|
+
class MatchingPursuitTrainingSAEConfig(TrainingSAEConfig):
|
|
122
|
+
"""
|
|
123
|
+
Configuration class for training a MatchingPursuitTrainingSAE.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
residual_threshold (float): residual error at which to stop selecting latents. Default 1e-2.
|
|
127
|
+
max_iterations (int | None): Maximum iterations (default: d_in if set to None).
|
|
128
|
+
Defaults to None.
|
|
129
|
+
stop_on_duplicate_support (bool): Whether to stop selecting latents if the support set has not changed from the previous iteration. Defaults to True.
|
|
130
|
+
decoder_init_norm (float | None): Norm to initialize decoder weights to.
|
|
131
|
+
0.1 corresponds to the "heuristic" initialization from Anthropic's April update.
|
|
132
|
+
Use None to disable. Inherited from TrainingSAEConfig. Defaults to 0.1.
|
|
133
|
+
d_in (int): Input dimension (dimensionality of the activations being encoded).
|
|
134
|
+
Inherited from SAEConfig.
|
|
135
|
+
d_sae (int): SAE latent dimension (number of features in the SAE).
|
|
136
|
+
Inherited from SAEConfig.
|
|
137
|
+
dtype (str): Data type for the SAE parameters. Inherited from SAEConfig.
|
|
138
|
+
Defaults to "float32".
|
|
139
|
+
device (str): Device to place the SAE on. Inherited from SAEConfig.
|
|
140
|
+
Defaults to "cpu".
|
|
141
|
+
apply_b_dec_to_input (bool): Whether to apply decoder bias to the input
|
|
142
|
+
before encoding. Inherited from SAEConfig. Defaults to True.
|
|
143
|
+
normalize_activations (Literal["none", "expected_average_only_in", "constant_norm_rescale", "layer_norm"]):
|
|
144
|
+
Normalization strategy for input activations. Inherited from SAEConfig.
|
|
145
|
+
Defaults to "none".
|
|
146
|
+
reshape_activations (Literal["none", "hook_z"]): How to reshape activations
|
|
147
|
+
(useful for attention head outputs). Inherited from SAEConfig.
|
|
148
|
+
Defaults to "none".
|
|
149
|
+
metadata (SAEMetadata): Metadata about the SAE training (model name, hook name, etc.).
|
|
150
|
+
Inherited from SAEConfig.
|
|
151
|
+
"""
|
|
152
|
+
|
|
153
|
+
residual_threshold: float = 1e-2
|
|
154
|
+
max_iterations: int | None = None
|
|
155
|
+
stop_on_duplicate_support: bool = True
|
|
156
|
+
|
|
157
|
+
@override
|
|
158
|
+
@classmethod
|
|
159
|
+
def architecture(cls) -> str:
|
|
160
|
+
return "matching_pursuit"
|
|
161
|
+
|
|
162
|
+
@override
|
|
163
|
+
def __post_init__(self):
|
|
164
|
+
super().__post_init__()
|
|
165
|
+
if self.decoder_init_norm != 1.0:
|
|
166
|
+
self.decoder_init_norm = 1.0
|
|
167
|
+
warnings.warn(
|
|
168
|
+
"decoder_init_norm must be set to 1.0 for MatchingPursuitTrainingSAE, setting to 1.0"
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
class MatchingPursuitTrainingSAE(TrainingSAE[MatchingPursuitTrainingSAEConfig]):
|
|
173
|
+
# Matching pursuit is a tied SAE, so we use W_enc as the decoder transposed
|
|
174
|
+
@property
|
|
175
|
+
def W_enc(self) -> torch.Tensor: # pyright: ignore[reportIncompatibleVariableOverride]
|
|
176
|
+
return self.W_dec.T
|
|
177
|
+
|
|
178
|
+
# hacky way to get around the base class having W_enc.
|
|
179
|
+
# TODO: harmonize with the base class in next major release
|
|
180
|
+
@override
|
|
181
|
+
def __setattr__(self, name: str, value: Any):
|
|
182
|
+
if name == "W_enc":
|
|
183
|
+
return
|
|
184
|
+
super().__setattr__(name, value)
|
|
185
|
+
|
|
186
|
+
@override
|
|
187
|
+
def encode_with_hidden_pre(
|
|
188
|
+
self, x: torch.Tensor
|
|
189
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
190
|
+
"""
|
|
191
|
+
hidden_pre doesn't make sense for matching pursuit, since there is not a single pre-activation.
|
|
192
|
+
We just return zeros for the hidden_pre.
|
|
193
|
+
"""
|
|
194
|
+
|
|
195
|
+
sae_in = self.process_sae_in(x)
|
|
196
|
+
acts = _encode_matching_pursuit(
|
|
197
|
+
sae_in,
|
|
198
|
+
self.W_dec,
|
|
199
|
+
self.cfg.residual_threshold,
|
|
200
|
+
max_iterations=self.cfg.max_iterations,
|
|
201
|
+
stop_on_duplicate_support=self.cfg.stop_on_duplicate_support,
|
|
202
|
+
)
|
|
203
|
+
return acts, torch.zeros_like(acts)
|
|
204
|
+
|
|
205
|
+
@override
|
|
206
|
+
@torch.no_grad()
|
|
207
|
+
def fold_W_dec_norm(self) -> None:
|
|
208
|
+
raise NotImplementedError(
|
|
209
|
+
"Folding W_dec_norm is not safe for MatchingPursuit SAEs, as this may change the resulting activations"
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
@override
|
|
213
|
+
def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]:
|
|
214
|
+
return {}
|
|
215
|
+
|
|
216
|
+
@override
|
|
217
|
+
def calculate_aux_loss(
|
|
218
|
+
self,
|
|
219
|
+
step_input: TrainStepInput,
|
|
220
|
+
feature_acts: torch.Tensor,
|
|
221
|
+
hidden_pre: torch.Tensor,
|
|
222
|
+
sae_out: torch.Tensor,
|
|
223
|
+
) -> dict[str, torch.Tensor]:
|
|
224
|
+
return {}
|
|
225
|
+
|
|
226
|
+
@override
|
|
227
|
+
def training_forward_pass(self, step_input: TrainStepInput) -> TrainStepOutput:
|
|
228
|
+
output = super().training_forward_pass(step_input)
|
|
229
|
+
l0 = output.feature_acts.bool().float().sum(-1).to_dense()
|
|
230
|
+
residual_norm = (step_input.sae_in - output.sae_out).norm(dim=-1)
|
|
231
|
+
output.metrics["max_l0"] = l0.max()
|
|
232
|
+
output.metrics["min_l0"] = l0.min()
|
|
233
|
+
output.metrics["residual_norm"] = residual_norm.mean()
|
|
234
|
+
output.metrics["residual_threshold_converged_portion"] = (
|
|
235
|
+
(residual_norm < self.cfg.residual_threshold).float().mean()
|
|
236
|
+
)
|
|
237
|
+
return output
|
|
238
|
+
|
|
239
|
+
@override
|
|
240
|
+
def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
|
|
241
|
+
"""
|
|
242
|
+
Decode the feature activations back to the input space.
|
|
243
|
+
Now, if hook_z reshaping is turned on, we reverse the flattening.
|
|
244
|
+
"""
|
|
245
|
+
sae_out_pre = feature_acts @ self.W_dec
|
|
246
|
+
# since this is a tied SAE, we need to make sure b_dec is only applied if applied at input
|
|
247
|
+
if self.cfg.apply_b_dec_to_input:
|
|
248
|
+
sae_out_pre = sae_out_pre + self.b_dec
|
|
249
|
+
sae_out_pre = self.hook_sae_recons(sae_out_pre)
|
|
250
|
+
sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
|
|
251
|
+
return self.reshape_fn_out(sae_out_pre, self.d_head)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
# --- shared ---
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def _encode_matching_pursuit(
|
|
258
|
+
sae_in_centered: torch.Tensor,
|
|
259
|
+
W_dec: torch.Tensor,
|
|
260
|
+
residual_threshold: float,
|
|
261
|
+
max_iterations: int | None,
|
|
262
|
+
stop_on_duplicate_support: bool,
|
|
263
|
+
) -> torch.Tensor:
|
|
264
|
+
"""
|
|
265
|
+
Matching pursuit encoding.
|
|
266
|
+
|
|
267
|
+
Args:
|
|
268
|
+
sae_in_centered: Input activations, centered by b_dec. Shape [..., d_in].
|
|
269
|
+
W_dec: Decoder weight matrix. Shape [d_sae, d_in].
|
|
270
|
+
residual_threshold: Stop when residual norm falls below this.
|
|
271
|
+
max_iterations: Maximum iterations (default: d_in). Prevents infinite loops.
|
|
272
|
+
stop_on_duplicate_support: Whether to stop selecting latents if the support set has not changed from the previous iteration.
|
|
273
|
+
"""
|
|
274
|
+
residual = sae_in_centered.clone()
|
|
275
|
+
|
|
276
|
+
stop_on_residual_threshold = residual_threshold > 0
|
|
277
|
+
|
|
278
|
+
# Handle multi-dimensional inputs by flattening all but the last dimension
|
|
279
|
+
original_shape = residual.shape
|
|
280
|
+
if residual.ndim > 2:
|
|
281
|
+
residual = residual.reshape(-1, residual.shape[-1])
|
|
282
|
+
|
|
283
|
+
batch_size = residual.shape[0]
|
|
284
|
+
d_sae, d_in = W_dec.shape
|
|
285
|
+
|
|
286
|
+
if max_iterations is None:
|
|
287
|
+
max_iterations = d_in # Sensible upper bound
|
|
288
|
+
|
|
289
|
+
acts = torch.zeros(batch_size, d_sae, device=W_dec.device, dtype=residual.dtype)
|
|
290
|
+
prev_support = torch.zeros(batch_size, d_sae, dtype=torch.bool, device=W_dec.device)
|
|
291
|
+
done = torch.zeros(batch_size, dtype=torch.bool, device=W_dec.device)
|
|
292
|
+
|
|
293
|
+
for _ in range(max_iterations):
|
|
294
|
+
# Find indices without gradients - the full [batch, d_sae] matmul result
|
|
295
|
+
# doesn't need to be saved for backward since max indices don't need gradients
|
|
296
|
+
with torch.no_grad():
|
|
297
|
+
indices = (residual @ W_dec.T).relu().max(dim=1, keepdim=True).indices
|
|
298
|
+
indices_flat = indices.squeeze(1) # [batch_size]
|
|
299
|
+
|
|
300
|
+
# Compute values with gradients using only the selected decoder rows.
|
|
301
|
+
# This stores [batch, d_in] for backward instead of [batch, d_sae].
|
|
302
|
+
selected_dec = W_dec[indices_flat] # [batch_size, d_in]
|
|
303
|
+
values = (residual * selected_dec).sum(dim=-1, keepdim=True).relu()
|
|
304
|
+
|
|
305
|
+
# Mask values for samples that are already done
|
|
306
|
+
active_mask = (~done).unsqueeze(1)
|
|
307
|
+
masked_values = (values * active_mask.to(values.dtype)).to(acts.dtype)
|
|
308
|
+
|
|
309
|
+
acts.scatter_add_(1, indices, masked_values)
|
|
310
|
+
|
|
311
|
+
# Update residual
|
|
312
|
+
residual = residual - masked_values * selected_dec
|
|
313
|
+
|
|
314
|
+
if stop_on_duplicate_support or stop_on_residual_threshold:
|
|
315
|
+
with torch.no_grad():
|
|
316
|
+
support = acts != 0
|
|
317
|
+
|
|
318
|
+
# A sample is considered converged if:
|
|
319
|
+
# (1) the support set hasn't changed from the previous iteration (stability), or
|
|
320
|
+
# (2) the residual norm is below a given threshold (good enough reconstruction)
|
|
321
|
+
if stop_on_duplicate_support:
|
|
322
|
+
done = done | (support == prev_support).all(dim=1)
|
|
323
|
+
prev_support = support
|
|
324
|
+
if stop_on_residual_threshold:
|
|
325
|
+
done = done | (residual.norm(dim=-1) < residual_threshold)
|
|
326
|
+
|
|
327
|
+
if done.all():
|
|
328
|
+
break
|
|
329
|
+
|
|
330
|
+
# Reshape acts back to original shape (replacing last dimension with d_sae)
|
|
331
|
+
if len(original_shape) > 2:
|
|
332
|
+
acts = acts.reshape(*original_shape[:-1], acts.shape[-1])
|
|
333
|
+
|
|
334
|
+
return acts
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
sae_lens/__init__.py,sha256=
|
|
1
|
+
sae_lens/__init__.py,sha256=zRp1nmb41W1Pt1rvlKvRWw73UxjGyz1iHAzH9_X6_WQ,4725
|
|
2
2
|
sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
3
|
sae_lens/analysis/hooked_sae_transformer.py,sha256=dQRgGVwce8XwylL2AzJE7l9elhtMRFCs2hdUj-Qyy4g,14038
|
|
4
4
|
sae_lens/analysis/neuronpedia_integration.py,sha256=Gx1W7hUBEuMoasNcnOnZ1wmqbXDd1pSZ1nqKEya1HQc,4962
|
|
5
5
|
sae_lens/cache_activations_runner.py,sha256=Lvlz-k5-3XxVRtUdC4b1CiKyx5s0ckLa8GDGv9_kcxs,12566
|
|
6
|
-
sae_lens/config.py,sha256=
|
|
6
|
+
sae_lens/config.py,sha256=C982bUELhGHcfTwzeMTtXIf2hPtc946thYpUyctLiBo,30516
|
|
7
7
|
sae_lens/constants.py,sha256=CM-h9AjZNAl2aP7hVpKk7YsFHpu-_Lfhhmq2d5qPEVc,887
|
|
8
8
|
sae_lens/evals.py,sha256=P0NUsJeGzYxFBiVKhbPzd72IFKY4gH40HHlEZ3jEAmg,39598
|
|
9
9
|
sae_lens/llm_sae_training_runner.py,sha256=M7BK55gSFYu2qFQKABHX3c8i46P1LfODCeyHFzGGuqU,15196
|
|
@@ -12,12 +12,13 @@ sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,
|
|
|
12
12
|
sae_lens/loading/pretrained_sae_loaders.py,sha256=hq-dhxsEdUmlAnZEiZBqX7lNyQQwZ6KXmXZWpzAc5FY,63638
|
|
13
13
|
sae_lens/loading/pretrained_saes_directory.py,sha256=hejNfLUepYCSGPalRfQwxxCEUqMMUPsn1tufwvwct5k,3820
|
|
14
14
|
sae_lens/pretokenize_runner.py,sha256=amJwIz3CKi2s2wNQn-10E7eAV7VFhNqtFDNTeTkwEI8,7133
|
|
15
|
-
sae_lens/pretrained_saes.yaml,sha256=
|
|
15
|
+
sae_lens/pretrained_saes.yaml,sha256=Hn8jXwZ7V6QQxzgu41LFEP-LAzuDxwYL5vhoar-pPX8,1509922
|
|
16
16
|
sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
|
|
17
|
-
sae_lens/saes/__init__.py,sha256=
|
|
17
|
+
sae_lens/saes/__init__.py,sha256=SBqPaP6Gl5uPFwHlumAZATC4Wd26xKIYLAAAo4MSa5Q,2200
|
|
18
18
|
sae_lens/saes/batchtopk_sae.py,sha256=x4EbgZl0GUickRPcCmtKNGS2Ra3Uy1Z1OtF2FnrSabQ,5422
|
|
19
19
|
sae_lens/saes/gated_sae.py,sha256=mHnmw-RD7hqIbP9_EBj3p2SK0OqQIkZivdOKRygeRgw,8825
|
|
20
20
|
sae_lens/saes/jumprelu_sae.py,sha256=udjGHp3WTABQSL2Qq57j-bINWX61GCmo68EmdjMOXoo,13310
|
|
21
|
+
sae_lens/saes/matching_pursuit_sae.py,sha256=08_G9p1YMLnE5qZVCPp6gll-iG6nHRbMMASf4_bkFt8,13207
|
|
21
22
|
sae_lens/saes/matryoshka_batchtopk_sae.py,sha256=Qr6htt1HHOuO9FXI9hyaPSnGFIiJG-v7y1t1CEmkFzM,5995
|
|
22
23
|
sae_lens/saes/sae.py,sha256=fzXv8lwHskSxsf8hm_wlKPkpq50iafmBjBNQzwZ6a00,40050
|
|
23
24
|
sae_lens/saes/standard_sae.py,sha256=nEVETwAmRD2tyX7ESIic1fij48gAq1Dh7s_GQ2fqCZ4,5747
|
|
@@ -35,7 +36,7 @@ sae_lens/training/types.py,sha256=1FpLx_Doda9vZpmfm-x1e8wGBYpyhe9Kpb_JuM5nIFM,90
|
|
|
35
36
|
sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
|
|
36
37
|
sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
|
|
37
38
|
sae_lens/util.py,sha256=spkcmQUsjVYFn5H2032nQYr1CKGVnv3tAdfIpY59-Mg,3919
|
|
38
|
-
sae_lens-6.
|
|
39
|
-
sae_lens-6.
|
|
40
|
-
sae_lens-6.
|
|
41
|
-
sae_lens-6.
|
|
39
|
+
sae_lens-6.26.1.dist-info/METADATA,sha256=yoE6CFgQ9L5SLzI3Zgr8H8CfUBgSimihGyEIvKd8TW8,5361
|
|
40
|
+
sae_lens-6.26.1.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
|
|
41
|
+
sae_lens-6.26.1.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
|
|
42
|
+
sae_lens-6.26.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|