sae-lens 6.10.0__py3-none-any.whl → 6.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sae_lens/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.10.0"
2
+ __version__ = "6.11.1"
3
3
 
4
4
  import logging
5
5
 
@@ -15,7 +15,7 @@ class BatchTopK(nn.Module):
15
15
 
16
16
  def __init__(
17
17
  self,
18
- k: int,
18
+ k: float,
19
19
  ):
20
20
  super().__init__()
21
21
  self.k = k
@@ -23,7 +23,7 @@ class BatchTopK(nn.Module):
23
23
  def forward(self, x: torch.Tensor) -> torch.Tensor:
24
24
  acts = x.relu()
25
25
  flat_acts = acts.flatten()
26
- acts_topk_flat = torch.topk(flat_acts, self.k * acts.shape[0], dim=-1)
26
+ acts_topk_flat = torch.topk(flat_acts, int(self.k * acts.shape[0]), dim=-1)
27
27
  return (
28
28
  torch.zeros_like(flat_acts)
29
29
  .scatter(-1, acts_topk_flat.indices, acts_topk_flat.values)
@@ -37,6 +37,7 @@ class BatchTopKTrainingSAEConfig(TopKTrainingSAEConfig):
37
37
  Configuration class for training a BatchTopKTrainingSAE.
38
38
  """
39
39
 
40
+ k: float = 100 # type: ignore[assignment]
40
41
  topk_threshold_lr: float = 0.01
41
42
 
42
43
  @override
@@ -1,5 +1,5 @@
1
1
  from dataclasses import dataclass
2
- from typing import Any
2
+ from typing import Any, Literal
3
3
 
4
4
  import numpy as np
5
5
  import torch
@@ -187,13 +187,29 @@ class JumpReLUSAE(SAE[JumpReLUSAEConfig]):
187
187
  class JumpReLUTrainingSAEConfig(TrainingSAEConfig):
188
188
  """
189
189
  Configuration class for training a JumpReLUTrainingSAE.
190
+
191
+ - jumprelu_init_threshold: initial threshold for the JumpReLU activation
192
+ - jumprelu_bandwidth: bandwidth for the JumpReLU activation
193
+ - jumprelu_sparsity_loss_mode: mode for the sparsity loss, either "step" or "tanh". "step" is Google Deepmind's L0 loss, "tanh" is Anthropic's sparsity loss.
194
+ - l0_coefficient: coefficient for the l0 sparsity loss
195
+ - l0_warm_up_steps: number of warm-up steps for the l0 sparsity loss
196
+ - pre_act_loss_coefficient: coefficient for the pre-activation loss. Set to None to disable. Set to 3e-6 to match Anthropic's setup. Default is None.
197
+ - jumprelu_tanh_scale: scale for the tanh sparsity loss. Only relevant for "tanh" sparsity loss mode. Default is 4.0.
190
198
  """
191
199
 
192
200
  jumprelu_init_threshold: float = 0.01
193
201
  jumprelu_bandwidth: float = 0.05
202
+ # step is Google Deepmind, tanh is Anthropic
203
+ jumprelu_sparsity_loss_mode: Literal["step", "tanh"] = "step"
194
204
  l0_coefficient: float = 1.0
195
205
  l0_warm_up_steps: int = 0
196
206
 
207
+ # anthropic's auxiliary loss to avoid dead features
208
+ pre_act_loss_coefficient: float | None = None
209
+
210
+ # only relevant for tanh sparsity loss mode
211
+ jumprelu_tanh_scale: float = 4.0
212
+
197
213
  @override
198
214
  @classmethod
199
215
  def architecture(cls) -> str:
@@ -267,9 +283,35 @@ class JumpReLUTrainingSAE(TrainingSAE[JumpReLUTrainingSAEConfig]):
267
283
  sae_out: torch.Tensor,
268
284
  ) -> dict[str, torch.Tensor]:
269
285
  """Calculate architecture-specific auxiliary loss terms."""
270
- l0 = torch.sum(Step.apply(hidden_pre, self.threshold, self.bandwidth), dim=-1) # type: ignore
271
- l0_loss = (step_input.coefficients["l0"] * l0).mean()
272
- return {"l0_loss": l0_loss}
286
+
287
+ threshold = self.threshold
288
+ W_dec_norm = self.W_dec.norm(dim=1)
289
+ if self.cfg.jumprelu_sparsity_loss_mode == "step":
290
+ l0 = torch.sum(
291
+ Step.apply(hidden_pre, threshold, self.bandwidth), # type: ignore
292
+ dim=-1,
293
+ )
294
+ l0_loss = (step_input.coefficients["l0"] * l0).mean()
295
+ elif self.cfg.jumprelu_sparsity_loss_mode == "tanh":
296
+ per_item_l0_loss = torch.tanh(
297
+ self.cfg.jumprelu_tanh_scale * feature_acts * W_dec_norm
298
+ ).sum(dim=-1)
299
+ l0_loss = (step_input.coefficients["l0"] * per_item_l0_loss).mean()
300
+ else:
301
+ raise ValueError(
302
+ f"Invalid sparsity loss mode: {self.cfg.jumprelu_sparsity_loss_mode}"
303
+ )
304
+ losses = {"l0_loss": l0_loss}
305
+
306
+ if self.cfg.pre_act_loss_coefficient is not None:
307
+ losses["pre_act_loss"] = calculate_pre_act_loss(
308
+ self.cfg.pre_act_loss_coefficient,
309
+ threshold,
310
+ hidden_pre,
311
+ step_input.dead_neuron_mask,
312
+ W_dec_norm,
313
+ )
314
+ return losses
273
315
 
274
316
  @override
275
317
  def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]:
@@ -310,3 +352,21 @@ class JumpReLUTrainingSAE(TrainingSAE[JumpReLUTrainingSAEConfig]):
310
352
  threshold = state_dict["threshold"]
311
353
  del state_dict["threshold"]
312
354
  state_dict["log_threshold"] = torch.log(threshold).detach().contiguous()
355
+
356
+
357
+ def calculate_pre_act_loss(
358
+ pre_act_loss_coefficient: float,
359
+ threshold: torch.Tensor,
360
+ hidden_pre: torch.Tensor,
361
+ dead_neuron_mask: torch.Tensor | None,
362
+ W_dec_norm: torch.Tensor,
363
+ ) -> torch.Tensor:
364
+ """
365
+ Calculate Anthropic's pre-activation loss, except we only calculate this for latents that are actually dead.
366
+ """
367
+ if dead_neuron_mask is None or not dead_neuron_mask.any():
368
+ return hidden_pre.new_tensor(0.0)
369
+ per_item_loss = (
370
+ (threshold - hidden_pre).relu() * dead_neuron_mask * W_dec_norm
371
+ ).sum(dim=-1)
372
+ return pre_act_loss_coefficient * per_item_loss.mean()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: sae-lens
3
- Version: 6.10.0
3
+ Version: 6.11.1
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  Keywords: deep-learning,sparse-autoencoders,mechanistic-interpretability,PyTorch
@@ -1,4 +1,4 @@
1
- sae_lens/__init__.py,sha256=k8M2SyKNE3KpipPxODICdLG8KJNVvf1Zab4KNJuGWMQ,3589
1
+ sae_lens/__init__.py,sha256=DLmCuiml_kjSeA2AlEbJwnCIwOorh5MLGRXt4uL7mqs,3589
2
2
  sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  sae_lens/analysis/hooked_sae_transformer.py,sha256=vRu6JseH1lZaEeILD5bEkQEQ1wYHHDcxD-f2olKmE9Y,14275
4
4
  sae_lens/analysis/neuronpedia_integration.py,sha256=Gx1W7hUBEuMoasNcnOnZ1wmqbXDd1pSZ1nqKEya1HQc,4962
@@ -15,9 +15,9 @@ sae_lens/pretokenize_runner.py,sha256=w0f6SfZLAxbp5eAAKnet8RqUB_DKofZ9RGsoJwFnYb
15
15
  sae_lens/pretrained_saes.yaml,sha256=d6FYfWTdVAPlOCM55C1ICS6lF9nWPPVNwjlXCa9p7NU,600468
16
16
  sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
17
17
  sae_lens/saes/__init__.py,sha256=jVwazK8Q6dW5J6_zFXPoNAuBvSxgziQ8eMOjGM3t-X8,1475
18
- sae_lens/saes/batchtopk_sae.py,sha256=CyaFG2hMyyDaEaXXrAMJC8wQDW1JoddTKF5mvxxBQKY,3395
18
+ sae_lens/saes/batchtopk_sae.py,sha256=GX_J0vH4vzeLqYxl0mkfsZQpFEoCEHMR4dIG8fz8N8w,3449
19
19
  sae_lens/saes/gated_sae.py,sha256=qcmM9JwBA8aZR8z_IRHV1_gQX-q_63tKewWXRnhdXuo,8986
20
- sae_lens/saes/jumprelu_sae.py,sha256=3xkhBcCol2mEpIBLceymCpudocm2ypOjTeTXbpiXoA4,10794
20
+ sae_lens/saes/jumprelu_sae.py,sha256=HHBF1sJ95lZvxwP5vwLSQFKdnJN2KKYK0WAEaLTrta0,13399
21
21
  sae_lens/saes/sae.py,sha256=gdUZuLaOHQrPjbDj-nZI813B6-_mNAnV9i9z4qTnpHk,38255
22
22
  sae_lens/saes/standard_sae.py,sha256=9UqYyYtQuThYxXKNaDjYcyowpOx2-7cShG-TeUP6JCQ,5940
23
23
  sae_lens/saes/topk_sae.py,sha256=CXMBI6CFvI5829bOhoQ350VXR9d8uFHUDlULTIWHXoU,8686
@@ -33,7 +33,7 @@ sae_lens/training/types.py,sha256=qSjmGzXf3MLalygG0psnVjmhX_mpLmL47MQtZfe7qxg,81
33
33
  sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
34
34
  sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
35
35
  sae_lens/util.py,sha256=lW7fBn_b8quvRYlen9PUmB7km60YhKyjmuelB1f6KzQ,2253
36
- sae_lens-6.10.0.dist-info/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
37
- sae_lens-6.10.0.dist-info/METADATA,sha256=7Yq4_hrZVc2CBB4nMvgy_BGFjT5FrF3SfOo8LnJ18Rg,5245
38
- sae_lens-6.10.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
- sae_lens-6.10.0.dist-info/RECORD,,
36
+ sae_lens-6.11.1.dist-info/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
37
+ sae_lens-6.11.1.dist-info/METADATA,sha256=qRU9qqA2fLgiyLct7lTpOOLjkkXAIzUEdpDrV1NwKX0,5245
38
+ sae_lens-6.11.1.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
+ sae_lens-6.11.1.dist-info/RECORD,,