sae-lens 5.10.3__py3-none-any.whl → 6.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sae_lens/evals.py CHANGED
@@ -19,8 +19,8 @@ from tqdm import tqdm
19
19
  from transformer_lens import HookedTransformer
20
20
  from transformer_lens.hook_points import HookedRootModule
21
21
 
22
- from sae_lens.sae import SAE
23
- from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory
22
+ from sae_lens.loading.pretrained_saes_directory import get_pretrained_saes_directory
23
+ from sae_lens.saes.sae import SAE, SAEConfig
24
24
  from sae_lens.training.activations_store import ActivationsStore
25
25
 
26
26
 
@@ -100,7 +100,7 @@ def get_eval_everything_config(
100
100
 
101
101
  @torch.no_grad()
102
102
  def run_evals(
103
- sae: SAE,
103
+ sae: SAE[Any],
104
104
  activation_store: ActivationsStore,
105
105
  model: HookedRootModule,
106
106
  eval_config: EvalConfig = EvalConfig(),
@@ -108,7 +108,7 @@ def run_evals(
108
108
  ignore_tokens: set[int | None] = set(),
109
109
  verbose: bool = False,
110
110
  ) -> tuple[dict[str, Any], dict[str, Any]]:
111
- hook_name = sae.cfg.hook_name
111
+ hook_name = sae.cfg.metadata.hook_name
112
112
  actual_batch_size = (
113
113
  eval_config.batch_size_prompts or activation_store.store_batch_size_prompts
114
114
  )
@@ -274,12 +274,11 @@ def run_evals(
274
274
  return all_metrics, feature_metrics
275
275
 
276
276
 
277
- def get_featurewise_weight_based_metrics(sae: SAE) -> dict[str, Any]:
277
+ def get_featurewise_weight_based_metrics(sae: SAE[Any]) -> dict[str, Any]:
278
278
  unit_norm_encoders = (sae.W_enc / sae.W_enc.norm(dim=0, keepdim=True)).cpu()
279
279
  unit_norm_decoder = (sae.W_dec.T / sae.W_dec.T.norm(dim=0, keepdim=True)).cpu()
280
280
 
281
281
  encoder_norms = sae.W_enc.norm(dim=-2).cpu().tolist()
282
- encoder_bias = sae.b_enc.cpu().tolist()
283
282
  encoder_decoder_cosine_sim = (
284
283
  torch.nn.functional.cosine_similarity(
285
284
  unit_norm_decoder.T,
@@ -289,15 +288,17 @@ def get_featurewise_weight_based_metrics(sae: SAE) -> dict[str, Any]:
289
288
  .tolist()
290
289
  )
291
290
 
292
- return {
293
- "encoder_bias": encoder_bias,
291
+ metrics = {
294
292
  "encoder_norm": encoder_norms,
295
293
  "encoder_decoder_cosine_sim": encoder_decoder_cosine_sim,
296
294
  }
295
+ if hasattr(sae, "b_enc") and sae.b_enc is not None:
296
+ metrics["encoder_bias"] = sae.b_enc.cpu().tolist() # type: ignore
297
+ return metrics
297
298
 
298
299
 
299
300
  def get_downstream_reconstruction_metrics(
300
- sae: SAE,
301
+ sae: SAE[Any],
301
302
  model: HookedRootModule,
302
303
  activation_store: ActivationsStore,
303
304
  compute_kl: bool,
@@ -365,7 +366,7 @@ def get_downstream_reconstruction_metrics(
365
366
 
366
367
 
367
368
  def get_sparsity_and_variance_metrics(
368
- sae: SAE,
369
+ sae: SAE[Any],
369
370
  model: HookedRootModule,
370
371
  activation_store: ActivationsStore,
371
372
  n_batches: int,
@@ -378,8 +379,8 @@ def get_sparsity_and_variance_metrics(
378
379
  ignore_tokens: set[int | None] = set(),
379
380
  verbose: bool = False,
380
381
  ) -> tuple[dict[str, Any], dict[str, Any]]:
381
- hook_name = sae.cfg.hook_name
382
- hook_head_index = sae.cfg.hook_head_index
382
+ hook_name = sae.cfg.metadata.hook_name
383
+ hook_head_index = sae.cfg.metadata.hook_head_index
383
384
 
384
385
  metric_dict = {}
385
386
  feature_metric_dict = {}
@@ -435,7 +436,7 @@ def get_sparsity_and_variance_metrics(
435
436
  batch_tokens,
436
437
  prepend_bos=False,
437
438
  names_filter=[hook_name],
438
- stop_at_layer=sae.cfg.hook_layer + 1,
439
+ stop_at_layer=sae.cfg.metadata.hook_layer + 1,
439
440
  **model_kwargs,
440
441
  )
441
442
 
@@ -579,7 +580,7 @@ def get_sparsity_and_variance_metrics(
579
580
 
580
581
  @torch.no_grad()
581
582
  def get_recons_loss(
582
- sae: SAE,
583
+ sae: SAE[SAEConfig],
583
584
  model: HookedRootModule,
584
585
  batch_tokens: torch.Tensor,
585
586
  activation_store: ActivationsStore,
@@ -587,9 +588,13 @@ def get_recons_loss(
587
588
  compute_ce_loss: bool,
588
589
  ignore_tokens: set[int | None] = set(),
589
590
  model_kwargs: Mapping[str, Any] = {},
591
+ hook_name: str | None = None,
590
592
  ) -> dict[str, Any]:
591
- hook_name = sae.cfg.hook_name
592
- head_index = sae.cfg.hook_head_index
593
+ hook_name = hook_name or sae.cfg.metadata.hook_name
594
+ head_index = sae.cfg.metadata.hook_head_index
595
+
596
+ if hook_name is None:
597
+ raise ValueError("hook_name must be provided")
593
598
 
594
599
  original_logits, original_ce_loss = model(
595
600
  batch_tokens, return_type="both", loss_per_token=True, **model_kwargs
@@ -764,6 +769,17 @@ def nested_dict() -> defaultdict[Any, Any]:
764
769
  return defaultdict(nested_dict)
765
770
 
766
771
 
772
+ def dict_to_nested(flat_dict: dict[str, Any]) -> defaultdict[Any, Any]:
773
+ nested = nested_dict()
774
+ for key, value in flat_dict.items():
775
+ parts = key.split("/")
776
+ d = nested
777
+ for part in parts[:-1]:
778
+ d = d[part]
779
+ d[parts[-1]] = value
780
+ return nested
781
+
782
+
767
783
  def multiple_evals(
768
784
  sae_regex_pattern: str,
769
785
  sae_block_pattern: str,
@@ -794,7 +810,6 @@ def multiple_evals(
794
810
 
795
811
  current_model = None
796
812
  current_model_str = None
797
- print(filtered_saes)
798
813
  for sae_release_name, sae_id, _, _ in tqdm(filtered_saes):
799
814
  sae = SAE.from_pretrained(
800
815
  release=sae_release_name, # see other options in sae_lens/pretrained_saes.yaml
@@ -7,21 +7,24 @@ import numpy as np
7
7
  import torch
8
8
  from huggingface_hub import hf_hub_download
9
9
  from huggingface_hub.utils import EntryNotFoundError
10
+ from packaging.version import Version
10
11
  from safetensors import safe_open
11
12
  from safetensors.torch import load_file
12
13
 
13
14
  from sae_lens import logger
14
- from sae_lens.config import (
15
+ from sae_lens.constants import (
15
16
  DTYPE_MAP,
16
17
  SAE_CFG_FILENAME,
17
18
  SAE_WEIGHTS_FILENAME,
18
19
  SPARSITY_FILENAME,
19
20
  )
20
- from sae_lens.toolkit.pretrained_saes_directory import (
21
+ from sae_lens.loading.pretrained_saes_directory import (
21
22
  get_config_overrides,
22
23
  get_pretrained_saes_directory,
23
24
  get_repo_id_and_folder_name,
24
25
  )
26
+ from sae_lens.registry import get_sae_class
27
+ from sae_lens.util import filter_valid_dataclass_fields
25
28
 
26
29
 
27
30
  # loaders take in a release, sae_id, device, and whether to force download, and returns a tuple of config, state_dict, and log sparsity
@@ -174,30 +177,68 @@ def get_sae_lens_config_from_disk(
174
177
 
175
178
 
176
179
  def handle_config_defaulting(cfg_dict: dict[str, Any]) -> dict[str, Any]:
180
+ sae_lens_version = cfg_dict.get("sae_lens_version")
181
+ if not sae_lens_version and "metadata" in cfg_dict:
182
+ sae_lens_version = cfg_dict["metadata"].get("sae_lens_version")
183
+
184
+ if not sae_lens_version or Version(sae_lens_version) < Version("6.0.0-rc.0"):
185
+ cfg_dict = handle_pre_6_0_config(cfg_dict)
186
+ return cfg_dict
187
+
188
+
189
+ def handle_pre_6_0_config(cfg_dict: dict[str, Any]) -> dict[str, Any]:
190
+ """
191
+ Format a config dictionary for a Sparse Autoencoder (SAE) to be compatible with the new 6.0 format.
192
+ """
193
+
194
+ rename_keys_map = {
195
+ "hook_point": "hook_name",
196
+ "hook_point_layer": "hook_layer",
197
+ "hook_point_head_index": "hook_head_index",
198
+ "activation_fn_str": "activation_fn",
199
+ }
200
+ new_cfg = {rename_keys_map.get(k, k): v for k, v in cfg_dict.items()}
201
+
177
202
  # Set default values for backwards compatibility
178
- cfg_dict.setdefault("prepend_bos", True)
179
- cfg_dict.setdefault("dataset_trust_remote_code", True)
180
- cfg_dict.setdefault("apply_b_dec_to_input", True)
181
- cfg_dict.setdefault("finetuning_scaling_factor", False)
182
- cfg_dict.setdefault("sae_lens_training_version", None)
183
- cfg_dict.setdefault("activation_fn_str", cfg_dict.get("activation_fn", "relu"))
184
- cfg_dict.setdefault("architecture", "standard")
185
- cfg_dict.setdefault("neuronpedia_id", None)
186
-
187
- if "normalize_activations" in cfg_dict and isinstance(
188
- cfg_dict["normalize_activations"], bool
203
+ new_cfg.setdefault("prepend_bos", True)
204
+ new_cfg.setdefault("dataset_trust_remote_code", True)
205
+ new_cfg.setdefault("apply_b_dec_to_input", True)
206
+ new_cfg.setdefault("finetuning_scaling_factor", False)
207
+ new_cfg.setdefault("sae_lens_training_version", None)
208
+ new_cfg.setdefault("activation_fn", new_cfg.get("activation_fn", "relu"))
209
+ new_cfg.setdefault("architecture", "standard")
210
+ new_cfg.setdefault("neuronpedia_id", None)
211
+
212
+ if "normalize_activations" in new_cfg and isinstance(
213
+ new_cfg["normalize_activations"], bool
189
214
  ):
190
215
  # backwards compatibility
191
- cfg_dict["normalize_activations"] = (
216
+ new_cfg["normalize_activations"] = (
192
217
  "none"
193
- if not cfg_dict["normalize_activations"]
218
+ if not new_cfg["normalize_activations"]
194
219
  else "expected_average_only_in"
195
220
  )
196
221
 
197
- cfg_dict.setdefault("normalize_activations", "none")
198
- cfg_dict.setdefault("device", "cpu")
222
+ if new_cfg.get("normalize_activations") is None:
223
+ new_cfg["normalize_activations"] = "none"
199
224
 
200
- return cfg_dict
225
+ new_cfg.setdefault("device", "cpu")
226
+
227
+ architecture = new_cfg.get("architecture", "standard")
228
+
229
+ config_class = get_sae_class(architecture)[1]
230
+
231
+ sae_cfg_dict = filter_valid_dataclass_fields(new_cfg, config_class)
232
+ if architecture == "topk":
233
+ sae_cfg_dict["k"] = new_cfg["activation_fn_kwargs"]["k"]
234
+
235
+ # import here to avoid circular import
236
+ from sae_lens.saes.sae import SAEMetadata
237
+
238
+ meta_dict = filter_valid_dataclass_fields(new_cfg, SAEMetadata)
239
+ sae_cfg_dict["metadata"] = meta_dict
240
+ sae_cfg_dict["architecture"] = architecture
241
+ return sae_cfg_dict
201
242
 
202
243
 
203
244
  def get_connor_rob_hook_z_config_from_hf(
@@ -223,7 +264,7 @@ def get_connor_rob_hook_z_config_from_hf(
223
264
  "hook_name": old_cfg_dict["act_name"],
224
265
  "hook_layer": old_cfg_dict["layer"],
225
266
  "hook_head_index": None,
226
- "activation_fn_str": "relu",
267
+ "activation_fn": "relu",
227
268
  "apply_b_dec_to_input": True,
228
269
  "finetuning_scaling_factor": False,
229
270
  "sae_lens_training_version": None,
@@ -372,7 +413,7 @@ def get_gemma_2_config_from_hf(
372
413
  "hook_name": hook_name,
373
414
  "hook_layer": layer,
374
415
  "hook_head_index": None,
375
- "activation_fn_str": "relu",
416
+ "activation_fn": "relu",
376
417
  "finetuning_scaling_factor": False,
377
418
  "sae_lens_training_version": None,
378
419
  "prepend_bos": True,
@@ -473,20 +514,11 @@ def get_llama_scope_config_from_hf(
473
514
  # Model specific parameters
474
515
  model_name, d_in = "meta-llama/Llama-3.1-8B", old_cfg_dict["d_model"]
475
516
 
476
- # Get norm scaling factor to rescale jumprelu threshold.
477
- # We need this because sae.fold_activation_norm_scaling_factor folds scaling norm into W_enc.
478
- # This requires jumprelu threshold to be scaled in the same way
479
- norm_scaling_factor = (
480
- d_in**0.5 / old_cfg_dict["dataset_average_activation_norm"]["in"]
481
- )
482
-
483
517
  cfg_dict = {
484
518
  "architecture": "jumprelu",
485
- "jump_relu_threshold": old_cfg_dict["jump_relu_threshold"]
486
- * norm_scaling_factor,
519
+ "jump_relu_threshold": old_cfg_dict["jump_relu_threshold"],
487
520
  # We use a scalar jump_relu_threshold for all features
488
521
  # This is different from Gemma Scope JumpReLU SAEs.
489
- # Scaled with norm_scaling_factor to match sae.fold_activation_norm_scaling_factor
490
522
  "d_in": d_in,
491
523
  "d_sae": old_cfg_dict["d_sae"],
492
524
  "dtype": "bfloat16",
@@ -494,7 +526,7 @@ def get_llama_scope_config_from_hf(
494
526
  "hook_name": old_cfg_dict["hook_point_in"],
495
527
  "hook_layer": int(old_cfg_dict["hook_point_in"].split(".")[1]),
496
528
  "hook_head_index": None,
497
- "activation_fn_str": "relu",
529
+ "activation_fn": "relu",
498
530
  "finetuning_scaling_factor": False,
499
531
  "sae_lens_training_version": None,
500
532
  "prepend_bos": True,
@@ -606,8 +638,8 @@ def get_dictionary_learning_config_1_from_hf(
606
638
 
607
639
  hook_point_name = f"blocks.{trainer['layer']}.hook_resid_post"
608
640
 
609
- activation_fn_str = "topk" if trainer["dict_class"] == "AutoEncoderTopK" else "relu"
610
- activation_fn_kwargs = {"k": trainer["k"]} if activation_fn_str == "topk" else {}
641
+ activation_fn = "topk" if trainer["dict_class"] == "AutoEncoderTopK" else "relu"
642
+ activation_fn_kwargs = {"k": trainer["k"]} if activation_fn == "topk" else {}
611
643
 
612
644
  return {
613
645
  "architecture": (
@@ -621,7 +653,7 @@ def get_dictionary_learning_config_1_from_hf(
621
653
  "hook_name": hook_point_name,
622
654
  "hook_layer": trainer["layer"],
623
655
  "hook_head_index": None,
624
- "activation_fn_str": activation_fn_str,
656
+ "activation_fn": activation_fn,
625
657
  "activation_fn_kwargs": activation_fn_kwargs,
626
658
  "apply_b_dec_to_input": True,
627
659
  "finetuning_scaling_factor": False,
@@ -664,7 +696,7 @@ def get_deepseek_r1_config_from_hf(
664
696
  "dataset_path": "lmsys/lmsys-chat-1m",
665
697
  "dataset_trust_remote_code": True,
666
698
  "sae_lens_training_version": None,
667
- "activation_fn_str": "relu",
699
+ "activation_fn": "relu",
668
700
  "normalize_activations": "none",
669
701
  "device": device,
670
702
  "apply_b_dec_to_input": False,
@@ -819,7 +851,7 @@ def get_llama_scope_r1_distill_config_from_hf(
819
851
  "hook_name": huggingface_cfg_dict["hook_point_in"],
820
852
  "hook_layer": int(huggingface_cfg_dict["hook_point_in"].split(".")[1]),
821
853
  "hook_head_index": None,
822
- "activation_fn_str": "relu",
854
+ "activation_fn": "relu",
823
855
  "finetuning_scaling_factor": False,
824
856
  "sae_lens_training_version": None,
825
857
  "prepend_bos": True,
@@ -13634,51 +13634,39 @@ gemma-2-2b-res-matryoshka-dc:
13634
13634
  - id: blocks.13.hook_resid_post
13635
13635
  path: standard/blocks.13.hook_resid_post
13636
13636
  l0: 40.0
13637
- neuronpedia: gemma-2-2b/13-res-matryoshka-dc
13638
13637
  - id: blocks.14.hook_resid_post
13639
13638
  path: standard/blocks.14.hook_resid_post
13640
13639
  l0: 40.0
13641
- neuronpedia: gemma-2-2b/14-res-matryoshka-dc
13642
13640
  - id: blocks.15.hook_resid_post
13643
13641
  path: standard/blocks.15.hook_resid_post
13644
13642
  l0: 40.0
13645
- neuronpedia: gemma-2-2b/15-res-matryoshka-dc
13646
13643
  - id: blocks.16.hook_resid_post
13647
13644
  path: standard/blocks.16.hook_resid_post
13648
13645
  l0: 40.0
13649
- neuronpedia: gemma-2-2b/16-res-matryoshka-dc
13650
13646
  - id: blocks.17.hook_resid_post
13651
13647
  path: standard/blocks.17.hook_resid_post
13652
13648
  l0: 40.0
13653
- neuronpedia: gemma-2-2b/17-res-matryoshka-dc
13654
13649
  - id: blocks.18.hook_resid_post
13655
13650
  path: standard/blocks.18.hook_resid_post
13656
13651
  l0: 40.0
13657
- neuronpedia: gemma-2-2b/18-res-matryoshka-dc
13658
13652
  - id: blocks.19.hook_resid_post
13659
13653
  path: standard/blocks.19.hook_resid_post
13660
13654
  l0: 40.0
13661
- neuronpedia: gemma-2-2b/19-res-matryoshka-dc
13662
13655
  - id: blocks.20.hook_resid_post
13663
13656
  path: standard/blocks.20.hook_resid_post
13664
13657
  l0: 40.0
13665
- neuronpedia: gemma-2-2b/20-res-matryoshka-dc
13666
13658
  - id: blocks.21.hook_resid_post
13667
13659
  path: standard/blocks.21.hook_resid_post
13668
13660
  l0: 40.0
13669
- neuronpedia: gemma-2-2b/21-res-matryoshka-dc
13670
13661
  - id: blocks.22.hook_resid_post
13671
13662
  path: standard/blocks.22.hook_resid_post
13672
13663
  l0: 40.0
13673
- neuronpedia: gemma-2-2b/22-res-matryoshka-dc
13674
13664
  - id: blocks.23.hook_resid_post
13675
13665
  path: standard/blocks.23.hook_resid_post
13676
13666
  l0: 40.0
13677
- neuronpedia: gemma-2-2b/23-res-matryoshka-dc
13678
13667
  - id: blocks.24.hook_resid_post
13679
13668
  path: standard/blocks.24.hook_resid_post
13680
13669
  l0: 40.0
13681
- neuronpedia: gemma-2-2b/24-res-matryoshka-dc
13682
13670
  gemma-2-2b-res-snap-matryoshka-dc:
13683
13671
  conversion_func: null
13684
13672
  links:
sae_lens/registry.py ADDED
@@ -0,0 +1,49 @@
1
+ from typing import TYPE_CHECKING, Any
2
+
3
+ # avoid circular imports
4
+ if TYPE_CHECKING:
5
+ from sae_lens.saes.sae import SAE, SAEConfig, TrainingSAE, TrainingSAEConfig
6
+
7
+ SAE_CLASS_REGISTRY: dict[str, tuple["type[SAE[Any]]", "type[SAEConfig]"]] = {}
8
+ SAE_TRAINING_CLASS_REGISTRY: dict[
9
+ str, tuple["type[TrainingSAE[Any]]", "type[TrainingSAEConfig]"]
10
+ ] = {}
11
+
12
+
13
+ def register_sae_class(
14
+ architecture: str,
15
+ sae_class: "type[SAE[Any]]",
16
+ sae_config_class: "type[SAEConfig]",
17
+ ) -> None:
18
+ if architecture in SAE_CLASS_REGISTRY:
19
+ raise ValueError(
20
+ f"SAE class for architecture {architecture} already registered."
21
+ )
22
+ SAE_CLASS_REGISTRY[architecture] = (sae_class, sae_config_class)
23
+
24
+
25
+ def register_sae_training_class(
26
+ architecture: str,
27
+ sae_training_class: "type[TrainingSAE[Any]]",
28
+ sae_training_config_class: "type[TrainingSAEConfig]",
29
+ ) -> None:
30
+ if architecture in SAE_TRAINING_CLASS_REGISTRY:
31
+ raise ValueError(
32
+ f"SAE training class for architecture {architecture} already registered."
33
+ )
34
+ SAE_TRAINING_CLASS_REGISTRY[architecture] = (
35
+ sae_training_class,
36
+ sae_training_config_class,
37
+ )
38
+
39
+
40
+ def get_sae_class(
41
+ architecture: str,
42
+ ) -> tuple["type[SAE[Any]]", "type[SAEConfig]"]:
43
+ return SAE_CLASS_REGISTRY[architecture]
44
+
45
+
46
+ def get_sae_training_class(
47
+ architecture: str,
48
+ ) -> tuple["type[TrainingSAE[Any]]", "type[TrainingSAEConfig]"]:
49
+ return SAE_TRAINING_CLASS_REGISTRY[architecture]
@@ -7,16 +7,18 @@ from typing import Any, cast
7
7
 
8
8
  import torch
9
9
  import wandb
10
+ from safetensors.torch import save_file
10
11
  from simple_parsing import ArgumentParser
11
12
  from transformer_lens.hook_points import HookedRootModule
12
13
 
13
14
  from sae_lens import logger
14
15
  from sae_lens.config import HfDataset, LanguageModelSAERunnerConfig
16
+ from sae_lens.constants import RUNNER_CFG_FILENAME, SPARSITY_FILENAME
15
17
  from sae_lens.load_model import load_model
18
+ from sae_lens.saes.sae import T_TRAINING_SAE_CONFIG, TrainingSAE, TrainingSAEConfig
16
19
  from sae_lens.training.activations_store import ActivationsStore
17
20
  from sae_lens.training.geometric_median import compute_geometric_median
18
21
  from sae_lens.training.sae_trainer import SAETrainer
19
- from sae_lens.training.training_sae import TrainingSAE, TrainingSAEConfig
20
22
 
21
23
 
22
24
  class InterruptedException(Exception):
@@ -32,17 +34,17 @@ class SAETrainingRunner:
32
34
  Class to run the training of a Sparse Autoencoder (SAE) on a TransformerLens model.
33
35
  """
34
36
 
35
- cfg: LanguageModelSAERunnerConfig
37
+ cfg: LanguageModelSAERunnerConfig[Any]
36
38
  model: HookedRootModule
37
- sae: TrainingSAE
39
+ sae: TrainingSAE[Any]
38
40
  activations_store: ActivationsStore
39
41
 
40
42
  def __init__(
41
43
  self,
42
- cfg: LanguageModelSAERunnerConfig,
44
+ cfg: LanguageModelSAERunnerConfig[T_TRAINING_SAE_CONFIG],
43
45
  override_dataset: HfDataset | None = None,
44
46
  override_model: HookedRootModule | None = None,
45
- override_sae: TrainingSAE | None = None,
47
+ override_sae: TrainingSAE[Any] | None = None,
46
48
  ):
47
49
  if override_dataset is not None:
48
50
  logger.warning(
@@ -73,14 +75,14 @@ class SAETrainingRunner:
73
75
 
74
76
  if override_sae is None:
75
77
  if self.cfg.from_pretrained_path is not None:
76
- self.sae = TrainingSAE.load_from_pretrained(
78
+ self.sae = TrainingSAE.load_from_disk(
77
79
  self.cfg.from_pretrained_path, self.cfg.device
78
80
  )
79
81
  else:
80
- self.sae = TrainingSAE(
82
+ self.sae = TrainingSAE.from_dict(
81
83
  TrainingSAEConfig.from_dict(
82
84
  self.cfg.get_training_sae_cfg_dict(),
83
- )
85
+ ).to_dict()
84
86
  )
85
87
  self._init_sae_group_b_decs()
86
88
  else:
@@ -91,13 +93,13 @@ class SAETrainingRunner:
91
93
  Run the training of the SAE.
92
94
  """
93
95
 
94
- if self.cfg.log_to_wandb:
96
+ if self.cfg.logger.log_to_wandb:
95
97
  wandb.init(
96
- project=self.cfg.wandb_project,
97
- entity=self.cfg.wandb_entity,
98
+ project=self.cfg.logger.wandb_project,
99
+ entity=self.cfg.logger.wandb_entity,
98
100
  config=cast(Any, self.cfg),
99
- name=self.cfg.run_name,
100
- id=self.cfg.wandb_id,
101
+ name=self.cfg.logger.run_name,
102
+ id=self.cfg.logger.wandb_id,
101
103
  )
102
104
 
103
105
  trainer = SAETrainer(
@@ -111,7 +113,7 @@ class SAETrainingRunner:
111
113
  self._compile_if_needed()
112
114
  sae = self.run_trainer_with_interruption_handling(trainer)
113
115
 
114
- if self.cfg.log_to_wandb:
116
+ if self.cfg.logger.log_to_wandb:
115
117
  wandb.finish()
116
118
 
117
119
  return sae
@@ -141,7 +143,9 @@ class SAETrainingRunner:
141
143
  backend=backend,
142
144
  ) # type: ignore
143
145
 
144
- def run_trainer_with_interruption_handling(self, trainer: SAETrainer):
146
+ def run_trainer_with_interruption_handling(
147
+ self, trainer: SAETrainer[TrainingSAE[TrainingSAEConfig], TrainingSAEConfig]
148
+ ):
145
149
  try:
146
150
  # signal handlers (if preempted)
147
151
  signal.signal(signal.SIGINT, interrupt_callback)
@@ -167,7 +171,7 @@ class SAETrainingRunner:
167
171
  extract all activations at a certain layer and use for sae b_dec initialization
168
172
  """
169
173
 
170
- if self.cfg.b_dec_init_method == "geometric_median":
174
+ if self.cfg.sae.b_dec_init_method == "geometric_median":
171
175
  self.activations_store.set_norm_scaling_factor_if_needed()
172
176
  layer_acts = self.activations_store.storage_buffer.detach()[:, 0, :]
173
177
  # get geometric median of the activations if we're using those.
@@ -175,15 +179,15 @@ class SAETrainingRunner:
175
179
  layer_acts,
176
180
  maxiter=100,
177
181
  ).median
178
- self.sae.initialize_b_dec_with_precalculated(median) # type: ignore
179
- elif self.cfg.b_dec_init_method == "mean":
182
+ self.sae.initialize_b_dec_with_precalculated(median)
183
+ elif self.cfg.sae.b_dec_init_method == "mean":
180
184
  self.activations_store.set_norm_scaling_factor_if_needed()
181
185
  layer_acts = self.activations_store.storage_buffer.detach().cpu()[:, 0, :]
182
186
  self.sae.initialize_b_dec_with_mean(layer_acts) # type: ignore
183
187
 
184
188
  @staticmethod
185
189
  def save_checkpoint(
186
- trainer: SAETrainer,
190
+ trainer: SAETrainer[TrainingSAE[Any], Any],
187
191
  checkpoint_name: str,
188
192
  wandb_aliases: list[str] | None = None,
189
193
  ) -> None:
@@ -194,46 +198,28 @@ class SAETrainingRunner:
194
198
  str(base_path / "activations_store_state.safetensors")
195
199
  )
196
200
 
197
- if trainer.sae.cfg.normalize_sae_decoder:
198
- trainer.sae.set_decoder_norm_to_unit_norm()
201
+ weights_path, cfg_path = trainer.sae.save_model(str(base_path))
199
202
 
200
- weights_path, cfg_path, sparsity_path = trainer.sae.save_model(
201
- str(base_path),
202
- trainer.log_feature_sparsity,
203
- )
203
+ sparsity_path = base_path / SPARSITY_FILENAME
204
+ save_file({"sparsity": trainer.log_feature_sparsity}, sparsity_path)
204
205
 
205
- # let's over write the cfg file with the trainer cfg, which is a super set of the original cfg.
206
- # and should not cause issues but give us more info about SAEs we trained in SAE Lens.
207
- config = trainer.cfg.to_dict()
208
- with open(cfg_path, "w") as f:
209
- json.dump(config, f)
210
-
211
- if trainer.cfg.log_to_wandb:
212
- # Avoid wandb saving errors such as:
213
- # ValueError: Artifact name may only contain alphanumeric characters, dashes, underscores, and dots. Invalid name: sae_google/gemma-2b_etc
214
- sae_name = trainer.sae.get_name().replace("/", "__")
215
-
216
- # save model weights and cfg
217
- model_artifact = wandb.Artifact(
218
- sae_name,
219
- type="model",
220
- metadata=dict(trainer.cfg.__dict__),
221
- )
222
- model_artifact.add_file(str(weights_path))
223
- model_artifact.add_file(str(cfg_path))
224
- wandb.log_artifact(model_artifact, aliases=wandb_aliases)
225
-
226
- # save log feature sparsity
227
- sparsity_artifact = wandb.Artifact(
228
- f"{sae_name}_log_feature_sparsity",
229
- type="log_feature_sparsity",
230
- metadata=dict(trainer.cfg.__dict__),
206
+ runner_config = trainer.cfg.to_dict()
207
+ with open(base_path / RUNNER_CFG_FILENAME, "w") as f:
208
+ json.dump(runner_config, f)
209
+
210
+ if trainer.cfg.logger.log_to_wandb:
211
+ trainer.cfg.logger.log(
212
+ trainer,
213
+ weights_path,
214
+ cfg_path,
215
+ sparsity_path=sparsity_path,
216
+ wandb_aliases=wandb_aliases,
231
217
  )
232
- sparsity_artifact.add_file(str(sparsity_path))
233
- wandb.log_artifact(sparsity_artifact)
234
218
 
235
219
 
236
- def _parse_cfg_args(args: Sequence[str]) -> LanguageModelSAERunnerConfig:
220
+ def _parse_cfg_args(
221
+ args: Sequence[str],
222
+ ) -> LanguageModelSAERunnerConfig[TrainingSAEConfig]:
237
223
  if len(args) == 0:
238
224
  args = ["--help"]
239
225
  parser = ArgumentParser(exit_on_error=False)
@@ -0,0 +1,48 @@
1
+ from .gated_sae import (
2
+ GatedSAE,
3
+ GatedSAEConfig,
4
+ GatedTrainingSAE,
5
+ GatedTrainingSAEConfig,
6
+ )
7
+ from .jumprelu_sae import (
8
+ JumpReLUSAE,
9
+ JumpReLUSAEConfig,
10
+ JumpReLUTrainingSAE,
11
+ JumpReLUTrainingSAEConfig,
12
+ )
13
+ from .sae import SAE, SAEConfig, TrainingSAE, TrainingSAEConfig
14
+ from .standard_sae import (
15
+ StandardSAE,
16
+ StandardSAEConfig,
17
+ StandardTrainingSAE,
18
+ StandardTrainingSAEConfig,
19
+ )
20
+ from .topk_sae import (
21
+ TopKSAE,
22
+ TopKSAEConfig,
23
+ TopKTrainingSAE,
24
+ TopKTrainingSAEConfig,
25
+ )
26
+
27
+ __all__ = [
28
+ "SAE",
29
+ "SAEConfig",
30
+ "TrainingSAE",
31
+ "TrainingSAEConfig",
32
+ "StandardSAE",
33
+ "StandardSAEConfig",
34
+ "StandardTrainingSAE",
35
+ "StandardTrainingSAEConfig",
36
+ "GatedSAE",
37
+ "GatedSAEConfig",
38
+ "GatedTrainingSAE",
39
+ "GatedTrainingSAEConfig",
40
+ "JumpReLUSAE",
41
+ "JumpReLUSAEConfig",
42
+ "JumpReLUTrainingSAE",
43
+ "JumpReLUTrainingSAEConfig",
44
+ "TopKSAE",
45
+ "TopKSAEConfig",
46
+ "TopKTrainingSAE",
47
+ "TopKTrainingSAEConfig",
48
+ ]