sae-lens 6.28.1__tar.gz → 6.29.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sae_lens-6.28.1 → sae_lens-6.29.0}/PKG-INFO +11 -1
- {sae_lens-6.28.1 → sae_lens-6.29.0}/README.md +10 -0
- {sae_lens-6.28.1 → sae_lens-6.29.0}/pyproject.toml +1 -1
- {sae_lens-6.28.1 → sae_lens-6.29.0}/sae_lens/__init__.py +1 -1
- {sae_lens-6.28.1 → sae_lens-6.29.0}/sae_lens/pretrained_saes.yaml +1 -1
- {sae_lens-6.28.1 → sae_lens-6.29.0}/sae_lens/synthetic/__init__.py +6 -0
- {sae_lens-6.28.1 → sae_lens-6.29.0}/sae_lens/synthetic/activation_generator.py +105 -6
- sae_lens-6.29.0/sae_lens/synthetic/correlation.py +351 -0
- {sae_lens-6.28.1 → sae_lens-6.29.0}/sae_lens/synthetic/feature_dictionary.py +54 -16
- sae_lens-6.29.0/sae_lens/synthetic/hierarchy.py +596 -0
- sae_lens-6.28.1/sae_lens/synthetic/correlation.py +0 -170
- sae_lens-6.28.1/sae_lens/synthetic/hierarchy.py +0 -335
- {sae_lens-6.28.1 → sae_lens-6.29.0}/LICENSE +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.0}/sae_lens/analysis/__init__.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.0}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.0}/sae_lens/analysis/neuronpedia_integration.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.0}/sae_lens/cache_activations_runner.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.0}/sae_lens/config.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.0}/sae_lens/constants.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.0}/sae_lens/evals.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.0}/sae_lens/llm_sae_training_runner.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.0}/sae_lens/load_model.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.0}/sae_lens/loading/__init__.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.0}/sae_lens/loading/pretrained_sae_loaders.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.0}/sae_lens/loading/pretrained_saes_directory.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.0}/sae_lens/pretokenize_runner.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.0}/sae_lens/registry.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.0}/sae_lens/saes/__init__.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.0}/sae_lens/saes/batchtopk_sae.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.0}/sae_lens/saes/gated_sae.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.0}/sae_lens/saes/jumprelu_sae.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.0}/sae_lens/saes/matching_pursuit_sae.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.0}/sae_lens/saes/matryoshka_batchtopk_sae.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.0}/sae_lens/saes/sae.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.0}/sae_lens/saes/standard_sae.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.0}/sae_lens/saes/temporal_sae.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.0}/sae_lens/saes/topk_sae.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.0}/sae_lens/saes/transcoder.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.0}/sae_lens/synthetic/evals.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.0}/sae_lens/synthetic/firing_probabilities.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.0}/sae_lens/synthetic/initialization.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.0}/sae_lens/synthetic/plotting.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.0}/sae_lens/synthetic/training.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.0}/sae_lens/tokenization_and_batching.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.0}/sae_lens/training/__init__.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.0}/sae_lens/training/activation_scaler.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.0}/sae_lens/training/activations_store.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.0}/sae_lens/training/mixing_buffer.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.0}/sae_lens/training/optim.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.0}/sae_lens/training/sae_trainer.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.0}/sae_lens/training/types.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.0}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.0}/sae_lens/tutorial/tsea.py +0 -0
- {sae_lens-6.28.1 → sae_lens-6.29.0}/sae_lens/util.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: sae-lens
|
|
3
|
-
Version: 6.
|
|
3
|
+
Version: 6.29.0
|
|
4
4
|
Summary: Training and Analyzing Sparse Autoencoders (SAEs)
|
|
5
5
|
License: MIT
|
|
6
6
|
License-File: LICENSE
|
|
@@ -50,6 +50,8 @@ SAELens exists to help researchers:
|
|
|
50
50
|
- Analyse sparse autoencoders / research mechanistic interpretability.
|
|
51
51
|
- Generate insights which make it easier to create safe and aligned AI systems.
|
|
52
52
|
|
|
53
|
+
SAELens inference works with any PyTorch-based model, not just TransformerLens. While we provide deep integration with TransformerLens via `HookedSAETransformer`, SAEs can be used with Hugging Face Transformers, NNsight, or any other framework by extracting activations and passing them to the SAE's `encode()` and `decode()` methods.
|
|
54
|
+
|
|
53
55
|
Please refer to the [documentation](https://decoderesearch.github.io/SAELens/) for information on how to:
|
|
54
56
|
|
|
55
57
|
- Download and Analyse pre-trained sparse autoencoders.
|
|
@@ -84,6 +86,14 @@ The new v6 update is a major refactor to SAELens and changes the way training co
|
|
|
84
86
|
|
|
85
87
|
Feel free to join the [Open Source Mechanistic Interpretability Slack](https://join.slack.com/t/opensourcemechanistic/shared_invite/zt-375zalm04-GFd5tdBU1yLKlu_T_JSqZQ) for support!
|
|
86
88
|
|
|
89
|
+
## Other SAE Projects
|
|
90
|
+
|
|
91
|
+
- [dictionary-learning](https://github.com/saprmarks/dictionary_learning): An SAE training library that focuses on having hackable code.
|
|
92
|
+
- [Sparsify](https://github.com/EleutherAI/sparsify): A lean SAE training library focused on TopK SAEs.
|
|
93
|
+
- [Overcomplete](https://github.com/KempnerInstitute/overcomplete): SAE training library focused on vision models.
|
|
94
|
+
- [SAE-Vis](https://github.com/callummcdougall/sae_vis): A library for visualizing SAE features, works with SAELens.
|
|
95
|
+
- [SAEBench](https://github.com/adamkarvonen/SAEBench): A suite of LLM SAE benchmarks, works with SAELens.
|
|
96
|
+
|
|
87
97
|
## Citation
|
|
88
98
|
|
|
89
99
|
Please cite the package as follows:
|
|
@@ -14,6 +14,8 @@ SAELens exists to help researchers:
|
|
|
14
14
|
- Analyse sparse autoencoders / research mechanistic interpretability.
|
|
15
15
|
- Generate insights which make it easier to create safe and aligned AI systems.
|
|
16
16
|
|
|
17
|
+
SAELens inference works with any PyTorch-based model, not just TransformerLens. While we provide deep integration with TransformerLens via `HookedSAETransformer`, SAEs can be used with Hugging Face Transformers, NNsight, or any other framework by extracting activations and passing them to the SAE's `encode()` and `decode()` methods.
|
|
18
|
+
|
|
17
19
|
Please refer to the [documentation](https://decoderesearch.github.io/SAELens/) for information on how to:
|
|
18
20
|
|
|
19
21
|
- Download and Analyse pre-trained sparse autoencoders.
|
|
@@ -48,6 +50,14 @@ The new v6 update is a major refactor to SAELens and changes the way training co
|
|
|
48
50
|
|
|
49
51
|
Feel free to join the [Open Source Mechanistic Interpretability Slack](https://join.slack.com/t/opensourcemechanistic/shared_invite/zt-375zalm04-GFd5tdBU1yLKlu_T_JSqZQ) for support!
|
|
50
52
|
|
|
53
|
+
## Other SAE Projects
|
|
54
|
+
|
|
55
|
+
- [dictionary-learning](https://github.com/saprmarks/dictionary_learning): An SAE training library that focuses on having hackable code.
|
|
56
|
+
- [Sparsify](https://github.com/EleutherAI/sparsify): A lean SAE training library focused on TopK SAEs.
|
|
57
|
+
- [Overcomplete](https://github.com/KempnerInstitute/overcomplete): SAE training library focused on vision models.
|
|
58
|
+
- [SAE-Vis](https://github.com/callummcdougall/sae_vis): A library for visualizing SAE features, works with SAELens.
|
|
59
|
+
- [SAEBench](https://github.com/adamkarvonen/SAEBench): A suite of LLM SAE benchmarks, works with SAELens.
|
|
60
|
+
|
|
51
61
|
## Citation
|
|
52
62
|
|
|
53
63
|
Please cite the package as follows:
|
|
@@ -40631,7 +40631,7 @@ gemma-3-1b-res-matryoshka-dc:
|
|
|
40631
40631
|
conversion_func: null
|
|
40632
40632
|
links:
|
|
40633
40633
|
model: https://huggingface.co/google/gemma-3-1b-pt
|
|
40634
|
-
model: gemma-3-1b
|
|
40634
|
+
model: google/gemma-3-1b-pt
|
|
40635
40635
|
repo_id: chanind/gemma-3-1b-batch-topk-matryoshka-saes-w-32k-l0-40
|
|
40636
40636
|
saes:
|
|
40637
40637
|
- id: blocks.0.hook_resid_post
|
|
@@ -17,11 +17,14 @@ from sae_lens.synthetic.activation_generator import (
|
|
|
17
17
|
ActivationGenerator,
|
|
18
18
|
ActivationsModifier,
|
|
19
19
|
ActivationsModifierInput,
|
|
20
|
+
CorrelationMatrixInput,
|
|
20
21
|
)
|
|
21
22
|
from sae_lens.synthetic.correlation import (
|
|
23
|
+
LowRankCorrelationMatrix,
|
|
22
24
|
create_correlation_matrix_from_correlations,
|
|
23
25
|
generate_random_correlation_matrix,
|
|
24
26
|
generate_random_correlations,
|
|
27
|
+
generate_random_low_rank_correlation_matrix,
|
|
25
28
|
)
|
|
26
29
|
from sae_lens.synthetic.evals import (
|
|
27
30
|
SyntheticDataEvalResult,
|
|
@@ -66,6 +69,9 @@ __all__ = [
|
|
|
66
69
|
"create_correlation_matrix_from_correlations",
|
|
67
70
|
"generate_random_correlations",
|
|
68
71
|
"generate_random_correlation_matrix",
|
|
72
|
+
"generate_random_low_rank_correlation_matrix",
|
|
73
|
+
"LowRankCorrelationMatrix",
|
|
74
|
+
"CorrelationMatrixInput",
|
|
69
75
|
# Feature modifiers
|
|
70
76
|
"ActivationsModifier",
|
|
71
77
|
"ActivationsModifierInput",
|
|
@@ -7,12 +7,16 @@ from collections.abc import Callable, Sequence
|
|
|
7
7
|
import torch
|
|
8
8
|
from scipy.stats import norm
|
|
9
9
|
from torch import nn
|
|
10
|
-
from torch.distributions import MultivariateNormal
|
|
10
|
+
from torch.distributions import LowRankMultivariateNormal, MultivariateNormal
|
|
11
11
|
|
|
12
|
+
from sae_lens.synthetic.correlation import LowRankCorrelationMatrix
|
|
12
13
|
from sae_lens.util import str_to_dtype
|
|
13
14
|
|
|
14
15
|
ActivationsModifier = Callable[[torch.Tensor], torch.Tensor]
|
|
15
16
|
ActivationsModifierInput = ActivationsModifier | Sequence[ActivationsModifier] | None
|
|
17
|
+
CorrelationMatrixInput = (
|
|
18
|
+
torch.Tensor | LowRankCorrelationMatrix | tuple[torch.Tensor, torch.Tensor]
|
|
19
|
+
)
|
|
16
20
|
|
|
17
21
|
|
|
18
22
|
class ActivationGenerator(nn.Module):
|
|
@@ -28,6 +32,7 @@ class ActivationGenerator(nn.Module):
|
|
|
28
32
|
mean_firing_magnitudes: torch.Tensor
|
|
29
33
|
modify_activations: ActivationsModifier | None
|
|
30
34
|
correlation_matrix: torch.Tensor | None
|
|
35
|
+
low_rank_correlation: tuple[torch.Tensor, torch.Tensor] | None
|
|
31
36
|
correlation_thresholds: torch.Tensor | None
|
|
32
37
|
|
|
33
38
|
def __init__(
|
|
@@ -37,7 +42,7 @@ class ActivationGenerator(nn.Module):
|
|
|
37
42
|
std_firing_magnitudes: torch.Tensor | float = 0.0,
|
|
38
43
|
mean_firing_magnitudes: torch.Tensor | float = 1.0,
|
|
39
44
|
modify_activations: ActivationsModifierInput = None,
|
|
40
|
-
correlation_matrix:
|
|
45
|
+
correlation_matrix: CorrelationMatrixInput | None = None,
|
|
41
46
|
device: torch.device | str = "cpu",
|
|
42
47
|
dtype: torch.dtype | str = "float32",
|
|
43
48
|
):
|
|
@@ -54,15 +59,32 @@ class ActivationGenerator(nn.Module):
|
|
|
54
59
|
)
|
|
55
60
|
self.modify_activations = _normalize_modifiers(modify_activations)
|
|
56
61
|
self.correlation_thresholds = None
|
|
62
|
+
self.correlation_matrix = None
|
|
63
|
+
self.low_rank_correlation = None
|
|
64
|
+
|
|
57
65
|
if correlation_matrix is not None:
|
|
58
|
-
|
|
66
|
+
if isinstance(correlation_matrix, torch.Tensor):
|
|
67
|
+
# Full correlation matrix
|
|
68
|
+
_validate_correlation_matrix(correlation_matrix, num_features)
|
|
69
|
+
self.correlation_matrix = correlation_matrix
|
|
70
|
+
else:
|
|
71
|
+
# Low-rank correlation (tuple or LowRankCorrelationMatrix)
|
|
72
|
+
correlation_factor, correlation_diag = (
|
|
73
|
+
correlation_matrix[0],
|
|
74
|
+
correlation_matrix[1],
|
|
75
|
+
)
|
|
76
|
+
_validate_low_rank_correlation(
|
|
77
|
+
correlation_factor, correlation_diag, num_features
|
|
78
|
+
)
|
|
79
|
+
self.low_rank_correlation = (correlation_factor, correlation_diag)
|
|
80
|
+
|
|
59
81
|
self.correlation_thresholds = torch.tensor(
|
|
60
82
|
[norm.ppf(1 - p.item()) for p in self.firing_probabilities],
|
|
61
83
|
device=device,
|
|
62
84
|
dtype=self.firing_probabilities.dtype,
|
|
63
85
|
)
|
|
64
|
-
self.correlation_matrix = correlation_matrix
|
|
65
86
|
|
|
87
|
+
@torch.no_grad()
|
|
66
88
|
def sample(self, batch_size: int) -> torch.Tensor:
|
|
67
89
|
"""
|
|
68
90
|
Generate a batch of feature activations with controlled properties.
|
|
@@ -89,6 +111,15 @@ class ActivationGenerator(nn.Module):
|
|
|
89
111
|
self.correlation_thresholds,
|
|
90
112
|
device,
|
|
91
113
|
)
|
|
114
|
+
elif self.low_rank_correlation is not None:
|
|
115
|
+
assert self.correlation_thresholds is not None
|
|
116
|
+
firing_features = _generate_low_rank_correlated_features(
|
|
117
|
+
batch_size,
|
|
118
|
+
self.low_rank_correlation[0],
|
|
119
|
+
self.low_rank_correlation[1],
|
|
120
|
+
self.correlation_thresholds,
|
|
121
|
+
device,
|
|
122
|
+
)
|
|
92
123
|
else:
|
|
93
124
|
firing_features = torch.bernoulli(
|
|
94
125
|
self.firing_probabilities.unsqueeze(0).expand(batch_size, -1)
|
|
@@ -132,7 +163,7 @@ def _generate_correlated_features(
|
|
|
132
163
|
device: Device to generate samples on
|
|
133
164
|
|
|
134
165
|
Returns:
|
|
135
|
-
Binary feature matrix of shape
|
|
166
|
+
Binary feature matrix of shape (batch_size, num_features)
|
|
136
167
|
"""
|
|
137
168
|
num_features = correlation_matrix.shape[0]
|
|
138
169
|
|
|
@@ -145,6 +176,41 @@ def _generate_correlated_features(
|
|
|
145
176
|
return (gaussian_samples > thresholds.unsqueeze(0)).float()
|
|
146
177
|
|
|
147
178
|
|
|
179
|
+
def _generate_low_rank_correlated_features(
|
|
180
|
+
batch_size: int,
|
|
181
|
+
correlation_factor: torch.Tensor,
|
|
182
|
+
correlation_diag: torch.Tensor,
|
|
183
|
+
thresholds: torch.Tensor,
|
|
184
|
+
device: torch.device,
|
|
185
|
+
) -> torch.Tensor:
|
|
186
|
+
"""
|
|
187
|
+
Generate correlated binary features using low-rank multivariate Gaussian sampling.
|
|
188
|
+
|
|
189
|
+
Uses the Gaussian copula approach with a low-rank covariance structure for scalability.
|
|
190
|
+
The covariance is represented as: cov = factor @ factor.T + diag(diag_term)
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
batch_size: Number of samples to generate
|
|
194
|
+
correlation_factor: Factor matrix of shape (num_features, rank)
|
|
195
|
+
correlation_diag: Diagonal term of shape (num_features,)
|
|
196
|
+
thresholds: Pre-computed thresholds for each feature (from inverse normal CDF)
|
|
197
|
+
device: Device to generate samples on
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
Binary feature matrix of shape (batch_size, num_features)
|
|
201
|
+
"""
|
|
202
|
+
mvn = LowRankMultivariateNormal(
|
|
203
|
+
loc=torch.zeros(
|
|
204
|
+
correlation_factor.shape[0], device=device, dtype=thresholds.dtype
|
|
205
|
+
),
|
|
206
|
+
cov_factor=correlation_factor.to(device=device, dtype=thresholds.dtype),
|
|
207
|
+
cov_diag=correlation_diag.to(device=device, dtype=thresholds.dtype),
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
gaussian_samples = mvn.sample((batch_size,))
|
|
211
|
+
return (gaussian_samples > thresholds.unsqueeze(0)).float()
|
|
212
|
+
|
|
213
|
+
|
|
148
214
|
def _to_tensor(
|
|
149
215
|
value: torch.Tensor | float,
|
|
150
216
|
num_features: int,
|
|
@@ -193,7 +259,7 @@ def _validate_correlation_matrix(
|
|
|
193
259
|
|
|
194
260
|
Args:
|
|
195
261
|
correlation_matrix: The matrix to validate
|
|
196
|
-
num_features: Expected number of features (matrix should be
|
|
262
|
+
num_features: Expected number of features (matrix should be (num_features, num_features))
|
|
197
263
|
|
|
198
264
|
Raises:
|
|
199
265
|
ValueError: If the matrix has incorrect shape, non-unit diagonal, or is not positive definite
|
|
@@ -213,3 +279,36 @@ def _validate_correlation_matrix(
|
|
|
213
279
|
torch.linalg.cholesky(correlation_matrix)
|
|
214
280
|
except RuntimeError as e:
|
|
215
281
|
raise ValueError("Correlation matrix must be positive definite") from e
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def _validate_low_rank_correlation(
|
|
285
|
+
correlation_factor: torch.Tensor,
|
|
286
|
+
correlation_diag: torch.Tensor,
|
|
287
|
+
num_features: int,
|
|
288
|
+
) -> None:
|
|
289
|
+
"""Validate that low-rank correlation parameters have correct properties.
|
|
290
|
+
|
|
291
|
+
Args:
|
|
292
|
+
correlation_factor: Factor matrix of shape (num_features, rank)
|
|
293
|
+
correlation_diag: Diagonal term of shape (num_features,)
|
|
294
|
+
num_features: Expected number of features
|
|
295
|
+
|
|
296
|
+
Raises:
|
|
297
|
+
ValueError: If shapes are incorrect or diagonal terms are not positive
|
|
298
|
+
"""
|
|
299
|
+
if correlation_factor.ndim != 2:
|
|
300
|
+
raise ValueError(
|
|
301
|
+
f"correlation_factor must be 2D, got {correlation_factor.ndim}D"
|
|
302
|
+
)
|
|
303
|
+
if correlation_factor.shape[0] != num_features:
|
|
304
|
+
raise ValueError(
|
|
305
|
+
f"correlation_factor must have shape ({num_features}, rank), "
|
|
306
|
+
f"got {tuple(correlation_factor.shape)}"
|
|
307
|
+
)
|
|
308
|
+
if correlation_diag.shape != (num_features,):
|
|
309
|
+
raise ValueError(
|
|
310
|
+
f"correlation_diag must have shape ({num_features},), "
|
|
311
|
+
f"got {tuple(correlation_diag.shape)}"
|
|
312
|
+
)
|
|
313
|
+
if torch.any(correlation_diag <= 0):
|
|
314
|
+
raise ValueError("correlation_diag must have all positive values")
|
|
@@ -0,0 +1,351 @@
|
|
|
1
|
+
import random
|
|
2
|
+
from typing import NamedTuple
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from sae_lens.util import str_to_dtype
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class LowRankCorrelationMatrix(NamedTuple):
|
|
10
|
+
"""
|
|
11
|
+
Low-rank representation of a correlation matrix for scalable correlated sampling.
|
|
12
|
+
|
|
13
|
+
The correlation structure is represented as:
|
|
14
|
+
correlation = correlation_factor @ correlation_factor.T + diag(correlation_diag)
|
|
15
|
+
|
|
16
|
+
This requires O(num_features * rank) storage instead of O(num_features^2),
|
|
17
|
+
making it suitable for very large numbers of features (e.g., 1M+).
|
|
18
|
+
|
|
19
|
+
Attributes:
|
|
20
|
+
correlation_factor: Factor matrix of shape (num_features, rank) that captures
|
|
21
|
+
correlations through shared latent factors.
|
|
22
|
+
correlation_diag: Diagonal variance term of shape (num_features,). Should be
|
|
23
|
+
chosen such that the diagonal of the full correlation matrix equals 1.
|
|
24
|
+
Typically: correlation_diag[i] = 1 - sum(correlation_factor[i, :]^2)
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
correlation_factor: torch.Tensor
|
|
28
|
+
correlation_diag: torch.Tensor
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def create_correlation_matrix_from_correlations(
|
|
32
|
+
num_features: int,
|
|
33
|
+
correlations: dict[tuple[int, int], float] | None = None,
|
|
34
|
+
default_correlation: float = 0.0,
|
|
35
|
+
) -> torch.Tensor:
|
|
36
|
+
"""
|
|
37
|
+
Create a correlation matrix with specified pairwise correlations.
|
|
38
|
+
|
|
39
|
+
Note: If the resulting matrix is not positive definite, it will be adjusted
|
|
40
|
+
to ensure validity. This adjustment may change the specified correlation
|
|
41
|
+
values. To minimize this effect, use smaller correlation magnitudes.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
num_features: Number of features
|
|
45
|
+
correlations: Dict mapping (i, j) pairs to correlation values.
|
|
46
|
+
Pairs should have i < j. Pairs not specified will use default_correlation.
|
|
47
|
+
default_correlation: Default correlation for unspecified pairs
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
Correlation matrix of shape (num_features, num_features)
|
|
51
|
+
"""
|
|
52
|
+
matrix = torch.eye(num_features) + default_correlation * (
|
|
53
|
+
1 - torch.eye(num_features)
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
if correlations is not None:
|
|
57
|
+
for (i, j), corr in correlations.items():
|
|
58
|
+
matrix[i, j] = corr
|
|
59
|
+
matrix[j, i] = corr
|
|
60
|
+
|
|
61
|
+
# Ensure matrix is symmetric (numerical precision)
|
|
62
|
+
matrix = (matrix + matrix.T) / 2
|
|
63
|
+
|
|
64
|
+
# Check positive definiteness and fix if necessary
|
|
65
|
+
# Use eigvalsh for symmetric matrices (returns real eigenvalues)
|
|
66
|
+
eigenvals = torch.linalg.eigvalsh(matrix)
|
|
67
|
+
if torch.any(eigenvals < -1e-6):
|
|
68
|
+
matrix = _fix_correlation_matrix(matrix)
|
|
69
|
+
|
|
70
|
+
return matrix
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _fix_correlation_matrix(
|
|
74
|
+
matrix: torch.Tensor, min_eigenval: float = 1e-6
|
|
75
|
+
) -> torch.Tensor:
|
|
76
|
+
"""Fix a correlation matrix to be positive semi-definite."""
|
|
77
|
+
eigenvals, eigenvecs = torch.linalg.eigh(matrix)
|
|
78
|
+
eigenvals = torch.clamp(eigenvals, min=min_eigenval)
|
|
79
|
+
fixed_matrix = eigenvecs @ torch.diag(eigenvals) @ eigenvecs.T
|
|
80
|
+
|
|
81
|
+
diag_vals = torch.diag(fixed_matrix)
|
|
82
|
+
diag_vals = torch.clamp(diag_vals, min=1e-8) # Prevent division by zero
|
|
83
|
+
fixed_matrix = fixed_matrix / torch.sqrt(
|
|
84
|
+
diag_vals.unsqueeze(0) * diag_vals.unsqueeze(1)
|
|
85
|
+
)
|
|
86
|
+
fixed_matrix.fill_diagonal_(1.0)
|
|
87
|
+
|
|
88
|
+
return fixed_matrix
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _validate_correlation_params(
|
|
92
|
+
positive_ratio: float,
|
|
93
|
+
uncorrelated_ratio: float,
|
|
94
|
+
min_correlation_strength: float,
|
|
95
|
+
max_correlation_strength: float,
|
|
96
|
+
) -> None:
|
|
97
|
+
"""Validate parameters for correlation generation."""
|
|
98
|
+
if not 0.0 <= positive_ratio <= 1.0:
|
|
99
|
+
raise ValueError("positive_ratio must be between 0.0 and 1.0")
|
|
100
|
+
if not 0.0 <= uncorrelated_ratio <= 1.0:
|
|
101
|
+
raise ValueError("uncorrelated_ratio must be between 0.0 and 1.0")
|
|
102
|
+
if min_correlation_strength < 0:
|
|
103
|
+
raise ValueError("min_correlation_strength must be non-negative")
|
|
104
|
+
if max_correlation_strength > 1.0:
|
|
105
|
+
raise ValueError("max_correlation_strength must be <= 1.0")
|
|
106
|
+
if min_correlation_strength > max_correlation_strength:
|
|
107
|
+
raise ValueError("min_correlation_strength must be <= max_correlation_strength")
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def generate_random_correlations(
|
|
111
|
+
num_features: int,
|
|
112
|
+
positive_ratio: float = 0.5,
|
|
113
|
+
uncorrelated_ratio: float = 0.3,
|
|
114
|
+
min_correlation_strength: float = 0.1,
|
|
115
|
+
max_correlation_strength: float = 0.8,
|
|
116
|
+
seed: int | None = None,
|
|
117
|
+
) -> dict[tuple[int, int], float]:
|
|
118
|
+
"""
|
|
119
|
+
Generate random correlations between features with specified constraints.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
num_features: Number of features
|
|
123
|
+
positive_ratio: Fraction of correlated pairs that should be positive (0.0 to 1.0)
|
|
124
|
+
uncorrelated_ratio: Fraction of feature pairs that should have zero correlation
|
|
125
|
+
(0.0 to 1.0). These pairs are omitted from the returned dictionary.
|
|
126
|
+
min_correlation_strength: Minimum absolute correlation strength for correlated pairs
|
|
127
|
+
max_correlation_strength: Maximum absolute correlation strength for correlated pairs
|
|
128
|
+
seed: Random seed for reproducibility
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
Dictionary mapping (i, j) pairs to correlation values. Pairs with zero
|
|
132
|
+
correlation (determined by uncorrelated_ratio) are not included.
|
|
133
|
+
"""
|
|
134
|
+
# Use local random number generator to avoid side effects on global state
|
|
135
|
+
rng = random.Random(seed)
|
|
136
|
+
|
|
137
|
+
_validate_correlation_params(
|
|
138
|
+
positive_ratio,
|
|
139
|
+
uncorrelated_ratio,
|
|
140
|
+
min_correlation_strength,
|
|
141
|
+
max_correlation_strength,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
# Generate all possible feature pairs (i, j) where i < j
|
|
145
|
+
all_pairs = [
|
|
146
|
+
(i, j) for i in range(num_features) for j in range(i + 1, num_features)
|
|
147
|
+
]
|
|
148
|
+
total_pairs = len(all_pairs)
|
|
149
|
+
|
|
150
|
+
if total_pairs == 0:
|
|
151
|
+
return {}
|
|
152
|
+
|
|
153
|
+
# Determine how many pairs to correlate vs leave uncorrelated
|
|
154
|
+
num_uncorrelated = int(total_pairs * uncorrelated_ratio)
|
|
155
|
+
num_correlated = total_pairs - num_uncorrelated
|
|
156
|
+
|
|
157
|
+
# Randomly select which pairs to correlate
|
|
158
|
+
correlated_pairs = rng.sample(all_pairs, num_correlated)
|
|
159
|
+
|
|
160
|
+
# For correlated pairs, determine positive vs negative
|
|
161
|
+
num_positive = int(num_correlated * positive_ratio)
|
|
162
|
+
num_negative = num_correlated - num_positive
|
|
163
|
+
|
|
164
|
+
# Assign signs
|
|
165
|
+
signs = [1] * num_positive + [-1] * num_negative
|
|
166
|
+
rng.shuffle(signs)
|
|
167
|
+
|
|
168
|
+
# Generate correlation strengths
|
|
169
|
+
correlations = {}
|
|
170
|
+
for pair, sign in zip(correlated_pairs, signs):
|
|
171
|
+
# Sample correlation strength uniformly from range
|
|
172
|
+
strength = rng.uniform(min_correlation_strength, max_correlation_strength)
|
|
173
|
+
correlations[pair] = sign * strength
|
|
174
|
+
|
|
175
|
+
return correlations
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def generate_random_correlation_matrix(
|
|
179
|
+
num_features: int,
|
|
180
|
+
positive_ratio: float = 0.5,
|
|
181
|
+
uncorrelated_ratio: float = 0.3,
|
|
182
|
+
min_correlation_strength: float = 0.1,
|
|
183
|
+
max_correlation_strength: float = 0.8,
|
|
184
|
+
seed: int | None = None,
|
|
185
|
+
device: torch.device | str = "cpu",
|
|
186
|
+
dtype: torch.dtype | str = torch.float32,
|
|
187
|
+
) -> torch.Tensor:
|
|
188
|
+
"""
|
|
189
|
+
Generate a random correlation matrix with specified constraints.
|
|
190
|
+
|
|
191
|
+
Uses vectorized torch operations for efficiency with large numbers of features.
|
|
192
|
+
|
|
193
|
+
Note: If the randomly generated matrix is not positive definite, it will be
|
|
194
|
+
adjusted to ensure validity. This adjustment may change correlation values,
|
|
195
|
+
including turning some zero correlations into non-zero values. To minimize
|
|
196
|
+
this effect, use smaller correlation strengths (e.g., 0.01-0.1).
|
|
197
|
+
|
|
198
|
+
Args:
|
|
199
|
+
num_features: Number of features
|
|
200
|
+
positive_ratio: Fraction of correlated pairs that should be positive (0.0 to 1.0)
|
|
201
|
+
uncorrelated_ratio: Fraction of feature pairs that should have zero correlation
|
|
202
|
+
(0.0 to 1.0). Note that matrix fixing for positive definiteness may reduce
|
|
203
|
+
the actual number of zero correlations.
|
|
204
|
+
min_correlation_strength: Minimum absolute correlation strength for correlated pairs
|
|
205
|
+
max_correlation_strength: Maximum absolute correlation strength for correlated pairs
|
|
206
|
+
seed: Random seed for reproducibility
|
|
207
|
+
device: Device to create the matrix on
|
|
208
|
+
dtype: Data type for the matrix
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
Random correlation matrix of shape (num_features, num_features)
|
|
212
|
+
"""
|
|
213
|
+
dtype = str_to_dtype(dtype)
|
|
214
|
+
_validate_correlation_params(
|
|
215
|
+
positive_ratio,
|
|
216
|
+
uncorrelated_ratio,
|
|
217
|
+
min_correlation_strength,
|
|
218
|
+
max_correlation_strength,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
if num_features <= 1:
|
|
222
|
+
return torch.eye(num_features, device=device, dtype=dtype)
|
|
223
|
+
|
|
224
|
+
# Set random seed if provided
|
|
225
|
+
generator = torch.Generator(device=device)
|
|
226
|
+
if seed is not None:
|
|
227
|
+
generator.manual_seed(seed)
|
|
228
|
+
|
|
229
|
+
# Get upper triangular indices (i < j)
|
|
230
|
+
row_idx, col_idx = torch.triu_indices(num_features, num_features, offset=1)
|
|
231
|
+
num_pairs = row_idx.shape[0]
|
|
232
|
+
|
|
233
|
+
# Generate random values for all pairs at once
|
|
234
|
+
# is_correlated: 1 if this pair should have a correlation, 0 otherwise
|
|
235
|
+
is_correlated = (
|
|
236
|
+
torch.rand(num_pairs, generator=generator, device=device) >= uncorrelated_ratio
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
# signs: +1 for positive correlation, -1 for negative
|
|
240
|
+
is_positive = (
|
|
241
|
+
torch.rand(num_pairs, generator=generator, device=device) < positive_ratio
|
|
242
|
+
)
|
|
243
|
+
signs = torch.where(is_positive, 1.0, -1.0)
|
|
244
|
+
|
|
245
|
+
# strengths: uniform in [min_strength, max_strength]
|
|
246
|
+
strengths = (
|
|
247
|
+
torch.rand(num_pairs, generator=generator, device=device, dtype=dtype)
|
|
248
|
+
* (max_correlation_strength - min_correlation_strength)
|
|
249
|
+
+ min_correlation_strength
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
# Combine: correlation = is_correlated * sign * strength
|
|
253
|
+
correlations = is_correlated.to(dtype) * signs.to(dtype) * strengths
|
|
254
|
+
|
|
255
|
+
# Build the symmetric matrix
|
|
256
|
+
matrix = torch.eye(num_features, device=device, dtype=dtype)
|
|
257
|
+
matrix[row_idx, col_idx] = correlations
|
|
258
|
+
matrix[col_idx, row_idx] = correlations
|
|
259
|
+
|
|
260
|
+
# Check positive definiteness and fix if necessary
|
|
261
|
+
eigenvals = torch.linalg.eigvalsh(matrix)
|
|
262
|
+
if torch.any(eigenvals < -1e-6):
|
|
263
|
+
matrix = _fix_correlation_matrix(matrix)
|
|
264
|
+
|
|
265
|
+
return matrix
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def generate_random_low_rank_correlation_matrix(
|
|
269
|
+
num_features: int,
|
|
270
|
+
rank: int,
|
|
271
|
+
correlation_scale: float = 0.1,
|
|
272
|
+
seed: int | None = None,
|
|
273
|
+
device: torch.device | str = "cpu",
|
|
274
|
+
dtype: torch.dtype | str = torch.float32,
|
|
275
|
+
) -> LowRankCorrelationMatrix:
|
|
276
|
+
"""
|
|
277
|
+
Generate a random low-rank correlation structure for scalable correlated sampling.
|
|
278
|
+
|
|
279
|
+
The correlation structure is represented as:
|
|
280
|
+
correlation = factor @ factor.T + diag(diag_term)
|
|
281
|
+
|
|
282
|
+
This requires O(num_features * rank) storage instead of O(num_features^2),
|
|
283
|
+
making it suitable for very large numbers of features (e.g., 1M+).
|
|
284
|
+
|
|
285
|
+
The factor matrix is initialized with random values scaled by correlation_scale,
|
|
286
|
+
and the diagonal term is computed to ensure the implied correlation matrix has
|
|
287
|
+
unit diagonal.
|
|
288
|
+
|
|
289
|
+
Args:
|
|
290
|
+
num_features: Number of features
|
|
291
|
+
rank: Rank of the low-rank approximation. Higher rank allows more complex
|
|
292
|
+
correlation structures but uses more memory. Typical values: 10-100.
|
|
293
|
+
correlation_scale: Scale factor for random correlations. Larger values produce
|
|
294
|
+
stronger correlations between features. Use 0 for no correlations (identity
|
|
295
|
+
matrix). Should be small enough that rank * correlation_scale^2 < 1 to
|
|
296
|
+
ensure valid diagonal terms.
|
|
297
|
+
seed: Random seed for reproducibility
|
|
298
|
+
device: Device to create tensors on
|
|
299
|
+
dtype: Data type for tensors
|
|
300
|
+
|
|
301
|
+
Returns:
|
|
302
|
+
LowRankCorrelationMatrix containing the factor matrix and diagonal term
|
|
303
|
+
"""
|
|
304
|
+
# Minimum diagonal value to ensure numerical stability in the covariance matrix.
|
|
305
|
+
# This limits how much variance can come from the low-rank factor.
|
|
306
|
+
_MIN_DIAG = 0.01
|
|
307
|
+
|
|
308
|
+
dtype = str_to_dtype(dtype)
|
|
309
|
+
device = torch.device(device)
|
|
310
|
+
|
|
311
|
+
if rank <= 0:
|
|
312
|
+
raise ValueError("rank must be positive")
|
|
313
|
+
if correlation_scale < 0:
|
|
314
|
+
raise ValueError("correlation_scale must be non-negative")
|
|
315
|
+
|
|
316
|
+
# Set random seed if provided
|
|
317
|
+
generator = torch.Generator(device=device)
|
|
318
|
+
if seed is not None:
|
|
319
|
+
generator.manual_seed(seed)
|
|
320
|
+
|
|
321
|
+
# Generate random factor matrix
|
|
322
|
+
# Each row has norm roughly sqrt(rank) * correlation_scale
|
|
323
|
+
factor = (
|
|
324
|
+
torch.randn(num_features, rank, generator=generator, device=device, dtype=dtype)
|
|
325
|
+
* correlation_scale
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
# Compute diagonal term to ensure unit diagonal in implied correlation matrix
|
|
329
|
+
# diag(factor @ factor.T) + diag_term = 1
|
|
330
|
+
# diag_term = 1 - sum(factor[i, :]^2)
|
|
331
|
+
factor_sq_sum = (factor**2).sum(dim=1)
|
|
332
|
+
diag_term = 1 - factor_sq_sum
|
|
333
|
+
|
|
334
|
+
# Ensure diagonal terms are at least _MIN_DIAG for numerical stability
|
|
335
|
+
# If any diagonal term is too small, scale down the factor matrix
|
|
336
|
+
if torch.any(diag_term < _MIN_DIAG):
|
|
337
|
+
# Scale factor so max row norm squared is at most (1 - _MIN_DIAG)
|
|
338
|
+
# This ensures all diagonal terms are >= _MIN_DIAG
|
|
339
|
+
max_factor_contribution = 1 - _MIN_DIAG
|
|
340
|
+
max_sq_sum = factor_sq_sum.max()
|
|
341
|
+
scale = torch.sqrt(
|
|
342
|
+
torch.tensor(max_factor_contribution, device=device, dtype=dtype)
|
|
343
|
+
/ max_sq_sum
|
|
344
|
+
)
|
|
345
|
+
factor = factor * scale
|
|
346
|
+
factor_sq_sum = (factor**2).sum(dim=1)
|
|
347
|
+
diag_term = 1 - factor_sq_sum
|
|
348
|
+
|
|
349
|
+
return LowRankCorrelationMatrix(
|
|
350
|
+
correlation_factor=factor, correlation_diag=diag_term
|
|
351
|
+
)
|