sae-lens 6.29.0__tar.gz → 6.30.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sae_lens-6.29.0 → sae_lens-6.30.1}/PKG-INFO +1 -1
- {sae_lens-6.29.0 → sae_lens-6.30.1}/pyproject.toml +1 -1
- {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/__init__.py +1 -1
- {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/loading/pretrained_sae_loaders.py +9 -3
- {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/pretrained_saes.yaml +36 -0
- {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/synthetic/activation_generator.py +110 -36
- {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/synthetic/feature_dictionary.py +10 -1
- {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/synthetic/hierarchy.py +314 -2
- {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/synthetic/training.py +16 -3
- {sae_lens-6.29.0 → sae_lens-6.30.1}/LICENSE +0 -0
- {sae_lens-6.29.0 → sae_lens-6.30.1}/README.md +0 -0
- {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/analysis/__init__.py +0 -0
- {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
- {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/analysis/neuronpedia_integration.py +0 -0
- {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/cache_activations_runner.py +0 -0
- {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/config.py +0 -0
- {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/constants.py +0 -0
- {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/evals.py +0 -0
- {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/llm_sae_training_runner.py +0 -0
- {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/load_model.py +0 -0
- {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/loading/__init__.py +0 -0
- {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/loading/pretrained_saes_directory.py +0 -0
- {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/pretokenize_runner.py +0 -0
- {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/registry.py +0 -0
- {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/saes/__init__.py +0 -0
- {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/saes/batchtopk_sae.py +0 -0
- {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/saes/gated_sae.py +0 -0
- {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/saes/jumprelu_sae.py +0 -0
- {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/saes/matching_pursuit_sae.py +0 -0
- {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/saes/matryoshka_batchtopk_sae.py +0 -0
- {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/saes/sae.py +0 -0
- {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/saes/standard_sae.py +0 -0
- {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/saes/temporal_sae.py +0 -0
- {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/saes/topk_sae.py +0 -0
- {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/saes/transcoder.py +0 -0
- {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/synthetic/__init__.py +0 -0
- {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/synthetic/correlation.py +0 -0
- {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/synthetic/evals.py +0 -0
- {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/synthetic/firing_probabilities.py +0 -0
- {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/synthetic/initialization.py +0 -0
- {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/synthetic/plotting.py +0 -0
- {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/tokenization_and_batching.py +0 -0
- {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/training/__init__.py +0 -0
- {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/training/activation_scaler.py +0 -0
- {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/training/activations_store.py +0 -0
- {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/training/mixing_buffer.py +0 -0
- {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/training/optim.py +0 -0
- {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/training/sae_trainer.py +0 -0
- {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/training/types.py +0 -0
- {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
- {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/tutorial/tsea.py +0 -0
- {sae_lens-6.29.0 → sae_lens-6.30.1}/sae_lens/util.py +0 -0
|
@@ -575,6 +575,8 @@ def _infer_gemma_3_raw_cfg_dict(repo_id: str, folder_name: str) -> dict[str, Any
|
|
|
575
575
|
"model_name": model_name,
|
|
576
576
|
"hf_hook_point_in": hf_hook_point_in,
|
|
577
577
|
}
|
|
578
|
+
if "transcoder" in folder_name or "clt" in folder_name:
|
|
579
|
+
cfg["affine_connection"] = "affine" in folder_name
|
|
578
580
|
if hf_hook_point_out is not None:
|
|
579
581
|
cfg["hf_hook_point_out"] = hf_hook_point_out
|
|
580
582
|
|
|
@@ -614,11 +616,11 @@ def get_gemma_3_config_from_hf(
|
|
|
614
616
|
if "resid_post" in folder_name:
|
|
615
617
|
hook_name = f"blocks.{layer}.hook_resid_post"
|
|
616
618
|
elif "attn_out" in folder_name:
|
|
617
|
-
hook_name = f"blocks.{layer}.
|
|
619
|
+
hook_name = f"blocks.{layer}.attn.hook_z"
|
|
618
620
|
elif "mlp_out" in folder_name:
|
|
619
621
|
hook_name = f"blocks.{layer}.hook_mlp_out"
|
|
620
622
|
elif "transcoder" in folder_name or "clt" in folder_name:
|
|
621
|
-
hook_name = f"blocks.{layer}.
|
|
623
|
+
hook_name = f"blocks.{layer}.hook_mlp_in"
|
|
622
624
|
hook_name_out = f"blocks.{layer}.hook_mlp_out"
|
|
623
625
|
else:
|
|
624
626
|
raise ValueError("Hook name not found in folder_name.")
|
|
@@ -643,7 +645,11 @@ def get_gemma_3_config_from_hf(
|
|
|
643
645
|
|
|
644
646
|
architecture = "jumprelu"
|
|
645
647
|
if "transcoder" in folder_name or "clt" in folder_name:
|
|
646
|
-
architecture =
|
|
648
|
+
architecture = (
|
|
649
|
+
"jumprelu_skip_transcoder"
|
|
650
|
+
if raw_cfg_dict.get("affine_connection", False)
|
|
651
|
+
else "jumprelu_transcoder"
|
|
652
|
+
)
|
|
647
653
|
d_out = shapes_dict["w_dec"][-1]
|
|
648
654
|
|
|
649
655
|
cfg = {
|
|
@@ -4148,6 +4148,7 @@ gemma-scope-2-4b-it-res:
|
|
|
4148
4148
|
- id: layer_17_width_16k_l0_medium
|
|
4149
4149
|
path: resid_post/layer_17_width_16k_l0_medium
|
|
4150
4150
|
l0: 60
|
|
4151
|
+
neuronpedia: gemma-3-4b-it/17-gemmascope-2-res-16k
|
|
4151
4152
|
- id: layer_17_width_16k_l0_small
|
|
4152
4153
|
path: resid_post/layer_17_width_16k_l0_small
|
|
4153
4154
|
l0: 20
|
|
@@ -4166,6 +4167,7 @@ gemma-scope-2-4b-it-res:
|
|
|
4166
4167
|
- id: layer_17_width_262k_l0_medium
|
|
4167
4168
|
path: resid_post/layer_17_width_262k_l0_medium
|
|
4168
4169
|
l0: 60
|
|
4170
|
+
neuronpedia: gemma-3-4b-it/17-gemmascope-2-res-262k
|
|
4169
4171
|
- id: layer_17_width_262k_l0_medium_seed_1
|
|
4170
4172
|
path: resid_post/layer_17_width_262k_l0_medium_seed_1
|
|
4171
4173
|
l0: 60
|
|
@@ -4178,6 +4180,7 @@ gemma-scope-2-4b-it-res:
|
|
|
4178
4180
|
- id: layer_17_width_65k_l0_medium
|
|
4179
4181
|
path: resid_post/layer_17_width_65k_l0_medium
|
|
4180
4182
|
l0: 60
|
|
4183
|
+
neuronpedia: gemma-3-4b-it/17-gemmascope-2-res-65k
|
|
4181
4184
|
- id: layer_17_width_65k_l0_small
|
|
4182
4185
|
path: resid_post/layer_17_width_65k_l0_small
|
|
4183
4186
|
l0: 20
|
|
@@ -4187,6 +4190,7 @@ gemma-scope-2-4b-it-res:
|
|
|
4187
4190
|
- id: layer_22_width_16k_l0_medium
|
|
4188
4191
|
path: resid_post/layer_22_width_16k_l0_medium
|
|
4189
4192
|
l0: 60
|
|
4193
|
+
neuronpedia: gemma-3-4b-it/22-gemmascope-2-res-16k
|
|
4190
4194
|
- id: layer_22_width_16k_l0_small
|
|
4191
4195
|
path: resid_post/layer_22_width_16k_l0_small
|
|
4192
4196
|
l0: 20
|
|
@@ -4205,6 +4209,7 @@ gemma-scope-2-4b-it-res:
|
|
|
4205
4209
|
- id: layer_22_width_262k_l0_medium
|
|
4206
4210
|
path: resid_post/layer_22_width_262k_l0_medium
|
|
4207
4211
|
l0: 60
|
|
4212
|
+
neuronpedia: gemma-3-4b-it/22-gemmascope-2-res-262k
|
|
4208
4213
|
- id: layer_22_width_262k_l0_medium_seed_1
|
|
4209
4214
|
path: resid_post/layer_22_width_262k_l0_medium_seed_1
|
|
4210
4215
|
l0: 60
|
|
@@ -4217,6 +4222,7 @@ gemma-scope-2-4b-it-res:
|
|
|
4217
4222
|
- id: layer_22_width_65k_l0_medium
|
|
4218
4223
|
path: resid_post/layer_22_width_65k_l0_medium
|
|
4219
4224
|
l0: 60
|
|
4225
|
+
neuronpedia: gemma-3-4b-it/22-gemmascope-2-res-65k
|
|
4220
4226
|
- id: layer_22_width_65k_l0_small
|
|
4221
4227
|
path: resid_post/layer_22_width_65k_l0_small
|
|
4222
4228
|
l0: 20
|
|
@@ -4226,6 +4232,7 @@ gemma-scope-2-4b-it-res:
|
|
|
4226
4232
|
- id: layer_29_width_16k_l0_medium
|
|
4227
4233
|
path: resid_post/layer_29_width_16k_l0_medium
|
|
4228
4234
|
l0: 60
|
|
4235
|
+
neuronpedia: gemma-3-4b-it/29-gemmascope-2-res-16k
|
|
4229
4236
|
- id: layer_29_width_16k_l0_small
|
|
4230
4237
|
path: resid_post/layer_29_width_16k_l0_small
|
|
4231
4238
|
l0: 20
|
|
@@ -4244,6 +4251,7 @@ gemma-scope-2-4b-it-res:
|
|
|
4244
4251
|
- id: layer_29_width_262k_l0_medium
|
|
4245
4252
|
path: resid_post/layer_29_width_262k_l0_medium
|
|
4246
4253
|
l0: 60
|
|
4254
|
+
neuronpedia: gemma-3-4b-it/29-gemmascope-2-res-262k
|
|
4247
4255
|
- id: layer_29_width_262k_l0_medium_seed_1
|
|
4248
4256
|
path: resid_post/layer_29_width_262k_l0_medium_seed_1
|
|
4249
4257
|
l0: 60
|
|
@@ -4256,6 +4264,7 @@ gemma-scope-2-4b-it-res:
|
|
|
4256
4264
|
- id: layer_29_width_65k_l0_medium
|
|
4257
4265
|
path: resid_post/layer_29_width_65k_l0_medium
|
|
4258
4266
|
l0: 60
|
|
4267
|
+
neuronpedia: gemma-3-4b-it/29-gemmascope-2-res-65k
|
|
4259
4268
|
- id: layer_29_width_65k_l0_small
|
|
4260
4269
|
path: resid_post/layer_29_width_65k_l0_small
|
|
4261
4270
|
l0: 20
|
|
@@ -4265,6 +4274,7 @@ gemma-scope-2-4b-it-res:
|
|
|
4265
4274
|
- id: layer_9_width_16k_l0_medium
|
|
4266
4275
|
path: resid_post/layer_9_width_16k_l0_medium
|
|
4267
4276
|
l0: 53
|
|
4277
|
+
neuronpedia: gemma-3-4b-it/9-gemmascope-2-res-16k
|
|
4268
4278
|
- id: layer_9_width_16k_l0_small
|
|
4269
4279
|
path: resid_post/layer_9_width_16k_l0_small
|
|
4270
4280
|
l0: 17
|
|
@@ -4283,6 +4293,7 @@ gemma-scope-2-4b-it-res:
|
|
|
4283
4293
|
- id: layer_9_width_262k_l0_medium
|
|
4284
4294
|
path: resid_post/layer_9_width_262k_l0_medium
|
|
4285
4295
|
l0: 53
|
|
4296
|
+
neuronpedia: gemma-3-4b-it/9-gemmascope-2-res-262k
|
|
4286
4297
|
- id: layer_9_width_262k_l0_medium_seed_1
|
|
4287
4298
|
path: resid_post/layer_9_width_262k_l0_medium_seed_1
|
|
4288
4299
|
l0: 53
|
|
@@ -4295,6 +4306,7 @@ gemma-scope-2-4b-it-res:
|
|
|
4295
4306
|
- id: layer_9_width_65k_l0_medium
|
|
4296
4307
|
path: resid_post/layer_9_width_65k_l0_medium
|
|
4297
4308
|
l0: 53
|
|
4309
|
+
neuronpedia: gemma-3-4b-it/9-gemmascope-2-res-65k
|
|
4298
4310
|
- id: layer_9_width_65k_l0_small
|
|
4299
4311
|
path: resid_post/layer_9_width_65k_l0_small
|
|
4300
4312
|
l0: 17
|
|
@@ -14491,6 +14503,7 @@ gemma-scope-2-270m-it-res:
|
|
|
14491
14503
|
- id: layer_12_width_16k_l0_medium
|
|
14492
14504
|
path: resid_post/layer_12_width_16k_l0_medium
|
|
14493
14505
|
l0: 60
|
|
14506
|
+
neuronpedia: gemma-3-270m-it/12-gemmascope-2-res-16k
|
|
14494
14507
|
- id: layer_12_width_16k_l0_small
|
|
14495
14508
|
path: resid_post/layer_12_width_16k_l0_small
|
|
14496
14509
|
l0: 20
|
|
@@ -14509,6 +14522,7 @@ gemma-scope-2-270m-it-res:
|
|
|
14509
14522
|
- id: layer_12_width_262k_l0_medium
|
|
14510
14523
|
path: resid_post/layer_12_width_262k_l0_medium
|
|
14511
14524
|
l0: 60
|
|
14525
|
+
neuronpedia: gemma-3-270m-it/12-gemmascope-2-res-262k
|
|
14512
14526
|
- id: layer_12_width_262k_l0_medium_seed_1
|
|
14513
14527
|
path: resid_post/layer_12_width_262k_l0_medium_seed_1
|
|
14514
14528
|
l0: 60
|
|
@@ -14521,6 +14535,7 @@ gemma-scope-2-270m-it-res:
|
|
|
14521
14535
|
- id: layer_12_width_65k_l0_medium
|
|
14522
14536
|
path: resid_post/layer_12_width_65k_l0_medium
|
|
14523
14537
|
l0: 60
|
|
14538
|
+
neuronpedia: gemma-3-270m-it/12-gemmascope-2-res-65k
|
|
14524
14539
|
- id: layer_12_width_65k_l0_small
|
|
14525
14540
|
path: resid_post/layer_12_width_65k_l0_small
|
|
14526
14541
|
l0: 20
|
|
@@ -14530,6 +14545,7 @@ gemma-scope-2-270m-it-res:
|
|
|
14530
14545
|
- id: layer_15_width_16k_l0_medium
|
|
14531
14546
|
path: resid_post/layer_15_width_16k_l0_medium
|
|
14532
14547
|
l0: 60
|
|
14548
|
+
neuronpedia: gemma-3-270m-it/15-gemmascope-2-res-16k
|
|
14533
14549
|
- id: layer_15_width_16k_l0_small
|
|
14534
14550
|
path: resid_post/layer_15_width_16k_l0_small
|
|
14535
14551
|
l0: 20
|
|
@@ -14548,6 +14564,7 @@ gemma-scope-2-270m-it-res:
|
|
|
14548
14564
|
- id: layer_15_width_262k_l0_medium
|
|
14549
14565
|
path: resid_post/layer_15_width_262k_l0_medium
|
|
14550
14566
|
l0: 60
|
|
14567
|
+
neuronpedia: gemma-3-270m-it/15-gemmascope-2-res-262k
|
|
14551
14568
|
- id: layer_15_width_262k_l0_medium_seed_1
|
|
14552
14569
|
path: resid_post/layer_15_width_262k_l0_medium_seed_1
|
|
14553
14570
|
l0: 60
|
|
@@ -14560,6 +14577,7 @@ gemma-scope-2-270m-it-res:
|
|
|
14560
14577
|
- id: layer_15_width_65k_l0_medium
|
|
14561
14578
|
path: resid_post/layer_15_width_65k_l0_medium
|
|
14562
14579
|
l0: 60
|
|
14580
|
+
neuronpedia: gemma-3-270m-it/15-gemmascope-2-res-65k
|
|
14563
14581
|
- id: layer_15_width_65k_l0_small
|
|
14564
14582
|
path: resid_post/layer_15_width_65k_l0_small
|
|
14565
14583
|
l0: 20
|
|
@@ -14569,6 +14587,7 @@ gemma-scope-2-270m-it-res:
|
|
|
14569
14587
|
- id: layer_5_width_16k_l0_medium
|
|
14570
14588
|
path: resid_post/layer_5_width_16k_l0_medium
|
|
14571
14589
|
l0: 55
|
|
14590
|
+
neuronpedia: gemma-3-270m-it/5-gemmascope-2-res-16k
|
|
14572
14591
|
- id: layer_5_width_16k_l0_small
|
|
14573
14592
|
path: resid_post/layer_5_width_16k_l0_small
|
|
14574
14593
|
l0: 18
|
|
@@ -14587,6 +14606,7 @@ gemma-scope-2-270m-it-res:
|
|
|
14587
14606
|
- id: layer_5_width_262k_l0_medium
|
|
14588
14607
|
path: resid_post/layer_5_width_262k_l0_medium
|
|
14589
14608
|
l0: 55
|
|
14609
|
+
neuronpedia: gemma-3-270m-it/5-gemmascope-2-res-262k
|
|
14590
14610
|
- id: layer_5_width_262k_l0_medium_seed_1
|
|
14591
14611
|
path: resid_post/layer_5_width_262k_l0_medium_seed_1
|
|
14592
14612
|
l0: 55
|
|
@@ -14599,6 +14619,7 @@ gemma-scope-2-270m-it-res:
|
|
|
14599
14619
|
- id: layer_5_width_65k_l0_medium
|
|
14600
14620
|
path: resid_post/layer_5_width_65k_l0_medium
|
|
14601
14621
|
l0: 55
|
|
14622
|
+
neuronpedia: gemma-3-270m-it/5-gemmascope-2-res-65k
|
|
14602
14623
|
- id: layer_5_width_65k_l0_small
|
|
14603
14624
|
path: resid_post/layer_5_width_65k_l0_small
|
|
14604
14625
|
l0: 18
|
|
@@ -14608,6 +14629,7 @@ gemma-scope-2-270m-it-res:
|
|
|
14608
14629
|
- id: layer_9_width_16k_l0_medium
|
|
14609
14630
|
path: resid_post/layer_9_width_16k_l0_medium
|
|
14610
14631
|
l0: 60
|
|
14632
|
+
neuronpedia: gemma-3-270m-it/9-gemmascope-2-res-16k
|
|
14611
14633
|
- id: layer_9_width_16k_l0_small
|
|
14612
14634
|
path: resid_post/layer_9_width_16k_l0_small
|
|
14613
14635
|
l0: 20
|
|
@@ -14626,6 +14648,7 @@ gemma-scope-2-270m-it-res:
|
|
|
14626
14648
|
- id: layer_9_width_262k_l0_medium
|
|
14627
14649
|
path: resid_post/layer_9_width_262k_l0_medium
|
|
14628
14650
|
l0: 60
|
|
14651
|
+
neuronpedia: gemma-3-270m-it/9-gemmascope-2-res-262k
|
|
14629
14652
|
- id: layer_9_width_262k_l0_medium_seed_1
|
|
14630
14653
|
path: resid_post/layer_9_width_262k_l0_medium_seed_1
|
|
14631
14654
|
l0: 60
|
|
@@ -14638,6 +14661,7 @@ gemma-scope-2-270m-it-res:
|
|
|
14638
14661
|
- id: layer_9_width_65k_l0_medium
|
|
14639
14662
|
path: resid_post/layer_9_width_65k_l0_medium
|
|
14640
14663
|
l0: 60
|
|
14664
|
+
neuronpedia: gemma-3-270m-it/9-gemmascope-2-res-65k
|
|
14641
14665
|
- id: layer_9_width_65k_l0_small
|
|
14642
14666
|
path: resid_post/layer_9_width_65k_l0_small
|
|
14643
14667
|
l0: 20
|
|
@@ -18727,6 +18751,7 @@ gemma-scope-2-1b-it-res:
|
|
|
18727
18751
|
- id: layer_13_width_16k_l0_medium
|
|
18728
18752
|
path: resid_post/layer_13_width_16k_l0_medium
|
|
18729
18753
|
l0: 60
|
|
18754
|
+
neuronpedia: gemma-3-1b-it/13-gemmascope-2-res-16k
|
|
18730
18755
|
- id: layer_13_width_16k_l0_small
|
|
18731
18756
|
path: resid_post/layer_13_width_16k_l0_small
|
|
18732
18757
|
l0: 20
|
|
@@ -18745,6 +18770,7 @@ gemma-scope-2-1b-it-res:
|
|
|
18745
18770
|
- id: layer_13_width_262k_l0_medium
|
|
18746
18771
|
path: resid_post/layer_13_width_262k_l0_medium
|
|
18747
18772
|
l0: 60
|
|
18773
|
+
neuronpedia: gemma-3-1b-it/13-gemmascope-2-res-262k
|
|
18748
18774
|
- id: layer_13_width_262k_l0_medium_seed_1
|
|
18749
18775
|
path: resid_post/layer_13_width_262k_l0_medium_seed_1
|
|
18750
18776
|
l0: 60
|
|
@@ -18757,6 +18783,7 @@ gemma-scope-2-1b-it-res:
|
|
|
18757
18783
|
- id: layer_13_width_65k_l0_medium
|
|
18758
18784
|
path: resid_post/layer_13_width_65k_l0_medium
|
|
18759
18785
|
l0: 60
|
|
18786
|
+
neuronpedia: gemma-3-1b-it/13-gemmascope-2-res-65k
|
|
18760
18787
|
- id: layer_13_width_65k_l0_small
|
|
18761
18788
|
path: resid_post/layer_13_width_65k_l0_small
|
|
18762
18789
|
l0: 20
|
|
@@ -18766,6 +18793,7 @@ gemma-scope-2-1b-it-res:
|
|
|
18766
18793
|
- id: layer_17_width_16k_l0_medium
|
|
18767
18794
|
path: resid_post/layer_17_width_16k_l0_medium
|
|
18768
18795
|
l0: 60
|
|
18796
|
+
neuronpedia: gemma-3-1b-it/17-gemmascope-2-res-16k
|
|
18769
18797
|
- id: layer_17_width_16k_l0_small
|
|
18770
18798
|
path: resid_post/layer_17_width_16k_l0_small
|
|
18771
18799
|
l0: 20
|
|
@@ -18784,6 +18812,7 @@ gemma-scope-2-1b-it-res:
|
|
|
18784
18812
|
- id: layer_17_width_262k_l0_medium
|
|
18785
18813
|
path: resid_post/layer_17_width_262k_l0_medium
|
|
18786
18814
|
l0: 60
|
|
18815
|
+
neuronpedia: gemma-3-1b-it/17-gemmascope-2-res-262k
|
|
18787
18816
|
- id: layer_17_width_262k_l0_medium_seed_1
|
|
18788
18817
|
path: resid_post/layer_17_width_262k_l0_medium_seed_1
|
|
18789
18818
|
l0: 60
|
|
@@ -18796,6 +18825,7 @@ gemma-scope-2-1b-it-res:
|
|
|
18796
18825
|
- id: layer_17_width_65k_l0_medium
|
|
18797
18826
|
path: resid_post/layer_17_width_65k_l0_medium
|
|
18798
18827
|
l0: 60
|
|
18828
|
+
neuronpedia: gemma-3-1b-it/17-gemmascope-2-res-65k
|
|
18799
18829
|
- id: layer_17_width_65k_l0_small
|
|
18800
18830
|
path: resid_post/layer_17_width_65k_l0_small
|
|
18801
18831
|
l0: 20
|
|
@@ -18805,6 +18835,7 @@ gemma-scope-2-1b-it-res:
|
|
|
18805
18835
|
- id: layer_22_width_16k_l0_medium
|
|
18806
18836
|
path: resid_post/layer_22_width_16k_l0_medium
|
|
18807
18837
|
l0: 60
|
|
18838
|
+
neuronpedia: gemma-3-1b-it/22-gemmascope-2-res-16k
|
|
18808
18839
|
- id: layer_22_width_16k_l0_small
|
|
18809
18840
|
path: resid_post/layer_22_width_16k_l0_small
|
|
18810
18841
|
l0: 20
|
|
@@ -18823,6 +18854,7 @@ gemma-scope-2-1b-it-res:
|
|
|
18823
18854
|
- id: layer_22_width_262k_l0_medium
|
|
18824
18855
|
path: resid_post/layer_22_width_262k_l0_medium
|
|
18825
18856
|
l0: 60
|
|
18857
|
+
neuronpedia: gemma-3-1b-it/22-gemmascope-2-res-262k
|
|
18826
18858
|
- id: layer_22_width_262k_l0_medium_seed_1
|
|
18827
18859
|
path: resid_post/layer_22_width_262k_l0_medium_seed_1
|
|
18828
18860
|
l0: 60
|
|
@@ -18835,6 +18867,7 @@ gemma-scope-2-1b-it-res:
|
|
|
18835
18867
|
- id: layer_22_width_65k_l0_medium
|
|
18836
18868
|
path: resid_post/layer_22_width_65k_l0_medium
|
|
18837
18869
|
l0: 60
|
|
18870
|
+
neuronpedia: gemma-3-1b-it/22-gemmascope-2-res-65k
|
|
18838
18871
|
- id: layer_22_width_65k_l0_small
|
|
18839
18872
|
path: resid_post/layer_22_width_65k_l0_small
|
|
18840
18873
|
l0: 20
|
|
@@ -18844,6 +18877,7 @@ gemma-scope-2-1b-it-res:
|
|
|
18844
18877
|
- id: layer_7_width_16k_l0_medium
|
|
18845
18878
|
path: resid_post/layer_7_width_16k_l0_medium
|
|
18846
18879
|
l0: 54
|
|
18880
|
+
neuronpedia: gemma-3-1b-it/7-gemmascope-2-res-16k
|
|
18847
18881
|
- id: layer_7_width_16k_l0_small
|
|
18848
18882
|
path: resid_post/layer_7_width_16k_l0_small
|
|
18849
18883
|
l0: 18
|
|
@@ -18862,6 +18896,7 @@ gemma-scope-2-1b-it-res:
|
|
|
18862
18896
|
- id: layer_7_width_262k_l0_medium
|
|
18863
18897
|
path: resid_post/layer_7_width_262k_l0_medium
|
|
18864
18898
|
l0: 54
|
|
18899
|
+
neuronpedia: gemma-3-1b-it/7-gemmascope-2-res-262k
|
|
18865
18900
|
- id: layer_7_width_262k_l0_medium_seed_1
|
|
18866
18901
|
path: resid_post/layer_7_width_262k_l0_medium_seed_1
|
|
18867
18902
|
l0: 54
|
|
@@ -18874,6 +18909,7 @@ gemma-scope-2-1b-it-res:
|
|
|
18874
18909
|
- id: layer_7_width_65k_l0_medium
|
|
18875
18910
|
path: resid_post/layer_7_width_65k_l0_medium
|
|
18876
18911
|
l0: 54
|
|
18912
|
+
neuronpedia: gemma-3-1b-it/7-gemmascope-2-res-65k
|
|
18877
18913
|
- id: layer_7_width_65k_l0_small
|
|
18878
18914
|
path: resid_post/layer_7_width_65k_l0_small
|
|
18879
18915
|
l0: 18
|
|
@@ -2,12 +2,12 @@
|
|
|
2
2
|
Functions for generating synthetic feature activations.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
+
import math
|
|
5
6
|
from collections.abc import Callable, Sequence
|
|
6
7
|
|
|
7
8
|
import torch
|
|
8
|
-
from scipy.stats import norm
|
|
9
9
|
from torch import nn
|
|
10
|
-
from torch.distributions import
|
|
10
|
+
from torch.distributions import MultivariateNormal
|
|
11
11
|
|
|
12
12
|
from sae_lens.synthetic.correlation import LowRankCorrelationMatrix
|
|
13
13
|
from sae_lens.util import str_to_dtype
|
|
@@ -34,6 +34,7 @@ class ActivationGenerator(nn.Module):
|
|
|
34
34
|
correlation_matrix: torch.Tensor | None
|
|
35
35
|
low_rank_correlation: tuple[torch.Tensor, torch.Tensor] | None
|
|
36
36
|
correlation_thresholds: torch.Tensor | None
|
|
37
|
+
use_sparse_tensors: bool
|
|
37
38
|
|
|
38
39
|
def __init__(
|
|
39
40
|
self,
|
|
@@ -45,7 +46,34 @@ class ActivationGenerator(nn.Module):
|
|
|
45
46
|
correlation_matrix: CorrelationMatrixInput | None = None,
|
|
46
47
|
device: torch.device | str = "cpu",
|
|
47
48
|
dtype: torch.dtype | str = "float32",
|
|
49
|
+
use_sparse_tensors: bool = False,
|
|
48
50
|
):
|
|
51
|
+
"""
|
|
52
|
+
Create a new ActivationGenerator.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
num_features: Number of features to generate activations for.
|
|
56
|
+
firing_probabilities: Probability of each feature firing. Can be a single
|
|
57
|
+
float (applied to all features) or a tensor of shape (num_features,).
|
|
58
|
+
std_firing_magnitudes: Standard deviation of firing magnitudes. Can be a
|
|
59
|
+
single float or a tensor of shape (num_features,). Defaults to 0.0
|
|
60
|
+
(deterministic magnitudes).
|
|
61
|
+
mean_firing_magnitudes: Mean firing magnitude when a feature fires. Can be
|
|
62
|
+
a single float or a tensor of shape (num_features,). Defaults to 1.0.
|
|
63
|
+
modify_activations: Optional function(s) to modify activations after
|
|
64
|
+
generation. Can be a single callable, a sequence of callables (applied
|
|
65
|
+
in order), or None. Useful for applying hierarchy constraints.
|
|
66
|
+
correlation_matrix: Optional correlation structure between features. Can be:
|
|
67
|
+
|
|
68
|
+
- A full correlation matrix tensor of shape (num_features, num_features)
|
|
69
|
+
- A LowRankCorrelationMatrix for memory-efficient large-scale correlations
|
|
70
|
+
- A tuple of (factor, diag) tensors representing low-rank structure
|
|
71
|
+
|
|
72
|
+
device: Device to place tensors on. Defaults to "cpu".
|
|
73
|
+
dtype: Data type for tensors. Defaults to "float32".
|
|
74
|
+
use_sparse_tensors: If True, return sparse COO tensors from sample().
|
|
75
|
+
Only recommended when using massive numbers of features. Defaults to False.
|
|
76
|
+
"""
|
|
49
77
|
super().__init__()
|
|
50
78
|
self.num_features = num_features
|
|
51
79
|
self.firing_probabilities = _to_tensor(
|
|
@@ -61,6 +89,7 @@ class ActivationGenerator(nn.Module):
|
|
|
61
89
|
self.correlation_thresholds = None
|
|
62
90
|
self.correlation_matrix = None
|
|
63
91
|
self.low_rank_correlation = None
|
|
92
|
+
self.use_sparse_tensors = use_sparse_tensors
|
|
64
93
|
|
|
65
94
|
if correlation_matrix is not None:
|
|
66
95
|
if isinstance(correlation_matrix, torch.Tensor):
|
|
@@ -76,12 +105,15 @@ class ActivationGenerator(nn.Module):
|
|
|
76
105
|
_validate_low_rank_correlation(
|
|
77
106
|
correlation_factor, correlation_diag, num_features
|
|
78
107
|
)
|
|
79
|
-
|
|
108
|
+
# Pre-compute sqrt for efficiency (used every sample call)
|
|
109
|
+
self.low_rank_correlation = (
|
|
110
|
+
correlation_factor,
|
|
111
|
+
correlation_diag.sqrt(),
|
|
112
|
+
)
|
|
80
113
|
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
dtype=self.firing_probabilities.dtype,
|
|
114
|
+
# Vectorized inverse normal CDF: norm.ppf(1-p) = sqrt(2) * erfinv(1 - 2*p)
|
|
115
|
+
self.correlation_thresholds = math.sqrt(2) * torch.erfinv(
|
|
116
|
+
1 - 2 * self.firing_probabilities
|
|
85
117
|
)
|
|
86
118
|
|
|
87
119
|
@torch.no_grad()
|
|
@@ -105,7 +137,7 @@ class ActivationGenerator(nn.Module):
|
|
|
105
137
|
|
|
106
138
|
if self.correlation_matrix is not None:
|
|
107
139
|
assert self.correlation_thresholds is not None
|
|
108
|
-
|
|
140
|
+
firing_indices = _generate_correlated_features(
|
|
109
141
|
batch_size,
|
|
110
142
|
self.correlation_matrix,
|
|
111
143
|
self.correlation_thresholds,
|
|
@@ -113,7 +145,7 @@ class ActivationGenerator(nn.Module):
|
|
|
113
145
|
)
|
|
114
146
|
elif self.low_rank_correlation is not None:
|
|
115
147
|
assert self.correlation_thresholds is not None
|
|
116
|
-
|
|
148
|
+
firing_indices = _generate_low_rank_correlated_features(
|
|
117
149
|
batch_size,
|
|
118
150
|
self.low_rank_correlation[0],
|
|
119
151
|
self.low_rank_correlation[1],
|
|
@@ -121,23 +153,58 @@ class ActivationGenerator(nn.Module):
|
|
|
121
153
|
device,
|
|
122
154
|
)
|
|
123
155
|
else:
|
|
124
|
-
|
|
156
|
+
firing_indices = torch.bernoulli(
|
|
125
157
|
self.firing_probabilities.unsqueeze(0).expand(batch_size, -1)
|
|
158
|
+
).nonzero(as_tuple=True)
|
|
159
|
+
|
|
160
|
+
# Compute activations only at firing positions (sparse optimization)
|
|
161
|
+
feature_indices = firing_indices[1]
|
|
162
|
+
num_firing = feature_indices.shape[0]
|
|
163
|
+
mean_at_firing = self.mean_firing_magnitudes[feature_indices]
|
|
164
|
+
std_at_firing = self.std_firing_magnitudes[feature_indices]
|
|
165
|
+
random_deltas = (
|
|
166
|
+
torch.randn(
|
|
167
|
+
num_firing, device=device, dtype=self.mean_firing_magnitudes.dtype
|
|
126
168
|
)
|
|
127
|
-
|
|
128
|
-
firing_magnitude_delta = torch.normal(
|
|
129
|
-
torch.zeros_like(self.firing_probabilities)
|
|
130
|
-
.unsqueeze(0)
|
|
131
|
-
.expand(batch_size, -1),
|
|
132
|
-
self.std_firing_magnitudes.unsqueeze(0).expand(batch_size, -1),
|
|
169
|
+
* std_at_firing
|
|
133
170
|
)
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
171
|
+
activations_at_firing = (mean_at_firing + random_deltas).relu()
|
|
172
|
+
|
|
173
|
+
if self.use_sparse_tensors:
|
|
174
|
+
# Return sparse COO tensor
|
|
175
|
+
indices = torch.stack(firing_indices) # [2, nnz]
|
|
176
|
+
feature_activations = torch.sparse_coo_tensor(
|
|
177
|
+
indices,
|
|
178
|
+
activations_at_firing,
|
|
179
|
+
size=(batch_size, self.num_features),
|
|
180
|
+
device=device,
|
|
181
|
+
dtype=self.mean_firing_magnitudes.dtype,
|
|
182
|
+
)
|
|
183
|
+
else:
|
|
184
|
+
# Dense tensor path
|
|
185
|
+
feature_activations = torch.zeros(
|
|
186
|
+
batch_size,
|
|
187
|
+
self.num_features,
|
|
188
|
+
device=device,
|
|
189
|
+
dtype=self.mean_firing_magnitudes.dtype,
|
|
190
|
+
)
|
|
191
|
+
feature_activations[firing_indices] = activations_at_firing
|
|
138
192
|
|
|
139
193
|
if self.modify_activations is not None:
|
|
140
|
-
feature_activations = self.modify_activations(feature_activations)
|
|
194
|
+
feature_activations = self.modify_activations(feature_activations)
|
|
195
|
+
if feature_activations.is_sparse:
|
|
196
|
+
# Apply relu to sparse values
|
|
197
|
+
feature_activations = feature_activations.coalesce()
|
|
198
|
+
feature_activations = torch.sparse_coo_tensor(
|
|
199
|
+
feature_activations.indices(),
|
|
200
|
+
feature_activations.values().relu(),
|
|
201
|
+
feature_activations.shape,
|
|
202
|
+
device=feature_activations.device,
|
|
203
|
+
dtype=feature_activations.dtype,
|
|
204
|
+
)
|
|
205
|
+
else:
|
|
206
|
+
feature_activations = feature_activations.relu()
|
|
207
|
+
|
|
141
208
|
return feature_activations
|
|
142
209
|
|
|
143
210
|
def forward(self, batch_size: int) -> torch.Tensor:
|
|
@@ -149,7 +216,7 @@ def _generate_correlated_features(
|
|
|
149
216
|
correlation_matrix: torch.Tensor,
|
|
150
217
|
thresholds: torch.Tensor,
|
|
151
218
|
device: torch.device,
|
|
152
|
-
) -> torch.Tensor:
|
|
219
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
153
220
|
"""
|
|
154
221
|
Generate correlated binary features using multivariate Gaussian sampling.
|
|
155
222
|
|
|
@@ -163,7 +230,7 @@ def _generate_correlated_features(
|
|
|
163
230
|
device: Device to generate samples on
|
|
164
231
|
|
|
165
232
|
Returns:
|
|
166
|
-
|
|
233
|
+
Tuple of (row_indices, col_indices) for firing features
|
|
167
234
|
"""
|
|
168
235
|
num_features = correlation_matrix.shape[0]
|
|
169
236
|
|
|
@@ -173,16 +240,17 @@ def _generate_correlated_features(
|
|
|
173
240
|
)
|
|
174
241
|
|
|
175
242
|
gaussian_samples = mvn.sample((batch_size,))
|
|
176
|
-
|
|
243
|
+
indices = (gaussian_samples > thresholds.unsqueeze(0)).nonzero(as_tuple=True)
|
|
244
|
+
return indices[0], indices[1]
|
|
177
245
|
|
|
178
246
|
|
|
179
247
|
def _generate_low_rank_correlated_features(
|
|
180
248
|
batch_size: int,
|
|
181
249
|
correlation_factor: torch.Tensor,
|
|
182
|
-
|
|
250
|
+
cov_diag_sqrt: torch.Tensor,
|
|
183
251
|
thresholds: torch.Tensor,
|
|
184
252
|
device: torch.device,
|
|
185
|
-
) -> torch.Tensor:
|
|
253
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
186
254
|
"""
|
|
187
255
|
Generate correlated binary features using low-rank multivariate Gaussian sampling.
|
|
188
256
|
|
|
@@ -192,23 +260,29 @@ def _generate_low_rank_correlated_features(
|
|
|
192
260
|
Args:
|
|
193
261
|
batch_size: Number of samples to generate
|
|
194
262
|
correlation_factor: Factor matrix of shape (num_features, rank)
|
|
195
|
-
|
|
263
|
+
cov_diag_sqrt: Pre-computed sqrt of diagonal term, shape (num_features,)
|
|
196
264
|
thresholds: Pre-computed thresholds for each feature (from inverse normal CDF)
|
|
197
265
|
device: Device to generate samples on
|
|
198
266
|
|
|
199
267
|
Returns:
|
|
200
|
-
|
|
268
|
+
Tuple of (row_indices, col_indices) for firing features
|
|
201
269
|
"""
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
270
|
+
# Manual low-rank MVN sampling to enable autocast for the expensive matmul
|
|
271
|
+
# samples = eps @ cov_factor.T + eta * sqrt(cov_diag)
|
|
272
|
+
# where eps ~ N(0, I_rank) and eta ~ N(0, I_n)
|
|
273
|
+
|
|
274
|
+
num_features, rank = correlation_factor.shape
|
|
275
|
+
|
|
276
|
+
# Generate random samples in float32 for numerical stability
|
|
277
|
+
eps = torch.randn(batch_size, rank, device=device, dtype=correlation_factor.dtype)
|
|
278
|
+
eta = torch.randn(
|
|
279
|
+
batch_size, num_features, device=device, dtype=cov_diag_sqrt.dtype
|
|
208
280
|
)
|
|
209
281
|
|
|
210
|
-
gaussian_samples =
|
|
211
|
-
|
|
282
|
+
gaussian_samples = eps @ correlation_factor.T + eta * cov_diag_sqrt
|
|
283
|
+
|
|
284
|
+
indices = (gaussian_samples > thresholds.unsqueeze(0)).nonzero(as_tuple=True)
|
|
285
|
+
return indices[0], indices[1]
|
|
212
286
|
|
|
213
287
|
|
|
214
288
|
def _to_tensor(
|
|
@@ -168,9 +168,18 @@ class FeatureDictionary(nn.Module):
|
|
|
168
168
|
|
|
169
169
|
Args:
|
|
170
170
|
feature_activations: Tensor of shape [batch, num_features] containing
|
|
171
|
-
sparse feature activation values
|
|
171
|
+
sparse feature activation values. Can be dense or sparse COO.
|
|
172
172
|
|
|
173
173
|
Returns:
|
|
174
174
|
Tensor of shape [batch, hidden_dim] containing dense hidden activations
|
|
175
175
|
"""
|
|
176
|
+
if feature_activations.is_sparse:
|
|
177
|
+
# autocast is disabled here because sparse matmul is not supported with bfloat16
|
|
178
|
+
with torch.autocast(
|
|
179
|
+
device_type=feature_activations.device.type, enabled=False
|
|
180
|
+
):
|
|
181
|
+
return (
|
|
182
|
+
torch.sparse.mm(feature_activations, self.feature_vectors)
|
|
183
|
+
+ self.bias
|
|
184
|
+
)
|
|
176
185
|
return feature_activations @ self.feature_vectors + self.bias
|
|
@@ -147,6 +147,14 @@ class _SparseHierarchyData:
|
|
|
147
147
|
# Total number of ME groups
|
|
148
148
|
num_groups: int
|
|
149
149
|
|
|
150
|
+
# Sparse COO support: Feature-to-parent mapping
|
|
151
|
+
# feat_to_parent[f] = parent feature index, or -1 if root/no parent
|
|
152
|
+
feat_to_parent: torch.Tensor | None = None # [num_features]
|
|
153
|
+
|
|
154
|
+
# Sparse COO support: Feature-to-ME-group mapping
|
|
155
|
+
# feat_to_me_group[f] = group index, or -1 if not in any ME group
|
|
156
|
+
feat_to_me_group: torch.Tensor | None = None # [num_features]
|
|
157
|
+
|
|
150
158
|
|
|
151
159
|
def _build_sparse_hierarchy(
|
|
152
160
|
roots: Sequence[HierarchyNode],
|
|
@@ -232,7 +240,11 @@ def _build_sparse_hierarchy(
|
|
|
232
240
|
me_indices = torch.empty(0, dtype=torch.long)
|
|
233
241
|
|
|
234
242
|
level_data.append(
|
|
235
|
-
_LevelData(
|
|
243
|
+
_LevelData(
|
|
244
|
+
features=feats,
|
|
245
|
+
parents=parents,
|
|
246
|
+
me_group_indices=me_indices,
|
|
247
|
+
)
|
|
236
248
|
)
|
|
237
249
|
|
|
238
250
|
# Build group siblings and parents tensors
|
|
@@ -254,12 +266,30 @@ def _build_sparse_hierarchy(
|
|
|
254
266
|
me_group_parents = torch.empty(0, dtype=torch.long)
|
|
255
267
|
num_groups = 0
|
|
256
268
|
|
|
269
|
+
# Build sparse COO support: feat_to_parent and feat_to_me_group mappings
|
|
270
|
+
# First determine num_features (max feature index + 1)
|
|
271
|
+
all_features = [f for f, _, _ in feature_info]
|
|
272
|
+
num_features = max(all_features) + 1 if all_features else 0
|
|
273
|
+
|
|
274
|
+
# Build feature-to-parent mapping
|
|
275
|
+
feat_to_parent = torch.full((num_features,), -1, dtype=torch.long)
|
|
276
|
+
for feat, parent, _ in feature_info:
|
|
277
|
+
feat_to_parent[feat] = parent
|
|
278
|
+
|
|
279
|
+
# Build feature-to-ME-group mapping
|
|
280
|
+
feat_to_me_group = torch.full((num_features,), -1, dtype=torch.long)
|
|
281
|
+
for g_idx, (_, _, siblings) in enumerate(me_groups):
|
|
282
|
+
for sib in siblings:
|
|
283
|
+
feat_to_me_group[sib] = g_idx
|
|
284
|
+
|
|
257
285
|
return _SparseHierarchyData(
|
|
258
286
|
level_data=level_data,
|
|
259
287
|
me_group_siblings=me_group_siblings,
|
|
260
288
|
me_group_sizes=me_group_sizes,
|
|
261
289
|
me_group_parents=me_group_parents,
|
|
262
290
|
num_groups=num_groups,
|
|
291
|
+
feat_to_parent=feat_to_parent,
|
|
292
|
+
feat_to_me_group=feat_to_me_group,
|
|
263
293
|
)
|
|
264
294
|
|
|
265
295
|
|
|
@@ -396,8 +426,9 @@ def _apply_me_for_groups(
|
|
|
396
426
|
# Random selection for winner
|
|
397
427
|
# Use -1e9 instead of -inf to avoid creating a tensor (torch.tensor(-float("inf")))
|
|
398
428
|
# on every call. Since random scores are in [0,1], -1e9 is effectively -inf for argmax.
|
|
429
|
+
_INACTIVE_SCORE = -1e9
|
|
399
430
|
random_scores = torch.rand(num_conflicts, max_siblings, device=device)
|
|
400
|
-
random_scores[~conflict_active] =
|
|
431
|
+
random_scores[~conflict_active] = _INACTIVE_SCORE
|
|
401
432
|
|
|
402
433
|
winner_idx = random_scores.argmax(dim=1)
|
|
403
434
|
|
|
@@ -420,6 +451,275 @@ def _apply_me_for_groups(
|
|
|
420
451
|
activations[deact_batch, deact_feat] = 0
|
|
421
452
|
|
|
422
453
|
|
|
454
|
+
# ---------------------------------------------------------------------------
|
|
455
|
+
# Sparse COO hierarchy implementation
|
|
456
|
+
# ---------------------------------------------------------------------------
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
def _apply_hierarchy_sparse_coo(
|
|
460
|
+
sparse_tensor: torch.Tensor,
|
|
461
|
+
sparse_data: _SparseHierarchyData,
|
|
462
|
+
) -> torch.Tensor:
|
|
463
|
+
"""
|
|
464
|
+
Apply hierarchy constraints to a sparse COO tensor.
|
|
465
|
+
|
|
466
|
+
This is the sparse analog of _apply_hierarchy_sparse. It processes
|
|
467
|
+
level-by-level, applying parent deactivation then mutual exclusion.
|
|
468
|
+
"""
|
|
469
|
+
if sparse_tensor._nnz() == 0:
|
|
470
|
+
return sparse_tensor
|
|
471
|
+
|
|
472
|
+
sparse_tensor = sparse_tensor.coalesce()
|
|
473
|
+
|
|
474
|
+
for level_data in sparse_data.level_data:
|
|
475
|
+
# Step 1: Apply parent deactivation for features at this level
|
|
476
|
+
if level_data.features.numel() > 0:
|
|
477
|
+
sparse_tensor = _apply_parent_deactivation_coo(
|
|
478
|
+
sparse_tensor, level_data, sparse_data
|
|
479
|
+
)
|
|
480
|
+
|
|
481
|
+
# Step 2: Apply ME for groups whose parent is at this level
|
|
482
|
+
if level_data.me_group_indices.numel() > 0:
|
|
483
|
+
sparse_tensor = _apply_me_coo(
|
|
484
|
+
sparse_tensor, level_data.me_group_indices, sparse_data
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
return sparse_tensor
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
def _apply_parent_deactivation_coo(
|
|
491
|
+
sparse_tensor: torch.Tensor,
|
|
492
|
+
level_data: _LevelData,
|
|
493
|
+
sparse_data: _SparseHierarchyData,
|
|
494
|
+
) -> torch.Tensor:
|
|
495
|
+
"""
|
|
496
|
+
Remove children from sparse COO tensor when their parent is inactive.
|
|
497
|
+
|
|
498
|
+
Uses searchsorted for efficient membership testing of parent activity.
|
|
499
|
+
"""
|
|
500
|
+
if sparse_tensor._nnz() == 0 or level_data.features.numel() == 0:
|
|
501
|
+
return sparse_tensor
|
|
502
|
+
|
|
503
|
+
sparse_tensor = sparse_tensor.coalesce()
|
|
504
|
+
indices = sparse_tensor.indices() # [2, nnz]
|
|
505
|
+
values = sparse_tensor.values() # [nnz]
|
|
506
|
+
batch_indices = indices[0]
|
|
507
|
+
feat_indices = indices[1]
|
|
508
|
+
|
|
509
|
+
_, num_features = sparse_tensor.shape
|
|
510
|
+
device = sparse_tensor.device
|
|
511
|
+
nnz = indices.shape[1]
|
|
512
|
+
|
|
513
|
+
# Build set of active (batch, feature) pairs for efficient lookup
|
|
514
|
+
# Encode as: batch_idx * num_features + feat_idx
|
|
515
|
+
active_pairs = batch_indices * num_features + feat_indices
|
|
516
|
+
active_pairs_sorted, _ = active_pairs.sort()
|
|
517
|
+
|
|
518
|
+
# Use the precomputed feat_to_parent mapping
|
|
519
|
+
assert sparse_data.feat_to_parent is not None
|
|
520
|
+
hierarchy_num_features = sparse_data.feat_to_parent.numel()
|
|
521
|
+
|
|
522
|
+
# Handle features outside the hierarchy (they have no parent, pass through)
|
|
523
|
+
in_hierarchy = feat_indices < hierarchy_num_features
|
|
524
|
+
parent_of_feat = torch.full((nnz,), -1, dtype=torch.long, device=device)
|
|
525
|
+
parent_of_feat[in_hierarchy] = sparse_data.feat_to_parent[
|
|
526
|
+
feat_indices[in_hierarchy]
|
|
527
|
+
]
|
|
528
|
+
|
|
529
|
+
# Find entries that have a parent (parent >= 0 means this feature has a parent)
|
|
530
|
+
has_parent = parent_of_feat >= 0
|
|
531
|
+
|
|
532
|
+
if not has_parent.any():
|
|
533
|
+
return sparse_tensor
|
|
534
|
+
|
|
535
|
+
# For entries with parents, check if parent is active
|
|
536
|
+
child_entry_indices = torch.where(has_parent)[0]
|
|
537
|
+
child_batch = batch_indices[has_parent]
|
|
538
|
+
child_parents = parent_of_feat[has_parent]
|
|
539
|
+
|
|
540
|
+
# Look up parent activity using searchsorted
|
|
541
|
+
parent_pairs = child_batch * num_features + child_parents
|
|
542
|
+
search_pos = torch.searchsorted(active_pairs_sorted, parent_pairs)
|
|
543
|
+
search_pos = search_pos.clamp(max=active_pairs_sorted.numel() - 1)
|
|
544
|
+
parent_active = active_pairs_sorted[search_pos] == parent_pairs
|
|
545
|
+
|
|
546
|
+
# Handle empty case
|
|
547
|
+
if active_pairs_sorted.numel() == 0:
|
|
548
|
+
parent_active = torch.zeros_like(parent_pairs, dtype=torch.bool)
|
|
549
|
+
|
|
550
|
+
# Build keep mask: keep entry if it's a root OR its parent is active
|
|
551
|
+
keep_mask = torch.ones(nnz, dtype=torch.bool, device=device)
|
|
552
|
+
keep_mask[child_entry_indices[~parent_active]] = False
|
|
553
|
+
|
|
554
|
+
if keep_mask.all():
|
|
555
|
+
return sparse_tensor
|
|
556
|
+
|
|
557
|
+
return torch.sparse_coo_tensor(
|
|
558
|
+
indices[:, keep_mask],
|
|
559
|
+
values[keep_mask],
|
|
560
|
+
sparse_tensor.shape,
|
|
561
|
+
device=device,
|
|
562
|
+
dtype=sparse_tensor.dtype,
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
|
|
566
|
+
def _apply_me_coo(
|
|
567
|
+
sparse_tensor: torch.Tensor,
|
|
568
|
+
group_indices: torch.Tensor,
|
|
569
|
+
sparse_data: _SparseHierarchyData,
|
|
570
|
+
) -> torch.Tensor:
|
|
571
|
+
"""
|
|
572
|
+
Apply mutual exclusion to sparse COO tensor.
|
|
573
|
+
|
|
574
|
+
For each ME group with multiple active siblings in the same batch,
|
|
575
|
+
randomly selects one winner and removes the rest.
|
|
576
|
+
"""
|
|
577
|
+
if sparse_tensor._nnz() == 0 or group_indices.numel() == 0:
|
|
578
|
+
return sparse_tensor
|
|
579
|
+
|
|
580
|
+
sparse_tensor = sparse_tensor.coalesce()
|
|
581
|
+
indices = sparse_tensor.indices() # [2, nnz]
|
|
582
|
+
values = sparse_tensor.values() # [nnz]
|
|
583
|
+
batch_indices = indices[0]
|
|
584
|
+
feat_indices = indices[1]
|
|
585
|
+
|
|
586
|
+
_, num_features = sparse_tensor.shape
|
|
587
|
+
device = sparse_tensor.device
|
|
588
|
+
nnz = indices.shape[1]
|
|
589
|
+
|
|
590
|
+
# Use precomputed feat_to_me_group mapping
|
|
591
|
+
assert sparse_data.feat_to_me_group is not None
|
|
592
|
+
hierarchy_num_features = sparse_data.feat_to_me_group.numel()
|
|
593
|
+
|
|
594
|
+
# Handle features outside the hierarchy (they are not in any ME group)
|
|
595
|
+
in_hierarchy = feat_indices < hierarchy_num_features
|
|
596
|
+
me_group_of_feat = torch.full((nnz,), -1, dtype=torch.long, device=device)
|
|
597
|
+
me_group_of_feat[in_hierarchy] = sparse_data.feat_to_me_group[
|
|
598
|
+
feat_indices[in_hierarchy]
|
|
599
|
+
]
|
|
600
|
+
|
|
601
|
+
# Find entries that belong to ME groups we're processing (vectorized)
|
|
602
|
+
in_relevant_group = torch.isin(me_group_of_feat, group_indices)
|
|
603
|
+
|
|
604
|
+
if not in_relevant_group.any():
|
|
605
|
+
return sparse_tensor
|
|
606
|
+
|
|
607
|
+
# Get the ME entries
|
|
608
|
+
me_entry_indices = torch.where(in_relevant_group)[0]
|
|
609
|
+
me_batch = batch_indices[in_relevant_group]
|
|
610
|
+
me_group = me_group_of_feat[in_relevant_group]
|
|
611
|
+
|
|
612
|
+
# Check parent activity for ME groups (only apply ME if parent is active)
|
|
613
|
+
me_group_parents = sparse_data.me_group_parents[me_group]
|
|
614
|
+
has_parent = me_group_parents >= 0
|
|
615
|
+
|
|
616
|
+
if has_parent.any():
|
|
617
|
+
# Build active pairs for parent lookup
|
|
618
|
+
active_pairs = batch_indices * num_features + feat_indices
|
|
619
|
+
active_pairs_sorted, _ = active_pairs.sort()
|
|
620
|
+
|
|
621
|
+
parent_pairs = (
|
|
622
|
+
me_batch[has_parent] * num_features + me_group_parents[has_parent]
|
|
623
|
+
)
|
|
624
|
+
search_pos = torch.searchsorted(active_pairs_sorted, parent_pairs)
|
|
625
|
+
search_pos = search_pos.clamp(max=active_pairs_sorted.numel() - 1)
|
|
626
|
+
parent_active_for_has_parent = active_pairs_sorted[search_pos] == parent_pairs
|
|
627
|
+
|
|
628
|
+
# Build full parent_active mask
|
|
629
|
+
parent_active = torch.ones(
|
|
630
|
+
me_entry_indices.numel(), dtype=torch.bool, device=device
|
|
631
|
+
)
|
|
632
|
+
parent_active[has_parent] = parent_active_for_has_parent
|
|
633
|
+
|
|
634
|
+
# Filter to only ME entries where parent is active
|
|
635
|
+
valid_me = parent_active
|
|
636
|
+
me_entry_indices = me_entry_indices[valid_me]
|
|
637
|
+
me_batch = me_batch[valid_me]
|
|
638
|
+
me_group = me_group[valid_me]
|
|
639
|
+
|
|
640
|
+
if me_entry_indices.numel() == 0:
|
|
641
|
+
return sparse_tensor
|
|
642
|
+
|
|
643
|
+
# Encode (batch, group) pairs
|
|
644
|
+
num_groups = sparse_data.num_groups
|
|
645
|
+
batch_group_pairs = me_batch * num_groups + me_group
|
|
646
|
+
|
|
647
|
+
# Find unique (batch, group) pairs and count occurrences
|
|
648
|
+
unique_bg, inverse, counts = torch.unique(
|
|
649
|
+
batch_group_pairs, return_inverse=True, return_counts=True
|
|
650
|
+
)
|
|
651
|
+
|
|
652
|
+
# Only process pairs with count > 1 (conflicts)
|
|
653
|
+
has_conflict = counts > 1
|
|
654
|
+
|
|
655
|
+
if not has_conflict.any():
|
|
656
|
+
return sparse_tensor
|
|
657
|
+
|
|
658
|
+
# For efficiency, we process all conflicts together
|
|
659
|
+
# Assign random scores to each ME entry
|
|
660
|
+
random_scores = torch.rand(me_entry_indices.numel(), device=device)
|
|
661
|
+
|
|
662
|
+
# For each (batch, group) pair, we want the entry with highest score to be winner
|
|
663
|
+
# Use scatter_reduce to find max score per (batch, group)
|
|
664
|
+
bg_to_dense = torch.zeros(unique_bg.numel(), dtype=torch.long, device=device)
|
|
665
|
+
bg_to_dense[has_conflict.nonzero(as_tuple=True)[0]] = torch.arange(
|
|
666
|
+
has_conflict.sum(), device=device
|
|
667
|
+
)
|
|
668
|
+
|
|
669
|
+
# Map each ME entry to its dense conflict index
|
|
670
|
+
entry_has_conflict = has_conflict[inverse]
|
|
671
|
+
|
|
672
|
+
if not entry_has_conflict.any():
|
|
673
|
+
return sparse_tensor
|
|
674
|
+
|
|
675
|
+
conflict_entries_mask = entry_has_conflict
|
|
676
|
+
conflict_entry_indices = me_entry_indices[conflict_entries_mask]
|
|
677
|
+
conflict_random_scores = random_scores[conflict_entries_mask]
|
|
678
|
+
conflict_inverse = inverse[conflict_entries_mask]
|
|
679
|
+
conflict_dense_idx = bg_to_dense[conflict_inverse]
|
|
680
|
+
|
|
681
|
+
# Vectorized winner selection using sorting
|
|
682
|
+
# Sort entries by (group_idx, -random_score) so highest score comes first per group
|
|
683
|
+
# Use group * 2 - score to sort by group ascending, then score descending
|
|
684
|
+
sort_keys = conflict_dense_idx.float() * 2.0 - conflict_random_scores
|
|
685
|
+
sorted_order = sort_keys.argsort()
|
|
686
|
+
sorted_dense_idx = conflict_dense_idx[sorted_order]
|
|
687
|
+
|
|
688
|
+
# Find first entry of each group in sorted order (these are winners)
|
|
689
|
+
group_starts = torch.cat(
|
|
690
|
+
[
|
|
691
|
+
torch.tensor([True], device=device),
|
|
692
|
+
sorted_dense_idx[1:] != sorted_dense_idx[:-1],
|
|
693
|
+
]
|
|
694
|
+
)
|
|
695
|
+
|
|
696
|
+
# Winners are entries at group starts in sorted order
|
|
697
|
+
winner_positions_in_sorted = torch.where(group_starts)[0]
|
|
698
|
+
winner_original_positions = sorted_order[winner_positions_in_sorted]
|
|
699
|
+
|
|
700
|
+
# Create winner mask (vectorized)
|
|
701
|
+
is_winner = torch.zeros(
|
|
702
|
+
conflict_entry_indices.numel(), dtype=torch.bool, device=device
|
|
703
|
+
)
|
|
704
|
+
is_winner[winner_original_positions] = True
|
|
705
|
+
|
|
706
|
+
# Build keep mask (vectorized)
|
|
707
|
+
keep_mask = torch.ones(nnz, dtype=torch.bool, device=device)
|
|
708
|
+
loser_entry_indices = conflict_entry_indices[~is_winner]
|
|
709
|
+
keep_mask[loser_entry_indices] = False
|
|
710
|
+
|
|
711
|
+
if keep_mask.all():
|
|
712
|
+
return sparse_tensor
|
|
713
|
+
|
|
714
|
+
return torch.sparse_coo_tensor(
|
|
715
|
+
indices[:, keep_mask],
|
|
716
|
+
values[keep_mask],
|
|
717
|
+
sparse_tensor.shape,
|
|
718
|
+
device=device,
|
|
719
|
+
dtype=sparse_tensor.dtype,
|
|
720
|
+
)
|
|
721
|
+
|
|
722
|
+
|
|
423
723
|
@torch.no_grad()
|
|
424
724
|
def hierarchy_modifier(
|
|
425
725
|
roots: Sequence[HierarchyNode] | HierarchyNode,
|
|
@@ -475,12 +775,24 @@ def hierarchy_modifier(
|
|
|
475
775
|
me_group_sizes=sparse_data.me_group_sizes.to(device),
|
|
476
776
|
me_group_parents=sparse_data.me_group_parents.to(device),
|
|
477
777
|
num_groups=sparse_data.num_groups,
|
|
778
|
+
feat_to_parent=(
|
|
779
|
+
sparse_data.feat_to_parent.to(device)
|
|
780
|
+
if sparse_data.feat_to_parent is not None
|
|
781
|
+
else None
|
|
782
|
+
),
|
|
783
|
+
feat_to_me_group=(
|
|
784
|
+
sparse_data.feat_to_me_group.to(device)
|
|
785
|
+
if sparse_data.feat_to_me_group is not None
|
|
786
|
+
else None
|
|
787
|
+
),
|
|
478
788
|
)
|
|
479
789
|
return device_cache[device]
|
|
480
790
|
|
|
481
791
|
def modifier(activations: torch.Tensor) -> torch.Tensor:
|
|
482
792
|
device = activations.device
|
|
483
793
|
cached = _get_sparse_for_device(device)
|
|
794
|
+
if activations.is_sparse:
|
|
795
|
+
return _apply_hierarchy_sparse_coo(activations, cached)
|
|
484
796
|
return _apply_hierarchy_sparse(activations, cached)
|
|
485
797
|
|
|
486
798
|
return modifier
|
|
@@ -23,6 +23,8 @@ def train_toy_sae(
|
|
|
23
23
|
device: str | torch.device = "cpu",
|
|
24
24
|
n_snapshots: int = 0,
|
|
25
25
|
snapshot_fn: Callable[[SAETrainer[Any, Any]], None] | None = None,
|
|
26
|
+
autocast_sae: bool = False,
|
|
27
|
+
autocast_data: bool = False,
|
|
26
28
|
) -> None:
|
|
27
29
|
"""
|
|
28
30
|
Train an SAE on synthetic activations from a feature dictionary.
|
|
@@ -46,6 +48,8 @@ def train_toy_sae(
|
|
|
46
48
|
snapshot_fn: Callback function called at each snapshot point. Receives
|
|
47
49
|
the SAETrainer instance, allowing access to the SAE, training step,
|
|
48
50
|
and other training state. Required if n_snapshots > 0.
|
|
51
|
+
autocast_sae: Whether to autocast the SAE to bfloat16. Only recommend for large SAEs on CUDA
|
|
52
|
+
autocast_data: Whether to autocast the activations generator and feature dictionary to bfloat16. Only recommend for large data on CUDA.
|
|
49
53
|
"""
|
|
50
54
|
|
|
51
55
|
device_str = str(device) if isinstance(device, torch.device) else device
|
|
@@ -55,6 +59,7 @@ def train_toy_sae(
|
|
|
55
59
|
feature_dict=feature_dict,
|
|
56
60
|
activations_generator=activations_generator,
|
|
57
61
|
batch_size=batch_size,
|
|
62
|
+
autocast=autocast_data,
|
|
58
63
|
)
|
|
59
64
|
|
|
60
65
|
# Create trainer config
|
|
@@ -64,7 +69,7 @@ def train_toy_sae(
|
|
|
64
69
|
save_final_checkpoint=False,
|
|
65
70
|
total_training_samples=training_samples,
|
|
66
71
|
device=device_str,
|
|
67
|
-
autocast=
|
|
72
|
+
autocast=autocast_sae,
|
|
68
73
|
lr=lr,
|
|
69
74
|
lr_end=lr,
|
|
70
75
|
lr_scheduler_name="constant",
|
|
@@ -119,6 +124,7 @@ class SyntheticActivationIterator(Iterator[torch.Tensor]):
|
|
|
119
124
|
feature_dict: FeatureDictionary,
|
|
120
125
|
activations_generator: ActivationGenerator,
|
|
121
126
|
batch_size: int,
|
|
127
|
+
autocast: bool = False,
|
|
122
128
|
):
|
|
123
129
|
"""
|
|
124
130
|
Create a new SyntheticActivationIterator.
|
|
@@ -127,16 +133,23 @@ class SyntheticActivationIterator(Iterator[torch.Tensor]):
|
|
|
127
133
|
feature_dict: The feature dictionary to use for generating hidden activations
|
|
128
134
|
activations_generator: Generator that produces feature activations
|
|
129
135
|
batch_size: Number of samples per batch
|
|
136
|
+
autocast: Whether to autocast the activations generator and feature dictionary to bfloat16.
|
|
130
137
|
"""
|
|
131
138
|
self.feature_dict = feature_dict
|
|
132
139
|
self.activations_generator = activations_generator
|
|
133
140
|
self.batch_size = batch_size
|
|
141
|
+
self.autocast = autocast
|
|
134
142
|
|
|
135
143
|
@torch.no_grad()
|
|
136
144
|
def next_batch(self) -> torch.Tensor:
|
|
137
145
|
"""Generate the next batch of hidden activations."""
|
|
138
|
-
|
|
139
|
-
|
|
146
|
+
with torch.autocast(
|
|
147
|
+
device_type=self.feature_dict.feature_vectors.device.type,
|
|
148
|
+
dtype=torch.bfloat16,
|
|
149
|
+
enabled=self.autocast,
|
|
150
|
+
):
|
|
151
|
+
features = self.activations_generator(self.batch_size)
|
|
152
|
+
return self.feature_dict(features)
|
|
140
153
|
|
|
141
154
|
def __iter__(self) -> "SyntheticActivationIterator":
|
|
142
155
|
return self
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|