sae-lens 6.22.1__py3-none-any.whl → 6.22.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sae_lens/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.22.1"
2
+ __version__ = "6.22.3"
3
3
 
4
4
  import logging
5
5
 
@@ -89,7 +89,7 @@ class GatedSAE(SAE[GatedSAEConfig]):
89
89
  @torch.no_grad()
90
90
  def fold_W_dec_norm(self):
91
91
  """Override to handle gated-specific parameters."""
92
- W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
92
+ W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1)
93
93
  self.W_dec.data = self.W_dec.data / W_dec_norms
94
94
  self.W_enc.data = self.W_enc.data * W_dec_norms.T
95
95
 
@@ -217,7 +217,7 @@ class GatedTrainingSAE(TrainingSAE[GatedTrainingSAEConfig]):
217
217
  @torch.no_grad()
218
218
  def fold_W_dec_norm(self):
219
219
  """Override to handle gated-specific parameters."""
220
- W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
220
+ W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1)
221
221
  self.W_dec.data = self.W_dec.data / W_dec_norms
222
222
  self.W_enc.data = self.W_enc.data * W_dec_norms.T
223
223
 
@@ -167,8 +167,8 @@ class JumpReLUSAE(SAE[JumpReLUSAEConfig]):
167
167
  # Save the current threshold before calling parent method
168
168
  current_thresh = self.threshold.clone()
169
169
 
170
- # Get W_dec norms that will be used for scaling
171
- W_dec_norms = self.W_dec.norm(dim=-1)
170
+ # Get W_dec norms that will be used for scaling (clamped to avoid division by zero)
171
+ W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8)
172
172
 
173
173
  # Call parent implementation to handle W_enc, W_dec, and b_enc adjustment
174
174
  super().fold_W_dec_norm()
@@ -325,8 +325,8 @@ class JumpReLUTrainingSAE(TrainingSAE[JumpReLUTrainingSAEConfig]):
325
325
  # Save the current threshold before we call the parent method
326
326
  current_thresh = self.threshold.clone()
327
327
 
328
- # Get W_dec norms
329
- W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
328
+ # Get W_dec norms (clamped to avoid division by zero)
329
+ W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1)
330
330
 
331
331
  # Call parent implementation to handle W_enc and W_dec adjustment
332
332
  super().fold_W_dec_norm()
sae_lens/saes/sae.py CHANGED
@@ -484,7 +484,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
484
484
  @torch.no_grad()
485
485
  def fold_W_dec_norm(self):
486
486
  """Fold decoder norms into encoder."""
487
- W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
487
+ W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1)
488
488
  self.W_dec.data = self.W_dec.data / W_dec_norms
489
489
  self.W_enc.data = self.W_enc.data * W_dec_norms.T
490
490
 
sae_lens/saes/topk_sae.py CHANGED
@@ -531,7 +531,7 @@ def _fold_norm_topk(
531
531
  b_enc: torch.Tensor,
532
532
  W_dec: torch.Tensor,
533
533
  ) -> None:
534
- W_dec_norm = W_dec.norm(dim=-1)
534
+ W_dec_norm = W_dec.norm(dim=-1).clamp(min=1e-8)
535
535
  b_enc.data = b_enc.data * W_dec_norm
536
536
  W_dec_norms = W_dec_norm.unsqueeze(1)
537
537
  W_dec.data = W_dec.data / W_dec_norms
@@ -166,9 +166,11 @@ class ActivationsStore:
166
166
  disable_concat_sequences: bool = False,
167
167
  sequence_separator_token: int | Literal["bos", "eos", "sep"] | None = "bos",
168
168
  ) -> ActivationsStore:
169
+ if context_size is None:
170
+ context_size = sae.cfg.metadata.context_size
169
171
  if sae.cfg.metadata.hook_name is None:
170
172
  raise ValueError("hook_name is required")
171
- if sae.cfg.metadata.context_size is None:
173
+ if context_size is None:
172
174
  raise ValueError("context_size is required")
173
175
  if sae.cfg.metadata.prepend_bos is None:
174
176
  raise ValueError("prepend_bos is required")
@@ -178,9 +180,7 @@ class ActivationsStore:
178
180
  d_in=sae.cfg.d_in,
179
181
  hook_name=sae.cfg.metadata.hook_name,
180
182
  hook_head_index=sae.cfg.metadata.hook_head_index,
181
- context_size=sae.cfg.metadata.context_size
182
- if context_size is None
183
- else context_size,
183
+ context_size=context_size,
184
184
  prepend_bos=sae.cfg.metadata.prepend_bos,
185
185
  streaming=streaming,
186
186
  store_batch_size_prompts=store_batch_size_prompts,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sae-lens
3
- Version: 6.22.1
3
+ Version: 6.22.3
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  License-File: LICENSE
@@ -1,4 +1,4 @@
1
- sae_lens/__init__.py,sha256=v-2uKiNW5UNVCRt7vyBrvI0olJsXIxaPp9TJvo-m9wg,4033
1
+ sae_lens/__init__.py,sha256=-q6U-a-hKOgMuJj8jBEyV5JvAZolVh6ccirbA3zrpYg,4033
2
2
  sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  sae_lens/analysis/hooked_sae_transformer.py,sha256=dQRgGVwce8XwylL2AzJE7l9elhtMRFCs2hdUj-Qyy4g,14038
4
4
  sae_lens/analysis/neuronpedia_integration.py,sha256=Gx1W7hUBEuMoasNcnOnZ1wmqbXDd1pSZ1nqKEya1HQc,4962
@@ -16,18 +16,18 @@ sae_lens/pretrained_saes.yaml,sha256=VzgJ_t-IEWpO2MabgQY6CAcg8FFsqZWiOVXjqvqfgeE
16
16
  sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
17
17
  sae_lens/saes/__init__.py,sha256=nTNPnJ7edyfedo1MX96xwn9WOG8504yHbT9LFw9od_0,1778
18
18
  sae_lens/saes/batchtopk_sae.py,sha256=x4EbgZl0GUickRPcCmtKNGS2Ra3Uy1Z1OtF2FnrSabQ,5422
19
- sae_lens/saes/gated_sae.py,sha256=Jq74JGtqpO6tW3XdJGbURTTWN_fAoAMKu9T7O-MZTeE,8793
20
- sae_lens/saes/jumprelu_sae.py,sha256=zUGHWOFXbeDBS3mjkOE3ikxlEniq2EX9rCAizLMOpp4,13206
19
+ sae_lens/saes/gated_sae.py,sha256=mHnmw-RD7hqIbP9_EBj3p2SK0OqQIkZivdOKRygeRgw,8825
20
+ sae_lens/saes/jumprelu_sae.py,sha256=udjGHp3WTABQSL2Qq57j-bINWX61GCmo68EmdjMOXoo,13310
21
21
  sae_lens/saes/matryoshka_batchtopk_sae.py,sha256=Qr6htt1HHOuO9FXI9hyaPSnGFIiJG-v7y1t1CEmkFzM,5995
22
- sae_lens/saes/sae.py,sha256=q8ylAdqtkNAms7X-3y1QIBfHOZ-FvKHvCap7Tw_cnzE,37733
22
+ sae_lens/saes/sae.py,sha256=Vb1aGSDPRv_0J2aL8-EICRSkIxsO6Q4lJaJE9NNmfdA,37749
23
23
  sae_lens/saes/standard_sae.py,sha256=nEVETwAmRD2tyX7ESIic1fij48gAq1Dh7s_GQ2fqCZ4,5747
24
24
  sae_lens/saes/temporal_sae.py,sha256=DsecivcHWId-MTuJpQbz8OhqtmGhZACxJauYZGHo0Ok,13272
25
- sae_lens/saes/topk_sae.py,sha256=D1N4LHGOeV8dhHW0i3HqBT1cqA-E1Plq11uMJtVfNBo,21057
25
+ sae_lens/saes/topk_sae.py,sha256=vrMRPrCQR1o8G_kXqY_EAoGZARupkQNFB2dNZVLsusE,21073
26
26
  sae_lens/saes/transcoder.py,sha256=BfLSbTYVNZh-ruGxseZiZJ_acEL6_7QyTdfqUr0lDOg,12156
27
27
  sae_lens/tokenization_and_batching.py,sha256=D_o7cXvRqhT89H3wNzoRymNALNE6eHojBWLdXOUwUGE,5438
28
28
  sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
29
29
  sae_lens/training/activation_scaler.py,sha256=FzNfgBplLWmyiSlZ6TUvE-nur3lOiGTrlvC97ys8S24,1973
30
- sae_lens/training/activations_store.py,sha256=1ozCANGXO8Vx9d_l-heb-MsSpUoYcHagcve5JLGwZYY,33762
30
+ sae_lens/training/activations_store.py,sha256=pG6YhPCLrf7bd7bZTZ3_aS5J92lfnhVtyK8e0bzYFnI,33752
31
31
  sae_lens/training/mixing_buffer.py,sha256=vDpYG5ZE70szDvBsRKcNHEES3h_WTKJ16qDYk5jPOVA,2015
32
32
  sae_lens/training/optim.py,sha256=bJpqqcK4enkcPvQAJkeH4Ci1LUOlfjIMTv6-IlaAbRA,5588
33
33
  sae_lens/training/sae_trainer.py,sha256=zhkabyIKxI_tZTV3_kwz6zMrHZ95Ecr97krmwc-9ffs,17600
@@ -35,7 +35,7 @@ sae_lens/training/types.py,sha256=1FpLx_Doda9vZpmfm-x1e8wGBYpyhe9Kpb_JuM5nIFM,90
35
35
  sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
36
36
  sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
37
37
  sae_lens/util.py,sha256=tCovQ-eZa1L7thPpNDL6PGOJrIMML2yLI5e0EHCOpS8,3309
38
- sae_lens-6.22.1.dist-info/METADATA,sha256=QoCu9iHTvA66XSkU2aR_4VxP7wGFr_NQPJUZwxvaOak,5369
39
- sae_lens-6.22.1.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
40
- sae_lens-6.22.1.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
41
- sae_lens-6.22.1.dist-info/RECORD,,
38
+ sae_lens-6.22.3.dist-info/METADATA,sha256=vsYZOnH0fqH5fsQHPFnwXRLd61-NzOWm1Hz98YRS52c,5369
39
+ sae_lens-6.22.3.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
40
+ sae_lens-6.22.3.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
41
+ sae_lens-6.22.3.dist-info/RECORD,,