sae-lens 6.18.0__py3-none-any.whl → 6.20.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sae-lens might be problematic. Click here for more details.
- sae_lens/__init__.py +6 -1
- sae_lens/loading/pretrained_sae_loaders.py +188 -0
- sae_lens/loading/pretrained_saes_directory.py +5 -3
- sae_lens/pretrained_saes.yaml +51 -1
- sae_lens/saes/__init__.py +3 -0
- sae_lens/saes/sae.py +4 -12
- sae_lens/saes/temporal_sae.py +372 -0
- sae_lens/training/activations_store.py +1 -1
- {sae_lens-6.18.0.dist-info → sae_lens-6.20.1.dist-info}/METADATA +16 -16
- {sae_lens-6.18.0.dist-info → sae_lens-6.20.1.dist-info}/RECORD +12 -11
- {sae_lens-6.18.0.dist-info → sae_lens-6.20.1.dist-info}/WHEEL +0 -0
- {sae_lens-6.18.0.dist-info → sae_lens-6.20.1.dist-info}/licenses/LICENSE +0 -0
sae_lens/__init__.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
# ruff: noqa: E402
|
|
2
|
-
__version__ = "6.
|
|
2
|
+
__version__ = "6.20.1"
|
|
3
3
|
|
|
4
4
|
import logging
|
|
5
5
|
|
|
@@ -28,6 +28,8 @@ from sae_lens.saes import (
|
|
|
28
28
|
StandardSAEConfig,
|
|
29
29
|
StandardTrainingSAE,
|
|
30
30
|
StandardTrainingSAEConfig,
|
|
31
|
+
TemporalSAE,
|
|
32
|
+
TemporalSAEConfig,
|
|
31
33
|
TopKSAE,
|
|
32
34
|
TopKSAEConfig,
|
|
33
35
|
TopKTrainingSAE,
|
|
@@ -105,6 +107,8 @@ __all__ = [
|
|
|
105
107
|
"JumpReLUTranscoderConfig",
|
|
106
108
|
"MatryoshkaBatchTopKTrainingSAE",
|
|
107
109
|
"MatryoshkaBatchTopKTrainingSAEConfig",
|
|
110
|
+
"TemporalSAE",
|
|
111
|
+
"TemporalSAEConfig",
|
|
108
112
|
]
|
|
109
113
|
|
|
110
114
|
|
|
@@ -127,3 +131,4 @@ register_sae_training_class(
|
|
|
127
131
|
register_sae_class("transcoder", Transcoder, TranscoderConfig)
|
|
128
132
|
register_sae_class("skip_transcoder", SkipTranscoder, SkipTranscoderConfig)
|
|
129
133
|
register_sae_class("jumprelu_transcoder", JumpReLUTranscoder, JumpReLUTranscoderConfig)
|
|
134
|
+
register_sae_class("temporal", TemporalSAE, TemporalSAEConfig)
|
|
@@ -523,6 +523,82 @@ def gemma_2_sae_huggingface_loader(
|
|
|
523
523
|
return cfg_dict, state_dict, log_sparsity
|
|
524
524
|
|
|
525
525
|
|
|
526
|
+
def get_goodfire_config_from_hf(
|
|
527
|
+
repo_id: str,
|
|
528
|
+
folder_name: str, # noqa: ARG001
|
|
529
|
+
device: str,
|
|
530
|
+
force_download: bool = False, # noqa: ARG001
|
|
531
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
532
|
+
) -> dict[str, Any]:
|
|
533
|
+
cfg_dict = None
|
|
534
|
+
if repo_id == "Goodfire/Llama-3.3-70B-Instruct-SAE-l50":
|
|
535
|
+
if folder_name != "Llama-3.3-70B-Instruct-SAE-l50.pt":
|
|
536
|
+
raise ValueError(f"Unsupported Goodfire SAE: {repo_id}/{folder_name}")
|
|
537
|
+
cfg_dict = {
|
|
538
|
+
"architecture": "standard",
|
|
539
|
+
"d_in": 8192,
|
|
540
|
+
"d_sae": 65536,
|
|
541
|
+
"model_name": "meta-llama/Llama-3.3-70B-Instruct",
|
|
542
|
+
"hook_name": "blocks.50.hook_resid_post",
|
|
543
|
+
"hook_head_index": None,
|
|
544
|
+
"dataset_path": "lmsys/lmsys-chat-1m",
|
|
545
|
+
"apply_b_dec_to_input": False,
|
|
546
|
+
}
|
|
547
|
+
elif repo_id == "Goodfire/Llama-3.1-8B-Instruct-SAE-l19":
|
|
548
|
+
if folder_name != "Llama-3.1-8B-Instruct-SAE-l19.pth":
|
|
549
|
+
raise ValueError(f"Unsupported Goodfire SAE: {repo_id}/{folder_name}")
|
|
550
|
+
cfg_dict = {
|
|
551
|
+
"architecture": "standard",
|
|
552
|
+
"d_in": 4096,
|
|
553
|
+
"d_sae": 65536,
|
|
554
|
+
"model_name": "meta-llama/Llama-3.1-8B-Instruct",
|
|
555
|
+
"hook_name": "blocks.19.hook_resid_post",
|
|
556
|
+
"hook_head_index": None,
|
|
557
|
+
"dataset_path": "lmsys/lmsys-chat-1m",
|
|
558
|
+
"apply_b_dec_to_input": False,
|
|
559
|
+
}
|
|
560
|
+
if cfg_dict is None:
|
|
561
|
+
raise ValueError(f"Unsupported Goodfire SAE: {repo_id}/{folder_name}")
|
|
562
|
+
if device is not None:
|
|
563
|
+
cfg_dict["device"] = device
|
|
564
|
+
if cfg_overrides is not None:
|
|
565
|
+
cfg_dict.update(cfg_overrides)
|
|
566
|
+
return cfg_dict
|
|
567
|
+
|
|
568
|
+
|
|
569
|
+
def get_goodfire_huggingface_loader(
|
|
570
|
+
repo_id: str,
|
|
571
|
+
folder_name: str,
|
|
572
|
+
device: str = "cpu",
|
|
573
|
+
force_download: bool = False,
|
|
574
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
575
|
+
) -> tuple[dict[str, Any], dict[str, torch.Tensor], torch.Tensor | None]:
|
|
576
|
+
cfg_dict = get_goodfire_config_from_hf(
|
|
577
|
+
repo_id,
|
|
578
|
+
folder_name,
|
|
579
|
+
device,
|
|
580
|
+
force_download,
|
|
581
|
+
cfg_overrides,
|
|
582
|
+
)
|
|
583
|
+
|
|
584
|
+
# Download the SAE weights
|
|
585
|
+
sae_path = hf_hub_download(
|
|
586
|
+
repo_id=repo_id,
|
|
587
|
+
filename=folder_name,
|
|
588
|
+
force_download=force_download,
|
|
589
|
+
)
|
|
590
|
+
raw_state_dict = torch.load(sae_path, map_location=device)
|
|
591
|
+
|
|
592
|
+
state_dict = {
|
|
593
|
+
"W_enc": raw_state_dict["encoder_linear.weight"].T,
|
|
594
|
+
"W_dec": raw_state_dict["decoder_linear.weight"].T,
|
|
595
|
+
"b_enc": raw_state_dict["encoder_linear.bias"],
|
|
596
|
+
"b_dec": raw_state_dict["decoder_linear.bias"],
|
|
597
|
+
}
|
|
598
|
+
|
|
599
|
+
return cfg_dict, state_dict, None
|
|
600
|
+
|
|
601
|
+
|
|
526
602
|
def get_llama_scope_config_from_hf(
|
|
527
603
|
repo_id: str,
|
|
528
604
|
folder_name: str,
|
|
@@ -1475,6 +1551,114 @@ def get_mntss_clt_layer_config_from_hf(
|
|
|
1475
1551
|
}
|
|
1476
1552
|
|
|
1477
1553
|
|
|
1554
|
+
def get_temporal_sae_config_from_hf(
|
|
1555
|
+
repo_id: str,
|
|
1556
|
+
folder_name: str,
|
|
1557
|
+
device: str,
|
|
1558
|
+
force_download: bool = False,
|
|
1559
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
1560
|
+
) -> dict[str, Any]:
|
|
1561
|
+
"""Get TemporalSAE config without loading weights."""
|
|
1562
|
+
# Download config file
|
|
1563
|
+
conf_path = hf_hub_download(
|
|
1564
|
+
repo_id=repo_id,
|
|
1565
|
+
filename=f"{folder_name}/conf.yaml",
|
|
1566
|
+
force_download=force_download,
|
|
1567
|
+
)
|
|
1568
|
+
|
|
1569
|
+
# Load and parse config
|
|
1570
|
+
with open(conf_path) as f:
|
|
1571
|
+
yaml_config = yaml.safe_load(f)
|
|
1572
|
+
|
|
1573
|
+
# Extract parameters
|
|
1574
|
+
d_in = yaml_config["llm"]["dimin"]
|
|
1575
|
+
exp_factor = yaml_config["sae"]["exp_factor"]
|
|
1576
|
+
d_sae = int(d_in * exp_factor)
|
|
1577
|
+
|
|
1578
|
+
# extract layer from folder_name eg : "layer_12/temporal"
|
|
1579
|
+
layer = re.search(r"layer_(\d+)", folder_name)
|
|
1580
|
+
if layer is None:
|
|
1581
|
+
raise ValueError(f"Could not find layer in folder_name: {folder_name}")
|
|
1582
|
+
layer = int(layer.group(1))
|
|
1583
|
+
|
|
1584
|
+
# Build config dict
|
|
1585
|
+
cfg_dict = {
|
|
1586
|
+
"architecture": "temporal",
|
|
1587
|
+
"hook_name": f"blocks.{layer}.hook_resid_post",
|
|
1588
|
+
"d_in": d_in,
|
|
1589
|
+
"d_sae": d_sae,
|
|
1590
|
+
"n_heads": yaml_config["sae"]["n_heads"],
|
|
1591
|
+
"n_attn_layers": yaml_config["sae"]["n_attn_layers"],
|
|
1592
|
+
"bottleneck_factor": yaml_config["sae"]["bottleneck_factor"],
|
|
1593
|
+
"sae_diff_type": yaml_config["sae"]["sae_diff_type"],
|
|
1594
|
+
"kval_topk": yaml_config["sae"]["kval_topk"],
|
|
1595
|
+
"tied_weights": yaml_config["sae"]["tied_weights"],
|
|
1596
|
+
"dtype": yaml_config["data"]["dtype"],
|
|
1597
|
+
"device": device,
|
|
1598
|
+
"normalize_activations": "constant_scalar_rescale",
|
|
1599
|
+
"activation_normalization_factor": yaml_config["sae"]["scaling_factor"],
|
|
1600
|
+
"apply_b_dec_to_input": True,
|
|
1601
|
+
}
|
|
1602
|
+
|
|
1603
|
+
if cfg_overrides:
|
|
1604
|
+
cfg_dict.update(cfg_overrides)
|
|
1605
|
+
|
|
1606
|
+
return cfg_dict
|
|
1607
|
+
|
|
1608
|
+
|
|
1609
|
+
def temporal_sae_huggingface_loader(
|
|
1610
|
+
repo_id: str,
|
|
1611
|
+
folder_name: str,
|
|
1612
|
+
device: str = "cpu",
|
|
1613
|
+
force_download: bool = False,
|
|
1614
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
1615
|
+
) -> tuple[dict[str, Any], dict[str, torch.Tensor], torch.Tensor | None]:
|
|
1616
|
+
"""
|
|
1617
|
+
Load TemporalSAE from canrager/temporalSAEs format (safetensors version).
|
|
1618
|
+
|
|
1619
|
+
Expects folder_name to contain:
|
|
1620
|
+
- conf.yaml (configuration)
|
|
1621
|
+
- latest_ckpt.safetensors (model weights)
|
|
1622
|
+
"""
|
|
1623
|
+
|
|
1624
|
+
cfg_dict = get_temporal_sae_config_from_hf(
|
|
1625
|
+
repo_id=repo_id,
|
|
1626
|
+
folder_name=folder_name,
|
|
1627
|
+
device=device,
|
|
1628
|
+
force_download=force_download,
|
|
1629
|
+
cfg_overrides=cfg_overrides,
|
|
1630
|
+
)
|
|
1631
|
+
|
|
1632
|
+
# Download checkpoint (safetensors format)
|
|
1633
|
+
ckpt_path = hf_hub_download(
|
|
1634
|
+
repo_id=repo_id,
|
|
1635
|
+
filename=f"{folder_name}/latest_ckpt.safetensors",
|
|
1636
|
+
force_download=force_download,
|
|
1637
|
+
)
|
|
1638
|
+
|
|
1639
|
+
# Load checkpoint from safetensors
|
|
1640
|
+
state_dict_raw = load_file(ckpt_path, device=device)
|
|
1641
|
+
|
|
1642
|
+
# Convert to SAELens naming convention
|
|
1643
|
+
# TemporalSAE uses: D (decoder), E (encoder), b (bias), attn_layers.*
|
|
1644
|
+
state_dict = {}
|
|
1645
|
+
|
|
1646
|
+
# Copy attention layers as-is
|
|
1647
|
+
for key, value in state_dict_raw.items():
|
|
1648
|
+
if key.startswith("attn_layers."):
|
|
1649
|
+
state_dict[key] = value.to(device)
|
|
1650
|
+
|
|
1651
|
+
# Main parameters
|
|
1652
|
+
state_dict["W_dec"] = state_dict_raw["D"].to(device)
|
|
1653
|
+
state_dict["b_dec"] = state_dict_raw["b"].to(device)
|
|
1654
|
+
|
|
1655
|
+
# Handle tied/untied weights
|
|
1656
|
+
if "E" in state_dict_raw:
|
|
1657
|
+
state_dict["W_enc"] = state_dict_raw["E"].to(device)
|
|
1658
|
+
|
|
1659
|
+
return cfg_dict, state_dict, None
|
|
1660
|
+
|
|
1661
|
+
|
|
1478
1662
|
NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
|
|
1479
1663
|
"sae_lens": sae_lens_huggingface_loader,
|
|
1480
1664
|
"connor_rob_hook_z": connor_rob_hook_z_huggingface_loader,
|
|
@@ -1487,6 +1671,8 @@ NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
|
|
|
1487
1671
|
"gemma_2_transcoder": gemma_2_transcoder_huggingface_loader,
|
|
1488
1672
|
"mwhanna_transcoder": mwhanna_transcoder_huggingface_loader,
|
|
1489
1673
|
"mntss_clt_layer_transcoder": mntss_clt_layer_huggingface_loader,
|
|
1674
|
+
"temporal": temporal_sae_huggingface_loader,
|
|
1675
|
+
"goodfire": get_goodfire_huggingface_loader,
|
|
1490
1676
|
}
|
|
1491
1677
|
|
|
1492
1678
|
|
|
@@ -1502,4 +1688,6 @@ NAMED_PRETRAINED_SAE_CONFIG_GETTERS: dict[str, PretrainedSaeConfigHuggingfaceLoa
|
|
|
1502
1688
|
"gemma_2_transcoder": get_gemma_2_transcoder_config_from_hf,
|
|
1503
1689
|
"mwhanna_transcoder": get_mwhanna_transcoder_config_from_hf,
|
|
1504
1690
|
"mntss_clt_layer_transcoder": get_mntss_clt_layer_config_from_hf,
|
|
1691
|
+
"temporal": get_temporal_sae_config_from_hf,
|
|
1692
|
+
"goodfire": get_goodfire_config_from_hf,
|
|
1505
1693
|
}
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
2
|
from functools import cache
|
|
3
|
-
from importlib import
|
|
3
|
+
from importlib.resources import files
|
|
4
4
|
from typing import Any
|
|
5
5
|
|
|
6
6
|
import yaml
|
|
@@ -24,7 +24,8 @@ def get_pretrained_saes_directory() -> dict[str, PretrainedSAELookup]:
|
|
|
24
24
|
package = "sae_lens"
|
|
25
25
|
# Access the file within the package using importlib.resources
|
|
26
26
|
directory: dict[str, PretrainedSAELookup] = {}
|
|
27
|
-
|
|
27
|
+
yaml_file = files(package).joinpath("pretrained_saes.yaml")
|
|
28
|
+
with yaml_file.open("r") as file:
|
|
28
29
|
# Load the YAML file content
|
|
29
30
|
data = yaml.safe_load(file)
|
|
30
31
|
for release, value in data.items():
|
|
@@ -68,7 +69,8 @@ def get_norm_scaling_factor(release: str, sae_id: str) -> float | None:
|
|
|
68
69
|
float | None: The norm_scaling_factor if it exists, None otherwise.
|
|
69
70
|
"""
|
|
70
71
|
package = "sae_lens"
|
|
71
|
-
|
|
72
|
+
yaml_file = files(package).joinpath("pretrained_saes.yaml")
|
|
73
|
+
with yaml_file.open("r") as file:
|
|
72
74
|
data = yaml.safe_load(file)
|
|
73
75
|
if release in data:
|
|
74
76
|
for sae_info in data[release]["saes"]:
|
sae_lens/pretrained_saes.yaml
CHANGED
|
@@ -1,3 +1,35 @@
|
|
|
1
|
+
temporal-sae-gemma-2-2b:
|
|
2
|
+
conversion_func: temporal
|
|
3
|
+
model: gemma-2-2b
|
|
4
|
+
repo_id: canrager/temporalSAEs
|
|
5
|
+
config_overrides:
|
|
6
|
+
model_name: gemma-2-2b
|
|
7
|
+
hook_name: blocks.12.hook_resid_post
|
|
8
|
+
dataset_path: monology/pile-uncopyrighted
|
|
9
|
+
saes:
|
|
10
|
+
- id: blocks.12.hook_resid_post
|
|
11
|
+
l0: 192
|
|
12
|
+
norm_scaling_factor: 0.00666666667
|
|
13
|
+
path: gemma-2-2B/layer_12/temporal
|
|
14
|
+
neuronpedia: gemma-2-2b/12-temporal-res
|
|
15
|
+
temporal-sae-llama-3.1-8b:
|
|
16
|
+
conversion_func: temporal
|
|
17
|
+
model: meta-llama/Llama-3.1-8B
|
|
18
|
+
repo_id: canrager/temporalSAEs
|
|
19
|
+
config_overrides:
|
|
20
|
+
model_name: meta-llama/Llama-3.1-8B
|
|
21
|
+
dataset_path: monology/pile-uncopyrighted
|
|
22
|
+
saes:
|
|
23
|
+
- id: blocks.15.hook_resid_post
|
|
24
|
+
l0: 256
|
|
25
|
+
norm_scaling_factor: 0.029
|
|
26
|
+
path: llama-3.1-8B/layer_15/temporal
|
|
27
|
+
neuronpedia: llama3.1-8b/15-temporal-res
|
|
28
|
+
- id: blocks.26.hook_resid_post
|
|
29
|
+
l0: 256
|
|
30
|
+
norm_scaling_factor: 0.029
|
|
31
|
+
path: llama-3.1-8B/layer_26/temporal
|
|
32
|
+
neuronpedia: llama3.1-8b/26-temporal-res
|
|
1
33
|
deepseek-r1-distill-llama-8b-qresearch:
|
|
2
34
|
conversion_func: deepseek_r1
|
|
3
35
|
model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
|
|
@@ -14882,4 +14914,22 @@ qwen2.5-7b-instruct-andyrdt:
|
|
|
14882
14914
|
neuronpedia: qwen2.5-7b-it/23-resid-post-aa
|
|
14883
14915
|
- id: resid_post_layer_27_trainer_1
|
|
14884
14916
|
path: resid_post_layer_27/trainer_1
|
|
14885
|
-
neuronpedia: qwen2.5-7b-it/27-resid-post-aa
|
|
14917
|
+
neuronpedia: qwen2.5-7b-it/27-resid-post-aa
|
|
14918
|
+
|
|
14919
|
+
goodfire-llama-3.3-70b-instruct:
|
|
14920
|
+
conversion_func: goodfire
|
|
14921
|
+
model: meta-llama/Llama-3.3-70B-Instruct
|
|
14922
|
+
repo_id: Goodfire/Llama-3.3-70B-Instruct-SAE-l50
|
|
14923
|
+
saes:
|
|
14924
|
+
- id: layer_50
|
|
14925
|
+
path: Llama-3.3-70B-Instruct-SAE-l50.pt
|
|
14926
|
+
l0: 121
|
|
14927
|
+
|
|
14928
|
+
goodfire-llama-3.1-8b-instruct:
|
|
14929
|
+
conversion_func: goodfire
|
|
14930
|
+
model: meta-llama/Llama-3.1-8B-Instruct
|
|
14931
|
+
repo_id: Goodfire/Llama-3.1-8B-Instruct-SAE-l19
|
|
14932
|
+
saes:
|
|
14933
|
+
- id: layer_19
|
|
14934
|
+
path: Llama-3.1-8B-Instruct-SAE-l19.pth
|
|
14935
|
+
l0: 91
|
sae_lens/saes/__init__.py
CHANGED
|
@@ -25,6 +25,7 @@ from .standard_sae import (
|
|
|
25
25
|
StandardTrainingSAE,
|
|
26
26
|
StandardTrainingSAEConfig,
|
|
27
27
|
)
|
|
28
|
+
from .temporal_sae import TemporalSAE, TemporalSAEConfig
|
|
28
29
|
from .topk_sae import (
|
|
29
30
|
TopKSAE,
|
|
30
31
|
TopKSAEConfig,
|
|
@@ -71,4 +72,6 @@ __all__ = [
|
|
|
71
72
|
"JumpReLUTranscoderConfig",
|
|
72
73
|
"MatryoshkaBatchTopKTrainingSAE",
|
|
73
74
|
"MatryoshkaBatchTopKTrainingSAEConfig",
|
|
75
|
+
"TemporalSAE",
|
|
76
|
+
"TemporalSAEConfig",
|
|
74
77
|
]
|
sae_lens/saes/sae.py
CHANGED
|
@@ -155,9 +155,9 @@ class SAEConfig(ABC):
|
|
|
155
155
|
dtype: str = "float32"
|
|
156
156
|
device: str = "cpu"
|
|
157
157
|
apply_b_dec_to_input: bool = True
|
|
158
|
-
normalize_activations: Literal[
|
|
159
|
-
"none",
|
|
160
|
-
|
|
158
|
+
normalize_activations: Literal["none", "expected_average_only_in", "layer_norm"] = (
|
|
159
|
+
"none" # none, expected_average_only_in (Anthropic April Update)
|
|
160
|
+
)
|
|
161
161
|
reshape_activations: Literal["none", "hook_z"] = "none"
|
|
162
162
|
metadata: SAEMetadata = field(default_factory=SAEMetadata)
|
|
163
163
|
|
|
@@ -309,6 +309,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
309
309
|
|
|
310
310
|
self.run_time_activation_norm_fn_in = run_time_activation_norm_fn_in
|
|
311
311
|
self.run_time_activation_norm_fn_out = run_time_activation_norm_fn_out
|
|
312
|
+
|
|
312
313
|
elif self.cfg.normalize_activations == "layer_norm":
|
|
313
314
|
# we need to scale the norm of the input and store the scaling factor
|
|
314
315
|
def run_time_activation_ln_in(
|
|
@@ -452,23 +453,14 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
452
453
|
def process_sae_in(
|
|
453
454
|
self, sae_in: Float[torch.Tensor, "... d_in"]
|
|
454
455
|
) -> Float[torch.Tensor, "... d_in"]:
|
|
455
|
-
# print(f"Input shape to process_sae_in: {sae_in.shape}")
|
|
456
|
-
# print(f"self.cfg.hook_name: {self.cfg.hook_name}")
|
|
457
|
-
# print(f"self.b_dec shape: {self.b_dec.shape}")
|
|
458
|
-
# print(f"Hook z reshaping mode: {getattr(self, 'hook_z_reshaping_mode', False)}")
|
|
459
|
-
|
|
460
456
|
sae_in = sae_in.to(self.dtype)
|
|
461
|
-
|
|
462
|
-
# print(f"Shape before reshape_fn_in: {sae_in.shape}")
|
|
463
457
|
sae_in = self.reshape_fn_in(sae_in)
|
|
464
|
-
# print(f"Shape after reshape_fn_in: {sae_in.shape}")
|
|
465
458
|
|
|
466
459
|
sae_in = self.hook_sae_input(sae_in)
|
|
467
460
|
sae_in = self.run_time_activation_norm_fn_in(sae_in)
|
|
468
461
|
|
|
469
462
|
# Here's where the error happens
|
|
470
463
|
bias_term = self.b_dec * self.cfg.apply_b_dec_to_input
|
|
471
|
-
# print(f"Bias term shape: {bias_term.shape}")
|
|
472
464
|
|
|
473
465
|
return sae_in - bias_term
|
|
474
466
|
|
|
@@ -0,0 +1,372 @@
|
|
|
1
|
+
"""TemporalSAE: A Sparse Autoencoder with temporal attention mechanism.
|
|
2
|
+
|
|
3
|
+
TemporalSAE decomposes activations into:
|
|
4
|
+
1. Predicted codes (from attention over context)
|
|
5
|
+
2. Novel codes (sparse features of the residual)
|
|
6
|
+
|
|
7
|
+
See: https://arxiv.org/abs/2410.04185
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import math
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from typing import Literal
|
|
13
|
+
|
|
14
|
+
import torch
|
|
15
|
+
import torch.nn.functional as F
|
|
16
|
+
from jaxtyping import Float
|
|
17
|
+
from torch import nn
|
|
18
|
+
from typing_extensions import override
|
|
19
|
+
|
|
20
|
+
from sae_lens import logger
|
|
21
|
+
from sae_lens.saes.sae import SAE, SAEConfig
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def get_attention(query: torch.Tensor, key: torch.Tensor) -> torch.Tensor:
|
|
25
|
+
"""Compute causal attention weights."""
|
|
26
|
+
L, S = query.size(-2), key.size(-2)
|
|
27
|
+
scale_factor = 1 / math.sqrt(query.size(-1))
|
|
28
|
+
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
|
|
29
|
+
temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0)
|
|
30
|
+
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
|
31
|
+
attn_bias.to(query.dtype)
|
|
32
|
+
|
|
33
|
+
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
|
34
|
+
attn_weight += attn_bias
|
|
35
|
+
return torch.softmax(attn_weight, dim=-1)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class ManualAttention(nn.Module):
|
|
39
|
+
"""Manual attention implementation for TemporalSAE."""
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
dimin: int,
|
|
44
|
+
n_heads: int = 4,
|
|
45
|
+
bottleneck_factor: int = 64,
|
|
46
|
+
bias_k: bool = True,
|
|
47
|
+
bias_q: bool = True,
|
|
48
|
+
bias_v: bool = True,
|
|
49
|
+
bias_o: bool = True,
|
|
50
|
+
):
|
|
51
|
+
super().__init__()
|
|
52
|
+
assert dimin % (bottleneck_factor * n_heads) == 0
|
|
53
|
+
|
|
54
|
+
self.n_heads = n_heads
|
|
55
|
+
self.n_embds = dimin // bottleneck_factor
|
|
56
|
+
self.dimin = dimin
|
|
57
|
+
|
|
58
|
+
# Key, query, value projections
|
|
59
|
+
self.k_ctx = nn.Linear(dimin, self.n_embds, bias=bias_k)
|
|
60
|
+
self.q_target = nn.Linear(dimin, self.n_embds, bias=bias_q)
|
|
61
|
+
self.v_ctx = nn.Linear(dimin, dimin, bias=bias_v)
|
|
62
|
+
self.c_proj = nn.Linear(dimin, dimin, bias=bias_o)
|
|
63
|
+
|
|
64
|
+
# Normalize to match scale with representations
|
|
65
|
+
with torch.no_grad():
|
|
66
|
+
scaling = 1 / math.sqrt(self.n_embds // self.n_heads)
|
|
67
|
+
self.k_ctx.weight.copy_(
|
|
68
|
+
scaling
|
|
69
|
+
* self.k_ctx.weight
|
|
70
|
+
/ (1e-6 + torch.linalg.norm(self.k_ctx.weight, dim=1, keepdim=True))
|
|
71
|
+
)
|
|
72
|
+
self.q_target.weight.copy_(
|
|
73
|
+
scaling
|
|
74
|
+
* self.q_target.weight
|
|
75
|
+
/ (1e-6 + torch.linalg.norm(self.q_target.weight, dim=1, keepdim=True))
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
scaling = 1 / math.sqrt(self.dimin // self.n_heads)
|
|
79
|
+
self.v_ctx.weight.copy_(
|
|
80
|
+
scaling
|
|
81
|
+
* self.v_ctx.weight
|
|
82
|
+
/ (1e-6 + torch.linalg.norm(self.v_ctx.weight, dim=1, keepdim=True))
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
scaling = 1 / math.sqrt(self.dimin)
|
|
86
|
+
self.c_proj.weight.copy_(
|
|
87
|
+
scaling
|
|
88
|
+
* self.c_proj.weight
|
|
89
|
+
/ (1e-6 + torch.linalg.norm(self.c_proj.weight, dim=1, keepdim=True))
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
def forward(
|
|
93
|
+
self, x_ctx: torch.Tensor, x_target: torch.Tensor, get_attn_map: bool = False
|
|
94
|
+
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
95
|
+
"""Compute projective attention output."""
|
|
96
|
+
k = self.k_ctx(x_ctx)
|
|
97
|
+
v = self.v_ctx(x_ctx)
|
|
98
|
+
q = self.q_target(x_target)
|
|
99
|
+
|
|
100
|
+
# Split into heads
|
|
101
|
+
B, T, _ = x_ctx.size()
|
|
102
|
+
k = k.view(B, T, self.n_heads, self.n_embds // self.n_heads).transpose(1, 2)
|
|
103
|
+
q = q.view(B, T, self.n_heads, self.n_embds // self.n_heads).transpose(1, 2)
|
|
104
|
+
v = v.view(B, T, self.n_heads, self.dimin // self.n_heads).transpose(1, 2)
|
|
105
|
+
|
|
106
|
+
# Attention map (optional)
|
|
107
|
+
attn_map = None
|
|
108
|
+
if get_attn_map:
|
|
109
|
+
attn_map = get_attention(query=q, key=k)
|
|
110
|
+
|
|
111
|
+
# Scaled dot-product attention
|
|
112
|
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
|
113
|
+
q, k, v, attn_mask=None, dropout_p=0, is_causal=True
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
# Reshape and project
|
|
117
|
+
d_target = self.c_proj(
|
|
118
|
+
attn_output.transpose(1, 2).contiguous().view(B, T, self.dimin)
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
return d_target, attn_map
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
@dataclass
|
|
125
|
+
class TemporalSAEConfig(SAEConfig):
|
|
126
|
+
"""Configuration for TemporalSAE inference.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
d_in: Input dimension (dimensionality of the activations being encoded)
|
|
130
|
+
d_sae: SAE latent dimension (number of features)
|
|
131
|
+
n_heads: Number of attention heads in temporal attention
|
|
132
|
+
n_attn_layers: Number of attention layers
|
|
133
|
+
bottleneck_factor: Bottleneck factor for attention dimension
|
|
134
|
+
sae_diff_type: Type of SAE for novel codes ('relu' or 'topk')
|
|
135
|
+
kval_topk: K value for top-k sparsity (if sae_diff_type='topk')
|
|
136
|
+
tied_weights: Whether to tie encoder and decoder weights
|
|
137
|
+
activation_normalization_factor: Scalar factor for rescaling activations (used with normalize_activations='constant_scalar_rescale')
|
|
138
|
+
"""
|
|
139
|
+
|
|
140
|
+
n_heads: int = 8
|
|
141
|
+
n_attn_layers: int = 1
|
|
142
|
+
bottleneck_factor: int = 64
|
|
143
|
+
sae_diff_type: Literal["relu", "topk"] = "topk"
|
|
144
|
+
kval_topk: int | None = None
|
|
145
|
+
tied_weights: bool = True
|
|
146
|
+
activation_normalization_factor: float = 1.0
|
|
147
|
+
|
|
148
|
+
def __post_init__(self):
|
|
149
|
+
# Call parent's __post_init__ first, but allow constant_scalar_rescale
|
|
150
|
+
if self.normalize_activations not in [
|
|
151
|
+
"none",
|
|
152
|
+
"expected_average_only_in",
|
|
153
|
+
"constant_norm_rescale",
|
|
154
|
+
"constant_scalar_rescale", # Temporal SAEs support this
|
|
155
|
+
"layer_norm",
|
|
156
|
+
]:
|
|
157
|
+
raise ValueError(
|
|
158
|
+
f"normalize_activations must be none, expected_average_only_in, layer_norm, constant_norm_rescale, or constant_scalar_rescale. Got {self.normalize_activations}"
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
@override
|
|
162
|
+
@classmethod
|
|
163
|
+
def architecture(cls) -> str:
|
|
164
|
+
return "temporal"
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
class TemporalSAE(SAE[TemporalSAEConfig]):
|
|
168
|
+
"""TemporalSAE: Sparse Autoencoder with temporal attention.
|
|
169
|
+
|
|
170
|
+
This SAE decomposes each activation x_t into:
|
|
171
|
+
- x_pred: Information aggregated from context {x_0, ..., x_{t-1}}
|
|
172
|
+
- x_novel: Novel information at position t (encoded sparsely)
|
|
173
|
+
|
|
174
|
+
The forward pass:
|
|
175
|
+
1. Uses attention layers to predict x_t from context
|
|
176
|
+
2. Encodes the residual (novel part) with a sparse SAE
|
|
177
|
+
3. Combines both for reconstruction
|
|
178
|
+
"""
|
|
179
|
+
|
|
180
|
+
# Custom parameters (in addition to W_enc, W_dec, b_dec from base)
|
|
181
|
+
attn_layers: nn.ModuleList # Attention layers
|
|
182
|
+
eps: float
|
|
183
|
+
lam: float
|
|
184
|
+
|
|
185
|
+
def __init__(self, cfg: TemporalSAEConfig, use_error_term: bool = False):
|
|
186
|
+
# Call parent init first
|
|
187
|
+
super().__init__(cfg, use_error_term)
|
|
188
|
+
|
|
189
|
+
# Initialize attention layers after parent init and move to correct device
|
|
190
|
+
self.attn_layers = nn.ModuleList(
|
|
191
|
+
[
|
|
192
|
+
ManualAttention(
|
|
193
|
+
dimin=cfg.d_sae,
|
|
194
|
+
n_heads=cfg.n_heads,
|
|
195
|
+
bottleneck_factor=cfg.bottleneck_factor,
|
|
196
|
+
bias_k=True,
|
|
197
|
+
bias_q=True,
|
|
198
|
+
bias_v=True,
|
|
199
|
+
bias_o=True,
|
|
200
|
+
).to(device=self.device, dtype=self.dtype)
|
|
201
|
+
for _ in range(cfg.n_attn_layers)
|
|
202
|
+
]
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
self.eps = 1e-6
|
|
206
|
+
self.lam = 1 / (4 * self.cfg.d_in)
|
|
207
|
+
|
|
208
|
+
@override
|
|
209
|
+
def _setup_activation_normalization(self):
|
|
210
|
+
"""Set up activation normalization functions for TemporalSAE.
|
|
211
|
+
|
|
212
|
+
Overrides the base implementation to handle constant_scalar_rescale
|
|
213
|
+
using the temporal-specific activation_normalization_factor.
|
|
214
|
+
"""
|
|
215
|
+
if self.cfg.normalize_activations == "constant_scalar_rescale":
|
|
216
|
+
# Handle constant scalar rescaling for temporal SAEs
|
|
217
|
+
def run_time_activation_norm_fn_in(x: torch.Tensor) -> torch.Tensor:
|
|
218
|
+
return x * self.cfg.activation_normalization_factor
|
|
219
|
+
|
|
220
|
+
def run_time_activation_norm_fn_out(x: torch.Tensor) -> torch.Tensor:
|
|
221
|
+
return x / self.cfg.activation_normalization_factor
|
|
222
|
+
|
|
223
|
+
self.run_time_activation_norm_fn_in = run_time_activation_norm_fn_in
|
|
224
|
+
self.run_time_activation_norm_fn_out = run_time_activation_norm_fn_out
|
|
225
|
+
else:
|
|
226
|
+
# Delegate to parent for all other normalization types
|
|
227
|
+
super()._setup_activation_normalization()
|
|
228
|
+
|
|
229
|
+
@override
|
|
230
|
+
def initialize_weights(self) -> None:
|
|
231
|
+
"""Initialize TemporalSAE weights."""
|
|
232
|
+
# Initialize D (decoder) and b (bias)
|
|
233
|
+
self.W_dec = nn.Parameter(
|
|
234
|
+
torch.randn(
|
|
235
|
+
(self.cfg.d_sae, self.cfg.d_in), dtype=self.dtype, device=self.device
|
|
236
|
+
)
|
|
237
|
+
)
|
|
238
|
+
self.b_dec = nn.Parameter(
|
|
239
|
+
torch.zeros((self.cfg.d_in), dtype=self.dtype, device=self.device)
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
# Initialize E (encoder) if not tied
|
|
243
|
+
if not self.cfg.tied_weights:
|
|
244
|
+
self.W_enc = nn.Parameter(
|
|
245
|
+
torch.randn(
|
|
246
|
+
(self.cfg.d_in, self.cfg.d_sae),
|
|
247
|
+
dtype=self.dtype,
|
|
248
|
+
device=self.device,
|
|
249
|
+
)
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
def encode_with_predictions(
|
|
253
|
+
self, x: Float[torch.Tensor, "... d_in"]
|
|
254
|
+
) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
|
|
255
|
+
"""Encode input to novel codes only.
|
|
256
|
+
|
|
257
|
+
Returns only the sparse novel codes (not predicted codes).
|
|
258
|
+
This is the main feature representation for TemporalSAE.
|
|
259
|
+
"""
|
|
260
|
+
# Process input through SAELens preprocessing
|
|
261
|
+
x = self.process_sae_in(x)
|
|
262
|
+
|
|
263
|
+
B, L, _ = x.shape
|
|
264
|
+
|
|
265
|
+
if self.cfg.tied_weights: # noqa: SIM108
|
|
266
|
+
W_enc = self.W_dec.T
|
|
267
|
+
else:
|
|
268
|
+
W_enc = self.W_enc
|
|
269
|
+
|
|
270
|
+
# Compute predicted codes using attention
|
|
271
|
+
x_residual = x
|
|
272
|
+
z_pred = torch.zeros((B, L, self.cfg.d_sae), device=x.device, dtype=x.dtype)
|
|
273
|
+
|
|
274
|
+
for attn_layer in self.attn_layers:
|
|
275
|
+
# Encode input to latent space
|
|
276
|
+
z_input = F.relu(torch.matmul(x_residual * self.lam, W_enc))
|
|
277
|
+
|
|
278
|
+
# Shift context (causal masking)
|
|
279
|
+
z_ctx = torch.cat(
|
|
280
|
+
(torch.zeros_like(z_input[:, :1, :]), z_input[:, :-1, :].clone()), dim=1
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
# Apply attention to get predicted codes
|
|
284
|
+
z_pred_, _ = attn_layer(z_ctx, z_input, get_attn_map=False)
|
|
285
|
+
z_pred_ = F.relu(z_pred_)
|
|
286
|
+
|
|
287
|
+
# Project predicted codes back to input space
|
|
288
|
+
Dz_pred_ = torch.matmul(z_pred_, self.W_dec)
|
|
289
|
+
Dz_norm_ = Dz_pred_.norm(dim=-1, keepdim=True) + self.eps
|
|
290
|
+
|
|
291
|
+
# Compute projection scale
|
|
292
|
+
proj_scale = (Dz_pred_ * x_residual).sum(
|
|
293
|
+
dim=-1, keepdim=True
|
|
294
|
+
) / Dz_norm_.pow(2)
|
|
295
|
+
|
|
296
|
+
# Accumulate predicted codes
|
|
297
|
+
z_pred = z_pred + (z_pred_ * proj_scale)
|
|
298
|
+
|
|
299
|
+
# Remove prediction from residual
|
|
300
|
+
x_residual = x_residual - proj_scale * Dz_pred_
|
|
301
|
+
|
|
302
|
+
# Encode residual (novel part) with sparse SAE
|
|
303
|
+
z_novel = F.relu(torch.matmul(x_residual * self.lam, W_enc))
|
|
304
|
+
if self.cfg.sae_diff_type == "topk":
|
|
305
|
+
kval = self.cfg.kval_topk
|
|
306
|
+
if kval is not None:
|
|
307
|
+
_, topk_indices = torch.topk(z_novel, kval, dim=-1)
|
|
308
|
+
mask = torch.zeros_like(z_novel)
|
|
309
|
+
mask.scatter_(-1, topk_indices, 1)
|
|
310
|
+
z_novel = z_novel * mask
|
|
311
|
+
|
|
312
|
+
# Return only novel codes (these are the interpretable features)
|
|
313
|
+
return z_novel, z_pred
|
|
314
|
+
|
|
315
|
+
def encode(
|
|
316
|
+
self, x: Float[torch.Tensor, "... d_in"]
|
|
317
|
+
) -> Float[torch.Tensor, "... d_sae"]:
|
|
318
|
+
return self.encode_with_predictions(x)[0]
|
|
319
|
+
|
|
320
|
+
def decode(
|
|
321
|
+
self, feature_acts: Float[torch.Tensor, "... d_sae"]
|
|
322
|
+
) -> Float[torch.Tensor, "... d_in"]:
|
|
323
|
+
"""Decode novel codes to reconstruction.
|
|
324
|
+
|
|
325
|
+
Note: This only decodes the novel codes. For full reconstruction,
|
|
326
|
+
use forward() which includes predicted codes.
|
|
327
|
+
"""
|
|
328
|
+
# Decode novel codes
|
|
329
|
+
sae_out = torch.matmul(feature_acts, self.W_dec)
|
|
330
|
+
sae_out = sae_out + self.b_dec
|
|
331
|
+
|
|
332
|
+
# Apply hook
|
|
333
|
+
sae_out = self.hook_sae_recons(sae_out)
|
|
334
|
+
|
|
335
|
+
# Apply output activation normalization (reverses input normalization)
|
|
336
|
+
sae_out = self.run_time_activation_norm_fn_out(sae_out)
|
|
337
|
+
|
|
338
|
+
# Add bias (already removed in process_sae_in)
|
|
339
|
+
logger.warning(
|
|
340
|
+
"NOTE this only decodes x_novel. The x_pred is missing, so we're not reconstructing the full x."
|
|
341
|
+
)
|
|
342
|
+
return sae_out
|
|
343
|
+
|
|
344
|
+
@override
|
|
345
|
+
def forward(
|
|
346
|
+
self, x: Float[torch.Tensor, "... d_in"]
|
|
347
|
+
) -> Float[torch.Tensor, "... d_in"]:
|
|
348
|
+
"""Full forward pass through TemporalSAE.
|
|
349
|
+
|
|
350
|
+
Returns complete reconstruction (predicted + novel).
|
|
351
|
+
"""
|
|
352
|
+
# Encode
|
|
353
|
+
z_novel, z_pred = self.encode_with_predictions(x)
|
|
354
|
+
|
|
355
|
+
# Decode the sum of predicted and novel codes.
|
|
356
|
+
x_recons = torch.matmul(z_novel + z_pred, self.W_dec) + self.b_dec
|
|
357
|
+
|
|
358
|
+
# Apply output activation normalization (reverses input normalization)
|
|
359
|
+
x_recons = self.run_time_activation_norm_fn_out(x_recons)
|
|
360
|
+
|
|
361
|
+
return self.hook_sae_output(x_recons)
|
|
362
|
+
|
|
363
|
+
@override
|
|
364
|
+
def fold_W_dec_norm(self) -> None:
|
|
365
|
+
raise NotImplementedError("Folding W_dec_norm is not supported for TemporalSAE")
|
|
366
|
+
|
|
367
|
+
@override
|
|
368
|
+
@torch.no_grad()
|
|
369
|
+
def fold_activation_norm_scaling_factor(self, scaling_factor: float) -> None:
|
|
370
|
+
raise NotImplementedError(
|
|
371
|
+
"Folding activation norm scaling factor is not supported for TemporalSAE"
|
|
372
|
+
)
|
|
@@ -319,7 +319,7 @@ class ActivationsStore:
|
|
|
319
319
|
)
|
|
320
320
|
else:
|
|
321
321
|
warnings.warn(
|
|
322
|
-
"Dataset is not tokenized. Pre-tokenizing will improve performance and allows for more control over special tokens. See https://
|
|
322
|
+
"Dataset is not tokenized. Pre-tokenizing will improve performance and allows for more control over special tokens. See https://decoderesearch.github.io/SAELens/training_saes/#pretokenizing-datasets for more info."
|
|
323
323
|
)
|
|
324
324
|
|
|
325
325
|
self.iterable_sequences = self._iterate_tokenized_sequences()
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: sae-lens
|
|
3
|
-
Version: 6.
|
|
3
|
+
Version: 6.20.1
|
|
4
4
|
Summary: Training and Analyzing Sparse Autoencoders (SAEs)
|
|
5
5
|
License: MIT
|
|
6
6
|
License-File: LICENSE
|
|
@@ -30,19 +30,19 @@ Requires-Dist: tenacity (>=9.0.0)
|
|
|
30
30
|
Requires-Dist: transformer-lens (>=2.16.1,<3.0.0)
|
|
31
31
|
Requires-Dist: transformers (>=4.38.1,<5.0.0)
|
|
32
32
|
Requires-Dist: typing-extensions (>=4.10.0,<5.0.0)
|
|
33
|
-
Project-URL: Homepage, https://
|
|
34
|
-
Project-URL: Repository, https://github.com/
|
|
33
|
+
Project-URL: Homepage, https://decoderesearch.github.io/SAELens
|
|
34
|
+
Project-URL: Repository, https://github.com/decoderesearch/SAELens
|
|
35
35
|
Description-Content-Type: text/markdown
|
|
36
36
|
|
|
37
|
-
<img width="1308"
|
|
37
|
+
<img width="1308" height="532" alt="saes_pic" src="https://github.com/user-attachments/assets/2a5d752f-b261-4ee4-ad5d-ebf282321371" />
|
|
38
38
|
|
|
39
39
|
# SAE Lens
|
|
40
40
|
|
|
41
41
|
[](https://pypi.org/project/sae-lens/)
|
|
42
42
|
[](https://opensource.org/licenses/MIT)
|
|
43
|
-
[](https://github.com/decoderesearch/SAELens/actions/workflows/build.yml)
|
|
44
|
+
[](https://github.com/decoderesearch/SAELens/actions/workflows/deploy_docs.yml)
|
|
45
|
+
[](https://codecov.io/gh/decoderesearch/SAELens)
|
|
46
46
|
|
|
47
47
|
SAELens exists to help researchers:
|
|
48
48
|
|
|
@@ -50,7 +50,7 @@ SAELens exists to help researchers:
|
|
|
50
50
|
- Analyse sparse autoencoders / research mechanistic interpretability.
|
|
51
51
|
- Generate insights which make it easier to create safe and aligned AI systems.
|
|
52
52
|
|
|
53
|
-
Please refer to the [documentation](https://
|
|
53
|
+
Please refer to the [documentation](https://decoderesearch.github.io/SAELens/) for information on how to:
|
|
54
54
|
|
|
55
55
|
- Download and Analyse pre-trained sparse autoencoders.
|
|
56
56
|
- Train your own sparse autoencoders.
|
|
@@ -58,25 +58,25 @@ Please refer to the [documentation](https://jbloomaus.github.io/SAELens/) for in
|
|
|
58
58
|
|
|
59
59
|
SAE Lens is the result of many contributors working collectively to improve humanity's understanding of neural networks, many of whom are motivated by a desire to [safeguard humanity from risks posed by artificial intelligence](https://80000hours.org/problem-profiles/artificial-intelligence/).
|
|
60
60
|
|
|
61
|
-
This library is maintained by [Joseph Bloom](https://www.
|
|
61
|
+
This library is maintained by [Joseph Bloom](https://www.decoderesearch.com/), [Curt Tigges](https://curttigges.com/), [Anthony Duong](https://github.com/anthonyduong9) and [David Chanin](https://github.com/chanind).
|
|
62
62
|
|
|
63
63
|
## Loading Pre-trained SAEs.
|
|
64
64
|
|
|
65
|
-
Pre-trained SAEs for various models can be imported via SAE Lens. See this [page](https://
|
|
65
|
+
Pre-trained SAEs for various models can be imported via SAE Lens. See this [page](https://decoderesearch.github.io/SAELens/sae_table/) in the readme for a list of all SAEs.
|
|
66
66
|
|
|
67
67
|
## Migrating to SAELens v6
|
|
68
68
|
|
|
69
|
-
The new v6 update is a major refactor to SAELens and changes the way training code is structured. Check out the [migration guide](https://
|
|
69
|
+
The new v6 update is a major refactor to SAELens and changes the way training code is structured. Check out the [migration guide](https://decoderesearch.github.io/SAELens/latest/migrating/) for more details.
|
|
70
70
|
|
|
71
71
|
## Tutorials
|
|
72
72
|
|
|
73
|
-
- [SAE Lens + Neuronpedia](tutorials/tutorial_2_0.ipynb)[](https://githubtocolab.com/
|
|
73
|
+
- [SAE Lens + Neuronpedia](tutorials/tutorial_2_0.ipynb)[](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/tutorial_2_0.ipynb)
|
|
74
74
|
- [Loading and Analysing Pre-Trained Sparse Autoencoders](tutorials/basic_loading_and_analysing.ipynb)
|
|
75
|
-
[](https://githubtocolab.com/
|
|
75
|
+
[](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb)
|
|
76
76
|
- [Understanding SAE Features with the Logit Lens](tutorials/logits_lens_with_features.ipynb)
|
|
77
|
-
[](https://githubtocolab.com/
|
|
77
|
+
[](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/logits_lens_with_features.ipynb)
|
|
78
78
|
- [Training a Sparse Autoencoder](tutorials/training_a_sparse_autoencoder.ipynb)
|
|
79
|
-
[](https://githubtocolab.com/
|
|
79
|
+
[](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)
|
|
80
80
|
|
|
81
81
|
## Join the Slack!
|
|
82
82
|
|
|
@@ -91,7 +91,7 @@ Please cite the package as follows:
|
|
|
91
91
|
title = {SAELens},
|
|
92
92
|
author = {Bloom, Joseph and Tigges, Curt and Duong, Anthony and Chanin, David},
|
|
93
93
|
year = {2024},
|
|
94
|
-
howpublished = {\url{https://github.com/
|
|
94
|
+
howpublished = {\url{https://github.com/decoderesearch/SAELens}},
|
|
95
95
|
}
|
|
96
96
|
```
|
|
97
97
|
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
sae_lens/__init__.py,sha256=
|
|
1
|
+
sae_lens/__init__.py,sha256=Q90bKWhK5R8nzb_mNI1WrVVtxJOVNyiOSgotaLNq3sU,4033
|
|
2
2
|
sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
3
|
sae_lens/analysis/hooked_sae_transformer.py,sha256=vRu6JseH1lZaEeILD5bEkQEQ1wYHHDcxD-f2olKmE9Y,14275
|
|
4
4
|
sae_lens/analysis/neuronpedia_integration.py,sha256=Gx1W7hUBEuMoasNcnOnZ1wmqbXDd1pSZ1nqKEya1HQc,4962
|
|
@@ -9,24 +9,25 @@ sae_lens/evals.py,sha256=P0NUsJeGzYxFBiVKhbPzd72IFKY4gH40HHlEZ3jEAmg,39598
|
|
|
9
9
|
sae_lens/llm_sae_training_runner.py,sha256=M7BK55gSFYu2qFQKABHX3c8i46P1LfODCeyHFzGGuqU,15196
|
|
10
10
|
sae_lens/load_model.py,sha256=C8AMykctj6H7tz_xRwB06-EXj6TfW64PtSJZR5Jxn1Y,8649
|
|
11
11
|
sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
12
|
-
sae_lens/loading/pretrained_sae_loaders.py,sha256=
|
|
13
|
-
sae_lens/loading/pretrained_saes_directory.py,sha256=
|
|
12
|
+
sae_lens/loading/pretrained_sae_loaders.py,sha256=X-gVZ4A74E85lSMFMsZ_rEQhHlR9AYFwhxvoA_vt2CQ,56051
|
|
13
|
+
sae_lens/loading/pretrained_saes_directory.py,sha256=hejNfLUepYCSGPalRfQwxxCEUqMMUPsn1tufwvwct5k,3820
|
|
14
14
|
sae_lens/pretokenize_runner.py,sha256=x-reJzVPFDS9iRFbZtrFYSzNguJYki9gd0pbHjYJ3r4,7085
|
|
15
|
-
sae_lens/pretrained_saes.yaml,sha256=
|
|
15
|
+
sae_lens/pretrained_saes.yaml,sha256=LO0r6ISTAWTicy2ecL7s4qU6hfI3K8jJr71HuXFBR_o,604001
|
|
16
16
|
sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
|
|
17
|
-
sae_lens/saes/__init__.py,sha256=
|
|
17
|
+
sae_lens/saes/__init__.py,sha256=nTNPnJ7edyfedo1MX96xwn9WOG8504yHbT9LFw9od_0,1778
|
|
18
18
|
sae_lens/saes/batchtopk_sae.py,sha256=x4EbgZl0GUickRPcCmtKNGS2Ra3Uy1Z1OtF2FnrSabQ,5422
|
|
19
19
|
sae_lens/saes/gated_sae.py,sha256=qcmM9JwBA8aZR8z_IRHV1_gQX-q_63tKewWXRnhdXuo,8986
|
|
20
20
|
sae_lens/saes/jumprelu_sae.py,sha256=HHBF1sJ95lZvxwP5vwLSQFKdnJN2KKYK0WAEaLTrta0,13399
|
|
21
21
|
sae_lens/saes/matryoshka_batchtopk_sae.py,sha256=4_1cVaxk6c6jgJEbxqebtG-cjQNIzaMAfjSPGfR7_VU,6062
|
|
22
|
-
sae_lens/saes/sae.py,sha256=
|
|
22
|
+
sae_lens/saes/sae.py,sha256=i6HwULvCrFQhRqKruCQo2aOY5a7c6FuUX7sD3TbnNTY,38084
|
|
23
23
|
sae_lens/saes/standard_sae.py,sha256=9UqYyYtQuThYxXKNaDjYcyowpOx2-7cShG-TeUP6JCQ,5940
|
|
24
|
+
sae_lens/saes/temporal_sae.py,sha256=jnHBqz3FvKMOw_lH2aGIJqgrMVJbpprronittCGGXPQ,13517
|
|
24
25
|
sae_lens/saes/topk_sae.py,sha256=tzQM5eQFifMe--8_8NUBYWY7hpjQa6A_olNe6U71FE8,21275
|
|
25
26
|
sae_lens/saes/transcoder.py,sha256=BfLSbTYVNZh-ruGxseZiZJ_acEL6_7QyTdfqUr0lDOg,12156
|
|
26
27
|
sae_lens/tokenization_and_batching.py,sha256=D_o7cXvRqhT89H3wNzoRymNALNE6eHojBWLdXOUwUGE,5438
|
|
27
28
|
sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
28
29
|
sae_lens/training/activation_scaler.py,sha256=FzNfgBplLWmyiSlZ6TUvE-nur3lOiGTrlvC97ys8S24,1973
|
|
29
|
-
sae_lens/training/activations_store.py,sha256=
|
|
30
|
+
sae_lens/training/activations_store.py,sha256=7fWUahMUqrvHhN5b_8Ma6kJX1NkRbIqNuNKzKvYmHFQ,33881
|
|
30
31
|
sae_lens/training/mixing_buffer.py,sha256=vDpYG5ZE70szDvBsRKcNHEES3h_WTKJ16qDYk5jPOVA,2015
|
|
31
32
|
sae_lens/training/optim.py,sha256=bJpqqcK4enkcPvQAJkeH4Ci1LUOlfjIMTv6-IlaAbRA,5588
|
|
32
33
|
sae_lens/training/sae_trainer.py,sha256=zhkabyIKxI_tZTV3_kwz6zMrHZ95Ecr97krmwc-9ffs,17600
|
|
@@ -34,7 +35,7 @@ sae_lens/training/types.py,sha256=1FpLx_Doda9vZpmfm-x1e8wGBYpyhe9Kpb_JuM5nIFM,90
|
|
|
34
35
|
sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
|
|
35
36
|
sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
|
|
36
37
|
sae_lens/util.py,sha256=tCovQ-eZa1L7thPpNDL6PGOJrIMML2yLI5e0EHCOpS8,3309
|
|
37
|
-
sae_lens-6.
|
|
38
|
-
sae_lens-6.
|
|
39
|
-
sae_lens-6.
|
|
40
|
-
sae_lens-6.
|
|
38
|
+
sae_lens-6.20.1.dist-info/METADATA,sha256=FrirZts5VERstYBsFiVUZcp1OU2fwYTnhZ8KjsvYHm0,5369
|
|
39
|
+
sae_lens-6.20.1.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
|
|
40
|
+
sae_lens-6.20.1.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
|
|
41
|
+
sae_lens-6.20.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|