sae-lens 6.26.1__py3-none-any.whl → 6.28.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +3 -1
- sae_lens/cache_activations_runner.py +12 -5
- sae_lens/config.py +2 -0
- sae_lens/loading/pretrained_sae_loaders.py +2 -1
- sae_lens/loading/pretrained_saes_directory.py +18 -0
- sae_lens/pretrained_saes.yaml +1 -1
- sae_lens/saes/gated_sae.py +1 -0
- sae_lens/saes/jumprelu_sae.py +3 -0
- sae_lens/saes/sae.py +13 -0
- sae_lens/saes/standard_sae.py +2 -0
- sae_lens/saes/temporal_sae.py +1 -0
- sae_lens/synthetic/__init__.py +89 -0
- sae_lens/synthetic/activation_generator.py +216 -0
- sae_lens/synthetic/correlation.py +170 -0
- sae_lens/synthetic/evals.py +141 -0
- sae_lens/synthetic/feature_dictionary.py +176 -0
- sae_lens/synthetic/firing_probabilities.py +104 -0
- sae_lens/synthetic/hierarchy.py +596 -0
- sae_lens/synthetic/initialization.py +40 -0
- sae_lens/synthetic/plotting.py +230 -0
- sae_lens/synthetic/training.py +145 -0
- sae_lens/tokenization_and_batching.py +1 -1
- sae_lens/training/activations_store.py +51 -91
- sae_lens/training/mixing_buffer.py +14 -5
- sae_lens/training/sae_trainer.py +1 -1
- sae_lens/util.py +26 -1
- {sae_lens-6.26.1.dist-info → sae_lens-6.28.2.dist-info}/METADATA +13 -1
- sae_lens-6.28.2.dist-info/RECORD +52 -0
- sae_lens-6.26.1.dist-info/RECORD +0 -42
- {sae_lens-6.26.1.dist-info → sae_lens-6.28.2.dist-info}/WHEEL +0 -0
- {sae_lens-6.26.1.dist-info → sae_lens-6.28.2.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utilities for training SAEs on synthetic data.
|
|
3
|
+
|
|
4
|
+
This module provides helpers for:
|
|
5
|
+
|
|
6
|
+
- Generating training data from feature dictionaries
|
|
7
|
+
- Training SAEs on synthetic data
|
|
8
|
+
- Evaluating SAEs against known ground truth features
|
|
9
|
+
- Initializing SAEs to match feature dictionaries
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from dataclasses import dataclass
|
|
13
|
+
|
|
14
|
+
import torch
|
|
15
|
+
from scipy.optimize import linear_sum_assignment
|
|
16
|
+
|
|
17
|
+
from sae_lens.synthetic.activation_generator import ActivationGenerator
|
|
18
|
+
from sae_lens.synthetic.feature_dictionary import FeatureDictionary
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def mean_correlation_coefficient(
|
|
22
|
+
features_a: torch.Tensor,
|
|
23
|
+
features_b: torch.Tensor,
|
|
24
|
+
) -> float:
|
|
25
|
+
"""
|
|
26
|
+
Compute Mean Correlation Coefficient (MCC) between two sets of feature vectors.
|
|
27
|
+
|
|
28
|
+
MCC measures how well learned features align with ground truth features by finding
|
|
29
|
+
an optimal one-to-one matching using the Hungarian algorithm and computing the
|
|
30
|
+
mean absolute cosine similarity of matched pairs.
|
|
31
|
+
|
|
32
|
+
Reference: O'Neill et al. "Compute Optimal Inference and Provable Amortisation
|
|
33
|
+
Gap in Sparse Autoencoders" (arXiv:2411.13117)
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
features_a: Feature vectors of shape [num_features_a, hidden_dim]
|
|
37
|
+
features_b: Feature vectors of shape [num_features_b, hidden_dim]
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
MCC score in range [0, 1], where 1 indicates perfect alignment
|
|
41
|
+
"""
|
|
42
|
+
# Normalize to unit vectors
|
|
43
|
+
a_norm = features_a / features_a.norm(dim=1, keepdim=True).clamp(min=1e-8)
|
|
44
|
+
b_norm = features_b / features_b.norm(dim=1, keepdim=True).clamp(min=1e-8)
|
|
45
|
+
|
|
46
|
+
# Compute absolute cosine similarity matrix
|
|
47
|
+
cos_sim = torch.abs(a_norm @ b_norm.T)
|
|
48
|
+
|
|
49
|
+
# Convert to cost matrix for Hungarian algorithm (which minimizes)
|
|
50
|
+
cost_matrix = 1 - cos_sim.cpu().numpy()
|
|
51
|
+
|
|
52
|
+
# Find optimal matching
|
|
53
|
+
row_ind, col_ind = linear_sum_assignment(cost_matrix)
|
|
54
|
+
|
|
55
|
+
# Compute mean of matched similarities
|
|
56
|
+
matched_similarities = cos_sim[row_ind, col_ind]
|
|
57
|
+
return matched_similarities.mean().item()
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@dataclass
|
|
61
|
+
class SyntheticDataEvalResult:
|
|
62
|
+
"""Results from evaluating an SAE on synthetic data."""
|
|
63
|
+
|
|
64
|
+
true_l0: float
|
|
65
|
+
"""Average L0 of the true feature activations"""
|
|
66
|
+
|
|
67
|
+
sae_l0: float
|
|
68
|
+
"""Average L0 of the SAE's latent activations"""
|
|
69
|
+
|
|
70
|
+
dead_latents: int
|
|
71
|
+
"""Number of SAE latents that never fired"""
|
|
72
|
+
|
|
73
|
+
shrinkage: float
|
|
74
|
+
"""Average ratio of SAE output norm to input norm (1.0 = no shrinkage)"""
|
|
75
|
+
|
|
76
|
+
mcc: float
|
|
77
|
+
"""Mean Correlation Coefficient between SAE decoder and ground truth features"""
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@torch.no_grad()
|
|
81
|
+
def eval_sae_on_synthetic_data(
|
|
82
|
+
sae: torch.nn.Module,
|
|
83
|
+
feature_dict: FeatureDictionary,
|
|
84
|
+
activations_generator: ActivationGenerator,
|
|
85
|
+
num_samples: int = 100_000,
|
|
86
|
+
) -> SyntheticDataEvalResult:
|
|
87
|
+
"""
|
|
88
|
+
Evaluate an SAE on synthetic data with known ground truth.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
sae: The SAE to evaluate. Must have encode() and decode() methods.
|
|
92
|
+
feature_dict: The feature dictionary used to generate activations
|
|
93
|
+
activations_generator: Generator that produces feature activations
|
|
94
|
+
num_samples: Number of samples to use for evaluation
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
SyntheticDataEvalResult containing evaluation metrics
|
|
98
|
+
"""
|
|
99
|
+
sae.eval()
|
|
100
|
+
|
|
101
|
+
# Generate samples
|
|
102
|
+
feature_acts = activations_generator.sample(num_samples)
|
|
103
|
+
true_l0 = (feature_acts > 0).float().sum(dim=-1).mean().item()
|
|
104
|
+
hidden_acts = feature_dict(feature_acts)
|
|
105
|
+
|
|
106
|
+
# Filter out entries where no features fire
|
|
107
|
+
non_zero_mask = hidden_acts.norm(dim=-1) > 0
|
|
108
|
+
hidden_acts_filtered = hidden_acts[non_zero_mask]
|
|
109
|
+
|
|
110
|
+
# Get SAE reconstructions
|
|
111
|
+
sae_latents = sae.encode(hidden_acts_filtered) # type: ignore[attr-defined]
|
|
112
|
+
sae_output = sae.decode(sae_latents) # type: ignore[attr-defined]
|
|
113
|
+
|
|
114
|
+
sae_l0 = (sae_latents > 0).float().sum(dim=-1).mean().item()
|
|
115
|
+
dead_latents = int(
|
|
116
|
+
((sae_latents == 0).sum(dim=0) == sae_latents.shape[0]).sum().item()
|
|
117
|
+
)
|
|
118
|
+
if hidden_acts_filtered.shape[0] == 0:
|
|
119
|
+
shrinkage = 0.0
|
|
120
|
+
else:
|
|
121
|
+
shrinkage = (
|
|
122
|
+
(
|
|
123
|
+
sae_output.norm(dim=-1)
|
|
124
|
+
/ hidden_acts_filtered.norm(dim=-1).clamp(min=1e-8)
|
|
125
|
+
)
|
|
126
|
+
.mean()
|
|
127
|
+
.item()
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
# Compute MCC between SAE decoder and ground truth features
|
|
131
|
+
sae_decoder: torch.Tensor = sae.W_dec # type: ignore[attr-defined]
|
|
132
|
+
gt_features = feature_dict.feature_vectors
|
|
133
|
+
mcc = mean_correlation_coefficient(sae_decoder, gt_features)
|
|
134
|
+
|
|
135
|
+
return SyntheticDataEvalResult(
|
|
136
|
+
true_l0=true_l0,
|
|
137
|
+
sae_l0=sae_l0,
|
|
138
|
+
dead_latents=dead_latents,
|
|
139
|
+
shrinkage=shrinkage,
|
|
140
|
+
mcc=mcc,
|
|
141
|
+
)
|
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Feature dictionary for generating synthetic activations.
|
|
3
|
+
|
|
4
|
+
A FeatureDictionary maps feature activations (sparse coefficients) to dense hidden activations
|
|
5
|
+
by multiplying with a learned or constructed feature embedding matrix.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Callable
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
from torch import nn
|
|
12
|
+
from tqdm import tqdm
|
|
13
|
+
|
|
14
|
+
FeatureDictionaryInitializer = Callable[["FeatureDictionary"], None]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def orthogonalize_embeddings(
|
|
18
|
+
embeddings: torch.Tensor,
|
|
19
|
+
num_steps: int = 200,
|
|
20
|
+
lr: float = 0.01,
|
|
21
|
+
show_progress: bool = False,
|
|
22
|
+
chunk_size: int = 1024,
|
|
23
|
+
) -> torch.Tensor:
|
|
24
|
+
"""
|
|
25
|
+
Orthogonalize embeddings using gradient descent with chunked computation.
|
|
26
|
+
|
|
27
|
+
Uses chunked computation to avoid O(n²) memory usage when computing pairwise
|
|
28
|
+
dot products. Memory usage is O(chunk_size × n) instead of O(n²).
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
embeddings: Tensor of shape [num_vectors, hidden_dim]
|
|
32
|
+
num_steps: Number of optimization steps
|
|
33
|
+
lr: Learning rate for Adam optimizer
|
|
34
|
+
show_progress: Whether to show progress bar
|
|
35
|
+
chunk_size: Number of vectors to process at once. Smaller values use less
|
|
36
|
+
memory but may be slower.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
Orthogonalized embeddings of the same shape, normalized to unit length.
|
|
40
|
+
"""
|
|
41
|
+
num_vectors = embeddings.shape[0]
|
|
42
|
+
# Create a detached copy and normalize, then enable gradients
|
|
43
|
+
embeddings = embeddings.detach().clone()
|
|
44
|
+
embeddings = embeddings / embeddings.norm(p=2, dim=1, keepdim=True).clamp(min=1e-8)
|
|
45
|
+
embeddings = embeddings.requires_grad_(True)
|
|
46
|
+
|
|
47
|
+
optimizer = torch.optim.Adam([embeddings], lr=lr) # type: ignore[list-item]
|
|
48
|
+
|
|
49
|
+
pbar = tqdm(
|
|
50
|
+
range(num_steps), desc="Orthogonalizing vectors", disable=not show_progress
|
|
51
|
+
)
|
|
52
|
+
for _ in pbar:
|
|
53
|
+
optimizer.zero_grad()
|
|
54
|
+
|
|
55
|
+
off_diag_loss = torch.tensor(0.0, device=embeddings.device)
|
|
56
|
+
diag_loss = torch.tensor(0.0, device=embeddings.device)
|
|
57
|
+
|
|
58
|
+
for i in range(0, num_vectors, chunk_size):
|
|
59
|
+
end_i = min(i + chunk_size, num_vectors)
|
|
60
|
+
chunk = embeddings[i:end_i]
|
|
61
|
+
chunk_dots = chunk @ embeddings.T # [chunk_size, num_vectors]
|
|
62
|
+
|
|
63
|
+
# Create mask to zero out diagonal elements for this chunk
|
|
64
|
+
# Diagonal of full matrix: position (i+k, i+k) → in chunk_dots: (k, i+k)
|
|
65
|
+
chunk_len = end_i - i
|
|
66
|
+
row_indices = torch.arange(chunk_len, device=embeddings.device)
|
|
67
|
+
col_indices = i + row_indices # column indices in full matrix
|
|
68
|
+
|
|
69
|
+
# Boolean mask: True for off-diagonal elements we want to include
|
|
70
|
+
off_diag_mask = torch.ones_like(chunk_dots, dtype=torch.bool)
|
|
71
|
+
off_diag_mask[row_indices, col_indices] = False
|
|
72
|
+
|
|
73
|
+
off_diag_loss = off_diag_loss + chunk_dots[off_diag_mask].pow(2).sum()
|
|
74
|
+
|
|
75
|
+
# Diagonal loss: keep self-dot-products at 1
|
|
76
|
+
diag_vals = chunk_dots[row_indices, col_indices]
|
|
77
|
+
diag_loss = diag_loss + (diag_vals - 1).pow(2).sum()
|
|
78
|
+
|
|
79
|
+
loss = off_diag_loss + num_vectors * diag_loss
|
|
80
|
+
loss.backward()
|
|
81
|
+
optimizer.step()
|
|
82
|
+
pbar.set_description(f"loss: {loss.item():.3f}")
|
|
83
|
+
|
|
84
|
+
with torch.no_grad():
|
|
85
|
+
embeddings = embeddings / embeddings.norm(p=2, dim=1, keepdim=True).clamp(
|
|
86
|
+
min=1e-8
|
|
87
|
+
)
|
|
88
|
+
return embeddings.detach().clone()
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def orthogonal_initializer(
|
|
92
|
+
num_steps: int = 200,
|
|
93
|
+
lr: float = 0.01,
|
|
94
|
+
show_progress: bool = False,
|
|
95
|
+
chunk_size: int = 1024,
|
|
96
|
+
) -> FeatureDictionaryInitializer:
|
|
97
|
+
def initializer(feature_dict: "FeatureDictionary") -> None:
|
|
98
|
+
feature_dict.feature_vectors.data = orthogonalize_embeddings(
|
|
99
|
+
feature_dict.feature_vectors,
|
|
100
|
+
num_steps=num_steps,
|
|
101
|
+
lr=lr,
|
|
102
|
+
show_progress=show_progress,
|
|
103
|
+
chunk_size=chunk_size,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
return initializer
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class FeatureDictionary(nn.Module):
|
|
110
|
+
"""
|
|
111
|
+
A feature dictionary that maps sparse feature activations to dense hidden activations.
|
|
112
|
+
|
|
113
|
+
This class creates a set of feature vectors (the "dictionary") and provides methods
|
|
114
|
+
to generate hidden activations from feature activations via a linear transformation.
|
|
115
|
+
|
|
116
|
+
The feature vectors can be configured to have a specific pairwise cosine similarity,
|
|
117
|
+
which is useful for controlling the difficulty of sparse recovery.
|
|
118
|
+
|
|
119
|
+
Attributes:
|
|
120
|
+
feature_vectors: Parameter of shape [num_features, hidden_dim] containing the
|
|
121
|
+
feature embedding vectors
|
|
122
|
+
bias: Parameter of shape [hidden_dim] containing the bias term (zeros if bias=False)
|
|
123
|
+
"""
|
|
124
|
+
|
|
125
|
+
feature_vectors: nn.Parameter
|
|
126
|
+
bias: nn.Parameter
|
|
127
|
+
|
|
128
|
+
def __init__(
|
|
129
|
+
self,
|
|
130
|
+
num_features: int,
|
|
131
|
+
hidden_dim: int,
|
|
132
|
+
bias: bool = False,
|
|
133
|
+
initializer: FeatureDictionaryInitializer | None = orthogonal_initializer(),
|
|
134
|
+
device: str | torch.device = "cpu",
|
|
135
|
+
):
|
|
136
|
+
"""
|
|
137
|
+
Create a new FeatureDictionary.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
num_features: Number of features in the dictionary
|
|
141
|
+
hidden_dim: Dimensionality of the hidden space
|
|
142
|
+
bias: Whether to include a bias term in the embedding
|
|
143
|
+
initializer: Initializer function to use. If None, the embeddings are initialized to random unit vectors. By default will orthogonalize embeddings.
|
|
144
|
+
device: Device to use for the feature dictionary.
|
|
145
|
+
"""
|
|
146
|
+
super().__init__()
|
|
147
|
+
self.num_features = num_features
|
|
148
|
+
self.hidden_dim = hidden_dim
|
|
149
|
+
|
|
150
|
+
# Initialize feature vectors as unit vectors
|
|
151
|
+
embeddings = torch.randn(num_features, hidden_dim, device=device)
|
|
152
|
+
embeddings = embeddings / embeddings.norm(p=2, dim=1, keepdim=True).clamp(
|
|
153
|
+
min=1e-8
|
|
154
|
+
)
|
|
155
|
+
self.feature_vectors = nn.Parameter(embeddings)
|
|
156
|
+
|
|
157
|
+
# Initialize bias (zeros if not using bias, but still a parameter for consistent API)
|
|
158
|
+
self.bias = nn.Parameter(
|
|
159
|
+
torch.zeros(hidden_dim, device=device), requires_grad=bias
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
if initializer is not None:
|
|
163
|
+
initializer(self)
|
|
164
|
+
|
|
165
|
+
def forward(self, feature_activations: torch.Tensor) -> torch.Tensor:
|
|
166
|
+
"""
|
|
167
|
+
Convert feature activations to hidden activations.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
feature_activations: Tensor of shape [batch, num_features] containing
|
|
171
|
+
sparse feature activation values
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
Tensor of shape [batch, hidden_dim] containing dense hidden activations
|
|
175
|
+
"""
|
|
176
|
+
return feature_activations @ self.feature_vectors + self.bias
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Helper functions for generating firing probability distributions.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def zipfian_firing_probabilities(
|
|
9
|
+
num_features: int,
|
|
10
|
+
exponent: float = 1.0,
|
|
11
|
+
max_prob: float = 0.3,
|
|
12
|
+
min_prob: float = 0.01,
|
|
13
|
+
) -> torch.Tensor:
|
|
14
|
+
"""
|
|
15
|
+
Generate firing probabilities following a Zipfian (power-law) distribution.
|
|
16
|
+
|
|
17
|
+
Creates probabilities where a few features fire frequently and most fire rarely,
|
|
18
|
+
which mirrors the distribution often observed in real neural network features.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
num_features: Number of features to generate probabilities for
|
|
22
|
+
exponent: Zipf exponent (higher = steeper dropoff). Default 1.0.
|
|
23
|
+
max_prob: Maximum firing probability (for the most frequent feature)
|
|
24
|
+
min_prob: Minimum firing probability (for the least frequent feature)
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
Tensor of shape [num_features] with firing probabilities in descending order
|
|
28
|
+
"""
|
|
29
|
+
if num_features < 1:
|
|
30
|
+
raise ValueError("num_features must be at least 1")
|
|
31
|
+
if exponent <= 0:
|
|
32
|
+
raise ValueError("exponent must be positive")
|
|
33
|
+
if not 0 < min_prob < max_prob <= 1:
|
|
34
|
+
raise ValueError("Must have 0 < min_prob < max_prob <= 1")
|
|
35
|
+
|
|
36
|
+
ranks = torch.arange(1, num_features + 1, dtype=torch.float32)
|
|
37
|
+
probs = 1.0 / ranks**exponent
|
|
38
|
+
|
|
39
|
+
# Scale to [min_prob, max_prob]
|
|
40
|
+
if num_features == 1:
|
|
41
|
+
return torch.tensor([max_prob])
|
|
42
|
+
|
|
43
|
+
probs_min, probs_max = probs.min(), probs.max()
|
|
44
|
+
return min_prob + (max_prob - min_prob) * (probs - probs_min) / (
|
|
45
|
+
probs_max - probs_min
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def linear_firing_probabilities(
|
|
50
|
+
num_features: int,
|
|
51
|
+
max_prob: float = 0.3,
|
|
52
|
+
min_prob: float = 0.01,
|
|
53
|
+
) -> torch.Tensor:
|
|
54
|
+
"""
|
|
55
|
+
Generate firing probabilities that decay linearly from max to min.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
num_features: Number of features to generate probabilities for
|
|
59
|
+
max_prob: Firing probability for the first feature
|
|
60
|
+
min_prob: Firing probability for the last feature
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
Tensor of shape [num_features] with linearly decaying probabilities
|
|
64
|
+
"""
|
|
65
|
+
if num_features < 1:
|
|
66
|
+
raise ValueError("num_features must be at least 1")
|
|
67
|
+
if not 0 < min_prob <= max_prob <= 1:
|
|
68
|
+
raise ValueError("Must have 0 < min_prob <= max_prob <= 1")
|
|
69
|
+
|
|
70
|
+
if num_features == 1:
|
|
71
|
+
return torch.tensor([max_prob])
|
|
72
|
+
|
|
73
|
+
return torch.linspace(max_prob, min_prob, num_features)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def random_firing_probabilities(
|
|
77
|
+
num_features: int,
|
|
78
|
+
max_prob: float = 0.5,
|
|
79
|
+
min_prob: float = 0.01,
|
|
80
|
+
seed: int | None = None,
|
|
81
|
+
) -> torch.Tensor:
|
|
82
|
+
"""
|
|
83
|
+
Generate random firing probabilities uniformly sampled from a range.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
num_features: Number of features to generate probabilities for
|
|
87
|
+
max_prob: Maximum firing probability
|
|
88
|
+
min_prob: Minimum firing probability
|
|
89
|
+
seed: Optional random seed for reproducibility
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
Tensor of shape [num_features] with random firing probabilities
|
|
93
|
+
"""
|
|
94
|
+
if num_features < 1:
|
|
95
|
+
raise ValueError("num_features must be at least 1")
|
|
96
|
+
if not 0 < min_prob < max_prob <= 1:
|
|
97
|
+
raise ValueError("Must have 0 < min_prob < max_prob <= 1")
|
|
98
|
+
|
|
99
|
+
generator = torch.Generator()
|
|
100
|
+
if seed is not None:
|
|
101
|
+
generator.manual_seed(seed)
|
|
102
|
+
|
|
103
|
+
probs = torch.rand(num_features, generator=generator, dtype=torch.float32)
|
|
104
|
+
return min_prob + (max_prob - min_prob) * probs
|