sae-lens 6.14.1__py3-none-any.whl → 6.22.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sae_lens/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.14.1"
2
+ __version__ = "6.22.1"
3
3
 
4
4
  import logging
5
5
 
@@ -19,6 +19,8 @@ from sae_lens.saes import (
19
19
  JumpReLUTrainingSAEConfig,
20
20
  JumpReLUTranscoder,
21
21
  JumpReLUTranscoderConfig,
22
+ MatryoshkaBatchTopKTrainingSAE,
23
+ MatryoshkaBatchTopKTrainingSAEConfig,
22
24
  SAEConfig,
23
25
  SkipTranscoder,
24
26
  SkipTranscoderConfig,
@@ -26,6 +28,8 @@ from sae_lens.saes import (
26
28
  StandardSAEConfig,
27
29
  StandardTrainingSAE,
28
30
  StandardTrainingSAEConfig,
31
+ TemporalSAE,
32
+ TemporalSAEConfig,
29
33
  TopKSAE,
30
34
  TopKSAEConfig,
31
35
  TopKTrainingSAE,
@@ -101,6 +105,10 @@ __all__ = [
101
105
  "SkipTranscoderConfig",
102
106
  "JumpReLUTranscoder",
103
107
  "JumpReLUTranscoderConfig",
108
+ "MatryoshkaBatchTopKTrainingSAE",
109
+ "MatryoshkaBatchTopKTrainingSAEConfig",
110
+ "TemporalSAE",
111
+ "TemporalSAEConfig",
104
112
  ]
105
113
 
106
114
 
@@ -115,6 +123,12 @@ register_sae_training_class("jumprelu", JumpReLUTrainingSAE, JumpReLUTrainingSAE
115
123
  register_sae_training_class(
116
124
  "batchtopk", BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig
117
125
  )
126
+ register_sae_training_class(
127
+ "matryoshka_batchtopk",
128
+ MatryoshkaBatchTopKTrainingSAE,
129
+ MatryoshkaBatchTopKTrainingSAEConfig,
130
+ )
118
131
  register_sae_class("transcoder", Transcoder, TranscoderConfig)
119
132
  register_sae_class("skip_transcoder", SkipTranscoder, SkipTranscoderConfig)
120
133
  register_sae_class("jumprelu_transcoder", JumpReLUTranscoder, JumpReLUTranscoderConfig)
134
+ register_sae_class("temporal", TemporalSAE, TemporalSAEConfig)
@@ -3,7 +3,6 @@ from contextlib import contextmanager
3
3
  from typing import Any, Callable
4
4
 
5
5
  import torch
6
- from jaxtyping import Float
7
6
  from transformer_lens.ActivationCache import ActivationCache
8
7
  from transformer_lens.components.mlps.can_be_used_as_mlp import CanBeUsedAsMLP
9
8
  from transformer_lens.hook_points import HookPoint # Hooking utilities
@@ -11,8 +10,8 @@ from transformer_lens.HookedTransformer import HookedTransformer
11
10
 
12
11
  from sae_lens.saes.sae import SAE
13
12
 
14
- SingleLoss = Float[torch.Tensor, ""] # Type alias for a single element tensor
15
- LossPerToken = Float[torch.Tensor, "batch pos-1"]
13
+ SingleLoss = torch.Tensor # Type alias for a single element tensor
14
+ LossPerToken = torch.Tensor
16
15
  Loss = SingleLoss | LossPerToken
17
16
 
18
17
 
@@ -171,12 +170,7 @@ class HookedSAETransformer(HookedTransformer):
171
170
  reset_saes_end: bool = True,
172
171
  use_error_term: bool | None = None,
173
172
  **model_kwargs: Any,
174
- ) -> (
175
- None
176
- | Float[torch.Tensor, "batch pos d_vocab"]
177
- | Loss
178
- | tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss]
179
- ):
173
+ ) -> None | torch.Tensor | Loss | tuple[torch.Tensor, Loss]:
180
174
  """Wrapper around HookedTransformer forward pass.
181
175
 
182
176
  Runs the model with the given SAEs attached for one forward pass, then removes them. By default, will reset all SAEs to original state after.
@@ -203,10 +197,7 @@ class HookedSAETransformer(HookedTransformer):
203
197
  remove_batch_dim: bool = False,
204
198
  **kwargs: Any,
205
199
  ) -> tuple[
206
- None
207
- | Float[torch.Tensor, "batch pos d_vocab"]
208
- | Loss
209
- | tuple[Float[torch.Tensor, "batch pos d_vocab"], Loss],
200
+ None | torch.Tensor | Loss | tuple[torch.Tensor, Loss],
210
201
  ActivationCache | dict[str, torch.Tensor],
211
202
  ]:
212
203
  """Wrapper around 'run_with_cache' in HookedTransformer.
@@ -9,8 +9,7 @@ import torch
9
9
  from datasets import Array2D, Dataset, Features, Sequence, Value
10
10
  from datasets.fingerprint import generate_fingerprint
11
11
  from huggingface_hub import HfApi
12
- from jaxtyping import Float, Int
13
- from tqdm import tqdm
12
+ from tqdm.auto import tqdm
14
13
  from transformer_lens.HookedTransformer import HookedRootModule
15
14
 
16
15
  from sae_lens import logger
@@ -318,8 +317,8 @@ class CacheActivationsRunner:
318
317
  def _create_shard(
319
318
  self,
320
319
  buffer: tuple[
321
- Float[torch.Tensor, "(bs context_size) d_in"],
322
- Int[torch.Tensor, "(bs context_size)"] | None,
320
+ torch.Tensor, # shape: (bs context_size) d_in
321
+ torch.Tensor | None, # shape: (bs context_size) or None
323
322
  ],
324
323
  ) -> Dataset:
325
324
  hook_names = [self.cfg.hook_name]
sae_lens/config.py CHANGED
@@ -18,6 +18,7 @@ from datasets import (
18
18
 
19
19
  from sae_lens import __version__, logger
20
20
  from sae_lens.constants import DTYPE_MAP
21
+ from sae_lens.registry import get_sae_training_class
21
22
  from sae_lens.saes.sae import TrainingSAEConfig
22
23
 
23
24
  if TYPE_CHECKING:
@@ -171,6 +172,7 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
171
172
  n_checkpoints (int): The number of checkpoints to save during training. 0 means no checkpoints.
172
173
  checkpoint_path (str | None): The path to save checkpoints. A unique ID will be appended to this path. Set to None to disable checkpoint saving. (default is "checkpoints")
173
174
  save_final_checkpoint (bool): Whether to include an additional final checkpoint when training is finished. (default is False).
175
+ resume_from_checkpoint (str | None): The path to the checkpoint to resume training from. (default is None).
174
176
  output_path (str | None): The path to save outputs. Set to None to disable output saving. (default is "output")
175
177
  verbose (bool): Whether to print verbose output. (default is True)
176
178
  model_kwargs (dict[str, Any]): Keyword arguments for `model.run_with_cache`
@@ -261,6 +263,7 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
261
263
  checkpoint_path: str | None = "checkpoints"
262
264
  save_final_checkpoint: bool = False
263
265
  output_path: str | None = "output"
266
+ resume_from_checkpoint: str | None = None
264
267
 
265
268
  # Misc
266
269
  verbose: bool = True
@@ -385,8 +388,11 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
385
388
  return self.sae.to_dict()
386
389
 
387
390
  def to_dict(self) -> dict[str, Any]:
388
- # Make a shallow copy of config's dictionary
389
- d = dict(self.__dict__)
391
+ """
392
+ Convert the config to a dictionary.
393
+ """
394
+
395
+ d = asdict(self)
390
396
 
391
397
  d["logger"] = asdict(self.logger)
392
398
  d["sae"] = self.sae.to_dict()
@@ -396,6 +402,37 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
396
402
  d["act_store_device"] = str(self.act_store_device)
397
403
  return d
398
404
 
405
+ @classmethod
406
+ def from_dict(cls, cfg_dict: dict[str, Any]) -> "LanguageModelSAERunnerConfig[Any]":
407
+ """
408
+ Load a LanguageModelSAERunnerConfig from a dictionary given by `to_dict`.
409
+
410
+ Args:
411
+ cfg_dict (dict[str, Any]): The dictionary to load the config from.
412
+
413
+ Returns:
414
+ LanguageModelSAERunnerConfig: The loaded config.
415
+ """
416
+ if "sae" not in cfg_dict:
417
+ raise ValueError("sae field is required in the config dictionary")
418
+ if "architecture" not in cfg_dict["sae"]:
419
+ raise ValueError("architecture field is required in the sae dictionary")
420
+ if "logger" not in cfg_dict:
421
+ raise ValueError("logger field is required in the config dictionary")
422
+ sae_config_class = get_sae_training_class(cfg_dict["sae"]["architecture"])[1]
423
+ sae_cfg = sae_config_class.from_dict(cfg_dict["sae"])
424
+ logger_cfg = LoggingConfig(**cfg_dict["logger"])
425
+ updated_cfg_dict: dict[str, Any] = {
426
+ **cfg_dict,
427
+ "sae": sae_cfg,
428
+ "logger": logger_cfg,
429
+ }
430
+ output = cls(**updated_cfg_dict)
431
+ # the post_init always appends to checkpoint path, so we need to set it explicitly here.
432
+ if "checkpoint_path" in cfg_dict:
433
+ output.checkpoint_path = cfg_dict["checkpoint_path"]
434
+ return output
435
+
399
436
  def to_sae_trainer_config(self) -> "SAETrainerConfig":
400
437
  return SAETrainerConfig(
401
438
  n_checkpoints=self.n_checkpoints,
sae_lens/constants.py CHANGED
@@ -17,5 +17,6 @@ SAE_WEIGHTS_FILENAME = "sae_weights.safetensors"
17
17
  SAE_CFG_FILENAME = "cfg.json"
18
18
  RUNNER_CFG_FILENAME = "runner_cfg.json"
19
19
  SPARSIFY_WEIGHTS_FILENAME = "sae.safetensors"
20
+ TRAINER_STATE_FILENAME = "trainer_state.pt"
20
21
  ACTIVATIONS_STORE_STATE_FILENAME = "activations_store_state.safetensors"
21
22
  ACTIVATION_SCALER_CFG_FILENAME = "activation_scaler.json"
sae_lens/evals.py CHANGED
@@ -11,7 +11,7 @@ from dataclasses import dataclass, field
11
11
  from functools import partial
12
12
  from importlib.metadata import PackageNotFoundError, version
13
13
  from pathlib import Path
14
- from typing import Any
14
+ from typing import Any, Iterable
15
15
 
16
16
  import einops
17
17
  import pandas as pd
@@ -24,7 +24,10 @@ from sae_lens.loading.pretrained_saes_directory import get_pretrained_saes_direc
24
24
  from sae_lens.saes.sae import SAE, SAEConfig
25
25
  from sae_lens.training.activation_scaler import ActivationScaler
26
26
  from sae_lens.training.activations_store import ActivationsStore
27
- from sae_lens.util import extract_stop_at_layer_from_tlens_hook_name
27
+ from sae_lens.util import (
28
+ extract_stop_at_layer_from_tlens_hook_name,
29
+ get_special_token_ids,
30
+ )
28
31
 
29
32
 
30
33
  def get_library_version() -> str:
@@ -109,9 +112,15 @@ def run_evals(
109
112
  activation_scaler: ActivationScaler,
110
113
  eval_config: EvalConfig = EvalConfig(),
111
114
  model_kwargs: Mapping[str, Any] = {},
112
- ignore_tokens: set[int | None] = set(),
115
+ exclude_special_tokens: Iterable[int] | bool = True,
113
116
  verbose: bool = False,
114
117
  ) -> tuple[dict[str, Any], dict[str, Any]]:
118
+ ignore_tokens = None
119
+ if exclude_special_tokens is True:
120
+ ignore_tokens = list(get_special_token_ids(model.tokenizer)) # type: ignore
121
+ elif exclude_special_tokens:
122
+ ignore_tokens = list(exclude_special_tokens)
123
+
115
124
  hook_name = sae.cfg.metadata.hook_name
116
125
  actual_batch_size = (
117
126
  eval_config.batch_size_prompts or activation_store.store_batch_size_prompts
@@ -312,7 +321,7 @@ def get_downstream_reconstruction_metrics(
312
321
  compute_ce_loss: bool,
313
322
  n_batches: int,
314
323
  eval_batch_size_prompts: int,
315
- ignore_tokens: set[int | None] = set(),
324
+ ignore_tokens: list[int] | None = None,
316
325
  verbose: bool = False,
317
326
  ):
318
327
  metrics_dict = {}
@@ -339,7 +348,7 @@ def get_downstream_reconstruction_metrics(
339
348
  compute_ce_loss=compute_ce_loss,
340
349
  ignore_tokens=ignore_tokens,
341
350
  ).items():
342
- if len(ignore_tokens) > 0:
351
+ if ignore_tokens:
343
352
  mask = torch.logical_not(
344
353
  torch.any(
345
354
  torch.stack(
@@ -384,7 +393,7 @@ def get_sparsity_and_variance_metrics(
384
393
  compute_featurewise_density_statistics: bool,
385
394
  eval_batch_size_prompts: int,
386
395
  model_kwargs: Mapping[str, Any],
387
- ignore_tokens: set[int | None] = set(),
396
+ ignore_tokens: list[int] | None = None,
388
397
  verbose: bool = False,
389
398
  ) -> tuple[dict[str, Any], dict[str, Any]]:
390
399
  hook_name = sae.cfg.metadata.hook_name
@@ -426,7 +435,7 @@ def get_sparsity_and_variance_metrics(
426
435
  for _ in batch_iter:
427
436
  batch_tokens = activation_store.get_batch_tokens(eval_batch_size_prompts)
428
437
 
429
- if len(ignore_tokens) > 0:
438
+ if ignore_tokens:
430
439
  mask = torch.logical_not(
431
440
  torch.any(
432
441
  torch.stack(
@@ -596,7 +605,7 @@ def get_recons_loss(
596
605
  batch_tokens: torch.Tensor,
597
606
  compute_kl: bool,
598
607
  compute_ce_loss: bool,
599
- ignore_tokens: set[int | None] = set(),
608
+ ignore_tokens: list[int] | None = None,
600
609
  model_kwargs: Mapping[str, Any] = {},
601
610
  hook_name: str | None = None,
602
611
  ) -> dict[str, Any]:
@@ -610,7 +619,7 @@ def get_recons_loss(
610
619
  batch_tokens, return_type="both", loss_per_token=True, **model_kwargs
611
620
  )
612
621
 
613
- if len(ignore_tokens) > 0:
622
+ if ignore_tokens:
614
623
  mask = torch.logical_not(
615
624
  torch.any(
616
625
  torch.stack([batch_tokens == token for token in ignore_tokens], dim=0),
@@ -856,11 +865,6 @@ def multiple_evals(
856
865
  activation_scaler=ActivationScaler(),
857
866
  model=current_model,
858
867
  eval_config=eval_config,
859
- ignore_tokens={
860
- current_model.tokenizer.pad_token_id, # type: ignore
861
- current_model.tokenizer.eos_token_id, # type: ignore
862
- current_model.tokenizer.bos_token_id, # type: ignore
863
- },
864
868
  verbose=verbose,
865
869
  )
866
870
  eval_metrics["metrics"] = scalar_metrics
@@ -16,23 +16,18 @@ from typing_extensions import deprecated
16
16
  from sae_lens import logger
17
17
  from sae_lens.config import HfDataset, LanguageModelSAERunnerConfig
18
18
  from sae_lens.constants import (
19
- ACTIVATIONS_STORE_STATE_FILENAME,
20
19
  RUNNER_CFG_FILENAME,
21
20
  SPARSITY_FILENAME,
22
21
  )
23
22
  from sae_lens.evals import EvalConfig, run_evals
24
23
  from sae_lens.load_model import load_model
25
- from sae_lens.saes.batchtopk_sae import BatchTopKTrainingSAEConfig
26
- from sae_lens.saes.gated_sae import GatedTrainingSAEConfig
27
- from sae_lens.saes.jumprelu_sae import JumpReLUTrainingSAEConfig
24
+ from sae_lens.registry import SAE_TRAINING_CLASS_REGISTRY
28
25
  from sae_lens.saes.sae import (
29
26
  T_TRAINING_SAE,
30
27
  T_TRAINING_SAE_CONFIG,
31
28
  TrainingSAE,
32
29
  TrainingSAEConfig,
33
30
  )
34
- from sae_lens.saes.standard_sae import StandardTrainingSAEConfig
35
- from sae_lens.saes.topk_sae import TopKTrainingSAEConfig
36
31
  from sae_lens.training.activation_scaler import ActivationScaler
37
32
  from sae_lens.training.activations_store import ActivationsStore
38
33
  from sae_lens.training.sae_trainer import SAETrainer
@@ -61,9 +56,11 @@ class LLMSaeEvaluator(Generic[T_TRAINING_SAE]):
61
56
  data_provider: DataProvider,
62
57
  activation_scaler: ActivationScaler,
63
58
  ) -> dict[str, Any]:
64
- ignore_tokens = set()
59
+ exclude_special_tokens = False
65
60
  if self.activations_store.exclude_special_tokens is not None:
66
- ignore_tokens = set(self.activations_store.exclude_special_tokens.tolist())
61
+ exclude_special_tokens = (
62
+ self.activations_store.exclude_special_tokens.tolist()
63
+ )
67
64
 
68
65
  eval_config = EvalConfig(
69
66
  batch_size_prompts=self.eval_batch_size_prompts,
@@ -81,7 +78,7 @@ class LLMSaeEvaluator(Generic[T_TRAINING_SAE]):
81
78
  model=self.model,
82
79
  activation_scaler=activation_scaler,
83
80
  eval_config=eval_config,
84
- ignore_tokens=ignore_tokens,
81
+ exclude_special_tokens=exclude_special_tokens,
85
82
  model_kwargs=self.model_kwargs,
86
83
  ) # not calculating featurwise metrics here.
87
84
 
@@ -114,6 +111,7 @@ class LanguageModelSAETrainingRunner:
114
111
  override_dataset: HfDataset | None = None,
115
112
  override_model: HookedRootModule | None = None,
116
113
  override_sae: TrainingSAE[Any] | None = None,
114
+ resume_from_checkpoint: Path | str | None = None,
117
115
  ):
118
116
  if override_dataset is not None:
119
117
  logger.warning(
@@ -155,6 +153,7 @@ class LanguageModelSAETrainingRunner:
155
153
  )
156
154
  else:
157
155
  self.sae = override_sae
156
+
158
157
  self.sae.to(self.cfg.device)
159
158
 
160
159
  def run(self):
@@ -187,6 +186,12 @@ class LanguageModelSAETrainingRunner:
187
186
  cfg=self.cfg.to_sae_trainer_config(),
188
187
  )
189
188
 
189
+ if self.cfg.resume_from_checkpoint is not None:
190
+ logger.info(f"Resuming from checkpoint: {self.cfg.resume_from_checkpoint}")
191
+ trainer.load_trainer_state(self.cfg.resume_from_checkpoint)
192
+ self.sae.load_weights_from_checkpoint(self.cfg.resume_from_checkpoint)
193
+ self.activations_store.load_from_checkpoint(self.cfg.resume_from_checkpoint)
194
+
190
195
  self._compile_if_needed()
191
196
  sae = self.run_trainer_with_interruption_handling(trainer)
192
197
 
@@ -306,9 +311,7 @@ class LanguageModelSAETrainingRunner:
306
311
  if checkpoint_path is None:
307
312
  return
308
313
 
309
- self.activations_store.save(
310
- str(checkpoint_path / ACTIVATIONS_STORE_STATE_FILENAME)
311
- )
314
+ self.activations_store.save_to_checkpoint(checkpoint_path)
312
315
 
313
316
  runner_config = self.cfg.to_dict()
314
317
  with open(checkpoint_path / RUNNER_CFG_FILENAME, "w") as f:
@@ -393,12 +396,8 @@ def _parse_cfg_args(
393
396
  )
394
397
 
395
398
  # Map architecture to concrete config class
396
- sae_config_map = {
397
- "standard": StandardTrainingSAEConfig,
398
- "gated": GatedTrainingSAEConfig,
399
- "jumprelu": JumpReLUTrainingSAEConfig,
400
- "topk": TopKTrainingSAEConfig,
401
- "batchtopk": BatchTopKTrainingSAEConfig,
399
+ sae_config_map: dict[str, type[TrainingSAEConfig]] = {
400
+ name: cfg for name, (_, cfg) in SAE_TRAINING_CLASS_REGISTRY.items()
402
401
  }
403
402
 
404
403
  sae_config_type = sae_config_map[architecture]
@@ -523,6 +523,82 @@ def gemma_2_sae_huggingface_loader(
523
523
  return cfg_dict, state_dict, log_sparsity
524
524
 
525
525
 
526
+ def get_goodfire_config_from_hf(
527
+ repo_id: str,
528
+ folder_name: str, # noqa: ARG001
529
+ device: str,
530
+ force_download: bool = False, # noqa: ARG001
531
+ cfg_overrides: dict[str, Any] | None = None,
532
+ ) -> dict[str, Any]:
533
+ cfg_dict = None
534
+ if repo_id == "Goodfire/Llama-3.3-70B-Instruct-SAE-l50":
535
+ if folder_name != "Llama-3.3-70B-Instruct-SAE-l50.pt":
536
+ raise ValueError(f"Unsupported Goodfire SAE: {repo_id}/{folder_name}")
537
+ cfg_dict = {
538
+ "architecture": "standard",
539
+ "d_in": 8192,
540
+ "d_sae": 65536,
541
+ "model_name": "meta-llama/Llama-3.3-70B-Instruct",
542
+ "hook_name": "blocks.50.hook_resid_post",
543
+ "hook_head_index": None,
544
+ "dataset_path": "lmsys/lmsys-chat-1m",
545
+ "apply_b_dec_to_input": False,
546
+ }
547
+ elif repo_id == "Goodfire/Llama-3.1-8B-Instruct-SAE-l19":
548
+ if folder_name != "Llama-3.1-8B-Instruct-SAE-l19.pth":
549
+ raise ValueError(f"Unsupported Goodfire SAE: {repo_id}/{folder_name}")
550
+ cfg_dict = {
551
+ "architecture": "standard",
552
+ "d_in": 4096,
553
+ "d_sae": 65536,
554
+ "model_name": "meta-llama/Llama-3.1-8B-Instruct",
555
+ "hook_name": "blocks.19.hook_resid_post",
556
+ "hook_head_index": None,
557
+ "dataset_path": "lmsys/lmsys-chat-1m",
558
+ "apply_b_dec_to_input": False,
559
+ }
560
+ if cfg_dict is None:
561
+ raise ValueError(f"Unsupported Goodfire SAE: {repo_id}/{folder_name}")
562
+ if device is not None:
563
+ cfg_dict["device"] = device
564
+ if cfg_overrides is not None:
565
+ cfg_dict.update(cfg_overrides)
566
+ return cfg_dict
567
+
568
+
569
+ def get_goodfire_huggingface_loader(
570
+ repo_id: str,
571
+ folder_name: str,
572
+ device: str = "cpu",
573
+ force_download: bool = False,
574
+ cfg_overrides: dict[str, Any] | None = None,
575
+ ) -> tuple[dict[str, Any], dict[str, torch.Tensor], torch.Tensor | None]:
576
+ cfg_dict = get_goodfire_config_from_hf(
577
+ repo_id,
578
+ folder_name,
579
+ device,
580
+ force_download,
581
+ cfg_overrides,
582
+ )
583
+
584
+ # Download the SAE weights
585
+ sae_path = hf_hub_download(
586
+ repo_id=repo_id,
587
+ filename=folder_name,
588
+ force_download=force_download,
589
+ )
590
+ raw_state_dict = torch.load(sae_path, map_location=device)
591
+
592
+ state_dict = {
593
+ "W_enc": raw_state_dict["encoder_linear.weight"].T,
594
+ "W_dec": raw_state_dict["decoder_linear.weight"].T,
595
+ "b_enc": raw_state_dict["encoder_linear.bias"],
596
+ "b_dec": raw_state_dict["decoder_linear.bias"],
597
+ }
598
+
599
+ return cfg_dict, state_dict, None
600
+
601
+
526
602
  def get_llama_scope_config_from_hf(
527
603
  repo_id: str,
528
604
  folder_name: str,
@@ -1475,6 +1551,114 @@ def get_mntss_clt_layer_config_from_hf(
1475
1551
  }
1476
1552
 
1477
1553
 
1554
+ def get_temporal_sae_config_from_hf(
1555
+ repo_id: str,
1556
+ folder_name: str,
1557
+ device: str,
1558
+ force_download: bool = False,
1559
+ cfg_overrides: dict[str, Any] | None = None,
1560
+ ) -> dict[str, Any]:
1561
+ """Get TemporalSAE config without loading weights."""
1562
+ # Download config file
1563
+ conf_path = hf_hub_download(
1564
+ repo_id=repo_id,
1565
+ filename=f"{folder_name}/conf.yaml",
1566
+ force_download=force_download,
1567
+ )
1568
+
1569
+ # Load and parse config
1570
+ with open(conf_path) as f:
1571
+ yaml_config = yaml.safe_load(f)
1572
+
1573
+ # Extract parameters
1574
+ d_in = yaml_config["llm"]["dimin"]
1575
+ exp_factor = yaml_config["sae"]["exp_factor"]
1576
+ d_sae = int(d_in * exp_factor)
1577
+
1578
+ # extract layer from folder_name eg : "layer_12/temporal"
1579
+ layer = re.search(r"layer_(\d+)", folder_name)
1580
+ if layer is None:
1581
+ raise ValueError(f"Could not find layer in folder_name: {folder_name}")
1582
+ layer = int(layer.group(1))
1583
+
1584
+ # Build config dict
1585
+ cfg_dict = {
1586
+ "architecture": "temporal",
1587
+ "hook_name": f"blocks.{layer}.hook_resid_post",
1588
+ "d_in": d_in,
1589
+ "d_sae": d_sae,
1590
+ "n_heads": yaml_config["sae"]["n_heads"],
1591
+ "n_attn_layers": yaml_config["sae"]["n_attn_layers"],
1592
+ "bottleneck_factor": yaml_config["sae"]["bottleneck_factor"],
1593
+ "sae_diff_type": yaml_config["sae"]["sae_diff_type"],
1594
+ "kval_topk": yaml_config["sae"]["kval_topk"],
1595
+ "tied_weights": yaml_config["sae"]["tied_weights"],
1596
+ "dtype": yaml_config["data"]["dtype"],
1597
+ "device": device,
1598
+ "normalize_activations": "constant_scalar_rescale",
1599
+ "activation_normalization_factor": yaml_config["sae"]["scaling_factor"],
1600
+ "apply_b_dec_to_input": True,
1601
+ }
1602
+
1603
+ if cfg_overrides:
1604
+ cfg_dict.update(cfg_overrides)
1605
+
1606
+ return cfg_dict
1607
+
1608
+
1609
+ def temporal_sae_huggingface_loader(
1610
+ repo_id: str,
1611
+ folder_name: str,
1612
+ device: str = "cpu",
1613
+ force_download: bool = False,
1614
+ cfg_overrides: dict[str, Any] | None = None,
1615
+ ) -> tuple[dict[str, Any], dict[str, torch.Tensor], torch.Tensor | None]:
1616
+ """
1617
+ Load TemporalSAE from canrager/temporalSAEs format (safetensors version).
1618
+
1619
+ Expects folder_name to contain:
1620
+ - conf.yaml (configuration)
1621
+ - latest_ckpt.safetensors (model weights)
1622
+ """
1623
+
1624
+ cfg_dict = get_temporal_sae_config_from_hf(
1625
+ repo_id=repo_id,
1626
+ folder_name=folder_name,
1627
+ device=device,
1628
+ force_download=force_download,
1629
+ cfg_overrides=cfg_overrides,
1630
+ )
1631
+
1632
+ # Download checkpoint (safetensors format)
1633
+ ckpt_path = hf_hub_download(
1634
+ repo_id=repo_id,
1635
+ filename=f"{folder_name}/latest_ckpt.safetensors",
1636
+ force_download=force_download,
1637
+ )
1638
+
1639
+ # Load checkpoint from safetensors
1640
+ state_dict_raw = load_file(ckpt_path, device=device)
1641
+
1642
+ # Convert to SAELens naming convention
1643
+ # TemporalSAE uses: D (decoder), E (encoder), b (bias), attn_layers.*
1644
+ state_dict = {}
1645
+
1646
+ # Copy attention layers as-is
1647
+ for key, value in state_dict_raw.items():
1648
+ if key.startswith("attn_layers."):
1649
+ state_dict[key] = value.to(device)
1650
+
1651
+ # Main parameters
1652
+ state_dict["W_dec"] = state_dict_raw["D"].to(device)
1653
+ state_dict["b_dec"] = state_dict_raw["b"].to(device)
1654
+
1655
+ # Handle tied/untied weights
1656
+ if "E" in state_dict_raw:
1657
+ state_dict["W_enc"] = state_dict_raw["E"].to(device)
1658
+
1659
+ return cfg_dict, state_dict, None
1660
+
1661
+
1478
1662
  NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
1479
1663
  "sae_lens": sae_lens_huggingface_loader,
1480
1664
  "connor_rob_hook_z": connor_rob_hook_z_huggingface_loader,
@@ -1487,6 +1671,8 @@ NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
1487
1671
  "gemma_2_transcoder": gemma_2_transcoder_huggingface_loader,
1488
1672
  "mwhanna_transcoder": mwhanna_transcoder_huggingface_loader,
1489
1673
  "mntss_clt_layer_transcoder": mntss_clt_layer_huggingface_loader,
1674
+ "temporal": temporal_sae_huggingface_loader,
1675
+ "goodfire": get_goodfire_huggingface_loader,
1490
1676
  }
1491
1677
 
1492
1678
 
@@ -1502,4 +1688,6 @@ NAMED_PRETRAINED_SAE_CONFIG_GETTERS: dict[str, PretrainedSaeConfigHuggingfaceLoa
1502
1688
  "gemma_2_transcoder": get_gemma_2_transcoder_config_from_hf,
1503
1689
  "mwhanna_transcoder": get_mwhanna_transcoder_config_from_hf,
1504
1690
  "mntss_clt_layer_transcoder": get_mntss_clt_layer_config_from_hf,
1691
+ "temporal": get_temporal_sae_config_from_hf,
1692
+ "goodfire": get_goodfire_config_from_hf,
1505
1693
  }
@@ -1,6 +1,6 @@
1
1
  from dataclasses import dataclass
2
2
  from functools import cache
3
- from importlib import resources
3
+ from importlib.resources import files
4
4
  from typing import Any
5
5
 
6
6
  import yaml
@@ -24,7 +24,8 @@ def get_pretrained_saes_directory() -> dict[str, PretrainedSAELookup]:
24
24
  package = "sae_lens"
25
25
  # Access the file within the package using importlib.resources
26
26
  directory: dict[str, PretrainedSAELookup] = {}
27
- with resources.open_text(package, "pretrained_saes.yaml") as file:
27
+ yaml_file = files(package).joinpath("pretrained_saes.yaml")
28
+ with yaml_file.open("r") as file:
28
29
  # Load the YAML file content
29
30
  data = yaml.safe_load(file)
30
31
  for release, value in data.items():
@@ -68,7 +69,8 @@ def get_norm_scaling_factor(release: str, sae_id: str) -> float | None:
68
69
  float | None: The norm_scaling_factor if it exists, None otherwise.
69
70
  """
70
71
  package = "sae_lens"
71
- with resources.open_text(package, "pretrained_saes.yaml") as file:
72
+ yaml_file = files(package).joinpath("pretrained_saes.yaml")
73
+ with yaml_file.open("r") as file:
72
74
  data = yaml.safe_load(file)
73
75
  if release in data:
74
76
  for sae_info in data[release]["saes"]: