sae-lens 6.12.1__py3-none-any.whl → 6.21.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sae-lens might be problematic. Click here for more details.

sae_lens/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.12.1"
2
+ __version__ = "6.21.0"
3
3
 
4
4
  import logging
5
5
 
@@ -19,6 +19,8 @@ from sae_lens.saes import (
19
19
  JumpReLUTrainingSAEConfig,
20
20
  JumpReLUTranscoder,
21
21
  JumpReLUTranscoderConfig,
22
+ MatryoshkaBatchTopKTrainingSAE,
23
+ MatryoshkaBatchTopKTrainingSAEConfig,
22
24
  SAEConfig,
23
25
  SkipTranscoder,
24
26
  SkipTranscoderConfig,
@@ -26,6 +28,8 @@ from sae_lens.saes import (
26
28
  StandardSAEConfig,
27
29
  StandardTrainingSAE,
28
30
  StandardTrainingSAEConfig,
31
+ TemporalSAE,
32
+ TemporalSAEConfig,
29
33
  TopKSAE,
30
34
  TopKSAEConfig,
31
35
  TopKTrainingSAE,
@@ -101,6 +105,10 @@ __all__ = [
101
105
  "SkipTranscoderConfig",
102
106
  "JumpReLUTranscoder",
103
107
  "JumpReLUTranscoderConfig",
108
+ "MatryoshkaBatchTopKTrainingSAE",
109
+ "MatryoshkaBatchTopKTrainingSAEConfig",
110
+ "TemporalSAE",
111
+ "TemporalSAEConfig",
104
112
  ]
105
113
 
106
114
 
@@ -115,6 +123,12 @@ register_sae_training_class("jumprelu", JumpReLUTrainingSAE, JumpReLUTrainingSAE
115
123
  register_sae_training_class(
116
124
  "batchtopk", BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig
117
125
  )
126
+ register_sae_training_class(
127
+ "matryoshka_batchtopk",
128
+ MatryoshkaBatchTopKTrainingSAE,
129
+ MatryoshkaBatchTopKTrainingSAEConfig,
130
+ )
118
131
  register_sae_class("transcoder", Transcoder, TranscoderConfig)
119
132
  register_sae_class("skip_transcoder", SkipTranscoder, SkipTranscoderConfig)
120
133
  register_sae_class("jumprelu_transcoder", JumpReLUTranscoder, JumpReLUTranscoderConfig)
134
+ register_sae_class("temporal", TemporalSAE, TemporalSAEConfig)
@@ -10,7 +10,7 @@ from datasets import Array2D, Dataset, Features, Sequence, Value
10
10
  from datasets.fingerprint import generate_fingerprint
11
11
  from huggingface_hub import HfApi
12
12
  from jaxtyping import Float, Int
13
- from tqdm import tqdm
13
+ from tqdm.auto import tqdm
14
14
  from transformer_lens.HookedTransformer import HookedRootModule
15
15
 
16
16
  from sae_lens import logger
sae_lens/config.py CHANGED
@@ -18,6 +18,7 @@ from datasets import (
18
18
 
19
19
  from sae_lens import __version__, logger
20
20
  from sae_lens.constants import DTYPE_MAP
21
+ from sae_lens.registry import get_sae_training_class
21
22
  from sae_lens.saes.sae import TrainingSAEConfig
22
23
 
23
24
  if TYPE_CHECKING:
@@ -171,6 +172,7 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
171
172
  n_checkpoints (int): The number of checkpoints to save during training. 0 means no checkpoints.
172
173
  checkpoint_path (str | None): The path to save checkpoints. A unique ID will be appended to this path. Set to None to disable checkpoint saving. (default is "checkpoints")
173
174
  save_final_checkpoint (bool): Whether to include an additional final checkpoint when training is finished. (default is False).
175
+ resume_from_checkpoint (str | None): The path to the checkpoint to resume training from. (default is None).
174
176
  output_path (str | None): The path to save outputs. Set to None to disable output saving. (default is "output")
175
177
  verbose (bool): Whether to print verbose output. (default is True)
176
178
  model_kwargs (dict[str, Any]): Keyword arguments for `model.run_with_cache`
@@ -261,6 +263,7 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
261
263
  checkpoint_path: str | None = "checkpoints"
262
264
  save_final_checkpoint: bool = False
263
265
  output_path: str | None = "output"
266
+ resume_from_checkpoint: str | None = None
264
267
 
265
268
  # Misc
266
269
  verbose: bool = True
@@ -385,8 +388,11 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
385
388
  return self.sae.to_dict()
386
389
 
387
390
  def to_dict(self) -> dict[str, Any]:
388
- # Make a shallow copy of config's dictionary
389
- d = dict(self.__dict__)
391
+ """
392
+ Convert the config to a dictionary.
393
+ """
394
+
395
+ d = asdict(self)
390
396
 
391
397
  d["logger"] = asdict(self.logger)
392
398
  d["sae"] = self.sae.to_dict()
@@ -396,6 +402,37 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
396
402
  d["act_store_device"] = str(self.act_store_device)
397
403
  return d
398
404
 
405
+ @classmethod
406
+ def from_dict(cls, cfg_dict: dict[str, Any]) -> "LanguageModelSAERunnerConfig[Any]":
407
+ """
408
+ Load a LanguageModelSAERunnerConfig from a dictionary given by `to_dict`.
409
+
410
+ Args:
411
+ cfg_dict (dict[str, Any]): The dictionary to load the config from.
412
+
413
+ Returns:
414
+ LanguageModelSAERunnerConfig: The loaded config.
415
+ """
416
+ if "sae" not in cfg_dict:
417
+ raise ValueError("sae field is required in the config dictionary")
418
+ if "architecture" not in cfg_dict["sae"]:
419
+ raise ValueError("architecture field is required in the sae dictionary")
420
+ if "logger" not in cfg_dict:
421
+ raise ValueError("logger field is required in the config dictionary")
422
+ sae_config_class = get_sae_training_class(cfg_dict["sae"]["architecture"])[1]
423
+ sae_cfg = sae_config_class.from_dict(cfg_dict["sae"])
424
+ logger_cfg = LoggingConfig(**cfg_dict["logger"])
425
+ updated_cfg_dict: dict[str, Any] = {
426
+ **cfg_dict,
427
+ "sae": sae_cfg,
428
+ "logger": logger_cfg,
429
+ }
430
+ output = cls(**updated_cfg_dict)
431
+ # the post_init always appends to checkpoint path, so we need to set it explicitly here.
432
+ if "checkpoint_path" in cfg_dict:
433
+ output.checkpoint_path = cfg_dict["checkpoint_path"]
434
+ return output
435
+
399
436
  def to_sae_trainer_config(self) -> "SAETrainerConfig":
400
437
  return SAETrainerConfig(
401
438
  n_checkpoints=self.n_checkpoints,
sae_lens/constants.py CHANGED
@@ -17,5 +17,6 @@ SAE_WEIGHTS_FILENAME = "sae_weights.safetensors"
17
17
  SAE_CFG_FILENAME = "cfg.json"
18
18
  RUNNER_CFG_FILENAME = "runner_cfg.json"
19
19
  SPARSIFY_WEIGHTS_FILENAME = "sae.safetensors"
20
+ TRAINER_STATE_FILENAME = "trainer_state.pt"
20
21
  ACTIVATIONS_STORE_STATE_FILENAME = "activations_store_state.safetensors"
21
22
  ACTIVATION_SCALER_CFG_FILENAME = "activation_scaler.json"
sae_lens/evals.py CHANGED
@@ -11,7 +11,7 @@ from dataclasses import dataclass, field
11
11
  from functools import partial
12
12
  from importlib.metadata import PackageNotFoundError, version
13
13
  from pathlib import Path
14
- from typing import Any
14
+ from typing import Any, Iterable
15
15
 
16
16
  import einops
17
17
  import pandas as pd
@@ -24,7 +24,10 @@ from sae_lens.loading.pretrained_saes_directory import get_pretrained_saes_direc
24
24
  from sae_lens.saes.sae import SAE, SAEConfig
25
25
  from sae_lens.training.activation_scaler import ActivationScaler
26
26
  from sae_lens.training.activations_store import ActivationsStore
27
- from sae_lens.util import extract_stop_at_layer_from_tlens_hook_name
27
+ from sae_lens.util import (
28
+ extract_stop_at_layer_from_tlens_hook_name,
29
+ get_special_token_ids,
30
+ )
28
31
 
29
32
 
30
33
  def get_library_version() -> str:
@@ -109,9 +112,15 @@ def run_evals(
109
112
  activation_scaler: ActivationScaler,
110
113
  eval_config: EvalConfig = EvalConfig(),
111
114
  model_kwargs: Mapping[str, Any] = {},
112
- ignore_tokens: set[int | None] = set(),
115
+ exclude_special_tokens: Iterable[int] | bool = True,
113
116
  verbose: bool = False,
114
117
  ) -> tuple[dict[str, Any], dict[str, Any]]:
118
+ ignore_tokens = None
119
+ if exclude_special_tokens is True:
120
+ ignore_tokens = list(get_special_token_ids(model.tokenizer)) # type: ignore
121
+ elif exclude_special_tokens:
122
+ ignore_tokens = list(exclude_special_tokens)
123
+
115
124
  hook_name = sae.cfg.metadata.hook_name
116
125
  actual_batch_size = (
117
126
  eval_config.batch_size_prompts or activation_store.store_batch_size_prompts
@@ -312,7 +321,7 @@ def get_downstream_reconstruction_metrics(
312
321
  compute_ce_loss: bool,
313
322
  n_batches: int,
314
323
  eval_batch_size_prompts: int,
315
- ignore_tokens: set[int | None] = set(),
324
+ ignore_tokens: list[int] | None = None,
316
325
  verbose: bool = False,
317
326
  ):
318
327
  metrics_dict = {}
@@ -339,7 +348,7 @@ def get_downstream_reconstruction_metrics(
339
348
  compute_ce_loss=compute_ce_loss,
340
349
  ignore_tokens=ignore_tokens,
341
350
  ).items():
342
- if len(ignore_tokens) > 0:
351
+ if ignore_tokens:
343
352
  mask = torch.logical_not(
344
353
  torch.any(
345
354
  torch.stack(
@@ -384,7 +393,7 @@ def get_sparsity_and_variance_metrics(
384
393
  compute_featurewise_density_statistics: bool,
385
394
  eval_batch_size_prompts: int,
386
395
  model_kwargs: Mapping[str, Any],
387
- ignore_tokens: set[int | None] = set(),
396
+ ignore_tokens: list[int] | None = None,
388
397
  verbose: bool = False,
389
398
  ) -> tuple[dict[str, Any], dict[str, Any]]:
390
399
  hook_name = sae.cfg.metadata.hook_name
@@ -426,7 +435,7 @@ def get_sparsity_and_variance_metrics(
426
435
  for _ in batch_iter:
427
436
  batch_tokens = activation_store.get_batch_tokens(eval_batch_size_prompts)
428
437
 
429
- if len(ignore_tokens) > 0:
438
+ if ignore_tokens:
430
439
  mask = torch.logical_not(
431
440
  torch.any(
432
441
  torch.stack(
@@ -466,6 +475,8 @@ def get_sparsity_and_variance_metrics(
466
475
  sae_out_scaled = sae.decode(sae_feature_activations).to(
467
476
  original_act_scaled.device
468
477
  )
478
+ if sae_feature_activations.is_sparse:
479
+ sae_feature_activations = sae_feature_activations.to_dense()
469
480
  del cache
470
481
 
471
482
  sae_out = activation_scaler.unscale(sae_out_scaled)
@@ -594,7 +605,7 @@ def get_recons_loss(
594
605
  batch_tokens: torch.Tensor,
595
606
  compute_kl: bool,
596
607
  compute_ce_loss: bool,
597
- ignore_tokens: set[int | None] = set(),
608
+ ignore_tokens: list[int] | None = None,
598
609
  model_kwargs: Mapping[str, Any] = {},
599
610
  hook_name: str | None = None,
600
611
  ) -> dict[str, Any]:
@@ -608,7 +619,7 @@ def get_recons_loss(
608
619
  batch_tokens, return_type="both", loss_per_token=True, **model_kwargs
609
620
  )
610
621
 
611
- if len(ignore_tokens) > 0:
622
+ if ignore_tokens:
612
623
  mask = torch.logical_not(
613
624
  torch.any(
614
625
  torch.stack([batch_tokens == token for token in ignore_tokens], dim=0),
@@ -854,11 +865,6 @@ def multiple_evals(
854
865
  activation_scaler=ActivationScaler(),
855
866
  model=current_model,
856
867
  eval_config=eval_config,
857
- ignore_tokens={
858
- current_model.tokenizer.pad_token_id, # type: ignore
859
- current_model.tokenizer.eos_token_id, # type: ignore
860
- current_model.tokenizer.bos_token_id, # type: ignore
861
- },
862
868
  verbose=verbose,
863
869
  )
864
870
  eval_metrics["metrics"] = scalar_metrics
@@ -16,23 +16,18 @@ from typing_extensions import deprecated
16
16
  from sae_lens import logger
17
17
  from sae_lens.config import HfDataset, LanguageModelSAERunnerConfig
18
18
  from sae_lens.constants import (
19
- ACTIVATIONS_STORE_STATE_FILENAME,
20
19
  RUNNER_CFG_FILENAME,
21
20
  SPARSITY_FILENAME,
22
21
  )
23
22
  from sae_lens.evals import EvalConfig, run_evals
24
23
  from sae_lens.load_model import load_model
25
- from sae_lens.saes.batchtopk_sae import BatchTopKTrainingSAEConfig
26
- from sae_lens.saes.gated_sae import GatedTrainingSAEConfig
27
- from sae_lens.saes.jumprelu_sae import JumpReLUTrainingSAEConfig
24
+ from sae_lens.registry import SAE_TRAINING_CLASS_REGISTRY
28
25
  from sae_lens.saes.sae import (
29
26
  T_TRAINING_SAE,
30
27
  T_TRAINING_SAE_CONFIG,
31
28
  TrainingSAE,
32
29
  TrainingSAEConfig,
33
30
  )
34
- from sae_lens.saes.standard_sae import StandardTrainingSAEConfig
35
- from sae_lens.saes.topk_sae import TopKTrainingSAEConfig
36
31
  from sae_lens.training.activation_scaler import ActivationScaler
37
32
  from sae_lens.training.activations_store import ActivationsStore
38
33
  from sae_lens.training.sae_trainer import SAETrainer
@@ -61,9 +56,11 @@ class LLMSaeEvaluator(Generic[T_TRAINING_SAE]):
61
56
  data_provider: DataProvider,
62
57
  activation_scaler: ActivationScaler,
63
58
  ) -> dict[str, Any]:
64
- ignore_tokens = set()
59
+ exclude_special_tokens = False
65
60
  if self.activations_store.exclude_special_tokens is not None:
66
- ignore_tokens = set(self.activations_store.exclude_special_tokens.tolist())
61
+ exclude_special_tokens = (
62
+ self.activations_store.exclude_special_tokens.tolist()
63
+ )
67
64
 
68
65
  eval_config = EvalConfig(
69
66
  batch_size_prompts=self.eval_batch_size_prompts,
@@ -81,7 +78,7 @@ class LLMSaeEvaluator(Generic[T_TRAINING_SAE]):
81
78
  model=self.model,
82
79
  activation_scaler=activation_scaler,
83
80
  eval_config=eval_config,
84
- ignore_tokens=ignore_tokens,
81
+ exclude_special_tokens=exclude_special_tokens,
85
82
  model_kwargs=self.model_kwargs,
86
83
  ) # not calculating featurwise metrics here.
87
84
 
@@ -114,6 +111,7 @@ class LanguageModelSAETrainingRunner:
114
111
  override_dataset: HfDataset | None = None,
115
112
  override_model: HookedRootModule | None = None,
116
113
  override_sae: TrainingSAE[Any] | None = None,
114
+ resume_from_checkpoint: Path | str | None = None,
117
115
  ):
118
116
  if override_dataset is not None:
119
117
  logger.warning(
@@ -155,6 +153,7 @@ class LanguageModelSAETrainingRunner:
155
153
  )
156
154
  else:
157
155
  self.sae = override_sae
156
+
158
157
  self.sae.to(self.cfg.device)
159
158
 
160
159
  def run(self):
@@ -187,6 +186,12 @@ class LanguageModelSAETrainingRunner:
187
186
  cfg=self.cfg.to_sae_trainer_config(),
188
187
  )
189
188
 
189
+ if self.cfg.resume_from_checkpoint is not None:
190
+ logger.info(f"Resuming from checkpoint: {self.cfg.resume_from_checkpoint}")
191
+ trainer.load_trainer_state(self.cfg.resume_from_checkpoint)
192
+ self.sae.load_weights_from_checkpoint(self.cfg.resume_from_checkpoint)
193
+ self.activations_store.load_from_checkpoint(self.cfg.resume_from_checkpoint)
194
+
190
195
  self._compile_if_needed()
191
196
  sae = self.run_trainer_with_interruption_handling(trainer)
192
197
 
@@ -306,9 +311,7 @@ class LanguageModelSAETrainingRunner:
306
311
  if checkpoint_path is None:
307
312
  return
308
313
 
309
- self.activations_store.save(
310
- str(checkpoint_path / ACTIVATIONS_STORE_STATE_FILENAME)
311
- )
314
+ self.activations_store.save_to_checkpoint(checkpoint_path)
312
315
 
313
316
  runner_config = self.cfg.to_dict()
314
317
  with open(checkpoint_path / RUNNER_CFG_FILENAME, "w") as f:
@@ -393,12 +396,8 @@ def _parse_cfg_args(
393
396
  )
394
397
 
395
398
  # Map architecture to concrete config class
396
- sae_config_map = {
397
- "standard": StandardTrainingSAEConfig,
398
- "gated": GatedTrainingSAEConfig,
399
- "jumprelu": JumpReLUTrainingSAEConfig,
400
- "topk": TopKTrainingSAEConfig,
401
- "batchtopk": BatchTopKTrainingSAEConfig,
399
+ sae_config_map: dict[str, type[TrainingSAEConfig]] = {
400
+ name: cfg for name, (_, cfg) in SAE_TRAINING_CLASS_REGISTRY.items()
402
401
  }
403
402
 
404
403
  sae_config_type = sae_config_map[architecture]
@@ -233,6 +233,12 @@ def handle_pre_6_0_config(cfg_dict: dict[str, Any]) -> dict[str, Any]:
233
233
  "reshape_activations",
234
234
  "hook_z" if "hook_z" in new_cfg.get("hook_name", "") else "none",
235
235
  )
236
+ if (
237
+ new_cfg.get("activation_fn") == "topk"
238
+ and new_cfg.get("activation_fn_kwargs", {}).get("k") is not None
239
+ ):
240
+ new_cfg["architecture"] = "topk"
241
+ new_cfg["k"] = new_cfg["activation_fn_kwargs"]["k"]
236
242
 
237
243
  if "normalize_activations" in new_cfg and isinstance(
238
244
  new_cfg["normalize_activations"], bool
@@ -517,6 +523,82 @@ def gemma_2_sae_huggingface_loader(
517
523
  return cfg_dict, state_dict, log_sparsity
518
524
 
519
525
 
526
+ def get_goodfire_config_from_hf(
527
+ repo_id: str,
528
+ folder_name: str, # noqa: ARG001
529
+ device: str,
530
+ force_download: bool = False, # noqa: ARG001
531
+ cfg_overrides: dict[str, Any] | None = None,
532
+ ) -> dict[str, Any]:
533
+ cfg_dict = None
534
+ if repo_id == "Goodfire/Llama-3.3-70B-Instruct-SAE-l50":
535
+ if folder_name != "Llama-3.3-70B-Instruct-SAE-l50.pt":
536
+ raise ValueError(f"Unsupported Goodfire SAE: {repo_id}/{folder_name}")
537
+ cfg_dict = {
538
+ "architecture": "standard",
539
+ "d_in": 8192,
540
+ "d_sae": 65536,
541
+ "model_name": "meta-llama/Llama-3.3-70B-Instruct",
542
+ "hook_name": "blocks.50.hook_resid_post",
543
+ "hook_head_index": None,
544
+ "dataset_path": "lmsys/lmsys-chat-1m",
545
+ "apply_b_dec_to_input": False,
546
+ }
547
+ elif repo_id == "Goodfire/Llama-3.1-8B-Instruct-SAE-l19":
548
+ if folder_name != "Llama-3.1-8B-Instruct-SAE-l19.pth":
549
+ raise ValueError(f"Unsupported Goodfire SAE: {repo_id}/{folder_name}")
550
+ cfg_dict = {
551
+ "architecture": "standard",
552
+ "d_in": 4096,
553
+ "d_sae": 65536,
554
+ "model_name": "meta-llama/Llama-3.1-8B-Instruct",
555
+ "hook_name": "blocks.19.hook_resid_post",
556
+ "hook_head_index": None,
557
+ "dataset_path": "lmsys/lmsys-chat-1m",
558
+ "apply_b_dec_to_input": False,
559
+ }
560
+ if cfg_dict is None:
561
+ raise ValueError(f"Unsupported Goodfire SAE: {repo_id}/{folder_name}")
562
+ if device is not None:
563
+ cfg_dict["device"] = device
564
+ if cfg_overrides is not None:
565
+ cfg_dict.update(cfg_overrides)
566
+ return cfg_dict
567
+
568
+
569
+ def get_goodfire_huggingface_loader(
570
+ repo_id: str,
571
+ folder_name: str,
572
+ device: str = "cpu",
573
+ force_download: bool = False,
574
+ cfg_overrides: dict[str, Any] | None = None,
575
+ ) -> tuple[dict[str, Any], dict[str, torch.Tensor], torch.Tensor | None]:
576
+ cfg_dict = get_goodfire_config_from_hf(
577
+ repo_id,
578
+ folder_name,
579
+ device,
580
+ force_download,
581
+ cfg_overrides,
582
+ )
583
+
584
+ # Download the SAE weights
585
+ sae_path = hf_hub_download(
586
+ repo_id=repo_id,
587
+ filename=folder_name,
588
+ force_download=force_download,
589
+ )
590
+ raw_state_dict = torch.load(sae_path, map_location=device)
591
+
592
+ state_dict = {
593
+ "W_enc": raw_state_dict["encoder_linear.weight"].T,
594
+ "W_dec": raw_state_dict["decoder_linear.weight"].T,
595
+ "b_enc": raw_state_dict["encoder_linear.bias"],
596
+ "b_dec": raw_state_dict["decoder_linear.bias"],
597
+ }
598
+
599
+ return cfg_dict, state_dict, None
600
+
601
+
520
602
  def get_llama_scope_config_from_hf(
521
603
  repo_id: str,
522
604
  folder_name: str,
@@ -1469,6 +1551,114 @@ def get_mntss_clt_layer_config_from_hf(
1469
1551
  }
1470
1552
 
1471
1553
 
1554
+ def get_temporal_sae_config_from_hf(
1555
+ repo_id: str,
1556
+ folder_name: str,
1557
+ device: str,
1558
+ force_download: bool = False,
1559
+ cfg_overrides: dict[str, Any] | None = None,
1560
+ ) -> dict[str, Any]:
1561
+ """Get TemporalSAE config without loading weights."""
1562
+ # Download config file
1563
+ conf_path = hf_hub_download(
1564
+ repo_id=repo_id,
1565
+ filename=f"{folder_name}/conf.yaml",
1566
+ force_download=force_download,
1567
+ )
1568
+
1569
+ # Load and parse config
1570
+ with open(conf_path) as f:
1571
+ yaml_config = yaml.safe_load(f)
1572
+
1573
+ # Extract parameters
1574
+ d_in = yaml_config["llm"]["dimin"]
1575
+ exp_factor = yaml_config["sae"]["exp_factor"]
1576
+ d_sae = int(d_in * exp_factor)
1577
+
1578
+ # extract layer from folder_name eg : "layer_12/temporal"
1579
+ layer = re.search(r"layer_(\d+)", folder_name)
1580
+ if layer is None:
1581
+ raise ValueError(f"Could not find layer in folder_name: {folder_name}")
1582
+ layer = int(layer.group(1))
1583
+
1584
+ # Build config dict
1585
+ cfg_dict = {
1586
+ "architecture": "temporal",
1587
+ "hook_name": f"blocks.{layer}.hook_resid_post",
1588
+ "d_in": d_in,
1589
+ "d_sae": d_sae,
1590
+ "n_heads": yaml_config["sae"]["n_heads"],
1591
+ "n_attn_layers": yaml_config["sae"]["n_attn_layers"],
1592
+ "bottleneck_factor": yaml_config["sae"]["bottleneck_factor"],
1593
+ "sae_diff_type": yaml_config["sae"]["sae_diff_type"],
1594
+ "kval_topk": yaml_config["sae"]["kval_topk"],
1595
+ "tied_weights": yaml_config["sae"]["tied_weights"],
1596
+ "dtype": yaml_config["data"]["dtype"],
1597
+ "device": device,
1598
+ "normalize_activations": "constant_scalar_rescale",
1599
+ "activation_normalization_factor": yaml_config["sae"]["scaling_factor"],
1600
+ "apply_b_dec_to_input": True,
1601
+ }
1602
+
1603
+ if cfg_overrides:
1604
+ cfg_dict.update(cfg_overrides)
1605
+
1606
+ return cfg_dict
1607
+
1608
+
1609
+ def temporal_sae_huggingface_loader(
1610
+ repo_id: str,
1611
+ folder_name: str,
1612
+ device: str = "cpu",
1613
+ force_download: bool = False,
1614
+ cfg_overrides: dict[str, Any] | None = None,
1615
+ ) -> tuple[dict[str, Any], dict[str, torch.Tensor], torch.Tensor | None]:
1616
+ """
1617
+ Load TemporalSAE from canrager/temporalSAEs format (safetensors version).
1618
+
1619
+ Expects folder_name to contain:
1620
+ - conf.yaml (configuration)
1621
+ - latest_ckpt.safetensors (model weights)
1622
+ """
1623
+
1624
+ cfg_dict = get_temporal_sae_config_from_hf(
1625
+ repo_id=repo_id,
1626
+ folder_name=folder_name,
1627
+ device=device,
1628
+ force_download=force_download,
1629
+ cfg_overrides=cfg_overrides,
1630
+ )
1631
+
1632
+ # Download checkpoint (safetensors format)
1633
+ ckpt_path = hf_hub_download(
1634
+ repo_id=repo_id,
1635
+ filename=f"{folder_name}/latest_ckpt.safetensors",
1636
+ force_download=force_download,
1637
+ )
1638
+
1639
+ # Load checkpoint from safetensors
1640
+ state_dict_raw = load_file(ckpt_path, device=device)
1641
+
1642
+ # Convert to SAELens naming convention
1643
+ # TemporalSAE uses: D (decoder), E (encoder), b (bias), attn_layers.*
1644
+ state_dict = {}
1645
+
1646
+ # Copy attention layers as-is
1647
+ for key, value in state_dict_raw.items():
1648
+ if key.startswith("attn_layers."):
1649
+ state_dict[key] = value.to(device)
1650
+
1651
+ # Main parameters
1652
+ state_dict["W_dec"] = state_dict_raw["D"].to(device)
1653
+ state_dict["b_dec"] = state_dict_raw["b"].to(device)
1654
+
1655
+ # Handle tied/untied weights
1656
+ if "E" in state_dict_raw:
1657
+ state_dict["W_enc"] = state_dict_raw["E"].to(device)
1658
+
1659
+ return cfg_dict, state_dict, None
1660
+
1661
+
1472
1662
  NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
1473
1663
  "sae_lens": sae_lens_huggingface_loader,
1474
1664
  "connor_rob_hook_z": connor_rob_hook_z_huggingface_loader,
@@ -1481,6 +1671,8 @@ NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
1481
1671
  "gemma_2_transcoder": gemma_2_transcoder_huggingface_loader,
1482
1672
  "mwhanna_transcoder": mwhanna_transcoder_huggingface_loader,
1483
1673
  "mntss_clt_layer_transcoder": mntss_clt_layer_huggingface_loader,
1674
+ "temporal": temporal_sae_huggingface_loader,
1675
+ "goodfire": get_goodfire_huggingface_loader,
1484
1676
  }
1485
1677
 
1486
1678
 
@@ -1496,4 +1688,6 @@ NAMED_PRETRAINED_SAE_CONFIG_GETTERS: dict[str, PretrainedSaeConfigHuggingfaceLoa
1496
1688
  "gemma_2_transcoder": get_gemma_2_transcoder_config_from_hf,
1497
1689
  "mwhanna_transcoder": get_mwhanna_transcoder_config_from_hf,
1498
1690
  "mntss_clt_layer_transcoder": get_mntss_clt_layer_config_from_hf,
1691
+ "temporal": get_temporal_sae_config_from_hf,
1692
+ "goodfire": get_goodfire_config_from_hf,
1499
1693
  }
@@ -1,6 +1,6 @@
1
1
  from dataclasses import dataclass
2
2
  from functools import cache
3
- from importlib import resources
3
+ from importlib.resources import files
4
4
  from typing import Any
5
5
 
6
6
  import yaml
@@ -24,7 +24,8 @@ def get_pretrained_saes_directory() -> dict[str, PretrainedSAELookup]:
24
24
  package = "sae_lens"
25
25
  # Access the file within the package using importlib.resources
26
26
  directory: dict[str, PretrainedSAELookup] = {}
27
- with resources.open_text(package, "pretrained_saes.yaml") as file:
27
+ yaml_file = files(package).joinpath("pretrained_saes.yaml")
28
+ with yaml_file.open("r") as file:
28
29
  # Load the YAML file content
29
30
  data = yaml.safe_load(file)
30
31
  for release, value in data.items():
@@ -68,7 +69,8 @@ def get_norm_scaling_factor(release: str, sae_id: str) -> float | None:
68
69
  float | None: The norm_scaling_factor if it exists, None otherwise.
69
70
  """
70
71
  package = "sae_lens"
71
- with resources.open_text(package, "pretrained_saes.yaml") as file:
72
+ yaml_file = files(package).joinpath("pretrained_saes.yaml")
73
+ with yaml_file.open("r") as file:
72
74
  data = yaml.safe_load(file)
73
75
  if release in data:
74
76
  for sae_info in data[release]["saes"]:
@@ -1,9 +1,10 @@
1
1
  import io
2
2
  import json
3
3
  import sys
4
+ from collections.abc import Iterator
4
5
  from dataclasses import dataclass
5
6
  from pathlib import Path
6
- from typing import Iterator, Literal, cast
7
+ from typing import Literal, cast
7
8
 
8
9
  import torch
9
10
  from datasets import Dataset, DatasetDict, load_dataset