sae-lens 6.17.0__tar.gz → 6.20.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. {sae_lens-6.17.0 → sae_lens-6.20.1}/PKG-INFO +16 -16
  2. {sae_lens-6.17.0 → sae_lens-6.20.1}/README.md +13 -13
  3. {sae_lens-6.17.0 → sae_lens-6.20.1}/pyproject.toml +3 -3
  4. {sae_lens-6.17.0 → sae_lens-6.20.1}/sae_lens/__init__.py +6 -1
  5. {sae_lens-6.17.0 → sae_lens-6.20.1}/sae_lens/config.py +37 -2
  6. {sae_lens-6.17.0 → sae_lens-6.20.1}/sae_lens/loading/pretrained_sae_loaders.py +188 -0
  7. {sae_lens-6.17.0 → sae_lens-6.20.1}/sae_lens/loading/pretrained_saes_directory.py +5 -3
  8. {sae_lens-6.17.0 → sae_lens-6.20.1}/sae_lens/pretrained_saes.yaml +51 -1
  9. {sae_lens-6.17.0 → sae_lens-6.20.1}/sae_lens/saes/__init__.py +3 -0
  10. {sae_lens-6.17.0 → sae_lens-6.20.1}/sae_lens/saes/sae.py +4 -12
  11. sae_lens-6.20.1/sae_lens/saes/temporal_sae.py +372 -0
  12. {sae_lens-6.17.0 → sae_lens-6.20.1}/sae_lens/training/activations_store.py +1 -1
  13. {sae_lens-6.17.0 → sae_lens-6.20.1}/LICENSE +0 -0
  14. {sae_lens-6.17.0 → sae_lens-6.20.1}/sae_lens/analysis/__init__.py +0 -0
  15. {sae_lens-6.17.0 → sae_lens-6.20.1}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
  16. {sae_lens-6.17.0 → sae_lens-6.20.1}/sae_lens/analysis/neuronpedia_integration.py +0 -0
  17. {sae_lens-6.17.0 → sae_lens-6.20.1}/sae_lens/cache_activations_runner.py +0 -0
  18. {sae_lens-6.17.0 → sae_lens-6.20.1}/sae_lens/constants.py +0 -0
  19. {sae_lens-6.17.0 → sae_lens-6.20.1}/sae_lens/evals.py +0 -0
  20. {sae_lens-6.17.0 → sae_lens-6.20.1}/sae_lens/llm_sae_training_runner.py +0 -0
  21. {sae_lens-6.17.0 → sae_lens-6.20.1}/sae_lens/load_model.py +0 -0
  22. {sae_lens-6.17.0 → sae_lens-6.20.1}/sae_lens/loading/__init__.py +0 -0
  23. {sae_lens-6.17.0 → sae_lens-6.20.1}/sae_lens/pretokenize_runner.py +0 -0
  24. {sae_lens-6.17.0 → sae_lens-6.20.1}/sae_lens/registry.py +0 -0
  25. {sae_lens-6.17.0 → sae_lens-6.20.1}/sae_lens/saes/batchtopk_sae.py +0 -0
  26. {sae_lens-6.17.0 → sae_lens-6.20.1}/sae_lens/saes/gated_sae.py +0 -0
  27. {sae_lens-6.17.0 → sae_lens-6.20.1}/sae_lens/saes/jumprelu_sae.py +0 -0
  28. {sae_lens-6.17.0 → sae_lens-6.20.1}/sae_lens/saes/matryoshka_batchtopk_sae.py +0 -0
  29. {sae_lens-6.17.0 → sae_lens-6.20.1}/sae_lens/saes/standard_sae.py +0 -0
  30. {sae_lens-6.17.0 → sae_lens-6.20.1}/sae_lens/saes/topk_sae.py +0 -0
  31. {sae_lens-6.17.0 → sae_lens-6.20.1}/sae_lens/saes/transcoder.py +0 -0
  32. {sae_lens-6.17.0 → sae_lens-6.20.1}/sae_lens/tokenization_and_batching.py +0 -0
  33. {sae_lens-6.17.0 → sae_lens-6.20.1}/sae_lens/training/__init__.py +0 -0
  34. {sae_lens-6.17.0 → sae_lens-6.20.1}/sae_lens/training/activation_scaler.py +0 -0
  35. {sae_lens-6.17.0 → sae_lens-6.20.1}/sae_lens/training/mixing_buffer.py +0 -0
  36. {sae_lens-6.17.0 → sae_lens-6.20.1}/sae_lens/training/optim.py +0 -0
  37. {sae_lens-6.17.0 → sae_lens-6.20.1}/sae_lens/training/sae_trainer.py +0 -0
  38. {sae_lens-6.17.0 → sae_lens-6.20.1}/sae_lens/training/types.py +0 -0
  39. {sae_lens-6.17.0 → sae_lens-6.20.1}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
  40. {sae_lens-6.17.0 → sae_lens-6.20.1}/sae_lens/tutorial/tsea.py +0 -0
  41. {sae_lens-6.17.0 → sae_lens-6.20.1}/sae_lens/util.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sae-lens
3
- Version: 6.17.0
3
+ Version: 6.20.1
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  License-File: LICENSE
@@ -30,19 +30,19 @@ Requires-Dist: tenacity (>=9.0.0)
30
30
  Requires-Dist: transformer-lens (>=2.16.1,<3.0.0)
31
31
  Requires-Dist: transformers (>=4.38.1,<5.0.0)
32
32
  Requires-Dist: typing-extensions (>=4.10.0,<5.0.0)
33
- Project-URL: Homepage, https://jbloomaus.github.io/SAELens
34
- Project-URL: Repository, https://github.com/jbloomAus/SAELens
33
+ Project-URL: Homepage, https://decoderesearch.github.io/SAELens
34
+ Project-URL: Repository, https://github.com/decoderesearch/SAELens
35
35
  Description-Content-Type: text/markdown
36
36
 
37
- <img width="1308" alt="Screenshot 2024-03-21 at 3 08 28 pm" src="https://github.com/jbloomAus/mats_sae_training/assets/69127271/209012ec-a779-4036-b4be-7b7739ea87f6">
37
+ <img width="1308" height="532" alt="saes_pic" src="https://github.com/user-attachments/assets/2a5d752f-b261-4ee4-ad5d-ebf282321371" />
38
38
 
39
39
  # SAE Lens
40
40
 
41
41
  [![PyPI](https://img.shields.io/pypi/v/sae-lens?color=blue)](https://pypi.org/project/sae-lens/)
42
42
  [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
43
- [![build](https://github.com/jbloomAus/SAELens/actions/workflows/build.yml/badge.svg)](https://github.com/jbloomAus/SAELens/actions/workflows/build.yml)
44
- [![Deploy Docs](https://github.com/jbloomAus/SAELens/actions/workflows/deploy_docs.yml/badge.svg)](https://github.com/jbloomAus/SAELens/actions/workflows/deploy_docs.yml)
45
- [![codecov](https://codecov.io/gh/jbloomAus/SAELens/graph/badge.svg?token=N83NGH8CGE)](https://codecov.io/gh/jbloomAus/SAELens)
43
+ [![build](https://github.com/decoderesearch/SAELens/actions/workflows/build.yml/badge.svg)](https://github.com/decoderesearch/SAELens/actions/workflows/build.yml)
44
+ [![Deploy Docs](https://github.com/decoderesearch/SAELens/actions/workflows/deploy_docs.yml/badge.svg)](https://github.com/decoderesearch/SAELens/actions/workflows/deploy_docs.yml)
45
+ [![codecov](https://codecov.io/gh/decoderesearch/SAELens/graph/badge.svg?token=N83NGH8CGE)](https://codecov.io/gh/decoderesearch/SAELens)
46
46
 
47
47
  SAELens exists to help researchers:
48
48
 
@@ -50,7 +50,7 @@ SAELens exists to help researchers:
50
50
  - Analyse sparse autoencoders / research mechanistic interpretability.
51
51
  - Generate insights which make it easier to create safe and aligned AI systems.
52
52
 
53
- Please refer to the [documentation](https://jbloomaus.github.io/SAELens/) for information on how to:
53
+ Please refer to the [documentation](https://decoderesearch.github.io/SAELens/) for information on how to:
54
54
 
55
55
  - Download and Analyse pre-trained sparse autoencoders.
56
56
  - Train your own sparse autoencoders.
@@ -58,25 +58,25 @@ Please refer to the [documentation](https://jbloomaus.github.io/SAELens/) for in
58
58
 
59
59
  SAE Lens is the result of many contributors working collectively to improve humanity's understanding of neural networks, many of whom are motivated by a desire to [safeguard humanity from risks posed by artificial intelligence](https://80000hours.org/problem-profiles/artificial-intelligence/).
60
60
 
61
- This library is maintained by [Joseph Bloom](https://www.jbloomaus.com/), [Curt Tigges](https://curttigges.com/), [Anthony Duong](https://github.com/anthonyduong9) and [David Chanin](https://github.com/chanind).
61
+ This library is maintained by [Joseph Bloom](https://www.decoderesearch.com/), [Curt Tigges](https://curttigges.com/), [Anthony Duong](https://github.com/anthonyduong9) and [David Chanin](https://github.com/chanind).
62
62
 
63
63
  ## Loading Pre-trained SAEs.
64
64
 
65
- Pre-trained SAEs for various models can be imported via SAE Lens. See this [page](https://jbloomaus.github.io/SAELens/sae_table/) in the readme for a list of all SAEs.
65
+ Pre-trained SAEs for various models can be imported via SAE Lens. See this [page](https://decoderesearch.github.io/SAELens/sae_table/) in the readme for a list of all SAEs.
66
66
 
67
67
  ## Migrating to SAELens v6
68
68
 
69
- The new v6 update is a major refactor to SAELens and changes the way training code is structured. Check out the [migration guide](https://jbloomaus.github.io/SAELens/latest/migrating/) for more details.
69
+ The new v6 update is a major refactor to SAELens and changes the way training code is structured. Check out the [migration guide](https://decoderesearch.github.io/SAELens/latest/migrating/) for more details.
70
70
 
71
71
  ## Tutorials
72
72
 
73
- - [SAE Lens + Neuronpedia](tutorials/tutorial_2_0.ipynb)[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/tutorial_2_0.ipynb)
73
+ - [SAE Lens + Neuronpedia](tutorials/tutorial_2_0.ipynb)[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/tutorial_2_0.ipynb)
74
74
  - [Loading and Analysing Pre-Trained Sparse Autoencoders](tutorials/basic_loading_and_analysing.ipynb)
75
- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb)
75
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb)
76
76
  - [Understanding SAE Features with the Logit Lens](tutorials/logits_lens_with_features.ipynb)
77
- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/logits_lens_with_features.ipynb)
77
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/logits_lens_with_features.ipynb)
78
78
  - [Training a Sparse Autoencoder](tutorials/training_a_sparse_autoencoder.ipynb)
79
- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)
79
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)
80
80
 
81
81
  ## Join the Slack!
82
82
 
@@ -91,7 +91,7 @@ Please cite the package as follows:
91
91
  title = {SAELens},
92
92
  author = {Bloom, Joseph and Tigges, Curt and Duong, Anthony and Chanin, David},
93
93
  year = {2024},
94
- howpublished = {\url{https://github.com/jbloomAus/SAELens}},
94
+ howpublished = {\url{https://github.com/decoderesearch/SAELens}},
95
95
  }
96
96
  ```
97
97
 
@@ -1,12 +1,12 @@
1
- <img width="1308" alt="Screenshot 2024-03-21 at 3 08 28 pm" src="https://github.com/jbloomAus/mats_sae_training/assets/69127271/209012ec-a779-4036-b4be-7b7739ea87f6">
1
+ <img width="1308" height="532" alt="saes_pic" src="https://github.com/user-attachments/assets/2a5d752f-b261-4ee4-ad5d-ebf282321371" />
2
2
 
3
3
  # SAE Lens
4
4
 
5
5
  [![PyPI](https://img.shields.io/pypi/v/sae-lens?color=blue)](https://pypi.org/project/sae-lens/)
6
6
  [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
7
- [![build](https://github.com/jbloomAus/SAELens/actions/workflows/build.yml/badge.svg)](https://github.com/jbloomAus/SAELens/actions/workflows/build.yml)
8
- [![Deploy Docs](https://github.com/jbloomAus/SAELens/actions/workflows/deploy_docs.yml/badge.svg)](https://github.com/jbloomAus/SAELens/actions/workflows/deploy_docs.yml)
9
- [![codecov](https://codecov.io/gh/jbloomAus/SAELens/graph/badge.svg?token=N83NGH8CGE)](https://codecov.io/gh/jbloomAus/SAELens)
7
+ [![build](https://github.com/decoderesearch/SAELens/actions/workflows/build.yml/badge.svg)](https://github.com/decoderesearch/SAELens/actions/workflows/build.yml)
8
+ [![Deploy Docs](https://github.com/decoderesearch/SAELens/actions/workflows/deploy_docs.yml/badge.svg)](https://github.com/decoderesearch/SAELens/actions/workflows/deploy_docs.yml)
9
+ [![codecov](https://codecov.io/gh/decoderesearch/SAELens/graph/badge.svg?token=N83NGH8CGE)](https://codecov.io/gh/decoderesearch/SAELens)
10
10
 
11
11
  SAELens exists to help researchers:
12
12
 
@@ -14,7 +14,7 @@ SAELens exists to help researchers:
14
14
  - Analyse sparse autoencoders / research mechanistic interpretability.
15
15
  - Generate insights which make it easier to create safe and aligned AI systems.
16
16
 
17
- Please refer to the [documentation](https://jbloomaus.github.io/SAELens/) for information on how to:
17
+ Please refer to the [documentation](https://decoderesearch.github.io/SAELens/) for information on how to:
18
18
 
19
19
  - Download and Analyse pre-trained sparse autoencoders.
20
20
  - Train your own sparse autoencoders.
@@ -22,25 +22,25 @@ Please refer to the [documentation](https://jbloomaus.github.io/SAELens/) for in
22
22
 
23
23
  SAE Lens is the result of many contributors working collectively to improve humanity's understanding of neural networks, many of whom are motivated by a desire to [safeguard humanity from risks posed by artificial intelligence](https://80000hours.org/problem-profiles/artificial-intelligence/).
24
24
 
25
- This library is maintained by [Joseph Bloom](https://www.jbloomaus.com/), [Curt Tigges](https://curttigges.com/), [Anthony Duong](https://github.com/anthonyduong9) and [David Chanin](https://github.com/chanind).
25
+ This library is maintained by [Joseph Bloom](https://www.decoderesearch.com/), [Curt Tigges](https://curttigges.com/), [Anthony Duong](https://github.com/anthonyduong9) and [David Chanin](https://github.com/chanind).
26
26
 
27
27
  ## Loading Pre-trained SAEs.
28
28
 
29
- Pre-trained SAEs for various models can be imported via SAE Lens. See this [page](https://jbloomaus.github.io/SAELens/sae_table/) in the readme for a list of all SAEs.
29
+ Pre-trained SAEs for various models can be imported via SAE Lens. See this [page](https://decoderesearch.github.io/SAELens/sae_table/) in the readme for a list of all SAEs.
30
30
 
31
31
  ## Migrating to SAELens v6
32
32
 
33
- The new v6 update is a major refactor to SAELens and changes the way training code is structured. Check out the [migration guide](https://jbloomaus.github.io/SAELens/latest/migrating/) for more details.
33
+ The new v6 update is a major refactor to SAELens and changes the way training code is structured. Check out the [migration guide](https://decoderesearch.github.io/SAELens/latest/migrating/) for more details.
34
34
 
35
35
  ## Tutorials
36
36
 
37
- - [SAE Lens + Neuronpedia](tutorials/tutorial_2_0.ipynb)[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/tutorial_2_0.ipynb)
37
+ - [SAE Lens + Neuronpedia](tutorials/tutorial_2_0.ipynb)[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/tutorial_2_0.ipynb)
38
38
  - [Loading and Analysing Pre-Trained Sparse Autoencoders](tutorials/basic_loading_and_analysing.ipynb)
39
- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb)
39
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/basic_loading_and_analysing.ipynb)
40
40
  - [Understanding SAE Features with the Logit Lens](tutorials/logits_lens_with_features.ipynb)
41
- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/logits_lens_with_features.ipynb)
41
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/logits_lens_with_features.ipynb)
42
42
  - [Training a Sparse Autoencoder](tutorials/training_a_sparse_autoencoder.ipynb)
43
- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)
43
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)
44
44
 
45
45
  ## Join the Slack!
46
46
 
@@ -55,6 +55,6 @@ Please cite the package as follows:
55
55
  title = {SAELens},
56
56
  author = {Bloom, Joseph and Tigges, Curt and Duong, Anthony and Chanin, David},
57
57
  year = {2024},
58
- howpublished = {\url{https://github.com/jbloomAus/SAELens}},
58
+ howpublished = {\url{https://github.com/decoderesearch/SAELens}},
59
59
  }
60
60
  ```
@@ -1,13 +1,13 @@
1
1
  [tool.poetry]
2
2
  name = "sae-lens"
3
- version = "6.17.0"
3
+ version = "6.20.1"
4
4
  description = "Training and Analyzing Sparse Autoencoders (SAEs)"
5
5
  authors = ["Joseph Bloom"]
6
6
  readme = "README.md"
7
7
  packages = [{ include = "sae_lens" }]
8
8
  include = ["pretrained_saes.yaml"]
9
- repository = "https://github.com/jbloomAus/SAELens"
10
- homepage = "https://jbloomaus.github.io/SAELens"
9
+ repository = "https://github.com/decoderesearch/SAELens"
10
+ homepage = "https://decoderesearch.github.io/SAELens"
11
11
  license = "MIT"
12
12
  keywords = [
13
13
  "deep-learning",
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.17.0"
2
+ __version__ = "6.20.1"
3
3
 
4
4
  import logging
5
5
 
@@ -28,6 +28,8 @@ from sae_lens.saes import (
28
28
  StandardSAEConfig,
29
29
  StandardTrainingSAE,
30
30
  StandardTrainingSAEConfig,
31
+ TemporalSAE,
32
+ TemporalSAEConfig,
31
33
  TopKSAE,
32
34
  TopKSAEConfig,
33
35
  TopKTrainingSAE,
@@ -105,6 +107,8 @@ __all__ = [
105
107
  "JumpReLUTranscoderConfig",
106
108
  "MatryoshkaBatchTopKTrainingSAE",
107
109
  "MatryoshkaBatchTopKTrainingSAEConfig",
110
+ "TemporalSAE",
111
+ "TemporalSAEConfig",
108
112
  ]
109
113
 
110
114
 
@@ -127,3 +131,4 @@ register_sae_training_class(
127
131
  register_sae_class("transcoder", Transcoder, TranscoderConfig)
128
132
  register_sae_class("skip_transcoder", SkipTranscoder, SkipTranscoderConfig)
129
133
  register_sae_class("jumprelu_transcoder", JumpReLUTranscoder, JumpReLUTranscoderConfig)
134
+ register_sae_class("temporal", TemporalSAE, TemporalSAEConfig)
@@ -18,6 +18,7 @@ from datasets import (
18
18
 
19
19
  from sae_lens import __version__, logger
20
20
  from sae_lens.constants import DTYPE_MAP
21
+ from sae_lens.registry import get_sae_training_class
21
22
  from sae_lens.saes.sae import TrainingSAEConfig
22
23
 
23
24
  if TYPE_CHECKING:
@@ -387,8 +388,11 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
387
388
  return self.sae.to_dict()
388
389
 
389
390
  def to_dict(self) -> dict[str, Any]:
390
- # Make a shallow copy of config's dictionary
391
- d = dict(self.__dict__)
391
+ """
392
+ Convert the config to a dictionary.
393
+ """
394
+
395
+ d = asdict(self)
392
396
 
393
397
  d["logger"] = asdict(self.logger)
394
398
  d["sae"] = self.sae.to_dict()
@@ -398,6 +402,37 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
398
402
  d["act_store_device"] = str(self.act_store_device)
399
403
  return d
400
404
 
405
+ @classmethod
406
+ def from_dict(cls, cfg_dict: dict[str, Any]) -> "LanguageModelSAERunnerConfig[Any]":
407
+ """
408
+ Load a LanguageModelSAERunnerConfig from a dictionary given by `to_dict`.
409
+
410
+ Args:
411
+ cfg_dict (dict[str, Any]): The dictionary to load the config from.
412
+
413
+ Returns:
414
+ LanguageModelSAERunnerConfig: The loaded config.
415
+ """
416
+ if "sae" not in cfg_dict:
417
+ raise ValueError("sae field is required in the config dictionary")
418
+ if "architecture" not in cfg_dict["sae"]:
419
+ raise ValueError("architecture field is required in the sae dictionary")
420
+ if "logger" not in cfg_dict:
421
+ raise ValueError("logger field is required in the config dictionary")
422
+ sae_config_class = get_sae_training_class(cfg_dict["sae"]["architecture"])[1]
423
+ sae_cfg = sae_config_class.from_dict(cfg_dict["sae"])
424
+ logger_cfg = LoggingConfig(**cfg_dict["logger"])
425
+ updated_cfg_dict: dict[str, Any] = {
426
+ **cfg_dict,
427
+ "sae": sae_cfg,
428
+ "logger": logger_cfg,
429
+ }
430
+ output = cls(**updated_cfg_dict)
431
+ # the post_init always appends to checkpoint path, so we need to set it explicitly here.
432
+ if "checkpoint_path" in cfg_dict:
433
+ output.checkpoint_path = cfg_dict["checkpoint_path"]
434
+ return output
435
+
401
436
  def to_sae_trainer_config(self) -> "SAETrainerConfig":
402
437
  return SAETrainerConfig(
403
438
  n_checkpoints=self.n_checkpoints,
@@ -523,6 +523,82 @@ def gemma_2_sae_huggingface_loader(
523
523
  return cfg_dict, state_dict, log_sparsity
524
524
 
525
525
 
526
+ def get_goodfire_config_from_hf(
527
+ repo_id: str,
528
+ folder_name: str, # noqa: ARG001
529
+ device: str,
530
+ force_download: bool = False, # noqa: ARG001
531
+ cfg_overrides: dict[str, Any] | None = None,
532
+ ) -> dict[str, Any]:
533
+ cfg_dict = None
534
+ if repo_id == "Goodfire/Llama-3.3-70B-Instruct-SAE-l50":
535
+ if folder_name != "Llama-3.3-70B-Instruct-SAE-l50.pt":
536
+ raise ValueError(f"Unsupported Goodfire SAE: {repo_id}/{folder_name}")
537
+ cfg_dict = {
538
+ "architecture": "standard",
539
+ "d_in": 8192,
540
+ "d_sae": 65536,
541
+ "model_name": "meta-llama/Llama-3.3-70B-Instruct",
542
+ "hook_name": "blocks.50.hook_resid_post",
543
+ "hook_head_index": None,
544
+ "dataset_path": "lmsys/lmsys-chat-1m",
545
+ "apply_b_dec_to_input": False,
546
+ }
547
+ elif repo_id == "Goodfire/Llama-3.1-8B-Instruct-SAE-l19":
548
+ if folder_name != "Llama-3.1-8B-Instruct-SAE-l19.pth":
549
+ raise ValueError(f"Unsupported Goodfire SAE: {repo_id}/{folder_name}")
550
+ cfg_dict = {
551
+ "architecture": "standard",
552
+ "d_in": 4096,
553
+ "d_sae": 65536,
554
+ "model_name": "meta-llama/Llama-3.1-8B-Instruct",
555
+ "hook_name": "blocks.19.hook_resid_post",
556
+ "hook_head_index": None,
557
+ "dataset_path": "lmsys/lmsys-chat-1m",
558
+ "apply_b_dec_to_input": False,
559
+ }
560
+ if cfg_dict is None:
561
+ raise ValueError(f"Unsupported Goodfire SAE: {repo_id}/{folder_name}")
562
+ if device is not None:
563
+ cfg_dict["device"] = device
564
+ if cfg_overrides is not None:
565
+ cfg_dict.update(cfg_overrides)
566
+ return cfg_dict
567
+
568
+
569
+ def get_goodfire_huggingface_loader(
570
+ repo_id: str,
571
+ folder_name: str,
572
+ device: str = "cpu",
573
+ force_download: bool = False,
574
+ cfg_overrides: dict[str, Any] | None = None,
575
+ ) -> tuple[dict[str, Any], dict[str, torch.Tensor], torch.Tensor | None]:
576
+ cfg_dict = get_goodfire_config_from_hf(
577
+ repo_id,
578
+ folder_name,
579
+ device,
580
+ force_download,
581
+ cfg_overrides,
582
+ )
583
+
584
+ # Download the SAE weights
585
+ sae_path = hf_hub_download(
586
+ repo_id=repo_id,
587
+ filename=folder_name,
588
+ force_download=force_download,
589
+ )
590
+ raw_state_dict = torch.load(sae_path, map_location=device)
591
+
592
+ state_dict = {
593
+ "W_enc": raw_state_dict["encoder_linear.weight"].T,
594
+ "W_dec": raw_state_dict["decoder_linear.weight"].T,
595
+ "b_enc": raw_state_dict["encoder_linear.bias"],
596
+ "b_dec": raw_state_dict["decoder_linear.bias"],
597
+ }
598
+
599
+ return cfg_dict, state_dict, None
600
+
601
+
526
602
  def get_llama_scope_config_from_hf(
527
603
  repo_id: str,
528
604
  folder_name: str,
@@ -1475,6 +1551,114 @@ def get_mntss_clt_layer_config_from_hf(
1475
1551
  }
1476
1552
 
1477
1553
 
1554
+ def get_temporal_sae_config_from_hf(
1555
+ repo_id: str,
1556
+ folder_name: str,
1557
+ device: str,
1558
+ force_download: bool = False,
1559
+ cfg_overrides: dict[str, Any] | None = None,
1560
+ ) -> dict[str, Any]:
1561
+ """Get TemporalSAE config without loading weights."""
1562
+ # Download config file
1563
+ conf_path = hf_hub_download(
1564
+ repo_id=repo_id,
1565
+ filename=f"{folder_name}/conf.yaml",
1566
+ force_download=force_download,
1567
+ )
1568
+
1569
+ # Load and parse config
1570
+ with open(conf_path) as f:
1571
+ yaml_config = yaml.safe_load(f)
1572
+
1573
+ # Extract parameters
1574
+ d_in = yaml_config["llm"]["dimin"]
1575
+ exp_factor = yaml_config["sae"]["exp_factor"]
1576
+ d_sae = int(d_in * exp_factor)
1577
+
1578
+ # extract layer from folder_name eg : "layer_12/temporal"
1579
+ layer = re.search(r"layer_(\d+)", folder_name)
1580
+ if layer is None:
1581
+ raise ValueError(f"Could not find layer in folder_name: {folder_name}")
1582
+ layer = int(layer.group(1))
1583
+
1584
+ # Build config dict
1585
+ cfg_dict = {
1586
+ "architecture": "temporal",
1587
+ "hook_name": f"blocks.{layer}.hook_resid_post",
1588
+ "d_in": d_in,
1589
+ "d_sae": d_sae,
1590
+ "n_heads": yaml_config["sae"]["n_heads"],
1591
+ "n_attn_layers": yaml_config["sae"]["n_attn_layers"],
1592
+ "bottleneck_factor": yaml_config["sae"]["bottleneck_factor"],
1593
+ "sae_diff_type": yaml_config["sae"]["sae_diff_type"],
1594
+ "kval_topk": yaml_config["sae"]["kval_topk"],
1595
+ "tied_weights": yaml_config["sae"]["tied_weights"],
1596
+ "dtype": yaml_config["data"]["dtype"],
1597
+ "device": device,
1598
+ "normalize_activations": "constant_scalar_rescale",
1599
+ "activation_normalization_factor": yaml_config["sae"]["scaling_factor"],
1600
+ "apply_b_dec_to_input": True,
1601
+ }
1602
+
1603
+ if cfg_overrides:
1604
+ cfg_dict.update(cfg_overrides)
1605
+
1606
+ return cfg_dict
1607
+
1608
+
1609
+ def temporal_sae_huggingface_loader(
1610
+ repo_id: str,
1611
+ folder_name: str,
1612
+ device: str = "cpu",
1613
+ force_download: bool = False,
1614
+ cfg_overrides: dict[str, Any] | None = None,
1615
+ ) -> tuple[dict[str, Any], dict[str, torch.Tensor], torch.Tensor | None]:
1616
+ """
1617
+ Load TemporalSAE from canrager/temporalSAEs format (safetensors version).
1618
+
1619
+ Expects folder_name to contain:
1620
+ - conf.yaml (configuration)
1621
+ - latest_ckpt.safetensors (model weights)
1622
+ """
1623
+
1624
+ cfg_dict = get_temporal_sae_config_from_hf(
1625
+ repo_id=repo_id,
1626
+ folder_name=folder_name,
1627
+ device=device,
1628
+ force_download=force_download,
1629
+ cfg_overrides=cfg_overrides,
1630
+ )
1631
+
1632
+ # Download checkpoint (safetensors format)
1633
+ ckpt_path = hf_hub_download(
1634
+ repo_id=repo_id,
1635
+ filename=f"{folder_name}/latest_ckpt.safetensors",
1636
+ force_download=force_download,
1637
+ )
1638
+
1639
+ # Load checkpoint from safetensors
1640
+ state_dict_raw = load_file(ckpt_path, device=device)
1641
+
1642
+ # Convert to SAELens naming convention
1643
+ # TemporalSAE uses: D (decoder), E (encoder), b (bias), attn_layers.*
1644
+ state_dict = {}
1645
+
1646
+ # Copy attention layers as-is
1647
+ for key, value in state_dict_raw.items():
1648
+ if key.startswith("attn_layers."):
1649
+ state_dict[key] = value.to(device)
1650
+
1651
+ # Main parameters
1652
+ state_dict["W_dec"] = state_dict_raw["D"].to(device)
1653
+ state_dict["b_dec"] = state_dict_raw["b"].to(device)
1654
+
1655
+ # Handle tied/untied weights
1656
+ if "E" in state_dict_raw:
1657
+ state_dict["W_enc"] = state_dict_raw["E"].to(device)
1658
+
1659
+ return cfg_dict, state_dict, None
1660
+
1661
+
1478
1662
  NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
1479
1663
  "sae_lens": sae_lens_huggingface_loader,
1480
1664
  "connor_rob_hook_z": connor_rob_hook_z_huggingface_loader,
@@ -1487,6 +1671,8 @@ NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
1487
1671
  "gemma_2_transcoder": gemma_2_transcoder_huggingface_loader,
1488
1672
  "mwhanna_transcoder": mwhanna_transcoder_huggingface_loader,
1489
1673
  "mntss_clt_layer_transcoder": mntss_clt_layer_huggingface_loader,
1674
+ "temporal": temporal_sae_huggingface_loader,
1675
+ "goodfire": get_goodfire_huggingface_loader,
1490
1676
  }
1491
1677
 
1492
1678
 
@@ -1502,4 +1688,6 @@ NAMED_PRETRAINED_SAE_CONFIG_GETTERS: dict[str, PretrainedSaeConfigHuggingfaceLoa
1502
1688
  "gemma_2_transcoder": get_gemma_2_transcoder_config_from_hf,
1503
1689
  "mwhanna_transcoder": get_mwhanna_transcoder_config_from_hf,
1504
1690
  "mntss_clt_layer_transcoder": get_mntss_clt_layer_config_from_hf,
1691
+ "temporal": get_temporal_sae_config_from_hf,
1692
+ "goodfire": get_goodfire_config_from_hf,
1505
1693
  }
@@ -1,6 +1,6 @@
1
1
  from dataclasses import dataclass
2
2
  from functools import cache
3
- from importlib import resources
3
+ from importlib.resources import files
4
4
  from typing import Any
5
5
 
6
6
  import yaml
@@ -24,7 +24,8 @@ def get_pretrained_saes_directory() -> dict[str, PretrainedSAELookup]:
24
24
  package = "sae_lens"
25
25
  # Access the file within the package using importlib.resources
26
26
  directory: dict[str, PretrainedSAELookup] = {}
27
- with resources.open_text(package, "pretrained_saes.yaml") as file:
27
+ yaml_file = files(package).joinpath("pretrained_saes.yaml")
28
+ with yaml_file.open("r") as file:
28
29
  # Load the YAML file content
29
30
  data = yaml.safe_load(file)
30
31
  for release, value in data.items():
@@ -68,7 +69,8 @@ def get_norm_scaling_factor(release: str, sae_id: str) -> float | None:
68
69
  float | None: The norm_scaling_factor if it exists, None otherwise.
69
70
  """
70
71
  package = "sae_lens"
71
- with resources.open_text(package, "pretrained_saes.yaml") as file:
72
+ yaml_file = files(package).joinpath("pretrained_saes.yaml")
73
+ with yaml_file.open("r") as file:
72
74
  data = yaml.safe_load(file)
73
75
  if release in data:
74
76
  for sae_info in data[release]["saes"]:
@@ -1,3 +1,35 @@
1
+ temporal-sae-gemma-2-2b:
2
+ conversion_func: temporal
3
+ model: gemma-2-2b
4
+ repo_id: canrager/temporalSAEs
5
+ config_overrides:
6
+ model_name: gemma-2-2b
7
+ hook_name: blocks.12.hook_resid_post
8
+ dataset_path: monology/pile-uncopyrighted
9
+ saes:
10
+ - id: blocks.12.hook_resid_post
11
+ l0: 192
12
+ norm_scaling_factor: 0.00666666667
13
+ path: gemma-2-2B/layer_12/temporal
14
+ neuronpedia: gemma-2-2b/12-temporal-res
15
+ temporal-sae-llama-3.1-8b:
16
+ conversion_func: temporal
17
+ model: meta-llama/Llama-3.1-8B
18
+ repo_id: canrager/temporalSAEs
19
+ config_overrides:
20
+ model_name: meta-llama/Llama-3.1-8B
21
+ dataset_path: monology/pile-uncopyrighted
22
+ saes:
23
+ - id: blocks.15.hook_resid_post
24
+ l0: 256
25
+ norm_scaling_factor: 0.029
26
+ path: llama-3.1-8B/layer_15/temporal
27
+ neuronpedia: llama3.1-8b/15-temporal-res
28
+ - id: blocks.26.hook_resid_post
29
+ l0: 256
30
+ norm_scaling_factor: 0.029
31
+ path: llama-3.1-8B/layer_26/temporal
32
+ neuronpedia: llama3.1-8b/26-temporal-res
1
33
  deepseek-r1-distill-llama-8b-qresearch:
2
34
  conversion_func: deepseek_r1
3
35
  model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
@@ -14882,4 +14914,22 @@ qwen2.5-7b-instruct-andyrdt:
14882
14914
  neuronpedia: qwen2.5-7b-it/23-resid-post-aa
14883
14915
  - id: resid_post_layer_27_trainer_1
14884
14916
  path: resid_post_layer_27/trainer_1
14885
- neuronpedia: qwen2.5-7b-it/27-resid-post-aa
14917
+ neuronpedia: qwen2.5-7b-it/27-resid-post-aa
14918
+
14919
+ goodfire-llama-3.3-70b-instruct:
14920
+ conversion_func: goodfire
14921
+ model: meta-llama/Llama-3.3-70B-Instruct
14922
+ repo_id: Goodfire/Llama-3.3-70B-Instruct-SAE-l50
14923
+ saes:
14924
+ - id: layer_50
14925
+ path: Llama-3.3-70B-Instruct-SAE-l50.pt
14926
+ l0: 121
14927
+
14928
+ goodfire-llama-3.1-8b-instruct:
14929
+ conversion_func: goodfire
14930
+ model: meta-llama/Llama-3.1-8B-Instruct
14931
+ repo_id: Goodfire/Llama-3.1-8B-Instruct-SAE-l19
14932
+ saes:
14933
+ - id: layer_19
14934
+ path: Llama-3.1-8B-Instruct-SAE-l19.pth
14935
+ l0: 91
@@ -25,6 +25,7 @@ from .standard_sae import (
25
25
  StandardTrainingSAE,
26
26
  StandardTrainingSAEConfig,
27
27
  )
28
+ from .temporal_sae import TemporalSAE, TemporalSAEConfig
28
29
  from .topk_sae import (
29
30
  TopKSAE,
30
31
  TopKSAEConfig,
@@ -71,4 +72,6 @@ __all__ = [
71
72
  "JumpReLUTranscoderConfig",
72
73
  "MatryoshkaBatchTopKTrainingSAE",
73
74
  "MatryoshkaBatchTopKTrainingSAEConfig",
75
+ "TemporalSAE",
76
+ "TemporalSAEConfig",
74
77
  ]
@@ -155,9 +155,9 @@ class SAEConfig(ABC):
155
155
  dtype: str = "float32"
156
156
  device: str = "cpu"
157
157
  apply_b_dec_to_input: bool = True
158
- normalize_activations: Literal[
159
- "none", "expected_average_only_in", "constant_norm_rescale", "layer_norm"
160
- ] = "none" # none, expected_average_only_in (Anthropic April Update), constant_norm_rescale (Anthropic Feb Update)
158
+ normalize_activations: Literal["none", "expected_average_only_in", "layer_norm"] = (
159
+ "none" # none, expected_average_only_in (Anthropic April Update)
160
+ )
161
161
  reshape_activations: Literal["none", "hook_z"] = "none"
162
162
  metadata: SAEMetadata = field(default_factory=SAEMetadata)
163
163
 
@@ -309,6 +309,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
309
309
 
310
310
  self.run_time_activation_norm_fn_in = run_time_activation_norm_fn_in
311
311
  self.run_time_activation_norm_fn_out = run_time_activation_norm_fn_out
312
+
312
313
  elif self.cfg.normalize_activations == "layer_norm":
313
314
  # we need to scale the norm of the input and store the scaling factor
314
315
  def run_time_activation_ln_in(
@@ -452,23 +453,14 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
452
453
  def process_sae_in(
453
454
  self, sae_in: Float[torch.Tensor, "... d_in"]
454
455
  ) -> Float[torch.Tensor, "... d_in"]:
455
- # print(f"Input shape to process_sae_in: {sae_in.shape}")
456
- # print(f"self.cfg.hook_name: {self.cfg.hook_name}")
457
- # print(f"self.b_dec shape: {self.b_dec.shape}")
458
- # print(f"Hook z reshaping mode: {getattr(self, 'hook_z_reshaping_mode', False)}")
459
-
460
456
  sae_in = sae_in.to(self.dtype)
461
-
462
- # print(f"Shape before reshape_fn_in: {sae_in.shape}")
463
457
  sae_in = self.reshape_fn_in(sae_in)
464
- # print(f"Shape after reshape_fn_in: {sae_in.shape}")
465
458
 
466
459
  sae_in = self.hook_sae_input(sae_in)
467
460
  sae_in = self.run_time_activation_norm_fn_in(sae_in)
468
461
 
469
462
  # Here's where the error happens
470
463
  bias_term = self.b_dec * self.cfg.apply_b_dec_to_input
471
- # print(f"Bias term shape: {bias_term.shape}")
472
464
 
473
465
  return sae_in - bias_term
474
466
 
@@ -0,0 +1,372 @@
1
+ """TemporalSAE: A Sparse Autoencoder with temporal attention mechanism.
2
+
3
+ TemporalSAE decomposes activations into:
4
+ 1. Predicted codes (from attention over context)
5
+ 2. Novel codes (sparse features of the residual)
6
+
7
+ See: https://arxiv.org/abs/2410.04185
8
+ """
9
+
10
+ import math
11
+ from dataclasses import dataclass
12
+ from typing import Literal
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+ from jaxtyping import Float
17
+ from torch import nn
18
+ from typing_extensions import override
19
+
20
+ from sae_lens import logger
21
+ from sae_lens.saes.sae import SAE, SAEConfig
22
+
23
+
24
+ def get_attention(query: torch.Tensor, key: torch.Tensor) -> torch.Tensor:
25
+ """Compute causal attention weights."""
26
+ L, S = query.size(-2), key.size(-2)
27
+ scale_factor = 1 / math.sqrt(query.size(-1))
28
+ attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
29
+ temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0)
30
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
31
+ attn_bias.to(query.dtype)
32
+
33
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
34
+ attn_weight += attn_bias
35
+ return torch.softmax(attn_weight, dim=-1)
36
+
37
+
38
+ class ManualAttention(nn.Module):
39
+ """Manual attention implementation for TemporalSAE."""
40
+
41
+ def __init__(
42
+ self,
43
+ dimin: int,
44
+ n_heads: int = 4,
45
+ bottleneck_factor: int = 64,
46
+ bias_k: bool = True,
47
+ bias_q: bool = True,
48
+ bias_v: bool = True,
49
+ bias_o: bool = True,
50
+ ):
51
+ super().__init__()
52
+ assert dimin % (bottleneck_factor * n_heads) == 0
53
+
54
+ self.n_heads = n_heads
55
+ self.n_embds = dimin // bottleneck_factor
56
+ self.dimin = dimin
57
+
58
+ # Key, query, value projections
59
+ self.k_ctx = nn.Linear(dimin, self.n_embds, bias=bias_k)
60
+ self.q_target = nn.Linear(dimin, self.n_embds, bias=bias_q)
61
+ self.v_ctx = nn.Linear(dimin, dimin, bias=bias_v)
62
+ self.c_proj = nn.Linear(dimin, dimin, bias=bias_o)
63
+
64
+ # Normalize to match scale with representations
65
+ with torch.no_grad():
66
+ scaling = 1 / math.sqrt(self.n_embds // self.n_heads)
67
+ self.k_ctx.weight.copy_(
68
+ scaling
69
+ * self.k_ctx.weight
70
+ / (1e-6 + torch.linalg.norm(self.k_ctx.weight, dim=1, keepdim=True))
71
+ )
72
+ self.q_target.weight.copy_(
73
+ scaling
74
+ * self.q_target.weight
75
+ / (1e-6 + torch.linalg.norm(self.q_target.weight, dim=1, keepdim=True))
76
+ )
77
+
78
+ scaling = 1 / math.sqrt(self.dimin // self.n_heads)
79
+ self.v_ctx.weight.copy_(
80
+ scaling
81
+ * self.v_ctx.weight
82
+ / (1e-6 + torch.linalg.norm(self.v_ctx.weight, dim=1, keepdim=True))
83
+ )
84
+
85
+ scaling = 1 / math.sqrt(self.dimin)
86
+ self.c_proj.weight.copy_(
87
+ scaling
88
+ * self.c_proj.weight
89
+ / (1e-6 + torch.linalg.norm(self.c_proj.weight, dim=1, keepdim=True))
90
+ )
91
+
92
+ def forward(
93
+ self, x_ctx: torch.Tensor, x_target: torch.Tensor, get_attn_map: bool = False
94
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
95
+ """Compute projective attention output."""
96
+ k = self.k_ctx(x_ctx)
97
+ v = self.v_ctx(x_ctx)
98
+ q = self.q_target(x_target)
99
+
100
+ # Split into heads
101
+ B, T, _ = x_ctx.size()
102
+ k = k.view(B, T, self.n_heads, self.n_embds // self.n_heads).transpose(1, 2)
103
+ q = q.view(B, T, self.n_heads, self.n_embds // self.n_heads).transpose(1, 2)
104
+ v = v.view(B, T, self.n_heads, self.dimin // self.n_heads).transpose(1, 2)
105
+
106
+ # Attention map (optional)
107
+ attn_map = None
108
+ if get_attn_map:
109
+ attn_map = get_attention(query=q, key=k)
110
+
111
+ # Scaled dot-product attention
112
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
113
+ q, k, v, attn_mask=None, dropout_p=0, is_causal=True
114
+ )
115
+
116
+ # Reshape and project
117
+ d_target = self.c_proj(
118
+ attn_output.transpose(1, 2).contiguous().view(B, T, self.dimin)
119
+ )
120
+
121
+ return d_target, attn_map
122
+
123
+
124
+ @dataclass
125
+ class TemporalSAEConfig(SAEConfig):
126
+ """Configuration for TemporalSAE inference.
127
+
128
+ Args:
129
+ d_in: Input dimension (dimensionality of the activations being encoded)
130
+ d_sae: SAE latent dimension (number of features)
131
+ n_heads: Number of attention heads in temporal attention
132
+ n_attn_layers: Number of attention layers
133
+ bottleneck_factor: Bottleneck factor for attention dimension
134
+ sae_diff_type: Type of SAE for novel codes ('relu' or 'topk')
135
+ kval_topk: K value for top-k sparsity (if sae_diff_type='topk')
136
+ tied_weights: Whether to tie encoder and decoder weights
137
+ activation_normalization_factor: Scalar factor for rescaling activations (used with normalize_activations='constant_scalar_rescale')
138
+ """
139
+
140
+ n_heads: int = 8
141
+ n_attn_layers: int = 1
142
+ bottleneck_factor: int = 64
143
+ sae_diff_type: Literal["relu", "topk"] = "topk"
144
+ kval_topk: int | None = None
145
+ tied_weights: bool = True
146
+ activation_normalization_factor: float = 1.0
147
+
148
+ def __post_init__(self):
149
+ # Call parent's __post_init__ first, but allow constant_scalar_rescale
150
+ if self.normalize_activations not in [
151
+ "none",
152
+ "expected_average_only_in",
153
+ "constant_norm_rescale",
154
+ "constant_scalar_rescale", # Temporal SAEs support this
155
+ "layer_norm",
156
+ ]:
157
+ raise ValueError(
158
+ f"normalize_activations must be none, expected_average_only_in, layer_norm, constant_norm_rescale, or constant_scalar_rescale. Got {self.normalize_activations}"
159
+ )
160
+
161
+ @override
162
+ @classmethod
163
+ def architecture(cls) -> str:
164
+ return "temporal"
165
+
166
+
167
+ class TemporalSAE(SAE[TemporalSAEConfig]):
168
+ """TemporalSAE: Sparse Autoencoder with temporal attention.
169
+
170
+ This SAE decomposes each activation x_t into:
171
+ - x_pred: Information aggregated from context {x_0, ..., x_{t-1}}
172
+ - x_novel: Novel information at position t (encoded sparsely)
173
+
174
+ The forward pass:
175
+ 1. Uses attention layers to predict x_t from context
176
+ 2. Encodes the residual (novel part) with a sparse SAE
177
+ 3. Combines both for reconstruction
178
+ """
179
+
180
+ # Custom parameters (in addition to W_enc, W_dec, b_dec from base)
181
+ attn_layers: nn.ModuleList # Attention layers
182
+ eps: float
183
+ lam: float
184
+
185
+ def __init__(self, cfg: TemporalSAEConfig, use_error_term: bool = False):
186
+ # Call parent init first
187
+ super().__init__(cfg, use_error_term)
188
+
189
+ # Initialize attention layers after parent init and move to correct device
190
+ self.attn_layers = nn.ModuleList(
191
+ [
192
+ ManualAttention(
193
+ dimin=cfg.d_sae,
194
+ n_heads=cfg.n_heads,
195
+ bottleneck_factor=cfg.bottleneck_factor,
196
+ bias_k=True,
197
+ bias_q=True,
198
+ bias_v=True,
199
+ bias_o=True,
200
+ ).to(device=self.device, dtype=self.dtype)
201
+ for _ in range(cfg.n_attn_layers)
202
+ ]
203
+ )
204
+
205
+ self.eps = 1e-6
206
+ self.lam = 1 / (4 * self.cfg.d_in)
207
+
208
+ @override
209
+ def _setup_activation_normalization(self):
210
+ """Set up activation normalization functions for TemporalSAE.
211
+
212
+ Overrides the base implementation to handle constant_scalar_rescale
213
+ using the temporal-specific activation_normalization_factor.
214
+ """
215
+ if self.cfg.normalize_activations == "constant_scalar_rescale":
216
+ # Handle constant scalar rescaling for temporal SAEs
217
+ def run_time_activation_norm_fn_in(x: torch.Tensor) -> torch.Tensor:
218
+ return x * self.cfg.activation_normalization_factor
219
+
220
+ def run_time_activation_norm_fn_out(x: torch.Tensor) -> torch.Tensor:
221
+ return x / self.cfg.activation_normalization_factor
222
+
223
+ self.run_time_activation_norm_fn_in = run_time_activation_norm_fn_in
224
+ self.run_time_activation_norm_fn_out = run_time_activation_norm_fn_out
225
+ else:
226
+ # Delegate to parent for all other normalization types
227
+ super()._setup_activation_normalization()
228
+
229
+ @override
230
+ def initialize_weights(self) -> None:
231
+ """Initialize TemporalSAE weights."""
232
+ # Initialize D (decoder) and b (bias)
233
+ self.W_dec = nn.Parameter(
234
+ torch.randn(
235
+ (self.cfg.d_sae, self.cfg.d_in), dtype=self.dtype, device=self.device
236
+ )
237
+ )
238
+ self.b_dec = nn.Parameter(
239
+ torch.zeros((self.cfg.d_in), dtype=self.dtype, device=self.device)
240
+ )
241
+
242
+ # Initialize E (encoder) if not tied
243
+ if not self.cfg.tied_weights:
244
+ self.W_enc = nn.Parameter(
245
+ torch.randn(
246
+ (self.cfg.d_in, self.cfg.d_sae),
247
+ dtype=self.dtype,
248
+ device=self.device,
249
+ )
250
+ )
251
+
252
+ def encode_with_predictions(
253
+ self, x: Float[torch.Tensor, "... d_in"]
254
+ ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
255
+ """Encode input to novel codes only.
256
+
257
+ Returns only the sparse novel codes (not predicted codes).
258
+ This is the main feature representation for TemporalSAE.
259
+ """
260
+ # Process input through SAELens preprocessing
261
+ x = self.process_sae_in(x)
262
+
263
+ B, L, _ = x.shape
264
+
265
+ if self.cfg.tied_weights: # noqa: SIM108
266
+ W_enc = self.W_dec.T
267
+ else:
268
+ W_enc = self.W_enc
269
+
270
+ # Compute predicted codes using attention
271
+ x_residual = x
272
+ z_pred = torch.zeros((B, L, self.cfg.d_sae), device=x.device, dtype=x.dtype)
273
+
274
+ for attn_layer in self.attn_layers:
275
+ # Encode input to latent space
276
+ z_input = F.relu(torch.matmul(x_residual * self.lam, W_enc))
277
+
278
+ # Shift context (causal masking)
279
+ z_ctx = torch.cat(
280
+ (torch.zeros_like(z_input[:, :1, :]), z_input[:, :-1, :].clone()), dim=1
281
+ )
282
+
283
+ # Apply attention to get predicted codes
284
+ z_pred_, _ = attn_layer(z_ctx, z_input, get_attn_map=False)
285
+ z_pred_ = F.relu(z_pred_)
286
+
287
+ # Project predicted codes back to input space
288
+ Dz_pred_ = torch.matmul(z_pred_, self.W_dec)
289
+ Dz_norm_ = Dz_pred_.norm(dim=-1, keepdim=True) + self.eps
290
+
291
+ # Compute projection scale
292
+ proj_scale = (Dz_pred_ * x_residual).sum(
293
+ dim=-1, keepdim=True
294
+ ) / Dz_norm_.pow(2)
295
+
296
+ # Accumulate predicted codes
297
+ z_pred = z_pred + (z_pred_ * proj_scale)
298
+
299
+ # Remove prediction from residual
300
+ x_residual = x_residual - proj_scale * Dz_pred_
301
+
302
+ # Encode residual (novel part) with sparse SAE
303
+ z_novel = F.relu(torch.matmul(x_residual * self.lam, W_enc))
304
+ if self.cfg.sae_diff_type == "topk":
305
+ kval = self.cfg.kval_topk
306
+ if kval is not None:
307
+ _, topk_indices = torch.topk(z_novel, kval, dim=-1)
308
+ mask = torch.zeros_like(z_novel)
309
+ mask.scatter_(-1, topk_indices, 1)
310
+ z_novel = z_novel * mask
311
+
312
+ # Return only novel codes (these are the interpretable features)
313
+ return z_novel, z_pred
314
+
315
+ def encode(
316
+ self, x: Float[torch.Tensor, "... d_in"]
317
+ ) -> Float[torch.Tensor, "... d_sae"]:
318
+ return self.encode_with_predictions(x)[0]
319
+
320
+ def decode(
321
+ self, feature_acts: Float[torch.Tensor, "... d_sae"]
322
+ ) -> Float[torch.Tensor, "... d_in"]:
323
+ """Decode novel codes to reconstruction.
324
+
325
+ Note: This only decodes the novel codes. For full reconstruction,
326
+ use forward() which includes predicted codes.
327
+ """
328
+ # Decode novel codes
329
+ sae_out = torch.matmul(feature_acts, self.W_dec)
330
+ sae_out = sae_out + self.b_dec
331
+
332
+ # Apply hook
333
+ sae_out = self.hook_sae_recons(sae_out)
334
+
335
+ # Apply output activation normalization (reverses input normalization)
336
+ sae_out = self.run_time_activation_norm_fn_out(sae_out)
337
+
338
+ # Add bias (already removed in process_sae_in)
339
+ logger.warning(
340
+ "NOTE this only decodes x_novel. The x_pred is missing, so we're not reconstructing the full x."
341
+ )
342
+ return sae_out
343
+
344
+ @override
345
+ def forward(
346
+ self, x: Float[torch.Tensor, "... d_in"]
347
+ ) -> Float[torch.Tensor, "... d_in"]:
348
+ """Full forward pass through TemporalSAE.
349
+
350
+ Returns complete reconstruction (predicted + novel).
351
+ """
352
+ # Encode
353
+ z_novel, z_pred = self.encode_with_predictions(x)
354
+
355
+ # Decode the sum of predicted and novel codes.
356
+ x_recons = torch.matmul(z_novel + z_pred, self.W_dec) + self.b_dec
357
+
358
+ # Apply output activation normalization (reverses input normalization)
359
+ x_recons = self.run_time_activation_norm_fn_out(x_recons)
360
+
361
+ return self.hook_sae_output(x_recons)
362
+
363
+ @override
364
+ def fold_W_dec_norm(self) -> None:
365
+ raise NotImplementedError("Folding W_dec_norm is not supported for TemporalSAE")
366
+
367
+ @override
368
+ @torch.no_grad()
369
+ def fold_activation_norm_scaling_factor(self, scaling_factor: float) -> None:
370
+ raise NotImplementedError(
371
+ "Folding activation norm scaling factor is not supported for TemporalSAE"
372
+ )
@@ -319,7 +319,7 @@ class ActivationsStore:
319
319
  )
320
320
  else:
321
321
  warnings.warn(
322
- "Dataset is not tokenized. Pre-tokenizing will improve performance and allows for more control over special tokens. See https://jbloomaus.github.io/SAELens/training_saes/#pretokenizing-datasets for more info."
322
+ "Dataset is not tokenized. Pre-tokenizing will improve performance and allows for more control over special tokens. See https://decoderesearch.github.io/SAELens/training_saes/#pretokenizing-datasets for more info."
323
323
  )
324
324
 
325
325
  self.iterable_sequences = self._iterate_tokenized_sequences()
File without changes
File without changes
File without changes