sae-lens 6.26.0__py3-none-any.whl → 6.28.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -8,15 +8,19 @@ def mixing_buffer(
8
8
  buffer_size: int,
9
9
  batch_size: int,
10
10
  activations_loader: Iterator[torch.Tensor],
11
+ mix_fraction: float = 0.5,
11
12
  ) -> Iterator[torch.Tensor]:
12
13
  """
13
14
  A generator that maintains a mix of old and new activations for better training.
14
- It stores half of the activations and mixes them with new ones to create batches.
15
+ It keeps a portion of activations and mixes them with new ones to create batches.
15
16
 
16
17
  Args:
17
- buffer_size: Total size of the buffer (will store buffer_size/2 activations)
18
+ buffer_size: Total size of the buffer
18
19
  batch_size: Size of batches to return
19
20
  activations_loader: Iterator providing new activations
21
+ mix_fraction: Fraction of buffer to keep for mixing (default 0.5).
22
+ Higher values mean more temporal mixing but slower throughput.
23
+ If 0, no shuffling occurs (passthrough mode).
20
24
 
21
25
  Yields:
22
26
  Batches of activations of shape (batch_size, *activation_dims)
@@ -24,6 +28,8 @@ def mixing_buffer(
24
28
 
25
29
  if buffer_size < batch_size:
26
30
  raise ValueError("Buffer size must be greater than or equal to batch size")
31
+ if not 0 <= mix_fraction <= 1:
32
+ raise ValueError("mix_fraction must be in [0, 1]")
27
33
 
28
34
  storage_buffer: torch.Tensor | None = None
29
35
 
@@ -35,10 +41,13 @@ def mixing_buffer(
35
41
  )
36
42
 
37
43
  if storage_buffer.shape[0] >= buffer_size:
38
- # Shuffle
39
- storage_buffer = storage_buffer[torch.randperm(storage_buffer.shape[0])]
44
+ if mix_fraction > 0:
45
+ storage_buffer = storage_buffer[torch.randperm(storage_buffer.shape[0])]
40
46
 
41
- num_serving_batches = max(1, storage_buffer.shape[0] // (2 * batch_size))
47
+ # Keep a fixed amount for mixing, serve the rest
48
+ keep_for_mixing = int(buffer_size * mix_fraction)
49
+ num_to_serve = storage_buffer.shape[0] - keep_for_mixing
50
+ num_serving_batches = max(1, num_to_serve // batch_size)
42
51
  serving_cutoff = num_serving_batches * batch_size
43
52
  serving_buffer = storage_buffer[:serving_cutoff]
44
53
  storage_buffer = storage_buffer[serving_cutoff:]
@@ -55,7 +55,7 @@ Evaluator = Callable[[T_TRAINING_SAE, DataProvider, ActivationScaler], dict[str,
55
55
 
56
56
  class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
57
57
  """
58
- Core SAE class used for inference. For training, see TrainingSAE.
58
+ Trainer for Sparse Autoencoder (SAE) models.
59
59
  """
60
60
 
61
61
  data_provider: DataProvider
sae_lens/util.py CHANGED
@@ -95,8 +95,10 @@ def get_special_token_ids(tokenizer: PreTrainedTokenizerBase) -> list[int]:
95
95
  return list(special_tokens)
96
96
 
97
97
 
98
- def str_to_dtype(dtype: str) -> torch.dtype:
98
+ def str_to_dtype(dtype: str | torch.dtype) -> torch.dtype:
99
99
  """Convert a string to a torch.dtype."""
100
+ if isinstance(dtype, torch.dtype):
101
+ return dtype
100
102
  if dtype not in DTYPE_MAP:
101
103
  raise ValueError(
102
104
  f"Invalid dtype: {dtype}. Must be one of {list(DTYPE_MAP.keys())}"
@@ -111,3 +113,26 @@ def dtype_to_str(dtype: torch.dtype) -> str:
111
113
  f"Invalid dtype: {dtype}. Must be one of {list(DTYPE_TO_STR.keys())}"
112
114
  )
113
115
  return DTYPE_TO_STR[dtype]
116
+
117
+
118
+ def cosine_similarities(
119
+ mat1: torch.Tensor, mat2: torch.Tensor | None = None
120
+ ) -> torch.Tensor:
121
+ """
122
+ Compute cosine similarities between each row of mat1 and each row of mat2.
123
+
124
+ Args:
125
+ mat1: Tensor of shape [n1, d]
126
+ mat2: Tensor of shape [n2, d]. If not provided, mat1 = mat2
127
+
128
+ Returns:
129
+ Tensor of shape [n1, n2] with cosine similarities
130
+ """
131
+ if mat2 is None:
132
+ mat2 = mat1
133
+ # Clamp norm to 1e-8 to prevent division by zero. This threshold is chosen
134
+ # to be small enough to not affect normal vectors but large enough to avoid
135
+ # numerical instability. Zero vectors will effectively map to zero similarity.
136
+ mat1_normed = mat1 / mat1.norm(dim=1, keepdim=True).clamp(min=1e-8)
137
+ mat2_normed = mat2 / mat2.norm(dim=1, keepdim=True).clamp(min=1e-8)
138
+ return mat1_normed @ mat2_normed.T
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sae-lens
3
- Version: 6.26.0
3
+ Version: 6.28.1
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  License-File: LICENSE
@@ -77,6 +77,8 @@ The new v6 update is a major refactor to SAELens and changes the way training co
77
77
  [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/logits_lens_with_features.ipynb)
78
78
  - [Training a Sparse Autoencoder](tutorials/training_a_sparse_autoencoder.ipynb)
79
79
  [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)
80
+ - [Training SAEs on Synthetic Data](tutorials/training_saes_on_synthetic_data.ipynb)
81
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/training_saes_on_synthetic_data.ipynb)
80
82
 
81
83
  ## Join the Slack!
82
84
 
@@ -0,0 +1,52 @@
1
+ sae_lens/__init__.py,sha256=S-AS72IxkvKO-wItRQjuyczikDxmfDaUgXRSfu5PU-o,4788
2
+ sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
+ sae_lens/analysis/hooked_sae_transformer.py,sha256=dQRgGVwce8XwylL2AzJE7l9elhtMRFCs2hdUj-Qyy4g,14038
4
+ sae_lens/analysis/neuronpedia_integration.py,sha256=Gx1W7hUBEuMoasNcnOnZ1wmqbXDd1pSZ1nqKEya1HQc,4962
5
+ sae_lens/cache_activations_runner.py,sha256=TjqNWIc46Nw09jHWFjzQzgzG5wdu_87Ahe-iFjI5_0Q,13117
6
+ sae_lens/config.py,sha256=sseYcRMsAyopj8FICup1RGTXjFxzAithZ2OH7OpQV3Y,30839
7
+ sae_lens/constants.py,sha256=CM-h9AjZNAl2aP7hVpKk7YsFHpu-_Lfhhmq2d5qPEVc,887
8
+ sae_lens/evals.py,sha256=P0NUsJeGzYxFBiVKhbPzd72IFKY4gH40HHlEZ3jEAmg,39598
9
+ sae_lens/llm_sae_training_runner.py,sha256=M7BK55gSFYu2qFQKABHX3c8i46P1LfODCeyHFzGGuqU,15196
10
+ sae_lens/load_model.py,sha256=C8AMykctj6H7tz_xRwB06-EXj6TfW64PtSJZR5Jxn1Y,8649
11
+ sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
+ sae_lens/loading/pretrained_sae_loaders.py,sha256=hHMlew1u6zVlbzvS9S_SfUPnAG0_OAjjIcjoUTIUZrU,63657
13
+ sae_lens/loading/pretrained_saes_directory.py,sha256=1at_aQbD8WFywchQCKuwfP-yvCq_Z2aUYrpKDnSN5Nc,4283
14
+ sae_lens/pretokenize_runner.py,sha256=amJwIz3CKi2s2wNQn-10E7eAV7VFhNqtFDNTeTkwEI8,7133
15
+ sae_lens/pretrained_saes.yaml,sha256=Hn8jXwZ7V6QQxzgu41LFEP-LAzuDxwYL5vhoar-pPX8,1509922
16
+ sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
17
+ sae_lens/saes/__init__.py,sha256=SBqPaP6Gl5uPFwHlumAZATC4Wd26xKIYLAAAo4MSa5Q,2200
18
+ sae_lens/saes/batchtopk_sae.py,sha256=x4EbgZl0GUickRPcCmtKNGS2Ra3Uy1Z1OtF2FnrSabQ,5422
19
+ sae_lens/saes/gated_sae.py,sha256=V_2ZNlV4gRD-rX5JSx1xqY7idT8ChfdQ5yxWDdu_6hg,8826
20
+ sae_lens/saes/jumprelu_sae.py,sha256=miiF-xI_yXdV9EkKjwAbU9zSMsx9KtKCz5YdXEzkN8g,13313
21
+ sae_lens/saes/matching_pursuit_sae.py,sha256=08_G9p1YMLnE5qZVCPp6gll-iG6nHRbMMASf4_bkFt8,13207
22
+ sae_lens/saes/matryoshka_batchtopk_sae.py,sha256=Qr6htt1HHOuO9FXI9hyaPSnGFIiJG-v7y1t1CEmkFzM,5995
23
+ sae_lens/saes/sae.py,sha256=xRmgiLuaFlDCv8SyLbL-5TwdrWHpNLqSGe8mC1L6WcI,40942
24
+ sae_lens/saes/standard_sae.py,sha256=_hldNZkFPAf9VGrxouR1-tN8T2OEk8IkWBcXoatrC1o,5749
25
+ sae_lens/saes/temporal_sae.py,sha256=83Ap4mYGfdN3sKdPF8nKjhdXph3-7E2QuLobqJ_YuoM,13273
26
+ sae_lens/saes/topk_sae.py,sha256=vrMRPrCQR1o8G_kXqY_EAoGZARupkQNFB2dNZVLsusE,21073
27
+ sae_lens/saes/transcoder.py,sha256=CTpJs8ASOK06npih7gZHygZuxqTR7HICWlOYfTiKjI4,13501
28
+ sae_lens/synthetic/__init__.py,sha256=FGUasB6fLPXRFCcrtKfL7vCKDOWebZ5Rx5F9QNJZklI,2875
29
+ sae_lens/synthetic/activation_generator.py,sha256=thWGTwRmhu0K8m66WfJUajHmuIPHkwV4_HjmG0dL3G8,7638
30
+ sae_lens/synthetic/correlation.py,sha256=odr-S5h6c2U-bepwrAQeMfV1iBF_cnnQzqw7zapEXZ4,6056
31
+ sae_lens/synthetic/evals.py,sha256=Nhi314ZnRgLfhBj-3tm_zzI-pGyFTcwllDXbIpPFXeU,4584
32
+ sae_lens/synthetic/feature_dictionary.py,sha256=2A9wqdT1KejRLuIoFWdoiWdDtaHHgIluaKsHGizsVxI,4864
33
+ sae_lens/synthetic/firing_probabilities.py,sha256=yclz1pWl5gE1r8LAxFvzQS88Lxwk5-3r8BCX9HLVejA,3370
34
+ sae_lens/synthetic/hierarchy.py,sha256=dlQdPnnG3VzQDB3QOaqSXwoH8Ij2ioxmTlZg1lXHaRQ,11754
35
+ sae_lens/synthetic/initialization.py,sha256=orMGW-786wRDHIS2W7bEH0HmlVFQ4g2z4bnnwdv5w4s,1386
36
+ sae_lens/synthetic/plotting.py,sha256=5lFrej1QOkGAcImFNo5-o-8mI_rUVqvEI57KzUQPPtQ,8208
37
+ sae_lens/synthetic/training.py,sha256=Bg6NYxdzifq_8g-dJQSZ_z_TXDdGRtEi7tqNDb-gCVc,4986
38
+ sae_lens/tokenization_and_batching.py,sha256=uoHtAs9z3XqG0Fh-iQVYVlrbyB_E3kFFhrKU30BosCo,5438
39
+ sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
40
+ sae_lens/training/activation_scaler.py,sha256=FzNfgBplLWmyiSlZ6TUvE-nur3lOiGTrlvC97ys8S24,1973
41
+ sae_lens/training/activations_store.py,sha256=kp4-6R4rTJUSt-g-Ifg5B1h7iIe7jZj-XQSKDvDpQMI,32187
42
+ sae_lens/training/mixing_buffer.py,sha256=1Z-S2CcQXMWGxRZJFnXeZFxbZcALkO_fP6VO37XdJQQ,2519
43
+ sae_lens/training/optim.py,sha256=bJpqqcK4enkcPvQAJkeH4Ci1LUOlfjIMTv6-IlaAbRA,5588
44
+ sae_lens/training/sae_trainer.py,sha256=iiGrNwmiX0xSHnJit0lH66yQzB6q8Fww1WNJZbTSBGY,17579
45
+ sae_lens/training/types.py,sha256=1FpLx_Doda9vZpmfm-x1e8wGBYpyhe9Kpb_JuM5nIFM,90
46
+ sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
47
+ sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
48
+ sae_lens/util.py,sha256=oIMoeyEP2IzcPFmRbKUzOAycgEyMcOasGeO_BGVZbc4,4846
49
+ sae_lens-6.28.1.dist-info/METADATA,sha256=OdPVG1dwWoLGqiutKkAJGazfBLLbYQLBUbs_3h58BKg,5633
50
+ sae_lens-6.28.1.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
51
+ sae_lens-6.28.1.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
52
+ sae_lens-6.28.1.dist-info/RECORD,,
@@ -1,42 +0,0 @@
1
- sae_lens/__init__.py,sha256=X-4FIAiwkUhduQlGp0nITAOCpFlU89339hmuLEfMz8A,4725
2
- sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- sae_lens/analysis/hooked_sae_transformer.py,sha256=dQRgGVwce8XwylL2AzJE7l9elhtMRFCs2hdUj-Qyy4g,14038
4
- sae_lens/analysis/neuronpedia_integration.py,sha256=Gx1W7hUBEuMoasNcnOnZ1wmqbXDd1pSZ1nqKEya1HQc,4962
5
- sae_lens/cache_activations_runner.py,sha256=Lvlz-k5-3XxVRtUdC4b1CiKyx5s0ckLa8GDGv9_kcxs,12566
6
- sae_lens/config.py,sha256=C982bUELhGHcfTwzeMTtXIf2hPtc946thYpUyctLiBo,30516
7
- sae_lens/constants.py,sha256=CM-h9AjZNAl2aP7hVpKk7YsFHpu-_Lfhhmq2d5qPEVc,887
8
- sae_lens/evals.py,sha256=P0NUsJeGzYxFBiVKhbPzd72IFKY4gH40HHlEZ3jEAmg,39598
9
- sae_lens/llm_sae_training_runner.py,sha256=M7BK55gSFYu2qFQKABHX3c8i46P1LfODCeyHFzGGuqU,15196
10
- sae_lens/load_model.py,sha256=C8AMykctj6H7tz_xRwB06-EXj6TfW64PtSJZR5Jxn1Y,8649
11
- sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- sae_lens/loading/pretrained_sae_loaders.py,sha256=hq-dhxsEdUmlAnZEiZBqX7lNyQQwZ6KXmXZWpzAc5FY,63638
13
- sae_lens/loading/pretrained_saes_directory.py,sha256=hejNfLUepYCSGPalRfQwxxCEUqMMUPsn1tufwvwct5k,3820
14
- sae_lens/pretokenize_runner.py,sha256=amJwIz3CKi2s2wNQn-10E7eAV7VFhNqtFDNTeTkwEI8,7133
15
- sae_lens/pretrained_saes.yaml,sha256=Hy9mk4Liy50B0CIBD4ER1ETcho2drFFiIy-bPVCN_lc,1510210
16
- sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
17
- sae_lens/saes/__init__.py,sha256=SBqPaP6Gl5uPFwHlumAZATC4Wd26xKIYLAAAo4MSa5Q,2200
18
- sae_lens/saes/batchtopk_sae.py,sha256=x4EbgZl0GUickRPcCmtKNGS2Ra3Uy1Z1OtF2FnrSabQ,5422
19
- sae_lens/saes/gated_sae.py,sha256=mHnmw-RD7hqIbP9_EBj3p2SK0OqQIkZivdOKRygeRgw,8825
20
- sae_lens/saes/jumprelu_sae.py,sha256=udjGHp3WTABQSL2Qq57j-bINWX61GCmo68EmdjMOXoo,13310
21
- sae_lens/saes/matching_pursuit_sae.py,sha256=08_G9p1YMLnE5qZVCPp6gll-iG6nHRbMMASf4_bkFt8,13207
22
- sae_lens/saes/matryoshka_batchtopk_sae.py,sha256=Qr6htt1HHOuO9FXI9hyaPSnGFIiJG-v7y1t1CEmkFzM,5995
23
- sae_lens/saes/sae.py,sha256=fzXv8lwHskSxsf8hm_wlKPkpq50iafmBjBNQzwZ6a00,40050
24
- sae_lens/saes/standard_sae.py,sha256=nEVETwAmRD2tyX7ESIic1fij48gAq1Dh7s_GQ2fqCZ4,5747
25
- sae_lens/saes/temporal_sae.py,sha256=DsecivcHWId-MTuJpQbz8OhqtmGhZACxJauYZGHo0Ok,13272
26
- sae_lens/saes/topk_sae.py,sha256=vrMRPrCQR1o8G_kXqY_EAoGZARupkQNFB2dNZVLsusE,21073
27
- sae_lens/saes/transcoder.py,sha256=CTpJs8ASOK06npih7gZHygZuxqTR7HICWlOYfTiKjI4,13501
28
- sae_lens/tokenization_and_batching.py,sha256=D_o7cXvRqhT89H3wNzoRymNALNE6eHojBWLdXOUwUGE,5438
29
- sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
30
- sae_lens/training/activation_scaler.py,sha256=FzNfgBplLWmyiSlZ6TUvE-nur3lOiGTrlvC97ys8S24,1973
31
- sae_lens/training/activations_store.py,sha256=rQadexm2BiwK7_MZIPlRkcKSqabi3iuOTC-R8aJchS8,33778
32
- sae_lens/training/mixing_buffer.py,sha256=vDpYG5ZE70szDvBsRKcNHEES3h_WTKJ16qDYk5jPOVA,2015
33
- sae_lens/training/optim.py,sha256=bJpqqcK4enkcPvQAJkeH4Ci1LUOlfjIMTv6-IlaAbRA,5588
34
- sae_lens/training/sae_trainer.py,sha256=zhkabyIKxI_tZTV3_kwz6zMrHZ95Ecr97krmwc-9ffs,17600
35
- sae_lens/training/types.py,sha256=1FpLx_Doda9vZpmfm-x1e8wGBYpyhe9Kpb_JuM5nIFM,90
36
- sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
37
- sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
38
- sae_lens/util.py,sha256=spkcmQUsjVYFn5H2032nQYr1CKGVnv3tAdfIpY59-Mg,3919
39
- sae_lens-6.26.0.dist-info/METADATA,sha256=tN-oGNK-u9iasBLcJIGXjKxuyM7hDDG8P3KeS7264uU,5361
40
- sae_lens-6.26.0.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
41
- sae_lens-6.26.0.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
42
- sae_lens-6.26.0.dist-info/RECORD,,