sae-lens 5.11.0__py3-none-any.whl → 6.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +60 -7
- sae_lens/analysis/hooked_sae_transformer.py +12 -12
- sae_lens/analysis/neuronpedia_integration.py +16 -14
- sae_lens/cache_activations_runner.py +9 -7
- sae_lens/config.py +170 -258
- sae_lens/constants.py +21 -0
- sae_lens/evals.py +59 -44
- sae_lens/llm_sae_training_runner.py +377 -0
- sae_lens/load_model.py +52 -4
- sae_lens/{toolkit → loading}/pretrained_sae_loaders.py +85 -32
- sae_lens/registry.py +49 -0
- sae_lens/saes/__init__.py +48 -0
- sae_lens/saes/gated_sae.py +254 -0
- sae_lens/saes/jumprelu_sae.py +348 -0
- sae_lens/saes/sae.py +1076 -0
- sae_lens/saes/standard_sae.py +178 -0
- sae_lens/saes/topk_sae.py +300 -0
- sae_lens/training/activation_scaler.py +53 -0
- sae_lens/training/activations_store.py +103 -184
- sae_lens/training/mixing_buffer.py +56 -0
- sae_lens/training/optim.py +60 -36
- sae_lens/training/sae_trainer.py +155 -177
- sae_lens/training/types.py +5 -0
- sae_lens/training/upload_saes_to_huggingface.py +13 -7
- sae_lens/util.py +47 -0
- {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/METADATA +1 -1
- sae_lens-6.0.0.dist-info/RECORD +37 -0
- sae_lens/sae.py +0 -747
- sae_lens/sae_training_runner.py +0 -251
- sae_lens/training/geometric_median.py +0 -101
- sae_lens/training/training_sae.py +0 -710
- sae_lens-5.11.0.dist-info/RECORD +0 -28
- /sae_lens/{toolkit → loading}/__init__.py +0 -0
- /sae_lens/{toolkit → loading}/pretrained_saes_directory.py +0 -0
- {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/LICENSE +0 -0
- {sae_lens-5.11.0.dist-info → sae_lens-6.0.0.dist-info}/WHEEL +0 -0
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import contextlib
|
|
4
3
|
import json
|
|
5
4
|
import os
|
|
6
5
|
import warnings
|
|
@@ -16,20 +15,21 @@ from huggingface_hub.utils import HfHubHTTPError
|
|
|
16
15
|
from jaxtyping import Float, Int
|
|
17
16
|
from requests import HTTPError
|
|
18
17
|
from safetensors.torch import save_file
|
|
19
|
-
from torch.utils.data import DataLoader
|
|
20
18
|
from tqdm import tqdm
|
|
21
19
|
from transformer_lens.hook_points import HookedRootModule
|
|
22
20
|
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
|
23
21
|
|
|
24
22
|
from sae_lens import logger
|
|
25
23
|
from sae_lens.config import (
|
|
26
|
-
DTYPE_MAP,
|
|
27
24
|
CacheActivationsRunnerConfig,
|
|
28
25
|
HfDataset,
|
|
29
26
|
LanguageModelSAERunnerConfig,
|
|
30
27
|
)
|
|
31
|
-
from sae_lens.
|
|
28
|
+
from sae_lens.constants import DTYPE_MAP
|
|
29
|
+
from sae_lens.saes.sae import SAE, T_SAE_CONFIG, T_TRAINING_SAE_CONFIG
|
|
32
30
|
from sae_lens.tokenization_and_batching import concat_and_batch_sequences
|
|
31
|
+
from sae_lens.training.mixing_buffer import mixing_buffer
|
|
32
|
+
from sae_lens.util import extract_stop_at_layer_from_tlens_hook_name
|
|
33
33
|
|
|
34
34
|
|
|
35
35
|
# TODO: Make an activation store config class to be consistent with the rest of the code.
|
|
@@ -45,10 +45,8 @@ class ActivationsStore:
|
|
|
45
45
|
cached_activation_dataset: Dataset | None = None
|
|
46
46
|
tokens_column: Literal["tokens", "input_ids", "text", "problem"]
|
|
47
47
|
hook_name: str
|
|
48
|
-
hook_layer: int
|
|
49
48
|
hook_head_index: int | None
|
|
50
49
|
_dataloader: Iterator[Any] | None = None
|
|
51
|
-
_storage_buffer: torch.Tensor | None = None
|
|
52
50
|
exclude_special_tokens: torch.Tensor | None = None
|
|
53
51
|
device: torch.device
|
|
54
52
|
|
|
@@ -65,7 +63,6 @@ class ActivationsStore:
|
|
|
65
63
|
cached_activations_path=cfg.new_cached_activations_path,
|
|
66
64
|
dtype=cfg.dtype,
|
|
67
65
|
hook_name=cfg.hook_name,
|
|
68
|
-
hook_layer=cfg.hook_layer,
|
|
69
66
|
context_size=cfg.context_size,
|
|
70
67
|
d_in=cfg.d_in,
|
|
71
68
|
n_batches_in_buffer=cfg.n_batches_in_buffer,
|
|
@@ -91,7 +88,8 @@ class ActivationsStore:
|
|
|
91
88
|
def from_config(
|
|
92
89
|
cls,
|
|
93
90
|
model: HookedRootModule,
|
|
94
|
-
cfg: LanguageModelSAERunnerConfig
|
|
91
|
+
cfg: LanguageModelSAERunnerConfig[T_TRAINING_SAE_CONFIG]
|
|
92
|
+
| CacheActivationsRunnerConfig,
|
|
95
93
|
override_dataset: HfDataset | None = None,
|
|
96
94
|
) -> ActivationsStore:
|
|
97
95
|
if isinstance(cfg, CacheActivationsRunnerConfig):
|
|
@@ -125,16 +123,17 @@ class ActivationsStore:
|
|
|
125
123
|
dataset=override_dataset or cfg.dataset_path,
|
|
126
124
|
streaming=cfg.streaming,
|
|
127
125
|
hook_name=cfg.hook_name,
|
|
128
|
-
hook_layer=cfg.hook_layer,
|
|
129
126
|
hook_head_index=cfg.hook_head_index,
|
|
130
127
|
context_size=cfg.context_size,
|
|
131
|
-
d_in=cfg.d_in
|
|
128
|
+
d_in=cfg.d_in
|
|
129
|
+
if isinstance(cfg, CacheActivationsRunnerConfig)
|
|
130
|
+
else cfg.sae.d_in,
|
|
132
131
|
n_batches_in_buffer=cfg.n_batches_in_buffer,
|
|
133
132
|
total_training_tokens=cfg.training_tokens,
|
|
134
133
|
store_batch_size_prompts=cfg.store_batch_size_prompts,
|
|
135
134
|
train_batch_size_tokens=cfg.train_batch_size_tokens,
|
|
136
135
|
prepend_bos=cfg.prepend_bos,
|
|
137
|
-
normalize_activations=cfg.normalize_activations,
|
|
136
|
+
normalize_activations=cfg.sae.normalize_activations,
|
|
138
137
|
device=device,
|
|
139
138
|
dtype=cfg.dtype,
|
|
140
139
|
cached_activations_path=cached_activations_path,
|
|
@@ -149,9 +148,10 @@ class ActivationsStore:
|
|
|
149
148
|
def from_sae(
|
|
150
149
|
cls,
|
|
151
150
|
model: HookedRootModule,
|
|
152
|
-
sae: SAE,
|
|
151
|
+
sae: SAE[T_SAE_CONFIG],
|
|
152
|
+
dataset: HfDataset | str,
|
|
153
|
+
dataset_trust_remote_code: bool = False,
|
|
153
154
|
context_size: int | None = None,
|
|
154
|
-
dataset: HfDataset | str | None = None,
|
|
155
155
|
streaming: bool = True,
|
|
156
156
|
store_batch_size_prompts: int = 8,
|
|
157
157
|
n_batches_in_buffer: int = 8,
|
|
@@ -159,25 +159,32 @@ class ActivationsStore:
|
|
|
159
159
|
total_tokens: int = 10**9,
|
|
160
160
|
device: str = "cpu",
|
|
161
161
|
) -> ActivationsStore:
|
|
162
|
+
if sae.cfg.metadata.hook_name is None:
|
|
163
|
+
raise ValueError("hook_name is required")
|
|
164
|
+
if sae.cfg.metadata.context_size is None:
|
|
165
|
+
raise ValueError("context_size is required")
|
|
166
|
+
if sae.cfg.metadata.prepend_bos is None:
|
|
167
|
+
raise ValueError("prepend_bos is required")
|
|
162
168
|
return cls(
|
|
163
169
|
model=model,
|
|
164
|
-
dataset=
|
|
170
|
+
dataset=dataset,
|
|
165
171
|
d_in=sae.cfg.d_in,
|
|
166
|
-
hook_name=sae.cfg.hook_name,
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
172
|
+
hook_name=sae.cfg.metadata.hook_name,
|
|
173
|
+
hook_head_index=sae.cfg.metadata.hook_head_index,
|
|
174
|
+
context_size=sae.cfg.metadata.context_size
|
|
175
|
+
if context_size is None
|
|
176
|
+
else context_size,
|
|
177
|
+
prepend_bos=sae.cfg.metadata.prepend_bos,
|
|
171
178
|
streaming=streaming,
|
|
172
179
|
store_batch_size_prompts=store_batch_size_prompts,
|
|
173
180
|
train_batch_size_tokens=train_batch_size_tokens,
|
|
174
181
|
n_batches_in_buffer=n_batches_in_buffer,
|
|
175
182
|
total_training_tokens=total_tokens,
|
|
176
183
|
normalize_activations=sae.cfg.normalize_activations,
|
|
177
|
-
dataset_trust_remote_code=
|
|
184
|
+
dataset_trust_remote_code=dataset_trust_remote_code,
|
|
178
185
|
dtype=sae.cfg.dtype,
|
|
179
186
|
device=torch.device(device),
|
|
180
|
-
seqpos_slice=sae.cfg.seqpos_slice,
|
|
187
|
+
seqpos_slice=sae.cfg.metadata.seqpos_slice or (None,),
|
|
181
188
|
)
|
|
182
189
|
|
|
183
190
|
def __init__(
|
|
@@ -186,7 +193,6 @@ class ActivationsStore:
|
|
|
186
193
|
dataset: HfDataset | str,
|
|
187
194
|
streaming: bool,
|
|
188
195
|
hook_name: str,
|
|
189
|
-
hook_layer: int,
|
|
190
196
|
hook_head_index: int | None,
|
|
191
197
|
context_size: int,
|
|
192
198
|
d_in: int,
|
|
@@ -230,7 +236,6 @@ class ActivationsStore:
|
|
|
230
236
|
)
|
|
231
237
|
|
|
232
238
|
self.hook_name = hook_name
|
|
233
|
-
self.hook_layer = hook_layer
|
|
234
239
|
self.hook_head_index = hook_head_index
|
|
235
240
|
self.context_size = context_size
|
|
236
241
|
self.d_in = d_in
|
|
@@ -246,12 +251,11 @@ class ActivationsStore:
|
|
|
246
251
|
self.cached_activations_path = cached_activations_path
|
|
247
252
|
self.autocast_lm = autocast_lm
|
|
248
253
|
self.seqpos_slice = seqpos_slice
|
|
254
|
+
self.training_context_size = len(range(context_size)[slice(*seqpos_slice)])
|
|
249
255
|
self.exclude_special_tokens = exclude_special_tokens
|
|
250
256
|
|
|
251
257
|
self.n_dataset_processed = 0
|
|
252
258
|
|
|
253
|
-
self.estimated_norm_scaling_factor = None
|
|
254
|
-
|
|
255
259
|
# Check if dataset is tokenized
|
|
256
260
|
dataset_sample = next(iter(self.dataset))
|
|
257
261
|
|
|
@@ -416,30 +420,6 @@ class ActivationsStore:
|
|
|
416
420
|
|
|
417
421
|
return activations_dataset
|
|
418
422
|
|
|
419
|
-
def set_norm_scaling_factor_if_needed(self):
|
|
420
|
-
if (
|
|
421
|
-
self.normalize_activations == "expected_average_only_in"
|
|
422
|
-
and self.estimated_norm_scaling_factor is None
|
|
423
|
-
):
|
|
424
|
-
self.estimated_norm_scaling_factor = self.estimate_norm_scaling_factor()
|
|
425
|
-
|
|
426
|
-
def apply_norm_scaling_factor(self, activations: torch.Tensor) -> torch.Tensor:
|
|
427
|
-
if self.estimated_norm_scaling_factor is None:
|
|
428
|
-
raise ValueError(
|
|
429
|
-
"estimated_norm_scaling_factor is not set, call set_norm_scaling_factor_if_needed() first"
|
|
430
|
-
)
|
|
431
|
-
return activations * self.estimated_norm_scaling_factor
|
|
432
|
-
|
|
433
|
-
def unscale(self, activations: torch.Tensor) -> torch.Tensor:
|
|
434
|
-
if self.estimated_norm_scaling_factor is None:
|
|
435
|
-
raise ValueError(
|
|
436
|
-
"estimated_norm_scaling_factor is not set, call set_norm_scaling_factor_if_needed() first"
|
|
437
|
-
)
|
|
438
|
-
return activations / self.estimated_norm_scaling_factor
|
|
439
|
-
|
|
440
|
-
def get_norm_scaling_factor(self, activations: torch.Tensor) -> torch.Tensor:
|
|
441
|
-
return (self.d_in**0.5) / activations.norm(dim=-1).mean()
|
|
442
|
-
|
|
443
423
|
@torch.no_grad()
|
|
444
424
|
def estimate_norm_scaling_factor(self, n_batches_for_norm_estimate: int = int(1e3)):
|
|
445
425
|
norms_per_batch = []
|
|
@@ -474,21 +454,6 @@ class ActivationsStore:
|
|
|
474
454
|
"""
|
|
475
455
|
self.iterable_dataset = iter(self.dataset)
|
|
476
456
|
|
|
477
|
-
@property
|
|
478
|
-
def storage_buffer(self) -> torch.Tensor:
|
|
479
|
-
if self._storage_buffer is None:
|
|
480
|
-
self._storage_buffer = _filter_buffer_acts(
|
|
481
|
-
self.get_buffer(self.half_buffer_size), self.exclude_special_tokens
|
|
482
|
-
)
|
|
483
|
-
|
|
484
|
-
return self._storage_buffer
|
|
485
|
-
|
|
486
|
-
@property
|
|
487
|
-
def dataloader(self) -> Iterator[Any]:
|
|
488
|
-
if self._dataloader is None:
|
|
489
|
-
self._dataloader = self.get_data_loader()
|
|
490
|
-
return self._dataloader
|
|
491
|
-
|
|
492
457
|
def get_batch_tokens(
|
|
493
458
|
self, batch_size: int | None = None, raise_at_epoch_end: bool = False
|
|
494
459
|
):
|
|
@@ -521,22 +486,17 @@ class ActivationsStore:
|
|
|
521
486
|
|
|
522
487
|
d_in may result from a concatenated head dimension.
|
|
523
488
|
"""
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
dtype=torch.bfloat16,
|
|
530
|
-
enabled=self.autocast_lm,
|
|
531
|
-
)
|
|
532
|
-
else:
|
|
533
|
-
autocast_if_enabled = contextlib.nullcontext()
|
|
534
|
-
|
|
535
|
-
with autocast_if_enabled:
|
|
489
|
+
with torch.autocast(
|
|
490
|
+
device_type="cuda",
|
|
491
|
+
dtype=torch.bfloat16,
|
|
492
|
+
enabled=self.autocast_lm,
|
|
493
|
+
):
|
|
536
494
|
layerwise_activations_cache = self.model.run_with_cache(
|
|
537
495
|
batch_tokens,
|
|
538
496
|
names_filter=[self.hook_name],
|
|
539
|
-
stop_at_layer=
|
|
497
|
+
stop_at_layer=extract_stop_at_layer_from_tlens_hook_name(
|
|
498
|
+
self.hook_name
|
|
499
|
+
),
|
|
540
500
|
prepend_bos=False,
|
|
541
501
|
**self.model_kwargs,
|
|
542
502
|
)[1]
|
|
@@ -547,25 +507,25 @@ class ActivationsStore:
|
|
|
547
507
|
|
|
548
508
|
n_batches, n_context = layerwise_activations.shape[:2]
|
|
549
509
|
|
|
550
|
-
stacked_activations = torch.zeros((n_batches, n_context,
|
|
510
|
+
stacked_activations = torch.zeros((n_batches, n_context, self.d_in))
|
|
551
511
|
|
|
552
512
|
if self.hook_head_index is not None:
|
|
553
|
-
stacked_activations[:,
|
|
513
|
+
stacked_activations[:, :] = layerwise_activations[
|
|
554
514
|
:, :, self.hook_head_index
|
|
555
515
|
]
|
|
556
516
|
elif layerwise_activations.ndim > 3: # if we have a head dimension
|
|
557
517
|
try:
|
|
558
|
-
stacked_activations[:,
|
|
518
|
+
stacked_activations[:, :] = layerwise_activations.view(
|
|
559
519
|
n_batches, n_context, -1
|
|
560
520
|
)
|
|
561
521
|
except RuntimeError as e:
|
|
562
522
|
logger.error(f"Error during view operation: {e}")
|
|
563
523
|
logger.info("Attempting to use reshape instead...")
|
|
564
|
-
stacked_activations[:,
|
|
524
|
+
stacked_activations[:, :] = layerwise_activations.reshape(
|
|
565
525
|
n_batches, n_context, -1
|
|
566
526
|
)
|
|
567
527
|
else:
|
|
568
|
-
stacked_activations[:,
|
|
528
|
+
stacked_activations[:, :] = layerwise_activations
|
|
569
529
|
|
|
570
530
|
return stacked_activations
|
|
571
531
|
|
|
@@ -573,7 +533,6 @@ class ActivationsStore:
|
|
|
573
533
|
self,
|
|
574
534
|
total_size: int,
|
|
575
535
|
context_size: int,
|
|
576
|
-
num_layers: int,
|
|
577
536
|
d_in: int,
|
|
578
537
|
raise_on_epoch_end: bool,
|
|
579
538
|
) -> tuple[
|
|
@@ -590,10 +549,9 @@ class ActivationsStore:
|
|
|
590
549
|
"""
|
|
591
550
|
assert self.cached_activation_dataset is not None
|
|
592
551
|
# In future, could be a list of multiple hook names
|
|
593
|
-
|
|
594
|
-
if not set(hook_names).issubset(self.cached_activation_dataset.column_names):
|
|
552
|
+
if self.hook_name not in self.cached_activation_dataset.column_names:
|
|
595
553
|
raise ValueError(
|
|
596
|
-
f"Missing columns in dataset. Expected {
|
|
554
|
+
f"Missing columns in dataset. Expected {self.hook_name}, "
|
|
597
555
|
f"got {self.cached_activation_dataset.column_names}."
|
|
598
556
|
)
|
|
599
557
|
|
|
@@ -606,28 +564,17 @@ class ActivationsStore:
|
|
|
606
564
|
ds_slice = self.cached_activation_dataset[
|
|
607
565
|
self.current_row_idx : self.current_row_idx + total_size
|
|
608
566
|
]
|
|
609
|
-
for
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
if _hook_buffer.shape != (total_size, context_size, d_in):
|
|
614
|
-
raise ValueError(
|
|
615
|
-
f"_hook_buffer has shape {_hook_buffer.shape}, "
|
|
616
|
-
f"but expected ({total_size}, {context_size}, {d_in})."
|
|
617
|
-
)
|
|
618
|
-
new_buffer.append(_hook_buffer)
|
|
619
|
-
|
|
620
|
-
# Stack across num_layers dimension
|
|
621
|
-
# list of num_layers; shape: (total_size, context_size, d_in) -> (total_size, context_size, num_layers, d_in)
|
|
622
|
-
new_buffer = torch.stack(new_buffer, dim=2)
|
|
623
|
-
if new_buffer.shape != (total_size, context_size, num_layers, d_in):
|
|
567
|
+
# Load activations for each hook.
|
|
568
|
+
# Usually faster to first slice dataset then pick column
|
|
569
|
+
new_buffer = ds_slice[self.hook_name]
|
|
570
|
+
if new_buffer.shape != (total_size, context_size, d_in):
|
|
624
571
|
raise ValueError(
|
|
625
572
|
f"new_buffer has shape {new_buffer.shape}, "
|
|
626
|
-
f"but expected ({total_size}, {context_size}, {
|
|
573
|
+
f"but expected ({total_size}, {context_size}, {d_in})."
|
|
627
574
|
)
|
|
628
575
|
|
|
629
576
|
self.current_row_idx += total_size
|
|
630
|
-
acts_buffer = new_buffer.reshape(total_size * context_size,
|
|
577
|
+
acts_buffer = new_buffer.reshape(total_size * context_size, d_in)
|
|
631
578
|
|
|
632
579
|
if "token_ids" not in self.cached_activation_dataset.column_names:
|
|
633
580
|
return acts_buffer, None
|
|
@@ -642,7 +589,7 @@ class ActivationsStore:
|
|
|
642
589
|
return acts_buffer, token_ids_buffer
|
|
643
590
|
|
|
644
591
|
@torch.no_grad()
|
|
645
|
-
def
|
|
592
|
+
def get_raw_buffer(
|
|
646
593
|
self,
|
|
647
594
|
n_batches_in_buffer: int,
|
|
648
595
|
raise_on_epoch_end: bool = False,
|
|
@@ -656,26 +603,24 @@ class ActivationsStore:
|
|
|
656
603
|
If raise_on_epoch_end is True, when the dataset it exhausted it will automatically refill the dataset and then raise a StopIteration so that the caller has a chance to react.
|
|
657
604
|
"""
|
|
658
605
|
context_size = self.context_size
|
|
659
|
-
training_context_size = len(range(context_size)[slice(*self.seqpos_slice)])
|
|
660
606
|
batch_size = self.store_batch_size_prompts
|
|
661
607
|
d_in = self.d_in
|
|
662
608
|
total_size = batch_size * n_batches_in_buffer
|
|
663
|
-
num_layers = 1
|
|
664
609
|
|
|
665
610
|
if self.cached_activation_dataset is not None:
|
|
666
611
|
return self._load_buffer_from_cached(
|
|
667
|
-
total_size, context_size,
|
|
612
|
+
total_size, context_size, d_in, raise_on_epoch_end
|
|
668
613
|
)
|
|
669
614
|
|
|
670
615
|
refill_iterator = range(0, total_size, batch_size)
|
|
671
616
|
# Initialize empty tensor buffer of the maximum required size with an additional dimension for layers
|
|
672
617
|
new_buffer_activations = torch.zeros(
|
|
673
|
-
(total_size, training_context_size,
|
|
618
|
+
(total_size, self.training_context_size, d_in),
|
|
674
619
|
dtype=self.dtype, # type: ignore
|
|
675
620
|
device=self.device,
|
|
676
621
|
)
|
|
677
622
|
new_buffer_token_ids = torch.zeros(
|
|
678
|
-
(total_size, training_context_size),
|
|
623
|
+
(total_size, self.training_context_size),
|
|
679
624
|
dtype=torch.long,
|
|
680
625
|
device=self.device,
|
|
681
626
|
)
|
|
@@ -700,106 +645,80 @@ class ActivationsStore:
|
|
|
700
645
|
refill_batch_idx_start : refill_batch_idx_start + batch_size, ...
|
|
701
646
|
] = refill_batch_tokens
|
|
702
647
|
|
|
703
|
-
new_buffer_activations = new_buffer_activations.reshape(-1,
|
|
648
|
+
new_buffer_activations = new_buffer_activations.reshape(-1, d_in)
|
|
704
649
|
new_buffer_token_ids = new_buffer_token_ids.reshape(-1)
|
|
705
650
|
if shuffle:
|
|
706
651
|
new_buffer_activations, new_buffer_token_ids = permute_together(
|
|
707
652
|
[new_buffer_activations, new_buffer_token_ids]
|
|
708
653
|
)
|
|
709
654
|
|
|
710
|
-
# every buffer should be normalized:
|
|
711
|
-
if self.normalize_activations == "expected_average_only_in":
|
|
712
|
-
new_buffer_activations = self.apply_norm_scaling_factor(
|
|
713
|
-
new_buffer_activations
|
|
714
|
-
)
|
|
715
|
-
|
|
716
655
|
return (
|
|
717
656
|
new_buffer_activations,
|
|
718
657
|
new_buffer_token_ids,
|
|
719
658
|
)
|
|
720
659
|
|
|
721
|
-
def
|
|
660
|
+
def get_filtered_buffer(
|
|
722
661
|
self,
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
662
|
+
n_batches_in_buffer: int,
|
|
663
|
+
raise_on_epoch_end: bool = False,
|
|
664
|
+
shuffle: bool = True,
|
|
665
|
+
) -> torch.Tensor:
|
|
666
|
+
return _filter_buffer_acts(
|
|
667
|
+
self.get_raw_buffer(
|
|
668
|
+
n_batches_in_buffer=n_batches_in_buffer,
|
|
669
|
+
raise_on_epoch_end=raise_on_epoch_end,
|
|
670
|
+
shuffle=shuffle,
|
|
671
|
+
),
|
|
672
|
+
self.exclude_special_tokens,
|
|
673
|
+
)
|
|
729
674
|
|
|
675
|
+
def _iterate_filtered_activations(self) -> Generator[torch.Tensor, None, None]:
|
|
730
676
|
"""
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
try:
|
|
735
|
-
new_samples = _filter_buffer_acts(
|
|
736
|
-
self.get_buffer(self.half_buffer_size, raise_on_epoch_end=True),
|
|
737
|
-
self.exclude_special_tokens,
|
|
738
|
-
)
|
|
739
|
-
except StopIteration:
|
|
740
|
-
warnings.warn(
|
|
741
|
-
"All samples in the training dataset have been exhausted, we are now beginning a new epoch with the same samples."
|
|
742
|
-
)
|
|
743
|
-
self._storage_buffer = (
|
|
744
|
-
None # dump the current buffer so samples do not leak between epochs
|
|
745
|
-
)
|
|
677
|
+
Iterate over the filtered tokens in the buffer.
|
|
678
|
+
"""
|
|
679
|
+
while True:
|
|
746
680
|
try:
|
|
747
|
-
|
|
748
|
-
self.
|
|
749
|
-
self.exclude_special_tokens,
|
|
681
|
+
yield self.get_filtered_buffer(
|
|
682
|
+
self.half_buffer_size, raise_on_epoch_end=True
|
|
750
683
|
)
|
|
751
684
|
except StopIteration:
|
|
752
|
-
|
|
753
|
-
"
|
|
685
|
+
warnings.warn(
|
|
686
|
+
"All samples in the training dataset have been exhausted, beginning new epoch."
|
|
754
687
|
)
|
|
688
|
+
try:
|
|
689
|
+
yield self.get_filtered_buffer(self.half_buffer_size)
|
|
690
|
+
except StopIteration:
|
|
691
|
+
raise ValueError(
|
|
692
|
+
"Unable to fill buffer after starting new epoch. Dataset may be too small."
|
|
693
|
+
)
|
|
755
694
|
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
# 3. put other 50 % in a dataloader
|
|
768
|
-
return iter(
|
|
769
|
-
DataLoader(
|
|
770
|
-
# TODO: seems like a typing bug?
|
|
771
|
-
cast(Any, mixing_buffer[mixing_buffer.shape[0] // 2 :]),
|
|
772
|
-
batch_size=batch_size,
|
|
773
|
-
shuffle=True,
|
|
774
|
-
)
|
|
695
|
+
def get_data_loader(
|
|
696
|
+
self,
|
|
697
|
+
) -> Iterator[Any]:
|
|
698
|
+
"""
|
|
699
|
+
Return an auto-refilling stream of filtered and mixed activations.
|
|
700
|
+
"""
|
|
701
|
+
return mixing_buffer(
|
|
702
|
+
buffer_size=self.n_batches_in_buffer * self.training_context_size,
|
|
703
|
+
batch_size=self.train_batch_size_tokens,
|
|
704
|
+
activations_loader=self._iterate_filtered_activations(),
|
|
775
705
|
)
|
|
776
706
|
|
|
777
707
|
def next_batch(self) -> torch.Tensor:
|
|
778
|
-
"""
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
return next(self.dataloader)
|
|
785
|
-
except StopIteration:
|
|
786
|
-
# If the DataLoader is exhausted, create a new one
|
|
708
|
+
"""Get next batch, updating buffer if needed."""
|
|
709
|
+
return self.__next__()
|
|
710
|
+
|
|
711
|
+
# ActivationsStore should be an iterator
|
|
712
|
+
def __next__(self) -> torch.Tensor:
|
|
713
|
+
if self._dataloader is None:
|
|
787
714
|
self._dataloader = self.get_data_loader()
|
|
788
|
-
|
|
715
|
+
return next(self._dataloader)
|
|
716
|
+
|
|
717
|
+
def __iter__(self) -> Iterator[torch.Tensor]:
|
|
718
|
+
return self
|
|
789
719
|
|
|
790
720
|
def state_dict(self) -> dict[str, torch.Tensor]:
|
|
791
|
-
|
|
792
|
-
"n_dataset_processed": torch.tensor(self.n_dataset_processed),
|
|
793
|
-
}
|
|
794
|
-
if self._storage_buffer is not None: # first time might be None
|
|
795
|
-
result["storage_buffer_activations"] = self._storage_buffer[0]
|
|
796
|
-
if self._storage_buffer[1] is not None:
|
|
797
|
-
result["storage_buffer_tokens"] = self._storage_buffer[1]
|
|
798
|
-
if self.estimated_norm_scaling_factor is not None:
|
|
799
|
-
result["estimated_norm_scaling_factor"] = torch.tensor(
|
|
800
|
-
self.estimated_norm_scaling_factor
|
|
801
|
-
)
|
|
802
|
-
return result
|
|
721
|
+
return {"n_dataset_processed": torch.tensor(self.n_dataset_processed)}
|
|
803
722
|
|
|
804
723
|
def save(self, file_path: str):
|
|
805
724
|
"""save the state dict to a file in safetensors format"""
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from collections.abc import Iterator
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@torch.no_grad()
|
|
7
|
+
def mixing_buffer(
|
|
8
|
+
buffer_size: int,
|
|
9
|
+
batch_size: int,
|
|
10
|
+
activations_loader: Iterator[torch.Tensor],
|
|
11
|
+
) -> Iterator[torch.Tensor]:
|
|
12
|
+
"""
|
|
13
|
+
A generator that maintains a mix of old and new activations for better training.
|
|
14
|
+
It stores half of the activations and mixes them with new ones to create batches.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
buffer_size: Total size of the buffer (will store buffer_size/2 activations)
|
|
18
|
+
batch_size: Size of batches to return
|
|
19
|
+
activations_loader: Iterator providing new activations
|
|
20
|
+
|
|
21
|
+
Yields:
|
|
22
|
+
Batches of activations of shape (batch_size, *activation_dims)
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
if buffer_size < batch_size:
|
|
26
|
+
raise ValueError("Buffer size must be greater than or equal to batch size")
|
|
27
|
+
|
|
28
|
+
storage_buffer: torch.Tensor | None = None
|
|
29
|
+
|
|
30
|
+
for new_activations in activations_loader:
|
|
31
|
+
storage_buffer = (
|
|
32
|
+
new_activations
|
|
33
|
+
if storage_buffer is None
|
|
34
|
+
else torch.cat([storage_buffer, new_activations], dim=0)
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
if storage_buffer.shape[0] >= buffer_size:
|
|
38
|
+
# Shuffle
|
|
39
|
+
storage_buffer = storage_buffer[torch.randperm(storage_buffer.shape[0])]
|
|
40
|
+
|
|
41
|
+
num_serving_batches = max(1, storage_buffer.shape[0] // (2 * batch_size))
|
|
42
|
+
serving_cutoff = num_serving_batches * batch_size
|
|
43
|
+
serving_buffer = storage_buffer[:serving_cutoff]
|
|
44
|
+
storage_buffer = storage_buffer[serving_cutoff:]
|
|
45
|
+
|
|
46
|
+
# Yield batches from the serving_buffer
|
|
47
|
+
for batch_idx in range(num_serving_batches):
|
|
48
|
+
yield serving_buffer[
|
|
49
|
+
batch_idx * batch_size : (batch_idx + 1) * batch_size
|
|
50
|
+
]
|
|
51
|
+
|
|
52
|
+
# If there are any remaining activations, yield them
|
|
53
|
+
if storage_buffer is not None:
|
|
54
|
+
remaining_batches = storage_buffer.shape[0] // batch_size
|
|
55
|
+
for i in range(remaining_batches):
|
|
56
|
+
yield storage_buffer[i * batch_size : (i + 1) * batch_size]
|