sae-lens 6.15.0__py3-none-any.whl → 6.24.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +13 -1
- sae_lens/analysis/hooked_sae_transformer.py +4 -13
- sae_lens/cache_activations_runner.py +3 -4
- sae_lens/config.py +39 -2
- sae_lens/constants.py +1 -0
- sae_lens/llm_sae_training_runner.py +9 -4
- sae_lens/loading/pretrained_sae_loaders.py +430 -24
- sae_lens/loading/pretrained_saes_directory.py +5 -3
- sae_lens/pretokenize_runner.py +3 -3
- sae_lens/pretrained_saes.yaml +26977 -65
- sae_lens/saes/__init__.py +7 -0
- sae_lens/saes/batchtopk_sae.py +3 -1
- sae_lens/saes/gated_sae.py +6 -11
- sae_lens/saes/jumprelu_sae.py +8 -13
- sae_lens/saes/matryoshka_batchtopk_sae.py +8 -15
- sae_lens/saes/sae.py +20 -32
- sae_lens/saes/standard_sae.py +4 -9
- sae_lens/saes/temporal_sae.py +365 -0
- sae_lens/saes/topk_sae.py +8 -11
- sae_lens/saes/transcoder.py +41 -0
- sae_lens/training/activation_scaler.py +7 -0
- sae_lens/training/activations_store.py +54 -12
- sae_lens/training/optim.py +11 -0
- sae_lens/training/sae_trainer.py +50 -11
- {sae_lens-6.15.0.dist-info → sae_lens-6.24.1.dist-info}/METADATA +16 -16
- sae_lens-6.24.1.dist-info/RECORD +41 -0
- sae_lens-6.15.0.dist-info/RECORD +0 -40
- {sae_lens-6.15.0.dist-info → sae_lens-6.24.1.dist-info}/WHEEL +0 -0
- {sae_lens-6.15.0.dist-info → sae_lens-6.24.1.dist-info}/licenses/LICENSE +0 -0
sae_lens/__init__.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
# ruff: noqa: E402
|
|
2
|
-
__version__ = "6.
|
|
2
|
+
__version__ = "6.24.1"
|
|
3
3
|
|
|
4
4
|
import logging
|
|
5
5
|
|
|
@@ -15,6 +15,8 @@ from sae_lens.saes import (
|
|
|
15
15
|
GatedTrainingSAEConfig,
|
|
16
16
|
JumpReLUSAE,
|
|
17
17
|
JumpReLUSAEConfig,
|
|
18
|
+
JumpReLUSkipTranscoder,
|
|
19
|
+
JumpReLUSkipTranscoderConfig,
|
|
18
20
|
JumpReLUTrainingSAE,
|
|
19
21
|
JumpReLUTrainingSAEConfig,
|
|
20
22
|
JumpReLUTranscoder,
|
|
@@ -28,6 +30,8 @@ from sae_lens.saes import (
|
|
|
28
30
|
StandardSAEConfig,
|
|
29
31
|
StandardTrainingSAE,
|
|
30
32
|
StandardTrainingSAEConfig,
|
|
33
|
+
TemporalSAE,
|
|
34
|
+
TemporalSAEConfig,
|
|
31
35
|
TopKSAE,
|
|
32
36
|
TopKSAEConfig,
|
|
33
37
|
TopKTrainingSAE,
|
|
@@ -103,8 +107,12 @@ __all__ = [
|
|
|
103
107
|
"SkipTranscoderConfig",
|
|
104
108
|
"JumpReLUTranscoder",
|
|
105
109
|
"JumpReLUTranscoderConfig",
|
|
110
|
+
"JumpReLUSkipTranscoder",
|
|
111
|
+
"JumpReLUSkipTranscoderConfig",
|
|
106
112
|
"MatryoshkaBatchTopKTrainingSAE",
|
|
107
113
|
"MatryoshkaBatchTopKTrainingSAEConfig",
|
|
114
|
+
"TemporalSAE",
|
|
115
|
+
"TemporalSAEConfig",
|
|
108
116
|
]
|
|
109
117
|
|
|
110
118
|
|
|
@@ -127,3 +135,7 @@ register_sae_training_class(
|
|
|
127
135
|
register_sae_class("transcoder", Transcoder, TranscoderConfig)
|
|
128
136
|
register_sae_class("skip_transcoder", SkipTranscoder, SkipTranscoderConfig)
|
|
129
137
|
register_sae_class("jumprelu_transcoder", JumpReLUTranscoder, JumpReLUTranscoderConfig)
|
|
138
|
+
register_sae_class(
|
|
139
|
+
"jumprelu_skip_transcoder", JumpReLUSkipTranscoder, JumpReLUSkipTranscoderConfig
|
|
140
|
+
)
|
|
141
|
+
register_sae_class("temporal", TemporalSAE, TemporalSAEConfig)
|
|
@@ -3,7 +3,6 @@ from contextlib import contextmanager
|
|
|
3
3
|
from typing import Any, Callable
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
|
-
from jaxtyping import Float
|
|
7
6
|
from transformer_lens.ActivationCache import ActivationCache
|
|
8
7
|
from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP
|
|
9
8
|
from transformer_lens.hook_points import HookPoint # Hooking utilities
|
|
@@ -11,8 +10,8 @@ from transformer_lens.HookedTransformer import HookedTransformer
|
|
|
11
10
|
|
|
12
11
|
from sae_lens.saes.sae import SAE
|
|
13
12
|
|
|
14
|
-
SingleLoss =
|
|
15
|
-
LossPerToken =
|
|
13
|
+
SingleLoss = torch.Tensor # Type alias for a single element tensor
|
|
14
|
+
LossPerToken = torch.Tensor
|
|
16
15
|
Loss = SingleLoss | LossPerToken
|
|
17
16
|
|
|
18
17
|
|
|
@@ -171,12 +170,7 @@ class HookedSAETransformer(HookedTransformer):
|
|
|
171
170
|
reset_saes_end: bool = True,
|
|
172
171
|
use_error_term: bool | None = None,
|
|
173
172
|
**model_kwargs: Any,
|
|
174
|
-
) ->
|
|
175
|
-
None
|
|
176
|
-
| Float[torch.Tensor, "batch pos d_vocab"]
|
|
177
|
-
| Loss
|
|
178
|
-
| tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss]
|
|
179
|
-
):
|
|
173
|
+
) -> None | torch.Tensor | Loss | tuple[torch.Tensor, Loss]:
|
|
180
174
|
"""Wrapper around HookedTransformer forward pass.
|
|
181
175
|
|
|
182
176
|
Runs the model with the given SAEs attached for one forward pass, then removes them. By default, will reset all SAEs to original state after.
|
|
@@ -203,10 +197,7 @@ class HookedSAETransformer(HookedTransformer):
|
|
|
203
197
|
remove_batch_dim: bool = False,
|
|
204
198
|
**kwargs: Any,
|
|
205
199
|
) -> tuple[
|
|
206
|
-
None
|
|
207
|
-
| Float[torch.Tensor, "batch pos d_vocab"]
|
|
208
|
-
| Loss
|
|
209
|
-
| tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss],
|
|
200
|
+
None | torch.Tensor | Loss | tuple[torch.Tensor, Loss],
|
|
210
201
|
ActivationCache | dict[str, torch.Tensor],
|
|
211
202
|
]:
|
|
212
203
|
"""Wrapper around 'run_with_cache' in HookedTransformer.
|
|
@@ -9,8 +9,7 @@ import torch
|
|
|
9
9
|
from datasets import Array2D, Dataset, Features, Sequence, Value
|
|
10
10
|
from datasets.fingerprint import generate_fingerprint
|
|
11
11
|
from huggingface_hub import HfApi
|
|
12
|
-
from
|
|
13
|
-
from tqdm import tqdm
|
|
12
|
+
from tqdm.auto import tqdm
|
|
14
13
|
from transformer_lens.HookedTransformer import HookedRootModule
|
|
15
14
|
|
|
16
15
|
from sae_lens import logger
|
|
@@ -318,8 +317,8 @@ class CacheActivationsRunner:
|
|
|
318
317
|
def _create_shard(
|
|
319
318
|
self,
|
|
320
319
|
buffer: tuple[
|
|
321
|
-
|
|
322
|
-
|
|
320
|
+
torch.Tensor, # shape: (bs context_size) d_in
|
|
321
|
+
torch.Tensor | None, # shape: (bs context_size) or None
|
|
323
322
|
],
|
|
324
323
|
) -> Dataset:
|
|
325
324
|
hook_names = [self.cfg.hook_name]
|
sae_lens/config.py
CHANGED
|
@@ -18,6 +18,7 @@ from datasets import (
|
|
|
18
18
|
|
|
19
19
|
from sae_lens import __version__, logger
|
|
20
20
|
from sae_lens.constants import DTYPE_MAP
|
|
21
|
+
from sae_lens.registry import get_sae_training_class
|
|
21
22
|
from sae_lens.saes.sae import TrainingSAEConfig
|
|
22
23
|
|
|
23
24
|
if TYPE_CHECKING:
|
|
@@ -171,6 +172,7 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
|
171
172
|
n_checkpoints (int): The number of checkpoints to save during training. 0 means no checkpoints.
|
|
172
173
|
checkpoint_path (str | None): The path to save checkpoints. A unique ID will be appended to this path. Set to None to disable checkpoint saving. (default is "checkpoints")
|
|
173
174
|
save_final_checkpoint (bool): Whether to include an additional final checkpoint when training is finished. (default is False).
|
|
175
|
+
resume_from_checkpoint (str | None): The path to the checkpoint to resume training from. (default is None).
|
|
174
176
|
output_path (str | None): The path to save outputs. Set to None to disable output saving. (default is "output")
|
|
175
177
|
verbose (bool): Whether to print verbose output. (default is True)
|
|
176
178
|
model_kwargs (dict[str, Any]): Keyword arguments for `model.run_with_cache`
|
|
@@ -261,6 +263,7 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
|
261
263
|
checkpoint_path: str | None = "checkpoints"
|
|
262
264
|
save_final_checkpoint: bool = False
|
|
263
265
|
output_path: str | None = "output"
|
|
266
|
+
resume_from_checkpoint: str | None = None
|
|
264
267
|
|
|
265
268
|
# Misc
|
|
266
269
|
verbose: bool = True
|
|
@@ -385,8 +388,11 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
|
385
388
|
return self.sae.to_dict()
|
|
386
389
|
|
|
387
390
|
def to_dict(self) -> dict[str, Any]:
|
|
388
|
-
|
|
389
|
-
|
|
391
|
+
"""
|
|
392
|
+
Convert the config to a dictionary.
|
|
393
|
+
"""
|
|
394
|
+
|
|
395
|
+
d = asdict(self)
|
|
390
396
|
|
|
391
397
|
d["logger"] = asdict(self.logger)
|
|
392
398
|
d["sae"] = self.sae.to_dict()
|
|
@@ -396,6 +402,37 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
|
396
402
|
d["act_store_device"] = str(self.act_store_device)
|
|
397
403
|
return d
|
|
398
404
|
|
|
405
|
+
@classmethod
|
|
406
|
+
def from_dict(cls, cfg_dict: dict[str, Any]) -> "LanguageModelSAERunnerConfig[Any]":
|
|
407
|
+
"""
|
|
408
|
+
Load a LanguageModelSAERunnerConfig from a dictionary given by `to_dict`.
|
|
409
|
+
|
|
410
|
+
Args:
|
|
411
|
+
cfg_dict (dict[str, Any]): The dictionary to load the config from.
|
|
412
|
+
|
|
413
|
+
Returns:
|
|
414
|
+
LanguageModelSAERunnerConfig: The loaded config.
|
|
415
|
+
"""
|
|
416
|
+
if "sae" not in cfg_dict:
|
|
417
|
+
raise ValueError("sae field is required in the config dictionary")
|
|
418
|
+
if "architecture" not in cfg_dict["sae"]:
|
|
419
|
+
raise ValueError("architecture field is required in the sae dictionary")
|
|
420
|
+
if "logger" not in cfg_dict:
|
|
421
|
+
raise ValueError("logger field is required in the config dictionary")
|
|
422
|
+
sae_config_class = get_sae_training_class(cfg_dict["sae"]["architecture"])[1]
|
|
423
|
+
sae_cfg = sae_config_class.from_dict(cfg_dict["sae"])
|
|
424
|
+
logger_cfg = LoggingConfig(**cfg_dict["logger"])
|
|
425
|
+
updated_cfg_dict: dict[str, Any] = {
|
|
426
|
+
**cfg_dict,
|
|
427
|
+
"sae": sae_cfg,
|
|
428
|
+
"logger": logger_cfg,
|
|
429
|
+
}
|
|
430
|
+
output = cls(**updated_cfg_dict)
|
|
431
|
+
# the post_init always appends to checkpoint path, so we need to set it explicitly here.
|
|
432
|
+
if "checkpoint_path" in cfg_dict:
|
|
433
|
+
output.checkpoint_path = cfg_dict["checkpoint_path"]
|
|
434
|
+
return output
|
|
435
|
+
|
|
399
436
|
def to_sae_trainer_config(self) -> "SAETrainerConfig":
|
|
400
437
|
return SAETrainerConfig(
|
|
401
438
|
n_checkpoints=self.n_checkpoints,
|
sae_lens/constants.py
CHANGED
|
@@ -17,5 +17,6 @@ SAE_WEIGHTS_FILENAME = "sae_weights.safetensors"
|
|
|
17
17
|
SAE_CFG_FILENAME = "cfg.json"
|
|
18
18
|
RUNNER_CFG_FILENAME = "runner_cfg.json"
|
|
19
19
|
SPARSIFY_WEIGHTS_FILENAME = "sae.safetensors"
|
|
20
|
+
TRAINER_STATE_FILENAME = "trainer_state.pt"
|
|
20
21
|
ACTIVATIONS_STORE_STATE_FILENAME = "activations_store_state.safetensors"
|
|
21
22
|
ACTIVATION_SCALER_CFG_FILENAME = "activation_scaler.json"
|
|
@@ -16,7 +16,6 @@ from typing_extensions import deprecated
|
|
|
16
16
|
from sae_lens import logger
|
|
17
17
|
from sae_lens.config import HfDataset, LanguageModelSAERunnerConfig
|
|
18
18
|
from sae_lens.constants import (
|
|
19
|
-
ACTIVATIONS_STORE_STATE_FILENAME,
|
|
20
19
|
RUNNER_CFG_FILENAME,
|
|
21
20
|
SPARSITY_FILENAME,
|
|
22
21
|
)
|
|
@@ -112,6 +111,7 @@ class LanguageModelSAETrainingRunner:
|
|
|
112
111
|
override_dataset: HfDataset | None = None,
|
|
113
112
|
override_model: HookedRootModule | None = None,
|
|
114
113
|
override_sae: TrainingSAE[Any] | None = None,
|
|
114
|
+
resume_from_checkpoint: Path | str | None = None,
|
|
115
115
|
):
|
|
116
116
|
if override_dataset is not None:
|
|
117
117
|
logger.warning(
|
|
@@ -153,6 +153,7 @@ class LanguageModelSAETrainingRunner:
|
|
|
153
153
|
)
|
|
154
154
|
else:
|
|
155
155
|
self.sae = override_sae
|
|
156
|
+
|
|
156
157
|
self.sae.to(self.cfg.device)
|
|
157
158
|
|
|
158
159
|
def run(self):
|
|
@@ -185,6 +186,12 @@ class LanguageModelSAETrainingRunner:
|
|
|
185
186
|
cfg=self.cfg.to_sae_trainer_config(),
|
|
186
187
|
)
|
|
187
188
|
|
|
189
|
+
if self.cfg.resume_from_checkpoint is not None:
|
|
190
|
+
logger.info(f"Resuming from checkpoint: {self.cfg.resume_from_checkpoint}")
|
|
191
|
+
trainer.load_trainer_state(self.cfg.resume_from_checkpoint)
|
|
192
|
+
self.sae.load_weights_from_checkpoint(self.cfg.resume_from_checkpoint)
|
|
193
|
+
self.activations_store.load_from_checkpoint(self.cfg.resume_from_checkpoint)
|
|
194
|
+
|
|
188
195
|
self._compile_if_needed()
|
|
189
196
|
sae = self.run_trainer_with_interruption_handling(trainer)
|
|
190
197
|
|
|
@@ -304,9 +311,7 @@ class LanguageModelSAETrainingRunner:
|
|
|
304
311
|
if checkpoint_path is None:
|
|
305
312
|
return
|
|
306
313
|
|
|
307
|
-
self.activations_store.
|
|
308
|
-
str(checkpoint_path / ACTIVATIONS_STORE_STATE_FILENAME)
|
|
309
|
-
)
|
|
314
|
+
self.activations_store.save_to_checkpoint(checkpoint_path)
|
|
310
315
|
|
|
311
316
|
runner_config = self.cfg.to_dict()
|
|
312
317
|
with open(checkpoint_path / RUNNER_CFG_FILENAME, "w") as f:
|