sae-lens 6.16.3__tar.gz → 6.22.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. {sae_lens-6.16.3 → sae_lens-6.22.0}/PKG-INFO +16 -16
  2. {sae_lens-6.16.3 → sae_lens-6.22.0}/README.md +13 -13
  3. {sae_lens-6.16.3 → sae_lens-6.22.0}/pyproject.toml +3 -3
  4. {sae_lens-6.16.3 → sae_lens-6.22.0}/sae_lens/__init__.py +6 -1
  5. {sae_lens-6.16.3 → sae_lens-6.22.0}/sae_lens/cache_activations_runner.py +1 -1
  6. {sae_lens-6.16.3 → sae_lens-6.22.0}/sae_lens/config.py +39 -2
  7. {sae_lens-6.16.3 → sae_lens-6.22.0}/sae_lens/constants.py +1 -0
  8. {sae_lens-6.16.3 → sae_lens-6.22.0}/sae_lens/llm_sae_training_runner.py +9 -4
  9. {sae_lens-6.16.3 → sae_lens-6.22.0}/sae_lens/loading/pretrained_sae_loaders.py +188 -0
  10. {sae_lens-6.16.3 → sae_lens-6.22.0}/sae_lens/loading/pretrained_saes_directory.py +5 -3
  11. {sae_lens-6.16.3 → sae_lens-6.22.0}/sae_lens/pretrained_saes.yaml +77 -1
  12. {sae_lens-6.16.3 → sae_lens-6.22.0}/sae_lens/saes/__init__.py +3 -0
  13. {sae_lens-6.16.3 → sae_lens-6.22.0}/sae_lens/saes/sae.py +11 -13
  14. sae_lens-6.22.0/sae_lens/saes/temporal_sae.py +372 -0
  15. {sae_lens-6.16.3 → sae_lens-6.22.0}/sae_lens/training/activation_scaler.py +7 -0
  16. {sae_lens-6.16.3 → sae_lens-6.22.0}/sae_lens/training/activations_store.py +47 -4
  17. {sae_lens-6.16.3 → sae_lens-6.22.0}/sae_lens/training/optim.py +11 -0
  18. {sae_lens-6.16.3 → sae_lens-6.22.0}/sae_lens/training/sae_trainer.py +49 -11
  19. {sae_lens-6.16.3 → sae_lens-6.22.0}/LICENSE +0 -0
  20. {sae_lens-6.16.3 → sae_lens-6.22.0}/sae_lens/analysis/__init__.py +0 -0
  21. {sae_lens-6.16.3 → sae_lens-6.22.0}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
  22. {sae_lens-6.16.3 → sae_lens-6.22.0}/sae_lens/analysis/neuronpedia_integration.py +0 -0
  23. {sae_lens-6.16.3 → sae_lens-6.22.0}/sae_lens/evals.py +0 -0
  24. {sae_lens-6.16.3 → sae_lens-6.22.0}/sae_lens/load_model.py +0 -0
  25. {sae_lens-6.16.3 → sae_lens-6.22.0}/sae_lens/loading/__init__.py +0 -0
  26. {sae_lens-6.16.3 → sae_lens-6.22.0}/sae_lens/pretokenize_runner.py +0 -0
  27. {sae_lens-6.16.3 → sae_lens-6.22.0}/sae_lens/registry.py +0 -0
  28. {sae_lens-6.16.3 → sae_lens-6.22.0}/sae_lens/saes/batchtopk_sae.py +0 -0
  29. {sae_lens-6.16.3 → sae_lens-6.22.0}/sae_lens/saes/gated_sae.py +0 -0
  30. {sae_lens-6.16.3 → sae_lens-6.22.0}/sae_lens/saes/jumprelu_sae.py +0 -0
  31. {sae_lens-6.16.3 → sae_lens-6.22.0}/sae_lens/saes/matryoshka_batchtopk_sae.py +0 -0
  32. {sae_lens-6.16.3 → sae_lens-6.22.0}/sae_lens/saes/standard_sae.py +0 -0
  33. {sae_lens-6.16.3 → sae_lens-6.22.0}/sae_lens/saes/topk_sae.py +0 -0
  34. {sae_lens-6.16.3 → sae_lens-6.22.0}/sae_lens/saes/transcoder.py +0 -0
  35. {sae_lens-6.16.3 → sae_lens-6.22.0}/sae_lens/tokenization_and_batching.py +0 -0
  36. {sae_lens-6.16.3 → sae_lens-6.22.0}/sae_lens/training/__init__.py +0 -0
  37. {sae_lens-6.16.3 → sae_lens-6.22.0}/sae_lens/training/mixing_buffer.py +0 -0
  38. {sae_lens-6.16.3 → sae_lens-6.22.0}/sae_lens/training/types.py +0 -0
  39. {sae_lens-6.16.3 → sae_lens-6.22.0}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
  40. {sae_lens-6.16.3 → sae_lens-6.22.0}/sae_lens/tutorial/tsea.py +0 -0
  41. {sae_lens-6.16.3 → sae_lens-6.22.0}/sae_lens/util.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sae-lens
3
- Version: 6.16.3
3
+ Version: 6.22.0
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  License-File: LICENSE
@@ -30,19 +30,19 @@ Requires-Dist: tenacity (>=9.0.0)
30
30
  Requires-Dist: transformer-lens (>=2.16.1,<3.0.0)
31
31
  Requires-Dist: transformers (>=4.38.1,<5.0.0)
32
32
  Requires-Dist: typing-extensions (>=4.10.0,<5.0.0)
33
- Project-URL: Homepage, https://jbloomaus.github.io/SAELens
34
- Project-URL: Repository, https://github.com/jbloomAus/SAELens
33
+ Project-URL: Homepage, https://decoderesearch.github.io/SAELens
34
+ Project-URL: Repository, https://github.com/decoderesearch/SAELens
35
35
  Description-Content-Type: text/markdown
36
36
 
37
- <img width="1308" alt="Screenshot 2024-03-21 at 3 08 28 pm" src="https://github.com/jbloomAus/mats_sae_training/assets/69127271/209012ec-a779-4036-b4be-7b7739ea87f6">
37
+ <img width="1308" height="532" alt="saes_pic" src="https://github.com/user-attachments/assets/2a5d752f-b261-4ee4-ad5d-ebf282321371" />
38
38
 
39
39
  # SAE Lens
40
40
 
41
41
  [![PyPI](https://img.shields.io/pypi/v/sae-lens?color=blue)](https://pypi.org/project/sae-lens/)
42
42
  [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
43
- [![build](https://github.com/jbloomAus/SAELens/actions/workflows/build.yml/badge.svg)](https://github.com/jbloomAus/SAELens/actions/workflows/build.yml)
44
- [![Deploy Docs](https://github.com/jbloomAus/SAELens/actions/workflows/deploy_docs.yml/badge.svg)](https://github.com/jbloomAus/SAELens/actions/workflows/deploy_docs.yml)
45
- [![codecov](https://codecov.io/gh/jbloomAus/SAELens/graph/badge.svg?token=N83NGH8CGE)](https://codecov.io/gh/jbloomAus/SAELens)
43
+ [![build](https://github.com/decoderesearch/SAELens/actions/workflows/build.yml/badge.svg)](https://github.com/decoderesearch/SAELens/actions/workflows/build.yml)
44
+ [![Deploy Docs](https://github.com/decoderesearch/SAELens/actions/workflows/deploy_docs.yml/badge.svg)](https://github.com/decoderesearch/SAELens/actions/workflows/deploy_docs.yml)
45
+ [![codecov](https://codecov.io/gh/decoderesearch/SAELens/graph/badge.svg?token=N83NGH8CGE)](https://codecov.io/gh/decoderesearch/SAELens)
46
46
 
47
47
  SAELens exists to help researchers:
48
48
 
@@ -50,7 +50,7 @@ SAELens exists to help researchers:
50
50
  - Analyse sparse autoencoders / research mechanistic interpretability.
51
51
  - Generate insights which make it easier to create safe and aligned AI systems.
52
52
 
53
- Please refer to the [documentation](https://jbloomaus.github.io/SAELens/) for information on how to:
53
+ Please refer to the [documentation](https://decoderesearch.github.io/SAELens/) for information on how to:
54
54
 
55
55
  - Download and Analyse pre-trained sparse autoencoders.
56
56
  - Train your own sparse autoencoders.
@@ -58,25 +58,25 @@ Please refer to the [documentation](https://jbloomaus.github.io/SAELens/) for in
58
58
 
59
59
  SAE Lens is the result of many contributors working collectively to improve humanity's understanding of neural networks, many of whom are motivated by a desire to [safeguard humanity from risks posed by artificial intelligence](https://80000hours.org/problem-profiles/artificial-intelligence/).
60
60
 
61
- This library is maintained by [Joseph Bloom](https://www.jbloomaus.com/), [Curt Tigges](https://curttigges.com/), [Anthony Duong](https://github.com/anthonyduong9) and [David Chanin](https://github.com/chanind).
61
+ This library is maintained by [Joseph Bloom](https://www.decoderesearch.com/), [Curt Tigges](https://curttigges.com/), [Anthony Duong](https://github.com/anthonyduong9) and [David Chanin](https://github.com/chanind).
62
62
 
63
63
  ## Loading Pre-trained SAEs.
64
64
 
65
- Pre-trained SAEs for various models can be imported via SAE Lens. See this [page](https://jbloomaus.github.io/SAELens/sae_table/) in the readme for a list of all SAEs.
65
+ Pre-trained SAEs for various models can be imported via SAE Lens. See this [page](https://decoderesearch.github.io/SAELens/sae_table/) in the readme for a list of all SAEs.
66
66
 
67
67
  ## Migrating to SAELens v6
68
68
 
69
- The new v6 update is a major refactor to SAELens and changes the way training code is structured. Check out the [migration guide](https://jbloomaus.github.io/SAELens/latest/migrating/) for more details.
69
+ The new v6 update is a major refactor to SAELens and changes the way training code is structured. Check out the [migration guide](https://decoderesearch.github.io/SAELens/latest/migrating/) for more details.
70
70
 
71
71
  ## Tutorials
72
72
 
73
- - [SAE Lens + Neuronpedia](tutorials/tutorial_2_0.ipynb)[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/tutorial_2_0.ipynb)
73
+ - [SAE Lens + Neuronpedia](tutorials/tutorial_2_0.ipynb)[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/tutorial_2_0.ipynb)
74
74
  - [Loading and Analysing Pre-Trained Sparse Autoencoders](tutorials/basic_loading_and_analysing.ipynb)
75
- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb)
75
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb)
76
76
  - [Understanding SAE Features with the Logit Lens](tutorials/logits_lens_with_features.ipynb)
77
- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/logits_lens_with_features.ipynb)
77
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/logits_lens_with_features.ipynb)
78
78
  - [Training a Sparse Autoencoder](tutorials/training_a_sparse_autoencoder.ipynb)
79
- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)
79
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)
80
80
 
81
81
  ## Join the Slack!
82
82
 
@@ -91,7 +91,7 @@ Please cite the package as follows:
91
91
  title = {SAELens},
92
92
  author = {Bloom, Joseph and Tigges, Curt and Duong, Anthony and Chanin, David},
93
93
  year = {2024},
94
- howpublished = {\url{https://github.com/jbloomAus/SAELens}},
94
+ howpublished = {\url{https://github.com/decoderesearch/SAELens}},
95
95
  }
96
96
  ```
97
97
 
@@ -1,12 +1,12 @@
1
- <img width="1308" alt="Screenshot 2024-03-21 at 3 08 28 pm" src="https://github.com/jbloomAus/mats_sae_training/assets/69127271/209012ec-a779-4036-b4be-7b7739ea87f6">
1
+ <img width="1308" height="532" alt="saes_pic" src="https://github.com/user-attachments/assets/2a5d752f-b261-4ee4-ad5d-ebf282321371" />
2
2
 
3
3
  # SAE Lens
4
4
 
5
5
  [![PyPI](https://img.shields.io/pypi/v/sae-lens?color=blue)](https://pypi.org/project/sae-lens/)
6
6
  [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
7
- [![build](https://github.com/jbloomAus/SAELens/actions/workflows/build.yml/badge.svg)](https://github.com/jbloomAus/SAELens/actions/workflows/build.yml)
8
- [![Deploy Docs](https://github.com/jbloomAus/SAELens/actions/workflows/deploy_docs.yml/badge.svg)](https://github.com/jbloomAus/SAELens/actions/workflows/deploy_docs.yml)
9
- [![codecov](https://codecov.io/gh/jbloomAus/SAELens/graph/badge.svg?token=N83NGH8CGE)](https://codecov.io/gh/jbloomAus/SAELens)
7
+ [![build](https://github.com/decoderesearch/SAELens/actions/workflows/build.yml/badge.svg)](https://github.com/decoderesearch/SAELens/actions/workflows/build.yml)
8
+ [![Deploy Docs](https://github.com/decoderesearch/SAELens/actions/workflows/deploy_docs.yml/badge.svg)](https://github.com/decoderesearch/SAELens/actions/workflows/deploy_docs.yml)
9
+ [![codecov](https://codecov.io/gh/decoderesearch/SAELens/graph/badge.svg?token=N83NGH8CGE)](https://codecov.io/gh/decoderesearch/SAELens)
10
10
 
11
11
  SAELens exists to help researchers:
12
12
 
@@ -14,7 +14,7 @@ SAELens exists to help researchers:
14
14
  - Analyse sparse autoencoders / research mechanistic interpretability.
15
15
  - Generate insights which make it easier to create safe and aligned AI systems.
16
16
 
17
- Please refer to the [documentation](https://jbloomaus.github.io/SAELens/) for information on how to:
17
+ Please refer to the [documentation](https://decoderesearch.github.io/SAELens/) for information on how to:
18
18
 
19
19
  - Download and Analyse pre-trained sparse autoencoders.
20
20
  - Train your own sparse autoencoders.
@@ -22,25 +22,25 @@ Please refer to the [documentation](https://jbloomaus.github.io/SAELens/) for in
22
22
 
23
23
  SAE Lens is the result of many contributors working collectively to improve humanity's understanding of neural networks, many of whom are motivated by a desire to [safeguard humanity from risks posed by artificial intelligence](https://80000hours.org/problem-profiles/artificial-intelligence/).
24
24
 
25
- This library is maintained by [Joseph Bloom](https://www.jbloomaus.com/), [Curt Tigges](https://curttigges.com/), [Anthony Duong](https://github.com/anthonyduong9) and [David Chanin](https://github.com/chanind).
25
+ This library is maintained by [Joseph Bloom](https://www.decoderesearch.com/), [Curt Tigges](https://curttigges.com/), [Anthony Duong](https://github.com/anthonyduong9) and [David Chanin](https://github.com/chanind).
26
26
 
27
27
  ## Loading Pre-trained SAEs.
28
28
 
29
- Pre-trained SAEs for various models can be imported via SAE Lens. See this [page](https://jbloomaus.github.io/SAELens/sae_table/) in the readme for a list of all SAEs.
29
+ Pre-trained SAEs for various models can be imported via SAE Lens. See this [page](https://decoderesearch.github.io/SAELens/sae_table/) in the readme for a list of all SAEs.
30
30
 
31
31
  ## Migrating to SAELens v6
32
32
 
33
- The new v6 update is a major refactor to SAELens and changes the way training code is structured. Check out the [migration guide](https://jbloomaus.github.io/SAELens/latest/migrating/) for more details.
33
+ The new v6 update is a major refactor to SAELens and changes the way training code is structured. Check out the [migration guide](https://decoderesearch.github.io/SAELens/latest/migrating/) for more details.
34
34
 
35
35
  ## Tutorials
36
36
 
37
- - [SAE Lens + Neuronpedia](tutorials/tutorial_2_0.ipynb)[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/tutorial_2_0.ipynb)
37
+ - [SAE Lens + Neuronpedia](tutorials/tutorial_2_0.ipynb)[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/tutorial_2_0.ipynb)
38
38
  - [Loading and Analysing Pre-Trained Sparse Autoencoders](tutorials/basic_loading_and_analysing.ipynb)
39
- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb)
39
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb)
40
40
  - [Understanding SAE Features with the Logit Lens](tutorials/logits_lens_with_features.ipynb)
41
- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/logits_lens_with_features.ipynb)
41
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/logits_lens_with_features.ipynb)
42
42
  - [Training a Sparse Autoencoder](tutorials/training_a_sparse_autoencoder.ipynb)
43
- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)
43
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)
44
44
 
45
45
  ## Join the Slack!
46
46
 
@@ -55,6 +55,6 @@ Please cite the package as follows:
55
55
  title = {SAELens},
56
56
  author = {Bloom, Joseph and Tigges, Curt and Duong, Anthony and Chanin, David},
57
57
  year = {2024},
58
- howpublished = {\url{https://github.com/jbloomAus/SAELens}},
58
+ howpublished = {\url{https://github.com/decoderesearch/SAELens}},
59
59
  }
60
60
  ```
@@ -1,13 +1,13 @@
1
1
  [tool.poetry]
2
2
  name = "sae-lens"
3
- version = "6.16.3"
3
+ version = "6.22.0"
4
4
  description = "Training and Analyzing Sparse Autoencoders (SAEs)"
5
5
  authors = ["Joseph Bloom"]
6
6
  readme = "README.md"
7
7
  packages = [{ include = "sae_lens" }]
8
8
  include = ["pretrained_saes.yaml"]
9
- repository = "https://github.com/jbloomAus/SAELens"
10
- homepage = "https://jbloomaus.github.io/SAELens"
9
+ repository = "https://github.com/decoderesearch/SAELens"
10
+ homepage = "https://decoderesearch.github.io/SAELens"
11
11
  license = "MIT"
12
12
  keywords = [
13
13
  "deep-learning",
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.16.3"
2
+ __version__ = "6.22.0"
3
3
 
4
4
  import logging
5
5
 
@@ -28,6 +28,8 @@ from sae_lens.saes import (
28
28
  StandardSAEConfig,
29
29
  StandardTrainingSAE,
30
30
  StandardTrainingSAEConfig,
31
+ TemporalSAE,
32
+ TemporalSAEConfig,
31
33
  TopKSAE,
32
34
  TopKSAEConfig,
33
35
  TopKTrainingSAE,
@@ -105,6 +107,8 @@ __all__ = [
105
107
  "JumpReLUTranscoderConfig",
106
108
  "MatryoshkaBatchTopKTrainingSAE",
107
109
  "MatryoshkaBatchTopKTrainingSAEConfig",
110
+ "TemporalSAE",
111
+ "TemporalSAEConfig",
108
112
  ]
109
113
 
110
114
 
@@ -127,3 +131,4 @@ register_sae_training_class(
127
131
  register_sae_class("transcoder", Transcoder, TranscoderConfig)
128
132
  register_sae_class("skip_transcoder", SkipTranscoder, SkipTranscoderConfig)
129
133
  register_sae_class("jumprelu_transcoder", JumpReLUTranscoder, JumpReLUTranscoderConfig)
134
+ register_sae_class("temporal", TemporalSAE, TemporalSAEConfig)
@@ -10,7 +10,7 @@ from datasets import Array2D, Dataset, Features, Sequence, Value
10
10
  from datasets.fingerprint import generate_fingerprint
11
11
  from huggingface_hub import HfApi
12
12
  from jaxtyping import Float, Int
13
- from tqdm import tqdm
13
+ from tqdm.auto import tqdm
14
14
  from transformer_lens.HookedTransformer import HookedRootModule
15
15
 
16
16
  from sae_lens import logger
@@ -18,6 +18,7 @@ from datasets import (
18
18
 
19
19
  from sae_lens import __version__, logger
20
20
  from sae_lens.constants import DTYPE_MAP
21
+ from sae_lens.registry import get_sae_training_class
21
22
  from sae_lens.saes.sae import TrainingSAEConfig
22
23
 
23
24
  if TYPE_CHECKING:
@@ -171,6 +172,7 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
171
172
  n_checkpoints (int): The number of checkpoints to save during training. 0 means no checkpoints.
172
173
  checkpoint_path (str | None): The path to save checkpoints. A unique ID will be appended to this path. Set to None to disable checkpoint saving. (default is "checkpoints")
173
174
  save_final_checkpoint (bool): Whether to include an additional final checkpoint when training is finished. (default is False).
175
+ resume_from_checkpoint (str | None): The path to the checkpoint to resume training from. (default is None).
174
176
  output_path (str | None): The path to save outputs. Set to None to disable output saving. (default is "output")
175
177
  verbose (bool): Whether to print verbose output. (default is True)
176
178
  model_kwargs (dict[str, Any]): Keyword arguments for `model.run_with_cache`
@@ -261,6 +263,7 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
261
263
  checkpoint_path: str | None = "checkpoints"
262
264
  save_final_checkpoint: bool = False
263
265
  output_path: str | None = "output"
266
+ resume_from_checkpoint: str | None = None
264
267
 
265
268
  # Misc
266
269
  verbose: bool = True
@@ -385,8 +388,11 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
385
388
  return self.sae.to_dict()
386
389
 
387
390
  def to_dict(self) -> dict[str, Any]:
388
- # Make a shallow copy of config's dictionary
389
- d = dict(self.__dict__)
391
+ """
392
+ Convert the config to a dictionary.
393
+ """
394
+
395
+ d = asdict(self)
390
396
 
391
397
  d["logger"] = asdict(self.logger)
392
398
  d["sae"] = self.sae.to_dict()
@@ -396,6 +402,37 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
396
402
  d["act_store_device"] = str(self.act_store_device)
397
403
  return d
398
404
 
405
+ @classmethod
406
+ def from_dict(cls, cfg_dict: dict[str, Any]) -> "LanguageModelSAERunnerConfig[Any]":
407
+ """
408
+ Load a LanguageModelSAERunnerConfig from a dictionary given by `to_dict`.
409
+
410
+ Args:
411
+ cfg_dict (dict[str, Any]): The dictionary to load the config from.
412
+
413
+ Returns:
414
+ LanguageModelSAERunnerConfig: The loaded config.
415
+ """
416
+ if "sae" not in cfg_dict:
417
+ raise ValueError("sae field is required in the config dictionary")
418
+ if "architecture" not in cfg_dict["sae"]:
419
+ raise ValueError("architecture field is required in the sae dictionary")
420
+ if "logger" not in cfg_dict:
421
+ raise ValueError("logger field is required in the config dictionary")
422
+ sae_config_class = get_sae_training_class(cfg_dict["sae"]["architecture"])[1]
423
+ sae_cfg = sae_config_class.from_dict(cfg_dict["sae"])
424
+ logger_cfg = LoggingConfig(**cfg_dict["logger"])
425
+ updated_cfg_dict: dict[str, Any] = {
426
+ **cfg_dict,
427
+ "sae": sae_cfg,
428
+ "logger": logger_cfg,
429
+ }
430
+ output = cls(**updated_cfg_dict)
431
+ # the post_init always appends to checkpoint path, so we need to set it explicitly here.
432
+ if "checkpoint_path" in cfg_dict:
433
+ output.checkpoint_path = cfg_dict["checkpoint_path"]
434
+ return output
435
+
399
436
  def to_sae_trainer_config(self) -> "SAETrainerConfig":
400
437
  return SAETrainerConfig(
401
438
  n_checkpoints=self.n_checkpoints,
@@ -17,5 +17,6 @@ SAE_WEIGHTS_FILENAME = "sae_weights.safetensors"
17
17
  SAE_CFG_FILENAME = "cfg.json"
18
18
  RUNNER_CFG_FILENAME = "runner_cfg.json"
19
19
  SPARSIFY_WEIGHTS_FILENAME = "sae.safetensors"
20
+ TRAINER_STATE_FILENAME = "trainer_state.pt"
20
21
  ACTIVATIONS_STORE_STATE_FILENAME = "activations_store_state.safetensors"
21
22
  ACTIVATION_SCALER_CFG_FILENAME = "activation_scaler.json"
@@ -16,7 +16,6 @@ from typing_extensions import deprecated
16
16
  from sae_lens import logger
17
17
  from sae_lens.config import HfDataset, LanguageModelSAERunnerConfig
18
18
  from sae_lens.constants import (
19
- ACTIVATIONS_STORE_STATE_FILENAME,
20
19
  RUNNER_CFG_FILENAME,
21
20
  SPARSITY_FILENAME,
22
21
  )
@@ -112,6 +111,7 @@ class LanguageModelSAETrainingRunner:
112
111
  override_dataset: HfDataset | None = None,
113
112
  override_model: HookedRootModule | None = None,
114
113
  override_sae: TrainingSAE[Any] | None = None,
114
+ resume_from_checkpoint: Path | str | None = None,
115
115
  ):
116
116
  if override_dataset is not None:
117
117
  logger.warning(
@@ -153,6 +153,7 @@ class LanguageModelSAETrainingRunner:
153
153
  )
154
154
  else:
155
155
  self.sae = override_sae
156
+
156
157
  self.sae.to(self.cfg.device)
157
158
 
158
159
  def run(self):
@@ -185,6 +186,12 @@ class LanguageModelSAETrainingRunner:
185
186
  cfg=self.cfg.to_sae_trainer_config(),
186
187
  )
187
188
 
189
+ if self.cfg.resume_from_checkpoint is not None:
190
+ logger.info(f"Resuming from checkpoint: {self.cfg.resume_from_checkpoint}")
191
+ trainer.load_trainer_state(self.cfg.resume_from_checkpoint)
192
+ self.sae.load_weights_from_checkpoint(self.cfg.resume_from_checkpoint)
193
+ self.activations_store.load_from_checkpoint(self.cfg.resume_from_checkpoint)
194
+
188
195
  self._compile_if_needed()
189
196
  sae = self.run_trainer_with_interruption_handling(trainer)
190
197
 
@@ -304,9 +311,7 @@ class LanguageModelSAETrainingRunner:
304
311
  if checkpoint_path is None:
305
312
  return
306
313
 
307
- self.activations_store.save(
308
- str(checkpoint_path / ACTIVATIONS_STORE_STATE_FILENAME)
309
- )
314
+ self.activations_store.save_to_checkpoint(checkpoint_path)
310
315
 
311
316
  runner_config = self.cfg.to_dict()
312
317
  with open(checkpoint_path / RUNNER_CFG_FILENAME, "w") as f:
@@ -523,6 +523,82 @@ def gemma_2_sae_huggingface_loader(
523
523
  return cfg_dict, state_dict, log_sparsity
524
524
 
525
525
 
526
+ def get_goodfire_config_from_hf(
527
+ repo_id: str,
528
+ folder_name: str, # noqa: ARG001
529
+ device: str,
530
+ force_download: bool = False, # noqa: ARG001
531
+ cfg_overrides: dict[str, Any] | None = None,
532
+ ) -> dict[str, Any]:
533
+ cfg_dict = None
534
+ if repo_id == "Goodfire/Llama-3.3-70B-Instruct-SAE-l50":
535
+ if folder_name != "Llama-3.3-70B-Instruct-SAE-l50.pt":
536
+ raise ValueError(f"Unsupported Goodfire SAE: {repo_id}/{folder_name}")
537
+ cfg_dict = {
538
+ "architecture": "standard",
539
+ "d_in": 8192,
540
+ "d_sae": 65536,
541
+ "model_name": "meta-llama/Llama-3.3-70B-Instruct",
542
+ "hook_name": "blocks.50.hook_resid_post",
543
+ "hook_head_index": None,
544
+ "dataset_path": "lmsys/lmsys-chat-1m",
545
+ "apply_b_dec_to_input": False,
546
+ }
547
+ elif repo_id == "Goodfire/Llama-3.1-8B-Instruct-SAE-l19":
548
+ if folder_name != "Llama-3.1-8B-Instruct-SAE-l19.pth":
549
+ raise ValueError(f"Unsupported Goodfire SAE: {repo_id}/{folder_name}")
550
+ cfg_dict = {
551
+ "architecture": "standard",
552
+ "d_in": 4096,
553
+ "d_sae": 65536,
554
+ "model_name": "meta-llama/Llama-3.1-8B-Instruct",
555
+ "hook_name": "blocks.19.hook_resid_post",
556
+ "hook_head_index": None,
557
+ "dataset_path": "lmsys/lmsys-chat-1m",
558
+ "apply_b_dec_to_input": False,
559
+ }
560
+ if cfg_dict is None:
561
+ raise ValueError(f"Unsupported Goodfire SAE: {repo_id}/{folder_name}")
562
+ if device is not None:
563
+ cfg_dict["device"] = device
564
+ if cfg_overrides is not None:
565
+ cfg_dict.update(cfg_overrides)
566
+ return cfg_dict
567
+
568
+
569
+ def get_goodfire_huggingface_loader(
570
+ repo_id: str,
571
+ folder_name: str,
572
+ device: str = "cpu",
573
+ force_download: bool = False,
574
+ cfg_overrides: dict[str, Any] | None = None,
575
+ ) -> tuple[dict[str, Any], dict[str, torch.Tensor], torch.Tensor | None]:
576
+ cfg_dict = get_goodfire_config_from_hf(
577
+ repo_id,
578
+ folder_name,
579
+ device,
580
+ force_download,
581
+ cfg_overrides,
582
+ )
583
+
584
+ # Download the SAE weights
585
+ sae_path = hf_hub_download(
586
+ repo_id=repo_id,
587
+ filename=folder_name,
588
+ force_download=force_download,
589
+ )
590
+ raw_state_dict = torch.load(sae_path, map_location=device)
591
+
592
+ state_dict = {
593
+ "W_enc": raw_state_dict["encoder_linear.weight"].T,
594
+ "W_dec": raw_state_dict["decoder_linear.weight"].T,
595
+ "b_enc": raw_state_dict["encoder_linear.bias"],
596
+ "b_dec": raw_state_dict["decoder_linear.bias"],
597
+ }
598
+
599
+ return cfg_dict, state_dict, None
600
+
601
+
526
602
  def get_llama_scope_config_from_hf(
527
603
  repo_id: str,
528
604
  folder_name: str,
@@ -1475,6 +1551,114 @@ def get_mntss_clt_layer_config_from_hf(
1475
1551
  }
1476
1552
 
1477
1553
 
1554
+ def get_temporal_sae_config_from_hf(
1555
+ repo_id: str,
1556
+ folder_name: str,
1557
+ device: str,
1558
+ force_download: bool = False,
1559
+ cfg_overrides: dict[str, Any] | None = None,
1560
+ ) -> dict[str, Any]:
1561
+ """Get TemporalSAE config without loading weights."""
1562
+ # Download config file
1563
+ conf_path = hf_hub_download(
1564
+ repo_id=repo_id,
1565
+ filename=f"{folder_name}/conf.yaml",
1566
+ force_download=force_download,
1567
+ )
1568
+
1569
+ # Load and parse config
1570
+ with open(conf_path) as f:
1571
+ yaml_config = yaml.safe_load(f)
1572
+
1573
+ # Extract parameters
1574
+ d_in = yaml_config["llm"]["dimin"]
1575
+ exp_factor = yaml_config["sae"]["exp_factor"]
1576
+ d_sae = int(d_in * exp_factor)
1577
+
1578
+ # extract layer from folder_name eg : "layer_12/temporal"
1579
+ layer = re.search(r"layer_(\d+)", folder_name)
1580
+ if layer is None:
1581
+ raise ValueError(f"Could not find layer in folder_name: {folder_name}")
1582
+ layer = int(layer.group(1))
1583
+
1584
+ # Build config dict
1585
+ cfg_dict = {
1586
+ "architecture": "temporal",
1587
+ "hook_name": f"blocks.{layer}.hook_resid_post",
1588
+ "d_in": d_in,
1589
+ "d_sae": d_sae,
1590
+ "n_heads": yaml_config["sae"]["n_heads"],
1591
+ "n_attn_layers": yaml_config["sae"]["n_attn_layers"],
1592
+ "bottleneck_factor": yaml_config["sae"]["bottleneck_factor"],
1593
+ "sae_diff_type": yaml_config["sae"]["sae_diff_type"],
1594
+ "kval_topk": yaml_config["sae"]["kval_topk"],
1595
+ "tied_weights": yaml_config["sae"]["tied_weights"],
1596
+ "dtype": yaml_config["data"]["dtype"],
1597
+ "device": device,
1598
+ "normalize_activations": "constant_scalar_rescale",
1599
+ "activation_normalization_factor": yaml_config["sae"]["scaling_factor"],
1600
+ "apply_b_dec_to_input": True,
1601
+ }
1602
+
1603
+ if cfg_overrides:
1604
+ cfg_dict.update(cfg_overrides)
1605
+
1606
+ return cfg_dict
1607
+
1608
+
1609
+ def temporal_sae_huggingface_loader(
1610
+ repo_id: str,
1611
+ folder_name: str,
1612
+ device: str = "cpu",
1613
+ force_download: bool = False,
1614
+ cfg_overrides: dict[str, Any] | None = None,
1615
+ ) -> tuple[dict[str, Any], dict[str, torch.Tensor], torch.Tensor | None]:
1616
+ """
1617
+ Load TemporalSAE from canrager/temporalSAEs format (safetensors version).
1618
+
1619
+ Expects folder_name to contain:
1620
+ - conf.yaml (configuration)
1621
+ - latest_ckpt.safetensors (model weights)
1622
+ """
1623
+
1624
+ cfg_dict = get_temporal_sae_config_from_hf(
1625
+ repo_id=repo_id,
1626
+ folder_name=folder_name,
1627
+ device=device,
1628
+ force_download=force_download,
1629
+ cfg_overrides=cfg_overrides,
1630
+ )
1631
+
1632
+ # Download checkpoint (safetensors format)
1633
+ ckpt_path = hf_hub_download(
1634
+ repo_id=repo_id,
1635
+ filename=f"{folder_name}/latest_ckpt.safetensors",
1636
+ force_download=force_download,
1637
+ )
1638
+
1639
+ # Load checkpoint from safetensors
1640
+ state_dict_raw = load_file(ckpt_path, device=device)
1641
+
1642
+ # Convert to SAELens naming convention
1643
+ # TemporalSAE uses: D (decoder), E (encoder), b (bias), attn_layers.*
1644
+ state_dict = {}
1645
+
1646
+ # Copy attention layers as-is
1647
+ for key, value in state_dict_raw.items():
1648
+ if key.startswith("attn_layers."):
1649
+ state_dict[key] = value.to(device)
1650
+
1651
+ # Main parameters
1652
+ state_dict["W_dec"] = state_dict_raw["D"].to(device)
1653
+ state_dict["b_dec"] = state_dict_raw["b"].to(device)
1654
+
1655
+ # Handle tied/untied weights
1656
+ if "E" in state_dict_raw:
1657
+ state_dict["W_enc"] = state_dict_raw["E"].to(device)
1658
+
1659
+ return cfg_dict, state_dict, None
1660
+
1661
+
1478
1662
  NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
1479
1663
  "sae_lens": sae_lens_huggingface_loader,
1480
1664
  "connor_rob_hook_z": connor_rob_hook_z_huggingface_loader,
@@ -1487,6 +1671,8 @@ NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
1487
1671
  "gemma_2_transcoder": gemma_2_transcoder_huggingface_loader,
1488
1672
  "mwhanna_transcoder": mwhanna_transcoder_huggingface_loader,
1489
1673
  "mntss_clt_layer_transcoder": mntss_clt_layer_huggingface_loader,
1674
+ "temporal": temporal_sae_huggingface_loader,
1675
+ "goodfire": get_goodfire_huggingface_loader,
1490
1676
  }
1491
1677
 
1492
1678
 
@@ -1502,4 +1688,6 @@ NAMED_PRETRAINED_SAE_CONFIG_GETTERS: dict[str, PretrainedSaeConfigHuggingfaceLoa
1502
1688
  "gemma_2_transcoder": get_gemma_2_transcoder_config_from_hf,
1503
1689
  "mwhanna_transcoder": get_mwhanna_transcoder_config_from_hf,
1504
1690
  "mntss_clt_layer_transcoder": get_mntss_clt_layer_config_from_hf,
1691
+ "temporal": get_temporal_sae_config_from_hf,
1692
+ "goodfire": get_goodfire_config_from_hf,
1505
1693
  }
@@ -1,6 +1,6 @@
1
1
  from dataclasses import dataclass
2
2
  from functools import cache
3
- from importlib import resources
3
+ from importlib.resources import files
4
4
  from typing import Any
5
5
 
6
6
  import yaml
@@ -24,7 +24,8 @@ def get_pretrained_saes_directory() -> dict[str, PretrainedSAELookup]:
24
24
  package = "sae_lens"
25
25
  # Access the file within the package using importlib.resources
26
26
  directory: dict[str, PretrainedSAELookup] = {}
27
- with resources.open_text(package, "pretrained_saes.yaml") as file:
27
+ yaml_file = files(package).joinpath("pretrained_saes.yaml")
28
+ with yaml_file.open("r") as file:
28
29
  # Load the YAML file content
29
30
  data = yaml.safe_load(file)
30
31
  for release, value in data.items():
@@ -68,7 +69,8 @@ def get_norm_scaling_factor(release: str, sae_id: str) -> float | None:
68
69
  float | None: The norm_scaling_factor if it exists, None otherwise.
69
70
  """
70
71
  package = "sae_lens"
71
- with resources.open_text(package, "pretrained_saes.yaml") as file:
72
+ yaml_file = files(package).joinpath("pretrained_saes.yaml")
73
+ with yaml_file.open("r") as file:
72
74
  data = yaml.safe_load(file)
73
75
  if release in data:
74
76
  for sae_info in data[release]["saes"]: