sae-lens 6.0.0rc2__py3-none-any.whl → 6.0.0rc4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +6 -3
- sae_lens/analysis/neuronpedia_integration.py +3 -3
- sae_lens/cache_activations_runner.py +7 -6
- sae_lens/config.py +50 -6
- sae_lens/constants.py +2 -0
- sae_lens/evals.py +39 -28
- sae_lens/llm_sae_training_runner.py +377 -0
- sae_lens/load_model.py +53 -5
- sae_lens/loading/pretrained_sae_loaders.py +24 -12
- sae_lens/saes/gated_sae.py +0 -4
- sae_lens/saes/jumprelu_sae.py +4 -10
- sae_lens/saes/sae.py +121 -51
- sae_lens/saes/standard_sae.py +4 -11
- sae_lens/saes/topk_sae.py +18 -12
- sae_lens/training/activation_scaler.py +53 -0
- sae_lens/training/activations_store.py +77 -174
- sae_lens/training/mixing_buffer.py +56 -0
- sae_lens/training/sae_trainer.py +107 -98
- sae_lens/training/types.py +5 -0
- sae_lens/training/upload_saes_to_huggingface.py +1 -1
- sae_lens/util.py +19 -0
- {sae_lens-6.0.0rc2.dist-info → sae_lens-6.0.0rc4.dist-info}/METADATA +1 -1
- sae_lens-6.0.0rc4.dist-info/RECORD +37 -0
- sae_lens/sae_training_runner.py +0 -237
- sae_lens/training/geometric_median.py +0 -101
- sae_lens-6.0.0rc2.dist-info/RECORD +0 -35
- {sae_lens-6.0.0rc2.dist-info → sae_lens-6.0.0rc4.dist-info}/LICENSE +0 -0
- {sae_lens-6.0.0rc2.dist-info → sae_lens-6.0.0rc4.dist-info}/WHEEL +0 -0
sae_lens/saes/sae.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Base classes for Sparse Autoencoders (SAEs)."""
|
|
2
2
|
|
|
3
|
+
import copy
|
|
3
4
|
import json
|
|
4
5
|
import warnings
|
|
5
6
|
from abc import ABC, abstractmethod
|
|
@@ -59,24 +60,91 @@ T_SAE = TypeVar("T_SAE", bound="SAE") # type: ignore
|
|
|
59
60
|
T_TRAINING_SAE = TypeVar("T_TRAINING_SAE", bound="TrainingSAE") # type: ignore
|
|
60
61
|
|
|
61
62
|
|
|
62
|
-
@dataclass
|
|
63
63
|
class SAEMetadata:
|
|
64
64
|
"""Core metadata about how this SAE should be used, if known."""
|
|
65
65
|
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
66
|
+
def __init__(self, **kwargs: Any):
|
|
67
|
+
# Set default version fields with their current behavior
|
|
68
|
+
self.sae_lens_version = kwargs.pop("sae_lens_version", __version__)
|
|
69
|
+
self.sae_lens_training_version = kwargs.pop(
|
|
70
|
+
"sae_lens_training_version", __version__
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
# Set all other attributes dynamically
|
|
74
|
+
for key, value in kwargs.items():
|
|
75
|
+
setattr(self, key, value)
|
|
76
|
+
|
|
77
|
+
def __getattr__(self, name: str) -> None:
|
|
78
|
+
"""Return None for any missing attribute (like defaultdict)"""
|
|
79
|
+
return
|
|
80
|
+
|
|
81
|
+
def __setattr__(self, name: str, value: Any) -> None:
|
|
82
|
+
"""Allow setting any attribute"""
|
|
83
|
+
super().__setattr__(name, value)
|
|
84
|
+
|
|
85
|
+
def __getitem__(self, key: str) -> Any:
|
|
86
|
+
"""Allow dictionary-style access: metadata['key'] - returns None for missing keys"""
|
|
87
|
+
return getattr(self, key)
|
|
88
|
+
|
|
89
|
+
def __setitem__(self, key: str, value: Any) -> None:
|
|
90
|
+
"""Allow dictionary-style assignment: metadata['key'] = value"""
|
|
91
|
+
setattr(self, key, value)
|
|
92
|
+
|
|
93
|
+
def __contains__(self, key: str) -> bool:
|
|
94
|
+
"""Allow 'in' operator: 'key' in metadata"""
|
|
95
|
+
# Only return True if the attribute was explicitly set (not just defaulting to None)
|
|
96
|
+
return key in self.__dict__
|
|
97
|
+
|
|
98
|
+
def get(self, key: str, default: Any = None) -> Any:
|
|
99
|
+
"""Dictionary-style get with default"""
|
|
100
|
+
value = getattr(self, key)
|
|
101
|
+
# If the attribute wasn't explicitly set and we got None from __getattr__,
|
|
102
|
+
# use the provided default instead
|
|
103
|
+
if key not in self.__dict__ and value is None:
|
|
104
|
+
return default
|
|
105
|
+
return value
|
|
106
|
+
|
|
107
|
+
def keys(self):
|
|
108
|
+
"""Return all explicitly set attribute names"""
|
|
109
|
+
return self.__dict__.keys()
|
|
110
|
+
|
|
111
|
+
def values(self):
|
|
112
|
+
"""Return all explicitly set attribute values"""
|
|
113
|
+
return self.__dict__.values()
|
|
114
|
+
|
|
115
|
+
def items(self):
|
|
116
|
+
"""Return all explicitly set attribute name-value pairs"""
|
|
117
|
+
return self.__dict__.items()
|
|
118
|
+
|
|
119
|
+
def to_dict(self) -> dict[str, Any]:
|
|
120
|
+
"""Convert to dictionary for serialization"""
|
|
121
|
+
return self.__dict__.copy()
|
|
122
|
+
|
|
123
|
+
@classmethod
|
|
124
|
+
def from_dict(cls, data: dict[str, Any]) -> "SAEMetadata":
|
|
125
|
+
"""Create from dictionary"""
|
|
126
|
+
return cls(**data)
|
|
127
|
+
|
|
128
|
+
def __repr__(self) -> str:
|
|
129
|
+
return f"SAEMetadata({self.__dict__})"
|
|
130
|
+
|
|
131
|
+
def __eq__(self, other: object) -> bool:
|
|
132
|
+
if not isinstance(other, SAEMetadata):
|
|
133
|
+
return False
|
|
134
|
+
return self.__dict__ == other.__dict__
|
|
135
|
+
|
|
136
|
+
def __deepcopy__(self, memo: dict[int, Any]) -> "SAEMetadata":
|
|
137
|
+
"""Support for deep copying"""
|
|
138
|
+
|
|
139
|
+
return SAEMetadata(**copy.deepcopy(self.__dict__, memo))
|
|
140
|
+
|
|
141
|
+
def __getstate__(self) -> dict[str, Any]:
|
|
142
|
+
"""Support for pickling"""
|
|
143
|
+
return self.__dict__
|
|
144
|
+
|
|
145
|
+
def __setstate__(self, state: dict[str, Any]) -> None:
|
|
146
|
+
"""Support for unpickling"""
|
|
147
|
+
self.__dict__.update(state)
|
|
80
148
|
|
|
81
149
|
|
|
82
150
|
@dataclass
|
|
@@ -100,7 +168,7 @@ class SAEConfig(ABC):
|
|
|
100
168
|
|
|
101
169
|
def to_dict(self) -> dict[str, Any]:
|
|
102
170
|
res = {field.name: getattr(self, field.name) for field in fields(self)}
|
|
103
|
-
res["metadata"] =
|
|
171
|
+
res["metadata"] = self.metadata.to_dict()
|
|
104
172
|
res["architecture"] = self.architecture()
|
|
105
173
|
return res
|
|
106
174
|
|
|
@@ -125,7 +193,7 @@ class SAEConfig(ABC):
|
|
|
125
193
|
"layer_norm",
|
|
126
194
|
]:
|
|
127
195
|
raise ValueError(
|
|
128
|
-
f"normalize_activations must be none, expected_average_only_in,
|
|
196
|
+
f"normalize_activations must be none, expected_average_only_in, layer_norm, or constant_norm_rescale. Got {self.normalize_activations}"
|
|
129
197
|
)
|
|
130
198
|
|
|
131
199
|
|
|
@@ -239,9 +307,8 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
239
307
|
|
|
240
308
|
self.run_time_activation_norm_fn_in = run_time_activation_norm_fn_in
|
|
241
309
|
self.run_time_activation_norm_fn_out = run_time_activation_norm_fn_out
|
|
242
|
-
|
|
243
310
|
elif self.cfg.normalize_activations == "layer_norm":
|
|
244
|
-
|
|
311
|
+
# we need to scale the norm of the input and store the scaling factor
|
|
245
312
|
def run_time_activation_ln_in(
|
|
246
313
|
x: torch.Tensor, eps: float = 1e-5
|
|
247
314
|
) -> torch.Tensor:
|
|
@@ -523,7 +590,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
523
590
|
device: str = "cpu",
|
|
524
591
|
force_download: bool = False,
|
|
525
592
|
converter: PretrainedSaeHuggingfaceLoader | None = None,
|
|
526
|
-
) ->
|
|
593
|
+
) -> T_SAE:
|
|
527
594
|
"""
|
|
528
595
|
Load a pretrained SAE from the Hugging Face model hub.
|
|
529
596
|
|
|
@@ -531,7 +598,28 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
531
598
|
release: The release name. This will be mapped to a huggingface repo id based on the pretrained_saes.yaml file.
|
|
532
599
|
id: The id of the SAE to load. This will be mapped to a path in the huggingface repo.
|
|
533
600
|
device: The device to load the SAE on.
|
|
534
|
-
|
|
601
|
+
"""
|
|
602
|
+
return cls.from_pretrained_with_cfg_and_sparsity(
|
|
603
|
+
release, sae_id, device, force_download, converter=converter
|
|
604
|
+
)[0]
|
|
605
|
+
|
|
606
|
+
@classmethod
|
|
607
|
+
def from_pretrained_with_cfg_and_sparsity(
|
|
608
|
+
cls: Type[T_SAE],
|
|
609
|
+
release: str,
|
|
610
|
+
sae_id: str,
|
|
611
|
+
device: str = "cpu",
|
|
612
|
+
force_download: bool = False,
|
|
613
|
+
converter: PretrainedSaeHuggingfaceLoader | None = None,
|
|
614
|
+
) -> tuple[T_SAE, dict[str, Any], torch.Tensor | None]:
|
|
615
|
+
"""
|
|
616
|
+
Load a pretrained SAE from the Hugging Face model hub, along with its config dict and sparsity, if present.
|
|
617
|
+
In SAELens <= 5.x.x, this was called SAE.from_pretrained().
|
|
618
|
+
|
|
619
|
+
Args:
|
|
620
|
+
release: The release name. This will be mapped to a huggingface repo id based on the pretrained_saes.yaml file.
|
|
621
|
+
id: The id of the SAE to load. This will be mapped to a path in the huggingface repo.
|
|
622
|
+
device: The device to load the SAE on.
|
|
535
623
|
"""
|
|
536
624
|
|
|
537
625
|
# get sae directory
|
|
@@ -647,9 +735,6 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
647
735
|
|
|
648
736
|
@dataclass(kw_only=True)
|
|
649
737
|
class TrainingSAEConfig(SAEConfig, ABC):
|
|
650
|
-
noise_scale: float = 0.0
|
|
651
|
-
mse_loss_normalization: str | None = None
|
|
652
|
-
b_dec_init_method: Literal["zeros", "geometric_median", "mean"] = "zeros"
|
|
653
738
|
# https://transformer-circuits.pub/2024/april-update/index.html#training-saes
|
|
654
739
|
# 0.1 corresponds to the "heuristic" initialization, use None to disable
|
|
655
740
|
decoder_init_norm: float | None = 0.1
|
|
@@ -666,7 +751,6 @@ class TrainingSAEConfig(SAEConfig, ABC):
|
|
|
666
751
|
metadata = SAEMetadata(
|
|
667
752
|
model_name=cfg.model_name,
|
|
668
753
|
hook_name=cfg.hook_name,
|
|
669
|
-
hook_layer=cfg.hook_layer,
|
|
670
754
|
hook_head_index=cfg.hook_head_index,
|
|
671
755
|
context_size=cfg.context_size,
|
|
672
756
|
prepend_bos=cfg.prepend_bos,
|
|
@@ -683,9 +767,6 @@ class TrainingSAEConfig(SAEConfig, ABC):
|
|
|
683
767
|
def from_dict(
|
|
684
768
|
cls: type[T_TRAINING_SAE_CONFIG], config_dict: dict[str, Any]
|
|
685
769
|
) -> T_TRAINING_SAE_CONFIG:
|
|
686
|
-
# remove any keys that are not in the dataclass
|
|
687
|
-
# since we sometimes enhance the config with the whole LM runner config
|
|
688
|
-
valid_config_dict = filter_valid_dataclass_fields(config_dict, cls)
|
|
689
770
|
cfg_class = cls
|
|
690
771
|
if "architecture" in config_dict:
|
|
691
772
|
cfg_class = get_sae_training_class(config_dict["architecture"])[1]
|
|
@@ -693,6 +774,9 @@ class TrainingSAEConfig(SAEConfig, ABC):
|
|
|
693
774
|
raise ValueError(
|
|
694
775
|
f"SAE config class {cls} does not match dict config class {type(cfg_class)}"
|
|
695
776
|
)
|
|
777
|
+
# remove any keys that are not in the dataclass
|
|
778
|
+
# since we sometimes enhance the config with the whole LM runner config
|
|
779
|
+
valid_config_dict = filter_valid_dataclass_fields(config_dict, cfg_class)
|
|
696
780
|
if "metadata" in config_dict:
|
|
697
781
|
valid_config_dict["metadata"] = SAEMetadata(**config_dict["metadata"])
|
|
698
782
|
return cfg_class(**valid_config_dict)
|
|
@@ -701,6 +785,7 @@ class TrainingSAEConfig(SAEConfig, ABC):
|
|
|
701
785
|
return {
|
|
702
786
|
**super().to_dict(),
|
|
703
787
|
**asdict(self),
|
|
788
|
+
"metadata": self.metadata.to_dict(),
|
|
704
789
|
"architecture": self.architecture(),
|
|
705
790
|
}
|
|
706
791
|
|
|
@@ -711,12 +796,14 @@ class TrainingSAEConfig(SAEConfig, ABC):
|
|
|
711
796
|
Creates a dictionary containing attributes corresponding to the fields
|
|
712
797
|
defined in the base SAEConfig class.
|
|
713
798
|
"""
|
|
714
|
-
|
|
799
|
+
base_sae_cfg_class = get_sae_class(self.architecture())[1]
|
|
800
|
+
base_config_field_names = {f.name for f in fields(base_sae_cfg_class)}
|
|
715
801
|
result_dict = {
|
|
716
802
|
field_name: getattr(self, field_name)
|
|
717
803
|
for field_name in base_config_field_names
|
|
718
804
|
}
|
|
719
805
|
result_dict["architecture"] = self.architecture()
|
|
806
|
+
result_dict["metadata"] = self.metadata.to_dict()
|
|
720
807
|
return result_dict
|
|
721
808
|
|
|
722
809
|
|
|
@@ -729,7 +816,7 @@ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
|
|
|
729
816
|
# Turn off hook_z reshaping for training mode - the activation store
|
|
730
817
|
# is expected to handle reshaping before passing data to the SAE
|
|
731
818
|
self.turn_off_forward_pass_hook_z_reshaping()
|
|
732
|
-
self.mse_loss_fn =
|
|
819
|
+
self.mse_loss_fn = mse_loss
|
|
733
820
|
|
|
734
821
|
@abstractmethod
|
|
735
822
|
def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]: ...
|
|
@@ -864,27 +951,6 @@ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
|
|
|
864
951
|
"""
|
|
865
952
|
return self.process_state_dict_for_saving(state_dict)
|
|
866
953
|
|
|
867
|
-
def _get_mse_loss_fn(self) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
|
|
868
|
-
"""Get the MSE loss function based on config."""
|
|
869
|
-
|
|
870
|
-
def standard_mse_loss_fn(
|
|
871
|
-
preds: torch.Tensor, target: torch.Tensor
|
|
872
|
-
) -> torch.Tensor:
|
|
873
|
-
return torch.nn.functional.mse_loss(preds, target, reduction="none")
|
|
874
|
-
|
|
875
|
-
def batch_norm_mse_loss_fn(
|
|
876
|
-
preds: torch.Tensor, target: torch.Tensor
|
|
877
|
-
) -> torch.Tensor:
|
|
878
|
-
target_centered = target - target.mean(dim=0, keepdim=True)
|
|
879
|
-
normalization = target_centered.norm(dim=-1, keepdim=True)
|
|
880
|
-
return torch.nn.functional.mse_loss(preds, target, reduction="none") / (
|
|
881
|
-
normalization + 1e-6
|
|
882
|
-
)
|
|
883
|
-
|
|
884
|
-
if self.cfg.mse_loss_normalization == "dense_batch":
|
|
885
|
-
return batch_norm_mse_loss_fn
|
|
886
|
-
return standard_mse_loss_fn
|
|
887
|
-
|
|
888
954
|
@torch.no_grad()
|
|
889
955
|
def remove_gradient_parallel_to_decoder_directions(self) -> None:
|
|
890
956
|
"""Remove gradient components parallel to decoder directions."""
|
|
@@ -946,3 +1012,7 @@ def _disable_hooks(sae: SAE[Any]):
|
|
|
946
1012
|
finally:
|
|
947
1013
|
for hook_name, hook in sae.hook_dict.items():
|
|
948
1014
|
setattr(sae, hook_name, hook)
|
|
1015
|
+
|
|
1016
|
+
|
|
1017
|
+
def mse_loss(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
1018
|
+
return torch.nn.functional.mse_loss(preds, target, reduction="none")
|
sae_lens/saes/standard_sae.py
CHANGED
|
@@ -67,7 +67,7 @@ class StandardSAE(SAE[StandardSAEConfig]):
|
|
|
67
67
|
sae_in = self.process_sae_in(x)
|
|
68
68
|
# Compute the pre-activation values
|
|
69
69
|
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
|
|
70
|
-
# Apply the activation function (e.g., ReLU,
|
|
70
|
+
# Apply the activation function (e.g., ReLU, depending on config)
|
|
71
71
|
return self.hook_sae_acts_post(self.activation_fn(hidden_pre))
|
|
72
72
|
|
|
73
73
|
def decode(
|
|
@@ -81,7 +81,7 @@ class StandardSAE(SAE[StandardSAEConfig]):
|
|
|
81
81
|
sae_out_pre = feature_acts @ self.W_dec + self.b_dec
|
|
82
82
|
# 2) hook reconstruction
|
|
83
83
|
sae_out_pre = self.hook_sae_recons(sae_out_pre)
|
|
84
|
-
# 4) optional out-normalization (e.g. constant_norm_rescale
|
|
84
|
+
# 4) optional out-normalization (e.g. constant_norm_rescale)
|
|
85
85
|
sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
|
|
86
86
|
# 5) if hook_z is enabled, rearrange back to (..., n_heads, d_head).
|
|
87
87
|
return self.reshape_fn_out(sae_out_pre, self.d_head)
|
|
@@ -136,16 +136,9 @@ class StandardTrainingSAE(TrainingSAE[StandardTrainingSAEConfig]):
|
|
|
136
136
|
sae_in = self.process_sae_in(x)
|
|
137
137
|
# Compute the pre-activation (and allow for a hook if desired)
|
|
138
138
|
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc) # type: ignore
|
|
139
|
-
# Add noise during training for robustness (scaled by noise_scale from the configuration)
|
|
140
|
-
if self.training and self.cfg.noise_scale > 0:
|
|
141
|
-
hidden_pre_noised = (
|
|
142
|
-
hidden_pre + torch.randn_like(hidden_pre) * self.cfg.noise_scale
|
|
143
|
-
)
|
|
144
|
-
else:
|
|
145
|
-
hidden_pre_noised = hidden_pre
|
|
146
139
|
# Apply the activation function (and any post-activation hook)
|
|
147
|
-
feature_acts = self.hook_sae_acts_post(self.activation_fn(
|
|
148
|
-
return feature_acts,
|
|
140
|
+
feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre))
|
|
141
|
+
return feature_acts, hidden_pre
|
|
149
142
|
|
|
150
143
|
def calculate_aux_loss(
|
|
151
144
|
self,
|
sae_lens/saes/topk_sae.py
CHANGED
|
@@ -91,8 +91,7 @@ class TopKSAE(SAE[TopKSAEConfig]):
|
|
|
91
91
|
) -> Float[torch.Tensor, "... d_sae"]:
|
|
92
92
|
"""
|
|
93
93
|
Converts input x into feature activations.
|
|
94
|
-
Uses topk activation
|
|
95
|
-
under the hood.
|
|
94
|
+
Uses topk activation under the hood.
|
|
96
95
|
"""
|
|
97
96
|
sae_in = self.process_sae_in(x)
|
|
98
97
|
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
|
|
@@ -116,6 +115,13 @@ class TopKSAE(SAE[TopKSAEConfig]):
|
|
|
116
115
|
def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
|
|
117
116
|
return TopK(self.cfg.k)
|
|
118
117
|
|
|
118
|
+
@override
|
|
119
|
+
@torch.no_grad()
|
|
120
|
+
def fold_W_dec_norm(self) -> None:
|
|
121
|
+
raise NotImplementedError(
|
|
122
|
+
"Folding W_dec_norm is not safe for TopKSAEs, as this may change the topk activations"
|
|
123
|
+
)
|
|
124
|
+
|
|
119
125
|
|
|
120
126
|
@dataclass
|
|
121
127
|
class TopKTrainingSAEConfig(TrainingSAEConfig):
|
|
@@ -156,18 +162,11 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
|
|
|
156
162
|
sae_in = self.process_sae_in(x)
|
|
157
163
|
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
|
|
158
164
|
|
|
159
|
-
# Inject noise if training
|
|
160
|
-
if self.training and self.cfg.noise_scale > 0:
|
|
161
|
-
hidden_pre_noised = (
|
|
162
|
-
hidden_pre + torch.randn_like(hidden_pre) * self.cfg.noise_scale
|
|
163
|
-
)
|
|
164
|
-
else:
|
|
165
|
-
hidden_pre_noised = hidden_pre
|
|
166
|
-
|
|
167
165
|
# Apply the TopK activation function (already set in self.activation_fn if config is "topk")
|
|
168
|
-
feature_acts = self.hook_sae_acts_post(self.activation_fn(
|
|
169
|
-
return feature_acts,
|
|
166
|
+
feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre))
|
|
167
|
+
return feature_acts, hidden_pre
|
|
170
168
|
|
|
169
|
+
@override
|
|
171
170
|
def calculate_aux_loss(
|
|
172
171
|
self,
|
|
173
172
|
step_input: TrainStepInput,
|
|
@@ -184,6 +183,13 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
|
|
|
184
183
|
)
|
|
185
184
|
return {"auxiliary_reconstruction_loss": topk_loss}
|
|
186
185
|
|
|
186
|
+
@override
|
|
187
|
+
@torch.no_grad()
|
|
188
|
+
def fold_W_dec_norm(self) -> None:
|
|
189
|
+
raise NotImplementedError(
|
|
190
|
+
"Folding W_dec_norm is not safe for TopKSAEs, as this may change the topk activations"
|
|
191
|
+
)
|
|
192
|
+
|
|
187
193
|
@override
|
|
188
194
|
def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
|
|
189
195
|
return TopK(self.cfg.k)
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from statistics import mean
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from tqdm.auto import tqdm
|
|
7
|
+
|
|
8
|
+
from sae_lens.training.types import DataProvider
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class ActivationScaler:
|
|
13
|
+
scaling_factor: float | None = None
|
|
14
|
+
|
|
15
|
+
def scale(self, acts: torch.Tensor) -> torch.Tensor:
|
|
16
|
+
return acts if self.scaling_factor is None else acts * self.scaling_factor
|
|
17
|
+
|
|
18
|
+
def unscale(self, acts: torch.Tensor) -> torch.Tensor:
|
|
19
|
+
return acts if self.scaling_factor is None else acts / self.scaling_factor
|
|
20
|
+
|
|
21
|
+
def __call__(self, acts: torch.Tensor) -> torch.Tensor:
|
|
22
|
+
return self.scale(acts)
|
|
23
|
+
|
|
24
|
+
@torch.no_grad()
|
|
25
|
+
def _calculate_mean_norm(
|
|
26
|
+
self, data_provider: DataProvider, n_batches_for_norm_estimate: int = int(1e3)
|
|
27
|
+
) -> float:
|
|
28
|
+
norms_per_batch: list[float] = []
|
|
29
|
+
for _ in tqdm(
|
|
30
|
+
range(n_batches_for_norm_estimate), desc="Estimating norm scaling factor"
|
|
31
|
+
):
|
|
32
|
+
acts = next(data_provider)
|
|
33
|
+
norms_per_batch.append(acts.norm(dim=-1).mean().item())
|
|
34
|
+
return mean(norms_per_batch)
|
|
35
|
+
|
|
36
|
+
def estimate_scaling_factor(
|
|
37
|
+
self,
|
|
38
|
+
d_in: int,
|
|
39
|
+
data_provider: DataProvider,
|
|
40
|
+
n_batches_for_norm_estimate: int = int(1e3),
|
|
41
|
+
):
|
|
42
|
+
mean_norm = self._calculate_mean_norm(
|
|
43
|
+
data_provider, n_batches_for_norm_estimate
|
|
44
|
+
)
|
|
45
|
+
self.scaling_factor = (d_in**0.5) / mean_norm
|
|
46
|
+
|
|
47
|
+
def save(self, file_path: str):
|
|
48
|
+
"""save the state dict to a file in json format"""
|
|
49
|
+
if not file_path.endswith(".json"):
|
|
50
|
+
raise ValueError("file_path must end with .json")
|
|
51
|
+
|
|
52
|
+
with open(file_path, "w") as f:
|
|
53
|
+
json.dump({"scaling_factor": self.scaling_factor}, f)
|