sae-lens 6.0.0rc2__py3-none-any.whl → 6.0.0rc4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sae_lens/saes/sae.py CHANGED
@@ -1,5 +1,6 @@
1
1
  """Base classes for Sparse Autoencoders (SAEs)."""
2
2
 
3
+ import copy
3
4
  import json
4
5
  import warnings
5
6
  from abc import ABC, abstractmethod
@@ -59,24 +60,91 @@ T_SAE = TypeVar("T_SAE", bound="SAE") # type: ignore
59
60
  T_TRAINING_SAE = TypeVar("T_TRAINING_SAE", bound="TrainingSAE") # type: ignore
60
61
 
61
62
 
62
- @dataclass
63
63
  class SAEMetadata:
64
64
  """Core metadata about how this SAE should be used, if known."""
65
65
 
66
- model_name: str | None = None
67
- hook_name: str | None = None
68
- model_class_name: str | None = None
69
- hook_layer: int | None = None
70
- hook_head_index: int | None = None
71
- model_from_pretrained_kwargs: dict[str, Any] | None = None
72
- prepend_bos: bool | None = None
73
- exclude_special_tokens: bool | list[int] | None = None
74
- neuronpedia_id: str | None = None
75
- context_size: int | None = None
76
- seqpos_slice: tuple[int | None, ...] | None = None
77
- dataset_path: str | None = None
78
- sae_lens_version: str = field(default_factory=lambda: __version__)
79
- sae_lens_training_version: str = field(default_factory=lambda: __version__)
66
+ def __init__(self, **kwargs: Any):
67
+ # Set default version fields with their current behavior
68
+ self.sae_lens_version = kwargs.pop("sae_lens_version", __version__)
69
+ self.sae_lens_training_version = kwargs.pop(
70
+ "sae_lens_training_version", __version__
71
+ )
72
+
73
+ # Set all other attributes dynamically
74
+ for key, value in kwargs.items():
75
+ setattr(self, key, value)
76
+
77
+ def __getattr__(self, name: str) -> None:
78
+ """Return None for any missing attribute (like defaultdict)"""
79
+ return
80
+
81
+ def __setattr__(self, name: str, value: Any) -> None:
82
+ """Allow setting any attribute"""
83
+ super().__setattr__(name, value)
84
+
85
+ def __getitem__(self, key: str) -> Any:
86
+ """Allow dictionary-style access: metadata['key'] - returns None for missing keys"""
87
+ return getattr(self, key)
88
+
89
+ def __setitem__(self, key: str, value: Any) -> None:
90
+ """Allow dictionary-style assignment: metadata['key'] = value"""
91
+ setattr(self, key, value)
92
+
93
+ def __contains__(self, key: str) -> bool:
94
+ """Allow 'in' operator: 'key' in metadata"""
95
+ # Only return True if the attribute was explicitly set (not just defaulting to None)
96
+ return key in self.__dict__
97
+
98
+ def get(self, key: str, default: Any = None) -> Any:
99
+ """Dictionary-style get with default"""
100
+ value = getattr(self, key)
101
+ # If the attribute wasn't explicitly set and we got None from __getattr__,
102
+ # use the provided default instead
103
+ if key not in self.__dict__ and value is None:
104
+ return default
105
+ return value
106
+
107
+ def keys(self):
108
+ """Return all explicitly set attribute names"""
109
+ return self.__dict__.keys()
110
+
111
+ def values(self):
112
+ """Return all explicitly set attribute values"""
113
+ return self.__dict__.values()
114
+
115
+ def items(self):
116
+ """Return all explicitly set attribute name-value pairs"""
117
+ return self.__dict__.items()
118
+
119
+ def to_dict(self) -> dict[str, Any]:
120
+ """Convert to dictionary for serialization"""
121
+ return self.__dict__.copy()
122
+
123
+ @classmethod
124
+ def from_dict(cls, data: dict[str, Any]) -> "SAEMetadata":
125
+ """Create from dictionary"""
126
+ return cls(**data)
127
+
128
+ def __repr__(self) -> str:
129
+ return f"SAEMetadata({self.__dict__})"
130
+
131
+ def __eq__(self, other: object) -> bool:
132
+ if not isinstance(other, SAEMetadata):
133
+ return False
134
+ return self.__dict__ == other.__dict__
135
+
136
+ def __deepcopy__(self, memo: dict[int, Any]) -> "SAEMetadata":
137
+ """Support for deep copying"""
138
+
139
+ return SAEMetadata(**copy.deepcopy(self.__dict__, memo))
140
+
141
+ def __getstate__(self) -> dict[str, Any]:
142
+ """Support for pickling"""
143
+ return self.__dict__
144
+
145
+ def __setstate__(self, state: dict[str, Any]) -> None:
146
+ """Support for unpickling"""
147
+ self.__dict__.update(state)
80
148
 
81
149
 
82
150
  @dataclass
@@ -100,7 +168,7 @@ class SAEConfig(ABC):
100
168
 
101
169
  def to_dict(self) -> dict[str, Any]:
102
170
  res = {field.name: getattr(self, field.name) for field in fields(self)}
103
- res["metadata"] = asdict(self.metadata)
171
+ res["metadata"] = self.metadata.to_dict()
104
172
  res["architecture"] = self.architecture()
105
173
  return res
106
174
 
@@ -125,7 +193,7 @@ class SAEConfig(ABC):
125
193
  "layer_norm",
126
194
  ]:
127
195
  raise ValueError(
128
- f"normalize_activations must be none, expected_average_only_in, constant_norm_rescale, or layer_norm. Got {self.normalize_activations}"
196
+ f"normalize_activations must be none, expected_average_only_in, layer_norm, or constant_norm_rescale. Got {self.normalize_activations}"
129
197
  )
130
198
 
131
199
 
@@ -239,9 +307,8 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
239
307
 
240
308
  self.run_time_activation_norm_fn_in = run_time_activation_norm_fn_in
241
309
  self.run_time_activation_norm_fn_out = run_time_activation_norm_fn_out
242
-
243
310
  elif self.cfg.normalize_activations == "layer_norm":
244
-
311
+ # we need to scale the norm of the input and store the scaling factor
245
312
  def run_time_activation_ln_in(
246
313
  x: torch.Tensor, eps: float = 1e-5
247
314
  ) -> torch.Tensor:
@@ -523,7 +590,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
523
590
  device: str = "cpu",
524
591
  force_download: bool = False,
525
592
  converter: PretrainedSaeHuggingfaceLoader | None = None,
526
- ) -> tuple[T_SAE, dict[str, Any], torch.Tensor | None]:
593
+ ) -> T_SAE:
527
594
  """
528
595
  Load a pretrained SAE from the Hugging Face model hub.
529
596
 
@@ -531,7 +598,28 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
531
598
  release: The release name. This will be mapped to a huggingface repo id based on the pretrained_saes.yaml file.
532
599
  id: The id of the SAE to load. This will be mapped to a path in the huggingface repo.
533
600
  device: The device to load the SAE on.
534
- return_sparsity_if_present: If True, will return the log sparsity tensor if it is present in the model directory in the Hugging Face model hub.
601
+ """
602
+ return cls.from_pretrained_with_cfg_and_sparsity(
603
+ release, sae_id, device, force_download, converter=converter
604
+ )[0]
605
+
606
+ @classmethod
607
+ def from_pretrained_with_cfg_and_sparsity(
608
+ cls: Type[T_SAE],
609
+ release: str,
610
+ sae_id: str,
611
+ device: str = "cpu",
612
+ force_download: bool = False,
613
+ converter: PretrainedSaeHuggingfaceLoader | None = None,
614
+ ) -> tuple[T_SAE, dict[str, Any], torch.Tensor | None]:
615
+ """
616
+ Load a pretrained SAE from the Hugging Face model hub, along with its config dict and sparsity, if present.
617
+ In SAELens <= 5.x.x, this was called SAE.from_pretrained().
618
+
619
+ Args:
620
+ release: The release name. This will be mapped to a huggingface repo id based on the pretrained_saes.yaml file.
621
+ id: The id of the SAE to load. This will be mapped to a path in the huggingface repo.
622
+ device: The device to load the SAE on.
535
623
  """
536
624
 
537
625
  # get sae directory
@@ -647,9 +735,6 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
647
735
 
648
736
  @dataclass(kw_only=True)
649
737
  class TrainingSAEConfig(SAEConfig, ABC):
650
- noise_scale: float = 0.0
651
- mse_loss_normalization: str | None = None
652
- b_dec_init_method: Literal["zeros", "geometric_median", "mean"] = "zeros"
653
738
  # https://transformer-circuits.pub/2024/april-update/index.html#training-saes
654
739
  # 0.1 corresponds to the "heuristic" initialization, use None to disable
655
740
  decoder_init_norm: float | None = 0.1
@@ -666,7 +751,6 @@ class TrainingSAEConfig(SAEConfig, ABC):
666
751
  metadata = SAEMetadata(
667
752
  model_name=cfg.model_name,
668
753
  hook_name=cfg.hook_name,
669
- hook_layer=cfg.hook_layer,
670
754
  hook_head_index=cfg.hook_head_index,
671
755
  context_size=cfg.context_size,
672
756
  prepend_bos=cfg.prepend_bos,
@@ -683,9 +767,6 @@ class TrainingSAEConfig(SAEConfig, ABC):
683
767
  def from_dict(
684
768
  cls: type[T_TRAINING_SAE_CONFIG], config_dict: dict[str, Any]
685
769
  ) -> T_TRAINING_SAE_CONFIG:
686
- # remove any keys that are not in the dataclass
687
- # since we sometimes enhance the config with the whole LM runner config
688
- valid_config_dict = filter_valid_dataclass_fields(config_dict, cls)
689
770
  cfg_class = cls
690
771
  if "architecture" in config_dict:
691
772
  cfg_class = get_sae_training_class(config_dict["architecture"])[1]
@@ -693,6 +774,9 @@ class TrainingSAEConfig(SAEConfig, ABC):
693
774
  raise ValueError(
694
775
  f"SAE config class {cls} does not match dict config class {type(cfg_class)}"
695
776
  )
777
+ # remove any keys that are not in the dataclass
778
+ # since we sometimes enhance the config with the whole LM runner config
779
+ valid_config_dict = filter_valid_dataclass_fields(config_dict, cfg_class)
696
780
  if "metadata" in config_dict:
697
781
  valid_config_dict["metadata"] = SAEMetadata(**config_dict["metadata"])
698
782
  return cfg_class(**valid_config_dict)
@@ -701,6 +785,7 @@ class TrainingSAEConfig(SAEConfig, ABC):
701
785
  return {
702
786
  **super().to_dict(),
703
787
  **asdict(self),
788
+ "metadata": self.metadata.to_dict(),
704
789
  "architecture": self.architecture(),
705
790
  }
706
791
 
@@ -711,12 +796,14 @@ class TrainingSAEConfig(SAEConfig, ABC):
711
796
  Creates a dictionary containing attributes corresponding to the fields
712
797
  defined in the base SAEConfig class.
713
798
  """
714
- base_config_field_names = {f.name for f in fields(SAEConfig)}
799
+ base_sae_cfg_class = get_sae_class(self.architecture())[1]
800
+ base_config_field_names = {f.name for f in fields(base_sae_cfg_class)}
715
801
  result_dict = {
716
802
  field_name: getattr(self, field_name)
717
803
  for field_name in base_config_field_names
718
804
  }
719
805
  result_dict["architecture"] = self.architecture()
806
+ result_dict["metadata"] = self.metadata.to_dict()
720
807
  return result_dict
721
808
 
722
809
 
@@ -729,7 +816,7 @@ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
729
816
  # Turn off hook_z reshaping for training mode - the activation store
730
817
  # is expected to handle reshaping before passing data to the SAE
731
818
  self.turn_off_forward_pass_hook_z_reshaping()
732
- self.mse_loss_fn = self._get_mse_loss_fn()
819
+ self.mse_loss_fn = mse_loss
733
820
 
734
821
  @abstractmethod
735
822
  def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]: ...
@@ -864,27 +951,6 @@ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
864
951
  """
865
952
  return self.process_state_dict_for_saving(state_dict)
866
953
 
867
- def _get_mse_loss_fn(self) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
868
- """Get the MSE loss function based on config."""
869
-
870
- def standard_mse_loss_fn(
871
- preds: torch.Tensor, target: torch.Tensor
872
- ) -> torch.Tensor:
873
- return torch.nn.functional.mse_loss(preds, target, reduction="none")
874
-
875
- def batch_norm_mse_loss_fn(
876
- preds: torch.Tensor, target: torch.Tensor
877
- ) -> torch.Tensor:
878
- target_centered = target - target.mean(dim=0, keepdim=True)
879
- normalization = target_centered.norm(dim=-1, keepdim=True)
880
- return torch.nn.functional.mse_loss(preds, target, reduction="none") / (
881
- normalization + 1e-6
882
- )
883
-
884
- if self.cfg.mse_loss_normalization == "dense_batch":
885
- return batch_norm_mse_loss_fn
886
- return standard_mse_loss_fn
887
-
888
954
  @torch.no_grad()
889
955
  def remove_gradient_parallel_to_decoder_directions(self) -> None:
890
956
  """Remove gradient components parallel to decoder directions."""
@@ -946,3 +1012,7 @@ def _disable_hooks(sae: SAE[Any]):
946
1012
  finally:
947
1013
  for hook_name, hook in sae.hook_dict.items():
948
1014
  setattr(sae, hook_name, hook)
1015
+
1016
+
1017
+ def mse_loss(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
1018
+ return torch.nn.functional.mse_loss(preds, target, reduction="none")
@@ -67,7 +67,7 @@ class StandardSAE(SAE[StandardSAEConfig]):
67
67
  sae_in = self.process_sae_in(x)
68
68
  # Compute the pre-activation values
69
69
  hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
70
- # Apply the activation function (e.g., ReLU, tanh-relu, depending on config)
70
+ # Apply the activation function (e.g., ReLU, depending on config)
71
71
  return self.hook_sae_acts_post(self.activation_fn(hidden_pre))
72
72
 
73
73
  def decode(
@@ -81,7 +81,7 @@ class StandardSAE(SAE[StandardSAEConfig]):
81
81
  sae_out_pre = feature_acts @ self.W_dec + self.b_dec
82
82
  # 2) hook reconstruction
83
83
  sae_out_pre = self.hook_sae_recons(sae_out_pre)
84
- # 4) optional out-normalization (e.g. constant_norm_rescale or layer_norm)
84
+ # 4) optional out-normalization (e.g. constant_norm_rescale)
85
85
  sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
86
86
  # 5) if hook_z is enabled, rearrange back to (..., n_heads, d_head).
87
87
  return self.reshape_fn_out(sae_out_pre, self.d_head)
@@ -136,16 +136,9 @@ class StandardTrainingSAE(TrainingSAE[StandardTrainingSAEConfig]):
136
136
  sae_in = self.process_sae_in(x)
137
137
  # Compute the pre-activation (and allow for a hook if desired)
138
138
  hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc) # type: ignore
139
- # Add noise during training for robustness (scaled by noise_scale from the configuration)
140
- if self.training and self.cfg.noise_scale > 0:
141
- hidden_pre_noised = (
142
- hidden_pre + torch.randn_like(hidden_pre) * self.cfg.noise_scale
143
- )
144
- else:
145
- hidden_pre_noised = hidden_pre
146
139
  # Apply the activation function (and any post-activation hook)
147
- feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre_noised))
148
- return feature_acts, hidden_pre_noised
140
+ feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre))
141
+ return feature_acts, hidden_pre
149
142
 
150
143
  def calculate_aux_loss(
151
144
  self,
sae_lens/saes/topk_sae.py CHANGED
@@ -91,8 +91,7 @@ class TopKSAE(SAE[TopKSAEConfig]):
91
91
  ) -> Float[torch.Tensor, "... d_sae"]:
92
92
  """
93
93
  Converts input x into feature activations.
94
- Uses topk activation from the config (cfg.activation_fn == "topk")
95
- under the hood.
94
+ Uses topk activation under the hood.
96
95
  """
97
96
  sae_in = self.process_sae_in(x)
98
97
  hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
@@ -116,6 +115,13 @@ class TopKSAE(SAE[TopKSAEConfig]):
116
115
  def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
117
116
  return TopK(self.cfg.k)
118
117
 
118
+ @override
119
+ @torch.no_grad()
120
+ def fold_W_dec_norm(self) -> None:
121
+ raise NotImplementedError(
122
+ "Folding W_dec_norm is not safe for TopKSAEs, as this may change the topk activations"
123
+ )
124
+
119
125
 
120
126
  @dataclass
121
127
  class TopKTrainingSAEConfig(TrainingSAEConfig):
@@ -156,18 +162,11 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
156
162
  sae_in = self.process_sae_in(x)
157
163
  hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
158
164
 
159
- # Inject noise if training
160
- if self.training and self.cfg.noise_scale > 0:
161
- hidden_pre_noised = (
162
- hidden_pre + torch.randn_like(hidden_pre) * self.cfg.noise_scale
163
- )
164
- else:
165
- hidden_pre_noised = hidden_pre
166
-
167
165
  # Apply the TopK activation function (already set in self.activation_fn if config is "topk")
168
- feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre_noised))
169
- return feature_acts, hidden_pre_noised
166
+ feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre))
167
+ return feature_acts, hidden_pre
170
168
 
169
+ @override
171
170
  def calculate_aux_loss(
172
171
  self,
173
172
  step_input: TrainStepInput,
@@ -184,6 +183,13 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
184
183
  )
185
184
  return {"auxiliary_reconstruction_loss": topk_loss}
186
185
 
186
+ @override
187
+ @torch.no_grad()
188
+ def fold_W_dec_norm(self) -> None:
189
+ raise NotImplementedError(
190
+ "Folding W_dec_norm is not safe for TopKSAEs, as this may change the topk activations"
191
+ )
192
+
187
193
  @override
188
194
  def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
189
195
  return TopK(self.cfg.k)
@@ -0,0 +1,53 @@
1
+ import json
2
+ from dataclasses import dataclass
3
+ from statistics import mean
4
+
5
+ import torch
6
+ from tqdm.auto import tqdm
7
+
8
+ from sae_lens.training.types import DataProvider
9
+
10
+
11
+ @dataclass
12
+ class ActivationScaler:
13
+ scaling_factor: float | None = None
14
+
15
+ def scale(self, acts: torch.Tensor) -> torch.Tensor:
16
+ return acts if self.scaling_factor is None else acts * self.scaling_factor
17
+
18
+ def unscale(self, acts: torch.Tensor) -> torch.Tensor:
19
+ return acts if self.scaling_factor is None else acts / self.scaling_factor
20
+
21
+ def __call__(self, acts: torch.Tensor) -> torch.Tensor:
22
+ return self.scale(acts)
23
+
24
+ @torch.no_grad()
25
+ def _calculate_mean_norm(
26
+ self, data_provider: DataProvider, n_batches_for_norm_estimate: int = int(1e3)
27
+ ) -> float:
28
+ norms_per_batch: list[float] = []
29
+ for _ in tqdm(
30
+ range(n_batches_for_norm_estimate), desc="Estimating norm scaling factor"
31
+ ):
32
+ acts = next(data_provider)
33
+ norms_per_batch.append(acts.norm(dim=-1).mean().item())
34
+ return mean(norms_per_batch)
35
+
36
+ def estimate_scaling_factor(
37
+ self,
38
+ d_in: int,
39
+ data_provider: DataProvider,
40
+ n_batches_for_norm_estimate: int = int(1e3),
41
+ ):
42
+ mean_norm = self._calculate_mean_norm(
43
+ data_provider, n_batches_for_norm_estimate
44
+ )
45
+ self.scaling_factor = (d_in**0.5) / mean_norm
46
+
47
+ def save(self, file_path: str):
48
+ """save the state dict to a file in json format"""
49
+ if not file_path.endswith(".json"):
50
+ raise ValueError("file_path must end with .json")
51
+
52
+ with open(file_path, "w") as f:
53
+ json.dump({"scaling_factor": self.scaling_factor}, f)