sae-lens 6.28.2__py3-none-any.whl → 6.29.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +1 -1
- sae_lens/synthetic/__init__.py +6 -0
- sae_lens/synthetic/activation_generator.py +197 -25
- sae_lens/synthetic/correlation.py +217 -36
- sae_lens/synthetic/feature_dictionary.py +11 -2
- sae_lens/synthetic/hierarchy.py +314 -2
- sae_lens/synthetic/training.py +16 -3
- {sae_lens-6.28.2.dist-info → sae_lens-6.29.1.dist-info}/METADATA +1 -1
- {sae_lens-6.28.2.dist-info → sae_lens-6.29.1.dist-info}/RECORD +11 -11
- {sae_lens-6.28.2.dist-info → sae_lens-6.29.1.dist-info}/WHEEL +0 -0
- {sae_lens-6.28.2.dist-info → sae_lens-6.29.1.dist-info}/licenses/LICENSE +0 -0
sae_lens/__init__.py
CHANGED
sae_lens/synthetic/__init__.py
CHANGED
|
@@ -17,11 +17,14 @@ from sae_lens.synthetic.activation_generator import (
|
|
|
17
17
|
ActivationGenerator,
|
|
18
18
|
ActivationsModifier,
|
|
19
19
|
ActivationsModifierInput,
|
|
20
|
+
CorrelationMatrixInput,
|
|
20
21
|
)
|
|
21
22
|
from sae_lens.synthetic.correlation import (
|
|
23
|
+
LowRankCorrelationMatrix,
|
|
22
24
|
create_correlation_matrix_from_correlations,
|
|
23
25
|
generate_random_correlation_matrix,
|
|
24
26
|
generate_random_correlations,
|
|
27
|
+
generate_random_low_rank_correlation_matrix,
|
|
25
28
|
)
|
|
26
29
|
from sae_lens.synthetic.evals import (
|
|
27
30
|
SyntheticDataEvalResult,
|
|
@@ -66,6 +69,9 @@ __all__ = [
|
|
|
66
69
|
"create_correlation_matrix_from_correlations",
|
|
67
70
|
"generate_random_correlations",
|
|
68
71
|
"generate_random_correlation_matrix",
|
|
72
|
+
"generate_random_low_rank_correlation_matrix",
|
|
73
|
+
"LowRankCorrelationMatrix",
|
|
74
|
+
"CorrelationMatrixInput",
|
|
69
75
|
# Feature modifiers
|
|
70
76
|
"ActivationsModifier",
|
|
71
77
|
"ActivationsModifierInput",
|
|
@@ -2,17 +2,21 @@
|
|
|
2
2
|
Functions for generating synthetic feature activations.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
+
import math
|
|
5
6
|
from collections.abc import Callable, Sequence
|
|
6
7
|
|
|
7
8
|
import torch
|
|
8
|
-
from scipy.stats import norm
|
|
9
9
|
from torch import nn
|
|
10
10
|
from torch.distributions import MultivariateNormal
|
|
11
11
|
|
|
12
|
+
from sae_lens.synthetic.correlation import LowRankCorrelationMatrix
|
|
12
13
|
from sae_lens.util import str_to_dtype
|
|
13
14
|
|
|
14
15
|
ActivationsModifier = Callable[[torch.Tensor], torch.Tensor]
|
|
15
16
|
ActivationsModifierInput = ActivationsModifier | Sequence[ActivationsModifier] | None
|
|
17
|
+
CorrelationMatrixInput = (
|
|
18
|
+
torch.Tensor | LowRankCorrelationMatrix | tuple[torch.Tensor, torch.Tensor]
|
|
19
|
+
)
|
|
16
20
|
|
|
17
21
|
|
|
18
22
|
class ActivationGenerator(nn.Module):
|
|
@@ -28,7 +32,9 @@ class ActivationGenerator(nn.Module):
|
|
|
28
32
|
mean_firing_magnitudes: torch.Tensor
|
|
29
33
|
modify_activations: ActivationsModifier | None
|
|
30
34
|
correlation_matrix: torch.Tensor | None
|
|
35
|
+
low_rank_correlation: tuple[torch.Tensor, torch.Tensor] | None
|
|
31
36
|
correlation_thresholds: torch.Tensor | None
|
|
37
|
+
use_sparse_tensors: bool
|
|
32
38
|
|
|
33
39
|
def __init__(
|
|
34
40
|
self,
|
|
@@ -37,10 +43,37 @@ class ActivationGenerator(nn.Module):
|
|
|
37
43
|
std_firing_magnitudes: torch.Tensor | float = 0.0,
|
|
38
44
|
mean_firing_magnitudes: torch.Tensor | float = 1.0,
|
|
39
45
|
modify_activations: ActivationsModifierInput = None,
|
|
40
|
-
correlation_matrix:
|
|
46
|
+
correlation_matrix: CorrelationMatrixInput | None = None,
|
|
41
47
|
device: torch.device | str = "cpu",
|
|
42
48
|
dtype: torch.dtype | str = "float32",
|
|
49
|
+
use_sparse_tensors: bool = False,
|
|
43
50
|
):
|
|
51
|
+
"""
|
|
52
|
+
Create a new ActivationGenerator.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
num_features: Number of features to generate activations for.
|
|
56
|
+
firing_probabilities: Probability of each feature firing. Can be a single
|
|
57
|
+
float (applied to all features) or a tensor of shape (num_features,).
|
|
58
|
+
std_firing_magnitudes: Standard deviation of firing magnitudes. Can be a
|
|
59
|
+
single float or a tensor of shape (num_features,). Defaults to 0.0
|
|
60
|
+
(deterministic magnitudes).
|
|
61
|
+
mean_firing_magnitudes: Mean firing magnitude when a feature fires. Can be
|
|
62
|
+
a single float or a tensor of shape (num_features,). Defaults to 1.0.
|
|
63
|
+
modify_activations: Optional function(s) to modify activations after
|
|
64
|
+
generation. Can be a single callable, a sequence of callables (applied
|
|
65
|
+
in order), or None. Useful for applying hierarchy constraints.
|
|
66
|
+
correlation_matrix: Optional correlation structure between features. Can be:
|
|
67
|
+
|
|
68
|
+
- A full correlation matrix tensor of shape (num_features, num_features)
|
|
69
|
+
- A LowRankCorrelationMatrix for memory-efficient large-scale correlations
|
|
70
|
+
- A tuple of (factor, diag) tensors representing low-rank structure
|
|
71
|
+
|
|
72
|
+
device: Device to place tensors on. Defaults to "cpu".
|
|
73
|
+
dtype: Data type for tensors. Defaults to "float32".
|
|
74
|
+
use_sparse_tensors: If True, return sparse COO tensors from sample().
|
|
75
|
+
Only recommended when using massive numbers of features. Defaults to False.
|
|
76
|
+
"""
|
|
44
77
|
super().__init__()
|
|
45
78
|
self.num_features = num_features
|
|
46
79
|
self.firing_probabilities = _to_tensor(
|
|
@@ -54,14 +87,34 @@ class ActivationGenerator(nn.Module):
|
|
|
54
87
|
)
|
|
55
88
|
self.modify_activations = _normalize_modifiers(modify_activations)
|
|
56
89
|
self.correlation_thresholds = None
|
|
90
|
+
self.correlation_matrix = None
|
|
91
|
+
self.low_rank_correlation = None
|
|
92
|
+
self.use_sparse_tensors = use_sparse_tensors
|
|
93
|
+
|
|
57
94
|
if correlation_matrix is not None:
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
95
|
+
if isinstance(correlation_matrix, torch.Tensor):
|
|
96
|
+
# Full correlation matrix
|
|
97
|
+
_validate_correlation_matrix(correlation_matrix, num_features)
|
|
98
|
+
self.correlation_matrix = correlation_matrix
|
|
99
|
+
else:
|
|
100
|
+
# Low-rank correlation (tuple or LowRankCorrelationMatrix)
|
|
101
|
+
correlation_factor, correlation_diag = (
|
|
102
|
+
correlation_matrix[0],
|
|
103
|
+
correlation_matrix[1],
|
|
104
|
+
)
|
|
105
|
+
_validate_low_rank_correlation(
|
|
106
|
+
correlation_factor, correlation_diag, num_features
|
|
107
|
+
)
|
|
108
|
+
# Pre-compute sqrt for efficiency (used every sample call)
|
|
109
|
+
self.low_rank_correlation = (
|
|
110
|
+
correlation_factor,
|
|
111
|
+
correlation_diag.sqrt(),
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
# Vectorized inverse normal CDF: norm.ppf(1-p) = sqrt(2) * erfinv(1 - 2*p)
|
|
115
|
+
self.correlation_thresholds = math.sqrt(2) * torch.erfinv(
|
|
116
|
+
1 - 2 * self.firing_probabilities
|
|
63
117
|
)
|
|
64
|
-
self.correlation_matrix = correlation_matrix
|
|
65
118
|
|
|
66
119
|
@torch.no_grad()
|
|
67
120
|
def sample(self, batch_size: int) -> torch.Tensor:
|
|
@@ -84,30 +137,74 @@ class ActivationGenerator(nn.Module):
|
|
|
84
137
|
|
|
85
138
|
if self.correlation_matrix is not None:
|
|
86
139
|
assert self.correlation_thresholds is not None
|
|
87
|
-
|
|
140
|
+
firing_indices = _generate_correlated_features(
|
|
88
141
|
batch_size,
|
|
89
142
|
self.correlation_matrix,
|
|
90
143
|
self.correlation_thresholds,
|
|
91
144
|
device,
|
|
92
145
|
)
|
|
146
|
+
elif self.low_rank_correlation is not None:
|
|
147
|
+
assert self.correlation_thresholds is not None
|
|
148
|
+
firing_indices = _generate_low_rank_correlated_features(
|
|
149
|
+
batch_size,
|
|
150
|
+
self.low_rank_correlation[0],
|
|
151
|
+
self.low_rank_correlation[1],
|
|
152
|
+
self.correlation_thresholds,
|
|
153
|
+
device,
|
|
154
|
+
)
|
|
93
155
|
else:
|
|
94
|
-
|
|
156
|
+
firing_indices = torch.bernoulli(
|
|
95
157
|
self.firing_probabilities.unsqueeze(0).expand(batch_size, -1)
|
|
158
|
+
).nonzero(as_tuple=True)
|
|
159
|
+
|
|
160
|
+
# Compute activations only at firing positions (sparse optimization)
|
|
161
|
+
feature_indices = firing_indices[1]
|
|
162
|
+
num_firing = feature_indices.shape[0]
|
|
163
|
+
mean_at_firing = self.mean_firing_magnitudes[feature_indices]
|
|
164
|
+
std_at_firing = self.std_firing_magnitudes[feature_indices]
|
|
165
|
+
random_deltas = (
|
|
166
|
+
torch.randn(
|
|
167
|
+
num_firing, device=device, dtype=self.mean_firing_magnitudes.dtype
|
|
96
168
|
)
|
|
97
|
-
|
|
98
|
-
firing_magnitude_delta = torch.normal(
|
|
99
|
-
torch.zeros_like(self.firing_probabilities)
|
|
100
|
-
.unsqueeze(0)
|
|
101
|
-
.expand(batch_size, -1),
|
|
102
|
-
self.std_firing_magnitudes.unsqueeze(0).expand(batch_size, -1),
|
|
169
|
+
* std_at_firing
|
|
103
170
|
)
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
171
|
+
activations_at_firing = (mean_at_firing + random_deltas).relu()
|
|
172
|
+
|
|
173
|
+
if self.use_sparse_tensors:
|
|
174
|
+
# Return sparse COO tensor
|
|
175
|
+
indices = torch.stack(firing_indices) # [2, nnz]
|
|
176
|
+
feature_activations = torch.sparse_coo_tensor(
|
|
177
|
+
indices,
|
|
178
|
+
activations_at_firing,
|
|
179
|
+
size=(batch_size, self.num_features),
|
|
180
|
+
device=device,
|
|
181
|
+
dtype=self.mean_firing_magnitudes.dtype,
|
|
182
|
+
)
|
|
183
|
+
else:
|
|
184
|
+
# Dense tensor path
|
|
185
|
+
feature_activations = torch.zeros(
|
|
186
|
+
batch_size,
|
|
187
|
+
self.num_features,
|
|
188
|
+
device=device,
|
|
189
|
+
dtype=self.mean_firing_magnitudes.dtype,
|
|
190
|
+
)
|
|
191
|
+
feature_activations[firing_indices] = activations_at_firing
|
|
108
192
|
|
|
109
193
|
if self.modify_activations is not None:
|
|
110
|
-
feature_activations = self.modify_activations(feature_activations)
|
|
194
|
+
feature_activations = self.modify_activations(feature_activations)
|
|
195
|
+
if feature_activations.is_sparse:
|
|
196
|
+
# Apply relu to sparse values
|
|
197
|
+
feature_activations = feature_activations.coalesce()
|
|
198
|
+
feature_activations = torch.sparse_coo_tensor(
|
|
199
|
+
feature_activations.indices(),
|
|
200
|
+
feature_activations.values().relu(),
|
|
201
|
+
feature_activations.shape,
|
|
202
|
+
device=feature_activations.device,
|
|
203
|
+
dtype=feature_activations.dtype,
|
|
204
|
+
)
|
|
205
|
+
else:
|
|
206
|
+
feature_activations = feature_activations.relu()
|
|
207
|
+
|
|
111
208
|
return feature_activations
|
|
112
209
|
|
|
113
210
|
def forward(self, batch_size: int) -> torch.Tensor:
|
|
@@ -119,7 +216,7 @@ def _generate_correlated_features(
|
|
|
119
216
|
correlation_matrix: torch.Tensor,
|
|
120
217
|
thresholds: torch.Tensor,
|
|
121
218
|
device: torch.device,
|
|
122
|
-
) -> torch.Tensor:
|
|
219
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
123
220
|
"""
|
|
124
221
|
Generate correlated binary features using multivariate Gaussian sampling.
|
|
125
222
|
|
|
@@ -133,7 +230,7 @@ def _generate_correlated_features(
|
|
|
133
230
|
device: Device to generate samples on
|
|
134
231
|
|
|
135
232
|
Returns:
|
|
136
|
-
|
|
233
|
+
Tuple of (row_indices, col_indices) for firing features
|
|
137
234
|
"""
|
|
138
235
|
num_features = correlation_matrix.shape[0]
|
|
139
236
|
|
|
@@ -143,7 +240,49 @@ def _generate_correlated_features(
|
|
|
143
240
|
)
|
|
144
241
|
|
|
145
242
|
gaussian_samples = mvn.sample((batch_size,))
|
|
146
|
-
|
|
243
|
+
indices = (gaussian_samples > thresholds.unsqueeze(0)).nonzero(as_tuple=True)
|
|
244
|
+
return indices[0], indices[1]
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def _generate_low_rank_correlated_features(
|
|
248
|
+
batch_size: int,
|
|
249
|
+
correlation_factor: torch.Tensor,
|
|
250
|
+
cov_diag_sqrt: torch.Tensor,
|
|
251
|
+
thresholds: torch.Tensor,
|
|
252
|
+
device: torch.device,
|
|
253
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
254
|
+
"""
|
|
255
|
+
Generate correlated binary features using low-rank multivariate Gaussian sampling.
|
|
256
|
+
|
|
257
|
+
Uses the Gaussian copula approach with a low-rank covariance structure for scalability.
|
|
258
|
+
The covariance is represented as: cov = factor @ factor.T + diag(diag_term)
|
|
259
|
+
|
|
260
|
+
Args:
|
|
261
|
+
batch_size: Number of samples to generate
|
|
262
|
+
correlation_factor: Factor matrix of shape (num_features, rank)
|
|
263
|
+
cov_diag_sqrt: Pre-computed sqrt of diagonal term, shape (num_features,)
|
|
264
|
+
thresholds: Pre-computed thresholds for each feature (from inverse normal CDF)
|
|
265
|
+
device: Device to generate samples on
|
|
266
|
+
|
|
267
|
+
Returns:
|
|
268
|
+
Tuple of (row_indices, col_indices) for firing features
|
|
269
|
+
"""
|
|
270
|
+
# Manual low-rank MVN sampling to enable autocast for the expensive matmul
|
|
271
|
+
# samples = eps @ cov_factor.T + eta * sqrt(cov_diag)
|
|
272
|
+
# where eps ~ N(0, I_rank) and eta ~ N(0, I_n)
|
|
273
|
+
|
|
274
|
+
num_features, rank = correlation_factor.shape
|
|
275
|
+
|
|
276
|
+
# Generate random samples in float32 for numerical stability
|
|
277
|
+
eps = torch.randn(batch_size, rank, device=device, dtype=correlation_factor.dtype)
|
|
278
|
+
eta = torch.randn(
|
|
279
|
+
batch_size, num_features, device=device, dtype=cov_diag_sqrt.dtype
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
gaussian_samples = eps @ correlation_factor.T + eta * cov_diag_sqrt
|
|
283
|
+
|
|
284
|
+
indices = (gaussian_samples > thresholds.unsqueeze(0)).nonzero(as_tuple=True)
|
|
285
|
+
return indices[0], indices[1]
|
|
147
286
|
|
|
148
287
|
|
|
149
288
|
def _to_tensor(
|
|
@@ -194,7 +333,7 @@ def _validate_correlation_matrix(
|
|
|
194
333
|
|
|
195
334
|
Args:
|
|
196
335
|
correlation_matrix: The matrix to validate
|
|
197
|
-
num_features: Expected number of features (matrix should be
|
|
336
|
+
num_features: Expected number of features (matrix should be (num_features, num_features))
|
|
198
337
|
|
|
199
338
|
Raises:
|
|
200
339
|
ValueError: If the matrix has incorrect shape, non-unit diagonal, or is not positive definite
|
|
@@ -214,3 +353,36 @@ def _validate_correlation_matrix(
|
|
|
214
353
|
torch.linalg.cholesky(correlation_matrix)
|
|
215
354
|
except RuntimeError as e:
|
|
216
355
|
raise ValueError("Correlation matrix must be positive definite") from e
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
def _validate_low_rank_correlation(
|
|
359
|
+
correlation_factor: torch.Tensor,
|
|
360
|
+
correlation_diag: torch.Tensor,
|
|
361
|
+
num_features: int,
|
|
362
|
+
) -> None:
|
|
363
|
+
"""Validate that low-rank correlation parameters have correct properties.
|
|
364
|
+
|
|
365
|
+
Args:
|
|
366
|
+
correlation_factor: Factor matrix of shape (num_features, rank)
|
|
367
|
+
correlation_diag: Diagonal term of shape (num_features,)
|
|
368
|
+
num_features: Expected number of features
|
|
369
|
+
|
|
370
|
+
Raises:
|
|
371
|
+
ValueError: If shapes are incorrect or diagonal terms are not positive
|
|
372
|
+
"""
|
|
373
|
+
if correlation_factor.ndim != 2:
|
|
374
|
+
raise ValueError(
|
|
375
|
+
f"correlation_factor must be 2D, got {correlation_factor.ndim}D"
|
|
376
|
+
)
|
|
377
|
+
if correlation_factor.shape[0] != num_features:
|
|
378
|
+
raise ValueError(
|
|
379
|
+
f"correlation_factor must have shape ({num_features}, rank), "
|
|
380
|
+
f"got {tuple(correlation_factor.shape)}"
|
|
381
|
+
)
|
|
382
|
+
if correlation_diag.shape != (num_features,):
|
|
383
|
+
raise ValueError(
|
|
384
|
+
f"correlation_diag must have shape ({num_features},), "
|
|
385
|
+
f"got {tuple(correlation_diag.shape)}"
|
|
386
|
+
)
|
|
387
|
+
if torch.any(correlation_diag <= 0):
|
|
388
|
+
raise ValueError("correlation_diag must have all positive values")
|
|
@@ -1,7 +1,32 @@
|
|
|
1
1
|
import random
|
|
2
|
+
from typing import NamedTuple
|
|
2
3
|
|
|
3
4
|
import torch
|
|
4
5
|
|
|
6
|
+
from sae_lens.util import str_to_dtype
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class LowRankCorrelationMatrix(NamedTuple):
|
|
10
|
+
"""
|
|
11
|
+
Low-rank representation of a correlation matrix for scalable correlated sampling.
|
|
12
|
+
|
|
13
|
+
The correlation structure is represented as:
|
|
14
|
+
correlation = correlation_factor @ correlation_factor.T + diag(correlation_diag)
|
|
15
|
+
|
|
16
|
+
This requires O(num_features * rank) storage instead of O(num_features^2),
|
|
17
|
+
making it suitable for very large numbers of features (e.g., 1M+).
|
|
18
|
+
|
|
19
|
+
Attributes:
|
|
20
|
+
correlation_factor: Factor matrix of shape (num_features, rank) that captures
|
|
21
|
+
correlations through shared latent factors.
|
|
22
|
+
correlation_diag: Diagonal variance term of shape (num_features,). Should be
|
|
23
|
+
chosen such that the diagonal of the full correlation matrix equals 1.
|
|
24
|
+
Typically: correlation_diag[i] = 1 - sum(correlation_factor[i, :]^2)
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
correlation_factor: torch.Tensor
|
|
28
|
+
correlation_diag: torch.Tensor
|
|
29
|
+
|
|
5
30
|
|
|
6
31
|
def create_correlation_matrix_from_correlations(
|
|
7
32
|
num_features: int,
|
|
@@ -11,14 +36,18 @@ def create_correlation_matrix_from_correlations(
|
|
|
11
36
|
"""
|
|
12
37
|
Create a correlation matrix with specified pairwise correlations.
|
|
13
38
|
|
|
39
|
+
Note: If the resulting matrix is not positive definite, it will be adjusted
|
|
40
|
+
to ensure validity. This adjustment may change the specified correlation
|
|
41
|
+
values. To minimize this effect, use smaller correlation magnitudes.
|
|
42
|
+
|
|
14
43
|
Args:
|
|
15
44
|
num_features: Number of features
|
|
16
45
|
correlations: Dict mapping (i, j) pairs to correlation values.
|
|
17
|
-
Pairs should have i < j.
|
|
46
|
+
Pairs should have i < j. Pairs not specified will use default_correlation.
|
|
18
47
|
default_correlation: Default correlation for unspecified pairs
|
|
19
48
|
|
|
20
49
|
Returns:
|
|
21
|
-
Correlation matrix of shape
|
|
50
|
+
Correlation matrix of shape (num_features, num_features)
|
|
22
51
|
"""
|
|
23
52
|
matrix = torch.eye(num_features) + default_correlation * (
|
|
24
53
|
1 - torch.eye(num_features)
|
|
@@ -50,6 +79,7 @@ def _fix_correlation_matrix(
|
|
|
50
79
|
fixed_matrix = eigenvecs @ torch.diag(eigenvals) @ eigenvecs.T
|
|
51
80
|
|
|
52
81
|
diag_vals = torch.diag(fixed_matrix)
|
|
82
|
+
diag_vals = torch.clamp(diag_vals, min=1e-8) # Prevent division by zero
|
|
53
83
|
fixed_matrix = fixed_matrix / torch.sqrt(
|
|
54
84
|
diag_vals.unsqueeze(0) * diag_vals.unsqueeze(1)
|
|
55
85
|
)
|
|
@@ -58,6 +88,25 @@ def _fix_correlation_matrix(
|
|
|
58
88
|
return fixed_matrix
|
|
59
89
|
|
|
60
90
|
|
|
91
|
+
def _validate_correlation_params(
|
|
92
|
+
positive_ratio: float,
|
|
93
|
+
uncorrelated_ratio: float,
|
|
94
|
+
min_correlation_strength: float,
|
|
95
|
+
max_correlation_strength: float,
|
|
96
|
+
) -> None:
|
|
97
|
+
"""Validate parameters for correlation generation."""
|
|
98
|
+
if not 0.0 <= positive_ratio <= 1.0:
|
|
99
|
+
raise ValueError("positive_ratio must be between 0.0 and 1.0")
|
|
100
|
+
if not 0.0 <= uncorrelated_ratio <= 1.0:
|
|
101
|
+
raise ValueError("uncorrelated_ratio must be between 0.0 and 1.0")
|
|
102
|
+
if min_correlation_strength < 0:
|
|
103
|
+
raise ValueError("min_correlation_strength must be non-negative")
|
|
104
|
+
if max_correlation_strength > 1.0:
|
|
105
|
+
raise ValueError("max_correlation_strength must be <= 1.0")
|
|
106
|
+
if min_correlation_strength > max_correlation_strength:
|
|
107
|
+
raise ValueError("min_correlation_strength must be <= max_correlation_strength")
|
|
108
|
+
|
|
109
|
+
|
|
61
110
|
def generate_random_correlations(
|
|
62
111
|
num_features: int,
|
|
63
112
|
positive_ratio: float = 0.5,
|
|
@@ -71,29 +120,26 @@ def generate_random_correlations(
|
|
|
71
120
|
|
|
72
121
|
Args:
|
|
73
122
|
num_features: Number of features
|
|
74
|
-
positive_ratio: Fraction of
|
|
75
|
-
uncorrelated_ratio: Fraction of feature pairs that should
|
|
76
|
-
|
|
77
|
-
|
|
123
|
+
positive_ratio: Fraction of correlated pairs that should be positive (0.0 to 1.0)
|
|
124
|
+
uncorrelated_ratio: Fraction of feature pairs that should have zero correlation
|
|
125
|
+
(0.0 to 1.0). These pairs are omitted from the returned dictionary.
|
|
126
|
+
min_correlation_strength: Minimum absolute correlation strength for correlated pairs
|
|
127
|
+
max_correlation_strength: Maximum absolute correlation strength for correlated pairs
|
|
78
128
|
seed: Random seed for reproducibility
|
|
79
129
|
|
|
80
130
|
Returns:
|
|
81
|
-
Dictionary mapping (i, j) pairs to correlation values
|
|
131
|
+
Dictionary mapping (i, j) pairs to correlation values. Pairs with zero
|
|
132
|
+
correlation (determined by uncorrelated_ratio) are not included.
|
|
82
133
|
"""
|
|
83
134
|
# Use local random number generator to avoid side effects on global state
|
|
84
135
|
rng = random.Random(seed)
|
|
85
136
|
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
raise ValueError("min_correlation_strength must be non-negative")
|
|
93
|
-
if max_correlation_strength > 1.0:
|
|
94
|
-
raise ValueError("max_correlation_strength must be <= 1.0")
|
|
95
|
-
if min_correlation_strength > max_correlation_strength:
|
|
96
|
-
raise ValueError("min_correlation_strength must be <= max_correlation_strength")
|
|
137
|
+
_validate_correlation_params(
|
|
138
|
+
positive_ratio,
|
|
139
|
+
uncorrelated_ratio,
|
|
140
|
+
min_correlation_strength,
|
|
141
|
+
max_correlation_strength,
|
|
142
|
+
)
|
|
97
143
|
|
|
98
144
|
# Generate all possible feature pairs (i, j) where i < j
|
|
99
145
|
all_pairs = [
|
|
@@ -136,35 +182,170 @@ def generate_random_correlation_matrix(
|
|
|
136
182
|
min_correlation_strength: float = 0.1,
|
|
137
183
|
max_correlation_strength: float = 0.8,
|
|
138
184
|
seed: int | None = None,
|
|
185
|
+
device: torch.device | str = "cpu",
|
|
186
|
+
dtype: torch.dtype | str = torch.float32,
|
|
139
187
|
) -> torch.Tensor:
|
|
140
188
|
"""
|
|
141
189
|
Generate a random correlation matrix with specified constraints.
|
|
142
190
|
|
|
143
|
-
|
|
144
|
-
|
|
191
|
+
Uses vectorized torch operations for efficiency with large numbers of features.
|
|
192
|
+
|
|
193
|
+
Note: If the randomly generated matrix is not positive definite, it will be
|
|
194
|
+
adjusted to ensure validity. This adjustment may change correlation values,
|
|
195
|
+
including turning some zero correlations into non-zero values. To minimize
|
|
196
|
+
this effect, use smaller correlation strengths (e.g., 0.01-0.1).
|
|
197
|
+
|
|
198
|
+
Args:
|
|
199
|
+
num_features: Number of features
|
|
200
|
+
positive_ratio: Fraction of correlated pairs that should be positive (0.0 to 1.0)
|
|
201
|
+
uncorrelated_ratio: Fraction of feature pairs that should have zero correlation
|
|
202
|
+
(0.0 to 1.0). Note that matrix fixing for positive definiteness may reduce
|
|
203
|
+
the actual number of zero correlations.
|
|
204
|
+
min_correlation_strength: Minimum absolute correlation strength for correlated pairs
|
|
205
|
+
max_correlation_strength: Maximum absolute correlation strength for correlated pairs
|
|
206
|
+
seed: Random seed for reproducibility
|
|
207
|
+
device: Device to create the matrix on
|
|
208
|
+
dtype: Data type for the matrix
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
Random correlation matrix of shape (num_features, num_features)
|
|
212
|
+
"""
|
|
213
|
+
dtype = str_to_dtype(dtype)
|
|
214
|
+
_validate_correlation_params(
|
|
215
|
+
positive_ratio,
|
|
216
|
+
uncorrelated_ratio,
|
|
217
|
+
min_correlation_strength,
|
|
218
|
+
max_correlation_strength,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
if num_features <= 1:
|
|
222
|
+
return torch.eye(num_features, device=device, dtype=dtype)
|
|
223
|
+
|
|
224
|
+
# Set random seed if provided
|
|
225
|
+
generator = torch.Generator(device=device)
|
|
226
|
+
if seed is not None:
|
|
227
|
+
generator.manual_seed(seed)
|
|
228
|
+
|
|
229
|
+
# Get upper triangular indices (i < j)
|
|
230
|
+
row_idx, col_idx = torch.triu_indices(num_features, num_features, offset=1)
|
|
231
|
+
num_pairs = row_idx.shape[0]
|
|
232
|
+
|
|
233
|
+
# Generate random values for all pairs at once
|
|
234
|
+
# is_correlated: 1 if this pair should have a correlation, 0 otherwise
|
|
235
|
+
is_correlated = (
|
|
236
|
+
torch.rand(num_pairs, generator=generator, device=device) >= uncorrelated_ratio
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
# signs: +1 for positive correlation, -1 for negative
|
|
240
|
+
is_positive = (
|
|
241
|
+
torch.rand(num_pairs, generator=generator, device=device) < positive_ratio
|
|
242
|
+
)
|
|
243
|
+
signs = torch.where(is_positive, 1.0, -1.0)
|
|
244
|
+
|
|
245
|
+
# strengths: uniform in [min_strength, max_strength]
|
|
246
|
+
strengths = (
|
|
247
|
+
torch.rand(num_pairs, generator=generator, device=device, dtype=dtype)
|
|
248
|
+
* (max_correlation_strength - min_correlation_strength)
|
|
249
|
+
+ min_correlation_strength
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
# Combine: correlation = is_correlated * sign * strength
|
|
253
|
+
correlations = is_correlated.to(dtype) * signs.to(dtype) * strengths
|
|
254
|
+
|
|
255
|
+
# Build the symmetric matrix
|
|
256
|
+
matrix = torch.eye(num_features, device=device, dtype=dtype)
|
|
257
|
+
matrix[row_idx, col_idx] = correlations
|
|
258
|
+
matrix[col_idx, row_idx] = correlations
|
|
259
|
+
|
|
260
|
+
# Check positive definiteness and fix if necessary
|
|
261
|
+
eigenvals = torch.linalg.eigvalsh(matrix)
|
|
262
|
+
if torch.any(eigenvals < -1e-6):
|
|
263
|
+
matrix = _fix_correlation_matrix(matrix)
|
|
264
|
+
|
|
265
|
+
return matrix
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def generate_random_low_rank_correlation_matrix(
|
|
269
|
+
num_features: int,
|
|
270
|
+
rank: int,
|
|
271
|
+
correlation_scale: float = 0.1,
|
|
272
|
+
seed: int | None = None,
|
|
273
|
+
device: torch.device | str = "cpu",
|
|
274
|
+
dtype: torch.dtype | str = torch.float32,
|
|
275
|
+
) -> LowRankCorrelationMatrix:
|
|
276
|
+
"""
|
|
277
|
+
Generate a random low-rank correlation structure for scalable correlated sampling.
|
|
278
|
+
|
|
279
|
+
The correlation structure is represented as:
|
|
280
|
+
correlation = factor @ factor.T + diag(diag_term)
|
|
281
|
+
|
|
282
|
+
This requires O(num_features * rank) storage instead of O(num_features^2),
|
|
283
|
+
making it suitable for very large numbers of features (e.g., 1M+).
|
|
284
|
+
|
|
285
|
+
The factor matrix is initialized with random values scaled by correlation_scale,
|
|
286
|
+
and the diagonal term is computed to ensure the implied correlation matrix has
|
|
287
|
+
unit diagonal.
|
|
145
288
|
|
|
146
289
|
Args:
|
|
147
290
|
num_features: Number of features
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
291
|
+
rank: Rank of the low-rank approximation. Higher rank allows more complex
|
|
292
|
+
correlation structures but uses more memory. Typical values: 10-100.
|
|
293
|
+
correlation_scale: Scale factor for random correlations. Larger values produce
|
|
294
|
+
stronger correlations between features. Use 0 for no correlations (identity
|
|
295
|
+
matrix). Should be small enough that rank * correlation_scale^2 < 1 to
|
|
296
|
+
ensure valid diagonal terms.
|
|
152
297
|
seed: Random seed for reproducibility
|
|
298
|
+
device: Device to create tensors on
|
|
299
|
+
dtype: Data type for tensors
|
|
153
300
|
|
|
154
301
|
Returns:
|
|
155
|
-
|
|
302
|
+
LowRankCorrelationMatrix containing the factor matrix and diagonal term
|
|
156
303
|
"""
|
|
157
|
-
#
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
304
|
+
# Minimum diagonal value to ensure numerical stability in the covariance matrix.
|
|
305
|
+
# This limits how much variance can come from the low-rank factor.
|
|
306
|
+
_MIN_DIAG = 0.01
|
|
307
|
+
|
|
308
|
+
dtype = str_to_dtype(dtype)
|
|
309
|
+
device = torch.device(device)
|
|
310
|
+
|
|
311
|
+
if rank <= 0:
|
|
312
|
+
raise ValueError("rank must be positive")
|
|
313
|
+
if correlation_scale < 0:
|
|
314
|
+
raise ValueError("correlation_scale must be non-negative")
|
|
315
|
+
|
|
316
|
+
# Set random seed if provided
|
|
317
|
+
generator = torch.Generator(device=device)
|
|
318
|
+
if seed is not None:
|
|
319
|
+
generator.manual_seed(seed)
|
|
320
|
+
|
|
321
|
+
# Generate random factor matrix
|
|
322
|
+
# Each row has norm roughly sqrt(rank) * correlation_scale
|
|
323
|
+
factor = (
|
|
324
|
+
torch.randn(num_features, rank, generator=generator, device=device, dtype=dtype)
|
|
325
|
+
* correlation_scale
|
|
165
326
|
)
|
|
166
327
|
|
|
167
|
-
#
|
|
168
|
-
|
|
169
|
-
|
|
328
|
+
# Compute diagonal term to ensure unit diagonal in implied correlation matrix
|
|
329
|
+
# diag(factor @ factor.T) + diag_term = 1
|
|
330
|
+
# diag_term = 1 - sum(factor[i, :]^2)
|
|
331
|
+
factor_sq_sum = (factor**2).sum(dim=1)
|
|
332
|
+
diag_term = 1 - factor_sq_sum
|
|
333
|
+
|
|
334
|
+
# Ensure diagonal terms are at least _MIN_DIAG for numerical stability
|
|
335
|
+
# If any diagonal term is too small, scale down the factor matrix
|
|
336
|
+
if torch.any(diag_term < _MIN_DIAG):
|
|
337
|
+
# Scale factor so max row norm squared is at most (1 - _MIN_DIAG)
|
|
338
|
+
# This ensures all diagonal terms are >= _MIN_DIAG
|
|
339
|
+
max_factor_contribution = 1 - _MIN_DIAG
|
|
340
|
+
max_sq_sum = factor_sq_sum.max()
|
|
341
|
+
scale = torch.sqrt(
|
|
342
|
+
torch.tensor(max_factor_contribution, device=device, dtype=dtype)
|
|
343
|
+
/ max_sq_sum
|
|
344
|
+
)
|
|
345
|
+
factor = factor * scale
|
|
346
|
+
factor_sq_sum = (factor**2).sum(dim=1)
|
|
347
|
+
diag_term = 1 - factor_sq_sum
|
|
348
|
+
|
|
349
|
+
return LowRankCorrelationMatrix(
|
|
350
|
+
correlation_factor=factor, correlation_diag=diag_term
|
|
170
351
|
)
|
|
@@ -9,7 +9,7 @@ from typing import Callable
|
|
|
9
9
|
|
|
10
10
|
import torch
|
|
11
11
|
from torch import nn
|
|
12
|
-
from tqdm import tqdm
|
|
12
|
+
from tqdm.auto import tqdm
|
|
13
13
|
|
|
14
14
|
FeatureDictionaryInitializer = Callable[["FeatureDictionary"], None]
|
|
15
15
|
|
|
@@ -168,9 +168,18 @@ class FeatureDictionary(nn.Module):
|
|
|
168
168
|
|
|
169
169
|
Args:
|
|
170
170
|
feature_activations: Tensor of shape [batch, num_features] containing
|
|
171
|
-
sparse feature activation values
|
|
171
|
+
sparse feature activation values. Can be dense or sparse COO.
|
|
172
172
|
|
|
173
173
|
Returns:
|
|
174
174
|
Tensor of shape [batch, hidden_dim] containing dense hidden activations
|
|
175
175
|
"""
|
|
176
|
+
if feature_activations.is_sparse:
|
|
177
|
+
# autocast is disabled here because sparse matmul is not supported with bfloat16
|
|
178
|
+
with torch.autocast(
|
|
179
|
+
device_type=feature_activations.device.type, enabled=False
|
|
180
|
+
):
|
|
181
|
+
return (
|
|
182
|
+
torch.sparse.mm(feature_activations, self.feature_vectors)
|
|
183
|
+
+ self.bias
|
|
184
|
+
)
|
|
176
185
|
return feature_activations @ self.feature_vectors + self.bias
|
sae_lens/synthetic/hierarchy.py
CHANGED
|
@@ -147,6 +147,14 @@ class _SparseHierarchyData:
|
|
|
147
147
|
# Total number of ME groups
|
|
148
148
|
num_groups: int
|
|
149
149
|
|
|
150
|
+
# Sparse COO support: Feature-to-parent mapping
|
|
151
|
+
# feat_to_parent[f] = parent feature index, or -1 if root/no parent
|
|
152
|
+
feat_to_parent: torch.Tensor | None = None # [num_features]
|
|
153
|
+
|
|
154
|
+
# Sparse COO support: Feature-to-ME-group mapping
|
|
155
|
+
# feat_to_me_group[f] = group index, or -1 if not in any ME group
|
|
156
|
+
feat_to_me_group: torch.Tensor | None = None # [num_features]
|
|
157
|
+
|
|
150
158
|
|
|
151
159
|
def _build_sparse_hierarchy(
|
|
152
160
|
roots: Sequence[HierarchyNode],
|
|
@@ -232,7 +240,11 @@ def _build_sparse_hierarchy(
|
|
|
232
240
|
me_indices = torch.empty(0, dtype=torch.long)
|
|
233
241
|
|
|
234
242
|
level_data.append(
|
|
235
|
-
_LevelData(
|
|
243
|
+
_LevelData(
|
|
244
|
+
features=feats,
|
|
245
|
+
parents=parents,
|
|
246
|
+
me_group_indices=me_indices,
|
|
247
|
+
)
|
|
236
248
|
)
|
|
237
249
|
|
|
238
250
|
# Build group siblings and parents tensors
|
|
@@ -254,12 +266,30 @@ def _build_sparse_hierarchy(
|
|
|
254
266
|
me_group_parents = torch.empty(0, dtype=torch.long)
|
|
255
267
|
num_groups = 0
|
|
256
268
|
|
|
269
|
+
# Build sparse COO support: feat_to_parent and feat_to_me_group mappings
|
|
270
|
+
# First determine num_features (max feature index + 1)
|
|
271
|
+
all_features = [f for f, _, _ in feature_info]
|
|
272
|
+
num_features = max(all_features) + 1 if all_features else 0
|
|
273
|
+
|
|
274
|
+
# Build feature-to-parent mapping
|
|
275
|
+
feat_to_parent = torch.full((num_features,), -1, dtype=torch.long)
|
|
276
|
+
for feat, parent, _ in feature_info:
|
|
277
|
+
feat_to_parent[feat] = parent
|
|
278
|
+
|
|
279
|
+
# Build feature-to-ME-group mapping
|
|
280
|
+
feat_to_me_group = torch.full((num_features,), -1, dtype=torch.long)
|
|
281
|
+
for g_idx, (_, _, siblings) in enumerate(me_groups):
|
|
282
|
+
for sib in siblings:
|
|
283
|
+
feat_to_me_group[sib] = g_idx
|
|
284
|
+
|
|
257
285
|
return _SparseHierarchyData(
|
|
258
286
|
level_data=level_data,
|
|
259
287
|
me_group_siblings=me_group_siblings,
|
|
260
288
|
me_group_sizes=me_group_sizes,
|
|
261
289
|
me_group_parents=me_group_parents,
|
|
262
290
|
num_groups=num_groups,
|
|
291
|
+
feat_to_parent=feat_to_parent,
|
|
292
|
+
feat_to_me_group=feat_to_me_group,
|
|
263
293
|
)
|
|
264
294
|
|
|
265
295
|
|
|
@@ -396,8 +426,9 @@ def _apply_me_for_groups(
|
|
|
396
426
|
# Random selection for winner
|
|
397
427
|
# Use -1e9 instead of -inf to avoid creating a tensor (torch.tensor(-float("inf")))
|
|
398
428
|
# on every call. Since random scores are in [0,1], -1e9 is effectively -inf for argmax.
|
|
429
|
+
_INACTIVE_SCORE = -1e9
|
|
399
430
|
random_scores = torch.rand(num_conflicts, max_siblings, device=device)
|
|
400
|
-
random_scores[~conflict_active] =
|
|
431
|
+
random_scores[~conflict_active] = _INACTIVE_SCORE
|
|
401
432
|
|
|
402
433
|
winner_idx = random_scores.argmax(dim=1)
|
|
403
434
|
|
|
@@ -420,6 +451,275 @@ def _apply_me_for_groups(
|
|
|
420
451
|
activations[deact_batch, deact_feat] = 0
|
|
421
452
|
|
|
422
453
|
|
|
454
|
+
# ---------------------------------------------------------------------------
|
|
455
|
+
# Sparse COO hierarchy implementation
|
|
456
|
+
# ---------------------------------------------------------------------------
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
def _apply_hierarchy_sparse_coo(
|
|
460
|
+
sparse_tensor: torch.Tensor,
|
|
461
|
+
sparse_data: _SparseHierarchyData,
|
|
462
|
+
) -> torch.Tensor:
|
|
463
|
+
"""
|
|
464
|
+
Apply hierarchy constraints to a sparse COO tensor.
|
|
465
|
+
|
|
466
|
+
This is the sparse analog of _apply_hierarchy_sparse. It processes
|
|
467
|
+
level-by-level, applying parent deactivation then mutual exclusion.
|
|
468
|
+
"""
|
|
469
|
+
if sparse_tensor._nnz() == 0:
|
|
470
|
+
return sparse_tensor
|
|
471
|
+
|
|
472
|
+
sparse_tensor = sparse_tensor.coalesce()
|
|
473
|
+
|
|
474
|
+
for level_data in sparse_data.level_data:
|
|
475
|
+
# Step 1: Apply parent deactivation for features at this level
|
|
476
|
+
if level_data.features.numel() > 0:
|
|
477
|
+
sparse_tensor = _apply_parent_deactivation_coo(
|
|
478
|
+
sparse_tensor, level_data, sparse_data
|
|
479
|
+
)
|
|
480
|
+
|
|
481
|
+
# Step 2: Apply ME for groups whose parent is at this level
|
|
482
|
+
if level_data.me_group_indices.numel() > 0:
|
|
483
|
+
sparse_tensor = _apply_me_coo(
|
|
484
|
+
sparse_tensor, level_data.me_group_indices, sparse_data
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
return sparse_tensor
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
def _apply_parent_deactivation_coo(
|
|
491
|
+
sparse_tensor: torch.Tensor,
|
|
492
|
+
level_data: _LevelData,
|
|
493
|
+
sparse_data: _SparseHierarchyData,
|
|
494
|
+
) -> torch.Tensor:
|
|
495
|
+
"""
|
|
496
|
+
Remove children from sparse COO tensor when their parent is inactive.
|
|
497
|
+
|
|
498
|
+
Uses searchsorted for efficient membership testing of parent activity.
|
|
499
|
+
"""
|
|
500
|
+
if sparse_tensor._nnz() == 0 or level_data.features.numel() == 0:
|
|
501
|
+
return sparse_tensor
|
|
502
|
+
|
|
503
|
+
sparse_tensor = sparse_tensor.coalesce()
|
|
504
|
+
indices = sparse_tensor.indices() # [2, nnz]
|
|
505
|
+
values = sparse_tensor.values() # [nnz]
|
|
506
|
+
batch_indices = indices[0]
|
|
507
|
+
feat_indices = indices[1]
|
|
508
|
+
|
|
509
|
+
_, num_features = sparse_tensor.shape
|
|
510
|
+
device = sparse_tensor.device
|
|
511
|
+
nnz = indices.shape[1]
|
|
512
|
+
|
|
513
|
+
# Build set of active (batch, feature) pairs for efficient lookup
|
|
514
|
+
# Encode as: batch_idx * num_features + feat_idx
|
|
515
|
+
active_pairs = batch_indices * num_features + feat_indices
|
|
516
|
+
active_pairs_sorted, _ = active_pairs.sort()
|
|
517
|
+
|
|
518
|
+
# Use the precomputed feat_to_parent mapping
|
|
519
|
+
assert sparse_data.feat_to_parent is not None
|
|
520
|
+
hierarchy_num_features = sparse_data.feat_to_parent.numel()
|
|
521
|
+
|
|
522
|
+
# Handle features outside the hierarchy (they have no parent, pass through)
|
|
523
|
+
in_hierarchy = feat_indices < hierarchy_num_features
|
|
524
|
+
parent_of_feat = torch.full((nnz,), -1, dtype=torch.long, device=device)
|
|
525
|
+
parent_of_feat[in_hierarchy] = sparse_data.feat_to_parent[
|
|
526
|
+
feat_indices[in_hierarchy]
|
|
527
|
+
]
|
|
528
|
+
|
|
529
|
+
# Find entries that have a parent (parent >= 0 means this feature has a parent)
|
|
530
|
+
has_parent = parent_of_feat >= 0
|
|
531
|
+
|
|
532
|
+
if not has_parent.any():
|
|
533
|
+
return sparse_tensor
|
|
534
|
+
|
|
535
|
+
# For entries with parents, check if parent is active
|
|
536
|
+
child_entry_indices = torch.where(has_parent)[0]
|
|
537
|
+
child_batch = batch_indices[has_parent]
|
|
538
|
+
child_parents = parent_of_feat[has_parent]
|
|
539
|
+
|
|
540
|
+
# Look up parent activity using searchsorted
|
|
541
|
+
parent_pairs = child_batch * num_features + child_parents
|
|
542
|
+
search_pos = torch.searchsorted(active_pairs_sorted, parent_pairs)
|
|
543
|
+
search_pos = search_pos.clamp(max=active_pairs_sorted.numel() - 1)
|
|
544
|
+
parent_active = active_pairs_sorted[search_pos] == parent_pairs
|
|
545
|
+
|
|
546
|
+
# Handle empty case
|
|
547
|
+
if active_pairs_sorted.numel() == 0:
|
|
548
|
+
parent_active = torch.zeros_like(parent_pairs, dtype=torch.bool)
|
|
549
|
+
|
|
550
|
+
# Build keep mask: keep entry if it's a root OR its parent is active
|
|
551
|
+
keep_mask = torch.ones(nnz, dtype=torch.bool, device=device)
|
|
552
|
+
keep_mask[child_entry_indices[~parent_active]] = False
|
|
553
|
+
|
|
554
|
+
if keep_mask.all():
|
|
555
|
+
return sparse_tensor
|
|
556
|
+
|
|
557
|
+
return torch.sparse_coo_tensor(
|
|
558
|
+
indices[:, keep_mask],
|
|
559
|
+
values[keep_mask],
|
|
560
|
+
sparse_tensor.shape,
|
|
561
|
+
device=device,
|
|
562
|
+
dtype=sparse_tensor.dtype,
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
|
|
566
|
+
def _apply_me_coo(
|
|
567
|
+
sparse_tensor: torch.Tensor,
|
|
568
|
+
group_indices: torch.Tensor,
|
|
569
|
+
sparse_data: _SparseHierarchyData,
|
|
570
|
+
) -> torch.Tensor:
|
|
571
|
+
"""
|
|
572
|
+
Apply mutual exclusion to sparse COO tensor.
|
|
573
|
+
|
|
574
|
+
For each ME group with multiple active siblings in the same batch,
|
|
575
|
+
randomly selects one winner and removes the rest.
|
|
576
|
+
"""
|
|
577
|
+
if sparse_tensor._nnz() == 0 or group_indices.numel() == 0:
|
|
578
|
+
return sparse_tensor
|
|
579
|
+
|
|
580
|
+
sparse_tensor = sparse_tensor.coalesce()
|
|
581
|
+
indices = sparse_tensor.indices() # [2, nnz]
|
|
582
|
+
values = sparse_tensor.values() # [nnz]
|
|
583
|
+
batch_indices = indices[0]
|
|
584
|
+
feat_indices = indices[1]
|
|
585
|
+
|
|
586
|
+
_, num_features = sparse_tensor.shape
|
|
587
|
+
device = sparse_tensor.device
|
|
588
|
+
nnz = indices.shape[1]
|
|
589
|
+
|
|
590
|
+
# Use precomputed feat_to_me_group mapping
|
|
591
|
+
assert sparse_data.feat_to_me_group is not None
|
|
592
|
+
hierarchy_num_features = sparse_data.feat_to_me_group.numel()
|
|
593
|
+
|
|
594
|
+
# Handle features outside the hierarchy (they are not in any ME group)
|
|
595
|
+
in_hierarchy = feat_indices < hierarchy_num_features
|
|
596
|
+
me_group_of_feat = torch.full((nnz,), -1, dtype=torch.long, device=device)
|
|
597
|
+
me_group_of_feat[in_hierarchy] = sparse_data.feat_to_me_group[
|
|
598
|
+
feat_indices[in_hierarchy]
|
|
599
|
+
]
|
|
600
|
+
|
|
601
|
+
# Find entries that belong to ME groups we're processing (vectorized)
|
|
602
|
+
in_relevant_group = torch.isin(me_group_of_feat, group_indices)
|
|
603
|
+
|
|
604
|
+
if not in_relevant_group.any():
|
|
605
|
+
return sparse_tensor
|
|
606
|
+
|
|
607
|
+
# Get the ME entries
|
|
608
|
+
me_entry_indices = torch.where(in_relevant_group)[0]
|
|
609
|
+
me_batch = batch_indices[in_relevant_group]
|
|
610
|
+
me_group = me_group_of_feat[in_relevant_group]
|
|
611
|
+
|
|
612
|
+
# Check parent activity for ME groups (only apply ME if parent is active)
|
|
613
|
+
me_group_parents = sparse_data.me_group_parents[me_group]
|
|
614
|
+
has_parent = me_group_parents >= 0
|
|
615
|
+
|
|
616
|
+
if has_parent.any():
|
|
617
|
+
# Build active pairs for parent lookup
|
|
618
|
+
active_pairs = batch_indices * num_features + feat_indices
|
|
619
|
+
active_pairs_sorted, _ = active_pairs.sort()
|
|
620
|
+
|
|
621
|
+
parent_pairs = (
|
|
622
|
+
me_batch[has_parent] * num_features + me_group_parents[has_parent]
|
|
623
|
+
)
|
|
624
|
+
search_pos = torch.searchsorted(active_pairs_sorted, parent_pairs)
|
|
625
|
+
search_pos = search_pos.clamp(max=active_pairs_sorted.numel() - 1)
|
|
626
|
+
parent_active_for_has_parent = active_pairs_sorted[search_pos] == parent_pairs
|
|
627
|
+
|
|
628
|
+
# Build full parent_active mask
|
|
629
|
+
parent_active = torch.ones(
|
|
630
|
+
me_entry_indices.numel(), dtype=torch.bool, device=device
|
|
631
|
+
)
|
|
632
|
+
parent_active[has_parent] = parent_active_for_has_parent
|
|
633
|
+
|
|
634
|
+
# Filter to only ME entries where parent is active
|
|
635
|
+
valid_me = parent_active
|
|
636
|
+
me_entry_indices = me_entry_indices[valid_me]
|
|
637
|
+
me_batch = me_batch[valid_me]
|
|
638
|
+
me_group = me_group[valid_me]
|
|
639
|
+
|
|
640
|
+
if me_entry_indices.numel() == 0:
|
|
641
|
+
return sparse_tensor
|
|
642
|
+
|
|
643
|
+
# Encode (batch, group) pairs
|
|
644
|
+
num_groups = sparse_data.num_groups
|
|
645
|
+
batch_group_pairs = me_batch * num_groups + me_group
|
|
646
|
+
|
|
647
|
+
# Find unique (batch, group) pairs and count occurrences
|
|
648
|
+
unique_bg, inverse, counts = torch.unique(
|
|
649
|
+
batch_group_pairs, return_inverse=True, return_counts=True
|
|
650
|
+
)
|
|
651
|
+
|
|
652
|
+
# Only process pairs with count > 1 (conflicts)
|
|
653
|
+
has_conflict = counts > 1
|
|
654
|
+
|
|
655
|
+
if not has_conflict.any():
|
|
656
|
+
return sparse_tensor
|
|
657
|
+
|
|
658
|
+
# For efficiency, we process all conflicts together
|
|
659
|
+
# Assign random scores to each ME entry
|
|
660
|
+
random_scores = torch.rand(me_entry_indices.numel(), device=device)
|
|
661
|
+
|
|
662
|
+
# For each (batch, group) pair, we want the entry with highest score to be winner
|
|
663
|
+
# Use scatter_reduce to find max score per (batch, group)
|
|
664
|
+
bg_to_dense = torch.zeros(unique_bg.numel(), dtype=torch.long, device=device)
|
|
665
|
+
bg_to_dense[has_conflict.nonzero(as_tuple=True)[0]] = torch.arange(
|
|
666
|
+
has_conflict.sum(), device=device
|
|
667
|
+
)
|
|
668
|
+
|
|
669
|
+
# Map each ME entry to its dense conflict index
|
|
670
|
+
entry_has_conflict = has_conflict[inverse]
|
|
671
|
+
|
|
672
|
+
if not entry_has_conflict.any():
|
|
673
|
+
return sparse_tensor
|
|
674
|
+
|
|
675
|
+
conflict_entries_mask = entry_has_conflict
|
|
676
|
+
conflict_entry_indices = me_entry_indices[conflict_entries_mask]
|
|
677
|
+
conflict_random_scores = random_scores[conflict_entries_mask]
|
|
678
|
+
conflict_inverse = inverse[conflict_entries_mask]
|
|
679
|
+
conflict_dense_idx = bg_to_dense[conflict_inverse]
|
|
680
|
+
|
|
681
|
+
# Vectorized winner selection using sorting
|
|
682
|
+
# Sort entries by (group_idx, -random_score) so highest score comes first per group
|
|
683
|
+
# Use group * 2 - score to sort by group ascending, then score descending
|
|
684
|
+
sort_keys = conflict_dense_idx.float() * 2.0 - conflict_random_scores
|
|
685
|
+
sorted_order = sort_keys.argsort()
|
|
686
|
+
sorted_dense_idx = conflict_dense_idx[sorted_order]
|
|
687
|
+
|
|
688
|
+
# Find first entry of each group in sorted order (these are winners)
|
|
689
|
+
group_starts = torch.cat(
|
|
690
|
+
[
|
|
691
|
+
torch.tensor([True], device=device),
|
|
692
|
+
sorted_dense_idx[1:] != sorted_dense_idx[:-1],
|
|
693
|
+
]
|
|
694
|
+
)
|
|
695
|
+
|
|
696
|
+
# Winners are entries at group starts in sorted order
|
|
697
|
+
winner_positions_in_sorted = torch.where(group_starts)[0]
|
|
698
|
+
winner_original_positions = sorted_order[winner_positions_in_sorted]
|
|
699
|
+
|
|
700
|
+
# Create winner mask (vectorized)
|
|
701
|
+
is_winner = torch.zeros(
|
|
702
|
+
conflict_entry_indices.numel(), dtype=torch.bool, device=device
|
|
703
|
+
)
|
|
704
|
+
is_winner[winner_original_positions] = True
|
|
705
|
+
|
|
706
|
+
# Build keep mask (vectorized)
|
|
707
|
+
keep_mask = torch.ones(nnz, dtype=torch.bool, device=device)
|
|
708
|
+
loser_entry_indices = conflict_entry_indices[~is_winner]
|
|
709
|
+
keep_mask[loser_entry_indices] = False
|
|
710
|
+
|
|
711
|
+
if keep_mask.all():
|
|
712
|
+
return sparse_tensor
|
|
713
|
+
|
|
714
|
+
return torch.sparse_coo_tensor(
|
|
715
|
+
indices[:, keep_mask],
|
|
716
|
+
values[keep_mask],
|
|
717
|
+
sparse_tensor.shape,
|
|
718
|
+
device=device,
|
|
719
|
+
dtype=sparse_tensor.dtype,
|
|
720
|
+
)
|
|
721
|
+
|
|
722
|
+
|
|
423
723
|
@torch.no_grad()
|
|
424
724
|
def hierarchy_modifier(
|
|
425
725
|
roots: Sequence[HierarchyNode] | HierarchyNode,
|
|
@@ -475,12 +775,24 @@ def hierarchy_modifier(
|
|
|
475
775
|
me_group_sizes=sparse_data.me_group_sizes.to(device),
|
|
476
776
|
me_group_parents=sparse_data.me_group_parents.to(device),
|
|
477
777
|
num_groups=sparse_data.num_groups,
|
|
778
|
+
feat_to_parent=(
|
|
779
|
+
sparse_data.feat_to_parent.to(device)
|
|
780
|
+
if sparse_data.feat_to_parent is not None
|
|
781
|
+
else None
|
|
782
|
+
),
|
|
783
|
+
feat_to_me_group=(
|
|
784
|
+
sparse_data.feat_to_me_group.to(device)
|
|
785
|
+
if sparse_data.feat_to_me_group is not None
|
|
786
|
+
else None
|
|
787
|
+
),
|
|
478
788
|
)
|
|
479
789
|
return device_cache[device]
|
|
480
790
|
|
|
481
791
|
def modifier(activations: torch.Tensor) -> torch.Tensor:
|
|
482
792
|
device = activations.device
|
|
483
793
|
cached = _get_sparse_for_device(device)
|
|
794
|
+
if activations.is_sparse:
|
|
795
|
+
return _apply_hierarchy_sparse_coo(activations, cached)
|
|
484
796
|
return _apply_hierarchy_sparse(activations, cached)
|
|
485
797
|
|
|
486
798
|
return modifier
|
sae_lens/synthetic/training.py
CHANGED
|
@@ -23,6 +23,8 @@ def train_toy_sae(
|
|
|
23
23
|
device: str | torch.device = "cpu",
|
|
24
24
|
n_snapshots: int = 0,
|
|
25
25
|
snapshot_fn: Callable[[SAETrainer[Any, Any]], None] | None = None,
|
|
26
|
+
autocast_sae: bool = False,
|
|
27
|
+
autocast_data: bool = False,
|
|
26
28
|
) -> None:
|
|
27
29
|
"""
|
|
28
30
|
Train an SAE on synthetic activations from a feature dictionary.
|
|
@@ -46,6 +48,8 @@ def train_toy_sae(
|
|
|
46
48
|
snapshot_fn: Callback function called at each snapshot point. Receives
|
|
47
49
|
the SAETrainer instance, allowing access to the SAE, training step,
|
|
48
50
|
and other training state. Required if n_snapshots > 0.
|
|
51
|
+
autocast_sae: Whether to autocast the SAE to bfloat16. Only recommend for large SAEs on CUDA
|
|
52
|
+
autocast_data: Whether to autocast the activations generator and feature dictionary to bfloat16. Only recommend for large data on CUDA.
|
|
49
53
|
"""
|
|
50
54
|
|
|
51
55
|
device_str = str(device) if isinstance(device, torch.device) else device
|
|
@@ -55,6 +59,7 @@ def train_toy_sae(
|
|
|
55
59
|
feature_dict=feature_dict,
|
|
56
60
|
activations_generator=activations_generator,
|
|
57
61
|
batch_size=batch_size,
|
|
62
|
+
autocast=autocast_data,
|
|
58
63
|
)
|
|
59
64
|
|
|
60
65
|
# Create trainer config
|
|
@@ -64,7 +69,7 @@ def train_toy_sae(
|
|
|
64
69
|
save_final_checkpoint=False,
|
|
65
70
|
total_training_samples=training_samples,
|
|
66
71
|
device=device_str,
|
|
67
|
-
autocast=
|
|
72
|
+
autocast=autocast_sae,
|
|
68
73
|
lr=lr,
|
|
69
74
|
lr_end=lr,
|
|
70
75
|
lr_scheduler_name="constant",
|
|
@@ -119,6 +124,7 @@ class SyntheticActivationIterator(Iterator[torch.Tensor]):
|
|
|
119
124
|
feature_dict: FeatureDictionary,
|
|
120
125
|
activations_generator: ActivationGenerator,
|
|
121
126
|
batch_size: int,
|
|
127
|
+
autocast: bool = False,
|
|
122
128
|
):
|
|
123
129
|
"""
|
|
124
130
|
Create a new SyntheticActivationIterator.
|
|
@@ -127,16 +133,23 @@ class SyntheticActivationIterator(Iterator[torch.Tensor]):
|
|
|
127
133
|
feature_dict: The feature dictionary to use for generating hidden activations
|
|
128
134
|
activations_generator: Generator that produces feature activations
|
|
129
135
|
batch_size: Number of samples per batch
|
|
136
|
+
autocast: Whether to autocast the activations generator and feature dictionary to bfloat16.
|
|
130
137
|
"""
|
|
131
138
|
self.feature_dict = feature_dict
|
|
132
139
|
self.activations_generator = activations_generator
|
|
133
140
|
self.batch_size = batch_size
|
|
141
|
+
self.autocast = autocast
|
|
134
142
|
|
|
135
143
|
@torch.no_grad()
|
|
136
144
|
def next_batch(self) -> torch.Tensor:
|
|
137
145
|
"""Generate the next batch of hidden activations."""
|
|
138
|
-
|
|
139
|
-
|
|
146
|
+
with torch.autocast(
|
|
147
|
+
device_type=self.feature_dict.feature_vectors.device.type,
|
|
148
|
+
dtype=torch.bfloat16,
|
|
149
|
+
enabled=self.autocast,
|
|
150
|
+
):
|
|
151
|
+
features = self.activations_generator(self.batch_size)
|
|
152
|
+
return self.feature_dict(features)
|
|
140
153
|
|
|
141
154
|
def __iter__(self) -> "SyntheticActivationIterator":
|
|
142
155
|
return self
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
sae_lens/__init__.py,sha256=
|
|
1
|
+
sae_lens/__init__.py,sha256=emqKVNiJwD8YtYhtgHJyAT8YSX1QmruQYuG-J4CStC4,4788
|
|
2
2
|
sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
3
|
sae_lens/analysis/hooked_sae_transformer.py,sha256=dQRgGVwce8XwylL2AzJE7l9elhtMRFCs2hdUj-Qyy4g,14038
|
|
4
4
|
sae_lens/analysis/neuronpedia_integration.py,sha256=Gx1W7hUBEuMoasNcnOnZ1wmqbXDd1pSZ1nqKEya1HQc,4962
|
|
@@ -25,16 +25,16 @@ sae_lens/saes/standard_sae.py,sha256=_hldNZkFPAf9VGrxouR1-tN8T2OEk8IkWBcXoatrC1o
|
|
|
25
25
|
sae_lens/saes/temporal_sae.py,sha256=83Ap4mYGfdN3sKdPF8nKjhdXph3-7E2QuLobqJ_YuoM,13273
|
|
26
26
|
sae_lens/saes/topk_sae.py,sha256=vrMRPrCQR1o8G_kXqY_EAoGZARupkQNFB2dNZVLsusE,21073
|
|
27
27
|
sae_lens/saes/transcoder.py,sha256=CTpJs8ASOK06npih7gZHygZuxqTR7HICWlOYfTiKjI4,13501
|
|
28
|
-
sae_lens/synthetic/__init__.py,sha256=
|
|
29
|
-
sae_lens/synthetic/activation_generator.py,sha256=
|
|
30
|
-
sae_lens/synthetic/correlation.py,sha256=
|
|
28
|
+
sae_lens/synthetic/__init__.py,sha256=MtTnGkTfHV2WjkIgs7zZyx10EK9U5fjOHXy69Aq3uKw,3095
|
|
29
|
+
sae_lens/synthetic/activation_generator.py,sha256=8L9nwC4jFRv_wg3QN-n1sFwX8w1NqwJMysWaJ41lLlY,15197
|
|
30
|
+
sae_lens/synthetic/correlation.py,sha256=tMTLo9fBfDpeXwqhyUgFqnTipj9x2W0t4oEtNxB7AG0,13256
|
|
31
31
|
sae_lens/synthetic/evals.py,sha256=Nhi314ZnRgLfhBj-3tm_zzI-pGyFTcwllDXbIpPFXeU,4584
|
|
32
|
-
sae_lens/synthetic/feature_dictionary.py,sha256=
|
|
32
|
+
sae_lens/synthetic/feature_dictionary.py,sha256=Nd4xjSTxKMnKilZ3uYi8Gv5SS5D4bv4wHiSL1uGB69E,6933
|
|
33
33
|
sae_lens/synthetic/firing_probabilities.py,sha256=yclz1pWl5gE1r8LAxFvzQS88Lxwk5-3r8BCX9HLVejA,3370
|
|
34
|
-
sae_lens/synthetic/hierarchy.py,sha256=
|
|
34
|
+
sae_lens/synthetic/hierarchy.py,sha256=nm7nwnTswktVJeKUsRZ0hLOdXcFWGbxnA1b6lefHm-4,33592
|
|
35
35
|
sae_lens/synthetic/initialization.py,sha256=orMGW-786wRDHIS2W7bEH0HmlVFQ4g2z4bnnwdv5w4s,1386
|
|
36
36
|
sae_lens/synthetic/plotting.py,sha256=5lFrej1QOkGAcImFNo5-o-8mI_rUVqvEI57KzUQPPtQ,8208
|
|
37
|
-
sae_lens/synthetic/training.py,sha256=
|
|
37
|
+
sae_lens/synthetic/training.py,sha256=fHcX2cZ6nDupr71GX0Gk17f1NvQ0SKIVXIA6IuAb2dw,5692
|
|
38
38
|
sae_lens/tokenization_and_batching.py,sha256=uoHtAs9z3XqG0Fh-iQVYVlrbyB_E3kFFhrKU30BosCo,5438
|
|
39
39
|
sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
40
40
|
sae_lens/training/activation_scaler.py,sha256=FzNfgBplLWmyiSlZ6TUvE-nur3lOiGTrlvC97ys8S24,1973
|
|
@@ -46,7 +46,7 @@ sae_lens/training/types.py,sha256=1FpLx_Doda9vZpmfm-x1e8wGBYpyhe9Kpb_JuM5nIFM,90
|
|
|
46
46
|
sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
|
|
47
47
|
sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
|
|
48
48
|
sae_lens/util.py,sha256=oIMoeyEP2IzcPFmRbKUzOAycgEyMcOasGeO_BGVZbc4,4846
|
|
49
|
-
sae_lens-6.
|
|
50
|
-
sae_lens-6.
|
|
51
|
-
sae_lens-6.
|
|
52
|
-
sae_lens-6.
|
|
49
|
+
sae_lens-6.29.1.dist-info/METADATA,sha256=0Pp1L3vNiUGzkMox_BdQR6B064tTHFgwAPGJz8FY8UM,6573
|
|
50
|
+
sae_lens-6.29.1.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
|
|
51
|
+
sae_lens-6.29.1.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
|
|
52
|
+
sae_lens-6.29.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|