sae-lens 6.12.1__py3-none-any.whl → 6.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sae-lens might be problematic. Click here for more details.
- sae_lens/__init__.py +15 -1
- sae_lens/cache_activations_runner.py +1 -1
- sae_lens/config.py +39 -2
- sae_lens/constants.py +1 -0
- sae_lens/evals.py +20 -14
- sae_lens/llm_sae_training_runner.py +17 -18
- sae_lens/loading/pretrained_sae_loaders.py +194 -0
- sae_lens/loading/pretrained_saes_directory.py +5 -3
- sae_lens/pretokenize_runner.py +2 -1
- sae_lens/pretrained_saes.yaml +75 -1
- sae_lens/saes/__init__.py +9 -0
- sae_lens/saes/batchtopk_sae.py +32 -1
- sae_lens/saes/matryoshka_batchtopk_sae.py +137 -0
- sae_lens/saes/sae.py +22 -24
- sae_lens/saes/temporal_sae.py +372 -0
- sae_lens/saes/topk_sae.py +287 -17
- sae_lens/tokenization_and_batching.py +21 -6
- sae_lens/training/activation_scaler.py +7 -0
- sae_lens/training/activations_store.py +52 -31
- sae_lens/training/optim.py +11 -0
- sae_lens/training/sae_trainer.py +57 -16
- sae_lens/training/types.py +1 -1
- sae_lens/util.py +27 -0
- {sae_lens-6.12.1.dist-info → sae_lens-6.21.0.dist-info}/METADATA +19 -17
- sae_lens-6.21.0.dist-info/RECORD +41 -0
- {sae_lens-6.12.1.dist-info → sae_lens-6.21.0.dist-info}/WHEEL +1 -1
- sae_lens-6.12.1.dist-info/RECORD +0 -39
- {sae_lens-6.12.1.dist-info → sae_lens-6.21.0.dist-info/licenses}/LICENSE +0 -0
sae_lens/__init__.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
# ruff: noqa: E402
|
|
2
|
-
__version__ = "6.
|
|
2
|
+
__version__ = "6.21.0"
|
|
3
3
|
|
|
4
4
|
import logging
|
|
5
5
|
|
|
@@ -19,6 +19,8 @@ from sae_lens.saes import (
|
|
|
19
19
|
JumpReLUTrainingSAEConfig,
|
|
20
20
|
JumpReLUTranscoder,
|
|
21
21
|
JumpReLUTranscoderConfig,
|
|
22
|
+
MatryoshkaBatchTopKTrainingSAE,
|
|
23
|
+
MatryoshkaBatchTopKTrainingSAEConfig,
|
|
22
24
|
SAEConfig,
|
|
23
25
|
SkipTranscoder,
|
|
24
26
|
SkipTranscoderConfig,
|
|
@@ -26,6 +28,8 @@ from sae_lens.saes import (
|
|
|
26
28
|
StandardSAEConfig,
|
|
27
29
|
StandardTrainingSAE,
|
|
28
30
|
StandardTrainingSAEConfig,
|
|
31
|
+
TemporalSAE,
|
|
32
|
+
TemporalSAEConfig,
|
|
29
33
|
TopKSAE,
|
|
30
34
|
TopKSAEConfig,
|
|
31
35
|
TopKTrainingSAE,
|
|
@@ -101,6 +105,10 @@ __all__ = [
|
|
|
101
105
|
"SkipTranscoderConfig",
|
|
102
106
|
"JumpReLUTranscoder",
|
|
103
107
|
"JumpReLUTranscoderConfig",
|
|
108
|
+
"MatryoshkaBatchTopKTrainingSAE",
|
|
109
|
+
"MatryoshkaBatchTopKTrainingSAEConfig",
|
|
110
|
+
"TemporalSAE",
|
|
111
|
+
"TemporalSAEConfig",
|
|
104
112
|
]
|
|
105
113
|
|
|
106
114
|
|
|
@@ -115,6 +123,12 @@ register_sae_training_class("jumprelu", JumpReLUTrainingSAE, JumpReLUTrainingSAE
|
|
|
115
123
|
register_sae_training_class(
|
|
116
124
|
"batchtopk", BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig
|
|
117
125
|
)
|
|
126
|
+
register_sae_training_class(
|
|
127
|
+
"matryoshka_batchtopk",
|
|
128
|
+
MatryoshkaBatchTopKTrainingSAE,
|
|
129
|
+
MatryoshkaBatchTopKTrainingSAEConfig,
|
|
130
|
+
)
|
|
118
131
|
register_sae_class("transcoder", Transcoder, TranscoderConfig)
|
|
119
132
|
register_sae_class("skip_transcoder", SkipTranscoder, SkipTranscoderConfig)
|
|
120
133
|
register_sae_class("jumprelu_transcoder", JumpReLUTranscoder, JumpReLUTranscoderConfig)
|
|
134
|
+
register_sae_class("temporal", TemporalSAE, TemporalSAEConfig)
|
|
@@ -10,7 +10,7 @@ from datasets import Array2D, Dataset, Features, Sequence, Value
|
|
|
10
10
|
from datasets.fingerprint import generate_fingerprint
|
|
11
11
|
from huggingface_hub import HfApi
|
|
12
12
|
from jaxtyping import Float, Int
|
|
13
|
-
from tqdm import tqdm
|
|
13
|
+
from tqdm.auto import tqdm
|
|
14
14
|
from transformer_lens.HookedTransformer import HookedRootModule
|
|
15
15
|
|
|
16
16
|
from sae_lens import logger
|
sae_lens/config.py
CHANGED
|
@@ -18,6 +18,7 @@ from datasets import (
|
|
|
18
18
|
|
|
19
19
|
from sae_lens import __version__, logger
|
|
20
20
|
from sae_lens.constants import DTYPE_MAP
|
|
21
|
+
from sae_lens.registry import get_sae_training_class
|
|
21
22
|
from sae_lens.saes.sae import TrainingSAEConfig
|
|
22
23
|
|
|
23
24
|
if TYPE_CHECKING:
|
|
@@ -171,6 +172,7 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
|
171
172
|
n_checkpoints (int): The number of checkpoints to save during training. 0 means no checkpoints.
|
|
172
173
|
checkpoint_path (str | None): The path to save checkpoints. A unique ID will be appended to this path. Set to None to disable checkpoint saving. (default is "checkpoints")
|
|
173
174
|
save_final_checkpoint (bool): Whether to include an additional final checkpoint when training is finished. (default is False).
|
|
175
|
+
resume_from_checkpoint (str | None): The path to the checkpoint to resume training from. (default is None).
|
|
174
176
|
output_path (str | None): The path to save outputs. Set to None to disable output saving. (default is "output")
|
|
175
177
|
verbose (bool): Whether to print verbose output. (default is True)
|
|
176
178
|
model_kwargs (dict[str, Any]): Keyword arguments for `model.run_with_cache`
|
|
@@ -261,6 +263,7 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
|
261
263
|
checkpoint_path: str | None = "checkpoints"
|
|
262
264
|
save_final_checkpoint: bool = False
|
|
263
265
|
output_path: str | None = "output"
|
|
266
|
+
resume_from_checkpoint: str | None = None
|
|
264
267
|
|
|
265
268
|
# Misc
|
|
266
269
|
verbose: bool = True
|
|
@@ -385,8 +388,11 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
|
385
388
|
return self.sae.to_dict()
|
|
386
389
|
|
|
387
390
|
def to_dict(self) -> dict[str, Any]:
|
|
388
|
-
|
|
389
|
-
|
|
391
|
+
"""
|
|
392
|
+
Convert the config to a dictionary.
|
|
393
|
+
"""
|
|
394
|
+
|
|
395
|
+
d = asdict(self)
|
|
390
396
|
|
|
391
397
|
d["logger"] = asdict(self.logger)
|
|
392
398
|
d["sae"] = self.sae.to_dict()
|
|
@@ -396,6 +402,37 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
|
396
402
|
d["act_store_device"] = str(self.act_store_device)
|
|
397
403
|
return d
|
|
398
404
|
|
|
405
|
+
@classmethod
|
|
406
|
+
def from_dict(cls, cfg_dict: dict[str, Any]) -> "LanguageModelSAERunnerConfig[Any]":
|
|
407
|
+
"""
|
|
408
|
+
Load a LanguageModelSAERunnerConfig from a dictionary given by `to_dict`.
|
|
409
|
+
|
|
410
|
+
Args:
|
|
411
|
+
cfg_dict (dict[str, Any]): The dictionary to load the config from.
|
|
412
|
+
|
|
413
|
+
Returns:
|
|
414
|
+
LanguageModelSAERunnerConfig: The loaded config.
|
|
415
|
+
"""
|
|
416
|
+
if "sae" not in cfg_dict:
|
|
417
|
+
raise ValueError("sae field is required in the config dictionary")
|
|
418
|
+
if "architecture" not in cfg_dict["sae"]:
|
|
419
|
+
raise ValueError("architecture field is required in the sae dictionary")
|
|
420
|
+
if "logger" not in cfg_dict:
|
|
421
|
+
raise ValueError("logger field is required in the config dictionary")
|
|
422
|
+
sae_config_class = get_sae_training_class(cfg_dict["sae"]["architecture"])[1]
|
|
423
|
+
sae_cfg = sae_config_class.from_dict(cfg_dict["sae"])
|
|
424
|
+
logger_cfg = LoggingConfig(**cfg_dict["logger"])
|
|
425
|
+
updated_cfg_dict: dict[str, Any] = {
|
|
426
|
+
**cfg_dict,
|
|
427
|
+
"sae": sae_cfg,
|
|
428
|
+
"logger": logger_cfg,
|
|
429
|
+
}
|
|
430
|
+
output = cls(**updated_cfg_dict)
|
|
431
|
+
# the post_init always appends to checkpoint path, so we need to set it explicitly here.
|
|
432
|
+
if "checkpoint_path" in cfg_dict:
|
|
433
|
+
output.checkpoint_path = cfg_dict["checkpoint_path"]
|
|
434
|
+
return output
|
|
435
|
+
|
|
399
436
|
def to_sae_trainer_config(self) -> "SAETrainerConfig":
|
|
400
437
|
return SAETrainerConfig(
|
|
401
438
|
n_checkpoints=self.n_checkpoints,
|
sae_lens/constants.py
CHANGED
|
@@ -17,5 +17,6 @@ SAE_WEIGHTS_FILENAME = "sae_weights.safetensors"
|
|
|
17
17
|
SAE_CFG_FILENAME = "cfg.json"
|
|
18
18
|
RUNNER_CFG_FILENAME = "runner_cfg.json"
|
|
19
19
|
SPARSIFY_WEIGHTS_FILENAME = "sae.safetensors"
|
|
20
|
+
TRAINER_STATE_FILENAME = "trainer_state.pt"
|
|
20
21
|
ACTIVATIONS_STORE_STATE_FILENAME = "activations_store_state.safetensors"
|
|
21
22
|
ACTIVATION_SCALER_CFG_FILENAME = "activation_scaler.json"
|
sae_lens/evals.py
CHANGED
|
@@ -11,7 +11,7 @@ from dataclasses import dataclass, field
|
|
|
11
11
|
from functools import partial
|
|
12
12
|
from importlib.metadata import PackageNotFoundError, version
|
|
13
13
|
from pathlib import Path
|
|
14
|
-
from typing import Any
|
|
14
|
+
from typing import Any, Iterable
|
|
15
15
|
|
|
16
16
|
import einops
|
|
17
17
|
import pandas as pd
|
|
@@ -24,7 +24,10 @@ from sae_lens.loading.pretrained_saes_directory import get_pretrained_saes_direc
|
|
|
24
24
|
from sae_lens.saes.sae import SAE, SAEConfig
|
|
25
25
|
from sae_lens.training.activation_scaler import ActivationScaler
|
|
26
26
|
from sae_lens.training.activations_store import ActivationsStore
|
|
27
|
-
from sae_lens.util import
|
|
27
|
+
from sae_lens.util import (
|
|
28
|
+
extract_stop_at_layer_from_tlens_hook_name,
|
|
29
|
+
get_special_token_ids,
|
|
30
|
+
)
|
|
28
31
|
|
|
29
32
|
|
|
30
33
|
def get_library_version() -> str:
|
|
@@ -109,9 +112,15 @@ def run_evals(
|
|
|
109
112
|
activation_scaler: ActivationScaler,
|
|
110
113
|
eval_config: EvalConfig = EvalConfig(),
|
|
111
114
|
model_kwargs: Mapping[str, Any] = {},
|
|
112
|
-
|
|
115
|
+
exclude_special_tokens: Iterable[int] | bool = True,
|
|
113
116
|
verbose: bool = False,
|
|
114
117
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
118
|
+
ignore_tokens = None
|
|
119
|
+
if exclude_special_tokens is True:
|
|
120
|
+
ignore_tokens = list(get_special_token_ids(model.tokenizer)) # type: ignore
|
|
121
|
+
elif exclude_special_tokens:
|
|
122
|
+
ignore_tokens = list(exclude_special_tokens)
|
|
123
|
+
|
|
115
124
|
hook_name = sae.cfg.metadata.hook_name
|
|
116
125
|
actual_batch_size = (
|
|
117
126
|
eval_config.batch_size_prompts or activation_store.store_batch_size_prompts
|
|
@@ -312,7 +321,7 @@ def get_downstream_reconstruction_metrics(
|
|
|
312
321
|
compute_ce_loss: bool,
|
|
313
322
|
n_batches: int,
|
|
314
323
|
eval_batch_size_prompts: int,
|
|
315
|
-
ignore_tokens:
|
|
324
|
+
ignore_tokens: list[int] | None = None,
|
|
316
325
|
verbose: bool = False,
|
|
317
326
|
):
|
|
318
327
|
metrics_dict = {}
|
|
@@ -339,7 +348,7 @@ def get_downstream_reconstruction_metrics(
|
|
|
339
348
|
compute_ce_loss=compute_ce_loss,
|
|
340
349
|
ignore_tokens=ignore_tokens,
|
|
341
350
|
).items():
|
|
342
|
-
if
|
|
351
|
+
if ignore_tokens:
|
|
343
352
|
mask = torch.logical_not(
|
|
344
353
|
torch.any(
|
|
345
354
|
torch.stack(
|
|
@@ -384,7 +393,7 @@ def get_sparsity_and_variance_metrics(
|
|
|
384
393
|
compute_featurewise_density_statistics: bool,
|
|
385
394
|
eval_batch_size_prompts: int,
|
|
386
395
|
model_kwargs: Mapping[str, Any],
|
|
387
|
-
ignore_tokens:
|
|
396
|
+
ignore_tokens: list[int] | None = None,
|
|
388
397
|
verbose: bool = False,
|
|
389
398
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
390
399
|
hook_name = sae.cfg.metadata.hook_name
|
|
@@ -426,7 +435,7 @@ def get_sparsity_and_variance_metrics(
|
|
|
426
435
|
for _ in batch_iter:
|
|
427
436
|
batch_tokens = activation_store.get_batch_tokens(eval_batch_size_prompts)
|
|
428
437
|
|
|
429
|
-
if
|
|
438
|
+
if ignore_tokens:
|
|
430
439
|
mask = torch.logical_not(
|
|
431
440
|
torch.any(
|
|
432
441
|
torch.stack(
|
|
@@ -466,6 +475,8 @@ def get_sparsity_and_variance_metrics(
|
|
|
466
475
|
sae_out_scaled = sae.decode(sae_feature_activations).to(
|
|
467
476
|
original_act_scaled.device
|
|
468
477
|
)
|
|
478
|
+
if sae_feature_activations.is_sparse:
|
|
479
|
+
sae_feature_activations = sae_feature_activations.to_dense()
|
|
469
480
|
del cache
|
|
470
481
|
|
|
471
482
|
sae_out = activation_scaler.unscale(sae_out_scaled)
|
|
@@ -594,7 +605,7 @@ def get_recons_loss(
|
|
|
594
605
|
batch_tokens: torch.Tensor,
|
|
595
606
|
compute_kl: bool,
|
|
596
607
|
compute_ce_loss: bool,
|
|
597
|
-
ignore_tokens:
|
|
608
|
+
ignore_tokens: list[int] | None = None,
|
|
598
609
|
model_kwargs: Mapping[str, Any] = {},
|
|
599
610
|
hook_name: str | None = None,
|
|
600
611
|
) -> dict[str, Any]:
|
|
@@ -608,7 +619,7 @@ def get_recons_loss(
|
|
|
608
619
|
batch_tokens, return_type="both", loss_per_token=True, **model_kwargs
|
|
609
620
|
)
|
|
610
621
|
|
|
611
|
-
if
|
|
622
|
+
if ignore_tokens:
|
|
612
623
|
mask = torch.logical_not(
|
|
613
624
|
torch.any(
|
|
614
625
|
torch.stack([batch_tokens == token for token in ignore_tokens], dim=0),
|
|
@@ -854,11 +865,6 @@ def multiple_evals(
|
|
|
854
865
|
activation_scaler=ActivationScaler(),
|
|
855
866
|
model=current_model,
|
|
856
867
|
eval_config=eval_config,
|
|
857
|
-
ignore_tokens={
|
|
858
|
-
current_model.tokenizer.pad_token_id, # type: ignore
|
|
859
|
-
current_model.tokenizer.eos_token_id, # type: ignore
|
|
860
|
-
current_model.tokenizer.bos_token_id, # type: ignore
|
|
861
|
-
},
|
|
862
868
|
verbose=verbose,
|
|
863
869
|
)
|
|
864
870
|
eval_metrics["metrics"] = scalar_metrics
|
|
@@ -16,23 +16,18 @@ from typing_extensions import deprecated
|
|
|
16
16
|
from sae_lens import logger
|
|
17
17
|
from sae_lens.config import HfDataset, LanguageModelSAERunnerConfig
|
|
18
18
|
from sae_lens.constants import (
|
|
19
|
-
ACTIVATIONS_STORE_STATE_FILENAME,
|
|
20
19
|
RUNNER_CFG_FILENAME,
|
|
21
20
|
SPARSITY_FILENAME,
|
|
22
21
|
)
|
|
23
22
|
from sae_lens.evals import EvalConfig, run_evals
|
|
24
23
|
from sae_lens.load_model import load_model
|
|
25
|
-
from sae_lens.
|
|
26
|
-
from sae_lens.saes.gated_sae import GatedTrainingSAEConfig
|
|
27
|
-
from sae_lens.saes.jumprelu_sae import JumpReLUTrainingSAEConfig
|
|
24
|
+
from sae_lens.registry import SAE_TRAINING_CLASS_REGISTRY
|
|
28
25
|
from sae_lens.saes.sae import (
|
|
29
26
|
T_TRAINING_SAE,
|
|
30
27
|
T_TRAINING_SAE_CONFIG,
|
|
31
28
|
TrainingSAE,
|
|
32
29
|
TrainingSAEConfig,
|
|
33
30
|
)
|
|
34
|
-
from sae_lens.saes.standard_sae import StandardTrainingSAEConfig
|
|
35
|
-
from sae_lens.saes.topk_sae import TopKTrainingSAEConfig
|
|
36
31
|
from sae_lens.training.activation_scaler import ActivationScaler
|
|
37
32
|
from sae_lens.training.activations_store import ActivationsStore
|
|
38
33
|
from sae_lens.training.sae_trainer import SAETrainer
|
|
@@ -61,9 +56,11 @@ class LLMSaeEvaluator(Generic[T_TRAINING_SAE]):
|
|
|
61
56
|
data_provider: DataProvider,
|
|
62
57
|
activation_scaler: ActivationScaler,
|
|
63
58
|
) -> dict[str, Any]:
|
|
64
|
-
|
|
59
|
+
exclude_special_tokens = False
|
|
65
60
|
if self.activations_store.exclude_special_tokens is not None:
|
|
66
|
-
|
|
61
|
+
exclude_special_tokens = (
|
|
62
|
+
self.activations_store.exclude_special_tokens.tolist()
|
|
63
|
+
)
|
|
67
64
|
|
|
68
65
|
eval_config = EvalConfig(
|
|
69
66
|
batch_size_prompts=self.eval_batch_size_prompts,
|
|
@@ -81,7 +78,7 @@ class LLMSaeEvaluator(Generic[T_TRAINING_SAE]):
|
|
|
81
78
|
model=self.model,
|
|
82
79
|
activation_scaler=activation_scaler,
|
|
83
80
|
eval_config=eval_config,
|
|
84
|
-
|
|
81
|
+
exclude_special_tokens=exclude_special_tokens,
|
|
85
82
|
model_kwargs=self.model_kwargs,
|
|
86
83
|
) # not calculating featurwise metrics here.
|
|
87
84
|
|
|
@@ -114,6 +111,7 @@ class LanguageModelSAETrainingRunner:
|
|
|
114
111
|
override_dataset: HfDataset | None = None,
|
|
115
112
|
override_model: HookedRootModule | None = None,
|
|
116
113
|
override_sae: TrainingSAE[Any] | None = None,
|
|
114
|
+
resume_from_checkpoint: Path | str | None = None,
|
|
117
115
|
):
|
|
118
116
|
if override_dataset is not None:
|
|
119
117
|
logger.warning(
|
|
@@ -155,6 +153,7 @@ class LanguageModelSAETrainingRunner:
|
|
|
155
153
|
)
|
|
156
154
|
else:
|
|
157
155
|
self.sae = override_sae
|
|
156
|
+
|
|
158
157
|
self.sae.to(self.cfg.device)
|
|
159
158
|
|
|
160
159
|
def run(self):
|
|
@@ -187,6 +186,12 @@ class LanguageModelSAETrainingRunner:
|
|
|
187
186
|
cfg=self.cfg.to_sae_trainer_config(),
|
|
188
187
|
)
|
|
189
188
|
|
|
189
|
+
if self.cfg.resume_from_checkpoint is not None:
|
|
190
|
+
logger.info(f"Resuming from checkpoint: {self.cfg.resume_from_checkpoint}")
|
|
191
|
+
trainer.load_trainer_state(self.cfg.resume_from_checkpoint)
|
|
192
|
+
self.sae.load_weights_from_checkpoint(self.cfg.resume_from_checkpoint)
|
|
193
|
+
self.activations_store.load_from_checkpoint(self.cfg.resume_from_checkpoint)
|
|
194
|
+
|
|
190
195
|
self._compile_if_needed()
|
|
191
196
|
sae = self.run_trainer_with_interruption_handling(trainer)
|
|
192
197
|
|
|
@@ -306,9 +311,7 @@ class LanguageModelSAETrainingRunner:
|
|
|
306
311
|
if checkpoint_path is None:
|
|
307
312
|
return
|
|
308
313
|
|
|
309
|
-
self.activations_store.
|
|
310
|
-
str(checkpoint_path / ACTIVATIONS_STORE_STATE_FILENAME)
|
|
311
|
-
)
|
|
314
|
+
self.activations_store.save_to_checkpoint(checkpoint_path)
|
|
312
315
|
|
|
313
316
|
runner_config = self.cfg.to_dict()
|
|
314
317
|
with open(checkpoint_path / RUNNER_CFG_FILENAME, "w") as f:
|
|
@@ -393,12 +396,8 @@ def _parse_cfg_args(
|
|
|
393
396
|
)
|
|
394
397
|
|
|
395
398
|
# Map architecture to concrete config class
|
|
396
|
-
sae_config_map = {
|
|
397
|
-
|
|
398
|
-
"gated": GatedTrainingSAEConfig,
|
|
399
|
-
"jumprelu": JumpReLUTrainingSAEConfig,
|
|
400
|
-
"topk": TopKTrainingSAEConfig,
|
|
401
|
-
"batchtopk": BatchTopKTrainingSAEConfig,
|
|
399
|
+
sae_config_map: dict[str, type[TrainingSAEConfig]] = {
|
|
400
|
+
name: cfg for name, (_, cfg) in SAE_TRAINING_CLASS_REGISTRY.items()
|
|
402
401
|
}
|
|
403
402
|
|
|
404
403
|
sae_config_type = sae_config_map[architecture]
|
|
@@ -233,6 +233,12 @@ def handle_pre_6_0_config(cfg_dict: dict[str, Any]) -> dict[str, Any]:
|
|
|
233
233
|
"reshape_activations",
|
|
234
234
|
"hook_z" if "hook_z" in new_cfg.get("hook_name", "") else "none",
|
|
235
235
|
)
|
|
236
|
+
if (
|
|
237
|
+
new_cfg.get("activation_fn") == "topk"
|
|
238
|
+
and new_cfg.get("activation_fn_kwargs", {}).get("k") is not None
|
|
239
|
+
):
|
|
240
|
+
new_cfg["architecture"] = "topk"
|
|
241
|
+
new_cfg["k"] = new_cfg["activation_fn_kwargs"]["k"]
|
|
236
242
|
|
|
237
243
|
if "normalize_activations" in new_cfg and isinstance(
|
|
238
244
|
new_cfg["normalize_activations"], bool
|
|
@@ -517,6 +523,82 @@ def gemma_2_sae_huggingface_loader(
|
|
|
517
523
|
return cfg_dict, state_dict, log_sparsity
|
|
518
524
|
|
|
519
525
|
|
|
526
|
+
def get_goodfire_config_from_hf(
|
|
527
|
+
repo_id: str,
|
|
528
|
+
folder_name: str, # noqa: ARG001
|
|
529
|
+
device: str,
|
|
530
|
+
force_download: bool = False, # noqa: ARG001
|
|
531
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
532
|
+
) -> dict[str, Any]:
|
|
533
|
+
cfg_dict = None
|
|
534
|
+
if repo_id == "Goodfire/Llama-3.3-70B-Instruct-SAE-l50":
|
|
535
|
+
if folder_name != "Llama-3.3-70B-Instruct-SAE-l50.pt":
|
|
536
|
+
raise ValueError(f"Unsupported Goodfire SAE: {repo_id}/{folder_name}")
|
|
537
|
+
cfg_dict = {
|
|
538
|
+
"architecture": "standard",
|
|
539
|
+
"d_in": 8192,
|
|
540
|
+
"d_sae": 65536,
|
|
541
|
+
"model_name": "meta-llama/Llama-3.3-70B-Instruct",
|
|
542
|
+
"hook_name": "blocks.50.hook_resid_post",
|
|
543
|
+
"hook_head_index": None,
|
|
544
|
+
"dataset_path": "lmsys/lmsys-chat-1m",
|
|
545
|
+
"apply_b_dec_to_input": False,
|
|
546
|
+
}
|
|
547
|
+
elif repo_id == "Goodfire/Llama-3.1-8B-Instruct-SAE-l19":
|
|
548
|
+
if folder_name != "Llama-3.1-8B-Instruct-SAE-l19.pth":
|
|
549
|
+
raise ValueError(f"Unsupported Goodfire SAE: {repo_id}/{folder_name}")
|
|
550
|
+
cfg_dict = {
|
|
551
|
+
"architecture": "standard",
|
|
552
|
+
"d_in": 4096,
|
|
553
|
+
"d_sae": 65536,
|
|
554
|
+
"model_name": "meta-llama/Llama-3.1-8B-Instruct",
|
|
555
|
+
"hook_name": "blocks.19.hook_resid_post",
|
|
556
|
+
"hook_head_index": None,
|
|
557
|
+
"dataset_path": "lmsys/lmsys-chat-1m",
|
|
558
|
+
"apply_b_dec_to_input": False,
|
|
559
|
+
}
|
|
560
|
+
if cfg_dict is None:
|
|
561
|
+
raise ValueError(f"Unsupported Goodfire SAE: {repo_id}/{folder_name}")
|
|
562
|
+
if device is not None:
|
|
563
|
+
cfg_dict["device"] = device
|
|
564
|
+
if cfg_overrides is not None:
|
|
565
|
+
cfg_dict.update(cfg_overrides)
|
|
566
|
+
return cfg_dict
|
|
567
|
+
|
|
568
|
+
|
|
569
|
+
def get_goodfire_huggingface_loader(
|
|
570
|
+
repo_id: str,
|
|
571
|
+
folder_name: str,
|
|
572
|
+
device: str = "cpu",
|
|
573
|
+
force_download: bool = False,
|
|
574
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
575
|
+
) -> tuple[dict[str, Any], dict[str, torch.Tensor], torch.Tensor | None]:
|
|
576
|
+
cfg_dict = get_goodfire_config_from_hf(
|
|
577
|
+
repo_id,
|
|
578
|
+
folder_name,
|
|
579
|
+
device,
|
|
580
|
+
force_download,
|
|
581
|
+
cfg_overrides,
|
|
582
|
+
)
|
|
583
|
+
|
|
584
|
+
# Download the SAE weights
|
|
585
|
+
sae_path = hf_hub_download(
|
|
586
|
+
repo_id=repo_id,
|
|
587
|
+
filename=folder_name,
|
|
588
|
+
force_download=force_download,
|
|
589
|
+
)
|
|
590
|
+
raw_state_dict = torch.load(sae_path, map_location=device)
|
|
591
|
+
|
|
592
|
+
state_dict = {
|
|
593
|
+
"W_enc": raw_state_dict["encoder_linear.weight"].T,
|
|
594
|
+
"W_dec": raw_state_dict["decoder_linear.weight"].T,
|
|
595
|
+
"b_enc": raw_state_dict["encoder_linear.bias"],
|
|
596
|
+
"b_dec": raw_state_dict["decoder_linear.bias"],
|
|
597
|
+
}
|
|
598
|
+
|
|
599
|
+
return cfg_dict, state_dict, None
|
|
600
|
+
|
|
601
|
+
|
|
520
602
|
def get_llama_scope_config_from_hf(
|
|
521
603
|
repo_id: str,
|
|
522
604
|
folder_name: str,
|
|
@@ -1469,6 +1551,114 @@ def get_mntss_clt_layer_config_from_hf(
|
|
|
1469
1551
|
}
|
|
1470
1552
|
|
|
1471
1553
|
|
|
1554
|
+
def get_temporal_sae_config_from_hf(
|
|
1555
|
+
repo_id: str,
|
|
1556
|
+
folder_name: str,
|
|
1557
|
+
device: str,
|
|
1558
|
+
force_download: bool = False,
|
|
1559
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
1560
|
+
) -> dict[str, Any]:
|
|
1561
|
+
"""Get TemporalSAE config without loading weights."""
|
|
1562
|
+
# Download config file
|
|
1563
|
+
conf_path = hf_hub_download(
|
|
1564
|
+
repo_id=repo_id,
|
|
1565
|
+
filename=f"{folder_name}/conf.yaml",
|
|
1566
|
+
force_download=force_download,
|
|
1567
|
+
)
|
|
1568
|
+
|
|
1569
|
+
# Load and parse config
|
|
1570
|
+
with open(conf_path) as f:
|
|
1571
|
+
yaml_config = yaml.safe_load(f)
|
|
1572
|
+
|
|
1573
|
+
# Extract parameters
|
|
1574
|
+
d_in = yaml_config["llm"]["dimin"]
|
|
1575
|
+
exp_factor = yaml_config["sae"]["exp_factor"]
|
|
1576
|
+
d_sae = int(d_in * exp_factor)
|
|
1577
|
+
|
|
1578
|
+
# extract layer from folder_name eg : "layer_12/temporal"
|
|
1579
|
+
layer = re.search(r"layer_(\d+)", folder_name)
|
|
1580
|
+
if layer is None:
|
|
1581
|
+
raise ValueError(f"Could not find layer in folder_name: {folder_name}")
|
|
1582
|
+
layer = int(layer.group(1))
|
|
1583
|
+
|
|
1584
|
+
# Build config dict
|
|
1585
|
+
cfg_dict = {
|
|
1586
|
+
"architecture": "temporal",
|
|
1587
|
+
"hook_name": f"blocks.{layer}.hook_resid_post",
|
|
1588
|
+
"d_in": d_in,
|
|
1589
|
+
"d_sae": d_sae,
|
|
1590
|
+
"n_heads": yaml_config["sae"]["n_heads"],
|
|
1591
|
+
"n_attn_layers": yaml_config["sae"]["n_attn_layers"],
|
|
1592
|
+
"bottleneck_factor": yaml_config["sae"]["bottleneck_factor"],
|
|
1593
|
+
"sae_diff_type": yaml_config["sae"]["sae_diff_type"],
|
|
1594
|
+
"kval_topk": yaml_config["sae"]["kval_topk"],
|
|
1595
|
+
"tied_weights": yaml_config["sae"]["tied_weights"],
|
|
1596
|
+
"dtype": yaml_config["data"]["dtype"],
|
|
1597
|
+
"device": device,
|
|
1598
|
+
"normalize_activations": "constant_scalar_rescale",
|
|
1599
|
+
"activation_normalization_factor": yaml_config["sae"]["scaling_factor"],
|
|
1600
|
+
"apply_b_dec_to_input": True,
|
|
1601
|
+
}
|
|
1602
|
+
|
|
1603
|
+
if cfg_overrides:
|
|
1604
|
+
cfg_dict.update(cfg_overrides)
|
|
1605
|
+
|
|
1606
|
+
return cfg_dict
|
|
1607
|
+
|
|
1608
|
+
|
|
1609
|
+
def temporal_sae_huggingface_loader(
|
|
1610
|
+
repo_id: str,
|
|
1611
|
+
folder_name: str,
|
|
1612
|
+
device: str = "cpu",
|
|
1613
|
+
force_download: bool = False,
|
|
1614
|
+
cfg_overrides: dict[str, Any] | None = None,
|
|
1615
|
+
) -> tuple[dict[str, Any], dict[str, torch.Tensor], torch.Tensor | None]:
|
|
1616
|
+
"""
|
|
1617
|
+
Load TemporalSAE from canrager/temporalSAEs format (safetensors version).
|
|
1618
|
+
|
|
1619
|
+
Expects folder_name to contain:
|
|
1620
|
+
- conf.yaml (configuration)
|
|
1621
|
+
- latest_ckpt.safetensors (model weights)
|
|
1622
|
+
"""
|
|
1623
|
+
|
|
1624
|
+
cfg_dict = get_temporal_sae_config_from_hf(
|
|
1625
|
+
repo_id=repo_id,
|
|
1626
|
+
folder_name=folder_name,
|
|
1627
|
+
device=device,
|
|
1628
|
+
force_download=force_download,
|
|
1629
|
+
cfg_overrides=cfg_overrides,
|
|
1630
|
+
)
|
|
1631
|
+
|
|
1632
|
+
# Download checkpoint (safetensors format)
|
|
1633
|
+
ckpt_path = hf_hub_download(
|
|
1634
|
+
repo_id=repo_id,
|
|
1635
|
+
filename=f"{folder_name}/latest_ckpt.safetensors",
|
|
1636
|
+
force_download=force_download,
|
|
1637
|
+
)
|
|
1638
|
+
|
|
1639
|
+
# Load checkpoint from safetensors
|
|
1640
|
+
state_dict_raw = load_file(ckpt_path, device=device)
|
|
1641
|
+
|
|
1642
|
+
# Convert to SAELens naming convention
|
|
1643
|
+
# TemporalSAE uses: D (decoder), E (encoder), b (bias), attn_layers.*
|
|
1644
|
+
state_dict = {}
|
|
1645
|
+
|
|
1646
|
+
# Copy attention layers as-is
|
|
1647
|
+
for key, value in state_dict_raw.items():
|
|
1648
|
+
if key.startswith("attn_layers."):
|
|
1649
|
+
state_dict[key] = value.to(device)
|
|
1650
|
+
|
|
1651
|
+
# Main parameters
|
|
1652
|
+
state_dict["W_dec"] = state_dict_raw["D"].to(device)
|
|
1653
|
+
state_dict["b_dec"] = state_dict_raw["b"].to(device)
|
|
1654
|
+
|
|
1655
|
+
# Handle tied/untied weights
|
|
1656
|
+
if "E" in state_dict_raw:
|
|
1657
|
+
state_dict["W_enc"] = state_dict_raw["E"].to(device)
|
|
1658
|
+
|
|
1659
|
+
return cfg_dict, state_dict, None
|
|
1660
|
+
|
|
1661
|
+
|
|
1472
1662
|
NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
|
|
1473
1663
|
"sae_lens": sae_lens_huggingface_loader,
|
|
1474
1664
|
"connor_rob_hook_z": connor_rob_hook_z_huggingface_loader,
|
|
@@ -1481,6 +1671,8 @@ NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
|
|
|
1481
1671
|
"gemma_2_transcoder": gemma_2_transcoder_huggingface_loader,
|
|
1482
1672
|
"mwhanna_transcoder": mwhanna_transcoder_huggingface_loader,
|
|
1483
1673
|
"mntss_clt_layer_transcoder": mntss_clt_layer_huggingface_loader,
|
|
1674
|
+
"temporal": temporal_sae_huggingface_loader,
|
|
1675
|
+
"goodfire": get_goodfire_huggingface_loader,
|
|
1484
1676
|
}
|
|
1485
1677
|
|
|
1486
1678
|
|
|
@@ -1496,4 +1688,6 @@ NAMED_PRETRAINED_SAE_CONFIG_GETTERS: dict[str, PretrainedSaeConfigHuggingfaceLoa
|
|
|
1496
1688
|
"gemma_2_transcoder": get_gemma_2_transcoder_config_from_hf,
|
|
1497
1689
|
"mwhanna_transcoder": get_mwhanna_transcoder_config_from_hf,
|
|
1498
1690
|
"mntss_clt_layer_transcoder": get_mntss_clt_layer_config_from_hf,
|
|
1691
|
+
"temporal": get_temporal_sae_config_from_hf,
|
|
1692
|
+
"goodfire": get_goodfire_config_from_hf,
|
|
1499
1693
|
}
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
2
|
from functools import cache
|
|
3
|
-
from importlib import
|
|
3
|
+
from importlib.resources import files
|
|
4
4
|
from typing import Any
|
|
5
5
|
|
|
6
6
|
import yaml
|
|
@@ -24,7 +24,8 @@ def get_pretrained_saes_directory() -> dict[str, PretrainedSAELookup]:
|
|
|
24
24
|
package = "sae_lens"
|
|
25
25
|
# Access the file within the package using importlib.resources
|
|
26
26
|
directory: dict[str, PretrainedSAELookup] = {}
|
|
27
|
-
|
|
27
|
+
yaml_file = files(package).joinpath("pretrained_saes.yaml")
|
|
28
|
+
with yaml_file.open("r") as file:
|
|
28
29
|
# Load the YAML file content
|
|
29
30
|
data = yaml.safe_load(file)
|
|
30
31
|
for release, value in data.items():
|
|
@@ -68,7 +69,8 @@ def get_norm_scaling_factor(release: str, sae_id: str) -> float | None:
|
|
|
68
69
|
float | None: The norm_scaling_factor if it exists, None otherwise.
|
|
69
70
|
"""
|
|
70
71
|
package = "sae_lens"
|
|
71
|
-
|
|
72
|
+
yaml_file = files(package).joinpath("pretrained_saes.yaml")
|
|
73
|
+
with yaml_file.open("r") as file:
|
|
72
74
|
data = yaml.safe_load(file)
|
|
73
75
|
if release in data:
|
|
74
76
|
for sae_info in data[release]["saes"]:
|
sae_lens/pretokenize_runner.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
import io
|
|
2
2
|
import json
|
|
3
3
|
import sys
|
|
4
|
+
from collections.abc import Iterator
|
|
4
5
|
from dataclasses import dataclass
|
|
5
6
|
from pathlib import Path
|
|
6
|
-
from typing import
|
|
7
|
+
from typing import Literal, cast
|
|
7
8
|
|
|
8
9
|
import torch
|
|
9
10
|
from datasets import Dataset, DatasetDict, load_dataset
|