sae-lens 5.7.1__tar.gz → 6.25.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sae-lens might be problematic. Click here for more details.

Files changed (50) hide show
  1. {sae_lens-5.7.1 → sae_lens-6.25.1}/PKG-INFO +31 -31
  2. {sae_lens-5.7.1 → sae_lens-6.25.1}/README.md +18 -14
  3. {sae_lens-5.7.1 → sae_lens-6.25.1}/pyproject.toml +20 -20
  4. sae_lens-6.25.1/sae_lens/__init__.py +141 -0
  5. {sae_lens-5.7.1 → sae_lens-6.25.1}/sae_lens/analysis/hooked_sae_transformer.py +29 -25
  6. sae_lens-6.25.1/sae_lens/analysis/neuronpedia_integration.py +163 -0
  7. {sae_lens-5.7.1 → sae_lens-6.25.1}/sae_lens/cache_activations_runner.py +13 -12
  8. {sae_lens-5.7.1 → sae_lens-6.25.1}/sae_lens/config.py +254 -271
  9. sae_lens-6.25.1/sae_lens/constants.py +30 -0
  10. {sae_lens-5.7.1 → sae_lens-6.25.1}/sae_lens/evals.py +146 -87
  11. sae_lens-6.25.1/sae_lens/llm_sae_training_runner.py +429 -0
  12. {sae_lens-5.7.1 → sae_lens-6.25.1}/sae_lens/load_model.py +54 -6
  13. sae_lens-6.25.1/sae_lens/loading/pretrained_sae_loaders.py +1911 -0
  14. {sae_lens-5.7.1/sae_lens/toolkit → sae_lens-6.25.1/sae_lens/loading}/pretrained_saes_directory.py +17 -3
  15. {sae_lens-5.7.1 → sae_lens-6.25.1}/sae_lens/pretokenize_runner.py +8 -4
  16. sae_lens-6.25.1/sae_lens/pretrained_saes.yaml +41813 -0
  17. sae_lens-6.25.1/sae_lens/registry.py +49 -0
  18. sae_lens-6.25.1/sae_lens/saes/__init__.py +81 -0
  19. sae_lens-6.25.1/sae_lens/saes/batchtopk_sae.py +134 -0
  20. sae_lens-6.25.1/sae_lens/saes/gated_sae.py +242 -0
  21. sae_lens-6.25.1/sae_lens/saes/jumprelu_sae.py +367 -0
  22. sae_lens-6.25.1/sae_lens/saes/matryoshka_batchtopk_sae.py +136 -0
  23. sae_lens-6.25.1/sae_lens/saes/sae.py +1067 -0
  24. sae_lens-6.25.1/sae_lens/saes/standard_sae.py +165 -0
  25. sae_lens-6.25.1/sae_lens/saes/temporal_sae.py +365 -0
  26. sae_lens-6.25.1/sae_lens/saes/topk_sae.py +538 -0
  27. sae_lens-6.25.1/sae_lens/saes/transcoder.py +411 -0
  28. {sae_lens-5.7.1 → sae_lens-6.25.1}/sae_lens/tokenization_and_batching.py +25 -2
  29. sae_lens-6.25.1/sae_lens/training/activation_scaler.py +60 -0
  30. {sae_lens-5.7.1 → sae_lens-6.25.1}/sae_lens/training/activations_store.py +179 -237
  31. sae_lens-6.25.1/sae_lens/training/mixing_buffer.py +56 -0
  32. {sae_lens-5.7.1 → sae_lens-6.25.1}/sae_lens/training/optim.py +36 -34
  33. sae_lens-6.25.1/sae_lens/training/sae_trainer.py +455 -0
  34. sae_lens-6.25.1/sae_lens/training/types.py +5 -0
  35. {sae_lens-5.7.1 → sae_lens-6.25.1}/sae_lens/training/upload_saes_to_huggingface.py +17 -6
  36. sae_lens-6.25.1/sae_lens/util.py +113 -0
  37. sae_lens-5.7.1/sae_lens/__init__.py +0 -39
  38. sae_lens-5.7.1/sae_lens/analysis/neuronpedia_integration.py +0 -492
  39. sae_lens-5.7.1/sae_lens/pretrained_saes.yaml +0 -13961
  40. sae_lens-5.7.1/sae_lens/sae.py +0 -737
  41. sae_lens-5.7.1/sae_lens/sae_training_runner.py +0 -251
  42. sae_lens-5.7.1/sae_lens/toolkit/pretrained_sae_loaders.py +0 -879
  43. sae_lens-5.7.1/sae_lens/training/geometric_median.py +0 -101
  44. sae_lens-5.7.1/sae_lens/training/sae_trainer.py +0 -444
  45. sae_lens-5.7.1/sae_lens/training/training_sae.py +0 -711
  46. {sae_lens-5.7.1 → sae_lens-6.25.1}/LICENSE +0 -0
  47. {sae_lens-5.7.1 → sae_lens-6.25.1}/sae_lens/analysis/__init__.py +0 -0
  48. {sae_lens-5.7.1/sae_lens/toolkit → sae_lens-6.25.1/sae_lens/loading}/__init__.py +0 -0
  49. {sae_lens-5.7.1 → sae_lens-6.25.1}/sae_lens/training/__init__.py +0 -0
  50. {sae_lens-5.7.1 → sae_lens-6.25.1}/sae_lens/tutorial/tsea.py +0 -0
@@ -1,8 +1,9 @@
1
- Metadata-Version: 2.3
1
+ Metadata-Version: 2.4
2
2
  Name: sae-lens
3
- Version: 5.7.1
3
+ Version: 6.25.1
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
+ License-File: LICENSE
6
7
  Keywords: deep-learning,sparse-autoencoders,mechanistic-interpretability,PyTorch
7
8
  Author: Joseph Bloom
8
9
  Requires-Python: >=3.10,<4.0
@@ -12,41 +13,36 @@ Classifier: Programming Language :: Python :: 3.10
12
13
  Classifier: Programming Language :: Python :: 3.11
13
14
  Classifier: Programming Language :: Python :: 3.12
14
15
  Classifier: Programming Language :: Python :: 3.13
16
+ Classifier: Programming Language :: Python :: 3.14
15
17
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
16
18
  Provides-Extra: mamba
17
- Requires-Dist: automated-interpretability (>=0.0.5,<1.0.0)
18
19
  Requires-Dist: babe (>=0.0.7,<0.0.8)
19
- Requires-Dist: datasets (>=2.17.1,<3.0.0)
20
+ Requires-Dist: datasets (>=3.1.0)
20
21
  Requires-Dist: mamba-lens (>=0.0.4,<0.0.5) ; extra == "mamba"
21
- Requires-Dist: matplotlib (>=3.8.3,<4.0.0)
22
- Requires-Dist: matplotlib-inline (>=0.1.6,<0.2.0)
23
22
  Requires-Dist: nltk (>=3.8.1,<4.0.0)
24
- Requires-Dist: plotly (>=5.19.0,<6.0.0)
25
- Requires-Dist: plotly-express (>=0.4.1,<0.5.0)
26
- Requires-Dist: pytest-profiling (>=1.7.0,<2.0.0)
27
- Requires-Dist: python-dotenv (>=1.0.1,<2.0.0)
23
+ Requires-Dist: plotly (>=5.19.0)
24
+ Requires-Dist: plotly-express (>=0.4.1)
25
+ Requires-Dist: python-dotenv (>=1.0.1)
28
26
  Requires-Dist: pyyaml (>=6.0.1,<7.0.0)
29
- Requires-Dist: pyzmq (==26.0.0)
30
- Requires-Dist: safetensors (>=0.4.2,<0.5.0)
27
+ Requires-Dist: safetensors (>=0.4.2,<1.0.0)
31
28
  Requires-Dist: simple-parsing (>=0.1.6,<0.2.0)
32
- Requires-Dist: transformer-lens (>=2.0.0,<3.0.0)
29
+ Requires-Dist: tenacity (>=9.0.0)
30
+ Requires-Dist: transformer-lens (>=2.16.1,<3.0.0)
33
31
  Requires-Dist: transformers (>=4.38.1,<5.0.0)
34
- Requires-Dist: typer (>=0.12.3,<0.13.0)
35
32
  Requires-Dist: typing-extensions (>=4.10.0,<5.0.0)
36
- Requires-Dist: zstandard (>=0.22.0,<0.23.0)
37
- Project-URL: Homepage, https://jbloomaus.github.io/SAELens
38
- Project-URL: Repository, https://github.com/jbloomAus/SAELens
33
+ Project-URL: Homepage, https://decoderesearch.github.io/SAELens
34
+ Project-URL: Repository, https://github.com/decoderesearch/SAELens
39
35
  Description-Content-Type: text/markdown
40
36
 
41
- <img width="1308" alt="Screenshot 2024-03-21 at 3 08 28 pm" src="https://github.com/jbloomAus/mats_sae_training/assets/69127271/209012ec-a779-4036-b4be-7b7739ea87f6">
37
+ <img width="1308" height="532" alt="saes_pic" src="https://github.com/user-attachments/assets/2a5d752f-b261-4ee4-ad5d-ebf282321371" />
42
38
 
43
39
  # SAE Lens
44
40
 
45
41
  [![PyPI](https://img.shields.io/pypi/v/sae-lens?color=blue)](https://pypi.org/project/sae-lens/)
46
42
  [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
47
- [![build](https://github.com/jbloomAus/SAELens/actions/workflows/build.yml/badge.svg)](https://github.com/jbloomAus/SAELens/actions/workflows/build.yml)
48
- [![Deploy Docs](https://github.com/jbloomAus/SAELens/actions/workflows/deploy_docs.yml/badge.svg)](https://github.com/jbloomAus/SAELens/actions/workflows/deploy_docs.yml)
49
- [![codecov](https://codecov.io/gh/jbloomAus/SAELens/graph/badge.svg?token=N83NGH8CGE)](https://codecov.io/gh/jbloomAus/SAELens)
43
+ [![build](https://github.com/decoderesearch/SAELens/actions/workflows/build.yml/badge.svg)](https://github.com/decoderesearch/SAELens/actions/workflows/build.yml)
44
+ [![Deploy Docs](https://github.com/decoderesearch/SAELens/actions/workflows/deploy_docs.yml/badge.svg)](https://github.com/decoderesearch/SAELens/actions/workflows/deploy_docs.yml)
45
+ [![codecov](https://codecov.io/gh/decoderesearch/SAELens/graph/badge.svg?token=N83NGH8CGE)](https://codecov.io/gh/decoderesearch/SAELens)
50
46
 
51
47
  SAELens exists to help researchers:
52
48
 
@@ -54,7 +50,7 @@ SAELens exists to help researchers:
54
50
  - Analyse sparse autoencoders / research mechanistic interpretability.
55
51
  - Generate insights which make it easier to create safe and aligned AI systems.
56
52
 
57
- Please refer to the [documentation](https://jbloomaus.github.io/SAELens/) for information on how to:
53
+ Please refer to the [documentation](https://decoderesearch.github.io/SAELens/) for information on how to:
58
54
 
59
55
  - Download and Analyse pre-trained sparse autoencoders.
60
56
  - Train your own sparse autoencoders.
@@ -62,25 +58,29 @@ Please refer to the [documentation](https://jbloomaus.github.io/SAELens/) for in
62
58
 
63
59
  SAE Lens is the result of many contributors working collectively to improve humanity's understanding of neural networks, many of whom are motivated by a desire to [safeguard humanity from risks posed by artificial intelligence](https://80000hours.org/problem-profiles/artificial-intelligence/).
64
60
 
65
- This library is maintained by [Joseph Bloom](https://www.jbloomaus.com/) and [David Chanin](https://github.com/chanind).
61
+ This library is maintained by [Joseph Bloom](https://www.decoderesearch.com/), [Curt Tigges](https://curttigges.com/), [Anthony Duong](https://github.com/anthonyduong9) and [David Chanin](https://github.com/chanind).
66
62
 
67
63
  ## Loading Pre-trained SAEs.
68
64
 
69
- Pre-trained SAEs for various models can be imported via SAE Lens. See this [page](https://jbloomaus.github.io/SAELens/sae_table/) in the readme for a list of all SAEs.
65
+ Pre-trained SAEs for various models can be imported via SAE Lens. See this [page](https://decoderesearch.github.io/SAELens/pretrained_saes/) for a list of all SAEs.
66
+
67
+ ## Migrating to SAELens v6
68
+
69
+ The new v6 update is a major refactor to SAELens and changes the way training code is structured. Check out the [migration guide](https://decoderesearch.github.io/SAELens/latest/migrating/) for more details.
70
70
 
71
71
  ## Tutorials
72
72
 
73
- - [SAE Lens + Neuronpedia](tutorials/tutorial_2_0.ipynb)[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/tutorial_2_0.ipynb)
73
+ - [SAE Lens + Neuronpedia](tutorials/tutorial_2_0.ipynb)[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/tutorial_2_0.ipynb)
74
74
  - [Loading and Analysing Pre-Trained Sparse Autoencoders](tutorials/basic_loading_and_analysing.ipynb)
75
- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb)
75
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb)
76
76
  - [Understanding SAE Features with the Logit Lens](tutorials/logits_lens_with_features.ipynb)
77
- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/logits_lens_with_features.ipynb)
77
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/logits_lens_with_features.ipynb)
78
78
  - [Training a Sparse Autoencoder](tutorials/training_a_sparse_autoencoder.ipynb)
79
- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)
79
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)
80
80
 
81
81
  ## Join the Slack!
82
82
 
83
- Feel free to join the [Open Source Mechanistic Interpretability Slack](https://join.slack.com/t/opensourcemechanistic/shared_invite/zt-2o756ku1c-_yKBeUQMVfS_p_qcK6QLeA) for support!
83
+ Feel free to join the [Open Source Mechanistic Interpretability Slack](https://join.slack.com/t/opensourcemechanistic/shared_invite/zt-375zalm04-GFd5tdBU1yLKlu_T_JSqZQ) for support!
84
84
 
85
85
  ## Citation
86
86
 
@@ -89,9 +89,9 @@ Please cite the package as follows:
89
89
  ```
90
90
  @misc{bloom2024saetrainingcodebase,
91
91
  title = {SAELens},
92
- author = {Joseph Bloom, Curt Tigges and David Chanin},
92
+ author = {Bloom, Joseph and Tigges, Curt and Duong, Anthony and Chanin, David},
93
93
  year = {2024},
94
- howpublished = {\url{https://github.com/jbloomAus/SAELens}},
94
+ howpublished = {\url{https://github.com/decoderesearch/SAELens}},
95
95
  }
96
96
  ```
97
97
 
@@ -1,12 +1,12 @@
1
- <img width="1308" alt="Screenshot 2024-03-21 at 3 08 28 pm" src="https://github.com/jbloomAus/mats_sae_training/assets/69127271/209012ec-a779-4036-b4be-7b7739ea87f6">
1
+ <img width="1308" height="532" alt="saes_pic" src="https://github.com/user-attachments/assets/2a5d752f-b261-4ee4-ad5d-ebf282321371" />
2
2
 
3
3
  # SAE Lens
4
4
 
5
5
  [![PyPI](https://img.shields.io/pypi/v/sae-lens?color=blue)](https://pypi.org/project/sae-lens/)
6
6
  [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
7
- [![build](https://github.com/jbloomAus/SAELens/actions/workflows/build.yml/badge.svg)](https://github.com/jbloomAus/SAELens/actions/workflows/build.yml)
8
- [![Deploy Docs](https://github.com/jbloomAus/SAELens/actions/workflows/deploy_docs.yml/badge.svg)](https://github.com/jbloomAus/SAELens/actions/workflows/deploy_docs.yml)
9
- [![codecov](https://codecov.io/gh/jbloomAus/SAELens/graph/badge.svg?token=N83NGH8CGE)](https://codecov.io/gh/jbloomAus/SAELens)
7
+ [![build](https://github.com/decoderesearch/SAELens/actions/workflows/build.yml/badge.svg)](https://github.com/decoderesearch/SAELens/actions/workflows/build.yml)
8
+ [![Deploy Docs](https://github.com/decoderesearch/SAELens/actions/workflows/deploy_docs.yml/badge.svg)](https://github.com/decoderesearch/SAELens/actions/workflows/deploy_docs.yml)
9
+ [![codecov](https://codecov.io/gh/decoderesearch/SAELens/graph/badge.svg?token=N83NGH8CGE)](https://codecov.io/gh/decoderesearch/SAELens)
10
10
 
11
11
  SAELens exists to help researchers:
12
12
 
@@ -14,7 +14,7 @@ SAELens exists to help researchers:
14
14
  - Analyse sparse autoencoders / research mechanistic interpretability.
15
15
  - Generate insights which make it easier to create safe and aligned AI systems.
16
16
 
17
- Please refer to the [documentation](https://jbloomaus.github.io/SAELens/) for information on how to:
17
+ Please refer to the [documentation](https://decoderesearch.github.io/SAELens/) for information on how to:
18
18
 
19
19
  - Download and Analyse pre-trained sparse autoencoders.
20
20
  - Train your own sparse autoencoders.
@@ -22,25 +22,29 @@ Please refer to the [documentation](https://jbloomaus.github.io/SAELens/) for in
22
22
 
23
23
  SAE Lens is the result of many contributors working collectively to improve humanity's understanding of neural networks, many of whom are motivated by a desire to [safeguard humanity from risks posed by artificial intelligence](https://80000hours.org/problem-profiles/artificial-intelligence/).
24
24
 
25
- This library is maintained by [Joseph Bloom](https://www.jbloomaus.com/) and [David Chanin](https://github.com/chanind).
25
+ This library is maintained by [Joseph Bloom](https://www.decoderesearch.com/), [Curt Tigges](https://curttigges.com/), [Anthony Duong](https://github.com/anthonyduong9) and [David Chanin](https://github.com/chanind).
26
26
 
27
27
  ## Loading Pre-trained SAEs.
28
28
 
29
- Pre-trained SAEs for various models can be imported via SAE Lens. See this [page](https://jbloomaus.github.io/SAELens/sae_table/) in the readme for a list of all SAEs.
29
+ Pre-trained SAEs for various models can be imported via SAE Lens. See this [page](https://decoderesearch.github.io/SAELens/pretrained_saes/) for a list of all SAEs.
30
+
31
+ ## Migrating to SAELens v6
32
+
33
+ The new v6 update is a major refactor to SAELens and changes the way training code is structured. Check out the [migration guide](https://decoderesearch.github.io/SAELens/latest/migrating/) for more details.
30
34
 
31
35
  ## Tutorials
32
36
 
33
- - [SAE Lens + Neuronpedia](tutorials/tutorial_2_0.ipynb)[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/tutorial_2_0.ipynb)
37
+ - [SAE Lens + Neuronpedia](tutorials/tutorial_2_0.ipynb)[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/tutorial_2_0.ipynb)
34
38
  - [Loading and Analysing Pre-Trained Sparse Autoencoders](tutorials/basic_loading_and_analysing.ipynb)
35
- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb)
39
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb)
36
40
  - [Understanding SAE Features with the Logit Lens](tutorials/logits_lens_with_features.ipynb)
37
- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/logits_lens_with_features.ipynb)
41
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/logits_lens_with_features.ipynb)
38
42
  - [Training a Sparse Autoencoder](tutorials/training_a_sparse_autoencoder.ipynb)
39
- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)
43
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)
40
44
 
41
45
  ## Join the Slack!
42
46
 
43
- Feel free to join the [Open Source Mechanistic Interpretability Slack](https://join.slack.com/t/opensourcemechanistic/shared_invite/zt-2o756ku1c-_yKBeUQMVfS_p_qcK6QLeA) for support!
47
+ Feel free to join the [Open Source Mechanistic Interpretability Slack](https://join.slack.com/t/opensourcemechanistic/shared_invite/zt-375zalm04-GFd5tdBU1yLKlu_T_JSqZQ) for support!
44
48
 
45
49
  ## Citation
46
50
 
@@ -49,8 +53,8 @@ Please cite the package as follows:
49
53
  ```
50
54
  @misc{bloom2024saetrainingcodebase,
51
55
  title = {SAELens},
52
- author = {Joseph Bloom, Curt Tigges and David Chanin},
56
+ author = {Bloom, Joseph and Tigges, Curt and Duong, Anthony and Chanin, David},
53
57
  year = {2024},
54
- howpublished = {\url{https://github.com/jbloomAus/SAELens}},
58
+ howpublished = {\url{https://github.com/decoderesearch/SAELens}},
55
59
  }
56
60
  ```
@@ -1,13 +1,13 @@
1
1
  [tool.poetry]
2
2
  name = "sae-lens"
3
- version = "5.7.1"
3
+ version = "6.25.1"
4
4
  description = "Training and Analyzing Sparse Autoencoders (SAEs)"
5
5
  authors = ["Joseph Bloom"]
6
6
  readme = "README.md"
7
7
  packages = [{ include = "sae_lens" }]
8
8
  include = ["pretrained_saes.yaml"]
9
- repository = "https://github.com/jbloomAus/SAELens"
10
- homepage = "https://jbloomaus.github.io/SAELens"
9
+ repository = "https://github.com/decoderesearch/SAELens"
10
+ homepage = "https://decoderesearch.github.io/SAELens"
11
11
  license = "MIT"
12
12
  keywords = [
13
13
  "deep-learning",
@@ -19,26 +19,20 @@ classifiers = ["Topic :: Scientific/Engineering :: Artificial Intelligence"]
19
19
 
20
20
  [tool.poetry.dependencies]
21
21
  python = "^3.10"
22
- transformer-lens = "^2.0.0"
22
+ transformer-lens = "^2.16.1"
23
23
  transformers = "^4.38.1"
24
- plotly = "^5.19.0"
25
- plotly-express = "^0.4.1"
26
- matplotlib = "^3.8.3"
27
- matplotlib-inline = "^0.1.6"
28
- datasets = "^2.17.1"
24
+ plotly = ">=5.19.0"
25
+ plotly-express = ">=0.4.1"
26
+ datasets = ">=3.1.0"
29
27
  babe = "^0.0.7"
30
28
  nltk = "^3.8.1"
31
- safetensors = "^0.4.2"
32
- typer = "^0.12.3"
29
+ safetensors = ">=0.4.2,<1.0.0"
33
30
  mamba-lens = { version = "^0.0.4", optional = true }
34
- pyzmq = "26.0.0"
35
- automated-interpretability = ">=0.0.5,<1.0.0"
36
- python-dotenv = "^1.0.1"
31
+ python-dotenv = ">=1.0.1"
37
32
  pyyaml = "^6.0.1"
38
- pytest-profiling = "^1.7.0"
39
- zstandard = "^0.22.0"
40
33
  typing-extensions = "^4.10.0"
41
34
  simple-parsing = "^0.1.6"
35
+ tenacity = ">=9.0.0"
42
36
 
43
37
  [tool.poetry.group.dev.dependencies]
44
38
  pytest = "^8.0.2"
@@ -52,13 +46,17 @@ boto3 = "^1.34.101"
52
46
  docstr-coverage = "^2.3.2"
53
47
  mkdocs = "^1.6.1"
54
48
  mkdocs-material = "^9.5.34"
55
- mkdocs-autorefs = "^1.1.0"
49
+ mkdocs-autorefs = "^1.4.2"
50
+ mkdocs-redirects = "^1.2.1"
56
51
  mkdocs-section-index = "^0.3.9"
57
52
  mkdocstrings = "^0.25.2"
58
53
  mkdocstrings-python = "^1.10.9"
59
54
  tabulate = "^0.9.0"
60
55
  ruff = "^0.7.4"
61
- sparsify = {git = "https://github.com/EleutherAI/sparsify"}
56
+ eai-sparsify = "^1.1.1"
57
+ mike = "^2.0.0"
58
+ trio = "^0.30.0"
59
+ dictionary-learning = "^0.1.0"
62
60
 
63
61
  [tool.poetry.extras]
64
62
  mamba = ["mamba-lens"]
@@ -69,8 +67,9 @@ ignore = ["E203", "E501", "E731", "F722", "E741", "F821", "F403", "ARG002"]
69
67
  select = ["UP", "TID", "I", "F", "E", "ARG", "SIM", "RET", "LOG", "T20"]
70
68
 
71
69
  [tool.ruff.lint.per-file-ignores]
72
- "benchmark/*" = ["T20"]
70
+ "benchmark/*" = ["T20", "TID251"]
73
71
  "scripts/*" = ["T20"]
72
+ "tests/*" = ["TID251"]
74
73
 
75
74
  [tool.ruff.lint.flake8-tidy-imports.banned-api]
76
75
  "typing.Union".msg = "Use `|` instead"
@@ -78,6 +77,7 @@ select = ["UP", "TID", "I", "F", "E", "ARG", "SIM", "RET", "LOG", "T20"]
78
77
  "typing.Dict".msg = "Use `dict` instead"
79
78
  "typing.Tuple".msg = "Use `tuple` instead"
80
79
  "typing.List".msg = "Use `list` instead"
80
+ "tests".msg = "Do not import from tests in the main codebase."
81
81
 
82
82
  [tool.pyright]
83
83
  typeCheckingMode = "strict"
@@ -102,5 +102,5 @@ build-backend = "poetry.core.masonry.api"
102
102
  [tool.semantic_release]
103
103
  version_variables = ["sae_lens/__init__.py:__version__"]
104
104
  version_toml = ["pyproject.toml:tool.poetry.version"]
105
- branch = "main"
106
105
  build_command = "pip install poetry && poetry build"
106
+ branches = { main = { match = "main" }, alpha = { match = "alpha", prerelease = true }, beta = { match = "beta", prerelease = true } }
@@ -0,0 +1,141 @@
1
+ # ruff: noqa: E402
2
+ __version__ = "6.25.1"
3
+
4
+ import logging
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+ from sae_lens.saes import (
9
+ SAE,
10
+ BatchTopKTrainingSAE,
11
+ BatchTopKTrainingSAEConfig,
12
+ GatedSAE,
13
+ GatedSAEConfig,
14
+ GatedTrainingSAE,
15
+ GatedTrainingSAEConfig,
16
+ JumpReLUSAE,
17
+ JumpReLUSAEConfig,
18
+ JumpReLUSkipTranscoder,
19
+ JumpReLUSkipTranscoderConfig,
20
+ JumpReLUTrainingSAE,
21
+ JumpReLUTrainingSAEConfig,
22
+ JumpReLUTranscoder,
23
+ JumpReLUTranscoderConfig,
24
+ MatryoshkaBatchTopKTrainingSAE,
25
+ MatryoshkaBatchTopKTrainingSAEConfig,
26
+ SAEConfig,
27
+ SkipTranscoder,
28
+ SkipTranscoderConfig,
29
+ StandardSAE,
30
+ StandardSAEConfig,
31
+ StandardTrainingSAE,
32
+ StandardTrainingSAEConfig,
33
+ TemporalSAE,
34
+ TemporalSAEConfig,
35
+ TopKSAE,
36
+ TopKSAEConfig,
37
+ TopKTrainingSAE,
38
+ TopKTrainingSAEConfig,
39
+ TrainingSAE,
40
+ TrainingSAEConfig,
41
+ Transcoder,
42
+ TranscoderConfig,
43
+ )
44
+
45
+ from .analysis.hooked_sae_transformer import HookedSAETransformer
46
+ from .cache_activations_runner import CacheActivationsRunner
47
+ from .config import (
48
+ CacheActivationsRunnerConfig,
49
+ LanguageModelSAERunnerConfig,
50
+ LoggingConfig,
51
+ PretokenizeRunnerConfig,
52
+ )
53
+ from .evals import run_evals
54
+ from .llm_sae_training_runner import LanguageModelSAETrainingRunner, SAETrainingRunner
55
+ from .loading.pretrained_sae_loaders import (
56
+ PretrainedSaeDiskLoader,
57
+ PretrainedSaeHuggingfaceLoader,
58
+ )
59
+ from .pretokenize_runner import PretokenizeRunner, pretokenize_runner
60
+ from .registry import register_sae_class, register_sae_training_class
61
+ from .training.activations_store import ActivationsStore
62
+ from .training.upload_saes_to_huggingface import upload_saes_to_huggingface
63
+
64
+ __all__ = [
65
+ "SAE",
66
+ "SAEConfig",
67
+ "TrainingSAE",
68
+ "TrainingSAEConfig",
69
+ "HookedSAETransformer",
70
+ "ActivationsStore",
71
+ "LanguageModelSAERunnerConfig",
72
+ "LanguageModelSAETrainingRunner",
73
+ "CacheActivationsRunnerConfig",
74
+ "CacheActivationsRunner",
75
+ "PretokenizeRunnerConfig",
76
+ "PretokenizeRunner",
77
+ "pretokenize_runner",
78
+ "run_evals",
79
+ "upload_saes_to_huggingface",
80
+ "PretrainedSaeHuggingfaceLoader",
81
+ "PretrainedSaeDiskLoader",
82
+ "register_sae_class",
83
+ "register_sae_training_class",
84
+ "StandardSAE",
85
+ "StandardSAEConfig",
86
+ "StandardTrainingSAE",
87
+ "StandardTrainingSAEConfig",
88
+ "GatedSAE",
89
+ "GatedSAEConfig",
90
+ "GatedTrainingSAE",
91
+ "GatedTrainingSAEConfig",
92
+ "TopKSAE",
93
+ "TopKSAEConfig",
94
+ "TopKTrainingSAE",
95
+ "TopKTrainingSAEConfig",
96
+ "JumpReLUSAE",
97
+ "JumpReLUSAEConfig",
98
+ "JumpReLUTrainingSAE",
99
+ "JumpReLUTrainingSAEConfig",
100
+ "SAETrainingRunner",
101
+ "LoggingConfig",
102
+ "BatchTopKTrainingSAE",
103
+ "BatchTopKTrainingSAEConfig",
104
+ "Transcoder",
105
+ "TranscoderConfig",
106
+ "SkipTranscoder",
107
+ "SkipTranscoderConfig",
108
+ "JumpReLUTranscoder",
109
+ "JumpReLUTranscoderConfig",
110
+ "JumpReLUSkipTranscoder",
111
+ "JumpReLUSkipTranscoderConfig",
112
+ "MatryoshkaBatchTopKTrainingSAE",
113
+ "MatryoshkaBatchTopKTrainingSAEConfig",
114
+ "TemporalSAE",
115
+ "TemporalSAEConfig",
116
+ ]
117
+
118
+
119
+ register_sae_class("standard", StandardSAE, StandardSAEConfig)
120
+ register_sae_training_class("standard", StandardTrainingSAE, StandardTrainingSAEConfig)
121
+ register_sae_class("gated", GatedSAE, GatedSAEConfig)
122
+ register_sae_training_class("gated", GatedTrainingSAE, GatedTrainingSAEConfig)
123
+ register_sae_class("topk", TopKSAE, TopKSAEConfig)
124
+ register_sae_training_class("topk", TopKTrainingSAE, TopKTrainingSAEConfig)
125
+ register_sae_class("jumprelu", JumpReLUSAE, JumpReLUSAEConfig)
126
+ register_sae_training_class("jumprelu", JumpReLUTrainingSAE, JumpReLUTrainingSAEConfig)
127
+ register_sae_training_class(
128
+ "batchtopk", BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig
129
+ )
130
+ register_sae_training_class(
131
+ "matryoshka_batchtopk",
132
+ MatryoshkaBatchTopKTrainingSAE,
133
+ MatryoshkaBatchTopKTrainingSAEConfig,
134
+ )
135
+ register_sae_class("transcoder", Transcoder, TranscoderConfig)
136
+ register_sae_class("skip_transcoder", SkipTranscoder, SkipTranscoderConfig)
137
+ register_sae_class("jumprelu_transcoder", JumpReLUTranscoder, JumpReLUTranscoderConfig)
138
+ register_sae_class(
139
+ "jumprelu_skip_transcoder", JumpReLUSkipTranscoder, JumpReLUSkipTranscoderConfig
140
+ )
141
+ register_sae_class("temporal", TemporalSAE, TemporalSAEConfig)
@@ -3,15 +3,15 @@ from contextlib import contextmanager
3
3
  from typing import Any, Callable
4
4
 
5
5
  import torch
6
- from jaxtyping import Float
7
6
  from transformer_lens.ActivationCache import ActivationCache
7
+ from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP
8
8
  from transformer_lens.hook_points import HookPoint # Hooking utilities
9
9
  from transformer_lens.HookedTransformer import HookedTransformer
10
10
 
11
- from sae_lens.sae import SAE
11
+ from sae_lens.saes.sae import SAE
12
12
 
13
- SingleLoss = Float[torch.Tensor, ""] # Type alias for a single element tensor
14
- LossPerToken = Float[torch.Tensor, "batch pos-1"]
13
+ SingleLoss = torch.Tensor # Type alias for a single element tensor
14
+ LossPerToken = torch.Tensor
15
15
  Loss = SingleLoss | LossPerToken
16
16
 
17
17
 
@@ -50,6 +50,13 @@ def set_deep_attr(obj: Any, path: str, value: Any):
50
50
  setattr(obj, parts[-1], value)
51
51
 
52
52
 
53
+ def add_hook_in_to_mlp(mlp: CanBeUsedAsMLP):
54
+ # Temporary hack to add a `mlp.hook_in` hook to mimic what's in circuit-tracer
55
+ mlp.hook_in = HookPoint()
56
+ original_forward = mlp.forward
57
+ mlp.forward = lambda x: original_forward(mlp.hook_in(x)) # type: ignore
58
+
59
+
53
60
  class HookedSAETransformer(HookedTransformer):
54
61
  def __init__(
55
62
  self,
@@ -66,9 +73,14 @@ class HookedSAETransformer(HookedTransformer):
66
73
  **model_kwargs: Keyword arguments for HookedTransformer initialization
67
74
  """
68
75
  super().__init__(*model_args, **model_kwargs)
76
+
77
+ for block in self.blocks:
78
+ add_hook_in_to_mlp(block.mlp) # type: ignore
79
+ self.setup()
80
+
69
81
  self.acts_to_saes: dict[str, SAE] = {} # type: ignore
70
82
 
71
- def add_sae(self, sae: SAE, use_error_term: bool | None = None):
83
+ def add_sae(self, sae: SAE[Any], use_error_term: bool | None = None):
72
84
  """Attaches an SAE to the model
73
85
 
74
86
  WARNING: This sae will be permanantly attached until you remove it with reset_saes. This function will also overwrite any existing SAE attached to the same hook point.
@@ -77,7 +89,7 @@ class HookedSAETransformer(HookedTransformer):
77
89
  sae: SparseAutoencoderBase. The SAE to attach to the model
78
90
  use_error_term: (bool | None) If provided, will set the use_error_term attribute of the SAE to this value. Determines whether the SAE returns input or reconstruction. Defaults to None.
79
91
  """
80
- act_name = sae.cfg.hook_name
92
+ act_name = sae.cfg.metadata.hook_name
81
93
  if (act_name not in self.acts_to_saes) and (act_name not in self.hook_dict):
82
94
  logging.warning(
83
95
  f"No hook found for {act_name}. Skipping. Check model.hook_dict for available hooks."
@@ -92,7 +104,7 @@ class HookedSAETransformer(HookedTransformer):
92
104
  set_deep_attr(self, act_name, sae)
93
105
  self.setup()
94
106
 
95
- def _reset_sae(self, act_name: str, prev_sae: SAE | None = None):
107
+ def _reset_sae(self, act_name: str, prev_sae: SAE[Any] | None = None):
96
108
  """Resets an SAE that was attached to the model
97
109
 
98
110
  By default will remove the SAE from that hook_point.
@@ -124,7 +136,7 @@ class HookedSAETransformer(HookedTransformer):
124
136
  def reset_saes(
125
137
  self,
126
138
  act_names: str | list[str] | None = None,
127
- prev_saes: list[SAE | None] | None = None,
139
+ prev_saes: list[SAE[Any] | None] | None = None,
128
140
  ):
129
141
  """Reset the SAEs attached to the model
130
142
 
@@ -154,16 +166,11 @@ class HookedSAETransformer(HookedTransformer):
154
166
  def run_with_saes(
155
167
  self,
156
168
  *model_args: Any,
157
- saes: SAE | list[SAE] = [],
169
+ saes: SAE[Any] | list[SAE[Any]] = [],
158
170
  reset_saes_end: bool = True,
159
171
  use_error_term: bool | None = None,
160
172
  **model_kwargs: Any,
161
- ) -> (
162
- None
163
- | Float[torch.Tensor, "batch pos d_vocab"]
164
- | Loss
165
- | tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss]
166
- ):
173
+ ) -> None | torch.Tensor | Loss | tuple[torch.Tensor, Loss]:
167
174
  """Wrapper around HookedTransformer forward pass.
168
175
 
169
176
  Runs the model with the given SAEs attached for one forward pass, then removes them. By default, will reset all SAEs to original state after.
@@ -183,17 +190,14 @@ class HookedSAETransformer(HookedTransformer):
183
190
  def run_with_cache_with_saes(
184
191
  self,
185
192
  *model_args: Any,
186
- saes: SAE | list[SAE] = [],
193
+ saes: SAE[Any] | list[SAE[Any]] = [],
187
194
  reset_saes_end: bool = True,
188
195
  use_error_term: bool | None = None,
189
196
  return_cache_object: bool = True,
190
197
  remove_batch_dim: bool = False,
191
198
  **kwargs: Any,
192
199
  ) -> tuple[
193
- None
194
- | Float[torch.Tensor, "batch pos d_vocab"]
195
- | Loss
196
- | tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss],
200
+ None | torch.Tensor | Loss | tuple[torch.Tensor, Loss],
197
201
  ActivationCache | dict[str, torch.Tensor],
198
202
  ]:
199
203
  """Wrapper around 'run_with_cache' in HookedTransformer.
@@ -225,7 +229,7 @@ class HookedSAETransformer(HookedTransformer):
225
229
  def run_with_hooks_with_saes(
226
230
  self,
227
231
  *model_args: Any,
228
- saes: SAE | list[SAE] = [],
232
+ saes: SAE[Any] | list[SAE[Any]] = [],
229
233
  reset_saes_end: bool = True,
230
234
  fwd_hooks: list[tuple[str | Callable, Callable]] = [], # type: ignore
231
235
  bwd_hooks: list[tuple[str | Callable, Callable]] = [], # type: ignore
@@ -261,7 +265,7 @@ class HookedSAETransformer(HookedTransformer):
261
265
  @contextmanager
262
266
  def saes(
263
267
  self,
264
- saes: SAE | list[SAE] = [],
268
+ saes: SAE[Any] | list[SAE[Any]] = [],
265
269
  reset_saes_end: bool = True,
266
270
  use_error_term: bool | None = None,
267
271
  ):
@@ -275,7 +279,7 @@ class HookedSAETransformer(HookedTransformer):
275
279
  .. code-block:: python
276
280
 
277
281
  from transformer_lens import HookedSAETransformer
278
- from sae_lens.sae import SAE
282
+ from sae_lens.saes.sae import SAE
279
283
 
280
284
  model = HookedSAETransformer.from_pretrained('gpt2-small')
281
285
  sae_cfg = SAEConfig(...)
@@ -295,8 +299,8 @@ class HookedSAETransformer(HookedTransformer):
295
299
  saes = [saes]
296
300
  try:
297
301
  for sae in saes:
298
- act_names_to_reset.append(sae.cfg.hook_name)
299
- prev_sae = self.acts_to_saes.get(sae.cfg.hook_name, None)
302
+ act_names_to_reset.append(sae.cfg.metadata.hook_name)
303
+ prev_sae = self.acts_to_saes.get(sae.cfg.metadata.hook_name, None)
300
304
  prev_saes.append(prev_sae)
301
305
  self.add_sae(sae, use_error_term=use_error_term)
302
306
  yield self