sae-lens 6.26.2__py3-none-any.whl → 6.27.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sae_lens/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.26.2"
2
+ __version__ = "6.27.0"
3
3
 
4
4
  import logging
5
5
 
sae_lens/config.py CHANGED
@@ -148,6 +148,7 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
148
148
  seqpos_slice (tuple[int | None, ...]): Determines slicing of activations when constructing batches during training. The slice should be (start_pos, end_pos, optional[step_size]), e.g. for Othello we sometimes use (5, -5). Note, step_size > 0.
149
149
  disable_concat_sequences (bool): Whether to disable concatenating sequences and ignore sequences shorter than the context size. If True, disables concatenating and ignores short sequences.
150
150
  sequence_separator_token (int | Literal["bos", "eos", "sep"] | None): If not `None`, this token will be placed between sentences in a batch to act as a separator. By default, this is the `<bos>` token.
151
+ activations_mixing_fraction (float): Fraction of the activation buffer to keep for mixing with new activations (default 0.5). Higher values mean more temporal shuffling but slower throughput. If 0, activations are served in order without shuffling (no temporal mixing).
151
152
  device (str): The device to use. Usually "cuda".
152
153
  act_store_device (str): The device to use for the activation store. "cpu" is advised in order to save VRAM. Defaults to "with_model" which uses the same device as the main model.
153
154
  seed (int): The seed to use.
@@ -217,6 +218,7 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
217
218
  sequence_separator_token: int | Literal["bos", "eos", "sep"] | None = (
218
219
  special_token_field(default="bos")
219
220
  )
221
+ activations_mixing_fraction: float = 0.5
220
222
 
221
223
  # Misc
222
224
  device: str = "cpu"
@@ -148,6 +148,7 @@ class ActivationsStore:
148
148
  exclude_special_tokens=exclude_special_tokens,
149
149
  disable_concat_sequences=cfg.disable_concat_sequences,
150
150
  sequence_separator_token=cfg.sequence_separator_token,
151
+ activations_mixing_fraction=cfg.activations_mixing_fraction,
151
152
  )
152
153
 
153
154
  @classmethod
@@ -222,6 +223,7 @@ class ActivationsStore:
222
223
  exclude_special_tokens: torch.Tensor | None = None,
223
224
  disable_concat_sequences: bool = False,
224
225
  sequence_separator_token: int | Literal["bos", "eos", "sep"] | None = "bos",
226
+ activations_mixing_fraction: float = 0.5,
225
227
  ):
226
228
  self.model = model
227
229
  if model_kwargs is None:
@@ -269,6 +271,7 @@ class ActivationsStore:
269
271
  self.sequence_separator_token: int | Literal["bos", "eos", "sep"] | None = (
270
272
  sequence_separator_token
271
273
  )
274
+ self.activations_mixing_fraction = activations_mixing_fraction
272
275
 
273
276
  self.n_dataset_processed = 0
274
277
 
@@ -708,6 +711,7 @@ class ActivationsStore:
708
711
  buffer_size=self.n_batches_in_buffer * self.training_context_size,
709
712
  batch_size=self.train_batch_size_tokens,
710
713
  activations_loader=self._iterate_filtered_activations(),
714
+ mix_fraction=self.activations_mixing_fraction,
711
715
  )
712
716
 
713
717
  def next_batch(self) -> torch.Tensor:
@@ -8,15 +8,19 @@ def mixing_buffer(
8
8
  buffer_size: int,
9
9
  batch_size: int,
10
10
  activations_loader: Iterator[torch.Tensor],
11
+ mix_fraction: float = 0.5,
11
12
  ) -> Iterator[torch.Tensor]:
12
13
  """
13
14
  A generator that maintains a mix of old and new activations for better training.
14
- It stores half of the activations and mixes them with new ones to create batches.
15
+ It keeps a portion of activations and mixes them with new ones to create batches.
15
16
 
16
17
  Args:
17
- buffer_size: Total size of the buffer (will store buffer_size/2 activations)
18
+ buffer_size: Total size of the buffer
18
19
  batch_size: Size of batches to return
19
20
  activations_loader: Iterator providing new activations
21
+ mix_fraction: Fraction of buffer to keep for mixing (default 0.5).
22
+ Higher values mean more temporal mixing but slower throughput.
23
+ If 0, no shuffling occurs (passthrough mode).
20
24
 
21
25
  Yields:
22
26
  Batches of activations of shape (batch_size, *activation_dims)
@@ -24,6 +28,8 @@ def mixing_buffer(
24
28
 
25
29
  if buffer_size < batch_size:
26
30
  raise ValueError("Buffer size must be greater than or equal to batch size")
31
+ if not 0 <= mix_fraction <= 1:
32
+ raise ValueError("mix_fraction must be in [0, 1]")
27
33
 
28
34
  storage_buffer: torch.Tensor | None = None
29
35
 
@@ -35,10 +41,12 @@ def mixing_buffer(
35
41
  )
36
42
 
37
43
  if storage_buffer.shape[0] >= buffer_size:
38
- # Shuffle
39
- storage_buffer = storage_buffer[torch.randperm(storage_buffer.shape[0])]
44
+ if mix_fraction > 0:
45
+ storage_buffer = storage_buffer[torch.randperm(storage_buffer.shape[0])]
40
46
 
41
- num_serving_batches = max(1, storage_buffer.shape[0] // (2 * batch_size))
47
+ num_serving_batches = max(
48
+ 1, int(storage_buffer.shape[0] * (1 - mix_fraction)) // batch_size
49
+ )
42
50
  serving_cutoff = num_serving_batches * batch_size
43
51
  serving_buffer = storage_buffer[:serving_cutoff]
44
52
  storage_buffer = storage_buffer[serving_cutoff:]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sae-lens
3
- Version: 6.26.2
3
+ Version: 6.27.0
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  License-File: LICENSE
@@ -1,9 +1,9 @@
1
- sae_lens/__init__.py,sha256=8muF12kzUe8sePiovnUMEXCu1OcotIVw-VvDjGEK2Zw,4725
1
+ sae_lens/__init__.py,sha256=379YK4TU5y4Gl_sjF9JG5b7c_ywo3PjcY37e3EW2IyA,4725
2
2
  sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  sae_lens/analysis/hooked_sae_transformer.py,sha256=dQRgGVwce8XwylL2AzJE7l9elhtMRFCs2hdUj-Qyy4g,14038
4
4
  sae_lens/analysis/neuronpedia_integration.py,sha256=Gx1W7hUBEuMoasNcnOnZ1wmqbXDd1pSZ1nqKEya1HQc,4962
5
5
  sae_lens/cache_activations_runner.py,sha256=Lvlz-k5-3XxVRtUdC4b1CiKyx5s0ckLa8GDGv9_kcxs,12566
6
- sae_lens/config.py,sha256=C982bUELhGHcfTwzeMTtXIf2hPtc946thYpUyctLiBo,30516
6
+ sae_lens/config.py,sha256=sseYcRMsAyopj8FICup1RGTXjFxzAithZ2OH7OpQV3Y,30839
7
7
  sae_lens/constants.py,sha256=CM-h9AjZNAl2aP7hVpKk7YsFHpu-_Lfhhmq2d5qPEVc,887
8
8
  sae_lens/evals.py,sha256=P0NUsJeGzYxFBiVKhbPzd72IFKY4gH40HHlEZ3jEAmg,39598
9
9
  sae_lens/llm_sae_training_runner.py,sha256=M7BK55gSFYu2qFQKABHX3c8i46P1LfODCeyHFzGGuqU,15196
@@ -28,15 +28,15 @@ sae_lens/saes/transcoder.py,sha256=CTpJs8ASOK06npih7gZHygZuxqTR7HICWlOYfTiKjI4,1
28
28
  sae_lens/tokenization_and_batching.py,sha256=D_o7cXvRqhT89H3wNzoRymNALNE6eHojBWLdXOUwUGE,5438
29
29
  sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
30
30
  sae_lens/training/activation_scaler.py,sha256=FzNfgBplLWmyiSlZ6TUvE-nur3lOiGTrlvC97ys8S24,1973
31
- sae_lens/training/activations_store.py,sha256=rQadexm2BiwK7_MZIPlRkcKSqabi3iuOTC-R8aJchS8,33778
32
- sae_lens/training/mixing_buffer.py,sha256=vDpYG5ZE70szDvBsRKcNHEES3h_WTKJ16qDYk5jPOVA,2015
31
+ sae_lens/training/activations_store.py,sha256=2BVajHRcozKQFf1tkeraUCdFuut3spdk0hhgtdpizzI,34031
32
+ sae_lens/training/mixing_buffer.py,sha256=DK22yPwEop4suG0K-8XFw5ZGNl0JrgCEjypmKEUAaGY,2394
33
33
  sae_lens/training/optim.py,sha256=bJpqqcK4enkcPvQAJkeH4Ci1LUOlfjIMTv6-IlaAbRA,5588
34
34
  sae_lens/training/sae_trainer.py,sha256=zhkabyIKxI_tZTV3_kwz6zMrHZ95Ecr97krmwc-9ffs,17600
35
35
  sae_lens/training/types.py,sha256=1FpLx_Doda9vZpmfm-x1e8wGBYpyhe9Kpb_JuM5nIFM,90
36
36
  sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
37
37
  sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
38
38
  sae_lens/util.py,sha256=spkcmQUsjVYFn5H2032nQYr1CKGVnv3tAdfIpY59-Mg,3919
39
- sae_lens-6.26.2.dist-info/METADATA,sha256=TPTLR3wKbPcGOsJ9P5hxVSQu-O6JIioFxoXUHP4Tj2w,5361
40
- sae_lens-6.26.2.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
41
- sae_lens-6.26.2.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
42
- sae_lens-6.26.2.dist-info/RECORD,,
39
+ sae_lens-6.27.0.dist-info/METADATA,sha256=S3GYpJhfYx05i-ZfX8rpwbiR1IFDlFAR0nSURgJQmJk,5361
40
+ sae_lens-6.27.0.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
41
+ sae_lens-6.27.0.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
42
+ sae_lens-6.27.0.dist-info/RECORD,,