sae-lens 6.15.0__py3-none-any.whl → 6.22.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sae_lens/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.15.0"
2
+ __version__ = "6.22.1"
3
3
 
4
4
  import logging
5
5
 
@@ -28,6 +28,8 @@ from sae_lens.saes import (
28
28
  StandardSAEConfig,
29
29
  StandardTrainingSAE,
30
30
  StandardTrainingSAEConfig,
31
+ TemporalSAE,
32
+ TemporalSAEConfig,
31
33
  TopKSAE,
32
34
  TopKSAEConfig,
33
35
  TopKTrainingSAE,
@@ -105,6 +107,8 @@ __all__ = [
105
107
  "JumpReLUTranscoderConfig",
106
108
  "MatryoshkaBatchTopKTrainingSAE",
107
109
  "MatryoshkaBatchTopKTrainingSAEConfig",
110
+ "TemporalSAE",
111
+ "TemporalSAEConfig",
108
112
  ]
109
113
 
110
114
 
@@ -127,3 +131,4 @@ register_sae_training_class(
127
131
  register_sae_class("transcoder", Transcoder, TranscoderConfig)
128
132
  register_sae_class("skip_transcoder", SkipTranscoder, SkipTranscoderConfig)
129
133
  register_sae_class("jumprelu_transcoder", JumpReLUTranscoder, JumpReLUTranscoderConfig)
134
+ register_sae_class("temporal", TemporalSAE, TemporalSAEConfig)
@@ -3,7 +3,6 @@ from contextlib import contextmanager
3
3
  from typing import Any, Callable
4
4
 
5
5
  import torch
6
- from jaxtyping import Float
7
6
  from transformer_lens.ActivationCache import ActivationCache
8
7
  from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP
9
8
  from transformer_lens.hook_points import HookPoint # Hooking utilities
@@ -11,8 +10,8 @@ from transformer_lens.HookedTransformer import HookedTransformer
11
10
 
12
11
  from sae_lens.saes.sae import SAE
13
12
 
14
- SingleLoss = Float[torch.Tensor, ""] # Type alias for a single element tensor
15
- LossPerToken = Float[torch.Tensor, "batch pos-1"]
13
+ SingleLoss = torch.Tensor # Type alias for a single element tensor
14
+ LossPerToken = torch.Tensor
16
15
  Loss = SingleLoss | LossPerToken
17
16
 
18
17
 
@@ -171,12 +170,7 @@ class HookedSAETransformer(HookedTransformer):
171
170
  reset_saes_end: bool = True,
172
171
  use_error_term: bool | None = None,
173
172
  **model_kwargs: Any,
174
- ) -> (
175
- None
176
- | Float[torch.Tensor, "batch pos d_vocab"]
177
- | Loss
178
- | tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss]
179
- ):
173
+ ) -> None | torch.Tensor | Loss | tuple[torch.Tensor, Loss]:
180
174
  """Wrapper around HookedTransformer forward pass.
181
175
 
182
176
  Runs the model with the given SAEs attached for one forward pass, then removes them. By default, will reset all SAEs to original state after.
@@ -203,10 +197,7 @@ class HookedSAETransformer(HookedTransformer):
203
197
  remove_batch_dim: bool = False,
204
198
  **kwargs: Any,
205
199
  ) -> tuple[
206
- None
207
- | Float[torch.Tensor, "batch pos d_vocab"]
208
- | Loss
209
- | tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss],
200
+ None | torch.Tensor | Loss | tuple[torch.Tensor, Loss],
210
201
  ActivationCache | dict[str, torch.Tensor],
211
202
  ]:
212
203
  """Wrapper around 'run_with_cache' in HookedTransformer.
@@ -9,8 +9,7 @@ import torch
9
9
  from datasets import Array2D, Dataset, Features, Sequence, Value
10
10
  from datasets.fingerprint import generate_fingerprint
11
11
  from huggingface_hub import HfApi
12
- from jaxtyping import Float, Int
13
- from tqdm import tqdm
12
+ from tqdm.auto import tqdm
14
13
  from transformer_lens.HookedTransformer import HookedRootModule
15
14
 
16
15
  from sae_lens import logger
@@ -318,8 +317,8 @@ class CacheActivationsRunner:
318
317
  def _create_shard(
319
318
  self,
320
319
  buffer: tuple[
321
- Float[torch.Tensor, "(bs context_size) d_in"],
322
- Int[torch.Tensor, "(bs context_size)"] | None,
320
+ torch.Tensor, # shape: (bs context_size) d_in
321
+ torch.Tensor | None, # shape: (bs context_size) or None
323
322
  ],
324
323
  ) -> Dataset:
325
324
  hook_names = [self.cfg.hook_name]
sae_lens/config.py CHANGED
@@ -18,6 +18,7 @@ from datasets import (
18
18
 
19
19
  from sae_lens import __version__, logger
20
20
  from sae_lens.constants import DTYPE_MAP
21
+ from sae_lens.registry import get_sae_training_class
21
22
  from sae_lens.saes.sae import TrainingSAEConfig
22
23
 
23
24
  if TYPE_CHECKING:
@@ -171,6 +172,7 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
171
172
  n_checkpoints (int): The number of checkpoints to save during training. 0 means no checkpoints.
172
173
  checkpoint_path (str | None): The path to save checkpoints. A unique ID will be appended to this path. Set to None to disable checkpoint saving. (default is "checkpoints")
173
174
  save_final_checkpoint (bool): Whether to include an additional final checkpoint when training is finished. (default is False).
175
+ resume_from_checkpoint (str | None): The path to the checkpoint to resume training from. (default is None).
174
176
  output_path (str | None): The path to save outputs. Set to None to disable output saving. (default is "output")
175
177
  verbose (bool): Whether to print verbose output. (default is True)
176
178
  model_kwargs (dict[str, Any]): Keyword arguments for `model.run_with_cache`
@@ -261,6 +263,7 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
261
263
  checkpoint_path: str | None = "checkpoints"
262
264
  save_final_checkpoint: bool = False
263
265
  output_path: str | None = "output"
266
+ resume_from_checkpoint: str | None = None
264
267
 
265
268
  # Misc
266
269
  verbose: bool = True
@@ -385,8 +388,11 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
385
388
  return self.sae.to_dict()
386
389
 
387
390
  def to_dict(self) -> dict[str, Any]:
388
- # Make a shallow copy of config's dictionary
389
- d = dict(self.__dict__)
391
+ """
392
+ Convert the config to a dictionary.
393
+ """
394
+
395
+ d = asdict(self)
390
396
 
391
397
  d["logger"] = asdict(self.logger)
392
398
  d["sae"] = self.sae.to_dict()
@@ -396,6 +402,37 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
396
402
  d["act_store_device"] = str(self.act_store_device)
397
403
  return d
398
404
 
405
+ @classmethod
406
+ def from_dict(cls, cfg_dict: dict[str, Any]) -> "LanguageModelSAERunnerConfig[Any]":
407
+ """
408
+ Load a LanguageModelSAERunnerConfig from a dictionary given by `to_dict`.
409
+
410
+ Args:
411
+ cfg_dict (dict[str, Any]): The dictionary to load the config from.
412
+
413
+ Returns:
414
+ LanguageModelSAERunnerConfig: The loaded config.
415
+ """
416
+ if "sae" not in cfg_dict:
417
+ raise ValueError("sae field is required in the config dictionary")
418
+ if "architecture" not in cfg_dict["sae"]:
419
+ raise ValueError("architecture field is required in the sae dictionary")
420
+ if "logger" not in cfg_dict:
421
+ raise ValueError("logger field is required in the config dictionary")
422
+ sae_config_class = get_sae_training_class(cfg_dict["sae"]["architecture"])[1]
423
+ sae_cfg = sae_config_class.from_dict(cfg_dict["sae"])
424
+ logger_cfg = LoggingConfig(**cfg_dict["logger"])
425
+ updated_cfg_dict: dict[str, Any] = {
426
+ **cfg_dict,
427
+ "sae": sae_cfg,
428
+ "logger": logger_cfg,
429
+ }
430
+ output = cls(**updated_cfg_dict)
431
+ # the post_init always appends to checkpoint path, so we need to set it explicitly here.
432
+ if "checkpoint_path" in cfg_dict:
433
+ output.checkpoint_path = cfg_dict["checkpoint_path"]
434
+ return output
435
+
399
436
  def to_sae_trainer_config(self) -> "SAETrainerConfig":
400
437
  return SAETrainerConfig(
401
438
  n_checkpoints=self.n_checkpoints,
sae_lens/constants.py CHANGED
@@ -17,5 +17,6 @@ SAE_WEIGHTS_FILENAME = "sae_weights.safetensors"
17
17
  SAE_CFG_FILENAME = "cfg.json"
18
18
  RUNNER_CFG_FILENAME = "runner_cfg.json"
19
19
  SPARSIFY_WEIGHTS_FILENAME = "sae.safetensors"
20
+ TRAINER_STATE_FILENAME = "trainer_state.pt"
20
21
  ACTIVATIONS_STORE_STATE_FILENAME = "activations_store_state.safetensors"
21
22
  ACTIVATION_SCALER_CFG_FILENAME = "activation_scaler.json"
@@ -16,7 +16,6 @@ from typing_extensions import deprecated
16
16
  from sae_lens import logger
17
17
  from sae_lens.config import HfDataset, LanguageModelSAERunnerConfig
18
18
  from sae_lens.constants import (
19
- ACTIVATIONS_STORE_STATE_FILENAME,
20
19
  RUNNER_CFG_FILENAME,
21
20
  SPARSITY_FILENAME,
22
21
  )
@@ -112,6 +111,7 @@ class LanguageModelSAETrainingRunner:
112
111
  override_dataset: HfDataset | None = None,
113
112
  override_model: HookedRootModule | None = None,
114
113
  override_sae: TrainingSAE[Any] | None = None,
114
+ resume_from_checkpoint: Path | str | None = None,
115
115
  ):
116
116
  if override_dataset is not None:
117
117
  logger.warning(
@@ -153,6 +153,7 @@ class LanguageModelSAETrainingRunner:
153
153
  )
154
154
  else:
155
155
  self.sae = override_sae
156
+
156
157
  self.sae.to(self.cfg.device)
157
158
 
158
159
  def run(self):
@@ -185,6 +186,12 @@ class LanguageModelSAETrainingRunner:
185
186
  cfg=self.cfg.to_sae_trainer_config(),
186
187
  )
187
188
 
189
+ if self.cfg.resume_from_checkpoint is not None:
190
+ logger.info(f"Resuming from checkpoint: {self.cfg.resume_from_checkpoint}")
191
+ trainer.load_trainer_state(self.cfg.resume_from_checkpoint)
192
+ self.sae.load_weights_from_checkpoint(self.cfg.resume_from_checkpoint)
193
+ self.activations_store.load_from_checkpoint(self.cfg.resume_from_checkpoint)
194
+
188
195
  self._compile_if_needed()
189
196
  sae = self.run_trainer_with_interruption_handling(trainer)
190
197
 
@@ -304,9 +311,7 @@ class LanguageModelSAETrainingRunner:
304
311
  if checkpoint_path is None:
305
312
  return
306
313
 
307
- self.activations_store.save(
308
- str(checkpoint_path / ACTIVATIONS_STORE_STATE_FILENAME)
309
- )
314
+ self.activations_store.save_to_checkpoint(checkpoint_path)
310
315
 
311
316
  runner_config = self.cfg.to_dict()
312
317
  with open(checkpoint_path / RUNNER_CFG_FILENAME, "w") as f:
@@ -523,6 +523,82 @@ def gemma_2_sae_huggingface_loader(
523
523
  return cfg_dict, state_dict, log_sparsity
524
524
 
525
525
 
526
+ def get_goodfire_config_from_hf(
527
+ repo_id: str,
528
+ folder_name: str, # noqa: ARG001
529
+ device: str,
530
+ force_download: bool = False, # noqa: ARG001
531
+ cfg_overrides: dict[str, Any] | None = None,
532
+ ) -> dict[str, Any]:
533
+ cfg_dict = None
534
+ if repo_id == "Goodfire/Llama-3.3-70B-Instruct-SAE-l50":
535
+ if folder_name != "Llama-3.3-70B-Instruct-SAE-l50.pt":
536
+ raise ValueError(f"Unsupported Goodfire SAE: {repo_id}/{folder_name}")
537
+ cfg_dict = {
538
+ "architecture": "standard",
539
+ "d_in": 8192,
540
+ "d_sae": 65536,
541
+ "model_name": "meta-llama/Llama-3.3-70B-Instruct",
542
+ "hook_name": "blocks.50.hook_resid_post",
543
+ "hook_head_index": None,
544
+ "dataset_path": "lmsys/lmsys-chat-1m",
545
+ "apply_b_dec_to_input": False,
546
+ }
547
+ elif repo_id == "Goodfire/Llama-3.1-8B-Instruct-SAE-l19":
548
+ if folder_name != "Llama-3.1-8B-Instruct-SAE-l19.pth":
549
+ raise ValueError(f"Unsupported Goodfire SAE: {repo_id}/{folder_name}")
550
+ cfg_dict = {
551
+ "architecture": "standard",
552
+ "d_in": 4096,
553
+ "d_sae": 65536,
554
+ "model_name": "meta-llama/Llama-3.1-8B-Instruct",
555
+ "hook_name": "blocks.19.hook_resid_post",
556
+ "hook_head_index": None,
557
+ "dataset_path": "lmsys/lmsys-chat-1m",
558
+ "apply_b_dec_to_input": False,
559
+ }
560
+ if cfg_dict is None:
561
+ raise ValueError(f"Unsupported Goodfire SAE: {repo_id}/{folder_name}")
562
+ if device is not None:
563
+ cfg_dict["device"] = device
564
+ if cfg_overrides is not None:
565
+ cfg_dict.update(cfg_overrides)
566
+ return cfg_dict
567
+
568
+
569
+ def get_goodfire_huggingface_loader(
570
+ repo_id: str,
571
+ folder_name: str,
572
+ device: str = "cpu",
573
+ force_download: bool = False,
574
+ cfg_overrides: dict[str, Any] | None = None,
575
+ ) -> tuple[dict[str, Any], dict[str, torch.Tensor], torch.Tensor | None]:
576
+ cfg_dict = get_goodfire_config_from_hf(
577
+ repo_id,
578
+ folder_name,
579
+ device,
580
+ force_download,
581
+ cfg_overrides,
582
+ )
583
+
584
+ # Download the SAE weights
585
+ sae_path = hf_hub_download(
586
+ repo_id=repo_id,
587
+ filename=folder_name,
588
+ force_download=force_download,
589
+ )
590
+ raw_state_dict = torch.load(sae_path, map_location=device)
591
+
592
+ state_dict = {
593
+ "W_enc": raw_state_dict["encoder_linear.weight"].T,
594
+ "W_dec": raw_state_dict["decoder_linear.weight"].T,
595
+ "b_enc": raw_state_dict["encoder_linear.bias"],
596
+ "b_dec": raw_state_dict["decoder_linear.bias"],
597
+ }
598
+
599
+ return cfg_dict, state_dict, None
600
+
601
+
526
602
  def get_llama_scope_config_from_hf(
527
603
  repo_id: str,
528
604
  folder_name: str,
@@ -1475,6 +1551,114 @@ def get_mntss_clt_layer_config_from_hf(
1475
1551
  }
1476
1552
 
1477
1553
 
1554
+ def get_temporal_sae_config_from_hf(
1555
+ repo_id: str,
1556
+ folder_name: str,
1557
+ device: str,
1558
+ force_download: bool = False,
1559
+ cfg_overrides: dict[str, Any] | None = None,
1560
+ ) -> dict[str, Any]:
1561
+ """Get TemporalSAE config without loading weights."""
1562
+ # Download config file
1563
+ conf_path = hf_hub_download(
1564
+ repo_id=repo_id,
1565
+ filename=f"{folder_name}/conf.yaml",
1566
+ force_download=force_download,
1567
+ )
1568
+
1569
+ # Load and parse config
1570
+ with open(conf_path) as f:
1571
+ yaml_config = yaml.safe_load(f)
1572
+
1573
+ # Extract parameters
1574
+ d_in = yaml_config["llm"]["dimin"]
1575
+ exp_factor = yaml_config["sae"]["exp_factor"]
1576
+ d_sae = int(d_in * exp_factor)
1577
+
1578
+ # extract layer from folder_name eg : "layer_12/temporal"
1579
+ layer = re.search(r"layer_(\d+)", folder_name)
1580
+ if layer is None:
1581
+ raise ValueError(f"Could not find layer in folder_name: {folder_name}")
1582
+ layer = int(layer.group(1))
1583
+
1584
+ # Build config dict
1585
+ cfg_dict = {
1586
+ "architecture": "temporal",
1587
+ "hook_name": f"blocks.{layer}.hook_resid_post",
1588
+ "d_in": d_in,
1589
+ "d_sae": d_sae,
1590
+ "n_heads": yaml_config["sae"]["n_heads"],
1591
+ "n_attn_layers": yaml_config["sae"]["n_attn_layers"],
1592
+ "bottleneck_factor": yaml_config["sae"]["bottleneck_factor"],
1593
+ "sae_diff_type": yaml_config["sae"]["sae_diff_type"],
1594
+ "kval_topk": yaml_config["sae"]["kval_topk"],
1595
+ "tied_weights": yaml_config["sae"]["tied_weights"],
1596
+ "dtype": yaml_config["data"]["dtype"],
1597
+ "device": device,
1598
+ "normalize_activations": "constant_scalar_rescale",
1599
+ "activation_normalization_factor": yaml_config["sae"]["scaling_factor"],
1600
+ "apply_b_dec_to_input": True,
1601
+ }
1602
+
1603
+ if cfg_overrides:
1604
+ cfg_dict.update(cfg_overrides)
1605
+
1606
+ return cfg_dict
1607
+
1608
+
1609
+ def temporal_sae_huggingface_loader(
1610
+ repo_id: str,
1611
+ folder_name: str,
1612
+ device: str = "cpu",
1613
+ force_download: bool = False,
1614
+ cfg_overrides: dict[str, Any] | None = None,
1615
+ ) -> tuple[dict[str, Any], dict[str, torch.Tensor], torch.Tensor | None]:
1616
+ """
1617
+ Load TemporalSAE from canrager/temporalSAEs format (safetensors version).
1618
+
1619
+ Expects folder_name to contain:
1620
+ - conf.yaml (configuration)
1621
+ - latest_ckpt.safetensors (model weights)
1622
+ """
1623
+
1624
+ cfg_dict = get_temporal_sae_config_from_hf(
1625
+ repo_id=repo_id,
1626
+ folder_name=folder_name,
1627
+ device=device,
1628
+ force_download=force_download,
1629
+ cfg_overrides=cfg_overrides,
1630
+ )
1631
+
1632
+ # Download checkpoint (safetensors format)
1633
+ ckpt_path = hf_hub_download(
1634
+ repo_id=repo_id,
1635
+ filename=f"{folder_name}/latest_ckpt.safetensors",
1636
+ force_download=force_download,
1637
+ )
1638
+
1639
+ # Load checkpoint from safetensors
1640
+ state_dict_raw = load_file(ckpt_path, device=device)
1641
+
1642
+ # Convert to SAELens naming convention
1643
+ # TemporalSAE uses: D (decoder), E (encoder), b (bias), attn_layers.*
1644
+ state_dict = {}
1645
+
1646
+ # Copy attention layers as-is
1647
+ for key, value in state_dict_raw.items():
1648
+ if key.startswith("attn_layers."):
1649
+ state_dict[key] = value.to(device)
1650
+
1651
+ # Main parameters
1652
+ state_dict["W_dec"] = state_dict_raw["D"].to(device)
1653
+ state_dict["b_dec"] = state_dict_raw["b"].to(device)
1654
+
1655
+ # Handle tied/untied weights
1656
+ if "E" in state_dict_raw:
1657
+ state_dict["W_enc"] = state_dict_raw["E"].to(device)
1658
+
1659
+ return cfg_dict, state_dict, None
1660
+
1661
+
1478
1662
  NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
1479
1663
  "sae_lens": sae_lens_huggingface_loader,
1480
1664
  "connor_rob_hook_z": connor_rob_hook_z_huggingface_loader,
@@ -1487,6 +1671,8 @@ NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
1487
1671
  "gemma_2_transcoder": gemma_2_transcoder_huggingface_loader,
1488
1672
  "mwhanna_transcoder": mwhanna_transcoder_huggingface_loader,
1489
1673
  "mntss_clt_layer_transcoder": mntss_clt_layer_huggingface_loader,
1674
+ "temporal": temporal_sae_huggingface_loader,
1675
+ "goodfire": get_goodfire_huggingface_loader,
1490
1676
  }
1491
1677
 
1492
1678
 
@@ -1502,4 +1688,6 @@ NAMED_PRETRAINED_SAE_CONFIG_GETTERS: dict[str, PretrainedSaeConfigHuggingfaceLoa
1502
1688
  "gemma_2_transcoder": get_gemma_2_transcoder_config_from_hf,
1503
1689
  "mwhanna_transcoder": get_mwhanna_transcoder_config_from_hf,
1504
1690
  "mntss_clt_layer_transcoder": get_mntss_clt_layer_config_from_hf,
1691
+ "temporal": get_temporal_sae_config_from_hf,
1692
+ "goodfire": get_goodfire_config_from_hf,
1505
1693
  }
@@ -1,6 +1,6 @@
1
1
  from dataclasses import dataclass
2
2
  from functools import cache
3
- from importlib import resources
3
+ from importlib.resources import files
4
4
  from typing import Any
5
5
 
6
6
  import yaml
@@ -24,7 +24,8 @@ def get_pretrained_saes_directory() -> dict[str, PretrainedSAELookup]:
24
24
  package = "sae_lens"
25
25
  # Access the file within the package using importlib.resources
26
26
  directory: dict[str, PretrainedSAELookup] = {}
27
- with resources.open_text(package, "pretrained_saes.yaml") as file:
27
+ yaml_file = files(package).joinpath("pretrained_saes.yaml")
28
+ with yaml_file.open("r") as file:
28
29
  # Load the YAML file content
29
30
  data = yaml.safe_load(file)
30
31
  for release, value in data.items():
@@ -68,7 +69,8 @@ def get_norm_scaling_factor(release: str, sae_id: str) -> float | None:
68
69
  float | None: The norm_scaling_factor if it exists, None otherwise.
69
70
  """
70
71
  package = "sae_lens"
71
- with resources.open_text(package, "pretrained_saes.yaml") as file:
72
+ yaml_file = files(package).joinpath("pretrained_saes.yaml")
73
+ with yaml_file.open("r") as file:
72
74
  data = yaml.safe_load(file)
73
75
  if release in data:
74
76
  for sae_info in data[release]["saes"]:
@@ -1,3 +1,35 @@
1
+ temporal-sae-gemma-2-2b:
2
+ conversion_func: temporal
3
+ model: gemma-2-2b
4
+ repo_id: canrager/temporalSAEs
5
+ config_overrides:
6
+ model_name: gemma-2-2b
7
+ hook_name: blocks.12.hook_resid_post
8
+ dataset_path: monology/pile-uncopyrighted
9
+ saes:
10
+ - id: blocks.12.hook_resid_post
11
+ l0: 192
12
+ norm_scaling_factor: 0.00666666667
13
+ path: gemma-2-2B/layer_12/temporal
14
+ neuronpedia: gemma-2-2b/12-temporal-res
15
+ temporal-sae-llama-3.1-8b:
16
+ conversion_func: temporal
17
+ model: meta-llama/Llama-3.1-8B
18
+ repo_id: canrager/temporalSAEs
19
+ config_overrides:
20
+ model_name: meta-llama/Llama-3.1-8B
21
+ dataset_path: monology/pile-uncopyrighted
22
+ saes:
23
+ - id: blocks.15.hook_resid_post
24
+ l0: 256
25
+ norm_scaling_factor: 0.029
26
+ path: llama-3.1-8B/layer_15/temporal
27
+ neuronpedia: llama3.1-8b/15-temporal-res
28
+ - id: blocks.26.hook_resid_post
29
+ l0: 256
30
+ norm_scaling_factor: 0.029
31
+ path: llama-3.1-8B/layer_26/temporal
32
+ neuronpedia: llama3.1-8b/26-temporal-res
1
33
  deepseek-r1-distill-llama-8b-qresearch:
2
34
  conversion_func: deepseek_r1
3
35
  model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
@@ -14882,4 +14914,48 @@ qwen2.5-7b-instruct-andyrdt:
14882
14914
  neuronpedia: qwen2.5-7b-it/23-resid-post-aa
14883
14915
  - id: resid_post_layer_27_trainer_1
14884
14916
  path: resid_post_layer_27/trainer_1
14885
- neuronpedia: qwen2.5-7b-it/27-resid-post-aa
14917
+ neuronpedia: qwen2.5-7b-it/27-resid-post-aa
14918
+
14919
+ gpt-oss-20b-andyrdt:
14920
+ conversion_func: dictionary_learning_1
14921
+ model: openai/gpt-oss-20b
14922
+ repo_id: andyrdt/saes-gpt-oss-20b
14923
+ saes:
14924
+ - id: resid_post_layer_3_trainer_0
14925
+ path: resid_post_layer_3/trainer_0
14926
+ neuronpedia: gpt-oss-20b/3-resid-post-aa
14927
+ - id: resid_post_layer_7_trainer_0
14928
+ path: resid_post_layer_7/trainer_0
14929
+ neuronpedia: gpt-oss-20b/7-resid-post-aa
14930
+ - id: resid_post_layer_11_trainer_0
14931
+ path: resid_post_layer_11/trainer_0
14932
+ neuronpedia: gpt-oss-20b/11-resid-post-aa
14933
+ - id: resid_post_layer_15_trainer_0
14934
+ path: resid_post_layer_15/trainer_0
14935
+ neuronpedia: gpt-oss-20b/15-resid-post-aa
14936
+ - id: resid_post_layer_19_trainer_0
14937
+ path: resid_post_layer_19/trainer_0
14938
+ neuronpedia: gpt-oss-20b/19-resid-post-aa
14939
+ - id: resid_post_layer_23_trainer_0
14940
+ path: resid_post_layer_23/trainer_0
14941
+ neuronpedia: gpt-oss-20b/23-resid-post-aa
14942
+
14943
+ goodfire-llama-3.3-70b-instruct:
14944
+ conversion_func: goodfire
14945
+ model: meta-llama/Llama-3.3-70B-Instruct
14946
+ repo_id: Goodfire/Llama-3.3-70B-Instruct-SAE-l50
14947
+ saes:
14948
+ - id: layer_50
14949
+ path: Llama-3.3-70B-Instruct-SAE-l50.pt
14950
+ l0: 121
14951
+ neuronpedia: llama3.3-70b-it/50-resid-post-gf
14952
+
14953
+ goodfire-llama-3.1-8b-instruct:
14954
+ conversion_func: goodfire
14955
+ model: meta-llama/Llama-3.1-8B-Instruct
14956
+ repo_id: Goodfire/Llama-3.1-8B-Instruct-SAE-l19
14957
+ saes:
14958
+ - id: layer_19
14959
+ path: Llama-3.1-8B-Instruct-SAE-l19.pth
14960
+ l0: 91
14961
+ neuronpedia: llama3.1-8b-it/19-resid-post-gf
sae_lens/saes/__init__.py CHANGED
@@ -25,6 +25,7 @@ from .standard_sae import (
25
25
  StandardTrainingSAE,
26
26
  StandardTrainingSAEConfig,
27
27
  )
28
+ from .temporal_sae import TemporalSAE, TemporalSAEConfig
28
29
  from .topk_sae import (
29
30
  TopKSAE,
30
31
  TopKSAEConfig,
@@ -71,4 +72,6 @@ __all__ = [
71
72
  "JumpReLUTranscoderConfig",
72
73
  "MatryoshkaBatchTopKTrainingSAE",
73
74
  "MatryoshkaBatchTopKTrainingSAEConfig",
75
+ "TemporalSAE",
76
+ "TemporalSAEConfig",
74
77
  ]
@@ -23,7 +23,9 @@ class BatchTopK(nn.Module):
23
23
  def forward(self, x: torch.Tensor) -> torch.Tensor:
24
24
  acts = x.relu()
25
25
  flat_acts = acts.flatten()
26
- acts_topk_flat = torch.topk(flat_acts, int(self.k * acts.shape[0]), dim=-1)
26
+ # Calculate total number of samples across all non-feature dimensions
27
+ num_samples = acts.shape[:-1].numel()
28
+ acts_topk_flat = torch.topk(flat_acts, int(self.k * num_samples), dim=-1)
27
29
  return (
28
30
  torch.zeros_like(flat_acts)
29
31
  .scatter(-1, acts_topk_flat.indices, acts_topk_flat.values)
@@ -2,7 +2,6 @@ from dataclasses import dataclass
2
2
  from typing import Any
3
3
 
4
4
  import torch
5
- from jaxtyping import Float
6
5
  from numpy.typing import NDArray
7
6
  from torch import nn
8
7
  from typing_extensions import override
@@ -49,9 +48,7 @@ class GatedSAE(SAE[GatedSAEConfig]):
49
48
  super().initialize_weights()
50
49
  _init_weights_gated(self)
51
50
 
52
- def encode(
53
- self, x: Float[torch.Tensor, "... d_in"]
54
- ) -> Float[torch.Tensor, "... d_sae"]:
51
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
55
52
  """
56
53
  Encode the input tensor into the feature space using a gated encoder.
57
54
  This must match the original encode_gated implementation from SAE class.
@@ -72,9 +69,7 @@ class GatedSAE(SAE[GatedSAEConfig]):
72
69
  # Combine gating and magnitudes
73
70
  return self.hook_sae_acts_post(active_features * feature_magnitudes)
74
71
 
75
- def decode(
76
- self, feature_acts: Float[torch.Tensor, "... d_sae"]
77
- ) -> Float[torch.Tensor, "... d_in"]:
72
+ def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
78
73
  """
79
74
  Decode the feature activations back into the input space:
80
75
  1) Apply optional finetuning scaling.
@@ -147,8 +142,8 @@ class GatedTrainingSAE(TrainingSAE[GatedTrainingSAEConfig]):
147
142
  _init_weights_gated(self)
148
143
 
149
144
  def encode_with_hidden_pre(
150
- self, x: Float[torch.Tensor, "... d_in"]
151
- ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
145
+ self, x: torch.Tensor
146
+ ) -> tuple[torch.Tensor, torch.Tensor]:
152
147
  """
153
148
  Gated forward pass with pre-activation (for training).
154
149
  """
@@ -3,7 +3,6 @@ from typing import Any, Literal
3
3
 
4
4
  import numpy as np
5
5
  import torch
6
- from jaxtyping import Float
7
6
  from torch import nn
8
7
  from typing_extensions import override
9
8
 
@@ -130,9 +129,7 @@ class JumpReLUSAE(SAE[JumpReLUSAEConfig]):
130
129
  torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
131
130
  )
132
131
 
133
- def encode(
134
- self, x: Float[torch.Tensor, "... d_in"]
135
- ) -> Float[torch.Tensor, "... d_sae"]:
132
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
136
133
  """
137
134
  Encode the input tensor into the feature space using JumpReLU.
138
135
  The threshold parameter determines which units remain active.
@@ -150,9 +147,7 @@ class JumpReLUSAE(SAE[JumpReLUSAEConfig]):
150
147
  # 3) Multiply the normally activated units by that mask.
151
148
  return self.hook_sae_acts_post(base_acts * jump_relu_mask)
152
149
 
153
- def decode(
154
- self, feature_acts: Float[torch.Tensor, "... d_sae"]
155
- ) -> Float[torch.Tensor, "... d_in"]:
150
+ def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
156
151
  """
157
152
  Decode the feature activations back to the input space.
158
153
  Follows the same steps as StandardSAE: apply scaling, transform, hook, and optionally reshape.
@@ -265,8 +260,8 @@ class JumpReLUTrainingSAE(TrainingSAE[JumpReLUTrainingSAEConfig]):
265
260
  return torch.exp(self.log_threshold)
266
261
 
267
262
  def encode_with_hidden_pre(
268
- self, x: Float[torch.Tensor, "... d_in"]
269
- ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
263
+ self, x: torch.Tensor
264
+ ) -> tuple[torch.Tensor, torch.Tensor]:
270
265
  sae_in = self.process_sae_in(x)
271
266
 
272
267
  hidden_pre = sae_in @ self.W_enc + self.b_enc