sae-lens 6.0.0rc4__py3-none-any.whl → 6.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sae_lens/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.0.0-rc.4"
2
+ __version__ = "6.1.0"
3
3
 
4
4
  import logging
5
5
 
@@ -7,6 +7,8 @@ logger = logging.getLogger(__name__)
7
7
 
8
8
  from sae_lens.saes import (
9
9
  SAE,
10
+ BatchTopKTrainingSAE,
11
+ BatchTopKTrainingSAEConfig,
10
12
  GatedSAE,
11
13
  GatedSAEConfig,
12
14
  GatedTrainingSAE,
@@ -85,6 +87,8 @@ __all__ = [
85
87
  "JumpReLUTrainingSAEConfig",
86
88
  "SAETrainingRunner",
87
89
  "LoggingConfig",
90
+ "BatchTopKTrainingSAE",
91
+ "BatchTopKTrainingSAEConfig",
88
92
  ]
89
93
 
90
94
 
@@ -96,3 +100,6 @@ register_sae_class("topk", TopKSAE, TopKSAEConfig)
96
100
  register_sae_training_class("topk", TopKTrainingSAE, TopKTrainingSAEConfig)
97
101
  register_sae_class("jumprelu", JumpReLUSAE, JumpReLUSAEConfig)
98
102
  register_sae_training_class("jumprelu", JumpReLUTrainingSAE, JumpReLUTrainingSAEConfig)
103
+ register_sae_training_class(
104
+ "batchtopk", BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig
105
+ )
sae_lens/config.py CHANGED
@@ -1,6 +1,5 @@
1
1
  import json
2
2
  import math
3
- import os
4
3
  from dataclasses import asdict, dataclass, field
5
4
  from pathlib import Path
6
5
  from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast
@@ -353,28 +352,6 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
353
352
  d["act_store_device"] = str(self.act_store_device)
354
353
  return d
355
354
 
356
- def to_json(self, path: str) -> None:
357
- if not os.path.exists(os.path.dirname(path)):
358
- os.makedirs(os.path.dirname(path))
359
-
360
- with open(path + "cfg.json", "w") as f:
361
- json.dump(self.to_dict(), f, indent=2)
362
-
363
- @classmethod
364
- def from_json(cls, path: str) -> "LanguageModelSAERunnerConfig[Any]":
365
- with open(path + "cfg.json") as f:
366
- cfg = json.load(f)
367
-
368
- # ensure that seqpos slices is a tuple
369
- # Ensure seqpos_slice is a tuple
370
- if "seqpos_slice" in cfg:
371
- if isinstance(cfg["seqpos_slice"], list):
372
- cfg["seqpos_slice"] = tuple(cfg["seqpos_slice"])
373
- elif not isinstance(cfg["seqpos_slice"], tuple):
374
- cfg["seqpos_slice"] = (cfg["seqpos_slice"],)
375
-
376
- return cls(**cfg)
377
-
378
355
  def to_sae_trainer_config(self) -> "SAETrainerConfig":
379
356
  return SAETrainerConfig(
380
357
  n_checkpoints=self.n_checkpoints,
sae_lens/constants.py CHANGED
@@ -16,5 +16,6 @@ SPARSITY_FILENAME = "sparsity.safetensors"
16
16
  SAE_WEIGHTS_FILENAME = "sae_weights.safetensors"
17
17
  SAE_CFG_FILENAME = "cfg.json"
18
18
  RUNNER_CFG_FILENAME = "runner_cfg.json"
19
+ SPARSIFY_WEIGHTS_FILENAME = "sae.safetensors"
19
20
  ACTIVATIONS_STORE_STATE_FILENAME = "activations_store_state.safetensors"
20
21
  ACTIVATION_SCALER_CFG_FILENAME = "activation_scaler.json"
sae_lens/evals.py CHANGED
@@ -769,17 +769,6 @@ def nested_dict() -> defaultdict[Any, Any]:
769
769
  return defaultdict(nested_dict)
770
770
 
771
771
 
772
- def dict_to_nested(flat_dict: dict[str, Any]) -> defaultdict[Any, Any]:
773
- nested = nested_dict()
774
- for key, value in flat_dict.items():
775
- parts = key.split("/")
776
- d = nested
777
- for part in parts[:-1]:
778
- d = d[part]
779
- d[parts[-1]] = value
780
- return nested
781
-
782
-
783
772
  def multiple_evals(
784
773
  sae_regex_pattern: str,
785
774
  sae_block_pattern: str,
@@ -16,6 +16,7 @@ from sae_lens.constants import (
16
16
  DTYPE_MAP,
17
17
  SAE_CFG_FILENAME,
18
18
  SAE_WEIGHTS_FILENAME,
19
+ SPARSIFY_WEIGHTS_FILENAME,
19
20
  SPARSITY_FILENAME,
20
21
  )
21
22
  from sae_lens.loading.pretrained_saes_directory import (
@@ -248,7 +249,7 @@ def handle_pre_6_0_config(cfg_dict: dict[str, Any]) -> dict[str, Any]:
248
249
  config_class = get_sae_class(architecture)[1]
249
250
 
250
251
  sae_cfg_dict = filter_valid_dataclass_fields(new_cfg, config_class)
251
- if architecture == "topk":
252
+ if architecture == "topk" and "activation_fn_kwargs" in new_cfg:
252
253
  sae_cfg_dict["k"] = new_cfg["activation_fn_kwargs"]["k"]
253
254
 
254
255
  sae_cfg_dict["metadata"] = {
@@ -530,11 +531,20 @@ def get_llama_scope_config_from_hf(
530
531
  # Model specific parameters
531
532
  model_name, d_in = "meta-llama/Llama-3.1-8B", old_cfg_dict["d_model"]
532
533
 
534
+ # Get norm scaling factor to rescale jumprelu threshold.
535
+ # We need this because sae.fold_activation_norm_scaling_factor folds scaling norm into W_enc.
536
+ # This requires jumprelu threshold to be scaled in the same way
537
+ norm_scaling_factor = (
538
+ d_in**0.5 / old_cfg_dict["dataset_average_activation_norm"]["in"]
539
+ )
540
+
533
541
  cfg_dict = {
534
542
  "architecture": "jumprelu",
535
- "jump_relu_threshold": old_cfg_dict["jump_relu_threshold"],
543
+ "jump_relu_threshold": old_cfg_dict["jump_relu_threshold"]
544
+ * norm_scaling_factor,
536
545
  # We use a scalar jump_relu_threshold for all features
537
546
  # This is different from Gemma Scope JumpReLU SAEs.
547
+ # Scaled with norm_scaling_factor to match sae.fold_activation_norm_scaling_factor
538
548
  "d_in": d_in,
539
549
  "d_sae": old_cfg_dict["d_sae"],
540
550
  "dtype": "bfloat16",
@@ -942,6 +952,146 @@ def llama_scope_r1_distill_sae_huggingface_loader(
942
952
  return cfg_dict, state_dict, log_sparsity
943
953
 
944
954
 
955
+ def get_sparsify_config_from_hf(
956
+ repo_id: str,
957
+ folder_name: str,
958
+ device: str,
959
+ force_download: bool = False,
960
+ cfg_overrides: dict[str, Any] | None = None,
961
+ ) -> dict[str, Any]:
962
+ cfg_filename = f"{folder_name}/{SAE_CFG_FILENAME}"
963
+ cfg_path = hf_hub_download(
964
+ repo_id,
965
+ filename=cfg_filename,
966
+ force_download=force_download,
967
+ )
968
+ sae_path = Path(cfg_path).parent
969
+ return get_sparsify_config_from_disk(
970
+ sae_path, device=device, cfg_overrides=cfg_overrides
971
+ )
972
+
973
+
974
+ def get_sparsify_config_from_disk(
975
+ path: str | Path,
976
+ device: str | None = None,
977
+ cfg_overrides: dict[str, Any] | None = None,
978
+ ) -> dict[str, Any]:
979
+ path = Path(path)
980
+
981
+ with open(path / SAE_CFG_FILENAME) as f:
982
+ old_cfg_dict = json.load(f)
983
+
984
+ config_path = path.parent / "config.json"
985
+ if config_path.exists():
986
+ with open(config_path) as f:
987
+ config_dict = json.load(f)
988
+ else:
989
+ config_dict = {}
990
+
991
+ folder_name = path.name
992
+ if folder_name == "embed_tokens":
993
+ hook_name, layer = "hook_embed", 0
994
+ else:
995
+ match = re.search(r"layers[._](\d+)", folder_name)
996
+ if match is None:
997
+ raise ValueError(f"Unrecognized Sparsify folder: {folder_name}")
998
+ layer = int(match.group(1))
999
+ hook_name = f"blocks.{layer}.hook_resid_post"
1000
+
1001
+ cfg_dict: dict[str, Any] = {
1002
+ "architecture": "standard",
1003
+ "d_in": old_cfg_dict["d_in"],
1004
+ "d_sae": old_cfg_dict["d_in"] * old_cfg_dict["expansion_factor"],
1005
+ "dtype": "bfloat16",
1006
+ "device": device or "cpu",
1007
+ "model_name": config_dict.get("model", path.parts[-2]),
1008
+ "hook_name": hook_name,
1009
+ "hook_layer": layer,
1010
+ "hook_head_index": None,
1011
+ "activation_fn_str": "topk",
1012
+ "activation_fn_kwargs": {
1013
+ "k": old_cfg_dict["k"],
1014
+ "signed": old_cfg_dict.get("signed", False),
1015
+ },
1016
+ "apply_b_dec_to_input": not old_cfg_dict.get("normalize_decoder", False),
1017
+ "dataset_path": config_dict.get(
1018
+ "dataset", "togethercomputer/RedPajama-Data-1T-Sample"
1019
+ ),
1020
+ "context_size": config_dict.get("ctx_len", 2048),
1021
+ "finetuning_scaling_factor": False,
1022
+ "sae_lens_training_version": None,
1023
+ "prepend_bos": True,
1024
+ "dataset_trust_remote_code": True,
1025
+ "normalize_activations": "none",
1026
+ "neuronpedia_id": None,
1027
+ }
1028
+
1029
+ if cfg_overrides:
1030
+ cfg_dict.update(cfg_overrides)
1031
+
1032
+ return cfg_dict
1033
+
1034
+
1035
+ def sparsify_huggingface_loader(
1036
+ repo_id: str,
1037
+ folder_name: str,
1038
+ device: str = "cpu",
1039
+ force_download: bool = False,
1040
+ cfg_overrides: dict[str, Any] | None = None,
1041
+ ) -> tuple[dict[str, Any], dict[str, torch.Tensor], None]:
1042
+ weights_filename = f"{folder_name}/{SPARSIFY_WEIGHTS_FILENAME}"
1043
+ sae_path = hf_hub_download(
1044
+ repo_id,
1045
+ filename=weights_filename,
1046
+ force_download=force_download,
1047
+ )
1048
+ cfg_dict, state_dict = sparsify_disk_loader(
1049
+ Path(sae_path).parent, device=device, cfg_overrides=cfg_overrides
1050
+ )
1051
+ return cfg_dict, state_dict, None
1052
+
1053
+
1054
+ def sparsify_disk_loader(
1055
+ path: str | Path,
1056
+ device: str = "cpu",
1057
+ cfg_overrides: dict[str, Any] | None = None,
1058
+ ) -> tuple[dict[str, Any], dict[str, torch.Tensor]]:
1059
+ cfg_dict = get_sparsify_config_from_disk(path, device, cfg_overrides)
1060
+
1061
+ weight_path = Path(path) / SPARSIFY_WEIGHTS_FILENAME
1062
+ state_dict_loaded = load_file(weight_path, device=device)
1063
+
1064
+ dtype = DTYPE_MAP[cfg_dict["dtype"]]
1065
+
1066
+ W_enc = (
1067
+ state_dict_loaded["W_enc"]
1068
+ if "W_enc" in state_dict_loaded
1069
+ else state_dict_loaded["encoder.weight"].T
1070
+ ).to(dtype)
1071
+
1072
+ if "W_dec" in state_dict_loaded:
1073
+ W_dec = state_dict_loaded["W_dec"].T.to(dtype)
1074
+ else:
1075
+ W_dec = state_dict_loaded["decoder.weight"].T.to(dtype)
1076
+
1077
+ if "b_enc" in state_dict_loaded:
1078
+ b_enc = state_dict_loaded["b_enc"].to(dtype)
1079
+ elif "encoder.bias" in state_dict_loaded:
1080
+ b_enc = state_dict_loaded["encoder.bias"].to(dtype)
1081
+ else:
1082
+ b_enc = torch.zeros(cfg_dict["d_sae"], dtype=dtype, device=device)
1083
+
1084
+ if "b_dec" in state_dict_loaded:
1085
+ b_dec = state_dict_loaded["b_dec"].to(dtype)
1086
+ elif "decoder.bias" in state_dict_loaded:
1087
+ b_dec = state_dict_loaded["decoder.bias"].to(dtype)
1088
+ else:
1089
+ b_dec = torch.zeros(cfg_dict["d_in"], dtype=dtype, device=device)
1090
+
1091
+ state_dict = {"W_enc": W_enc, "b_enc": b_enc, "W_dec": W_dec, "b_dec": b_dec}
1092
+ return cfg_dict, state_dict
1093
+
1094
+
945
1095
  NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
946
1096
  "sae_lens": sae_lens_huggingface_loader,
947
1097
  "connor_rob_hook_z": connor_rob_hook_z_huggingface_loader,
@@ -950,6 +1100,7 @@ NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
950
1100
  "llama_scope_r1_distill": llama_scope_r1_distill_sae_huggingface_loader,
951
1101
  "dictionary_learning_1": dictionary_learning_sae_huggingface_loader_1,
952
1102
  "deepseek_r1": deepseek_r1_sae_huggingface_loader,
1103
+ "sparsify": sparsify_huggingface_loader,
953
1104
  }
954
1105
 
955
1106
 
@@ -961,4 +1112,5 @@ NAMED_PRETRAINED_SAE_CONFIG_GETTERS: dict[str, PretrainedSaeConfigHuggingfaceLoa
961
1112
  "llama_scope_r1_distill": get_llama_scope_r1_distill_config_from_hf,
962
1113
  "dictionary_learning_1": get_dictionary_learning_config_1_from_hf,
963
1114
  "deepseek_r1": get_deepseek_r1_config_from_hf,
1115
+ "sparsify": get_sparsify_config_from_hf,
964
1116
  }
@@ -13634,39 +13634,51 @@ gemma-2-2b-res-matryoshka-dc:
13634
13634
  - id: blocks.13.hook_resid_post
13635
13635
  path: standard/blocks.13.hook_resid_post
13636
13636
  l0: 40.0
13637
+ neuronpedia: gemma-2-2b/13-res-matryoshka-dc
13637
13638
  - id: blocks.14.hook_resid_post
13638
13639
  path: standard/blocks.14.hook_resid_post
13639
13640
  l0: 40.0
13641
+ neuronpedia: gemma-2-2b/14-res-matryoshka-dc
13640
13642
  - id: blocks.15.hook_resid_post
13641
13643
  path: standard/blocks.15.hook_resid_post
13642
13644
  l0: 40.0
13645
+ neuronpedia: gemma-2-2b/15-res-matryoshka-dc
13643
13646
  - id: blocks.16.hook_resid_post
13644
13647
  path: standard/blocks.16.hook_resid_post
13645
13648
  l0: 40.0
13649
+ neuronpedia: gemma-2-2b/16-res-matryoshka-dc
13646
13650
  - id: blocks.17.hook_resid_post
13647
13651
  path: standard/blocks.17.hook_resid_post
13648
13652
  l0: 40.0
13653
+ neuronpedia: gemma-2-2b/17-res-matryoshka-dc
13649
13654
  - id: blocks.18.hook_resid_post
13650
13655
  path: standard/blocks.18.hook_resid_post
13651
13656
  l0: 40.0
13657
+ neuronpedia: gemma-2-2b/18-res-matryoshka-dc
13652
13658
  - id: blocks.19.hook_resid_post
13653
13659
  path: standard/blocks.19.hook_resid_post
13654
13660
  l0: 40.0
13661
+ neuronpedia: gemma-2-2b/19-res-matryoshka-dc
13655
13662
  - id: blocks.20.hook_resid_post
13656
13663
  path: standard/blocks.20.hook_resid_post
13657
13664
  l0: 40.0
13665
+ neuronpedia: gemma-2-2b/20-res-matryoshka-dc
13658
13666
  - id: blocks.21.hook_resid_post
13659
13667
  path: standard/blocks.21.hook_resid_post
13660
13668
  l0: 40.0
13669
+ neuronpedia: gemma-2-2b/21-res-matryoshka-dc
13661
13670
  - id: blocks.22.hook_resid_post
13662
13671
  path: standard/blocks.22.hook_resid_post
13663
13672
  l0: 40.0
13673
+ neuronpedia: gemma-2-2b/22-res-matryoshka-dc
13664
13674
  - id: blocks.23.hook_resid_post
13665
13675
  path: standard/blocks.23.hook_resid_post
13666
13676
  l0: 40.0
13677
+ neuronpedia: gemma-2-2b/23-res-matryoshka-dc
13667
13678
  - id: blocks.24.hook_resid_post
13668
13679
  path: standard/blocks.24.hook_resid_post
13669
13680
  l0: 40.0
13681
+ neuronpedia: gemma-2-2b/24-res-matryoshka-dc
13670
13682
  gemma-2-2b-res-snap-matryoshka-dc:
13671
13683
  conversion_func: null
13672
13684
  links:
sae_lens/saes/__init__.py CHANGED
@@ -1,3 +1,7 @@
1
+ from .batchtopk_sae import (
2
+ BatchTopKTrainingSAE,
3
+ BatchTopKTrainingSAEConfig,
4
+ )
1
5
  from .gated_sae import (
2
6
  GatedSAE,
3
7
  GatedSAEConfig,
@@ -45,4 +49,6 @@ __all__ = [
45
49
  "TopKSAEConfig",
46
50
  "TopKTrainingSAE",
47
51
  "TopKTrainingSAEConfig",
52
+ "BatchTopKTrainingSAE",
53
+ "BatchTopKTrainingSAEConfig",
48
54
  ]
@@ -0,0 +1,102 @@
1
+ from dataclasses import dataclass
2
+ from typing import Any, Callable
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from typing_extensions import override
7
+
8
+ from sae_lens.saes.jumprelu_sae import JumpReLUSAEConfig
9
+ from sae_lens.saes.sae import SAEConfig, TrainStepInput, TrainStepOutput
10
+ from sae_lens.saes.topk_sae import TopKTrainingSAE, TopKTrainingSAEConfig
11
+
12
+
13
+ class BatchTopK(nn.Module):
14
+ """BatchTopK activation function"""
15
+
16
+ def __init__(
17
+ self,
18
+ k: int,
19
+ ):
20
+ super().__init__()
21
+ self.k = k
22
+
23
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
24
+ acts = x.relu()
25
+ flat_acts = acts.flatten()
26
+ acts_topk_flat = torch.topk(flat_acts, self.k * acts.shape[0], dim=-1)
27
+ return (
28
+ torch.zeros_like(flat_acts)
29
+ .scatter(-1, acts_topk_flat.indices, acts_topk_flat.values)
30
+ .reshape(acts.shape)
31
+ )
32
+
33
+
34
+ @dataclass
35
+ class BatchTopKTrainingSAEConfig(TopKTrainingSAEConfig):
36
+ """
37
+ Configuration class for training a BatchTopKTrainingSAE.
38
+ """
39
+
40
+ topk_threshold_lr: float = 0.01
41
+
42
+ @override
43
+ @classmethod
44
+ def architecture(cls) -> str:
45
+ return "batchtopk"
46
+
47
+ @override
48
+ def get_inference_config_class(self) -> type[SAEConfig]:
49
+ return JumpReLUSAEConfig
50
+
51
+
52
+ class BatchTopKTrainingSAE(TopKTrainingSAE):
53
+ """
54
+ Global Batch TopK Training SAE
55
+
56
+ This SAE will maintain the k on average across the batch, rather than enforcing the k per-sample as in standard TopK.
57
+
58
+ BatchTopK SAEs are saved as JumpReLU SAEs after training.
59
+ """
60
+
61
+ topk_threshold: torch.Tensor
62
+ cfg: BatchTopKTrainingSAEConfig # type: ignore[assignment]
63
+
64
+ def __init__(self, cfg: BatchTopKTrainingSAEConfig, use_error_term: bool = False):
65
+ super().__init__(cfg, use_error_term)
66
+
67
+ self.register_buffer(
68
+ "topk_threshold",
69
+ # use double precision as otherwise we can run into numerical issues
70
+ torch.tensor(0.0, dtype=torch.double, device=self.W_dec.device),
71
+ )
72
+
73
+ def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
74
+ return BatchTopK(self.cfg.k)
75
+
76
+ @override
77
+ def training_forward_pass(self, step_input: TrainStepInput) -> TrainStepOutput:
78
+ output = super().training_forward_pass(step_input)
79
+ self.update_topk_threshold(output.feature_acts)
80
+ output.metrics["topk_threshold"] = self.topk_threshold
81
+ return output
82
+
83
+ @torch.no_grad()
84
+ def update_topk_threshold(self, acts_topk: torch.Tensor) -> None:
85
+ positive_mask = acts_topk > 0
86
+ lr = self.cfg.topk_threshold_lr
87
+ # autocast can cause numerical issues with the threshold update
88
+ with torch.autocast(self.topk_threshold.device.type, enabled=False):
89
+ if positive_mask.any():
90
+ min_positive = (
91
+ acts_topk[positive_mask].min().to(self.topk_threshold.dtype)
92
+ )
93
+ self.topk_threshold = (1 - lr) * self.topk_threshold + lr * min_positive
94
+
95
+ @override
96
+ def process_state_dict_for_saving_inference(
97
+ self, state_dict: dict[str, Any]
98
+ ) -> None:
99
+ super().process_state_dict_for_saving_inference(state_dict)
100
+ # turn the topk threshold into jumprelu threshold
101
+ topk_threshold = state_dict.pop("topk_threshold").item()
102
+ state_dict["threshold"] = torch.ones_like(self.b_enc) * topk_threshold
@@ -15,7 +15,6 @@ from sae_lens.saes.sae import (
15
15
  TrainingSAEConfig,
16
16
  TrainStepInput,
17
17
  )
18
- from sae_lens.util import filter_valid_dataclass_fields
19
18
 
20
19
 
21
20
  @dataclass
@@ -100,16 +99,10 @@ class GatedSAE(SAE[GatedSAEConfig]):
100
99
  self.W_enc.data = self.W_enc.data * W_dec_norms.T
101
100
 
102
101
  # Gated-specific parameters need special handling
103
- self.r_mag.data = self.r_mag.data * W_dec_norms.squeeze()
102
+ # r_mag doesn't need scaling since W_enc scaling is sufficient for magnitude path
104
103
  self.b_gate.data = self.b_gate.data * W_dec_norms.squeeze()
105
104
  self.b_mag.data = self.b_mag.data * W_dec_norms.squeeze()
106
105
 
107
- @torch.no_grad()
108
- def initialize_decoder_norm_constant_norm(self, norm: float = 0.1):
109
- """Initialize decoder with constant norm."""
110
- self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
111
- self.W_dec.data *= norm
112
-
113
106
 
114
107
  @dataclass
115
108
  class GatedTrainingSAEConfig(TrainingSAEConfig):
@@ -133,7 +126,7 @@ class GatedTrainingSAE(TrainingSAE[GatedTrainingSAEConfig]):
133
126
  - initialize_weights: sets up gating parameters (as in GatedSAE) plus optional training-specific init.
134
127
  - encode: calls encode_with_hidden_pre (standard training approach).
135
128
  - decode: linear transformation + hooking, same as GatedSAE or StandardTrainingSAE.
136
- - encode_with_hidden_pre: gating logic + optional noise injection for training.
129
+ - encode_with_hidden_pre: gating logic.
137
130
  - calculate_aux_loss: includes an auxiliary reconstruction path and gating-based sparsity penalty.
138
131
  - training_forward_pass: calls encode_with_hidden_pre, decode, and sums up MSE + gating losses.
139
132
  """
@@ -158,7 +151,6 @@ class GatedTrainingSAE(TrainingSAE[GatedTrainingSAEConfig]):
158
151
  ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
159
152
  """
160
153
  Gated forward pass with pre-activation (for training).
161
- We also inject noise if self.training is True.
162
154
  """
163
155
  sae_in = self.process_sae_in(x)
164
156
 
@@ -219,12 +211,6 @@ class GatedTrainingSAE(TrainingSAE[GatedTrainingSAEConfig]):
219
211
  "weights/b_mag": b_mag_dist,
220
212
  }
221
213
 
222
- @torch.no_grad()
223
- def initialize_decoder_norm_constant_norm(self, norm: float = 0.1):
224
- """Initialize decoder with constant norm"""
225
- self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
226
- self.W_dec.data *= norm
227
-
228
214
  def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]:
229
215
  return {
230
216
  "l1": TrainCoefficientConfig(
@@ -233,10 +219,17 @@ class GatedTrainingSAE(TrainingSAE[GatedTrainingSAEConfig]):
233
219
  ),
234
220
  }
235
221
 
236
- def to_inference_config_dict(self) -> dict[str, Any]:
237
- return filter_valid_dataclass_fields(
238
- self.cfg.to_dict(), GatedSAEConfig, ["architecture"]
239
- )
222
+ @torch.no_grad()
223
+ def fold_W_dec_norm(self):
224
+ """Override to handle gated-specific parameters."""
225
+ W_dec_norms = self.W_dec.norm(dim=-1).unsqueeze(1)
226
+ self.W_dec.data = self.W_dec.data / W_dec_norms
227
+ self.W_enc.data = self.W_enc.data * W_dec_norms.T
228
+
229
+ # Gated-specific parameters need special handling
230
+ # r_mag doesn't need scaling since W_enc scaling is sufficient for magnitude path
231
+ self.b_gate.data = self.b_gate.data * W_dec_norms.squeeze()
232
+ self.b_mag.data = self.b_mag.data * W_dec_norms.squeeze()
240
233
 
241
234
 
242
235
  def _init_weights_gated(
@@ -14,9 +14,7 @@ from sae_lens.saes.sae import (
14
14
  TrainingSAE,
15
15
  TrainingSAEConfig,
16
16
  TrainStepInput,
17
- TrainStepOutput,
18
17
  )
19
- from sae_lens.util import filter_valid_dataclass_fields
20
18
 
21
19
 
22
20
  def rectangle(x: torch.Tensor) -> torch.Tensor:
@@ -208,12 +206,11 @@ class JumpReLUTrainingSAE(TrainingSAE[JumpReLUTrainingSAEConfig]):
208
206
 
209
207
  Similar to the inference-only JumpReLUSAE, but with:
210
208
  - A learnable log-threshold parameter (instead of a raw threshold).
211
- - Forward passes that add noise during training, if configured.
212
209
  - A specialized auxiliary loss term for sparsity (L0 or similar).
213
210
 
214
211
  Methods of interest include:
215
212
  - initialize_weights: sets up W_enc, b_enc, W_dec, b_dec, and log_threshold.
216
- - encode_with_hidden_pre_jumprelu: runs a forward pass for training, optionally adding noise.
213
+ - encode_with_hidden_pre_jumprelu: runs a forward pass for training.
217
214
  - training_forward_pass: calculates MSE and auxiliary losses, returning a TrainStepOutput.
218
215
  """
219
216
 
@@ -300,34 +297,6 @@ class JumpReLUTrainingSAE(TrainingSAE[JumpReLUTrainingSAEConfig]):
300
297
  # Fix: Use squeeze() instead of squeeze(-1) to match old behavior
301
298
  self.log_threshold.data = torch.log(current_thresh * W_dec_norms.squeeze())
302
299
 
303
- def _create_train_step_output(
304
- self,
305
- sae_in: torch.Tensor,
306
- sae_out: torch.Tensor,
307
- feature_acts: torch.Tensor,
308
- hidden_pre: torch.Tensor,
309
- loss: torch.Tensor,
310
- losses: dict[str, torch.Tensor],
311
- ) -> TrainStepOutput:
312
- """
313
- Helper to produce a TrainStepOutput from the trainer.
314
- The old code expects a method named _create_train_step_output().
315
- """
316
- return TrainStepOutput(
317
- sae_in=sae_in,
318
- sae_out=sae_out,
319
- feature_acts=feature_acts,
320
- hidden_pre=hidden_pre,
321
- loss=loss,
322
- losses=losses,
323
- )
324
-
325
- @torch.no_grad()
326
- def initialize_decoder_norm_constant_norm(self, norm: float = 0.1):
327
- """Initialize decoder with constant norm"""
328
- self.W_dec.data /= torch.norm(self.W_dec.data, dim=1, keepdim=True)
329
- self.W_dec.data *= norm
330
-
331
300
  def process_state_dict_for_saving(self, state_dict: dict[str, Any]) -> None:
332
301
  """Convert log_threshold to threshold for saving"""
333
302
  if "log_threshold" in state_dict:
@@ -341,8 +310,3 @@ class JumpReLUTrainingSAE(TrainingSAE[JumpReLUTrainingSAEConfig]):
341
310
  threshold = state_dict["threshold"]
342
311
  del state_dict["threshold"]
343
312
  state_dict["log_threshold"] = torch.log(threshold).detach().contiguous()
344
-
345
- def to_inference_config_dict(self) -> dict[str, Any]:
346
- return filter_valid_dataclass_fields(
347
- self.cfg.to_dict(), JumpReLUSAEConfig, ["architecture"]
348
- )
sae_lens/saes/sae.py CHANGED
@@ -27,7 +27,7 @@ from torch import nn
27
27
  from transformer_lens.hook_points import HookedRootModule, HookPoint
28
28
  from typing_extensions import deprecated, overload, override
29
29
 
30
- from sae_lens import __version__, logger
30
+ from sae_lens import __version__
31
31
  from sae_lens.constants import (
32
32
  DTYPE_MAP,
33
33
  SAE_CFG_FILENAME,
@@ -207,6 +207,8 @@ class TrainStepOutput:
207
207
  hidden_pre: torch.Tensor
208
208
  loss: torch.Tensor # we need to call backwards on this
209
209
  losses: dict[str, torch.Tensor]
210
+ # any extra metrics to log can be added here
211
+ metrics: dict[str, torch.Tensor | float | int] = field(default_factory=dict)
210
212
 
211
213
 
212
214
  @dataclass
@@ -528,28 +530,6 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
528
530
 
529
531
  return model_weights_path, cfg_path
530
532
 
531
- ## Initialization Methods
532
- @torch.no_grad()
533
- def initialize_b_dec_with_precalculated(self, origin: torch.Tensor):
534
- out = torch.tensor(origin, dtype=self.dtype, device=self.device)
535
- self.b_dec.data = out
536
-
537
- @torch.no_grad()
538
- def initialize_b_dec_with_mean(self, all_activations: torch.Tensor):
539
- previous_b_dec = self.b_dec.clone().cpu()
540
- out = all_activations.mean(dim=0)
541
-
542
- previous_distances = torch.norm(all_activations - previous_b_dec, dim=-1)
543
- distances = torch.norm(all_activations - out, dim=-1)
544
-
545
- logger.info("Reinitializing b_dec with mean of activations")
546
- logger.debug(
547
- f"Previous distances: {previous_distances.median(0).values.mean().item()}"
548
- )
549
- logger.debug(f"New distances: {distances.median(0).values.mean().item()}")
550
-
551
- self.b_dec.data = out.to(self.dtype).to(self.device)
552
-
553
533
  # Class methods for loading models
554
534
  @classmethod
555
535
  @deprecated("Use load_from_disk instead")
@@ -732,6 +712,64 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
732
712
  ) -> type[SAEConfig]:
733
713
  return SAEConfig
734
714
 
715
+ ### Methods to support deprecated usage of SAE.from_pretrained() ###
716
+
717
+ def __getitem__(self, index: int) -> Any:
718
+ """
719
+ Support indexing for backward compatibility with tuple unpacking.
720
+ DEPRECATED: SAE.from_pretrained() no longer returns a tuple.
721
+ Use SAE.from_pretrained_with_cfg_and_sparsity() instead.
722
+ """
723
+ warnings.warn(
724
+ "Indexing SAE objects is deprecated. SAE.from_pretrained() now returns "
725
+ "only the SAE object. Use SAE.from_pretrained_with_cfg_and_sparsity() "
726
+ "to get the config dict and sparsity as well.",
727
+ DeprecationWarning,
728
+ stacklevel=2,
729
+ )
730
+
731
+ if index == 0:
732
+ return self
733
+ if index == 1:
734
+ return self.cfg.to_dict()
735
+ if index == 2:
736
+ return None
737
+ raise IndexError(f"SAE tuple index {index} out of range")
738
+
739
+ def __iter__(self):
740
+ """
741
+ Support unpacking for backward compatibility with tuple unpacking.
742
+ DEPRECATED: SAE.from_pretrained() no longer returns a tuple.
743
+ Use SAE.from_pretrained_with_cfg_and_sparsity() instead.
744
+ """
745
+ warnings.warn(
746
+ "Unpacking SAE objects is deprecated. SAE.from_pretrained() now returns "
747
+ "only the SAE object. Use SAE.from_pretrained_with_cfg_and_sparsity() "
748
+ "to get the config dict and sparsity as well.",
749
+ DeprecationWarning,
750
+ stacklevel=2,
751
+ )
752
+
753
+ yield self
754
+ yield self.cfg.to_dict()
755
+ yield None
756
+
757
+ def __len__(self) -> int:
758
+ """
759
+ Support len() for backward compatibility with tuple unpacking.
760
+ DEPRECATED: SAE.from_pretrained() no longer returns a tuple.
761
+ Use SAE.from_pretrained_with_cfg_and_sparsity() instead.
762
+ """
763
+ warnings.warn(
764
+ "Getting length of SAE objects is deprecated. SAE.from_pretrained() now returns "
765
+ "only the SAE object. Use SAE.from_pretrained_with_cfg_and_sparsity() "
766
+ "to get the config dict and sparsity as well.",
767
+ DeprecationWarning,
768
+ stacklevel=2,
769
+ )
770
+
771
+ return 3
772
+
735
773
 
736
774
  @dataclass(kw_only=True)
737
775
  class TrainingSAEConfig(SAEConfig, ABC):
@@ -789,20 +827,26 @@ class TrainingSAEConfig(SAEConfig, ABC):
789
827
  "architecture": self.architecture(),
790
828
  }
791
829
 
830
+ def get_inference_config_class(self) -> type[SAEConfig]:
831
+ """
832
+ Get the architecture for inference.
833
+ """
834
+ return get_sae_class(self.architecture())[1]
835
+
792
836
  # this needs to exist so we can initialize the parent sae cfg without the training specific
793
837
  # parameters. Maybe there's a cleaner way to do this
794
- def get_base_sae_cfg_dict(self) -> dict[str, Any]:
838
+ def get_inference_sae_cfg_dict(self) -> dict[str, Any]:
795
839
  """
796
840
  Creates a dictionary containing attributes corresponding to the fields
797
841
  defined in the base SAEConfig class.
798
842
  """
799
- base_sae_cfg_class = get_sae_class(self.architecture())[1]
843
+ base_sae_cfg_class = self.get_inference_config_class()
800
844
  base_config_field_names = {f.name for f in fields(base_sae_cfg_class)}
801
845
  result_dict = {
802
846
  field_name: getattr(self, field_name)
803
847
  for field_name in base_config_field_names
804
848
  }
805
- result_dict["architecture"] = self.architecture()
849
+ result_dict["architecture"] = base_sae_cfg_class.architecture()
806
850
  result_dict["metadata"] = self.metadata.to_dict()
807
851
  return result_dict
808
852
 
@@ -930,18 +974,13 @@ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
930
974
  save_file(state_dict, model_weights_path)
931
975
 
932
976
  # Save the config
933
- config = self.to_inference_config_dict()
977
+ config = self.cfg.get_inference_sae_cfg_dict()
934
978
  cfg_path = path / SAE_CFG_FILENAME
935
979
  with open(cfg_path, "w") as f:
936
980
  json.dump(config, f)
937
981
 
938
982
  return model_weights_path, cfg_path
939
983
 
940
- @abstractmethod
941
- def to_inference_config_dict(self) -> dict[str, Any]:
942
- """Convert the config into an inference SAE config dict."""
943
- ...
944
-
945
984
  def process_state_dict_for_saving_inference(
946
985
  self, state_dict: dict[str, Any]
947
986
  ) -> None:
@@ -951,23 +990,6 @@ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
951
990
  """
952
991
  return self.process_state_dict_for_saving(state_dict)
953
992
 
954
- @torch.no_grad()
955
- def remove_gradient_parallel_to_decoder_directions(self) -> None:
956
- """Remove gradient components parallel to decoder directions."""
957
- # Implement the original logic since this may not be in the base class
958
- assert self.W_dec.grad is not None
959
-
960
- parallel_component = einops.einsum(
961
- self.W_dec.grad,
962
- self.W_dec.data,
963
- "d_sae d_in, d_sae d_in -> d_sae",
964
- )
965
- self.W_dec.grad -= einops.einsum(
966
- parallel_component,
967
- self.W_dec.data,
968
- "d_sae, d_sae d_in -> d_sae d_in",
969
- )
970
-
971
993
  @torch.no_grad()
972
994
  def log_histograms(self) -> dict[str, NDArray[Any]]:
973
995
  """Log histograms of the weights and biases."""
@@ -1,5 +1,4 @@
1
1
  from dataclasses import dataclass
2
- from typing import Any
3
2
 
4
3
  import numpy as np
5
4
  import torch
@@ -16,7 +15,6 @@ from sae_lens.saes.sae import (
16
15
  TrainingSAEConfig,
17
16
  TrainStepInput,
18
17
  )
19
- from sae_lens.util import filter_valid_dataclass_fields
20
18
 
21
19
 
22
20
  @dataclass
@@ -61,7 +59,6 @@ class StandardSAE(SAE[StandardSAEConfig]):
61
59
  ) -> Float[torch.Tensor, "... d_sae"]:
62
60
  """
63
61
  Encode the input tensor into the feature space.
64
- For inference, no noise is added.
65
62
  """
66
63
  # Preprocess the SAE input (casting type, applying hooks, normalization)
67
64
  sae_in = self.process_sae_in(x)
@@ -110,7 +107,7 @@ class StandardTrainingSAE(TrainingSAE[StandardTrainingSAEConfig]):
110
107
  - initialize_weights: basic weight initialization for encoder/decoder.
111
108
  - encode: inference encoding (invokes encode_with_hidden_pre).
112
109
  - decode: a simple linear decoder.
113
- - encode_with_hidden_pre: computes pre-activations, adds noise when training, and then activates.
110
+ - encode_with_hidden_pre: computes activations and pre-activations.
114
111
  - calculate_aux_loss: computes a sparsity penalty based on the (optionally scaled) p-norm of feature activations.
115
112
  """
116
113
 
@@ -164,11 +161,6 @@ class StandardTrainingSAE(TrainingSAE[StandardTrainingSAEConfig]):
164
161
  "weights/b_e": b_e_dist,
165
162
  }
166
163
 
167
- def to_inference_config_dict(self) -> dict[str, Any]:
168
- return filter_valid_dataclass_fields(
169
- self.cfg.to_dict(), StandardSAEConfig, ["architecture"]
170
- )
171
-
172
164
 
173
165
  def _init_weights_standard(
174
166
  sae: SAE[StandardSAEConfig] | TrainingSAE[StandardTrainingSAEConfig],
sae_lens/saes/topk_sae.py CHANGED
@@ -1,7 +1,7 @@
1
1
  """Inference-only TopKSAE variant, similar in spirit to StandardSAE but using a TopK-based activation."""
2
2
 
3
3
  from dataclasses import dataclass
4
- from typing import Any, Callable
4
+ from typing import Callable
5
5
 
6
6
  import torch
7
7
  from jaxtyping import Float
@@ -16,13 +16,12 @@ from sae_lens.saes.sae import (
16
16
  TrainingSAEConfig,
17
17
  TrainStepInput,
18
18
  )
19
- from sae_lens.util import filter_valid_dataclass_fields
20
19
 
21
20
 
22
21
  class TopK(nn.Module):
23
22
  """
24
23
  A simple TopK activation that zeroes out all but the top K elements along the last dimension,
25
- then optionally applies a post-activation function (e.g., ReLU).
24
+ and applies ReLU to the top K elements.
26
25
  """
27
26
 
28
27
  b_enc: nn.Parameter
@@ -30,20 +29,18 @@ class TopK(nn.Module):
30
29
  def __init__(
31
30
  self,
32
31
  k: int,
33
- postact_fn: Callable[[torch.Tensor], torch.Tensor] = nn.ReLU(),
34
32
  ):
35
33
  super().__init__()
36
34
  self.k = k
37
- self.postact_fn = postact_fn
38
35
 
39
36
  def forward(self, x: torch.Tensor) -> torch.Tensor:
40
37
  """
41
38
  1) Select top K elements along the last dimension.
42
- 2) Apply post-activation (often ReLU).
39
+ 2) Apply ReLU.
43
40
  3) Zero out all other entries.
44
41
  """
45
42
  topk = torch.topk(x, k=self.k, dim=-1)
46
- values = self.postact_fn(topk.values)
43
+ values = topk.values.relu()
47
44
  result = torch.zeros_like(x)
48
45
  result.scatter_(-1, topk.indices, values)
49
46
  return result
@@ -139,8 +136,7 @@ class TopKTrainingSAEConfig(TrainingSAEConfig):
139
136
 
140
137
  class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
141
138
  """
142
- TopK variant with training functionality. Injects noise during training, optionally
143
- calculates a topk-related auxiliary loss, etc.
139
+ TopK variant with training functionality. Calculates a topk-related auxiliary loss, etc.
144
140
  """
145
141
 
146
142
  b_enc: nn.Parameter
@@ -157,7 +153,7 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
157
153
  self, x: Float[torch.Tensor, "... d_in"]
158
154
  ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
159
155
  """
160
- Similar to the base training method: cast input, optionally add noise, then apply TopK.
156
+ Similar to the base training method: calculate pre-activations, then apply TopK.
161
157
  """
162
158
  sae_in = self.process_sae_in(x)
163
159
  hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
@@ -237,50 +233,24 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
237
233
  auxk_loss = (recons - residual).pow(2).sum(dim=-1).mean()
238
234
  return scale * auxk_loss
239
235
 
240
- def _calculate_topk_aux_acts(
241
- self,
242
- k_aux: int,
243
- hidden_pre: torch.Tensor,
244
- dead_neuron_mask: torch.Tensor,
245
- ) -> torch.Tensor:
246
- """
247
- Helper method to calculate activations for the auxiliary loss.
248
-
249
- Args:
250
- k_aux: Number of top dead neurons to select
251
- hidden_pre: Pre-activation values from encoder
252
- dead_neuron_mask: Boolean mask indicating which neurons are dead
253
-
254
- Returns:
255
- Tensor with activations for only the top-k dead neurons, zeros elsewhere
256
- """
257
- # Don't include living latents in this loss (set them to -inf so they won't be selected)
258
- auxk_latents = torch.where(
259
- dead_neuron_mask[None],
260
- hidden_pre,
261
- torch.tensor(-float("inf"), device=hidden_pre.device),
262
- )
263
-
264
- # Find topk values among dead neurons
265
- auxk_topk = auxk_latents.topk(k_aux, dim=-1, sorted=False)
266
-
267
- # Create a tensor of zeros, then place the topk values at their proper indices
268
- auxk_acts = torch.zeros_like(hidden_pre)
269
- auxk_acts.scatter_(-1, auxk_topk.indices, auxk_topk.values)
270
-
271
- return auxk_acts
272
-
273
- def to_inference_config_dict(self) -> dict[str, Any]:
274
- return filter_valid_dataclass_fields(
275
- self.cfg.to_dict(), TopKSAEConfig, ["architecture"]
276
- )
277
-
278
236
 
279
237
  def _calculate_topk_aux_acts(
280
238
  k_aux: int,
281
239
  hidden_pre: torch.Tensor,
282
240
  dead_neuron_mask: torch.Tensor,
283
241
  ) -> torch.Tensor:
242
+ """
243
+ Helper method to calculate activations for the auxiliary loss.
244
+
245
+ Args:
246
+ k_aux: Number of top dead neurons to select
247
+ hidden_pre: Pre-activation values from encoder
248
+ dead_neuron_mask: Boolean mask indicating which neurons are dead
249
+
250
+ Returns:
251
+ Tensor with activations for only the top-k dead neurons, zeros elsewhere
252
+ """
253
+
284
254
  # Don't include living latents in this loss
285
255
  auxk_latents = torch.where(dead_neuron_mask[None], hidden_pre, -torch.inf)
286
256
  # Top-k dead latents
@@ -7,7 +7,6 @@ from collections.abc import Generator, Iterator, Sequence
7
7
  from typing import Any, Literal, cast
8
8
 
9
9
  import datasets
10
- import numpy as np
11
10
  import torch
12
11
  from datasets import Dataset, DatasetDict, IterableDataset, load_dataset
13
12
  from huggingface_hub import hf_hub_download
@@ -420,20 +419,6 @@ class ActivationsStore:
420
419
 
421
420
  return activations_dataset
422
421
 
423
- @torch.no_grad()
424
- def estimate_norm_scaling_factor(self, n_batches_for_norm_estimate: int = int(1e3)):
425
- norms_per_batch = []
426
- for _ in tqdm(
427
- range(n_batches_for_norm_estimate), desc="Estimating norm scaling factor"
428
- ):
429
- # temporalily set estimated_norm_scaling_factor to 1.0 so the dataloader works
430
- self.estimated_norm_scaling_factor = 1.0
431
- acts = self.next_batch()[0]
432
- self.estimated_norm_scaling_factor = None
433
- norms_per_batch.append(acts.norm(dim=-1).mean().item())
434
- mean_norm = np.mean(norms_per_batch)
435
- return np.sqrt(self.d_in) / mean_norm
436
-
437
422
  def shuffle_input_dataset(self, seed: int, buffer_size: int = 1):
438
423
  """
439
424
  This applies a shuffle to the huggingface dataset that is the input to the activations store. This
@@ -2,8 +2,6 @@
2
2
  Took the LR scheduler from my previous work: https://github.com/jbloomAus/DecisionTransformerInterpretability/blob/ee55df35cdb92e81d689c72fb9dd5a7252893363/src/decision_transformer/utils.py#L425
3
3
  """
4
4
 
5
- from typing import Any
6
-
7
5
  import torch.optim as optim
8
6
  import torch.optim.lr_scheduler as lr_scheduler
9
7
 
@@ -152,34 +150,3 @@ class CoefficientScheduler:
152
150
  def value(self) -> float:
153
151
  """Returns the current scalar value."""
154
152
  return self.current_value
155
-
156
- def state_dict(self) -> dict[str, Any]:
157
- """State dict for serialization."""
158
- return {
159
- "warm_up_steps": self.warm_up_steps,
160
- "final_value": self.final_value,
161
- "current_step": self.current_step,
162
- "current_value": self.current_value,
163
- }
164
-
165
- def load_state_dict(self, state_dict: dict[str, Any]):
166
- """Loads the scheduler state."""
167
- self.warm_up_steps = state_dict["warm_up_steps"]
168
- self.final_value = state_dict["final_value"]
169
- self.current_step = state_dict["current_step"]
170
- # Maintain consistency: re-calculate current_value based on loaded step
171
- # This handles resuming correctly if stopped mid-warmup.
172
- if self.current_step <= self.warm_up_steps and self.warm_up_steps > 0:
173
- # Use max(0, ...) to handle case where current_step might be loaded as -1 or similar before first step
174
- step_for_calc = max(0, self.current_step)
175
- # Recalculate based on the step *before* the one about to be taken
176
- # Or simply use the saved current_value if available and consistent
177
- if "current_value" in state_dict:
178
- self.current_value = state_dict["current_value"]
179
- else: # Legacy state dicts might not have current_value
180
- self.current_value = self.final_value * (
181
- step_for_calc / self.warm_up_steps
182
- )
183
-
184
- else:
185
- self.current_value = self.final_value
@@ -349,8 +349,10 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
349
349
  },
350
350
  }
351
351
  for loss_name, loss_value in output.losses.items():
352
- loss_item = _unwrap_item(loss_value)
353
- log_dict[f"losses/{loss_name}"] = loss_item
352
+ log_dict[f"losses/{loss_name}"] = _unwrap_item(loss_value)
353
+
354
+ for metric_name, metric_value in output.metrics.items():
355
+ log_dict[f"metrics/{metric_name}"] = _unwrap_item(metric_value)
354
356
 
355
357
  return log_dict
356
358
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: sae-lens
3
- Version: 6.0.0rc4
3
+ Version: 6.1.0
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  Keywords: deep-learning,sparse-autoencoders,mechanistic-interpretability,PyTorch
@@ -80,7 +80,7 @@ Pre-trained SAEs for various models can be imported via SAE Lens. See this [page
80
80
 
81
81
  ## Join the Slack!
82
82
 
83
- Feel free to join the [Open Source Mechanistic Interpretability Slack](https://join.slack.com/t/opensourcemechanistic/shared_invite/zt-2o756ku1c-_yKBeUQMVfS_p_qcK6QLeA) for support!
83
+ Feel free to join the [Open Source Mechanistic Interpretability Slack](https://join.slack.com/t/opensourcemechanistic/shared_invite/zt-375zalm04-GFd5tdBU1yLKlu_T_JSqZQ) for support!
84
84
 
85
85
  ## Citation
86
86
 
@@ -0,0 +1,38 @@
1
+ sae_lens/__init__.py,sha256=vM8ncfMn8YHyl1CHj48L2pG6FWJ54--3blxrY3WtJww,3073
2
+ sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
+ sae_lens/analysis/hooked_sae_transformer.py,sha256=Eyg1Y2hVIHNuiiLOCTgzstOuW6iA-7hPHqaGR8y_vMs,13809
4
+ sae_lens/analysis/neuronpedia_integration.py,sha256=MrENqc81Mc2SMbxGjbwHzpkGUCAFKSf0i4EdaUF2Oj4,18707
5
+ sae_lens/cache_activations_runner.py,sha256=L5hhuU2-zPQr2S3L64GMKKLeMQfqXxwDl8NbuOtrybI,12567
6
+ sae_lens/config.py,sha256=qMMx9KuiXTD5lG3g0VzaekWOnvdAzGFSq8j1n-GObEQ,26467
7
+ sae_lens/constants.py,sha256=CSjmiZ-bhjQeVLyRvWxAjBokCgkfM8mnvd7-vxLIWTY,639
8
+ sae_lens/evals.py,sha256=kQyrzczKaVD9rHwfFa_DxL_gMXDxsoIVHmsFIPIU2bY,38696
9
+ sae_lens/llm_sae_training_runner.py,sha256=58XbDylw2fPOD7C-ZfSAjeNqJLXB05uHGTuiYVVbXXY,13354
10
+ sae_lens/load_model.py,sha256=dBB_9gO6kWyQ4sXHq7qB8T3YUlXm3PGwYcpR4UVW4QY,8633
11
+ sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
+ sae_lens/loading/pretrained_sae_loaders.py,sha256=5XEU4uFFeGCePwqDwhlE7CrFGRSI0U9Cu-UQVa33Y1E,36432
13
+ sae_lens/loading/pretrained_saes_directory.py,sha256=4Vn-Jex6SveD7EbxcSOBv8cx1gkPfUMLU1QOP-ww1ZE,3752
14
+ sae_lens/pretokenize_runner.py,sha256=0nHQq3s_d80VS8iVK4-e6y_orAYVO8c4RrLGtIDfK_E,6885
15
+ sae_lens/pretrained_saes.yaml,sha256=nhHW1auhyi4GHYrjUnHQqbNVhI5cMJv-HThzbzU1xG0,574145
16
+ sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
17
+ sae_lens/saes/__init__.py,sha256=RYqE1qkMws-kwQLmBZFhA_VCa69zVtBjGPIy_UAk2pw,1159
18
+ sae_lens/saes/batchtopk_sae.py,sha256=CyaFG2hMyyDaEaXXrAMJC8wQDW1JoddTKF5mvxxBQKY,3395
19
+ sae_lens/saes/gated_sae.py,sha256=qcmM9JwBA8aZR8z_IRHV1_gQX-q_63tKewWXRnhdXuo,8986
20
+ sae_lens/saes/jumprelu_sae.py,sha256=3xkhBcCol2mEpIBLceymCpudocm2ypOjTeTXbpiXoA4,10794
21
+ sae_lens/saes/sae.py,sha256=McpF4pTh70r6SQUbHFm0YQ9X2c2qPULBUSd_YmnEk4Y,38284
22
+ sae_lens/saes/standard_sae.py,sha256=9UqYyYtQuThYxXKNaDjYcyowpOx2-7cShG-TeUP6JCQ,5940
23
+ sae_lens/saes/topk_sae.py,sha256=CH8LGtSQOrbA_xOdqZUkDCG7TOS81CeQJeyLEpPricU,8616
24
+ sae_lens/tokenization_and_batching.py,sha256=oUAscjy_LPOrOb8_Ty6eLAcZ0B3HB_wiWjWktgolhG0,4314
25
+ sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
26
+ sae_lens/training/activation_scaler.py,sha256=seEE-2Qd2JMHxqgnsNWPt-DGtYGZxWPnOwCGuVNSOtI,1719
27
+ sae_lens/training/activations_store.py,sha256=HBN3oEib3PlPUDJb_yVFabQp0JcN9rWbnUN1s2DBMAs,31933
28
+ sae_lens/training/mixing_buffer.py,sha256=vDpYG5ZE70szDvBsRKcNHEES3h_WTKJ16qDYk5jPOVA,2015
29
+ sae_lens/training/optim.py,sha256=TiI9nbffzXNsI8WjcIsqa2uheW6suxqL_KDDmWXobWI,5312
30
+ sae_lens/training/sae_trainer.py,sha256=2xcO-02OozFunob5vwoHud-hVMhVl9d28_F9gDCiL6o,15529
31
+ sae_lens/training/types.py,sha256=qSjmGzXf3MLalygG0psnVjmhX_mpLmL47MQtZfe7qxg,81
32
+ sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
33
+ sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
34
+ sae_lens/util.py,sha256=mCwLAilGMVo8Scm7CIsCafU7GsfmBvCcjwmloI4Ly7Y,1718
35
+ sae_lens-6.1.0.dist-info/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
36
+ sae_lens-6.1.0.dist-info/METADATA,sha256=AjB2PWa1s8CCluq-_jjeBj7OsCSswoRP5GEGGSoNjHo,5323
37
+ sae_lens-6.1.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
38
+ sae_lens-6.1.0.dist-info/RECORD,,
@@ -1,37 +0,0 @@
1
- sae_lens/__init__.py,sha256=dGZU3Y6iwiuW5oQVTfNvUmfnHO3bHWWbpU-nvXvw9M8,2861
2
- sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- sae_lens/analysis/hooked_sae_transformer.py,sha256=Eyg1Y2hVIHNuiiLOCTgzstOuW6iA-7hPHqaGR8y_vMs,13809
4
- sae_lens/analysis/neuronpedia_integration.py,sha256=MrENqc81Mc2SMbxGjbwHzpkGUCAFKSf0i4EdaUF2Oj4,18707
5
- sae_lens/cache_activations_runner.py,sha256=L5hhuU2-zPQr2S3L64GMKKLeMQfqXxwDl8NbuOtrybI,12567
6
- sae_lens/config.py,sha256=9Lg4HkQvj1t9QZJdmC071lyJMc_iqNQknosT7zOYfwM,27278
7
- sae_lens/constants.py,sha256=RJlzWx7wLNMNmrdI63naF7-M3enb55vYRN4x1hXx6vI,593
8
- sae_lens/evals.py,sha256=PIMGQobE9o2bHksDAtQe5bnTMYyHoZKB_elFhDOjrmo,38991
9
- sae_lens/llm_sae_training_runner.py,sha256=58XbDylw2fPOD7C-ZfSAjeNqJLXB05uHGTuiYVVbXXY,13354
10
- sae_lens/load_model.py,sha256=dBB_9gO6kWyQ4sXHq7qB8T3YUlXm3PGwYcpR4UVW4QY,8633
11
- sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- sae_lens/loading/pretrained_sae_loaders.py,sha256=kbirwfCg4Ks9Cg3rt78bYxIHMhz5h015n0UTRJQLJY0,31291
13
- sae_lens/loading/pretrained_saes_directory.py,sha256=4Vn-Jex6SveD7EbxcSOBv8cx1gkPfUMLU1QOP-ww1ZE,3752
14
- sae_lens/pretokenize_runner.py,sha256=0nHQq3s_d80VS8iVK4-e6y_orAYVO8c4RrLGtIDfK_E,6885
15
- sae_lens/pretrained_saes.yaml,sha256=C_z-7Lxz6ZIy2V-c-4Xw45eAQ926O9aGjocSNuki0xs,573557
16
- sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
17
- sae_lens/saes/__init__.py,sha256=v6mfeDzyGYtT6x5SszAQtkldTXwPE-V_iwOlrT_pDwQ,1008
18
- sae_lens/saes/gated_sae.py,sha256=0zd66bH04nsaGk3bxHk10hsZofa2GrFbMo15LOsuqgU,9233
19
- sae_lens/saes/jumprelu_sae.py,sha256=iwmPQJ4XpIxzgosty680u8Zj7x1uVZhM75kPOT3obi0,12060
20
- sae_lens/saes/sae.py,sha256=HAGkJAj_FIDzbSR1dsG8b2AyMq8UauUU_yx-LvdfjuE,37465
21
- sae_lens/saes/standard_sae.py,sha256=PfkGLsw_6La3PXHOQL0u7qQsaZsXCJqYCeCcRDj5n64,6274
22
- sae_lens/saes/topk_sae.py,sha256=kmry1FE1H06OvCfn84V-j2JfWGKcU5b2urwAq_Oq5j4,9893
23
- sae_lens/tokenization_and_batching.py,sha256=oUAscjy_LPOrOb8_Ty6eLAcZ0B3HB_wiWjWktgolhG0,4314
24
- sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
25
- sae_lens/training/activation_scaler.py,sha256=seEE-2Qd2JMHxqgnsNWPt-DGtYGZxWPnOwCGuVNSOtI,1719
26
- sae_lens/training/activations_store.py,sha256=s3Qvztv2siuuXSuXEUDZYSKq1QQCsqsGXY767kv6grc,32609
27
- sae_lens/training/mixing_buffer.py,sha256=vDpYG5ZE70szDvBsRKcNHEES3h_WTKJ16qDYk5jPOVA,2015
28
- sae_lens/training/optim.py,sha256=KXdOym-Ly3f2aFbndRc0JEH0Wa7u1BE5ljxGN3YtouQ,6836
29
- sae_lens/training/sae_trainer.py,sha256=9K0VudwSTJp9OlCVzaU_ngZ0WlYNrN6-ozTCCAxR9_k,15421
30
- sae_lens/training/types.py,sha256=qSjmGzXf3MLalygG0psnVjmhX_mpLmL47MQtZfe7qxg,81
31
- sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
32
- sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
33
- sae_lens/util.py,sha256=mCwLAilGMVo8Scm7CIsCafU7GsfmBvCcjwmloI4Ly7Y,1718
34
- sae_lens-6.0.0rc4.dist-info/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
35
- sae_lens-6.0.0rc4.dist-info/METADATA,sha256=wOQMSV4yNlpgpGxuE4DI0-q4KzTRYOg1m9ZxpdCsNjk,5326
36
- sae_lens-6.0.0rc4.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
37
- sae_lens-6.0.0rc4.dist-info/RECORD,,