sae-lens 6.10.0__tar.gz → 6.11.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sae_lens-6.10.0 → sae_lens-6.11.0}/PKG-INFO +1 -1
- {sae_lens-6.10.0 → sae_lens-6.11.0}/pyproject.toml +1 -1
- {sae_lens-6.10.0 → sae_lens-6.11.0}/sae_lens/__init__.py +1 -1
- {sae_lens-6.10.0 → sae_lens-6.11.0}/sae_lens/saes/jumprelu_sae.py +64 -4
- {sae_lens-6.10.0 → sae_lens-6.11.0}/LICENSE +0 -0
- {sae_lens-6.10.0 → sae_lens-6.11.0}/README.md +0 -0
- {sae_lens-6.10.0 → sae_lens-6.11.0}/sae_lens/analysis/__init__.py +0 -0
- {sae_lens-6.10.0 → sae_lens-6.11.0}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
- {sae_lens-6.10.0 → sae_lens-6.11.0}/sae_lens/analysis/neuronpedia_integration.py +0 -0
- {sae_lens-6.10.0 → sae_lens-6.11.0}/sae_lens/cache_activations_runner.py +0 -0
- {sae_lens-6.10.0 → sae_lens-6.11.0}/sae_lens/config.py +0 -0
- {sae_lens-6.10.0 → sae_lens-6.11.0}/sae_lens/constants.py +0 -0
- {sae_lens-6.10.0 → sae_lens-6.11.0}/sae_lens/evals.py +0 -0
- {sae_lens-6.10.0 → sae_lens-6.11.0}/sae_lens/llm_sae_training_runner.py +0 -0
- {sae_lens-6.10.0 → sae_lens-6.11.0}/sae_lens/load_model.py +0 -0
- {sae_lens-6.10.0 → sae_lens-6.11.0}/sae_lens/loading/__init__.py +0 -0
- {sae_lens-6.10.0 → sae_lens-6.11.0}/sae_lens/loading/pretrained_sae_loaders.py +0 -0
- {sae_lens-6.10.0 → sae_lens-6.11.0}/sae_lens/loading/pretrained_saes_directory.py +0 -0
- {sae_lens-6.10.0 → sae_lens-6.11.0}/sae_lens/pretokenize_runner.py +0 -0
- {sae_lens-6.10.0 → sae_lens-6.11.0}/sae_lens/pretrained_saes.yaml +0 -0
- {sae_lens-6.10.0 → sae_lens-6.11.0}/sae_lens/registry.py +0 -0
- {sae_lens-6.10.0 → sae_lens-6.11.0}/sae_lens/saes/__init__.py +0 -0
- {sae_lens-6.10.0 → sae_lens-6.11.0}/sae_lens/saes/batchtopk_sae.py +0 -0
- {sae_lens-6.10.0 → sae_lens-6.11.0}/sae_lens/saes/gated_sae.py +0 -0
- {sae_lens-6.10.0 → sae_lens-6.11.0}/sae_lens/saes/sae.py +0 -0
- {sae_lens-6.10.0 → sae_lens-6.11.0}/sae_lens/saes/standard_sae.py +0 -0
- {sae_lens-6.10.0 → sae_lens-6.11.0}/sae_lens/saes/topk_sae.py +0 -0
- {sae_lens-6.10.0 → sae_lens-6.11.0}/sae_lens/saes/transcoder.py +0 -0
- {sae_lens-6.10.0 → sae_lens-6.11.0}/sae_lens/tokenization_and_batching.py +0 -0
- {sae_lens-6.10.0 → sae_lens-6.11.0}/sae_lens/training/__init__.py +0 -0
- {sae_lens-6.10.0 → sae_lens-6.11.0}/sae_lens/training/activation_scaler.py +0 -0
- {sae_lens-6.10.0 → sae_lens-6.11.0}/sae_lens/training/activations_store.py +0 -0
- {sae_lens-6.10.0 → sae_lens-6.11.0}/sae_lens/training/mixing_buffer.py +0 -0
- {sae_lens-6.10.0 → sae_lens-6.11.0}/sae_lens/training/optim.py +0 -0
- {sae_lens-6.10.0 → sae_lens-6.11.0}/sae_lens/training/sae_trainer.py +0 -0
- {sae_lens-6.10.0 → sae_lens-6.11.0}/sae_lens/training/types.py +0 -0
- {sae_lens-6.10.0 → sae_lens-6.11.0}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
- {sae_lens-6.10.0 → sae_lens-6.11.0}/sae_lens/tutorial/tsea.py +0 -0
- {sae_lens-6.10.0 → sae_lens-6.11.0}/sae_lens/util.py +0 -0
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
|
-
from typing import Any
|
|
2
|
+
from typing import Any, Literal
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
import torch
|
|
@@ -187,13 +187,29 @@ class JumpReLUSAE(SAE[JumpReLUSAEConfig]):
|
|
|
187
187
|
class JumpReLUTrainingSAEConfig(TrainingSAEConfig):
|
|
188
188
|
"""
|
|
189
189
|
Configuration class for training a JumpReLUTrainingSAE.
|
|
190
|
+
|
|
191
|
+
- jumprelu_init_threshold: initial threshold for the JumpReLU activation
|
|
192
|
+
- jumprelu_bandwidth: bandwidth for the JumpReLU activation
|
|
193
|
+
- jumprelu_sparsity_loss_mode: mode for the sparsity loss, either "step" or "tanh". "step" is Google Deepmind's L0 loss, "tanh" is Anthropic's sparsity loss.
|
|
194
|
+
- l0_coefficient: coefficient for the l0 sparsity loss
|
|
195
|
+
- l0_warm_up_steps: number of warm-up steps for the l0 sparsity loss
|
|
196
|
+
- pre_act_loss_coefficient: coefficient for the pre-activation loss. Set to None to disable. Set to 3e-6 to match Anthropic's setup. Default is None.
|
|
197
|
+
- jumprelu_tanh_scale: scale for the tanh sparsity loss. Only relevant for "tanh" sparsity loss mode. Default is 4.0.
|
|
190
198
|
"""
|
|
191
199
|
|
|
192
200
|
jumprelu_init_threshold: float = 0.01
|
|
193
201
|
jumprelu_bandwidth: float = 0.05
|
|
202
|
+
# step is Google Deepmind, tanh is Anthropic
|
|
203
|
+
jumprelu_sparsity_loss_mode: Literal["step", "tanh"] = "step"
|
|
194
204
|
l0_coefficient: float = 1.0
|
|
195
205
|
l0_warm_up_steps: int = 0
|
|
196
206
|
|
|
207
|
+
# anthropic's auxiliary loss to avoid dead features
|
|
208
|
+
pre_act_loss_coefficient: float | None = None
|
|
209
|
+
|
|
210
|
+
# only relevant for tanh sparsity loss mode
|
|
211
|
+
jumprelu_tanh_scale: float = 4.0
|
|
212
|
+
|
|
197
213
|
@override
|
|
198
214
|
@classmethod
|
|
199
215
|
def architecture(cls) -> str:
|
|
@@ -267,9 +283,35 @@ class JumpReLUTrainingSAE(TrainingSAE[JumpReLUTrainingSAEConfig]):
|
|
|
267
283
|
sae_out: torch.Tensor,
|
|
268
284
|
) -> dict[str, torch.Tensor]:
|
|
269
285
|
"""Calculate architecture-specific auxiliary loss terms."""
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
286
|
+
|
|
287
|
+
threshold = self.threshold
|
|
288
|
+
W_dec_norm = self.W_dec.norm(dim=1)
|
|
289
|
+
if self.cfg.jumprelu_sparsity_loss_mode == "step":
|
|
290
|
+
l0 = torch.sum(
|
|
291
|
+
Step.apply(hidden_pre, threshold, self.bandwidth), # type: ignore
|
|
292
|
+
dim=-1,
|
|
293
|
+
)
|
|
294
|
+
l0_loss = (step_input.coefficients["l0"] * l0).mean()
|
|
295
|
+
elif self.cfg.jumprelu_sparsity_loss_mode == "tanh":
|
|
296
|
+
per_item_l0_loss = torch.tanh(
|
|
297
|
+
self.cfg.jumprelu_tanh_scale * feature_acts * W_dec_norm
|
|
298
|
+
).sum(dim=-1)
|
|
299
|
+
l0_loss = (step_input.coefficients["l0"] * per_item_l0_loss).mean()
|
|
300
|
+
else:
|
|
301
|
+
raise ValueError(
|
|
302
|
+
f"Invalid sparsity loss mode: {self.cfg.jumprelu_sparsity_loss_mode}"
|
|
303
|
+
)
|
|
304
|
+
losses = {"l0_loss": l0_loss}
|
|
305
|
+
|
|
306
|
+
if self.cfg.pre_act_loss_coefficient is not None:
|
|
307
|
+
losses["pre_act_loss"] = calculate_pre_act_loss(
|
|
308
|
+
self.cfg.pre_act_loss_coefficient,
|
|
309
|
+
threshold,
|
|
310
|
+
hidden_pre,
|
|
311
|
+
step_input.dead_neuron_mask,
|
|
312
|
+
W_dec_norm,
|
|
313
|
+
)
|
|
314
|
+
return losses
|
|
273
315
|
|
|
274
316
|
@override
|
|
275
317
|
def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]:
|
|
@@ -310,3 +352,21 @@ class JumpReLUTrainingSAE(TrainingSAE[JumpReLUTrainingSAEConfig]):
|
|
|
310
352
|
threshold = state_dict["threshold"]
|
|
311
353
|
del state_dict["threshold"]
|
|
312
354
|
state_dict["log_threshold"] = torch.log(threshold).detach().contiguous()
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
def calculate_pre_act_loss(
|
|
358
|
+
pre_act_loss_coefficient: float,
|
|
359
|
+
threshold: torch.Tensor,
|
|
360
|
+
hidden_pre: torch.Tensor,
|
|
361
|
+
dead_neuron_mask: torch.Tensor | None,
|
|
362
|
+
W_dec_norm: torch.Tensor,
|
|
363
|
+
) -> torch.Tensor:
|
|
364
|
+
"""
|
|
365
|
+
Calculate Anthropic's pre-activation loss, except we only calculate this for latents that are actually dead.
|
|
366
|
+
"""
|
|
367
|
+
if dead_neuron_mask is None or not dead_neuron_mask.any():
|
|
368
|
+
return hidden_pre.new_tensor(0.0)
|
|
369
|
+
per_item_loss = (
|
|
370
|
+
(threshold - hidden_pre).relu() * dead_neuron_mask * W_dec_norm
|
|
371
|
+
).sum(dim=-1)
|
|
372
|
+
return pre_act_loss_coefficient * per_item_loss.mean()
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|