sae-lens 6.15.0__py3-none-any.whl → 6.24.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sae_lens/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.15.0"
2
+ __version__ = "6.24.1"
3
3
 
4
4
  import logging
5
5
 
@@ -15,6 +15,8 @@ from sae_lens.saes import (
15
15
  GatedTrainingSAEConfig,
16
16
  JumpReLUSAE,
17
17
  JumpReLUSAEConfig,
18
+ JumpReLUSkipTranscoder,
19
+ JumpReLUSkipTranscoderConfig,
18
20
  JumpReLUTrainingSAE,
19
21
  JumpReLUTrainingSAEConfig,
20
22
  JumpReLUTranscoder,
@@ -28,6 +30,8 @@ from sae_lens.saes import (
28
30
  StandardSAEConfig,
29
31
  StandardTrainingSAE,
30
32
  StandardTrainingSAEConfig,
33
+ TemporalSAE,
34
+ TemporalSAEConfig,
31
35
  TopKSAE,
32
36
  TopKSAEConfig,
33
37
  TopKTrainingSAE,
@@ -103,8 +107,12 @@ __all__ = [
103
107
  "SkipTranscoderConfig",
104
108
  "JumpReLUTranscoder",
105
109
  "JumpReLUTranscoderConfig",
110
+ "JumpReLUSkipTranscoder",
111
+ "JumpReLUSkipTranscoderConfig",
106
112
  "MatryoshkaBatchTopKTrainingSAE",
107
113
  "MatryoshkaBatchTopKTrainingSAEConfig",
114
+ "TemporalSAE",
115
+ "TemporalSAEConfig",
108
116
  ]
109
117
 
110
118
 
@@ -127,3 +135,7 @@ register_sae_training_class(
127
135
  register_sae_class("transcoder", Transcoder, TranscoderConfig)
128
136
  register_sae_class("skip_transcoder", SkipTranscoder, SkipTranscoderConfig)
129
137
  register_sae_class("jumprelu_transcoder", JumpReLUTranscoder, JumpReLUTranscoderConfig)
138
+ register_sae_class(
139
+ "jumprelu_skip_transcoder", JumpReLUSkipTranscoder, JumpReLUSkipTranscoderConfig
140
+ )
141
+ register_sae_class("temporal", TemporalSAE, TemporalSAEConfig)
@@ -3,7 +3,6 @@ from contextlib import contextmanager
3
3
  from typing import Any, Callable
4
4
 
5
5
  import torch
6
- from jaxtyping import Float
7
6
  from transformer_lens.ActivationCache import ActivationCache
8
7
  from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP
9
8
  from transformer_lens.hook_points import HookPoint # Hooking utilities
@@ -11,8 +10,8 @@ from transformer_lens.HookedTransformer import HookedTransformer
11
10
 
12
11
  from sae_lens.saes.sae import SAE
13
12
 
14
- SingleLoss = Float[torch.Tensor, ""] # Type alias for a single element tensor
15
- LossPerToken = Float[torch.Tensor, "batch pos-1"]
13
+ SingleLoss = torch.Tensor # Type alias for a single element tensor
14
+ LossPerToken = torch.Tensor
16
15
  Loss = SingleLoss | LossPerToken
17
16
 
18
17
 
@@ -171,12 +170,7 @@ class HookedSAETransformer(HookedTransformer):
171
170
  reset_saes_end: bool = True,
172
171
  use_error_term: bool | None = None,
173
172
  **model_kwargs: Any,
174
- ) -> (
175
- None
176
- | Float[torch.Tensor, "batch pos d_vocab"]
177
- | Loss
178
- | tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss]
179
- ):
173
+ ) -> None | torch.Tensor | Loss | tuple[torch.Tensor, Loss]:
180
174
  """Wrapper around HookedTransformer forward pass.
181
175
 
182
176
  Runs the model with the given SAEs attached for one forward pass, then removes them. By default, will reset all SAEs to original state after.
@@ -203,10 +197,7 @@ class HookedSAETransformer(HookedTransformer):
203
197
  remove_batch_dim: bool = False,
204
198
  **kwargs: Any,
205
199
  ) -> tuple[
206
- None
207
- | Float[torch.Tensor, "batch pos d_vocab"]
208
- | Loss
209
- | tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss],
200
+ None | torch.Tensor | Loss | tuple[torch.Tensor, Loss],
210
201
  ActivationCache | dict[str, torch.Tensor],
211
202
  ]:
212
203
  """Wrapper around 'run_with_cache' in HookedTransformer.
@@ -9,8 +9,7 @@ import torch
9
9
  from datasets import Array2D, Dataset, Features, Sequence, Value
10
10
  from datasets.fingerprint import generate_fingerprint
11
11
  from huggingface_hub import HfApi
12
- from jaxtyping import Float, Int
13
- from tqdm import tqdm
12
+ from tqdm.auto import tqdm
14
13
  from transformer_lens.HookedTransformer import HookedRootModule
15
14
 
16
15
  from sae_lens import logger
@@ -318,8 +317,8 @@ class CacheActivationsRunner:
318
317
  def _create_shard(
319
318
  self,
320
319
  buffer: tuple[
321
- Float[torch.Tensor, "(bs context_size) d_in"],
322
- Int[torch.Tensor, "(bs context_size)"] | None,
320
+ torch.Tensor, # shape: (bs context_size) d_in
321
+ torch.Tensor | None, # shape: (bs context_size) or None
323
322
  ],
324
323
  ) -> Dataset:
325
324
  hook_names = [self.cfg.hook_name]
sae_lens/config.py CHANGED
@@ -18,6 +18,7 @@ from datasets import (
18
18
 
19
19
  from sae_lens import __version__, logger
20
20
  from sae_lens.constants import DTYPE_MAP
21
+ from sae_lens.registry import get_sae_training_class
21
22
  from sae_lens.saes.sae import TrainingSAEConfig
22
23
 
23
24
  if TYPE_CHECKING:
@@ -171,6 +172,7 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
171
172
  n_checkpoints (int): The number of checkpoints to save during training. 0 means no checkpoints.
172
173
  checkpoint_path (str | None): The path to save checkpoints. A unique ID will be appended to this path. Set to None to disable checkpoint saving. (default is "checkpoints")
173
174
  save_final_checkpoint (bool): Whether to include an additional final checkpoint when training is finished. (default is False).
175
+ resume_from_checkpoint (str | None): The path to the checkpoint to resume training from. (default is None).
174
176
  output_path (str | None): The path to save outputs. Set to None to disable output saving. (default is "output")
175
177
  verbose (bool): Whether to print verbose output. (default is True)
176
178
  model_kwargs (dict[str, Any]): Keyword arguments for `model.run_with_cache`
@@ -261,6 +263,7 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
261
263
  checkpoint_path: str | None = "checkpoints"
262
264
  save_final_checkpoint: bool = False
263
265
  output_path: str | None = "output"
266
+ resume_from_checkpoint: str | None = None
264
267
 
265
268
  # Misc
266
269
  verbose: bool = True
@@ -385,8 +388,11 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
385
388
  return self.sae.to_dict()
386
389
 
387
390
  def to_dict(self) -> dict[str, Any]:
388
- # Make a shallow copy of config's dictionary
389
- d = dict(self.__dict__)
391
+ """
392
+ Convert the config to a dictionary.
393
+ """
394
+
395
+ d = asdict(self)
390
396
 
391
397
  d["logger"] = asdict(self.logger)
392
398
  d["sae"] = self.sae.to_dict()
@@ -396,6 +402,37 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
396
402
  d["act_store_device"] = str(self.act_store_device)
397
403
  return d
398
404
 
405
+ @classmethod
406
+ def from_dict(cls, cfg_dict: dict[str, Any]) -> "LanguageModelSAERunnerConfig[Any]":
407
+ """
408
+ Load a LanguageModelSAERunnerConfig from a dictionary given by `to_dict`.
409
+
410
+ Args:
411
+ cfg_dict (dict[str, Any]): The dictionary to load the config from.
412
+
413
+ Returns:
414
+ LanguageModelSAERunnerConfig: The loaded config.
415
+ """
416
+ if "sae" not in cfg_dict:
417
+ raise ValueError("sae field is required in the config dictionary")
418
+ if "architecture" not in cfg_dict["sae"]:
419
+ raise ValueError("architecture field is required in the sae dictionary")
420
+ if "logger" not in cfg_dict:
421
+ raise ValueError("logger field is required in the config dictionary")
422
+ sae_config_class = get_sae_training_class(cfg_dict["sae"]["architecture"])[1]
423
+ sae_cfg = sae_config_class.from_dict(cfg_dict["sae"])
424
+ logger_cfg = LoggingConfig(**cfg_dict["logger"])
425
+ updated_cfg_dict: dict[str, Any] = {
426
+ **cfg_dict,
427
+ "sae": sae_cfg,
428
+ "logger": logger_cfg,
429
+ }
430
+ output = cls(**updated_cfg_dict)
431
+ # the post_init always appends to checkpoint path, so we need to set it explicitly here.
432
+ if "checkpoint_path" in cfg_dict:
433
+ output.checkpoint_path = cfg_dict["checkpoint_path"]
434
+ return output
435
+
399
436
  def to_sae_trainer_config(self) -> "SAETrainerConfig":
400
437
  return SAETrainerConfig(
401
438
  n_checkpoints=self.n_checkpoints,
sae_lens/constants.py CHANGED
@@ -17,5 +17,6 @@ SAE_WEIGHTS_FILENAME = "sae_weights.safetensors"
17
17
  SAE_CFG_FILENAME = "cfg.json"
18
18
  RUNNER_CFG_FILENAME = "runner_cfg.json"
19
19
  SPARSIFY_WEIGHTS_FILENAME = "sae.safetensors"
20
+ TRAINER_STATE_FILENAME = "trainer_state.pt"
20
21
  ACTIVATIONS_STORE_STATE_FILENAME = "activations_store_state.safetensors"
21
22
  ACTIVATION_SCALER_CFG_FILENAME = "activation_scaler.json"
@@ -16,7 +16,6 @@ from typing_extensions import deprecated
16
16
  from sae_lens import logger
17
17
  from sae_lens.config import HfDataset, LanguageModelSAERunnerConfig
18
18
  from sae_lens.constants import (
19
- ACTIVATIONS_STORE_STATE_FILENAME,
20
19
  RUNNER_CFG_FILENAME,
21
20
  SPARSITY_FILENAME,
22
21
  )
@@ -112,6 +111,7 @@ class LanguageModelSAETrainingRunner:
112
111
  override_dataset: HfDataset | None = None,
113
112
  override_model: HookedRootModule | None = None,
114
113
  override_sae: TrainingSAE[Any] | None = None,
114
+ resume_from_checkpoint: Path | str | None = None,
115
115
  ):
116
116
  if override_dataset is not None:
117
117
  logger.warning(
@@ -153,6 +153,7 @@ class LanguageModelSAETrainingRunner:
153
153
  )
154
154
  else:
155
155
  self.sae = override_sae
156
+
156
157
  self.sae.to(self.cfg.device)
157
158
 
158
159
  def run(self):
@@ -185,6 +186,12 @@ class LanguageModelSAETrainingRunner:
185
186
  cfg=self.cfg.to_sae_trainer_config(),
186
187
  )
187
188
 
189
+ if self.cfg.resume_from_checkpoint is not None:
190
+ logger.info(f"Resuming from checkpoint: {self.cfg.resume_from_checkpoint}")
191
+ trainer.load_trainer_state(self.cfg.resume_from_checkpoint)
192
+ self.sae.load_weights_from_checkpoint(self.cfg.resume_from_checkpoint)
193
+ self.activations_store.load_from_checkpoint(self.cfg.resume_from_checkpoint)
194
+
188
195
  self._compile_if_needed()
189
196
  sae = self.run_trainer_with_interruption_handling(trainer)
190
197
 
@@ -304,9 +311,7 @@ class LanguageModelSAETrainingRunner:
304
311
  if checkpoint_path is None:
305
312
  return
306
313
 
307
- self.activations_store.save(
308
- str(checkpoint_path / ACTIVATIONS_STORE_STATE_FILENAME)
309
- )
314
+ self.activations_store.save_to_checkpoint(checkpoint_path)
310
315
 
311
316
  runner_config = self.cfg.to_dict()
312
317
  with open(checkpoint_path / RUNNER_CFG_FILENAME, "w") as f: