sae-lens 6.3.0__tar.gz → 6.25.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sae-lens might be problematic. Click here for more details.

Files changed (45) hide show
  1. {sae_lens-6.3.0 → sae_lens-6.25.1}/PKG-INFO +26 -30
  2. {sae_lens-6.3.0 → sae_lens-6.25.1}/README.md +13 -13
  3. {sae_lens-6.3.0 → sae_lens-6.25.1}/pyproject.toml +12 -16
  4. {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/__init__.py +37 -1
  5. {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/analysis/hooked_sae_transformer.py +17 -13
  6. sae_lens-6.25.1/sae_lens/analysis/neuronpedia_integration.py +163 -0
  7. {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/cache_activations_runner.py +6 -7
  8. {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/config.py +64 -10
  9. {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/constants.py +9 -0
  10. {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/evals.py +52 -29
  11. {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/llm_sae_training_runner.py +70 -24
  12. {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/load_model.py +1 -1
  13. {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/loading/pretrained_sae_loaders.py +851 -56
  14. {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/loading/pretrained_saes_directory.py +5 -3
  15. {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/pretokenize_runner.py +5 -4
  16. sae_lens-6.25.1/sae_lens/pretrained_saes.yaml +41813 -0
  17. {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/saes/__init__.py +27 -0
  18. {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/saes/batchtopk_sae.py +34 -2
  19. {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/saes/gated_sae.py +6 -11
  20. {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/saes/jumprelu_sae.py +72 -17
  21. sae_lens-6.25.1/sae_lens/saes/matryoshka_batchtopk_sae.py +136 -0
  22. {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/saes/sae.py +81 -54
  23. {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/saes/standard_sae.py +4 -9
  24. sae_lens-6.25.1/sae_lens/saes/temporal_sae.py +365 -0
  25. sae_lens-6.25.1/sae_lens/saes/topk_sae.py +538 -0
  26. sae_lens-6.25.1/sae_lens/saes/transcoder.py +411 -0
  27. {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/tokenization_and_batching.py +21 -6
  28. {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/training/activation_scaler.py +7 -0
  29. {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/training/activations_store.py +62 -41
  30. {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/training/optim.py +11 -0
  31. {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/training/sae_trainer.py +77 -48
  32. {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/training/types.py +1 -1
  33. sae_lens-6.25.1/sae_lens/util.py +113 -0
  34. sae_lens-6.3.0/sae_lens/analysis/neuronpedia_integration.py +0 -494
  35. sae_lens-6.3.0/sae_lens/pretrained_saes.yaml +0 -13976
  36. sae_lens-6.3.0/sae_lens/saes/topk_sae.py +0 -271
  37. sae_lens-6.3.0/sae_lens/util.py +0 -47
  38. {sae_lens-6.3.0 → sae_lens-6.25.1}/LICENSE +0 -0
  39. {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/analysis/__init__.py +0 -0
  40. {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/loading/__init__.py +0 -0
  41. {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/registry.py +0 -0
  42. {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/training/__init__.py +0 -0
  43. {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/training/mixing_buffer.py +0 -0
  44. {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
  45. {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/tutorial/tsea.py +0 -0
@@ -1,8 +1,9 @@
1
- Metadata-Version: 2.3
1
+ Metadata-Version: 2.4
2
2
  Name: sae-lens
3
- Version: 6.3.0
3
+ Version: 6.25.1
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
+ License-File: LICENSE
6
7
  Keywords: deep-learning,sparse-autoencoders,mechanistic-interpretability,PyTorch
7
8
  Author: Joseph Bloom
8
9
  Requires-Python: >=3.10,<4.0
@@ -12,41 +13,36 @@ Classifier: Programming Language :: Python :: 3.10
12
13
  Classifier: Programming Language :: Python :: 3.11
13
14
  Classifier: Programming Language :: Python :: 3.12
14
15
  Classifier: Programming Language :: Python :: 3.13
16
+ Classifier: Programming Language :: Python :: 3.14
15
17
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
16
18
  Provides-Extra: mamba
17
- Requires-Dist: automated-interpretability (>=0.0.5,<1.0.0)
18
19
  Requires-Dist: babe (>=0.0.7,<0.0.8)
19
- Requires-Dist: datasets (>=2.17.1,<3.0.0)
20
+ Requires-Dist: datasets (>=3.1.0)
20
21
  Requires-Dist: mamba-lens (>=0.0.4,<0.0.5) ; extra == "mamba"
21
- Requires-Dist: matplotlib (>=3.8.3,<4.0.0)
22
- Requires-Dist: matplotlib-inline (>=0.1.6,<0.2.0)
23
22
  Requires-Dist: nltk (>=3.8.1,<4.0.0)
24
- Requires-Dist: plotly (>=5.19.0,<6.0.0)
25
- Requires-Dist: plotly-express (>=0.4.1,<0.5.0)
26
- Requires-Dist: pytest-profiling (>=1.7.0,<2.0.0)
27
- Requires-Dist: python-dotenv (>=1.0.1,<2.0.0)
23
+ Requires-Dist: plotly (>=5.19.0)
24
+ Requires-Dist: plotly-express (>=0.4.1)
25
+ Requires-Dist: python-dotenv (>=1.0.1)
28
26
  Requires-Dist: pyyaml (>=6.0.1,<7.0.0)
29
- Requires-Dist: pyzmq (==26.0.0)
30
- Requires-Dist: safetensors (>=0.4.2,<0.5.0)
27
+ Requires-Dist: safetensors (>=0.4.2,<1.0.0)
31
28
  Requires-Dist: simple-parsing (>=0.1.6,<0.2.0)
32
- Requires-Dist: transformer-lens (>=2.0.0,<3.0.0)
29
+ Requires-Dist: tenacity (>=9.0.0)
30
+ Requires-Dist: transformer-lens (>=2.16.1,<3.0.0)
33
31
  Requires-Dist: transformers (>=4.38.1,<5.0.0)
34
- Requires-Dist: typer (>=0.12.3,<0.13.0)
35
32
  Requires-Dist: typing-extensions (>=4.10.0,<5.0.0)
36
- Requires-Dist: zstandard (>=0.22.0,<0.23.0)
37
- Project-URL: Homepage, https://jbloomaus.github.io/SAELens
38
- Project-URL: Repository, https://github.com/jbloomAus/SAELens
33
+ Project-URL: Homepage, https://decoderesearch.github.io/SAELens
34
+ Project-URL: Repository, https://github.com/decoderesearch/SAELens
39
35
  Description-Content-Type: text/markdown
40
36
 
41
- <img width="1308" alt="Screenshot 2024-03-21 at 3 08 28 pm" src="https://github.com/jbloomAus/mats_sae_training/assets/69127271/209012ec-a779-4036-b4be-7b7739ea87f6">
37
+ <img width="1308" height="532" alt="saes_pic" src="https://github.com/user-attachments/assets/2a5d752f-b261-4ee4-ad5d-ebf282321371" />
42
38
 
43
39
  # SAE Lens
44
40
 
45
41
  [![PyPI](https://img.shields.io/pypi/v/sae-lens?color=blue)](https://pypi.org/project/sae-lens/)
46
42
  [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
47
- [![build](https://github.com/jbloomAus/SAELens/actions/workflows/build.yml/badge.svg)](https://github.com/jbloomAus/SAELens/actions/workflows/build.yml)
48
- [![Deploy Docs](https://github.com/jbloomAus/SAELens/actions/workflows/deploy_docs.yml/badge.svg)](https://github.com/jbloomAus/SAELens/actions/workflows/deploy_docs.yml)
49
- [![codecov](https://codecov.io/gh/jbloomAus/SAELens/graph/badge.svg?token=N83NGH8CGE)](https://codecov.io/gh/jbloomAus/SAELens)
43
+ [![build](https://github.com/decoderesearch/SAELens/actions/workflows/build.yml/badge.svg)](https://github.com/decoderesearch/SAELens/actions/workflows/build.yml)
44
+ [![Deploy Docs](https://github.com/decoderesearch/SAELens/actions/workflows/deploy_docs.yml/badge.svg)](https://github.com/decoderesearch/SAELens/actions/workflows/deploy_docs.yml)
45
+ [![codecov](https://codecov.io/gh/decoderesearch/SAELens/graph/badge.svg?token=N83NGH8CGE)](https://codecov.io/gh/decoderesearch/SAELens)
50
46
 
51
47
  SAELens exists to help researchers:
52
48
 
@@ -54,7 +50,7 @@ SAELens exists to help researchers:
54
50
  - Analyse sparse autoencoders / research mechanistic interpretability.
55
51
  - Generate insights which make it easier to create safe and aligned AI systems.
56
52
 
57
- Please refer to the [documentation](https://jbloomaus.github.io/SAELens/) for information on how to:
53
+ Please refer to the [documentation](https://decoderesearch.github.io/SAELens/) for information on how to:
58
54
 
59
55
  - Download and Analyse pre-trained sparse autoencoders.
60
56
  - Train your own sparse autoencoders.
@@ -62,25 +58,25 @@ Please refer to the [documentation](https://jbloomaus.github.io/SAELens/) for in
62
58
 
63
59
  SAE Lens is the result of many contributors working collectively to improve humanity's understanding of neural networks, many of whom are motivated by a desire to [safeguard humanity from risks posed by artificial intelligence](https://80000hours.org/problem-profiles/artificial-intelligence/).
64
60
 
65
- This library is maintained by [Joseph Bloom](https://www.jbloomaus.com/), [Curt Tigges](https://curttigges.com/), [Anthony Duong](https://github.com/anthonyduong9) and [David Chanin](https://github.com/chanind).
61
+ This library is maintained by [Joseph Bloom](https://www.decoderesearch.com/), [Curt Tigges](https://curttigges.com/), [Anthony Duong](https://github.com/anthonyduong9) and [David Chanin](https://github.com/chanind).
66
62
 
67
63
  ## Loading Pre-trained SAEs.
68
64
 
69
- Pre-trained SAEs for various models can be imported via SAE Lens. See this [page](https://jbloomaus.github.io/SAELens/sae_table/) in the readme for a list of all SAEs.
65
+ Pre-trained SAEs for various models can be imported via SAE Lens. See this [page](https://decoderesearch.github.io/SAELens/pretrained_saes/) for a list of all SAEs.
70
66
 
71
67
  ## Migrating to SAELens v6
72
68
 
73
- The new v6 update is a major refactor to SAELens and changes the way training code is structured. Check out the [migration guide](https://jbloomaus.github.io/SAELens/latest/migrating/) for more details.
69
+ The new v6 update is a major refactor to SAELens and changes the way training code is structured. Check out the [migration guide](https://decoderesearch.github.io/SAELens/latest/migrating/) for more details.
74
70
 
75
71
  ## Tutorials
76
72
 
77
- - [SAE Lens + Neuronpedia](tutorials/tutorial_2_0.ipynb)[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/tutorial_2_0.ipynb)
73
+ - [SAE Lens + Neuronpedia](tutorials/tutorial_2_0.ipynb)[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/tutorial_2_0.ipynb)
78
74
  - [Loading and Analysing Pre-Trained Sparse Autoencoders](tutorials/basic_loading_and_analysing.ipynb)
79
- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb)
75
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb)
80
76
  - [Understanding SAE Features with the Logit Lens](tutorials/logits_lens_with_features.ipynb)
81
- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/logits_lens_with_features.ipynb)
77
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/logits_lens_with_features.ipynb)
82
78
  - [Training a Sparse Autoencoder](tutorials/training_a_sparse_autoencoder.ipynb)
83
- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)
79
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)
84
80
 
85
81
  ## Join the Slack!
86
82
 
@@ -95,7 +91,7 @@ Please cite the package as follows:
95
91
  title = {SAELens},
96
92
  author = {Bloom, Joseph and Tigges, Curt and Duong, Anthony and Chanin, David},
97
93
  year = {2024},
98
- howpublished = {\url{https://github.com/jbloomAus/SAELens}},
94
+ howpublished = {\url{https://github.com/decoderesearch/SAELens}},
99
95
  }
100
96
  ```
101
97
 
@@ -1,12 +1,12 @@
1
- <img width="1308" alt="Screenshot 2024-03-21 at 3 08 28 pm" src="https://github.com/jbloomAus/mats_sae_training/assets/69127271/209012ec-a779-4036-b4be-7b7739ea87f6">
1
+ <img width="1308" height="532" alt="saes_pic" src="https://github.com/user-attachments/assets/2a5d752f-b261-4ee4-ad5d-ebf282321371" />
2
2
 
3
3
  # SAE Lens
4
4
 
5
5
  [![PyPI](https://img.shields.io/pypi/v/sae-lens?color=blue)](https://pypi.org/project/sae-lens/)
6
6
  [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
7
- [![build](https://github.com/jbloomAus/SAELens/actions/workflows/build.yml/badge.svg)](https://github.com/jbloomAus/SAELens/actions/workflows/build.yml)
8
- [![Deploy Docs](https://github.com/jbloomAus/SAELens/actions/workflows/deploy_docs.yml/badge.svg)](https://github.com/jbloomAus/SAELens/actions/workflows/deploy_docs.yml)
9
- [![codecov](https://codecov.io/gh/jbloomAus/SAELens/graph/badge.svg?token=N83NGH8CGE)](https://codecov.io/gh/jbloomAus/SAELens)
7
+ [![build](https://github.com/decoderesearch/SAELens/actions/workflows/build.yml/badge.svg)](https://github.com/decoderesearch/SAELens/actions/workflows/build.yml)
8
+ [![Deploy Docs](https://github.com/decoderesearch/SAELens/actions/workflows/deploy_docs.yml/badge.svg)](https://github.com/decoderesearch/SAELens/actions/workflows/deploy_docs.yml)
9
+ [![codecov](https://codecov.io/gh/decoderesearch/SAELens/graph/badge.svg?token=N83NGH8CGE)](https://codecov.io/gh/decoderesearch/SAELens)
10
10
 
11
11
  SAELens exists to help researchers:
12
12
 
@@ -14,7 +14,7 @@ SAELens exists to help researchers:
14
14
  - Analyse sparse autoencoders / research mechanistic interpretability.
15
15
  - Generate insights which make it easier to create safe and aligned AI systems.
16
16
 
17
- Please refer to the [documentation](https://jbloomaus.github.io/SAELens/) for information on how to:
17
+ Please refer to the [documentation](https://decoderesearch.github.io/SAELens/) for information on how to:
18
18
 
19
19
  - Download and Analyse pre-trained sparse autoencoders.
20
20
  - Train your own sparse autoencoders.
@@ -22,25 +22,25 @@ Please refer to the [documentation](https://jbloomaus.github.io/SAELens/) for in
22
22
 
23
23
  SAE Lens is the result of many contributors working collectively to improve humanity's understanding of neural networks, many of whom are motivated by a desire to [safeguard humanity from risks posed by artificial intelligence](https://80000hours.org/problem-profiles/artificial-intelligence/).
24
24
 
25
- This library is maintained by [Joseph Bloom](https://www.jbloomaus.com/), [Curt Tigges](https://curttigges.com/), [Anthony Duong](https://github.com/anthonyduong9) and [David Chanin](https://github.com/chanind).
25
+ This library is maintained by [Joseph Bloom](https://www.decoderesearch.com/), [Curt Tigges](https://curttigges.com/), [Anthony Duong](https://github.com/anthonyduong9) and [David Chanin](https://github.com/chanind).
26
26
 
27
27
  ## Loading Pre-trained SAEs.
28
28
 
29
- Pre-trained SAEs for various models can be imported via SAE Lens. See this [page](https://jbloomaus.github.io/SAELens/sae_table/) in the readme for a list of all SAEs.
29
+ Pre-trained SAEs for various models can be imported via SAE Lens. See this [page](https://decoderesearch.github.io/SAELens/pretrained_saes/) for a list of all SAEs.
30
30
 
31
31
  ## Migrating to SAELens v6
32
32
 
33
- The new v6 update is a major refactor to SAELens and changes the way training code is structured. Check out the [migration guide](https://jbloomaus.github.io/SAELens/latest/migrating/) for more details.
33
+ The new v6 update is a major refactor to SAELens and changes the way training code is structured. Check out the [migration guide](https://decoderesearch.github.io/SAELens/latest/migrating/) for more details.
34
34
 
35
35
  ## Tutorials
36
36
 
37
- - [SAE Lens + Neuronpedia](tutorials/tutorial_2_0.ipynb)[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/tutorial_2_0.ipynb)
37
+ - [SAE Lens + Neuronpedia](tutorials/tutorial_2_0.ipynb)[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/tutorial_2_0.ipynb)
38
38
  - [Loading and Analysing Pre-Trained Sparse Autoencoders](tutorials/basic_loading_and_analysing.ipynb)
39
- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb)
39
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb)
40
40
  - [Understanding SAE Features with the Logit Lens](tutorials/logits_lens_with_features.ipynb)
41
- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/logits_lens_with_features.ipynb)
41
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/logits_lens_with_features.ipynb)
42
42
  - [Training a Sparse Autoencoder](tutorials/training_a_sparse_autoencoder.ipynb)
43
- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)
43
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)
44
44
 
45
45
  ## Join the Slack!
46
46
 
@@ -55,6 +55,6 @@ Please cite the package as follows:
55
55
  title = {SAELens},
56
56
  author = {Bloom, Joseph and Tigges, Curt and Duong, Anthony and Chanin, David},
57
57
  year = {2024},
58
- howpublished = {\url{https://github.com/jbloomAus/SAELens}},
58
+ howpublished = {\url{https://github.com/decoderesearch/SAELens}},
59
59
  }
60
60
  ```
@@ -1,13 +1,13 @@
1
1
  [tool.poetry]
2
2
  name = "sae-lens"
3
- version = "6.3.0"
3
+ version = "6.25.1"
4
4
  description = "Training and Analyzing Sparse Autoencoders (SAEs)"
5
5
  authors = ["Joseph Bloom"]
6
6
  readme = "README.md"
7
7
  packages = [{ include = "sae_lens" }]
8
8
  include = ["pretrained_saes.yaml"]
9
- repository = "https://github.com/jbloomAus/SAELens"
10
- homepage = "https://jbloomaus.github.io/SAELens"
9
+ repository = "https://github.com/decoderesearch/SAELens"
10
+ homepage = "https://decoderesearch.github.io/SAELens"
11
11
  license = "MIT"
12
12
  keywords = [
13
13
  "deep-learning",
@@ -19,26 +19,20 @@ classifiers = ["Topic :: Scientific/Engineering :: Artificial Intelligence"]
19
19
 
20
20
  [tool.poetry.dependencies]
21
21
  python = "^3.10"
22
- transformer-lens = "^2.0.0"
22
+ transformer-lens = "^2.16.1"
23
23
  transformers = "^4.38.1"
24
- plotly = "^5.19.0"
25
- plotly-express = "^0.4.1"
26
- matplotlib = "^3.8.3"
27
- matplotlib-inline = "^0.1.6"
28
- datasets = "^2.17.1"
24
+ plotly = ">=5.19.0"
25
+ plotly-express = ">=0.4.1"
26
+ datasets = ">=3.1.0"
29
27
  babe = "^0.0.7"
30
28
  nltk = "^3.8.1"
31
- safetensors = "^0.4.2"
32
- typer = "^0.12.3"
29
+ safetensors = ">=0.4.2,<1.0.0"
33
30
  mamba-lens = { version = "^0.0.4", optional = true }
34
- pyzmq = "26.0.0"
35
- automated-interpretability = ">=0.0.5,<1.0.0"
36
- python-dotenv = "^1.0.1"
31
+ python-dotenv = ">=1.0.1"
37
32
  pyyaml = "^6.0.1"
38
- pytest-profiling = "^1.7.0"
39
- zstandard = "^0.22.0"
40
33
  typing-extensions = "^4.10.0"
41
34
  simple-parsing = "^0.1.6"
35
+ tenacity = ">=9.0.0"
42
36
 
43
37
  [tool.poetry.group.dev.dependencies]
44
38
  pytest = "^8.0.2"
@@ -53,6 +47,7 @@ docstr-coverage = "^2.3.2"
53
47
  mkdocs = "^1.6.1"
54
48
  mkdocs-material = "^9.5.34"
55
49
  mkdocs-autorefs = "^1.4.2"
50
+ mkdocs-redirects = "^1.2.1"
56
51
  mkdocs-section-index = "^0.3.9"
57
52
  mkdocstrings = "^0.25.2"
58
53
  mkdocstrings-python = "^1.10.9"
@@ -61,6 +56,7 @@ ruff = "^0.7.4"
61
56
  eai-sparsify = "^1.1.1"
62
57
  mike = "^2.0.0"
63
58
  trio = "^0.30.0"
59
+ dictionary-learning = "^0.1.0"
64
60
 
65
61
  [tool.poetry.extras]
66
62
  mamba = ["mamba-lens"]
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.3.0"
2
+ __version__ = "6.25.1"
3
3
 
4
4
  import logging
5
5
 
@@ -15,19 +15,31 @@ from sae_lens.saes import (
15
15
  GatedTrainingSAEConfig,
16
16
  JumpReLUSAE,
17
17
  JumpReLUSAEConfig,
18
+ JumpReLUSkipTranscoder,
19
+ JumpReLUSkipTranscoderConfig,
18
20
  JumpReLUTrainingSAE,
19
21
  JumpReLUTrainingSAEConfig,
22
+ JumpReLUTranscoder,
23
+ JumpReLUTranscoderConfig,
24
+ MatryoshkaBatchTopKTrainingSAE,
25
+ MatryoshkaBatchTopKTrainingSAEConfig,
20
26
  SAEConfig,
27
+ SkipTranscoder,
28
+ SkipTranscoderConfig,
21
29
  StandardSAE,
22
30
  StandardSAEConfig,
23
31
  StandardTrainingSAE,
24
32
  StandardTrainingSAEConfig,
33
+ TemporalSAE,
34
+ TemporalSAEConfig,
25
35
  TopKSAE,
26
36
  TopKSAEConfig,
27
37
  TopKTrainingSAE,
28
38
  TopKTrainingSAEConfig,
29
39
  TrainingSAE,
30
40
  TrainingSAEConfig,
41
+ Transcoder,
42
+ TranscoderConfig,
31
43
  )
32
44
 
33
45
  from .analysis.hooked_sae_transformer import HookedSAETransformer
@@ -89,6 +101,18 @@ __all__ = [
89
101
  "LoggingConfig",
90
102
  "BatchTopKTrainingSAE",
91
103
  "BatchTopKTrainingSAEConfig",
104
+ "Transcoder",
105
+ "TranscoderConfig",
106
+ "SkipTranscoder",
107
+ "SkipTranscoderConfig",
108
+ "JumpReLUTranscoder",
109
+ "JumpReLUTranscoderConfig",
110
+ "JumpReLUSkipTranscoder",
111
+ "JumpReLUSkipTranscoderConfig",
112
+ "MatryoshkaBatchTopKTrainingSAE",
113
+ "MatryoshkaBatchTopKTrainingSAEConfig",
114
+ "TemporalSAE",
115
+ "TemporalSAEConfig",
92
116
  ]
93
117
 
94
118
 
@@ -103,3 +127,15 @@ register_sae_training_class("jumprelu", JumpReLUTrainingSAE, JumpReLUTrainingSAE
103
127
  register_sae_training_class(
104
128
  "batchtopk", BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig
105
129
  )
130
+ register_sae_training_class(
131
+ "matryoshka_batchtopk",
132
+ MatryoshkaBatchTopKTrainingSAE,
133
+ MatryoshkaBatchTopKTrainingSAEConfig,
134
+ )
135
+ register_sae_class("transcoder", Transcoder, TranscoderConfig)
136
+ register_sae_class("skip_transcoder", SkipTranscoder, SkipTranscoderConfig)
137
+ register_sae_class("jumprelu_transcoder", JumpReLUTranscoder, JumpReLUTranscoderConfig)
138
+ register_sae_class(
139
+ "jumprelu_skip_transcoder", JumpReLUSkipTranscoder, JumpReLUSkipTranscoderConfig
140
+ )
141
+ register_sae_class("temporal", TemporalSAE, TemporalSAEConfig)
@@ -3,15 +3,15 @@ from contextlib import contextmanager
3
3
  from typing import Any, Callable
4
4
 
5
5
  import torch
6
- from jaxtyping import Float
7
6
  from transformer_lens.ActivationCache import ActivationCache
7
+ from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP
8
8
  from transformer_lens.hook_points import HookPoint # Hooking utilities
9
9
  from transformer_lens.HookedTransformer import HookedTransformer
10
10
 
11
11
  from sae_lens.saes.sae import SAE
12
12
 
13
- SingleLoss = Float[torch.Tensor, ""] # Type alias for a single element tensor
14
- LossPerToken = Float[torch.Tensor, "batch pos-1"]
13
+ SingleLoss = torch.Tensor # Type alias for a single element tensor
14
+ LossPerToken = torch.Tensor
15
15
  Loss = SingleLoss | LossPerToken
16
16
 
17
17
 
@@ -50,6 +50,13 @@ def set_deep_attr(obj: Any, path: str, value: Any):
50
50
  setattr(obj, parts[-1], value)
51
51
 
52
52
 
53
+ def add_hook_in_to_mlp(mlp: CanBeUsedAsMLP):
54
+ # Temporary hack to add a `mlp.hook_in` hook to mimic what's in circuit-tracer
55
+ mlp.hook_in = HookPoint()
56
+ original_forward = mlp.forward
57
+ mlp.forward = lambda x: original_forward(mlp.hook_in(x)) # type: ignore
58
+
59
+
53
60
  class HookedSAETransformer(HookedTransformer):
54
61
  def __init__(
55
62
  self,
@@ -66,6 +73,11 @@ class HookedSAETransformer(HookedTransformer):
66
73
  **model_kwargs: Keyword arguments for HookedTransformer initialization
67
74
  """
68
75
  super().__init__(*model_args, **model_kwargs)
76
+
77
+ for block in self.blocks:
78
+ add_hook_in_to_mlp(block.mlp) # type: ignore
79
+ self.setup()
80
+
69
81
  self.acts_to_saes: dict[str, SAE] = {} # type: ignore
70
82
 
71
83
  def add_sae(self, sae: SAE[Any], use_error_term: bool | None = None):
@@ -158,12 +170,7 @@ class HookedSAETransformer(HookedTransformer):
158
170
  reset_saes_end: bool = True,
159
171
  use_error_term: bool | None = None,
160
172
  **model_kwargs: Any,
161
- ) -> (
162
- None
163
- | Float[torch.Tensor, "batch pos d_vocab"]
164
- | Loss
165
- | tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss]
166
- ):
173
+ ) -> None | torch.Tensor | Loss | tuple[torch.Tensor, Loss]:
167
174
  """Wrapper around HookedTransformer forward pass.
168
175
 
169
176
  Runs the model with the given SAEs attached for one forward pass, then removes them. By default, will reset all SAEs to original state after.
@@ -190,10 +197,7 @@ class HookedSAETransformer(HookedTransformer):
190
197
  remove_batch_dim: bool = False,
191
198
  **kwargs: Any,
192
199
  ) -> tuple[
193
- None
194
- | Float[torch.Tensor, "batch pos d_vocab"]
195
- | Loss
196
- | tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss],
200
+ None | torch.Tensor | Loss | tuple[torch.Tensor, Loss],
197
201
  ActivationCache | dict[str, torch.Tensor],
198
202
  ]:
199
203
  """Wrapper around 'run_with_cache' in HookedTransformer.
@@ -0,0 +1,163 @@
1
+ import json
2
+ import urllib.parse
3
+ import webbrowser
4
+ from typing import Any
5
+
6
+ import requests
7
+ from dotenv import load_dotenv
8
+
9
+ from sae_lens import SAE, logger
10
+
11
+ NEURONPEDIA_DOMAIN = "https://neuronpedia.org"
12
+
13
+ # Constants for replacing NaNs and Infs in outputs
14
+ POSITIVE_INF_REPLACEMENT = 9999
15
+ NEGATIVE_INF_REPLACEMENT = -9999
16
+ NAN_REPLACEMENT = 0
17
+ OTHER_INVALID_REPLACEMENT = -99999
18
+
19
+ # Pick up OPENAI_API_KEY from environment variable
20
+ load_dotenv()
21
+
22
+
23
+ def NanAndInfReplacer(value: str):
24
+ """Replace NaNs and Infs in outputs."""
25
+ replacements = {
26
+ "-Infinity": NEGATIVE_INF_REPLACEMENT,
27
+ "Infinity": POSITIVE_INF_REPLACEMENT,
28
+ "NaN": NAN_REPLACEMENT,
29
+ }
30
+ if value in replacements:
31
+ replaced_value = replacements[value]
32
+ return float(replaced_value)
33
+ return NAN_REPLACEMENT
34
+
35
+
36
+ def open_neuronpedia_feature_dashboard(sae: SAE[Any], index: int):
37
+ sae_id = sae.cfg.metadata.neuronpedia_id
38
+ if sae_id is None:
39
+ logger.warning(
40
+ "SAE does not have a Neuronpedia ID. Either dashboards for this SAE do not exist (yet) on Neuronpedia, or the SAE was not loaded via the from_pretrained method"
41
+ )
42
+ else:
43
+ url = f"{NEURONPEDIA_DOMAIN}/{sae_id}/{index}"
44
+ webbrowser.open(url)
45
+
46
+
47
+ def get_neuronpedia_quick_list(
48
+ sae: SAE[Any],
49
+ features: list[int],
50
+ name: str = "temporary_list",
51
+ ):
52
+ sae_id = sae.cfg.metadata.neuronpedia_id
53
+ if sae_id is None:
54
+ logger.warning(
55
+ "SAE does not have a Neuronpedia ID. Either dashboards for this SAE do not exist (yet) on Neuronpedia, or the SAE was not loaded via the from_pretrained method"
56
+ )
57
+ assert sae_id is not None
58
+
59
+ url = NEURONPEDIA_DOMAIN + "/quick-list/"
60
+ name = urllib.parse.quote(name)
61
+ url = url + "?name=" + name
62
+ list_feature = [
63
+ {
64
+ "modelId": sae.cfg.metadata.model_name,
65
+ "layer": sae_id.split("/")[1],
66
+ "index": str(feature),
67
+ }
68
+ for feature in features
69
+ ]
70
+ url = url + "&features=" + urllib.parse.quote(json.dumps(list_feature))
71
+ webbrowser.open(url)
72
+
73
+ return url
74
+
75
+
76
+ def get_neuronpedia_feature(
77
+ feature: int, layer: int, model: str = "gpt2-small", dataset: str = "res-jb"
78
+ ) -> dict[str, Any]:
79
+ """Fetch a feature from Neuronpedia API."""
80
+ url = f"{NEURONPEDIA_DOMAIN}/api/feature/{model}/{layer}-{dataset}/{feature}"
81
+ result = requests.get(url).json()
82
+ result["index"] = int(result["index"])
83
+ return result
84
+
85
+
86
+ class NeuronpediaActivation:
87
+ """Represents an activation from Neuronpedia."""
88
+
89
+ def __init__(self, id: str, tokens: list[str], act_values: list[float]):
90
+ self.id = id
91
+ self.tokens = tokens
92
+ self.act_values = act_values
93
+
94
+
95
+ class NeuronpediaFeature:
96
+ """Represents a feature from Neuronpedia."""
97
+
98
+ def __init__(
99
+ self,
100
+ modelId: str,
101
+ layer: int,
102
+ dataset: str,
103
+ feature: int,
104
+ description: str = "",
105
+ activations: list[NeuronpediaActivation] | None = None,
106
+ autointerp_explanation: str = "",
107
+ autointerp_explanation_score: float = 0.0,
108
+ ):
109
+ self.modelId = modelId
110
+ self.layer = layer
111
+ self.dataset = dataset
112
+ self.feature = feature
113
+ self.description = description
114
+ self.activations = activations
115
+ self.autointerp_explanation = autointerp_explanation
116
+ self.autointerp_explanation_score = autointerp_explanation_score
117
+
118
+ def has_activating_text(self) -> bool:
119
+ """Check if the feature has activating text."""
120
+ if self.activations is None:
121
+ return False
122
+ return any(max(activation.act_values) > 0 for activation in self.activations)
123
+
124
+
125
+ def make_neuronpedia_list_with_features(
126
+ api_key: str,
127
+ list_name: str,
128
+ features: list[NeuronpediaFeature],
129
+ list_description: str | None = None,
130
+ open_browser: bool = True,
131
+ ):
132
+ url = NEURONPEDIA_DOMAIN + "/api/list/new-with-features"
133
+
134
+ # make POST json request with body
135
+ body = {
136
+ "name": list_name,
137
+ "description": list_description,
138
+ "features": [
139
+ {
140
+ "modelId": feature.modelId,
141
+ "layer": f"{feature.layer}-{feature.dataset}",
142
+ "index": feature.feature,
143
+ "description": feature.description,
144
+ }
145
+ for feature in features
146
+ ],
147
+ }
148
+ response = requests.post(url, json=body, headers={"x-api-key": api_key})
149
+ result = response.json()
150
+
151
+ if "url" in result and open_browser:
152
+ webbrowser.open(result["url"])
153
+ return result["url"]
154
+ raise Exception("Error in creating list: " + result["message"])
155
+
156
+
157
+ def test_key(api_key: str):
158
+ """Test the validity of the Neuronpedia API key."""
159
+ url = f"{NEURONPEDIA_DOMAIN}/api/test"
160
+ body = {"apiKey": api_key}
161
+ response = requests.post(url, json=body)
162
+ if response.status_code != 200:
163
+ raise Exception("Neuronpedia API key is not valid.")
@@ -9,15 +9,14 @@ import torch
9
9
  from datasets import Array2D, Dataset, Features, Sequence, Value
10
10
  from datasets.fingerprint import generate_fingerprint
11
11
  from huggingface_hub import HfApi
12
- from jaxtyping import Float, Int
13
- from tqdm import tqdm
12
+ from tqdm.auto import tqdm
14
13
  from transformer_lens.HookedTransformer import HookedRootModule
15
14
 
16
15
  from sae_lens import logger
17
16
  from sae_lens.config import CacheActivationsRunnerConfig
18
- from sae_lens.constants import DTYPE_MAP
19
17
  from sae_lens.load_model import load_model
20
18
  from sae_lens.training.activations_store import ActivationsStore
19
+ from sae_lens.util import str_to_dtype
21
20
 
22
21
 
23
22
  def _mk_activations_store(
@@ -82,7 +81,7 @@ class CacheActivationsRunner:
82
81
  )
83
82
  for hook_name in [self.cfg.hook_name]
84
83
  }
85
- features_dict["token_ids"] = Sequence(
84
+ features_dict["token_ids"] = Sequence( # type: ignore
86
85
  Value(dtype="int32"), length=self.context_size
87
86
  )
88
87
  self.features = Features(features_dict)
@@ -98,7 +97,7 @@ class CacheActivationsRunner:
98
97
  bytes_per_token = (
99
98
  self.cfg.d_in * self.cfg.dtype.itemsize
100
99
  if isinstance(self.cfg.dtype, torch.dtype)
101
- else DTYPE_MAP[self.cfg.dtype].itemsize
100
+ else str_to_dtype(self.cfg.dtype).itemsize
102
101
  )
103
102
  total_training_tokens = self.cfg.n_seq_in_dataset * self.context_size
104
103
  total_disk_space_gb = total_training_tokens * bytes_per_token / 10**9
@@ -318,8 +317,8 @@ class CacheActivationsRunner:
318
317
  def _create_shard(
319
318
  self,
320
319
  buffer: tuple[
321
- Float[torch.Tensor, "(bs context_size) d_in"],
322
- Int[torch.Tensor, "(bs context_size)"] | None,
320
+ torch.Tensor, # shape: (bs context_size) d_in
321
+ torch.Tensor | None, # shape: (bs context_size) or None
323
322
  ],
324
323
  ) -> Dataset:
325
324
  hook_names = [self.cfg.hook_name]