sae-lens 6.29.0__tar.gz → 6.30.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. {sae_lens-6.29.0 → sae_lens-6.30.1}/PKG-INFO +1 -1
  2. {sae_lens-6.29.0 → sae_lens-6.30.1}/pyproject.toml +1 -1
  3. {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/__init__.py +1 -1
  4. {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/loading/pretrained_sae_loaders.py +9 -3
  5. {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/pretrained_saes.yaml +36 -0
  6. {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/synthetic/activation_generator.py +110 -36
  7. {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/synthetic/feature_dictionary.py +10 -1
  8. {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/synthetic/hierarchy.py +314 -2
  9. {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/synthetic/training.py +16 -3
  10. {sae_lens-6.29.0 → sae_lens-6.30.1}/LICENSE +0 -0
  11. {sae_lens-6.29.0 → sae_lens-6.30.1}/README.md +0 -0
  12. {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/analysis/__init__.py +0 -0
  13. {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
  14. {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/analysis/neuronpedia_integration.py +0 -0
  15. {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/cache_activations_runner.py +0 -0
  16. {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/config.py +0 -0
  17. {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/constants.py +0 -0
  18. {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/evals.py +0 -0
  19. {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/llm_sae_training_runner.py +0 -0
  20. {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/load_model.py +0 -0
  21. {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/loading/__init__.py +0 -0
  22. {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/loading/pretrained_saes_directory.py +0 -0
  23. {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/pretokenize_runner.py +0 -0
  24. {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/registry.py +0 -0
  25. {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/saes/__init__.py +0 -0
  26. {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/saes/batchtopk_sae.py +0 -0
  27. {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/saes/gated_sae.py +0 -0
  28. {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/saes/jumprelu_sae.py +0 -0
  29. {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/saes/matching_pursuit_sae.py +0 -0
  30. {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/saes/matryoshka_batchtopk_sae.py +0 -0
  31. {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/saes/sae.py +0 -0
  32. {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/saes/standard_sae.py +0 -0
  33. {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/saes/temporal_sae.py +0 -0
  34. {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/saes/topk_sae.py +0 -0
  35. {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/saes/transcoder.py +0 -0
  36. {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/synthetic/__init__.py +0 -0
  37. {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/synthetic/correlation.py +0 -0
  38. {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/synthetic/evals.py +0 -0
  39. {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/synthetic/firing_probabilities.py +0 -0
  40. {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/synthetic/initialization.py +0 -0
  41. {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/synthetic/plotting.py +0 -0
  42. {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/tokenization_and_batching.py +0 -0
  43. {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/training/__init__.py +0 -0
  44. {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/training/activation_scaler.py +0 -0
  45. {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/training/activations_store.py +0 -0
  46. {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/training/mixing_buffer.py +0 -0
  47. {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/training/optim.py +0 -0
  48. {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/training/sae_trainer.py +0 -0
  49. {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/training/types.py +0 -0
  50. {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
  51. {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/tutorial/tsea.py +0 -0
  52. {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/util.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sae-lens
3
- Version: 6.29.0
3
+ Version: 6.30.1
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  License-File: LICENSE
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "sae-lens"
3
- version = "6.29.0"
3
+ version = "6.30.1"
4
4
  description = "Training and Analyzing Sparse Autoencoders (SAEs)"
5
5
  authors = ["Joseph Bloom"]
6
6
  readme = "README.md"
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.29.0"
2
+ __version__ = "6.30.1"
3
3
 
4
4
  import logging
5
5
 
@@ -575,6 +575,8 @@ def _infer_gemma_3_raw_cfg_dict(repo_id: str, folder_name: str) -> dict[str, Any
575
575
  "model_name": model_name,
576
576
  "hf_hook_point_in": hf_hook_point_in,
577
577
  }
578
+ if "transcoder" in folder_name or "clt" in folder_name:
579
+ cfg["affine_connection"] = "affine" in folder_name
578
580
  if hf_hook_point_out is not None:
579
581
  cfg["hf_hook_point_out"] = hf_hook_point_out
580
582
 
@@ -614,11 +616,11 @@ def get_gemma_3_config_from_hf(
614
616
  if "resid_post" in folder_name:
615
617
  hook_name = f"blocks.{layer}.hook_resid_post"
616
618
  elif "attn_out" in folder_name:
617
- hook_name = f"blocks.{layer}.hook_attn_out"
619
+ hook_name = f"blocks.{layer}.attn.hook_z"
618
620
  elif "mlp_out" in folder_name:
619
621
  hook_name = f"blocks.{layer}.hook_mlp_out"
620
622
  elif "transcoder" in folder_name or "clt" in folder_name:
621
- hook_name = f"blocks.{layer}.ln2.hook_normalized"
623
+ hook_name = f"blocks.{layer}.hook_mlp_in"
622
624
  hook_name_out = f"blocks.{layer}.hook_mlp_out"
623
625
  else:
624
626
  raise ValueError("Hook name not found in folder_name.")
@@ -643,7 +645,11 @@ def get_gemma_3_config_from_hf(
643
645
 
644
646
  architecture = "jumprelu"
645
647
  if "transcoder" in folder_name or "clt" in folder_name:
646
- architecture = "jumprelu_skip_transcoder"
648
+ architecture = (
649
+ "jumprelu_skip_transcoder"
650
+ if raw_cfg_dict.get("affine_connection", False)
651
+ else "jumprelu_transcoder"
652
+ )
647
653
  d_out = shapes_dict["w_dec"][-1]
648
654
 
649
655
  cfg = {
@@ -4148,6 +4148,7 @@ gemma-scope-2-4b-it-res:
4148
4148
  - id: layer_17_width_16k_l0_medium
4149
4149
  path: resid_post/layer_17_width_16k_l0_medium
4150
4150
  l0: 60
4151
+ neuronpedia: gemma-3-4b-it/17-gemmascope-2-res-16k
4151
4152
  - id: layer_17_width_16k_l0_small
4152
4153
  path: resid_post/layer_17_width_16k_l0_small
4153
4154
  l0: 20
@@ -4166,6 +4167,7 @@ gemma-scope-2-4b-it-res:
4166
4167
  - id: layer_17_width_262k_l0_medium
4167
4168
  path: resid_post/layer_17_width_262k_l0_medium
4168
4169
  l0: 60
4170
+ neuronpedia: gemma-3-4b-it/17-gemmascope-2-res-262k
4169
4171
  - id: layer_17_width_262k_l0_medium_seed_1
4170
4172
  path: resid_post/layer_17_width_262k_l0_medium_seed_1
4171
4173
  l0: 60
@@ -4178,6 +4180,7 @@ gemma-scope-2-4b-it-res:
4178
4180
  - id: layer_17_width_65k_l0_medium
4179
4181
  path: resid_post/layer_17_width_65k_l0_medium
4180
4182
  l0: 60
4183
+ neuronpedia: gemma-3-4b-it/17-gemmascope-2-res-65k
4181
4184
  - id: layer_17_width_65k_l0_small
4182
4185
  path: resid_post/layer_17_width_65k_l0_small
4183
4186
  l0: 20
@@ -4187,6 +4190,7 @@ gemma-scope-2-4b-it-res:
4187
4190
  - id: layer_22_width_16k_l0_medium
4188
4191
  path: resid_post/layer_22_width_16k_l0_medium
4189
4192
  l0: 60
4193
+ neuronpedia: gemma-3-4b-it/22-gemmascope-2-res-16k
4190
4194
  - id: layer_22_width_16k_l0_small
4191
4195
  path: resid_post/layer_22_width_16k_l0_small
4192
4196
  l0: 20
@@ -4205,6 +4209,7 @@ gemma-scope-2-4b-it-res:
4205
4209
  - id: layer_22_width_262k_l0_medium
4206
4210
  path: resid_post/layer_22_width_262k_l0_medium
4207
4211
  l0: 60
4212
+ neuronpedia: gemma-3-4b-it/22-gemmascope-2-res-262k
4208
4213
  - id: layer_22_width_262k_l0_medium_seed_1
4209
4214
  path: resid_post/layer_22_width_262k_l0_medium_seed_1
4210
4215
  l0: 60
@@ -4217,6 +4222,7 @@ gemma-scope-2-4b-it-res:
4217
4222
  - id: layer_22_width_65k_l0_medium
4218
4223
  path: resid_post/layer_22_width_65k_l0_medium
4219
4224
  l0: 60
4225
+ neuronpedia: gemma-3-4b-it/22-gemmascope-2-res-65k
4220
4226
  - id: layer_22_width_65k_l0_small
4221
4227
  path: resid_post/layer_22_width_65k_l0_small
4222
4228
  l0: 20
@@ -4226,6 +4232,7 @@ gemma-scope-2-4b-it-res:
4226
4232
  - id: layer_29_width_16k_l0_medium
4227
4233
  path: resid_post/layer_29_width_16k_l0_medium
4228
4234
  l0: 60
4235
+ neuronpedia: gemma-3-4b-it/29-gemmascope-2-res-16k
4229
4236
  - id: layer_29_width_16k_l0_small
4230
4237
  path: resid_post/layer_29_width_16k_l0_small
4231
4238
  l0: 20
@@ -4244,6 +4251,7 @@ gemma-scope-2-4b-it-res:
4244
4251
  - id: layer_29_width_262k_l0_medium
4245
4252
  path: resid_post/layer_29_width_262k_l0_medium
4246
4253
  l0: 60
4254
+ neuronpedia: gemma-3-4b-it/29-gemmascope-2-res-262k
4247
4255
  - id: layer_29_width_262k_l0_medium_seed_1
4248
4256
  path: resid_post/layer_29_width_262k_l0_medium_seed_1
4249
4257
  l0: 60
@@ -4256,6 +4264,7 @@ gemma-scope-2-4b-it-res:
4256
4264
  - id: layer_29_width_65k_l0_medium
4257
4265
  path: resid_post/layer_29_width_65k_l0_medium
4258
4266
  l0: 60
4267
+ neuronpedia: gemma-3-4b-it/29-gemmascope-2-res-65k
4259
4268
  - id: layer_29_width_65k_l0_small
4260
4269
  path: resid_post/layer_29_width_65k_l0_small
4261
4270
  l0: 20
@@ -4265,6 +4274,7 @@ gemma-scope-2-4b-it-res:
4265
4274
  - id: layer_9_width_16k_l0_medium
4266
4275
  path: resid_post/layer_9_width_16k_l0_medium
4267
4276
  l0: 53
4277
+ neuronpedia: gemma-3-4b-it/9-gemmascope-2-res-16k
4268
4278
  - id: layer_9_width_16k_l0_small
4269
4279
  path: resid_post/layer_9_width_16k_l0_small
4270
4280
  l0: 17
@@ -4283,6 +4293,7 @@ gemma-scope-2-4b-it-res:
4283
4293
  - id: layer_9_width_262k_l0_medium
4284
4294
  path: resid_post/layer_9_width_262k_l0_medium
4285
4295
  l0: 53
4296
+ neuronpedia: gemma-3-4b-it/9-gemmascope-2-res-262k
4286
4297
  - id: layer_9_width_262k_l0_medium_seed_1
4287
4298
  path: resid_post/layer_9_width_262k_l0_medium_seed_1
4288
4299
  l0: 53
@@ -4295,6 +4306,7 @@ gemma-scope-2-4b-it-res:
4295
4306
  - id: layer_9_width_65k_l0_medium
4296
4307
  path: resid_post/layer_9_width_65k_l0_medium
4297
4308
  l0: 53
4309
+ neuronpedia: gemma-3-4b-it/9-gemmascope-2-res-65k
4298
4310
  - id: layer_9_width_65k_l0_small
4299
4311
  path: resid_post/layer_9_width_65k_l0_small
4300
4312
  l0: 17
@@ -14491,6 +14503,7 @@ gemma-scope-2-270m-it-res:
14491
14503
  - id: layer_12_width_16k_l0_medium
14492
14504
  path: resid_post/layer_12_width_16k_l0_medium
14493
14505
  l0: 60
14506
+ neuronpedia: gemma-3-270m-it/12-gemmascope-2-res-16k
14494
14507
  - id: layer_12_width_16k_l0_small
14495
14508
  path: resid_post/layer_12_width_16k_l0_small
14496
14509
  l0: 20
@@ -14509,6 +14522,7 @@ gemma-scope-2-270m-it-res:
14509
14522
  - id: layer_12_width_262k_l0_medium
14510
14523
  path: resid_post/layer_12_width_262k_l0_medium
14511
14524
  l0: 60
14525
+ neuronpedia: gemma-3-270m-it/12-gemmascope-2-res-262k
14512
14526
  - id: layer_12_width_262k_l0_medium_seed_1
14513
14527
  path: resid_post/layer_12_width_262k_l0_medium_seed_1
14514
14528
  l0: 60
@@ -14521,6 +14535,7 @@ gemma-scope-2-270m-it-res:
14521
14535
  - id: layer_12_width_65k_l0_medium
14522
14536
  path: resid_post/layer_12_width_65k_l0_medium
14523
14537
  l0: 60
14538
+ neuronpedia: gemma-3-270m-it/12-gemmascope-2-res-65k
14524
14539
  - id: layer_12_width_65k_l0_small
14525
14540
  path: resid_post/layer_12_width_65k_l0_small
14526
14541
  l0: 20
@@ -14530,6 +14545,7 @@ gemma-scope-2-270m-it-res:
14530
14545
  - id: layer_15_width_16k_l0_medium
14531
14546
  path: resid_post/layer_15_width_16k_l0_medium
14532
14547
  l0: 60
14548
+ neuronpedia: gemma-3-270m-it/15-gemmascope-2-res-16k
14533
14549
  - id: layer_15_width_16k_l0_small
14534
14550
  path: resid_post/layer_15_width_16k_l0_small
14535
14551
  l0: 20
@@ -14548,6 +14564,7 @@ gemma-scope-2-270m-it-res:
14548
14564
  - id: layer_15_width_262k_l0_medium
14549
14565
  path: resid_post/layer_15_width_262k_l0_medium
14550
14566
  l0: 60
14567
+ neuronpedia: gemma-3-270m-it/15-gemmascope-2-res-262k
14551
14568
  - id: layer_15_width_262k_l0_medium_seed_1
14552
14569
  path: resid_post/layer_15_width_262k_l0_medium_seed_1
14553
14570
  l0: 60
@@ -14560,6 +14577,7 @@ gemma-scope-2-270m-it-res:
14560
14577
  - id: layer_15_width_65k_l0_medium
14561
14578
  path: resid_post/layer_15_width_65k_l0_medium
14562
14579
  l0: 60
14580
+ neuronpedia: gemma-3-270m-it/15-gemmascope-2-res-65k
14563
14581
  - id: layer_15_width_65k_l0_small
14564
14582
  path: resid_post/layer_15_width_65k_l0_small
14565
14583
  l0: 20
@@ -14569,6 +14587,7 @@ gemma-scope-2-270m-it-res:
14569
14587
  - id: layer_5_width_16k_l0_medium
14570
14588
  path: resid_post/layer_5_width_16k_l0_medium
14571
14589
  l0: 55
14590
+ neuronpedia: gemma-3-270m-it/5-gemmascope-2-res-16k
14572
14591
  - id: layer_5_width_16k_l0_small
14573
14592
  path: resid_post/layer_5_width_16k_l0_small
14574
14593
  l0: 18
@@ -14587,6 +14606,7 @@ gemma-scope-2-270m-it-res:
14587
14606
  - id: layer_5_width_262k_l0_medium
14588
14607
  path: resid_post/layer_5_width_262k_l0_medium
14589
14608
  l0: 55
14609
+ neuronpedia: gemma-3-270m-it/5-gemmascope-2-res-262k
14590
14610
  - id: layer_5_width_262k_l0_medium_seed_1
14591
14611
  path: resid_post/layer_5_width_262k_l0_medium_seed_1
14592
14612
  l0: 55
@@ -14599,6 +14619,7 @@ gemma-scope-2-270m-it-res:
14599
14619
  - id: layer_5_width_65k_l0_medium
14600
14620
  path: resid_post/layer_5_width_65k_l0_medium
14601
14621
  l0: 55
14622
+ neuronpedia: gemma-3-270m-it/5-gemmascope-2-res-65k
14602
14623
  - id: layer_5_width_65k_l0_small
14603
14624
  path: resid_post/layer_5_width_65k_l0_small
14604
14625
  l0: 18
@@ -14608,6 +14629,7 @@ gemma-scope-2-270m-it-res:
14608
14629
  - id: layer_9_width_16k_l0_medium
14609
14630
  path: resid_post/layer_9_width_16k_l0_medium
14610
14631
  l0: 60
14632
+ neuronpedia: gemma-3-270m-it/9-gemmascope-2-res-16k
14611
14633
  - id: layer_9_width_16k_l0_small
14612
14634
  path: resid_post/layer_9_width_16k_l0_small
14613
14635
  l0: 20
@@ -14626,6 +14648,7 @@ gemma-scope-2-270m-it-res:
14626
14648
  - id: layer_9_width_262k_l0_medium
14627
14649
  path: resid_post/layer_9_width_262k_l0_medium
14628
14650
  l0: 60
14651
+ neuronpedia: gemma-3-270m-it/9-gemmascope-2-res-262k
14629
14652
  - id: layer_9_width_262k_l0_medium_seed_1
14630
14653
  path: resid_post/layer_9_width_262k_l0_medium_seed_1
14631
14654
  l0: 60
@@ -14638,6 +14661,7 @@ gemma-scope-2-270m-it-res:
14638
14661
  - id: layer_9_width_65k_l0_medium
14639
14662
  path: resid_post/layer_9_width_65k_l0_medium
14640
14663
  l0: 60
14664
+ neuronpedia: gemma-3-270m-it/9-gemmascope-2-res-65k
14641
14665
  - id: layer_9_width_65k_l0_small
14642
14666
  path: resid_post/layer_9_width_65k_l0_small
14643
14667
  l0: 20
@@ -18727,6 +18751,7 @@ gemma-scope-2-1b-it-res:
18727
18751
  - id: layer_13_width_16k_l0_medium
18728
18752
  path: resid_post/layer_13_width_16k_l0_medium
18729
18753
  l0: 60
18754
+ neuronpedia: gemma-3-1b-it/13-gemmascope-2-res-16k
18730
18755
  - id: layer_13_width_16k_l0_small
18731
18756
  path: resid_post/layer_13_width_16k_l0_small
18732
18757
  l0: 20
@@ -18745,6 +18770,7 @@ gemma-scope-2-1b-it-res:
18745
18770
  - id: layer_13_width_262k_l0_medium
18746
18771
  path: resid_post/layer_13_width_262k_l0_medium
18747
18772
  l0: 60
18773
+ neuronpedia: gemma-3-1b-it/13-gemmascope-2-res-262k
18748
18774
  - id: layer_13_width_262k_l0_medium_seed_1
18749
18775
  path: resid_post/layer_13_width_262k_l0_medium_seed_1
18750
18776
  l0: 60
@@ -18757,6 +18783,7 @@ gemma-scope-2-1b-it-res:
18757
18783
  - id: layer_13_width_65k_l0_medium
18758
18784
  path: resid_post/layer_13_width_65k_l0_medium
18759
18785
  l0: 60
18786
+ neuronpedia: gemma-3-1b-it/13-gemmascope-2-res-65k
18760
18787
  - id: layer_13_width_65k_l0_small
18761
18788
  path: resid_post/layer_13_width_65k_l0_small
18762
18789
  l0: 20
@@ -18766,6 +18793,7 @@ gemma-scope-2-1b-it-res:
18766
18793
  - id: layer_17_width_16k_l0_medium
18767
18794
  path: resid_post/layer_17_width_16k_l0_medium
18768
18795
  l0: 60
18796
+ neuronpedia: gemma-3-1b-it/17-gemmascope-2-res-16k
18769
18797
  - id: layer_17_width_16k_l0_small
18770
18798
  path: resid_post/layer_17_width_16k_l0_small
18771
18799
  l0: 20
@@ -18784,6 +18812,7 @@ gemma-scope-2-1b-it-res:
18784
18812
  - id: layer_17_width_262k_l0_medium
18785
18813
  path: resid_post/layer_17_width_262k_l0_medium
18786
18814
  l0: 60
18815
+ neuronpedia: gemma-3-1b-it/17-gemmascope-2-res-262k
18787
18816
  - id: layer_17_width_262k_l0_medium_seed_1
18788
18817
  path: resid_post/layer_17_width_262k_l0_medium_seed_1
18789
18818
  l0: 60
@@ -18796,6 +18825,7 @@ gemma-scope-2-1b-it-res:
18796
18825
  - id: layer_17_width_65k_l0_medium
18797
18826
  path: resid_post/layer_17_width_65k_l0_medium
18798
18827
  l0: 60
18828
+ neuronpedia: gemma-3-1b-it/17-gemmascope-2-res-65k
18799
18829
  - id: layer_17_width_65k_l0_small
18800
18830
  path: resid_post/layer_17_width_65k_l0_small
18801
18831
  l0: 20
@@ -18805,6 +18835,7 @@ gemma-scope-2-1b-it-res:
18805
18835
  - id: layer_22_width_16k_l0_medium
18806
18836
  path: resid_post/layer_22_width_16k_l0_medium
18807
18837
  l0: 60
18838
+ neuronpedia: gemma-3-1b-it/22-gemmascope-2-res-16k
18808
18839
  - id: layer_22_width_16k_l0_small
18809
18840
  path: resid_post/layer_22_width_16k_l0_small
18810
18841
  l0: 20
@@ -18823,6 +18854,7 @@ gemma-scope-2-1b-it-res:
18823
18854
  - id: layer_22_width_262k_l0_medium
18824
18855
  path: resid_post/layer_22_width_262k_l0_medium
18825
18856
  l0: 60
18857
+ neuronpedia: gemma-3-1b-it/22-gemmascope-2-res-262k
18826
18858
  - id: layer_22_width_262k_l0_medium_seed_1
18827
18859
  path: resid_post/layer_22_width_262k_l0_medium_seed_1
18828
18860
  l0: 60
@@ -18835,6 +18867,7 @@ gemma-scope-2-1b-it-res:
18835
18867
  - id: layer_22_width_65k_l0_medium
18836
18868
  path: resid_post/layer_22_width_65k_l0_medium
18837
18869
  l0: 60
18870
+ neuronpedia: gemma-3-1b-it/22-gemmascope-2-res-65k
18838
18871
  - id: layer_22_width_65k_l0_small
18839
18872
  path: resid_post/layer_22_width_65k_l0_small
18840
18873
  l0: 20
@@ -18844,6 +18877,7 @@ gemma-scope-2-1b-it-res:
18844
18877
  - id: layer_7_width_16k_l0_medium
18845
18878
  path: resid_post/layer_7_width_16k_l0_medium
18846
18879
  l0: 54
18880
+ neuronpedia: gemma-3-1b-it/7-gemmascope-2-res-16k
18847
18881
  - id: layer_7_width_16k_l0_small
18848
18882
  path: resid_post/layer_7_width_16k_l0_small
18849
18883
  l0: 18
@@ -18862,6 +18896,7 @@ gemma-scope-2-1b-it-res:
18862
18896
  - id: layer_7_width_262k_l0_medium
18863
18897
  path: resid_post/layer_7_width_262k_l0_medium
18864
18898
  l0: 54
18899
+ neuronpedia: gemma-3-1b-it/7-gemmascope-2-res-262k
18865
18900
  - id: layer_7_width_262k_l0_medium_seed_1
18866
18901
  path: resid_post/layer_7_width_262k_l0_medium_seed_1
18867
18902
  l0: 54
@@ -18874,6 +18909,7 @@ gemma-scope-2-1b-it-res:
18874
18909
  - id: layer_7_width_65k_l0_medium
18875
18910
  path: resid_post/layer_7_width_65k_l0_medium
18876
18911
  l0: 54
18912
+ neuronpedia: gemma-3-1b-it/7-gemmascope-2-res-65k
18877
18913
  - id: layer_7_width_65k_l0_small
18878
18914
  path: resid_post/layer_7_width_65k_l0_small
18879
18915
  l0: 18
@@ -2,12 +2,12 @@
2
2
  Functions for generating synthetic feature activations.
3
3
  """
4
4
 
5
+ import math
5
6
  from collections.abc import Callable, Sequence
6
7
 
7
8
  import torch
8
- from scipy.stats import norm
9
9
  from torch import nn
10
- from torch.distributions import LowRankMultivariateNormal, MultivariateNormal
10
+ from torch.distributions import MultivariateNormal
11
11
 
12
12
  from sae_lens.synthetic.correlation import LowRankCorrelationMatrix
13
13
  from sae_lens.util import str_to_dtype
@@ -34,6 +34,7 @@ class ActivationGenerator(nn.Module):
34
34
  correlation_matrix: torch.Tensor | None
35
35
  low_rank_correlation: tuple[torch.Tensor, torch.Tensor] | None
36
36
  correlation_thresholds: torch.Tensor | None
37
+ use_sparse_tensors: bool
37
38
 
38
39
  def __init__(
39
40
  self,
@@ -45,7 +46,34 @@ class ActivationGenerator(nn.Module):
45
46
  correlation_matrix: CorrelationMatrixInput | None = None,
46
47
  device: torch.device | str = "cpu",
47
48
  dtype: torch.dtype | str = "float32",
49
+ use_sparse_tensors: bool = False,
48
50
  ):
51
+ """
52
+ Create a new ActivationGenerator.
53
+
54
+ Args:
55
+ num_features: Number of features to generate activations for.
56
+ firing_probabilities: Probability of each feature firing. Can be a single
57
+ float (applied to all features) or a tensor of shape (num_features,).
58
+ std_firing_magnitudes: Standard deviation of firing magnitudes. Can be a
59
+ single float or a tensor of shape (num_features,). Defaults to 0.0
60
+ (deterministic magnitudes).
61
+ mean_firing_magnitudes: Mean firing magnitude when a feature fires. Can be
62
+ a single float or a tensor of shape (num_features,). Defaults to 1.0.
63
+ modify_activations: Optional function(s) to modify activations after
64
+ generation. Can be a single callable, a sequence of callables (applied
65
+ in order), or None. Useful for applying hierarchy constraints.
66
+ correlation_matrix: Optional correlation structure between features. Can be:
67
+
68
+ - A full correlation matrix tensor of shape (num_features, num_features)
69
+ - A LowRankCorrelationMatrix for memory-efficient large-scale correlations
70
+ - A tuple of (factor, diag) tensors representing low-rank structure
71
+
72
+ device: Device to place tensors on. Defaults to "cpu".
73
+ dtype: Data type for tensors. Defaults to "float32".
74
+ use_sparse_tensors: If True, return sparse COO tensors from sample().
75
+ Only recommended when using massive numbers of features. Defaults to False.
76
+ """
49
77
  super().__init__()
50
78
  self.num_features = num_features
51
79
  self.firing_probabilities = _to_tensor(
@@ -61,6 +89,7 @@ class ActivationGenerator(nn.Module):
61
89
  self.correlation_thresholds = None
62
90
  self.correlation_matrix = None
63
91
  self.low_rank_correlation = None
92
+ self.use_sparse_tensors = use_sparse_tensors
64
93
 
65
94
  if correlation_matrix is not None:
66
95
  if isinstance(correlation_matrix, torch.Tensor):
@@ -76,12 +105,15 @@ class ActivationGenerator(nn.Module):
76
105
  _validate_low_rank_correlation(
77
106
  correlation_factor, correlation_diag, num_features
78
107
  )
79
- self.low_rank_correlation = (correlation_factor, correlation_diag)
108
+ # Pre-compute sqrt for efficiency (used every sample call)
109
+ self.low_rank_correlation = (
110
+ correlation_factor,
111
+ correlation_diag.sqrt(),
112
+ )
80
113
 
81
- self.correlation_thresholds = torch.tensor(
82
- [norm.ppf(1 - p.item()) for p in self.firing_probabilities],
83
- device=device,
84
- dtype=self.firing_probabilities.dtype,
114
+ # Vectorized inverse normal CDF: norm.ppf(1-p) = sqrt(2) * erfinv(1 - 2*p)
115
+ self.correlation_thresholds = math.sqrt(2) * torch.erfinv(
116
+ 1 - 2 * self.firing_probabilities
85
117
  )
86
118
 
87
119
  @torch.no_grad()
@@ -105,7 +137,7 @@ class ActivationGenerator(nn.Module):
105
137
 
106
138
  if self.correlation_matrix is not None:
107
139
  assert self.correlation_thresholds is not None
108
- firing_features = _generate_correlated_features(
140
+ firing_indices = _generate_correlated_features(
109
141
  batch_size,
110
142
  self.correlation_matrix,
111
143
  self.correlation_thresholds,
@@ -113,7 +145,7 @@ class ActivationGenerator(nn.Module):
113
145
  )
114
146
  elif self.low_rank_correlation is not None:
115
147
  assert self.correlation_thresholds is not None
116
- firing_features = _generate_low_rank_correlated_features(
148
+ firing_indices = _generate_low_rank_correlated_features(
117
149
  batch_size,
118
150
  self.low_rank_correlation[0],
119
151
  self.low_rank_correlation[1],
@@ -121,23 +153,58 @@ class ActivationGenerator(nn.Module):
121
153
  device,
122
154
  )
123
155
  else:
124
- firing_features = torch.bernoulli(
156
+ firing_indices = torch.bernoulli(
125
157
  self.firing_probabilities.unsqueeze(0).expand(batch_size, -1)
158
+ ).nonzero(as_tuple=True)
159
+
160
+ # Compute activations only at firing positions (sparse optimization)
161
+ feature_indices = firing_indices[1]
162
+ num_firing = feature_indices.shape[0]
163
+ mean_at_firing = self.mean_firing_magnitudes[feature_indices]
164
+ std_at_firing = self.std_firing_magnitudes[feature_indices]
165
+ random_deltas = (
166
+ torch.randn(
167
+ num_firing, device=device, dtype=self.mean_firing_magnitudes.dtype
126
168
  )
127
-
128
- firing_magnitude_delta = torch.normal(
129
- torch.zeros_like(self.firing_probabilities)
130
- .unsqueeze(0)
131
- .expand(batch_size, -1),
132
- self.std_firing_magnitudes.unsqueeze(0).expand(batch_size, -1),
169
+ * std_at_firing
133
170
  )
134
- firing_magnitude_delta[firing_features == 0] = 0
135
- feature_activations = (
136
- firing_features * self.mean_firing_magnitudes + firing_magnitude_delta
137
- ).relu()
171
+ activations_at_firing = (mean_at_firing + random_deltas).relu()
172
+
173
+ if self.use_sparse_tensors:
174
+ # Return sparse COO tensor
175
+ indices = torch.stack(firing_indices) # [2, nnz]
176
+ feature_activations = torch.sparse_coo_tensor(
177
+ indices,
178
+ activations_at_firing,
179
+ size=(batch_size, self.num_features),
180
+ device=device,
181
+ dtype=self.mean_firing_magnitudes.dtype,
182
+ )
183
+ else:
184
+ # Dense tensor path
185
+ feature_activations = torch.zeros(
186
+ batch_size,
187
+ self.num_features,
188
+ device=device,
189
+ dtype=self.mean_firing_magnitudes.dtype,
190
+ )
191
+ feature_activations[firing_indices] = activations_at_firing
138
192
 
139
193
  if self.modify_activations is not None:
140
- feature_activations = self.modify_activations(feature_activations).relu()
194
+ feature_activations = self.modify_activations(feature_activations)
195
+ if feature_activations.is_sparse:
196
+ # Apply relu to sparse values
197
+ feature_activations = feature_activations.coalesce()
198
+ feature_activations = torch.sparse_coo_tensor(
199
+ feature_activations.indices(),
200
+ feature_activations.values().relu(),
201
+ feature_activations.shape,
202
+ device=feature_activations.device,
203
+ dtype=feature_activations.dtype,
204
+ )
205
+ else:
206
+ feature_activations = feature_activations.relu()
207
+
141
208
  return feature_activations
142
209
 
143
210
  def forward(self, batch_size: int) -> torch.Tensor:
@@ -149,7 +216,7 @@ def _generate_correlated_features(
149
216
  correlation_matrix: torch.Tensor,
150
217
  thresholds: torch.Tensor,
151
218
  device: torch.device,
152
- ) -> torch.Tensor:
219
+ ) -> tuple[torch.Tensor, torch.Tensor]:
153
220
  """
154
221
  Generate correlated binary features using multivariate Gaussian sampling.
155
222
 
@@ -163,7 +230,7 @@ def _generate_correlated_features(
163
230
  device: Device to generate samples on
164
231
 
165
232
  Returns:
166
- Binary feature matrix of shape (batch_size, num_features)
233
+ Tuple of (row_indices, col_indices) for firing features
167
234
  """
168
235
  num_features = correlation_matrix.shape[0]
169
236
 
@@ -173,16 +240,17 @@ def _generate_correlated_features(
173
240
  )
174
241
 
175
242
  gaussian_samples = mvn.sample((batch_size,))
176
- return (gaussian_samples > thresholds.unsqueeze(0)).float()
243
+ indices = (gaussian_samples > thresholds.unsqueeze(0)).nonzero(as_tuple=True)
244
+ return indices[0], indices[1]
177
245
 
178
246
 
179
247
  def _generate_low_rank_correlated_features(
180
248
  batch_size: int,
181
249
  correlation_factor: torch.Tensor,
182
- correlation_diag: torch.Tensor,
250
+ cov_diag_sqrt: torch.Tensor,
183
251
  thresholds: torch.Tensor,
184
252
  device: torch.device,
185
- ) -> torch.Tensor:
253
+ ) -> tuple[torch.Tensor, torch.Tensor]:
186
254
  """
187
255
  Generate correlated binary features using low-rank multivariate Gaussian sampling.
188
256
 
@@ -192,23 +260,29 @@ def _generate_low_rank_correlated_features(
192
260
  Args:
193
261
  batch_size: Number of samples to generate
194
262
  correlation_factor: Factor matrix of shape (num_features, rank)
195
- correlation_diag: Diagonal term of shape (num_features,)
263
+ cov_diag_sqrt: Pre-computed sqrt of diagonal term, shape (num_features,)
196
264
  thresholds: Pre-computed thresholds for each feature (from inverse normal CDF)
197
265
  device: Device to generate samples on
198
266
 
199
267
  Returns:
200
- Binary feature matrix of shape (batch_size, num_features)
268
+ Tuple of (row_indices, col_indices) for firing features
201
269
  """
202
- mvn = LowRankMultivariateNormal(
203
- loc=torch.zeros(
204
- correlation_factor.shape[0], device=device, dtype=thresholds.dtype
205
- ),
206
- cov_factor=correlation_factor.to(device=device, dtype=thresholds.dtype),
207
- cov_diag=correlation_diag.to(device=device, dtype=thresholds.dtype),
270
+ # Manual low-rank MVN sampling to enable autocast for the expensive matmul
271
+ # samples = eps @ cov_factor.T + eta * sqrt(cov_diag)
272
+ # where eps ~ N(0, I_rank) and eta ~ N(0, I_n)
273
+
274
+ num_features, rank = correlation_factor.shape
275
+
276
+ # Generate random samples in float32 for numerical stability
277
+ eps = torch.randn(batch_size, rank, device=device, dtype=correlation_factor.dtype)
278
+ eta = torch.randn(
279
+ batch_size, num_features, device=device, dtype=cov_diag_sqrt.dtype
208
280
  )
209
281
 
210
- gaussian_samples = mvn.sample((batch_size,))
211
- return (gaussian_samples > thresholds.unsqueeze(0)).float()
282
+ gaussian_samples = eps @ correlation_factor.T + eta * cov_diag_sqrt
283
+
284
+ indices = (gaussian_samples > thresholds.unsqueeze(0)).nonzero(as_tuple=True)
285
+ return indices[0], indices[1]
212
286
 
213
287
 
214
288
  def _to_tensor(
@@ -168,9 +168,18 @@ class FeatureDictionary(nn.Module):
168
168
 
169
169
  Args:
170
170
  feature_activations: Tensor of shape [batch, num_features] containing
171
- sparse feature activation values
171
+ sparse feature activation values. Can be dense or sparse COO.
172
172
 
173
173
  Returns:
174
174
  Tensor of shape [batch, hidden_dim] containing dense hidden activations
175
175
  """
176
+ if feature_activations.is_sparse:
177
+ # autocast is disabled here because sparse matmul is not supported with bfloat16
178
+ with torch.autocast(
179
+ device_type=feature_activations.device.type, enabled=False
180
+ ):
181
+ return (
182
+ torch.sparse.mm(feature_activations, self.feature_vectors)
183
+ + self.bias
184
+ )
176
185
  return feature_activations @ self.feature_vectors + self.bias
@@ -147,6 +147,14 @@ class _SparseHierarchyData:
147
147
  # Total number of ME groups
148
148
  num_groups: int
149
149
 
150
+ # Sparse COO support: Feature-to-parent mapping
151
+ # feat_to_parent[f] = parent feature index, or -1 if root/no parent
152
+ feat_to_parent: torch.Tensor | None = None # [num_features]
153
+
154
+ # Sparse COO support: Feature-to-ME-group mapping
155
+ # feat_to_me_group[f] = group index, or -1 if not in any ME group
156
+ feat_to_me_group: torch.Tensor | None = None # [num_features]
157
+
150
158
 
151
159
  def _build_sparse_hierarchy(
152
160
  roots: Sequence[HierarchyNode],
@@ -232,7 +240,11 @@ def _build_sparse_hierarchy(
232
240
  me_indices = torch.empty(0, dtype=torch.long)
233
241
 
234
242
  level_data.append(
235
- _LevelData(features=feats, parents=parents, me_group_indices=me_indices)
243
+ _LevelData(
244
+ features=feats,
245
+ parents=parents,
246
+ me_group_indices=me_indices,
247
+ )
236
248
  )
237
249
 
238
250
  # Build group siblings and parents tensors
@@ -254,12 +266,30 @@ def _build_sparse_hierarchy(
254
266
  me_group_parents = torch.empty(0, dtype=torch.long)
255
267
  num_groups = 0
256
268
 
269
+ # Build sparse COO support: feat_to_parent and feat_to_me_group mappings
270
+ # First determine num_features (max feature index + 1)
271
+ all_features = [f for f, _, _ in feature_info]
272
+ num_features = max(all_features) + 1 if all_features else 0
273
+
274
+ # Build feature-to-parent mapping
275
+ feat_to_parent = torch.full((num_features,), -1, dtype=torch.long)
276
+ for feat, parent, _ in feature_info:
277
+ feat_to_parent[feat] = parent
278
+
279
+ # Build feature-to-ME-group mapping
280
+ feat_to_me_group = torch.full((num_features,), -1, dtype=torch.long)
281
+ for g_idx, (_, _, siblings) in enumerate(me_groups):
282
+ for sib in siblings:
283
+ feat_to_me_group[sib] = g_idx
284
+
257
285
  return _SparseHierarchyData(
258
286
  level_data=level_data,
259
287
  me_group_siblings=me_group_siblings,
260
288
  me_group_sizes=me_group_sizes,
261
289
  me_group_parents=me_group_parents,
262
290
  num_groups=num_groups,
291
+ feat_to_parent=feat_to_parent,
292
+ feat_to_me_group=feat_to_me_group,
263
293
  )
264
294
 
265
295
 
@@ -396,8 +426,9 @@ def _apply_me_for_groups(
396
426
  # Random selection for winner
397
427
  # Use -1e9 instead of -inf to avoid creating a tensor (torch.tensor(-float("inf")))
398
428
  # on every call. Since random scores are in [0,1], -1e9 is effectively -inf for argmax.
429
+ _INACTIVE_SCORE = -1e9
399
430
  random_scores = torch.rand(num_conflicts, max_siblings, device=device)
400
- random_scores[~conflict_active] = -1e9
431
+ random_scores[~conflict_active] = _INACTIVE_SCORE
401
432
 
402
433
  winner_idx = random_scores.argmax(dim=1)
403
434
 
@@ -420,6 +451,275 @@ def _apply_me_for_groups(
420
451
  activations[deact_batch, deact_feat] = 0
421
452
 
422
453
 
454
+ # ---------------------------------------------------------------------------
455
+ # Sparse COO hierarchy implementation
456
+ # ---------------------------------------------------------------------------
457
+
458
+
459
+ def _apply_hierarchy_sparse_coo(
460
+ sparse_tensor: torch.Tensor,
461
+ sparse_data: _SparseHierarchyData,
462
+ ) -> torch.Tensor:
463
+ """
464
+ Apply hierarchy constraints to a sparse COO tensor.
465
+
466
+ This is the sparse analog of _apply_hierarchy_sparse. It processes
467
+ level-by-level, applying parent deactivation then mutual exclusion.
468
+ """
469
+ if sparse_tensor._nnz() == 0:
470
+ return sparse_tensor
471
+
472
+ sparse_tensor = sparse_tensor.coalesce()
473
+
474
+ for level_data in sparse_data.level_data:
475
+ # Step 1: Apply parent deactivation for features at this level
476
+ if level_data.features.numel() > 0:
477
+ sparse_tensor = _apply_parent_deactivation_coo(
478
+ sparse_tensor, level_data, sparse_data
479
+ )
480
+
481
+ # Step 2: Apply ME for groups whose parent is at this level
482
+ if level_data.me_group_indices.numel() > 0:
483
+ sparse_tensor = _apply_me_coo(
484
+ sparse_tensor, level_data.me_group_indices, sparse_data
485
+ )
486
+
487
+ return sparse_tensor
488
+
489
+
490
+ def _apply_parent_deactivation_coo(
491
+ sparse_tensor: torch.Tensor,
492
+ level_data: _LevelData,
493
+ sparse_data: _SparseHierarchyData,
494
+ ) -> torch.Tensor:
495
+ """
496
+ Remove children from sparse COO tensor when their parent is inactive.
497
+
498
+ Uses searchsorted for efficient membership testing of parent activity.
499
+ """
500
+ if sparse_tensor._nnz() == 0 or level_data.features.numel() == 0:
501
+ return sparse_tensor
502
+
503
+ sparse_tensor = sparse_tensor.coalesce()
504
+ indices = sparse_tensor.indices() # [2, nnz]
505
+ values = sparse_tensor.values() # [nnz]
506
+ batch_indices = indices[0]
507
+ feat_indices = indices[1]
508
+
509
+ _, num_features = sparse_tensor.shape
510
+ device = sparse_tensor.device
511
+ nnz = indices.shape[1]
512
+
513
+ # Build set of active (batch, feature) pairs for efficient lookup
514
+ # Encode as: batch_idx * num_features + feat_idx
515
+ active_pairs = batch_indices * num_features + feat_indices
516
+ active_pairs_sorted, _ = active_pairs.sort()
517
+
518
+ # Use the precomputed feat_to_parent mapping
519
+ assert sparse_data.feat_to_parent is not None
520
+ hierarchy_num_features = sparse_data.feat_to_parent.numel()
521
+
522
+ # Handle features outside the hierarchy (they have no parent, pass through)
523
+ in_hierarchy = feat_indices < hierarchy_num_features
524
+ parent_of_feat = torch.full((nnz,), -1, dtype=torch.long, device=device)
525
+ parent_of_feat[in_hierarchy] = sparse_data.feat_to_parent[
526
+ feat_indices[in_hierarchy]
527
+ ]
528
+
529
+ # Find entries that have a parent (parent >= 0 means this feature has a parent)
530
+ has_parent = parent_of_feat >= 0
531
+
532
+ if not has_parent.any():
533
+ return sparse_tensor
534
+
535
+ # For entries with parents, check if parent is active
536
+ child_entry_indices = torch.where(has_parent)[0]
537
+ child_batch = batch_indices[has_parent]
538
+ child_parents = parent_of_feat[has_parent]
539
+
540
+ # Look up parent activity using searchsorted
541
+ parent_pairs = child_batch * num_features + child_parents
542
+ search_pos = torch.searchsorted(active_pairs_sorted, parent_pairs)
543
+ search_pos = search_pos.clamp(max=active_pairs_sorted.numel() - 1)
544
+ parent_active = active_pairs_sorted[search_pos] == parent_pairs
545
+
546
+ # Handle empty case
547
+ if active_pairs_sorted.numel() == 0:
548
+ parent_active = torch.zeros_like(parent_pairs, dtype=torch.bool)
549
+
550
+ # Build keep mask: keep entry if it's a root OR its parent is active
551
+ keep_mask = torch.ones(nnz, dtype=torch.bool, device=device)
552
+ keep_mask[child_entry_indices[~parent_active]] = False
553
+
554
+ if keep_mask.all():
555
+ return sparse_tensor
556
+
557
+ return torch.sparse_coo_tensor(
558
+ indices[:, keep_mask],
559
+ values[keep_mask],
560
+ sparse_tensor.shape,
561
+ device=device,
562
+ dtype=sparse_tensor.dtype,
563
+ )
564
+
565
+
566
+ def _apply_me_coo(
567
+ sparse_tensor: torch.Tensor,
568
+ group_indices: torch.Tensor,
569
+ sparse_data: _SparseHierarchyData,
570
+ ) -> torch.Tensor:
571
+ """
572
+ Apply mutual exclusion to sparse COO tensor.
573
+
574
+ For each ME group with multiple active siblings in the same batch,
575
+ randomly selects one winner and removes the rest.
576
+ """
577
+ if sparse_tensor._nnz() == 0 or group_indices.numel() == 0:
578
+ return sparse_tensor
579
+
580
+ sparse_tensor = sparse_tensor.coalesce()
581
+ indices = sparse_tensor.indices() # [2, nnz]
582
+ values = sparse_tensor.values() # [nnz]
583
+ batch_indices = indices[0]
584
+ feat_indices = indices[1]
585
+
586
+ _, num_features = sparse_tensor.shape
587
+ device = sparse_tensor.device
588
+ nnz = indices.shape[1]
589
+
590
+ # Use precomputed feat_to_me_group mapping
591
+ assert sparse_data.feat_to_me_group is not None
592
+ hierarchy_num_features = sparse_data.feat_to_me_group.numel()
593
+
594
+ # Handle features outside the hierarchy (they are not in any ME group)
595
+ in_hierarchy = feat_indices < hierarchy_num_features
596
+ me_group_of_feat = torch.full((nnz,), -1, dtype=torch.long, device=device)
597
+ me_group_of_feat[in_hierarchy] = sparse_data.feat_to_me_group[
598
+ feat_indices[in_hierarchy]
599
+ ]
600
+
601
+ # Find entries that belong to ME groups we're processing (vectorized)
602
+ in_relevant_group = torch.isin(me_group_of_feat, group_indices)
603
+
604
+ if not in_relevant_group.any():
605
+ return sparse_tensor
606
+
607
+ # Get the ME entries
608
+ me_entry_indices = torch.where(in_relevant_group)[0]
609
+ me_batch = batch_indices[in_relevant_group]
610
+ me_group = me_group_of_feat[in_relevant_group]
611
+
612
+ # Check parent activity for ME groups (only apply ME if parent is active)
613
+ me_group_parents = sparse_data.me_group_parents[me_group]
614
+ has_parent = me_group_parents >= 0
615
+
616
+ if has_parent.any():
617
+ # Build active pairs for parent lookup
618
+ active_pairs = batch_indices * num_features + feat_indices
619
+ active_pairs_sorted, _ = active_pairs.sort()
620
+
621
+ parent_pairs = (
622
+ me_batch[has_parent] * num_features + me_group_parents[has_parent]
623
+ )
624
+ search_pos = torch.searchsorted(active_pairs_sorted, parent_pairs)
625
+ search_pos = search_pos.clamp(max=active_pairs_sorted.numel() - 1)
626
+ parent_active_for_has_parent = active_pairs_sorted[search_pos] == parent_pairs
627
+
628
+ # Build full parent_active mask
629
+ parent_active = torch.ones(
630
+ me_entry_indices.numel(), dtype=torch.bool, device=device
631
+ )
632
+ parent_active[has_parent] = parent_active_for_has_parent
633
+
634
+ # Filter to only ME entries where parent is active
635
+ valid_me = parent_active
636
+ me_entry_indices = me_entry_indices[valid_me]
637
+ me_batch = me_batch[valid_me]
638
+ me_group = me_group[valid_me]
639
+
640
+ if me_entry_indices.numel() == 0:
641
+ return sparse_tensor
642
+
643
+ # Encode (batch, group) pairs
644
+ num_groups = sparse_data.num_groups
645
+ batch_group_pairs = me_batch * num_groups + me_group
646
+
647
+ # Find unique (batch, group) pairs and count occurrences
648
+ unique_bg, inverse, counts = torch.unique(
649
+ batch_group_pairs, return_inverse=True, return_counts=True
650
+ )
651
+
652
+ # Only process pairs with count > 1 (conflicts)
653
+ has_conflict = counts > 1
654
+
655
+ if not has_conflict.any():
656
+ return sparse_tensor
657
+
658
+ # For efficiency, we process all conflicts together
659
+ # Assign random scores to each ME entry
660
+ random_scores = torch.rand(me_entry_indices.numel(), device=device)
661
+
662
+ # For each (batch, group) pair, we want the entry with highest score to be winner
663
+ # Use scatter_reduce to find max score per (batch, group)
664
+ bg_to_dense = torch.zeros(unique_bg.numel(), dtype=torch.long, device=device)
665
+ bg_to_dense[has_conflict.nonzero(as_tuple=True)[0]] = torch.arange(
666
+ has_conflict.sum(), device=device
667
+ )
668
+
669
+ # Map each ME entry to its dense conflict index
670
+ entry_has_conflict = has_conflict[inverse]
671
+
672
+ if not entry_has_conflict.any():
673
+ return sparse_tensor
674
+
675
+ conflict_entries_mask = entry_has_conflict
676
+ conflict_entry_indices = me_entry_indices[conflict_entries_mask]
677
+ conflict_random_scores = random_scores[conflict_entries_mask]
678
+ conflict_inverse = inverse[conflict_entries_mask]
679
+ conflict_dense_idx = bg_to_dense[conflict_inverse]
680
+
681
+ # Vectorized winner selection using sorting
682
+ # Sort entries by (group_idx, -random_score) so highest score comes first per group
683
+ # Use group * 2 - score to sort by group ascending, then score descending
684
+ sort_keys = conflict_dense_idx.float() * 2.0 - conflict_random_scores
685
+ sorted_order = sort_keys.argsort()
686
+ sorted_dense_idx = conflict_dense_idx[sorted_order]
687
+
688
+ # Find first entry of each group in sorted order (these are winners)
689
+ group_starts = torch.cat(
690
+ [
691
+ torch.tensor([True], device=device),
692
+ sorted_dense_idx[1:] != sorted_dense_idx[:-1],
693
+ ]
694
+ )
695
+
696
+ # Winners are entries at group starts in sorted order
697
+ winner_positions_in_sorted = torch.where(group_starts)[0]
698
+ winner_original_positions = sorted_order[winner_positions_in_sorted]
699
+
700
+ # Create winner mask (vectorized)
701
+ is_winner = torch.zeros(
702
+ conflict_entry_indices.numel(), dtype=torch.bool, device=device
703
+ )
704
+ is_winner[winner_original_positions] = True
705
+
706
+ # Build keep mask (vectorized)
707
+ keep_mask = torch.ones(nnz, dtype=torch.bool, device=device)
708
+ loser_entry_indices = conflict_entry_indices[~is_winner]
709
+ keep_mask[loser_entry_indices] = False
710
+
711
+ if keep_mask.all():
712
+ return sparse_tensor
713
+
714
+ return torch.sparse_coo_tensor(
715
+ indices[:, keep_mask],
716
+ values[keep_mask],
717
+ sparse_tensor.shape,
718
+ device=device,
719
+ dtype=sparse_tensor.dtype,
720
+ )
721
+
722
+
423
723
  @torch.no_grad()
424
724
  def hierarchy_modifier(
425
725
  roots: Sequence[HierarchyNode] | HierarchyNode,
@@ -475,12 +775,24 @@ def hierarchy_modifier(
475
775
  me_group_sizes=sparse_data.me_group_sizes.to(device),
476
776
  me_group_parents=sparse_data.me_group_parents.to(device),
477
777
  num_groups=sparse_data.num_groups,
778
+ feat_to_parent=(
779
+ sparse_data.feat_to_parent.to(device)
780
+ if sparse_data.feat_to_parent is not None
781
+ else None
782
+ ),
783
+ feat_to_me_group=(
784
+ sparse_data.feat_to_me_group.to(device)
785
+ if sparse_data.feat_to_me_group is not None
786
+ else None
787
+ ),
478
788
  )
479
789
  return device_cache[device]
480
790
 
481
791
  def modifier(activations: torch.Tensor) -> torch.Tensor:
482
792
  device = activations.device
483
793
  cached = _get_sparse_for_device(device)
794
+ if activations.is_sparse:
795
+ return _apply_hierarchy_sparse_coo(activations, cached)
484
796
  return _apply_hierarchy_sparse(activations, cached)
485
797
 
486
798
  return modifier
@@ -23,6 +23,8 @@ def train_toy_sae(
23
23
  device: str | torch.device = "cpu",
24
24
  n_snapshots: int = 0,
25
25
  snapshot_fn: Callable[[SAETrainer[Any, Any]], None] | None = None,
26
+ autocast_sae: bool = False,
27
+ autocast_data: bool = False,
26
28
  ) -> None:
27
29
  """
28
30
  Train an SAE on synthetic activations from a feature dictionary.
@@ -46,6 +48,8 @@ def train_toy_sae(
46
48
  snapshot_fn: Callback function called at each snapshot point. Receives
47
49
  the SAETrainer instance, allowing access to the SAE, training step,
48
50
  and other training state. Required if n_snapshots > 0.
51
+ autocast_sae: Whether to autocast the SAE to bfloat16. Only recommend for large SAEs on CUDA
52
+ autocast_data: Whether to autocast the activations generator and feature dictionary to bfloat16. Only recommend for large data on CUDA.
49
53
  """
50
54
 
51
55
  device_str = str(device) if isinstance(device, torch.device) else device
@@ -55,6 +59,7 @@ def train_toy_sae(
55
59
  feature_dict=feature_dict,
56
60
  activations_generator=activations_generator,
57
61
  batch_size=batch_size,
62
+ autocast=autocast_data,
58
63
  )
59
64
 
60
65
  # Create trainer config
@@ -64,7 +69,7 @@ def train_toy_sae(
64
69
  save_final_checkpoint=False,
65
70
  total_training_samples=training_samples,
66
71
  device=device_str,
67
- autocast=False,
72
+ autocast=autocast_sae,
68
73
  lr=lr,
69
74
  lr_end=lr,
70
75
  lr_scheduler_name="constant",
@@ -119,6 +124,7 @@ class SyntheticActivationIterator(Iterator[torch.Tensor]):
119
124
  feature_dict: FeatureDictionary,
120
125
  activations_generator: ActivationGenerator,
121
126
  batch_size: int,
127
+ autocast: bool = False,
122
128
  ):
123
129
  """
124
130
  Create a new SyntheticActivationIterator.
@@ -127,16 +133,23 @@ class SyntheticActivationIterator(Iterator[torch.Tensor]):
127
133
  feature_dict: The feature dictionary to use for generating hidden activations
128
134
  activations_generator: Generator that produces feature activations
129
135
  batch_size: Number of samples per batch
136
+ autocast: Whether to autocast the activations generator and feature dictionary to bfloat16.
130
137
  """
131
138
  self.feature_dict = feature_dict
132
139
  self.activations_generator = activations_generator
133
140
  self.batch_size = batch_size
141
+ self.autocast = autocast
134
142
 
135
143
  @torch.no_grad()
136
144
  def next_batch(self) -> torch.Tensor:
137
145
  """Generate the next batch of hidden activations."""
138
- features = self.activations_generator(self.batch_size)
139
- return self.feature_dict(features)
146
+ with torch.autocast(
147
+ device_type=self.feature_dict.feature_vectors.device.type,
148
+ dtype=torch.bfloat16,
149
+ enabled=self.autocast,
150
+ ):
151
+ features = self.activations_generator(self.batch_size)
152
+ return self.feature_dict(features)
140
153
 
141
154
  def __iter__(self) -> "SyntheticActivationIterator":
142
155
  return self
File without changes
File without changes
File without changes
File without changes
File without changes