sae-lens 6.16.0__py3-none-any.whl → 6.16.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sae_lens/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.16.0"
2
+ __version__ = "6.16.3"
3
3
 
4
4
  import logging
5
5
 
@@ -23,7 +23,9 @@ class BatchTopK(nn.Module):
23
23
  def forward(self, x: torch.Tensor) -> torch.Tensor:
24
24
  acts = x.relu()
25
25
  flat_acts = acts.flatten()
26
- acts_topk_flat = torch.topk(flat_acts, int(self.k * acts.shape[0]), dim=-1)
26
+ # Calculate total number of samples across all non-feature dimensions
27
+ num_samples = acts.shape[:-1].numel()
28
+ acts_topk_flat = torch.topk(flat_acts, int(self.k * num_samples), dim=-1)
27
29
  return (
28
30
  torch.zeros_like(flat_acts)
29
31
  .scatter(-1, acts_topk_flat.indices, acts_topk_flat.values)
@@ -78,14 +78,11 @@ class MatryoshkaBatchTopKTrainingSAE(BatchTopKTrainingSAE):
78
78
  @override
79
79
  def training_forward_pass(self, step_input: TrainStepInput) -> TrainStepOutput:
80
80
  base_output = super().training_forward_pass(step_input)
81
- hidden_pre = base_output.hidden_pre
82
81
  inv_W_dec_norm = 1 / self.W_dec.norm(dim=-1)
83
82
  # the outer matryoshka level is the base SAE, so we don't need to add an extra loss for it
84
83
  for width in self.cfg.matryoshka_widths[:-1]:
85
- inner_hidden_pre = hidden_pre[:, :width]
86
- inner_feat_acts = self.activation_fn(inner_hidden_pre)
87
84
  inner_reconstruction = self._decode_matryoshka_level(
88
- inner_feat_acts, width, inv_W_dec_norm
85
+ base_output.feature_acts, width, inv_W_dec_norm
89
86
  )
90
87
  inner_mse_loss = (
91
88
  self.mse_loss_fn(inner_reconstruction, step_input.sae_in)
@@ -105,16 +102,17 @@ class MatryoshkaBatchTopKTrainingSAE(BatchTopKTrainingSAE):
105
102
  """
106
103
  Decodes feature activations back into input space for a matryoshka level
107
104
  """
105
+ inner_feature_acts = feature_acts[:, :width]
108
106
  # Handle sparse tensors using efficient sparse matrix multiplication
109
107
  if self.cfg.rescale_acts_by_decoder_norm:
110
108
  # need to multiply by the inverse of the norm because division is illegal with sparse tensors
111
- feature_acts = feature_acts * inv_W_dec_norm[:width]
112
- if feature_acts.is_sparse:
109
+ inner_feature_acts = inner_feature_acts * inv_W_dec_norm[:width]
110
+ if inner_feature_acts.is_sparse:
113
111
  sae_out_pre = (
114
- _sparse_matmul_nd(feature_acts, self.W_dec[:width]) + self.b_dec
112
+ _sparse_matmul_nd(inner_feature_acts, self.W_dec[:width]) + self.b_dec
115
113
  )
116
114
  else:
117
- sae_out_pre = feature_acts @ self.W_dec[:width] + self.b_dec
115
+ sae_out_pre = inner_feature_acts @ self.W_dec[:width] + self.b_dec
118
116
  sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
119
117
  return self.reshape_fn_out(sae_out_pre, self.d_head)
120
118
 
@@ -137,7 +135,3 @@ def _validate_matryoshka_config(cfg: MatryoshkaBatchTopKTrainingSAEConfig) -> No
137
135
  warnings.warn(
138
136
  "WARNING: You have only set one matryoshka level. This is equivalent to using a standard BatchTopK SAE and is likely not what you want."
139
137
  )
140
- if cfg.matryoshka_widths[0] < cfg.k:
141
- raise ValueError(
142
- "The smallest matryoshka level width cannot be smaller than cfg.k."
143
- )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sae-lens
3
- Version: 6.16.0
3
+ Version: 6.16.3
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  License-File: LICENSE
@@ -1,4 +1,4 @@
1
- sae_lens/__init__.py,sha256=8-kJQbVwfBSm9fla6sxHlPwLxQ9ghneJanQJVL1i_gc,3886
1
+ sae_lens/__init__.py,sha256=c1rxG64QCdP4n1LI8Du_dxEn30E1fXXbEfZ0kaZ2JiI,3886
2
2
  sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  sae_lens/analysis/hooked_sae_transformer.py,sha256=vRu6JseH1lZaEeILD5bEkQEQ1wYHHDcxD-f2olKmE9Y,14275
4
4
  sae_lens/analysis/neuronpedia_integration.py,sha256=Gx1W7hUBEuMoasNcnOnZ1wmqbXDd1pSZ1nqKEya1HQc,4962
@@ -15,10 +15,10 @@ sae_lens/pretokenize_runner.py,sha256=x-reJzVPFDS9iRFbZtrFYSzNguJYki9gd0pbHjYJ3r
15
15
  sae_lens/pretrained_saes.yaml,sha256=6ca3geEB6NyhULUrmdtPDK8ea0YdpLp8_au78vIFC5w,602553
16
16
  sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
17
17
  sae_lens/saes/__init__.py,sha256=sIfZUxZ4m3HPtPymCJlpBEofiCrL8_QziE6ChS-v4lE,1677
18
- sae_lens/saes/batchtopk_sae.py,sha256=zxIke8lOBKkQEMVFk6sSW6q_s6F9RKhysLqfqG9ecwI,5300
18
+ sae_lens/saes/batchtopk_sae.py,sha256=x4EbgZl0GUickRPcCmtKNGS2Ra3Uy1Z1OtF2FnrSabQ,5422
19
19
  sae_lens/saes/gated_sae.py,sha256=qcmM9JwBA8aZR8z_IRHV1_gQX-q_63tKewWXRnhdXuo,8986
20
20
  sae_lens/saes/jumprelu_sae.py,sha256=HHBF1sJ95lZvxwP5vwLSQFKdnJN2KKYK0WAEaLTrta0,13399
21
- sae_lens/saes/matryoshka_batchtopk_sae.py,sha256=zrS4MksbxdhhftmU3UWjRCWjR7iEBpAk6N00c6GrXks,6291
21
+ sae_lens/saes/matryoshka_batchtopk_sae.py,sha256=4_1cVaxk6c6jgJEbxqebtG-cjQNIzaMAfjSPGfR7_VU,6062
22
22
  sae_lens/saes/sae.py,sha256=vABlwyZ0JtL896xxBGIoqfiByoszIf-e4ggPgz34RL0,38300
23
23
  sae_lens/saes/standard_sae.py,sha256=9UqYyYtQuThYxXKNaDjYcyowpOx2-7cShG-TeUP6JCQ,5940
24
24
  sae_lens/saes/topk_sae.py,sha256=tzQM5eQFifMe--8_8NUBYWY7hpjQa6A_olNe6U71FE8,21275
@@ -34,7 +34,7 @@ sae_lens/training/types.py,sha256=1FpLx_Doda9vZpmfm-x1e8wGBYpyhe9Kpb_JuM5nIFM,90
34
34
  sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
35
35
  sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
36
36
  sae_lens/util.py,sha256=tCovQ-eZa1L7thPpNDL6PGOJrIMML2yLI5e0EHCOpS8,3309
37
- sae_lens-6.16.0.dist-info/METADATA,sha256=VI8vndYrn1pKxiHUDjyBNx74_S8mHwiaTK1nfo7t0BU,5318
38
- sae_lens-6.16.0.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
39
- sae_lens-6.16.0.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
40
- sae_lens-6.16.0.dist-info/RECORD,,
37
+ sae_lens-6.16.3.dist-info/METADATA,sha256=sZBPyQRO8rpyB21LtAfOtfxiXhBQ7RiG8F3IluD_8mw,5318
38
+ sae_lens-6.16.3.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
39
+ sae_lens-6.16.3.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
40
+ sae_lens-6.16.3.dist-info/RECORD,,