sae-lens 5.10.3__py3-none-any.whl → 6.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sae_lens/saes/sae.py ADDED
@@ -0,0 +1,948 @@
1
+ """Base classes for Sparse Autoencoders (SAEs)."""
2
+
3
+ import json
4
+ import warnings
5
+ from abc import ABC, abstractmethod
6
+ from contextlib import contextmanager
7
+ from dataclasses import asdict, dataclass, field, fields, replace
8
+ from pathlib import Path
9
+ from typing import (
10
+ TYPE_CHECKING,
11
+ Any,
12
+ Callable,
13
+ Generic,
14
+ Literal,
15
+ NamedTuple,
16
+ Type,
17
+ TypeVar,
18
+ )
19
+
20
+ import einops
21
+ import torch
22
+ from jaxtyping import Float
23
+ from numpy.typing import NDArray
24
+ from safetensors.torch import save_file
25
+ from torch import nn
26
+ from transformer_lens.hook_points import HookedRootModule, HookPoint
27
+ from typing_extensions import deprecated, overload, override
28
+
29
+ from sae_lens import __version__, logger
30
+ from sae_lens.constants import (
31
+ DTYPE_MAP,
32
+ SAE_CFG_FILENAME,
33
+ SAE_WEIGHTS_FILENAME,
34
+ )
35
+ from sae_lens.util import filter_valid_dataclass_fields
36
+
37
+ if TYPE_CHECKING:
38
+ from sae_lens.config import LanguageModelSAERunnerConfig
39
+
40
+ from sae_lens.loading.pretrained_sae_loaders import (
41
+ NAMED_PRETRAINED_SAE_LOADERS,
42
+ PretrainedSaeDiskLoader,
43
+ PretrainedSaeHuggingfaceLoader,
44
+ get_conversion_loader_name,
45
+ handle_config_defaulting,
46
+ sae_lens_disk_loader,
47
+ )
48
+ from sae_lens.loading.pretrained_saes_directory import (
49
+ get_config_overrides,
50
+ get_norm_scaling_factor,
51
+ get_pretrained_saes_directory,
52
+ get_repo_id_and_folder_name,
53
+ )
54
+ from sae_lens.registry import get_sae_class, get_sae_training_class
55
+
56
+ T_SAE_CONFIG = TypeVar("T_SAE_CONFIG", bound="SAEConfig")
57
+ T_TRAINING_SAE_CONFIG = TypeVar("T_TRAINING_SAE_CONFIG", bound="TrainingSAEConfig")
58
+ T_SAE = TypeVar("T_SAE", bound="SAE") # type: ignore
59
+ T_TRAINING_SAE = TypeVar("T_TRAINING_SAE", bound="TrainingSAE") # type: ignore
60
+
61
+
62
+ @dataclass
63
+ class SAEMetadata:
64
+ """Core metadata about how this SAE should be used, if known."""
65
+
66
+ model_name: str | None = None
67
+ hook_name: str | None = None
68
+ model_class_name: str | None = None
69
+ hook_layer: int | None = None
70
+ hook_head_index: int | None = None
71
+ model_from_pretrained_kwargs: dict[str, Any] | None = None
72
+ prepend_bos: bool | None = None
73
+ exclude_special_tokens: bool | list[int] | None = None
74
+ neuronpedia_id: str | None = None
75
+ context_size: int | None = None
76
+ seqpos_slice: tuple[int | None, ...] | None = None
77
+ dataset_path: str | None = None
78
+ sae_lens_version: str = field(default_factory=lambda: __version__)
79
+ sae_lens_training_version: str = field(default_factory=lambda: __version__)
80
+
81
+
82
+ @dataclass
83
+ class SAEConfig(ABC):
84
+ """Base configuration for SAE models."""
85
+
86
+ d_in: int
87
+ d_sae: int
88
+ dtype: str = "float32"
89
+ device: str = "cpu"
90
+ apply_b_dec_to_input: bool = True
91
+ normalize_activations: Literal[
92
+ "none", "expected_average_only_in", "constant_norm_rescale", "layer_norm"
93
+ ] = "none" # none, expected_average_only_in (Anthropic April Update), constant_norm_rescale (Anthropic Feb Update)
94
+ reshape_activations: Literal["none", "hook_z"] = "none"
95
+ metadata: SAEMetadata = field(default_factory=SAEMetadata)
96
+
97
+ @classmethod
98
+ @abstractmethod
99
+ def architecture(cls) -> str: ...
100
+
101
+ def to_dict(self) -> dict[str, Any]:
102
+ res = {field.name: getattr(self, field.name) for field in fields(self)}
103
+ res["metadata"] = asdict(self.metadata)
104
+ res["architecture"] = self.architecture()
105
+ return res
106
+
107
+ @classmethod
108
+ def from_dict(cls: type[T_SAE_CONFIG], config_dict: dict[str, Any]) -> T_SAE_CONFIG:
109
+ cfg_class = get_sae_class(config_dict["architecture"])[1]
110
+ filtered_config_dict = filter_valid_dataclass_fields(config_dict, cfg_class)
111
+ res = cfg_class(**filtered_config_dict)
112
+ if "metadata" in config_dict:
113
+ res.metadata = SAEMetadata(**config_dict["metadata"])
114
+ if not isinstance(res, cls):
115
+ raise ValueError(
116
+ f"SAE config class {cls} does not match dict config class {type(res)}"
117
+ )
118
+ return res
119
+
120
+ def __post_init__(self):
121
+ if self.normalize_activations not in [
122
+ "none",
123
+ "expected_average_only_in",
124
+ "constant_norm_rescale",
125
+ "layer_norm",
126
+ ]:
127
+ raise ValueError(
128
+ f"normalize_activations must be none, expected_average_only_in, constant_norm_rescale, or layer_norm. Got {self.normalize_activations}"
129
+ )
130
+
131
+
132
+ @dataclass
133
+ class TrainStepOutput:
134
+ """Output from a training step."""
135
+
136
+ sae_in: torch.Tensor
137
+ sae_out: torch.Tensor
138
+ feature_acts: torch.Tensor
139
+ hidden_pre: torch.Tensor
140
+ loss: torch.Tensor # we need to call backwards on this
141
+ losses: dict[str, torch.Tensor]
142
+
143
+
144
+ @dataclass
145
+ class TrainStepInput:
146
+ """Input to a training step."""
147
+
148
+ sae_in: torch.Tensor
149
+ coefficients: dict[str, float]
150
+ dead_neuron_mask: torch.Tensor | None
151
+
152
+
153
+ class TrainCoefficientConfig(NamedTuple):
154
+ value: float
155
+ warm_up_steps: int
156
+
157
+
158
+ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
159
+ """Abstract base class for all SAE architectures."""
160
+
161
+ cfg: T_SAE_CONFIG
162
+ dtype: torch.dtype
163
+ device: torch.device
164
+ use_error_term: bool
165
+
166
+ # For type checking only - don't provide default values
167
+ # These will be initialized by subclasses
168
+ W_enc: nn.Parameter
169
+ W_dec: nn.Parameter
170
+ b_dec: nn.Parameter
171
+
172
+ def __init__(self, cfg: T_SAE_CONFIG, use_error_term: bool = False):
173
+ """Initialize the SAE."""
174
+ super().__init__()
175
+
176
+ self.cfg = cfg
177
+
178
+ if cfg.metadata and cfg.metadata.model_from_pretrained_kwargs:
179
+ warnings.warn(
180
+ "\nThis SAE has non-empty model_from_pretrained_kwargs. "
181
+ "\nFor optimal performance, load the model like so:\n"
182
+ "model = HookedSAETransformer.from_pretrained_no_processing(..., **cfg.model_from_pretrained_kwargs)",
183
+ category=UserWarning,
184
+ stacklevel=1,
185
+ )
186
+
187
+ self.dtype = DTYPE_MAP[cfg.dtype]
188
+ self.device = torch.device(cfg.device)
189
+ self.use_error_term = use_error_term
190
+
191
+ # Set up activation function
192
+ self.activation_fn = self.get_activation_fn()
193
+
194
+ # Initialize weights
195
+ self.initialize_weights()
196
+
197
+ # Set up hooks
198
+ self.hook_sae_input = HookPoint()
199
+ self.hook_sae_acts_pre = HookPoint()
200
+ self.hook_sae_acts_post = HookPoint()
201
+ self.hook_sae_output = HookPoint()
202
+ self.hook_sae_recons = HookPoint()
203
+ self.hook_sae_error = HookPoint()
204
+
205
+ # handle hook_z reshaping if needed.
206
+ if self.cfg.reshape_activations == "hook_z":
207
+ self.turn_on_forward_pass_hook_z_reshaping()
208
+ else:
209
+ self.turn_off_forward_pass_hook_z_reshaping()
210
+
211
+ # Set up activation normalization
212
+ self._setup_activation_normalization()
213
+
214
+ self.setup() # Required for HookedRootModule
215
+
216
+ @torch.no_grad()
217
+ def fold_activation_norm_scaling_factor(self, scaling_factor: float):
218
+ self.W_enc.data *= scaling_factor # type: ignore
219
+ self.W_dec.data /= scaling_factor # type: ignore
220
+ self.b_dec.data /= scaling_factor # type: ignore
221
+ self.cfg.normalize_activations = "none"
222
+
223
+ def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
224
+ """Get the activation function specified in config."""
225
+ return nn.ReLU()
226
+
227
+ def _setup_activation_normalization(self):
228
+ """Set up activation normalization functions based on config."""
229
+ if self.cfg.normalize_activations == "constant_norm_rescale":
230
+
231
+ def run_time_activation_norm_fn_in(x: torch.Tensor) -> torch.Tensor:
232
+ self.x_norm_coeff = (self.cfg.d_in**0.5) / x.norm(dim=-1, keepdim=True)
233
+ return x * self.x_norm_coeff
234
+
235
+ def run_time_activation_norm_fn_out(x: torch.Tensor) -> torch.Tensor:
236
+ x = x / self.x_norm_coeff # type: ignore
237
+ del self.x_norm_coeff
238
+ return x
239
+
240
+ self.run_time_activation_norm_fn_in = run_time_activation_norm_fn_in
241
+ self.run_time_activation_norm_fn_out = run_time_activation_norm_fn_out
242
+
243
+ elif self.cfg.normalize_activations == "layer_norm":
244
+
245
+ def run_time_activation_ln_in(
246
+ x: torch.Tensor, eps: float = 1e-5
247
+ ) -> torch.Tensor:
248
+ mu = x.mean(dim=-1, keepdim=True)
249
+ x = x - mu
250
+ std = x.std(dim=-1, keepdim=True)
251
+ x = x / (std + eps)
252
+ self.ln_mu = mu
253
+ self.ln_std = std
254
+ return x
255
+
256
+ def run_time_activation_ln_out(
257
+ x: torch.Tensor,
258
+ eps: float = 1e-5, # noqa: ARG001
259
+ ) -> torch.Tensor:
260
+ return x * self.ln_std + self.ln_mu # type: ignore
261
+
262
+ self.run_time_activation_norm_fn_in = run_time_activation_ln_in
263
+ self.run_time_activation_norm_fn_out = run_time_activation_ln_out
264
+ else:
265
+ self.run_time_activation_norm_fn_in = lambda x: x
266
+ self.run_time_activation_norm_fn_out = lambda x: x
267
+
268
+ def initialize_weights(self):
269
+ """Initialize model weights."""
270
+ self.b_dec = nn.Parameter(
271
+ torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
272
+ )
273
+
274
+ w_dec_data = torch.empty(
275
+ self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
276
+ )
277
+ nn.init.kaiming_uniform_(w_dec_data)
278
+ self.W_dec = nn.Parameter(w_dec_data)
279
+
280
+ w_enc_data = self.W_dec.data.T.clone().detach().contiguous()
281
+ self.W_enc = nn.Parameter(w_enc_data)
282
+
283
+ @abstractmethod
284
+ def encode(
285
+ self, x: Float[torch.Tensor, "... d_in"]
286
+ ) -> Float[torch.Tensor, "... d_sae"]:
287
+ """Encode input tensor to feature space."""
288
+ pass
289
+
290
+ @abstractmethod
291
+ def decode(
292
+ self, feature_acts: Float[torch.Tensor, "... d_sae"]
293
+ ) -> Float[torch.Tensor, "... d_in"]:
294
+ """Decode feature activations back to input space."""
295
+ pass
296
+
297
+ def turn_on_forward_pass_hook_z_reshaping(self):
298
+ if (
299
+ self.cfg.metadata.hook_name is not None
300
+ and not self.cfg.metadata.hook_name.endswith("_z")
301
+ ):
302
+ raise ValueError("This method should only be called for hook_z SAEs.")
303
+
304
+ # print(f"Turning on hook_z reshaping for {self.cfg.hook_name}")
305
+
306
+ def reshape_fn_in(x: torch.Tensor):
307
+ # print(f"reshape_fn_in input shape: {x.shape}")
308
+ self.d_head = x.shape[-1]
309
+ # print(f"Setting d_head to: {self.d_head}")
310
+ self.reshape_fn_in = lambda x: einops.rearrange(
311
+ x, "... n_heads d_head -> ... (n_heads d_head)"
312
+ )
313
+ return einops.rearrange(x, "... n_heads d_head -> ... (n_heads d_head)")
314
+
315
+ self.reshape_fn_in = reshape_fn_in
316
+ self.reshape_fn_out = lambda x, d_head: einops.rearrange(
317
+ x, "... (n_heads d_head) -> ... n_heads d_head", d_head=d_head
318
+ )
319
+ self.hook_z_reshaping_mode = True
320
+ # print(f"hook_z reshaping turned on, self.d_head={getattr(self, 'd_head', None)}")
321
+
322
+ def turn_off_forward_pass_hook_z_reshaping(self):
323
+ self.reshape_fn_in = lambda x: x
324
+ self.reshape_fn_out = lambda x, d_head: x # noqa: ARG005
325
+ self.d_head = None
326
+ self.hook_z_reshaping_mode = False
327
+
328
+ @overload
329
+ def to(
330
+ self: T_SAE,
331
+ device: torch.device | str | None = ...,
332
+ dtype: torch.dtype | None = ...,
333
+ non_blocking: bool = ...,
334
+ ) -> T_SAE: ...
335
+
336
+ @overload
337
+ def to(self: T_SAE, dtype: torch.dtype, non_blocking: bool = ...) -> T_SAE: ...
338
+
339
+ @overload
340
+ def to(self: T_SAE, tensor: torch.Tensor, non_blocking: bool = ...) -> T_SAE: ...
341
+
342
+ def to(self: T_SAE, *args: Any, **kwargs: Any) -> T_SAE: # type: ignore
343
+ device_arg = None
344
+ dtype_arg = None
345
+
346
+ # Check args
347
+ for arg in args:
348
+ if isinstance(arg, (torch.device, str)):
349
+ device_arg = arg
350
+ elif isinstance(arg, torch.dtype):
351
+ dtype_arg = arg
352
+ elif isinstance(arg, torch.Tensor):
353
+ device_arg = arg.device
354
+ dtype_arg = arg.dtype
355
+
356
+ # Check kwargs
357
+ device_arg = kwargs.get("device", device_arg)
358
+ dtype_arg = kwargs.get("dtype", dtype_arg)
359
+
360
+ # Update device in config if provided
361
+ if device_arg is not None:
362
+ # Convert device to torch.device if it's a string
363
+ device = (
364
+ torch.device(device_arg) if isinstance(device_arg, str) else device_arg
365
+ )
366
+
367
+ # Update the cfg.device
368
+ self.cfg.device = str(device)
369
+
370
+ # Update the device property
371
+ self.device = device
372
+
373
+ # Update dtype in config if provided
374
+ if dtype_arg is not None:
375
+ # Update the cfg.dtype
376
+ self.cfg.dtype = str(dtype_arg)
377
+
378
+ # Update the dtype property
379
+ self.dtype = dtype_arg
380
+
381
+ return super().to(*args, **kwargs)
382
+
383
+ def process_sae_in(
384
+ self, sae_in: Float[torch.Tensor, "... d_in"]
385
+ ) -> Float[torch.Tensor, "... d_in"]:
386
+ # print(f"Input shape to process_sae_in: {sae_in.shape}")
387
+ # print(f"self.cfg.hook_name: {self.cfg.hook_name}")
388
+ # print(f"self.b_dec shape: {self.b_dec.shape}")
389
+ # print(f"Hook z reshaping mode: {getattr(self, 'hook_z_reshaping_mode', False)}")
390
+
391
+ sae_in = sae_in.to(self.dtype)
392
+
393
+ # print(f"Shape before reshape_fn_in: {sae_in.shape}")
394
+ sae_in = self.reshape_fn_in(sae_in)
395
+ # print(f"Shape after reshape_fn_in: {sae_in.shape}")
396
+
397
+ sae_in = self.hook_sae_input(sae_in)
398
+ sae_in = self.run_time_activation_norm_fn_in(sae_in)
399
+
400
+ # Here's where the error happens
401
+ bias_term = self.b_dec * self.cfg.apply_b_dec_to_input
402
+ # print(f"Bias term shape: {bias_term.shape}")
403
+
404
+ return sae_in - bias_term
405
+
406
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
407
+ """Forward pass through the SAE."""
408
+ feature_acts = self.encode(x)
409
+ sae_out = self.decode(feature_acts)
410
+
411
+ if self.use_error_term:
412
+ with torch.no_grad():
413
+ # Recompute without hooks for true error term
414
+ with _disable_hooks(self):
415
+ feature_acts_clean = self.encode(x)
416
+ x_reconstruct_clean = self.decode(feature_acts_clean)
417
+ sae_error = self.hook_sae_error(x - x_reconstruct_clean)
418
+ sae_out = sae_out + sae_error
419
+
420
+ return self.hook_sae_output(sae_out)
421
+
422
+ # overwrite this in subclasses to modify the state_dict in-place before saving
423
+ def process_state_dict_for_saving(self, state_dict: dict[str, Any]) -> None:
424
+ pass
425
+
426
+ # overwrite this in subclasses to modify the state_dict in-place after loading
427
+ def process_state_dict_for_loading(self, state_dict: dict[str, Any]) -> None:
428
+ pass
429
+
430
+ @torch.no_grad()
431
+ def fold_W_dec_norm(self):
432
+ """Fold decoder norms into encoder."""
433
+ W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
434
+ self.W_dec.data = self.W_dec.data / W_dec_norms
435
+ self.W_enc.data = self.W_enc.data * W_dec_norms.T
436
+
437
+ # Only update b_enc if it exists (standard/jumprelu architectures)
438
+ if hasattr(self, "b_enc") and isinstance(self.b_enc, nn.Parameter):
439
+ self.b_enc.data = self.b_enc.data * W_dec_norms.squeeze()
440
+
441
+ def get_name(self):
442
+ """Generate a name for this SAE."""
443
+ return f"sae_{self.cfg.metadata.model_name}_{self.cfg.metadata.hook_name}_{self.cfg.d_sae}"
444
+
445
+ def save_model(self, path: str | Path) -> tuple[Path, Path]:
446
+ """Save model weights and config to disk."""
447
+ path = Path(path)
448
+ path.mkdir(parents=True, exist_ok=True)
449
+
450
+ # Generate the weights
451
+ state_dict = self.state_dict() # Use internal SAE state dict
452
+ self.process_state_dict_for_saving(state_dict)
453
+ model_weights_path = path / SAE_WEIGHTS_FILENAME
454
+ save_file(state_dict, model_weights_path)
455
+
456
+ # Save the config
457
+ config = self.cfg.to_dict()
458
+ cfg_path = path / SAE_CFG_FILENAME
459
+ with open(cfg_path, "w") as f:
460
+ json.dump(config, f)
461
+
462
+ return model_weights_path, cfg_path
463
+
464
+ ## Initialization Methods
465
+ @torch.no_grad()
466
+ def initialize_b_dec_with_precalculated(self, origin: torch.Tensor):
467
+ out = torch.tensor(origin, dtype=self.dtype, device=self.device)
468
+ self.b_dec.data = out
469
+
470
+ @torch.no_grad()
471
+ def initialize_b_dec_with_mean(self, all_activations: torch.Tensor):
472
+ previous_b_dec = self.b_dec.clone().cpu()
473
+ out = all_activations.mean(dim=0)
474
+
475
+ previous_distances = torch.norm(all_activations - previous_b_dec, dim=-1)
476
+ distances = torch.norm(all_activations - out, dim=-1)
477
+
478
+ logger.info("Reinitializing b_dec with mean of activations")
479
+ logger.debug(
480
+ f"Previous distances: {previous_distances.median(0).values.mean().item()}"
481
+ )
482
+ logger.debug(f"New distances: {distances.median(0).values.mean().item()}")
483
+
484
+ self.b_dec.data = out.to(self.dtype).to(self.device)
485
+
486
+ # Class methods for loading models
487
+ @classmethod
488
+ @deprecated("Use load_from_disk instead")
489
+ def load_from_pretrained(
490
+ cls: Type[T_SAE],
491
+ path: str | Path,
492
+ device: str = "cpu",
493
+ dtype: str | None = None,
494
+ ) -> T_SAE:
495
+ return cls.load_from_disk(path, device=device, dtype=dtype)
496
+
497
+ @classmethod
498
+ def load_from_disk(
499
+ cls: Type[T_SAE],
500
+ path: str | Path,
501
+ device: str = "cpu",
502
+ dtype: str | None = None,
503
+ converter: PretrainedSaeDiskLoader = sae_lens_disk_loader,
504
+ ) -> T_SAE:
505
+ overrides = {"dtype": dtype} if dtype is not None else None
506
+ cfg_dict, state_dict = converter(path, device, cfg_overrides=overrides)
507
+ cfg_dict = handle_config_defaulting(cfg_dict)
508
+ sae_config_cls = cls.get_sae_config_class_for_architecture(
509
+ cfg_dict["architecture"]
510
+ )
511
+ sae_cfg = sae_config_cls.from_dict(cfg_dict)
512
+ sae_cls = cls.get_sae_class_for_architecture(sae_cfg.architecture())
513
+ sae = sae_cls(sae_cfg)
514
+ sae.process_state_dict_for_loading(state_dict)
515
+ sae.load_state_dict(state_dict)
516
+ return sae
517
+
518
+ @classmethod
519
+ def from_pretrained(
520
+ cls: Type[T_SAE],
521
+ release: str,
522
+ sae_id: str,
523
+ device: str = "cpu",
524
+ force_download: bool = False,
525
+ converter: PretrainedSaeHuggingfaceLoader | None = None,
526
+ ) -> tuple[T_SAE, dict[str, Any], torch.Tensor | None]:
527
+ """
528
+ Load a pretrained SAE from the Hugging Face model hub.
529
+
530
+ Args:
531
+ release: The release name. This will be mapped to a huggingface repo id based on the pretrained_saes.yaml file.
532
+ id: The id of the SAE to load. This will be mapped to a path in the huggingface repo.
533
+ device: The device to load the SAE on.
534
+ return_sparsity_if_present: If True, will return the log sparsity tensor if it is present in the model directory in the Hugging Face model hub.
535
+ """
536
+
537
+ # get sae directory
538
+ sae_directory = get_pretrained_saes_directory()
539
+
540
+ # Validate release and sae_id
541
+ if release not in sae_directory:
542
+ if "/" not in release:
543
+ raise ValueError(
544
+ f"Release {release} not found in pretrained SAEs directory, and is not a valid huggingface repo."
545
+ )
546
+ elif sae_id not in sae_directory[release].saes_map:
547
+ # Handle special cases like Gemma Scope
548
+ if (
549
+ "gemma-scope" in release
550
+ and "canonical" not in release
551
+ and f"{release}-canonical" in sae_directory
552
+ ):
553
+ canonical_ids = list(
554
+ sae_directory[release + "-canonical"].saes_map.keys()
555
+ )
556
+ # Shorten the lengthy string of valid IDs
557
+ if len(canonical_ids) > 5:
558
+ str_canonical_ids = str(canonical_ids[:5])[:-1] + ", ...]"
559
+ else:
560
+ str_canonical_ids = str(canonical_ids)
561
+ value_suffix = f" If you don't want to specify an L0 value, consider using release {release}-canonical which has valid IDs {str_canonical_ids}"
562
+ else:
563
+ value_suffix = ""
564
+
565
+ valid_ids = list(sae_directory[release].saes_map.keys())
566
+ # Shorten the lengthy string of valid IDs
567
+ if len(valid_ids) > 5:
568
+ str_valid_ids = str(valid_ids[:5])[:-1] + ", ...]"
569
+ else:
570
+ str_valid_ids = str(valid_ids)
571
+
572
+ raise ValueError(
573
+ f"ID {sae_id} not found in release {release}. Valid IDs are {str_valid_ids}."
574
+ + value_suffix
575
+ )
576
+
577
+ conversion_loader = (
578
+ converter
579
+ or NAMED_PRETRAINED_SAE_LOADERS[get_conversion_loader_name(release)]
580
+ )
581
+ repo_id, folder_name = get_repo_id_and_folder_name(release, sae_id)
582
+ config_overrides = get_config_overrides(release, sae_id)
583
+ config_overrides["device"] = device
584
+
585
+ # Load config and weights
586
+ cfg_dict, state_dict, log_sparsities = conversion_loader(
587
+ repo_id=repo_id,
588
+ folder_name=folder_name,
589
+ device=device,
590
+ force_download=force_download,
591
+ cfg_overrides=config_overrides,
592
+ )
593
+ cfg_dict = handle_config_defaulting(cfg_dict)
594
+
595
+ # Create SAE with appropriate architecture
596
+ sae_config_cls = cls.get_sae_config_class_for_architecture(
597
+ cfg_dict["architecture"]
598
+ )
599
+ sae_cfg = sae_config_cls.from_dict(cfg_dict)
600
+ sae_cls = cls.get_sae_class_for_architecture(sae_cfg.architecture())
601
+ sae = sae_cls(sae_cfg)
602
+ sae.process_state_dict_for_loading(state_dict)
603
+ sae.load_state_dict(state_dict)
604
+
605
+ # Apply normalization if needed
606
+ if cfg_dict.get("normalize_activations") == "expected_average_only_in":
607
+ norm_scaling_factor = get_norm_scaling_factor(release, sae_id)
608
+ if norm_scaling_factor is not None:
609
+ sae.fold_activation_norm_scaling_factor(norm_scaling_factor)
610
+ cfg_dict["normalize_activations"] = "none"
611
+ else:
612
+ warnings.warn(
613
+ f"norm_scaling_factor not found for {release} and {sae_id}, but normalize_activations is 'expected_average_only_in'. Skipping normalization folding."
614
+ )
615
+
616
+ return sae, cfg_dict, log_sparsities
617
+
618
+ @classmethod
619
+ def from_dict(cls: Type[T_SAE], config_dict: dict[str, Any]) -> T_SAE:
620
+ """Create an SAE from a config dictionary."""
621
+ sae_cls = cls.get_sae_class_for_architecture(config_dict["architecture"])
622
+ sae_config_cls = cls.get_sae_config_class_for_architecture(
623
+ config_dict["architecture"]
624
+ )
625
+ return sae_cls(sae_config_cls.from_dict(config_dict))
626
+
627
+ @classmethod
628
+ def get_sae_class_for_architecture(
629
+ cls: Type[T_SAE], architecture: str
630
+ ) -> Type[T_SAE]:
631
+ """Get the SAE class for a given architecture."""
632
+ sae_cls, _ = get_sae_class(architecture)
633
+ if not issubclass(sae_cls, cls):
634
+ raise ValueError(
635
+ f"Loaded SAE is not of type {cls.__name__}. Use {sae_cls.__name__} instead"
636
+ )
637
+ return sae_cls
638
+
639
+ # in the future, this can be used to load different config classes for different architectures
640
+ @classmethod
641
+ def get_sae_config_class_for_architecture(
642
+ cls,
643
+ architecture: str, # noqa: ARG003
644
+ ) -> type[SAEConfig]:
645
+ return SAEConfig
646
+
647
+
648
+ @dataclass(kw_only=True)
649
+ class TrainingSAEConfig(SAEConfig, ABC):
650
+ noise_scale: float = 0.0
651
+ mse_loss_normalization: str | None = None
652
+ b_dec_init_method: Literal["zeros", "geometric_median", "mean"] = "zeros"
653
+ # https://transformer-circuits.pub/2024/april-update/index.html#training-saes
654
+ # 0.1 corresponds to the "heuristic" initialization, use None to disable
655
+ decoder_init_norm: float | None = 0.1
656
+
657
+ @classmethod
658
+ @abstractmethod
659
+ def architecture(cls) -> str: ...
660
+
661
+ @classmethod
662
+ def from_sae_runner_config(
663
+ cls: type[T_TRAINING_SAE_CONFIG],
664
+ cfg: "LanguageModelSAERunnerConfig[T_TRAINING_SAE_CONFIG]",
665
+ ) -> T_TRAINING_SAE_CONFIG:
666
+ metadata = SAEMetadata(
667
+ model_name=cfg.model_name,
668
+ hook_name=cfg.hook_name,
669
+ hook_layer=cfg.hook_layer,
670
+ hook_head_index=cfg.hook_head_index,
671
+ context_size=cfg.context_size,
672
+ prepend_bos=cfg.prepend_bos,
673
+ seqpos_slice=cfg.seqpos_slice,
674
+ model_from_pretrained_kwargs=cfg.model_from_pretrained_kwargs or {},
675
+ )
676
+ if not isinstance(cfg.sae, cls):
677
+ raise ValueError(
678
+ f"SAE config class {cls} does not match SAE runner config class {type(cfg.sae)}"
679
+ )
680
+ return replace(cfg.sae, metadata=metadata)
681
+
682
+ @classmethod
683
+ def from_dict(
684
+ cls: type[T_TRAINING_SAE_CONFIG], config_dict: dict[str, Any]
685
+ ) -> T_TRAINING_SAE_CONFIG:
686
+ # remove any keys that are not in the dataclass
687
+ # since we sometimes enhance the config with the whole LM runner config
688
+ valid_config_dict = filter_valid_dataclass_fields(config_dict, cls)
689
+ cfg_class = cls
690
+ if "architecture" in config_dict:
691
+ cfg_class = get_sae_training_class(config_dict["architecture"])[1]
692
+ if not issubclass(cfg_class, cls):
693
+ raise ValueError(
694
+ f"SAE config class {cls} does not match dict config class {type(cfg_class)}"
695
+ )
696
+ if "metadata" in config_dict:
697
+ valid_config_dict["metadata"] = SAEMetadata(**config_dict["metadata"])
698
+ return cfg_class(**valid_config_dict)
699
+
700
+ def to_dict(self) -> dict[str, Any]:
701
+ return {
702
+ **super().to_dict(),
703
+ **asdict(self),
704
+ "architecture": self.architecture(),
705
+ }
706
+
707
+ # this needs to exist so we can initialize the parent sae cfg without the training specific
708
+ # parameters. Maybe there's a cleaner way to do this
709
+ def get_base_sae_cfg_dict(self) -> dict[str, Any]:
710
+ """
711
+ Creates a dictionary containing attributes corresponding to the fields
712
+ defined in the base SAEConfig class.
713
+ """
714
+ base_config_field_names = {f.name for f in fields(SAEConfig)}
715
+ result_dict = {
716
+ field_name: getattr(self, field_name)
717
+ for field_name in base_config_field_names
718
+ }
719
+ result_dict["architecture"] = self.architecture()
720
+ return result_dict
721
+
722
+
723
+ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
724
+ """Abstract base class for training versions of SAEs."""
725
+
726
+ def __init__(self, cfg: T_TRAINING_SAE_CONFIG, use_error_term: bool = False):
727
+ super().__init__(cfg, use_error_term)
728
+
729
+ # Turn off hook_z reshaping for training mode - the activation store
730
+ # is expected to handle reshaping before passing data to the SAE
731
+ self.turn_off_forward_pass_hook_z_reshaping()
732
+ self.mse_loss_fn = self._get_mse_loss_fn()
733
+
734
+ @abstractmethod
735
+ def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]: ...
736
+
737
+ @abstractmethod
738
+ def encode_with_hidden_pre(
739
+ self, x: Float[torch.Tensor, "... d_in"]
740
+ ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
741
+ """Encode with access to pre-activation values for training."""
742
+ ...
743
+
744
+ def encode(
745
+ self, x: Float[torch.Tensor, "... d_in"]
746
+ ) -> Float[torch.Tensor, "... d_sae"]:
747
+ """
748
+ For inference, just encode without returning hidden_pre.
749
+ (training_forward_pass calls encode_with_hidden_pre).
750
+ """
751
+ feature_acts, _ = self.encode_with_hidden_pre(x)
752
+ return feature_acts
753
+
754
+ def decode(
755
+ self, feature_acts: Float[torch.Tensor, "... d_sae"]
756
+ ) -> Float[torch.Tensor, "... d_in"]:
757
+ """
758
+ Decodes feature activations back into input space,
759
+ applying optional finetuning scale, hooking, out normalization, etc.
760
+ """
761
+ sae_out_pre = feature_acts @ self.W_dec + self.b_dec
762
+ sae_out_pre = self.hook_sae_recons(sae_out_pre)
763
+ sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
764
+ return self.reshape_fn_out(sae_out_pre, self.d_head)
765
+
766
+ @override
767
+ def initialize_weights(self):
768
+ super().initialize_weights()
769
+ if self.cfg.decoder_init_norm is not None:
770
+ with torch.no_grad():
771
+ self.W_dec.data /= self.W_dec.norm(dim=-1, keepdim=True)
772
+ self.W_dec.data *= self.cfg.decoder_init_norm
773
+ self.W_enc.data = self.W_dec.data.T.clone().detach().contiguous()
774
+
775
+ @abstractmethod
776
+ def calculate_aux_loss(
777
+ self,
778
+ step_input: TrainStepInput,
779
+ feature_acts: torch.Tensor,
780
+ hidden_pre: torch.Tensor,
781
+ sae_out: torch.Tensor,
782
+ ) -> torch.Tensor | dict[str, torch.Tensor]:
783
+ """Calculate architecture-specific auxiliary loss terms."""
784
+ ...
785
+
786
+ def training_forward_pass(
787
+ self,
788
+ step_input: TrainStepInput,
789
+ ) -> TrainStepOutput:
790
+ """Forward pass during training."""
791
+ feature_acts, hidden_pre = self.encode_with_hidden_pre(step_input.sae_in)
792
+ sae_out = self.decode(feature_acts)
793
+
794
+ # Calculate MSE loss
795
+ per_item_mse_loss = self.mse_loss_fn(sae_out, step_input.sae_in)
796
+ mse_loss = per_item_mse_loss.sum(dim=-1).mean()
797
+
798
+ # Calculate architecture-specific auxiliary losses
799
+ aux_losses = self.calculate_aux_loss(
800
+ step_input=step_input,
801
+ feature_acts=feature_acts,
802
+ hidden_pre=hidden_pre,
803
+ sae_out=sae_out,
804
+ )
805
+
806
+ # Total loss is MSE plus all auxiliary losses
807
+ total_loss = mse_loss
808
+
809
+ # Create losses dictionary with mse_loss
810
+ losses = {"mse_loss": mse_loss}
811
+
812
+ # Add architecture-specific losses to the dictionary
813
+ # Make sure aux_losses is a dictionary with string keys and tensor values
814
+ if isinstance(aux_losses, dict):
815
+ losses.update(aux_losses)
816
+
817
+ # Sum all losses for total_loss
818
+ if isinstance(aux_losses, dict):
819
+ for loss_value in aux_losses.values():
820
+ total_loss = total_loss + loss_value
821
+ else:
822
+ # Handle case where aux_losses is a tensor
823
+ total_loss = total_loss + aux_losses
824
+
825
+ return TrainStepOutput(
826
+ sae_in=step_input.sae_in,
827
+ sae_out=sae_out,
828
+ feature_acts=feature_acts,
829
+ hidden_pre=hidden_pre,
830
+ loss=total_loss,
831
+ losses=losses,
832
+ )
833
+
834
+ def save_inference_model(self, path: str | Path) -> tuple[Path, Path]:
835
+ """Save inference version of model weights and config to disk."""
836
+ path = Path(path)
837
+ path.mkdir(parents=True, exist_ok=True)
838
+
839
+ # Generate the weights
840
+ state_dict = self.state_dict() # Use internal SAE state dict
841
+ self.process_state_dict_for_saving_inference(state_dict)
842
+ model_weights_path = path / SAE_WEIGHTS_FILENAME
843
+ save_file(state_dict, model_weights_path)
844
+
845
+ # Save the config
846
+ config = self.to_inference_config_dict()
847
+ cfg_path = path / SAE_CFG_FILENAME
848
+ with open(cfg_path, "w") as f:
849
+ json.dump(config, f)
850
+
851
+ return model_weights_path, cfg_path
852
+
853
+ @abstractmethod
854
+ def to_inference_config_dict(self) -> dict[str, Any]:
855
+ """Convert the config into an inference SAE config dict."""
856
+ ...
857
+
858
+ def process_state_dict_for_saving_inference(
859
+ self, state_dict: dict[str, Any]
860
+ ) -> None:
861
+ """
862
+ Process the state dict for saving the inference model.
863
+ This is a hook that can be overridden to change how the state dict is processed for the inference model.
864
+ """
865
+ return self.process_state_dict_for_saving(state_dict)
866
+
867
+ def _get_mse_loss_fn(self) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
868
+ """Get the MSE loss function based on config."""
869
+
870
+ def standard_mse_loss_fn(
871
+ preds: torch.Tensor, target: torch.Tensor
872
+ ) -> torch.Tensor:
873
+ return torch.nn.functional.mse_loss(preds, target, reduction="none")
874
+
875
+ def batch_norm_mse_loss_fn(
876
+ preds: torch.Tensor, target: torch.Tensor
877
+ ) -> torch.Tensor:
878
+ target_centered = target - target.mean(dim=0, keepdim=True)
879
+ normalization = target_centered.norm(dim=-1, keepdim=True)
880
+ return torch.nn.functional.mse_loss(preds, target, reduction="none") / (
881
+ normalization + 1e-6
882
+ )
883
+
884
+ if self.cfg.mse_loss_normalization == "dense_batch":
885
+ return batch_norm_mse_loss_fn
886
+ return standard_mse_loss_fn
887
+
888
+ @torch.no_grad()
889
+ def remove_gradient_parallel_to_decoder_directions(self) -> None:
890
+ """Remove gradient components parallel to decoder directions."""
891
+ # Implement the original logic since this may not be in the base class
892
+ assert self.W_dec.grad is not None
893
+
894
+ parallel_component = einops.einsum(
895
+ self.W_dec.grad,
896
+ self.W_dec.data,
897
+ "d_sae d_in, d_sae d_in -> d_sae",
898
+ )
899
+ self.W_dec.grad -= einops.einsum(
900
+ parallel_component,
901
+ self.W_dec.data,
902
+ "d_sae, d_sae d_in -> d_sae d_in",
903
+ )
904
+
905
+ @torch.no_grad()
906
+ def log_histograms(self) -> dict[str, NDArray[Any]]:
907
+ """Log histograms of the weights and biases."""
908
+ W_dec_norm_dist = self.W_dec.detach().float().norm(dim=1).cpu().numpy()
909
+ return {
910
+ "weights/W_dec_norms": W_dec_norm_dist,
911
+ }
912
+
913
+ @classmethod
914
+ def get_sae_class_for_architecture(
915
+ cls: Type[T_TRAINING_SAE], architecture: str
916
+ ) -> Type[T_TRAINING_SAE]:
917
+ """Get the SAE class for a given architecture."""
918
+ sae_cls, _ = get_sae_training_class(architecture)
919
+ if not issubclass(sae_cls, cls):
920
+ raise ValueError(
921
+ f"Loaded SAE is not of type {cls.__name__}. Use {sae_cls.__name__} instead"
922
+ )
923
+ return sae_cls
924
+
925
+ # in the future, this can be used to load different config classes for different architectures
926
+ @classmethod
927
+ def get_sae_config_class_for_architecture(
928
+ cls,
929
+ architecture: str, # noqa: ARG003
930
+ ) -> type[TrainingSAEConfig]:
931
+ return get_sae_training_class(architecture)[1]
932
+
933
+
934
+ _blank_hook = nn.Identity()
935
+
936
+
937
+ @contextmanager
938
+ def _disable_hooks(sae: SAE[Any]):
939
+ """
940
+ Temporarily disable hooks for the SAE. Swaps out all the hooks with a fake modules that does nothing.
941
+ """
942
+ try:
943
+ for hook_name in sae.hook_dict:
944
+ setattr(sae, hook_name, _blank_hook)
945
+ yield
946
+ finally:
947
+ for hook_name, hook in sae.hook_dict.items():
948
+ setattr(sae, hook_name, hook)