sae-lens 6.0.0rc2__py3-none-any.whl → 6.0.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -193,7 +193,6 @@ def handle_pre_6_0_config(cfg_dict: dict[str, Any]) -> dict[str, Any]:
193
193
 
194
194
  rename_keys_map = {
195
195
  "hook_point": "hook_name",
196
- "hook_point_layer": "hook_layer",
197
196
  "hook_point_head_index": "hook_head_index",
198
197
  "activation_fn_str": "activation_fn",
199
198
  }
@@ -262,7 +261,6 @@ def get_connor_rob_hook_z_config_from_hf(
262
261
  "device": device if device is not None else "cpu",
263
262
  "model_name": "gpt2-small",
264
263
  "hook_name": old_cfg_dict["act_name"],
265
- "hook_layer": old_cfg_dict["layer"],
266
264
  "hook_head_index": None,
267
265
  "activation_fn": "relu",
268
266
  "apply_b_dec_to_input": True,
@@ -411,7 +409,6 @@ def get_gemma_2_config_from_hf(
411
409
  "dtype": "float32",
412
410
  "model_name": model_name,
413
411
  "hook_name": hook_name,
414
- "hook_layer": layer,
415
412
  "hook_head_index": None,
416
413
  "activation_fn": "relu",
417
414
  "finetuning_scaling_factor": False,
@@ -524,7 +521,6 @@ def get_llama_scope_config_from_hf(
524
521
  "dtype": "bfloat16",
525
522
  "model_name": model_name,
526
523
  "hook_name": old_cfg_dict["hook_point_in"],
527
- "hook_layer": int(old_cfg_dict["hook_point_in"].split(".")[1]),
528
524
  "hook_head_index": None,
529
525
  "activation_fn": "relu",
530
526
  "finetuning_scaling_factor": False,
@@ -651,7 +647,6 @@ def get_dictionary_learning_config_1_from_hf(
651
647
  "device": device,
652
648
  "model_name": trainer["lm_name"].split("/")[-1],
653
649
  "hook_name": hook_point_name,
654
- "hook_layer": trainer["layer"],
655
650
  "hook_head_index": None,
656
651
  "activation_fn": activation_fn,
657
652
  "activation_fn_kwargs": activation_fn_kwargs,
@@ -690,7 +685,6 @@ def get_deepseek_r1_config_from_hf(
690
685
  "context_size": 1024,
691
686
  "model_name": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
692
687
  "hook_name": f"blocks.{layer}.hook_resid_post",
693
- "hook_layer": layer,
694
688
  "hook_head_index": None,
695
689
  "prepend_bos": True,
696
690
  "dataset_path": "lmsys/lmsys-chat-1m",
@@ -849,7 +843,6 @@ def get_llama_scope_r1_distill_config_from_hf(
849
843
  "device": device,
850
844
  "model_name": model_name,
851
845
  "hook_name": huggingface_cfg_dict["hook_point_in"],
852
- "hook_layer": int(huggingface_cfg_dict["hook_point_in"].split(".")[1]),
853
846
  "hook_head_index": None,
854
847
  "activation_fn": "relu",
855
848
  "finetuning_scaling_factor": False,
sae_lens/saes/sae.py CHANGED
@@ -66,7 +66,6 @@ class SAEMetadata:
66
66
  model_name: str | None = None
67
67
  hook_name: str | None = None
68
68
  model_class_name: str | None = None
69
- hook_layer: int | None = None
70
69
  hook_head_index: int | None = None
71
70
  model_from_pretrained_kwargs: dict[str, Any] | None = None
72
71
  prepend_bos: bool | None = None
@@ -649,7 +648,6 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
649
648
  class TrainingSAEConfig(SAEConfig, ABC):
650
649
  noise_scale: float = 0.0
651
650
  mse_loss_normalization: str | None = None
652
- b_dec_init_method: Literal["zeros", "geometric_median", "mean"] = "zeros"
653
651
  # https://transformer-circuits.pub/2024/april-update/index.html#training-saes
654
652
  # 0.1 corresponds to the "heuristic" initialization, use None to disable
655
653
  decoder_init_norm: float | None = 0.1
@@ -666,7 +664,6 @@ class TrainingSAEConfig(SAEConfig, ABC):
666
664
  metadata = SAEMetadata(
667
665
  model_name=cfg.model_name,
668
666
  hook_name=cfg.hook_name,
669
- hook_layer=cfg.hook_layer,
670
667
  hook_head_index=cfg.hook_head_index,
671
668
  context_size=cfg.context_size,
672
669
  prepend_bos=cfg.prepend_bos,
@@ -0,0 +1,53 @@
1
+ import json
2
+ from dataclasses import dataclass
3
+ from statistics import mean
4
+
5
+ import torch
6
+ from tqdm import tqdm
7
+
8
+ from sae_lens.training.types import DataProvider
9
+
10
+
11
+ @dataclass
12
+ class ActivationScaler:
13
+ scaling_factor: float | None = None
14
+
15
+ def scale(self, acts: torch.Tensor) -> torch.Tensor:
16
+ return acts if self.scaling_factor is None else acts * self.scaling_factor
17
+
18
+ def unscale(self, acts: torch.Tensor) -> torch.Tensor:
19
+ return acts if self.scaling_factor is None else acts / self.scaling_factor
20
+
21
+ def __call__(self, acts: torch.Tensor) -> torch.Tensor:
22
+ return self.scale(acts)
23
+
24
+ @torch.no_grad()
25
+ def _calculate_mean_norm(
26
+ self, data_provider: DataProvider, n_batches_for_norm_estimate: int = int(1e3)
27
+ ) -> float:
28
+ norms_per_batch: list[float] = []
29
+ for _ in tqdm(
30
+ range(n_batches_for_norm_estimate), desc="Estimating norm scaling factor"
31
+ ):
32
+ acts = next(data_provider)
33
+ norms_per_batch.append(acts.norm(dim=-1).mean().item())
34
+ return mean(norms_per_batch)
35
+
36
+ def estimate_scaling_factor(
37
+ self,
38
+ d_in: int,
39
+ data_provider: DataProvider,
40
+ n_batches_for_norm_estimate: int = int(1e3),
41
+ ):
42
+ mean_norm = self._calculate_mean_norm(
43
+ data_provider, n_batches_for_norm_estimate
44
+ )
45
+ self.scaling_factor = (d_in**0.5) / mean_norm
46
+
47
+ def save(self, file_path: str):
48
+ """save the state dict to a file in json format"""
49
+ if not file_path.endswith(".json"):
50
+ raise ValueError("file_path must end with .json")
51
+
52
+ with open(file_path, "w") as f:
53
+ json.dump({"scaling_factor": self.scaling_factor}, f)
@@ -1,6 +1,5 @@
1
1
  from __future__ import annotations
2
2
 
3
- import contextlib
4
3
  import json
5
4
  import os
6
5
  import warnings
@@ -16,7 +15,6 @@ from huggingface_hub.utils import HfHubHTTPError
16
15
  from jaxtyping import Float, Int
17
16
  from requests import HTTPError
18
17
  from safetensors.torch import save_file
19
- from torch.utils.data import DataLoader
20
18
  from tqdm import tqdm
21
19
  from transformer_lens.hook_points import HookedRootModule
22
20
  from transformers import AutoTokenizer, PreTrainedTokenizerBase
@@ -30,6 +28,8 @@ from sae_lens.config import (
30
28
  from sae_lens.constants import DTYPE_MAP
31
29
  from sae_lens.saes.sae import SAE, T_SAE_CONFIG, T_TRAINING_SAE_CONFIG
32
30
  from sae_lens.tokenization_and_batching import concat_and_batch_sequences
31
+ from sae_lens.training.mixing_buffer import mixing_buffer
32
+ from sae_lens.util import extract_stop_at_layer_from_tlens_hook_name
33
33
 
34
34
 
35
35
  # TODO: Make an activation store config class to be consistent with the rest of the code.
@@ -45,10 +45,8 @@ class ActivationsStore:
45
45
  cached_activation_dataset: Dataset | None = None
46
46
  tokens_column: Literal["tokens", "input_ids", "text", "problem"]
47
47
  hook_name: str
48
- hook_layer: int
49
48
  hook_head_index: int | None
50
49
  _dataloader: Iterator[Any] | None = None
51
- _storage_buffer: torch.Tensor | None = None
52
50
  exclude_special_tokens: torch.Tensor | None = None
53
51
  device: torch.device
54
52
 
@@ -65,7 +63,6 @@ class ActivationsStore:
65
63
  cached_activations_path=cfg.new_cached_activations_path,
66
64
  dtype=cfg.dtype,
67
65
  hook_name=cfg.hook_name,
68
- hook_layer=cfg.hook_layer,
69
66
  context_size=cfg.context_size,
70
67
  d_in=cfg.d_in,
71
68
  n_batches_in_buffer=cfg.n_batches_in_buffer,
@@ -126,7 +123,6 @@ class ActivationsStore:
126
123
  dataset=override_dataset or cfg.dataset_path,
127
124
  streaming=cfg.streaming,
128
125
  hook_name=cfg.hook_name,
129
- hook_layer=cfg.hook_layer,
130
126
  hook_head_index=cfg.hook_head_index,
131
127
  context_size=cfg.context_size,
132
128
  d_in=cfg.d_in
@@ -165,8 +161,6 @@ class ActivationsStore:
165
161
  ) -> ActivationsStore:
166
162
  if sae.cfg.metadata.hook_name is None:
167
163
  raise ValueError("hook_name is required")
168
- if sae.cfg.metadata.hook_layer is None:
169
- raise ValueError("hook_layer is required")
170
164
  if sae.cfg.metadata.hook_head_index is None:
171
165
  raise ValueError("hook_head_index is required")
172
166
  if sae.cfg.metadata.context_size is None:
@@ -178,7 +172,6 @@ class ActivationsStore:
178
172
  dataset=dataset,
179
173
  d_in=sae.cfg.d_in,
180
174
  hook_name=sae.cfg.metadata.hook_name,
181
- hook_layer=sae.cfg.metadata.hook_layer,
182
175
  hook_head_index=sae.cfg.metadata.hook_head_index,
183
176
  context_size=sae.cfg.metadata.context_size
184
177
  if context_size is None
@@ -202,7 +195,6 @@ class ActivationsStore:
202
195
  dataset: HfDataset | str,
203
196
  streaming: bool,
204
197
  hook_name: str,
205
- hook_layer: int,
206
198
  hook_head_index: int | None,
207
199
  context_size: int,
208
200
  d_in: int,
@@ -246,7 +238,6 @@ class ActivationsStore:
246
238
  )
247
239
 
248
240
  self.hook_name = hook_name
249
- self.hook_layer = hook_layer
250
241
  self.hook_head_index = hook_head_index
251
242
  self.context_size = context_size
252
243
  self.d_in = d_in
@@ -262,12 +253,11 @@ class ActivationsStore:
262
253
  self.cached_activations_path = cached_activations_path
263
254
  self.autocast_lm = autocast_lm
264
255
  self.seqpos_slice = seqpos_slice
256
+ self.training_context_size = len(range(context_size)[slice(*seqpos_slice)])
265
257
  self.exclude_special_tokens = exclude_special_tokens
266
258
 
267
259
  self.n_dataset_processed = 0
268
260
 
269
- self.estimated_norm_scaling_factor = None
270
-
271
261
  # Check if dataset is tokenized
272
262
  dataset_sample = next(iter(self.dataset))
273
263
 
@@ -432,30 +422,6 @@ class ActivationsStore:
432
422
 
433
423
  return activations_dataset
434
424
 
435
- def set_norm_scaling_factor_if_needed(self):
436
- if (
437
- self.normalize_activations == "expected_average_only_in"
438
- and self.estimated_norm_scaling_factor is None
439
- ):
440
- self.estimated_norm_scaling_factor = self.estimate_norm_scaling_factor()
441
-
442
- def apply_norm_scaling_factor(self, activations: torch.Tensor) -> torch.Tensor:
443
- if self.estimated_norm_scaling_factor is None:
444
- raise ValueError(
445
- "estimated_norm_scaling_factor is not set, call set_norm_scaling_factor_if_needed() first"
446
- )
447
- return activations * self.estimated_norm_scaling_factor
448
-
449
- def unscale(self, activations: torch.Tensor) -> torch.Tensor:
450
- if self.estimated_norm_scaling_factor is None:
451
- raise ValueError(
452
- "estimated_norm_scaling_factor is not set, call set_norm_scaling_factor_if_needed() first"
453
- )
454
- return activations / self.estimated_norm_scaling_factor
455
-
456
- def get_norm_scaling_factor(self, activations: torch.Tensor) -> torch.Tensor:
457
- return (self.d_in**0.5) / activations.norm(dim=-1).mean()
458
-
459
425
  @torch.no_grad()
460
426
  def estimate_norm_scaling_factor(self, n_batches_for_norm_estimate: int = int(1e3)):
461
427
  norms_per_batch = []
@@ -490,21 +456,6 @@ class ActivationsStore:
490
456
  """
491
457
  self.iterable_dataset = iter(self.dataset)
492
458
 
493
- @property
494
- def storage_buffer(self) -> torch.Tensor:
495
- if self._storage_buffer is None:
496
- self._storage_buffer = _filter_buffer_acts(
497
- self.get_buffer(self.half_buffer_size), self.exclude_special_tokens
498
- )
499
-
500
- return self._storage_buffer
501
-
502
- @property
503
- def dataloader(self) -> Iterator[Any]:
504
- if self._dataloader is None:
505
- self._dataloader = self.get_data_loader()
506
- return self._dataloader
507
-
508
459
  def get_batch_tokens(
509
460
  self, batch_size: int | None = None, raise_at_epoch_end: bool = False
510
461
  ):
@@ -537,22 +488,17 @@ class ActivationsStore:
537
488
 
538
489
  d_in may result from a concatenated head dimension.
539
490
  """
540
-
541
- # Setup autocast if using
542
- if self.autocast_lm:
543
- autocast_if_enabled = torch.autocast(
544
- device_type="cuda",
545
- dtype=torch.bfloat16,
546
- enabled=self.autocast_lm,
547
- )
548
- else:
549
- autocast_if_enabled = contextlib.nullcontext()
550
-
551
- with autocast_if_enabled:
491
+ with torch.autocast(
492
+ device_type="cuda",
493
+ dtype=torch.bfloat16,
494
+ enabled=self.autocast_lm,
495
+ ):
552
496
  layerwise_activations_cache = self.model.run_with_cache(
553
497
  batch_tokens,
554
498
  names_filter=[self.hook_name],
555
- stop_at_layer=self.hook_layer + 1,
499
+ stop_at_layer=extract_stop_at_layer_from_tlens_hook_name(
500
+ self.hook_name
501
+ ),
556
502
  prepend_bos=False,
557
503
  **self.model_kwargs,
558
504
  )[1]
@@ -563,25 +509,25 @@ class ActivationsStore:
563
509
 
564
510
  n_batches, n_context = layerwise_activations.shape[:2]
565
511
 
566
- stacked_activations = torch.zeros((n_batches, n_context, 1, self.d_in))
512
+ stacked_activations = torch.zeros((n_batches, n_context, self.d_in))
567
513
 
568
514
  if self.hook_head_index is not None:
569
- stacked_activations[:, :, 0] = layerwise_activations[
515
+ stacked_activations[:, :] = layerwise_activations[
570
516
  :, :, self.hook_head_index
571
517
  ]
572
518
  elif layerwise_activations.ndim > 3: # if we have a head dimension
573
519
  try:
574
- stacked_activations[:, :, 0] = layerwise_activations.view(
520
+ stacked_activations[:, :] = layerwise_activations.view(
575
521
  n_batches, n_context, -1
576
522
  )
577
523
  except RuntimeError as e:
578
524
  logger.error(f"Error during view operation: {e}")
579
525
  logger.info("Attempting to use reshape instead...")
580
- stacked_activations[:, :, 0] = layerwise_activations.reshape(
526
+ stacked_activations[:, :] = layerwise_activations.reshape(
581
527
  n_batches, n_context, -1
582
528
  )
583
529
  else:
584
- stacked_activations[:, :, 0] = layerwise_activations
530
+ stacked_activations[:, :] = layerwise_activations
585
531
 
586
532
  return stacked_activations
587
533
 
@@ -589,7 +535,6 @@ class ActivationsStore:
589
535
  self,
590
536
  total_size: int,
591
537
  context_size: int,
592
- num_layers: int,
593
538
  d_in: int,
594
539
  raise_on_epoch_end: bool,
595
540
  ) -> tuple[
@@ -606,10 +551,9 @@ class ActivationsStore:
606
551
  """
607
552
  assert self.cached_activation_dataset is not None
608
553
  # In future, could be a list of multiple hook names
609
- hook_names = [self.hook_name]
610
- if not set(hook_names).issubset(self.cached_activation_dataset.column_names):
554
+ if self.hook_name not in self.cached_activation_dataset.column_names:
611
555
  raise ValueError(
612
- f"Missing columns in dataset. Expected {hook_names}, "
556
+ f"Missing columns in dataset. Expected {self.hook_name}, "
613
557
  f"got {self.cached_activation_dataset.column_names}."
614
558
  )
615
559
 
@@ -622,28 +566,17 @@ class ActivationsStore:
622
566
  ds_slice = self.cached_activation_dataset[
623
567
  self.current_row_idx : self.current_row_idx + total_size
624
568
  ]
625
- for hook_name in hook_names:
626
- # Load activations for each hook.
627
- # Usually faster to first slice dataset then pick column
628
- _hook_buffer = ds_slice[hook_name]
629
- if _hook_buffer.shape != (total_size, context_size, d_in):
630
- raise ValueError(
631
- f"_hook_buffer has shape {_hook_buffer.shape}, "
632
- f"but expected ({total_size}, {context_size}, {d_in})."
633
- )
634
- new_buffer.append(_hook_buffer)
635
-
636
- # Stack across num_layers dimension
637
- # list of num_layers; shape: (total_size, context_size, d_in) -> (total_size, context_size, num_layers, d_in)
638
- new_buffer = torch.stack(new_buffer, dim=2)
639
- if new_buffer.shape != (total_size, context_size, num_layers, d_in):
569
+ # Load activations for each hook.
570
+ # Usually faster to first slice dataset then pick column
571
+ new_buffer = ds_slice[self.hook_name]
572
+ if new_buffer.shape != (total_size, context_size, d_in):
640
573
  raise ValueError(
641
574
  f"new_buffer has shape {new_buffer.shape}, "
642
- f"but expected ({total_size}, {context_size}, {num_layers}, {d_in})."
575
+ f"but expected ({total_size}, {context_size}, {d_in})."
643
576
  )
644
577
 
645
578
  self.current_row_idx += total_size
646
- acts_buffer = new_buffer.reshape(total_size * context_size, num_layers, d_in)
579
+ acts_buffer = new_buffer.reshape(total_size * context_size, d_in)
647
580
 
648
581
  if "token_ids" not in self.cached_activation_dataset.column_names:
649
582
  return acts_buffer, None
@@ -658,7 +591,7 @@ class ActivationsStore:
658
591
  return acts_buffer, token_ids_buffer
659
592
 
660
593
  @torch.no_grad()
661
- def get_buffer(
594
+ def get_raw_buffer(
662
595
  self,
663
596
  n_batches_in_buffer: int,
664
597
  raise_on_epoch_end: bool = False,
@@ -672,26 +605,24 @@ class ActivationsStore:
672
605
  If raise_on_epoch_end is True, when the dataset it exhausted it will automatically refill the dataset and then raise a StopIteration so that the caller has a chance to react.
673
606
  """
674
607
  context_size = self.context_size
675
- training_context_size = len(range(context_size)[slice(*self.seqpos_slice)])
676
608
  batch_size = self.store_batch_size_prompts
677
609
  d_in = self.d_in
678
610
  total_size = batch_size * n_batches_in_buffer
679
- num_layers = 1
680
611
 
681
612
  if self.cached_activation_dataset is not None:
682
613
  return self._load_buffer_from_cached(
683
- total_size, context_size, num_layers, d_in, raise_on_epoch_end
614
+ total_size, context_size, d_in, raise_on_epoch_end
684
615
  )
685
616
 
686
617
  refill_iterator = range(0, total_size, batch_size)
687
618
  # Initialize empty tensor buffer of the maximum required size with an additional dimension for layers
688
619
  new_buffer_activations = torch.zeros(
689
- (total_size, training_context_size, num_layers, d_in),
620
+ (total_size, self.training_context_size, d_in),
690
621
  dtype=self.dtype, # type: ignore
691
622
  device=self.device,
692
623
  )
693
624
  new_buffer_token_ids = torch.zeros(
694
- (total_size, training_context_size),
625
+ (total_size, self.training_context_size),
695
626
  dtype=torch.long,
696
627
  device=self.device,
697
628
  )
@@ -716,106 +647,80 @@ class ActivationsStore:
716
647
  refill_batch_idx_start : refill_batch_idx_start + batch_size, ...
717
648
  ] = refill_batch_tokens
718
649
 
719
- new_buffer_activations = new_buffer_activations.reshape(-1, num_layers, d_in)
650
+ new_buffer_activations = new_buffer_activations.reshape(-1, d_in)
720
651
  new_buffer_token_ids = new_buffer_token_ids.reshape(-1)
721
652
  if shuffle:
722
653
  new_buffer_activations, new_buffer_token_ids = permute_together(
723
654
  [new_buffer_activations, new_buffer_token_ids]
724
655
  )
725
656
 
726
- # every buffer should be normalized:
727
- if self.normalize_activations == "expected_average_only_in":
728
- new_buffer_activations = self.apply_norm_scaling_factor(
729
- new_buffer_activations
730
- )
731
-
732
657
  return (
733
658
  new_buffer_activations,
734
659
  new_buffer_token_ids,
735
660
  )
736
661
 
737
- def get_data_loader(
662
+ def get_filtered_buffer(
738
663
  self,
739
- ) -> Iterator[Any]:
740
- """
741
- Return a torch.utils.dataloader which you can get batches from.
742
-
743
- Should automatically refill the buffer when it gets to n % full.
744
- (better mixing if you refill and shuffle regularly).
664
+ n_batches_in_buffer: int,
665
+ raise_on_epoch_end: bool = False,
666
+ shuffle: bool = True,
667
+ ) -> torch.Tensor:
668
+ return _filter_buffer_acts(
669
+ self.get_raw_buffer(
670
+ n_batches_in_buffer=n_batches_in_buffer,
671
+ raise_on_epoch_end=raise_on_epoch_end,
672
+ shuffle=shuffle,
673
+ ),
674
+ self.exclude_special_tokens,
675
+ )
745
676
 
677
+ def _iterate_filtered_activations(self) -> Generator[torch.Tensor, None, None]:
746
678
  """
747
-
748
- batch_size = self.train_batch_size_tokens
749
-
750
- try:
751
- new_samples = _filter_buffer_acts(
752
- self.get_buffer(self.half_buffer_size, raise_on_epoch_end=True),
753
- self.exclude_special_tokens,
754
- )
755
- except StopIteration:
756
- warnings.warn(
757
- "All samples in the training dataset have been exhausted, we are now beginning a new epoch with the same samples."
758
- )
759
- self._storage_buffer = (
760
- None # dump the current buffer so samples do not leak between epochs
761
- )
679
+ Iterate over the filtered tokens in the buffer.
680
+ """
681
+ while True:
762
682
  try:
763
- new_samples = _filter_buffer_acts(
764
- self.get_buffer(self.half_buffer_size),
765
- self.exclude_special_tokens,
683
+ yield self.get_filtered_buffer(
684
+ self.half_buffer_size, raise_on_epoch_end=True
766
685
  )
767
686
  except StopIteration:
768
- raise ValueError(
769
- "We were unable to fill up the buffer directly after starting a new epoch. This could indicate that there are less samples in the dataset than are required to fill up the buffer. Consider reducing batch_size or n_batches_in_buffer. "
687
+ warnings.warn(
688
+ "All samples in the training dataset have been exhausted, beginning new epoch."
770
689
  )
690
+ try:
691
+ yield self.get_filtered_buffer(self.half_buffer_size)
692
+ except StopIteration:
693
+ raise ValueError(
694
+ "Unable to fill buffer after starting new epoch. Dataset may be too small."
695
+ )
771
696
 
772
- # 1. # create new buffer by mixing stored and new buffer
773
- mixing_buffer = torch.cat(
774
- [new_samples, self.storage_buffer],
775
- dim=0,
776
- )
777
-
778
- mixing_buffer = mixing_buffer[torch.randperm(mixing_buffer.shape[0])]
779
-
780
- # 2. put 50 % in storage
781
- self._storage_buffer = mixing_buffer[: mixing_buffer.shape[0] // 2]
782
-
783
- # 3. put other 50 % in a dataloader
784
- return iter(
785
- DataLoader(
786
- # TODO: seems like a typing bug?
787
- cast(Any, mixing_buffer[mixing_buffer.shape[0] // 2 :]),
788
- batch_size=batch_size,
789
- shuffle=True,
790
- )
697
+ def get_data_loader(
698
+ self,
699
+ ) -> Iterator[Any]:
700
+ """
701
+ Return an auto-refilling stream of filtered and mixed activations.
702
+ """
703
+ return mixing_buffer(
704
+ buffer_size=self.n_batches_in_buffer * self.training_context_size,
705
+ batch_size=self.train_batch_size_tokens,
706
+ activations_loader=self._iterate_filtered_activations(),
791
707
  )
792
708
 
793
709
  def next_batch(self) -> torch.Tensor:
794
- """
795
- Get the next batch from the current DataLoader.
796
- If the DataLoader is exhausted, refill the buffer and create a new DataLoader.
797
- """
798
- try:
799
- # Try to get the next batch
800
- return next(self.dataloader)
801
- except StopIteration:
802
- # If the DataLoader is exhausted, create a new one
710
+ """Get next batch, updating buffer if needed."""
711
+ return self.__next__()
712
+
713
+ # ActivationsStore should be an iterator
714
+ def __next__(self) -> torch.Tensor:
715
+ if self._dataloader is None:
803
716
  self._dataloader = self.get_data_loader()
804
- return next(self.dataloader)
717
+ return next(self._dataloader)
718
+
719
+ def __iter__(self) -> Iterator[torch.Tensor]:
720
+ return self
805
721
 
806
722
  def state_dict(self) -> dict[str, torch.Tensor]:
807
- result = {
808
- "n_dataset_processed": torch.tensor(self.n_dataset_processed),
809
- }
810
- if self._storage_buffer is not None: # first time might be None
811
- result["storage_buffer_activations"] = self._storage_buffer[0]
812
- if self._storage_buffer[1] is not None:
813
- result["storage_buffer_tokens"] = self._storage_buffer[1]
814
- if self.estimated_norm_scaling_factor is not None:
815
- result["estimated_norm_scaling_factor"] = torch.tensor(
816
- self.estimated_norm_scaling_factor
817
- )
818
- return result
723
+ return {"n_dataset_processed": torch.tensor(self.n_dataset_processed)}
819
724
 
820
725
  def save(self, file_path: str):
821
726
  """save the state dict to a file in safetensors format"""
@@ -0,0 +1,56 @@
1
+ from collections.abc import Iterator
2
+
3
+ import torch
4
+
5
+
6
+ @torch.no_grad()
7
+ def mixing_buffer(
8
+ buffer_size: int,
9
+ batch_size: int,
10
+ activations_loader: Iterator[torch.Tensor],
11
+ ) -> Iterator[torch.Tensor]:
12
+ """
13
+ A generator that maintains a mix of old and new activations for better training.
14
+ It stores half of the activations and mixes them with new ones to create batches.
15
+
16
+ Args:
17
+ buffer_size: Total size of the buffer (will store buffer_size/2 activations)
18
+ batch_size: Size of batches to return
19
+ activations_loader: Iterator providing new activations
20
+
21
+ Yields:
22
+ Batches of activations of shape (batch_size, *activation_dims)
23
+ """
24
+
25
+ if buffer_size < batch_size:
26
+ raise ValueError("Buffer size must be greater than or equal to batch size")
27
+
28
+ storage_buffer: torch.Tensor | None = None
29
+
30
+ for new_activations in activations_loader:
31
+ storage_buffer = (
32
+ new_activations
33
+ if storage_buffer is None
34
+ else torch.cat([storage_buffer, new_activations], dim=0)
35
+ )
36
+
37
+ if storage_buffer.shape[0] >= buffer_size:
38
+ # Shuffle
39
+ storage_buffer = storage_buffer[torch.randperm(storage_buffer.shape[0])]
40
+
41
+ num_serving_batches = max(1, storage_buffer.shape[0] // (2 * batch_size))
42
+ serving_cutoff = num_serving_batches * batch_size
43
+ serving_buffer = storage_buffer[:serving_cutoff]
44
+ storage_buffer = storage_buffer[serving_cutoff:]
45
+
46
+ # Yield batches from the serving_buffer
47
+ for batch_idx in range(num_serving_batches):
48
+ yield serving_buffer[
49
+ batch_idx * batch_size : (batch_idx + 1) * batch_size
50
+ ]
51
+
52
+ # If there are any remaining activations, yield them
53
+ if storage_buffer is not None:
54
+ remaining_batches = storage_buffer.shape[0] // batch_size
55
+ for i in range(remaining_batches):
56
+ yield storage_buffer[i * batch_size : (i + 1) * batch_size]