sae-lens 6.26.1__py3-none-any.whl → 6.28.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,141 @@
1
+ """
2
+ Utilities for training SAEs on synthetic data.
3
+
4
+ This module provides helpers for:
5
+
6
+ - Generating training data from feature dictionaries
7
+ - Training SAEs on synthetic data
8
+ - Evaluating SAEs against known ground truth features
9
+ - Initializing SAEs to match feature dictionaries
10
+ """
11
+
12
+ from dataclasses import dataclass
13
+
14
+ import torch
15
+ from scipy.optimize import linear_sum_assignment
16
+
17
+ from sae_lens.synthetic.activation_generator import ActivationGenerator
18
+ from sae_lens.synthetic.feature_dictionary import FeatureDictionary
19
+
20
+
21
+ def mean_correlation_coefficient(
22
+ features_a: torch.Tensor,
23
+ features_b: torch.Tensor,
24
+ ) -> float:
25
+ """
26
+ Compute Mean Correlation Coefficient (MCC) between two sets of feature vectors.
27
+
28
+ MCC measures how well learned features align with ground truth features by finding
29
+ an optimal one-to-one matching using the Hungarian algorithm and computing the
30
+ mean absolute cosine similarity of matched pairs.
31
+
32
+ Reference: O'Neill et al. "Compute Optimal Inference and Provable Amortisation
33
+ Gap in Sparse Autoencoders" (arXiv:2411.13117)
34
+
35
+ Args:
36
+ features_a: Feature vectors of shape [num_features_a, hidden_dim]
37
+ features_b: Feature vectors of shape [num_features_b, hidden_dim]
38
+
39
+ Returns:
40
+ MCC score in range [0, 1], where 1 indicates perfect alignment
41
+ """
42
+ # Normalize to unit vectors
43
+ a_norm = features_a / features_a.norm(dim=1, keepdim=True).clamp(min=1e-8)
44
+ b_norm = features_b / features_b.norm(dim=1, keepdim=True).clamp(min=1e-8)
45
+
46
+ # Compute absolute cosine similarity matrix
47
+ cos_sim = torch.abs(a_norm @ b_norm.T)
48
+
49
+ # Convert to cost matrix for Hungarian algorithm (which minimizes)
50
+ cost_matrix = 1 - cos_sim.cpu().numpy()
51
+
52
+ # Find optimal matching
53
+ row_ind, col_ind = linear_sum_assignment(cost_matrix)
54
+
55
+ # Compute mean of matched similarities
56
+ matched_similarities = cos_sim[row_ind, col_ind]
57
+ return matched_similarities.mean().item()
58
+
59
+
60
+ @dataclass
61
+ class SyntheticDataEvalResult:
62
+ """Results from evaluating an SAE on synthetic data."""
63
+
64
+ true_l0: float
65
+ """Average L0 of the true feature activations"""
66
+
67
+ sae_l0: float
68
+ """Average L0 of the SAE's latent activations"""
69
+
70
+ dead_latents: int
71
+ """Number of SAE latents that never fired"""
72
+
73
+ shrinkage: float
74
+ """Average ratio of SAE output norm to input norm (1.0 = no shrinkage)"""
75
+
76
+ mcc: float
77
+ """Mean Correlation Coefficient between SAE decoder and ground truth features"""
78
+
79
+
80
+ @torch.no_grad()
81
+ def eval_sae_on_synthetic_data(
82
+ sae: torch.nn.Module,
83
+ feature_dict: FeatureDictionary,
84
+ activations_generator: ActivationGenerator,
85
+ num_samples: int = 100_000,
86
+ ) -> SyntheticDataEvalResult:
87
+ """
88
+ Evaluate an SAE on synthetic data with known ground truth.
89
+
90
+ Args:
91
+ sae: The SAE to evaluate. Must have encode() and decode() methods.
92
+ feature_dict: The feature dictionary used to generate activations
93
+ activations_generator: Generator that produces feature activations
94
+ num_samples: Number of samples to use for evaluation
95
+
96
+ Returns:
97
+ SyntheticDataEvalResult containing evaluation metrics
98
+ """
99
+ sae.eval()
100
+
101
+ # Generate samples
102
+ feature_acts = activations_generator.sample(num_samples)
103
+ true_l0 = (feature_acts > 0).float().sum(dim=-1).mean().item()
104
+ hidden_acts = feature_dict(feature_acts)
105
+
106
+ # Filter out entries where no features fire
107
+ non_zero_mask = hidden_acts.norm(dim=-1) > 0
108
+ hidden_acts_filtered = hidden_acts[non_zero_mask]
109
+
110
+ # Get SAE reconstructions
111
+ sae_latents = sae.encode(hidden_acts_filtered) # type: ignore[attr-defined]
112
+ sae_output = sae.decode(sae_latents) # type: ignore[attr-defined]
113
+
114
+ sae_l0 = (sae_latents > 0).float().sum(dim=-1).mean().item()
115
+ dead_latents = int(
116
+ ((sae_latents == 0).sum(dim=0) == sae_latents.shape[0]).sum().item()
117
+ )
118
+ if hidden_acts_filtered.shape[0] == 0:
119
+ shrinkage = 0.0
120
+ else:
121
+ shrinkage = (
122
+ (
123
+ sae_output.norm(dim=-1)
124
+ / hidden_acts_filtered.norm(dim=-1).clamp(min=1e-8)
125
+ )
126
+ .mean()
127
+ .item()
128
+ )
129
+
130
+ # Compute MCC between SAE decoder and ground truth features
131
+ sae_decoder: torch.Tensor = sae.W_dec # type: ignore[attr-defined]
132
+ gt_features = feature_dict.feature_vectors
133
+ mcc = mean_correlation_coefficient(sae_decoder, gt_features)
134
+
135
+ return SyntheticDataEvalResult(
136
+ true_l0=true_l0,
137
+ sae_l0=sae_l0,
138
+ dead_latents=dead_latents,
139
+ shrinkage=shrinkage,
140
+ mcc=mcc,
141
+ )
@@ -0,0 +1,138 @@
1
+ """
2
+ Feature dictionary for generating synthetic activations.
3
+
4
+ A FeatureDictionary maps feature activations (sparse coefficients) to dense hidden activations
5
+ by multiplying with a learned or constructed feature embedding matrix.
6
+ """
7
+
8
+ from typing import Callable
9
+
10
+ import torch
11
+ from torch import nn
12
+ from tqdm import tqdm
13
+
14
+ FeatureDictionaryInitializer = Callable[["FeatureDictionary"], None]
15
+
16
+
17
+ def orthogonalize_embeddings(
18
+ embeddings: torch.Tensor,
19
+ target_cos_sim: float = 0,
20
+ num_steps: int = 200,
21
+ lr: float = 0.01,
22
+ show_progress: bool = False,
23
+ ) -> torch.Tensor:
24
+ num_vectors = embeddings.shape[0]
25
+ # Create a detached copy and normalize, then enable gradients
26
+ embeddings = embeddings.detach().clone()
27
+ embeddings = embeddings / embeddings.norm(p=2, dim=1, keepdim=True).clamp(min=1e-8)
28
+ embeddings = embeddings.requires_grad_(True)
29
+
30
+ optimizer = torch.optim.Adam([embeddings], lr=lr) # type: ignore[list-item]
31
+
32
+ # Create a mask to zero out diagonal elements (avoid in-place operations)
33
+ off_diagonal_mask = ~torch.eye(
34
+ num_vectors, dtype=torch.bool, device=embeddings.device
35
+ )
36
+
37
+ pbar = tqdm(
38
+ range(num_steps), desc="Orthogonalizing vectors", disable=not show_progress
39
+ )
40
+ for _ in pbar:
41
+ optimizer.zero_grad()
42
+
43
+ dot_products = embeddings @ embeddings.T
44
+ diff = dot_products - target_cos_sim
45
+ # Use masking instead of in-place fill_diagonal_
46
+ off_diagonal_diff = diff * off_diagonal_mask.float()
47
+ loss = off_diagonal_diff.pow(2).sum()
48
+ loss = loss + num_vectors * (dot_products.diag() - 1).pow(2).sum()
49
+
50
+ loss.backward()
51
+ optimizer.step()
52
+ pbar.set_description(f"loss: {loss.item():.3f}")
53
+
54
+ with torch.no_grad():
55
+ embeddings = embeddings / embeddings.norm(p=2, dim=1, keepdim=True).clamp(
56
+ min=1e-8
57
+ )
58
+ return embeddings.detach().clone()
59
+
60
+
61
+ def orthogonal_initializer(
62
+ num_steps: int = 200, lr: float = 0.01, show_progress: bool = False
63
+ ) -> FeatureDictionaryInitializer:
64
+ def initializer(feature_dict: "FeatureDictionary") -> None:
65
+ feature_dict.feature_vectors.data = orthogonalize_embeddings(
66
+ feature_dict.feature_vectors,
67
+ num_steps=num_steps,
68
+ lr=lr,
69
+ show_progress=show_progress,
70
+ )
71
+
72
+ return initializer
73
+
74
+
75
+ class FeatureDictionary(nn.Module):
76
+ """
77
+ A feature dictionary that maps sparse feature activations to dense hidden activations.
78
+
79
+ This class creates a set of feature vectors (the "dictionary") and provides methods
80
+ to generate hidden activations from feature activations via a linear transformation.
81
+
82
+ The feature vectors can be configured to have a specific pairwise cosine similarity,
83
+ which is useful for controlling the difficulty of sparse recovery.
84
+
85
+ Attributes:
86
+ feature_vectors: Parameter of shape [num_features, hidden_dim] containing the
87
+ feature embedding vectors
88
+ bias: Parameter of shape [hidden_dim] containing the bias term (zeros if bias=False)
89
+ """
90
+
91
+ feature_vectors: nn.Parameter
92
+ bias: nn.Parameter
93
+
94
+ def __init__(
95
+ self,
96
+ num_features: int,
97
+ hidden_dim: int,
98
+ bias: bool = False,
99
+ initializer: FeatureDictionaryInitializer | None = orthogonal_initializer(),
100
+ ):
101
+ """
102
+ Create a new FeatureDictionary.
103
+
104
+ Args:
105
+ num_features: Number of features in the dictionary
106
+ hidden_dim: Dimensionality of the hidden space
107
+ bias: Whether to include a bias term in the embedding
108
+ initializer: Initializer function to use. If None, the embeddings are initialized to random unit vectors. By default will orthogonalize embeddings.
109
+ """
110
+ super().__init__()
111
+ self.num_features = num_features
112
+ self.hidden_dim = hidden_dim
113
+
114
+ # Initialize feature vectors as unit vectors
115
+ embeddings = torch.randn(num_features, hidden_dim)
116
+ embeddings = embeddings / embeddings.norm(p=2, dim=1, keepdim=True).clamp(
117
+ min=1e-8
118
+ )
119
+ self.feature_vectors = nn.Parameter(embeddings)
120
+
121
+ # Initialize bias (zeros if not using bias, but still a parameter for consistent API)
122
+ self.bias = nn.Parameter(torch.zeros(hidden_dim), requires_grad=bias)
123
+
124
+ if initializer is not None:
125
+ initializer(self)
126
+
127
+ def forward(self, feature_activations: torch.Tensor) -> torch.Tensor:
128
+ """
129
+ Convert feature activations to hidden activations.
130
+
131
+ Args:
132
+ feature_activations: Tensor of shape [batch, num_features] containing
133
+ sparse feature activation values
134
+
135
+ Returns:
136
+ Tensor of shape [batch, hidden_dim] containing dense hidden activations
137
+ """
138
+ return feature_activations @ self.feature_vectors + self.bias
@@ -0,0 +1,104 @@
1
+ """
2
+ Helper functions for generating firing probability distributions.
3
+ """
4
+
5
+ import torch
6
+
7
+
8
+ def zipfian_firing_probabilities(
9
+ num_features: int,
10
+ exponent: float = 1.0,
11
+ max_prob: float = 0.3,
12
+ min_prob: float = 0.01,
13
+ ) -> torch.Tensor:
14
+ """
15
+ Generate firing probabilities following a Zipfian (power-law) distribution.
16
+
17
+ Creates probabilities where a few features fire frequently and most fire rarely,
18
+ which mirrors the distribution often observed in real neural network features.
19
+
20
+ Args:
21
+ num_features: Number of features to generate probabilities for
22
+ exponent: Zipf exponent (higher = steeper dropoff). Default 1.0.
23
+ max_prob: Maximum firing probability (for the most frequent feature)
24
+ min_prob: Minimum firing probability (for the least frequent feature)
25
+
26
+ Returns:
27
+ Tensor of shape [num_features] with firing probabilities in descending order
28
+ """
29
+ if num_features < 1:
30
+ raise ValueError("num_features must be at least 1")
31
+ if exponent <= 0:
32
+ raise ValueError("exponent must be positive")
33
+ if not 0 < min_prob < max_prob <= 1:
34
+ raise ValueError("Must have 0 < min_prob < max_prob <= 1")
35
+
36
+ ranks = torch.arange(1, num_features + 1, dtype=torch.float32)
37
+ probs = 1.0 / ranks**exponent
38
+
39
+ # Scale to [min_prob, max_prob]
40
+ if num_features == 1:
41
+ return torch.tensor([max_prob])
42
+
43
+ probs_min, probs_max = probs.min(), probs.max()
44
+ return min_prob + (max_prob - min_prob) * (probs - probs_min) / (
45
+ probs_max - probs_min
46
+ )
47
+
48
+
49
+ def linear_firing_probabilities(
50
+ num_features: int,
51
+ max_prob: float = 0.3,
52
+ min_prob: float = 0.01,
53
+ ) -> torch.Tensor:
54
+ """
55
+ Generate firing probabilities that decay linearly from max to min.
56
+
57
+ Args:
58
+ num_features: Number of features to generate probabilities for
59
+ max_prob: Firing probability for the first feature
60
+ min_prob: Firing probability for the last feature
61
+
62
+ Returns:
63
+ Tensor of shape [num_features] with linearly decaying probabilities
64
+ """
65
+ if num_features < 1:
66
+ raise ValueError("num_features must be at least 1")
67
+ if not 0 < min_prob <= max_prob <= 1:
68
+ raise ValueError("Must have 0 < min_prob <= max_prob <= 1")
69
+
70
+ if num_features == 1:
71
+ return torch.tensor([max_prob])
72
+
73
+ return torch.linspace(max_prob, min_prob, num_features)
74
+
75
+
76
+ def random_firing_probabilities(
77
+ num_features: int,
78
+ max_prob: float = 0.5,
79
+ min_prob: float = 0.01,
80
+ seed: int | None = None,
81
+ ) -> torch.Tensor:
82
+ """
83
+ Generate random firing probabilities uniformly sampled from a range.
84
+
85
+ Args:
86
+ num_features: Number of features to generate probabilities for
87
+ max_prob: Maximum firing probability
88
+ min_prob: Minimum firing probability
89
+ seed: Optional random seed for reproducibility
90
+
91
+ Returns:
92
+ Tensor of shape [num_features] with random firing probabilities
93
+ """
94
+ if num_features < 1:
95
+ raise ValueError("num_features must be at least 1")
96
+ if not 0 < min_prob < max_prob <= 1:
97
+ raise ValueError("Must have 0 < min_prob < max_prob <= 1")
98
+
99
+ generator = torch.Generator()
100
+ if seed is not None:
101
+ generator.manual_seed(seed)
102
+
103
+ probs = torch.rand(num_features, generator=generator, dtype=torch.float32)
104
+ return min_prob + (max_prob - min_prob) * probs