sae-lens 6.26.1__py3-none-any.whl → 6.28.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +3 -1
- sae_lens/cache_activations_runner.py +12 -5
- sae_lens/config.py +2 -0
- sae_lens/loading/pretrained_sae_loaders.py +2 -1
- sae_lens/loading/pretrained_saes_directory.py +18 -0
- sae_lens/saes/gated_sae.py +1 -0
- sae_lens/saes/jumprelu_sae.py +3 -0
- sae_lens/saes/sae.py +13 -0
- sae_lens/saes/standard_sae.py +2 -0
- sae_lens/saes/temporal_sae.py +1 -0
- sae_lens/synthetic/__init__.py +89 -0
- sae_lens/synthetic/activation_generator.py +215 -0
- sae_lens/synthetic/correlation.py +170 -0
- sae_lens/synthetic/evals.py +141 -0
- sae_lens/synthetic/feature_dictionary.py +138 -0
- sae_lens/synthetic/firing_probabilities.py +104 -0
- sae_lens/synthetic/hierarchy.py +335 -0
- sae_lens/synthetic/initialization.py +40 -0
- sae_lens/synthetic/plotting.py +230 -0
- sae_lens/synthetic/training.py +145 -0
- sae_lens/tokenization_and_batching.py +1 -1
- sae_lens/training/activations_store.py +51 -91
- sae_lens/training/mixing_buffer.py +14 -5
- sae_lens/training/sae_trainer.py +1 -1
- sae_lens/util.py +26 -1
- {sae_lens-6.26.1.dist-info → sae_lens-6.28.1.dist-info}/METADATA +3 -1
- sae_lens-6.28.1.dist-info/RECORD +52 -0
- sae_lens-6.26.1.dist-info/RECORD +0 -42
- {sae_lens-6.26.1.dist-info → sae_lens-6.28.1.dist-info}/WHEEL +0 -0
- {sae_lens-6.26.1.dist-info → sae_lens-6.28.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,335 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Hierarchical feature modifier for activation generators.
|
|
3
|
+
|
|
4
|
+
This module provides HierarchyNode, which enforces hierarchical dependencies
|
|
5
|
+
on feature activations. Child features are deactivated when their parent is inactive,
|
|
6
|
+
and children can optionally be mutually exclusive.
|
|
7
|
+
|
|
8
|
+
Based on Noa Nabeshima's Matryoshka SAEs:
|
|
9
|
+
https://github.com/noanabeshima/matryoshka-saes/blob/main/toy_model.py
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
from collections.abc import Callable, Sequence
|
|
15
|
+
from typing import Any
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
|
|
19
|
+
ActivationsModifier = Callable[[torch.Tensor], torch.Tensor]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _validate_hierarchy(roots: Sequence[HierarchyNode]) -> None:
|
|
23
|
+
"""
|
|
24
|
+
Validate a forest of hierarchy trees.
|
|
25
|
+
|
|
26
|
+
Treats the input as children of a virtual root node and validates the
|
|
27
|
+
entire structure.
|
|
28
|
+
|
|
29
|
+
Checks that:
|
|
30
|
+
1. There are no loops (no node is its own ancestor)
|
|
31
|
+
2. Each node has at most one parent (no node appears in multiple children lists)
|
|
32
|
+
3. No feature index appears in multiple trees
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
roots: Root nodes of the hierarchy trees to validate
|
|
36
|
+
|
|
37
|
+
Raises:
|
|
38
|
+
ValueError: If the hierarchy is invalid
|
|
39
|
+
"""
|
|
40
|
+
if not roots:
|
|
41
|
+
return
|
|
42
|
+
|
|
43
|
+
# Collect all nodes and check for loops, treating roots as children of virtual root
|
|
44
|
+
all_nodes: list[HierarchyNode] = []
|
|
45
|
+
virtual_root_id = id(roots) # Use the list itself as virtual root identity
|
|
46
|
+
|
|
47
|
+
for root in roots:
|
|
48
|
+
all_nodes.append(root)
|
|
49
|
+
_collect_nodes_and_check_loops(root, all_nodes, ancestors={virtual_root_id})
|
|
50
|
+
|
|
51
|
+
# Check for multiple parents (same node appearing multiple times)
|
|
52
|
+
seen_ids: set[int] = set()
|
|
53
|
+
for node in all_nodes:
|
|
54
|
+
node_id = id(node)
|
|
55
|
+
if node_id in seen_ids:
|
|
56
|
+
node_desc = _node_description(node)
|
|
57
|
+
raise ValueError(
|
|
58
|
+
f"Node ({node_desc}) has multiple parents. "
|
|
59
|
+
"Each node must have at most one parent."
|
|
60
|
+
)
|
|
61
|
+
seen_ids.add(node_id)
|
|
62
|
+
|
|
63
|
+
# Check for overlapping feature indices across trees
|
|
64
|
+
if len(roots) > 1:
|
|
65
|
+
all_indices: set[int] = set()
|
|
66
|
+
for root in roots:
|
|
67
|
+
tree_indices = root.get_all_feature_indices()
|
|
68
|
+
overlap = all_indices & set(tree_indices)
|
|
69
|
+
if overlap:
|
|
70
|
+
raise ValueError(
|
|
71
|
+
f"Feature indices {overlap} appear in multiple hierarchy trees. "
|
|
72
|
+
"Each feature should belong to at most one hierarchy."
|
|
73
|
+
)
|
|
74
|
+
all_indices.update(tree_indices)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _collect_nodes_and_check_loops(
|
|
78
|
+
node: HierarchyNode,
|
|
79
|
+
all_nodes: list[HierarchyNode],
|
|
80
|
+
ancestors: set[int],
|
|
81
|
+
) -> None:
|
|
82
|
+
"""Recursively collect nodes and check for loops."""
|
|
83
|
+
node_id = id(node)
|
|
84
|
+
|
|
85
|
+
if node_id in ancestors:
|
|
86
|
+
node_desc = _node_description(node)
|
|
87
|
+
raise ValueError(f"Loop detected: node ({node_desc}) is its own ancestor.")
|
|
88
|
+
|
|
89
|
+
# Add to ancestors for children traversal
|
|
90
|
+
new_ancestors = ancestors | {node_id}
|
|
91
|
+
|
|
92
|
+
for child in node.children:
|
|
93
|
+
# Collect child (before recursing, so we can detect multiple parents)
|
|
94
|
+
all_nodes.append(child)
|
|
95
|
+
_collect_nodes_and_check_loops(child, all_nodes, new_ancestors)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def _node_description(node: HierarchyNode) -> str:
|
|
99
|
+
"""Get a human-readable description of a node for error messages."""
|
|
100
|
+
if node.feature_index is not None:
|
|
101
|
+
return f"feature_index={node.feature_index}"
|
|
102
|
+
if node.feature_id:
|
|
103
|
+
return f"id={node.feature_id}"
|
|
104
|
+
return "unnamed node"
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def hierarchy_modifier(
|
|
108
|
+
roots: Sequence[HierarchyNode] | HierarchyNode,
|
|
109
|
+
) -> ActivationsModifier:
|
|
110
|
+
"""
|
|
111
|
+
Create an activations modifier from one or more hierarchy trees.
|
|
112
|
+
|
|
113
|
+
This is the recommended way to use hierarchies with ActivationGenerator.
|
|
114
|
+
It validates the hierarchy structure and returns a modifier function that
|
|
115
|
+
applies all hierarchy constraints.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
roots: One or more root HierarchyNode objects. Each root defines an
|
|
119
|
+
independent hierarchy tree. All trees are validated and applied.
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
An ActivationsModifier function that can be passed to ActivationGenerator.
|
|
123
|
+
|
|
124
|
+
Raises:
|
|
125
|
+
ValueError: If validate=True and any hierarchy contains loops or
|
|
126
|
+
nodes with multiple parents.
|
|
127
|
+
"""
|
|
128
|
+
if not roots:
|
|
129
|
+
# No hierarchies - return identity function
|
|
130
|
+
def identity(activations: torch.Tensor) -> torch.Tensor:
|
|
131
|
+
return activations
|
|
132
|
+
|
|
133
|
+
return identity
|
|
134
|
+
|
|
135
|
+
if isinstance(roots, HierarchyNode):
|
|
136
|
+
roots = [roots]
|
|
137
|
+
_validate_hierarchy(roots)
|
|
138
|
+
|
|
139
|
+
# Create modifier function that applies all hierarchies
|
|
140
|
+
def modifier(activations: torch.Tensor) -> torch.Tensor:
|
|
141
|
+
result = activations.clone()
|
|
142
|
+
for root in roots:
|
|
143
|
+
root._apply_hierarchy(result, parent_active_mask=None)
|
|
144
|
+
return result
|
|
145
|
+
|
|
146
|
+
return modifier
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class HierarchyNode:
|
|
150
|
+
"""
|
|
151
|
+
Represents a node in a feature hierarchy tree.
|
|
152
|
+
|
|
153
|
+
Used to define hierarchical dependencies between features. Children are
|
|
154
|
+
deactivated when their parent is inactive, and children can optionally
|
|
155
|
+
be mutually exclusive.
|
|
156
|
+
|
|
157
|
+
Use `hierarchy_modifier()` to create an ActivationsModifier from one or
|
|
158
|
+
more HierarchyNode trees.
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
Attributes:
|
|
162
|
+
feature_index: Index of this feature in the activation tensor
|
|
163
|
+
children: Child HierarchyNode nodes
|
|
164
|
+
mutually_exclusive_children: If True, at most one child is active per sample
|
|
165
|
+
feature_id: Optional identifier for debugging
|
|
166
|
+
"""
|
|
167
|
+
|
|
168
|
+
children: Sequence[HierarchyNode]
|
|
169
|
+
feature_index: int | None
|
|
170
|
+
|
|
171
|
+
@classmethod
|
|
172
|
+
def from_dict(cls, tree_dict: dict[str, Any]) -> HierarchyNode:
|
|
173
|
+
"""
|
|
174
|
+
Create a HierarchyNode from a dictionary specification.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
tree_dict: Dictionary with keys:
|
|
178
|
+
|
|
179
|
+
- feature_index (optional): Index in the activation tensor
|
|
180
|
+
- children (optional): List of child tree dictionaries
|
|
181
|
+
- mutually_exclusive_children (optional): Whether children are exclusive
|
|
182
|
+
- id (optional): Identifier for this node
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
HierarchyNode instance
|
|
186
|
+
"""
|
|
187
|
+
children = [
|
|
188
|
+
HierarchyNode.from_dict(child_dict)
|
|
189
|
+
for child_dict in tree_dict.get("children", [])
|
|
190
|
+
]
|
|
191
|
+
return cls(
|
|
192
|
+
feature_index=tree_dict.get("feature_index"),
|
|
193
|
+
children=children,
|
|
194
|
+
mutually_exclusive_children=tree_dict.get(
|
|
195
|
+
"mutually_exclusive_children", False
|
|
196
|
+
),
|
|
197
|
+
feature_id=tree_dict.get("id"),
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
def __init__(
|
|
201
|
+
self,
|
|
202
|
+
feature_index: int | None = None,
|
|
203
|
+
children: Sequence[HierarchyNode] | None = None,
|
|
204
|
+
mutually_exclusive_children: bool = False,
|
|
205
|
+
feature_id: str | None = None,
|
|
206
|
+
):
|
|
207
|
+
"""
|
|
208
|
+
Create a new HierarchyNode.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
feature_index: Index of this feature in the activation tensor.
|
|
212
|
+
Use None for organizational nodes that don't correspond to a feature.
|
|
213
|
+
children: Child nodes that depend on this feature
|
|
214
|
+
mutually_exclusive_children: If True, only one child can be active per sample
|
|
215
|
+
feature_id: Optional identifier for debugging
|
|
216
|
+
"""
|
|
217
|
+
self.feature_index = feature_index
|
|
218
|
+
self.children = children or []
|
|
219
|
+
self.mutually_exclusive_children = mutually_exclusive_children
|
|
220
|
+
self.feature_id = feature_id
|
|
221
|
+
|
|
222
|
+
if self.mutually_exclusive_children and len(self.children) < 2:
|
|
223
|
+
raise ValueError("Need at least 2 children for mutual exclusion")
|
|
224
|
+
|
|
225
|
+
def _apply_hierarchy(
|
|
226
|
+
self,
|
|
227
|
+
activations: torch.Tensor,
|
|
228
|
+
parent_active_mask: torch.Tensor | None,
|
|
229
|
+
) -> None:
|
|
230
|
+
"""Recursively apply hierarchical constraints."""
|
|
231
|
+
batch_size = activations.shape[0]
|
|
232
|
+
|
|
233
|
+
# Determine which samples have this node active
|
|
234
|
+
if self.feature_index is not None:
|
|
235
|
+
is_active = activations[:, self.feature_index] > 0
|
|
236
|
+
else:
|
|
237
|
+
# Non-readout node: active if parent is active (or always if root)
|
|
238
|
+
is_active = (
|
|
239
|
+
parent_active_mask
|
|
240
|
+
if parent_active_mask is not None
|
|
241
|
+
else torch.ones(batch_size, dtype=torch.bool, device=activations.device)
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
# Deactivate this node if parent is inactive
|
|
245
|
+
if parent_active_mask is not None and self.feature_index is not None:
|
|
246
|
+
activations[~parent_active_mask, self.feature_index] = 0
|
|
247
|
+
# Update is_active after deactivation
|
|
248
|
+
is_active = activations[:, self.feature_index] > 0
|
|
249
|
+
|
|
250
|
+
# Handle mutually exclusive children
|
|
251
|
+
if self.mutually_exclusive_children and len(self.children) >= 2:
|
|
252
|
+
self._enforce_mutual_exclusion(activations, is_active)
|
|
253
|
+
|
|
254
|
+
# Recursively process children
|
|
255
|
+
for child in self.children:
|
|
256
|
+
child._apply_hierarchy(activations, parent_active_mask=is_active)
|
|
257
|
+
|
|
258
|
+
def _enforce_mutual_exclusion(
|
|
259
|
+
self,
|
|
260
|
+
activations: torch.Tensor,
|
|
261
|
+
parent_active_mask: torch.Tensor,
|
|
262
|
+
) -> None:
|
|
263
|
+
"""Ensure at most one child is active per sample."""
|
|
264
|
+
batch_size = activations.shape[0]
|
|
265
|
+
|
|
266
|
+
# Get indices of children that have feature indices
|
|
267
|
+
child_indices = [
|
|
268
|
+
child.feature_index
|
|
269
|
+
for child in self.children
|
|
270
|
+
if child.feature_index is not None
|
|
271
|
+
]
|
|
272
|
+
|
|
273
|
+
if len(child_indices) < 2:
|
|
274
|
+
return
|
|
275
|
+
|
|
276
|
+
# For each sample where parent is active, enforce mutual exclusion.
|
|
277
|
+
# Note: This loop is not vectorized because we need to randomly select
|
|
278
|
+
# which child to keep active per sample. Vectorizing would require either
|
|
279
|
+
# a deterministic selection (losing randomness) or complex gather/scatter
|
|
280
|
+
# operations that aren't more efficient for typical batch sizes.
|
|
281
|
+
for batch_idx in range(batch_size):
|
|
282
|
+
if not parent_active_mask[batch_idx]:
|
|
283
|
+
continue
|
|
284
|
+
|
|
285
|
+
# Find which children are active
|
|
286
|
+
active_children = [
|
|
287
|
+
i
|
|
288
|
+
for i, feat_idx in enumerate(child_indices)
|
|
289
|
+
if activations[batch_idx, feat_idx] > 0
|
|
290
|
+
]
|
|
291
|
+
|
|
292
|
+
if len(active_children) <= 1:
|
|
293
|
+
continue
|
|
294
|
+
|
|
295
|
+
# Randomly select one to keep active
|
|
296
|
+
random_idx = int(torch.randint(len(active_children), (1,)).item())
|
|
297
|
+
keep_idx = active_children[random_idx]
|
|
298
|
+
|
|
299
|
+
# Deactivate all others
|
|
300
|
+
for i, feat_idx in enumerate(child_indices):
|
|
301
|
+
if i != keep_idx and i in active_children:
|
|
302
|
+
activations[batch_idx, feat_idx] = 0
|
|
303
|
+
|
|
304
|
+
def get_all_feature_indices(self) -> list[int]:
|
|
305
|
+
"""Get all feature indices in this subtree."""
|
|
306
|
+
indices = []
|
|
307
|
+
if self.feature_index is not None:
|
|
308
|
+
indices.append(self.feature_index)
|
|
309
|
+
for child in self.children:
|
|
310
|
+
indices.extend(child.get_all_feature_indices())
|
|
311
|
+
return indices
|
|
312
|
+
|
|
313
|
+
def validate(self) -> None:
|
|
314
|
+
"""
|
|
315
|
+
Validate the hierarchy structure.
|
|
316
|
+
|
|
317
|
+
Checks that:
|
|
318
|
+
1. There are no loops (no node is its own ancestor)
|
|
319
|
+
2. Each node has at most one parent (no node appears in multiple children lists)
|
|
320
|
+
|
|
321
|
+
Raises:
|
|
322
|
+
ValueError: If the hierarchy is invalid
|
|
323
|
+
"""
|
|
324
|
+
_validate_hierarchy([self])
|
|
325
|
+
|
|
326
|
+
def __repr__(self, indent: int = 0) -> str:
|
|
327
|
+
s = " " * (indent * 2)
|
|
328
|
+
s += str(self.feature_index) if self.feature_index is not None else "-"
|
|
329
|
+
s += "x" if self.mutually_exclusive_children else " "
|
|
330
|
+
if self.feature_id:
|
|
331
|
+
s += f" ({self.feature_id})"
|
|
332
|
+
|
|
333
|
+
for child in self.children:
|
|
334
|
+
s += "\n" + child.__repr__(indent + 2)
|
|
335
|
+
return s
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from sae_lens.synthetic import FeatureDictionary
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@torch.no_grad()
|
|
7
|
+
def init_sae_to_match_feature_dict(
|
|
8
|
+
sae: torch.nn.Module,
|
|
9
|
+
feature_dict: FeatureDictionary,
|
|
10
|
+
noise_level: float = 0.0,
|
|
11
|
+
feature_ordering: torch.Tensor | None = None,
|
|
12
|
+
) -> None:
|
|
13
|
+
"""
|
|
14
|
+
Initialize an SAE's weights to match a feature dictionary.
|
|
15
|
+
|
|
16
|
+
This can be useful for:
|
|
17
|
+
|
|
18
|
+
- Starting training from a known good initialization
|
|
19
|
+
- Testing SAE evaluation code with ground truth
|
|
20
|
+
- Ablation studies on initialization
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
sae: The SAE to initialize. Must have W_enc and W_dec attributes.
|
|
24
|
+
feature_dict: The feature dictionary to match
|
|
25
|
+
noise_level: Standard deviation of Gaussian noise to add (0 = exact match)
|
|
26
|
+
feature_ordering: Optional permutation of feature indices
|
|
27
|
+
"""
|
|
28
|
+
features = feature_dict.feature_vectors # [num_features, hidden_dim]
|
|
29
|
+
min_dim = min(sae.W_enc.shape[1], features.shape[0]) # type: ignore[attr-defined]
|
|
30
|
+
|
|
31
|
+
if feature_ordering is not None:
|
|
32
|
+
features = features[feature_ordering]
|
|
33
|
+
|
|
34
|
+
features = features[:min_dim]
|
|
35
|
+
|
|
36
|
+
# W_enc is [hidden_dim, d_sae], feature vectors are [num_features, hidden_dim]
|
|
37
|
+
sae.W_enc.data[:, :min_dim] = ( # type: ignore[index]
|
|
38
|
+
features.T + torch.randn_like(features.T) * noise_level
|
|
39
|
+
)
|
|
40
|
+
sae.W_dec.data = sae.W_enc.data.T.clone() # type: ignore[union-attr]
|
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Plotting utilities for visualizing SAE training on synthetic data.
|
|
3
|
+
|
|
4
|
+
This module provides functions for:
|
|
5
|
+
|
|
6
|
+
- Plotting cosine similarities between SAE features and true features
|
|
7
|
+
- Automatically reordering features for better visualization
|
|
8
|
+
- Creating comparison plots between encoder and decoder
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from collections.abc import Iterable
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import Any
|
|
14
|
+
|
|
15
|
+
import plotly.graph_objects as go
|
|
16
|
+
import torch
|
|
17
|
+
from plotly.subplots import make_subplots
|
|
18
|
+
|
|
19
|
+
from sae_lens.saes import SAE
|
|
20
|
+
from sae_lens.synthetic.feature_dictionary import FeatureDictionary
|
|
21
|
+
from sae_lens.util import cosine_similarities
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def find_best_feature_ordering(
|
|
25
|
+
sae_features: torch.Tensor,
|
|
26
|
+
true_features: torch.Tensor,
|
|
27
|
+
) -> torch.Tensor:
|
|
28
|
+
"""
|
|
29
|
+
Find the best ordering of SAE features to match true features.
|
|
30
|
+
|
|
31
|
+
Reorders SAE features so that each SAE latent aligns with its best-matching
|
|
32
|
+
true feature in order. This makes cosine similarity plots more interpretable.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
sae_features: SAE decoder weights of shape [d_sae, hidden_dim]
|
|
36
|
+
true_features: True feature vectors of shape [num_features, hidden_dim]
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
Tensor of indices that reorders sae_features for best alignment
|
|
40
|
+
"""
|
|
41
|
+
cos_sims = cosine_similarities(sae_features, true_features)
|
|
42
|
+
best_matches = torch.argmax(torch.abs(cos_sims), dim=1)
|
|
43
|
+
return torch.argsort(best_matches)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def find_best_feature_ordering_from_sae(
|
|
47
|
+
sae: torch.nn.Module,
|
|
48
|
+
feature_dict: FeatureDictionary,
|
|
49
|
+
) -> torch.Tensor:
|
|
50
|
+
"""
|
|
51
|
+
Find the best feature ordering for an SAE given a feature dictionary.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
sae: SAE with W_dec attribute of shape [d_sae, hidden_dim]
|
|
55
|
+
feature_dict: The feature dictionary containing true features
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
Tensor of indices that reorders SAE latents for best alignment
|
|
59
|
+
"""
|
|
60
|
+
sae_features = sae.W_dec.detach() # type: ignore[attr-defined]
|
|
61
|
+
true_features = feature_dict.feature_vectors.detach()
|
|
62
|
+
return find_best_feature_ordering(sae_features, true_features)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def find_best_feature_ordering_across_saes(
|
|
66
|
+
saes: Iterable[torch.nn.Module],
|
|
67
|
+
feature_dict: FeatureDictionary,
|
|
68
|
+
) -> torch.Tensor:
|
|
69
|
+
"""
|
|
70
|
+
Find the best feature ordering that works across multiple SAEs.
|
|
71
|
+
|
|
72
|
+
Useful for creating consistent orderings across training snapshots.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
saes: Iterable of SAEs to consider
|
|
76
|
+
feature_dict: The feature dictionary containing true features
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
The best ordering tensor found across all SAEs
|
|
80
|
+
"""
|
|
81
|
+
best_score = float("-inf")
|
|
82
|
+
best_ordering: torch.Tensor | None = None
|
|
83
|
+
|
|
84
|
+
true_features = feature_dict.feature_vectors.detach()
|
|
85
|
+
|
|
86
|
+
for sae in saes:
|
|
87
|
+
sae_features = sae.W_dec.detach() # type: ignore[attr-defined]
|
|
88
|
+
cos_sims = cosine_similarities(sae_features, true_features)
|
|
89
|
+
cos_sims = torch.round(cos_sims * 100) / 100 # Reduce numerical noise
|
|
90
|
+
|
|
91
|
+
ordering = find_best_feature_ordering(sae_features, true_features)
|
|
92
|
+
score = cos_sims[ordering, torch.arange(cos_sims.shape[1])].mean().item()
|
|
93
|
+
|
|
94
|
+
if score > best_score:
|
|
95
|
+
best_score = score
|
|
96
|
+
best_ordering = ordering
|
|
97
|
+
|
|
98
|
+
if best_ordering is None:
|
|
99
|
+
raise ValueError("No SAEs provided")
|
|
100
|
+
|
|
101
|
+
return best_ordering
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def plot_sae_feature_similarity(
|
|
105
|
+
sae: SAE[Any],
|
|
106
|
+
feature_dict: FeatureDictionary,
|
|
107
|
+
title: str | None = None,
|
|
108
|
+
reorder_features: bool | torch.Tensor = False,
|
|
109
|
+
decoder_only: bool = False,
|
|
110
|
+
show_values: bool = False,
|
|
111
|
+
height: int = 400,
|
|
112
|
+
width: int = 800,
|
|
113
|
+
save_path: str | Path | None = None,
|
|
114
|
+
show_plot: bool = True,
|
|
115
|
+
dtick: int | None = 1,
|
|
116
|
+
scale: float = 1.0,
|
|
117
|
+
):
|
|
118
|
+
"""
|
|
119
|
+
Plot cosine similarities between SAE features and true features.
|
|
120
|
+
|
|
121
|
+
Creates a heatmap showing how well each SAE latent aligns with each
|
|
122
|
+
true feature. Useful for understanding what the SAE has learned.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
sae: The SAE to visualize. Must have W_enc and W_dec attributes.
|
|
126
|
+
feature_dict: The feature dictionary containing true features
|
|
127
|
+
title: Plot title. If None, a default title is used.
|
|
128
|
+
reorder_features: If True, automatically reorders features for best alignment.
|
|
129
|
+
If a tensor, uses that as the ordering.
|
|
130
|
+
decoder_only: If True, only plots the decoder (not encoder and decoder side-by-side)
|
|
131
|
+
show_values: If True, shows numeric values on the heatmap
|
|
132
|
+
height: Height of the figure in pixels
|
|
133
|
+
width: Width of the figure in pixels
|
|
134
|
+
save_path: If provided, saves the figure to this path
|
|
135
|
+
show_plot: If True, displays the plot
|
|
136
|
+
dtick: Tick spacing for axes
|
|
137
|
+
scale: Scale factor for image resolution when saving
|
|
138
|
+
"""
|
|
139
|
+
# Get cosine similarities
|
|
140
|
+
true_features = feature_dict.feature_vectors.detach()
|
|
141
|
+
dec_cos_sims = cosine_similarities(sae.W_dec.detach(), true_features) # type: ignore[attr-defined]
|
|
142
|
+
enc_cos_sims = cosine_similarities(sae.W_enc.T.detach(), true_features) # type: ignore[attr-defined]
|
|
143
|
+
|
|
144
|
+
# Round to reduce numerical noise
|
|
145
|
+
dec_cos_sims = torch.round(dec_cos_sims * 100) / 100
|
|
146
|
+
enc_cos_sims = torch.round(enc_cos_sims * 100) / 100
|
|
147
|
+
|
|
148
|
+
# Apply feature reordering if requested
|
|
149
|
+
if reorder_features is not False:
|
|
150
|
+
if isinstance(reorder_features, bool):
|
|
151
|
+
sorted_indices = find_best_feature_ordering(
|
|
152
|
+
sae.W_dec.detach(),
|
|
153
|
+
true_features, # type: ignore[attr-defined]
|
|
154
|
+
)
|
|
155
|
+
else:
|
|
156
|
+
sorted_indices = reorder_features
|
|
157
|
+
dec_cos_sims = dec_cos_sims[sorted_indices]
|
|
158
|
+
enc_cos_sims = enc_cos_sims[sorted_indices]
|
|
159
|
+
|
|
160
|
+
hovertemplate = "True feature: %{x}<br>SAE Latent: %{y}<br>Cosine Similarity: %{z:.3f}<extra></extra>"
|
|
161
|
+
|
|
162
|
+
if decoder_only:
|
|
163
|
+
fig = make_subplots(rows=1, cols=1)
|
|
164
|
+
|
|
165
|
+
decoder_args: dict[str, Any] = {
|
|
166
|
+
"z": dec_cos_sims.cpu().numpy(),
|
|
167
|
+
"zmin": -1,
|
|
168
|
+
"zmax": 1,
|
|
169
|
+
"colorscale": "RdBu",
|
|
170
|
+
"colorbar": dict(title="cos sim", x=1.0, dtick=1, tickvals=[-1, 0, 1]),
|
|
171
|
+
"hovertemplate": hovertemplate,
|
|
172
|
+
}
|
|
173
|
+
if show_values:
|
|
174
|
+
decoder_args["texttemplate"] = "%{z:.2f}"
|
|
175
|
+
decoder_args["textfont"] = {"size": 10}
|
|
176
|
+
|
|
177
|
+
fig.add_trace(go.Heatmap(**decoder_args), row=1, col=1)
|
|
178
|
+
fig.update_xaxes(title_text="True feature", row=1, col=1, dtick=dtick)
|
|
179
|
+
fig.update_yaxes(title_text="SAE Latent", row=1, col=1, dtick=dtick)
|
|
180
|
+
else:
|
|
181
|
+
fig = make_subplots(
|
|
182
|
+
rows=1, cols=2, subplot_titles=("SAE encoder", "SAE decoder")
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
# Encoder heatmap
|
|
186
|
+
encoder_args: dict[str, Any] = {
|
|
187
|
+
"z": enc_cos_sims.cpu().numpy(),
|
|
188
|
+
"zmin": -1,
|
|
189
|
+
"zmax": 1,
|
|
190
|
+
"colorscale": "RdBu",
|
|
191
|
+
"showscale": False,
|
|
192
|
+
"hovertemplate": hovertemplate,
|
|
193
|
+
}
|
|
194
|
+
if show_values:
|
|
195
|
+
encoder_args["texttemplate"] = "%{z:.2f}"
|
|
196
|
+
encoder_args["textfont"] = {"size": 10}
|
|
197
|
+
|
|
198
|
+
fig.add_trace(go.Heatmap(**encoder_args), row=1, col=1)
|
|
199
|
+
|
|
200
|
+
# Decoder heatmap
|
|
201
|
+
decoder_args = {
|
|
202
|
+
"z": dec_cos_sims.cpu().numpy(),
|
|
203
|
+
"zmin": -1,
|
|
204
|
+
"zmax": 1,
|
|
205
|
+
"colorscale": "RdBu",
|
|
206
|
+
"colorbar": dict(title="cos sim", x=1.0, dtick=1, tickvals=[-1, 0, 1]),
|
|
207
|
+
"hovertemplate": hovertemplate,
|
|
208
|
+
}
|
|
209
|
+
if show_values:
|
|
210
|
+
decoder_args["texttemplate"] = "%{z:.2f}"
|
|
211
|
+
decoder_args["textfont"] = {"size": 10}
|
|
212
|
+
|
|
213
|
+
fig.add_trace(go.Heatmap(**decoder_args), row=1, col=2)
|
|
214
|
+
|
|
215
|
+
fig.update_xaxes(title_text="True feature", row=1, col=1, dtick=dtick)
|
|
216
|
+
fig.update_xaxes(title_text="True feature", row=1, col=2, dtick=dtick)
|
|
217
|
+
fig.update_yaxes(title_text="SAE Latent", row=1, col=1, dtick=dtick)
|
|
218
|
+
fig.update_yaxes(title_text="SAE Latent", row=1, col=2, dtick=dtick)
|
|
219
|
+
|
|
220
|
+
# Set main title
|
|
221
|
+
if title is None:
|
|
222
|
+
title = "Cosine similarity with true features"
|
|
223
|
+
fig.update_layout(height=height, width=width, title_text=title)
|
|
224
|
+
|
|
225
|
+
if save_path:
|
|
226
|
+
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
|
|
227
|
+
fig.write_image(save_path, scale=scale)
|
|
228
|
+
|
|
229
|
+
if show_plot:
|
|
230
|
+
fig.show()
|