sae-lens 6.28.1__py3-none-any.whl → 6.29.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +1 -1
- sae_lens/pretrained_saes.yaml +1 -1
- sae_lens/synthetic/__init__.py +6 -0
- sae_lens/synthetic/activation_generator.py +198 -25
- sae_lens/synthetic/correlation.py +217 -36
- sae_lens/synthetic/feature_dictionary.py +64 -17
- sae_lens/synthetic/hierarchy.py +657 -84
- sae_lens/synthetic/training.py +16 -3
- {sae_lens-6.28.1.dist-info → sae_lens-6.29.1.dist-info}/METADATA +11 -1
- {sae_lens-6.28.1.dist-info → sae_lens-6.29.1.dist-info}/RECORD +12 -12
- {sae_lens-6.28.1.dist-info → sae_lens-6.29.1.dist-info}/WHEEL +0 -0
- {sae_lens-6.28.1.dist-info → sae_lens-6.29.1.dist-info}/licenses/LICENSE +0 -0
sae_lens/synthetic/training.py
CHANGED
|
@@ -23,6 +23,8 @@ def train_toy_sae(
|
|
|
23
23
|
device: str | torch.device = "cpu",
|
|
24
24
|
n_snapshots: int = 0,
|
|
25
25
|
snapshot_fn: Callable[[SAETrainer[Any, Any]], None] | None = None,
|
|
26
|
+
autocast_sae: bool = False,
|
|
27
|
+
autocast_data: bool = False,
|
|
26
28
|
) -> None:
|
|
27
29
|
"""
|
|
28
30
|
Train an SAE on synthetic activations from a feature dictionary.
|
|
@@ -46,6 +48,8 @@ def train_toy_sae(
|
|
|
46
48
|
snapshot_fn: Callback function called at each snapshot point. Receives
|
|
47
49
|
the SAETrainer instance, allowing access to the SAE, training step,
|
|
48
50
|
and other training state. Required if n_snapshots > 0.
|
|
51
|
+
autocast_sae: Whether to autocast the SAE to bfloat16. Only recommend for large SAEs on CUDA
|
|
52
|
+
autocast_data: Whether to autocast the activations generator and feature dictionary to bfloat16. Only recommend for large data on CUDA.
|
|
49
53
|
"""
|
|
50
54
|
|
|
51
55
|
device_str = str(device) if isinstance(device, torch.device) else device
|
|
@@ -55,6 +59,7 @@ def train_toy_sae(
|
|
|
55
59
|
feature_dict=feature_dict,
|
|
56
60
|
activations_generator=activations_generator,
|
|
57
61
|
batch_size=batch_size,
|
|
62
|
+
autocast=autocast_data,
|
|
58
63
|
)
|
|
59
64
|
|
|
60
65
|
# Create trainer config
|
|
@@ -64,7 +69,7 @@ def train_toy_sae(
|
|
|
64
69
|
save_final_checkpoint=False,
|
|
65
70
|
total_training_samples=training_samples,
|
|
66
71
|
device=device_str,
|
|
67
|
-
autocast=
|
|
72
|
+
autocast=autocast_sae,
|
|
68
73
|
lr=lr,
|
|
69
74
|
lr_end=lr,
|
|
70
75
|
lr_scheduler_name="constant",
|
|
@@ -119,6 +124,7 @@ class SyntheticActivationIterator(Iterator[torch.Tensor]):
|
|
|
119
124
|
feature_dict: FeatureDictionary,
|
|
120
125
|
activations_generator: ActivationGenerator,
|
|
121
126
|
batch_size: int,
|
|
127
|
+
autocast: bool = False,
|
|
122
128
|
):
|
|
123
129
|
"""
|
|
124
130
|
Create a new SyntheticActivationIterator.
|
|
@@ -127,16 +133,23 @@ class SyntheticActivationIterator(Iterator[torch.Tensor]):
|
|
|
127
133
|
feature_dict: The feature dictionary to use for generating hidden activations
|
|
128
134
|
activations_generator: Generator that produces feature activations
|
|
129
135
|
batch_size: Number of samples per batch
|
|
136
|
+
autocast: Whether to autocast the activations generator and feature dictionary to bfloat16.
|
|
130
137
|
"""
|
|
131
138
|
self.feature_dict = feature_dict
|
|
132
139
|
self.activations_generator = activations_generator
|
|
133
140
|
self.batch_size = batch_size
|
|
141
|
+
self.autocast = autocast
|
|
134
142
|
|
|
135
143
|
@torch.no_grad()
|
|
136
144
|
def next_batch(self) -> torch.Tensor:
|
|
137
145
|
"""Generate the next batch of hidden activations."""
|
|
138
|
-
|
|
139
|
-
|
|
146
|
+
with torch.autocast(
|
|
147
|
+
device_type=self.feature_dict.feature_vectors.device.type,
|
|
148
|
+
dtype=torch.bfloat16,
|
|
149
|
+
enabled=self.autocast,
|
|
150
|
+
):
|
|
151
|
+
features = self.activations_generator(self.batch_size)
|
|
152
|
+
return self.feature_dict(features)
|
|
140
153
|
|
|
141
154
|
def __iter__(self) -> "SyntheticActivationIterator":
|
|
142
155
|
return self
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: sae-lens
|
|
3
|
-
Version: 6.
|
|
3
|
+
Version: 6.29.1
|
|
4
4
|
Summary: Training and Analyzing Sparse Autoencoders (SAEs)
|
|
5
5
|
License: MIT
|
|
6
6
|
License-File: LICENSE
|
|
@@ -50,6 +50,8 @@ SAELens exists to help researchers:
|
|
|
50
50
|
- Analyse sparse autoencoders / research mechanistic interpretability.
|
|
51
51
|
- Generate insights which make it easier to create safe and aligned AI systems.
|
|
52
52
|
|
|
53
|
+
SAELens inference works with any PyTorch-based model, not just TransformerLens. While we provide deep integration with TransformerLens via `HookedSAETransformer`, SAEs can be used with Hugging Face Transformers, NNsight, or any other framework by extracting activations and passing them to the SAE's `encode()` and `decode()` methods.
|
|
54
|
+
|
|
53
55
|
Please refer to the [documentation](https://decoderesearch.github.io/SAELens/) for information on how to:
|
|
54
56
|
|
|
55
57
|
- Download and Analyse pre-trained sparse autoencoders.
|
|
@@ -84,6 +86,14 @@ The new v6 update is a major refactor to SAELens and changes the way training co
|
|
|
84
86
|
|
|
85
87
|
Feel free to join the [Open Source Mechanistic Interpretability Slack](https://join.slack.com/t/opensourcemechanistic/shared_invite/zt-375zalm04-GFd5tdBU1yLKlu_T_JSqZQ) for support!
|
|
86
88
|
|
|
89
|
+
## Other SAE Projects
|
|
90
|
+
|
|
91
|
+
- [dictionary-learning](https://github.com/saprmarks/dictionary_learning): An SAE training library that focuses on having hackable code.
|
|
92
|
+
- [Sparsify](https://github.com/EleutherAI/sparsify): A lean SAE training library focused on TopK SAEs.
|
|
93
|
+
- [Overcomplete](https://github.com/KempnerInstitute/overcomplete): SAE training library focused on vision models.
|
|
94
|
+
- [SAE-Vis](https://github.com/callummcdougall/sae_vis): A library for visualizing SAE features, works with SAELens.
|
|
95
|
+
- [SAEBench](https://github.com/adamkarvonen/SAEBench): A suite of LLM SAE benchmarks, works with SAELens.
|
|
96
|
+
|
|
87
97
|
## Citation
|
|
88
98
|
|
|
89
99
|
Please cite the package as follows:
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
sae_lens/__init__.py,sha256=
|
|
1
|
+
sae_lens/__init__.py,sha256=emqKVNiJwD8YtYhtgHJyAT8YSX1QmruQYuG-J4CStC4,4788
|
|
2
2
|
sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
3
|
sae_lens/analysis/hooked_sae_transformer.py,sha256=dQRgGVwce8XwylL2AzJE7l9elhtMRFCs2hdUj-Qyy4g,14038
|
|
4
4
|
sae_lens/analysis/neuronpedia_integration.py,sha256=Gx1W7hUBEuMoasNcnOnZ1wmqbXDd1pSZ1nqKEya1HQc,4962
|
|
@@ -12,7 +12,7 @@ sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,
|
|
|
12
12
|
sae_lens/loading/pretrained_sae_loaders.py,sha256=hHMlew1u6zVlbzvS9S_SfUPnAG0_OAjjIcjoUTIUZrU,63657
|
|
13
13
|
sae_lens/loading/pretrained_saes_directory.py,sha256=1at_aQbD8WFywchQCKuwfP-yvCq_Z2aUYrpKDnSN5Nc,4283
|
|
14
14
|
sae_lens/pretokenize_runner.py,sha256=amJwIz3CKi2s2wNQn-10E7eAV7VFhNqtFDNTeTkwEI8,7133
|
|
15
|
-
sae_lens/pretrained_saes.yaml,sha256=
|
|
15
|
+
sae_lens/pretrained_saes.yaml,sha256=Nq43dTcFvDDONTuJ9Me_HQ5nHqr9BdbP5-ZJGXj0TAQ,1509932
|
|
16
16
|
sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
|
|
17
17
|
sae_lens/saes/__init__.py,sha256=SBqPaP6Gl5uPFwHlumAZATC4Wd26xKIYLAAAo4MSa5Q,2200
|
|
18
18
|
sae_lens/saes/batchtopk_sae.py,sha256=x4EbgZl0GUickRPcCmtKNGS2Ra3Uy1Z1OtF2FnrSabQ,5422
|
|
@@ -25,16 +25,16 @@ sae_lens/saes/standard_sae.py,sha256=_hldNZkFPAf9VGrxouR1-tN8T2OEk8IkWBcXoatrC1o
|
|
|
25
25
|
sae_lens/saes/temporal_sae.py,sha256=83Ap4mYGfdN3sKdPF8nKjhdXph3-7E2QuLobqJ_YuoM,13273
|
|
26
26
|
sae_lens/saes/topk_sae.py,sha256=vrMRPrCQR1o8G_kXqY_EAoGZARupkQNFB2dNZVLsusE,21073
|
|
27
27
|
sae_lens/saes/transcoder.py,sha256=CTpJs8ASOK06npih7gZHygZuxqTR7HICWlOYfTiKjI4,13501
|
|
28
|
-
sae_lens/synthetic/__init__.py,sha256=
|
|
29
|
-
sae_lens/synthetic/activation_generator.py,sha256=
|
|
30
|
-
sae_lens/synthetic/correlation.py,sha256=
|
|
28
|
+
sae_lens/synthetic/__init__.py,sha256=MtTnGkTfHV2WjkIgs7zZyx10EK9U5fjOHXy69Aq3uKw,3095
|
|
29
|
+
sae_lens/synthetic/activation_generator.py,sha256=8L9nwC4jFRv_wg3QN-n1sFwX8w1NqwJMysWaJ41lLlY,15197
|
|
30
|
+
sae_lens/synthetic/correlation.py,sha256=tMTLo9fBfDpeXwqhyUgFqnTipj9x2W0t4oEtNxB7AG0,13256
|
|
31
31
|
sae_lens/synthetic/evals.py,sha256=Nhi314ZnRgLfhBj-3tm_zzI-pGyFTcwllDXbIpPFXeU,4584
|
|
32
|
-
sae_lens/synthetic/feature_dictionary.py,sha256=
|
|
32
|
+
sae_lens/synthetic/feature_dictionary.py,sha256=Nd4xjSTxKMnKilZ3uYi8Gv5SS5D4bv4wHiSL1uGB69E,6933
|
|
33
33
|
sae_lens/synthetic/firing_probabilities.py,sha256=yclz1pWl5gE1r8LAxFvzQS88Lxwk5-3r8BCX9HLVejA,3370
|
|
34
|
-
sae_lens/synthetic/hierarchy.py,sha256=
|
|
34
|
+
sae_lens/synthetic/hierarchy.py,sha256=nm7nwnTswktVJeKUsRZ0hLOdXcFWGbxnA1b6lefHm-4,33592
|
|
35
35
|
sae_lens/synthetic/initialization.py,sha256=orMGW-786wRDHIS2W7bEH0HmlVFQ4g2z4bnnwdv5w4s,1386
|
|
36
36
|
sae_lens/synthetic/plotting.py,sha256=5lFrej1QOkGAcImFNo5-o-8mI_rUVqvEI57KzUQPPtQ,8208
|
|
37
|
-
sae_lens/synthetic/training.py,sha256=
|
|
37
|
+
sae_lens/synthetic/training.py,sha256=fHcX2cZ6nDupr71GX0Gk17f1NvQ0SKIVXIA6IuAb2dw,5692
|
|
38
38
|
sae_lens/tokenization_and_batching.py,sha256=uoHtAs9z3XqG0Fh-iQVYVlrbyB_E3kFFhrKU30BosCo,5438
|
|
39
39
|
sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
40
40
|
sae_lens/training/activation_scaler.py,sha256=FzNfgBplLWmyiSlZ6TUvE-nur3lOiGTrlvC97ys8S24,1973
|
|
@@ -46,7 +46,7 @@ sae_lens/training/types.py,sha256=1FpLx_Doda9vZpmfm-x1e8wGBYpyhe9Kpb_JuM5nIFM,90
|
|
|
46
46
|
sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
|
|
47
47
|
sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
|
|
48
48
|
sae_lens/util.py,sha256=oIMoeyEP2IzcPFmRbKUzOAycgEyMcOasGeO_BGVZbc4,4846
|
|
49
|
-
sae_lens-6.
|
|
50
|
-
sae_lens-6.
|
|
51
|
-
sae_lens-6.
|
|
52
|
-
sae_lens-6.
|
|
49
|
+
sae_lens-6.29.1.dist-info/METADATA,sha256=0Pp1L3vNiUGzkMox_BdQR6B064tTHFgwAPGJz8FY8UM,6573
|
|
50
|
+
sae_lens-6.29.1.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
|
|
51
|
+
sae_lens-6.29.1.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
|
|
52
|
+
sae_lens-6.29.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|