sae-lens 5.11.0__py3-none-any.whl → 6.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. sae_lens/__init__.py +60 -7
  2. sae_lens/analysis/hooked_sae_transformer.py +12 -12
  3. sae_lens/analysis/neuronpedia_integration.py +16 -14
  4. sae_lens/cache_activations_runner.py +9 -7
  5. sae_lens/config.py +170 -258
  6. sae_lens/constants.py +21 -0
  7. sae_lens/evals.py +59 -44
  8. sae_lens/llm_sae_training_runner.py +377 -0
  9. sae_lens/load_model.py +52 -4
  10. sae_lens/{toolkit → loading}/pretrained_sae_loaders.py +85 -32
  11. sae_lens/registry.py +49 -0
  12. sae_lens/saes/__init__.py +48 -0
  13. sae_lens/saes/gated_sae.py +254 -0
  14. sae_lens/saes/jumprelu_sae.py +348 -0
  15. sae_lens/saes/sae.py +1076 -0
  16. sae_lens/saes/standard_sae.py +178 -0
  17. sae_lens/saes/topk_sae.py +300 -0
  18. sae_lens/training/activation_scaler.py +53 -0
  19. sae_lens/training/activations_store.py +103 -184
  20. sae_lens/training/mixing_buffer.py +56 -0
  21. sae_lens/training/optim.py +60 -36
  22. sae_lens/training/sae_trainer.py +155 -177
  23. sae_lens/training/types.py +5 -0
  24. sae_lens/training/upload_saes_to_huggingface.py +13 -7
  25. sae_lens/util.py +47 -0
  26. {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/METADATA +1 -1
  27. sae_lens-6.0.0.dist-info/RECORD +37 -0
  28. sae_lens/sae.py +0 -747
  29. sae_lens/sae_training_runner.py +0 -251
  30. sae_lens/training/geometric_median.py +0 -101
  31. sae_lens/training/training_sae.py +0 -710
  32. sae_lens-5.11.0.dist-info/RECORD +0 -28
  33. /sae_lens/{toolkit → loading}/__init__.py +0 -0
  34. /sae_lens/{toolkit → loading}/pretrained_saes_directory.py +0 -0
  35. {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/LICENSE +0 -0
  36. {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/WHEEL +0 -0
@@ -1,6 +1,5 @@
1
1
  from __future__ import annotations
2
2
 
3
- import contextlib
4
3
  import json
5
4
  import os
6
5
  import warnings
@@ -16,20 +15,21 @@ from huggingface_hub.utils import HfHubHTTPError
16
15
  from jaxtyping import Float, Int
17
16
  from requests import HTTPError
18
17
  from safetensors.torch import save_file
19
- from torch.utils.data import DataLoader
20
18
  from tqdm import tqdm
21
19
  from transformer_lens.hook_points import HookedRootModule
22
20
  from transformers import AutoTokenizer, PreTrainedTokenizerBase
23
21
 
24
22
  from sae_lens import logger
25
23
  from sae_lens.config import (
26
- DTYPE_MAP,
27
24
  CacheActivationsRunnerConfig,
28
25
  HfDataset,
29
26
  LanguageModelSAERunnerConfig,
30
27
  )
31
- from sae_lens.sae import SAE
28
+ from sae_lens.constants import DTYPE_MAP
29
+ from sae_lens.saes.sae import SAE, T_SAE_CONFIG, T_TRAINING_SAE_CONFIG
32
30
  from sae_lens.tokenization_and_batching import concat_and_batch_sequences
31
+ from sae_lens.training.mixing_buffer import mixing_buffer
32
+ from sae_lens.util import extract_stop_at_layer_from_tlens_hook_name
33
33
 
34
34
 
35
35
  # TODO: Make an activation store config class to be consistent with the rest of the code.
@@ -45,10 +45,8 @@ class ActivationsStore:
45
45
  cached_activation_dataset: Dataset | None = None
46
46
  tokens_column: Literal["tokens", "input_ids", "text", "problem"]
47
47
  hook_name: str
48
- hook_layer: int
49
48
  hook_head_index: int | None
50
49
  _dataloader: Iterator[Any] | None = None
51
- _storage_buffer: torch.Tensor | None = None
52
50
  exclude_special_tokens: torch.Tensor | None = None
53
51
  device: torch.device
54
52
 
@@ -65,7 +63,6 @@ class ActivationsStore:
65
63
  cached_activations_path=cfg.new_cached_activations_path,
66
64
  dtype=cfg.dtype,
67
65
  hook_name=cfg.hook_name,
68
- hook_layer=cfg.hook_layer,
69
66
  context_size=cfg.context_size,
70
67
  d_in=cfg.d_in,
71
68
  n_batches_in_buffer=cfg.n_batches_in_buffer,
@@ -91,7 +88,8 @@ class ActivationsStore:
91
88
  def from_config(
92
89
  cls,
93
90
  model: HookedRootModule,
94
- cfg: LanguageModelSAERunnerConfig | CacheActivationsRunnerConfig,
91
+ cfg: LanguageModelSAERunnerConfig[T_TRAINING_SAE_CONFIG]
92
+ | CacheActivationsRunnerConfig,
95
93
  override_dataset: HfDataset | None = None,
96
94
  ) -> ActivationsStore:
97
95
  if isinstance(cfg, CacheActivationsRunnerConfig):
@@ -125,16 +123,17 @@ class ActivationsStore:
125
123
  dataset=override_dataset or cfg.dataset_path,
126
124
  streaming=cfg.streaming,
127
125
  hook_name=cfg.hook_name,
128
- hook_layer=cfg.hook_layer,
129
126
  hook_head_index=cfg.hook_head_index,
130
127
  context_size=cfg.context_size,
131
- d_in=cfg.d_in,
128
+ d_in=cfg.d_in
129
+ if isinstance(cfg, CacheActivationsRunnerConfig)
130
+ else cfg.sae.d_in,
132
131
  n_batches_in_buffer=cfg.n_batches_in_buffer,
133
132
  total_training_tokens=cfg.training_tokens,
134
133
  store_batch_size_prompts=cfg.store_batch_size_prompts,
135
134
  train_batch_size_tokens=cfg.train_batch_size_tokens,
136
135
  prepend_bos=cfg.prepend_bos,
137
- normalize_activations=cfg.normalize_activations,
136
+ normalize_activations=cfg.sae.normalize_activations,
138
137
  device=device,
139
138
  dtype=cfg.dtype,
140
139
  cached_activations_path=cached_activations_path,
@@ -149,9 +148,10 @@ class ActivationsStore:
149
148
  def from_sae(
150
149
  cls,
151
150
  model: HookedRootModule,
152
- sae: SAE,
151
+ sae: SAE[T_SAE_CONFIG],
152
+ dataset: HfDataset | str,
153
+ dataset_trust_remote_code: bool = False,
153
154
  context_size: int | None = None,
154
- dataset: HfDataset | str | None = None,
155
155
  streaming: bool = True,
156
156
  store_batch_size_prompts: int = 8,
157
157
  n_batches_in_buffer: int = 8,
@@ -159,25 +159,32 @@ class ActivationsStore:
159
159
  total_tokens: int = 10**9,
160
160
  device: str = "cpu",
161
161
  ) -> ActivationsStore:
162
+ if sae.cfg.metadata.hook_name is None:
163
+ raise ValueError("hook_name is required")
164
+ if sae.cfg.metadata.context_size is None:
165
+ raise ValueError("context_size is required")
166
+ if sae.cfg.metadata.prepend_bos is None:
167
+ raise ValueError("prepend_bos is required")
162
168
  return cls(
163
169
  model=model,
164
- dataset=sae.cfg.dataset_path if dataset is None else dataset,
170
+ dataset=dataset,
165
171
  d_in=sae.cfg.d_in,
166
- hook_name=sae.cfg.hook_name,
167
- hook_layer=sae.cfg.hook_layer,
168
- hook_head_index=sae.cfg.hook_head_index,
169
- context_size=sae.cfg.context_size if context_size is None else context_size,
170
- prepend_bos=sae.cfg.prepend_bos,
172
+ hook_name=sae.cfg.metadata.hook_name,
173
+ hook_head_index=sae.cfg.metadata.hook_head_index,
174
+ context_size=sae.cfg.metadata.context_size
175
+ if context_size is None
176
+ else context_size,
177
+ prepend_bos=sae.cfg.metadata.prepend_bos,
171
178
  streaming=streaming,
172
179
  store_batch_size_prompts=store_batch_size_prompts,
173
180
  train_batch_size_tokens=train_batch_size_tokens,
174
181
  n_batches_in_buffer=n_batches_in_buffer,
175
182
  total_training_tokens=total_tokens,
176
183
  normalize_activations=sae.cfg.normalize_activations,
177
- dataset_trust_remote_code=sae.cfg.dataset_trust_remote_code,
184
+ dataset_trust_remote_code=dataset_trust_remote_code,
178
185
  dtype=sae.cfg.dtype,
179
186
  device=torch.device(device),
180
- seqpos_slice=sae.cfg.seqpos_slice,
187
+ seqpos_slice=sae.cfg.metadata.seqpos_slice or (None,),
181
188
  )
182
189
 
183
190
  def __init__(
@@ -186,7 +193,6 @@ class ActivationsStore:
186
193
  dataset: HfDataset | str,
187
194
  streaming: bool,
188
195
  hook_name: str,
189
- hook_layer: int,
190
196
  hook_head_index: int | None,
191
197
  context_size: int,
192
198
  d_in: int,
@@ -230,7 +236,6 @@ class ActivationsStore:
230
236
  )
231
237
 
232
238
  self.hook_name = hook_name
233
- self.hook_layer = hook_layer
234
239
  self.hook_head_index = hook_head_index
235
240
  self.context_size = context_size
236
241
  self.d_in = d_in
@@ -246,12 +251,11 @@ class ActivationsStore:
246
251
  self.cached_activations_path = cached_activations_path
247
252
  self.autocast_lm = autocast_lm
248
253
  self.seqpos_slice = seqpos_slice
254
+ self.training_context_size = len(range(context_size)[slice(*seqpos_slice)])
249
255
  self.exclude_special_tokens = exclude_special_tokens
250
256
 
251
257
  self.n_dataset_processed = 0
252
258
 
253
- self.estimated_norm_scaling_factor = None
254
-
255
259
  # Check if dataset is tokenized
256
260
  dataset_sample = next(iter(self.dataset))
257
261
 
@@ -416,30 +420,6 @@ class ActivationsStore:
416
420
 
417
421
  return activations_dataset
418
422
 
419
- def set_norm_scaling_factor_if_needed(self):
420
- if (
421
- self.normalize_activations == "expected_average_only_in"
422
- and self.estimated_norm_scaling_factor is None
423
- ):
424
- self.estimated_norm_scaling_factor = self.estimate_norm_scaling_factor()
425
-
426
- def apply_norm_scaling_factor(self, activations: torch.Tensor) -> torch.Tensor:
427
- if self.estimated_norm_scaling_factor is None:
428
- raise ValueError(
429
- "estimated_norm_scaling_factor is not set, call set_norm_scaling_factor_if_needed() first"
430
- )
431
- return activations * self.estimated_norm_scaling_factor
432
-
433
- def unscale(self, activations: torch.Tensor) -> torch.Tensor:
434
- if self.estimated_norm_scaling_factor is None:
435
- raise ValueError(
436
- "estimated_norm_scaling_factor is not set, call set_norm_scaling_factor_if_needed() first"
437
- )
438
- return activations / self.estimated_norm_scaling_factor
439
-
440
- def get_norm_scaling_factor(self, activations: torch.Tensor) -> torch.Tensor:
441
- return (self.d_in**0.5) / activations.norm(dim=-1).mean()
442
-
443
423
  @torch.no_grad()
444
424
  def estimate_norm_scaling_factor(self, n_batches_for_norm_estimate: int = int(1e3)):
445
425
  norms_per_batch = []
@@ -474,21 +454,6 @@ class ActivationsStore:
474
454
  """
475
455
  self.iterable_dataset = iter(self.dataset)
476
456
 
477
- @property
478
- def storage_buffer(self) -> torch.Tensor:
479
- if self._storage_buffer is None:
480
- self._storage_buffer = _filter_buffer_acts(
481
- self.get_buffer(self.half_buffer_size), self.exclude_special_tokens
482
- )
483
-
484
- return self._storage_buffer
485
-
486
- @property
487
- def dataloader(self) -> Iterator[Any]:
488
- if self._dataloader is None:
489
- self._dataloader = self.get_data_loader()
490
- return self._dataloader
491
-
492
457
  def get_batch_tokens(
493
458
  self, batch_size: int | None = None, raise_at_epoch_end: bool = False
494
459
  ):
@@ -521,22 +486,17 @@ class ActivationsStore:
521
486
 
522
487
  d_in may result from a concatenated head dimension.
523
488
  """
524
-
525
- # Setup autocast if using
526
- if self.autocast_lm:
527
- autocast_if_enabled = torch.autocast(
528
- device_type="cuda",
529
- dtype=torch.bfloat16,
530
- enabled=self.autocast_lm,
531
- )
532
- else:
533
- autocast_if_enabled = contextlib.nullcontext()
534
-
535
- with autocast_if_enabled:
489
+ with torch.autocast(
490
+ device_type="cuda",
491
+ dtype=torch.bfloat16,
492
+ enabled=self.autocast_lm,
493
+ ):
536
494
  layerwise_activations_cache = self.model.run_with_cache(
537
495
  batch_tokens,
538
496
  names_filter=[self.hook_name],
539
- stop_at_layer=self.hook_layer + 1,
497
+ stop_at_layer=extract_stop_at_layer_from_tlens_hook_name(
498
+ self.hook_name
499
+ ),
540
500
  prepend_bos=False,
541
501
  **self.model_kwargs,
542
502
  )[1]
@@ -547,25 +507,25 @@ class ActivationsStore:
547
507
 
548
508
  n_batches, n_context = layerwise_activations.shape[:2]
549
509
 
550
- stacked_activations = torch.zeros((n_batches, n_context, 1, self.d_in))
510
+ stacked_activations = torch.zeros((n_batches, n_context, self.d_in))
551
511
 
552
512
  if self.hook_head_index is not None:
553
- stacked_activations[:, :, 0] = layerwise_activations[
513
+ stacked_activations[:, :] = layerwise_activations[
554
514
  :, :, self.hook_head_index
555
515
  ]
556
516
  elif layerwise_activations.ndim > 3: # if we have a head dimension
557
517
  try:
558
- stacked_activations[:, :, 0] = layerwise_activations.view(
518
+ stacked_activations[:, :] = layerwise_activations.view(
559
519
  n_batches, n_context, -1
560
520
  )
561
521
  except RuntimeError as e:
562
522
  logger.error(f"Error during view operation: {e}")
563
523
  logger.info("Attempting to use reshape instead...")
564
- stacked_activations[:, :, 0] = layerwise_activations.reshape(
524
+ stacked_activations[:, :] = layerwise_activations.reshape(
565
525
  n_batches, n_context, -1
566
526
  )
567
527
  else:
568
- stacked_activations[:, :, 0] = layerwise_activations
528
+ stacked_activations[:, :] = layerwise_activations
569
529
 
570
530
  return stacked_activations
571
531
 
@@ -573,7 +533,6 @@ class ActivationsStore:
573
533
  self,
574
534
  total_size: int,
575
535
  context_size: int,
576
- num_layers: int,
577
536
  d_in: int,
578
537
  raise_on_epoch_end: bool,
579
538
  ) -> tuple[
@@ -590,10 +549,9 @@ class ActivationsStore:
590
549
  """
591
550
  assert self.cached_activation_dataset is not None
592
551
  # In future, could be a list of multiple hook names
593
- hook_names = [self.hook_name]
594
- if not set(hook_names).issubset(self.cached_activation_dataset.column_names):
552
+ if self.hook_name not in self.cached_activation_dataset.column_names:
595
553
  raise ValueError(
596
- f"Missing columns in dataset. Expected {hook_names}, "
554
+ f"Missing columns in dataset. Expected {self.hook_name}, "
597
555
  f"got {self.cached_activation_dataset.column_names}."
598
556
  )
599
557
 
@@ -606,28 +564,17 @@ class ActivationsStore:
606
564
  ds_slice = self.cached_activation_dataset[
607
565
  self.current_row_idx : self.current_row_idx + total_size
608
566
  ]
609
- for hook_name in hook_names:
610
- # Load activations for each hook.
611
- # Usually faster to first slice dataset then pick column
612
- _hook_buffer = ds_slice[hook_name]
613
- if _hook_buffer.shape != (total_size, context_size, d_in):
614
- raise ValueError(
615
- f"_hook_buffer has shape {_hook_buffer.shape}, "
616
- f"but expected ({total_size}, {context_size}, {d_in})."
617
- )
618
- new_buffer.append(_hook_buffer)
619
-
620
- # Stack across num_layers dimension
621
- # list of num_layers; shape: (total_size, context_size, d_in) -> (total_size, context_size, num_layers, d_in)
622
- new_buffer = torch.stack(new_buffer, dim=2)
623
- if new_buffer.shape != (total_size, context_size, num_layers, d_in):
567
+ # Load activations for each hook.
568
+ # Usually faster to first slice dataset then pick column
569
+ new_buffer = ds_slice[self.hook_name]
570
+ if new_buffer.shape != (total_size, context_size, d_in):
624
571
  raise ValueError(
625
572
  f"new_buffer has shape {new_buffer.shape}, "
626
- f"but expected ({total_size}, {context_size}, {num_layers}, {d_in})."
573
+ f"but expected ({total_size}, {context_size}, {d_in})."
627
574
  )
628
575
 
629
576
  self.current_row_idx += total_size
630
- acts_buffer = new_buffer.reshape(total_size * context_size, num_layers, d_in)
577
+ acts_buffer = new_buffer.reshape(total_size * context_size, d_in)
631
578
 
632
579
  if "token_ids" not in self.cached_activation_dataset.column_names:
633
580
  return acts_buffer, None
@@ -642,7 +589,7 @@ class ActivationsStore:
642
589
  return acts_buffer, token_ids_buffer
643
590
 
644
591
  @torch.no_grad()
645
- def get_buffer(
592
+ def get_raw_buffer(
646
593
  self,
647
594
  n_batches_in_buffer: int,
648
595
  raise_on_epoch_end: bool = False,
@@ -656,26 +603,24 @@ class ActivationsStore:
656
603
  If raise_on_epoch_end is True, when the dataset it exhausted it will automatically refill the dataset and then raise a StopIteration so that the caller has a chance to react.
657
604
  """
658
605
  context_size = self.context_size
659
- training_context_size = len(range(context_size)[slice(*self.seqpos_slice)])
660
606
  batch_size = self.store_batch_size_prompts
661
607
  d_in = self.d_in
662
608
  total_size = batch_size * n_batches_in_buffer
663
- num_layers = 1
664
609
 
665
610
  if self.cached_activation_dataset is not None:
666
611
  return self._load_buffer_from_cached(
667
- total_size, context_size, num_layers, d_in, raise_on_epoch_end
612
+ total_size, context_size, d_in, raise_on_epoch_end
668
613
  )
669
614
 
670
615
  refill_iterator = range(0, total_size, batch_size)
671
616
  # Initialize empty tensor buffer of the maximum required size with an additional dimension for layers
672
617
  new_buffer_activations = torch.zeros(
673
- (total_size, training_context_size, num_layers, d_in),
618
+ (total_size, self.training_context_size, d_in),
674
619
  dtype=self.dtype, # type: ignore
675
620
  device=self.device,
676
621
  )
677
622
  new_buffer_token_ids = torch.zeros(
678
- (total_size, training_context_size),
623
+ (total_size, self.training_context_size),
679
624
  dtype=torch.long,
680
625
  device=self.device,
681
626
  )
@@ -700,106 +645,80 @@ class ActivationsStore:
700
645
  refill_batch_idx_start : refill_batch_idx_start + batch_size, ...
701
646
  ] = refill_batch_tokens
702
647
 
703
- new_buffer_activations = new_buffer_activations.reshape(-1, num_layers, d_in)
648
+ new_buffer_activations = new_buffer_activations.reshape(-1, d_in)
704
649
  new_buffer_token_ids = new_buffer_token_ids.reshape(-1)
705
650
  if shuffle:
706
651
  new_buffer_activations, new_buffer_token_ids = permute_together(
707
652
  [new_buffer_activations, new_buffer_token_ids]
708
653
  )
709
654
 
710
- # every buffer should be normalized:
711
- if self.normalize_activations == "expected_average_only_in":
712
- new_buffer_activations = self.apply_norm_scaling_factor(
713
- new_buffer_activations
714
- )
715
-
716
655
  return (
717
656
  new_buffer_activations,
718
657
  new_buffer_token_ids,
719
658
  )
720
659
 
721
- def get_data_loader(
660
+ def get_filtered_buffer(
722
661
  self,
723
- ) -> Iterator[Any]:
724
- """
725
- Return a torch.utils.dataloader which you can get batches from.
726
-
727
- Should automatically refill the buffer when it gets to n % full.
728
- (better mixing if you refill and shuffle regularly).
662
+ n_batches_in_buffer: int,
663
+ raise_on_epoch_end: bool = False,
664
+ shuffle: bool = True,
665
+ ) -> torch.Tensor:
666
+ return _filter_buffer_acts(
667
+ self.get_raw_buffer(
668
+ n_batches_in_buffer=n_batches_in_buffer,
669
+ raise_on_epoch_end=raise_on_epoch_end,
670
+ shuffle=shuffle,
671
+ ),
672
+ self.exclude_special_tokens,
673
+ )
729
674
 
675
+ def _iterate_filtered_activations(self) -> Generator[torch.Tensor, None, None]:
730
676
  """
731
-
732
- batch_size = self.train_batch_size_tokens
733
-
734
- try:
735
- new_samples = _filter_buffer_acts(
736
- self.get_buffer(self.half_buffer_size, raise_on_epoch_end=True),
737
- self.exclude_special_tokens,
738
- )
739
- except StopIteration:
740
- warnings.warn(
741
- "All samples in the training dataset have been exhausted, we are now beginning a new epoch with the same samples."
742
- )
743
- self._storage_buffer = (
744
- None # dump the current buffer so samples do not leak between epochs
745
- )
677
+ Iterate over the filtered tokens in the buffer.
678
+ """
679
+ while True:
746
680
  try:
747
- new_samples = _filter_buffer_acts(
748
- self.get_buffer(self.half_buffer_size),
749
- self.exclude_special_tokens,
681
+ yield self.get_filtered_buffer(
682
+ self.half_buffer_size, raise_on_epoch_end=True
750
683
  )
751
684
  except StopIteration:
752
- raise ValueError(
753
- "We were unable to fill up the buffer directly after starting a new epoch. This could indicate that there are less samples in the dataset than are required to fill up the buffer. Consider reducing batch_size or n_batches_in_buffer. "
685
+ warnings.warn(
686
+ "All samples in the training dataset have been exhausted, beginning new epoch."
754
687
  )
688
+ try:
689
+ yield self.get_filtered_buffer(self.half_buffer_size)
690
+ except StopIteration:
691
+ raise ValueError(
692
+ "Unable to fill buffer after starting new epoch. Dataset may be too small."
693
+ )
755
694
 
756
- # 1. # create new buffer by mixing stored and new buffer
757
- mixing_buffer = torch.cat(
758
- [new_samples, self.storage_buffer],
759
- dim=0,
760
- )
761
-
762
- mixing_buffer = mixing_buffer[torch.randperm(mixing_buffer.shape[0])]
763
-
764
- # 2. put 50 % in storage
765
- self._storage_buffer = mixing_buffer[: mixing_buffer.shape[0] // 2]
766
-
767
- # 3. put other 50 % in a dataloader
768
- return iter(
769
- DataLoader(
770
- # TODO: seems like a typing bug?
771
- cast(Any, mixing_buffer[mixing_buffer.shape[0] // 2 :]),
772
- batch_size=batch_size,
773
- shuffle=True,
774
- )
695
+ def get_data_loader(
696
+ self,
697
+ ) -> Iterator[Any]:
698
+ """
699
+ Return an auto-refilling stream of filtered and mixed activations.
700
+ """
701
+ return mixing_buffer(
702
+ buffer_size=self.n_batches_in_buffer * self.training_context_size,
703
+ batch_size=self.train_batch_size_tokens,
704
+ activations_loader=self._iterate_filtered_activations(),
775
705
  )
776
706
 
777
707
  def next_batch(self) -> torch.Tensor:
778
- """
779
- Get the next batch from the current DataLoader.
780
- If the DataLoader is exhausted, refill the buffer and create a new DataLoader.
781
- """
782
- try:
783
- # Try to get the next batch
784
- return next(self.dataloader)
785
- except StopIteration:
786
- # If the DataLoader is exhausted, create a new one
708
+ """Get next batch, updating buffer if needed."""
709
+ return self.__next__()
710
+
711
+ # ActivationsStore should be an iterator
712
+ def __next__(self) -> torch.Tensor:
713
+ if self._dataloader is None:
787
714
  self._dataloader = self.get_data_loader()
788
- return next(self.dataloader)
715
+ return next(self._dataloader)
716
+
717
+ def __iter__(self) -> Iterator[torch.Tensor]:
718
+ return self
789
719
 
790
720
  def state_dict(self) -> dict[str, torch.Tensor]:
791
- result = {
792
- "n_dataset_processed": torch.tensor(self.n_dataset_processed),
793
- }
794
- if self._storage_buffer is not None: # first time might be None
795
- result["storage_buffer_activations"] = self._storage_buffer[0]
796
- if self._storage_buffer[1] is not None:
797
- result["storage_buffer_tokens"] = self._storage_buffer[1]
798
- if self.estimated_norm_scaling_factor is not None:
799
- result["estimated_norm_scaling_factor"] = torch.tensor(
800
- self.estimated_norm_scaling_factor
801
- )
802
- return result
721
+ return {"n_dataset_processed": torch.tensor(self.n_dataset_processed)}
803
722
 
804
723
  def save(self, file_path: str):
805
724
  """save the state dict to a file in safetensors format"""
@@ -0,0 +1,56 @@
1
+ from collections.abc import Iterator
2
+
3
+ import torch
4
+
5
+
6
+ @torch.no_grad()
7
+ def mixing_buffer(
8
+ buffer_size: int,
9
+ batch_size: int,
10
+ activations_loader: Iterator[torch.Tensor],
11
+ ) -> Iterator[torch.Tensor]:
12
+ """
13
+ A generator that maintains a mix of old and new activations for better training.
14
+ It stores half of the activations and mixes them with new ones to create batches.
15
+
16
+ Args:
17
+ buffer_size: Total size of the buffer (will store buffer_size/2 activations)
18
+ batch_size: Size of batches to return
19
+ activations_loader: Iterator providing new activations
20
+
21
+ Yields:
22
+ Batches of activations of shape (batch_size, *activation_dims)
23
+ """
24
+
25
+ if buffer_size < batch_size:
26
+ raise ValueError("Buffer size must be greater than or equal to batch size")
27
+
28
+ storage_buffer: torch.Tensor | None = None
29
+
30
+ for new_activations in activations_loader:
31
+ storage_buffer = (
32
+ new_activations
33
+ if storage_buffer is None
34
+ else torch.cat([storage_buffer, new_activations], dim=0)
35
+ )
36
+
37
+ if storage_buffer.shape[0] >= buffer_size:
38
+ # Shuffle
39
+ storage_buffer = storage_buffer[torch.randperm(storage_buffer.shape[0])]
40
+
41
+ num_serving_batches = max(1, storage_buffer.shape[0] // (2 * batch_size))
42
+ serving_cutoff = num_serving_batches * batch_size
43
+ serving_buffer = storage_buffer[:serving_cutoff]
44
+ storage_buffer = storage_buffer[serving_cutoff:]
45
+
46
+ # Yield batches from the serving_buffer
47
+ for batch_idx in range(num_serving_batches):
48
+ yield serving_buffer[
49
+ batch_idx * batch_size : (batch_idx + 1) * batch_size
50
+ ]
51
+
52
+ # If there are any remaining activations, yield them
53
+ if storage_buffer is not None:
54
+ remaining_batches = storage_buffer.shape[0] // batch_size
55
+ for i in range(remaining_batches):
56
+ yield storage_buffer[i * batch_size : (i + 1) * batch_size]