sae-lens 6.16.0__py3-none-any.whl → 6.16.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +1 -1
- sae_lens/saes/batchtopk_sae.py +3 -1
- sae_lens/saes/matryoshka_batchtopk_sae.py +6 -12
- {sae_lens-6.16.0.dist-info → sae_lens-6.16.3.dist-info}/METADATA +1 -1
- {sae_lens-6.16.0.dist-info → sae_lens-6.16.3.dist-info}/RECORD +7 -7
- {sae_lens-6.16.0.dist-info → sae_lens-6.16.3.dist-info}/WHEEL +0 -0
- {sae_lens-6.16.0.dist-info → sae_lens-6.16.3.dist-info}/licenses/LICENSE +0 -0
sae_lens/__init__.py
CHANGED
sae_lens/saes/batchtopk_sae.py
CHANGED
|
@@ -23,7 +23,9 @@ class BatchTopK(nn.Module):
|
|
|
23
23
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
24
24
|
acts = x.relu()
|
|
25
25
|
flat_acts = acts.flatten()
|
|
26
|
-
|
|
26
|
+
# Calculate total number of samples across all non-feature dimensions
|
|
27
|
+
num_samples = acts.shape[:-1].numel()
|
|
28
|
+
acts_topk_flat = torch.topk(flat_acts, int(self.k * num_samples), dim=-1)
|
|
27
29
|
return (
|
|
28
30
|
torch.zeros_like(flat_acts)
|
|
29
31
|
.scatter(-1, acts_topk_flat.indices, acts_topk_flat.values)
|
|
@@ -78,14 +78,11 @@ class MatryoshkaBatchTopKTrainingSAE(BatchTopKTrainingSAE):
|
|
|
78
78
|
@override
|
|
79
79
|
def training_forward_pass(self, step_input: TrainStepInput) -> TrainStepOutput:
|
|
80
80
|
base_output = super().training_forward_pass(step_input)
|
|
81
|
-
hidden_pre = base_output.hidden_pre
|
|
82
81
|
inv_W_dec_norm = 1 / self.W_dec.norm(dim=-1)
|
|
83
82
|
# the outer matryoshka level is the base SAE, so we don't need to add an extra loss for it
|
|
84
83
|
for width in self.cfg.matryoshka_widths[:-1]:
|
|
85
|
-
inner_hidden_pre = hidden_pre[:, :width]
|
|
86
|
-
inner_feat_acts = self.activation_fn(inner_hidden_pre)
|
|
87
84
|
inner_reconstruction = self._decode_matryoshka_level(
|
|
88
|
-
|
|
85
|
+
base_output.feature_acts, width, inv_W_dec_norm
|
|
89
86
|
)
|
|
90
87
|
inner_mse_loss = (
|
|
91
88
|
self.mse_loss_fn(inner_reconstruction, step_input.sae_in)
|
|
@@ -105,16 +102,17 @@ class MatryoshkaBatchTopKTrainingSAE(BatchTopKTrainingSAE):
|
|
|
105
102
|
"""
|
|
106
103
|
Decodes feature activations back into input space for a matryoshka level
|
|
107
104
|
"""
|
|
105
|
+
inner_feature_acts = feature_acts[:, :width]
|
|
108
106
|
# Handle sparse tensors using efficient sparse matrix multiplication
|
|
109
107
|
if self.cfg.rescale_acts_by_decoder_norm:
|
|
110
108
|
# need to multiply by the inverse of the norm because division is illegal with sparse tensors
|
|
111
|
-
|
|
112
|
-
if
|
|
109
|
+
inner_feature_acts = inner_feature_acts * inv_W_dec_norm[:width]
|
|
110
|
+
if inner_feature_acts.is_sparse:
|
|
113
111
|
sae_out_pre = (
|
|
114
|
-
_sparse_matmul_nd(
|
|
112
|
+
_sparse_matmul_nd(inner_feature_acts, self.W_dec[:width]) + self.b_dec
|
|
115
113
|
)
|
|
116
114
|
else:
|
|
117
|
-
sae_out_pre =
|
|
115
|
+
sae_out_pre = inner_feature_acts @ self.W_dec[:width] + self.b_dec
|
|
118
116
|
sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
|
|
119
117
|
return self.reshape_fn_out(sae_out_pre, self.d_head)
|
|
120
118
|
|
|
@@ -137,7 +135,3 @@ def _validate_matryoshka_config(cfg: MatryoshkaBatchTopKTrainingSAEConfig) -> No
|
|
|
137
135
|
warnings.warn(
|
|
138
136
|
"WARNING: You have only set one matryoshka level. This is equivalent to using a standard BatchTopK SAE and is likely not what you want."
|
|
139
137
|
)
|
|
140
|
-
if cfg.matryoshka_widths[0] < cfg.k:
|
|
141
|
-
raise ValueError(
|
|
142
|
-
"The smallest matryoshka level width cannot be smaller than cfg.k."
|
|
143
|
-
)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
sae_lens/__init__.py,sha256=
|
|
1
|
+
sae_lens/__init__.py,sha256=c1rxG64QCdP4n1LI8Du_dxEn30E1fXXbEfZ0kaZ2JiI,3886
|
|
2
2
|
sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
3
|
sae_lens/analysis/hooked_sae_transformer.py,sha256=vRu6JseH1lZaEeILD5bEkQEQ1wYHHDcxD-f2olKmE9Y,14275
|
|
4
4
|
sae_lens/analysis/neuronpedia_integration.py,sha256=Gx1W7hUBEuMoasNcnOnZ1wmqbXDd1pSZ1nqKEya1HQc,4962
|
|
@@ -15,10 +15,10 @@ sae_lens/pretokenize_runner.py,sha256=x-reJzVPFDS9iRFbZtrFYSzNguJYki9gd0pbHjYJ3r
|
|
|
15
15
|
sae_lens/pretrained_saes.yaml,sha256=6ca3geEB6NyhULUrmdtPDK8ea0YdpLp8_au78vIFC5w,602553
|
|
16
16
|
sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
|
|
17
17
|
sae_lens/saes/__init__.py,sha256=sIfZUxZ4m3HPtPymCJlpBEofiCrL8_QziE6ChS-v4lE,1677
|
|
18
|
-
sae_lens/saes/batchtopk_sae.py,sha256=
|
|
18
|
+
sae_lens/saes/batchtopk_sae.py,sha256=x4EbgZl0GUickRPcCmtKNGS2Ra3Uy1Z1OtF2FnrSabQ,5422
|
|
19
19
|
sae_lens/saes/gated_sae.py,sha256=qcmM9JwBA8aZR8z_IRHV1_gQX-q_63tKewWXRnhdXuo,8986
|
|
20
20
|
sae_lens/saes/jumprelu_sae.py,sha256=HHBF1sJ95lZvxwP5vwLSQFKdnJN2KKYK0WAEaLTrta0,13399
|
|
21
|
-
sae_lens/saes/matryoshka_batchtopk_sae.py,sha256=
|
|
21
|
+
sae_lens/saes/matryoshka_batchtopk_sae.py,sha256=4_1cVaxk6c6jgJEbxqebtG-cjQNIzaMAfjSPGfR7_VU,6062
|
|
22
22
|
sae_lens/saes/sae.py,sha256=vABlwyZ0JtL896xxBGIoqfiByoszIf-e4ggPgz34RL0,38300
|
|
23
23
|
sae_lens/saes/standard_sae.py,sha256=9UqYyYtQuThYxXKNaDjYcyowpOx2-7cShG-TeUP6JCQ,5940
|
|
24
24
|
sae_lens/saes/topk_sae.py,sha256=tzQM5eQFifMe--8_8NUBYWY7hpjQa6A_olNe6U71FE8,21275
|
|
@@ -34,7 +34,7 @@ sae_lens/training/types.py,sha256=1FpLx_Doda9vZpmfm-x1e8wGBYpyhe9Kpb_JuM5nIFM,90
|
|
|
34
34
|
sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
|
|
35
35
|
sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
|
|
36
36
|
sae_lens/util.py,sha256=tCovQ-eZa1L7thPpNDL6PGOJrIMML2yLI5e0EHCOpS8,3309
|
|
37
|
-
sae_lens-6.16.
|
|
38
|
-
sae_lens-6.16.
|
|
39
|
-
sae_lens-6.16.
|
|
40
|
-
sae_lens-6.16.
|
|
37
|
+
sae_lens-6.16.3.dist-info/METADATA,sha256=sZBPyQRO8rpyB21LtAfOtfxiXhBQ7RiG8F3IluD_8mw,5318
|
|
38
|
+
sae_lens-6.16.3.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
|
|
39
|
+
sae_lens-6.16.3.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
|
|
40
|
+
sae_lens-6.16.3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|