sae-lens 6.0.0rc5__py3-none-any.whl → 6.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sae_lens/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.0.0-rc.5"
2
+ __version__ = "6.2.0"
3
3
 
4
4
  import logging
5
5
 
@@ -7,6 +7,8 @@ logger = logging.getLogger(__name__)
7
7
 
8
8
  from sae_lens.saes import (
9
9
  SAE,
10
+ BatchTopKTrainingSAE,
11
+ BatchTopKTrainingSAEConfig,
10
12
  GatedSAE,
11
13
  GatedSAEConfig,
12
14
  GatedTrainingSAE,
@@ -85,6 +87,8 @@ __all__ = [
85
87
  "JumpReLUTrainingSAEConfig",
86
88
  "SAETrainingRunner",
87
89
  "LoggingConfig",
90
+ "BatchTopKTrainingSAE",
91
+ "BatchTopKTrainingSAEConfig",
88
92
  ]
89
93
 
90
94
 
@@ -96,3 +100,6 @@ register_sae_class("topk", TopKSAE, TopKSAEConfig)
96
100
  register_sae_training_class("topk", TopKTrainingSAE, TopKTrainingSAEConfig)
97
101
  register_sae_class("jumprelu", JumpReLUSAE, JumpReLUSAEConfig)
98
102
  register_sae_training_class("jumprelu", JumpReLUTrainingSAE, JumpReLUTrainingSAEConfig)
103
+ register_sae_training_class(
104
+ "batchtopk", BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig
105
+ )
sae_lens/config.py CHANGED
@@ -1,6 +1,5 @@
1
1
  import json
2
2
  import math
3
- import os
4
3
  from dataclasses import asdict, dataclass, field
5
4
  from pathlib import Path
6
5
  from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast
@@ -353,28 +352,6 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
353
352
  d["act_store_device"] = str(self.act_store_device)
354
353
  return d
355
354
 
356
- def to_json(self, path: str) -> None:
357
- if not os.path.exists(os.path.dirname(path)):
358
- os.makedirs(os.path.dirname(path))
359
-
360
- with open(path + "cfg.json", "w") as f:
361
- json.dump(self.to_dict(), f, indent=2)
362
-
363
- @classmethod
364
- def from_json(cls, path: str) -> "LanguageModelSAERunnerConfig[Any]":
365
- with open(path + "cfg.json") as f:
366
- cfg = json.load(f)
367
-
368
- # ensure that seqpos slices is a tuple
369
- # Ensure seqpos_slice is a tuple
370
- if "seqpos_slice" in cfg:
371
- if isinstance(cfg["seqpos_slice"], list):
372
- cfg["seqpos_slice"] = tuple(cfg["seqpos_slice"])
373
- elif not isinstance(cfg["seqpos_slice"], tuple):
374
- cfg["seqpos_slice"] = (cfg["seqpos_slice"],)
375
-
376
- return cls(**cfg)
377
-
378
355
  def to_sae_trainer_config(self) -> "SAETrainerConfig":
379
356
  return SAETrainerConfig(
380
357
  n_checkpoints=self.n_checkpoints,
sae_lens/saes/__init__.py CHANGED
@@ -1,3 +1,7 @@
1
+ from .batchtopk_sae import (
2
+ BatchTopKTrainingSAE,
3
+ BatchTopKTrainingSAEConfig,
4
+ )
1
5
  from .gated_sae import (
2
6
  GatedSAE,
3
7
  GatedSAEConfig,
@@ -45,4 +49,6 @@ __all__ = [
45
49
  "TopKSAEConfig",
46
50
  "TopKTrainingSAE",
47
51
  "TopKTrainingSAEConfig",
52
+ "BatchTopKTrainingSAE",
53
+ "BatchTopKTrainingSAEConfig",
48
54
  ]
@@ -0,0 +1,102 @@
1
+ from dataclasses import dataclass
2
+ from typing import Any, Callable
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from typing_extensions import override
7
+
8
+ from sae_lens.saes.jumprelu_sae import JumpReLUSAEConfig
9
+ from sae_lens.saes.sae import SAEConfig, TrainStepInput, TrainStepOutput
10
+ from sae_lens.saes.topk_sae import TopKTrainingSAE, TopKTrainingSAEConfig
11
+
12
+
13
+ class BatchTopK(nn.Module):
14
+ """BatchTopK activation function"""
15
+
16
+ def __init__(
17
+ self,
18
+ k: int,
19
+ ):
20
+ super().__init__()
21
+ self.k = k
22
+
23
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
24
+ acts = x.relu()
25
+ flat_acts = acts.flatten()
26
+ acts_topk_flat = torch.topk(flat_acts, self.k * acts.shape[0], dim=-1)
27
+ return (
28
+ torch.zeros_like(flat_acts)
29
+ .scatter(-1, acts_topk_flat.indices, acts_topk_flat.values)
30
+ .reshape(acts.shape)
31
+ )
32
+
33
+
34
+ @dataclass
35
+ class BatchTopKTrainingSAEConfig(TopKTrainingSAEConfig):
36
+ """
37
+ Configuration class for training a BatchTopKTrainingSAE.
38
+ """
39
+
40
+ topk_threshold_lr: float = 0.01
41
+
42
+ @override
43
+ @classmethod
44
+ def architecture(cls) -> str:
45
+ return "batchtopk"
46
+
47
+ @override
48
+ def get_inference_config_class(self) -> type[SAEConfig]:
49
+ return JumpReLUSAEConfig
50
+
51
+
52
+ class BatchTopKTrainingSAE(TopKTrainingSAE):
53
+ """
54
+ Global Batch TopK Training SAE
55
+
56
+ This SAE will maintain the k on average across the batch, rather than enforcing the k per-sample as in standard TopK.
57
+
58
+ BatchTopK SAEs are saved as JumpReLU SAEs after training.
59
+ """
60
+
61
+ topk_threshold: torch.Tensor
62
+ cfg: BatchTopKTrainingSAEConfig # type: ignore[assignment]
63
+
64
+ def __init__(self, cfg: BatchTopKTrainingSAEConfig, use_error_term: bool = False):
65
+ super().__init__(cfg, use_error_term)
66
+
67
+ self.register_buffer(
68
+ "topk_threshold",
69
+ # use double precision as otherwise we can run into numerical issues
70
+ torch.tensor(0.0, dtype=torch.double, device=self.W_dec.device),
71
+ )
72
+
73
+ def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
74
+ return BatchTopK(self.cfg.k)
75
+
76
+ @override
77
+ def training_forward_pass(self, step_input: TrainStepInput) -> TrainStepOutput:
78
+ output = super().training_forward_pass(step_input)
79
+ self.update_topk_threshold(output.feature_acts)
80
+ output.metrics["topk_threshold"] = self.topk_threshold
81
+ return output
82
+
83
+ @torch.no_grad()
84
+ def update_topk_threshold(self, acts_topk: torch.Tensor) -> None:
85
+ positive_mask = acts_topk > 0
86
+ lr = self.cfg.topk_threshold_lr
87
+ # autocast can cause numerical issues with the threshold update
88
+ with torch.autocast(self.topk_threshold.device.type, enabled=False):
89
+ if positive_mask.any():
90
+ min_positive = (
91
+ acts_topk[positive_mask].min().to(self.topk_threshold.dtype)
92
+ )
93
+ self.topk_threshold = (1 - lr) * self.topk_threshold + lr * min_positive
94
+
95
+ @override
96
+ def process_state_dict_for_saving_inference(
97
+ self, state_dict: dict[str, Any]
98
+ ) -> None:
99
+ super().process_state_dict_for_saving_inference(state_dict)
100
+ # turn the topk threshold into jumprelu threshold
101
+ topk_threshold = state_dict.pop("topk_threshold").item()
102
+ state_dict["threshold"] = torch.ones_like(self.b_enc) * topk_threshold
@@ -15,7 +15,6 @@ from sae_lens.saes.sae import (
15
15
  TrainingSAEConfig,
16
16
  TrainStepInput,
17
17
  )
18
- from sae_lens.util import filter_valid_dataclass_fields
19
18
 
20
19
 
21
20
  @dataclass
@@ -100,16 +99,10 @@ class GatedSAE(SAE[GatedSAEConfig]):
100
99
  self.W_enc.data = self.W_enc.data * W_dec_norms.T
101
100
 
102
101
  # Gated-specific parameters need special handling
103
- self.r_mag.data = self.r_mag.data * W_dec_norms.squeeze()
102
+ # r_mag doesn't need scaling since W_enc scaling is sufficient for magnitude path
104
103
  self.b_gate.data = self.b_gate.data * W_dec_norms.squeeze()
105
104
  self.b_mag.data = self.b_mag.data * W_dec_norms.squeeze()
106
105
 
107
- @torch.no_grad()
108
- def initialize_decoder_norm_constant_norm(self, norm: float = 0.1):
109
- """Initialize decoder with constant norm."""
110
- self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
111
- self.W_dec.data *= norm
112
-
113
106
 
114
107
  @dataclass
115
108
  class GatedTrainingSAEConfig(TrainingSAEConfig):
@@ -133,7 +126,7 @@ class GatedTrainingSAE(TrainingSAE[GatedTrainingSAEConfig]):
133
126
  - initialize_weights: sets up gating parameters (as in GatedSAE) plus optional training-specific init.
134
127
  - encode: calls encode_with_hidden_pre (standard training approach).
135
128
  - decode: linear transformation + hooking, same as GatedSAE or StandardTrainingSAE.
136
- - encode_with_hidden_pre: gating logic + optional noise injection for training.
129
+ - encode_with_hidden_pre: gating logic.
137
130
  - calculate_aux_loss: includes an auxiliary reconstruction path and gating-based sparsity penalty.
138
131
  - training_forward_pass: calls encode_with_hidden_pre, decode, and sums up MSE + gating losses.
139
132
  """
@@ -158,7 +151,6 @@ class GatedTrainingSAE(TrainingSAE[GatedTrainingSAEConfig]):
158
151
  ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
159
152
  """
160
153
  Gated forward pass with pre-activation (for training).
161
- We also inject noise if self.training is True.
162
154
  """
163
155
  sae_in = self.process_sae_in(x)
164
156
 
@@ -219,12 +211,6 @@ class GatedTrainingSAE(TrainingSAE[GatedTrainingSAEConfig]):
219
211
  "weights/b_mag": b_mag_dist,
220
212
  }
221
213
 
222
- @torch.no_grad()
223
- def initialize_decoder_norm_constant_norm(self, norm: float = 0.1):
224
- """Initialize decoder with constant norm"""
225
- self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
226
- self.W_dec.data *= norm
227
-
228
214
  def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]:
229
215
  return {
230
216
  "l1": TrainCoefficientConfig(
@@ -233,10 +219,17 @@ class GatedTrainingSAE(TrainingSAE[GatedTrainingSAEConfig]):
233
219
  ),
234
220
  }
235
221
 
236
- def to_inference_config_dict(self) -> dict[str, Any]:
237
- return filter_valid_dataclass_fields(
238
- self.cfg.to_dict(), GatedSAEConfig, ["architecture"]
239
- )
222
+ @torch.no_grad()
223
+ def fold_W_dec_norm(self):
224
+ """Override to handle gated-specific parameters."""
225
+ W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
226
+ self.W_dec.data = self.W_dec.data / W_dec_norms
227
+ self.W_enc.data = self.W_enc.data * W_dec_norms.T
228
+
229
+ # Gated-specific parameters need special handling
230
+ # r_mag doesn't need scaling since W_enc scaling is sufficient for magnitude path
231
+ self.b_gate.data = self.b_gate.data * W_dec_norms.squeeze()
232
+ self.b_mag.data = self.b_mag.data * W_dec_norms.squeeze()
240
233
 
241
234
 
242
235
  def _init_weights_gated(
@@ -14,9 +14,7 @@ from sae_lens.saes.sae import (
14
14
  TrainingSAE,
15
15
  TrainingSAEConfig,
16
16
  TrainStepInput,
17
- TrainStepOutput,
18
17
  )
19
- from sae_lens.util import filter_valid_dataclass_fields
20
18
 
21
19
 
22
20
  def rectangle(x: torch.Tensor) -> torch.Tensor:
@@ -208,12 +206,11 @@ class JumpReLUTrainingSAE(TrainingSAE[JumpReLUTrainingSAEConfig]):
208
206
 
209
207
  Similar to the inference-only JumpReLUSAE, but with:
210
208
  - A learnable log-threshold parameter (instead of a raw threshold).
211
- - Forward passes that add noise during training, if configured.
212
209
  - A specialized auxiliary loss term for sparsity (L0 or similar).
213
210
 
214
211
  Methods of interest include:
215
212
  - initialize_weights: sets up W_enc, b_enc, W_dec, b_dec, and log_threshold.
216
- - encode_with_hidden_pre_jumprelu: runs a forward pass for training, optionally adding noise.
213
+ - encode_with_hidden_pre_jumprelu: runs a forward pass for training.
217
214
  - training_forward_pass: calculates MSE and auxiliary losses, returning a TrainStepOutput.
218
215
  """
219
216
 
@@ -300,34 +297,6 @@ class JumpReLUTrainingSAE(TrainingSAE[JumpReLUTrainingSAEConfig]):
300
297
  # Fix: Use squeeze() instead of squeeze(-1) to match old behavior
301
298
  self.log_threshold.data = torch.log(current_thresh * W_dec_norms.squeeze())
302
299
 
303
- def _create_train_step_output(
304
- self,
305
- sae_in: torch.Tensor,
306
- sae_out: torch.Tensor,
307
- feature_acts: torch.Tensor,
308
- hidden_pre: torch.Tensor,
309
- loss: torch.Tensor,
310
- losses: dict[str, torch.Tensor],
311
- ) -> TrainStepOutput:
312
- """
313
- Helper to produce a TrainStepOutput from the trainer.
314
- The old code expects a method named _create_train_step_output().
315
- """
316
- return TrainStepOutput(
317
- sae_in=sae_in,
318
- sae_out=sae_out,
319
- feature_acts=feature_acts,
320
- hidden_pre=hidden_pre,
321
- loss=loss,
322
- losses=losses,
323
- )
324
-
325
- @torch.no_grad()
326
- def initialize_decoder_norm_constant_norm(self, norm: float = 0.1):
327
- """Initialize decoder with constant norm"""
328
- self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
329
- self.W_dec.data *= norm
330
-
331
300
  def process_state_dict_for_saving(self, state_dict: dict[str, Any]) -> None:
332
301
  """Convert log_threshold to threshold for saving"""
333
302
  if "log_threshold" in state_dict:
@@ -341,8 +310,3 @@ class JumpReLUTrainingSAE(TrainingSAE[JumpReLUTrainingSAEConfig]):
341
310
  threshold = state_dict["threshold"]
342
311
  del state_dict["threshold"]
343
312
  state_dict["log_threshold"] = torch.log(threshold).detach().contiguous()
344
-
345
- def to_inference_config_dict(self) -> dict[str, Any]:
346
- return filter_valid_dataclass_fields(
347
- self.cfg.to_dict(), JumpReLUSAEConfig, ["architecture"]
348
- )
sae_lens/saes/sae.py CHANGED
@@ -27,7 +27,7 @@ from torch import nn
27
27
  from transformer_lens.hook_points import HookedRootModule, HookPoint
28
28
  from typing_extensions import deprecated, overload, override
29
29
 
30
- from sae_lens import __version__, logger
30
+ from sae_lens import __version__
31
31
  from sae_lens.constants import (
32
32
  DTYPE_MAP,
33
33
  SAE_CFG_FILENAME,
@@ -207,6 +207,8 @@ class TrainStepOutput:
207
207
  hidden_pre: torch.Tensor
208
208
  loss: torch.Tensor # we need to call backwards on this
209
209
  losses: dict[str, torch.Tensor]
210
+ # any extra metrics to log can be added here
211
+ metrics: dict[str, torch.Tensor | float | int] = field(default_factory=dict)
210
212
 
211
213
 
212
214
  @dataclass
@@ -528,28 +530,6 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
528
530
 
529
531
  return model_weights_path, cfg_path
530
532
 
531
- ## Initialization Methods
532
- @torch.no_grad()
533
- def initialize_b_dec_with_precalculated(self, origin: torch.Tensor):
534
- out = torch.tensor(origin, dtype=self.dtype, device=self.device)
535
- self.b_dec.data = out
536
-
537
- @torch.no_grad()
538
- def initialize_b_dec_with_mean(self, all_activations: torch.Tensor):
539
- previous_b_dec = self.b_dec.clone().cpu()
540
- out = all_activations.mean(dim=0)
541
-
542
- previous_distances = torch.norm(all_activations - previous_b_dec, dim=-1)
543
- distances = torch.norm(all_activations - out, dim=-1)
544
-
545
- logger.info("Reinitializing b_dec with mean of activations")
546
- logger.debug(
547
- f"Previous distances: {previous_distances.median(0).values.mean().item()}"
548
- )
549
- logger.debug(f"New distances: {distances.median(0).values.mean().item()}")
550
-
551
- self.b_dec.data = out.to(self.dtype).to(self.device)
552
-
553
533
  # Class methods for loading models
554
534
  @classmethod
555
535
  @deprecated("Use load_from_disk instead")
@@ -847,20 +827,26 @@ class TrainingSAEConfig(SAEConfig, ABC):
847
827
  "architecture": self.architecture(),
848
828
  }
849
829
 
830
+ def get_inference_config_class(self) -> type[SAEConfig]:
831
+ """
832
+ Get the architecture for inference.
833
+ """
834
+ return get_sae_class(self.architecture())[1]
835
+
850
836
  # this needs to exist so we can initialize the parent sae cfg without the training specific
851
837
  # parameters. Maybe there's a cleaner way to do this
852
- def get_base_sae_cfg_dict(self) -> dict[str, Any]:
838
+ def get_inference_sae_cfg_dict(self) -> dict[str, Any]:
853
839
  """
854
840
  Creates a dictionary containing attributes corresponding to the fields
855
841
  defined in the base SAEConfig class.
856
842
  """
857
- base_sae_cfg_class = get_sae_class(self.architecture())[1]
843
+ base_sae_cfg_class = self.get_inference_config_class()
858
844
  base_config_field_names = {f.name for f in fields(base_sae_cfg_class)}
859
845
  result_dict = {
860
846
  field_name: getattr(self, field_name)
861
847
  for field_name in base_config_field_names
862
848
  }
863
- result_dict["architecture"] = self.architecture()
849
+ result_dict["architecture"] = base_sae_cfg_class.architecture()
864
850
  result_dict["metadata"] = self.metadata.to_dict()
865
851
  return result_dict
866
852
 
@@ -988,18 +974,13 @@ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
988
974
  save_file(state_dict, model_weights_path)
989
975
 
990
976
  # Save the config
991
- config = self.to_inference_config_dict()
977
+ config = self.cfg.get_inference_sae_cfg_dict()
992
978
  cfg_path = path / SAE_CFG_FILENAME
993
979
  with open(cfg_path, "w") as f:
994
980
  json.dump(config, f)
995
981
 
996
982
  return model_weights_path, cfg_path
997
983
 
998
- @abstractmethod
999
- def to_inference_config_dict(self) -> dict[str, Any]:
1000
- """Convert the config into an inference SAE config dict."""
1001
- ...
1002
-
1003
984
  def process_state_dict_for_saving_inference(
1004
985
  self, state_dict: dict[str, Any]
1005
986
  ) -> None:
@@ -1009,23 +990,6 @@ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
1009
990
  """
1010
991
  return self.process_state_dict_for_saving(state_dict)
1011
992
 
1012
- @torch.no_grad()
1013
- def remove_gradient_parallel_to_decoder_directions(self) -> None:
1014
- """Remove gradient components parallel to decoder directions."""
1015
- # Implement the original logic since this may not be in the base class
1016
- assert self.W_dec.grad is not None
1017
-
1018
- parallel_component = einops.einsum(
1019
- self.W_dec.grad,
1020
- self.W_dec.data,
1021
- "d_sae d_in, d_sae d_in -> d_sae",
1022
- )
1023
- self.W_dec.grad -= einops.einsum(
1024
- parallel_component,
1025
- self.W_dec.data,
1026
- "d_sae, d_sae d_in -> d_sae d_in",
1027
- )
1028
-
1029
993
  @torch.no_grad()
1030
994
  def log_histograms(self) -> dict[str, NDArray[Any]]:
1031
995
  """Log histograms of the weights and biases."""
@@ -1,5 +1,4 @@
1
1
  from dataclasses import dataclass
2
- from typing import Any
3
2
 
4
3
  import numpy as np
5
4
  import torch
@@ -16,7 +15,6 @@ from sae_lens.saes.sae import (
16
15
  TrainingSAEConfig,
17
16
  TrainStepInput,
18
17
  )
19
- from sae_lens.util import filter_valid_dataclass_fields
20
18
 
21
19
 
22
20
  @dataclass
@@ -61,7 +59,6 @@ class StandardSAE(SAE[StandardSAEConfig]):
61
59
  ) -> Float[torch.Tensor, "... d_sae"]:
62
60
  """
63
61
  Encode the input tensor into the feature space.
64
- For inference, no noise is added.
65
62
  """
66
63
  # Preprocess the SAE input (casting type, applying hooks, normalization)
67
64
  sae_in = self.process_sae_in(x)
@@ -110,7 +107,7 @@ class StandardTrainingSAE(TrainingSAE[StandardTrainingSAEConfig]):
110
107
  - initialize_weights: basic weight initialization for encoder/decoder.
111
108
  - encode: inference encoding (invokes encode_with_hidden_pre).
112
109
  - decode: a simple linear decoder.
113
- - encode_with_hidden_pre: computes pre-activations, adds noise when training, and then activates.
110
+ - encode_with_hidden_pre: computes activations and pre-activations.
114
111
  - calculate_aux_loss: computes a sparsity penalty based on the (optionally scaled) p-norm of feature activations.
115
112
  """
116
113
 
@@ -164,11 +161,6 @@ class StandardTrainingSAE(TrainingSAE[StandardTrainingSAEConfig]):
164
161
  "weights/b_e": b_e_dist,
165
162
  }
166
163
 
167
- def to_inference_config_dict(self) -> dict[str, Any]:
168
- return filter_valid_dataclass_fields(
169
- self.cfg.to_dict(), StandardSAEConfig, ["architecture"]
170
- )
171
-
172
164
 
173
165
  def _init_weights_standard(
174
166
  sae: SAE[StandardSAEConfig] | TrainingSAE[StandardTrainingSAEConfig],
sae_lens/saes/topk_sae.py CHANGED
@@ -1,7 +1,7 @@
1
1
  """Inference-only TopKSAE variant, similar in spirit to StandardSAE but using a TopK-based activation."""
2
2
 
3
3
  from dataclasses import dataclass
4
- from typing import Any, Callable
4
+ from typing import Callable
5
5
 
6
6
  import torch
7
7
  from jaxtyping import Float
@@ -16,13 +16,12 @@ from sae_lens.saes.sae import (
16
16
  TrainingSAEConfig,
17
17
  TrainStepInput,
18
18
  )
19
- from sae_lens.util import filter_valid_dataclass_fields
20
19
 
21
20
 
22
21
  class TopK(nn.Module):
23
22
  """
24
23
  A simple TopK activation that zeroes out all but the top K elements along the last dimension,
25
- then optionally applies a post-activation function (e.g., ReLU).
24
+ and applies ReLU to the top K elements.
26
25
  """
27
26
 
28
27
  b_enc: nn.Parameter
@@ -30,20 +29,18 @@ class TopK(nn.Module):
30
29
  def __init__(
31
30
  self,
32
31
  k: int,
33
- postact_fn: Callable[[torch.Tensor], torch.Tensor] = nn.ReLU(),
34
32
  ):
35
33
  super().__init__()
36
34
  self.k = k
37
- self.postact_fn = postact_fn
38
35
 
39
36
  def forward(self, x: torch.Tensor) -> torch.Tensor:
40
37
  """
41
38
  1) Select top K elements along the last dimension.
42
- 2) Apply post-activation (often ReLU).
39
+ 2) Apply ReLU.
43
40
  3) Zero out all other entries.
44
41
  """
45
42
  topk = torch.topk(x, k=self.k, dim=-1)
46
- values = self.postact_fn(topk.values)
43
+ values = topk.values.relu()
47
44
  result = torch.zeros_like(x)
48
45
  result.scatter_(-1, topk.indices, values)
49
46
  return result
@@ -130,6 +127,7 @@ class TopKTrainingSAEConfig(TrainingSAEConfig):
130
127
  """
131
128
 
132
129
  k: int = 100
130
+ aux_loss_coefficient: float = 1.0
133
131
 
134
132
  @override
135
133
  @classmethod
@@ -139,8 +137,7 @@ class TopKTrainingSAEConfig(TrainingSAEConfig):
139
137
 
140
138
  class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
141
139
  """
142
- TopK variant with training functionality. Injects noise during training, optionally
143
- calculates a topk-related auxiliary loss, etc.
140
+ TopK variant with training functionality. Calculates a topk-related auxiliary loss, etc.
144
141
  """
145
142
 
146
143
  b_enc: nn.Parameter
@@ -157,7 +154,7 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
157
154
  self, x: Float[torch.Tensor, "... d_in"]
158
155
  ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
159
156
  """
160
- Similar to the base training method: cast input, optionally add noise, then apply TopK.
157
+ Similar to the base training method: calculate pre-activations, then apply TopK.
161
158
  """
162
159
  sae_in = self.process_sae_in(x)
163
160
  hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
@@ -235,45 +232,7 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
235
232
  # top k living latents
236
233
  recons = self.decode(auxk_acts)
237
234
  auxk_loss = (recons - residual).pow(2).sum(dim=-1).mean()
238
- return scale * auxk_loss
239
-
240
- def _calculate_topk_aux_acts(
241
- self,
242
- k_aux: int,
243
- hidden_pre: torch.Tensor,
244
- dead_neuron_mask: torch.Tensor,
245
- ) -> torch.Tensor:
246
- """
247
- Helper method to calculate activations for the auxiliary loss.
248
-
249
- Args:
250
- k_aux: Number of top dead neurons to select
251
- hidden_pre: Pre-activation values from encoder
252
- dead_neuron_mask: Boolean mask indicating which neurons are dead
253
-
254
- Returns:
255
- Tensor with activations for only the top-k dead neurons, zeros elsewhere
256
- """
257
- # Don't include living latents in this loss (set them to -inf so they won't be selected)
258
- auxk_latents = torch.where(
259
- dead_neuron_mask[None],
260
- hidden_pre,
261
- torch.tensor(-float("inf"), device=hidden_pre.device),
262
- )
263
-
264
- # Find topk values among dead neurons
265
- auxk_topk = auxk_latents.topk(k_aux, dim=-1, sorted=False)
266
-
267
- # Create a tensor of zeros, then place the topk values at their proper indices
268
- auxk_acts = torch.zeros_like(hidden_pre)
269
- auxk_acts.scatter_(-1, auxk_topk.indices, auxk_topk.values)
270
-
271
- return auxk_acts
272
-
273
- def to_inference_config_dict(self) -> dict[str, Any]:
274
- return filter_valid_dataclass_fields(
275
- self.cfg.to_dict(), TopKSAEConfig, ["architecture"]
276
- )
235
+ return self.cfg.aux_loss_coefficient * scale * auxk_loss
277
236
 
278
237
 
279
238
  def _calculate_topk_aux_acts(
@@ -281,6 +240,18 @@ def _calculate_topk_aux_acts(
281
240
  hidden_pre: torch.Tensor,
282
241
  dead_neuron_mask: torch.Tensor,
283
242
  ) -> torch.Tensor:
243
+ """
244
+ Helper method to calculate activations for the auxiliary loss.
245
+
246
+ Args:
247
+ k_aux: Number of top dead neurons to select
248
+ hidden_pre: Pre-activation values from encoder
249
+ dead_neuron_mask: Boolean mask indicating which neurons are dead
250
+
251
+ Returns:
252
+ Tensor with activations for only the top-k dead neurons, zeros elsewhere
253
+ """
254
+
284
255
  # Don't include living latents in this loss
285
256
  auxk_latents = torch.where(dead_neuron_mask[None], hidden_pre, -torch.inf)
286
257
  # Top-k dead latents
@@ -7,7 +7,6 @@ from collections.abc import Generator, Iterator, Sequence
7
7
  from typing import Any, Literal, cast
8
8
 
9
9
  import datasets
10
- import numpy as np
11
10
  import torch
12
11
  from datasets import Dataset, DatasetDict, IterableDataset, load_dataset
13
12
  from huggingface_hub import hf_hub_download
@@ -420,20 +419,6 @@ class ActivationsStore:
420
419
 
421
420
  return activations_dataset
422
421
 
423
- @torch.no_grad()
424
- def estimate_norm_scaling_factor(self, n_batches_for_norm_estimate: int = int(1e3)):
425
- norms_per_batch = []
426
- for _ in tqdm(
427
- range(n_batches_for_norm_estimate), desc="Estimating norm scaling factor"
428
- ):
429
- # temporalily set estimated_norm_scaling_factor to 1.0 so the dataloader works
430
- self.estimated_norm_scaling_factor = 1.0
431
- acts = self.next_batch()[:, 0]
432
- self.estimated_norm_scaling_factor = None
433
- norms_per_batch.append(acts.norm(dim=-1).mean().item())
434
- mean_norm = np.mean(norms_per_batch)
435
- return np.sqrt(self.d_in) / mean_norm
436
-
437
422
  def shuffle_input_dataset(self, seed: int, buffer_size: int = 1):
438
423
  """
439
424
  This applies a shuffle to the huggingface dataset that is the input to the activations store. This
@@ -2,8 +2,6 @@
2
2
  Took the LR scheduler from my previous work: https://github.com/jbloomAus/DecisionTransformerInterpretability/blob/ee55df35cdb92e81d689c72fb9dd5a7252893363/src/decision_transformer/utils.py#L425
3
3
  """
4
4
 
5
- from typing import Any
6
-
7
5
  import torch.optim as optim
8
6
  import torch.optim.lr_scheduler as lr_scheduler
9
7
 
@@ -152,34 +150,3 @@ class CoefficientScheduler:
152
150
  def value(self) -> float:
153
151
  """Returns the current scalar value."""
154
152
  return self.current_value
155
-
156
- def state_dict(self) -> dict[str, Any]:
157
- """State dict for serialization."""
158
- return {
159
- "warm_up_steps": self.warm_up_steps,
160
- "final_value": self.final_value,
161
- "current_step": self.current_step,
162
- "current_value": self.current_value,
163
- }
164
-
165
- def load_state_dict(self, state_dict: dict[str, Any]):
166
- """Loads the scheduler state."""
167
- self.warm_up_steps = state_dict["warm_up_steps"]
168
- self.final_value = state_dict["final_value"]
169
- self.current_step = state_dict["current_step"]
170
- # Maintain consistency: re-calculate current_value based on loaded step
171
- # This handles resuming correctly if stopped mid-warmup.
172
- if self.current_step <= self.warm_up_steps and self.warm_up_steps > 0:
173
- # Use max(0, ...) to handle case where current_step might be loaded as -1 or similar before first step
174
- step_for_calc = max(0, self.current_step)
175
- # Recalculate based on the step *before* the one about to be taken
176
- # Or simply use the saved current_value if available and consistent
177
- if "current_value" in state_dict:
178
- self.current_value = state_dict["current_value"]
179
- else: # Legacy state dicts might not have current_value
180
- self.current_value = self.final_value * (
181
- step_for_calc / self.warm_up_steps
182
- )
183
-
184
- else:
185
- self.current_value = self.final_value
@@ -349,8 +349,10 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
349
349
  },
350
350
  }
351
351
  for loss_name, loss_value in output.losses.items():
352
- loss_item = _unwrap_item(loss_value)
353
- log_dict[f"losses/{loss_name}"] = loss_item
352
+ log_dict[f"losses/{loss_name}"] = _unwrap_item(loss_value)
353
+
354
+ for metric_name, metric_value in output.metrics.items():
355
+ log_dict[f"metrics/{metric_name}"] = _unwrap_item(metric_value)
354
356
 
355
357
  return log_dict
356
358
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: sae-lens
3
- Version: 6.0.0rc5
3
+ Version: 6.2.0
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  Keywords: deep-learning,sparse-autoencoders,mechanistic-interpretability,PyTorch
@@ -68,6 +68,10 @@ This library is maintained by [Joseph Bloom](https://www.jbloomaus.com/), [Curt
68
68
 
69
69
  Pre-trained SAEs for various models can be imported via SAE Lens. See this [page](https://jbloomaus.github.io/SAELens/sae_table/) in the readme for a list of all SAEs.
70
70
 
71
+ ## Migrating to SAELens v6
72
+
73
+ The new v6 update is a major refactor to SAELens and changes the way training code is structured. Check out the [migration guide](https://jbloomaus.github.io/SAELens/latest/migrating/) for more details.
74
+
71
75
  ## Tutorials
72
76
 
73
77
  - [SAE Lens + Neuronpedia](tutorials/tutorial_2_0.ipynb)[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/jbloomAus/SAELens/blob/main/tutorials/tutorial_2_0.ipynb)
@@ -1,9 +1,9 @@
1
- sae_lens/__init__.py,sha256=hiHDLT9_1V7iVulw5hwqDqDj2HVxUR9I88xOfYx6X94,2861
1
+ sae_lens/__init__.py,sha256=ByxdNdLeg_pvK89IX1lHa6iHgs2ab-UulX55Y0hUhY4,3073
2
2
  sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  sae_lens/analysis/hooked_sae_transformer.py,sha256=Eyg1Y2hVIHNuiiLOCTgzstOuW6iA-7hPHqaGR8y_vMs,13809
4
4
  sae_lens/analysis/neuronpedia_integration.py,sha256=MrENqc81Mc2SMbxGjbwHzpkGUCAFKSf0i4EdaUF2Oj4,18707
5
5
  sae_lens/cache_activations_runner.py,sha256=L5hhuU2-zPQr2S3L64GMKKLeMQfqXxwDl8NbuOtrybI,12567
6
- sae_lens/config.py,sha256=9Lg4HkQvj1t9QZJdmC071lyJMc_iqNQknosT7zOYfwM,27278
6
+ sae_lens/config.py,sha256=qMMx9KuiXTD5lG3g0VzaekWOnvdAzGFSq8j1n-GObEQ,26467
7
7
  sae_lens/constants.py,sha256=CSjmiZ-bhjQeVLyRvWxAjBokCgkfM8mnvd7-vxLIWTY,639
8
8
  sae_lens/evals.py,sha256=kQyrzczKaVD9rHwfFa_DxL_gMXDxsoIVHmsFIPIU2bY,38696
9
9
  sae_lens/llm_sae_training_runner.py,sha256=58XbDylw2fPOD7C-ZfSAjeNqJLXB05uHGTuiYVVbXXY,13354
@@ -14,24 +14,25 @@ sae_lens/loading/pretrained_saes_directory.py,sha256=4Vn-Jex6SveD7EbxcSOBv8cx1gk
14
14
  sae_lens/pretokenize_runner.py,sha256=0nHQq3s_d80VS8iVK4-e6y_orAYVO8c4RrLGtIDfK_E,6885
15
15
  sae_lens/pretrained_saes.yaml,sha256=nhHW1auhyi4GHYrjUnHQqbNVhI5cMJv-HThzbzU1xG0,574145
16
16
  sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
17
- sae_lens/saes/__init__.py,sha256=v6mfeDzyGYtT6x5SszAQtkldTXwPE-V_iwOlrT_pDwQ,1008
18
- sae_lens/saes/gated_sae.py,sha256=0zd66bH04nsaGk3bxHk10hsZofa2GrFbMo15LOsuqgU,9233
19
- sae_lens/saes/jumprelu_sae.py,sha256=iwmPQJ4XpIxzgosty680u8Zj7x1uVZhM75kPOT3obi0,12060
20
- sae_lens/saes/sae.py,sha256=ZEXEXFVtrtFrzuOV3nyweTBleNCV4EDGh1ImaF32uqg,39618
21
- sae_lens/saes/standard_sae.py,sha256=PfkGLsw_6La3PXHOQL0u7qQsaZsXCJqYCeCcRDj5n64,6274
22
- sae_lens/saes/topk_sae.py,sha256=kmry1FE1H06OvCfn84V-j2JfWGKcU5b2urwAq_Oq5j4,9893
17
+ sae_lens/saes/__init__.py,sha256=RYqE1qkMws-kwQLmBZFhA_VCa69zVtBjGPIy_UAk2pw,1159
18
+ sae_lens/saes/batchtopk_sae.py,sha256=CyaFG2hMyyDaEaXXrAMJC8wQDW1JoddTKF5mvxxBQKY,3395
19
+ sae_lens/saes/gated_sae.py,sha256=qcmM9JwBA8aZR8z_IRHV1_gQX-q_63tKewWXRnhdXuo,8986
20
+ sae_lens/saes/jumprelu_sae.py,sha256=3xkhBcCol2mEpIBLceymCpudocm2ypOjTeTXbpiXoA4,10794
21
+ sae_lens/saes/sae.py,sha256=McpF4pTh70r6SQUbHFm0YQ9X2c2qPULBUSd_YmnEk4Y,38284
22
+ sae_lens/saes/standard_sae.py,sha256=9UqYyYtQuThYxXKNaDjYcyowpOx2-7cShG-TeUP6JCQ,5940
23
+ sae_lens/saes/topk_sae.py,sha256=CXMBI6CFvI5829bOhoQ350VXR9d8uFHUDlULTIWHXoU,8686
23
24
  sae_lens/tokenization_and_batching.py,sha256=oUAscjy_LPOrOb8_Ty6eLAcZ0B3HB_wiWjWktgolhG0,4314
24
25
  sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
25
26
  sae_lens/training/activation_scaler.py,sha256=seEE-2Qd2JMHxqgnsNWPt-DGtYGZxWPnOwCGuVNSOtI,1719
26
- sae_lens/training/activations_store.py,sha256=z8erbiB6ODbsqlu-bwEWbyj4XZvgsVgjCRBuQovqp2Q,32612
27
+ sae_lens/training/activations_store.py,sha256=HBN3oEib3PlPUDJb_yVFabQp0JcN9rWbnUN1s2DBMAs,31933
27
28
  sae_lens/training/mixing_buffer.py,sha256=vDpYG5ZE70szDvBsRKcNHEES3h_WTKJ16qDYk5jPOVA,2015
28
- sae_lens/training/optim.py,sha256=KXdOym-Ly3f2aFbndRc0JEH0Wa7u1BE5ljxGN3YtouQ,6836
29
- sae_lens/training/sae_trainer.py,sha256=9K0VudwSTJp9OlCVzaU_ngZ0WlYNrN6-ozTCCAxR9_k,15421
29
+ sae_lens/training/optim.py,sha256=TiI9nbffzXNsI8WjcIsqa2uheW6suxqL_KDDmWXobWI,5312
30
+ sae_lens/training/sae_trainer.py,sha256=2xcO-02OozFunob5vwoHud-hVMhVl9d28_F9gDCiL6o,15529
30
31
  sae_lens/training/types.py,sha256=qSjmGzXf3MLalygG0psnVjmhX_mpLmL47MQtZfe7qxg,81
31
32
  sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
32
33
  sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
33
34
  sae_lens/util.py,sha256=mCwLAilGMVo8Scm7CIsCafU7GsfmBvCcjwmloI4Ly7Y,1718
34
- sae_lens-6.0.0rc5.dist-info/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
35
- sae_lens-6.0.0rc5.dist-info/METADATA,sha256=ZrBaBFeIuM-ZJ9r0HHKakxnx3tGv7Zf6l_Z2OIdBxIU,5326
36
- sae_lens-6.0.0rc5.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
37
- sae_lens-6.0.0rc5.dist-info/RECORD,,
35
+ sae_lens-6.2.0.dist-info/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
36
+ sae_lens-6.2.0.dist-info/METADATA,sha256=Fqsq0scF5Uia0YBmeZQwVi4m4DX16_Ck-cKokbuch7U,5555
37
+ sae_lens-6.2.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
38
+ sae_lens-6.2.0.dist-info/RECORD,,