sae-lens 6.16.3__py3-none-any.whl → 6.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sae-lens might be problematic. Click here for more details.

sae_lens/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.16.3"
2
+ __version__ = "6.21.0"
3
3
 
4
4
  import logging
5
5
 
@@ -28,6 +28,8 @@ from sae_lens.saes import (
28
28
  StandardSAEConfig,
29
29
  StandardTrainingSAE,
30
30
  StandardTrainingSAEConfig,
31
+ TemporalSAE,
32
+ TemporalSAEConfig,
31
33
  TopKSAE,
32
34
  TopKSAEConfig,
33
35
  TopKTrainingSAE,
@@ -105,6 +107,8 @@ __all__ = [
105
107
  "JumpReLUTranscoderConfig",
106
108
  "MatryoshkaBatchTopKTrainingSAE",
107
109
  "MatryoshkaBatchTopKTrainingSAEConfig",
110
+ "TemporalSAE",
111
+ "TemporalSAEConfig",
108
112
  ]
109
113
 
110
114
 
@@ -127,3 +131,4 @@ register_sae_training_class(
127
131
  register_sae_class("transcoder", Transcoder, TranscoderConfig)
128
132
  register_sae_class("skip_transcoder", SkipTranscoder, SkipTranscoderConfig)
129
133
  register_sae_class("jumprelu_transcoder", JumpReLUTranscoder, JumpReLUTranscoderConfig)
134
+ register_sae_class("temporal", TemporalSAE, TemporalSAEConfig)
@@ -10,7 +10,7 @@ from datasets import Array2D, Dataset, Features, Sequence, Value
10
10
  from datasets.fingerprint import generate_fingerprint
11
11
  from huggingface_hub import HfApi
12
12
  from jaxtyping import Float, Int
13
- from tqdm import tqdm
13
+ from tqdm.auto import tqdm
14
14
  from transformer_lens.HookedTransformer import HookedRootModule
15
15
 
16
16
  from sae_lens import logger
sae_lens/config.py CHANGED
@@ -18,6 +18,7 @@ from datasets import (
18
18
 
19
19
  from sae_lens import __version__, logger
20
20
  from sae_lens.constants import DTYPE_MAP
21
+ from sae_lens.registry import get_sae_training_class
21
22
  from sae_lens.saes.sae import TrainingSAEConfig
22
23
 
23
24
  if TYPE_CHECKING:
@@ -171,6 +172,7 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
171
172
  n_checkpoints (int): The number of checkpoints to save during training. 0 means no checkpoints.
172
173
  checkpoint_path (str | None): The path to save checkpoints. A unique ID will be appended to this path. Set to None to disable checkpoint saving. (default is "checkpoints")
173
174
  save_final_checkpoint (bool): Whether to include an additional final checkpoint when training is finished. (default is False).
175
+ resume_from_checkpoint (str | None): The path to the checkpoint to resume training from. (default is None).
174
176
  output_path (str | None): The path to save outputs. Set to None to disable output saving. (default is "output")
175
177
  verbose (bool): Whether to print verbose output. (default is True)
176
178
  model_kwargs (dict[str, Any]): Keyword arguments for `model.run_with_cache`
@@ -261,6 +263,7 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
261
263
  checkpoint_path: str | None = "checkpoints"
262
264
  save_final_checkpoint: bool = False
263
265
  output_path: str | None = "output"
266
+ resume_from_checkpoint: str | None = None
264
267
 
265
268
  # Misc
266
269
  verbose: bool = True
@@ -385,8 +388,11 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
385
388
  return self.sae.to_dict()
386
389
 
387
390
  def to_dict(self) -> dict[str, Any]:
388
- # Make a shallow copy of config's dictionary
389
- d = dict(self.__dict__)
391
+ """
392
+ Convert the config to a dictionary.
393
+ """
394
+
395
+ d = asdict(self)
390
396
 
391
397
  d["logger"] = asdict(self.logger)
392
398
  d["sae"] = self.sae.to_dict()
@@ -396,6 +402,37 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
396
402
  d["act_store_device"] = str(self.act_store_device)
397
403
  return d
398
404
 
405
+ @classmethod
406
+ def from_dict(cls, cfg_dict: dict[str, Any]) -> "LanguageModelSAERunnerConfig[Any]":
407
+ """
408
+ Load a LanguageModelSAERunnerConfig from a dictionary given by `to_dict`.
409
+
410
+ Args:
411
+ cfg_dict (dict[str, Any]): The dictionary to load the config from.
412
+
413
+ Returns:
414
+ LanguageModelSAERunnerConfig: The loaded config.
415
+ """
416
+ if "sae" not in cfg_dict:
417
+ raise ValueError("sae field is required in the config dictionary")
418
+ if "architecture" not in cfg_dict["sae"]:
419
+ raise ValueError("architecture field is required in the sae dictionary")
420
+ if "logger" not in cfg_dict:
421
+ raise ValueError("logger field is required in the config dictionary")
422
+ sae_config_class = get_sae_training_class(cfg_dict["sae"]["architecture"])[1]
423
+ sae_cfg = sae_config_class.from_dict(cfg_dict["sae"])
424
+ logger_cfg = LoggingConfig(**cfg_dict["logger"])
425
+ updated_cfg_dict: dict[str, Any] = {
426
+ **cfg_dict,
427
+ "sae": sae_cfg,
428
+ "logger": logger_cfg,
429
+ }
430
+ output = cls(**updated_cfg_dict)
431
+ # the post_init always appends to checkpoint path, so we need to set it explicitly here.
432
+ if "checkpoint_path" in cfg_dict:
433
+ output.checkpoint_path = cfg_dict["checkpoint_path"]
434
+ return output
435
+
399
436
  def to_sae_trainer_config(self) -> "SAETrainerConfig":
400
437
  return SAETrainerConfig(
401
438
  n_checkpoints=self.n_checkpoints,
sae_lens/constants.py CHANGED
@@ -17,5 +17,6 @@ SAE_WEIGHTS_FILENAME = "sae_weights.safetensors"
17
17
  SAE_CFG_FILENAME = "cfg.json"
18
18
  RUNNER_CFG_FILENAME = "runner_cfg.json"
19
19
  SPARSIFY_WEIGHTS_FILENAME = "sae.safetensors"
20
+ TRAINER_STATE_FILENAME = "trainer_state.pt"
20
21
  ACTIVATIONS_STORE_STATE_FILENAME = "activations_store_state.safetensors"
21
22
  ACTIVATION_SCALER_CFG_FILENAME = "activation_scaler.json"
@@ -16,7 +16,6 @@ from typing_extensions import deprecated
16
16
  from sae_lens import logger
17
17
  from sae_lens.config import HfDataset, LanguageModelSAERunnerConfig
18
18
  from sae_lens.constants import (
19
- ACTIVATIONS_STORE_STATE_FILENAME,
20
19
  RUNNER_CFG_FILENAME,
21
20
  SPARSITY_FILENAME,
22
21
  )
@@ -112,6 +111,7 @@ class LanguageModelSAETrainingRunner:
112
111
  override_dataset: HfDataset | None = None,
113
112
  override_model: HookedRootModule | None = None,
114
113
  override_sae: TrainingSAE[Any] | None = None,
114
+ resume_from_checkpoint: Path | str | None = None,
115
115
  ):
116
116
  if override_dataset is not None:
117
117
  logger.warning(
@@ -153,6 +153,7 @@ class LanguageModelSAETrainingRunner:
153
153
  )
154
154
  else:
155
155
  self.sae = override_sae
156
+
156
157
  self.sae.to(self.cfg.device)
157
158
 
158
159
  def run(self):
@@ -185,6 +186,12 @@ class LanguageModelSAETrainingRunner:
185
186
  cfg=self.cfg.to_sae_trainer_config(),
186
187
  )
187
188
 
189
+ if self.cfg.resume_from_checkpoint is not None:
190
+ logger.info(f"Resuming from checkpoint: {self.cfg.resume_from_checkpoint}")
191
+ trainer.load_trainer_state(self.cfg.resume_from_checkpoint)
192
+ self.sae.load_weights_from_checkpoint(self.cfg.resume_from_checkpoint)
193
+ self.activations_store.load_from_checkpoint(self.cfg.resume_from_checkpoint)
194
+
188
195
  self._compile_if_needed()
189
196
  sae = self.run_trainer_with_interruption_handling(trainer)
190
197
 
@@ -304,9 +311,7 @@ class LanguageModelSAETrainingRunner:
304
311
  if checkpoint_path is None:
305
312
  return
306
313
 
307
- self.activations_store.save(
308
- str(checkpoint_path / ACTIVATIONS_STORE_STATE_FILENAME)
309
- )
314
+ self.activations_store.save_to_checkpoint(checkpoint_path)
310
315
 
311
316
  runner_config = self.cfg.to_dict()
312
317
  with open(checkpoint_path / RUNNER_CFG_FILENAME, "w") as f:
@@ -523,6 +523,82 @@ def gemma_2_sae_huggingface_loader(
523
523
  return cfg_dict, state_dict, log_sparsity
524
524
 
525
525
 
526
+ def get_goodfire_config_from_hf(
527
+ repo_id: str,
528
+ folder_name: str, # noqa: ARG001
529
+ device: str,
530
+ force_download: bool = False, # noqa: ARG001
531
+ cfg_overrides: dict[str, Any] | None = None,
532
+ ) -> dict[str, Any]:
533
+ cfg_dict = None
534
+ if repo_id == "Goodfire/Llama-3.3-70B-Instruct-SAE-l50":
535
+ if folder_name != "Llama-3.3-70B-Instruct-SAE-l50.pt":
536
+ raise ValueError(f"Unsupported Goodfire SAE: {repo_id}/{folder_name}")
537
+ cfg_dict = {
538
+ "architecture": "standard",
539
+ "d_in": 8192,
540
+ "d_sae": 65536,
541
+ "model_name": "meta-llama/Llama-3.3-70B-Instruct",
542
+ "hook_name": "blocks.50.hook_resid_post",
543
+ "hook_head_index": None,
544
+ "dataset_path": "lmsys/lmsys-chat-1m",
545
+ "apply_b_dec_to_input": False,
546
+ }
547
+ elif repo_id == "Goodfire/Llama-3.1-8B-Instruct-SAE-l19":
548
+ if folder_name != "Llama-3.1-8B-Instruct-SAE-l19.pth":
549
+ raise ValueError(f"Unsupported Goodfire SAE: {repo_id}/{folder_name}")
550
+ cfg_dict = {
551
+ "architecture": "standard",
552
+ "d_in": 4096,
553
+ "d_sae": 65536,
554
+ "model_name": "meta-llama/Llama-3.1-8B-Instruct",
555
+ "hook_name": "blocks.19.hook_resid_post",
556
+ "hook_head_index": None,
557
+ "dataset_path": "lmsys/lmsys-chat-1m",
558
+ "apply_b_dec_to_input": False,
559
+ }
560
+ if cfg_dict is None:
561
+ raise ValueError(f"Unsupported Goodfire SAE: {repo_id}/{folder_name}")
562
+ if device is not None:
563
+ cfg_dict["device"] = device
564
+ if cfg_overrides is not None:
565
+ cfg_dict.update(cfg_overrides)
566
+ return cfg_dict
567
+
568
+
569
+ def get_goodfire_huggingface_loader(
570
+ repo_id: str,
571
+ folder_name: str,
572
+ device: str = "cpu",
573
+ force_download: bool = False,
574
+ cfg_overrides: dict[str, Any] | None = None,
575
+ ) -> tuple[dict[str, Any], dict[str, torch.Tensor], torch.Tensor | None]:
576
+ cfg_dict = get_goodfire_config_from_hf(
577
+ repo_id,
578
+ folder_name,
579
+ device,
580
+ force_download,
581
+ cfg_overrides,
582
+ )
583
+
584
+ # Download the SAE weights
585
+ sae_path = hf_hub_download(
586
+ repo_id=repo_id,
587
+ filename=folder_name,
588
+ force_download=force_download,
589
+ )
590
+ raw_state_dict = torch.load(sae_path, map_location=device)
591
+
592
+ state_dict = {
593
+ "W_enc": raw_state_dict["encoder_linear.weight"].T,
594
+ "W_dec": raw_state_dict["decoder_linear.weight"].T,
595
+ "b_enc": raw_state_dict["encoder_linear.bias"],
596
+ "b_dec": raw_state_dict["decoder_linear.bias"],
597
+ }
598
+
599
+ return cfg_dict, state_dict, None
600
+
601
+
526
602
  def get_llama_scope_config_from_hf(
527
603
  repo_id: str,
528
604
  folder_name: str,
@@ -1475,6 +1551,114 @@ def get_mntss_clt_layer_config_from_hf(
1475
1551
  }
1476
1552
 
1477
1553
 
1554
+ def get_temporal_sae_config_from_hf(
1555
+ repo_id: str,
1556
+ folder_name: str,
1557
+ device: str,
1558
+ force_download: bool = False,
1559
+ cfg_overrides: dict[str, Any] | None = None,
1560
+ ) -> dict[str, Any]:
1561
+ """Get TemporalSAE config without loading weights."""
1562
+ # Download config file
1563
+ conf_path = hf_hub_download(
1564
+ repo_id=repo_id,
1565
+ filename=f"{folder_name}/conf.yaml",
1566
+ force_download=force_download,
1567
+ )
1568
+
1569
+ # Load and parse config
1570
+ with open(conf_path) as f:
1571
+ yaml_config = yaml.safe_load(f)
1572
+
1573
+ # Extract parameters
1574
+ d_in = yaml_config["llm"]["dimin"]
1575
+ exp_factor = yaml_config["sae"]["exp_factor"]
1576
+ d_sae = int(d_in * exp_factor)
1577
+
1578
+ # extract layer from folder_name eg : "layer_12/temporal"
1579
+ layer = re.search(r"layer_(\d+)", folder_name)
1580
+ if layer is None:
1581
+ raise ValueError(f"Could not find layer in folder_name: {folder_name}")
1582
+ layer = int(layer.group(1))
1583
+
1584
+ # Build config dict
1585
+ cfg_dict = {
1586
+ "architecture": "temporal",
1587
+ "hook_name": f"blocks.{layer}.hook_resid_post",
1588
+ "d_in": d_in,
1589
+ "d_sae": d_sae,
1590
+ "n_heads": yaml_config["sae"]["n_heads"],
1591
+ "n_attn_layers": yaml_config["sae"]["n_attn_layers"],
1592
+ "bottleneck_factor": yaml_config["sae"]["bottleneck_factor"],
1593
+ "sae_diff_type": yaml_config["sae"]["sae_diff_type"],
1594
+ "kval_topk": yaml_config["sae"]["kval_topk"],
1595
+ "tied_weights": yaml_config["sae"]["tied_weights"],
1596
+ "dtype": yaml_config["data"]["dtype"],
1597
+ "device": device,
1598
+ "normalize_activations": "constant_scalar_rescale",
1599
+ "activation_normalization_factor": yaml_config["sae"]["scaling_factor"],
1600
+ "apply_b_dec_to_input": True,
1601
+ }
1602
+
1603
+ if cfg_overrides:
1604
+ cfg_dict.update(cfg_overrides)
1605
+
1606
+ return cfg_dict
1607
+
1608
+
1609
+ def temporal_sae_huggingface_loader(
1610
+ repo_id: str,
1611
+ folder_name: str,
1612
+ device: str = "cpu",
1613
+ force_download: bool = False,
1614
+ cfg_overrides: dict[str, Any] | None = None,
1615
+ ) -> tuple[dict[str, Any], dict[str, torch.Tensor], torch.Tensor | None]:
1616
+ """
1617
+ Load TemporalSAE from canrager/temporalSAEs format (safetensors version).
1618
+
1619
+ Expects folder_name to contain:
1620
+ - conf.yaml (configuration)
1621
+ - latest_ckpt.safetensors (model weights)
1622
+ """
1623
+
1624
+ cfg_dict = get_temporal_sae_config_from_hf(
1625
+ repo_id=repo_id,
1626
+ folder_name=folder_name,
1627
+ device=device,
1628
+ force_download=force_download,
1629
+ cfg_overrides=cfg_overrides,
1630
+ )
1631
+
1632
+ # Download checkpoint (safetensors format)
1633
+ ckpt_path = hf_hub_download(
1634
+ repo_id=repo_id,
1635
+ filename=f"{folder_name}/latest_ckpt.safetensors",
1636
+ force_download=force_download,
1637
+ )
1638
+
1639
+ # Load checkpoint from safetensors
1640
+ state_dict_raw = load_file(ckpt_path, device=device)
1641
+
1642
+ # Convert to SAELens naming convention
1643
+ # TemporalSAE uses: D (decoder), E (encoder), b (bias), attn_layers.*
1644
+ state_dict = {}
1645
+
1646
+ # Copy attention layers as-is
1647
+ for key, value in state_dict_raw.items():
1648
+ if key.startswith("attn_layers."):
1649
+ state_dict[key] = value.to(device)
1650
+
1651
+ # Main parameters
1652
+ state_dict["W_dec"] = state_dict_raw["D"].to(device)
1653
+ state_dict["b_dec"] = state_dict_raw["b"].to(device)
1654
+
1655
+ # Handle tied/untied weights
1656
+ if "E" in state_dict_raw:
1657
+ state_dict["W_enc"] = state_dict_raw["E"].to(device)
1658
+
1659
+ return cfg_dict, state_dict, None
1660
+
1661
+
1478
1662
  NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
1479
1663
  "sae_lens": sae_lens_huggingface_loader,
1480
1664
  "connor_rob_hook_z": connor_rob_hook_z_huggingface_loader,
@@ -1487,6 +1671,8 @@ NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
1487
1671
  "gemma_2_transcoder": gemma_2_transcoder_huggingface_loader,
1488
1672
  "mwhanna_transcoder": mwhanna_transcoder_huggingface_loader,
1489
1673
  "mntss_clt_layer_transcoder": mntss_clt_layer_huggingface_loader,
1674
+ "temporal": temporal_sae_huggingface_loader,
1675
+ "goodfire": get_goodfire_huggingface_loader,
1490
1676
  }
1491
1677
 
1492
1678
 
@@ -1502,4 +1688,6 @@ NAMED_PRETRAINED_SAE_CONFIG_GETTERS: dict[str, PretrainedSaeConfigHuggingfaceLoa
1502
1688
  "gemma_2_transcoder": get_gemma_2_transcoder_config_from_hf,
1503
1689
  "mwhanna_transcoder": get_mwhanna_transcoder_config_from_hf,
1504
1690
  "mntss_clt_layer_transcoder": get_mntss_clt_layer_config_from_hf,
1691
+ "temporal": get_temporal_sae_config_from_hf,
1692
+ "goodfire": get_goodfire_config_from_hf,
1505
1693
  }
@@ -1,6 +1,6 @@
1
1
  from dataclasses import dataclass
2
2
  from functools import cache
3
- from importlib import resources
3
+ from importlib.resources import files
4
4
  from typing import Any
5
5
 
6
6
  import yaml
@@ -24,7 +24,8 @@ def get_pretrained_saes_directory() -> dict[str, PretrainedSAELookup]:
24
24
  package = "sae_lens"
25
25
  # Access the file within the package using importlib.resources
26
26
  directory: dict[str, PretrainedSAELookup] = {}
27
- with resources.open_text(package, "pretrained_saes.yaml") as file:
27
+ yaml_file = files(package).joinpath("pretrained_saes.yaml")
28
+ with yaml_file.open("r") as file:
28
29
  # Load the YAML file content
29
30
  data = yaml.safe_load(file)
30
31
  for release, value in data.items():
@@ -68,7 +69,8 @@ def get_norm_scaling_factor(release: str, sae_id: str) -> float | None:
68
69
  float | None: The norm_scaling_factor if it exists, None otherwise.
69
70
  """
70
71
  package = "sae_lens"
71
- with resources.open_text(package, "pretrained_saes.yaml") as file:
72
+ yaml_file = files(package).joinpath("pretrained_saes.yaml")
73
+ with yaml_file.open("r") as file:
72
74
  data = yaml.safe_load(file)
73
75
  if release in data:
74
76
  for sae_info in data[release]["saes"]:
@@ -1,3 +1,35 @@
1
+ temporal-sae-gemma-2-2b:
2
+ conversion_func: temporal
3
+ model: gemma-2-2b
4
+ repo_id: canrager/temporalSAEs
5
+ config_overrides:
6
+ model_name: gemma-2-2b
7
+ hook_name: blocks.12.hook_resid_post
8
+ dataset_path: monology/pile-uncopyrighted
9
+ saes:
10
+ - id: blocks.12.hook_resid_post
11
+ l0: 192
12
+ norm_scaling_factor: 0.00666666667
13
+ path: gemma-2-2B/layer_12/temporal
14
+ neuronpedia: gemma-2-2b/12-temporal-res
15
+ temporal-sae-llama-3.1-8b:
16
+ conversion_func: temporal
17
+ model: meta-llama/Llama-3.1-8B
18
+ repo_id: canrager/temporalSAEs
19
+ config_overrides:
20
+ model_name: meta-llama/Llama-3.1-8B
21
+ dataset_path: monology/pile-uncopyrighted
22
+ saes:
23
+ - id: blocks.15.hook_resid_post
24
+ l0: 256
25
+ norm_scaling_factor: 0.029
26
+ path: llama-3.1-8B/layer_15/temporal
27
+ neuronpedia: llama3.1-8b/15-temporal-res
28
+ - id: blocks.26.hook_resid_post
29
+ l0: 256
30
+ norm_scaling_factor: 0.029
31
+ path: llama-3.1-8B/layer_26/temporal
32
+ neuronpedia: llama3.1-8b/26-temporal-res
1
33
  deepseek-r1-distill-llama-8b-qresearch:
2
34
  conversion_func: deepseek_r1
3
35
  model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
@@ -14882,4 +14914,46 @@ qwen2.5-7b-instruct-andyrdt:
14882
14914
  neuronpedia: qwen2.5-7b-it/23-resid-post-aa
14883
14915
  - id: resid_post_layer_27_trainer_1
14884
14916
  path: resid_post_layer_27/trainer_1
14885
- neuronpedia: qwen2.5-7b-it/27-resid-post-aa
14917
+ neuronpedia: qwen2.5-7b-it/27-resid-post-aa
14918
+
14919
+ gpt-oss-20b-andyrdt:
14920
+ conversion_func: dictionary_learning_1
14921
+ model: openai/gpt-oss-20b
14922
+ repo_id: andyrdt/saes-gpt-oss-20b
14923
+ saes:
14924
+ - id: resid_post_layer_3_trainer_0
14925
+ path: resid_post_layer_3/trainer_0
14926
+ neuronpedia: gpt-oss-20b/3-resid-post-aa
14927
+ - id: resid_post_layer_7_trainer_0
14928
+ path: resid_post_layer_7/trainer_0
14929
+ neuronpedia: gpt-oss-20b/7-resid-post-aa
14930
+ - id: resid_post_layer_11_trainer_0
14931
+ path: resid_post_layer_11/trainer_0
14932
+ neuronpedia: gpt-oss-20b/11-resid-post-aa
14933
+ - id: resid_post_layer_15_trainer_0
14934
+ path: resid_post_layer_15/trainer_0
14935
+ neuronpedia: gpt-oss-20b/15-resid-post-aa
14936
+ - id: resid_post_layer_19_trainer_0
14937
+ path: resid_post_layer_19/trainer_0
14938
+ neuronpedia: gpt-oss-20b/19-resid-post-aa
14939
+ - id: resid_post_layer_23_trainer_0
14940
+ path: resid_post_layer_23/trainer_0
14941
+ neuronpedia: gpt-oss-20b/23-resid-post-aa
14942
+
14943
+ goodfire-llama-3.3-70b-instruct:
14944
+ conversion_func: goodfire
14945
+ model: meta-llama/Llama-3.3-70B-Instruct
14946
+ repo_id: Goodfire/Llama-3.3-70B-Instruct-SAE-l50
14947
+ saes:
14948
+ - id: layer_50
14949
+ path: Llama-3.3-70B-Instruct-SAE-l50.pt
14950
+ l0: 121
14951
+
14952
+ goodfire-llama-3.1-8b-instruct:
14953
+ conversion_func: goodfire
14954
+ model: meta-llama/Llama-3.1-8B-Instruct
14955
+ repo_id: Goodfire/Llama-3.1-8B-Instruct-SAE-l19
14956
+ saes:
14957
+ - id: layer_19
14958
+ path: Llama-3.1-8B-Instruct-SAE-l19.pth
14959
+ l0: 91
sae_lens/saes/__init__.py CHANGED
@@ -25,6 +25,7 @@ from .standard_sae import (
25
25
  StandardTrainingSAE,
26
26
  StandardTrainingSAEConfig,
27
27
  )
28
+ from .temporal_sae import TemporalSAE, TemporalSAEConfig
28
29
  from .topk_sae import (
29
30
  TopKSAE,
30
31
  TopKSAEConfig,
@@ -71,4 +72,6 @@ __all__ = [
71
72
  "JumpReLUTranscoderConfig",
72
73
  "MatryoshkaBatchTopKTrainingSAE",
73
74
  "MatryoshkaBatchTopKTrainingSAEConfig",
75
+ "TemporalSAE",
76
+ "TemporalSAEConfig",
74
77
  ]
sae_lens/saes/sae.py CHANGED
@@ -21,7 +21,7 @@ import einops
21
21
  import torch
22
22
  from jaxtyping import Float
23
23
  from numpy.typing import NDArray
24
- from safetensors.torch import save_file
24
+ from safetensors.torch import load_file, save_file
25
25
  from torch import nn
26
26
  from transformer_lens.hook_points import HookedRootModule, HookPoint
27
27
  from typing_extensions import deprecated, overload, override
@@ -155,9 +155,9 @@ class SAEConfig(ABC):
155
155
  dtype: str = "float32"
156
156
  device: str = "cpu"
157
157
  apply_b_dec_to_input: bool = True
158
- normalize_activations: Literal[
159
- "none", "expected_average_only_in", "constant_norm_rescale", "layer_norm"
160
- ] = "none" # none, expected_average_only_in (Anthropic April Update), constant_norm_rescale (Anthropic Feb Update)
158
+ normalize_activations: Literal["none", "expected_average_only_in", "layer_norm"] = (
159
+ "none" # none, expected_average_only_in (Anthropic April Update)
160
+ )
161
161
  reshape_activations: Literal["none", "hook_z"] = "none"
162
162
  metadata: SAEMetadata = field(default_factory=SAEMetadata)
163
163
 
@@ -309,6 +309,7 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
309
309
 
310
310
  self.run_time_activation_norm_fn_in = run_time_activation_norm_fn_in
311
311
  self.run_time_activation_norm_fn_out = run_time_activation_norm_fn_out
312
+
312
313
  elif self.cfg.normalize_activations == "layer_norm":
313
314
  # we need to scale the norm of the input and store the scaling factor
314
315
  def run_time_activation_ln_in(
@@ -452,23 +453,14 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
452
453
  def process_sae_in(
453
454
  self, sae_in: Float[torch.Tensor, "... d_in"]
454
455
  ) -> Float[torch.Tensor, "... d_in"]:
455
- # print(f"Input shape to process_sae_in: {sae_in.shape}")
456
- # print(f"self.cfg.hook_name: {self.cfg.hook_name}")
457
- # print(f"self.b_dec shape: {self.b_dec.shape}")
458
- # print(f"Hook z reshaping mode: {getattr(self, 'hook_z_reshaping_mode', False)}")
459
-
460
456
  sae_in = sae_in.to(self.dtype)
461
-
462
- # print(f"Shape before reshape_fn_in: {sae_in.shape}")
463
457
  sae_in = self.reshape_fn_in(sae_in)
464
- # print(f"Shape after reshape_fn_in: {sae_in.shape}")
465
458
 
466
459
  sae_in = self.hook_sae_input(sae_in)
467
460
  sae_in = self.run_time_activation_norm_fn_in(sae_in)
468
461
 
469
462
  # Here's where the error happens
470
463
  bias_term = self.b_dec * self.cfg.apply_b_dec_to_input
471
- # print(f"Bias term shape: {bias_term.shape}")
472
464
 
473
465
  return sae_in - bias_term
474
466
 
@@ -1018,6 +1010,12 @@ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
1018
1010
  ) -> type[TrainingSAEConfig]:
1019
1011
  return get_sae_training_class(architecture)[1]
1020
1012
 
1013
+ def load_weights_from_checkpoint(self, checkpoint_path: Path | str) -> None:
1014
+ checkpoint_path = Path(checkpoint_path)
1015
+ state_dict = load_file(checkpoint_path / SAE_WEIGHTS_FILENAME)
1016
+ self.process_state_dict_for_loading(state_dict)
1017
+ self.load_state_dict(state_dict)
1018
+
1021
1019
 
1022
1020
  _blank_hook = nn.Identity()
1023
1021