sae-lens 6.6.4__tar.gz → 6.7.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sae_lens-6.6.4 → sae_lens-6.7.0}/PKG-INFO +1 -1
- {sae_lens-6.6.4 → sae_lens-6.7.0}/pyproject.toml +1 -1
- {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/__init__.py +1 -1
- {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/evals.py +14 -10
- {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/loading/pretrained_sae_loaders.py +68 -3
- {sae_lens-6.6.4 → sae_lens-6.7.0}/LICENSE +0 -0
- {sae_lens-6.6.4 → sae_lens-6.7.0}/README.md +0 -0
- {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/analysis/__init__.py +0 -0
- {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
- {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/analysis/neuronpedia_integration.py +0 -0
- {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/cache_activations_runner.py +0 -0
- {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/config.py +0 -0
- {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/constants.py +0 -0
- {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/llm_sae_training_runner.py +0 -0
- {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/load_model.py +0 -0
- {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/loading/__init__.py +0 -0
- {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/loading/pretrained_saes_directory.py +0 -0
- {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/pretokenize_runner.py +0 -0
- {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/pretrained_saes.yaml +0 -0
- {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/registry.py +0 -0
- {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/saes/__init__.py +0 -0
- {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/saes/batchtopk_sae.py +0 -0
- {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/saes/gated_sae.py +0 -0
- {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/saes/jumprelu_sae.py +0 -0
- {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/saes/sae.py +0 -0
- {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/saes/standard_sae.py +0 -0
- {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/saes/topk_sae.py +0 -0
- {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/saes/transcoder.py +0 -0
- {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/tokenization_and_batching.py +0 -0
- {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/training/__init__.py +0 -0
- {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/training/activation_scaler.py +0 -0
- {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/training/activations_store.py +0 -0
- {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/training/mixing_buffer.py +0 -0
- {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/training/optim.py +0 -0
- {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/training/sae_trainer.py +0 -0
- {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/training/types.py +0 -0
- {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
- {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/tutorial/tsea.py +0 -0
- {sae_lens-6.6.4 → sae_lens-6.7.0}/sae_lens/util.py +0 -0
|
@@ -718,17 +718,9 @@ def get_recons_loss(
|
|
|
718
718
|
**model_kwargs,
|
|
719
719
|
)
|
|
720
720
|
|
|
721
|
-
def kl(original_logits: torch.Tensor, new_logits: torch.Tensor):
|
|
722
|
-
original_probs = torch.nn.functional.softmax(original_logits, dim=-1)
|
|
723
|
-
log_original_probs = torch.log(original_probs)
|
|
724
|
-
new_probs = torch.nn.functional.softmax(new_logits, dim=-1)
|
|
725
|
-
log_new_probs = torch.log(new_probs)
|
|
726
|
-
kl_div = original_probs * (log_original_probs - log_new_probs)
|
|
727
|
-
return kl_div.sum(dim=-1)
|
|
728
|
-
|
|
729
721
|
if compute_kl:
|
|
730
|
-
recons_kl_div =
|
|
731
|
-
zero_abl_kl_div =
|
|
722
|
+
recons_kl_div = _kl(original_logits, recons_logits)
|
|
723
|
+
zero_abl_kl_div = _kl(original_logits, zero_abl_logits)
|
|
732
724
|
metrics["kl_div_with_sae"] = recons_kl_div
|
|
733
725
|
metrics["kl_div_with_ablation"] = zero_abl_kl_div
|
|
734
726
|
|
|
@@ -740,6 +732,18 @@ def get_recons_loss(
|
|
|
740
732
|
return metrics
|
|
741
733
|
|
|
742
734
|
|
|
735
|
+
def _kl(original_logits: torch.Tensor, new_logits: torch.Tensor):
|
|
736
|
+
# Computes the log-probabilities of the new logits (approximation).
|
|
737
|
+
log_probs_new = torch.nn.functional.log_softmax(new_logits, dim=-1)
|
|
738
|
+
# Computes the probabilities of the original logits (true distribution).
|
|
739
|
+
probs_orig = torch.nn.functional.softmax(original_logits, dim=-1)
|
|
740
|
+
# Compute the KL divergence. torch.nn.functional.kl_div expects the first argument to be the log
|
|
741
|
+
# probabilities of the approximation (new), and the second argument to be the true distribution
|
|
742
|
+
# (original) as probabilities. This computes KL(original || new).
|
|
743
|
+
kl = torch.nn.functional.kl_div(log_probs_new, probs_orig, reduction="none")
|
|
744
|
+
return kl.sum(dim=-1)
|
|
745
|
+
|
|
746
|
+
|
|
743
747
|
def all_loadable_saes() -> list[tuple[str, str, float, float]]:
|
|
744
748
|
all_loadable_saes = []
|
|
745
749
|
saes_directory = get_pretrained_saes_directory()
|
|
@@ -1001,10 +1001,14 @@ def get_sparsify_config_from_disk(
|
|
|
1001
1001
|
layer = int(match.group(1))
|
|
1002
1002
|
hook_name = f"blocks.{layer}.hook_resid_post"
|
|
1003
1003
|
|
|
1004
|
+
d_sae = old_cfg_dict.get("num_latents")
|
|
1005
|
+
if d_sae is None:
|
|
1006
|
+
d_sae = old_cfg_dict["d_in"] * old_cfg_dict["expansion_factor"]
|
|
1007
|
+
|
|
1004
1008
|
cfg_dict: dict[str, Any] = {
|
|
1005
1009
|
"architecture": "standard",
|
|
1006
1010
|
"d_in": old_cfg_dict["d_in"],
|
|
1007
|
-
"d_sae":
|
|
1011
|
+
"d_sae": d_sae,
|
|
1008
1012
|
"dtype": "bfloat16",
|
|
1009
1013
|
"device": device or "cpu",
|
|
1010
1014
|
"model_name": config_dict.get("model", path.parts[-2]),
|
|
@@ -1248,11 +1252,11 @@ def get_mwhanna_transcoder_config_from_hf(
|
|
|
1248
1252
|
try:
|
|
1249
1253
|
# mwhanna transcoders sometimes have a typo in the config file name, so check for both
|
|
1250
1254
|
wandb_config_path = hf_hub_download(
|
|
1251
|
-
repo_id, "
|
|
1255
|
+
repo_id, "wandb-config.yaml", force_download=force_download
|
|
1252
1256
|
)
|
|
1253
1257
|
except EntryNotFoundError:
|
|
1254
1258
|
wandb_config_path = hf_hub_download(
|
|
1255
|
-
repo_id, "
|
|
1259
|
+
repo_id, "wanb-config.yaml", force_download=force_download
|
|
1256
1260
|
)
|
|
1257
1261
|
try:
|
|
1258
1262
|
base_config_path = hf_hub_download(
|
|
@@ -1326,6 +1330,66 @@ def mwhanna_transcoder_huggingface_loader(
|
|
|
1326
1330
|
return cfg_dict, state_dict, None
|
|
1327
1331
|
|
|
1328
1332
|
|
|
1333
|
+
def mntss_clt_layer_huggingface_loader(
|
|
1334
|
+
repo_id: str,
|
|
1335
|
+
folder_name: str,
|
|
1336
|
+
device: str = "cpu",
|
|
1337
|
+
force_download: bool = False, # noqa: ARG001
|
|
1338
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
1339
|
+
) -> tuple[dict[str, Any], dict[str, torch.Tensor], torch.Tensor | None]:
|
|
1340
|
+
"""
|
|
1341
|
+
Load a MNTSS CLT layer as a single layer transcoder.
|
|
1342
|
+
The assumption is that the `folder_name` is the layer to load as an int
|
|
1343
|
+
"""
|
|
1344
|
+
base_config_path = hf_hub_download(
|
|
1345
|
+
repo_id, "config.yaml", force_download=force_download
|
|
1346
|
+
)
|
|
1347
|
+
with open(base_config_path) as f:
|
|
1348
|
+
cfg_info: dict[str, Any] = yaml.safe_load(f)
|
|
1349
|
+
|
|
1350
|
+
# We need to actually load the weights, since the config is missing most information
|
|
1351
|
+
encoder_path = hf_hub_download(
|
|
1352
|
+
repo_id,
|
|
1353
|
+
f"W_enc_{folder_name}.safetensors",
|
|
1354
|
+
force_download=force_download,
|
|
1355
|
+
)
|
|
1356
|
+
decoder_path = hf_hub_download(
|
|
1357
|
+
repo_id,
|
|
1358
|
+
f"W_dec_{folder_name}.safetensors",
|
|
1359
|
+
force_download=force_download,
|
|
1360
|
+
)
|
|
1361
|
+
|
|
1362
|
+
encoder_state_dict = load_file(encoder_path, device=device)
|
|
1363
|
+
decoder_state_dict = load_file(decoder_path, device=device)
|
|
1364
|
+
|
|
1365
|
+
with torch.no_grad():
|
|
1366
|
+
state_dict = {
|
|
1367
|
+
"W_enc": encoder_state_dict[f"W_enc_{folder_name}"].T, # type: ignore
|
|
1368
|
+
"b_enc": encoder_state_dict[f"b_enc_{folder_name}"], # type: ignore
|
|
1369
|
+
"b_dec": encoder_state_dict[f"b_dec_{folder_name}"], # type: ignore
|
|
1370
|
+
"W_dec": decoder_state_dict[f"W_dec_{folder_name}"].sum(dim=1), # type: ignore
|
|
1371
|
+
}
|
|
1372
|
+
|
|
1373
|
+
cfg_dict = {
|
|
1374
|
+
"architecture": "transcoder",
|
|
1375
|
+
"d_in": state_dict["b_dec"].shape[0],
|
|
1376
|
+
"d_out": state_dict["b_dec"].shape[0],
|
|
1377
|
+
"d_sae": state_dict["b_enc"].shape[0],
|
|
1378
|
+
"dtype": "float32",
|
|
1379
|
+
"device": device if device is not None else "cpu",
|
|
1380
|
+
"activation_fn": "relu",
|
|
1381
|
+
"normalize_activations": "none",
|
|
1382
|
+
"model_name": cfg_info["model_name"],
|
|
1383
|
+
"hook_name": f"blocks.{folder_name}.{cfg_info['feature_input_hook']}",
|
|
1384
|
+
"hook_name_out": f"blocks.{folder_name}.{cfg_info['feature_output_hook']}",
|
|
1385
|
+
"apply_b_dec_to_input": False,
|
|
1386
|
+
"model_from_pretrained_kwargs": {"fold_ln": False},
|
|
1387
|
+
**(cfg_overrides or {}),
|
|
1388
|
+
}
|
|
1389
|
+
|
|
1390
|
+
return cfg_dict, state_dict, None
|
|
1391
|
+
|
|
1392
|
+
|
|
1329
1393
|
NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
|
|
1330
1394
|
"sae_lens": sae_lens_huggingface_loader,
|
|
1331
1395
|
"connor_rob_hook_z": connor_rob_hook_z_huggingface_loader,
|
|
@@ -1337,6 +1401,7 @@ NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
|
|
|
1337
1401
|
"sparsify": sparsify_huggingface_loader,
|
|
1338
1402
|
"gemma_2_transcoder": gemma_2_transcoder_huggingface_loader,
|
|
1339
1403
|
"mwhanna_transcoder": mwhanna_transcoder_huggingface_loader,
|
|
1404
|
+
"mntss_clt_layer_transcoder": mntss_clt_layer_huggingface_loader,
|
|
1340
1405
|
}
|
|
1341
1406
|
|
|
1342
1407
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|