sae-lens 6.0.0rc1__py3-none-any.whl → 6.0.0rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +55 -18
- sae_lens/analysis/hooked_sae_transformer.py +10 -10
- sae_lens/analysis/neuronpedia_integration.py +13 -11
- sae_lens/cache_activations_runner.py +9 -7
- sae_lens/config.py +105 -235
- sae_lens/constants.py +20 -0
- sae_lens/evals.py +34 -31
- sae_lens/{sae_training_runner.py → llm_sae_training_runner.py} +103 -70
- sae_lens/load_model.py +53 -5
- sae_lens/loading/pretrained_sae_loaders.py +36 -10
- sae_lens/registry.py +49 -0
- sae_lens/saes/__init__.py +48 -0
- sae_lens/saes/gated_sae.py +70 -59
- sae_lens/saes/jumprelu_sae.py +58 -72
- sae_lens/saes/sae.py +248 -273
- sae_lens/saes/standard_sae.py +75 -57
- sae_lens/saes/topk_sae.py +72 -83
- sae_lens/training/activation_scaler.py +53 -0
- sae_lens/training/activations_store.py +105 -184
- sae_lens/training/mixing_buffer.py +56 -0
- sae_lens/training/optim.py +60 -36
- sae_lens/training/sae_trainer.py +134 -158
- sae_lens/training/types.py +5 -0
- sae_lens/training/upload_saes_to_huggingface.py +11 -5
- sae_lens/util.py +47 -0
- {sae_lens-6.0.0rc1.dist-info → sae_lens-6.0.0rc3.dist-info}/METADATA +1 -1
- sae_lens-6.0.0rc3.dist-info/RECORD +38 -0
- {sae_lens-6.0.0rc1.dist-info → sae_lens-6.0.0rc3.dist-info}/WHEEL +1 -1
- sae_lens/regsitry.py +0 -34
- sae_lens-6.0.0rc1.dist-info/RECORD +0 -32
- {sae_lens-6.0.0rc1.dist-info → sae_lens-6.0.0rc3.dist-info}/LICENSE +0 -0
sae_lens/saes/standard_sae.py
CHANGED
|
@@ -1,13 +1,37 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
1
4
|
import numpy as np
|
|
2
5
|
import torch
|
|
3
6
|
from jaxtyping import Float
|
|
4
7
|
from numpy.typing import NDArray
|
|
5
8
|
from torch import nn
|
|
9
|
+
from typing_extensions import override
|
|
10
|
+
|
|
11
|
+
from sae_lens.saes.sae import (
|
|
12
|
+
SAE,
|
|
13
|
+
SAEConfig,
|
|
14
|
+
TrainCoefficientConfig,
|
|
15
|
+
TrainingSAE,
|
|
16
|
+
TrainingSAEConfig,
|
|
17
|
+
TrainStepInput,
|
|
18
|
+
)
|
|
19
|
+
from sae_lens.util import filter_valid_dataclass_fields
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class StandardSAEConfig(SAEConfig):
|
|
24
|
+
"""
|
|
25
|
+
Configuration class for a StandardSAE.
|
|
26
|
+
"""
|
|
6
27
|
|
|
7
|
-
|
|
28
|
+
@override
|
|
29
|
+
@classmethod
|
|
30
|
+
def architecture(cls) -> str:
|
|
31
|
+
return "standard"
|
|
8
32
|
|
|
9
33
|
|
|
10
|
-
class StandardSAE(SAE):
|
|
34
|
+
class StandardSAE(SAE[StandardSAEConfig]):
|
|
11
35
|
"""
|
|
12
36
|
StandardSAE is an inference-only implementation of a Sparse Autoencoder (SAE)
|
|
13
37
|
using a simple linear encoder and decoder.
|
|
@@ -23,31 +47,14 @@ class StandardSAE(SAE):
|
|
|
23
47
|
|
|
24
48
|
b_enc: nn.Parameter
|
|
25
49
|
|
|
26
|
-
def __init__(self, cfg:
|
|
50
|
+
def __init__(self, cfg: StandardSAEConfig, use_error_term: bool = False):
|
|
27
51
|
super().__init__(cfg, use_error_term)
|
|
28
52
|
|
|
53
|
+
@override
|
|
29
54
|
def initialize_weights(self) -> None:
|
|
30
55
|
# Initialize encoder weights and bias.
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
)
|
|
34
|
-
self.b_dec = nn.Parameter(
|
|
35
|
-
torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
|
|
36
|
-
)
|
|
37
|
-
|
|
38
|
-
# Use Kaiming Uniform for W_enc
|
|
39
|
-
w_enc_data = torch.empty(
|
|
40
|
-
self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
|
|
41
|
-
)
|
|
42
|
-
nn.init.kaiming_uniform_(w_enc_data)
|
|
43
|
-
self.W_enc = nn.Parameter(w_enc_data)
|
|
44
|
-
|
|
45
|
-
# Use Kaiming Uniform for W_dec
|
|
46
|
-
w_dec_data = torch.empty(
|
|
47
|
-
self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
|
|
48
|
-
)
|
|
49
|
-
nn.init.kaiming_uniform_(w_dec_data)
|
|
50
|
-
self.W_dec = nn.Parameter(w_dec_data)
|
|
56
|
+
super().initialize_weights()
|
|
57
|
+
_init_weights_standard(self)
|
|
51
58
|
|
|
52
59
|
def encode(
|
|
53
60
|
self, x: Float[torch.Tensor, "... d_in"]
|
|
@@ -70,11 +77,9 @@ class StandardSAE(SAE):
|
|
|
70
77
|
Decode the feature activations back to the input space.
|
|
71
78
|
Now, if hook_z reshaping is turned on, we reverse the flattening.
|
|
72
79
|
"""
|
|
73
|
-
# 1)
|
|
74
|
-
|
|
75
|
-
# 2)
|
|
76
|
-
sae_out_pre = scaled_features @ self.W_dec + self.b_dec
|
|
77
|
-
# 3) hook reconstruction
|
|
80
|
+
# 1) linear transform
|
|
81
|
+
sae_out_pre = feature_acts @ self.W_dec + self.b_dec
|
|
82
|
+
# 2) hook reconstruction
|
|
78
83
|
sae_out_pre = self.hook_sae_recons(sae_out_pre)
|
|
79
84
|
# 4) optional out-normalization (e.g. constant_norm_rescale or layer_norm)
|
|
80
85
|
sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
|
|
@@ -82,7 +87,23 @@ class StandardSAE(SAE):
|
|
|
82
87
|
return self.reshape_fn_out(sae_out_pre, self.d_head)
|
|
83
88
|
|
|
84
89
|
|
|
85
|
-
|
|
90
|
+
@dataclass
|
|
91
|
+
class StandardTrainingSAEConfig(TrainingSAEConfig):
|
|
92
|
+
"""
|
|
93
|
+
Configuration class for training a StandardTrainingSAE.
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
l1_coefficient: float = 1.0
|
|
97
|
+
lp_norm: float = 1.0
|
|
98
|
+
l1_warm_up_steps: int = 0
|
|
99
|
+
|
|
100
|
+
@override
|
|
101
|
+
@classmethod
|
|
102
|
+
def architecture(cls) -> str:
|
|
103
|
+
return "standard"
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class StandardTrainingSAE(TrainingSAE[StandardTrainingSAEConfig]):
|
|
86
107
|
"""
|
|
87
108
|
StandardTrainingSAE is a concrete implementation of BaseTrainingSAE using the "standard" SAE architecture.
|
|
88
109
|
It implements:
|
|
@@ -96,31 +117,17 @@ class StandardTrainingSAE(TrainingSAE):
|
|
|
96
117
|
b_enc: nn.Parameter
|
|
97
118
|
|
|
98
119
|
def initialize_weights(self) -> None:
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
StandardSAE.initialize_weights(self) # type: ignore
|
|
102
|
-
|
|
103
|
-
# Complex init logic from original TrainingSAE
|
|
104
|
-
if self.cfg.decoder_orthogonal_init:
|
|
105
|
-
self.W_dec.data = nn.init.orthogonal_(self.W_dec.data.T).T
|
|
106
|
-
|
|
107
|
-
elif self.cfg.decoder_heuristic_init:
|
|
108
|
-
self.W_dec.data = torch.rand( # Changed from Parameter to data assignment
|
|
109
|
-
self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
|
|
110
|
-
)
|
|
111
|
-
self.initialize_decoder_norm_constant_norm()
|
|
112
|
-
|
|
113
|
-
if self.cfg.init_encoder_as_decoder_transpose:
|
|
114
|
-
self.W_enc.data = self.W_dec.data.T.clone().contiguous() # type: ignore
|
|
120
|
+
super().initialize_weights()
|
|
121
|
+
_init_weights_standard(self)
|
|
115
122
|
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
123
|
+
@override
|
|
124
|
+
def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]:
|
|
125
|
+
return {
|
|
126
|
+
"l1": TrainCoefficientConfig(
|
|
127
|
+
value=self.cfg.l1_coefficient,
|
|
128
|
+
warm_up_steps=self.cfg.l1_warm_up_steps,
|
|
129
|
+
),
|
|
130
|
+
}
|
|
124
131
|
|
|
125
132
|
def encode_with_hidden_pre(
|
|
126
133
|
self, x: Float[torch.Tensor, "... d_in"]
|
|
@@ -148,13 +155,11 @@ class StandardTrainingSAE(TrainingSAE):
|
|
|
148
155
|
sae_out: torch.Tensor,
|
|
149
156
|
) -> dict[str, torch.Tensor]:
|
|
150
157
|
# The "standard" auxiliary loss is a sparsity penalty on the feature activations
|
|
151
|
-
weighted_feature_acts = feature_acts
|
|
152
|
-
if self.cfg.scale_sparsity_penalty_by_decoder_norm:
|
|
153
|
-
weighted_feature_acts = feature_acts * self.W_dec.norm(dim=1)
|
|
158
|
+
weighted_feature_acts = feature_acts * self.W_dec.norm(dim=1)
|
|
154
159
|
|
|
155
160
|
# Compute the p-norm (set by cfg.lp_norm) over the feature dimension
|
|
156
161
|
sparsity = weighted_feature_acts.norm(p=self.cfg.lp_norm, dim=-1)
|
|
157
|
-
l1_loss = (step_input.
|
|
162
|
+
l1_loss = (step_input.coefficients["l1"] * sparsity).mean()
|
|
158
163
|
|
|
159
164
|
return {"l1_loss": l1_loss}
|
|
160
165
|
|
|
@@ -165,3 +170,16 @@ class StandardTrainingSAE(TrainingSAE):
|
|
|
165
170
|
**super().log_histograms(),
|
|
166
171
|
"weights/b_e": b_e_dist,
|
|
167
172
|
}
|
|
173
|
+
|
|
174
|
+
def to_inference_config_dict(self) -> dict[str, Any]:
|
|
175
|
+
return filter_valid_dataclass_fields(
|
|
176
|
+
self.cfg.to_dict(), StandardSAEConfig, ["architecture"]
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def _init_weights_standard(
|
|
181
|
+
sae: SAE[StandardSAEConfig] | TrainingSAE[StandardTrainingSAEConfig],
|
|
182
|
+
) -> None:
|
|
183
|
+
sae.b_enc = nn.Parameter(
|
|
184
|
+
torch.zeros(sae.cfg.d_sae, dtype=sae.dtype, device=sae.device)
|
|
185
|
+
)
|
sae_lens/saes/topk_sae.py
CHANGED
|
@@ -1,18 +1,22 @@
|
|
|
1
1
|
"""Inference-only TopKSAE variant, similar in spirit to StandardSAE but using a TopK-based activation."""
|
|
2
2
|
|
|
3
|
-
from
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, Callable
|
|
4
5
|
|
|
5
6
|
import torch
|
|
6
7
|
from jaxtyping import Float
|
|
7
8
|
from torch import nn
|
|
9
|
+
from typing_extensions import override
|
|
8
10
|
|
|
9
11
|
from sae_lens.saes.sae import (
|
|
10
12
|
SAE,
|
|
11
13
|
SAEConfig,
|
|
14
|
+
TrainCoefficientConfig,
|
|
12
15
|
TrainingSAE,
|
|
13
16
|
TrainingSAEConfig,
|
|
14
17
|
TrainStepInput,
|
|
15
18
|
)
|
|
19
|
+
from sae_lens.util import filter_valid_dataclass_fields
|
|
16
20
|
|
|
17
21
|
|
|
18
22
|
class TopK(nn.Module):
|
|
@@ -45,14 +49,30 @@ class TopK(nn.Module):
|
|
|
45
49
|
return result
|
|
46
50
|
|
|
47
51
|
|
|
48
|
-
|
|
52
|
+
@dataclass
|
|
53
|
+
class TopKSAEConfig(SAEConfig):
|
|
54
|
+
"""
|
|
55
|
+
Configuration class for a TopKSAE.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
k: int = 100
|
|
59
|
+
|
|
60
|
+
@override
|
|
61
|
+
@classmethod
|
|
62
|
+
def architecture(cls) -> str:
|
|
63
|
+
return "topk"
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class TopKSAE(SAE[TopKSAEConfig]):
|
|
49
67
|
"""
|
|
50
68
|
An inference-only sparse autoencoder using a "topk" activation function.
|
|
51
69
|
It uses linear encoder and decoder layers, applying the TopK activation
|
|
52
70
|
to the hidden pre-activation in its encode step.
|
|
53
71
|
"""
|
|
54
72
|
|
|
55
|
-
|
|
73
|
+
b_enc: nn.Parameter
|
|
74
|
+
|
|
75
|
+
def __init__(self, cfg: TopKSAEConfig, use_error_term: bool = False):
|
|
56
76
|
"""
|
|
57
77
|
Args:
|
|
58
78
|
cfg: SAEConfig defining model size and behavior.
|
|
@@ -60,38 +80,11 @@ class TopKSAE(SAE):
|
|
|
60
80
|
"""
|
|
61
81
|
super().__init__(cfg, use_error_term)
|
|
62
82
|
|
|
63
|
-
|
|
64
|
-
raise ValueError("TopKSAE must use a TopK activation function.")
|
|
65
|
-
|
|
83
|
+
@override
|
|
66
84
|
def initialize_weights(self) -> None:
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
- b_enc, b_dec are zero-initialized
|
|
71
|
-
- W_enc, W_dec are Kaiming Uniform
|
|
72
|
-
"""
|
|
73
|
-
# encoder bias
|
|
74
|
-
self.b_enc = nn.Parameter(
|
|
75
|
-
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
|
|
76
|
-
)
|
|
77
|
-
# decoder bias
|
|
78
|
-
self.b_dec = nn.Parameter(
|
|
79
|
-
torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
|
|
80
|
-
)
|
|
81
|
-
|
|
82
|
-
# encoder weight
|
|
83
|
-
w_enc_data = torch.empty(
|
|
84
|
-
self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
|
|
85
|
-
)
|
|
86
|
-
nn.init.kaiming_uniform_(w_enc_data)
|
|
87
|
-
self.W_enc = nn.Parameter(w_enc_data)
|
|
88
|
-
|
|
89
|
-
# decoder weight
|
|
90
|
-
w_dec_data = torch.empty(
|
|
91
|
-
self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
|
|
92
|
-
)
|
|
93
|
-
nn.init.kaiming_uniform_(w_dec_data)
|
|
94
|
-
self.W_dec = nn.Parameter(w_dec_data)
|
|
85
|
+
# Initialize encoder weights and bias.
|
|
86
|
+
super().initialize_weights()
|
|
87
|
+
_init_weights_topk(self)
|
|
95
88
|
|
|
96
89
|
def encode(
|
|
97
90
|
self, x: Float[torch.Tensor, "... d_in"]
|
|
@@ -114,28 +107,31 @@ class TopKSAE(SAE):
|
|
|
114
107
|
Applies optional finetuning scaling, hooking to recons, out normalization,
|
|
115
108
|
and optional head reshaping.
|
|
116
109
|
"""
|
|
117
|
-
|
|
118
|
-
sae_out_pre = scaled_features @ self.W_dec + self.b_dec
|
|
110
|
+
sae_out_pre = feature_acts @ self.W_dec + self.b_dec
|
|
119
111
|
sae_out_pre = self.hook_sae_recons(sae_out_pre)
|
|
120
112
|
sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
|
|
121
113
|
return self.reshape_fn_out(sae_out_pre, self.d_head)
|
|
122
114
|
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
115
|
+
@override
|
|
116
|
+
def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
|
|
117
|
+
return TopK(self.cfg.k)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@dataclass
|
|
121
|
+
class TopKTrainingSAEConfig(TrainingSAEConfig):
|
|
122
|
+
"""
|
|
123
|
+
Configuration class for training a TopKTrainingSAE.
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
k: int = 100
|
|
127
|
+
|
|
128
|
+
@override
|
|
129
|
+
@classmethod
|
|
130
|
+
def architecture(cls) -> str:
|
|
131
|
+
return "topk"
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
|
|
139
135
|
"""
|
|
140
136
|
TopK variant with training functionality. Injects noise during training, optionally
|
|
141
137
|
calculates a topk-related auxiliary loss, etc.
|
|
@@ -143,32 +139,13 @@ class TopKTrainingSAE(TrainingSAE):
|
|
|
143
139
|
|
|
144
140
|
b_enc: nn.Parameter
|
|
145
141
|
|
|
146
|
-
def __init__(self, cfg:
|
|
142
|
+
def __init__(self, cfg: TopKTrainingSAEConfig, use_error_term: bool = False):
|
|
147
143
|
super().__init__(cfg, use_error_term)
|
|
148
144
|
|
|
149
|
-
|
|
150
|
-
raise ValueError("TopKSAE must use a TopK activation function.")
|
|
151
|
-
|
|
145
|
+
@override
|
|
152
146
|
def initialize_weights(self) -> None:
|
|
153
|
-
|
|
154
|
-
self
|
|
155
|
-
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
|
|
156
|
-
)
|
|
157
|
-
self.b_dec = nn.Parameter(
|
|
158
|
-
torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
|
|
159
|
-
)
|
|
160
|
-
|
|
161
|
-
w_enc_data = torch.empty(
|
|
162
|
-
self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
|
|
163
|
-
)
|
|
164
|
-
nn.init.kaiming_uniform_(w_enc_data)
|
|
165
|
-
self.W_enc = nn.Parameter(w_enc_data)
|
|
166
|
-
|
|
167
|
-
w_dec_data = torch.empty(
|
|
168
|
-
self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
|
|
169
|
-
)
|
|
170
|
-
nn.init.kaiming_uniform_(w_dec_data)
|
|
171
|
-
self.W_dec = nn.Parameter(w_dec_data)
|
|
147
|
+
super().initialize_weights()
|
|
148
|
+
_init_weights_topk(self)
|
|
172
149
|
|
|
173
150
|
def encode_with_hidden_pre(
|
|
174
151
|
self, x: Float[torch.Tensor, "... d_in"]
|
|
@@ -207,14 +184,13 @@ class TopKTrainingSAE(TrainingSAE):
|
|
|
207
184
|
)
|
|
208
185
|
return {"auxiliary_reconstruction_loss": topk_loss}
|
|
209
186
|
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
return super()._get_activation_fn()
|
|
187
|
+
@override
|
|
188
|
+
def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
|
|
189
|
+
return TopK(self.cfg.k)
|
|
190
|
+
|
|
191
|
+
@override
|
|
192
|
+
def get_coefficients(self) -> dict[str, TrainCoefficientConfig | float]:
|
|
193
|
+
return {}
|
|
218
194
|
|
|
219
195
|
def calculate_topk_aux_loss(
|
|
220
196
|
self,
|
|
@@ -288,6 +264,11 @@ class TopKTrainingSAE(TrainingSAE):
|
|
|
288
264
|
|
|
289
265
|
return auxk_acts
|
|
290
266
|
|
|
267
|
+
def to_inference_config_dict(self) -> dict[str, Any]:
|
|
268
|
+
return filter_valid_dataclass_fields(
|
|
269
|
+
self.cfg.to_dict(), TopKSAEConfig, ["architecture"]
|
|
270
|
+
)
|
|
271
|
+
|
|
291
272
|
|
|
292
273
|
def _calculate_topk_aux_acts(
|
|
293
274
|
k_aux: int,
|
|
@@ -303,3 +284,11 @@ def _calculate_topk_aux_acts(
|
|
|
303
284
|
auxk_acts.scatter_(-1, auxk_topk.indices, auxk_topk.values)
|
|
304
285
|
# Set activations to zero for all but top k_aux dead latents
|
|
305
286
|
return auxk_acts
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
def _init_weights_topk(
|
|
290
|
+
sae: SAE[TopKSAEConfig] | TrainingSAE[TopKTrainingSAEConfig],
|
|
291
|
+
) -> None:
|
|
292
|
+
sae.b_enc = nn.Parameter(
|
|
293
|
+
torch.zeros(sae.cfg.d_sae, dtype=sae.dtype, device=sae.device)
|
|
294
|
+
)
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from statistics import mean
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from tqdm import tqdm
|
|
7
|
+
|
|
8
|
+
from sae_lens.training.types import DataProvider
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class ActivationScaler:
|
|
13
|
+
scaling_factor: float | None = None
|
|
14
|
+
|
|
15
|
+
def scale(self, acts: torch.Tensor) -> torch.Tensor:
|
|
16
|
+
return acts if self.scaling_factor is None else acts * self.scaling_factor
|
|
17
|
+
|
|
18
|
+
def unscale(self, acts: torch.Tensor) -> torch.Tensor:
|
|
19
|
+
return acts if self.scaling_factor is None else acts / self.scaling_factor
|
|
20
|
+
|
|
21
|
+
def __call__(self, acts: torch.Tensor) -> torch.Tensor:
|
|
22
|
+
return self.scale(acts)
|
|
23
|
+
|
|
24
|
+
@torch.no_grad()
|
|
25
|
+
def _calculate_mean_norm(
|
|
26
|
+
self, data_provider: DataProvider, n_batches_for_norm_estimate: int = int(1e3)
|
|
27
|
+
) -> float:
|
|
28
|
+
norms_per_batch: list[float] = []
|
|
29
|
+
for _ in tqdm(
|
|
30
|
+
range(n_batches_for_norm_estimate), desc="Estimating norm scaling factor"
|
|
31
|
+
):
|
|
32
|
+
acts = next(data_provider)
|
|
33
|
+
norms_per_batch.append(acts.norm(dim=-1).mean().item())
|
|
34
|
+
return mean(norms_per_batch)
|
|
35
|
+
|
|
36
|
+
def estimate_scaling_factor(
|
|
37
|
+
self,
|
|
38
|
+
d_in: int,
|
|
39
|
+
data_provider: DataProvider,
|
|
40
|
+
n_batches_for_norm_estimate: int = int(1e3),
|
|
41
|
+
):
|
|
42
|
+
mean_norm = self._calculate_mean_norm(
|
|
43
|
+
data_provider, n_batches_for_norm_estimate
|
|
44
|
+
)
|
|
45
|
+
self.scaling_factor = (d_in**0.5) / mean_norm
|
|
46
|
+
|
|
47
|
+
def save(self, file_path: str):
|
|
48
|
+
"""save the state dict to a file in json format"""
|
|
49
|
+
if not file_path.endswith(".json"):
|
|
50
|
+
raise ValueError("file_path must end with .json")
|
|
51
|
+
|
|
52
|
+
with open(file_path, "w") as f:
|
|
53
|
+
json.dump({"scaling_factor": self.scaling_factor}, f)
|