sae-lens 6.33.0__py3-none-any.whl → 6.34.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +1 -1
- sae_lens/analysis/hooked_sae_transformer.py +176 -76
- sae_lens/analysis/sae_transformer_bridge.py +101 -70
- sae_lens/saes/sae.py +20 -2
- {sae_lens-6.33.0.dist-info → sae_lens-6.34.1.dist-info}/METADATA +1 -1
- {sae_lens-6.33.0.dist-info → sae_lens-6.34.1.dist-info}/RECORD +8 -8
- {sae_lens-6.33.0.dist-info → sae_lens-6.34.1.dist-info}/WHEEL +1 -1
- {sae_lens-6.33.0.dist-info → sae_lens-6.34.1.dist-info}/licenses/LICENSE +0 -0
sae_lens/__init__.py
CHANGED
|
@@ -1,13 +1,14 @@
|
|
|
1
|
-
import logging
|
|
2
1
|
from contextlib import contextmanager
|
|
3
2
|
from typing import Any, Callable
|
|
4
3
|
|
|
5
4
|
import torch
|
|
5
|
+
from torch import nn
|
|
6
6
|
from transformer_lens.ActivationCache import ActivationCache
|
|
7
7
|
from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP
|
|
8
8
|
from transformer_lens.hook_points import HookPoint # Hooking utilities
|
|
9
9
|
from transformer_lens.HookedTransformer import HookedTransformer
|
|
10
10
|
|
|
11
|
+
from sae_lens import logger
|
|
11
12
|
from sae_lens.saes.sae import SAE
|
|
12
13
|
|
|
13
14
|
SingleLoss = torch.Tensor # Type alias for a single element tensor
|
|
@@ -15,6 +16,72 @@ LossPerToken = torch.Tensor
|
|
|
15
16
|
Loss = SingleLoss | LossPerToken
|
|
16
17
|
|
|
17
18
|
|
|
19
|
+
class _SAEWrapper(nn.Module):
|
|
20
|
+
"""Wrapper for SAE/Transcoder that handles error term and hook coordination.
|
|
21
|
+
|
|
22
|
+
For SAEs (input_hook == output_hook), _captured_input stays None and we use
|
|
23
|
+
the forward argument directly. For transcoders, _captured_input is set at
|
|
24
|
+
the input hook via capture_input().
|
|
25
|
+
|
|
26
|
+
Implementation Note:
|
|
27
|
+
The SAE is stored in __dict__ directly rather than as a registered submodule.
|
|
28
|
+
This is intentional: PyTorch's module registration would add a ".sae." prefix
|
|
29
|
+
to all hook names in the cache (e.g., "blocks.0.hook_mlp_out.sae.hook_sae_input"
|
|
30
|
+
instead of "blocks.0.hook_mlp_out.hook_sae_input"). By storing in __dict__ and
|
|
31
|
+
copying hooks directly to the wrapper, we preserve the expected cache paths
|
|
32
|
+
for backwards compatibility.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(self, sae: SAE[Any], use_error_term: bool = False):
|
|
36
|
+
super().__init__()
|
|
37
|
+
# Store SAE in __dict__ to avoid registering as submodule. This keeps cache
|
|
38
|
+
# paths clean by avoiding a ".sae." prefix on hook names. See class docstring.
|
|
39
|
+
self.__dict__["_sae"] = sae
|
|
40
|
+
# Copy SAE's hooks directly to wrapper so they appear at the right path
|
|
41
|
+
for name, hook in sae.hook_dict.items():
|
|
42
|
+
setattr(self, name, hook)
|
|
43
|
+
self.use_error_term = use_error_term
|
|
44
|
+
self._captured_input: torch.Tensor | None = None
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def sae(self) -> SAE[Any]:
|
|
48
|
+
return self.__dict__["_sae"]
|
|
49
|
+
|
|
50
|
+
def capture_input(self, x: torch.Tensor) -> None:
|
|
51
|
+
"""Capture input at input hook (for transcoders).
|
|
52
|
+
|
|
53
|
+
Note: We don't clone the tensor here - the input should not be modified
|
|
54
|
+
in-place between capture and use, and avoiding clone preserves memory.
|
|
55
|
+
"""
|
|
56
|
+
self._captured_input = x
|
|
57
|
+
|
|
58
|
+
def forward(self, original_output: torch.Tensor) -> torch.Tensor:
|
|
59
|
+
"""Run SAE/transcoder at output hook location."""
|
|
60
|
+
# For SAE: use original_output as input (same hook for input/output)
|
|
61
|
+
# For transcoder: use captured input from earlier hook
|
|
62
|
+
sae_input = (
|
|
63
|
+
self._captured_input
|
|
64
|
+
if self._captured_input is not None
|
|
65
|
+
else original_output
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
# Temporarily disable SAE's internal use_error_term - we handle it here
|
|
69
|
+
# Use _use_error_term directly to avoid triggering deprecation warning
|
|
70
|
+
sae_use_error_term = self.sae._use_error_term
|
|
71
|
+
self.sae._use_error_term = False
|
|
72
|
+
try:
|
|
73
|
+
sae_out = self.sae(sae_input)
|
|
74
|
+
|
|
75
|
+
if self.use_error_term:
|
|
76
|
+
error = original_output - sae_out.detach()
|
|
77
|
+
sae_out = sae_out + error
|
|
78
|
+
|
|
79
|
+
return sae_out
|
|
80
|
+
finally:
|
|
81
|
+
self.sae._use_error_term = sae_use_error_term
|
|
82
|
+
self._captured_input = None
|
|
83
|
+
|
|
84
|
+
|
|
18
85
|
def get_deep_attr(obj: Any, path: str):
|
|
19
86
|
"""Helper function to get a nested attribute from a object.
|
|
20
87
|
In practice used to access HookedTransformer HookPoints (eg model.blocks[0].attn.hook_z)
|
|
@@ -78,88 +145,130 @@ class HookedSAETransformer(HookedTransformer):
|
|
|
78
145
|
add_hook_in_to_mlp(block.mlp) # type: ignore
|
|
79
146
|
self.setup()
|
|
80
147
|
|
|
81
|
-
self.
|
|
148
|
+
self._acts_to_saes: dict[str, _SAEWrapper] = {}
|
|
149
|
+
# Track output hooks used by transcoders for cleanup
|
|
150
|
+
self._transcoder_output_hooks: dict[str, str] = {}
|
|
151
|
+
|
|
152
|
+
@property
|
|
153
|
+
def acts_to_saes(self) -> dict[str, SAE[Any]]:
|
|
154
|
+
"""Returns a dict mapping hook names to attached SAEs."""
|
|
155
|
+
return {name: wrapper.sae for name, wrapper in self._acts_to_saes.items()}
|
|
82
156
|
|
|
83
157
|
def add_sae(self, sae: SAE[Any], use_error_term: bool | None = None):
|
|
84
|
-
"""Attaches an SAE to the model
|
|
158
|
+
"""Attaches an SAE or Transcoder to the model.
|
|
85
159
|
|
|
86
|
-
WARNING: This
|
|
160
|
+
WARNING: This SAE will be permanently attached until you remove it with
|
|
161
|
+
reset_saes. This function will also overwrite any existing SAE attached
|
|
162
|
+
to the same hook point.
|
|
87
163
|
|
|
88
164
|
Args:
|
|
89
|
-
sae:
|
|
90
|
-
use_error_term:
|
|
165
|
+
sae: The SAE or Transcoder to attach to the model.
|
|
166
|
+
use_error_term: If True, computes error term so output matches what the
|
|
167
|
+
model would have produced without the SAE. This works for both SAEs
|
|
168
|
+
(where input==output hook) and transcoders (where they differ).
|
|
169
|
+
Defaults to None (uses SAE's existing setting).
|
|
91
170
|
"""
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
171
|
+
input_hook = sae.cfg.metadata.hook_name
|
|
172
|
+
output_hook = sae.cfg.metadata.hook_name_out or input_hook
|
|
173
|
+
|
|
174
|
+
if (input_hook not in self._acts_to_saes) and (
|
|
175
|
+
input_hook not in self.hook_dict
|
|
176
|
+
):
|
|
177
|
+
logger.warning(
|
|
178
|
+
f"No hook found for {input_hook}. Skipping. Check model.hook_dict for available hooks."
|
|
96
179
|
)
|
|
97
180
|
return
|
|
98
181
|
|
|
99
|
-
if
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
182
|
+
# Check if output hook exists (either as hook_dict entry or already has SAE attached)
|
|
183
|
+
output_hook_exists = (
|
|
184
|
+
output_hook in self.hook_dict
|
|
185
|
+
or output_hook in self._acts_to_saes
|
|
186
|
+
or any(v == output_hook for v in self._transcoder_output_hooks.values())
|
|
187
|
+
)
|
|
188
|
+
if not output_hook_exists:
|
|
189
|
+
logger.warning(f"No hook found for output {output_hook}. Skipping.")
|
|
190
|
+
return
|
|
191
|
+
|
|
192
|
+
# Always use wrapper - it handles both SAEs and transcoders uniformly
|
|
193
|
+
# If use_error_term not specified, respect SAE's existing setting
|
|
194
|
+
effective_use_error_term = (
|
|
195
|
+
use_error_term if use_error_term is not None else sae.use_error_term
|
|
196
|
+
)
|
|
197
|
+
wrapper = _SAEWrapper(sae, use_error_term=effective_use_error_term)
|
|
198
|
+
|
|
199
|
+
# For transcoders (input != output), capture input at input hook
|
|
200
|
+
if input_hook != output_hook:
|
|
201
|
+
input_hook_point = get_deep_attr(self, input_hook)
|
|
202
|
+
if isinstance(input_hook_point, HookPoint):
|
|
203
|
+
input_hook_point.add_hook(
|
|
204
|
+
lambda tensor, hook: (wrapper.capture_input(tensor), tensor)[1], # noqa: ARG005
|
|
205
|
+
dir="fwd",
|
|
206
|
+
is_permanent=True,
|
|
207
|
+
)
|
|
208
|
+
self._transcoder_output_hooks[input_hook] = output_hook
|
|
209
|
+
|
|
210
|
+
# Store wrapper in _acts_to_saes and at output hook
|
|
211
|
+
self._acts_to_saes[input_hook] = wrapper
|
|
212
|
+
set_deep_attr(self, output_hook, wrapper)
|
|
105
213
|
self.setup()
|
|
106
214
|
|
|
107
|
-
def _reset_sae(
|
|
108
|
-
|
|
215
|
+
def _reset_sae(
|
|
216
|
+
self, act_name: str, prev_wrapper: _SAEWrapper | None = None
|
|
217
|
+
) -> None:
|
|
218
|
+
"""Resets an SAE that was attached to the model.
|
|
109
219
|
|
|
110
220
|
By default will remove the SAE from that hook_point.
|
|
111
|
-
If
|
|
112
|
-
This is mainly used to restore previously attached SAEs after temporarily running with different SAEs (eg with run_with_saes)
|
|
221
|
+
If prev_wrapper is provided, will restore that wrapper's SAE with its settings.
|
|
113
222
|
|
|
114
223
|
Args:
|
|
115
|
-
act_name:
|
|
116
|
-
|
|
224
|
+
act_name: The hook_name of the SAE to reset.
|
|
225
|
+
prev_wrapper: The previous wrapper to restore. If None, will just
|
|
226
|
+
remove the SAE from this hook point. Defaults to None.
|
|
117
227
|
"""
|
|
118
|
-
if act_name not in self.
|
|
119
|
-
|
|
228
|
+
if act_name not in self._acts_to_saes:
|
|
229
|
+
logger.warning(
|
|
120
230
|
f"No SAE is attached to {act_name}. There's nothing to reset."
|
|
121
231
|
)
|
|
122
232
|
return
|
|
123
233
|
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
234
|
+
# Determine output hook location (different from input for transcoders)
|
|
235
|
+
output_hook = self._transcoder_output_hooks.pop(act_name, act_name)
|
|
236
|
+
|
|
237
|
+
# For transcoders, clear permanent hooks from input hook point
|
|
238
|
+
if output_hook != act_name:
|
|
239
|
+
input_hook_point = get_deep_attr(self, act_name)
|
|
240
|
+
if isinstance(input_hook_point, HookPoint):
|
|
241
|
+
input_hook_point.remove_hooks(dir="fwd", including_permanent=True)
|
|
242
|
+
|
|
243
|
+
# Reset output hook location
|
|
244
|
+
set_deep_attr(self, output_hook, HookPoint())
|
|
245
|
+
del self._acts_to_saes[act_name]
|
|
128
246
|
|
|
129
|
-
if
|
|
130
|
-
|
|
131
|
-
self.
|
|
132
|
-
|
|
133
|
-
set_deep_attr(self, act_name, HookPoint())
|
|
134
|
-
del self.acts_to_saes[act_name]
|
|
247
|
+
if prev_wrapper is not None:
|
|
248
|
+
# Rebuild hook_dict before adding new SAE
|
|
249
|
+
self.setup()
|
|
250
|
+
self.add_sae(prev_wrapper.sae, use_error_term=prev_wrapper.use_error_term)
|
|
135
251
|
|
|
136
252
|
def reset_saes(
|
|
137
253
|
self,
|
|
138
254
|
act_names: str | list[str] | None = None,
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
"""Reset the SAEs attached to the model
|
|
255
|
+
) -> None:
|
|
256
|
+
"""Reset the SAEs attached to the model.
|
|
142
257
|
|
|
143
|
-
If act_names are provided will just reset SAEs attached to those hooks.
|
|
144
|
-
|
|
258
|
+
If act_names are provided will just reset SAEs attached to those hooks.
|
|
259
|
+
Otherwise will reset all SAEs attached to the model.
|
|
145
260
|
|
|
146
261
|
Args:
|
|
147
|
-
act_names
|
|
148
|
-
|
|
262
|
+
act_names: The act_names of the SAEs to reset. If None, will reset
|
|
263
|
+
all SAEs attached to the model. Defaults to None.
|
|
149
264
|
"""
|
|
150
265
|
if isinstance(act_names, str):
|
|
151
266
|
act_names = [act_names]
|
|
152
267
|
elif act_names is None:
|
|
153
|
-
act_names = list(self.
|
|
154
|
-
|
|
155
|
-
if prev_saes:
|
|
156
|
-
if len(act_names) != len(prev_saes):
|
|
157
|
-
raise ValueError("act_names and prev_saes must have the same length")
|
|
158
|
-
else:
|
|
159
|
-
prev_saes = [None] * len(act_names) # type: ignore
|
|
268
|
+
act_names = list(self._acts_to_saes.keys())
|
|
160
269
|
|
|
161
|
-
for act_name
|
|
162
|
-
self._reset_sae(act_name
|
|
270
|
+
for act_name in act_names:
|
|
271
|
+
self._reset_sae(act_name)
|
|
163
272
|
|
|
164
273
|
self.setup()
|
|
165
274
|
|
|
@@ -269,41 +378,32 @@ class HookedSAETransformer(HookedTransformer):
|
|
|
269
378
|
reset_saes_end: bool = True,
|
|
270
379
|
use_error_term: bool | None = None,
|
|
271
380
|
):
|
|
272
|
-
"""
|
|
273
|
-
A context manager for adding temporary SAEs to the model.
|
|
274
|
-
See HookedTransformer.hooks for a similar context manager for hooks.
|
|
275
|
-
By default will keep track of previously attached SAEs, and restore them when the context manager exits.
|
|
276
|
-
|
|
277
|
-
Example:
|
|
278
|
-
|
|
279
|
-
.. code-block:: python
|
|
280
|
-
|
|
281
|
-
from transformer_lens import HookedSAETransformer
|
|
282
|
-
from sae_lens.saes.sae import SAE
|
|
283
|
-
|
|
284
|
-
model = HookedSAETransformer.from_pretrained('gpt2-small')
|
|
285
|
-
sae_cfg = SAEConfig(...)
|
|
286
|
-
sae = SAE(sae_cfg)
|
|
287
|
-
with model.saes(saes=[sae]):
|
|
288
|
-
spliced_logits = model(text)
|
|
381
|
+
"""A context manager for adding temporary SAEs to the model.
|
|
289
382
|
|
|
383
|
+
See HookedTransformer.hooks for a similar context manager for hooks.
|
|
384
|
+
By default will keep track of previously attached SAEs, and restore
|
|
385
|
+
them when the context manager exits.
|
|
290
386
|
|
|
291
387
|
Args:
|
|
292
|
-
saes
|
|
293
|
-
reset_saes_end
|
|
294
|
-
|
|
388
|
+
saes: SAEs to be attached.
|
|
389
|
+
reset_saes_end: If True, removes all SAEs added by this context
|
|
390
|
+
manager when the context manager exits, returning previously
|
|
391
|
+
attached SAEs to their original state.
|
|
392
|
+
use_error_term: If provided, will set the use_error_term attribute
|
|
393
|
+
of all SAEs attached during this run to this value.
|
|
295
394
|
"""
|
|
296
|
-
|
|
297
|
-
prev_saes = []
|
|
395
|
+
saes_to_restore: list[tuple[str, _SAEWrapper | None]] = []
|
|
298
396
|
if isinstance(saes, SAE):
|
|
299
397
|
saes = [saes]
|
|
300
398
|
try:
|
|
301
399
|
for sae in saes:
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
400
|
+
act_name = sae.cfg.metadata.hook_name
|
|
401
|
+
prev_wrapper = self._acts_to_saes.get(act_name, None)
|
|
402
|
+
saes_to_restore.append((act_name, prev_wrapper))
|
|
305
403
|
self.add_sae(sae, use_error_term=use_error_term)
|
|
306
404
|
yield self
|
|
307
405
|
finally:
|
|
308
406
|
if reset_saes_end:
|
|
309
|
-
|
|
407
|
+
for act_name, prev_wrapper in saes_to_restore:
|
|
408
|
+
self._reset_sae(act_name, prev_wrapper)
|
|
409
|
+
self.setup()
|
|
@@ -8,7 +8,11 @@ from transformer_lens.hook_points import HookPoint
|
|
|
8
8
|
from transformer_lens.model_bridge import TransformerBridge
|
|
9
9
|
|
|
10
10
|
from sae_lens import logger
|
|
11
|
-
from sae_lens.analysis.hooked_sae_transformer import
|
|
11
|
+
from sae_lens.analysis.hooked_sae_transformer import (
|
|
12
|
+
_SAEWrapper,
|
|
13
|
+
get_deep_attr,
|
|
14
|
+
set_deep_attr,
|
|
15
|
+
)
|
|
12
16
|
from sae_lens.saes.sae import SAE
|
|
13
17
|
|
|
14
18
|
SingleLoss = torch.Tensor # Type alias for a single element tensor
|
|
@@ -30,11 +34,18 @@ class SAETransformerBridge(TransformerBridge): # type: ignore[misc,no-untyped-c
|
|
|
30
34
|
useful for models not natively supported by HookedTransformer, such as Gemma 3.
|
|
31
35
|
"""
|
|
32
36
|
|
|
33
|
-
|
|
37
|
+
_acts_to_saes: dict[str, _SAEWrapper]
|
|
34
38
|
|
|
35
39
|
def __init__(self, *args: Any, **kwargs: Any):
|
|
36
40
|
super().__init__(*args, **kwargs)
|
|
37
|
-
self.
|
|
41
|
+
self._acts_to_saes = {}
|
|
42
|
+
# Track output hooks used by transcoders for cleanup
|
|
43
|
+
self._transcoder_output_hooks: dict[str, str] = {}
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def acts_to_saes(self) -> dict[str, SAE[Any]]:
|
|
47
|
+
"""Returns a dict mapping hook names to attached SAEs."""
|
|
48
|
+
return {name: wrapper.sae for name, wrapper in self._acts_to_saes.items()}
|
|
38
49
|
|
|
39
50
|
@classmethod
|
|
40
51
|
def boot_transformers( # type: ignore[override]
|
|
@@ -56,7 +67,8 @@ class SAETransformerBridge(TransformerBridge): # type: ignore[misc,no-untyped-c
|
|
|
56
67
|
# Convert to our class
|
|
57
68
|
# NOTE: this is super hacky and scary, but I don't know how else to achieve this given TLens' internal code
|
|
58
69
|
bridge.__class__ = cls
|
|
59
|
-
bridge.
|
|
70
|
+
bridge._acts_to_saes = {} # type: ignore[attr-defined]
|
|
71
|
+
bridge._transcoder_output_hooks = {} # type: ignore[attr-defined]
|
|
60
72
|
return bridge # type: ignore[return-value]
|
|
61
73
|
|
|
62
74
|
def _resolve_hook_name(self, hook_name: str) -> str:
|
|
@@ -75,110 +87,129 @@ class SAETransformerBridge(TransformerBridge): # type: ignore[misc,no-untyped-c
|
|
|
75
87
|
return resolved if isinstance(resolved, str) else hook_name
|
|
76
88
|
|
|
77
89
|
def add_sae(self, sae: SAE[Any], use_error_term: bool | None = None) -> None:
|
|
78
|
-
"""Attaches an SAE to the model.
|
|
90
|
+
"""Attaches an SAE or Transcoder to the model.
|
|
79
91
|
|
|
80
92
|
WARNING: This SAE will be permanently attached until you remove it with
|
|
81
93
|
reset_saes. This function will also overwrite any existing SAE attached
|
|
82
94
|
to the same hook point.
|
|
83
95
|
|
|
84
96
|
Args:
|
|
85
|
-
sae: The SAE to attach to the model
|
|
86
|
-
use_error_term: If
|
|
87
|
-
|
|
88
|
-
|
|
97
|
+
sae: The SAE or Transcoder to attach to the model.
|
|
98
|
+
use_error_term: If True, computes error term so output matches what the
|
|
99
|
+
model would have produced without the SAE. This works for both SAEs
|
|
100
|
+
(where input==output hook) and transcoders (where they differ).
|
|
101
|
+
Defaults to None (uses SAE's existing setting).
|
|
89
102
|
"""
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
103
|
+
input_hook_alias = sae.cfg.metadata.hook_name
|
|
104
|
+
output_hook_alias = sae.cfg.metadata.hook_name_out or input_hook_alias
|
|
105
|
+
input_hook_actual = self._resolve_hook_name(input_hook_alias)
|
|
106
|
+
output_hook_actual = self._resolve_hook_name(output_hook_alias)
|
|
107
|
+
|
|
108
|
+
# Check if hooks exist
|
|
109
|
+
if (input_hook_alias not in self._acts_to_saes) and (
|
|
110
|
+
input_hook_actual not in self._hook_registry
|
|
96
111
|
):
|
|
97
112
|
logger.warning(
|
|
98
|
-
f"No hook found for {
|
|
113
|
+
f"No hook found for {input_hook_alias}. Skipping. "
|
|
99
114
|
f"Check model._hook_registry for available hooks."
|
|
100
115
|
)
|
|
101
116
|
return
|
|
102
117
|
|
|
103
|
-
if
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
118
|
+
# Check if output hook exists (either as registry entry or already has SAE attached)
|
|
119
|
+
output_hook_exists = (
|
|
120
|
+
output_hook_actual in self._hook_registry
|
|
121
|
+
or input_hook_alias in self._acts_to_saes
|
|
122
|
+
or any(
|
|
123
|
+
v == output_hook_actual for v in self._transcoder_output_hooks.values()
|
|
124
|
+
)
|
|
125
|
+
)
|
|
126
|
+
if not output_hook_exists:
|
|
127
|
+
logger.warning(f"No hook found for output {output_hook_alias}. Skipping.")
|
|
128
|
+
return
|
|
112
129
|
|
|
113
|
-
|
|
130
|
+
# Always use wrapper - it handles both SAEs and transcoders uniformly
|
|
131
|
+
# If use_error_term not specified, respect SAE's existing setting
|
|
132
|
+
effective_use_error_term = (
|
|
133
|
+
use_error_term if use_error_term is not None else sae.use_error_term
|
|
134
|
+
)
|
|
135
|
+
wrapper = _SAEWrapper(sae, use_error_term=effective_use_error_term)
|
|
136
|
+
|
|
137
|
+
# For transcoders (input != output), capture input at input hook
|
|
138
|
+
if input_hook_alias != output_hook_alias:
|
|
139
|
+
input_hook_point = get_deep_attr(self, input_hook_actual)
|
|
140
|
+
if isinstance(input_hook_point, HookPoint):
|
|
141
|
+
input_hook_point.add_hook(
|
|
142
|
+
lambda tensor, hook: (wrapper.capture_input(tensor), tensor)[1], # noqa: ARG005
|
|
143
|
+
dir="fwd",
|
|
144
|
+
is_permanent=True,
|
|
145
|
+
)
|
|
146
|
+
self._transcoder_output_hooks[input_hook_alias] = output_hook_actual
|
|
147
|
+
|
|
148
|
+
# Store wrapper in _acts_to_saes and at output hook
|
|
149
|
+
set_deep_attr(self, output_hook_actual, wrapper)
|
|
150
|
+
self._hook_registry[output_hook_actual] = wrapper # type: ignore[assignment]
|
|
151
|
+
self._acts_to_saes[input_hook_alias] = wrapper
|
|
152
|
+
|
|
153
|
+
def _reset_sae(
|
|
154
|
+
self, act_name: str, prev_wrapper: _SAEWrapper | None = None
|
|
155
|
+
) -> None:
|
|
114
156
|
"""Resets an SAE that was attached to the model.
|
|
115
157
|
|
|
116
158
|
By default will remove the SAE from that hook_point.
|
|
117
|
-
If
|
|
118
|
-
This is mainly used to restore previously attached SAEs after temporarily
|
|
119
|
-
running with different SAEs (e.g., with run_with_saes).
|
|
159
|
+
If prev_wrapper is provided, will restore that wrapper's SAE with its settings.
|
|
120
160
|
|
|
121
161
|
Args:
|
|
122
162
|
act_name: The hook_name of the SAE to reset
|
|
123
|
-
|
|
163
|
+
prev_wrapper: The previous wrapper to restore. If None, will just
|
|
124
164
|
remove the SAE from this hook point. Defaults to None.
|
|
125
165
|
"""
|
|
126
|
-
if act_name not in self.
|
|
166
|
+
if act_name not in self._acts_to_saes:
|
|
127
167
|
logger.warning(
|
|
128
168
|
f"No SAE is attached to {act_name}. There's nothing to reset."
|
|
129
169
|
)
|
|
130
170
|
return
|
|
131
171
|
|
|
132
172
|
actual_name = self._resolve_hook_name(act_name)
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
173
|
+
|
|
174
|
+
# Determine output hook location (different from input for transcoders)
|
|
175
|
+
output_hook = self._transcoder_output_hooks.pop(act_name, actual_name)
|
|
176
|
+
|
|
177
|
+
# For transcoders, clear permanent hooks from input hook point
|
|
178
|
+
if output_hook != actual_name:
|
|
179
|
+
input_hook_point = get_deep_attr(self, actual_name)
|
|
180
|
+
if isinstance(input_hook_point, HookPoint):
|
|
181
|
+
input_hook_point.remove_hooks(dir="fwd", including_permanent=True)
|
|
182
|
+
|
|
183
|
+
# Reset output hook location
|
|
184
|
+
new_hook = HookPoint()
|
|
185
|
+
new_hook.name = output_hook
|
|
186
|
+
set_deep_attr(self, output_hook, new_hook)
|
|
187
|
+
self._hook_registry[output_hook] = new_hook
|
|
188
|
+
del self._acts_to_saes[act_name]
|
|
189
|
+
|
|
190
|
+
if prev_wrapper is not None:
|
|
191
|
+
self.add_sae(prev_wrapper.sae, use_error_term=prev_wrapper.use_error_term)
|
|
149
192
|
|
|
150
193
|
def reset_saes(
|
|
151
194
|
self,
|
|
152
195
|
act_names: str | list[str] | None = None,
|
|
153
|
-
prev_saes: list[SAE[Any] | None] | None = None,
|
|
154
196
|
) -> None:
|
|
155
197
|
"""Reset the SAEs attached to the model.
|
|
156
198
|
|
|
157
199
|
If act_names are provided will just reset SAEs attached to those hooks.
|
|
158
200
|
Otherwise will reset all SAEs attached to the model.
|
|
159
|
-
Optionally can provide a list of prev_saes to reset to. This is mainly
|
|
160
|
-
used to restore previously attached SAEs after temporarily running with
|
|
161
|
-
different SAEs (e.g., with run_with_saes).
|
|
162
201
|
|
|
163
202
|
Args:
|
|
164
203
|
act_names: The act_names of the SAEs to reset. If None, will reset all
|
|
165
204
|
SAEs attached to the model. Defaults to None.
|
|
166
|
-
prev_saes: List of SAEs to replace the current ones with. If None, will
|
|
167
|
-
just remove the SAEs. Defaults to None.
|
|
168
205
|
"""
|
|
169
206
|
if isinstance(act_names, str):
|
|
170
207
|
act_names = [act_names]
|
|
171
208
|
elif act_names is None:
|
|
172
|
-
act_names = list(self.
|
|
173
|
-
|
|
174
|
-
if prev_saes:
|
|
175
|
-
if len(act_names) != len(prev_saes):
|
|
176
|
-
raise ValueError("act_names and prev_saes must have the same length")
|
|
177
|
-
else:
|
|
178
|
-
prev_saes = [None] * len(act_names) # type: ignore[assignment]
|
|
209
|
+
act_names = list(self._acts_to_saes.keys())
|
|
179
210
|
|
|
180
|
-
for act_name
|
|
181
|
-
self._reset_sae(act_name
|
|
211
|
+
for act_name in act_names:
|
|
212
|
+
self._reset_sae(act_name)
|
|
182
213
|
|
|
183
214
|
def run_with_saes(
|
|
184
215
|
self,
|
|
@@ -310,20 +341,20 @@ class SAETransformerBridge(TransformerBridge): # type: ignore[misc,no-untyped-c
|
|
|
310
341
|
use_error_term: If provided, will set the use_error_term attribute of
|
|
311
342
|
all SAEs attached during this run to this value. Defaults to None.
|
|
312
343
|
"""
|
|
313
|
-
|
|
314
|
-
prev_saes: list[SAE[Any] | None] = []
|
|
344
|
+
saes_to_restore: list[tuple[str, _SAEWrapper | None]] = []
|
|
315
345
|
if isinstance(saes, SAE):
|
|
316
346
|
saes = [saes]
|
|
317
347
|
try:
|
|
318
348
|
for sae in saes:
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
349
|
+
act_name = sae.cfg.metadata.hook_name
|
|
350
|
+
prev_wrapper = self._acts_to_saes.get(act_name, None)
|
|
351
|
+
saes_to_restore.append((act_name, prev_wrapper))
|
|
322
352
|
self.add_sae(sae, use_error_term=use_error_term)
|
|
323
353
|
yield self
|
|
324
354
|
finally:
|
|
325
355
|
if reset_saes_end:
|
|
326
|
-
|
|
356
|
+
for act_name, prev_wrapper in saes_to_restore:
|
|
357
|
+
self._reset_sae(act_name, prev_wrapper)
|
|
327
358
|
|
|
328
359
|
@property
|
|
329
360
|
def hook_dict(self) -> dict[str, HookPoint]:
|
|
@@ -337,9 +368,9 @@ class SAETransformerBridge(TransformerBridge): # type: ignore[misc,no-untyped-c
|
|
|
337
368
|
hooks: dict[str, HookPoint] = {}
|
|
338
369
|
|
|
339
370
|
for name, hook_or_sae in self._hook_registry.items():
|
|
340
|
-
if isinstance(hook_or_sae,
|
|
371
|
+
if isinstance(hook_or_sae, _SAEWrapper):
|
|
341
372
|
# Include SAE's internal hooks with full path names
|
|
342
|
-
for sae_hook_name, sae_hook in hook_or_sae.hook_dict.items():
|
|
373
|
+
for sae_hook_name, sae_hook in hook_or_sae.sae.hook_dict.items():
|
|
343
374
|
full_name = f"{name}.{sae_hook_name}"
|
|
344
375
|
hooks[full_name] = sae_hook
|
|
345
376
|
else:
|
sae_lens/saes/sae.py
CHANGED
|
@@ -229,7 +229,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
229
229
|
cfg: T_SAE_CONFIG
|
|
230
230
|
dtype: torch.dtype
|
|
231
231
|
device: torch.device
|
|
232
|
-
|
|
232
|
+
_use_error_term: bool
|
|
233
233
|
|
|
234
234
|
# For type checking only - don't provide default values
|
|
235
235
|
# These will be initialized by subclasses
|
|
@@ -254,7 +254,9 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
254
254
|
|
|
255
255
|
self.dtype = str_to_dtype(cfg.dtype)
|
|
256
256
|
self.device = torch.device(cfg.device)
|
|
257
|
-
self.
|
|
257
|
+
self._use_error_term = False # Set directly to avoid warning during init
|
|
258
|
+
if use_error_term:
|
|
259
|
+
self.use_error_term = True # Use property setter to trigger warning
|
|
258
260
|
|
|
259
261
|
# Set up activation function
|
|
260
262
|
self.activation_fn = self.get_activation_fn()
|
|
@@ -281,6 +283,22 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
281
283
|
|
|
282
284
|
self.setup() # Required for HookedRootModule
|
|
283
285
|
|
|
286
|
+
@property
|
|
287
|
+
def use_error_term(self) -> bool:
|
|
288
|
+
return self._use_error_term
|
|
289
|
+
|
|
290
|
+
@use_error_term.setter
|
|
291
|
+
def use_error_term(self, value: bool) -> None:
|
|
292
|
+
if value and not self._use_error_term:
|
|
293
|
+
warnings.warn(
|
|
294
|
+
"Setting use_error_term directly on SAE is deprecated. "
|
|
295
|
+
"Use HookedSAETransformer.add_sae(sae, use_error_term=True) instead. "
|
|
296
|
+
"This will be removed in a future version.",
|
|
297
|
+
DeprecationWarning,
|
|
298
|
+
stacklevel=2,
|
|
299
|
+
)
|
|
300
|
+
self._use_error_term = value
|
|
301
|
+
|
|
284
302
|
@torch.no_grad()
|
|
285
303
|
def fold_activation_norm_scaling_factor(self, scaling_factor: float):
|
|
286
304
|
self.W_enc.data *= scaling_factor # type: ignore
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
sae_lens/__init__.py,sha256=
|
|
1
|
+
sae_lens/__init__.py,sha256=U5I6dWw2_x6ZXwL8BF72vB0HGPEUt-_BzHq-btLq2pw,5168
|
|
2
2
|
sae_lens/analysis/__init__.py,sha256=FZExlMviNwWR7OGUSGRbd0l-yUDGSp80gglI_ivILrY,412
|
|
3
3
|
sae_lens/analysis/compat.py,sha256=cgE3nhFcJTcuhppxbL71VanJS7YqVEOefuneB5eOaPw,538
|
|
4
|
-
sae_lens/analysis/hooked_sae_transformer.py,sha256
|
|
4
|
+
sae_lens/analysis/hooked_sae_transformer.py,sha256=-LY9CKYEziSTt-H7MeLdTx6ErfvMqgNEF709wV7tEs4,17826
|
|
5
5
|
sae_lens/analysis/neuronpedia_integration.py,sha256=Gx1W7hUBEuMoasNcnOnZ1wmqbXDd1pSZ1nqKEya1HQc,4962
|
|
6
|
-
sae_lens/analysis/sae_transformer_bridge.py,sha256=
|
|
6
|
+
sae_lens/analysis/sae_transformer_bridge.py,sha256=y_ZdvaxuUM_-7ywSCeFl6f6cq1_FiqRstMAo8mxZtYE,16191
|
|
7
7
|
sae_lens/cache_activations_runner.py,sha256=TjqNWIc46Nw09jHWFjzQzgzG5wdu_87Ahe-iFjI5_0Q,13117
|
|
8
8
|
sae_lens/config.py,sha256=V0BXV8rvpbm5YuVukow9FURPpdyE4HSflbdymAo0Ycg,31205
|
|
9
9
|
sae_lens/constants.py,sha256=CM-h9AjZNAl2aP7hVpKk7YsFHpu-_Lfhhmq2d5qPEVc,887
|
|
@@ -22,7 +22,7 @@ sae_lens/saes/gated_sae.py,sha256=V_2ZNlV4gRD-rX5JSx1xqY7idT8ChfdQ5yxWDdu_6hg,88
|
|
|
22
22
|
sae_lens/saes/jumprelu_sae.py,sha256=miiF-xI_yXdV9EkKjwAbU9zSMsx9KtKCz5YdXEzkN8g,13313
|
|
23
23
|
sae_lens/saes/matching_pursuit_sae.py,sha256=08_G9p1YMLnE5qZVCPp6gll-iG6nHRbMMASf4_bkFt8,13207
|
|
24
24
|
sae_lens/saes/matryoshka_batchtopk_sae.py,sha256=Qr6htt1HHOuO9FXI9hyaPSnGFIiJG-v7y1t1CEmkFzM,5995
|
|
25
|
-
sae_lens/saes/sae.py,sha256=
|
|
25
|
+
sae_lens/saes/sae.py,sha256=2FHhLoTZYOJUGjkWH7eF2EAc1fPxQsF4KYr5TZ8IUIU,40155
|
|
26
26
|
sae_lens/saes/standard_sae.py,sha256=_hldNZkFPAf9VGrxouR1-tN8T2OEk8IkWBcXoatrC1o,5749
|
|
27
27
|
sae_lens/saes/temporal_sae.py,sha256=S44sPddVj2xujA02CC8gT1tG0in7c_CSAhspu9FHbaA,13273
|
|
28
28
|
sae_lens/saes/topk_sae.py,sha256=vrMRPrCQR1o8G_kXqY_EAoGZARupkQNFB2dNZVLsusE,21073
|
|
@@ -49,7 +49,7 @@ sae_lens/training/types.py,sha256=1FpLx_Doda9vZpmfm-x1e8wGBYpyhe9Kpb_JuM5nIFM,90
|
|
|
49
49
|
sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
|
|
50
50
|
sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
|
|
51
51
|
sae_lens/util.py,sha256=oIMoeyEP2IzcPFmRbKUzOAycgEyMcOasGeO_BGVZbc4,4846
|
|
52
|
-
sae_lens-6.
|
|
53
|
-
sae_lens-6.
|
|
54
|
-
sae_lens-6.
|
|
55
|
-
sae_lens-6.
|
|
52
|
+
sae_lens-6.34.1.dist-info/METADATA,sha256=vAlbWT90NggKoxjII6dnEj7wlC-PmuBn8zzqL6r8dRg,6566
|
|
53
|
+
sae_lens-6.34.1.dist-info/WHEEL,sha256=kJCRJT_g0adfAJzTx2GUMmS80rTJIVHRCfG0DQgLq3o,88
|
|
54
|
+
sae_lens-6.34.1.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
|
|
55
|
+
sae_lens-6.34.1.dist-info/RECORD,,
|
|
File without changes
|