sae-lens 6.28.2__py3-none-any.whl → 6.32.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +14 -1
- sae_lens/analysis/__init__.py +15 -0
- sae_lens/analysis/compat.py +16 -0
- sae_lens/analysis/hooked_sae_transformer.py +1 -1
- sae_lens/analysis/sae_transformer_bridge.py +348 -0
- sae_lens/config.py +9 -1
- sae_lens/evals.py +2 -2
- sae_lens/loading/pretrained_sae_loaders.py +11 -4
- sae_lens/pretrained_saes.yaml +36 -0
- sae_lens/saes/temporal_sae.py +1 -1
- sae_lens/synthetic/__init__.py +6 -0
- sae_lens/synthetic/activation_generator.py +197 -25
- sae_lens/synthetic/correlation.py +217 -36
- sae_lens/synthetic/feature_dictionary.py +11 -2
- sae_lens/synthetic/hierarchy.py +314 -2
- sae_lens/synthetic/training.py +16 -3
- sae_lens/training/activation_scaler.py +3 -1
- {sae_lens-6.28.2.dist-info → sae_lens-6.32.1.dist-info}/METADATA +2 -2
- {sae_lens-6.28.2.dist-info → sae_lens-6.32.1.dist-info}/RECORD +21 -19
- {sae_lens-6.28.2.dist-info → sae_lens-6.32.1.dist-info}/WHEEL +1 -1
- {sae_lens-6.28.2.dist-info → sae_lens-6.32.1.dist-info}/licenses/LICENSE +0 -0
sae_lens/pretrained_saes.yaml
CHANGED
|
@@ -4148,6 +4148,7 @@ gemma-scope-2-4b-it-res:
|
|
|
4148
4148
|
- id: layer_17_width_16k_l0_medium
|
|
4149
4149
|
path: resid_post/layer_17_width_16k_l0_medium
|
|
4150
4150
|
l0: 60
|
|
4151
|
+
neuronpedia: gemma-3-4b-it/17-gemmascope-2-res-16k
|
|
4151
4152
|
- id: layer_17_width_16k_l0_small
|
|
4152
4153
|
path: resid_post/layer_17_width_16k_l0_small
|
|
4153
4154
|
l0: 20
|
|
@@ -4166,6 +4167,7 @@ gemma-scope-2-4b-it-res:
|
|
|
4166
4167
|
- id: layer_17_width_262k_l0_medium
|
|
4167
4168
|
path: resid_post/layer_17_width_262k_l0_medium
|
|
4168
4169
|
l0: 60
|
|
4170
|
+
neuronpedia: gemma-3-4b-it/17-gemmascope-2-res-262k
|
|
4169
4171
|
- id: layer_17_width_262k_l0_medium_seed_1
|
|
4170
4172
|
path: resid_post/layer_17_width_262k_l0_medium_seed_1
|
|
4171
4173
|
l0: 60
|
|
@@ -4178,6 +4180,7 @@ gemma-scope-2-4b-it-res:
|
|
|
4178
4180
|
- id: layer_17_width_65k_l0_medium
|
|
4179
4181
|
path: resid_post/layer_17_width_65k_l0_medium
|
|
4180
4182
|
l0: 60
|
|
4183
|
+
neuronpedia: gemma-3-4b-it/17-gemmascope-2-res-65k
|
|
4181
4184
|
- id: layer_17_width_65k_l0_small
|
|
4182
4185
|
path: resid_post/layer_17_width_65k_l0_small
|
|
4183
4186
|
l0: 20
|
|
@@ -4187,6 +4190,7 @@ gemma-scope-2-4b-it-res:
|
|
|
4187
4190
|
- id: layer_22_width_16k_l0_medium
|
|
4188
4191
|
path: resid_post/layer_22_width_16k_l0_medium
|
|
4189
4192
|
l0: 60
|
|
4193
|
+
neuronpedia: gemma-3-4b-it/22-gemmascope-2-res-16k
|
|
4190
4194
|
- id: layer_22_width_16k_l0_small
|
|
4191
4195
|
path: resid_post/layer_22_width_16k_l0_small
|
|
4192
4196
|
l0: 20
|
|
@@ -4205,6 +4209,7 @@ gemma-scope-2-4b-it-res:
|
|
|
4205
4209
|
- id: layer_22_width_262k_l0_medium
|
|
4206
4210
|
path: resid_post/layer_22_width_262k_l0_medium
|
|
4207
4211
|
l0: 60
|
|
4212
|
+
neuronpedia: gemma-3-4b-it/22-gemmascope-2-res-262k
|
|
4208
4213
|
- id: layer_22_width_262k_l0_medium_seed_1
|
|
4209
4214
|
path: resid_post/layer_22_width_262k_l0_medium_seed_1
|
|
4210
4215
|
l0: 60
|
|
@@ -4217,6 +4222,7 @@ gemma-scope-2-4b-it-res:
|
|
|
4217
4222
|
- id: layer_22_width_65k_l0_medium
|
|
4218
4223
|
path: resid_post/layer_22_width_65k_l0_medium
|
|
4219
4224
|
l0: 60
|
|
4225
|
+
neuronpedia: gemma-3-4b-it/22-gemmascope-2-res-65k
|
|
4220
4226
|
- id: layer_22_width_65k_l0_small
|
|
4221
4227
|
path: resid_post/layer_22_width_65k_l0_small
|
|
4222
4228
|
l0: 20
|
|
@@ -4226,6 +4232,7 @@ gemma-scope-2-4b-it-res:
|
|
|
4226
4232
|
- id: layer_29_width_16k_l0_medium
|
|
4227
4233
|
path: resid_post/layer_29_width_16k_l0_medium
|
|
4228
4234
|
l0: 60
|
|
4235
|
+
neuronpedia: gemma-3-4b-it/29-gemmascope-2-res-16k
|
|
4229
4236
|
- id: layer_29_width_16k_l0_small
|
|
4230
4237
|
path: resid_post/layer_29_width_16k_l0_small
|
|
4231
4238
|
l0: 20
|
|
@@ -4244,6 +4251,7 @@ gemma-scope-2-4b-it-res:
|
|
|
4244
4251
|
- id: layer_29_width_262k_l0_medium
|
|
4245
4252
|
path: resid_post/layer_29_width_262k_l0_medium
|
|
4246
4253
|
l0: 60
|
|
4254
|
+
neuronpedia: gemma-3-4b-it/29-gemmascope-2-res-262k
|
|
4247
4255
|
- id: layer_29_width_262k_l0_medium_seed_1
|
|
4248
4256
|
path: resid_post/layer_29_width_262k_l0_medium_seed_1
|
|
4249
4257
|
l0: 60
|
|
@@ -4256,6 +4264,7 @@ gemma-scope-2-4b-it-res:
|
|
|
4256
4264
|
- id: layer_29_width_65k_l0_medium
|
|
4257
4265
|
path: resid_post/layer_29_width_65k_l0_medium
|
|
4258
4266
|
l0: 60
|
|
4267
|
+
neuronpedia: gemma-3-4b-it/29-gemmascope-2-res-65k
|
|
4259
4268
|
- id: layer_29_width_65k_l0_small
|
|
4260
4269
|
path: resid_post/layer_29_width_65k_l0_small
|
|
4261
4270
|
l0: 20
|
|
@@ -4265,6 +4274,7 @@ gemma-scope-2-4b-it-res:
|
|
|
4265
4274
|
- id: layer_9_width_16k_l0_medium
|
|
4266
4275
|
path: resid_post/layer_9_width_16k_l0_medium
|
|
4267
4276
|
l0: 53
|
|
4277
|
+
neuronpedia: gemma-3-4b-it/9-gemmascope-2-res-16k
|
|
4268
4278
|
- id: layer_9_width_16k_l0_small
|
|
4269
4279
|
path: resid_post/layer_9_width_16k_l0_small
|
|
4270
4280
|
l0: 17
|
|
@@ -4283,6 +4293,7 @@ gemma-scope-2-4b-it-res:
|
|
|
4283
4293
|
- id: layer_9_width_262k_l0_medium
|
|
4284
4294
|
path: resid_post/layer_9_width_262k_l0_medium
|
|
4285
4295
|
l0: 53
|
|
4296
|
+
neuronpedia: gemma-3-4b-it/9-gemmascope-2-res-262k
|
|
4286
4297
|
- id: layer_9_width_262k_l0_medium_seed_1
|
|
4287
4298
|
path: resid_post/layer_9_width_262k_l0_medium_seed_1
|
|
4288
4299
|
l0: 53
|
|
@@ -4295,6 +4306,7 @@ gemma-scope-2-4b-it-res:
|
|
|
4295
4306
|
- id: layer_9_width_65k_l0_medium
|
|
4296
4307
|
path: resid_post/layer_9_width_65k_l0_medium
|
|
4297
4308
|
l0: 53
|
|
4309
|
+
neuronpedia: gemma-3-4b-it/9-gemmascope-2-res-65k
|
|
4298
4310
|
- id: layer_9_width_65k_l0_small
|
|
4299
4311
|
path: resid_post/layer_9_width_65k_l0_small
|
|
4300
4312
|
l0: 17
|
|
@@ -14491,6 +14503,7 @@ gemma-scope-2-270m-it-res:
|
|
|
14491
14503
|
- id: layer_12_width_16k_l0_medium
|
|
14492
14504
|
path: resid_post/layer_12_width_16k_l0_medium
|
|
14493
14505
|
l0: 60
|
|
14506
|
+
neuronpedia: gemma-3-270m-it/12-gemmascope-2-res-16k
|
|
14494
14507
|
- id: layer_12_width_16k_l0_small
|
|
14495
14508
|
path: resid_post/layer_12_width_16k_l0_small
|
|
14496
14509
|
l0: 20
|
|
@@ -14509,6 +14522,7 @@ gemma-scope-2-270m-it-res:
|
|
|
14509
14522
|
- id: layer_12_width_262k_l0_medium
|
|
14510
14523
|
path: resid_post/layer_12_width_262k_l0_medium
|
|
14511
14524
|
l0: 60
|
|
14525
|
+
neuronpedia: gemma-3-270m-it/12-gemmascope-2-res-262k
|
|
14512
14526
|
- id: layer_12_width_262k_l0_medium_seed_1
|
|
14513
14527
|
path: resid_post/layer_12_width_262k_l0_medium_seed_1
|
|
14514
14528
|
l0: 60
|
|
@@ -14521,6 +14535,7 @@ gemma-scope-2-270m-it-res:
|
|
|
14521
14535
|
- id: layer_12_width_65k_l0_medium
|
|
14522
14536
|
path: resid_post/layer_12_width_65k_l0_medium
|
|
14523
14537
|
l0: 60
|
|
14538
|
+
neuronpedia: gemma-3-270m-it/12-gemmascope-2-res-65k
|
|
14524
14539
|
- id: layer_12_width_65k_l0_small
|
|
14525
14540
|
path: resid_post/layer_12_width_65k_l0_small
|
|
14526
14541
|
l0: 20
|
|
@@ -14530,6 +14545,7 @@ gemma-scope-2-270m-it-res:
|
|
|
14530
14545
|
- id: layer_15_width_16k_l0_medium
|
|
14531
14546
|
path: resid_post/layer_15_width_16k_l0_medium
|
|
14532
14547
|
l0: 60
|
|
14548
|
+
neuronpedia: gemma-3-270m-it/15-gemmascope-2-res-16k
|
|
14533
14549
|
- id: layer_15_width_16k_l0_small
|
|
14534
14550
|
path: resid_post/layer_15_width_16k_l0_small
|
|
14535
14551
|
l0: 20
|
|
@@ -14548,6 +14564,7 @@ gemma-scope-2-270m-it-res:
|
|
|
14548
14564
|
- id: layer_15_width_262k_l0_medium
|
|
14549
14565
|
path: resid_post/layer_15_width_262k_l0_medium
|
|
14550
14566
|
l0: 60
|
|
14567
|
+
neuronpedia: gemma-3-270m-it/15-gemmascope-2-res-262k
|
|
14551
14568
|
- id: layer_15_width_262k_l0_medium_seed_1
|
|
14552
14569
|
path: resid_post/layer_15_width_262k_l0_medium_seed_1
|
|
14553
14570
|
l0: 60
|
|
@@ -14560,6 +14577,7 @@ gemma-scope-2-270m-it-res:
|
|
|
14560
14577
|
- id: layer_15_width_65k_l0_medium
|
|
14561
14578
|
path: resid_post/layer_15_width_65k_l0_medium
|
|
14562
14579
|
l0: 60
|
|
14580
|
+
neuronpedia: gemma-3-270m-it/15-gemmascope-2-res-65k
|
|
14563
14581
|
- id: layer_15_width_65k_l0_small
|
|
14564
14582
|
path: resid_post/layer_15_width_65k_l0_small
|
|
14565
14583
|
l0: 20
|
|
@@ -14569,6 +14587,7 @@ gemma-scope-2-270m-it-res:
|
|
|
14569
14587
|
- id: layer_5_width_16k_l0_medium
|
|
14570
14588
|
path: resid_post/layer_5_width_16k_l0_medium
|
|
14571
14589
|
l0: 55
|
|
14590
|
+
neuronpedia: gemma-3-270m-it/5-gemmascope-2-res-16k
|
|
14572
14591
|
- id: layer_5_width_16k_l0_small
|
|
14573
14592
|
path: resid_post/layer_5_width_16k_l0_small
|
|
14574
14593
|
l0: 18
|
|
@@ -14587,6 +14606,7 @@ gemma-scope-2-270m-it-res:
|
|
|
14587
14606
|
- id: layer_5_width_262k_l0_medium
|
|
14588
14607
|
path: resid_post/layer_5_width_262k_l0_medium
|
|
14589
14608
|
l0: 55
|
|
14609
|
+
neuronpedia: gemma-3-270m-it/5-gemmascope-2-res-262k
|
|
14590
14610
|
- id: layer_5_width_262k_l0_medium_seed_1
|
|
14591
14611
|
path: resid_post/layer_5_width_262k_l0_medium_seed_1
|
|
14592
14612
|
l0: 55
|
|
@@ -14599,6 +14619,7 @@ gemma-scope-2-270m-it-res:
|
|
|
14599
14619
|
- id: layer_5_width_65k_l0_medium
|
|
14600
14620
|
path: resid_post/layer_5_width_65k_l0_medium
|
|
14601
14621
|
l0: 55
|
|
14622
|
+
neuronpedia: gemma-3-270m-it/5-gemmascope-2-res-65k
|
|
14602
14623
|
- id: layer_5_width_65k_l0_small
|
|
14603
14624
|
path: resid_post/layer_5_width_65k_l0_small
|
|
14604
14625
|
l0: 18
|
|
@@ -14608,6 +14629,7 @@ gemma-scope-2-270m-it-res:
|
|
|
14608
14629
|
- id: layer_9_width_16k_l0_medium
|
|
14609
14630
|
path: resid_post/layer_9_width_16k_l0_medium
|
|
14610
14631
|
l0: 60
|
|
14632
|
+
neuronpedia: gemma-3-270m-it/9-gemmascope-2-res-16k
|
|
14611
14633
|
- id: layer_9_width_16k_l0_small
|
|
14612
14634
|
path: resid_post/layer_9_width_16k_l0_small
|
|
14613
14635
|
l0: 20
|
|
@@ -14626,6 +14648,7 @@ gemma-scope-2-270m-it-res:
|
|
|
14626
14648
|
- id: layer_9_width_262k_l0_medium
|
|
14627
14649
|
path: resid_post/layer_9_width_262k_l0_medium
|
|
14628
14650
|
l0: 60
|
|
14651
|
+
neuronpedia: gemma-3-270m-it/9-gemmascope-2-res-262k
|
|
14629
14652
|
- id: layer_9_width_262k_l0_medium_seed_1
|
|
14630
14653
|
path: resid_post/layer_9_width_262k_l0_medium_seed_1
|
|
14631
14654
|
l0: 60
|
|
@@ -14638,6 +14661,7 @@ gemma-scope-2-270m-it-res:
|
|
|
14638
14661
|
- id: layer_9_width_65k_l0_medium
|
|
14639
14662
|
path: resid_post/layer_9_width_65k_l0_medium
|
|
14640
14663
|
l0: 60
|
|
14664
|
+
neuronpedia: gemma-3-270m-it/9-gemmascope-2-res-65k
|
|
14641
14665
|
- id: layer_9_width_65k_l0_small
|
|
14642
14666
|
path: resid_post/layer_9_width_65k_l0_small
|
|
14643
14667
|
l0: 20
|
|
@@ -18727,6 +18751,7 @@ gemma-scope-2-1b-it-res:
|
|
|
18727
18751
|
- id: layer_13_width_16k_l0_medium
|
|
18728
18752
|
path: resid_post/layer_13_width_16k_l0_medium
|
|
18729
18753
|
l0: 60
|
|
18754
|
+
neuronpedia: gemma-3-1b-it/13-gemmascope-2-res-16k
|
|
18730
18755
|
- id: layer_13_width_16k_l0_small
|
|
18731
18756
|
path: resid_post/layer_13_width_16k_l0_small
|
|
18732
18757
|
l0: 20
|
|
@@ -18745,6 +18770,7 @@ gemma-scope-2-1b-it-res:
|
|
|
18745
18770
|
- id: layer_13_width_262k_l0_medium
|
|
18746
18771
|
path: resid_post/layer_13_width_262k_l0_medium
|
|
18747
18772
|
l0: 60
|
|
18773
|
+
neuronpedia: gemma-3-1b-it/13-gemmascope-2-res-262k
|
|
18748
18774
|
- id: layer_13_width_262k_l0_medium_seed_1
|
|
18749
18775
|
path: resid_post/layer_13_width_262k_l0_medium_seed_1
|
|
18750
18776
|
l0: 60
|
|
@@ -18757,6 +18783,7 @@ gemma-scope-2-1b-it-res:
|
|
|
18757
18783
|
- id: layer_13_width_65k_l0_medium
|
|
18758
18784
|
path: resid_post/layer_13_width_65k_l0_medium
|
|
18759
18785
|
l0: 60
|
|
18786
|
+
neuronpedia: gemma-3-1b-it/13-gemmascope-2-res-65k
|
|
18760
18787
|
- id: layer_13_width_65k_l0_small
|
|
18761
18788
|
path: resid_post/layer_13_width_65k_l0_small
|
|
18762
18789
|
l0: 20
|
|
@@ -18766,6 +18793,7 @@ gemma-scope-2-1b-it-res:
|
|
|
18766
18793
|
- id: layer_17_width_16k_l0_medium
|
|
18767
18794
|
path: resid_post/layer_17_width_16k_l0_medium
|
|
18768
18795
|
l0: 60
|
|
18796
|
+
neuronpedia: gemma-3-1b-it/17-gemmascope-2-res-16k
|
|
18769
18797
|
- id: layer_17_width_16k_l0_small
|
|
18770
18798
|
path: resid_post/layer_17_width_16k_l0_small
|
|
18771
18799
|
l0: 20
|
|
@@ -18784,6 +18812,7 @@ gemma-scope-2-1b-it-res:
|
|
|
18784
18812
|
- id: layer_17_width_262k_l0_medium
|
|
18785
18813
|
path: resid_post/layer_17_width_262k_l0_medium
|
|
18786
18814
|
l0: 60
|
|
18815
|
+
neuronpedia: gemma-3-1b-it/17-gemmascope-2-res-262k
|
|
18787
18816
|
- id: layer_17_width_262k_l0_medium_seed_1
|
|
18788
18817
|
path: resid_post/layer_17_width_262k_l0_medium_seed_1
|
|
18789
18818
|
l0: 60
|
|
@@ -18796,6 +18825,7 @@ gemma-scope-2-1b-it-res:
|
|
|
18796
18825
|
- id: layer_17_width_65k_l0_medium
|
|
18797
18826
|
path: resid_post/layer_17_width_65k_l0_medium
|
|
18798
18827
|
l0: 60
|
|
18828
|
+
neuronpedia: gemma-3-1b-it/17-gemmascope-2-res-65k
|
|
18799
18829
|
- id: layer_17_width_65k_l0_small
|
|
18800
18830
|
path: resid_post/layer_17_width_65k_l0_small
|
|
18801
18831
|
l0: 20
|
|
@@ -18805,6 +18835,7 @@ gemma-scope-2-1b-it-res:
|
|
|
18805
18835
|
- id: layer_22_width_16k_l0_medium
|
|
18806
18836
|
path: resid_post/layer_22_width_16k_l0_medium
|
|
18807
18837
|
l0: 60
|
|
18838
|
+
neuronpedia: gemma-3-1b-it/22-gemmascope-2-res-16k
|
|
18808
18839
|
- id: layer_22_width_16k_l0_small
|
|
18809
18840
|
path: resid_post/layer_22_width_16k_l0_small
|
|
18810
18841
|
l0: 20
|
|
@@ -18823,6 +18854,7 @@ gemma-scope-2-1b-it-res:
|
|
|
18823
18854
|
- id: layer_22_width_262k_l0_medium
|
|
18824
18855
|
path: resid_post/layer_22_width_262k_l0_medium
|
|
18825
18856
|
l0: 60
|
|
18857
|
+
neuronpedia: gemma-3-1b-it/22-gemmascope-2-res-262k
|
|
18826
18858
|
- id: layer_22_width_262k_l0_medium_seed_1
|
|
18827
18859
|
path: resid_post/layer_22_width_262k_l0_medium_seed_1
|
|
18828
18860
|
l0: 60
|
|
@@ -18835,6 +18867,7 @@ gemma-scope-2-1b-it-res:
|
|
|
18835
18867
|
- id: layer_22_width_65k_l0_medium
|
|
18836
18868
|
path: resid_post/layer_22_width_65k_l0_medium
|
|
18837
18869
|
l0: 60
|
|
18870
|
+
neuronpedia: gemma-3-1b-it/22-gemmascope-2-res-65k
|
|
18838
18871
|
- id: layer_22_width_65k_l0_small
|
|
18839
18872
|
path: resid_post/layer_22_width_65k_l0_small
|
|
18840
18873
|
l0: 20
|
|
@@ -18844,6 +18877,7 @@ gemma-scope-2-1b-it-res:
|
|
|
18844
18877
|
- id: layer_7_width_16k_l0_medium
|
|
18845
18878
|
path: resid_post/layer_7_width_16k_l0_medium
|
|
18846
18879
|
l0: 54
|
|
18880
|
+
neuronpedia: gemma-3-1b-it/7-gemmascope-2-res-16k
|
|
18847
18881
|
- id: layer_7_width_16k_l0_small
|
|
18848
18882
|
path: resid_post/layer_7_width_16k_l0_small
|
|
18849
18883
|
l0: 18
|
|
@@ -18862,6 +18896,7 @@ gemma-scope-2-1b-it-res:
|
|
|
18862
18896
|
- id: layer_7_width_262k_l0_medium
|
|
18863
18897
|
path: resid_post/layer_7_width_262k_l0_medium
|
|
18864
18898
|
l0: 54
|
|
18899
|
+
neuronpedia: gemma-3-1b-it/7-gemmascope-2-res-262k
|
|
18865
18900
|
- id: layer_7_width_262k_l0_medium_seed_1
|
|
18866
18901
|
path: resid_post/layer_7_width_262k_l0_medium_seed_1
|
|
18867
18902
|
l0: 54
|
|
@@ -18874,6 +18909,7 @@ gemma-scope-2-1b-it-res:
|
|
|
18874
18909
|
- id: layer_7_width_65k_l0_medium
|
|
18875
18910
|
path: resid_post/layer_7_width_65k_l0_medium
|
|
18876
18911
|
l0: 54
|
|
18912
|
+
neuronpedia: gemma-3-1b-it/7-gemmascope-2-res-65k
|
|
18877
18913
|
- id: layer_7_width_65k_l0_small
|
|
18878
18914
|
path: resid_post/layer_7_width_65k_l0_small
|
|
18879
18915
|
l0: 18
|
sae_lens/saes/temporal_sae.py
CHANGED
sae_lens/synthetic/__init__.py
CHANGED
|
@@ -17,11 +17,14 @@ from sae_lens.synthetic.activation_generator import (
|
|
|
17
17
|
ActivationGenerator,
|
|
18
18
|
ActivationsModifier,
|
|
19
19
|
ActivationsModifierInput,
|
|
20
|
+
CorrelationMatrixInput,
|
|
20
21
|
)
|
|
21
22
|
from sae_lens.synthetic.correlation import (
|
|
23
|
+
LowRankCorrelationMatrix,
|
|
22
24
|
create_correlation_matrix_from_correlations,
|
|
23
25
|
generate_random_correlation_matrix,
|
|
24
26
|
generate_random_correlations,
|
|
27
|
+
generate_random_low_rank_correlation_matrix,
|
|
25
28
|
)
|
|
26
29
|
from sae_lens.synthetic.evals import (
|
|
27
30
|
SyntheticDataEvalResult,
|
|
@@ -66,6 +69,9 @@ __all__ = [
|
|
|
66
69
|
"create_correlation_matrix_from_correlations",
|
|
67
70
|
"generate_random_correlations",
|
|
68
71
|
"generate_random_correlation_matrix",
|
|
72
|
+
"generate_random_low_rank_correlation_matrix",
|
|
73
|
+
"LowRankCorrelationMatrix",
|
|
74
|
+
"CorrelationMatrixInput",
|
|
69
75
|
# Feature modifiers
|
|
70
76
|
"ActivationsModifier",
|
|
71
77
|
"ActivationsModifierInput",
|
|
@@ -2,17 +2,21 @@
|
|
|
2
2
|
Functions for generating synthetic feature activations.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
+
import math
|
|
5
6
|
from collections.abc import Callable, Sequence
|
|
6
7
|
|
|
7
8
|
import torch
|
|
8
|
-
from scipy.stats import norm
|
|
9
9
|
from torch import nn
|
|
10
10
|
from torch.distributions import MultivariateNormal
|
|
11
11
|
|
|
12
|
+
from sae_lens.synthetic.correlation import LowRankCorrelationMatrix
|
|
12
13
|
from sae_lens.util import str_to_dtype
|
|
13
14
|
|
|
14
15
|
ActivationsModifier = Callable[[torch.Tensor], torch.Tensor]
|
|
15
16
|
ActivationsModifierInput = ActivationsModifier | Sequence[ActivationsModifier] | None
|
|
17
|
+
CorrelationMatrixInput = (
|
|
18
|
+
torch.Tensor | LowRankCorrelationMatrix | tuple[torch.Tensor, torch.Tensor]
|
|
19
|
+
)
|
|
16
20
|
|
|
17
21
|
|
|
18
22
|
class ActivationGenerator(nn.Module):
|
|
@@ -28,7 +32,9 @@ class ActivationGenerator(nn.Module):
|
|
|
28
32
|
mean_firing_magnitudes: torch.Tensor
|
|
29
33
|
modify_activations: ActivationsModifier | None
|
|
30
34
|
correlation_matrix: torch.Tensor | None
|
|
35
|
+
low_rank_correlation: tuple[torch.Tensor, torch.Tensor] | None
|
|
31
36
|
correlation_thresholds: torch.Tensor | None
|
|
37
|
+
use_sparse_tensors: bool
|
|
32
38
|
|
|
33
39
|
def __init__(
|
|
34
40
|
self,
|
|
@@ -37,10 +43,37 @@ class ActivationGenerator(nn.Module):
|
|
|
37
43
|
std_firing_magnitudes: torch.Tensor | float = 0.0,
|
|
38
44
|
mean_firing_magnitudes: torch.Tensor | float = 1.0,
|
|
39
45
|
modify_activations: ActivationsModifierInput = None,
|
|
40
|
-
correlation_matrix:
|
|
46
|
+
correlation_matrix: CorrelationMatrixInput | None = None,
|
|
41
47
|
device: torch.device | str = "cpu",
|
|
42
48
|
dtype: torch.dtype | str = "float32",
|
|
49
|
+
use_sparse_tensors: bool = False,
|
|
43
50
|
):
|
|
51
|
+
"""
|
|
52
|
+
Create a new ActivationGenerator.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
num_features: Number of features to generate activations for.
|
|
56
|
+
firing_probabilities: Probability of each feature firing. Can be a single
|
|
57
|
+
float (applied to all features) or a tensor of shape (num_features,).
|
|
58
|
+
std_firing_magnitudes: Standard deviation of firing magnitudes. Can be a
|
|
59
|
+
single float or a tensor of shape (num_features,). Defaults to 0.0
|
|
60
|
+
(deterministic magnitudes).
|
|
61
|
+
mean_firing_magnitudes: Mean firing magnitude when a feature fires. Can be
|
|
62
|
+
a single float or a tensor of shape (num_features,). Defaults to 1.0.
|
|
63
|
+
modify_activations: Optional function(s) to modify activations after
|
|
64
|
+
generation. Can be a single callable, a sequence of callables (applied
|
|
65
|
+
in order), or None. Useful for applying hierarchy constraints.
|
|
66
|
+
correlation_matrix: Optional correlation structure between features. Can be:
|
|
67
|
+
|
|
68
|
+
- A full correlation matrix tensor of shape (num_features, num_features)
|
|
69
|
+
- A LowRankCorrelationMatrix for memory-efficient large-scale correlations
|
|
70
|
+
- A tuple of (factor, diag) tensors representing low-rank structure
|
|
71
|
+
|
|
72
|
+
device: Device to place tensors on. Defaults to "cpu".
|
|
73
|
+
dtype: Data type for tensors. Defaults to "float32".
|
|
74
|
+
use_sparse_tensors: If True, return sparse COO tensors from sample().
|
|
75
|
+
Only recommended when using massive numbers of features. Defaults to False.
|
|
76
|
+
"""
|
|
44
77
|
super().__init__()
|
|
45
78
|
self.num_features = num_features
|
|
46
79
|
self.firing_probabilities = _to_tensor(
|
|
@@ -54,14 +87,34 @@ class ActivationGenerator(nn.Module):
|
|
|
54
87
|
)
|
|
55
88
|
self.modify_activations = _normalize_modifiers(modify_activations)
|
|
56
89
|
self.correlation_thresholds = None
|
|
90
|
+
self.correlation_matrix = None
|
|
91
|
+
self.low_rank_correlation = None
|
|
92
|
+
self.use_sparse_tensors = use_sparse_tensors
|
|
93
|
+
|
|
57
94
|
if correlation_matrix is not None:
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
95
|
+
if isinstance(correlation_matrix, torch.Tensor):
|
|
96
|
+
# Full correlation matrix
|
|
97
|
+
_validate_correlation_matrix(correlation_matrix, num_features)
|
|
98
|
+
self.correlation_matrix = correlation_matrix
|
|
99
|
+
else:
|
|
100
|
+
# Low-rank correlation (tuple or LowRankCorrelationMatrix)
|
|
101
|
+
correlation_factor, correlation_diag = (
|
|
102
|
+
correlation_matrix[0],
|
|
103
|
+
correlation_matrix[1],
|
|
104
|
+
)
|
|
105
|
+
_validate_low_rank_correlation(
|
|
106
|
+
correlation_factor, correlation_diag, num_features
|
|
107
|
+
)
|
|
108
|
+
# Pre-compute sqrt for efficiency (used every sample call)
|
|
109
|
+
self.low_rank_correlation = (
|
|
110
|
+
correlation_factor,
|
|
111
|
+
correlation_diag.sqrt(),
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
# Vectorized inverse normal CDF: norm.ppf(1-p) = sqrt(2) * erfinv(1 - 2*p)
|
|
115
|
+
self.correlation_thresholds = math.sqrt(2) * torch.erfinv(
|
|
116
|
+
1 - 2 * self.firing_probabilities
|
|
63
117
|
)
|
|
64
|
-
self.correlation_matrix = correlation_matrix
|
|
65
118
|
|
|
66
119
|
@torch.no_grad()
|
|
67
120
|
def sample(self, batch_size: int) -> torch.Tensor:
|
|
@@ -84,30 +137,74 @@ class ActivationGenerator(nn.Module):
|
|
|
84
137
|
|
|
85
138
|
if self.correlation_matrix is not None:
|
|
86
139
|
assert self.correlation_thresholds is not None
|
|
87
|
-
|
|
140
|
+
firing_indices = _generate_correlated_features(
|
|
88
141
|
batch_size,
|
|
89
142
|
self.correlation_matrix,
|
|
90
143
|
self.correlation_thresholds,
|
|
91
144
|
device,
|
|
92
145
|
)
|
|
146
|
+
elif self.low_rank_correlation is not None:
|
|
147
|
+
assert self.correlation_thresholds is not None
|
|
148
|
+
firing_indices = _generate_low_rank_correlated_features(
|
|
149
|
+
batch_size,
|
|
150
|
+
self.low_rank_correlation[0],
|
|
151
|
+
self.low_rank_correlation[1],
|
|
152
|
+
self.correlation_thresholds,
|
|
153
|
+
device,
|
|
154
|
+
)
|
|
93
155
|
else:
|
|
94
|
-
|
|
156
|
+
firing_indices = torch.bernoulli(
|
|
95
157
|
self.firing_probabilities.unsqueeze(0).expand(batch_size, -1)
|
|
158
|
+
).nonzero(as_tuple=True)
|
|
159
|
+
|
|
160
|
+
# Compute activations only at firing positions (sparse optimization)
|
|
161
|
+
feature_indices = firing_indices[1]
|
|
162
|
+
num_firing = feature_indices.shape[0]
|
|
163
|
+
mean_at_firing = self.mean_firing_magnitudes[feature_indices]
|
|
164
|
+
std_at_firing = self.std_firing_magnitudes[feature_indices]
|
|
165
|
+
random_deltas = (
|
|
166
|
+
torch.randn(
|
|
167
|
+
num_firing, device=device, dtype=self.mean_firing_magnitudes.dtype
|
|
96
168
|
)
|
|
97
|
-
|
|
98
|
-
firing_magnitude_delta = torch.normal(
|
|
99
|
-
torch.zeros_like(self.firing_probabilities)
|
|
100
|
-
.unsqueeze(0)
|
|
101
|
-
.expand(batch_size, -1),
|
|
102
|
-
self.std_firing_magnitudes.unsqueeze(0).expand(batch_size, -1),
|
|
169
|
+
* std_at_firing
|
|
103
170
|
)
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
171
|
+
activations_at_firing = (mean_at_firing + random_deltas).relu()
|
|
172
|
+
|
|
173
|
+
if self.use_sparse_tensors:
|
|
174
|
+
# Return sparse COO tensor
|
|
175
|
+
indices = torch.stack(firing_indices) # [2, nnz]
|
|
176
|
+
feature_activations = torch.sparse_coo_tensor(
|
|
177
|
+
indices,
|
|
178
|
+
activations_at_firing,
|
|
179
|
+
size=(batch_size, self.num_features),
|
|
180
|
+
device=device,
|
|
181
|
+
dtype=self.mean_firing_magnitudes.dtype,
|
|
182
|
+
)
|
|
183
|
+
else:
|
|
184
|
+
# Dense tensor path
|
|
185
|
+
feature_activations = torch.zeros(
|
|
186
|
+
batch_size,
|
|
187
|
+
self.num_features,
|
|
188
|
+
device=device,
|
|
189
|
+
dtype=self.mean_firing_magnitudes.dtype,
|
|
190
|
+
)
|
|
191
|
+
feature_activations[firing_indices] = activations_at_firing
|
|
108
192
|
|
|
109
193
|
if self.modify_activations is not None:
|
|
110
|
-
feature_activations = self.modify_activations(feature_activations)
|
|
194
|
+
feature_activations = self.modify_activations(feature_activations)
|
|
195
|
+
if feature_activations.is_sparse:
|
|
196
|
+
# Apply relu to sparse values
|
|
197
|
+
feature_activations = feature_activations.coalesce()
|
|
198
|
+
feature_activations = torch.sparse_coo_tensor(
|
|
199
|
+
feature_activations.indices(),
|
|
200
|
+
feature_activations.values().relu(),
|
|
201
|
+
feature_activations.shape,
|
|
202
|
+
device=feature_activations.device,
|
|
203
|
+
dtype=feature_activations.dtype,
|
|
204
|
+
)
|
|
205
|
+
else:
|
|
206
|
+
feature_activations = feature_activations.relu()
|
|
207
|
+
|
|
111
208
|
return feature_activations
|
|
112
209
|
|
|
113
210
|
def forward(self, batch_size: int) -> torch.Tensor:
|
|
@@ -119,7 +216,7 @@ def _generate_correlated_features(
|
|
|
119
216
|
correlation_matrix: torch.Tensor,
|
|
120
217
|
thresholds: torch.Tensor,
|
|
121
218
|
device: torch.device,
|
|
122
|
-
) -> torch.Tensor:
|
|
219
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
123
220
|
"""
|
|
124
221
|
Generate correlated binary features using multivariate Gaussian sampling.
|
|
125
222
|
|
|
@@ -133,7 +230,7 @@ def _generate_correlated_features(
|
|
|
133
230
|
device: Device to generate samples on
|
|
134
231
|
|
|
135
232
|
Returns:
|
|
136
|
-
|
|
233
|
+
Tuple of (row_indices, col_indices) for firing features
|
|
137
234
|
"""
|
|
138
235
|
num_features = correlation_matrix.shape[0]
|
|
139
236
|
|
|
@@ -143,7 +240,49 @@ def _generate_correlated_features(
|
|
|
143
240
|
)
|
|
144
241
|
|
|
145
242
|
gaussian_samples = mvn.sample((batch_size,))
|
|
146
|
-
|
|
243
|
+
indices = (gaussian_samples > thresholds.unsqueeze(0)).nonzero(as_tuple=True)
|
|
244
|
+
return indices[0], indices[1]
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def _generate_low_rank_correlated_features(
|
|
248
|
+
batch_size: int,
|
|
249
|
+
correlation_factor: torch.Tensor,
|
|
250
|
+
cov_diag_sqrt: torch.Tensor,
|
|
251
|
+
thresholds: torch.Tensor,
|
|
252
|
+
device: torch.device,
|
|
253
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
254
|
+
"""
|
|
255
|
+
Generate correlated binary features using low-rank multivariate Gaussian sampling.
|
|
256
|
+
|
|
257
|
+
Uses the Gaussian copula approach with a low-rank covariance structure for scalability.
|
|
258
|
+
The covariance is represented as: cov = factor @ factor.T + diag(diag_term)
|
|
259
|
+
|
|
260
|
+
Args:
|
|
261
|
+
batch_size: Number of samples to generate
|
|
262
|
+
correlation_factor: Factor matrix of shape (num_features, rank)
|
|
263
|
+
cov_diag_sqrt: Pre-computed sqrt of diagonal term, shape (num_features,)
|
|
264
|
+
thresholds: Pre-computed thresholds for each feature (from inverse normal CDF)
|
|
265
|
+
device: Device to generate samples on
|
|
266
|
+
|
|
267
|
+
Returns:
|
|
268
|
+
Tuple of (row_indices, col_indices) for firing features
|
|
269
|
+
"""
|
|
270
|
+
# Manual low-rank MVN sampling to enable autocast for the expensive matmul
|
|
271
|
+
# samples = eps @ cov_factor.T + eta * sqrt(cov_diag)
|
|
272
|
+
# where eps ~ N(0, I_rank) and eta ~ N(0, I_n)
|
|
273
|
+
|
|
274
|
+
num_features, rank = correlation_factor.shape
|
|
275
|
+
|
|
276
|
+
# Generate random samples in float32 for numerical stability
|
|
277
|
+
eps = torch.randn(batch_size, rank, device=device, dtype=correlation_factor.dtype)
|
|
278
|
+
eta = torch.randn(
|
|
279
|
+
batch_size, num_features, device=device, dtype=cov_diag_sqrt.dtype
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
gaussian_samples = eps @ correlation_factor.T + eta * cov_diag_sqrt
|
|
283
|
+
|
|
284
|
+
indices = (gaussian_samples > thresholds.unsqueeze(0)).nonzero(as_tuple=True)
|
|
285
|
+
return indices[0], indices[1]
|
|
147
286
|
|
|
148
287
|
|
|
149
288
|
def _to_tensor(
|
|
@@ -194,7 +333,7 @@ def _validate_correlation_matrix(
|
|
|
194
333
|
|
|
195
334
|
Args:
|
|
196
335
|
correlation_matrix: The matrix to validate
|
|
197
|
-
num_features: Expected number of features (matrix should be
|
|
336
|
+
num_features: Expected number of features (matrix should be (num_features, num_features))
|
|
198
337
|
|
|
199
338
|
Raises:
|
|
200
339
|
ValueError: If the matrix has incorrect shape, non-unit diagonal, or is not positive definite
|
|
@@ -214,3 +353,36 @@ def _validate_correlation_matrix(
|
|
|
214
353
|
torch.linalg.cholesky(correlation_matrix)
|
|
215
354
|
except RuntimeError as e:
|
|
216
355
|
raise ValueError("Correlation matrix must be positive definite") from e
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
def _validate_low_rank_correlation(
|
|
359
|
+
correlation_factor: torch.Tensor,
|
|
360
|
+
correlation_diag: torch.Tensor,
|
|
361
|
+
num_features: int,
|
|
362
|
+
) -> None:
|
|
363
|
+
"""Validate that low-rank correlation parameters have correct properties.
|
|
364
|
+
|
|
365
|
+
Args:
|
|
366
|
+
correlation_factor: Factor matrix of shape (num_features, rank)
|
|
367
|
+
correlation_diag: Diagonal term of shape (num_features,)
|
|
368
|
+
num_features: Expected number of features
|
|
369
|
+
|
|
370
|
+
Raises:
|
|
371
|
+
ValueError: If shapes are incorrect or diagonal terms are not positive
|
|
372
|
+
"""
|
|
373
|
+
if correlation_factor.ndim != 2:
|
|
374
|
+
raise ValueError(
|
|
375
|
+
f"correlation_factor must be 2D, got {correlation_factor.ndim}D"
|
|
376
|
+
)
|
|
377
|
+
if correlation_factor.shape[0] != num_features:
|
|
378
|
+
raise ValueError(
|
|
379
|
+
f"correlation_factor must have shape ({num_features}, rank), "
|
|
380
|
+
f"got {tuple(correlation_factor.shape)}"
|
|
381
|
+
)
|
|
382
|
+
if correlation_diag.shape != (num_features,):
|
|
383
|
+
raise ValueError(
|
|
384
|
+
f"correlation_diag must have shape ({num_features},), "
|
|
385
|
+
f"got {tuple(correlation_diag.shape)}"
|
|
386
|
+
)
|
|
387
|
+
if torch.any(correlation_diag <= 0):
|
|
388
|
+
raise ValueError("correlation_diag must have all positive values")
|