sae-lens 5.11.0__py3-none-any.whl → 6.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +60 -7
- sae_lens/analysis/hooked_sae_transformer.py +12 -12
- sae_lens/analysis/neuronpedia_integration.py +16 -14
- sae_lens/cache_activations_runner.py +9 -7
- sae_lens/config.py +170 -258
- sae_lens/constants.py +21 -0
- sae_lens/evals.py +59 -44
- sae_lens/llm_sae_training_runner.py +377 -0
- sae_lens/load_model.py +52 -4
- sae_lens/{toolkit → loading}/pretrained_sae_loaders.py +85 -32
- sae_lens/registry.py +49 -0
- sae_lens/saes/__init__.py +48 -0
- sae_lens/saes/gated_sae.py +254 -0
- sae_lens/saes/jumprelu_sae.py +348 -0
- sae_lens/saes/sae.py +1076 -0
- sae_lens/saes/standard_sae.py +178 -0
- sae_lens/saes/topk_sae.py +300 -0
- sae_lens/training/activation_scaler.py +53 -0
- sae_lens/training/activations_store.py +103 -184
- sae_lens/training/mixing_buffer.py +56 -0
- sae_lens/training/optim.py +60 -36
- sae_lens/training/sae_trainer.py +155 -177
- sae_lens/training/types.py +5 -0
- sae_lens/training/upload_saes_to_huggingface.py +13 -7
- sae_lens/util.py +47 -0
- {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/METADATA +1 -1
- sae_lens-6.0.0.dist-info/RECORD +37 -0
- sae_lens/sae.py +0 -747
- sae_lens/sae_training_runner.py +0 -251
- sae_lens/training/geometric_median.py +0 -101
- sae_lens/training/training_sae.py +0 -710
- sae_lens-5.11.0.dist-info/RECORD +0 -28
- /sae_lens/{toolkit → loading}/__init__.py +0 -0
- /sae_lens/{toolkit → loading}/pretrained_saes_directory.py +0 -0
- {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/LICENSE +0 -0
- {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/WHEEL +0 -0
sae_lens/sae.py
DELETED
|
@@ -1,747 +0,0 @@
|
|
|
1
|
-
"""Most of this is just copied over from Arthur's code and slightly simplified:
|
|
2
|
-
https://github.com/ArthurConmy/sae/blob/main/sae/model.py
|
|
3
|
-
"""
|
|
4
|
-
|
|
5
|
-
import json
|
|
6
|
-
import warnings
|
|
7
|
-
from contextlib import contextmanager
|
|
8
|
-
from dataclasses import dataclass, field
|
|
9
|
-
from pathlib import Path
|
|
10
|
-
from typing import Any, Callable, Literal, TypeVar, overload
|
|
11
|
-
|
|
12
|
-
import einops
|
|
13
|
-
import torch
|
|
14
|
-
from jaxtyping import Float
|
|
15
|
-
from safetensors.torch import save_file
|
|
16
|
-
from torch import nn
|
|
17
|
-
from transformer_lens.hook_points import HookedRootModule, HookPoint
|
|
18
|
-
from typing_extensions import deprecated
|
|
19
|
-
|
|
20
|
-
from sae_lens.config import (
|
|
21
|
-
DTYPE_MAP,
|
|
22
|
-
SAE_CFG_FILENAME,
|
|
23
|
-
SAE_WEIGHTS_FILENAME,
|
|
24
|
-
SPARSITY_FILENAME,
|
|
25
|
-
)
|
|
26
|
-
from sae_lens.toolkit.pretrained_sae_loaders import (
|
|
27
|
-
NAMED_PRETRAINED_SAE_LOADERS,
|
|
28
|
-
PretrainedSaeDiskLoader,
|
|
29
|
-
PretrainedSaeHuggingfaceLoader,
|
|
30
|
-
get_conversion_loader_name,
|
|
31
|
-
handle_config_defaulting,
|
|
32
|
-
sae_lens_disk_loader,
|
|
33
|
-
)
|
|
34
|
-
from sae_lens.toolkit.pretrained_saes_directory import (
|
|
35
|
-
get_config_overrides,
|
|
36
|
-
get_norm_scaling_factor,
|
|
37
|
-
get_pretrained_saes_directory,
|
|
38
|
-
get_repo_id_and_folder_name,
|
|
39
|
-
)
|
|
40
|
-
|
|
41
|
-
T = TypeVar("T", bound="SAE")
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
@dataclass
|
|
45
|
-
class SAEConfig:
|
|
46
|
-
# architecture details
|
|
47
|
-
architecture: Literal["standard", "gated", "jumprelu", "topk"]
|
|
48
|
-
|
|
49
|
-
# forward pass details.
|
|
50
|
-
d_in: int
|
|
51
|
-
d_sae: int
|
|
52
|
-
activation_fn_str: str
|
|
53
|
-
apply_b_dec_to_input: bool
|
|
54
|
-
finetuning_scaling_factor: bool
|
|
55
|
-
|
|
56
|
-
# dataset it was trained on details.
|
|
57
|
-
context_size: int
|
|
58
|
-
model_name: str
|
|
59
|
-
hook_name: str
|
|
60
|
-
hook_layer: int
|
|
61
|
-
hook_head_index: int | None
|
|
62
|
-
prepend_bos: bool
|
|
63
|
-
dataset_path: str
|
|
64
|
-
dataset_trust_remote_code: bool
|
|
65
|
-
normalize_activations: str
|
|
66
|
-
|
|
67
|
-
# misc
|
|
68
|
-
dtype: str
|
|
69
|
-
device: str
|
|
70
|
-
sae_lens_training_version: str | None
|
|
71
|
-
activation_fn_kwargs: dict[str, Any] = field(default_factory=dict)
|
|
72
|
-
neuronpedia_id: str | None = None
|
|
73
|
-
model_from_pretrained_kwargs: dict[str, Any] = field(default_factory=dict)
|
|
74
|
-
seqpos_slice: tuple[int | None, ...] = (None,)
|
|
75
|
-
|
|
76
|
-
@classmethod
|
|
77
|
-
def from_dict(cls, config_dict: dict[str, Any]) -> "SAEConfig":
|
|
78
|
-
# rename dict:
|
|
79
|
-
rename_dict = { # old : new
|
|
80
|
-
"hook_point": "hook_name",
|
|
81
|
-
"hook_point_head_index": "hook_head_index",
|
|
82
|
-
"hook_point_layer": "hook_layer",
|
|
83
|
-
"activation_fn": "activation_fn_str",
|
|
84
|
-
}
|
|
85
|
-
config_dict = {rename_dict.get(k, k): v for k, v in config_dict.items()}
|
|
86
|
-
|
|
87
|
-
# use only config terms that are in the dataclass
|
|
88
|
-
config_dict = {
|
|
89
|
-
k: v
|
|
90
|
-
for k, v in config_dict.items()
|
|
91
|
-
if k in cls.__dataclass_fields__ # pylint: disable=no-member
|
|
92
|
-
}
|
|
93
|
-
|
|
94
|
-
if "seqpos_slice" in config_dict:
|
|
95
|
-
config_dict["seqpos_slice"] = tuple(config_dict["seqpos_slice"])
|
|
96
|
-
|
|
97
|
-
return cls(**config_dict)
|
|
98
|
-
|
|
99
|
-
# def __post_init__(self):
|
|
100
|
-
|
|
101
|
-
def to_dict(self) -> dict[str, Any]:
|
|
102
|
-
return {
|
|
103
|
-
"architecture": self.architecture,
|
|
104
|
-
"d_in": self.d_in,
|
|
105
|
-
"d_sae": self.d_sae,
|
|
106
|
-
"dtype": self.dtype,
|
|
107
|
-
"device": self.device,
|
|
108
|
-
"model_name": self.model_name,
|
|
109
|
-
"hook_name": self.hook_name,
|
|
110
|
-
"hook_layer": self.hook_layer,
|
|
111
|
-
"hook_head_index": self.hook_head_index,
|
|
112
|
-
"activation_fn_str": self.activation_fn_str, # use string for serialization
|
|
113
|
-
"activation_fn_kwargs": self.activation_fn_kwargs or {},
|
|
114
|
-
"apply_b_dec_to_input": self.apply_b_dec_to_input,
|
|
115
|
-
"finetuning_scaling_factor": self.finetuning_scaling_factor,
|
|
116
|
-
"sae_lens_training_version": self.sae_lens_training_version,
|
|
117
|
-
"prepend_bos": self.prepend_bos,
|
|
118
|
-
"dataset_path": self.dataset_path,
|
|
119
|
-
"dataset_trust_remote_code": self.dataset_trust_remote_code,
|
|
120
|
-
"context_size": self.context_size,
|
|
121
|
-
"normalize_activations": self.normalize_activations,
|
|
122
|
-
"neuronpedia_id": self.neuronpedia_id,
|
|
123
|
-
"model_from_pretrained_kwargs": self.model_from_pretrained_kwargs,
|
|
124
|
-
"seqpos_slice": self.seqpos_slice,
|
|
125
|
-
}
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
class SAE(HookedRootModule):
|
|
129
|
-
"""
|
|
130
|
-
Core Sparse Autoencoder (SAE) class used for inference. For training, see `TrainingSAE`.
|
|
131
|
-
"""
|
|
132
|
-
|
|
133
|
-
cfg: SAEConfig
|
|
134
|
-
dtype: torch.dtype
|
|
135
|
-
device: torch.device
|
|
136
|
-
x_norm_coeff: torch.Tensor
|
|
137
|
-
|
|
138
|
-
# analysis
|
|
139
|
-
use_error_term: bool
|
|
140
|
-
|
|
141
|
-
def __init__(
|
|
142
|
-
self,
|
|
143
|
-
cfg: SAEConfig,
|
|
144
|
-
use_error_term: bool = False,
|
|
145
|
-
):
|
|
146
|
-
super().__init__()
|
|
147
|
-
|
|
148
|
-
self.cfg = cfg
|
|
149
|
-
|
|
150
|
-
if cfg.model_from_pretrained_kwargs:
|
|
151
|
-
warnings.warn(
|
|
152
|
-
"\nThis SAE has non-empty model_from_pretrained_kwargs. "
|
|
153
|
-
"\nFor optimal performance, load the model like so:\n"
|
|
154
|
-
"model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)",
|
|
155
|
-
category=UserWarning,
|
|
156
|
-
stacklevel=1,
|
|
157
|
-
)
|
|
158
|
-
|
|
159
|
-
self.activation_fn = get_activation_fn(
|
|
160
|
-
cfg.activation_fn_str, **cfg.activation_fn_kwargs or {}
|
|
161
|
-
)
|
|
162
|
-
self.dtype = DTYPE_MAP[cfg.dtype]
|
|
163
|
-
self.device = torch.device(cfg.device)
|
|
164
|
-
self.use_error_term = use_error_term
|
|
165
|
-
|
|
166
|
-
if self.cfg.architecture == "standard" or self.cfg.architecture == "topk":
|
|
167
|
-
self.initialize_weights_basic()
|
|
168
|
-
self.encode = self.encode_standard
|
|
169
|
-
elif self.cfg.architecture == "gated":
|
|
170
|
-
self.initialize_weights_gated()
|
|
171
|
-
self.encode = self.encode_gated
|
|
172
|
-
elif self.cfg.architecture == "jumprelu":
|
|
173
|
-
self.initialize_weights_jumprelu()
|
|
174
|
-
self.encode = self.encode_jumprelu
|
|
175
|
-
else:
|
|
176
|
-
raise ValueError(f"Invalid architecture: {self.cfg.architecture}")
|
|
177
|
-
|
|
178
|
-
# handle presence / absence of scaling factor.
|
|
179
|
-
if self.cfg.finetuning_scaling_factor:
|
|
180
|
-
self.apply_finetuning_scaling_factor = (
|
|
181
|
-
lambda x: x * self.finetuning_scaling_factor
|
|
182
|
-
)
|
|
183
|
-
else:
|
|
184
|
-
self.apply_finetuning_scaling_factor = lambda x: x
|
|
185
|
-
|
|
186
|
-
# set up hooks
|
|
187
|
-
self.hook_sae_input = HookPoint()
|
|
188
|
-
self.hook_sae_acts_pre = HookPoint()
|
|
189
|
-
self.hook_sae_acts_post = HookPoint()
|
|
190
|
-
self.hook_sae_output = HookPoint()
|
|
191
|
-
self.hook_sae_recons = HookPoint()
|
|
192
|
-
self.hook_sae_error = HookPoint()
|
|
193
|
-
|
|
194
|
-
# handle hook_z reshaping if needed.
|
|
195
|
-
# this is very cursed and should be refactored. it exists so that we can reshape out
|
|
196
|
-
# the z activations for hook_z SAEs. but don't know d_head if we split up the forward pass
|
|
197
|
-
# into a separate encode and decode function.
|
|
198
|
-
# this will cause errors if we call decode before encode.
|
|
199
|
-
if self.cfg.hook_name.endswith("_z"):
|
|
200
|
-
self.turn_on_forward_pass_hook_z_reshaping()
|
|
201
|
-
else:
|
|
202
|
-
# need to default the reshape fns
|
|
203
|
-
self.turn_off_forward_pass_hook_z_reshaping()
|
|
204
|
-
|
|
205
|
-
# handle run time activation normalization if needed:
|
|
206
|
-
if self.cfg.normalize_activations == "constant_norm_rescale":
|
|
207
|
-
# we need to scale the norm of the input and store the scaling factor
|
|
208
|
-
def run_time_activation_norm_fn_in(x: torch.Tensor) -> torch.Tensor:
|
|
209
|
-
self.x_norm_coeff = (self.cfg.d_in**0.5) / x.norm(dim=-1, keepdim=True)
|
|
210
|
-
return x * self.x_norm_coeff
|
|
211
|
-
|
|
212
|
-
def run_time_activation_norm_fn_out(x: torch.Tensor) -> torch.Tensor: #
|
|
213
|
-
x = x / self.x_norm_coeff
|
|
214
|
-
del self.x_norm_coeff # prevents reusing
|
|
215
|
-
return x
|
|
216
|
-
|
|
217
|
-
self.run_time_activation_norm_fn_in = run_time_activation_norm_fn_in
|
|
218
|
-
self.run_time_activation_norm_fn_out = run_time_activation_norm_fn_out
|
|
219
|
-
|
|
220
|
-
elif self.cfg.normalize_activations == "layer_norm":
|
|
221
|
-
# we need to scale the norm of the input and store the scaling factor
|
|
222
|
-
def run_time_activation_ln_in(
|
|
223
|
-
x: torch.Tensor, eps: float = 1e-5
|
|
224
|
-
) -> torch.Tensor:
|
|
225
|
-
mu = x.mean(dim=-1, keepdim=True)
|
|
226
|
-
x = x - mu
|
|
227
|
-
std = x.std(dim=-1, keepdim=True)
|
|
228
|
-
x = x / (std + eps)
|
|
229
|
-
self.ln_mu = mu
|
|
230
|
-
self.ln_std = std
|
|
231
|
-
return x
|
|
232
|
-
|
|
233
|
-
def run_time_activation_ln_out(
|
|
234
|
-
x: torch.Tensor,
|
|
235
|
-
eps: float = 1e-5, # noqa: ARG001
|
|
236
|
-
) -> torch.Tensor:
|
|
237
|
-
return x * self.ln_std + self.ln_mu # type: ignore
|
|
238
|
-
|
|
239
|
-
self.run_time_activation_norm_fn_in = run_time_activation_ln_in
|
|
240
|
-
self.run_time_activation_norm_fn_out = run_time_activation_ln_out
|
|
241
|
-
else:
|
|
242
|
-
self.run_time_activation_norm_fn_in = lambda x: x
|
|
243
|
-
self.run_time_activation_norm_fn_out = lambda x: x
|
|
244
|
-
|
|
245
|
-
self.setup() # Required for `HookedRootModule`s
|
|
246
|
-
|
|
247
|
-
def initialize_weights_basic(self):
|
|
248
|
-
# no config changes encoder bias init for now.
|
|
249
|
-
self.b_enc = nn.Parameter(
|
|
250
|
-
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
|
|
251
|
-
)
|
|
252
|
-
|
|
253
|
-
# Start with the default init strategy:
|
|
254
|
-
self.W_dec = nn.Parameter(
|
|
255
|
-
torch.nn.init.kaiming_uniform_(
|
|
256
|
-
torch.empty(
|
|
257
|
-
self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
|
|
258
|
-
)
|
|
259
|
-
)
|
|
260
|
-
)
|
|
261
|
-
|
|
262
|
-
self.W_enc = nn.Parameter(
|
|
263
|
-
torch.nn.init.kaiming_uniform_(
|
|
264
|
-
torch.empty(
|
|
265
|
-
self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
|
|
266
|
-
)
|
|
267
|
-
)
|
|
268
|
-
)
|
|
269
|
-
|
|
270
|
-
# methdods which change b_dec as a function of the dataset are implemented after init.
|
|
271
|
-
self.b_dec = nn.Parameter(
|
|
272
|
-
torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
|
|
273
|
-
)
|
|
274
|
-
|
|
275
|
-
# scaling factor for fine-tuning (not to be used in initial training)
|
|
276
|
-
# TODO: Make this optional and not included with all SAEs by default (but maintain backwards compatibility)
|
|
277
|
-
if self.cfg.finetuning_scaling_factor:
|
|
278
|
-
self.finetuning_scaling_factor = nn.Parameter(
|
|
279
|
-
torch.ones(self.cfg.d_sae, dtype=self.dtype, device=self.device)
|
|
280
|
-
)
|
|
281
|
-
|
|
282
|
-
def initialize_weights_gated(self):
|
|
283
|
-
# Initialize the weights and biases for the gated encoder
|
|
284
|
-
self.W_enc = nn.Parameter(
|
|
285
|
-
torch.nn.init.kaiming_uniform_(
|
|
286
|
-
torch.empty(
|
|
287
|
-
self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
|
|
288
|
-
)
|
|
289
|
-
)
|
|
290
|
-
)
|
|
291
|
-
|
|
292
|
-
self.b_gate = nn.Parameter(
|
|
293
|
-
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
|
|
294
|
-
)
|
|
295
|
-
|
|
296
|
-
self.r_mag = nn.Parameter(
|
|
297
|
-
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
|
|
298
|
-
)
|
|
299
|
-
|
|
300
|
-
self.b_mag = nn.Parameter(
|
|
301
|
-
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
|
|
302
|
-
)
|
|
303
|
-
|
|
304
|
-
self.W_dec = nn.Parameter(
|
|
305
|
-
torch.nn.init.kaiming_uniform_(
|
|
306
|
-
torch.empty(
|
|
307
|
-
self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
|
|
308
|
-
)
|
|
309
|
-
)
|
|
310
|
-
)
|
|
311
|
-
|
|
312
|
-
self.b_dec = nn.Parameter(
|
|
313
|
-
torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
|
|
314
|
-
)
|
|
315
|
-
|
|
316
|
-
def initialize_weights_jumprelu(self):
|
|
317
|
-
# The params are identical to the standard SAE
|
|
318
|
-
# except we use a threshold parameter too
|
|
319
|
-
self.threshold = nn.Parameter(
|
|
320
|
-
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
|
|
321
|
-
)
|
|
322
|
-
self.initialize_weights_basic()
|
|
323
|
-
|
|
324
|
-
@overload
|
|
325
|
-
def to(
|
|
326
|
-
self: T,
|
|
327
|
-
device: torch.device | str | None = ...,
|
|
328
|
-
dtype: torch.dtype | None = ...,
|
|
329
|
-
non_blocking: bool = ...,
|
|
330
|
-
) -> T: ...
|
|
331
|
-
|
|
332
|
-
@overload
|
|
333
|
-
def to(self: T, dtype: torch.dtype, non_blocking: bool = ...) -> T: ...
|
|
334
|
-
|
|
335
|
-
@overload
|
|
336
|
-
def to(self: T, tensor: torch.Tensor, non_blocking: bool = ...) -> T: ...
|
|
337
|
-
|
|
338
|
-
def to(self, *args: Any, **kwargs: Any) -> "SAE": # type: ignore
|
|
339
|
-
device_arg = None
|
|
340
|
-
dtype_arg = None
|
|
341
|
-
|
|
342
|
-
# Check args
|
|
343
|
-
for arg in args:
|
|
344
|
-
if isinstance(arg, (torch.device, str)):
|
|
345
|
-
device_arg = arg
|
|
346
|
-
elif isinstance(arg, torch.dtype):
|
|
347
|
-
dtype_arg = arg
|
|
348
|
-
elif isinstance(arg, torch.Tensor):
|
|
349
|
-
device_arg = arg.device
|
|
350
|
-
dtype_arg = arg.dtype
|
|
351
|
-
|
|
352
|
-
# Check kwargs
|
|
353
|
-
device_arg = kwargs.get("device", device_arg)
|
|
354
|
-
dtype_arg = kwargs.get("dtype", dtype_arg)
|
|
355
|
-
|
|
356
|
-
if device_arg is not None:
|
|
357
|
-
# Convert device to torch.device if it's a string
|
|
358
|
-
device = (
|
|
359
|
-
torch.device(device_arg) if isinstance(device_arg, str) else device_arg
|
|
360
|
-
)
|
|
361
|
-
|
|
362
|
-
# Update the cfg.device
|
|
363
|
-
self.cfg.device = str(device)
|
|
364
|
-
|
|
365
|
-
# Update the .device property
|
|
366
|
-
self.device = device
|
|
367
|
-
|
|
368
|
-
if dtype_arg is not None:
|
|
369
|
-
# Update the cfg.dtype
|
|
370
|
-
self.cfg.dtype = str(dtype_arg)
|
|
371
|
-
|
|
372
|
-
# Update the .dtype property
|
|
373
|
-
self.dtype = dtype_arg
|
|
374
|
-
|
|
375
|
-
# Call the parent class's to() method to handle all cases (device, dtype, tensor)
|
|
376
|
-
return super().to(*args, **kwargs)
|
|
377
|
-
|
|
378
|
-
# Basic Forward Pass Functionality.
|
|
379
|
-
def forward(
|
|
380
|
-
self,
|
|
381
|
-
x: torch.Tensor,
|
|
382
|
-
) -> torch.Tensor:
|
|
383
|
-
feature_acts = self.encode(x)
|
|
384
|
-
sae_out = self.decode(feature_acts)
|
|
385
|
-
|
|
386
|
-
# TEMP
|
|
387
|
-
if self.use_error_term:
|
|
388
|
-
with torch.no_grad():
|
|
389
|
-
# Recompute everything without hooks to get true error term
|
|
390
|
-
# Otherwise, the output with error term will always equal input, even for causal interventions that affect x_reconstruct
|
|
391
|
-
# This is in a no_grad context to detach the error, so we can compute SAE feature gradients (eg for attribution patching). See A.3 in https://arxiv.org/pdf/2403.19647.pdf for more detail
|
|
392
|
-
# NOTE: we can't just use `sae_error = input - x_reconstruct.detach()` or something simpler, since this would mean intervening on features would mean ablating features still results in perfect reconstruction.
|
|
393
|
-
with _disable_hooks(self):
|
|
394
|
-
feature_acts_clean = self.encode(x)
|
|
395
|
-
x_reconstruct_clean = self.decode(feature_acts_clean)
|
|
396
|
-
sae_error = self.hook_sae_error(x - x_reconstruct_clean)
|
|
397
|
-
sae_out = sae_out + sae_error
|
|
398
|
-
return self.hook_sae_output(sae_out)
|
|
399
|
-
|
|
400
|
-
def encode_gated(
|
|
401
|
-
self, x: Float[torch.Tensor, "... d_in"]
|
|
402
|
-
) -> Float[torch.Tensor, "... d_sae"]:
|
|
403
|
-
sae_in = self.process_sae_in(x)
|
|
404
|
-
|
|
405
|
-
# Gating path
|
|
406
|
-
gating_pre_activation = sae_in @ self.W_enc + self.b_gate
|
|
407
|
-
active_features = (gating_pre_activation > 0).to(self.dtype)
|
|
408
|
-
|
|
409
|
-
# Magnitude path with weight sharing
|
|
410
|
-
magnitude_pre_activation = self.hook_sae_acts_pre(
|
|
411
|
-
sae_in @ (self.W_enc * self.r_mag.exp()) + self.b_mag
|
|
412
|
-
)
|
|
413
|
-
feature_magnitudes = self.activation_fn(magnitude_pre_activation)
|
|
414
|
-
|
|
415
|
-
return self.hook_sae_acts_post(active_features * feature_magnitudes)
|
|
416
|
-
|
|
417
|
-
def encode_jumprelu(
|
|
418
|
-
self, x: Float[torch.Tensor, "... d_in"]
|
|
419
|
-
) -> Float[torch.Tensor, "... d_sae"]:
|
|
420
|
-
"""
|
|
421
|
-
Calculate SAE features from inputs
|
|
422
|
-
"""
|
|
423
|
-
sae_in = self.process_sae_in(x)
|
|
424
|
-
|
|
425
|
-
# "... d_in, d_in d_sae -> ... d_sae",
|
|
426
|
-
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
|
|
427
|
-
|
|
428
|
-
return self.hook_sae_acts_post(
|
|
429
|
-
self.activation_fn(hidden_pre) * (hidden_pre > self.threshold)
|
|
430
|
-
)
|
|
431
|
-
|
|
432
|
-
def encode_standard(
|
|
433
|
-
self, x: Float[torch.Tensor, "... d_in"]
|
|
434
|
-
) -> Float[torch.Tensor, "... d_sae"]:
|
|
435
|
-
"""
|
|
436
|
-
Calculate SAE features from inputs
|
|
437
|
-
"""
|
|
438
|
-
sae_in = self.process_sae_in(x)
|
|
439
|
-
|
|
440
|
-
# "... d_in, d_in d_sae -> ... d_sae",
|
|
441
|
-
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
|
|
442
|
-
return self.hook_sae_acts_post(self.activation_fn(hidden_pre))
|
|
443
|
-
|
|
444
|
-
def process_sae_in(
|
|
445
|
-
self, sae_in: Float[torch.Tensor, "... d_in"]
|
|
446
|
-
) -> Float[torch.Tensor, "... d_sae"]:
|
|
447
|
-
sae_in = sae_in.to(self.dtype)
|
|
448
|
-
sae_in = self.reshape_fn_in(sae_in)
|
|
449
|
-
sae_in = self.hook_sae_input(sae_in)
|
|
450
|
-
sae_in = self.run_time_activation_norm_fn_in(sae_in)
|
|
451
|
-
return sae_in - (self.b_dec * self.cfg.apply_b_dec_to_input)
|
|
452
|
-
|
|
453
|
-
def decode(
|
|
454
|
-
self, feature_acts: Float[torch.Tensor, "... d_sae"]
|
|
455
|
-
) -> Float[torch.Tensor, "... d_in"]:
|
|
456
|
-
"""Decodes SAE feature activation tensor into a reconstructed input activation tensor."""
|
|
457
|
-
# "... d_sae, d_sae d_in -> ... d_in",
|
|
458
|
-
sae_out = self.hook_sae_recons(
|
|
459
|
-
self.apply_finetuning_scaling_factor(feature_acts) @ self.W_dec + self.b_dec
|
|
460
|
-
)
|
|
461
|
-
|
|
462
|
-
# handle run time activation normalization if needed
|
|
463
|
-
# will fail if you call this twice without calling encode in between.
|
|
464
|
-
sae_out = self.run_time_activation_norm_fn_out(sae_out)
|
|
465
|
-
|
|
466
|
-
# handle hook z reshaping if needed.
|
|
467
|
-
return self.reshape_fn_out(sae_out, self.d_head) # type: ignore
|
|
468
|
-
|
|
469
|
-
@torch.no_grad()
|
|
470
|
-
def fold_W_dec_norm(self):
|
|
471
|
-
W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
|
|
472
|
-
self.W_dec.data = self.W_dec.data / W_dec_norms
|
|
473
|
-
self.W_enc.data = self.W_enc.data * W_dec_norms.T
|
|
474
|
-
if self.cfg.architecture == "gated":
|
|
475
|
-
self.r_mag.data = self.r_mag.data * W_dec_norms.squeeze()
|
|
476
|
-
self.b_gate.data = self.b_gate.data * W_dec_norms.squeeze()
|
|
477
|
-
self.b_mag.data = self.b_mag.data * W_dec_norms.squeeze()
|
|
478
|
-
elif self.cfg.architecture == "jumprelu":
|
|
479
|
-
self.threshold.data = self.threshold.data * W_dec_norms.squeeze()
|
|
480
|
-
self.b_enc.data = self.b_enc.data * W_dec_norms.squeeze()
|
|
481
|
-
else:
|
|
482
|
-
self.b_enc.data = self.b_enc.data * W_dec_norms.squeeze()
|
|
483
|
-
|
|
484
|
-
@torch.no_grad()
|
|
485
|
-
def fold_activation_norm_scaling_factor(
|
|
486
|
-
self, activation_norm_scaling_factor: float
|
|
487
|
-
):
|
|
488
|
-
self.W_enc.data = self.W_enc.data * activation_norm_scaling_factor
|
|
489
|
-
# previously weren't doing this.
|
|
490
|
-
self.W_dec.data = self.W_dec.data / activation_norm_scaling_factor
|
|
491
|
-
self.b_dec.data = self.b_dec.data / activation_norm_scaling_factor
|
|
492
|
-
|
|
493
|
-
# once we normalize, we shouldn't need to scale activations.
|
|
494
|
-
self.cfg.normalize_activations = "none"
|
|
495
|
-
|
|
496
|
-
@overload
|
|
497
|
-
def save_model(self, path: str | Path) -> tuple[Path, Path]: ...
|
|
498
|
-
|
|
499
|
-
@overload
|
|
500
|
-
def save_model(
|
|
501
|
-
self, path: str | Path, sparsity: torch.Tensor
|
|
502
|
-
) -> tuple[Path, Path, Path]: ...
|
|
503
|
-
|
|
504
|
-
def save_model(self, path: str | Path, sparsity: torch.Tensor | None = None):
|
|
505
|
-
path = Path(path)
|
|
506
|
-
|
|
507
|
-
if not path.exists():
|
|
508
|
-
path.mkdir(parents=True)
|
|
509
|
-
|
|
510
|
-
# generate the weights
|
|
511
|
-
state_dict = self.state_dict()
|
|
512
|
-
self.process_state_dict_for_saving(state_dict)
|
|
513
|
-
model_weights_path = path / SAE_WEIGHTS_FILENAME
|
|
514
|
-
save_file(state_dict, model_weights_path)
|
|
515
|
-
|
|
516
|
-
# save the config
|
|
517
|
-
config = self.cfg.to_dict()
|
|
518
|
-
|
|
519
|
-
cfg_path = path / SAE_CFG_FILENAME
|
|
520
|
-
with open(cfg_path, "w") as f:
|
|
521
|
-
json.dump(config, f)
|
|
522
|
-
|
|
523
|
-
if sparsity is not None:
|
|
524
|
-
sparsity_in_dict = {"sparsity": sparsity}
|
|
525
|
-
sparsity_path = path / SPARSITY_FILENAME
|
|
526
|
-
save_file(sparsity_in_dict, sparsity_path)
|
|
527
|
-
return model_weights_path, cfg_path, sparsity_path
|
|
528
|
-
|
|
529
|
-
return model_weights_path, cfg_path
|
|
530
|
-
|
|
531
|
-
# overwrite this in subclasses to modify the state_dict in-place before saving
|
|
532
|
-
def process_state_dict_for_saving(self, state_dict: dict[str, Any]) -> None:
|
|
533
|
-
pass
|
|
534
|
-
|
|
535
|
-
# overwrite this in subclasses to modify the state_dict in-place after loading
|
|
536
|
-
def process_state_dict_for_loading(self, state_dict: dict[str, Any]) -> None:
|
|
537
|
-
pass
|
|
538
|
-
|
|
539
|
-
@classmethod
|
|
540
|
-
@deprecated("Use load_from_disk instead")
|
|
541
|
-
def load_from_pretrained(
|
|
542
|
-
cls, path: str, device: str = "cpu", dtype: str | None = None
|
|
543
|
-
) -> "SAE":
|
|
544
|
-
sae = cls.load_from_disk(path, device)
|
|
545
|
-
if dtype is not None:
|
|
546
|
-
sae.cfg.dtype = dtype
|
|
547
|
-
sae = sae.to(dtype)
|
|
548
|
-
return sae
|
|
549
|
-
|
|
550
|
-
@classmethod
|
|
551
|
-
def load_from_disk(
|
|
552
|
-
cls,
|
|
553
|
-
path: str,
|
|
554
|
-
device: str = "cpu",
|
|
555
|
-
dtype: str | None = None,
|
|
556
|
-
converter: PretrainedSaeDiskLoader = sae_lens_disk_loader,
|
|
557
|
-
) -> "SAE":
|
|
558
|
-
overrides = {"dtype": dtype} if dtype is not None else None
|
|
559
|
-
cfg_dict, state_dict = converter(path, device, cfg_overrides=overrides)
|
|
560
|
-
cfg_dict = handle_config_defaulting(cfg_dict)
|
|
561
|
-
sae_cfg = SAEConfig.from_dict(cfg_dict)
|
|
562
|
-
sae = cls(sae_cfg)
|
|
563
|
-
sae.process_state_dict_for_loading(state_dict)
|
|
564
|
-
sae.load_state_dict(state_dict)
|
|
565
|
-
return sae
|
|
566
|
-
|
|
567
|
-
@classmethod
|
|
568
|
-
def from_pretrained(
|
|
569
|
-
cls,
|
|
570
|
-
release: str,
|
|
571
|
-
sae_id: str,
|
|
572
|
-
device: str = "cpu",
|
|
573
|
-
force_download: bool = False,
|
|
574
|
-
converter: PretrainedSaeHuggingfaceLoader | None = None,
|
|
575
|
-
) -> tuple["SAE", dict[str, Any], torch.Tensor | None]:
|
|
576
|
-
"""
|
|
577
|
-
Load a pretrained SAE from the Hugging Face model hub.
|
|
578
|
-
|
|
579
|
-
Args:
|
|
580
|
-
release: The release name. This will be mapped to a huggingface repo id based on the pretrained_saes.yaml file.
|
|
581
|
-
id: The id of the SAE to load. This will be mapped to a path in the huggingface repo.
|
|
582
|
-
device: The device to load the SAE on.
|
|
583
|
-
return_sparsity_if_present: If True, will return the log sparsity tensor if it is present in the model directory in the Hugging Face model hub.
|
|
584
|
-
"""
|
|
585
|
-
|
|
586
|
-
# get sae directory
|
|
587
|
-
sae_directory = get_pretrained_saes_directory()
|
|
588
|
-
|
|
589
|
-
# get the repo id and path to the SAE
|
|
590
|
-
if release not in sae_directory:
|
|
591
|
-
if "/" not in release:
|
|
592
|
-
raise ValueError(
|
|
593
|
-
f"Release {release} not found in pretrained SAEs directory, and is not a valid huggingface repo."
|
|
594
|
-
)
|
|
595
|
-
elif sae_id not in sae_directory[release].saes_map:
|
|
596
|
-
# If using Gemma Scope and not the canonical release, give a hint to use it
|
|
597
|
-
if (
|
|
598
|
-
"gemma-scope" in release
|
|
599
|
-
and "canonical" not in release
|
|
600
|
-
and f"{release}-canonical" in sae_directory
|
|
601
|
-
):
|
|
602
|
-
canonical_ids = list(
|
|
603
|
-
sae_directory[release + "-canonical"].saes_map.keys()
|
|
604
|
-
)
|
|
605
|
-
# Shorten the lengthy string of valid IDs
|
|
606
|
-
if len(canonical_ids) > 5:
|
|
607
|
-
str_canonical_ids = str(canonical_ids[:5])[:-1] + ", ...]"
|
|
608
|
-
else:
|
|
609
|
-
str_canonical_ids = str(canonical_ids)
|
|
610
|
-
value_suffix = f" If you don't want to specify an L0 value, consider using release {release}-canonical which has valid IDs {str_canonical_ids}"
|
|
611
|
-
else:
|
|
612
|
-
value_suffix = ""
|
|
613
|
-
|
|
614
|
-
valid_ids = list(sae_directory[release].saes_map.keys())
|
|
615
|
-
# Shorten the lengthy string of valid IDs
|
|
616
|
-
if len(valid_ids) > 5:
|
|
617
|
-
str_valid_ids = str(valid_ids[:5])[:-1] + ", ...]"
|
|
618
|
-
else:
|
|
619
|
-
str_valid_ids = str(valid_ids)
|
|
620
|
-
|
|
621
|
-
raise ValueError(
|
|
622
|
-
f"ID {sae_id} not found in release {release}. Valid IDs are {str_valid_ids}."
|
|
623
|
-
+ value_suffix
|
|
624
|
-
)
|
|
625
|
-
|
|
626
|
-
conversion_loader = (
|
|
627
|
-
converter
|
|
628
|
-
or NAMED_PRETRAINED_SAE_LOADERS[get_conversion_loader_name(release)]
|
|
629
|
-
)
|
|
630
|
-
repo_id, folder_name = get_repo_id_and_folder_name(release, sae_id)
|
|
631
|
-
config_overrides = get_config_overrides(release, sae_id)
|
|
632
|
-
config_overrides["device"] = device
|
|
633
|
-
|
|
634
|
-
cfg_dict, state_dict, log_sparsities = conversion_loader(
|
|
635
|
-
repo_id=repo_id,
|
|
636
|
-
folder_name=folder_name,
|
|
637
|
-
device=device,
|
|
638
|
-
force_download=force_download,
|
|
639
|
-
cfg_overrides=config_overrides,
|
|
640
|
-
)
|
|
641
|
-
cfg_dict = handle_config_defaulting(cfg_dict)
|
|
642
|
-
|
|
643
|
-
sae = cls(SAEConfig.from_dict(cfg_dict))
|
|
644
|
-
sae.process_state_dict_for_loading(state_dict)
|
|
645
|
-
sae.load_state_dict(state_dict)
|
|
646
|
-
|
|
647
|
-
# Check if normalization is 'expected_average_only_in'
|
|
648
|
-
if cfg_dict.get("normalize_activations") == "expected_average_only_in":
|
|
649
|
-
norm_scaling_factor = get_norm_scaling_factor(release, sae_id)
|
|
650
|
-
if norm_scaling_factor is not None:
|
|
651
|
-
sae.fold_activation_norm_scaling_factor(norm_scaling_factor)
|
|
652
|
-
cfg_dict["normalize_activations"] = "none"
|
|
653
|
-
else:
|
|
654
|
-
warnings.warn(
|
|
655
|
-
f"norm_scaling_factor not found for {release} and {sae_id}, but normalize_activations is 'expected_average_only_in'. Skipping normalization folding."
|
|
656
|
-
)
|
|
657
|
-
|
|
658
|
-
return sae, cfg_dict, log_sparsities
|
|
659
|
-
|
|
660
|
-
def get_name(self):
|
|
661
|
-
return f"sae_{self.cfg.model_name}_{self.cfg.hook_name}_{self.cfg.d_sae}"
|
|
662
|
-
|
|
663
|
-
@classmethod
|
|
664
|
-
def from_dict(cls, config_dict: dict[str, Any]) -> "SAE":
|
|
665
|
-
return cls(SAEConfig.from_dict(config_dict))
|
|
666
|
-
|
|
667
|
-
def turn_on_forward_pass_hook_z_reshaping(self):
|
|
668
|
-
if not self.cfg.hook_name.endswith("_z"):
|
|
669
|
-
raise ValueError("This method should only be called for hook_z SAEs.")
|
|
670
|
-
|
|
671
|
-
def reshape_fn_in(x: torch.Tensor):
|
|
672
|
-
self.d_head = x.shape[-1] # type: ignore
|
|
673
|
-
self.reshape_fn_in = lambda x: einops.rearrange(
|
|
674
|
-
x, "... n_heads d_head -> ... (n_heads d_head)"
|
|
675
|
-
)
|
|
676
|
-
return einops.rearrange(x, "... n_heads d_head -> ... (n_heads d_head)")
|
|
677
|
-
|
|
678
|
-
self.reshape_fn_in = reshape_fn_in
|
|
679
|
-
|
|
680
|
-
self.reshape_fn_out = lambda x, d_head: einops.rearrange(
|
|
681
|
-
x, "... (n_heads d_head) -> ... n_heads d_head", d_head=d_head
|
|
682
|
-
)
|
|
683
|
-
self.hook_z_reshaping_mode = True
|
|
684
|
-
|
|
685
|
-
def turn_off_forward_pass_hook_z_reshaping(self):
|
|
686
|
-
self.reshape_fn_in = lambda x: x
|
|
687
|
-
self.reshape_fn_out = lambda x, d_head: x # noqa: ARG005
|
|
688
|
-
self.d_head = None
|
|
689
|
-
self.hook_z_reshaping_mode = False
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
class TopK(nn.Module):
|
|
693
|
-
def __init__(
|
|
694
|
-
self, k: int, postact_fn: Callable[[torch.Tensor], torch.Tensor] = nn.ReLU()
|
|
695
|
-
):
|
|
696
|
-
super().__init__()
|
|
697
|
-
self.k = k
|
|
698
|
-
self.postact_fn = postact_fn
|
|
699
|
-
|
|
700
|
-
# TODO: Use a fused kernel to speed up topk decoding like https://github.com/EleutherAI/sae/blob/main/sae/kernels.py
|
|
701
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
702
|
-
topk = torch.topk(x, k=self.k, dim=-1)
|
|
703
|
-
values = self.postact_fn(topk.values)
|
|
704
|
-
result = torch.zeros_like(x)
|
|
705
|
-
result.scatter_(-1, topk.indices, values)
|
|
706
|
-
return result
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
def get_activation_fn(
|
|
710
|
-
activation_fn: str, **kwargs: Any
|
|
711
|
-
) -> Callable[[torch.Tensor], torch.Tensor]:
|
|
712
|
-
if activation_fn == "relu":
|
|
713
|
-
return torch.nn.ReLU()
|
|
714
|
-
if activation_fn == "tanh-relu":
|
|
715
|
-
|
|
716
|
-
def tanh_relu(input: torch.Tensor) -> torch.Tensor:
|
|
717
|
-
input = torch.relu(input)
|
|
718
|
-
return torch.tanh(input)
|
|
719
|
-
|
|
720
|
-
return tanh_relu
|
|
721
|
-
if activation_fn == "topk":
|
|
722
|
-
if "k" not in kwargs:
|
|
723
|
-
raise ValueError("TopK activation function requires a k value.")
|
|
724
|
-
k = kwargs.get("k", 1) # Default k to 1 if not provided
|
|
725
|
-
postact_fn = kwargs.get(
|
|
726
|
-
"postact_fn", nn.ReLU()
|
|
727
|
-
) # Default post-activation to ReLU if not provided
|
|
728
|
-
|
|
729
|
-
return TopK(k, postact_fn)
|
|
730
|
-
raise ValueError(f"Unknown activation function: {activation_fn}")
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
_blank_hook = nn.Identity()
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
@contextmanager
|
|
737
|
-
def _disable_hooks(sae: SAE):
|
|
738
|
-
"""
|
|
739
|
-
Temporarily disable hooks for the SAE. Swaps out all the hooks with a fake modules that does nothing.
|
|
740
|
-
"""
|
|
741
|
-
try:
|
|
742
|
-
for hook_name in sae.hook_dict:
|
|
743
|
-
setattr(sae, hook_name, _blank_hook)
|
|
744
|
-
yield
|
|
745
|
-
finally:
|
|
746
|
-
for hook_name, hook in sae.hook_dict.items():
|
|
747
|
-
setattr(sae, hook_name, hook)
|