sae-lens 5.9.0__py3-none-any.whl → 6.0.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +22 -6
- sae_lens/analysis/hooked_sae_transformer.py +2 -2
- sae_lens/config.py +66 -23
- sae_lens/evals.py +6 -5
- sae_lens/{toolkit → loading}/pretrained_sae_loaders.py +33 -25
- sae_lens/regsitry.py +34 -0
- sae_lens/sae_training_runner.py +18 -33
- sae_lens/saes/gated_sae.py +247 -0
- sae_lens/saes/jumprelu_sae.py +368 -0
- sae_lens/saes/sae.py +970 -0
- sae_lens/saes/standard_sae.py +167 -0
- sae_lens/saes/topk_sae.py +305 -0
- sae_lens/training/activations_store.py +2 -2
- sae_lens/training/sae_trainer.py +13 -19
- sae_lens/training/upload_saes_to_huggingface.py +1 -1
- {sae_lens-5.9.0.dist-info → sae_lens-6.0.0rc1.dist-info}/METADATA +3 -3
- sae_lens-6.0.0rc1.dist-info/RECORD +32 -0
- sae_lens/sae.py +0 -747
- sae_lens/training/training_sae.py +0 -705
- sae_lens-5.9.0.dist-info/RECORD +0 -28
- /sae_lens/{toolkit → loading}/__init__.py +0 -0
- /sae_lens/{toolkit → loading}/pretrained_saes_directory.py +0 -0
- {sae_lens-5.9.0.dist-info → sae_lens-6.0.0rc1.dist-info}/LICENSE +0 -0
- {sae_lens-5.9.0.dist-info → sae_lens-6.0.0rc1.dist-info}/WHEEL +0 -0
sae_lens/saes/sae.py
ADDED
|
@@ -0,0 +1,970 @@
|
|
|
1
|
+
"""Base classes for Sparse Autoencoders (SAEs)."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import warnings
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from contextlib import contextmanager
|
|
7
|
+
from dataclasses import dataclass, field, fields
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any, Callable, Type, TypeVar
|
|
10
|
+
|
|
11
|
+
import einops
|
|
12
|
+
import torch
|
|
13
|
+
from jaxtyping import Float
|
|
14
|
+
from numpy.typing import NDArray
|
|
15
|
+
from safetensors.torch import save_file
|
|
16
|
+
from torch import nn
|
|
17
|
+
from transformer_lens.hook_points import HookedRootModule, HookPoint
|
|
18
|
+
from typing_extensions import deprecated, overload
|
|
19
|
+
|
|
20
|
+
from sae_lens import logger
|
|
21
|
+
from sae_lens.config import (
|
|
22
|
+
DTYPE_MAP,
|
|
23
|
+
SAE_CFG_FILENAME,
|
|
24
|
+
SAE_WEIGHTS_FILENAME,
|
|
25
|
+
SPARSITY_FILENAME,
|
|
26
|
+
LanguageModelSAERunnerConfig,
|
|
27
|
+
)
|
|
28
|
+
from sae_lens.loading.pretrained_sae_loaders import (
|
|
29
|
+
NAMED_PRETRAINED_SAE_LOADERS,
|
|
30
|
+
PretrainedSaeDiskLoader,
|
|
31
|
+
PretrainedSaeHuggingfaceLoader,
|
|
32
|
+
get_conversion_loader_name,
|
|
33
|
+
handle_config_defaulting,
|
|
34
|
+
sae_lens_disk_loader,
|
|
35
|
+
)
|
|
36
|
+
from sae_lens.loading.pretrained_saes_directory import (
|
|
37
|
+
get_config_overrides,
|
|
38
|
+
get_norm_scaling_factor,
|
|
39
|
+
get_pretrained_saes_directory,
|
|
40
|
+
get_repo_id_and_folder_name,
|
|
41
|
+
)
|
|
42
|
+
from sae_lens.regsitry import get_sae_class, get_sae_training_class
|
|
43
|
+
|
|
44
|
+
T = TypeVar("T", bound="SAE")
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@dataclass
|
|
48
|
+
class SAEConfig:
|
|
49
|
+
"""Base configuration for SAE models."""
|
|
50
|
+
|
|
51
|
+
architecture: str
|
|
52
|
+
d_in: int
|
|
53
|
+
d_sae: int
|
|
54
|
+
dtype: str
|
|
55
|
+
device: str
|
|
56
|
+
model_name: str
|
|
57
|
+
hook_name: str
|
|
58
|
+
hook_layer: int
|
|
59
|
+
hook_head_index: int | None
|
|
60
|
+
activation_fn: str
|
|
61
|
+
activation_fn_kwargs: dict[str, Any]
|
|
62
|
+
apply_b_dec_to_input: bool
|
|
63
|
+
finetuning_scaling_factor: bool
|
|
64
|
+
normalize_activations: str
|
|
65
|
+
context_size: int
|
|
66
|
+
dataset_path: str
|
|
67
|
+
dataset_trust_remote_code: bool
|
|
68
|
+
sae_lens_training_version: str
|
|
69
|
+
model_from_pretrained_kwargs: dict[str, Any] = field(default_factory=dict)
|
|
70
|
+
seqpos_slice: tuple[int, ...] | None = None
|
|
71
|
+
prepend_bos: bool = False
|
|
72
|
+
neuronpedia_id: str | None = None
|
|
73
|
+
|
|
74
|
+
def to_dict(self) -> dict[str, Any]:
|
|
75
|
+
return {field.name: getattr(self, field.name) for field in fields(self)}
|
|
76
|
+
|
|
77
|
+
@classmethod
|
|
78
|
+
def from_dict(cls, config_dict: dict[str, Any]) -> "SAEConfig":
|
|
79
|
+
valid_field_names = {field.name for field in fields(cls)}
|
|
80
|
+
valid_config_dict = {
|
|
81
|
+
key: val for key, val in config_dict.items() if key in valid_field_names
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
# Ensure seqpos_slice is a tuple
|
|
85
|
+
if (
|
|
86
|
+
"seqpos_slice" in valid_config_dict
|
|
87
|
+
and valid_config_dict["seqpos_slice"] is not None
|
|
88
|
+
and isinstance(valid_config_dict["seqpos_slice"], list)
|
|
89
|
+
):
|
|
90
|
+
valid_config_dict["seqpos_slice"] = tuple(valid_config_dict["seqpos_slice"])
|
|
91
|
+
|
|
92
|
+
return cls(**valid_config_dict)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@dataclass
|
|
96
|
+
class TrainStepOutput:
|
|
97
|
+
"""Output from a training step."""
|
|
98
|
+
|
|
99
|
+
sae_in: torch.Tensor
|
|
100
|
+
sae_out: torch.Tensor
|
|
101
|
+
feature_acts: torch.Tensor
|
|
102
|
+
hidden_pre: torch.Tensor
|
|
103
|
+
loss: torch.Tensor # we need to call backwards on this
|
|
104
|
+
losses: dict[str, torch.Tensor]
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
@dataclass
|
|
108
|
+
class TrainStepInput:
|
|
109
|
+
"""Input to a training step."""
|
|
110
|
+
|
|
111
|
+
sae_in: torch.Tensor
|
|
112
|
+
current_l1_coefficient: float
|
|
113
|
+
dead_neuron_mask: torch.Tensor | None
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class SAE(HookedRootModule, ABC):
|
|
117
|
+
"""Abstract base class for all SAE architectures."""
|
|
118
|
+
|
|
119
|
+
cfg: SAEConfig
|
|
120
|
+
dtype: torch.dtype
|
|
121
|
+
device: torch.device
|
|
122
|
+
use_error_term: bool
|
|
123
|
+
|
|
124
|
+
# For type checking only - don't provide default values
|
|
125
|
+
# These will be initialized by subclasses
|
|
126
|
+
W_enc: nn.Parameter
|
|
127
|
+
W_dec: nn.Parameter
|
|
128
|
+
b_dec: nn.Parameter
|
|
129
|
+
|
|
130
|
+
def __init__(self, cfg: SAEConfig, use_error_term: bool = False):
|
|
131
|
+
"""Initialize the SAE."""
|
|
132
|
+
super().__init__()
|
|
133
|
+
|
|
134
|
+
self.cfg = cfg
|
|
135
|
+
|
|
136
|
+
if cfg.model_from_pretrained_kwargs:
|
|
137
|
+
warnings.warn(
|
|
138
|
+
"\nThis SAE has non-empty model_from_pretrained_kwargs. "
|
|
139
|
+
"\nFor optimal performance, load the model like so:\n"
|
|
140
|
+
"model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)",
|
|
141
|
+
category=UserWarning,
|
|
142
|
+
stacklevel=1,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
self.dtype = DTYPE_MAP[cfg.dtype]
|
|
146
|
+
self.device = torch.device(cfg.device)
|
|
147
|
+
self.use_error_term = use_error_term
|
|
148
|
+
|
|
149
|
+
# Set up activation function
|
|
150
|
+
self.activation_fn = self._get_activation_fn()
|
|
151
|
+
|
|
152
|
+
# Initialize weights
|
|
153
|
+
self.initialize_weights()
|
|
154
|
+
|
|
155
|
+
# Handle presence / absence of scaling factor
|
|
156
|
+
if self.cfg.finetuning_scaling_factor:
|
|
157
|
+
self.apply_finetuning_scaling_factor = (
|
|
158
|
+
lambda x: x * self.finetuning_scaling_factor
|
|
159
|
+
)
|
|
160
|
+
else:
|
|
161
|
+
self.apply_finetuning_scaling_factor = lambda x: x
|
|
162
|
+
|
|
163
|
+
# Set up hooks
|
|
164
|
+
self.hook_sae_input = HookPoint()
|
|
165
|
+
self.hook_sae_acts_pre = HookPoint()
|
|
166
|
+
self.hook_sae_acts_post = HookPoint()
|
|
167
|
+
self.hook_sae_output = HookPoint()
|
|
168
|
+
self.hook_sae_recons = HookPoint()
|
|
169
|
+
self.hook_sae_error = HookPoint()
|
|
170
|
+
|
|
171
|
+
# handle hook_z reshaping if needed.
|
|
172
|
+
if self.cfg.hook_name.endswith("_z"):
|
|
173
|
+
# print(f"Setting up hook_z reshaping for {self.cfg.hook_name}")
|
|
174
|
+
self.turn_on_forward_pass_hook_z_reshaping()
|
|
175
|
+
else:
|
|
176
|
+
# print(f"No hook_z reshaping needed for {self.cfg.hook_name}")
|
|
177
|
+
self.turn_off_forward_pass_hook_z_reshaping()
|
|
178
|
+
|
|
179
|
+
# Set up activation normalization
|
|
180
|
+
self._setup_activation_normalization()
|
|
181
|
+
|
|
182
|
+
self.setup() # Required for HookedRootModule
|
|
183
|
+
|
|
184
|
+
@torch.no_grad()
|
|
185
|
+
def fold_activation_norm_scaling_factor(self, scaling_factor: float):
|
|
186
|
+
self.W_enc.data *= scaling_factor # type: ignore
|
|
187
|
+
self.W_dec.data /= scaling_factor # type: ignore
|
|
188
|
+
self.b_dec.data /= scaling_factor # type: ignore
|
|
189
|
+
self.cfg.normalize_activations = "none"
|
|
190
|
+
|
|
191
|
+
def _get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
|
|
192
|
+
"""Get the activation function specified in config."""
|
|
193
|
+
return self._get_activation_fn_static(
|
|
194
|
+
self.cfg.activation_fn, **(self.cfg.activation_fn_kwargs or {})
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
@staticmethod
|
|
198
|
+
def _get_activation_fn_static(
|
|
199
|
+
activation_fn: str, **kwargs: Any
|
|
200
|
+
) -> Callable[[torch.Tensor], torch.Tensor]:
|
|
201
|
+
"""Get the activation function from a string specification."""
|
|
202
|
+
if activation_fn == "relu":
|
|
203
|
+
return torch.nn.ReLU()
|
|
204
|
+
if activation_fn == "tanh-relu":
|
|
205
|
+
|
|
206
|
+
def tanh_relu(input: torch.Tensor) -> torch.Tensor:
|
|
207
|
+
input = torch.relu(input)
|
|
208
|
+
return torch.tanh(input)
|
|
209
|
+
|
|
210
|
+
return tanh_relu
|
|
211
|
+
if activation_fn == "topk":
|
|
212
|
+
if "k" not in kwargs:
|
|
213
|
+
raise ValueError("TopK activation function requires a k value.")
|
|
214
|
+
k = kwargs.get("k", 1) # Default k to 1 if not provided
|
|
215
|
+
|
|
216
|
+
def topk_fn(x: torch.Tensor) -> torch.Tensor:
|
|
217
|
+
topk = torch.topk(x.flatten(start_dim=-1), k=k, dim=-1)
|
|
218
|
+
values = torch.relu(topk.values)
|
|
219
|
+
result = torch.zeros_like(x.flatten(start_dim=-1))
|
|
220
|
+
result.scatter_(-1, topk.indices, values)
|
|
221
|
+
return result.view_as(x)
|
|
222
|
+
|
|
223
|
+
return topk_fn
|
|
224
|
+
raise ValueError(f"Unknown activation function: {activation_fn}")
|
|
225
|
+
|
|
226
|
+
def _setup_activation_normalization(self):
|
|
227
|
+
"""Set up activation normalization functions based on config."""
|
|
228
|
+
if self.cfg.normalize_activations == "constant_norm_rescale":
|
|
229
|
+
|
|
230
|
+
def run_time_activation_norm_fn_in(x: torch.Tensor) -> torch.Tensor:
|
|
231
|
+
self.x_norm_coeff = (self.cfg.d_in**0.5) / x.norm(dim=-1, keepdim=True)
|
|
232
|
+
return x * self.x_norm_coeff
|
|
233
|
+
|
|
234
|
+
def run_time_activation_norm_fn_out(x: torch.Tensor) -> torch.Tensor:
|
|
235
|
+
x = x / self.x_norm_coeff # type: ignore
|
|
236
|
+
del self.x_norm_coeff
|
|
237
|
+
return x
|
|
238
|
+
|
|
239
|
+
self.run_time_activation_norm_fn_in = run_time_activation_norm_fn_in
|
|
240
|
+
self.run_time_activation_norm_fn_out = run_time_activation_norm_fn_out
|
|
241
|
+
|
|
242
|
+
elif self.cfg.normalize_activations == "layer_norm":
|
|
243
|
+
|
|
244
|
+
def run_time_activation_ln_in(
|
|
245
|
+
x: torch.Tensor, eps: float = 1e-5
|
|
246
|
+
) -> torch.Tensor:
|
|
247
|
+
mu = x.mean(dim=-1, keepdim=True)
|
|
248
|
+
x = x - mu
|
|
249
|
+
std = x.std(dim=-1, keepdim=True)
|
|
250
|
+
x = x / (std + eps)
|
|
251
|
+
self.ln_mu = mu
|
|
252
|
+
self.ln_std = std
|
|
253
|
+
return x
|
|
254
|
+
|
|
255
|
+
def run_time_activation_ln_out(
|
|
256
|
+
x: torch.Tensor,
|
|
257
|
+
eps: float = 1e-5, # noqa: ARG001
|
|
258
|
+
) -> torch.Tensor:
|
|
259
|
+
return x * self.ln_std + self.ln_mu # type: ignore
|
|
260
|
+
|
|
261
|
+
self.run_time_activation_norm_fn_in = run_time_activation_ln_in
|
|
262
|
+
self.run_time_activation_norm_fn_out = run_time_activation_ln_out
|
|
263
|
+
else:
|
|
264
|
+
self.run_time_activation_norm_fn_in = lambda x: x
|
|
265
|
+
self.run_time_activation_norm_fn_out = lambda x: x
|
|
266
|
+
|
|
267
|
+
@abstractmethod
|
|
268
|
+
def initialize_weights(self):
|
|
269
|
+
"""Initialize model weights."""
|
|
270
|
+
pass
|
|
271
|
+
|
|
272
|
+
@abstractmethod
|
|
273
|
+
def encode(
|
|
274
|
+
self, x: Float[torch.Tensor, "... d_in"]
|
|
275
|
+
) -> Float[torch.Tensor, "... d_sae"]:
|
|
276
|
+
"""Encode input tensor to feature space."""
|
|
277
|
+
pass
|
|
278
|
+
|
|
279
|
+
@abstractmethod
|
|
280
|
+
def decode(
|
|
281
|
+
self, feature_acts: Float[torch.Tensor, "... d_sae"]
|
|
282
|
+
) -> Float[torch.Tensor, "... d_in"]:
|
|
283
|
+
"""Decode feature activations back to input space."""
|
|
284
|
+
pass
|
|
285
|
+
|
|
286
|
+
def turn_on_forward_pass_hook_z_reshaping(self):
|
|
287
|
+
if not self.cfg.hook_name.endswith("_z"):
|
|
288
|
+
raise ValueError("This method should only be called for hook_z SAEs.")
|
|
289
|
+
|
|
290
|
+
# print(f"Turning on hook_z reshaping for {self.cfg.hook_name}")
|
|
291
|
+
|
|
292
|
+
def reshape_fn_in(x: torch.Tensor):
|
|
293
|
+
# print(f"reshape_fn_in input shape: {x.shape}")
|
|
294
|
+
self.d_head = x.shape[-1]
|
|
295
|
+
# print(f"Setting d_head to: {self.d_head}")
|
|
296
|
+
self.reshape_fn_in = lambda x: einops.rearrange(
|
|
297
|
+
x, "... n_heads d_head -> ... (n_heads d_head)"
|
|
298
|
+
)
|
|
299
|
+
return einops.rearrange(x, "... n_heads d_head -> ... (n_heads d_head)")
|
|
300
|
+
|
|
301
|
+
self.reshape_fn_in = reshape_fn_in
|
|
302
|
+
self.reshape_fn_out = lambda x, d_head: einops.rearrange(
|
|
303
|
+
x, "... (n_heads d_head) -> ... n_heads d_head", d_head=d_head
|
|
304
|
+
)
|
|
305
|
+
self.hook_z_reshaping_mode = True
|
|
306
|
+
# print(f"hook_z reshaping turned on, self.d_head={getattr(self, 'd_head', None)}")
|
|
307
|
+
|
|
308
|
+
def turn_off_forward_pass_hook_z_reshaping(self):
|
|
309
|
+
self.reshape_fn_in = lambda x: x
|
|
310
|
+
self.reshape_fn_out = lambda x, d_head: x # noqa: ARG005
|
|
311
|
+
self.d_head = None
|
|
312
|
+
self.hook_z_reshaping_mode = False
|
|
313
|
+
|
|
314
|
+
@overload
|
|
315
|
+
def to(
|
|
316
|
+
self: T,
|
|
317
|
+
device: torch.device | str | None = ...,
|
|
318
|
+
dtype: torch.dtype | None = ...,
|
|
319
|
+
non_blocking: bool = ...,
|
|
320
|
+
) -> T: ...
|
|
321
|
+
|
|
322
|
+
@overload
|
|
323
|
+
def to(self: T, dtype: torch.dtype, non_blocking: bool = ...) -> T: ...
|
|
324
|
+
|
|
325
|
+
@overload
|
|
326
|
+
def to(self: T, tensor: torch.Tensor, non_blocking: bool = ...) -> T: ...
|
|
327
|
+
|
|
328
|
+
def to(self: T, *args: Any, **kwargs: Any) -> T: # type: ignore
|
|
329
|
+
device_arg = None
|
|
330
|
+
dtype_arg = None
|
|
331
|
+
|
|
332
|
+
# Check args
|
|
333
|
+
for arg in args:
|
|
334
|
+
if isinstance(arg, (torch.device, str)):
|
|
335
|
+
device_arg = arg
|
|
336
|
+
elif isinstance(arg, torch.dtype):
|
|
337
|
+
dtype_arg = arg
|
|
338
|
+
elif isinstance(arg, torch.Tensor):
|
|
339
|
+
device_arg = arg.device
|
|
340
|
+
dtype_arg = arg.dtype
|
|
341
|
+
|
|
342
|
+
# Check kwargs
|
|
343
|
+
device_arg = kwargs.get("device", device_arg)
|
|
344
|
+
dtype_arg = kwargs.get("dtype", dtype_arg)
|
|
345
|
+
|
|
346
|
+
# Update device in config if provided
|
|
347
|
+
if device_arg is not None:
|
|
348
|
+
# Convert device to torch.device if it's a string
|
|
349
|
+
device = (
|
|
350
|
+
torch.device(device_arg) if isinstance(device_arg, str) else device_arg
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
# Update the cfg.device
|
|
354
|
+
self.cfg.device = str(device)
|
|
355
|
+
|
|
356
|
+
# Update the device property
|
|
357
|
+
self.device = device
|
|
358
|
+
|
|
359
|
+
# Update dtype in config if provided
|
|
360
|
+
if dtype_arg is not None:
|
|
361
|
+
# Update the cfg.dtype
|
|
362
|
+
self.cfg.dtype = str(dtype_arg)
|
|
363
|
+
|
|
364
|
+
# Update the dtype property
|
|
365
|
+
self.dtype = dtype_arg
|
|
366
|
+
|
|
367
|
+
return super().to(*args, **kwargs)
|
|
368
|
+
|
|
369
|
+
def process_sae_in(
|
|
370
|
+
self, sae_in: Float[torch.Tensor, "... d_in"]
|
|
371
|
+
) -> Float[torch.Tensor, "... d_in"]:
|
|
372
|
+
# print(f"Input shape to process_sae_in: {sae_in.shape}")
|
|
373
|
+
# print(f"self.cfg.hook_name: {self.cfg.hook_name}")
|
|
374
|
+
# print(f"self.b_dec shape: {self.b_dec.shape}")
|
|
375
|
+
# print(f"Hook z reshaping mode: {getattr(self, 'hook_z_reshaping_mode', False)}")
|
|
376
|
+
|
|
377
|
+
sae_in = sae_in.to(self.dtype)
|
|
378
|
+
|
|
379
|
+
# print(f"Shape before reshape_fn_in: {sae_in.shape}")
|
|
380
|
+
sae_in = self.reshape_fn_in(sae_in)
|
|
381
|
+
# print(f"Shape after reshape_fn_in: {sae_in.shape}")
|
|
382
|
+
|
|
383
|
+
sae_in = self.hook_sae_input(sae_in)
|
|
384
|
+
sae_in = self.run_time_activation_norm_fn_in(sae_in)
|
|
385
|
+
|
|
386
|
+
# Here's where the error happens
|
|
387
|
+
bias_term = self.b_dec * self.cfg.apply_b_dec_to_input
|
|
388
|
+
# print(f"Bias term shape: {bias_term.shape}")
|
|
389
|
+
|
|
390
|
+
return sae_in - bias_term
|
|
391
|
+
|
|
392
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
393
|
+
"""Forward pass through the SAE."""
|
|
394
|
+
feature_acts = self.encode(x)
|
|
395
|
+
sae_out = self.decode(feature_acts)
|
|
396
|
+
|
|
397
|
+
if self.use_error_term:
|
|
398
|
+
with torch.no_grad():
|
|
399
|
+
# Recompute without hooks for true error term
|
|
400
|
+
with _disable_hooks(self):
|
|
401
|
+
feature_acts_clean = self.encode(x)
|
|
402
|
+
x_reconstruct_clean = self.decode(feature_acts_clean)
|
|
403
|
+
sae_error = self.hook_sae_error(x - x_reconstruct_clean)
|
|
404
|
+
sae_out = sae_out + sae_error
|
|
405
|
+
|
|
406
|
+
return self.hook_sae_output(sae_out)
|
|
407
|
+
|
|
408
|
+
# overwrite this in subclasses to modify the state_dict in-place before saving
|
|
409
|
+
def process_state_dict_for_saving(self, state_dict: dict[str, Any]) -> None:
|
|
410
|
+
pass
|
|
411
|
+
|
|
412
|
+
# overwrite this in subclasses to modify the state_dict in-place after loading
|
|
413
|
+
def process_state_dict_for_loading(self, state_dict: dict[str, Any]) -> None:
|
|
414
|
+
pass
|
|
415
|
+
|
|
416
|
+
@torch.no_grad()
|
|
417
|
+
def fold_W_dec_norm(self):
|
|
418
|
+
"""Fold decoder norms into encoder."""
|
|
419
|
+
W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
|
|
420
|
+
self.W_dec.data = self.W_dec.data / W_dec_norms
|
|
421
|
+
self.W_enc.data = self.W_enc.data * W_dec_norms.T
|
|
422
|
+
|
|
423
|
+
# Only update b_enc if it exists (standard/jumprelu architectures)
|
|
424
|
+
if hasattr(self, "b_enc") and isinstance(self.b_enc, nn.Parameter):
|
|
425
|
+
self.b_enc.data = self.b_enc.data * W_dec_norms.squeeze()
|
|
426
|
+
|
|
427
|
+
def get_name(self):
|
|
428
|
+
"""Generate a name for this SAE."""
|
|
429
|
+
return f"sae_{self.cfg.model_name}_{self.cfg.hook_name}_{self.cfg.d_sae}"
|
|
430
|
+
|
|
431
|
+
def save_model(
|
|
432
|
+
self, path: str | Path, sparsity: torch.Tensor | None = None
|
|
433
|
+
) -> tuple[Path, Path, Path | None]:
|
|
434
|
+
"""Save model weights, config, and optional sparsity tensor to disk."""
|
|
435
|
+
path = Path(path)
|
|
436
|
+
if not path.exists():
|
|
437
|
+
path.mkdir(parents=True)
|
|
438
|
+
|
|
439
|
+
# Generate the weights
|
|
440
|
+
state_dict = self.state_dict() # Use internal SAE state dict
|
|
441
|
+
self.process_state_dict_for_saving(state_dict)
|
|
442
|
+
model_weights_path = path / SAE_WEIGHTS_FILENAME
|
|
443
|
+
save_file(state_dict, model_weights_path)
|
|
444
|
+
|
|
445
|
+
# Save the config
|
|
446
|
+
config = self.cfg.to_dict()
|
|
447
|
+
cfg_path = path / SAE_CFG_FILENAME
|
|
448
|
+
with open(cfg_path, "w") as f:
|
|
449
|
+
json.dump(config, f)
|
|
450
|
+
|
|
451
|
+
if sparsity is not None:
|
|
452
|
+
sparsity_in_dict = {"sparsity": sparsity}
|
|
453
|
+
sparsity_path = path / SPARSITY_FILENAME
|
|
454
|
+
save_file(sparsity_in_dict, sparsity_path)
|
|
455
|
+
return model_weights_path, cfg_path, sparsity_path
|
|
456
|
+
|
|
457
|
+
return model_weights_path, cfg_path, None
|
|
458
|
+
|
|
459
|
+
## Initialization Methods
|
|
460
|
+
@torch.no_grad()
|
|
461
|
+
def initialize_b_dec_with_precalculated(self, origin: torch.Tensor):
|
|
462
|
+
out = torch.tensor(origin, dtype=self.dtype, device=self.device)
|
|
463
|
+
self.b_dec.data = out
|
|
464
|
+
|
|
465
|
+
@torch.no_grad()
|
|
466
|
+
def initialize_b_dec_with_mean(self, all_activations: torch.Tensor):
|
|
467
|
+
previous_b_dec = self.b_dec.clone().cpu()
|
|
468
|
+
out = all_activations.mean(dim=0)
|
|
469
|
+
|
|
470
|
+
previous_distances = torch.norm(all_activations - previous_b_dec, dim=-1)
|
|
471
|
+
distances = torch.norm(all_activations - out, dim=-1)
|
|
472
|
+
|
|
473
|
+
logger.info("Reinitializing b_dec with mean of activations")
|
|
474
|
+
logger.debug(
|
|
475
|
+
f"Previous distances: {previous_distances.median(0).values.mean().item()}"
|
|
476
|
+
)
|
|
477
|
+
logger.debug(f"New distances: {distances.median(0).values.mean().item()}")
|
|
478
|
+
|
|
479
|
+
self.b_dec.data = out.to(self.dtype).to(self.device)
|
|
480
|
+
|
|
481
|
+
# Class methods for loading models
|
|
482
|
+
@classmethod
|
|
483
|
+
@deprecated("Use load_from_disk instead")
|
|
484
|
+
def load_from_pretrained(
|
|
485
|
+
cls: Type[T], path: str | Path, device: str = "cpu", dtype: str | None = None
|
|
486
|
+
) -> T:
|
|
487
|
+
return cls.load_from_disk(path, device=device, dtype=dtype)
|
|
488
|
+
|
|
489
|
+
@classmethod
|
|
490
|
+
def load_from_disk(
|
|
491
|
+
cls: Type[T],
|
|
492
|
+
path: str | Path,
|
|
493
|
+
device: str = "cpu",
|
|
494
|
+
dtype: str | None = None,
|
|
495
|
+
converter: PretrainedSaeDiskLoader = sae_lens_disk_loader,
|
|
496
|
+
) -> T:
|
|
497
|
+
overrides = {"dtype": dtype} if dtype is not None else None
|
|
498
|
+
cfg_dict, state_dict = converter(path, device, cfg_overrides=overrides)
|
|
499
|
+
cfg_dict = handle_config_defaulting(cfg_dict)
|
|
500
|
+
sae_config_cls = cls.get_sae_config_class_for_architecture(
|
|
501
|
+
cfg_dict["architecture"]
|
|
502
|
+
)
|
|
503
|
+
sae_cfg = sae_config_cls.from_dict(cfg_dict)
|
|
504
|
+
sae_cls = cls.get_sae_class_for_architecture(sae_cfg.architecture)
|
|
505
|
+
sae = sae_cls(sae_cfg)
|
|
506
|
+
sae.process_state_dict_for_loading(state_dict)
|
|
507
|
+
sae.load_state_dict(state_dict)
|
|
508
|
+
return sae
|
|
509
|
+
|
|
510
|
+
@classmethod
|
|
511
|
+
def from_pretrained(
|
|
512
|
+
cls,
|
|
513
|
+
release: str,
|
|
514
|
+
sae_id: str,
|
|
515
|
+
device: str = "cpu",
|
|
516
|
+
force_download: bool = False,
|
|
517
|
+
converter: PretrainedSaeHuggingfaceLoader | None = None,
|
|
518
|
+
) -> tuple["SAE", dict[str, Any], torch.Tensor | None]:
|
|
519
|
+
"""
|
|
520
|
+
Load a pretrained SAE from the Hugging Face model hub.
|
|
521
|
+
|
|
522
|
+
Args:
|
|
523
|
+
release: The release name. This will be mapped to a huggingface repo id based on the pretrained_saes.yaml file.
|
|
524
|
+
id: The id of the SAE to load. This will be mapped to a path in the huggingface repo.
|
|
525
|
+
device: The device to load the SAE on.
|
|
526
|
+
return_sparsity_if_present: If True, will return the log sparsity tensor if it is present in the model directory in the Hugging Face model hub.
|
|
527
|
+
"""
|
|
528
|
+
|
|
529
|
+
# get sae directory
|
|
530
|
+
sae_directory = get_pretrained_saes_directory()
|
|
531
|
+
|
|
532
|
+
# Validate release and sae_id
|
|
533
|
+
if release not in sae_directory:
|
|
534
|
+
if "/" not in release:
|
|
535
|
+
raise ValueError(
|
|
536
|
+
f"Release {release} not found in pretrained SAEs directory, and is not a valid huggingface repo."
|
|
537
|
+
)
|
|
538
|
+
elif sae_id not in sae_directory[release].saes_map:
|
|
539
|
+
# Handle special cases like Gemma Scope
|
|
540
|
+
if (
|
|
541
|
+
"gemma-scope" in release
|
|
542
|
+
and "canonical" not in release
|
|
543
|
+
and f"{release}-canonical" in sae_directory
|
|
544
|
+
):
|
|
545
|
+
canonical_ids = list(
|
|
546
|
+
sae_directory[release + "-canonical"].saes_map.keys()
|
|
547
|
+
)
|
|
548
|
+
# Shorten the lengthy string of valid IDs
|
|
549
|
+
if len(canonical_ids) > 5:
|
|
550
|
+
str_canonical_ids = str(canonical_ids[:5])[:-1] + ", ...]"
|
|
551
|
+
else:
|
|
552
|
+
str_canonical_ids = str(canonical_ids)
|
|
553
|
+
value_suffix = f" If you don't want to specify an L0 value, consider using release {release}-canonical which has valid IDs {str_canonical_ids}"
|
|
554
|
+
else:
|
|
555
|
+
value_suffix = ""
|
|
556
|
+
|
|
557
|
+
valid_ids = list(sae_directory[release].saes_map.keys())
|
|
558
|
+
# Shorten the lengthy string of valid IDs
|
|
559
|
+
if len(valid_ids) > 5:
|
|
560
|
+
str_valid_ids = str(valid_ids[:5])[:-1] + ", ...]"
|
|
561
|
+
else:
|
|
562
|
+
str_valid_ids = str(valid_ids)
|
|
563
|
+
|
|
564
|
+
raise ValueError(
|
|
565
|
+
f"ID {sae_id} not found in release {release}. Valid IDs are {str_valid_ids}."
|
|
566
|
+
+ value_suffix
|
|
567
|
+
)
|
|
568
|
+
|
|
569
|
+
conversion_loader = (
|
|
570
|
+
converter
|
|
571
|
+
or NAMED_PRETRAINED_SAE_LOADERS[get_conversion_loader_name(release)]
|
|
572
|
+
)
|
|
573
|
+
repo_id, folder_name = get_repo_id_and_folder_name(release, sae_id)
|
|
574
|
+
config_overrides = get_config_overrides(release, sae_id)
|
|
575
|
+
config_overrides["device"] = device
|
|
576
|
+
|
|
577
|
+
# Load config and weights
|
|
578
|
+
cfg_dict, state_dict, log_sparsities = conversion_loader(
|
|
579
|
+
repo_id=repo_id,
|
|
580
|
+
folder_name=folder_name,
|
|
581
|
+
device=device,
|
|
582
|
+
force_download=force_download,
|
|
583
|
+
cfg_overrides=config_overrides,
|
|
584
|
+
)
|
|
585
|
+
cfg_dict = handle_config_defaulting(cfg_dict)
|
|
586
|
+
|
|
587
|
+
# Rename keys to match SAEConfig field names
|
|
588
|
+
renamed_cfg_dict = {}
|
|
589
|
+
rename_map = {
|
|
590
|
+
"hook_point": "hook_name",
|
|
591
|
+
"hook_point_layer": "hook_layer",
|
|
592
|
+
"hook_point_head_index": "hook_head_index",
|
|
593
|
+
"activation_fn": "activation_fn",
|
|
594
|
+
}
|
|
595
|
+
|
|
596
|
+
for k, v in cfg_dict.items():
|
|
597
|
+
renamed_cfg_dict[rename_map.get(k, k)] = v
|
|
598
|
+
|
|
599
|
+
# Set default values for required fields
|
|
600
|
+
renamed_cfg_dict.setdefault("activation_fn_kwargs", {})
|
|
601
|
+
renamed_cfg_dict.setdefault("seqpos_slice", None)
|
|
602
|
+
|
|
603
|
+
# Create SAE with appropriate architecture
|
|
604
|
+
sae_config_cls = cls.get_sae_config_class_for_architecture(
|
|
605
|
+
renamed_cfg_dict["architecture"]
|
|
606
|
+
)
|
|
607
|
+
sae_cfg = sae_config_cls.from_dict(renamed_cfg_dict)
|
|
608
|
+
sae_cls = cls.get_sae_class_for_architecture(sae_cfg.architecture)
|
|
609
|
+
sae = sae_cls(sae_cfg)
|
|
610
|
+
sae.process_state_dict_for_loading(state_dict)
|
|
611
|
+
sae.load_state_dict(state_dict)
|
|
612
|
+
|
|
613
|
+
# Apply normalization if needed
|
|
614
|
+
if renamed_cfg_dict.get("normalize_activations") == "expected_average_only_in":
|
|
615
|
+
norm_scaling_factor = get_norm_scaling_factor(release, sae_id)
|
|
616
|
+
if norm_scaling_factor is not None:
|
|
617
|
+
sae.fold_activation_norm_scaling_factor(norm_scaling_factor)
|
|
618
|
+
renamed_cfg_dict["normalize_activations"] = "none"
|
|
619
|
+
else:
|
|
620
|
+
warnings.warn(
|
|
621
|
+
f"norm_scaling_factor not found for {release} and {sae_id}, but normalize_activations is 'expected_average_only_in'. Skipping normalization folding."
|
|
622
|
+
)
|
|
623
|
+
|
|
624
|
+
return sae, renamed_cfg_dict, log_sparsities
|
|
625
|
+
|
|
626
|
+
@classmethod
|
|
627
|
+
def from_dict(cls: Type[T], config_dict: dict[str, Any]) -> T:
|
|
628
|
+
"""Create an SAE from a config dictionary."""
|
|
629
|
+
sae_cls = cls.get_sae_class_for_architecture(config_dict["architecture"])
|
|
630
|
+
sae_config_cls = cls.get_sae_config_class_for_architecture(
|
|
631
|
+
config_dict["architecture"]
|
|
632
|
+
)
|
|
633
|
+
return sae_cls(sae_config_cls.from_dict(config_dict))
|
|
634
|
+
|
|
635
|
+
@classmethod
|
|
636
|
+
def get_sae_class_for_architecture(cls: Type[T], architecture: str) -> Type[T]:
|
|
637
|
+
"""Get the SAE class for a given architecture."""
|
|
638
|
+
sae_cls = get_sae_class(architecture)
|
|
639
|
+
if not issubclass(sae_cls, cls):
|
|
640
|
+
raise ValueError(
|
|
641
|
+
f"Loaded SAE is not of type {cls.__name__}. Use {sae_cls.__name__} instead"
|
|
642
|
+
)
|
|
643
|
+
return sae_cls
|
|
644
|
+
|
|
645
|
+
# in the future, this can be used to load different config classes for different architectures
|
|
646
|
+
@classmethod
|
|
647
|
+
def get_sae_config_class_for_architecture(
|
|
648
|
+
cls: Type[T],
|
|
649
|
+
architecture: str, # noqa: ARG003
|
|
650
|
+
) -> type[SAEConfig]:
|
|
651
|
+
return SAEConfig
|
|
652
|
+
|
|
653
|
+
|
|
654
|
+
@dataclass(kw_only=True)
|
|
655
|
+
class TrainingSAEConfig(SAEConfig):
|
|
656
|
+
# Sparsity Loss Calculations
|
|
657
|
+
l1_coefficient: float
|
|
658
|
+
lp_norm: float
|
|
659
|
+
use_ghost_grads: bool
|
|
660
|
+
normalize_sae_decoder: bool
|
|
661
|
+
noise_scale: float
|
|
662
|
+
decoder_orthogonal_init: bool
|
|
663
|
+
mse_loss_normalization: str | None
|
|
664
|
+
jumprelu_init_threshold: float
|
|
665
|
+
jumprelu_bandwidth: float
|
|
666
|
+
decoder_heuristic_init: bool
|
|
667
|
+
init_encoder_as_decoder_transpose: bool
|
|
668
|
+
scale_sparsity_penalty_by_decoder_norm: bool
|
|
669
|
+
|
|
670
|
+
@classmethod
|
|
671
|
+
def from_sae_runner_config(
|
|
672
|
+
cls, cfg: LanguageModelSAERunnerConfig
|
|
673
|
+
) -> "TrainingSAEConfig":
|
|
674
|
+
return cls(
|
|
675
|
+
# base config
|
|
676
|
+
architecture=cfg.architecture,
|
|
677
|
+
d_in=cfg.d_in,
|
|
678
|
+
d_sae=cfg.d_sae, # type: ignore
|
|
679
|
+
dtype=cfg.dtype,
|
|
680
|
+
device=cfg.device,
|
|
681
|
+
model_name=cfg.model_name,
|
|
682
|
+
hook_name=cfg.hook_name,
|
|
683
|
+
hook_layer=cfg.hook_layer,
|
|
684
|
+
hook_head_index=cfg.hook_head_index,
|
|
685
|
+
activation_fn=cfg.activation_fn,
|
|
686
|
+
activation_fn_kwargs=cfg.activation_fn_kwargs,
|
|
687
|
+
apply_b_dec_to_input=cfg.apply_b_dec_to_input,
|
|
688
|
+
finetuning_scaling_factor=cfg.finetuning_method is not None,
|
|
689
|
+
sae_lens_training_version=cfg.sae_lens_training_version,
|
|
690
|
+
context_size=cfg.context_size,
|
|
691
|
+
dataset_path=cfg.dataset_path,
|
|
692
|
+
prepend_bos=cfg.prepend_bos,
|
|
693
|
+
seqpos_slice=tuple(x for x in cfg.seqpos_slice if x is not None)
|
|
694
|
+
if cfg.seqpos_slice is not None
|
|
695
|
+
else None,
|
|
696
|
+
# Training cfg
|
|
697
|
+
l1_coefficient=cfg.l1_coefficient,
|
|
698
|
+
lp_norm=cfg.lp_norm,
|
|
699
|
+
use_ghost_grads=cfg.use_ghost_grads,
|
|
700
|
+
normalize_sae_decoder=cfg.normalize_sae_decoder,
|
|
701
|
+
noise_scale=cfg.noise_scale,
|
|
702
|
+
decoder_orthogonal_init=cfg.decoder_orthogonal_init,
|
|
703
|
+
mse_loss_normalization=cfg.mse_loss_normalization,
|
|
704
|
+
decoder_heuristic_init=cfg.decoder_heuristic_init,
|
|
705
|
+
init_encoder_as_decoder_transpose=cfg.init_encoder_as_decoder_transpose,
|
|
706
|
+
scale_sparsity_penalty_by_decoder_norm=cfg.scale_sparsity_penalty_by_decoder_norm,
|
|
707
|
+
normalize_activations=cfg.normalize_activations,
|
|
708
|
+
dataset_trust_remote_code=cfg.dataset_trust_remote_code,
|
|
709
|
+
model_from_pretrained_kwargs=cfg.model_from_pretrained_kwargs or {},
|
|
710
|
+
jumprelu_init_threshold=cfg.jumprelu_init_threshold,
|
|
711
|
+
jumprelu_bandwidth=cfg.jumprelu_bandwidth,
|
|
712
|
+
)
|
|
713
|
+
|
|
714
|
+
@classmethod
|
|
715
|
+
def from_dict(cls, config_dict: dict[str, Any]) -> "TrainingSAEConfig":
|
|
716
|
+
# remove any keys that are not in the dataclass
|
|
717
|
+
# since we sometimes enhance the config with the whole LM runner config
|
|
718
|
+
valid_field_names = {field.name for field in fields(cls)}
|
|
719
|
+
valid_config_dict = {
|
|
720
|
+
key: val for key, val in config_dict.items() if key in valid_field_names
|
|
721
|
+
}
|
|
722
|
+
|
|
723
|
+
# ensure seqpos slice is tuple
|
|
724
|
+
# ensure that seqpos slices is a tuple
|
|
725
|
+
# Ensure seqpos_slice is a tuple
|
|
726
|
+
if "seqpos_slice" in valid_config_dict:
|
|
727
|
+
if isinstance(valid_config_dict["seqpos_slice"], list):
|
|
728
|
+
valid_config_dict["seqpos_slice"] = tuple(
|
|
729
|
+
valid_config_dict["seqpos_slice"]
|
|
730
|
+
)
|
|
731
|
+
elif not isinstance(valid_config_dict["seqpos_slice"], tuple):
|
|
732
|
+
valid_config_dict["seqpos_slice"] = (valid_config_dict["seqpos_slice"],)
|
|
733
|
+
|
|
734
|
+
return TrainingSAEConfig(**valid_config_dict)
|
|
735
|
+
|
|
736
|
+
def to_dict(self) -> dict[str, Any]:
|
|
737
|
+
return {
|
|
738
|
+
**super().to_dict(),
|
|
739
|
+
"l1_coefficient": self.l1_coefficient,
|
|
740
|
+
"lp_norm": self.lp_norm,
|
|
741
|
+
"use_ghost_grads": self.use_ghost_grads,
|
|
742
|
+
"normalize_sae_decoder": self.normalize_sae_decoder,
|
|
743
|
+
"noise_scale": self.noise_scale,
|
|
744
|
+
"decoder_orthogonal_init": self.decoder_orthogonal_init,
|
|
745
|
+
"init_encoder_as_decoder_transpose": self.init_encoder_as_decoder_transpose,
|
|
746
|
+
"mse_loss_normalization": self.mse_loss_normalization,
|
|
747
|
+
"decoder_heuristic_init": self.decoder_heuristic_init,
|
|
748
|
+
"scale_sparsity_penalty_by_decoder_norm": self.scale_sparsity_penalty_by_decoder_norm,
|
|
749
|
+
"normalize_activations": self.normalize_activations,
|
|
750
|
+
"jumprelu_init_threshold": self.jumprelu_init_threshold,
|
|
751
|
+
"jumprelu_bandwidth": self.jumprelu_bandwidth,
|
|
752
|
+
}
|
|
753
|
+
|
|
754
|
+
# this needs to exist so we can initialize the parent sae cfg without the training specific
|
|
755
|
+
# parameters. Maybe there's a cleaner way to do this
|
|
756
|
+
def get_base_sae_cfg_dict(self) -> dict[str, Any]:
|
|
757
|
+
return {
|
|
758
|
+
"architecture": self.architecture,
|
|
759
|
+
"d_in": self.d_in,
|
|
760
|
+
"d_sae": self.d_sae,
|
|
761
|
+
"activation_fn": self.activation_fn,
|
|
762
|
+
"activation_fn_kwargs": self.activation_fn_kwargs,
|
|
763
|
+
"apply_b_dec_to_input": self.apply_b_dec_to_input,
|
|
764
|
+
"dtype": self.dtype,
|
|
765
|
+
"model_name": self.model_name,
|
|
766
|
+
"hook_name": self.hook_name,
|
|
767
|
+
"hook_layer": self.hook_layer,
|
|
768
|
+
"hook_head_index": self.hook_head_index,
|
|
769
|
+
"device": self.device,
|
|
770
|
+
"context_size": self.context_size,
|
|
771
|
+
"prepend_bos": self.prepend_bos,
|
|
772
|
+
"finetuning_scaling_factor": self.finetuning_scaling_factor,
|
|
773
|
+
"normalize_activations": self.normalize_activations,
|
|
774
|
+
"dataset_path": self.dataset_path,
|
|
775
|
+
"dataset_trust_remote_code": self.dataset_trust_remote_code,
|
|
776
|
+
"sae_lens_training_version": self.sae_lens_training_version,
|
|
777
|
+
"model_from_pretrained_kwargs": self.model_from_pretrained_kwargs,
|
|
778
|
+
"seqpos_slice": self.seqpos_slice,
|
|
779
|
+
"neuronpedia_id": self.neuronpedia_id,
|
|
780
|
+
}
|
|
781
|
+
|
|
782
|
+
|
|
783
|
+
class TrainingSAE(SAE, ABC):
|
|
784
|
+
"""Abstract base class for training versions of SAEs."""
|
|
785
|
+
|
|
786
|
+
cfg: "TrainingSAEConfig" # type: ignore
|
|
787
|
+
|
|
788
|
+
def __init__(self, cfg: TrainingSAEConfig, use_error_term: bool = False):
|
|
789
|
+
super().__init__(cfg, use_error_term)
|
|
790
|
+
|
|
791
|
+
# Turn off hook_z reshaping for training mode - the activation store
|
|
792
|
+
# is expected to handle reshaping before passing data to the SAE
|
|
793
|
+
self.turn_off_forward_pass_hook_z_reshaping()
|
|
794
|
+
|
|
795
|
+
self.mse_loss_fn = self._get_mse_loss_fn()
|
|
796
|
+
|
|
797
|
+
@abstractmethod
|
|
798
|
+
def encode_with_hidden_pre(
|
|
799
|
+
self, x: Float[torch.Tensor, "... d_in"]
|
|
800
|
+
) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
|
|
801
|
+
"""Encode with access to pre-activation values for training."""
|
|
802
|
+
pass
|
|
803
|
+
|
|
804
|
+
def encode(
|
|
805
|
+
self, x: Float[torch.Tensor, "... d_in"]
|
|
806
|
+
) -> Float[torch.Tensor, "... d_sae"]:
|
|
807
|
+
"""
|
|
808
|
+
For inference, just encode without returning hidden_pre.
|
|
809
|
+
(training_forward_pass calls encode_with_hidden_pre).
|
|
810
|
+
"""
|
|
811
|
+
feature_acts, _ = self.encode_with_hidden_pre(x)
|
|
812
|
+
return feature_acts
|
|
813
|
+
|
|
814
|
+
def decode(
|
|
815
|
+
self, feature_acts: Float[torch.Tensor, "... d_sae"]
|
|
816
|
+
) -> Float[torch.Tensor, "... d_in"]:
|
|
817
|
+
"""
|
|
818
|
+
Decodes feature activations back into input space,
|
|
819
|
+
applying optional finetuning scale, hooking, out normalization, etc.
|
|
820
|
+
"""
|
|
821
|
+
scaled_features = self.apply_finetuning_scaling_factor(feature_acts)
|
|
822
|
+
sae_out_pre = scaled_features @ self.W_dec + self.b_dec
|
|
823
|
+
sae_out_pre = self.hook_sae_recons(sae_out_pre)
|
|
824
|
+
sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
|
|
825
|
+
return self.reshape_fn_out(sae_out_pre, self.d_head)
|
|
826
|
+
|
|
827
|
+
@abstractmethod
|
|
828
|
+
def calculate_aux_loss(
|
|
829
|
+
self,
|
|
830
|
+
step_input: TrainStepInput,
|
|
831
|
+
feature_acts: torch.Tensor,
|
|
832
|
+
hidden_pre: torch.Tensor,
|
|
833
|
+
sae_out: torch.Tensor,
|
|
834
|
+
) -> torch.Tensor | dict[str, torch.Tensor]:
|
|
835
|
+
"""Calculate architecture-specific auxiliary loss terms."""
|
|
836
|
+
pass
|
|
837
|
+
|
|
838
|
+
def training_forward_pass(
|
|
839
|
+
self,
|
|
840
|
+
step_input: TrainStepInput,
|
|
841
|
+
) -> TrainStepOutput:
|
|
842
|
+
"""Forward pass during training."""
|
|
843
|
+
feature_acts, hidden_pre = self.encode_with_hidden_pre(step_input.sae_in)
|
|
844
|
+
sae_out = self.decode(feature_acts)
|
|
845
|
+
|
|
846
|
+
# Calculate MSE loss
|
|
847
|
+
per_item_mse_loss = self.mse_loss_fn(sae_out, step_input.sae_in)
|
|
848
|
+
mse_loss = per_item_mse_loss.sum(dim=-1).mean()
|
|
849
|
+
|
|
850
|
+
# Calculate architecture-specific auxiliary losses
|
|
851
|
+
aux_losses = self.calculate_aux_loss(
|
|
852
|
+
step_input=step_input,
|
|
853
|
+
feature_acts=feature_acts,
|
|
854
|
+
hidden_pre=hidden_pre,
|
|
855
|
+
sae_out=sae_out,
|
|
856
|
+
)
|
|
857
|
+
|
|
858
|
+
# Total loss is MSE plus all auxiliary losses
|
|
859
|
+
total_loss = mse_loss
|
|
860
|
+
|
|
861
|
+
# Create losses dictionary with mse_loss
|
|
862
|
+
losses = {"mse_loss": mse_loss}
|
|
863
|
+
|
|
864
|
+
# Add architecture-specific losses to the dictionary
|
|
865
|
+
# Make sure aux_losses is a dictionary with string keys and tensor values
|
|
866
|
+
if isinstance(aux_losses, dict):
|
|
867
|
+
losses.update(aux_losses)
|
|
868
|
+
|
|
869
|
+
# Sum all losses for total_loss
|
|
870
|
+
if isinstance(aux_losses, dict):
|
|
871
|
+
for loss_value in aux_losses.values():
|
|
872
|
+
total_loss = total_loss + loss_value
|
|
873
|
+
else:
|
|
874
|
+
# Handle case where aux_losses is a tensor
|
|
875
|
+
total_loss = total_loss + aux_losses
|
|
876
|
+
|
|
877
|
+
return TrainStepOutput(
|
|
878
|
+
sae_in=step_input.sae_in,
|
|
879
|
+
sae_out=sae_out,
|
|
880
|
+
feature_acts=feature_acts,
|
|
881
|
+
hidden_pre=hidden_pre,
|
|
882
|
+
loss=total_loss,
|
|
883
|
+
losses=losses,
|
|
884
|
+
)
|
|
885
|
+
|
|
886
|
+
def _get_mse_loss_fn(self) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
|
|
887
|
+
"""Get the MSE loss function based on config."""
|
|
888
|
+
|
|
889
|
+
def standard_mse_loss_fn(
|
|
890
|
+
preds: torch.Tensor, target: torch.Tensor
|
|
891
|
+
) -> torch.Tensor:
|
|
892
|
+
return torch.nn.functional.mse_loss(preds, target, reduction="none")
|
|
893
|
+
|
|
894
|
+
def batch_norm_mse_loss_fn(
|
|
895
|
+
preds: torch.Tensor, target: torch.Tensor
|
|
896
|
+
) -> torch.Tensor:
|
|
897
|
+
target_centered = target - target.mean(dim=0, keepdim=True)
|
|
898
|
+
normalization = target_centered.norm(dim=-1, keepdim=True)
|
|
899
|
+
return torch.nn.functional.mse_loss(preds, target, reduction="none") / (
|
|
900
|
+
normalization + 1e-6
|
|
901
|
+
)
|
|
902
|
+
|
|
903
|
+
if self.cfg.mse_loss_normalization == "dense_batch":
|
|
904
|
+
return batch_norm_mse_loss_fn
|
|
905
|
+
return standard_mse_loss_fn
|
|
906
|
+
|
|
907
|
+
@torch.no_grad()
|
|
908
|
+
def remove_gradient_parallel_to_decoder_directions(self) -> None:
|
|
909
|
+
"""Remove gradient components parallel to decoder directions."""
|
|
910
|
+
# Implement the original logic since this may not be in the base class
|
|
911
|
+
assert self.W_dec.grad is not None
|
|
912
|
+
|
|
913
|
+
parallel_component = einops.einsum(
|
|
914
|
+
self.W_dec.grad,
|
|
915
|
+
self.W_dec.data,
|
|
916
|
+
"d_sae d_in, d_sae d_in -> d_sae",
|
|
917
|
+
)
|
|
918
|
+
self.W_dec.grad -= einops.einsum(
|
|
919
|
+
parallel_component,
|
|
920
|
+
self.W_dec.data,
|
|
921
|
+
"d_sae, d_sae d_in -> d_sae d_in",
|
|
922
|
+
)
|
|
923
|
+
|
|
924
|
+
@torch.no_grad()
|
|
925
|
+
def set_decoder_norm_to_unit_norm(self):
|
|
926
|
+
"""Normalize decoder columns to unit norm."""
|
|
927
|
+
self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
|
|
928
|
+
|
|
929
|
+
@torch.no_grad()
|
|
930
|
+
def log_histograms(self) -> dict[str, NDArray[Any]]:
|
|
931
|
+
"""Log histograms of the weights and biases."""
|
|
932
|
+
W_dec_norm_dist = self.W_dec.detach().float().norm(dim=1).cpu().numpy()
|
|
933
|
+
return {
|
|
934
|
+
"weights/W_dec_norms": W_dec_norm_dist,
|
|
935
|
+
}
|
|
936
|
+
|
|
937
|
+
@classmethod
|
|
938
|
+
def get_sae_class_for_architecture(cls: Type[T], architecture: str) -> Type[T]:
|
|
939
|
+
"""Get the SAE class for a given architecture."""
|
|
940
|
+
sae_cls = get_sae_training_class(architecture)
|
|
941
|
+
if not issubclass(sae_cls, cls):
|
|
942
|
+
raise ValueError(
|
|
943
|
+
f"Loaded SAE is not of type {cls.__name__}. Use {sae_cls.__name__} instead"
|
|
944
|
+
)
|
|
945
|
+
return sae_cls
|
|
946
|
+
|
|
947
|
+
# in the future, this can be used to load different config classes for different architectures
|
|
948
|
+
@classmethod
|
|
949
|
+
def get_sae_config_class_for_architecture(
|
|
950
|
+
cls: Type[T],
|
|
951
|
+
architecture: str, # noqa: ARG003
|
|
952
|
+
) -> type[SAEConfig]:
|
|
953
|
+
return TrainingSAEConfig
|
|
954
|
+
|
|
955
|
+
|
|
956
|
+
_blank_hook = nn.Identity()
|
|
957
|
+
|
|
958
|
+
|
|
959
|
+
@contextmanager
|
|
960
|
+
def _disable_hooks(sae: SAE):
|
|
961
|
+
"""
|
|
962
|
+
Temporarily disable hooks for the SAE. Swaps out all the hooks with a fake modules that does nothing.
|
|
963
|
+
"""
|
|
964
|
+
try:
|
|
965
|
+
for hook_name in sae.hook_dict:
|
|
966
|
+
setattr(sae, hook_name, _blank_hook)
|
|
967
|
+
yield
|
|
968
|
+
finally:
|
|
969
|
+
for hook_name, hook in sae.hook_dict.items():
|
|
970
|
+
setattr(sae, hook_name, hook)
|