sae-lens 6.29.1__py3-none-any.whl → 6.33.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +14 -1
- sae_lens/analysis/__init__.py +15 -0
- sae_lens/analysis/compat.py +16 -0
- sae_lens/analysis/hooked_sae_transformer.py +1 -1
- sae_lens/analysis/sae_transformer_bridge.py +348 -0
- sae_lens/config.py +9 -1
- sae_lens/evals.py +2 -2
- sae_lens/loading/pretrained_sae_loaders.py +11 -4
- sae_lens/loading/pretrained_saes_directory.py +0 -22
- sae_lens/pretrained_saes.yaml +36 -0
- sae_lens/saes/sae.py +0 -31
- sae_lens/saes/temporal_sae.py +1 -1
- sae_lens/synthetic/__init__.py +13 -0
- sae_lens/synthetic/correlation.py +12 -14
- sae_lens/synthetic/stats.py +205 -0
- sae_lens/training/activation_scaler.py +3 -1
- {sae_lens-6.29.1.dist-info → sae_lens-6.33.0.dist-info}/METADATA +2 -2
- {sae_lens-6.29.1.dist-info → sae_lens-6.33.0.dist-info}/RECORD +20 -17
- {sae_lens-6.29.1.dist-info → sae_lens-6.33.0.dist-info}/WHEEL +1 -1
- {sae_lens-6.29.1.dist-info → sae_lens-6.33.0.dist-info}/licenses/LICENSE +0 -0
sae_lens/__init__.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
# ruff: noqa: E402
|
|
2
|
-
__version__ = "6.
|
|
2
|
+
__version__ = "6.33.0"
|
|
3
3
|
|
|
4
4
|
import logging
|
|
5
5
|
|
|
@@ -125,6 +125,19 @@ __all__ = [
|
|
|
125
125
|
"MatchingPursuitTrainingSAEConfig",
|
|
126
126
|
]
|
|
127
127
|
|
|
128
|
+
# Conditional export for SAETransformerBridge (requires transformer-lens v3+)
|
|
129
|
+
try:
|
|
130
|
+
from sae_lens.analysis.compat import has_transformer_bridge
|
|
131
|
+
|
|
132
|
+
if has_transformer_bridge():
|
|
133
|
+
from sae_lens.analysis.sae_transformer_bridge import ( # noqa: F401
|
|
134
|
+
SAETransformerBridge,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
__all__.append("SAETransformerBridge")
|
|
138
|
+
except ImportError:
|
|
139
|
+
pass
|
|
140
|
+
|
|
128
141
|
|
|
129
142
|
register_sae_class("standard", StandardSAE, StandardSAEConfig)
|
|
130
143
|
register_sae_training_class("standard", StandardTrainingSAE, StandardTrainingSAEConfig)
|
sae_lens/analysis/__init__.py
CHANGED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from sae_lens.analysis.hooked_sae_transformer import HookedSAETransformer
|
|
2
|
+
|
|
3
|
+
__all__ = ["HookedSAETransformer"]
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
from sae_lens.analysis.compat import has_transformer_bridge
|
|
7
|
+
|
|
8
|
+
if has_transformer_bridge():
|
|
9
|
+
from sae_lens.analysis.sae_transformer_bridge import ( # noqa: F401
|
|
10
|
+
SAETransformerBridge,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
__all__.append("SAETransformerBridge")
|
|
14
|
+
except ImportError:
|
|
15
|
+
pass
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import importlib.metadata
|
|
2
|
+
|
|
3
|
+
from packaging.version import parse as parse_version
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def get_transformer_lens_version() -> tuple[int, int, int]:
|
|
7
|
+
"""Get transformer-lens version as (major, minor, patch)."""
|
|
8
|
+
version_str = importlib.metadata.version("transformer-lens")
|
|
9
|
+
version = parse_version(version_str)
|
|
10
|
+
return (version.major, version.minor, version.micro)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def has_transformer_bridge() -> bool:
|
|
14
|
+
"""Check if TransformerBridge is available (v3+)."""
|
|
15
|
+
major, _, _ = get_transformer_lens_version()
|
|
16
|
+
return major >= 3
|
|
@@ -126,7 +126,7 @@ class HookedSAETransformer(HookedTransformer):
|
|
|
126
126
|
current_sae.use_error_term = current_sae._original_use_error_term # type: ignore
|
|
127
127
|
delattr(current_sae, "_original_use_error_term")
|
|
128
128
|
|
|
129
|
-
if prev_sae:
|
|
129
|
+
if prev_sae is not None:
|
|
130
130
|
set_deep_attr(self, act_name, prev_sae)
|
|
131
131
|
self.acts_to_saes[act_name] = prev_sae
|
|
132
132
|
else:
|
|
@@ -0,0 +1,348 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from contextlib import contextmanager
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from transformer_lens.ActivationCache import ActivationCache
|
|
7
|
+
from transformer_lens.hook_points import HookPoint
|
|
8
|
+
from transformer_lens.model_bridge import TransformerBridge
|
|
9
|
+
|
|
10
|
+
from sae_lens import logger
|
|
11
|
+
from sae_lens.analysis.hooked_sae_transformer import set_deep_attr
|
|
12
|
+
from sae_lens.saes.sae import SAE
|
|
13
|
+
|
|
14
|
+
SingleLoss = torch.Tensor # Type alias for a single element tensor
|
|
15
|
+
LossPerToken = torch.Tensor
|
|
16
|
+
Loss = SingleLoss | LossPerToken
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class SAETransformerBridge(TransformerBridge): # type: ignore[misc,no-untyped-call]
|
|
20
|
+
"""A TransformerBridge subclass that supports attaching SAEs.
|
|
21
|
+
|
|
22
|
+
.. warning::
|
|
23
|
+
This class is in **beta**. The API may change in future versions.
|
|
24
|
+
|
|
25
|
+
This class provides the same SAE attachment functionality as HookedSAETransformer,
|
|
26
|
+
but for transformer-lens v3's TransformerBridge instead of HookedTransformer.
|
|
27
|
+
|
|
28
|
+
TransformerBridge is a lightweight wrapper around HuggingFace models that provides
|
|
29
|
+
hook points without the overhead of HookedTransformer's weight processing. This is
|
|
30
|
+
useful for models not natively supported by HookedTransformer, such as Gemma 3.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
acts_to_saes: dict[str, SAE[Any]]
|
|
34
|
+
|
|
35
|
+
def __init__(self, *args: Any, **kwargs: Any):
|
|
36
|
+
super().__init__(*args, **kwargs)
|
|
37
|
+
self.acts_to_saes = {}
|
|
38
|
+
|
|
39
|
+
@classmethod
|
|
40
|
+
def boot_transformers( # type: ignore[override]
|
|
41
|
+
cls,
|
|
42
|
+
model_name: str,
|
|
43
|
+
**kwargs: Any,
|
|
44
|
+
) -> "SAETransformerBridge":
|
|
45
|
+
"""Factory method to boot a model and return SAETransformerBridge instance.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
model_name: The name of the model to load (e.g., "gpt2", "gemma-2-2b")
|
|
49
|
+
**kwargs: Additional arguments passed to TransformerBridge.boot_transformers
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
SAETransformerBridge instance with the loaded model
|
|
53
|
+
"""
|
|
54
|
+
# Boot parent TransformerBridge
|
|
55
|
+
bridge = TransformerBridge.boot_transformers(model_name, **kwargs)
|
|
56
|
+
# Convert to our class
|
|
57
|
+
# NOTE: this is super hacky and scary, but I don't know how else to achieve this given TLens' internal code
|
|
58
|
+
bridge.__class__ = cls
|
|
59
|
+
bridge.acts_to_saes = {} # type: ignore[attr-defined]
|
|
60
|
+
return bridge # type: ignore[return-value]
|
|
61
|
+
|
|
62
|
+
def _resolve_hook_name(self, hook_name: str) -> str:
|
|
63
|
+
"""Resolve alias to actual hook name.
|
|
64
|
+
|
|
65
|
+
TransformerBridge supports hook aliases like 'blocks.0.hook_mlp_out'
|
|
66
|
+
that map to actual paths like 'blocks.0.mlp.hook_out'.
|
|
67
|
+
"""
|
|
68
|
+
# Combine static and dynamic aliases
|
|
69
|
+
aliases: dict[str, Any] = {
|
|
70
|
+
**self.hook_aliases,
|
|
71
|
+
**self._collect_hook_aliases_from_registry(),
|
|
72
|
+
}
|
|
73
|
+
resolved = aliases.get(hook_name, hook_name)
|
|
74
|
+
# aliases values are always strings, but type checker doesn't know this
|
|
75
|
+
return resolved if isinstance(resolved, str) else hook_name
|
|
76
|
+
|
|
77
|
+
def add_sae(self, sae: SAE[Any], use_error_term: bool | None = None) -> None:
|
|
78
|
+
"""Attaches an SAE to the model.
|
|
79
|
+
|
|
80
|
+
WARNING: This SAE will be permanently attached until you remove it with
|
|
81
|
+
reset_saes. This function will also overwrite any existing SAE attached
|
|
82
|
+
to the same hook point.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
sae: The SAE to attach to the model
|
|
86
|
+
use_error_term: If provided, will set the use_error_term attribute of
|
|
87
|
+
the SAE to this value. Determines whether the SAE returns input
|
|
88
|
+
or reconstruction. Defaults to None.
|
|
89
|
+
"""
|
|
90
|
+
alias_name = sae.cfg.metadata.hook_name
|
|
91
|
+
actual_name = self._resolve_hook_name(alias_name)
|
|
92
|
+
|
|
93
|
+
# Check if hook exists (either as alias or actual name)
|
|
94
|
+
if (alias_name not in self.acts_to_saes) and (
|
|
95
|
+
actual_name not in self._hook_registry
|
|
96
|
+
):
|
|
97
|
+
logger.warning(
|
|
98
|
+
f"No hook found for {alias_name}. Skipping. "
|
|
99
|
+
f"Check model._hook_registry for available hooks."
|
|
100
|
+
)
|
|
101
|
+
return
|
|
102
|
+
|
|
103
|
+
if use_error_term is not None:
|
|
104
|
+
if not hasattr(sae, "_original_use_error_term"):
|
|
105
|
+
sae._original_use_error_term = sae.use_error_term # type: ignore[attr-defined]
|
|
106
|
+
sae.use_error_term = use_error_term
|
|
107
|
+
|
|
108
|
+
# Replace hook and update registry
|
|
109
|
+
set_deep_attr(self, actual_name, sae)
|
|
110
|
+
self._hook_registry[actual_name] = sae # type: ignore[assignment]
|
|
111
|
+
self.acts_to_saes[alias_name] = sae
|
|
112
|
+
|
|
113
|
+
def _reset_sae(self, act_name: str, prev_sae: SAE[Any] | None = None) -> None:
|
|
114
|
+
"""Resets an SAE that was attached to the model.
|
|
115
|
+
|
|
116
|
+
By default will remove the SAE from that hook_point.
|
|
117
|
+
If prev_sae is provided, will replace the current SAE with the provided one.
|
|
118
|
+
This is mainly used to restore previously attached SAEs after temporarily
|
|
119
|
+
running with different SAEs (e.g., with run_with_saes).
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
act_name: The hook_name of the SAE to reset
|
|
123
|
+
prev_sae: The SAE to replace the current one with. If None, will just
|
|
124
|
+
remove the SAE from this hook point. Defaults to None.
|
|
125
|
+
"""
|
|
126
|
+
if act_name not in self.acts_to_saes:
|
|
127
|
+
logger.warning(
|
|
128
|
+
f"No SAE is attached to {act_name}. There's nothing to reset."
|
|
129
|
+
)
|
|
130
|
+
return
|
|
131
|
+
|
|
132
|
+
actual_name = self._resolve_hook_name(act_name)
|
|
133
|
+
current_sae = self.acts_to_saes[act_name]
|
|
134
|
+
|
|
135
|
+
if hasattr(current_sae, "_original_use_error_term"):
|
|
136
|
+
current_sae.use_error_term = current_sae._original_use_error_term # type: ignore[attr-defined]
|
|
137
|
+
delattr(current_sae, "_original_use_error_term")
|
|
138
|
+
|
|
139
|
+
if prev_sae is not None:
|
|
140
|
+
set_deep_attr(self, actual_name, prev_sae)
|
|
141
|
+
self._hook_registry[actual_name] = prev_sae # type: ignore[assignment]
|
|
142
|
+
self.acts_to_saes[act_name] = prev_sae
|
|
143
|
+
else:
|
|
144
|
+
new_hook = HookPoint()
|
|
145
|
+
new_hook.name = actual_name
|
|
146
|
+
set_deep_attr(self, actual_name, new_hook)
|
|
147
|
+
self._hook_registry[actual_name] = new_hook
|
|
148
|
+
del self.acts_to_saes[act_name]
|
|
149
|
+
|
|
150
|
+
def reset_saes(
|
|
151
|
+
self,
|
|
152
|
+
act_names: str | list[str] | None = None,
|
|
153
|
+
prev_saes: list[SAE[Any] | None] | None = None,
|
|
154
|
+
) -> None:
|
|
155
|
+
"""Reset the SAEs attached to the model.
|
|
156
|
+
|
|
157
|
+
If act_names are provided will just reset SAEs attached to those hooks.
|
|
158
|
+
Otherwise will reset all SAEs attached to the model.
|
|
159
|
+
Optionally can provide a list of prev_saes to reset to. This is mainly
|
|
160
|
+
used to restore previously attached SAEs after temporarily running with
|
|
161
|
+
different SAEs (e.g., with run_with_saes).
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
act_names: The act_names of the SAEs to reset. If None, will reset all
|
|
165
|
+
SAEs attached to the model. Defaults to None.
|
|
166
|
+
prev_saes: List of SAEs to replace the current ones with. If None, will
|
|
167
|
+
just remove the SAEs. Defaults to None.
|
|
168
|
+
"""
|
|
169
|
+
if isinstance(act_names, str):
|
|
170
|
+
act_names = [act_names]
|
|
171
|
+
elif act_names is None:
|
|
172
|
+
act_names = list(self.acts_to_saes.keys())
|
|
173
|
+
|
|
174
|
+
if prev_saes:
|
|
175
|
+
if len(act_names) != len(prev_saes):
|
|
176
|
+
raise ValueError("act_names and prev_saes must have the same length")
|
|
177
|
+
else:
|
|
178
|
+
prev_saes = [None] * len(act_names) # type: ignore[assignment]
|
|
179
|
+
|
|
180
|
+
for act_name, prev_sae in zip(act_names, prev_saes): # type: ignore[arg-type]
|
|
181
|
+
self._reset_sae(act_name, prev_sae)
|
|
182
|
+
|
|
183
|
+
def run_with_saes(
|
|
184
|
+
self,
|
|
185
|
+
*model_args: Any,
|
|
186
|
+
saes: SAE[Any] | list[SAE[Any]] = [],
|
|
187
|
+
reset_saes_end: bool = True,
|
|
188
|
+
use_error_term: bool | None = None,
|
|
189
|
+
**model_kwargs: Any,
|
|
190
|
+
) -> torch.Tensor | Loss | tuple[torch.Tensor, Loss] | None:
|
|
191
|
+
"""Wrapper around forward pass.
|
|
192
|
+
|
|
193
|
+
Runs the model with the given SAEs attached for one forward pass, then
|
|
194
|
+
removes them. By default, will reset all SAEs to original state after.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
*model_args: Positional arguments for the model forward pass
|
|
198
|
+
saes: The SAEs to be attached for this forward pass
|
|
199
|
+
reset_saes_end: If True, all SAEs added during this run are removed
|
|
200
|
+
at the end, and previously attached SAEs are restored to their
|
|
201
|
+
original state. Default is True.
|
|
202
|
+
use_error_term: If provided, will set the use_error_term attribute
|
|
203
|
+
of all SAEs attached during this run to this value. Defaults to None.
|
|
204
|
+
**model_kwargs: Keyword arguments for the model forward pass
|
|
205
|
+
"""
|
|
206
|
+
with self.saes(
|
|
207
|
+
saes=saes, reset_saes_end=reset_saes_end, use_error_term=use_error_term
|
|
208
|
+
):
|
|
209
|
+
return self(*model_args, **model_kwargs)
|
|
210
|
+
|
|
211
|
+
def run_with_cache_with_saes(
|
|
212
|
+
self,
|
|
213
|
+
*model_args: Any,
|
|
214
|
+
saes: SAE[Any] | list[SAE[Any]] = [],
|
|
215
|
+
reset_saes_end: bool = True,
|
|
216
|
+
use_error_term: bool | None = None,
|
|
217
|
+
return_cache_object: bool = True,
|
|
218
|
+
remove_batch_dim: bool = False,
|
|
219
|
+
**kwargs: Any,
|
|
220
|
+
) -> tuple[
|
|
221
|
+
torch.Tensor | Loss | tuple[torch.Tensor, Loss] | None,
|
|
222
|
+
ActivationCache | dict[str, torch.Tensor],
|
|
223
|
+
]:
|
|
224
|
+
"""Wrapper around 'run_with_cache'.
|
|
225
|
+
|
|
226
|
+
Attaches given SAEs before running the model with cache and then removes them.
|
|
227
|
+
By default, will reset all SAEs to original state after.
|
|
228
|
+
|
|
229
|
+
Args:
|
|
230
|
+
*model_args: Positional arguments for the model forward pass
|
|
231
|
+
saes: The SAEs to be attached for this forward pass
|
|
232
|
+
reset_saes_end: If True, all SAEs added during this run are removed
|
|
233
|
+
at the end, and previously attached SAEs are restored to their
|
|
234
|
+
original state. Default is True.
|
|
235
|
+
use_error_term: If provided, will set the use_error_term attribute
|
|
236
|
+
of all SAEs attached during this run to this value. Defaults to None.
|
|
237
|
+
return_cache_object: If True, returns an ActivationCache object with
|
|
238
|
+
useful methods, otherwise returns a dictionary of activations.
|
|
239
|
+
remove_batch_dim: Whether to remove the batch dimension
|
|
240
|
+
(only works for batch_size==1). Defaults to False.
|
|
241
|
+
**kwargs: Keyword arguments for the model forward pass
|
|
242
|
+
"""
|
|
243
|
+
with self.saes(
|
|
244
|
+
saes=saes, reset_saes_end=reset_saes_end, use_error_term=use_error_term
|
|
245
|
+
):
|
|
246
|
+
return self.run_with_cache(
|
|
247
|
+
*model_args,
|
|
248
|
+
return_cache_object=return_cache_object, # type: ignore[arg-type]
|
|
249
|
+
remove_batch_dim=remove_batch_dim,
|
|
250
|
+
**kwargs,
|
|
251
|
+
) # type: ignore[return-value]
|
|
252
|
+
|
|
253
|
+
def run_with_hooks_with_saes(
|
|
254
|
+
self,
|
|
255
|
+
*model_args: Any,
|
|
256
|
+
saes: SAE[Any] | list[SAE[Any]] = [],
|
|
257
|
+
reset_saes_end: bool = True,
|
|
258
|
+
fwd_hooks: list[tuple[str | Callable[..., Any], Callable[..., Any]]] = [],
|
|
259
|
+
bwd_hooks: list[tuple[str | Callable[..., Any], Callable[..., Any]]] = [],
|
|
260
|
+
reset_hooks_end: bool = True,
|
|
261
|
+
clear_contexts: bool = False,
|
|
262
|
+
**model_kwargs: Any,
|
|
263
|
+
) -> Any:
|
|
264
|
+
"""Wrapper around 'run_with_hooks'.
|
|
265
|
+
|
|
266
|
+
Attaches the given SAEs to the model before running the model with hooks
|
|
267
|
+
and then removes them. By default, will reset all SAEs to original state after.
|
|
268
|
+
|
|
269
|
+
Args:
|
|
270
|
+
*model_args: Positional arguments for the model forward pass
|
|
271
|
+
saes: The SAEs to be attached for this forward pass
|
|
272
|
+
reset_saes_end: If True, all SAEs added during this run are removed
|
|
273
|
+
at the end, and previously attached SAEs are restored to their
|
|
274
|
+
original state. Default is True.
|
|
275
|
+
fwd_hooks: List of forward hooks to apply
|
|
276
|
+
bwd_hooks: List of backward hooks to apply
|
|
277
|
+
reset_hooks_end: Whether to reset the hooks at the end of the forward
|
|
278
|
+
pass. Default is True.
|
|
279
|
+
clear_contexts: Whether to clear the contexts at the end of the forward
|
|
280
|
+
pass. Default is False.
|
|
281
|
+
**model_kwargs: Keyword arguments for the model forward pass
|
|
282
|
+
"""
|
|
283
|
+
with self.saes(saes=saes, reset_saes_end=reset_saes_end):
|
|
284
|
+
return self.run_with_hooks(
|
|
285
|
+
*model_args,
|
|
286
|
+
fwd_hooks=fwd_hooks,
|
|
287
|
+
bwd_hooks=bwd_hooks,
|
|
288
|
+
reset_hooks_end=reset_hooks_end,
|
|
289
|
+
clear_contexts=clear_contexts,
|
|
290
|
+
**model_kwargs,
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
@contextmanager
|
|
294
|
+
def saes(
|
|
295
|
+
self,
|
|
296
|
+
saes: SAE[Any] | list[SAE[Any]] = [],
|
|
297
|
+
reset_saes_end: bool = True,
|
|
298
|
+
use_error_term: bool | None = None,
|
|
299
|
+
): # type: ignore[no-untyped-def]
|
|
300
|
+
"""A context manager for adding temporary SAEs to the model.
|
|
301
|
+
|
|
302
|
+
By default will keep track of previously attached SAEs, and restore them
|
|
303
|
+
when the context manager exits.
|
|
304
|
+
|
|
305
|
+
Args:
|
|
306
|
+
saes: SAEs to be attached.
|
|
307
|
+
reset_saes_end: If True, removes all SAEs added by this context manager
|
|
308
|
+
when the context manager exits, returning previously attached SAEs
|
|
309
|
+
to their original state.
|
|
310
|
+
use_error_term: If provided, will set the use_error_term attribute of
|
|
311
|
+
all SAEs attached during this run to this value. Defaults to None.
|
|
312
|
+
"""
|
|
313
|
+
act_names_to_reset: list[str] = []
|
|
314
|
+
prev_saes: list[SAE[Any] | None] = []
|
|
315
|
+
if isinstance(saes, SAE):
|
|
316
|
+
saes = [saes]
|
|
317
|
+
try:
|
|
318
|
+
for sae in saes:
|
|
319
|
+
act_names_to_reset.append(sae.cfg.metadata.hook_name)
|
|
320
|
+
prev_sae = self.acts_to_saes.get(sae.cfg.metadata.hook_name, None)
|
|
321
|
+
prev_saes.append(prev_sae)
|
|
322
|
+
self.add_sae(sae, use_error_term=use_error_term)
|
|
323
|
+
yield self
|
|
324
|
+
finally:
|
|
325
|
+
if reset_saes_end:
|
|
326
|
+
self.reset_saes(act_names_to_reset, prev_saes)
|
|
327
|
+
|
|
328
|
+
@property
|
|
329
|
+
def hook_dict(self) -> dict[str, HookPoint]:
|
|
330
|
+
"""Return combined hook registry including SAE internal hooks.
|
|
331
|
+
|
|
332
|
+
When SAEs are attached, they replace HookPoint entries in the registry.
|
|
333
|
+
This property returns both the base hooks and any internal hooks from
|
|
334
|
+
attached SAEs (like hook_sae_acts_post, hook_sae_input, etc.) with
|
|
335
|
+
their full path names.
|
|
336
|
+
"""
|
|
337
|
+
hooks: dict[str, HookPoint] = {}
|
|
338
|
+
|
|
339
|
+
for name, hook_or_sae in self._hook_registry.items():
|
|
340
|
+
if isinstance(hook_or_sae, SAE):
|
|
341
|
+
# Include SAE's internal hooks with full path names
|
|
342
|
+
for sae_hook_name, sae_hook in hook_or_sae.hook_dict.items():
|
|
343
|
+
full_name = f"{name}.{sae_hook_name}"
|
|
344
|
+
hooks[full_name] = sae_hook
|
|
345
|
+
else:
|
|
346
|
+
hooks[name] = hook_or_sae
|
|
347
|
+
|
|
348
|
+
return hooks
|
sae_lens/config.py
CHANGED
|
@@ -82,6 +82,7 @@ class LoggingConfig:
|
|
|
82
82
|
log_to_wandb: bool = True
|
|
83
83
|
log_activations_store_to_wandb: bool = False
|
|
84
84
|
log_optimizer_state_to_wandb: bool = False
|
|
85
|
+
log_weights_to_wandb: bool = True
|
|
85
86
|
wandb_project: str = "sae_lens_training"
|
|
86
87
|
wandb_id: str | None = None
|
|
87
88
|
run_name: str | None = None
|
|
@@ -107,7 +108,8 @@ class LoggingConfig:
|
|
|
107
108
|
type="model",
|
|
108
109
|
metadata=dict(trainer.cfg.__dict__),
|
|
109
110
|
)
|
|
110
|
-
|
|
111
|
+
if self.log_weights_to_wandb:
|
|
112
|
+
model_artifact.add_file(str(weights_path))
|
|
111
113
|
model_artifact.add_file(str(cfg_path))
|
|
112
114
|
wandb.log_artifact(model_artifact, aliases=wandb_aliases)
|
|
113
115
|
|
|
@@ -557,6 +559,12 @@ class CacheActivationsRunnerConfig:
|
|
|
557
559
|
context_size=self.context_size,
|
|
558
560
|
)
|
|
559
561
|
|
|
562
|
+
if self.context_size > self.training_tokens:
|
|
563
|
+
raise ValueError(
|
|
564
|
+
f"context_size ({self.context_size}) is greater than training_tokens "
|
|
565
|
+
f"({self.training_tokens}). Please reduce context_size or increase training_tokens."
|
|
566
|
+
)
|
|
567
|
+
|
|
560
568
|
if self.new_cached_activations_path is None:
|
|
561
569
|
self.new_cached_activations_path = _default_cached_activations_path( # type: ignore
|
|
562
570
|
self.dataset_path, self.model_name, self.hook_name, None
|
sae_lens/evals.py
CHANGED
|
@@ -335,7 +335,7 @@ def get_downstream_reconstruction_metrics(
|
|
|
335
335
|
|
|
336
336
|
batch_iter = range(n_batches)
|
|
337
337
|
if verbose:
|
|
338
|
-
batch_iter = tqdm(batch_iter, desc="Reconstruction Batches")
|
|
338
|
+
batch_iter = tqdm(batch_iter, desc="Reconstruction Batches", leave=False)
|
|
339
339
|
|
|
340
340
|
for _ in batch_iter:
|
|
341
341
|
batch_tokens = activation_store.get_batch_tokens(eval_batch_size_prompts)
|
|
@@ -430,7 +430,7 @@ def get_sparsity_and_variance_metrics(
|
|
|
430
430
|
|
|
431
431
|
batch_iter = range(n_batches)
|
|
432
432
|
if verbose:
|
|
433
|
-
batch_iter = tqdm(batch_iter, desc="Sparsity and Variance Batches")
|
|
433
|
+
batch_iter = tqdm(batch_iter, desc="Sparsity and Variance Batches", leave=False)
|
|
434
434
|
|
|
435
435
|
for _ in batch_iter:
|
|
436
436
|
batch_tokens = activation_store.get_batch_tokens(eval_batch_size_prompts)
|
|
@@ -575,6 +575,8 @@ def _infer_gemma_3_raw_cfg_dict(repo_id: str, folder_name: str) -> dict[str, Any
|
|
|
575
575
|
"model_name": model_name,
|
|
576
576
|
"hf_hook_point_in": hf_hook_point_in,
|
|
577
577
|
}
|
|
578
|
+
if "transcoder" in folder_name or "clt" in folder_name:
|
|
579
|
+
cfg["affine_connection"] = "affine" in folder_name
|
|
578
580
|
if hf_hook_point_out is not None:
|
|
579
581
|
cfg["hf_hook_point_out"] = hf_hook_point_out
|
|
580
582
|
|
|
@@ -614,11 +616,11 @@ def get_gemma_3_config_from_hf(
|
|
|
614
616
|
if "resid_post" in folder_name:
|
|
615
617
|
hook_name = f"blocks.{layer}.hook_resid_post"
|
|
616
618
|
elif "attn_out" in folder_name:
|
|
617
|
-
hook_name = f"blocks.{layer}.
|
|
619
|
+
hook_name = f"blocks.{layer}.attn.hook_z"
|
|
618
620
|
elif "mlp_out" in folder_name:
|
|
619
621
|
hook_name = f"blocks.{layer}.hook_mlp_out"
|
|
620
622
|
elif "transcoder" in folder_name or "clt" in folder_name:
|
|
621
|
-
hook_name = f"blocks.{layer}.
|
|
623
|
+
hook_name = f"blocks.{layer}.hook_mlp_in"
|
|
622
624
|
hook_name_out = f"blocks.{layer}.hook_mlp_out"
|
|
623
625
|
else:
|
|
624
626
|
raise ValueError("Hook name not found in folder_name.")
|
|
@@ -643,7 +645,11 @@ def get_gemma_3_config_from_hf(
|
|
|
643
645
|
|
|
644
646
|
architecture = "jumprelu"
|
|
645
647
|
if "transcoder" in folder_name or "clt" in folder_name:
|
|
646
|
-
architecture =
|
|
648
|
+
architecture = (
|
|
649
|
+
"jumprelu_skip_transcoder"
|
|
650
|
+
if raw_cfg_dict.get("affine_connection", False)
|
|
651
|
+
else "jumprelu_transcoder"
|
|
652
|
+
)
|
|
647
653
|
d_out = shapes_dict["w_dec"][-1]
|
|
648
654
|
|
|
649
655
|
cfg = {
|
|
@@ -660,7 +666,8 @@ def get_gemma_3_config_from_hf(
|
|
|
660
666
|
"dataset_path": "monology/pile-uncopyrighted",
|
|
661
667
|
"context_size": 1024,
|
|
662
668
|
"apply_b_dec_to_input": False,
|
|
663
|
-
"normalize_activations":
|
|
669
|
+
"normalize_activations": "none",
|
|
670
|
+
"reshape_activations": "none",
|
|
664
671
|
"hf_hook_name": raw_cfg_dict.get("hf_hook_point_in"),
|
|
665
672
|
}
|
|
666
673
|
if hook_name_out is not None:
|
|
@@ -57,28 +57,6 @@ def get_pretrained_saes_directory() -> dict[str, PretrainedSAELookup]:
|
|
|
57
57
|
return directory
|
|
58
58
|
|
|
59
59
|
|
|
60
|
-
def get_norm_scaling_factor(release: str, sae_id: str) -> float | None:
|
|
61
|
-
"""
|
|
62
|
-
Retrieve the norm_scaling_factor for a specific SAE if it exists.
|
|
63
|
-
|
|
64
|
-
Args:
|
|
65
|
-
release (str): The release name of the SAE.
|
|
66
|
-
sae_id (str): The ID of the specific SAE.
|
|
67
|
-
|
|
68
|
-
Returns:
|
|
69
|
-
float | None: The norm_scaling_factor if it exists, None otherwise.
|
|
70
|
-
"""
|
|
71
|
-
package = "sae_lens"
|
|
72
|
-
yaml_file = files(package).joinpath("pretrained_saes.yaml")
|
|
73
|
-
with yaml_file.open("r") as file:
|
|
74
|
-
data = yaml.safe_load(file)
|
|
75
|
-
if release in data:
|
|
76
|
-
for sae_info in data[release]["saes"]:
|
|
77
|
-
if sae_info["id"] == sae_id:
|
|
78
|
-
return sae_info.get("norm_scaling_factor")
|
|
79
|
-
return None
|
|
80
|
-
|
|
81
|
-
|
|
82
60
|
def get_repo_id_and_folder_name(release: str, sae_id: str) -> tuple[str, str]:
|
|
83
61
|
saes_directory = get_pretrained_saes_directory()
|
|
84
62
|
sae_info = saes_directory.get(release, None)
|
sae_lens/pretrained_saes.yaml
CHANGED
|
@@ -4148,6 +4148,7 @@ gemma-scope-2-4b-it-res:
|
|
|
4148
4148
|
- id: layer_17_width_16k_l0_medium
|
|
4149
4149
|
path: resid_post/layer_17_width_16k_l0_medium
|
|
4150
4150
|
l0: 60
|
|
4151
|
+
neuronpedia: gemma-3-4b-it/17-gemmascope-2-res-16k
|
|
4151
4152
|
- id: layer_17_width_16k_l0_small
|
|
4152
4153
|
path: resid_post/layer_17_width_16k_l0_small
|
|
4153
4154
|
l0: 20
|
|
@@ -4166,6 +4167,7 @@ gemma-scope-2-4b-it-res:
|
|
|
4166
4167
|
- id: layer_17_width_262k_l0_medium
|
|
4167
4168
|
path: resid_post/layer_17_width_262k_l0_medium
|
|
4168
4169
|
l0: 60
|
|
4170
|
+
neuronpedia: gemma-3-4b-it/17-gemmascope-2-res-262k
|
|
4169
4171
|
- id: layer_17_width_262k_l0_medium_seed_1
|
|
4170
4172
|
path: resid_post/layer_17_width_262k_l0_medium_seed_1
|
|
4171
4173
|
l0: 60
|
|
@@ -4178,6 +4180,7 @@ gemma-scope-2-4b-it-res:
|
|
|
4178
4180
|
- id: layer_17_width_65k_l0_medium
|
|
4179
4181
|
path: resid_post/layer_17_width_65k_l0_medium
|
|
4180
4182
|
l0: 60
|
|
4183
|
+
neuronpedia: gemma-3-4b-it/17-gemmascope-2-res-65k
|
|
4181
4184
|
- id: layer_17_width_65k_l0_small
|
|
4182
4185
|
path: resid_post/layer_17_width_65k_l0_small
|
|
4183
4186
|
l0: 20
|
|
@@ -4187,6 +4190,7 @@ gemma-scope-2-4b-it-res:
|
|
|
4187
4190
|
- id: layer_22_width_16k_l0_medium
|
|
4188
4191
|
path: resid_post/layer_22_width_16k_l0_medium
|
|
4189
4192
|
l0: 60
|
|
4193
|
+
neuronpedia: gemma-3-4b-it/22-gemmascope-2-res-16k
|
|
4190
4194
|
- id: layer_22_width_16k_l0_small
|
|
4191
4195
|
path: resid_post/layer_22_width_16k_l0_small
|
|
4192
4196
|
l0: 20
|
|
@@ -4205,6 +4209,7 @@ gemma-scope-2-4b-it-res:
|
|
|
4205
4209
|
- id: layer_22_width_262k_l0_medium
|
|
4206
4210
|
path: resid_post/layer_22_width_262k_l0_medium
|
|
4207
4211
|
l0: 60
|
|
4212
|
+
neuronpedia: gemma-3-4b-it/22-gemmascope-2-res-262k
|
|
4208
4213
|
- id: layer_22_width_262k_l0_medium_seed_1
|
|
4209
4214
|
path: resid_post/layer_22_width_262k_l0_medium_seed_1
|
|
4210
4215
|
l0: 60
|
|
@@ -4217,6 +4222,7 @@ gemma-scope-2-4b-it-res:
|
|
|
4217
4222
|
- id: layer_22_width_65k_l0_medium
|
|
4218
4223
|
path: resid_post/layer_22_width_65k_l0_medium
|
|
4219
4224
|
l0: 60
|
|
4225
|
+
neuronpedia: gemma-3-4b-it/22-gemmascope-2-res-65k
|
|
4220
4226
|
- id: layer_22_width_65k_l0_small
|
|
4221
4227
|
path: resid_post/layer_22_width_65k_l0_small
|
|
4222
4228
|
l0: 20
|
|
@@ -4226,6 +4232,7 @@ gemma-scope-2-4b-it-res:
|
|
|
4226
4232
|
- id: layer_29_width_16k_l0_medium
|
|
4227
4233
|
path: resid_post/layer_29_width_16k_l0_medium
|
|
4228
4234
|
l0: 60
|
|
4235
|
+
neuronpedia: gemma-3-4b-it/29-gemmascope-2-res-16k
|
|
4229
4236
|
- id: layer_29_width_16k_l0_small
|
|
4230
4237
|
path: resid_post/layer_29_width_16k_l0_small
|
|
4231
4238
|
l0: 20
|
|
@@ -4244,6 +4251,7 @@ gemma-scope-2-4b-it-res:
|
|
|
4244
4251
|
- id: layer_29_width_262k_l0_medium
|
|
4245
4252
|
path: resid_post/layer_29_width_262k_l0_medium
|
|
4246
4253
|
l0: 60
|
|
4254
|
+
neuronpedia: gemma-3-4b-it/29-gemmascope-2-res-262k
|
|
4247
4255
|
- id: layer_29_width_262k_l0_medium_seed_1
|
|
4248
4256
|
path: resid_post/layer_29_width_262k_l0_medium_seed_1
|
|
4249
4257
|
l0: 60
|
|
@@ -4256,6 +4264,7 @@ gemma-scope-2-4b-it-res:
|
|
|
4256
4264
|
- id: layer_29_width_65k_l0_medium
|
|
4257
4265
|
path: resid_post/layer_29_width_65k_l0_medium
|
|
4258
4266
|
l0: 60
|
|
4267
|
+
neuronpedia: gemma-3-4b-it/29-gemmascope-2-res-65k
|
|
4259
4268
|
- id: layer_29_width_65k_l0_small
|
|
4260
4269
|
path: resid_post/layer_29_width_65k_l0_small
|
|
4261
4270
|
l0: 20
|
|
@@ -4265,6 +4274,7 @@ gemma-scope-2-4b-it-res:
|
|
|
4265
4274
|
- id: layer_9_width_16k_l0_medium
|
|
4266
4275
|
path: resid_post/layer_9_width_16k_l0_medium
|
|
4267
4276
|
l0: 53
|
|
4277
|
+
neuronpedia: gemma-3-4b-it/9-gemmascope-2-res-16k
|
|
4268
4278
|
- id: layer_9_width_16k_l0_small
|
|
4269
4279
|
path: resid_post/layer_9_width_16k_l0_small
|
|
4270
4280
|
l0: 17
|
|
@@ -4283,6 +4293,7 @@ gemma-scope-2-4b-it-res:
|
|
|
4283
4293
|
- id: layer_9_width_262k_l0_medium
|
|
4284
4294
|
path: resid_post/layer_9_width_262k_l0_medium
|
|
4285
4295
|
l0: 53
|
|
4296
|
+
neuronpedia: gemma-3-4b-it/9-gemmascope-2-res-262k
|
|
4286
4297
|
- id: layer_9_width_262k_l0_medium_seed_1
|
|
4287
4298
|
path: resid_post/layer_9_width_262k_l0_medium_seed_1
|
|
4288
4299
|
l0: 53
|
|
@@ -4295,6 +4306,7 @@ gemma-scope-2-4b-it-res:
|
|
|
4295
4306
|
- id: layer_9_width_65k_l0_medium
|
|
4296
4307
|
path: resid_post/layer_9_width_65k_l0_medium
|
|
4297
4308
|
l0: 53
|
|
4309
|
+
neuronpedia: gemma-3-4b-it/9-gemmascope-2-res-65k
|
|
4298
4310
|
- id: layer_9_width_65k_l0_small
|
|
4299
4311
|
path: resid_post/layer_9_width_65k_l0_small
|
|
4300
4312
|
l0: 17
|
|
@@ -14491,6 +14503,7 @@ gemma-scope-2-270m-it-res:
|
|
|
14491
14503
|
- id: layer_12_width_16k_l0_medium
|
|
14492
14504
|
path: resid_post/layer_12_width_16k_l0_medium
|
|
14493
14505
|
l0: 60
|
|
14506
|
+
neuronpedia: gemma-3-270m-it/12-gemmascope-2-res-16k
|
|
14494
14507
|
- id: layer_12_width_16k_l0_small
|
|
14495
14508
|
path: resid_post/layer_12_width_16k_l0_small
|
|
14496
14509
|
l0: 20
|
|
@@ -14509,6 +14522,7 @@ gemma-scope-2-270m-it-res:
|
|
|
14509
14522
|
- id: layer_12_width_262k_l0_medium
|
|
14510
14523
|
path: resid_post/layer_12_width_262k_l0_medium
|
|
14511
14524
|
l0: 60
|
|
14525
|
+
neuronpedia: gemma-3-270m-it/12-gemmascope-2-res-262k
|
|
14512
14526
|
- id: layer_12_width_262k_l0_medium_seed_1
|
|
14513
14527
|
path: resid_post/layer_12_width_262k_l0_medium_seed_1
|
|
14514
14528
|
l0: 60
|
|
@@ -14521,6 +14535,7 @@ gemma-scope-2-270m-it-res:
|
|
|
14521
14535
|
- id: layer_12_width_65k_l0_medium
|
|
14522
14536
|
path: resid_post/layer_12_width_65k_l0_medium
|
|
14523
14537
|
l0: 60
|
|
14538
|
+
neuronpedia: gemma-3-270m-it/12-gemmascope-2-res-65k
|
|
14524
14539
|
- id: layer_12_width_65k_l0_small
|
|
14525
14540
|
path: resid_post/layer_12_width_65k_l0_small
|
|
14526
14541
|
l0: 20
|
|
@@ -14530,6 +14545,7 @@ gemma-scope-2-270m-it-res:
|
|
|
14530
14545
|
- id: layer_15_width_16k_l0_medium
|
|
14531
14546
|
path: resid_post/layer_15_width_16k_l0_medium
|
|
14532
14547
|
l0: 60
|
|
14548
|
+
neuronpedia: gemma-3-270m-it/15-gemmascope-2-res-16k
|
|
14533
14549
|
- id: layer_15_width_16k_l0_small
|
|
14534
14550
|
path: resid_post/layer_15_width_16k_l0_small
|
|
14535
14551
|
l0: 20
|
|
@@ -14548,6 +14564,7 @@ gemma-scope-2-270m-it-res:
|
|
|
14548
14564
|
- id: layer_15_width_262k_l0_medium
|
|
14549
14565
|
path: resid_post/layer_15_width_262k_l0_medium
|
|
14550
14566
|
l0: 60
|
|
14567
|
+
neuronpedia: gemma-3-270m-it/15-gemmascope-2-res-262k
|
|
14551
14568
|
- id: layer_15_width_262k_l0_medium_seed_1
|
|
14552
14569
|
path: resid_post/layer_15_width_262k_l0_medium_seed_1
|
|
14553
14570
|
l0: 60
|
|
@@ -14560,6 +14577,7 @@ gemma-scope-2-270m-it-res:
|
|
|
14560
14577
|
- id: layer_15_width_65k_l0_medium
|
|
14561
14578
|
path: resid_post/layer_15_width_65k_l0_medium
|
|
14562
14579
|
l0: 60
|
|
14580
|
+
neuronpedia: gemma-3-270m-it/15-gemmascope-2-res-65k
|
|
14563
14581
|
- id: layer_15_width_65k_l0_small
|
|
14564
14582
|
path: resid_post/layer_15_width_65k_l0_small
|
|
14565
14583
|
l0: 20
|
|
@@ -14569,6 +14587,7 @@ gemma-scope-2-270m-it-res:
|
|
|
14569
14587
|
- id: layer_5_width_16k_l0_medium
|
|
14570
14588
|
path: resid_post/layer_5_width_16k_l0_medium
|
|
14571
14589
|
l0: 55
|
|
14590
|
+
neuronpedia: gemma-3-270m-it/5-gemmascope-2-res-16k
|
|
14572
14591
|
- id: layer_5_width_16k_l0_small
|
|
14573
14592
|
path: resid_post/layer_5_width_16k_l0_small
|
|
14574
14593
|
l0: 18
|
|
@@ -14587,6 +14606,7 @@ gemma-scope-2-270m-it-res:
|
|
|
14587
14606
|
- id: layer_5_width_262k_l0_medium
|
|
14588
14607
|
path: resid_post/layer_5_width_262k_l0_medium
|
|
14589
14608
|
l0: 55
|
|
14609
|
+
neuronpedia: gemma-3-270m-it/5-gemmascope-2-res-262k
|
|
14590
14610
|
- id: layer_5_width_262k_l0_medium_seed_1
|
|
14591
14611
|
path: resid_post/layer_5_width_262k_l0_medium_seed_1
|
|
14592
14612
|
l0: 55
|
|
@@ -14599,6 +14619,7 @@ gemma-scope-2-270m-it-res:
|
|
|
14599
14619
|
- id: layer_5_width_65k_l0_medium
|
|
14600
14620
|
path: resid_post/layer_5_width_65k_l0_medium
|
|
14601
14621
|
l0: 55
|
|
14622
|
+
neuronpedia: gemma-3-270m-it/5-gemmascope-2-res-65k
|
|
14602
14623
|
- id: layer_5_width_65k_l0_small
|
|
14603
14624
|
path: resid_post/layer_5_width_65k_l0_small
|
|
14604
14625
|
l0: 18
|
|
@@ -14608,6 +14629,7 @@ gemma-scope-2-270m-it-res:
|
|
|
14608
14629
|
- id: layer_9_width_16k_l0_medium
|
|
14609
14630
|
path: resid_post/layer_9_width_16k_l0_medium
|
|
14610
14631
|
l0: 60
|
|
14632
|
+
neuronpedia: gemma-3-270m-it/9-gemmascope-2-res-16k
|
|
14611
14633
|
- id: layer_9_width_16k_l0_small
|
|
14612
14634
|
path: resid_post/layer_9_width_16k_l0_small
|
|
14613
14635
|
l0: 20
|
|
@@ -14626,6 +14648,7 @@ gemma-scope-2-270m-it-res:
|
|
|
14626
14648
|
- id: layer_9_width_262k_l0_medium
|
|
14627
14649
|
path: resid_post/layer_9_width_262k_l0_medium
|
|
14628
14650
|
l0: 60
|
|
14651
|
+
neuronpedia: gemma-3-270m-it/9-gemmascope-2-res-262k
|
|
14629
14652
|
- id: layer_9_width_262k_l0_medium_seed_1
|
|
14630
14653
|
path: resid_post/layer_9_width_262k_l0_medium_seed_1
|
|
14631
14654
|
l0: 60
|
|
@@ -14638,6 +14661,7 @@ gemma-scope-2-270m-it-res:
|
|
|
14638
14661
|
- id: layer_9_width_65k_l0_medium
|
|
14639
14662
|
path: resid_post/layer_9_width_65k_l0_medium
|
|
14640
14663
|
l0: 60
|
|
14664
|
+
neuronpedia: gemma-3-270m-it/9-gemmascope-2-res-65k
|
|
14641
14665
|
- id: layer_9_width_65k_l0_small
|
|
14642
14666
|
path: resid_post/layer_9_width_65k_l0_small
|
|
14643
14667
|
l0: 20
|
|
@@ -18727,6 +18751,7 @@ gemma-scope-2-1b-it-res:
|
|
|
18727
18751
|
- id: layer_13_width_16k_l0_medium
|
|
18728
18752
|
path: resid_post/layer_13_width_16k_l0_medium
|
|
18729
18753
|
l0: 60
|
|
18754
|
+
neuronpedia: gemma-3-1b-it/13-gemmascope-2-res-16k
|
|
18730
18755
|
- id: layer_13_width_16k_l0_small
|
|
18731
18756
|
path: resid_post/layer_13_width_16k_l0_small
|
|
18732
18757
|
l0: 20
|
|
@@ -18745,6 +18770,7 @@ gemma-scope-2-1b-it-res:
|
|
|
18745
18770
|
- id: layer_13_width_262k_l0_medium
|
|
18746
18771
|
path: resid_post/layer_13_width_262k_l0_medium
|
|
18747
18772
|
l0: 60
|
|
18773
|
+
neuronpedia: gemma-3-1b-it/13-gemmascope-2-res-262k
|
|
18748
18774
|
- id: layer_13_width_262k_l0_medium_seed_1
|
|
18749
18775
|
path: resid_post/layer_13_width_262k_l0_medium_seed_1
|
|
18750
18776
|
l0: 60
|
|
@@ -18757,6 +18783,7 @@ gemma-scope-2-1b-it-res:
|
|
|
18757
18783
|
- id: layer_13_width_65k_l0_medium
|
|
18758
18784
|
path: resid_post/layer_13_width_65k_l0_medium
|
|
18759
18785
|
l0: 60
|
|
18786
|
+
neuronpedia: gemma-3-1b-it/13-gemmascope-2-res-65k
|
|
18760
18787
|
- id: layer_13_width_65k_l0_small
|
|
18761
18788
|
path: resid_post/layer_13_width_65k_l0_small
|
|
18762
18789
|
l0: 20
|
|
@@ -18766,6 +18793,7 @@ gemma-scope-2-1b-it-res:
|
|
|
18766
18793
|
- id: layer_17_width_16k_l0_medium
|
|
18767
18794
|
path: resid_post/layer_17_width_16k_l0_medium
|
|
18768
18795
|
l0: 60
|
|
18796
|
+
neuronpedia: gemma-3-1b-it/17-gemmascope-2-res-16k
|
|
18769
18797
|
- id: layer_17_width_16k_l0_small
|
|
18770
18798
|
path: resid_post/layer_17_width_16k_l0_small
|
|
18771
18799
|
l0: 20
|
|
@@ -18784,6 +18812,7 @@ gemma-scope-2-1b-it-res:
|
|
|
18784
18812
|
- id: layer_17_width_262k_l0_medium
|
|
18785
18813
|
path: resid_post/layer_17_width_262k_l0_medium
|
|
18786
18814
|
l0: 60
|
|
18815
|
+
neuronpedia: gemma-3-1b-it/17-gemmascope-2-res-262k
|
|
18787
18816
|
- id: layer_17_width_262k_l0_medium_seed_1
|
|
18788
18817
|
path: resid_post/layer_17_width_262k_l0_medium_seed_1
|
|
18789
18818
|
l0: 60
|
|
@@ -18796,6 +18825,7 @@ gemma-scope-2-1b-it-res:
|
|
|
18796
18825
|
- id: layer_17_width_65k_l0_medium
|
|
18797
18826
|
path: resid_post/layer_17_width_65k_l0_medium
|
|
18798
18827
|
l0: 60
|
|
18828
|
+
neuronpedia: gemma-3-1b-it/17-gemmascope-2-res-65k
|
|
18799
18829
|
- id: layer_17_width_65k_l0_small
|
|
18800
18830
|
path: resid_post/layer_17_width_65k_l0_small
|
|
18801
18831
|
l0: 20
|
|
@@ -18805,6 +18835,7 @@ gemma-scope-2-1b-it-res:
|
|
|
18805
18835
|
- id: layer_22_width_16k_l0_medium
|
|
18806
18836
|
path: resid_post/layer_22_width_16k_l0_medium
|
|
18807
18837
|
l0: 60
|
|
18838
|
+
neuronpedia: gemma-3-1b-it/22-gemmascope-2-res-16k
|
|
18808
18839
|
- id: layer_22_width_16k_l0_small
|
|
18809
18840
|
path: resid_post/layer_22_width_16k_l0_small
|
|
18810
18841
|
l0: 20
|
|
@@ -18823,6 +18854,7 @@ gemma-scope-2-1b-it-res:
|
|
|
18823
18854
|
- id: layer_22_width_262k_l0_medium
|
|
18824
18855
|
path: resid_post/layer_22_width_262k_l0_medium
|
|
18825
18856
|
l0: 60
|
|
18857
|
+
neuronpedia: gemma-3-1b-it/22-gemmascope-2-res-262k
|
|
18826
18858
|
- id: layer_22_width_262k_l0_medium_seed_1
|
|
18827
18859
|
path: resid_post/layer_22_width_262k_l0_medium_seed_1
|
|
18828
18860
|
l0: 60
|
|
@@ -18835,6 +18867,7 @@ gemma-scope-2-1b-it-res:
|
|
|
18835
18867
|
- id: layer_22_width_65k_l0_medium
|
|
18836
18868
|
path: resid_post/layer_22_width_65k_l0_medium
|
|
18837
18869
|
l0: 60
|
|
18870
|
+
neuronpedia: gemma-3-1b-it/22-gemmascope-2-res-65k
|
|
18838
18871
|
- id: layer_22_width_65k_l0_small
|
|
18839
18872
|
path: resid_post/layer_22_width_65k_l0_small
|
|
18840
18873
|
l0: 20
|
|
@@ -18844,6 +18877,7 @@ gemma-scope-2-1b-it-res:
|
|
|
18844
18877
|
- id: layer_7_width_16k_l0_medium
|
|
18845
18878
|
path: resid_post/layer_7_width_16k_l0_medium
|
|
18846
18879
|
l0: 54
|
|
18880
|
+
neuronpedia: gemma-3-1b-it/7-gemmascope-2-res-16k
|
|
18847
18881
|
- id: layer_7_width_16k_l0_small
|
|
18848
18882
|
path: resid_post/layer_7_width_16k_l0_small
|
|
18849
18883
|
l0: 18
|
|
@@ -18862,6 +18896,7 @@ gemma-scope-2-1b-it-res:
|
|
|
18862
18896
|
- id: layer_7_width_262k_l0_medium
|
|
18863
18897
|
path: resid_post/layer_7_width_262k_l0_medium
|
|
18864
18898
|
l0: 54
|
|
18899
|
+
neuronpedia: gemma-3-1b-it/7-gemmascope-2-res-262k
|
|
18865
18900
|
- id: layer_7_width_262k_l0_medium_seed_1
|
|
18866
18901
|
path: resid_post/layer_7_width_262k_l0_medium_seed_1
|
|
18867
18902
|
l0: 54
|
|
@@ -18874,6 +18909,7 @@ gemma-scope-2-1b-it-res:
|
|
|
18874
18909
|
- id: layer_7_width_65k_l0_medium
|
|
18875
18910
|
path: resid_post/layer_7_width_65k_l0_medium
|
|
18876
18911
|
l0: 54
|
|
18912
|
+
neuronpedia: gemma-3-1b-it/7-gemmascope-2-res-65k
|
|
18877
18913
|
- id: layer_7_width_65k_l0_small
|
|
18878
18914
|
path: resid_post/layer_7_width_65k_l0_small
|
|
18879
18915
|
l0: 18
|
sae_lens/saes/sae.py
CHANGED
|
@@ -45,7 +45,6 @@ from sae_lens.loading.pretrained_sae_loaders import (
|
|
|
45
45
|
)
|
|
46
46
|
from sae_lens.loading.pretrained_saes_directory import (
|
|
47
47
|
get_config_overrides,
|
|
48
|
-
get_norm_scaling_factor,
|
|
49
48
|
get_pretrained_saes_directory,
|
|
50
49
|
get_releases_for_repo_id,
|
|
51
50
|
get_repo_id_and_folder_name,
|
|
@@ -638,24 +637,6 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
638
637
|
stacklevel=2,
|
|
639
638
|
)
|
|
640
639
|
elif sae_id not in sae_directory[release].saes_map:
|
|
641
|
-
# Handle special cases like Gemma Scope
|
|
642
|
-
if (
|
|
643
|
-
"gemma-scope" in release
|
|
644
|
-
and "canonical" not in release
|
|
645
|
-
and f"{release}-canonical" in sae_directory
|
|
646
|
-
):
|
|
647
|
-
canonical_ids = list(
|
|
648
|
-
sae_directory[release + "-canonical"].saes_map.keys()
|
|
649
|
-
)
|
|
650
|
-
# Shorten the lengthy string of valid IDs
|
|
651
|
-
if len(canonical_ids) > 5:
|
|
652
|
-
str_canonical_ids = str(canonical_ids[:5])[:-1] + ", ...]"
|
|
653
|
-
else:
|
|
654
|
-
str_canonical_ids = str(canonical_ids)
|
|
655
|
-
value_suffix = f" If you don't want to specify an L0 value, consider using release {release}-canonical which has valid IDs {str_canonical_ids}"
|
|
656
|
-
else:
|
|
657
|
-
value_suffix = ""
|
|
658
|
-
|
|
659
640
|
valid_ids = list(sae_directory[release].saes_map.keys())
|
|
660
641
|
# Shorten the lengthy string of valid IDs
|
|
661
642
|
if len(valid_ids) > 5:
|
|
@@ -665,7 +646,6 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
665
646
|
|
|
666
647
|
raise ValueError(
|
|
667
648
|
f"ID {sae_id} not found in release {release}. Valid IDs are {str_valid_ids}."
|
|
668
|
-
+ value_suffix
|
|
669
649
|
)
|
|
670
650
|
|
|
671
651
|
conversion_loader = (
|
|
@@ -702,17 +682,6 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
702
682
|
sae.process_state_dict_for_loading(state_dict)
|
|
703
683
|
sae.load_state_dict(state_dict, assign=True)
|
|
704
684
|
|
|
705
|
-
# Apply normalization if needed
|
|
706
|
-
if cfg_dict.get("normalize_activations") == "expected_average_only_in":
|
|
707
|
-
norm_scaling_factor = get_norm_scaling_factor(release, sae_id)
|
|
708
|
-
if norm_scaling_factor is not None:
|
|
709
|
-
sae.fold_activation_norm_scaling_factor(norm_scaling_factor)
|
|
710
|
-
cfg_dict["normalize_activations"] = "none"
|
|
711
|
-
else:
|
|
712
|
-
warnings.warn(
|
|
713
|
-
f"norm_scaling_factor not found for {release} and {sae_id}, but normalize_activations is 'expected_average_only_in'. Skipping normalization folding."
|
|
714
|
-
)
|
|
715
|
-
|
|
716
685
|
# the loaders should already handle the dtype / device conversion
|
|
717
686
|
# but this is a fallback to guarantee the SAE is on the correct device and dtype
|
|
718
687
|
return (
|
sae_lens/saes/temporal_sae.py
CHANGED
sae_lens/synthetic/__init__.py
CHANGED
|
@@ -50,6 +50,13 @@ from sae_lens.synthetic.plotting import (
|
|
|
50
50
|
find_best_feature_ordering_from_sae,
|
|
51
51
|
plot_sae_feature_similarity,
|
|
52
52
|
)
|
|
53
|
+
from sae_lens.synthetic.stats import (
|
|
54
|
+
CorrelationMatrixStats,
|
|
55
|
+
SuperpositionStats,
|
|
56
|
+
compute_correlation_matrix_stats,
|
|
57
|
+
compute_low_rank_correlation_matrix_stats,
|
|
58
|
+
compute_superposition_stats,
|
|
59
|
+
)
|
|
53
60
|
from sae_lens.synthetic.training import (
|
|
54
61
|
SyntheticActivationIterator,
|
|
55
62
|
train_toy_sae,
|
|
@@ -80,6 +87,12 @@ __all__ = [
|
|
|
80
87
|
"orthogonal_initializer",
|
|
81
88
|
"FeatureDictionaryInitializer",
|
|
82
89
|
"cosine_similarities",
|
|
90
|
+
# Statistics
|
|
91
|
+
"compute_correlation_matrix_stats",
|
|
92
|
+
"compute_low_rank_correlation_matrix_stats",
|
|
93
|
+
"compute_superposition_stats",
|
|
94
|
+
"CorrelationMatrixStats",
|
|
95
|
+
"SuperpositionStats",
|
|
83
96
|
# Training utilities
|
|
84
97
|
"SyntheticActivationIterator",
|
|
85
98
|
"SyntheticDataEvalResult",
|
|
@@ -3,6 +3,7 @@ from typing import NamedTuple
|
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
|
+
from sae_lens import logger
|
|
6
7
|
from sae_lens.util import str_to_dtype
|
|
7
8
|
|
|
8
9
|
|
|
@@ -268,7 +269,7 @@ def generate_random_correlation_matrix(
|
|
|
268
269
|
def generate_random_low_rank_correlation_matrix(
|
|
269
270
|
num_features: int,
|
|
270
271
|
rank: int,
|
|
271
|
-
correlation_scale: float = 0.
|
|
272
|
+
correlation_scale: float = 0.075,
|
|
272
273
|
seed: int | None = None,
|
|
273
274
|
device: torch.device | str = "cpu",
|
|
274
275
|
dtype: torch.dtype | str = torch.float32,
|
|
@@ -331,20 +332,17 @@ def generate_random_low_rank_correlation_matrix(
|
|
|
331
332
|
factor_sq_sum = (factor**2).sum(dim=1)
|
|
332
333
|
diag_term = 1 - factor_sq_sum
|
|
333
334
|
|
|
334
|
-
#
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
/
|
|
335
|
+
# alternatively, we can rescale each row independently to ensure the diagonal is 1
|
|
336
|
+
mask = diag_term < _MIN_DIAG
|
|
337
|
+
factor[mask, :] *= torch.sqrt((1 - _MIN_DIAG) / factor_sq_sum[mask].unsqueeze(1))
|
|
338
|
+
factor_sq_sum = (factor**2).sum(dim=1)
|
|
339
|
+
diag_term = 1 - factor_sq_sum
|
|
340
|
+
|
|
341
|
+
total_rescaled = mask.sum().item()
|
|
342
|
+
if total_rescaled > 0:
|
|
343
|
+
logger.warning(
|
|
344
|
+
f"{total_rescaled} / {num_features} rows were capped. Either reduce the rank or reduce the correlation_scale to avoid rescaling."
|
|
344
345
|
)
|
|
345
|
-
factor = factor * scale
|
|
346
|
-
factor_sq_sum = (factor**2).sum(dim=1)
|
|
347
|
-
diag_term = 1 - factor_sq_sum
|
|
348
346
|
|
|
349
347
|
return LowRankCorrelationMatrix(
|
|
350
348
|
correlation_factor=factor, correlation_diag=diag_term
|
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from sae_lens.synthetic.correlation import LowRankCorrelationMatrix
|
|
6
|
+
from sae_lens.synthetic.feature_dictionary import FeatureDictionary
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class CorrelationMatrixStats:
|
|
11
|
+
"""Statistics computed from a correlation matrix."""
|
|
12
|
+
|
|
13
|
+
rms_correlation: float # Root mean square of off-diagonal correlations
|
|
14
|
+
mean_correlation: float # Mean of off-diagonal correlations (not absolute)
|
|
15
|
+
num_features: int
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@torch.no_grad()
|
|
19
|
+
def compute_correlation_matrix_stats(
|
|
20
|
+
correlation_matrix: torch.Tensor,
|
|
21
|
+
) -> CorrelationMatrixStats:
|
|
22
|
+
"""Compute correlation statistics from a dense correlation matrix.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
correlation_matrix: Dense correlation matrix of shape (n, n)
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
CorrelationMatrixStats with correlation statistics
|
|
29
|
+
"""
|
|
30
|
+
num_features = correlation_matrix.shape[0]
|
|
31
|
+
|
|
32
|
+
# Extract off-diagonal elements
|
|
33
|
+
mask = ~torch.eye(num_features, dtype=torch.bool, device=correlation_matrix.device)
|
|
34
|
+
off_diag = correlation_matrix[mask]
|
|
35
|
+
|
|
36
|
+
rms_correlation = (off_diag**2).mean().sqrt().item()
|
|
37
|
+
mean_correlation = off_diag.mean().item()
|
|
38
|
+
|
|
39
|
+
return CorrelationMatrixStats(
|
|
40
|
+
rms_correlation=rms_correlation,
|
|
41
|
+
mean_correlation=mean_correlation,
|
|
42
|
+
num_features=num_features,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@torch.no_grad()
|
|
47
|
+
def compute_low_rank_correlation_matrix_stats(
|
|
48
|
+
correlation_matrix: LowRankCorrelationMatrix,
|
|
49
|
+
) -> CorrelationMatrixStats:
|
|
50
|
+
"""Compute correlation statistics from a LowRankCorrelationMatrix.
|
|
51
|
+
|
|
52
|
+
The correlation matrix is represented as:
|
|
53
|
+
correlation = factor @ factor.T + diag(diag_term)
|
|
54
|
+
|
|
55
|
+
The off-diagonal elements are simply factor @ factor.T (the diagonal term
|
|
56
|
+
only affects the diagonal).
|
|
57
|
+
|
|
58
|
+
All statistics are computed efficiently in O(n*r²) time and O(r²) memory
|
|
59
|
+
without materializing the full n×n correlation matrix.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
correlation_matrix: Low-rank correlation matrix
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
CorrelationMatrixStats with correlation statistics
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
factor = correlation_matrix.correlation_factor
|
|
69
|
+
num_features = factor.shape[0]
|
|
70
|
+
num_off_diag = num_features * (num_features - 1)
|
|
71
|
+
|
|
72
|
+
# RMS correlation: uses ||F @ F.T||_F² = ||F.T @ F||_F²
|
|
73
|
+
# This avoids computing the (num_features, num_features) matrix
|
|
74
|
+
G = factor.T @ factor # (rank, rank) - small!
|
|
75
|
+
frobenius_sq = (G**2).sum()
|
|
76
|
+
row_norms_sq = (factor**2).sum(dim=1) # ||F[i]||² for each row
|
|
77
|
+
diag_sq_sum = (row_norms_sq**2).sum() # Σᵢ ||F[i]||⁴
|
|
78
|
+
off_diag_sq_sum = frobenius_sq - diag_sq_sum
|
|
79
|
+
rms_correlation = (off_diag_sq_sum / num_off_diag).sqrt().item()
|
|
80
|
+
|
|
81
|
+
# Mean correlation (not absolute): sum(C) = ||col_sums(F)||², trace(C) = Σ||F[i]||²
|
|
82
|
+
col_sums = factor.sum(dim=0) # (rank,)
|
|
83
|
+
sum_all = (col_sums**2).sum() # 1ᵀ C 1
|
|
84
|
+
trace_C = row_norms_sq.sum()
|
|
85
|
+
mean_correlation = ((sum_all - trace_C) / num_off_diag).item()
|
|
86
|
+
|
|
87
|
+
return CorrelationMatrixStats(
|
|
88
|
+
rms_correlation=rms_correlation,
|
|
89
|
+
mean_correlation=mean_correlation,
|
|
90
|
+
num_features=num_features,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@dataclass
|
|
95
|
+
class SuperpositionStats:
|
|
96
|
+
"""Statistics measuring superposition in a feature dictionary."""
|
|
97
|
+
|
|
98
|
+
# Per-latent statistics: for each latent, max and percentile of |cos_sim| with others
|
|
99
|
+
max_abs_cos_sims: torch.Tensor # Shape: (num_features,)
|
|
100
|
+
percentile_abs_cos_sims: dict[int, torch.Tensor] # {percentile: (num_features,)}
|
|
101
|
+
|
|
102
|
+
# Summary statistics (means of the per-latent values)
|
|
103
|
+
mean_max_abs_cos_sim: float
|
|
104
|
+
mean_percentile_abs_cos_sim: dict[int, float]
|
|
105
|
+
mean_abs_cos_sim: float # Mean |cos_sim| across all pairs
|
|
106
|
+
|
|
107
|
+
# Metadata
|
|
108
|
+
num_features: int
|
|
109
|
+
hidden_dim: int
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@torch.no_grad()
|
|
113
|
+
def compute_superposition_stats(
|
|
114
|
+
feature_dictionary: FeatureDictionary,
|
|
115
|
+
batch_size: int = 1024,
|
|
116
|
+
device: str | torch.device | None = None,
|
|
117
|
+
percentiles: list[int] | None = None,
|
|
118
|
+
) -> SuperpositionStats:
|
|
119
|
+
"""Compute superposition statistics for a feature dictionary.
|
|
120
|
+
|
|
121
|
+
Computes pairwise cosine similarities in batches to handle large dictionaries.
|
|
122
|
+
|
|
123
|
+
For each latent i, computes:
|
|
124
|
+
|
|
125
|
+
- max |cos_sim(i, j)| over all j != i
|
|
126
|
+
- kth percentile of |cos_sim(i, j)| over all j != i (for each k in percentiles)
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
feature_dictionary: FeatureDictionary containing the feature vectors
|
|
130
|
+
batch_size: Number of features to process per batch
|
|
131
|
+
device: Device for computation (defaults to feature dictionary's device)
|
|
132
|
+
percentiles: List of percentiles to compute per latent (default: [95, 99])
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
SuperpositionStats with superposition metrics
|
|
136
|
+
"""
|
|
137
|
+
if percentiles is None:
|
|
138
|
+
percentiles = [95, 99]
|
|
139
|
+
|
|
140
|
+
feature_vectors = feature_dictionary.feature_vectors
|
|
141
|
+
num_features, hidden_dim = feature_vectors.shape
|
|
142
|
+
|
|
143
|
+
if num_features < 2:
|
|
144
|
+
raise ValueError("Need at least 2 features to compute superposition stats")
|
|
145
|
+
if device is None:
|
|
146
|
+
device = feature_vectors.device
|
|
147
|
+
|
|
148
|
+
# Normalize features to unit norm for cosine similarity
|
|
149
|
+
features_normalized = feature_vectors.to(device).float()
|
|
150
|
+
norms = torch.linalg.norm(features_normalized, dim=1, keepdim=True)
|
|
151
|
+
features_normalized = features_normalized / norms.clamp(min=1e-8)
|
|
152
|
+
|
|
153
|
+
# Track per-latent statistics
|
|
154
|
+
max_abs_cos_sims = torch.zeros(num_features, device=device)
|
|
155
|
+
percentile_abs_cos_sims = {
|
|
156
|
+
p: torch.zeros(num_features, device=device) for p in percentiles
|
|
157
|
+
}
|
|
158
|
+
sum_abs_cos_sim = 0.0
|
|
159
|
+
n_pairs = 0
|
|
160
|
+
|
|
161
|
+
# Process in batches: for each batch of features, compute similarities with all others
|
|
162
|
+
for i in range(0, num_features, batch_size):
|
|
163
|
+
batch_end = min(i + batch_size, num_features)
|
|
164
|
+
batch = features_normalized[i:batch_end] # (batch_size, hidden_dim)
|
|
165
|
+
|
|
166
|
+
# Compute cosine similarities with all features: (batch_size, num_features)
|
|
167
|
+
cos_sims = batch @ features_normalized.T
|
|
168
|
+
|
|
169
|
+
# Absolute cosine similarities
|
|
170
|
+
abs_cos_sims = cos_sims.abs()
|
|
171
|
+
|
|
172
|
+
# Process each latent in the batch
|
|
173
|
+
for j, idx in enumerate(range(i, batch_end)):
|
|
174
|
+
# Get similarities with all other features (exclude self)
|
|
175
|
+
row = abs_cos_sims[j].clone()
|
|
176
|
+
row[idx] = 0.0 # Exclude self for max
|
|
177
|
+
max_abs_cos_sims[idx] = row.max()
|
|
178
|
+
|
|
179
|
+
# For percentiles, exclude self and compute
|
|
180
|
+
other_sims = torch.cat([abs_cos_sims[j, :idx], abs_cos_sims[j, idx + 1 :]])
|
|
181
|
+
for p in percentiles:
|
|
182
|
+
percentile_abs_cos_sims[p][idx] = torch.quantile(other_sims, p / 100.0)
|
|
183
|
+
|
|
184
|
+
# Sum for mean computation (only count pairs once - with features after this one)
|
|
185
|
+
sum_abs_cos_sim += abs_cos_sims[j, idx + 1 :].sum().item()
|
|
186
|
+
n_pairs += num_features - idx - 1
|
|
187
|
+
|
|
188
|
+
# Compute summary statistics
|
|
189
|
+
mean_max_abs_cos_sim = max_abs_cos_sims.mean().item()
|
|
190
|
+
mean_percentile_abs_cos_sim = {
|
|
191
|
+
p: percentile_abs_cos_sims[p].mean().item() for p in percentiles
|
|
192
|
+
}
|
|
193
|
+
mean_abs_cos_sim = sum_abs_cos_sim / n_pairs if n_pairs > 0 else 0.0
|
|
194
|
+
|
|
195
|
+
return SuperpositionStats(
|
|
196
|
+
max_abs_cos_sims=max_abs_cos_sims.cpu(),
|
|
197
|
+
percentile_abs_cos_sims={
|
|
198
|
+
p: v.cpu() for p, v in percentile_abs_cos_sims.items()
|
|
199
|
+
},
|
|
200
|
+
mean_max_abs_cos_sim=mean_max_abs_cos_sim,
|
|
201
|
+
mean_percentile_abs_cos_sim=mean_percentile_abs_cos_sim,
|
|
202
|
+
mean_abs_cos_sim=mean_abs_cos_sim,
|
|
203
|
+
num_features=num_features,
|
|
204
|
+
hidden_dim=hidden_dim,
|
|
205
|
+
)
|
|
@@ -28,7 +28,9 @@ class ActivationScaler:
|
|
|
28
28
|
) -> float:
|
|
29
29
|
norms_per_batch: list[float] = []
|
|
30
30
|
for _ in tqdm(
|
|
31
|
-
range(n_batches_for_norm_estimate),
|
|
31
|
+
range(n_batches_for_norm_estimate),
|
|
32
|
+
desc="Estimating norm scaling factor",
|
|
33
|
+
leave=False,
|
|
32
34
|
):
|
|
33
35
|
acts = next(data_provider)
|
|
34
36
|
norms_per_batch.append(acts.norm(dim=-1).mean().item())
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: sae-lens
|
|
3
|
-
Version: 6.
|
|
3
|
+
Version: 6.33.0
|
|
4
4
|
Summary: Training and Analyzing Sparse Autoencoders (SAEs)
|
|
5
5
|
License: MIT
|
|
6
6
|
License-File: LICENSE
|
|
@@ -27,7 +27,7 @@ Requires-Dist: pyyaml (>=6.0.1,<7.0.0)
|
|
|
27
27
|
Requires-Dist: safetensors (>=0.4.2,<1.0.0)
|
|
28
28
|
Requires-Dist: simple-parsing (>=0.1.6,<0.2.0)
|
|
29
29
|
Requires-Dist: tenacity (>=9.0.0)
|
|
30
|
-
Requires-Dist: transformer-lens (>=2.16.1
|
|
30
|
+
Requires-Dist: transformer-lens (>=2.16.1)
|
|
31
31
|
Requires-Dist: transformers (>=4.38.1,<5.0.0)
|
|
32
32
|
Requires-Dist: typing-extensions (>=4.10.0,<5.0.0)
|
|
33
33
|
Project-URL: Homepage, https://decoderesearch.github.io/SAELens
|
|
@@ -1,18 +1,20 @@
|
|
|
1
|
-
sae_lens/__init__.py,sha256=
|
|
2
|
-
sae_lens/analysis/__init__.py,sha256=
|
|
3
|
-
sae_lens/analysis/
|
|
1
|
+
sae_lens/__init__.py,sha256=gHaxlySzLskrAUg2oUZ3aOpnI3U_AVIHce-agGJL9rI,5168
|
|
2
|
+
sae_lens/analysis/__init__.py,sha256=FZExlMviNwWR7OGUSGRbd0l-yUDGSp80gglI_ivILrY,412
|
|
3
|
+
sae_lens/analysis/compat.py,sha256=cgE3nhFcJTcuhppxbL71VanJS7YqVEOefuneB5eOaPw,538
|
|
4
|
+
sae_lens/analysis/hooked_sae_transformer.py,sha256=LpnjxSAcItqqXA4SJyZuxY4Ki0UOuWV683wg9laYAsY,14050
|
|
4
5
|
sae_lens/analysis/neuronpedia_integration.py,sha256=Gx1W7hUBEuMoasNcnOnZ1wmqbXDd1pSZ1nqKEya1HQc,4962
|
|
6
|
+
sae_lens/analysis/sae_transformer_bridge.py,sha256=xpJRRcB0g47EOQcmNCwMyrJJsbqMsGxVViDrV6C3upU,14916
|
|
5
7
|
sae_lens/cache_activations_runner.py,sha256=TjqNWIc46Nw09jHWFjzQzgzG5wdu_87Ahe-iFjI5_0Q,13117
|
|
6
|
-
sae_lens/config.py,sha256=
|
|
8
|
+
sae_lens/config.py,sha256=V0BXV8rvpbm5YuVukow9FURPpdyE4HSflbdymAo0Ycg,31205
|
|
7
9
|
sae_lens/constants.py,sha256=CM-h9AjZNAl2aP7hVpKk7YsFHpu-_Lfhhmq2d5qPEVc,887
|
|
8
|
-
sae_lens/evals.py,sha256=
|
|
10
|
+
sae_lens/evals.py,sha256=nEZpUfEUN-plw6Mj9GEqm-cU_tb1qrIF9km9ktQ0vVU,39624
|
|
9
11
|
sae_lens/llm_sae_training_runner.py,sha256=M7BK55gSFYu2qFQKABHX3c8i46P1LfODCeyHFzGGuqU,15196
|
|
10
12
|
sae_lens/load_model.py,sha256=C8AMykctj6H7tz_xRwB06-EXj6TfW64PtSJZR5Jxn1Y,8649
|
|
11
13
|
sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
12
|
-
sae_lens/loading/pretrained_sae_loaders.py,sha256=
|
|
13
|
-
sae_lens/loading/pretrained_saes_directory.py,sha256=
|
|
14
|
+
sae_lens/loading/pretrained_sae_loaders.py,sha256=kshvA0NivOc7B3sL19lHr_zrC_DDfW2T6YWb5j0hgAk,63930
|
|
15
|
+
sae_lens/loading/pretrained_saes_directory.py,sha256=lSnHl77IO5dd7iO21ynCzZNMrzuJAT8Za4W5THNq0qw,3554
|
|
14
16
|
sae_lens/pretokenize_runner.py,sha256=amJwIz3CKi2s2wNQn-10E7eAV7VFhNqtFDNTeTkwEI8,7133
|
|
15
|
-
sae_lens/pretrained_saes.yaml,sha256=
|
|
17
|
+
sae_lens/pretrained_saes.yaml,sha256=IVBLLR8_XNllJ1O-kVv9ED4u0u44Yn8UOL9R-f8Idp4,1511936
|
|
16
18
|
sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
|
|
17
19
|
sae_lens/saes/__init__.py,sha256=SBqPaP6Gl5uPFwHlumAZATC4Wd26xKIYLAAAo4MSa5Q,2200
|
|
18
20
|
sae_lens/saes/batchtopk_sae.py,sha256=x4EbgZl0GUickRPcCmtKNGS2Ra3Uy1Z1OtF2FnrSabQ,5422
|
|
@@ -20,24 +22,25 @@ sae_lens/saes/gated_sae.py,sha256=V_2ZNlV4gRD-rX5JSx1xqY7idT8ChfdQ5yxWDdu_6hg,88
|
|
|
20
22
|
sae_lens/saes/jumprelu_sae.py,sha256=miiF-xI_yXdV9EkKjwAbU9zSMsx9KtKCz5YdXEzkN8g,13313
|
|
21
23
|
sae_lens/saes/matching_pursuit_sae.py,sha256=08_G9p1YMLnE5qZVCPp6gll-iG6nHRbMMASf4_bkFt8,13207
|
|
22
24
|
sae_lens/saes/matryoshka_batchtopk_sae.py,sha256=Qr6htt1HHOuO9FXI9hyaPSnGFIiJG-v7y1t1CEmkFzM,5995
|
|
23
|
-
sae_lens/saes/sae.py,sha256=
|
|
25
|
+
sae_lens/saes/sae.py,sha256=wkwqzNragj-1189cV52S3_XeRtEgBd2ZNwvL2EsKkWw,39429
|
|
24
26
|
sae_lens/saes/standard_sae.py,sha256=_hldNZkFPAf9VGrxouR1-tN8T2OEk8IkWBcXoatrC1o,5749
|
|
25
|
-
sae_lens/saes/temporal_sae.py,sha256=
|
|
27
|
+
sae_lens/saes/temporal_sae.py,sha256=S44sPddVj2xujA02CC8gT1tG0in7c_CSAhspu9FHbaA,13273
|
|
26
28
|
sae_lens/saes/topk_sae.py,sha256=vrMRPrCQR1o8G_kXqY_EAoGZARupkQNFB2dNZVLsusE,21073
|
|
27
29
|
sae_lens/saes/transcoder.py,sha256=CTpJs8ASOK06npih7gZHygZuxqTR7HICWlOYfTiKjI4,13501
|
|
28
|
-
sae_lens/synthetic/__init__.py,sha256=
|
|
30
|
+
sae_lens/synthetic/__init__.py,sha256=hRRA3xhEQUacGyFbJXkLVYg_8A1bbSYYWlVovb0g4KU,3503
|
|
29
31
|
sae_lens/synthetic/activation_generator.py,sha256=8L9nwC4jFRv_wg3QN-n1sFwX8w1NqwJMysWaJ41lLlY,15197
|
|
30
|
-
sae_lens/synthetic/correlation.py,sha256=
|
|
32
|
+
sae_lens/synthetic/correlation.py,sha256=tD8J9abWfuFtGZrEbbFn4P8FeTcNKF2V5JhBLwDUmkg,13146
|
|
31
33
|
sae_lens/synthetic/evals.py,sha256=Nhi314ZnRgLfhBj-3tm_zzI-pGyFTcwllDXbIpPFXeU,4584
|
|
32
34
|
sae_lens/synthetic/feature_dictionary.py,sha256=Nd4xjSTxKMnKilZ3uYi8Gv5SS5D4bv4wHiSL1uGB69E,6933
|
|
33
35
|
sae_lens/synthetic/firing_probabilities.py,sha256=yclz1pWl5gE1r8LAxFvzQS88Lxwk5-3r8BCX9HLVejA,3370
|
|
34
36
|
sae_lens/synthetic/hierarchy.py,sha256=nm7nwnTswktVJeKUsRZ0hLOdXcFWGbxnA1b6lefHm-4,33592
|
|
35
37
|
sae_lens/synthetic/initialization.py,sha256=orMGW-786wRDHIS2W7bEH0HmlVFQ4g2z4bnnwdv5w4s,1386
|
|
36
38
|
sae_lens/synthetic/plotting.py,sha256=5lFrej1QOkGAcImFNo5-o-8mI_rUVqvEI57KzUQPPtQ,8208
|
|
39
|
+
sae_lens/synthetic/stats.py,sha256=BoDPKDx8pgFF5Ko_IaBRZTczm7-ANUIRjjF5W5Qh3Lk,7441
|
|
37
40
|
sae_lens/synthetic/training.py,sha256=fHcX2cZ6nDupr71GX0Gk17f1NvQ0SKIVXIA6IuAb2dw,5692
|
|
38
41
|
sae_lens/tokenization_and_batching.py,sha256=uoHtAs9z3XqG0Fh-iQVYVlrbyB_E3kFFhrKU30BosCo,5438
|
|
39
42
|
sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
40
|
-
sae_lens/training/activation_scaler.py,sha256=
|
|
43
|
+
sae_lens/training/activation_scaler.py,sha256=SJZzIMX1TGdeN_wT_wqgx2ij6f4p5Dm5lWH6DGNSt5g,2011
|
|
41
44
|
sae_lens/training/activations_store.py,sha256=kp4-6R4rTJUSt-g-Ifg5B1h7iIe7jZj-XQSKDvDpQMI,32187
|
|
42
45
|
sae_lens/training/mixing_buffer.py,sha256=1Z-S2CcQXMWGxRZJFnXeZFxbZcALkO_fP6VO37XdJQQ,2519
|
|
43
46
|
sae_lens/training/optim.py,sha256=bJpqqcK4enkcPvQAJkeH4Ci1LUOlfjIMTv6-IlaAbRA,5588
|
|
@@ -46,7 +49,7 @@ sae_lens/training/types.py,sha256=1FpLx_Doda9vZpmfm-x1e8wGBYpyhe9Kpb_JuM5nIFM,90
|
|
|
46
49
|
sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
|
|
47
50
|
sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
|
|
48
51
|
sae_lens/util.py,sha256=oIMoeyEP2IzcPFmRbKUzOAycgEyMcOasGeO_BGVZbc4,4846
|
|
49
|
-
sae_lens-6.
|
|
50
|
-
sae_lens-6.
|
|
51
|
-
sae_lens-6.
|
|
52
|
-
sae_lens-6.
|
|
52
|
+
sae_lens-6.33.0.dist-info/METADATA,sha256=X6XqngWTNEsfdaPPWXxtF8Kvdp8fAk8i68sfRtDb2xo,6566
|
|
53
|
+
sae_lens-6.33.0.dist-info/WHEEL,sha256=3ny-bZhpXrU6vSQ1UPG34FoxZBp3lVcvK0LkgUz6VLk,88
|
|
54
|
+
sae_lens-6.33.0.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
|
|
55
|
+
sae_lens-6.33.0.dist-info/RECORD,,
|
|
File without changes
|