sae-lens 5.10.7__py3-none-any.whl → 6.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +60 -7
- sae_lens/analysis/hooked_sae_transformer.py +12 -12
- sae_lens/analysis/neuronpedia_integration.py +16 -14
- sae_lens/cache_activations_runner.py +9 -7
- sae_lens/config.py +170 -257
- sae_lens/constants.py +21 -0
- sae_lens/evals.py +59 -44
- sae_lens/llm_sae_training_runner.py +377 -0
- sae_lens/load_model.py +53 -5
- sae_lens/{toolkit → loading}/pretrained_sae_loaders.py +228 -32
- sae_lens/registry.py +49 -0
- sae_lens/saes/__init__.py +48 -0
- sae_lens/saes/gated_sae.py +254 -0
- sae_lens/saes/jumprelu_sae.py +348 -0
- sae_lens/saes/sae.py +1076 -0
- sae_lens/saes/standard_sae.py +178 -0
- sae_lens/saes/topk_sae.py +300 -0
- sae_lens/training/activation_scaler.py +53 -0
- sae_lens/training/activations_store.py +103 -184
- sae_lens/training/mixing_buffer.py +56 -0
- sae_lens/training/optim.py +60 -36
- sae_lens/training/sae_trainer.py +155 -177
- sae_lens/training/types.py +5 -0
- sae_lens/training/upload_saes_to_huggingface.py +13 -7
- sae_lens/util.py +47 -0
- {sae_lens-5.10.7.dist-info → sae_lens-6.0.0.dist-info}/METADATA +1 -1
- sae_lens-6.0.0.dist-info/RECORD +37 -0
- sae_lens/sae.py +0 -747
- sae_lens/sae_training_runner.py +0 -251
- sae_lens/training/geometric_median.py +0 -101
- sae_lens/training/training_sae.py +0 -710
- sae_lens-5.10.7.dist-info/RECORD +0 -28
- /sae_lens/{toolkit → loading}/__init__.py +0 -0
- /sae_lens/{toolkit → loading}/pretrained_saes_directory.py +0 -0
- {sae_lens-5.10.7.dist-info → sae_lens-6.0.0.dist-info}/LICENSE +0 -0
- {sae_lens-5.10.7.dist-info → sae_lens-6.0.0.dist-info}/WHEEL +0 -0
sae_lens/saes/sae.py
ADDED
|
@@ -0,0 +1,1076 @@
|
|
|
1
|
+
"""Base classes for Sparse Autoencoders (SAEs)."""
|
|
2
|
+
|
|
3
|
+
import copy
|
|
4
|
+
import json
|
|
5
|
+
import warnings
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
from contextlib import contextmanager
|
|
8
|
+
from dataclasses import asdict, dataclass, field, fields, replace
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import (
|
|
11
|
+
TYPE_CHECKING,
|
|
12
|
+
Any,
|
|
13
|
+
Callable,
|
|
14
|
+
Generic,
|
|
15
|
+
Literal,
|
|
16
|
+
NamedTuple,
|
|
17
|
+
Type,
|
|
18
|
+
TypeVar,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
import einops
|
|
22
|
+
import torch
|
|
23
|
+
from jaxtyping import Float
|
|
24
|
+
from numpy.typing import NDArray
|
|
25
|
+
from safetensors.torch import save_file
|
|
26
|
+
from torch import nn
|
|
27
|
+
from transformer_lens.hook_points import HookedRootModule, HookPoint
|
|
28
|
+
from typing_extensions import deprecated, overload, override
|
|
29
|
+
|
|
30
|
+
from sae_lens import __version__, logger
|
|
31
|
+
from sae_lens.constants import (
|
|
32
|
+
DTYPE_MAP,
|
|
33
|
+
SAE_CFG_FILENAME,
|
|
34
|
+
SAE_WEIGHTS_FILENAME,
|
|
35
|
+
)
|
|
36
|
+
from sae_lens.util import filter_valid_dataclass_fields
|
|
37
|
+
|
|
38
|
+
if TYPE_CHECKING:
|
|
39
|
+
from sae_lens.config import LanguageModelSAERunnerConfig
|
|
40
|
+
|
|
41
|
+
from sae_lens.loading.pretrained_sae_loaders import (
|
|
42
|
+
NAMED_PRETRAINED_SAE_LOADERS,
|
|
43
|
+
PretrainedSaeDiskLoader,
|
|
44
|
+
PretrainedSaeHuggingfaceLoader,
|
|
45
|
+
get_conversion_loader_name,
|
|
46
|
+
handle_config_defaulting,
|
|
47
|
+
sae_lens_disk_loader,
|
|
48
|
+
)
|
|
49
|
+
from sae_lens.loading.pretrained_saes_directory import (
|
|
50
|
+
get_config_overrides,
|
|
51
|
+
get_norm_scaling_factor,
|
|
52
|
+
get_pretrained_saes_directory,
|
|
53
|
+
get_repo_id_and_folder_name,
|
|
54
|
+
)
|
|
55
|
+
from sae_lens.registry import get_sae_class, get_sae_training_class
|
|
56
|
+
|
|
57
|
+
T_SAE_CONFIG = TypeVar("T_SAE_CONFIG", bound="SAEConfig")
|
|
58
|
+
T_TRAINING_SAE_CONFIG = TypeVar("T_TRAINING_SAE_CONFIG", bound="TrainingSAEConfig")
|
|
59
|
+
T_SAE = TypeVar("T_SAE", bound="SAE") # type: ignore
|
|
60
|
+
T_TRAINING_SAE = TypeVar("T_TRAINING_SAE", bound="TrainingSAE") # type: ignore
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class SAEMetadata:
|
|
64
|
+
"""Core metadata about how this SAE should be used, if known."""
|
|
65
|
+
|
|
66
|
+
def __init__(self, **kwargs: Any):
|
|
67
|
+
# Set default version fields with their current behavior
|
|
68
|
+
self.sae_lens_version = kwargs.pop("sae_lens_version", __version__)
|
|
69
|
+
self.sae_lens_training_version = kwargs.pop(
|
|
70
|
+
"sae_lens_training_version", __version__
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
# Set all other attributes dynamically
|
|
74
|
+
for key, value in kwargs.items():
|
|
75
|
+
setattr(self, key, value)
|
|
76
|
+
|
|
77
|
+
def __getattr__(self, name: str) -> None:
|
|
78
|
+
"""Return None for any missing attribute (like defaultdict)"""
|
|
79
|
+
return
|
|
80
|
+
|
|
81
|
+
def __setattr__(self, name: str, value: Any) -> None:
|
|
82
|
+
"""Allow setting any attribute"""
|
|
83
|
+
super().__setattr__(name, value)
|
|
84
|
+
|
|
85
|
+
def __getitem__(self, key: str) -> Any:
|
|
86
|
+
"""Allow dictionary-style access: metadata['key'] - returns None for missing keys"""
|
|
87
|
+
return getattr(self, key)
|
|
88
|
+
|
|
89
|
+
def __setitem__(self, key: str, value: Any) -> None:
|
|
90
|
+
"""Allow dictionary-style assignment: metadata['key'] = value"""
|
|
91
|
+
setattr(self, key, value)
|
|
92
|
+
|
|
93
|
+
def __contains__(self, key: str) -> bool:
|
|
94
|
+
"""Allow 'in' operator: 'key' in metadata"""
|
|
95
|
+
# Only return True if the attribute was explicitly set (not just defaulting to None)
|
|
96
|
+
return key in self.__dict__
|
|
97
|
+
|
|
98
|
+
def get(self, key: str, default: Any = None) -> Any:
|
|
99
|
+
"""Dictionary-style get with default"""
|
|
100
|
+
value = getattr(self, key)
|
|
101
|
+
# If the attribute wasn't explicitly set and we got None from __getattr__,
|
|
102
|
+
# use the provided default instead
|
|
103
|
+
if key not in self.__dict__ and value is None:
|
|
104
|
+
return default
|
|
105
|
+
return value
|
|
106
|
+
|
|
107
|
+
def keys(self):
|
|
108
|
+
"""Return all explicitly set attribute names"""
|
|
109
|
+
return self.__dict__.keys()
|
|
110
|
+
|
|
111
|
+
def values(self):
|
|
112
|
+
"""Return all explicitly set attribute values"""
|
|
113
|
+
return self.__dict__.values()
|
|
114
|
+
|
|
115
|
+
def items(self):
|
|
116
|
+
"""Return all explicitly set attribute name-value pairs"""
|
|
117
|
+
return self.__dict__.items()
|
|
118
|
+
|
|
119
|
+
def to_dict(self) -> dict[str, Any]:
|
|
120
|
+
"""Convert to dictionary for serialization"""
|
|
121
|
+
return self.__dict__.copy()
|
|
122
|
+
|
|
123
|
+
@classmethod
|
|
124
|
+
def from_dict(cls, data: dict[str, Any]) -> "SAEMetadata":
|
|
125
|
+
"""Create from dictionary"""
|
|
126
|
+
return cls(**data)
|
|
127
|
+
|
|
128
|
+
def __repr__(self) -> str:
|
|
129
|
+
return f"SAEMetadata({self.__dict__})"
|
|
130
|
+
|
|
131
|
+
def __eq__(self, other: object) -> bool:
|
|
132
|
+
if not isinstance(other, SAEMetadata):
|
|
133
|
+
return False
|
|
134
|
+
return self.__dict__ == other.__dict__
|
|
135
|
+
|
|
136
|
+
def __deepcopy__(self, memo: dict[int, Any]) -> "SAEMetadata":
|
|
137
|
+
"""Support for deep copying"""
|
|
138
|
+
|
|
139
|
+
return SAEMetadata(**copy.deepcopy(self.__dict__, memo))
|
|
140
|
+
|
|
141
|
+
def __getstate__(self) -> dict[str, Any]:
|
|
142
|
+
"""Support for pickling"""
|
|
143
|
+
return self.__dict__
|
|
144
|
+
|
|
145
|
+
def __setstate__(self, state: dict[str, Any]) -> None:
|
|
146
|
+
"""Support for unpickling"""
|
|
147
|
+
self.__dict__.update(state)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
@dataclass
|
|
151
|
+
class SAEConfig(ABC):
|
|
152
|
+
"""Base configuration for SAE models."""
|
|
153
|
+
|
|
154
|
+
d_in: int
|
|
155
|
+
d_sae: int
|
|
156
|
+
dtype: str = "float32"
|
|
157
|
+
device: str = "cpu"
|
|
158
|
+
apply_b_dec_to_input: bool = True
|
|
159
|
+
normalize_activations: Literal[
|
|
160
|
+
"none", "expected_average_only_in", "constant_norm_rescale", "layer_norm"
|
|
161
|
+
] = "none" # none, expected_average_only_in (Anthropic April Update), constant_norm_rescale (Anthropic Feb Update)
|
|
162
|
+
reshape_activations: Literal["none", "hook_z"] = "none"
|
|
163
|
+
metadata: SAEMetadata = field(default_factory=SAEMetadata)
|
|
164
|
+
|
|
165
|
+
@classmethod
|
|
166
|
+
@abstractmethod
|
|
167
|
+
def architecture(cls) -> str: ...
|
|
168
|
+
|
|
169
|
+
def to_dict(self) -> dict[str, Any]:
|
|
170
|
+
res = {field.name: getattr(self, field.name) for field in fields(self)}
|
|
171
|
+
res["metadata"] = self.metadata.to_dict()
|
|
172
|
+
res["architecture"] = self.architecture()
|
|
173
|
+
return res
|
|
174
|
+
|
|
175
|
+
@classmethod
|
|
176
|
+
def from_dict(cls: type[T_SAE_CONFIG], config_dict: dict[str, Any]) -> T_SAE_CONFIG:
|
|
177
|
+
cfg_class = get_sae_class(config_dict["architecture"])[1]
|
|
178
|
+
filtered_config_dict = filter_valid_dataclass_fields(config_dict, cfg_class)
|
|
179
|
+
res = cfg_class(**filtered_config_dict)
|
|
180
|
+
if "metadata" in config_dict:
|
|
181
|
+
res.metadata = SAEMetadata(**config_dict["metadata"])
|
|
182
|
+
if not isinstance(res, cls):
|
|
183
|
+
raise ValueError(
|
|
184
|
+
f"SAE config class {cls} does not match dict config class {type(res)}"
|
|
185
|
+
)
|
|
186
|
+
return res
|
|
187
|
+
|
|
188
|
+
def __post_init__(self):
|
|
189
|
+
if self.normalize_activations not in [
|
|
190
|
+
"none",
|
|
191
|
+
"expected_average_only_in",
|
|
192
|
+
"constant_norm_rescale",
|
|
193
|
+
"layer_norm",
|
|
194
|
+
]:
|
|
195
|
+
raise ValueError(
|
|
196
|
+
f"normalize_activations must be none, expected_average_only_in, layer_norm, or constant_norm_rescale. Got {self.normalize_activations}"
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
@dataclass
|
|
201
|
+
class TrainStepOutput:
|
|
202
|
+
"""Output from a training step."""
|
|
203
|
+
|
|
204
|
+
sae_in: torch.Tensor
|
|
205
|
+
sae_out: torch.Tensor
|
|
206
|
+
feature_acts: torch.Tensor
|
|
207
|
+
hidden_pre: torch.Tensor
|
|
208
|
+
loss: torch.Tensor # we need to call backwards on this
|
|
209
|
+
losses: dict[str, torch.Tensor]
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
@dataclass
|
|
213
|
+
class TrainStepInput:
|
|
214
|
+
"""Input to a training step."""
|
|
215
|
+
|
|
216
|
+
sae_in: torch.Tensor
|
|
217
|
+
coefficients: dict[str, float]
|
|
218
|
+
dead_neuron_mask: torch.Tensor | None
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
class TrainCoefficientConfig(NamedTuple):
|
|
222
|
+
value: float
|
|
223
|
+
warm_up_steps: int
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
227
|
+
"""Abstract base class for all SAE architectures."""
|
|
228
|
+
|
|
229
|
+
cfg: T_SAE_CONFIG
|
|
230
|
+
dtype: torch.dtype
|
|
231
|
+
device: torch.device
|
|
232
|
+
use_error_term: bool
|
|
233
|
+
|
|
234
|
+
# For type checking only - don't provide default values
|
|
235
|
+
# These will be initialized by subclasses
|
|
236
|
+
W_enc: nn.Parameter
|
|
237
|
+
W_dec: nn.Parameter
|
|
238
|
+
b_dec: nn.Parameter
|
|
239
|
+
|
|
240
|
+
def __init__(self, cfg: T_SAE_CONFIG, use_error_term: bool = False):
|
|
241
|
+
"""Initialize the SAE."""
|
|
242
|
+
super().__init__()
|
|
243
|
+
|
|
244
|
+
self.cfg = cfg
|
|
245
|
+
|
|
246
|
+
if cfg.metadata and cfg.metadata.model_from_pretrained_kwargs:
|
|
247
|
+
warnings.warn(
|
|
248
|
+
"\nThis SAE has non-empty model_from_pretrained_kwargs. "
|
|
249
|
+
"\nFor optimal performance, load the model like so:\n"
|
|
250
|
+
"model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)",
|
|
251
|
+
category=UserWarning,
|
|
252
|
+
stacklevel=1,
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
self.dtype = DTYPE_MAP[cfg.dtype]
|
|
256
|
+
self.device = torch.device(cfg.device)
|
|
257
|
+
self.use_error_term = use_error_term
|
|
258
|
+
|
|
259
|
+
# Set up activation function
|
|
260
|
+
self.activation_fn = self.get_activation_fn()
|
|
261
|
+
|
|
262
|
+
# Initialize weights
|
|
263
|
+
self.initialize_weights()
|
|
264
|
+
|
|
265
|
+
# Set up hooks
|
|
266
|
+
self.hook_sae_input = HookPoint()
|
|
267
|
+
self.hook_sae_acts_pre = HookPoint()
|
|
268
|
+
self.hook_sae_acts_post = HookPoint()
|
|
269
|
+
self.hook_sae_output = HookPoint()
|
|
270
|
+
self.hook_sae_recons = HookPoint()
|
|
271
|
+
self.hook_sae_error = HookPoint()
|
|
272
|
+
|
|
273
|
+
# handle hook_z reshaping if needed.
|
|
274
|
+
if self.cfg.reshape_activations == "hook_z":
|
|
275
|
+
self.turn_on_forward_pass_hook_z_reshaping()
|
|
276
|
+
else:
|
|
277
|
+
self.turn_off_forward_pass_hook_z_reshaping()
|
|
278
|
+
|
|
279
|
+
# Set up activation normalization
|
|
280
|
+
self._setup_activation_normalization()
|
|
281
|
+
|
|
282
|
+
self.setup() # Required for HookedRootModule
|
|
283
|
+
|
|
284
|
+
@torch.no_grad()
|
|
285
|
+
def fold_activation_norm_scaling_factor(self, scaling_factor: float):
|
|
286
|
+
self.W_enc.data *= scaling_factor # type: ignore
|
|
287
|
+
self.W_dec.data /= scaling_factor # type: ignore
|
|
288
|
+
self.b_dec.data /= scaling_factor # type: ignore
|
|
289
|
+
self.cfg.normalize_activations = "none"
|
|
290
|
+
|
|
291
|
+
def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
|
|
292
|
+
"""Get the activation function specified in config."""
|
|
293
|
+
return nn.ReLU()
|
|
294
|
+
|
|
295
|
+
def _setup_activation_normalization(self):
|
|
296
|
+
"""Set up activation normalization functions based on config."""
|
|
297
|
+
if self.cfg.normalize_activations == "constant_norm_rescale":
|
|
298
|
+
|
|
299
|
+
def run_time_activation_norm_fn_in(x: torch.Tensor) -> torch.Tensor:
|
|
300
|
+
self.x_norm_coeff = (self.cfg.d_in**0.5) / x.norm(dim=-1, keepdim=True)
|
|
301
|
+
return x * self.x_norm_coeff
|
|
302
|
+
|
|
303
|
+
def run_time_activation_norm_fn_out(x: torch.Tensor) -> torch.Tensor:
|
|
304
|
+
x = x / self.x_norm_coeff # type: ignore
|
|
305
|
+
del self.x_norm_coeff
|
|
306
|
+
return x
|
|
307
|
+
|
|
308
|
+
self.run_time_activation_norm_fn_in = run_time_activation_norm_fn_in
|
|
309
|
+
self.run_time_activation_norm_fn_out = run_time_activation_norm_fn_out
|
|
310
|
+
elif self.cfg.normalize_activations == "layer_norm":
|
|
311
|
+
# we need to scale the norm of the input and store the scaling factor
|
|
312
|
+
def run_time_activation_ln_in(
|
|
313
|
+
x: torch.Tensor, eps: float = 1e-5
|
|
314
|
+
) -> torch.Tensor:
|
|
315
|
+
mu = x.mean(dim=-1, keepdim=True)
|
|
316
|
+
x = x - mu
|
|
317
|
+
std = x.std(dim=-1, keepdim=True)
|
|
318
|
+
x = x / (std + eps)
|
|
319
|
+
self.ln_mu = mu
|
|
320
|
+
self.ln_std = std
|
|
321
|
+
return x
|
|
322
|
+
|
|
323
|
+
def run_time_activation_ln_out(
|
|
324
|
+
x: torch.Tensor,
|
|
325
|
+
eps: float = 1e-5, # noqa: ARG001
|
|
326
|
+
) -> torch.Tensor:
|
|
327
|
+
return x * self.ln_std + self.ln_mu # type: ignore
|
|
328
|
+
|
|
329
|
+
self.run_time_activation_norm_fn_in = run_time_activation_ln_in
|
|
330
|
+
self.run_time_activation_norm_fn_out = run_time_activation_ln_out
|
|
331
|
+
else:
|
|
332
|
+
self.run_time_activation_norm_fn_in = lambda x: x
|
|
333
|
+
self.run_time_activation_norm_fn_out = lambda x: x
|
|
334
|
+
|
|
335
|
+
def initialize_weights(self):
|
|
336
|
+
"""Initialize model weights."""
|
|
337
|
+
self.b_dec = nn.Parameter(
|
|
338
|
+
torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
w_dec_data = torch.empty(
|
|
342
|
+
self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
|
|
343
|
+
)
|
|
344
|
+
nn.init.kaiming_uniform_(w_dec_data)
|
|
345
|
+
self.W_dec = nn.Parameter(w_dec_data)
|
|
346
|
+
|
|
347
|
+
w_enc_data = self.W_dec.data.T.clone().detach().contiguous()
|
|
348
|
+
self.W_enc = nn.Parameter(w_enc_data)
|
|
349
|
+
|
|
350
|
+
@abstractmethod
|
|
351
|
+
def encode(
|
|
352
|
+
self, x: Float[torch.Tensor, "... d_in"]
|
|
353
|
+
) -> Float[torch.Tensor, "... d_sae"]:
|
|
354
|
+
"""Encode input tensor to feature space."""
|
|
355
|
+
pass
|
|
356
|
+
|
|
357
|
+
@abstractmethod
|
|
358
|
+
def decode(
|
|
359
|
+
self, feature_acts: Float[torch.Tensor, "... d_sae"]
|
|
360
|
+
) -> Float[torch.Tensor, "... d_in"]:
|
|
361
|
+
"""Decode feature activations back to input space."""
|
|
362
|
+
pass
|
|
363
|
+
|
|
364
|
+
def turn_on_forward_pass_hook_z_reshaping(self):
|
|
365
|
+
if (
|
|
366
|
+
self.cfg.metadata.hook_name is not None
|
|
367
|
+
and not self.cfg.metadata.hook_name.endswith("_z")
|
|
368
|
+
):
|
|
369
|
+
raise ValueError("This method should only be called for hook_z SAEs.")
|
|
370
|
+
|
|
371
|
+
# print(f"Turning on hook_z reshaping for {self.cfg.hook_name}")
|
|
372
|
+
|
|
373
|
+
def reshape_fn_in(x: torch.Tensor):
|
|
374
|
+
# print(f"reshape_fn_in input shape: {x.shape}")
|
|
375
|
+
self.d_head = x.shape[-1]
|
|
376
|
+
# print(f"Setting d_head to: {self.d_head}")
|
|
377
|
+
self.reshape_fn_in = lambda x: einops.rearrange(
|
|
378
|
+
x, "... n_heads d_head -> ... (n_heads d_head)"
|
|
379
|
+
)
|
|
380
|
+
return einops.rearrange(x, "... n_heads d_head -> ... (n_heads d_head)")
|
|
381
|
+
|
|
382
|
+
self.reshape_fn_in = reshape_fn_in
|
|
383
|
+
self.reshape_fn_out = lambda x, d_head: einops.rearrange(
|
|
384
|
+
x, "... (n_heads d_head) -> ... n_heads d_head", d_head=d_head
|
|
385
|
+
)
|
|
386
|
+
self.hook_z_reshaping_mode = True
|
|
387
|
+
# print(f"hook_z reshaping turned on, self.d_head={getattr(self, 'd_head', None)}")
|
|
388
|
+
|
|
389
|
+
def turn_off_forward_pass_hook_z_reshaping(self):
|
|
390
|
+
self.reshape_fn_in = lambda x: x
|
|
391
|
+
self.reshape_fn_out = lambda x, d_head: x # noqa: ARG005
|
|
392
|
+
self.d_head = None
|
|
393
|
+
self.hook_z_reshaping_mode = False
|
|
394
|
+
|
|
395
|
+
@overload
|
|
396
|
+
def to(
|
|
397
|
+
self: T_SAE,
|
|
398
|
+
device: torch.device | str | None = ...,
|
|
399
|
+
dtype: torch.dtype | None = ...,
|
|
400
|
+
non_blocking: bool = ...,
|
|
401
|
+
) -> T_SAE: ...
|
|
402
|
+
|
|
403
|
+
@overload
|
|
404
|
+
def to(self: T_SAE, dtype: torch.dtype, non_blocking: bool = ...) -> T_SAE: ...
|
|
405
|
+
|
|
406
|
+
@overload
|
|
407
|
+
def to(self: T_SAE, tensor: torch.Tensor, non_blocking: bool = ...) -> T_SAE: ...
|
|
408
|
+
|
|
409
|
+
def to(self: T_SAE, *args: Any, **kwargs: Any) -> T_SAE: # type: ignore
|
|
410
|
+
device_arg = None
|
|
411
|
+
dtype_arg = None
|
|
412
|
+
|
|
413
|
+
# Check args
|
|
414
|
+
for arg in args:
|
|
415
|
+
if isinstance(arg, (torch.device, str)):
|
|
416
|
+
device_arg = arg
|
|
417
|
+
elif isinstance(arg, torch.dtype):
|
|
418
|
+
dtype_arg = arg
|
|
419
|
+
elif isinstance(arg, torch.Tensor):
|
|
420
|
+
device_arg = arg.device
|
|
421
|
+
dtype_arg = arg.dtype
|
|
422
|
+
|
|
423
|
+
# Check kwargs
|
|
424
|
+
device_arg = kwargs.get("device", device_arg)
|
|
425
|
+
dtype_arg = kwargs.get("dtype", dtype_arg)
|
|
426
|
+
|
|
427
|
+
# Update device in config if provided
|
|
428
|
+
if device_arg is not None:
|
|
429
|
+
# Convert device to torch.device if it's a string
|
|
430
|
+
device = (
|
|
431
|
+
torch.device(device_arg) if isinstance(device_arg, str) else device_arg
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
# Update the cfg.device
|
|
435
|
+
self.cfg.device = str(device)
|
|
436
|
+
|
|
437
|
+
# Update the device property
|
|
438
|
+
self.device = device
|
|
439
|
+
|
|
440
|
+
# Update dtype in config if provided
|
|
441
|
+
if dtype_arg is not None:
|
|
442
|
+
# Update the cfg.dtype
|
|
443
|
+
self.cfg.dtype = str(dtype_arg)
|
|
444
|
+
|
|
445
|
+
# Update the dtype property
|
|
446
|
+
self.dtype = dtype_arg
|
|
447
|
+
|
|
448
|
+
return super().to(*args, **kwargs)
|
|
449
|
+
|
|
450
|
+
def process_sae_in(
|
|
451
|
+
self, sae_in: Float[torch.Tensor, "... d_in"]
|
|
452
|
+
) -> Float[torch.Tensor, "... d_in"]:
|
|
453
|
+
# print(f"Input shape to process_sae_in: {sae_in.shape}")
|
|
454
|
+
# print(f"self.cfg.hook_name: {self.cfg.hook_name}")
|
|
455
|
+
# print(f"self.b_dec shape: {self.b_dec.shape}")
|
|
456
|
+
# print(f"Hook z reshaping mode: {getattr(self, 'hook_z_reshaping_mode', False)}")
|
|
457
|
+
|
|
458
|
+
sae_in = sae_in.to(self.dtype)
|
|
459
|
+
|
|
460
|
+
# print(f"Shape before reshape_fn_in: {sae_in.shape}")
|
|
461
|
+
sae_in = self.reshape_fn_in(sae_in)
|
|
462
|
+
# print(f"Shape after reshape_fn_in: {sae_in.shape}")
|
|
463
|
+
|
|
464
|
+
sae_in = self.hook_sae_input(sae_in)
|
|
465
|
+
sae_in = self.run_time_activation_norm_fn_in(sae_in)
|
|
466
|
+
|
|
467
|
+
# Here's where the error happens
|
|
468
|
+
bias_term = self.b_dec * self.cfg.apply_b_dec_to_input
|
|
469
|
+
# print(f"Bias term shape: {bias_term.shape}")
|
|
470
|
+
|
|
471
|
+
return sae_in - bias_term
|
|
472
|
+
|
|
473
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
474
|
+
"""Forward pass through the SAE."""
|
|
475
|
+
feature_acts = self.encode(x)
|
|
476
|
+
sae_out = self.decode(feature_acts)
|
|
477
|
+
|
|
478
|
+
if self.use_error_term:
|
|
479
|
+
with torch.no_grad():
|
|
480
|
+
# Recompute without hooks for true error term
|
|
481
|
+
with _disable_hooks(self):
|
|
482
|
+
feature_acts_clean = self.encode(x)
|
|
483
|
+
x_reconstruct_clean = self.decode(feature_acts_clean)
|
|
484
|
+
sae_error = self.hook_sae_error(x - x_reconstruct_clean)
|
|
485
|
+
sae_out = sae_out + sae_error
|
|
486
|
+
|
|
487
|
+
return self.hook_sae_output(sae_out)
|
|
488
|
+
|
|
489
|
+
# overwrite this in subclasses to modify the state_dict in-place before saving
|
|
490
|
+
def process_state_dict_for_saving(self, state_dict: dict[str, Any]) -> None:
|
|
491
|
+
pass
|
|
492
|
+
|
|
493
|
+
# overwrite this in subclasses to modify the state_dict in-place after loading
|
|
494
|
+
def process_state_dict_for_loading(self, state_dict: dict[str, Any]) -> None:
|
|
495
|
+
pass
|
|
496
|
+
|
|
497
|
+
@torch.no_grad()
|
|
498
|
+
def fold_W_dec_norm(self):
|
|
499
|
+
"""Fold decoder norms into encoder."""
|
|
500
|
+
W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
|
|
501
|
+
self.W_dec.data = self.W_dec.data / W_dec_norms
|
|
502
|
+
self.W_enc.data = self.W_enc.data * W_dec_norms.T
|
|
503
|
+
|
|
504
|
+
# Only update b_enc if it exists (standard/jumprelu architectures)
|
|
505
|
+
if hasattr(self, "b_enc") and isinstance(self.b_enc, nn.Parameter):
|
|
506
|
+
self.b_enc.data = self.b_enc.data * W_dec_norms.squeeze()
|
|
507
|
+
|
|
508
|
+
def get_name(self):
|
|
509
|
+
"""Generate a name for this SAE."""
|
|
510
|
+
return f"sae_{self.cfg.metadata.model_name}_{self.cfg.metadata.hook_name}_{self.cfg.d_sae}"
|
|
511
|
+
|
|
512
|
+
def save_model(self, path: str | Path) -> tuple[Path, Path]:
|
|
513
|
+
"""Save model weights and config to disk."""
|
|
514
|
+
path = Path(path)
|
|
515
|
+
path.mkdir(parents=True, exist_ok=True)
|
|
516
|
+
|
|
517
|
+
# Generate the weights
|
|
518
|
+
state_dict = self.state_dict() # Use internal SAE state dict
|
|
519
|
+
self.process_state_dict_for_saving(state_dict)
|
|
520
|
+
model_weights_path = path / SAE_WEIGHTS_FILENAME
|
|
521
|
+
save_file(state_dict, model_weights_path)
|
|
522
|
+
|
|
523
|
+
# Save the config
|
|
524
|
+
config = self.cfg.to_dict()
|
|
525
|
+
cfg_path = path / SAE_CFG_FILENAME
|
|
526
|
+
with open(cfg_path, "w") as f:
|
|
527
|
+
json.dump(config, f)
|
|
528
|
+
|
|
529
|
+
return model_weights_path, cfg_path
|
|
530
|
+
|
|
531
|
+
## Initialization Methods
|
|
532
|
+
@torch.no_grad()
|
|
533
|
+
def initialize_b_dec_with_precalculated(self, origin: torch.Tensor):
|
|
534
|
+
out = torch.tensor(origin, dtype=self.dtype, device=self.device)
|
|
535
|
+
self.b_dec.data = out
|
|
536
|
+
|
|
537
|
+
@torch.no_grad()
|
|
538
|
+
def initialize_b_dec_with_mean(self, all_activations: torch.Tensor):
|
|
539
|
+
previous_b_dec = self.b_dec.clone().cpu()
|
|
540
|
+
out = all_activations.mean(dim=0)
|
|
541
|
+
|
|
542
|
+
previous_distances = torch.norm(all_activations - previous_b_dec, dim=-1)
|
|
543
|
+
distances = torch.norm(all_activations - out, dim=-1)
|
|
544
|
+
|
|
545
|
+
logger.info("Reinitializing b_dec with mean of activations")
|
|
546
|
+
logger.debug(
|
|
547
|
+
f"Previous distances: {previous_distances.median(0).values.mean().item()}"
|
|
548
|
+
)
|
|
549
|
+
logger.debug(f"New distances: {distances.median(0).values.mean().item()}")
|
|
550
|
+
|
|
551
|
+
self.b_dec.data = out.to(self.dtype).to(self.device)
|
|
552
|
+
|
|
553
|
+
# Class methods for loading models
|
|
554
|
+
@classmethod
|
|
555
|
+
@deprecated("Use load_from_disk instead")
|
|
556
|
+
def load_from_pretrained(
|
|
557
|
+
cls: Type[T_SAE],
|
|
558
|
+
path: str | Path,
|
|
559
|
+
device: str = "cpu",
|
|
560
|
+
dtype: str | None = None,
|
|
561
|
+
) -> T_SAE:
|
|
562
|
+
return cls.load_from_disk(path, device=device, dtype=dtype)
|
|
563
|
+
|
|
564
|
+
@classmethod
|
|
565
|
+
def load_from_disk(
|
|
566
|
+
cls: Type[T_SAE],
|
|
567
|
+
path: str | Path,
|
|
568
|
+
device: str = "cpu",
|
|
569
|
+
dtype: str | None = None,
|
|
570
|
+
converter: PretrainedSaeDiskLoader = sae_lens_disk_loader,
|
|
571
|
+
) -> T_SAE:
|
|
572
|
+
overrides = {"dtype": dtype} if dtype is not None else None
|
|
573
|
+
cfg_dict, state_dict = converter(path, device, cfg_overrides=overrides)
|
|
574
|
+
cfg_dict = handle_config_defaulting(cfg_dict)
|
|
575
|
+
sae_config_cls = cls.get_sae_config_class_for_architecture(
|
|
576
|
+
cfg_dict["architecture"]
|
|
577
|
+
)
|
|
578
|
+
sae_cfg = sae_config_cls.from_dict(cfg_dict)
|
|
579
|
+
sae_cls = cls.get_sae_class_for_architecture(sae_cfg.architecture())
|
|
580
|
+
sae = sae_cls(sae_cfg)
|
|
581
|
+
sae.process_state_dict_for_loading(state_dict)
|
|
582
|
+
sae.load_state_dict(state_dict)
|
|
583
|
+
return sae
|
|
584
|
+
|
|
585
|
+
@classmethod
|
|
586
|
+
def from_pretrained(
|
|
587
|
+
cls: Type[T_SAE],
|
|
588
|
+
release: str,
|
|
589
|
+
sae_id: str,
|
|
590
|
+
device: str = "cpu",
|
|
591
|
+
force_download: bool = False,
|
|
592
|
+
converter: PretrainedSaeHuggingfaceLoader | None = None,
|
|
593
|
+
) -> T_SAE:
|
|
594
|
+
"""
|
|
595
|
+
Load a pretrained SAE from the Hugging Face model hub.
|
|
596
|
+
|
|
597
|
+
Args:
|
|
598
|
+
release: The release name. This will be mapped to a huggingface repo id based on the pretrained_saes.yaml file.
|
|
599
|
+
id: The id of the SAE to load. This will be mapped to a path in the huggingface repo.
|
|
600
|
+
device: The device to load the SAE on.
|
|
601
|
+
"""
|
|
602
|
+
return cls.from_pretrained_with_cfg_and_sparsity(
|
|
603
|
+
release, sae_id, device, force_download, converter=converter
|
|
604
|
+
)[0]
|
|
605
|
+
|
|
606
|
+
@classmethod
|
|
607
|
+
def from_pretrained_with_cfg_and_sparsity(
|
|
608
|
+
cls: Type[T_SAE],
|
|
609
|
+
release: str,
|
|
610
|
+
sae_id: str,
|
|
611
|
+
device: str = "cpu",
|
|
612
|
+
force_download: bool = False,
|
|
613
|
+
converter: PretrainedSaeHuggingfaceLoader | None = None,
|
|
614
|
+
) -> tuple[T_SAE, dict[str, Any], torch.Tensor | None]:
|
|
615
|
+
"""
|
|
616
|
+
Load a pretrained SAE from the Hugging Face model hub, along with its config dict and sparsity, if present.
|
|
617
|
+
In SAELens <= 5.x.x, this was called SAE.from_pretrained().
|
|
618
|
+
|
|
619
|
+
Args:
|
|
620
|
+
release: The release name. This will be mapped to a huggingface repo id based on the pretrained_saes.yaml file.
|
|
621
|
+
id: The id of the SAE to load. This will be mapped to a path in the huggingface repo.
|
|
622
|
+
device: The device to load the SAE on.
|
|
623
|
+
"""
|
|
624
|
+
|
|
625
|
+
# get sae directory
|
|
626
|
+
sae_directory = get_pretrained_saes_directory()
|
|
627
|
+
|
|
628
|
+
# Validate release and sae_id
|
|
629
|
+
if release not in sae_directory:
|
|
630
|
+
if "/" not in release:
|
|
631
|
+
raise ValueError(
|
|
632
|
+
f"Release {release} not found in pretrained SAEs directory, and is not a valid huggingface repo."
|
|
633
|
+
)
|
|
634
|
+
elif sae_id not in sae_directory[release].saes_map:
|
|
635
|
+
# Handle special cases like Gemma Scope
|
|
636
|
+
if (
|
|
637
|
+
"gemma-scope" in release
|
|
638
|
+
and "canonical" not in release
|
|
639
|
+
and f"{release}-canonical" in sae_directory
|
|
640
|
+
):
|
|
641
|
+
canonical_ids = list(
|
|
642
|
+
sae_directory[release + "-canonical"].saes_map.keys()
|
|
643
|
+
)
|
|
644
|
+
# Shorten the lengthy string of valid IDs
|
|
645
|
+
if len(canonical_ids) > 5:
|
|
646
|
+
str_canonical_ids = str(canonical_ids[:5])[:-1] + ", ...]"
|
|
647
|
+
else:
|
|
648
|
+
str_canonical_ids = str(canonical_ids)
|
|
649
|
+
value_suffix = f" If you don't want to specify an L0 value, consider using release {release}-canonical which has valid IDs {str_canonical_ids}"
|
|
650
|
+
else:
|
|
651
|
+
value_suffix = ""
|
|
652
|
+
|
|
653
|
+
valid_ids = list(sae_directory[release].saes_map.keys())
|
|
654
|
+
# Shorten the lengthy string of valid IDs
|
|
655
|
+
if len(valid_ids) > 5:
|
|
656
|
+
str_valid_ids = str(valid_ids[:5])[:-1] + ", ...]"
|
|
657
|
+
else:
|
|
658
|
+
str_valid_ids = str(valid_ids)
|
|
659
|
+
|
|
660
|
+
raise ValueError(
|
|
661
|
+
f"ID {sae_id} not found in release {release}. Valid IDs are {str_valid_ids}."
|
|
662
|
+
+ value_suffix
|
|
663
|
+
)
|
|
664
|
+
|
|
665
|
+
conversion_loader = (
|
|
666
|
+
converter
|
|
667
|
+
or NAMED_PRETRAINED_SAE_LOADERS[get_conversion_loader_name(release)]
|
|
668
|
+
)
|
|
669
|
+
repo_id, folder_name = get_repo_id_and_folder_name(release, sae_id)
|
|
670
|
+
config_overrides = get_config_overrides(release, sae_id)
|
|
671
|
+
config_overrides["device"] = device
|
|
672
|
+
|
|
673
|
+
# Load config and weights
|
|
674
|
+
cfg_dict, state_dict, log_sparsities = conversion_loader(
|
|
675
|
+
repo_id=repo_id,
|
|
676
|
+
folder_name=folder_name,
|
|
677
|
+
device=device,
|
|
678
|
+
force_download=force_download,
|
|
679
|
+
cfg_overrides=config_overrides,
|
|
680
|
+
)
|
|
681
|
+
cfg_dict = handle_config_defaulting(cfg_dict)
|
|
682
|
+
|
|
683
|
+
# Create SAE with appropriate architecture
|
|
684
|
+
sae_config_cls = cls.get_sae_config_class_for_architecture(
|
|
685
|
+
cfg_dict["architecture"]
|
|
686
|
+
)
|
|
687
|
+
sae_cfg = sae_config_cls.from_dict(cfg_dict)
|
|
688
|
+
sae_cls = cls.get_sae_class_for_architecture(sae_cfg.architecture())
|
|
689
|
+
sae = sae_cls(sae_cfg)
|
|
690
|
+
sae.process_state_dict_for_loading(state_dict)
|
|
691
|
+
sae.load_state_dict(state_dict)
|
|
692
|
+
|
|
693
|
+
# Apply normalization if needed
|
|
694
|
+
if cfg_dict.get("normalize_activations") == "expected_average_only_in":
|
|
695
|
+
norm_scaling_factor = get_norm_scaling_factor(release, sae_id)
|
|
696
|
+
if norm_scaling_factor is not None:
|
|
697
|
+
sae.fold_activation_norm_scaling_factor(norm_scaling_factor)
|
|
698
|
+
cfg_dict["normalize_activations"] = "none"
|
|
699
|
+
else:
|
|
700
|
+
warnings.warn(
|
|
701
|
+
f"norm_scaling_factor not found for {release} and {sae_id}, but normalize_activations is 'expected_average_only_in'. Skipping normalization folding."
|
|
702
|
+
)
|
|
703
|
+
|
|
704
|
+
return sae, cfg_dict, log_sparsities
|
|
705
|
+
|
|
706
|
+
@classmethod
|
|
707
|
+
def from_dict(cls: Type[T_SAE], config_dict: dict[str, Any]) -> T_SAE:
|
|
708
|
+
"""Create an SAE from a config dictionary."""
|
|
709
|
+
sae_cls = cls.get_sae_class_for_architecture(config_dict["architecture"])
|
|
710
|
+
sae_config_cls = cls.get_sae_config_class_for_architecture(
|
|
711
|
+
config_dict["architecture"]
|
|
712
|
+
)
|
|
713
|
+
return sae_cls(sae_config_cls.from_dict(config_dict))
|
|
714
|
+
|
|
715
|
+
@classmethod
|
|
716
|
+
def get_sae_class_for_architecture(
|
|
717
|
+
cls: Type[T_SAE], architecture: str
|
|
718
|
+
) -> Type[T_SAE]:
|
|
719
|
+
"""Get the SAE class for a given architecture."""
|
|
720
|
+
sae_cls, _ = get_sae_class(architecture)
|
|
721
|
+
if not issubclass(sae_cls, cls):
|
|
722
|
+
raise ValueError(
|
|
723
|
+
f"Loaded SAE is not of type {cls.__name__}. Use {sae_cls.__name__} instead"
|
|
724
|
+
)
|
|
725
|
+
return sae_cls
|
|
726
|
+
|
|
727
|
+
# in the future, this can be used to load different config classes for different architectures
|
|
728
|
+
@classmethod
|
|
729
|
+
def get_sae_config_class_for_architecture(
|
|
730
|
+
cls,
|
|
731
|
+
architecture: str, # noqa: ARG003
|
|
732
|
+
) -> type[SAEConfig]:
|
|
733
|
+
return SAEConfig
|
|
734
|
+
|
|
735
|
+
### Methods to support deprecated usage of SAE.from_pretrained() ###
|
|
736
|
+
|
|
737
|
+
def __getitem__(self, index: int) -> Any:
|
|
738
|
+
"""
|
|
739
|
+
Support indexing for backward compatibility with tuple unpacking.
|
|
740
|
+
DEPRECATED: SAE.from_pretrained() no longer returns a tuple.
|
|
741
|
+
Use SAE.from_pretrained_with_cfg_and_sparsity() instead.
|
|
742
|
+
"""
|
|
743
|
+
warnings.warn(
|
|
744
|
+
"Indexing SAE objects is deprecated. SAE.from_pretrained() now returns "
|
|
745
|
+
"only the SAE object. Use SAE.from_pretrained_with_cfg_and_sparsity() "
|
|
746
|
+
"to get the config dict and sparsity as well.",
|
|
747
|
+
DeprecationWarning,
|
|
748
|
+
stacklevel=2,
|
|
749
|
+
)
|
|
750
|
+
|
|
751
|
+
if index == 0:
|
|
752
|
+
return self
|
|
753
|
+
if index == 1:
|
|
754
|
+
return self.cfg.to_dict()
|
|
755
|
+
if index == 2:
|
|
756
|
+
return None
|
|
757
|
+
raise IndexError(f"SAE tuple index {index} out of range")
|
|
758
|
+
|
|
759
|
+
def __iter__(self):
|
|
760
|
+
"""
|
|
761
|
+
Support unpacking for backward compatibility with tuple unpacking.
|
|
762
|
+
DEPRECATED: SAE.from_pretrained() no longer returns a tuple.
|
|
763
|
+
Use SAE.from_pretrained_with_cfg_and_sparsity() instead.
|
|
764
|
+
"""
|
|
765
|
+
warnings.warn(
|
|
766
|
+
"Unpacking SAE objects is deprecated. SAE.from_pretrained() now returns "
|
|
767
|
+
"only the SAE object. Use SAE.from_pretrained_with_cfg_and_sparsity() "
|
|
768
|
+
"to get the config dict and sparsity as well.",
|
|
769
|
+
DeprecationWarning,
|
|
770
|
+
stacklevel=2,
|
|
771
|
+
)
|
|
772
|
+
|
|
773
|
+
yield self
|
|
774
|
+
yield self.cfg.to_dict()
|
|
775
|
+
yield None
|
|
776
|
+
|
|
777
|
+
def __len__(self) -> int:
|
|
778
|
+
"""
|
|
779
|
+
Support len() for backward compatibility with tuple unpacking.
|
|
780
|
+
DEPRECATED: SAE.from_pretrained() no longer returns a tuple.
|
|
781
|
+
Use SAE.from_pretrained_with_cfg_and_sparsity() instead.
|
|
782
|
+
"""
|
|
783
|
+
warnings.warn(
|
|
784
|
+
"Getting length of SAE objects is deprecated. SAE.from_pretrained() now returns "
|
|
785
|
+
"only the SAE object. Use SAE.from_pretrained_with_cfg_and_sparsity() "
|
|
786
|
+
"to get the config dict and sparsity as well.",
|
|
787
|
+
DeprecationWarning,
|
|
788
|
+
stacklevel=2,
|
|
789
|
+
)
|
|
790
|
+
|
|
791
|
+
return 3
|
|
792
|
+
|
|
793
|
+
|
|
794
|
+
@dataclass(kw_only=True)
|
|
795
|
+
class TrainingSAEConfig(SAEConfig, ABC):
|
|
796
|
+
# https://transformer-circuits.pub/2024/april-update/index.html#training-saes
|
|
797
|
+
# 0.1 corresponds to the "heuristic" initialization, use None to disable
|
|
798
|
+
decoder_init_norm: float | None = 0.1
|
|
799
|
+
|
|
800
|
+
@classmethod
|
|
801
|
+
@abstractmethod
|
|
802
|
+
def architecture(cls) -> str: ...
|
|
803
|
+
|
|
804
|
+
@classmethod
|
|
805
|
+
def from_sae_runner_config(
|
|
806
|
+
cls: type[T_TRAINING_SAE_CONFIG],
|
|
807
|
+
cfg: "LanguageModelSAERunnerConfig[T_TRAINING_SAE_CONFIG]",
|
|
808
|
+
) -> T_TRAINING_SAE_CONFIG:
|
|
809
|
+
metadata = SAEMetadata(
|
|
810
|
+
model_name=cfg.model_name,
|
|
811
|
+
hook_name=cfg.hook_name,
|
|
812
|
+
hook_head_index=cfg.hook_head_index,
|
|
813
|
+
context_size=cfg.context_size,
|
|
814
|
+
prepend_bos=cfg.prepend_bos,
|
|
815
|
+
seqpos_slice=cfg.seqpos_slice,
|
|
816
|
+
model_from_pretrained_kwargs=cfg.model_from_pretrained_kwargs or {},
|
|
817
|
+
)
|
|
818
|
+
if not isinstance(cfg.sae, cls):
|
|
819
|
+
raise ValueError(
|
|
820
|
+
f"SAE config class {cls} does not match SAE runner config class {type(cfg.sae)}"
|
|
821
|
+
)
|
|
822
|
+
return replace(cfg.sae, metadata=metadata)
|
|
823
|
+
|
|
824
|
+
@classmethod
|
|
825
|
+
def from_dict(
|
|
826
|
+
cls: type[T_TRAINING_SAE_CONFIG], config_dict: dict[str, Any]
|
|
827
|
+
) -> T_TRAINING_SAE_CONFIG:
|
|
828
|
+
cfg_class = cls
|
|
829
|
+
if "architecture" in config_dict:
|
|
830
|
+
cfg_class = get_sae_training_class(config_dict["architecture"])[1]
|
|
831
|
+
if not issubclass(cfg_class, cls):
|
|
832
|
+
raise ValueError(
|
|
833
|
+
f"SAE config class {cls} does not match dict config class {type(cfg_class)}"
|
|
834
|
+
)
|
|
835
|
+
# remove any keys that are not in the dataclass
|
|
836
|
+
# since we sometimes enhance the config with the whole LM runner config
|
|
837
|
+
valid_config_dict = filter_valid_dataclass_fields(config_dict, cfg_class)
|
|
838
|
+
if "metadata" in config_dict:
|
|
839
|
+
valid_config_dict["metadata"] = SAEMetadata(**config_dict["metadata"])
|
|
840
|
+
return cfg_class(**valid_config_dict)
|
|
841
|
+
|
|
842
|
+
def to_dict(self) -> dict[str, Any]:
|
|
843
|
+
return {
|
|
844
|
+
**super().to_dict(),
|
|
845
|
+
**asdict(self),
|
|
846
|
+
"metadata": self.metadata.to_dict(),
|
|
847
|
+
"architecture": self.architecture(),
|
|
848
|
+
}
|
|
849
|
+
|
|
850
|
+
# this needs to exist so we can initialize the parent sae cfg without the training specific
|
|
851
|
+
# parameters. Maybe there's a cleaner way to do this
|
|
852
|
+
def get_base_sae_cfg_dict(self) -> dict[str, Any]:
|
|
853
|
+
"""
|
|
854
|
+
Creates a dictionary containing attributes corresponding to the fields
|
|
855
|
+
defined in the base SAEConfig class.
|
|
856
|
+
"""
|
|
857
|
+
base_sae_cfg_class = get_sae_class(self.architecture())[1]
|
|
858
|
+
base_config_field_names = {f.name for f in fields(base_sae_cfg_class)}
|
|
859
|
+
result_dict = {
|
|
860
|
+
field_name: getattr(self, field_name)
|
|
861
|
+
for field_name in base_config_field_names
|
|
862
|
+
}
|
|
863
|
+
result_dict["architecture"] = self.architecture()
|
|
864
|
+
result_dict["metadata"] = self.metadata.to_dict()
|
|
865
|
+
return result_dict
|
|
866
|
+
|
|
867
|
+
|
|
868
|
+
class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
|
|
869
|
+
"""Abstract base class for training versions of SAEs."""
|
|
870
|
+
|
|
871
|
+
def __init__(self, cfg: T_TRAINING_SAE_CONFIG, use_error_term: bool = False):
|
|
872
|
+
super().__init__(cfg, use_error_term)
|
|
873
|
+
|
|
874
|
+
# Turn off hook_z reshaping for training mode - the activation store
|
|
875
|
+
# is expected to handle reshaping before passing data to the SAE
|
|
876
|
+
self.turn_off_forward_pass_hook_z_reshaping()
|
|
877
|
+
self.mse_loss_fn = mse_loss
|
|
878
|
+
|
|
879
|
+
@abstractmethod
|
|
880
|
+
def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]: ...
|
|
881
|
+
|
|
882
|
+
@abstractmethod
|
|
883
|
+
def encode_with_hidden_pre(
|
|
884
|
+
self, x: Float[torch.Tensor, "... d_in"]
|
|
885
|
+
) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
|
|
886
|
+
"""Encode with access to pre-activation values for training."""
|
|
887
|
+
...
|
|
888
|
+
|
|
889
|
+
def encode(
|
|
890
|
+
self, x: Float[torch.Tensor, "... d_in"]
|
|
891
|
+
) -> Float[torch.Tensor, "... d_sae"]:
|
|
892
|
+
"""
|
|
893
|
+
For inference, just encode without returning hidden_pre.
|
|
894
|
+
(training_forward_pass calls encode_with_hidden_pre).
|
|
895
|
+
"""
|
|
896
|
+
feature_acts, _ = self.encode_with_hidden_pre(x)
|
|
897
|
+
return feature_acts
|
|
898
|
+
|
|
899
|
+
def decode(
|
|
900
|
+
self, feature_acts: Float[torch.Tensor, "... d_sae"]
|
|
901
|
+
) -> Float[torch.Tensor, "... d_in"]:
|
|
902
|
+
"""
|
|
903
|
+
Decodes feature activations back into input space,
|
|
904
|
+
applying optional finetuning scale, hooking, out normalization, etc.
|
|
905
|
+
"""
|
|
906
|
+
sae_out_pre = feature_acts @ self.W_dec + self.b_dec
|
|
907
|
+
sae_out_pre = self.hook_sae_recons(sae_out_pre)
|
|
908
|
+
sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
|
|
909
|
+
return self.reshape_fn_out(sae_out_pre, self.d_head)
|
|
910
|
+
|
|
911
|
+
@override
|
|
912
|
+
def initialize_weights(self):
|
|
913
|
+
super().initialize_weights()
|
|
914
|
+
if self.cfg.decoder_init_norm is not None:
|
|
915
|
+
with torch.no_grad():
|
|
916
|
+
self.W_dec.data /= self.W_dec.norm(dim=-1, keepdim=True)
|
|
917
|
+
self.W_dec.data *= self.cfg.decoder_init_norm
|
|
918
|
+
self.W_enc.data = self.W_dec.data.T.clone().detach().contiguous()
|
|
919
|
+
|
|
920
|
+
@abstractmethod
|
|
921
|
+
def calculate_aux_loss(
|
|
922
|
+
self,
|
|
923
|
+
step_input: TrainStepInput,
|
|
924
|
+
feature_acts: torch.Tensor,
|
|
925
|
+
hidden_pre: torch.Tensor,
|
|
926
|
+
sae_out: torch.Tensor,
|
|
927
|
+
) -> torch.Tensor | dict[str, torch.Tensor]:
|
|
928
|
+
"""Calculate architecture-specific auxiliary loss terms."""
|
|
929
|
+
...
|
|
930
|
+
|
|
931
|
+
def training_forward_pass(
|
|
932
|
+
self,
|
|
933
|
+
step_input: TrainStepInput,
|
|
934
|
+
) -> TrainStepOutput:
|
|
935
|
+
"""Forward pass during training."""
|
|
936
|
+
feature_acts, hidden_pre = self.encode_with_hidden_pre(step_input.sae_in)
|
|
937
|
+
sae_out = self.decode(feature_acts)
|
|
938
|
+
|
|
939
|
+
# Calculate MSE loss
|
|
940
|
+
per_item_mse_loss = self.mse_loss_fn(sae_out, step_input.sae_in)
|
|
941
|
+
mse_loss = per_item_mse_loss.sum(dim=-1).mean()
|
|
942
|
+
|
|
943
|
+
# Calculate architecture-specific auxiliary losses
|
|
944
|
+
aux_losses = self.calculate_aux_loss(
|
|
945
|
+
step_input=step_input,
|
|
946
|
+
feature_acts=feature_acts,
|
|
947
|
+
hidden_pre=hidden_pre,
|
|
948
|
+
sae_out=sae_out,
|
|
949
|
+
)
|
|
950
|
+
|
|
951
|
+
# Total loss is MSE plus all auxiliary losses
|
|
952
|
+
total_loss = mse_loss
|
|
953
|
+
|
|
954
|
+
# Create losses dictionary with mse_loss
|
|
955
|
+
losses = {"mse_loss": mse_loss}
|
|
956
|
+
|
|
957
|
+
# Add architecture-specific losses to the dictionary
|
|
958
|
+
# Make sure aux_losses is a dictionary with string keys and tensor values
|
|
959
|
+
if isinstance(aux_losses, dict):
|
|
960
|
+
losses.update(aux_losses)
|
|
961
|
+
|
|
962
|
+
# Sum all losses for total_loss
|
|
963
|
+
if isinstance(aux_losses, dict):
|
|
964
|
+
for loss_value in aux_losses.values():
|
|
965
|
+
total_loss = total_loss + loss_value
|
|
966
|
+
else:
|
|
967
|
+
# Handle case where aux_losses is a tensor
|
|
968
|
+
total_loss = total_loss + aux_losses
|
|
969
|
+
|
|
970
|
+
return TrainStepOutput(
|
|
971
|
+
sae_in=step_input.sae_in,
|
|
972
|
+
sae_out=sae_out,
|
|
973
|
+
feature_acts=feature_acts,
|
|
974
|
+
hidden_pre=hidden_pre,
|
|
975
|
+
loss=total_loss,
|
|
976
|
+
losses=losses,
|
|
977
|
+
)
|
|
978
|
+
|
|
979
|
+
def save_inference_model(self, path: str | Path) -> tuple[Path, Path]:
|
|
980
|
+
"""Save inference version of model weights and config to disk."""
|
|
981
|
+
path = Path(path)
|
|
982
|
+
path.mkdir(parents=True, exist_ok=True)
|
|
983
|
+
|
|
984
|
+
# Generate the weights
|
|
985
|
+
state_dict = self.state_dict() # Use internal SAE state dict
|
|
986
|
+
self.process_state_dict_for_saving_inference(state_dict)
|
|
987
|
+
model_weights_path = path / SAE_WEIGHTS_FILENAME
|
|
988
|
+
save_file(state_dict, model_weights_path)
|
|
989
|
+
|
|
990
|
+
# Save the config
|
|
991
|
+
config = self.to_inference_config_dict()
|
|
992
|
+
cfg_path = path / SAE_CFG_FILENAME
|
|
993
|
+
with open(cfg_path, "w") as f:
|
|
994
|
+
json.dump(config, f)
|
|
995
|
+
|
|
996
|
+
return model_weights_path, cfg_path
|
|
997
|
+
|
|
998
|
+
@abstractmethod
|
|
999
|
+
def to_inference_config_dict(self) -> dict[str, Any]:
|
|
1000
|
+
"""Convert the config into an inference SAE config dict."""
|
|
1001
|
+
...
|
|
1002
|
+
|
|
1003
|
+
def process_state_dict_for_saving_inference(
|
|
1004
|
+
self, state_dict: dict[str, Any]
|
|
1005
|
+
) -> None:
|
|
1006
|
+
"""
|
|
1007
|
+
Process the state dict for saving the inference model.
|
|
1008
|
+
This is a hook that can be overridden to change how the state dict is processed for the inference model.
|
|
1009
|
+
"""
|
|
1010
|
+
return self.process_state_dict_for_saving(state_dict)
|
|
1011
|
+
|
|
1012
|
+
@torch.no_grad()
|
|
1013
|
+
def remove_gradient_parallel_to_decoder_directions(self) -> None:
|
|
1014
|
+
"""Remove gradient components parallel to decoder directions."""
|
|
1015
|
+
# Implement the original logic since this may not be in the base class
|
|
1016
|
+
assert self.W_dec.grad is not None
|
|
1017
|
+
|
|
1018
|
+
parallel_component = einops.einsum(
|
|
1019
|
+
self.W_dec.grad,
|
|
1020
|
+
self.W_dec.data,
|
|
1021
|
+
"d_sae d_in, d_sae d_in -> d_sae",
|
|
1022
|
+
)
|
|
1023
|
+
self.W_dec.grad -= einops.einsum(
|
|
1024
|
+
parallel_component,
|
|
1025
|
+
self.W_dec.data,
|
|
1026
|
+
"d_sae, d_sae d_in -> d_sae d_in",
|
|
1027
|
+
)
|
|
1028
|
+
|
|
1029
|
+
@torch.no_grad()
|
|
1030
|
+
def log_histograms(self) -> dict[str, NDArray[Any]]:
|
|
1031
|
+
"""Log histograms of the weights and biases."""
|
|
1032
|
+
W_dec_norm_dist = self.W_dec.detach().float().norm(dim=1).cpu().numpy()
|
|
1033
|
+
return {
|
|
1034
|
+
"weights/W_dec_norms": W_dec_norm_dist,
|
|
1035
|
+
}
|
|
1036
|
+
|
|
1037
|
+
@classmethod
|
|
1038
|
+
def get_sae_class_for_architecture(
|
|
1039
|
+
cls: Type[T_TRAINING_SAE], architecture: str
|
|
1040
|
+
) -> Type[T_TRAINING_SAE]:
|
|
1041
|
+
"""Get the SAE class for a given architecture."""
|
|
1042
|
+
sae_cls, _ = get_sae_training_class(architecture)
|
|
1043
|
+
if not issubclass(sae_cls, cls):
|
|
1044
|
+
raise ValueError(
|
|
1045
|
+
f"Loaded SAE is not of type {cls.__name__}. Use {sae_cls.__name__} instead"
|
|
1046
|
+
)
|
|
1047
|
+
return sae_cls
|
|
1048
|
+
|
|
1049
|
+
# in the future, this can be used to load different config classes for different architectures
|
|
1050
|
+
@classmethod
|
|
1051
|
+
def get_sae_config_class_for_architecture(
|
|
1052
|
+
cls,
|
|
1053
|
+
architecture: str, # noqa: ARG003
|
|
1054
|
+
) -> type[TrainingSAEConfig]:
|
|
1055
|
+
return get_sae_training_class(architecture)[1]
|
|
1056
|
+
|
|
1057
|
+
|
|
1058
|
+
_blank_hook = nn.Identity()
|
|
1059
|
+
|
|
1060
|
+
|
|
1061
|
+
@contextmanager
|
|
1062
|
+
def _disable_hooks(sae: SAE[Any]):
|
|
1063
|
+
"""
|
|
1064
|
+
Temporarily disable hooks for the SAE. Swaps out all the hooks with a fake modules that does nothing.
|
|
1065
|
+
"""
|
|
1066
|
+
try:
|
|
1067
|
+
for hook_name in sae.hook_dict:
|
|
1068
|
+
setattr(sae, hook_name, _blank_hook)
|
|
1069
|
+
yield
|
|
1070
|
+
finally:
|
|
1071
|
+
for hook_name, hook in sae.hook_dict.items():
|
|
1072
|
+
setattr(sae, hook_name, hook)
|
|
1073
|
+
|
|
1074
|
+
|
|
1075
|
+
def mse_loss(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
1076
|
+
return torch.nn.functional.mse_loss(preds, target, reduction="none")
|