sae-lens 6.0.0rc1__py3-none-any.whl → 6.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sae_lens/saes/sae.py CHANGED
@@ -4,9 +4,18 @@ import json
4
4
  import warnings
5
5
  from abc import ABC, abstractmethod
6
6
  from contextlib import contextmanager
7
- from dataclasses import dataclass, field, fields
7
+ from dataclasses import asdict, dataclass, field, fields, replace
8
8
  from pathlib import Path
9
- from typing import Any, Callable, Type, TypeVar
9
+ from typing import (
10
+ TYPE_CHECKING,
11
+ Any,
12
+ Callable,
13
+ Generic,
14
+ Literal,
15
+ NamedTuple,
16
+ Type,
17
+ TypeVar,
18
+ )
10
19
 
11
20
  import einops
12
21
  import torch
@@ -15,16 +24,19 @@ from numpy.typing import NDArray
15
24
  from safetensors.torch import save_file
16
25
  from torch import nn
17
26
  from transformer_lens.hook_points import HookedRootModule, HookPoint
18
- from typing_extensions import deprecated, overload
27
+ from typing_extensions import deprecated, overload, override
19
28
 
20
- from sae_lens import logger
21
- from sae_lens.config import (
29
+ from sae_lens import __version__, logger
30
+ from sae_lens.constants import (
22
31
  DTYPE_MAP,
23
32
  SAE_CFG_FILENAME,
24
33
  SAE_WEIGHTS_FILENAME,
25
- SPARSITY_FILENAME,
26
- LanguageModelSAERunnerConfig,
27
34
  )
35
+ from sae_lens.util import filter_valid_dataclass_fields
36
+
37
+ if TYPE_CHECKING:
38
+ from sae_lens.config import LanguageModelSAERunnerConfig
39
+
28
40
  from sae_lens.loading.pretrained_sae_loaders import (
29
41
  NAMED_PRETRAINED_SAE_LOADERS,
30
42
  PretrainedSaeDiskLoader,
@@ -39,57 +51,82 @@ from sae_lens.loading.pretrained_saes_directory import (
39
51
  get_pretrained_saes_directory,
40
52
  get_repo_id_and_folder_name,
41
53
  )
42
- from sae_lens.regsitry import get_sae_class, get_sae_training_class
54
+ from sae_lens.registry import get_sae_class, get_sae_training_class
43
55
 
44
- T = TypeVar("T", bound="SAE")
56
+ T_SAE_CONFIG = TypeVar("T_SAE_CONFIG", bound="SAEConfig")
57
+ T_TRAINING_SAE_CONFIG = TypeVar("T_TRAINING_SAE_CONFIG", bound="TrainingSAEConfig")
58
+ T_SAE = TypeVar("T_SAE", bound="SAE") # type: ignore
59
+ T_TRAINING_SAE = TypeVar("T_TRAINING_SAE", bound="TrainingSAE") # type: ignore
60
+
61
+
62
+ @dataclass
63
+ class SAEMetadata:
64
+ """Core metadata about how this SAE should be used, if known."""
65
+
66
+ model_name: str | None = None
67
+ hook_name: str | None = None
68
+ model_class_name: str | None = None
69
+ hook_layer: int | None = None
70
+ hook_head_index: int | None = None
71
+ model_from_pretrained_kwargs: dict[str, Any] | None = None
72
+ prepend_bos: bool | None = None
73
+ exclude_special_tokens: bool | list[int] | None = None
74
+ neuronpedia_id: str | None = None
75
+ context_size: int | None = None
76
+ seqpos_slice: tuple[int | None, ...] | None = None
77
+ dataset_path: str | None = None
78
+ sae_lens_version: str = field(default_factory=lambda: __version__)
79
+ sae_lens_training_version: str = field(default_factory=lambda: __version__)
45
80
 
46
81
 
47
82
  @dataclass
48
- class SAEConfig:
83
+ class SAEConfig(ABC):
49
84
  """Base configuration for SAE models."""
50
85
 
51
- architecture: str
52
86
  d_in: int
53
87
  d_sae: int
54
- dtype: str
55
- device: str
56
- model_name: str
57
- hook_name: str
58
- hook_layer: int
59
- hook_head_index: int | None
60
- activation_fn: str
61
- activation_fn_kwargs: dict[str, Any]
62
- apply_b_dec_to_input: bool
63
- finetuning_scaling_factor: bool
64
- normalize_activations: str
65
- context_size: int
66
- dataset_path: str
67
- dataset_trust_remote_code: bool
68
- sae_lens_training_version: str
69
- model_from_pretrained_kwargs: dict[str, Any] = field(default_factory=dict)
70
- seqpos_slice: tuple[int, ...] | None = None
71
- prepend_bos: bool = False
72
- neuronpedia_id: str | None = None
73
-
74
- def to_dict(self) -> dict[str, Any]:
75
- return {field.name: getattr(self, field.name) for field in fields(self)}
88
+ dtype: str = "float32"
89
+ device: str = "cpu"
90
+ apply_b_dec_to_input: bool = True
91
+ normalize_activations: Literal[
92
+ "none", "expected_average_only_in", "constant_norm_rescale", "layer_norm"
93
+ ] = "none" # none, expected_average_only_in (Anthropic April Update), constant_norm_rescale (Anthropic Feb Update)
94
+ reshape_activations: Literal["none", "hook_z"] = "none"
95
+ metadata: SAEMetadata = field(default_factory=SAEMetadata)
76
96
 
77
97
  @classmethod
78
- def from_dict(cls, config_dict: dict[str, Any]) -> "SAEConfig":
79
- valid_field_names = {field.name for field in fields(cls)}
80
- valid_config_dict = {
81
- key: val for key, val in config_dict.items() if key in valid_field_names
82
- }
98
+ @abstractmethod
99
+ def architecture(cls) -> str: ...
83
100
 
84
- # Ensure seqpos_slice is a tuple
85
- if (
86
- "seqpos_slice" in valid_config_dict
87
- and valid_config_dict["seqpos_slice"] is not None
88
- and isinstance(valid_config_dict["seqpos_slice"], list)
89
- ):
90
- valid_config_dict["seqpos_slice"] = tuple(valid_config_dict["seqpos_slice"])
101
+ def to_dict(self) -> dict[str, Any]:
102
+ res = {field.name: getattr(self, field.name) for field in fields(self)}
103
+ res["metadata"] = asdict(self.metadata)
104
+ res["architecture"] = self.architecture()
105
+ return res
91
106
 
92
- return cls(**valid_config_dict)
107
+ @classmethod
108
+ def from_dict(cls: type[T_SAE_CONFIG], config_dict: dict[str, Any]) -> T_SAE_CONFIG:
109
+ cfg_class = get_sae_class(config_dict["architecture"])[1]
110
+ filtered_config_dict = filter_valid_dataclass_fields(config_dict, cfg_class)
111
+ res = cfg_class(**filtered_config_dict)
112
+ if "metadata" in config_dict:
113
+ res.metadata = SAEMetadata(**config_dict["metadata"])
114
+ if not isinstance(res, cls):
115
+ raise ValueError(
116
+ f"SAE config class {cls} does not match dict config class {type(res)}"
117
+ )
118
+ return res
119
+
120
+ def __post_init__(self):
121
+ if self.normalize_activations not in [
122
+ "none",
123
+ "expected_average_only_in",
124
+ "constant_norm_rescale",
125
+ "layer_norm",
126
+ ]:
127
+ raise ValueError(
128
+ f"normalize_activations must be none, expected_average_only_in, constant_norm_rescale, or layer_norm. Got {self.normalize_activations}"
129
+ )
93
130
 
94
131
 
95
132
  @dataclass
@@ -109,14 +146,19 @@ class TrainStepInput:
109
146
  """Input to a training step."""
110
147
 
111
148
  sae_in: torch.Tensor
112
- current_l1_coefficient: float
149
+ coefficients: dict[str, float]
113
150
  dead_neuron_mask: torch.Tensor | None
114
151
 
115
152
 
116
- class SAE(HookedRootModule, ABC):
153
+ class TrainCoefficientConfig(NamedTuple):
154
+ value: float
155
+ warm_up_steps: int
156
+
157
+
158
+ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
117
159
  """Abstract base class for all SAE architectures."""
118
160
 
119
- cfg: SAEConfig
161
+ cfg: T_SAE_CONFIG
120
162
  dtype: torch.dtype
121
163
  device: torch.device
122
164
  use_error_term: bool
@@ -127,13 +169,13 @@ class SAE(HookedRootModule, ABC):
127
169
  W_dec: nn.Parameter
128
170
  b_dec: nn.Parameter
129
171
 
130
- def __init__(self, cfg: SAEConfig, use_error_term: bool = False):
172
+ def __init__(self, cfg: T_SAE_CONFIG, use_error_term: bool = False):
131
173
  """Initialize the SAE."""
132
174
  super().__init__()
133
175
 
134
176
  self.cfg = cfg
135
177
 
136
- if cfg.model_from_pretrained_kwargs:
178
+ if cfg.metadata and cfg.metadata.model_from_pretrained_kwargs:
137
179
  warnings.warn(
138
180
  "\nThis SAE has non-empty model_from_pretrained_kwargs. "
139
181
  "\nFor optimal performance, load the model like so:\n"
@@ -147,19 +189,11 @@ class SAE(HookedRootModule, ABC):
147
189
  self.use_error_term = use_error_term
148
190
 
149
191
  # Set up activation function
150
- self.activation_fn = self._get_activation_fn()
192
+ self.activation_fn = self.get_activation_fn()
151
193
 
152
194
  # Initialize weights
153
195
  self.initialize_weights()
154
196
 
155
- # Handle presence / absence of scaling factor
156
- if self.cfg.finetuning_scaling_factor:
157
- self.apply_finetuning_scaling_factor = (
158
- lambda x: x * self.finetuning_scaling_factor
159
- )
160
- else:
161
- self.apply_finetuning_scaling_factor = lambda x: x
162
-
163
197
  # Set up hooks
164
198
  self.hook_sae_input = HookPoint()
165
199
  self.hook_sae_acts_pre = HookPoint()
@@ -169,11 +203,9 @@ class SAE(HookedRootModule, ABC):
169
203
  self.hook_sae_error = HookPoint()
170
204
 
171
205
  # handle hook_z reshaping if needed.
172
- if self.cfg.hook_name.endswith("_z"):
173
- # print(f"Setting up hook_z reshaping for {self.cfg.hook_name}")
206
+ if self.cfg.reshape_activations == "hook_z":
174
207
  self.turn_on_forward_pass_hook_z_reshaping()
175
208
  else:
176
- # print(f"No hook_z reshaping needed for {self.cfg.hook_name}")
177
209
  self.turn_off_forward_pass_hook_z_reshaping()
178
210
 
179
211
  # Set up activation normalization
@@ -188,40 +220,9 @@ class SAE(HookedRootModule, ABC):
188
220
  self.b_dec.data /= scaling_factor # type: ignore
189
221
  self.cfg.normalize_activations = "none"
190
222
 
191
- def _get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
223
+ def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
192
224
  """Get the activation function specified in config."""
193
- return self._get_activation_fn_static(
194
- self.cfg.activation_fn, **(self.cfg.activation_fn_kwargs or {})
195
- )
196
-
197
- @staticmethod
198
- def _get_activation_fn_static(
199
- activation_fn: str, **kwargs: Any
200
- ) -> Callable[[torch.Tensor], torch.Tensor]:
201
- """Get the activation function from a string specification."""
202
- if activation_fn == "relu":
203
- return torch.nn.ReLU()
204
- if activation_fn == "tanh-relu":
205
-
206
- def tanh_relu(input: torch.Tensor) -> torch.Tensor:
207
- input = torch.relu(input)
208
- return torch.tanh(input)
209
-
210
- return tanh_relu
211
- if activation_fn == "topk":
212
- if "k" not in kwargs:
213
- raise ValueError("TopK activation function requires a k value.")
214
- k = kwargs.get("k", 1) # Default k to 1 if not provided
215
-
216
- def topk_fn(x: torch.Tensor) -> torch.Tensor:
217
- topk = torch.topk(x.flatten(start_dim=-1), k=k, dim=-1)
218
- values = torch.relu(topk.values)
219
- result = torch.zeros_like(x.flatten(start_dim=-1))
220
- result.scatter_(-1, topk.indices, values)
221
- return result.view_as(x)
222
-
223
- return topk_fn
224
- raise ValueError(f"Unknown activation function: {activation_fn}")
225
+ return nn.ReLU()
225
226
 
226
227
  def _setup_activation_normalization(self):
227
228
  """Set up activation normalization functions based on config."""
@@ -264,10 +265,20 @@ class SAE(HookedRootModule, ABC):
264
265
  self.run_time_activation_norm_fn_in = lambda x: x
265
266
  self.run_time_activation_norm_fn_out = lambda x: x
266
267
 
267
- @abstractmethod
268
268
  def initialize_weights(self):
269
269
  """Initialize model weights."""
270
- pass
270
+ self.b_dec = nn.Parameter(
271
+ torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
272
+ )
273
+
274
+ w_dec_data = torch.empty(
275
+ self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
276
+ )
277
+ nn.init.kaiming_uniform_(w_dec_data)
278
+ self.W_dec = nn.Parameter(w_dec_data)
279
+
280
+ w_enc_data = self.W_dec.data.T.clone().detach().contiguous()
281
+ self.W_enc = nn.Parameter(w_enc_data)
271
282
 
272
283
  @abstractmethod
273
284
  def encode(
@@ -284,7 +295,10 @@ class SAE(HookedRootModule, ABC):
284
295
  pass
285
296
 
286
297
  def turn_on_forward_pass_hook_z_reshaping(self):
287
- if not self.cfg.hook_name.endswith("_z"):
298
+ if (
299
+ self.cfg.metadata.hook_name is not None
300
+ and not self.cfg.metadata.hook_name.endswith("_z")
301
+ ):
288
302
  raise ValueError("This method should only be called for hook_z SAEs.")
289
303
 
290
304
  # print(f"Turning on hook_z reshaping for {self.cfg.hook_name}")
@@ -313,19 +327,19 @@ class SAE(HookedRootModule, ABC):
313
327
 
314
328
  @overload
315
329
  def to(
316
- self: T,
330
+ self: T_SAE,
317
331
  device: torch.device | str | None = ...,
318
332
  dtype: torch.dtype | None = ...,
319
333
  non_blocking: bool = ...,
320
- ) -> T: ...
334
+ ) -> T_SAE: ...
321
335
 
322
336
  @overload
323
- def to(self: T, dtype: torch.dtype, non_blocking: bool = ...) -> T: ...
337
+ def to(self: T_SAE, dtype: torch.dtype, non_blocking: bool = ...) -> T_SAE: ...
324
338
 
325
339
  @overload
326
- def to(self: T, tensor: torch.Tensor, non_blocking: bool = ...) -> T: ...
340
+ def to(self: T_SAE, tensor: torch.Tensor, non_blocking: bool = ...) -> T_SAE: ...
327
341
 
328
- def to(self: T, *args: Any, **kwargs: Any) -> T: # type: ignore
342
+ def to(self: T_SAE, *args: Any, **kwargs: Any) -> T_SAE: # type: ignore
329
343
  device_arg = None
330
344
  dtype_arg = None
331
345
 
@@ -426,15 +440,12 @@ class SAE(HookedRootModule, ABC):
426
440
 
427
441
  def get_name(self):
428
442
  """Generate a name for this SAE."""
429
- return f"sae_{self.cfg.model_name}_{self.cfg.hook_name}_{self.cfg.d_sae}"
443
+ return f"sae_{self.cfg.metadata.model_name}_{self.cfg.metadata.hook_name}_{self.cfg.d_sae}"
430
444
 
431
- def save_model(
432
- self, path: str | Path, sparsity: torch.Tensor | None = None
433
- ) -> tuple[Path, Path, Path | None]:
434
- """Save model weights, config, and optional sparsity tensor to disk."""
445
+ def save_model(self, path: str | Path) -> tuple[Path, Path]:
446
+ """Save model weights and config to disk."""
435
447
  path = Path(path)
436
- if not path.exists():
437
- path.mkdir(parents=True)
448
+ path.mkdir(parents=True, exist_ok=True)
438
449
 
439
450
  # Generate the weights
440
451
  state_dict = self.state_dict() # Use internal SAE state dict
@@ -448,13 +459,7 @@ class SAE(HookedRootModule, ABC):
448
459
  with open(cfg_path, "w") as f:
449
460
  json.dump(config, f)
450
461
 
451
- if sparsity is not None:
452
- sparsity_in_dict = {"sparsity": sparsity}
453
- sparsity_path = path / SPARSITY_FILENAME
454
- save_file(sparsity_in_dict, sparsity_path)
455
- return model_weights_path, cfg_path, sparsity_path
456
-
457
- return model_weights_path, cfg_path, None
462
+ return model_weights_path, cfg_path
458
463
 
459
464
  ## Initialization Methods
460
465
  @torch.no_grad()
@@ -482,18 +487,21 @@ class SAE(HookedRootModule, ABC):
482
487
  @classmethod
483
488
  @deprecated("Use load_from_disk instead")
484
489
  def load_from_pretrained(
485
- cls: Type[T], path: str | Path, device: str = "cpu", dtype: str | None = None
486
- ) -> T:
490
+ cls: Type[T_SAE],
491
+ path: str | Path,
492
+ device: str = "cpu",
493
+ dtype: str | None = None,
494
+ ) -> T_SAE:
487
495
  return cls.load_from_disk(path, device=device, dtype=dtype)
488
496
 
489
497
  @classmethod
490
498
  def load_from_disk(
491
- cls: Type[T],
499
+ cls: Type[T_SAE],
492
500
  path: str | Path,
493
501
  device: str = "cpu",
494
502
  dtype: str | None = None,
495
503
  converter: PretrainedSaeDiskLoader = sae_lens_disk_loader,
496
- ) -> T:
504
+ ) -> T_SAE:
497
505
  overrides = {"dtype": dtype} if dtype is not None else None
498
506
  cfg_dict, state_dict = converter(path, device, cfg_overrides=overrides)
499
507
  cfg_dict = handle_config_defaulting(cfg_dict)
@@ -501,7 +509,7 @@ class SAE(HookedRootModule, ABC):
501
509
  cfg_dict["architecture"]
502
510
  )
503
511
  sae_cfg = sae_config_cls.from_dict(cfg_dict)
504
- sae_cls = cls.get_sae_class_for_architecture(sae_cfg.architecture)
512
+ sae_cls = cls.get_sae_class_for_architecture(sae_cfg.architecture())
505
513
  sae = sae_cls(sae_cfg)
506
514
  sae.process_state_dict_for_loading(state_dict)
507
515
  sae.load_state_dict(state_dict)
@@ -509,13 +517,13 @@ class SAE(HookedRootModule, ABC):
509
517
 
510
518
  @classmethod
511
519
  def from_pretrained(
512
- cls,
520
+ cls: Type[T_SAE],
513
521
  release: str,
514
522
  sae_id: str,
515
523
  device: str = "cpu",
516
524
  force_download: bool = False,
517
525
  converter: PretrainedSaeHuggingfaceLoader | None = None,
518
- ) -> tuple["SAE", dict[str, Any], torch.Tensor | None]:
526
+ ) -> tuple[T_SAE, dict[str, Any], torch.Tensor | None]:
519
527
  """
520
528
  Load a pretrained SAE from the Hugging Face model hub.
521
529
 
@@ -584,47 +592,31 @@ class SAE(HookedRootModule, ABC):
584
592
  )
585
593
  cfg_dict = handle_config_defaulting(cfg_dict)
586
594
 
587
- # Rename keys to match SAEConfig field names
588
- renamed_cfg_dict = {}
589
- rename_map = {
590
- "hook_point": "hook_name",
591
- "hook_point_layer": "hook_layer",
592
- "hook_point_head_index": "hook_head_index",
593
- "activation_fn": "activation_fn",
594
- }
595
-
596
- for k, v in cfg_dict.items():
597
- renamed_cfg_dict[rename_map.get(k, k)] = v
598
-
599
- # Set default values for required fields
600
- renamed_cfg_dict.setdefault("activation_fn_kwargs", {})
601
- renamed_cfg_dict.setdefault("seqpos_slice", None)
602
-
603
595
  # Create SAE with appropriate architecture
604
596
  sae_config_cls = cls.get_sae_config_class_for_architecture(
605
- renamed_cfg_dict["architecture"]
597
+ cfg_dict["architecture"]
606
598
  )
607
- sae_cfg = sae_config_cls.from_dict(renamed_cfg_dict)
608
- sae_cls = cls.get_sae_class_for_architecture(sae_cfg.architecture)
599
+ sae_cfg = sae_config_cls.from_dict(cfg_dict)
600
+ sae_cls = cls.get_sae_class_for_architecture(sae_cfg.architecture())
609
601
  sae = sae_cls(sae_cfg)
610
602
  sae.process_state_dict_for_loading(state_dict)
611
603
  sae.load_state_dict(state_dict)
612
604
 
613
605
  # Apply normalization if needed
614
- if renamed_cfg_dict.get("normalize_activations") == "expected_average_only_in":
606
+ if cfg_dict.get("normalize_activations") == "expected_average_only_in":
615
607
  norm_scaling_factor = get_norm_scaling_factor(release, sae_id)
616
608
  if norm_scaling_factor is not None:
617
609
  sae.fold_activation_norm_scaling_factor(norm_scaling_factor)
618
- renamed_cfg_dict["normalize_activations"] = "none"
610
+ cfg_dict["normalize_activations"] = "none"
619
611
  else:
620
612
  warnings.warn(
621
613
  f"norm_scaling_factor not found for {release} and {sae_id}, but normalize_activations is 'expected_average_only_in'. Skipping normalization folding."
622
614
  )
623
615
 
624
- return sae, renamed_cfg_dict, log_sparsities
616
+ return sae, cfg_dict, log_sparsities
625
617
 
626
618
  @classmethod
627
- def from_dict(cls: Type[T], config_dict: dict[str, Any]) -> T:
619
+ def from_dict(cls: Type[T_SAE], config_dict: dict[str, Any]) -> T_SAE:
628
620
  """Create an SAE from a config dictionary."""
629
621
  sae_cls = cls.get_sae_class_for_architecture(config_dict["architecture"])
630
622
  sae_config_cls = cls.get_sae_config_class_for_architecture(
@@ -633,9 +625,11 @@ class SAE(HookedRootModule, ABC):
633
625
  return sae_cls(sae_config_cls.from_dict(config_dict))
634
626
 
635
627
  @classmethod
636
- def get_sae_class_for_architecture(cls: Type[T], architecture: str) -> Type[T]:
628
+ def get_sae_class_for_architecture(
629
+ cls: Type[T_SAE], architecture: str
630
+ ) -> Type[T_SAE]:
637
631
  """Get the SAE class for a given architecture."""
638
- sae_cls = get_sae_class(architecture)
632
+ sae_cls, _ = get_sae_class(architecture)
639
633
  if not issubclass(sae_cls, cls):
640
634
  raise ValueError(
641
635
  f"Loaded SAE is not of type {cls.__name__}. Use {sae_cls.__name__} instead"
@@ -645,161 +639,107 @@ class SAE(HookedRootModule, ABC):
645
639
  # in the future, this can be used to load different config classes for different architectures
646
640
  @classmethod
647
641
  def get_sae_config_class_for_architecture(
648
- cls: Type[T],
642
+ cls,
649
643
  architecture: str, # noqa: ARG003
650
644
  ) -> type[SAEConfig]:
651
645
  return SAEConfig
652
646
 
653
647
 
654
648
  @dataclass(kw_only=True)
655
- class TrainingSAEConfig(SAEConfig):
656
- # Sparsity Loss Calculations
657
- l1_coefficient: float
658
- lp_norm: float
659
- use_ghost_grads: bool
660
- normalize_sae_decoder: bool
661
- noise_scale: float
662
- decoder_orthogonal_init: bool
663
- mse_loss_normalization: str | None
664
- jumprelu_init_threshold: float
665
- jumprelu_bandwidth: float
666
- decoder_heuristic_init: bool
667
- init_encoder_as_decoder_transpose: bool
668
- scale_sparsity_penalty_by_decoder_norm: bool
649
+ class TrainingSAEConfig(SAEConfig, ABC):
650
+ noise_scale: float = 0.0
651
+ mse_loss_normalization: str | None = None
652
+ b_dec_init_method: Literal["zeros", "geometric_median", "mean"] = "zeros"
653
+ # https://transformer-circuits.pub/2024/april-update/index.html#training-saes
654
+ # 0.1 corresponds to the "heuristic" initialization, use None to disable
655
+ decoder_init_norm: float | None = 0.1
656
+
657
+ @classmethod
658
+ @abstractmethod
659
+ def architecture(cls) -> str: ...
669
660
 
670
661
  @classmethod
671
662
  def from_sae_runner_config(
672
- cls, cfg: LanguageModelSAERunnerConfig
673
- ) -> "TrainingSAEConfig":
674
- return cls(
675
- # base config
676
- architecture=cfg.architecture,
677
- d_in=cfg.d_in,
678
- d_sae=cfg.d_sae, # type: ignore
679
- dtype=cfg.dtype,
680
- device=cfg.device,
663
+ cls: type[T_TRAINING_SAE_CONFIG],
664
+ cfg: "LanguageModelSAERunnerConfig[T_TRAINING_SAE_CONFIG]",
665
+ ) -> T_TRAINING_SAE_CONFIG:
666
+ metadata = SAEMetadata(
681
667
  model_name=cfg.model_name,
682
668
  hook_name=cfg.hook_name,
683
669
  hook_layer=cfg.hook_layer,
684
670
  hook_head_index=cfg.hook_head_index,
685
- activation_fn=cfg.activation_fn,
686
- activation_fn_kwargs=cfg.activation_fn_kwargs,
687
- apply_b_dec_to_input=cfg.apply_b_dec_to_input,
688
- finetuning_scaling_factor=cfg.finetuning_method is not None,
689
- sae_lens_training_version=cfg.sae_lens_training_version,
690
671
  context_size=cfg.context_size,
691
- dataset_path=cfg.dataset_path,
692
672
  prepend_bos=cfg.prepend_bos,
693
- seqpos_slice=tuple(x for x in cfg.seqpos_slice if x is not None)
694
- if cfg.seqpos_slice is not None
695
- else None,
696
- # Training cfg
697
- l1_coefficient=cfg.l1_coefficient,
698
- lp_norm=cfg.lp_norm,
699
- use_ghost_grads=cfg.use_ghost_grads,
700
- normalize_sae_decoder=cfg.normalize_sae_decoder,
701
- noise_scale=cfg.noise_scale,
702
- decoder_orthogonal_init=cfg.decoder_orthogonal_init,
703
- mse_loss_normalization=cfg.mse_loss_normalization,
704
- decoder_heuristic_init=cfg.decoder_heuristic_init,
705
- init_encoder_as_decoder_transpose=cfg.init_encoder_as_decoder_transpose,
706
- scale_sparsity_penalty_by_decoder_norm=cfg.scale_sparsity_penalty_by_decoder_norm,
707
- normalize_activations=cfg.normalize_activations,
708
- dataset_trust_remote_code=cfg.dataset_trust_remote_code,
673
+ seqpos_slice=cfg.seqpos_slice,
709
674
  model_from_pretrained_kwargs=cfg.model_from_pretrained_kwargs or {},
710
- jumprelu_init_threshold=cfg.jumprelu_init_threshold,
711
- jumprelu_bandwidth=cfg.jumprelu_bandwidth,
712
675
  )
676
+ if not isinstance(cfg.sae, cls):
677
+ raise ValueError(
678
+ f"SAE config class {cls} does not match SAE runner config class {type(cfg.sae)}"
679
+ )
680
+ return replace(cfg.sae, metadata=metadata)
713
681
 
714
682
  @classmethod
715
- def from_dict(cls, config_dict: dict[str, Any]) -> "TrainingSAEConfig":
683
+ def from_dict(
684
+ cls: type[T_TRAINING_SAE_CONFIG], config_dict: dict[str, Any]
685
+ ) -> T_TRAINING_SAE_CONFIG:
716
686
  # remove any keys that are not in the dataclass
717
687
  # since we sometimes enhance the config with the whole LM runner config
718
- valid_field_names = {field.name for field in fields(cls)}
719
- valid_config_dict = {
720
- key: val for key, val in config_dict.items() if key in valid_field_names
721
- }
722
-
723
- # ensure seqpos slice is tuple
724
- # ensure that seqpos slices is a tuple
725
- # Ensure seqpos_slice is a tuple
726
- if "seqpos_slice" in valid_config_dict:
727
- if isinstance(valid_config_dict["seqpos_slice"], list):
728
- valid_config_dict["seqpos_slice"] = tuple(
729
- valid_config_dict["seqpos_slice"]
730
- )
731
- elif not isinstance(valid_config_dict["seqpos_slice"], tuple):
732
- valid_config_dict["seqpos_slice"] = (valid_config_dict["seqpos_slice"],)
733
-
734
- return TrainingSAEConfig(**valid_config_dict)
688
+ valid_config_dict = filter_valid_dataclass_fields(config_dict, cls)
689
+ cfg_class = cls
690
+ if "architecture" in config_dict:
691
+ cfg_class = get_sae_training_class(config_dict["architecture"])[1]
692
+ if not issubclass(cfg_class, cls):
693
+ raise ValueError(
694
+ f"SAE config class {cls} does not match dict config class {type(cfg_class)}"
695
+ )
696
+ if "metadata" in config_dict:
697
+ valid_config_dict["metadata"] = SAEMetadata(**config_dict["metadata"])
698
+ return cfg_class(**valid_config_dict)
735
699
 
736
700
  def to_dict(self) -> dict[str, Any]:
737
701
  return {
738
702
  **super().to_dict(),
739
- "l1_coefficient": self.l1_coefficient,
740
- "lp_norm": self.lp_norm,
741
- "use_ghost_grads": self.use_ghost_grads,
742
- "normalize_sae_decoder": self.normalize_sae_decoder,
743
- "noise_scale": self.noise_scale,
744
- "decoder_orthogonal_init": self.decoder_orthogonal_init,
745
- "init_encoder_as_decoder_transpose": self.init_encoder_as_decoder_transpose,
746
- "mse_loss_normalization": self.mse_loss_normalization,
747
- "decoder_heuristic_init": self.decoder_heuristic_init,
748
- "scale_sparsity_penalty_by_decoder_norm": self.scale_sparsity_penalty_by_decoder_norm,
749
- "normalize_activations": self.normalize_activations,
750
- "jumprelu_init_threshold": self.jumprelu_init_threshold,
751
- "jumprelu_bandwidth": self.jumprelu_bandwidth,
703
+ **asdict(self),
704
+ "architecture": self.architecture(),
752
705
  }
753
706
 
754
707
  # this needs to exist so we can initialize the parent sae cfg without the training specific
755
708
  # parameters. Maybe there's a cleaner way to do this
756
709
  def get_base_sae_cfg_dict(self) -> dict[str, Any]:
757
- return {
758
- "architecture": self.architecture,
759
- "d_in": self.d_in,
760
- "d_sae": self.d_sae,
761
- "activation_fn": self.activation_fn,
762
- "activation_fn_kwargs": self.activation_fn_kwargs,
763
- "apply_b_dec_to_input": self.apply_b_dec_to_input,
764
- "dtype": self.dtype,
765
- "model_name": self.model_name,
766
- "hook_name": self.hook_name,
767
- "hook_layer": self.hook_layer,
768
- "hook_head_index": self.hook_head_index,
769
- "device": self.device,
770
- "context_size": self.context_size,
771
- "prepend_bos": self.prepend_bos,
772
- "finetuning_scaling_factor": self.finetuning_scaling_factor,
773
- "normalize_activations": self.normalize_activations,
774
- "dataset_path": self.dataset_path,
775
- "dataset_trust_remote_code": self.dataset_trust_remote_code,
776
- "sae_lens_training_version": self.sae_lens_training_version,
777
- "model_from_pretrained_kwargs": self.model_from_pretrained_kwargs,
778
- "seqpos_slice": self.seqpos_slice,
779
- "neuronpedia_id": self.neuronpedia_id,
710
+ """
711
+ Creates a dictionary containing attributes corresponding to the fields
712
+ defined in the base SAEConfig class.
713
+ """
714
+ base_config_field_names = {f.name for f in fields(SAEConfig)}
715
+ result_dict = {
716
+ field_name: getattr(self, field_name)
717
+ for field_name in base_config_field_names
780
718
  }
719
+ result_dict["architecture"] = self.architecture()
720
+ return result_dict
781
721
 
782
722
 
783
- class TrainingSAE(SAE, ABC):
723
+ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
784
724
  """Abstract base class for training versions of SAEs."""
785
725
 
786
- cfg: "TrainingSAEConfig" # type: ignore
787
-
788
- def __init__(self, cfg: TrainingSAEConfig, use_error_term: bool = False):
726
+ def __init__(self, cfg: T_TRAINING_SAE_CONFIG, use_error_term: bool = False):
789
727
  super().__init__(cfg, use_error_term)
790
728
 
791
729
  # Turn off hook_z reshaping for training mode - the activation store
792
730
  # is expected to handle reshaping before passing data to the SAE
793
731
  self.turn_off_forward_pass_hook_z_reshaping()
794
-
795
732
  self.mse_loss_fn = self._get_mse_loss_fn()
796
733
 
734
+ @abstractmethod
735
+ def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]: ...
736
+
797
737
  @abstractmethod
798
738
  def encode_with_hidden_pre(
799
739
  self, x: Float[torch.Tensor, "... d_in"]
800
740
  ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
801
741
  """Encode with access to pre-activation values for training."""
802
- pass
742
+ ...
803
743
 
804
744
  def encode(
805
745
  self, x: Float[torch.Tensor, "... d_in"]
@@ -818,12 +758,20 @@ class TrainingSAE(SAE, ABC):
818
758
  Decodes feature activations back into input space,
819
759
  applying optional finetuning scale, hooking, out normalization, etc.
820
760
  """
821
- scaled_features = self.apply_finetuning_scaling_factor(feature_acts)
822
- sae_out_pre = scaled_features @ self.W_dec + self.b_dec
761
+ sae_out_pre = feature_acts @ self.W_dec + self.b_dec
823
762
  sae_out_pre = self.hook_sae_recons(sae_out_pre)
824
763
  sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
825
764
  return self.reshape_fn_out(sae_out_pre, self.d_head)
826
765
 
766
+ @override
767
+ def initialize_weights(self):
768
+ super().initialize_weights()
769
+ if self.cfg.decoder_init_norm is not None:
770
+ with torch.no_grad():
771
+ self.W_dec.data /= self.W_dec.norm(dim=-1, keepdim=True)
772
+ self.W_dec.data *= self.cfg.decoder_init_norm
773
+ self.W_enc.data = self.W_dec.data.T.clone().detach().contiguous()
774
+
827
775
  @abstractmethod
828
776
  def calculate_aux_loss(
829
777
  self,
@@ -833,7 +781,7 @@ class TrainingSAE(SAE, ABC):
833
781
  sae_out: torch.Tensor,
834
782
  ) -> torch.Tensor | dict[str, torch.Tensor]:
835
783
  """Calculate architecture-specific auxiliary loss terms."""
836
- pass
784
+ ...
837
785
 
838
786
  def training_forward_pass(
839
787
  self,
@@ -883,6 +831,39 @@ class TrainingSAE(SAE, ABC):
883
831
  losses=losses,
884
832
  )
885
833
 
834
+ def save_inference_model(self, path: str | Path) -> tuple[Path, Path]:
835
+ """Save inference version of model weights and config to disk."""
836
+ path = Path(path)
837
+ path.mkdir(parents=True, exist_ok=True)
838
+
839
+ # Generate the weights
840
+ state_dict = self.state_dict() # Use internal SAE state dict
841
+ self.process_state_dict_for_saving_inference(state_dict)
842
+ model_weights_path = path / SAE_WEIGHTS_FILENAME
843
+ save_file(state_dict, model_weights_path)
844
+
845
+ # Save the config
846
+ config = self.to_inference_config_dict()
847
+ cfg_path = path / SAE_CFG_FILENAME
848
+ with open(cfg_path, "w") as f:
849
+ json.dump(config, f)
850
+
851
+ return model_weights_path, cfg_path
852
+
853
+ @abstractmethod
854
+ def to_inference_config_dict(self) -> dict[str, Any]:
855
+ """Convert the config into an inference SAE config dict."""
856
+ ...
857
+
858
+ def process_state_dict_for_saving_inference(
859
+ self, state_dict: dict[str, Any]
860
+ ) -> None:
861
+ """
862
+ Process the state dict for saving the inference model.
863
+ This is a hook that can be overridden to change how the state dict is processed for the inference model.
864
+ """
865
+ return self.process_state_dict_for_saving(state_dict)
866
+
886
867
  def _get_mse_loss_fn(self) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
887
868
  """Get the MSE loss function based on config."""
888
869
 
@@ -921,11 +902,6 @@ class TrainingSAE(SAE, ABC):
921
902
  "d_sae, d_sae d_in -> d_sae d_in",
922
903
  )
923
904
 
924
- @torch.no_grad()
925
- def set_decoder_norm_to_unit_norm(self):
926
- """Normalize decoder columns to unit norm."""
927
- self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
928
-
929
905
  @torch.no_grad()
930
906
  def log_histograms(self) -> dict[str, NDArray[Any]]:
931
907
  """Log histograms of the weights and biases."""
@@ -935,9 +911,11 @@ class TrainingSAE(SAE, ABC):
935
911
  }
936
912
 
937
913
  @classmethod
938
- def get_sae_class_for_architecture(cls: Type[T], architecture: str) -> Type[T]:
914
+ def get_sae_class_for_architecture(
915
+ cls: Type[T_TRAINING_SAE], architecture: str
916
+ ) -> Type[T_TRAINING_SAE]:
939
917
  """Get the SAE class for a given architecture."""
940
- sae_cls = get_sae_training_class(architecture)
918
+ sae_cls, _ = get_sae_training_class(architecture)
941
919
  if not issubclass(sae_cls, cls):
942
920
  raise ValueError(
943
921
  f"Loaded SAE is not of type {cls.__name__}. Use {sae_cls.__name__} instead"
@@ -947,17 +925,17 @@ class TrainingSAE(SAE, ABC):
947
925
  # in the future, this can be used to load different config classes for different architectures
948
926
  @classmethod
949
927
  def get_sae_config_class_for_architecture(
950
- cls: Type[T],
928
+ cls,
951
929
  architecture: str, # noqa: ARG003
952
- ) -> type[SAEConfig]:
953
- return TrainingSAEConfig
930
+ ) -> type[TrainingSAEConfig]:
931
+ return get_sae_training_class(architecture)[1]
954
932
 
955
933
 
956
934
  _blank_hook = nn.Identity()
957
935
 
958
936
 
959
937
  @contextmanager
960
- def _disable_hooks(sae: SAE):
938
+ def _disable_hooks(sae: SAE[Any]):
961
939
  """
962
940
  Temporarily disable hooks for the SAE. Swaps out all the hooks with a fake modules that does nothing.
963
941
  """