sae-lens 6.28.2__py3-none-any.whl → 6.29.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sae_lens/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.28.2"
2
+ __version__ = "6.29.1"
3
3
 
4
4
  import logging
5
5
 
@@ -17,11 +17,14 @@ from sae_lens.synthetic.activation_generator import (
17
17
  ActivationGenerator,
18
18
  ActivationsModifier,
19
19
  ActivationsModifierInput,
20
+ CorrelationMatrixInput,
20
21
  )
21
22
  from sae_lens.synthetic.correlation import (
23
+ LowRankCorrelationMatrix,
22
24
  create_correlation_matrix_from_correlations,
23
25
  generate_random_correlation_matrix,
24
26
  generate_random_correlations,
27
+ generate_random_low_rank_correlation_matrix,
25
28
  )
26
29
  from sae_lens.synthetic.evals import (
27
30
  SyntheticDataEvalResult,
@@ -66,6 +69,9 @@ __all__ = [
66
69
  "create_correlation_matrix_from_correlations",
67
70
  "generate_random_correlations",
68
71
  "generate_random_correlation_matrix",
72
+ "generate_random_low_rank_correlation_matrix",
73
+ "LowRankCorrelationMatrix",
74
+ "CorrelationMatrixInput",
69
75
  # Feature modifiers
70
76
  "ActivationsModifier",
71
77
  "ActivationsModifierInput",
@@ -2,17 +2,21 @@
2
2
  Functions for generating synthetic feature activations.
3
3
  """
4
4
 
5
+ import math
5
6
  from collections.abc import Callable, Sequence
6
7
 
7
8
  import torch
8
- from scipy.stats import norm
9
9
  from torch import nn
10
10
  from torch.distributions import MultivariateNormal
11
11
 
12
+ from sae_lens.synthetic.correlation import LowRankCorrelationMatrix
12
13
  from sae_lens.util import str_to_dtype
13
14
 
14
15
  ActivationsModifier = Callable[[torch.Tensor], torch.Tensor]
15
16
  ActivationsModifierInput = ActivationsModifier | Sequence[ActivationsModifier] | None
17
+ CorrelationMatrixInput = (
18
+ torch.Tensor | LowRankCorrelationMatrix | tuple[torch.Tensor, torch.Tensor]
19
+ )
16
20
 
17
21
 
18
22
  class ActivationGenerator(nn.Module):
@@ -28,7 +32,9 @@ class ActivationGenerator(nn.Module):
28
32
  mean_firing_magnitudes: torch.Tensor
29
33
  modify_activations: ActivationsModifier | None
30
34
  correlation_matrix: torch.Tensor | None
35
+ low_rank_correlation: tuple[torch.Tensor, torch.Tensor] | None
31
36
  correlation_thresholds: torch.Tensor | None
37
+ use_sparse_tensors: bool
32
38
 
33
39
  def __init__(
34
40
  self,
@@ -37,10 +43,37 @@ class ActivationGenerator(nn.Module):
37
43
  std_firing_magnitudes: torch.Tensor | float = 0.0,
38
44
  mean_firing_magnitudes: torch.Tensor | float = 1.0,
39
45
  modify_activations: ActivationsModifierInput = None,
40
- correlation_matrix: torch.Tensor | None = None,
46
+ correlation_matrix: CorrelationMatrixInput | None = None,
41
47
  device: torch.device | str = "cpu",
42
48
  dtype: torch.dtype | str = "float32",
49
+ use_sparse_tensors: bool = False,
43
50
  ):
51
+ """
52
+ Create a new ActivationGenerator.
53
+
54
+ Args:
55
+ num_features: Number of features to generate activations for.
56
+ firing_probabilities: Probability of each feature firing. Can be a single
57
+ float (applied to all features) or a tensor of shape (num_features,).
58
+ std_firing_magnitudes: Standard deviation of firing magnitudes. Can be a
59
+ single float or a tensor of shape (num_features,). Defaults to 0.0
60
+ (deterministic magnitudes).
61
+ mean_firing_magnitudes: Mean firing magnitude when a feature fires. Can be
62
+ a single float or a tensor of shape (num_features,). Defaults to 1.0.
63
+ modify_activations: Optional function(s) to modify activations after
64
+ generation. Can be a single callable, a sequence of callables (applied
65
+ in order), or None. Useful for applying hierarchy constraints.
66
+ correlation_matrix: Optional correlation structure between features. Can be:
67
+
68
+ - A full correlation matrix tensor of shape (num_features, num_features)
69
+ - A LowRankCorrelationMatrix for memory-efficient large-scale correlations
70
+ - A tuple of (factor, diag) tensors representing low-rank structure
71
+
72
+ device: Device to place tensors on. Defaults to "cpu".
73
+ dtype: Data type for tensors. Defaults to "float32".
74
+ use_sparse_tensors: If True, return sparse COO tensors from sample().
75
+ Only recommended when using massive numbers of features. Defaults to False.
76
+ """
44
77
  super().__init__()
45
78
  self.num_features = num_features
46
79
  self.firing_probabilities = _to_tensor(
@@ -54,14 +87,34 @@ class ActivationGenerator(nn.Module):
54
87
  )
55
88
  self.modify_activations = _normalize_modifiers(modify_activations)
56
89
  self.correlation_thresholds = None
90
+ self.correlation_matrix = None
91
+ self.low_rank_correlation = None
92
+ self.use_sparse_tensors = use_sparse_tensors
93
+
57
94
  if correlation_matrix is not None:
58
- _validate_correlation_matrix(correlation_matrix, num_features)
59
- self.correlation_thresholds = torch.tensor(
60
- [norm.ppf(1 - p.item()) for p in self.firing_probabilities],
61
- device=device,
62
- dtype=self.firing_probabilities.dtype,
95
+ if isinstance(correlation_matrix, torch.Tensor):
96
+ # Full correlation matrix
97
+ _validate_correlation_matrix(correlation_matrix, num_features)
98
+ self.correlation_matrix = correlation_matrix
99
+ else:
100
+ # Low-rank correlation (tuple or LowRankCorrelationMatrix)
101
+ correlation_factor, correlation_diag = (
102
+ correlation_matrix[0],
103
+ correlation_matrix[1],
104
+ )
105
+ _validate_low_rank_correlation(
106
+ correlation_factor, correlation_diag, num_features
107
+ )
108
+ # Pre-compute sqrt for efficiency (used every sample call)
109
+ self.low_rank_correlation = (
110
+ correlation_factor,
111
+ correlation_diag.sqrt(),
112
+ )
113
+
114
+ # Vectorized inverse normal CDF: norm.ppf(1-p) = sqrt(2) * erfinv(1 - 2*p)
115
+ self.correlation_thresholds = math.sqrt(2) * torch.erfinv(
116
+ 1 - 2 * self.firing_probabilities
63
117
  )
64
- self.correlation_matrix = correlation_matrix
65
118
 
66
119
  @torch.no_grad()
67
120
  def sample(self, batch_size: int) -> torch.Tensor:
@@ -84,30 +137,74 @@ class ActivationGenerator(nn.Module):
84
137
 
85
138
  if self.correlation_matrix is not None:
86
139
  assert self.correlation_thresholds is not None
87
- firing_features = _generate_correlated_features(
140
+ firing_indices = _generate_correlated_features(
88
141
  batch_size,
89
142
  self.correlation_matrix,
90
143
  self.correlation_thresholds,
91
144
  device,
92
145
  )
146
+ elif self.low_rank_correlation is not None:
147
+ assert self.correlation_thresholds is not None
148
+ firing_indices = _generate_low_rank_correlated_features(
149
+ batch_size,
150
+ self.low_rank_correlation[0],
151
+ self.low_rank_correlation[1],
152
+ self.correlation_thresholds,
153
+ device,
154
+ )
93
155
  else:
94
- firing_features = torch.bernoulli(
156
+ firing_indices = torch.bernoulli(
95
157
  self.firing_probabilities.unsqueeze(0).expand(batch_size, -1)
158
+ ).nonzero(as_tuple=True)
159
+
160
+ # Compute activations only at firing positions (sparse optimization)
161
+ feature_indices = firing_indices[1]
162
+ num_firing = feature_indices.shape[0]
163
+ mean_at_firing = self.mean_firing_magnitudes[feature_indices]
164
+ std_at_firing = self.std_firing_magnitudes[feature_indices]
165
+ random_deltas = (
166
+ torch.randn(
167
+ num_firing, device=device, dtype=self.mean_firing_magnitudes.dtype
96
168
  )
97
-
98
- firing_magnitude_delta = torch.normal(
99
- torch.zeros_like(self.firing_probabilities)
100
- .unsqueeze(0)
101
- .expand(batch_size, -1),
102
- self.std_firing_magnitudes.unsqueeze(0).expand(batch_size, -1),
169
+ * std_at_firing
103
170
  )
104
- firing_magnitude_delta[firing_features == 0] = 0
105
- feature_activations = (
106
- firing_features * self.mean_firing_magnitudes + firing_magnitude_delta
107
- ).relu()
171
+ activations_at_firing = (mean_at_firing + random_deltas).relu()
172
+
173
+ if self.use_sparse_tensors:
174
+ # Return sparse COO tensor
175
+ indices = torch.stack(firing_indices) # [2, nnz]
176
+ feature_activations = torch.sparse_coo_tensor(
177
+ indices,
178
+ activations_at_firing,
179
+ size=(batch_size, self.num_features),
180
+ device=device,
181
+ dtype=self.mean_firing_magnitudes.dtype,
182
+ )
183
+ else:
184
+ # Dense tensor path
185
+ feature_activations = torch.zeros(
186
+ batch_size,
187
+ self.num_features,
188
+ device=device,
189
+ dtype=self.mean_firing_magnitudes.dtype,
190
+ )
191
+ feature_activations[firing_indices] = activations_at_firing
108
192
 
109
193
  if self.modify_activations is not None:
110
- feature_activations = self.modify_activations(feature_activations).relu()
194
+ feature_activations = self.modify_activations(feature_activations)
195
+ if feature_activations.is_sparse:
196
+ # Apply relu to sparse values
197
+ feature_activations = feature_activations.coalesce()
198
+ feature_activations = torch.sparse_coo_tensor(
199
+ feature_activations.indices(),
200
+ feature_activations.values().relu(),
201
+ feature_activations.shape,
202
+ device=feature_activations.device,
203
+ dtype=feature_activations.dtype,
204
+ )
205
+ else:
206
+ feature_activations = feature_activations.relu()
207
+
111
208
  return feature_activations
112
209
 
113
210
  def forward(self, batch_size: int) -> torch.Tensor:
@@ -119,7 +216,7 @@ def _generate_correlated_features(
119
216
  correlation_matrix: torch.Tensor,
120
217
  thresholds: torch.Tensor,
121
218
  device: torch.device,
122
- ) -> torch.Tensor:
219
+ ) -> tuple[torch.Tensor, torch.Tensor]:
123
220
  """
124
221
  Generate correlated binary features using multivariate Gaussian sampling.
125
222
 
@@ -133,7 +230,7 @@ def _generate_correlated_features(
133
230
  device: Device to generate samples on
134
231
 
135
232
  Returns:
136
- Binary feature matrix of shape [batch_size, num_features]
233
+ Tuple of (row_indices, col_indices) for firing features
137
234
  """
138
235
  num_features = correlation_matrix.shape[0]
139
236
 
@@ -143,7 +240,49 @@ def _generate_correlated_features(
143
240
  )
144
241
 
145
242
  gaussian_samples = mvn.sample((batch_size,))
146
- return (gaussian_samples > thresholds.unsqueeze(0)).float()
243
+ indices = (gaussian_samples > thresholds.unsqueeze(0)).nonzero(as_tuple=True)
244
+ return indices[0], indices[1]
245
+
246
+
247
+ def _generate_low_rank_correlated_features(
248
+ batch_size: int,
249
+ correlation_factor: torch.Tensor,
250
+ cov_diag_sqrt: torch.Tensor,
251
+ thresholds: torch.Tensor,
252
+ device: torch.device,
253
+ ) -> tuple[torch.Tensor, torch.Tensor]:
254
+ """
255
+ Generate correlated binary features using low-rank multivariate Gaussian sampling.
256
+
257
+ Uses the Gaussian copula approach with a low-rank covariance structure for scalability.
258
+ The covariance is represented as: cov = factor @ factor.T + diag(diag_term)
259
+
260
+ Args:
261
+ batch_size: Number of samples to generate
262
+ correlation_factor: Factor matrix of shape (num_features, rank)
263
+ cov_diag_sqrt: Pre-computed sqrt of diagonal term, shape (num_features,)
264
+ thresholds: Pre-computed thresholds for each feature (from inverse normal CDF)
265
+ device: Device to generate samples on
266
+
267
+ Returns:
268
+ Tuple of (row_indices, col_indices) for firing features
269
+ """
270
+ # Manual low-rank MVN sampling to enable autocast for the expensive matmul
271
+ # samples = eps @ cov_factor.T + eta * sqrt(cov_diag)
272
+ # where eps ~ N(0, I_rank) and eta ~ N(0, I_n)
273
+
274
+ num_features, rank = correlation_factor.shape
275
+
276
+ # Generate random samples in float32 for numerical stability
277
+ eps = torch.randn(batch_size, rank, device=device, dtype=correlation_factor.dtype)
278
+ eta = torch.randn(
279
+ batch_size, num_features, device=device, dtype=cov_diag_sqrt.dtype
280
+ )
281
+
282
+ gaussian_samples = eps @ correlation_factor.T + eta * cov_diag_sqrt
283
+
284
+ indices = (gaussian_samples > thresholds.unsqueeze(0)).nonzero(as_tuple=True)
285
+ return indices[0], indices[1]
147
286
 
148
287
 
149
288
  def _to_tensor(
@@ -194,7 +333,7 @@ def _validate_correlation_matrix(
194
333
 
195
334
  Args:
196
335
  correlation_matrix: The matrix to validate
197
- num_features: Expected number of features (matrix should be [num_features, num_features])
336
+ num_features: Expected number of features (matrix should be (num_features, num_features))
198
337
 
199
338
  Raises:
200
339
  ValueError: If the matrix has incorrect shape, non-unit diagonal, or is not positive definite
@@ -214,3 +353,36 @@ def _validate_correlation_matrix(
214
353
  torch.linalg.cholesky(correlation_matrix)
215
354
  except RuntimeError as e:
216
355
  raise ValueError("Correlation matrix must be positive definite") from e
356
+
357
+
358
+ def _validate_low_rank_correlation(
359
+ correlation_factor: torch.Tensor,
360
+ correlation_diag: torch.Tensor,
361
+ num_features: int,
362
+ ) -> None:
363
+ """Validate that low-rank correlation parameters have correct properties.
364
+
365
+ Args:
366
+ correlation_factor: Factor matrix of shape (num_features, rank)
367
+ correlation_diag: Diagonal term of shape (num_features,)
368
+ num_features: Expected number of features
369
+
370
+ Raises:
371
+ ValueError: If shapes are incorrect or diagonal terms are not positive
372
+ """
373
+ if correlation_factor.ndim != 2:
374
+ raise ValueError(
375
+ f"correlation_factor must be 2D, got {correlation_factor.ndim}D"
376
+ )
377
+ if correlation_factor.shape[0] != num_features:
378
+ raise ValueError(
379
+ f"correlation_factor must have shape ({num_features}, rank), "
380
+ f"got {tuple(correlation_factor.shape)}"
381
+ )
382
+ if correlation_diag.shape != (num_features,):
383
+ raise ValueError(
384
+ f"correlation_diag must have shape ({num_features},), "
385
+ f"got {tuple(correlation_diag.shape)}"
386
+ )
387
+ if torch.any(correlation_diag <= 0):
388
+ raise ValueError("correlation_diag must have all positive values")
@@ -1,7 +1,32 @@
1
1
  import random
2
+ from typing import NamedTuple
2
3
 
3
4
  import torch
4
5
 
6
+ from sae_lens.util import str_to_dtype
7
+
8
+
9
+ class LowRankCorrelationMatrix(NamedTuple):
10
+ """
11
+ Low-rank representation of a correlation matrix for scalable correlated sampling.
12
+
13
+ The correlation structure is represented as:
14
+ correlation = correlation_factor @ correlation_factor.T + diag(correlation_diag)
15
+
16
+ This requires O(num_features * rank) storage instead of O(num_features^2),
17
+ making it suitable for very large numbers of features (e.g., 1M+).
18
+
19
+ Attributes:
20
+ correlation_factor: Factor matrix of shape (num_features, rank) that captures
21
+ correlations through shared latent factors.
22
+ correlation_diag: Diagonal variance term of shape (num_features,). Should be
23
+ chosen such that the diagonal of the full correlation matrix equals 1.
24
+ Typically: correlation_diag[i] = 1 - sum(correlation_factor[i, :]^2)
25
+ """
26
+
27
+ correlation_factor: torch.Tensor
28
+ correlation_diag: torch.Tensor
29
+
5
30
 
6
31
  def create_correlation_matrix_from_correlations(
7
32
  num_features: int,
@@ -11,14 +36,18 @@ def create_correlation_matrix_from_correlations(
11
36
  """
12
37
  Create a correlation matrix with specified pairwise correlations.
13
38
 
39
+ Note: If the resulting matrix is not positive definite, it will be adjusted
40
+ to ensure validity. This adjustment may change the specified correlation
41
+ values. To minimize this effect, use smaller correlation magnitudes.
42
+
14
43
  Args:
15
44
  num_features: Number of features
16
45
  correlations: Dict mapping (i, j) pairs to correlation values.
17
- Pairs should have i < j.
46
+ Pairs should have i < j. Pairs not specified will use default_correlation.
18
47
  default_correlation: Default correlation for unspecified pairs
19
48
 
20
49
  Returns:
21
- Correlation matrix of shape [num_features, num_features]
50
+ Correlation matrix of shape (num_features, num_features)
22
51
  """
23
52
  matrix = torch.eye(num_features) + default_correlation * (
24
53
  1 - torch.eye(num_features)
@@ -50,6 +79,7 @@ def _fix_correlation_matrix(
50
79
  fixed_matrix = eigenvecs @ torch.diag(eigenvals) @ eigenvecs.T
51
80
 
52
81
  diag_vals = torch.diag(fixed_matrix)
82
+ diag_vals = torch.clamp(diag_vals, min=1e-8) # Prevent division by zero
53
83
  fixed_matrix = fixed_matrix / torch.sqrt(
54
84
  diag_vals.unsqueeze(0) * diag_vals.unsqueeze(1)
55
85
  )
@@ -58,6 +88,25 @@ def _fix_correlation_matrix(
58
88
  return fixed_matrix
59
89
 
60
90
 
91
+ def _validate_correlation_params(
92
+ positive_ratio: float,
93
+ uncorrelated_ratio: float,
94
+ min_correlation_strength: float,
95
+ max_correlation_strength: float,
96
+ ) -> None:
97
+ """Validate parameters for correlation generation."""
98
+ if not 0.0 <= positive_ratio <= 1.0:
99
+ raise ValueError("positive_ratio must be between 0.0 and 1.0")
100
+ if not 0.0 <= uncorrelated_ratio <= 1.0:
101
+ raise ValueError("uncorrelated_ratio must be between 0.0 and 1.0")
102
+ if min_correlation_strength < 0:
103
+ raise ValueError("min_correlation_strength must be non-negative")
104
+ if max_correlation_strength > 1.0:
105
+ raise ValueError("max_correlation_strength must be <= 1.0")
106
+ if min_correlation_strength > max_correlation_strength:
107
+ raise ValueError("min_correlation_strength must be <= max_correlation_strength")
108
+
109
+
61
110
  def generate_random_correlations(
62
111
  num_features: int,
63
112
  positive_ratio: float = 0.5,
@@ -71,29 +120,26 @@ def generate_random_correlations(
71
120
 
72
121
  Args:
73
122
  num_features: Number of features
74
- positive_ratio: Fraction of correlations that should be positive (0.0 to 1.0)
75
- uncorrelated_ratio: Fraction of feature pairs that should remain uncorrelated (0.0 to 1.0)
76
- min_correlation_strength: Minimum absolute correlation strength
77
- max_correlation_strength: Maximum absolute correlation strength
123
+ positive_ratio: Fraction of correlated pairs that should be positive (0.0 to 1.0)
124
+ uncorrelated_ratio: Fraction of feature pairs that should have zero correlation
125
+ (0.0 to 1.0). These pairs are omitted from the returned dictionary.
126
+ min_correlation_strength: Minimum absolute correlation strength for correlated pairs
127
+ max_correlation_strength: Maximum absolute correlation strength for correlated pairs
78
128
  seed: Random seed for reproducibility
79
129
 
80
130
  Returns:
81
- Dictionary mapping (i, j) pairs to correlation values
131
+ Dictionary mapping (i, j) pairs to correlation values. Pairs with zero
132
+ correlation (determined by uncorrelated_ratio) are not included.
82
133
  """
83
134
  # Use local random number generator to avoid side effects on global state
84
135
  rng = random.Random(seed)
85
136
 
86
- # Validate inputs
87
- if not 0.0 <= positive_ratio <= 1.0:
88
- raise ValueError("positive_ratio must be between 0.0 and 1.0")
89
- if not 0.0 <= uncorrelated_ratio <= 1.0:
90
- raise ValueError("uncorrelated_ratio must be between 0.0 and 1.0")
91
- if min_correlation_strength < 0:
92
- raise ValueError("min_correlation_strength must be non-negative")
93
- if max_correlation_strength > 1.0:
94
- raise ValueError("max_correlation_strength must be <= 1.0")
95
- if min_correlation_strength > max_correlation_strength:
96
- raise ValueError("min_correlation_strength must be <= max_correlation_strength")
137
+ _validate_correlation_params(
138
+ positive_ratio,
139
+ uncorrelated_ratio,
140
+ min_correlation_strength,
141
+ max_correlation_strength,
142
+ )
97
143
 
98
144
  # Generate all possible feature pairs (i, j) where i < j
99
145
  all_pairs = [
@@ -136,35 +182,170 @@ def generate_random_correlation_matrix(
136
182
  min_correlation_strength: float = 0.1,
137
183
  max_correlation_strength: float = 0.8,
138
184
  seed: int | None = None,
185
+ device: torch.device | str = "cpu",
186
+ dtype: torch.dtype | str = torch.float32,
139
187
  ) -> torch.Tensor:
140
188
  """
141
189
  Generate a random correlation matrix with specified constraints.
142
190
 
143
- This is a convenience function that combines generate_random_correlations()
144
- and create_correlation_matrix_from_correlations() into a single call.
191
+ Uses vectorized torch operations for efficiency with large numbers of features.
192
+
193
+ Note: If the randomly generated matrix is not positive definite, it will be
194
+ adjusted to ensure validity. This adjustment may change correlation values,
195
+ including turning some zero correlations into non-zero values. To minimize
196
+ this effect, use smaller correlation strengths (e.g., 0.01-0.1).
197
+
198
+ Args:
199
+ num_features: Number of features
200
+ positive_ratio: Fraction of correlated pairs that should be positive (0.0 to 1.0)
201
+ uncorrelated_ratio: Fraction of feature pairs that should have zero correlation
202
+ (0.0 to 1.0). Note that matrix fixing for positive definiteness may reduce
203
+ the actual number of zero correlations.
204
+ min_correlation_strength: Minimum absolute correlation strength for correlated pairs
205
+ max_correlation_strength: Maximum absolute correlation strength for correlated pairs
206
+ seed: Random seed for reproducibility
207
+ device: Device to create the matrix on
208
+ dtype: Data type for the matrix
209
+
210
+ Returns:
211
+ Random correlation matrix of shape (num_features, num_features)
212
+ """
213
+ dtype = str_to_dtype(dtype)
214
+ _validate_correlation_params(
215
+ positive_ratio,
216
+ uncorrelated_ratio,
217
+ min_correlation_strength,
218
+ max_correlation_strength,
219
+ )
220
+
221
+ if num_features <= 1:
222
+ return torch.eye(num_features, device=device, dtype=dtype)
223
+
224
+ # Set random seed if provided
225
+ generator = torch.Generator(device=device)
226
+ if seed is not None:
227
+ generator.manual_seed(seed)
228
+
229
+ # Get upper triangular indices (i < j)
230
+ row_idx, col_idx = torch.triu_indices(num_features, num_features, offset=1)
231
+ num_pairs = row_idx.shape[0]
232
+
233
+ # Generate random values for all pairs at once
234
+ # is_correlated: 1 if this pair should have a correlation, 0 otherwise
235
+ is_correlated = (
236
+ torch.rand(num_pairs, generator=generator, device=device) >= uncorrelated_ratio
237
+ )
238
+
239
+ # signs: +1 for positive correlation, -1 for negative
240
+ is_positive = (
241
+ torch.rand(num_pairs, generator=generator, device=device) < positive_ratio
242
+ )
243
+ signs = torch.where(is_positive, 1.0, -1.0)
244
+
245
+ # strengths: uniform in [min_strength, max_strength]
246
+ strengths = (
247
+ torch.rand(num_pairs, generator=generator, device=device, dtype=dtype)
248
+ * (max_correlation_strength - min_correlation_strength)
249
+ + min_correlation_strength
250
+ )
251
+
252
+ # Combine: correlation = is_correlated * sign * strength
253
+ correlations = is_correlated.to(dtype) * signs.to(dtype) * strengths
254
+
255
+ # Build the symmetric matrix
256
+ matrix = torch.eye(num_features, device=device, dtype=dtype)
257
+ matrix[row_idx, col_idx] = correlations
258
+ matrix[col_idx, row_idx] = correlations
259
+
260
+ # Check positive definiteness and fix if necessary
261
+ eigenvals = torch.linalg.eigvalsh(matrix)
262
+ if torch.any(eigenvals < -1e-6):
263
+ matrix = _fix_correlation_matrix(matrix)
264
+
265
+ return matrix
266
+
267
+
268
+ def generate_random_low_rank_correlation_matrix(
269
+ num_features: int,
270
+ rank: int,
271
+ correlation_scale: float = 0.1,
272
+ seed: int | None = None,
273
+ device: torch.device | str = "cpu",
274
+ dtype: torch.dtype | str = torch.float32,
275
+ ) -> LowRankCorrelationMatrix:
276
+ """
277
+ Generate a random low-rank correlation structure for scalable correlated sampling.
278
+
279
+ The correlation structure is represented as:
280
+ correlation = factor @ factor.T + diag(diag_term)
281
+
282
+ This requires O(num_features * rank) storage instead of O(num_features^2),
283
+ making it suitable for very large numbers of features (e.g., 1M+).
284
+
285
+ The factor matrix is initialized with random values scaled by correlation_scale,
286
+ and the diagonal term is computed to ensure the implied correlation matrix has
287
+ unit diagonal.
145
288
 
146
289
  Args:
147
290
  num_features: Number of features
148
- positive_ratio: Fraction of correlations that should be positive (0.0 to 1.0)
149
- uncorrelated_ratio: Fraction of feature pairs that should remain uncorrelated (0.0 to 1.0)
150
- min_correlation_strength: Minimum absolute correlation strength
151
- max_correlation_strength: Maximum absolute correlation strength
291
+ rank: Rank of the low-rank approximation. Higher rank allows more complex
292
+ correlation structures but uses more memory. Typical values: 10-100.
293
+ correlation_scale: Scale factor for random correlations. Larger values produce
294
+ stronger correlations between features. Use 0 for no correlations (identity
295
+ matrix). Should be small enough that rank * correlation_scale^2 < 1 to
296
+ ensure valid diagonal terms.
152
297
  seed: Random seed for reproducibility
298
+ device: Device to create tensors on
299
+ dtype: Data type for tensors
153
300
 
154
301
  Returns:
155
- Random correlation matrix of shape [num_features, num_features]
302
+ LowRankCorrelationMatrix containing the factor matrix and diagonal term
156
303
  """
157
- # Generate random correlations
158
- correlations = generate_random_correlations(
159
- num_features=num_features,
160
- positive_ratio=positive_ratio,
161
- uncorrelated_ratio=uncorrelated_ratio,
162
- min_correlation_strength=min_correlation_strength,
163
- max_correlation_strength=max_correlation_strength,
164
- seed=seed,
304
+ # Minimum diagonal value to ensure numerical stability in the covariance matrix.
305
+ # This limits how much variance can come from the low-rank factor.
306
+ _MIN_DIAG = 0.01
307
+
308
+ dtype = str_to_dtype(dtype)
309
+ device = torch.device(device)
310
+
311
+ if rank <= 0:
312
+ raise ValueError("rank must be positive")
313
+ if correlation_scale < 0:
314
+ raise ValueError("correlation_scale must be non-negative")
315
+
316
+ # Set random seed if provided
317
+ generator = torch.Generator(device=device)
318
+ if seed is not None:
319
+ generator.manual_seed(seed)
320
+
321
+ # Generate random factor matrix
322
+ # Each row has norm roughly sqrt(rank) * correlation_scale
323
+ factor = (
324
+ torch.randn(num_features, rank, generator=generator, device=device, dtype=dtype)
325
+ * correlation_scale
165
326
  )
166
327
 
167
- # Create and return correlation matrix
168
- return create_correlation_matrix_from_correlations(
169
- num_features=num_features, correlations=correlations
328
+ # Compute diagonal term to ensure unit diagonal in implied correlation matrix
329
+ # diag(factor @ factor.T) + diag_term = 1
330
+ # diag_term = 1 - sum(factor[i, :]^2)
331
+ factor_sq_sum = (factor**2).sum(dim=1)
332
+ diag_term = 1 - factor_sq_sum
333
+
334
+ # Ensure diagonal terms are at least _MIN_DIAG for numerical stability
335
+ # If any diagonal term is too small, scale down the factor matrix
336
+ if torch.any(diag_term < _MIN_DIAG):
337
+ # Scale factor so max row norm squared is at most (1 - _MIN_DIAG)
338
+ # This ensures all diagonal terms are >= _MIN_DIAG
339
+ max_factor_contribution = 1 - _MIN_DIAG
340
+ max_sq_sum = factor_sq_sum.max()
341
+ scale = torch.sqrt(
342
+ torch.tensor(max_factor_contribution, device=device, dtype=dtype)
343
+ / max_sq_sum
344
+ )
345
+ factor = factor * scale
346
+ factor_sq_sum = (factor**2).sum(dim=1)
347
+ diag_term = 1 - factor_sq_sum
348
+
349
+ return LowRankCorrelationMatrix(
350
+ correlation_factor=factor, correlation_diag=diag_term
170
351
  )
@@ -9,7 +9,7 @@ from typing import Callable
9
9
 
10
10
  import torch
11
11
  from torch import nn
12
- from tqdm import tqdm
12
+ from tqdm.auto import tqdm
13
13
 
14
14
  FeatureDictionaryInitializer = Callable[["FeatureDictionary"], None]
15
15
 
@@ -168,9 +168,18 @@ class FeatureDictionary(nn.Module):
168
168
 
169
169
  Args:
170
170
  feature_activations: Tensor of shape [batch, num_features] containing
171
- sparse feature activation values
171
+ sparse feature activation values. Can be dense or sparse COO.
172
172
 
173
173
  Returns:
174
174
  Tensor of shape [batch, hidden_dim] containing dense hidden activations
175
175
  """
176
+ if feature_activations.is_sparse:
177
+ # autocast is disabled here because sparse matmul is not supported with bfloat16
178
+ with torch.autocast(
179
+ device_type=feature_activations.device.type, enabled=False
180
+ ):
181
+ return (
182
+ torch.sparse.mm(feature_activations, self.feature_vectors)
183
+ + self.bias
184
+ )
176
185
  return feature_activations @ self.feature_vectors + self.bias
@@ -147,6 +147,14 @@ class _SparseHierarchyData:
147
147
  # Total number of ME groups
148
148
  num_groups: int
149
149
 
150
+ # Sparse COO support: Feature-to-parent mapping
151
+ # feat_to_parent[f] = parent feature index, or -1 if root/no parent
152
+ feat_to_parent: torch.Tensor | None = None # [num_features]
153
+
154
+ # Sparse COO support: Feature-to-ME-group mapping
155
+ # feat_to_me_group[f] = group index, or -1 if not in any ME group
156
+ feat_to_me_group: torch.Tensor | None = None # [num_features]
157
+
150
158
 
151
159
  def _build_sparse_hierarchy(
152
160
  roots: Sequence[HierarchyNode],
@@ -232,7 +240,11 @@ def _build_sparse_hierarchy(
232
240
  me_indices = torch.empty(0, dtype=torch.long)
233
241
 
234
242
  level_data.append(
235
- _LevelData(features=feats, parents=parents, me_group_indices=me_indices)
243
+ _LevelData(
244
+ features=feats,
245
+ parents=parents,
246
+ me_group_indices=me_indices,
247
+ )
236
248
  )
237
249
 
238
250
  # Build group siblings and parents tensors
@@ -254,12 +266,30 @@ def _build_sparse_hierarchy(
254
266
  me_group_parents = torch.empty(0, dtype=torch.long)
255
267
  num_groups = 0
256
268
 
269
+ # Build sparse COO support: feat_to_parent and feat_to_me_group mappings
270
+ # First determine num_features (max feature index + 1)
271
+ all_features = [f for f, _, _ in feature_info]
272
+ num_features = max(all_features) + 1 if all_features else 0
273
+
274
+ # Build feature-to-parent mapping
275
+ feat_to_parent = torch.full((num_features,), -1, dtype=torch.long)
276
+ for feat, parent, _ in feature_info:
277
+ feat_to_parent[feat] = parent
278
+
279
+ # Build feature-to-ME-group mapping
280
+ feat_to_me_group = torch.full((num_features,), -1, dtype=torch.long)
281
+ for g_idx, (_, _, siblings) in enumerate(me_groups):
282
+ for sib in siblings:
283
+ feat_to_me_group[sib] = g_idx
284
+
257
285
  return _SparseHierarchyData(
258
286
  level_data=level_data,
259
287
  me_group_siblings=me_group_siblings,
260
288
  me_group_sizes=me_group_sizes,
261
289
  me_group_parents=me_group_parents,
262
290
  num_groups=num_groups,
291
+ feat_to_parent=feat_to_parent,
292
+ feat_to_me_group=feat_to_me_group,
263
293
  )
264
294
 
265
295
 
@@ -396,8 +426,9 @@ def _apply_me_for_groups(
396
426
  # Random selection for winner
397
427
  # Use -1e9 instead of -inf to avoid creating a tensor (torch.tensor(-float("inf")))
398
428
  # on every call. Since random scores are in [0,1], -1e9 is effectively -inf for argmax.
429
+ _INACTIVE_SCORE = -1e9
399
430
  random_scores = torch.rand(num_conflicts, max_siblings, device=device)
400
- random_scores[~conflict_active] = -1e9
431
+ random_scores[~conflict_active] = _INACTIVE_SCORE
401
432
 
402
433
  winner_idx = random_scores.argmax(dim=1)
403
434
 
@@ -420,6 +451,275 @@ def _apply_me_for_groups(
420
451
  activations[deact_batch, deact_feat] = 0
421
452
 
422
453
 
454
+ # ---------------------------------------------------------------------------
455
+ # Sparse COO hierarchy implementation
456
+ # ---------------------------------------------------------------------------
457
+
458
+
459
+ def _apply_hierarchy_sparse_coo(
460
+ sparse_tensor: torch.Tensor,
461
+ sparse_data: _SparseHierarchyData,
462
+ ) -> torch.Tensor:
463
+ """
464
+ Apply hierarchy constraints to a sparse COO tensor.
465
+
466
+ This is the sparse analog of _apply_hierarchy_sparse. It processes
467
+ level-by-level, applying parent deactivation then mutual exclusion.
468
+ """
469
+ if sparse_tensor._nnz() == 0:
470
+ return sparse_tensor
471
+
472
+ sparse_tensor = sparse_tensor.coalesce()
473
+
474
+ for level_data in sparse_data.level_data:
475
+ # Step 1: Apply parent deactivation for features at this level
476
+ if level_data.features.numel() > 0:
477
+ sparse_tensor = _apply_parent_deactivation_coo(
478
+ sparse_tensor, level_data, sparse_data
479
+ )
480
+
481
+ # Step 2: Apply ME for groups whose parent is at this level
482
+ if level_data.me_group_indices.numel() > 0:
483
+ sparse_tensor = _apply_me_coo(
484
+ sparse_tensor, level_data.me_group_indices, sparse_data
485
+ )
486
+
487
+ return sparse_tensor
488
+
489
+
490
+ def _apply_parent_deactivation_coo(
491
+ sparse_tensor: torch.Tensor,
492
+ level_data: _LevelData,
493
+ sparse_data: _SparseHierarchyData,
494
+ ) -> torch.Tensor:
495
+ """
496
+ Remove children from sparse COO tensor when their parent is inactive.
497
+
498
+ Uses searchsorted for efficient membership testing of parent activity.
499
+ """
500
+ if sparse_tensor._nnz() == 0 or level_data.features.numel() == 0:
501
+ return sparse_tensor
502
+
503
+ sparse_tensor = sparse_tensor.coalesce()
504
+ indices = sparse_tensor.indices() # [2, nnz]
505
+ values = sparse_tensor.values() # [nnz]
506
+ batch_indices = indices[0]
507
+ feat_indices = indices[1]
508
+
509
+ _, num_features = sparse_tensor.shape
510
+ device = sparse_tensor.device
511
+ nnz = indices.shape[1]
512
+
513
+ # Build set of active (batch, feature) pairs for efficient lookup
514
+ # Encode as: batch_idx * num_features + feat_idx
515
+ active_pairs = batch_indices * num_features + feat_indices
516
+ active_pairs_sorted, _ = active_pairs.sort()
517
+
518
+ # Use the precomputed feat_to_parent mapping
519
+ assert sparse_data.feat_to_parent is not None
520
+ hierarchy_num_features = sparse_data.feat_to_parent.numel()
521
+
522
+ # Handle features outside the hierarchy (they have no parent, pass through)
523
+ in_hierarchy = feat_indices < hierarchy_num_features
524
+ parent_of_feat = torch.full((nnz,), -1, dtype=torch.long, device=device)
525
+ parent_of_feat[in_hierarchy] = sparse_data.feat_to_parent[
526
+ feat_indices[in_hierarchy]
527
+ ]
528
+
529
+ # Find entries that have a parent (parent >= 0 means this feature has a parent)
530
+ has_parent = parent_of_feat >= 0
531
+
532
+ if not has_parent.any():
533
+ return sparse_tensor
534
+
535
+ # For entries with parents, check if parent is active
536
+ child_entry_indices = torch.where(has_parent)[0]
537
+ child_batch = batch_indices[has_parent]
538
+ child_parents = parent_of_feat[has_parent]
539
+
540
+ # Look up parent activity using searchsorted
541
+ parent_pairs = child_batch * num_features + child_parents
542
+ search_pos = torch.searchsorted(active_pairs_sorted, parent_pairs)
543
+ search_pos = search_pos.clamp(max=active_pairs_sorted.numel() - 1)
544
+ parent_active = active_pairs_sorted[search_pos] == parent_pairs
545
+
546
+ # Handle empty case
547
+ if active_pairs_sorted.numel() == 0:
548
+ parent_active = torch.zeros_like(parent_pairs, dtype=torch.bool)
549
+
550
+ # Build keep mask: keep entry if it's a root OR its parent is active
551
+ keep_mask = torch.ones(nnz, dtype=torch.bool, device=device)
552
+ keep_mask[child_entry_indices[~parent_active]] = False
553
+
554
+ if keep_mask.all():
555
+ return sparse_tensor
556
+
557
+ return torch.sparse_coo_tensor(
558
+ indices[:, keep_mask],
559
+ values[keep_mask],
560
+ sparse_tensor.shape,
561
+ device=device,
562
+ dtype=sparse_tensor.dtype,
563
+ )
564
+
565
+
566
+ def _apply_me_coo(
567
+ sparse_tensor: torch.Tensor,
568
+ group_indices: torch.Tensor,
569
+ sparse_data: _SparseHierarchyData,
570
+ ) -> torch.Tensor:
571
+ """
572
+ Apply mutual exclusion to sparse COO tensor.
573
+
574
+ For each ME group with multiple active siblings in the same batch,
575
+ randomly selects one winner and removes the rest.
576
+ """
577
+ if sparse_tensor._nnz() == 0 or group_indices.numel() == 0:
578
+ return sparse_tensor
579
+
580
+ sparse_tensor = sparse_tensor.coalesce()
581
+ indices = sparse_tensor.indices() # [2, nnz]
582
+ values = sparse_tensor.values() # [nnz]
583
+ batch_indices = indices[0]
584
+ feat_indices = indices[1]
585
+
586
+ _, num_features = sparse_tensor.shape
587
+ device = sparse_tensor.device
588
+ nnz = indices.shape[1]
589
+
590
+ # Use precomputed feat_to_me_group mapping
591
+ assert sparse_data.feat_to_me_group is not None
592
+ hierarchy_num_features = sparse_data.feat_to_me_group.numel()
593
+
594
+ # Handle features outside the hierarchy (they are not in any ME group)
595
+ in_hierarchy = feat_indices < hierarchy_num_features
596
+ me_group_of_feat = torch.full((nnz,), -1, dtype=torch.long, device=device)
597
+ me_group_of_feat[in_hierarchy] = sparse_data.feat_to_me_group[
598
+ feat_indices[in_hierarchy]
599
+ ]
600
+
601
+ # Find entries that belong to ME groups we're processing (vectorized)
602
+ in_relevant_group = torch.isin(me_group_of_feat, group_indices)
603
+
604
+ if not in_relevant_group.any():
605
+ return sparse_tensor
606
+
607
+ # Get the ME entries
608
+ me_entry_indices = torch.where(in_relevant_group)[0]
609
+ me_batch = batch_indices[in_relevant_group]
610
+ me_group = me_group_of_feat[in_relevant_group]
611
+
612
+ # Check parent activity for ME groups (only apply ME if parent is active)
613
+ me_group_parents = sparse_data.me_group_parents[me_group]
614
+ has_parent = me_group_parents >= 0
615
+
616
+ if has_parent.any():
617
+ # Build active pairs for parent lookup
618
+ active_pairs = batch_indices * num_features + feat_indices
619
+ active_pairs_sorted, _ = active_pairs.sort()
620
+
621
+ parent_pairs = (
622
+ me_batch[has_parent] * num_features + me_group_parents[has_parent]
623
+ )
624
+ search_pos = torch.searchsorted(active_pairs_sorted, parent_pairs)
625
+ search_pos = search_pos.clamp(max=active_pairs_sorted.numel() - 1)
626
+ parent_active_for_has_parent = active_pairs_sorted[search_pos] == parent_pairs
627
+
628
+ # Build full parent_active mask
629
+ parent_active = torch.ones(
630
+ me_entry_indices.numel(), dtype=torch.bool, device=device
631
+ )
632
+ parent_active[has_parent] = parent_active_for_has_parent
633
+
634
+ # Filter to only ME entries where parent is active
635
+ valid_me = parent_active
636
+ me_entry_indices = me_entry_indices[valid_me]
637
+ me_batch = me_batch[valid_me]
638
+ me_group = me_group[valid_me]
639
+
640
+ if me_entry_indices.numel() == 0:
641
+ return sparse_tensor
642
+
643
+ # Encode (batch, group) pairs
644
+ num_groups = sparse_data.num_groups
645
+ batch_group_pairs = me_batch * num_groups + me_group
646
+
647
+ # Find unique (batch, group) pairs and count occurrences
648
+ unique_bg, inverse, counts = torch.unique(
649
+ batch_group_pairs, return_inverse=True, return_counts=True
650
+ )
651
+
652
+ # Only process pairs with count > 1 (conflicts)
653
+ has_conflict = counts > 1
654
+
655
+ if not has_conflict.any():
656
+ return sparse_tensor
657
+
658
+ # For efficiency, we process all conflicts together
659
+ # Assign random scores to each ME entry
660
+ random_scores = torch.rand(me_entry_indices.numel(), device=device)
661
+
662
+ # For each (batch, group) pair, we want the entry with highest score to be winner
663
+ # Use scatter_reduce to find max score per (batch, group)
664
+ bg_to_dense = torch.zeros(unique_bg.numel(), dtype=torch.long, device=device)
665
+ bg_to_dense[has_conflict.nonzero(as_tuple=True)[0]] = torch.arange(
666
+ has_conflict.sum(), device=device
667
+ )
668
+
669
+ # Map each ME entry to its dense conflict index
670
+ entry_has_conflict = has_conflict[inverse]
671
+
672
+ if not entry_has_conflict.any():
673
+ return sparse_tensor
674
+
675
+ conflict_entries_mask = entry_has_conflict
676
+ conflict_entry_indices = me_entry_indices[conflict_entries_mask]
677
+ conflict_random_scores = random_scores[conflict_entries_mask]
678
+ conflict_inverse = inverse[conflict_entries_mask]
679
+ conflict_dense_idx = bg_to_dense[conflict_inverse]
680
+
681
+ # Vectorized winner selection using sorting
682
+ # Sort entries by (group_idx, -random_score) so highest score comes first per group
683
+ # Use group * 2 - score to sort by group ascending, then score descending
684
+ sort_keys = conflict_dense_idx.float() * 2.0 - conflict_random_scores
685
+ sorted_order = sort_keys.argsort()
686
+ sorted_dense_idx = conflict_dense_idx[sorted_order]
687
+
688
+ # Find first entry of each group in sorted order (these are winners)
689
+ group_starts = torch.cat(
690
+ [
691
+ torch.tensor([True], device=device),
692
+ sorted_dense_idx[1:] != sorted_dense_idx[:-1],
693
+ ]
694
+ )
695
+
696
+ # Winners are entries at group starts in sorted order
697
+ winner_positions_in_sorted = torch.where(group_starts)[0]
698
+ winner_original_positions = sorted_order[winner_positions_in_sorted]
699
+
700
+ # Create winner mask (vectorized)
701
+ is_winner = torch.zeros(
702
+ conflict_entry_indices.numel(), dtype=torch.bool, device=device
703
+ )
704
+ is_winner[winner_original_positions] = True
705
+
706
+ # Build keep mask (vectorized)
707
+ keep_mask = torch.ones(nnz, dtype=torch.bool, device=device)
708
+ loser_entry_indices = conflict_entry_indices[~is_winner]
709
+ keep_mask[loser_entry_indices] = False
710
+
711
+ if keep_mask.all():
712
+ return sparse_tensor
713
+
714
+ return torch.sparse_coo_tensor(
715
+ indices[:, keep_mask],
716
+ values[keep_mask],
717
+ sparse_tensor.shape,
718
+ device=device,
719
+ dtype=sparse_tensor.dtype,
720
+ )
721
+
722
+
423
723
  @torch.no_grad()
424
724
  def hierarchy_modifier(
425
725
  roots: Sequence[HierarchyNode] | HierarchyNode,
@@ -475,12 +775,24 @@ def hierarchy_modifier(
475
775
  me_group_sizes=sparse_data.me_group_sizes.to(device),
476
776
  me_group_parents=sparse_data.me_group_parents.to(device),
477
777
  num_groups=sparse_data.num_groups,
778
+ feat_to_parent=(
779
+ sparse_data.feat_to_parent.to(device)
780
+ if sparse_data.feat_to_parent is not None
781
+ else None
782
+ ),
783
+ feat_to_me_group=(
784
+ sparse_data.feat_to_me_group.to(device)
785
+ if sparse_data.feat_to_me_group is not None
786
+ else None
787
+ ),
478
788
  )
479
789
  return device_cache[device]
480
790
 
481
791
  def modifier(activations: torch.Tensor) -> torch.Tensor:
482
792
  device = activations.device
483
793
  cached = _get_sparse_for_device(device)
794
+ if activations.is_sparse:
795
+ return _apply_hierarchy_sparse_coo(activations, cached)
484
796
  return _apply_hierarchy_sparse(activations, cached)
485
797
 
486
798
  return modifier
@@ -23,6 +23,8 @@ def train_toy_sae(
23
23
  device: str | torch.device = "cpu",
24
24
  n_snapshots: int = 0,
25
25
  snapshot_fn: Callable[[SAETrainer[Any, Any]], None] | None = None,
26
+ autocast_sae: bool = False,
27
+ autocast_data: bool = False,
26
28
  ) -> None:
27
29
  """
28
30
  Train an SAE on synthetic activations from a feature dictionary.
@@ -46,6 +48,8 @@ def train_toy_sae(
46
48
  snapshot_fn: Callback function called at each snapshot point. Receives
47
49
  the SAETrainer instance, allowing access to the SAE, training step,
48
50
  and other training state. Required if n_snapshots > 0.
51
+ autocast_sae: Whether to autocast the SAE to bfloat16. Only recommend for large SAEs on CUDA
52
+ autocast_data: Whether to autocast the activations generator and feature dictionary to bfloat16. Only recommend for large data on CUDA.
49
53
  """
50
54
 
51
55
  device_str = str(device) if isinstance(device, torch.device) else device
@@ -55,6 +59,7 @@ def train_toy_sae(
55
59
  feature_dict=feature_dict,
56
60
  activations_generator=activations_generator,
57
61
  batch_size=batch_size,
62
+ autocast=autocast_data,
58
63
  )
59
64
 
60
65
  # Create trainer config
@@ -64,7 +69,7 @@ def train_toy_sae(
64
69
  save_final_checkpoint=False,
65
70
  total_training_samples=training_samples,
66
71
  device=device_str,
67
- autocast=False,
72
+ autocast=autocast_sae,
68
73
  lr=lr,
69
74
  lr_end=lr,
70
75
  lr_scheduler_name="constant",
@@ -119,6 +124,7 @@ class SyntheticActivationIterator(Iterator[torch.Tensor]):
119
124
  feature_dict: FeatureDictionary,
120
125
  activations_generator: ActivationGenerator,
121
126
  batch_size: int,
127
+ autocast: bool = False,
122
128
  ):
123
129
  """
124
130
  Create a new SyntheticActivationIterator.
@@ -127,16 +133,23 @@ class SyntheticActivationIterator(Iterator[torch.Tensor]):
127
133
  feature_dict: The feature dictionary to use for generating hidden activations
128
134
  activations_generator: Generator that produces feature activations
129
135
  batch_size: Number of samples per batch
136
+ autocast: Whether to autocast the activations generator and feature dictionary to bfloat16.
130
137
  """
131
138
  self.feature_dict = feature_dict
132
139
  self.activations_generator = activations_generator
133
140
  self.batch_size = batch_size
141
+ self.autocast = autocast
134
142
 
135
143
  @torch.no_grad()
136
144
  def next_batch(self) -> torch.Tensor:
137
145
  """Generate the next batch of hidden activations."""
138
- features = self.activations_generator(self.batch_size)
139
- return self.feature_dict(features)
146
+ with torch.autocast(
147
+ device_type=self.feature_dict.feature_vectors.device.type,
148
+ dtype=torch.bfloat16,
149
+ enabled=self.autocast,
150
+ ):
151
+ features = self.activations_generator(self.batch_size)
152
+ return self.feature_dict(features)
140
153
 
141
154
  def __iter__(self) -> "SyntheticActivationIterator":
142
155
  return self
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sae-lens
3
- Version: 6.28.2
3
+ Version: 6.29.1
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  License-File: LICENSE
@@ -1,4 +1,4 @@
1
- sae_lens/__init__.py,sha256=B9tY0Jt21pOHmSQrQLpMxQHyUAdLHIZpVP6pg3O0dfQ,4788
1
+ sae_lens/__init__.py,sha256=emqKVNiJwD8YtYhtgHJyAT8YSX1QmruQYuG-J4CStC4,4788
2
2
  sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  sae_lens/analysis/hooked_sae_transformer.py,sha256=dQRgGVwce8XwylL2AzJE7l9elhtMRFCs2hdUj-Qyy4g,14038
4
4
  sae_lens/analysis/neuronpedia_integration.py,sha256=Gx1W7hUBEuMoasNcnOnZ1wmqbXDd1pSZ1nqKEya1HQc,4962
@@ -25,16 +25,16 @@ sae_lens/saes/standard_sae.py,sha256=_hldNZkFPAf9VGrxouR1-tN8T2OEk8IkWBcXoatrC1o
25
25
  sae_lens/saes/temporal_sae.py,sha256=83Ap4mYGfdN3sKdPF8nKjhdXph3-7E2QuLobqJ_YuoM,13273
26
26
  sae_lens/saes/topk_sae.py,sha256=vrMRPrCQR1o8G_kXqY_EAoGZARupkQNFB2dNZVLsusE,21073
27
27
  sae_lens/saes/transcoder.py,sha256=CTpJs8ASOK06npih7gZHygZuxqTR7HICWlOYfTiKjI4,13501
28
- sae_lens/synthetic/__init__.py,sha256=FGUasB6fLPXRFCcrtKfL7vCKDOWebZ5Rx5F9QNJZklI,2875
29
- sae_lens/synthetic/activation_generator.py,sha256=JEN7mEgdGDuXr0ArTwUsSdSVUAfvheT_1Eew2ojbA-g,7659
30
- sae_lens/synthetic/correlation.py,sha256=odr-S5h6c2U-bepwrAQeMfV1iBF_cnnQzqw7zapEXZ4,6056
28
+ sae_lens/synthetic/__init__.py,sha256=MtTnGkTfHV2WjkIgs7zZyx10EK9U5fjOHXy69Aq3uKw,3095
29
+ sae_lens/synthetic/activation_generator.py,sha256=8L9nwC4jFRv_wg3QN-n1sFwX8w1NqwJMysWaJ41lLlY,15197
30
+ sae_lens/synthetic/correlation.py,sha256=tMTLo9fBfDpeXwqhyUgFqnTipj9x2W0t4oEtNxB7AG0,13256
31
31
  sae_lens/synthetic/evals.py,sha256=Nhi314ZnRgLfhBj-3tm_zzI-pGyFTcwllDXbIpPFXeU,4584
32
- sae_lens/synthetic/feature_dictionary.py,sha256=ysn0ihE3JgVlCLUZMb127WYZqbz4kMp9BGHfCZqERBg,6487
32
+ sae_lens/synthetic/feature_dictionary.py,sha256=Nd4xjSTxKMnKilZ3uYi8Gv5SS5D4bv4wHiSL1uGB69E,6933
33
33
  sae_lens/synthetic/firing_probabilities.py,sha256=yclz1pWl5gE1r8LAxFvzQS88Lxwk5-3r8BCX9HLVejA,3370
34
- sae_lens/synthetic/hierarchy.py,sha256=j9-6K7xq6zQS9N8bB5nK_-EbuzAZsY5Z5AfUK-qlB5M,22138
34
+ sae_lens/synthetic/hierarchy.py,sha256=nm7nwnTswktVJeKUsRZ0hLOdXcFWGbxnA1b6lefHm-4,33592
35
35
  sae_lens/synthetic/initialization.py,sha256=orMGW-786wRDHIS2W7bEH0HmlVFQ4g2z4bnnwdv5w4s,1386
36
36
  sae_lens/synthetic/plotting.py,sha256=5lFrej1QOkGAcImFNo5-o-8mI_rUVqvEI57KzUQPPtQ,8208
37
- sae_lens/synthetic/training.py,sha256=Bg6NYxdzifq_8g-dJQSZ_z_TXDdGRtEi7tqNDb-gCVc,4986
37
+ sae_lens/synthetic/training.py,sha256=fHcX2cZ6nDupr71GX0Gk17f1NvQ0SKIVXIA6IuAb2dw,5692
38
38
  sae_lens/tokenization_and_batching.py,sha256=uoHtAs9z3XqG0Fh-iQVYVlrbyB_E3kFFhrKU30BosCo,5438
39
39
  sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
40
40
  sae_lens/training/activation_scaler.py,sha256=FzNfgBplLWmyiSlZ6TUvE-nur3lOiGTrlvC97ys8S24,1973
@@ -46,7 +46,7 @@ sae_lens/training/types.py,sha256=1FpLx_Doda9vZpmfm-x1e8wGBYpyhe9Kpb_JuM5nIFM,90
46
46
  sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
47
47
  sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
48
48
  sae_lens/util.py,sha256=oIMoeyEP2IzcPFmRbKUzOAycgEyMcOasGeO_BGVZbc4,4846
49
- sae_lens-6.28.2.dist-info/METADATA,sha256=i_kbAa64It0NRDrnSlmwNa8qgqOyEMntT_Ifxdx4Q90,6573
50
- sae_lens-6.28.2.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
51
- sae_lens-6.28.2.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
52
- sae_lens-6.28.2.dist-info/RECORD,,
49
+ sae_lens-6.29.1.dist-info/METADATA,sha256=0Pp1L3vNiUGzkMox_BdQR6B064tTHFgwAPGJz8FY8UM,6573
50
+ sae_lens-6.29.1.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
51
+ sae_lens-6.29.1.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
52
+ sae_lens-6.29.1.dist-info/RECORD,,