sae-lens 6.26.1__tar.gz → 6.28.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sae_lens-6.26.1 → sae_lens-6.28.0}/PKG-INFO +3 -1
- {sae_lens-6.26.1 → sae_lens-6.28.0}/README.md +2 -0
- {sae_lens-6.26.1 → sae_lens-6.28.0}/pyproject.toml +2 -1
- {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/__init__.py +3 -1
- {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/cache_activations_runner.py +12 -5
- {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/config.py +2 -0
- {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/loading/pretrained_sae_loaders.py +2 -1
- {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/saes/gated_sae.py +1 -0
- {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/saes/jumprelu_sae.py +3 -0
- {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/saes/standard_sae.py +2 -0
- {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/saes/temporal_sae.py +1 -0
- sae_lens-6.28.0/sae_lens/synthetic/__init__.py +89 -0
- sae_lens-6.28.0/sae_lens/synthetic/activation_generator.py +215 -0
- sae_lens-6.28.0/sae_lens/synthetic/correlation.py +170 -0
- sae_lens-6.28.0/sae_lens/synthetic/evals.py +141 -0
- sae_lens-6.28.0/sae_lens/synthetic/feature_dictionary.py +138 -0
- sae_lens-6.28.0/sae_lens/synthetic/firing_probabilities.py +104 -0
- sae_lens-6.28.0/sae_lens/synthetic/hierarchy.py +335 -0
- sae_lens-6.28.0/sae_lens/synthetic/initialization.py +40 -0
- sae_lens-6.28.0/sae_lens/synthetic/plotting.py +230 -0
- sae_lens-6.28.0/sae_lens/synthetic/training.py +145 -0
- {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/training/activations_store.py +51 -91
- {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/training/mixing_buffer.py +14 -5
- {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/training/sae_trainer.py +1 -1
- {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/util.py +26 -1
- {sae_lens-6.26.1 → sae_lens-6.28.0}/LICENSE +0 -0
- {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/analysis/__init__.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/analysis/neuronpedia_integration.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/constants.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/evals.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/llm_sae_training_runner.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/load_model.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/loading/__init__.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/loading/pretrained_saes_directory.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/pretokenize_runner.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/pretrained_saes.yaml +0 -0
- {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/registry.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/saes/__init__.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/saes/batchtopk_sae.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/saes/matching_pursuit_sae.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/saes/matryoshka_batchtopk_sae.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/saes/sae.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/saes/topk_sae.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/saes/transcoder.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/tokenization_and_batching.py +1 -1
- {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/training/__init__.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/training/activation_scaler.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/training/optim.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/training/types.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
- {sae_lens-6.26.1 → sae_lens-6.28.0}/sae_lens/tutorial/tsea.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: sae-lens
|
|
3
|
-
Version: 6.
|
|
3
|
+
Version: 6.28.0
|
|
4
4
|
Summary: Training and Analyzing Sparse Autoencoders (SAEs)
|
|
5
5
|
License: MIT
|
|
6
6
|
License-File: LICENSE
|
|
@@ -77,6 +77,8 @@ The new v6 update is a major refactor to SAELens and changes the way training co
|
|
|
77
77
|
[](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/logits_lens_with_features.ipynb)
|
|
78
78
|
- [Training a Sparse Autoencoder](tutorials/training_a_sparse_autoencoder.ipynb)
|
|
79
79
|
[](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)
|
|
80
|
+
- [Training SAEs on Synthetic Data](tutorials/training_saes_on_synthetic_data.ipynb)
|
|
81
|
+
[](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/training_saes_on_synthetic_data.ipynb)
|
|
80
82
|
|
|
81
83
|
## Join the Slack!
|
|
82
84
|
|
|
@@ -41,6 +41,8 @@ The new v6 update is a major refactor to SAELens and changes the way training co
|
|
|
41
41
|
[](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/logits_lens_with_features.ipynb)
|
|
42
42
|
- [Training a Sparse Autoencoder](tutorials/training_a_sparse_autoencoder.ipynb)
|
|
43
43
|
[](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)
|
|
44
|
+
- [Training SAEs on Synthetic Data](tutorials/training_saes_on_synthetic_data.ipynb)
|
|
45
|
+
[](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/training_saes_on_synthetic_data.ipynb)
|
|
44
46
|
|
|
45
47
|
## Join the Slack!
|
|
46
48
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "sae-lens"
|
|
3
|
-
version = "6.
|
|
3
|
+
version = "6.28.0"
|
|
4
4
|
description = "Training and Analyzing Sparse Autoencoders (SAEs)"
|
|
5
5
|
authors = ["Joseph Bloom"]
|
|
6
6
|
readme = "README.md"
|
|
@@ -58,6 +58,7 @@ eai-sparsify = "^1.1.1"
|
|
|
58
58
|
mike = "^2.0.0"
|
|
59
59
|
trio = "^0.30.0"
|
|
60
60
|
dictionary-learning = "^0.1.0"
|
|
61
|
+
kaleido = "^1.2.0"
|
|
61
62
|
|
|
62
63
|
[tool.poetry.extras]
|
|
63
64
|
mamba = ["mamba-lens"]
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
# ruff: noqa: E402
|
|
2
|
-
__version__ = "6.
|
|
2
|
+
__version__ = "6.28.0"
|
|
3
3
|
|
|
4
4
|
import logging
|
|
5
5
|
|
|
@@ -63,6 +63,7 @@ from .loading.pretrained_sae_loaders import (
|
|
|
63
63
|
from .pretokenize_runner import PretokenizeRunner, pretokenize_runner
|
|
64
64
|
from .registry import register_sae_class, register_sae_training_class
|
|
65
65
|
from .training.activations_store import ActivationsStore
|
|
66
|
+
from .training.sae_trainer import SAETrainer
|
|
66
67
|
from .training.upload_saes_to_huggingface import upload_saes_to_huggingface
|
|
67
68
|
|
|
68
69
|
__all__ = [
|
|
@@ -102,6 +103,7 @@ __all__ = [
|
|
|
102
103
|
"JumpReLUTrainingSAE",
|
|
103
104
|
"JumpReLUTrainingSAEConfig",
|
|
104
105
|
"SAETrainingRunner",
|
|
106
|
+
"SAETrainer",
|
|
105
107
|
"LoggingConfig",
|
|
106
108
|
"BatchTopKTrainingSAE",
|
|
107
109
|
"BatchTopKTrainingSAEConfig",
|
|
@@ -263,14 +263,21 @@ class CacheActivationsRunner:
|
|
|
263
263
|
|
|
264
264
|
for i in tqdm(range(self.cfg.n_buffers), desc="Caching activations"):
|
|
265
265
|
try:
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
)
|
|
269
|
-
|
|
266
|
+
# Accumulate n_batches_in_buffer batches into one shard
|
|
267
|
+
buffers: list[tuple[torch.Tensor, torch.Tensor | None]] = []
|
|
268
|
+
for _ in range(self.cfg.n_batches_in_buffer):
|
|
269
|
+
buffers.append(self.activations_store.get_raw_llm_batch())
|
|
270
|
+
# Concatenate all batches
|
|
271
|
+
acts = torch.cat([b[0] for b in buffers], dim=0)
|
|
272
|
+
token_ids: torch.Tensor | None = None
|
|
273
|
+
if buffers[0][1] is not None:
|
|
274
|
+
# All batches have token_ids if the first one does
|
|
275
|
+
token_ids = torch.cat([b[1] for b in buffers], dim=0) # type: ignore[arg-type]
|
|
276
|
+
shard = self._create_shard((acts, token_ids))
|
|
270
277
|
shard.save_to_disk(
|
|
271
278
|
f"{tmp_cached_activation_path}/shard_{i:05d}", num_shards=1
|
|
272
279
|
)
|
|
273
|
-
del
|
|
280
|
+
del buffers, acts, token_ids, shard
|
|
274
281
|
except StopIteration:
|
|
275
282
|
logger.warning(
|
|
276
283
|
f"Warning: Ran out of samples while filling the buffer at batch {i} before reaching {self.cfg.n_buffers} batches."
|
|
@@ -148,6 +148,7 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
|
148
148
|
seqpos_slice (tuple[int | None, ...]): Determines slicing of activations when constructing batches during training. The slice should be (start_pos, end_pos, optional[step_size]), e.g. for Othello we sometimes use (5, -5). Note, step_size > 0.
|
|
149
149
|
disable_concat_sequences (bool): Whether to disable concatenating sequences and ignore sequences shorter than the context size. If True, disables concatenating and ignores short sequences.
|
|
150
150
|
sequence_separator_token (int | Literal["bos", "eos", "sep"] | None): If not `None`, this token will be placed between sentences in a batch to act as a separator. By default, this is the `<bos>` token.
|
|
151
|
+
activations_mixing_fraction (float): Fraction of the activation buffer to keep for mixing with new activations (default 0.5). Higher values mean more temporal shuffling but slower throughput. If 0, activations are served in order without shuffling (no temporal mixing).
|
|
151
152
|
device (str): The device to use. Usually "cuda".
|
|
152
153
|
act_store_device (str): The device to use for the activation store. "cpu" is advised in order to save VRAM. Defaults to "with_model" which uses the same device as the main model.
|
|
153
154
|
seed (int): The seed to use.
|
|
@@ -217,6 +218,7 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
|
217
218
|
sequence_separator_token: int | Literal["bos", "eos", "sep"] | None = (
|
|
218
219
|
special_token_field(default="bos")
|
|
219
220
|
)
|
|
221
|
+
activations_mixing_fraction: float = 0.5
|
|
220
222
|
|
|
221
223
|
# Misc
|
|
222
224
|
device: str = "cpu"
|
|
@@ -959,7 +959,7 @@ def get_dictionary_learning_config_1_from_hf(
|
|
|
959
959
|
architecture = "standard"
|
|
960
960
|
if trainer["dict_class"] == "GatedAutoEncoder":
|
|
961
961
|
architecture = "gated"
|
|
962
|
-
elif trainer["dict_class"]
|
|
962
|
+
elif trainer["dict_class"] in ["MatryoshkaBatchTopKSAE", "BatchTopKSAE"]:
|
|
963
963
|
architecture = "jumprelu"
|
|
964
964
|
|
|
965
965
|
return {
|
|
@@ -1831,6 +1831,7 @@ def temporal_sae_huggingface_loader(
|
|
|
1831
1831
|
Load TemporalSAE from canrager/temporalSAEs format (safetensors version).
|
|
1832
1832
|
|
|
1833
1833
|
Expects folder_name to contain:
|
|
1834
|
+
|
|
1834
1835
|
- conf.yaml (configuration)
|
|
1835
1836
|
- latest_ckpt.safetensors (model weights)
|
|
1836
1837
|
"""
|
|
@@ -118,6 +118,7 @@ class GatedTrainingSAE(TrainingSAE[GatedTrainingSAEConfig]):
|
|
|
118
118
|
"""
|
|
119
119
|
GatedTrainingSAE is a concrete implementation of BaseTrainingSAE for the "gated" SAE architecture.
|
|
120
120
|
It implements:
|
|
121
|
+
|
|
121
122
|
- initialize_weights: sets up gating parameters (as in GatedSAE) plus optional training-specific init.
|
|
122
123
|
- encode: calls encode_with_hidden_pre (standard training approach).
|
|
123
124
|
- decode: linear transformation + hooking, same as GatedSAE or StandardTrainingSAE.
|
|
@@ -105,6 +105,7 @@ class JumpReLUSAE(SAE[JumpReLUSAEConfig]):
|
|
|
105
105
|
activation function (e.g., ReLU etc.).
|
|
106
106
|
|
|
107
107
|
It implements:
|
|
108
|
+
|
|
108
109
|
- initialize_weights: sets up parameters, including a threshold.
|
|
109
110
|
- encode: computes the feature activations using JumpReLU.
|
|
110
111
|
- decode: reconstructs the input from the feature activations.
|
|
@@ -216,10 +217,12 @@ class JumpReLUTrainingSAE(TrainingSAE[JumpReLUTrainingSAEConfig]):
|
|
|
216
217
|
JumpReLUTrainingSAE is a training-focused implementation of a SAE using a JumpReLU activation.
|
|
217
218
|
|
|
218
219
|
Similar to the inference-only JumpReLUSAE, but with:
|
|
220
|
+
|
|
219
221
|
- A learnable log-threshold parameter (instead of a raw threshold).
|
|
220
222
|
- A specialized auxiliary loss term for sparsity (L0 or similar).
|
|
221
223
|
|
|
222
224
|
Methods of interest include:
|
|
225
|
+
|
|
223
226
|
- initialize_weights: sets up W_enc, b_enc, W_dec, b_dec, and log_threshold.
|
|
224
227
|
- encode_with_hidden_pre_jumprelu: runs a forward pass for training.
|
|
225
228
|
- training_forward_pass: calculates MSE and auxiliary losses, returning a TrainStepOutput.
|
|
@@ -34,6 +34,7 @@ class StandardSAE(SAE[StandardSAEConfig]):
|
|
|
34
34
|
using a simple linear encoder and decoder.
|
|
35
35
|
|
|
36
36
|
It implements the required abstract methods from BaseSAE:
|
|
37
|
+
|
|
37
38
|
- initialize_weights: sets up simple parameter initializations for W_enc, b_enc, W_dec, and b_dec.
|
|
38
39
|
- encode: computes the feature activations from an input.
|
|
39
40
|
- decode: reconstructs the input from the feature activations.
|
|
@@ -99,6 +100,7 @@ class StandardTrainingSAE(TrainingSAE[StandardTrainingSAEConfig]):
|
|
|
99
100
|
"""
|
|
100
101
|
StandardTrainingSAE is a concrete implementation of BaseTrainingSAE using the "standard" SAE architecture.
|
|
101
102
|
It implements:
|
|
103
|
+
|
|
102
104
|
- initialize_weights: basic weight initialization for encoder/decoder.
|
|
103
105
|
- encode: inference encoding (invokes encode_with_hidden_pre).
|
|
104
106
|
- decode: a simple linear decoder.
|
|
@@ -167,6 +167,7 @@ class TemporalSAE(SAE[TemporalSAEConfig]):
|
|
|
167
167
|
"""TemporalSAE: Sparse Autoencoder with temporal attention.
|
|
168
168
|
|
|
169
169
|
This SAE decomposes each activation x_t into:
|
|
170
|
+
|
|
170
171
|
- x_pred: Information aggregated from context {x_0, ..., x_{t-1}}
|
|
171
172
|
- x_novel: Novel information at position t (encoded sparsely)
|
|
172
173
|
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Synthetic data utilities for SAE experiments.
|
|
3
|
+
|
|
4
|
+
This module provides tools for creating feature dictionaries and generating
|
|
5
|
+
synthetic activations for testing and experimenting with SAEs.
|
|
6
|
+
|
|
7
|
+
Main components:
|
|
8
|
+
|
|
9
|
+
- FeatureDictionary: Maps sparse feature activations to dense hidden activations
|
|
10
|
+
- ActivationGenerator: Generates batches of synthetic feature activations
|
|
11
|
+
- HierarchyNode: Enforces hierarchical structure on feature activations
|
|
12
|
+
- Training utilities: Helpers for training and evaluating SAEs on synthetic data
|
|
13
|
+
- Plotting utilities: Visualization helpers for understanding SAE behavior
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from sae_lens.synthetic.activation_generator import (
|
|
17
|
+
ActivationGenerator,
|
|
18
|
+
ActivationsModifier,
|
|
19
|
+
ActivationsModifierInput,
|
|
20
|
+
)
|
|
21
|
+
from sae_lens.synthetic.correlation import (
|
|
22
|
+
create_correlation_matrix_from_correlations,
|
|
23
|
+
generate_random_correlation_matrix,
|
|
24
|
+
generate_random_correlations,
|
|
25
|
+
)
|
|
26
|
+
from sae_lens.synthetic.evals import (
|
|
27
|
+
SyntheticDataEvalResult,
|
|
28
|
+
eval_sae_on_synthetic_data,
|
|
29
|
+
mean_correlation_coefficient,
|
|
30
|
+
)
|
|
31
|
+
from sae_lens.synthetic.feature_dictionary import (
|
|
32
|
+
FeatureDictionary,
|
|
33
|
+
FeatureDictionaryInitializer,
|
|
34
|
+
orthogonal_initializer,
|
|
35
|
+
orthogonalize_embeddings,
|
|
36
|
+
)
|
|
37
|
+
from sae_lens.synthetic.firing_probabilities import (
|
|
38
|
+
linear_firing_probabilities,
|
|
39
|
+
random_firing_probabilities,
|
|
40
|
+
zipfian_firing_probabilities,
|
|
41
|
+
)
|
|
42
|
+
from sae_lens.synthetic.hierarchy import HierarchyNode, hierarchy_modifier
|
|
43
|
+
from sae_lens.synthetic.initialization import init_sae_to_match_feature_dict
|
|
44
|
+
from sae_lens.synthetic.plotting import (
|
|
45
|
+
find_best_feature_ordering,
|
|
46
|
+
find_best_feature_ordering_across_saes,
|
|
47
|
+
find_best_feature_ordering_from_sae,
|
|
48
|
+
plot_sae_feature_similarity,
|
|
49
|
+
)
|
|
50
|
+
from sae_lens.synthetic.training import (
|
|
51
|
+
SyntheticActivationIterator,
|
|
52
|
+
train_toy_sae,
|
|
53
|
+
)
|
|
54
|
+
from sae_lens.util import cosine_similarities
|
|
55
|
+
|
|
56
|
+
__all__ = [
|
|
57
|
+
# Main classes
|
|
58
|
+
"FeatureDictionary",
|
|
59
|
+
"HierarchyNode",
|
|
60
|
+
"hierarchy_modifier",
|
|
61
|
+
"ActivationGenerator",
|
|
62
|
+
# Activation generation
|
|
63
|
+
"zipfian_firing_probabilities",
|
|
64
|
+
"linear_firing_probabilities",
|
|
65
|
+
"random_firing_probabilities",
|
|
66
|
+
"create_correlation_matrix_from_correlations",
|
|
67
|
+
"generate_random_correlations",
|
|
68
|
+
"generate_random_correlation_matrix",
|
|
69
|
+
# Feature modifiers
|
|
70
|
+
"ActivationsModifier",
|
|
71
|
+
"ActivationsModifierInput",
|
|
72
|
+
# Utilities
|
|
73
|
+
"orthogonalize_embeddings",
|
|
74
|
+
"orthogonal_initializer",
|
|
75
|
+
"FeatureDictionaryInitializer",
|
|
76
|
+
"cosine_similarities",
|
|
77
|
+
# Training utilities
|
|
78
|
+
"SyntheticActivationIterator",
|
|
79
|
+
"SyntheticDataEvalResult",
|
|
80
|
+
"train_toy_sae",
|
|
81
|
+
"eval_sae_on_synthetic_data",
|
|
82
|
+
"mean_correlation_coefficient",
|
|
83
|
+
"init_sae_to_match_feature_dict",
|
|
84
|
+
# Plotting utilities
|
|
85
|
+
"find_best_feature_ordering",
|
|
86
|
+
"find_best_feature_ordering_from_sae",
|
|
87
|
+
"find_best_feature_ordering_across_saes",
|
|
88
|
+
"plot_sae_feature_similarity",
|
|
89
|
+
]
|
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Functions for generating synthetic feature activations.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from collections.abc import Callable, Sequence
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from scipy.stats import norm
|
|
9
|
+
from torch import nn
|
|
10
|
+
from torch.distributions import MultivariateNormal
|
|
11
|
+
|
|
12
|
+
from sae_lens.util import str_to_dtype
|
|
13
|
+
|
|
14
|
+
ActivationsModifier = Callable[[torch.Tensor], torch.Tensor]
|
|
15
|
+
ActivationsModifierInput = ActivationsModifier | Sequence[ActivationsModifier] | None
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ActivationGenerator(nn.Module):
|
|
19
|
+
"""
|
|
20
|
+
Generator for synthetic feature activations.
|
|
21
|
+
|
|
22
|
+
This module provides a generator for synthetic feature activations with controlled properties.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
num_features: int
|
|
26
|
+
firing_probabilities: torch.Tensor
|
|
27
|
+
std_firing_magnitudes: torch.Tensor
|
|
28
|
+
mean_firing_magnitudes: torch.Tensor
|
|
29
|
+
modify_activations: ActivationsModifier | None
|
|
30
|
+
correlation_matrix: torch.Tensor | None
|
|
31
|
+
correlation_thresholds: torch.Tensor | None
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
num_features: int,
|
|
36
|
+
firing_probabilities: torch.Tensor | float,
|
|
37
|
+
std_firing_magnitudes: torch.Tensor | float = 0.0,
|
|
38
|
+
mean_firing_magnitudes: torch.Tensor | float = 1.0,
|
|
39
|
+
modify_activations: ActivationsModifierInput = None,
|
|
40
|
+
correlation_matrix: torch.Tensor | None = None,
|
|
41
|
+
device: torch.device | str = "cpu",
|
|
42
|
+
dtype: torch.dtype | str = "float32",
|
|
43
|
+
):
|
|
44
|
+
super().__init__()
|
|
45
|
+
self.num_features = num_features
|
|
46
|
+
self.firing_probabilities = _to_tensor(
|
|
47
|
+
firing_probabilities, num_features, device, dtype
|
|
48
|
+
)
|
|
49
|
+
self.std_firing_magnitudes = _to_tensor(
|
|
50
|
+
std_firing_magnitudes, num_features, device, dtype
|
|
51
|
+
)
|
|
52
|
+
self.mean_firing_magnitudes = _to_tensor(
|
|
53
|
+
mean_firing_magnitudes, num_features, device, dtype
|
|
54
|
+
)
|
|
55
|
+
self.modify_activations = _normalize_modifiers(modify_activations)
|
|
56
|
+
self.correlation_thresholds = None
|
|
57
|
+
if correlation_matrix is not None:
|
|
58
|
+
_validate_correlation_matrix(correlation_matrix, num_features)
|
|
59
|
+
self.correlation_thresholds = torch.tensor(
|
|
60
|
+
[norm.ppf(1 - p.item()) for p in self.firing_probabilities],
|
|
61
|
+
device=device,
|
|
62
|
+
dtype=self.firing_probabilities.dtype,
|
|
63
|
+
)
|
|
64
|
+
self.correlation_matrix = correlation_matrix
|
|
65
|
+
|
|
66
|
+
def sample(self, batch_size: int) -> torch.Tensor:
|
|
67
|
+
"""
|
|
68
|
+
Generate a batch of feature activations with controlled properties.
|
|
69
|
+
|
|
70
|
+
This is the main function for generating synthetic training data for SAEs.
|
|
71
|
+
Features fire independently according to their firing probabilities unless
|
|
72
|
+
a correlation matrix is provided.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
batch_size: Number of samples to generate
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
Tensor of shape [batch_size, num_features] with non-negative activations
|
|
79
|
+
"""
|
|
80
|
+
# All tensors (firing_probabilities, std_firing_magnitudes, mean_firing_magnitudes)
|
|
81
|
+
# are on the same device from __init__ via _to_tensor()
|
|
82
|
+
device = self.firing_probabilities.device
|
|
83
|
+
|
|
84
|
+
if self.correlation_matrix is not None:
|
|
85
|
+
assert self.correlation_thresholds is not None
|
|
86
|
+
firing_features = _generate_correlated_features(
|
|
87
|
+
batch_size,
|
|
88
|
+
self.correlation_matrix,
|
|
89
|
+
self.correlation_thresholds,
|
|
90
|
+
device,
|
|
91
|
+
)
|
|
92
|
+
else:
|
|
93
|
+
firing_features = torch.bernoulli(
|
|
94
|
+
self.firing_probabilities.unsqueeze(0).expand(batch_size, -1)
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
firing_magnitude_delta = torch.normal(
|
|
98
|
+
torch.zeros_like(self.firing_probabilities)
|
|
99
|
+
.unsqueeze(0)
|
|
100
|
+
.expand(batch_size, -1),
|
|
101
|
+
self.std_firing_magnitudes.unsqueeze(0).expand(batch_size, -1),
|
|
102
|
+
)
|
|
103
|
+
firing_magnitude_delta[firing_features == 0] = 0
|
|
104
|
+
feature_activations = (
|
|
105
|
+
firing_features * self.mean_firing_magnitudes + firing_magnitude_delta
|
|
106
|
+
).relu()
|
|
107
|
+
|
|
108
|
+
if self.modify_activations is not None:
|
|
109
|
+
feature_activations = self.modify_activations(feature_activations).relu()
|
|
110
|
+
return feature_activations
|
|
111
|
+
|
|
112
|
+
def forward(self, batch_size: int) -> torch.Tensor:
|
|
113
|
+
return self.sample(batch_size)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _generate_correlated_features(
|
|
117
|
+
batch_size: int,
|
|
118
|
+
correlation_matrix: torch.Tensor,
|
|
119
|
+
thresholds: torch.Tensor,
|
|
120
|
+
device: torch.device,
|
|
121
|
+
) -> torch.Tensor:
|
|
122
|
+
"""
|
|
123
|
+
Generate correlated binary features using multivariate Gaussian sampling.
|
|
124
|
+
|
|
125
|
+
Uses the Gaussian copula approach: sample from a multivariate normal
|
|
126
|
+
distribution, then threshold to get binary features.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
batch_size: Number of samples to generate
|
|
130
|
+
correlation_matrix: Correlation matrix between features
|
|
131
|
+
thresholds: Pre-computed thresholds for each feature (from inverse normal CDF)
|
|
132
|
+
device: Device to generate samples on
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
Binary feature matrix of shape [batch_size, num_features]
|
|
136
|
+
"""
|
|
137
|
+
num_features = correlation_matrix.shape[0]
|
|
138
|
+
|
|
139
|
+
mvn = MultivariateNormal(
|
|
140
|
+
loc=torch.zeros(num_features, device=device, dtype=thresholds.dtype),
|
|
141
|
+
covariance_matrix=correlation_matrix.to(device=device, dtype=thresholds.dtype),
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
gaussian_samples = mvn.sample((batch_size,))
|
|
145
|
+
return (gaussian_samples > thresholds.unsqueeze(0)).float()
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def _to_tensor(
|
|
149
|
+
value: torch.Tensor | float,
|
|
150
|
+
num_features: int,
|
|
151
|
+
device: torch.device | str,
|
|
152
|
+
dtype: torch.dtype | str,
|
|
153
|
+
) -> torch.Tensor:
|
|
154
|
+
dtype = str_to_dtype(dtype)
|
|
155
|
+
device = torch.device(device)
|
|
156
|
+
if not isinstance(value, torch.Tensor):
|
|
157
|
+
value = value * torch.ones(num_features, device=device, dtype=dtype)
|
|
158
|
+
if value.shape != (num_features,):
|
|
159
|
+
raise ValueError(
|
|
160
|
+
f"Value must be a tensor of shape ({num_features},) or a float. Got {value.shape}"
|
|
161
|
+
)
|
|
162
|
+
return value.to(device, dtype)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def _normalize_modifiers(
|
|
166
|
+
modify_activations: ActivationsModifierInput,
|
|
167
|
+
) -> ActivationsModifier | None:
|
|
168
|
+
"""Convert modifier input to a single modifier or None."""
|
|
169
|
+
if modify_activations is None:
|
|
170
|
+
return None
|
|
171
|
+
if callable(modify_activations):
|
|
172
|
+
return modify_activations
|
|
173
|
+
# It's a sequence of modifiers - chain them
|
|
174
|
+
modifiers = list(modify_activations)
|
|
175
|
+
if len(modifiers) == 0:
|
|
176
|
+
return None
|
|
177
|
+
if len(modifiers) == 1:
|
|
178
|
+
return modifiers[0]
|
|
179
|
+
|
|
180
|
+
def chained(activations: torch.Tensor) -> torch.Tensor:
|
|
181
|
+
result = activations
|
|
182
|
+
for modifier in modifiers:
|
|
183
|
+
result = modifier(result)
|
|
184
|
+
return result
|
|
185
|
+
|
|
186
|
+
return chained
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def _validate_correlation_matrix(
|
|
190
|
+
correlation_matrix: torch.Tensor, num_features: int
|
|
191
|
+
) -> None:
|
|
192
|
+
"""Validate that a correlation matrix has correct properties.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
correlation_matrix: The matrix to validate
|
|
196
|
+
num_features: Expected number of features (matrix should be [num_features, num_features])
|
|
197
|
+
|
|
198
|
+
Raises:
|
|
199
|
+
ValueError: If the matrix has incorrect shape, non-unit diagonal, or is not positive definite
|
|
200
|
+
"""
|
|
201
|
+
expected_shape = (num_features, num_features)
|
|
202
|
+
if correlation_matrix.shape != expected_shape:
|
|
203
|
+
raise ValueError(
|
|
204
|
+
f"Correlation matrix must have shape {expected_shape}, "
|
|
205
|
+
f"got {tuple(correlation_matrix.shape)}"
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
diagonal = torch.diag(correlation_matrix)
|
|
209
|
+
if not torch.allclose(diagonal, torch.ones_like(diagonal)):
|
|
210
|
+
raise ValueError("Correlation matrix diagonal must be all 1s")
|
|
211
|
+
|
|
212
|
+
try:
|
|
213
|
+
torch.linalg.cholesky(correlation_matrix)
|
|
214
|
+
except RuntimeError as e:
|
|
215
|
+
raise ValueError("Correlation matrix must be positive definite") from e
|
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
import random
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def create_correlation_matrix_from_correlations(
|
|
7
|
+
num_features: int,
|
|
8
|
+
correlations: dict[tuple[int, int], float] | None = None,
|
|
9
|
+
default_correlation: float = 0.0,
|
|
10
|
+
) -> torch.Tensor:
|
|
11
|
+
"""
|
|
12
|
+
Create a correlation matrix with specified pairwise correlations.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
num_features: Number of features
|
|
16
|
+
correlations: Dict mapping (i, j) pairs to correlation values.
|
|
17
|
+
Pairs should have i < j.
|
|
18
|
+
default_correlation: Default correlation for unspecified pairs
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
Correlation matrix of shape [num_features, num_features]
|
|
22
|
+
"""
|
|
23
|
+
matrix = torch.eye(num_features) + default_correlation * (
|
|
24
|
+
1 - torch.eye(num_features)
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
if correlations is not None:
|
|
28
|
+
for (i, j), corr in correlations.items():
|
|
29
|
+
matrix[i, j] = corr
|
|
30
|
+
matrix[j, i] = corr
|
|
31
|
+
|
|
32
|
+
# Ensure matrix is symmetric (numerical precision)
|
|
33
|
+
matrix = (matrix + matrix.T) / 2
|
|
34
|
+
|
|
35
|
+
# Check positive definiteness and fix if necessary
|
|
36
|
+
# Use eigvalsh for symmetric matrices (returns real eigenvalues)
|
|
37
|
+
eigenvals = torch.linalg.eigvalsh(matrix)
|
|
38
|
+
if torch.any(eigenvals < -1e-6):
|
|
39
|
+
matrix = _fix_correlation_matrix(matrix)
|
|
40
|
+
|
|
41
|
+
return matrix
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _fix_correlation_matrix(
|
|
45
|
+
matrix: torch.Tensor, min_eigenval: float = 1e-6
|
|
46
|
+
) -> torch.Tensor:
|
|
47
|
+
"""Fix a correlation matrix to be positive semi-definite."""
|
|
48
|
+
eigenvals, eigenvecs = torch.linalg.eigh(matrix)
|
|
49
|
+
eigenvals = torch.clamp(eigenvals, min=min_eigenval)
|
|
50
|
+
fixed_matrix = eigenvecs @ torch.diag(eigenvals) @ eigenvecs.T
|
|
51
|
+
|
|
52
|
+
diag_vals = torch.diag(fixed_matrix)
|
|
53
|
+
fixed_matrix = fixed_matrix / torch.sqrt(
|
|
54
|
+
diag_vals.unsqueeze(0) * diag_vals.unsqueeze(1)
|
|
55
|
+
)
|
|
56
|
+
fixed_matrix.fill_diagonal_(1.0)
|
|
57
|
+
|
|
58
|
+
return fixed_matrix
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def generate_random_correlations(
|
|
62
|
+
num_features: int,
|
|
63
|
+
positive_ratio: float = 0.5,
|
|
64
|
+
uncorrelated_ratio: float = 0.3,
|
|
65
|
+
min_correlation_strength: float = 0.1,
|
|
66
|
+
max_correlation_strength: float = 0.8,
|
|
67
|
+
seed: int | None = None,
|
|
68
|
+
) -> dict[tuple[int, int], float]:
|
|
69
|
+
"""
|
|
70
|
+
Generate random correlations between features with specified constraints.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
num_features: Number of features
|
|
74
|
+
positive_ratio: Fraction of correlations that should be positive (0.0 to 1.0)
|
|
75
|
+
uncorrelated_ratio: Fraction of feature pairs that should remain uncorrelated (0.0 to 1.0)
|
|
76
|
+
min_correlation_strength: Minimum absolute correlation strength
|
|
77
|
+
max_correlation_strength: Maximum absolute correlation strength
|
|
78
|
+
seed: Random seed for reproducibility
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
Dictionary mapping (i, j) pairs to correlation values
|
|
82
|
+
"""
|
|
83
|
+
# Use local random number generator to avoid side effects on global state
|
|
84
|
+
rng = random.Random(seed)
|
|
85
|
+
|
|
86
|
+
# Validate inputs
|
|
87
|
+
if not 0.0 <= positive_ratio <= 1.0:
|
|
88
|
+
raise ValueError("positive_ratio must be between 0.0 and 1.0")
|
|
89
|
+
if not 0.0 <= uncorrelated_ratio <= 1.0:
|
|
90
|
+
raise ValueError("uncorrelated_ratio must be between 0.0 and 1.0")
|
|
91
|
+
if min_correlation_strength < 0:
|
|
92
|
+
raise ValueError("min_correlation_strength must be non-negative")
|
|
93
|
+
if max_correlation_strength > 1.0:
|
|
94
|
+
raise ValueError("max_correlation_strength must be <= 1.0")
|
|
95
|
+
if min_correlation_strength > max_correlation_strength:
|
|
96
|
+
raise ValueError("min_correlation_strength must be <= max_correlation_strength")
|
|
97
|
+
|
|
98
|
+
# Generate all possible feature pairs (i, j) where i < j
|
|
99
|
+
all_pairs = [
|
|
100
|
+
(i, j) for i in range(num_features) for j in range(i + 1, num_features)
|
|
101
|
+
]
|
|
102
|
+
total_pairs = len(all_pairs)
|
|
103
|
+
|
|
104
|
+
if total_pairs == 0:
|
|
105
|
+
return {}
|
|
106
|
+
|
|
107
|
+
# Determine how many pairs to correlate vs leave uncorrelated
|
|
108
|
+
num_uncorrelated = int(total_pairs * uncorrelated_ratio)
|
|
109
|
+
num_correlated = total_pairs - num_uncorrelated
|
|
110
|
+
|
|
111
|
+
# Randomly select which pairs to correlate
|
|
112
|
+
correlated_pairs = rng.sample(all_pairs, num_correlated)
|
|
113
|
+
|
|
114
|
+
# For correlated pairs, determine positive vs negative
|
|
115
|
+
num_positive = int(num_correlated * positive_ratio)
|
|
116
|
+
num_negative = num_correlated - num_positive
|
|
117
|
+
|
|
118
|
+
# Assign signs
|
|
119
|
+
signs = [1] * num_positive + [-1] * num_negative
|
|
120
|
+
rng.shuffle(signs)
|
|
121
|
+
|
|
122
|
+
# Generate correlation strengths
|
|
123
|
+
correlations = {}
|
|
124
|
+
for pair, sign in zip(correlated_pairs, signs):
|
|
125
|
+
# Sample correlation strength uniformly from range
|
|
126
|
+
strength = rng.uniform(min_correlation_strength, max_correlation_strength)
|
|
127
|
+
correlations[pair] = sign * strength
|
|
128
|
+
|
|
129
|
+
return correlations
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def generate_random_correlation_matrix(
|
|
133
|
+
num_features: int,
|
|
134
|
+
positive_ratio: float = 0.5,
|
|
135
|
+
uncorrelated_ratio: float = 0.3,
|
|
136
|
+
min_correlation_strength: float = 0.1,
|
|
137
|
+
max_correlation_strength: float = 0.8,
|
|
138
|
+
seed: int | None = None,
|
|
139
|
+
) -> torch.Tensor:
|
|
140
|
+
"""
|
|
141
|
+
Generate a random correlation matrix with specified constraints.
|
|
142
|
+
|
|
143
|
+
This is a convenience function that combines generate_random_correlations()
|
|
144
|
+
and create_correlation_matrix_from_correlations() into a single call.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
num_features: Number of features
|
|
148
|
+
positive_ratio: Fraction of correlations that should be positive (0.0 to 1.0)
|
|
149
|
+
uncorrelated_ratio: Fraction of feature pairs that should remain uncorrelated (0.0 to 1.0)
|
|
150
|
+
min_correlation_strength: Minimum absolute correlation strength
|
|
151
|
+
max_correlation_strength: Maximum absolute correlation strength
|
|
152
|
+
seed: Random seed for reproducibility
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
Random correlation matrix of shape [num_features, num_features]
|
|
156
|
+
"""
|
|
157
|
+
# Generate random correlations
|
|
158
|
+
correlations = generate_random_correlations(
|
|
159
|
+
num_features=num_features,
|
|
160
|
+
positive_ratio=positive_ratio,
|
|
161
|
+
uncorrelated_ratio=uncorrelated_ratio,
|
|
162
|
+
min_correlation_strength=min_correlation_strength,
|
|
163
|
+
max_correlation_strength=max_correlation_strength,
|
|
164
|
+
seed=seed,
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
# Create and return correlation matrix
|
|
168
|
+
return create_correlation_matrix_from_correlations(
|
|
169
|
+
num_features=num_features, correlations=correlations
|
|
170
|
+
)
|