sae-lens 6.8.0__tar.gz → 6.9.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sae_lens-6.8.0 → sae_lens-6.9.1}/PKG-INFO +1 -1
- {sae_lens-6.8.0 → sae_lens-6.9.1}/pyproject.toml +1 -1
- {sae_lens-6.8.0 → sae_lens-6.9.1}/sae_lens/__init__.py +1 -1
- {sae_lens-6.8.0 → sae_lens-6.9.1}/sae_lens/loading/pretrained_sae_loaders.py +91 -11
- {sae_lens-6.8.0 → sae_lens-6.9.1}/sae_lens/pretrained_saes.yaml +85 -1
- {sae_lens-6.8.0 → sae_lens-6.9.1}/LICENSE +0 -0
- {sae_lens-6.8.0 → sae_lens-6.9.1}/README.md +0 -0
- {sae_lens-6.8.0 → sae_lens-6.9.1}/sae_lens/analysis/__init__.py +0 -0
- {sae_lens-6.8.0 → sae_lens-6.9.1}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
- {sae_lens-6.8.0 → sae_lens-6.9.1}/sae_lens/analysis/neuronpedia_integration.py +0 -0
- {sae_lens-6.8.0 → sae_lens-6.9.1}/sae_lens/cache_activations_runner.py +0 -0
- {sae_lens-6.8.0 → sae_lens-6.9.1}/sae_lens/config.py +0 -0
- {sae_lens-6.8.0 → sae_lens-6.9.1}/sae_lens/constants.py +0 -0
- {sae_lens-6.8.0 → sae_lens-6.9.1}/sae_lens/evals.py +0 -0
- {sae_lens-6.8.0 → sae_lens-6.9.1}/sae_lens/llm_sae_training_runner.py +0 -0
- {sae_lens-6.8.0 → sae_lens-6.9.1}/sae_lens/load_model.py +0 -0
- {sae_lens-6.8.0 → sae_lens-6.9.1}/sae_lens/loading/__init__.py +0 -0
- {sae_lens-6.8.0 → sae_lens-6.9.1}/sae_lens/loading/pretrained_saes_directory.py +0 -0
- {sae_lens-6.8.0 → sae_lens-6.9.1}/sae_lens/pretokenize_runner.py +0 -0
- {sae_lens-6.8.0 → sae_lens-6.9.1}/sae_lens/registry.py +0 -0
- {sae_lens-6.8.0 → sae_lens-6.9.1}/sae_lens/saes/__init__.py +0 -0
- {sae_lens-6.8.0 → sae_lens-6.9.1}/sae_lens/saes/batchtopk_sae.py +0 -0
- {sae_lens-6.8.0 → sae_lens-6.9.1}/sae_lens/saes/gated_sae.py +0 -0
- {sae_lens-6.8.0 → sae_lens-6.9.1}/sae_lens/saes/jumprelu_sae.py +0 -0
- {sae_lens-6.8.0 → sae_lens-6.9.1}/sae_lens/saes/sae.py +0 -0
- {sae_lens-6.8.0 → sae_lens-6.9.1}/sae_lens/saes/standard_sae.py +0 -0
- {sae_lens-6.8.0 → sae_lens-6.9.1}/sae_lens/saes/topk_sae.py +0 -0
- {sae_lens-6.8.0 → sae_lens-6.9.1}/sae_lens/saes/transcoder.py +0 -0
- {sae_lens-6.8.0 → sae_lens-6.9.1}/sae_lens/tokenization_and_batching.py +0 -0
- {sae_lens-6.8.0 → sae_lens-6.9.1}/sae_lens/training/__init__.py +0 -0
- {sae_lens-6.8.0 → sae_lens-6.9.1}/sae_lens/training/activation_scaler.py +0 -0
- {sae_lens-6.8.0 → sae_lens-6.9.1}/sae_lens/training/activations_store.py +0 -0
- {sae_lens-6.8.0 → sae_lens-6.9.1}/sae_lens/training/mixing_buffer.py +0 -0
- {sae_lens-6.8.0 → sae_lens-6.9.1}/sae_lens/training/optim.py +0 -0
- {sae_lens-6.8.0 → sae_lens-6.9.1}/sae_lens/training/sae_trainer.py +0 -0
- {sae_lens-6.8.0 → sae_lens-6.9.1}/sae_lens/training/types.py +0 -0
- {sae_lens-6.8.0 → sae_lens-6.9.1}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
- {sae_lens-6.8.0 → sae_lens-6.9.1}/sae_lens/tutorial/tsea.py +0 -0
- {sae_lens-6.8.0 → sae_lens-6.9.1}/sae_lens/util.py +0 -0
|
@@ -1,12 +1,14 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import re
|
|
3
|
+
import warnings
|
|
3
4
|
from pathlib import Path
|
|
4
5
|
from typing import Any, Protocol
|
|
5
6
|
|
|
6
7
|
import numpy as np
|
|
8
|
+
import requests
|
|
7
9
|
import torch
|
|
8
10
|
import yaml
|
|
9
|
-
from huggingface_hub import hf_hub_download
|
|
11
|
+
from huggingface_hub import hf_hub_download, hf_hub_url
|
|
10
12
|
from huggingface_hub.utils import EntryNotFoundError
|
|
11
13
|
from packaging.version import Version
|
|
12
14
|
from safetensors import safe_open
|
|
@@ -1330,6 +1332,48 @@ def mwhanna_transcoder_huggingface_loader(
|
|
|
1330
1332
|
return cfg_dict, state_dict, None
|
|
1331
1333
|
|
|
1332
1334
|
|
|
1335
|
+
def get_safetensors_tensor_shapes(url: str) -> dict[str, list[int]]:
|
|
1336
|
+
"""
|
|
1337
|
+
Get tensor shapes from a safetensors file using HTTP range requests
|
|
1338
|
+
without downloading the entire file.
|
|
1339
|
+
|
|
1340
|
+
Args:
|
|
1341
|
+
url: Direct URL to the safetensors file
|
|
1342
|
+
|
|
1343
|
+
Returns:
|
|
1344
|
+
Dictionary mapping tensor names to their shapes
|
|
1345
|
+
"""
|
|
1346
|
+
# Check if server supports range requests
|
|
1347
|
+
response = requests.head(url, timeout=10)
|
|
1348
|
+
response.raise_for_status()
|
|
1349
|
+
|
|
1350
|
+
accept_ranges = response.headers.get("Accept-Ranges", "")
|
|
1351
|
+
if "bytes" not in accept_ranges:
|
|
1352
|
+
raise ValueError("Server does not support range requests")
|
|
1353
|
+
|
|
1354
|
+
# Fetch first 8 bytes to get metadata size
|
|
1355
|
+
headers = {"Range": "bytes=0-7"}
|
|
1356
|
+
response = requests.get(url, headers=headers, timeout=10)
|
|
1357
|
+
if response.status_code != 206:
|
|
1358
|
+
raise ValueError("Failed to fetch initial bytes for metadata size")
|
|
1359
|
+
|
|
1360
|
+
meta_size = int.from_bytes(response.content, byteorder="little")
|
|
1361
|
+
|
|
1362
|
+
# Fetch the metadata header
|
|
1363
|
+
headers = {"Range": f"bytes=8-{8 + meta_size - 1}"}
|
|
1364
|
+
response = requests.get(url, headers=headers, timeout=10)
|
|
1365
|
+
if response.status_code != 206:
|
|
1366
|
+
raise ValueError("Failed to fetch metadata header")
|
|
1367
|
+
|
|
1368
|
+
metadata_json = response.content.decode("utf-8").strip()
|
|
1369
|
+
metadata = json.loads(metadata_json)
|
|
1370
|
+
|
|
1371
|
+
# Extract tensor shapes, excluding the __metadata__ key
|
|
1372
|
+
return {
|
|
1373
|
+
name: info["shape"] for name, info in metadata.items() if name != "__metadata__"
|
|
1374
|
+
}
|
|
1375
|
+
|
|
1376
|
+
|
|
1333
1377
|
def mntss_clt_layer_huggingface_loader(
|
|
1334
1378
|
repo_id: str,
|
|
1335
1379
|
folder_name: str,
|
|
@@ -1341,11 +1385,20 @@ def mntss_clt_layer_huggingface_loader(
|
|
|
1341
1385
|
Load a MNTSS CLT layer as a single layer transcoder.
|
|
1342
1386
|
The assumption is that the `folder_name` is the layer to load as an int
|
|
1343
1387
|
"""
|
|
1344
|
-
|
|
1345
|
-
|
|
1388
|
+
|
|
1389
|
+
# warn that this sums the decoders together, so should only be used to find feature activations, not for reconstruction
|
|
1390
|
+
warnings.warn(
|
|
1391
|
+
"This loads the CLT layer as a single layer transcoder by summing all decoders together. This should only be used to find feature activations, not for reconstruction",
|
|
1392
|
+
UserWarning,
|
|
1393
|
+
)
|
|
1394
|
+
|
|
1395
|
+
cfg_dict = get_mntss_clt_layer_config_from_hf(
|
|
1396
|
+
repo_id,
|
|
1397
|
+
folder_name,
|
|
1398
|
+
device,
|
|
1399
|
+
force_download,
|
|
1400
|
+
cfg_overrides,
|
|
1346
1401
|
)
|
|
1347
|
-
with open(base_config_path) as f:
|
|
1348
|
-
cfg_info: dict[str, Any] = yaml.safe_load(f)
|
|
1349
1402
|
|
|
1350
1403
|
# We need to actually load the weights, since the config is missing most information
|
|
1351
1404
|
encoder_path = hf_hub_download(
|
|
@@ -1370,11 +1423,39 @@ def mntss_clt_layer_huggingface_loader(
|
|
|
1370
1423
|
"W_dec": decoder_state_dict[f"W_dec_{folder_name}"].sum(dim=1), # type: ignore
|
|
1371
1424
|
}
|
|
1372
1425
|
|
|
1373
|
-
cfg_dict
|
|
1426
|
+
return cfg_dict, state_dict, None
|
|
1427
|
+
|
|
1428
|
+
|
|
1429
|
+
def get_mntss_clt_layer_config_from_hf(
|
|
1430
|
+
repo_id: str,
|
|
1431
|
+
folder_name: str,
|
|
1432
|
+
device: str,
|
|
1433
|
+
force_download: bool = False, # noqa: ARG001
|
|
1434
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
1435
|
+
) -> dict[str, Any]:
|
|
1436
|
+
"""
|
|
1437
|
+
Load a MNTSS CLT layer as a single layer transcoder.
|
|
1438
|
+
The assumption is that the `folder_name` is the layer to load as an int
|
|
1439
|
+
"""
|
|
1440
|
+
base_config_path = hf_hub_download(
|
|
1441
|
+
repo_id, "config.yaml", force_download=force_download
|
|
1442
|
+
)
|
|
1443
|
+
with open(base_config_path) as f:
|
|
1444
|
+
cfg_info: dict[str, Any] = yaml.safe_load(f)
|
|
1445
|
+
|
|
1446
|
+
# Get tensor shapes without downloading full files using HTTP range requests
|
|
1447
|
+
encoder_url = hf_hub_url(repo_id, f"W_enc_{folder_name}.safetensors")
|
|
1448
|
+
encoder_shapes = get_safetensors_tensor_shapes(encoder_url)
|
|
1449
|
+
|
|
1450
|
+
# Extract shapes for the required tensors
|
|
1451
|
+
b_dec_shape = encoder_shapes[f"b_dec_{folder_name}"]
|
|
1452
|
+
b_enc_shape = encoder_shapes[f"b_enc_{folder_name}"]
|
|
1453
|
+
|
|
1454
|
+
return {
|
|
1374
1455
|
"architecture": "transcoder",
|
|
1375
|
-
"d_in":
|
|
1376
|
-
"d_out":
|
|
1377
|
-
"d_sae":
|
|
1456
|
+
"d_in": b_dec_shape[0],
|
|
1457
|
+
"d_out": b_dec_shape[0],
|
|
1458
|
+
"d_sae": b_enc_shape[0],
|
|
1378
1459
|
"dtype": "float32",
|
|
1379
1460
|
"device": device if device is not None else "cpu",
|
|
1380
1461
|
"activation_fn": "relu",
|
|
@@ -1387,8 +1468,6 @@ def mntss_clt_layer_huggingface_loader(
|
|
|
1387
1468
|
**(cfg_overrides or {}),
|
|
1388
1469
|
}
|
|
1389
1470
|
|
|
1390
|
-
return cfg_dict, state_dict, None
|
|
1391
|
-
|
|
1392
1471
|
|
|
1393
1472
|
NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
|
|
1394
1473
|
"sae_lens": sae_lens_huggingface_loader,
|
|
@@ -1416,4 +1495,5 @@ NAMED_PRETRAINED_SAE_CONFIG_GETTERS: dict[str, PretrainedSaeConfigHuggingfaceLoa
|
|
|
1416
1495
|
"sparsify": get_sparsify_config_from_hf,
|
|
1417
1496
|
"gemma_2_transcoder": get_gemma_2_transcoder_config_from_hf,
|
|
1418
1497
|
"mwhanna_transcoder": get_mwhanna_transcoder_config_from_hf,
|
|
1498
|
+
"mntss_clt_layer_transcoder": get_mntss_clt_layer_config_from_hf,
|
|
1419
1499
|
}
|
|
@@ -14744,4 +14744,88 @@ mwhanna-qwen3-0.6b-transcoders-lowl0:
|
|
|
14744
14744
|
neuronpedia: qwen3-0.6b/26-transcoder-hp-lowl0
|
|
14745
14745
|
- id: layer_27
|
|
14746
14746
|
path: layer_27.safetensors
|
|
14747
|
-
neuronpedia: qwen3-0.6b/27-transcoder-hp-lowl0
|
|
14747
|
+
neuronpedia: qwen3-0.6b/27-transcoder-hp-lowl0
|
|
14748
|
+
|
|
14749
|
+
mntss-gemma-2-2b-2.5m-clt-as-per-layer:
|
|
14750
|
+
conversion_func: mntss_clt_layer_transcoder
|
|
14751
|
+
model: gemma-2-2b
|
|
14752
|
+
repo_id: mntss/clt-gemma-2-2b-2.5M
|
|
14753
|
+
saes:
|
|
14754
|
+
- id: layer_0
|
|
14755
|
+
path: 0
|
|
14756
|
+
neuronpedia: gemma-2-2b/0-clt-hp
|
|
14757
|
+
- id: layer_1
|
|
14758
|
+
path: 1
|
|
14759
|
+
neuronpedia: gemma-2-2b/1-clt-hp
|
|
14760
|
+
- id: layer_2
|
|
14761
|
+
path: 2
|
|
14762
|
+
neuronpedia: gemma-2-2b/2-clt-hp
|
|
14763
|
+
- id: layer_3
|
|
14764
|
+
path: 3
|
|
14765
|
+
neuronpedia: gemma-2-2b/3-clt-hp
|
|
14766
|
+
- id: layer_4
|
|
14767
|
+
path: 4
|
|
14768
|
+
neuronpedia: gemma-2-2b/4-clt-hp
|
|
14769
|
+
- id: layer_5
|
|
14770
|
+
path: 5
|
|
14771
|
+
neuronpedia: gemma-2-2b/5-clt-hp
|
|
14772
|
+
- id: layer_6
|
|
14773
|
+
path: 6
|
|
14774
|
+
neuronpedia: gemma-2-2b/6-clt-hp
|
|
14775
|
+
- id: layer_7
|
|
14776
|
+
path: 7
|
|
14777
|
+
neuronpedia: gemma-2-2b/7-clt-hp
|
|
14778
|
+
- id: layer_8
|
|
14779
|
+
path: 8
|
|
14780
|
+
neuronpedia: gemma-2-2b/8-clt-hp
|
|
14781
|
+
- id: layer_9
|
|
14782
|
+
path: 9
|
|
14783
|
+
neuronpedia: gemma-2-2b/9-clt-hp
|
|
14784
|
+
- id: layer_10
|
|
14785
|
+
path: 10
|
|
14786
|
+
neuronpedia: gemma-2-2b/10-clt-hp
|
|
14787
|
+
- id: layer_11
|
|
14788
|
+
path: 11
|
|
14789
|
+
neuronpedia: gemma-2-2b/11-clt-hp
|
|
14790
|
+
- id: layer_12
|
|
14791
|
+
path: 12
|
|
14792
|
+
neuronpedia: gemma-2-2b/12-clt-hp
|
|
14793
|
+
- id: layer_13
|
|
14794
|
+
path: 13
|
|
14795
|
+
neuronpedia: gemma-2-2b/13-clt-hp
|
|
14796
|
+
- id: layer_14
|
|
14797
|
+
path: 14
|
|
14798
|
+
neuronpedia: gemma-2-2b/14-clt-hp
|
|
14799
|
+
- id: layer_15
|
|
14800
|
+
path: 15
|
|
14801
|
+
neuronpedia: gemma-2-2b/15-clt-hp
|
|
14802
|
+
- id: layer_16
|
|
14803
|
+
path: 16
|
|
14804
|
+
neuronpedia: gemma-2-2b/16-clt-hp
|
|
14805
|
+
- id: layer_17
|
|
14806
|
+
path: 17
|
|
14807
|
+
neuronpedia: gemma-2-2b/17-clt-hp
|
|
14808
|
+
- id: layer_18
|
|
14809
|
+
path: 18
|
|
14810
|
+
neuronpedia: gemma-2-2b/18-clt-hp
|
|
14811
|
+
- id: layer_19
|
|
14812
|
+
path: 19
|
|
14813
|
+
neuronpedia: gemma-2-2b/19-clt-hp
|
|
14814
|
+
- id: layer_20
|
|
14815
|
+
path: 20
|
|
14816
|
+
neuronpedia: gemma-2-2b/20-clt-hp
|
|
14817
|
+
- id: layer_21
|
|
14818
|
+
path: 21
|
|
14819
|
+
neuronpedia: gemma-2-2b/21-clt-hp
|
|
14820
|
+
- id: layer_22
|
|
14821
|
+
path: 22
|
|
14822
|
+
neuronpedia: gemma-2-2b/22-clt-hp
|
|
14823
|
+
- id: layer_23
|
|
14824
|
+
path: 23
|
|
14825
|
+
neuronpedia: gemma-2-2b/23-clt-hp
|
|
14826
|
+
- id: layer_24
|
|
14827
|
+
path: 24
|
|
14828
|
+
neuronpedia: gemma-2-2b/24-clt-hp
|
|
14829
|
+
- id: layer_25
|
|
14830
|
+
path: 25
|
|
14831
|
+
neuronpedia: gemma-2-2b/25-clt-hp
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|