sae-lens 6.28.2__py3-none-any.whl → 6.32.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +14 -1
- sae_lens/analysis/__init__.py +15 -0
- sae_lens/analysis/compat.py +16 -0
- sae_lens/analysis/hooked_sae_transformer.py +1 -1
- sae_lens/analysis/sae_transformer_bridge.py +348 -0
- sae_lens/config.py +9 -1
- sae_lens/evals.py +2 -2
- sae_lens/loading/pretrained_sae_loaders.py +11 -4
- sae_lens/pretrained_saes.yaml +36 -0
- sae_lens/saes/temporal_sae.py +1 -1
- sae_lens/synthetic/__init__.py +6 -0
- sae_lens/synthetic/activation_generator.py +197 -25
- sae_lens/synthetic/correlation.py +217 -36
- sae_lens/synthetic/feature_dictionary.py +11 -2
- sae_lens/synthetic/hierarchy.py +314 -2
- sae_lens/synthetic/training.py +16 -3
- sae_lens/training/activation_scaler.py +3 -1
- {sae_lens-6.28.2.dist-info → sae_lens-6.32.1.dist-info}/METADATA +2 -2
- {sae_lens-6.28.2.dist-info → sae_lens-6.32.1.dist-info}/RECORD +21 -19
- {sae_lens-6.28.2.dist-info → sae_lens-6.32.1.dist-info}/WHEEL +1 -1
- {sae_lens-6.28.2.dist-info → sae_lens-6.32.1.dist-info}/licenses/LICENSE +0 -0
sae_lens/__init__.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
# ruff: noqa: E402
|
|
2
|
-
__version__ = "6.
|
|
2
|
+
__version__ = "6.32.1"
|
|
3
3
|
|
|
4
4
|
import logging
|
|
5
5
|
|
|
@@ -125,6 +125,19 @@ __all__ = [
|
|
|
125
125
|
"MatchingPursuitTrainingSAEConfig",
|
|
126
126
|
]
|
|
127
127
|
|
|
128
|
+
# Conditional export for SAETransformerBridge (requires transformer-lens v3+)
|
|
129
|
+
try:
|
|
130
|
+
from sae_lens.analysis.compat import has_transformer_bridge
|
|
131
|
+
|
|
132
|
+
if has_transformer_bridge():
|
|
133
|
+
from sae_lens.analysis.sae_transformer_bridge import ( # noqa: F401
|
|
134
|
+
SAETransformerBridge,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
__all__.append("SAETransformerBridge")
|
|
138
|
+
except ImportError:
|
|
139
|
+
pass
|
|
140
|
+
|
|
128
141
|
|
|
129
142
|
register_sae_class("standard", StandardSAE, StandardSAEConfig)
|
|
130
143
|
register_sae_training_class("standard", StandardTrainingSAE, StandardTrainingSAEConfig)
|
sae_lens/analysis/__init__.py
CHANGED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from sae_lens.analysis.hooked_sae_transformer import HookedSAETransformer
|
|
2
|
+
|
|
3
|
+
__all__ = ["HookedSAETransformer"]
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
from sae_lens.analysis.compat import has_transformer_bridge
|
|
7
|
+
|
|
8
|
+
if has_transformer_bridge():
|
|
9
|
+
from sae_lens.analysis.sae_transformer_bridge import ( # noqa: F401
|
|
10
|
+
SAETransformerBridge,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
__all__.append("SAETransformerBridge")
|
|
14
|
+
except ImportError:
|
|
15
|
+
pass
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import importlib.metadata
|
|
2
|
+
|
|
3
|
+
from packaging.version import parse as parse_version
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def get_transformer_lens_version() -> tuple[int, int, int]:
|
|
7
|
+
"""Get transformer-lens version as (major, minor, patch)."""
|
|
8
|
+
version_str = importlib.metadata.version("transformer-lens")
|
|
9
|
+
version = parse_version(version_str)
|
|
10
|
+
return (version.major, version.minor, version.micro)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def has_transformer_bridge() -> bool:
|
|
14
|
+
"""Check if TransformerBridge is available (v3+)."""
|
|
15
|
+
major, _, _ = get_transformer_lens_version()
|
|
16
|
+
return major >= 3
|
|
@@ -126,7 +126,7 @@ class HookedSAETransformer(HookedTransformer):
|
|
|
126
126
|
current_sae.use_error_term = current_sae._original_use_error_term # type: ignore
|
|
127
127
|
delattr(current_sae, "_original_use_error_term")
|
|
128
128
|
|
|
129
|
-
if prev_sae:
|
|
129
|
+
if prev_sae is not None:
|
|
130
130
|
set_deep_attr(self, act_name, prev_sae)
|
|
131
131
|
self.acts_to_saes[act_name] = prev_sae
|
|
132
132
|
else:
|
|
@@ -0,0 +1,348 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from contextlib import contextmanager
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from transformer_lens.ActivationCache import ActivationCache
|
|
7
|
+
from transformer_lens.hook_points import HookPoint
|
|
8
|
+
from transformer_lens.model_bridge import TransformerBridge
|
|
9
|
+
|
|
10
|
+
from sae_lens import logger
|
|
11
|
+
from sae_lens.analysis.hooked_sae_transformer import set_deep_attr
|
|
12
|
+
from sae_lens.saes.sae import SAE
|
|
13
|
+
|
|
14
|
+
SingleLoss = torch.Tensor # Type alias for a single element tensor
|
|
15
|
+
LossPerToken = torch.Tensor
|
|
16
|
+
Loss = SingleLoss | LossPerToken
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class SAETransformerBridge(TransformerBridge): # type: ignore[misc,no-untyped-call]
|
|
20
|
+
"""A TransformerBridge subclass that supports attaching SAEs.
|
|
21
|
+
|
|
22
|
+
.. warning::
|
|
23
|
+
This class is in **beta**. The API may change in future versions.
|
|
24
|
+
|
|
25
|
+
This class provides the same SAE attachment functionality as HookedSAETransformer,
|
|
26
|
+
but for transformer-lens v3's TransformerBridge instead of HookedTransformer.
|
|
27
|
+
|
|
28
|
+
TransformerBridge is a lightweight wrapper around HuggingFace models that provides
|
|
29
|
+
hook points without the overhead of HookedTransformer's weight processing. This is
|
|
30
|
+
useful for models not natively supported by HookedTransformer, such as Gemma 3.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
acts_to_saes: dict[str, SAE[Any]]
|
|
34
|
+
|
|
35
|
+
def __init__(self, *args: Any, **kwargs: Any):
|
|
36
|
+
super().__init__(*args, **kwargs)
|
|
37
|
+
self.acts_to_saes = {}
|
|
38
|
+
|
|
39
|
+
@classmethod
|
|
40
|
+
def boot_transformers( # type: ignore[override]
|
|
41
|
+
cls,
|
|
42
|
+
model_name: str,
|
|
43
|
+
**kwargs: Any,
|
|
44
|
+
) -> "SAETransformerBridge":
|
|
45
|
+
"""Factory method to boot a model and return SAETransformerBridge instance.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
model_name: The name of the model to load (e.g., "gpt2", "gemma-2-2b")
|
|
49
|
+
**kwargs: Additional arguments passed to TransformerBridge.boot_transformers
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
SAETransformerBridge instance with the loaded model
|
|
53
|
+
"""
|
|
54
|
+
# Boot parent TransformerBridge
|
|
55
|
+
bridge = TransformerBridge.boot_transformers(model_name, **kwargs)
|
|
56
|
+
# Convert to our class
|
|
57
|
+
# NOTE: this is super hacky and scary, but I don't know how else to achieve this given TLens' internal code
|
|
58
|
+
bridge.__class__ = cls
|
|
59
|
+
bridge.acts_to_saes = {} # type: ignore[attr-defined]
|
|
60
|
+
return bridge # type: ignore[return-value]
|
|
61
|
+
|
|
62
|
+
def _resolve_hook_name(self, hook_name: str) -> str:
|
|
63
|
+
"""Resolve alias to actual hook name.
|
|
64
|
+
|
|
65
|
+
TransformerBridge supports hook aliases like 'blocks.0.hook_mlp_out'
|
|
66
|
+
that map to actual paths like 'blocks.0.mlp.hook_out'.
|
|
67
|
+
"""
|
|
68
|
+
# Combine static and dynamic aliases
|
|
69
|
+
aliases: dict[str, Any] = {
|
|
70
|
+
**self.hook_aliases,
|
|
71
|
+
**self._collect_hook_aliases_from_registry(),
|
|
72
|
+
}
|
|
73
|
+
resolved = aliases.get(hook_name, hook_name)
|
|
74
|
+
# aliases values are always strings, but type checker doesn't know this
|
|
75
|
+
return resolved if isinstance(resolved, str) else hook_name
|
|
76
|
+
|
|
77
|
+
def add_sae(self, sae: SAE[Any], use_error_term: bool | None = None) -> None:
|
|
78
|
+
"""Attaches an SAE to the model.
|
|
79
|
+
|
|
80
|
+
WARNING: This SAE will be permanently attached until you remove it with
|
|
81
|
+
reset_saes. This function will also overwrite any existing SAE attached
|
|
82
|
+
to the same hook point.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
sae: The SAE to attach to the model
|
|
86
|
+
use_error_term: If provided, will set the use_error_term attribute of
|
|
87
|
+
the SAE to this value. Determines whether the SAE returns input
|
|
88
|
+
or reconstruction. Defaults to None.
|
|
89
|
+
"""
|
|
90
|
+
alias_name = sae.cfg.metadata.hook_name
|
|
91
|
+
actual_name = self._resolve_hook_name(alias_name)
|
|
92
|
+
|
|
93
|
+
# Check if hook exists (either as alias or actual name)
|
|
94
|
+
if (alias_name not in self.acts_to_saes) and (
|
|
95
|
+
actual_name not in self._hook_registry
|
|
96
|
+
):
|
|
97
|
+
logger.warning(
|
|
98
|
+
f"No hook found for {alias_name}. Skipping. "
|
|
99
|
+
f"Check model._hook_registry for available hooks."
|
|
100
|
+
)
|
|
101
|
+
return
|
|
102
|
+
|
|
103
|
+
if use_error_term is not None:
|
|
104
|
+
if not hasattr(sae, "_original_use_error_term"):
|
|
105
|
+
sae._original_use_error_term = sae.use_error_term # type: ignore[attr-defined]
|
|
106
|
+
sae.use_error_term = use_error_term
|
|
107
|
+
|
|
108
|
+
# Replace hook and update registry
|
|
109
|
+
set_deep_attr(self, actual_name, sae)
|
|
110
|
+
self._hook_registry[actual_name] = sae # type: ignore[assignment]
|
|
111
|
+
self.acts_to_saes[alias_name] = sae
|
|
112
|
+
|
|
113
|
+
def _reset_sae(self, act_name: str, prev_sae: SAE[Any] | None = None) -> None:
|
|
114
|
+
"""Resets an SAE that was attached to the model.
|
|
115
|
+
|
|
116
|
+
By default will remove the SAE from that hook_point.
|
|
117
|
+
If prev_sae is provided, will replace the current SAE with the provided one.
|
|
118
|
+
This is mainly used to restore previously attached SAEs after temporarily
|
|
119
|
+
running with different SAEs (e.g., with run_with_saes).
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
act_name: The hook_name of the SAE to reset
|
|
123
|
+
prev_sae: The SAE to replace the current one with. If None, will just
|
|
124
|
+
remove the SAE from this hook point. Defaults to None.
|
|
125
|
+
"""
|
|
126
|
+
if act_name not in self.acts_to_saes:
|
|
127
|
+
logger.warning(
|
|
128
|
+
f"No SAE is attached to {act_name}. There's nothing to reset."
|
|
129
|
+
)
|
|
130
|
+
return
|
|
131
|
+
|
|
132
|
+
actual_name = self._resolve_hook_name(act_name)
|
|
133
|
+
current_sae = self.acts_to_saes[act_name]
|
|
134
|
+
|
|
135
|
+
if hasattr(current_sae, "_original_use_error_term"):
|
|
136
|
+
current_sae.use_error_term = current_sae._original_use_error_term # type: ignore[attr-defined]
|
|
137
|
+
delattr(current_sae, "_original_use_error_term")
|
|
138
|
+
|
|
139
|
+
if prev_sae is not None:
|
|
140
|
+
set_deep_attr(self, actual_name, prev_sae)
|
|
141
|
+
self._hook_registry[actual_name] = prev_sae # type: ignore[assignment]
|
|
142
|
+
self.acts_to_saes[act_name] = prev_sae
|
|
143
|
+
else:
|
|
144
|
+
new_hook = HookPoint()
|
|
145
|
+
new_hook.name = actual_name
|
|
146
|
+
set_deep_attr(self, actual_name, new_hook)
|
|
147
|
+
self._hook_registry[actual_name] = new_hook
|
|
148
|
+
del self.acts_to_saes[act_name]
|
|
149
|
+
|
|
150
|
+
def reset_saes(
|
|
151
|
+
self,
|
|
152
|
+
act_names: str | list[str] | None = None,
|
|
153
|
+
prev_saes: list[SAE[Any] | None] | None = None,
|
|
154
|
+
) -> None:
|
|
155
|
+
"""Reset the SAEs attached to the model.
|
|
156
|
+
|
|
157
|
+
If act_names are provided will just reset SAEs attached to those hooks.
|
|
158
|
+
Otherwise will reset all SAEs attached to the model.
|
|
159
|
+
Optionally can provide a list of prev_saes to reset to. This is mainly
|
|
160
|
+
used to restore previously attached SAEs after temporarily running with
|
|
161
|
+
different SAEs (e.g., with run_with_saes).
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
act_names: The act_names of the SAEs to reset. If None, will reset all
|
|
165
|
+
SAEs attached to the model. Defaults to None.
|
|
166
|
+
prev_saes: List of SAEs to replace the current ones with. If None, will
|
|
167
|
+
just remove the SAEs. Defaults to None.
|
|
168
|
+
"""
|
|
169
|
+
if isinstance(act_names, str):
|
|
170
|
+
act_names = [act_names]
|
|
171
|
+
elif act_names is None:
|
|
172
|
+
act_names = list(self.acts_to_saes.keys())
|
|
173
|
+
|
|
174
|
+
if prev_saes:
|
|
175
|
+
if len(act_names) != len(prev_saes):
|
|
176
|
+
raise ValueError("act_names and prev_saes must have the same length")
|
|
177
|
+
else:
|
|
178
|
+
prev_saes = [None] * len(act_names) # type: ignore[assignment]
|
|
179
|
+
|
|
180
|
+
for act_name, prev_sae in zip(act_names, prev_saes): # type: ignore[arg-type]
|
|
181
|
+
self._reset_sae(act_name, prev_sae)
|
|
182
|
+
|
|
183
|
+
def run_with_saes(
|
|
184
|
+
self,
|
|
185
|
+
*model_args: Any,
|
|
186
|
+
saes: SAE[Any] | list[SAE[Any]] = [],
|
|
187
|
+
reset_saes_end: bool = True,
|
|
188
|
+
use_error_term: bool | None = None,
|
|
189
|
+
**model_kwargs: Any,
|
|
190
|
+
) -> torch.Tensor | Loss | tuple[torch.Tensor, Loss] | None:
|
|
191
|
+
"""Wrapper around forward pass.
|
|
192
|
+
|
|
193
|
+
Runs the model with the given SAEs attached for one forward pass, then
|
|
194
|
+
removes them. By default, will reset all SAEs to original state after.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
*model_args: Positional arguments for the model forward pass
|
|
198
|
+
saes: The SAEs to be attached for this forward pass
|
|
199
|
+
reset_saes_end: If True, all SAEs added during this run are removed
|
|
200
|
+
at the end, and previously attached SAEs are restored to their
|
|
201
|
+
original state. Default is True.
|
|
202
|
+
use_error_term: If provided, will set the use_error_term attribute
|
|
203
|
+
of all SAEs attached during this run to this value. Defaults to None.
|
|
204
|
+
**model_kwargs: Keyword arguments for the model forward pass
|
|
205
|
+
"""
|
|
206
|
+
with self.saes(
|
|
207
|
+
saes=saes, reset_saes_end=reset_saes_end, use_error_term=use_error_term
|
|
208
|
+
):
|
|
209
|
+
return self(*model_args, **model_kwargs)
|
|
210
|
+
|
|
211
|
+
def run_with_cache_with_saes(
|
|
212
|
+
self,
|
|
213
|
+
*model_args: Any,
|
|
214
|
+
saes: SAE[Any] | list[SAE[Any]] = [],
|
|
215
|
+
reset_saes_end: bool = True,
|
|
216
|
+
use_error_term: bool | None = None,
|
|
217
|
+
return_cache_object: bool = True,
|
|
218
|
+
remove_batch_dim: bool = False,
|
|
219
|
+
**kwargs: Any,
|
|
220
|
+
) -> tuple[
|
|
221
|
+
torch.Tensor | Loss | tuple[torch.Tensor, Loss] | None,
|
|
222
|
+
ActivationCache | dict[str, torch.Tensor],
|
|
223
|
+
]:
|
|
224
|
+
"""Wrapper around 'run_with_cache'.
|
|
225
|
+
|
|
226
|
+
Attaches given SAEs before running the model with cache and then removes them.
|
|
227
|
+
By default, will reset all SAEs to original state after.
|
|
228
|
+
|
|
229
|
+
Args:
|
|
230
|
+
*model_args: Positional arguments for the model forward pass
|
|
231
|
+
saes: The SAEs to be attached for this forward pass
|
|
232
|
+
reset_saes_end: If True, all SAEs added during this run are removed
|
|
233
|
+
at the end, and previously attached SAEs are restored to their
|
|
234
|
+
original state. Default is True.
|
|
235
|
+
use_error_term: If provided, will set the use_error_term attribute
|
|
236
|
+
of all SAEs attached during this run to this value. Defaults to None.
|
|
237
|
+
return_cache_object: If True, returns an ActivationCache object with
|
|
238
|
+
useful methods, otherwise returns a dictionary of activations.
|
|
239
|
+
remove_batch_dim: Whether to remove the batch dimension
|
|
240
|
+
(only works for batch_size==1). Defaults to False.
|
|
241
|
+
**kwargs: Keyword arguments for the model forward pass
|
|
242
|
+
"""
|
|
243
|
+
with self.saes(
|
|
244
|
+
saes=saes, reset_saes_end=reset_saes_end, use_error_term=use_error_term
|
|
245
|
+
):
|
|
246
|
+
return self.run_with_cache(
|
|
247
|
+
*model_args,
|
|
248
|
+
return_cache_object=return_cache_object, # type: ignore[arg-type]
|
|
249
|
+
remove_batch_dim=remove_batch_dim,
|
|
250
|
+
**kwargs,
|
|
251
|
+
) # type: ignore[return-value]
|
|
252
|
+
|
|
253
|
+
def run_with_hooks_with_saes(
|
|
254
|
+
self,
|
|
255
|
+
*model_args: Any,
|
|
256
|
+
saes: SAE[Any] | list[SAE[Any]] = [],
|
|
257
|
+
reset_saes_end: bool = True,
|
|
258
|
+
fwd_hooks: list[tuple[str | Callable[..., Any], Callable[..., Any]]] = [],
|
|
259
|
+
bwd_hooks: list[tuple[str | Callable[..., Any], Callable[..., Any]]] = [],
|
|
260
|
+
reset_hooks_end: bool = True,
|
|
261
|
+
clear_contexts: bool = False,
|
|
262
|
+
**model_kwargs: Any,
|
|
263
|
+
) -> Any:
|
|
264
|
+
"""Wrapper around 'run_with_hooks'.
|
|
265
|
+
|
|
266
|
+
Attaches the given SAEs to the model before running the model with hooks
|
|
267
|
+
and then removes them. By default, will reset all SAEs to original state after.
|
|
268
|
+
|
|
269
|
+
Args:
|
|
270
|
+
*model_args: Positional arguments for the model forward pass
|
|
271
|
+
saes: The SAEs to be attached for this forward pass
|
|
272
|
+
reset_saes_end: If True, all SAEs added during this run are removed
|
|
273
|
+
at the end, and previously attached SAEs are restored to their
|
|
274
|
+
original state. Default is True.
|
|
275
|
+
fwd_hooks: List of forward hooks to apply
|
|
276
|
+
bwd_hooks: List of backward hooks to apply
|
|
277
|
+
reset_hooks_end: Whether to reset the hooks at the end of the forward
|
|
278
|
+
pass. Default is True.
|
|
279
|
+
clear_contexts: Whether to clear the contexts at the end of the forward
|
|
280
|
+
pass. Default is False.
|
|
281
|
+
**model_kwargs: Keyword arguments for the model forward pass
|
|
282
|
+
"""
|
|
283
|
+
with self.saes(saes=saes, reset_saes_end=reset_saes_end):
|
|
284
|
+
return self.run_with_hooks(
|
|
285
|
+
*model_args,
|
|
286
|
+
fwd_hooks=fwd_hooks,
|
|
287
|
+
bwd_hooks=bwd_hooks,
|
|
288
|
+
reset_hooks_end=reset_hooks_end,
|
|
289
|
+
clear_contexts=clear_contexts,
|
|
290
|
+
**model_kwargs,
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
@contextmanager
|
|
294
|
+
def saes(
|
|
295
|
+
self,
|
|
296
|
+
saes: SAE[Any] | list[SAE[Any]] = [],
|
|
297
|
+
reset_saes_end: bool = True,
|
|
298
|
+
use_error_term: bool | None = None,
|
|
299
|
+
): # type: ignore[no-untyped-def]
|
|
300
|
+
"""A context manager for adding temporary SAEs to the model.
|
|
301
|
+
|
|
302
|
+
By default will keep track of previously attached SAEs, and restore them
|
|
303
|
+
when the context manager exits.
|
|
304
|
+
|
|
305
|
+
Args:
|
|
306
|
+
saes: SAEs to be attached.
|
|
307
|
+
reset_saes_end: If True, removes all SAEs added by this context manager
|
|
308
|
+
when the context manager exits, returning previously attached SAEs
|
|
309
|
+
to their original state.
|
|
310
|
+
use_error_term: If provided, will set the use_error_term attribute of
|
|
311
|
+
all SAEs attached during this run to this value. Defaults to None.
|
|
312
|
+
"""
|
|
313
|
+
act_names_to_reset: list[str] = []
|
|
314
|
+
prev_saes: list[SAE[Any] | None] = []
|
|
315
|
+
if isinstance(saes, SAE):
|
|
316
|
+
saes = [saes]
|
|
317
|
+
try:
|
|
318
|
+
for sae in saes:
|
|
319
|
+
act_names_to_reset.append(sae.cfg.metadata.hook_name)
|
|
320
|
+
prev_sae = self.acts_to_saes.get(sae.cfg.metadata.hook_name, None)
|
|
321
|
+
prev_saes.append(prev_sae)
|
|
322
|
+
self.add_sae(sae, use_error_term=use_error_term)
|
|
323
|
+
yield self
|
|
324
|
+
finally:
|
|
325
|
+
if reset_saes_end:
|
|
326
|
+
self.reset_saes(act_names_to_reset, prev_saes)
|
|
327
|
+
|
|
328
|
+
@property
|
|
329
|
+
def hook_dict(self) -> dict[str, HookPoint]:
|
|
330
|
+
"""Return combined hook registry including SAE internal hooks.
|
|
331
|
+
|
|
332
|
+
When SAEs are attached, they replace HookPoint entries in the registry.
|
|
333
|
+
This property returns both the base hooks and any internal hooks from
|
|
334
|
+
attached SAEs (like hook_sae_acts_post, hook_sae_input, etc.) with
|
|
335
|
+
their full path names.
|
|
336
|
+
"""
|
|
337
|
+
hooks: dict[str, HookPoint] = {}
|
|
338
|
+
|
|
339
|
+
for name, hook_or_sae in self._hook_registry.items():
|
|
340
|
+
if isinstance(hook_or_sae, SAE):
|
|
341
|
+
# Include SAE's internal hooks with full path names
|
|
342
|
+
for sae_hook_name, sae_hook in hook_or_sae.hook_dict.items():
|
|
343
|
+
full_name = f"{name}.{sae_hook_name}"
|
|
344
|
+
hooks[full_name] = sae_hook
|
|
345
|
+
else:
|
|
346
|
+
hooks[name] = hook_or_sae
|
|
347
|
+
|
|
348
|
+
return hooks
|
sae_lens/config.py
CHANGED
|
@@ -82,6 +82,7 @@ class LoggingConfig:
|
|
|
82
82
|
log_to_wandb: bool = True
|
|
83
83
|
log_activations_store_to_wandb: bool = False
|
|
84
84
|
log_optimizer_state_to_wandb: bool = False
|
|
85
|
+
log_weights_to_wandb: bool = True
|
|
85
86
|
wandb_project: str = "sae_lens_training"
|
|
86
87
|
wandb_id: str | None = None
|
|
87
88
|
run_name: str | None = None
|
|
@@ -107,7 +108,8 @@ class LoggingConfig:
|
|
|
107
108
|
type="model",
|
|
108
109
|
metadata=dict(trainer.cfg.__dict__),
|
|
109
110
|
)
|
|
110
|
-
|
|
111
|
+
if self.log_weights_to_wandb:
|
|
112
|
+
model_artifact.add_file(str(weights_path))
|
|
111
113
|
model_artifact.add_file(str(cfg_path))
|
|
112
114
|
wandb.log_artifact(model_artifact, aliases=wandb_aliases)
|
|
113
115
|
|
|
@@ -557,6 +559,12 @@ class CacheActivationsRunnerConfig:
|
|
|
557
559
|
context_size=self.context_size,
|
|
558
560
|
)
|
|
559
561
|
|
|
562
|
+
if self.context_size > self.training_tokens:
|
|
563
|
+
raise ValueError(
|
|
564
|
+
f"context_size ({self.context_size}) is greater than training_tokens "
|
|
565
|
+
f"({self.training_tokens}). Please reduce context_size or increase training_tokens."
|
|
566
|
+
)
|
|
567
|
+
|
|
560
568
|
if self.new_cached_activations_path is None:
|
|
561
569
|
self.new_cached_activations_path = _default_cached_activations_path( # type: ignore
|
|
562
570
|
self.dataset_path, self.model_name, self.hook_name, None
|
sae_lens/evals.py
CHANGED
|
@@ -335,7 +335,7 @@ def get_downstream_reconstruction_metrics(
|
|
|
335
335
|
|
|
336
336
|
batch_iter = range(n_batches)
|
|
337
337
|
if verbose:
|
|
338
|
-
batch_iter = tqdm(batch_iter, desc="Reconstruction Batches")
|
|
338
|
+
batch_iter = tqdm(batch_iter, desc="Reconstruction Batches", leave=False)
|
|
339
339
|
|
|
340
340
|
for _ in batch_iter:
|
|
341
341
|
batch_tokens = activation_store.get_batch_tokens(eval_batch_size_prompts)
|
|
@@ -430,7 +430,7 @@ def get_sparsity_and_variance_metrics(
|
|
|
430
430
|
|
|
431
431
|
batch_iter = range(n_batches)
|
|
432
432
|
if verbose:
|
|
433
|
-
batch_iter = tqdm(batch_iter, desc="Sparsity and Variance Batches")
|
|
433
|
+
batch_iter = tqdm(batch_iter, desc="Sparsity and Variance Batches", leave=False)
|
|
434
434
|
|
|
435
435
|
for _ in batch_iter:
|
|
436
436
|
batch_tokens = activation_store.get_batch_tokens(eval_batch_size_prompts)
|
|
@@ -575,6 +575,8 @@ def _infer_gemma_3_raw_cfg_dict(repo_id: str, folder_name: str) -> dict[str, Any
|
|
|
575
575
|
"model_name": model_name,
|
|
576
576
|
"hf_hook_point_in": hf_hook_point_in,
|
|
577
577
|
}
|
|
578
|
+
if "transcoder" in folder_name or "clt" in folder_name:
|
|
579
|
+
cfg["affine_connection"] = "affine" in folder_name
|
|
578
580
|
if hf_hook_point_out is not None:
|
|
579
581
|
cfg["hf_hook_point_out"] = hf_hook_point_out
|
|
580
582
|
|
|
@@ -614,11 +616,11 @@ def get_gemma_3_config_from_hf(
|
|
|
614
616
|
if "resid_post" in folder_name:
|
|
615
617
|
hook_name = f"blocks.{layer}.hook_resid_post"
|
|
616
618
|
elif "attn_out" in folder_name:
|
|
617
|
-
hook_name = f"blocks.{layer}.
|
|
619
|
+
hook_name = f"blocks.{layer}.attn.hook_z"
|
|
618
620
|
elif "mlp_out" in folder_name:
|
|
619
621
|
hook_name = f"blocks.{layer}.hook_mlp_out"
|
|
620
622
|
elif "transcoder" in folder_name or "clt" in folder_name:
|
|
621
|
-
hook_name = f"blocks.{layer}.
|
|
623
|
+
hook_name = f"blocks.{layer}.hook_mlp_in"
|
|
622
624
|
hook_name_out = f"blocks.{layer}.hook_mlp_out"
|
|
623
625
|
else:
|
|
624
626
|
raise ValueError("Hook name not found in folder_name.")
|
|
@@ -643,7 +645,11 @@ def get_gemma_3_config_from_hf(
|
|
|
643
645
|
|
|
644
646
|
architecture = "jumprelu"
|
|
645
647
|
if "transcoder" in folder_name or "clt" in folder_name:
|
|
646
|
-
architecture =
|
|
648
|
+
architecture = (
|
|
649
|
+
"jumprelu_skip_transcoder"
|
|
650
|
+
if raw_cfg_dict.get("affine_connection", False)
|
|
651
|
+
else "jumprelu_transcoder"
|
|
652
|
+
)
|
|
647
653
|
d_out = shapes_dict["w_dec"][-1]
|
|
648
654
|
|
|
649
655
|
cfg = {
|
|
@@ -660,7 +666,8 @@ def get_gemma_3_config_from_hf(
|
|
|
660
666
|
"dataset_path": "monology/pile-uncopyrighted",
|
|
661
667
|
"context_size": 1024,
|
|
662
668
|
"apply_b_dec_to_input": False,
|
|
663
|
-
"normalize_activations":
|
|
669
|
+
"normalize_activations": "none",
|
|
670
|
+
"reshape_activations": "none",
|
|
664
671
|
"hf_hook_name": raw_cfg_dict.get("hf_hook_point_in"),
|
|
665
672
|
}
|
|
666
673
|
if hook_name_out is not None:
|