sae-lens 6.26.0__py3-none-any.whl → 6.28.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,104 @@
1
+ """
2
+ Helper functions for generating firing probability distributions.
3
+ """
4
+
5
+ import torch
6
+
7
+
8
+ def zipfian_firing_probabilities(
9
+ num_features: int,
10
+ exponent: float = 1.0,
11
+ max_prob: float = 0.3,
12
+ min_prob: float = 0.01,
13
+ ) -> torch.Tensor:
14
+ """
15
+ Generate firing probabilities following a Zipfian (power-law) distribution.
16
+
17
+ Creates probabilities where a few features fire frequently and most fire rarely,
18
+ which mirrors the distribution often observed in real neural network features.
19
+
20
+ Args:
21
+ num_features: Number of features to generate probabilities for
22
+ exponent: Zipf exponent (higher = steeper dropoff). Default 1.0.
23
+ max_prob: Maximum firing probability (for the most frequent feature)
24
+ min_prob: Minimum firing probability (for the least frequent feature)
25
+
26
+ Returns:
27
+ Tensor of shape [num_features] with firing probabilities in descending order
28
+ """
29
+ if num_features < 1:
30
+ raise ValueError("num_features must be at least 1")
31
+ if exponent <= 0:
32
+ raise ValueError("exponent must be positive")
33
+ if not 0 < min_prob < max_prob <= 1:
34
+ raise ValueError("Must have 0 < min_prob < max_prob <= 1")
35
+
36
+ ranks = torch.arange(1, num_features + 1, dtype=torch.float32)
37
+ probs = 1.0 / ranks**exponent
38
+
39
+ # Scale to [min_prob, max_prob]
40
+ if num_features == 1:
41
+ return torch.tensor([max_prob])
42
+
43
+ probs_min, probs_max = probs.min(), probs.max()
44
+ return min_prob + (max_prob - min_prob) * (probs - probs_min) / (
45
+ probs_max - probs_min
46
+ )
47
+
48
+
49
+ def linear_firing_probabilities(
50
+ num_features: int,
51
+ max_prob: float = 0.3,
52
+ min_prob: float = 0.01,
53
+ ) -> torch.Tensor:
54
+ """
55
+ Generate firing probabilities that decay linearly from max to min.
56
+
57
+ Args:
58
+ num_features: Number of features to generate probabilities for
59
+ max_prob: Firing probability for the first feature
60
+ min_prob: Firing probability for the last feature
61
+
62
+ Returns:
63
+ Tensor of shape [num_features] with linearly decaying probabilities
64
+ """
65
+ if num_features < 1:
66
+ raise ValueError("num_features must be at least 1")
67
+ if not 0 < min_prob <= max_prob <= 1:
68
+ raise ValueError("Must have 0 < min_prob <= max_prob <= 1")
69
+
70
+ if num_features == 1:
71
+ return torch.tensor([max_prob])
72
+
73
+ return torch.linspace(max_prob, min_prob, num_features)
74
+
75
+
76
+ def random_firing_probabilities(
77
+ num_features: int,
78
+ max_prob: float = 0.5,
79
+ min_prob: float = 0.01,
80
+ seed: int | None = None,
81
+ ) -> torch.Tensor:
82
+ """
83
+ Generate random firing probabilities uniformly sampled from a range.
84
+
85
+ Args:
86
+ num_features: Number of features to generate probabilities for
87
+ max_prob: Maximum firing probability
88
+ min_prob: Minimum firing probability
89
+ seed: Optional random seed for reproducibility
90
+
91
+ Returns:
92
+ Tensor of shape [num_features] with random firing probabilities
93
+ """
94
+ if num_features < 1:
95
+ raise ValueError("num_features must be at least 1")
96
+ if not 0 < min_prob < max_prob <= 1:
97
+ raise ValueError("Must have 0 < min_prob < max_prob <= 1")
98
+
99
+ generator = torch.Generator()
100
+ if seed is not None:
101
+ generator.manual_seed(seed)
102
+
103
+ probs = torch.rand(num_features, generator=generator, dtype=torch.float32)
104
+ return min_prob + (max_prob - min_prob) * probs
@@ -0,0 +1,335 @@
1
+ """
2
+ Hierarchical feature modifier for activation generators.
3
+
4
+ This module provides HierarchyNode, which enforces hierarchical dependencies
5
+ on feature activations. Child features are deactivated when their parent is inactive,
6
+ and children can optionally be mutually exclusive.
7
+
8
+ Based on Noa Nabeshima's Matryoshka SAEs:
9
+ https://github.com/noanabeshima/matryoshka-saes/blob/main/toy_model.py
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ from collections.abc import Callable, Sequence
15
+ from typing import Any
16
+
17
+ import torch
18
+
19
+ ActivationsModifier = Callable[[torch.Tensor], torch.Tensor]
20
+
21
+
22
+ def _validate_hierarchy(roots: Sequence[HierarchyNode]) -> None:
23
+ """
24
+ Validate a forest of hierarchy trees.
25
+
26
+ Treats the input as children of a virtual root node and validates the
27
+ entire structure.
28
+
29
+ Checks that:
30
+ 1. There are no loops (no node is its own ancestor)
31
+ 2. Each node has at most one parent (no node appears in multiple children lists)
32
+ 3. No feature index appears in multiple trees
33
+
34
+ Args:
35
+ roots: Root nodes of the hierarchy trees to validate
36
+
37
+ Raises:
38
+ ValueError: If the hierarchy is invalid
39
+ """
40
+ if not roots:
41
+ return
42
+
43
+ # Collect all nodes and check for loops, treating roots as children of virtual root
44
+ all_nodes: list[HierarchyNode] = []
45
+ virtual_root_id = id(roots) # Use the list itself as virtual root identity
46
+
47
+ for root in roots:
48
+ all_nodes.append(root)
49
+ _collect_nodes_and_check_loops(root, all_nodes, ancestors={virtual_root_id})
50
+
51
+ # Check for multiple parents (same node appearing multiple times)
52
+ seen_ids: set[int] = set()
53
+ for node in all_nodes:
54
+ node_id = id(node)
55
+ if node_id in seen_ids:
56
+ node_desc = _node_description(node)
57
+ raise ValueError(
58
+ f"Node ({node_desc}) has multiple parents. "
59
+ "Each node must have at most one parent."
60
+ )
61
+ seen_ids.add(node_id)
62
+
63
+ # Check for overlapping feature indices across trees
64
+ if len(roots) > 1:
65
+ all_indices: set[int] = set()
66
+ for root in roots:
67
+ tree_indices = root.get_all_feature_indices()
68
+ overlap = all_indices & set(tree_indices)
69
+ if overlap:
70
+ raise ValueError(
71
+ f"Feature indices {overlap} appear in multiple hierarchy trees. "
72
+ "Each feature should belong to at most one hierarchy."
73
+ )
74
+ all_indices.update(tree_indices)
75
+
76
+
77
+ def _collect_nodes_and_check_loops(
78
+ node: HierarchyNode,
79
+ all_nodes: list[HierarchyNode],
80
+ ancestors: set[int],
81
+ ) -> None:
82
+ """Recursively collect nodes and check for loops."""
83
+ node_id = id(node)
84
+
85
+ if node_id in ancestors:
86
+ node_desc = _node_description(node)
87
+ raise ValueError(f"Loop detected: node ({node_desc}) is its own ancestor.")
88
+
89
+ # Add to ancestors for children traversal
90
+ new_ancestors = ancestors | {node_id}
91
+
92
+ for child in node.children:
93
+ # Collect child (before recursing, so we can detect multiple parents)
94
+ all_nodes.append(child)
95
+ _collect_nodes_and_check_loops(child, all_nodes, new_ancestors)
96
+
97
+
98
+ def _node_description(node: HierarchyNode) -> str:
99
+ """Get a human-readable description of a node for error messages."""
100
+ if node.feature_index is not None:
101
+ return f"feature_index={node.feature_index}"
102
+ if node.feature_id:
103
+ return f"id={node.feature_id}"
104
+ return "unnamed node"
105
+
106
+
107
+ def hierarchy_modifier(
108
+ roots: Sequence[HierarchyNode] | HierarchyNode,
109
+ ) -> ActivationsModifier:
110
+ """
111
+ Create an activations modifier from one or more hierarchy trees.
112
+
113
+ This is the recommended way to use hierarchies with ActivationGenerator.
114
+ It validates the hierarchy structure and returns a modifier function that
115
+ applies all hierarchy constraints.
116
+
117
+ Args:
118
+ roots: One or more root HierarchyNode objects. Each root defines an
119
+ independent hierarchy tree. All trees are validated and applied.
120
+
121
+ Returns:
122
+ An ActivationsModifier function that can be passed to ActivationGenerator.
123
+
124
+ Raises:
125
+ ValueError: If validate=True and any hierarchy contains loops or
126
+ nodes with multiple parents.
127
+ """
128
+ if not roots:
129
+ # No hierarchies - return identity function
130
+ def identity(activations: torch.Tensor) -> torch.Tensor:
131
+ return activations
132
+
133
+ return identity
134
+
135
+ if isinstance(roots, HierarchyNode):
136
+ roots = [roots]
137
+ _validate_hierarchy(roots)
138
+
139
+ # Create modifier function that applies all hierarchies
140
+ def modifier(activations: torch.Tensor) -> torch.Tensor:
141
+ result = activations.clone()
142
+ for root in roots:
143
+ root._apply_hierarchy(result, parent_active_mask=None)
144
+ return result
145
+
146
+ return modifier
147
+
148
+
149
+ class HierarchyNode:
150
+ """
151
+ Represents a node in a feature hierarchy tree.
152
+
153
+ Used to define hierarchical dependencies between features. Children are
154
+ deactivated when their parent is inactive, and children can optionally
155
+ be mutually exclusive.
156
+
157
+ Use `hierarchy_modifier()` to create an ActivationsModifier from one or
158
+ more HierarchyNode trees.
159
+
160
+
161
+ Attributes:
162
+ feature_index: Index of this feature in the activation tensor
163
+ children: Child HierarchyNode nodes
164
+ mutually_exclusive_children: If True, at most one child is active per sample
165
+ feature_id: Optional identifier for debugging
166
+ """
167
+
168
+ children: Sequence[HierarchyNode]
169
+ feature_index: int | None
170
+
171
+ @classmethod
172
+ def from_dict(cls, tree_dict: dict[str, Any]) -> HierarchyNode:
173
+ """
174
+ Create a HierarchyNode from a dictionary specification.
175
+
176
+ Args:
177
+ tree_dict: Dictionary with keys:
178
+
179
+ - feature_index (optional): Index in the activation tensor
180
+ - children (optional): List of child tree dictionaries
181
+ - mutually_exclusive_children (optional): Whether children are exclusive
182
+ - id (optional): Identifier for this node
183
+
184
+ Returns:
185
+ HierarchyNode instance
186
+ """
187
+ children = [
188
+ HierarchyNode.from_dict(child_dict)
189
+ for child_dict in tree_dict.get("children", [])
190
+ ]
191
+ return cls(
192
+ feature_index=tree_dict.get("feature_index"),
193
+ children=children,
194
+ mutually_exclusive_children=tree_dict.get(
195
+ "mutually_exclusive_children", False
196
+ ),
197
+ feature_id=tree_dict.get("id"),
198
+ )
199
+
200
+ def __init__(
201
+ self,
202
+ feature_index: int | None = None,
203
+ children: Sequence[HierarchyNode] | None = None,
204
+ mutually_exclusive_children: bool = False,
205
+ feature_id: str | None = None,
206
+ ):
207
+ """
208
+ Create a new HierarchyNode.
209
+
210
+ Args:
211
+ feature_index: Index of this feature in the activation tensor.
212
+ Use None for organizational nodes that don't correspond to a feature.
213
+ children: Child nodes that depend on this feature
214
+ mutually_exclusive_children: If True, only one child can be active per sample
215
+ feature_id: Optional identifier for debugging
216
+ """
217
+ self.feature_index = feature_index
218
+ self.children = children or []
219
+ self.mutually_exclusive_children = mutually_exclusive_children
220
+ self.feature_id = feature_id
221
+
222
+ if self.mutually_exclusive_children and len(self.children) < 2:
223
+ raise ValueError("Need at least 2 children for mutual exclusion")
224
+
225
+ def _apply_hierarchy(
226
+ self,
227
+ activations: torch.Tensor,
228
+ parent_active_mask: torch.Tensor | None,
229
+ ) -> None:
230
+ """Recursively apply hierarchical constraints."""
231
+ batch_size = activations.shape[0]
232
+
233
+ # Determine which samples have this node active
234
+ if self.feature_index is not None:
235
+ is_active = activations[:, self.feature_index] > 0
236
+ else:
237
+ # Non-readout node: active if parent is active (or always if root)
238
+ is_active = (
239
+ parent_active_mask
240
+ if parent_active_mask is not None
241
+ else torch.ones(batch_size, dtype=torch.bool, device=activations.device)
242
+ )
243
+
244
+ # Deactivate this node if parent is inactive
245
+ if parent_active_mask is not None and self.feature_index is not None:
246
+ activations[~parent_active_mask, self.feature_index] = 0
247
+ # Update is_active after deactivation
248
+ is_active = activations[:, self.feature_index] > 0
249
+
250
+ # Handle mutually exclusive children
251
+ if self.mutually_exclusive_children and len(self.children) >= 2:
252
+ self._enforce_mutual_exclusion(activations, is_active)
253
+
254
+ # Recursively process children
255
+ for child in self.children:
256
+ child._apply_hierarchy(activations, parent_active_mask=is_active)
257
+
258
+ def _enforce_mutual_exclusion(
259
+ self,
260
+ activations: torch.Tensor,
261
+ parent_active_mask: torch.Tensor,
262
+ ) -> None:
263
+ """Ensure at most one child is active per sample."""
264
+ batch_size = activations.shape[0]
265
+
266
+ # Get indices of children that have feature indices
267
+ child_indices = [
268
+ child.feature_index
269
+ for child in self.children
270
+ if child.feature_index is not None
271
+ ]
272
+
273
+ if len(child_indices) < 2:
274
+ return
275
+
276
+ # For each sample where parent is active, enforce mutual exclusion.
277
+ # Note: This loop is not vectorized because we need to randomly select
278
+ # which child to keep active per sample. Vectorizing would require either
279
+ # a deterministic selection (losing randomness) or complex gather/scatter
280
+ # operations that aren't more efficient for typical batch sizes.
281
+ for batch_idx in range(batch_size):
282
+ if not parent_active_mask[batch_idx]:
283
+ continue
284
+
285
+ # Find which children are active
286
+ active_children = [
287
+ i
288
+ for i, feat_idx in enumerate(child_indices)
289
+ if activations[batch_idx, feat_idx] > 0
290
+ ]
291
+
292
+ if len(active_children) <= 1:
293
+ continue
294
+
295
+ # Randomly select one to keep active
296
+ random_idx = int(torch.randint(len(active_children), (1,)).item())
297
+ keep_idx = active_children[random_idx]
298
+
299
+ # Deactivate all others
300
+ for i, feat_idx in enumerate(child_indices):
301
+ if i != keep_idx and i in active_children:
302
+ activations[batch_idx, feat_idx] = 0
303
+
304
+ def get_all_feature_indices(self) -> list[int]:
305
+ """Get all feature indices in this subtree."""
306
+ indices = []
307
+ if self.feature_index is not None:
308
+ indices.append(self.feature_index)
309
+ for child in self.children:
310
+ indices.extend(child.get_all_feature_indices())
311
+ return indices
312
+
313
+ def validate(self) -> None:
314
+ """
315
+ Validate the hierarchy structure.
316
+
317
+ Checks that:
318
+ 1. There are no loops (no node is its own ancestor)
319
+ 2. Each node has at most one parent (no node appears in multiple children lists)
320
+
321
+ Raises:
322
+ ValueError: If the hierarchy is invalid
323
+ """
324
+ _validate_hierarchy([self])
325
+
326
+ def __repr__(self, indent: int = 0) -> str:
327
+ s = " " * (indent * 2)
328
+ s += str(self.feature_index) if self.feature_index is not None else "-"
329
+ s += "x" if self.mutually_exclusive_children else " "
330
+ if self.feature_id:
331
+ s += f" ({self.feature_id})"
332
+
333
+ for child in self.children:
334
+ s += "\n" + child.__repr__(indent + 2)
335
+ return s
@@ -0,0 +1,40 @@
1
+ import torch
2
+
3
+ from sae_lens.synthetic import FeatureDictionary
4
+
5
+
6
+ @torch.no_grad()
7
+ def init_sae_to_match_feature_dict(
8
+ sae: torch.nn.Module,
9
+ feature_dict: FeatureDictionary,
10
+ noise_level: float = 0.0,
11
+ feature_ordering: torch.Tensor | None = None,
12
+ ) -> None:
13
+ """
14
+ Initialize an SAE's weights to match a feature dictionary.
15
+
16
+ This can be useful for:
17
+
18
+ - Starting training from a known good initialization
19
+ - Testing SAE evaluation code with ground truth
20
+ - Ablation studies on initialization
21
+
22
+ Args:
23
+ sae: The SAE to initialize. Must have W_enc and W_dec attributes.
24
+ feature_dict: The feature dictionary to match
25
+ noise_level: Standard deviation of Gaussian noise to add (0 = exact match)
26
+ feature_ordering: Optional permutation of feature indices
27
+ """
28
+ features = feature_dict.feature_vectors # [num_features, hidden_dim]
29
+ min_dim = min(sae.W_enc.shape[1], features.shape[0]) # type: ignore[attr-defined]
30
+
31
+ if feature_ordering is not None:
32
+ features = features[feature_ordering]
33
+
34
+ features = features[:min_dim]
35
+
36
+ # W_enc is [hidden_dim, d_sae], feature vectors are [num_features, hidden_dim]
37
+ sae.W_enc.data[:, :min_dim] = ( # type: ignore[index]
38
+ features.T + torch.randn_like(features.T) * noise_level
39
+ )
40
+ sae.W_dec.data = sae.W_enc.data.T.clone() # type: ignore[union-attr]