sae-lens 6.26.1__py3-none-any.whl → 6.28.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,141 @@
1
+ """
2
+ Utilities for training SAEs on synthetic data.
3
+
4
+ This module provides helpers for:
5
+
6
+ - Generating training data from feature dictionaries
7
+ - Training SAEs on synthetic data
8
+ - Evaluating SAEs against known ground truth features
9
+ - Initializing SAEs to match feature dictionaries
10
+ """
11
+
12
+ from dataclasses import dataclass
13
+
14
+ import torch
15
+ from scipy.optimize import linear_sum_assignment
16
+
17
+ from sae_lens.synthetic.activation_generator import ActivationGenerator
18
+ from sae_lens.synthetic.feature_dictionary import FeatureDictionary
19
+
20
+
21
+ def mean_correlation_coefficient(
22
+ features_a: torch.Tensor,
23
+ features_b: torch.Tensor,
24
+ ) -> float:
25
+ """
26
+ Compute Mean Correlation Coefficient (MCC) between two sets of feature vectors.
27
+
28
+ MCC measures how well learned features align with ground truth features by finding
29
+ an optimal one-to-one matching using the Hungarian algorithm and computing the
30
+ mean absolute cosine similarity of matched pairs.
31
+
32
+ Reference: O'Neill et al. "Compute Optimal Inference and Provable Amortisation
33
+ Gap in Sparse Autoencoders" (arXiv:2411.13117)
34
+
35
+ Args:
36
+ features_a: Feature vectors of shape [num_features_a, hidden_dim]
37
+ features_b: Feature vectors of shape [num_features_b, hidden_dim]
38
+
39
+ Returns:
40
+ MCC score in range [0, 1], where 1 indicates perfect alignment
41
+ """
42
+ # Normalize to unit vectors
43
+ a_norm = features_a / features_a.norm(dim=1, keepdim=True).clamp(min=1e-8)
44
+ b_norm = features_b / features_b.norm(dim=1, keepdim=True).clamp(min=1e-8)
45
+
46
+ # Compute absolute cosine similarity matrix
47
+ cos_sim = torch.abs(a_norm @ b_norm.T)
48
+
49
+ # Convert to cost matrix for Hungarian algorithm (which minimizes)
50
+ cost_matrix = 1 - cos_sim.cpu().numpy()
51
+
52
+ # Find optimal matching
53
+ row_ind, col_ind = linear_sum_assignment(cost_matrix)
54
+
55
+ # Compute mean of matched similarities
56
+ matched_similarities = cos_sim[row_ind, col_ind]
57
+ return matched_similarities.mean().item()
58
+
59
+
60
+ @dataclass
61
+ class SyntheticDataEvalResult:
62
+ """Results from evaluating an SAE on synthetic data."""
63
+
64
+ true_l0: float
65
+ """Average L0 of the true feature activations"""
66
+
67
+ sae_l0: float
68
+ """Average L0 of the SAE's latent activations"""
69
+
70
+ dead_latents: int
71
+ """Number of SAE latents that never fired"""
72
+
73
+ shrinkage: float
74
+ """Average ratio of SAE output norm to input norm (1.0 = no shrinkage)"""
75
+
76
+ mcc: float
77
+ """Mean Correlation Coefficient between SAE decoder and ground truth features"""
78
+
79
+
80
+ @torch.no_grad()
81
+ def eval_sae_on_synthetic_data(
82
+ sae: torch.nn.Module,
83
+ feature_dict: FeatureDictionary,
84
+ activations_generator: ActivationGenerator,
85
+ num_samples: int = 100_000,
86
+ ) -> SyntheticDataEvalResult:
87
+ """
88
+ Evaluate an SAE on synthetic data with known ground truth.
89
+
90
+ Args:
91
+ sae: The SAE to evaluate. Must have encode() and decode() methods.
92
+ feature_dict: The feature dictionary used to generate activations
93
+ activations_generator: Generator that produces feature activations
94
+ num_samples: Number of samples to use for evaluation
95
+
96
+ Returns:
97
+ SyntheticDataEvalResult containing evaluation metrics
98
+ """
99
+ sae.eval()
100
+
101
+ # Generate samples
102
+ feature_acts = activations_generator.sample(num_samples)
103
+ true_l0 = (feature_acts > 0).float().sum(dim=-1).mean().item()
104
+ hidden_acts = feature_dict(feature_acts)
105
+
106
+ # Filter out entries where no features fire
107
+ non_zero_mask = hidden_acts.norm(dim=-1) > 0
108
+ hidden_acts_filtered = hidden_acts[non_zero_mask]
109
+
110
+ # Get SAE reconstructions
111
+ sae_latents = sae.encode(hidden_acts_filtered) # type: ignore[attr-defined]
112
+ sae_output = sae.decode(sae_latents) # type: ignore[attr-defined]
113
+
114
+ sae_l0 = (sae_latents > 0).float().sum(dim=-1).mean().item()
115
+ dead_latents = int(
116
+ ((sae_latents == 0).sum(dim=0) == sae_latents.shape[0]).sum().item()
117
+ )
118
+ if hidden_acts_filtered.shape[0] == 0:
119
+ shrinkage = 0.0
120
+ else:
121
+ shrinkage = (
122
+ (
123
+ sae_output.norm(dim=-1)
124
+ / hidden_acts_filtered.norm(dim=-1).clamp(min=1e-8)
125
+ )
126
+ .mean()
127
+ .item()
128
+ )
129
+
130
+ # Compute MCC between SAE decoder and ground truth features
131
+ sae_decoder: torch.Tensor = sae.W_dec # type: ignore[attr-defined]
132
+ gt_features = feature_dict.feature_vectors
133
+ mcc = mean_correlation_coefficient(sae_decoder, gt_features)
134
+
135
+ return SyntheticDataEvalResult(
136
+ true_l0=true_l0,
137
+ sae_l0=sae_l0,
138
+ dead_latents=dead_latents,
139
+ shrinkage=shrinkage,
140
+ mcc=mcc,
141
+ )
@@ -0,0 +1,176 @@
1
+ """
2
+ Feature dictionary for generating synthetic activations.
3
+
4
+ A FeatureDictionary maps feature activations (sparse coefficients) to dense hidden activations
5
+ by multiplying with a learned or constructed feature embedding matrix.
6
+ """
7
+
8
+ from typing import Callable
9
+
10
+ import torch
11
+ from torch import nn
12
+ from tqdm import tqdm
13
+
14
+ FeatureDictionaryInitializer = Callable[["FeatureDictionary"], None]
15
+
16
+
17
+ def orthogonalize_embeddings(
18
+ embeddings: torch.Tensor,
19
+ num_steps: int = 200,
20
+ lr: float = 0.01,
21
+ show_progress: bool = False,
22
+ chunk_size: int = 1024,
23
+ ) -> torch.Tensor:
24
+ """
25
+ Orthogonalize embeddings using gradient descent with chunked computation.
26
+
27
+ Uses chunked computation to avoid O(n²) memory usage when computing pairwise
28
+ dot products. Memory usage is O(chunk_size × n) instead of O(n²).
29
+
30
+ Args:
31
+ embeddings: Tensor of shape [num_vectors, hidden_dim]
32
+ num_steps: Number of optimization steps
33
+ lr: Learning rate for Adam optimizer
34
+ show_progress: Whether to show progress bar
35
+ chunk_size: Number of vectors to process at once. Smaller values use less
36
+ memory but may be slower.
37
+
38
+ Returns:
39
+ Orthogonalized embeddings of the same shape, normalized to unit length.
40
+ """
41
+ num_vectors = embeddings.shape[0]
42
+ # Create a detached copy and normalize, then enable gradients
43
+ embeddings = embeddings.detach().clone()
44
+ embeddings = embeddings / embeddings.norm(p=2, dim=1, keepdim=True).clamp(min=1e-8)
45
+ embeddings = embeddings.requires_grad_(True)
46
+
47
+ optimizer = torch.optim.Adam([embeddings], lr=lr) # type: ignore[list-item]
48
+
49
+ pbar = tqdm(
50
+ range(num_steps), desc="Orthogonalizing vectors", disable=not show_progress
51
+ )
52
+ for _ in pbar:
53
+ optimizer.zero_grad()
54
+
55
+ off_diag_loss = torch.tensor(0.0, device=embeddings.device)
56
+ diag_loss = torch.tensor(0.0, device=embeddings.device)
57
+
58
+ for i in range(0, num_vectors, chunk_size):
59
+ end_i = min(i + chunk_size, num_vectors)
60
+ chunk = embeddings[i:end_i]
61
+ chunk_dots = chunk @ embeddings.T # [chunk_size, num_vectors]
62
+
63
+ # Create mask to zero out diagonal elements for this chunk
64
+ # Diagonal of full matrix: position (i+k, i+k) → in chunk_dots: (k, i+k)
65
+ chunk_len = end_i - i
66
+ row_indices = torch.arange(chunk_len, device=embeddings.device)
67
+ col_indices = i + row_indices # column indices in full matrix
68
+
69
+ # Boolean mask: True for off-diagonal elements we want to include
70
+ off_diag_mask = torch.ones_like(chunk_dots, dtype=torch.bool)
71
+ off_diag_mask[row_indices, col_indices] = False
72
+
73
+ off_diag_loss = off_diag_loss + chunk_dots[off_diag_mask].pow(2).sum()
74
+
75
+ # Diagonal loss: keep self-dot-products at 1
76
+ diag_vals = chunk_dots[row_indices, col_indices]
77
+ diag_loss = diag_loss + (diag_vals - 1).pow(2).sum()
78
+
79
+ loss = off_diag_loss + num_vectors * diag_loss
80
+ loss.backward()
81
+ optimizer.step()
82
+ pbar.set_description(f"loss: {loss.item():.3f}")
83
+
84
+ with torch.no_grad():
85
+ embeddings = embeddings / embeddings.norm(p=2, dim=1, keepdim=True).clamp(
86
+ min=1e-8
87
+ )
88
+ return embeddings.detach().clone()
89
+
90
+
91
+ def orthogonal_initializer(
92
+ num_steps: int = 200,
93
+ lr: float = 0.01,
94
+ show_progress: bool = False,
95
+ chunk_size: int = 1024,
96
+ ) -> FeatureDictionaryInitializer:
97
+ def initializer(feature_dict: "FeatureDictionary") -> None:
98
+ feature_dict.feature_vectors.data = orthogonalize_embeddings(
99
+ feature_dict.feature_vectors,
100
+ num_steps=num_steps,
101
+ lr=lr,
102
+ show_progress=show_progress,
103
+ chunk_size=chunk_size,
104
+ )
105
+
106
+ return initializer
107
+
108
+
109
+ class FeatureDictionary(nn.Module):
110
+ """
111
+ A feature dictionary that maps sparse feature activations to dense hidden activations.
112
+
113
+ This class creates a set of feature vectors (the "dictionary") and provides methods
114
+ to generate hidden activations from feature activations via a linear transformation.
115
+
116
+ The feature vectors can be configured to have a specific pairwise cosine similarity,
117
+ which is useful for controlling the difficulty of sparse recovery.
118
+
119
+ Attributes:
120
+ feature_vectors: Parameter of shape [num_features, hidden_dim] containing the
121
+ feature embedding vectors
122
+ bias: Parameter of shape [hidden_dim] containing the bias term (zeros if bias=False)
123
+ """
124
+
125
+ feature_vectors: nn.Parameter
126
+ bias: nn.Parameter
127
+
128
+ def __init__(
129
+ self,
130
+ num_features: int,
131
+ hidden_dim: int,
132
+ bias: bool = False,
133
+ initializer: FeatureDictionaryInitializer | None = orthogonal_initializer(),
134
+ device: str | torch.device = "cpu",
135
+ ):
136
+ """
137
+ Create a new FeatureDictionary.
138
+
139
+ Args:
140
+ num_features: Number of features in the dictionary
141
+ hidden_dim: Dimensionality of the hidden space
142
+ bias: Whether to include a bias term in the embedding
143
+ initializer: Initializer function to use. If None, the embeddings are initialized to random unit vectors. By default will orthogonalize embeddings.
144
+ device: Device to use for the feature dictionary.
145
+ """
146
+ super().__init__()
147
+ self.num_features = num_features
148
+ self.hidden_dim = hidden_dim
149
+
150
+ # Initialize feature vectors as unit vectors
151
+ embeddings = torch.randn(num_features, hidden_dim, device=device)
152
+ embeddings = embeddings / embeddings.norm(p=2, dim=1, keepdim=True).clamp(
153
+ min=1e-8
154
+ )
155
+ self.feature_vectors = nn.Parameter(embeddings)
156
+
157
+ # Initialize bias (zeros if not using bias, but still a parameter for consistent API)
158
+ self.bias = nn.Parameter(
159
+ torch.zeros(hidden_dim, device=device), requires_grad=bias
160
+ )
161
+
162
+ if initializer is not None:
163
+ initializer(self)
164
+
165
+ def forward(self, feature_activations: torch.Tensor) -> torch.Tensor:
166
+ """
167
+ Convert feature activations to hidden activations.
168
+
169
+ Args:
170
+ feature_activations: Tensor of shape [batch, num_features] containing
171
+ sparse feature activation values
172
+
173
+ Returns:
174
+ Tensor of shape [batch, hidden_dim] containing dense hidden activations
175
+ """
176
+ return feature_activations @ self.feature_vectors + self.bias
@@ -0,0 +1,104 @@
1
+ """
2
+ Helper functions for generating firing probability distributions.
3
+ """
4
+
5
+ import torch
6
+
7
+
8
+ def zipfian_firing_probabilities(
9
+ num_features: int,
10
+ exponent: float = 1.0,
11
+ max_prob: float = 0.3,
12
+ min_prob: float = 0.01,
13
+ ) -> torch.Tensor:
14
+ """
15
+ Generate firing probabilities following a Zipfian (power-law) distribution.
16
+
17
+ Creates probabilities where a few features fire frequently and most fire rarely,
18
+ which mirrors the distribution often observed in real neural network features.
19
+
20
+ Args:
21
+ num_features: Number of features to generate probabilities for
22
+ exponent: Zipf exponent (higher = steeper dropoff). Default 1.0.
23
+ max_prob: Maximum firing probability (for the most frequent feature)
24
+ min_prob: Minimum firing probability (for the least frequent feature)
25
+
26
+ Returns:
27
+ Tensor of shape [num_features] with firing probabilities in descending order
28
+ """
29
+ if num_features < 1:
30
+ raise ValueError("num_features must be at least 1")
31
+ if exponent <= 0:
32
+ raise ValueError("exponent must be positive")
33
+ if not 0 < min_prob < max_prob <= 1:
34
+ raise ValueError("Must have 0 < min_prob < max_prob <= 1")
35
+
36
+ ranks = torch.arange(1, num_features + 1, dtype=torch.float32)
37
+ probs = 1.0 / ranks**exponent
38
+
39
+ # Scale to [min_prob, max_prob]
40
+ if num_features == 1:
41
+ return torch.tensor([max_prob])
42
+
43
+ probs_min, probs_max = probs.min(), probs.max()
44
+ return min_prob + (max_prob - min_prob) * (probs - probs_min) / (
45
+ probs_max - probs_min
46
+ )
47
+
48
+
49
+ def linear_firing_probabilities(
50
+ num_features: int,
51
+ max_prob: float = 0.3,
52
+ min_prob: float = 0.01,
53
+ ) -> torch.Tensor:
54
+ """
55
+ Generate firing probabilities that decay linearly from max to min.
56
+
57
+ Args:
58
+ num_features: Number of features to generate probabilities for
59
+ max_prob: Firing probability for the first feature
60
+ min_prob: Firing probability for the last feature
61
+
62
+ Returns:
63
+ Tensor of shape [num_features] with linearly decaying probabilities
64
+ """
65
+ if num_features < 1:
66
+ raise ValueError("num_features must be at least 1")
67
+ if not 0 < min_prob <= max_prob <= 1:
68
+ raise ValueError("Must have 0 < min_prob <= max_prob <= 1")
69
+
70
+ if num_features == 1:
71
+ return torch.tensor([max_prob])
72
+
73
+ return torch.linspace(max_prob, min_prob, num_features)
74
+
75
+
76
+ def random_firing_probabilities(
77
+ num_features: int,
78
+ max_prob: float = 0.5,
79
+ min_prob: float = 0.01,
80
+ seed: int | None = None,
81
+ ) -> torch.Tensor:
82
+ """
83
+ Generate random firing probabilities uniformly sampled from a range.
84
+
85
+ Args:
86
+ num_features: Number of features to generate probabilities for
87
+ max_prob: Maximum firing probability
88
+ min_prob: Minimum firing probability
89
+ seed: Optional random seed for reproducibility
90
+
91
+ Returns:
92
+ Tensor of shape [num_features] with random firing probabilities
93
+ """
94
+ if num_features < 1:
95
+ raise ValueError("num_features must be at least 1")
96
+ if not 0 < min_prob < max_prob <= 1:
97
+ raise ValueError("Must have 0 < min_prob < max_prob <= 1")
98
+
99
+ generator = torch.Generator()
100
+ if seed is not None:
101
+ generator.manual_seed(seed)
102
+
103
+ probs = torch.rand(num_features, generator=generator, dtype=torch.float32)
104
+ return min_prob + (max_prob - min_prob) * probs