sae-lens 6.28.2__py3-none-any.whl → 6.32.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +14 -1
- sae_lens/analysis/__init__.py +15 -0
- sae_lens/analysis/compat.py +16 -0
- sae_lens/analysis/hooked_sae_transformer.py +1 -1
- sae_lens/analysis/sae_transformer_bridge.py +348 -0
- sae_lens/config.py +9 -1
- sae_lens/evals.py +2 -2
- sae_lens/loading/pretrained_sae_loaders.py +11 -4
- sae_lens/pretrained_saes.yaml +36 -0
- sae_lens/saes/temporal_sae.py +1 -1
- sae_lens/synthetic/__init__.py +6 -0
- sae_lens/synthetic/activation_generator.py +197 -25
- sae_lens/synthetic/correlation.py +217 -36
- sae_lens/synthetic/feature_dictionary.py +11 -2
- sae_lens/synthetic/hierarchy.py +314 -2
- sae_lens/synthetic/training.py +16 -3
- sae_lens/training/activation_scaler.py +3 -1
- {sae_lens-6.28.2.dist-info → sae_lens-6.32.1.dist-info}/METADATA +2 -2
- {sae_lens-6.28.2.dist-info → sae_lens-6.32.1.dist-info}/RECORD +21 -19
- {sae_lens-6.28.2.dist-info → sae_lens-6.32.1.dist-info}/WHEEL +1 -1
- {sae_lens-6.28.2.dist-info → sae_lens-6.32.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,7 +1,32 @@
|
|
|
1
1
|
import random
|
|
2
|
+
from typing import NamedTuple
|
|
2
3
|
|
|
3
4
|
import torch
|
|
4
5
|
|
|
6
|
+
from sae_lens.util import str_to_dtype
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class LowRankCorrelationMatrix(NamedTuple):
|
|
10
|
+
"""
|
|
11
|
+
Low-rank representation of a correlation matrix for scalable correlated sampling.
|
|
12
|
+
|
|
13
|
+
The correlation structure is represented as:
|
|
14
|
+
correlation = correlation_factor @ correlation_factor.T + diag(correlation_diag)
|
|
15
|
+
|
|
16
|
+
This requires O(num_features * rank) storage instead of O(num_features^2),
|
|
17
|
+
making it suitable for very large numbers of features (e.g., 1M+).
|
|
18
|
+
|
|
19
|
+
Attributes:
|
|
20
|
+
correlation_factor: Factor matrix of shape (num_features, rank) that captures
|
|
21
|
+
correlations through shared latent factors.
|
|
22
|
+
correlation_diag: Diagonal variance term of shape (num_features,). Should be
|
|
23
|
+
chosen such that the diagonal of the full correlation matrix equals 1.
|
|
24
|
+
Typically: correlation_diag[i] = 1 - sum(correlation_factor[i, :]^2)
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
correlation_factor: torch.Tensor
|
|
28
|
+
correlation_diag: torch.Tensor
|
|
29
|
+
|
|
5
30
|
|
|
6
31
|
def create_correlation_matrix_from_correlations(
|
|
7
32
|
num_features: int,
|
|
@@ -11,14 +36,18 @@ def create_correlation_matrix_from_correlations(
|
|
|
11
36
|
"""
|
|
12
37
|
Create a correlation matrix with specified pairwise correlations.
|
|
13
38
|
|
|
39
|
+
Note: If the resulting matrix is not positive definite, it will be adjusted
|
|
40
|
+
to ensure validity. This adjustment may change the specified correlation
|
|
41
|
+
values. To minimize this effect, use smaller correlation magnitudes.
|
|
42
|
+
|
|
14
43
|
Args:
|
|
15
44
|
num_features: Number of features
|
|
16
45
|
correlations: Dict mapping (i, j) pairs to correlation values.
|
|
17
|
-
Pairs should have i < j.
|
|
46
|
+
Pairs should have i < j. Pairs not specified will use default_correlation.
|
|
18
47
|
default_correlation: Default correlation for unspecified pairs
|
|
19
48
|
|
|
20
49
|
Returns:
|
|
21
|
-
Correlation matrix of shape
|
|
50
|
+
Correlation matrix of shape (num_features, num_features)
|
|
22
51
|
"""
|
|
23
52
|
matrix = torch.eye(num_features) + default_correlation * (
|
|
24
53
|
1 - torch.eye(num_features)
|
|
@@ -50,6 +79,7 @@ def _fix_correlation_matrix(
|
|
|
50
79
|
fixed_matrix = eigenvecs @ torch.diag(eigenvals) @ eigenvecs.T
|
|
51
80
|
|
|
52
81
|
diag_vals = torch.diag(fixed_matrix)
|
|
82
|
+
diag_vals = torch.clamp(diag_vals, min=1e-8) # Prevent division by zero
|
|
53
83
|
fixed_matrix = fixed_matrix / torch.sqrt(
|
|
54
84
|
diag_vals.unsqueeze(0) * diag_vals.unsqueeze(1)
|
|
55
85
|
)
|
|
@@ -58,6 +88,25 @@ def _fix_correlation_matrix(
|
|
|
58
88
|
return fixed_matrix
|
|
59
89
|
|
|
60
90
|
|
|
91
|
+
def _validate_correlation_params(
|
|
92
|
+
positive_ratio: float,
|
|
93
|
+
uncorrelated_ratio: float,
|
|
94
|
+
min_correlation_strength: float,
|
|
95
|
+
max_correlation_strength: float,
|
|
96
|
+
) -> None:
|
|
97
|
+
"""Validate parameters for correlation generation."""
|
|
98
|
+
if not 0.0 <= positive_ratio <= 1.0:
|
|
99
|
+
raise ValueError("positive_ratio must be between 0.0 and 1.0")
|
|
100
|
+
if not 0.0 <= uncorrelated_ratio <= 1.0:
|
|
101
|
+
raise ValueError("uncorrelated_ratio must be between 0.0 and 1.0")
|
|
102
|
+
if min_correlation_strength < 0:
|
|
103
|
+
raise ValueError("min_correlation_strength must be non-negative")
|
|
104
|
+
if max_correlation_strength > 1.0:
|
|
105
|
+
raise ValueError("max_correlation_strength must be <= 1.0")
|
|
106
|
+
if min_correlation_strength > max_correlation_strength:
|
|
107
|
+
raise ValueError("min_correlation_strength must be <= max_correlation_strength")
|
|
108
|
+
|
|
109
|
+
|
|
61
110
|
def generate_random_correlations(
|
|
62
111
|
num_features: int,
|
|
63
112
|
positive_ratio: float = 0.5,
|
|
@@ -71,29 +120,26 @@ def generate_random_correlations(
|
|
|
71
120
|
|
|
72
121
|
Args:
|
|
73
122
|
num_features: Number of features
|
|
74
|
-
positive_ratio: Fraction of
|
|
75
|
-
uncorrelated_ratio: Fraction of feature pairs that should
|
|
76
|
-
|
|
77
|
-
|
|
123
|
+
positive_ratio: Fraction of correlated pairs that should be positive (0.0 to 1.0)
|
|
124
|
+
uncorrelated_ratio: Fraction of feature pairs that should have zero correlation
|
|
125
|
+
(0.0 to 1.0). These pairs are omitted from the returned dictionary.
|
|
126
|
+
min_correlation_strength: Minimum absolute correlation strength for correlated pairs
|
|
127
|
+
max_correlation_strength: Maximum absolute correlation strength for correlated pairs
|
|
78
128
|
seed: Random seed for reproducibility
|
|
79
129
|
|
|
80
130
|
Returns:
|
|
81
|
-
Dictionary mapping (i, j) pairs to correlation values
|
|
131
|
+
Dictionary mapping (i, j) pairs to correlation values. Pairs with zero
|
|
132
|
+
correlation (determined by uncorrelated_ratio) are not included.
|
|
82
133
|
"""
|
|
83
134
|
# Use local random number generator to avoid side effects on global state
|
|
84
135
|
rng = random.Random(seed)
|
|
85
136
|
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
raise ValueError("min_correlation_strength must be non-negative")
|
|
93
|
-
if max_correlation_strength > 1.0:
|
|
94
|
-
raise ValueError("max_correlation_strength must be <= 1.0")
|
|
95
|
-
if min_correlation_strength > max_correlation_strength:
|
|
96
|
-
raise ValueError("min_correlation_strength must be <= max_correlation_strength")
|
|
137
|
+
_validate_correlation_params(
|
|
138
|
+
positive_ratio,
|
|
139
|
+
uncorrelated_ratio,
|
|
140
|
+
min_correlation_strength,
|
|
141
|
+
max_correlation_strength,
|
|
142
|
+
)
|
|
97
143
|
|
|
98
144
|
# Generate all possible feature pairs (i, j) where i < j
|
|
99
145
|
all_pairs = [
|
|
@@ -136,35 +182,170 @@ def generate_random_correlation_matrix(
|
|
|
136
182
|
min_correlation_strength: float = 0.1,
|
|
137
183
|
max_correlation_strength: float = 0.8,
|
|
138
184
|
seed: int | None = None,
|
|
185
|
+
device: torch.device | str = "cpu",
|
|
186
|
+
dtype: torch.dtype | str = torch.float32,
|
|
139
187
|
) -> torch.Tensor:
|
|
140
188
|
"""
|
|
141
189
|
Generate a random correlation matrix with specified constraints.
|
|
142
190
|
|
|
143
|
-
|
|
144
|
-
|
|
191
|
+
Uses vectorized torch operations for efficiency with large numbers of features.
|
|
192
|
+
|
|
193
|
+
Note: If the randomly generated matrix is not positive definite, it will be
|
|
194
|
+
adjusted to ensure validity. This adjustment may change correlation values,
|
|
195
|
+
including turning some zero correlations into non-zero values. To minimize
|
|
196
|
+
this effect, use smaller correlation strengths (e.g., 0.01-0.1).
|
|
197
|
+
|
|
198
|
+
Args:
|
|
199
|
+
num_features: Number of features
|
|
200
|
+
positive_ratio: Fraction of correlated pairs that should be positive (0.0 to 1.0)
|
|
201
|
+
uncorrelated_ratio: Fraction of feature pairs that should have zero correlation
|
|
202
|
+
(0.0 to 1.0). Note that matrix fixing for positive definiteness may reduce
|
|
203
|
+
the actual number of zero correlations.
|
|
204
|
+
min_correlation_strength: Minimum absolute correlation strength for correlated pairs
|
|
205
|
+
max_correlation_strength: Maximum absolute correlation strength for correlated pairs
|
|
206
|
+
seed: Random seed for reproducibility
|
|
207
|
+
device: Device to create the matrix on
|
|
208
|
+
dtype: Data type for the matrix
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
Random correlation matrix of shape (num_features, num_features)
|
|
212
|
+
"""
|
|
213
|
+
dtype = str_to_dtype(dtype)
|
|
214
|
+
_validate_correlation_params(
|
|
215
|
+
positive_ratio,
|
|
216
|
+
uncorrelated_ratio,
|
|
217
|
+
min_correlation_strength,
|
|
218
|
+
max_correlation_strength,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
if num_features <= 1:
|
|
222
|
+
return torch.eye(num_features, device=device, dtype=dtype)
|
|
223
|
+
|
|
224
|
+
# Set random seed if provided
|
|
225
|
+
generator = torch.Generator(device=device)
|
|
226
|
+
if seed is not None:
|
|
227
|
+
generator.manual_seed(seed)
|
|
228
|
+
|
|
229
|
+
# Get upper triangular indices (i < j)
|
|
230
|
+
row_idx, col_idx = torch.triu_indices(num_features, num_features, offset=1)
|
|
231
|
+
num_pairs = row_idx.shape[0]
|
|
232
|
+
|
|
233
|
+
# Generate random values for all pairs at once
|
|
234
|
+
# is_correlated: 1 if this pair should have a correlation, 0 otherwise
|
|
235
|
+
is_correlated = (
|
|
236
|
+
torch.rand(num_pairs, generator=generator, device=device) >= uncorrelated_ratio
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
# signs: +1 for positive correlation, -1 for negative
|
|
240
|
+
is_positive = (
|
|
241
|
+
torch.rand(num_pairs, generator=generator, device=device) < positive_ratio
|
|
242
|
+
)
|
|
243
|
+
signs = torch.where(is_positive, 1.0, -1.0)
|
|
244
|
+
|
|
245
|
+
# strengths: uniform in [min_strength, max_strength]
|
|
246
|
+
strengths = (
|
|
247
|
+
torch.rand(num_pairs, generator=generator, device=device, dtype=dtype)
|
|
248
|
+
* (max_correlation_strength - min_correlation_strength)
|
|
249
|
+
+ min_correlation_strength
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
# Combine: correlation = is_correlated * sign * strength
|
|
253
|
+
correlations = is_correlated.to(dtype) * signs.to(dtype) * strengths
|
|
254
|
+
|
|
255
|
+
# Build the symmetric matrix
|
|
256
|
+
matrix = torch.eye(num_features, device=device, dtype=dtype)
|
|
257
|
+
matrix[row_idx, col_idx] = correlations
|
|
258
|
+
matrix[col_idx, row_idx] = correlations
|
|
259
|
+
|
|
260
|
+
# Check positive definiteness and fix if necessary
|
|
261
|
+
eigenvals = torch.linalg.eigvalsh(matrix)
|
|
262
|
+
if torch.any(eigenvals < -1e-6):
|
|
263
|
+
matrix = _fix_correlation_matrix(matrix)
|
|
264
|
+
|
|
265
|
+
return matrix
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def generate_random_low_rank_correlation_matrix(
|
|
269
|
+
num_features: int,
|
|
270
|
+
rank: int,
|
|
271
|
+
correlation_scale: float = 0.1,
|
|
272
|
+
seed: int | None = None,
|
|
273
|
+
device: torch.device | str = "cpu",
|
|
274
|
+
dtype: torch.dtype | str = torch.float32,
|
|
275
|
+
) -> LowRankCorrelationMatrix:
|
|
276
|
+
"""
|
|
277
|
+
Generate a random low-rank correlation structure for scalable correlated sampling.
|
|
278
|
+
|
|
279
|
+
The correlation structure is represented as:
|
|
280
|
+
correlation = factor @ factor.T + diag(diag_term)
|
|
281
|
+
|
|
282
|
+
This requires O(num_features * rank) storage instead of O(num_features^2),
|
|
283
|
+
making it suitable for very large numbers of features (e.g., 1M+).
|
|
284
|
+
|
|
285
|
+
The factor matrix is initialized with random values scaled by correlation_scale,
|
|
286
|
+
and the diagonal term is computed to ensure the implied correlation matrix has
|
|
287
|
+
unit diagonal.
|
|
145
288
|
|
|
146
289
|
Args:
|
|
147
290
|
num_features: Number of features
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
291
|
+
rank: Rank of the low-rank approximation. Higher rank allows more complex
|
|
292
|
+
correlation structures but uses more memory. Typical values: 10-100.
|
|
293
|
+
correlation_scale: Scale factor for random correlations. Larger values produce
|
|
294
|
+
stronger correlations between features. Use 0 for no correlations (identity
|
|
295
|
+
matrix). Should be small enough that rank * correlation_scale^2 < 1 to
|
|
296
|
+
ensure valid diagonal terms.
|
|
152
297
|
seed: Random seed for reproducibility
|
|
298
|
+
device: Device to create tensors on
|
|
299
|
+
dtype: Data type for tensors
|
|
153
300
|
|
|
154
301
|
Returns:
|
|
155
|
-
|
|
302
|
+
LowRankCorrelationMatrix containing the factor matrix and diagonal term
|
|
156
303
|
"""
|
|
157
|
-
#
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
304
|
+
# Minimum diagonal value to ensure numerical stability in the covariance matrix.
|
|
305
|
+
# This limits how much variance can come from the low-rank factor.
|
|
306
|
+
_MIN_DIAG = 0.01
|
|
307
|
+
|
|
308
|
+
dtype = str_to_dtype(dtype)
|
|
309
|
+
device = torch.device(device)
|
|
310
|
+
|
|
311
|
+
if rank <= 0:
|
|
312
|
+
raise ValueError("rank must be positive")
|
|
313
|
+
if correlation_scale < 0:
|
|
314
|
+
raise ValueError("correlation_scale must be non-negative")
|
|
315
|
+
|
|
316
|
+
# Set random seed if provided
|
|
317
|
+
generator = torch.Generator(device=device)
|
|
318
|
+
if seed is not None:
|
|
319
|
+
generator.manual_seed(seed)
|
|
320
|
+
|
|
321
|
+
# Generate random factor matrix
|
|
322
|
+
# Each row has norm roughly sqrt(rank) * correlation_scale
|
|
323
|
+
factor = (
|
|
324
|
+
torch.randn(num_features, rank, generator=generator, device=device, dtype=dtype)
|
|
325
|
+
* correlation_scale
|
|
165
326
|
)
|
|
166
327
|
|
|
167
|
-
#
|
|
168
|
-
|
|
169
|
-
|
|
328
|
+
# Compute diagonal term to ensure unit diagonal in implied correlation matrix
|
|
329
|
+
# diag(factor @ factor.T) + diag_term = 1
|
|
330
|
+
# diag_term = 1 - sum(factor[i, :]^2)
|
|
331
|
+
factor_sq_sum = (factor**2).sum(dim=1)
|
|
332
|
+
diag_term = 1 - factor_sq_sum
|
|
333
|
+
|
|
334
|
+
# Ensure diagonal terms are at least _MIN_DIAG for numerical stability
|
|
335
|
+
# If any diagonal term is too small, scale down the factor matrix
|
|
336
|
+
if torch.any(diag_term < _MIN_DIAG):
|
|
337
|
+
# Scale factor so max row norm squared is at most (1 - _MIN_DIAG)
|
|
338
|
+
# This ensures all diagonal terms are >= _MIN_DIAG
|
|
339
|
+
max_factor_contribution = 1 - _MIN_DIAG
|
|
340
|
+
max_sq_sum = factor_sq_sum.max()
|
|
341
|
+
scale = torch.sqrt(
|
|
342
|
+
torch.tensor(max_factor_contribution, device=device, dtype=dtype)
|
|
343
|
+
/ max_sq_sum
|
|
344
|
+
)
|
|
345
|
+
factor = factor * scale
|
|
346
|
+
factor_sq_sum = (factor**2).sum(dim=1)
|
|
347
|
+
diag_term = 1 - factor_sq_sum
|
|
348
|
+
|
|
349
|
+
return LowRankCorrelationMatrix(
|
|
350
|
+
correlation_factor=factor, correlation_diag=diag_term
|
|
170
351
|
)
|
|
@@ -9,7 +9,7 @@ from typing import Callable
|
|
|
9
9
|
|
|
10
10
|
import torch
|
|
11
11
|
from torch import nn
|
|
12
|
-
from tqdm import tqdm
|
|
12
|
+
from tqdm.auto import tqdm
|
|
13
13
|
|
|
14
14
|
FeatureDictionaryInitializer = Callable[["FeatureDictionary"], None]
|
|
15
15
|
|
|
@@ -168,9 +168,18 @@ class FeatureDictionary(nn.Module):
|
|
|
168
168
|
|
|
169
169
|
Args:
|
|
170
170
|
feature_activations: Tensor of shape [batch, num_features] containing
|
|
171
|
-
sparse feature activation values
|
|
171
|
+
sparse feature activation values. Can be dense or sparse COO.
|
|
172
172
|
|
|
173
173
|
Returns:
|
|
174
174
|
Tensor of shape [batch, hidden_dim] containing dense hidden activations
|
|
175
175
|
"""
|
|
176
|
+
if feature_activations.is_sparse:
|
|
177
|
+
# autocast is disabled here because sparse matmul is not supported with bfloat16
|
|
178
|
+
with torch.autocast(
|
|
179
|
+
device_type=feature_activations.device.type, enabled=False
|
|
180
|
+
):
|
|
181
|
+
return (
|
|
182
|
+
torch.sparse.mm(feature_activations, self.feature_vectors)
|
|
183
|
+
+ self.bias
|
|
184
|
+
)
|
|
176
185
|
return feature_activations @ self.feature_vectors + self.bias
|