sae-lens 6.0.0rc1__tar.gz → 6.0.0rc2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc2}/PKG-INFO +1 -1
- {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc2}/pyproject.toml +3 -2
- sae_lens-6.0.0rc2/sae_lens/__init__.py +95 -0
- {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc2}/sae_lens/analysis/hooked_sae_transformer.py +10 -10
- {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc2}/sae_lens/analysis/neuronpedia_integration.py +13 -11
- {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc2}/sae_lens/cache_activations_runner.py +2 -1
- {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc2}/sae_lens/config.py +59 -231
- sae_lens-6.0.0rc2/sae_lens/constants.py +18 -0
- {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc2}/sae_lens/evals.py +16 -13
- {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc2}/sae_lens/loading/pretrained_sae_loaders.py +36 -3
- sae_lens-6.0.0rc2/sae_lens/registry.py +49 -0
- {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc2}/sae_lens/sae_training_runner.py +22 -21
- sae_lens-6.0.0rc2/sae_lens/saes/__init__.py +48 -0
- {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc2}/sae_lens/saes/gated_sae.py +70 -59
- {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc2}/sae_lens/saes/jumprelu_sae.py +58 -72
- {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc2}/sae_lens/saes/sae.py +250 -272
- {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc2}/sae_lens/saes/standard_sae.py +75 -57
- {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc2}/sae_lens/saes/topk_sae.py +72 -83
- {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc2}/sae_lens/training/activations_store.py +31 -15
- {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc2}/sae_lens/training/optim.py +60 -36
- {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc2}/sae_lens/training/sae_trainer.py +44 -69
- {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc2}/sae_lens/training/upload_saes_to_huggingface.py +11 -5
- sae_lens-6.0.0rc2/sae_lens/util.py +28 -0
- sae_lens-6.0.0rc1/sae_lens/__init__.py +0 -61
- sae_lens-6.0.0rc1/sae_lens/regsitry.py +0 -34
- {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc2}/LICENSE +0 -0
- {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc2}/README.md +0 -0
- {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc2}/sae_lens/analysis/__init__.py +0 -0
- {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc2}/sae_lens/load_model.py +0 -0
- {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc2}/sae_lens/loading/__init__.py +0 -0
- {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc2}/sae_lens/loading/pretrained_saes_directory.py +0 -0
- {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc2}/sae_lens/pretokenize_runner.py +0 -0
- {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc2}/sae_lens/pretrained_saes.yaml +0 -0
- {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc2}/sae_lens/tokenization_and_batching.py +0 -0
- {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc2}/sae_lens/training/__init__.py +0 -0
- {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc2}/sae_lens/training/geometric_median.py +0 -0
- {sae_lens-6.0.0rc1 → sae_lens-6.0.0rc2}/sae_lens/tutorial/tsea.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "sae-lens"
|
|
3
|
-
version = "6.0.0-rc.
|
|
3
|
+
version = "6.0.0-rc.2"
|
|
4
4
|
description = "Training and Analyzing Sparse Autoencoders (SAEs)"
|
|
5
5
|
authors = ["Joseph Bloom"]
|
|
6
6
|
readme = "README.md"
|
|
@@ -52,13 +52,14 @@ boto3 = "^1.34.101"
|
|
|
52
52
|
docstr-coverage = "^2.3.2"
|
|
53
53
|
mkdocs = "^1.6.1"
|
|
54
54
|
mkdocs-material = "^9.5.34"
|
|
55
|
-
mkdocs-autorefs = "^1.
|
|
55
|
+
mkdocs-autorefs = "^1.4.2"
|
|
56
56
|
mkdocs-section-index = "^0.3.9"
|
|
57
57
|
mkdocstrings = "^0.25.2"
|
|
58
58
|
mkdocstrings-python = "^1.10.9"
|
|
59
59
|
tabulate = "^0.9.0"
|
|
60
60
|
ruff = "^0.7.4"
|
|
61
61
|
eai-sparsify = "^1.1.1"
|
|
62
|
+
mike = "^2.0.0"
|
|
62
63
|
|
|
63
64
|
[tool.poetry.extras]
|
|
64
65
|
mamba = ["mamba-lens"]
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
# ruff: noqa: E402
|
|
2
|
+
__version__ = "6.0.0-rc.2"
|
|
3
|
+
|
|
4
|
+
import logging
|
|
5
|
+
|
|
6
|
+
logger = logging.getLogger(__name__)
|
|
7
|
+
|
|
8
|
+
from sae_lens.saes import (
|
|
9
|
+
SAE,
|
|
10
|
+
GatedSAE,
|
|
11
|
+
GatedSAEConfig,
|
|
12
|
+
GatedTrainingSAE,
|
|
13
|
+
GatedTrainingSAEConfig,
|
|
14
|
+
JumpReLUSAE,
|
|
15
|
+
JumpReLUSAEConfig,
|
|
16
|
+
JumpReLUTrainingSAE,
|
|
17
|
+
JumpReLUTrainingSAEConfig,
|
|
18
|
+
SAEConfig,
|
|
19
|
+
StandardSAE,
|
|
20
|
+
StandardSAEConfig,
|
|
21
|
+
StandardTrainingSAE,
|
|
22
|
+
StandardTrainingSAEConfig,
|
|
23
|
+
TopKSAE,
|
|
24
|
+
TopKSAEConfig,
|
|
25
|
+
TopKTrainingSAE,
|
|
26
|
+
TopKTrainingSAEConfig,
|
|
27
|
+
TrainingSAE,
|
|
28
|
+
TrainingSAEConfig,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
from .analysis.hooked_sae_transformer import HookedSAETransformer
|
|
32
|
+
from .cache_activations_runner import CacheActivationsRunner
|
|
33
|
+
from .config import (
|
|
34
|
+
CacheActivationsRunnerConfig,
|
|
35
|
+
LanguageModelSAERunnerConfig,
|
|
36
|
+
PretokenizeRunnerConfig,
|
|
37
|
+
)
|
|
38
|
+
from .evals import run_evals
|
|
39
|
+
from .loading.pretrained_sae_loaders import (
|
|
40
|
+
PretrainedSaeDiskLoader,
|
|
41
|
+
PretrainedSaeHuggingfaceLoader,
|
|
42
|
+
)
|
|
43
|
+
from .pretokenize_runner import PretokenizeRunner, pretokenize_runner
|
|
44
|
+
from .registry import register_sae_class, register_sae_training_class
|
|
45
|
+
from .sae_training_runner import SAETrainingRunner
|
|
46
|
+
from .training.activations_store import ActivationsStore
|
|
47
|
+
from .training.upload_saes_to_huggingface import upload_saes_to_huggingface
|
|
48
|
+
|
|
49
|
+
__all__ = [
|
|
50
|
+
"SAE",
|
|
51
|
+
"SAEConfig",
|
|
52
|
+
"TrainingSAE",
|
|
53
|
+
"TrainingSAEConfig",
|
|
54
|
+
"HookedSAETransformer",
|
|
55
|
+
"ActivationsStore",
|
|
56
|
+
"LanguageModelSAERunnerConfig",
|
|
57
|
+
"SAETrainingRunner",
|
|
58
|
+
"CacheActivationsRunnerConfig",
|
|
59
|
+
"CacheActivationsRunner",
|
|
60
|
+
"PretokenizeRunnerConfig",
|
|
61
|
+
"PretokenizeRunner",
|
|
62
|
+
"pretokenize_runner",
|
|
63
|
+
"run_evals",
|
|
64
|
+
"upload_saes_to_huggingface",
|
|
65
|
+
"PretrainedSaeHuggingfaceLoader",
|
|
66
|
+
"PretrainedSaeDiskLoader",
|
|
67
|
+
"register_sae_class",
|
|
68
|
+
"register_sae_training_class",
|
|
69
|
+
"StandardSAE",
|
|
70
|
+
"StandardSAEConfig",
|
|
71
|
+
"StandardTrainingSAE",
|
|
72
|
+
"StandardTrainingSAEConfig",
|
|
73
|
+
"GatedSAE",
|
|
74
|
+
"GatedSAEConfig",
|
|
75
|
+
"GatedTrainingSAE",
|
|
76
|
+
"GatedTrainingSAEConfig",
|
|
77
|
+
"TopKSAE",
|
|
78
|
+
"TopKSAEConfig",
|
|
79
|
+
"TopKTrainingSAE",
|
|
80
|
+
"TopKTrainingSAEConfig",
|
|
81
|
+
"JumpReLUSAE",
|
|
82
|
+
"JumpReLUSAEConfig",
|
|
83
|
+
"JumpReLUTrainingSAE",
|
|
84
|
+
"JumpReLUTrainingSAEConfig",
|
|
85
|
+
]
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
register_sae_class("standard", StandardSAE, StandardSAEConfig)
|
|
89
|
+
register_sae_training_class("standard", StandardTrainingSAE, StandardTrainingSAEConfig)
|
|
90
|
+
register_sae_class("gated", GatedSAE, GatedSAEConfig)
|
|
91
|
+
register_sae_training_class("gated", GatedTrainingSAE, GatedTrainingSAEConfig)
|
|
92
|
+
register_sae_class("topk", TopKSAE, TopKSAEConfig)
|
|
93
|
+
register_sae_training_class("topk", TopKTrainingSAE, TopKTrainingSAEConfig)
|
|
94
|
+
register_sae_class("jumprelu", JumpReLUSAE, JumpReLUSAEConfig)
|
|
95
|
+
register_sae_training_class("jumprelu", JumpReLUTrainingSAE, JumpReLUTrainingSAEConfig)
|
|
@@ -68,7 +68,7 @@ class HookedSAETransformer(HookedTransformer):
|
|
|
68
68
|
super().__init__(*model_args, **model_kwargs)
|
|
69
69
|
self.acts_to_saes: dict[str, SAE] = {} # type: ignore
|
|
70
70
|
|
|
71
|
-
def add_sae(self, sae: SAE, use_error_term: bool | None = None):
|
|
71
|
+
def add_sae(self, sae: SAE[Any], use_error_term: bool | None = None):
|
|
72
72
|
"""Attaches an SAE to the model
|
|
73
73
|
|
|
74
74
|
WARNING: This sae will be permanantly attached until you remove it with reset_saes. This function will also overwrite any existing SAE attached to the same hook point.
|
|
@@ -77,7 +77,7 @@ class HookedSAETransformer(HookedTransformer):
|
|
|
77
77
|
sae: SparseAutoencoderBase. The SAE to attach to the model
|
|
78
78
|
use_error_term: (bool | None) If provided, will set the use_error_term attribute of the SAE to this value. Determines whether the SAE returns input or reconstruction. Defaults to None.
|
|
79
79
|
"""
|
|
80
|
-
act_name = sae.cfg.hook_name
|
|
80
|
+
act_name = sae.cfg.metadata.hook_name
|
|
81
81
|
if (act_name not in self.acts_to_saes) and (act_name not in self.hook_dict):
|
|
82
82
|
logging.warning(
|
|
83
83
|
f"No hook found for {act_name}. Skipping. Check model.hook_dict for available hooks."
|
|
@@ -92,7 +92,7 @@ class HookedSAETransformer(HookedTransformer):
|
|
|
92
92
|
set_deep_attr(self, act_name, sae)
|
|
93
93
|
self.setup()
|
|
94
94
|
|
|
95
|
-
def _reset_sae(self, act_name: str, prev_sae: SAE | None = None):
|
|
95
|
+
def _reset_sae(self, act_name: str, prev_sae: SAE[Any] | None = None):
|
|
96
96
|
"""Resets an SAE that was attached to the model
|
|
97
97
|
|
|
98
98
|
By default will remove the SAE from that hook_point.
|
|
@@ -124,7 +124,7 @@ class HookedSAETransformer(HookedTransformer):
|
|
|
124
124
|
def reset_saes(
|
|
125
125
|
self,
|
|
126
126
|
act_names: str | list[str] | None = None,
|
|
127
|
-
prev_saes: list[SAE | None] | None = None,
|
|
127
|
+
prev_saes: list[SAE[Any] | None] | None = None,
|
|
128
128
|
):
|
|
129
129
|
"""Reset the SAEs attached to the model
|
|
130
130
|
|
|
@@ -154,7 +154,7 @@ class HookedSAETransformer(HookedTransformer):
|
|
|
154
154
|
def run_with_saes(
|
|
155
155
|
self,
|
|
156
156
|
*model_args: Any,
|
|
157
|
-
saes: SAE | list[SAE] = [],
|
|
157
|
+
saes: SAE[Any] | list[SAE[Any]] = [],
|
|
158
158
|
reset_saes_end: bool = True,
|
|
159
159
|
use_error_term: bool | None = None,
|
|
160
160
|
**model_kwargs: Any,
|
|
@@ -183,7 +183,7 @@ class HookedSAETransformer(HookedTransformer):
|
|
|
183
183
|
def run_with_cache_with_saes(
|
|
184
184
|
self,
|
|
185
185
|
*model_args: Any,
|
|
186
|
-
saes: SAE | list[SAE] = [],
|
|
186
|
+
saes: SAE[Any] | list[SAE[Any]] = [],
|
|
187
187
|
reset_saes_end: bool = True,
|
|
188
188
|
use_error_term: bool | None = None,
|
|
189
189
|
return_cache_object: bool = True,
|
|
@@ -225,7 +225,7 @@ class HookedSAETransformer(HookedTransformer):
|
|
|
225
225
|
def run_with_hooks_with_saes(
|
|
226
226
|
self,
|
|
227
227
|
*model_args: Any,
|
|
228
|
-
saes: SAE | list[SAE] = [],
|
|
228
|
+
saes: SAE[Any] | list[SAE[Any]] = [],
|
|
229
229
|
reset_saes_end: bool = True,
|
|
230
230
|
fwd_hooks: list[tuple[str | Callable, Callable]] = [], # type: ignore
|
|
231
231
|
bwd_hooks: list[tuple[str | Callable, Callable]] = [], # type: ignore
|
|
@@ -261,7 +261,7 @@ class HookedSAETransformer(HookedTransformer):
|
|
|
261
261
|
@contextmanager
|
|
262
262
|
def saes(
|
|
263
263
|
self,
|
|
264
|
-
saes: SAE | list[SAE] = [],
|
|
264
|
+
saes: SAE[Any] | list[SAE[Any]] = [],
|
|
265
265
|
reset_saes_end: bool = True,
|
|
266
266
|
use_error_term: bool | None = None,
|
|
267
267
|
):
|
|
@@ -295,8 +295,8 @@ class HookedSAETransformer(HookedTransformer):
|
|
|
295
295
|
saes = [saes]
|
|
296
296
|
try:
|
|
297
297
|
for sae in saes:
|
|
298
|
-
act_names_to_reset.append(sae.cfg.hook_name)
|
|
299
|
-
prev_sae = self.acts_to_saes.get(sae.cfg.hook_name, None)
|
|
298
|
+
act_names_to_reset.append(sae.cfg.metadata.hook_name)
|
|
299
|
+
prev_sae = self.acts_to_saes.get(sae.cfg.metadata.hook_name, None)
|
|
300
300
|
prev_saes.append(prev_sae)
|
|
301
301
|
self.add_sae(sae, use_error_term=use_error_term)
|
|
302
302
|
yield self
|
|
@@ -58,7 +58,7 @@ def NanAndInfReplacer(value: str):
|
|
|
58
58
|
return NAN_REPLACEMENT
|
|
59
59
|
|
|
60
60
|
|
|
61
|
-
def open_neuronpedia_feature_dashboard(sae: SAE, index: int):
|
|
61
|
+
def open_neuronpedia_feature_dashboard(sae: SAE[Any], index: int):
|
|
62
62
|
sae_id = sae.cfg.neuronpedia_id
|
|
63
63
|
if sae_id is None:
|
|
64
64
|
logger.warning(
|
|
@@ -70,7 +70,7 @@ def open_neuronpedia_feature_dashboard(sae: SAE, index: int):
|
|
|
70
70
|
|
|
71
71
|
|
|
72
72
|
def get_neuronpedia_quick_list(
|
|
73
|
-
sae: SAE,
|
|
73
|
+
sae: SAE[Any],
|
|
74
74
|
features: list[int],
|
|
75
75
|
name: str = "temporary_list",
|
|
76
76
|
):
|
|
@@ -157,9 +157,10 @@ def sleep_identity(x: T) -> T:
|
|
|
157
157
|
|
|
158
158
|
|
|
159
159
|
@retry(wait=wait_random_exponential(min=1, max=500), stop=stop_after_attempt(10))
|
|
160
|
-
async def simulate_and_score(
|
|
161
|
-
simulator: NeuronSimulator,
|
|
162
|
-
|
|
160
|
+
async def simulate_and_score( # type: ignore
|
|
161
|
+
simulator: NeuronSimulator,
|
|
162
|
+
activation_records: list[ActivationRecord], # type: ignore
|
|
163
|
+
) -> ScoredSimulation: # type: ignore
|
|
163
164
|
"""Score an explanation of a neuron by how well it predicts activations on the given text sequences."""
|
|
164
165
|
scored_sequence_simulations = await asyncio.gather(
|
|
165
166
|
*[
|
|
@@ -330,8 +331,9 @@ async def autointerp_neuronpedia_features( # noqa: C901
|
|
|
330
331
|
feature.activations = []
|
|
331
332
|
activation_records = [
|
|
332
333
|
ActivationRecord(
|
|
333
|
-
tokens=activation.tokens,
|
|
334
|
-
|
|
334
|
+
tokens=activation.tokens, # type: ignore
|
|
335
|
+
activations=activation.act_values, # type: ignore
|
|
336
|
+
) # type: ignore
|
|
335
337
|
for activation in feature.activations
|
|
336
338
|
]
|
|
337
339
|
|
|
@@ -384,15 +386,15 @@ async def autointerp_neuronpedia_features( # noqa: C901
|
|
|
384
386
|
|
|
385
387
|
temp_activation_records = [
|
|
386
388
|
ActivationRecord(
|
|
387
|
-
tokens=[
|
|
389
|
+
tokens=[ # type: ignore
|
|
388
390
|
token.replace("<|endoftext|>", "<|not_endoftext|>")
|
|
389
391
|
.replace(" 55", "_55")
|
|
390
392
|
.encode("ascii", errors="backslashreplace")
|
|
391
393
|
.decode("ascii")
|
|
392
|
-
for token in activation_record.tokens
|
|
394
|
+
for token in activation_record.tokens # type: ignore
|
|
393
395
|
],
|
|
394
|
-
activations=activation_record.activations,
|
|
395
|
-
)
|
|
396
|
+
activations=activation_record.activations, # type: ignore
|
|
397
|
+
) # type: ignore
|
|
396
398
|
for activation_record in activation_records
|
|
397
399
|
]
|
|
398
400
|
|
|
@@ -14,7 +14,8 @@ from tqdm import tqdm
|
|
|
14
14
|
from transformer_lens.HookedTransformer import HookedRootModule
|
|
15
15
|
|
|
16
16
|
from sae_lens import logger
|
|
17
|
-
from sae_lens.config import
|
|
17
|
+
from sae_lens.config import CacheActivationsRunnerConfig
|
|
18
|
+
from sae_lens.constants import DTYPE_MAP
|
|
18
19
|
from sae_lens.load_model import load_model
|
|
19
20
|
from sae_lens.training.activations_store import ActivationsStore
|
|
20
21
|
|