sae-lens 6.0.0rc2__py3-none-any.whl → 6.0.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sae_lens/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.0.0-rc.2"
2
+ __version__ = "6.0.0-rc.3"
3
3
 
4
4
  import logging
5
5
 
@@ -33,16 +33,17 @@ from .cache_activations_runner import CacheActivationsRunner
33
33
  from .config import (
34
34
  CacheActivationsRunnerConfig,
35
35
  LanguageModelSAERunnerConfig,
36
+ LoggingConfig,
36
37
  PretokenizeRunnerConfig,
37
38
  )
38
39
  from .evals import run_evals
40
+ from .llm_sae_training_runner import LanguageModelSAETrainingRunner, SAETrainingRunner
39
41
  from .loading.pretrained_sae_loaders import (
40
42
  PretrainedSaeDiskLoader,
41
43
  PretrainedSaeHuggingfaceLoader,
42
44
  )
43
45
  from .pretokenize_runner import PretokenizeRunner, pretokenize_runner
44
46
  from .registry import register_sae_class, register_sae_training_class
45
- from .sae_training_runner import SAETrainingRunner
46
47
  from .training.activations_store import ActivationsStore
47
48
  from .training.upload_saes_to_huggingface import upload_saes_to_huggingface
48
49
 
@@ -54,7 +55,7 @@ __all__ = [
54
55
  "HookedSAETransformer",
55
56
  "ActivationsStore",
56
57
  "LanguageModelSAERunnerConfig",
57
- "SAETrainingRunner",
58
+ "LanguageModelSAETrainingRunner",
58
59
  "CacheActivationsRunnerConfig",
59
60
  "CacheActivationsRunner",
60
61
  "PretokenizeRunnerConfig",
@@ -82,6 +83,8 @@ __all__ = [
82
83
  "JumpReLUSAEConfig",
83
84
  "JumpReLUTrainingSAE",
84
85
  "JumpReLUTrainingSAEConfig",
86
+ "SAETrainingRunner",
87
+ "LoggingConfig",
85
88
  ]
86
89
 
87
90
 
@@ -34,7 +34,6 @@ def _mk_activations_store(
34
34
  dataset=override_dataset or cfg.dataset_path,
35
35
  streaming=cfg.streaming,
36
36
  hook_name=cfg.hook_name,
37
- hook_layer=cfg.hook_layer,
38
37
  hook_head_index=None,
39
38
  context_size=cfg.context_size,
40
39
  d_in=cfg.d_in,
@@ -265,7 +264,7 @@ class CacheActivationsRunner:
265
264
 
266
265
  for i in tqdm(range(self.cfg.n_buffers), desc="Caching activations"):
267
266
  try:
268
- buffer = self.activations_store.get_buffer(
267
+ buffer = self.activations_store.get_raw_buffer(
269
268
  self.cfg.n_batches_in_buffer, shuffle=False
270
269
  )
271
270
  shard = self._create_shard(buffer)
@@ -319,7 +318,7 @@ class CacheActivationsRunner:
319
318
  def _create_shard(
320
319
  self,
321
320
  buffer: tuple[
322
- Float[torch.Tensor, "(bs context_size) num_layers d_in"],
321
+ Float[torch.Tensor, "(bs context_size) d_in"],
323
322
  Int[torch.Tensor, "(bs context_size)"] | None,
324
323
  ],
325
324
  ) -> Dataset:
@@ -327,13 +326,15 @@ class CacheActivationsRunner:
327
326
  acts, token_ids = buffer
328
327
  acts = einops.rearrange(
329
328
  acts,
330
- "(bs context_size) num_layers d_in -> num_layers bs context_size d_in",
329
+ "(bs context_size) d_in -> bs context_size d_in",
331
330
  bs=self.cfg.n_seq_in_buffer,
332
331
  context_size=self.context_size,
333
332
  d_in=self.cfg.d_in,
334
- num_layers=len(hook_names),
335
333
  )
336
- shard_dict = {hook_name: act for hook_name, act in zip(hook_names, acts)}
334
+ shard_dict: dict[str, object] = {
335
+ hook_name: act_batch
336
+ for hook_name, act_batch in zip(hook_names, [acts], strict=True)
337
+ }
337
338
 
338
339
  if token_ids is not None:
339
340
  token_ids = einops.rearrange(
sae_lens/config.py CHANGED
@@ -23,7 +23,9 @@ from sae_lens.saes.sae import TrainingSAEConfig
23
23
  if TYPE_CHECKING:
24
24
  pass
25
25
 
26
- T_TRAINING_SAE_CONFIG = TypeVar("T_TRAINING_SAE_CONFIG", bound=TrainingSAEConfig)
26
+ T_TRAINING_SAE_CONFIG = TypeVar(
27
+ "T_TRAINING_SAE_CONFIG", bound=TrainingSAEConfig, covariant=True
28
+ )
27
29
 
28
30
  HfDataset = DatasetDict | Dataset | IterableDatasetDict | IterableDataset
29
31
 
@@ -102,7 +104,6 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
102
104
  model_class_name (str): The name of the class of the model to use. This should be either `HookedTransformer` or `HookedMamba`.
103
105
  hook_name (str): The name of the hook to use. This should be a valid TransformerLens hook.
104
106
  hook_eval (str): NOT CURRENTLY IN USE. The name of the hook to use for evaluation.
105
- hook_layer (int): The index of the layer to hook. Used to stop forward passes early and speed up processing.
106
107
  hook_head_index (int, optional): When the hook is for an activation with a head index, we can specify a specific head to use here.
107
108
  dataset_path (str): A Hugging Face dataset path.
108
109
  dataset_trust_remote_code (bool): Whether to trust remote code when loading datasets from Huggingface.
@@ -159,7 +160,6 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
159
160
  model_class_name: str = "HookedTransformer"
160
161
  hook_name: str = "blocks.0.hook_mlp_out"
161
162
  hook_eval: str = "NOT_IN_USE"
162
- hook_layer: int = 0
163
163
  hook_head_index: int | None = None
164
164
  dataset_path: str = ""
165
165
  dataset_trust_remote_code: bool = True
@@ -375,6 +375,28 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
375
375
 
376
376
  return cls(**cfg)
377
377
 
378
+ def to_sae_trainer_config(self) -> "SAETrainerConfig":
379
+ return SAETrainerConfig(
380
+ n_checkpoints=self.n_checkpoints,
381
+ checkpoint_path=self.checkpoint_path,
382
+ total_training_samples=self.total_training_tokens,
383
+ device=self.device,
384
+ autocast=self.autocast,
385
+ lr=self.lr,
386
+ lr_end=self.lr_end,
387
+ lr_scheduler_name=self.lr_scheduler_name,
388
+ lr_warm_up_steps=self.lr_warm_up_steps,
389
+ adam_beta1=self.adam_beta1,
390
+ adam_beta2=self.adam_beta2,
391
+ lr_decay_steps=self.lr_decay_steps,
392
+ n_restart_cycles=self.n_restart_cycles,
393
+ total_training_steps=self.total_training_steps,
394
+ train_batch_size_samples=self.train_batch_size_tokens,
395
+ dead_feature_window=self.dead_feature_window,
396
+ feature_sampling_window=self.feature_sampling_window,
397
+ logger=self.logger,
398
+ )
399
+
378
400
 
379
401
  @dataclass
380
402
  class CacheActivationsRunnerConfig:
@@ -386,7 +408,6 @@ class CacheActivationsRunnerConfig:
386
408
  model_name (str): The name of the model to use.
387
409
  model_batch_size (int): How many prompts are in the batch of the language model when generating activations.
388
410
  hook_name (str): The name of the hook to use.
389
- hook_layer (int): The layer of the final hook. Currently only support a single hook, so this should be the same as hook_name.
390
411
  d_in (int): Dimension of the model.
391
412
  total_training_tokens (int): Total number of tokens to process.
392
413
  context_size (int): Context size to process. Can be left as -1 if the dataset is tokenized.
@@ -416,7 +437,6 @@ class CacheActivationsRunnerConfig:
416
437
  model_name: str
417
438
  model_batch_size: int
418
439
  hook_name: str
419
- hook_layer: int
420
440
  d_in: int
421
441
  training_tokens: int
422
442
 
@@ -576,3 +596,25 @@ class PretokenizeRunnerConfig:
576
596
  hf_num_shards: int = 64
577
597
  hf_revision: str = "main"
578
598
  hf_is_private_repo: bool = False
599
+
600
+
601
+ @dataclass
602
+ class SAETrainerConfig:
603
+ n_checkpoints: int
604
+ checkpoint_path: str
605
+ total_training_samples: int
606
+ device: str
607
+ autocast: bool
608
+ lr: float
609
+ lr_end: float | None
610
+ lr_scheduler_name: str
611
+ lr_warm_up_steps: int
612
+ adam_beta1: float
613
+ adam_beta2: float
614
+ lr_decay_steps: int
615
+ n_restart_cycles: int
616
+ total_training_steps: int
617
+ train_batch_size_samples: int
618
+ dead_feature_window: int
619
+ feature_sampling_window: int
620
+ logger: LoggingConfig
sae_lens/constants.py CHANGED
@@ -16,3 +16,5 @@ SPARSITY_FILENAME = "sparsity.safetensors"
16
16
  SAE_WEIGHTS_FILENAME = "sae_weights.safetensors"
17
17
  SAE_CFG_FILENAME = "cfg.json"
18
18
  RUNNER_CFG_FILENAME = "runner_cfg.json"
19
+ ACTIVATIONS_STORE_STATE_FILENAME = "activations_store_state.safetensors"
20
+ ACTIVATION_SCALER_CFG_FILENAME = "activation_scaler.json"
sae_lens/evals.py CHANGED
@@ -21,7 +21,9 @@ from transformer_lens.hook_points import HookedRootModule
21
21
 
22
22
  from sae_lens.loading.pretrained_saes_directory import get_pretrained_saes_directory
23
23
  from sae_lens.saes.sae import SAE, SAEConfig
24
+ from sae_lens.training.activation_scaler import ActivationScaler
24
25
  from sae_lens.training.activations_store import ActivationsStore
26
+ from sae_lens.util import extract_stop_at_layer_from_tlens_hook_name
25
27
 
26
28
 
27
29
  def get_library_version() -> str:
@@ -103,6 +105,7 @@ def run_evals(
103
105
  sae: SAE[Any],
104
106
  activation_store: ActivationsStore,
105
107
  model: HookedRootModule,
108
+ activation_scaler: ActivationScaler,
106
109
  eval_config: EvalConfig = EvalConfig(),
107
110
  model_kwargs: Mapping[str, Any] = {},
108
111
  ignore_tokens: set[int | None] = set(),
@@ -140,6 +143,7 @@ def run_evals(
140
143
  sae,
141
144
  model,
142
145
  activation_store,
146
+ activation_scaler,
143
147
  compute_kl=eval_config.compute_kl,
144
148
  compute_ce_loss=eval_config.compute_ce_loss,
145
149
  n_batches=eval_config.n_eval_reconstruction_batches,
@@ -189,6 +193,7 @@ def run_evals(
189
193
  sae,
190
194
  model,
191
195
  activation_store,
196
+ activation_scaler,
192
197
  compute_l2_norms=eval_config.compute_l2_norms,
193
198
  compute_sparsity_metrics=eval_config.compute_sparsity_metrics,
194
199
  compute_variance_metrics=eval_config.compute_variance_metrics,
@@ -301,6 +306,7 @@ def get_downstream_reconstruction_metrics(
301
306
  sae: SAE[Any],
302
307
  model: HookedRootModule,
303
308
  activation_store: ActivationsStore,
309
+ activation_scaler: ActivationScaler,
304
310
  compute_kl: bool,
305
311
  compute_ce_loss: bool,
306
312
  n_batches: int,
@@ -326,8 +332,8 @@ def get_downstream_reconstruction_metrics(
326
332
  for metric_name, metric_value in get_recons_loss(
327
333
  sae,
328
334
  model,
335
+ activation_scaler,
329
336
  batch_tokens,
330
- activation_store,
331
337
  compute_kl=compute_kl,
332
338
  compute_ce_loss=compute_ce_loss,
333
339
  ignore_tokens=ignore_tokens,
@@ -369,6 +375,7 @@ def get_sparsity_and_variance_metrics(
369
375
  sae: SAE[Any],
370
376
  model: HookedRootModule,
371
377
  activation_store: ActivationsStore,
378
+ activation_scaler: ActivationScaler,
372
379
  n_batches: int,
373
380
  compute_l2_norms: bool,
374
381
  compute_sparsity_metrics: bool,
@@ -436,7 +443,7 @@ def get_sparsity_and_variance_metrics(
436
443
  batch_tokens,
437
444
  prepend_bos=False,
438
445
  names_filter=[hook_name],
439
- stop_at_layer=sae.cfg.metadata.hook_layer + 1,
446
+ stop_at_layer=extract_stop_at_layer_from_tlens_hook_name(hook_name),
440
447
  **model_kwargs,
441
448
  )
442
449
 
@@ -451,16 +458,14 @@ def get_sparsity_and_variance_metrics(
451
458
  original_act = cache[hook_name]
452
459
 
453
460
  # normalise if necessary (necessary in training only, otherwise we should fold the scaling in)
454
- if activation_store.normalize_activations == "expected_average_only_in":
455
- original_act = activation_store.apply_norm_scaling_factor(original_act)
461
+ original_act = activation_scaler.scale(original_act)
456
462
 
457
463
  # send the (maybe normalised) activations into the SAE
458
464
  sae_feature_activations = sae.encode(original_act.to(sae.device))
459
465
  sae_out = sae.decode(sae_feature_activations).to(original_act.device)
460
466
  del cache
461
467
 
462
- if activation_store.normalize_activations == "expected_average_only_in":
463
- sae_out = activation_store.unscale(sae_out)
468
+ sae_out = activation_scaler.unscale(sae_out)
464
469
 
465
470
  flattened_sae_input = einops.rearrange(original_act, "b ctx d -> (b ctx) d")
466
471
  flattened_sae_feature_acts = einops.rearrange(
@@ -582,8 +587,8 @@ def get_sparsity_and_variance_metrics(
582
587
  def get_recons_loss(
583
588
  sae: SAE[SAEConfig],
584
589
  model: HookedRootModule,
590
+ activation_scaler: ActivationScaler,
585
591
  batch_tokens: torch.Tensor,
586
- activation_store: ActivationsStore,
587
592
  compute_kl: bool,
588
593
  compute_ce_loss: bool,
589
594
  ignore_tokens: set[int | None] = set(),
@@ -618,15 +623,13 @@ def get_recons_loss(
618
623
  activations = activations.to(sae.device)
619
624
 
620
625
  # Handle rescaling if SAE expects it
621
- if activation_store.normalize_activations == "expected_average_only_in":
622
- activations = activation_store.apply_norm_scaling_factor(activations)
626
+ activations = activation_scaler.scale(activations)
623
627
 
624
628
  # SAE class agnost forward forward pass.
625
629
  new_activations = sae.decode(sae.encode(activations)).to(activations.dtype)
626
630
 
627
631
  # Unscale if activations were scaled prior to going into the SAE
628
- if activation_store.normalize_activations == "expected_average_only_in":
629
- new_activations = activation_store.unscale(new_activations)
632
+ new_activations = activation_scaler.unscale(new_activations)
630
633
 
631
634
  new_activations = torch.where(mask[..., None], new_activations, activations)
632
635
 
@@ -637,8 +640,7 @@ def get_recons_loss(
637
640
  activations = activations.to(sae.device)
638
641
 
639
642
  # Handle rescaling if SAE expects it
640
- if activation_store.normalize_activations == "expected_average_only_in":
641
- activations = activation_store.apply_norm_scaling_factor(activations)
643
+ activations = activation_scaler.scale(activations)
642
644
 
643
645
  # SAE class agnost forward forward pass.
644
646
  new_activations = sae.decode(sae.encode(activations.flatten(-2, -1))).to(
@@ -650,8 +652,7 @@ def get_recons_loss(
650
652
  ) # reshape to match original shape
651
653
 
652
654
  # Unscale if activations were scaled prior to going into the SAE
653
- if activation_store.normalize_activations == "expected_average_only_in":
654
- new_activations = activation_store.unscale(new_activations)
655
+ new_activations = activation_scaler.unscale(new_activations)
655
656
 
656
657
  return new_activations.to(original_device)
657
658
 
@@ -660,8 +661,7 @@ def get_recons_loss(
660
661
  activations = activations.to(sae.device)
661
662
 
662
663
  # Handle rescaling if SAE expects it
663
- if activation_store.normalize_activations == "expected_average_only_in":
664
- activations = activation_store.apply_norm_scaling_factor(activations)
664
+ activations = activation_scaler.scale(activations)
665
665
 
666
666
  new_activations = sae.decode(sae.encode(activations[:, :, head_index])).to(
667
667
  activations.dtype
@@ -669,8 +669,7 @@ def get_recons_loss(
669
669
  activations[:, :, head_index] = new_activations
670
670
 
671
671
  # Unscale if activations were scaled prior to going into the SAE
672
- if activation_store.normalize_activations == "expected_average_only_in":
673
- activations = activation_store.unscale(activations)
672
+ activations = activation_scaler.unscale(activations)
674
673
 
675
674
  return activations.to(original_device)
676
675
 
@@ -849,6 +848,7 @@ def multiple_evals(
849
848
  scalar_metrics, feature_metrics = run_evals(
850
849
  sae=sae,
851
850
  activation_store=activation_store,
851
+ activation_scaler=ActivationScaler(),
852
852
  model=current_model,
853
853
  eval_config=eval_config,
854
854
  ignore_tokens={
@@ -2,23 +2,31 @@ import json
2
2
  import signal
3
3
  import sys
4
4
  from collections.abc import Sequence
5
+ from dataclasses import dataclass
5
6
  from pathlib import Path
6
- from typing import Any, cast
7
+ from typing import Any, Generic, cast
7
8
 
8
9
  import torch
9
10
  import wandb
10
- from safetensors.torch import save_file
11
11
  from simple_parsing import ArgumentParser
12
12
  from transformer_lens.hook_points import HookedRootModule
13
+ from typing_extensions import deprecated
13
14
 
14
15
  from sae_lens import logger
15
16
  from sae_lens.config import HfDataset, LanguageModelSAERunnerConfig
16
- from sae_lens.constants import RUNNER_CFG_FILENAME, SPARSITY_FILENAME
17
+ from sae_lens.constants import ACTIVATIONS_STORE_STATE_FILENAME, RUNNER_CFG_FILENAME
18
+ from sae_lens.evals import EvalConfig, run_evals
17
19
  from sae_lens.load_model import load_model
18
- from sae_lens.saes.sae import T_TRAINING_SAE_CONFIG, TrainingSAE, TrainingSAEConfig
20
+ from sae_lens.saes.sae import (
21
+ T_TRAINING_SAE,
22
+ T_TRAINING_SAE_CONFIG,
23
+ TrainingSAE,
24
+ TrainingSAEConfig,
25
+ )
26
+ from sae_lens.training.activation_scaler import ActivationScaler
19
27
  from sae_lens.training.activations_store import ActivationsStore
20
- from sae_lens.training.geometric_median import compute_geometric_median
21
28
  from sae_lens.training.sae_trainer import SAETrainer
29
+ from sae_lens.training.types import DataProvider
22
30
 
23
31
 
24
32
  class InterruptedException(Exception):
@@ -29,7 +37,58 @@ def interrupt_callback(sig_num: Any, stack_frame: Any): # noqa: ARG001
29
37
  raise InterruptedException()
30
38
 
31
39
 
32
- class SAETrainingRunner:
40
+ @dataclass
41
+ class LLMSaeEvaluator(Generic[T_TRAINING_SAE]):
42
+ model: HookedRootModule
43
+ activations_store: ActivationsStore
44
+ eval_batch_size_prompts: int | None
45
+ n_eval_batches: int
46
+ model_kwargs: dict[str, Any]
47
+
48
+ def __call__(
49
+ self,
50
+ sae: T_TRAINING_SAE,
51
+ data_provider: DataProvider,
52
+ activation_scaler: ActivationScaler,
53
+ ) -> dict[str, Any]:
54
+ ignore_tokens = set()
55
+ if self.activations_store.exclude_special_tokens is not None:
56
+ ignore_tokens = set(self.activations_store.exclude_special_tokens.tolist())
57
+
58
+ eval_config = EvalConfig(
59
+ batch_size_prompts=self.eval_batch_size_prompts,
60
+ n_eval_reconstruction_batches=self.n_eval_batches,
61
+ n_eval_sparsity_variance_batches=self.n_eval_batches,
62
+ compute_ce_loss=True,
63
+ compute_l2_norms=True,
64
+ compute_sparsity_metrics=True,
65
+ compute_variance_metrics=True,
66
+ )
67
+
68
+ eval_metrics, _ = run_evals(
69
+ sae=sae,
70
+ activation_store=self.activations_store,
71
+ model=self.model,
72
+ activation_scaler=activation_scaler,
73
+ eval_config=eval_config,
74
+ ignore_tokens=ignore_tokens,
75
+ model_kwargs=self.model_kwargs,
76
+ ) # not calculating featurwise metrics here.
77
+
78
+ # Remove eval metrics that are already logged during training
79
+ eval_metrics.pop("metrics/explained_variance", None)
80
+ eval_metrics.pop("metrics/explained_variance_std", None)
81
+ eval_metrics.pop("metrics/l0", None)
82
+ eval_metrics.pop("metrics/l1", None)
83
+ eval_metrics.pop("metrics/mse", None)
84
+
85
+ # Remove metrics that are not useful for wandb logging
86
+ eval_metrics.pop("metrics/total_tokens_evaluated", None)
87
+
88
+ return eval_metrics
89
+
90
+
91
+ class LanguageModelSAETrainingRunner:
33
92
  """
34
93
  Class to run the training of a Sparse Autoencoder (SAE) on a TransformerLens model.
35
94
  """
@@ -84,7 +143,6 @@ class SAETrainingRunner:
84
143
  self.cfg.get_training_sae_cfg_dict(),
85
144
  ).to_dict()
86
145
  )
87
- self._init_sae_group_b_decs()
88
146
  else:
89
147
  self.sae = override_sae
90
148
 
@@ -102,12 +160,20 @@ class SAETrainingRunner:
102
160
  id=self.cfg.logger.wandb_id,
103
161
  )
104
162
 
105
- trainer = SAETrainer(
163
+ evaluator = LLMSaeEvaluator(
106
164
  model=self.model,
165
+ activations_store=self.activations_store,
166
+ eval_batch_size_prompts=self.cfg.eval_batch_size_prompts,
167
+ n_eval_batches=self.cfg.n_eval_batches,
168
+ model_kwargs=self.cfg.model_kwargs,
169
+ )
170
+
171
+ trainer = SAETrainer(
107
172
  sae=self.sae,
108
- activation_store=self.activations_store,
173
+ data_provider=self.activations_store,
174
+ evaluator=evaluator,
109
175
  save_checkpoint_fn=self.save_checkpoint,
110
- cfg=self.cfg,
176
+ cfg=self.cfg.to_sae_trainer_config(),
111
177
  )
112
178
 
113
179
  self._compile_if_needed()
@@ -156,66 +222,27 @@ class SAETrainingRunner:
156
222
 
157
223
  except (KeyboardInterrupt, InterruptedException):
158
224
  logger.warning("interrupted, saving progress")
159
- checkpoint_name = str(trainer.n_training_tokens)
160
- self.save_checkpoint(trainer, checkpoint_name=checkpoint_name)
225
+ checkpoint_path = Path(self.cfg.checkpoint_path) / str(
226
+ trainer.n_training_samples
227
+ )
228
+ self.save_checkpoint(checkpoint_path)
161
229
  logger.info("done saving")
162
230
  raise
163
231
 
164
232
  return sae
165
233
 
166
- # TODO: move this into the SAE trainer or Training SAE class
167
- def _init_sae_group_b_decs(
168
- self,
169
- ) -> None:
170
- """
171
- extract all activations at a certain layer and use for sae b_dec initialization
172
- """
173
-
174
- if self.cfg.sae.b_dec_init_method == "geometric_median":
175
- self.activations_store.set_norm_scaling_factor_if_needed()
176
- layer_acts = self.activations_store.storage_buffer.detach()[:, 0, :]
177
- # get geometric median of the activations if we're using those.
178
- median = compute_geometric_median(
179
- layer_acts,
180
- maxiter=100,
181
- ).median
182
- self.sae.initialize_b_dec_with_precalculated(median)
183
- elif self.cfg.sae.b_dec_init_method == "mean":
184
- self.activations_store.set_norm_scaling_factor_if_needed()
185
- layer_acts = self.activations_store.storage_buffer.detach().cpu()[:, 0, :]
186
- self.sae.initialize_b_dec_with_mean(layer_acts) # type: ignore
187
-
188
- @staticmethod
189
234
  def save_checkpoint(
190
- trainer: SAETrainer[TrainingSAE[Any], Any],
191
- checkpoint_name: str,
192
- wandb_aliases: list[str] | None = None,
235
+ self,
236
+ checkpoint_path: Path,
193
237
  ) -> None:
194
- base_path = Path(trainer.cfg.checkpoint_path) / checkpoint_name
195
- base_path.mkdir(exist_ok=True, parents=True)
196
-
197
- trainer.activations_store.save(
198
- str(base_path / "activations_store_state.safetensors")
238
+ self.activations_store.save(
239
+ str(checkpoint_path / ACTIVATIONS_STORE_STATE_FILENAME)
199
240
  )
200
241
 
201
- weights_path, cfg_path = trainer.sae.save_model(str(base_path))
202
-
203
- sparsity_path = base_path / SPARSITY_FILENAME
204
- save_file({"sparsity": trainer.log_feature_sparsity}, sparsity_path)
205
-
206
- runner_config = trainer.cfg.to_dict()
207
- with open(base_path / RUNNER_CFG_FILENAME, "w") as f:
242
+ runner_config = self.cfg.to_dict()
243
+ with open(checkpoint_path / RUNNER_CFG_FILENAME, "w") as f:
208
244
  json.dump(runner_config, f)
209
245
 
210
- if trainer.cfg.logger.log_to_wandb:
211
- trainer.cfg.logger.log(
212
- trainer,
213
- weights_path,
214
- cfg_path,
215
- sparsity_path=sparsity_path,
216
- wandb_aliases=wandb_aliases,
217
- )
218
-
219
246
 
220
247
  def _parse_cfg_args(
221
248
  args: Sequence[str],
@@ -230,8 +257,13 @@ def _parse_cfg_args(
230
257
  # moved into its own function to make it easier to test
231
258
  def _run_cli(args: Sequence[str]):
232
259
  cfg = _parse_cfg_args(args)
233
- SAETrainingRunner(cfg=cfg).run()
260
+ LanguageModelSAETrainingRunner(cfg=cfg).run()
234
261
 
235
262
 
236
263
  if __name__ == "__main__":
237
264
  _run_cli(args=sys.argv[1:])
265
+
266
+
267
+ @deprecated("Use LanguageModelSAETrainingRunner instead")
268
+ class SAETrainingRunner(LanguageModelSAETrainingRunner):
269
+ pass
sae_lens/load_model.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Any, Literal, cast
1
+ from typing import Any, Callable, Literal, cast
2
2
 
3
3
  import torch
4
4
  from transformer_lens import HookedTransformer
@@ -77,6 +77,7 @@ class HookedProxyLM(HookedRootModule):
77
77
  # copied and modified from base HookedRootModule
78
78
  def setup(self):
79
79
  self.mod_dict = {}
80
+ self.named_modules_dict = {}
80
81
  self.hook_dict: dict[str, HookPoint] = {}
81
82
  for name, module in self.model.named_modules():
82
83
  if name == "":
@@ -89,14 +90,21 @@ class HookedProxyLM(HookedRootModule):
89
90
 
90
91
  self.hook_dict[name] = hook_point
91
92
  self.mod_dict[name] = hook_point
93
+ self.named_modules_dict[name] = module
94
+
95
+ def run_with_cache(self, *args: Any, **kwargs: Any): # type: ignore
96
+ if "names_filter" in kwargs:
97
+ # hacky way to make sure that the names_filter is passed to our forward method
98
+ kwargs["_names_filter"] = kwargs["names_filter"]
99
+ return super().run_with_cache(*args, **kwargs)
92
100
 
93
101
  def forward(
94
102
  self,
95
103
  tokens: torch.Tensor,
96
104
  return_type: Literal["both", "logits"] = "logits",
97
105
  loss_per_token: bool = False,
98
- # TODO: implement real support for stop_at_layer
99
106
  stop_at_layer: int | None = None,
107
+ _names_filter: list[str] | None = None,
100
108
  **kwargs: Any,
101
109
  ) -> Output | Loss:
102
110
  # This is just what's needed for evals, not everything that HookedTransformer has
@@ -107,8 +115,28 @@ class HookedProxyLM(HookedRootModule):
107
115
  raise NotImplementedError(
108
116
  "Only return_type supported is 'both' or 'logits' to match what's in evals.py and ActivationsStore"
109
117
  )
110
- output = self.model(tokens)
111
- logits = _extract_logits_from_output(output)
118
+
119
+ stop_hooks = []
120
+ if stop_at_layer is not None and _names_filter is not None:
121
+ if return_type != "logits":
122
+ raise NotImplementedError(
123
+ "stop_at_layer is not supported for return_type='both'"
124
+ )
125
+ stop_manager = StopManager(_names_filter)
126
+
127
+ for hook_name in _names_filter:
128
+ module = self.named_modules_dict[hook_name]
129
+ stop_fn = stop_manager.get_stop_hook_fn(hook_name)
130
+ stop_hooks.append(module.register_forward_hook(stop_fn))
131
+ try:
132
+ output = self.model(tokens)
133
+ logits = _extract_logits_from_output(output)
134
+ except StopForward:
135
+ # If we stop early, we don't care about the return output
136
+ return None # type: ignore
137
+ finally:
138
+ for stop_hook in stop_hooks:
139
+ stop_hook.remove()
112
140
 
113
141
  if return_type == "logits":
114
142
  return logits
@@ -159,7 +187,7 @@ class HookedProxyLM(HookedRootModule):
159
187
 
160
188
  # We don't want to prepend bos but the tokenizer does it automatically, so we remove it manually
161
189
  if hasattr(self.tokenizer, "add_bos_token") and self.tokenizer.add_bos_token: # type: ignore
162
- tokens = get_tokens_with_bos_removed(self.tokenizer, tokens)
190
+ tokens = get_tokens_with_bos_removed(self.tokenizer, tokens) # type: ignore
163
191
  return tokens # type: ignore
164
192
 
165
193
 
@@ -183,3 +211,23 @@ def get_hook_fn(hook_point: HookPoint):
183
211
  return output
184
212
 
185
213
  return hook_fn
214
+
215
+
216
+ class StopForward(Exception):
217
+ pass
218
+
219
+
220
+ class StopManager:
221
+ def __init__(self, hook_names: list[str]):
222
+ self.hook_names = hook_names
223
+ self.total_hook_names = len(set(hook_names))
224
+ self.called_hook_names = set()
225
+
226
+ def get_stop_hook_fn(self, hook_name: str) -> Callable[[Any, Any, Any], Any]:
227
+ def stop_hook_fn(module: Any, input: Any, output: Any) -> Any: # noqa: ARG001
228
+ self.called_hook_names.add(hook_name)
229
+ if len(self.called_hook_names) == self.total_hook_names:
230
+ raise StopForward()
231
+ return output
232
+
233
+ return stop_hook_fn