sae-lens 6.28.1__tar.gz → 6.29.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sae_lens-6.28.1 → sae_lens-6.29.1}/PKG-INFO +11 -1
- {sae_lens-6.28.1 → sae_lens-6.29.1}/README.md +10 -0
- {sae_lens-6.28.1 → sae_lens-6.29.1}/pyproject.toml +1 -1
- {sae_lens-6.28.1 → sae_lens-6.29.1}/sae_lens/__init__.py +1 -1
- {sae_lens-6.28.1 → sae_lens-6.29.1}/sae_lens/pretrained_saes.yaml +1 -1
- {sae_lens-6.28.1 → sae_lens-6.29.1}/sae_lens/synthetic/__init__.py +6 -0
- sae_lens-6.29.1/sae_lens/synthetic/activation_generator.py +388 -0
- sae_lens-6.29.1/sae_lens/synthetic/correlation.py +351 -0
- {sae_lens-6.28.1 → sae_lens-6.29.1}/sae_lens/synthetic/feature_dictionary.py +64 -17
- sae_lens-6.29.1/sae_lens/synthetic/hierarchy.py +908 -0
- {sae_lens-6.28.1 → sae_lens-6.29.1}/sae_lens/synthetic/training.py +16 -3
- sae_lens-6.28.1/sae_lens/synthetic/activation_generator.py +0 -215
- sae_lens-6.28.1/sae_lens/synthetic/correlation.py +0 -170
- sae_lens-6.28.1/sae_lens/synthetic/hierarchy.py +0 -335
- {sae_lens-6.28.1 → sae_lens-6.29.1}/LICENSE +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.1}/sae_lens/analysis/__init__.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.1}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.1}/sae_lens/analysis/neuronpedia_integration.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.1}/sae_lens/cache_activations_runner.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.1}/sae_lens/config.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.1}/sae_lens/constants.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.1}/sae_lens/evals.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.1}/sae_lens/llm_sae_training_runner.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.1}/sae_lens/load_model.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.1}/sae_lens/loading/__init__.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.1}/sae_lens/loading/pretrained_sae_loaders.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.1}/sae_lens/loading/pretrained_saes_directory.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.1}/sae_lens/pretokenize_runner.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.1}/sae_lens/registry.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.1}/sae_lens/saes/__init__.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.1}/sae_lens/saes/batchtopk_sae.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.1}/sae_lens/saes/gated_sae.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.1}/sae_lens/saes/jumprelu_sae.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.1}/sae_lens/saes/matching_pursuit_sae.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.1}/sae_lens/saes/matryoshka_batchtopk_sae.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.1}/sae_lens/saes/sae.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.1}/sae_lens/saes/standard_sae.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.1}/sae_lens/saes/temporal_sae.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.1}/sae_lens/saes/topk_sae.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.1}/sae_lens/saes/transcoder.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.1}/sae_lens/synthetic/evals.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.1}/sae_lens/synthetic/firing_probabilities.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.1}/sae_lens/synthetic/initialization.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.1}/sae_lens/synthetic/plotting.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.1}/sae_lens/tokenization_and_batching.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.1}/sae_lens/training/__init__.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.1}/sae_lens/training/activation_scaler.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.1}/sae_lens/training/activations_store.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.1}/sae_lens/training/mixing_buffer.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.1}/sae_lens/training/optim.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.1}/sae_lens/training/sae_trainer.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.1}/sae_lens/training/types.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.1}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.1}/sae_lens/tutorial/tsea.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.1}/sae_lens/util.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: sae-lens
|
|
3
|
-
Version: 6.
|
|
3
|
+
Version: 6.29.1
|
|
4
4
|
Summary: Training and Analyzing Sparse Autoencoders (SAEs)
|
|
5
5
|
License: MIT
|
|
6
6
|
License-File: LICENSE
|
|
@@ -50,6 +50,8 @@ SAELens exists to help researchers:
|
|
|
50
50
|
- Analyse sparse autoencoders / research mechanistic interpretability.
|
|
51
51
|
- Generate insights which make it easier to create safe and aligned AI systems.
|
|
52
52
|
|
|
53
|
+
SAELens inference works with any PyTorch-based model, not just TransformerLens. While we provide deep integration with TransformerLens via `HookedSAETransformer`, SAEs can be used with Hugging Face Transformers, NNsight, or any other framework by extracting activations and passing them to the SAE's `encode()` and `decode()` methods.
|
|
54
|
+
|
|
53
55
|
Please refer to the [documentation](https://decoderesearch.github.io/SAELens/) for information on how to:
|
|
54
56
|
|
|
55
57
|
- Download and Analyse pre-trained sparse autoencoders.
|
|
@@ -84,6 +86,14 @@ The new v6 update is a major refactor to SAELens and changes the way training co
|
|
|
84
86
|
|
|
85
87
|
Feel free to join the [Open Source Mechanistic Interpretability Slack](https://join.slack.com/t/opensourcemechanistic/shared_invite/zt-375zalm04-GFd5tdBU1yLKlu_T_JSqZQ) for support!
|
|
86
88
|
|
|
89
|
+
## Other SAE Projects
|
|
90
|
+
|
|
91
|
+
- [dictionary-learning](https://github.com/saprmarks/dictionary_learning): An SAE training library that focuses on having hackable code.
|
|
92
|
+
- [Sparsify](https://github.com/EleutherAI/sparsify): A lean SAE training library focused on TopK SAEs.
|
|
93
|
+
- [Overcomplete](https://github.com/KempnerInstitute/overcomplete): SAE training library focused on vision models.
|
|
94
|
+
- [SAE-Vis](https://github.com/callummcdougall/sae_vis): A library for visualizing SAE features, works with SAELens.
|
|
95
|
+
- [SAEBench](https://github.com/adamkarvonen/SAEBench): A suite of LLM SAE benchmarks, works with SAELens.
|
|
96
|
+
|
|
87
97
|
## Citation
|
|
88
98
|
|
|
89
99
|
Please cite the package as follows:
|
|
@@ -14,6 +14,8 @@ SAELens exists to help researchers:
|
|
|
14
14
|
- Analyse sparse autoencoders / research mechanistic interpretability.
|
|
15
15
|
- Generate insights which make it easier to create safe and aligned AI systems.
|
|
16
16
|
|
|
17
|
+
SAELens inference works with any PyTorch-based model, not just TransformerLens. While we provide deep integration with TransformerLens via `HookedSAETransformer`, SAEs can be used with Hugging Face Transformers, NNsight, or any other framework by extracting activations and passing them to the SAE's `encode()` and `decode()` methods.
|
|
18
|
+
|
|
17
19
|
Please refer to the [documentation](https://decoderesearch.github.io/SAELens/) for information on how to:
|
|
18
20
|
|
|
19
21
|
- Download and Analyse pre-trained sparse autoencoders.
|
|
@@ -48,6 +50,14 @@ The new v6 update is a major refactor to SAELens and changes the way training co
|
|
|
48
50
|
|
|
49
51
|
Feel free to join the [Open Source Mechanistic Interpretability Slack](https://join.slack.com/t/opensourcemechanistic/shared_invite/zt-375zalm04-GFd5tdBU1yLKlu_T_JSqZQ) for support!
|
|
50
52
|
|
|
53
|
+
## Other SAE Projects
|
|
54
|
+
|
|
55
|
+
- [dictionary-learning](https://github.com/saprmarks/dictionary_learning): An SAE training library that focuses on having hackable code.
|
|
56
|
+
- [Sparsify](https://github.com/EleutherAI/sparsify): A lean SAE training library focused on TopK SAEs.
|
|
57
|
+
- [Overcomplete](https://github.com/KempnerInstitute/overcomplete): SAE training library focused on vision models.
|
|
58
|
+
- [SAE-Vis](https://github.com/callummcdougall/sae_vis): A library for visualizing SAE features, works with SAELens.
|
|
59
|
+
- [SAEBench](https://github.com/adamkarvonen/SAEBench): A suite of LLM SAE benchmarks, works with SAELens.
|
|
60
|
+
|
|
51
61
|
## Citation
|
|
52
62
|
|
|
53
63
|
Please cite the package as follows:
|
|
@@ -40631,7 +40631,7 @@ gemma-3-1b-res-matryoshka-dc:
|
|
|
40631
40631
|
conversion_func: null
|
|
40632
40632
|
links:
|
|
40633
40633
|
model: https://huggingface.co/google/gemma-3-1b-pt
|
|
40634
|
-
model: gemma-3-1b
|
|
40634
|
+
model: google/gemma-3-1b-pt
|
|
40635
40635
|
repo_id: chanind/gemma-3-1b-batch-topk-matryoshka-saes-w-32k-l0-40
|
|
40636
40636
|
saes:
|
|
40637
40637
|
- id: blocks.0.hook_resid_post
|
|
@@ -17,11 +17,14 @@ from sae_lens.synthetic.activation_generator import (
|
|
|
17
17
|
ActivationGenerator,
|
|
18
18
|
ActivationsModifier,
|
|
19
19
|
ActivationsModifierInput,
|
|
20
|
+
CorrelationMatrixInput,
|
|
20
21
|
)
|
|
21
22
|
from sae_lens.synthetic.correlation import (
|
|
23
|
+
LowRankCorrelationMatrix,
|
|
22
24
|
create_correlation_matrix_from_correlations,
|
|
23
25
|
generate_random_correlation_matrix,
|
|
24
26
|
generate_random_correlations,
|
|
27
|
+
generate_random_low_rank_correlation_matrix,
|
|
25
28
|
)
|
|
26
29
|
from sae_lens.synthetic.evals import (
|
|
27
30
|
SyntheticDataEvalResult,
|
|
@@ -66,6 +69,9 @@ __all__ = [
|
|
|
66
69
|
"create_correlation_matrix_from_correlations",
|
|
67
70
|
"generate_random_correlations",
|
|
68
71
|
"generate_random_correlation_matrix",
|
|
72
|
+
"generate_random_low_rank_correlation_matrix",
|
|
73
|
+
"LowRankCorrelationMatrix",
|
|
74
|
+
"CorrelationMatrixInput",
|
|
69
75
|
# Feature modifiers
|
|
70
76
|
"ActivationsModifier",
|
|
71
77
|
"ActivationsModifierInput",
|
|
@@ -0,0 +1,388 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Functions for generating synthetic feature activations.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import math
|
|
6
|
+
from collections.abc import Callable, Sequence
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from torch import nn
|
|
10
|
+
from torch.distributions import MultivariateNormal
|
|
11
|
+
|
|
12
|
+
from sae_lens.synthetic.correlation import LowRankCorrelationMatrix
|
|
13
|
+
from sae_lens.util import str_to_dtype
|
|
14
|
+
|
|
15
|
+
ActivationsModifier = Callable[[torch.Tensor], torch.Tensor]
|
|
16
|
+
ActivationsModifierInput = ActivationsModifier | Sequence[ActivationsModifier] | None
|
|
17
|
+
CorrelationMatrixInput = (
|
|
18
|
+
torch.Tensor | LowRankCorrelationMatrix | tuple[torch.Tensor, torch.Tensor]
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ActivationGenerator(nn.Module):
|
|
23
|
+
"""
|
|
24
|
+
Generator for synthetic feature activations.
|
|
25
|
+
|
|
26
|
+
This module provides a generator for synthetic feature activations with controlled properties.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
num_features: int
|
|
30
|
+
firing_probabilities: torch.Tensor
|
|
31
|
+
std_firing_magnitudes: torch.Tensor
|
|
32
|
+
mean_firing_magnitudes: torch.Tensor
|
|
33
|
+
modify_activations: ActivationsModifier | None
|
|
34
|
+
correlation_matrix: torch.Tensor | None
|
|
35
|
+
low_rank_correlation: tuple[torch.Tensor, torch.Tensor] | None
|
|
36
|
+
correlation_thresholds: torch.Tensor | None
|
|
37
|
+
use_sparse_tensors: bool
|
|
38
|
+
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
num_features: int,
|
|
42
|
+
firing_probabilities: torch.Tensor | float,
|
|
43
|
+
std_firing_magnitudes: torch.Tensor | float = 0.0,
|
|
44
|
+
mean_firing_magnitudes: torch.Tensor | float = 1.0,
|
|
45
|
+
modify_activations: ActivationsModifierInput = None,
|
|
46
|
+
correlation_matrix: CorrelationMatrixInput | None = None,
|
|
47
|
+
device: torch.device | str = "cpu",
|
|
48
|
+
dtype: torch.dtype | str = "float32",
|
|
49
|
+
use_sparse_tensors: bool = False,
|
|
50
|
+
):
|
|
51
|
+
"""
|
|
52
|
+
Create a new ActivationGenerator.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
num_features: Number of features to generate activations for.
|
|
56
|
+
firing_probabilities: Probability of each feature firing. Can be a single
|
|
57
|
+
float (applied to all features) or a tensor of shape (num_features,).
|
|
58
|
+
std_firing_magnitudes: Standard deviation of firing magnitudes. Can be a
|
|
59
|
+
single float or a tensor of shape (num_features,). Defaults to 0.0
|
|
60
|
+
(deterministic magnitudes).
|
|
61
|
+
mean_firing_magnitudes: Mean firing magnitude when a feature fires. Can be
|
|
62
|
+
a single float or a tensor of shape (num_features,). Defaults to 1.0.
|
|
63
|
+
modify_activations: Optional function(s) to modify activations after
|
|
64
|
+
generation. Can be a single callable, a sequence of callables (applied
|
|
65
|
+
in order), or None. Useful for applying hierarchy constraints.
|
|
66
|
+
correlation_matrix: Optional correlation structure between features. Can be:
|
|
67
|
+
|
|
68
|
+
- A full correlation matrix tensor of shape (num_features, num_features)
|
|
69
|
+
- A LowRankCorrelationMatrix for memory-efficient large-scale correlations
|
|
70
|
+
- A tuple of (factor, diag) tensors representing low-rank structure
|
|
71
|
+
|
|
72
|
+
device: Device to place tensors on. Defaults to "cpu".
|
|
73
|
+
dtype: Data type for tensors. Defaults to "float32".
|
|
74
|
+
use_sparse_tensors: If True, return sparse COO tensors from sample().
|
|
75
|
+
Only recommended when using massive numbers of features. Defaults to False.
|
|
76
|
+
"""
|
|
77
|
+
super().__init__()
|
|
78
|
+
self.num_features = num_features
|
|
79
|
+
self.firing_probabilities = _to_tensor(
|
|
80
|
+
firing_probabilities, num_features, device, dtype
|
|
81
|
+
)
|
|
82
|
+
self.std_firing_magnitudes = _to_tensor(
|
|
83
|
+
std_firing_magnitudes, num_features, device, dtype
|
|
84
|
+
)
|
|
85
|
+
self.mean_firing_magnitudes = _to_tensor(
|
|
86
|
+
mean_firing_magnitudes, num_features, device, dtype
|
|
87
|
+
)
|
|
88
|
+
self.modify_activations = _normalize_modifiers(modify_activations)
|
|
89
|
+
self.correlation_thresholds = None
|
|
90
|
+
self.correlation_matrix = None
|
|
91
|
+
self.low_rank_correlation = None
|
|
92
|
+
self.use_sparse_tensors = use_sparse_tensors
|
|
93
|
+
|
|
94
|
+
if correlation_matrix is not None:
|
|
95
|
+
if isinstance(correlation_matrix, torch.Tensor):
|
|
96
|
+
# Full correlation matrix
|
|
97
|
+
_validate_correlation_matrix(correlation_matrix, num_features)
|
|
98
|
+
self.correlation_matrix = correlation_matrix
|
|
99
|
+
else:
|
|
100
|
+
# Low-rank correlation (tuple or LowRankCorrelationMatrix)
|
|
101
|
+
correlation_factor, correlation_diag = (
|
|
102
|
+
correlation_matrix[0],
|
|
103
|
+
correlation_matrix[1],
|
|
104
|
+
)
|
|
105
|
+
_validate_low_rank_correlation(
|
|
106
|
+
correlation_factor, correlation_diag, num_features
|
|
107
|
+
)
|
|
108
|
+
# Pre-compute sqrt for efficiency (used every sample call)
|
|
109
|
+
self.low_rank_correlation = (
|
|
110
|
+
correlation_factor,
|
|
111
|
+
correlation_diag.sqrt(),
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
# Vectorized inverse normal CDF: norm.ppf(1-p) = sqrt(2) * erfinv(1 - 2*p)
|
|
115
|
+
self.correlation_thresholds = math.sqrt(2) * torch.erfinv(
|
|
116
|
+
1 - 2 * self.firing_probabilities
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
@torch.no_grad()
|
|
120
|
+
def sample(self, batch_size: int) -> torch.Tensor:
|
|
121
|
+
"""
|
|
122
|
+
Generate a batch of feature activations with controlled properties.
|
|
123
|
+
|
|
124
|
+
This is the main function for generating synthetic training data for SAEs.
|
|
125
|
+
Features fire independently according to their firing probabilities unless
|
|
126
|
+
a correlation matrix is provided.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
batch_size: Number of samples to generate
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
Tensor of shape [batch_size, num_features] with non-negative activations
|
|
133
|
+
"""
|
|
134
|
+
# All tensors (firing_probabilities, std_firing_magnitudes, mean_firing_magnitudes)
|
|
135
|
+
# are on the same device from __init__ via _to_tensor()
|
|
136
|
+
device = self.firing_probabilities.device
|
|
137
|
+
|
|
138
|
+
if self.correlation_matrix is not None:
|
|
139
|
+
assert self.correlation_thresholds is not None
|
|
140
|
+
firing_indices = _generate_correlated_features(
|
|
141
|
+
batch_size,
|
|
142
|
+
self.correlation_matrix,
|
|
143
|
+
self.correlation_thresholds,
|
|
144
|
+
device,
|
|
145
|
+
)
|
|
146
|
+
elif self.low_rank_correlation is not None:
|
|
147
|
+
assert self.correlation_thresholds is not None
|
|
148
|
+
firing_indices = _generate_low_rank_correlated_features(
|
|
149
|
+
batch_size,
|
|
150
|
+
self.low_rank_correlation[0],
|
|
151
|
+
self.low_rank_correlation[1],
|
|
152
|
+
self.correlation_thresholds,
|
|
153
|
+
device,
|
|
154
|
+
)
|
|
155
|
+
else:
|
|
156
|
+
firing_indices = torch.bernoulli(
|
|
157
|
+
self.firing_probabilities.unsqueeze(0).expand(batch_size, -1)
|
|
158
|
+
).nonzero(as_tuple=True)
|
|
159
|
+
|
|
160
|
+
# Compute activations only at firing positions (sparse optimization)
|
|
161
|
+
feature_indices = firing_indices[1]
|
|
162
|
+
num_firing = feature_indices.shape[0]
|
|
163
|
+
mean_at_firing = self.mean_firing_magnitudes[feature_indices]
|
|
164
|
+
std_at_firing = self.std_firing_magnitudes[feature_indices]
|
|
165
|
+
random_deltas = (
|
|
166
|
+
torch.randn(
|
|
167
|
+
num_firing, device=device, dtype=self.mean_firing_magnitudes.dtype
|
|
168
|
+
)
|
|
169
|
+
* std_at_firing
|
|
170
|
+
)
|
|
171
|
+
activations_at_firing = (mean_at_firing + random_deltas).relu()
|
|
172
|
+
|
|
173
|
+
if self.use_sparse_tensors:
|
|
174
|
+
# Return sparse COO tensor
|
|
175
|
+
indices = torch.stack(firing_indices) # [2, nnz]
|
|
176
|
+
feature_activations = torch.sparse_coo_tensor(
|
|
177
|
+
indices,
|
|
178
|
+
activations_at_firing,
|
|
179
|
+
size=(batch_size, self.num_features),
|
|
180
|
+
device=device,
|
|
181
|
+
dtype=self.mean_firing_magnitudes.dtype,
|
|
182
|
+
)
|
|
183
|
+
else:
|
|
184
|
+
# Dense tensor path
|
|
185
|
+
feature_activations = torch.zeros(
|
|
186
|
+
batch_size,
|
|
187
|
+
self.num_features,
|
|
188
|
+
device=device,
|
|
189
|
+
dtype=self.mean_firing_magnitudes.dtype,
|
|
190
|
+
)
|
|
191
|
+
feature_activations[firing_indices] = activations_at_firing
|
|
192
|
+
|
|
193
|
+
if self.modify_activations is not None:
|
|
194
|
+
feature_activations = self.modify_activations(feature_activations)
|
|
195
|
+
if feature_activations.is_sparse:
|
|
196
|
+
# Apply relu to sparse values
|
|
197
|
+
feature_activations = feature_activations.coalesce()
|
|
198
|
+
feature_activations = torch.sparse_coo_tensor(
|
|
199
|
+
feature_activations.indices(),
|
|
200
|
+
feature_activations.values().relu(),
|
|
201
|
+
feature_activations.shape,
|
|
202
|
+
device=feature_activations.device,
|
|
203
|
+
dtype=feature_activations.dtype,
|
|
204
|
+
)
|
|
205
|
+
else:
|
|
206
|
+
feature_activations = feature_activations.relu()
|
|
207
|
+
|
|
208
|
+
return feature_activations
|
|
209
|
+
|
|
210
|
+
def forward(self, batch_size: int) -> torch.Tensor:
|
|
211
|
+
return self.sample(batch_size)
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def _generate_correlated_features(
|
|
215
|
+
batch_size: int,
|
|
216
|
+
correlation_matrix: torch.Tensor,
|
|
217
|
+
thresholds: torch.Tensor,
|
|
218
|
+
device: torch.device,
|
|
219
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
220
|
+
"""
|
|
221
|
+
Generate correlated binary features using multivariate Gaussian sampling.
|
|
222
|
+
|
|
223
|
+
Uses the Gaussian copula approach: sample from a multivariate normal
|
|
224
|
+
distribution, then threshold to get binary features.
|
|
225
|
+
|
|
226
|
+
Args:
|
|
227
|
+
batch_size: Number of samples to generate
|
|
228
|
+
correlation_matrix: Correlation matrix between features
|
|
229
|
+
thresholds: Pre-computed thresholds for each feature (from inverse normal CDF)
|
|
230
|
+
device: Device to generate samples on
|
|
231
|
+
|
|
232
|
+
Returns:
|
|
233
|
+
Tuple of (row_indices, col_indices) for firing features
|
|
234
|
+
"""
|
|
235
|
+
num_features = correlation_matrix.shape[0]
|
|
236
|
+
|
|
237
|
+
mvn = MultivariateNormal(
|
|
238
|
+
loc=torch.zeros(num_features, device=device, dtype=thresholds.dtype),
|
|
239
|
+
covariance_matrix=correlation_matrix.to(device=device, dtype=thresholds.dtype),
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
gaussian_samples = mvn.sample((batch_size,))
|
|
243
|
+
indices = (gaussian_samples > thresholds.unsqueeze(0)).nonzero(as_tuple=True)
|
|
244
|
+
return indices[0], indices[1]
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def _generate_low_rank_correlated_features(
|
|
248
|
+
batch_size: int,
|
|
249
|
+
correlation_factor: torch.Tensor,
|
|
250
|
+
cov_diag_sqrt: torch.Tensor,
|
|
251
|
+
thresholds: torch.Tensor,
|
|
252
|
+
device: torch.device,
|
|
253
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
254
|
+
"""
|
|
255
|
+
Generate correlated binary features using low-rank multivariate Gaussian sampling.
|
|
256
|
+
|
|
257
|
+
Uses the Gaussian copula approach with a low-rank covariance structure for scalability.
|
|
258
|
+
The covariance is represented as: cov = factor @ factor.T + diag(diag_term)
|
|
259
|
+
|
|
260
|
+
Args:
|
|
261
|
+
batch_size: Number of samples to generate
|
|
262
|
+
correlation_factor: Factor matrix of shape (num_features, rank)
|
|
263
|
+
cov_diag_sqrt: Pre-computed sqrt of diagonal term, shape (num_features,)
|
|
264
|
+
thresholds: Pre-computed thresholds for each feature (from inverse normal CDF)
|
|
265
|
+
device: Device to generate samples on
|
|
266
|
+
|
|
267
|
+
Returns:
|
|
268
|
+
Tuple of (row_indices, col_indices) for firing features
|
|
269
|
+
"""
|
|
270
|
+
# Manual low-rank MVN sampling to enable autocast for the expensive matmul
|
|
271
|
+
# samples = eps @ cov_factor.T + eta * sqrt(cov_diag)
|
|
272
|
+
# where eps ~ N(0, I_rank) and eta ~ N(0, I_n)
|
|
273
|
+
|
|
274
|
+
num_features, rank = correlation_factor.shape
|
|
275
|
+
|
|
276
|
+
# Generate random samples in float32 for numerical stability
|
|
277
|
+
eps = torch.randn(batch_size, rank, device=device, dtype=correlation_factor.dtype)
|
|
278
|
+
eta = torch.randn(
|
|
279
|
+
batch_size, num_features, device=device, dtype=cov_diag_sqrt.dtype
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
gaussian_samples = eps @ correlation_factor.T + eta * cov_diag_sqrt
|
|
283
|
+
|
|
284
|
+
indices = (gaussian_samples > thresholds.unsqueeze(0)).nonzero(as_tuple=True)
|
|
285
|
+
return indices[0], indices[1]
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def _to_tensor(
|
|
289
|
+
value: torch.Tensor | float,
|
|
290
|
+
num_features: int,
|
|
291
|
+
device: torch.device | str,
|
|
292
|
+
dtype: torch.dtype | str,
|
|
293
|
+
) -> torch.Tensor:
|
|
294
|
+
dtype = str_to_dtype(dtype)
|
|
295
|
+
device = torch.device(device)
|
|
296
|
+
if not isinstance(value, torch.Tensor):
|
|
297
|
+
value = value * torch.ones(num_features, device=device, dtype=dtype)
|
|
298
|
+
if value.shape != (num_features,):
|
|
299
|
+
raise ValueError(
|
|
300
|
+
f"Value must be a tensor of shape ({num_features},) or a float. Got {value.shape}"
|
|
301
|
+
)
|
|
302
|
+
return value.to(device, dtype)
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def _normalize_modifiers(
|
|
306
|
+
modify_activations: ActivationsModifierInput,
|
|
307
|
+
) -> ActivationsModifier | None:
|
|
308
|
+
"""Convert modifier input to a single modifier or None."""
|
|
309
|
+
if modify_activations is None:
|
|
310
|
+
return None
|
|
311
|
+
if callable(modify_activations):
|
|
312
|
+
return modify_activations
|
|
313
|
+
# It's a sequence of modifiers - chain them
|
|
314
|
+
modifiers = list(modify_activations)
|
|
315
|
+
if len(modifiers) == 0:
|
|
316
|
+
return None
|
|
317
|
+
if len(modifiers) == 1:
|
|
318
|
+
return modifiers[0]
|
|
319
|
+
|
|
320
|
+
def chained(activations: torch.Tensor) -> torch.Tensor:
|
|
321
|
+
result = activations
|
|
322
|
+
for modifier in modifiers:
|
|
323
|
+
result = modifier(result)
|
|
324
|
+
return result
|
|
325
|
+
|
|
326
|
+
return chained
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
def _validate_correlation_matrix(
|
|
330
|
+
correlation_matrix: torch.Tensor, num_features: int
|
|
331
|
+
) -> None:
|
|
332
|
+
"""Validate that a correlation matrix has correct properties.
|
|
333
|
+
|
|
334
|
+
Args:
|
|
335
|
+
correlation_matrix: The matrix to validate
|
|
336
|
+
num_features: Expected number of features (matrix should be (num_features, num_features))
|
|
337
|
+
|
|
338
|
+
Raises:
|
|
339
|
+
ValueError: If the matrix has incorrect shape, non-unit diagonal, or is not positive definite
|
|
340
|
+
"""
|
|
341
|
+
expected_shape = (num_features, num_features)
|
|
342
|
+
if correlation_matrix.shape != expected_shape:
|
|
343
|
+
raise ValueError(
|
|
344
|
+
f"Correlation matrix must have shape {expected_shape}, "
|
|
345
|
+
f"got {tuple(correlation_matrix.shape)}"
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
diagonal = torch.diag(correlation_matrix)
|
|
349
|
+
if not torch.allclose(diagonal, torch.ones_like(diagonal)):
|
|
350
|
+
raise ValueError("Correlation matrix diagonal must be all 1s")
|
|
351
|
+
|
|
352
|
+
try:
|
|
353
|
+
torch.linalg.cholesky(correlation_matrix)
|
|
354
|
+
except RuntimeError as e:
|
|
355
|
+
raise ValueError("Correlation matrix must be positive definite") from e
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
def _validate_low_rank_correlation(
|
|
359
|
+
correlation_factor: torch.Tensor,
|
|
360
|
+
correlation_diag: torch.Tensor,
|
|
361
|
+
num_features: int,
|
|
362
|
+
) -> None:
|
|
363
|
+
"""Validate that low-rank correlation parameters have correct properties.
|
|
364
|
+
|
|
365
|
+
Args:
|
|
366
|
+
correlation_factor: Factor matrix of shape (num_features, rank)
|
|
367
|
+
correlation_diag: Diagonal term of shape (num_features,)
|
|
368
|
+
num_features: Expected number of features
|
|
369
|
+
|
|
370
|
+
Raises:
|
|
371
|
+
ValueError: If shapes are incorrect or diagonal terms are not positive
|
|
372
|
+
"""
|
|
373
|
+
if correlation_factor.ndim != 2:
|
|
374
|
+
raise ValueError(
|
|
375
|
+
f"correlation_factor must be 2D, got {correlation_factor.ndim}D"
|
|
376
|
+
)
|
|
377
|
+
if correlation_factor.shape[0] != num_features:
|
|
378
|
+
raise ValueError(
|
|
379
|
+
f"correlation_factor must have shape ({num_features}, rank), "
|
|
380
|
+
f"got {tuple(correlation_factor.shape)}"
|
|
381
|
+
)
|
|
382
|
+
if correlation_diag.shape != (num_features,):
|
|
383
|
+
raise ValueError(
|
|
384
|
+
f"correlation_diag must have shape ({num_features},), "
|
|
385
|
+
f"got {tuple(correlation_diag.shape)}"
|
|
386
|
+
)
|
|
387
|
+
if torch.any(correlation_diag <= 0):
|
|
388
|
+
raise ValueError("correlation_diag must have all positive values")
|