sae-lens 6.0.0rc2__py3-none-any.whl → 6.0.0rc4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,5 @@
1
1
  from __future__ import annotations
2
2
 
3
- import contextlib
4
3
  import json
5
4
  import os
6
5
  import warnings
@@ -16,7 +15,6 @@ from huggingface_hub.utils import HfHubHTTPError
16
15
  from jaxtyping import Float, Int
17
16
  from requests import HTTPError
18
17
  from safetensors.torch import save_file
19
- from torch.utils.data import DataLoader
20
18
  from tqdm import tqdm
21
19
  from transformer_lens.hook_points import HookedRootModule
22
20
  from transformers import AutoTokenizer, PreTrainedTokenizerBase
@@ -30,6 +28,8 @@ from sae_lens.config import (
30
28
  from sae_lens.constants import DTYPE_MAP
31
29
  from sae_lens.saes.sae import SAE, T_SAE_CONFIG, T_TRAINING_SAE_CONFIG
32
30
  from sae_lens.tokenization_and_batching import concat_and_batch_sequences
31
+ from sae_lens.training.mixing_buffer import mixing_buffer
32
+ from sae_lens.util import extract_stop_at_layer_from_tlens_hook_name
33
33
 
34
34
 
35
35
  # TODO: Make an activation store config class to be consistent with the rest of the code.
@@ -45,10 +45,8 @@ class ActivationsStore:
45
45
  cached_activation_dataset: Dataset | None = None
46
46
  tokens_column: Literal["tokens", "input_ids", "text", "problem"]
47
47
  hook_name: str
48
- hook_layer: int
49
48
  hook_head_index: int | None
50
49
  _dataloader: Iterator[Any] | None = None
51
- _storage_buffer: torch.Tensor | None = None
52
50
  exclude_special_tokens: torch.Tensor | None = None
53
51
  device: torch.device
54
52
 
@@ -65,7 +63,6 @@ class ActivationsStore:
65
63
  cached_activations_path=cfg.new_cached_activations_path,
66
64
  dtype=cfg.dtype,
67
65
  hook_name=cfg.hook_name,
68
- hook_layer=cfg.hook_layer,
69
66
  context_size=cfg.context_size,
70
67
  d_in=cfg.d_in,
71
68
  n_batches_in_buffer=cfg.n_batches_in_buffer,
@@ -126,7 +123,6 @@ class ActivationsStore:
126
123
  dataset=override_dataset or cfg.dataset_path,
127
124
  streaming=cfg.streaming,
128
125
  hook_name=cfg.hook_name,
129
- hook_layer=cfg.hook_layer,
130
126
  hook_head_index=cfg.hook_head_index,
131
127
  context_size=cfg.context_size,
132
128
  d_in=cfg.d_in
@@ -165,10 +161,6 @@ class ActivationsStore:
165
161
  ) -> ActivationsStore:
166
162
  if sae.cfg.metadata.hook_name is None:
167
163
  raise ValueError("hook_name is required")
168
- if sae.cfg.metadata.hook_layer is None:
169
- raise ValueError("hook_layer is required")
170
- if sae.cfg.metadata.hook_head_index is None:
171
- raise ValueError("hook_head_index is required")
172
164
  if sae.cfg.metadata.context_size is None:
173
165
  raise ValueError("context_size is required")
174
166
  if sae.cfg.metadata.prepend_bos is None:
@@ -178,7 +170,6 @@ class ActivationsStore:
178
170
  dataset=dataset,
179
171
  d_in=sae.cfg.d_in,
180
172
  hook_name=sae.cfg.metadata.hook_name,
181
- hook_layer=sae.cfg.metadata.hook_layer,
182
173
  hook_head_index=sae.cfg.metadata.hook_head_index,
183
174
  context_size=sae.cfg.metadata.context_size
184
175
  if context_size is None
@@ -202,7 +193,6 @@ class ActivationsStore:
202
193
  dataset: HfDataset | str,
203
194
  streaming: bool,
204
195
  hook_name: str,
205
- hook_layer: int,
206
196
  hook_head_index: int | None,
207
197
  context_size: int,
208
198
  d_in: int,
@@ -246,7 +236,6 @@ class ActivationsStore:
246
236
  )
247
237
 
248
238
  self.hook_name = hook_name
249
- self.hook_layer = hook_layer
250
239
  self.hook_head_index = hook_head_index
251
240
  self.context_size = context_size
252
241
  self.d_in = d_in
@@ -262,12 +251,11 @@ class ActivationsStore:
262
251
  self.cached_activations_path = cached_activations_path
263
252
  self.autocast_lm = autocast_lm
264
253
  self.seqpos_slice = seqpos_slice
254
+ self.training_context_size = len(range(context_size)[slice(*seqpos_slice)])
265
255
  self.exclude_special_tokens = exclude_special_tokens
266
256
 
267
257
  self.n_dataset_processed = 0
268
258
 
269
- self.estimated_norm_scaling_factor = None
270
-
271
259
  # Check if dataset is tokenized
272
260
  dataset_sample = next(iter(self.dataset))
273
261
 
@@ -432,30 +420,6 @@ class ActivationsStore:
432
420
 
433
421
  return activations_dataset
434
422
 
435
- def set_norm_scaling_factor_if_needed(self):
436
- if (
437
- self.normalize_activations == "expected_average_only_in"
438
- and self.estimated_norm_scaling_factor is None
439
- ):
440
- self.estimated_norm_scaling_factor = self.estimate_norm_scaling_factor()
441
-
442
- def apply_norm_scaling_factor(self, activations: torch.Tensor) -> torch.Tensor:
443
- if self.estimated_norm_scaling_factor is None:
444
- raise ValueError(
445
- "estimated_norm_scaling_factor is not set, call set_norm_scaling_factor_if_needed() first"
446
- )
447
- return activations * self.estimated_norm_scaling_factor
448
-
449
- def unscale(self, activations: torch.Tensor) -> torch.Tensor:
450
- if self.estimated_norm_scaling_factor is None:
451
- raise ValueError(
452
- "estimated_norm_scaling_factor is not set, call set_norm_scaling_factor_if_needed() first"
453
- )
454
- return activations / self.estimated_norm_scaling_factor
455
-
456
- def get_norm_scaling_factor(self, activations: torch.Tensor) -> torch.Tensor:
457
- return (self.d_in**0.5) / activations.norm(dim=-1).mean()
458
-
459
423
  @torch.no_grad()
460
424
  def estimate_norm_scaling_factor(self, n_batches_for_norm_estimate: int = int(1e3)):
461
425
  norms_per_batch = []
@@ -490,21 +454,6 @@ class ActivationsStore:
490
454
  """
491
455
  self.iterable_dataset = iter(self.dataset)
492
456
 
493
- @property
494
- def storage_buffer(self) -> torch.Tensor:
495
- if self._storage_buffer is None:
496
- self._storage_buffer = _filter_buffer_acts(
497
- self.get_buffer(self.half_buffer_size), self.exclude_special_tokens
498
- )
499
-
500
- return self._storage_buffer
501
-
502
- @property
503
- def dataloader(self) -> Iterator[Any]:
504
- if self._dataloader is None:
505
- self._dataloader = self.get_data_loader()
506
- return self._dataloader
507
-
508
457
  def get_batch_tokens(
509
458
  self, batch_size: int | None = None, raise_at_epoch_end: bool = False
510
459
  ):
@@ -537,22 +486,17 @@ class ActivationsStore:
537
486
 
538
487
  d_in may result from a concatenated head dimension.
539
488
  """
540
-
541
- # Setup autocast if using
542
- if self.autocast_lm:
543
- autocast_if_enabled = torch.autocast(
544
- device_type="cuda",
545
- dtype=torch.bfloat16,
546
- enabled=self.autocast_lm,
547
- )
548
- else:
549
- autocast_if_enabled = contextlib.nullcontext()
550
-
551
- with autocast_if_enabled:
489
+ with torch.autocast(
490
+ device_type="cuda",
491
+ dtype=torch.bfloat16,
492
+ enabled=self.autocast_lm,
493
+ ):
552
494
  layerwise_activations_cache = self.model.run_with_cache(
553
495
  batch_tokens,
554
496
  names_filter=[self.hook_name],
555
- stop_at_layer=self.hook_layer + 1,
497
+ stop_at_layer=extract_stop_at_layer_from_tlens_hook_name(
498
+ self.hook_name
499
+ ),
556
500
  prepend_bos=False,
557
501
  **self.model_kwargs,
558
502
  )[1]
@@ -563,25 +507,25 @@ class ActivationsStore:
563
507
 
564
508
  n_batches, n_context = layerwise_activations.shape[:2]
565
509
 
566
- stacked_activations = torch.zeros((n_batches, n_context, 1, self.d_in))
510
+ stacked_activations = torch.zeros((n_batches, n_context, self.d_in))
567
511
 
568
512
  if self.hook_head_index is not None:
569
- stacked_activations[:, :, 0] = layerwise_activations[
513
+ stacked_activations[:, :] = layerwise_activations[
570
514
  :, :, self.hook_head_index
571
515
  ]
572
516
  elif layerwise_activations.ndim > 3: # if we have a head dimension
573
517
  try:
574
- stacked_activations[:, :, 0] = layerwise_activations.view(
518
+ stacked_activations[:, :] = layerwise_activations.view(
575
519
  n_batches, n_context, -1
576
520
  )
577
521
  except RuntimeError as e:
578
522
  logger.error(f"Error during view operation: {e}")
579
523
  logger.info("Attempting to use reshape instead...")
580
- stacked_activations[:, :, 0] = layerwise_activations.reshape(
524
+ stacked_activations[:, :] = layerwise_activations.reshape(
581
525
  n_batches, n_context, -1
582
526
  )
583
527
  else:
584
- stacked_activations[:, :, 0] = layerwise_activations
528
+ stacked_activations[:, :] = layerwise_activations
585
529
 
586
530
  return stacked_activations
587
531
 
@@ -589,7 +533,6 @@ class ActivationsStore:
589
533
  self,
590
534
  total_size: int,
591
535
  context_size: int,
592
- num_layers: int,
593
536
  d_in: int,
594
537
  raise_on_epoch_end: bool,
595
538
  ) -> tuple[
@@ -606,10 +549,9 @@ class ActivationsStore:
606
549
  """
607
550
  assert self.cached_activation_dataset is not None
608
551
  # In future, could be a list of multiple hook names
609
- hook_names = [self.hook_name]
610
- if not set(hook_names).issubset(self.cached_activation_dataset.column_names):
552
+ if self.hook_name not in self.cached_activation_dataset.column_names:
611
553
  raise ValueError(
612
- f"Missing columns in dataset. Expected {hook_names}, "
554
+ f"Missing columns in dataset. Expected {self.hook_name}, "
613
555
  f"got {self.cached_activation_dataset.column_names}."
614
556
  )
615
557
 
@@ -622,28 +564,17 @@ class ActivationsStore:
622
564
  ds_slice = self.cached_activation_dataset[
623
565
  self.current_row_idx : self.current_row_idx + total_size
624
566
  ]
625
- for hook_name in hook_names:
626
- # Load activations for each hook.
627
- # Usually faster to first slice dataset then pick column
628
- _hook_buffer = ds_slice[hook_name]
629
- if _hook_buffer.shape != (total_size, context_size, d_in):
630
- raise ValueError(
631
- f"_hook_buffer has shape {_hook_buffer.shape}, "
632
- f"but expected ({total_size}, {context_size}, {d_in})."
633
- )
634
- new_buffer.append(_hook_buffer)
635
-
636
- # Stack across num_layers dimension
637
- # list of num_layers; shape: (total_size, context_size, d_in) -> (total_size, context_size, num_layers, d_in)
638
- new_buffer = torch.stack(new_buffer, dim=2)
639
- if new_buffer.shape != (total_size, context_size, num_layers, d_in):
567
+ # Load activations for each hook.
568
+ # Usually faster to first slice dataset then pick column
569
+ new_buffer = ds_slice[self.hook_name]
570
+ if new_buffer.shape != (total_size, context_size, d_in):
640
571
  raise ValueError(
641
572
  f"new_buffer has shape {new_buffer.shape}, "
642
- f"but expected ({total_size}, {context_size}, {num_layers}, {d_in})."
573
+ f"but expected ({total_size}, {context_size}, {d_in})."
643
574
  )
644
575
 
645
576
  self.current_row_idx += total_size
646
- acts_buffer = new_buffer.reshape(total_size * context_size, num_layers, d_in)
577
+ acts_buffer = new_buffer.reshape(total_size * context_size, d_in)
647
578
 
648
579
  if "token_ids" not in self.cached_activation_dataset.column_names:
649
580
  return acts_buffer, None
@@ -658,7 +589,7 @@ class ActivationsStore:
658
589
  return acts_buffer, token_ids_buffer
659
590
 
660
591
  @torch.no_grad()
661
- def get_buffer(
592
+ def get_raw_buffer(
662
593
  self,
663
594
  n_batches_in_buffer: int,
664
595
  raise_on_epoch_end: bool = False,
@@ -672,26 +603,24 @@ class ActivationsStore:
672
603
  If raise_on_epoch_end is True, when the dataset it exhausted it will automatically refill the dataset and then raise a StopIteration so that the caller has a chance to react.
673
604
  """
674
605
  context_size = self.context_size
675
- training_context_size = len(range(context_size)[slice(*self.seqpos_slice)])
676
606
  batch_size = self.store_batch_size_prompts
677
607
  d_in = self.d_in
678
608
  total_size = batch_size * n_batches_in_buffer
679
- num_layers = 1
680
609
 
681
610
  if self.cached_activation_dataset is not None:
682
611
  return self._load_buffer_from_cached(
683
- total_size, context_size, num_layers, d_in, raise_on_epoch_end
612
+ total_size, context_size, d_in, raise_on_epoch_end
684
613
  )
685
614
 
686
615
  refill_iterator = range(0, total_size, batch_size)
687
616
  # Initialize empty tensor buffer of the maximum required size with an additional dimension for layers
688
617
  new_buffer_activations = torch.zeros(
689
- (total_size, training_context_size, num_layers, d_in),
618
+ (total_size, self.training_context_size, d_in),
690
619
  dtype=self.dtype, # type: ignore
691
620
  device=self.device,
692
621
  )
693
622
  new_buffer_token_ids = torch.zeros(
694
- (total_size, training_context_size),
623
+ (total_size, self.training_context_size),
695
624
  dtype=torch.long,
696
625
  device=self.device,
697
626
  )
@@ -716,106 +645,80 @@ class ActivationsStore:
716
645
  refill_batch_idx_start : refill_batch_idx_start + batch_size, ...
717
646
  ] = refill_batch_tokens
718
647
 
719
- new_buffer_activations = new_buffer_activations.reshape(-1, num_layers, d_in)
648
+ new_buffer_activations = new_buffer_activations.reshape(-1, d_in)
720
649
  new_buffer_token_ids = new_buffer_token_ids.reshape(-1)
721
650
  if shuffle:
722
651
  new_buffer_activations, new_buffer_token_ids = permute_together(
723
652
  [new_buffer_activations, new_buffer_token_ids]
724
653
  )
725
654
 
726
- # every buffer should be normalized:
727
- if self.normalize_activations == "expected_average_only_in":
728
- new_buffer_activations = self.apply_norm_scaling_factor(
729
- new_buffer_activations
730
- )
731
-
732
655
  return (
733
656
  new_buffer_activations,
734
657
  new_buffer_token_ids,
735
658
  )
736
659
 
737
- def get_data_loader(
660
+ def get_filtered_buffer(
738
661
  self,
739
- ) -> Iterator[Any]:
740
- """
741
- Return a torch.utils.dataloader which you can get batches from.
742
-
743
- Should automatically refill the buffer when it gets to n % full.
744
- (better mixing if you refill and shuffle regularly).
662
+ n_batches_in_buffer: int,
663
+ raise_on_epoch_end: bool = False,
664
+ shuffle: bool = True,
665
+ ) -> torch.Tensor:
666
+ return _filter_buffer_acts(
667
+ self.get_raw_buffer(
668
+ n_batches_in_buffer=n_batches_in_buffer,
669
+ raise_on_epoch_end=raise_on_epoch_end,
670
+ shuffle=shuffle,
671
+ ),
672
+ self.exclude_special_tokens,
673
+ )
745
674
 
675
+ def _iterate_filtered_activations(self) -> Generator[torch.Tensor, None, None]:
746
676
  """
747
-
748
- batch_size = self.train_batch_size_tokens
749
-
750
- try:
751
- new_samples = _filter_buffer_acts(
752
- self.get_buffer(self.half_buffer_size, raise_on_epoch_end=True),
753
- self.exclude_special_tokens,
754
- )
755
- except StopIteration:
756
- warnings.warn(
757
- "All samples in the training dataset have been exhausted, we are now beginning a new epoch with the same samples."
758
- )
759
- self._storage_buffer = (
760
- None # dump the current buffer so samples do not leak between epochs
761
- )
677
+ Iterate over the filtered tokens in the buffer.
678
+ """
679
+ while True:
762
680
  try:
763
- new_samples = _filter_buffer_acts(
764
- self.get_buffer(self.half_buffer_size),
765
- self.exclude_special_tokens,
681
+ yield self.get_filtered_buffer(
682
+ self.half_buffer_size, raise_on_epoch_end=True
766
683
  )
767
684
  except StopIteration:
768
- raise ValueError(
769
- "We were unable to fill up the buffer directly after starting a new epoch. This could indicate that there are less samples in the dataset than are required to fill up the buffer. Consider reducing batch_size or n_batches_in_buffer. "
685
+ warnings.warn(
686
+ "All samples in the training dataset have been exhausted, beginning new epoch."
770
687
  )
688
+ try:
689
+ yield self.get_filtered_buffer(self.half_buffer_size)
690
+ except StopIteration:
691
+ raise ValueError(
692
+ "Unable to fill buffer after starting new epoch. Dataset may be too small."
693
+ )
771
694
 
772
- # 1. # create new buffer by mixing stored and new buffer
773
- mixing_buffer = torch.cat(
774
- [new_samples, self.storage_buffer],
775
- dim=0,
776
- )
777
-
778
- mixing_buffer = mixing_buffer[torch.randperm(mixing_buffer.shape[0])]
779
-
780
- # 2. put 50 % in storage
781
- self._storage_buffer = mixing_buffer[: mixing_buffer.shape[0] // 2]
782
-
783
- # 3. put other 50 % in a dataloader
784
- return iter(
785
- DataLoader(
786
- # TODO: seems like a typing bug?
787
- cast(Any, mixing_buffer[mixing_buffer.shape[0] // 2 :]),
788
- batch_size=batch_size,
789
- shuffle=True,
790
- )
695
+ def get_data_loader(
696
+ self,
697
+ ) -> Iterator[Any]:
698
+ """
699
+ Return an auto-refilling stream of filtered and mixed activations.
700
+ """
701
+ return mixing_buffer(
702
+ buffer_size=self.n_batches_in_buffer * self.training_context_size,
703
+ batch_size=self.train_batch_size_tokens,
704
+ activations_loader=self._iterate_filtered_activations(),
791
705
  )
792
706
 
793
707
  def next_batch(self) -> torch.Tensor:
794
- """
795
- Get the next batch from the current DataLoader.
796
- If the DataLoader is exhausted, refill the buffer and create a new DataLoader.
797
- """
798
- try:
799
- # Try to get the next batch
800
- return next(self.dataloader)
801
- except StopIteration:
802
- # If the DataLoader is exhausted, create a new one
708
+ """Get next batch, updating buffer if needed."""
709
+ return self.__next__()
710
+
711
+ # ActivationsStore should be an iterator
712
+ def __next__(self) -> torch.Tensor:
713
+ if self._dataloader is None:
803
714
  self._dataloader = self.get_data_loader()
804
- return next(self.dataloader)
715
+ return next(self._dataloader)
716
+
717
+ def __iter__(self) -> Iterator[torch.Tensor]:
718
+ return self
805
719
 
806
720
  def state_dict(self) -> dict[str, torch.Tensor]:
807
- result = {
808
- "n_dataset_processed": torch.tensor(self.n_dataset_processed),
809
- }
810
- if self._storage_buffer is not None: # first time might be None
811
- result["storage_buffer_activations"] = self._storage_buffer[0]
812
- if self._storage_buffer[1] is not None:
813
- result["storage_buffer_tokens"] = self._storage_buffer[1]
814
- if self.estimated_norm_scaling_factor is not None:
815
- result["estimated_norm_scaling_factor"] = torch.tensor(
816
- self.estimated_norm_scaling_factor
817
- )
818
- return result
721
+ return {"n_dataset_processed": torch.tensor(self.n_dataset_processed)}
819
722
 
820
723
  def save(self, file_path: str):
821
724
  """save the state dict to a file in safetensors format"""
@@ -0,0 +1,56 @@
1
+ from collections.abc import Iterator
2
+
3
+ import torch
4
+
5
+
6
+ @torch.no_grad()
7
+ def mixing_buffer(
8
+ buffer_size: int,
9
+ batch_size: int,
10
+ activations_loader: Iterator[torch.Tensor],
11
+ ) -> Iterator[torch.Tensor]:
12
+ """
13
+ A generator that maintains a mix of old and new activations for better training.
14
+ It stores half of the activations and mixes them with new ones to create batches.
15
+
16
+ Args:
17
+ buffer_size: Total size of the buffer (will store buffer_size/2 activations)
18
+ batch_size: Size of batches to return
19
+ activations_loader: Iterator providing new activations
20
+
21
+ Yields:
22
+ Batches of activations of shape (batch_size, *activation_dims)
23
+ """
24
+
25
+ if buffer_size < batch_size:
26
+ raise ValueError("Buffer size must be greater than or equal to batch size")
27
+
28
+ storage_buffer: torch.Tensor | None = None
29
+
30
+ for new_activations in activations_loader:
31
+ storage_buffer = (
32
+ new_activations
33
+ if storage_buffer is None
34
+ else torch.cat([storage_buffer, new_activations], dim=0)
35
+ )
36
+
37
+ if storage_buffer.shape[0] >= buffer_size:
38
+ # Shuffle
39
+ storage_buffer = storage_buffer[torch.randperm(storage_buffer.shape[0])]
40
+
41
+ num_serving_batches = max(1, storage_buffer.shape[0] // (2 * batch_size))
42
+ serving_cutoff = num_serving_batches * batch_size
43
+ serving_buffer = storage_buffer[:serving_cutoff]
44
+ storage_buffer = storage_buffer[serving_cutoff:]
45
+
46
+ # Yield batches from the serving_buffer
47
+ for batch_idx in range(num_serving_batches):
48
+ yield serving_buffer[
49
+ batch_idx * batch_size : (batch_idx + 1) * batch_size
50
+ ]
51
+
52
+ # If there are any remaining activations, yield them
53
+ if storage_buffer is not None:
54
+ remaining_batches = storage_buffer.shape[0] // batch_size
55
+ for i in range(remaining_batches):
56
+ yield storage_buffer[i * batch_size : (i + 1) * batch_size]