sae-lens 6.0.0rc2__py3-none-any.whl → 6.0.0rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +6 -3
- sae_lens/cache_activations_runner.py +7 -6
- sae_lens/config.py +47 -5
- sae_lens/constants.py +2 -0
- sae_lens/evals.py +19 -19
- sae_lens/{sae_training_runner.py → llm_sae_training_runner.py} +92 -60
- sae_lens/load_model.py +53 -5
- sae_lens/loading/pretrained_sae_loaders.py +0 -7
- sae_lens/saes/sae.py +0 -3
- sae_lens/training/activation_scaler.py +53 -0
- sae_lens/training/activations_store.py +77 -172
- sae_lens/training/mixing_buffer.py +56 -0
- sae_lens/training/sae_trainer.py +96 -95
- sae_lens/training/types.py +5 -0
- sae_lens/util.py +19 -0
- {sae_lens-6.0.0rc2.dist-info → sae_lens-6.0.0rc3.dist-info}/METADATA +1 -1
- {sae_lens-6.0.0rc2.dist-info → sae_lens-6.0.0rc3.dist-info}/RECORD +19 -16
- {sae_lens-6.0.0rc2.dist-info → sae_lens-6.0.0rc3.dist-info}/LICENSE +0 -0
- {sae_lens-6.0.0rc2.dist-info → sae_lens-6.0.0rc3.dist-info}/WHEEL +0 -0
|
@@ -193,7 +193,6 @@ def handle_pre_6_0_config(cfg_dict: dict[str, Any]) -> dict[str, Any]:
|
|
|
193
193
|
|
|
194
194
|
rename_keys_map = {
|
|
195
195
|
"hook_point": "hook_name",
|
|
196
|
-
"hook_point_layer": "hook_layer",
|
|
197
196
|
"hook_point_head_index": "hook_head_index",
|
|
198
197
|
"activation_fn_str": "activation_fn",
|
|
199
198
|
}
|
|
@@ -262,7 +261,6 @@ def get_connor_rob_hook_z_config_from_hf(
|
|
|
262
261
|
"device": device if device is not None else "cpu",
|
|
263
262
|
"model_name": "gpt2-small",
|
|
264
263
|
"hook_name": old_cfg_dict["act_name"],
|
|
265
|
-
"hook_layer": old_cfg_dict["layer"],
|
|
266
264
|
"hook_head_index": None,
|
|
267
265
|
"activation_fn": "relu",
|
|
268
266
|
"apply_b_dec_to_input": True,
|
|
@@ -411,7 +409,6 @@ def get_gemma_2_config_from_hf(
|
|
|
411
409
|
"dtype": "float32",
|
|
412
410
|
"model_name": model_name,
|
|
413
411
|
"hook_name": hook_name,
|
|
414
|
-
"hook_layer": layer,
|
|
415
412
|
"hook_head_index": None,
|
|
416
413
|
"activation_fn": "relu",
|
|
417
414
|
"finetuning_scaling_factor": False,
|
|
@@ -524,7 +521,6 @@ def get_llama_scope_config_from_hf(
|
|
|
524
521
|
"dtype": "bfloat16",
|
|
525
522
|
"model_name": model_name,
|
|
526
523
|
"hook_name": old_cfg_dict["hook_point_in"],
|
|
527
|
-
"hook_layer": int(old_cfg_dict["hook_point_in"].split(".")[1]),
|
|
528
524
|
"hook_head_index": None,
|
|
529
525
|
"activation_fn": "relu",
|
|
530
526
|
"finetuning_scaling_factor": False,
|
|
@@ -651,7 +647,6 @@ def get_dictionary_learning_config_1_from_hf(
|
|
|
651
647
|
"device": device,
|
|
652
648
|
"model_name": trainer["lm_name"].split("/")[-1],
|
|
653
649
|
"hook_name": hook_point_name,
|
|
654
|
-
"hook_layer": trainer["layer"],
|
|
655
650
|
"hook_head_index": None,
|
|
656
651
|
"activation_fn": activation_fn,
|
|
657
652
|
"activation_fn_kwargs": activation_fn_kwargs,
|
|
@@ -690,7 +685,6 @@ def get_deepseek_r1_config_from_hf(
|
|
|
690
685
|
"context_size": 1024,
|
|
691
686
|
"model_name": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
|
|
692
687
|
"hook_name": f"blocks.{layer}.hook_resid_post",
|
|
693
|
-
"hook_layer": layer,
|
|
694
688
|
"hook_head_index": None,
|
|
695
689
|
"prepend_bos": True,
|
|
696
690
|
"dataset_path": "lmsys/lmsys-chat-1m",
|
|
@@ -849,7 +843,6 @@ def get_llama_scope_r1_distill_config_from_hf(
|
|
|
849
843
|
"device": device,
|
|
850
844
|
"model_name": model_name,
|
|
851
845
|
"hook_name": huggingface_cfg_dict["hook_point_in"],
|
|
852
|
-
"hook_layer": int(huggingface_cfg_dict["hook_point_in"].split(".")[1]),
|
|
853
846
|
"hook_head_index": None,
|
|
854
847
|
"activation_fn": "relu",
|
|
855
848
|
"finetuning_scaling_factor": False,
|
sae_lens/saes/sae.py
CHANGED
|
@@ -66,7 +66,6 @@ class SAEMetadata:
|
|
|
66
66
|
model_name: str | None = None
|
|
67
67
|
hook_name: str | None = None
|
|
68
68
|
model_class_name: str | None = None
|
|
69
|
-
hook_layer: int | None = None
|
|
70
69
|
hook_head_index: int | None = None
|
|
71
70
|
model_from_pretrained_kwargs: dict[str, Any] | None = None
|
|
72
71
|
prepend_bos: bool | None = None
|
|
@@ -649,7 +648,6 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
|
|
|
649
648
|
class TrainingSAEConfig(SAEConfig, ABC):
|
|
650
649
|
noise_scale: float = 0.0
|
|
651
650
|
mse_loss_normalization: str | None = None
|
|
652
|
-
b_dec_init_method: Literal["zeros", "geometric_median", "mean"] = "zeros"
|
|
653
651
|
# https://transformer-circuits.pub/2024/april-update/index.html#training-saes
|
|
654
652
|
# 0.1 corresponds to the "heuristic" initialization, use None to disable
|
|
655
653
|
decoder_init_norm: float | None = 0.1
|
|
@@ -666,7 +664,6 @@ class TrainingSAEConfig(SAEConfig, ABC):
|
|
|
666
664
|
metadata = SAEMetadata(
|
|
667
665
|
model_name=cfg.model_name,
|
|
668
666
|
hook_name=cfg.hook_name,
|
|
669
|
-
hook_layer=cfg.hook_layer,
|
|
670
667
|
hook_head_index=cfg.hook_head_index,
|
|
671
668
|
context_size=cfg.context_size,
|
|
672
669
|
prepend_bos=cfg.prepend_bos,
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from statistics import mean
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from tqdm import tqdm
|
|
7
|
+
|
|
8
|
+
from sae_lens.training.types import DataProvider
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class ActivationScaler:
|
|
13
|
+
scaling_factor: float | None = None
|
|
14
|
+
|
|
15
|
+
def scale(self, acts: torch.Tensor) -> torch.Tensor:
|
|
16
|
+
return acts if self.scaling_factor is None else acts * self.scaling_factor
|
|
17
|
+
|
|
18
|
+
def unscale(self, acts: torch.Tensor) -> torch.Tensor:
|
|
19
|
+
return acts if self.scaling_factor is None else acts / self.scaling_factor
|
|
20
|
+
|
|
21
|
+
def __call__(self, acts: torch.Tensor) -> torch.Tensor:
|
|
22
|
+
return self.scale(acts)
|
|
23
|
+
|
|
24
|
+
@torch.no_grad()
|
|
25
|
+
def _calculate_mean_norm(
|
|
26
|
+
self, data_provider: DataProvider, n_batches_for_norm_estimate: int = int(1e3)
|
|
27
|
+
) -> float:
|
|
28
|
+
norms_per_batch: list[float] = []
|
|
29
|
+
for _ in tqdm(
|
|
30
|
+
range(n_batches_for_norm_estimate), desc="Estimating norm scaling factor"
|
|
31
|
+
):
|
|
32
|
+
acts = next(data_provider)
|
|
33
|
+
norms_per_batch.append(acts.norm(dim=-1).mean().item())
|
|
34
|
+
return mean(norms_per_batch)
|
|
35
|
+
|
|
36
|
+
def estimate_scaling_factor(
|
|
37
|
+
self,
|
|
38
|
+
d_in: int,
|
|
39
|
+
data_provider: DataProvider,
|
|
40
|
+
n_batches_for_norm_estimate: int = int(1e3),
|
|
41
|
+
):
|
|
42
|
+
mean_norm = self._calculate_mean_norm(
|
|
43
|
+
data_provider, n_batches_for_norm_estimate
|
|
44
|
+
)
|
|
45
|
+
self.scaling_factor = (d_in**0.5) / mean_norm
|
|
46
|
+
|
|
47
|
+
def save(self, file_path: str):
|
|
48
|
+
"""save the state dict to a file in json format"""
|
|
49
|
+
if not file_path.endswith(".json"):
|
|
50
|
+
raise ValueError("file_path must end with .json")
|
|
51
|
+
|
|
52
|
+
with open(file_path, "w") as f:
|
|
53
|
+
json.dump({"scaling_factor": self.scaling_factor}, f)
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import contextlib
|
|
4
3
|
import json
|
|
5
4
|
import os
|
|
6
5
|
import warnings
|
|
@@ -16,7 +15,6 @@ from huggingface_hub.utils import HfHubHTTPError
|
|
|
16
15
|
from jaxtyping import Float, Int
|
|
17
16
|
from requests import HTTPError
|
|
18
17
|
from safetensors.torch import save_file
|
|
19
|
-
from torch.utils.data import DataLoader
|
|
20
18
|
from tqdm import tqdm
|
|
21
19
|
from transformer_lens.hook_points import HookedRootModule
|
|
22
20
|
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
|
@@ -30,6 +28,8 @@ from sae_lens.config import (
|
|
|
30
28
|
from sae_lens.constants import DTYPE_MAP
|
|
31
29
|
from sae_lens.saes.sae import SAE, T_SAE_CONFIG, T_TRAINING_SAE_CONFIG
|
|
32
30
|
from sae_lens.tokenization_and_batching import concat_and_batch_sequences
|
|
31
|
+
from sae_lens.training.mixing_buffer import mixing_buffer
|
|
32
|
+
from sae_lens.util import extract_stop_at_layer_from_tlens_hook_name
|
|
33
33
|
|
|
34
34
|
|
|
35
35
|
# TODO: Make an activation store config class to be consistent with the rest of the code.
|
|
@@ -45,10 +45,8 @@ class ActivationsStore:
|
|
|
45
45
|
cached_activation_dataset: Dataset | None = None
|
|
46
46
|
tokens_column: Literal["tokens", "input_ids", "text", "problem"]
|
|
47
47
|
hook_name: str
|
|
48
|
-
hook_layer: int
|
|
49
48
|
hook_head_index: int | None
|
|
50
49
|
_dataloader: Iterator[Any] | None = None
|
|
51
|
-
_storage_buffer: torch.Tensor | None = None
|
|
52
50
|
exclude_special_tokens: torch.Tensor | None = None
|
|
53
51
|
device: torch.device
|
|
54
52
|
|
|
@@ -65,7 +63,6 @@ class ActivationsStore:
|
|
|
65
63
|
cached_activations_path=cfg.new_cached_activations_path,
|
|
66
64
|
dtype=cfg.dtype,
|
|
67
65
|
hook_name=cfg.hook_name,
|
|
68
|
-
hook_layer=cfg.hook_layer,
|
|
69
66
|
context_size=cfg.context_size,
|
|
70
67
|
d_in=cfg.d_in,
|
|
71
68
|
n_batches_in_buffer=cfg.n_batches_in_buffer,
|
|
@@ -126,7 +123,6 @@ class ActivationsStore:
|
|
|
126
123
|
dataset=override_dataset or cfg.dataset_path,
|
|
127
124
|
streaming=cfg.streaming,
|
|
128
125
|
hook_name=cfg.hook_name,
|
|
129
|
-
hook_layer=cfg.hook_layer,
|
|
130
126
|
hook_head_index=cfg.hook_head_index,
|
|
131
127
|
context_size=cfg.context_size,
|
|
132
128
|
d_in=cfg.d_in
|
|
@@ -165,8 +161,6 @@ class ActivationsStore:
|
|
|
165
161
|
) -> ActivationsStore:
|
|
166
162
|
if sae.cfg.metadata.hook_name is None:
|
|
167
163
|
raise ValueError("hook_name is required")
|
|
168
|
-
if sae.cfg.metadata.hook_layer is None:
|
|
169
|
-
raise ValueError("hook_layer is required")
|
|
170
164
|
if sae.cfg.metadata.hook_head_index is None:
|
|
171
165
|
raise ValueError("hook_head_index is required")
|
|
172
166
|
if sae.cfg.metadata.context_size is None:
|
|
@@ -178,7 +172,6 @@ class ActivationsStore:
|
|
|
178
172
|
dataset=dataset,
|
|
179
173
|
d_in=sae.cfg.d_in,
|
|
180
174
|
hook_name=sae.cfg.metadata.hook_name,
|
|
181
|
-
hook_layer=sae.cfg.metadata.hook_layer,
|
|
182
175
|
hook_head_index=sae.cfg.metadata.hook_head_index,
|
|
183
176
|
context_size=sae.cfg.metadata.context_size
|
|
184
177
|
if context_size is None
|
|
@@ -202,7 +195,6 @@ class ActivationsStore:
|
|
|
202
195
|
dataset: HfDataset | str,
|
|
203
196
|
streaming: bool,
|
|
204
197
|
hook_name: str,
|
|
205
|
-
hook_layer: int,
|
|
206
198
|
hook_head_index: int | None,
|
|
207
199
|
context_size: int,
|
|
208
200
|
d_in: int,
|
|
@@ -246,7 +238,6 @@ class ActivationsStore:
|
|
|
246
238
|
)
|
|
247
239
|
|
|
248
240
|
self.hook_name = hook_name
|
|
249
|
-
self.hook_layer = hook_layer
|
|
250
241
|
self.hook_head_index = hook_head_index
|
|
251
242
|
self.context_size = context_size
|
|
252
243
|
self.d_in = d_in
|
|
@@ -262,12 +253,11 @@ class ActivationsStore:
|
|
|
262
253
|
self.cached_activations_path = cached_activations_path
|
|
263
254
|
self.autocast_lm = autocast_lm
|
|
264
255
|
self.seqpos_slice = seqpos_slice
|
|
256
|
+
self.training_context_size = len(range(context_size)[slice(*seqpos_slice)])
|
|
265
257
|
self.exclude_special_tokens = exclude_special_tokens
|
|
266
258
|
|
|
267
259
|
self.n_dataset_processed = 0
|
|
268
260
|
|
|
269
|
-
self.estimated_norm_scaling_factor = None
|
|
270
|
-
|
|
271
261
|
# Check if dataset is tokenized
|
|
272
262
|
dataset_sample = next(iter(self.dataset))
|
|
273
263
|
|
|
@@ -432,30 +422,6 @@ class ActivationsStore:
|
|
|
432
422
|
|
|
433
423
|
return activations_dataset
|
|
434
424
|
|
|
435
|
-
def set_norm_scaling_factor_if_needed(self):
|
|
436
|
-
if (
|
|
437
|
-
self.normalize_activations == "expected_average_only_in"
|
|
438
|
-
and self.estimated_norm_scaling_factor is None
|
|
439
|
-
):
|
|
440
|
-
self.estimated_norm_scaling_factor = self.estimate_norm_scaling_factor()
|
|
441
|
-
|
|
442
|
-
def apply_norm_scaling_factor(self, activations: torch.Tensor) -> torch.Tensor:
|
|
443
|
-
if self.estimated_norm_scaling_factor is None:
|
|
444
|
-
raise ValueError(
|
|
445
|
-
"estimated_norm_scaling_factor is not set, call set_norm_scaling_factor_if_needed() first"
|
|
446
|
-
)
|
|
447
|
-
return activations * self.estimated_norm_scaling_factor
|
|
448
|
-
|
|
449
|
-
def unscale(self, activations: torch.Tensor) -> torch.Tensor:
|
|
450
|
-
if self.estimated_norm_scaling_factor is None:
|
|
451
|
-
raise ValueError(
|
|
452
|
-
"estimated_norm_scaling_factor is not set, call set_norm_scaling_factor_if_needed() first"
|
|
453
|
-
)
|
|
454
|
-
return activations / self.estimated_norm_scaling_factor
|
|
455
|
-
|
|
456
|
-
def get_norm_scaling_factor(self, activations: torch.Tensor) -> torch.Tensor:
|
|
457
|
-
return (self.d_in**0.5) / activations.norm(dim=-1).mean()
|
|
458
|
-
|
|
459
425
|
@torch.no_grad()
|
|
460
426
|
def estimate_norm_scaling_factor(self, n_batches_for_norm_estimate: int = int(1e3)):
|
|
461
427
|
norms_per_batch = []
|
|
@@ -490,21 +456,6 @@ class ActivationsStore:
|
|
|
490
456
|
"""
|
|
491
457
|
self.iterable_dataset = iter(self.dataset)
|
|
492
458
|
|
|
493
|
-
@property
|
|
494
|
-
def storage_buffer(self) -> torch.Tensor:
|
|
495
|
-
if self._storage_buffer is None:
|
|
496
|
-
self._storage_buffer = _filter_buffer_acts(
|
|
497
|
-
self.get_buffer(self.half_buffer_size), self.exclude_special_tokens
|
|
498
|
-
)
|
|
499
|
-
|
|
500
|
-
return self._storage_buffer
|
|
501
|
-
|
|
502
|
-
@property
|
|
503
|
-
def dataloader(self) -> Iterator[Any]:
|
|
504
|
-
if self._dataloader is None:
|
|
505
|
-
self._dataloader = self.get_data_loader()
|
|
506
|
-
return self._dataloader
|
|
507
|
-
|
|
508
459
|
def get_batch_tokens(
|
|
509
460
|
self, batch_size: int | None = None, raise_at_epoch_end: bool = False
|
|
510
461
|
):
|
|
@@ -537,22 +488,17 @@ class ActivationsStore:
|
|
|
537
488
|
|
|
538
489
|
d_in may result from a concatenated head dimension.
|
|
539
490
|
"""
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
dtype=torch.bfloat16,
|
|
546
|
-
enabled=self.autocast_lm,
|
|
547
|
-
)
|
|
548
|
-
else:
|
|
549
|
-
autocast_if_enabled = contextlib.nullcontext()
|
|
550
|
-
|
|
551
|
-
with autocast_if_enabled:
|
|
491
|
+
with torch.autocast(
|
|
492
|
+
device_type="cuda",
|
|
493
|
+
dtype=torch.bfloat16,
|
|
494
|
+
enabled=self.autocast_lm,
|
|
495
|
+
):
|
|
552
496
|
layerwise_activations_cache = self.model.run_with_cache(
|
|
553
497
|
batch_tokens,
|
|
554
498
|
names_filter=[self.hook_name],
|
|
555
|
-
stop_at_layer=
|
|
499
|
+
stop_at_layer=extract_stop_at_layer_from_tlens_hook_name(
|
|
500
|
+
self.hook_name
|
|
501
|
+
),
|
|
556
502
|
prepend_bos=False,
|
|
557
503
|
**self.model_kwargs,
|
|
558
504
|
)[1]
|
|
@@ -563,25 +509,25 @@ class ActivationsStore:
|
|
|
563
509
|
|
|
564
510
|
n_batches, n_context = layerwise_activations.shape[:2]
|
|
565
511
|
|
|
566
|
-
stacked_activations = torch.zeros((n_batches, n_context,
|
|
512
|
+
stacked_activations = torch.zeros((n_batches, n_context, self.d_in))
|
|
567
513
|
|
|
568
514
|
if self.hook_head_index is not None:
|
|
569
|
-
stacked_activations[:,
|
|
515
|
+
stacked_activations[:, :] = layerwise_activations[
|
|
570
516
|
:, :, self.hook_head_index
|
|
571
517
|
]
|
|
572
518
|
elif layerwise_activations.ndim > 3: # if we have a head dimension
|
|
573
519
|
try:
|
|
574
|
-
stacked_activations[:,
|
|
520
|
+
stacked_activations[:, :] = layerwise_activations.view(
|
|
575
521
|
n_batches, n_context, -1
|
|
576
522
|
)
|
|
577
523
|
except RuntimeError as e:
|
|
578
524
|
logger.error(f"Error during view operation: {e}")
|
|
579
525
|
logger.info("Attempting to use reshape instead...")
|
|
580
|
-
stacked_activations[:,
|
|
526
|
+
stacked_activations[:, :] = layerwise_activations.reshape(
|
|
581
527
|
n_batches, n_context, -1
|
|
582
528
|
)
|
|
583
529
|
else:
|
|
584
|
-
stacked_activations[:,
|
|
530
|
+
stacked_activations[:, :] = layerwise_activations
|
|
585
531
|
|
|
586
532
|
return stacked_activations
|
|
587
533
|
|
|
@@ -589,7 +535,6 @@ class ActivationsStore:
|
|
|
589
535
|
self,
|
|
590
536
|
total_size: int,
|
|
591
537
|
context_size: int,
|
|
592
|
-
num_layers: int,
|
|
593
538
|
d_in: int,
|
|
594
539
|
raise_on_epoch_end: bool,
|
|
595
540
|
) -> tuple[
|
|
@@ -606,10 +551,9 @@ class ActivationsStore:
|
|
|
606
551
|
"""
|
|
607
552
|
assert self.cached_activation_dataset is not None
|
|
608
553
|
# In future, could be a list of multiple hook names
|
|
609
|
-
|
|
610
|
-
if not set(hook_names).issubset(self.cached_activation_dataset.column_names):
|
|
554
|
+
if self.hook_name not in self.cached_activation_dataset.column_names:
|
|
611
555
|
raise ValueError(
|
|
612
|
-
f"Missing columns in dataset. Expected {
|
|
556
|
+
f"Missing columns in dataset. Expected {self.hook_name}, "
|
|
613
557
|
f"got {self.cached_activation_dataset.column_names}."
|
|
614
558
|
)
|
|
615
559
|
|
|
@@ -622,28 +566,17 @@ class ActivationsStore:
|
|
|
622
566
|
ds_slice = self.cached_activation_dataset[
|
|
623
567
|
self.current_row_idx : self.current_row_idx + total_size
|
|
624
568
|
]
|
|
625
|
-
for
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
if _hook_buffer.shape != (total_size, context_size, d_in):
|
|
630
|
-
raise ValueError(
|
|
631
|
-
f"_hook_buffer has shape {_hook_buffer.shape}, "
|
|
632
|
-
f"but expected ({total_size}, {context_size}, {d_in})."
|
|
633
|
-
)
|
|
634
|
-
new_buffer.append(_hook_buffer)
|
|
635
|
-
|
|
636
|
-
# Stack across num_layers dimension
|
|
637
|
-
# list of num_layers; shape: (total_size, context_size, d_in) -> (total_size, context_size, num_layers, d_in)
|
|
638
|
-
new_buffer = torch.stack(new_buffer, dim=2)
|
|
639
|
-
if new_buffer.shape != (total_size, context_size, num_layers, d_in):
|
|
569
|
+
# Load activations for each hook.
|
|
570
|
+
# Usually faster to first slice dataset then pick column
|
|
571
|
+
new_buffer = ds_slice[self.hook_name]
|
|
572
|
+
if new_buffer.shape != (total_size, context_size, d_in):
|
|
640
573
|
raise ValueError(
|
|
641
574
|
f"new_buffer has shape {new_buffer.shape}, "
|
|
642
|
-
f"but expected ({total_size}, {context_size}, {
|
|
575
|
+
f"but expected ({total_size}, {context_size}, {d_in})."
|
|
643
576
|
)
|
|
644
577
|
|
|
645
578
|
self.current_row_idx += total_size
|
|
646
|
-
acts_buffer = new_buffer.reshape(total_size * context_size,
|
|
579
|
+
acts_buffer = new_buffer.reshape(total_size * context_size, d_in)
|
|
647
580
|
|
|
648
581
|
if "token_ids" not in self.cached_activation_dataset.column_names:
|
|
649
582
|
return acts_buffer, None
|
|
@@ -658,7 +591,7 @@ class ActivationsStore:
|
|
|
658
591
|
return acts_buffer, token_ids_buffer
|
|
659
592
|
|
|
660
593
|
@torch.no_grad()
|
|
661
|
-
def
|
|
594
|
+
def get_raw_buffer(
|
|
662
595
|
self,
|
|
663
596
|
n_batches_in_buffer: int,
|
|
664
597
|
raise_on_epoch_end: bool = False,
|
|
@@ -672,26 +605,24 @@ class ActivationsStore:
|
|
|
672
605
|
If raise_on_epoch_end is True, when the dataset it exhausted it will automatically refill the dataset and then raise a StopIteration so that the caller has a chance to react.
|
|
673
606
|
"""
|
|
674
607
|
context_size = self.context_size
|
|
675
|
-
training_context_size = len(range(context_size)[slice(*self.seqpos_slice)])
|
|
676
608
|
batch_size = self.store_batch_size_prompts
|
|
677
609
|
d_in = self.d_in
|
|
678
610
|
total_size = batch_size * n_batches_in_buffer
|
|
679
|
-
num_layers = 1
|
|
680
611
|
|
|
681
612
|
if self.cached_activation_dataset is not None:
|
|
682
613
|
return self._load_buffer_from_cached(
|
|
683
|
-
total_size, context_size,
|
|
614
|
+
total_size, context_size, d_in, raise_on_epoch_end
|
|
684
615
|
)
|
|
685
616
|
|
|
686
617
|
refill_iterator = range(0, total_size, batch_size)
|
|
687
618
|
# Initialize empty tensor buffer of the maximum required size with an additional dimension for layers
|
|
688
619
|
new_buffer_activations = torch.zeros(
|
|
689
|
-
(total_size, training_context_size,
|
|
620
|
+
(total_size, self.training_context_size, d_in),
|
|
690
621
|
dtype=self.dtype, # type: ignore
|
|
691
622
|
device=self.device,
|
|
692
623
|
)
|
|
693
624
|
new_buffer_token_ids = torch.zeros(
|
|
694
|
-
(total_size, training_context_size),
|
|
625
|
+
(total_size, self.training_context_size),
|
|
695
626
|
dtype=torch.long,
|
|
696
627
|
device=self.device,
|
|
697
628
|
)
|
|
@@ -716,106 +647,80 @@ class ActivationsStore:
|
|
|
716
647
|
refill_batch_idx_start : refill_batch_idx_start + batch_size, ...
|
|
717
648
|
] = refill_batch_tokens
|
|
718
649
|
|
|
719
|
-
new_buffer_activations = new_buffer_activations.reshape(-1,
|
|
650
|
+
new_buffer_activations = new_buffer_activations.reshape(-1, d_in)
|
|
720
651
|
new_buffer_token_ids = new_buffer_token_ids.reshape(-1)
|
|
721
652
|
if shuffle:
|
|
722
653
|
new_buffer_activations, new_buffer_token_ids = permute_together(
|
|
723
654
|
[new_buffer_activations, new_buffer_token_ids]
|
|
724
655
|
)
|
|
725
656
|
|
|
726
|
-
# every buffer should be normalized:
|
|
727
|
-
if self.normalize_activations == "expected_average_only_in":
|
|
728
|
-
new_buffer_activations = self.apply_norm_scaling_factor(
|
|
729
|
-
new_buffer_activations
|
|
730
|
-
)
|
|
731
|
-
|
|
732
657
|
return (
|
|
733
658
|
new_buffer_activations,
|
|
734
659
|
new_buffer_token_ids,
|
|
735
660
|
)
|
|
736
661
|
|
|
737
|
-
def
|
|
662
|
+
def get_filtered_buffer(
|
|
738
663
|
self,
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
664
|
+
n_batches_in_buffer: int,
|
|
665
|
+
raise_on_epoch_end: bool = False,
|
|
666
|
+
shuffle: bool = True,
|
|
667
|
+
) -> torch.Tensor:
|
|
668
|
+
return _filter_buffer_acts(
|
|
669
|
+
self.get_raw_buffer(
|
|
670
|
+
n_batches_in_buffer=n_batches_in_buffer,
|
|
671
|
+
raise_on_epoch_end=raise_on_epoch_end,
|
|
672
|
+
shuffle=shuffle,
|
|
673
|
+
),
|
|
674
|
+
self.exclude_special_tokens,
|
|
675
|
+
)
|
|
745
676
|
|
|
677
|
+
def _iterate_filtered_activations(self) -> Generator[torch.Tensor, None, None]:
|
|
746
678
|
"""
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
try:
|
|
751
|
-
new_samples = _filter_buffer_acts(
|
|
752
|
-
self.get_buffer(self.half_buffer_size, raise_on_epoch_end=True),
|
|
753
|
-
self.exclude_special_tokens,
|
|
754
|
-
)
|
|
755
|
-
except StopIteration:
|
|
756
|
-
warnings.warn(
|
|
757
|
-
"All samples in the training dataset have been exhausted, we are now beginning a new epoch with the same samples."
|
|
758
|
-
)
|
|
759
|
-
self._storage_buffer = (
|
|
760
|
-
None # dump the current buffer so samples do not leak between epochs
|
|
761
|
-
)
|
|
679
|
+
Iterate over the filtered tokens in the buffer.
|
|
680
|
+
"""
|
|
681
|
+
while True:
|
|
762
682
|
try:
|
|
763
|
-
|
|
764
|
-
self.
|
|
765
|
-
self.exclude_special_tokens,
|
|
683
|
+
yield self.get_filtered_buffer(
|
|
684
|
+
self.half_buffer_size, raise_on_epoch_end=True
|
|
766
685
|
)
|
|
767
686
|
except StopIteration:
|
|
768
|
-
|
|
769
|
-
"
|
|
687
|
+
warnings.warn(
|
|
688
|
+
"All samples in the training dataset have been exhausted, beginning new epoch."
|
|
770
689
|
)
|
|
690
|
+
try:
|
|
691
|
+
yield self.get_filtered_buffer(self.half_buffer_size)
|
|
692
|
+
except StopIteration:
|
|
693
|
+
raise ValueError(
|
|
694
|
+
"Unable to fill buffer after starting new epoch. Dataset may be too small."
|
|
695
|
+
)
|
|
771
696
|
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
# 3. put other 50 % in a dataloader
|
|
784
|
-
return iter(
|
|
785
|
-
DataLoader(
|
|
786
|
-
# TODO: seems like a typing bug?
|
|
787
|
-
cast(Any, mixing_buffer[mixing_buffer.shape[0] // 2 :]),
|
|
788
|
-
batch_size=batch_size,
|
|
789
|
-
shuffle=True,
|
|
790
|
-
)
|
|
697
|
+
def get_data_loader(
|
|
698
|
+
self,
|
|
699
|
+
) -> Iterator[Any]:
|
|
700
|
+
"""
|
|
701
|
+
Return an auto-refilling stream of filtered and mixed activations.
|
|
702
|
+
"""
|
|
703
|
+
return mixing_buffer(
|
|
704
|
+
buffer_size=self.n_batches_in_buffer * self.training_context_size,
|
|
705
|
+
batch_size=self.train_batch_size_tokens,
|
|
706
|
+
activations_loader=self._iterate_filtered_activations(),
|
|
791
707
|
)
|
|
792
708
|
|
|
793
709
|
def next_batch(self) -> torch.Tensor:
|
|
794
|
-
"""
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
return next(self.dataloader)
|
|
801
|
-
except StopIteration:
|
|
802
|
-
# If the DataLoader is exhausted, create a new one
|
|
710
|
+
"""Get next batch, updating buffer if needed."""
|
|
711
|
+
return self.__next__()
|
|
712
|
+
|
|
713
|
+
# ActivationsStore should be an iterator
|
|
714
|
+
def __next__(self) -> torch.Tensor:
|
|
715
|
+
if self._dataloader is None:
|
|
803
716
|
self._dataloader = self.get_data_loader()
|
|
804
|
-
|
|
717
|
+
return next(self._dataloader)
|
|
718
|
+
|
|
719
|
+
def __iter__(self) -> Iterator[torch.Tensor]:
|
|
720
|
+
return self
|
|
805
721
|
|
|
806
722
|
def state_dict(self) -> dict[str, torch.Tensor]:
|
|
807
|
-
|
|
808
|
-
"n_dataset_processed": torch.tensor(self.n_dataset_processed),
|
|
809
|
-
}
|
|
810
|
-
if self._storage_buffer is not None: # first time might be None
|
|
811
|
-
result["storage_buffer_activations"] = self._storage_buffer[0]
|
|
812
|
-
if self._storage_buffer[1] is not None:
|
|
813
|
-
result["storage_buffer_tokens"] = self._storage_buffer[1]
|
|
814
|
-
if self.estimated_norm_scaling_factor is not None:
|
|
815
|
-
result["estimated_norm_scaling_factor"] = torch.tensor(
|
|
816
|
-
self.estimated_norm_scaling_factor
|
|
817
|
-
)
|
|
818
|
-
return result
|
|
723
|
+
return {"n_dataset_processed": torch.tensor(self.n_dataset_processed)}
|
|
819
724
|
|
|
820
725
|
def save(self, file_path: str):
|
|
821
726
|
"""save the state dict to a file in safetensors format"""
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from collections.abc import Iterator
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@torch.no_grad()
|
|
7
|
+
def mixing_buffer(
|
|
8
|
+
buffer_size: int,
|
|
9
|
+
batch_size: int,
|
|
10
|
+
activations_loader: Iterator[torch.Tensor],
|
|
11
|
+
) -> Iterator[torch.Tensor]:
|
|
12
|
+
"""
|
|
13
|
+
A generator that maintains a mix of old and new activations for better training.
|
|
14
|
+
It stores half of the activations and mixes them with new ones to create batches.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
buffer_size: Total size of the buffer (will store buffer_size/2 activations)
|
|
18
|
+
batch_size: Size of batches to return
|
|
19
|
+
activations_loader: Iterator providing new activations
|
|
20
|
+
|
|
21
|
+
Yields:
|
|
22
|
+
Batches of activations of shape (batch_size, *activation_dims)
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
if buffer_size < batch_size:
|
|
26
|
+
raise ValueError("Buffer size must be greater than or equal to batch size")
|
|
27
|
+
|
|
28
|
+
storage_buffer: torch.Tensor | None = None
|
|
29
|
+
|
|
30
|
+
for new_activations in activations_loader:
|
|
31
|
+
storage_buffer = (
|
|
32
|
+
new_activations
|
|
33
|
+
if storage_buffer is None
|
|
34
|
+
else torch.cat([storage_buffer, new_activations], dim=0)
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
if storage_buffer.shape[0] >= buffer_size:
|
|
38
|
+
# Shuffle
|
|
39
|
+
storage_buffer = storage_buffer[torch.randperm(storage_buffer.shape[0])]
|
|
40
|
+
|
|
41
|
+
num_serving_batches = max(1, storage_buffer.shape[0] // (2 * batch_size))
|
|
42
|
+
serving_cutoff = num_serving_batches * batch_size
|
|
43
|
+
serving_buffer = storage_buffer[:serving_cutoff]
|
|
44
|
+
storage_buffer = storage_buffer[serving_cutoff:]
|
|
45
|
+
|
|
46
|
+
# Yield batches from the serving_buffer
|
|
47
|
+
for batch_idx in range(num_serving_batches):
|
|
48
|
+
yield serving_buffer[
|
|
49
|
+
batch_idx * batch_size : (batch_idx + 1) * batch_size
|
|
50
|
+
]
|
|
51
|
+
|
|
52
|
+
# If there are any remaining activations, yield them
|
|
53
|
+
if storage_buffer is not None:
|
|
54
|
+
remaining_batches = storage_buffer.shape[0] // batch_size
|
|
55
|
+
for i in range(remaining_batches):
|
|
56
|
+
yield storage_buffer[i * batch_size : (i + 1) * batch_size]
|