sae-lens 6.26.1__tar.gz → 6.27.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sae_lens-6.26.1 → sae_lens-6.27.0}/PKG-INFO +1 -1
- {sae_lens-6.26.1 → sae_lens-6.27.0}/pyproject.toml +1 -1
- {sae_lens-6.26.1 → sae_lens-6.27.0}/sae_lens/__init__.py +1 -1
- {sae_lens-6.26.1 → sae_lens-6.27.0}/sae_lens/config.py +2 -0
- {sae_lens-6.26.1 → sae_lens-6.27.0}/sae_lens/loading/pretrained_sae_loaders.py +1 -1
- {sae_lens-6.26.1 → sae_lens-6.27.0}/sae_lens/training/activations_store.py +4 -0
- {sae_lens-6.26.1 → sae_lens-6.27.0}/sae_lens/training/mixing_buffer.py +13 -5
- {sae_lens-6.26.1 → sae_lens-6.27.0}/LICENSE +0 -0
- {sae_lens-6.26.1 → sae_lens-6.27.0}/README.md +0 -0
- {sae_lens-6.26.1 → sae_lens-6.27.0}/sae_lens/analysis/__init__.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.27.0}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.27.0}/sae_lens/analysis/neuronpedia_integration.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.27.0}/sae_lens/cache_activations_runner.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.27.0}/sae_lens/constants.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.27.0}/sae_lens/evals.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.27.0}/sae_lens/llm_sae_training_runner.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.27.0}/sae_lens/load_model.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.27.0}/sae_lens/loading/__init__.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.27.0}/sae_lens/loading/pretrained_saes_directory.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.27.0}/sae_lens/pretokenize_runner.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.27.0}/sae_lens/pretrained_saes.yaml +0 -0
- {sae_lens-6.26.1 → sae_lens-6.27.0}/sae_lens/registry.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.27.0}/sae_lens/saes/__init__.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.27.0}/sae_lens/saes/batchtopk_sae.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.27.0}/sae_lens/saes/gated_sae.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.27.0}/sae_lens/saes/jumprelu_sae.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.27.0}/sae_lens/saes/matching_pursuit_sae.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.27.0}/sae_lens/saes/matryoshka_batchtopk_sae.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.27.0}/sae_lens/saes/sae.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.27.0}/sae_lens/saes/standard_sae.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.27.0}/sae_lens/saes/temporal_sae.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.27.0}/sae_lens/saes/topk_sae.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.27.0}/sae_lens/saes/transcoder.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.27.0}/sae_lens/tokenization_and_batching.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.27.0}/sae_lens/training/__init__.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.27.0}/sae_lens/training/activation_scaler.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.27.0}/sae_lens/training/optim.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.27.0}/sae_lens/training/sae_trainer.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.27.0}/sae_lens/training/types.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.27.0}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.27.0}/sae_lens/tutorial/tsea.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.27.0}/sae_lens/util.py +0 -0
|
@@ -148,6 +148,7 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
|
148
148
|
seqpos_slice (tuple[int | None, ...]): Determines slicing of activations when constructing batches during training. The slice should be (start_pos, end_pos, optional[step_size]), e.g. for Othello we sometimes use (5, -5). Note, step_size > 0.
|
|
149
149
|
disable_concat_sequences (bool): Whether to disable concatenating sequences and ignore sequences shorter than the context size. If True, disables concatenating and ignores short sequences.
|
|
150
150
|
sequence_separator_token (int | Literal["bos", "eos", "sep"] | None): If not `None`, this token will be placed between sentences in a batch to act as a separator. By default, this is the `<bos>` token.
|
|
151
|
+
activations_mixing_fraction (float): Fraction of the activation buffer to keep for mixing with new activations (default 0.5). Higher values mean more temporal shuffling but slower throughput. If 0, activations are served in order without shuffling (no temporal mixing).
|
|
151
152
|
device (str): The device to use. Usually "cuda".
|
|
152
153
|
act_store_device (str): The device to use for the activation store. "cpu" is advised in order to save VRAM. Defaults to "with_model" which uses the same device as the main model.
|
|
153
154
|
seed (int): The seed to use.
|
|
@@ -217,6 +218,7 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
|
217
218
|
sequence_separator_token: int | Literal["bos", "eos", "sep"] | None = (
|
|
218
219
|
special_token_field(default="bos")
|
|
219
220
|
)
|
|
221
|
+
activations_mixing_fraction: float = 0.5
|
|
220
222
|
|
|
221
223
|
# Misc
|
|
222
224
|
device: str = "cpu"
|
|
@@ -959,7 +959,7 @@ def get_dictionary_learning_config_1_from_hf(
|
|
|
959
959
|
architecture = "standard"
|
|
960
960
|
if trainer["dict_class"] == "GatedAutoEncoder":
|
|
961
961
|
architecture = "gated"
|
|
962
|
-
elif trainer["dict_class"]
|
|
962
|
+
elif trainer["dict_class"] in ["MatryoshkaBatchTopKSAE", "BatchTopKSAE"]:
|
|
963
963
|
architecture = "jumprelu"
|
|
964
964
|
|
|
965
965
|
return {
|
|
@@ -148,6 +148,7 @@ class ActivationsStore:
|
|
|
148
148
|
exclude_special_tokens=exclude_special_tokens,
|
|
149
149
|
disable_concat_sequences=cfg.disable_concat_sequences,
|
|
150
150
|
sequence_separator_token=cfg.sequence_separator_token,
|
|
151
|
+
activations_mixing_fraction=cfg.activations_mixing_fraction,
|
|
151
152
|
)
|
|
152
153
|
|
|
153
154
|
@classmethod
|
|
@@ -222,6 +223,7 @@ class ActivationsStore:
|
|
|
222
223
|
exclude_special_tokens: torch.Tensor | None = None,
|
|
223
224
|
disable_concat_sequences: bool = False,
|
|
224
225
|
sequence_separator_token: int | Literal["bos", "eos", "sep"] | None = "bos",
|
|
226
|
+
activations_mixing_fraction: float = 0.5,
|
|
225
227
|
):
|
|
226
228
|
self.model = model
|
|
227
229
|
if model_kwargs is None:
|
|
@@ -269,6 +271,7 @@ class ActivationsStore:
|
|
|
269
271
|
self.sequence_separator_token: int | Literal["bos", "eos", "sep"] | None = (
|
|
270
272
|
sequence_separator_token
|
|
271
273
|
)
|
|
274
|
+
self.activations_mixing_fraction = activations_mixing_fraction
|
|
272
275
|
|
|
273
276
|
self.n_dataset_processed = 0
|
|
274
277
|
|
|
@@ -708,6 +711,7 @@ class ActivationsStore:
|
|
|
708
711
|
buffer_size=self.n_batches_in_buffer * self.training_context_size,
|
|
709
712
|
batch_size=self.train_batch_size_tokens,
|
|
710
713
|
activations_loader=self._iterate_filtered_activations(),
|
|
714
|
+
mix_fraction=self.activations_mixing_fraction,
|
|
711
715
|
)
|
|
712
716
|
|
|
713
717
|
def next_batch(self) -> torch.Tensor:
|
|
@@ -8,15 +8,19 @@ def mixing_buffer(
|
|
|
8
8
|
buffer_size: int,
|
|
9
9
|
batch_size: int,
|
|
10
10
|
activations_loader: Iterator[torch.Tensor],
|
|
11
|
+
mix_fraction: float = 0.5,
|
|
11
12
|
) -> Iterator[torch.Tensor]:
|
|
12
13
|
"""
|
|
13
14
|
A generator that maintains a mix of old and new activations for better training.
|
|
14
|
-
It
|
|
15
|
+
It keeps a portion of activations and mixes them with new ones to create batches.
|
|
15
16
|
|
|
16
17
|
Args:
|
|
17
|
-
buffer_size: Total size of the buffer
|
|
18
|
+
buffer_size: Total size of the buffer
|
|
18
19
|
batch_size: Size of batches to return
|
|
19
20
|
activations_loader: Iterator providing new activations
|
|
21
|
+
mix_fraction: Fraction of buffer to keep for mixing (default 0.5).
|
|
22
|
+
Higher values mean more temporal mixing but slower throughput.
|
|
23
|
+
If 0, no shuffling occurs (passthrough mode).
|
|
20
24
|
|
|
21
25
|
Yields:
|
|
22
26
|
Batches of activations of shape (batch_size, *activation_dims)
|
|
@@ -24,6 +28,8 @@ def mixing_buffer(
|
|
|
24
28
|
|
|
25
29
|
if buffer_size < batch_size:
|
|
26
30
|
raise ValueError("Buffer size must be greater than or equal to batch size")
|
|
31
|
+
if not 0 <= mix_fraction <= 1:
|
|
32
|
+
raise ValueError("mix_fraction must be in [0, 1]")
|
|
27
33
|
|
|
28
34
|
storage_buffer: torch.Tensor | None = None
|
|
29
35
|
|
|
@@ -35,10 +41,12 @@ def mixing_buffer(
|
|
|
35
41
|
)
|
|
36
42
|
|
|
37
43
|
if storage_buffer.shape[0] >= buffer_size:
|
|
38
|
-
|
|
39
|
-
|
|
44
|
+
if mix_fraction > 0:
|
|
45
|
+
storage_buffer = storage_buffer[torch.randperm(storage_buffer.shape[0])]
|
|
40
46
|
|
|
41
|
-
num_serving_batches = max(
|
|
47
|
+
num_serving_batches = max(
|
|
48
|
+
1, int(storage_buffer.shape[0] * (1 - mix_fraction)) // batch_size
|
|
49
|
+
)
|
|
42
50
|
serving_cutoff = num_serving_batches * batch_size
|
|
43
51
|
serving_buffer = storage_buffer[:serving_cutoff]
|
|
44
52
|
storage_buffer = storage_buffer[serving_cutoff:]
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|