sae-lens 6.28.1__py3-none-any.whl → 6.28.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +1 -1
- sae_lens/pretrained_saes.yaml +1 -1
- sae_lens/synthetic/activation_generator.py +1 -0
- sae_lens/synthetic/feature_dictionary.py +53 -15
- sae_lens/synthetic/hierarchy.py +345 -84
- {sae_lens-6.28.1.dist-info → sae_lens-6.28.2.dist-info}/METADATA +11 -1
- {sae_lens-6.28.1.dist-info → sae_lens-6.28.2.dist-info}/RECORD +9 -9
- {sae_lens-6.28.1.dist-info → sae_lens-6.28.2.dist-info}/WHEEL +0 -0
- {sae_lens-6.28.1.dist-info → sae_lens-6.28.2.dist-info}/licenses/LICENSE +0 -0
sae_lens/__init__.py
CHANGED
sae_lens/pretrained_saes.yaml
CHANGED
|
@@ -40631,7 +40631,7 @@ gemma-3-1b-res-matryoshka-dc:
|
|
|
40631
40631
|
conversion_func: null
|
|
40632
40632
|
links:
|
|
40633
40633
|
model: https://huggingface.co/google/gemma-3-1b-pt
|
|
40634
|
-
model: gemma-3-1b
|
|
40634
|
+
model: google/gemma-3-1b-pt
|
|
40635
40635
|
repo_id: chanind/gemma-3-1b-batch-topk-matryoshka-saes-w-32k-l0-40
|
|
40636
40636
|
saes:
|
|
40637
40637
|
- id: blocks.0.hook_resid_post
|
|
@@ -16,11 +16,28 @@ FeatureDictionaryInitializer = Callable[["FeatureDictionary"], None]
|
|
|
16
16
|
|
|
17
17
|
def orthogonalize_embeddings(
|
|
18
18
|
embeddings: torch.Tensor,
|
|
19
|
-
target_cos_sim: float = 0,
|
|
20
19
|
num_steps: int = 200,
|
|
21
20
|
lr: float = 0.01,
|
|
22
21
|
show_progress: bool = False,
|
|
22
|
+
chunk_size: int = 1024,
|
|
23
23
|
) -> torch.Tensor:
|
|
24
|
+
"""
|
|
25
|
+
Orthogonalize embeddings using gradient descent with chunked computation.
|
|
26
|
+
|
|
27
|
+
Uses chunked computation to avoid O(n²) memory usage when computing pairwise
|
|
28
|
+
dot products. Memory usage is O(chunk_size × n) instead of O(n²).
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
embeddings: Tensor of shape [num_vectors, hidden_dim]
|
|
32
|
+
num_steps: Number of optimization steps
|
|
33
|
+
lr: Learning rate for Adam optimizer
|
|
34
|
+
show_progress: Whether to show progress bar
|
|
35
|
+
chunk_size: Number of vectors to process at once. Smaller values use less
|
|
36
|
+
memory but may be slower.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
Orthogonalized embeddings of the same shape, normalized to unit length.
|
|
40
|
+
"""
|
|
24
41
|
num_vectors = embeddings.shape[0]
|
|
25
42
|
# Create a detached copy and normalize, then enable gradients
|
|
26
43
|
embeddings = embeddings.detach().clone()
|
|
@@ -29,24 +46,37 @@ def orthogonalize_embeddings(
|
|
|
29
46
|
|
|
30
47
|
optimizer = torch.optim.Adam([embeddings], lr=lr) # type: ignore[list-item]
|
|
31
48
|
|
|
32
|
-
# Create a mask to zero out diagonal elements (avoid in-place operations)
|
|
33
|
-
off_diagonal_mask = ~torch.eye(
|
|
34
|
-
num_vectors, dtype=torch.bool, device=embeddings.device
|
|
35
|
-
)
|
|
36
|
-
|
|
37
49
|
pbar = tqdm(
|
|
38
50
|
range(num_steps), desc="Orthogonalizing vectors", disable=not show_progress
|
|
39
51
|
)
|
|
40
52
|
for _ in pbar:
|
|
41
53
|
optimizer.zero_grad()
|
|
42
54
|
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
55
|
+
off_diag_loss = torch.tensor(0.0, device=embeddings.device)
|
|
56
|
+
diag_loss = torch.tensor(0.0, device=embeddings.device)
|
|
57
|
+
|
|
58
|
+
for i in range(0, num_vectors, chunk_size):
|
|
59
|
+
end_i = min(i + chunk_size, num_vectors)
|
|
60
|
+
chunk = embeddings[i:end_i]
|
|
61
|
+
chunk_dots = chunk @ embeddings.T # [chunk_size, num_vectors]
|
|
49
62
|
|
|
63
|
+
# Create mask to zero out diagonal elements for this chunk
|
|
64
|
+
# Diagonal of full matrix: position (i+k, i+k) → in chunk_dots: (k, i+k)
|
|
65
|
+
chunk_len = end_i - i
|
|
66
|
+
row_indices = torch.arange(chunk_len, device=embeddings.device)
|
|
67
|
+
col_indices = i + row_indices # column indices in full matrix
|
|
68
|
+
|
|
69
|
+
# Boolean mask: True for off-diagonal elements we want to include
|
|
70
|
+
off_diag_mask = torch.ones_like(chunk_dots, dtype=torch.bool)
|
|
71
|
+
off_diag_mask[row_indices, col_indices] = False
|
|
72
|
+
|
|
73
|
+
off_diag_loss = off_diag_loss + chunk_dots[off_diag_mask].pow(2).sum()
|
|
74
|
+
|
|
75
|
+
# Diagonal loss: keep self-dot-products at 1
|
|
76
|
+
diag_vals = chunk_dots[row_indices, col_indices]
|
|
77
|
+
diag_loss = diag_loss + (diag_vals - 1).pow(2).sum()
|
|
78
|
+
|
|
79
|
+
loss = off_diag_loss + num_vectors * diag_loss
|
|
50
80
|
loss.backward()
|
|
51
81
|
optimizer.step()
|
|
52
82
|
pbar.set_description(f"loss: {loss.item():.3f}")
|
|
@@ -59,7 +89,10 @@ def orthogonalize_embeddings(
|
|
|
59
89
|
|
|
60
90
|
|
|
61
91
|
def orthogonal_initializer(
|
|
62
|
-
num_steps: int = 200,
|
|
92
|
+
num_steps: int = 200,
|
|
93
|
+
lr: float = 0.01,
|
|
94
|
+
show_progress: bool = False,
|
|
95
|
+
chunk_size: int = 1024,
|
|
63
96
|
) -> FeatureDictionaryInitializer:
|
|
64
97
|
def initializer(feature_dict: "FeatureDictionary") -> None:
|
|
65
98
|
feature_dict.feature_vectors.data = orthogonalize_embeddings(
|
|
@@ -67,6 +100,7 @@ def orthogonal_initializer(
|
|
|
67
100
|
num_steps=num_steps,
|
|
68
101
|
lr=lr,
|
|
69
102
|
show_progress=show_progress,
|
|
103
|
+
chunk_size=chunk_size,
|
|
70
104
|
)
|
|
71
105
|
|
|
72
106
|
return initializer
|
|
@@ -97,6 +131,7 @@ class FeatureDictionary(nn.Module):
|
|
|
97
131
|
hidden_dim: int,
|
|
98
132
|
bias: bool = False,
|
|
99
133
|
initializer: FeatureDictionaryInitializer | None = orthogonal_initializer(),
|
|
134
|
+
device: str | torch.device = "cpu",
|
|
100
135
|
):
|
|
101
136
|
"""
|
|
102
137
|
Create a new FeatureDictionary.
|
|
@@ -106,20 +141,23 @@ class FeatureDictionary(nn.Module):
|
|
|
106
141
|
hidden_dim: Dimensionality of the hidden space
|
|
107
142
|
bias: Whether to include a bias term in the embedding
|
|
108
143
|
initializer: Initializer function to use. If None, the embeddings are initialized to random unit vectors. By default will orthogonalize embeddings.
|
|
144
|
+
device: Device to use for the feature dictionary.
|
|
109
145
|
"""
|
|
110
146
|
super().__init__()
|
|
111
147
|
self.num_features = num_features
|
|
112
148
|
self.hidden_dim = hidden_dim
|
|
113
149
|
|
|
114
150
|
# Initialize feature vectors as unit vectors
|
|
115
|
-
embeddings = torch.randn(num_features, hidden_dim)
|
|
151
|
+
embeddings = torch.randn(num_features, hidden_dim, device=device)
|
|
116
152
|
embeddings = embeddings / embeddings.norm(p=2, dim=1, keepdim=True).clamp(
|
|
117
153
|
min=1e-8
|
|
118
154
|
)
|
|
119
155
|
self.feature_vectors = nn.Parameter(embeddings)
|
|
120
156
|
|
|
121
157
|
# Initialize bias (zeros if not using bias, but still a parameter for consistent API)
|
|
122
|
-
self.bias = nn.Parameter(
|
|
158
|
+
self.bias = nn.Parameter(
|
|
159
|
+
torch.zeros(hidden_dim, device=device), requires_grad=bias
|
|
160
|
+
)
|
|
123
161
|
|
|
124
162
|
if initializer is not None:
|
|
125
163
|
initializer(self)
|
sae_lens/synthetic/hierarchy.py
CHANGED
|
@@ -11,7 +11,9 @@ https://github.com/noanabeshima/matryoshka-saes/blob/main/toy_model.py
|
|
|
11
11
|
|
|
12
12
|
from __future__ import annotations
|
|
13
13
|
|
|
14
|
+
from collections import deque
|
|
14
15
|
from collections.abc import Callable, Sequence
|
|
16
|
+
from dataclasses import dataclass
|
|
15
17
|
from typing import Any
|
|
16
18
|
|
|
17
19
|
import torch
|
|
@@ -19,6 +21,7 @@ import torch
|
|
|
19
21
|
ActivationsModifier = Callable[[torch.Tensor], torch.Tensor]
|
|
20
22
|
|
|
21
23
|
|
|
24
|
+
@torch.no_grad()
|
|
22
25
|
def _validate_hierarchy(roots: Sequence[HierarchyNode]) -> None:
|
|
23
26
|
"""
|
|
24
27
|
Validate a forest of hierarchy trees.
|
|
@@ -104,6 +107,320 @@ def _node_description(node: HierarchyNode) -> str:
|
|
|
104
107
|
return "unnamed node"
|
|
105
108
|
|
|
106
109
|
|
|
110
|
+
# ---------------------------------------------------------------------------
|
|
111
|
+
# Vectorized hierarchy implementation
|
|
112
|
+
# ---------------------------------------------------------------------------
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@dataclass
|
|
116
|
+
class _LevelData:
|
|
117
|
+
"""Data for a single level in the hierarchy."""
|
|
118
|
+
|
|
119
|
+
# Features at this level and their parents (for parent deactivation)
|
|
120
|
+
features: torch.Tensor # [num_features_at_level]
|
|
121
|
+
parents: torch.Tensor # [num_features_at_level]
|
|
122
|
+
|
|
123
|
+
# ME group indices to process AFTER this level's parent deactivation
|
|
124
|
+
# These are groups whose parent node is at this level
|
|
125
|
+
# ME must be applied here before processing next level's parent deactivation
|
|
126
|
+
me_group_indices: torch.Tensor # [num_groups_at_level], may be empty
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
@dataclass
|
|
130
|
+
class _SparseHierarchyData:
|
|
131
|
+
"""Precomputed data for sparse hierarchy processing.
|
|
132
|
+
|
|
133
|
+
This structure enables O(active_features) processing instead of O(all_groups).
|
|
134
|
+
ME is applied at each level after parent deactivation to ensure cascading works.
|
|
135
|
+
"""
|
|
136
|
+
|
|
137
|
+
# Per-level data for parent deactivation and ME (processed in order)
|
|
138
|
+
level_data: list[_LevelData]
|
|
139
|
+
|
|
140
|
+
# ME group data (shared across levels, indexed by me_group_indices)
|
|
141
|
+
me_group_siblings: torch.Tensor # [num_groups, max_siblings]
|
|
142
|
+
me_group_sizes: torch.Tensor # [num_groups]
|
|
143
|
+
me_group_parents: (
|
|
144
|
+
torch.Tensor
|
|
145
|
+
) # [num_groups] - parent feature index (-1 if no parent)
|
|
146
|
+
|
|
147
|
+
# Total number of ME groups
|
|
148
|
+
num_groups: int
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def _build_sparse_hierarchy(
|
|
152
|
+
roots: Sequence[HierarchyNode],
|
|
153
|
+
) -> _SparseHierarchyData:
|
|
154
|
+
"""
|
|
155
|
+
Build sparse hierarchy data structure for O(active_features) processing.
|
|
156
|
+
|
|
157
|
+
The key insight is that ME groups must be applied at the level of their parent node,
|
|
158
|
+
AFTER parent deactivation at that level, but BEFORE processing the next level.
|
|
159
|
+
This ensures that when a child is deactivated by ME, its grandchildren are also
|
|
160
|
+
deactivated during the next level's parent deactivation.
|
|
161
|
+
"""
|
|
162
|
+
# Collect feature info by level using BFS
|
|
163
|
+
# Each entry: (feature_index, effective_parent, level)
|
|
164
|
+
feature_info: list[tuple[int, int, int]] = []
|
|
165
|
+
|
|
166
|
+
# ME groups: list of (parent_level, parent_feature, child_feature_indices)
|
|
167
|
+
me_groups: list[tuple[int, int, list[int]]] = []
|
|
168
|
+
|
|
169
|
+
# BFS queue: (node, effective_parent, level)
|
|
170
|
+
queue: deque[tuple[HierarchyNode, int, int]] = deque()
|
|
171
|
+
for root in roots:
|
|
172
|
+
queue.append((root, -1, 0))
|
|
173
|
+
|
|
174
|
+
while queue:
|
|
175
|
+
node, effective_parent, level = queue.popleft()
|
|
176
|
+
|
|
177
|
+
if node.feature_index is not None:
|
|
178
|
+
feature_info.append((node.feature_index, effective_parent, level))
|
|
179
|
+
new_effective_parent = node.feature_index
|
|
180
|
+
else:
|
|
181
|
+
new_effective_parent = effective_parent
|
|
182
|
+
|
|
183
|
+
# Handle mutual exclusion children - record the parent's level and feature
|
|
184
|
+
if node.mutually_exclusive_children and len(node.children) >= 2:
|
|
185
|
+
child_feats = [
|
|
186
|
+
c.feature_index for c in node.children if c.feature_index is not None
|
|
187
|
+
]
|
|
188
|
+
if len(child_feats) >= 2:
|
|
189
|
+
# ME group belongs to the parent's level (current level)
|
|
190
|
+
# Parent feature is the node's feature_index (-1 if organizational node)
|
|
191
|
+
parent_feat = (
|
|
192
|
+
node.feature_index if node.feature_index is not None else -1
|
|
193
|
+
)
|
|
194
|
+
me_groups.append((level, parent_feat, child_feats))
|
|
195
|
+
|
|
196
|
+
for child in node.children:
|
|
197
|
+
queue.append((child, new_effective_parent, level + 1))
|
|
198
|
+
|
|
199
|
+
# Determine max level for both features and ME groups
|
|
200
|
+
max_feature_level = max((info[2] for info in feature_info), default=-1)
|
|
201
|
+
max_me_level = max((lvl for lvl, _, _ in me_groups), default=-1)
|
|
202
|
+
max_level = max(max_feature_level, max_me_level)
|
|
203
|
+
|
|
204
|
+
# Build level data with ME group indices per level
|
|
205
|
+
level_data: list[_LevelData] = []
|
|
206
|
+
|
|
207
|
+
# Group ME groups by their parent level
|
|
208
|
+
me_groups_by_level: dict[int, list[int]] = {}
|
|
209
|
+
for g_idx, (parent_level, _, _) in enumerate(me_groups):
|
|
210
|
+
if parent_level not in me_groups_by_level:
|
|
211
|
+
me_groups_by_level[parent_level] = []
|
|
212
|
+
me_groups_by_level[parent_level].append(g_idx)
|
|
213
|
+
|
|
214
|
+
for level in range(max_level + 1):
|
|
215
|
+
# Get features at this level that have parents
|
|
216
|
+
features_at_level = [
|
|
217
|
+
(feat, parent) for feat, parent, lv in feature_info if lv == level
|
|
218
|
+
]
|
|
219
|
+
with_parents = [(f, p) for f, p in features_at_level if p >= 0]
|
|
220
|
+
|
|
221
|
+
if with_parents:
|
|
222
|
+
feats = torch.tensor([f for f, _ in with_parents], dtype=torch.long)
|
|
223
|
+
parents = torch.tensor([p for _, p in with_parents], dtype=torch.long)
|
|
224
|
+
else:
|
|
225
|
+
feats = torch.empty(0, dtype=torch.long)
|
|
226
|
+
parents = torch.empty(0, dtype=torch.long)
|
|
227
|
+
|
|
228
|
+
# Get ME group indices for this level
|
|
229
|
+
if level in me_groups_by_level:
|
|
230
|
+
me_indices = torch.tensor(me_groups_by_level[level], dtype=torch.long)
|
|
231
|
+
else:
|
|
232
|
+
me_indices = torch.empty(0, dtype=torch.long)
|
|
233
|
+
|
|
234
|
+
level_data.append(
|
|
235
|
+
_LevelData(features=feats, parents=parents, me_group_indices=me_indices)
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
# Build group siblings and parents tensors
|
|
239
|
+
if me_groups:
|
|
240
|
+
max_siblings = max(len(children) for _, _, children in me_groups)
|
|
241
|
+
num_groups = len(me_groups)
|
|
242
|
+
me_group_siblings = torch.full((num_groups, max_siblings), -1, dtype=torch.long)
|
|
243
|
+
me_group_sizes = torch.zeros(num_groups, dtype=torch.long)
|
|
244
|
+
me_group_parents = torch.full((num_groups,), -1, dtype=torch.long)
|
|
245
|
+
for g_idx, (_, parent_feat, siblings) in enumerate(me_groups):
|
|
246
|
+
me_group_sizes[g_idx] = len(siblings)
|
|
247
|
+
me_group_parents[g_idx] = parent_feat
|
|
248
|
+
me_group_siblings[g_idx, : len(siblings)] = torch.tensor(
|
|
249
|
+
siblings, dtype=torch.long
|
|
250
|
+
)
|
|
251
|
+
else:
|
|
252
|
+
me_group_siblings = torch.empty((0, 0), dtype=torch.long)
|
|
253
|
+
me_group_sizes = torch.empty(0, dtype=torch.long)
|
|
254
|
+
me_group_parents = torch.empty(0, dtype=torch.long)
|
|
255
|
+
num_groups = 0
|
|
256
|
+
|
|
257
|
+
return _SparseHierarchyData(
|
|
258
|
+
level_data=level_data,
|
|
259
|
+
me_group_siblings=me_group_siblings,
|
|
260
|
+
me_group_sizes=me_group_sizes,
|
|
261
|
+
me_group_parents=me_group_parents,
|
|
262
|
+
num_groups=num_groups,
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def _apply_hierarchy_sparse(
|
|
267
|
+
activations: torch.Tensor,
|
|
268
|
+
sparse_data: _SparseHierarchyData,
|
|
269
|
+
) -> torch.Tensor:
|
|
270
|
+
"""
|
|
271
|
+
Apply hierarchy constraints using precomputed sparse indices.
|
|
272
|
+
|
|
273
|
+
Processes level by level:
|
|
274
|
+
1. Apply parent deactivation for features at this level
|
|
275
|
+
2. Apply mutual exclusion for groups whose parent is at this level
|
|
276
|
+
3. Move to next level
|
|
277
|
+
|
|
278
|
+
This ensures that ME at level L affects parent deactivation at level L+1.
|
|
279
|
+
"""
|
|
280
|
+
result = activations.clone()
|
|
281
|
+
|
|
282
|
+
# Data is already on correct device from cache
|
|
283
|
+
me_group_siblings = sparse_data.me_group_siblings
|
|
284
|
+
me_group_sizes = sparse_data.me_group_sizes
|
|
285
|
+
me_group_parents = sparse_data.me_group_parents
|
|
286
|
+
|
|
287
|
+
for level_data in sparse_data.level_data:
|
|
288
|
+
# Step 1: Deactivate children where parent is inactive
|
|
289
|
+
if level_data.features.numel() > 0:
|
|
290
|
+
parent_vals = result[:, level_data.parents]
|
|
291
|
+
child_vals = result[:, level_data.features]
|
|
292
|
+
result[:, level_data.features] = child_vals * (parent_vals > 0)
|
|
293
|
+
|
|
294
|
+
# Step 2: Apply ME for groups whose parent is at this level
|
|
295
|
+
if level_data.me_group_indices.numel() > 0:
|
|
296
|
+
_apply_me_for_groups(
|
|
297
|
+
result,
|
|
298
|
+
level_data.me_group_indices,
|
|
299
|
+
me_group_siblings,
|
|
300
|
+
me_group_sizes,
|
|
301
|
+
me_group_parents,
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
return result
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
def _apply_me_for_groups(
|
|
308
|
+
activations: torch.Tensor,
|
|
309
|
+
group_indices: torch.Tensor,
|
|
310
|
+
me_group_siblings: torch.Tensor,
|
|
311
|
+
me_group_sizes: torch.Tensor,
|
|
312
|
+
me_group_parents: torch.Tensor,
|
|
313
|
+
) -> None:
|
|
314
|
+
"""
|
|
315
|
+
Apply mutual exclusion for the specified groups.
|
|
316
|
+
|
|
317
|
+
Only processes groups where the parent is active (or has no parent).
|
|
318
|
+
This is a key optimization since most groups are skipped when parent is inactive.
|
|
319
|
+
|
|
320
|
+
Args:
|
|
321
|
+
activations: [batch_size, num_features] - modified in place
|
|
322
|
+
group_indices: [num_groups_to_process] - which groups to apply ME for
|
|
323
|
+
me_group_siblings: [total_groups, max_siblings] - sibling indices per group
|
|
324
|
+
me_group_sizes: [total_groups] - number of valid siblings per group
|
|
325
|
+
me_group_parents: [total_groups] - parent feature index (-1 if no parent)
|
|
326
|
+
"""
|
|
327
|
+
batch_size = activations.shape[0]
|
|
328
|
+
device = activations.device
|
|
329
|
+
num_groups = group_indices.numel()
|
|
330
|
+
|
|
331
|
+
if num_groups == 0:
|
|
332
|
+
return
|
|
333
|
+
|
|
334
|
+
# Get parent indices for these groups
|
|
335
|
+
parents = me_group_parents[group_indices] # [num_groups]
|
|
336
|
+
|
|
337
|
+
# Check which parents are active: [batch_size, num_groups]
|
|
338
|
+
# Groups with parent=-1 are always active (root-level ME)
|
|
339
|
+
has_parent = parents >= 0
|
|
340
|
+
if has_parent.all():
|
|
341
|
+
# All groups have parents - check their activation directly
|
|
342
|
+
parent_active = activations[:, parents] > 0 # [batch, num_groups]
|
|
343
|
+
if not parent_active.any():
|
|
344
|
+
return
|
|
345
|
+
elif has_parent.any():
|
|
346
|
+
# Mixed case: some groups have parents, some don't
|
|
347
|
+
# Use clamp to avoid indexing with -1 (reads feature 0, but result is masked out)
|
|
348
|
+
safe_parents = parents.clamp(min=0)
|
|
349
|
+
parent_active = activations[:, safe_parents] > 0 # [batch, num_groups]
|
|
350
|
+
# Groups without parent are always "active"
|
|
351
|
+
parent_active = parent_active | ~has_parent
|
|
352
|
+
else:
|
|
353
|
+
# No groups have parents - all are always active, skip parent check
|
|
354
|
+
parent_active = None
|
|
355
|
+
|
|
356
|
+
# Get siblings for the groups we're processing
|
|
357
|
+
siblings = me_group_siblings[group_indices] # [num_groups, max_siblings]
|
|
358
|
+
sizes = me_group_sizes[group_indices] # [num_groups]
|
|
359
|
+
max_siblings = siblings.shape[1]
|
|
360
|
+
|
|
361
|
+
# Get activations for all siblings: [batch_size, num_groups, max_siblings]
|
|
362
|
+
safe_siblings = siblings.clamp(min=0)
|
|
363
|
+
sibling_activations = activations[:, safe_siblings.view(-1)].view(
|
|
364
|
+
batch_size, num_groups, max_siblings
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
# Create validity mask for padding: [num_groups, max_siblings]
|
|
368
|
+
sibling_range = torch.arange(max_siblings, device=device)
|
|
369
|
+
valid_mask = sibling_range < sizes.unsqueeze(1)
|
|
370
|
+
|
|
371
|
+
# Find active valid siblings, but only where parent is active: [batch, groups, siblings]
|
|
372
|
+
sibling_active = (sibling_activations > 0) & valid_mask
|
|
373
|
+
if parent_active is not None:
|
|
374
|
+
sibling_active = sibling_active & parent_active.unsqueeze(2)
|
|
375
|
+
|
|
376
|
+
# Count active per group and check for conflicts: [batch_size, num_groups]
|
|
377
|
+
active_counts = sibling_active.sum(dim=2)
|
|
378
|
+
needs_exclusion = active_counts > 1
|
|
379
|
+
|
|
380
|
+
if not needs_exclusion.any():
|
|
381
|
+
return
|
|
382
|
+
|
|
383
|
+
# Get (batch, group) pairs needing exclusion
|
|
384
|
+
batch_with_conflict, groups_with_conflict = torch.where(needs_exclusion)
|
|
385
|
+
num_conflicts = batch_with_conflict.numel()
|
|
386
|
+
|
|
387
|
+
if num_conflicts == 0:
|
|
388
|
+
return
|
|
389
|
+
|
|
390
|
+
# Get siblings and activations for conflicts
|
|
391
|
+
conflict_siblings = siblings[groups_with_conflict] # [num_conflicts, max_siblings]
|
|
392
|
+
conflict_active = sibling_active[
|
|
393
|
+
batch_with_conflict, groups_with_conflict
|
|
394
|
+
] # [num_conflicts, max_siblings]
|
|
395
|
+
|
|
396
|
+
# Random selection for winner
|
|
397
|
+
# Use -1e9 instead of -inf to avoid creating a tensor (torch.tensor(-float("inf")))
|
|
398
|
+
# on every call. Since random scores are in [0,1], -1e9 is effectively -inf for argmax.
|
|
399
|
+
random_scores = torch.rand(num_conflicts, max_siblings, device=device)
|
|
400
|
+
random_scores[~conflict_active] = -1e9
|
|
401
|
+
|
|
402
|
+
winner_idx = random_scores.argmax(dim=1)
|
|
403
|
+
|
|
404
|
+
# Determine losers using scatter for efficiency
|
|
405
|
+
is_winner = torch.zeros(
|
|
406
|
+
num_conflicts, max_siblings, dtype=torch.bool, device=device
|
|
407
|
+
)
|
|
408
|
+
is_winner.scatter_(1, winner_idx.unsqueeze(1), True)
|
|
409
|
+
should_deactivate = conflict_active & ~is_winner
|
|
410
|
+
|
|
411
|
+
# Get (conflict, sibling) pairs to deactivate
|
|
412
|
+
conflict_idx, sib_idx = torch.where(should_deactivate)
|
|
413
|
+
|
|
414
|
+
if conflict_idx.numel() == 0:
|
|
415
|
+
return
|
|
416
|
+
|
|
417
|
+
# Map back to (batch, feature) and deactivate
|
|
418
|
+
deact_batch = batch_with_conflict[conflict_idx]
|
|
419
|
+
deact_feat = conflict_siblings[conflict_idx, sib_idx]
|
|
420
|
+
activations[deact_batch, deact_feat] = 0
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
@torch.no_grad()
|
|
107
424
|
def hierarchy_modifier(
|
|
108
425
|
roots: Sequence[HierarchyNode] | HierarchyNode,
|
|
109
426
|
) -> ActivationsModifier:
|
|
@@ -136,12 +453,35 @@ def hierarchy_modifier(
|
|
|
136
453
|
roots = [roots]
|
|
137
454
|
_validate_hierarchy(roots)
|
|
138
455
|
|
|
139
|
-
#
|
|
456
|
+
# Build sparse hierarchy data
|
|
457
|
+
sparse_data = _build_sparse_hierarchy(roots)
|
|
458
|
+
|
|
459
|
+
# Cache for device-specific tensors
|
|
460
|
+
device_cache: dict[torch.device, _SparseHierarchyData] = {}
|
|
461
|
+
|
|
462
|
+
def _get_sparse_for_device(device: torch.device) -> _SparseHierarchyData:
|
|
463
|
+
"""Get or create device-specific sparse hierarchy data."""
|
|
464
|
+
if device not in device_cache:
|
|
465
|
+
device_cache[device] = _SparseHierarchyData(
|
|
466
|
+
level_data=[
|
|
467
|
+
_LevelData(
|
|
468
|
+
features=ld.features.to(device),
|
|
469
|
+
parents=ld.parents.to(device),
|
|
470
|
+
me_group_indices=ld.me_group_indices.to(device),
|
|
471
|
+
)
|
|
472
|
+
for ld in sparse_data.level_data
|
|
473
|
+
],
|
|
474
|
+
me_group_siblings=sparse_data.me_group_siblings.to(device),
|
|
475
|
+
me_group_sizes=sparse_data.me_group_sizes.to(device),
|
|
476
|
+
me_group_parents=sparse_data.me_group_parents.to(device),
|
|
477
|
+
num_groups=sparse_data.num_groups,
|
|
478
|
+
)
|
|
479
|
+
return device_cache[device]
|
|
480
|
+
|
|
140
481
|
def modifier(activations: torch.Tensor) -> torch.Tensor:
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
return result
|
|
482
|
+
device = activations.device
|
|
483
|
+
cached = _get_sparse_for_device(device)
|
|
484
|
+
return _apply_hierarchy_sparse(activations, cached)
|
|
145
485
|
|
|
146
486
|
return modifier
|
|
147
487
|
|
|
@@ -222,85 +562,6 @@ class HierarchyNode:
|
|
|
222
562
|
if self.mutually_exclusive_children and len(self.children) < 2:
|
|
223
563
|
raise ValueError("Need at least 2 children for mutual exclusion")
|
|
224
564
|
|
|
225
|
-
def _apply_hierarchy(
|
|
226
|
-
self,
|
|
227
|
-
activations: torch.Tensor,
|
|
228
|
-
parent_active_mask: torch.Tensor | None,
|
|
229
|
-
) -> None:
|
|
230
|
-
"""Recursively apply hierarchical constraints."""
|
|
231
|
-
batch_size = activations.shape[0]
|
|
232
|
-
|
|
233
|
-
# Determine which samples have this node active
|
|
234
|
-
if self.feature_index is not None:
|
|
235
|
-
is_active = activations[:, self.feature_index] > 0
|
|
236
|
-
else:
|
|
237
|
-
# Non-readout node: active if parent is active (or always if root)
|
|
238
|
-
is_active = (
|
|
239
|
-
parent_active_mask
|
|
240
|
-
if parent_active_mask is not None
|
|
241
|
-
else torch.ones(batch_size, dtype=torch.bool, device=activations.device)
|
|
242
|
-
)
|
|
243
|
-
|
|
244
|
-
# Deactivate this node if parent is inactive
|
|
245
|
-
if parent_active_mask is not None and self.feature_index is not None:
|
|
246
|
-
activations[~parent_active_mask, self.feature_index] = 0
|
|
247
|
-
# Update is_active after deactivation
|
|
248
|
-
is_active = activations[:, self.feature_index] > 0
|
|
249
|
-
|
|
250
|
-
# Handle mutually exclusive children
|
|
251
|
-
if self.mutually_exclusive_children and len(self.children) >= 2:
|
|
252
|
-
self._enforce_mutual_exclusion(activations, is_active)
|
|
253
|
-
|
|
254
|
-
# Recursively process children
|
|
255
|
-
for child in self.children:
|
|
256
|
-
child._apply_hierarchy(activations, parent_active_mask=is_active)
|
|
257
|
-
|
|
258
|
-
def _enforce_mutual_exclusion(
|
|
259
|
-
self,
|
|
260
|
-
activations: torch.Tensor,
|
|
261
|
-
parent_active_mask: torch.Tensor,
|
|
262
|
-
) -> None:
|
|
263
|
-
"""Ensure at most one child is active per sample."""
|
|
264
|
-
batch_size = activations.shape[0]
|
|
265
|
-
|
|
266
|
-
# Get indices of children that have feature indices
|
|
267
|
-
child_indices = [
|
|
268
|
-
child.feature_index
|
|
269
|
-
for child in self.children
|
|
270
|
-
if child.feature_index is not None
|
|
271
|
-
]
|
|
272
|
-
|
|
273
|
-
if len(child_indices) < 2:
|
|
274
|
-
return
|
|
275
|
-
|
|
276
|
-
# For each sample where parent is active, enforce mutual exclusion.
|
|
277
|
-
# Note: This loop is not vectorized because we need to randomly select
|
|
278
|
-
# which child to keep active per sample. Vectorizing would require either
|
|
279
|
-
# a deterministic selection (losing randomness) or complex gather/scatter
|
|
280
|
-
# operations that aren't more efficient for typical batch sizes.
|
|
281
|
-
for batch_idx in range(batch_size):
|
|
282
|
-
if not parent_active_mask[batch_idx]:
|
|
283
|
-
continue
|
|
284
|
-
|
|
285
|
-
# Find which children are active
|
|
286
|
-
active_children = [
|
|
287
|
-
i
|
|
288
|
-
for i, feat_idx in enumerate(child_indices)
|
|
289
|
-
if activations[batch_idx, feat_idx] > 0
|
|
290
|
-
]
|
|
291
|
-
|
|
292
|
-
if len(active_children) <= 1:
|
|
293
|
-
continue
|
|
294
|
-
|
|
295
|
-
# Randomly select one to keep active
|
|
296
|
-
random_idx = int(torch.randint(len(active_children), (1,)).item())
|
|
297
|
-
keep_idx = active_children[random_idx]
|
|
298
|
-
|
|
299
|
-
# Deactivate all others
|
|
300
|
-
for i, feat_idx in enumerate(child_indices):
|
|
301
|
-
if i != keep_idx and i in active_children:
|
|
302
|
-
activations[batch_idx, feat_idx] = 0
|
|
303
|
-
|
|
304
565
|
def get_all_feature_indices(self) -> list[int]:
|
|
305
566
|
"""Get all feature indices in this subtree."""
|
|
306
567
|
indices = []
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: sae-lens
|
|
3
|
-
Version: 6.28.
|
|
3
|
+
Version: 6.28.2
|
|
4
4
|
Summary: Training and Analyzing Sparse Autoencoders (SAEs)
|
|
5
5
|
License: MIT
|
|
6
6
|
License-File: LICENSE
|
|
@@ -50,6 +50,8 @@ SAELens exists to help researchers:
|
|
|
50
50
|
- Analyse sparse autoencoders / research mechanistic interpretability.
|
|
51
51
|
- Generate insights which make it easier to create safe and aligned AI systems.
|
|
52
52
|
|
|
53
|
+
SAELens inference works with any PyTorch-based model, not just TransformerLens. While we provide deep integration with TransformerLens via `HookedSAETransformer`, SAEs can be used with Hugging Face Transformers, NNsight, or any other framework by extracting activations and passing them to the SAE's `encode()` and `decode()` methods.
|
|
54
|
+
|
|
53
55
|
Please refer to the [documentation](https://decoderesearch.github.io/SAELens/) for information on how to:
|
|
54
56
|
|
|
55
57
|
- Download and Analyse pre-trained sparse autoencoders.
|
|
@@ -84,6 +86,14 @@ The new v6 update is a major refactor to SAELens and changes the way training co
|
|
|
84
86
|
|
|
85
87
|
Feel free to join the [Open Source Mechanistic Interpretability Slack](https://join.slack.com/t/opensourcemechanistic/shared_invite/zt-375zalm04-GFd5tdBU1yLKlu_T_JSqZQ) for support!
|
|
86
88
|
|
|
89
|
+
## Other SAE Projects
|
|
90
|
+
|
|
91
|
+
- [dictionary-learning](https://github.com/saprmarks/dictionary_learning): An SAE training library that focuses on having hackable code.
|
|
92
|
+
- [Sparsify](https://github.com/EleutherAI/sparsify): A lean SAE training library focused on TopK SAEs.
|
|
93
|
+
- [Overcomplete](https://github.com/KempnerInstitute/overcomplete): SAE training library focused on vision models.
|
|
94
|
+
- [SAE-Vis](https://github.com/callummcdougall/sae_vis): A library for visualizing SAE features, works with SAELens.
|
|
95
|
+
- [SAEBench](https://github.com/adamkarvonen/SAEBench): A suite of LLM SAE benchmarks, works with SAELens.
|
|
96
|
+
|
|
87
97
|
## Citation
|
|
88
98
|
|
|
89
99
|
Please cite the package as follows:
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
sae_lens/__init__.py,sha256=
|
|
1
|
+
sae_lens/__init__.py,sha256=B9tY0Jt21pOHmSQrQLpMxQHyUAdLHIZpVP6pg3O0dfQ,4788
|
|
2
2
|
sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
3
|
sae_lens/analysis/hooked_sae_transformer.py,sha256=dQRgGVwce8XwylL2AzJE7l9elhtMRFCs2hdUj-Qyy4g,14038
|
|
4
4
|
sae_lens/analysis/neuronpedia_integration.py,sha256=Gx1W7hUBEuMoasNcnOnZ1wmqbXDd1pSZ1nqKEya1HQc,4962
|
|
@@ -12,7 +12,7 @@ sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,
|
|
|
12
12
|
sae_lens/loading/pretrained_sae_loaders.py,sha256=hHMlew1u6zVlbzvS9S_SfUPnAG0_OAjjIcjoUTIUZrU,63657
|
|
13
13
|
sae_lens/loading/pretrained_saes_directory.py,sha256=1at_aQbD8WFywchQCKuwfP-yvCq_Z2aUYrpKDnSN5Nc,4283
|
|
14
14
|
sae_lens/pretokenize_runner.py,sha256=amJwIz3CKi2s2wNQn-10E7eAV7VFhNqtFDNTeTkwEI8,7133
|
|
15
|
-
sae_lens/pretrained_saes.yaml,sha256=
|
|
15
|
+
sae_lens/pretrained_saes.yaml,sha256=Nq43dTcFvDDONTuJ9Me_HQ5nHqr9BdbP5-ZJGXj0TAQ,1509932
|
|
16
16
|
sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
|
|
17
17
|
sae_lens/saes/__init__.py,sha256=SBqPaP6Gl5uPFwHlumAZATC4Wd26xKIYLAAAo4MSa5Q,2200
|
|
18
18
|
sae_lens/saes/batchtopk_sae.py,sha256=x4EbgZl0GUickRPcCmtKNGS2Ra3Uy1Z1OtF2FnrSabQ,5422
|
|
@@ -26,12 +26,12 @@ sae_lens/saes/temporal_sae.py,sha256=83Ap4mYGfdN3sKdPF8nKjhdXph3-7E2QuLobqJ_YuoM
|
|
|
26
26
|
sae_lens/saes/topk_sae.py,sha256=vrMRPrCQR1o8G_kXqY_EAoGZARupkQNFB2dNZVLsusE,21073
|
|
27
27
|
sae_lens/saes/transcoder.py,sha256=CTpJs8ASOK06npih7gZHygZuxqTR7HICWlOYfTiKjI4,13501
|
|
28
28
|
sae_lens/synthetic/__init__.py,sha256=FGUasB6fLPXRFCcrtKfL7vCKDOWebZ5Rx5F9QNJZklI,2875
|
|
29
|
-
sae_lens/synthetic/activation_generator.py,sha256=
|
|
29
|
+
sae_lens/synthetic/activation_generator.py,sha256=JEN7mEgdGDuXr0ArTwUsSdSVUAfvheT_1Eew2ojbA-g,7659
|
|
30
30
|
sae_lens/synthetic/correlation.py,sha256=odr-S5h6c2U-bepwrAQeMfV1iBF_cnnQzqw7zapEXZ4,6056
|
|
31
31
|
sae_lens/synthetic/evals.py,sha256=Nhi314ZnRgLfhBj-3tm_zzI-pGyFTcwllDXbIpPFXeU,4584
|
|
32
|
-
sae_lens/synthetic/feature_dictionary.py,sha256=
|
|
32
|
+
sae_lens/synthetic/feature_dictionary.py,sha256=ysn0ihE3JgVlCLUZMb127WYZqbz4kMp9BGHfCZqERBg,6487
|
|
33
33
|
sae_lens/synthetic/firing_probabilities.py,sha256=yclz1pWl5gE1r8LAxFvzQS88Lxwk5-3r8BCX9HLVejA,3370
|
|
34
|
-
sae_lens/synthetic/hierarchy.py,sha256=
|
|
34
|
+
sae_lens/synthetic/hierarchy.py,sha256=j9-6K7xq6zQS9N8bB5nK_-EbuzAZsY5Z5AfUK-qlB5M,22138
|
|
35
35
|
sae_lens/synthetic/initialization.py,sha256=orMGW-786wRDHIS2W7bEH0HmlVFQ4g2z4bnnwdv5w4s,1386
|
|
36
36
|
sae_lens/synthetic/plotting.py,sha256=5lFrej1QOkGAcImFNo5-o-8mI_rUVqvEI57KzUQPPtQ,8208
|
|
37
37
|
sae_lens/synthetic/training.py,sha256=Bg6NYxdzifq_8g-dJQSZ_z_TXDdGRtEi7tqNDb-gCVc,4986
|
|
@@ -46,7 +46,7 @@ sae_lens/training/types.py,sha256=1FpLx_Doda9vZpmfm-x1e8wGBYpyhe9Kpb_JuM5nIFM,90
|
|
|
46
46
|
sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
|
|
47
47
|
sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
|
|
48
48
|
sae_lens/util.py,sha256=oIMoeyEP2IzcPFmRbKUzOAycgEyMcOasGeO_BGVZbc4,4846
|
|
49
|
-
sae_lens-6.28.
|
|
50
|
-
sae_lens-6.28.
|
|
51
|
-
sae_lens-6.28.
|
|
52
|
-
sae_lens-6.28.
|
|
49
|
+
sae_lens-6.28.2.dist-info/METADATA,sha256=i_kbAa64It0NRDrnSlmwNa8qgqOyEMntT_Ifxdx4Q90,6573
|
|
50
|
+
sae_lens-6.28.2.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
|
|
51
|
+
sae_lens-6.28.2.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
|
|
52
|
+
sae_lens-6.28.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|