sae-lens 6.28.1__py3-none-any.whl → 6.29.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -23,6 +23,8 @@ def train_toy_sae(
23
23
  device: str | torch.device = "cpu",
24
24
  n_snapshots: int = 0,
25
25
  snapshot_fn: Callable[[SAETrainer[Any, Any]], None] | None = None,
26
+ autocast_sae: bool = False,
27
+ autocast_data: bool = False,
26
28
  ) -> None:
27
29
  """
28
30
  Train an SAE on synthetic activations from a feature dictionary.
@@ -46,6 +48,8 @@ def train_toy_sae(
46
48
  snapshot_fn: Callback function called at each snapshot point. Receives
47
49
  the SAETrainer instance, allowing access to the SAE, training step,
48
50
  and other training state. Required if n_snapshots > 0.
51
+ autocast_sae: Whether to autocast the SAE to bfloat16. Only recommend for large SAEs on CUDA
52
+ autocast_data: Whether to autocast the activations generator and feature dictionary to bfloat16. Only recommend for large data on CUDA.
49
53
  """
50
54
 
51
55
  device_str = str(device) if isinstance(device, torch.device) else device
@@ -55,6 +59,7 @@ def train_toy_sae(
55
59
  feature_dict=feature_dict,
56
60
  activations_generator=activations_generator,
57
61
  batch_size=batch_size,
62
+ autocast=autocast_data,
58
63
  )
59
64
 
60
65
  # Create trainer config
@@ -64,7 +69,7 @@ def train_toy_sae(
64
69
  save_final_checkpoint=False,
65
70
  total_training_samples=training_samples,
66
71
  device=device_str,
67
- autocast=False,
72
+ autocast=autocast_sae,
68
73
  lr=lr,
69
74
  lr_end=lr,
70
75
  lr_scheduler_name="constant",
@@ -119,6 +124,7 @@ class SyntheticActivationIterator(Iterator[torch.Tensor]):
119
124
  feature_dict: FeatureDictionary,
120
125
  activations_generator: ActivationGenerator,
121
126
  batch_size: int,
127
+ autocast: bool = False,
122
128
  ):
123
129
  """
124
130
  Create a new SyntheticActivationIterator.
@@ -127,16 +133,23 @@ class SyntheticActivationIterator(Iterator[torch.Tensor]):
127
133
  feature_dict: The feature dictionary to use for generating hidden activations
128
134
  activations_generator: Generator that produces feature activations
129
135
  batch_size: Number of samples per batch
136
+ autocast: Whether to autocast the activations generator and feature dictionary to bfloat16.
130
137
  """
131
138
  self.feature_dict = feature_dict
132
139
  self.activations_generator = activations_generator
133
140
  self.batch_size = batch_size
141
+ self.autocast = autocast
134
142
 
135
143
  @torch.no_grad()
136
144
  def next_batch(self) -> torch.Tensor:
137
145
  """Generate the next batch of hidden activations."""
138
- features = self.activations_generator(self.batch_size)
139
- return self.feature_dict(features)
146
+ with torch.autocast(
147
+ device_type=self.feature_dict.feature_vectors.device.type,
148
+ dtype=torch.bfloat16,
149
+ enabled=self.autocast,
150
+ ):
151
+ features = self.activations_generator(self.batch_size)
152
+ return self.feature_dict(features)
140
153
 
141
154
  def __iter__(self) -> "SyntheticActivationIterator":
142
155
  return self
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sae-lens
3
- Version: 6.28.1
3
+ Version: 6.29.1
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  License-File: LICENSE
@@ -50,6 +50,8 @@ SAELens exists to help researchers:
50
50
  - Analyse sparse autoencoders / research mechanistic interpretability.
51
51
  - Generate insights which make it easier to create safe and aligned AI systems.
52
52
 
53
+ SAELens inference works with any PyTorch-based model, not just TransformerLens. While we provide deep integration with TransformerLens via `HookedSAETransformer`, SAEs can be used with Hugging Face Transformers, NNsight, or any other framework by extracting activations and passing them to the SAE's `encode()` and `decode()` methods.
54
+
53
55
  Please refer to the [documentation](https://decoderesearch.github.io/SAELens/) for information on how to:
54
56
 
55
57
  - Download and Analyse pre-trained sparse autoencoders.
@@ -84,6 +86,14 @@ The new v6 update is a major refactor to SAELens and changes the way training co
84
86
 
85
87
  Feel free to join the [Open Source Mechanistic Interpretability Slack](https://join.slack.com/t/opensourcemechanistic/shared_invite/zt-375zalm04-GFd5tdBU1yLKlu_T_JSqZQ) for support!
86
88
 
89
+ ## Other SAE Projects
90
+
91
+ - [dictionary-learning](https://github.com/saprmarks/dictionary_learning): An SAE training library that focuses on having hackable code.
92
+ - [Sparsify](https://github.com/EleutherAI/sparsify): A lean SAE training library focused on TopK SAEs.
93
+ - [Overcomplete](https://github.com/KempnerInstitute/overcomplete): SAE training library focused on vision models.
94
+ - [SAE-Vis](https://github.com/callummcdougall/sae_vis): A library for visualizing SAE features, works with SAELens.
95
+ - [SAEBench](https://github.com/adamkarvonen/SAEBench): A suite of LLM SAE benchmarks, works with SAELens.
96
+
87
97
  ## Citation
88
98
 
89
99
  Please cite the package as follows:
@@ -1,4 +1,4 @@
1
- sae_lens/__init__.py,sha256=S-AS72IxkvKO-wItRQjuyczikDxmfDaUgXRSfu5PU-o,4788
1
+ sae_lens/__init__.py,sha256=emqKVNiJwD8YtYhtgHJyAT8YSX1QmruQYuG-J4CStC4,4788
2
2
  sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  sae_lens/analysis/hooked_sae_transformer.py,sha256=dQRgGVwce8XwylL2AzJE7l9elhtMRFCs2hdUj-Qyy4g,14038
4
4
  sae_lens/analysis/neuronpedia_integration.py,sha256=Gx1W7hUBEuMoasNcnOnZ1wmqbXDd1pSZ1nqKEya1HQc,4962
@@ -12,7 +12,7 @@ sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,
12
12
  sae_lens/loading/pretrained_sae_loaders.py,sha256=hHMlew1u6zVlbzvS9S_SfUPnAG0_OAjjIcjoUTIUZrU,63657
13
13
  sae_lens/loading/pretrained_saes_directory.py,sha256=1at_aQbD8WFywchQCKuwfP-yvCq_Z2aUYrpKDnSN5Nc,4283
14
14
  sae_lens/pretokenize_runner.py,sha256=amJwIz3CKi2s2wNQn-10E7eAV7VFhNqtFDNTeTkwEI8,7133
15
- sae_lens/pretrained_saes.yaml,sha256=Hn8jXwZ7V6QQxzgu41LFEP-LAzuDxwYL5vhoar-pPX8,1509922
15
+ sae_lens/pretrained_saes.yaml,sha256=Nq43dTcFvDDONTuJ9Me_HQ5nHqr9BdbP5-ZJGXj0TAQ,1509932
16
16
  sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
17
17
  sae_lens/saes/__init__.py,sha256=SBqPaP6Gl5uPFwHlumAZATC4Wd26xKIYLAAAo4MSa5Q,2200
18
18
  sae_lens/saes/batchtopk_sae.py,sha256=x4EbgZl0GUickRPcCmtKNGS2Ra3Uy1Z1OtF2FnrSabQ,5422
@@ -25,16 +25,16 @@ sae_lens/saes/standard_sae.py,sha256=_hldNZkFPAf9VGrxouR1-tN8T2OEk8IkWBcXoatrC1o
25
25
  sae_lens/saes/temporal_sae.py,sha256=83Ap4mYGfdN3sKdPF8nKjhdXph3-7E2QuLobqJ_YuoM,13273
26
26
  sae_lens/saes/topk_sae.py,sha256=vrMRPrCQR1o8G_kXqY_EAoGZARupkQNFB2dNZVLsusE,21073
27
27
  sae_lens/saes/transcoder.py,sha256=CTpJs8ASOK06npih7gZHygZuxqTR7HICWlOYfTiKjI4,13501
28
- sae_lens/synthetic/__init__.py,sha256=FGUasB6fLPXRFCcrtKfL7vCKDOWebZ5Rx5F9QNJZklI,2875
29
- sae_lens/synthetic/activation_generator.py,sha256=thWGTwRmhu0K8m66WfJUajHmuIPHkwV4_HjmG0dL3G8,7638
30
- sae_lens/synthetic/correlation.py,sha256=odr-S5h6c2U-bepwrAQeMfV1iBF_cnnQzqw7zapEXZ4,6056
28
+ sae_lens/synthetic/__init__.py,sha256=MtTnGkTfHV2WjkIgs7zZyx10EK9U5fjOHXy69Aq3uKw,3095
29
+ sae_lens/synthetic/activation_generator.py,sha256=8L9nwC4jFRv_wg3QN-n1sFwX8w1NqwJMysWaJ41lLlY,15197
30
+ sae_lens/synthetic/correlation.py,sha256=tMTLo9fBfDpeXwqhyUgFqnTipj9x2W0t4oEtNxB7AG0,13256
31
31
  sae_lens/synthetic/evals.py,sha256=Nhi314ZnRgLfhBj-3tm_zzI-pGyFTcwllDXbIpPFXeU,4584
32
- sae_lens/synthetic/feature_dictionary.py,sha256=2A9wqdT1KejRLuIoFWdoiWdDtaHHgIluaKsHGizsVxI,4864
32
+ sae_lens/synthetic/feature_dictionary.py,sha256=Nd4xjSTxKMnKilZ3uYi8Gv5SS5D4bv4wHiSL1uGB69E,6933
33
33
  sae_lens/synthetic/firing_probabilities.py,sha256=yclz1pWl5gE1r8LAxFvzQS88Lxwk5-3r8BCX9HLVejA,3370
34
- sae_lens/synthetic/hierarchy.py,sha256=dlQdPnnG3VzQDB3QOaqSXwoH8Ij2ioxmTlZg1lXHaRQ,11754
34
+ sae_lens/synthetic/hierarchy.py,sha256=nm7nwnTswktVJeKUsRZ0hLOdXcFWGbxnA1b6lefHm-4,33592
35
35
  sae_lens/synthetic/initialization.py,sha256=orMGW-786wRDHIS2W7bEH0HmlVFQ4g2z4bnnwdv5w4s,1386
36
36
  sae_lens/synthetic/plotting.py,sha256=5lFrej1QOkGAcImFNo5-o-8mI_rUVqvEI57KzUQPPtQ,8208
37
- sae_lens/synthetic/training.py,sha256=Bg6NYxdzifq_8g-dJQSZ_z_TXDdGRtEi7tqNDb-gCVc,4986
37
+ sae_lens/synthetic/training.py,sha256=fHcX2cZ6nDupr71GX0Gk17f1NvQ0SKIVXIA6IuAb2dw,5692
38
38
  sae_lens/tokenization_and_batching.py,sha256=uoHtAs9z3XqG0Fh-iQVYVlrbyB_E3kFFhrKU30BosCo,5438
39
39
  sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
40
40
  sae_lens/training/activation_scaler.py,sha256=FzNfgBplLWmyiSlZ6TUvE-nur3lOiGTrlvC97ys8S24,1973
@@ -46,7 +46,7 @@ sae_lens/training/types.py,sha256=1FpLx_Doda9vZpmfm-x1e8wGBYpyhe9Kpb_JuM5nIFM,90
46
46
  sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
47
47
  sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
48
48
  sae_lens/util.py,sha256=oIMoeyEP2IzcPFmRbKUzOAycgEyMcOasGeO_BGVZbc4,4846
49
- sae_lens-6.28.1.dist-info/METADATA,sha256=OdPVG1dwWoLGqiutKkAJGazfBLLbYQLBUbs_3h58BKg,5633
50
- sae_lens-6.28.1.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
51
- sae_lens-6.28.1.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
52
- sae_lens-6.28.1.dist-info/RECORD,,
49
+ sae_lens-6.29.1.dist-info/METADATA,sha256=0Pp1L3vNiUGzkMox_BdQR6B064tTHFgwAPGJz8FY8UM,6573
50
+ sae_lens-6.29.1.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
51
+ sae_lens-6.29.1.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
52
+ sae_lens-6.29.1.dist-info/RECORD,,