sae-lens 6.0.0rc3__py3-none-any.whl → 6.0.0rc5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sae_lens/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.0.0-rc.3"
2
+ __version__ = "6.0.0-rc.5"
3
3
 
4
4
  import logging
5
5
 
@@ -59,7 +59,7 @@ def NanAndInfReplacer(value: str):
59
59
 
60
60
 
61
61
  def open_neuronpedia_feature_dashboard(sae: SAE[Any], index: int):
62
- sae_id = sae.cfg.neuronpedia_id
62
+ sae_id = sae.cfg.metadata.neuronpedia_id
63
63
  if sae_id is None:
64
64
  logger.warning(
65
65
  "SAE does not have a Neuronpedia ID. Either dashboards for this SAE do not exist (yet) on Neuronpedia, or the SAE was not loaded via the from_pretrained method"
@@ -74,7 +74,7 @@ def get_neuronpedia_quick_list(
74
74
  features: list[int],
75
75
  name: str = "temporary_list",
76
76
  ):
77
- sae_id = sae.cfg.neuronpedia_id
77
+ sae_id = sae.cfg.metadata.neuronpedia_id
78
78
  if sae_id is None:
79
79
  logger.warning(
80
80
  "SAE does not have a Neuronpedia ID. Either dashboards for this SAE do not exist (yet) on Neuronpedia, or the SAE was not loaded via the from_pretrained method"
@@ -86,7 +86,7 @@ def get_neuronpedia_quick_list(
86
86
  url = url + "?name=" + name
87
87
  list_feature = [
88
88
  {
89
- "modelId": sae.cfg.model_name,
89
+ "modelId": sae.cfg.metadata.model_name,
90
90
  "layer": sae_id.split("/")[1],
91
91
  "index": str(feature),
92
92
  }
sae_lens/config.py CHANGED
@@ -201,7 +201,7 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
201
201
  train_batch_size_tokens: int = 4096
202
202
 
203
203
  ## Adam
204
- adam_beta1: float = 0.0
204
+ adam_beta1: float = 0.9
205
205
  adam_beta2: float = 0.999
206
206
 
207
207
  ## Learning Rate Schedule
@@ -390,7 +390,6 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
390
390
  adam_beta2=self.adam_beta2,
391
391
  lr_decay_steps=self.lr_decay_steps,
392
392
  n_restart_cycles=self.n_restart_cycles,
393
- total_training_steps=self.total_training_steps,
394
393
  train_batch_size_samples=self.train_batch_size_tokens,
395
394
  dead_feature_window=self.dead_feature_window,
396
395
  feature_sampling_window=self.feature_sampling_window,
@@ -613,8 +612,11 @@ class SAETrainerConfig:
613
612
  adam_beta2: float
614
613
  lr_decay_steps: int
615
614
  n_restart_cycles: int
616
- total_training_steps: int
617
615
  train_batch_size_samples: int
618
616
  dead_feature_window: int
619
617
  feature_sampling_window: int
620
618
  logger: LoggingConfig
619
+
620
+ @property
621
+ def total_training_steps(self) -> int:
622
+ return self.total_training_samples // self.train_batch_size_samples
sae_lens/constants.py CHANGED
@@ -16,5 +16,6 @@ SPARSITY_FILENAME = "sparsity.safetensors"
16
16
  SAE_WEIGHTS_FILENAME = "sae_weights.safetensors"
17
17
  SAE_CFG_FILENAME = "cfg.json"
18
18
  RUNNER_CFG_FILENAME = "runner_cfg.json"
19
+ SPARSIFY_WEIGHTS_FILENAME = "sae.safetensors"
19
20
  ACTIVATIONS_STORE_STATE_FILENAME = "activations_store_state.safetensors"
20
21
  ACTIVATION_SCALER_CFG_FILENAME = "activation_scaler.json"
sae_lens/evals.py CHANGED
@@ -4,6 +4,7 @@ import json
4
4
  import math
5
5
  import re
6
6
  import subprocess
7
+ import sys
7
8
  from collections import defaultdict
8
9
  from collections.abc import Mapping
9
10
  from dataclasses import dataclass, field
@@ -15,7 +16,7 @@ from typing import Any
15
16
  import einops
16
17
  import pandas as pd
17
18
  import torch
18
- from tqdm import tqdm
19
+ from tqdm.auto import tqdm
19
20
  from transformer_lens import HookedTransformer
20
21
  from transformer_lens.hook_points import HookedRootModule
21
22
 
@@ -768,17 +769,6 @@ def nested_dict() -> defaultdict[Any, Any]:
768
769
  return defaultdict(nested_dict)
769
770
 
770
771
 
771
- def dict_to_nested(flat_dict: dict[str, Any]) -> defaultdict[Any, Any]:
772
- nested = nested_dict()
773
- for key, value in flat_dict.items():
774
- parts = key.split("/")
775
- d = nested
776
- for part in parts[:-1]:
777
- d = d[part]
778
- d[parts[-1]] = value
779
- return nested
780
-
781
-
782
772
  def multiple_evals(
783
773
  sae_regex_pattern: str,
784
774
  sae_block_pattern: str,
@@ -814,16 +804,18 @@ def multiple_evals(
814
804
  release=sae_release_name, # see other options in sae_lens/pretrained_saes.yaml
815
805
  sae_id=sae_id, # won't always be a hook point
816
806
  device=device,
817
- )[0]
807
+ )
818
808
 
819
809
  # move SAE to device if not there already
820
810
  sae.to(device)
821
811
 
822
- if current_model_str != sae.cfg.model_name:
812
+ if current_model_str != sae.cfg.metadata.model_name:
823
813
  del current_model # potentially saves GPU memory
824
- current_model_str = sae.cfg.model_name
814
+ current_model_str = sae.cfg.metadata.model_name
825
815
  current_model = HookedTransformer.from_pretrained_no_processing(
826
- current_model_str, device=device, **sae.cfg.model_from_pretrained_kwargs
816
+ current_model_str,
817
+ device=device,
818
+ **sae.cfg.metadata.model_from_pretrained_kwargs,
827
819
  )
828
820
  assert current_model is not None
829
821
 
@@ -941,7 +933,7 @@ def process_results(
941
933
  }
942
934
 
943
935
 
944
- if __name__ == "__main__":
936
+ def process_args(args: list[str]) -> argparse.Namespace:
945
937
  arg_parser = argparse.ArgumentParser(description="Run evaluations on SAEs")
946
938
  arg_parser.add_argument(
947
939
  "sae_regex_pattern",
@@ -1031,11 +1023,19 @@ if __name__ == "__main__":
1031
1023
  help="Enable verbose output with tqdm loaders.",
1032
1024
  )
1033
1025
 
1034
- args = arg_parser.parse_args()
1035
- eval_results = run_evaluations(args)
1036
- output_files = process_results(eval_results, args.output_dir)
1026
+ return arg_parser.parse_args(args)
1027
+
1028
+
1029
+ def run_evals_cli(args: list[str]) -> None:
1030
+ opts = process_args(args)
1031
+ eval_results = run_evaluations(opts)
1032
+ output_files = process_results(eval_results, opts.output_dir)
1037
1033
 
1038
1034
  print("Evaluation complete. Output files:")
1039
1035
  print(f"Individual JSONs: {len(output_files['individual_jsons'])}") # type: ignore
1040
1036
  print(f"Combined JSON: {output_files['combined_json']}")
1041
1037
  print(f"CSV: {output_files['csv']}")
1038
+
1039
+
1040
+ if __name__ == "__main__":
1041
+ run_evals_cli(sys.argv[1:])
@@ -4,7 +4,7 @@ import sys
4
4
  from collections.abc import Sequence
5
5
  from dataclasses import dataclass
6
6
  from pathlib import Path
7
- from typing import Any, Generic, cast
7
+ from typing import Any, Generic
8
8
 
9
9
  import torch
10
10
  import wandb
@@ -17,12 +17,16 @@ from sae_lens.config import HfDataset, LanguageModelSAERunnerConfig
17
17
  from sae_lens.constants import ACTIVATIONS_STORE_STATE_FILENAME, RUNNER_CFG_FILENAME
18
18
  from sae_lens.evals import EvalConfig, run_evals
19
19
  from sae_lens.load_model import load_model
20
+ from sae_lens.saes.gated_sae import GatedTrainingSAEConfig
21
+ from sae_lens.saes.jumprelu_sae import JumpReLUTrainingSAEConfig
20
22
  from sae_lens.saes.sae import (
21
23
  T_TRAINING_SAE,
22
24
  T_TRAINING_SAE_CONFIG,
23
25
  TrainingSAE,
24
26
  TrainingSAEConfig,
25
27
  )
28
+ from sae_lens.saes.standard_sae import StandardTrainingSAEConfig
29
+ from sae_lens.saes.topk_sae import TopKTrainingSAEConfig
26
30
  from sae_lens.training.activation_scaler import ActivationScaler
27
31
  from sae_lens.training.activations_store import ActivationsStore
28
32
  from sae_lens.training.sae_trainer import SAETrainer
@@ -145,17 +149,18 @@ class LanguageModelSAETrainingRunner:
145
149
  )
146
150
  else:
147
151
  self.sae = override_sae
152
+ self.sae.to(self.cfg.device)
148
153
 
149
154
  def run(self):
150
155
  """
151
156
  Run the training of the SAE.
152
157
  """
153
-
158
+ self._set_sae_metadata()
154
159
  if self.cfg.logger.log_to_wandb:
155
160
  wandb.init(
156
161
  project=self.cfg.logger.wandb_project,
157
162
  entity=self.cfg.logger.wandb_entity,
158
- config=cast(Any, self.cfg),
163
+ config=self.cfg.to_dict(),
159
164
  name=self.cfg.logger.run_name,
160
165
  id=self.cfg.logger.wandb_id,
161
166
  )
@@ -184,6 +189,20 @@ class LanguageModelSAETrainingRunner:
184
189
 
185
190
  return sae
186
191
 
192
+ def _set_sae_metadata(self):
193
+ self.sae.cfg.metadata.dataset_path = self.cfg.dataset_path
194
+ self.sae.cfg.metadata.hook_name = self.cfg.hook_name
195
+ self.sae.cfg.metadata.model_name = self.cfg.model_name
196
+ self.sae.cfg.metadata.model_class_name = self.cfg.model_class_name
197
+ self.sae.cfg.metadata.hook_head_index = self.cfg.hook_head_index
198
+ self.sae.cfg.metadata.context_size = self.cfg.context_size
199
+ self.sae.cfg.metadata.seqpos_slice = self.cfg.seqpos_slice
200
+ self.sae.cfg.metadata.model_from_pretrained_kwargs = (
201
+ self.cfg.model_from_pretrained_kwargs
202
+ )
203
+ self.sae.cfg.metadata.prepend_bos = self.cfg.prepend_bos
204
+ self.sae.cfg.metadata.exclude_special_tokens = self.cfg.exclude_special_tokens
205
+
187
206
  def _compile_if_needed(self):
188
207
  # Compile model and SAE
189
208
  # torch.compile can provide significant speedups (10-20% in testing)
@@ -247,11 +266,100 @@ class LanguageModelSAETrainingRunner:
247
266
  def _parse_cfg_args(
248
267
  args: Sequence[str],
249
268
  ) -> LanguageModelSAERunnerConfig[TrainingSAEConfig]:
269
+ """
270
+ Parse command line arguments into a LanguageModelSAERunnerConfig.
271
+
272
+ This function first parses the architecture argument to determine which
273
+ concrete SAE config class to use, then parses the full configuration
274
+ with that concrete type.
275
+ """
250
276
  if len(args) == 0:
251
277
  args = ["--help"]
278
+
279
+ # First, parse only the architecture to determine which concrete class to use
280
+ architecture_parser = ArgumentParser(
281
+ description="Parse architecture to determine SAE config class",
282
+ exit_on_error=False,
283
+ add_help=False, # Don't add help to avoid conflicts
284
+ )
285
+ architecture_parser.add_argument(
286
+ "--architecture",
287
+ type=str,
288
+ choices=["standard", "gated", "jumprelu", "topk"],
289
+ default="standard",
290
+ help="SAE architecture to use",
291
+ )
292
+
293
+ # Parse known args to extract architecture, ignore unknown args for now
294
+ arch_args, remaining_args = architecture_parser.parse_known_args(args)
295
+ architecture = arch_args.architecture
296
+
297
+ # Remove architecture from remaining args if it exists
298
+ filtered_args = []
299
+ skip_next = False
300
+ for arg in remaining_args:
301
+ if skip_next:
302
+ skip_next = False
303
+ continue
304
+ if arg == "--architecture":
305
+ skip_next = True # Skip the next argument (the architecture value)
306
+ continue
307
+ filtered_args.append(arg)
308
+
309
+ # Create a custom wrapper class that simple_parsing can handle
310
+ def create_config_class(
311
+ sae_config_type: type[TrainingSAEConfig],
312
+ ) -> type[LanguageModelSAERunnerConfig[TrainingSAEConfig]]:
313
+ """Create a concrete config class for the given SAE config type."""
314
+
315
+ # Create the base config without the sae field
316
+ from dataclasses import field as dataclass_field
317
+ from dataclasses import fields, make_dataclass
318
+
319
+ # Get all fields from LanguageModelSAERunnerConfig except the generic sae field
320
+ base_fields = []
321
+ for field_obj in fields(LanguageModelSAERunnerConfig):
322
+ if field_obj.name != "sae":
323
+ base_fields.append((field_obj.name, field_obj.type, field_obj))
324
+
325
+ # Add the concrete sae field
326
+ base_fields.append(
327
+ (
328
+ "sae",
329
+ sae_config_type,
330
+ dataclass_field(
331
+ default_factory=lambda: sae_config_type(d_in=512, d_sae=1024)
332
+ ),
333
+ )
334
+ )
335
+
336
+ # Create the concrete class
337
+ return make_dataclass(
338
+ f"{sae_config_type.__name__}RunnerConfig",
339
+ base_fields,
340
+ bases=(LanguageModelSAERunnerConfig,),
341
+ )
342
+
343
+ # Map architecture to concrete config class
344
+ sae_config_map = {
345
+ "standard": StandardTrainingSAEConfig,
346
+ "gated": GatedTrainingSAEConfig,
347
+ "jumprelu": JumpReLUTrainingSAEConfig,
348
+ "topk": TopKTrainingSAEConfig,
349
+ }
350
+
351
+ sae_config_type = sae_config_map[architecture]
352
+ concrete_config_class = create_config_class(sae_config_type)
353
+
354
+ # Now parse the full configuration with the concrete type
252
355
  parser = ArgumentParser(exit_on_error=False)
253
- parser.add_arguments(LanguageModelSAERunnerConfig, dest="cfg")
254
- return parser.parse_args(args).cfg
356
+ parser.add_arguments(concrete_config_class, dest="cfg")
357
+
358
+ # Parse the filtered arguments (without --architecture)
359
+ parsed_args = parser.parse_args(filtered_args)
360
+
361
+ # Return the parsed configuration
362
+ return parsed_args.cfg
255
363
 
256
364
 
257
365
  # moved into its own function to make it easier to test
@@ -16,6 +16,7 @@ from sae_lens.constants import (
16
16
  DTYPE_MAP,
17
17
  SAE_CFG_FILENAME,
18
18
  SAE_WEIGHTS_FILENAME,
19
+ SPARSIFY_WEIGHTS_FILENAME,
19
20
  SPARSITY_FILENAME,
20
21
  )
21
22
  from sae_lens.loading.pretrained_saes_directory import (
@@ -26,6 +27,22 @@ from sae_lens.loading.pretrained_saes_directory import (
26
27
  from sae_lens.registry import get_sae_class
27
28
  from sae_lens.util import filter_valid_dataclass_fields
28
29
 
30
+ LLM_METADATA_KEYS = {
31
+ "model_name",
32
+ "hook_name",
33
+ "model_class_name",
34
+ "hook_head_index",
35
+ "model_from_pretrained_kwargs",
36
+ "prepend_bos",
37
+ "exclude_special_tokens",
38
+ "neuronpedia_id",
39
+ "context_size",
40
+ "seqpos_slice",
41
+ "dataset_path",
42
+ "sae_lens_version",
43
+ "sae_lens_training_version",
44
+ }
45
+
29
46
 
30
47
  # loaders take in a release, sae_id, device, and whether to force download, and returns a tuple of config, state_dict, and log sparsity
31
48
  class PretrainedSaeHuggingfaceLoader(Protocol):
@@ -207,6 +224,10 @@ def handle_pre_6_0_config(cfg_dict: dict[str, Any]) -> dict[str, Any]:
207
224
  new_cfg.setdefault("activation_fn", new_cfg.get("activation_fn", "relu"))
208
225
  new_cfg.setdefault("architecture", "standard")
209
226
  new_cfg.setdefault("neuronpedia_id", None)
227
+ new_cfg.setdefault(
228
+ "reshape_activations",
229
+ "hook_z" if "hook_z" in new_cfg.get("hook_name", "") else "none",
230
+ )
210
231
 
211
232
  if "normalize_activations" in new_cfg and isinstance(
212
233
  new_cfg["normalize_activations"], bool
@@ -228,14 +249,12 @@ def handle_pre_6_0_config(cfg_dict: dict[str, Any]) -> dict[str, Any]:
228
249
  config_class = get_sae_class(architecture)[1]
229
250
 
230
251
  sae_cfg_dict = filter_valid_dataclass_fields(new_cfg, config_class)
231
- if architecture == "topk":
252
+ if architecture == "topk" and "activation_fn_kwargs" in new_cfg:
232
253
  sae_cfg_dict["k"] = new_cfg["activation_fn_kwargs"]["k"]
233
254
 
234
- # import here to avoid circular import
235
- from sae_lens.saes.sae import SAEMetadata
236
-
237
- meta_dict = filter_valid_dataclass_fields(new_cfg, SAEMetadata)
238
- sae_cfg_dict["metadata"] = meta_dict
255
+ sae_cfg_dict["metadata"] = {
256
+ k: v for k, v in new_cfg.items() if k in LLM_METADATA_KEYS
257
+ }
239
258
  sae_cfg_dict["architecture"] = architecture
240
259
  return sae_cfg_dict
241
260
 
@@ -271,6 +290,7 @@ def get_connor_rob_hook_z_config_from_hf(
271
290
  "context_size": 128,
272
291
  "normalize_activations": "none",
273
292
  "dataset_trust_remote_code": True,
293
+ "reshape_activations": "hook_z",
274
294
  **(cfg_overrides or {}),
275
295
  }
276
296
 
@@ -511,11 +531,20 @@ def get_llama_scope_config_from_hf(
511
531
  # Model specific parameters
512
532
  model_name, d_in = "meta-llama/Llama-3.1-8B", old_cfg_dict["d_model"]
513
533
 
534
+ # Get norm scaling factor to rescale jumprelu threshold.
535
+ # We need this because sae.fold_activation_norm_scaling_factor folds scaling norm into W_enc.
536
+ # This requires jumprelu threshold to be scaled in the same way
537
+ norm_scaling_factor = (
538
+ d_in**0.5 / old_cfg_dict["dataset_average_activation_norm"]["in"]
539
+ )
540
+
514
541
  cfg_dict = {
515
542
  "architecture": "jumprelu",
516
- "jump_relu_threshold": old_cfg_dict["jump_relu_threshold"],
543
+ "jump_relu_threshold": old_cfg_dict["jump_relu_threshold"]
544
+ * norm_scaling_factor,
517
545
  # We use a scalar jump_relu_threshold for all features
518
546
  # This is different from Gemma Scope JumpReLU SAEs.
547
+ # Scaled with norm_scaling_factor to match sae.fold_activation_norm_scaling_factor
519
548
  "d_in": d_in,
520
549
  "d_sae": old_cfg_dict["d_sae"],
521
550
  "dtype": "bfloat16",
@@ -923,6 +952,146 @@ def llama_scope_r1_distill_sae_huggingface_loader(
923
952
  return cfg_dict, state_dict, log_sparsity
924
953
 
925
954
 
955
+ def get_sparsify_config_from_hf(
956
+ repo_id: str,
957
+ folder_name: str,
958
+ device: str,
959
+ force_download: bool = False,
960
+ cfg_overrides: dict[str, Any] | None = None,
961
+ ) -> dict[str, Any]:
962
+ cfg_filename = f"{folder_name}/{SAE_CFG_FILENAME}"
963
+ cfg_path = hf_hub_download(
964
+ repo_id,
965
+ filename=cfg_filename,
966
+ force_download=force_download,
967
+ )
968
+ sae_path = Path(cfg_path).parent
969
+ return get_sparsify_config_from_disk(
970
+ sae_path, device=device, cfg_overrides=cfg_overrides
971
+ )
972
+
973
+
974
+ def get_sparsify_config_from_disk(
975
+ path: str | Path,
976
+ device: str | None = None,
977
+ cfg_overrides: dict[str, Any] | None = None,
978
+ ) -> dict[str, Any]:
979
+ path = Path(path)
980
+
981
+ with open(path / SAE_CFG_FILENAME) as f:
982
+ old_cfg_dict = json.load(f)
983
+
984
+ config_path = path.parent / "config.json"
985
+ if config_path.exists():
986
+ with open(config_path) as f:
987
+ config_dict = json.load(f)
988
+ else:
989
+ config_dict = {}
990
+
991
+ folder_name = path.name
992
+ if folder_name == "embed_tokens":
993
+ hook_name, layer = "hook_embed", 0
994
+ else:
995
+ match = re.search(r"layers[._](\d+)", folder_name)
996
+ if match is None:
997
+ raise ValueError(f"Unrecognized Sparsify folder: {folder_name}")
998
+ layer = int(match.group(1))
999
+ hook_name = f"blocks.{layer}.hook_resid_post"
1000
+
1001
+ cfg_dict: dict[str, Any] = {
1002
+ "architecture": "standard",
1003
+ "d_in": old_cfg_dict["d_in"],
1004
+ "d_sae": old_cfg_dict["d_in"] * old_cfg_dict["expansion_factor"],
1005
+ "dtype": "bfloat16",
1006
+ "device": device or "cpu",
1007
+ "model_name": config_dict.get("model", path.parts[-2]),
1008
+ "hook_name": hook_name,
1009
+ "hook_layer": layer,
1010
+ "hook_head_index": None,
1011
+ "activation_fn_str": "topk",
1012
+ "activation_fn_kwargs": {
1013
+ "k": old_cfg_dict["k"],
1014
+ "signed": old_cfg_dict.get("signed", False),
1015
+ },
1016
+ "apply_b_dec_to_input": not old_cfg_dict.get("normalize_decoder", False),
1017
+ "dataset_path": config_dict.get(
1018
+ "dataset", "togethercomputer/RedPajama-Data-1T-Sample"
1019
+ ),
1020
+ "context_size": config_dict.get("ctx_len", 2048),
1021
+ "finetuning_scaling_factor": False,
1022
+ "sae_lens_training_version": None,
1023
+ "prepend_bos": True,
1024
+ "dataset_trust_remote_code": True,
1025
+ "normalize_activations": "none",
1026
+ "neuronpedia_id": None,
1027
+ }
1028
+
1029
+ if cfg_overrides:
1030
+ cfg_dict.update(cfg_overrides)
1031
+
1032
+ return cfg_dict
1033
+
1034
+
1035
+ def sparsify_huggingface_loader(
1036
+ repo_id: str,
1037
+ folder_name: str,
1038
+ device: str = "cpu",
1039
+ force_download: bool = False,
1040
+ cfg_overrides: dict[str, Any] | None = None,
1041
+ ) -> tuple[dict[str, Any], dict[str, torch.Tensor], None]:
1042
+ weights_filename = f"{folder_name}/{SPARSIFY_WEIGHTS_FILENAME}"
1043
+ sae_path = hf_hub_download(
1044
+ repo_id,
1045
+ filename=weights_filename,
1046
+ force_download=force_download,
1047
+ )
1048
+ cfg_dict, state_dict = sparsify_disk_loader(
1049
+ Path(sae_path).parent, device=device, cfg_overrides=cfg_overrides
1050
+ )
1051
+ return cfg_dict, state_dict, None
1052
+
1053
+
1054
+ def sparsify_disk_loader(
1055
+ path: str | Path,
1056
+ device: str = "cpu",
1057
+ cfg_overrides: dict[str, Any] | None = None,
1058
+ ) -> tuple[dict[str, Any], dict[str, torch.Tensor]]:
1059
+ cfg_dict = get_sparsify_config_from_disk(path, device, cfg_overrides)
1060
+
1061
+ weight_path = Path(path) / SPARSIFY_WEIGHTS_FILENAME
1062
+ state_dict_loaded = load_file(weight_path, device=device)
1063
+
1064
+ dtype = DTYPE_MAP[cfg_dict["dtype"]]
1065
+
1066
+ W_enc = (
1067
+ state_dict_loaded["W_enc"]
1068
+ if "W_enc" in state_dict_loaded
1069
+ else state_dict_loaded["encoder.weight"].T
1070
+ ).to(dtype)
1071
+
1072
+ if "W_dec" in state_dict_loaded:
1073
+ W_dec = state_dict_loaded["W_dec"].T.to(dtype)
1074
+ else:
1075
+ W_dec = state_dict_loaded["decoder.weight"].T.to(dtype)
1076
+
1077
+ if "b_enc" in state_dict_loaded:
1078
+ b_enc = state_dict_loaded["b_enc"].to(dtype)
1079
+ elif "encoder.bias" in state_dict_loaded:
1080
+ b_enc = state_dict_loaded["encoder.bias"].to(dtype)
1081
+ else:
1082
+ b_enc = torch.zeros(cfg_dict["d_sae"], dtype=dtype, device=device)
1083
+
1084
+ if "b_dec" in state_dict_loaded:
1085
+ b_dec = state_dict_loaded["b_dec"].to(dtype)
1086
+ elif "decoder.bias" in state_dict_loaded:
1087
+ b_dec = state_dict_loaded["decoder.bias"].to(dtype)
1088
+ else:
1089
+ b_dec = torch.zeros(cfg_dict["d_in"], dtype=dtype, device=device)
1090
+
1091
+ state_dict = {"W_enc": W_enc, "b_enc": b_enc, "W_dec": W_dec, "b_dec": b_dec}
1092
+ return cfg_dict, state_dict
1093
+
1094
+
926
1095
  NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
927
1096
  "sae_lens": sae_lens_huggingface_loader,
928
1097
  "connor_rob_hook_z": connor_rob_hook_z_huggingface_loader,
@@ -931,6 +1100,7 @@ NAMED_PRETRAINED_SAE_LOADERS: dict[str, PretrainedSaeHuggingfaceLoader] = {
931
1100
  "llama_scope_r1_distill": llama_scope_r1_distill_sae_huggingface_loader,
932
1101
  "dictionary_learning_1": dictionary_learning_sae_huggingface_loader_1,
933
1102
  "deepseek_r1": deepseek_r1_sae_huggingface_loader,
1103
+ "sparsify": sparsify_huggingface_loader,
934
1104
  }
935
1105
 
936
1106
 
@@ -942,4 +1112,5 @@ NAMED_PRETRAINED_SAE_CONFIG_GETTERS: dict[str, PretrainedSaeConfigHuggingfaceLoa
942
1112
  "llama_scope_r1_distill": get_llama_scope_r1_distill_config_from_hf,
943
1113
  "dictionary_learning_1": get_dictionary_learning_config_1_from_hf,
944
1114
  "deepseek_r1": get_deepseek_r1_config_from_hf,
1115
+ "sparsify": get_sparsify_config_from_hf,
945
1116
  }
@@ -13634,39 +13634,51 @@ gemma-2-2b-res-matryoshka-dc:
13634
13634
  - id: blocks.13.hook_resid_post
13635
13635
  path: standard/blocks.13.hook_resid_post
13636
13636
  l0: 40.0
13637
+ neuronpedia: gemma-2-2b/13-res-matryoshka-dc
13637
13638
  - id: blocks.14.hook_resid_post
13638
13639
  path: standard/blocks.14.hook_resid_post
13639
13640
  l0: 40.0
13641
+ neuronpedia: gemma-2-2b/14-res-matryoshka-dc
13640
13642
  - id: blocks.15.hook_resid_post
13641
13643
  path: standard/blocks.15.hook_resid_post
13642
13644
  l0: 40.0
13645
+ neuronpedia: gemma-2-2b/15-res-matryoshka-dc
13643
13646
  - id: blocks.16.hook_resid_post
13644
13647
  path: standard/blocks.16.hook_resid_post
13645
13648
  l0: 40.0
13649
+ neuronpedia: gemma-2-2b/16-res-matryoshka-dc
13646
13650
  - id: blocks.17.hook_resid_post
13647
13651
  path: standard/blocks.17.hook_resid_post
13648
13652
  l0: 40.0
13653
+ neuronpedia: gemma-2-2b/17-res-matryoshka-dc
13649
13654
  - id: blocks.18.hook_resid_post
13650
13655
  path: standard/blocks.18.hook_resid_post
13651
13656
  l0: 40.0
13657
+ neuronpedia: gemma-2-2b/18-res-matryoshka-dc
13652
13658
  - id: blocks.19.hook_resid_post
13653
13659
  path: standard/blocks.19.hook_resid_post
13654
13660
  l0: 40.0
13661
+ neuronpedia: gemma-2-2b/19-res-matryoshka-dc
13655
13662
  - id: blocks.20.hook_resid_post
13656
13663
  path: standard/blocks.20.hook_resid_post
13657
13664
  l0: 40.0
13665
+ neuronpedia: gemma-2-2b/20-res-matryoshka-dc
13658
13666
  - id: blocks.21.hook_resid_post
13659
13667
  path: standard/blocks.21.hook_resid_post
13660
13668
  l0: 40.0
13669
+ neuronpedia: gemma-2-2b/21-res-matryoshka-dc
13661
13670
  - id: blocks.22.hook_resid_post
13662
13671
  path: standard/blocks.22.hook_resid_post
13663
13672
  l0: 40.0
13673
+ neuronpedia: gemma-2-2b/22-res-matryoshka-dc
13664
13674
  - id: blocks.23.hook_resid_post
13665
13675
  path: standard/blocks.23.hook_resid_post
13666
13676
  l0: 40.0
13677
+ neuronpedia: gemma-2-2b/23-res-matryoshka-dc
13667
13678
  - id: blocks.24.hook_resid_post
13668
13679
  path: standard/blocks.24.hook_resid_post
13669
13680
  l0: 40.0
13681
+ neuronpedia: gemma-2-2b/24-res-matryoshka-dc
13670
13682
  gemma-2-2b-res-snap-matryoshka-dc:
13671
13683
  conversion_func: null
13672
13684
  links:
@@ -168,10 +168,6 @@ class GatedTrainingSAE(TrainingSAE[GatedTrainingSAEConfig]):
168
168
 
169
169
  # Magnitude path
170
170
  magnitude_pre_activation = sae_in @ (self.W_enc * self.r_mag.exp()) + self.b_mag
171
- if self.training and self.cfg.noise_scale > 0:
172
- magnitude_pre_activation += (
173
- torch.randn_like(magnitude_pre_activation) * self.cfg.noise_scale
174
- )
175
171
  magnitude_pre_activation = self.hook_sae_acts_pre(magnitude_pre_activation)
176
172
 
177
173
  feature_magnitudes = self.activation_fn(magnitude_pre_activation)
@@ -105,7 +105,7 @@ class JumpReLUSAE(SAE[JumpReLUSAEConfig]):
105
105
  JumpReLUSAE is an inference-only implementation of a Sparse Autoencoder (SAE)
106
106
  using a JumpReLU activation. For each unit, if its pre-activation is
107
107
  <= threshold, that unit is zeroed out; otherwise, it follows a user-specified
108
- activation function (e.g., ReLU, tanh-relu, etc.).
108
+ activation function (e.g., ReLU etc.).
109
109
 
110
110
  It implements:
111
111
  - initialize_weights: sets up parameters, including a threshold.
@@ -142,7 +142,7 @@ class JumpReLUSAE(SAE[JumpReLUSAEConfig]):
142
142
  sae_in = self.process_sae_in(x)
143
143
  hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
144
144
 
145
- # 1) Apply the base "activation_fn" from config (e.g., ReLU, tanh-relu).
145
+ # 1) Apply the base "activation_fn" from config (e.g., ReLU).
146
146
  base_acts = self.activation_fn(hidden_pre)
147
147
 
148
148
  # 2) Zero out any unit whose (hidden_pre <= threshold).
@@ -191,8 +191,8 @@ class JumpReLUTrainingSAEConfig(TrainingSAEConfig):
191
191
  Configuration class for training a JumpReLUTrainingSAE.
192
192
  """
193
193
 
194
- jumprelu_init_threshold: float = 0.001
195
- jumprelu_bandwidth: float = 0.001
194
+ jumprelu_init_threshold: float = 0.01
195
+ jumprelu_bandwidth: float = 0.05
196
196
  l0_coefficient: float = 1.0
197
197
  l0_warm_up_steps: int = 0
198
198
 
@@ -257,12 +257,6 @@ class JumpReLUTrainingSAE(TrainingSAE[JumpReLUTrainingSAEConfig]):
257
257
  sae_in = self.process_sae_in(x)
258
258
 
259
259
  hidden_pre = sae_in @ self.W_enc + self.b_enc
260
-
261
- if self.training and self.cfg.noise_scale > 0:
262
- hidden_pre = (
263
- hidden_pre + torch.randn_like(hidden_pre) * self.cfg.noise_scale
264
- )
265
-
266
260
  feature_acts = JumpReLU.apply(hidden_pre, self.threshold, self.bandwidth)
267
261
 
268
262
  return feature_acts, hidden_pre # type: ignore