sae-lens 6.0.0rc1__py3-none-any.whl → 6.0.0rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +55 -18
- sae_lens/analysis/hooked_sae_transformer.py +10 -10
- sae_lens/analysis/neuronpedia_integration.py +13 -11
- sae_lens/cache_activations_runner.py +9 -7
- sae_lens/config.py +105 -235
- sae_lens/constants.py +20 -0
- sae_lens/evals.py +34 -31
- sae_lens/{sae_training_runner.py → llm_sae_training_runner.py} +103 -70
- sae_lens/load_model.py +53 -5
- sae_lens/loading/pretrained_sae_loaders.py +36 -10
- sae_lens/registry.py +49 -0
- sae_lens/saes/__init__.py +48 -0
- sae_lens/saes/gated_sae.py +70 -59
- sae_lens/saes/jumprelu_sae.py +58 -72
- sae_lens/saes/sae.py +248 -273
- sae_lens/saes/standard_sae.py +75 -57
- sae_lens/saes/topk_sae.py +72 -83
- sae_lens/training/activation_scaler.py +53 -0
- sae_lens/training/activations_store.py +105 -184
- sae_lens/training/mixing_buffer.py +56 -0
- sae_lens/training/optim.py +60 -36
- sae_lens/training/sae_trainer.py +134 -158
- sae_lens/training/types.py +5 -0
- sae_lens/training/upload_saes_to_huggingface.py +11 -5
- sae_lens/util.py +47 -0
- {sae_lens-6.0.0rc1.dist-info → sae_lens-6.0.0rc3.dist-info}/METADATA +1 -1
- sae_lens-6.0.0rc3.dist-info/RECORD +38 -0
- {sae_lens-6.0.0rc1.dist-info → sae_lens-6.0.0rc3.dist-info}/WHEEL +1 -1
- sae_lens/regsitry.py +0 -34
- sae_lens-6.0.0rc1.dist-info/RECORD +0 -32
- {sae_lens-6.0.0rc1.dist-info → sae_lens-6.0.0rc3.dist-info}/LICENSE +0 -0
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import contextlib
|
|
4
3
|
import json
|
|
5
4
|
import os
|
|
6
5
|
import warnings
|
|
@@ -16,20 +15,21 @@ from huggingface_hub.utils import HfHubHTTPError
|
|
|
16
15
|
from jaxtyping import Float, Int
|
|
17
16
|
from requests import HTTPError
|
|
18
17
|
from safetensors.torch import save_file
|
|
19
|
-
from torch.utils.data import DataLoader
|
|
20
18
|
from tqdm import tqdm
|
|
21
19
|
from transformer_lens.hook_points import HookedRootModule
|
|
22
20
|
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
|
23
21
|
|
|
24
22
|
from sae_lens import logger
|
|
25
23
|
from sae_lens.config import (
|
|
26
|
-
DTYPE_MAP,
|
|
27
24
|
CacheActivationsRunnerConfig,
|
|
28
25
|
HfDataset,
|
|
29
26
|
LanguageModelSAERunnerConfig,
|
|
30
27
|
)
|
|
31
|
-
from sae_lens.
|
|
28
|
+
from sae_lens.constants import DTYPE_MAP
|
|
29
|
+
from sae_lens.saes.sae import SAE, T_SAE_CONFIG, T_TRAINING_SAE_CONFIG
|
|
32
30
|
from sae_lens.tokenization_and_batching import concat_and_batch_sequences
|
|
31
|
+
from sae_lens.training.mixing_buffer import mixing_buffer
|
|
32
|
+
from sae_lens.util import extract_stop_at_layer_from_tlens_hook_name
|
|
33
33
|
|
|
34
34
|
|
|
35
35
|
# TODO: Make an activation store config class to be consistent with the rest of the code.
|
|
@@ -45,10 +45,8 @@ class ActivationsStore:
|
|
|
45
45
|
cached_activation_dataset: Dataset | None = None
|
|
46
46
|
tokens_column: Literal["tokens", "input_ids", "text", "problem"]
|
|
47
47
|
hook_name: str
|
|
48
|
-
hook_layer: int
|
|
49
48
|
hook_head_index: int | None
|
|
50
49
|
_dataloader: Iterator[Any] | None = None
|
|
51
|
-
_storage_buffer: torch.Tensor | None = None
|
|
52
50
|
exclude_special_tokens: torch.Tensor | None = None
|
|
53
51
|
device: torch.device
|
|
54
52
|
|
|
@@ -65,7 +63,6 @@ class ActivationsStore:
|
|
|
65
63
|
cached_activations_path=cfg.new_cached_activations_path,
|
|
66
64
|
dtype=cfg.dtype,
|
|
67
65
|
hook_name=cfg.hook_name,
|
|
68
|
-
hook_layer=cfg.hook_layer,
|
|
69
66
|
context_size=cfg.context_size,
|
|
70
67
|
d_in=cfg.d_in,
|
|
71
68
|
n_batches_in_buffer=cfg.n_batches_in_buffer,
|
|
@@ -91,7 +88,8 @@ class ActivationsStore:
|
|
|
91
88
|
def from_config(
|
|
92
89
|
cls,
|
|
93
90
|
model: HookedRootModule,
|
|
94
|
-
cfg: LanguageModelSAERunnerConfig
|
|
91
|
+
cfg: LanguageModelSAERunnerConfig[T_TRAINING_SAE_CONFIG]
|
|
92
|
+
| CacheActivationsRunnerConfig,
|
|
95
93
|
override_dataset: HfDataset | None = None,
|
|
96
94
|
) -> ActivationsStore:
|
|
97
95
|
if isinstance(cfg, CacheActivationsRunnerConfig):
|
|
@@ -125,16 +123,17 @@ class ActivationsStore:
|
|
|
125
123
|
dataset=override_dataset or cfg.dataset_path,
|
|
126
124
|
streaming=cfg.streaming,
|
|
127
125
|
hook_name=cfg.hook_name,
|
|
128
|
-
hook_layer=cfg.hook_layer,
|
|
129
126
|
hook_head_index=cfg.hook_head_index,
|
|
130
127
|
context_size=cfg.context_size,
|
|
131
|
-
d_in=cfg.d_in
|
|
128
|
+
d_in=cfg.d_in
|
|
129
|
+
if isinstance(cfg, CacheActivationsRunnerConfig)
|
|
130
|
+
else cfg.sae.d_in,
|
|
132
131
|
n_batches_in_buffer=cfg.n_batches_in_buffer,
|
|
133
132
|
total_training_tokens=cfg.training_tokens,
|
|
134
133
|
store_batch_size_prompts=cfg.store_batch_size_prompts,
|
|
135
134
|
train_batch_size_tokens=cfg.train_batch_size_tokens,
|
|
136
135
|
prepend_bos=cfg.prepend_bos,
|
|
137
|
-
normalize_activations=cfg.normalize_activations,
|
|
136
|
+
normalize_activations=cfg.sae.normalize_activations,
|
|
138
137
|
device=device,
|
|
139
138
|
dtype=cfg.dtype,
|
|
140
139
|
cached_activations_path=cached_activations_path,
|
|
@@ -149,9 +148,10 @@ class ActivationsStore:
|
|
|
149
148
|
def from_sae(
|
|
150
149
|
cls,
|
|
151
150
|
model: HookedRootModule,
|
|
152
|
-
sae: SAE,
|
|
151
|
+
sae: SAE[T_SAE_CONFIG],
|
|
152
|
+
dataset: HfDataset | str,
|
|
153
|
+
dataset_trust_remote_code: bool = False,
|
|
153
154
|
context_size: int | None = None,
|
|
154
|
-
dataset: HfDataset | str | None = None,
|
|
155
155
|
streaming: bool = True,
|
|
156
156
|
store_batch_size_prompts: int = 8,
|
|
157
157
|
n_batches_in_buffer: int = 8,
|
|
@@ -159,25 +159,34 @@ class ActivationsStore:
|
|
|
159
159
|
total_tokens: int = 10**9,
|
|
160
160
|
device: str = "cpu",
|
|
161
161
|
) -> ActivationsStore:
|
|
162
|
+
if sae.cfg.metadata.hook_name is None:
|
|
163
|
+
raise ValueError("hook_name is required")
|
|
164
|
+
if sae.cfg.metadata.hook_head_index is None:
|
|
165
|
+
raise ValueError("hook_head_index is required")
|
|
166
|
+
if sae.cfg.metadata.context_size is None:
|
|
167
|
+
raise ValueError("context_size is required")
|
|
168
|
+
if sae.cfg.metadata.prepend_bos is None:
|
|
169
|
+
raise ValueError("prepend_bos is required")
|
|
162
170
|
return cls(
|
|
163
171
|
model=model,
|
|
164
|
-
dataset=
|
|
172
|
+
dataset=dataset,
|
|
165
173
|
d_in=sae.cfg.d_in,
|
|
166
|
-
hook_name=sae.cfg.hook_name,
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
174
|
+
hook_name=sae.cfg.metadata.hook_name,
|
|
175
|
+
hook_head_index=sae.cfg.metadata.hook_head_index,
|
|
176
|
+
context_size=sae.cfg.metadata.context_size
|
|
177
|
+
if context_size is None
|
|
178
|
+
else context_size,
|
|
179
|
+
prepend_bos=sae.cfg.metadata.prepend_bos,
|
|
171
180
|
streaming=streaming,
|
|
172
181
|
store_batch_size_prompts=store_batch_size_prompts,
|
|
173
182
|
train_batch_size_tokens=train_batch_size_tokens,
|
|
174
183
|
n_batches_in_buffer=n_batches_in_buffer,
|
|
175
184
|
total_training_tokens=total_tokens,
|
|
176
185
|
normalize_activations=sae.cfg.normalize_activations,
|
|
177
|
-
dataset_trust_remote_code=
|
|
186
|
+
dataset_trust_remote_code=dataset_trust_remote_code,
|
|
178
187
|
dtype=sae.cfg.dtype,
|
|
179
188
|
device=torch.device(device),
|
|
180
|
-
seqpos_slice=sae.cfg.seqpos_slice or (None,),
|
|
189
|
+
seqpos_slice=sae.cfg.metadata.seqpos_slice or (None,),
|
|
181
190
|
)
|
|
182
191
|
|
|
183
192
|
def __init__(
|
|
@@ -186,7 +195,6 @@ class ActivationsStore:
|
|
|
186
195
|
dataset: HfDataset | str,
|
|
187
196
|
streaming: bool,
|
|
188
197
|
hook_name: str,
|
|
189
|
-
hook_layer: int,
|
|
190
198
|
hook_head_index: int | None,
|
|
191
199
|
context_size: int,
|
|
192
200
|
d_in: int,
|
|
@@ -230,7 +238,6 @@ class ActivationsStore:
|
|
|
230
238
|
)
|
|
231
239
|
|
|
232
240
|
self.hook_name = hook_name
|
|
233
|
-
self.hook_layer = hook_layer
|
|
234
241
|
self.hook_head_index = hook_head_index
|
|
235
242
|
self.context_size = context_size
|
|
236
243
|
self.d_in = d_in
|
|
@@ -246,12 +253,11 @@ class ActivationsStore:
|
|
|
246
253
|
self.cached_activations_path = cached_activations_path
|
|
247
254
|
self.autocast_lm = autocast_lm
|
|
248
255
|
self.seqpos_slice = seqpos_slice
|
|
256
|
+
self.training_context_size = len(range(context_size)[slice(*seqpos_slice)])
|
|
249
257
|
self.exclude_special_tokens = exclude_special_tokens
|
|
250
258
|
|
|
251
259
|
self.n_dataset_processed = 0
|
|
252
260
|
|
|
253
|
-
self.estimated_norm_scaling_factor = None
|
|
254
|
-
|
|
255
261
|
# Check if dataset is tokenized
|
|
256
262
|
dataset_sample = next(iter(self.dataset))
|
|
257
263
|
|
|
@@ -416,30 +422,6 @@ class ActivationsStore:
|
|
|
416
422
|
|
|
417
423
|
return activations_dataset
|
|
418
424
|
|
|
419
|
-
def set_norm_scaling_factor_if_needed(self):
|
|
420
|
-
if (
|
|
421
|
-
self.normalize_activations == "expected_average_only_in"
|
|
422
|
-
and self.estimated_norm_scaling_factor is None
|
|
423
|
-
):
|
|
424
|
-
self.estimated_norm_scaling_factor = self.estimate_norm_scaling_factor()
|
|
425
|
-
|
|
426
|
-
def apply_norm_scaling_factor(self, activations: torch.Tensor) -> torch.Tensor:
|
|
427
|
-
if self.estimated_norm_scaling_factor is None:
|
|
428
|
-
raise ValueError(
|
|
429
|
-
"estimated_norm_scaling_factor is not set, call set_norm_scaling_factor_if_needed() first"
|
|
430
|
-
)
|
|
431
|
-
return activations * self.estimated_norm_scaling_factor
|
|
432
|
-
|
|
433
|
-
def unscale(self, activations: torch.Tensor) -> torch.Tensor:
|
|
434
|
-
if self.estimated_norm_scaling_factor is None:
|
|
435
|
-
raise ValueError(
|
|
436
|
-
"estimated_norm_scaling_factor is not set, call set_norm_scaling_factor_if_needed() first"
|
|
437
|
-
)
|
|
438
|
-
return activations / self.estimated_norm_scaling_factor
|
|
439
|
-
|
|
440
|
-
def get_norm_scaling_factor(self, activations: torch.Tensor) -> torch.Tensor:
|
|
441
|
-
return (self.d_in**0.5) / activations.norm(dim=-1).mean()
|
|
442
|
-
|
|
443
425
|
@torch.no_grad()
|
|
444
426
|
def estimate_norm_scaling_factor(self, n_batches_for_norm_estimate: int = int(1e3)):
|
|
445
427
|
norms_per_batch = []
|
|
@@ -474,21 +456,6 @@ class ActivationsStore:
|
|
|
474
456
|
"""
|
|
475
457
|
self.iterable_dataset = iter(self.dataset)
|
|
476
458
|
|
|
477
|
-
@property
|
|
478
|
-
def storage_buffer(self) -> torch.Tensor:
|
|
479
|
-
if self._storage_buffer is None:
|
|
480
|
-
self._storage_buffer = _filter_buffer_acts(
|
|
481
|
-
self.get_buffer(self.half_buffer_size), self.exclude_special_tokens
|
|
482
|
-
)
|
|
483
|
-
|
|
484
|
-
return self._storage_buffer
|
|
485
|
-
|
|
486
|
-
@property
|
|
487
|
-
def dataloader(self) -> Iterator[Any]:
|
|
488
|
-
if self._dataloader is None:
|
|
489
|
-
self._dataloader = self.get_data_loader()
|
|
490
|
-
return self._dataloader
|
|
491
|
-
|
|
492
459
|
def get_batch_tokens(
|
|
493
460
|
self, batch_size: int | None = None, raise_at_epoch_end: bool = False
|
|
494
461
|
):
|
|
@@ -521,22 +488,17 @@ class ActivationsStore:
|
|
|
521
488
|
|
|
522
489
|
d_in may result from a concatenated head dimension.
|
|
523
490
|
"""
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
dtype=torch.bfloat16,
|
|
530
|
-
enabled=self.autocast_lm,
|
|
531
|
-
)
|
|
532
|
-
else:
|
|
533
|
-
autocast_if_enabled = contextlib.nullcontext()
|
|
534
|
-
|
|
535
|
-
with autocast_if_enabled:
|
|
491
|
+
with torch.autocast(
|
|
492
|
+
device_type="cuda",
|
|
493
|
+
dtype=torch.bfloat16,
|
|
494
|
+
enabled=self.autocast_lm,
|
|
495
|
+
):
|
|
536
496
|
layerwise_activations_cache = self.model.run_with_cache(
|
|
537
497
|
batch_tokens,
|
|
538
498
|
names_filter=[self.hook_name],
|
|
539
|
-
stop_at_layer=
|
|
499
|
+
stop_at_layer=extract_stop_at_layer_from_tlens_hook_name(
|
|
500
|
+
self.hook_name
|
|
501
|
+
),
|
|
540
502
|
prepend_bos=False,
|
|
541
503
|
**self.model_kwargs,
|
|
542
504
|
)[1]
|
|
@@ -547,25 +509,25 @@ class ActivationsStore:
|
|
|
547
509
|
|
|
548
510
|
n_batches, n_context = layerwise_activations.shape[:2]
|
|
549
511
|
|
|
550
|
-
stacked_activations = torch.zeros((n_batches, n_context,
|
|
512
|
+
stacked_activations = torch.zeros((n_batches, n_context, self.d_in))
|
|
551
513
|
|
|
552
514
|
if self.hook_head_index is not None:
|
|
553
|
-
stacked_activations[:,
|
|
515
|
+
stacked_activations[:, :] = layerwise_activations[
|
|
554
516
|
:, :, self.hook_head_index
|
|
555
517
|
]
|
|
556
518
|
elif layerwise_activations.ndim > 3: # if we have a head dimension
|
|
557
519
|
try:
|
|
558
|
-
stacked_activations[:,
|
|
520
|
+
stacked_activations[:, :] = layerwise_activations.view(
|
|
559
521
|
n_batches, n_context, -1
|
|
560
522
|
)
|
|
561
523
|
except RuntimeError as e:
|
|
562
524
|
logger.error(f"Error during view operation: {e}")
|
|
563
525
|
logger.info("Attempting to use reshape instead...")
|
|
564
|
-
stacked_activations[:,
|
|
526
|
+
stacked_activations[:, :] = layerwise_activations.reshape(
|
|
565
527
|
n_batches, n_context, -1
|
|
566
528
|
)
|
|
567
529
|
else:
|
|
568
|
-
stacked_activations[:,
|
|
530
|
+
stacked_activations[:, :] = layerwise_activations
|
|
569
531
|
|
|
570
532
|
return stacked_activations
|
|
571
533
|
|
|
@@ -573,7 +535,6 @@ class ActivationsStore:
|
|
|
573
535
|
self,
|
|
574
536
|
total_size: int,
|
|
575
537
|
context_size: int,
|
|
576
|
-
num_layers: int,
|
|
577
538
|
d_in: int,
|
|
578
539
|
raise_on_epoch_end: bool,
|
|
579
540
|
) -> tuple[
|
|
@@ -590,10 +551,9 @@ class ActivationsStore:
|
|
|
590
551
|
"""
|
|
591
552
|
assert self.cached_activation_dataset is not None
|
|
592
553
|
# In future, could be a list of multiple hook names
|
|
593
|
-
|
|
594
|
-
if not set(hook_names).issubset(self.cached_activation_dataset.column_names):
|
|
554
|
+
if self.hook_name not in self.cached_activation_dataset.column_names:
|
|
595
555
|
raise ValueError(
|
|
596
|
-
f"Missing columns in dataset. Expected {
|
|
556
|
+
f"Missing columns in dataset. Expected {self.hook_name}, "
|
|
597
557
|
f"got {self.cached_activation_dataset.column_names}."
|
|
598
558
|
)
|
|
599
559
|
|
|
@@ -606,28 +566,17 @@ class ActivationsStore:
|
|
|
606
566
|
ds_slice = self.cached_activation_dataset[
|
|
607
567
|
self.current_row_idx : self.current_row_idx + total_size
|
|
608
568
|
]
|
|
609
|
-
for
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
if _hook_buffer.shape != (total_size, context_size, d_in):
|
|
614
|
-
raise ValueError(
|
|
615
|
-
f"_hook_buffer has shape {_hook_buffer.shape}, "
|
|
616
|
-
f"but expected ({total_size}, {context_size}, {d_in})."
|
|
617
|
-
)
|
|
618
|
-
new_buffer.append(_hook_buffer)
|
|
619
|
-
|
|
620
|
-
# Stack across num_layers dimension
|
|
621
|
-
# list of num_layers; shape: (total_size, context_size, d_in) -> (total_size, context_size, num_layers, d_in)
|
|
622
|
-
new_buffer = torch.stack(new_buffer, dim=2)
|
|
623
|
-
if new_buffer.shape != (total_size, context_size, num_layers, d_in):
|
|
569
|
+
# Load activations for each hook.
|
|
570
|
+
# Usually faster to first slice dataset then pick column
|
|
571
|
+
new_buffer = ds_slice[self.hook_name]
|
|
572
|
+
if new_buffer.shape != (total_size, context_size, d_in):
|
|
624
573
|
raise ValueError(
|
|
625
574
|
f"new_buffer has shape {new_buffer.shape}, "
|
|
626
|
-
f"but expected ({total_size}, {context_size}, {
|
|
575
|
+
f"but expected ({total_size}, {context_size}, {d_in})."
|
|
627
576
|
)
|
|
628
577
|
|
|
629
578
|
self.current_row_idx += total_size
|
|
630
|
-
acts_buffer = new_buffer.reshape(total_size * context_size,
|
|
579
|
+
acts_buffer = new_buffer.reshape(total_size * context_size, d_in)
|
|
631
580
|
|
|
632
581
|
if "token_ids" not in self.cached_activation_dataset.column_names:
|
|
633
582
|
return acts_buffer, None
|
|
@@ -642,7 +591,7 @@ class ActivationsStore:
|
|
|
642
591
|
return acts_buffer, token_ids_buffer
|
|
643
592
|
|
|
644
593
|
@torch.no_grad()
|
|
645
|
-
def
|
|
594
|
+
def get_raw_buffer(
|
|
646
595
|
self,
|
|
647
596
|
n_batches_in_buffer: int,
|
|
648
597
|
raise_on_epoch_end: bool = False,
|
|
@@ -656,26 +605,24 @@ class ActivationsStore:
|
|
|
656
605
|
If raise_on_epoch_end is True, when the dataset it exhausted it will automatically refill the dataset and then raise a StopIteration so that the caller has a chance to react.
|
|
657
606
|
"""
|
|
658
607
|
context_size = self.context_size
|
|
659
|
-
training_context_size = len(range(context_size)[slice(*self.seqpos_slice)])
|
|
660
608
|
batch_size = self.store_batch_size_prompts
|
|
661
609
|
d_in = self.d_in
|
|
662
610
|
total_size = batch_size * n_batches_in_buffer
|
|
663
|
-
num_layers = 1
|
|
664
611
|
|
|
665
612
|
if self.cached_activation_dataset is not None:
|
|
666
613
|
return self._load_buffer_from_cached(
|
|
667
|
-
total_size, context_size,
|
|
614
|
+
total_size, context_size, d_in, raise_on_epoch_end
|
|
668
615
|
)
|
|
669
616
|
|
|
670
617
|
refill_iterator = range(0, total_size, batch_size)
|
|
671
618
|
# Initialize empty tensor buffer of the maximum required size with an additional dimension for layers
|
|
672
619
|
new_buffer_activations = torch.zeros(
|
|
673
|
-
(total_size, training_context_size,
|
|
620
|
+
(total_size, self.training_context_size, d_in),
|
|
674
621
|
dtype=self.dtype, # type: ignore
|
|
675
622
|
device=self.device,
|
|
676
623
|
)
|
|
677
624
|
new_buffer_token_ids = torch.zeros(
|
|
678
|
-
(total_size, training_context_size),
|
|
625
|
+
(total_size, self.training_context_size),
|
|
679
626
|
dtype=torch.long,
|
|
680
627
|
device=self.device,
|
|
681
628
|
)
|
|
@@ -700,106 +647,80 @@ class ActivationsStore:
|
|
|
700
647
|
refill_batch_idx_start : refill_batch_idx_start + batch_size, ...
|
|
701
648
|
] = refill_batch_tokens
|
|
702
649
|
|
|
703
|
-
new_buffer_activations = new_buffer_activations.reshape(-1,
|
|
650
|
+
new_buffer_activations = new_buffer_activations.reshape(-1, d_in)
|
|
704
651
|
new_buffer_token_ids = new_buffer_token_ids.reshape(-1)
|
|
705
652
|
if shuffle:
|
|
706
653
|
new_buffer_activations, new_buffer_token_ids = permute_together(
|
|
707
654
|
[new_buffer_activations, new_buffer_token_ids]
|
|
708
655
|
)
|
|
709
656
|
|
|
710
|
-
# every buffer should be normalized:
|
|
711
|
-
if self.normalize_activations == "expected_average_only_in":
|
|
712
|
-
new_buffer_activations = self.apply_norm_scaling_factor(
|
|
713
|
-
new_buffer_activations
|
|
714
|
-
)
|
|
715
|
-
|
|
716
657
|
return (
|
|
717
658
|
new_buffer_activations,
|
|
718
659
|
new_buffer_token_ids,
|
|
719
660
|
)
|
|
720
661
|
|
|
721
|
-
def
|
|
662
|
+
def get_filtered_buffer(
|
|
722
663
|
self,
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
664
|
+
n_batches_in_buffer: int,
|
|
665
|
+
raise_on_epoch_end: bool = False,
|
|
666
|
+
shuffle: bool = True,
|
|
667
|
+
) -> torch.Tensor:
|
|
668
|
+
return _filter_buffer_acts(
|
|
669
|
+
self.get_raw_buffer(
|
|
670
|
+
n_batches_in_buffer=n_batches_in_buffer,
|
|
671
|
+
raise_on_epoch_end=raise_on_epoch_end,
|
|
672
|
+
shuffle=shuffle,
|
|
673
|
+
),
|
|
674
|
+
self.exclude_special_tokens,
|
|
675
|
+
)
|
|
729
676
|
|
|
677
|
+
def _iterate_filtered_activations(self) -> Generator[torch.Tensor, None, None]:
|
|
730
678
|
"""
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
try:
|
|
735
|
-
new_samples = _filter_buffer_acts(
|
|
736
|
-
self.get_buffer(self.half_buffer_size, raise_on_epoch_end=True),
|
|
737
|
-
self.exclude_special_tokens,
|
|
738
|
-
)
|
|
739
|
-
except StopIteration:
|
|
740
|
-
warnings.warn(
|
|
741
|
-
"All samples in the training dataset have been exhausted, we are now beginning a new epoch with the same samples."
|
|
742
|
-
)
|
|
743
|
-
self._storage_buffer = (
|
|
744
|
-
None # dump the current buffer so samples do not leak between epochs
|
|
745
|
-
)
|
|
679
|
+
Iterate over the filtered tokens in the buffer.
|
|
680
|
+
"""
|
|
681
|
+
while True:
|
|
746
682
|
try:
|
|
747
|
-
|
|
748
|
-
self.
|
|
749
|
-
self.exclude_special_tokens,
|
|
683
|
+
yield self.get_filtered_buffer(
|
|
684
|
+
self.half_buffer_size, raise_on_epoch_end=True
|
|
750
685
|
)
|
|
751
686
|
except StopIteration:
|
|
752
|
-
|
|
753
|
-
"
|
|
687
|
+
warnings.warn(
|
|
688
|
+
"All samples in the training dataset have been exhausted, beginning new epoch."
|
|
754
689
|
)
|
|
690
|
+
try:
|
|
691
|
+
yield self.get_filtered_buffer(self.half_buffer_size)
|
|
692
|
+
except StopIteration:
|
|
693
|
+
raise ValueError(
|
|
694
|
+
"Unable to fill buffer after starting new epoch. Dataset may be too small."
|
|
695
|
+
)
|
|
755
696
|
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
# 3. put other 50 % in a dataloader
|
|
768
|
-
return iter(
|
|
769
|
-
DataLoader(
|
|
770
|
-
# TODO: seems like a typing bug?
|
|
771
|
-
cast(Any, mixing_buffer[mixing_buffer.shape[0] // 2 :]),
|
|
772
|
-
batch_size=batch_size,
|
|
773
|
-
shuffle=True,
|
|
774
|
-
)
|
|
697
|
+
def get_data_loader(
|
|
698
|
+
self,
|
|
699
|
+
) -> Iterator[Any]:
|
|
700
|
+
"""
|
|
701
|
+
Return an auto-refilling stream of filtered and mixed activations.
|
|
702
|
+
"""
|
|
703
|
+
return mixing_buffer(
|
|
704
|
+
buffer_size=self.n_batches_in_buffer * self.training_context_size,
|
|
705
|
+
batch_size=self.train_batch_size_tokens,
|
|
706
|
+
activations_loader=self._iterate_filtered_activations(),
|
|
775
707
|
)
|
|
776
708
|
|
|
777
709
|
def next_batch(self) -> torch.Tensor:
|
|
778
|
-
"""
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
return next(self.dataloader)
|
|
785
|
-
except StopIteration:
|
|
786
|
-
# If the DataLoader is exhausted, create a new one
|
|
710
|
+
"""Get next batch, updating buffer if needed."""
|
|
711
|
+
return self.__next__()
|
|
712
|
+
|
|
713
|
+
# ActivationsStore should be an iterator
|
|
714
|
+
def __next__(self) -> torch.Tensor:
|
|
715
|
+
if self._dataloader is None:
|
|
787
716
|
self._dataloader = self.get_data_loader()
|
|
788
|
-
|
|
717
|
+
return next(self._dataloader)
|
|
718
|
+
|
|
719
|
+
def __iter__(self) -> Iterator[torch.Tensor]:
|
|
720
|
+
return self
|
|
789
721
|
|
|
790
722
|
def state_dict(self) -> dict[str, torch.Tensor]:
|
|
791
|
-
|
|
792
|
-
"n_dataset_processed": torch.tensor(self.n_dataset_processed),
|
|
793
|
-
}
|
|
794
|
-
if self._storage_buffer is not None: # first time might be None
|
|
795
|
-
result["storage_buffer_activations"] = self._storage_buffer[0]
|
|
796
|
-
if self._storage_buffer[1] is not None:
|
|
797
|
-
result["storage_buffer_tokens"] = self._storage_buffer[1]
|
|
798
|
-
if self.estimated_norm_scaling_factor is not None:
|
|
799
|
-
result["estimated_norm_scaling_factor"] = torch.tensor(
|
|
800
|
-
self.estimated_norm_scaling_factor
|
|
801
|
-
)
|
|
802
|
-
return result
|
|
723
|
+
return {"n_dataset_processed": torch.tensor(self.n_dataset_processed)}
|
|
803
724
|
|
|
804
725
|
def save(self, file_path: str):
|
|
805
726
|
"""save the state dict to a file in safetensors format"""
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from collections.abc import Iterator
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@torch.no_grad()
|
|
7
|
+
def mixing_buffer(
|
|
8
|
+
buffer_size: int,
|
|
9
|
+
batch_size: int,
|
|
10
|
+
activations_loader: Iterator[torch.Tensor],
|
|
11
|
+
) -> Iterator[torch.Tensor]:
|
|
12
|
+
"""
|
|
13
|
+
A generator that maintains a mix of old and new activations for better training.
|
|
14
|
+
It stores half of the activations and mixes them with new ones to create batches.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
buffer_size: Total size of the buffer (will store buffer_size/2 activations)
|
|
18
|
+
batch_size: Size of batches to return
|
|
19
|
+
activations_loader: Iterator providing new activations
|
|
20
|
+
|
|
21
|
+
Yields:
|
|
22
|
+
Batches of activations of shape (batch_size, *activation_dims)
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
if buffer_size < batch_size:
|
|
26
|
+
raise ValueError("Buffer size must be greater than or equal to batch size")
|
|
27
|
+
|
|
28
|
+
storage_buffer: torch.Tensor | None = None
|
|
29
|
+
|
|
30
|
+
for new_activations in activations_loader:
|
|
31
|
+
storage_buffer = (
|
|
32
|
+
new_activations
|
|
33
|
+
if storage_buffer is None
|
|
34
|
+
else torch.cat([storage_buffer, new_activations], dim=0)
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
if storage_buffer.shape[0] >= buffer_size:
|
|
38
|
+
# Shuffle
|
|
39
|
+
storage_buffer = storage_buffer[torch.randperm(storage_buffer.shape[0])]
|
|
40
|
+
|
|
41
|
+
num_serving_batches = max(1, storage_buffer.shape[0] // (2 * batch_size))
|
|
42
|
+
serving_cutoff = num_serving_batches * batch_size
|
|
43
|
+
serving_buffer = storage_buffer[:serving_cutoff]
|
|
44
|
+
storage_buffer = storage_buffer[serving_cutoff:]
|
|
45
|
+
|
|
46
|
+
# Yield batches from the serving_buffer
|
|
47
|
+
for batch_idx in range(num_serving_batches):
|
|
48
|
+
yield serving_buffer[
|
|
49
|
+
batch_idx * batch_size : (batch_idx + 1) * batch_size
|
|
50
|
+
]
|
|
51
|
+
|
|
52
|
+
# If there are any remaining activations, yield them
|
|
53
|
+
if storage_buffer is not None:
|
|
54
|
+
remaining_batches = storage_buffer.shape[0] // batch_size
|
|
55
|
+
for i in range(remaining_batches):
|
|
56
|
+
yield storage_buffer[i * batch_size : (i + 1) * batch_size]
|