sae-lens 6.0.0rc3__py3-none-any.whl → 6.0.0rc5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sae_lens/saes/sae.py CHANGED
@@ -1,5 +1,6 @@
1
1
  """Base classes for Sparse Autoencoders (SAEs)."""
2
2
 
3
+ import copy
3
4
  import json
4
5
  import warnings
5
6
  from abc import ABC, abstractmethod
@@ -59,23 +60,91 @@ T_SAE = TypeVar("T_SAE", bound="SAE") # type: ignore
59
60
  T_TRAINING_SAE = TypeVar("T_TRAINING_SAE", bound="TrainingSAE") # type: ignore
60
61
 
61
62
 
62
- @dataclass
63
63
  class SAEMetadata:
64
64
  """Core metadata about how this SAE should be used, if known."""
65
65
 
66
- model_name: str | None = None
67
- hook_name: str | None = None
68
- model_class_name: str | None = None
69
- hook_head_index: int | None = None
70
- model_from_pretrained_kwargs: dict[str, Any] | None = None
71
- prepend_bos: bool | None = None
72
- exclude_special_tokens: bool | list[int] | None = None
73
- neuronpedia_id: str | None = None
74
- context_size: int | None = None
75
- seqpos_slice: tuple[int | None, ...] | None = None
76
- dataset_path: str | None = None
77
- sae_lens_version: str = field(default_factory=lambda: __version__)
78
- sae_lens_training_version: str = field(default_factory=lambda: __version__)
66
+ def __init__(self, **kwargs: Any):
67
+ # Set default version fields with their current behavior
68
+ self.sae_lens_version = kwargs.pop("sae_lens_version", __version__)
69
+ self.sae_lens_training_version = kwargs.pop(
70
+ "sae_lens_training_version", __version__
71
+ )
72
+
73
+ # Set all other attributes dynamically
74
+ for key, value in kwargs.items():
75
+ setattr(self, key, value)
76
+
77
+ def __getattr__(self, name: str) -> None:
78
+ """Return None for any missing attribute (like defaultdict)"""
79
+ return
80
+
81
+ def __setattr__(self, name: str, value: Any) -> None:
82
+ """Allow setting any attribute"""
83
+ super().__setattr__(name, value)
84
+
85
+ def __getitem__(self, key: str) -> Any:
86
+ """Allow dictionary-style access: metadata['key'] - returns None for missing keys"""
87
+ return getattr(self, key)
88
+
89
+ def __setitem__(self, key: str, value: Any) -> None:
90
+ """Allow dictionary-style assignment: metadata['key'] = value"""
91
+ setattr(self, key, value)
92
+
93
+ def __contains__(self, key: str) -> bool:
94
+ """Allow 'in' operator: 'key' in metadata"""
95
+ # Only return True if the attribute was explicitly set (not just defaulting to None)
96
+ return key in self.__dict__
97
+
98
+ def get(self, key: str, default: Any = None) -> Any:
99
+ """Dictionary-style get with default"""
100
+ value = getattr(self, key)
101
+ # If the attribute wasn't explicitly set and we got None from __getattr__,
102
+ # use the provided default instead
103
+ if key not in self.__dict__ and value is None:
104
+ return default
105
+ return value
106
+
107
+ def keys(self):
108
+ """Return all explicitly set attribute names"""
109
+ return self.__dict__.keys()
110
+
111
+ def values(self):
112
+ """Return all explicitly set attribute values"""
113
+ return self.__dict__.values()
114
+
115
+ def items(self):
116
+ """Return all explicitly set attribute name-value pairs"""
117
+ return self.__dict__.items()
118
+
119
+ def to_dict(self) -> dict[str, Any]:
120
+ """Convert to dictionary for serialization"""
121
+ return self.__dict__.copy()
122
+
123
+ @classmethod
124
+ def from_dict(cls, data: dict[str, Any]) -> "SAEMetadata":
125
+ """Create from dictionary"""
126
+ return cls(**data)
127
+
128
+ def __repr__(self) -> str:
129
+ return f"SAEMetadata({self.__dict__})"
130
+
131
+ def __eq__(self, other: object) -> bool:
132
+ if not isinstance(other, SAEMetadata):
133
+ return False
134
+ return self.__dict__ == other.__dict__
135
+
136
+ def __deepcopy__(self, memo: dict[int, Any]) -> "SAEMetadata":
137
+ """Support for deep copying"""
138
+
139
+ return SAEMetadata(**copy.deepcopy(self.__dict__, memo))
140
+
141
+ def __getstate__(self) -> dict[str, Any]:
142
+ """Support for pickling"""
143
+ return self.__dict__
144
+
145
+ def __setstate__(self, state: dict[str, Any]) -> None:
146
+ """Support for unpickling"""
147
+ self.__dict__.update(state)
79
148
 
80
149
 
81
150
  @dataclass
@@ -99,7 +168,7 @@ class SAEConfig(ABC):
99
168
 
100
169
  def to_dict(self) -> dict[str, Any]:
101
170
  res = {field.name: getattr(self, field.name) for field in fields(self)}
102
- res["metadata"] = asdict(self.metadata)
171
+ res["metadata"] = self.metadata.to_dict()
103
172
  res["architecture"] = self.architecture()
104
173
  return res
105
174
 
@@ -124,7 +193,7 @@ class SAEConfig(ABC):
124
193
  "layer_norm",
125
194
  ]:
126
195
  raise ValueError(
127
- f"normalize_activations must be none, expected_average_only_in, constant_norm_rescale, or layer_norm. Got {self.normalize_activations}"
196
+ f"normalize_activations must be none, expected_average_only_in, layer_norm, or constant_norm_rescale. Got {self.normalize_activations}"
128
197
  )
129
198
 
130
199
 
@@ -238,9 +307,8 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
238
307
 
239
308
  self.run_time_activation_norm_fn_in = run_time_activation_norm_fn_in
240
309
  self.run_time_activation_norm_fn_out = run_time_activation_norm_fn_out
241
-
242
310
  elif self.cfg.normalize_activations == "layer_norm":
243
-
311
+ # we need to scale the norm of the input and store the scaling factor
244
312
  def run_time_activation_ln_in(
245
313
  x: torch.Tensor, eps: float = 1e-5
246
314
  ) -> torch.Tensor:
@@ -522,7 +590,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
522
590
  device: str = "cpu",
523
591
  force_download: bool = False,
524
592
  converter: PretrainedSaeHuggingfaceLoader | None = None,
525
- ) -> tuple[T_SAE, dict[str, Any], torch.Tensor | None]:
593
+ ) -> T_SAE:
526
594
  """
527
595
  Load a pretrained SAE from the Hugging Face model hub.
528
596
 
@@ -530,7 +598,28 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
530
598
  release: The release name. This will be mapped to a huggingface repo id based on the pretrained_saes.yaml file.
531
599
  id: The id of the SAE to load. This will be mapped to a path in the huggingface repo.
532
600
  device: The device to load the SAE on.
533
- return_sparsity_if_present: If True, will return the log sparsity tensor if it is present in the model directory in the Hugging Face model hub.
601
+ """
602
+ return cls.from_pretrained_with_cfg_and_sparsity(
603
+ release, sae_id, device, force_download, converter=converter
604
+ )[0]
605
+
606
+ @classmethod
607
+ def from_pretrained_with_cfg_and_sparsity(
608
+ cls: Type[T_SAE],
609
+ release: str,
610
+ sae_id: str,
611
+ device: str = "cpu",
612
+ force_download: bool = False,
613
+ converter: PretrainedSaeHuggingfaceLoader | None = None,
614
+ ) -> tuple[T_SAE, dict[str, Any], torch.Tensor | None]:
615
+ """
616
+ Load a pretrained SAE from the Hugging Face model hub, along with its config dict and sparsity, if present.
617
+ In SAELens <= 5.x.x, this was called SAE.from_pretrained().
618
+
619
+ Args:
620
+ release: The release name. This will be mapped to a huggingface repo id based on the pretrained_saes.yaml file.
621
+ id: The id of the SAE to load. This will be mapped to a path in the huggingface repo.
622
+ device: The device to load the SAE on.
534
623
  """
535
624
 
536
625
  # get sae directory
@@ -643,11 +732,67 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
643
732
  ) -> type[SAEConfig]:
644
733
  return SAEConfig
645
734
 
735
+ ### Methods to support deprecated usage of SAE.from_pretrained() ###
736
+
737
+ def __getitem__(self, index: int) -> Any:
738
+ """
739
+ Support indexing for backward compatibility with tuple unpacking.
740
+ DEPRECATED: SAE.from_pretrained() no longer returns a tuple.
741
+ Use SAE.from_pretrained_with_cfg_and_sparsity() instead.
742
+ """
743
+ warnings.warn(
744
+ "Indexing SAE objects is deprecated. SAE.from_pretrained() now returns "
745
+ "only the SAE object. Use SAE.from_pretrained_with_cfg_and_sparsity() "
746
+ "to get the config dict and sparsity as well.",
747
+ DeprecationWarning,
748
+ stacklevel=2,
749
+ )
750
+
751
+ if index == 0:
752
+ return self
753
+ if index == 1:
754
+ return self.cfg.to_dict()
755
+ if index == 2:
756
+ return None
757
+ raise IndexError(f"SAE tuple index {index} out of range")
758
+
759
+ def __iter__(self):
760
+ """
761
+ Support unpacking for backward compatibility with tuple unpacking.
762
+ DEPRECATED: SAE.from_pretrained() no longer returns a tuple.
763
+ Use SAE.from_pretrained_with_cfg_and_sparsity() instead.
764
+ """
765
+ warnings.warn(
766
+ "Unpacking SAE objects is deprecated. SAE.from_pretrained() now returns "
767
+ "only the SAE object. Use SAE.from_pretrained_with_cfg_and_sparsity() "
768
+ "to get the config dict and sparsity as well.",
769
+ DeprecationWarning,
770
+ stacklevel=2,
771
+ )
772
+
773
+ yield self
774
+ yield self.cfg.to_dict()
775
+ yield None
776
+
777
+ def __len__(self) -> int:
778
+ """
779
+ Support len() for backward compatibility with tuple unpacking.
780
+ DEPRECATED: SAE.from_pretrained() no longer returns a tuple.
781
+ Use SAE.from_pretrained_with_cfg_and_sparsity() instead.
782
+ """
783
+ warnings.warn(
784
+ "Getting length of SAE objects is deprecated. SAE.from_pretrained() now returns "
785
+ "only the SAE object. Use SAE.from_pretrained_with_cfg_and_sparsity() "
786
+ "to get the config dict and sparsity as well.",
787
+ DeprecationWarning,
788
+ stacklevel=2,
789
+ )
790
+
791
+ return 3
792
+
646
793
 
647
794
  @dataclass(kw_only=True)
648
795
  class TrainingSAEConfig(SAEConfig, ABC):
649
- noise_scale: float = 0.0
650
- mse_loss_normalization: str | None = None
651
796
  # https://transformer-circuits.pub/2024/april-update/index.html#training-saes
652
797
  # 0.1 corresponds to the "heuristic" initialization, use None to disable
653
798
  decoder_init_norm: float | None = 0.1
@@ -680,9 +825,6 @@ class TrainingSAEConfig(SAEConfig, ABC):
680
825
  def from_dict(
681
826
  cls: type[T_TRAINING_SAE_CONFIG], config_dict: dict[str, Any]
682
827
  ) -> T_TRAINING_SAE_CONFIG:
683
- # remove any keys that are not in the dataclass
684
- # since we sometimes enhance the config with the whole LM runner config
685
- valid_config_dict = filter_valid_dataclass_fields(config_dict, cls)
686
828
  cfg_class = cls
687
829
  if "architecture" in config_dict:
688
830
  cfg_class = get_sae_training_class(config_dict["architecture"])[1]
@@ -690,6 +832,9 @@ class TrainingSAEConfig(SAEConfig, ABC):
690
832
  raise ValueError(
691
833
  f"SAE config class {cls} does not match dict config class {type(cfg_class)}"
692
834
  )
835
+ # remove any keys that are not in the dataclass
836
+ # since we sometimes enhance the config with the whole LM runner config
837
+ valid_config_dict = filter_valid_dataclass_fields(config_dict, cfg_class)
693
838
  if "metadata" in config_dict:
694
839
  valid_config_dict["metadata"] = SAEMetadata(**config_dict["metadata"])
695
840
  return cfg_class(**valid_config_dict)
@@ -698,6 +843,7 @@ class TrainingSAEConfig(SAEConfig, ABC):
698
843
  return {
699
844
  **super().to_dict(),
700
845
  **asdict(self),
846
+ "metadata": self.metadata.to_dict(),
701
847
  "architecture": self.architecture(),
702
848
  }
703
849
 
@@ -708,12 +854,14 @@ class TrainingSAEConfig(SAEConfig, ABC):
708
854
  Creates a dictionary containing attributes corresponding to the fields
709
855
  defined in the base SAEConfig class.
710
856
  """
711
- base_config_field_names = {f.name for f in fields(SAEConfig)}
857
+ base_sae_cfg_class = get_sae_class(self.architecture())[1]
858
+ base_config_field_names = {f.name for f in fields(base_sae_cfg_class)}
712
859
  result_dict = {
713
860
  field_name: getattr(self, field_name)
714
861
  for field_name in base_config_field_names
715
862
  }
716
863
  result_dict["architecture"] = self.architecture()
864
+ result_dict["metadata"] = self.metadata.to_dict()
717
865
  return result_dict
718
866
 
719
867
 
@@ -726,7 +874,7 @@ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
726
874
  # Turn off hook_z reshaping for training mode - the activation store
727
875
  # is expected to handle reshaping before passing data to the SAE
728
876
  self.turn_off_forward_pass_hook_z_reshaping()
729
- self.mse_loss_fn = self._get_mse_loss_fn()
877
+ self.mse_loss_fn = mse_loss
730
878
 
731
879
  @abstractmethod
732
880
  def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]: ...
@@ -861,27 +1009,6 @@ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
861
1009
  """
862
1010
  return self.process_state_dict_for_saving(state_dict)
863
1011
 
864
- def _get_mse_loss_fn(self) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
865
- """Get the MSE loss function based on config."""
866
-
867
- def standard_mse_loss_fn(
868
- preds: torch.Tensor, target: torch.Tensor
869
- ) -> torch.Tensor:
870
- return torch.nn.functional.mse_loss(preds, target, reduction="none")
871
-
872
- def batch_norm_mse_loss_fn(
873
- preds: torch.Tensor, target: torch.Tensor
874
- ) -> torch.Tensor:
875
- target_centered = target - target.mean(dim=0, keepdim=True)
876
- normalization = target_centered.norm(dim=-1, keepdim=True)
877
- return torch.nn.functional.mse_loss(preds, target, reduction="none") / (
878
- normalization + 1e-6
879
- )
880
-
881
- if self.cfg.mse_loss_normalization == "dense_batch":
882
- return batch_norm_mse_loss_fn
883
- return standard_mse_loss_fn
884
-
885
1012
  @torch.no_grad()
886
1013
  def remove_gradient_parallel_to_decoder_directions(self) -> None:
887
1014
  """Remove gradient components parallel to decoder directions."""
@@ -943,3 +1070,7 @@ def _disable_hooks(sae: SAE[Any]):
943
1070
  finally:
944
1071
  for hook_name, hook in sae.hook_dict.items():
945
1072
  setattr(sae, hook_name, hook)
1073
+
1074
+
1075
+ def mse_loss(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
1076
+ return torch.nn.functional.mse_loss(preds, target, reduction="none")
@@ -67,7 +67,7 @@ class StandardSAE(SAE[StandardSAEConfig]):
67
67
  sae_in = self.process_sae_in(x)
68
68
  # Compute the pre-activation values
69
69
  hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
70
- # Apply the activation function (e.g., ReLU, tanh-relu, depending on config)
70
+ # Apply the activation function (e.g., ReLU, depending on config)
71
71
  return self.hook_sae_acts_post(self.activation_fn(hidden_pre))
72
72
 
73
73
  def decode(
@@ -81,7 +81,7 @@ class StandardSAE(SAE[StandardSAEConfig]):
81
81
  sae_out_pre = feature_acts @ self.W_dec + self.b_dec
82
82
  # 2) hook reconstruction
83
83
  sae_out_pre = self.hook_sae_recons(sae_out_pre)
84
- # 4) optional out-normalization (e.g. constant_norm_rescale or layer_norm)
84
+ # 4) optional out-normalization (e.g. constant_norm_rescale)
85
85
  sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
86
86
  # 5) if hook_z is enabled, rearrange back to (..., n_heads, d_head).
87
87
  return self.reshape_fn_out(sae_out_pre, self.d_head)
@@ -136,16 +136,9 @@ class StandardTrainingSAE(TrainingSAE[StandardTrainingSAEConfig]):
136
136
  sae_in = self.process_sae_in(x)
137
137
  # Compute the pre-activation (and allow for a hook if desired)
138
138
  hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc) # type: ignore
139
- # Add noise during training for robustness (scaled by noise_scale from the configuration)
140
- if self.training and self.cfg.noise_scale > 0:
141
- hidden_pre_noised = (
142
- hidden_pre + torch.randn_like(hidden_pre) * self.cfg.noise_scale
143
- )
144
- else:
145
- hidden_pre_noised = hidden_pre
146
139
  # Apply the activation function (and any post-activation hook)
147
- feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre_noised))
148
- return feature_acts, hidden_pre_noised
140
+ feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre))
141
+ return feature_acts, hidden_pre
149
142
 
150
143
  def calculate_aux_loss(
151
144
  self,
sae_lens/saes/topk_sae.py CHANGED
@@ -91,8 +91,7 @@ class TopKSAE(SAE[TopKSAEConfig]):
91
91
  ) -> Float[torch.Tensor, "... d_sae"]:
92
92
  """
93
93
  Converts input x into feature activations.
94
- Uses topk activation from the config (cfg.activation_fn == "topk")
95
- under the hood.
94
+ Uses topk activation under the hood.
96
95
  """
97
96
  sae_in = self.process_sae_in(x)
98
97
  hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
@@ -116,6 +115,13 @@ class TopKSAE(SAE[TopKSAEConfig]):
116
115
  def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
117
116
  return TopK(self.cfg.k)
118
117
 
118
+ @override
119
+ @torch.no_grad()
120
+ def fold_W_dec_norm(self) -> None:
121
+ raise NotImplementedError(
122
+ "Folding W_dec_norm is not safe for TopKSAEs, as this may change the topk activations"
123
+ )
124
+
119
125
 
120
126
  @dataclass
121
127
  class TopKTrainingSAEConfig(TrainingSAEConfig):
@@ -156,18 +162,11 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
156
162
  sae_in = self.process_sae_in(x)
157
163
  hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
158
164
 
159
- # Inject noise if training
160
- if self.training and self.cfg.noise_scale > 0:
161
- hidden_pre_noised = (
162
- hidden_pre + torch.randn_like(hidden_pre) * self.cfg.noise_scale
163
- )
164
- else:
165
- hidden_pre_noised = hidden_pre
166
-
167
165
  # Apply the TopK activation function (already set in self.activation_fn if config is "topk")
168
- feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre_noised))
169
- return feature_acts, hidden_pre_noised
166
+ feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre))
167
+ return feature_acts, hidden_pre
170
168
 
169
+ @override
171
170
  def calculate_aux_loss(
172
171
  self,
173
172
  step_input: TrainStepInput,
@@ -184,6 +183,13 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
184
183
  )
185
184
  return {"auxiliary_reconstruction_loss": topk_loss}
186
185
 
186
+ @override
187
+ @torch.no_grad()
188
+ def fold_W_dec_norm(self) -> None:
189
+ raise NotImplementedError(
190
+ "Folding W_dec_norm is not safe for TopKSAEs, as this may change the topk activations"
191
+ )
192
+
187
193
  @override
188
194
  def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
189
195
  return TopK(self.cfg.k)
@@ -3,7 +3,7 @@ from dataclasses import dataclass
3
3
  from statistics import mean
4
4
 
5
5
  import torch
6
- from tqdm import tqdm
6
+ from tqdm.auto import tqdm
7
7
 
8
8
  from sae_lens.training.types import DataProvider
9
9
 
@@ -161,8 +161,6 @@ class ActivationsStore:
161
161
  ) -> ActivationsStore:
162
162
  if sae.cfg.metadata.hook_name is None:
163
163
  raise ValueError("hook_name is required")
164
- if sae.cfg.metadata.hook_head_index is None:
165
- raise ValueError("hook_head_index is required")
166
164
  if sae.cfg.metadata.context_size is None:
167
165
  raise ValueError("context_size is required")
168
166
  if sae.cfg.metadata.prepend_bos is None:
@@ -430,7 +428,7 @@ class ActivationsStore:
430
428
  ):
431
429
  # temporalily set estimated_norm_scaling_factor to 1.0 so the dataloader works
432
430
  self.estimated_norm_scaling_factor = 1.0
433
- acts = self.next_batch()[0]
431
+ acts = self.next_batch()[:, 0]
434
432
  self.estimated_norm_scaling_factor = None
435
433
  norms_per_batch.append(acts.norm(dim=-1).mean().item())
436
434
  mean_norm = np.mean(norms_per_batch)
@@ -7,7 +7,7 @@ import torch
7
7
  import wandb
8
8
  from safetensors.torch import save_file
9
9
  from torch.optim import Adam
10
- from tqdm import tqdm
10
+ from tqdm.auto import tqdm
11
11
 
12
12
  from sae_lens import __version__
13
13
  from sae_lens.config import SAETrainerConfig
@@ -161,6 +161,7 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
161
161
  return (self.n_forward_passes_since_fired > self.cfg.dead_feature_window).bool()
162
162
 
163
163
  def fit(self) -> T_TRAINING_SAE:
164
+ self.sae.to(self.cfg.device)
164
165
  pbar = tqdm(total=self.cfg.total_training_samples, desc="Training SAE")
165
166
 
166
167
  if self.sae.cfg.normalize_activations == "expected_average_only_in":
@@ -194,10 +195,11 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
194
195
  )
195
196
  self.activation_scaler.scaling_factor = None
196
197
 
197
- # save final sae group to checkpoints folder
198
+ # save final inference sae group to checkpoints folder
198
199
  self.save_checkpoint(
199
200
  checkpoint_name=f"final_{self.n_training_samples}",
200
201
  wandb_aliases=["final_model"],
202
+ save_inference_model=True,
201
203
  )
202
204
 
203
205
  pbar.close()
@@ -207,11 +209,17 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
207
209
  self,
208
210
  checkpoint_name: str,
209
211
  wandb_aliases: list[str] | None = None,
212
+ save_inference_model: bool = False,
210
213
  ) -> None:
211
214
  checkpoint_path = Path(self.cfg.checkpoint_path) / checkpoint_name
212
215
  checkpoint_path.mkdir(exist_ok=True, parents=True)
213
216
 
214
- weights_path, cfg_path = self.sae.save_model(str(checkpoint_path))
217
+ save_fn = (
218
+ self.sae.save_inference_model
219
+ if save_inference_model
220
+ else self.sae.save_model
221
+ )
222
+ weights_path, cfg_path = save_fn(str(checkpoint_path))
215
223
 
216
224
  sparsity_path = checkpoint_path / SPARSITY_FILENAME
217
225
  save_file({"sparsity": self.log_feature_sparsity}, sparsity_path)
@@ -88,7 +88,7 @@ def _create_default_readme(repo_id: str, sae_ids: Iterable[str]) -> str:
88
88
  ```python
89
89
  from sae_lens import SAE
90
90
 
91
- sae, cfg_dict, sparsity = SAE.from_pretrained("{repo_id}", "<sae_id>")
91
+ sae = SAE.from_pretrained("{repo_id}", "<sae_id>")
92
92
  ```
93
93
  """
94
94
  )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: sae-lens
3
- Version: 6.0.0rc3
3
+ Version: 6.0.0rc5
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  Keywords: deep-learning,sparse-autoencoders,mechanistic-interpretability,PyTorch
@@ -80,7 +80,7 @@ Pre-trained SAEs for various models can be imported via SAE Lens. See this [page
80
80
 
81
81
  ## Join the Slack!
82
82
 
83
- Feel free to join the [Open Source Mechanistic Interpretability Slack](https://join.slack.com/t/opensourcemechanistic/shared_invite/zt-2o756ku1c-_yKBeUQMVfS_p_qcK6QLeA) for support!
83
+ Feel free to join the [Open Source Mechanistic Interpretability Slack](https://join.slack.com/t/opensourcemechanistic/shared_invite/zt-375zalm04-GFd5tdBU1yLKlu_T_JSqZQ) for support!
84
84
 
85
85
  ## Citation
86
86
 
@@ -0,0 +1,37 @@
1
+ sae_lens/__init__.py,sha256=hiHDLT9_1V7iVulw5hwqDqDj2HVxUR9I88xOfYx6X94,2861
2
+ sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
+ sae_lens/analysis/hooked_sae_transformer.py,sha256=Eyg1Y2hVIHNuiiLOCTgzstOuW6iA-7hPHqaGR8y_vMs,13809
4
+ sae_lens/analysis/neuronpedia_integration.py,sha256=MrENqc81Mc2SMbxGjbwHzpkGUCAFKSf0i4EdaUF2Oj4,18707
5
+ sae_lens/cache_activations_runner.py,sha256=L5hhuU2-zPQr2S3L64GMKKLeMQfqXxwDl8NbuOtrybI,12567
6
+ sae_lens/config.py,sha256=9Lg4HkQvj1t9QZJdmC071lyJMc_iqNQknosT7zOYfwM,27278
7
+ sae_lens/constants.py,sha256=CSjmiZ-bhjQeVLyRvWxAjBokCgkfM8mnvd7-vxLIWTY,639
8
+ sae_lens/evals.py,sha256=kQyrzczKaVD9rHwfFa_DxL_gMXDxsoIVHmsFIPIU2bY,38696
9
+ sae_lens/llm_sae_training_runner.py,sha256=58XbDylw2fPOD7C-ZfSAjeNqJLXB05uHGTuiYVVbXXY,13354
10
+ sae_lens/load_model.py,sha256=dBB_9gO6kWyQ4sXHq7qB8T3YUlXm3PGwYcpR4UVW4QY,8633
11
+ sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
+ sae_lens/loading/pretrained_sae_loaders.py,sha256=5XEU4uFFeGCePwqDwhlE7CrFGRSI0U9Cu-UQVa33Y1E,36432
13
+ sae_lens/loading/pretrained_saes_directory.py,sha256=4Vn-Jex6SveD7EbxcSOBv8cx1gkPfUMLU1QOP-ww1ZE,3752
14
+ sae_lens/pretokenize_runner.py,sha256=0nHQq3s_d80VS8iVK4-e6y_orAYVO8c4RrLGtIDfK_E,6885
15
+ sae_lens/pretrained_saes.yaml,sha256=nhHW1auhyi4GHYrjUnHQqbNVhI5cMJv-HThzbzU1xG0,574145
16
+ sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
17
+ sae_lens/saes/__init__.py,sha256=v6mfeDzyGYtT6x5SszAQtkldTXwPE-V_iwOlrT_pDwQ,1008
18
+ sae_lens/saes/gated_sae.py,sha256=0zd66bH04nsaGk3bxHk10hsZofa2GrFbMo15LOsuqgU,9233
19
+ sae_lens/saes/jumprelu_sae.py,sha256=iwmPQJ4XpIxzgosty680u8Zj7x1uVZhM75kPOT3obi0,12060
20
+ sae_lens/saes/sae.py,sha256=ZEXEXFVtrtFrzuOV3nyweTBleNCV4EDGh1ImaF32uqg,39618
21
+ sae_lens/saes/standard_sae.py,sha256=PfkGLsw_6La3PXHOQL0u7qQsaZsXCJqYCeCcRDj5n64,6274
22
+ sae_lens/saes/topk_sae.py,sha256=kmry1FE1H06OvCfn84V-j2JfWGKcU5b2urwAq_Oq5j4,9893
23
+ sae_lens/tokenization_and_batching.py,sha256=oUAscjy_LPOrOb8_Ty6eLAcZ0B3HB_wiWjWktgolhG0,4314
24
+ sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
25
+ sae_lens/training/activation_scaler.py,sha256=seEE-2Qd2JMHxqgnsNWPt-DGtYGZxWPnOwCGuVNSOtI,1719
26
+ sae_lens/training/activations_store.py,sha256=z8erbiB6ODbsqlu-bwEWbyj4XZvgsVgjCRBuQovqp2Q,32612
27
+ sae_lens/training/mixing_buffer.py,sha256=vDpYG5ZE70szDvBsRKcNHEES3h_WTKJ16qDYk5jPOVA,2015
28
+ sae_lens/training/optim.py,sha256=KXdOym-Ly3f2aFbndRc0JEH0Wa7u1BE5ljxGN3YtouQ,6836
29
+ sae_lens/training/sae_trainer.py,sha256=9K0VudwSTJp9OlCVzaU_ngZ0WlYNrN6-ozTCCAxR9_k,15421
30
+ sae_lens/training/types.py,sha256=qSjmGzXf3MLalygG0psnVjmhX_mpLmL47MQtZfe7qxg,81
31
+ sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
32
+ sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
33
+ sae_lens/util.py,sha256=mCwLAilGMVo8Scm7CIsCafU7GsfmBvCcjwmloI4Ly7Y,1718
34
+ sae_lens-6.0.0rc5.dist-info/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
35
+ sae_lens-6.0.0rc5.dist-info/METADATA,sha256=ZrBaBFeIuM-ZJ9r0HHKakxnx3tGv7Zf6l_Z2OIdBxIU,5326
36
+ sae_lens-6.0.0rc5.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
37
+ sae_lens-6.0.0rc5.dist-info/RECORD,,
@@ -1,101 +0,0 @@
1
- from types import SimpleNamespace
2
-
3
- import torch
4
- import tqdm
5
-
6
-
7
- def weighted_average(points: torch.Tensor, weights: torch.Tensor):
8
- weights = weights / weights.sum()
9
- return (points * weights.view(-1, 1)).sum(dim=0)
10
-
11
-
12
- @torch.no_grad()
13
- def geometric_median_objective(
14
- median: torch.Tensor, points: torch.Tensor, weights: torch.Tensor
15
- ) -> torch.Tensor:
16
- norms = torch.linalg.norm(points - median.view(1, -1), dim=1) # type: ignore
17
-
18
- return (norms * weights).sum()
19
-
20
-
21
- def compute_geometric_median(
22
- points: torch.Tensor,
23
- weights: torch.Tensor | None = None,
24
- eps: float = 1e-6,
25
- maxiter: int = 100,
26
- ftol: float = 1e-20,
27
- do_log: bool = False,
28
- ):
29
- """
30
- :param points: ``torch.Tensor`` of shape ``(n, d)``
31
- :param weights: Optional ``torch.Tensor`` of shape :math:``(n,)``.
32
- :param eps: Smallest allowed value of denominator, to avoid divide by zero.
33
- Equivalently, this is a smoothing parameter. Default 1e-6.
34
- :param maxiter: Maximum number of Weiszfeld iterations. Default 100
35
- :param ftol: If objective value does not improve by at least this `ftol` fraction, terminate the algorithm. Default 1e-20.
36
- :param do_log: If true will return a log of function values encountered through the course of the algorithm
37
- :return: SimpleNamespace object with fields
38
- - `median`: estimate of the geometric median, which is a ``torch.Tensor`` object of shape :math:``(d,)``
39
- - `termination`: string explaining how the algorithm terminated.
40
- - `logs`: function values encountered through the course of the algorithm in a list (None if do_log is false).
41
- """
42
- with torch.no_grad():
43
- if weights is None:
44
- weights = torch.ones((points.shape[0],), device=points.device)
45
- # initialize median estimate at mean
46
- new_weights = weights
47
- median = weighted_average(points, weights)
48
- objective_value = geometric_median_objective(median, points, weights)
49
- logs = [objective_value] if do_log else None
50
-
51
- # Weiszfeld iterations
52
- early_termination = False
53
- pbar = tqdm.tqdm(range(maxiter))
54
- for _ in pbar:
55
- prev_obj_value = objective_value
56
-
57
- norms = torch.linalg.norm(points - median.view(1, -1), dim=1) # type: ignore
58
- new_weights = weights / torch.clamp(norms, min=eps)
59
- median = weighted_average(points, new_weights)
60
- objective_value = geometric_median_objective(median, points, weights)
61
-
62
- if logs is not None:
63
- logs.append(objective_value)
64
- if abs(prev_obj_value - objective_value) <= ftol * objective_value:
65
- early_termination = True
66
- break
67
-
68
- pbar.set_description(f"Objective value: {objective_value:.4f}")
69
-
70
- median = weighted_average(points, new_weights) # allow autodiff to track it
71
- return SimpleNamespace(
72
- median=median,
73
- new_weights=new_weights,
74
- termination=(
75
- "function value converged within tolerance"
76
- if early_termination
77
- else "maximum iterations reached"
78
- ),
79
- logs=logs,
80
- )
81
-
82
-
83
- if __name__ == "__main__":
84
- import time
85
-
86
- TOLERANCE = 1e-2
87
-
88
- dim1 = 10000
89
- dim2 = 768
90
- device = "cuda" if torch.cuda.is_available() else "cpu"
91
-
92
- sample = (
93
- torch.randn((dim1, dim2), device=device) * 100
94
- ) # seems to be the order of magnitude of the actual use case
95
- weights = torch.randn((dim1,), device=device)
96
-
97
- torch.tensor(weights, device=device)
98
-
99
- tic = time.perf_counter()
100
- new = compute_geometric_median(sample, weights=weights, maxiter=100)
101
- print(f"new code takes {time.perf_counter()-tic} seconds!") # noqa: T201