sae-lens 5.10.7__py3-none-any.whl → 6.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. sae_lens/__init__.py +60 -7
  2. sae_lens/analysis/hooked_sae_transformer.py +12 -12
  3. sae_lens/analysis/neuronpedia_integration.py +16 -14
  4. sae_lens/cache_activations_runner.py +9 -7
  5. sae_lens/config.py +170 -257
  6. sae_lens/constants.py +21 -0
  7. sae_lens/evals.py +59 -44
  8. sae_lens/llm_sae_training_runner.py +377 -0
  9. sae_lens/load_model.py +53 -5
  10. sae_lens/{toolkit → loading}/pretrained_sae_loaders.py +228 -32
  11. sae_lens/registry.py +49 -0
  12. sae_lens/saes/__init__.py +48 -0
  13. sae_lens/saes/gated_sae.py +254 -0
  14. sae_lens/saes/jumprelu_sae.py +348 -0
  15. sae_lens/saes/sae.py +1076 -0
  16. sae_lens/saes/standard_sae.py +178 -0
  17. sae_lens/saes/topk_sae.py +300 -0
  18. sae_lens/training/activation_scaler.py +53 -0
  19. sae_lens/training/activations_store.py +103 -184
  20. sae_lens/training/mixing_buffer.py +56 -0
  21. sae_lens/training/optim.py +60 -36
  22. sae_lens/training/sae_trainer.py +155 -177
  23. sae_lens/training/types.py +5 -0
  24. sae_lens/training/upload_saes_to_huggingface.py +13 -7
  25. sae_lens/util.py +47 -0
  26. {sae_lens-5.10.7.dist-info → sae_lens-6.0.0.dist-info}/METADATA +1 -1
  27. sae_lens-6.0.0.dist-info/RECORD +37 -0
  28. sae_lens/sae.py +0 -747
  29. sae_lens/sae_training_runner.py +0 -251
  30. sae_lens/training/geometric_median.py +0 -101
  31. sae_lens/training/training_sae.py +0 -710
  32. sae_lens-5.10.7.dist-info/RECORD +0 -28
  33. /sae_lens/{toolkit → loading}/__init__.py +0 -0
  34. /sae_lens/{toolkit → loading}/pretrained_saes_directory.py +0 -0
  35. {sae_lens-5.10.7.dist-info → sae_lens-6.0.0.dist-info}/LICENSE +0 -0
  36. {sae_lens-5.10.7.dist-info → sae_lens-6.0.0.dist-info}/WHEEL +0 -0
sae_lens/saes/sae.py ADDED
@@ -0,0 +1,1076 @@
1
+ """Base classes for Sparse Autoencoders (SAEs)."""
2
+
3
+ import copy
4
+ import json
5
+ import warnings
6
+ from abc import ABC, abstractmethod
7
+ from contextlib import contextmanager
8
+ from dataclasses import asdict, dataclass, field, fields, replace
9
+ from pathlib import Path
10
+ from typing import (
11
+ TYPE_CHECKING,
12
+ Any,
13
+ Callable,
14
+ Generic,
15
+ Literal,
16
+ NamedTuple,
17
+ Type,
18
+ TypeVar,
19
+ )
20
+
21
+ import einops
22
+ import torch
23
+ from jaxtyping import Float
24
+ from numpy.typing import NDArray
25
+ from safetensors.torch import save_file
26
+ from torch import nn
27
+ from transformer_lens.hook_points import HookedRootModule, HookPoint
28
+ from typing_extensions import deprecated, overload, override
29
+
30
+ from sae_lens import __version__, logger
31
+ from sae_lens.constants import (
32
+ DTYPE_MAP,
33
+ SAE_CFG_FILENAME,
34
+ SAE_WEIGHTS_FILENAME,
35
+ )
36
+ from sae_lens.util import filter_valid_dataclass_fields
37
+
38
+ if TYPE_CHECKING:
39
+ from sae_lens.config import LanguageModelSAERunnerConfig
40
+
41
+ from sae_lens.loading.pretrained_sae_loaders import (
42
+ NAMED_PRETRAINED_SAE_LOADERS,
43
+ PretrainedSaeDiskLoader,
44
+ PretrainedSaeHuggingfaceLoader,
45
+ get_conversion_loader_name,
46
+ handle_config_defaulting,
47
+ sae_lens_disk_loader,
48
+ )
49
+ from sae_lens.loading.pretrained_saes_directory import (
50
+ get_config_overrides,
51
+ get_norm_scaling_factor,
52
+ get_pretrained_saes_directory,
53
+ get_repo_id_and_folder_name,
54
+ )
55
+ from sae_lens.registry import get_sae_class, get_sae_training_class
56
+
57
+ T_SAE_CONFIG = TypeVar("T_SAE_CONFIG", bound="SAEConfig")
58
+ T_TRAINING_SAE_CONFIG = TypeVar("T_TRAINING_SAE_CONFIG", bound="TrainingSAEConfig")
59
+ T_SAE = TypeVar("T_SAE", bound="SAE") # type: ignore
60
+ T_TRAINING_SAE = TypeVar("T_TRAINING_SAE", bound="TrainingSAE") # type: ignore
61
+
62
+
63
+ class SAEMetadata:
64
+ """Core metadata about how this SAE should be used, if known."""
65
+
66
+ def __init__(self, **kwargs: Any):
67
+ # Set default version fields with their current behavior
68
+ self.sae_lens_version = kwargs.pop("sae_lens_version", __version__)
69
+ self.sae_lens_training_version = kwargs.pop(
70
+ "sae_lens_training_version", __version__
71
+ )
72
+
73
+ # Set all other attributes dynamically
74
+ for key, value in kwargs.items():
75
+ setattr(self, key, value)
76
+
77
+ def __getattr__(self, name: str) -> None:
78
+ """Return None for any missing attribute (like defaultdict)"""
79
+ return
80
+
81
+ def __setattr__(self, name: str, value: Any) -> None:
82
+ """Allow setting any attribute"""
83
+ super().__setattr__(name, value)
84
+
85
+ def __getitem__(self, key: str) -> Any:
86
+ """Allow dictionary-style access: metadata['key'] - returns None for missing keys"""
87
+ return getattr(self, key)
88
+
89
+ def __setitem__(self, key: str, value: Any) -> None:
90
+ """Allow dictionary-style assignment: metadata['key'] = value"""
91
+ setattr(self, key, value)
92
+
93
+ def __contains__(self, key: str) -> bool:
94
+ """Allow 'in' operator: 'key' in metadata"""
95
+ # Only return True if the attribute was explicitly set (not just defaulting to None)
96
+ return key in self.__dict__
97
+
98
+ def get(self, key: str, default: Any = None) -> Any:
99
+ """Dictionary-style get with default"""
100
+ value = getattr(self, key)
101
+ # If the attribute wasn't explicitly set and we got None from __getattr__,
102
+ # use the provided default instead
103
+ if key not in self.__dict__ and value is None:
104
+ return default
105
+ return value
106
+
107
+ def keys(self):
108
+ """Return all explicitly set attribute names"""
109
+ return self.__dict__.keys()
110
+
111
+ def values(self):
112
+ """Return all explicitly set attribute values"""
113
+ return self.__dict__.values()
114
+
115
+ def items(self):
116
+ """Return all explicitly set attribute name-value pairs"""
117
+ return self.__dict__.items()
118
+
119
+ def to_dict(self) -> dict[str, Any]:
120
+ """Convert to dictionary for serialization"""
121
+ return self.__dict__.copy()
122
+
123
+ @classmethod
124
+ def from_dict(cls, data: dict[str, Any]) -> "SAEMetadata":
125
+ """Create from dictionary"""
126
+ return cls(**data)
127
+
128
+ def __repr__(self) -> str:
129
+ return f"SAEMetadata({self.__dict__})"
130
+
131
+ def __eq__(self, other: object) -> bool:
132
+ if not isinstance(other, SAEMetadata):
133
+ return False
134
+ return self.__dict__ == other.__dict__
135
+
136
+ def __deepcopy__(self, memo: dict[int, Any]) -> "SAEMetadata":
137
+ """Support for deep copying"""
138
+
139
+ return SAEMetadata(**copy.deepcopy(self.__dict__, memo))
140
+
141
+ def __getstate__(self) -> dict[str, Any]:
142
+ """Support for pickling"""
143
+ return self.__dict__
144
+
145
+ def __setstate__(self, state: dict[str, Any]) -> None:
146
+ """Support for unpickling"""
147
+ self.__dict__.update(state)
148
+
149
+
150
+ @dataclass
151
+ class SAEConfig(ABC):
152
+ """Base configuration for SAE models."""
153
+
154
+ d_in: int
155
+ d_sae: int
156
+ dtype: str = "float32"
157
+ device: str = "cpu"
158
+ apply_b_dec_to_input: bool = True
159
+ normalize_activations: Literal[
160
+ "none", "expected_average_only_in", "constant_norm_rescale", "layer_norm"
161
+ ] = "none" # none, expected_average_only_in (Anthropic April Update), constant_norm_rescale (Anthropic Feb Update)
162
+ reshape_activations: Literal["none", "hook_z"] = "none"
163
+ metadata: SAEMetadata = field(default_factory=SAEMetadata)
164
+
165
+ @classmethod
166
+ @abstractmethod
167
+ def architecture(cls) -> str: ...
168
+
169
+ def to_dict(self) -> dict[str, Any]:
170
+ res = {field.name: getattr(self, field.name) for field in fields(self)}
171
+ res["metadata"] = self.metadata.to_dict()
172
+ res["architecture"] = self.architecture()
173
+ return res
174
+
175
+ @classmethod
176
+ def from_dict(cls: type[T_SAE_CONFIG], config_dict: dict[str, Any]) -> T_SAE_CONFIG:
177
+ cfg_class = get_sae_class(config_dict["architecture"])[1]
178
+ filtered_config_dict = filter_valid_dataclass_fields(config_dict, cfg_class)
179
+ res = cfg_class(**filtered_config_dict)
180
+ if "metadata" in config_dict:
181
+ res.metadata = SAEMetadata(**config_dict["metadata"])
182
+ if not isinstance(res, cls):
183
+ raise ValueError(
184
+ f"SAE config class {cls} does not match dict config class {type(res)}"
185
+ )
186
+ return res
187
+
188
+ def __post_init__(self):
189
+ if self.normalize_activations not in [
190
+ "none",
191
+ "expected_average_only_in",
192
+ "constant_norm_rescale",
193
+ "layer_norm",
194
+ ]:
195
+ raise ValueError(
196
+ f"normalize_activations must be none, expected_average_only_in, layer_norm, or constant_norm_rescale. Got {self.normalize_activations}"
197
+ )
198
+
199
+
200
+ @dataclass
201
+ class TrainStepOutput:
202
+ """Output from a training step."""
203
+
204
+ sae_in: torch.Tensor
205
+ sae_out: torch.Tensor
206
+ feature_acts: torch.Tensor
207
+ hidden_pre: torch.Tensor
208
+ loss: torch.Tensor # we need to call backwards on this
209
+ losses: dict[str, torch.Tensor]
210
+
211
+
212
+ @dataclass
213
+ class TrainStepInput:
214
+ """Input to a training step."""
215
+
216
+ sae_in: torch.Tensor
217
+ coefficients: dict[str, float]
218
+ dead_neuron_mask: torch.Tensor | None
219
+
220
+
221
+ class TrainCoefficientConfig(NamedTuple):
222
+ value: float
223
+ warm_up_steps: int
224
+
225
+
226
+ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
227
+ """Abstract base class for all SAE architectures."""
228
+
229
+ cfg: T_SAE_CONFIG
230
+ dtype: torch.dtype
231
+ device: torch.device
232
+ use_error_term: bool
233
+
234
+ # For type checking only - don't provide default values
235
+ # These will be initialized by subclasses
236
+ W_enc: nn.Parameter
237
+ W_dec: nn.Parameter
238
+ b_dec: nn.Parameter
239
+
240
+ def __init__(self, cfg: T_SAE_CONFIG, use_error_term: bool = False):
241
+ """Initialize the SAE."""
242
+ super().__init__()
243
+
244
+ self.cfg = cfg
245
+
246
+ if cfg.metadata and cfg.metadata.model_from_pretrained_kwargs:
247
+ warnings.warn(
248
+ "\nThis SAE has non-empty model_from_pretrained_kwargs. "
249
+ "\nFor optimal performance, load the model like so:\n"
250
+ "model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)",
251
+ category=UserWarning,
252
+ stacklevel=1,
253
+ )
254
+
255
+ self.dtype = DTYPE_MAP[cfg.dtype]
256
+ self.device = torch.device(cfg.device)
257
+ self.use_error_term = use_error_term
258
+
259
+ # Set up activation function
260
+ self.activation_fn = self.get_activation_fn()
261
+
262
+ # Initialize weights
263
+ self.initialize_weights()
264
+
265
+ # Set up hooks
266
+ self.hook_sae_input = HookPoint()
267
+ self.hook_sae_acts_pre = HookPoint()
268
+ self.hook_sae_acts_post = HookPoint()
269
+ self.hook_sae_output = HookPoint()
270
+ self.hook_sae_recons = HookPoint()
271
+ self.hook_sae_error = HookPoint()
272
+
273
+ # handle hook_z reshaping if needed.
274
+ if self.cfg.reshape_activations == "hook_z":
275
+ self.turn_on_forward_pass_hook_z_reshaping()
276
+ else:
277
+ self.turn_off_forward_pass_hook_z_reshaping()
278
+
279
+ # Set up activation normalization
280
+ self._setup_activation_normalization()
281
+
282
+ self.setup() # Required for HookedRootModule
283
+
284
+ @torch.no_grad()
285
+ def fold_activation_norm_scaling_factor(self, scaling_factor: float):
286
+ self.W_enc.data *= scaling_factor # type: ignore
287
+ self.W_dec.data /= scaling_factor # type: ignore
288
+ self.b_dec.data /= scaling_factor # type: ignore
289
+ self.cfg.normalize_activations = "none"
290
+
291
+ def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
292
+ """Get the activation function specified in config."""
293
+ return nn.ReLU()
294
+
295
+ def _setup_activation_normalization(self):
296
+ """Set up activation normalization functions based on config."""
297
+ if self.cfg.normalize_activations == "constant_norm_rescale":
298
+
299
+ def run_time_activation_norm_fn_in(x: torch.Tensor) -> torch.Tensor:
300
+ self.x_norm_coeff = (self.cfg.d_in**0.5) / x.norm(dim=-1, keepdim=True)
301
+ return x * self.x_norm_coeff
302
+
303
+ def run_time_activation_norm_fn_out(x: torch.Tensor) -> torch.Tensor:
304
+ x = x / self.x_norm_coeff # type: ignore
305
+ del self.x_norm_coeff
306
+ return x
307
+
308
+ self.run_time_activation_norm_fn_in = run_time_activation_norm_fn_in
309
+ self.run_time_activation_norm_fn_out = run_time_activation_norm_fn_out
310
+ elif self.cfg.normalize_activations == "layer_norm":
311
+ # we need to scale the norm of the input and store the scaling factor
312
+ def run_time_activation_ln_in(
313
+ x: torch.Tensor, eps: float = 1e-5
314
+ ) -> torch.Tensor:
315
+ mu = x.mean(dim=-1, keepdim=True)
316
+ x = x - mu
317
+ std = x.std(dim=-1, keepdim=True)
318
+ x = x / (std + eps)
319
+ self.ln_mu = mu
320
+ self.ln_std = std
321
+ return x
322
+
323
+ def run_time_activation_ln_out(
324
+ x: torch.Tensor,
325
+ eps: float = 1e-5, # noqa: ARG001
326
+ ) -> torch.Tensor:
327
+ return x * self.ln_std + self.ln_mu # type: ignore
328
+
329
+ self.run_time_activation_norm_fn_in = run_time_activation_ln_in
330
+ self.run_time_activation_norm_fn_out = run_time_activation_ln_out
331
+ else:
332
+ self.run_time_activation_norm_fn_in = lambda x: x
333
+ self.run_time_activation_norm_fn_out = lambda x: x
334
+
335
+ def initialize_weights(self):
336
+ """Initialize model weights."""
337
+ self.b_dec = nn.Parameter(
338
+ torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
339
+ )
340
+
341
+ w_dec_data = torch.empty(
342
+ self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
343
+ )
344
+ nn.init.kaiming_uniform_(w_dec_data)
345
+ self.W_dec = nn.Parameter(w_dec_data)
346
+
347
+ w_enc_data = self.W_dec.data.T.clone().detach().contiguous()
348
+ self.W_enc = nn.Parameter(w_enc_data)
349
+
350
+ @abstractmethod
351
+ def encode(
352
+ self, x: Float[torch.Tensor, "... d_in"]
353
+ ) -> Float[torch.Tensor, "... d_sae"]:
354
+ """Encode input tensor to feature space."""
355
+ pass
356
+
357
+ @abstractmethod
358
+ def decode(
359
+ self, feature_acts: Float[torch.Tensor, "... d_sae"]
360
+ ) -> Float[torch.Tensor, "... d_in"]:
361
+ """Decode feature activations back to input space."""
362
+ pass
363
+
364
+ def turn_on_forward_pass_hook_z_reshaping(self):
365
+ if (
366
+ self.cfg.metadata.hook_name is not None
367
+ and not self.cfg.metadata.hook_name.endswith("_z")
368
+ ):
369
+ raise ValueError("This method should only be called for hook_z SAEs.")
370
+
371
+ # print(f"Turning on hook_z reshaping for {self.cfg.hook_name}")
372
+
373
+ def reshape_fn_in(x: torch.Tensor):
374
+ # print(f"reshape_fn_in input shape: {x.shape}")
375
+ self.d_head = x.shape[-1]
376
+ # print(f"Setting d_head to: {self.d_head}")
377
+ self.reshape_fn_in = lambda x: einops.rearrange(
378
+ x, "... n_heads d_head -> ... (n_heads d_head)"
379
+ )
380
+ return einops.rearrange(x, "... n_heads d_head -> ... (n_heads d_head)")
381
+
382
+ self.reshape_fn_in = reshape_fn_in
383
+ self.reshape_fn_out = lambda x, d_head: einops.rearrange(
384
+ x, "... (n_heads d_head) -> ... n_heads d_head", d_head=d_head
385
+ )
386
+ self.hook_z_reshaping_mode = True
387
+ # print(f"hook_z reshaping turned on, self.d_head={getattr(self, 'd_head', None)}")
388
+
389
+ def turn_off_forward_pass_hook_z_reshaping(self):
390
+ self.reshape_fn_in = lambda x: x
391
+ self.reshape_fn_out = lambda x, d_head: x # noqa: ARG005
392
+ self.d_head = None
393
+ self.hook_z_reshaping_mode = False
394
+
395
+ @overload
396
+ def to(
397
+ self: T_SAE,
398
+ device: torch.device | str | None = ...,
399
+ dtype: torch.dtype | None = ...,
400
+ non_blocking: bool = ...,
401
+ ) -> T_SAE: ...
402
+
403
+ @overload
404
+ def to(self: T_SAE, dtype: torch.dtype, non_blocking: bool = ...) -> T_SAE: ...
405
+
406
+ @overload
407
+ def to(self: T_SAE, tensor: torch.Tensor, non_blocking: bool = ...) -> T_SAE: ...
408
+
409
+ def to(self: T_SAE, *args: Any, **kwargs: Any) -> T_SAE: # type: ignore
410
+ device_arg = None
411
+ dtype_arg = None
412
+
413
+ # Check args
414
+ for arg in args:
415
+ if isinstance(arg, (torch.device, str)):
416
+ device_arg = arg
417
+ elif isinstance(arg, torch.dtype):
418
+ dtype_arg = arg
419
+ elif isinstance(arg, torch.Tensor):
420
+ device_arg = arg.device
421
+ dtype_arg = arg.dtype
422
+
423
+ # Check kwargs
424
+ device_arg = kwargs.get("device", device_arg)
425
+ dtype_arg = kwargs.get("dtype", dtype_arg)
426
+
427
+ # Update device in config if provided
428
+ if device_arg is not None:
429
+ # Convert device to torch.device if it's a string
430
+ device = (
431
+ torch.device(device_arg) if isinstance(device_arg, str) else device_arg
432
+ )
433
+
434
+ # Update the cfg.device
435
+ self.cfg.device = str(device)
436
+
437
+ # Update the device property
438
+ self.device = device
439
+
440
+ # Update dtype in config if provided
441
+ if dtype_arg is not None:
442
+ # Update the cfg.dtype
443
+ self.cfg.dtype = str(dtype_arg)
444
+
445
+ # Update the dtype property
446
+ self.dtype = dtype_arg
447
+
448
+ return super().to(*args, **kwargs)
449
+
450
+ def process_sae_in(
451
+ self, sae_in: Float[torch.Tensor, "... d_in"]
452
+ ) -> Float[torch.Tensor, "... d_in"]:
453
+ # print(f"Input shape to process_sae_in: {sae_in.shape}")
454
+ # print(f"self.cfg.hook_name: {self.cfg.hook_name}")
455
+ # print(f"self.b_dec shape: {self.b_dec.shape}")
456
+ # print(f"Hook z reshaping mode: {getattr(self, 'hook_z_reshaping_mode', False)}")
457
+
458
+ sae_in = sae_in.to(self.dtype)
459
+
460
+ # print(f"Shape before reshape_fn_in: {sae_in.shape}")
461
+ sae_in = self.reshape_fn_in(sae_in)
462
+ # print(f"Shape after reshape_fn_in: {sae_in.shape}")
463
+
464
+ sae_in = self.hook_sae_input(sae_in)
465
+ sae_in = self.run_time_activation_norm_fn_in(sae_in)
466
+
467
+ # Here's where the error happens
468
+ bias_term = self.b_dec * self.cfg.apply_b_dec_to_input
469
+ # print(f"Bias term shape: {bias_term.shape}")
470
+
471
+ return sae_in - bias_term
472
+
473
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
474
+ """Forward pass through the SAE."""
475
+ feature_acts = self.encode(x)
476
+ sae_out = self.decode(feature_acts)
477
+
478
+ if self.use_error_term:
479
+ with torch.no_grad():
480
+ # Recompute without hooks for true error term
481
+ with _disable_hooks(self):
482
+ feature_acts_clean = self.encode(x)
483
+ x_reconstruct_clean = self.decode(feature_acts_clean)
484
+ sae_error = self.hook_sae_error(x - x_reconstruct_clean)
485
+ sae_out = sae_out + sae_error
486
+
487
+ return self.hook_sae_output(sae_out)
488
+
489
+ # overwrite this in subclasses to modify the state_dict in-place before saving
490
+ def process_state_dict_for_saving(self, state_dict: dict[str, Any]) -> None:
491
+ pass
492
+
493
+ # overwrite this in subclasses to modify the state_dict in-place after loading
494
+ def process_state_dict_for_loading(self, state_dict: dict[str, Any]) -> None:
495
+ pass
496
+
497
+ @torch.no_grad()
498
+ def fold_W_dec_norm(self):
499
+ """Fold decoder norms into encoder."""
500
+ W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
501
+ self.W_dec.data = self.W_dec.data / W_dec_norms
502
+ self.W_enc.data = self.W_enc.data * W_dec_norms.T
503
+
504
+ # Only update b_enc if it exists (standard/jumprelu architectures)
505
+ if hasattr(self, "b_enc") and isinstance(self.b_enc, nn.Parameter):
506
+ self.b_enc.data = self.b_enc.data * W_dec_norms.squeeze()
507
+
508
+ def get_name(self):
509
+ """Generate a name for this SAE."""
510
+ return f"sae_{self.cfg.metadata.model_name}_{self.cfg.metadata.hook_name}_{self.cfg.d_sae}"
511
+
512
+ def save_model(self, path: str | Path) -> tuple[Path, Path]:
513
+ """Save model weights and config to disk."""
514
+ path = Path(path)
515
+ path.mkdir(parents=True, exist_ok=True)
516
+
517
+ # Generate the weights
518
+ state_dict = self.state_dict() # Use internal SAE state dict
519
+ self.process_state_dict_for_saving(state_dict)
520
+ model_weights_path = path / SAE_WEIGHTS_FILENAME
521
+ save_file(state_dict, model_weights_path)
522
+
523
+ # Save the config
524
+ config = self.cfg.to_dict()
525
+ cfg_path = path / SAE_CFG_FILENAME
526
+ with open(cfg_path, "w") as f:
527
+ json.dump(config, f)
528
+
529
+ return model_weights_path, cfg_path
530
+
531
+ ## Initialization Methods
532
+ @torch.no_grad()
533
+ def initialize_b_dec_with_precalculated(self, origin: torch.Tensor):
534
+ out = torch.tensor(origin, dtype=self.dtype, device=self.device)
535
+ self.b_dec.data = out
536
+
537
+ @torch.no_grad()
538
+ def initialize_b_dec_with_mean(self, all_activations: torch.Tensor):
539
+ previous_b_dec = self.b_dec.clone().cpu()
540
+ out = all_activations.mean(dim=0)
541
+
542
+ previous_distances = torch.norm(all_activations - previous_b_dec, dim=-1)
543
+ distances = torch.norm(all_activations - out, dim=-1)
544
+
545
+ logger.info("Reinitializing b_dec with mean of activations")
546
+ logger.debug(
547
+ f"Previous distances: {previous_distances.median(0).values.mean().item()}"
548
+ )
549
+ logger.debug(f"New distances: {distances.median(0).values.mean().item()}")
550
+
551
+ self.b_dec.data = out.to(self.dtype).to(self.device)
552
+
553
+ # Class methods for loading models
554
+ @classmethod
555
+ @deprecated("Use load_from_disk instead")
556
+ def load_from_pretrained(
557
+ cls: Type[T_SAE],
558
+ path: str | Path,
559
+ device: str = "cpu",
560
+ dtype: str | None = None,
561
+ ) -> T_SAE:
562
+ return cls.load_from_disk(path, device=device, dtype=dtype)
563
+
564
+ @classmethod
565
+ def load_from_disk(
566
+ cls: Type[T_SAE],
567
+ path: str | Path,
568
+ device: str = "cpu",
569
+ dtype: str | None = None,
570
+ converter: PretrainedSaeDiskLoader = sae_lens_disk_loader,
571
+ ) -> T_SAE:
572
+ overrides = {"dtype": dtype} if dtype is not None else None
573
+ cfg_dict, state_dict = converter(path, device, cfg_overrides=overrides)
574
+ cfg_dict = handle_config_defaulting(cfg_dict)
575
+ sae_config_cls = cls.get_sae_config_class_for_architecture(
576
+ cfg_dict["architecture"]
577
+ )
578
+ sae_cfg = sae_config_cls.from_dict(cfg_dict)
579
+ sae_cls = cls.get_sae_class_for_architecture(sae_cfg.architecture())
580
+ sae = sae_cls(sae_cfg)
581
+ sae.process_state_dict_for_loading(state_dict)
582
+ sae.load_state_dict(state_dict)
583
+ return sae
584
+
585
+ @classmethod
586
+ def from_pretrained(
587
+ cls: Type[T_SAE],
588
+ release: str,
589
+ sae_id: str,
590
+ device: str = "cpu",
591
+ force_download: bool = False,
592
+ converter: PretrainedSaeHuggingfaceLoader | None = None,
593
+ ) -> T_SAE:
594
+ """
595
+ Load a pretrained SAE from the Hugging Face model hub.
596
+
597
+ Args:
598
+ release: The release name. This will be mapped to a huggingface repo id based on the pretrained_saes.yaml file.
599
+ id: The id of the SAE to load. This will be mapped to a path in the huggingface repo.
600
+ device: The device to load the SAE on.
601
+ """
602
+ return cls.from_pretrained_with_cfg_and_sparsity(
603
+ release, sae_id, device, force_download, converter=converter
604
+ )[0]
605
+
606
+ @classmethod
607
+ def from_pretrained_with_cfg_and_sparsity(
608
+ cls: Type[T_SAE],
609
+ release: str,
610
+ sae_id: str,
611
+ device: str = "cpu",
612
+ force_download: bool = False,
613
+ converter: PretrainedSaeHuggingfaceLoader | None = None,
614
+ ) -> tuple[T_SAE, dict[str, Any], torch.Tensor | None]:
615
+ """
616
+ Load a pretrained SAE from the Hugging Face model hub, along with its config dict and sparsity, if present.
617
+ In SAELens <= 5.x.x, this was called SAE.from_pretrained().
618
+
619
+ Args:
620
+ release: The release name. This will be mapped to a huggingface repo id based on the pretrained_saes.yaml file.
621
+ id: The id of the SAE to load. This will be mapped to a path in the huggingface repo.
622
+ device: The device to load the SAE on.
623
+ """
624
+
625
+ # get sae directory
626
+ sae_directory = get_pretrained_saes_directory()
627
+
628
+ # Validate release and sae_id
629
+ if release not in sae_directory:
630
+ if "/" not in release:
631
+ raise ValueError(
632
+ f"Release {release} not found in pretrained SAEs directory, and is not a valid huggingface repo."
633
+ )
634
+ elif sae_id not in sae_directory[release].saes_map:
635
+ # Handle special cases like Gemma Scope
636
+ if (
637
+ "gemma-scope" in release
638
+ and "canonical" not in release
639
+ and f"{release}-canonical" in sae_directory
640
+ ):
641
+ canonical_ids = list(
642
+ sae_directory[release + "-canonical"].saes_map.keys()
643
+ )
644
+ # Shorten the lengthy string of valid IDs
645
+ if len(canonical_ids) > 5:
646
+ str_canonical_ids = str(canonical_ids[:5])[:-1] + ", ...]"
647
+ else:
648
+ str_canonical_ids = str(canonical_ids)
649
+ value_suffix = f" If you don't want to specify an L0 value, consider using release {release}-canonical which has valid IDs {str_canonical_ids}"
650
+ else:
651
+ value_suffix = ""
652
+
653
+ valid_ids = list(sae_directory[release].saes_map.keys())
654
+ # Shorten the lengthy string of valid IDs
655
+ if len(valid_ids) > 5:
656
+ str_valid_ids = str(valid_ids[:5])[:-1] + ", ...]"
657
+ else:
658
+ str_valid_ids = str(valid_ids)
659
+
660
+ raise ValueError(
661
+ f"ID {sae_id} not found in release {release}. Valid IDs are {str_valid_ids}."
662
+ + value_suffix
663
+ )
664
+
665
+ conversion_loader = (
666
+ converter
667
+ or NAMED_PRETRAINED_SAE_LOADERS[get_conversion_loader_name(release)]
668
+ )
669
+ repo_id, folder_name = get_repo_id_and_folder_name(release, sae_id)
670
+ config_overrides = get_config_overrides(release, sae_id)
671
+ config_overrides["device"] = device
672
+
673
+ # Load config and weights
674
+ cfg_dict, state_dict, log_sparsities = conversion_loader(
675
+ repo_id=repo_id,
676
+ folder_name=folder_name,
677
+ device=device,
678
+ force_download=force_download,
679
+ cfg_overrides=config_overrides,
680
+ )
681
+ cfg_dict = handle_config_defaulting(cfg_dict)
682
+
683
+ # Create SAE with appropriate architecture
684
+ sae_config_cls = cls.get_sae_config_class_for_architecture(
685
+ cfg_dict["architecture"]
686
+ )
687
+ sae_cfg = sae_config_cls.from_dict(cfg_dict)
688
+ sae_cls = cls.get_sae_class_for_architecture(sae_cfg.architecture())
689
+ sae = sae_cls(sae_cfg)
690
+ sae.process_state_dict_for_loading(state_dict)
691
+ sae.load_state_dict(state_dict)
692
+
693
+ # Apply normalization if needed
694
+ if cfg_dict.get("normalize_activations") == "expected_average_only_in":
695
+ norm_scaling_factor = get_norm_scaling_factor(release, sae_id)
696
+ if norm_scaling_factor is not None:
697
+ sae.fold_activation_norm_scaling_factor(norm_scaling_factor)
698
+ cfg_dict["normalize_activations"] = "none"
699
+ else:
700
+ warnings.warn(
701
+ f"norm_scaling_factor not found for {release} and {sae_id}, but normalize_activations is 'expected_average_only_in'. Skipping normalization folding."
702
+ )
703
+
704
+ return sae, cfg_dict, log_sparsities
705
+
706
+ @classmethod
707
+ def from_dict(cls: Type[T_SAE], config_dict: dict[str, Any]) -> T_SAE:
708
+ """Create an SAE from a config dictionary."""
709
+ sae_cls = cls.get_sae_class_for_architecture(config_dict["architecture"])
710
+ sae_config_cls = cls.get_sae_config_class_for_architecture(
711
+ config_dict["architecture"]
712
+ )
713
+ return sae_cls(sae_config_cls.from_dict(config_dict))
714
+
715
+ @classmethod
716
+ def get_sae_class_for_architecture(
717
+ cls: Type[T_SAE], architecture: str
718
+ ) -> Type[T_SAE]:
719
+ """Get the SAE class for a given architecture."""
720
+ sae_cls, _ = get_sae_class(architecture)
721
+ if not issubclass(sae_cls, cls):
722
+ raise ValueError(
723
+ f"Loaded SAE is not of type {cls.__name__}. Use {sae_cls.__name__} instead"
724
+ )
725
+ return sae_cls
726
+
727
+ # in the future, this can be used to load different config classes for different architectures
728
+ @classmethod
729
+ def get_sae_config_class_for_architecture(
730
+ cls,
731
+ architecture: str, # noqa: ARG003
732
+ ) -> type[SAEConfig]:
733
+ return SAEConfig
734
+
735
+ ### Methods to support deprecated usage of SAE.from_pretrained() ###
736
+
737
+ def __getitem__(self, index: int) -> Any:
738
+ """
739
+ Support indexing for backward compatibility with tuple unpacking.
740
+ DEPRECATED: SAE.from_pretrained() no longer returns a tuple.
741
+ Use SAE.from_pretrained_with_cfg_and_sparsity() instead.
742
+ """
743
+ warnings.warn(
744
+ "Indexing SAE objects is deprecated. SAE.from_pretrained() now returns "
745
+ "only the SAE object. Use SAE.from_pretrained_with_cfg_and_sparsity() "
746
+ "to get the config dict and sparsity as well.",
747
+ DeprecationWarning,
748
+ stacklevel=2,
749
+ )
750
+
751
+ if index == 0:
752
+ return self
753
+ if index == 1:
754
+ return self.cfg.to_dict()
755
+ if index == 2:
756
+ return None
757
+ raise IndexError(f"SAE tuple index {index} out of range")
758
+
759
+ def __iter__(self):
760
+ """
761
+ Support unpacking for backward compatibility with tuple unpacking.
762
+ DEPRECATED: SAE.from_pretrained() no longer returns a tuple.
763
+ Use SAE.from_pretrained_with_cfg_and_sparsity() instead.
764
+ """
765
+ warnings.warn(
766
+ "Unpacking SAE objects is deprecated. SAE.from_pretrained() now returns "
767
+ "only the SAE object. Use SAE.from_pretrained_with_cfg_and_sparsity() "
768
+ "to get the config dict and sparsity as well.",
769
+ DeprecationWarning,
770
+ stacklevel=2,
771
+ )
772
+
773
+ yield self
774
+ yield self.cfg.to_dict()
775
+ yield None
776
+
777
+ def __len__(self) -> int:
778
+ """
779
+ Support len() for backward compatibility with tuple unpacking.
780
+ DEPRECATED: SAE.from_pretrained() no longer returns a tuple.
781
+ Use SAE.from_pretrained_with_cfg_and_sparsity() instead.
782
+ """
783
+ warnings.warn(
784
+ "Getting length of SAE objects is deprecated. SAE.from_pretrained() now returns "
785
+ "only the SAE object. Use SAE.from_pretrained_with_cfg_and_sparsity() "
786
+ "to get the config dict and sparsity as well.",
787
+ DeprecationWarning,
788
+ stacklevel=2,
789
+ )
790
+
791
+ return 3
792
+
793
+
794
+ @dataclass(kw_only=True)
795
+ class TrainingSAEConfig(SAEConfig, ABC):
796
+ # https://transformer-circuits.pub/2024/april-update/index.html#training-saes
797
+ # 0.1 corresponds to the "heuristic" initialization, use None to disable
798
+ decoder_init_norm: float | None = 0.1
799
+
800
+ @classmethod
801
+ @abstractmethod
802
+ def architecture(cls) -> str: ...
803
+
804
+ @classmethod
805
+ def from_sae_runner_config(
806
+ cls: type[T_TRAINING_SAE_CONFIG],
807
+ cfg: "LanguageModelSAERunnerConfig[T_TRAINING_SAE_CONFIG]",
808
+ ) -> T_TRAINING_SAE_CONFIG:
809
+ metadata = SAEMetadata(
810
+ model_name=cfg.model_name,
811
+ hook_name=cfg.hook_name,
812
+ hook_head_index=cfg.hook_head_index,
813
+ context_size=cfg.context_size,
814
+ prepend_bos=cfg.prepend_bos,
815
+ seqpos_slice=cfg.seqpos_slice,
816
+ model_from_pretrained_kwargs=cfg.model_from_pretrained_kwargs or {},
817
+ )
818
+ if not isinstance(cfg.sae, cls):
819
+ raise ValueError(
820
+ f"SAE config class {cls} does not match SAE runner config class {type(cfg.sae)}"
821
+ )
822
+ return replace(cfg.sae, metadata=metadata)
823
+
824
+ @classmethod
825
+ def from_dict(
826
+ cls: type[T_TRAINING_SAE_CONFIG], config_dict: dict[str, Any]
827
+ ) -> T_TRAINING_SAE_CONFIG:
828
+ cfg_class = cls
829
+ if "architecture" in config_dict:
830
+ cfg_class = get_sae_training_class(config_dict["architecture"])[1]
831
+ if not issubclass(cfg_class, cls):
832
+ raise ValueError(
833
+ f"SAE config class {cls} does not match dict config class {type(cfg_class)}"
834
+ )
835
+ # remove any keys that are not in the dataclass
836
+ # since we sometimes enhance the config with the whole LM runner config
837
+ valid_config_dict = filter_valid_dataclass_fields(config_dict, cfg_class)
838
+ if "metadata" in config_dict:
839
+ valid_config_dict["metadata"] = SAEMetadata(**config_dict["metadata"])
840
+ return cfg_class(**valid_config_dict)
841
+
842
+ def to_dict(self) -> dict[str, Any]:
843
+ return {
844
+ **super().to_dict(),
845
+ **asdict(self),
846
+ "metadata": self.metadata.to_dict(),
847
+ "architecture": self.architecture(),
848
+ }
849
+
850
+ # this needs to exist so we can initialize the parent sae cfg without the training specific
851
+ # parameters. Maybe there's a cleaner way to do this
852
+ def get_base_sae_cfg_dict(self) -> dict[str, Any]:
853
+ """
854
+ Creates a dictionary containing attributes corresponding to the fields
855
+ defined in the base SAEConfig class.
856
+ """
857
+ base_sae_cfg_class = get_sae_class(self.architecture())[1]
858
+ base_config_field_names = {f.name for f in fields(base_sae_cfg_class)}
859
+ result_dict = {
860
+ field_name: getattr(self, field_name)
861
+ for field_name in base_config_field_names
862
+ }
863
+ result_dict["architecture"] = self.architecture()
864
+ result_dict["metadata"] = self.metadata.to_dict()
865
+ return result_dict
866
+
867
+
868
+ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
869
+ """Abstract base class for training versions of SAEs."""
870
+
871
+ def __init__(self, cfg: T_TRAINING_SAE_CONFIG, use_error_term: bool = False):
872
+ super().__init__(cfg, use_error_term)
873
+
874
+ # Turn off hook_z reshaping for training mode - the activation store
875
+ # is expected to handle reshaping before passing data to the SAE
876
+ self.turn_off_forward_pass_hook_z_reshaping()
877
+ self.mse_loss_fn = mse_loss
878
+
879
+ @abstractmethod
880
+ def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]: ...
881
+
882
+ @abstractmethod
883
+ def encode_with_hidden_pre(
884
+ self, x: Float[torch.Tensor, "... d_in"]
885
+ ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
886
+ """Encode with access to pre-activation values for training."""
887
+ ...
888
+
889
+ def encode(
890
+ self, x: Float[torch.Tensor, "... d_in"]
891
+ ) -> Float[torch.Tensor, "... d_sae"]:
892
+ """
893
+ For inference, just encode without returning hidden_pre.
894
+ (training_forward_pass calls encode_with_hidden_pre).
895
+ """
896
+ feature_acts, _ = self.encode_with_hidden_pre(x)
897
+ return feature_acts
898
+
899
+ def decode(
900
+ self, feature_acts: Float[torch.Tensor, "... d_sae"]
901
+ ) -> Float[torch.Tensor, "... d_in"]:
902
+ """
903
+ Decodes feature activations back into input space,
904
+ applying optional finetuning scale, hooking, out normalization, etc.
905
+ """
906
+ sae_out_pre = feature_acts @ self.W_dec + self.b_dec
907
+ sae_out_pre = self.hook_sae_recons(sae_out_pre)
908
+ sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
909
+ return self.reshape_fn_out(sae_out_pre, self.d_head)
910
+
911
+ @override
912
+ def initialize_weights(self):
913
+ super().initialize_weights()
914
+ if self.cfg.decoder_init_norm is not None:
915
+ with torch.no_grad():
916
+ self.W_dec.data /= self.W_dec.norm(dim=-1, keepdim=True)
917
+ self.W_dec.data *= self.cfg.decoder_init_norm
918
+ self.W_enc.data = self.W_dec.data.T.clone().detach().contiguous()
919
+
920
+ @abstractmethod
921
+ def calculate_aux_loss(
922
+ self,
923
+ step_input: TrainStepInput,
924
+ feature_acts: torch.Tensor,
925
+ hidden_pre: torch.Tensor,
926
+ sae_out: torch.Tensor,
927
+ ) -> torch.Tensor | dict[str, torch.Tensor]:
928
+ """Calculate architecture-specific auxiliary loss terms."""
929
+ ...
930
+
931
+ def training_forward_pass(
932
+ self,
933
+ step_input: TrainStepInput,
934
+ ) -> TrainStepOutput:
935
+ """Forward pass during training."""
936
+ feature_acts, hidden_pre = self.encode_with_hidden_pre(step_input.sae_in)
937
+ sae_out = self.decode(feature_acts)
938
+
939
+ # Calculate MSE loss
940
+ per_item_mse_loss = self.mse_loss_fn(sae_out, step_input.sae_in)
941
+ mse_loss = per_item_mse_loss.sum(dim=-1).mean()
942
+
943
+ # Calculate architecture-specific auxiliary losses
944
+ aux_losses = self.calculate_aux_loss(
945
+ step_input=step_input,
946
+ feature_acts=feature_acts,
947
+ hidden_pre=hidden_pre,
948
+ sae_out=sae_out,
949
+ )
950
+
951
+ # Total loss is MSE plus all auxiliary losses
952
+ total_loss = mse_loss
953
+
954
+ # Create losses dictionary with mse_loss
955
+ losses = {"mse_loss": mse_loss}
956
+
957
+ # Add architecture-specific losses to the dictionary
958
+ # Make sure aux_losses is a dictionary with string keys and tensor values
959
+ if isinstance(aux_losses, dict):
960
+ losses.update(aux_losses)
961
+
962
+ # Sum all losses for total_loss
963
+ if isinstance(aux_losses, dict):
964
+ for loss_value in aux_losses.values():
965
+ total_loss = total_loss + loss_value
966
+ else:
967
+ # Handle case where aux_losses is a tensor
968
+ total_loss = total_loss + aux_losses
969
+
970
+ return TrainStepOutput(
971
+ sae_in=step_input.sae_in,
972
+ sae_out=sae_out,
973
+ feature_acts=feature_acts,
974
+ hidden_pre=hidden_pre,
975
+ loss=total_loss,
976
+ losses=losses,
977
+ )
978
+
979
+ def save_inference_model(self, path: str | Path) -> tuple[Path, Path]:
980
+ """Save inference version of model weights and config to disk."""
981
+ path = Path(path)
982
+ path.mkdir(parents=True, exist_ok=True)
983
+
984
+ # Generate the weights
985
+ state_dict = self.state_dict() # Use internal SAE state dict
986
+ self.process_state_dict_for_saving_inference(state_dict)
987
+ model_weights_path = path / SAE_WEIGHTS_FILENAME
988
+ save_file(state_dict, model_weights_path)
989
+
990
+ # Save the config
991
+ config = self.to_inference_config_dict()
992
+ cfg_path = path / SAE_CFG_FILENAME
993
+ with open(cfg_path, "w") as f:
994
+ json.dump(config, f)
995
+
996
+ return model_weights_path, cfg_path
997
+
998
+ @abstractmethod
999
+ def to_inference_config_dict(self) -> dict[str, Any]:
1000
+ """Convert the config into an inference SAE config dict."""
1001
+ ...
1002
+
1003
+ def process_state_dict_for_saving_inference(
1004
+ self, state_dict: dict[str, Any]
1005
+ ) -> None:
1006
+ """
1007
+ Process the state dict for saving the inference model.
1008
+ This is a hook that can be overridden to change how the state dict is processed for the inference model.
1009
+ """
1010
+ return self.process_state_dict_for_saving(state_dict)
1011
+
1012
+ @torch.no_grad()
1013
+ def remove_gradient_parallel_to_decoder_directions(self) -> None:
1014
+ """Remove gradient components parallel to decoder directions."""
1015
+ # Implement the original logic since this may not be in the base class
1016
+ assert self.W_dec.grad is not None
1017
+
1018
+ parallel_component = einops.einsum(
1019
+ self.W_dec.grad,
1020
+ self.W_dec.data,
1021
+ "d_sae d_in, d_sae d_in -> d_sae",
1022
+ )
1023
+ self.W_dec.grad -= einops.einsum(
1024
+ parallel_component,
1025
+ self.W_dec.data,
1026
+ "d_sae, d_sae d_in -> d_sae d_in",
1027
+ )
1028
+
1029
+ @torch.no_grad()
1030
+ def log_histograms(self) -> dict[str, NDArray[Any]]:
1031
+ """Log histograms of the weights and biases."""
1032
+ W_dec_norm_dist = self.W_dec.detach().float().norm(dim=1).cpu().numpy()
1033
+ return {
1034
+ "weights/W_dec_norms": W_dec_norm_dist,
1035
+ }
1036
+
1037
+ @classmethod
1038
+ def get_sae_class_for_architecture(
1039
+ cls: Type[T_TRAINING_SAE], architecture: str
1040
+ ) -> Type[T_TRAINING_SAE]:
1041
+ """Get the SAE class for a given architecture."""
1042
+ sae_cls, _ = get_sae_training_class(architecture)
1043
+ if not issubclass(sae_cls, cls):
1044
+ raise ValueError(
1045
+ f"Loaded SAE is not of type {cls.__name__}. Use {sae_cls.__name__} instead"
1046
+ )
1047
+ return sae_cls
1048
+
1049
+ # in the future, this can be used to load different config classes for different architectures
1050
+ @classmethod
1051
+ def get_sae_config_class_for_architecture(
1052
+ cls,
1053
+ architecture: str, # noqa: ARG003
1054
+ ) -> type[TrainingSAEConfig]:
1055
+ return get_sae_training_class(architecture)[1]
1056
+
1057
+
1058
+ _blank_hook = nn.Identity()
1059
+
1060
+
1061
+ @contextmanager
1062
+ def _disable_hooks(sae: SAE[Any]):
1063
+ """
1064
+ Temporarily disable hooks for the SAE. Swaps out all the hooks with a fake modules that does nothing.
1065
+ """
1066
+ try:
1067
+ for hook_name in sae.hook_dict:
1068
+ setattr(sae, hook_name, _blank_hook)
1069
+ yield
1070
+ finally:
1071
+ for hook_name, hook in sae.hook_dict.items():
1072
+ setattr(sae, hook_name, hook)
1073
+
1074
+
1075
+ def mse_loss(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
1076
+ return torch.nn.functional.mse_loss(preds, target, reduction="none")