sae-lens 6.0.0rc2__py3-none-any.whl → 6.0.0rc4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +6 -3
- sae_lens/analysis/neuronpedia_integration.py +3 -3
- sae_lens/cache_activations_runner.py +7 -6
- sae_lens/config.py +50 -6
- sae_lens/constants.py +2 -0
- sae_lens/evals.py +39 -28
- sae_lens/llm_sae_training_runner.py +377 -0
- sae_lens/load_model.py +53 -5
- sae_lens/loading/pretrained_sae_loaders.py +24 -12
- sae_lens/saes/gated_sae.py +0 -4
- sae_lens/saes/jumprelu_sae.py +4 -10
- sae_lens/saes/sae.py +121 -51
- sae_lens/saes/standard_sae.py +4 -11
- sae_lens/saes/topk_sae.py +18 -12
- sae_lens/training/activation_scaler.py +53 -0
- sae_lens/training/activations_store.py +77 -174
- sae_lens/training/mixing_buffer.py +56 -0
- sae_lens/training/sae_trainer.py +107 -98
- sae_lens/training/types.py +5 -0
- sae_lens/training/upload_saes_to_huggingface.py +1 -1
- sae_lens/util.py +19 -0
- {sae_lens-6.0.0rc2.dist-info → sae_lens-6.0.0rc4.dist-info}/METADATA +1 -1
- sae_lens-6.0.0rc4.dist-info/RECORD +37 -0
- sae_lens/sae_training_runner.py +0 -237
- sae_lens/training/geometric_median.py +0 -101
- sae_lens-6.0.0rc2.dist-info/RECORD +0 -35
- {sae_lens-6.0.0rc2.dist-info → sae_lens-6.0.0rc4.dist-info}/LICENSE +0 -0
- {sae_lens-6.0.0rc2.dist-info → sae_lens-6.0.0rc4.dist-info}/WHEEL +0 -0
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import contextlib
|
|
4
3
|
import json
|
|
5
4
|
import os
|
|
6
5
|
import warnings
|
|
@@ -16,7 +15,6 @@ from huggingface_hub.utils import HfHubHTTPError
|
|
|
16
15
|
from jaxtyping import Float, Int
|
|
17
16
|
from requests import HTTPError
|
|
18
17
|
from safetensors.torch import save_file
|
|
19
|
-
from torch.utils.data import DataLoader
|
|
20
18
|
from tqdm import tqdm
|
|
21
19
|
from transformer_lens.hook_points import HookedRootModule
|
|
22
20
|
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
|
@@ -30,6 +28,8 @@ from sae_lens.config import (
|
|
|
30
28
|
from sae_lens.constants import DTYPE_MAP
|
|
31
29
|
from sae_lens.saes.sae import SAE, T_SAE_CONFIG, T_TRAINING_SAE_CONFIG
|
|
32
30
|
from sae_lens.tokenization_and_batching import concat_and_batch_sequences
|
|
31
|
+
from sae_lens.training.mixing_buffer import mixing_buffer
|
|
32
|
+
from sae_lens.util import extract_stop_at_layer_from_tlens_hook_name
|
|
33
33
|
|
|
34
34
|
|
|
35
35
|
# TODO: Make an activation store config class to be consistent with the rest of the code.
|
|
@@ -45,10 +45,8 @@ class ActivationsStore:
|
|
|
45
45
|
cached_activation_dataset: Dataset | None = None
|
|
46
46
|
tokens_column: Literal["tokens", "input_ids", "text", "problem"]
|
|
47
47
|
hook_name: str
|
|
48
|
-
hook_layer: int
|
|
49
48
|
hook_head_index: int | None
|
|
50
49
|
_dataloader: Iterator[Any] | None = None
|
|
51
|
-
_storage_buffer: torch.Tensor | None = None
|
|
52
50
|
exclude_special_tokens: torch.Tensor | None = None
|
|
53
51
|
device: torch.device
|
|
54
52
|
|
|
@@ -65,7 +63,6 @@ class ActivationsStore:
|
|
|
65
63
|
cached_activations_path=cfg.new_cached_activations_path,
|
|
66
64
|
dtype=cfg.dtype,
|
|
67
65
|
hook_name=cfg.hook_name,
|
|
68
|
-
hook_layer=cfg.hook_layer,
|
|
69
66
|
context_size=cfg.context_size,
|
|
70
67
|
d_in=cfg.d_in,
|
|
71
68
|
n_batches_in_buffer=cfg.n_batches_in_buffer,
|
|
@@ -126,7 +123,6 @@ class ActivationsStore:
|
|
|
126
123
|
dataset=override_dataset or cfg.dataset_path,
|
|
127
124
|
streaming=cfg.streaming,
|
|
128
125
|
hook_name=cfg.hook_name,
|
|
129
|
-
hook_layer=cfg.hook_layer,
|
|
130
126
|
hook_head_index=cfg.hook_head_index,
|
|
131
127
|
context_size=cfg.context_size,
|
|
132
128
|
d_in=cfg.d_in
|
|
@@ -165,10 +161,6 @@ class ActivationsStore:
|
|
|
165
161
|
) -> ActivationsStore:
|
|
166
162
|
if sae.cfg.metadata.hook_name is None:
|
|
167
163
|
raise ValueError("hook_name is required")
|
|
168
|
-
if sae.cfg.metadata.hook_layer is None:
|
|
169
|
-
raise ValueError("hook_layer is required")
|
|
170
|
-
if sae.cfg.metadata.hook_head_index is None:
|
|
171
|
-
raise ValueError("hook_head_index is required")
|
|
172
164
|
if sae.cfg.metadata.context_size is None:
|
|
173
165
|
raise ValueError("context_size is required")
|
|
174
166
|
if sae.cfg.metadata.prepend_bos is None:
|
|
@@ -178,7 +170,6 @@ class ActivationsStore:
|
|
|
178
170
|
dataset=dataset,
|
|
179
171
|
d_in=sae.cfg.d_in,
|
|
180
172
|
hook_name=sae.cfg.metadata.hook_name,
|
|
181
|
-
hook_layer=sae.cfg.metadata.hook_layer,
|
|
182
173
|
hook_head_index=sae.cfg.metadata.hook_head_index,
|
|
183
174
|
context_size=sae.cfg.metadata.context_size
|
|
184
175
|
if context_size is None
|
|
@@ -202,7 +193,6 @@ class ActivationsStore:
|
|
|
202
193
|
dataset: HfDataset | str,
|
|
203
194
|
streaming: bool,
|
|
204
195
|
hook_name: str,
|
|
205
|
-
hook_layer: int,
|
|
206
196
|
hook_head_index: int | None,
|
|
207
197
|
context_size: int,
|
|
208
198
|
d_in: int,
|
|
@@ -246,7 +236,6 @@ class ActivationsStore:
|
|
|
246
236
|
)
|
|
247
237
|
|
|
248
238
|
self.hook_name = hook_name
|
|
249
|
-
self.hook_layer = hook_layer
|
|
250
239
|
self.hook_head_index = hook_head_index
|
|
251
240
|
self.context_size = context_size
|
|
252
241
|
self.d_in = d_in
|
|
@@ -262,12 +251,11 @@ class ActivationsStore:
|
|
|
262
251
|
self.cached_activations_path = cached_activations_path
|
|
263
252
|
self.autocast_lm = autocast_lm
|
|
264
253
|
self.seqpos_slice = seqpos_slice
|
|
254
|
+
self.training_context_size = len(range(context_size)[slice(*seqpos_slice)])
|
|
265
255
|
self.exclude_special_tokens = exclude_special_tokens
|
|
266
256
|
|
|
267
257
|
self.n_dataset_processed = 0
|
|
268
258
|
|
|
269
|
-
self.estimated_norm_scaling_factor = None
|
|
270
|
-
|
|
271
259
|
# Check if dataset is tokenized
|
|
272
260
|
dataset_sample = next(iter(self.dataset))
|
|
273
261
|
|
|
@@ -432,30 +420,6 @@ class ActivationsStore:
|
|
|
432
420
|
|
|
433
421
|
return activations_dataset
|
|
434
422
|
|
|
435
|
-
def set_norm_scaling_factor_if_needed(self):
|
|
436
|
-
if (
|
|
437
|
-
self.normalize_activations == "expected_average_only_in"
|
|
438
|
-
and self.estimated_norm_scaling_factor is None
|
|
439
|
-
):
|
|
440
|
-
self.estimated_norm_scaling_factor = self.estimate_norm_scaling_factor()
|
|
441
|
-
|
|
442
|
-
def apply_norm_scaling_factor(self, activations: torch.Tensor) -> torch.Tensor:
|
|
443
|
-
if self.estimated_norm_scaling_factor is None:
|
|
444
|
-
raise ValueError(
|
|
445
|
-
"estimated_norm_scaling_factor is not set, call set_norm_scaling_factor_if_needed() first"
|
|
446
|
-
)
|
|
447
|
-
return activations * self.estimated_norm_scaling_factor
|
|
448
|
-
|
|
449
|
-
def unscale(self, activations: torch.Tensor) -> torch.Tensor:
|
|
450
|
-
if self.estimated_norm_scaling_factor is None:
|
|
451
|
-
raise ValueError(
|
|
452
|
-
"estimated_norm_scaling_factor is not set, call set_norm_scaling_factor_if_needed() first"
|
|
453
|
-
)
|
|
454
|
-
return activations / self.estimated_norm_scaling_factor
|
|
455
|
-
|
|
456
|
-
def get_norm_scaling_factor(self, activations: torch.Tensor) -> torch.Tensor:
|
|
457
|
-
return (self.d_in**0.5) / activations.norm(dim=-1).mean()
|
|
458
|
-
|
|
459
423
|
@torch.no_grad()
|
|
460
424
|
def estimate_norm_scaling_factor(self, n_batches_for_norm_estimate: int = int(1e3)):
|
|
461
425
|
norms_per_batch = []
|
|
@@ -490,21 +454,6 @@ class ActivationsStore:
|
|
|
490
454
|
"""
|
|
491
455
|
self.iterable_dataset = iter(self.dataset)
|
|
492
456
|
|
|
493
|
-
@property
|
|
494
|
-
def storage_buffer(self) -> torch.Tensor:
|
|
495
|
-
if self._storage_buffer is None:
|
|
496
|
-
self._storage_buffer = _filter_buffer_acts(
|
|
497
|
-
self.get_buffer(self.half_buffer_size), self.exclude_special_tokens
|
|
498
|
-
)
|
|
499
|
-
|
|
500
|
-
return self._storage_buffer
|
|
501
|
-
|
|
502
|
-
@property
|
|
503
|
-
def dataloader(self) -> Iterator[Any]:
|
|
504
|
-
if self._dataloader is None:
|
|
505
|
-
self._dataloader = self.get_data_loader()
|
|
506
|
-
return self._dataloader
|
|
507
|
-
|
|
508
457
|
def get_batch_tokens(
|
|
509
458
|
self, batch_size: int | None = None, raise_at_epoch_end: bool = False
|
|
510
459
|
):
|
|
@@ -537,22 +486,17 @@ class ActivationsStore:
|
|
|
537
486
|
|
|
538
487
|
d_in may result from a concatenated head dimension.
|
|
539
488
|
"""
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
dtype=torch.bfloat16,
|
|
546
|
-
enabled=self.autocast_lm,
|
|
547
|
-
)
|
|
548
|
-
else:
|
|
549
|
-
autocast_if_enabled = contextlib.nullcontext()
|
|
550
|
-
|
|
551
|
-
with autocast_if_enabled:
|
|
489
|
+
with torch.autocast(
|
|
490
|
+
device_type="cuda",
|
|
491
|
+
dtype=torch.bfloat16,
|
|
492
|
+
enabled=self.autocast_lm,
|
|
493
|
+
):
|
|
552
494
|
layerwise_activations_cache = self.model.run_with_cache(
|
|
553
495
|
batch_tokens,
|
|
554
496
|
names_filter=[self.hook_name],
|
|
555
|
-
stop_at_layer=
|
|
497
|
+
stop_at_layer=extract_stop_at_layer_from_tlens_hook_name(
|
|
498
|
+
self.hook_name
|
|
499
|
+
),
|
|
556
500
|
prepend_bos=False,
|
|
557
501
|
**self.model_kwargs,
|
|
558
502
|
)[1]
|
|
@@ -563,25 +507,25 @@ class ActivationsStore:
|
|
|
563
507
|
|
|
564
508
|
n_batches, n_context = layerwise_activations.shape[:2]
|
|
565
509
|
|
|
566
|
-
stacked_activations = torch.zeros((n_batches, n_context,
|
|
510
|
+
stacked_activations = torch.zeros((n_batches, n_context, self.d_in))
|
|
567
511
|
|
|
568
512
|
if self.hook_head_index is not None:
|
|
569
|
-
stacked_activations[:,
|
|
513
|
+
stacked_activations[:, :] = layerwise_activations[
|
|
570
514
|
:, :, self.hook_head_index
|
|
571
515
|
]
|
|
572
516
|
elif layerwise_activations.ndim > 3: # if we have a head dimension
|
|
573
517
|
try:
|
|
574
|
-
stacked_activations[:,
|
|
518
|
+
stacked_activations[:, :] = layerwise_activations.view(
|
|
575
519
|
n_batches, n_context, -1
|
|
576
520
|
)
|
|
577
521
|
except RuntimeError as e:
|
|
578
522
|
logger.error(f"Error during view operation: {e}")
|
|
579
523
|
logger.info("Attempting to use reshape instead...")
|
|
580
|
-
stacked_activations[:,
|
|
524
|
+
stacked_activations[:, :] = layerwise_activations.reshape(
|
|
581
525
|
n_batches, n_context, -1
|
|
582
526
|
)
|
|
583
527
|
else:
|
|
584
|
-
stacked_activations[:,
|
|
528
|
+
stacked_activations[:, :] = layerwise_activations
|
|
585
529
|
|
|
586
530
|
return stacked_activations
|
|
587
531
|
|
|
@@ -589,7 +533,6 @@ class ActivationsStore:
|
|
|
589
533
|
self,
|
|
590
534
|
total_size: int,
|
|
591
535
|
context_size: int,
|
|
592
|
-
num_layers: int,
|
|
593
536
|
d_in: int,
|
|
594
537
|
raise_on_epoch_end: bool,
|
|
595
538
|
) -> tuple[
|
|
@@ -606,10 +549,9 @@ class ActivationsStore:
|
|
|
606
549
|
"""
|
|
607
550
|
assert self.cached_activation_dataset is not None
|
|
608
551
|
# In future, could be a list of multiple hook names
|
|
609
|
-
|
|
610
|
-
if not set(hook_names).issubset(self.cached_activation_dataset.column_names):
|
|
552
|
+
if self.hook_name not in self.cached_activation_dataset.column_names:
|
|
611
553
|
raise ValueError(
|
|
612
|
-
f"Missing columns in dataset. Expected {
|
|
554
|
+
f"Missing columns in dataset. Expected {self.hook_name}, "
|
|
613
555
|
f"got {self.cached_activation_dataset.column_names}."
|
|
614
556
|
)
|
|
615
557
|
|
|
@@ -622,28 +564,17 @@ class ActivationsStore:
|
|
|
622
564
|
ds_slice = self.cached_activation_dataset[
|
|
623
565
|
self.current_row_idx : self.current_row_idx + total_size
|
|
624
566
|
]
|
|
625
|
-
for
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
if _hook_buffer.shape != (total_size, context_size, d_in):
|
|
630
|
-
raise ValueError(
|
|
631
|
-
f"_hook_buffer has shape {_hook_buffer.shape}, "
|
|
632
|
-
f"but expected ({total_size}, {context_size}, {d_in})."
|
|
633
|
-
)
|
|
634
|
-
new_buffer.append(_hook_buffer)
|
|
635
|
-
|
|
636
|
-
# Stack across num_layers dimension
|
|
637
|
-
# list of num_layers; shape: (total_size, context_size, d_in) -> (total_size, context_size, num_layers, d_in)
|
|
638
|
-
new_buffer = torch.stack(new_buffer, dim=2)
|
|
639
|
-
if new_buffer.shape != (total_size, context_size, num_layers, d_in):
|
|
567
|
+
# Load activations for each hook.
|
|
568
|
+
# Usually faster to first slice dataset then pick column
|
|
569
|
+
new_buffer = ds_slice[self.hook_name]
|
|
570
|
+
if new_buffer.shape != (total_size, context_size, d_in):
|
|
640
571
|
raise ValueError(
|
|
641
572
|
f"new_buffer has shape {new_buffer.shape}, "
|
|
642
|
-
f"but expected ({total_size}, {context_size}, {
|
|
573
|
+
f"but expected ({total_size}, {context_size}, {d_in})."
|
|
643
574
|
)
|
|
644
575
|
|
|
645
576
|
self.current_row_idx += total_size
|
|
646
|
-
acts_buffer = new_buffer.reshape(total_size * context_size,
|
|
577
|
+
acts_buffer = new_buffer.reshape(total_size * context_size, d_in)
|
|
647
578
|
|
|
648
579
|
if "token_ids" not in self.cached_activation_dataset.column_names:
|
|
649
580
|
return acts_buffer, None
|
|
@@ -658,7 +589,7 @@ class ActivationsStore:
|
|
|
658
589
|
return acts_buffer, token_ids_buffer
|
|
659
590
|
|
|
660
591
|
@torch.no_grad()
|
|
661
|
-
def
|
|
592
|
+
def get_raw_buffer(
|
|
662
593
|
self,
|
|
663
594
|
n_batches_in_buffer: int,
|
|
664
595
|
raise_on_epoch_end: bool = False,
|
|
@@ -672,26 +603,24 @@ class ActivationsStore:
|
|
|
672
603
|
If raise_on_epoch_end is True, when the dataset it exhausted it will automatically refill the dataset and then raise a StopIteration so that the caller has a chance to react.
|
|
673
604
|
"""
|
|
674
605
|
context_size = self.context_size
|
|
675
|
-
training_context_size = len(range(context_size)[slice(*self.seqpos_slice)])
|
|
676
606
|
batch_size = self.store_batch_size_prompts
|
|
677
607
|
d_in = self.d_in
|
|
678
608
|
total_size = batch_size * n_batches_in_buffer
|
|
679
|
-
num_layers = 1
|
|
680
609
|
|
|
681
610
|
if self.cached_activation_dataset is not None:
|
|
682
611
|
return self._load_buffer_from_cached(
|
|
683
|
-
total_size, context_size,
|
|
612
|
+
total_size, context_size, d_in, raise_on_epoch_end
|
|
684
613
|
)
|
|
685
614
|
|
|
686
615
|
refill_iterator = range(0, total_size, batch_size)
|
|
687
616
|
# Initialize empty tensor buffer of the maximum required size with an additional dimension for layers
|
|
688
617
|
new_buffer_activations = torch.zeros(
|
|
689
|
-
(total_size, training_context_size,
|
|
618
|
+
(total_size, self.training_context_size, d_in),
|
|
690
619
|
dtype=self.dtype, # type: ignore
|
|
691
620
|
device=self.device,
|
|
692
621
|
)
|
|
693
622
|
new_buffer_token_ids = torch.zeros(
|
|
694
|
-
(total_size, training_context_size),
|
|
623
|
+
(total_size, self.training_context_size),
|
|
695
624
|
dtype=torch.long,
|
|
696
625
|
device=self.device,
|
|
697
626
|
)
|
|
@@ -716,106 +645,80 @@ class ActivationsStore:
|
|
|
716
645
|
refill_batch_idx_start : refill_batch_idx_start + batch_size, ...
|
|
717
646
|
] = refill_batch_tokens
|
|
718
647
|
|
|
719
|
-
new_buffer_activations = new_buffer_activations.reshape(-1,
|
|
648
|
+
new_buffer_activations = new_buffer_activations.reshape(-1, d_in)
|
|
720
649
|
new_buffer_token_ids = new_buffer_token_ids.reshape(-1)
|
|
721
650
|
if shuffle:
|
|
722
651
|
new_buffer_activations, new_buffer_token_ids = permute_together(
|
|
723
652
|
[new_buffer_activations, new_buffer_token_ids]
|
|
724
653
|
)
|
|
725
654
|
|
|
726
|
-
# every buffer should be normalized:
|
|
727
|
-
if self.normalize_activations == "expected_average_only_in":
|
|
728
|
-
new_buffer_activations = self.apply_norm_scaling_factor(
|
|
729
|
-
new_buffer_activations
|
|
730
|
-
)
|
|
731
|
-
|
|
732
655
|
return (
|
|
733
656
|
new_buffer_activations,
|
|
734
657
|
new_buffer_token_ids,
|
|
735
658
|
)
|
|
736
659
|
|
|
737
|
-
def
|
|
660
|
+
def get_filtered_buffer(
|
|
738
661
|
self,
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
662
|
+
n_batches_in_buffer: int,
|
|
663
|
+
raise_on_epoch_end: bool = False,
|
|
664
|
+
shuffle: bool = True,
|
|
665
|
+
) -> torch.Tensor:
|
|
666
|
+
return _filter_buffer_acts(
|
|
667
|
+
self.get_raw_buffer(
|
|
668
|
+
n_batches_in_buffer=n_batches_in_buffer,
|
|
669
|
+
raise_on_epoch_end=raise_on_epoch_end,
|
|
670
|
+
shuffle=shuffle,
|
|
671
|
+
),
|
|
672
|
+
self.exclude_special_tokens,
|
|
673
|
+
)
|
|
745
674
|
|
|
675
|
+
def _iterate_filtered_activations(self) -> Generator[torch.Tensor, None, None]:
|
|
746
676
|
"""
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
try:
|
|
751
|
-
new_samples = _filter_buffer_acts(
|
|
752
|
-
self.get_buffer(self.half_buffer_size, raise_on_epoch_end=True),
|
|
753
|
-
self.exclude_special_tokens,
|
|
754
|
-
)
|
|
755
|
-
except StopIteration:
|
|
756
|
-
warnings.warn(
|
|
757
|
-
"All samples in the training dataset have been exhausted, we are now beginning a new epoch with the same samples."
|
|
758
|
-
)
|
|
759
|
-
self._storage_buffer = (
|
|
760
|
-
None # dump the current buffer so samples do not leak between epochs
|
|
761
|
-
)
|
|
677
|
+
Iterate over the filtered tokens in the buffer.
|
|
678
|
+
"""
|
|
679
|
+
while True:
|
|
762
680
|
try:
|
|
763
|
-
|
|
764
|
-
self.
|
|
765
|
-
self.exclude_special_tokens,
|
|
681
|
+
yield self.get_filtered_buffer(
|
|
682
|
+
self.half_buffer_size, raise_on_epoch_end=True
|
|
766
683
|
)
|
|
767
684
|
except StopIteration:
|
|
768
|
-
|
|
769
|
-
"
|
|
685
|
+
warnings.warn(
|
|
686
|
+
"All samples in the training dataset have been exhausted, beginning new epoch."
|
|
770
687
|
)
|
|
688
|
+
try:
|
|
689
|
+
yield self.get_filtered_buffer(self.half_buffer_size)
|
|
690
|
+
except StopIteration:
|
|
691
|
+
raise ValueError(
|
|
692
|
+
"Unable to fill buffer after starting new epoch. Dataset may be too small."
|
|
693
|
+
)
|
|
771
694
|
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
# 3. put other 50 % in a dataloader
|
|
784
|
-
return iter(
|
|
785
|
-
DataLoader(
|
|
786
|
-
# TODO: seems like a typing bug?
|
|
787
|
-
cast(Any, mixing_buffer[mixing_buffer.shape[0] // 2 :]),
|
|
788
|
-
batch_size=batch_size,
|
|
789
|
-
shuffle=True,
|
|
790
|
-
)
|
|
695
|
+
def get_data_loader(
|
|
696
|
+
self,
|
|
697
|
+
) -> Iterator[Any]:
|
|
698
|
+
"""
|
|
699
|
+
Return an auto-refilling stream of filtered and mixed activations.
|
|
700
|
+
"""
|
|
701
|
+
return mixing_buffer(
|
|
702
|
+
buffer_size=self.n_batches_in_buffer * self.training_context_size,
|
|
703
|
+
batch_size=self.train_batch_size_tokens,
|
|
704
|
+
activations_loader=self._iterate_filtered_activations(),
|
|
791
705
|
)
|
|
792
706
|
|
|
793
707
|
def next_batch(self) -> torch.Tensor:
|
|
794
|
-
"""
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
return next(self.dataloader)
|
|
801
|
-
except StopIteration:
|
|
802
|
-
# If the DataLoader is exhausted, create a new one
|
|
708
|
+
"""Get next batch, updating buffer if needed."""
|
|
709
|
+
return self.__next__()
|
|
710
|
+
|
|
711
|
+
# ActivationsStore should be an iterator
|
|
712
|
+
def __next__(self) -> torch.Tensor:
|
|
713
|
+
if self._dataloader is None:
|
|
803
714
|
self._dataloader = self.get_data_loader()
|
|
804
|
-
|
|
715
|
+
return next(self._dataloader)
|
|
716
|
+
|
|
717
|
+
def __iter__(self) -> Iterator[torch.Tensor]:
|
|
718
|
+
return self
|
|
805
719
|
|
|
806
720
|
def state_dict(self) -> dict[str, torch.Tensor]:
|
|
807
|
-
|
|
808
|
-
"n_dataset_processed": torch.tensor(self.n_dataset_processed),
|
|
809
|
-
}
|
|
810
|
-
if self._storage_buffer is not None: # first time might be None
|
|
811
|
-
result["storage_buffer_activations"] = self._storage_buffer[0]
|
|
812
|
-
if self._storage_buffer[1] is not None:
|
|
813
|
-
result["storage_buffer_tokens"] = self._storage_buffer[1]
|
|
814
|
-
if self.estimated_norm_scaling_factor is not None:
|
|
815
|
-
result["estimated_norm_scaling_factor"] = torch.tensor(
|
|
816
|
-
self.estimated_norm_scaling_factor
|
|
817
|
-
)
|
|
818
|
-
return result
|
|
721
|
+
return {"n_dataset_processed": torch.tensor(self.n_dataset_processed)}
|
|
819
722
|
|
|
820
723
|
def save(self, file_path: str):
|
|
821
724
|
"""save the state dict to a file in safetensors format"""
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from collections.abc import Iterator
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@torch.no_grad()
|
|
7
|
+
def mixing_buffer(
|
|
8
|
+
buffer_size: int,
|
|
9
|
+
batch_size: int,
|
|
10
|
+
activations_loader: Iterator[torch.Tensor],
|
|
11
|
+
) -> Iterator[torch.Tensor]:
|
|
12
|
+
"""
|
|
13
|
+
A generator that maintains a mix of old and new activations for better training.
|
|
14
|
+
It stores half of the activations and mixes them with new ones to create batches.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
buffer_size: Total size of the buffer (will store buffer_size/2 activations)
|
|
18
|
+
batch_size: Size of batches to return
|
|
19
|
+
activations_loader: Iterator providing new activations
|
|
20
|
+
|
|
21
|
+
Yields:
|
|
22
|
+
Batches of activations of shape (batch_size, *activation_dims)
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
if buffer_size < batch_size:
|
|
26
|
+
raise ValueError("Buffer size must be greater than or equal to batch size")
|
|
27
|
+
|
|
28
|
+
storage_buffer: torch.Tensor | None = None
|
|
29
|
+
|
|
30
|
+
for new_activations in activations_loader:
|
|
31
|
+
storage_buffer = (
|
|
32
|
+
new_activations
|
|
33
|
+
if storage_buffer is None
|
|
34
|
+
else torch.cat([storage_buffer, new_activations], dim=0)
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
if storage_buffer.shape[0] >= buffer_size:
|
|
38
|
+
# Shuffle
|
|
39
|
+
storage_buffer = storage_buffer[torch.randperm(storage_buffer.shape[0])]
|
|
40
|
+
|
|
41
|
+
num_serving_batches = max(1, storage_buffer.shape[0] // (2 * batch_size))
|
|
42
|
+
serving_cutoff = num_serving_batches * batch_size
|
|
43
|
+
serving_buffer = storage_buffer[:serving_cutoff]
|
|
44
|
+
storage_buffer = storage_buffer[serving_cutoff:]
|
|
45
|
+
|
|
46
|
+
# Yield batches from the serving_buffer
|
|
47
|
+
for batch_idx in range(num_serving_batches):
|
|
48
|
+
yield serving_buffer[
|
|
49
|
+
batch_idx * batch_size : (batch_idx + 1) * batch_size
|
|
50
|
+
]
|
|
51
|
+
|
|
52
|
+
# If there are any remaining activations, yield them
|
|
53
|
+
if storage_buffer is not None:
|
|
54
|
+
remaining_batches = storage_buffer.shape[0] // batch_size
|
|
55
|
+
for i in range(remaining_batches):
|
|
56
|
+
yield storage_buffer[i * batch_size : (i + 1) * batch_size]
|