sae-lens 6.28.2__py3-none-any.whl → 6.32.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -147,6 +147,14 @@ class _SparseHierarchyData:
147
147
  # Total number of ME groups
148
148
  num_groups: int
149
149
 
150
+ # Sparse COO support: Feature-to-parent mapping
151
+ # feat_to_parent[f] = parent feature index, or -1 if root/no parent
152
+ feat_to_parent: torch.Tensor | None = None # [num_features]
153
+
154
+ # Sparse COO support: Feature-to-ME-group mapping
155
+ # feat_to_me_group[f] = group index, or -1 if not in any ME group
156
+ feat_to_me_group: torch.Tensor | None = None # [num_features]
157
+
150
158
 
151
159
  def _build_sparse_hierarchy(
152
160
  roots: Sequence[HierarchyNode],
@@ -232,7 +240,11 @@ def _build_sparse_hierarchy(
232
240
  me_indices = torch.empty(0, dtype=torch.long)
233
241
 
234
242
  level_data.append(
235
- _LevelData(features=feats, parents=parents, me_group_indices=me_indices)
243
+ _LevelData(
244
+ features=feats,
245
+ parents=parents,
246
+ me_group_indices=me_indices,
247
+ )
236
248
  )
237
249
 
238
250
  # Build group siblings and parents tensors
@@ -254,12 +266,30 @@ def _build_sparse_hierarchy(
254
266
  me_group_parents = torch.empty(0, dtype=torch.long)
255
267
  num_groups = 0
256
268
 
269
+ # Build sparse COO support: feat_to_parent and feat_to_me_group mappings
270
+ # First determine num_features (max feature index + 1)
271
+ all_features = [f for f, _, _ in feature_info]
272
+ num_features = max(all_features) + 1 if all_features else 0
273
+
274
+ # Build feature-to-parent mapping
275
+ feat_to_parent = torch.full((num_features,), -1, dtype=torch.long)
276
+ for feat, parent, _ in feature_info:
277
+ feat_to_parent[feat] = parent
278
+
279
+ # Build feature-to-ME-group mapping
280
+ feat_to_me_group = torch.full((num_features,), -1, dtype=torch.long)
281
+ for g_idx, (_, _, siblings) in enumerate(me_groups):
282
+ for sib in siblings:
283
+ feat_to_me_group[sib] = g_idx
284
+
257
285
  return _SparseHierarchyData(
258
286
  level_data=level_data,
259
287
  me_group_siblings=me_group_siblings,
260
288
  me_group_sizes=me_group_sizes,
261
289
  me_group_parents=me_group_parents,
262
290
  num_groups=num_groups,
291
+ feat_to_parent=feat_to_parent,
292
+ feat_to_me_group=feat_to_me_group,
263
293
  )
264
294
 
265
295
 
@@ -396,8 +426,9 @@ def _apply_me_for_groups(
396
426
  # Random selection for winner
397
427
  # Use -1e9 instead of -inf to avoid creating a tensor (torch.tensor(-float("inf")))
398
428
  # on every call. Since random scores are in [0,1], -1e9 is effectively -inf for argmax.
429
+ _INACTIVE_SCORE = -1e9
399
430
  random_scores = torch.rand(num_conflicts, max_siblings, device=device)
400
- random_scores[~conflict_active] = -1e9
431
+ random_scores[~conflict_active] = _INACTIVE_SCORE
401
432
 
402
433
  winner_idx = random_scores.argmax(dim=1)
403
434
 
@@ -420,6 +451,275 @@ def _apply_me_for_groups(
420
451
  activations[deact_batch, deact_feat] = 0
421
452
 
422
453
 
454
+ # ---------------------------------------------------------------------------
455
+ # Sparse COO hierarchy implementation
456
+ # ---------------------------------------------------------------------------
457
+
458
+
459
+ def _apply_hierarchy_sparse_coo(
460
+ sparse_tensor: torch.Tensor,
461
+ sparse_data: _SparseHierarchyData,
462
+ ) -> torch.Tensor:
463
+ """
464
+ Apply hierarchy constraints to a sparse COO tensor.
465
+
466
+ This is the sparse analog of _apply_hierarchy_sparse. It processes
467
+ level-by-level, applying parent deactivation then mutual exclusion.
468
+ """
469
+ if sparse_tensor._nnz() == 0:
470
+ return sparse_tensor
471
+
472
+ sparse_tensor = sparse_tensor.coalesce()
473
+
474
+ for level_data in sparse_data.level_data:
475
+ # Step 1: Apply parent deactivation for features at this level
476
+ if level_data.features.numel() > 0:
477
+ sparse_tensor = _apply_parent_deactivation_coo(
478
+ sparse_tensor, level_data, sparse_data
479
+ )
480
+
481
+ # Step 2: Apply ME for groups whose parent is at this level
482
+ if level_data.me_group_indices.numel() > 0:
483
+ sparse_tensor = _apply_me_coo(
484
+ sparse_tensor, level_data.me_group_indices, sparse_data
485
+ )
486
+
487
+ return sparse_tensor
488
+
489
+
490
+ def _apply_parent_deactivation_coo(
491
+ sparse_tensor: torch.Tensor,
492
+ level_data: _LevelData,
493
+ sparse_data: _SparseHierarchyData,
494
+ ) -> torch.Tensor:
495
+ """
496
+ Remove children from sparse COO tensor when their parent is inactive.
497
+
498
+ Uses searchsorted for efficient membership testing of parent activity.
499
+ """
500
+ if sparse_tensor._nnz() == 0 or level_data.features.numel() == 0:
501
+ return sparse_tensor
502
+
503
+ sparse_tensor = sparse_tensor.coalesce()
504
+ indices = sparse_tensor.indices() # [2, nnz]
505
+ values = sparse_tensor.values() # [nnz]
506
+ batch_indices = indices[0]
507
+ feat_indices = indices[1]
508
+
509
+ _, num_features = sparse_tensor.shape
510
+ device = sparse_tensor.device
511
+ nnz = indices.shape[1]
512
+
513
+ # Build set of active (batch, feature) pairs for efficient lookup
514
+ # Encode as: batch_idx * num_features + feat_idx
515
+ active_pairs = batch_indices * num_features + feat_indices
516
+ active_pairs_sorted, _ = active_pairs.sort()
517
+
518
+ # Use the precomputed feat_to_parent mapping
519
+ assert sparse_data.feat_to_parent is not None
520
+ hierarchy_num_features = sparse_data.feat_to_parent.numel()
521
+
522
+ # Handle features outside the hierarchy (they have no parent, pass through)
523
+ in_hierarchy = feat_indices < hierarchy_num_features
524
+ parent_of_feat = torch.full((nnz,), -1, dtype=torch.long, device=device)
525
+ parent_of_feat[in_hierarchy] = sparse_data.feat_to_parent[
526
+ feat_indices[in_hierarchy]
527
+ ]
528
+
529
+ # Find entries that have a parent (parent >= 0 means this feature has a parent)
530
+ has_parent = parent_of_feat >= 0
531
+
532
+ if not has_parent.any():
533
+ return sparse_tensor
534
+
535
+ # For entries with parents, check if parent is active
536
+ child_entry_indices = torch.where(has_parent)[0]
537
+ child_batch = batch_indices[has_parent]
538
+ child_parents = parent_of_feat[has_parent]
539
+
540
+ # Look up parent activity using searchsorted
541
+ parent_pairs = child_batch * num_features + child_parents
542
+ search_pos = torch.searchsorted(active_pairs_sorted, parent_pairs)
543
+ search_pos = search_pos.clamp(max=active_pairs_sorted.numel() - 1)
544
+ parent_active = active_pairs_sorted[search_pos] == parent_pairs
545
+
546
+ # Handle empty case
547
+ if active_pairs_sorted.numel() == 0:
548
+ parent_active = torch.zeros_like(parent_pairs, dtype=torch.bool)
549
+
550
+ # Build keep mask: keep entry if it's a root OR its parent is active
551
+ keep_mask = torch.ones(nnz, dtype=torch.bool, device=device)
552
+ keep_mask[child_entry_indices[~parent_active]] = False
553
+
554
+ if keep_mask.all():
555
+ return sparse_tensor
556
+
557
+ return torch.sparse_coo_tensor(
558
+ indices[:, keep_mask],
559
+ values[keep_mask],
560
+ sparse_tensor.shape,
561
+ device=device,
562
+ dtype=sparse_tensor.dtype,
563
+ )
564
+
565
+
566
+ def _apply_me_coo(
567
+ sparse_tensor: torch.Tensor,
568
+ group_indices: torch.Tensor,
569
+ sparse_data: _SparseHierarchyData,
570
+ ) -> torch.Tensor:
571
+ """
572
+ Apply mutual exclusion to sparse COO tensor.
573
+
574
+ For each ME group with multiple active siblings in the same batch,
575
+ randomly selects one winner and removes the rest.
576
+ """
577
+ if sparse_tensor._nnz() == 0 or group_indices.numel() == 0:
578
+ return sparse_tensor
579
+
580
+ sparse_tensor = sparse_tensor.coalesce()
581
+ indices = sparse_tensor.indices() # [2, nnz]
582
+ values = sparse_tensor.values() # [nnz]
583
+ batch_indices = indices[0]
584
+ feat_indices = indices[1]
585
+
586
+ _, num_features = sparse_tensor.shape
587
+ device = sparse_tensor.device
588
+ nnz = indices.shape[1]
589
+
590
+ # Use precomputed feat_to_me_group mapping
591
+ assert sparse_data.feat_to_me_group is not None
592
+ hierarchy_num_features = sparse_data.feat_to_me_group.numel()
593
+
594
+ # Handle features outside the hierarchy (they are not in any ME group)
595
+ in_hierarchy = feat_indices < hierarchy_num_features
596
+ me_group_of_feat = torch.full((nnz,), -1, dtype=torch.long, device=device)
597
+ me_group_of_feat[in_hierarchy] = sparse_data.feat_to_me_group[
598
+ feat_indices[in_hierarchy]
599
+ ]
600
+
601
+ # Find entries that belong to ME groups we're processing (vectorized)
602
+ in_relevant_group = torch.isin(me_group_of_feat, group_indices)
603
+
604
+ if not in_relevant_group.any():
605
+ return sparse_tensor
606
+
607
+ # Get the ME entries
608
+ me_entry_indices = torch.where(in_relevant_group)[0]
609
+ me_batch = batch_indices[in_relevant_group]
610
+ me_group = me_group_of_feat[in_relevant_group]
611
+
612
+ # Check parent activity for ME groups (only apply ME if parent is active)
613
+ me_group_parents = sparse_data.me_group_parents[me_group]
614
+ has_parent = me_group_parents >= 0
615
+
616
+ if has_parent.any():
617
+ # Build active pairs for parent lookup
618
+ active_pairs = batch_indices * num_features + feat_indices
619
+ active_pairs_sorted, _ = active_pairs.sort()
620
+
621
+ parent_pairs = (
622
+ me_batch[has_parent] * num_features + me_group_parents[has_parent]
623
+ )
624
+ search_pos = torch.searchsorted(active_pairs_sorted, parent_pairs)
625
+ search_pos = search_pos.clamp(max=active_pairs_sorted.numel() - 1)
626
+ parent_active_for_has_parent = active_pairs_sorted[search_pos] == parent_pairs
627
+
628
+ # Build full parent_active mask
629
+ parent_active = torch.ones(
630
+ me_entry_indices.numel(), dtype=torch.bool, device=device
631
+ )
632
+ parent_active[has_parent] = parent_active_for_has_parent
633
+
634
+ # Filter to only ME entries where parent is active
635
+ valid_me = parent_active
636
+ me_entry_indices = me_entry_indices[valid_me]
637
+ me_batch = me_batch[valid_me]
638
+ me_group = me_group[valid_me]
639
+
640
+ if me_entry_indices.numel() == 0:
641
+ return sparse_tensor
642
+
643
+ # Encode (batch, group) pairs
644
+ num_groups = sparse_data.num_groups
645
+ batch_group_pairs = me_batch * num_groups + me_group
646
+
647
+ # Find unique (batch, group) pairs and count occurrences
648
+ unique_bg, inverse, counts = torch.unique(
649
+ batch_group_pairs, return_inverse=True, return_counts=True
650
+ )
651
+
652
+ # Only process pairs with count > 1 (conflicts)
653
+ has_conflict = counts > 1
654
+
655
+ if not has_conflict.any():
656
+ return sparse_tensor
657
+
658
+ # For efficiency, we process all conflicts together
659
+ # Assign random scores to each ME entry
660
+ random_scores = torch.rand(me_entry_indices.numel(), device=device)
661
+
662
+ # For each (batch, group) pair, we want the entry with highest score to be winner
663
+ # Use scatter_reduce to find max score per (batch, group)
664
+ bg_to_dense = torch.zeros(unique_bg.numel(), dtype=torch.long, device=device)
665
+ bg_to_dense[has_conflict.nonzero(as_tuple=True)[0]] = torch.arange(
666
+ has_conflict.sum(), device=device
667
+ )
668
+
669
+ # Map each ME entry to its dense conflict index
670
+ entry_has_conflict = has_conflict[inverse]
671
+
672
+ if not entry_has_conflict.any():
673
+ return sparse_tensor
674
+
675
+ conflict_entries_mask = entry_has_conflict
676
+ conflict_entry_indices = me_entry_indices[conflict_entries_mask]
677
+ conflict_random_scores = random_scores[conflict_entries_mask]
678
+ conflict_inverse = inverse[conflict_entries_mask]
679
+ conflict_dense_idx = bg_to_dense[conflict_inverse]
680
+
681
+ # Vectorized winner selection using sorting
682
+ # Sort entries by (group_idx, -random_score) so highest score comes first per group
683
+ # Use group * 2 - score to sort by group ascending, then score descending
684
+ sort_keys = conflict_dense_idx.float() * 2.0 - conflict_random_scores
685
+ sorted_order = sort_keys.argsort()
686
+ sorted_dense_idx = conflict_dense_idx[sorted_order]
687
+
688
+ # Find first entry of each group in sorted order (these are winners)
689
+ group_starts = torch.cat(
690
+ [
691
+ torch.tensor([True], device=device),
692
+ sorted_dense_idx[1:] != sorted_dense_idx[:-1],
693
+ ]
694
+ )
695
+
696
+ # Winners are entries at group starts in sorted order
697
+ winner_positions_in_sorted = torch.where(group_starts)[0]
698
+ winner_original_positions = sorted_order[winner_positions_in_sorted]
699
+
700
+ # Create winner mask (vectorized)
701
+ is_winner = torch.zeros(
702
+ conflict_entry_indices.numel(), dtype=torch.bool, device=device
703
+ )
704
+ is_winner[winner_original_positions] = True
705
+
706
+ # Build keep mask (vectorized)
707
+ keep_mask = torch.ones(nnz, dtype=torch.bool, device=device)
708
+ loser_entry_indices = conflict_entry_indices[~is_winner]
709
+ keep_mask[loser_entry_indices] = False
710
+
711
+ if keep_mask.all():
712
+ return sparse_tensor
713
+
714
+ return torch.sparse_coo_tensor(
715
+ indices[:, keep_mask],
716
+ values[keep_mask],
717
+ sparse_tensor.shape,
718
+ device=device,
719
+ dtype=sparse_tensor.dtype,
720
+ )
721
+
722
+
423
723
  @torch.no_grad()
424
724
  def hierarchy_modifier(
425
725
  roots: Sequence[HierarchyNode] | HierarchyNode,
@@ -475,12 +775,24 @@ def hierarchy_modifier(
475
775
  me_group_sizes=sparse_data.me_group_sizes.to(device),
476
776
  me_group_parents=sparse_data.me_group_parents.to(device),
477
777
  num_groups=sparse_data.num_groups,
778
+ feat_to_parent=(
779
+ sparse_data.feat_to_parent.to(device)
780
+ if sparse_data.feat_to_parent is not None
781
+ else None
782
+ ),
783
+ feat_to_me_group=(
784
+ sparse_data.feat_to_me_group.to(device)
785
+ if sparse_data.feat_to_me_group is not None
786
+ else None
787
+ ),
478
788
  )
479
789
  return device_cache[device]
480
790
 
481
791
  def modifier(activations: torch.Tensor) -> torch.Tensor:
482
792
  device = activations.device
483
793
  cached = _get_sparse_for_device(device)
794
+ if activations.is_sparse:
795
+ return _apply_hierarchy_sparse_coo(activations, cached)
484
796
  return _apply_hierarchy_sparse(activations, cached)
485
797
 
486
798
  return modifier
@@ -23,6 +23,8 @@ def train_toy_sae(
23
23
  device: str | torch.device = "cpu",
24
24
  n_snapshots: int = 0,
25
25
  snapshot_fn: Callable[[SAETrainer[Any, Any]], None] | None = None,
26
+ autocast_sae: bool = False,
27
+ autocast_data: bool = False,
26
28
  ) -> None:
27
29
  """
28
30
  Train an SAE on synthetic activations from a feature dictionary.
@@ -46,6 +48,8 @@ def train_toy_sae(
46
48
  snapshot_fn: Callback function called at each snapshot point. Receives
47
49
  the SAETrainer instance, allowing access to the SAE, training step,
48
50
  and other training state. Required if n_snapshots > 0.
51
+ autocast_sae: Whether to autocast the SAE to bfloat16. Only recommend for large SAEs on CUDA
52
+ autocast_data: Whether to autocast the activations generator and feature dictionary to bfloat16. Only recommend for large data on CUDA.
49
53
  """
50
54
 
51
55
  device_str = str(device) if isinstance(device, torch.device) else device
@@ -55,6 +59,7 @@ def train_toy_sae(
55
59
  feature_dict=feature_dict,
56
60
  activations_generator=activations_generator,
57
61
  batch_size=batch_size,
62
+ autocast=autocast_data,
58
63
  )
59
64
 
60
65
  # Create trainer config
@@ -64,7 +69,7 @@ def train_toy_sae(
64
69
  save_final_checkpoint=False,
65
70
  total_training_samples=training_samples,
66
71
  device=device_str,
67
- autocast=False,
72
+ autocast=autocast_sae,
68
73
  lr=lr,
69
74
  lr_end=lr,
70
75
  lr_scheduler_name="constant",
@@ -119,6 +124,7 @@ class SyntheticActivationIterator(Iterator[torch.Tensor]):
119
124
  feature_dict: FeatureDictionary,
120
125
  activations_generator: ActivationGenerator,
121
126
  batch_size: int,
127
+ autocast: bool = False,
122
128
  ):
123
129
  """
124
130
  Create a new SyntheticActivationIterator.
@@ -127,16 +133,23 @@ class SyntheticActivationIterator(Iterator[torch.Tensor]):
127
133
  feature_dict: The feature dictionary to use for generating hidden activations
128
134
  activations_generator: Generator that produces feature activations
129
135
  batch_size: Number of samples per batch
136
+ autocast: Whether to autocast the activations generator and feature dictionary to bfloat16.
130
137
  """
131
138
  self.feature_dict = feature_dict
132
139
  self.activations_generator = activations_generator
133
140
  self.batch_size = batch_size
141
+ self.autocast = autocast
134
142
 
135
143
  @torch.no_grad()
136
144
  def next_batch(self) -> torch.Tensor:
137
145
  """Generate the next batch of hidden activations."""
138
- features = self.activations_generator(self.batch_size)
139
- return self.feature_dict(features)
146
+ with torch.autocast(
147
+ device_type=self.feature_dict.feature_vectors.device.type,
148
+ dtype=torch.bfloat16,
149
+ enabled=self.autocast,
150
+ ):
151
+ features = self.activations_generator(self.batch_size)
152
+ return self.feature_dict(features)
140
153
 
141
154
  def __iter__(self) -> "SyntheticActivationIterator":
142
155
  return self
@@ -28,7 +28,9 @@ class ActivationScaler:
28
28
  ) -> float:
29
29
  norms_per_batch: list[float] = []
30
30
  for _ in tqdm(
31
- range(n_batches_for_norm_estimate), desc="Estimating norm scaling factor"
31
+ range(n_batches_for_norm_estimate),
32
+ desc="Estimating norm scaling factor",
33
+ leave=False,
32
34
  ):
33
35
  acts = next(data_provider)
34
36
  norms_per_batch.append(acts.norm(dim=-1).mean().item())
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sae-lens
3
- Version: 6.28.2
3
+ Version: 6.32.1
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  License-File: LICENSE
@@ -27,7 +27,7 @@ Requires-Dist: pyyaml (>=6.0.1,<7.0.0)
27
27
  Requires-Dist: safetensors (>=0.4.2,<1.0.0)
28
28
  Requires-Dist: simple-parsing (>=0.1.6,<0.2.0)
29
29
  Requires-Dist: tenacity (>=9.0.0)
30
- Requires-Dist: transformer-lens (>=2.16.1,<3.0.0)
30
+ Requires-Dist: transformer-lens (>=2.16.1)
31
31
  Requires-Dist: transformers (>=4.38.1,<5.0.0)
32
32
  Requires-Dist: typing-extensions (>=4.10.0,<5.0.0)
33
33
  Project-URL: Homepage, https://decoderesearch.github.io/SAELens
@@ -1,18 +1,20 @@
1
- sae_lens/__init__.py,sha256=B9tY0Jt21pOHmSQrQLpMxQHyUAdLHIZpVP6pg3O0dfQ,4788
2
- sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- sae_lens/analysis/hooked_sae_transformer.py,sha256=dQRgGVwce8XwylL2AzJE7l9elhtMRFCs2hdUj-Qyy4g,14038
1
+ sae_lens/__init__.py,sha256=Y_TVKGehpnTvQw8tvIn0fjo8uAw-XAYi7carZS_cRjQ,5168
2
+ sae_lens/analysis/__init__.py,sha256=FZExlMviNwWR7OGUSGRbd0l-yUDGSp80gglI_ivILrY,412
3
+ sae_lens/analysis/compat.py,sha256=cgE3nhFcJTcuhppxbL71VanJS7YqVEOefuneB5eOaPw,538
4
+ sae_lens/analysis/hooked_sae_transformer.py,sha256=LpnjxSAcItqqXA4SJyZuxY4Ki0UOuWV683wg9laYAsY,14050
4
5
  sae_lens/analysis/neuronpedia_integration.py,sha256=Gx1W7hUBEuMoasNcnOnZ1wmqbXDd1pSZ1nqKEya1HQc,4962
6
+ sae_lens/analysis/sae_transformer_bridge.py,sha256=xpJRRcB0g47EOQcmNCwMyrJJsbqMsGxVViDrV6C3upU,14916
5
7
  sae_lens/cache_activations_runner.py,sha256=TjqNWIc46Nw09jHWFjzQzgzG5wdu_87Ahe-iFjI5_0Q,13117
6
- sae_lens/config.py,sha256=sseYcRMsAyopj8FICup1RGTXjFxzAithZ2OH7OpQV3Y,30839
8
+ sae_lens/config.py,sha256=V0BXV8rvpbm5YuVukow9FURPpdyE4HSflbdymAo0Ycg,31205
7
9
  sae_lens/constants.py,sha256=CM-h9AjZNAl2aP7hVpKk7YsFHpu-_Lfhhmq2d5qPEVc,887
8
- sae_lens/evals.py,sha256=P0NUsJeGzYxFBiVKhbPzd72IFKY4gH40HHlEZ3jEAmg,39598
10
+ sae_lens/evals.py,sha256=nEZpUfEUN-plw6Mj9GEqm-cU_tb1qrIF9km9ktQ0vVU,39624
9
11
  sae_lens/llm_sae_training_runner.py,sha256=M7BK55gSFYu2qFQKABHX3c8i46P1LfODCeyHFzGGuqU,15196
10
12
  sae_lens/load_model.py,sha256=C8AMykctj6H7tz_xRwB06-EXj6TfW64PtSJZR5Jxn1Y,8649
11
13
  sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- sae_lens/loading/pretrained_sae_loaders.py,sha256=hHMlew1u6zVlbzvS9S_SfUPnAG0_OAjjIcjoUTIUZrU,63657
14
+ sae_lens/loading/pretrained_sae_loaders.py,sha256=kshvA0NivOc7B3sL19lHr_zrC_DDfW2T6YWb5j0hgAk,63930
13
15
  sae_lens/loading/pretrained_saes_directory.py,sha256=1at_aQbD8WFywchQCKuwfP-yvCq_Z2aUYrpKDnSN5Nc,4283
14
16
  sae_lens/pretokenize_runner.py,sha256=amJwIz3CKi2s2wNQn-10E7eAV7VFhNqtFDNTeTkwEI8,7133
15
- sae_lens/pretrained_saes.yaml,sha256=Nq43dTcFvDDONTuJ9Me_HQ5nHqr9BdbP5-ZJGXj0TAQ,1509932
17
+ sae_lens/pretrained_saes.yaml,sha256=IVBLLR8_XNllJ1O-kVv9ED4u0u44Yn8UOL9R-f8Idp4,1511936
16
18
  sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
17
19
  sae_lens/saes/__init__.py,sha256=SBqPaP6Gl5uPFwHlumAZATC4Wd26xKIYLAAAo4MSa5Q,2200
18
20
  sae_lens/saes/batchtopk_sae.py,sha256=x4EbgZl0GUickRPcCmtKNGS2Ra3Uy1Z1OtF2FnrSabQ,5422
@@ -22,22 +24,22 @@ sae_lens/saes/matching_pursuit_sae.py,sha256=08_G9p1YMLnE5qZVCPp6gll-iG6nHRbMMAS
22
24
  sae_lens/saes/matryoshka_batchtopk_sae.py,sha256=Qr6htt1HHOuO9FXI9hyaPSnGFIiJG-v7y1t1CEmkFzM,5995
23
25
  sae_lens/saes/sae.py,sha256=xRmgiLuaFlDCv8SyLbL-5TwdrWHpNLqSGe8mC1L6WcI,40942
24
26
  sae_lens/saes/standard_sae.py,sha256=_hldNZkFPAf9VGrxouR1-tN8T2OEk8IkWBcXoatrC1o,5749
25
- sae_lens/saes/temporal_sae.py,sha256=83Ap4mYGfdN3sKdPF8nKjhdXph3-7E2QuLobqJ_YuoM,13273
27
+ sae_lens/saes/temporal_sae.py,sha256=S44sPddVj2xujA02CC8gT1tG0in7c_CSAhspu9FHbaA,13273
26
28
  sae_lens/saes/topk_sae.py,sha256=vrMRPrCQR1o8G_kXqY_EAoGZARupkQNFB2dNZVLsusE,21073
27
29
  sae_lens/saes/transcoder.py,sha256=CTpJs8ASOK06npih7gZHygZuxqTR7HICWlOYfTiKjI4,13501
28
- sae_lens/synthetic/__init__.py,sha256=FGUasB6fLPXRFCcrtKfL7vCKDOWebZ5Rx5F9QNJZklI,2875
29
- sae_lens/synthetic/activation_generator.py,sha256=JEN7mEgdGDuXr0ArTwUsSdSVUAfvheT_1Eew2ojbA-g,7659
30
- sae_lens/synthetic/correlation.py,sha256=odr-S5h6c2U-bepwrAQeMfV1iBF_cnnQzqw7zapEXZ4,6056
30
+ sae_lens/synthetic/__init__.py,sha256=MtTnGkTfHV2WjkIgs7zZyx10EK9U5fjOHXy69Aq3uKw,3095
31
+ sae_lens/synthetic/activation_generator.py,sha256=8L9nwC4jFRv_wg3QN-n1sFwX8w1NqwJMysWaJ41lLlY,15197
32
+ sae_lens/synthetic/correlation.py,sha256=tMTLo9fBfDpeXwqhyUgFqnTipj9x2W0t4oEtNxB7AG0,13256
31
33
  sae_lens/synthetic/evals.py,sha256=Nhi314ZnRgLfhBj-3tm_zzI-pGyFTcwllDXbIpPFXeU,4584
32
- sae_lens/synthetic/feature_dictionary.py,sha256=ysn0ihE3JgVlCLUZMb127WYZqbz4kMp9BGHfCZqERBg,6487
34
+ sae_lens/synthetic/feature_dictionary.py,sha256=Nd4xjSTxKMnKilZ3uYi8Gv5SS5D4bv4wHiSL1uGB69E,6933
33
35
  sae_lens/synthetic/firing_probabilities.py,sha256=yclz1pWl5gE1r8LAxFvzQS88Lxwk5-3r8BCX9HLVejA,3370
34
- sae_lens/synthetic/hierarchy.py,sha256=j9-6K7xq6zQS9N8bB5nK_-EbuzAZsY5Z5AfUK-qlB5M,22138
36
+ sae_lens/synthetic/hierarchy.py,sha256=nm7nwnTswktVJeKUsRZ0hLOdXcFWGbxnA1b6lefHm-4,33592
35
37
  sae_lens/synthetic/initialization.py,sha256=orMGW-786wRDHIS2W7bEH0HmlVFQ4g2z4bnnwdv5w4s,1386
36
38
  sae_lens/synthetic/plotting.py,sha256=5lFrej1QOkGAcImFNo5-o-8mI_rUVqvEI57KzUQPPtQ,8208
37
- sae_lens/synthetic/training.py,sha256=Bg6NYxdzifq_8g-dJQSZ_z_TXDdGRtEi7tqNDb-gCVc,4986
39
+ sae_lens/synthetic/training.py,sha256=fHcX2cZ6nDupr71GX0Gk17f1NvQ0SKIVXIA6IuAb2dw,5692
38
40
  sae_lens/tokenization_and_batching.py,sha256=uoHtAs9z3XqG0Fh-iQVYVlrbyB_E3kFFhrKU30BosCo,5438
39
41
  sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
40
- sae_lens/training/activation_scaler.py,sha256=FzNfgBplLWmyiSlZ6TUvE-nur3lOiGTrlvC97ys8S24,1973
42
+ sae_lens/training/activation_scaler.py,sha256=SJZzIMX1TGdeN_wT_wqgx2ij6f4p5Dm5lWH6DGNSt5g,2011
41
43
  sae_lens/training/activations_store.py,sha256=kp4-6R4rTJUSt-g-Ifg5B1h7iIe7jZj-XQSKDvDpQMI,32187
42
44
  sae_lens/training/mixing_buffer.py,sha256=1Z-S2CcQXMWGxRZJFnXeZFxbZcALkO_fP6VO37XdJQQ,2519
43
45
  sae_lens/training/optim.py,sha256=bJpqqcK4enkcPvQAJkeH4Ci1LUOlfjIMTv6-IlaAbRA,5588
@@ -46,7 +48,7 @@ sae_lens/training/types.py,sha256=1FpLx_Doda9vZpmfm-x1e8wGBYpyhe9Kpb_JuM5nIFM,90
46
48
  sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
47
49
  sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
48
50
  sae_lens/util.py,sha256=oIMoeyEP2IzcPFmRbKUzOAycgEyMcOasGeO_BGVZbc4,4846
49
- sae_lens-6.28.2.dist-info/METADATA,sha256=i_kbAa64It0NRDrnSlmwNa8qgqOyEMntT_Ifxdx4Q90,6573
50
- sae_lens-6.28.2.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
51
- sae_lens-6.28.2.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
52
- sae_lens-6.28.2.dist-info/RECORD,,
51
+ sae_lens-6.32.1.dist-info/METADATA,sha256=TcO6hFEXKdbLp32UTiVluHcMXFetfYJDqTHNCsx9PRw,6566
52
+ sae_lens-6.32.1.dist-info/WHEEL,sha256=3ny-bZhpXrU6vSQ1UPG34FoxZBp3lVcvK0LkgUz6VLk,88
53
+ sae_lens-6.32.1.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
54
+ sae_lens-6.32.1.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: poetry-core 2.2.1
2
+ Generator: poetry-core 2.3.0
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any