sae-lens 6.26.1__py3-none-any.whl → 6.28.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +3 -1
- sae_lens/cache_activations_runner.py +12 -5
- sae_lens/config.py +2 -0
- sae_lens/loading/pretrained_sae_loaders.py +2 -1
- sae_lens/loading/pretrained_saes_directory.py +18 -0
- sae_lens/pretrained_saes.yaml +1 -1
- sae_lens/saes/gated_sae.py +1 -0
- sae_lens/saes/jumprelu_sae.py +3 -0
- sae_lens/saes/sae.py +13 -0
- sae_lens/saes/standard_sae.py +2 -0
- sae_lens/saes/temporal_sae.py +1 -0
- sae_lens/synthetic/__init__.py +89 -0
- sae_lens/synthetic/activation_generator.py +216 -0
- sae_lens/synthetic/correlation.py +170 -0
- sae_lens/synthetic/evals.py +141 -0
- sae_lens/synthetic/feature_dictionary.py +176 -0
- sae_lens/synthetic/firing_probabilities.py +104 -0
- sae_lens/synthetic/hierarchy.py +596 -0
- sae_lens/synthetic/initialization.py +40 -0
- sae_lens/synthetic/plotting.py +230 -0
- sae_lens/synthetic/training.py +145 -0
- sae_lens/tokenization_and_batching.py +1 -1
- sae_lens/training/activations_store.py +51 -91
- sae_lens/training/mixing_buffer.py +14 -5
- sae_lens/training/sae_trainer.py +1 -1
- sae_lens/util.py +26 -1
- {sae_lens-6.26.1.dist-info → sae_lens-6.28.2.dist-info}/METADATA +13 -1
- sae_lens-6.28.2.dist-info/RECORD +52 -0
- sae_lens-6.26.1.dist-info/RECORD +0 -42
- {sae_lens-6.26.1.dist-info → sae_lens-6.28.2.dist-info}/WHEEL +0 -0
- {sae_lens-6.26.1.dist-info → sae_lens-6.28.2.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,596 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Hierarchical feature modifier for activation generators.
|
|
3
|
+
|
|
4
|
+
This module provides HierarchyNode, which enforces hierarchical dependencies
|
|
5
|
+
on feature activations. Child features are deactivated when their parent is inactive,
|
|
6
|
+
and children can optionally be mutually exclusive.
|
|
7
|
+
|
|
8
|
+
Based on Noa Nabeshima's Matryoshka SAEs:
|
|
9
|
+
https://github.com/noanabeshima/matryoshka-saes/blob/main/toy_model.py
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
from collections import deque
|
|
15
|
+
from collections.abc import Callable, Sequence
|
|
16
|
+
from dataclasses import dataclass
|
|
17
|
+
from typing import Any
|
|
18
|
+
|
|
19
|
+
import torch
|
|
20
|
+
|
|
21
|
+
ActivationsModifier = Callable[[torch.Tensor], torch.Tensor]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@torch.no_grad()
|
|
25
|
+
def _validate_hierarchy(roots: Sequence[HierarchyNode]) -> None:
|
|
26
|
+
"""
|
|
27
|
+
Validate a forest of hierarchy trees.
|
|
28
|
+
|
|
29
|
+
Treats the input as children of a virtual root node and validates the
|
|
30
|
+
entire structure.
|
|
31
|
+
|
|
32
|
+
Checks that:
|
|
33
|
+
1. There are no loops (no node is its own ancestor)
|
|
34
|
+
2. Each node has at most one parent (no node appears in multiple children lists)
|
|
35
|
+
3. No feature index appears in multiple trees
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
roots: Root nodes of the hierarchy trees to validate
|
|
39
|
+
|
|
40
|
+
Raises:
|
|
41
|
+
ValueError: If the hierarchy is invalid
|
|
42
|
+
"""
|
|
43
|
+
if not roots:
|
|
44
|
+
return
|
|
45
|
+
|
|
46
|
+
# Collect all nodes and check for loops, treating roots as children of virtual root
|
|
47
|
+
all_nodes: list[HierarchyNode] = []
|
|
48
|
+
virtual_root_id = id(roots) # Use the list itself as virtual root identity
|
|
49
|
+
|
|
50
|
+
for root in roots:
|
|
51
|
+
all_nodes.append(root)
|
|
52
|
+
_collect_nodes_and_check_loops(root, all_nodes, ancestors={virtual_root_id})
|
|
53
|
+
|
|
54
|
+
# Check for multiple parents (same node appearing multiple times)
|
|
55
|
+
seen_ids: set[int] = set()
|
|
56
|
+
for node in all_nodes:
|
|
57
|
+
node_id = id(node)
|
|
58
|
+
if node_id in seen_ids:
|
|
59
|
+
node_desc = _node_description(node)
|
|
60
|
+
raise ValueError(
|
|
61
|
+
f"Node ({node_desc}) has multiple parents. "
|
|
62
|
+
"Each node must have at most one parent."
|
|
63
|
+
)
|
|
64
|
+
seen_ids.add(node_id)
|
|
65
|
+
|
|
66
|
+
# Check for overlapping feature indices across trees
|
|
67
|
+
if len(roots) > 1:
|
|
68
|
+
all_indices: set[int] = set()
|
|
69
|
+
for root in roots:
|
|
70
|
+
tree_indices = root.get_all_feature_indices()
|
|
71
|
+
overlap = all_indices & set(tree_indices)
|
|
72
|
+
if overlap:
|
|
73
|
+
raise ValueError(
|
|
74
|
+
f"Feature indices {overlap} appear in multiple hierarchy trees. "
|
|
75
|
+
"Each feature should belong to at most one hierarchy."
|
|
76
|
+
)
|
|
77
|
+
all_indices.update(tree_indices)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _collect_nodes_and_check_loops(
|
|
81
|
+
node: HierarchyNode,
|
|
82
|
+
all_nodes: list[HierarchyNode],
|
|
83
|
+
ancestors: set[int],
|
|
84
|
+
) -> None:
|
|
85
|
+
"""Recursively collect nodes and check for loops."""
|
|
86
|
+
node_id = id(node)
|
|
87
|
+
|
|
88
|
+
if node_id in ancestors:
|
|
89
|
+
node_desc = _node_description(node)
|
|
90
|
+
raise ValueError(f"Loop detected: node ({node_desc}) is its own ancestor.")
|
|
91
|
+
|
|
92
|
+
# Add to ancestors for children traversal
|
|
93
|
+
new_ancestors = ancestors | {node_id}
|
|
94
|
+
|
|
95
|
+
for child in node.children:
|
|
96
|
+
# Collect child (before recursing, so we can detect multiple parents)
|
|
97
|
+
all_nodes.append(child)
|
|
98
|
+
_collect_nodes_and_check_loops(child, all_nodes, new_ancestors)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def _node_description(node: HierarchyNode) -> str:
|
|
102
|
+
"""Get a human-readable description of a node for error messages."""
|
|
103
|
+
if node.feature_index is not None:
|
|
104
|
+
return f"feature_index={node.feature_index}"
|
|
105
|
+
if node.feature_id:
|
|
106
|
+
return f"id={node.feature_id}"
|
|
107
|
+
return "unnamed node"
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
# ---------------------------------------------------------------------------
|
|
111
|
+
# Vectorized hierarchy implementation
|
|
112
|
+
# ---------------------------------------------------------------------------
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@dataclass
|
|
116
|
+
class _LevelData:
|
|
117
|
+
"""Data for a single level in the hierarchy."""
|
|
118
|
+
|
|
119
|
+
# Features at this level and their parents (for parent deactivation)
|
|
120
|
+
features: torch.Tensor # [num_features_at_level]
|
|
121
|
+
parents: torch.Tensor # [num_features_at_level]
|
|
122
|
+
|
|
123
|
+
# ME group indices to process AFTER this level's parent deactivation
|
|
124
|
+
# These are groups whose parent node is at this level
|
|
125
|
+
# ME must be applied here before processing next level's parent deactivation
|
|
126
|
+
me_group_indices: torch.Tensor # [num_groups_at_level], may be empty
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
@dataclass
|
|
130
|
+
class _SparseHierarchyData:
|
|
131
|
+
"""Precomputed data for sparse hierarchy processing.
|
|
132
|
+
|
|
133
|
+
This structure enables O(active_features) processing instead of O(all_groups).
|
|
134
|
+
ME is applied at each level after parent deactivation to ensure cascading works.
|
|
135
|
+
"""
|
|
136
|
+
|
|
137
|
+
# Per-level data for parent deactivation and ME (processed in order)
|
|
138
|
+
level_data: list[_LevelData]
|
|
139
|
+
|
|
140
|
+
# ME group data (shared across levels, indexed by me_group_indices)
|
|
141
|
+
me_group_siblings: torch.Tensor # [num_groups, max_siblings]
|
|
142
|
+
me_group_sizes: torch.Tensor # [num_groups]
|
|
143
|
+
me_group_parents: (
|
|
144
|
+
torch.Tensor
|
|
145
|
+
) # [num_groups] - parent feature index (-1 if no parent)
|
|
146
|
+
|
|
147
|
+
# Total number of ME groups
|
|
148
|
+
num_groups: int
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def _build_sparse_hierarchy(
|
|
152
|
+
roots: Sequence[HierarchyNode],
|
|
153
|
+
) -> _SparseHierarchyData:
|
|
154
|
+
"""
|
|
155
|
+
Build sparse hierarchy data structure for O(active_features) processing.
|
|
156
|
+
|
|
157
|
+
The key insight is that ME groups must be applied at the level of their parent node,
|
|
158
|
+
AFTER parent deactivation at that level, but BEFORE processing the next level.
|
|
159
|
+
This ensures that when a child is deactivated by ME, its grandchildren are also
|
|
160
|
+
deactivated during the next level's parent deactivation.
|
|
161
|
+
"""
|
|
162
|
+
# Collect feature info by level using BFS
|
|
163
|
+
# Each entry: (feature_index, effective_parent, level)
|
|
164
|
+
feature_info: list[tuple[int, int, int]] = []
|
|
165
|
+
|
|
166
|
+
# ME groups: list of (parent_level, parent_feature, child_feature_indices)
|
|
167
|
+
me_groups: list[tuple[int, int, list[int]]] = []
|
|
168
|
+
|
|
169
|
+
# BFS queue: (node, effective_parent, level)
|
|
170
|
+
queue: deque[tuple[HierarchyNode, int, int]] = deque()
|
|
171
|
+
for root in roots:
|
|
172
|
+
queue.append((root, -1, 0))
|
|
173
|
+
|
|
174
|
+
while queue:
|
|
175
|
+
node, effective_parent, level = queue.popleft()
|
|
176
|
+
|
|
177
|
+
if node.feature_index is not None:
|
|
178
|
+
feature_info.append((node.feature_index, effective_parent, level))
|
|
179
|
+
new_effective_parent = node.feature_index
|
|
180
|
+
else:
|
|
181
|
+
new_effective_parent = effective_parent
|
|
182
|
+
|
|
183
|
+
# Handle mutual exclusion children - record the parent's level and feature
|
|
184
|
+
if node.mutually_exclusive_children and len(node.children) >= 2:
|
|
185
|
+
child_feats = [
|
|
186
|
+
c.feature_index for c in node.children if c.feature_index is not None
|
|
187
|
+
]
|
|
188
|
+
if len(child_feats) >= 2:
|
|
189
|
+
# ME group belongs to the parent's level (current level)
|
|
190
|
+
# Parent feature is the node's feature_index (-1 if organizational node)
|
|
191
|
+
parent_feat = (
|
|
192
|
+
node.feature_index if node.feature_index is not None else -1
|
|
193
|
+
)
|
|
194
|
+
me_groups.append((level, parent_feat, child_feats))
|
|
195
|
+
|
|
196
|
+
for child in node.children:
|
|
197
|
+
queue.append((child, new_effective_parent, level + 1))
|
|
198
|
+
|
|
199
|
+
# Determine max level for both features and ME groups
|
|
200
|
+
max_feature_level = max((info[2] for info in feature_info), default=-1)
|
|
201
|
+
max_me_level = max((lvl for lvl, _, _ in me_groups), default=-1)
|
|
202
|
+
max_level = max(max_feature_level, max_me_level)
|
|
203
|
+
|
|
204
|
+
# Build level data with ME group indices per level
|
|
205
|
+
level_data: list[_LevelData] = []
|
|
206
|
+
|
|
207
|
+
# Group ME groups by their parent level
|
|
208
|
+
me_groups_by_level: dict[int, list[int]] = {}
|
|
209
|
+
for g_idx, (parent_level, _, _) in enumerate(me_groups):
|
|
210
|
+
if parent_level not in me_groups_by_level:
|
|
211
|
+
me_groups_by_level[parent_level] = []
|
|
212
|
+
me_groups_by_level[parent_level].append(g_idx)
|
|
213
|
+
|
|
214
|
+
for level in range(max_level + 1):
|
|
215
|
+
# Get features at this level that have parents
|
|
216
|
+
features_at_level = [
|
|
217
|
+
(feat, parent) for feat, parent, lv in feature_info if lv == level
|
|
218
|
+
]
|
|
219
|
+
with_parents = [(f, p) for f, p in features_at_level if p >= 0]
|
|
220
|
+
|
|
221
|
+
if with_parents:
|
|
222
|
+
feats = torch.tensor([f for f, _ in with_parents], dtype=torch.long)
|
|
223
|
+
parents = torch.tensor([p for _, p in with_parents], dtype=torch.long)
|
|
224
|
+
else:
|
|
225
|
+
feats = torch.empty(0, dtype=torch.long)
|
|
226
|
+
parents = torch.empty(0, dtype=torch.long)
|
|
227
|
+
|
|
228
|
+
# Get ME group indices for this level
|
|
229
|
+
if level in me_groups_by_level:
|
|
230
|
+
me_indices = torch.tensor(me_groups_by_level[level], dtype=torch.long)
|
|
231
|
+
else:
|
|
232
|
+
me_indices = torch.empty(0, dtype=torch.long)
|
|
233
|
+
|
|
234
|
+
level_data.append(
|
|
235
|
+
_LevelData(features=feats, parents=parents, me_group_indices=me_indices)
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
# Build group siblings and parents tensors
|
|
239
|
+
if me_groups:
|
|
240
|
+
max_siblings = max(len(children) for _, _, children in me_groups)
|
|
241
|
+
num_groups = len(me_groups)
|
|
242
|
+
me_group_siblings = torch.full((num_groups, max_siblings), -1, dtype=torch.long)
|
|
243
|
+
me_group_sizes = torch.zeros(num_groups, dtype=torch.long)
|
|
244
|
+
me_group_parents = torch.full((num_groups,), -1, dtype=torch.long)
|
|
245
|
+
for g_idx, (_, parent_feat, siblings) in enumerate(me_groups):
|
|
246
|
+
me_group_sizes[g_idx] = len(siblings)
|
|
247
|
+
me_group_parents[g_idx] = parent_feat
|
|
248
|
+
me_group_siblings[g_idx, : len(siblings)] = torch.tensor(
|
|
249
|
+
siblings, dtype=torch.long
|
|
250
|
+
)
|
|
251
|
+
else:
|
|
252
|
+
me_group_siblings = torch.empty((0, 0), dtype=torch.long)
|
|
253
|
+
me_group_sizes = torch.empty(0, dtype=torch.long)
|
|
254
|
+
me_group_parents = torch.empty(0, dtype=torch.long)
|
|
255
|
+
num_groups = 0
|
|
256
|
+
|
|
257
|
+
return _SparseHierarchyData(
|
|
258
|
+
level_data=level_data,
|
|
259
|
+
me_group_siblings=me_group_siblings,
|
|
260
|
+
me_group_sizes=me_group_sizes,
|
|
261
|
+
me_group_parents=me_group_parents,
|
|
262
|
+
num_groups=num_groups,
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def _apply_hierarchy_sparse(
|
|
267
|
+
activations: torch.Tensor,
|
|
268
|
+
sparse_data: _SparseHierarchyData,
|
|
269
|
+
) -> torch.Tensor:
|
|
270
|
+
"""
|
|
271
|
+
Apply hierarchy constraints using precomputed sparse indices.
|
|
272
|
+
|
|
273
|
+
Processes level by level:
|
|
274
|
+
1. Apply parent deactivation for features at this level
|
|
275
|
+
2. Apply mutual exclusion for groups whose parent is at this level
|
|
276
|
+
3. Move to next level
|
|
277
|
+
|
|
278
|
+
This ensures that ME at level L affects parent deactivation at level L+1.
|
|
279
|
+
"""
|
|
280
|
+
result = activations.clone()
|
|
281
|
+
|
|
282
|
+
# Data is already on correct device from cache
|
|
283
|
+
me_group_siblings = sparse_data.me_group_siblings
|
|
284
|
+
me_group_sizes = sparse_data.me_group_sizes
|
|
285
|
+
me_group_parents = sparse_data.me_group_parents
|
|
286
|
+
|
|
287
|
+
for level_data in sparse_data.level_data:
|
|
288
|
+
# Step 1: Deactivate children where parent is inactive
|
|
289
|
+
if level_data.features.numel() > 0:
|
|
290
|
+
parent_vals = result[:, level_data.parents]
|
|
291
|
+
child_vals = result[:, level_data.features]
|
|
292
|
+
result[:, level_data.features] = child_vals * (parent_vals > 0)
|
|
293
|
+
|
|
294
|
+
# Step 2: Apply ME for groups whose parent is at this level
|
|
295
|
+
if level_data.me_group_indices.numel() > 0:
|
|
296
|
+
_apply_me_for_groups(
|
|
297
|
+
result,
|
|
298
|
+
level_data.me_group_indices,
|
|
299
|
+
me_group_siblings,
|
|
300
|
+
me_group_sizes,
|
|
301
|
+
me_group_parents,
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
return result
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
def _apply_me_for_groups(
|
|
308
|
+
activations: torch.Tensor,
|
|
309
|
+
group_indices: torch.Tensor,
|
|
310
|
+
me_group_siblings: torch.Tensor,
|
|
311
|
+
me_group_sizes: torch.Tensor,
|
|
312
|
+
me_group_parents: torch.Tensor,
|
|
313
|
+
) -> None:
|
|
314
|
+
"""
|
|
315
|
+
Apply mutual exclusion for the specified groups.
|
|
316
|
+
|
|
317
|
+
Only processes groups where the parent is active (or has no parent).
|
|
318
|
+
This is a key optimization since most groups are skipped when parent is inactive.
|
|
319
|
+
|
|
320
|
+
Args:
|
|
321
|
+
activations: [batch_size, num_features] - modified in place
|
|
322
|
+
group_indices: [num_groups_to_process] - which groups to apply ME for
|
|
323
|
+
me_group_siblings: [total_groups, max_siblings] - sibling indices per group
|
|
324
|
+
me_group_sizes: [total_groups] - number of valid siblings per group
|
|
325
|
+
me_group_parents: [total_groups] - parent feature index (-1 if no parent)
|
|
326
|
+
"""
|
|
327
|
+
batch_size = activations.shape[0]
|
|
328
|
+
device = activations.device
|
|
329
|
+
num_groups = group_indices.numel()
|
|
330
|
+
|
|
331
|
+
if num_groups == 0:
|
|
332
|
+
return
|
|
333
|
+
|
|
334
|
+
# Get parent indices for these groups
|
|
335
|
+
parents = me_group_parents[group_indices] # [num_groups]
|
|
336
|
+
|
|
337
|
+
# Check which parents are active: [batch_size, num_groups]
|
|
338
|
+
# Groups with parent=-1 are always active (root-level ME)
|
|
339
|
+
has_parent = parents >= 0
|
|
340
|
+
if has_parent.all():
|
|
341
|
+
# All groups have parents - check their activation directly
|
|
342
|
+
parent_active = activations[:, parents] > 0 # [batch, num_groups]
|
|
343
|
+
if not parent_active.any():
|
|
344
|
+
return
|
|
345
|
+
elif has_parent.any():
|
|
346
|
+
# Mixed case: some groups have parents, some don't
|
|
347
|
+
# Use clamp to avoid indexing with -1 (reads feature 0, but result is masked out)
|
|
348
|
+
safe_parents = parents.clamp(min=0)
|
|
349
|
+
parent_active = activations[:, safe_parents] > 0 # [batch, num_groups]
|
|
350
|
+
# Groups without parent are always "active"
|
|
351
|
+
parent_active = parent_active | ~has_parent
|
|
352
|
+
else:
|
|
353
|
+
# No groups have parents - all are always active, skip parent check
|
|
354
|
+
parent_active = None
|
|
355
|
+
|
|
356
|
+
# Get siblings for the groups we're processing
|
|
357
|
+
siblings = me_group_siblings[group_indices] # [num_groups, max_siblings]
|
|
358
|
+
sizes = me_group_sizes[group_indices] # [num_groups]
|
|
359
|
+
max_siblings = siblings.shape[1]
|
|
360
|
+
|
|
361
|
+
# Get activations for all siblings: [batch_size, num_groups, max_siblings]
|
|
362
|
+
safe_siblings = siblings.clamp(min=0)
|
|
363
|
+
sibling_activations = activations[:, safe_siblings.view(-1)].view(
|
|
364
|
+
batch_size, num_groups, max_siblings
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
# Create validity mask for padding: [num_groups, max_siblings]
|
|
368
|
+
sibling_range = torch.arange(max_siblings, device=device)
|
|
369
|
+
valid_mask = sibling_range < sizes.unsqueeze(1)
|
|
370
|
+
|
|
371
|
+
# Find active valid siblings, but only where parent is active: [batch, groups, siblings]
|
|
372
|
+
sibling_active = (sibling_activations > 0) & valid_mask
|
|
373
|
+
if parent_active is not None:
|
|
374
|
+
sibling_active = sibling_active & parent_active.unsqueeze(2)
|
|
375
|
+
|
|
376
|
+
# Count active per group and check for conflicts: [batch_size, num_groups]
|
|
377
|
+
active_counts = sibling_active.sum(dim=2)
|
|
378
|
+
needs_exclusion = active_counts > 1
|
|
379
|
+
|
|
380
|
+
if not needs_exclusion.any():
|
|
381
|
+
return
|
|
382
|
+
|
|
383
|
+
# Get (batch, group) pairs needing exclusion
|
|
384
|
+
batch_with_conflict, groups_with_conflict = torch.where(needs_exclusion)
|
|
385
|
+
num_conflicts = batch_with_conflict.numel()
|
|
386
|
+
|
|
387
|
+
if num_conflicts == 0:
|
|
388
|
+
return
|
|
389
|
+
|
|
390
|
+
# Get siblings and activations for conflicts
|
|
391
|
+
conflict_siblings = siblings[groups_with_conflict] # [num_conflicts, max_siblings]
|
|
392
|
+
conflict_active = sibling_active[
|
|
393
|
+
batch_with_conflict, groups_with_conflict
|
|
394
|
+
] # [num_conflicts, max_siblings]
|
|
395
|
+
|
|
396
|
+
# Random selection for winner
|
|
397
|
+
# Use -1e9 instead of -inf to avoid creating a tensor (torch.tensor(-float("inf")))
|
|
398
|
+
# on every call. Since random scores are in [0,1], -1e9 is effectively -inf for argmax.
|
|
399
|
+
random_scores = torch.rand(num_conflicts, max_siblings, device=device)
|
|
400
|
+
random_scores[~conflict_active] = -1e9
|
|
401
|
+
|
|
402
|
+
winner_idx = random_scores.argmax(dim=1)
|
|
403
|
+
|
|
404
|
+
# Determine losers using scatter for efficiency
|
|
405
|
+
is_winner = torch.zeros(
|
|
406
|
+
num_conflicts, max_siblings, dtype=torch.bool, device=device
|
|
407
|
+
)
|
|
408
|
+
is_winner.scatter_(1, winner_idx.unsqueeze(1), True)
|
|
409
|
+
should_deactivate = conflict_active & ~is_winner
|
|
410
|
+
|
|
411
|
+
# Get (conflict, sibling) pairs to deactivate
|
|
412
|
+
conflict_idx, sib_idx = torch.where(should_deactivate)
|
|
413
|
+
|
|
414
|
+
if conflict_idx.numel() == 0:
|
|
415
|
+
return
|
|
416
|
+
|
|
417
|
+
# Map back to (batch, feature) and deactivate
|
|
418
|
+
deact_batch = batch_with_conflict[conflict_idx]
|
|
419
|
+
deact_feat = conflict_siblings[conflict_idx, sib_idx]
|
|
420
|
+
activations[deact_batch, deact_feat] = 0
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
@torch.no_grad()
|
|
424
|
+
def hierarchy_modifier(
|
|
425
|
+
roots: Sequence[HierarchyNode] | HierarchyNode,
|
|
426
|
+
) -> ActivationsModifier:
|
|
427
|
+
"""
|
|
428
|
+
Create an activations modifier from one or more hierarchy trees.
|
|
429
|
+
|
|
430
|
+
This is the recommended way to use hierarchies with ActivationGenerator.
|
|
431
|
+
It validates the hierarchy structure and returns a modifier function that
|
|
432
|
+
applies all hierarchy constraints.
|
|
433
|
+
|
|
434
|
+
Args:
|
|
435
|
+
roots: One or more root HierarchyNode objects. Each root defines an
|
|
436
|
+
independent hierarchy tree. All trees are validated and applied.
|
|
437
|
+
|
|
438
|
+
Returns:
|
|
439
|
+
An ActivationsModifier function that can be passed to ActivationGenerator.
|
|
440
|
+
|
|
441
|
+
Raises:
|
|
442
|
+
ValueError: If validate=True and any hierarchy contains loops or
|
|
443
|
+
nodes with multiple parents.
|
|
444
|
+
"""
|
|
445
|
+
if not roots:
|
|
446
|
+
# No hierarchies - return identity function
|
|
447
|
+
def identity(activations: torch.Tensor) -> torch.Tensor:
|
|
448
|
+
return activations
|
|
449
|
+
|
|
450
|
+
return identity
|
|
451
|
+
|
|
452
|
+
if isinstance(roots, HierarchyNode):
|
|
453
|
+
roots = [roots]
|
|
454
|
+
_validate_hierarchy(roots)
|
|
455
|
+
|
|
456
|
+
# Build sparse hierarchy data
|
|
457
|
+
sparse_data = _build_sparse_hierarchy(roots)
|
|
458
|
+
|
|
459
|
+
# Cache for device-specific tensors
|
|
460
|
+
device_cache: dict[torch.device, _SparseHierarchyData] = {}
|
|
461
|
+
|
|
462
|
+
def _get_sparse_for_device(device: torch.device) -> _SparseHierarchyData:
|
|
463
|
+
"""Get or create device-specific sparse hierarchy data."""
|
|
464
|
+
if device not in device_cache:
|
|
465
|
+
device_cache[device] = _SparseHierarchyData(
|
|
466
|
+
level_data=[
|
|
467
|
+
_LevelData(
|
|
468
|
+
features=ld.features.to(device),
|
|
469
|
+
parents=ld.parents.to(device),
|
|
470
|
+
me_group_indices=ld.me_group_indices.to(device),
|
|
471
|
+
)
|
|
472
|
+
for ld in sparse_data.level_data
|
|
473
|
+
],
|
|
474
|
+
me_group_siblings=sparse_data.me_group_siblings.to(device),
|
|
475
|
+
me_group_sizes=sparse_data.me_group_sizes.to(device),
|
|
476
|
+
me_group_parents=sparse_data.me_group_parents.to(device),
|
|
477
|
+
num_groups=sparse_data.num_groups,
|
|
478
|
+
)
|
|
479
|
+
return device_cache[device]
|
|
480
|
+
|
|
481
|
+
def modifier(activations: torch.Tensor) -> torch.Tensor:
|
|
482
|
+
device = activations.device
|
|
483
|
+
cached = _get_sparse_for_device(device)
|
|
484
|
+
return _apply_hierarchy_sparse(activations, cached)
|
|
485
|
+
|
|
486
|
+
return modifier
|
|
487
|
+
|
|
488
|
+
|
|
489
|
+
class HierarchyNode:
|
|
490
|
+
"""
|
|
491
|
+
Represents a node in a feature hierarchy tree.
|
|
492
|
+
|
|
493
|
+
Used to define hierarchical dependencies between features. Children are
|
|
494
|
+
deactivated when their parent is inactive, and children can optionally
|
|
495
|
+
be mutually exclusive.
|
|
496
|
+
|
|
497
|
+
Use `hierarchy_modifier()` to create an ActivationsModifier from one or
|
|
498
|
+
more HierarchyNode trees.
|
|
499
|
+
|
|
500
|
+
|
|
501
|
+
Attributes:
|
|
502
|
+
feature_index: Index of this feature in the activation tensor
|
|
503
|
+
children: Child HierarchyNode nodes
|
|
504
|
+
mutually_exclusive_children: If True, at most one child is active per sample
|
|
505
|
+
feature_id: Optional identifier for debugging
|
|
506
|
+
"""
|
|
507
|
+
|
|
508
|
+
children: Sequence[HierarchyNode]
|
|
509
|
+
feature_index: int | None
|
|
510
|
+
|
|
511
|
+
@classmethod
|
|
512
|
+
def from_dict(cls, tree_dict: dict[str, Any]) -> HierarchyNode:
|
|
513
|
+
"""
|
|
514
|
+
Create a HierarchyNode from a dictionary specification.
|
|
515
|
+
|
|
516
|
+
Args:
|
|
517
|
+
tree_dict: Dictionary with keys:
|
|
518
|
+
|
|
519
|
+
- feature_index (optional): Index in the activation tensor
|
|
520
|
+
- children (optional): List of child tree dictionaries
|
|
521
|
+
- mutually_exclusive_children (optional): Whether children are exclusive
|
|
522
|
+
- id (optional): Identifier for this node
|
|
523
|
+
|
|
524
|
+
Returns:
|
|
525
|
+
HierarchyNode instance
|
|
526
|
+
"""
|
|
527
|
+
children = [
|
|
528
|
+
HierarchyNode.from_dict(child_dict)
|
|
529
|
+
for child_dict in tree_dict.get("children", [])
|
|
530
|
+
]
|
|
531
|
+
return cls(
|
|
532
|
+
feature_index=tree_dict.get("feature_index"),
|
|
533
|
+
children=children,
|
|
534
|
+
mutually_exclusive_children=tree_dict.get(
|
|
535
|
+
"mutually_exclusive_children", False
|
|
536
|
+
),
|
|
537
|
+
feature_id=tree_dict.get("id"),
|
|
538
|
+
)
|
|
539
|
+
|
|
540
|
+
def __init__(
|
|
541
|
+
self,
|
|
542
|
+
feature_index: int | None = None,
|
|
543
|
+
children: Sequence[HierarchyNode] | None = None,
|
|
544
|
+
mutually_exclusive_children: bool = False,
|
|
545
|
+
feature_id: str | None = None,
|
|
546
|
+
):
|
|
547
|
+
"""
|
|
548
|
+
Create a new HierarchyNode.
|
|
549
|
+
|
|
550
|
+
Args:
|
|
551
|
+
feature_index: Index of this feature in the activation tensor.
|
|
552
|
+
Use None for organizational nodes that don't correspond to a feature.
|
|
553
|
+
children: Child nodes that depend on this feature
|
|
554
|
+
mutually_exclusive_children: If True, only one child can be active per sample
|
|
555
|
+
feature_id: Optional identifier for debugging
|
|
556
|
+
"""
|
|
557
|
+
self.feature_index = feature_index
|
|
558
|
+
self.children = children or []
|
|
559
|
+
self.mutually_exclusive_children = mutually_exclusive_children
|
|
560
|
+
self.feature_id = feature_id
|
|
561
|
+
|
|
562
|
+
if self.mutually_exclusive_children and len(self.children) < 2:
|
|
563
|
+
raise ValueError("Need at least 2 children for mutual exclusion")
|
|
564
|
+
|
|
565
|
+
def get_all_feature_indices(self) -> list[int]:
|
|
566
|
+
"""Get all feature indices in this subtree."""
|
|
567
|
+
indices = []
|
|
568
|
+
if self.feature_index is not None:
|
|
569
|
+
indices.append(self.feature_index)
|
|
570
|
+
for child in self.children:
|
|
571
|
+
indices.extend(child.get_all_feature_indices())
|
|
572
|
+
return indices
|
|
573
|
+
|
|
574
|
+
def validate(self) -> None:
|
|
575
|
+
"""
|
|
576
|
+
Validate the hierarchy structure.
|
|
577
|
+
|
|
578
|
+
Checks that:
|
|
579
|
+
1. There are no loops (no node is its own ancestor)
|
|
580
|
+
2. Each node has at most one parent (no node appears in multiple children lists)
|
|
581
|
+
|
|
582
|
+
Raises:
|
|
583
|
+
ValueError: If the hierarchy is invalid
|
|
584
|
+
"""
|
|
585
|
+
_validate_hierarchy([self])
|
|
586
|
+
|
|
587
|
+
def __repr__(self, indent: int = 0) -> str:
|
|
588
|
+
s = " " * (indent * 2)
|
|
589
|
+
s += str(self.feature_index) if self.feature_index is not None else "-"
|
|
590
|
+
s += "x" if self.mutually_exclusive_children else " "
|
|
591
|
+
if self.feature_id:
|
|
592
|
+
s += f" ({self.feature_id})"
|
|
593
|
+
|
|
594
|
+
for child in self.children:
|
|
595
|
+
s += "\n" + child.__repr__(indent + 2)
|
|
596
|
+
return s
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from sae_lens.synthetic import FeatureDictionary
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@torch.no_grad()
|
|
7
|
+
def init_sae_to_match_feature_dict(
|
|
8
|
+
sae: torch.nn.Module,
|
|
9
|
+
feature_dict: FeatureDictionary,
|
|
10
|
+
noise_level: float = 0.0,
|
|
11
|
+
feature_ordering: torch.Tensor | None = None,
|
|
12
|
+
) -> None:
|
|
13
|
+
"""
|
|
14
|
+
Initialize an SAE's weights to match a feature dictionary.
|
|
15
|
+
|
|
16
|
+
This can be useful for:
|
|
17
|
+
|
|
18
|
+
- Starting training from a known good initialization
|
|
19
|
+
- Testing SAE evaluation code with ground truth
|
|
20
|
+
- Ablation studies on initialization
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
sae: The SAE to initialize. Must have W_enc and W_dec attributes.
|
|
24
|
+
feature_dict: The feature dictionary to match
|
|
25
|
+
noise_level: Standard deviation of Gaussian noise to add (0 = exact match)
|
|
26
|
+
feature_ordering: Optional permutation of feature indices
|
|
27
|
+
"""
|
|
28
|
+
features = feature_dict.feature_vectors # [num_features, hidden_dim]
|
|
29
|
+
min_dim = min(sae.W_enc.shape[1], features.shape[0]) # type: ignore[attr-defined]
|
|
30
|
+
|
|
31
|
+
if feature_ordering is not None:
|
|
32
|
+
features = features[feature_ordering]
|
|
33
|
+
|
|
34
|
+
features = features[:min_dim]
|
|
35
|
+
|
|
36
|
+
# W_enc is [hidden_dim, d_sae], feature vectors are [num_features, hidden_dim]
|
|
37
|
+
sae.W_enc.data[:, :min_dim] = ( # type: ignore[index]
|
|
38
|
+
features.T + torch.randn_like(features.T) * noise_level
|
|
39
|
+
)
|
|
40
|
+
sae.W_dec.data = sae.W_enc.data.T.clone() # type: ignore[union-attr]
|