sae-lens 6.28.1__py3-none-any.whl → 6.29.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sae_lens/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.28.1"
2
+ __version__ = "6.29.1"
3
3
 
4
4
  import logging
5
5
 
@@ -40631,7 +40631,7 @@ gemma-3-1b-res-matryoshka-dc:
40631
40631
  conversion_func: null
40632
40632
  links:
40633
40633
  model: https://huggingface.co/google/gemma-3-1b-pt
40634
- model: gemma-3-1b
40634
+ model: google/gemma-3-1b-pt
40635
40635
  repo_id: chanind/gemma-3-1b-batch-topk-matryoshka-saes-w-32k-l0-40
40636
40636
  saes:
40637
40637
  - id: blocks.0.hook_resid_post
@@ -17,11 +17,14 @@ from sae_lens.synthetic.activation_generator import (
17
17
  ActivationGenerator,
18
18
  ActivationsModifier,
19
19
  ActivationsModifierInput,
20
+ CorrelationMatrixInput,
20
21
  )
21
22
  from sae_lens.synthetic.correlation import (
23
+ LowRankCorrelationMatrix,
22
24
  create_correlation_matrix_from_correlations,
23
25
  generate_random_correlation_matrix,
24
26
  generate_random_correlations,
27
+ generate_random_low_rank_correlation_matrix,
25
28
  )
26
29
  from sae_lens.synthetic.evals import (
27
30
  SyntheticDataEvalResult,
@@ -66,6 +69,9 @@ __all__ = [
66
69
  "create_correlation_matrix_from_correlations",
67
70
  "generate_random_correlations",
68
71
  "generate_random_correlation_matrix",
72
+ "generate_random_low_rank_correlation_matrix",
73
+ "LowRankCorrelationMatrix",
74
+ "CorrelationMatrixInput",
69
75
  # Feature modifiers
70
76
  "ActivationsModifier",
71
77
  "ActivationsModifierInput",
@@ -2,17 +2,21 @@
2
2
  Functions for generating synthetic feature activations.
3
3
  """
4
4
 
5
+ import math
5
6
  from collections.abc import Callable, Sequence
6
7
 
7
8
  import torch
8
- from scipy.stats import norm
9
9
  from torch import nn
10
10
  from torch.distributions import MultivariateNormal
11
11
 
12
+ from sae_lens.synthetic.correlation import LowRankCorrelationMatrix
12
13
  from sae_lens.util import str_to_dtype
13
14
 
14
15
  ActivationsModifier = Callable[[torch.Tensor], torch.Tensor]
15
16
  ActivationsModifierInput = ActivationsModifier | Sequence[ActivationsModifier] | None
17
+ CorrelationMatrixInput = (
18
+ torch.Tensor | LowRankCorrelationMatrix | tuple[torch.Tensor, torch.Tensor]
19
+ )
16
20
 
17
21
 
18
22
  class ActivationGenerator(nn.Module):
@@ -28,7 +32,9 @@ class ActivationGenerator(nn.Module):
28
32
  mean_firing_magnitudes: torch.Tensor
29
33
  modify_activations: ActivationsModifier | None
30
34
  correlation_matrix: torch.Tensor | None
35
+ low_rank_correlation: tuple[torch.Tensor, torch.Tensor] | None
31
36
  correlation_thresholds: torch.Tensor | None
37
+ use_sparse_tensors: bool
32
38
 
33
39
  def __init__(
34
40
  self,
@@ -37,10 +43,37 @@ class ActivationGenerator(nn.Module):
37
43
  std_firing_magnitudes: torch.Tensor | float = 0.0,
38
44
  mean_firing_magnitudes: torch.Tensor | float = 1.0,
39
45
  modify_activations: ActivationsModifierInput = None,
40
- correlation_matrix: torch.Tensor | None = None,
46
+ correlation_matrix: CorrelationMatrixInput | None = None,
41
47
  device: torch.device | str = "cpu",
42
48
  dtype: torch.dtype | str = "float32",
49
+ use_sparse_tensors: bool = False,
43
50
  ):
51
+ """
52
+ Create a new ActivationGenerator.
53
+
54
+ Args:
55
+ num_features: Number of features to generate activations for.
56
+ firing_probabilities: Probability of each feature firing. Can be a single
57
+ float (applied to all features) or a tensor of shape (num_features,).
58
+ std_firing_magnitudes: Standard deviation of firing magnitudes. Can be a
59
+ single float or a tensor of shape (num_features,). Defaults to 0.0
60
+ (deterministic magnitudes).
61
+ mean_firing_magnitudes: Mean firing magnitude when a feature fires. Can be
62
+ a single float or a tensor of shape (num_features,). Defaults to 1.0.
63
+ modify_activations: Optional function(s) to modify activations after
64
+ generation. Can be a single callable, a sequence of callables (applied
65
+ in order), or None. Useful for applying hierarchy constraints.
66
+ correlation_matrix: Optional correlation structure between features. Can be:
67
+
68
+ - A full correlation matrix tensor of shape (num_features, num_features)
69
+ - A LowRankCorrelationMatrix for memory-efficient large-scale correlations
70
+ - A tuple of (factor, diag) tensors representing low-rank structure
71
+
72
+ device: Device to place tensors on. Defaults to "cpu".
73
+ dtype: Data type for tensors. Defaults to "float32".
74
+ use_sparse_tensors: If True, return sparse COO tensors from sample().
75
+ Only recommended when using massive numbers of features. Defaults to False.
76
+ """
44
77
  super().__init__()
45
78
  self.num_features = num_features
46
79
  self.firing_probabilities = _to_tensor(
@@ -54,15 +87,36 @@ class ActivationGenerator(nn.Module):
54
87
  )
55
88
  self.modify_activations = _normalize_modifiers(modify_activations)
56
89
  self.correlation_thresholds = None
90
+ self.correlation_matrix = None
91
+ self.low_rank_correlation = None
92
+ self.use_sparse_tensors = use_sparse_tensors
93
+
57
94
  if correlation_matrix is not None:
58
- _validate_correlation_matrix(correlation_matrix, num_features)
59
- self.correlation_thresholds = torch.tensor(
60
- [norm.ppf(1 - p.item()) for p in self.firing_probabilities],
61
- device=device,
62
- dtype=self.firing_probabilities.dtype,
95
+ if isinstance(correlation_matrix, torch.Tensor):
96
+ # Full correlation matrix
97
+ _validate_correlation_matrix(correlation_matrix, num_features)
98
+ self.correlation_matrix = correlation_matrix
99
+ else:
100
+ # Low-rank correlation (tuple or LowRankCorrelationMatrix)
101
+ correlation_factor, correlation_diag = (
102
+ correlation_matrix[0],
103
+ correlation_matrix[1],
104
+ )
105
+ _validate_low_rank_correlation(
106
+ correlation_factor, correlation_diag, num_features
107
+ )
108
+ # Pre-compute sqrt for efficiency (used every sample call)
109
+ self.low_rank_correlation = (
110
+ correlation_factor,
111
+ correlation_diag.sqrt(),
112
+ )
113
+
114
+ # Vectorized inverse normal CDF: norm.ppf(1-p) = sqrt(2) * erfinv(1 - 2*p)
115
+ self.correlation_thresholds = math.sqrt(2) * torch.erfinv(
116
+ 1 - 2 * self.firing_probabilities
63
117
  )
64
- self.correlation_matrix = correlation_matrix
65
118
 
119
+ @torch.no_grad()
66
120
  def sample(self, batch_size: int) -> torch.Tensor:
67
121
  """
68
122
  Generate a batch of feature activations with controlled properties.
@@ -83,30 +137,74 @@ class ActivationGenerator(nn.Module):
83
137
 
84
138
  if self.correlation_matrix is not None:
85
139
  assert self.correlation_thresholds is not None
86
- firing_features = _generate_correlated_features(
140
+ firing_indices = _generate_correlated_features(
87
141
  batch_size,
88
142
  self.correlation_matrix,
89
143
  self.correlation_thresholds,
90
144
  device,
91
145
  )
146
+ elif self.low_rank_correlation is not None:
147
+ assert self.correlation_thresholds is not None
148
+ firing_indices = _generate_low_rank_correlated_features(
149
+ batch_size,
150
+ self.low_rank_correlation[0],
151
+ self.low_rank_correlation[1],
152
+ self.correlation_thresholds,
153
+ device,
154
+ )
92
155
  else:
93
- firing_features = torch.bernoulli(
156
+ firing_indices = torch.bernoulli(
94
157
  self.firing_probabilities.unsqueeze(0).expand(batch_size, -1)
158
+ ).nonzero(as_tuple=True)
159
+
160
+ # Compute activations only at firing positions (sparse optimization)
161
+ feature_indices = firing_indices[1]
162
+ num_firing = feature_indices.shape[0]
163
+ mean_at_firing = self.mean_firing_magnitudes[feature_indices]
164
+ std_at_firing = self.std_firing_magnitudes[feature_indices]
165
+ random_deltas = (
166
+ torch.randn(
167
+ num_firing, device=device, dtype=self.mean_firing_magnitudes.dtype
95
168
  )
96
-
97
- firing_magnitude_delta = torch.normal(
98
- torch.zeros_like(self.firing_probabilities)
99
- .unsqueeze(0)
100
- .expand(batch_size, -1),
101
- self.std_firing_magnitudes.unsqueeze(0).expand(batch_size, -1),
169
+ * std_at_firing
102
170
  )
103
- firing_magnitude_delta[firing_features == 0] = 0
104
- feature_activations = (
105
- firing_features * self.mean_firing_magnitudes + firing_magnitude_delta
106
- ).relu()
171
+ activations_at_firing = (mean_at_firing + random_deltas).relu()
172
+
173
+ if self.use_sparse_tensors:
174
+ # Return sparse COO tensor
175
+ indices = torch.stack(firing_indices) # [2, nnz]
176
+ feature_activations = torch.sparse_coo_tensor(
177
+ indices,
178
+ activations_at_firing,
179
+ size=(batch_size, self.num_features),
180
+ device=device,
181
+ dtype=self.mean_firing_magnitudes.dtype,
182
+ )
183
+ else:
184
+ # Dense tensor path
185
+ feature_activations = torch.zeros(
186
+ batch_size,
187
+ self.num_features,
188
+ device=device,
189
+ dtype=self.mean_firing_magnitudes.dtype,
190
+ )
191
+ feature_activations[firing_indices] = activations_at_firing
107
192
 
108
193
  if self.modify_activations is not None:
109
- feature_activations = self.modify_activations(feature_activations).relu()
194
+ feature_activations = self.modify_activations(feature_activations)
195
+ if feature_activations.is_sparse:
196
+ # Apply relu to sparse values
197
+ feature_activations = feature_activations.coalesce()
198
+ feature_activations = torch.sparse_coo_tensor(
199
+ feature_activations.indices(),
200
+ feature_activations.values().relu(),
201
+ feature_activations.shape,
202
+ device=feature_activations.device,
203
+ dtype=feature_activations.dtype,
204
+ )
205
+ else:
206
+ feature_activations = feature_activations.relu()
207
+
110
208
  return feature_activations
111
209
 
112
210
  def forward(self, batch_size: int) -> torch.Tensor:
@@ -118,7 +216,7 @@ def _generate_correlated_features(
118
216
  correlation_matrix: torch.Tensor,
119
217
  thresholds: torch.Tensor,
120
218
  device: torch.device,
121
- ) -> torch.Tensor:
219
+ ) -> tuple[torch.Tensor, torch.Tensor]:
122
220
  """
123
221
  Generate correlated binary features using multivariate Gaussian sampling.
124
222
 
@@ -132,7 +230,7 @@ def _generate_correlated_features(
132
230
  device: Device to generate samples on
133
231
 
134
232
  Returns:
135
- Binary feature matrix of shape [batch_size, num_features]
233
+ Tuple of (row_indices, col_indices) for firing features
136
234
  """
137
235
  num_features = correlation_matrix.shape[0]
138
236
 
@@ -142,7 +240,49 @@ def _generate_correlated_features(
142
240
  )
143
241
 
144
242
  gaussian_samples = mvn.sample((batch_size,))
145
- return (gaussian_samples > thresholds.unsqueeze(0)).float()
243
+ indices = (gaussian_samples > thresholds.unsqueeze(0)).nonzero(as_tuple=True)
244
+ return indices[0], indices[1]
245
+
246
+
247
+ def _generate_low_rank_correlated_features(
248
+ batch_size: int,
249
+ correlation_factor: torch.Tensor,
250
+ cov_diag_sqrt: torch.Tensor,
251
+ thresholds: torch.Tensor,
252
+ device: torch.device,
253
+ ) -> tuple[torch.Tensor, torch.Tensor]:
254
+ """
255
+ Generate correlated binary features using low-rank multivariate Gaussian sampling.
256
+
257
+ Uses the Gaussian copula approach with a low-rank covariance structure for scalability.
258
+ The covariance is represented as: cov = factor @ factor.T + diag(diag_term)
259
+
260
+ Args:
261
+ batch_size: Number of samples to generate
262
+ correlation_factor: Factor matrix of shape (num_features, rank)
263
+ cov_diag_sqrt: Pre-computed sqrt of diagonal term, shape (num_features,)
264
+ thresholds: Pre-computed thresholds for each feature (from inverse normal CDF)
265
+ device: Device to generate samples on
266
+
267
+ Returns:
268
+ Tuple of (row_indices, col_indices) for firing features
269
+ """
270
+ # Manual low-rank MVN sampling to enable autocast for the expensive matmul
271
+ # samples = eps @ cov_factor.T + eta * sqrt(cov_diag)
272
+ # where eps ~ N(0, I_rank) and eta ~ N(0, I_n)
273
+
274
+ num_features, rank = correlation_factor.shape
275
+
276
+ # Generate random samples in float32 for numerical stability
277
+ eps = torch.randn(batch_size, rank, device=device, dtype=correlation_factor.dtype)
278
+ eta = torch.randn(
279
+ batch_size, num_features, device=device, dtype=cov_diag_sqrt.dtype
280
+ )
281
+
282
+ gaussian_samples = eps @ correlation_factor.T + eta * cov_diag_sqrt
283
+
284
+ indices = (gaussian_samples > thresholds.unsqueeze(0)).nonzero(as_tuple=True)
285
+ return indices[0], indices[1]
146
286
 
147
287
 
148
288
  def _to_tensor(
@@ -193,7 +333,7 @@ def _validate_correlation_matrix(
193
333
 
194
334
  Args:
195
335
  correlation_matrix: The matrix to validate
196
- num_features: Expected number of features (matrix should be [num_features, num_features])
336
+ num_features: Expected number of features (matrix should be (num_features, num_features))
197
337
 
198
338
  Raises:
199
339
  ValueError: If the matrix has incorrect shape, non-unit diagonal, or is not positive definite
@@ -213,3 +353,36 @@ def _validate_correlation_matrix(
213
353
  torch.linalg.cholesky(correlation_matrix)
214
354
  except RuntimeError as e:
215
355
  raise ValueError("Correlation matrix must be positive definite") from e
356
+
357
+
358
+ def _validate_low_rank_correlation(
359
+ correlation_factor: torch.Tensor,
360
+ correlation_diag: torch.Tensor,
361
+ num_features: int,
362
+ ) -> None:
363
+ """Validate that low-rank correlation parameters have correct properties.
364
+
365
+ Args:
366
+ correlation_factor: Factor matrix of shape (num_features, rank)
367
+ correlation_diag: Diagonal term of shape (num_features,)
368
+ num_features: Expected number of features
369
+
370
+ Raises:
371
+ ValueError: If shapes are incorrect or diagonal terms are not positive
372
+ """
373
+ if correlation_factor.ndim != 2:
374
+ raise ValueError(
375
+ f"correlation_factor must be 2D, got {correlation_factor.ndim}D"
376
+ )
377
+ if correlation_factor.shape[0] != num_features:
378
+ raise ValueError(
379
+ f"correlation_factor must have shape ({num_features}, rank), "
380
+ f"got {tuple(correlation_factor.shape)}"
381
+ )
382
+ if correlation_diag.shape != (num_features,):
383
+ raise ValueError(
384
+ f"correlation_diag must have shape ({num_features},), "
385
+ f"got {tuple(correlation_diag.shape)}"
386
+ )
387
+ if torch.any(correlation_diag <= 0):
388
+ raise ValueError("correlation_diag must have all positive values")
@@ -1,7 +1,32 @@
1
1
  import random
2
+ from typing import NamedTuple
2
3
 
3
4
  import torch
4
5
 
6
+ from sae_lens.util import str_to_dtype
7
+
8
+
9
+ class LowRankCorrelationMatrix(NamedTuple):
10
+ """
11
+ Low-rank representation of a correlation matrix for scalable correlated sampling.
12
+
13
+ The correlation structure is represented as:
14
+ correlation = correlation_factor @ correlation_factor.T + diag(correlation_diag)
15
+
16
+ This requires O(num_features * rank) storage instead of O(num_features^2),
17
+ making it suitable for very large numbers of features (e.g., 1M+).
18
+
19
+ Attributes:
20
+ correlation_factor: Factor matrix of shape (num_features, rank) that captures
21
+ correlations through shared latent factors.
22
+ correlation_diag: Diagonal variance term of shape (num_features,). Should be
23
+ chosen such that the diagonal of the full correlation matrix equals 1.
24
+ Typically: correlation_diag[i] = 1 - sum(correlation_factor[i, :]^2)
25
+ """
26
+
27
+ correlation_factor: torch.Tensor
28
+ correlation_diag: torch.Tensor
29
+
5
30
 
6
31
  def create_correlation_matrix_from_correlations(
7
32
  num_features: int,
@@ -11,14 +36,18 @@ def create_correlation_matrix_from_correlations(
11
36
  """
12
37
  Create a correlation matrix with specified pairwise correlations.
13
38
 
39
+ Note: If the resulting matrix is not positive definite, it will be adjusted
40
+ to ensure validity. This adjustment may change the specified correlation
41
+ values. To minimize this effect, use smaller correlation magnitudes.
42
+
14
43
  Args:
15
44
  num_features: Number of features
16
45
  correlations: Dict mapping (i, j) pairs to correlation values.
17
- Pairs should have i < j.
46
+ Pairs should have i < j. Pairs not specified will use default_correlation.
18
47
  default_correlation: Default correlation for unspecified pairs
19
48
 
20
49
  Returns:
21
- Correlation matrix of shape [num_features, num_features]
50
+ Correlation matrix of shape (num_features, num_features)
22
51
  """
23
52
  matrix = torch.eye(num_features) + default_correlation * (
24
53
  1 - torch.eye(num_features)
@@ -50,6 +79,7 @@ def _fix_correlation_matrix(
50
79
  fixed_matrix = eigenvecs @ torch.diag(eigenvals) @ eigenvecs.T
51
80
 
52
81
  diag_vals = torch.diag(fixed_matrix)
82
+ diag_vals = torch.clamp(diag_vals, min=1e-8) # Prevent division by zero
53
83
  fixed_matrix = fixed_matrix / torch.sqrt(
54
84
  diag_vals.unsqueeze(0) * diag_vals.unsqueeze(1)
55
85
  )
@@ -58,6 +88,25 @@ def _fix_correlation_matrix(
58
88
  return fixed_matrix
59
89
 
60
90
 
91
+ def _validate_correlation_params(
92
+ positive_ratio: float,
93
+ uncorrelated_ratio: float,
94
+ min_correlation_strength: float,
95
+ max_correlation_strength: float,
96
+ ) -> None:
97
+ """Validate parameters for correlation generation."""
98
+ if not 0.0 <= positive_ratio <= 1.0:
99
+ raise ValueError("positive_ratio must be between 0.0 and 1.0")
100
+ if not 0.0 <= uncorrelated_ratio <= 1.0:
101
+ raise ValueError("uncorrelated_ratio must be between 0.0 and 1.0")
102
+ if min_correlation_strength < 0:
103
+ raise ValueError("min_correlation_strength must be non-negative")
104
+ if max_correlation_strength > 1.0:
105
+ raise ValueError("max_correlation_strength must be <= 1.0")
106
+ if min_correlation_strength > max_correlation_strength:
107
+ raise ValueError("min_correlation_strength must be <= max_correlation_strength")
108
+
109
+
61
110
  def generate_random_correlations(
62
111
  num_features: int,
63
112
  positive_ratio: float = 0.5,
@@ -71,29 +120,26 @@ def generate_random_correlations(
71
120
 
72
121
  Args:
73
122
  num_features: Number of features
74
- positive_ratio: Fraction of correlations that should be positive (0.0 to 1.0)
75
- uncorrelated_ratio: Fraction of feature pairs that should remain uncorrelated (0.0 to 1.0)
76
- min_correlation_strength: Minimum absolute correlation strength
77
- max_correlation_strength: Maximum absolute correlation strength
123
+ positive_ratio: Fraction of correlated pairs that should be positive (0.0 to 1.0)
124
+ uncorrelated_ratio: Fraction of feature pairs that should have zero correlation
125
+ (0.0 to 1.0). These pairs are omitted from the returned dictionary.
126
+ min_correlation_strength: Minimum absolute correlation strength for correlated pairs
127
+ max_correlation_strength: Maximum absolute correlation strength for correlated pairs
78
128
  seed: Random seed for reproducibility
79
129
 
80
130
  Returns:
81
- Dictionary mapping (i, j) pairs to correlation values
131
+ Dictionary mapping (i, j) pairs to correlation values. Pairs with zero
132
+ correlation (determined by uncorrelated_ratio) are not included.
82
133
  """
83
134
  # Use local random number generator to avoid side effects on global state
84
135
  rng = random.Random(seed)
85
136
 
86
- # Validate inputs
87
- if not 0.0 <= positive_ratio <= 1.0:
88
- raise ValueError("positive_ratio must be between 0.0 and 1.0")
89
- if not 0.0 <= uncorrelated_ratio <= 1.0:
90
- raise ValueError("uncorrelated_ratio must be between 0.0 and 1.0")
91
- if min_correlation_strength < 0:
92
- raise ValueError("min_correlation_strength must be non-negative")
93
- if max_correlation_strength > 1.0:
94
- raise ValueError("max_correlation_strength must be <= 1.0")
95
- if min_correlation_strength > max_correlation_strength:
96
- raise ValueError("min_correlation_strength must be <= max_correlation_strength")
137
+ _validate_correlation_params(
138
+ positive_ratio,
139
+ uncorrelated_ratio,
140
+ min_correlation_strength,
141
+ max_correlation_strength,
142
+ )
97
143
 
98
144
  # Generate all possible feature pairs (i, j) where i < j
99
145
  all_pairs = [
@@ -136,35 +182,170 @@ def generate_random_correlation_matrix(
136
182
  min_correlation_strength: float = 0.1,
137
183
  max_correlation_strength: float = 0.8,
138
184
  seed: int | None = None,
185
+ device: torch.device | str = "cpu",
186
+ dtype: torch.dtype | str = torch.float32,
139
187
  ) -> torch.Tensor:
140
188
  """
141
189
  Generate a random correlation matrix with specified constraints.
142
190
 
143
- This is a convenience function that combines generate_random_correlations()
144
- and create_correlation_matrix_from_correlations() into a single call.
191
+ Uses vectorized torch operations for efficiency with large numbers of features.
192
+
193
+ Note: If the randomly generated matrix is not positive definite, it will be
194
+ adjusted to ensure validity. This adjustment may change correlation values,
195
+ including turning some zero correlations into non-zero values. To minimize
196
+ this effect, use smaller correlation strengths (e.g., 0.01-0.1).
197
+
198
+ Args:
199
+ num_features: Number of features
200
+ positive_ratio: Fraction of correlated pairs that should be positive (0.0 to 1.0)
201
+ uncorrelated_ratio: Fraction of feature pairs that should have zero correlation
202
+ (0.0 to 1.0). Note that matrix fixing for positive definiteness may reduce
203
+ the actual number of zero correlations.
204
+ min_correlation_strength: Minimum absolute correlation strength for correlated pairs
205
+ max_correlation_strength: Maximum absolute correlation strength for correlated pairs
206
+ seed: Random seed for reproducibility
207
+ device: Device to create the matrix on
208
+ dtype: Data type for the matrix
209
+
210
+ Returns:
211
+ Random correlation matrix of shape (num_features, num_features)
212
+ """
213
+ dtype = str_to_dtype(dtype)
214
+ _validate_correlation_params(
215
+ positive_ratio,
216
+ uncorrelated_ratio,
217
+ min_correlation_strength,
218
+ max_correlation_strength,
219
+ )
220
+
221
+ if num_features <= 1:
222
+ return torch.eye(num_features, device=device, dtype=dtype)
223
+
224
+ # Set random seed if provided
225
+ generator = torch.Generator(device=device)
226
+ if seed is not None:
227
+ generator.manual_seed(seed)
228
+
229
+ # Get upper triangular indices (i < j)
230
+ row_idx, col_idx = torch.triu_indices(num_features, num_features, offset=1)
231
+ num_pairs = row_idx.shape[0]
232
+
233
+ # Generate random values for all pairs at once
234
+ # is_correlated: 1 if this pair should have a correlation, 0 otherwise
235
+ is_correlated = (
236
+ torch.rand(num_pairs, generator=generator, device=device) >= uncorrelated_ratio
237
+ )
238
+
239
+ # signs: +1 for positive correlation, -1 for negative
240
+ is_positive = (
241
+ torch.rand(num_pairs, generator=generator, device=device) < positive_ratio
242
+ )
243
+ signs = torch.where(is_positive, 1.0, -1.0)
244
+
245
+ # strengths: uniform in [min_strength, max_strength]
246
+ strengths = (
247
+ torch.rand(num_pairs, generator=generator, device=device, dtype=dtype)
248
+ * (max_correlation_strength - min_correlation_strength)
249
+ + min_correlation_strength
250
+ )
251
+
252
+ # Combine: correlation = is_correlated * sign * strength
253
+ correlations = is_correlated.to(dtype) * signs.to(dtype) * strengths
254
+
255
+ # Build the symmetric matrix
256
+ matrix = torch.eye(num_features, device=device, dtype=dtype)
257
+ matrix[row_idx, col_idx] = correlations
258
+ matrix[col_idx, row_idx] = correlations
259
+
260
+ # Check positive definiteness and fix if necessary
261
+ eigenvals = torch.linalg.eigvalsh(matrix)
262
+ if torch.any(eigenvals < -1e-6):
263
+ matrix = _fix_correlation_matrix(matrix)
264
+
265
+ return matrix
266
+
267
+
268
+ def generate_random_low_rank_correlation_matrix(
269
+ num_features: int,
270
+ rank: int,
271
+ correlation_scale: float = 0.1,
272
+ seed: int | None = None,
273
+ device: torch.device | str = "cpu",
274
+ dtype: torch.dtype | str = torch.float32,
275
+ ) -> LowRankCorrelationMatrix:
276
+ """
277
+ Generate a random low-rank correlation structure for scalable correlated sampling.
278
+
279
+ The correlation structure is represented as:
280
+ correlation = factor @ factor.T + diag(diag_term)
281
+
282
+ This requires O(num_features * rank) storage instead of O(num_features^2),
283
+ making it suitable for very large numbers of features (e.g., 1M+).
284
+
285
+ The factor matrix is initialized with random values scaled by correlation_scale,
286
+ and the diagonal term is computed to ensure the implied correlation matrix has
287
+ unit diagonal.
145
288
 
146
289
  Args:
147
290
  num_features: Number of features
148
- positive_ratio: Fraction of correlations that should be positive (0.0 to 1.0)
149
- uncorrelated_ratio: Fraction of feature pairs that should remain uncorrelated (0.0 to 1.0)
150
- min_correlation_strength: Minimum absolute correlation strength
151
- max_correlation_strength: Maximum absolute correlation strength
291
+ rank: Rank of the low-rank approximation. Higher rank allows more complex
292
+ correlation structures but uses more memory. Typical values: 10-100.
293
+ correlation_scale: Scale factor for random correlations. Larger values produce
294
+ stronger correlations between features. Use 0 for no correlations (identity
295
+ matrix). Should be small enough that rank * correlation_scale^2 < 1 to
296
+ ensure valid diagonal terms.
152
297
  seed: Random seed for reproducibility
298
+ device: Device to create tensors on
299
+ dtype: Data type for tensors
153
300
 
154
301
  Returns:
155
- Random correlation matrix of shape [num_features, num_features]
302
+ LowRankCorrelationMatrix containing the factor matrix and diagonal term
156
303
  """
157
- # Generate random correlations
158
- correlations = generate_random_correlations(
159
- num_features=num_features,
160
- positive_ratio=positive_ratio,
161
- uncorrelated_ratio=uncorrelated_ratio,
162
- min_correlation_strength=min_correlation_strength,
163
- max_correlation_strength=max_correlation_strength,
164
- seed=seed,
304
+ # Minimum diagonal value to ensure numerical stability in the covariance matrix.
305
+ # This limits how much variance can come from the low-rank factor.
306
+ _MIN_DIAG = 0.01
307
+
308
+ dtype = str_to_dtype(dtype)
309
+ device = torch.device(device)
310
+
311
+ if rank <= 0:
312
+ raise ValueError("rank must be positive")
313
+ if correlation_scale < 0:
314
+ raise ValueError("correlation_scale must be non-negative")
315
+
316
+ # Set random seed if provided
317
+ generator = torch.Generator(device=device)
318
+ if seed is not None:
319
+ generator.manual_seed(seed)
320
+
321
+ # Generate random factor matrix
322
+ # Each row has norm roughly sqrt(rank) * correlation_scale
323
+ factor = (
324
+ torch.randn(num_features, rank, generator=generator, device=device, dtype=dtype)
325
+ * correlation_scale
165
326
  )
166
327
 
167
- # Create and return correlation matrix
168
- return create_correlation_matrix_from_correlations(
169
- num_features=num_features, correlations=correlations
328
+ # Compute diagonal term to ensure unit diagonal in implied correlation matrix
329
+ # diag(factor @ factor.T) + diag_term = 1
330
+ # diag_term = 1 - sum(factor[i, :]^2)
331
+ factor_sq_sum = (factor**2).sum(dim=1)
332
+ diag_term = 1 - factor_sq_sum
333
+
334
+ # Ensure diagonal terms are at least _MIN_DIAG for numerical stability
335
+ # If any diagonal term is too small, scale down the factor matrix
336
+ if torch.any(diag_term < _MIN_DIAG):
337
+ # Scale factor so max row norm squared is at most (1 - _MIN_DIAG)
338
+ # This ensures all diagonal terms are >= _MIN_DIAG
339
+ max_factor_contribution = 1 - _MIN_DIAG
340
+ max_sq_sum = factor_sq_sum.max()
341
+ scale = torch.sqrt(
342
+ torch.tensor(max_factor_contribution, device=device, dtype=dtype)
343
+ / max_sq_sum
344
+ )
345
+ factor = factor * scale
346
+ factor_sq_sum = (factor**2).sum(dim=1)
347
+ diag_term = 1 - factor_sq_sum
348
+
349
+ return LowRankCorrelationMatrix(
350
+ correlation_factor=factor, correlation_diag=diag_term
170
351
  )