sae-lens 6.6.5__tar.gz → 6.7.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sae_lens-6.6.5 → sae_lens-6.7.0}/PKG-INFO +1 -1
- {sae_lens-6.6.5 → sae_lens-6.7.0}/pyproject.toml +1 -1
- {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/__init__.py +1 -1
- {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/loading/pretrained_sae_loaders.py +63 -2
- {sae_lens-6.6.5 → sae_lens-6.7.0}/LICENSE +0 -0
- {sae_lens-6.6.5 → sae_lens-6.7.0}/README.md +0 -0
- {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/analysis/__init__.py +0 -0
- {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
- {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/analysis/neuronpedia_integration.py +0 -0
- {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/cache_activations_runner.py +0 -0
- {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/config.py +0 -0
- {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/constants.py +0 -0
- {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/evals.py +0 -0
- {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/llm_sae_training_runner.py +0 -0
- {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/load_model.py +0 -0
- {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/loading/__init__.py +0 -0
- {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/loading/pretrained_saes_directory.py +0 -0
- {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/pretokenize_runner.py +0 -0
- {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/pretrained_saes.yaml +0 -0
- {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/registry.py +0 -0
- {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/saes/__init__.py +0 -0
- {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/saes/batchtopk_sae.py +0 -0
- {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/saes/gated_sae.py +0 -0
- {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/saes/jumprelu_sae.py +0 -0
- {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/saes/sae.py +0 -0
- {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/saes/standard_sae.py +0 -0
- {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/saes/topk_sae.py +0 -0
- {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/saes/transcoder.py +0 -0
- {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/tokenization_and_batching.py +0 -0
- {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/training/__init__.py +0 -0
- {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/training/activation_scaler.py +0 -0
- {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/training/activations_store.py +0 -0
- {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/training/mixing_buffer.py +0 -0
- {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/training/optim.py +0 -0
- {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/training/sae_trainer.py +0 -0
- {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/training/types.py +0 -0
- {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
- {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/tutorial/tsea.py +0 -0
- {sae_lens-6.6.5 → sae_lens-6.7.0}/sae_lens/util.py +0 -0
|
@@ -1252,11 +1252,11 @@ def get_mwhanna_transcoder_config_from_hf(
|
|
|
1252
1252
|
try:
|
|
1253
1253
|
# mwhanna transcoders sometimes have a typo in the config file name, so check for both
|
|
1254
1254
|
wandb_config_path = hf_hub_download(
|
|
1255
|
-
repo_id, "
|
|
1255
|
+
repo_id, "wandb-config.yaml", force_download=force_download
|
|
1256
1256
|
)
|
|
1257
1257
|
except EntryNotFoundError:
|
|
1258
1258
|
wandb_config_path = hf_hub_download(
|
|
1259
|
-
repo_id, "
|
|
1259
|
+
repo_id, "wanb-config.yaml", force_download=force_download
|
|
1260
1260
|
)
|
|
1261
1261
|
try:
|
|
1262
1262
|
base_config_path = hf_hub_download(
|
|
@@ -1330,6 +1330,66 @@ def mwhanna_transcoder_huggingface_loader(
|
|
|
1330
1330
|
return cfg_dict, state_dict, None
|
|
1331
1331
|
|
|
1332
1332
|
|
|
1333
|
+
def mntss_clt_layer_huggingface_loader(
|
|
1334
|
+
repo_id: str,
|
|
1335
|
+
folder_name: str,
|
|
1336
|
+
device: str = "cpu",
|
|
1337
|
+
force_download: bool = False, # noqa: ARG001
|
|
1338
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
1339
|
+
) -> tuple[dict[str, Any], dict[str, torch.Tensor], torch.Tensor | None]:
|
|
1340
|
+
"""
|
|
1341
|
+
Load a MNTSS CLT layer as a single layer transcoder.
|
|
1342
|
+
The assumption is that the `folder_name` is the layer to load as an int
|
|
1343
|
+
"""
|
|
1344
|
+
base_config_path = hf_hub_download(
|
|
1345
|
+
repo_id, "config.yaml", force_download=force_download
|
|
1346
|
+
)
|
|
1347
|
+
with open(base_config_path) as f:
|
|
1348
|
+
cfg_info: dict[str, Any] = yaml.safe_load(f)
|
|
1349
|
+
|
|
1350
|
+
# We need to actually load the weights, since the config is missing most information
|
|
1351
|
+
encoder_path = hf_hub_download(
|
|
1352
|
+
repo_id,
|
|
1353
|
+
f"W_enc_{folder_name}.safetensors",
|
|
1354
|
+
force_download=force_download,
|
|
1355
|
+
)
|
|
1356
|
+
decoder_path = hf_hub_download(
|
|
1357
|
+
repo_id,
|
|
1358
|
+
f"W_dec_{folder_name}.safetensors",
|
|
1359
|
+
force_download=force_download,
|
|
1360
|
+
)
|
|
1361
|
+
|
|
1362
|
+
encoder_state_dict = load_file(encoder_path, device=device)
|
|
1363
|
+
decoder_state_dict = load_file(decoder_path, device=device)
|
|
1364
|
+
|
|
1365
|
+
with torch.no_grad():
|
|
1366
|
+
state_dict = {
|
|
1367
|
+
"W_enc": encoder_state_dict[f"W_enc_{folder_name}"].T, # type: ignore
|
|
1368
|
+
"b_enc": encoder_state_dict[f"b_enc_{folder_name}"], # type: ignore
|
|
1369
|
+
"b_dec": encoder_state_dict[f"b_dec_{folder_name}"], # type: ignore
|
|
1370
|
+
"W_dec": decoder_state_dict[f"W_dec_{folder_name}"].sum(dim=1), # type: ignore
|
|
1371
|
+
}
|
|
1372
|
+
|
|
1373
|
+
cfg_dict = {
|
|
1374
|
+
"architecture": "transcoder",
|
|
1375
|
+
"d_in": state_dict["b_dec"].shape[0],
|
|
1376
|
+
"d_out": state_dict["b_dec"].shape[0],
|
|
1377
|
+
"d_sae": state_dict["b_enc"].shape[0],
|
|
1378
|
+
"dtype": "float32",
|
|
1379
|
+
"device": device if device is not None else "cpu",
|
|
1380
|
+
"activation_fn": "relu",
|
|
1381
|
+
"normalize_activations": "none",
|
|
1382
|
+
"model_name": cfg_info["model_name"],
|
|
1383
|
+
"hook_name": f"blocks.{folder_name}.{cfg_info['feature_input_hook']}",
|
|
1384
|
+
"hook_name_out": f"blocks.{folder_name}.{cfg_info['feature_output_hook']}",
|
|
1385
|
+
"apply_b_dec_to_input": False,
|
|
1386
|
+
"model_from_pretrained_kwargs": {"fold_ln": False},
|
|
1387
|
+
**(cfg_overrides or {}),
|
|
1388
|
+
}
|
|
1389
|
+
|
|
1390
|
+
return cfg_dict, state_dict, None
|
|
1391
|
+
|
|
1392
|
+
|
|
1333
1393
|
NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
|
|
1334
1394
|
"sae_lens": sae_lens_huggingface_loader,
|
|
1335
1395
|
"connor_rob_hook_z": connor_rob_hook_z_huggingface_loader,
|
|
@@ -1341,6 +1401,7 @@ NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
|
|
|
1341
1401
|
"sparsify": sparsify_huggingface_loader,
|
|
1342
1402
|
"gemma_2_transcoder": gemma_2_transcoder_huggingface_loader,
|
|
1343
1403
|
"mwhanna_transcoder": mwhanna_transcoder_huggingface_loader,
|
|
1404
|
+
"mntss_clt_layer_transcoder": mntss_clt_layer_huggingface_loader,
|
|
1344
1405
|
}
|
|
1345
1406
|
|
|
1346
1407
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|