sae-lens 5.7.1__tar.gz → 6.25.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sae-lens might be problematic. Click here for more details.
- {sae_lens-5.7.1 → sae_lens-6.25.1}/PKG-INFO +31 -31
- {sae_lens-5.7.1 → sae_lens-6.25.1}/README.md +18 -14
- {sae_lens-5.7.1 → sae_lens-6.25.1}/pyproject.toml +20 -20
- sae_lens-6.25.1/sae_lens/__init__.py +141 -0
- {sae_lens-5.7.1 → sae_lens-6.25.1}/sae_lens/analysis/hooked_sae_transformer.py +29 -25
- sae_lens-6.25.1/sae_lens/analysis/neuronpedia_integration.py +163 -0
- {sae_lens-5.7.1 → sae_lens-6.25.1}/sae_lens/cache_activations_runner.py +13 -12
- {sae_lens-5.7.1 → sae_lens-6.25.1}/sae_lens/config.py +254 -271
- sae_lens-6.25.1/sae_lens/constants.py +30 -0
- {sae_lens-5.7.1 → sae_lens-6.25.1}/sae_lens/evals.py +146 -87
- sae_lens-6.25.1/sae_lens/llm_sae_training_runner.py +429 -0
- {sae_lens-5.7.1 → sae_lens-6.25.1}/sae_lens/load_model.py +54 -6
- sae_lens-6.25.1/sae_lens/loading/pretrained_sae_loaders.py +1911 -0
- {sae_lens-5.7.1/sae_lens/toolkit → sae_lens-6.25.1/sae_lens/loading}/pretrained_saes_directory.py +17 -3
- {sae_lens-5.7.1 → sae_lens-6.25.1}/sae_lens/pretokenize_runner.py +8 -4
- sae_lens-6.25.1/sae_lens/pretrained_saes.yaml +41813 -0
- sae_lens-6.25.1/sae_lens/registry.py +49 -0
- sae_lens-6.25.1/sae_lens/saes/__init__.py +81 -0
- sae_lens-6.25.1/sae_lens/saes/batchtopk_sae.py +134 -0
- sae_lens-6.25.1/sae_lens/saes/gated_sae.py +242 -0
- sae_lens-6.25.1/sae_lens/saes/jumprelu_sae.py +367 -0
- sae_lens-6.25.1/sae_lens/saes/matryoshka_batchtopk_sae.py +136 -0
- sae_lens-6.25.1/sae_lens/saes/sae.py +1067 -0
- sae_lens-6.25.1/sae_lens/saes/standard_sae.py +165 -0
- sae_lens-6.25.1/sae_lens/saes/temporal_sae.py +365 -0
- sae_lens-6.25.1/sae_lens/saes/topk_sae.py +538 -0
- sae_lens-6.25.1/sae_lens/saes/transcoder.py +411 -0
- {sae_lens-5.7.1 → sae_lens-6.25.1}/sae_lens/tokenization_and_batching.py +25 -2
- sae_lens-6.25.1/sae_lens/training/activation_scaler.py +60 -0
- {sae_lens-5.7.1 → sae_lens-6.25.1}/sae_lens/training/activations_store.py +179 -237
- sae_lens-6.25.1/sae_lens/training/mixing_buffer.py +56 -0
- {sae_lens-5.7.1 → sae_lens-6.25.1}/sae_lens/training/optim.py +36 -34
- sae_lens-6.25.1/sae_lens/training/sae_trainer.py +455 -0
- sae_lens-6.25.1/sae_lens/training/types.py +5 -0
- {sae_lens-5.7.1 → sae_lens-6.25.1}/sae_lens/training/upload_saes_to_huggingface.py +17 -6
- sae_lens-6.25.1/sae_lens/util.py +113 -0
- sae_lens-5.7.1/sae_lens/__init__.py +0 -39
- sae_lens-5.7.1/sae_lens/analysis/neuronpedia_integration.py +0 -492
- sae_lens-5.7.1/sae_lens/pretrained_saes.yaml +0 -13961
- sae_lens-5.7.1/sae_lens/sae.py +0 -737
- sae_lens-5.7.1/sae_lens/sae_training_runner.py +0 -251
- sae_lens-5.7.1/sae_lens/toolkit/pretrained_sae_loaders.py +0 -879
- sae_lens-5.7.1/sae_lens/training/geometric_median.py +0 -101
- sae_lens-5.7.1/sae_lens/training/sae_trainer.py +0 -444
- sae_lens-5.7.1/sae_lens/training/training_sae.py +0 -711
- {sae_lens-5.7.1 → sae_lens-6.25.1}/LICENSE +0 -0
- {sae_lens-5.7.1 → sae_lens-6.25.1}/sae_lens/analysis/__init__.py +0 -0
- {sae_lens-5.7.1/sae_lens/toolkit → sae_lens-6.25.1/sae_lens/loading}/__init__.py +0 -0
- {sae_lens-5.7.1 → sae_lens-6.25.1}/sae_lens/training/__init__.py +0 -0
- {sae_lens-5.7.1 → sae_lens-6.25.1}/sae_lens/tutorial/tsea.py +0 -0
|
@@ -1,8 +1,9 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: sae-lens
|
|
3
|
-
Version:
|
|
3
|
+
Version: 6.25.1
|
|
4
4
|
Summary: Training and Analyzing Sparse Autoencoders (SAEs)
|
|
5
5
|
License: MIT
|
|
6
|
+
License-File: LICENSE
|
|
6
7
|
Keywords: deep-learning,sparse-autoencoders,mechanistic-interpretability,PyTorch
|
|
7
8
|
Author: Joseph Bloom
|
|
8
9
|
Requires-Python: >=3.10,<4.0
|
|
@@ -12,41 +13,36 @@ Classifier: Programming Language :: Python :: 3.10
|
|
|
12
13
|
Classifier: Programming Language :: Python :: 3.11
|
|
13
14
|
Classifier: Programming Language :: Python :: 3.12
|
|
14
15
|
Classifier: Programming Language :: Python :: 3.13
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.14
|
|
15
17
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
16
18
|
Provides-Extra: mamba
|
|
17
|
-
Requires-Dist: automated-interpretability (>=0.0.5,<1.0.0)
|
|
18
19
|
Requires-Dist: babe (>=0.0.7,<0.0.8)
|
|
19
|
-
Requires-Dist: datasets (>=
|
|
20
|
+
Requires-Dist: datasets (>=3.1.0)
|
|
20
21
|
Requires-Dist: mamba-lens (>=0.0.4,<0.0.5) ; extra == "mamba"
|
|
21
|
-
Requires-Dist: matplotlib (>=3.8.3,<4.0.0)
|
|
22
|
-
Requires-Dist: matplotlib-inline (>=0.1.6,<0.2.0)
|
|
23
22
|
Requires-Dist: nltk (>=3.8.1,<4.0.0)
|
|
24
|
-
Requires-Dist: plotly (>=5.19.0
|
|
25
|
-
Requires-Dist: plotly-express (>=0.4.1
|
|
26
|
-
Requires-Dist:
|
|
27
|
-
Requires-Dist: python-dotenv (>=1.0.1,<2.0.0)
|
|
23
|
+
Requires-Dist: plotly (>=5.19.0)
|
|
24
|
+
Requires-Dist: plotly-express (>=0.4.1)
|
|
25
|
+
Requires-Dist: python-dotenv (>=1.0.1)
|
|
28
26
|
Requires-Dist: pyyaml (>=6.0.1,<7.0.0)
|
|
29
|
-
Requires-Dist:
|
|
30
|
-
Requires-Dist: safetensors (>=0.4.2,<0.5.0)
|
|
27
|
+
Requires-Dist: safetensors (>=0.4.2,<1.0.0)
|
|
31
28
|
Requires-Dist: simple-parsing (>=0.1.6,<0.2.0)
|
|
32
|
-
Requires-Dist:
|
|
29
|
+
Requires-Dist: tenacity (>=9.0.0)
|
|
30
|
+
Requires-Dist: transformer-lens (>=2.16.1,<3.0.0)
|
|
33
31
|
Requires-Dist: transformers (>=4.38.1,<5.0.0)
|
|
34
|
-
Requires-Dist: typer (>=0.12.3,<0.13.0)
|
|
35
32
|
Requires-Dist: typing-extensions (>=4.10.0,<5.0.0)
|
|
36
|
-
|
|
37
|
-
Project-URL:
|
|
38
|
-
Project-URL: Repository, https://github.com/jbloomAus/SAELens
|
|
33
|
+
Project-URL: Homepage, https://decoderesearch.github.io/SAELens
|
|
34
|
+
Project-URL: Repository, https://github.com/decoderesearch/SAELens
|
|
39
35
|
Description-Content-Type: text/markdown
|
|
40
36
|
|
|
41
|
-
<img width="1308"
|
|
37
|
+
<img width="1308" height="532" alt="saes_pic" src="https://github.com/user-attachments/assets/2a5d752f-b261-4ee4-ad5d-ebf282321371" />
|
|
42
38
|
|
|
43
39
|
# SAE Lens
|
|
44
40
|
|
|
45
41
|
[](https://pypi.org/project/sae-lens/)
|
|
46
42
|
[](https://opensource.org/licenses/MIT)
|
|
47
|
-
[](https://github.com/decoderesearch/SAELens/actions/workflows/build.yml)
|
|
44
|
+
[](https://github.com/decoderesearch/SAELens/actions/workflows/deploy_docs.yml)
|
|
45
|
+
[](https://codecov.io/gh/decoderesearch/SAELens)
|
|
50
46
|
|
|
51
47
|
SAELens exists to help researchers:
|
|
52
48
|
|
|
@@ -54,7 +50,7 @@ SAELens exists to help researchers:
|
|
|
54
50
|
- Analyse sparse autoencoders / research mechanistic interpretability.
|
|
55
51
|
- Generate insights which make it easier to create safe and aligned AI systems.
|
|
56
52
|
|
|
57
|
-
Please refer to the [documentation](https://
|
|
53
|
+
Please refer to the [documentation](https://decoderesearch.github.io/SAELens/) for information on how to:
|
|
58
54
|
|
|
59
55
|
- Download and Analyse pre-trained sparse autoencoders.
|
|
60
56
|
- Train your own sparse autoencoders.
|
|
@@ -62,25 +58,29 @@ Please refer to the [documentation](https://jbloomaus.github.io/SAELens/) for in
|
|
|
62
58
|
|
|
63
59
|
SAE Lens is the result of many contributors working collectively to improve humanity's understanding of neural networks, many of whom are motivated by a desire to [safeguard humanity from risks posed by artificial intelligence](https://80000hours.org/problem-profiles/artificial-intelligence/).
|
|
64
60
|
|
|
65
|
-
This library is maintained by [Joseph Bloom](https://www.
|
|
61
|
+
This library is maintained by [Joseph Bloom](https://www.decoderesearch.com/), [Curt Tigges](https://curttigges.com/), [Anthony Duong](https://github.com/anthonyduong9) and [David Chanin](https://github.com/chanind).
|
|
66
62
|
|
|
67
63
|
## Loading Pre-trained SAEs.
|
|
68
64
|
|
|
69
|
-
Pre-trained SAEs for various models can be imported via SAE Lens. See this [page](https://
|
|
65
|
+
Pre-trained SAEs for various models can be imported via SAE Lens. See this [page](https://decoderesearch.github.io/SAELens/pretrained_saes/) for a list of all SAEs.
|
|
66
|
+
|
|
67
|
+
## Migrating to SAELens v6
|
|
68
|
+
|
|
69
|
+
The new v6 update is a major refactor to SAELens and changes the way training code is structured. Check out the [migration guide](https://decoderesearch.github.io/SAELens/latest/migrating/) for more details.
|
|
70
70
|
|
|
71
71
|
## Tutorials
|
|
72
72
|
|
|
73
|
-
- [SAE Lens + Neuronpedia](tutorials/tutorial_2_0.ipynb)[](https://githubtocolab.com/
|
|
73
|
+
- [SAE Lens + Neuronpedia](tutorials/tutorial_2_0.ipynb)[](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/tutorial_2_0.ipynb)
|
|
74
74
|
- [Loading and Analysing Pre-Trained Sparse Autoencoders](tutorials/basic_loading_and_analysing.ipynb)
|
|
75
|
-
[](https://githubtocolab.com/
|
|
75
|
+
[](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb)
|
|
76
76
|
- [Understanding SAE Features with the Logit Lens](tutorials/logits_lens_with_features.ipynb)
|
|
77
|
-
[](https://githubtocolab.com/
|
|
77
|
+
[](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/logits_lens_with_features.ipynb)
|
|
78
78
|
- [Training a Sparse Autoencoder](tutorials/training_a_sparse_autoencoder.ipynb)
|
|
79
|
-
[](https://githubtocolab.com/
|
|
79
|
+
[](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)
|
|
80
80
|
|
|
81
81
|
## Join the Slack!
|
|
82
82
|
|
|
83
|
-
Feel free to join the [Open Source Mechanistic Interpretability Slack](https://join.slack.com/t/opensourcemechanistic/shared_invite/zt-
|
|
83
|
+
Feel free to join the [Open Source Mechanistic Interpretability Slack](https://join.slack.com/t/opensourcemechanistic/shared_invite/zt-375zalm04-GFd5tdBU1yLKlu_T_JSqZQ) for support!
|
|
84
84
|
|
|
85
85
|
## Citation
|
|
86
86
|
|
|
@@ -89,9 +89,9 @@ Please cite the package as follows:
|
|
|
89
89
|
```
|
|
90
90
|
@misc{bloom2024saetrainingcodebase,
|
|
91
91
|
title = {SAELens},
|
|
92
|
-
author = {Joseph
|
|
92
|
+
author = {Bloom, Joseph and Tigges, Curt and Duong, Anthony and Chanin, David},
|
|
93
93
|
year = {2024},
|
|
94
|
-
howpublished = {\url{https://github.com/
|
|
94
|
+
howpublished = {\url{https://github.com/decoderesearch/SAELens}},
|
|
95
95
|
}
|
|
96
96
|
```
|
|
97
97
|
|
|
@@ -1,12 +1,12 @@
|
|
|
1
|
-
<img width="1308"
|
|
1
|
+
<img width="1308" height="532" alt="saes_pic" src="https://github.com/user-attachments/assets/2a5d752f-b261-4ee4-ad5d-ebf282321371" />
|
|
2
2
|
|
|
3
3
|
# SAE Lens
|
|
4
4
|
|
|
5
5
|
[](https://pypi.org/project/sae-lens/)
|
|
6
6
|
[](https://opensource.org/licenses/MIT)
|
|
7
|
-
[](https://github.com/decoderesearch/SAELens/actions/workflows/build.yml)
|
|
8
|
+
[](https://github.com/decoderesearch/SAELens/actions/workflows/deploy_docs.yml)
|
|
9
|
+
[](https://codecov.io/gh/decoderesearch/SAELens)
|
|
10
10
|
|
|
11
11
|
SAELens exists to help researchers:
|
|
12
12
|
|
|
@@ -14,7 +14,7 @@ SAELens exists to help researchers:
|
|
|
14
14
|
- Analyse sparse autoencoders / research mechanistic interpretability.
|
|
15
15
|
- Generate insights which make it easier to create safe and aligned AI systems.
|
|
16
16
|
|
|
17
|
-
Please refer to the [documentation](https://
|
|
17
|
+
Please refer to the [documentation](https://decoderesearch.github.io/SAELens/) for information on how to:
|
|
18
18
|
|
|
19
19
|
- Download and Analyse pre-trained sparse autoencoders.
|
|
20
20
|
- Train your own sparse autoencoders.
|
|
@@ -22,25 +22,29 @@ Please refer to the [documentation](https://jbloomaus.github.io/SAELens/) for in
|
|
|
22
22
|
|
|
23
23
|
SAE Lens is the result of many contributors working collectively to improve humanity's understanding of neural networks, many of whom are motivated by a desire to [safeguard humanity from risks posed by artificial intelligence](https://80000hours.org/problem-profiles/artificial-intelligence/).
|
|
24
24
|
|
|
25
|
-
This library is maintained by [Joseph Bloom](https://www.
|
|
25
|
+
This library is maintained by [Joseph Bloom](https://www.decoderesearch.com/), [Curt Tigges](https://curttigges.com/), [Anthony Duong](https://github.com/anthonyduong9) and [David Chanin](https://github.com/chanind).
|
|
26
26
|
|
|
27
27
|
## Loading Pre-trained SAEs.
|
|
28
28
|
|
|
29
|
-
Pre-trained SAEs for various models can be imported via SAE Lens. See this [page](https://
|
|
29
|
+
Pre-trained SAEs for various models can be imported via SAE Lens. See this [page](https://decoderesearch.github.io/SAELens/pretrained_saes/) for a list of all SAEs.
|
|
30
|
+
|
|
31
|
+
## Migrating to SAELens v6
|
|
32
|
+
|
|
33
|
+
The new v6 update is a major refactor to SAELens and changes the way training code is structured. Check out the [migration guide](https://decoderesearch.github.io/SAELens/latest/migrating/) for more details.
|
|
30
34
|
|
|
31
35
|
## Tutorials
|
|
32
36
|
|
|
33
|
-
- [SAE Lens + Neuronpedia](tutorials/tutorial_2_0.ipynb)[](https://githubtocolab.com/
|
|
37
|
+
- [SAE Lens + Neuronpedia](tutorials/tutorial_2_0.ipynb)[](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/tutorial_2_0.ipynb)
|
|
34
38
|
- [Loading and Analysing Pre-Trained Sparse Autoencoders](tutorials/basic_loading_and_analysing.ipynb)
|
|
35
|
-
[](https://githubtocolab.com/
|
|
39
|
+
[](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb)
|
|
36
40
|
- [Understanding SAE Features with the Logit Lens](tutorials/logits_lens_with_features.ipynb)
|
|
37
|
-
[](https://githubtocolab.com/
|
|
41
|
+
[](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/logits_lens_with_features.ipynb)
|
|
38
42
|
- [Training a Sparse Autoencoder](tutorials/training_a_sparse_autoencoder.ipynb)
|
|
39
|
-
[](https://githubtocolab.com/
|
|
43
|
+
[](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)
|
|
40
44
|
|
|
41
45
|
## Join the Slack!
|
|
42
46
|
|
|
43
|
-
Feel free to join the [Open Source Mechanistic Interpretability Slack](https://join.slack.com/t/opensourcemechanistic/shared_invite/zt-
|
|
47
|
+
Feel free to join the [Open Source Mechanistic Interpretability Slack](https://join.slack.com/t/opensourcemechanistic/shared_invite/zt-375zalm04-GFd5tdBU1yLKlu_T_JSqZQ) for support!
|
|
44
48
|
|
|
45
49
|
## Citation
|
|
46
50
|
|
|
@@ -49,8 +53,8 @@ Please cite the package as follows:
|
|
|
49
53
|
```
|
|
50
54
|
@misc{bloom2024saetrainingcodebase,
|
|
51
55
|
title = {SAELens},
|
|
52
|
-
author = {Joseph
|
|
56
|
+
author = {Bloom, Joseph and Tigges, Curt and Duong, Anthony and Chanin, David},
|
|
53
57
|
year = {2024},
|
|
54
|
-
howpublished = {\url{https://github.com/
|
|
58
|
+
howpublished = {\url{https://github.com/decoderesearch/SAELens}},
|
|
55
59
|
}
|
|
56
60
|
```
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "sae-lens"
|
|
3
|
-
version = "
|
|
3
|
+
version = "6.25.1"
|
|
4
4
|
description = "Training and Analyzing Sparse Autoencoders (SAEs)"
|
|
5
5
|
authors = ["Joseph Bloom"]
|
|
6
6
|
readme = "README.md"
|
|
7
7
|
packages = [{ include = "sae_lens" }]
|
|
8
8
|
include = ["pretrained_saes.yaml"]
|
|
9
|
-
repository = "https://github.com/
|
|
10
|
-
homepage = "https://
|
|
9
|
+
repository = "https://github.com/decoderesearch/SAELens"
|
|
10
|
+
homepage = "https://decoderesearch.github.io/SAELens"
|
|
11
11
|
license = "MIT"
|
|
12
12
|
keywords = [
|
|
13
13
|
"deep-learning",
|
|
@@ -19,26 +19,20 @@ classifiers = ["Topic :: Scientific/Engineering :: Artificial Intelligence"]
|
|
|
19
19
|
|
|
20
20
|
[tool.poetry.dependencies]
|
|
21
21
|
python = "^3.10"
|
|
22
|
-
transformer-lens = "^2.
|
|
22
|
+
transformer-lens = "^2.16.1"
|
|
23
23
|
transformers = "^4.38.1"
|
|
24
|
-
plotly = "
|
|
25
|
-
plotly-express = "
|
|
26
|
-
|
|
27
|
-
matplotlib-inline = "^0.1.6"
|
|
28
|
-
datasets = "^2.17.1"
|
|
24
|
+
plotly = ">=5.19.0"
|
|
25
|
+
plotly-express = ">=0.4.1"
|
|
26
|
+
datasets = ">=3.1.0"
|
|
29
27
|
babe = "^0.0.7"
|
|
30
28
|
nltk = "^3.8.1"
|
|
31
|
-
safetensors = "
|
|
32
|
-
typer = "^0.12.3"
|
|
29
|
+
safetensors = ">=0.4.2,<1.0.0"
|
|
33
30
|
mamba-lens = { version = "^0.0.4", optional = true }
|
|
34
|
-
|
|
35
|
-
automated-interpretability = ">=0.0.5,<1.0.0"
|
|
36
|
-
python-dotenv = "^1.0.1"
|
|
31
|
+
python-dotenv = ">=1.0.1"
|
|
37
32
|
pyyaml = "^6.0.1"
|
|
38
|
-
pytest-profiling = "^1.7.0"
|
|
39
|
-
zstandard = "^0.22.0"
|
|
40
33
|
typing-extensions = "^4.10.0"
|
|
41
34
|
simple-parsing = "^0.1.6"
|
|
35
|
+
tenacity = ">=9.0.0"
|
|
42
36
|
|
|
43
37
|
[tool.poetry.group.dev.dependencies]
|
|
44
38
|
pytest = "^8.0.2"
|
|
@@ -52,13 +46,17 @@ boto3 = "^1.34.101"
|
|
|
52
46
|
docstr-coverage = "^2.3.2"
|
|
53
47
|
mkdocs = "^1.6.1"
|
|
54
48
|
mkdocs-material = "^9.5.34"
|
|
55
|
-
mkdocs-autorefs = "^1.
|
|
49
|
+
mkdocs-autorefs = "^1.4.2"
|
|
50
|
+
mkdocs-redirects = "^1.2.1"
|
|
56
51
|
mkdocs-section-index = "^0.3.9"
|
|
57
52
|
mkdocstrings = "^0.25.2"
|
|
58
53
|
mkdocstrings-python = "^1.10.9"
|
|
59
54
|
tabulate = "^0.9.0"
|
|
60
55
|
ruff = "^0.7.4"
|
|
61
|
-
sparsify =
|
|
56
|
+
eai-sparsify = "^1.1.1"
|
|
57
|
+
mike = "^2.0.0"
|
|
58
|
+
trio = "^0.30.0"
|
|
59
|
+
dictionary-learning = "^0.1.0"
|
|
62
60
|
|
|
63
61
|
[tool.poetry.extras]
|
|
64
62
|
mamba = ["mamba-lens"]
|
|
@@ -69,8 +67,9 @@ ignore = ["E203", "E501", "E731", "F722", "E741", "F821", "F403", "ARG002"]
|
|
|
69
67
|
select = ["UP", "TID", "I", "F", "E", "ARG", "SIM", "RET", "LOG", "T20"]
|
|
70
68
|
|
|
71
69
|
[tool.ruff.lint.per-file-ignores]
|
|
72
|
-
"benchmark/*" = ["T20"]
|
|
70
|
+
"benchmark/*" = ["T20", "TID251"]
|
|
73
71
|
"scripts/*" = ["T20"]
|
|
72
|
+
"tests/*" = ["TID251"]
|
|
74
73
|
|
|
75
74
|
[tool.ruff.lint.flake8-tidy-imports.banned-api]
|
|
76
75
|
"typing.Union".msg = "Use `|` instead"
|
|
@@ -78,6 +77,7 @@ select = ["UP", "TID", "I", "F", "E", "ARG", "SIM", "RET", "LOG", "T20"]
|
|
|
78
77
|
"typing.Dict".msg = "Use `dict` instead"
|
|
79
78
|
"typing.Tuple".msg = "Use `tuple` instead"
|
|
80
79
|
"typing.List".msg = "Use `list` instead"
|
|
80
|
+
"tests".msg = "Do not import from tests in the main codebase."
|
|
81
81
|
|
|
82
82
|
[tool.pyright]
|
|
83
83
|
typeCheckingMode = "strict"
|
|
@@ -102,5 +102,5 @@ build-backend = "poetry.core.masonry.api"
|
|
|
102
102
|
[tool.semantic_release]
|
|
103
103
|
version_variables = ["sae_lens/__init__.py:__version__"]
|
|
104
104
|
version_toml = ["pyproject.toml:tool.poetry.version"]
|
|
105
|
-
branch = "main"
|
|
106
105
|
build_command = "pip install poetry && poetry build"
|
|
106
|
+
branches = { main = { match = "main" }, alpha = { match = "alpha", prerelease = true }, beta = { match = "beta", prerelease = true } }
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
# ruff: noqa: E402
|
|
2
|
+
__version__ = "6.25.1"
|
|
3
|
+
|
|
4
|
+
import logging
|
|
5
|
+
|
|
6
|
+
logger = logging.getLogger(__name__)
|
|
7
|
+
|
|
8
|
+
from sae_lens.saes import (
|
|
9
|
+
SAE,
|
|
10
|
+
BatchTopKTrainingSAE,
|
|
11
|
+
BatchTopKTrainingSAEConfig,
|
|
12
|
+
GatedSAE,
|
|
13
|
+
GatedSAEConfig,
|
|
14
|
+
GatedTrainingSAE,
|
|
15
|
+
GatedTrainingSAEConfig,
|
|
16
|
+
JumpReLUSAE,
|
|
17
|
+
JumpReLUSAEConfig,
|
|
18
|
+
JumpReLUSkipTranscoder,
|
|
19
|
+
JumpReLUSkipTranscoderConfig,
|
|
20
|
+
JumpReLUTrainingSAE,
|
|
21
|
+
JumpReLUTrainingSAEConfig,
|
|
22
|
+
JumpReLUTranscoder,
|
|
23
|
+
JumpReLUTranscoderConfig,
|
|
24
|
+
MatryoshkaBatchTopKTrainingSAE,
|
|
25
|
+
MatryoshkaBatchTopKTrainingSAEConfig,
|
|
26
|
+
SAEConfig,
|
|
27
|
+
SkipTranscoder,
|
|
28
|
+
SkipTranscoderConfig,
|
|
29
|
+
StandardSAE,
|
|
30
|
+
StandardSAEConfig,
|
|
31
|
+
StandardTrainingSAE,
|
|
32
|
+
StandardTrainingSAEConfig,
|
|
33
|
+
TemporalSAE,
|
|
34
|
+
TemporalSAEConfig,
|
|
35
|
+
TopKSAE,
|
|
36
|
+
TopKSAEConfig,
|
|
37
|
+
TopKTrainingSAE,
|
|
38
|
+
TopKTrainingSAEConfig,
|
|
39
|
+
TrainingSAE,
|
|
40
|
+
TrainingSAEConfig,
|
|
41
|
+
Transcoder,
|
|
42
|
+
TranscoderConfig,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
from .analysis.hooked_sae_transformer import HookedSAETransformer
|
|
46
|
+
from .cache_activations_runner import CacheActivationsRunner
|
|
47
|
+
from .config import (
|
|
48
|
+
CacheActivationsRunnerConfig,
|
|
49
|
+
LanguageModelSAERunnerConfig,
|
|
50
|
+
LoggingConfig,
|
|
51
|
+
PretokenizeRunnerConfig,
|
|
52
|
+
)
|
|
53
|
+
from .evals import run_evals
|
|
54
|
+
from .llm_sae_training_runner import LanguageModelSAETrainingRunner, SAETrainingRunner
|
|
55
|
+
from .loading.pretrained_sae_loaders import (
|
|
56
|
+
PretrainedSaeDiskLoader,
|
|
57
|
+
PretrainedSaeHuggingfaceLoader,
|
|
58
|
+
)
|
|
59
|
+
from .pretokenize_runner import PretokenizeRunner, pretokenize_runner
|
|
60
|
+
from .registry import register_sae_class, register_sae_training_class
|
|
61
|
+
from .training.activations_store import ActivationsStore
|
|
62
|
+
from .training.upload_saes_to_huggingface import upload_saes_to_huggingface
|
|
63
|
+
|
|
64
|
+
__all__ = [
|
|
65
|
+
"SAE",
|
|
66
|
+
"SAEConfig",
|
|
67
|
+
"TrainingSAE",
|
|
68
|
+
"TrainingSAEConfig",
|
|
69
|
+
"HookedSAETransformer",
|
|
70
|
+
"ActivationsStore",
|
|
71
|
+
"LanguageModelSAERunnerConfig",
|
|
72
|
+
"LanguageModelSAETrainingRunner",
|
|
73
|
+
"CacheActivationsRunnerConfig",
|
|
74
|
+
"CacheActivationsRunner",
|
|
75
|
+
"PretokenizeRunnerConfig",
|
|
76
|
+
"PretokenizeRunner",
|
|
77
|
+
"pretokenize_runner",
|
|
78
|
+
"run_evals",
|
|
79
|
+
"upload_saes_to_huggingface",
|
|
80
|
+
"PretrainedSaeHuggingfaceLoader",
|
|
81
|
+
"PretrainedSaeDiskLoader",
|
|
82
|
+
"register_sae_class",
|
|
83
|
+
"register_sae_training_class",
|
|
84
|
+
"StandardSAE",
|
|
85
|
+
"StandardSAEConfig",
|
|
86
|
+
"StandardTrainingSAE",
|
|
87
|
+
"StandardTrainingSAEConfig",
|
|
88
|
+
"GatedSAE",
|
|
89
|
+
"GatedSAEConfig",
|
|
90
|
+
"GatedTrainingSAE",
|
|
91
|
+
"GatedTrainingSAEConfig",
|
|
92
|
+
"TopKSAE",
|
|
93
|
+
"TopKSAEConfig",
|
|
94
|
+
"TopKTrainingSAE",
|
|
95
|
+
"TopKTrainingSAEConfig",
|
|
96
|
+
"JumpReLUSAE",
|
|
97
|
+
"JumpReLUSAEConfig",
|
|
98
|
+
"JumpReLUTrainingSAE",
|
|
99
|
+
"JumpReLUTrainingSAEConfig",
|
|
100
|
+
"SAETrainingRunner",
|
|
101
|
+
"LoggingConfig",
|
|
102
|
+
"BatchTopKTrainingSAE",
|
|
103
|
+
"BatchTopKTrainingSAEConfig",
|
|
104
|
+
"Transcoder",
|
|
105
|
+
"TranscoderConfig",
|
|
106
|
+
"SkipTranscoder",
|
|
107
|
+
"SkipTranscoderConfig",
|
|
108
|
+
"JumpReLUTranscoder",
|
|
109
|
+
"JumpReLUTranscoderConfig",
|
|
110
|
+
"JumpReLUSkipTranscoder",
|
|
111
|
+
"JumpReLUSkipTranscoderConfig",
|
|
112
|
+
"MatryoshkaBatchTopKTrainingSAE",
|
|
113
|
+
"MatryoshkaBatchTopKTrainingSAEConfig",
|
|
114
|
+
"TemporalSAE",
|
|
115
|
+
"TemporalSAEConfig",
|
|
116
|
+
]
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
register_sae_class("standard", StandardSAE, StandardSAEConfig)
|
|
120
|
+
register_sae_training_class("standard", StandardTrainingSAE, StandardTrainingSAEConfig)
|
|
121
|
+
register_sae_class("gated", GatedSAE, GatedSAEConfig)
|
|
122
|
+
register_sae_training_class("gated", GatedTrainingSAE, GatedTrainingSAEConfig)
|
|
123
|
+
register_sae_class("topk", TopKSAE, TopKSAEConfig)
|
|
124
|
+
register_sae_training_class("topk", TopKTrainingSAE, TopKTrainingSAEConfig)
|
|
125
|
+
register_sae_class("jumprelu", JumpReLUSAE, JumpReLUSAEConfig)
|
|
126
|
+
register_sae_training_class("jumprelu", JumpReLUTrainingSAE, JumpReLUTrainingSAEConfig)
|
|
127
|
+
register_sae_training_class(
|
|
128
|
+
"batchtopk", BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig
|
|
129
|
+
)
|
|
130
|
+
register_sae_training_class(
|
|
131
|
+
"matryoshka_batchtopk",
|
|
132
|
+
MatryoshkaBatchTopKTrainingSAE,
|
|
133
|
+
MatryoshkaBatchTopKTrainingSAEConfig,
|
|
134
|
+
)
|
|
135
|
+
register_sae_class("transcoder", Transcoder, TranscoderConfig)
|
|
136
|
+
register_sae_class("skip_transcoder", SkipTranscoder, SkipTranscoderConfig)
|
|
137
|
+
register_sae_class("jumprelu_transcoder", JumpReLUTranscoder, JumpReLUTranscoderConfig)
|
|
138
|
+
register_sae_class(
|
|
139
|
+
"jumprelu_skip_transcoder", JumpReLUSkipTranscoder, JumpReLUSkipTranscoderConfig
|
|
140
|
+
)
|
|
141
|
+
register_sae_class("temporal", TemporalSAE, TemporalSAEConfig)
|
|
@@ -3,15 +3,15 @@ from contextlib import contextmanager
|
|
|
3
3
|
from typing import Any, Callable
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
|
-
from jaxtyping import Float
|
|
7
6
|
from transformer_lens.ActivationCache import ActivationCache
|
|
7
|
+
from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP
|
|
8
8
|
from transformer_lens.hook_points import HookPoint # Hooking utilities
|
|
9
9
|
from transformer_lens.HookedTransformer import HookedTransformer
|
|
10
10
|
|
|
11
|
-
from sae_lens.sae import SAE
|
|
11
|
+
from sae_lens.saes.sae import SAE
|
|
12
12
|
|
|
13
|
-
SingleLoss =
|
|
14
|
-
LossPerToken =
|
|
13
|
+
SingleLoss = torch.Tensor # Type alias for a single element tensor
|
|
14
|
+
LossPerToken = torch.Tensor
|
|
15
15
|
Loss = SingleLoss | LossPerToken
|
|
16
16
|
|
|
17
17
|
|
|
@@ -50,6 +50,13 @@ def set_deep_attr(obj: Any, path: str, value: Any):
|
|
|
50
50
|
setattr(obj, parts[-1], value)
|
|
51
51
|
|
|
52
52
|
|
|
53
|
+
def add_hook_in_to_mlp(mlp: CanBeUsedAsMLP):
|
|
54
|
+
# Temporary hack to add a `mlp.hook_in` hook to mimic what's in circuit-tracer
|
|
55
|
+
mlp.hook_in = HookPoint()
|
|
56
|
+
original_forward = mlp.forward
|
|
57
|
+
mlp.forward = lambda x: original_forward(mlp.hook_in(x)) # type: ignore
|
|
58
|
+
|
|
59
|
+
|
|
53
60
|
class HookedSAETransformer(HookedTransformer):
|
|
54
61
|
def __init__(
|
|
55
62
|
self,
|
|
@@ -66,9 +73,14 @@ class HookedSAETransformer(HookedTransformer):
|
|
|
66
73
|
**model_kwargs: Keyword arguments for HookedTransformer initialization
|
|
67
74
|
"""
|
|
68
75
|
super().__init__(*model_args, **model_kwargs)
|
|
76
|
+
|
|
77
|
+
for block in self.blocks:
|
|
78
|
+
add_hook_in_to_mlp(block.mlp) # type: ignore
|
|
79
|
+
self.setup()
|
|
80
|
+
|
|
69
81
|
self.acts_to_saes: dict[str, SAE] = {} # type: ignore
|
|
70
82
|
|
|
71
|
-
def add_sae(self, sae: SAE, use_error_term: bool | None = None):
|
|
83
|
+
def add_sae(self, sae: SAE[Any], use_error_term: bool | None = None):
|
|
72
84
|
"""Attaches an SAE to the model
|
|
73
85
|
|
|
74
86
|
WARNING: This sae will be permanantly attached until you remove it with reset_saes. This function will also overwrite any existing SAE attached to the same hook point.
|
|
@@ -77,7 +89,7 @@ class HookedSAETransformer(HookedTransformer):
|
|
|
77
89
|
sae: SparseAutoencoderBase. The SAE to attach to the model
|
|
78
90
|
use_error_term: (bool | None) If provided, will set the use_error_term attribute of the SAE to this value. Determines whether the SAE returns input or reconstruction. Defaults to None.
|
|
79
91
|
"""
|
|
80
|
-
act_name = sae.cfg.hook_name
|
|
92
|
+
act_name = sae.cfg.metadata.hook_name
|
|
81
93
|
if (act_name not in self.acts_to_saes) and (act_name not in self.hook_dict):
|
|
82
94
|
logging.warning(
|
|
83
95
|
f"No hook found for {act_name}. Skipping. Check model.hook_dict for available hooks."
|
|
@@ -92,7 +104,7 @@ class HookedSAETransformer(HookedTransformer):
|
|
|
92
104
|
set_deep_attr(self, act_name, sae)
|
|
93
105
|
self.setup()
|
|
94
106
|
|
|
95
|
-
def _reset_sae(self, act_name: str, prev_sae: SAE | None = None):
|
|
107
|
+
def _reset_sae(self, act_name: str, prev_sae: SAE[Any] | None = None):
|
|
96
108
|
"""Resets an SAE that was attached to the model
|
|
97
109
|
|
|
98
110
|
By default will remove the SAE from that hook_point.
|
|
@@ -124,7 +136,7 @@ class HookedSAETransformer(HookedTransformer):
|
|
|
124
136
|
def reset_saes(
|
|
125
137
|
self,
|
|
126
138
|
act_names: str | list[str] | None = None,
|
|
127
|
-
prev_saes: list[SAE | None] | None = None,
|
|
139
|
+
prev_saes: list[SAE[Any] | None] | None = None,
|
|
128
140
|
):
|
|
129
141
|
"""Reset the SAEs attached to the model
|
|
130
142
|
|
|
@@ -154,16 +166,11 @@ class HookedSAETransformer(HookedTransformer):
|
|
|
154
166
|
def run_with_saes(
|
|
155
167
|
self,
|
|
156
168
|
*model_args: Any,
|
|
157
|
-
saes: SAE | list[SAE] = [],
|
|
169
|
+
saes: SAE[Any] | list[SAE[Any]] = [],
|
|
158
170
|
reset_saes_end: bool = True,
|
|
159
171
|
use_error_term: bool | None = None,
|
|
160
172
|
**model_kwargs: Any,
|
|
161
|
-
) ->
|
|
162
|
-
None
|
|
163
|
-
| Float[torch.Tensor, "batch pos d_vocab"]
|
|
164
|
-
| Loss
|
|
165
|
-
| tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss]
|
|
166
|
-
):
|
|
173
|
+
) -> None | torch.Tensor | Loss | tuple[torch.Tensor, Loss]:
|
|
167
174
|
"""Wrapper around HookedTransformer forward pass.
|
|
168
175
|
|
|
169
176
|
Runs the model with the given SAEs attached for one forward pass, then removes them. By default, will reset all SAEs to original state after.
|
|
@@ -183,17 +190,14 @@ class HookedSAETransformer(HookedTransformer):
|
|
|
183
190
|
def run_with_cache_with_saes(
|
|
184
191
|
self,
|
|
185
192
|
*model_args: Any,
|
|
186
|
-
saes: SAE | list[SAE] = [],
|
|
193
|
+
saes: SAE[Any] | list[SAE[Any]] = [],
|
|
187
194
|
reset_saes_end: bool = True,
|
|
188
195
|
use_error_term: bool | None = None,
|
|
189
196
|
return_cache_object: bool = True,
|
|
190
197
|
remove_batch_dim: bool = False,
|
|
191
198
|
**kwargs: Any,
|
|
192
199
|
) -> tuple[
|
|
193
|
-
None
|
|
194
|
-
| Float[torch.Tensor, "batch pos d_vocab"]
|
|
195
|
-
| Loss
|
|
196
|
-
| tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss],
|
|
200
|
+
None | torch.Tensor | Loss | tuple[torch.Tensor, Loss],
|
|
197
201
|
ActivationCache | dict[str, torch.Tensor],
|
|
198
202
|
]:
|
|
199
203
|
"""Wrapper around 'run_with_cache' in HookedTransformer.
|
|
@@ -225,7 +229,7 @@ class HookedSAETransformer(HookedTransformer):
|
|
|
225
229
|
def run_with_hooks_with_saes(
|
|
226
230
|
self,
|
|
227
231
|
*model_args: Any,
|
|
228
|
-
saes: SAE | list[SAE] = [],
|
|
232
|
+
saes: SAE[Any] | list[SAE[Any]] = [],
|
|
229
233
|
reset_saes_end: bool = True,
|
|
230
234
|
fwd_hooks: list[tuple[str | Callable, Callable]] = [], # type: ignore
|
|
231
235
|
bwd_hooks: list[tuple[str | Callable, Callable]] = [], # type: ignore
|
|
@@ -261,7 +265,7 @@ class HookedSAETransformer(HookedTransformer):
|
|
|
261
265
|
@contextmanager
|
|
262
266
|
def saes(
|
|
263
267
|
self,
|
|
264
|
-
saes: SAE | list[SAE] = [],
|
|
268
|
+
saes: SAE[Any] | list[SAE[Any]] = [],
|
|
265
269
|
reset_saes_end: bool = True,
|
|
266
270
|
use_error_term: bool | None = None,
|
|
267
271
|
):
|
|
@@ -275,7 +279,7 @@ class HookedSAETransformer(HookedTransformer):
|
|
|
275
279
|
.. code-block:: python
|
|
276
280
|
|
|
277
281
|
from transformer_lens import HookedSAETransformer
|
|
278
|
-
from sae_lens.sae import SAE
|
|
282
|
+
from sae_lens.saes.sae import SAE
|
|
279
283
|
|
|
280
284
|
model = HookedSAETransformer.from_pretrained('gpt2-small')
|
|
281
285
|
sae_cfg = SAEConfig(...)
|
|
@@ -295,8 +299,8 @@ class HookedSAETransformer(HookedTransformer):
|
|
|
295
299
|
saes = [saes]
|
|
296
300
|
try:
|
|
297
301
|
for sae in saes:
|
|
298
|
-
act_names_to_reset.append(sae.cfg.hook_name)
|
|
299
|
-
prev_sae = self.acts_to_saes.get(sae.cfg.hook_name, None)
|
|
302
|
+
act_names_to_reset.append(sae.cfg.metadata.hook_name)
|
|
303
|
+
prev_sae = self.acts_to_saes.get(sae.cfg.metadata.hook_name, None)
|
|
300
304
|
prev_saes.append(prev_sae)
|
|
301
305
|
self.add_sae(sae, use_error_term=use_error_term)
|
|
302
306
|
yield self
|