sae-lens 6.22.2__tar.gz → 6.22.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sae_lens-6.22.2 → sae_lens-6.22.3}/PKG-INFO +1 -1
- {sae_lens-6.22.2 → sae_lens-6.22.3}/pyproject.toml +1 -1
- {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/__init__.py +1 -1
- {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/saes/gated_sae.py +2 -2
- {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/saes/jumprelu_sae.py +4 -4
- {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/saes/sae.py +1 -1
- {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/saes/topk_sae.py +1 -1
- {sae_lens-6.22.2 → sae_lens-6.22.3}/LICENSE +0 -0
- {sae_lens-6.22.2 → sae_lens-6.22.3}/README.md +0 -0
- {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/analysis/__init__.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/analysis/neuronpedia_integration.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/cache_activations_runner.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/config.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/constants.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/evals.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/llm_sae_training_runner.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/load_model.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/loading/__init__.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/loading/pretrained_sae_loaders.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/loading/pretrained_saes_directory.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/pretokenize_runner.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/pretrained_saes.yaml +0 -0
- {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/registry.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/saes/__init__.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/saes/batchtopk_sae.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/saes/matryoshka_batchtopk_sae.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/saes/standard_sae.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/saes/temporal_sae.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/saes/transcoder.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/tokenization_and_batching.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/training/__init__.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/training/activation_scaler.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/training/activations_store.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/training/mixing_buffer.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/training/optim.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/training/sae_trainer.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/training/types.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/tutorial/tsea.py +0 -0
- {sae_lens-6.22.2 → sae_lens-6.22.3}/sae_lens/util.py +0 -0
|
@@ -89,7 +89,7 @@ class GatedSAE(SAE[GatedSAEConfig]):
|
|
|
89
89
|
@torch.no_grad()
|
|
90
90
|
def fold_W_dec_norm(self):
|
|
91
91
|
"""Override to handle gated-specific parameters."""
|
|
92
|
-
W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
|
|
92
|
+
W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1)
|
|
93
93
|
self.W_dec.data = self.W_dec.data / W_dec_norms
|
|
94
94
|
self.W_enc.data = self.W_enc.data * W_dec_norms.T
|
|
95
95
|
|
|
@@ -217,7 +217,7 @@ class GatedTrainingSAE(TrainingSAE[GatedTrainingSAEConfig]):
|
|
|
217
217
|
@torch.no_grad()
|
|
218
218
|
def fold_W_dec_norm(self):
|
|
219
219
|
"""Override to handle gated-specific parameters."""
|
|
220
|
-
W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
|
|
220
|
+
W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1)
|
|
221
221
|
self.W_dec.data = self.W_dec.data / W_dec_norms
|
|
222
222
|
self.W_enc.data = self.W_enc.data * W_dec_norms.T
|
|
223
223
|
|
|
@@ -167,8 +167,8 @@ class JumpReLUSAE(SAE[JumpReLUSAEConfig]):
|
|
|
167
167
|
# Save the current threshold before calling parent method
|
|
168
168
|
current_thresh = self.threshold.clone()
|
|
169
169
|
|
|
170
|
-
# Get W_dec norms that will be used for scaling
|
|
171
|
-
W_dec_norms = self.W_dec.norm(dim=-1)
|
|
170
|
+
# Get W_dec norms that will be used for scaling (clamped to avoid division by zero)
|
|
171
|
+
W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8)
|
|
172
172
|
|
|
173
173
|
# Call parent implementation to handle W_enc, W_dec, and b_enc adjustment
|
|
174
174
|
super().fold_W_dec_norm()
|
|
@@ -325,8 +325,8 @@ class JumpReLUTrainingSAE(TrainingSAE[JumpReLUTrainingSAEConfig]):
|
|
|
325
325
|
# Save the current threshold before we call the parent method
|
|
326
326
|
current_thresh = self.threshold.clone()
|
|
327
327
|
|
|
328
|
-
# Get W_dec norms
|
|
329
|
-
W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
|
|
328
|
+
# Get W_dec norms (clamped to avoid division by zero)
|
|
329
|
+
W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1)
|
|
330
330
|
|
|
331
331
|
# Call parent implementation to handle W_enc and W_dec adjustment
|
|
332
332
|
super().fold_W_dec_norm()
|
|
@@ -484,7 +484,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
484
484
|
@torch.no_grad()
|
|
485
485
|
def fold_W_dec_norm(self):
|
|
486
486
|
"""Fold decoder norms into encoder."""
|
|
487
|
-
W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
|
|
487
|
+
W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1)
|
|
488
488
|
self.W_dec.data = self.W_dec.data / W_dec_norms
|
|
489
489
|
self.W_enc.data = self.W_enc.data * W_dec_norms.T
|
|
490
490
|
|
|
@@ -531,7 +531,7 @@ def _fold_norm_topk(
|
|
|
531
531
|
b_enc: torch.Tensor,
|
|
532
532
|
W_dec: torch.Tensor,
|
|
533
533
|
) -> None:
|
|
534
|
-
W_dec_norm = W_dec.norm(dim=-1)
|
|
534
|
+
W_dec_norm = W_dec.norm(dim=-1).clamp(min=1e-8)
|
|
535
535
|
b_enc.data = b_enc.data * W_dec_norm
|
|
536
536
|
W_dec_norms = W_dec_norm.unsqueeze(1)
|
|
537
537
|
W_dec.data = W_dec.data / W_dec_norms
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|