sae-lens 5.9.1__py3-none-any.whl → 6.0.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sae_lens/saes/sae.py ADDED
@@ -0,0 +1,970 @@
1
+ """Base classes for Sparse Autoencoders (SAEs)."""
2
+
3
+ import json
4
+ import warnings
5
+ from abc import ABC, abstractmethod
6
+ from contextlib import contextmanager
7
+ from dataclasses import dataclass, field, fields
8
+ from pathlib import Path
9
+ from typing import Any, Callable, Type, TypeVar
10
+
11
+ import einops
12
+ import torch
13
+ from jaxtyping import Float
14
+ from numpy.typing import NDArray
15
+ from safetensors.torch import save_file
16
+ from torch import nn
17
+ from transformer_lens.hook_points import HookedRootModule, HookPoint
18
+ from typing_extensions import deprecated, overload
19
+
20
+ from sae_lens import logger
21
+ from sae_lens.config import (
22
+ DTYPE_MAP,
23
+ SAE_CFG_FILENAME,
24
+ SAE_WEIGHTS_FILENAME,
25
+ SPARSITY_FILENAME,
26
+ LanguageModelSAERunnerConfig,
27
+ )
28
+ from sae_lens.loading.pretrained_sae_loaders import (
29
+ NAMED_PRETRAINED_SAE_LOADERS,
30
+ PretrainedSaeDiskLoader,
31
+ PretrainedSaeHuggingfaceLoader,
32
+ get_conversion_loader_name,
33
+ handle_config_defaulting,
34
+ sae_lens_disk_loader,
35
+ )
36
+ from sae_lens.loading.pretrained_saes_directory import (
37
+ get_config_overrides,
38
+ get_norm_scaling_factor,
39
+ get_pretrained_saes_directory,
40
+ get_repo_id_and_folder_name,
41
+ )
42
+ from sae_lens.regsitry import get_sae_class, get_sae_training_class
43
+
44
+ T = TypeVar("T", bound="SAE")
45
+
46
+
47
+ @dataclass
48
+ class SAEConfig:
49
+ """Base configuration for SAE models."""
50
+
51
+ architecture: str
52
+ d_in: int
53
+ d_sae: int
54
+ dtype: str
55
+ device: str
56
+ model_name: str
57
+ hook_name: str
58
+ hook_layer: int
59
+ hook_head_index: int | None
60
+ activation_fn: str
61
+ activation_fn_kwargs: dict[str, Any]
62
+ apply_b_dec_to_input: bool
63
+ finetuning_scaling_factor: bool
64
+ normalize_activations: str
65
+ context_size: int
66
+ dataset_path: str
67
+ dataset_trust_remote_code: bool
68
+ sae_lens_training_version: str
69
+ model_from_pretrained_kwargs: dict[str, Any] = field(default_factory=dict)
70
+ seqpos_slice: tuple[int, ...] | None = None
71
+ prepend_bos: bool = False
72
+ neuronpedia_id: str | None = None
73
+
74
+ def to_dict(self) -> dict[str, Any]:
75
+ return {field.name: getattr(self, field.name) for field in fields(self)}
76
+
77
+ @classmethod
78
+ def from_dict(cls, config_dict: dict[str, Any]) -> "SAEConfig":
79
+ valid_field_names = {field.name for field in fields(cls)}
80
+ valid_config_dict = {
81
+ key: val for key, val in config_dict.items() if key in valid_field_names
82
+ }
83
+
84
+ # Ensure seqpos_slice is a tuple
85
+ if (
86
+ "seqpos_slice" in valid_config_dict
87
+ and valid_config_dict["seqpos_slice"] is not None
88
+ and isinstance(valid_config_dict["seqpos_slice"], list)
89
+ ):
90
+ valid_config_dict["seqpos_slice"] = tuple(valid_config_dict["seqpos_slice"])
91
+
92
+ return cls(**valid_config_dict)
93
+
94
+
95
+ @dataclass
96
+ class TrainStepOutput:
97
+ """Output from a training step."""
98
+
99
+ sae_in: torch.Tensor
100
+ sae_out: torch.Tensor
101
+ feature_acts: torch.Tensor
102
+ hidden_pre: torch.Tensor
103
+ loss: torch.Tensor # we need to call backwards on this
104
+ losses: dict[str, torch.Tensor]
105
+
106
+
107
+ @dataclass
108
+ class TrainStepInput:
109
+ """Input to a training step."""
110
+
111
+ sae_in: torch.Tensor
112
+ current_l1_coefficient: float
113
+ dead_neuron_mask: torch.Tensor | None
114
+
115
+
116
+ class SAE(HookedRootModule, ABC):
117
+ """Abstract base class for all SAE architectures."""
118
+
119
+ cfg: SAEConfig
120
+ dtype: torch.dtype
121
+ device: torch.device
122
+ use_error_term: bool
123
+
124
+ # For type checking only - don't provide default values
125
+ # These will be initialized by subclasses
126
+ W_enc: nn.Parameter
127
+ W_dec: nn.Parameter
128
+ b_dec: nn.Parameter
129
+
130
+ def __init__(self, cfg: SAEConfig, use_error_term: bool = False):
131
+ """Initialize the SAE."""
132
+ super().__init__()
133
+
134
+ self.cfg = cfg
135
+
136
+ if cfg.model_from_pretrained_kwargs:
137
+ warnings.warn(
138
+ "\nThis SAE has non-empty model_from_pretrained_kwargs. "
139
+ "\nFor optimal performance, load the model like so:\n"
140
+ "model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)",
141
+ category=UserWarning,
142
+ stacklevel=1,
143
+ )
144
+
145
+ self.dtype = DTYPE_MAP[cfg.dtype]
146
+ self.device = torch.device(cfg.device)
147
+ self.use_error_term = use_error_term
148
+
149
+ # Set up activation function
150
+ self.activation_fn = self._get_activation_fn()
151
+
152
+ # Initialize weights
153
+ self.initialize_weights()
154
+
155
+ # Handle presence / absence of scaling factor
156
+ if self.cfg.finetuning_scaling_factor:
157
+ self.apply_finetuning_scaling_factor = (
158
+ lambda x: x * self.finetuning_scaling_factor
159
+ )
160
+ else:
161
+ self.apply_finetuning_scaling_factor = lambda x: x
162
+
163
+ # Set up hooks
164
+ self.hook_sae_input = HookPoint()
165
+ self.hook_sae_acts_pre = HookPoint()
166
+ self.hook_sae_acts_post = HookPoint()
167
+ self.hook_sae_output = HookPoint()
168
+ self.hook_sae_recons = HookPoint()
169
+ self.hook_sae_error = HookPoint()
170
+
171
+ # handle hook_z reshaping if needed.
172
+ if self.cfg.hook_name.endswith("_z"):
173
+ # print(f"Setting up hook_z reshaping for {self.cfg.hook_name}")
174
+ self.turn_on_forward_pass_hook_z_reshaping()
175
+ else:
176
+ # print(f"No hook_z reshaping needed for {self.cfg.hook_name}")
177
+ self.turn_off_forward_pass_hook_z_reshaping()
178
+
179
+ # Set up activation normalization
180
+ self._setup_activation_normalization()
181
+
182
+ self.setup() # Required for HookedRootModule
183
+
184
+ @torch.no_grad()
185
+ def fold_activation_norm_scaling_factor(self, scaling_factor: float):
186
+ self.W_enc.data *= scaling_factor # type: ignore
187
+ self.W_dec.data /= scaling_factor # type: ignore
188
+ self.b_dec.data /= scaling_factor # type: ignore
189
+ self.cfg.normalize_activations = "none"
190
+
191
+ def _get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
192
+ """Get the activation function specified in config."""
193
+ return self._get_activation_fn_static(
194
+ self.cfg.activation_fn, **(self.cfg.activation_fn_kwargs or {})
195
+ )
196
+
197
+ @staticmethod
198
+ def _get_activation_fn_static(
199
+ activation_fn: str, **kwargs: Any
200
+ ) -> Callable[[torch.Tensor], torch.Tensor]:
201
+ """Get the activation function from a string specification."""
202
+ if activation_fn == "relu":
203
+ return torch.nn.ReLU()
204
+ if activation_fn == "tanh-relu":
205
+
206
+ def tanh_relu(input: torch.Tensor) -> torch.Tensor:
207
+ input = torch.relu(input)
208
+ return torch.tanh(input)
209
+
210
+ return tanh_relu
211
+ if activation_fn == "topk":
212
+ if "k" not in kwargs:
213
+ raise ValueError("TopK activation function requires a k value.")
214
+ k = kwargs.get("k", 1) # Default k to 1 if not provided
215
+
216
+ def topk_fn(x: torch.Tensor) -> torch.Tensor:
217
+ topk = torch.topk(x.flatten(start_dim=-1), k=k, dim=-1)
218
+ values = torch.relu(topk.values)
219
+ result = torch.zeros_like(x.flatten(start_dim=-1))
220
+ result.scatter_(-1, topk.indices, values)
221
+ return result.view_as(x)
222
+
223
+ return topk_fn
224
+ raise ValueError(f"Unknown activation function: {activation_fn}")
225
+
226
+ def _setup_activation_normalization(self):
227
+ """Set up activation normalization functions based on config."""
228
+ if self.cfg.normalize_activations == "constant_norm_rescale":
229
+
230
+ def run_time_activation_norm_fn_in(x: torch.Tensor) -> torch.Tensor:
231
+ self.x_norm_coeff = (self.cfg.d_in**0.5) / x.norm(dim=-1, keepdim=True)
232
+ return x * self.x_norm_coeff
233
+
234
+ def run_time_activation_norm_fn_out(x: torch.Tensor) -> torch.Tensor:
235
+ x = x / self.x_norm_coeff # type: ignore
236
+ del self.x_norm_coeff
237
+ return x
238
+
239
+ self.run_time_activation_norm_fn_in = run_time_activation_norm_fn_in
240
+ self.run_time_activation_norm_fn_out = run_time_activation_norm_fn_out
241
+
242
+ elif self.cfg.normalize_activations == "layer_norm":
243
+
244
+ def run_time_activation_ln_in(
245
+ x: torch.Tensor, eps: float = 1e-5
246
+ ) -> torch.Tensor:
247
+ mu = x.mean(dim=-1, keepdim=True)
248
+ x = x - mu
249
+ std = x.std(dim=-1, keepdim=True)
250
+ x = x / (std + eps)
251
+ self.ln_mu = mu
252
+ self.ln_std = std
253
+ return x
254
+
255
+ def run_time_activation_ln_out(
256
+ x: torch.Tensor,
257
+ eps: float = 1e-5, # noqa: ARG001
258
+ ) -> torch.Tensor:
259
+ return x * self.ln_std + self.ln_mu # type: ignore
260
+
261
+ self.run_time_activation_norm_fn_in = run_time_activation_ln_in
262
+ self.run_time_activation_norm_fn_out = run_time_activation_ln_out
263
+ else:
264
+ self.run_time_activation_norm_fn_in = lambda x: x
265
+ self.run_time_activation_norm_fn_out = lambda x: x
266
+
267
+ @abstractmethod
268
+ def initialize_weights(self):
269
+ """Initialize model weights."""
270
+ pass
271
+
272
+ @abstractmethod
273
+ def encode(
274
+ self, x: Float[torch.Tensor, "... d_in"]
275
+ ) -> Float[torch.Tensor, "... d_sae"]:
276
+ """Encode input tensor to feature space."""
277
+ pass
278
+
279
+ @abstractmethod
280
+ def decode(
281
+ self, feature_acts: Float[torch.Tensor, "... d_sae"]
282
+ ) -> Float[torch.Tensor, "... d_in"]:
283
+ """Decode feature activations back to input space."""
284
+ pass
285
+
286
+ def turn_on_forward_pass_hook_z_reshaping(self):
287
+ if not self.cfg.hook_name.endswith("_z"):
288
+ raise ValueError("This method should only be called for hook_z SAEs.")
289
+
290
+ # print(f"Turning on hook_z reshaping for {self.cfg.hook_name}")
291
+
292
+ def reshape_fn_in(x: torch.Tensor):
293
+ # print(f"reshape_fn_in input shape: {x.shape}")
294
+ self.d_head = x.shape[-1]
295
+ # print(f"Setting d_head to: {self.d_head}")
296
+ self.reshape_fn_in = lambda x: einops.rearrange(
297
+ x, "... n_heads d_head -> ... (n_heads d_head)"
298
+ )
299
+ return einops.rearrange(x, "... n_heads d_head -> ... (n_heads d_head)")
300
+
301
+ self.reshape_fn_in = reshape_fn_in
302
+ self.reshape_fn_out = lambda x, d_head: einops.rearrange(
303
+ x, "... (n_heads d_head) -> ... n_heads d_head", d_head=d_head
304
+ )
305
+ self.hook_z_reshaping_mode = True
306
+ # print(f"hook_z reshaping turned on, self.d_head={getattr(self, 'd_head', None)}")
307
+
308
+ def turn_off_forward_pass_hook_z_reshaping(self):
309
+ self.reshape_fn_in = lambda x: x
310
+ self.reshape_fn_out = lambda x, d_head: x # noqa: ARG005
311
+ self.d_head = None
312
+ self.hook_z_reshaping_mode = False
313
+
314
+ @overload
315
+ def to(
316
+ self: T,
317
+ device: torch.device | str | None = ...,
318
+ dtype: torch.dtype | None = ...,
319
+ non_blocking: bool = ...,
320
+ ) -> T: ...
321
+
322
+ @overload
323
+ def to(self: T, dtype: torch.dtype, non_blocking: bool = ...) -> T: ...
324
+
325
+ @overload
326
+ def to(self: T, tensor: torch.Tensor, non_blocking: bool = ...) -> T: ...
327
+
328
+ def to(self: T, *args: Any, **kwargs: Any) -> T: # type: ignore
329
+ device_arg = None
330
+ dtype_arg = None
331
+
332
+ # Check args
333
+ for arg in args:
334
+ if isinstance(arg, (torch.device, str)):
335
+ device_arg = arg
336
+ elif isinstance(arg, torch.dtype):
337
+ dtype_arg = arg
338
+ elif isinstance(arg, torch.Tensor):
339
+ device_arg = arg.device
340
+ dtype_arg = arg.dtype
341
+
342
+ # Check kwargs
343
+ device_arg = kwargs.get("device", device_arg)
344
+ dtype_arg = kwargs.get("dtype", dtype_arg)
345
+
346
+ # Update device in config if provided
347
+ if device_arg is not None:
348
+ # Convert device to torch.device if it's a string
349
+ device = (
350
+ torch.device(device_arg) if isinstance(device_arg, str) else device_arg
351
+ )
352
+
353
+ # Update the cfg.device
354
+ self.cfg.device = str(device)
355
+
356
+ # Update the device property
357
+ self.device = device
358
+
359
+ # Update dtype in config if provided
360
+ if dtype_arg is not None:
361
+ # Update the cfg.dtype
362
+ self.cfg.dtype = str(dtype_arg)
363
+
364
+ # Update the dtype property
365
+ self.dtype = dtype_arg
366
+
367
+ return super().to(*args, **kwargs)
368
+
369
+ def process_sae_in(
370
+ self, sae_in: Float[torch.Tensor, "... d_in"]
371
+ ) -> Float[torch.Tensor, "... d_in"]:
372
+ # print(f"Input shape to process_sae_in: {sae_in.shape}")
373
+ # print(f"self.cfg.hook_name: {self.cfg.hook_name}")
374
+ # print(f"self.b_dec shape: {self.b_dec.shape}")
375
+ # print(f"Hook z reshaping mode: {getattr(self, 'hook_z_reshaping_mode', False)}")
376
+
377
+ sae_in = sae_in.to(self.dtype)
378
+
379
+ # print(f"Shape before reshape_fn_in: {sae_in.shape}")
380
+ sae_in = self.reshape_fn_in(sae_in)
381
+ # print(f"Shape after reshape_fn_in: {sae_in.shape}")
382
+
383
+ sae_in = self.hook_sae_input(sae_in)
384
+ sae_in = self.run_time_activation_norm_fn_in(sae_in)
385
+
386
+ # Here's where the error happens
387
+ bias_term = self.b_dec * self.cfg.apply_b_dec_to_input
388
+ # print(f"Bias term shape: {bias_term.shape}")
389
+
390
+ return sae_in - bias_term
391
+
392
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
393
+ """Forward pass through the SAE."""
394
+ feature_acts = self.encode(x)
395
+ sae_out = self.decode(feature_acts)
396
+
397
+ if self.use_error_term:
398
+ with torch.no_grad():
399
+ # Recompute without hooks for true error term
400
+ with _disable_hooks(self):
401
+ feature_acts_clean = self.encode(x)
402
+ x_reconstruct_clean = self.decode(feature_acts_clean)
403
+ sae_error = self.hook_sae_error(x - x_reconstruct_clean)
404
+ sae_out = sae_out + sae_error
405
+
406
+ return self.hook_sae_output(sae_out)
407
+
408
+ # overwrite this in subclasses to modify the state_dict in-place before saving
409
+ def process_state_dict_for_saving(self, state_dict: dict[str, Any]) -> None:
410
+ pass
411
+
412
+ # overwrite this in subclasses to modify the state_dict in-place after loading
413
+ def process_state_dict_for_loading(self, state_dict: dict[str, Any]) -> None:
414
+ pass
415
+
416
+ @torch.no_grad()
417
+ def fold_W_dec_norm(self):
418
+ """Fold decoder norms into encoder."""
419
+ W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
420
+ self.W_dec.data = self.W_dec.data / W_dec_norms
421
+ self.W_enc.data = self.W_enc.data * W_dec_norms.T
422
+
423
+ # Only update b_enc if it exists (standard/jumprelu architectures)
424
+ if hasattr(self, "b_enc") and isinstance(self.b_enc, nn.Parameter):
425
+ self.b_enc.data = self.b_enc.data * W_dec_norms.squeeze()
426
+
427
+ def get_name(self):
428
+ """Generate a name for this SAE."""
429
+ return f"sae_{self.cfg.model_name}_{self.cfg.hook_name}_{self.cfg.d_sae}"
430
+
431
+ def save_model(
432
+ self, path: str | Path, sparsity: torch.Tensor | None = None
433
+ ) -> tuple[Path, Path, Path | None]:
434
+ """Save model weights, config, and optional sparsity tensor to disk."""
435
+ path = Path(path)
436
+ if not path.exists():
437
+ path.mkdir(parents=True)
438
+
439
+ # Generate the weights
440
+ state_dict = self.state_dict() # Use internal SAE state dict
441
+ self.process_state_dict_for_saving(state_dict)
442
+ model_weights_path = path / SAE_WEIGHTS_FILENAME
443
+ save_file(state_dict, model_weights_path)
444
+
445
+ # Save the config
446
+ config = self.cfg.to_dict()
447
+ cfg_path = path / SAE_CFG_FILENAME
448
+ with open(cfg_path, "w") as f:
449
+ json.dump(config, f)
450
+
451
+ if sparsity is not None:
452
+ sparsity_in_dict = {"sparsity": sparsity}
453
+ sparsity_path = path / SPARSITY_FILENAME
454
+ save_file(sparsity_in_dict, sparsity_path)
455
+ return model_weights_path, cfg_path, sparsity_path
456
+
457
+ return model_weights_path, cfg_path, None
458
+
459
+ ## Initialization Methods
460
+ @torch.no_grad()
461
+ def initialize_b_dec_with_precalculated(self, origin: torch.Tensor):
462
+ out = torch.tensor(origin, dtype=self.dtype, device=self.device)
463
+ self.b_dec.data = out
464
+
465
+ @torch.no_grad()
466
+ def initialize_b_dec_with_mean(self, all_activations: torch.Tensor):
467
+ previous_b_dec = self.b_dec.clone().cpu()
468
+ out = all_activations.mean(dim=0)
469
+
470
+ previous_distances = torch.norm(all_activations - previous_b_dec, dim=-1)
471
+ distances = torch.norm(all_activations - out, dim=-1)
472
+
473
+ logger.info("Reinitializing b_dec with mean of activations")
474
+ logger.debug(
475
+ f"Previous distances: {previous_distances.median(0).values.mean().item()}"
476
+ )
477
+ logger.debug(f"New distances: {distances.median(0).values.mean().item()}")
478
+
479
+ self.b_dec.data = out.to(self.dtype).to(self.device)
480
+
481
+ # Class methods for loading models
482
+ @classmethod
483
+ @deprecated("Use load_from_disk instead")
484
+ def load_from_pretrained(
485
+ cls: Type[T], path: str | Path, device: str = "cpu", dtype: str | None = None
486
+ ) -> T:
487
+ return cls.load_from_disk(path, device=device, dtype=dtype)
488
+
489
+ @classmethod
490
+ def load_from_disk(
491
+ cls: Type[T],
492
+ path: str | Path,
493
+ device: str = "cpu",
494
+ dtype: str | None = None,
495
+ converter: PretrainedSaeDiskLoader = sae_lens_disk_loader,
496
+ ) -> T:
497
+ overrides = {"dtype": dtype} if dtype is not None else None
498
+ cfg_dict, state_dict = converter(path, device, cfg_overrides=overrides)
499
+ cfg_dict = handle_config_defaulting(cfg_dict)
500
+ sae_config_cls = cls.get_sae_config_class_for_architecture(
501
+ cfg_dict["architecture"]
502
+ )
503
+ sae_cfg = sae_config_cls.from_dict(cfg_dict)
504
+ sae_cls = cls.get_sae_class_for_architecture(sae_cfg.architecture)
505
+ sae = sae_cls(sae_cfg)
506
+ sae.process_state_dict_for_loading(state_dict)
507
+ sae.load_state_dict(state_dict)
508
+ return sae
509
+
510
+ @classmethod
511
+ def from_pretrained(
512
+ cls,
513
+ release: str,
514
+ sae_id: str,
515
+ device: str = "cpu",
516
+ force_download: bool = False,
517
+ converter: PretrainedSaeHuggingfaceLoader | None = None,
518
+ ) -> tuple["SAE", dict[str, Any], torch.Tensor | None]:
519
+ """
520
+ Load a pretrained SAE from the Hugging Face model hub.
521
+
522
+ Args:
523
+ release: The release name. This will be mapped to a huggingface repo id based on the pretrained_saes.yaml file.
524
+ id: The id of the SAE to load. This will be mapped to a path in the huggingface repo.
525
+ device: The device to load the SAE on.
526
+ return_sparsity_if_present: If True, will return the log sparsity tensor if it is present in the model directory in the Hugging Face model hub.
527
+ """
528
+
529
+ # get sae directory
530
+ sae_directory = get_pretrained_saes_directory()
531
+
532
+ # Validate release and sae_id
533
+ if release not in sae_directory:
534
+ if "/" not in release:
535
+ raise ValueError(
536
+ f"Release {release} not found in pretrained SAEs directory, and is not a valid huggingface repo."
537
+ )
538
+ elif sae_id not in sae_directory[release].saes_map:
539
+ # Handle special cases like Gemma Scope
540
+ if (
541
+ "gemma-scope" in release
542
+ and "canonical" not in release
543
+ and f"{release}-canonical" in sae_directory
544
+ ):
545
+ canonical_ids = list(
546
+ sae_directory[release + "-canonical"].saes_map.keys()
547
+ )
548
+ # Shorten the lengthy string of valid IDs
549
+ if len(canonical_ids) > 5:
550
+ str_canonical_ids = str(canonical_ids[:5])[:-1] + ", ...]"
551
+ else:
552
+ str_canonical_ids = str(canonical_ids)
553
+ value_suffix = f" If you don't want to specify an L0 value, consider using release {release}-canonical which has valid IDs {str_canonical_ids}"
554
+ else:
555
+ value_suffix = ""
556
+
557
+ valid_ids = list(sae_directory[release].saes_map.keys())
558
+ # Shorten the lengthy string of valid IDs
559
+ if len(valid_ids) > 5:
560
+ str_valid_ids = str(valid_ids[:5])[:-1] + ", ...]"
561
+ else:
562
+ str_valid_ids = str(valid_ids)
563
+
564
+ raise ValueError(
565
+ f"ID {sae_id} not found in release {release}. Valid IDs are {str_valid_ids}."
566
+ + value_suffix
567
+ )
568
+
569
+ conversion_loader = (
570
+ converter
571
+ or NAMED_PRETRAINED_SAE_LOADERS[get_conversion_loader_name(release)]
572
+ )
573
+ repo_id, folder_name = get_repo_id_and_folder_name(release, sae_id)
574
+ config_overrides = get_config_overrides(release, sae_id)
575
+ config_overrides["device"] = device
576
+
577
+ # Load config and weights
578
+ cfg_dict, state_dict, log_sparsities = conversion_loader(
579
+ repo_id=repo_id,
580
+ folder_name=folder_name,
581
+ device=device,
582
+ force_download=force_download,
583
+ cfg_overrides=config_overrides,
584
+ )
585
+ cfg_dict = handle_config_defaulting(cfg_dict)
586
+
587
+ # Rename keys to match SAEConfig field names
588
+ renamed_cfg_dict = {}
589
+ rename_map = {
590
+ "hook_point": "hook_name",
591
+ "hook_point_layer": "hook_layer",
592
+ "hook_point_head_index": "hook_head_index",
593
+ "activation_fn": "activation_fn",
594
+ }
595
+
596
+ for k, v in cfg_dict.items():
597
+ renamed_cfg_dict[rename_map.get(k, k)] = v
598
+
599
+ # Set default values for required fields
600
+ renamed_cfg_dict.setdefault("activation_fn_kwargs", {})
601
+ renamed_cfg_dict.setdefault("seqpos_slice", None)
602
+
603
+ # Create SAE with appropriate architecture
604
+ sae_config_cls = cls.get_sae_config_class_for_architecture(
605
+ renamed_cfg_dict["architecture"]
606
+ )
607
+ sae_cfg = sae_config_cls.from_dict(renamed_cfg_dict)
608
+ sae_cls = cls.get_sae_class_for_architecture(sae_cfg.architecture)
609
+ sae = sae_cls(sae_cfg)
610
+ sae.process_state_dict_for_loading(state_dict)
611
+ sae.load_state_dict(state_dict)
612
+
613
+ # Apply normalization if needed
614
+ if renamed_cfg_dict.get("normalize_activations") == "expected_average_only_in":
615
+ norm_scaling_factor = get_norm_scaling_factor(release, sae_id)
616
+ if norm_scaling_factor is not None:
617
+ sae.fold_activation_norm_scaling_factor(norm_scaling_factor)
618
+ renamed_cfg_dict["normalize_activations"] = "none"
619
+ else:
620
+ warnings.warn(
621
+ f"norm_scaling_factor not found for {release} and {sae_id}, but normalize_activations is 'expected_average_only_in'. Skipping normalization folding."
622
+ )
623
+
624
+ return sae, renamed_cfg_dict, log_sparsities
625
+
626
+ @classmethod
627
+ def from_dict(cls: Type[T], config_dict: dict[str, Any]) -> T:
628
+ """Create an SAE from a config dictionary."""
629
+ sae_cls = cls.get_sae_class_for_architecture(config_dict["architecture"])
630
+ sae_config_cls = cls.get_sae_config_class_for_architecture(
631
+ config_dict["architecture"]
632
+ )
633
+ return sae_cls(sae_config_cls.from_dict(config_dict))
634
+
635
+ @classmethod
636
+ def get_sae_class_for_architecture(cls: Type[T], architecture: str) -> Type[T]:
637
+ """Get the SAE class for a given architecture."""
638
+ sae_cls = get_sae_class(architecture)
639
+ if not issubclass(sae_cls, cls):
640
+ raise ValueError(
641
+ f"Loaded SAE is not of type {cls.__name__}. Use {sae_cls.__name__} instead"
642
+ )
643
+ return sae_cls
644
+
645
+ # in the future, this can be used to load different config classes for different architectures
646
+ @classmethod
647
+ def get_sae_config_class_for_architecture(
648
+ cls: Type[T],
649
+ architecture: str, # noqa: ARG003
650
+ ) -> type[SAEConfig]:
651
+ return SAEConfig
652
+
653
+
654
+ @dataclass(kw_only=True)
655
+ class TrainingSAEConfig(SAEConfig):
656
+ # Sparsity Loss Calculations
657
+ l1_coefficient: float
658
+ lp_norm: float
659
+ use_ghost_grads: bool
660
+ normalize_sae_decoder: bool
661
+ noise_scale: float
662
+ decoder_orthogonal_init: bool
663
+ mse_loss_normalization: str | None
664
+ jumprelu_init_threshold: float
665
+ jumprelu_bandwidth: float
666
+ decoder_heuristic_init: bool
667
+ init_encoder_as_decoder_transpose: bool
668
+ scale_sparsity_penalty_by_decoder_norm: bool
669
+
670
+ @classmethod
671
+ def from_sae_runner_config(
672
+ cls, cfg: LanguageModelSAERunnerConfig
673
+ ) -> "TrainingSAEConfig":
674
+ return cls(
675
+ # base config
676
+ architecture=cfg.architecture,
677
+ d_in=cfg.d_in,
678
+ d_sae=cfg.d_sae, # type: ignore
679
+ dtype=cfg.dtype,
680
+ device=cfg.device,
681
+ model_name=cfg.model_name,
682
+ hook_name=cfg.hook_name,
683
+ hook_layer=cfg.hook_layer,
684
+ hook_head_index=cfg.hook_head_index,
685
+ activation_fn=cfg.activation_fn,
686
+ activation_fn_kwargs=cfg.activation_fn_kwargs,
687
+ apply_b_dec_to_input=cfg.apply_b_dec_to_input,
688
+ finetuning_scaling_factor=cfg.finetuning_method is not None,
689
+ sae_lens_training_version=cfg.sae_lens_training_version,
690
+ context_size=cfg.context_size,
691
+ dataset_path=cfg.dataset_path,
692
+ prepend_bos=cfg.prepend_bos,
693
+ seqpos_slice=tuple(x for x in cfg.seqpos_slice if x is not None)
694
+ if cfg.seqpos_slice is not None
695
+ else None,
696
+ # Training cfg
697
+ l1_coefficient=cfg.l1_coefficient,
698
+ lp_norm=cfg.lp_norm,
699
+ use_ghost_grads=cfg.use_ghost_grads,
700
+ normalize_sae_decoder=cfg.normalize_sae_decoder,
701
+ noise_scale=cfg.noise_scale,
702
+ decoder_orthogonal_init=cfg.decoder_orthogonal_init,
703
+ mse_loss_normalization=cfg.mse_loss_normalization,
704
+ decoder_heuristic_init=cfg.decoder_heuristic_init,
705
+ init_encoder_as_decoder_transpose=cfg.init_encoder_as_decoder_transpose,
706
+ scale_sparsity_penalty_by_decoder_norm=cfg.scale_sparsity_penalty_by_decoder_norm,
707
+ normalize_activations=cfg.normalize_activations,
708
+ dataset_trust_remote_code=cfg.dataset_trust_remote_code,
709
+ model_from_pretrained_kwargs=cfg.model_from_pretrained_kwargs or {},
710
+ jumprelu_init_threshold=cfg.jumprelu_init_threshold,
711
+ jumprelu_bandwidth=cfg.jumprelu_bandwidth,
712
+ )
713
+
714
+ @classmethod
715
+ def from_dict(cls, config_dict: dict[str, Any]) -> "TrainingSAEConfig":
716
+ # remove any keys that are not in the dataclass
717
+ # since we sometimes enhance the config with the whole LM runner config
718
+ valid_field_names = {field.name for field in fields(cls)}
719
+ valid_config_dict = {
720
+ key: val for key, val in config_dict.items() if key in valid_field_names
721
+ }
722
+
723
+ # ensure seqpos slice is tuple
724
+ # ensure that seqpos slices is a tuple
725
+ # Ensure seqpos_slice is a tuple
726
+ if "seqpos_slice" in valid_config_dict:
727
+ if isinstance(valid_config_dict["seqpos_slice"], list):
728
+ valid_config_dict["seqpos_slice"] = tuple(
729
+ valid_config_dict["seqpos_slice"]
730
+ )
731
+ elif not isinstance(valid_config_dict["seqpos_slice"], tuple):
732
+ valid_config_dict["seqpos_slice"] = (valid_config_dict["seqpos_slice"],)
733
+
734
+ return TrainingSAEConfig(**valid_config_dict)
735
+
736
+ def to_dict(self) -> dict[str, Any]:
737
+ return {
738
+ **super().to_dict(),
739
+ "l1_coefficient": self.l1_coefficient,
740
+ "lp_norm": self.lp_norm,
741
+ "use_ghost_grads": self.use_ghost_grads,
742
+ "normalize_sae_decoder": self.normalize_sae_decoder,
743
+ "noise_scale": self.noise_scale,
744
+ "decoder_orthogonal_init": self.decoder_orthogonal_init,
745
+ "init_encoder_as_decoder_transpose": self.init_encoder_as_decoder_transpose,
746
+ "mse_loss_normalization": self.mse_loss_normalization,
747
+ "decoder_heuristic_init": self.decoder_heuristic_init,
748
+ "scale_sparsity_penalty_by_decoder_norm": self.scale_sparsity_penalty_by_decoder_norm,
749
+ "normalize_activations": self.normalize_activations,
750
+ "jumprelu_init_threshold": self.jumprelu_init_threshold,
751
+ "jumprelu_bandwidth": self.jumprelu_bandwidth,
752
+ }
753
+
754
+ # this needs to exist so we can initialize the parent sae cfg without the training specific
755
+ # parameters. Maybe there's a cleaner way to do this
756
+ def get_base_sae_cfg_dict(self) -> dict[str, Any]:
757
+ return {
758
+ "architecture": self.architecture,
759
+ "d_in": self.d_in,
760
+ "d_sae": self.d_sae,
761
+ "activation_fn": self.activation_fn,
762
+ "activation_fn_kwargs": self.activation_fn_kwargs,
763
+ "apply_b_dec_to_input": self.apply_b_dec_to_input,
764
+ "dtype": self.dtype,
765
+ "model_name": self.model_name,
766
+ "hook_name": self.hook_name,
767
+ "hook_layer": self.hook_layer,
768
+ "hook_head_index": self.hook_head_index,
769
+ "device": self.device,
770
+ "context_size": self.context_size,
771
+ "prepend_bos": self.prepend_bos,
772
+ "finetuning_scaling_factor": self.finetuning_scaling_factor,
773
+ "normalize_activations": self.normalize_activations,
774
+ "dataset_path": self.dataset_path,
775
+ "dataset_trust_remote_code": self.dataset_trust_remote_code,
776
+ "sae_lens_training_version": self.sae_lens_training_version,
777
+ "model_from_pretrained_kwargs": self.model_from_pretrained_kwargs,
778
+ "seqpos_slice": self.seqpos_slice,
779
+ "neuronpedia_id": self.neuronpedia_id,
780
+ }
781
+
782
+
783
+ class TrainingSAE(SAE, ABC):
784
+ """Abstract base class for training versions of SAEs."""
785
+
786
+ cfg: "TrainingSAEConfig" # type: ignore
787
+
788
+ def __init__(self, cfg: TrainingSAEConfig, use_error_term: bool = False):
789
+ super().__init__(cfg, use_error_term)
790
+
791
+ # Turn off hook_z reshaping for training mode - the activation store
792
+ # is expected to handle reshaping before passing data to the SAE
793
+ self.turn_off_forward_pass_hook_z_reshaping()
794
+
795
+ self.mse_loss_fn = self._get_mse_loss_fn()
796
+
797
+ @abstractmethod
798
+ def encode_with_hidden_pre(
799
+ self, x: Float[torch.Tensor, "... d_in"]
800
+ ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
801
+ """Encode with access to pre-activation values for training."""
802
+ pass
803
+
804
+ def encode(
805
+ self, x: Float[torch.Tensor, "... d_in"]
806
+ ) -> Float[torch.Tensor, "... d_sae"]:
807
+ """
808
+ For inference, just encode without returning hidden_pre.
809
+ (training_forward_pass calls encode_with_hidden_pre).
810
+ """
811
+ feature_acts, _ = self.encode_with_hidden_pre(x)
812
+ return feature_acts
813
+
814
+ def decode(
815
+ self, feature_acts: Float[torch.Tensor, "... d_sae"]
816
+ ) -> Float[torch.Tensor, "... d_in"]:
817
+ """
818
+ Decodes feature activations back into input space,
819
+ applying optional finetuning scale, hooking, out normalization, etc.
820
+ """
821
+ scaled_features = self.apply_finetuning_scaling_factor(feature_acts)
822
+ sae_out_pre = scaled_features @ self.W_dec + self.b_dec
823
+ sae_out_pre = self.hook_sae_recons(sae_out_pre)
824
+ sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
825
+ return self.reshape_fn_out(sae_out_pre, self.d_head)
826
+
827
+ @abstractmethod
828
+ def calculate_aux_loss(
829
+ self,
830
+ step_input: TrainStepInput,
831
+ feature_acts: torch.Tensor,
832
+ hidden_pre: torch.Tensor,
833
+ sae_out: torch.Tensor,
834
+ ) -> torch.Tensor | dict[str, torch.Tensor]:
835
+ """Calculate architecture-specific auxiliary loss terms."""
836
+ pass
837
+
838
+ def training_forward_pass(
839
+ self,
840
+ step_input: TrainStepInput,
841
+ ) -> TrainStepOutput:
842
+ """Forward pass during training."""
843
+ feature_acts, hidden_pre = self.encode_with_hidden_pre(step_input.sae_in)
844
+ sae_out = self.decode(feature_acts)
845
+
846
+ # Calculate MSE loss
847
+ per_item_mse_loss = self.mse_loss_fn(sae_out, step_input.sae_in)
848
+ mse_loss = per_item_mse_loss.sum(dim=-1).mean()
849
+
850
+ # Calculate architecture-specific auxiliary losses
851
+ aux_losses = self.calculate_aux_loss(
852
+ step_input=step_input,
853
+ feature_acts=feature_acts,
854
+ hidden_pre=hidden_pre,
855
+ sae_out=sae_out,
856
+ )
857
+
858
+ # Total loss is MSE plus all auxiliary losses
859
+ total_loss = mse_loss
860
+
861
+ # Create losses dictionary with mse_loss
862
+ losses = {"mse_loss": mse_loss}
863
+
864
+ # Add architecture-specific losses to the dictionary
865
+ # Make sure aux_losses is a dictionary with string keys and tensor values
866
+ if isinstance(aux_losses, dict):
867
+ losses.update(aux_losses)
868
+
869
+ # Sum all losses for total_loss
870
+ if isinstance(aux_losses, dict):
871
+ for loss_value in aux_losses.values():
872
+ total_loss = total_loss + loss_value
873
+ else:
874
+ # Handle case where aux_losses is a tensor
875
+ total_loss = total_loss + aux_losses
876
+
877
+ return TrainStepOutput(
878
+ sae_in=step_input.sae_in,
879
+ sae_out=sae_out,
880
+ feature_acts=feature_acts,
881
+ hidden_pre=hidden_pre,
882
+ loss=total_loss,
883
+ losses=losses,
884
+ )
885
+
886
+ def _get_mse_loss_fn(self) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
887
+ """Get the MSE loss function based on config."""
888
+
889
+ def standard_mse_loss_fn(
890
+ preds: torch.Tensor, target: torch.Tensor
891
+ ) -> torch.Tensor:
892
+ return torch.nn.functional.mse_loss(preds, target, reduction="none")
893
+
894
+ def batch_norm_mse_loss_fn(
895
+ preds: torch.Tensor, target: torch.Tensor
896
+ ) -> torch.Tensor:
897
+ target_centered = target - target.mean(dim=0, keepdim=True)
898
+ normalization = target_centered.norm(dim=-1, keepdim=True)
899
+ return torch.nn.functional.mse_loss(preds, target, reduction="none") / (
900
+ normalization + 1e-6
901
+ )
902
+
903
+ if self.cfg.mse_loss_normalization == "dense_batch":
904
+ return batch_norm_mse_loss_fn
905
+ return standard_mse_loss_fn
906
+
907
+ @torch.no_grad()
908
+ def remove_gradient_parallel_to_decoder_directions(self) -> None:
909
+ """Remove gradient components parallel to decoder directions."""
910
+ # Implement the original logic since this may not be in the base class
911
+ assert self.W_dec.grad is not None
912
+
913
+ parallel_component = einops.einsum(
914
+ self.W_dec.grad,
915
+ self.W_dec.data,
916
+ "d_sae d_in, d_sae d_in -> d_sae",
917
+ )
918
+ self.W_dec.grad -= einops.einsum(
919
+ parallel_component,
920
+ self.W_dec.data,
921
+ "d_sae, d_sae d_in -> d_sae d_in",
922
+ )
923
+
924
+ @torch.no_grad()
925
+ def set_decoder_norm_to_unit_norm(self):
926
+ """Normalize decoder columns to unit norm."""
927
+ self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
928
+
929
+ @torch.no_grad()
930
+ def log_histograms(self) -> dict[str, NDArray[Any]]:
931
+ """Log histograms of the weights and biases."""
932
+ W_dec_norm_dist = self.W_dec.detach().float().norm(dim=1).cpu().numpy()
933
+ return {
934
+ "weights/W_dec_norms": W_dec_norm_dist,
935
+ }
936
+
937
+ @classmethod
938
+ def get_sae_class_for_architecture(cls: Type[T], architecture: str) -> Type[T]:
939
+ """Get the SAE class for a given architecture."""
940
+ sae_cls = get_sae_training_class(architecture)
941
+ if not issubclass(sae_cls, cls):
942
+ raise ValueError(
943
+ f"Loaded SAE is not of type {cls.__name__}. Use {sae_cls.__name__} instead"
944
+ )
945
+ return sae_cls
946
+
947
+ # in the future, this can be used to load different config classes for different architectures
948
+ @classmethod
949
+ def get_sae_config_class_for_architecture(
950
+ cls: Type[T],
951
+ architecture: str, # noqa: ARG003
952
+ ) -> type[SAEConfig]:
953
+ return TrainingSAEConfig
954
+
955
+
956
+ _blank_hook = nn.Identity()
957
+
958
+
959
+ @contextmanager
960
+ def _disable_hooks(sae: SAE):
961
+ """
962
+ Temporarily disable hooks for the SAE. Swaps out all the hooks with a fake modules that does nothing.
963
+ """
964
+ try:
965
+ for hook_name in sae.hook_dict:
966
+ setattr(sae, hook_name, _blank_hook)
967
+ yield
968
+ finally:
969
+ for hook_name, hook in sae.hook_dict.items():
970
+ setattr(sae, hook_name, hook)