sae-lens 6.0.0rc1__py3-none-any.whl → 6.0.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sae_lens/evals.py CHANGED
@@ -20,8 +20,10 @@ from transformer_lens import HookedTransformer
20
20
  from transformer_lens.hook_points import HookedRootModule
21
21
 
22
22
  from sae_lens.loading.pretrained_saes_directory import get_pretrained_saes_directory
23
- from sae_lens.saes.sae import SAE
23
+ from sae_lens.saes.sae import SAE, SAEConfig
24
+ from sae_lens.training.activation_scaler import ActivationScaler
24
25
  from sae_lens.training.activations_store import ActivationsStore
26
+ from sae_lens.util import extract_stop_at_layer_from_tlens_hook_name
25
27
 
26
28
 
27
29
  def get_library_version() -> str:
@@ -100,15 +102,16 @@ def get_eval_everything_config(
100
102
 
101
103
  @torch.no_grad()
102
104
  def run_evals(
103
- sae: SAE,
105
+ sae: SAE[Any],
104
106
  activation_store: ActivationsStore,
105
107
  model: HookedRootModule,
108
+ activation_scaler: ActivationScaler,
106
109
  eval_config: EvalConfig = EvalConfig(),
107
110
  model_kwargs: Mapping[str, Any] = {},
108
111
  ignore_tokens: set[int | None] = set(),
109
112
  verbose: bool = False,
110
113
  ) -> tuple[dict[str, Any], dict[str, Any]]:
111
- hook_name = sae.cfg.hook_name
114
+ hook_name = sae.cfg.metadata.hook_name
112
115
  actual_batch_size = (
113
116
  eval_config.batch_size_prompts or activation_store.store_batch_size_prompts
114
117
  )
@@ -140,6 +143,7 @@ def run_evals(
140
143
  sae,
141
144
  model,
142
145
  activation_store,
146
+ activation_scaler,
143
147
  compute_kl=eval_config.compute_kl,
144
148
  compute_ce_loss=eval_config.compute_ce_loss,
145
149
  n_batches=eval_config.n_eval_reconstruction_batches,
@@ -189,6 +193,7 @@ def run_evals(
189
193
  sae,
190
194
  model,
191
195
  activation_store,
196
+ activation_scaler,
192
197
  compute_l2_norms=eval_config.compute_l2_norms,
193
198
  compute_sparsity_metrics=eval_config.compute_sparsity_metrics,
194
199
  compute_variance_metrics=eval_config.compute_variance_metrics,
@@ -274,7 +279,7 @@ def run_evals(
274
279
  return all_metrics, feature_metrics
275
280
 
276
281
 
277
- def get_featurewise_weight_based_metrics(sae: SAE) -> dict[str, Any]:
282
+ def get_featurewise_weight_based_metrics(sae: SAE[Any]) -> dict[str, Any]:
278
283
  unit_norm_encoders = (sae.W_enc / sae.W_enc.norm(dim=0, keepdim=True)).cpu()
279
284
  unit_norm_decoder = (sae.W_dec.T / sae.W_dec.T.norm(dim=0, keepdim=True)).cpu()
280
285
 
@@ -298,9 +303,10 @@ def get_featurewise_weight_based_metrics(sae: SAE) -> dict[str, Any]:
298
303
 
299
304
 
300
305
  def get_downstream_reconstruction_metrics(
301
- sae: SAE,
306
+ sae: SAE[Any],
302
307
  model: HookedRootModule,
303
308
  activation_store: ActivationsStore,
309
+ activation_scaler: ActivationScaler,
304
310
  compute_kl: bool,
305
311
  compute_ce_loss: bool,
306
312
  n_batches: int,
@@ -326,8 +332,8 @@ def get_downstream_reconstruction_metrics(
326
332
  for metric_name, metric_value in get_recons_loss(
327
333
  sae,
328
334
  model,
335
+ activation_scaler,
329
336
  batch_tokens,
330
- activation_store,
331
337
  compute_kl=compute_kl,
332
338
  compute_ce_loss=compute_ce_loss,
333
339
  ignore_tokens=ignore_tokens,
@@ -366,9 +372,10 @@ def get_downstream_reconstruction_metrics(
366
372
 
367
373
 
368
374
  def get_sparsity_and_variance_metrics(
369
- sae: SAE,
375
+ sae: SAE[Any],
370
376
  model: HookedRootModule,
371
377
  activation_store: ActivationsStore,
378
+ activation_scaler: ActivationScaler,
372
379
  n_batches: int,
373
380
  compute_l2_norms: bool,
374
381
  compute_sparsity_metrics: bool,
@@ -379,8 +386,8 @@ def get_sparsity_and_variance_metrics(
379
386
  ignore_tokens: set[int | None] = set(),
380
387
  verbose: bool = False,
381
388
  ) -> tuple[dict[str, Any], dict[str, Any]]:
382
- hook_name = sae.cfg.hook_name
383
- hook_head_index = sae.cfg.hook_head_index
389
+ hook_name = sae.cfg.metadata.hook_name
390
+ hook_head_index = sae.cfg.metadata.hook_head_index
384
391
 
385
392
  metric_dict = {}
386
393
  feature_metric_dict = {}
@@ -436,7 +443,7 @@ def get_sparsity_and_variance_metrics(
436
443
  batch_tokens,
437
444
  prepend_bos=False,
438
445
  names_filter=[hook_name],
439
- stop_at_layer=sae.cfg.hook_layer + 1,
446
+ stop_at_layer=extract_stop_at_layer_from_tlens_hook_name(hook_name),
440
447
  **model_kwargs,
441
448
  )
442
449
 
@@ -451,16 +458,14 @@ def get_sparsity_and_variance_metrics(
451
458
  original_act = cache[hook_name]
452
459
 
453
460
  # normalise if necessary (necessary in training only, otherwise we should fold the scaling in)
454
- if activation_store.normalize_activations == "expected_average_only_in":
455
- original_act = activation_store.apply_norm_scaling_factor(original_act)
461
+ original_act = activation_scaler.scale(original_act)
456
462
 
457
463
  # send the (maybe normalised) activations into the SAE
458
464
  sae_feature_activations = sae.encode(original_act.to(sae.device))
459
465
  sae_out = sae.decode(sae_feature_activations).to(original_act.device)
460
466
  del cache
461
467
 
462
- if activation_store.normalize_activations == "expected_average_only_in":
463
- sae_out = activation_store.unscale(sae_out)
468
+ sae_out = activation_scaler.unscale(sae_out)
464
469
 
465
470
  flattened_sae_input = einops.rearrange(original_act, "b ctx d -> (b ctx) d")
466
471
  flattened_sae_feature_acts = einops.rearrange(
@@ -580,17 +585,21 @@ def get_sparsity_and_variance_metrics(
580
585
 
581
586
  @torch.no_grad()
582
587
  def get_recons_loss(
583
- sae: SAE,
588
+ sae: SAE[SAEConfig],
584
589
  model: HookedRootModule,
590
+ activation_scaler: ActivationScaler,
585
591
  batch_tokens: torch.Tensor,
586
- activation_store: ActivationsStore,
587
592
  compute_kl: bool,
588
593
  compute_ce_loss: bool,
589
594
  ignore_tokens: set[int | None] = set(),
590
595
  model_kwargs: Mapping[str, Any] = {},
596
+ hook_name: str | None = None,
591
597
  ) -> dict[str, Any]:
592
- hook_name = sae.cfg.hook_name
593
- head_index = sae.cfg.hook_head_index
598
+ hook_name = hook_name or sae.cfg.metadata.hook_name
599
+ head_index = sae.cfg.metadata.hook_head_index
600
+
601
+ if hook_name is None:
602
+ raise ValueError("hook_name must be provided")
594
603
 
595
604
  original_logits, original_ce_loss = model(
596
605
  batch_tokens, return_type="both", loss_per_token=True, **model_kwargs
@@ -614,15 +623,13 @@ def get_recons_loss(
614
623
  activations = activations.to(sae.device)
615
624
 
616
625
  # Handle rescaling if SAE expects it
617
- if activation_store.normalize_activations == "expected_average_only_in":
618
- activations = activation_store.apply_norm_scaling_factor(activations)
626
+ activations = activation_scaler.scale(activations)
619
627
 
620
628
  # SAE class agnost forward forward pass.
621
629
  new_activations = sae.decode(sae.encode(activations)).to(activations.dtype)
622
630
 
623
631
  # Unscale if activations were scaled prior to going into the SAE
624
- if activation_store.normalize_activations == "expected_average_only_in":
625
- new_activations = activation_store.unscale(new_activations)
632
+ new_activations = activation_scaler.unscale(new_activations)
626
633
 
627
634
  new_activations = torch.where(mask[..., None], new_activations, activations)
628
635
 
@@ -633,8 +640,7 @@ def get_recons_loss(
633
640
  activations = activations.to(sae.device)
634
641
 
635
642
  # Handle rescaling if SAE expects it
636
- if activation_store.normalize_activations == "expected_average_only_in":
637
- activations = activation_store.apply_norm_scaling_factor(activations)
643
+ activations = activation_scaler.scale(activations)
638
644
 
639
645
  # SAE class agnost forward forward pass.
640
646
  new_activations = sae.decode(sae.encode(activations.flatten(-2, -1))).to(
@@ -646,8 +652,7 @@ def get_recons_loss(
646
652
  ) # reshape to match original shape
647
653
 
648
654
  # Unscale if activations were scaled prior to going into the SAE
649
- if activation_store.normalize_activations == "expected_average_only_in":
650
- new_activations = activation_store.unscale(new_activations)
655
+ new_activations = activation_scaler.unscale(new_activations)
651
656
 
652
657
  return new_activations.to(original_device)
653
658
 
@@ -656,8 +661,7 @@ def get_recons_loss(
656
661
  activations = activations.to(sae.device)
657
662
 
658
663
  # Handle rescaling if SAE expects it
659
- if activation_store.normalize_activations == "expected_average_only_in":
660
- activations = activation_store.apply_norm_scaling_factor(activations)
664
+ activations = activation_scaler.scale(activations)
661
665
 
662
666
  new_activations = sae.decode(sae.encode(activations[:, :, head_index])).to(
663
667
  activations.dtype
@@ -665,8 +669,7 @@ def get_recons_loss(
665
669
  activations[:, :, head_index] = new_activations
666
670
 
667
671
  # Unscale if activations were scaled prior to going into the SAE
668
- if activation_store.normalize_activations == "expected_average_only_in":
669
- activations = activation_store.unscale(activations)
672
+ activations = activation_scaler.unscale(activations)
670
673
 
671
674
  return activations.to(original_device)
672
675
 
@@ -806,7 +809,6 @@ def multiple_evals(
806
809
 
807
810
  current_model = None
808
811
  current_model_str = None
809
- print(filtered_saes)
810
812
  for sae_release_name, sae_id, _, _ in tqdm(filtered_saes):
811
813
  sae = SAE.from_pretrained(
812
814
  release=sae_release_name, # see other options in sae_lens/pretrained_saes.yaml
@@ -846,6 +848,7 @@ def multiple_evals(
846
848
  scalar_metrics, feature_metrics = run_evals(
847
849
  sae=sae,
848
850
  activation_store=activation_store,
851
+ activation_scaler=ActivationScaler(),
849
852
  model=current_model,
850
853
  eval_config=eval_config,
851
854
  ignore_tokens={
@@ -2,21 +2,31 @@ import json
2
2
  import signal
3
3
  import sys
4
4
  from collections.abc import Sequence
5
+ from dataclasses import dataclass
5
6
  from pathlib import Path
6
- from typing import Any, cast
7
+ from typing import Any, Generic, cast
7
8
 
8
9
  import torch
9
10
  import wandb
10
11
  from simple_parsing import ArgumentParser
11
12
  from transformer_lens.hook_points import HookedRootModule
13
+ from typing_extensions import deprecated
12
14
 
13
15
  from sae_lens import logger
14
16
  from sae_lens.config import HfDataset, LanguageModelSAERunnerConfig
17
+ from sae_lens.constants import ACTIVATIONS_STORE_STATE_FILENAME, RUNNER_CFG_FILENAME
18
+ from sae_lens.evals import EvalConfig, run_evals
15
19
  from sae_lens.load_model import load_model
16
- from sae_lens.saes.sae import TrainingSAE, TrainingSAEConfig
20
+ from sae_lens.saes.sae import (
21
+ T_TRAINING_SAE,
22
+ T_TRAINING_SAE_CONFIG,
23
+ TrainingSAE,
24
+ TrainingSAEConfig,
25
+ )
26
+ from sae_lens.training.activation_scaler import ActivationScaler
17
27
  from sae_lens.training.activations_store import ActivationsStore
18
- from sae_lens.training.geometric_median import compute_geometric_median
19
28
  from sae_lens.training.sae_trainer import SAETrainer
29
+ from sae_lens.training.types import DataProvider
20
30
 
21
31
 
22
32
  class InterruptedException(Exception):
@@ -27,22 +37,73 @@ def interrupt_callback(sig_num: Any, stack_frame: Any): # noqa: ARG001
27
37
  raise InterruptedException()
28
38
 
29
39
 
30
- class SAETrainingRunner:
40
+ @dataclass
41
+ class LLMSaeEvaluator(Generic[T_TRAINING_SAE]):
42
+ model: HookedRootModule
43
+ activations_store: ActivationsStore
44
+ eval_batch_size_prompts: int | None
45
+ n_eval_batches: int
46
+ model_kwargs: dict[str, Any]
47
+
48
+ def __call__(
49
+ self,
50
+ sae: T_TRAINING_SAE,
51
+ data_provider: DataProvider,
52
+ activation_scaler: ActivationScaler,
53
+ ) -> dict[str, Any]:
54
+ ignore_tokens = set()
55
+ if self.activations_store.exclude_special_tokens is not None:
56
+ ignore_tokens = set(self.activations_store.exclude_special_tokens.tolist())
57
+
58
+ eval_config = EvalConfig(
59
+ batch_size_prompts=self.eval_batch_size_prompts,
60
+ n_eval_reconstruction_batches=self.n_eval_batches,
61
+ n_eval_sparsity_variance_batches=self.n_eval_batches,
62
+ compute_ce_loss=True,
63
+ compute_l2_norms=True,
64
+ compute_sparsity_metrics=True,
65
+ compute_variance_metrics=True,
66
+ )
67
+
68
+ eval_metrics, _ = run_evals(
69
+ sae=sae,
70
+ activation_store=self.activations_store,
71
+ model=self.model,
72
+ activation_scaler=activation_scaler,
73
+ eval_config=eval_config,
74
+ ignore_tokens=ignore_tokens,
75
+ model_kwargs=self.model_kwargs,
76
+ ) # not calculating featurwise metrics here.
77
+
78
+ # Remove eval metrics that are already logged during training
79
+ eval_metrics.pop("metrics/explained_variance", None)
80
+ eval_metrics.pop("metrics/explained_variance_std", None)
81
+ eval_metrics.pop("metrics/l0", None)
82
+ eval_metrics.pop("metrics/l1", None)
83
+ eval_metrics.pop("metrics/mse", None)
84
+
85
+ # Remove metrics that are not useful for wandb logging
86
+ eval_metrics.pop("metrics/total_tokens_evaluated", None)
87
+
88
+ return eval_metrics
89
+
90
+
91
+ class LanguageModelSAETrainingRunner:
31
92
  """
32
93
  Class to run the training of a Sparse Autoencoder (SAE) on a TransformerLens model.
33
94
  """
34
95
 
35
- cfg: LanguageModelSAERunnerConfig
96
+ cfg: LanguageModelSAERunnerConfig[Any]
36
97
  model: HookedRootModule
37
- sae: TrainingSAE
98
+ sae: TrainingSAE[Any]
38
99
  activations_store: ActivationsStore
39
100
 
40
101
  def __init__(
41
102
  self,
42
- cfg: LanguageModelSAERunnerConfig,
103
+ cfg: LanguageModelSAERunnerConfig[T_TRAINING_SAE_CONFIG],
43
104
  override_dataset: HfDataset | None = None,
44
105
  override_model: HookedRootModule | None = None,
45
- override_sae: TrainingSAE | None = None,
106
+ override_sae: TrainingSAE[Any] | None = None,
46
107
  ):
47
108
  if override_dataset is not None:
48
109
  logger.warning(
@@ -82,7 +143,6 @@ class SAETrainingRunner:
82
143
  self.cfg.get_training_sae_cfg_dict(),
83
144
  ).to_dict()
84
145
  )
85
- self._init_sae_group_b_decs()
86
146
  else:
87
147
  self.sae = override_sae
88
148
 
@@ -100,12 +160,20 @@ class SAETrainingRunner:
100
160
  id=self.cfg.logger.wandb_id,
101
161
  )
102
162
 
103
- trainer = SAETrainer(
163
+ evaluator = LLMSaeEvaluator(
104
164
  model=self.model,
165
+ activations_store=self.activations_store,
166
+ eval_batch_size_prompts=self.cfg.eval_batch_size_prompts,
167
+ n_eval_batches=self.cfg.n_eval_batches,
168
+ model_kwargs=self.cfg.model_kwargs,
169
+ )
170
+
171
+ trainer = SAETrainer(
105
172
  sae=self.sae,
106
- activation_store=self.activations_store,
173
+ data_provider=self.activations_store,
174
+ evaluator=evaluator,
107
175
  save_checkpoint_fn=self.save_checkpoint,
108
- cfg=self.cfg,
176
+ cfg=self.cfg.to_sae_trainer_config(),
109
177
  )
110
178
 
111
179
  self._compile_if_needed()
@@ -141,7 +209,9 @@ class SAETrainingRunner:
141
209
  backend=backend,
142
210
  ) # type: ignore
143
211
 
144
- def run_trainer_with_interruption_handling(self, trainer: SAETrainer):
212
+ def run_trainer_with_interruption_handling(
213
+ self, trainer: SAETrainer[TrainingSAE[TrainingSAEConfig], TrainingSAEConfig]
214
+ ):
145
215
  try:
146
216
  # signal handlers (if preempted)
147
217
  signal.signal(signal.SIGINT, interrupt_callback)
@@ -152,73 +222,31 @@ class SAETrainingRunner:
152
222
 
153
223
  except (KeyboardInterrupt, InterruptedException):
154
224
  logger.warning("interrupted, saving progress")
155
- checkpoint_name = str(trainer.n_training_tokens)
156
- self.save_checkpoint(trainer, checkpoint_name=checkpoint_name)
225
+ checkpoint_path = Path(self.cfg.checkpoint_path) / str(
226
+ trainer.n_training_samples
227
+ )
228
+ self.save_checkpoint(checkpoint_path)
157
229
  logger.info("done saving")
158
230
  raise
159
231
 
160
232
  return sae
161
233
 
162
- # TODO: move this into the SAE trainer or Training SAE class
163
- def _init_sae_group_b_decs(
164
- self,
165
- ) -> None:
166
- """
167
- extract all activations at a certain layer and use for sae b_dec initialization
168
- """
169
-
170
- if self.cfg.b_dec_init_method == "geometric_median":
171
- self.activations_store.set_norm_scaling_factor_if_needed()
172
- layer_acts = self.activations_store.storage_buffer.detach()[:, 0, :]
173
- # get geometric median of the activations if we're using those.
174
- median = compute_geometric_median(
175
- layer_acts,
176
- maxiter=100,
177
- ).median
178
- self.sae.initialize_b_dec_with_precalculated(median)
179
- elif self.cfg.b_dec_init_method == "mean":
180
- self.activations_store.set_norm_scaling_factor_if_needed()
181
- layer_acts = self.activations_store.storage_buffer.detach().cpu()[:, 0, :]
182
- self.sae.initialize_b_dec_with_mean(layer_acts) # type: ignore
183
-
184
- @staticmethod
185
234
  def save_checkpoint(
186
- trainer: SAETrainer,
187
- checkpoint_name: str,
188
- wandb_aliases: list[str] | None = None,
235
+ self,
236
+ checkpoint_path: Path,
189
237
  ) -> None:
190
- base_path = Path(trainer.cfg.checkpoint_path) / checkpoint_name
191
- base_path.mkdir(exist_ok=True, parents=True)
192
-
193
- trainer.activations_store.save(
194
- str(base_path / "activations_store_state.safetensors")
195
- )
196
-
197
- if trainer.sae.cfg.normalize_sae_decoder:
198
- trainer.sae.set_decoder_norm_to_unit_norm()
199
-
200
- weights_path, cfg_path, sparsity_path = trainer.sae.save_model(
201
- str(base_path),
202
- trainer.log_feature_sparsity,
238
+ self.activations_store.save(
239
+ str(checkpoint_path / ACTIVATIONS_STORE_STATE_FILENAME)
203
240
  )
204
241
 
205
- # let's over write the cfg file with the trainer cfg, which is a super set of the original cfg.
206
- # and should not cause issues but give us more info about SAEs we trained in SAE Lens.
207
- config = trainer.cfg.to_dict()
208
- with open(cfg_path, "w") as f:
209
- json.dump(config, f)
210
-
211
- if trainer.cfg.logger.log_to_wandb:
212
- trainer.cfg.logger.log(
213
- trainer,
214
- weights_path,
215
- cfg_path,
216
- sparsity_path=sparsity_path,
217
- wandb_aliases=wandb_aliases,
218
- )
242
+ runner_config = self.cfg.to_dict()
243
+ with open(checkpoint_path / RUNNER_CFG_FILENAME, "w") as f:
244
+ json.dump(runner_config, f)
219
245
 
220
246
 
221
- def _parse_cfg_args(args: Sequence[str]) -> LanguageModelSAERunnerConfig:
247
+ def _parse_cfg_args(
248
+ args: Sequence[str],
249
+ ) -> LanguageModelSAERunnerConfig[TrainingSAEConfig]:
222
250
  if len(args) == 0:
223
251
  args = ["--help"]
224
252
  parser = ArgumentParser(exit_on_error=False)
@@ -229,8 +257,13 @@ def _parse_cfg_args(args: Sequence[str]) -> LanguageModelSAERunnerConfig:
229
257
  # moved into its own function to make it easier to test
230
258
  def _run_cli(args: Sequence[str]):
231
259
  cfg = _parse_cfg_args(args)
232
- SAETrainingRunner(cfg=cfg).run()
260
+ LanguageModelSAETrainingRunner(cfg=cfg).run()
233
261
 
234
262
 
235
263
  if __name__ == "__main__":
236
264
  _run_cli(args=sys.argv[1:])
265
+
266
+
267
+ @deprecated("Use LanguageModelSAETrainingRunner instead")
268
+ class SAETrainingRunner(LanguageModelSAETrainingRunner):
269
+ pass
sae_lens/load_model.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Any, Literal, cast
1
+ from typing import Any, Callable, Literal, cast
2
2
 
3
3
  import torch
4
4
  from transformer_lens import HookedTransformer
@@ -77,6 +77,7 @@ class HookedProxyLM(HookedRootModule):
77
77
  # copied and modified from base HookedRootModule
78
78
  def setup(self):
79
79
  self.mod_dict = {}
80
+ self.named_modules_dict = {}
80
81
  self.hook_dict: dict[str, HookPoint] = {}
81
82
  for name, module in self.model.named_modules():
82
83
  if name == "":
@@ -89,14 +90,21 @@ class HookedProxyLM(HookedRootModule):
89
90
 
90
91
  self.hook_dict[name] = hook_point
91
92
  self.mod_dict[name] = hook_point
93
+ self.named_modules_dict[name] = module
94
+
95
+ def run_with_cache(self, *args: Any, **kwargs: Any): # type: ignore
96
+ if "names_filter" in kwargs:
97
+ # hacky way to make sure that the names_filter is passed to our forward method
98
+ kwargs["_names_filter"] = kwargs["names_filter"]
99
+ return super().run_with_cache(*args, **kwargs)
92
100
 
93
101
  def forward(
94
102
  self,
95
103
  tokens: torch.Tensor,
96
104
  return_type: Literal["both", "logits"] = "logits",
97
105
  loss_per_token: bool = False,
98
- # TODO: implement real support for stop_at_layer
99
106
  stop_at_layer: int | None = None,
107
+ _names_filter: list[str] | None = None,
100
108
  **kwargs: Any,
101
109
  ) -> Output | Loss:
102
110
  # This is just what's needed for evals, not everything that HookedTransformer has
@@ -107,8 +115,28 @@ class HookedProxyLM(HookedRootModule):
107
115
  raise NotImplementedError(
108
116
  "Only return_type supported is 'both' or 'logits' to match what's in evals.py and ActivationsStore"
109
117
  )
110
- output = self.model(tokens)
111
- logits = _extract_logits_from_output(output)
118
+
119
+ stop_hooks = []
120
+ if stop_at_layer is not None and _names_filter is not None:
121
+ if return_type != "logits":
122
+ raise NotImplementedError(
123
+ "stop_at_layer is not supported for return_type='both'"
124
+ )
125
+ stop_manager = StopManager(_names_filter)
126
+
127
+ for hook_name in _names_filter:
128
+ module = self.named_modules_dict[hook_name]
129
+ stop_fn = stop_manager.get_stop_hook_fn(hook_name)
130
+ stop_hooks.append(module.register_forward_hook(stop_fn))
131
+ try:
132
+ output = self.model(tokens)
133
+ logits = _extract_logits_from_output(output)
134
+ except StopForward:
135
+ # If we stop early, we don't care about the return output
136
+ return None # type: ignore
137
+ finally:
138
+ for stop_hook in stop_hooks:
139
+ stop_hook.remove()
112
140
 
113
141
  if return_type == "logits":
114
142
  return logits
@@ -159,7 +187,7 @@ class HookedProxyLM(HookedRootModule):
159
187
 
160
188
  # We don't want to prepend bos but the tokenizer does it automatically, so we remove it manually
161
189
  if hasattr(self.tokenizer, "add_bos_token") and self.tokenizer.add_bos_token: # type: ignore
162
- tokens = get_tokens_with_bos_removed(self.tokenizer, tokens)
190
+ tokens = get_tokens_with_bos_removed(self.tokenizer, tokens) # type: ignore
163
191
  return tokens # type: ignore
164
192
 
165
193
 
@@ -183,3 +211,23 @@ def get_hook_fn(hook_point: HookPoint):
183
211
  return output
184
212
 
185
213
  return hook_fn
214
+
215
+
216
+ class StopForward(Exception):
217
+ pass
218
+
219
+
220
+ class StopManager:
221
+ def __init__(self, hook_names: list[str]):
222
+ self.hook_names = hook_names
223
+ self.total_hook_names = len(set(hook_names))
224
+ self.called_hook_names = set()
225
+
226
+ def get_stop_hook_fn(self, hook_name: str) -> Callable[[Any, Any, Any], Any]:
227
+ def stop_hook_fn(module: Any, input: Any, output: Any) -> Any: # noqa: ARG001
228
+ self.called_hook_names.add(hook_name)
229
+ if len(self.called_hook_names) == self.total_hook_names:
230
+ raise StopForward()
231
+ return output
232
+
233
+ return stop_hook_fn