sae-lens 6.0.0rc1__py3-none-any.whl → 6.0.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,5 @@
1
1
  from __future__ import annotations
2
2
 
3
- import contextlib
4
3
  import json
5
4
  import os
6
5
  import warnings
@@ -16,20 +15,21 @@ from huggingface_hub.utils import HfHubHTTPError
16
15
  from jaxtyping import Float, Int
17
16
  from requests import HTTPError
18
17
  from safetensors.torch import save_file
19
- from torch.utils.data import DataLoader
20
18
  from tqdm import tqdm
21
19
  from transformer_lens.hook_points import HookedRootModule
22
20
  from transformers import AutoTokenizer, PreTrainedTokenizerBase
23
21
 
24
22
  from sae_lens import logger
25
23
  from sae_lens.config import (
26
- DTYPE_MAP,
27
24
  CacheActivationsRunnerConfig,
28
25
  HfDataset,
29
26
  LanguageModelSAERunnerConfig,
30
27
  )
31
- from sae_lens.saes.sae import SAE
28
+ from sae_lens.constants import DTYPE_MAP
29
+ from sae_lens.saes.sae import SAE, T_SAE_CONFIG, T_TRAINING_SAE_CONFIG
32
30
  from sae_lens.tokenization_and_batching import concat_and_batch_sequences
31
+ from sae_lens.training.mixing_buffer import mixing_buffer
32
+ from sae_lens.util import extract_stop_at_layer_from_tlens_hook_name
33
33
 
34
34
 
35
35
  # TODO: Make an activation store config class to be consistent with the rest of the code.
@@ -45,10 +45,8 @@ class ActivationsStore:
45
45
  cached_activation_dataset: Dataset | None = None
46
46
  tokens_column: Literal["tokens", "input_ids", "text", "problem"]
47
47
  hook_name: str
48
- hook_layer: int
49
48
  hook_head_index: int | None
50
49
  _dataloader: Iterator[Any] | None = None
51
- _storage_buffer: torch.Tensor | None = None
52
50
  exclude_special_tokens: torch.Tensor | None = None
53
51
  device: torch.device
54
52
 
@@ -65,7 +63,6 @@ class ActivationsStore:
65
63
  cached_activations_path=cfg.new_cached_activations_path,
66
64
  dtype=cfg.dtype,
67
65
  hook_name=cfg.hook_name,
68
- hook_layer=cfg.hook_layer,
69
66
  context_size=cfg.context_size,
70
67
  d_in=cfg.d_in,
71
68
  n_batches_in_buffer=cfg.n_batches_in_buffer,
@@ -91,7 +88,8 @@ class ActivationsStore:
91
88
  def from_config(
92
89
  cls,
93
90
  model: HookedRootModule,
94
- cfg: LanguageModelSAERunnerConfig | CacheActivationsRunnerConfig,
91
+ cfg: LanguageModelSAERunnerConfig[T_TRAINING_SAE_CONFIG]
92
+ | CacheActivationsRunnerConfig,
95
93
  override_dataset: HfDataset | None = None,
96
94
  ) -> ActivationsStore:
97
95
  if isinstance(cfg, CacheActivationsRunnerConfig):
@@ -125,16 +123,17 @@ class ActivationsStore:
125
123
  dataset=override_dataset or cfg.dataset_path,
126
124
  streaming=cfg.streaming,
127
125
  hook_name=cfg.hook_name,
128
- hook_layer=cfg.hook_layer,
129
126
  hook_head_index=cfg.hook_head_index,
130
127
  context_size=cfg.context_size,
131
- d_in=cfg.d_in,
128
+ d_in=cfg.d_in
129
+ if isinstance(cfg, CacheActivationsRunnerConfig)
130
+ else cfg.sae.d_in,
132
131
  n_batches_in_buffer=cfg.n_batches_in_buffer,
133
132
  total_training_tokens=cfg.training_tokens,
134
133
  store_batch_size_prompts=cfg.store_batch_size_prompts,
135
134
  train_batch_size_tokens=cfg.train_batch_size_tokens,
136
135
  prepend_bos=cfg.prepend_bos,
137
- normalize_activations=cfg.normalize_activations,
136
+ normalize_activations=cfg.sae.normalize_activations,
138
137
  device=device,
139
138
  dtype=cfg.dtype,
140
139
  cached_activations_path=cached_activations_path,
@@ -149,9 +148,10 @@ class ActivationsStore:
149
148
  def from_sae(
150
149
  cls,
151
150
  model: HookedRootModule,
152
- sae: SAE,
151
+ sae: SAE[T_SAE_CONFIG],
152
+ dataset: HfDataset | str,
153
+ dataset_trust_remote_code: bool = False,
153
154
  context_size: int | None = None,
154
- dataset: HfDataset | str | None = None,
155
155
  streaming: bool = True,
156
156
  store_batch_size_prompts: int = 8,
157
157
  n_batches_in_buffer: int = 8,
@@ -159,25 +159,34 @@ class ActivationsStore:
159
159
  total_tokens: int = 10**9,
160
160
  device: str = "cpu",
161
161
  ) -> ActivationsStore:
162
+ if sae.cfg.metadata.hook_name is None:
163
+ raise ValueError("hook_name is required")
164
+ if sae.cfg.metadata.hook_head_index is None:
165
+ raise ValueError("hook_head_index is required")
166
+ if sae.cfg.metadata.context_size is None:
167
+ raise ValueError("context_size is required")
168
+ if sae.cfg.metadata.prepend_bos is None:
169
+ raise ValueError("prepend_bos is required")
162
170
  return cls(
163
171
  model=model,
164
- dataset=sae.cfg.dataset_path if dataset is None else dataset,
172
+ dataset=dataset,
165
173
  d_in=sae.cfg.d_in,
166
- hook_name=sae.cfg.hook_name,
167
- hook_layer=sae.cfg.hook_layer,
168
- hook_head_index=sae.cfg.hook_head_index,
169
- context_size=sae.cfg.context_size if context_size is None else context_size,
170
- prepend_bos=sae.cfg.prepend_bos,
174
+ hook_name=sae.cfg.metadata.hook_name,
175
+ hook_head_index=sae.cfg.metadata.hook_head_index,
176
+ context_size=sae.cfg.metadata.context_size
177
+ if context_size is None
178
+ else context_size,
179
+ prepend_bos=sae.cfg.metadata.prepend_bos,
171
180
  streaming=streaming,
172
181
  store_batch_size_prompts=store_batch_size_prompts,
173
182
  train_batch_size_tokens=train_batch_size_tokens,
174
183
  n_batches_in_buffer=n_batches_in_buffer,
175
184
  total_training_tokens=total_tokens,
176
185
  normalize_activations=sae.cfg.normalize_activations,
177
- dataset_trust_remote_code=sae.cfg.dataset_trust_remote_code,
186
+ dataset_trust_remote_code=dataset_trust_remote_code,
178
187
  dtype=sae.cfg.dtype,
179
188
  device=torch.device(device),
180
- seqpos_slice=sae.cfg.seqpos_slice or (None,),
189
+ seqpos_slice=sae.cfg.metadata.seqpos_slice or (None,),
181
190
  )
182
191
 
183
192
  def __init__(
@@ -186,7 +195,6 @@ class ActivationsStore:
186
195
  dataset: HfDataset | str,
187
196
  streaming: bool,
188
197
  hook_name: str,
189
- hook_layer: int,
190
198
  hook_head_index: int | None,
191
199
  context_size: int,
192
200
  d_in: int,
@@ -230,7 +238,6 @@ class ActivationsStore:
230
238
  )
231
239
 
232
240
  self.hook_name = hook_name
233
- self.hook_layer = hook_layer
234
241
  self.hook_head_index = hook_head_index
235
242
  self.context_size = context_size
236
243
  self.d_in = d_in
@@ -246,12 +253,11 @@ class ActivationsStore:
246
253
  self.cached_activations_path = cached_activations_path
247
254
  self.autocast_lm = autocast_lm
248
255
  self.seqpos_slice = seqpos_slice
256
+ self.training_context_size = len(range(context_size)[slice(*seqpos_slice)])
249
257
  self.exclude_special_tokens = exclude_special_tokens
250
258
 
251
259
  self.n_dataset_processed = 0
252
260
 
253
- self.estimated_norm_scaling_factor = None
254
-
255
261
  # Check if dataset is tokenized
256
262
  dataset_sample = next(iter(self.dataset))
257
263
 
@@ -416,30 +422,6 @@ class ActivationsStore:
416
422
 
417
423
  return activations_dataset
418
424
 
419
- def set_norm_scaling_factor_if_needed(self):
420
- if (
421
- self.normalize_activations == "expected_average_only_in"
422
- and self.estimated_norm_scaling_factor is None
423
- ):
424
- self.estimated_norm_scaling_factor = self.estimate_norm_scaling_factor()
425
-
426
- def apply_norm_scaling_factor(self, activations: torch.Tensor) -> torch.Tensor:
427
- if self.estimated_norm_scaling_factor is None:
428
- raise ValueError(
429
- "estimated_norm_scaling_factor is not set, call set_norm_scaling_factor_if_needed() first"
430
- )
431
- return activations * self.estimated_norm_scaling_factor
432
-
433
- def unscale(self, activations: torch.Tensor) -> torch.Tensor:
434
- if self.estimated_norm_scaling_factor is None:
435
- raise ValueError(
436
- "estimated_norm_scaling_factor is not set, call set_norm_scaling_factor_if_needed() first"
437
- )
438
- return activations / self.estimated_norm_scaling_factor
439
-
440
- def get_norm_scaling_factor(self, activations: torch.Tensor) -> torch.Tensor:
441
- return (self.d_in**0.5) / activations.norm(dim=-1).mean()
442
-
443
425
  @torch.no_grad()
444
426
  def estimate_norm_scaling_factor(self, n_batches_for_norm_estimate: int = int(1e3)):
445
427
  norms_per_batch = []
@@ -474,21 +456,6 @@ class ActivationsStore:
474
456
  """
475
457
  self.iterable_dataset = iter(self.dataset)
476
458
 
477
- @property
478
- def storage_buffer(self) -> torch.Tensor:
479
- if self._storage_buffer is None:
480
- self._storage_buffer = _filter_buffer_acts(
481
- self.get_buffer(self.half_buffer_size), self.exclude_special_tokens
482
- )
483
-
484
- return self._storage_buffer
485
-
486
- @property
487
- def dataloader(self) -> Iterator[Any]:
488
- if self._dataloader is None:
489
- self._dataloader = self.get_data_loader()
490
- return self._dataloader
491
-
492
459
  def get_batch_tokens(
493
460
  self, batch_size: int | None = None, raise_at_epoch_end: bool = False
494
461
  ):
@@ -521,22 +488,17 @@ class ActivationsStore:
521
488
 
522
489
  d_in may result from a concatenated head dimension.
523
490
  """
524
-
525
- # Setup autocast if using
526
- if self.autocast_lm:
527
- autocast_if_enabled = torch.autocast(
528
- device_type="cuda",
529
- dtype=torch.bfloat16,
530
- enabled=self.autocast_lm,
531
- )
532
- else:
533
- autocast_if_enabled = contextlib.nullcontext()
534
-
535
- with autocast_if_enabled:
491
+ with torch.autocast(
492
+ device_type="cuda",
493
+ dtype=torch.bfloat16,
494
+ enabled=self.autocast_lm,
495
+ ):
536
496
  layerwise_activations_cache = self.model.run_with_cache(
537
497
  batch_tokens,
538
498
  names_filter=[self.hook_name],
539
- stop_at_layer=self.hook_layer + 1,
499
+ stop_at_layer=extract_stop_at_layer_from_tlens_hook_name(
500
+ self.hook_name
501
+ ),
540
502
  prepend_bos=False,
541
503
  **self.model_kwargs,
542
504
  )[1]
@@ -547,25 +509,25 @@ class ActivationsStore:
547
509
 
548
510
  n_batches, n_context = layerwise_activations.shape[:2]
549
511
 
550
- stacked_activations = torch.zeros((n_batches, n_context, 1, self.d_in))
512
+ stacked_activations = torch.zeros((n_batches, n_context, self.d_in))
551
513
 
552
514
  if self.hook_head_index is not None:
553
- stacked_activations[:, :, 0] = layerwise_activations[
515
+ stacked_activations[:, :] = layerwise_activations[
554
516
  :, :, self.hook_head_index
555
517
  ]
556
518
  elif layerwise_activations.ndim > 3: # if we have a head dimension
557
519
  try:
558
- stacked_activations[:, :, 0] = layerwise_activations.view(
520
+ stacked_activations[:, :] = layerwise_activations.view(
559
521
  n_batches, n_context, -1
560
522
  )
561
523
  except RuntimeError as e:
562
524
  logger.error(f"Error during view operation: {e}")
563
525
  logger.info("Attempting to use reshape instead...")
564
- stacked_activations[:, :, 0] = layerwise_activations.reshape(
526
+ stacked_activations[:, :] = layerwise_activations.reshape(
565
527
  n_batches, n_context, -1
566
528
  )
567
529
  else:
568
- stacked_activations[:, :, 0] = layerwise_activations
530
+ stacked_activations[:, :] = layerwise_activations
569
531
 
570
532
  return stacked_activations
571
533
 
@@ -573,7 +535,6 @@ class ActivationsStore:
573
535
  self,
574
536
  total_size: int,
575
537
  context_size: int,
576
- num_layers: int,
577
538
  d_in: int,
578
539
  raise_on_epoch_end: bool,
579
540
  ) -> tuple[
@@ -590,10 +551,9 @@ class ActivationsStore:
590
551
  """
591
552
  assert self.cached_activation_dataset is not None
592
553
  # In future, could be a list of multiple hook names
593
- hook_names = [self.hook_name]
594
- if not set(hook_names).issubset(self.cached_activation_dataset.column_names):
554
+ if self.hook_name not in self.cached_activation_dataset.column_names:
595
555
  raise ValueError(
596
- f"Missing columns in dataset. Expected {hook_names}, "
556
+ f"Missing columns in dataset. Expected {self.hook_name}, "
597
557
  f"got {self.cached_activation_dataset.column_names}."
598
558
  )
599
559
 
@@ -606,28 +566,17 @@ class ActivationsStore:
606
566
  ds_slice = self.cached_activation_dataset[
607
567
  self.current_row_idx : self.current_row_idx + total_size
608
568
  ]
609
- for hook_name in hook_names:
610
- # Load activations for each hook.
611
- # Usually faster to first slice dataset then pick column
612
- _hook_buffer = ds_slice[hook_name]
613
- if _hook_buffer.shape != (total_size, context_size, d_in):
614
- raise ValueError(
615
- f"_hook_buffer has shape {_hook_buffer.shape}, "
616
- f"but expected ({total_size}, {context_size}, {d_in})."
617
- )
618
- new_buffer.append(_hook_buffer)
619
-
620
- # Stack across num_layers dimension
621
- # list of num_layers; shape: (total_size, context_size, d_in) -> (total_size, context_size, num_layers, d_in)
622
- new_buffer = torch.stack(new_buffer, dim=2)
623
- if new_buffer.shape != (total_size, context_size, num_layers, d_in):
569
+ # Load activations for each hook.
570
+ # Usually faster to first slice dataset then pick column
571
+ new_buffer = ds_slice[self.hook_name]
572
+ if new_buffer.shape != (total_size, context_size, d_in):
624
573
  raise ValueError(
625
574
  f"new_buffer has shape {new_buffer.shape}, "
626
- f"but expected ({total_size}, {context_size}, {num_layers}, {d_in})."
575
+ f"but expected ({total_size}, {context_size}, {d_in})."
627
576
  )
628
577
 
629
578
  self.current_row_idx += total_size
630
- acts_buffer = new_buffer.reshape(total_size * context_size, num_layers, d_in)
579
+ acts_buffer = new_buffer.reshape(total_size * context_size, d_in)
631
580
 
632
581
  if "token_ids" not in self.cached_activation_dataset.column_names:
633
582
  return acts_buffer, None
@@ -642,7 +591,7 @@ class ActivationsStore:
642
591
  return acts_buffer, token_ids_buffer
643
592
 
644
593
  @torch.no_grad()
645
- def get_buffer(
594
+ def get_raw_buffer(
646
595
  self,
647
596
  n_batches_in_buffer: int,
648
597
  raise_on_epoch_end: bool = False,
@@ -656,26 +605,24 @@ class ActivationsStore:
656
605
  If raise_on_epoch_end is True, when the dataset it exhausted it will automatically refill the dataset and then raise a StopIteration so that the caller has a chance to react.
657
606
  """
658
607
  context_size = self.context_size
659
- training_context_size = len(range(context_size)[slice(*self.seqpos_slice)])
660
608
  batch_size = self.store_batch_size_prompts
661
609
  d_in = self.d_in
662
610
  total_size = batch_size * n_batches_in_buffer
663
- num_layers = 1
664
611
 
665
612
  if self.cached_activation_dataset is not None:
666
613
  return self._load_buffer_from_cached(
667
- total_size, context_size, num_layers, d_in, raise_on_epoch_end
614
+ total_size, context_size, d_in, raise_on_epoch_end
668
615
  )
669
616
 
670
617
  refill_iterator = range(0, total_size, batch_size)
671
618
  # Initialize empty tensor buffer of the maximum required size with an additional dimension for layers
672
619
  new_buffer_activations = torch.zeros(
673
- (total_size, training_context_size, num_layers, d_in),
620
+ (total_size, self.training_context_size, d_in),
674
621
  dtype=self.dtype, # type: ignore
675
622
  device=self.device,
676
623
  )
677
624
  new_buffer_token_ids = torch.zeros(
678
- (total_size, training_context_size),
625
+ (total_size, self.training_context_size),
679
626
  dtype=torch.long,
680
627
  device=self.device,
681
628
  )
@@ -700,106 +647,80 @@ class ActivationsStore:
700
647
  refill_batch_idx_start : refill_batch_idx_start + batch_size, ...
701
648
  ] = refill_batch_tokens
702
649
 
703
- new_buffer_activations = new_buffer_activations.reshape(-1, num_layers, d_in)
650
+ new_buffer_activations = new_buffer_activations.reshape(-1, d_in)
704
651
  new_buffer_token_ids = new_buffer_token_ids.reshape(-1)
705
652
  if shuffle:
706
653
  new_buffer_activations, new_buffer_token_ids = permute_together(
707
654
  [new_buffer_activations, new_buffer_token_ids]
708
655
  )
709
656
 
710
- # every buffer should be normalized:
711
- if self.normalize_activations == "expected_average_only_in":
712
- new_buffer_activations = self.apply_norm_scaling_factor(
713
- new_buffer_activations
714
- )
715
-
716
657
  return (
717
658
  new_buffer_activations,
718
659
  new_buffer_token_ids,
719
660
  )
720
661
 
721
- def get_data_loader(
662
+ def get_filtered_buffer(
722
663
  self,
723
- ) -> Iterator[Any]:
724
- """
725
- Return a torch.utils.dataloader which you can get batches from.
726
-
727
- Should automatically refill the buffer when it gets to n % full.
728
- (better mixing if you refill and shuffle regularly).
664
+ n_batches_in_buffer: int,
665
+ raise_on_epoch_end: bool = False,
666
+ shuffle: bool = True,
667
+ ) -> torch.Tensor:
668
+ return _filter_buffer_acts(
669
+ self.get_raw_buffer(
670
+ n_batches_in_buffer=n_batches_in_buffer,
671
+ raise_on_epoch_end=raise_on_epoch_end,
672
+ shuffle=shuffle,
673
+ ),
674
+ self.exclude_special_tokens,
675
+ )
729
676
 
677
+ def _iterate_filtered_activations(self) -> Generator[torch.Tensor, None, None]:
730
678
  """
731
-
732
- batch_size = self.train_batch_size_tokens
733
-
734
- try:
735
- new_samples = _filter_buffer_acts(
736
- self.get_buffer(self.half_buffer_size, raise_on_epoch_end=True),
737
- self.exclude_special_tokens,
738
- )
739
- except StopIteration:
740
- warnings.warn(
741
- "All samples in the training dataset have been exhausted, we are now beginning a new epoch with the same samples."
742
- )
743
- self._storage_buffer = (
744
- None # dump the current buffer so samples do not leak between epochs
745
- )
679
+ Iterate over the filtered tokens in the buffer.
680
+ """
681
+ while True:
746
682
  try:
747
- new_samples = _filter_buffer_acts(
748
- self.get_buffer(self.half_buffer_size),
749
- self.exclude_special_tokens,
683
+ yield self.get_filtered_buffer(
684
+ self.half_buffer_size, raise_on_epoch_end=True
750
685
  )
751
686
  except StopIteration:
752
- raise ValueError(
753
- "We were unable to fill up the buffer directly after starting a new epoch. This could indicate that there are less samples in the dataset than are required to fill up the buffer. Consider reducing batch_size or n_batches_in_buffer. "
687
+ warnings.warn(
688
+ "All samples in the training dataset have been exhausted, beginning new epoch."
754
689
  )
690
+ try:
691
+ yield self.get_filtered_buffer(self.half_buffer_size)
692
+ except StopIteration:
693
+ raise ValueError(
694
+ "Unable to fill buffer after starting new epoch. Dataset may be too small."
695
+ )
755
696
 
756
- # 1. # create new buffer by mixing stored and new buffer
757
- mixing_buffer = torch.cat(
758
- [new_samples, self.storage_buffer],
759
- dim=0,
760
- )
761
-
762
- mixing_buffer = mixing_buffer[torch.randperm(mixing_buffer.shape[0])]
763
-
764
- # 2. put 50 % in storage
765
- self._storage_buffer = mixing_buffer[: mixing_buffer.shape[0] // 2]
766
-
767
- # 3. put other 50 % in a dataloader
768
- return iter(
769
- DataLoader(
770
- # TODO: seems like a typing bug?
771
- cast(Any, mixing_buffer[mixing_buffer.shape[0] // 2 :]),
772
- batch_size=batch_size,
773
- shuffle=True,
774
- )
697
+ def get_data_loader(
698
+ self,
699
+ ) -> Iterator[Any]:
700
+ """
701
+ Return an auto-refilling stream of filtered and mixed activations.
702
+ """
703
+ return mixing_buffer(
704
+ buffer_size=self.n_batches_in_buffer * self.training_context_size,
705
+ batch_size=self.train_batch_size_tokens,
706
+ activations_loader=self._iterate_filtered_activations(),
775
707
  )
776
708
 
777
709
  def next_batch(self) -> torch.Tensor:
778
- """
779
- Get the next batch from the current DataLoader.
780
- If the DataLoader is exhausted, refill the buffer and create a new DataLoader.
781
- """
782
- try:
783
- # Try to get the next batch
784
- return next(self.dataloader)
785
- except StopIteration:
786
- # If the DataLoader is exhausted, create a new one
710
+ """Get next batch, updating buffer if needed."""
711
+ return self.__next__()
712
+
713
+ # ActivationsStore should be an iterator
714
+ def __next__(self) -> torch.Tensor:
715
+ if self._dataloader is None:
787
716
  self._dataloader = self.get_data_loader()
788
- return next(self.dataloader)
717
+ return next(self._dataloader)
718
+
719
+ def __iter__(self) -> Iterator[torch.Tensor]:
720
+ return self
789
721
 
790
722
  def state_dict(self) -> dict[str, torch.Tensor]:
791
- result = {
792
- "n_dataset_processed": torch.tensor(self.n_dataset_processed),
793
- }
794
- if self._storage_buffer is not None: # first time might be None
795
- result["storage_buffer_activations"] = self._storage_buffer[0]
796
- if self._storage_buffer[1] is not None:
797
- result["storage_buffer_tokens"] = self._storage_buffer[1]
798
- if self.estimated_norm_scaling_factor is not None:
799
- result["estimated_norm_scaling_factor"] = torch.tensor(
800
- self.estimated_norm_scaling_factor
801
- )
802
- return result
723
+ return {"n_dataset_processed": torch.tensor(self.n_dataset_processed)}
803
724
 
804
725
  def save(self, file_path: str):
805
726
  """save the state dict to a file in safetensors format"""
@@ -0,0 +1,56 @@
1
+ from collections.abc import Iterator
2
+
3
+ import torch
4
+
5
+
6
+ @torch.no_grad()
7
+ def mixing_buffer(
8
+ buffer_size: int,
9
+ batch_size: int,
10
+ activations_loader: Iterator[torch.Tensor],
11
+ ) -> Iterator[torch.Tensor]:
12
+ """
13
+ A generator that maintains a mix of old and new activations for better training.
14
+ It stores half of the activations and mixes them with new ones to create batches.
15
+
16
+ Args:
17
+ buffer_size: Total size of the buffer (will store buffer_size/2 activations)
18
+ batch_size: Size of batches to return
19
+ activations_loader: Iterator providing new activations
20
+
21
+ Yields:
22
+ Batches of activations of shape (batch_size, *activation_dims)
23
+ """
24
+
25
+ if buffer_size < batch_size:
26
+ raise ValueError("Buffer size must be greater than or equal to batch size")
27
+
28
+ storage_buffer: torch.Tensor | None = None
29
+
30
+ for new_activations in activations_loader:
31
+ storage_buffer = (
32
+ new_activations
33
+ if storage_buffer is None
34
+ else torch.cat([storage_buffer, new_activations], dim=0)
35
+ )
36
+
37
+ if storage_buffer.shape[0] >= buffer_size:
38
+ # Shuffle
39
+ storage_buffer = storage_buffer[torch.randperm(storage_buffer.shape[0])]
40
+
41
+ num_serving_batches = max(1, storage_buffer.shape[0] // (2 * batch_size))
42
+ serving_cutoff = num_serving_batches * batch_size
43
+ serving_buffer = storage_buffer[:serving_cutoff]
44
+ storage_buffer = storage_buffer[serving_cutoff:]
45
+
46
+ # Yield batches from the serving_buffer
47
+ for batch_idx in range(num_serving_batches):
48
+ yield serving_buffer[
49
+ batch_idx * batch_size : (batch_idx + 1) * batch_size
50
+ ]
51
+
52
+ # If there are any remaining activations, yield them
53
+ if storage_buffer is not None:
54
+ remaining_batches = storage_buffer.shape[0] // batch_size
55
+ for i in range(remaining_batches):
56
+ yield storage_buffer[i * batch_size : (i + 1) * batch_size]