sae-lens 6.6.4__tar.gz → 6.7.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {sae_lens-6.6.4 → sae_lens-6.7.0}/PKG-INFO +1 -1
  2. {sae_lens-6.6.4 → sae_lens-6.7.0}/pyproject.toml +1 -1
  3. {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/__init__.py +1 -1
  4. {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/evals.py +14 -10
  5. {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/loading/pretrained_sae_loaders.py +68 -3
  6. {sae_lens-6.6.4 → sae_lens-6.7.0}/LICENSE +0 -0
  7. {sae_lens-6.6.4 → sae_lens-6.7.0}/README.md +0 -0
  8. {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/analysis/__init__.py +0 -0
  9. {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
  10. {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/analysis/neuronpedia_integration.py +0 -0
  11. {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/cache_activations_runner.py +0 -0
  12. {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/config.py +0 -0
  13. {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/constants.py +0 -0
  14. {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/llm_sae_training_runner.py +0 -0
  15. {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/load_model.py +0 -0
  16. {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/loading/__init__.py +0 -0
  17. {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/loading/pretrained_saes_directory.py +0 -0
  18. {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/pretokenize_runner.py +0 -0
  19. {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/pretrained_saes.yaml +0 -0
  20. {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/registry.py +0 -0
  21. {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/saes/__init__.py +0 -0
  22. {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/saes/batchtopk_sae.py +0 -0
  23. {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/saes/gated_sae.py +0 -0
  24. {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/saes/jumprelu_sae.py +0 -0
  25. {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/saes/sae.py +0 -0
  26. {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/saes/standard_sae.py +0 -0
  27. {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/saes/topk_sae.py +0 -0
  28. {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/saes/transcoder.py +0 -0
  29. {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/tokenization_and_batching.py +0 -0
  30. {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/training/__init__.py +0 -0
  31. {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/training/activation_scaler.py +0 -0
  32. {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/training/activations_store.py +0 -0
  33. {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/training/mixing_buffer.py +0 -0
  34. {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/training/optim.py +0 -0
  35. {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/training/sae_trainer.py +0 -0
  36. {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/training/types.py +0 -0
  37. {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
  38. {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/tutorial/tsea.py +0 -0
  39. {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/util.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: sae-lens
3
- Version: 6.6.4
3
+ Version: 6.7.0
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  Keywords: deep-learning,sparse-autoencoders,mechanistic-interpretability,PyTorch
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "sae-lens"
3
- version = "6.6.4"
3
+ version = "6.7.0"
4
4
  description = "Training and Analyzing Sparse Autoencoders (SAEs)"
5
5
  authors = ["Joseph Bloom"]
6
6
  readme = "README.md"
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.6.4"
2
+ __version__ = "6.7.0"
3
3
 
4
4
  import logging
5
5
 
@@ -718,17 +718,9 @@ def get_recons_loss(
718
718
  **model_kwargs,
719
719
  )
720
720
 
721
- def kl(original_logits: torch.Tensor, new_logits: torch.Tensor):
722
- original_probs = torch.nn.functional.softmax(original_logits, dim=-1)
723
- log_original_probs = torch.log(original_probs)
724
- new_probs = torch.nn.functional.softmax(new_logits, dim=-1)
725
- log_new_probs = torch.log(new_probs)
726
- kl_div = original_probs * (log_original_probs - log_new_probs)
727
- return kl_div.sum(dim=-1)
728
-
729
721
  if compute_kl:
730
- recons_kl_div = kl(original_logits, recons_logits)
731
- zero_abl_kl_div = kl(original_logits, zero_abl_logits)
722
+ recons_kl_div = _kl(original_logits, recons_logits)
723
+ zero_abl_kl_div = _kl(original_logits, zero_abl_logits)
732
724
  metrics["kl_div_with_sae"] = recons_kl_div
733
725
  metrics["kl_div_with_ablation"] = zero_abl_kl_div
734
726
 
@@ -740,6 +732,18 @@ def get_recons_loss(
740
732
  return metrics
741
733
 
742
734
 
735
+ def _kl(original_logits: torch.Tensor, new_logits: torch.Tensor):
736
+ # Computes the log-probabilities of the new logits (approximation).
737
+ log_probs_new = torch.nn.functional.log_softmax(new_logits, dim=-1)
738
+ # Computes the probabilities of the original logits (true distribution).
739
+ probs_orig = torch.nn.functional.softmax(original_logits, dim=-1)
740
+ # Compute the KL divergence. torch.nn.functional.kl_div expects the first argument to be the log
741
+ # probabilities of the approximation (new), and the second argument to be the true distribution
742
+ # (original) as probabilities. This computes KL(original || new).
743
+ kl = torch.nn.functional.kl_div(log_probs_new, probs_orig, reduction="none")
744
+ return kl.sum(dim=-1)
745
+
746
+
743
747
  def all_loadable_saes() -> list[tuple[str, str, float, float]]:
744
748
  all_loadable_saes = []
745
749
  saes_directory = get_pretrained_saes_directory()
@@ -1001,10 +1001,14 @@ def get_sparsify_config_from_disk(
1001
1001
  layer = int(match.group(1))
1002
1002
  hook_name = f"blocks.{layer}.hook_resid_post"
1003
1003
 
1004
+ d_sae = old_cfg_dict.get("num_latents")
1005
+ if d_sae is None:
1006
+ d_sae = old_cfg_dict["d_in"] * old_cfg_dict["expansion_factor"]
1007
+
1004
1008
  cfg_dict: dict[str, Any] = {
1005
1009
  "architecture": "standard",
1006
1010
  "d_in": old_cfg_dict["d_in"],
1007
- "d_sae": old_cfg_dict["d_in"] * old_cfg_dict["expansion_factor"],
1011
+ "d_sae": d_sae,
1008
1012
  "dtype": "bfloat16",
1009
1013
  "device": device or "cpu",
1010
1014
  "model_name": config_dict.get("model", path.parts[-2]),
@@ -1248,11 +1252,11 @@ def get_mwhanna_transcoder_config_from_hf(
1248
1252
  try:
1249
1253
  # mwhanna transcoders sometimes have a typo in the config file name, so check for both
1250
1254
  wandb_config_path = hf_hub_download(
1251
- repo_id, "wanb-config.yaml", force_download=force_download
1255
+ repo_id, "wandb-config.yaml", force_download=force_download
1252
1256
  )
1253
1257
  except EntryNotFoundError:
1254
1258
  wandb_config_path = hf_hub_download(
1255
- repo_id, "wandb-config.yaml", force_download=force_download
1259
+ repo_id, "wanb-config.yaml", force_download=force_download
1256
1260
  )
1257
1261
  try:
1258
1262
  base_config_path = hf_hub_download(
@@ -1326,6 +1330,66 @@ def mwhanna_transcoder_huggingface_loader(
1326
1330
  return cfg_dict, state_dict, None
1327
1331
 
1328
1332
 
1333
+ def mntss_clt_layer_huggingface_loader(
1334
+ repo_id: str,
1335
+ folder_name: str,
1336
+ device: str = "cpu",
1337
+ force_download: bool = False, # noqa: ARG001
1338
+ cfg_overrides: dict[str, Any] | None = None,
1339
+ ) -> tuple[dict[str, Any], dict[str, torch.Tensor], torch.Tensor | None]:
1340
+ """
1341
+ Load a MNTSS CLT layer as a single layer transcoder.
1342
+ The assumption is that the `folder_name` is the layer to load as an int
1343
+ """
1344
+ base_config_path = hf_hub_download(
1345
+ repo_id, "config.yaml", force_download=force_download
1346
+ )
1347
+ with open(base_config_path) as f:
1348
+ cfg_info: dict[str, Any] = yaml.safe_load(f)
1349
+
1350
+ # We need to actually load the weights, since the config is missing most information
1351
+ encoder_path = hf_hub_download(
1352
+ repo_id,
1353
+ f"W_enc_{folder_name}.safetensors",
1354
+ force_download=force_download,
1355
+ )
1356
+ decoder_path = hf_hub_download(
1357
+ repo_id,
1358
+ f"W_dec_{folder_name}.safetensors",
1359
+ force_download=force_download,
1360
+ )
1361
+
1362
+ encoder_state_dict = load_file(encoder_path, device=device)
1363
+ decoder_state_dict = load_file(decoder_path, device=device)
1364
+
1365
+ with torch.no_grad():
1366
+ state_dict = {
1367
+ "W_enc": encoder_state_dict[f"W_enc_{folder_name}"].T, # type: ignore
1368
+ "b_enc": encoder_state_dict[f"b_enc_{folder_name}"], # type: ignore
1369
+ "b_dec": encoder_state_dict[f"b_dec_{folder_name}"], # type: ignore
1370
+ "W_dec": decoder_state_dict[f"W_dec_{folder_name}"].sum(dim=1), # type: ignore
1371
+ }
1372
+
1373
+ cfg_dict = {
1374
+ "architecture": "transcoder",
1375
+ "d_in": state_dict["b_dec"].shape[0],
1376
+ "d_out": state_dict["b_dec"].shape[0],
1377
+ "d_sae": state_dict["b_enc"].shape[0],
1378
+ "dtype": "float32",
1379
+ "device": device if device is not None else "cpu",
1380
+ "activation_fn": "relu",
1381
+ "normalize_activations": "none",
1382
+ "model_name": cfg_info["model_name"],
1383
+ "hook_name": f"blocks.{folder_name}.{cfg_info['feature_input_hook']}",
1384
+ "hook_name_out": f"blocks.{folder_name}.{cfg_info['feature_output_hook']}",
1385
+ "apply_b_dec_to_input": False,
1386
+ "model_from_pretrained_kwargs": {"fold_ln": False},
1387
+ **(cfg_overrides or {}),
1388
+ }
1389
+
1390
+ return cfg_dict, state_dict, None
1391
+
1392
+
1329
1393
  NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
1330
1394
  "sae_lens": sae_lens_huggingface_loader,
1331
1395
  "connor_rob_hook_z": connor_rob_hook_z_huggingface_loader,
@@ -1337,6 +1401,7 @@ NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
1337
1401
  "sparsify": sparsify_huggingface_loader,
1338
1402
  "gemma_2_transcoder": gemma_2_transcoder_huggingface_loader,
1339
1403
  "mwhanna_transcoder": mwhanna_transcoder_huggingface_loader,
1404
+ "mntss_clt_layer_transcoder": mntss_clt_layer_huggingface_loader,
1340
1405
  }
1341
1406
 
1342
1407
 
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes