sae-lens 6.26.0__py3-none-any.whl → 6.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +3 -1
- sae_lens/cache_activations_runner.py +12 -5
- sae_lens/config.py +2 -0
- sae_lens/loading/pretrained_sae_loaders.py +2 -1
- sae_lens/loading/pretrained_saes_directory.py +18 -0
- sae_lens/pretrained_saes.yaml +144 -144
- sae_lens/saes/gated_sae.py +1 -0
- sae_lens/saes/jumprelu_sae.py +3 -0
- sae_lens/saes/sae.py +13 -0
- sae_lens/saes/standard_sae.py +2 -0
- sae_lens/saes/temporal_sae.py +1 -0
- sae_lens/synthetic/__init__.py +89 -0
- sae_lens/synthetic/activation_generator.py +215 -0
- sae_lens/synthetic/correlation.py +170 -0
- sae_lens/synthetic/evals.py +141 -0
- sae_lens/synthetic/feature_dictionary.py +138 -0
- sae_lens/synthetic/firing_probabilities.py +104 -0
- sae_lens/synthetic/hierarchy.py +335 -0
- sae_lens/synthetic/initialization.py +40 -0
- sae_lens/synthetic/plotting.py +230 -0
- sae_lens/synthetic/training.py +145 -0
- sae_lens/tokenization_and_batching.py +1 -1
- sae_lens/training/activations_store.py +51 -91
- sae_lens/training/mixing_buffer.py +14 -5
- sae_lens/training/sae_trainer.py +1 -1
- sae_lens/util.py +26 -1
- {sae_lens-6.26.0.dist-info → sae_lens-6.28.1.dist-info}/METADATA +3 -1
- sae_lens-6.28.1.dist-info/RECORD +52 -0
- sae_lens-6.26.0.dist-info/RECORD +0 -42
- {sae_lens-6.26.0.dist-info → sae_lens-6.28.1.dist-info}/WHEEL +0 -0
- {sae_lens-6.26.0.dist-info → sae_lens-6.28.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Functions for generating synthetic feature activations.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from collections.abc import Callable, Sequence
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from scipy.stats import norm
|
|
9
|
+
from torch import nn
|
|
10
|
+
from torch.distributions import MultivariateNormal
|
|
11
|
+
|
|
12
|
+
from sae_lens.util import str_to_dtype
|
|
13
|
+
|
|
14
|
+
ActivationsModifier = Callable[[torch.Tensor], torch.Tensor]
|
|
15
|
+
ActivationsModifierInput = ActivationsModifier | Sequence[ActivationsModifier] | None
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ActivationGenerator(nn.Module):
|
|
19
|
+
"""
|
|
20
|
+
Generator for synthetic feature activations.
|
|
21
|
+
|
|
22
|
+
This module provides a generator for synthetic feature activations with controlled properties.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
num_features: int
|
|
26
|
+
firing_probabilities: torch.Tensor
|
|
27
|
+
std_firing_magnitudes: torch.Tensor
|
|
28
|
+
mean_firing_magnitudes: torch.Tensor
|
|
29
|
+
modify_activations: ActivationsModifier | None
|
|
30
|
+
correlation_matrix: torch.Tensor | None
|
|
31
|
+
correlation_thresholds: torch.Tensor | None
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
num_features: int,
|
|
36
|
+
firing_probabilities: torch.Tensor | float,
|
|
37
|
+
std_firing_magnitudes: torch.Tensor | float = 0.0,
|
|
38
|
+
mean_firing_magnitudes: torch.Tensor | float = 1.0,
|
|
39
|
+
modify_activations: ActivationsModifierInput = None,
|
|
40
|
+
correlation_matrix: torch.Tensor | None = None,
|
|
41
|
+
device: torch.device | str = "cpu",
|
|
42
|
+
dtype: torch.dtype | str = "float32",
|
|
43
|
+
):
|
|
44
|
+
super().__init__()
|
|
45
|
+
self.num_features = num_features
|
|
46
|
+
self.firing_probabilities = _to_tensor(
|
|
47
|
+
firing_probabilities, num_features, device, dtype
|
|
48
|
+
)
|
|
49
|
+
self.std_firing_magnitudes = _to_tensor(
|
|
50
|
+
std_firing_magnitudes, num_features, device, dtype
|
|
51
|
+
)
|
|
52
|
+
self.mean_firing_magnitudes = _to_tensor(
|
|
53
|
+
mean_firing_magnitudes, num_features, device, dtype
|
|
54
|
+
)
|
|
55
|
+
self.modify_activations = _normalize_modifiers(modify_activations)
|
|
56
|
+
self.correlation_thresholds = None
|
|
57
|
+
if correlation_matrix is not None:
|
|
58
|
+
_validate_correlation_matrix(correlation_matrix, num_features)
|
|
59
|
+
self.correlation_thresholds = torch.tensor(
|
|
60
|
+
[norm.ppf(1 - p.item()) for p in self.firing_probabilities],
|
|
61
|
+
device=device,
|
|
62
|
+
dtype=self.firing_probabilities.dtype,
|
|
63
|
+
)
|
|
64
|
+
self.correlation_matrix = correlation_matrix
|
|
65
|
+
|
|
66
|
+
def sample(self, batch_size: int) -> torch.Tensor:
|
|
67
|
+
"""
|
|
68
|
+
Generate a batch of feature activations with controlled properties.
|
|
69
|
+
|
|
70
|
+
This is the main function for generating synthetic training data for SAEs.
|
|
71
|
+
Features fire independently according to their firing probabilities unless
|
|
72
|
+
a correlation matrix is provided.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
batch_size: Number of samples to generate
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
Tensor of shape [batch_size, num_features] with non-negative activations
|
|
79
|
+
"""
|
|
80
|
+
# All tensors (firing_probabilities, std_firing_magnitudes, mean_firing_magnitudes)
|
|
81
|
+
# are on the same device from __init__ via _to_tensor()
|
|
82
|
+
device = self.firing_probabilities.device
|
|
83
|
+
|
|
84
|
+
if self.correlation_matrix is not None:
|
|
85
|
+
assert self.correlation_thresholds is not None
|
|
86
|
+
firing_features = _generate_correlated_features(
|
|
87
|
+
batch_size,
|
|
88
|
+
self.correlation_matrix,
|
|
89
|
+
self.correlation_thresholds,
|
|
90
|
+
device,
|
|
91
|
+
)
|
|
92
|
+
else:
|
|
93
|
+
firing_features = torch.bernoulli(
|
|
94
|
+
self.firing_probabilities.unsqueeze(0).expand(batch_size, -1)
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
firing_magnitude_delta = torch.normal(
|
|
98
|
+
torch.zeros_like(self.firing_probabilities)
|
|
99
|
+
.unsqueeze(0)
|
|
100
|
+
.expand(batch_size, -1),
|
|
101
|
+
self.std_firing_magnitudes.unsqueeze(0).expand(batch_size, -1),
|
|
102
|
+
)
|
|
103
|
+
firing_magnitude_delta[firing_features == 0] = 0
|
|
104
|
+
feature_activations = (
|
|
105
|
+
firing_features * self.mean_firing_magnitudes + firing_magnitude_delta
|
|
106
|
+
).relu()
|
|
107
|
+
|
|
108
|
+
if self.modify_activations is not None:
|
|
109
|
+
feature_activations = self.modify_activations(feature_activations).relu()
|
|
110
|
+
return feature_activations
|
|
111
|
+
|
|
112
|
+
def forward(self, batch_size: int) -> torch.Tensor:
|
|
113
|
+
return self.sample(batch_size)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _generate_correlated_features(
|
|
117
|
+
batch_size: int,
|
|
118
|
+
correlation_matrix: torch.Tensor,
|
|
119
|
+
thresholds: torch.Tensor,
|
|
120
|
+
device: torch.device,
|
|
121
|
+
) -> torch.Tensor:
|
|
122
|
+
"""
|
|
123
|
+
Generate correlated binary features using multivariate Gaussian sampling.
|
|
124
|
+
|
|
125
|
+
Uses the Gaussian copula approach: sample from a multivariate normal
|
|
126
|
+
distribution, then threshold to get binary features.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
batch_size: Number of samples to generate
|
|
130
|
+
correlation_matrix: Correlation matrix between features
|
|
131
|
+
thresholds: Pre-computed thresholds for each feature (from inverse normal CDF)
|
|
132
|
+
device: Device to generate samples on
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
Binary feature matrix of shape [batch_size, num_features]
|
|
136
|
+
"""
|
|
137
|
+
num_features = correlation_matrix.shape[0]
|
|
138
|
+
|
|
139
|
+
mvn = MultivariateNormal(
|
|
140
|
+
loc=torch.zeros(num_features, device=device, dtype=thresholds.dtype),
|
|
141
|
+
covariance_matrix=correlation_matrix.to(device=device, dtype=thresholds.dtype),
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
gaussian_samples = mvn.sample((batch_size,))
|
|
145
|
+
return (gaussian_samples > thresholds.unsqueeze(0)).float()
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def _to_tensor(
|
|
149
|
+
value: torch.Tensor | float,
|
|
150
|
+
num_features: int,
|
|
151
|
+
device: torch.device | str,
|
|
152
|
+
dtype: torch.dtype | str,
|
|
153
|
+
) -> torch.Tensor:
|
|
154
|
+
dtype = str_to_dtype(dtype)
|
|
155
|
+
device = torch.device(device)
|
|
156
|
+
if not isinstance(value, torch.Tensor):
|
|
157
|
+
value = value * torch.ones(num_features, device=device, dtype=dtype)
|
|
158
|
+
if value.shape != (num_features,):
|
|
159
|
+
raise ValueError(
|
|
160
|
+
f"Value must be a tensor of shape ({num_features},) or a float. Got {value.shape}"
|
|
161
|
+
)
|
|
162
|
+
return value.to(device, dtype)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def _normalize_modifiers(
|
|
166
|
+
modify_activations: ActivationsModifierInput,
|
|
167
|
+
) -> ActivationsModifier | None:
|
|
168
|
+
"""Convert modifier input to a single modifier or None."""
|
|
169
|
+
if modify_activations is None:
|
|
170
|
+
return None
|
|
171
|
+
if callable(modify_activations):
|
|
172
|
+
return modify_activations
|
|
173
|
+
# It's a sequence of modifiers - chain them
|
|
174
|
+
modifiers = list(modify_activations)
|
|
175
|
+
if len(modifiers) == 0:
|
|
176
|
+
return None
|
|
177
|
+
if len(modifiers) == 1:
|
|
178
|
+
return modifiers[0]
|
|
179
|
+
|
|
180
|
+
def chained(activations: torch.Tensor) -> torch.Tensor:
|
|
181
|
+
result = activations
|
|
182
|
+
for modifier in modifiers:
|
|
183
|
+
result = modifier(result)
|
|
184
|
+
return result
|
|
185
|
+
|
|
186
|
+
return chained
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def _validate_correlation_matrix(
|
|
190
|
+
correlation_matrix: torch.Tensor, num_features: int
|
|
191
|
+
) -> None:
|
|
192
|
+
"""Validate that a correlation matrix has correct properties.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
correlation_matrix: The matrix to validate
|
|
196
|
+
num_features: Expected number of features (matrix should be [num_features, num_features])
|
|
197
|
+
|
|
198
|
+
Raises:
|
|
199
|
+
ValueError: If the matrix has incorrect shape, non-unit diagonal, or is not positive definite
|
|
200
|
+
"""
|
|
201
|
+
expected_shape = (num_features, num_features)
|
|
202
|
+
if correlation_matrix.shape != expected_shape:
|
|
203
|
+
raise ValueError(
|
|
204
|
+
f"Correlation matrix must have shape {expected_shape}, "
|
|
205
|
+
f"got {tuple(correlation_matrix.shape)}"
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
diagonal = torch.diag(correlation_matrix)
|
|
209
|
+
if not torch.allclose(diagonal, torch.ones_like(diagonal)):
|
|
210
|
+
raise ValueError("Correlation matrix diagonal must be all 1s")
|
|
211
|
+
|
|
212
|
+
try:
|
|
213
|
+
torch.linalg.cholesky(correlation_matrix)
|
|
214
|
+
except RuntimeError as e:
|
|
215
|
+
raise ValueError("Correlation matrix must be positive definite") from e
|
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
import random
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def create_correlation_matrix_from_correlations(
|
|
7
|
+
num_features: int,
|
|
8
|
+
correlations: dict[tuple[int, int], float] | None = None,
|
|
9
|
+
default_correlation: float = 0.0,
|
|
10
|
+
) -> torch.Tensor:
|
|
11
|
+
"""
|
|
12
|
+
Create a correlation matrix with specified pairwise correlations.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
num_features: Number of features
|
|
16
|
+
correlations: Dict mapping (i, j) pairs to correlation values.
|
|
17
|
+
Pairs should have i < j.
|
|
18
|
+
default_correlation: Default correlation for unspecified pairs
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
Correlation matrix of shape [num_features, num_features]
|
|
22
|
+
"""
|
|
23
|
+
matrix = torch.eye(num_features) + default_correlation * (
|
|
24
|
+
1 - torch.eye(num_features)
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
if correlations is not None:
|
|
28
|
+
for (i, j), corr in correlations.items():
|
|
29
|
+
matrix[i, j] = corr
|
|
30
|
+
matrix[j, i] = corr
|
|
31
|
+
|
|
32
|
+
# Ensure matrix is symmetric (numerical precision)
|
|
33
|
+
matrix = (matrix + matrix.T) / 2
|
|
34
|
+
|
|
35
|
+
# Check positive definiteness and fix if necessary
|
|
36
|
+
# Use eigvalsh for symmetric matrices (returns real eigenvalues)
|
|
37
|
+
eigenvals = torch.linalg.eigvalsh(matrix)
|
|
38
|
+
if torch.any(eigenvals < -1e-6):
|
|
39
|
+
matrix = _fix_correlation_matrix(matrix)
|
|
40
|
+
|
|
41
|
+
return matrix
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _fix_correlation_matrix(
|
|
45
|
+
matrix: torch.Tensor, min_eigenval: float = 1e-6
|
|
46
|
+
) -> torch.Tensor:
|
|
47
|
+
"""Fix a correlation matrix to be positive semi-definite."""
|
|
48
|
+
eigenvals, eigenvecs = torch.linalg.eigh(matrix)
|
|
49
|
+
eigenvals = torch.clamp(eigenvals, min=min_eigenval)
|
|
50
|
+
fixed_matrix = eigenvecs @ torch.diag(eigenvals) @ eigenvecs.T
|
|
51
|
+
|
|
52
|
+
diag_vals = torch.diag(fixed_matrix)
|
|
53
|
+
fixed_matrix = fixed_matrix / torch.sqrt(
|
|
54
|
+
diag_vals.unsqueeze(0) * diag_vals.unsqueeze(1)
|
|
55
|
+
)
|
|
56
|
+
fixed_matrix.fill_diagonal_(1.0)
|
|
57
|
+
|
|
58
|
+
return fixed_matrix
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def generate_random_correlations(
|
|
62
|
+
num_features: int,
|
|
63
|
+
positive_ratio: float = 0.5,
|
|
64
|
+
uncorrelated_ratio: float = 0.3,
|
|
65
|
+
min_correlation_strength: float = 0.1,
|
|
66
|
+
max_correlation_strength: float = 0.8,
|
|
67
|
+
seed: int | None = None,
|
|
68
|
+
) -> dict[tuple[int, int], float]:
|
|
69
|
+
"""
|
|
70
|
+
Generate random correlations between features with specified constraints.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
num_features: Number of features
|
|
74
|
+
positive_ratio: Fraction of correlations that should be positive (0.0 to 1.0)
|
|
75
|
+
uncorrelated_ratio: Fraction of feature pairs that should remain uncorrelated (0.0 to 1.0)
|
|
76
|
+
min_correlation_strength: Minimum absolute correlation strength
|
|
77
|
+
max_correlation_strength: Maximum absolute correlation strength
|
|
78
|
+
seed: Random seed for reproducibility
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
Dictionary mapping (i, j) pairs to correlation values
|
|
82
|
+
"""
|
|
83
|
+
# Use local random number generator to avoid side effects on global state
|
|
84
|
+
rng = random.Random(seed)
|
|
85
|
+
|
|
86
|
+
# Validate inputs
|
|
87
|
+
if not 0.0 <= positive_ratio <= 1.0:
|
|
88
|
+
raise ValueError("positive_ratio must be between 0.0 and 1.0")
|
|
89
|
+
if not 0.0 <= uncorrelated_ratio <= 1.0:
|
|
90
|
+
raise ValueError("uncorrelated_ratio must be between 0.0 and 1.0")
|
|
91
|
+
if min_correlation_strength < 0:
|
|
92
|
+
raise ValueError("min_correlation_strength must be non-negative")
|
|
93
|
+
if max_correlation_strength > 1.0:
|
|
94
|
+
raise ValueError("max_correlation_strength must be <= 1.0")
|
|
95
|
+
if min_correlation_strength > max_correlation_strength:
|
|
96
|
+
raise ValueError("min_correlation_strength must be <= max_correlation_strength")
|
|
97
|
+
|
|
98
|
+
# Generate all possible feature pairs (i, j) where i < j
|
|
99
|
+
all_pairs = [
|
|
100
|
+
(i, j) for i in range(num_features) for j in range(i + 1, num_features)
|
|
101
|
+
]
|
|
102
|
+
total_pairs = len(all_pairs)
|
|
103
|
+
|
|
104
|
+
if total_pairs == 0:
|
|
105
|
+
return {}
|
|
106
|
+
|
|
107
|
+
# Determine how many pairs to correlate vs leave uncorrelated
|
|
108
|
+
num_uncorrelated = int(total_pairs * uncorrelated_ratio)
|
|
109
|
+
num_correlated = total_pairs - num_uncorrelated
|
|
110
|
+
|
|
111
|
+
# Randomly select which pairs to correlate
|
|
112
|
+
correlated_pairs = rng.sample(all_pairs, num_correlated)
|
|
113
|
+
|
|
114
|
+
# For correlated pairs, determine positive vs negative
|
|
115
|
+
num_positive = int(num_correlated * positive_ratio)
|
|
116
|
+
num_negative = num_correlated - num_positive
|
|
117
|
+
|
|
118
|
+
# Assign signs
|
|
119
|
+
signs = [1] * num_positive + [-1] * num_negative
|
|
120
|
+
rng.shuffle(signs)
|
|
121
|
+
|
|
122
|
+
# Generate correlation strengths
|
|
123
|
+
correlations = {}
|
|
124
|
+
for pair, sign in zip(correlated_pairs, signs):
|
|
125
|
+
# Sample correlation strength uniformly from range
|
|
126
|
+
strength = rng.uniform(min_correlation_strength, max_correlation_strength)
|
|
127
|
+
correlations[pair] = sign * strength
|
|
128
|
+
|
|
129
|
+
return correlations
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def generate_random_correlation_matrix(
|
|
133
|
+
num_features: int,
|
|
134
|
+
positive_ratio: float = 0.5,
|
|
135
|
+
uncorrelated_ratio: float = 0.3,
|
|
136
|
+
min_correlation_strength: float = 0.1,
|
|
137
|
+
max_correlation_strength: float = 0.8,
|
|
138
|
+
seed: int | None = None,
|
|
139
|
+
) -> torch.Tensor:
|
|
140
|
+
"""
|
|
141
|
+
Generate a random correlation matrix with specified constraints.
|
|
142
|
+
|
|
143
|
+
This is a convenience function that combines generate_random_correlations()
|
|
144
|
+
and create_correlation_matrix_from_correlations() into a single call.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
num_features: Number of features
|
|
148
|
+
positive_ratio: Fraction of correlations that should be positive (0.0 to 1.0)
|
|
149
|
+
uncorrelated_ratio: Fraction of feature pairs that should remain uncorrelated (0.0 to 1.0)
|
|
150
|
+
min_correlation_strength: Minimum absolute correlation strength
|
|
151
|
+
max_correlation_strength: Maximum absolute correlation strength
|
|
152
|
+
seed: Random seed for reproducibility
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
Random correlation matrix of shape [num_features, num_features]
|
|
156
|
+
"""
|
|
157
|
+
# Generate random correlations
|
|
158
|
+
correlations = generate_random_correlations(
|
|
159
|
+
num_features=num_features,
|
|
160
|
+
positive_ratio=positive_ratio,
|
|
161
|
+
uncorrelated_ratio=uncorrelated_ratio,
|
|
162
|
+
min_correlation_strength=min_correlation_strength,
|
|
163
|
+
max_correlation_strength=max_correlation_strength,
|
|
164
|
+
seed=seed,
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
# Create and return correlation matrix
|
|
168
|
+
return create_correlation_matrix_from_correlations(
|
|
169
|
+
num_features=num_features, correlations=correlations
|
|
170
|
+
)
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utilities for training SAEs on synthetic data.
|
|
3
|
+
|
|
4
|
+
This module provides helpers for:
|
|
5
|
+
|
|
6
|
+
- Generating training data from feature dictionaries
|
|
7
|
+
- Training SAEs on synthetic data
|
|
8
|
+
- Evaluating SAEs against known ground truth features
|
|
9
|
+
- Initializing SAEs to match feature dictionaries
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from dataclasses import dataclass
|
|
13
|
+
|
|
14
|
+
import torch
|
|
15
|
+
from scipy.optimize import linear_sum_assignment
|
|
16
|
+
|
|
17
|
+
from sae_lens.synthetic.activation_generator import ActivationGenerator
|
|
18
|
+
from sae_lens.synthetic.feature_dictionary import FeatureDictionary
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def mean_correlation_coefficient(
|
|
22
|
+
features_a: torch.Tensor,
|
|
23
|
+
features_b: torch.Tensor,
|
|
24
|
+
) -> float:
|
|
25
|
+
"""
|
|
26
|
+
Compute Mean Correlation Coefficient (MCC) between two sets of feature vectors.
|
|
27
|
+
|
|
28
|
+
MCC measures how well learned features align with ground truth features by finding
|
|
29
|
+
an optimal one-to-one matching using the Hungarian algorithm and computing the
|
|
30
|
+
mean absolute cosine similarity of matched pairs.
|
|
31
|
+
|
|
32
|
+
Reference: O'Neill et al. "Compute Optimal Inference and Provable Amortisation
|
|
33
|
+
Gap in Sparse Autoencoders" (arXiv:2411.13117)
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
features_a: Feature vectors of shape [num_features_a, hidden_dim]
|
|
37
|
+
features_b: Feature vectors of shape [num_features_b, hidden_dim]
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
MCC score in range [0, 1], where 1 indicates perfect alignment
|
|
41
|
+
"""
|
|
42
|
+
# Normalize to unit vectors
|
|
43
|
+
a_norm = features_a / features_a.norm(dim=1, keepdim=True).clamp(min=1e-8)
|
|
44
|
+
b_norm = features_b / features_b.norm(dim=1, keepdim=True).clamp(min=1e-8)
|
|
45
|
+
|
|
46
|
+
# Compute absolute cosine similarity matrix
|
|
47
|
+
cos_sim = torch.abs(a_norm @ b_norm.T)
|
|
48
|
+
|
|
49
|
+
# Convert to cost matrix for Hungarian algorithm (which minimizes)
|
|
50
|
+
cost_matrix = 1 - cos_sim.cpu().numpy()
|
|
51
|
+
|
|
52
|
+
# Find optimal matching
|
|
53
|
+
row_ind, col_ind = linear_sum_assignment(cost_matrix)
|
|
54
|
+
|
|
55
|
+
# Compute mean of matched similarities
|
|
56
|
+
matched_similarities = cos_sim[row_ind, col_ind]
|
|
57
|
+
return matched_similarities.mean().item()
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@dataclass
|
|
61
|
+
class SyntheticDataEvalResult:
|
|
62
|
+
"""Results from evaluating an SAE on synthetic data."""
|
|
63
|
+
|
|
64
|
+
true_l0: float
|
|
65
|
+
"""Average L0 of the true feature activations"""
|
|
66
|
+
|
|
67
|
+
sae_l0: float
|
|
68
|
+
"""Average L0 of the SAE's latent activations"""
|
|
69
|
+
|
|
70
|
+
dead_latents: int
|
|
71
|
+
"""Number of SAE latents that never fired"""
|
|
72
|
+
|
|
73
|
+
shrinkage: float
|
|
74
|
+
"""Average ratio of SAE output norm to input norm (1.0 = no shrinkage)"""
|
|
75
|
+
|
|
76
|
+
mcc: float
|
|
77
|
+
"""Mean Correlation Coefficient between SAE decoder and ground truth features"""
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@torch.no_grad()
|
|
81
|
+
def eval_sae_on_synthetic_data(
|
|
82
|
+
sae: torch.nn.Module,
|
|
83
|
+
feature_dict: FeatureDictionary,
|
|
84
|
+
activations_generator: ActivationGenerator,
|
|
85
|
+
num_samples: int = 100_000,
|
|
86
|
+
) -> SyntheticDataEvalResult:
|
|
87
|
+
"""
|
|
88
|
+
Evaluate an SAE on synthetic data with known ground truth.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
sae: The SAE to evaluate. Must have encode() and decode() methods.
|
|
92
|
+
feature_dict: The feature dictionary used to generate activations
|
|
93
|
+
activations_generator: Generator that produces feature activations
|
|
94
|
+
num_samples: Number of samples to use for evaluation
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
SyntheticDataEvalResult containing evaluation metrics
|
|
98
|
+
"""
|
|
99
|
+
sae.eval()
|
|
100
|
+
|
|
101
|
+
# Generate samples
|
|
102
|
+
feature_acts = activations_generator.sample(num_samples)
|
|
103
|
+
true_l0 = (feature_acts > 0).float().sum(dim=-1).mean().item()
|
|
104
|
+
hidden_acts = feature_dict(feature_acts)
|
|
105
|
+
|
|
106
|
+
# Filter out entries where no features fire
|
|
107
|
+
non_zero_mask = hidden_acts.norm(dim=-1) > 0
|
|
108
|
+
hidden_acts_filtered = hidden_acts[non_zero_mask]
|
|
109
|
+
|
|
110
|
+
# Get SAE reconstructions
|
|
111
|
+
sae_latents = sae.encode(hidden_acts_filtered) # type: ignore[attr-defined]
|
|
112
|
+
sae_output = sae.decode(sae_latents) # type: ignore[attr-defined]
|
|
113
|
+
|
|
114
|
+
sae_l0 = (sae_latents > 0).float().sum(dim=-1).mean().item()
|
|
115
|
+
dead_latents = int(
|
|
116
|
+
((sae_latents == 0).sum(dim=0) == sae_latents.shape[0]).sum().item()
|
|
117
|
+
)
|
|
118
|
+
if hidden_acts_filtered.shape[0] == 0:
|
|
119
|
+
shrinkage = 0.0
|
|
120
|
+
else:
|
|
121
|
+
shrinkage = (
|
|
122
|
+
(
|
|
123
|
+
sae_output.norm(dim=-1)
|
|
124
|
+
/ hidden_acts_filtered.norm(dim=-1).clamp(min=1e-8)
|
|
125
|
+
)
|
|
126
|
+
.mean()
|
|
127
|
+
.item()
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
# Compute MCC between SAE decoder and ground truth features
|
|
131
|
+
sae_decoder: torch.Tensor = sae.W_dec # type: ignore[attr-defined]
|
|
132
|
+
gt_features = feature_dict.feature_vectors
|
|
133
|
+
mcc = mean_correlation_coefficient(sae_decoder, gt_features)
|
|
134
|
+
|
|
135
|
+
return SyntheticDataEvalResult(
|
|
136
|
+
true_l0=true_l0,
|
|
137
|
+
sae_l0=sae_l0,
|
|
138
|
+
dead_latents=dead_latents,
|
|
139
|
+
shrinkage=shrinkage,
|
|
140
|
+
mcc=mcc,
|
|
141
|
+
)
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Feature dictionary for generating synthetic activations.
|
|
3
|
+
|
|
4
|
+
A FeatureDictionary maps feature activations (sparse coefficients) to dense hidden activations
|
|
5
|
+
by multiplying with a learned or constructed feature embedding matrix.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Callable
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
from torch import nn
|
|
12
|
+
from tqdm import tqdm
|
|
13
|
+
|
|
14
|
+
FeatureDictionaryInitializer = Callable[["FeatureDictionary"], None]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def orthogonalize_embeddings(
|
|
18
|
+
embeddings: torch.Tensor,
|
|
19
|
+
target_cos_sim: float = 0,
|
|
20
|
+
num_steps: int = 200,
|
|
21
|
+
lr: float = 0.01,
|
|
22
|
+
show_progress: bool = False,
|
|
23
|
+
) -> torch.Tensor:
|
|
24
|
+
num_vectors = embeddings.shape[0]
|
|
25
|
+
# Create a detached copy and normalize, then enable gradients
|
|
26
|
+
embeddings = embeddings.detach().clone()
|
|
27
|
+
embeddings = embeddings / embeddings.norm(p=2, dim=1, keepdim=True).clamp(min=1e-8)
|
|
28
|
+
embeddings = embeddings.requires_grad_(True)
|
|
29
|
+
|
|
30
|
+
optimizer = torch.optim.Adam([embeddings], lr=lr) # type: ignore[list-item]
|
|
31
|
+
|
|
32
|
+
# Create a mask to zero out diagonal elements (avoid in-place operations)
|
|
33
|
+
off_diagonal_mask = ~torch.eye(
|
|
34
|
+
num_vectors, dtype=torch.bool, device=embeddings.device
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
pbar = tqdm(
|
|
38
|
+
range(num_steps), desc="Orthogonalizing vectors", disable=not show_progress
|
|
39
|
+
)
|
|
40
|
+
for _ in pbar:
|
|
41
|
+
optimizer.zero_grad()
|
|
42
|
+
|
|
43
|
+
dot_products = embeddings @ embeddings.T
|
|
44
|
+
diff = dot_products - target_cos_sim
|
|
45
|
+
# Use masking instead of in-place fill_diagonal_
|
|
46
|
+
off_diagonal_diff = diff * off_diagonal_mask.float()
|
|
47
|
+
loss = off_diagonal_diff.pow(2).sum()
|
|
48
|
+
loss = loss + num_vectors * (dot_products.diag() - 1).pow(2).sum()
|
|
49
|
+
|
|
50
|
+
loss.backward()
|
|
51
|
+
optimizer.step()
|
|
52
|
+
pbar.set_description(f"loss: {loss.item():.3f}")
|
|
53
|
+
|
|
54
|
+
with torch.no_grad():
|
|
55
|
+
embeddings = embeddings / embeddings.norm(p=2, dim=1, keepdim=True).clamp(
|
|
56
|
+
min=1e-8
|
|
57
|
+
)
|
|
58
|
+
return embeddings.detach().clone()
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def orthogonal_initializer(
|
|
62
|
+
num_steps: int = 200, lr: float = 0.01, show_progress: bool = False
|
|
63
|
+
) -> FeatureDictionaryInitializer:
|
|
64
|
+
def initializer(feature_dict: "FeatureDictionary") -> None:
|
|
65
|
+
feature_dict.feature_vectors.data = orthogonalize_embeddings(
|
|
66
|
+
feature_dict.feature_vectors,
|
|
67
|
+
num_steps=num_steps,
|
|
68
|
+
lr=lr,
|
|
69
|
+
show_progress=show_progress,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
return initializer
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class FeatureDictionary(nn.Module):
|
|
76
|
+
"""
|
|
77
|
+
A feature dictionary that maps sparse feature activations to dense hidden activations.
|
|
78
|
+
|
|
79
|
+
This class creates a set of feature vectors (the "dictionary") and provides methods
|
|
80
|
+
to generate hidden activations from feature activations via a linear transformation.
|
|
81
|
+
|
|
82
|
+
The feature vectors can be configured to have a specific pairwise cosine similarity,
|
|
83
|
+
which is useful for controlling the difficulty of sparse recovery.
|
|
84
|
+
|
|
85
|
+
Attributes:
|
|
86
|
+
feature_vectors: Parameter of shape [num_features, hidden_dim] containing the
|
|
87
|
+
feature embedding vectors
|
|
88
|
+
bias: Parameter of shape [hidden_dim] containing the bias term (zeros if bias=False)
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
feature_vectors: nn.Parameter
|
|
92
|
+
bias: nn.Parameter
|
|
93
|
+
|
|
94
|
+
def __init__(
|
|
95
|
+
self,
|
|
96
|
+
num_features: int,
|
|
97
|
+
hidden_dim: int,
|
|
98
|
+
bias: bool = False,
|
|
99
|
+
initializer: FeatureDictionaryInitializer | None = orthogonal_initializer(),
|
|
100
|
+
):
|
|
101
|
+
"""
|
|
102
|
+
Create a new FeatureDictionary.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
num_features: Number of features in the dictionary
|
|
106
|
+
hidden_dim: Dimensionality of the hidden space
|
|
107
|
+
bias: Whether to include a bias term in the embedding
|
|
108
|
+
initializer: Initializer function to use. If None, the embeddings are initialized to random unit vectors. By default will orthogonalize embeddings.
|
|
109
|
+
"""
|
|
110
|
+
super().__init__()
|
|
111
|
+
self.num_features = num_features
|
|
112
|
+
self.hidden_dim = hidden_dim
|
|
113
|
+
|
|
114
|
+
# Initialize feature vectors as unit vectors
|
|
115
|
+
embeddings = torch.randn(num_features, hidden_dim)
|
|
116
|
+
embeddings = embeddings / embeddings.norm(p=2, dim=1, keepdim=True).clamp(
|
|
117
|
+
min=1e-8
|
|
118
|
+
)
|
|
119
|
+
self.feature_vectors = nn.Parameter(embeddings)
|
|
120
|
+
|
|
121
|
+
# Initialize bias (zeros if not using bias, but still a parameter for consistent API)
|
|
122
|
+
self.bias = nn.Parameter(torch.zeros(hidden_dim), requires_grad=bias)
|
|
123
|
+
|
|
124
|
+
if initializer is not None:
|
|
125
|
+
initializer(self)
|
|
126
|
+
|
|
127
|
+
def forward(self, feature_activations: torch.Tensor) -> torch.Tensor:
|
|
128
|
+
"""
|
|
129
|
+
Convert feature activations to hidden activations.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
feature_activations: Tensor of shape [batch, num_features] containing
|
|
133
|
+
sparse feature activation values
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
Tensor of shape [batch, hidden_dim] containing dense hidden activations
|
|
137
|
+
"""
|
|
138
|
+
return feature_activations @ self.feature_vectors + self.bias
|