sae-lens 6.0.0rc1__py3-none-any.whl → 6.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,13 +1,37 @@
1
+ from dataclasses import dataclass
2
+ from typing import Any
3
+
1
4
  import numpy as np
2
5
  import torch
3
6
  from jaxtyping import Float
4
7
  from numpy.typing import NDArray
5
8
  from torch import nn
9
+ from typing_extensions import override
10
+
11
+ from sae_lens.saes.sae import (
12
+ SAE,
13
+ SAEConfig,
14
+ TrainCoefficientConfig,
15
+ TrainingSAE,
16
+ TrainingSAEConfig,
17
+ TrainStepInput,
18
+ )
19
+ from sae_lens.util import filter_valid_dataclass_fields
20
+
21
+
22
+ @dataclass
23
+ class StandardSAEConfig(SAEConfig):
24
+ """
25
+ Configuration class for a StandardSAE.
26
+ """
6
27
 
7
- from sae_lens.saes.sae import SAE, SAEConfig, TrainingSAE, TrainStepInput
28
+ @override
29
+ @classmethod
30
+ def architecture(cls) -> str:
31
+ return "standard"
8
32
 
9
33
 
10
- class StandardSAE(SAE):
34
+ class StandardSAE(SAE[StandardSAEConfig]):
11
35
  """
12
36
  StandardSAE is an inference-only implementation of a Sparse Autoencoder (SAE)
13
37
  using a simple linear encoder and decoder.
@@ -23,31 +47,14 @@ class StandardSAE(SAE):
23
47
 
24
48
  b_enc: nn.Parameter
25
49
 
26
- def __init__(self, cfg: SAEConfig, use_error_term: bool = False):
50
+ def __init__(self, cfg: StandardSAEConfig, use_error_term: bool = False):
27
51
  super().__init__(cfg, use_error_term)
28
52
 
53
+ @override
29
54
  def initialize_weights(self) -> None:
30
55
  # Initialize encoder weights and bias.
31
- self.b_enc = nn.Parameter(
32
- torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
33
- )
34
- self.b_dec = nn.Parameter(
35
- torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
36
- )
37
-
38
- # Use Kaiming Uniform for W_enc
39
- w_enc_data = torch.empty(
40
- self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
41
- )
42
- nn.init.kaiming_uniform_(w_enc_data)
43
- self.W_enc = nn.Parameter(w_enc_data)
44
-
45
- # Use Kaiming Uniform for W_dec
46
- w_dec_data = torch.empty(
47
- self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
48
- )
49
- nn.init.kaiming_uniform_(w_dec_data)
50
- self.W_dec = nn.Parameter(w_dec_data)
56
+ super().initialize_weights()
57
+ _init_weights_standard(self)
51
58
 
52
59
  def encode(
53
60
  self, x: Float[torch.Tensor, "... d_in"]
@@ -70,11 +77,9 @@ class StandardSAE(SAE):
70
77
  Decode the feature activations back to the input space.
71
78
  Now, if hook_z reshaping is turned on, we reverse the flattening.
72
79
  """
73
- # 1) apply finetuning scaling if configured.
74
- scaled_features = self.apply_finetuning_scaling_factor(feature_acts)
75
- # 2) linear transform
76
- sae_out_pre = scaled_features @ self.W_dec + self.b_dec
77
- # 3) hook reconstruction
80
+ # 1) linear transform
81
+ sae_out_pre = feature_acts @ self.W_dec + self.b_dec
82
+ # 2) hook reconstruction
78
83
  sae_out_pre = self.hook_sae_recons(sae_out_pre)
79
84
  # 4) optional out-normalization (e.g. constant_norm_rescale or layer_norm)
80
85
  sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
@@ -82,7 +87,23 @@ class StandardSAE(SAE):
82
87
  return self.reshape_fn_out(sae_out_pre, self.d_head)
83
88
 
84
89
 
85
- class StandardTrainingSAE(TrainingSAE):
90
+ @dataclass
91
+ class StandardTrainingSAEConfig(TrainingSAEConfig):
92
+ """
93
+ Configuration class for training a StandardTrainingSAE.
94
+ """
95
+
96
+ l1_coefficient: float = 1.0
97
+ lp_norm: float = 1.0
98
+ l1_warm_up_steps: int = 0
99
+
100
+ @override
101
+ @classmethod
102
+ def architecture(cls) -> str:
103
+ return "standard"
104
+
105
+
106
+ class StandardTrainingSAE(TrainingSAE[StandardTrainingSAEConfig]):
86
107
  """
87
108
  StandardTrainingSAE is a concrete implementation of BaseTrainingSAE using the "standard" SAE architecture.
88
109
  It implements:
@@ -96,31 +117,17 @@ class StandardTrainingSAE(TrainingSAE):
96
117
  b_enc: nn.Parameter
97
118
 
98
119
  def initialize_weights(self) -> None:
99
- # Basic init
100
- # In Python MRO, this calls StandardSAE.initialize_weights()
101
- StandardSAE.initialize_weights(self) # type: ignore
102
-
103
- # Complex init logic from original TrainingSAE
104
- if self.cfg.decoder_orthogonal_init:
105
- self.W_dec.data = nn.init.orthogonal_(self.W_dec.data.T).T
106
-
107
- elif self.cfg.decoder_heuristic_init:
108
- self.W_dec.data = torch.rand( # Changed from Parameter to data assignment
109
- self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
110
- )
111
- self.initialize_decoder_norm_constant_norm()
112
-
113
- if self.cfg.init_encoder_as_decoder_transpose:
114
- self.W_enc.data = self.W_dec.data.T.clone().contiguous() # type: ignore
120
+ super().initialize_weights()
121
+ _init_weights_standard(self)
115
122
 
116
- if self.cfg.normalize_sae_decoder:
117
- with torch.no_grad():
118
- self.set_decoder_norm_to_unit_norm()
119
-
120
- @torch.no_grad()
121
- def initialize_decoder_norm_constant_norm(self, norm: float = 0.1):
122
- self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True) # type: ignore
123
- self.W_dec.data *= norm
123
+ @override
124
+ def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]:
125
+ return {
126
+ "l1": TrainCoefficientConfig(
127
+ value=self.cfg.l1_coefficient,
128
+ warm_up_steps=self.cfg.l1_warm_up_steps,
129
+ ),
130
+ }
124
131
 
125
132
  def encode_with_hidden_pre(
126
133
  self, x: Float[torch.Tensor, "... d_in"]
@@ -148,13 +155,11 @@ class StandardTrainingSAE(TrainingSAE):
148
155
  sae_out: torch.Tensor,
149
156
  ) -> dict[str, torch.Tensor]:
150
157
  # The "standard" auxiliary loss is a sparsity penalty on the feature activations
151
- weighted_feature_acts = feature_acts
152
- if self.cfg.scale_sparsity_penalty_by_decoder_norm:
153
- weighted_feature_acts = feature_acts * self.W_dec.norm(dim=1)
158
+ weighted_feature_acts = feature_acts * self.W_dec.norm(dim=1)
154
159
 
155
160
  # Compute the p-norm (set by cfg.lp_norm) over the feature dimension
156
161
  sparsity = weighted_feature_acts.norm(p=self.cfg.lp_norm, dim=-1)
157
- l1_loss = (step_input.current_l1_coefficient * sparsity).mean()
162
+ l1_loss = (step_input.coefficients["l1"] * sparsity).mean()
158
163
 
159
164
  return {"l1_loss": l1_loss}
160
165
 
@@ -165,3 +170,16 @@ class StandardTrainingSAE(TrainingSAE):
165
170
  **super().log_histograms(),
166
171
  "weights/b_e": b_e_dist,
167
172
  }
173
+
174
+ def to_inference_config_dict(self) -> dict[str, Any]:
175
+ return filter_valid_dataclass_fields(
176
+ self.cfg.to_dict(), StandardSAEConfig, ["architecture"]
177
+ )
178
+
179
+
180
+ def _init_weights_standard(
181
+ sae: SAE[StandardSAEConfig] | TrainingSAE[StandardTrainingSAEConfig],
182
+ ) -> None:
183
+ sae.b_enc = nn.Parameter(
184
+ torch.zeros(sae.cfg.d_sae, dtype=sae.dtype, device=sae.device)
185
+ )
sae_lens/saes/topk_sae.py CHANGED
@@ -1,18 +1,22 @@
1
1
  """Inference-only TopKSAE variant, similar in spirit to StandardSAE but using a TopK-based activation."""
2
2
 
3
- from typing import Callable
3
+ from dataclasses import dataclass
4
+ from typing import Any, Callable
4
5
 
5
6
  import torch
6
7
  from jaxtyping import Float
7
8
  from torch import nn
9
+ from typing_extensions import override
8
10
 
9
11
  from sae_lens.saes.sae import (
10
12
  SAE,
11
13
  SAEConfig,
14
+ TrainCoefficientConfig,
12
15
  TrainingSAE,
13
16
  TrainingSAEConfig,
14
17
  TrainStepInput,
15
18
  )
19
+ from sae_lens.util import filter_valid_dataclass_fields
16
20
 
17
21
 
18
22
  class TopK(nn.Module):
@@ -45,14 +49,30 @@ class TopK(nn.Module):
45
49
  return result
46
50
 
47
51
 
48
- class TopKSAE(SAE):
52
+ @dataclass
53
+ class TopKSAEConfig(SAEConfig):
54
+ """
55
+ Configuration class for a TopKSAE.
56
+ """
57
+
58
+ k: int = 100
59
+
60
+ @override
61
+ @classmethod
62
+ def architecture(cls) -> str:
63
+ return "topk"
64
+
65
+
66
+ class TopKSAE(SAE[TopKSAEConfig]):
49
67
  """
50
68
  An inference-only sparse autoencoder using a "topk" activation function.
51
69
  It uses linear encoder and decoder layers, applying the TopK activation
52
70
  to the hidden pre-activation in its encode step.
53
71
  """
54
72
 
55
- def __init__(self, cfg: SAEConfig, use_error_term: bool = False):
73
+ b_enc: nn.Parameter
74
+
75
+ def __init__(self, cfg: TopKSAEConfig, use_error_term: bool = False):
56
76
  """
57
77
  Args:
58
78
  cfg: SAEConfig defining model size and behavior.
@@ -60,38 +80,11 @@ class TopKSAE(SAE):
60
80
  """
61
81
  super().__init__(cfg, use_error_term)
62
82
 
63
- if self.cfg.activation_fn != "topk":
64
- raise ValueError("TopKSAE must use a TopK activation function.")
65
-
83
+ @override
66
84
  def initialize_weights(self) -> None:
67
- """
68
- Initializes weights and biases for encoder/decoder similarly to the standard SAE,
69
- that is:
70
- - b_enc, b_dec are zero-initialized
71
- - W_enc, W_dec are Kaiming Uniform
72
- """
73
- # encoder bias
74
- self.b_enc = nn.Parameter(
75
- torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
76
- )
77
- # decoder bias
78
- self.b_dec = nn.Parameter(
79
- torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
80
- )
81
-
82
- # encoder weight
83
- w_enc_data = torch.empty(
84
- self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
85
- )
86
- nn.init.kaiming_uniform_(w_enc_data)
87
- self.W_enc = nn.Parameter(w_enc_data)
88
-
89
- # decoder weight
90
- w_dec_data = torch.empty(
91
- self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
92
- )
93
- nn.init.kaiming_uniform_(w_dec_data)
94
- self.W_dec = nn.Parameter(w_dec_data)
85
+ # Initialize encoder weights and bias.
86
+ super().initialize_weights()
87
+ _init_weights_topk(self)
95
88
 
96
89
  def encode(
97
90
  self, x: Float[torch.Tensor, "... d_in"]
@@ -114,28 +107,31 @@ class TopKSAE(SAE):
114
107
  Applies optional finetuning scaling, hooking to recons, out normalization,
115
108
  and optional head reshaping.
116
109
  """
117
- scaled_features = self.apply_finetuning_scaling_factor(feature_acts)
118
- sae_out_pre = scaled_features @ self.W_dec + self.b_dec
110
+ sae_out_pre = feature_acts @ self.W_dec + self.b_dec
119
111
  sae_out_pre = self.hook_sae_recons(sae_out_pre)
120
112
  sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
121
113
  return self.reshape_fn_out(sae_out_pre, self.d_head)
122
114
 
123
- def _get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
124
- if self.cfg.activation_fn == "topk":
125
- if "k" not in self.cfg.activation_fn_kwargs:
126
- raise ValueError("TopK activation function requires a k value.")
127
- k = self.cfg.activation_fn_kwargs.get(
128
- "k", 1
129
- ) # Default k to 1 if not provided
130
- postact_fn = self.cfg.activation_fn_kwargs.get(
131
- "postact_fn", nn.ReLU()
132
- ) # Default post-activation to ReLU if not provided
133
- return TopK(k, postact_fn)
134
- # Otherwise, return the "standard" handling from BaseSAE
135
- return super()._get_activation_fn()
136
-
137
-
138
- class TopKTrainingSAE(TrainingSAE):
115
+ @override
116
+ def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
117
+ return TopK(self.cfg.k)
118
+
119
+
120
+ @dataclass
121
+ class TopKTrainingSAEConfig(TrainingSAEConfig):
122
+ """
123
+ Configuration class for training a TopKTrainingSAE.
124
+ """
125
+
126
+ k: int = 100
127
+
128
+ @override
129
+ @classmethod
130
+ def architecture(cls) -> str:
131
+ return "topk"
132
+
133
+
134
+ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
139
135
  """
140
136
  TopK variant with training functionality. Injects noise during training, optionally
141
137
  calculates a topk-related auxiliary loss, etc.
@@ -143,32 +139,13 @@ class TopKTrainingSAE(TrainingSAE):
143
139
 
144
140
  b_enc: nn.Parameter
145
141
 
146
- def __init__(self, cfg: TrainingSAEConfig, use_error_term: bool = False):
142
+ def __init__(self, cfg: TopKTrainingSAEConfig, use_error_term: bool = False):
147
143
  super().__init__(cfg, use_error_term)
148
144
 
149
- if self.cfg.activation_fn != "topk":
150
- raise ValueError("TopKSAE must use a TopK activation function.")
151
-
145
+ @override
152
146
  def initialize_weights(self) -> None:
153
- """Very similar to TopKSAE, using zero biases + Kaiming Uniform weights."""
154
- self.b_enc = nn.Parameter(
155
- torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
156
- )
157
- self.b_dec = nn.Parameter(
158
- torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
159
- )
160
-
161
- w_enc_data = torch.empty(
162
- self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
163
- )
164
- nn.init.kaiming_uniform_(w_enc_data)
165
- self.W_enc = nn.Parameter(w_enc_data)
166
-
167
- w_dec_data = torch.empty(
168
- self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
169
- )
170
- nn.init.kaiming_uniform_(w_dec_data)
171
- self.W_dec = nn.Parameter(w_dec_data)
147
+ super().initialize_weights()
148
+ _init_weights_topk(self)
172
149
 
173
150
  def encode_with_hidden_pre(
174
151
  self, x: Float[torch.Tensor, "... d_in"]
@@ -207,14 +184,13 @@ class TopKTrainingSAE(TrainingSAE):
207
184
  )
208
185
  return {"auxiliary_reconstruction_loss": topk_loss}
209
186
 
210
- def _get_activation_fn(self):
211
- if self.cfg.activation_fn == "topk":
212
- if "k" not in self.cfg.activation_fn_kwargs:
213
- raise ValueError("TopK activation function requires a k value.")
214
- k = self.cfg.activation_fn_kwargs.get("k", 1)
215
- postact_fn = self.cfg.activation_fn_kwargs.get("postact_fn", nn.ReLU())
216
- return TopK(k, postact_fn)
217
- return super()._get_activation_fn()
187
+ @override
188
+ def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
189
+ return TopK(self.cfg.k)
190
+
191
+ @override
192
+ def get_coefficients(self) -> dict[str, TrainCoefficientConfig | float]:
193
+ return {}
218
194
 
219
195
  def calculate_topk_aux_loss(
220
196
  self,
@@ -288,6 +264,11 @@ class TopKTrainingSAE(TrainingSAE):
288
264
 
289
265
  return auxk_acts
290
266
 
267
+ def to_inference_config_dict(self) -> dict[str, Any]:
268
+ return filter_valid_dataclass_fields(
269
+ self.cfg.to_dict(), TopKSAEConfig, ["architecture"]
270
+ )
271
+
291
272
 
292
273
  def _calculate_topk_aux_acts(
293
274
  k_aux: int,
@@ -303,3 +284,11 @@ def _calculate_topk_aux_acts(
303
284
  auxk_acts.scatter_(-1, auxk_topk.indices, auxk_topk.values)
304
285
  # Set activations to zero for all but top k_aux dead latents
305
286
  return auxk_acts
287
+
288
+
289
+ def _init_weights_topk(
290
+ sae: SAE[TopKSAEConfig] | TrainingSAE[TopKTrainingSAEConfig],
291
+ ) -> None:
292
+ sae.b_enc = nn.Parameter(
293
+ torch.zeros(sae.cfg.d_sae, dtype=sae.dtype, device=sae.device)
294
+ )
@@ -23,12 +23,12 @@ from transformers import AutoTokenizer, PreTrainedTokenizerBase
23
23
 
24
24
  from sae_lens import logger
25
25
  from sae_lens.config import (
26
- DTYPE_MAP,
27
26
  CacheActivationsRunnerConfig,
28
27
  HfDataset,
29
28
  LanguageModelSAERunnerConfig,
30
29
  )
31
- from sae_lens.saes.sae import SAE
30
+ from sae_lens.constants import DTYPE_MAP
31
+ from sae_lens.saes.sae import SAE, T_SAE_CONFIG, T_TRAINING_SAE_CONFIG
32
32
  from sae_lens.tokenization_and_batching import concat_and_batch_sequences
33
33
 
34
34
 
@@ -91,7 +91,8 @@ class ActivationsStore:
91
91
  def from_config(
92
92
  cls,
93
93
  model: HookedRootModule,
94
- cfg: LanguageModelSAERunnerConfig | CacheActivationsRunnerConfig,
94
+ cfg: LanguageModelSAERunnerConfig[T_TRAINING_SAE_CONFIG]
95
+ | CacheActivationsRunnerConfig,
95
96
  override_dataset: HfDataset | None = None,
96
97
  ) -> ActivationsStore:
97
98
  if isinstance(cfg, CacheActivationsRunnerConfig):
@@ -128,13 +129,15 @@ class ActivationsStore:
128
129
  hook_layer=cfg.hook_layer,
129
130
  hook_head_index=cfg.hook_head_index,
130
131
  context_size=cfg.context_size,
131
- d_in=cfg.d_in,
132
+ d_in=cfg.d_in
133
+ if isinstance(cfg, CacheActivationsRunnerConfig)
134
+ else cfg.sae.d_in,
132
135
  n_batches_in_buffer=cfg.n_batches_in_buffer,
133
136
  total_training_tokens=cfg.training_tokens,
134
137
  store_batch_size_prompts=cfg.store_batch_size_prompts,
135
138
  train_batch_size_tokens=cfg.train_batch_size_tokens,
136
139
  prepend_bos=cfg.prepend_bos,
137
- normalize_activations=cfg.normalize_activations,
140
+ normalize_activations=cfg.sae.normalize_activations,
138
141
  device=device,
139
142
  dtype=cfg.dtype,
140
143
  cached_activations_path=cached_activations_path,
@@ -149,9 +152,10 @@ class ActivationsStore:
149
152
  def from_sae(
150
153
  cls,
151
154
  model: HookedRootModule,
152
- sae: SAE,
155
+ sae: SAE[T_SAE_CONFIG],
156
+ dataset: HfDataset | str,
157
+ dataset_trust_remote_code: bool = False,
153
158
  context_size: int | None = None,
154
- dataset: HfDataset | str | None = None,
155
159
  streaming: bool = True,
156
160
  store_batch_size_prompts: int = 8,
157
161
  n_batches_in_buffer: int = 8,
@@ -159,25 +163,37 @@ class ActivationsStore:
159
163
  total_tokens: int = 10**9,
160
164
  device: str = "cpu",
161
165
  ) -> ActivationsStore:
166
+ if sae.cfg.metadata.hook_name is None:
167
+ raise ValueError("hook_name is required")
168
+ if sae.cfg.metadata.hook_layer is None:
169
+ raise ValueError("hook_layer is required")
170
+ if sae.cfg.metadata.hook_head_index is None:
171
+ raise ValueError("hook_head_index is required")
172
+ if sae.cfg.metadata.context_size is None:
173
+ raise ValueError("context_size is required")
174
+ if sae.cfg.metadata.prepend_bos is None:
175
+ raise ValueError("prepend_bos is required")
162
176
  return cls(
163
177
  model=model,
164
- dataset=sae.cfg.dataset_path if dataset is None else dataset,
178
+ dataset=dataset,
165
179
  d_in=sae.cfg.d_in,
166
- hook_name=sae.cfg.hook_name,
167
- hook_layer=sae.cfg.hook_layer,
168
- hook_head_index=sae.cfg.hook_head_index,
169
- context_size=sae.cfg.context_size if context_size is None else context_size,
170
- prepend_bos=sae.cfg.prepend_bos,
180
+ hook_name=sae.cfg.metadata.hook_name,
181
+ hook_layer=sae.cfg.metadata.hook_layer,
182
+ hook_head_index=sae.cfg.metadata.hook_head_index,
183
+ context_size=sae.cfg.metadata.context_size
184
+ if context_size is None
185
+ else context_size,
186
+ prepend_bos=sae.cfg.metadata.prepend_bos,
171
187
  streaming=streaming,
172
188
  store_batch_size_prompts=store_batch_size_prompts,
173
189
  train_batch_size_tokens=train_batch_size_tokens,
174
190
  n_batches_in_buffer=n_batches_in_buffer,
175
191
  total_training_tokens=total_tokens,
176
192
  normalize_activations=sae.cfg.normalize_activations,
177
- dataset_trust_remote_code=sae.cfg.dataset_trust_remote_code,
193
+ dataset_trust_remote_code=dataset_trust_remote_code,
178
194
  dtype=sae.cfg.dtype,
179
195
  device=torch.device(device),
180
- seqpos_slice=sae.cfg.seqpos_slice or (None,),
196
+ seqpos_slice=sae.cfg.metadata.seqpos_slice or (None,),
181
197
  )
182
198
 
183
199
  def __init__(
@@ -101,61 +101,85 @@ def _get_main_lr_scheduler(
101
101
  raise ValueError(f"Unsupported scheduler: {scheduler_name}")
102
102
 
103
103
 
104
- class L1Scheduler:
104
+ class CoefficientScheduler:
105
+ """Linearly warms up a scalar value from 0.0 to a final value."""
106
+
105
107
  def __init__(
106
108
  self,
107
- l1_warm_up_steps: float,
108
- total_steps: int,
109
- final_l1_coefficient: float,
109
+ warm_up_steps: float,
110
+ final_value: float,
110
111
  ):
111
- self.l1_warmup_steps = l1_warm_up_steps
112
- # assume using warm-up
113
- if self.l1_warmup_steps != 0:
114
- self.current_l1_coefficient = 0.0
115
- else:
116
- self.current_l1_coefficient = final_l1_coefficient
117
-
118
- self.final_l1_coefficient = final_l1_coefficient
119
-
112
+ self.warm_up_steps = warm_up_steps
113
+ self.final_value = final_value
120
114
  self.current_step = 0
121
- self.total_steps = total_steps
122
- if not isinstance(self.final_l1_coefficient, (float, int)):
115
+
116
+ if not isinstance(self.final_value, (float, int)):
123
117
  raise TypeError(
124
- f"final_l1_coefficient must be float or int, got {type(self.final_l1_coefficient)}."
118
+ f"final_value must be float or int, got {type(self.final_value)}."
125
119
  )
126
120
 
121
+ # Initialize current_value based on whether warm-up is used
122
+ if self.warm_up_steps > 0:
123
+ self.current_value = 0.0
124
+ else:
125
+ self.current_value = self.final_value
126
+
127
127
  def __repr__(self) -> str:
128
128
  return (
129
- f"L1Scheduler(final_l1_value={self.final_l1_coefficient}, "
130
- f"l1_warmup_steps={self.l1_warmup_steps}, "
131
- f"total_steps={self.total_steps})"
129
+ f"{self.__class__.__name__}(final_value={self.final_value}, "
130
+ f"warm_up_steps={self.warm_up_steps})"
132
131
  )
133
132
 
134
- def step(self):
133
+ def step(self) -> float:
135
134
  """
136
- Updates the l1 coefficient of the sparse autoencoder.
135
+ Updates the scalar value based on the current step.
136
+
137
+ Returns:
138
+ The current scalar value after the step.
137
139
  """
138
- step = self.current_step
139
- if step < self.l1_warmup_steps:
140
- self.current_l1_coefficient = self.final_l1_coefficient * (
141
- (1 + step) / self.l1_warmup_steps
142
- ) # type: ignore
140
+ if self.current_step < self.warm_up_steps:
141
+ self.current_value = self.final_value * (
142
+ (self.current_step + 1) / self.warm_up_steps
143
+ )
143
144
  else:
144
- self.current_l1_coefficient = self.final_l1_coefficient # type: ignore
145
+ # Ensure the value stays at final_value after warm-up
146
+ self.current_value = self.final_value
145
147
 
146
148
  self.current_step += 1
149
+ return self.current_value
147
150
 
148
- def state_dict(self):
149
- """State dict for serializing as part of an SAETrainContext."""
151
+ @property
152
+ def value(self) -> float:
153
+ """Returns the current scalar value."""
154
+ return self.current_value
155
+
156
+ def state_dict(self) -> dict[str, Any]:
157
+ """State dict for serialization."""
150
158
  return {
151
- "l1_warmup_steps": self.l1_warmup_steps,
152
- "total_steps": self.total_steps,
153
- "current_l1_coefficient": self.current_l1_coefficient,
154
- "final_l1_coefficient": self.final_l1_coefficient,
159
+ "warm_up_steps": self.warm_up_steps,
160
+ "final_value": self.final_value,
155
161
  "current_step": self.current_step,
162
+ "current_value": self.current_value,
156
163
  }
157
164
 
158
165
  def load_state_dict(self, state_dict: dict[str, Any]):
159
- """Loads all state apart from attached SAE."""
160
- for k in state_dict:
161
- setattr(self, k, state_dict[k])
166
+ """Loads the scheduler state."""
167
+ self.warm_up_steps = state_dict["warm_up_steps"]
168
+ self.final_value = state_dict["final_value"]
169
+ self.current_step = state_dict["current_step"]
170
+ # Maintain consistency: re-calculate current_value based on loaded step
171
+ # This handles resuming correctly if stopped mid-warmup.
172
+ if self.current_step <= self.warm_up_steps and self.warm_up_steps > 0:
173
+ # Use max(0, ...) to handle case where current_step might be loaded as -1 or similar before first step
174
+ step_for_calc = max(0, self.current_step)
175
+ # Recalculate based on the step *before* the one about to be taken
176
+ # Or simply use the saved current_value if available and consistent
177
+ if "current_value" in state_dict:
178
+ self.current_value = state_dict["current_value"]
179
+ else: # Legacy state dicts might not have current_value
180
+ self.current_value = self.final_value * (
181
+ step_for_calc / self.warm_up_steps
182
+ )
183
+
184
+ else:
185
+ self.current_value = self.final_value