sae-lens 6.0.0rc1__py3-none-any.whl → 6.0.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sae_lens/saes/sae.py CHANGED
@@ -4,9 +4,18 @@ import json
4
4
  import warnings
5
5
  from abc import ABC, abstractmethod
6
6
  from contextlib import contextmanager
7
- from dataclasses import dataclass, field, fields
7
+ from dataclasses import asdict, dataclass, field, fields, replace
8
8
  from pathlib import Path
9
- from typing import Any, Callable, Type, TypeVar
9
+ from typing import (
10
+ TYPE_CHECKING,
11
+ Any,
12
+ Callable,
13
+ Generic,
14
+ Literal,
15
+ NamedTuple,
16
+ Type,
17
+ TypeVar,
18
+ )
10
19
 
11
20
  import einops
12
21
  import torch
@@ -15,16 +24,19 @@ from numpy.typing import NDArray
15
24
  from safetensors.torch import save_file
16
25
  from torch import nn
17
26
  from transformer_lens.hook_points import HookedRootModule, HookPoint
18
- from typing_extensions import deprecated, overload
27
+ from typing_extensions import deprecated, overload, override
19
28
 
20
- from sae_lens import logger
21
- from sae_lens.config import (
29
+ from sae_lens import __version__, logger
30
+ from sae_lens.constants import (
22
31
  DTYPE_MAP,
23
32
  SAE_CFG_FILENAME,
24
33
  SAE_WEIGHTS_FILENAME,
25
- SPARSITY_FILENAME,
26
- LanguageModelSAERunnerConfig,
27
34
  )
35
+ from sae_lens.util import filter_valid_dataclass_fields
36
+
37
+ if TYPE_CHECKING:
38
+ from sae_lens.config import LanguageModelSAERunnerConfig
39
+
28
40
  from sae_lens.loading.pretrained_sae_loaders import (
29
41
  NAMED_PRETRAINED_SAE_LOADERS,
30
42
  PretrainedSaeDiskLoader,
@@ -39,57 +51,81 @@ from sae_lens.loading.pretrained_saes_directory import (
39
51
  get_pretrained_saes_directory,
40
52
  get_repo_id_and_folder_name,
41
53
  )
42
- from sae_lens.regsitry import get_sae_class, get_sae_training_class
54
+ from sae_lens.registry import get_sae_class, get_sae_training_class
43
55
 
44
- T = TypeVar("T", bound="SAE")
56
+ T_SAE_CONFIG = TypeVar("T_SAE_CONFIG", bound="SAEConfig")
57
+ T_TRAINING_SAE_CONFIG = TypeVar("T_TRAINING_SAE_CONFIG", bound="TrainingSAEConfig")
58
+ T_SAE = TypeVar("T_SAE", bound="SAE") # type: ignore
59
+ T_TRAINING_SAE = TypeVar("T_TRAINING_SAE", bound="TrainingSAE") # type: ignore
60
+
61
+
62
+ @dataclass
63
+ class SAEMetadata:
64
+ """Core metadata about how this SAE should be used, if known."""
65
+
66
+ model_name: str | None = None
67
+ hook_name: str | None = None
68
+ model_class_name: str | None = None
69
+ hook_head_index: int | None = None
70
+ model_from_pretrained_kwargs: dict[str, Any] | None = None
71
+ prepend_bos: bool | None = None
72
+ exclude_special_tokens: bool | list[int] | None = None
73
+ neuronpedia_id: str | None = None
74
+ context_size: int | None = None
75
+ seqpos_slice: tuple[int | None, ...] | None = None
76
+ dataset_path: str | None = None
77
+ sae_lens_version: str = field(default_factory=lambda: __version__)
78
+ sae_lens_training_version: str = field(default_factory=lambda: __version__)
45
79
 
46
80
 
47
81
  @dataclass
48
- class SAEConfig:
82
+ class SAEConfig(ABC):
49
83
  """Base configuration for SAE models."""
50
84
 
51
- architecture: str
52
85
  d_in: int
53
86
  d_sae: int
54
- dtype: str
55
- device: str
56
- model_name: str
57
- hook_name: str
58
- hook_layer: int
59
- hook_head_index: int | None
60
- activation_fn: str
61
- activation_fn_kwargs: dict[str, Any]
62
- apply_b_dec_to_input: bool
63
- finetuning_scaling_factor: bool
64
- normalize_activations: str
65
- context_size: int
66
- dataset_path: str
67
- dataset_trust_remote_code: bool
68
- sae_lens_training_version: str
69
- model_from_pretrained_kwargs: dict[str, Any] = field(default_factory=dict)
70
- seqpos_slice: tuple[int, ...] | None = None
71
- prepend_bos: bool = False
72
- neuronpedia_id: str | None = None
73
-
74
- def to_dict(self) -> dict[str, Any]:
75
- return {field.name: getattr(self, field.name) for field in fields(self)}
87
+ dtype: str = "float32"
88
+ device: str = "cpu"
89
+ apply_b_dec_to_input: bool = True
90
+ normalize_activations: Literal[
91
+ "none", "expected_average_only_in", "constant_norm_rescale", "layer_norm"
92
+ ] = "none" # none, expected_average_only_in (Anthropic April Update), constant_norm_rescale (Anthropic Feb Update)
93
+ reshape_activations: Literal["none", "hook_z"] = "none"
94
+ metadata: SAEMetadata = field(default_factory=SAEMetadata)
76
95
 
77
96
  @classmethod
78
- def from_dict(cls, config_dict: dict[str, Any]) -> "SAEConfig":
79
- valid_field_names = {field.name for field in fields(cls)}
80
- valid_config_dict = {
81
- key: val for key, val in config_dict.items() if key in valid_field_names
82
- }
97
+ @abstractmethod
98
+ def architecture(cls) -> str: ...
83
99
 
84
- # Ensure seqpos_slice is a tuple
85
- if (
86
- "seqpos_slice" in valid_config_dict
87
- and valid_config_dict["seqpos_slice"] is not None
88
- and isinstance(valid_config_dict["seqpos_slice"], list)
89
- ):
90
- valid_config_dict["seqpos_slice"] = tuple(valid_config_dict["seqpos_slice"])
100
+ def to_dict(self) -> dict[str, Any]:
101
+ res = {field.name: getattr(self, field.name) for field in fields(self)}
102
+ res["metadata"] = asdict(self.metadata)
103
+ res["architecture"] = self.architecture()
104
+ return res
91
105
 
92
- return cls(**valid_config_dict)
106
+ @classmethod
107
+ def from_dict(cls: type[T_SAE_CONFIG], config_dict: dict[str, Any]) -> T_SAE_CONFIG:
108
+ cfg_class = get_sae_class(config_dict["architecture"])[1]
109
+ filtered_config_dict = filter_valid_dataclass_fields(config_dict, cfg_class)
110
+ res = cfg_class(**filtered_config_dict)
111
+ if "metadata" in config_dict:
112
+ res.metadata = SAEMetadata(**config_dict["metadata"])
113
+ if not isinstance(res, cls):
114
+ raise ValueError(
115
+ f"SAE config class {cls} does not match dict config class {type(res)}"
116
+ )
117
+ return res
118
+
119
+ def __post_init__(self):
120
+ if self.normalize_activations not in [
121
+ "none",
122
+ "expected_average_only_in",
123
+ "constant_norm_rescale",
124
+ "layer_norm",
125
+ ]:
126
+ raise ValueError(
127
+ f"normalize_activations must be none, expected_average_only_in, constant_norm_rescale, or layer_norm. Got {self.normalize_activations}"
128
+ )
93
129
 
94
130
 
95
131
  @dataclass
@@ -109,14 +145,19 @@ class TrainStepInput:
109
145
  """Input to a training step."""
110
146
 
111
147
  sae_in: torch.Tensor
112
- current_l1_coefficient: float
148
+ coefficients: dict[str, float]
113
149
  dead_neuron_mask: torch.Tensor | None
114
150
 
115
151
 
116
- class SAE(HookedRootModule, ABC):
152
+ class TrainCoefficientConfig(NamedTuple):
153
+ value: float
154
+ warm_up_steps: int
155
+
156
+
157
+ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
117
158
  """Abstract base class for all SAE architectures."""
118
159
 
119
- cfg: SAEConfig
160
+ cfg: T_SAE_CONFIG
120
161
  dtype: torch.dtype
121
162
  device: torch.device
122
163
  use_error_term: bool
@@ -127,13 +168,13 @@ class SAE(HookedRootModule, ABC):
127
168
  W_dec: nn.Parameter
128
169
  b_dec: nn.Parameter
129
170
 
130
- def __init__(self, cfg: SAEConfig, use_error_term: bool = False):
171
+ def __init__(self, cfg: T_SAE_CONFIG, use_error_term: bool = False):
131
172
  """Initialize the SAE."""
132
173
  super().__init__()
133
174
 
134
175
  self.cfg = cfg
135
176
 
136
- if cfg.model_from_pretrained_kwargs:
177
+ if cfg.metadata and cfg.metadata.model_from_pretrained_kwargs:
137
178
  warnings.warn(
138
179
  "\nThis SAE has non-empty model_from_pretrained_kwargs. "
139
180
  "\nFor optimal performance, load the model like so:\n"
@@ -147,19 +188,11 @@ class SAE(HookedRootModule, ABC):
147
188
  self.use_error_term = use_error_term
148
189
 
149
190
  # Set up activation function
150
- self.activation_fn = self._get_activation_fn()
191
+ self.activation_fn = self.get_activation_fn()
151
192
 
152
193
  # Initialize weights
153
194
  self.initialize_weights()
154
195
 
155
- # Handle presence / absence of scaling factor
156
- if self.cfg.finetuning_scaling_factor:
157
- self.apply_finetuning_scaling_factor = (
158
- lambda x: x * self.finetuning_scaling_factor
159
- )
160
- else:
161
- self.apply_finetuning_scaling_factor = lambda x: x
162
-
163
196
  # Set up hooks
164
197
  self.hook_sae_input = HookPoint()
165
198
  self.hook_sae_acts_pre = HookPoint()
@@ -169,11 +202,9 @@ class SAE(HookedRootModule, ABC):
169
202
  self.hook_sae_error = HookPoint()
170
203
 
171
204
  # handle hook_z reshaping if needed.
172
- if self.cfg.hook_name.endswith("_z"):
173
- # print(f"Setting up hook_z reshaping for {self.cfg.hook_name}")
205
+ if self.cfg.reshape_activations == "hook_z":
174
206
  self.turn_on_forward_pass_hook_z_reshaping()
175
207
  else:
176
- # print(f"No hook_z reshaping needed for {self.cfg.hook_name}")
177
208
  self.turn_off_forward_pass_hook_z_reshaping()
178
209
 
179
210
  # Set up activation normalization
@@ -188,40 +219,9 @@ class SAE(HookedRootModule, ABC):
188
219
  self.b_dec.data /= scaling_factor # type: ignore
189
220
  self.cfg.normalize_activations = "none"
190
221
 
191
- def _get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
222
+ def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
192
223
  """Get the activation function specified in config."""
193
- return self._get_activation_fn_static(
194
- self.cfg.activation_fn, **(self.cfg.activation_fn_kwargs or {})
195
- )
196
-
197
- @staticmethod
198
- def _get_activation_fn_static(
199
- activation_fn: str, **kwargs: Any
200
- ) -> Callable[[torch.Tensor], torch.Tensor]:
201
- """Get the activation function from a string specification."""
202
- if activation_fn == "relu":
203
- return torch.nn.ReLU()
204
- if activation_fn == "tanh-relu":
205
-
206
- def tanh_relu(input: torch.Tensor) -> torch.Tensor:
207
- input = torch.relu(input)
208
- return torch.tanh(input)
209
-
210
- return tanh_relu
211
- if activation_fn == "topk":
212
- if "k" not in kwargs:
213
- raise ValueError("TopK activation function requires a k value.")
214
- k = kwargs.get("k", 1) # Default k to 1 if not provided
215
-
216
- def topk_fn(x: torch.Tensor) -> torch.Tensor:
217
- topk = torch.topk(x.flatten(start_dim=-1), k=k, dim=-1)
218
- values = torch.relu(topk.values)
219
- result = torch.zeros_like(x.flatten(start_dim=-1))
220
- result.scatter_(-1, topk.indices, values)
221
- return result.view_as(x)
222
-
223
- return topk_fn
224
- raise ValueError(f"Unknown activation function: {activation_fn}")
224
+ return nn.ReLU()
225
225
 
226
226
  def _setup_activation_normalization(self):
227
227
  """Set up activation normalization functions based on config."""
@@ -264,10 +264,20 @@ class SAE(HookedRootModule, ABC):
264
264
  self.run_time_activation_norm_fn_in = lambda x: x
265
265
  self.run_time_activation_norm_fn_out = lambda x: x
266
266
 
267
- @abstractmethod
268
267
  def initialize_weights(self):
269
268
  """Initialize model weights."""
270
- pass
269
+ self.b_dec = nn.Parameter(
270
+ torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
271
+ )
272
+
273
+ w_dec_data = torch.empty(
274
+ self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
275
+ )
276
+ nn.init.kaiming_uniform_(w_dec_data)
277
+ self.W_dec = nn.Parameter(w_dec_data)
278
+
279
+ w_enc_data = self.W_dec.data.T.clone().detach().contiguous()
280
+ self.W_enc = nn.Parameter(w_enc_data)
271
281
 
272
282
  @abstractmethod
273
283
  def encode(
@@ -284,7 +294,10 @@ class SAE(HookedRootModule, ABC):
284
294
  pass
285
295
 
286
296
  def turn_on_forward_pass_hook_z_reshaping(self):
287
- if not self.cfg.hook_name.endswith("_z"):
297
+ if (
298
+ self.cfg.metadata.hook_name is not None
299
+ and not self.cfg.metadata.hook_name.endswith("_z")
300
+ ):
288
301
  raise ValueError("This method should only be called for hook_z SAEs.")
289
302
 
290
303
  # print(f"Turning on hook_z reshaping for {self.cfg.hook_name}")
@@ -313,19 +326,19 @@ class SAE(HookedRootModule, ABC):
313
326
 
314
327
  @overload
315
328
  def to(
316
- self: T,
329
+ self: T_SAE,
317
330
  device: torch.device | str | None = ...,
318
331
  dtype: torch.dtype | None = ...,
319
332
  non_blocking: bool = ...,
320
- ) -> T: ...
333
+ ) -> T_SAE: ...
321
334
 
322
335
  @overload
323
- def to(self: T, dtype: torch.dtype, non_blocking: bool = ...) -> T: ...
336
+ def to(self: T_SAE, dtype: torch.dtype, non_blocking: bool = ...) -> T_SAE: ...
324
337
 
325
338
  @overload
326
- def to(self: T, tensor: torch.Tensor, non_blocking: bool = ...) -> T: ...
339
+ def to(self: T_SAE, tensor: torch.Tensor, non_blocking: bool = ...) -> T_SAE: ...
327
340
 
328
- def to(self: T, *args: Any, **kwargs: Any) -> T: # type: ignore
341
+ def to(self: T_SAE, *args: Any, **kwargs: Any) -> T_SAE: # type: ignore
329
342
  device_arg = None
330
343
  dtype_arg = None
331
344
 
@@ -426,15 +439,12 @@ class SAE(HookedRootModule, ABC):
426
439
 
427
440
  def get_name(self):
428
441
  """Generate a name for this SAE."""
429
- return f"sae_{self.cfg.model_name}_{self.cfg.hook_name}_{self.cfg.d_sae}"
442
+ return f"sae_{self.cfg.metadata.model_name}_{self.cfg.metadata.hook_name}_{self.cfg.d_sae}"
430
443
 
431
- def save_model(
432
- self, path: str | Path, sparsity: torch.Tensor | None = None
433
- ) -> tuple[Path, Path, Path | None]:
434
- """Save model weights, config, and optional sparsity tensor to disk."""
444
+ def save_model(self, path: str | Path) -> tuple[Path, Path]:
445
+ """Save model weights and config to disk."""
435
446
  path = Path(path)
436
- if not path.exists():
437
- path.mkdir(parents=True)
447
+ path.mkdir(parents=True, exist_ok=True)
438
448
 
439
449
  # Generate the weights
440
450
  state_dict = self.state_dict() # Use internal SAE state dict
@@ -448,13 +458,7 @@ class SAE(HookedRootModule, ABC):
448
458
  with open(cfg_path, "w") as f:
449
459
  json.dump(config, f)
450
460
 
451
- if sparsity is not None:
452
- sparsity_in_dict = {"sparsity": sparsity}
453
- sparsity_path = path / SPARSITY_FILENAME
454
- save_file(sparsity_in_dict, sparsity_path)
455
- return model_weights_path, cfg_path, sparsity_path
456
-
457
- return model_weights_path, cfg_path, None
461
+ return model_weights_path, cfg_path
458
462
 
459
463
  ## Initialization Methods
460
464
  @torch.no_grad()
@@ -482,18 +486,21 @@ class SAE(HookedRootModule, ABC):
482
486
  @classmethod
483
487
  @deprecated("Use load_from_disk instead")
484
488
  def load_from_pretrained(
485
- cls: Type[T], path: str | Path, device: str = "cpu", dtype: str | None = None
486
- ) -> T:
489
+ cls: Type[T_SAE],
490
+ path: str | Path,
491
+ device: str = "cpu",
492
+ dtype: str | None = None,
493
+ ) -> T_SAE:
487
494
  return cls.load_from_disk(path, device=device, dtype=dtype)
488
495
 
489
496
  @classmethod
490
497
  def load_from_disk(
491
- cls: Type[T],
498
+ cls: Type[T_SAE],
492
499
  path: str | Path,
493
500
  device: str = "cpu",
494
501
  dtype: str | None = None,
495
502
  converter: PretrainedSaeDiskLoader = sae_lens_disk_loader,
496
- ) -> T:
503
+ ) -> T_SAE:
497
504
  overrides = {"dtype": dtype} if dtype is not None else None
498
505
  cfg_dict, state_dict = converter(path, device, cfg_overrides=overrides)
499
506
  cfg_dict = handle_config_defaulting(cfg_dict)
@@ -501,7 +508,7 @@ class SAE(HookedRootModule, ABC):
501
508
  cfg_dict["architecture"]
502
509
  )
503
510
  sae_cfg = sae_config_cls.from_dict(cfg_dict)
504
- sae_cls = cls.get_sae_class_for_architecture(sae_cfg.architecture)
511
+ sae_cls = cls.get_sae_class_for_architecture(sae_cfg.architecture())
505
512
  sae = sae_cls(sae_cfg)
506
513
  sae.process_state_dict_for_loading(state_dict)
507
514
  sae.load_state_dict(state_dict)
@@ -509,13 +516,13 @@ class SAE(HookedRootModule, ABC):
509
516
 
510
517
  @classmethod
511
518
  def from_pretrained(
512
- cls,
519
+ cls: Type[T_SAE],
513
520
  release: str,
514
521
  sae_id: str,
515
522
  device: str = "cpu",
516
523
  force_download: bool = False,
517
524
  converter: PretrainedSaeHuggingfaceLoader | None = None,
518
- ) -> tuple["SAE", dict[str, Any], torch.Tensor | None]:
525
+ ) -> tuple[T_SAE, dict[str, Any], torch.Tensor | None]:
519
526
  """
520
527
  Load a pretrained SAE from the Hugging Face model hub.
521
528
 
@@ -584,47 +591,31 @@ class SAE(HookedRootModule, ABC):
584
591
  )
585
592
  cfg_dict = handle_config_defaulting(cfg_dict)
586
593
 
587
- # Rename keys to match SAEConfig field names
588
- renamed_cfg_dict = {}
589
- rename_map = {
590
- "hook_point": "hook_name",
591
- "hook_point_layer": "hook_layer",
592
- "hook_point_head_index": "hook_head_index",
593
- "activation_fn": "activation_fn",
594
- }
595
-
596
- for k, v in cfg_dict.items():
597
- renamed_cfg_dict[rename_map.get(k, k)] = v
598
-
599
- # Set default values for required fields
600
- renamed_cfg_dict.setdefault("activation_fn_kwargs", {})
601
- renamed_cfg_dict.setdefault("seqpos_slice", None)
602
-
603
594
  # Create SAE with appropriate architecture
604
595
  sae_config_cls = cls.get_sae_config_class_for_architecture(
605
- renamed_cfg_dict["architecture"]
596
+ cfg_dict["architecture"]
606
597
  )
607
- sae_cfg = sae_config_cls.from_dict(renamed_cfg_dict)
608
- sae_cls = cls.get_sae_class_for_architecture(sae_cfg.architecture)
598
+ sae_cfg = sae_config_cls.from_dict(cfg_dict)
599
+ sae_cls = cls.get_sae_class_for_architecture(sae_cfg.architecture())
609
600
  sae = sae_cls(sae_cfg)
610
601
  sae.process_state_dict_for_loading(state_dict)
611
602
  sae.load_state_dict(state_dict)
612
603
 
613
604
  # Apply normalization if needed
614
- if renamed_cfg_dict.get("normalize_activations") == "expected_average_only_in":
605
+ if cfg_dict.get("normalize_activations") == "expected_average_only_in":
615
606
  norm_scaling_factor = get_norm_scaling_factor(release, sae_id)
616
607
  if norm_scaling_factor is not None:
617
608
  sae.fold_activation_norm_scaling_factor(norm_scaling_factor)
618
- renamed_cfg_dict["normalize_activations"] = "none"
609
+ cfg_dict["normalize_activations"] = "none"
619
610
  else:
620
611
  warnings.warn(
621
612
  f"norm_scaling_factor not found for {release} and {sae_id}, but normalize_activations is 'expected_average_only_in'. Skipping normalization folding."
622
613
  )
623
614
 
624
- return sae, renamed_cfg_dict, log_sparsities
615
+ return sae, cfg_dict, log_sparsities
625
616
 
626
617
  @classmethod
627
- def from_dict(cls: Type[T], config_dict: dict[str, Any]) -> T:
618
+ def from_dict(cls: Type[T_SAE], config_dict: dict[str, Any]) -> T_SAE:
628
619
  """Create an SAE from a config dictionary."""
629
620
  sae_cls = cls.get_sae_class_for_architecture(config_dict["architecture"])
630
621
  sae_config_cls = cls.get_sae_config_class_for_architecture(
@@ -633,9 +624,11 @@ class SAE(HookedRootModule, ABC):
633
624
  return sae_cls(sae_config_cls.from_dict(config_dict))
634
625
 
635
626
  @classmethod
636
- def get_sae_class_for_architecture(cls: Type[T], architecture: str) -> Type[T]:
627
+ def get_sae_class_for_architecture(
628
+ cls: Type[T_SAE], architecture: str
629
+ ) -> Type[T_SAE]:
637
630
  """Get the SAE class for a given architecture."""
638
- sae_cls = get_sae_class(architecture)
631
+ sae_cls, _ = get_sae_class(architecture)
639
632
  if not issubclass(sae_cls, cls):
640
633
  raise ValueError(
641
634
  f"Loaded SAE is not of type {cls.__name__}. Use {sae_cls.__name__} instead"
@@ -645,161 +638,105 @@ class SAE(HookedRootModule, ABC):
645
638
  # in the future, this can be used to load different config classes for different architectures
646
639
  @classmethod
647
640
  def get_sae_config_class_for_architecture(
648
- cls: Type[T],
641
+ cls,
649
642
  architecture: str, # noqa: ARG003
650
643
  ) -> type[SAEConfig]:
651
644
  return SAEConfig
652
645
 
653
646
 
654
647
  @dataclass(kw_only=True)
655
- class TrainingSAEConfig(SAEConfig):
656
- # Sparsity Loss Calculations
657
- l1_coefficient: float
658
- lp_norm: float
659
- use_ghost_grads: bool
660
- normalize_sae_decoder: bool
661
- noise_scale: float
662
- decoder_orthogonal_init: bool
663
- mse_loss_normalization: str | None
664
- jumprelu_init_threshold: float
665
- jumprelu_bandwidth: float
666
- decoder_heuristic_init: bool
667
- init_encoder_as_decoder_transpose: bool
668
- scale_sparsity_penalty_by_decoder_norm: bool
648
+ class TrainingSAEConfig(SAEConfig, ABC):
649
+ noise_scale: float = 0.0
650
+ mse_loss_normalization: str | None = None
651
+ # https://transformer-circuits.pub/2024/april-update/index.html#training-saes
652
+ # 0.1 corresponds to the "heuristic" initialization, use None to disable
653
+ decoder_init_norm: float | None = 0.1
654
+
655
+ @classmethod
656
+ @abstractmethod
657
+ def architecture(cls) -> str: ...
669
658
 
670
659
  @classmethod
671
660
  def from_sae_runner_config(
672
- cls, cfg: LanguageModelSAERunnerConfig
673
- ) -> "TrainingSAEConfig":
674
- return cls(
675
- # base config
676
- architecture=cfg.architecture,
677
- d_in=cfg.d_in,
678
- d_sae=cfg.d_sae, # type: ignore
679
- dtype=cfg.dtype,
680
- device=cfg.device,
661
+ cls: type[T_TRAINING_SAE_CONFIG],
662
+ cfg: "LanguageModelSAERunnerConfig[T_TRAINING_SAE_CONFIG]",
663
+ ) -> T_TRAINING_SAE_CONFIG:
664
+ metadata = SAEMetadata(
681
665
  model_name=cfg.model_name,
682
666
  hook_name=cfg.hook_name,
683
- hook_layer=cfg.hook_layer,
684
667
  hook_head_index=cfg.hook_head_index,
685
- activation_fn=cfg.activation_fn,
686
- activation_fn_kwargs=cfg.activation_fn_kwargs,
687
- apply_b_dec_to_input=cfg.apply_b_dec_to_input,
688
- finetuning_scaling_factor=cfg.finetuning_method is not None,
689
- sae_lens_training_version=cfg.sae_lens_training_version,
690
668
  context_size=cfg.context_size,
691
- dataset_path=cfg.dataset_path,
692
669
  prepend_bos=cfg.prepend_bos,
693
- seqpos_slice=tuple(x for x in cfg.seqpos_slice if x is not None)
694
- if cfg.seqpos_slice is not None
695
- else None,
696
- # Training cfg
697
- l1_coefficient=cfg.l1_coefficient,
698
- lp_norm=cfg.lp_norm,
699
- use_ghost_grads=cfg.use_ghost_grads,
700
- normalize_sae_decoder=cfg.normalize_sae_decoder,
701
- noise_scale=cfg.noise_scale,
702
- decoder_orthogonal_init=cfg.decoder_orthogonal_init,
703
- mse_loss_normalization=cfg.mse_loss_normalization,
704
- decoder_heuristic_init=cfg.decoder_heuristic_init,
705
- init_encoder_as_decoder_transpose=cfg.init_encoder_as_decoder_transpose,
706
- scale_sparsity_penalty_by_decoder_norm=cfg.scale_sparsity_penalty_by_decoder_norm,
707
- normalize_activations=cfg.normalize_activations,
708
- dataset_trust_remote_code=cfg.dataset_trust_remote_code,
670
+ seqpos_slice=cfg.seqpos_slice,
709
671
  model_from_pretrained_kwargs=cfg.model_from_pretrained_kwargs or {},
710
- jumprelu_init_threshold=cfg.jumprelu_init_threshold,
711
- jumprelu_bandwidth=cfg.jumprelu_bandwidth,
712
672
  )
673
+ if not isinstance(cfg.sae, cls):
674
+ raise ValueError(
675
+ f"SAE config class {cls} does not match SAE runner config class {type(cfg.sae)}"
676
+ )
677
+ return replace(cfg.sae, metadata=metadata)
713
678
 
714
679
  @classmethod
715
- def from_dict(cls, config_dict: dict[str, Any]) -> "TrainingSAEConfig":
680
+ def from_dict(
681
+ cls: type[T_TRAINING_SAE_CONFIG], config_dict: dict[str, Any]
682
+ ) -> T_TRAINING_SAE_CONFIG:
716
683
  # remove any keys that are not in the dataclass
717
684
  # since we sometimes enhance the config with the whole LM runner config
718
- valid_field_names = {field.name for field in fields(cls)}
719
- valid_config_dict = {
720
- key: val for key, val in config_dict.items() if key in valid_field_names
721
- }
722
-
723
- # ensure seqpos slice is tuple
724
- # ensure that seqpos slices is a tuple
725
- # Ensure seqpos_slice is a tuple
726
- if "seqpos_slice" in valid_config_dict:
727
- if isinstance(valid_config_dict["seqpos_slice"], list):
728
- valid_config_dict["seqpos_slice"] = tuple(
729
- valid_config_dict["seqpos_slice"]
730
- )
731
- elif not isinstance(valid_config_dict["seqpos_slice"], tuple):
732
- valid_config_dict["seqpos_slice"] = (valid_config_dict["seqpos_slice"],)
733
-
734
- return TrainingSAEConfig(**valid_config_dict)
685
+ valid_config_dict = filter_valid_dataclass_fields(config_dict, cls)
686
+ cfg_class = cls
687
+ if "architecture" in config_dict:
688
+ cfg_class = get_sae_training_class(config_dict["architecture"])[1]
689
+ if not issubclass(cfg_class, cls):
690
+ raise ValueError(
691
+ f"SAE config class {cls} does not match dict config class {type(cfg_class)}"
692
+ )
693
+ if "metadata" in config_dict:
694
+ valid_config_dict["metadata"] = SAEMetadata(**config_dict["metadata"])
695
+ return cfg_class(**valid_config_dict)
735
696
 
736
697
  def to_dict(self) -> dict[str, Any]:
737
698
  return {
738
699
  **super().to_dict(),
739
- "l1_coefficient": self.l1_coefficient,
740
- "lp_norm": self.lp_norm,
741
- "use_ghost_grads": self.use_ghost_grads,
742
- "normalize_sae_decoder": self.normalize_sae_decoder,
743
- "noise_scale": self.noise_scale,
744
- "decoder_orthogonal_init": self.decoder_orthogonal_init,
745
- "init_encoder_as_decoder_transpose": self.init_encoder_as_decoder_transpose,
746
- "mse_loss_normalization": self.mse_loss_normalization,
747
- "decoder_heuristic_init": self.decoder_heuristic_init,
748
- "scale_sparsity_penalty_by_decoder_norm": self.scale_sparsity_penalty_by_decoder_norm,
749
- "normalize_activations": self.normalize_activations,
750
- "jumprelu_init_threshold": self.jumprelu_init_threshold,
751
- "jumprelu_bandwidth": self.jumprelu_bandwidth,
700
+ **asdict(self),
701
+ "architecture": self.architecture(),
752
702
  }
753
703
 
754
704
  # this needs to exist so we can initialize the parent sae cfg without the training specific
755
705
  # parameters. Maybe there's a cleaner way to do this
756
706
  def get_base_sae_cfg_dict(self) -> dict[str, Any]:
757
- return {
758
- "architecture": self.architecture,
759
- "d_in": self.d_in,
760
- "d_sae": self.d_sae,
761
- "activation_fn": self.activation_fn,
762
- "activation_fn_kwargs": self.activation_fn_kwargs,
763
- "apply_b_dec_to_input": self.apply_b_dec_to_input,
764
- "dtype": self.dtype,
765
- "model_name": self.model_name,
766
- "hook_name": self.hook_name,
767
- "hook_layer": self.hook_layer,
768
- "hook_head_index": self.hook_head_index,
769
- "device": self.device,
770
- "context_size": self.context_size,
771
- "prepend_bos": self.prepend_bos,
772
- "finetuning_scaling_factor": self.finetuning_scaling_factor,
773
- "normalize_activations": self.normalize_activations,
774
- "dataset_path": self.dataset_path,
775
- "dataset_trust_remote_code": self.dataset_trust_remote_code,
776
- "sae_lens_training_version": self.sae_lens_training_version,
777
- "model_from_pretrained_kwargs": self.model_from_pretrained_kwargs,
778
- "seqpos_slice": self.seqpos_slice,
779
- "neuronpedia_id": self.neuronpedia_id,
707
+ """
708
+ Creates a dictionary containing attributes corresponding to the fields
709
+ defined in the base SAEConfig class.
710
+ """
711
+ base_config_field_names = {f.name for f in fields(SAEConfig)}
712
+ result_dict = {
713
+ field_name: getattr(self, field_name)
714
+ for field_name in base_config_field_names
780
715
  }
716
+ result_dict["architecture"] = self.architecture()
717
+ return result_dict
781
718
 
782
719
 
783
- class TrainingSAE(SAE, ABC):
720
+ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
784
721
  """Abstract base class for training versions of SAEs."""
785
722
 
786
- cfg: "TrainingSAEConfig" # type: ignore
787
-
788
- def __init__(self, cfg: TrainingSAEConfig, use_error_term: bool = False):
723
+ def __init__(self, cfg: T_TRAINING_SAE_CONFIG, use_error_term: bool = False):
789
724
  super().__init__(cfg, use_error_term)
790
725
 
791
726
  # Turn off hook_z reshaping for training mode - the activation store
792
727
  # is expected to handle reshaping before passing data to the SAE
793
728
  self.turn_off_forward_pass_hook_z_reshaping()
794
-
795
729
  self.mse_loss_fn = self._get_mse_loss_fn()
796
730
 
731
+ @abstractmethod
732
+ def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]: ...
733
+
797
734
  @abstractmethod
798
735
  def encode_with_hidden_pre(
799
736
  self, x: Float[torch.Tensor, "... d_in"]
800
737
  ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
801
738
  """Encode with access to pre-activation values for training."""
802
- pass
739
+ ...
803
740
 
804
741
  def encode(
805
742
  self, x: Float[torch.Tensor, "... d_in"]
@@ -818,12 +755,20 @@ class TrainingSAE(SAE, ABC):
818
755
  Decodes feature activations back into input space,
819
756
  applying optional finetuning scale, hooking, out normalization, etc.
820
757
  """
821
- scaled_features = self.apply_finetuning_scaling_factor(feature_acts)
822
- sae_out_pre = scaled_features @ self.W_dec + self.b_dec
758
+ sae_out_pre = feature_acts @ self.W_dec + self.b_dec
823
759
  sae_out_pre = self.hook_sae_recons(sae_out_pre)
824
760
  sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
825
761
  return self.reshape_fn_out(sae_out_pre, self.d_head)
826
762
 
763
+ @override
764
+ def initialize_weights(self):
765
+ super().initialize_weights()
766
+ if self.cfg.decoder_init_norm is not None:
767
+ with torch.no_grad():
768
+ self.W_dec.data /= self.W_dec.norm(dim=-1, keepdim=True)
769
+ self.W_dec.data *= self.cfg.decoder_init_norm
770
+ self.W_enc.data = self.W_dec.data.T.clone().detach().contiguous()
771
+
827
772
  @abstractmethod
828
773
  def calculate_aux_loss(
829
774
  self,
@@ -833,7 +778,7 @@ class TrainingSAE(SAE, ABC):
833
778
  sae_out: torch.Tensor,
834
779
  ) -> torch.Tensor | dict[str, torch.Tensor]:
835
780
  """Calculate architecture-specific auxiliary loss terms."""
836
- pass
781
+ ...
837
782
 
838
783
  def training_forward_pass(
839
784
  self,
@@ -883,6 +828,39 @@ class TrainingSAE(SAE, ABC):
883
828
  losses=losses,
884
829
  )
885
830
 
831
+ def save_inference_model(self, path: str | Path) -> tuple[Path, Path]:
832
+ """Save inference version of model weights and config to disk."""
833
+ path = Path(path)
834
+ path.mkdir(parents=True, exist_ok=True)
835
+
836
+ # Generate the weights
837
+ state_dict = self.state_dict() # Use internal SAE state dict
838
+ self.process_state_dict_for_saving_inference(state_dict)
839
+ model_weights_path = path / SAE_WEIGHTS_FILENAME
840
+ save_file(state_dict, model_weights_path)
841
+
842
+ # Save the config
843
+ config = self.to_inference_config_dict()
844
+ cfg_path = path / SAE_CFG_FILENAME
845
+ with open(cfg_path, "w") as f:
846
+ json.dump(config, f)
847
+
848
+ return model_weights_path, cfg_path
849
+
850
+ @abstractmethod
851
+ def to_inference_config_dict(self) -> dict[str, Any]:
852
+ """Convert the config into an inference SAE config dict."""
853
+ ...
854
+
855
+ def process_state_dict_for_saving_inference(
856
+ self, state_dict: dict[str, Any]
857
+ ) -> None:
858
+ """
859
+ Process the state dict for saving the inference model.
860
+ This is a hook that can be overridden to change how the state dict is processed for the inference model.
861
+ """
862
+ return self.process_state_dict_for_saving(state_dict)
863
+
886
864
  def _get_mse_loss_fn(self) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
887
865
  """Get the MSE loss function based on config."""
888
866
 
@@ -921,11 +899,6 @@ class TrainingSAE(SAE, ABC):
921
899
  "d_sae, d_sae d_in -> d_sae d_in",
922
900
  )
923
901
 
924
- @torch.no_grad()
925
- def set_decoder_norm_to_unit_norm(self):
926
- """Normalize decoder columns to unit norm."""
927
- self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
928
-
929
902
  @torch.no_grad()
930
903
  def log_histograms(self) -> dict[str, NDArray[Any]]:
931
904
  """Log histograms of the weights and biases."""
@@ -935,9 +908,11 @@ class TrainingSAE(SAE, ABC):
935
908
  }
936
909
 
937
910
  @classmethod
938
- def get_sae_class_for_architecture(cls: Type[T], architecture: str) -> Type[T]:
911
+ def get_sae_class_for_architecture(
912
+ cls: Type[T_TRAINING_SAE], architecture: str
913
+ ) -> Type[T_TRAINING_SAE]:
939
914
  """Get the SAE class for a given architecture."""
940
- sae_cls = get_sae_training_class(architecture)
915
+ sae_cls, _ = get_sae_training_class(architecture)
941
916
  if not issubclass(sae_cls, cls):
942
917
  raise ValueError(
943
918
  f"Loaded SAE is not of type {cls.__name__}. Use {sae_cls.__name__} instead"
@@ -947,17 +922,17 @@ class TrainingSAE(SAE, ABC):
947
922
  # in the future, this can be used to load different config classes for different architectures
948
923
  @classmethod
949
924
  def get_sae_config_class_for_architecture(
950
- cls: Type[T],
925
+ cls,
951
926
  architecture: str, # noqa: ARG003
952
- ) -> type[SAEConfig]:
953
- return TrainingSAEConfig
927
+ ) -> type[TrainingSAEConfig]:
928
+ return get_sae_training_class(architecture)[1]
954
929
 
955
930
 
956
931
  _blank_hook = nn.Identity()
957
932
 
958
933
 
959
934
  @contextmanager
960
- def _disable_hooks(sae: SAE):
935
+ def _disable_hooks(sae: SAE[Any]):
961
936
  """
962
937
  Temporarily disable hooks for the SAE. Swaps out all the hooks with a fake modules that does nothing.
963
938
  """