sae-lens 6.6.5__tar.gz → 6.7.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {sae_lens-6.6.5 → sae_lens-6.7.0}/PKG-INFO +1 -1
  2. {sae_lens-6.6.5 → sae_lens-6.7.0}/pyproject.toml +1 -1
  3. {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/__init__.py +1 -1
  4. {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/loading/pretrained_sae_loaders.py +63 -2
  5. {sae_lens-6.6.5 → sae_lens-6.7.0}/LICENSE +0 -0
  6. {sae_lens-6.6.5 → sae_lens-6.7.0}/README.md +0 -0
  7. {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/analysis/__init__.py +0 -0
  8. {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
  9. {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/analysis/neuronpedia_integration.py +0 -0
  10. {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/cache_activations_runner.py +0 -0
  11. {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/config.py +0 -0
  12. {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/constants.py +0 -0
  13. {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/evals.py +0 -0
  14. {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/llm_sae_training_runner.py +0 -0
  15. {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/load_model.py +0 -0
  16. {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/loading/__init__.py +0 -0
  17. {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/loading/pretrained_saes_directory.py +0 -0
  18. {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/pretokenize_runner.py +0 -0
  19. {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/pretrained_saes.yaml +0 -0
  20. {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/registry.py +0 -0
  21. {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/saes/__init__.py +0 -0
  22. {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/saes/batchtopk_sae.py +0 -0
  23. {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/saes/gated_sae.py +0 -0
  24. {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/saes/jumprelu_sae.py +0 -0
  25. {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/saes/sae.py +0 -0
  26. {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/saes/standard_sae.py +0 -0
  27. {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/saes/topk_sae.py +0 -0
  28. {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/saes/transcoder.py +0 -0
  29. {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/tokenization_and_batching.py +0 -0
  30. {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/training/__init__.py +0 -0
  31. {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/training/activation_scaler.py +0 -0
  32. {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/training/activations_store.py +0 -0
  33. {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/training/mixing_buffer.py +0 -0
  34. {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/training/optim.py +0 -0
  35. {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/training/sae_trainer.py +0 -0
  36. {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/training/types.py +0 -0
  37. {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
  38. {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/tutorial/tsea.py +0 -0
  39. {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/util.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: sae-lens
3
- Version: 6.6.5
3
+ Version: 6.7.0
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  Keywords: deep-learning,sparse-autoencoders,mechanistic-interpretability,PyTorch
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "sae-lens"
3
- version = "6.6.5"
3
+ version = "6.7.0"
4
4
  description = "Training and Analyzing Sparse Autoencoders (SAEs)"
5
5
  authors = ["Joseph Bloom"]
6
6
  readme = "README.md"
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.6.5"
2
+ __version__ = "6.7.0"
3
3
 
4
4
  import logging
5
5
 
@@ -1252,11 +1252,11 @@ def get_mwhanna_transcoder_config_from_hf(
1252
1252
  try:
1253
1253
  # mwhanna transcoders sometimes have a typo in the config file name, so check for both
1254
1254
  wandb_config_path = hf_hub_download(
1255
- repo_id, "wanb-config.yaml", force_download=force_download
1255
+ repo_id, "wandb-config.yaml", force_download=force_download
1256
1256
  )
1257
1257
  except EntryNotFoundError:
1258
1258
  wandb_config_path = hf_hub_download(
1259
- repo_id, "wandb-config.yaml", force_download=force_download
1259
+ repo_id, "wanb-config.yaml", force_download=force_download
1260
1260
  )
1261
1261
  try:
1262
1262
  base_config_path = hf_hub_download(
@@ -1330,6 +1330,66 @@ def mwhanna_transcoder_huggingface_loader(
1330
1330
  return cfg_dict, state_dict, None
1331
1331
 
1332
1332
 
1333
+ def mntss_clt_layer_huggingface_loader(
1334
+ repo_id: str,
1335
+ folder_name: str,
1336
+ device: str = "cpu",
1337
+ force_download: bool = False, # noqa: ARG001
1338
+ cfg_overrides: dict[str, Any] | None = None,
1339
+ ) -> tuple[dict[str, Any], dict[str, torch.Tensor], torch.Tensor | None]:
1340
+ """
1341
+ Load a MNTSS CLT layer as a single layer transcoder.
1342
+ The assumption is that the `folder_name` is the layer to load as an int
1343
+ """
1344
+ base_config_path = hf_hub_download(
1345
+ repo_id, "config.yaml", force_download=force_download
1346
+ )
1347
+ with open(base_config_path) as f:
1348
+ cfg_info: dict[str, Any] = yaml.safe_load(f)
1349
+
1350
+ # We need to actually load the weights, since the config is missing most information
1351
+ encoder_path = hf_hub_download(
1352
+ repo_id,
1353
+ f"W_enc_{folder_name}.safetensors",
1354
+ force_download=force_download,
1355
+ )
1356
+ decoder_path = hf_hub_download(
1357
+ repo_id,
1358
+ f"W_dec_{folder_name}.safetensors",
1359
+ force_download=force_download,
1360
+ )
1361
+
1362
+ encoder_state_dict = load_file(encoder_path, device=device)
1363
+ decoder_state_dict = load_file(decoder_path, device=device)
1364
+
1365
+ with torch.no_grad():
1366
+ state_dict = {
1367
+ "W_enc": encoder_state_dict[f"W_enc_{folder_name}"].T, # type: ignore
1368
+ "b_enc": encoder_state_dict[f"b_enc_{folder_name}"], # type: ignore
1369
+ "b_dec": encoder_state_dict[f"b_dec_{folder_name}"], # type: ignore
1370
+ "W_dec": decoder_state_dict[f"W_dec_{folder_name}"].sum(dim=1), # type: ignore
1371
+ }
1372
+
1373
+ cfg_dict = {
1374
+ "architecture": "transcoder",
1375
+ "d_in": state_dict["b_dec"].shape[0],
1376
+ "d_out": state_dict["b_dec"].shape[0],
1377
+ "d_sae": state_dict["b_enc"].shape[0],
1378
+ "dtype": "float32",
1379
+ "device": device if device is not None else "cpu",
1380
+ "activation_fn": "relu",
1381
+ "normalize_activations": "none",
1382
+ "model_name": cfg_info["model_name"],
1383
+ "hook_name": f"blocks.{folder_name}.{cfg_info['feature_input_hook']}",
1384
+ "hook_name_out": f"blocks.{folder_name}.{cfg_info['feature_output_hook']}",
1385
+ "apply_b_dec_to_input": False,
1386
+ "model_from_pretrained_kwargs": {"fold_ln": False},
1387
+ **(cfg_overrides or {}),
1388
+ }
1389
+
1390
+ return cfg_dict, state_dict, None
1391
+
1392
+
1333
1393
  NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
1334
1394
  "sae_lens": sae_lens_huggingface_loader,
1335
1395
  "connor_rob_hook_z": connor_rob_hook_z_huggingface_loader,
@@ -1341,6 +1401,7 @@ NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
1341
1401
  "sparsify": sparsify_huggingface_loader,
1342
1402
  "gemma_2_transcoder": gemma_2_transcoder_huggingface_loader,
1343
1403
  "mwhanna_transcoder": mwhanna_transcoder_huggingface_loader,
1404
+ "mntss_clt_layer_transcoder": mntss_clt_layer_huggingface_loader,
1344
1405
  }
1345
1406
 
1346
1407
 
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes