sae-lens 6.3.0__tar.gz → 6.25.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sae-lens might be problematic. Click here for more details.
- {sae_lens-6.3.0 → sae_lens-6.25.1}/PKG-INFO +26 -30
- {sae_lens-6.3.0 → sae_lens-6.25.1}/README.md +13 -13
- {sae_lens-6.3.0 → sae_lens-6.25.1}/pyproject.toml +12 -16
- {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/__init__.py +37 -1
- {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/analysis/hooked_sae_transformer.py +17 -13
- sae_lens-6.25.1/sae_lens/analysis/neuronpedia_integration.py +163 -0
- {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/cache_activations_runner.py +6 -7
- {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/config.py +64 -10
- {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/constants.py +9 -0
- {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/evals.py +52 -29
- {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/llm_sae_training_runner.py +70 -24
- {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/load_model.py +1 -1
- {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/loading/pretrained_sae_loaders.py +851 -56
- {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/loading/pretrained_saes_directory.py +5 -3
- {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/pretokenize_runner.py +5 -4
- sae_lens-6.25.1/sae_lens/pretrained_saes.yaml +41813 -0
- {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/saes/__init__.py +27 -0
- {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/saes/batchtopk_sae.py +34 -2
- {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/saes/gated_sae.py +6 -11
- {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/saes/jumprelu_sae.py +72 -17
- sae_lens-6.25.1/sae_lens/saes/matryoshka_batchtopk_sae.py +136 -0
- {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/saes/sae.py +81 -54
- {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/saes/standard_sae.py +4 -9
- sae_lens-6.25.1/sae_lens/saes/temporal_sae.py +365 -0
- sae_lens-6.25.1/sae_lens/saes/topk_sae.py +538 -0
- sae_lens-6.25.1/sae_lens/saes/transcoder.py +411 -0
- {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/tokenization_and_batching.py +21 -6
- {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/training/activation_scaler.py +7 -0
- {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/training/activations_store.py +62 -41
- {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/training/optim.py +11 -0
- {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/training/sae_trainer.py +77 -48
- {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/training/types.py +1 -1
- sae_lens-6.25.1/sae_lens/util.py +113 -0
- sae_lens-6.3.0/sae_lens/analysis/neuronpedia_integration.py +0 -494
- sae_lens-6.3.0/sae_lens/pretrained_saes.yaml +0 -13976
- sae_lens-6.3.0/sae_lens/saes/topk_sae.py +0 -271
- sae_lens-6.3.0/sae_lens/util.py +0 -47
- {sae_lens-6.3.0 → sae_lens-6.25.1}/LICENSE +0 -0
- {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/analysis/__init__.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/loading/__init__.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/registry.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/training/__init__.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/training/mixing_buffer.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
- {sae_lens-6.3.0 → sae_lens-6.25.1}/sae_lens/tutorial/tsea.py +0 -0
|
@@ -1,8 +1,9 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: sae-lens
|
|
3
|
-
Version: 6.
|
|
3
|
+
Version: 6.25.1
|
|
4
4
|
Summary: Training and Analyzing Sparse Autoencoders (SAEs)
|
|
5
5
|
License: MIT
|
|
6
|
+
License-File: LICENSE
|
|
6
7
|
Keywords: deep-learning,sparse-autoencoders,mechanistic-interpretability,PyTorch
|
|
7
8
|
Author: Joseph Bloom
|
|
8
9
|
Requires-Python: >=3.10,<4.0
|
|
@@ -12,41 +13,36 @@ Classifier: Programming Language :: Python :: 3.10
|
|
|
12
13
|
Classifier: Programming Language :: Python :: 3.11
|
|
13
14
|
Classifier: Programming Language :: Python :: 3.12
|
|
14
15
|
Classifier: Programming Language :: Python :: 3.13
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.14
|
|
15
17
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
16
18
|
Provides-Extra: mamba
|
|
17
|
-
Requires-Dist: automated-interpretability (>=0.0.5,<1.0.0)
|
|
18
19
|
Requires-Dist: babe (>=0.0.7,<0.0.8)
|
|
19
|
-
Requires-Dist: datasets (>=
|
|
20
|
+
Requires-Dist: datasets (>=3.1.0)
|
|
20
21
|
Requires-Dist: mamba-lens (>=0.0.4,<0.0.5) ; extra == "mamba"
|
|
21
|
-
Requires-Dist: matplotlib (>=3.8.3,<4.0.0)
|
|
22
|
-
Requires-Dist: matplotlib-inline (>=0.1.6,<0.2.0)
|
|
23
22
|
Requires-Dist: nltk (>=3.8.1,<4.0.0)
|
|
24
|
-
Requires-Dist: plotly (>=5.19.0
|
|
25
|
-
Requires-Dist: plotly-express (>=0.4.1
|
|
26
|
-
Requires-Dist:
|
|
27
|
-
Requires-Dist: python-dotenv (>=1.0.1,<2.0.0)
|
|
23
|
+
Requires-Dist: plotly (>=5.19.0)
|
|
24
|
+
Requires-Dist: plotly-express (>=0.4.1)
|
|
25
|
+
Requires-Dist: python-dotenv (>=1.0.1)
|
|
28
26
|
Requires-Dist: pyyaml (>=6.0.1,<7.0.0)
|
|
29
|
-
Requires-Dist:
|
|
30
|
-
Requires-Dist: safetensors (>=0.4.2,<0.5.0)
|
|
27
|
+
Requires-Dist: safetensors (>=0.4.2,<1.0.0)
|
|
31
28
|
Requires-Dist: simple-parsing (>=0.1.6,<0.2.0)
|
|
32
|
-
Requires-Dist:
|
|
29
|
+
Requires-Dist: tenacity (>=9.0.0)
|
|
30
|
+
Requires-Dist: transformer-lens (>=2.16.1,<3.0.0)
|
|
33
31
|
Requires-Dist: transformers (>=4.38.1,<5.0.0)
|
|
34
|
-
Requires-Dist: typer (>=0.12.3,<0.13.0)
|
|
35
32
|
Requires-Dist: typing-extensions (>=4.10.0,<5.0.0)
|
|
36
|
-
|
|
37
|
-
Project-URL:
|
|
38
|
-
Project-URL: Repository, https://github.com/jbloomAus/SAELens
|
|
33
|
+
Project-URL: Homepage, https://decoderesearch.github.io/SAELens
|
|
34
|
+
Project-URL: Repository, https://github.com/decoderesearch/SAELens
|
|
39
35
|
Description-Content-Type: text/markdown
|
|
40
36
|
|
|
41
|
-
<img width="1308"
|
|
37
|
+
<img width="1308" height="532" alt="saes_pic" src="https://github.com/user-attachments/assets/2a5d752f-b261-4ee4-ad5d-ebf282321371" />
|
|
42
38
|
|
|
43
39
|
# SAE Lens
|
|
44
40
|
|
|
45
41
|
[](https://pypi.org/project/sae-lens/)
|
|
46
42
|
[](https://opensource.org/licenses/MIT)
|
|
47
|
-
[](https://github.com/decoderesearch/SAELens/actions/workflows/build.yml)
|
|
44
|
+
[](https://github.com/decoderesearch/SAELens/actions/workflows/deploy_docs.yml)
|
|
45
|
+
[](https://codecov.io/gh/decoderesearch/SAELens)
|
|
50
46
|
|
|
51
47
|
SAELens exists to help researchers:
|
|
52
48
|
|
|
@@ -54,7 +50,7 @@ SAELens exists to help researchers:
|
|
|
54
50
|
- Analyse sparse autoencoders / research mechanistic interpretability.
|
|
55
51
|
- Generate insights which make it easier to create safe and aligned AI systems.
|
|
56
52
|
|
|
57
|
-
Please refer to the [documentation](https://
|
|
53
|
+
Please refer to the [documentation](https://decoderesearch.github.io/SAELens/) for information on how to:
|
|
58
54
|
|
|
59
55
|
- Download and Analyse pre-trained sparse autoencoders.
|
|
60
56
|
- Train your own sparse autoencoders.
|
|
@@ -62,25 +58,25 @@ Please refer to the [documentation](https://jbloomaus.github.io/SAELens/) for in
|
|
|
62
58
|
|
|
63
59
|
SAE Lens is the result of many contributors working collectively to improve humanity's understanding of neural networks, many of whom are motivated by a desire to [safeguard humanity from risks posed by artificial intelligence](https://80000hours.org/problem-profiles/artificial-intelligence/).
|
|
64
60
|
|
|
65
|
-
This library is maintained by [Joseph Bloom](https://www.
|
|
61
|
+
This library is maintained by [Joseph Bloom](https://www.decoderesearch.com/), [Curt Tigges](https://curttigges.com/), [Anthony Duong](https://github.com/anthonyduong9) and [David Chanin](https://github.com/chanind).
|
|
66
62
|
|
|
67
63
|
## Loading Pre-trained SAEs.
|
|
68
64
|
|
|
69
|
-
Pre-trained SAEs for various models can be imported via SAE Lens. See this [page](https://
|
|
65
|
+
Pre-trained SAEs for various models can be imported via SAE Lens. See this [page](https://decoderesearch.github.io/SAELens/pretrained_saes/) for a list of all SAEs.
|
|
70
66
|
|
|
71
67
|
## Migrating to SAELens v6
|
|
72
68
|
|
|
73
|
-
The new v6 update is a major refactor to SAELens and changes the way training code is structured. Check out the [migration guide](https://
|
|
69
|
+
The new v6 update is a major refactor to SAELens and changes the way training code is structured. Check out the [migration guide](https://decoderesearch.github.io/SAELens/latest/migrating/) for more details.
|
|
74
70
|
|
|
75
71
|
## Tutorials
|
|
76
72
|
|
|
77
|
-
- [SAE Lens + Neuronpedia](tutorials/tutorial_2_0.ipynb)[](https://githubtocolab.com/
|
|
73
|
+
- [SAE Lens + Neuronpedia](tutorials/tutorial_2_0.ipynb)[](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/tutorial_2_0.ipynb)
|
|
78
74
|
- [Loading and Analysing Pre-Trained Sparse Autoencoders](tutorials/basic_loading_and_analysing.ipynb)
|
|
79
|
-
[](https://githubtocolab.com/
|
|
75
|
+
[](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb)
|
|
80
76
|
- [Understanding SAE Features with the Logit Lens](tutorials/logits_lens_with_features.ipynb)
|
|
81
|
-
[](https://githubtocolab.com/
|
|
77
|
+
[](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/logits_lens_with_features.ipynb)
|
|
82
78
|
- [Training a Sparse Autoencoder](tutorials/training_a_sparse_autoencoder.ipynb)
|
|
83
|
-
[](https://githubtocolab.com/
|
|
79
|
+
[](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)
|
|
84
80
|
|
|
85
81
|
## Join the Slack!
|
|
86
82
|
|
|
@@ -95,7 +91,7 @@ Please cite the package as follows:
|
|
|
95
91
|
title = {SAELens},
|
|
96
92
|
author = {Bloom, Joseph and Tigges, Curt and Duong, Anthony and Chanin, David},
|
|
97
93
|
year = {2024},
|
|
98
|
-
howpublished = {\url{https://github.com/
|
|
94
|
+
howpublished = {\url{https://github.com/decoderesearch/SAELens}},
|
|
99
95
|
}
|
|
100
96
|
```
|
|
101
97
|
|
|
@@ -1,12 +1,12 @@
|
|
|
1
|
-
<img width="1308"
|
|
1
|
+
<img width="1308" height="532" alt="saes_pic" src="https://github.com/user-attachments/assets/2a5d752f-b261-4ee4-ad5d-ebf282321371" />
|
|
2
2
|
|
|
3
3
|
# SAE Lens
|
|
4
4
|
|
|
5
5
|
[](https://pypi.org/project/sae-lens/)
|
|
6
6
|
[](https://opensource.org/licenses/MIT)
|
|
7
|
-
[](https://github.com/decoderesearch/SAELens/actions/workflows/build.yml)
|
|
8
|
+
[](https://github.com/decoderesearch/SAELens/actions/workflows/deploy_docs.yml)
|
|
9
|
+
[](https://codecov.io/gh/decoderesearch/SAELens)
|
|
10
10
|
|
|
11
11
|
SAELens exists to help researchers:
|
|
12
12
|
|
|
@@ -14,7 +14,7 @@ SAELens exists to help researchers:
|
|
|
14
14
|
- Analyse sparse autoencoders / research mechanistic interpretability.
|
|
15
15
|
- Generate insights which make it easier to create safe and aligned AI systems.
|
|
16
16
|
|
|
17
|
-
Please refer to the [documentation](https://
|
|
17
|
+
Please refer to the [documentation](https://decoderesearch.github.io/SAELens/) for information on how to:
|
|
18
18
|
|
|
19
19
|
- Download and Analyse pre-trained sparse autoencoders.
|
|
20
20
|
- Train your own sparse autoencoders.
|
|
@@ -22,25 +22,25 @@ Please refer to the [documentation](https://jbloomaus.github.io/SAELens/) for in
|
|
|
22
22
|
|
|
23
23
|
SAE Lens is the result of many contributors working collectively to improve humanity's understanding of neural networks, many of whom are motivated by a desire to [safeguard humanity from risks posed by artificial intelligence](https://80000hours.org/problem-profiles/artificial-intelligence/).
|
|
24
24
|
|
|
25
|
-
This library is maintained by [Joseph Bloom](https://www.
|
|
25
|
+
This library is maintained by [Joseph Bloom](https://www.decoderesearch.com/), [Curt Tigges](https://curttigges.com/), [Anthony Duong](https://github.com/anthonyduong9) and [David Chanin](https://github.com/chanind).
|
|
26
26
|
|
|
27
27
|
## Loading Pre-trained SAEs.
|
|
28
28
|
|
|
29
|
-
Pre-trained SAEs for various models can be imported via SAE Lens. See this [page](https://
|
|
29
|
+
Pre-trained SAEs for various models can be imported via SAE Lens. See this [page](https://decoderesearch.github.io/SAELens/pretrained_saes/) for a list of all SAEs.
|
|
30
30
|
|
|
31
31
|
## Migrating to SAELens v6
|
|
32
32
|
|
|
33
|
-
The new v6 update is a major refactor to SAELens and changes the way training code is structured. Check out the [migration guide](https://
|
|
33
|
+
The new v6 update is a major refactor to SAELens and changes the way training code is structured. Check out the [migration guide](https://decoderesearch.github.io/SAELens/latest/migrating/) for more details.
|
|
34
34
|
|
|
35
35
|
## Tutorials
|
|
36
36
|
|
|
37
|
-
- [SAE Lens + Neuronpedia](tutorials/tutorial_2_0.ipynb)[](https://githubtocolab.com/
|
|
37
|
+
- [SAE Lens + Neuronpedia](tutorials/tutorial_2_0.ipynb)[](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/tutorial_2_0.ipynb)
|
|
38
38
|
- [Loading and Analysing Pre-Trained Sparse Autoencoders](tutorials/basic_loading_and_analysing.ipynb)
|
|
39
|
-
[](https://githubtocolab.com/
|
|
39
|
+
[](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb)
|
|
40
40
|
- [Understanding SAE Features with the Logit Lens](tutorials/logits_lens_with_features.ipynb)
|
|
41
|
-
[](https://githubtocolab.com/
|
|
41
|
+
[](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/logits_lens_with_features.ipynb)
|
|
42
42
|
- [Training a Sparse Autoencoder](tutorials/training_a_sparse_autoencoder.ipynb)
|
|
43
|
-
[](https://githubtocolab.com/
|
|
43
|
+
[](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)
|
|
44
44
|
|
|
45
45
|
## Join the Slack!
|
|
46
46
|
|
|
@@ -55,6 +55,6 @@ Please cite the package as follows:
|
|
|
55
55
|
title = {SAELens},
|
|
56
56
|
author = {Bloom, Joseph and Tigges, Curt and Duong, Anthony and Chanin, David},
|
|
57
57
|
year = {2024},
|
|
58
|
-
howpublished = {\url{https://github.com/
|
|
58
|
+
howpublished = {\url{https://github.com/decoderesearch/SAELens}},
|
|
59
59
|
}
|
|
60
60
|
```
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "sae-lens"
|
|
3
|
-
version = "6.
|
|
3
|
+
version = "6.25.1"
|
|
4
4
|
description = "Training and Analyzing Sparse Autoencoders (SAEs)"
|
|
5
5
|
authors = ["Joseph Bloom"]
|
|
6
6
|
readme = "README.md"
|
|
7
7
|
packages = [{ include = "sae_lens" }]
|
|
8
8
|
include = ["pretrained_saes.yaml"]
|
|
9
|
-
repository = "https://github.com/
|
|
10
|
-
homepage = "https://
|
|
9
|
+
repository = "https://github.com/decoderesearch/SAELens"
|
|
10
|
+
homepage = "https://decoderesearch.github.io/SAELens"
|
|
11
11
|
license = "MIT"
|
|
12
12
|
keywords = [
|
|
13
13
|
"deep-learning",
|
|
@@ -19,26 +19,20 @@ classifiers = ["Topic :: Scientific/Engineering :: Artificial Intelligence"]
|
|
|
19
19
|
|
|
20
20
|
[tool.poetry.dependencies]
|
|
21
21
|
python = "^3.10"
|
|
22
|
-
transformer-lens = "^2.
|
|
22
|
+
transformer-lens = "^2.16.1"
|
|
23
23
|
transformers = "^4.38.1"
|
|
24
|
-
plotly = "
|
|
25
|
-
plotly-express = "
|
|
26
|
-
|
|
27
|
-
matplotlib-inline = "^0.1.6"
|
|
28
|
-
datasets = "^2.17.1"
|
|
24
|
+
plotly = ">=5.19.0"
|
|
25
|
+
plotly-express = ">=0.4.1"
|
|
26
|
+
datasets = ">=3.1.0"
|
|
29
27
|
babe = "^0.0.7"
|
|
30
28
|
nltk = "^3.8.1"
|
|
31
|
-
safetensors = "
|
|
32
|
-
typer = "^0.12.3"
|
|
29
|
+
safetensors = ">=0.4.2,<1.0.0"
|
|
33
30
|
mamba-lens = { version = "^0.0.4", optional = true }
|
|
34
|
-
|
|
35
|
-
automated-interpretability = ">=0.0.5,<1.0.0"
|
|
36
|
-
python-dotenv = "^1.0.1"
|
|
31
|
+
python-dotenv = ">=1.0.1"
|
|
37
32
|
pyyaml = "^6.0.1"
|
|
38
|
-
pytest-profiling = "^1.7.0"
|
|
39
|
-
zstandard = "^0.22.0"
|
|
40
33
|
typing-extensions = "^4.10.0"
|
|
41
34
|
simple-parsing = "^0.1.6"
|
|
35
|
+
tenacity = ">=9.0.0"
|
|
42
36
|
|
|
43
37
|
[tool.poetry.group.dev.dependencies]
|
|
44
38
|
pytest = "^8.0.2"
|
|
@@ -53,6 +47,7 @@ docstr-coverage = "^2.3.2"
|
|
|
53
47
|
mkdocs = "^1.6.1"
|
|
54
48
|
mkdocs-material = "^9.5.34"
|
|
55
49
|
mkdocs-autorefs = "^1.4.2"
|
|
50
|
+
mkdocs-redirects = "^1.2.1"
|
|
56
51
|
mkdocs-section-index = "^0.3.9"
|
|
57
52
|
mkdocstrings = "^0.25.2"
|
|
58
53
|
mkdocstrings-python = "^1.10.9"
|
|
@@ -61,6 +56,7 @@ ruff = "^0.7.4"
|
|
|
61
56
|
eai-sparsify = "^1.1.1"
|
|
62
57
|
mike = "^2.0.0"
|
|
63
58
|
trio = "^0.30.0"
|
|
59
|
+
dictionary-learning = "^0.1.0"
|
|
64
60
|
|
|
65
61
|
[tool.poetry.extras]
|
|
66
62
|
mamba = ["mamba-lens"]
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
# ruff: noqa: E402
|
|
2
|
-
__version__ = "6.
|
|
2
|
+
__version__ = "6.25.1"
|
|
3
3
|
|
|
4
4
|
import logging
|
|
5
5
|
|
|
@@ -15,19 +15,31 @@ from sae_lens.saes import (
|
|
|
15
15
|
GatedTrainingSAEConfig,
|
|
16
16
|
JumpReLUSAE,
|
|
17
17
|
JumpReLUSAEConfig,
|
|
18
|
+
JumpReLUSkipTranscoder,
|
|
19
|
+
JumpReLUSkipTranscoderConfig,
|
|
18
20
|
JumpReLUTrainingSAE,
|
|
19
21
|
JumpReLUTrainingSAEConfig,
|
|
22
|
+
JumpReLUTranscoder,
|
|
23
|
+
JumpReLUTranscoderConfig,
|
|
24
|
+
MatryoshkaBatchTopKTrainingSAE,
|
|
25
|
+
MatryoshkaBatchTopKTrainingSAEConfig,
|
|
20
26
|
SAEConfig,
|
|
27
|
+
SkipTranscoder,
|
|
28
|
+
SkipTranscoderConfig,
|
|
21
29
|
StandardSAE,
|
|
22
30
|
StandardSAEConfig,
|
|
23
31
|
StandardTrainingSAE,
|
|
24
32
|
StandardTrainingSAEConfig,
|
|
33
|
+
TemporalSAE,
|
|
34
|
+
TemporalSAEConfig,
|
|
25
35
|
TopKSAE,
|
|
26
36
|
TopKSAEConfig,
|
|
27
37
|
TopKTrainingSAE,
|
|
28
38
|
TopKTrainingSAEConfig,
|
|
29
39
|
TrainingSAE,
|
|
30
40
|
TrainingSAEConfig,
|
|
41
|
+
Transcoder,
|
|
42
|
+
TranscoderConfig,
|
|
31
43
|
)
|
|
32
44
|
|
|
33
45
|
from .analysis.hooked_sae_transformer import HookedSAETransformer
|
|
@@ -89,6 +101,18 @@ __all__ = [
|
|
|
89
101
|
"LoggingConfig",
|
|
90
102
|
"BatchTopKTrainingSAE",
|
|
91
103
|
"BatchTopKTrainingSAEConfig",
|
|
104
|
+
"Transcoder",
|
|
105
|
+
"TranscoderConfig",
|
|
106
|
+
"SkipTranscoder",
|
|
107
|
+
"SkipTranscoderConfig",
|
|
108
|
+
"JumpReLUTranscoder",
|
|
109
|
+
"JumpReLUTranscoderConfig",
|
|
110
|
+
"JumpReLUSkipTranscoder",
|
|
111
|
+
"JumpReLUSkipTranscoderConfig",
|
|
112
|
+
"MatryoshkaBatchTopKTrainingSAE",
|
|
113
|
+
"MatryoshkaBatchTopKTrainingSAEConfig",
|
|
114
|
+
"TemporalSAE",
|
|
115
|
+
"TemporalSAEConfig",
|
|
92
116
|
]
|
|
93
117
|
|
|
94
118
|
|
|
@@ -103,3 +127,15 @@ register_sae_training_class("jumprelu", JumpReLUTrainingSAE, JumpReLUTrainingSAE
|
|
|
103
127
|
register_sae_training_class(
|
|
104
128
|
"batchtopk", BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig
|
|
105
129
|
)
|
|
130
|
+
register_sae_training_class(
|
|
131
|
+
"matryoshka_batchtopk",
|
|
132
|
+
MatryoshkaBatchTopKTrainingSAE,
|
|
133
|
+
MatryoshkaBatchTopKTrainingSAEConfig,
|
|
134
|
+
)
|
|
135
|
+
register_sae_class("transcoder", Transcoder, TranscoderConfig)
|
|
136
|
+
register_sae_class("skip_transcoder", SkipTranscoder, SkipTranscoderConfig)
|
|
137
|
+
register_sae_class("jumprelu_transcoder", JumpReLUTranscoder, JumpReLUTranscoderConfig)
|
|
138
|
+
register_sae_class(
|
|
139
|
+
"jumprelu_skip_transcoder", JumpReLUSkipTranscoder, JumpReLUSkipTranscoderConfig
|
|
140
|
+
)
|
|
141
|
+
register_sae_class("temporal", TemporalSAE, TemporalSAEConfig)
|
|
@@ -3,15 +3,15 @@ from contextlib import contextmanager
|
|
|
3
3
|
from typing import Any, Callable
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
|
-
from jaxtyping import Float
|
|
7
6
|
from transformer_lens.ActivationCache import ActivationCache
|
|
7
|
+
from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP
|
|
8
8
|
from transformer_lens.hook_points import HookPoint # Hooking utilities
|
|
9
9
|
from transformer_lens.HookedTransformer import HookedTransformer
|
|
10
10
|
|
|
11
11
|
from sae_lens.saes.sae import SAE
|
|
12
12
|
|
|
13
|
-
SingleLoss =
|
|
14
|
-
LossPerToken =
|
|
13
|
+
SingleLoss = torch.Tensor # Type alias for a single element tensor
|
|
14
|
+
LossPerToken = torch.Tensor
|
|
15
15
|
Loss = SingleLoss | LossPerToken
|
|
16
16
|
|
|
17
17
|
|
|
@@ -50,6 +50,13 @@ def set_deep_attr(obj: Any, path: str, value: Any):
|
|
|
50
50
|
setattr(obj, parts[-1], value)
|
|
51
51
|
|
|
52
52
|
|
|
53
|
+
def add_hook_in_to_mlp(mlp: CanBeUsedAsMLP):
|
|
54
|
+
# Temporary hack to add a `mlp.hook_in` hook to mimic what's in circuit-tracer
|
|
55
|
+
mlp.hook_in = HookPoint()
|
|
56
|
+
original_forward = mlp.forward
|
|
57
|
+
mlp.forward = lambda x: original_forward(mlp.hook_in(x)) # type: ignore
|
|
58
|
+
|
|
59
|
+
|
|
53
60
|
class HookedSAETransformer(HookedTransformer):
|
|
54
61
|
def __init__(
|
|
55
62
|
self,
|
|
@@ -66,6 +73,11 @@ class HookedSAETransformer(HookedTransformer):
|
|
|
66
73
|
**model_kwargs: Keyword arguments for HookedTransformer initialization
|
|
67
74
|
"""
|
|
68
75
|
super().__init__(*model_args, **model_kwargs)
|
|
76
|
+
|
|
77
|
+
for block in self.blocks:
|
|
78
|
+
add_hook_in_to_mlp(block.mlp) # type: ignore
|
|
79
|
+
self.setup()
|
|
80
|
+
|
|
69
81
|
self.acts_to_saes: dict[str, SAE] = {} # type: ignore
|
|
70
82
|
|
|
71
83
|
def add_sae(self, sae: SAE[Any], use_error_term: bool | None = None):
|
|
@@ -158,12 +170,7 @@ class HookedSAETransformer(HookedTransformer):
|
|
|
158
170
|
reset_saes_end: bool = True,
|
|
159
171
|
use_error_term: bool | None = None,
|
|
160
172
|
**model_kwargs: Any,
|
|
161
|
-
) ->
|
|
162
|
-
None
|
|
163
|
-
| Float[torch.Tensor, "batch pos d_vocab"]
|
|
164
|
-
| Loss
|
|
165
|
-
| tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss]
|
|
166
|
-
):
|
|
173
|
+
) -> None | torch.Tensor | Loss | tuple[torch.Tensor, Loss]:
|
|
167
174
|
"""Wrapper around HookedTransformer forward pass.
|
|
168
175
|
|
|
169
176
|
Runs the model with the given SAEs attached for one forward pass, then removes them. By default, will reset all SAEs to original state after.
|
|
@@ -190,10 +197,7 @@ class HookedSAETransformer(HookedTransformer):
|
|
|
190
197
|
remove_batch_dim: bool = False,
|
|
191
198
|
**kwargs: Any,
|
|
192
199
|
) -> tuple[
|
|
193
|
-
None
|
|
194
|
-
| Float[torch.Tensor, "batch pos d_vocab"]
|
|
195
|
-
| Loss
|
|
196
|
-
| tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss],
|
|
200
|
+
None | torch.Tensor | Loss | tuple[torch.Tensor, Loss],
|
|
197
201
|
ActivationCache | dict[str, torch.Tensor],
|
|
198
202
|
]:
|
|
199
203
|
"""Wrapper around 'run_with_cache' in HookedTransformer.
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import urllib.parse
|
|
3
|
+
import webbrowser
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import requests
|
|
7
|
+
from dotenv import load_dotenv
|
|
8
|
+
|
|
9
|
+
from sae_lens import SAE, logger
|
|
10
|
+
|
|
11
|
+
NEURONPEDIA_DOMAIN = "https://neuronpedia.org"
|
|
12
|
+
|
|
13
|
+
# Constants for replacing NaNs and Infs in outputs
|
|
14
|
+
POSITIVE_INF_REPLACEMENT = 9999
|
|
15
|
+
NEGATIVE_INF_REPLACEMENT = -9999
|
|
16
|
+
NAN_REPLACEMENT = 0
|
|
17
|
+
OTHER_INVALID_REPLACEMENT = -99999
|
|
18
|
+
|
|
19
|
+
# Pick up OPENAI_API_KEY from environment variable
|
|
20
|
+
load_dotenv()
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def NanAndInfReplacer(value: str):
|
|
24
|
+
"""Replace NaNs and Infs in outputs."""
|
|
25
|
+
replacements = {
|
|
26
|
+
"-Infinity": NEGATIVE_INF_REPLACEMENT,
|
|
27
|
+
"Infinity": POSITIVE_INF_REPLACEMENT,
|
|
28
|
+
"NaN": NAN_REPLACEMENT,
|
|
29
|
+
}
|
|
30
|
+
if value in replacements:
|
|
31
|
+
replaced_value = replacements[value]
|
|
32
|
+
return float(replaced_value)
|
|
33
|
+
return NAN_REPLACEMENT
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def open_neuronpedia_feature_dashboard(sae: SAE[Any], index: int):
|
|
37
|
+
sae_id = sae.cfg.metadata.neuronpedia_id
|
|
38
|
+
if sae_id is None:
|
|
39
|
+
logger.warning(
|
|
40
|
+
"SAE does not have a Neuronpedia ID. Either dashboards for this SAE do not exist (yet) on Neuronpedia, or the SAE was not loaded via the from_pretrained method"
|
|
41
|
+
)
|
|
42
|
+
else:
|
|
43
|
+
url = f"{NEURONPEDIA_DOMAIN}/{sae_id}/{index}"
|
|
44
|
+
webbrowser.open(url)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def get_neuronpedia_quick_list(
|
|
48
|
+
sae: SAE[Any],
|
|
49
|
+
features: list[int],
|
|
50
|
+
name: str = "temporary_list",
|
|
51
|
+
):
|
|
52
|
+
sae_id = sae.cfg.metadata.neuronpedia_id
|
|
53
|
+
if sae_id is None:
|
|
54
|
+
logger.warning(
|
|
55
|
+
"SAE does not have a Neuronpedia ID. Either dashboards for this SAE do not exist (yet) on Neuronpedia, or the SAE was not loaded via the from_pretrained method"
|
|
56
|
+
)
|
|
57
|
+
assert sae_id is not None
|
|
58
|
+
|
|
59
|
+
url = NEURONPEDIA_DOMAIN + "/quick-list/"
|
|
60
|
+
name = urllib.parse.quote(name)
|
|
61
|
+
url = url + "?name=" + name
|
|
62
|
+
list_feature = [
|
|
63
|
+
{
|
|
64
|
+
"modelId": sae.cfg.metadata.model_name,
|
|
65
|
+
"layer": sae_id.split("/")[1],
|
|
66
|
+
"index": str(feature),
|
|
67
|
+
}
|
|
68
|
+
for feature in features
|
|
69
|
+
]
|
|
70
|
+
url = url + "&features=" + urllib.parse.quote(json.dumps(list_feature))
|
|
71
|
+
webbrowser.open(url)
|
|
72
|
+
|
|
73
|
+
return url
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def get_neuronpedia_feature(
|
|
77
|
+
feature: int, layer: int, model: str = "gpt2-small", dataset: str = "res-jb"
|
|
78
|
+
) -> dict[str, Any]:
|
|
79
|
+
"""Fetch a feature from Neuronpedia API."""
|
|
80
|
+
url = f"{NEURONPEDIA_DOMAIN}/api/feature/{model}/{layer}-{dataset}/{feature}"
|
|
81
|
+
result = requests.get(url).json()
|
|
82
|
+
result["index"] = int(result["index"])
|
|
83
|
+
return result
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class NeuronpediaActivation:
|
|
87
|
+
"""Represents an activation from Neuronpedia."""
|
|
88
|
+
|
|
89
|
+
def __init__(self, id: str, tokens: list[str], act_values: list[float]):
|
|
90
|
+
self.id = id
|
|
91
|
+
self.tokens = tokens
|
|
92
|
+
self.act_values = act_values
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class NeuronpediaFeature:
|
|
96
|
+
"""Represents a feature from Neuronpedia."""
|
|
97
|
+
|
|
98
|
+
def __init__(
|
|
99
|
+
self,
|
|
100
|
+
modelId: str,
|
|
101
|
+
layer: int,
|
|
102
|
+
dataset: str,
|
|
103
|
+
feature: int,
|
|
104
|
+
description: str = "",
|
|
105
|
+
activations: list[NeuronpediaActivation] | None = None,
|
|
106
|
+
autointerp_explanation: str = "",
|
|
107
|
+
autointerp_explanation_score: float = 0.0,
|
|
108
|
+
):
|
|
109
|
+
self.modelId = modelId
|
|
110
|
+
self.layer = layer
|
|
111
|
+
self.dataset = dataset
|
|
112
|
+
self.feature = feature
|
|
113
|
+
self.description = description
|
|
114
|
+
self.activations = activations
|
|
115
|
+
self.autointerp_explanation = autointerp_explanation
|
|
116
|
+
self.autointerp_explanation_score = autointerp_explanation_score
|
|
117
|
+
|
|
118
|
+
def has_activating_text(self) -> bool:
|
|
119
|
+
"""Check if the feature has activating text."""
|
|
120
|
+
if self.activations is None:
|
|
121
|
+
return False
|
|
122
|
+
return any(max(activation.act_values) > 0 for activation in self.activations)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def make_neuronpedia_list_with_features(
|
|
126
|
+
api_key: str,
|
|
127
|
+
list_name: str,
|
|
128
|
+
features: list[NeuronpediaFeature],
|
|
129
|
+
list_description: str | None = None,
|
|
130
|
+
open_browser: bool = True,
|
|
131
|
+
):
|
|
132
|
+
url = NEURONPEDIA_DOMAIN + "/api/list/new-with-features"
|
|
133
|
+
|
|
134
|
+
# make POST json request with body
|
|
135
|
+
body = {
|
|
136
|
+
"name": list_name,
|
|
137
|
+
"description": list_description,
|
|
138
|
+
"features": [
|
|
139
|
+
{
|
|
140
|
+
"modelId": feature.modelId,
|
|
141
|
+
"layer": f"{feature.layer}-{feature.dataset}",
|
|
142
|
+
"index": feature.feature,
|
|
143
|
+
"description": feature.description,
|
|
144
|
+
}
|
|
145
|
+
for feature in features
|
|
146
|
+
],
|
|
147
|
+
}
|
|
148
|
+
response = requests.post(url, json=body, headers={"x-api-key": api_key})
|
|
149
|
+
result = response.json()
|
|
150
|
+
|
|
151
|
+
if "url" in result and open_browser:
|
|
152
|
+
webbrowser.open(result["url"])
|
|
153
|
+
return result["url"]
|
|
154
|
+
raise Exception("Error in creating list: " + result["message"])
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def test_key(api_key: str):
|
|
158
|
+
"""Test the validity of the Neuronpedia API key."""
|
|
159
|
+
url = f"{NEURONPEDIA_DOMAIN}/api/test"
|
|
160
|
+
body = {"apiKey": api_key}
|
|
161
|
+
response = requests.post(url, json=body)
|
|
162
|
+
if response.status_code != 200:
|
|
163
|
+
raise Exception("Neuronpedia API key is not valid.")
|
|
@@ -9,15 +9,14 @@ import torch
|
|
|
9
9
|
from datasets import Array2D, Dataset, Features, Sequence, Value
|
|
10
10
|
from datasets.fingerprint import generate_fingerprint
|
|
11
11
|
from huggingface_hub import HfApi
|
|
12
|
-
from
|
|
13
|
-
from tqdm import tqdm
|
|
12
|
+
from tqdm.auto import tqdm
|
|
14
13
|
from transformer_lens.HookedTransformer import HookedRootModule
|
|
15
14
|
|
|
16
15
|
from sae_lens import logger
|
|
17
16
|
from sae_lens.config import CacheActivationsRunnerConfig
|
|
18
|
-
from sae_lens.constants import DTYPE_MAP
|
|
19
17
|
from sae_lens.load_model import load_model
|
|
20
18
|
from sae_lens.training.activations_store import ActivationsStore
|
|
19
|
+
from sae_lens.util import str_to_dtype
|
|
21
20
|
|
|
22
21
|
|
|
23
22
|
def _mk_activations_store(
|
|
@@ -82,7 +81,7 @@ class CacheActivationsRunner:
|
|
|
82
81
|
)
|
|
83
82
|
for hook_name in [self.cfg.hook_name]
|
|
84
83
|
}
|
|
85
|
-
features_dict["token_ids"] = Sequence(
|
|
84
|
+
features_dict["token_ids"] = Sequence( # type: ignore
|
|
86
85
|
Value(dtype="int32"), length=self.context_size
|
|
87
86
|
)
|
|
88
87
|
self.features = Features(features_dict)
|
|
@@ -98,7 +97,7 @@ class CacheActivationsRunner:
|
|
|
98
97
|
bytes_per_token = (
|
|
99
98
|
self.cfg.d_in * self.cfg.dtype.itemsize
|
|
100
99
|
if isinstance(self.cfg.dtype, torch.dtype)
|
|
101
|
-
else
|
|
100
|
+
else str_to_dtype(self.cfg.dtype).itemsize
|
|
102
101
|
)
|
|
103
102
|
total_training_tokens = self.cfg.n_seq_in_dataset * self.context_size
|
|
104
103
|
total_disk_space_gb = total_training_tokens * bytes_per_token / 10**9
|
|
@@ -318,8 +317,8 @@ class CacheActivationsRunner:
|
|
|
318
317
|
def _create_shard(
|
|
319
318
|
self,
|
|
320
319
|
buffer: tuple[
|
|
321
|
-
|
|
322
|
-
|
|
320
|
+
torch.Tensor, # shape: (bs context_size) d_in
|
|
321
|
+
torch.Tensor | None, # shape: (bs context_size) or None
|
|
323
322
|
],
|
|
324
323
|
) -> Dataset:
|
|
325
324
|
hook_names = [self.cfg.hook_name]
|