sae-lens 6.15.0__py3-none-any.whl → 6.16.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sae_lens/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.15.0"
2
+ __version__ = "6.16.2"
3
3
 
4
4
  import logging
5
5
 
@@ -23,7 +23,9 @@ class BatchTopK(nn.Module):
23
23
  def forward(self, x: torch.Tensor) -> torch.Tensor:
24
24
  acts = x.relu()
25
25
  flat_acts = acts.flatten()
26
- acts_topk_flat = torch.topk(flat_acts, int(self.k * acts.shape[0]), dim=-1)
26
+ # Calculate total number of samples across all non-feature dimensions
27
+ num_samples = acts.shape[:-1].numel()
28
+ acts_topk_flat = torch.topk(flat_acts, int(self.k * num_samples), dim=-1)
27
29
  return (
28
30
  torch.zeros_like(flat_acts)
29
31
  .scatter(-1, acts_topk_flat.indices, acts_topk_flat.values)
sae_lens/saes/sae.py CHANGED
@@ -217,6 +217,7 @@ class TrainStepInput:
217
217
  sae_in: torch.Tensor
218
218
  coefficients: dict[str, float]
219
219
  dead_neuron_mask: torch.Tensor | None
220
+ n_training_steps: int
220
221
 
221
222
 
222
223
  class TrainCoefficientConfig(NamedTuple):
@@ -249,6 +249,7 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
249
249
  sae_in=sae_in,
250
250
  dead_neuron_mask=self.dead_neurons,
251
251
  coefficients=self.get_coefficients(),
252
+ n_training_steps=self.n_training_steps,
252
253
  ),
253
254
  )
254
255
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sae-lens
3
- Version: 6.15.0
3
+ Version: 6.16.2
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  License-File: LICENSE
@@ -1,4 +1,4 @@
1
- sae_lens/__init__.py,sha256=ab8Lj2QJE3i1uOP_4B9LLh_vCgi__3XXx66_eO8rcrA,3886
1
+ sae_lens/__init__.py,sha256=vLRjrn1HSwh2xAVptsmnj6iIrzTZwUKE2Vxe3N0mvek,3886
2
2
  sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  sae_lens/analysis/hooked_sae_transformer.py,sha256=vRu6JseH1lZaEeILD5bEkQEQ1wYHHDcxD-f2olKmE9Y,14275
4
4
  sae_lens/analysis/neuronpedia_integration.py,sha256=Gx1W7hUBEuMoasNcnOnZ1wmqbXDd1pSZ1nqKEya1HQc,4962
@@ -15,11 +15,11 @@ sae_lens/pretokenize_runner.py,sha256=x-reJzVPFDS9iRFbZtrFYSzNguJYki9gd0pbHjYJ3r
15
15
  sae_lens/pretrained_saes.yaml,sha256=6ca3geEB6NyhULUrmdtPDK8ea0YdpLp8_au78vIFC5w,602553
16
16
  sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
17
17
  sae_lens/saes/__init__.py,sha256=sIfZUxZ4m3HPtPymCJlpBEofiCrL8_QziE6ChS-v4lE,1677
18
- sae_lens/saes/batchtopk_sae.py,sha256=zxIke8lOBKkQEMVFk6sSW6q_s6F9RKhysLqfqG9ecwI,5300
18
+ sae_lens/saes/batchtopk_sae.py,sha256=x4EbgZl0GUickRPcCmtKNGS2Ra3Uy1Z1OtF2FnrSabQ,5422
19
19
  sae_lens/saes/gated_sae.py,sha256=qcmM9JwBA8aZR8z_IRHV1_gQX-q_63tKewWXRnhdXuo,8986
20
20
  sae_lens/saes/jumprelu_sae.py,sha256=HHBF1sJ95lZvxwP5vwLSQFKdnJN2KKYK0WAEaLTrta0,13399
21
21
  sae_lens/saes/matryoshka_batchtopk_sae.py,sha256=zrS4MksbxdhhftmU3UWjRCWjR7iEBpAk6N00c6GrXks,6291
22
- sae_lens/saes/sae.py,sha256=nuII6ZmaVtJWhPjyhasHQyiv_Wj-zdAtRQqJRYbVBQs,38274
22
+ sae_lens/saes/sae.py,sha256=vABlwyZ0JtL896xxBGIoqfiByoszIf-e4ggPgz34RL0,38300
23
23
  sae_lens/saes/standard_sae.py,sha256=9UqYyYtQuThYxXKNaDjYcyowpOx2-7cShG-TeUP6JCQ,5940
24
24
  sae_lens/saes/topk_sae.py,sha256=tzQM5eQFifMe--8_8NUBYWY7hpjQa6A_olNe6U71FE8,21275
25
25
  sae_lens/saes/transcoder.py,sha256=BfLSbTYVNZh-ruGxseZiZJ_acEL6_7QyTdfqUr0lDOg,12156
@@ -29,12 +29,12 @@ sae_lens/training/activation_scaler.py,sha256=seEE-2Qd2JMHxqgnsNWPt-DGtYGZxWPnOw
29
29
  sae_lens/training/activations_store.py,sha256=hHY6rW-T7sLq2a8JPEyWdm8leuIRm_MsObZs3jRTZmE,31931
30
30
  sae_lens/training/mixing_buffer.py,sha256=vDpYG5ZE70szDvBsRKcNHEES3h_WTKJ16qDYk5jPOVA,2015
31
31
  sae_lens/training/optim.py,sha256=TiI9nbffzXNsI8WjcIsqa2uheW6suxqL_KDDmWXobWI,5312
32
- sae_lens/training/sae_trainer.py,sha256=il4Evf-c4F3Uf2n_v-AOItCasX-uPxYTzn_sZLvLkl0,15633
32
+ sae_lens/training/sae_trainer.py,sha256=ESA8FjHmIsLSgU-p5zUDtOkhzLEMFAijrWDmcYow6Rs,15693
33
33
  sae_lens/training/types.py,sha256=1FpLx_Doda9vZpmfm-x1e8wGBYpyhe9Kpb_JuM5nIFM,90
34
34
  sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
35
35
  sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
36
36
  sae_lens/util.py,sha256=tCovQ-eZa1L7thPpNDL6PGOJrIMML2yLI5e0EHCOpS8,3309
37
- sae_lens-6.15.0.dist-info/METADATA,sha256=UmBQ8quUJBWyLclhbnDcXAkL-6jnOW4SbT8_X3rrcbE,5318
38
- sae_lens-6.15.0.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
39
- sae_lens-6.15.0.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
40
- sae_lens-6.15.0.dist-info/RECORD,,
37
+ sae_lens-6.16.2.dist-info/METADATA,sha256=bixpAL-bXPWMgRmtao5pk8J_jWrRHeOdfZQRu-R4SNU,5318
38
+ sae_lens-6.16.2.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
39
+ sae_lens-6.16.2.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
40
+ sae_lens-6.16.2.dist-info/RECORD,,