sae-lens 6.26.0__py3-none-any.whl → 6.28.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,215 @@
1
+ """
2
+ Functions for generating synthetic feature activations.
3
+ """
4
+
5
+ from collections.abc import Callable, Sequence
6
+
7
+ import torch
8
+ from scipy.stats import norm
9
+ from torch import nn
10
+ from torch.distributions import MultivariateNormal
11
+
12
+ from sae_lens.util import str_to_dtype
13
+
14
+ ActivationsModifier = Callable[[torch.Tensor], torch.Tensor]
15
+ ActivationsModifierInput = ActivationsModifier | Sequence[ActivationsModifier] | None
16
+
17
+
18
+ class ActivationGenerator(nn.Module):
19
+ """
20
+ Generator for synthetic feature activations.
21
+
22
+ This module provides a generator for synthetic feature activations with controlled properties.
23
+ """
24
+
25
+ num_features: int
26
+ firing_probabilities: torch.Tensor
27
+ std_firing_magnitudes: torch.Tensor
28
+ mean_firing_magnitudes: torch.Tensor
29
+ modify_activations: ActivationsModifier | None
30
+ correlation_matrix: torch.Tensor | None
31
+ correlation_thresholds: torch.Tensor | None
32
+
33
+ def __init__(
34
+ self,
35
+ num_features: int,
36
+ firing_probabilities: torch.Tensor | float,
37
+ std_firing_magnitudes: torch.Tensor | float = 0.0,
38
+ mean_firing_magnitudes: torch.Tensor | float = 1.0,
39
+ modify_activations: ActivationsModifierInput = None,
40
+ correlation_matrix: torch.Tensor | None = None,
41
+ device: torch.device | str = "cpu",
42
+ dtype: torch.dtype | str = "float32",
43
+ ):
44
+ super().__init__()
45
+ self.num_features = num_features
46
+ self.firing_probabilities = _to_tensor(
47
+ firing_probabilities, num_features, device, dtype
48
+ )
49
+ self.std_firing_magnitudes = _to_tensor(
50
+ std_firing_magnitudes, num_features, device, dtype
51
+ )
52
+ self.mean_firing_magnitudes = _to_tensor(
53
+ mean_firing_magnitudes, num_features, device, dtype
54
+ )
55
+ self.modify_activations = _normalize_modifiers(modify_activations)
56
+ self.correlation_thresholds = None
57
+ if correlation_matrix is not None:
58
+ _validate_correlation_matrix(correlation_matrix, num_features)
59
+ self.correlation_thresholds = torch.tensor(
60
+ [norm.ppf(1 - p.item()) for p in self.firing_probabilities],
61
+ device=device,
62
+ dtype=self.firing_probabilities.dtype,
63
+ )
64
+ self.correlation_matrix = correlation_matrix
65
+
66
+ def sample(self, batch_size: int) -> torch.Tensor:
67
+ """
68
+ Generate a batch of feature activations with controlled properties.
69
+
70
+ This is the main function for generating synthetic training data for SAEs.
71
+ Features fire independently according to their firing probabilities unless
72
+ a correlation matrix is provided.
73
+
74
+ Args:
75
+ batch_size: Number of samples to generate
76
+
77
+ Returns:
78
+ Tensor of shape [batch_size, num_features] with non-negative activations
79
+ """
80
+ # All tensors (firing_probabilities, std_firing_magnitudes, mean_firing_magnitudes)
81
+ # are on the same device from __init__ via _to_tensor()
82
+ device = self.firing_probabilities.device
83
+
84
+ if self.correlation_matrix is not None:
85
+ assert self.correlation_thresholds is not None
86
+ firing_features = _generate_correlated_features(
87
+ batch_size,
88
+ self.correlation_matrix,
89
+ self.correlation_thresholds,
90
+ device,
91
+ )
92
+ else:
93
+ firing_features = torch.bernoulli(
94
+ self.firing_probabilities.unsqueeze(0).expand(batch_size, -1)
95
+ )
96
+
97
+ firing_magnitude_delta = torch.normal(
98
+ torch.zeros_like(self.firing_probabilities)
99
+ .unsqueeze(0)
100
+ .expand(batch_size, -1),
101
+ self.std_firing_magnitudes.unsqueeze(0).expand(batch_size, -1),
102
+ )
103
+ firing_magnitude_delta[firing_features == 0] = 0
104
+ feature_activations = (
105
+ firing_features * self.mean_firing_magnitudes + firing_magnitude_delta
106
+ ).relu()
107
+
108
+ if self.modify_activations is not None:
109
+ feature_activations = self.modify_activations(feature_activations).relu()
110
+ return feature_activations
111
+
112
+ def forward(self, batch_size: int) -> torch.Tensor:
113
+ return self.sample(batch_size)
114
+
115
+
116
+ def _generate_correlated_features(
117
+ batch_size: int,
118
+ correlation_matrix: torch.Tensor,
119
+ thresholds: torch.Tensor,
120
+ device: torch.device,
121
+ ) -> torch.Tensor:
122
+ """
123
+ Generate correlated binary features using multivariate Gaussian sampling.
124
+
125
+ Uses the Gaussian copula approach: sample from a multivariate normal
126
+ distribution, then threshold to get binary features.
127
+
128
+ Args:
129
+ batch_size: Number of samples to generate
130
+ correlation_matrix: Correlation matrix between features
131
+ thresholds: Pre-computed thresholds for each feature (from inverse normal CDF)
132
+ device: Device to generate samples on
133
+
134
+ Returns:
135
+ Binary feature matrix of shape [batch_size, num_features]
136
+ """
137
+ num_features = correlation_matrix.shape[0]
138
+
139
+ mvn = MultivariateNormal(
140
+ loc=torch.zeros(num_features, device=device, dtype=thresholds.dtype),
141
+ covariance_matrix=correlation_matrix.to(device=device, dtype=thresholds.dtype),
142
+ )
143
+
144
+ gaussian_samples = mvn.sample((batch_size,))
145
+ return (gaussian_samples > thresholds.unsqueeze(0)).float()
146
+
147
+
148
+ def _to_tensor(
149
+ value: torch.Tensor | float,
150
+ num_features: int,
151
+ device: torch.device | str,
152
+ dtype: torch.dtype | str,
153
+ ) -> torch.Tensor:
154
+ dtype = str_to_dtype(dtype)
155
+ device = torch.device(device)
156
+ if not isinstance(value, torch.Tensor):
157
+ value = value * torch.ones(num_features, device=device, dtype=dtype)
158
+ if value.shape != (num_features,):
159
+ raise ValueError(
160
+ f"Value must be a tensor of shape ({num_features},) or a float. Got {value.shape}"
161
+ )
162
+ return value.to(device, dtype)
163
+
164
+
165
+ def _normalize_modifiers(
166
+ modify_activations: ActivationsModifierInput,
167
+ ) -> ActivationsModifier | None:
168
+ """Convert modifier input to a single modifier or None."""
169
+ if modify_activations is None:
170
+ return None
171
+ if callable(modify_activations):
172
+ return modify_activations
173
+ # It's a sequence of modifiers - chain them
174
+ modifiers = list(modify_activations)
175
+ if len(modifiers) == 0:
176
+ return None
177
+ if len(modifiers) == 1:
178
+ return modifiers[0]
179
+
180
+ def chained(activations: torch.Tensor) -> torch.Tensor:
181
+ result = activations
182
+ for modifier in modifiers:
183
+ result = modifier(result)
184
+ return result
185
+
186
+ return chained
187
+
188
+
189
+ def _validate_correlation_matrix(
190
+ correlation_matrix: torch.Tensor, num_features: int
191
+ ) -> None:
192
+ """Validate that a correlation matrix has correct properties.
193
+
194
+ Args:
195
+ correlation_matrix: The matrix to validate
196
+ num_features: Expected number of features (matrix should be [num_features, num_features])
197
+
198
+ Raises:
199
+ ValueError: If the matrix has incorrect shape, non-unit diagonal, or is not positive definite
200
+ """
201
+ expected_shape = (num_features, num_features)
202
+ if correlation_matrix.shape != expected_shape:
203
+ raise ValueError(
204
+ f"Correlation matrix must have shape {expected_shape}, "
205
+ f"got {tuple(correlation_matrix.shape)}"
206
+ )
207
+
208
+ diagonal = torch.diag(correlation_matrix)
209
+ if not torch.allclose(diagonal, torch.ones_like(diagonal)):
210
+ raise ValueError("Correlation matrix diagonal must be all 1s")
211
+
212
+ try:
213
+ torch.linalg.cholesky(correlation_matrix)
214
+ except RuntimeError as e:
215
+ raise ValueError("Correlation matrix must be positive definite") from e
@@ -0,0 +1,170 @@
1
+ import random
2
+
3
+ import torch
4
+
5
+
6
+ def create_correlation_matrix_from_correlations(
7
+ num_features: int,
8
+ correlations: dict[tuple[int, int], float] | None = None,
9
+ default_correlation: float = 0.0,
10
+ ) -> torch.Tensor:
11
+ """
12
+ Create a correlation matrix with specified pairwise correlations.
13
+
14
+ Args:
15
+ num_features: Number of features
16
+ correlations: Dict mapping (i, j) pairs to correlation values.
17
+ Pairs should have i < j.
18
+ default_correlation: Default correlation for unspecified pairs
19
+
20
+ Returns:
21
+ Correlation matrix of shape [num_features, num_features]
22
+ """
23
+ matrix = torch.eye(num_features) + default_correlation * (
24
+ 1 - torch.eye(num_features)
25
+ )
26
+
27
+ if correlations is not None:
28
+ for (i, j), corr in correlations.items():
29
+ matrix[i, j] = corr
30
+ matrix[j, i] = corr
31
+
32
+ # Ensure matrix is symmetric (numerical precision)
33
+ matrix = (matrix + matrix.T) / 2
34
+
35
+ # Check positive definiteness and fix if necessary
36
+ # Use eigvalsh for symmetric matrices (returns real eigenvalues)
37
+ eigenvals = torch.linalg.eigvalsh(matrix)
38
+ if torch.any(eigenvals < -1e-6):
39
+ matrix = _fix_correlation_matrix(matrix)
40
+
41
+ return matrix
42
+
43
+
44
+ def _fix_correlation_matrix(
45
+ matrix: torch.Tensor, min_eigenval: float = 1e-6
46
+ ) -> torch.Tensor:
47
+ """Fix a correlation matrix to be positive semi-definite."""
48
+ eigenvals, eigenvecs = torch.linalg.eigh(matrix)
49
+ eigenvals = torch.clamp(eigenvals, min=min_eigenval)
50
+ fixed_matrix = eigenvecs @ torch.diag(eigenvals) @ eigenvecs.T
51
+
52
+ diag_vals = torch.diag(fixed_matrix)
53
+ fixed_matrix = fixed_matrix / torch.sqrt(
54
+ diag_vals.unsqueeze(0) * diag_vals.unsqueeze(1)
55
+ )
56
+ fixed_matrix.fill_diagonal_(1.0)
57
+
58
+ return fixed_matrix
59
+
60
+
61
+ def generate_random_correlations(
62
+ num_features: int,
63
+ positive_ratio: float = 0.5,
64
+ uncorrelated_ratio: float = 0.3,
65
+ min_correlation_strength: float = 0.1,
66
+ max_correlation_strength: float = 0.8,
67
+ seed: int | None = None,
68
+ ) -> dict[tuple[int, int], float]:
69
+ """
70
+ Generate random correlations between features with specified constraints.
71
+
72
+ Args:
73
+ num_features: Number of features
74
+ positive_ratio: Fraction of correlations that should be positive (0.0 to 1.0)
75
+ uncorrelated_ratio: Fraction of feature pairs that should remain uncorrelated (0.0 to 1.0)
76
+ min_correlation_strength: Minimum absolute correlation strength
77
+ max_correlation_strength: Maximum absolute correlation strength
78
+ seed: Random seed for reproducibility
79
+
80
+ Returns:
81
+ Dictionary mapping (i, j) pairs to correlation values
82
+ """
83
+ # Use local random number generator to avoid side effects on global state
84
+ rng = random.Random(seed)
85
+
86
+ # Validate inputs
87
+ if not 0.0 <= positive_ratio <= 1.0:
88
+ raise ValueError("positive_ratio must be between 0.0 and 1.0")
89
+ if not 0.0 <= uncorrelated_ratio <= 1.0:
90
+ raise ValueError("uncorrelated_ratio must be between 0.0 and 1.0")
91
+ if min_correlation_strength < 0:
92
+ raise ValueError("min_correlation_strength must be non-negative")
93
+ if max_correlation_strength > 1.0:
94
+ raise ValueError("max_correlation_strength must be <= 1.0")
95
+ if min_correlation_strength > max_correlation_strength:
96
+ raise ValueError("min_correlation_strength must be <= max_correlation_strength")
97
+
98
+ # Generate all possible feature pairs (i, j) where i < j
99
+ all_pairs = [
100
+ (i, j) for i in range(num_features) for j in range(i + 1, num_features)
101
+ ]
102
+ total_pairs = len(all_pairs)
103
+
104
+ if total_pairs == 0:
105
+ return {}
106
+
107
+ # Determine how many pairs to correlate vs leave uncorrelated
108
+ num_uncorrelated = int(total_pairs * uncorrelated_ratio)
109
+ num_correlated = total_pairs - num_uncorrelated
110
+
111
+ # Randomly select which pairs to correlate
112
+ correlated_pairs = rng.sample(all_pairs, num_correlated)
113
+
114
+ # For correlated pairs, determine positive vs negative
115
+ num_positive = int(num_correlated * positive_ratio)
116
+ num_negative = num_correlated - num_positive
117
+
118
+ # Assign signs
119
+ signs = [1] * num_positive + [-1] * num_negative
120
+ rng.shuffle(signs)
121
+
122
+ # Generate correlation strengths
123
+ correlations = {}
124
+ for pair, sign in zip(correlated_pairs, signs):
125
+ # Sample correlation strength uniformly from range
126
+ strength = rng.uniform(min_correlation_strength, max_correlation_strength)
127
+ correlations[pair] = sign * strength
128
+
129
+ return correlations
130
+
131
+
132
+ def generate_random_correlation_matrix(
133
+ num_features: int,
134
+ positive_ratio: float = 0.5,
135
+ uncorrelated_ratio: float = 0.3,
136
+ min_correlation_strength: float = 0.1,
137
+ max_correlation_strength: float = 0.8,
138
+ seed: int | None = None,
139
+ ) -> torch.Tensor:
140
+ """
141
+ Generate a random correlation matrix with specified constraints.
142
+
143
+ This is a convenience function that combines generate_random_correlations()
144
+ and create_correlation_matrix_from_correlations() into a single call.
145
+
146
+ Args:
147
+ num_features: Number of features
148
+ positive_ratio: Fraction of correlations that should be positive (0.0 to 1.0)
149
+ uncorrelated_ratio: Fraction of feature pairs that should remain uncorrelated (0.0 to 1.0)
150
+ min_correlation_strength: Minimum absolute correlation strength
151
+ max_correlation_strength: Maximum absolute correlation strength
152
+ seed: Random seed for reproducibility
153
+
154
+ Returns:
155
+ Random correlation matrix of shape [num_features, num_features]
156
+ """
157
+ # Generate random correlations
158
+ correlations = generate_random_correlations(
159
+ num_features=num_features,
160
+ positive_ratio=positive_ratio,
161
+ uncorrelated_ratio=uncorrelated_ratio,
162
+ min_correlation_strength=min_correlation_strength,
163
+ max_correlation_strength=max_correlation_strength,
164
+ seed=seed,
165
+ )
166
+
167
+ # Create and return correlation matrix
168
+ return create_correlation_matrix_from_correlations(
169
+ num_features=num_features, correlations=correlations
170
+ )
@@ -0,0 +1,141 @@
1
+ """
2
+ Utilities for training SAEs on synthetic data.
3
+
4
+ This module provides helpers for:
5
+
6
+ - Generating training data from feature dictionaries
7
+ - Training SAEs on synthetic data
8
+ - Evaluating SAEs against known ground truth features
9
+ - Initializing SAEs to match feature dictionaries
10
+ """
11
+
12
+ from dataclasses import dataclass
13
+
14
+ import torch
15
+ from scipy.optimize import linear_sum_assignment
16
+
17
+ from sae_lens.synthetic.activation_generator import ActivationGenerator
18
+ from sae_lens.synthetic.feature_dictionary import FeatureDictionary
19
+
20
+
21
+ def mean_correlation_coefficient(
22
+ features_a: torch.Tensor,
23
+ features_b: torch.Tensor,
24
+ ) -> float:
25
+ """
26
+ Compute Mean Correlation Coefficient (MCC) between two sets of feature vectors.
27
+
28
+ MCC measures how well learned features align with ground truth features by finding
29
+ an optimal one-to-one matching using the Hungarian algorithm and computing the
30
+ mean absolute cosine similarity of matched pairs.
31
+
32
+ Reference: O'Neill et al. "Compute Optimal Inference and Provable Amortisation
33
+ Gap in Sparse Autoencoders" (arXiv:2411.13117)
34
+
35
+ Args:
36
+ features_a: Feature vectors of shape [num_features_a, hidden_dim]
37
+ features_b: Feature vectors of shape [num_features_b, hidden_dim]
38
+
39
+ Returns:
40
+ MCC score in range [0, 1], where 1 indicates perfect alignment
41
+ """
42
+ # Normalize to unit vectors
43
+ a_norm = features_a / features_a.norm(dim=1, keepdim=True).clamp(min=1e-8)
44
+ b_norm = features_b / features_b.norm(dim=1, keepdim=True).clamp(min=1e-8)
45
+
46
+ # Compute absolute cosine similarity matrix
47
+ cos_sim = torch.abs(a_norm @ b_norm.T)
48
+
49
+ # Convert to cost matrix for Hungarian algorithm (which minimizes)
50
+ cost_matrix = 1 - cos_sim.cpu().numpy()
51
+
52
+ # Find optimal matching
53
+ row_ind, col_ind = linear_sum_assignment(cost_matrix)
54
+
55
+ # Compute mean of matched similarities
56
+ matched_similarities = cos_sim[row_ind, col_ind]
57
+ return matched_similarities.mean().item()
58
+
59
+
60
+ @dataclass
61
+ class SyntheticDataEvalResult:
62
+ """Results from evaluating an SAE on synthetic data."""
63
+
64
+ true_l0: float
65
+ """Average L0 of the true feature activations"""
66
+
67
+ sae_l0: float
68
+ """Average L0 of the SAE's latent activations"""
69
+
70
+ dead_latents: int
71
+ """Number of SAE latents that never fired"""
72
+
73
+ shrinkage: float
74
+ """Average ratio of SAE output norm to input norm (1.0 = no shrinkage)"""
75
+
76
+ mcc: float
77
+ """Mean Correlation Coefficient between SAE decoder and ground truth features"""
78
+
79
+
80
+ @torch.no_grad()
81
+ def eval_sae_on_synthetic_data(
82
+ sae: torch.nn.Module,
83
+ feature_dict: FeatureDictionary,
84
+ activations_generator: ActivationGenerator,
85
+ num_samples: int = 100_000,
86
+ ) -> SyntheticDataEvalResult:
87
+ """
88
+ Evaluate an SAE on synthetic data with known ground truth.
89
+
90
+ Args:
91
+ sae: The SAE to evaluate. Must have encode() and decode() methods.
92
+ feature_dict: The feature dictionary used to generate activations
93
+ activations_generator: Generator that produces feature activations
94
+ num_samples: Number of samples to use for evaluation
95
+
96
+ Returns:
97
+ SyntheticDataEvalResult containing evaluation metrics
98
+ """
99
+ sae.eval()
100
+
101
+ # Generate samples
102
+ feature_acts = activations_generator.sample(num_samples)
103
+ true_l0 = (feature_acts > 0).float().sum(dim=-1).mean().item()
104
+ hidden_acts = feature_dict(feature_acts)
105
+
106
+ # Filter out entries where no features fire
107
+ non_zero_mask = hidden_acts.norm(dim=-1) > 0
108
+ hidden_acts_filtered = hidden_acts[non_zero_mask]
109
+
110
+ # Get SAE reconstructions
111
+ sae_latents = sae.encode(hidden_acts_filtered) # type: ignore[attr-defined]
112
+ sae_output = sae.decode(sae_latents) # type: ignore[attr-defined]
113
+
114
+ sae_l0 = (sae_latents > 0).float().sum(dim=-1).mean().item()
115
+ dead_latents = int(
116
+ ((sae_latents == 0).sum(dim=0) == sae_latents.shape[0]).sum().item()
117
+ )
118
+ if hidden_acts_filtered.shape[0] == 0:
119
+ shrinkage = 0.0
120
+ else:
121
+ shrinkage = (
122
+ (
123
+ sae_output.norm(dim=-1)
124
+ / hidden_acts_filtered.norm(dim=-1).clamp(min=1e-8)
125
+ )
126
+ .mean()
127
+ .item()
128
+ )
129
+
130
+ # Compute MCC between SAE decoder and ground truth features
131
+ sae_decoder: torch.Tensor = sae.W_dec # type: ignore[attr-defined]
132
+ gt_features = feature_dict.feature_vectors
133
+ mcc = mean_correlation_coefficient(sae_decoder, gt_features)
134
+
135
+ return SyntheticDataEvalResult(
136
+ true_l0=true_l0,
137
+ sae_l0=sae_l0,
138
+ dead_latents=dead_latents,
139
+ shrinkage=shrinkage,
140
+ mcc=mcc,
141
+ )
@@ -0,0 +1,138 @@
1
+ """
2
+ Feature dictionary for generating synthetic activations.
3
+
4
+ A FeatureDictionary maps feature activations (sparse coefficients) to dense hidden activations
5
+ by multiplying with a learned or constructed feature embedding matrix.
6
+ """
7
+
8
+ from typing import Callable
9
+
10
+ import torch
11
+ from torch import nn
12
+ from tqdm import tqdm
13
+
14
+ FeatureDictionaryInitializer = Callable[["FeatureDictionary"], None]
15
+
16
+
17
+ def orthogonalize_embeddings(
18
+ embeddings: torch.Tensor,
19
+ target_cos_sim: float = 0,
20
+ num_steps: int = 200,
21
+ lr: float = 0.01,
22
+ show_progress: bool = False,
23
+ ) -> torch.Tensor:
24
+ num_vectors = embeddings.shape[0]
25
+ # Create a detached copy and normalize, then enable gradients
26
+ embeddings = embeddings.detach().clone()
27
+ embeddings = embeddings / embeddings.norm(p=2, dim=1, keepdim=True).clamp(min=1e-8)
28
+ embeddings = embeddings.requires_grad_(True)
29
+
30
+ optimizer = torch.optim.Adam([embeddings], lr=lr) # type: ignore[list-item]
31
+
32
+ # Create a mask to zero out diagonal elements (avoid in-place operations)
33
+ off_diagonal_mask = ~torch.eye(
34
+ num_vectors, dtype=torch.bool, device=embeddings.device
35
+ )
36
+
37
+ pbar = tqdm(
38
+ range(num_steps), desc="Orthogonalizing vectors", disable=not show_progress
39
+ )
40
+ for _ in pbar:
41
+ optimizer.zero_grad()
42
+
43
+ dot_products = embeddings @ embeddings.T
44
+ diff = dot_products - target_cos_sim
45
+ # Use masking instead of in-place fill_diagonal_
46
+ off_diagonal_diff = diff * off_diagonal_mask.float()
47
+ loss = off_diagonal_diff.pow(2).sum()
48
+ loss = loss + num_vectors * (dot_products.diag() - 1).pow(2).sum()
49
+
50
+ loss.backward()
51
+ optimizer.step()
52
+ pbar.set_description(f"loss: {loss.item():.3f}")
53
+
54
+ with torch.no_grad():
55
+ embeddings = embeddings / embeddings.norm(p=2, dim=1, keepdim=True).clamp(
56
+ min=1e-8
57
+ )
58
+ return embeddings.detach().clone()
59
+
60
+
61
+ def orthogonal_initializer(
62
+ num_steps: int = 200, lr: float = 0.01, show_progress: bool = False
63
+ ) -> FeatureDictionaryInitializer:
64
+ def initializer(feature_dict: "FeatureDictionary") -> None:
65
+ feature_dict.feature_vectors.data = orthogonalize_embeddings(
66
+ feature_dict.feature_vectors,
67
+ num_steps=num_steps,
68
+ lr=lr,
69
+ show_progress=show_progress,
70
+ )
71
+
72
+ return initializer
73
+
74
+
75
+ class FeatureDictionary(nn.Module):
76
+ """
77
+ A feature dictionary that maps sparse feature activations to dense hidden activations.
78
+
79
+ This class creates a set of feature vectors (the "dictionary") and provides methods
80
+ to generate hidden activations from feature activations via a linear transformation.
81
+
82
+ The feature vectors can be configured to have a specific pairwise cosine similarity,
83
+ which is useful for controlling the difficulty of sparse recovery.
84
+
85
+ Attributes:
86
+ feature_vectors: Parameter of shape [num_features, hidden_dim] containing the
87
+ feature embedding vectors
88
+ bias: Parameter of shape [hidden_dim] containing the bias term (zeros if bias=False)
89
+ """
90
+
91
+ feature_vectors: nn.Parameter
92
+ bias: nn.Parameter
93
+
94
+ def __init__(
95
+ self,
96
+ num_features: int,
97
+ hidden_dim: int,
98
+ bias: bool = False,
99
+ initializer: FeatureDictionaryInitializer | None = orthogonal_initializer(),
100
+ ):
101
+ """
102
+ Create a new FeatureDictionary.
103
+
104
+ Args:
105
+ num_features: Number of features in the dictionary
106
+ hidden_dim: Dimensionality of the hidden space
107
+ bias: Whether to include a bias term in the embedding
108
+ initializer: Initializer function to use. If None, the embeddings are initialized to random unit vectors. By default will orthogonalize embeddings.
109
+ """
110
+ super().__init__()
111
+ self.num_features = num_features
112
+ self.hidden_dim = hidden_dim
113
+
114
+ # Initialize feature vectors as unit vectors
115
+ embeddings = torch.randn(num_features, hidden_dim)
116
+ embeddings = embeddings / embeddings.norm(p=2, dim=1, keepdim=True).clamp(
117
+ min=1e-8
118
+ )
119
+ self.feature_vectors = nn.Parameter(embeddings)
120
+
121
+ # Initialize bias (zeros if not using bias, but still a parameter for consistent API)
122
+ self.bias = nn.Parameter(torch.zeros(hidden_dim), requires_grad=bias)
123
+
124
+ if initializer is not None:
125
+ initializer(self)
126
+
127
+ def forward(self, feature_activations: torch.Tensor) -> torch.Tensor:
128
+ """
129
+ Convert feature activations to hidden activations.
130
+
131
+ Args:
132
+ feature_activations: Tensor of shape [batch, num_features] containing
133
+ sparse feature activation values
134
+
135
+ Returns:
136
+ Tensor of shape [batch, hidden_dim] containing dense hidden activations
137
+ """
138
+ return feature_activations @ self.feature_vectors + self.bias