sae-lens 6.28.1__py3-none-any.whl → 6.29.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -11,7 +11,9 @@ https://github.com/noanabeshima/matryoshka-saes/blob/main/toy_model.py
11
11
 
12
12
  from __future__ import annotations
13
13
 
14
+ from collections import deque
14
15
  from collections.abc import Callable, Sequence
16
+ from dataclasses import dataclass
15
17
  from typing import Any
16
18
 
17
19
  import torch
@@ -19,6 +21,7 @@ import torch
19
21
  ActivationsModifier = Callable[[torch.Tensor], torch.Tensor]
20
22
 
21
23
 
24
+ @torch.no_grad()
22
25
  def _validate_hierarchy(roots: Sequence[HierarchyNode]) -> None:
23
26
  """
24
27
  Validate a forest of hierarchy trees.
@@ -104,6 +107,620 @@ def _node_description(node: HierarchyNode) -> str:
104
107
  return "unnamed node"
105
108
 
106
109
 
110
+ # ---------------------------------------------------------------------------
111
+ # Vectorized hierarchy implementation
112
+ # ---------------------------------------------------------------------------
113
+
114
+
115
+ @dataclass
116
+ class _LevelData:
117
+ """Data for a single level in the hierarchy."""
118
+
119
+ # Features at this level and their parents (for parent deactivation)
120
+ features: torch.Tensor # [num_features_at_level]
121
+ parents: torch.Tensor # [num_features_at_level]
122
+
123
+ # ME group indices to process AFTER this level's parent deactivation
124
+ # These are groups whose parent node is at this level
125
+ # ME must be applied here before processing next level's parent deactivation
126
+ me_group_indices: torch.Tensor # [num_groups_at_level], may be empty
127
+
128
+
129
+ @dataclass
130
+ class _SparseHierarchyData:
131
+ """Precomputed data for sparse hierarchy processing.
132
+
133
+ This structure enables O(active_features) processing instead of O(all_groups).
134
+ ME is applied at each level after parent deactivation to ensure cascading works.
135
+ """
136
+
137
+ # Per-level data for parent deactivation and ME (processed in order)
138
+ level_data: list[_LevelData]
139
+
140
+ # ME group data (shared across levels, indexed by me_group_indices)
141
+ me_group_siblings: torch.Tensor # [num_groups, max_siblings]
142
+ me_group_sizes: torch.Tensor # [num_groups]
143
+ me_group_parents: (
144
+ torch.Tensor
145
+ ) # [num_groups] - parent feature index (-1 if no parent)
146
+
147
+ # Total number of ME groups
148
+ num_groups: int
149
+
150
+ # Sparse COO support: Feature-to-parent mapping
151
+ # feat_to_parent[f] = parent feature index, or -1 if root/no parent
152
+ feat_to_parent: torch.Tensor | None = None # [num_features]
153
+
154
+ # Sparse COO support: Feature-to-ME-group mapping
155
+ # feat_to_me_group[f] = group index, or -1 if not in any ME group
156
+ feat_to_me_group: torch.Tensor | None = None # [num_features]
157
+
158
+
159
+ def _build_sparse_hierarchy(
160
+ roots: Sequence[HierarchyNode],
161
+ ) -> _SparseHierarchyData:
162
+ """
163
+ Build sparse hierarchy data structure for O(active_features) processing.
164
+
165
+ The key insight is that ME groups must be applied at the level of their parent node,
166
+ AFTER parent deactivation at that level, but BEFORE processing the next level.
167
+ This ensures that when a child is deactivated by ME, its grandchildren are also
168
+ deactivated during the next level's parent deactivation.
169
+ """
170
+ # Collect feature info by level using BFS
171
+ # Each entry: (feature_index, effective_parent, level)
172
+ feature_info: list[tuple[int, int, int]] = []
173
+
174
+ # ME groups: list of (parent_level, parent_feature, child_feature_indices)
175
+ me_groups: list[tuple[int, int, list[int]]] = []
176
+
177
+ # BFS queue: (node, effective_parent, level)
178
+ queue: deque[tuple[HierarchyNode, int, int]] = deque()
179
+ for root in roots:
180
+ queue.append((root, -1, 0))
181
+
182
+ while queue:
183
+ node, effective_parent, level = queue.popleft()
184
+
185
+ if node.feature_index is not None:
186
+ feature_info.append((node.feature_index, effective_parent, level))
187
+ new_effective_parent = node.feature_index
188
+ else:
189
+ new_effective_parent = effective_parent
190
+
191
+ # Handle mutual exclusion children - record the parent's level and feature
192
+ if node.mutually_exclusive_children and len(node.children) >= 2:
193
+ child_feats = [
194
+ c.feature_index for c in node.children if c.feature_index is not None
195
+ ]
196
+ if len(child_feats) >= 2:
197
+ # ME group belongs to the parent's level (current level)
198
+ # Parent feature is the node's feature_index (-1 if organizational node)
199
+ parent_feat = (
200
+ node.feature_index if node.feature_index is not None else -1
201
+ )
202
+ me_groups.append((level, parent_feat, child_feats))
203
+
204
+ for child in node.children:
205
+ queue.append((child, new_effective_parent, level + 1))
206
+
207
+ # Determine max level for both features and ME groups
208
+ max_feature_level = max((info[2] for info in feature_info), default=-1)
209
+ max_me_level = max((lvl for lvl, _, _ in me_groups), default=-1)
210
+ max_level = max(max_feature_level, max_me_level)
211
+
212
+ # Build level data with ME group indices per level
213
+ level_data: list[_LevelData] = []
214
+
215
+ # Group ME groups by their parent level
216
+ me_groups_by_level: dict[int, list[int]] = {}
217
+ for g_idx, (parent_level, _, _) in enumerate(me_groups):
218
+ if parent_level not in me_groups_by_level:
219
+ me_groups_by_level[parent_level] = []
220
+ me_groups_by_level[parent_level].append(g_idx)
221
+
222
+ for level in range(max_level + 1):
223
+ # Get features at this level that have parents
224
+ features_at_level = [
225
+ (feat, parent) for feat, parent, lv in feature_info if lv == level
226
+ ]
227
+ with_parents = [(f, p) for f, p in features_at_level if p >= 0]
228
+
229
+ if with_parents:
230
+ feats = torch.tensor([f for f, _ in with_parents], dtype=torch.long)
231
+ parents = torch.tensor([p for _, p in with_parents], dtype=torch.long)
232
+ else:
233
+ feats = torch.empty(0, dtype=torch.long)
234
+ parents = torch.empty(0, dtype=torch.long)
235
+
236
+ # Get ME group indices for this level
237
+ if level in me_groups_by_level:
238
+ me_indices = torch.tensor(me_groups_by_level[level], dtype=torch.long)
239
+ else:
240
+ me_indices = torch.empty(0, dtype=torch.long)
241
+
242
+ level_data.append(
243
+ _LevelData(
244
+ features=feats,
245
+ parents=parents,
246
+ me_group_indices=me_indices,
247
+ )
248
+ )
249
+
250
+ # Build group siblings and parents tensors
251
+ if me_groups:
252
+ max_siblings = max(len(children) for _, _, children in me_groups)
253
+ num_groups = len(me_groups)
254
+ me_group_siblings = torch.full((num_groups, max_siblings), -1, dtype=torch.long)
255
+ me_group_sizes = torch.zeros(num_groups, dtype=torch.long)
256
+ me_group_parents = torch.full((num_groups,), -1, dtype=torch.long)
257
+ for g_idx, (_, parent_feat, siblings) in enumerate(me_groups):
258
+ me_group_sizes[g_idx] = len(siblings)
259
+ me_group_parents[g_idx] = parent_feat
260
+ me_group_siblings[g_idx, : len(siblings)] = torch.tensor(
261
+ siblings, dtype=torch.long
262
+ )
263
+ else:
264
+ me_group_siblings = torch.empty((0, 0), dtype=torch.long)
265
+ me_group_sizes = torch.empty(0, dtype=torch.long)
266
+ me_group_parents = torch.empty(0, dtype=torch.long)
267
+ num_groups = 0
268
+
269
+ # Build sparse COO support: feat_to_parent and feat_to_me_group mappings
270
+ # First determine num_features (max feature index + 1)
271
+ all_features = [f for f, _, _ in feature_info]
272
+ num_features = max(all_features) + 1 if all_features else 0
273
+
274
+ # Build feature-to-parent mapping
275
+ feat_to_parent = torch.full((num_features,), -1, dtype=torch.long)
276
+ for feat, parent, _ in feature_info:
277
+ feat_to_parent[feat] = parent
278
+
279
+ # Build feature-to-ME-group mapping
280
+ feat_to_me_group = torch.full((num_features,), -1, dtype=torch.long)
281
+ for g_idx, (_, _, siblings) in enumerate(me_groups):
282
+ for sib in siblings:
283
+ feat_to_me_group[sib] = g_idx
284
+
285
+ return _SparseHierarchyData(
286
+ level_data=level_data,
287
+ me_group_siblings=me_group_siblings,
288
+ me_group_sizes=me_group_sizes,
289
+ me_group_parents=me_group_parents,
290
+ num_groups=num_groups,
291
+ feat_to_parent=feat_to_parent,
292
+ feat_to_me_group=feat_to_me_group,
293
+ )
294
+
295
+
296
+ def _apply_hierarchy_sparse(
297
+ activations: torch.Tensor,
298
+ sparse_data: _SparseHierarchyData,
299
+ ) -> torch.Tensor:
300
+ """
301
+ Apply hierarchy constraints using precomputed sparse indices.
302
+
303
+ Processes level by level:
304
+ 1. Apply parent deactivation for features at this level
305
+ 2. Apply mutual exclusion for groups whose parent is at this level
306
+ 3. Move to next level
307
+
308
+ This ensures that ME at level L affects parent deactivation at level L+1.
309
+ """
310
+ result = activations.clone()
311
+
312
+ # Data is already on correct device from cache
313
+ me_group_siblings = sparse_data.me_group_siblings
314
+ me_group_sizes = sparse_data.me_group_sizes
315
+ me_group_parents = sparse_data.me_group_parents
316
+
317
+ for level_data in sparse_data.level_data:
318
+ # Step 1: Deactivate children where parent is inactive
319
+ if level_data.features.numel() > 0:
320
+ parent_vals = result[:, level_data.parents]
321
+ child_vals = result[:, level_data.features]
322
+ result[:, level_data.features] = child_vals * (parent_vals > 0)
323
+
324
+ # Step 2: Apply ME for groups whose parent is at this level
325
+ if level_data.me_group_indices.numel() > 0:
326
+ _apply_me_for_groups(
327
+ result,
328
+ level_data.me_group_indices,
329
+ me_group_siblings,
330
+ me_group_sizes,
331
+ me_group_parents,
332
+ )
333
+
334
+ return result
335
+
336
+
337
+ def _apply_me_for_groups(
338
+ activations: torch.Tensor,
339
+ group_indices: torch.Tensor,
340
+ me_group_siblings: torch.Tensor,
341
+ me_group_sizes: torch.Tensor,
342
+ me_group_parents: torch.Tensor,
343
+ ) -> None:
344
+ """
345
+ Apply mutual exclusion for the specified groups.
346
+
347
+ Only processes groups where the parent is active (or has no parent).
348
+ This is a key optimization since most groups are skipped when parent is inactive.
349
+
350
+ Args:
351
+ activations: [batch_size, num_features] - modified in place
352
+ group_indices: [num_groups_to_process] - which groups to apply ME for
353
+ me_group_siblings: [total_groups, max_siblings] - sibling indices per group
354
+ me_group_sizes: [total_groups] - number of valid siblings per group
355
+ me_group_parents: [total_groups] - parent feature index (-1 if no parent)
356
+ """
357
+ batch_size = activations.shape[0]
358
+ device = activations.device
359
+ num_groups = group_indices.numel()
360
+
361
+ if num_groups == 0:
362
+ return
363
+
364
+ # Get parent indices for these groups
365
+ parents = me_group_parents[group_indices] # [num_groups]
366
+
367
+ # Check which parents are active: [batch_size, num_groups]
368
+ # Groups with parent=-1 are always active (root-level ME)
369
+ has_parent = parents >= 0
370
+ if has_parent.all():
371
+ # All groups have parents - check their activation directly
372
+ parent_active = activations[:, parents] > 0 # [batch, num_groups]
373
+ if not parent_active.any():
374
+ return
375
+ elif has_parent.any():
376
+ # Mixed case: some groups have parents, some don't
377
+ # Use clamp to avoid indexing with -1 (reads feature 0, but result is masked out)
378
+ safe_parents = parents.clamp(min=0)
379
+ parent_active = activations[:, safe_parents] > 0 # [batch, num_groups]
380
+ # Groups without parent are always "active"
381
+ parent_active = parent_active | ~has_parent
382
+ else:
383
+ # No groups have parents - all are always active, skip parent check
384
+ parent_active = None
385
+
386
+ # Get siblings for the groups we're processing
387
+ siblings = me_group_siblings[group_indices] # [num_groups, max_siblings]
388
+ sizes = me_group_sizes[group_indices] # [num_groups]
389
+ max_siblings = siblings.shape[1]
390
+
391
+ # Get activations for all siblings: [batch_size, num_groups, max_siblings]
392
+ safe_siblings = siblings.clamp(min=0)
393
+ sibling_activations = activations[:, safe_siblings.view(-1)].view(
394
+ batch_size, num_groups, max_siblings
395
+ )
396
+
397
+ # Create validity mask for padding: [num_groups, max_siblings]
398
+ sibling_range = torch.arange(max_siblings, device=device)
399
+ valid_mask = sibling_range < sizes.unsqueeze(1)
400
+
401
+ # Find active valid siblings, but only where parent is active: [batch, groups, siblings]
402
+ sibling_active = (sibling_activations > 0) & valid_mask
403
+ if parent_active is not None:
404
+ sibling_active = sibling_active & parent_active.unsqueeze(2)
405
+
406
+ # Count active per group and check for conflicts: [batch_size, num_groups]
407
+ active_counts = sibling_active.sum(dim=2)
408
+ needs_exclusion = active_counts > 1
409
+
410
+ if not needs_exclusion.any():
411
+ return
412
+
413
+ # Get (batch, group) pairs needing exclusion
414
+ batch_with_conflict, groups_with_conflict = torch.where(needs_exclusion)
415
+ num_conflicts = batch_with_conflict.numel()
416
+
417
+ if num_conflicts == 0:
418
+ return
419
+
420
+ # Get siblings and activations for conflicts
421
+ conflict_siblings = siblings[groups_with_conflict] # [num_conflicts, max_siblings]
422
+ conflict_active = sibling_active[
423
+ batch_with_conflict, groups_with_conflict
424
+ ] # [num_conflicts, max_siblings]
425
+
426
+ # Random selection for winner
427
+ # Use -1e9 instead of -inf to avoid creating a tensor (torch.tensor(-float("inf")))
428
+ # on every call. Since random scores are in [0,1], -1e9 is effectively -inf for argmax.
429
+ _INACTIVE_SCORE = -1e9
430
+ random_scores = torch.rand(num_conflicts, max_siblings, device=device)
431
+ random_scores[~conflict_active] = _INACTIVE_SCORE
432
+
433
+ winner_idx = random_scores.argmax(dim=1)
434
+
435
+ # Determine losers using scatter for efficiency
436
+ is_winner = torch.zeros(
437
+ num_conflicts, max_siblings, dtype=torch.bool, device=device
438
+ )
439
+ is_winner.scatter_(1, winner_idx.unsqueeze(1), True)
440
+ should_deactivate = conflict_active & ~is_winner
441
+
442
+ # Get (conflict, sibling) pairs to deactivate
443
+ conflict_idx, sib_idx = torch.where(should_deactivate)
444
+
445
+ if conflict_idx.numel() == 0:
446
+ return
447
+
448
+ # Map back to (batch, feature) and deactivate
449
+ deact_batch = batch_with_conflict[conflict_idx]
450
+ deact_feat = conflict_siblings[conflict_idx, sib_idx]
451
+ activations[deact_batch, deact_feat] = 0
452
+
453
+
454
+ # ---------------------------------------------------------------------------
455
+ # Sparse COO hierarchy implementation
456
+ # ---------------------------------------------------------------------------
457
+
458
+
459
+ def _apply_hierarchy_sparse_coo(
460
+ sparse_tensor: torch.Tensor,
461
+ sparse_data: _SparseHierarchyData,
462
+ ) -> torch.Tensor:
463
+ """
464
+ Apply hierarchy constraints to a sparse COO tensor.
465
+
466
+ This is the sparse analog of _apply_hierarchy_sparse. It processes
467
+ level-by-level, applying parent deactivation then mutual exclusion.
468
+ """
469
+ if sparse_tensor._nnz() == 0:
470
+ return sparse_tensor
471
+
472
+ sparse_tensor = sparse_tensor.coalesce()
473
+
474
+ for level_data in sparse_data.level_data:
475
+ # Step 1: Apply parent deactivation for features at this level
476
+ if level_data.features.numel() > 0:
477
+ sparse_tensor = _apply_parent_deactivation_coo(
478
+ sparse_tensor, level_data, sparse_data
479
+ )
480
+
481
+ # Step 2: Apply ME for groups whose parent is at this level
482
+ if level_data.me_group_indices.numel() > 0:
483
+ sparse_tensor = _apply_me_coo(
484
+ sparse_tensor, level_data.me_group_indices, sparse_data
485
+ )
486
+
487
+ return sparse_tensor
488
+
489
+
490
+ def _apply_parent_deactivation_coo(
491
+ sparse_tensor: torch.Tensor,
492
+ level_data: _LevelData,
493
+ sparse_data: _SparseHierarchyData,
494
+ ) -> torch.Tensor:
495
+ """
496
+ Remove children from sparse COO tensor when their parent is inactive.
497
+
498
+ Uses searchsorted for efficient membership testing of parent activity.
499
+ """
500
+ if sparse_tensor._nnz() == 0 or level_data.features.numel() == 0:
501
+ return sparse_tensor
502
+
503
+ sparse_tensor = sparse_tensor.coalesce()
504
+ indices = sparse_tensor.indices() # [2, nnz]
505
+ values = sparse_tensor.values() # [nnz]
506
+ batch_indices = indices[0]
507
+ feat_indices = indices[1]
508
+
509
+ _, num_features = sparse_tensor.shape
510
+ device = sparse_tensor.device
511
+ nnz = indices.shape[1]
512
+
513
+ # Build set of active (batch, feature) pairs for efficient lookup
514
+ # Encode as: batch_idx * num_features + feat_idx
515
+ active_pairs = batch_indices * num_features + feat_indices
516
+ active_pairs_sorted, _ = active_pairs.sort()
517
+
518
+ # Use the precomputed feat_to_parent mapping
519
+ assert sparse_data.feat_to_parent is not None
520
+ hierarchy_num_features = sparse_data.feat_to_parent.numel()
521
+
522
+ # Handle features outside the hierarchy (they have no parent, pass through)
523
+ in_hierarchy = feat_indices < hierarchy_num_features
524
+ parent_of_feat = torch.full((nnz,), -1, dtype=torch.long, device=device)
525
+ parent_of_feat[in_hierarchy] = sparse_data.feat_to_parent[
526
+ feat_indices[in_hierarchy]
527
+ ]
528
+
529
+ # Find entries that have a parent (parent >= 0 means this feature has a parent)
530
+ has_parent = parent_of_feat >= 0
531
+
532
+ if not has_parent.any():
533
+ return sparse_tensor
534
+
535
+ # For entries with parents, check if parent is active
536
+ child_entry_indices = torch.where(has_parent)[0]
537
+ child_batch = batch_indices[has_parent]
538
+ child_parents = parent_of_feat[has_parent]
539
+
540
+ # Look up parent activity using searchsorted
541
+ parent_pairs = child_batch * num_features + child_parents
542
+ search_pos = torch.searchsorted(active_pairs_sorted, parent_pairs)
543
+ search_pos = search_pos.clamp(max=active_pairs_sorted.numel() - 1)
544
+ parent_active = active_pairs_sorted[search_pos] == parent_pairs
545
+
546
+ # Handle empty case
547
+ if active_pairs_sorted.numel() == 0:
548
+ parent_active = torch.zeros_like(parent_pairs, dtype=torch.bool)
549
+
550
+ # Build keep mask: keep entry if it's a root OR its parent is active
551
+ keep_mask = torch.ones(nnz, dtype=torch.bool, device=device)
552
+ keep_mask[child_entry_indices[~parent_active]] = False
553
+
554
+ if keep_mask.all():
555
+ return sparse_tensor
556
+
557
+ return torch.sparse_coo_tensor(
558
+ indices[:, keep_mask],
559
+ values[keep_mask],
560
+ sparse_tensor.shape,
561
+ device=device,
562
+ dtype=sparse_tensor.dtype,
563
+ )
564
+
565
+
566
+ def _apply_me_coo(
567
+ sparse_tensor: torch.Tensor,
568
+ group_indices: torch.Tensor,
569
+ sparse_data: _SparseHierarchyData,
570
+ ) -> torch.Tensor:
571
+ """
572
+ Apply mutual exclusion to sparse COO tensor.
573
+
574
+ For each ME group with multiple active siblings in the same batch,
575
+ randomly selects one winner and removes the rest.
576
+ """
577
+ if sparse_tensor._nnz() == 0 or group_indices.numel() == 0:
578
+ return sparse_tensor
579
+
580
+ sparse_tensor = sparse_tensor.coalesce()
581
+ indices = sparse_tensor.indices() # [2, nnz]
582
+ values = sparse_tensor.values() # [nnz]
583
+ batch_indices = indices[0]
584
+ feat_indices = indices[1]
585
+
586
+ _, num_features = sparse_tensor.shape
587
+ device = sparse_tensor.device
588
+ nnz = indices.shape[1]
589
+
590
+ # Use precomputed feat_to_me_group mapping
591
+ assert sparse_data.feat_to_me_group is not None
592
+ hierarchy_num_features = sparse_data.feat_to_me_group.numel()
593
+
594
+ # Handle features outside the hierarchy (they are not in any ME group)
595
+ in_hierarchy = feat_indices < hierarchy_num_features
596
+ me_group_of_feat = torch.full((nnz,), -1, dtype=torch.long, device=device)
597
+ me_group_of_feat[in_hierarchy] = sparse_data.feat_to_me_group[
598
+ feat_indices[in_hierarchy]
599
+ ]
600
+
601
+ # Find entries that belong to ME groups we're processing (vectorized)
602
+ in_relevant_group = torch.isin(me_group_of_feat, group_indices)
603
+
604
+ if not in_relevant_group.any():
605
+ return sparse_tensor
606
+
607
+ # Get the ME entries
608
+ me_entry_indices = torch.where(in_relevant_group)[0]
609
+ me_batch = batch_indices[in_relevant_group]
610
+ me_group = me_group_of_feat[in_relevant_group]
611
+
612
+ # Check parent activity for ME groups (only apply ME if parent is active)
613
+ me_group_parents = sparse_data.me_group_parents[me_group]
614
+ has_parent = me_group_parents >= 0
615
+
616
+ if has_parent.any():
617
+ # Build active pairs for parent lookup
618
+ active_pairs = batch_indices * num_features + feat_indices
619
+ active_pairs_sorted, _ = active_pairs.sort()
620
+
621
+ parent_pairs = (
622
+ me_batch[has_parent] * num_features + me_group_parents[has_parent]
623
+ )
624
+ search_pos = torch.searchsorted(active_pairs_sorted, parent_pairs)
625
+ search_pos = search_pos.clamp(max=active_pairs_sorted.numel() - 1)
626
+ parent_active_for_has_parent = active_pairs_sorted[search_pos] == parent_pairs
627
+
628
+ # Build full parent_active mask
629
+ parent_active = torch.ones(
630
+ me_entry_indices.numel(), dtype=torch.bool, device=device
631
+ )
632
+ parent_active[has_parent] = parent_active_for_has_parent
633
+
634
+ # Filter to only ME entries where parent is active
635
+ valid_me = parent_active
636
+ me_entry_indices = me_entry_indices[valid_me]
637
+ me_batch = me_batch[valid_me]
638
+ me_group = me_group[valid_me]
639
+
640
+ if me_entry_indices.numel() == 0:
641
+ return sparse_tensor
642
+
643
+ # Encode (batch, group) pairs
644
+ num_groups = sparse_data.num_groups
645
+ batch_group_pairs = me_batch * num_groups + me_group
646
+
647
+ # Find unique (batch, group) pairs and count occurrences
648
+ unique_bg, inverse, counts = torch.unique(
649
+ batch_group_pairs, return_inverse=True, return_counts=True
650
+ )
651
+
652
+ # Only process pairs with count > 1 (conflicts)
653
+ has_conflict = counts > 1
654
+
655
+ if not has_conflict.any():
656
+ return sparse_tensor
657
+
658
+ # For efficiency, we process all conflicts together
659
+ # Assign random scores to each ME entry
660
+ random_scores = torch.rand(me_entry_indices.numel(), device=device)
661
+
662
+ # For each (batch, group) pair, we want the entry with highest score to be winner
663
+ # Use scatter_reduce to find max score per (batch, group)
664
+ bg_to_dense = torch.zeros(unique_bg.numel(), dtype=torch.long, device=device)
665
+ bg_to_dense[has_conflict.nonzero(as_tuple=True)[0]] = torch.arange(
666
+ has_conflict.sum(), device=device
667
+ )
668
+
669
+ # Map each ME entry to its dense conflict index
670
+ entry_has_conflict = has_conflict[inverse]
671
+
672
+ if not entry_has_conflict.any():
673
+ return sparse_tensor
674
+
675
+ conflict_entries_mask = entry_has_conflict
676
+ conflict_entry_indices = me_entry_indices[conflict_entries_mask]
677
+ conflict_random_scores = random_scores[conflict_entries_mask]
678
+ conflict_inverse = inverse[conflict_entries_mask]
679
+ conflict_dense_idx = bg_to_dense[conflict_inverse]
680
+
681
+ # Vectorized winner selection using sorting
682
+ # Sort entries by (group_idx, -random_score) so highest score comes first per group
683
+ # Use group * 2 - score to sort by group ascending, then score descending
684
+ sort_keys = conflict_dense_idx.float() * 2.0 - conflict_random_scores
685
+ sorted_order = sort_keys.argsort()
686
+ sorted_dense_idx = conflict_dense_idx[sorted_order]
687
+
688
+ # Find first entry of each group in sorted order (these are winners)
689
+ group_starts = torch.cat(
690
+ [
691
+ torch.tensor([True], device=device),
692
+ sorted_dense_idx[1:] != sorted_dense_idx[:-1],
693
+ ]
694
+ )
695
+
696
+ # Winners are entries at group starts in sorted order
697
+ winner_positions_in_sorted = torch.where(group_starts)[0]
698
+ winner_original_positions = sorted_order[winner_positions_in_sorted]
699
+
700
+ # Create winner mask (vectorized)
701
+ is_winner = torch.zeros(
702
+ conflict_entry_indices.numel(), dtype=torch.bool, device=device
703
+ )
704
+ is_winner[winner_original_positions] = True
705
+
706
+ # Build keep mask (vectorized)
707
+ keep_mask = torch.ones(nnz, dtype=torch.bool, device=device)
708
+ loser_entry_indices = conflict_entry_indices[~is_winner]
709
+ keep_mask[loser_entry_indices] = False
710
+
711
+ if keep_mask.all():
712
+ return sparse_tensor
713
+
714
+ return torch.sparse_coo_tensor(
715
+ indices[:, keep_mask],
716
+ values[keep_mask],
717
+ sparse_tensor.shape,
718
+ device=device,
719
+ dtype=sparse_tensor.dtype,
720
+ )
721
+
722
+
723
+ @torch.no_grad()
107
724
  def hierarchy_modifier(
108
725
  roots: Sequence[HierarchyNode] | HierarchyNode,
109
726
  ) -> ActivationsModifier:
@@ -136,12 +753,47 @@ def hierarchy_modifier(
136
753
  roots = [roots]
137
754
  _validate_hierarchy(roots)
138
755
 
139
- # Create modifier function that applies all hierarchies
756
+ # Build sparse hierarchy data
757
+ sparse_data = _build_sparse_hierarchy(roots)
758
+
759
+ # Cache for device-specific tensors
760
+ device_cache: dict[torch.device, _SparseHierarchyData] = {}
761
+
762
+ def _get_sparse_for_device(device: torch.device) -> _SparseHierarchyData:
763
+ """Get or create device-specific sparse hierarchy data."""
764
+ if device not in device_cache:
765
+ device_cache[device] = _SparseHierarchyData(
766
+ level_data=[
767
+ _LevelData(
768
+ features=ld.features.to(device),
769
+ parents=ld.parents.to(device),
770
+ me_group_indices=ld.me_group_indices.to(device),
771
+ )
772
+ for ld in sparse_data.level_data
773
+ ],
774
+ me_group_siblings=sparse_data.me_group_siblings.to(device),
775
+ me_group_sizes=sparse_data.me_group_sizes.to(device),
776
+ me_group_parents=sparse_data.me_group_parents.to(device),
777
+ num_groups=sparse_data.num_groups,
778
+ feat_to_parent=(
779
+ sparse_data.feat_to_parent.to(device)
780
+ if sparse_data.feat_to_parent is not None
781
+ else None
782
+ ),
783
+ feat_to_me_group=(
784
+ sparse_data.feat_to_me_group.to(device)
785
+ if sparse_data.feat_to_me_group is not None
786
+ else None
787
+ ),
788
+ )
789
+ return device_cache[device]
790
+
140
791
  def modifier(activations: torch.Tensor) -> torch.Tensor:
141
- result = activations.clone()
142
- for root in roots:
143
- root._apply_hierarchy(result, parent_active_mask=None)
144
- return result
792
+ device = activations.device
793
+ cached = _get_sparse_for_device(device)
794
+ if activations.is_sparse:
795
+ return _apply_hierarchy_sparse_coo(activations, cached)
796
+ return _apply_hierarchy_sparse(activations, cached)
145
797
 
146
798
  return modifier
147
799
 
@@ -222,85 +874,6 @@ class HierarchyNode:
222
874
  if self.mutually_exclusive_children and len(self.children) < 2:
223
875
  raise ValueError("Need at least 2 children for mutual exclusion")
224
876
 
225
- def _apply_hierarchy(
226
- self,
227
- activations: torch.Tensor,
228
- parent_active_mask: torch.Tensor | None,
229
- ) -> None:
230
- """Recursively apply hierarchical constraints."""
231
- batch_size = activations.shape[0]
232
-
233
- # Determine which samples have this node active
234
- if self.feature_index is not None:
235
- is_active = activations[:, self.feature_index] > 0
236
- else:
237
- # Non-readout node: active if parent is active (or always if root)
238
- is_active = (
239
- parent_active_mask
240
- if parent_active_mask is not None
241
- else torch.ones(batch_size, dtype=torch.bool, device=activations.device)
242
- )
243
-
244
- # Deactivate this node if parent is inactive
245
- if parent_active_mask is not None and self.feature_index is not None:
246
- activations[~parent_active_mask, self.feature_index] = 0
247
- # Update is_active after deactivation
248
- is_active = activations[:, self.feature_index] > 0
249
-
250
- # Handle mutually exclusive children
251
- if self.mutually_exclusive_children and len(self.children) >= 2:
252
- self._enforce_mutual_exclusion(activations, is_active)
253
-
254
- # Recursively process children
255
- for child in self.children:
256
- child._apply_hierarchy(activations, parent_active_mask=is_active)
257
-
258
- def _enforce_mutual_exclusion(
259
- self,
260
- activations: torch.Tensor,
261
- parent_active_mask: torch.Tensor,
262
- ) -> None:
263
- """Ensure at most one child is active per sample."""
264
- batch_size = activations.shape[0]
265
-
266
- # Get indices of children that have feature indices
267
- child_indices = [
268
- child.feature_index
269
- for child in self.children
270
- if child.feature_index is not None
271
- ]
272
-
273
- if len(child_indices) < 2:
274
- return
275
-
276
- # For each sample where parent is active, enforce mutual exclusion.
277
- # Note: This loop is not vectorized because we need to randomly select
278
- # which child to keep active per sample. Vectorizing would require either
279
- # a deterministic selection (losing randomness) or complex gather/scatter
280
- # operations that aren't more efficient for typical batch sizes.
281
- for batch_idx in range(batch_size):
282
- if not parent_active_mask[batch_idx]:
283
- continue
284
-
285
- # Find which children are active
286
- active_children = [
287
- i
288
- for i, feat_idx in enumerate(child_indices)
289
- if activations[batch_idx, feat_idx] > 0
290
- ]
291
-
292
- if len(active_children) <= 1:
293
- continue
294
-
295
- # Randomly select one to keep active
296
- random_idx = int(torch.randint(len(active_children), (1,)).item())
297
- keep_idx = active_children[random_idx]
298
-
299
- # Deactivate all others
300
- for i, feat_idx in enumerate(child_indices):
301
- if i != keep_idx and i in active_children:
302
- activations[batch_idx, feat_idx] = 0
303
-
304
877
  def get_all_feature_indices(self) -> list[int]:
305
878
  """Get all feature indices in this subtree."""
306
879
  indices = []