sae-lens 6.0.0rc1__py3-none-any.whl → 6.0.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,13 +1,37 @@
1
+ from dataclasses import dataclass
2
+ from typing import Any
3
+
1
4
  import numpy as np
2
5
  import torch
3
6
  from jaxtyping import Float
4
7
  from numpy.typing import NDArray
5
8
  from torch import nn
9
+ from typing_extensions import override
10
+
11
+ from sae_lens.saes.sae import (
12
+ SAE,
13
+ SAEConfig,
14
+ TrainCoefficientConfig,
15
+ TrainingSAE,
16
+ TrainingSAEConfig,
17
+ TrainStepInput,
18
+ )
19
+ from sae_lens.util import filter_valid_dataclass_fields
20
+
21
+
22
+ @dataclass
23
+ class StandardSAEConfig(SAEConfig):
24
+ """
25
+ Configuration class for a StandardSAE.
26
+ """
6
27
 
7
- from sae_lens.saes.sae import SAE, SAEConfig, TrainingSAE, TrainStepInput
28
+ @override
29
+ @classmethod
30
+ def architecture(cls) -> str:
31
+ return "standard"
8
32
 
9
33
 
10
- class StandardSAE(SAE):
34
+ class StandardSAE(SAE[StandardSAEConfig]):
11
35
  """
12
36
  StandardSAE is an inference-only implementation of a Sparse Autoencoder (SAE)
13
37
  using a simple linear encoder and decoder.
@@ -23,31 +47,14 @@ class StandardSAE(SAE):
23
47
 
24
48
  b_enc: nn.Parameter
25
49
 
26
- def __init__(self, cfg: SAEConfig, use_error_term: bool = False):
50
+ def __init__(self, cfg: StandardSAEConfig, use_error_term: bool = False):
27
51
  super().__init__(cfg, use_error_term)
28
52
 
53
+ @override
29
54
  def initialize_weights(self) -> None:
30
55
  # Initialize encoder weights and bias.
31
- self.b_enc = nn.Parameter(
32
- torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
33
- )
34
- self.b_dec = nn.Parameter(
35
- torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
36
- )
37
-
38
- # Use Kaiming Uniform for W_enc
39
- w_enc_data = torch.empty(
40
- self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
41
- )
42
- nn.init.kaiming_uniform_(w_enc_data)
43
- self.W_enc = nn.Parameter(w_enc_data)
44
-
45
- # Use Kaiming Uniform for W_dec
46
- w_dec_data = torch.empty(
47
- self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
48
- )
49
- nn.init.kaiming_uniform_(w_dec_data)
50
- self.W_dec = nn.Parameter(w_dec_data)
56
+ super().initialize_weights()
57
+ _init_weights_standard(self)
51
58
 
52
59
  def encode(
53
60
  self, x: Float[torch.Tensor, "... d_in"]
@@ -70,11 +77,9 @@ class StandardSAE(SAE):
70
77
  Decode the feature activations back to the input space.
71
78
  Now, if hook_z reshaping is turned on, we reverse the flattening.
72
79
  """
73
- # 1) apply finetuning scaling if configured.
74
- scaled_features = self.apply_finetuning_scaling_factor(feature_acts)
75
- # 2) linear transform
76
- sae_out_pre = scaled_features @ self.W_dec + self.b_dec
77
- # 3) hook reconstruction
80
+ # 1) linear transform
81
+ sae_out_pre = feature_acts @ self.W_dec + self.b_dec
82
+ # 2) hook reconstruction
78
83
  sae_out_pre = self.hook_sae_recons(sae_out_pre)
79
84
  # 4) optional out-normalization (e.g. constant_norm_rescale or layer_norm)
80
85
  sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
@@ -82,7 +87,23 @@ class StandardSAE(SAE):
82
87
  return self.reshape_fn_out(sae_out_pre, self.d_head)
83
88
 
84
89
 
85
- class StandardTrainingSAE(TrainingSAE):
90
+ @dataclass
91
+ class StandardTrainingSAEConfig(TrainingSAEConfig):
92
+ """
93
+ Configuration class for training a StandardTrainingSAE.
94
+ """
95
+
96
+ l1_coefficient: float = 1.0
97
+ lp_norm: float = 1.0
98
+ l1_warm_up_steps: int = 0
99
+
100
+ @override
101
+ @classmethod
102
+ def architecture(cls) -> str:
103
+ return "standard"
104
+
105
+
106
+ class StandardTrainingSAE(TrainingSAE[StandardTrainingSAEConfig]):
86
107
  """
87
108
  StandardTrainingSAE is a concrete implementation of BaseTrainingSAE using the "standard" SAE architecture.
88
109
  It implements:
@@ -96,31 +117,17 @@ class StandardTrainingSAE(TrainingSAE):
96
117
  b_enc: nn.Parameter
97
118
 
98
119
  def initialize_weights(self) -> None:
99
- # Basic init
100
- # In Python MRO, this calls StandardSAE.initialize_weights()
101
- StandardSAE.initialize_weights(self) # type: ignore
102
-
103
- # Complex init logic from original TrainingSAE
104
- if self.cfg.decoder_orthogonal_init:
105
- self.W_dec.data = nn.init.orthogonal_(self.W_dec.data.T).T
106
-
107
- elif self.cfg.decoder_heuristic_init:
108
- self.W_dec.data = torch.rand( # Changed from Parameter to data assignment
109
- self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
110
- )
111
- self.initialize_decoder_norm_constant_norm()
112
-
113
- if self.cfg.init_encoder_as_decoder_transpose:
114
- self.W_enc.data = self.W_dec.data.T.clone().contiguous() # type: ignore
120
+ super().initialize_weights()
121
+ _init_weights_standard(self)
115
122
 
116
- if self.cfg.normalize_sae_decoder:
117
- with torch.no_grad():
118
- self.set_decoder_norm_to_unit_norm()
119
-
120
- @torch.no_grad()
121
- def initialize_decoder_norm_constant_norm(self, norm: float = 0.1):
122
- self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True) # type: ignore
123
- self.W_dec.data *= norm
123
+ @override
124
+ def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]:
125
+ return {
126
+ "l1": TrainCoefficientConfig(
127
+ value=self.cfg.l1_coefficient,
128
+ warm_up_steps=self.cfg.l1_warm_up_steps,
129
+ ),
130
+ }
124
131
 
125
132
  def encode_with_hidden_pre(
126
133
  self, x: Float[torch.Tensor, "... d_in"]
@@ -148,13 +155,11 @@ class StandardTrainingSAE(TrainingSAE):
148
155
  sae_out: torch.Tensor,
149
156
  ) -> dict[str, torch.Tensor]:
150
157
  # The "standard" auxiliary loss is a sparsity penalty on the feature activations
151
- weighted_feature_acts = feature_acts
152
- if self.cfg.scale_sparsity_penalty_by_decoder_norm:
153
- weighted_feature_acts = feature_acts * self.W_dec.norm(dim=1)
158
+ weighted_feature_acts = feature_acts * self.W_dec.norm(dim=1)
154
159
 
155
160
  # Compute the p-norm (set by cfg.lp_norm) over the feature dimension
156
161
  sparsity = weighted_feature_acts.norm(p=self.cfg.lp_norm, dim=-1)
157
- l1_loss = (step_input.current_l1_coefficient * sparsity).mean()
162
+ l1_loss = (step_input.coefficients["l1"] * sparsity).mean()
158
163
 
159
164
  return {"l1_loss": l1_loss}
160
165
 
@@ -165,3 +170,16 @@ class StandardTrainingSAE(TrainingSAE):
165
170
  **super().log_histograms(),
166
171
  "weights/b_e": b_e_dist,
167
172
  }
173
+
174
+ def to_inference_config_dict(self) -> dict[str, Any]:
175
+ return filter_valid_dataclass_fields(
176
+ self.cfg.to_dict(), StandardSAEConfig, ["architecture"]
177
+ )
178
+
179
+
180
+ def _init_weights_standard(
181
+ sae: SAE[StandardSAEConfig] | TrainingSAE[StandardTrainingSAEConfig],
182
+ ) -> None:
183
+ sae.b_enc = nn.Parameter(
184
+ torch.zeros(sae.cfg.d_sae, dtype=sae.dtype, device=sae.device)
185
+ )
sae_lens/saes/topk_sae.py CHANGED
@@ -1,18 +1,22 @@
1
1
  """Inference-only TopKSAE variant, similar in spirit to StandardSAE but using a TopK-based activation."""
2
2
 
3
- from typing import Callable
3
+ from dataclasses import dataclass
4
+ from typing import Any, Callable
4
5
 
5
6
  import torch
6
7
  from jaxtyping import Float
7
8
  from torch import nn
9
+ from typing_extensions import override
8
10
 
9
11
  from sae_lens.saes.sae import (
10
12
  SAE,
11
13
  SAEConfig,
14
+ TrainCoefficientConfig,
12
15
  TrainingSAE,
13
16
  TrainingSAEConfig,
14
17
  TrainStepInput,
15
18
  )
19
+ from sae_lens.util import filter_valid_dataclass_fields
16
20
 
17
21
 
18
22
  class TopK(nn.Module):
@@ -45,14 +49,30 @@ class TopK(nn.Module):
45
49
  return result
46
50
 
47
51
 
48
- class TopKSAE(SAE):
52
+ @dataclass
53
+ class TopKSAEConfig(SAEConfig):
54
+ """
55
+ Configuration class for a TopKSAE.
56
+ """
57
+
58
+ k: int = 100
59
+
60
+ @override
61
+ @classmethod
62
+ def architecture(cls) -> str:
63
+ return "topk"
64
+
65
+
66
+ class TopKSAE(SAE[TopKSAEConfig]):
49
67
  """
50
68
  An inference-only sparse autoencoder using a "topk" activation function.
51
69
  It uses linear encoder and decoder layers, applying the TopK activation
52
70
  to the hidden pre-activation in its encode step.
53
71
  """
54
72
 
55
- def __init__(self, cfg: SAEConfig, use_error_term: bool = False):
73
+ b_enc: nn.Parameter
74
+
75
+ def __init__(self, cfg: TopKSAEConfig, use_error_term: bool = False):
56
76
  """
57
77
  Args:
58
78
  cfg: SAEConfig defining model size and behavior.
@@ -60,38 +80,11 @@ class TopKSAE(SAE):
60
80
  """
61
81
  super().__init__(cfg, use_error_term)
62
82
 
63
- if self.cfg.activation_fn != "topk":
64
- raise ValueError("TopKSAE must use a TopK activation function.")
65
-
83
+ @override
66
84
  def initialize_weights(self) -> None:
67
- """
68
- Initializes weights and biases for encoder/decoder similarly to the standard SAE,
69
- that is:
70
- - b_enc, b_dec are zero-initialized
71
- - W_enc, W_dec are Kaiming Uniform
72
- """
73
- # encoder bias
74
- self.b_enc = nn.Parameter(
75
- torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
76
- )
77
- # decoder bias
78
- self.b_dec = nn.Parameter(
79
- torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
80
- )
81
-
82
- # encoder weight
83
- w_enc_data = torch.empty(
84
- self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
85
- )
86
- nn.init.kaiming_uniform_(w_enc_data)
87
- self.W_enc = nn.Parameter(w_enc_data)
88
-
89
- # decoder weight
90
- w_dec_data = torch.empty(
91
- self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
92
- )
93
- nn.init.kaiming_uniform_(w_dec_data)
94
- self.W_dec = nn.Parameter(w_dec_data)
85
+ # Initialize encoder weights and bias.
86
+ super().initialize_weights()
87
+ _init_weights_topk(self)
95
88
 
96
89
  def encode(
97
90
  self, x: Float[torch.Tensor, "... d_in"]
@@ -114,28 +107,31 @@ class TopKSAE(SAE):
114
107
  Applies optional finetuning scaling, hooking to recons, out normalization,
115
108
  and optional head reshaping.
116
109
  """
117
- scaled_features = self.apply_finetuning_scaling_factor(feature_acts)
118
- sae_out_pre = scaled_features @ self.W_dec + self.b_dec
110
+ sae_out_pre = feature_acts @ self.W_dec + self.b_dec
119
111
  sae_out_pre = self.hook_sae_recons(sae_out_pre)
120
112
  sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
121
113
  return self.reshape_fn_out(sae_out_pre, self.d_head)
122
114
 
123
- def _get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
124
- if self.cfg.activation_fn == "topk":
125
- if "k" not in self.cfg.activation_fn_kwargs:
126
- raise ValueError("TopK activation function requires a k value.")
127
- k = self.cfg.activation_fn_kwargs.get(
128
- "k", 1
129
- ) # Default k to 1 if not provided
130
- postact_fn = self.cfg.activation_fn_kwargs.get(
131
- "postact_fn", nn.ReLU()
132
- ) # Default post-activation to ReLU if not provided
133
- return TopK(k, postact_fn)
134
- # Otherwise, return the "standard" handling from BaseSAE
135
- return super()._get_activation_fn()
136
-
137
-
138
- class TopKTrainingSAE(TrainingSAE):
115
+ @override
116
+ def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
117
+ return TopK(self.cfg.k)
118
+
119
+
120
+ @dataclass
121
+ class TopKTrainingSAEConfig(TrainingSAEConfig):
122
+ """
123
+ Configuration class for training a TopKTrainingSAE.
124
+ """
125
+
126
+ k: int = 100
127
+
128
+ @override
129
+ @classmethod
130
+ def architecture(cls) -> str:
131
+ return "topk"
132
+
133
+
134
+ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
139
135
  """
140
136
  TopK variant with training functionality. Injects noise during training, optionally
141
137
  calculates a topk-related auxiliary loss, etc.
@@ -143,32 +139,13 @@ class TopKTrainingSAE(TrainingSAE):
143
139
 
144
140
  b_enc: nn.Parameter
145
141
 
146
- def __init__(self, cfg: TrainingSAEConfig, use_error_term: bool = False):
142
+ def __init__(self, cfg: TopKTrainingSAEConfig, use_error_term: bool = False):
147
143
  super().__init__(cfg, use_error_term)
148
144
 
149
- if self.cfg.activation_fn != "topk":
150
- raise ValueError("TopKSAE must use a TopK activation function.")
151
-
145
+ @override
152
146
  def initialize_weights(self) -> None:
153
- """Very similar to TopKSAE, using zero biases + Kaiming Uniform weights."""
154
- self.b_enc = nn.Parameter(
155
- torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
156
- )
157
- self.b_dec = nn.Parameter(
158
- torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
159
- )
160
-
161
- w_enc_data = torch.empty(
162
- self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
163
- )
164
- nn.init.kaiming_uniform_(w_enc_data)
165
- self.W_enc = nn.Parameter(w_enc_data)
166
-
167
- w_dec_data = torch.empty(
168
- self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
169
- )
170
- nn.init.kaiming_uniform_(w_dec_data)
171
- self.W_dec = nn.Parameter(w_dec_data)
147
+ super().initialize_weights()
148
+ _init_weights_topk(self)
172
149
 
173
150
  def encode_with_hidden_pre(
174
151
  self, x: Float[torch.Tensor, "... d_in"]
@@ -207,14 +184,13 @@ class TopKTrainingSAE(TrainingSAE):
207
184
  )
208
185
  return {"auxiliary_reconstruction_loss": topk_loss}
209
186
 
210
- def _get_activation_fn(self):
211
- if self.cfg.activation_fn == "topk":
212
- if "k" not in self.cfg.activation_fn_kwargs:
213
- raise ValueError("TopK activation function requires a k value.")
214
- k = self.cfg.activation_fn_kwargs.get("k", 1)
215
- postact_fn = self.cfg.activation_fn_kwargs.get("postact_fn", nn.ReLU())
216
- return TopK(k, postact_fn)
217
- return super()._get_activation_fn()
187
+ @override
188
+ def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
189
+ return TopK(self.cfg.k)
190
+
191
+ @override
192
+ def get_coefficients(self) -> dict[str, TrainCoefficientConfig | float]:
193
+ return {}
218
194
 
219
195
  def calculate_topk_aux_loss(
220
196
  self,
@@ -288,6 +264,11 @@ class TopKTrainingSAE(TrainingSAE):
288
264
 
289
265
  return auxk_acts
290
266
 
267
+ def to_inference_config_dict(self) -> dict[str, Any]:
268
+ return filter_valid_dataclass_fields(
269
+ self.cfg.to_dict(), TopKSAEConfig, ["architecture"]
270
+ )
271
+
291
272
 
292
273
  def _calculate_topk_aux_acts(
293
274
  k_aux: int,
@@ -303,3 +284,11 @@ def _calculate_topk_aux_acts(
303
284
  auxk_acts.scatter_(-1, auxk_topk.indices, auxk_topk.values)
304
285
  # Set activations to zero for all but top k_aux dead latents
305
286
  return auxk_acts
287
+
288
+
289
+ def _init_weights_topk(
290
+ sae: SAE[TopKSAEConfig] | TrainingSAE[TopKTrainingSAEConfig],
291
+ ) -> None:
292
+ sae.b_enc = nn.Parameter(
293
+ torch.zeros(sae.cfg.d_sae, dtype=sae.dtype, device=sae.device)
294
+ )
@@ -0,0 +1,53 @@
1
+ import json
2
+ from dataclasses import dataclass
3
+ from statistics import mean
4
+
5
+ import torch
6
+ from tqdm import tqdm
7
+
8
+ from sae_lens.training.types import DataProvider
9
+
10
+
11
+ @dataclass
12
+ class ActivationScaler:
13
+ scaling_factor: float | None = None
14
+
15
+ def scale(self, acts: torch.Tensor) -> torch.Tensor:
16
+ return acts if self.scaling_factor is None else acts * self.scaling_factor
17
+
18
+ def unscale(self, acts: torch.Tensor) -> torch.Tensor:
19
+ return acts if self.scaling_factor is None else acts / self.scaling_factor
20
+
21
+ def __call__(self, acts: torch.Tensor) -> torch.Tensor:
22
+ return self.scale(acts)
23
+
24
+ @torch.no_grad()
25
+ def _calculate_mean_norm(
26
+ self, data_provider: DataProvider, n_batches_for_norm_estimate: int = int(1e3)
27
+ ) -> float:
28
+ norms_per_batch: list[float] = []
29
+ for _ in tqdm(
30
+ range(n_batches_for_norm_estimate), desc="Estimating norm scaling factor"
31
+ ):
32
+ acts = next(data_provider)
33
+ norms_per_batch.append(acts.norm(dim=-1).mean().item())
34
+ return mean(norms_per_batch)
35
+
36
+ def estimate_scaling_factor(
37
+ self,
38
+ d_in: int,
39
+ data_provider: DataProvider,
40
+ n_batches_for_norm_estimate: int = int(1e3),
41
+ ):
42
+ mean_norm = self._calculate_mean_norm(
43
+ data_provider, n_batches_for_norm_estimate
44
+ )
45
+ self.scaling_factor = (d_in**0.5) / mean_norm
46
+
47
+ def save(self, file_path: str):
48
+ """save the state dict to a file in json format"""
49
+ if not file_path.endswith(".json"):
50
+ raise ValueError("file_path must end with .json")
51
+
52
+ with open(file_path, "w") as f:
53
+ json.dump({"scaling_factor": self.scaling_factor}, f)