sae-lens 6.16.3__py3-none-any.whl → 6.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sae-lens might be problematic. Click here for more details.
- sae_lens/__init__.py +6 -1
- sae_lens/cache_activations_runner.py +1 -1
- sae_lens/config.py +39 -2
- sae_lens/constants.py +1 -0
- sae_lens/llm_sae_training_runner.py +9 -4
- sae_lens/loading/pretrained_sae_loaders.py +188 -0
- sae_lens/loading/pretrained_saes_directory.py +5 -3
- sae_lens/pretrained_saes.yaml +75 -1
- sae_lens/saes/__init__.py +3 -0
- sae_lens/saes/sae.py +11 -13
- sae_lens/saes/temporal_sae.py +372 -0
- sae_lens/training/activation_scaler.py +7 -0
- sae_lens/training/activations_store.py +47 -4
- sae_lens/training/optim.py +11 -0
- sae_lens/training/sae_trainer.py +49 -11
- {sae_lens-6.16.3.dist-info → sae_lens-6.21.0.dist-info}/METADATA +16 -16
- {sae_lens-6.16.3.dist-info → sae_lens-6.21.0.dist-info}/RECORD +19 -18
- {sae_lens-6.16.3.dist-info → sae_lens-6.21.0.dist-info}/WHEEL +0 -0
- {sae_lens-6.16.3.dist-info → sae_lens-6.21.0.dist-info}/licenses/LICENSE +0 -0
sae_lens/__init__.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
# ruff: noqa: E402
|
|
2
|
-
__version__ = "6.
|
|
2
|
+
__version__ = "6.21.0"
|
|
3
3
|
|
|
4
4
|
import logging
|
|
5
5
|
|
|
@@ -28,6 +28,8 @@ from sae_lens.saes import (
|
|
|
28
28
|
StandardSAEConfig,
|
|
29
29
|
StandardTrainingSAE,
|
|
30
30
|
StandardTrainingSAEConfig,
|
|
31
|
+
TemporalSAE,
|
|
32
|
+
TemporalSAEConfig,
|
|
31
33
|
TopKSAE,
|
|
32
34
|
TopKSAEConfig,
|
|
33
35
|
TopKTrainingSAE,
|
|
@@ -105,6 +107,8 @@ __all__ = [
|
|
|
105
107
|
"JumpReLUTranscoderConfig",
|
|
106
108
|
"MatryoshkaBatchTopKTrainingSAE",
|
|
107
109
|
"MatryoshkaBatchTopKTrainingSAEConfig",
|
|
110
|
+
"TemporalSAE",
|
|
111
|
+
"TemporalSAEConfig",
|
|
108
112
|
]
|
|
109
113
|
|
|
110
114
|
|
|
@@ -127,3 +131,4 @@ register_sae_training_class(
|
|
|
127
131
|
register_sae_class("transcoder", Transcoder, TranscoderConfig)
|
|
128
132
|
register_sae_class("skip_transcoder", SkipTranscoder, SkipTranscoderConfig)
|
|
129
133
|
register_sae_class("jumprelu_transcoder", JumpReLUTranscoder, JumpReLUTranscoderConfig)
|
|
134
|
+
register_sae_class("temporal", TemporalSAE, TemporalSAEConfig)
|
|
@@ -10,7 +10,7 @@ from datasets import Array2D, Dataset, Features, Sequence, Value
|
|
|
10
10
|
from datasets.fingerprint import generate_fingerprint
|
|
11
11
|
from huggingface_hub import HfApi
|
|
12
12
|
from jaxtyping import Float, Int
|
|
13
|
-
from tqdm import tqdm
|
|
13
|
+
from tqdm.auto import tqdm
|
|
14
14
|
from transformer_lens.HookedTransformer import HookedRootModule
|
|
15
15
|
|
|
16
16
|
from sae_lens import logger
|
sae_lens/config.py
CHANGED
|
@@ -18,6 +18,7 @@ from datasets import (
|
|
|
18
18
|
|
|
19
19
|
from sae_lens import __version__, logger
|
|
20
20
|
from sae_lens.constants import DTYPE_MAP
|
|
21
|
+
from sae_lens.registry import get_sae_training_class
|
|
21
22
|
from sae_lens.saes.sae import TrainingSAEConfig
|
|
22
23
|
|
|
23
24
|
if TYPE_CHECKING:
|
|
@@ -171,6 +172,7 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
|
171
172
|
n_checkpoints (int): The number of checkpoints to save during training. 0 means no checkpoints.
|
|
172
173
|
checkpoint_path (str | None): The path to save checkpoints. A unique ID will be appended to this path. Set to None to disable checkpoint saving. (default is "checkpoints")
|
|
173
174
|
save_final_checkpoint (bool): Whether to include an additional final checkpoint when training is finished. (default is False).
|
|
175
|
+
resume_from_checkpoint (str | None): The path to the checkpoint to resume training from. (default is None).
|
|
174
176
|
output_path (str | None): The path to save outputs. Set to None to disable output saving. (default is "output")
|
|
175
177
|
verbose (bool): Whether to print verbose output. (default is True)
|
|
176
178
|
model_kwargs (dict[str, Any]): Keyword arguments for `model.run_with_cache`
|
|
@@ -261,6 +263,7 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
|
261
263
|
checkpoint_path: str | None = "checkpoints"
|
|
262
264
|
save_final_checkpoint: bool = False
|
|
263
265
|
output_path: str | None = "output"
|
|
266
|
+
resume_from_checkpoint: str | None = None
|
|
264
267
|
|
|
265
268
|
# Misc
|
|
266
269
|
verbose: bool = True
|
|
@@ -385,8 +388,11 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
|
385
388
|
return self.sae.to_dict()
|
|
386
389
|
|
|
387
390
|
def to_dict(self) -> dict[str, Any]:
|
|
388
|
-
|
|
389
|
-
|
|
391
|
+
"""
|
|
392
|
+
Convert the config to a dictionary.
|
|
393
|
+
"""
|
|
394
|
+
|
|
395
|
+
d = asdict(self)
|
|
390
396
|
|
|
391
397
|
d["logger"] = asdict(self.logger)
|
|
392
398
|
d["sae"] = self.sae.to_dict()
|
|
@@ -396,6 +402,37 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
|
396
402
|
d["act_store_device"] = str(self.act_store_device)
|
|
397
403
|
return d
|
|
398
404
|
|
|
405
|
+
@classmethod
|
|
406
|
+
def from_dict(cls, cfg_dict: dict[str, Any]) -> "LanguageModelSAERunnerConfig[Any]":
|
|
407
|
+
"""
|
|
408
|
+
Load a LanguageModelSAERunnerConfig from a dictionary given by `to_dict`.
|
|
409
|
+
|
|
410
|
+
Args:
|
|
411
|
+
cfg_dict (dict[str, Any]): The dictionary to load the config from.
|
|
412
|
+
|
|
413
|
+
Returns:
|
|
414
|
+
LanguageModelSAERunnerConfig: The loaded config.
|
|
415
|
+
"""
|
|
416
|
+
if "sae" not in cfg_dict:
|
|
417
|
+
raise ValueError("sae field is required in the config dictionary")
|
|
418
|
+
if "architecture" not in cfg_dict["sae"]:
|
|
419
|
+
raise ValueError("architecture field is required in the sae dictionary")
|
|
420
|
+
if "logger" not in cfg_dict:
|
|
421
|
+
raise ValueError("logger field is required in the config dictionary")
|
|
422
|
+
sae_config_class = get_sae_training_class(cfg_dict["sae"]["architecture"])[1]
|
|
423
|
+
sae_cfg = sae_config_class.from_dict(cfg_dict["sae"])
|
|
424
|
+
logger_cfg = LoggingConfig(**cfg_dict["logger"])
|
|
425
|
+
updated_cfg_dict: dict[str, Any] = {
|
|
426
|
+
**cfg_dict,
|
|
427
|
+
"sae": sae_cfg,
|
|
428
|
+
"logger": logger_cfg,
|
|
429
|
+
}
|
|
430
|
+
output = cls(**updated_cfg_dict)
|
|
431
|
+
# the post_init always appends to checkpoint path, so we need to set it explicitly here.
|
|
432
|
+
if "checkpoint_path" in cfg_dict:
|
|
433
|
+
output.checkpoint_path = cfg_dict["checkpoint_path"]
|
|
434
|
+
return output
|
|
435
|
+
|
|
399
436
|
def to_sae_trainer_config(self) -> "SAETrainerConfig":
|
|
400
437
|
return SAETrainerConfig(
|
|
401
438
|
n_checkpoints=self.n_checkpoints,
|
sae_lens/constants.py
CHANGED
|
@@ -17,5 +17,6 @@ SAE_WEIGHTS_FILENAME = "sae_weights.safetensors"
|
|
|
17
17
|
SAE_CFG_FILENAME = "cfg.json"
|
|
18
18
|
RUNNER_CFG_FILENAME = "runner_cfg.json"
|
|
19
19
|
SPARSIFY_WEIGHTS_FILENAME = "sae.safetensors"
|
|
20
|
+
TRAINER_STATE_FILENAME = "trainer_state.pt"
|
|
20
21
|
ACTIVATIONS_STORE_STATE_FILENAME = "activations_store_state.safetensors"
|
|
21
22
|
ACTIVATION_SCALER_CFG_FILENAME = "activation_scaler.json"
|
|
@@ -16,7 +16,6 @@ from typing_extensions import deprecated
|
|
|
16
16
|
from sae_lens import logger
|
|
17
17
|
from sae_lens.config import HfDataset, LanguageModelSAERunnerConfig
|
|
18
18
|
from sae_lens.constants import (
|
|
19
|
-
ACTIVATIONS_STORE_STATE_FILENAME,
|
|
20
19
|
RUNNER_CFG_FILENAME,
|
|
21
20
|
SPARSITY_FILENAME,
|
|
22
21
|
)
|
|
@@ -112,6 +111,7 @@ class LanguageModelSAETrainingRunner:
|
|
|
112
111
|
override_dataset: HfDataset | None = None,
|
|
113
112
|
override_model: HookedRootModule | None = None,
|
|
114
113
|
override_sae: TrainingSAE[Any] | None = None,
|
|
114
|
+
resume_from_checkpoint: Path | str | None = None,
|
|
115
115
|
):
|
|
116
116
|
if override_dataset is not None:
|
|
117
117
|
logger.warning(
|
|
@@ -153,6 +153,7 @@ class LanguageModelSAETrainingRunner:
|
|
|
153
153
|
)
|
|
154
154
|
else:
|
|
155
155
|
self.sae = override_sae
|
|
156
|
+
|
|
156
157
|
self.sae.to(self.cfg.device)
|
|
157
158
|
|
|
158
159
|
def run(self):
|
|
@@ -185,6 +186,12 @@ class LanguageModelSAETrainingRunner:
|
|
|
185
186
|
cfg=self.cfg.to_sae_trainer_config(),
|
|
186
187
|
)
|
|
187
188
|
|
|
189
|
+
if self.cfg.resume_from_checkpoint is not None:
|
|
190
|
+
logger.info(f"Resuming from checkpoint: {self.cfg.resume_from_checkpoint}")
|
|
191
|
+
trainer.load_trainer_state(self.cfg.resume_from_checkpoint)
|
|
192
|
+
self.sae.load_weights_from_checkpoint(self.cfg.resume_from_checkpoint)
|
|
193
|
+
self.activations_store.load_from_checkpoint(self.cfg.resume_from_checkpoint)
|
|
194
|
+
|
|
188
195
|
self._compile_if_needed()
|
|
189
196
|
sae = self.run_trainer_with_interruption_handling(trainer)
|
|
190
197
|
|
|
@@ -304,9 +311,7 @@ class LanguageModelSAETrainingRunner:
|
|
|
304
311
|
if checkpoint_path is None:
|
|
305
312
|
return
|
|
306
313
|
|
|
307
|
-
self.activations_store.
|
|
308
|
-
str(checkpoint_path / ACTIVATIONS_STORE_STATE_FILENAME)
|
|
309
|
-
)
|
|
314
|
+
self.activations_store.save_to_checkpoint(checkpoint_path)
|
|
310
315
|
|
|
311
316
|
runner_config = self.cfg.to_dict()
|
|
312
317
|
with open(checkpoint_path / RUNNER_CFG_FILENAME, "w") as f:
|
|
@@ -523,6 +523,82 @@ def gemma_2_sae_huggingface_loader(
|
|
|
523
523
|
return cfg_dict, state_dict, log_sparsity
|
|
524
524
|
|
|
525
525
|
|
|
526
|
+
def get_goodfire_config_from_hf(
|
|
527
|
+
repo_id: str,
|
|
528
|
+
folder_name: str, # noqa: ARG001
|
|
529
|
+
device: str,
|
|
530
|
+
force_download: bool = False, # noqa: ARG001
|
|
531
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
532
|
+
) -> dict[str, Any]:
|
|
533
|
+
cfg_dict = None
|
|
534
|
+
if repo_id == "Goodfire/Llama-3.3-70B-Instruct-SAE-l50":
|
|
535
|
+
if folder_name != "Llama-3.3-70B-Instruct-SAE-l50.pt":
|
|
536
|
+
raise ValueError(f"Unsupported Goodfire SAE: {repo_id}/{folder_name}")
|
|
537
|
+
cfg_dict = {
|
|
538
|
+
"architecture": "standard",
|
|
539
|
+
"d_in": 8192,
|
|
540
|
+
"d_sae": 65536,
|
|
541
|
+
"model_name": "meta-llama/Llama-3.3-70B-Instruct",
|
|
542
|
+
"hook_name": "blocks.50.hook_resid_post",
|
|
543
|
+
"hook_head_index": None,
|
|
544
|
+
"dataset_path": "lmsys/lmsys-chat-1m",
|
|
545
|
+
"apply_b_dec_to_input": False,
|
|
546
|
+
}
|
|
547
|
+
elif repo_id == "Goodfire/Llama-3.1-8B-Instruct-SAE-l19":
|
|
548
|
+
if folder_name != "Llama-3.1-8B-Instruct-SAE-l19.pth":
|
|
549
|
+
raise ValueError(f"Unsupported Goodfire SAE: {repo_id}/{folder_name}")
|
|
550
|
+
cfg_dict = {
|
|
551
|
+
"architecture": "standard",
|
|
552
|
+
"d_in": 4096,
|
|
553
|
+
"d_sae": 65536,
|
|
554
|
+
"model_name": "meta-llama/Llama-3.1-8B-Instruct",
|
|
555
|
+
"hook_name": "blocks.19.hook_resid_post",
|
|
556
|
+
"hook_head_index": None,
|
|
557
|
+
"dataset_path": "lmsys/lmsys-chat-1m",
|
|
558
|
+
"apply_b_dec_to_input": False,
|
|
559
|
+
}
|
|
560
|
+
if cfg_dict is None:
|
|
561
|
+
raise ValueError(f"Unsupported Goodfire SAE: {repo_id}/{folder_name}")
|
|
562
|
+
if device is not None:
|
|
563
|
+
cfg_dict["device"] = device
|
|
564
|
+
if cfg_overrides is not None:
|
|
565
|
+
cfg_dict.update(cfg_overrides)
|
|
566
|
+
return cfg_dict
|
|
567
|
+
|
|
568
|
+
|
|
569
|
+
def get_goodfire_huggingface_loader(
|
|
570
|
+
repo_id: str,
|
|
571
|
+
folder_name: str,
|
|
572
|
+
device: str = "cpu",
|
|
573
|
+
force_download: bool = False,
|
|
574
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
575
|
+
) -> tuple[dict[str, Any], dict[str, torch.Tensor], torch.Tensor | None]:
|
|
576
|
+
cfg_dict = get_goodfire_config_from_hf(
|
|
577
|
+
repo_id,
|
|
578
|
+
folder_name,
|
|
579
|
+
device,
|
|
580
|
+
force_download,
|
|
581
|
+
cfg_overrides,
|
|
582
|
+
)
|
|
583
|
+
|
|
584
|
+
# Download the SAE weights
|
|
585
|
+
sae_path = hf_hub_download(
|
|
586
|
+
repo_id=repo_id,
|
|
587
|
+
filename=folder_name,
|
|
588
|
+
force_download=force_download,
|
|
589
|
+
)
|
|
590
|
+
raw_state_dict = torch.load(sae_path, map_location=device)
|
|
591
|
+
|
|
592
|
+
state_dict = {
|
|
593
|
+
"W_enc": raw_state_dict["encoder_linear.weight"].T,
|
|
594
|
+
"W_dec": raw_state_dict["decoder_linear.weight"].T,
|
|
595
|
+
"b_enc": raw_state_dict["encoder_linear.bias"],
|
|
596
|
+
"b_dec": raw_state_dict["decoder_linear.bias"],
|
|
597
|
+
}
|
|
598
|
+
|
|
599
|
+
return cfg_dict, state_dict, None
|
|
600
|
+
|
|
601
|
+
|
|
526
602
|
def get_llama_scope_config_from_hf(
|
|
527
603
|
repo_id: str,
|
|
528
604
|
folder_name: str,
|
|
@@ -1475,6 +1551,114 @@ def get_mntss_clt_layer_config_from_hf(
|
|
|
1475
1551
|
}
|
|
1476
1552
|
|
|
1477
1553
|
|
|
1554
|
+
def get_temporal_sae_config_from_hf(
|
|
1555
|
+
repo_id: str,
|
|
1556
|
+
folder_name: str,
|
|
1557
|
+
device: str,
|
|
1558
|
+
force_download: bool = False,
|
|
1559
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
1560
|
+
) -> dict[str, Any]:
|
|
1561
|
+
"""Get TemporalSAE config without loading weights."""
|
|
1562
|
+
# Download config file
|
|
1563
|
+
conf_path = hf_hub_download(
|
|
1564
|
+
repo_id=repo_id,
|
|
1565
|
+
filename=f"{folder_name}/conf.yaml",
|
|
1566
|
+
force_download=force_download,
|
|
1567
|
+
)
|
|
1568
|
+
|
|
1569
|
+
# Load and parse config
|
|
1570
|
+
with open(conf_path) as f:
|
|
1571
|
+
yaml_config = yaml.safe_load(f)
|
|
1572
|
+
|
|
1573
|
+
# Extract parameters
|
|
1574
|
+
d_in = yaml_config["llm"]["dimin"]
|
|
1575
|
+
exp_factor = yaml_config["sae"]["exp_factor"]
|
|
1576
|
+
d_sae = int(d_in * exp_factor)
|
|
1577
|
+
|
|
1578
|
+
# extract layer from folder_name eg : "layer_12/temporal"
|
|
1579
|
+
layer = re.search(r"layer_(\d+)", folder_name)
|
|
1580
|
+
if layer is None:
|
|
1581
|
+
raise ValueError(f"Could not find layer in folder_name: {folder_name}")
|
|
1582
|
+
layer = int(layer.group(1))
|
|
1583
|
+
|
|
1584
|
+
# Build config dict
|
|
1585
|
+
cfg_dict = {
|
|
1586
|
+
"architecture": "temporal",
|
|
1587
|
+
"hook_name": f"blocks.{layer}.hook_resid_post",
|
|
1588
|
+
"d_in": d_in,
|
|
1589
|
+
"d_sae": d_sae,
|
|
1590
|
+
"n_heads": yaml_config["sae"]["n_heads"],
|
|
1591
|
+
"n_attn_layers": yaml_config["sae"]["n_attn_layers"],
|
|
1592
|
+
"bottleneck_factor": yaml_config["sae"]["bottleneck_factor"],
|
|
1593
|
+
"sae_diff_type": yaml_config["sae"]["sae_diff_type"],
|
|
1594
|
+
"kval_topk": yaml_config["sae"]["kval_topk"],
|
|
1595
|
+
"tied_weights": yaml_config["sae"]["tied_weights"],
|
|
1596
|
+
"dtype": yaml_config["data"]["dtype"],
|
|
1597
|
+
"device": device,
|
|
1598
|
+
"normalize_activations": "constant_scalar_rescale",
|
|
1599
|
+
"activation_normalization_factor": yaml_config["sae"]["scaling_factor"],
|
|
1600
|
+
"apply_b_dec_to_input": True,
|
|
1601
|
+
}
|
|
1602
|
+
|
|
1603
|
+
if cfg_overrides:
|
|
1604
|
+
cfg_dict.update(cfg_overrides)
|
|
1605
|
+
|
|
1606
|
+
return cfg_dict
|
|
1607
|
+
|
|
1608
|
+
|
|
1609
|
+
def temporal_sae_huggingface_loader(
|
|
1610
|
+
repo_id: str,
|
|
1611
|
+
folder_name: str,
|
|
1612
|
+
device: str = "cpu",
|
|
1613
|
+
force_download: bool = False,
|
|
1614
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
1615
|
+
) -> tuple[dict[str, Any], dict[str, torch.Tensor], torch.Tensor | None]:
|
|
1616
|
+
"""
|
|
1617
|
+
Load TemporalSAE from canrager/temporalSAEs format (safetensors version).
|
|
1618
|
+
|
|
1619
|
+
Expects folder_name to contain:
|
|
1620
|
+
- conf.yaml (configuration)
|
|
1621
|
+
- latest_ckpt.safetensors (model weights)
|
|
1622
|
+
"""
|
|
1623
|
+
|
|
1624
|
+
cfg_dict = get_temporal_sae_config_from_hf(
|
|
1625
|
+
repo_id=repo_id,
|
|
1626
|
+
folder_name=folder_name,
|
|
1627
|
+
device=device,
|
|
1628
|
+
force_download=force_download,
|
|
1629
|
+
cfg_overrides=cfg_overrides,
|
|
1630
|
+
)
|
|
1631
|
+
|
|
1632
|
+
# Download checkpoint (safetensors format)
|
|
1633
|
+
ckpt_path = hf_hub_download(
|
|
1634
|
+
repo_id=repo_id,
|
|
1635
|
+
filename=f"{folder_name}/latest_ckpt.safetensors",
|
|
1636
|
+
force_download=force_download,
|
|
1637
|
+
)
|
|
1638
|
+
|
|
1639
|
+
# Load checkpoint from safetensors
|
|
1640
|
+
state_dict_raw = load_file(ckpt_path, device=device)
|
|
1641
|
+
|
|
1642
|
+
# Convert to SAELens naming convention
|
|
1643
|
+
# TemporalSAE uses: D (decoder), E (encoder), b (bias), attn_layers.*
|
|
1644
|
+
state_dict = {}
|
|
1645
|
+
|
|
1646
|
+
# Copy attention layers as-is
|
|
1647
|
+
for key, value in state_dict_raw.items():
|
|
1648
|
+
if key.startswith("attn_layers."):
|
|
1649
|
+
state_dict[key] = value.to(device)
|
|
1650
|
+
|
|
1651
|
+
# Main parameters
|
|
1652
|
+
state_dict["W_dec"] = state_dict_raw["D"].to(device)
|
|
1653
|
+
state_dict["b_dec"] = state_dict_raw["b"].to(device)
|
|
1654
|
+
|
|
1655
|
+
# Handle tied/untied weights
|
|
1656
|
+
if "E" in state_dict_raw:
|
|
1657
|
+
state_dict["W_enc"] = state_dict_raw["E"].to(device)
|
|
1658
|
+
|
|
1659
|
+
return cfg_dict, state_dict, None
|
|
1660
|
+
|
|
1661
|
+
|
|
1478
1662
|
NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
|
|
1479
1663
|
"sae_lens": sae_lens_huggingface_loader,
|
|
1480
1664
|
"connor_rob_hook_z": connor_rob_hook_z_huggingface_loader,
|
|
@@ -1487,6 +1671,8 @@ NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
|
|
|
1487
1671
|
"gemma_2_transcoder": gemma_2_transcoder_huggingface_loader,
|
|
1488
1672
|
"mwhanna_transcoder": mwhanna_transcoder_huggingface_loader,
|
|
1489
1673
|
"mntss_clt_layer_transcoder": mntss_clt_layer_huggingface_loader,
|
|
1674
|
+
"temporal": temporal_sae_huggingface_loader,
|
|
1675
|
+
"goodfire": get_goodfire_huggingface_loader,
|
|
1490
1676
|
}
|
|
1491
1677
|
|
|
1492
1678
|
|
|
@@ -1502,4 +1688,6 @@ NAMED_PRETRAINED_SAE_CONFIG_GETTERS: dict[str, PretrainedSaeConfigHuggingfaceLoa
|
|
|
1502
1688
|
"gemma_2_transcoder": get_gemma_2_transcoder_config_from_hf,
|
|
1503
1689
|
"mwhanna_transcoder": get_mwhanna_transcoder_config_from_hf,
|
|
1504
1690
|
"mntss_clt_layer_transcoder": get_mntss_clt_layer_config_from_hf,
|
|
1691
|
+
"temporal": get_temporal_sae_config_from_hf,
|
|
1692
|
+
"goodfire": get_goodfire_config_from_hf,
|
|
1505
1693
|
}
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
2
|
from functools import cache
|
|
3
|
-
from importlib import
|
|
3
|
+
from importlib.resources import files
|
|
4
4
|
from typing import Any
|
|
5
5
|
|
|
6
6
|
import yaml
|
|
@@ -24,7 +24,8 @@ def get_pretrained_saes_directory() -> dict[str, PretrainedSAELookup]:
|
|
|
24
24
|
package = "sae_lens"
|
|
25
25
|
# Access the file within the package using importlib.resources
|
|
26
26
|
directory: dict[str, PretrainedSAELookup] = {}
|
|
27
|
-
|
|
27
|
+
yaml_file = files(package).joinpath("pretrained_saes.yaml")
|
|
28
|
+
with yaml_file.open("r") as file:
|
|
28
29
|
# Load the YAML file content
|
|
29
30
|
data = yaml.safe_load(file)
|
|
30
31
|
for release, value in data.items():
|
|
@@ -68,7 +69,8 @@ def get_norm_scaling_factor(release: str, sae_id: str) -> float | None:
|
|
|
68
69
|
float | None: The norm_scaling_factor if it exists, None otherwise.
|
|
69
70
|
"""
|
|
70
71
|
package = "sae_lens"
|
|
71
|
-
|
|
72
|
+
yaml_file = files(package).joinpath("pretrained_saes.yaml")
|
|
73
|
+
with yaml_file.open("r") as file:
|
|
72
74
|
data = yaml.safe_load(file)
|
|
73
75
|
if release in data:
|
|
74
76
|
for sae_info in data[release]["saes"]:
|
sae_lens/pretrained_saes.yaml
CHANGED
|
@@ -1,3 +1,35 @@
|
|
|
1
|
+
temporal-sae-gemma-2-2b:
|
|
2
|
+
conversion_func: temporal
|
|
3
|
+
model: gemma-2-2b
|
|
4
|
+
repo_id: canrager/temporalSAEs
|
|
5
|
+
config_overrides:
|
|
6
|
+
model_name: gemma-2-2b
|
|
7
|
+
hook_name: blocks.12.hook_resid_post
|
|
8
|
+
dataset_path: monology/pile-uncopyrighted
|
|
9
|
+
saes:
|
|
10
|
+
- id: blocks.12.hook_resid_post
|
|
11
|
+
l0: 192
|
|
12
|
+
norm_scaling_factor: 0.00666666667
|
|
13
|
+
path: gemma-2-2B/layer_12/temporal
|
|
14
|
+
neuronpedia: gemma-2-2b/12-temporal-res
|
|
15
|
+
temporal-sae-llama-3.1-8b:
|
|
16
|
+
conversion_func: temporal
|
|
17
|
+
model: meta-llama/Llama-3.1-8B
|
|
18
|
+
repo_id: canrager/temporalSAEs
|
|
19
|
+
config_overrides:
|
|
20
|
+
model_name: meta-llama/Llama-3.1-8B
|
|
21
|
+
dataset_path: monology/pile-uncopyrighted
|
|
22
|
+
saes:
|
|
23
|
+
- id: blocks.15.hook_resid_post
|
|
24
|
+
l0: 256
|
|
25
|
+
norm_scaling_factor: 0.029
|
|
26
|
+
path: llama-3.1-8B/layer_15/temporal
|
|
27
|
+
neuronpedia: llama3.1-8b/15-temporal-res
|
|
28
|
+
- id: blocks.26.hook_resid_post
|
|
29
|
+
l0: 256
|
|
30
|
+
norm_scaling_factor: 0.029
|
|
31
|
+
path: llama-3.1-8B/layer_26/temporal
|
|
32
|
+
neuronpedia: llama3.1-8b/26-temporal-res
|
|
1
33
|
deepseek-r1-distill-llama-8b-qresearch:
|
|
2
34
|
conversion_func: deepseek_r1
|
|
3
35
|
model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
|
|
@@ -14882,4 +14914,46 @@ qwen2.5-7b-instruct-andyrdt:
|
|
|
14882
14914
|
neuronpedia: qwen2.5-7b-it/23-resid-post-aa
|
|
14883
14915
|
- id: resid_post_layer_27_trainer_1
|
|
14884
14916
|
path: resid_post_layer_27/trainer_1
|
|
14885
|
-
neuronpedia: qwen2.5-7b-it/27-resid-post-aa
|
|
14917
|
+
neuronpedia: qwen2.5-7b-it/27-resid-post-aa
|
|
14918
|
+
|
|
14919
|
+
gpt-oss-20b-andyrdt:
|
|
14920
|
+
conversion_func: dictionary_learning_1
|
|
14921
|
+
model: openai/gpt-oss-20b
|
|
14922
|
+
repo_id: andyrdt/saes-gpt-oss-20b
|
|
14923
|
+
saes:
|
|
14924
|
+
- id: resid_post_layer_3_trainer_0
|
|
14925
|
+
path: resid_post_layer_3/trainer_0
|
|
14926
|
+
neuronpedia: gpt-oss-20b/3-resid-post-aa
|
|
14927
|
+
- id: resid_post_layer_7_trainer_0
|
|
14928
|
+
path: resid_post_layer_7/trainer_0
|
|
14929
|
+
neuronpedia: gpt-oss-20b/7-resid-post-aa
|
|
14930
|
+
- id: resid_post_layer_11_trainer_0
|
|
14931
|
+
path: resid_post_layer_11/trainer_0
|
|
14932
|
+
neuronpedia: gpt-oss-20b/11-resid-post-aa
|
|
14933
|
+
- id: resid_post_layer_15_trainer_0
|
|
14934
|
+
path: resid_post_layer_15/trainer_0
|
|
14935
|
+
neuronpedia: gpt-oss-20b/15-resid-post-aa
|
|
14936
|
+
- id: resid_post_layer_19_trainer_0
|
|
14937
|
+
path: resid_post_layer_19/trainer_0
|
|
14938
|
+
neuronpedia: gpt-oss-20b/19-resid-post-aa
|
|
14939
|
+
- id: resid_post_layer_23_trainer_0
|
|
14940
|
+
path: resid_post_layer_23/trainer_0
|
|
14941
|
+
neuronpedia: gpt-oss-20b/23-resid-post-aa
|
|
14942
|
+
|
|
14943
|
+
goodfire-llama-3.3-70b-instruct:
|
|
14944
|
+
conversion_func: goodfire
|
|
14945
|
+
model: meta-llama/Llama-3.3-70B-Instruct
|
|
14946
|
+
repo_id: Goodfire/Llama-3.3-70B-Instruct-SAE-l50
|
|
14947
|
+
saes:
|
|
14948
|
+
- id: layer_50
|
|
14949
|
+
path: Llama-3.3-70B-Instruct-SAE-l50.pt
|
|
14950
|
+
l0: 121
|
|
14951
|
+
|
|
14952
|
+
goodfire-llama-3.1-8b-instruct:
|
|
14953
|
+
conversion_func: goodfire
|
|
14954
|
+
model: meta-llama/Llama-3.1-8B-Instruct
|
|
14955
|
+
repo_id: Goodfire/Llama-3.1-8B-Instruct-SAE-l19
|
|
14956
|
+
saes:
|
|
14957
|
+
- id: layer_19
|
|
14958
|
+
path: Llama-3.1-8B-Instruct-SAE-l19.pth
|
|
14959
|
+
l0: 91
|
sae_lens/saes/__init__.py
CHANGED
|
@@ -25,6 +25,7 @@ from .standard_sae import (
|
|
|
25
25
|
StandardTrainingSAE,
|
|
26
26
|
StandardTrainingSAEConfig,
|
|
27
27
|
)
|
|
28
|
+
from .temporal_sae import TemporalSAE, TemporalSAEConfig
|
|
28
29
|
from .topk_sae import (
|
|
29
30
|
TopKSAE,
|
|
30
31
|
TopKSAEConfig,
|
|
@@ -71,4 +72,6 @@ __all__ = [
|
|
|
71
72
|
"JumpReLUTranscoderConfig",
|
|
72
73
|
"MatryoshkaBatchTopKTrainingSAE",
|
|
73
74
|
"MatryoshkaBatchTopKTrainingSAEConfig",
|
|
75
|
+
"TemporalSAE",
|
|
76
|
+
"TemporalSAEConfig",
|
|
74
77
|
]
|
sae_lens/saes/sae.py
CHANGED
|
@@ -21,7 +21,7 @@ import einops
|
|
|
21
21
|
import torch
|
|
22
22
|
from jaxtyping import Float
|
|
23
23
|
from numpy.typing import NDArray
|
|
24
|
-
from safetensors.torch import save_file
|
|
24
|
+
from safetensors.torch import load_file, save_file
|
|
25
25
|
from torch import nn
|
|
26
26
|
from transformer_lens.hook_points import HookedRootModule, HookPoint
|
|
27
27
|
from typing_extensions import deprecated, overload, override
|
|
@@ -155,9 +155,9 @@ class SAEConfig(ABC):
|
|
|
155
155
|
dtype: str = "float32"
|
|
156
156
|
device: str = "cpu"
|
|
157
157
|
apply_b_dec_to_input: bool = True
|
|
158
|
-
normalize_activations: Literal[
|
|
159
|
-
"none",
|
|
160
|
-
|
|
158
|
+
normalize_activations: Literal["none", "expected_average_only_in", "layer_norm"] = (
|
|
159
|
+
"none" # none, expected_average_only_in (Anthropic April Update)
|
|
160
|
+
)
|
|
161
161
|
reshape_activations: Literal["none", "hook_z"] = "none"
|
|
162
162
|
metadata: SAEMetadata = field(default_factory=SAEMetadata)
|
|
163
163
|
|
|
@@ -309,6 +309,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
309
309
|
|
|
310
310
|
self.run_time_activation_norm_fn_in = run_time_activation_norm_fn_in
|
|
311
311
|
self.run_time_activation_norm_fn_out = run_time_activation_norm_fn_out
|
|
312
|
+
|
|
312
313
|
elif self.cfg.normalize_activations == "layer_norm":
|
|
313
314
|
# we need to scale the norm of the input and store the scaling factor
|
|
314
315
|
def run_time_activation_ln_in(
|
|
@@ -452,23 +453,14 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
452
453
|
def process_sae_in(
|
|
453
454
|
self, sae_in: Float[torch.Tensor, "... d_in"]
|
|
454
455
|
) -> Float[torch.Tensor, "... d_in"]:
|
|
455
|
-
# print(f"Input shape to process_sae_in: {sae_in.shape}")
|
|
456
|
-
# print(f"self.cfg.hook_name: {self.cfg.hook_name}")
|
|
457
|
-
# print(f"self.b_dec shape: {self.b_dec.shape}")
|
|
458
|
-
# print(f"Hook z reshaping mode: {getattr(self, 'hook_z_reshaping_mode', False)}")
|
|
459
|
-
|
|
460
456
|
sae_in = sae_in.to(self.dtype)
|
|
461
|
-
|
|
462
|
-
# print(f"Shape before reshape_fn_in: {sae_in.shape}")
|
|
463
457
|
sae_in = self.reshape_fn_in(sae_in)
|
|
464
|
-
# print(f"Shape after reshape_fn_in: {sae_in.shape}")
|
|
465
458
|
|
|
466
459
|
sae_in = self.hook_sae_input(sae_in)
|
|
467
460
|
sae_in = self.run_time_activation_norm_fn_in(sae_in)
|
|
468
461
|
|
|
469
462
|
# Here's where the error happens
|
|
470
463
|
bias_term = self.b_dec * self.cfg.apply_b_dec_to_input
|
|
471
|
-
# print(f"Bias term shape: {bias_term.shape}")
|
|
472
464
|
|
|
473
465
|
return sae_in - bias_term
|
|
474
466
|
|
|
@@ -1018,6 +1010,12 @@ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
|
|
|
1018
1010
|
) -> type[TrainingSAEConfig]:
|
|
1019
1011
|
return get_sae_training_class(architecture)[1]
|
|
1020
1012
|
|
|
1013
|
+
def load_weights_from_checkpoint(self, checkpoint_path: Path | str) -> None:
|
|
1014
|
+
checkpoint_path = Path(checkpoint_path)
|
|
1015
|
+
state_dict = load_file(checkpoint_path / SAE_WEIGHTS_FILENAME)
|
|
1016
|
+
self.process_state_dict_for_loading(state_dict)
|
|
1017
|
+
self.load_state_dict(state_dict)
|
|
1018
|
+
|
|
1021
1019
|
|
|
1022
1020
|
_blank_hook = nn.Identity()
|
|
1023
1021
|
|