sae-lens 6.32.1__py3-none-any.whl → 6.34.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +1 -1
- sae_lens/analysis/hooked_sae_transformer.py +176 -76
- sae_lens/analysis/sae_transformer_bridge.py +101 -70
- sae_lens/loading/pretrained_saes_directory.py +0 -22
- sae_lens/saes/sae.py +20 -33
- sae_lens/synthetic/__init__.py +13 -0
- sae_lens/synthetic/correlation.py +12 -14
- sae_lens/synthetic/stats.py +205 -0
- {sae_lens-6.32.1.dist-info → sae_lens-6.34.1.dist-info}/METADATA +1 -1
- {sae_lens-6.32.1.dist-info → sae_lens-6.34.1.dist-info}/RECORD +12 -11
- {sae_lens-6.32.1.dist-info → sae_lens-6.34.1.dist-info}/WHEEL +1 -1
- {sae_lens-6.32.1.dist-info → sae_lens-6.34.1.dist-info}/licenses/LICENSE +0 -0
sae_lens/__init__.py
CHANGED
|
@@ -1,13 +1,14 @@
|
|
|
1
|
-
import logging
|
|
2
1
|
from contextlib import contextmanager
|
|
3
2
|
from typing import Any, Callable
|
|
4
3
|
|
|
5
4
|
import torch
|
|
5
|
+
from torch import nn
|
|
6
6
|
from transformer_lens.ActivationCache import ActivationCache
|
|
7
7
|
from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP
|
|
8
8
|
from transformer_lens.hook_points import HookPoint # Hooking utilities
|
|
9
9
|
from transformer_lens.HookedTransformer import HookedTransformer
|
|
10
10
|
|
|
11
|
+
from sae_lens import logger
|
|
11
12
|
from sae_lens.saes.sae import SAE
|
|
12
13
|
|
|
13
14
|
SingleLoss = torch.Tensor # Type alias for a single element tensor
|
|
@@ -15,6 +16,72 @@ LossPerToken = torch.Tensor
|
|
|
15
16
|
Loss = SingleLoss | LossPerToken
|
|
16
17
|
|
|
17
18
|
|
|
19
|
+
class _SAEWrapper(nn.Module):
|
|
20
|
+
"""Wrapper for SAE/Transcoder that handles error term and hook coordination.
|
|
21
|
+
|
|
22
|
+
For SAEs (input_hook == output_hook), _captured_input stays None and we use
|
|
23
|
+
the forward argument directly. For transcoders, _captured_input is set at
|
|
24
|
+
the input hook via capture_input().
|
|
25
|
+
|
|
26
|
+
Implementation Note:
|
|
27
|
+
The SAE is stored in __dict__ directly rather than as a registered submodule.
|
|
28
|
+
This is intentional: PyTorch's module registration would add a ".sae." prefix
|
|
29
|
+
to all hook names in the cache (e.g., "blocks.0.hook_mlp_out.sae.hook_sae_input"
|
|
30
|
+
instead of "blocks.0.hook_mlp_out.hook_sae_input"). By storing in __dict__ and
|
|
31
|
+
copying hooks directly to the wrapper, we preserve the expected cache paths
|
|
32
|
+
for backwards compatibility.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(self, sae: SAE[Any], use_error_term: bool = False):
|
|
36
|
+
super().__init__()
|
|
37
|
+
# Store SAE in __dict__ to avoid registering as submodule. This keeps cache
|
|
38
|
+
# paths clean by avoiding a ".sae." prefix on hook names. See class docstring.
|
|
39
|
+
self.__dict__["_sae"] = sae
|
|
40
|
+
# Copy SAE's hooks directly to wrapper so they appear at the right path
|
|
41
|
+
for name, hook in sae.hook_dict.items():
|
|
42
|
+
setattr(self, name, hook)
|
|
43
|
+
self.use_error_term = use_error_term
|
|
44
|
+
self._captured_input: torch.Tensor | None = None
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def sae(self) -> SAE[Any]:
|
|
48
|
+
return self.__dict__["_sae"]
|
|
49
|
+
|
|
50
|
+
def capture_input(self, x: torch.Tensor) -> None:
|
|
51
|
+
"""Capture input at input hook (for transcoders).
|
|
52
|
+
|
|
53
|
+
Note: We don't clone the tensor here - the input should not be modified
|
|
54
|
+
in-place between capture and use, and avoiding clone preserves memory.
|
|
55
|
+
"""
|
|
56
|
+
self._captured_input = x
|
|
57
|
+
|
|
58
|
+
def forward(self, original_output: torch.Tensor) -> torch.Tensor:
|
|
59
|
+
"""Run SAE/transcoder at output hook location."""
|
|
60
|
+
# For SAE: use original_output as input (same hook for input/output)
|
|
61
|
+
# For transcoder: use captured input from earlier hook
|
|
62
|
+
sae_input = (
|
|
63
|
+
self._captured_input
|
|
64
|
+
if self._captured_input is not None
|
|
65
|
+
else original_output
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
# Temporarily disable SAE's internal use_error_term - we handle it here
|
|
69
|
+
# Use _use_error_term directly to avoid triggering deprecation warning
|
|
70
|
+
sae_use_error_term = self.sae._use_error_term
|
|
71
|
+
self.sae._use_error_term = False
|
|
72
|
+
try:
|
|
73
|
+
sae_out = self.sae(sae_input)
|
|
74
|
+
|
|
75
|
+
if self.use_error_term:
|
|
76
|
+
error = original_output - sae_out.detach()
|
|
77
|
+
sae_out = sae_out + error
|
|
78
|
+
|
|
79
|
+
return sae_out
|
|
80
|
+
finally:
|
|
81
|
+
self.sae._use_error_term = sae_use_error_term
|
|
82
|
+
self._captured_input = None
|
|
83
|
+
|
|
84
|
+
|
|
18
85
|
def get_deep_attr(obj: Any, path: str):
|
|
19
86
|
"""Helper function to get a nested attribute from a object.
|
|
20
87
|
In practice used to access HookedTransformer HookPoints (eg model.blocks[0].attn.hook_z)
|
|
@@ -78,88 +145,130 @@ class HookedSAETransformer(HookedTransformer):
|
|
|
78
145
|
add_hook_in_to_mlp(block.mlp) # type: ignore
|
|
79
146
|
self.setup()
|
|
80
147
|
|
|
81
|
-
self.
|
|
148
|
+
self._acts_to_saes: dict[str, _SAEWrapper] = {}
|
|
149
|
+
# Track output hooks used by transcoders for cleanup
|
|
150
|
+
self._transcoder_output_hooks: dict[str, str] = {}
|
|
151
|
+
|
|
152
|
+
@property
|
|
153
|
+
def acts_to_saes(self) -> dict[str, SAE[Any]]:
|
|
154
|
+
"""Returns a dict mapping hook names to attached SAEs."""
|
|
155
|
+
return {name: wrapper.sae for name, wrapper in self._acts_to_saes.items()}
|
|
82
156
|
|
|
83
157
|
def add_sae(self, sae: SAE[Any], use_error_term: bool | None = None):
|
|
84
|
-
"""Attaches an SAE to the model
|
|
158
|
+
"""Attaches an SAE or Transcoder to the model.
|
|
85
159
|
|
|
86
|
-
WARNING: This
|
|
160
|
+
WARNING: This SAE will be permanently attached until you remove it with
|
|
161
|
+
reset_saes. This function will also overwrite any existing SAE attached
|
|
162
|
+
to the same hook point.
|
|
87
163
|
|
|
88
164
|
Args:
|
|
89
|
-
sae:
|
|
90
|
-
use_error_term:
|
|
165
|
+
sae: The SAE or Transcoder to attach to the model.
|
|
166
|
+
use_error_term: If True, computes error term so output matches what the
|
|
167
|
+
model would have produced without the SAE. This works for both SAEs
|
|
168
|
+
(where input==output hook) and transcoders (where they differ).
|
|
169
|
+
Defaults to None (uses SAE's existing setting).
|
|
91
170
|
"""
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
171
|
+
input_hook = sae.cfg.metadata.hook_name
|
|
172
|
+
output_hook = sae.cfg.metadata.hook_name_out or input_hook
|
|
173
|
+
|
|
174
|
+
if (input_hook not in self._acts_to_saes) and (
|
|
175
|
+
input_hook not in self.hook_dict
|
|
176
|
+
):
|
|
177
|
+
logger.warning(
|
|
178
|
+
f"No hook found for {input_hook}. Skipping. Check model.hook_dict for available hooks."
|
|
96
179
|
)
|
|
97
180
|
return
|
|
98
181
|
|
|
99
|
-
if
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
182
|
+
# Check if output hook exists (either as hook_dict entry or already has SAE attached)
|
|
183
|
+
output_hook_exists = (
|
|
184
|
+
output_hook in self.hook_dict
|
|
185
|
+
or output_hook in self._acts_to_saes
|
|
186
|
+
or any(v == output_hook for v in self._transcoder_output_hooks.values())
|
|
187
|
+
)
|
|
188
|
+
if not output_hook_exists:
|
|
189
|
+
logger.warning(f"No hook found for output {output_hook}. Skipping.")
|
|
190
|
+
return
|
|
191
|
+
|
|
192
|
+
# Always use wrapper - it handles both SAEs and transcoders uniformly
|
|
193
|
+
# If use_error_term not specified, respect SAE's existing setting
|
|
194
|
+
effective_use_error_term = (
|
|
195
|
+
use_error_term if use_error_term is not None else sae.use_error_term
|
|
196
|
+
)
|
|
197
|
+
wrapper = _SAEWrapper(sae, use_error_term=effective_use_error_term)
|
|
198
|
+
|
|
199
|
+
# For transcoders (input != output), capture input at input hook
|
|
200
|
+
if input_hook != output_hook:
|
|
201
|
+
input_hook_point = get_deep_attr(self, input_hook)
|
|
202
|
+
if isinstance(input_hook_point, HookPoint):
|
|
203
|
+
input_hook_point.add_hook(
|
|
204
|
+
lambda tensor, hook: (wrapper.capture_input(tensor), tensor)[1], # noqa: ARG005
|
|
205
|
+
dir="fwd",
|
|
206
|
+
is_permanent=True,
|
|
207
|
+
)
|
|
208
|
+
self._transcoder_output_hooks[input_hook] = output_hook
|
|
209
|
+
|
|
210
|
+
# Store wrapper in _acts_to_saes and at output hook
|
|
211
|
+
self._acts_to_saes[input_hook] = wrapper
|
|
212
|
+
set_deep_attr(self, output_hook, wrapper)
|
|
105
213
|
self.setup()
|
|
106
214
|
|
|
107
|
-
def _reset_sae(
|
|
108
|
-
|
|
215
|
+
def _reset_sae(
|
|
216
|
+
self, act_name: str, prev_wrapper: _SAEWrapper | None = None
|
|
217
|
+
) -> None:
|
|
218
|
+
"""Resets an SAE that was attached to the model.
|
|
109
219
|
|
|
110
220
|
By default will remove the SAE from that hook_point.
|
|
111
|
-
If
|
|
112
|
-
This is mainly used to restore previously attached SAEs after temporarily running with different SAEs (eg with run_with_saes)
|
|
221
|
+
If prev_wrapper is provided, will restore that wrapper's SAE with its settings.
|
|
113
222
|
|
|
114
223
|
Args:
|
|
115
|
-
act_name:
|
|
116
|
-
|
|
224
|
+
act_name: The hook_name of the SAE to reset.
|
|
225
|
+
prev_wrapper: The previous wrapper to restore. If None, will just
|
|
226
|
+
remove the SAE from this hook point. Defaults to None.
|
|
117
227
|
"""
|
|
118
|
-
if act_name not in self.
|
|
119
|
-
|
|
228
|
+
if act_name not in self._acts_to_saes:
|
|
229
|
+
logger.warning(
|
|
120
230
|
f"No SAE is attached to {act_name}. There's nothing to reset."
|
|
121
231
|
)
|
|
122
232
|
return
|
|
123
233
|
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
234
|
+
# Determine output hook location (different from input for transcoders)
|
|
235
|
+
output_hook = self._transcoder_output_hooks.pop(act_name, act_name)
|
|
236
|
+
|
|
237
|
+
# For transcoders, clear permanent hooks from input hook point
|
|
238
|
+
if output_hook != act_name:
|
|
239
|
+
input_hook_point = get_deep_attr(self, act_name)
|
|
240
|
+
if isinstance(input_hook_point, HookPoint):
|
|
241
|
+
input_hook_point.remove_hooks(dir="fwd", including_permanent=True)
|
|
242
|
+
|
|
243
|
+
# Reset output hook location
|
|
244
|
+
set_deep_attr(self, output_hook, HookPoint())
|
|
245
|
+
del self._acts_to_saes[act_name]
|
|
128
246
|
|
|
129
|
-
if
|
|
130
|
-
|
|
131
|
-
self.
|
|
132
|
-
|
|
133
|
-
set_deep_attr(self, act_name, HookPoint())
|
|
134
|
-
del self.acts_to_saes[act_name]
|
|
247
|
+
if prev_wrapper is not None:
|
|
248
|
+
# Rebuild hook_dict before adding new SAE
|
|
249
|
+
self.setup()
|
|
250
|
+
self.add_sae(prev_wrapper.sae, use_error_term=prev_wrapper.use_error_term)
|
|
135
251
|
|
|
136
252
|
def reset_saes(
|
|
137
253
|
self,
|
|
138
254
|
act_names: str | list[str] | None = None,
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
"""Reset the SAEs attached to the model
|
|
255
|
+
) -> None:
|
|
256
|
+
"""Reset the SAEs attached to the model.
|
|
142
257
|
|
|
143
|
-
If act_names are provided will just reset SAEs attached to those hooks.
|
|
144
|
-
|
|
258
|
+
If act_names are provided will just reset SAEs attached to those hooks.
|
|
259
|
+
Otherwise will reset all SAEs attached to the model.
|
|
145
260
|
|
|
146
261
|
Args:
|
|
147
|
-
act_names
|
|
148
|
-
|
|
262
|
+
act_names: The act_names of the SAEs to reset. If None, will reset
|
|
263
|
+
all SAEs attached to the model. Defaults to None.
|
|
149
264
|
"""
|
|
150
265
|
if isinstance(act_names, str):
|
|
151
266
|
act_names = [act_names]
|
|
152
267
|
elif act_names is None:
|
|
153
|
-
act_names = list(self.
|
|
154
|
-
|
|
155
|
-
if prev_saes:
|
|
156
|
-
if len(act_names) != len(prev_saes):
|
|
157
|
-
raise ValueError("act_names and prev_saes must have the same length")
|
|
158
|
-
else:
|
|
159
|
-
prev_saes = [None] * len(act_names) # type: ignore
|
|
268
|
+
act_names = list(self._acts_to_saes.keys())
|
|
160
269
|
|
|
161
|
-
for act_name
|
|
162
|
-
self._reset_sae(act_name
|
|
270
|
+
for act_name in act_names:
|
|
271
|
+
self._reset_sae(act_name)
|
|
163
272
|
|
|
164
273
|
self.setup()
|
|
165
274
|
|
|
@@ -269,41 +378,32 @@ class HookedSAETransformer(HookedTransformer):
|
|
|
269
378
|
reset_saes_end: bool = True,
|
|
270
379
|
use_error_term: bool | None = None,
|
|
271
380
|
):
|
|
272
|
-
"""
|
|
273
|
-
A context manager for adding temporary SAEs to the model.
|
|
274
|
-
See HookedTransformer.hooks for a similar context manager for hooks.
|
|
275
|
-
By default will keep track of previously attached SAEs, and restore them when the context manager exits.
|
|
276
|
-
|
|
277
|
-
Example:
|
|
278
|
-
|
|
279
|
-
.. code-block:: python
|
|
280
|
-
|
|
281
|
-
from transformer_lens import HookedSAETransformer
|
|
282
|
-
from sae_lens.saes.sae import SAE
|
|
283
|
-
|
|
284
|
-
model = HookedSAETransformer.from_pretrained('gpt2-small')
|
|
285
|
-
sae_cfg = SAEConfig(...)
|
|
286
|
-
sae = SAE(sae_cfg)
|
|
287
|
-
with model.saes(saes=[sae]):
|
|
288
|
-
spliced_logits = model(text)
|
|
381
|
+
"""A context manager for adding temporary SAEs to the model.
|
|
289
382
|
|
|
383
|
+
See HookedTransformer.hooks for a similar context manager for hooks.
|
|
384
|
+
By default will keep track of previously attached SAEs, and restore
|
|
385
|
+
them when the context manager exits.
|
|
290
386
|
|
|
291
387
|
Args:
|
|
292
|
-
saes
|
|
293
|
-
reset_saes_end
|
|
294
|
-
|
|
388
|
+
saes: SAEs to be attached.
|
|
389
|
+
reset_saes_end: If True, removes all SAEs added by this context
|
|
390
|
+
manager when the context manager exits, returning previously
|
|
391
|
+
attached SAEs to their original state.
|
|
392
|
+
use_error_term: If provided, will set the use_error_term attribute
|
|
393
|
+
of all SAEs attached during this run to this value.
|
|
295
394
|
"""
|
|
296
|
-
|
|
297
|
-
prev_saes = []
|
|
395
|
+
saes_to_restore: list[tuple[str, _SAEWrapper | None]] = []
|
|
298
396
|
if isinstance(saes, SAE):
|
|
299
397
|
saes = [saes]
|
|
300
398
|
try:
|
|
301
399
|
for sae in saes:
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
400
|
+
act_name = sae.cfg.metadata.hook_name
|
|
401
|
+
prev_wrapper = self._acts_to_saes.get(act_name, None)
|
|
402
|
+
saes_to_restore.append((act_name, prev_wrapper))
|
|
305
403
|
self.add_sae(sae, use_error_term=use_error_term)
|
|
306
404
|
yield self
|
|
307
405
|
finally:
|
|
308
406
|
if reset_saes_end:
|
|
309
|
-
|
|
407
|
+
for act_name, prev_wrapper in saes_to_restore:
|
|
408
|
+
self._reset_sae(act_name, prev_wrapper)
|
|
409
|
+
self.setup()
|
|
@@ -8,7 +8,11 @@ from transformer_lens.hook_points import HookPoint
|
|
|
8
8
|
from transformer_lens.model_bridge import TransformerBridge
|
|
9
9
|
|
|
10
10
|
from sae_lens import logger
|
|
11
|
-
from sae_lens.analysis.hooked_sae_transformer import
|
|
11
|
+
from sae_lens.analysis.hooked_sae_transformer import (
|
|
12
|
+
_SAEWrapper,
|
|
13
|
+
get_deep_attr,
|
|
14
|
+
set_deep_attr,
|
|
15
|
+
)
|
|
12
16
|
from sae_lens.saes.sae import SAE
|
|
13
17
|
|
|
14
18
|
SingleLoss = torch.Tensor # Type alias for a single element tensor
|
|
@@ -30,11 +34,18 @@ class SAETransformerBridge(TransformerBridge): # type: ignore[misc,no-untyped-c
|
|
|
30
34
|
useful for models not natively supported by HookedTransformer, such as Gemma 3.
|
|
31
35
|
"""
|
|
32
36
|
|
|
33
|
-
|
|
37
|
+
_acts_to_saes: dict[str, _SAEWrapper]
|
|
34
38
|
|
|
35
39
|
def __init__(self, *args: Any, **kwargs: Any):
|
|
36
40
|
super().__init__(*args, **kwargs)
|
|
37
|
-
self.
|
|
41
|
+
self._acts_to_saes = {}
|
|
42
|
+
# Track output hooks used by transcoders for cleanup
|
|
43
|
+
self._transcoder_output_hooks: dict[str, str] = {}
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def acts_to_saes(self) -> dict[str, SAE[Any]]:
|
|
47
|
+
"""Returns a dict mapping hook names to attached SAEs."""
|
|
48
|
+
return {name: wrapper.sae for name, wrapper in self._acts_to_saes.items()}
|
|
38
49
|
|
|
39
50
|
@classmethod
|
|
40
51
|
def boot_transformers( # type: ignore[override]
|
|
@@ -56,7 +67,8 @@ class SAETransformerBridge(TransformerBridge): # type: ignore[misc,no-untyped-c
|
|
|
56
67
|
# Convert to our class
|
|
57
68
|
# NOTE: this is super hacky and scary, but I don't know how else to achieve this given TLens' internal code
|
|
58
69
|
bridge.__class__ = cls
|
|
59
|
-
bridge.
|
|
70
|
+
bridge._acts_to_saes = {} # type: ignore[attr-defined]
|
|
71
|
+
bridge._transcoder_output_hooks = {} # type: ignore[attr-defined]
|
|
60
72
|
return bridge # type: ignore[return-value]
|
|
61
73
|
|
|
62
74
|
def _resolve_hook_name(self, hook_name: str) -> str:
|
|
@@ -75,110 +87,129 @@ class SAETransformerBridge(TransformerBridge): # type: ignore[misc,no-untyped-c
|
|
|
75
87
|
return resolved if isinstance(resolved, str) else hook_name
|
|
76
88
|
|
|
77
89
|
def add_sae(self, sae: SAE[Any], use_error_term: bool | None = None) -> None:
|
|
78
|
-
"""Attaches an SAE to the model.
|
|
90
|
+
"""Attaches an SAE or Transcoder to the model.
|
|
79
91
|
|
|
80
92
|
WARNING: This SAE will be permanently attached until you remove it with
|
|
81
93
|
reset_saes. This function will also overwrite any existing SAE attached
|
|
82
94
|
to the same hook point.
|
|
83
95
|
|
|
84
96
|
Args:
|
|
85
|
-
sae: The SAE to attach to the model
|
|
86
|
-
use_error_term: If
|
|
87
|
-
|
|
88
|
-
|
|
97
|
+
sae: The SAE or Transcoder to attach to the model.
|
|
98
|
+
use_error_term: If True, computes error term so output matches what the
|
|
99
|
+
model would have produced without the SAE. This works for both SAEs
|
|
100
|
+
(where input==output hook) and transcoders (where they differ).
|
|
101
|
+
Defaults to None (uses SAE's existing setting).
|
|
89
102
|
"""
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
103
|
+
input_hook_alias = sae.cfg.metadata.hook_name
|
|
104
|
+
output_hook_alias = sae.cfg.metadata.hook_name_out or input_hook_alias
|
|
105
|
+
input_hook_actual = self._resolve_hook_name(input_hook_alias)
|
|
106
|
+
output_hook_actual = self._resolve_hook_name(output_hook_alias)
|
|
107
|
+
|
|
108
|
+
# Check if hooks exist
|
|
109
|
+
if (input_hook_alias not in self._acts_to_saes) and (
|
|
110
|
+
input_hook_actual not in self._hook_registry
|
|
96
111
|
):
|
|
97
112
|
logger.warning(
|
|
98
|
-
f"No hook found for {
|
|
113
|
+
f"No hook found for {input_hook_alias}. Skipping. "
|
|
99
114
|
f"Check model._hook_registry for available hooks."
|
|
100
115
|
)
|
|
101
116
|
return
|
|
102
117
|
|
|
103
|
-
if
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
118
|
+
# Check if output hook exists (either as registry entry or already has SAE attached)
|
|
119
|
+
output_hook_exists = (
|
|
120
|
+
output_hook_actual in self._hook_registry
|
|
121
|
+
or input_hook_alias in self._acts_to_saes
|
|
122
|
+
or any(
|
|
123
|
+
v == output_hook_actual for v in self._transcoder_output_hooks.values()
|
|
124
|
+
)
|
|
125
|
+
)
|
|
126
|
+
if not output_hook_exists:
|
|
127
|
+
logger.warning(f"No hook found for output {output_hook_alias}. Skipping.")
|
|
128
|
+
return
|
|
112
129
|
|
|
113
|
-
|
|
130
|
+
# Always use wrapper - it handles both SAEs and transcoders uniformly
|
|
131
|
+
# If use_error_term not specified, respect SAE's existing setting
|
|
132
|
+
effective_use_error_term = (
|
|
133
|
+
use_error_term if use_error_term is not None else sae.use_error_term
|
|
134
|
+
)
|
|
135
|
+
wrapper = _SAEWrapper(sae, use_error_term=effective_use_error_term)
|
|
136
|
+
|
|
137
|
+
# For transcoders (input != output), capture input at input hook
|
|
138
|
+
if input_hook_alias != output_hook_alias:
|
|
139
|
+
input_hook_point = get_deep_attr(self, input_hook_actual)
|
|
140
|
+
if isinstance(input_hook_point, HookPoint):
|
|
141
|
+
input_hook_point.add_hook(
|
|
142
|
+
lambda tensor, hook: (wrapper.capture_input(tensor), tensor)[1], # noqa: ARG005
|
|
143
|
+
dir="fwd",
|
|
144
|
+
is_permanent=True,
|
|
145
|
+
)
|
|
146
|
+
self._transcoder_output_hooks[input_hook_alias] = output_hook_actual
|
|
147
|
+
|
|
148
|
+
# Store wrapper in _acts_to_saes and at output hook
|
|
149
|
+
set_deep_attr(self, output_hook_actual, wrapper)
|
|
150
|
+
self._hook_registry[output_hook_actual] = wrapper # type: ignore[assignment]
|
|
151
|
+
self._acts_to_saes[input_hook_alias] = wrapper
|
|
152
|
+
|
|
153
|
+
def _reset_sae(
|
|
154
|
+
self, act_name: str, prev_wrapper: _SAEWrapper | None = None
|
|
155
|
+
) -> None:
|
|
114
156
|
"""Resets an SAE that was attached to the model.
|
|
115
157
|
|
|
116
158
|
By default will remove the SAE from that hook_point.
|
|
117
|
-
If
|
|
118
|
-
This is mainly used to restore previously attached SAEs after temporarily
|
|
119
|
-
running with different SAEs (e.g., with run_with_saes).
|
|
159
|
+
If prev_wrapper is provided, will restore that wrapper's SAE with its settings.
|
|
120
160
|
|
|
121
161
|
Args:
|
|
122
162
|
act_name: The hook_name of the SAE to reset
|
|
123
|
-
|
|
163
|
+
prev_wrapper: The previous wrapper to restore. If None, will just
|
|
124
164
|
remove the SAE from this hook point. Defaults to None.
|
|
125
165
|
"""
|
|
126
|
-
if act_name not in self.
|
|
166
|
+
if act_name not in self._acts_to_saes:
|
|
127
167
|
logger.warning(
|
|
128
168
|
f"No SAE is attached to {act_name}. There's nothing to reset."
|
|
129
169
|
)
|
|
130
170
|
return
|
|
131
171
|
|
|
132
172
|
actual_name = self._resolve_hook_name(act_name)
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
173
|
+
|
|
174
|
+
# Determine output hook location (different from input for transcoders)
|
|
175
|
+
output_hook = self._transcoder_output_hooks.pop(act_name, actual_name)
|
|
176
|
+
|
|
177
|
+
# For transcoders, clear permanent hooks from input hook point
|
|
178
|
+
if output_hook != actual_name:
|
|
179
|
+
input_hook_point = get_deep_attr(self, actual_name)
|
|
180
|
+
if isinstance(input_hook_point, HookPoint):
|
|
181
|
+
input_hook_point.remove_hooks(dir="fwd", including_permanent=True)
|
|
182
|
+
|
|
183
|
+
# Reset output hook location
|
|
184
|
+
new_hook = HookPoint()
|
|
185
|
+
new_hook.name = output_hook
|
|
186
|
+
set_deep_attr(self, output_hook, new_hook)
|
|
187
|
+
self._hook_registry[output_hook] = new_hook
|
|
188
|
+
del self._acts_to_saes[act_name]
|
|
189
|
+
|
|
190
|
+
if prev_wrapper is not None:
|
|
191
|
+
self.add_sae(prev_wrapper.sae, use_error_term=prev_wrapper.use_error_term)
|
|
149
192
|
|
|
150
193
|
def reset_saes(
|
|
151
194
|
self,
|
|
152
195
|
act_names: str | list[str] | None = None,
|
|
153
|
-
prev_saes: list[SAE[Any] | None] | None = None,
|
|
154
196
|
) -> None:
|
|
155
197
|
"""Reset the SAEs attached to the model.
|
|
156
198
|
|
|
157
199
|
If act_names are provided will just reset SAEs attached to those hooks.
|
|
158
200
|
Otherwise will reset all SAEs attached to the model.
|
|
159
|
-
Optionally can provide a list of prev_saes to reset to. This is mainly
|
|
160
|
-
used to restore previously attached SAEs after temporarily running with
|
|
161
|
-
different SAEs (e.g., with run_with_saes).
|
|
162
201
|
|
|
163
202
|
Args:
|
|
164
203
|
act_names: The act_names of the SAEs to reset. If None, will reset all
|
|
165
204
|
SAEs attached to the model. Defaults to None.
|
|
166
|
-
prev_saes: List of SAEs to replace the current ones with. If None, will
|
|
167
|
-
just remove the SAEs. Defaults to None.
|
|
168
205
|
"""
|
|
169
206
|
if isinstance(act_names, str):
|
|
170
207
|
act_names = [act_names]
|
|
171
208
|
elif act_names is None:
|
|
172
|
-
act_names = list(self.
|
|
173
|
-
|
|
174
|
-
if prev_saes:
|
|
175
|
-
if len(act_names) != len(prev_saes):
|
|
176
|
-
raise ValueError("act_names and prev_saes must have the same length")
|
|
177
|
-
else:
|
|
178
|
-
prev_saes = [None] * len(act_names) # type: ignore[assignment]
|
|
209
|
+
act_names = list(self._acts_to_saes.keys())
|
|
179
210
|
|
|
180
|
-
for act_name
|
|
181
|
-
self._reset_sae(act_name
|
|
211
|
+
for act_name in act_names:
|
|
212
|
+
self._reset_sae(act_name)
|
|
182
213
|
|
|
183
214
|
def run_with_saes(
|
|
184
215
|
self,
|
|
@@ -310,20 +341,20 @@ class SAETransformerBridge(TransformerBridge): # type: ignore[misc,no-untyped-c
|
|
|
310
341
|
use_error_term: If provided, will set the use_error_term attribute of
|
|
311
342
|
all SAEs attached during this run to this value. Defaults to None.
|
|
312
343
|
"""
|
|
313
|
-
|
|
314
|
-
prev_saes: list[SAE[Any] | None] = []
|
|
344
|
+
saes_to_restore: list[tuple[str, _SAEWrapper | None]] = []
|
|
315
345
|
if isinstance(saes, SAE):
|
|
316
346
|
saes = [saes]
|
|
317
347
|
try:
|
|
318
348
|
for sae in saes:
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
349
|
+
act_name = sae.cfg.metadata.hook_name
|
|
350
|
+
prev_wrapper = self._acts_to_saes.get(act_name, None)
|
|
351
|
+
saes_to_restore.append((act_name, prev_wrapper))
|
|
322
352
|
self.add_sae(sae, use_error_term=use_error_term)
|
|
323
353
|
yield self
|
|
324
354
|
finally:
|
|
325
355
|
if reset_saes_end:
|
|
326
|
-
|
|
356
|
+
for act_name, prev_wrapper in saes_to_restore:
|
|
357
|
+
self._reset_sae(act_name, prev_wrapper)
|
|
327
358
|
|
|
328
359
|
@property
|
|
329
360
|
def hook_dict(self) -> dict[str, HookPoint]:
|
|
@@ -337,9 +368,9 @@ class SAETransformerBridge(TransformerBridge): # type: ignore[misc,no-untyped-c
|
|
|
337
368
|
hooks: dict[str, HookPoint] = {}
|
|
338
369
|
|
|
339
370
|
for name, hook_or_sae in self._hook_registry.items():
|
|
340
|
-
if isinstance(hook_or_sae,
|
|
371
|
+
if isinstance(hook_or_sae, _SAEWrapper):
|
|
341
372
|
# Include SAE's internal hooks with full path names
|
|
342
|
-
for sae_hook_name, sae_hook in hook_or_sae.hook_dict.items():
|
|
373
|
+
for sae_hook_name, sae_hook in hook_or_sae.sae.hook_dict.items():
|
|
343
374
|
full_name = f"{name}.{sae_hook_name}"
|
|
344
375
|
hooks[full_name] = sae_hook
|
|
345
376
|
else:
|
|
@@ -57,28 +57,6 @@ def get_pretrained_saes_directory() -> dict[str, PretrainedSAELookup]:
|
|
|
57
57
|
return directory
|
|
58
58
|
|
|
59
59
|
|
|
60
|
-
def get_norm_scaling_factor(release: str, sae_id: str) -> float | None:
|
|
61
|
-
"""
|
|
62
|
-
Retrieve the norm_scaling_factor for a specific SAE if it exists.
|
|
63
|
-
|
|
64
|
-
Args:
|
|
65
|
-
release (str): The release name of the SAE.
|
|
66
|
-
sae_id (str): The ID of the specific SAE.
|
|
67
|
-
|
|
68
|
-
Returns:
|
|
69
|
-
float | None: The norm_scaling_factor if it exists, None otherwise.
|
|
70
|
-
"""
|
|
71
|
-
package = "sae_lens"
|
|
72
|
-
yaml_file = files(package).joinpath("pretrained_saes.yaml")
|
|
73
|
-
with yaml_file.open("r") as file:
|
|
74
|
-
data = yaml.safe_load(file)
|
|
75
|
-
if release in data:
|
|
76
|
-
for sae_info in data[release]["saes"]:
|
|
77
|
-
if sae_info["id"] == sae_id:
|
|
78
|
-
return sae_info.get("norm_scaling_factor")
|
|
79
|
-
return None
|
|
80
|
-
|
|
81
|
-
|
|
82
60
|
def get_repo_id_and_folder_name(release: str, sae_id: str) -> tuple[str, str]:
|
|
83
61
|
saes_directory = get_pretrained_saes_directory()
|
|
84
62
|
sae_info = saes_directory.get(release, None)
|
sae_lens/saes/sae.py
CHANGED
|
@@ -45,7 +45,6 @@ from sae_lens.loading.pretrained_sae_loaders import (
|
|
|
45
45
|
)
|
|
46
46
|
from sae_lens.loading.pretrained_saes_directory import (
|
|
47
47
|
get_config_overrides,
|
|
48
|
-
get_norm_scaling_factor,
|
|
49
48
|
get_pretrained_saes_directory,
|
|
50
49
|
get_releases_for_repo_id,
|
|
51
50
|
get_repo_id_and_folder_name,
|
|
@@ -230,7 +229,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
230
229
|
cfg: T_SAE_CONFIG
|
|
231
230
|
dtype: torch.dtype
|
|
232
231
|
device: torch.device
|
|
233
|
-
|
|
232
|
+
_use_error_term: bool
|
|
234
233
|
|
|
235
234
|
# For type checking only - don't provide default values
|
|
236
235
|
# These will be initialized by subclasses
|
|
@@ -255,7 +254,9 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
255
254
|
|
|
256
255
|
self.dtype = str_to_dtype(cfg.dtype)
|
|
257
256
|
self.device = torch.device(cfg.device)
|
|
258
|
-
self.
|
|
257
|
+
self._use_error_term = False # Set directly to avoid warning during init
|
|
258
|
+
if use_error_term:
|
|
259
|
+
self.use_error_term = True # Use property setter to trigger warning
|
|
259
260
|
|
|
260
261
|
# Set up activation function
|
|
261
262
|
self.activation_fn = self.get_activation_fn()
|
|
@@ -282,6 +283,22 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
282
283
|
|
|
283
284
|
self.setup() # Required for HookedRootModule
|
|
284
285
|
|
|
286
|
+
@property
|
|
287
|
+
def use_error_term(self) -> bool:
|
|
288
|
+
return self._use_error_term
|
|
289
|
+
|
|
290
|
+
@use_error_term.setter
|
|
291
|
+
def use_error_term(self, value: bool) -> None:
|
|
292
|
+
if value and not self._use_error_term:
|
|
293
|
+
warnings.warn(
|
|
294
|
+
"Setting use_error_term directly on SAE is deprecated. "
|
|
295
|
+
"Use HookedSAETransformer.add_sae(sae, use_error_term=True) instead. "
|
|
296
|
+
"This will be removed in a future version.",
|
|
297
|
+
DeprecationWarning,
|
|
298
|
+
stacklevel=2,
|
|
299
|
+
)
|
|
300
|
+
self._use_error_term = value
|
|
301
|
+
|
|
285
302
|
@torch.no_grad()
|
|
286
303
|
def fold_activation_norm_scaling_factor(self, scaling_factor: float):
|
|
287
304
|
self.W_enc.data *= scaling_factor # type: ignore
|
|
@@ -638,24 +655,6 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
638
655
|
stacklevel=2,
|
|
639
656
|
)
|
|
640
657
|
elif sae_id not in sae_directory[release].saes_map:
|
|
641
|
-
# Handle special cases like Gemma Scope
|
|
642
|
-
if (
|
|
643
|
-
"gemma-scope" in release
|
|
644
|
-
and "canonical" not in release
|
|
645
|
-
and f"{release}-canonical" in sae_directory
|
|
646
|
-
):
|
|
647
|
-
canonical_ids = list(
|
|
648
|
-
sae_directory[release + "-canonical"].saes_map.keys()
|
|
649
|
-
)
|
|
650
|
-
# Shorten the lengthy string of valid IDs
|
|
651
|
-
if len(canonical_ids) > 5:
|
|
652
|
-
str_canonical_ids = str(canonical_ids[:5])[:-1] + ", ...]"
|
|
653
|
-
else:
|
|
654
|
-
str_canonical_ids = str(canonical_ids)
|
|
655
|
-
value_suffix = f" If you don't want to specify an L0 value, consider using release {release}-canonical which has valid IDs {str_canonical_ids}"
|
|
656
|
-
else:
|
|
657
|
-
value_suffix = ""
|
|
658
|
-
|
|
659
658
|
valid_ids = list(sae_directory[release].saes_map.keys())
|
|
660
659
|
# Shorten the lengthy string of valid IDs
|
|
661
660
|
if len(valid_ids) > 5:
|
|
@@ -665,7 +664,6 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
665
664
|
|
|
666
665
|
raise ValueError(
|
|
667
666
|
f"ID {sae_id} not found in release {release}. Valid IDs are {str_valid_ids}."
|
|
668
|
-
+ value_suffix
|
|
669
667
|
)
|
|
670
668
|
|
|
671
669
|
conversion_loader = (
|
|
@@ -702,17 +700,6 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
702
700
|
sae.process_state_dict_for_loading(state_dict)
|
|
703
701
|
sae.load_state_dict(state_dict, assign=True)
|
|
704
702
|
|
|
705
|
-
# Apply normalization if needed
|
|
706
|
-
if cfg_dict.get("normalize_activations") == "expected_average_only_in":
|
|
707
|
-
norm_scaling_factor = get_norm_scaling_factor(release, sae_id)
|
|
708
|
-
if norm_scaling_factor is not None:
|
|
709
|
-
sae.fold_activation_norm_scaling_factor(norm_scaling_factor)
|
|
710
|
-
cfg_dict["normalize_activations"] = "none"
|
|
711
|
-
else:
|
|
712
|
-
warnings.warn(
|
|
713
|
-
f"norm_scaling_factor not found for {release} and {sae_id}, but normalize_activations is 'expected_average_only_in'. Skipping normalization folding."
|
|
714
|
-
)
|
|
715
|
-
|
|
716
703
|
# the loaders should already handle the dtype / device conversion
|
|
717
704
|
# but this is a fallback to guarantee the SAE is on the correct device and dtype
|
|
718
705
|
return (
|
sae_lens/synthetic/__init__.py
CHANGED
|
@@ -50,6 +50,13 @@ from sae_lens.synthetic.plotting import (
|
|
|
50
50
|
find_best_feature_ordering_from_sae,
|
|
51
51
|
plot_sae_feature_similarity,
|
|
52
52
|
)
|
|
53
|
+
from sae_lens.synthetic.stats import (
|
|
54
|
+
CorrelationMatrixStats,
|
|
55
|
+
SuperpositionStats,
|
|
56
|
+
compute_correlation_matrix_stats,
|
|
57
|
+
compute_low_rank_correlation_matrix_stats,
|
|
58
|
+
compute_superposition_stats,
|
|
59
|
+
)
|
|
53
60
|
from sae_lens.synthetic.training import (
|
|
54
61
|
SyntheticActivationIterator,
|
|
55
62
|
train_toy_sae,
|
|
@@ -80,6 +87,12 @@ __all__ = [
|
|
|
80
87
|
"orthogonal_initializer",
|
|
81
88
|
"FeatureDictionaryInitializer",
|
|
82
89
|
"cosine_similarities",
|
|
90
|
+
# Statistics
|
|
91
|
+
"compute_correlation_matrix_stats",
|
|
92
|
+
"compute_low_rank_correlation_matrix_stats",
|
|
93
|
+
"compute_superposition_stats",
|
|
94
|
+
"CorrelationMatrixStats",
|
|
95
|
+
"SuperpositionStats",
|
|
83
96
|
# Training utilities
|
|
84
97
|
"SyntheticActivationIterator",
|
|
85
98
|
"SyntheticDataEvalResult",
|
|
@@ -3,6 +3,7 @@ from typing import NamedTuple
|
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
|
+
from sae_lens import logger
|
|
6
7
|
from sae_lens.util import str_to_dtype
|
|
7
8
|
|
|
8
9
|
|
|
@@ -268,7 +269,7 @@ def generate_random_correlation_matrix(
|
|
|
268
269
|
def generate_random_low_rank_correlation_matrix(
|
|
269
270
|
num_features: int,
|
|
270
271
|
rank: int,
|
|
271
|
-
correlation_scale: float = 0.
|
|
272
|
+
correlation_scale: float = 0.075,
|
|
272
273
|
seed: int | None = None,
|
|
273
274
|
device: torch.device | str = "cpu",
|
|
274
275
|
dtype: torch.dtype | str = torch.float32,
|
|
@@ -331,20 +332,17 @@ def generate_random_low_rank_correlation_matrix(
|
|
|
331
332
|
factor_sq_sum = (factor**2).sum(dim=1)
|
|
332
333
|
diag_term = 1 - factor_sq_sum
|
|
333
334
|
|
|
334
|
-
#
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
/
|
|
335
|
+
# alternatively, we can rescale each row independently to ensure the diagonal is 1
|
|
336
|
+
mask = diag_term < _MIN_DIAG
|
|
337
|
+
factor[mask, :] *= torch.sqrt((1 - _MIN_DIAG) / factor_sq_sum[mask].unsqueeze(1))
|
|
338
|
+
factor_sq_sum = (factor**2).sum(dim=1)
|
|
339
|
+
diag_term = 1 - factor_sq_sum
|
|
340
|
+
|
|
341
|
+
total_rescaled = mask.sum().item()
|
|
342
|
+
if total_rescaled > 0:
|
|
343
|
+
logger.warning(
|
|
344
|
+
f"{total_rescaled} / {num_features} rows were capped. Either reduce the rank or reduce the correlation_scale to avoid rescaling."
|
|
344
345
|
)
|
|
345
|
-
factor = factor * scale
|
|
346
|
-
factor_sq_sum = (factor**2).sum(dim=1)
|
|
347
|
-
diag_term = 1 - factor_sq_sum
|
|
348
346
|
|
|
349
347
|
return LowRankCorrelationMatrix(
|
|
350
348
|
correlation_factor=factor, correlation_diag=diag_term
|
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from sae_lens.synthetic.correlation import LowRankCorrelationMatrix
|
|
6
|
+
from sae_lens.synthetic.feature_dictionary import FeatureDictionary
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class CorrelationMatrixStats:
|
|
11
|
+
"""Statistics computed from a correlation matrix."""
|
|
12
|
+
|
|
13
|
+
rms_correlation: float # Root mean square of off-diagonal correlations
|
|
14
|
+
mean_correlation: float # Mean of off-diagonal correlations (not absolute)
|
|
15
|
+
num_features: int
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@torch.no_grad()
|
|
19
|
+
def compute_correlation_matrix_stats(
|
|
20
|
+
correlation_matrix: torch.Tensor,
|
|
21
|
+
) -> CorrelationMatrixStats:
|
|
22
|
+
"""Compute correlation statistics from a dense correlation matrix.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
correlation_matrix: Dense correlation matrix of shape (n, n)
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
CorrelationMatrixStats with correlation statistics
|
|
29
|
+
"""
|
|
30
|
+
num_features = correlation_matrix.shape[0]
|
|
31
|
+
|
|
32
|
+
# Extract off-diagonal elements
|
|
33
|
+
mask = ~torch.eye(num_features, dtype=torch.bool, device=correlation_matrix.device)
|
|
34
|
+
off_diag = correlation_matrix[mask]
|
|
35
|
+
|
|
36
|
+
rms_correlation = (off_diag**2).mean().sqrt().item()
|
|
37
|
+
mean_correlation = off_diag.mean().item()
|
|
38
|
+
|
|
39
|
+
return CorrelationMatrixStats(
|
|
40
|
+
rms_correlation=rms_correlation,
|
|
41
|
+
mean_correlation=mean_correlation,
|
|
42
|
+
num_features=num_features,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@torch.no_grad()
|
|
47
|
+
def compute_low_rank_correlation_matrix_stats(
|
|
48
|
+
correlation_matrix: LowRankCorrelationMatrix,
|
|
49
|
+
) -> CorrelationMatrixStats:
|
|
50
|
+
"""Compute correlation statistics from a LowRankCorrelationMatrix.
|
|
51
|
+
|
|
52
|
+
The correlation matrix is represented as:
|
|
53
|
+
correlation = factor @ factor.T + diag(diag_term)
|
|
54
|
+
|
|
55
|
+
The off-diagonal elements are simply factor @ factor.T (the diagonal term
|
|
56
|
+
only affects the diagonal).
|
|
57
|
+
|
|
58
|
+
All statistics are computed efficiently in O(n*r²) time and O(r²) memory
|
|
59
|
+
without materializing the full n×n correlation matrix.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
correlation_matrix: Low-rank correlation matrix
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
CorrelationMatrixStats with correlation statistics
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
factor = correlation_matrix.correlation_factor
|
|
69
|
+
num_features = factor.shape[0]
|
|
70
|
+
num_off_diag = num_features * (num_features - 1)
|
|
71
|
+
|
|
72
|
+
# RMS correlation: uses ||F @ F.T||_F² = ||F.T @ F||_F²
|
|
73
|
+
# This avoids computing the (num_features, num_features) matrix
|
|
74
|
+
G = factor.T @ factor # (rank, rank) - small!
|
|
75
|
+
frobenius_sq = (G**2).sum()
|
|
76
|
+
row_norms_sq = (factor**2).sum(dim=1) # ||F[i]||² for each row
|
|
77
|
+
diag_sq_sum = (row_norms_sq**2).sum() # Σᵢ ||F[i]||⁴
|
|
78
|
+
off_diag_sq_sum = frobenius_sq - diag_sq_sum
|
|
79
|
+
rms_correlation = (off_diag_sq_sum / num_off_diag).sqrt().item()
|
|
80
|
+
|
|
81
|
+
# Mean correlation (not absolute): sum(C) = ||col_sums(F)||², trace(C) = Σ||F[i]||²
|
|
82
|
+
col_sums = factor.sum(dim=0) # (rank,)
|
|
83
|
+
sum_all = (col_sums**2).sum() # 1ᵀ C 1
|
|
84
|
+
trace_C = row_norms_sq.sum()
|
|
85
|
+
mean_correlation = ((sum_all - trace_C) / num_off_diag).item()
|
|
86
|
+
|
|
87
|
+
return CorrelationMatrixStats(
|
|
88
|
+
rms_correlation=rms_correlation,
|
|
89
|
+
mean_correlation=mean_correlation,
|
|
90
|
+
num_features=num_features,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@dataclass
|
|
95
|
+
class SuperpositionStats:
|
|
96
|
+
"""Statistics measuring superposition in a feature dictionary."""
|
|
97
|
+
|
|
98
|
+
# Per-latent statistics: for each latent, max and percentile of |cos_sim| with others
|
|
99
|
+
max_abs_cos_sims: torch.Tensor # Shape: (num_features,)
|
|
100
|
+
percentile_abs_cos_sims: dict[int, torch.Tensor] # {percentile: (num_features,)}
|
|
101
|
+
|
|
102
|
+
# Summary statistics (means of the per-latent values)
|
|
103
|
+
mean_max_abs_cos_sim: float
|
|
104
|
+
mean_percentile_abs_cos_sim: dict[int, float]
|
|
105
|
+
mean_abs_cos_sim: float # Mean |cos_sim| across all pairs
|
|
106
|
+
|
|
107
|
+
# Metadata
|
|
108
|
+
num_features: int
|
|
109
|
+
hidden_dim: int
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@torch.no_grad()
|
|
113
|
+
def compute_superposition_stats(
|
|
114
|
+
feature_dictionary: FeatureDictionary,
|
|
115
|
+
batch_size: int = 1024,
|
|
116
|
+
device: str | torch.device | None = None,
|
|
117
|
+
percentiles: list[int] | None = None,
|
|
118
|
+
) -> SuperpositionStats:
|
|
119
|
+
"""Compute superposition statistics for a feature dictionary.
|
|
120
|
+
|
|
121
|
+
Computes pairwise cosine similarities in batches to handle large dictionaries.
|
|
122
|
+
|
|
123
|
+
For each latent i, computes:
|
|
124
|
+
|
|
125
|
+
- max |cos_sim(i, j)| over all j != i
|
|
126
|
+
- kth percentile of |cos_sim(i, j)| over all j != i (for each k in percentiles)
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
feature_dictionary: FeatureDictionary containing the feature vectors
|
|
130
|
+
batch_size: Number of features to process per batch
|
|
131
|
+
device: Device for computation (defaults to feature dictionary's device)
|
|
132
|
+
percentiles: List of percentiles to compute per latent (default: [95, 99])
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
SuperpositionStats with superposition metrics
|
|
136
|
+
"""
|
|
137
|
+
if percentiles is None:
|
|
138
|
+
percentiles = [95, 99]
|
|
139
|
+
|
|
140
|
+
feature_vectors = feature_dictionary.feature_vectors
|
|
141
|
+
num_features, hidden_dim = feature_vectors.shape
|
|
142
|
+
|
|
143
|
+
if num_features < 2:
|
|
144
|
+
raise ValueError("Need at least 2 features to compute superposition stats")
|
|
145
|
+
if device is None:
|
|
146
|
+
device = feature_vectors.device
|
|
147
|
+
|
|
148
|
+
# Normalize features to unit norm for cosine similarity
|
|
149
|
+
features_normalized = feature_vectors.to(device).float()
|
|
150
|
+
norms = torch.linalg.norm(features_normalized, dim=1, keepdim=True)
|
|
151
|
+
features_normalized = features_normalized / norms.clamp(min=1e-8)
|
|
152
|
+
|
|
153
|
+
# Track per-latent statistics
|
|
154
|
+
max_abs_cos_sims = torch.zeros(num_features, device=device)
|
|
155
|
+
percentile_abs_cos_sims = {
|
|
156
|
+
p: torch.zeros(num_features, device=device) for p in percentiles
|
|
157
|
+
}
|
|
158
|
+
sum_abs_cos_sim = 0.0
|
|
159
|
+
n_pairs = 0
|
|
160
|
+
|
|
161
|
+
# Process in batches: for each batch of features, compute similarities with all others
|
|
162
|
+
for i in range(0, num_features, batch_size):
|
|
163
|
+
batch_end = min(i + batch_size, num_features)
|
|
164
|
+
batch = features_normalized[i:batch_end] # (batch_size, hidden_dim)
|
|
165
|
+
|
|
166
|
+
# Compute cosine similarities with all features: (batch_size, num_features)
|
|
167
|
+
cos_sims = batch @ features_normalized.T
|
|
168
|
+
|
|
169
|
+
# Absolute cosine similarities
|
|
170
|
+
abs_cos_sims = cos_sims.abs()
|
|
171
|
+
|
|
172
|
+
# Process each latent in the batch
|
|
173
|
+
for j, idx in enumerate(range(i, batch_end)):
|
|
174
|
+
# Get similarities with all other features (exclude self)
|
|
175
|
+
row = abs_cos_sims[j].clone()
|
|
176
|
+
row[idx] = 0.0 # Exclude self for max
|
|
177
|
+
max_abs_cos_sims[idx] = row.max()
|
|
178
|
+
|
|
179
|
+
# For percentiles, exclude self and compute
|
|
180
|
+
other_sims = torch.cat([abs_cos_sims[j, :idx], abs_cos_sims[j, idx + 1 :]])
|
|
181
|
+
for p in percentiles:
|
|
182
|
+
percentile_abs_cos_sims[p][idx] = torch.quantile(other_sims, p / 100.0)
|
|
183
|
+
|
|
184
|
+
# Sum for mean computation (only count pairs once - with features after this one)
|
|
185
|
+
sum_abs_cos_sim += abs_cos_sims[j, idx + 1 :].sum().item()
|
|
186
|
+
n_pairs += num_features - idx - 1
|
|
187
|
+
|
|
188
|
+
# Compute summary statistics
|
|
189
|
+
mean_max_abs_cos_sim = max_abs_cos_sims.mean().item()
|
|
190
|
+
mean_percentile_abs_cos_sim = {
|
|
191
|
+
p: percentile_abs_cos_sims[p].mean().item() for p in percentiles
|
|
192
|
+
}
|
|
193
|
+
mean_abs_cos_sim = sum_abs_cos_sim / n_pairs if n_pairs > 0 else 0.0
|
|
194
|
+
|
|
195
|
+
return SuperpositionStats(
|
|
196
|
+
max_abs_cos_sims=max_abs_cos_sims.cpu(),
|
|
197
|
+
percentile_abs_cos_sims={
|
|
198
|
+
p: v.cpu() for p, v in percentile_abs_cos_sims.items()
|
|
199
|
+
},
|
|
200
|
+
mean_max_abs_cos_sim=mean_max_abs_cos_sim,
|
|
201
|
+
mean_percentile_abs_cos_sim=mean_percentile_abs_cos_sim,
|
|
202
|
+
mean_abs_cos_sim=mean_abs_cos_sim,
|
|
203
|
+
num_features=num_features,
|
|
204
|
+
hidden_dim=hidden_dim,
|
|
205
|
+
)
|
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
sae_lens/__init__.py,sha256=
|
|
1
|
+
sae_lens/__init__.py,sha256=U5I6dWw2_x6ZXwL8BF72vB0HGPEUt-_BzHq-btLq2pw,5168
|
|
2
2
|
sae_lens/analysis/__init__.py,sha256=FZExlMviNwWR7OGUSGRbd0l-yUDGSp80gglI_ivILrY,412
|
|
3
3
|
sae_lens/analysis/compat.py,sha256=cgE3nhFcJTcuhppxbL71VanJS7YqVEOefuneB5eOaPw,538
|
|
4
|
-
sae_lens/analysis/hooked_sae_transformer.py,sha256
|
|
4
|
+
sae_lens/analysis/hooked_sae_transformer.py,sha256=-LY9CKYEziSTt-H7MeLdTx6ErfvMqgNEF709wV7tEs4,17826
|
|
5
5
|
sae_lens/analysis/neuronpedia_integration.py,sha256=Gx1W7hUBEuMoasNcnOnZ1wmqbXDd1pSZ1nqKEya1HQc,4962
|
|
6
|
-
sae_lens/analysis/sae_transformer_bridge.py,sha256=
|
|
6
|
+
sae_lens/analysis/sae_transformer_bridge.py,sha256=y_ZdvaxuUM_-7ywSCeFl6f6cq1_FiqRstMAo8mxZtYE,16191
|
|
7
7
|
sae_lens/cache_activations_runner.py,sha256=TjqNWIc46Nw09jHWFjzQzgzG5wdu_87Ahe-iFjI5_0Q,13117
|
|
8
8
|
sae_lens/config.py,sha256=V0BXV8rvpbm5YuVukow9FURPpdyE4HSflbdymAo0Ycg,31205
|
|
9
9
|
sae_lens/constants.py,sha256=CM-h9AjZNAl2aP7hVpKk7YsFHpu-_Lfhhmq2d5qPEVc,887
|
|
@@ -12,7 +12,7 @@ sae_lens/llm_sae_training_runner.py,sha256=M7BK55gSFYu2qFQKABHX3c8i46P1LfODCeyHF
|
|
|
12
12
|
sae_lens/load_model.py,sha256=C8AMykctj6H7tz_xRwB06-EXj6TfW64PtSJZR5Jxn1Y,8649
|
|
13
13
|
sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
14
14
|
sae_lens/loading/pretrained_sae_loaders.py,sha256=kshvA0NivOc7B3sL19lHr_zrC_DDfW2T6YWb5j0hgAk,63930
|
|
15
|
-
sae_lens/loading/pretrained_saes_directory.py,sha256=
|
|
15
|
+
sae_lens/loading/pretrained_saes_directory.py,sha256=lSnHl77IO5dd7iO21ynCzZNMrzuJAT8Za4W5THNq0qw,3554
|
|
16
16
|
sae_lens/pretokenize_runner.py,sha256=amJwIz3CKi2s2wNQn-10E7eAV7VFhNqtFDNTeTkwEI8,7133
|
|
17
17
|
sae_lens/pretrained_saes.yaml,sha256=IVBLLR8_XNllJ1O-kVv9ED4u0u44Yn8UOL9R-f8Idp4,1511936
|
|
18
18
|
sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
|
|
@@ -22,20 +22,21 @@ sae_lens/saes/gated_sae.py,sha256=V_2ZNlV4gRD-rX5JSx1xqY7idT8ChfdQ5yxWDdu_6hg,88
|
|
|
22
22
|
sae_lens/saes/jumprelu_sae.py,sha256=miiF-xI_yXdV9EkKjwAbU9zSMsx9KtKCz5YdXEzkN8g,13313
|
|
23
23
|
sae_lens/saes/matching_pursuit_sae.py,sha256=08_G9p1YMLnE5qZVCPp6gll-iG6nHRbMMASf4_bkFt8,13207
|
|
24
24
|
sae_lens/saes/matryoshka_batchtopk_sae.py,sha256=Qr6htt1HHOuO9FXI9hyaPSnGFIiJG-v7y1t1CEmkFzM,5995
|
|
25
|
-
sae_lens/saes/sae.py,sha256=
|
|
25
|
+
sae_lens/saes/sae.py,sha256=2FHhLoTZYOJUGjkWH7eF2EAc1fPxQsF4KYr5TZ8IUIU,40155
|
|
26
26
|
sae_lens/saes/standard_sae.py,sha256=_hldNZkFPAf9VGrxouR1-tN8T2OEk8IkWBcXoatrC1o,5749
|
|
27
27
|
sae_lens/saes/temporal_sae.py,sha256=S44sPddVj2xujA02CC8gT1tG0in7c_CSAhspu9FHbaA,13273
|
|
28
28
|
sae_lens/saes/topk_sae.py,sha256=vrMRPrCQR1o8G_kXqY_EAoGZARupkQNFB2dNZVLsusE,21073
|
|
29
29
|
sae_lens/saes/transcoder.py,sha256=CTpJs8ASOK06npih7gZHygZuxqTR7HICWlOYfTiKjI4,13501
|
|
30
|
-
sae_lens/synthetic/__init__.py,sha256=
|
|
30
|
+
sae_lens/synthetic/__init__.py,sha256=hRRA3xhEQUacGyFbJXkLVYg_8A1bbSYYWlVovb0g4KU,3503
|
|
31
31
|
sae_lens/synthetic/activation_generator.py,sha256=8L9nwC4jFRv_wg3QN-n1sFwX8w1NqwJMysWaJ41lLlY,15197
|
|
32
|
-
sae_lens/synthetic/correlation.py,sha256=
|
|
32
|
+
sae_lens/synthetic/correlation.py,sha256=tD8J9abWfuFtGZrEbbFn4P8FeTcNKF2V5JhBLwDUmkg,13146
|
|
33
33
|
sae_lens/synthetic/evals.py,sha256=Nhi314ZnRgLfhBj-3tm_zzI-pGyFTcwllDXbIpPFXeU,4584
|
|
34
34
|
sae_lens/synthetic/feature_dictionary.py,sha256=Nd4xjSTxKMnKilZ3uYi8Gv5SS5D4bv4wHiSL1uGB69E,6933
|
|
35
35
|
sae_lens/synthetic/firing_probabilities.py,sha256=yclz1pWl5gE1r8LAxFvzQS88Lxwk5-3r8BCX9HLVejA,3370
|
|
36
36
|
sae_lens/synthetic/hierarchy.py,sha256=nm7nwnTswktVJeKUsRZ0hLOdXcFWGbxnA1b6lefHm-4,33592
|
|
37
37
|
sae_lens/synthetic/initialization.py,sha256=orMGW-786wRDHIS2W7bEH0HmlVFQ4g2z4bnnwdv5w4s,1386
|
|
38
38
|
sae_lens/synthetic/plotting.py,sha256=5lFrej1QOkGAcImFNo5-o-8mI_rUVqvEI57KzUQPPtQ,8208
|
|
39
|
+
sae_lens/synthetic/stats.py,sha256=BoDPKDx8pgFF5Ko_IaBRZTczm7-ANUIRjjF5W5Qh3Lk,7441
|
|
39
40
|
sae_lens/synthetic/training.py,sha256=fHcX2cZ6nDupr71GX0Gk17f1NvQ0SKIVXIA6IuAb2dw,5692
|
|
40
41
|
sae_lens/tokenization_and_batching.py,sha256=uoHtAs9z3XqG0Fh-iQVYVlrbyB_E3kFFhrKU30BosCo,5438
|
|
41
42
|
sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -48,7 +49,7 @@ sae_lens/training/types.py,sha256=1FpLx_Doda9vZpmfm-x1e8wGBYpyhe9Kpb_JuM5nIFM,90
|
|
|
48
49
|
sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
|
|
49
50
|
sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
|
|
50
51
|
sae_lens/util.py,sha256=oIMoeyEP2IzcPFmRbKUzOAycgEyMcOasGeO_BGVZbc4,4846
|
|
51
|
-
sae_lens-6.
|
|
52
|
-
sae_lens-6.
|
|
53
|
-
sae_lens-6.
|
|
54
|
-
sae_lens-6.
|
|
52
|
+
sae_lens-6.34.1.dist-info/METADATA,sha256=vAlbWT90NggKoxjII6dnEj7wlC-PmuBn8zzqL6r8dRg,6566
|
|
53
|
+
sae_lens-6.34.1.dist-info/WHEEL,sha256=kJCRJT_g0adfAJzTx2GUMmS80rTJIVHRCfG0DQgLq3o,88
|
|
54
|
+
sae_lens-6.34.1.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
|
|
55
|
+
sae_lens-6.34.1.dist-info/RECORD,,
|
|
File without changes
|