sae-lens 6.0.0rc1__py3-none-any.whl → 6.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +50 -16
- sae_lens/analysis/hooked_sae_transformer.py +10 -10
- sae_lens/analysis/neuronpedia_integration.py +13 -11
- sae_lens/cache_activations_runner.py +2 -1
- sae_lens/config.py +59 -231
- sae_lens/constants.py +18 -0
- sae_lens/evals.py +16 -13
- sae_lens/loading/pretrained_sae_loaders.py +36 -3
- sae_lens/registry.py +49 -0
- sae_lens/sae_training_runner.py +22 -21
- sae_lens/saes/__init__.py +48 -0
- sae_lens/saes/gated_sae.py +70 -59
- sae_lens/saes/jumprelu_sae.py +58 -72
- sae_lens/saes/sae.py +250 -272
- sae_lens/saes/standard_sae.py +75 -57
- sae_lens/saes/topk_sae.py +72 -83
- sae_lens/training/activations_store.py +31 -15
- sae_lens/training/optim.py +60 -36
- sae_lens/training/sae_trainer.py +44 -69
- sae_lens/training/upload_saes_to_huggingface.py +11 -5
- sae_lens/util.py +28 -0
- {sae_lens-6.0.0rc1.dist-info → sae_lens-6.0.0rc2.dist-info}/METADATA +1 -1
- sae_lens-6.0.0rc2.dist-info/RECORD +35 -0
- {sae_lens-6.0.0rc1.dist-info → sae_lens-6.0.0rc2.dist-info}/WHEEL +1 -1
- sae_lens/regsitry.py +0 -34
- sae_lens-6.0.0rc1.dist-info/RECORD +0 -32
- {sae_lens-6.0.0rc1.dist-info → sae_lens-6.0.0rc2.dist-info}/LICENSE +0 -0
sae_lens/saes/standard_sae.py
CHANGED
|
@@ -1,13 +1,37 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
1
4
|
import numpy as np
|
|
2
5
|
import torch
|
|
3
6
|
from jaxtyping import Float
|
|
4
7
|
from numpy.typing import NDArray
|
|
5
8
|
from torch import nn
|
|
9
|
+
from typing_extensions import override
|
|
10
|
+
|
|
11
|
+
from sae_lens.saes.sae import (
|
|
12
|
+
SAE,
|
|
13
|
+
SAEConfig,
|
|
14
|
+
TrainCoefficientConfig,
|
|
15
|
+
TrainingSAE,
|
|
16
|
+
TrainingSAEConfig,
|
|
17
|
+
TrainStepInput,
|
|
18
|
+
)
|
|
19
|
+
from sae_lens.util import filter_valid_dataclass_fields
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class StandardSAEConfig(SAEConfig):
|
|
24
|
+
"""
|
|
25
|
+
Configuration class for a StandardSAE.
|
|
26
|
+
"""
|
|
6
27
|
|
|
7
|
-
|
|
28
|
+
@override
|
|
29
|
+
@classmethod
|
|
30
|
+
def architecture(cls) -> str:
|
|
31
|
+
return "standard"
|
|
8
32
|
|
|
9
33
|
|
|
10
|
-
class StandardSAE(SAE):
|
|
34
|
+
class StandardSAE(SAE[StandardSAEConfig]):
|
|
11
35
|
"""
|
|
12
36
|
StandardSAE is an inference-only implementation of a Sparse Autoencoder (SAE)
|
|
13
37
|
using a simple linear encoder and decoder.
|
|
@@ -23,31 +47,14 @@ class StandardSAE(SAE):
|
|
|
23
47
|
|
|
24
48
|
b_enc: nn.Parameter
|
|
25
49
|
|
|
26
|
-
def __init__(self, cfg:
|
|
50
|
+
def __init__(self, cfg: StandardSAEConfig, use_error_term: bool = False):
|
|
27
51
|
super().__init__(cfg, use_error_term)
|
|
28
52
|
|
|
53
|
+
@override
|
|
29
54
|
def initialize_weights(self) -> None:
|
|
30
55
|
# Initialize encoder weights and bias.
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
)
|
|
34
|
-
self.b_dec = nn.Parameter(
|
|
35
|
-
torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
|
|
36
|
-
)
|
|
37
|
-
|
|
38
|
-
# Use Kaiming Uniform for W_enc
|
|
39
|
-
w_enc_data = torch.empty(
|
|
40
|
-
self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
|
|
41
|
-
)
|
|
42
|
-
nn.init.kaiming_uniform_(w_enc_data)
|
|
43
|
-
self.W_enc = nn.Parameter(w_enc_data)
|
|
44
|
-
|
|
45
|
-
# Use Kaiming Uniform for W_dec
|
|
46
|
-
w_dec_data = torch.empty(
|
|
47
|
-
self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
|
|
48
|
-
)
|
|
49
|
-
nn.init.kaiming_uniform_(w_dec_data)
|
|
50
|
-
self.W_dec = nn.Parameter(w_dec_data)
|
|
56
|
+
super().initialize_weights()
|
|
57
|
+
_init_weights_standard(self)
|
|
51
58
|
|
|
52
59
|
def encode(
|
|
53
60
|
self, x: Float[torch.Tensor, "... d_in"]
|
|
@@ -70,11 +77,9 @@ class StandardSAE(SAE):
|
|
|
70
77
|
Decode the feature activations back to the input space.
|
|
71
78
|
Now, if hook_z reshaping is turned on, we reverse the flattening.
|
|
72
79
|
"""
|
|
73
|
-
# 1)
|
|
74
|
-
|
|
75
|
-
# 2)
|
|
76
|
-
sae_out_pre = scaled_features @ self.W_dec + self.b_dec
|
|
77
|
-
# 3) hook reconstruction
|
|
80
|
+
# 1) linear transform
|
|
81
|
+
sae_out_pre = feature_acts @ self.W_dec + self.b_dec
|
|
82
|
+
# 2) hook reconstruction
|
|
78
83
|
sae_out_pre = self.hook_sae_recons(sae_out_pre)
|
|
79
84
|
# 4) optional out-normalization (e.g. constant_norm_rescale or layer_norm)
|
|
80
85
|
sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
|
|
@@ -82,7 +87,23 @@ class StandardSAE(SAE):
|
|
|
82
87
|
return self.reshape_fn_out(sae_out_pre, self.d_head)
|
|
83
88
|
|
|
84
89
|
|
|
85
|
-
|
|
90
|
+
@dataclass
|
|
91
|
+
class StandardTrainingSAEConfig(TrainingSAEConfig):
|
|
92
|
+
"""
|
|
93
|
+
Configuration class for training a StandardTrainingSAE.
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
l1_coefficient: float = 1.0
|
|
97
|
+
lp_norm: float = 1.0
|
|
98
|
+
l1_warm_up_steps: int = 0
|
|
99
|
+
|
|
100
|
+
@override
|
|
101
|
+
@classmethod
|
|
102
|
+
def architecture(cls) -> str:
|
|
103
|
+
return "standard"
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class StandardTrainingSAE(TrainingSAE[StandardTrainingSAEConfig]):
|
|
86
107
|
"""
|
|
87
108
|
StandardTrainingSAE is a concrete implementation of BaseTrainingSAE using the "standard" SAE architecture.
|
|
88
109
|
It implements:
|
|
@@ -96,31 +117,17 @@ class StandardTrainingSAE(TrainingSAE):
|
|
|
96
117
|
b_enc: nn.Parameter
|
|
97
118
|
|
|
98
119
|
def initialize_weights(self) -> None:
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
StandardSAE.initialize_weights(self) # type: ignore
|
|
102
|
-
|
|
103
|
-
# Complex init logic from original TrainingSAE
|
|
104
|
-
if self.cfg.decoder_orthogonal_init:
|
|
105
|
-
self.W_dec.data = nn.init.orthogonal_(self.W_dec.data.T).T
|
|
106
|
-
|
|
107
|
-
elif self.cfg.decoder_heuristic_init:
|
|
108
|
-
self.W_dec.data = torch.rand( # Changed from Parameter to data assignment
|
|
109
|
-
self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
|
|
110
|
-
)
|
|
111
|
-
self.initialize_decoder_norm_constant_norm()
|
|
112
|
-
|
|
113
|
-
if self.cfg.init_encoder_as_decoder_transpose:
|
|
114
|
-
self.W_enc.data = self.W_dec.data.T.clone().contiguous() # type: ignore
|
|
120
|
+
super().initialize_weights()
|
|
121
|
+
_init_weights_standard(self)
|
|
115
122
|
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
123
|
+
@override
|
|
124
|
+
def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]:
|
|
125
|
+
return {
|
|
126
|
+
"l1": TrainCoefficientConfig(
|
|
127
|
+
value=self.cfg.l1_coefficient,
|
|
128
|
+
warm_up_steps=self.cfg.l1_warm_up_steps,
|
|
129
|
+
),
|
|
130
|
+
}
|
|
124
131
|
|
|
125
132
|
def encode_with_hidden_pre(
|
|
126
133
|
self, x: Float[torch.Tensor, "... d_in"]
|
|
@@ -148,13 +155,11 @@ class StandardTrainingSAE(TrainingSAE):
|
|
|
148
155
|
sae_out: torch.Tensor,
|
|
149
156
|
) -> dict[str, torch.Tensor]:
|
|
150
157
|
# The "standard" auxiliary loss is a sparsity penalty on the feature activations
|
|
151
|
-
weighted_feature_acts = feature_acts
|
|
152
|
-
if self.cfg.scale_sparsity_penalty_by_decoder_norm:
|
|
153
|
-
weighted_feature_acts = feature_acts * self.W_dec.norm(dim=1)
|
|
158
|
+
weighted_feature_acts = feature_acts * self.W_dec.norm(dim=1)
|
|
154
159
|
|
|
155
160
|
# Compute the p-norm (set by cfg.lp_norm) over the feature dimension
|
|
156
161
|
sparsity = weighted_feature_acts.norm(p=self.cfg.lp_norm, dim=-1)
|
|
157
|
-
l1_loss = (step_input.
|
|
162
|
+
l1_loss = (step_input.coefficients["l1"] * sparsity).mean()
|
|
158
163
|
|
|
159
164
|
return {"l1_loss": l1_loss}
|
|
160
165
|
|
|
@@ -165,3 +170,16 @@ class StandardTrainingSAE(TrainingSAE):
|
|
|
165
170
|
**super().log_histograms(),
|
|
166
171
|
"weights/b_e": b_e_dist,
|
|
167
172
|
}
|
|
173
|
+
|
|
174
|
+
def to_inference_config_dict(self) -> dict[str, Any]:
|
|
175
|
+
return filter_valid_dataclass_fields(
|
|
176
|
+
self.cfg.to_dict(), StandardSAEConfig, ["architecture"]
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def _init_weights_standard(
|
|
181
|
+
sae: SAE[StandardSAEConfig] | TrainingSAE[StandardTrainingSAEConfig],
|
|
182
|
+
) -> None:
|
|
183
|
+
sae.b_enc = nn.Parameter(
|
|
184
|
+
torch.zeros(sae.cfg.d_sae, dtype=sae.dtype, device=sae.device)
|
|
185
|
+
)
|
sae_lens/saes/topk_sae.py
CHANGED
|
@@ -1,18 +1,22 @@
|
|
|
1
1
|
"""Inference-only TopKSAE variant, similar in spirit to StandardSAE but using a TopK-based activation."""
|
|
2
2
|
|
|
3
|
-
from
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, Callable
|
|
4
5
|
|
|
5
6
|
import torch
|
|
6
7
|
from jaxtyping import Float
|
|
7
8
|
from torch import nn
|
|
9
|
+
from typing_extensions import override
|
|
8
10
|
|
|
9
11
|
from sae_lens.saes.sae import (
|
|
10
12
|
SAE,
|
|
11
13
|
SAEConfig,
|
|
14
|
+
TrainCoefficientConfig,
|
|
12
15
|
TrainingSAE,
|
|
13
16
|
TrainingSAEConfig,
|
|
14
17
|
TrainStepInput,
|
|
15
18
|
)
|
|
19
|
+
from sae_lens.util import filter_valid_dataclass_fields
|
|
16
20
|
|
|
17
21
|
|
|
18
22
|
class TopK(nn.Module):
|
|
@@ -45,14 +49,30 @@ class TopK(nn.Module):
|
|
|
45
49
|
return result
|
|
46
50
|
|
|
47
51
|
|
|
48
|
-
|
|
52
|
+
@dataclass
|
|
53
|
+
class TopKSAEConfig(SAEConfig):
|
|
54
|
+
"""
|
|
55
|
+
Configuration class for a TopKSAE.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
k: int = 100
|
|
59
|
+
|
|
60
|
+
@override
|
|
61
|
+
@classmethod
|
|
62
|
+
def architecture(cls) -> str:
|
|
63
|
+
return "topk"
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class TopKSAE(SAE[TopKSAEConfig]):
|
|
49
67
|
"""
|
|
50
68
|
An inference-only sparse autoencoder using a "topk" activation function.
|
|
51
69
|
It uses linear encoder and decoder layers, applying the TopK activation
|
|
52
70
|
to the hidden pre-activation in its encode step.
|
|
53
71
|
"""
|
|
54
72
|
|
|
55
|
-
|
|
73
|
+
b_enc: nn.Parameter
|
|
74
|
+
|
|
75
|
+
def __init__(self, cfg: TopKSAEConfig, use_error_term: bool = False):
|
|
56
76
|
"""
|
|
57
77
|
Args:
|
|
58
78
|
cfg: SAEConfig defining model size and behavior.
|
|
@@ -60,38 +80,11 @@ class TopKSAE(SAE):
|
|
|
60
80
|
"""
|
|
61
81
|
super().__init__(cfg, use_error_term)
|
|
62
82
|
|
|
63
|
-
|
|
64
|
-
raise ValueError("TopKSAE must use a TopK activation function.")
|
|
65
|
-
|
|
83
|
+
@override
|
|
66
84
|
def initialize_weights(self) -> None:
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
- b_enc, b_dec are zero-initialized
|
|
71
|
-
- W_enc, W_dec are Kaiming Uniform
|
|
72
|
-
"""
|
|
73
|
-
# encoder bias
|
|
74
|
-
self.b_enc = nn.Parameter(
|
|
75
|
-
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
|
|
76
|
-
)
|
|
77
|
-
# decoder bias
|
|
78
|
-
self.b_dec = nn.Parameter(
|
|
79
|
-
torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
|
|
80
|
-
)
|
|
81
|
-
|
|
82
|
-
# encoder weight
|
|
83
|
-
w_enc_data = torch.empty(
|
|
84
|
-
self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
|
|
85
|
-
)
|
|
86
|
-
nn.init.kaiming_uniform_(w_enc_data)
|
|
87
|
-
self.W_enc = nn.Parameter(w_enc_data)
|
|
88
|
-
|
|
89
|
-
# decoder weight
|
|
90
|
-
w_dec_data = torch.empty(
|
|
91
|
-
self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
|
|
92
|
-
)
|
|
93
|
-
nn.init.kaiming_uniform_(w_dec_data)
|
|
94
|
-
self.W_dec = nn.Parameter(w_dec_data)
|
|
85
|
+
# Initialize encoder weights and bias.
|
|
86
|
+
super().initialize_weights()
|
|
87
|
+
_init_weights_topk(self)
|
|
95
88
|
|
|
96
89
|
def encode(
|
|
97
90
|
self, x: Float[torch.Tensor, "... d_in"]
|
|
@@ -114,28 +107,31 @@ class TopKSAE(SAE):
|
|
|
114
107
|
Applies optional finetuning scaling, hooking to recons, out normalization,
|
|
115
108
|
and optional head reshaping.
|
|
116
109
|
"""
|
|
117
|
-
|
|
118
|
-
sae_out_pre = scaled_features @ self.W_dec + self.b_dec
|
|
110
|
+
sae_out_pre = feature_acts @ self.W_dec + self.b_dec
|
|
119
111
|
sae_out_pre = self.hook_sae_recons(sae_out_pre)
|
|
120
112
|
sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
|
|
121
113
|
return self.reshape_fn_out(sae_out_pre, self.d_head)
|
|
122
114
|
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
115
|
+
@override
|
|
116
|
+
def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
|
|
117
|
+
return TopK(self.cfg.k)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@dataclass
|
|
121
|
+
class TopKTrainingSAEConfig(TrainingSAEConfig):
|
|
122
|
+
"""
|
|
123
|
+
Configuration class for training a TopKTrainingSAE.
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
k: int = 100
|
|
127
|
+
|
|
128
|
+
@override
|
|
129
|
+
@classmethod
|
|
130
|
+
def architecture(cls) -> str:
|
|
131
|
+
return "topk"
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
|
|
139
135
|
"""
|
|
140
136
|
TopK variant with training functionality. Injects noise during training, optionally
|
|
141
137
|
calculates a topk-related auxiliary loss, etc.
|
|
@@ -143,32 +139,13 @@ class TopKTrainingSAE(TrainingSAE):
|
|
|
143
139
|
|
|
144
140
|
b_enc: nn.Parameter
|
|
145
141
|
|
|
146
|
-
def __init__(self, cfg:
|
|
142
|
+
def __init__(self, cfg: TopKTrainingSAEConfig, use_error_term: bool = False):
|
|
147
143
|
super().__init__(cfg, use_error_term)
|
|
148
144
|
|
|
149
|
-
|
|
150
|
-
raise ValueError("TopKSAE must use a TopK activation function.")
|
|
151
|
-
|
|
145
|
+
@override
|
|
152
146
|
def initialize_weights(self) -> None:
|
|
153
|
-
|
|
154
|
-
self
|
|
155
|
-
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
|
|
156
|
-
)
|
|
157
|
-
self.b_dec = nn.Parameter(
|
|
158
|
-
torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
|
|
159
|
-
)
|
|
160
|
-
|
|
161
|
-
w_enc_data = torch.empty(
|
|
162
|
-
self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
|
|
163
|
-
)
|
|
164
|
-
nn.init.kaiming_uniform_(w_enc_data)
|
|
165
|
-
self.W_enc = nn.Parameter(w_enc_data)
|
|
166
|
-
|
|
167
|
-
w_dec_data = torch.empty(
|
|
168
|
-
self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
|
|
169
|
-
)
|
|
170
|
-
nn.init.kaiming_uniform_(w_dec_data)
|
|
171
|
-
self.W_dec = nn.Parameter(w_dec_data)
|
|
147
|
+
super().initialize_weights()
|
|
148
|
+
_init_weights_topk(self)
|
|
172
149
|
|
|
173
150
|
def encode_with_hidden_pre(
|
|
174
151
|
self, x: Float[torch.Tensor, "... d_in"]
|
|
@@ -207,14 +184,13 @@ class TopKTrainingSAE(TrainingSAE):
|
|
|
207
184
|
)
|
|
208
185
|
return {"auxiliary_reconstruction_loss": topk_loss}
|
|
209
186
|
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
return super()._get_activation_fn()
|
|
187
|
+
@override
|
|
188
|
+
def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
|
|
189
|
+
return TopK(self.cfg.k)
|
|
190
|
+
|
|
191
|
+
@override
|
|
192
|
+
def get_coefficients(self) -> dict[str, TrainCoefficientConfig | float]:
|
|
193
|
+
return {}
|
|
218
194
|
|
|
219
195
|
def calculate_topk_aux_loss(
|
|
220
196
|
self,
|
|
@@ -288,6 +264,11 @@ class TopKTrainingSAE(TrainingSAE):
|
|
|
288
264
|
|
|
289
265
|
return auxk_acts
|
|
290
266
|
|
|
267
|
+
def to_inference_config_dict(self) -> dict[str, Any]:
|
|
268
|
+
return filter_valid_dataclass_fields(
|
|
269
|
+
self.cfg.to_dict(), TopKSAEConfig, ["architecture"]
|
|
270
|
+
)
|
|
271
|
+
|
|
291
272
|
|
|
292
273
|
def _calculate_topk_aux_acts(
|
|
293
274
|
k_aux: int,
|
|
@@ -303,3 +284,11 @@ def _calculate_topk_aux_acts(
|
|
|
303
284
|
auxk_acts.scatter_(-1, auxk_topk.indices, auxk_topk.values)
|
|
304
285
|
# Set activations to zero for all but top k_aux dead latents
|
|
305
286
|
return auxk_acts
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
def _init_weights_topk(
|
|
290
|
+
sae: SAE[TopKSAEConfig] | TrainingSAE[TopKTrainingSAEConfig],
|
|
291
|
+
) -> None:
|
|
292
|
+
sae.b_enc = nn.Parameter(
|
|
293
|
+
torch.zeros(sae.cfg.d_sae, dtype=sae.dtype, device=sae.device)
|
|
294
|
+
)
|
|
@@ -23,12 +23,12 @@ from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
|
|
23
23
|
|
|
24
24
|
from sae_lens import logger
|
|
25
25
|
from sae_lens.config import (
|
|
26
|
-
DTYPE_MAP,
|
|
27
26
|
CacheActivationsRunnerConfig,
|
|
28
27
|
HfDataset,
|
|
29
28
|
LanguageModelSAERunnerConfig,
|
|
30
29
|
)
|
|
31
|
-
from sae_lens.
|
|
30
|
+
from sae_lens.constants import DTYPE_MAP
|
|
31
|
+
from sae_lens.saes.sae import SAE, T_SAE_CONFIG, T_TRAINING_SAE_CONFIG
|
|
32
32
|
from sae_lens.tokenization_and_batching import concat_and_batch_sequences
|
|
33
33
|
|
|
34
34
|
|
|
@@ -91,7 +91,8 @@ class ActivationsStore:
|
|
|
91
91
|
def from_config(
|
|
92
92
|
cls,
|
|
93
93
|
model: HookedRootModule,
|
|
94
|
-
cfg: LanguageModelSAERunnerConfig
|
|
94
|
+
cfg: LanguageModelSAERunnerConfig[T_TRAINING_SAE_CONFIG]
|
|
95
|
+
| CacheActivationsRunnerConfig,
|
|
95
96
|
override_dataset: HfDataset | None = None,
|
|
96
97
|
) -> ActivationsStore:
|
|
97
98
|
if isinstance(cfg, CacheActivationsRunnerConfig):
|
|
@@ -128,13 +129,15 @@ class ActivationsStore:
|
|
|
128
129
|
hook_layer=cfg.hook_layer,
|
|
129
130
|
hook_head_index=cfg.hook_head_index,
|
|
130
131
|
context_size=cfg.context_size,
|
|
131
|
-
d_in=cfg.d_in
|
|
132
|
+
d_in=cfg.d_in
|
|
133
|
+
if isinstance(cfg, CacheActivationsRunnerConfig)
|
|
134
|
+
else cfg.sae.d_in,
|
|
132
135
|
n_batches_in_buffer=cfg.n_batches_in_buffer,
|
|
133
136
|
total_training_tokens=cfg.training_tokens,
|
|
134
137
|
store_batch_size_prompts=cfg.store_batch_size_prompts,
|
|
135
138
|
train_batch_size_tokens=cfg.train_batch_size_tokens,
|
|
136
139
|
prepend_bos=cfg.prepend_bos,
|
|
137
|
-
normalize_activations=cfg.normalize_activations,
|
|
140
|
+
normalize_activations=cfg.sae.normalize_activations,
|
|
138
141
|
device=device,
|
|
139
142
|
dtype=cfg.dtype,
|
|
140
143
|
cached_activations_path=cached_activations_path,
|
|
@@ -149,9 +152,10 @@ class ActivationsStore:
|
|
|
149
152
|
def from_sae(
|
|
150
153
|
cls,
|
|
151
154
|
model: HookedRootModule,
|
|
152
|
-
sae: SAE,
|
|
155
|
+
sae: SAE[T_SAE_CONFIG],
|
|
156
|
+
dataset: HfDataset | str,
|
|
157
|
+
dataset_trust_remote_code: bool = False,
|
|
153
158
|
context_size: int | None = None,
|
|
154
|
-
dataset: HfDataset | str | None = None,
|
|
155
159
|
streaming: bool = True,
|
|
156
160
|
store_batch_size_prompts: int = 8,
|
|
157
161
|
n_batches_in_buffer: int = 8,
|
|
@@ -159,25 +163,37 @@ class ActivationsStore:
|
|
|
159
163
|
total_tokens: int = 10**9,
|
|
160
164
|
device: str = "cpu",
|
|
161
165
|
) -> ActivationsStore:
|
|
166
|
+
if sae.cfg.metadata.hook_name is None:
|
|
167
|
+
raise ValueError("hook_name is required")
|
|
168
|
+
if sae.cfg.metadata.hook_layer is None:
|
|
169
|
+
raise ValueError("hook_layer is required")
|
|
170
|
+
if sae.cfg.metadata.hook_head_index is None:
|
|
171
|
+
raise ValueError("hook_head_index is required")
|
|
172
|
+
if sae.cfg.metadata.context_size is None:
|
|
173
|
+
raise ValueError("context_size is required")
|
|
174
|
+
if sae.cfg.metadata.prepend_bos is None:
|
|
175
|
+
raise ValueError("prepend_bos is required")
|
|
162
176
|
return cls(
|
|
163
177
|
model=model,
|
|
164
|
-
dataset=
|
|
178
|
+
dataset=dataset,
|
|
165
179
|
d_in=sae.cfg.d_in,
|
|
166
|
-
hook_name=sae.cfg.hook_name,
|
|
167
|
-
hook_layer=sae.cfg.hook_layer,
|
|
168
|
-
hook_head_index=sae.cfg.hook_head_index,
|
|
169
|
-
context_size=sae.cfg.context_size
|
|
170
|
-
|
|
180
|
+
hook_name=sae.cfg.metadata.hook_name,
|
|
181
|
+
hook_layer=sae.cfg.metadata.hook_layer,
|
|
182
|
+
hook_head_index=sae.cfg.metadata.hook_head_index,
|
|
183
|
+
context_size=sae.cfg.metadata.context_size
|
|
184
|
+
if context_size is None
|
|
185
|
+
else context_size,
|
|
186
|
+
prepend_bos=sae.cfg.metadata.prepend_bos,
|
|
171
187
|
streaming=streaming,
|
|
172
188
|
store_batch_size_prompts=store_batch_size_prompts,
|
|
173
189
|
train_batch_size_tokens=train_batch_size_tokens,
|
|
174
190
|
n_batches_in_buffer=n_batches_in_buffer,
|
|
175
191
|
total_training_tokens=total_tokens,
|
|
176
192
|
normalize_activations=sae.cfg.normalize_activations,
|
|
177
|
-
dataset_trust_remote_code=
|
|
193
|
+
dataset_trust_remote_code=dataset_trust_remote_code,
|
|
178
194
|
dtype=sae.cfg.dtype,
|
|
179
195
|
device=torch.device(device),
|
|
180
|
-
seqpos_slice=sae.cfg.seqpos_slice or (None,),
|
|
196
|
+
seqpos_slice=sae.cfg.metadata.seqpos_slice or (None,),
|
|
181
197
|
)
|
|
182
198
|
|
|
183
199
|
def __init__(
|
sae_lens/training/optim.py
CHANGED
|
@@ -101,61 +101,85 @@ def _get_main_lr_scheduler(
|
|
|
101
101
|
raise ValueError(f"Unsupported scheduler: {scheduler_name}")
|
|
102
102
|
|
|
103
103
|
|
|
104
|
-
class
|
|
104
|
+
class CoefficientScheduler:
|
|
105
|
+
"""Linearly warms up a scalar value from 0.0 to a final value."""
|
|
106
|
+
|
|
105
107
|
def __init__(
|
|
106
108
|
self,
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
final_l1_coefficient: float,
|
|
109
|
+
warm_up_steps: float,
|
|
110
|
+
final_value: float,
|
|
110
111
|
):
|
|
111
|
-
self.
|
|
112
|
-
|
|
113
|
-
if self.l1_warmup_steps != 0:
|
|
114
|
-
self.current_l1_coefficient = 0.0
|
|
115
|
-
else:
|
|
116
|
-
self.current_l1_coefficient = final_l1_coefficient
|
|
117
|
-
|
|
118
|
-
self.final_l1_coefficient = final_l1_coefficient
|
|
119
|
-
|
|
112
|
+
self.warm_up_steps = warm_up_steps
|
|
113
|
+
self.final_value = final_value
|
|
120
114
|
self.current_step = 0
|
|
121
|
-
|
|
122
|
-
if not isinstance(self.
|
|
115
|
+
|
|
116
|
+
if not isinstance(self.final_value, (float, int)):
|
|
123
117
|
raise TypeError(
|
|
124
|
-
f"
|
|
118
|
+
f"final_value must be float or int, got {type(self.final_value)}."
|
|
125
119
|
)
|
|
126
120
|
|
|
121
|
+
# Initialize current_value based on whether warm-up is used
|
|
122
|
+
if self.warm_up_steps > 0:
|
|
123
|
+
self.current_value = 0.0
|
|
124
|
+
else:
|
|
125
|
+
self.current_value = self.final_value
|
|
126
|
+
|
|
127
127
|
def __repr__(self) -> str:
|
|
128
128
|
return (
|
|
129
|
-
f"
|
|
130
|
-
f"
|
|
131
|
-
f"total_steps={self.total_steps})"
|
|
129
|
+
f"{self.__class__.__name__}(final_value={self.final_value}, "
|
|
130
|
+
f"warm_up_steps={self.warm_up_steps})"
|
|
132
131
|
)
|
|
133
132
|
|
|
134
|
-
def step(self):
|
|
133
|
+
def step(self) -> float:
|
|
135
134
|
"""
|
|
136
|
-
Updates the
|
|
135
|
+
Updates the scalar value based on the current step.
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
The current scalar value after the step.
|
|
137
139
|
"""
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
) # type: ignore
|
|
140
|
+
if self.current_step < self.warm_up_steps:
|
|
141
|
+
self.current_value = self.final_value * (
|
|
142
|
+
(self.current_step + 1) / self.warm_up_steps
|
|
143
|
+
)
|
|
143
144
|
else:
|
|
144
|
-
|
|
145
|
+
# Ensure the value stays at final_value after warm-up
|
|
146
|
+
self.current_value = self.final_value
|
|
145
147
|
|
|
146
148
|
self.current_step += 1
|
|
149
|
+
return self.current_value
|
|
147
150
|
|
|
148
|
-
|
|
149
|
-
|
|
151
|
+
@property
|
|
152
|
+
def value(self) -> float:
|
|
153
|
+
"""Returns the current scalar value."""
|
|
154
|
+
return self.current_value
|
|
155
|
+
|
|
156
|
+
def state_dict(self) -> dict[str, Any]:
|
|
157
|
+
"""State dict for serialization."""
|
|
150
158
|
return {
|
|
151
|
-
"
|
|
152
|
-
"
|
|
153
|
-
"current_l1_coefficient": self.current_l1_coefficient,
|
|
154
|
-
"final_l1_coefficient": self.final_l1_coefficient,
|
|
159
|
+
"warm_up_steps": self.warm_up_steps,
|
|
160
|
+
"final_value": self.final_value,
|
|
155
161
|
"current_step": self.current_step,
|
|
162
|
+
"current_value": self.current_value,
|
|
156
163
|
}
|
|
157
164
|
|
|
158
165
|
def load_state_dict(self, state_dict: dict[str, Any]):
|
|
159
|
-
"""Loads
|
|
160
|
-
|
|
161
|
-
|
|
166
|
+
"""Loads the scheduler state."""
|
|
167
|
+
self.warm_up_steps = state_dict["warm_up_steps"]
|
|
168
|
+
self.final_value = state_dict["final_value"]
|
|
169
|
+
self.current_step = state_dict["current_step"]
|
|
170
|
+
# Maintain consistency: re-calculate current_value based on loaded step
|
|
171
|
+
# This handles resuming correctly if stopped mid-warmup.
|
|
172
|
+
if self.current_step <= self.warm_up_steps and self.warm_up_steps > 0:
|
|
173
|
+
# Use max(0, ...) to handle case where current_step might be loaded as -1 or similar before first step
|
|
174
|
+
step_for_calc = max(0, self.current_step)
|
|
175
|
+
# Recalculate based on the step *before* the one about to be taken
|
|
176
|
+
# Or simply use the saved current_value if available and consistent
|
|
177
|
+
if "current_value" in state_dict:
|
|
178
|
+
self.current_value = state_dict["current_value"]
|
|
179
|
+
else: # Legacy state dicts might not have current_value
|
|
180
|
+
self.current_value = self.final_value * (
|
|
181
|
+
step_for_calc / self.warm_up_steps
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
else:
|
|
185
|
+
self.current_value = self.final_value
|