sae-lens 6.22.1__py3-none-any.whl → 6.24.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +8 -1
- sae_lens/loading/pretrained_sae_loaders.py +242 -24
- sae_lens/pretokenize_runner.py +3 -3
- sae_lens/pretrained_saes.yaml +26933 -97
- sae_lens/saes/__init__.py +4 -0
- sae_lens/saes/gated_sae.py +2 -2
- sae_lens/saes/jumprelu_sae.py +4 -4
- sae_lens/saes/sae.py +1 -1
- sae_lens/saes/topk_sae.py +1 -1
- sae_lens/saes/transcoder.py +41 -0
- sae_lens/training/activations_store.py +5 -5
- {sae_lens-6.22.1.dist-info → sae_lens-6.24.1.dist-info}/METADATA +2 -2
- {sae_lens-6.22.1.dist-info → sae_lens-6.24.1.dist-info}/RECORD +15 -15
- {sae_lens-6.22.1.dist-info → sae_lens-6.24.1.dist-info}/WHEEL +0 -0
- {sae_lens-6.22.1.dist-info → sae_lens-6.24.1.dist-info}/licenses/LICENSE +0 -0
sae_lens/saes/__init__.py
CHANGED
|
@@ -33,6 +33,8 @@ from .topk_sae import (
|
|
|
33
33
|
TopKTrainingSAEConfig,
|
|
34
34
|
)
|
|
35
35
|
from .transcoder import (
|
|
36
|
+
JumpReLUSkipTranscoder,
|
|
37
|
+
JumpReLUSkipTranscoderConfig,
|
|
36
38
|
JumpReLUTranscoder,
|
|
37
39
|
JumpReLUTranscoderConfig,
|
|
38
40
|
SkipTranscoder,
|
|
@@ -70,6 +72,8 @@ __all__ = [
|
|
|
70
72
|
"SkipTranscoderConfig",
|
|
71
73
|
"JumpReLUTranscoder",
|
|
72
74
|
"JumpReLUTranscoderConfig",
|
|
75
|
+
"JumpReLUSkipTranscoder",
|
|
76
|
+
"JumpReLUSkipTranscoderConfig",
|
|
73
77
|
"MatryoshkaBatchTopKTrainingSAE",
|
|
74
78
|
"MatryoshkaBatchTopKTrainingSAEConfig",
|
|
75
79
|
"TemporalSAE",
|
sae_lens/saes/gated_sae.py
CHANGED
|
@@ -89,7 +89,7 @@ class GatedSAE(SAE[GatedSAEConfig]):
|
|
|
89
89
|
@torch.no_grad()
|
|
90
90
|
def fold_W_dec_norm(self):
|
|
91
91
|
"""Override to handle gated-specific parameters."""
|
|
92
|
-
W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
|
|
92
|
+
W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1)
|
|
93
93
|
self.W_dec.data = self.W_dec.data / W_dec_norms
|
|
94
94
|
self.W_enc.data = self.W_enc.data * W_dec_norms.T
|
|
95
95
|
|
|
@@ -217,7 +217,7 @@ class GatedTrainingSAE(TrainingSAE[GatedTrainingSAEConfig]):
|
|
|
217
217
|
@torch.no_grad()
|
|
218
218
|
def fold_W_dec_norm(self):
|
|
219
219
|
"""Override to handle gated-specific parameters."""
|
|
220
|
-
W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
|
|
220
|
+
W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1)
|
|
221
221
|
self.W_dec.data = self.W_dec.data / W_dec_norms
|
|
222
222
|
self.W_enc.data = self.W_enc.data * W_dec_norms.T
|
|
223
223
|
|
sae_lens/saes/jumprelu_sae.py
CHANGED
|
@@ -167,8 +167,8 @@ class JumpReLUSAE(SAE[JumpReLUSAEConfig]):
|
|
|
167
167
|
# Save the current threshold before calling parent method
|
|
168
168
|
current_thresh = self.threshold.clone()
|
|
169
169
|
|
|
170
|
-
# Get W_dec norms that will be used for scaling
|
|
171
|
-
W_dec_norms = self.W_dec.norm(dim=-1)
|
|
170
|
+
# Get W_dec norms that will be used for scaling (clamped to avoid division by zero)
|
|
171
|
+
W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8)
|
|
172
172
|
|
|
173
173
|
# Call parent implementation to handle W_enc, W_dec, and b_enc adjustment
|
|
174
174
|
super().fold_W_dec_norm()
|
|
@@ -325,8 +325,8 @@ class JumpReLUTrainingSAE(TrainingSAE[JumpReLUTrainingSAEConfig]):
|
|
|
325
325
|
# Save the current threshold before we call the parent method
|
|
326
326
|
current_thresh = self.threshold.clone()
|
|
327
327
|
|
|
328
|
-
# Get W_dec norms
|
|
329
|
-
W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
|
|
328
|
+
# Get W_dec norms (clamped to avoid division by zero)
|
|
329
|
+
W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1)
|
|
330
330
|
|
|
331
331
|
# Call parent implementation to handle W_enc and W_dec adjustment
|
|
332
332
|
super().fold_W_dec_norm()
|
sae_lens/saes/sae.py
CHANGED
|
@@ -484,7 +484,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
484
484
|
@torch.no_grad()
|
|
485
485
|
def fold_W_dec_norm(self):
|
|
486
486
|
"""Fold decoder norms into encoder."""
|
|
487
|
-
W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
|
|
487
|
+
W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1)
|
|
488
488
|
self.W_dec.data = self.W_dec.data / W_dec_norms
|
|
489
489
|
self.W_enc.data = self.W_enc.data * W_dec_norms.T
|
|
490
490
|
|
sae_lens/saes/topk_sae.py
CHANGED
|
@@ -531,7 +531,7 @@ def _fold_norm_topk(
|
|
|
531
531
|
b_enc: torch.Tensor,
|
|
532
532
|
W_dec: torch.Tensor,
|
|
533
533
|
) -> None:
|
|
534
|
-
W_dec_norm = W_dec.norm(dim=-1)
|
|
534
|
+
W_dec_norm = W_dec.norm(dim=-1).clamp(min=1e-8)
|
|
535
535
|
b_enc.data = b_enc.data * W_dec_norm
|
|
536
536
|
W_dec_norms = W_dec_norm.unsqueeze(1)
|
|
537
537
|
W_dec.data = W_dec.data / W_dec_norms
|
sae_lens/saes/transcoder.py
CHANGED
|
@@ -368,3 +368,44 @@ class JumpReLUTranscoder(Transcoder):
|
|
|
368
368
|
def from_dict(cls, config_dict: dict[str, Any]) -> "JumpReLUTranscoder":
|
|
369
369
|
cfg = JumpReLUTranscoderConfig.from_dict(config_dict)
|
|
370
370
|
return cls(cfg)
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
@dataclass
|
|
374
|
+
class JumpReLUSkipTranscoderConfig(JumpReLUTranscoderConfig):
|
|
375
|
+
"""Configuration for JumpReLU transcoder."""
|
|
376
|
+
|
|
377
|
+
@classmethod
|
|
378
|
+
def architecture(cls) -> str:
|
|
379
|
+
"""Return the architecture name for this config."""
|
|
380
|
+
return "jumprelu_skip_transcoder"
|
|
381
|
+
|
|
382
|
+
@classmethod
|
|
383
|
+
def from_dict(cls, config_dict: dict[str, Any]) -> "JumpReLUSkipTranscoderConfig":
|
|
384
|
+
"""Create a JumpReLUSkipTranscoderConfig from a dictionary."""
|
|
385
|
+
# Filter to only include valid dataclass fields
|
|
386
|
+
filtered_config_dict = filter_valid_dataclass_fields(config_dict, cls)
|
|
387
|
+
|
|
388
|
+
# Create the config instance
|
|
389
|
+
res = cls(**filtered_config_dict)
|
|
390
|
+
|
|
391
|
+
# Handle metadata if present
|
|
392
|
+
if "metadata" in config_dict:
|
|
393
|
+
res.metadata = SAEMetadata(**config_dict["metadata"])
|
|
394
|
+
|
|
395
|
+
return res
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
class JumpReLUSkipTranscoder(JumpReLUTranscoder, SkipTranscoder):
|
|
399
|
+
"""
|
|
400
|
+
A transcoder with a learnable skip connection and JumpReLU activation function.
|
|
401
|
+
"""
|
|
402
|
+
|
|
403
|
+
cfg: JumpReLUSkipTranscoderConfig # type: ignore[assignment]
|
|
404
|
+
|
|
405
|
+
def __init__(self, cfg: JumpReLUSkipTranscoderConfig):
|
|
406
|
+
super().__init__(cfg)
|
|
407
|
+
|
|
408
|
+
@classmethod
|
|
409
|
+
def from_dict(cls, config_dict: dict[str, Any]) -> "JumpReLUSkipTranscoder":
|
|
410
|
+
cfg = JumpReLUSkipTranscoderConfig.from_dict(config_dict)
|
|
411
|
+
return cls(cfg)
|
|
@@ -166,9 +166,11 @@ class ActivationsStore:
|
|
|
166
166
|
disable_concat_sequences: bool = False,
|
|
167
167
|
sequence_separator_token: int | Literal["bos", "eos", "sep"] | None = "bos",
|
|
168
168
|
) -> ActivationsStore:
|
|
169
|
+
if context_size is None:
|
|
170
|
+
context_size = sae.cfg.metadata.context_size
|
|
169
171
|
if sae.cfg.metadata.hook_name is None:
|
|
170
172
|
raise ValueError("hook_name is required")
|
|
171
|
-
if
|
|
173
|
+
if context_size is None:
|
|
172
174
|
raise ValueError("context_size is required")
|
|
173
175
|
if sae.cfg.metadata.prepend_bos is None:
|
|
174
176
|
raise ValueError("prepend_bos is required")
|
|
@@ -178,9 +180,7 @@ class ActivationsStore:
|
|
|
178
180
|
d_in=sae.cfg.d_in,
|
|
179
181
|
hook_name=sae.cfg.metadata.hook_name,
|
|
180
182
|
hook_head_index=sae.cfg.metadata.hook_head_index,
|
|
181
|
-
context_size=
|
|
182
|
-
if context_size is None
|
|
183
|
-
else context_size,
|
|
183
|
+
context_size=context_size,
|
|
184
184
|
prepend_bos=sae.cfg.metadata.prepend_bos,
|
|
185
185
|
streaming=streaming,
|
|
186
186
|
store_batch_size_prompts=store_batch_size_prompts,
|
|
@@ -230,7 +230,7 @@ class ActivationsStore:
|
|
|
230
230
|
load_dataset(
|
|
231
231
|
dataset,
|
|
232
232
|
split="train",
|
|
233
|
-
streaming=streaming,
|
|
233
|
+
streaming=streaming, # type: ignore
|
|
234
234
|
trust_remote_code=dataset_trust_remote_code, # type: ignore
|
|
235
235
|
)
|
|
236
236
|
if isinstance(dataset, str)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: sae-lens
|
|
3
|
-
Version: 6.
|
|
3
|
+
Version: 6.24.1
|
|
4
4
|
Summary: Training and Analyzing Sparse Autoencoders (SAEs)
|
|
5
5
|
License: MIT
|
|
6
6
|
License-File: LICENSE
|
|
@@ -62,7 +62,7 @@ This library is maintained by [Joseph Bloom](https://www.decoderesearch.com/), [
|
|
|
62
62
|
|
|
63
63
|
## Loading Pre-trained SAEs.
|
|
64
64
|
|
|
65
|
-
Pre-trained SAEs for various models can be imported via SAE Lens. See this [page](https://decoderesearch.github.io/SAELens/
|
|
65
|
+
Pre-trained SAEs for various models can be imported via SAE Lens. See this [page](https://decoderesearch.github.io/SAELens/pretrained_saes/) for a list of all SAEs.
|
|
66
66
|
|
|
67
67
|
## Migrating to SAELens v6
|
|
68
68
|
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
sae_lens/__init__.py,sha256=
|
|
1
|
+
sae_lens/__init__.py,sha256=spLEw4TR2BzzKc3R-ik8MbHlYOAR__wVmkSmJqOB4Tc,4268
|
|
2
2
|
sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
3
|
sae_lens/analysis/hooked_sae_transformer.py,sha256=dQRgGVwce8XwylL2AzJE7l9elhtMRFCs2hdUj-Qyy4g,14038
|
|
4
4
|
sae_lens/analysis/neuronpedia_integration.py,sha256=Gx1W7hUBEuMoasNcnOnZ1wmqbXDd1pSZ1nqKEya1HQc,4962
|
|
@@ -9,25 +9,25 @@ sae_lens/evals.py,sha256=P0NUsJeGzYxFBiVKhbPzd72IFKY4gH40HHlEZ3jEAmg,39598
|
|
|
9
9
|
sae_lens/llm_sae_training_runner.py,sha256=M7BK55gSFYu2qFQKABHX3c8i46P1LfODCeyHFzGGuqU,15196
|
|
10
10
|
sae_lens/load_model.py,sha256=C8AMykctj6H7tz_xRwB06-EXj6TfW64PtSJZR5Jxn1Y,8649
|
|
11
11
|
sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
12
|
-
sae_lens/loading/pretrained_sae_loaders.py,sha256=
|
|
12
|
+
sae_lens/loading/pretrained_sae_loaders.py,sha256=W2eIvUU1wAHrYxGiZs4s2D6DnGBQqqKjq0wvXzWbD5c,63561
|
|
13
13
|
sae_lens/loading/pretrained_saes_directory.py,sha256=hejNfLUepYCSGPalRfQwxxCEUqMMUPsn1tufwvwct5k,3820
|
|
14
|
-
sae_lens/pretokenize_runner.py,sha256=
|
|
15
|
-
sae_lens/pretrained_saes.yaml,sha256=
|
|
14
|
+
sae_lens/pretokenize_runner.py,sha256=amJwIz3CKi2s2wNQn-10E7eAV7VFhNqtFDNTeTkwEI8,7133
|
|
15
|
+
sae_lens/pretrained_saes.yaml,sha256=Hd1GgaPL4TAXoS2gizG9e_9jc_9LpfI4w_hwGkEz9xQ,1509314
|
|
16
16
|
sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
|
|
17
|
-
sae_lens/saes/__init__.py,sha256=
|
|
17
|
+
sae_lens/saes/__init__.py,sha256=fYVujOzNnUgpzLL0MBLBt_DNX2CPcTaheukzCd2bEPo,1906
|
|
18
18
|
sae_lens/saes/batchtopk_sae.py,sha256=x4EbgZl0GUickRPcCmtKNGS2Ra3Uy1Z1OtF2FnrSabQ,5422
|
|
19
|
-
sae_lens/saes/gated_sae.py,sha256=
|
|
20
|
-
sae_lens/saes/jumprelu_sae.py,sha256=
|
|
19
|
+
sae_lens/saes/gated_sae.py,sha256=mHnmw-RD7hqIbP9_EBj3p2SK0OqQIkZivdOKRygeRgw,8825
|
|
20
|
+
sae_lens/saes/jumprelu_sae.py,sha256=udjGHp3WTABQSL2Qq57j-bINWX61GCmo68EmdjMOXoo,13310
|
|
21
21
|
sae_lens/saes/matryoshka_batchtopk_sae.py,sha256=Qr6htt1HHOuO9FXI9hyaPSnGFIiJG-v7y1t1CEmkFzM,5995
|
|
22
|
-
sae_lens/saes/sae.py,sha256=
|
|
22
|
+
sae_lens/saes/sae.py,sha256=Vb1aGSDPRv_0J2aL8-EICRSkIxsO6Q4lJaJE9NNmfdA,37749
|
|
23
23
|
sae_lens/saes/standard_sae.py,sha256=nEVETwAmRD2tyX7ESIic1fij48gAq1Dh7s_GQ2fqCZ4,5747
|
|
24
24
|
sae_lens/saes/temporal_sae.py,sha256=DsecivcHWId-MTuJpQbz8OhqtmGhZACxJauYZGHo0Ok,13272
|
|
25
|
-
sae_lens/saes/topk_sae.py,sha256=
|
|
26
|
-
sae_lens/saes/transcoder.py,sha256=
|
|
25
|
+
sae_lens/saes/topk_sae.py,sha256=vrMRPrCQR1o8G_kXqY_EAoGZARupkQNFB2dNZVLsusE,21073
|
|
26
|
+
sae_lens/saes/transcoder.py,sha256=CTpJs8ASOK06npih7gZHygZuxqTR7HICWlOYfTiKjI4,13501
|
|
27
27
|
sae_lens/tokenization_and_batching.py,sha256=D_o7cXvRqhT89H3wNzoRymNALNE6eHojBWLdXOUwUGE,5438
|
|
28
28
|
sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
29
29
|
sae_lens/training/activation_scaler.py,sha256=FzNfgBplLWmyiSlZ6TUvE-nur3lOiGTrlvC97ys8S24,1973
|
|
30
|
-
sae_lens/training/activations_store.py,sha256=
|
|
30
|
+
sae_lens/training/activations_store.py,sha256=yDWw7TZGPFM_O8_Oi78j8lLIHJJesxq9TKVP_TrMX-M,33768
|
|
31
31
|
sae_lens/training/mixing_buffer.py,sha256=vDpYG5ZE70szDvBsRKcNHEES3h_WTKJ16qDYk5jPOVA,2015
|
|
32
32
|
sae_lens/training/optim.py,sha256=bJpqqcK4enkcPvQAJkeH4Ci1LUOlfjIMTv6-IlaAbRA,5588
|
|
33
33
|
sae_lens/training/sae_trainer.py,sha256=zhkabyIKxI_tZTV3_kwz6zMrHZ95Ecr97krmwc-9ffs,17600
|
|
@@ -35,7 +35,7 @@ sae_lens/training/types.py,sha256=1FpLx_Doda9vZpmfm-x1e8wGBYpyhe9Kpb_JuM5nIFM,90
|
|
|
35
35
|
sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
|
|
36
36
|
sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
|
|
37
37
|
sae_lens/util.py,sha256=tCovQ-eZa1L7thPpNDL6PGOJrIMML2yLI5e0EHCOpS8,3309
|
|
38
|
-
sae_lens-6.
|
|
39
|
-
sae_lens-6.
|
|
40
|
-
sae_lens-6.
|
|
41
|
-
sae_lens-6.
|
|
38
|
+
sae_lens-6.24.1.dist-info/METADATA,sha256=5TlxCqEZoJV4S0F9IP6Ak_aitVkMkFfUhlFOl5NIJBc,5361
|
|
39
|
+
sae_lens-6.24.1.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
|
|
40
|
+
sae_lens-6.24.1.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
|
|
41
|
+
sae_lens-6.24.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|