sae-lens 6.0.0rc3__py3-none-any.whl → 6.0.0rc5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +1 -1
- sae_lens/analysis/neuronpedia_integration.py +3 -3
- sae_lens/config.py +5 -3
- sae_lens/constants.py +1 -0
- sae_lens/evals.py +20 -20
- sae_lens/llm_sae_training_runner.py +113 -5
- sae_lens/loading/pretrained_sae_loaders.py +178 -7
- sae_lens/pretrained_saes.yaml +12 -0
- sae_lens/saes/gated_sae.py +0 -4
- sae_lens/saes/jumprelu_sae.py +4 -10
- sae_lens/saes/sae.py +179 -48
- sae_lens/saes/standard_sae.py +4 -11
- sae_lens/saes/topk_sae.py +18 -12
- sae_lens/training/activation_scaler.py +1 -1
- sae_lens/training/activations_store.py +1 -3
- sae_lens/training/sae_trainer.py +11 -3
- sae_lens/training/upload_saes_to_huggingface.py +1 -1
- {sae_lens-6.0.0rc3.dist-info → sae_lens-6.0.0rc5.dist-info}/METADATA +2 -2
- sae_lens-6.0.0rc5.dist-info/RECORD +37 -0
- sae_lens/training/geometric_median.py +0 -101
- sae_lens-6.0.0rc3.dist-info/RECORD +0 -38
- {sae_lens-6.0.0rc3.dist-info → sae_lens-6.0.0rc5.dist-info}/LICENSE +0 -0
- {sae_lens-6.0.0rc3.dist-info → sae_lens-6.0.0rc5.dist-info}/WHEEL +0 -0
sae_lens/saes/sae.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Base classes for Sparse Autoencoders (SAEs)."""
|
|
2
2
|
|
|
3
|
+
import copy
|
|
3
4
|
import json
|
|
4
5
|
import warnings
|
|
5
6
|
from abc import ABC, abstractmethod
|
|
@@ -59,23 +60,91 @@ T_SAE = TypeVar("T_SAE", bound="SAE") # type: ignore
|
|
|
59
60
|
T_TRAINING_SAE = TypeVar("T_TRAINING_SAE", bound="TrainingSAE") # type: ignore
|
|
60
61
|
|
|
61
62
|
|
|
62
|
-
@dataclass
|
|
63
63
|
class SAEMetadata:
|
|
64
64
|
"""Core metadata about how this SAE should be used, if known."""
|
|
65
65
|
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
66
|
+
def __init__(self, **kwargs: Any):
|
|
67
|
+
# Set default version fields with their current behavior
|
|
68
|
+
self.sae_lens_version = kwargs.pop("sae_lens_version", __version__)
|
|
69
|
+
self.sae_lens_training_version = kwargs.pop(
|
|
70
|
+
"sae_lens_training_version", __version__
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
# Set all other attributes dynamically
|
|
74
|
+
for key, value in kwargs.items():
|
|
75
|
+
setattr(self, key, value)
|
|
76
|
+
|
|
77
|
+
def __getattr__(self, name: str) -> None:
|
|
78
|
+
"""Return None for any missing attribute (like defaultdict)"""
|
|
79
|
+
return
|
|
80
|
+
|
|
81
|
+
def __setattr__(self, name: str, value: Any) -> None:
|
|
82
|
+
"""Allow setting any attribute"""
|
|
83
|
+
super().__setattr__(name, value)
|
|
84
|
+
|
|
85
|
+
def __getitem__(self, key: str) -> Any:
|
|
86
|
+
"""Allow dictionary-style access: metadata['key'] - returns None for missing keys"""
|
|
87
|
+
return getattr(self, key)
|
|
88
|
+
|
|
89
|
+
def __setitem__(self, key: str, value: Any) -> None:
|
|
90
|
+
"""Allow dictionary-style assignment: metadata['key'] = value"""
|
|
91
|
+
setattr(self, key, value)
|
|
92
|
+
|
|
93
|
+
def __contains__(self, key: str) -> bool:
|
|
94
|
+
"""Allow 'in' operator: 'key' in metadata"""
|
|
95
|
+
# Only return True if the attribute was explicitly set (not just defaulting to None)
|
|
96
|
+
return key in self.__dict__
|
|
97
|
+
|
|
98
|
+
def get(self, key: str, default: Any = None) -> Any:
|
|
99
|
+
"""Dictionary-style get with default"""
|
|
100
|
+
value = getattr(self, key)
|
|
101
|
+
# If the attribute wasn't explicitly set and we got None from __getattr__,
|
|
102
|
+
# use the provided default instead
|
|
103
|
+
if key not in self.__dict__ and value is None:
|
|
104
|
+
return default
|
|
105
|
+
return value
|
|
106
|
+
|
|
107
|
+
def keys(self):
|
|
108
|
+
"""Return all explicitly set attribute names"""
|
|
109
|
+
return self.__dict__.keys()
|
|
110
|
+
|
|
111
|
+
def values(self):
|
|
112
|
+
"""Return all explicitly set attribute values"""
|
|
113
|
+
return self.__dict__.values()
|
|
114
|
+
|
|
115
|
+
def items(self):
|
|
116
|
+
"""Return all explicitly set attribute name-value pairs"""
|
|
117
|
+
return self.__dict__.items()
|
|
118
|
+
|
|
119
|
+
def to_dict(self) -> dict[str, Any]:
|
|
120
|
+
"""Convert to dictionary for serialization"""
|
|
121
|
+
return self.__dict__.copy()
|
|
122
|
+
|
|
123
|
+
@classmethod
|
|
124
|
+
def from_dict(cls, data: dict[str, Any]) -> "SAEMetadata":
|
|
125
|
+
"""Create from dictionary"""
|
|
126
|
+
return cls(**data)
|
|
127
|
+
|
|
128
|
+
def __repr__(self) -> str:
|
|
129
|
+
return f"SAEMetadata({self.__dict__})"
|
|
130
|
+
|
|
131
|
+
def __eq__(self, other: object) -> bool:
|
|
132
|
+
if not isinstance(other, SAEMetadata):
|
|
133
|
+
return False
|
|
134
|
+
return self.__dict__ == other.__dict__
|
|
135
|
+
|
|
136
|
+
def __deepcopy__(self, memo: dict[int, Any]) -> "SAEMetadata":
|
|
137
|
+
"""Support for deep copying"""
|
|
138
|
+
|
|
139
|
+
return SAEMetadata(**copy.deepcopy(self.__dict__, memo))
|
|
140
|
+
|
|
141
|
+
def __getstate__(self) -> dict[str, Any]:
|
|
142
|
+
"""Support for pickling"""
|
|
143
|
+
return self.__dict__
|
|
144
|
+
|
|
145
|
+
def __setstate__(self, state: dict[str, Any]) -> None:
|
|
146
|
+
"""Support for unpickling"""
|
|
147
|
+
self.__dict__.update(state)
|
|
79
148
|
|
|
80
149
|
|
|
81
150
|
@dataclass
|
|
@@ -99,7 +168,7 @@ class SAEConfig(ABC):
|
|
|
99
168
|
|
|
100
169
|
def to_dict(self) -> dict[str, Any]:
|
|
101
170
|
res = {field.name: getattr(self, field.name) for field in fields(self)}
|
|
102
|
-
res["metadata"] =
|
|
171
|
+
res["metadata"] = self.metadata.to_dict()
|
|
103
172
|
res["architecture"] = self.architecture()
|
|
104
173
|
return res
|
|
105
174
|
|
|
@@ -124,7 +193,7 @@ class SAEConfig(ABC):
|
|
|
124
193
|
"layer_norm",
|
|
125
194
|
]:
|
|
126
195
|
raise ValueError(
|
|
127
|
-
f"normalize_activations must be none, expected_average_only_in,
|
|
196
|
+
f"normalize_activations must be none, expected_average_only_in, layer_norm, or constant_norm_rescale. Got {self.normalize_activations}"
|
|
128
197
|
)
|
|
129
198
|
|
|
130
199
|
|
|
@@ -238,9 +307,8 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
238
307
|
|
|
239
308
|
self.run_time_activation_norm_fn_in = run_time_activation_norm_fn_in
|
|
240
309
|
self.run_time_activation_norm_fn_out = run_time_activation_norm_fn_out
|
|
241
|
-
|
|
242
310
|
elif self.cfg.normalize_activations == "layer_norm":
|
|
243
|
-
|
|
311
|
+
# we need to scale the norm of the input and store the scaling factor
|
|
244
312
|
def run_time_activation_ln_in(
|
|
245
313
|
x: torch.Tensor, eps: float = 1e-5
|
|
246
314
|
) -> torch.Tensor:
|
|
@@ -522,7 +590,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
522
590
|
device: str = "cpu",
|
|
523
591
|
force_download: bool = False,
|
|
524
592
|
converter: PretrainedSaeHuggingfaceLoader | None = None,
|
|
525
|
-
) ->
|
|
593
|
+
) -> T_SAE:
|
|
526
594
|
"""
|
|
527
595
|
Load a pretrained SAE from the Hugging Face model hub.
|
|
528
596
|
|
|
@@ -530,7 +598,28 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
530
598
|
release: The release name. This will be mapped to a huggingface repo id based on the pretrained_saes.yaml file.
|
|
531
599
|
id: The id of the SAE to load. This will be mapped to a path in the huggingface repo.
|
|
532
600
|
device: The device to load the SAE on.
|
|
533
|
-
|
|
601
|
+
"""
|
|
602
|
+
return cls.from_pretrained_with_cfg_and_sparsity(
|
|
603
|
+
release, sae_id, device, force_download, converter=converter
|
|
604
|
+
)[0]
|
|
605
|
+
|
|
606
|
+
@classmethod
|
|
607
|
+
def from_pretrained_with_cfg_and_sparsity(
|
|
608
|
+
cls: Type[T_SAE],
|
|
609
|
+
release: str,
|
|
610
|
+
sae_id: str,
|
|
611
|
+
device: str = "cpu",
|
|
612
|
+
force_download: bool = False,
|
|
613
|
+
converter: PretrainedSaeHuggingfaceLoader | None = None,
|
|
614
|
+
) -> tuple[T_SAE, dict[str, Any], torch.Tensor | None]:
|
|
615
|
+
"""
|
|
616
|
+
Load a pretrained SAE from the Hugging Face model hub, along with its config dict and sparsity, if present.
|
|
617
|
+
In SAELens <= 5.x.x, this was called SAE.from_pretrained().
|
|
618
|
+
|
|
619
|
+
Args:
|
|
620
|
+
release: The release name. This will be mapped to a huggingface repo id based on the pretrained_saes.yaml file.
|
|
621
|
+
id: The id of the SAE to load. This will be mapped to a path in the huggingface repo.
|
|
622
|
+
device: The device to load the SAE on.
|
|
534
623
|
"""
|
|
535
624
|
|
|
536
625
|
# get sae directory
|
|
@@ -643,11 +732,67 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
643
732
|
) -> type[SAEConfig]:
|
|
644
733
|
return SAEConfig
|
|
645
734
|
|
|
735
|
+
### Methods to support deprecated usage of SAE.from_pretrained() ###
|
|
736
|
+
|
|
737
|
+
def __getitem__(self, index: int) -> Any:
|
|
738
|
+
"""
|
|
739
|
+
Support indexing for backward compatibility with tuple unpacking.
|
|
740
|
+
DEPRECATED: SAE.from_pretrained() no longer returns a tuple.
|
|
741
|
+
Use SAE.from_pretrained_with_cfg_and_sparsity() instead.
|
|
742
|
+
"""
|
|
743
|
+
warnings.warn(
|
|
744
|
+
"Indexing SAE objects is deprecated. SAE.from_pretrained() now returns "
|
|
745
|
+
"only the SAE object. Use SAE.from_pretrained_with_cfg_and_sparsity() "
|
|
746
|
+
"to get the config dict and sparsity as well.",
|
|
747
|
+
DeprecationWarning,
|
|
748
|
+
stacklevel=2,
|
|
749
|
+
)
|
|
750
|
+
|
|
751
|
+
if index == 0:
|
|
752
|
+
return self
|
|
753
|
+
if index == 1:
|
|
754
|
+
return self.cfg.to_dict()
|
|
755
|
+
if index == 2:
|
|
756
|
+
return None
|
|
757
|
+
raise IndexError(f"SAE tuple index {index} out of range")
|
|
758
|
+
|
|
759
|
+
def __iter__(self):
|
|
760
|
+
"""
|
|
761
|
+
Support unpacking for backward compatibility with tuple unpacking.
|
|
762
|
+
DEPRECATED: SAE.from_pretrained() no longer returns a tuple.
|
|
763
|
+
Use SAE.from_pretrained_with_cfg_and_sparsity() instead.
|
|
764
|
+
"""
|
|
765
|
+
warnings.warn(
|
|
766
|
+
"Unpacking SAE objects is deprecated. SAE.from_pretrained() now returns "
|
|
767
|
+
"only the SAE object. Use SAE.from_pretrained_with_cfg_and_sparsity() "
|
|
768
|
+
"to get the config dict and sparsity as well.",
|
|
769
|
+
DeprecationWarning,
|
|
770
|
+
stacklevel=2,
|
|
771
|
+
)
|
|
772
|
+
|
|
773
|
+
yield self
|
|
774
|
+
yield self.cfg.to_dict()
|
|
775
|
+
yield None
|
|
776
|
+
|
|
777
|
+
def __len__(self) -> int:
|
|
778
|
+
"""
|
|
779
|
+
Support len() for backward compatibility with tuple unpacking.
|
|
780
|
+
DEPRECATED: SAE.from_pretrained() no longer returns a tuple.
|
|
781
|
+
Use SAE.from_pretrained_with_cfg_and_sparsity() instead.
|
|
782
|
+
"""
|
|
783
|
+
warnings.warn(
|
|
784
|
+
"Getting length of SAE objects is deprecated. SAE.from_pretrained() now returns "
|
|
785
|
+
"only the SAE object. Use SAE.from_pretrained_with_cfg_and_sparsity() "
|
|
786
|
+
"to get the config dict and sparsity as well.",
|
|
787
|
+
DeprecationWarning,
|
|
788
|
+
stacklevel=2,
|
|
789
|
+
)
|
|
790
|
+
|
|
791
|
+
return 3
|
|
792
|
+
|
|
646
793
|
|
|
647
794
|
@dataclass(kw_only=True)
|
|
648
795
|
class TrainingSAEConfig(SAEConfig, ABC):
|
|
649
|
-
noise_scale: float = 0.0
|
|
650
|
-
mse_loss_normalization: str | None = None
|
|
651
796
|
# https://transformer-circuits.pub/2024/april-update/index.html#training-saes
|
|
652
797
|
# 0.1 corresponds to the "heuristic" initialization, use None to disable
|
|
653
798
|
decoder_init_norm: float | None = 0.1
|
|
@@ -680,9 +825,6 @@ class TrainingSAEConfig(SAEConfig, ABC):
|
|
|
680
825
|
def from_dict(
|
|
681
826
|
cls: type[T_TRAINING_SAE_CONFIG], config_dict: dict[str, Any]
|
|
682
827
|
) -> T_TRAINING_SAE_CONFIG:
|
|
683
|
-
# remove any keys that are not in the dataclass
|
|
684
|
-
# since we sometimes enhance the config with the whole LM runner config
|
|
685
|
-
valid_config_dict = filter_valid_dataclass_fields(config_dict, cls)
|
|
686
828
|
cfg_class = cls
|
|
687
829
|
if "architecture" in config_dict:
|
|
688
830
|
cfg_class = get_sae_training_class(config_dict["architecture"])[1]
|
|
@@ -690,6 +832,9 @@ class TrainingSAEConfig(SAEConfig, ABC):
|
|
|
690
832
|
raise ValueError(
|
|
691
833
|
f"SAE config class {cls} does not match dict config class {type(cfg_class)}"
|
|
692
834
|
)
|
|
835
|
+
# remove any keys that are not in the dataclass
|
|
836
|
+
# since we sometimes enhance the config with the whole LM runner config
|
|
837
|
+
valid_config_dict = filter_valid_dataclass_fields(config_dict, cfg_class)
|
|
693
838
|
if "metadata" in config_dict:
|
|
694
839
|
valid_config_dict["metadata"] = SAEMetadata(**config_dict["metadata"])
|
|
695
840
|
return cfg_class(**valid_config_dict)
|
|
@@ -698,6 +843,7 @@ class TrainingSAEConfig(SAEConfig, ABC):
|
|
|
698
843
|
return {
|
|
699
844
|
**super().to_dict(),
|
|
700
845
|
**asdict(self),
|
|
846
|
+
"metadata": self.metadata.to_dict(),
|
|
701
847
|
"architecture": self.architecture(),
|
|
702
848
|
}
|
|
703
849
|
|
|
@@ -708,12 +854,14 @@ class TrainingSAEConfig(SAEConfig, ABC):
|
|
|
708
854
|
Creates a dictionary containing attributes corresponding to the fields
|
|
709
855
|
defined in the base SAEConfig class.
|
|
710
856
|
"""
|
|
711
|
-
|
|
857
|
+
base_sae_cfg_class = get_sae_class(self.architecture())[1]
|
|
858
|
+
base_config_field_names = {f.name for f in fields(base_sae_cfg_class)}
|
|
712
859
|
result_dict = {
|
|
713
860
|
field_name: getattr(self, field_name)
|
|
714
861
|
for field_name in base_config_field_names
|
|
715
862
|
}
|
|
716
863
|
result_dict["architecture"] = self.architecture()
|
|
864
|
+
result_dict["metadata"] = self.metadata.to_dict()
|
|
717
865
|
return result_dict
|
|
718
866
|
|
|
719
867
|
|
|
@@ -726,7 +874,7 @@ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
|
|
|
726
874
|
# Turn off hook_z reshaping for training mode - the activation store
|
|
727
875
|
# is expected to handle reshaping before passing data to the SAE
|
|
728
876
|
self.turn_off_forward_pass_hook_z_reshaping()
|
|
729
|
-
self.mse_loss_fn =
|
|
877
|
+
self.mse_loss_fn = mse_loss
|
|
730
878
|
|
|
731
879
|
@abstractmethod
|
|
732
880
|
def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]: ...
|
|
@@ -861,27 +1009,6 @@ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
|
|
|
861
1009
|
"""
|
|
862
1010
|
return self.process_state_dict_for_saving(state_dict)
|
|
863
1011
|
|
|
864
|
-
def _get_mse_loss_fn(self) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
|
|
865
|
-
"""Get the MSE loss function based on config."""
|
|
866
|
-
|
|
867
|
-
def standard_mse_loss_fn(
|
|
868
|
-
preds: torch.Tensor, target: torch.Tensor
|
|
869
|
-
) -> torch.Tensor:
|
|
870
|
-
return torch.nn.functional.mse_loss(preds, target, reduction="none")
|
|
871
|
-
|
|
872
|
-
def batch_norm_mse_loss_fn(
|
|
873
|
-
preds: torch.Tensor, target: torch.Tensor
|
|
874
|
-
) -> torch.Tensor:
|
|
875
|
-
target_centered = target - target.mean(dim=0, keepdim=True)
|
|
876
|
-
normalization = target_centered.norm(dim=-1, keepdim=True)
|
|
877
|
-
return torch.nn.functional.mse_loss(preds, target, reduction="none") / (
|
|
878
|
-
normalization + 1e-6
|
|
879
|
-
)
|
|
880
|
-
|
|
881
|
-
if self.cfg.mse_loss_normalization == "dense_batch":
|
|
882
|
-
return batch_norm_mse_loss_fn
|
|
883
|
-
return standard_mse_loss_fn
|
|
884
|
-
|
|
885
1012
|
@torch.no_grad()
|
|
886
1013
|
def remove_gradient_parallel_to_decoder_directions(self) -> None:
|
|
887
1014
|
"""Remove gradient components parallel to decoder directions."""
|
|
@@ -943,3 +1070,7 @@ def _disable_hooks(sae: SAE[Any]):
|
|
|
943
1070
|
finally:
|
|
944
1071
|
for hook_name, hook in sae.hook_dict.items():
|
|
945
1072
|
setattr(sae, hook_name, hook)
|
|
1073
|
+
|
|
1074
|
+
|
|
1075
|
+
def mse_loss(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
1076
|
+
return torch.nn.functional.mse_loss(preds, target, reduction="none")
|
sae_lens/saes/standard_sae.py
CHANGED
|
@@ -67,7 +67,7 @@ class StandardSAE(SAE[StandardSAEConfig]):
|
|
|
67
67
|
sae_in = self.process_sae_in(x)
|
|
68
68
|
# Compute the pre-activation values
|
|
69
69
|
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
|
|
70
|
-
# Apply the activation function (e.g., ReLU,
|
|
70
|
+
# Apply the activation function (e.g., ReLU, depending on config)
|
|
71
71
|
return self.hook_sae_acts_post(self.activation_fn(hidden_pre))
|
|
72
72
|
|
|
73
73
|
def decode(
|
|
@@ -81,7 +81,7 @@ class StandardSAE(SAE[StandardSAEConfig]):
|
|
|
81
81
|
sae_out_pre = feature_acts @ self.W_dec + self.b_dec
|
|
82
82
|
# 2) hook reconstruction
|
|
83
83
|
sae_out_pre = self.hook_sae_recons(sae_out_pre)
|
|
84
|
-
# 4) optional out-normalization (e.g. constant_norm_rescale
|
|
84
|
+
# 4) optional out-normalization (e.g. constant_norm_rescale)
|
|
85
85
|
sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
|
|
86
86
|
# 5) if hook_z is enabled, rearrange back to (..., n_heads, d_head).
|
|
87
87
|
return self.reshape_fn_out(sae_out_pre, self.d_head)
|
|
@@ -136,16 +136,9 @@ class StandardTrainingSAE(TrainingSAE[StandardTrainingSAEConfig]):
|
|
|
136
136
|
sae_in = self.process_sae_in(x)
|
|
137
137
|
# Compute the pre-activation (and allow for a hook if desired)
|
|
138
138
|
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc) # type: ignore
|
|
139
|
-
# Add noise during training for robustness (scaled by noise_scale from the configuration)
|
|
140
|
-
if self.training and self.cfg.noise_scale > 0:
|
|
141
|
-
hidden_pre_noised = (
|
|
142
|
-
hidden_pre + torch.randn_like(hidden_pre) * self.cfg.noise_scale
|
|
143
|
-
)
|
|
144
|
-
else:
|
|
145
|
-
hidden_pre_noised = hidden_pre
|
|
146
139
|
# Apply the activation function (and any post-activation hook)
|
|
147
|
-
feature_acts = self.hook_sae_acts_post(self.activation_fn(
|
|
148
|
-
return feature_acts,
|
|
140
|
+
feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre))
|
|
141
|
+
return feature_acts, hidden_pre
|
|
149
142
|
|
|
150
143
|
def calculate_aux_loss(
|
|
151
144
|
self,
|
sae_lens/saes/topk_sae.py
CHANGED
|
@@ -91,8 +91,7 @@ class TopKSAE(SAE[TopKSAEConfig]):
|
|
|
91
91
|
) -> Float[torch.Tensor, "... d_sae"]:
|
|
92
92
|
"""
|
|
93
93
|
Converts input x into feature activations.
|
|
94
|
-
Uses topk activation
|
|
95
|
-
under the hood.
|
|
94
|
+
Uses topk activation under the hood.
|
|
96
95
|
"""
|
|
97
96
|
sae_in = self.process_sae_in(x)
|
|
98
97
|
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
|
|
@@ -116,6 +115,13 @@ class TopKSAE(SAE[TopKSAEConfig]):
|
|
|
116
115
|
def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
|
|
117
116
|
return TopK(self.cfg.k)
|
|
118
117
|
|
|
118
|
+
@override
|
|
119
|
+
@torch.no_grad()
|
|
120
|
+
def fold_W_dec_norm(self) -> None:
|
|
121
|
+
raise NotImplementedError(
|
|
122
|
+
"Folding W_dec_norm is not safe for TopKSAEs, as this may change the topk activations"
|
|
123
|
+
)
|
|
124
|
+
|
|
119
125
|
|
|
120
126
|
@dataclass
|
|
121
127
|
class TopKTrainingSAEConfig(TrainingSAEConfig):
|
|
@@ -156,18 +162,11 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
|
|
|
156
162
|
sae_in = self.process_sae_in(x)
|
|
157
163
|
hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
|
|
158
164
|
|
|
159
|
-
# Inject noise if training
|
|
160
|
-
if self.training and self.cfg.noise_scale > 0:
|
|
161
|
-
hidden_pre_noised = (
|
|
162
|
-
hidden_pre + torch.randn_like(hidden_pre) * self.cfg.noise_scale
|
|
163
|
-
)
|
|
164
|
-
else:
|
|
165
|
-
hidden_pre_noised = hidden_pre
|
|
166
|
-
|
|
167
165
|
# Apply the TopK activation function (already set in self.activation_fn if config is "topk")
|
|
168
|
-
feature_acts = self.hook_sae_acts_post(self.activation_fn(
|
|
169
|
-
return feature_acts,
|
|
166
|
+
feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre))
|
|
167
|
+
return feature_acts, hidden_pre
|
|
170
168
|
|
|
169
|
+
@override
|
|
171
170
|
def calculate_aux_loss(
|
|
172
171
|
self,
|
|
173
172
|
step_input: TrainStepInput,
|
|
@@ -184,6 +183,13 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
|
|
|
184
183
|
)
|
|
185
184
|
return {"auxiliary_reconstruction_loss": topk_loss}
|
|
186
185
|
|
|
186
|
+
@override
|
|
187
|
+
@torch.no_grad()
|
|
188
|
+
def fold_W_dec_norm(self) -> None:
|
|
189
|
+
raise NotImplementedError(
|
|
190
|
+
"Folding W_dec_norm is not safe for TopKSAEs, as this may change the topk activations"
|
|
191
|
+
)
|
|
192
|
+
|
|
187
193
|
@override
|
|
188
194
|
def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
|
|
189
195
|
return TopK(self.cfg.k)
|
|
@@ -161,8 +161,6 @@ class ActivationsStore:
|
|
|
161
161
|
) -> ActivationsStore:
|
|
162
162
|
if sae.cfg.metadata.hook_name is None:
|
|
163
163
|
raise ValueError("hook_name is required")
|
|
164
|
-
if sae.cfg.metadata.hook_head_index is None:
|
|
165
|
-
raise ValueError("hook_head_index is required")
|
|
166
164
|
if sae.cfg.metadata.context_size is None:
|
|
167
165
|
raise ValueError("context_size is required")
|
|
168
166
|
if sae.cfg.metadata.prepend_bos is None:
|
|
@@ -430,7 +428,7 @@ class ActivationsStore:
|
|
|
430
428
|
):
|
|
431
429
|
# temporalily set estimated_norm_scaling_factor to 1.0 so the dataloader works
|
|
432
430
|
self.estimated_norm_scaling_factor = 1.0
|
|
433
|
-
acts = self.next_batch()[0]
|
|
431
|
+
acts = self.next_batch()[:, 0]
|
|
434
432
|
self.estimated_norm_scaling_factor = None
|
|
435
433
|
norms_per_batch.append(acts.norm(dim=-1).mean().item())
|
|
436
434
|
mean_norm = np.mean(norms_per_batch)
|
sae_lens/training/sae_trainer.py
CHANGED
|
@@ -7,7 +7,7 @@ import torch
|
|
|
7
7
|
import wandb
|
|
8
8
|
from safetensors.torch import save_file
|
|
9
9
|
from torch.optim import Adam
|
|
10
|
-
from tqdm import tqdm
|
|
10
|
+
from tqdm.auto import tqdm
|
|
11
11
|
|
|
12
12
|
from sae_lens import __version__
|
|
13
13
|
from sae_lens.config import SAETrainerConfig
|
|
@@ -161,6 +161,7 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
|
161
161
|
return (self.n_forward_passes_since_fired > self.cfg.dead_feature_window).bool()
|
|
162
162
|
|
|
163
163
|
def fit(self) -> T_TRAINING_SAE:
|
|
164
|
+
self.sae.to(self.cfg.device)
|
|
164
165
|
pbar = tqdm(total=self.cfg.total_training_samples, desc="Training SAE")
|
|
165
166
|
|
|
166
167
|
if self.sae.cfg.normalize_activations == "expected_average_only_in":
|
|
@@ -194,10 +195,11 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
|
194
195
|
)
|
|
195
196
|
self.activation_scaler.scaling_factor = None
|
|
196
197
|
|
|
197
|
-
# save final sae group to checkpoints folder
|
|
198
|
+
# save final inference sae group to checkpoints folder
|
|
198
199
|
self.save_checkpoint(
|
|
199
200
|
checkpoint_name=f"final_{self.n_training_samples}",
|
|
200
201
|
wandb_aliases=["final_model"],
|
|
202
|
+
save_inference_model=True,
|
|
201
203
|
)
|
|
202
204
|
|
|
203
205
|
pbar.close()
|
|
@@ -207,11 +209,17 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
|
207
209
|
self,
|
|
208
210
|
checkpoint_name: str,
|
|
209
211
|
wandb_aliases: list[str] | None = None,
|
|
212
|
+
save_inference_model: bool = False,
|
|
210
213
|
) -> None:
|
|
211
214
|
checkpoint_path = Path(self.cfg.checkpoint_path) / checkpoint_name
|
|
212
215
|
checkpoint_path.mkdir(exist_ok=True, parents=True)
|
|
213
216
|
|
|
214
|
-
|
|
217
|
+
save_fn = (
|
|
218
|
+
self.sae.save_inference_model
|
|
219
|
+
if save_inference_model
|
|
220
|
+
else self.sae.save_model
|
|
221
|
+
)
|
|
222
|
+
weights_path, cfg_path = save_fn(str(checkpoint_path))
|
|
215
223
|
|
|
216
224
|
sparsity_path = checkpoint_path / SPARSITY_FILENAME
|
|
217
225
|
save_file({"sparsity": self.log_feature_sparsity}, sparsity_path)
|
|
@@ -88,7 +88,7 @@ def _create_default_readme(repo_id: str, sae_ids: Iterable[str]) -> str:
|
|
|
88
88
|
```python
|
|
89
89
|
from sae_lens import SAE
|
|
90
90
|
|
|
91
|
-
sae
|
|
91
|
+
sae = SAE.from_pretrained("{repo_id}", "<sae_id>")
|
|
92
92
|
```
|
|
93
93
|
"""
|
|
94
94
|
)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: sae-lens
|
|
3
|
-
Version: 6.0.
|
|
3
|
+
Version: 6.0.0rc5
|
|
4
4
|
Summary: Training and Analyzing Sparse Autoencoders (SAEs)
|
|
5
5
|
License: MIT
|
|
6
6
|
Keywords: deep-learning,sparse-autoencoders,mechanistic-interpretability,PyTorch
|
|
@@ -80,7 +80,7 @@ Pre-trained SAEs for various models can be imported via SAE Lens. See this [page
|
|
|
80
80
|
|
|
81
81
|
## Join the Slack!
|
|
82
82
|
|
|
83
|
-
Feel free to join the [Open Source Mechanistic Interpretability Slack](https://join.slack.com/t/opensourcemechanistic/shared_invite/zt-
|
|
83
|
+
Feel free to join the [Open Source Mechanistic Interpretability Slack](https://join.slack.com/t/opensourcemechanistic/shared_invite/zt-375zalm04-GFd5tdBU1yLKlu_T_JSqZQ) for support!
|
|
84
84
|
|
|
85
85
|
## Citation
|
|
86
86
|
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
sae_lens/__init__.py,sha256=hiHDLT9_1V7iVulw5hwqDqDj2HVxUR9I88xOfYx6X94,2861
|
|
2
|
+
sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
|
+
sae_lens/analysis/hooked_sae_transformer.py,sha256=Eyg1Y2hVIHNuiiLOCTgzstOuW6iA-7hPHqaGR8y_vMs,13809
|
|
4
|
+
sae_lens/analysis/neuronpedia_integration.py,sha256=MrENqc81Mc2SMbxGjbwHzpkGUCAFKSf0i4EdaUF2Oj4,18707
|
|
5
|
+
sae_lens/cache_activations_runner.py,sha256=L5hhuU2-zPQr2S3L64GMKKLeMQfqXxwDl8NbuOtrybI,12567
|
|
6
|
+
sae_lens/config.py,sha256=9Lg4HkQvj1t9QZJdmC071lyJMc_iqNQknosT7zOYfwM,27278
|
|
7
|
+
sae_lens/constants.py,sha256=CSjmiZ-bhjQeVLyRvWxAjBokCgkfM8mnvd7-vxLIWTY,639
|
|
8
|
+
sae_lens/evals.py,sha256=kQyrzczKaVD9rHwfFa_DxL_gMXDxsoIVHmsFIPIU2bY,38696
|
|
9
|
+
sae_lens/llm_sae_training_runner.py,sha256=58XbDylw2fPOD7C-ZfSAjeNqJLXB05uHGTuiYVVbXXY,13354
|
|
10
|
+
sae_lens/load_model.py,sha256=dBB_9gO6kWyQ4sXHq7qB8T3YUlXm3PGwYcpR4UVW4QY,8633
|
|
11
|
+
sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
12
|
+
sae_lens/loading/pretrained_sae_loaders.py,sha256=5XEU4uFFeGCePwqDwhlE7CrFGRSI0U9Cu-UQVa33Y1E,36432
|
|
13
|
+
sae_lens/loading/pretrained_saes_directory.py,sha256=4Vn-Jex6SveD7EbxcSOBv8cx1gkPfUMLU1QOP-ww1ZE,3752
|
|
14
|
+
sae_lens/pretokenize_runner.py,sha256=0nHQq3s_d80VS8iVK4-e6y_orAYVO8c4RrLGtIDfK_E,6885
|
|
15
|
+
sae_lens/pretrained_saes.yaml,sha256=nhHW1auhyi4GHYrjUnHQqbNVhI5cMJv-HThzbzU1xG0,574145
|
|
16
|
+
sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
|
|
17
|
+
sae_lens/saes/__init__.py,sha256=v6mfeDzyGYtT6x5SszAQtkldTXwPE-V_iwOlrT_pDwQ,1008
|
|
18
|
+
sae_lens/saes/gated_sae.py,sha256=0zd66bH04nsaGk3bxHk10hsZofa2GrFbMo15LOsuqgU,9233
|
|
19
|
+
sae_lens/saes/jumprelu_sae.py,sha256=iwmPQJ4XpIxzgosty680u8Zj7x1uVZhM75kPOT3obi0,12060
|
|
20
|
+
sae_lens/saes/sae.py,sha256=ZEXEXFVtrtFrzuOV3nyweTBleNCV4EDGh1ImaF32uqg,39618
|
|
21
|
+
sae_lens/saes/standard_sae.py,sha256=PfkGLsw_6La3PXHOQL0u7qQsaZsXCJqYCeCcRDj5n64,6274
|
|
22
|
+
sae_lens/saes/topk_sae.py,sha256=kmry1FE1H06OvCfn84V-j2JfWGKcU5b2urwAq_Oq5j4,9893
|
|
23
|
+
sae_lens/tokenization_and_batching.py,sha256=oUAscjy_LPOrOb8_Ty6eLAcZ0B3HB_wiWjWktgolhG0,4314
|
|
24
|
+
sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
25
|
+
sae_lens/training/activation_scaler.py,sha256=seEE-2Qd2JMHxqgnsNWPt-DGtYGZxWPnOwCGuVNSOtI,1719
|
|
26
|
+
sae_lens/training/activations_store.py,sha256=z8erbiB6ODbsqlu-bwEWbyj4XZvgsVgjCRBuQovqp2Q,32612
|
|
27
|
+
sae_lens/training/mixing_buffer.py,sha256=vDpYG5ZE70szDvBsRKcNHEES3h_WTKJ16qDYk5jPOVA,2015
|
|
28
|
+
sae_lens/training/optim.py,sha256=KXdOym-Ly3f2aFbndRc0JEH0Wa7u1BE5ljxGN3YtouQ,6836
|
|
29
|
+
sae_lens/training/sae_trainer.py,sha256=9K0VudwSTJp9OlCVzaU_ngZ0WlYNrN6-ozTCCAxR9_k,15421
|
|
30
|
+
sae_lens/training/types.py,sha256=qSjmGzXf3MLalygG0psnVjmhX_mpLmL47MQtZfe7qxg,81
|
|
31
|
+
sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
|
|
32
|
+
sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
|
|
33
|
+
sae_lens/util.py,sha256=mCwLAilGMVo8Scm7CIsCafU7GsfmBvCcjwmloI4Ly7Y,1718
|
|
34
|
+
sae_lens-6.0.0rc5.dist-info/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
|
|
35
|
+
sae_lens-6.0.0rc5.dist-info/METADATA,sha256=ZrBaBFeIuM-ZJ9r0HHKakxnx3tGv7Zf6l_Z2OIdBxIU,5326
|
|
36
|
+
sae_lens-6.0.0rc5.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
|
37
|
+
sae_lens-6.0.0rc5.dist-info/RECORD,,
|
|
@@ -1,101 +0,0 @@
|
|
|
1
|
-
from types import SimpleNamespace
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
import tqdm
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
def weighted_average(points: torch.Tensor, weights: torch.Tensor):
|
|
8
|
-
weights = weights / weights.sum()
|
|
9
|
-
return (points * weights.view(-1, 1)).sum(dim=0)
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
@torch.no_grad()
|
|
13
|
-
def geometric_median_objective(
|
|
14
|
-
median: torch.Tensor, points: torch.Tensor, weights: torch.Tensor
|
|
15
|
-
) -> torch.Tensor:
|
|
16
|
-
norms = torch.linalg.norm(points - median.view(1, -1), dim=1) # type: ignore
|
|
17
|
-
|
|
18
|
-
return (norms * weights).sum()
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
def compute_geometric_median(
|
|
22
|
-
points: torch.Tensor,
|
|
23
|
-
weights: torch.Tensor | None = None,
|
|
24
|
-
eps: float = 1e-6,
|
|
25
|
-
maxiter: int = 100,
|
|
26
|
-
ftol: float = 1e-20,
|
|
27
|
-
do_log: bool = False,
|
|
28
|
-
):
|
|
29
|
-
"""
|
|
30
|
-
:param points: ``torch.Tensor`` of shape ``(n, d)``
|
|
31
|
-
:param weights: Optional ``torch.Tensor`` of shape :math:``(n,)``.
|
|
32
|
-
:param eps: Smallest allowed value of denominator, to avoid divide by zero.
|
|
33
|
-
Equivalently, this is a smoothing parameter. Default 1e-6.
|
|
34
|
-
:param maxiter: Maximum number of Weiszfeld iterations. Default 100
|
|
35
|
-
:param ftol: If objective value does not improve by at least this `ftol` fraction, terminate the algorithm. Default 1e-20.
|
|
36
|
-
:param do_log: If true will return a log of function values encountered through the course of the algorithm
|
|
37
|
-
:return: SimpleNamespace object with fields
|
|
38
|
-
- `median`: estimate of the geometric median, which is a ``torch.Tensor`` object of shape :math:``(d,)``
|
|
39
|
-
- `termination`: string explaining how the algorithm terminated.
|
|
40
|
-
- `logs`: function values encountered through the course of the algorithm in a list (None if do_log is false).
|
|
41
|
-
"""
|
|
42
|
-
with torch.no_grad():
|
|
43
|
-
if weights is None:
|
|
44
|
-
weights = torch.ones((points.shape[0],), device=points.device)
|
|
45
|
-
# initialize median estimate at mean
|
|
46
|
-
new_weights = weights
|
|
47
|
-
median = weighted_average(points, weights)
|
|
48
|
-
objective_value = geometric_median_objective(median, points, weights)
|
|
49
|
-
logs = [objective_value] if do_log else None
|
|
50
|
-
|
|
51
|
-
# Weiszfeld iterations
|
|
52
|
-
early_termination = False
|
|
53
|
-
pbar = tqdm.tqdm(range(maxiter))
|
|
54
|
-
for _ in pbar:
|
|
55
|
-
prev_obj_value = objective_value
|
|
56
|
-
|
|
57
|
-
norms = torch.linalg.norm(points - median.view(1, -1), dim=1) # type: ignore
|
|
58
|
-
new_weights = weights / torch.clamp(norms, min=eps)
|
|
59
|
-
median = weighted_average(points, new_weights)
|
|
60
|
-
objective_value = geometric_median_objective(median, points, weights)
|
|
61
|
-
|
|
62
|
-
if logs is not None:
|
|
63
|
-
logs.append(objective_value)
|
|
64
|
-
if abs(prev_obj_value - objective_value) <= ftol * objective_value:
|
|
65
|
-
early_termination = True
|
|
66
|
-
break
|
|
67
|
-
|
|
68
|
-
pbar.set_description(f"Objective value: {objective_value:.4f}")
|
|
69
|
-
|
|
70
|
-
median = weighted_average(points, new_weights) # allow autodiff to track it
|
|
71
|
-
return SimpleNamespace(
|
|
72
|
-
median=median,
|
|
73
|
-
new_weights=new_weights,
|
|
74
|
-
termination=(
|
|
75
|
-
"function value converged within tolerance"
|
|
76
|
-
if early_termination
|
|
77
|
-
else "maximum iterations reached"
|
|
78
|
-
),
|
|
79
|
-
logs=logs,
|
|
80
|
-
)
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
if __name__ == "__main__":
|
|
84
|
-
import time
|
|
85
|
-
|
|
86
|
-
TOLERANCE = 1e-2
|
|
87
|
-
|
|
88
|
-
dim1 = 10000
|
|
89
|
-
dim2 = 768
|
|
90
|
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
91
|
-
|
|
92
|
-
sample = (
|
|
93
|
-
torch.randn((dim1, dim2), device=device) * 100
|
|
94
|
-
) # seems to be the order of magnitude of the actual use case
|
|
95
|
-
weights = torch.randn((dim1,), device=device)
|
|
96
|
-
|
|
97
|
-
torch.tensor(weights, device=device)
|
|
98
|
-
|
|
99
|
-
tic = time.perf_counter()
|
|
100
|
-
new = compute_geometric_median(sample, weights=weights, maxiter=100)
|
|
101
|
-
print(f"new code takes {time.perf_counter()-tic} seconds!") # noqa: T201
|