sae-lens 6.28.1__py3-none-any.whl → 6.28.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sae_lens/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.28.1"
2
+ __version__ = "6.28.2"
3
3
 
4
4
  import logging
5
5
 
@@ -40631,7 +40631,7 @@ gemma-3-1b-res-matryoshka-dc:
40631
40631
  conversion_func: null
40632
40632
  links:
40633
40633
  model: https://huggingface.co/google/gemma-3-1b-pt
40634
- model: gemma-3-1b
40634
+ model: google/gemma-3-1b-pt
40635
40635
  repo_id: chanind/gemma-3-1b-batch-topk-matryoshka-saes-w-32k-l0-40
40636
40636
  saes:
40637
40637
  - id: blocks.0.hook_resid_post
@@ -63,6 +63,7 @@ class ActivationGenerator(nn.Module):
63
63
  )
64
64
  self.correlation_matrix = correlation_matrix
65
65
 
66
+ @torch.no_grad()
66
67
  def sample(self, batch_size: int) -> torch.Tensor:
67
68
  """
68
69
  Generate a batch of feature activations with controlled properties.
@@ -16,11 +16,28 @@ FeatureDictionaryInitializer = Callable[["FeatureDictionary"], None]
16
16
 
17
17
  def orthogonalize_embeddings(
18
18
  embeddings: torch.Tensor,
19
- target_cos_sim: float = 0,
20
19
  num_steps: int = 200,
21
20
  lr: float = 0.01,
22
21
  show_progress: bool = False,
22
+ chunk_size: int = 1024,
23
23
  ) -> torch.Tensor:
24
+ """
25
+ Orthogonalize embeddings using gradient descent with chunked computation.
26
+
27
+ Uses chunked computation to avoid O(n²) memory usage when computing pairwise
28
+ dot products. Memory usage is O(chunk_size × n) instead of O(n²).
29
+
30
+ Args:
31
+ embeddings: Tensor of shape [num_vectors, hidden_dim]
32
+ num_steps: Number of optimization steps
33
+ lr: Learning rate for Adam optimizer
34
+ show_progress: Whether to show progress bar
35
+ chunk_size: Number of vectors to process at once. Smaller values use less
36
+ memory but may be slower.
37
+
38
+ Returns:
39
+ Orthogonalized embeddings of the same shape, normalized to unit length.
40
+ """
24
41
  num_vectors = embeddings.shape[0]
25
42
  # Create a detached copy and normalize, then enable gradients
26
43
  embeddings = embeddings.detach().clone()
@@ -29,24 +46,37 @@ def orthogonalize_embeddings(
29
46
 
30
47
  optimizer = torch.optim.Adam([embeddings], lr=lr) # type: ignore[list-item]
31
48
 
32
- # Create a mask to zero out diagonal elements (avoid in-place operations)
33
- off_diagonal_mask = ~torch.eye(
34
- num_vectors, dtype=torch.bool, device=embeddings.device
35
- )
36
-
37
49
  pbar = tqdm(
38
50
  range(num_steps), desc="Orthogonalizing vectors", disable=not show_progress
39
51
  )
40
52
  for _ in pbar:
41
53
  optimizer.zero_grad()
42
54
 
43
- dot_products = embeddings @ embeddings.T
44
- diff = dot_products - target_cos_sim
45
- # Use masking instead of in-place fill_diagonal_
46
- off_diagonal_diff = diff * off_diagonal_mask.float()
47
- loss = off_diagonal_diff.pow(2).sum()
48
- loss = loss + num_vectors * (dot_products.diag() - 1).pow(2).sum()
55
+ off_diag_loss = torch.tensor(0.0, device=embeddings.device)
56
+ diag_loss = torch.tensor(0.0, device=embeddings.device)
57
+
58
+ for i in range(0, num_vectors, chunk_size):
59
+ end_i = min(i + chunk_size, num_vectors)
60
+ chunk = embeddings[i:end_i]
61
+ chunk_dots = chunk @ embeddings.T # [chunk_size, num_vectors]
49
62
 
63
+ # Create mask to zero out diagonal elements for this chunk
64
+ # Diagonal of full matrix: position (i+k, i+k) → in chunk_dots: (k, i+k)
65
+ chunk_len = end_i - i
66
+ row_indices = torch.arange(chunk_len, device=embeddings.device)
67
+ col_indices = i + row_indices # column indices in full matrix
68
+
69
+ # Boolean mask: True for off-diagonal elements we want to include
70
+ off_diag_mask = torch.ones_like(chunk_dots, dtype=torch.bool)
71
+ off_diag_mask[row_indices, col_indices] = False
72
+
73
+ off_diag_loss = off_diag_loss + chunk_dots[off_diag_mask].pow(2).sum()
74
+
75
+ # Diagonal loss: keep self-dot-products at 1
76
+ diag_vals = chunk_dots[row_indices, col_indices]
77
+ diag_loss = diag_loss + (diag_vals - 1).pow(2).sum()
78
+
79
+ loss = off_diag_loss + num_vectors * diag_loss
50
80
  loss.backward()
51
81
  optimizer.step()
52
82
  pbar.set_description(f"loss: {loss.item():.3f}")
@@ -59,7 +89,10 @@ def orthogonalize_embeddings(
59
89
 
60
90
 
61
91
  def orthogonal_initializer(
62
- num_steps: int = 200, lr: float = 0.01, show_progress: bool = False
92
+ num_steps: int = 200,
93
+ lr: float = 0.01,
94
+ show_progress: bool = False,
95
+ chunk_size: int = 1024,
63
96
  ) -> FeatureDictionaryInitializer:
64
97
  def initializer(feature_dict: "FeatureDictionary") -> None:
65
98
  feature_dict.feature_vectors.data = orthogonalize_embeddings(
@@ -67,6 +100,7 @@ def orthogonal_initializer(
67
100
  num_steps=num_steps,
68
101
  lr=lr,
69
102
  show_progress=show_progress,
103
+ chunk_size=chunk_size,
70
104
  )
71
105
 
72
106
  return initializer
@@ -97,6 +131,7 @@ class FeatureDictionary(nn.Module):
97
131
  hidden_dim: int,
98
132
  bias: bool = False,
99
133
  initializer: FeatureDictionaryInitializer | None = orthogonal_initializer(),
134
+ device: str | torch.device = "cpu",
100
135
  ):
101
136
  """
102
137
  Create a new FeatureDictionary.
@@ -106,20 +141,23 @@ class FeatureDictionary(nn.Module):
106
141
  hidden_dim: Dimensionality of the hidden space
107
142
  bias: Whether to include a bias term in the embedding
108
143
  initializer: Initializer function to use. If None, the embeddings are initialized to random unit vectors. By default will orthogonalize embeddings.
144
+ device: Device to use for the feature dictionary.
109
145
  """
110
146
  super().__init__()
111
147
  self.num_features = num_features
112
148
  self.hidden_dim = hidden_dim
113
149
 
114
150
  # Initialize feature vectors as unit vectors
115
- embeddings = torch.randn(num_features, hidden_dim)
151
+ embeddings = torch.randn(num_features, hidden_dim, device=device)
116
152
  embeddings = embeddings / embeddings.norm(p=2, dim=1, keepdim=True).clamp(
117
153
  min=1e-8
118
154
  )
119
155
  self.feature_vectors = nn.Parameter(embeddings)
120
156
 
121
157
  # Initialize bias (zeros if not using bias, but still a parameter for consistent API)
122
- self.bias = nn.Parameter(torch.zeros(hidden_dim), requires_grad=bias)
158
+ self.bias = nn.Parameter(
159
+ torch.zeros(hidden_dim, device=device), requires_grad=bias
160
+ )
123
161
 
124
162
  if initializer is not None:
125
163
  initializer(self)
@@ -11,7 +11,9 @@ https://github.com/noanabeshima/matryoshka-saes/blob/main/toy_model.py
11
11
 
12
12
  from __future__ import annotations
13
13
 
14
+ from collections import deque
14
15
  from collections.abc import Callable, Sequence
16
+ from dataclasses import dataclass
15
17
  from typing import Any
16
18
 
17
19
  import torch
@@ -19,6 +21,7 @@ import torch
19
21
  ActivationsModifier = Callable[[torch.Tensor], torch.Tensor]
20
22
 
21
23
 
24
+ @torch.no_grad()
22
25
  def _validate_hierarchy(roots: Sequence[HierarchyNode]) -> None:
23
26
  """
24
27
  Validate a forest of hierarchy trees.
@@ -104,6 +107,320 @@ def _node_description(node: HierarchyNode) -> str:
104
107
  return "unnamed node"
105
108
 
106
109
 
110
+ # ---------------------------------------------------------------------------
111
+ # Vectorized hierarchy implementation
112
+ # ---------------------------------------------------------------------------
113
+
114
+
115
+ @dataclass
116
+ class _LevelData:
117
+ """Data for a single level in the hierarchy."""
118
+
119
+ # Features at this level and their parents (for parent deactivation)
120
+ features: torch.Tensor # [num_features_at_level]
121
+ parents: torch.Tensor # [num_features_at_level]
122
+
123
+ # ME group indices to process AFTER this level's parent deactivation
124
+ # These are groups whose parent node is at this level
125
+ # ME must be applied here before processing next level's parent deactivation
126
+ me_group_indices: torch.Tensor # [num_groups_at_level], may be empty
127
+
128
+
129
+ @dataclass
130
+ class _SparseHierarchyData:
131
+ """Precomputed data for sparse hierarchy processing.
132
+
133
+ This structure enables O(active_features) processing instead of O(all_groups).
134
+ ME is applied at each level after parent deactivation to ensure cascading works.
135
+ """
136
+
137
+ # Per-level data for parent deactivation and ME (processed in order)
138
+ level_data: list[_LevelData]
139
+
140
+ # ME group data (shared across levels, indexed by me_group_indices)
141
+ me_group_siblings: torch.Tensor # [num_groups, max_siblings]
142
+ me_group_sizes: torch.Tensor # [num_groups]
143
+ me_group_parents: (
144
+ torch.Tensor
145
+ ) # [num_groups] - parent feature index (-1 if no parent)
146
+
147
+ # Total number of ME groups
148
+ num_groups: int
149
+
150
+
151
+ def _build_sparse_hierarchy(
152
+ roots: Sequence[HierarchyNode],
153
+ ) -> _SparseHierarchyData:
154
+ """
155
+ Build sparse hierarchy data structure for O(active_features) processing.
156
+
157
+ The key insight is that ME groups must be applied at the level of their parent node,
158
+ AFTER parent deactivation at that level, but BEFORE processing the next level.
159
+ This ensures that when a child is deactivated by ME, its grandchildren are also
160
+ deactivated during the next level's parent deactivation.
161
+ """
162
+ # Collect feature info by level using BFS
163
+ # Each entry: (feature_index, effective_parent, level)
164
+ feature_info: list[tuple[int, int, int]] = []
165
+
166
+ # ME groups: list of (parent_level, parent_feature, child_feature_indices)
167
+ me_groups: list[tuple[int, int, list[int]]] = []
168
+
169
+ # BFS queue: (node, effective_parent, level)
170
+ queue: deque[tuple[HierarchyNode, int, int]] = deque()
171
+ for root in roots:
172
+ queue.append((root, -1, 0))
173
+
174
+ while queue:
175
+ node, effective_parent, level = queue.popleft()
176
+
177
+ if node.feature_index is not None:
178
+ feature_info.append((node.feature_index, effective_parent, level))
179
+ new_effective_parent = node.feature_index
180
+ else:
181
+ new_effective_parent = effective_parent
182
+
183
+ # Handle mutual exclusion children - record the parent's level and feature
184
+ if node.mutually_exclusive_children and len(node.children) >= 2:
185
+ child_feats = [
186
+ c.feature_index for c in node.children if c.feature_index is not None
187
+ ]
188
+ if len(child_feats) >= 2:
189
+ # ME group belongs to the parent's level (current level)
190
+ # Parent feature is the node's feature_index (-1 if organizational node)
191
+ parent_feat = (
192
+ node.feature_index if node.feature_index is not None else -1
193
+ )
194
+ me_groups.append((level, parent_feat, child_feats))
195
+
196
+ for child in node.children:
197
+ queue.append((child, new_effective_parent, level + 1))
198
+
199
+ # Determine max level for both features and ME groups
200
+ max_feature_level = max((info[2] for info in feature_info), default=-1)
201
+ max_me_level = max((lvl for lvl, _, _ in me_groups), default=-1)
202
+ max_level = max(max_feature_level, max_me_level)
203
+
204
+ # Build level data with ME group indices per level
205
+ level_data: list[_LevelData] = []
206
+
207
+ # Group ME groups by their parent level
208
+ me_groups_by_level: dict[int, list[int]] = {}
209
+ for g_idx, (parent_level, _, _) in enumerate(me_groups):
210
+ if parent_level not in me_groups_by_level:
211
+ me_groups_by_level[parent_level] = []
212
+ me_groups_by_level[parent_level].append(g_idx)
213
+
214
+ for level in range(max_level + 1):
215
+ # Get features at this level that have parents
216
+ features_at_level = [
217
+ (feat, parent) for feat, parent, lv in feature_info if lv == level
218
+ ]
219
+ with_parents = [(f, p) for f, p in features_at_level if p >= 0]
220
+
221
+ if with_parents:
222
+ feats = torch.tensor([f for f, _ in with_parents], dtype=torch.long)
223
+ parents = torch.tensor([p for _, p in with_parents], dtype=torch.long)
224
+ else:
225
+ feats = torch.empty(0, dtype=torch.long)
226
+ parents = torch.empty(0, dtype=torch.long)
227
+
228
+ # Get ME group indices for this level
229
+ if level in me_groups_by_level:
230
+ me_indices = torch.tensor(me_groups_by_level[level], dtype=torch.long)
231
+ else:
232
+ me_indices = torch.empty(0, dtype=torch.long)
233
+
234
+ level_data.append(
235
+ _LevelData(features=feats, parents=parents, me_group_indices=me_indices)
236
+ )
237
+
238
+ # Build group siblings and parents tensors
239
+ if me_groups:
240
+ max_siblings = max(len(children) for _, _, children in me_groups)
241
+ num_groups = len(me_groups)
242
+ me_group_siblings = torch.full((num_groups, max_siblings), -1, dtype=torch.long)
243
+ me_group_sizes = torch.zeros(num_groups, dtype=torch.long)
244
+ me_group_parents = torch.full((num_groups,), -1, dtype=torch.long)
245
+ for g_idx, (_, parent_feat, siblings) in enumerate(me_groups):
246
+ me_group_sizes[g_idx] = len(siblings)
247
+ me_group_parents[g_idx] = parent_feat
248
+ me_group_siblings[g_idx, : len(siblings)] = torch.tensor(
249
+ siblings, dtype=torch.long
250
+ )
251
+ else:
252
+ me_group_siblings = torch.empty((0, 0), dtype=torch.long)
253
+ me_group_sizes = torch.empty(0, dtype=torch.long)
254
+ me_group_parents = torch.empty(0, dtype=torch.long)
255
+ num_groups = 0
256
+
257
+ return _SparseHierarchyData(
258
+ level_data=level_data,
259
+ me_group_siblings=me_group_siblings,
260
+ me_group_sizes=me_group_sizes,
261
+ me_group_parents=me_group_parents,
262
+ num_groups=num_groups,
263
+ )
264
+
265
+
266
+ def _apply_hierarchy_sparse(
267
+ activations: torch.Tensor,
268
+ sparse_data: _SparseHierarchyData,
269
+ ) -> torch.Tensor:
270
+ """
271
+ Apply hierarchy constraints using precomputed sparse indices.
272
+
273
+ Processes level by level:
274
+ 1. Apply parent deactivation for features at this level
275
+ 2. Apply mutual exclusion for groups whose parent is at this level
276
+ 3. Move to next level
277
+
278
+ This ensures that ME at level L affects parent deactivation at level L+1.
279
+ """
280
+ result = activations.clone()
281
+
282
+ # Data is already on correct device from cache
283
+ me_group_siblings = sparse_data.me_group_siblings
284
+ me_group_sizes = sparse_data.me_group_sizes
285
+ me_group_parents = sparse_data.me_group_parents
286
+
287
+ for level_data in sparse_data.level_data:
288
+ # Step 1: Deactivate children where parent is inactive
289
+ if level_data.features.numel() > 0:
290
+ parent_vals = result[:, level_data.parents]
291
+ child_vals = result[:, level_data.features]
292
+ result[:, level_data.features] = child_vals * (parent_vals > 0)
293
+
294
+ # Step 2: Apply ME for groups whose parent is at this level
295
+ if level_data.me_group_indices.numel() > 0:
296
+ _apply_me_for_groups(
297
+ result,
298
+ level_data.me_group_indices,
299
+ me_group_siblings,
300
+ me_group_sizes,
301
+ me_group_parents,
302
+ )
303
+
304
+ return result
305
+
306
+
307
+ def _apply_me_for_groups(
308
+ activations: torch.Tensor,
309
+ group_indices: torch.Tensor,
310
+ me_group_siblings: torch.Tensor,
311
+ me_group_sizes: torch.Tensor,
312
+ me_group_parents: torch.Tensor,
313
+ ) -> None:
314
+ """
315
+ Apply mutual exclusion for the specified groups.
316
+
317
+ Only processes groups where the parent is active (or has no parent).
318
+ This is a key optimization since most groups are skipped when parent is inactive.
319
+
320
+ Args:
321
+ activations: [batch_size, num_features] - modified in place
322
+ group_indices: [num_groups_to_process] - which groups to apply ME for
323
+ me_group_siblings: [total_groups, max_siblings] - sibling indices per group
324
+ me_group_sizes: [total_groups] - number of valid siblings per group
325
+ me_group_parents: [total_groups] - parent feature index (-1 if no parent)
326
+ """
327
+ batch_size = activations.shape[0]
328
+ device = activations.device
329
+ num_groups = group_indices.numel()
330
+
331
+ if num_groups == 0:
332
+ return
333
+
334
+ # Get parent indices for these groups
335
+ parents = me_group_parents[group_indices] # [num_groups]
336
+
337
+ # Check which parents are active: [batch_size, num_groups]
338
+ # Groups with parent=-1 are always active (root-level ME)
339
+ has_parent = parents >= 0
340
+ if has_parent.all():
341
+ # All groups have parents - check their activation directly
342
+ parent_active = activations[:, parents] > 0 # [batch, num_groups]
343
+ if not parent_active.any():
344
+ return
345
+ elif has_parent.any():
346
+ # Mixed case: some groups have parents, some don't
347
+ # Use clamp to avoid indexing with -1 (reads feature 0, but result is masked out)
348
+ safe_parents = parents.clamp(min=0)
349
+ parent_active = activations[:, safe_parents] > 0 # [batch, num_groups]
350
+ # Groups without parent are always "active"
351
+ parent_active = parent_active | ~has_parent
352
+ else:
353
+ # No groups have parents - all are always active, skip parent check
354
+ parent_active = None
355
+
356
+ # Get siblings for the groups we're processing
357
+ siblings = me_group_siblings[group_indices] # [num_groups, max_siblings]
358
+ sizes = me_group_sizes[group_indices] # [num_groups]
359
+ max_siblings = siblings.shape[1]
360
+
361
+ # Get activations for all siblings: [batch_size, num_groups, max_siblings]
362
+ safe_siblings = siblings.clamp(min=0)
363
+ sibling_activations = activations[:, safe_siblings.view(-1)].view(
364
+ batch_size, num_groups, max_siblings
365
+ )
366
+
367
+ # Create validity mask for padding: [num_groups, max_siblings]
368
+ sibling_range = torch.arange(max_siblings, device=device)
369
+ valid_mask = sibling_range < sizes.unsqueeze(1)
370
+
371
+ # Find active valid siblings, but only where parent is active: [batch, groups, siblings]
372
+ sibling_active = (sibling_activations > 0) & valid_mask
373
+ if parent_active is not None:
374
+ sibling_active = sibling_active & parent_active.unsqueeze(2)
375
+
376
+ # Count active per group and check for conflicts: [batch_size, num_groups]
377
+ active_counts = sibling_active.sum(dim=2)
378
+ needs_exclusion = active_counts > 1
379
+
380
+ if not needs_exclusion.any():
381
+ return
382
+
383
+ # Get (batch, group) pairs needing exclusion
384
+ batch_with_conflict, groups_with_conflict = torch.where(needs_exclusion)
385
+ num_conflicts = batch_with_conflict.numel()
386
+
387
+ if num_conflicts == 0:
388
+ return
389
+
390
+ # Get siblings and activations for conflicts
391
+ conflict_siblings = siblings[groups_with_conflict] # [num_conflicts, max_siblings]
392
+ conflict_active = sibling_active[
393
+ batch_with_conflict, groups_with_conflict
394
+ ] # [num_conflicts, max_siblings]
395
+
396
+ # Random selection for winner
397
+ # Use -1e9 instead of -inf to avoid creating a tensor (torch.tensor(-float("inf")))
398
+ # on every call. Since random scores are in [0,1], -1e9 is effectively -inf for argmax.
399
+ random_scores = torch.rand(num_conflicts, max_siblings, device=device)
400
+ random_scores[~conflict_active] = -1e9
401
+
402
+ winner_idx = random_scores.argmax(dim=1)
403
+
404
+ # Determine losers using scatter for efficiency
405
+ is_winner = torch.zeros(
406
+ num_conflicts, max_siblings, dtype=torch.bool, device=device
407
+ )
408
+ is_winner.scatter_(1, winner_idx.unsqueeze(1), True)
409
+ should_deactivate = conflict_active & ~is_winner
410
+
411
+ # Get (conflict, sibling) pairs to deactivate
412
+ conflict_idx, sib_idx = torch.where(should_deactivate)
413
+
414
+ if conflict_idx.numel() == 0:
415
+ return
416
+
417
+ # Map back to (batch, feature) and deactivate
418
+ deact_batch = batch_with_conflict[conflict_idx]
419
+ deact_feat = conflict_siblings[conflict_idx, sib_idx]
420
+ activations[deact_batch, deact_feat] = 0
421
+
422
+
423
+ @torch.no_grad()
107
424
  def hierarchy_modifier(
108
425
  roots: Sequence[HierarchyNode] | HierarchyNode,
109
426
  ) -> ActivationsModifier:
@@ -136,12 +453,35 @@ def hierarchy_modifier(
136
453
  roots = [roots]
137
454
  _validate_hierarchy(roots)
138
455
 
139
- # Create modifier function that applies all hierarchies
456
+ # Build sparse hierarchy data
457
+ sparse_data = _build_sparse_hierarchy(roots)
458
+
459
+ # Cache for device-specific tensors
460
+ device_cache: dict[torch.device, _SparseHierarchyData] = {}
461
+
462
+ def _get_sparse_for_device(device: torch.device) -> _SparseHierarchyData:
463
+ """Get or create device-specific sparse hierarchy data."""
464
+ if device not in device_cache:
465
+ device_cache[device] = _SparseHierarchyData(
466
+ level_data=[
467
+ _LevelData(
468
+ features=ld.features.to(device),
469
+ parents=ld.parents.to(device),
470
+ me_group_indices=ld.me_group_indices.to(device),
471
+ )
472
+ for ld in sparse_data.level_data
473
+ ],
474
+ me_group_siblings=sparse_data.me_group_siblings.to(device),
475
+ me_group_sizes=sparse_data.me_group_sizes.to(device),
476
+ me_group_parents=sparse_data.me_group_parents.to(device),
477
+ num_groups=sparse_data.num_groups,
478
+ )
479
+ return device_cache[device]
480
+
140
481
  def modifier(activations: torch.Tensor) -> torch.Tensor:
141
- result = activations.clone()
142
- for root in roots:
143
- root._apply_hierarchy(result, parent_active_mask=None)
144
- return result
482
+ device = activations.device
483
+ cached = _get_sparse_for_device(device)
484
+ return _apply_hierarchy_sparse(activations, cached)
145
485
 
146
486
  return modifier
147
487
 
@@ -222,85 +562,6 @@ class HierarchyNode:
222
562
  if self.mutually_exclusive_children and len(self.children) < 2:
223
563
  raise ValueError("Need at least 2 children for mutual exclusion")
224
564
 
225
- def _apply_hierarchy(
226
- self,
227
- activations: torch.Tensor,
228
- parent_active_mask: torch.Tensor | None,
229
- ) -> None:
230
- """Recursively apply hierarchical constraints."""
231
- batch_size = activations.shape[0]
232
-
233
- # Determine which samples have this node active
234
- if self.feature_index is not None:
235
- is_active = activations[:, self.feature_index] > 0
236
- else:
237
- # Non-readout node: active if parent is active (or always if root)
238
- is_active = (
239
- parent_active_mask
240
- if parent_active_mask is not None
241
- else torch.ones(batch_size, dtype=torch.bool, device=activations.device)
242
- )
243
-
244
- # Deactivate this node if parent is inactive
245
- if parent_active_mask is not None and self.feature_index is not None:
246
- activations[~parent_active_mask, self.feature_index] = 0
247
- # Update is_active after deactivation
248
- is_active = activations[:, self.feature_index] > 0
249
-
250
- # Handle mutually exclusive children
251
- if self.mutually_exclusive_children and len(self.children) >= 2:
252
- self._enforce_mutual_exclusion(activations, is_active)
253
-
254
- # Recursively process children
255
- for child in self.children:
256
- child._apply_hierarchy(activations, parent_active_mask=is_active)
257
-
258
- def _enforce_mutual_exclusion(
259
- self,
260
- activations: torch.Tensor,
261
- parent_active_mask: torch.Tensor,
262
- ) -> None:
263
- """Ensure at most one child is active per sample."""
264
- batch_size = activations.shape[0]
265
-
266
- # Get indices of children that have feature indices
267
- child_indices = [
268
- child.feature_index
269
- for child in self.children
270
- if child.feature_index is not None
271
- ]
272
-
273
- if len(child_indices) < 2:
274
- return
275
-
276
- # For each sample where parent is active, enforce mutual exclusion.
277
- # Note: This loop is not vectorized because we need to randomly select
278
- # which child to keep active per sample. Vectorizing would require either
279
- # a deterministic selection (losing randomness) or complex gather/scatter
280
- # operations that aren't more efficient for typical batch sizes.
281
- for batch_idx in range(batch_size):
282
- if not parent_active_mask[batch_idx]:
283
- continue
284
-
285
- # Find which children are active
286
- active_children = [
287
- i
288
- for i, feat_idx in enumerate(child_indices)
289
- if activations[batch_idx, feat_idx] > 0
290
- ]
291
-
292
- if len(active_children) <= 1:
293
- continue
294
-
295
- # Randomly select one to keep active
296
- random_idx = int(torch.randint(len(active_children), (1,)).item())
297
- keep_idx = active_children[random_idx]
298
-
299
- # Deactivate all others
300
- for i, feat_idx in enumerate(child_indices):
301
- if i != keep_idx and i in active_children:
302
- activations[batch_idx, feat_idx] = 0
303
-
304
565
  def get_all_feature_indices(self) -> list[int]:
305
566
  """Get all feature indices in this subtree."""
306
567
  indices = []
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sae-lens
3
- Version: 6.28.1
3
+ Version: 6.28.2
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  License-File: LICENSE
@@ -50,6 +50,8 @@ SAELens exists to help researchers:
50
50
  - Analyse sparse autoencoders / research mechanistic interpretability.
51
51
  - Generate insights which make it easier to create safe and aligned AI systems.
52
52
 
53
+ SAELens inference works with any PyTorch-based model, not just TransformerLens. While we provide deep integration with TransformerLens via `HookedSAETransformer`, SAEs can be used with Hugging Face Transformers, NNsight, or any other framework by extracting activations and passing them to the SAE's `encode()` and `decode()` methods.
54
+
53
55
  Please refer to the [documentation](https://decoderesearch.github.io/SAELens/) for information on how to:
54
56
 
55
57
  - Download and Analyse pre-trained sparse autoencoders.
@@ -84,6 +86,14 @@ The new v6 update is a major refactor to SAELens and changes the way training co
84
86
 
85
87
  Feel free to join the [Open Source Mechanistic Interpretability Slack](https://join.slack.com/t/opensourcemechanistic/shared_invite/zt-375zalm04-GFd5tdBU1yLKlu_T_JSqZQ) for support!
86
88
 
89
+ ## Other SAE Projects
90
+
91
+ - [dictionary-learning](https://github.com/saprmarks/dictionary_learning): An SAE training library that focuses on having hackable code.
92
+ - [Sparsify](https://github.com/EleutherAI/sparsify): A lean SAE training library focused on TopK SAEs.
93
+ - [Overcomplete](https://github.com/KempnerInstitute/overcomplete): SAE training library focused on vision models.
94
+ - [SAE-Vis](https://github.com/callummcdougall/sae_vis): A library for visualizing SAE features, works with SAELens.
95
+ - [SAEBench](https://github.com/adamkarvonen/SAEBench): A suite of LLM SAE benchmarks, works with SAELens.
96
+
87
97
  ## Citation
88
98
 
89
99
  Please cite the package as follows:
@@ -1,4 +1,4 @@
1
- sae_lens/__init__.py,sha256=S-AS72IxkvKO-wItRQjuyczikDxmfDaUgXRSfu5PU-o,4788
1
+ sae_lens/__init__.py,sha256=B9tY0Jt21pOHmSQrQLpMxQHyUAdLHIZpVP6pg3O0dfQ,4788
2
2
  sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  sae_lens/analysis/hooked_sae_transformer.py,sha256=dQRgGVwce8XwylL2AzJE7l9elhtMRFCs2hdUj-Qyy4g,14038
4
4
  sae_lens/analysis/neuronpedia_integration.py,sha256=Gx1W7hUBEuMoasNcnOnZ1wmqbXDd1pSZ1nqKEya1HQc,4962
@@ -12,7 +12,7 @@ sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,
12
12
  sae_lens/loading/pretrained_sae_loaders.py,sha256=hHMlew1u6zVlbzvS9S_SfUPnAG0_OAjjIcjoUTIUZrU,63657
13
13
  sae_lens/loading/pretrained_saes_directory.py,sha256=1at_aQbD8WFywchQCKuwfP-yvCq_Z2aUYrpKDnSN5Nc,4283
14
14
  sae_lens/pretokenize_runner.py,sha256=amJwIz3CKi2s2wNQn-10E7eAV7VFhNqtFDNTeTkwEI8,7133
15
- sae_lens/pretrained_saes.yaml,sha256=Hn8jXwZ7V6QQxzgu41LFEP-LAzuDxwYL5vhoar-pPX8,1509922
15
+ sae_lens/pretrained_saes.yaml,sha256=Nq43dTcFvDDONTuJ9Me_HQ5nHqr9BdbP5-ZJGXj0TAQ,1509932
16
16
  sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
17
17
  sae_lens/saes/__init__.py,sha256=SBqPaP6Gl5uPFwHlumAZATC4Wd26xKIYLAAAo4MSa5Q,2200
18
18
  sae_lens/saes/batchtopk_sae.py,sha256=x4EbgZl0GUickRPcCmtKNGS2Ra3Uy1Z1OtF2FnrSabQ,5422
@@ -26,12 +26,12 @@ sae_lens/saes/temporal_sae.py,sha256=83Ap4mYGfdN3sKdPF8nKjhdXph3-7E2QuLobqJ_YuoM
26
26
  sae_lens/saes/topk_sae.py,sha256=vrMRPrCQR1o8G_kXqY_EAoGZARupkQNFB2dNZVLsusE,21073
27
27
  sae_lens/saes/transcoder.py,sha256=CTpJs8ASOK06npih7gZHygZuxqTR7HICWlOYfTiKjI4,13501
28
28
  sae_lens/synthetic/__init__.py,sha256=FGUasB6fLPXRFCcrtKfL7vCKDOWebZ5Rx5F9QNJZklI,2875
29
- sae_lens/synthetic/activation_generator.py,sha256=thWGTwRmhu0K8m66WfJUajHmuIPHkwV4_HjmG0dL3G8,7638
29
+ sae_lens/synthetic/activation_generator.py,sha256=JEN7mEgdGDuXr0ArTwUsSdSVUAfvheT_1Eew2ojbA-g,7659
30
30
  sae_lens/synthetic/correlation.py,sha256=odr-S5h6c2U-bepwrAQeMfV1iBF_cnnQzqw7zapEXZ4,6056
31
31
  sae_lens/synthetic/evals.py,sha256=Nhi314ZnRgLfhBj-3tm_zzI-pGyFTcwllDXbIpPFXeU,4584
32
- sae_lens/synthetic/feature_dictionary.py,sha256=2A9wqdT1KejRLuIoFWdoiWdDtaHHgIluaKsHGizsVxI,4864
32
+ sae_lens/synthetic/feature_dictionary.py,sha256=ysn0ihE3JgVlCLUZMb127WYZqbz4kMp9BGHfCZqERBg,6487
33
33
  sae_lens/synthetic/firing_probabilities.py,sha256=yclz1pWl5gE1r8LAxFvzQS88Lxwk5-3r8BCX9HLVejA,3370
34
- sae_lens/synthetic/hierarchy.py,sha256=dlQdPnnG3VzQDB3QOaqSXwoH8Ij2ioxmTlZg1lXHaRQ,11754
34
+ sae_lens/synthetic/hierarchy.py,sha256=j9-6K7xq6zQS9N8bB5nK_-EbuzAZsY5Z5AfUK-qlB5M,22138
35
35
  sae_lens/synthetic/initialization.py,sha256=orMGW-786wRDHIS2W7bEH0HmlVFQ4g2z4bnnwdv5w4s,1386
36
36
  sae_lens/synthetic/plotting.py,sha256=5lFrej1QOkGAcImFNo5-o-8mI_rUVqvEI57KzUQPPtQ,8208
37
37
  sae_lens/synthetic/training.py,sha256=Bg6NYxdzifq_8g-dJQSZ_z_TXDdGRtEi7tqNDb-gCVc,4986
@@ -46,7 +46,7 @@ sae_lens/training/types.py,sha256=1FpLx_Doda9vZpmfm-x1e8wGBYpyhe9Kpb_JuM5nIFM,90
46
46
  sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
47
47
  sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
48
48
  sae_lens/util.py,sha256=oIMoeyEP2IzcPFmRbKUzOAycgEyMcOasGeO_BGVZbc4,4846
49
- sae_lens-6.28.1.dist-info/METADATA,sha256=OdPVG1dwWoLGqiutKkAJGazfBLLbYQLBUbs_3h58BKg,5633
50
- sae_lens-6.28.1.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
51
- sae_lens-6.28.1.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
52
- sae_lens-6.28.1.dist-info/RECORD,,
49
+ sae_lens-6.28.2.dist-info/METADATA,sha256=i_kbAa64It0NRDrnSlmwNa8qgqOyEMntT_Ifxdx4Q90,6573
50
+ sae_lens-6.28.2.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
51
+ sae_lens-6.28.2.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
52
+ sae_lens-6.28.2.dist-info/RECORD,,