sae-lens 6.31.0__tar.gz → 6.34.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sae_lens-6.31.0 → sae_lens-6.34.0}/PKG-INFO +2 -2
- {sae_lens-6.31.0 → sae_lens-6.34.0}/pyproject.toml +3 -2
- {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/__init__.py +14 -1
- sae_lens-6.34.0/sae_lens/analysis/__init__.py +15 -0
- sae_lens-6.34.0/sae_lens/analysis/compat.py +16 -0
- {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/analysis/hooked_sae_transformer.py +175 -76
- sae_lens-6.34.0/sae_lens/analysis/sae_transformer_bridge.py +379 -0
- {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/loading/pretrained_sae_loaders.py +2 -1
- {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/loading/pretrained_saes_directory.py +0 -22
- {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/saes/sae.py +20 -33
- {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/synthetic/__init__.py +13 -0
- {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/synthetic/correlation.py +12 -14
- sae_lens-6.34.0/sae_lens/synthetic/stats.py +205 -0
- sae_lens-6.31.0/sae_lens/training/__init__.py +0 -0
- {sae_lens-6.31.0 → sae_lens-6.34.0}/LICENSE +0 -0
- {sae_lens-6.31.0 → sae_lens-6.34.0}/README.md +0 -0
- {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/analysis/neuronpedia_integration.py +0 -0
- {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/cache_activations_runner.py +0 -0
- {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/config.py +0 -0
- {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/constants.py +0 -0
- {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/evals.py +0 -0
- {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/llm_sae_training_runner.py +0 -0
- {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/load_model.py +0 -0
- {sae_lens-6.31.0/sae_lens/analysis → sae_lens-6.34.0/sae_lens/loading}/__init__.py +0 -0
- {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/pretokenize_runner.py +0 -0
- {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/pretrained_saes.yaml +0 -0
- {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/registry.py +0 -0
- {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/saes/__init__.py +0 -0
- {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/saes/batchtopk_sae.py +0 -0
- {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/saes/gated_sae.py +0 -0
- {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/saes/jumprelu_sae.py +0 -0
- {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/saes/matching_pursuit_sae.py +0 -0
- {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/saes/matryoshka_batchtopk_sae.py +0 -0
- {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/saes/standard_sae.py +0 -0
- {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/saes/temporal_sae.py +0 -0
- {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/saes/topk_sae.py +0 -0
- {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/saes/transcoder.py +0 -0
- {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/synthetic/activation_generator.py +0 -0
- {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/synthetic/evals.py +0 -0
- {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/synthetic/feature_dictionary.py +0 -0
- {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/synthetic/firing_probabilities.py +0 -0
- {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/synthetic/hierarchy.py +0 -0
- {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/synthetic/initialization.py +0 -0
- {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/synthetic/plotting.py +0 -0
- {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/synthetic/training.py +0 -0
- {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/tokenization_and_batching.py +0 -0
- {sae_lens-6.31.0/sae_lens/loading → sae_lens-6.34.0/sae_lens/training}/__init__.py +0 -0
- {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/training/activation_scaler.py +0 -0
- {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/training/activations_store.py +0 -0
- {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/training/mixing_buffer.py +0 -0
- {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/training/optim.py +0 -0
- {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/training/sae_trainer.py +0 -0
- {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/training/types.py +0 -0
- {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
- {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/tutorial/tsea.py +0 -0
- {sae_lens-6.31.0 → sae_lens-6.34.0}/sae_lens/util.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: sae-lens
|
|
3
|
-
Version: 6.
|
|
3
|
+
Version: 6.34.0
|
|
4
4
|
Summary: Training and Analyzing Sparse Autoencoders (SAEs)
|
|
5
5
|
License: MIT
|
|
6
6
|
License-File: LICENSE
|
|
@@ -27,7 +27,7 @@ Requires-Dist: pyyaml (>=6.0.1,<7.0.0)
|
|
|
27
27
|
Requires-Dist: safetensors (>=0.4.2,<1.0.0)
|
|
28
28
|
Requires-Dist: simple-parsing (>=0.1.6,<0.2.0)
|
|
29
29
|
Requires-Dist: tenacity (>=9.0.0)
|
|
30
|
-
Requires-Dist: transformer-lens (>=2.16.1
|
|
30
|
+
Requires-Dist: transformer-lens (>=2.16.1)
|
|
31
31
|
Requires-Dist: transformers (>=4.38.1,<5.0.0)
|
|
32
32
|
Requires-Dist: typing-extensions (>=4.10.0,<5.0.0)
|
|
33
33
|
Project-URL: Homepage, https://decoderesearch.github.io/SAELens
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "sae-lens"
|
|
3
|
-
version = "6.
|
|
3
|
+
version = "6.34.0"
|
|
4
4
|
description = "Training and Analyzing Sparse Autoencoders (SAEs)"
|
|
5
5
|
authors = ["Joseph Bloom"]
|
|
6
6
|
readme = "README.md"
|
|
@@ -19,7 +19,7 @@ classifiers = ["Topic :: Scientific/Engineering :: Artificial Intelligence"]
|
|
|
19
19
|
|
|
20
20
|
[tool.poetry.dependencies]
|
|
21
21
|
python = "^3.10"
|
|
22
|
-
transformer-lens = "
|
|
22
|
+
transformer-lens = ">=2.16.1"
|
|
23
23
|
transformers = "^4.38.1"
|
|
24
24
|
plotly = ">=5.19.0"
|
|
25
25
|
plotly-express = ">=0.4.1"
|
|
@@ -59,6 +59,7 @@ mike = "^2.0.0"
|
|
|
59
59
|
trio = "^0.30.0"
|
|
60
60
|
dictionary-learning = "^0.1.0"
|
|
61
61
|
kaleido = "^1.2.0"
|
|
62
|
+
transformer-lens = { version = "3.0.0b1", allow-prereleases = true }
|
|
62
63
|
|
|
63
64
|
[tool.poetry.extras]
|
|
64
65
|
mamba = ["mamba-lens"]
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
# ruff: noqa: E402
|
|
2
|
-
__version__ = "6.
|
|
2
|
+
__version__ = "6.34.0"
|
|
3
3
|
|
|
4
4
|
import logging
|
|
5
5
|
|
|
@@ -125,6 +125,19 @@ __all__ = [
|
|
|
125
125
|
"MatchingPursuitTrainingSAEConfig",
|
|
126
126
|
]
|
|
127
127
|
|
|
128
|
+
# Conditional export for SAETransformerBridge (requires transformer-lens v3+)
|
|
129
|
+
try:
|
|
130
|
+
from sae_lens.analysis.compat import has_transformer_bridge
|
|
131
|
+
|
|
132
|
+
if has_transformer_bridge():
|
|
133
|
+
from sae_lens.analysis.sae_transformer_bridge import ( # noqa: F401
|
|
134
|
+
SAETransformerBridge,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
__all__.append("SAETransformerBridge")
|
|
138
|
+
except ImportError:
|
|
139
|
+
pass
|
|
140
|
+
|
|
128
141
|
|
|
129
142
|
register_sae_class("standard", StandardSAE, StandardSAEConfig)
|
|
130
143
|
register_sae_training_class("standard", StandardTrainingSAE, StandardTrainingSAEConfig)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from sae_lens.analysis.hooked_sae_transformer import HookedSAETransformer
|
|
2
|
+
|
|
3
|
+
__all__ = ["HookedSAETransformer"]
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
from sae_lens.analysis.compat import has_transformer_bridge
|
|
7
|
+
|
|
8
|
+
if has_transformer_bridge():
|
|
9
|
+
from sae_lens.analysis.sae_transformer_bridge import ( # noqa: F401
|
|
10
|
+
SAETransformerBridge,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
__all__.append("SAETransformerBridge")
|
|
14
|
+
except ImportError:
|
|
15
|
+
pass
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import importlib.metadata
|
|
2
|
+
|
|
3
|
+
from packaging.version import parse as parse_version
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def get_transformer_lens_version() -> tuple[int, int, int]:
|
|
7
|
+
"""Get transformer-lens version as (major, minor, patch)."""
|
|
8
|
+
version_str = importlib.metadata.version("transformer-lens")
|
|
9
|
+
version = parse_version(version_str)
|
|
10
|
+
return (version.major, version.minor, version.micro)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def has_transformer_bridge() -> bool:
|
|
14
|
+
"""Check if TransformerBridge is available (v3+)."""
|
|
15
|
+
major, _, _ = get_transformer_lens_version()
|
|
16
|
+
return major >= 3
|
|
@@ -1,13 +1,14 @@
|
|
|
1
|
-
import logging
|
|
2
1
|
from contextlib import contextmanager
|
|
3
2
|
from typing import Any, Callable
|
|
4
3
|
|
|
5
4
|
import torch
|
|
5
|
+
from torch import nn
|
|
6
6
|
from transformer_lens.ActivationCache import ActivationCache
|
|
7
7
|
from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP
|
|
8
8
|
from transformer_lens.hook_points import HookPoint # Hooking utilities
|
|
9
9
|
from transformer_lens.HookedTransformer import HookedTransformer
|
|
10
10
|
|
|
11
|
+
from sae_lens import logger
|
|
11
12
|
from sae_lens.saes.sae import SAE
|
|
12
13
|
|
|
13
14
|
SingleLoss = torch.Tensor # Type alias for a single element tensor
|
|
@@ -15,6 +16,71 @@ LossPerToken = torch.Tensor
|
|
|
15
16
|
Loss = SingleLoss | LossPerToken
|
|
16
17
|
|
|
17
18
|
|
|
19
|
+
class _SAEWrapper(nn.Module):
|
|
20
|
+
"""Wrapper for SAE/Transcoder that handles error term and hook coordination.
|
|
21
|
+
|
|
22
|
+
For SAEs (input_hook == output_hook), _captured_input stays None and we use
|
|
23
|
+
the forward argument directly. For transcoders, _captured_input is set at
|
|
24
|
+
the input hook via capture_input().
|
|
25
|
+
|
|
26
|
+
Implementation Note:
|
|
27
|
+
The SAE is stored in __dict__ directly rather than as a registered submodule.
|
|
28
|
+
This is intentional: PyTorch's module registration would add a ".sae." prefix
|
|
29
|
+
to all hook names in the cache (e.g., "blocks.0.hook_mlp_out.sae.hook_sae_input"
|
|
30
|
+
instead of "blocks.0.hook_mlp_out.hook_sae_input"). By storing in __dict__ and
|
|
31
|
+
copying hooks directly to the wrapper, we preserve the expected cache paths
|
|
32
|
+
for backwards compatibility.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(self, sae: SAE[Any], use_error_term: bool = False):
|
|
36
|
+
super().__init__()
|
|
37
|
+
# Store SAE in __dict__ to avoid registering as submodule. This keeps cache
|
|
38
|
+
# paths clean by avoiding a ".sae." prefix on hook names. See class docstring.
|
|
39
|
+
self.__dict__["_sae"] = sae
|
|
40
|
+
# Copy SAE's hooks directly to wrapper so they appear at the right path
|
|
41
|
+
for name, hook in sae.hook_dict.items():
|
|
42
|
+
setattr(self, name, hook)
|
|
43
|
+
self.use_error_term = use_error_term
|
|
44
|
+
self._captured_input: torch.Tensor | None = None
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def sae(self) -> SAE[Any]:
|
|
48
|
+
return self.__dict__["_sae"]
|
|
49
|
+
|
|
50
|
+
def capture_input(self, x: torch.Tensor) -> None:
|
|
51
|
+
"""Capture input at input hook (for transcoders).
|
|
52
|
+
|
|
53
|
+
Note: We don't clone the tensor here - the input should not be modified
|
|
54
|
+
in-place between capture and use, and avoiding clone preserves memory.
|
|
55
|
+
"""
|
|
56
|
+
self._captured_input = x
|
|
57
|
+
|
|
58
|
+
def forward(self, original_output: torch.Tensor) -> torch.Tensor:
|
|
59
|
+
"""Run SAE/transcoder at output hook location."""
|
|
60
|
+
# For SAE: use original_output as input (same hook for input/output)
|
|
61
|
+
# For transcoder: use captured input from earlier hook
|
|
62
|
+
sae_input = (
|
|
63
|
+
self._captured_input
|
|
64
|
+
if self._captured_input is not None
|
|
65
|
+
else original_output
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
# Temporarily disable SAE's internal use_error_term - we handle it here
|
|
69
|
+
sae_use_error_term = self.sae.use_error_term
|
|
70
|
+
self.sae.use_error_term = False
|
|
71
|
+
try:
|
|
72
|
+
sae_out = self.sae(sae_input)
|
|
73
|
+
|
|
74
|
+
if self.use_error_term:
|
|
75
|
+
error = original_output - sae_out.detach()
|
|
76
|
+
sae_out = sae_out + error
|
|
77
|
+
|
|
78
|
+
return sae_out
|
|
79
|
+
finally:
|
|
80
|
+
self.sae.use_error_term = sae_use_error_term
|
|
81
|
+
self._captured_input = None
|
|
82
|
+
|
|
83
|
+
|
|
18
84
|
def get_deep_attr(obj: Any, path: str):
|
|
19
85
|
"""Helper function to get a nested attribute from a object.
|
|
20
86
|
In practice used to access HookedTransformer HookPoints (eg model.blocks[0].attn.hook_z)
|
|
@@ -78,88 +144,130 @@ class HookedSAETransformer(HookedTransformer):
|
|
|
78
144
|
add_hook_in_to_mlp(block.mlp) # type: ignore
|
|
79
145
|
self.setup()
|
|
80
146
|
|
|
81
|
-
self.
|
|
147
|
+
self._acts_to_saes: dict[str, _SAEWrapper] = {}
|
|
148
|
+
# Track output hooks used by transcoders for cleanup
|
|
149
|
+
self._transcoder_output_hooks: dict[str, str] = {}
|
|
150
|
+
|
|
151
|
+
@property
|
|
152
|
+
def acts_to_saes(self) -> dict[str, SAE[Any]]:
|
|
153
|
+
"""Returns a dict mapping hook names to attached SAEs."""
|
|
154
|
+
return {name: wrapper.sae for name, wrapper in self._acts_to_saes.items()}
|
|
82
155
|
|
|
83
156
|
def add_sae(self, sae: SAE[Any], use_error_term: bool | None = None):
|
|
84
|
-
"""Attaches an SAE to the model
|
|
157
|
+
"""Attaches an SAE or Transcoder to the model.
|
|
85
158
|
|
|
86
|
-
WARNING: This
|
|
159
|
+
WARNING: This SAE will be permanently attached until you remove it with
|
|
160
|
+
reset_saes. This function will also overwrite any existing SAE attached
|
|
161
|
+
to the same hook point.
|
|
87
162
|
|
|
88
163
|
Args:
|
|
89
|
-
sae:
|
|
90
|
-
use_error_term:
|
|
164
|
+
sae: The SAE or Transcoder to attach to the model.
|
|
165
|
+
use_error_term: If True, computes error term so output matches what the
|
|
166
|
+
model would have produced without the SAE. This works for both SAEs
|
|
167
|
+
(where input==output hook) and transcoders (where they differ).
|
|
168
|
+
Defaults to None (uses SAE's existing setting).
|
|
91
169
|
"""
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
170
|
+
input_hook = sae.cfg.metadata.hook_name
|
|
171
|
+
output_hook = sae.cfg.metadata.hook_name_out or input_hook
|
|
172
|
+
|
|
173
|
+
if (input_hook not in self._acts_to_saes) and (
|
|
174
|
+
input_hook not in self.hook_dict
|
|
175
|
+
):
|
|
176
|
+
logger.warning(
|
|
177
|
+
f"No hook found for {input_hook}. Skipping. Check model.hook_dict for available hooks."
|
|
96
178
|
)
|
|
97
179
|
return
|
|
98
180
|
|
|
99
|
-
if
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
181
|
+
# Check if output hook exists (either as hook_dict entry or already has SAE attached)
|
|
182
|
+
output_hook_exists = (
|
|
183
|
+
output_hook in self.hook_dict
|
|
184
|
+
or output_hook in self._acts_to_saes
|
|
185
|
+
or any(v == output_hook for v in self._transcoder_output_hooks.values())
|
|
186
|
+
)
|
|
187
|
+
if not output_hook_exists:
|
|
188
|
+
logger.warning(f"No hook found for output {output_hook}. Skipping.")
|
|
189
|
+
return
|
|
190
|
+
|
|
191
|
+
# Always use wrapper - it handles both SAEs and transcoders uniformly
|
|
192
|
+
# If use_error_term not specified, respect SAE's existing setting
|
|
193
|
+
effective_use_error_term = (
|
|
194
|
+
use_error_term if use_error_term is not None else sae.use_error_term
|
|
195
|
+
)
|
|
196
|
+
wrapper = _SAEWrapper(sae, use_error_term=effective_use_error_term)
|
|
197
|
+
|
|
198
|
+
# For transcoders (input != output), capture input at input hook
|
|
199
|
+
if input_hook != output_hook:
|
|
200
|
+
input_hook_point = get_deep_attr(self, input_hook)
|
|
201
|
+
if isinstance(input_hook_point, HookPoint):
|
|
202
|
+
input_hook_point.add_hook(
|
|
203
|
+
lambda tensor, hook: (wrapper.capture_input(tensor), tensor)[1], # noqa: ARG005
|
|
204
|
+
dir="fwd",
|
|
205
|
+
is_permanent=True,
|
|
206
|
+
)
|
|
207
|
+
self._transcoder_output_hooks[input_hook] = output_hook
|
|
208
|
+
|
|
209
|
+
# Store wrapper in _acts_to_saes and at output hook
|
|
210
|
+
self._acts_to_saes[input_hook] = wrapper
|
|
211
|
+
set_deep_attr(self, output_hook, wrapper)
|
|
105
212
|
self.setup()
|
|
106
213
|
|
|
107
|
-
def _reset_sae(
|
|
108
|
-
|
|
214
|
+
def _reset_sae(
|
|
215
|
+
self, act_name: str, prev_wrapper: _SAEWrapper | None = None
|
|
216
|
+
) -> None:
|
|
217
|
+
"""Resets an SAE that was attached to the model.
|
|
109
218
|
|
|
110
219
|
By default will remove the SAE from that hook_point.
|
|
111
|
-
If
|
|
112
|
-
This is mainly used to restore previously attached SAEs after temporarily running with different SAEs (eg with run_with_saes)
|
|
220
|
+
If prev_wrapper is provided, will restore that wrapper's SAE with its settings.
|
|
113
221
|
|
|
114
222
|
Args:
|
|
115
|
-
act_name:
|
|
116
|
-
|
|
223
|
+
act_name: The hook_name of the SAE to reset.
|
|
224
|
+
prev_wrapper: The previous wrapper to restore. If None, will just
|
|
225
|
+
remove the SAE from this hook point. Defaults to None.
|
|
117
226
|
"""
|
|
118
|
-
if act_name not in self.
|
|
119
|
-
|
|
227
|
+
if act_name not in self._acts_to_saes:
|
|
228
|
+
logger.warning(
|
|
120
229
|
f"No SAE is attached to {act_name}. There's nothing to reset."
|
|
121
230
|
)
|
|
122
231
|
return
|
|
123
232
|
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
233
|
+
# Determine output hook location (different from input for transcoders)
|
|
234
|
+
output_hook = self._transcoder_output_hooks.pop(act_name, act_name)
|
|
235
|
+
|
|
236
|
+
# For transcoders, clear permanent hooks from input hook point
|
|
237
|
+
if output_hook != act_name:
|
|
238
|
+
input_hook_point = get_deep_attr(self, act_name)
|
|
239
|
+
if isinstance(input_hook_point, HookPoint):
|
|
240
|
+
input_hook_point.remove_hooks(dir="fwd", including_permanent=True)
|
|
241
|
+
|
|
242
|
+
# Reset output hook location
|
|
243
|
+
set_deep_attr(self, output_hook, HookPoint())
|
|
244
|
+
del self._acts_to_saes[act_name]
|
|
128
245
|
|
|
129
|
-
if
|
|
130
|
-
|
|
131
|
-
self.
|
|
132
|
-
|
|
133
|
-
set_deep_attr(self, act_name, HookPoint())
|
|
134
|
-
del self.acts_to_saes[act_name]
|
|
246
|
+
if prev_wrapper is not None:
|
|
247
|
+
# Rebuild hook_dict before adding new SAE
|
|
248
|
+
self.setup()
|
|
249
|
+
self.add_sae(prev_wrapper.sae, use_error_term=prev_wrapper.use_error_term)
|
|
135
250
|
|
|
136
251
|
def reset_saes(
|
|
137
252
|
self,
|
|
138
253
|
act_names: str | list[str] | None = None,
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
"""Reset the SAEs attached to the model
|
|
254
|
+
) -> None:
|
|
255
|
+
"""Reset the SAEs attached to the model.
|
|
142
256
|
|
|
143
|
-
If act_names are provided will just reset SAEs attached to those hooks.
|
|
144
|
-
|
|
257
|
+
If act_names are provided will just reset SAEs attached to those hooks.
|
|
258
|
+
Otherwise will reset all SAEs attached to the model.
|
|
145
259
|
|
|
146
260
|
Args:
|
|
147
|
-
act_names
|
|
148
|
-
|
|
261
|
+
act_names: The act_names of the SAEs to reset. If None, will reset
|
|
262
|
+
all SAEs attached to the model. Defaults to None.
|
|
149
263
|
"""
|
|
150
264
|
if isinstance(act_names, str):
|
|
151
265
|
act_names = [act_names]
|
|
152
266
|
elif act_names is None:
|
|
153
|
-
act_names = list(self.
|
|
154
|
-
|
|
155
|
-
if prev_saes:
|
|
156
|
-
if len(act_names) != len(prev_saes):
|
|
157
|
-
raise ValueError("act_names and prev_saes must have the same length")
|
|
158
|
-
else:
|
|
159
|
-
prev_saes = [None] * len(act_names) # type: ignore
|
|
267
|
+
act_names = list(self._acts_to_saes.keys())
|
|
160
268
|
|
|
161
|
-
for act_name
|
|
162
|
-
self._reset_sae(act_name
|
|
269
|
+
for act_name in act_names:
|
|
270
|
+
self._reset_sae(act_name)
|
|
163
271
|
|
|
164
272
|
self.setup()
|
|
165
273
|
|
|
@@ -269,41 +377,32 @@ class HookedSAETransformer(HookedTransformer):
|
|
|
269
377
|
reset_saes_end: bool = True,
|
|
270
378
|
use_error_term: bool | None = None,
|
|
271
379
|
):
|
|
272
|
-
"""
|
|
273
|
-
A context manager for adding temporary SAEs to the model.
|
|
274
|
-
See HookedTransformer.hooks for a similar context manager for hooks.
|
|
275
|
-
By default will keep track of previously attached SAEs, and restore them when the context manager exits.
|
|
276
|
-
|
|
277
|
-
Example:
|
|
278
|
-
|
|
279
|
-
.. code-block:: python
|
|
280
|
-
|
|
281
|
-
from transformer_lens import HookedSAETransformer
|
|
282
|
-
from sae_lens.saes.sae import SAE
|
|
283
|
-
|
|
284
|
-
model = HookedSAETransformer.from_pretrained('gpt2-small')
|
|
285
|
-
sae_cfg = SAEConfig(...)
|
|
286
|
-
sae = SAE(sae_cfg)
|
|
287
|
-
with model.saes(saes=[sae]):
|
|
288
|
-
spliced_logits = model(text)
|
|
380
|
+
"""A context manager for adding temporary SAEs to the model.
|
|
289
381
|
|
|
382
|
+
See HookedTransformer.hooks for a similar context manager for hooks.
|
|
383
|
+
By default will keep track of previously attached SAEs, and restore
|
|
384
|
+
them when the context manager exits.
|
|
290
385
|
|
|
291
386
|
Args:
|
|
292
|
-
saes
|
|
293
|
-
reset_saes_end
|
|
294
|
-
|
|
387
|
+
saes: SAEs to be attached.
|
|
388
|
+
reset_saes_end: If True, removes all SAEs added by this context
|
|
389
|
+
manager when the context manager exits, returning previously
|
|
390
|
+
attached SAEs to their original state.
|
|
391
|
+
use_error_term: If provided, will set the use_error_term attribute
|
|
392
|
+
of all SAEs attached during this run to this value.
|
|
295
393
|
"""
|
|
296
|
-
|
|
297
|
-
prev_saes = []
|
|
394
|
+
saes_to_restore: list[tuple[str, _SAEWrapper | None]] = []
|
|
298
395
|
if isinstance(saes, SAE):
|
|
299
396
|
saes = [saes]
|
|
300
397
|
try:
|
|
301
398
|
for sae in saes:
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
399
|
+
act_name = sae.cfg.metadata.hook_name
|
|
400
|
+
prev_wrapper = self._acts_to_saes.get(act_name, None)
|
|
401
|
+
saes_to_restore.append((act_name, prev_wrapper))
|
|
305
402
|
self.add_sae(sae, use_error_term=use_error_term)
|
|
306
403
|
yield self
|
|
307
404
|
finally:
|
|
308
405
|
if reset_saes_end:
|
|
309
|
-
|
|
406
|
+
for act_name, prev_wrapper in saes_to_restore:
|
|
407
|
+
self._reset_sae(act_name, prev_wrapper)
|
|
408
|
+
self.setup()
|