sae-lens 6.0.0rc2__py3-none-any.whl → 6.0.0rc4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,377 @@
1
+ import json
2
+ import signal
3
+ import sys
4
+ from collections.abc import Sequence
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from typing import Any, Generic
8
+
9
+ import torch
10
+ import wandb
11
+ from simple_parsing import ArgumentParser
12
+ from transformer_lens.hook_points import HookedRootModule
13
+ from typing_extensions import deprecated
14
+
15
+ from sae_lens import logger
16
+ from sae_lens.config import HfDataset, LanguageModelSAERunnerConfig
17
+ from sae_lens.constants import ACTIVATIONS_STORE_STATE_FILENAME, RUNNER_CFG_FILENAME
18
+ from sae_lens.evals import EvalConfig, run_evals
19
+ from sae_lens.load_model import load_model
20
+ from sae_lens.saes.gated_sae import GatedTrainingSAEConfig
21
+ from sae_lens.saes.jumprelu_sae import JumpReLUTrainingSAEConfig
22
+ from sae_lens.saes.sae import (
23
+ T_TRAINING_SAE,
24
+ T_TRAINING_SAE_CONFIG,
25
+ TrainingSAE,
26
+ TrainingSAEConfig,
27
+ )
28
+ from sae_lens.saes.standard_sae import StandardTrainingSAEConfig
29
+ from sae_lens.saes.topk_sae import TopKTrainingSAEConfig
30
+ from sae_lens.training.activation_scaler import ActivationScaler
31
+ from sae_lens.training.activations_store import ActivationsStore
32
+ from sae_lens.training.sae_trainer import SAETrainer
33
+ from sae_lens.training.types import DataProvider
34
+
35
+
36
+ class InterruptedException(Exception):
37
+ pass
38
+
39
+
40
+ def interrupt_callback(sig_num: Any, stack_frame: Any): # noqa: ARG001
41
+ raise InterruptedException()
42
+
43
+
44
+ @dataclass
45
+ class LLMSaeEvaluator(Generic[T_TRAINING_SAE]):
46
+ model: HookedRootModule
47
+ activations_store: ActivationsStore
48
+ eval_batch_size_prompts: int | None
49
+ n_eval_batches: int
50
+ model_kwargs: dict[str, Any]
51
+
52
+ def __call__(
53
+ self,
54
+ sae: T_TRAINING_SAE,
55
+ data_provider: DataProvider,
56
+ activation_scaler: ActivationScaler,
57
+ ) -> dict[str, Any]:
58
+ ignore_tokens = set()
59
+ if self.activations_store.exclude_special_tokens is not None:
60
+ ignore_tokens = set(self.activations_store.exclude_special_tokens.tolist())
61
+
62
+ eval_config = EvalConfig(
63
+ batch_size_prompts=self.eval_batch_size_prompts,
64
+ n_eval_reconstruction_batches=self.n_eval_batches,
65
+ n_eval_sparsity_variance_batches=self.n_eval_batches,
66
+ compute_ce_loss=True,
67
+ compute_l2_norms=True,
68
+ compute_sparsity_metrics=True,
69
+ compute_variance_metrics=True,
70
+ )
71
+
72
+ eval_metrics, _ = run_evals(
73
+ sae=sae,
74
+ activation_store=self.activations_store,
75
+ model=self.model,
76
+ activation_scaler=activation_scaler,
77
+ eval_config=eval_config,
78
+ ignore_tokens=ignore_tokens,
79
+ model_kwargs=self.model_kwargs,
80
+ ) # not calculating featurwise metrics here.
81
+
82
+ # Remove eval metrics that are already logged during training
83
+ eval_metrics.pop("metrics/explained_variance", None)
84
+ eval_metrics.pop("metrics/explained_variance_std", None)
85
+ eval_metrics.pop("metrics/l0", None)
86
+ eval_metrics.pop("metrics/l1", None)
87
+ eval_metrics.pop("metrics/mse", None)
88
+
89
+ # Remove metrics that are not useful for wandb logging
90
+ eval_metrics.pop("metrics/total_tokens_evaluated", None)
91
+
92
+ return eval_metrics
93
+
94
+
95
+ class LanguageModelSAETrainingRunner:
96
+ """
97
+ Class to run the training of a Sparse Autoencoder (SAE) on a TransformerLens model.
98
+ """
99
+
100
+ cfg: LanguageModelSAERunnerConfig[Any]
101
+ model: HookedRootModule
102
+ sae: TrainingSAE[Any]
103
+ activations_store: ActivationsStore
104
+
105
+ def __init__(
106
+ self,
107
+ cfg: LanguageModelSAERunnerConfig[T_TRAINING_SAE_CONFIG],
108
+ override_dataset: HfDataset | None = None,
109
+ override_model: HookedRootModule | None = None,
110
+ override_sae: TrainingSAE[Any] | None = None,
111
+ ):
112
+ if override_dataset is not None:
113
+ logger.warning(
114
+ f"You just passed in a dataset which will override the one specified in your configuration: {cfg.dataset_path}. As a consequence this run will not be reproducible via configuration alone."
115
+ )
116
+ if override_model is not None:
117
+ logger.warning(
118
+ f"You just passed in a model which will override the one specified in your configuration: {cfg.model_name}. As a consequence this run will not be reproducible via configuration alone."
119
+ )
120
+
121
+ self.cfg = cfg
122
+
123
+ if override_model is None:
124
+ self.model = load_model(
125
+ self.cfg.model_class_name,
126
+ self.cfg.model_name,
127
+ device=self.cfg.device,
128
+ model_from_pretrained_kwargs=self.cfg.model_from_pretrained_kwargs,
129
+ )
130
+ else:
131
+ self.model = override_model
132
+
133
+ self.activations_store = ActivationsStore.from_config(
134
+ self.model,
135
+ self.cfg,
136
+ override_dataset=override_dataset,
137
+ )
138
+
139
+ if override_sae is None:
140
+ if self.cfg.from_pretrained_path is not None:
141
+ self.sae = TrainingSAE.load_from_disk(
142
+ self.cfg.from_pretrained_path, self.cfg.device
143
+ )
144
+ else:
145
+ self.sae = TrainingSAE.from_dict(
146
+ TrainingSAEConfig.from_dict(
147
+ self.cfg.get_training_sae_cfg_dict(),
148
+ ).to_dict()
149
+ )
150
+ else:
151
+ self.sae = override_sae
152
+ self.sae.to(self.cfg.device)
153
+
154
+ def run(self):
155
+ """
156
+ Run the training of the SAE.
157
+ """
158
+ self._set_sae_metadata()
159
+ if self.cfg.logger.log_to_wandb:
160
+ wandb.init(
161
+ project=self.cfg.logger.wandb_project,
162
+ entity=self.cfg.logger.wandb_entity,
163
+ config=self.cfg.to_dict(),
164
+ name=self.cfg.logger.run_name,
165
+ id=self.cfg.logger.wandb_id,
166
+ )
167
+
168
+ evaluator = LLMSaeEvaluator(
169
+ model=self.model,
170
+ activations_store=self.activations_store,
171
+ eval_batch_size_prompts=self.cfg.eval_batch_size_prompts,
172
+ n_eval_batches=self.cfg.n_eval_batches,
173
+ model_kwargs=self.cfg.model_kwargs,
174
+ )
175
+
176
+ trainer = SAETrainer(
177
+ sae=self.sae,
178
+ data_provider=self.activations_store,
179
+ evaluator=evaluator,
180
+ save_checkpoint_fn=self.save_checkpoint,
181
+ cfg=self.cfg.to_sae_trainer_config(),
182
+ )
183
+
184
+ self._compile_if_needed()
185
+ sae = self.run_trainer_with_interruption_handling(trainer)
186
+
187
+ if self.cfg.logger.log_to_wandb:
188
+ wandb.finish()
189
+
190
+ return sae
191
+
192
+ def _set_sae_metadata(self):
193
+ self.sae.cfg.metadata.dataset_path = self.cfg.dataset_path
194
+ self.sae.cfg.metadata.hook_name = self.cfg.hook_name
195
+ self.sae.cfg.metadata.model_name = self.cfg.model_name
196
+ self.sae.cfg.metadata.model_class_name = self.cfg.model_class_name
197
+ self.sae.cfg.metadata.hook_head_index = self.cfg.hook_head_index
198
+ self.sae.cfg.metadata.context_size = self.cfg.context_size
199
+ self.sae.cfg.metadata.seqpos_slice = self.cfg.seqpos_slice
200
+ self.sae.cfg.metadata.model_from_pretrained_kwargs = (
201
+ self.cfg.model_from_pretrained_kwargs
202
+ )
203
+ self.sae.cfg.metadata.prepend_bos = self.cfg.prepend_bos
204
+ self.sae.cfg.metadata.exclude_special_tokens = self.cfg.exclude_special_tokens
205
+
206
+ def _compile_if_needed(self):
207
+ # Compile model and SAE
208
+ # torch.compile can provide significant speedups (10-20% in testing)
209
+ # using max-autotune gives the best speedups but:
210
+ # (a) increases VRAM usage,
211
+ # (b) can't be used on both SAE and LM (some issue with cudagraphs), and
212
+ # (c) takes some time to compile
213
+ # optimal settings seem to be:
214
+ # use max-autotune on SAE and max-autotune-no-cudagraphs on LM
215
+ # (also pylance seems to really hate this)
216
+ if self.cfg.compile_llm:
217
+ self.model = torch.compile(
218
+ self.model,
219
+ mode=self.cfg.llm_compilation_mode,
220
+ ) # type: ignore
221
+
222
+ if self.cfg.compile_sae:
223
+ backend = "aot_eager" if self.cfg.device == "mps" else "inductor"
224
+
225
+ self.sae.training_forward_pass = torch.compile( # type: ignore
226
+ self.sae.training_forward_pass,
227
+ mode=self.cfg.sae_compilation_mode,
228
+ backend=backend,
229
+ ) # type: ignore
230
+
231
+ def run_trainer_with_interruption_handling(
232
+ self, trainer: SAETrainer[TrainingSAE[TrainingSAEConfig], TrainingSAEConfig]
233
+ ):
234
+ try:
235
+ # signal handlers (if preempted)
236
+ signal.signal(signal.SIGINT, interrupt_callback)
237
+ signal.signal(signal.SIGTERM, interrupt_callback)
238
+
239
+ # train SAE
240
+ sae = trainer.fit()
241
+
242
+ except (KeyboardInterrupt, InterruptedException):
243
+ logger.warning("interrupted, saving progress")
244
+ checkpoint_path = Path(self.cfg.checkpoint_path) / str(
245
+ trainer.n_training_samples
246
+ )
247
+ self.save_checkpoint(checkpoint_path)
248
+ logger.info("done saving")
249
+ raise
250
+
251
+ return sae
252
+
253
+ def save_checkpoint(
254
+ self,
255
+ checkpoint_path: Path,
256
+ ) -> None:
257
+ self.activations_store.save(
258
+ str(checkpoint_path / ACTIVATIONS_STORE_STATE_FILENAME)
259
+ )
260
+
261
+ runner_config = self.cfg.to_dict()
262
+ with open(checkpoint_path / RUNNER_CFG_FILENAME, "w") as f:
263
+ json.dump(runner_config, f)
264
+
265
+
266
+ def _parse_cfg_args(
267
+ args: Sequence[str],
268
+ ) -> LanguageModelSAERunnerConfig[TrainingSAEConfig]:
269
+ """
270
+ Parse command line arguments into a LanguageModelSAERunnerConfig.
271
+
272
+ This function first parses the architecture argument to determine which
273
+ concrete SAE config class to use, then parses the full configuration
274
+ with that concrete type.
275
+ """
276
+ if len(args) == 0:
277
+ args = ["--help"]
278
+
279
+ # First, parse only the architecture to determine which concrete class to use
280
+ architecture_parser = ArgumentParser(
281
+ description="Parse architecture to determine SAE config class",
282
+ exit_on_error=False,
283
+ add_help=False, # Don't add help to avoid conflicts
284
+ )
285
+ architecture_parser.add_argument(
286
+ "--architecture",
287
+ type=str,
288
+ choices=["standard", "gated", "jumprelu", "topk"],
289
+ default="standard",
290
+ help="SAE architecture to use",
291
+ )
292
+
293
+ # Parse known args to extract architecture, ignore unknown args for now
294
+ arch_args, remaining_args = architecture_parser.parse_known_args(args)
295
+ architecture = arch_args.architecture
296
+
297
+ # Remove architecture from remaining args if it exists
298
+ filtered_args = []
299
+ skip_next = False
300
+ for arg in remaining_args:
301
+ if skip_next:
302
+ skip_next = False
303
+ continue
304
+ if arg == "--architecture":
305
+ skip_next = True # Skip the next argument (the architecture value)
306
+ continue
307
+ filtered_args.append(arg)
308
+
309
+ # Create a custom wrapper class that simple_parsing can handle
310
+ def create_config_class(
311
+ sae_config_type: type[TrainingSAEConfig],
312
+ ) -> type[LanguageModelSAERunnerConfig[TrainingSAEConfig]]:
313
+ """Create a concrete config class for the given SAE config type."""
314
+
315
+ # Create the base config without the sae field
316
+ from dataclasses import field as dataclass_field
317
+ from dataclasses import fields, make_dataclass
318
+
319
+ # Get all fields from LanguageModelSAERunnerConfig except the generic sae field
320
+ base_fields = []
321
+ for field_obj in fields(LanguageModelSAERunnerConfig):
322
+ if field_obj.name != "sae":
323
+ base_fields.append((field_obj.name, field_obj.type, field_obj))
324
+
325
+ # Add the concrete sae field
326
+ base_fields.append(
327
+ (
328
+ "sae",
329
+ sae_config_type,
330
+ dataclass_field(
331
+ default_factory=lambda: sae_config_type(d_in=512, d_sae=1024)
332
+ ),
333
+ )
334
+ )
335
+
336
+ # Create the concrete class
337
+ return make_dataclass(
338
+ f"{sae_config_type.__name__}RunnerConfig",
339
+ base_fields,
340
+ bases=(LanguageModelSAERunnerConfig,),
341
+ )
342
+
343
+ # Map architecture to concrete config class
344
+ sae_config_map = {
345
+ "standard": StandardTrainingSAEConfig,
346
+ "gated": GatedTrainingSAEConfig,
347
+ "jumprelu": JumpReLUTrainingSAEConfig,
348
+ "topk": TopKTrainingSAEConfig,
349
+ }
350
+
351
+ sae_config_type = sae_config_map[architecture]
352
+ concrete_config_class = create_config_class(sae_config_type)
353
+
354
+ # Now parse the full configuration with the concrete type
355
+ parser = ArgumentParser(exit_on_error=False)
356
+ parser.add_arguments(concrete_config_class, dest="cfg")
357
+
358
+ # Parse the filtered arguments (without --architecture)
359
+ parsed_args = parser.parse_args(filtered_args)
360
+
361
+ # Return the parsed configuration
362
+ return parsed_args.cfg
363
+
364
+
365
+ # moved into its own function to make it easier to test
366
+ def _run_cli(args: Sequence[str]):
367
+ cfg = _parse_cfg_args(args)
368
+ LanguageModelSAETrainingRunner(cfg=cfg).run()
369
+
370
+
371
+ if __name__ == "__main__":
372
+ _run_cli(args=sys.argv[1:])
373
+
374
+
375
+ @deprecated("Use LanguageModelSAETrainingRunner instead")
376
+ class SAETrainingRunner(LanguageModelSAETrainingRunner):
377
+ pass
sae_lens/load_model.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Any, Literal, cast
1
+ from typing import Any, Callable, Literal, cast
2
2
 
3
3
  import torch
4
4
  from transformer_lens import HookedTransformer
@@ -77,6 +77,7 @@ class HookedProxyLM(HookedRootModule):
77
77
  # copied and modified from base HookedRootModule
78
78
  def setup(self):
79
79
  self.mod_dict = {}
80
+ self.named_modules_dict = {}
80
81
  self.hook_dict: dict[str, HookPoint] = {}
81
82
  for name, module in self.model.named_modules():
82
83
  if name == "":
@@ -89,14 +90,21 @@ class HookedProxyLM(HookedRootModule):
89
90
 
90
91
  self.hook_dict[name] = hook_point
91
92
  self.mod_dict[name] = hook_point
93
+ self.named_modules_dict[name] = module
94
+
95
+ def run_with_cache(self, *args: Any, **kwargs: Any): # type: ignore
96
+ if "names_filter" in kwargs:
97
+ # hacky way to make sure that the names_filter is passed to our forward method
98
+ kwargs["_names_filter"] = kwargs["names_filter"]
99
+ return super().run_with_cache(*args, **kwargs)
92
100
 
93
101
  def forward(
94
102
  self,
95
103
  tokens: torch.Tensor,
96
104
  return_type: Literal["both", "logits"] = "logits",
97
105
  loss_per_token: bool = False,
98
- # TODO: implement real support for stop_at_layer
99
106
  stop_at_layer: int | None = None,
107
+ _names_filter: list[str] | None = None,
100
108
  **kwargs: Any,
101
109
  ) -> Output | Loss:
102
110
  # This is just what's needed for evals, not everything that HookedTransformer has
@@ -107,8 +115,28 @@ class HookedProxyLM(HookedRootModule):
107
115
  raise NotImplementedError(
108
116
  "Only return_type supported is 'both' or 'logits' to match what's in evals.py and ActivationsStore"
109
117
  )
110
- output = self.model(tokens)
111
- logits = _extract_logits_from_output(output)
118
+
119
+ stop_hooks = []
120
+ if stop_at_layer is not None and _names_filter is not None:
121
+ if return_type != "logits":
122
+ raise NotImplementedError(
123
+ "stop_at_layer is not supported for return_type='both'"
124
+ )
125
+ stop_manager = StopManager(_names_filter)
126
+
127
+ for hook_name in _names_filter:
128
+ module = self.named_modules_dict[hook_name]
129
+ stop_fn = stop_manager.get_stop_hook_fn(hook_name)
130
+ stop_hooks.append(module.register_forward_hook(stop_fn))
131
+ try:
132
+ output = self.model(tokens)
133
+ logits = _extract_logits_from_output(output)
134
+ except StopForward:
135
+ # If we stop early, we don't care about the return output
136
+ return None # type: ignore
137
+ finally:
138
+ for stop_hook in stop_hooks:
139
+ stop_hook.remove()
112
140
 
113
141
  if return_type == "logits":
114
142
  return logits
@@ -159,7 +187,7 @@ class HookedProxyLM(HookedRootModule):
159
187
 
160
188
  # We don't want to prepend bos but the tokenizer does it automatically, so we remove it manually
161
189
  if hasattr(self.tokenizer, "add_bos_token") and self.tokenizer.add_bos_token: # type: ignore
162
- tokens = get_tokens_with_bos_removed(self.tokenizer, tokens)
190
+ tokens = get_tokens_with_bos_removed(self.tokenizer, tokens) # type: ignore
163
191
  return tokens # type: ignore
164
192
 
165
193
 
@@ -183,3 +211,23 @@ def get_hook_fn(hook_point: HookPoint):
183
211
  return output
184
212
 
185
213
  return hook_fn
214
+
215
+
216
+ class StopForward(Exception):
217
+ pass
218
+
219
+
220
+ class StopManager:
221
+ def __init__(self, hook_names: list[str]):
222
+ self.hook_names = hook_names
223
+ self.total_hook_names = len(set(hook_names))
224
+ self.called_hook_names = set()
225
+
226
+ def get_stop_hook_fn(self, hook_name: str) -> Callable[[Any, Any, Any], Any]:
227
+ def stop_hook_fn(module: Any, input: Any, output: Any) -> Any: # noqa: ARG001
228
+ self.called_hook_names.add(hook_name)
229
+ if len(self.called_hook_names) == self.total_hook_names:
230
+ raise StopForward()
231
+ return output
232
+
233
+ return stop_hook_fn
@@ -26,6 +26,22 @@ from sae_lens.loading.pretrained_saes_directory import (
26
26
  from sae_lens.registry import get_sae_class
27
27
  from sae_lens.util import filter_valid_dataclass_fields
28
28
 
29
+ LLM_METADATA_KEYS = {
30
+ "model_name",
31
+ "hook_name",
32
+ "model_class_name",
33
+ "hook_head_index",
34
+ "model_from_pretrained_kwargs",
35
+ "prepend_bos",
36
+ "exclude_special_tokens",
37
+ "neuronpedia_id",
38
+ "context_size",
39
+ "seqpos_slice",
40
+ "dataset_path",
41
+ "sae_lens_version",
42
+ "sae_lens_training_version",
43
+ }
44
+
29
45
 
30
46
  # loaders take in a release, sae_id, device, and whether to force download, and returns a tuple of config, state_dict, and log sparsity
31
47
  class PretrainedSaeHuggingfaceLoader(Protocol):
@@ -193,7 +209,6 @@ def handle_pre_6_0_config(cfg_dict: dict[str, Any]) -> dict[str, Any]:
193
209
 
194
210
  rename_keys_map = {
195
211
  "hook_point": "hook_name",
196
- "hook_point_layer": "hook_layer",
197
212
  "hook_point_head_index": "hook_head_index",
198
213
  "activation_fn_str": "activation_fn",
199
214
  }
@@ -208,6 +223,10 @@ def handle_pre_6_0_config(cfg_dict: dict[str, Any]) -> dict[str, Any]:
208
223
  new_cfg.setdefault("activation_fn", new_cfg.get("activation_fn", "relu"))
209
224
  new_cfg.setdefault("architecture", "standard")
210
225
  new_cfg.setdefault("neuronpedia_id", None)
226
+ new_cfg.setdefault(
227
+ "reshape_activations",
228
+ "hook_z" if "hook_z" in new_cfg.get("hook_name", "") else "none",
229
+ )
211
230
 
212
231
  if "normalize_activations" in new_cfg and isinstance(
213
232
  new_cfg["normalize_activations"], bool
@@ -232,11 +251,9 @@ def handle_pre_6_0_config(cfg_dict: dict[str, Any]) -> dict[str, Any]:
232
251
  if architecture == "topk":
233
252
  sae_cfg_dict["k"] = new_cfg["activation_fn_kwargs"]["k"]
234
253
 
235
- # import here to avoid circular import
236
- from sae_lens.saes.sae import SAEMetadata
237
-
238
- meta_dict = filter_valid_dataclass_fields(new_cfg, SAEMetadata)
239
- sae_cfg_dict["metadata"] = meta_dict
254
+ sae_cfg_dict["metadata"] = {
255
+ k: v for k, v in new_cfg.items() if k in LLM_METADATA_KEYS
256
+ }
240
257
  sae_cfg_dict["architecture"] = architecture
241
258
  return sae_cfg_dict
242
259
 
@@ -262,7 +279,6 @@ def get_connor_rob_hook_z_config_from_hf(
262
279
  "device": device if device is not None else "cpu",
263
280
  "model_name": "gpt2-small",
264
281
  "hook_name": old_cfg_dict["act_name"],
265
- "hook_layer": old_cfg_dict["layer"],
266
282
  "hook_head_index": None,
267
283
  "activation_fn": "relu",
268
284
  "apply_b_dec_to_input": True,
@@ -273,6 +289,7 @@ def get_connor_rob_hook_z_config_from_hf(
273
289
  "context_size": 128,
274
290
  "normalize_activations": "none",
275
291
  "dataset_trust_remote_code": True,
292
+ "reshape_activations": "hook_z",
276
293
  **(cfg_overrides or {}),
277
294
  }
278
295
 
@@ -411,7 +428,6 @@ def get_gemma_2_config_from_hf(
411
428
  "dtype": "float32",
412
429
  "model_name": model_name,
413
430
  "hook_name": hook_name,
414
- "hook_layer": layer,
415
431
  "hook_head_index": None,
416
432
  "activation_fn": "relu",
417
433
  "finetuning_scaling_factor": False,
@@ -524,7 +540,6 @@ def get_llama_scope_config_from_hf(
524
540
  "dtype": "bfloat16",
525
541
  "model_name": model_name,
526
542
  "hook_name": old_cfg_dict["hook_point_in"],
527
- "hook_layer": int(old_cfg_dict["hook_point_in"].split(".")[1]),
528
543
  "hook_head_index": None,
529
544
  "activation_fn": "relu",
530
545
  "finetuning_scaling_factor": False,
@@ -651,7 +666,6 @@ def get_dictionary_learning_config_1_from_hf(
651
666
  "device": device,
652
667
  "model_name": trainer["lm_name"].split("/")[-1],
653
668
  "hook_name": hook_point_name,
654
- "hook_layer": trainer["layer"],
655
669
  "hook_head_index": None,
656
670
  "activation_fn": activation_fn,
657
671
  "activation_fn_kwargs": activation_fn_kwargs,
@@ -690,7 +704,6 @@ def get_deepseek_r1_config_from_hf(
690
704
  "context_size": 1024,
691
705
  "model_name": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
692
706
  "hook_name": f"blocks.{layer}.hook_resid_post",
693
- "hook_layer": layer,
694
707
  "hook_head_index": None,
695
708
  "prepend_bos": True,
696
709
  "dataset_path": "lmsys/lmsys-chat-1m",
@@ -849,7 +862,6 @@ def get_llama_scope_r1_distill_config_from_hf(
849
862
  "device": device,
850
863
  "model_name": model_name,
851
864
  "hook_name": huggingface_cfg_dict["hook_point_in"],
852
- "hook_layer": int(huggingface_cfg_dict["hook_point_in"].split(".")[1]),
853
865
  "hook_head_index": None,
854
866
  "activation_fn": "relu",
855
867
  "finetuning_scaling_factor": False,
@@ -168,10 +168,6 @@ class GatedTrainingSAE(TrainingSAE[GatedTrainingSAEConfig]):
168
168
 
169
169
  # Magnitude path
170
170
  magnitude_pre_activation = sae_in @ (self.W_enc * self.r_mag.exp()) + self.b_mag
171
- if self.training and self.cfg.noise_scale > 0:
172
- magnitude_pre_activation += (
173
- torch.randn_like(magnitude_pre_activation) * self.cfg.noise_scale
174
- )
175
171
  magnitude_pre_activation = self.hook_sae_acts_pre(magnitude_pre_activation)
176
172
 
177
173
  feature_magnitudes = self.activation_fn(magnitude_pre_activation)
@@ -105,7 +105,7 @@ class JumpReLUSAE(SAE[JumpReLUSAEConfig]):
105
105
  JumpReLUSAE is an inference-only implementation of a Sparse Autoencoder (SAE)
106
106
  using a JumpReLU activation. For each unit, if its pre-activation is
107
107
  <= threshold, that unit is zeroed out; otherwise, it follows a user-specified
108
- activation function (e.g., ReLU, tanh-relu, etc.).
108
+ activation function (e.g., ReLU etc.).
109
109
 
110
110
  It implements:
111
111
  - initialize_weights: sets up parameters, including a threshold.
@@ -142,7 +142,7 @@ class JumpReLUSAE(SAE[JumpReLUSAEConfig]):
142
142
  sae_in = self.process_sae_in(x)
143
143
  hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
144
144
 
145
- # 1) Apply the base "activation_fn" from config (e.g., ReLU, tanh-relu).
145
+ # 1) Apply the base "activation_fn" from config (e.g., ReLU).
146
146
  base_acts = self.activation_fn(hidden_pre)
147
147
 
148
148
  # 2) Zero out any unit whose (hidden_pre <= threshold).
@@ -191,8 +191,8 @@ class JumpReLUTrainingSAEConfig(TrainingSAEConfig):
191
191
  Configuration class for training a JumpReLUTrainingSAE.
192
192
  """
193
193
 
194
- jumprelu_init_threshold: float = 0.001
195
- jumprelu_bandwidth: float = 0.001
194
+ jumprelu_init_threshold: float = 0.01
195
+ jumprelu_bandwidth: float = 0.05
196
196
  l0_coefficient: float = 1.0
197
197
  l0_warm_up_steps: int = 0
198
198
 
@@ -257,12 +257,6 @@ class JumpReLUTrainingSAE(TrainingSAE[JumpReLUTrainingSAEConfig]):
257
257
  sae_in = self.process_sae_in(x)
258
258
 
259
259
  hidden_pre = sae_in @ self.W_enc + self.b_enc
260
-
261
- if self.training and self.cfg.noise_scale > 0:
262
- hidden_pre = (
263
- hidden_pre + torch.randn_like(hidden_pre) * self.cfg.noise_scale
264
- )
265
-
266
260
  feature_acts = JumpReLU.apply(hidden_pre, self.threshold, self.bandwidth)
267
261
 
268
262
  return feature_acts, hidden_pre # type: ignore