sae-lens 6.28.1__py3-none-any.whl → 6.29.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +1 -1
- sae_lens/pretrained_saes.yaml +1 -1
- sae_lens/synthetic/__init__.py +6 -0
- sae_lens/synthetic/activation_generator.py +198 -25
- sae_lens/synthetic/correlation.py +217 -36
- sae_lens/synthetic/feature_dictionary.py +64 -17
- sae_lens/synthetic/hierarchy.py +657 -84
- sae_lens/synthetic/training.py +16 -3
- {sae_lens-6.28.1.dist-info → sae_lens-6.29.1.dist-info}/METADATA +11 -1
- {sae_lens-6.28.1.dist-info → sae_lens-6.29.1.dist-info}/RECORD +12 -12
- {sae_lens-6.28.1.dist-info → sae_lens-6.29.1.dist-info}/WHEEL +0 -0
- {sae_lens-6.28.1.dist-info → sae_lens-6.29.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -9,18 +9,35 @@ from typing import Callable
|
|
|
9
9
|
|
|
10
10
|
import torch
|
|
11
11
|
from torch import nn
|
|
12
|
-
from tqdm import tqdm
|
|
12
|
+
from tqdm.auto import tqdm
|
|
13
13
|
|
|
14
14
|
FeatureDictionaryInitializer = Callable[["FeatureDictionary"], None]
|
|
15
15
|
|
|
16
16
|
|
|
17
17
|
def orthogonalize_embeddings(
|
|
18
18
|
embeddings: torch.Tensor,
|
|
19
|
-
target_cos_sim: float = 0,
|
|
20
19
|
num_steps: int = 200,
|
|
21
20
|
lr: float = 0.01,
|
|
22
21
|
show_progress: bool = False,
|
|
22
|
+
chunk_size: int = 1024,
|
|
23
23
|
) -> torch.Tensor:
|
|
24
|
+
"""
|
|
25
|
+
Orthogonalize embeddings using gradient descent with chunked computation.
|
|
26
|
+
|
|
27
|
+
Uses chunked computation to avoid O(n²) memory usage when computing pairwise
|
|
28
|
+
dot products. Memory usage is O(chunk_size × n) instead of O(n²).
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
embeddings: Tensor of shape [num_vectors, hidden_dim]
|
|
32
|
+
num_steps: Number of optimization steps
|
|
33
|
+
lr: Learning rate for Adam optimizer
|
|
34
|
+
show_progress: Whether to show progress bar
|
|
35
|
+
chunk_size: Number of vectors to process at once. Smaller values use less
|
|
36
|
+
memory but may be slower.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
Orthogonalized embeddings of the same shape, normalized to unit length.
|
|
40
|
+
"""
|
|
24
41
|
num_vectors = embeddings.shape[0]
|
|
25
42
|
# Create a detached copy and normalize, then enable gradients
|
|
26
43
|
embeddings = embeddings.detach().clone()
|
|
@@ -29,24 +46,37 @@ def orthogonalize_embeddings(
|
|
|
29
46
|
|
|
30
47
|
optimizer = torch.optim.Adam([embeddings], lr=lr) # type: ignore[list-item]
|
|
31
48
|
|
|
32
|
-
# Create a mask to zero out diagonal elements (avoid in-place operations)
|
|
33
|
-
off_diagonal_mask = ~torch.eye(
|
|
34
|
-
num_vectors, dtype=torch.bool, device=embeddings.device
|
|
35
|
-
)
|
|
36
|
-
|
|
37
49
|
pbar = tqdm(
|
|
38
50
|
range(num_steps), desc="Orthogonalizing vectors", disable=not show_progress
|
|
39
51
|
)
|
|
40
52
|
for _ in pbar:
|
|
41
53
|
optimizer.zero_grad()
|
|
42
54
|
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
55
|
+
off_diag_loss = torch.tensor(0.0, device=embeddings.device)
|
|
56
|
+
diag_loss = torch.tensor(0.0, device=embeddings.device)
|
|
57
|
+
|
|
58
|
+
for i in range(0, num_vectors, chunk_size):
|
|
59
|
+
end_i = min(i + chunk_size, num_vectors)
|
|
60
|
+
chunk = embeddings[i:end_i]
|
|
61
|
+
chunk_dots = chunk @ embeddings.T # [chunk_size, num_vectors]
|
|
49
62
|
|
|
63
|
+
# Create mask to zero out diagonal elements for this chunk
|
|
64
|
+
# Diagonal of full matrix: position (i+k, i+k) → in chunk_dots: (k, i+k)
|
|
65
|
+
chunk_len = end_i - i
|
|
66
|
+
row_indices = torch.arange(chunk_len, device=embeddings.device)
|
|
67
|
+
col_indices = i + row_indices # column indices in full matrix
|
|
68
|
+
|
|
69
|
+
# Boolean mask: True for off-diagonal elements we want to include
|
|
70
|
+
off_diag_mask = torch.ones_like(chunk_dots, dtype=torch.bool)
|
|
71
|
+
off_diag_mask[row_indices, col_indices] = False
|
|
72
|
+
|
|
73
|
+
off_diag_loss = off_diag_loss + chunk_dots[off_diag_mask].pow(2).sum()
|
|
74
|
+
|
|
75
|
+
# Diagonal loss: keep self-dot-products at 1
|
|
76
|
+
diag_vals = chunk_dots[row_indices, col_indices]
|
|
77
|
+
diag_loss = diag_loss + (diag_vals - 1).pow(2).sum()
|
|
78
|
+
|
|
79
|
+
loss = off_diag_loss + num_vectors * diag_loss
|
|
50
80
|
loss.backward()
|
|
51
81
|
optimizer.step()
|
|
52
82
|
pbar.set_description(f"loss: {loss.item():.3f}")
|
|
@@ -59,7 +89,10 @@ def orthogonalize_embeddings(
|
|
|
59
89
|
|
|
60
90
|
|
|
61
91
|
def orthogonal_initializer(
|
|
62
|
-
num_steps: int = 200,
|
|
92
|
+
num_steps: int = 200,
|
|
93
|
+
lr: float = 0.01,
|
|
94
|
+
show_progress: bool = False,
|
|
95
|
+
chunk_size: int = 1024,
|
|
63
96
|
) -> FeatureDictionaryInitializer:
|
|
64
97
|
def initializer(feature_dict: "FeatureDictionary") -> None:
|
|
65
98
|
feature_dict.feature_vectors.data = orthogonalize_embeddings(
|
|
@@ -67,6 +100,7 @@ def orthogonal_initializer(
|
|
|
67
100
|
num_steps=num_steps,
|
|
68
101
|
lr=lr,
|
|
69
102
|
show_progress=show_progress,
|
|
103
|
+
chunk_size=chunk_size,
|
|
70
104
|
)
|
|
71
105
|
|
|
72
106
|
return initializer
|
|
@@ -97,6 +131,7 @@ class FeatureDictionary(nn.Module):
|
|
|
97
131
|
hidden_dim: int,
|
|
98
132
|
bias: bool = False,
|
|
99
133
|
initializer: FeatureDictionaryInitializer | None = orthogonal_initializer(),
|
|
134
|
+
device: str | torch.device = "cpu",
|
|
100
135
|
):
|
|
101
136
|
"""
|
|
102
137
|
Create a new FeatureDictionary.
|
|
@@ -106,20 +141,23 @@ class FeatureDictionary(nn.Module):
|
|
|
106
141
|
hidden_dim: Dimensionality of the hidden space
|
|
107
142
|
bias: Whether to include a bias term in the embedding
|
|
108
143
|
initializer: Initializer function to use. If None, the embeddings are initialized to random unit vectors. By default will orthogonalize embeddings.
|
|
144
|
+
device: Device to use for the feature dictionary.
|
|
109
145
|
"""
|
|
110
146
|
super().__init__()
|
|
111
147
|
self.num_features = num_features
|
|
112
148
|
self.hidden_dim = hidden_dim
|
|
113
149
|
|
|
114
150
|
# Initialize feature vectors as unit vectors
|
|
115
|
-
embeddings = torch.randn(num_features, hidden_dim)
|
|
151
|
+
embeddings = torch.randn(num_features, hidden_dim, device=device)
|
|
116
152
|
embeddings = embeddings / embeddings.norm(p=2, dim=1, keepdim=True).clamp(
|
|
117
153
|
min=1e-8
|
|
118
154
|
)
|
|
119
155
|
self.feature_vectors = nn.Parameter(embeddings)
|
|
120
156
|
|
|
121
157
|
# Initialize bias (zeros if not using bias, but still a parameter for consistent API)
|
|
122
|
-
self.bias = nn.Parameter(
|
|
158
|
+
self.bias = nn.Parameter(
|
|
159
|
+
torch.zeros(hidden_dim, device=device), requires_grad=bias
|
|
160
|
+
)
|
|
123
161
|
|
|
124
162
|
if initializer is not None:
|
|
125
163
|
initializer(self)
|
|
@@ -130,9 +168,18 @@ class FeatureDictionary(nn.Module):
|
|
|
130
168
|
|
|
131
169
|
Args:
|
|
132
170
|
feature_activations: Tensor of shape [batch, num_features] containing
|
|
133
|
-
sparse feature activation values
|
|
171
|
+
sparse feature activation values. Can be dense or sparse COO.
|
|
134
172
|
|
|
135
173
|
Returns:
|
|
136
174
|
Tensor of shape [batch, hidden_dim] containing dense hidden activations
|
|
137
175
|
"""
|
|
176
|
+
if feature_activations.is_sparse:
|
|
177
|
+
# autocast is disabled here because sparse matmul is not supported with bfloat16
|
|
178
|
+
with torch.autocast(
|
|
179
|
+
device_type=feature_activations.device.type, enabled=False
|
|
180
|
+
):
|
|
181
|
+
return (
|
|
182
|
+
torch.sparse.mm(feature_activations, self.feature_vectors)
|
|
183
|
+
+ self.bias
|
|
184
|
+
)
|
|
138
185
|
return feature_activations @ self.feature_vectors + self.bias
|