sae-lens 6.27.1__py3-none-any.whl → 6.27.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +1 -1
- sae_lens/cache_activations_runner.py +12 -5
- sae_lens/training/activations_store.py +47 -91
- sae_lens/training/mixing_buffer.py +4 -3
- {sae_lens-6.27.1.dist-info → sae_lens-6.27.3.dist-info}/METADATA +1 -1
- {sae_lens-6.27.1.dist-info → sae_lens-6.27.3.dist-info}/RECORD +8 -8
- {sae_lens-6.27.1.dist-info → sae_lens-6.27.3.dist-info}/WHEEL +0 -0
- {sae_lens-6.27.1.dist-info → sae_lens-6.27.3.dist-info}/licenses/LICENSE +0 -0
sae_lens/__init__.py
CHANGED
|
@@ -263,14 +263,21 @@ class CacheActivationsRunner:
|
|
|
263
263
|
|
|
264
264
|
for i in tqdm(range(self.cfg.n_buffers), desc="Caching activations"):
|
|
265
265
|
try:
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
)
|
|
269
|
-
|
|
266
|
+
# Accumulate n_batches_in_buffer batches into one shard
|
|
267
|
+
buffers: list[tuple[torch.Tensor, torch.Tensor | None]] = []
|
|
268
|
+
for _ in range(self.cfg.n_batches_in_buffer):
|
|
269
|
+
buffers.append(self.activations_store.get_raw_llm_batch())
|
|
270
|
+
# Concatenate all batches
|
|
271
|
+
acts = torch.cat([b[0] for b in buffers], dim=0)
|
|
272
|
+
token_ids: torch.Tensor | None = None
|
|
273
|
+
if buffers[0][1] is not None:
|
|
274
|
+
# All batches have token_ids if the first one does
|
|
275
|
+
token_ids = torch.cat([b[1] for b in buffers], dim=0) # type: ignore[arg-type]
|
|
276
|
+
shard = self._create_shard((acts, token_ids))
|
|
270
277
|
shard.save_to_disk(
|
|
271
278
|
f"{tmp_cached_activation_path}/shard_{i:05d}", num_shards=1
|
|
272
279
|
)
|
|
273
|
-
del
|
|
280
|
+
del buffers, acts, token_ids, shard
|
|
274
281
|
except StopIteration:
|
|
275
282
|
logger.warning(
|
|
276
283
|
f"Warning: Ran out of samples while filling the buffer at batch {i} before reaching {self.cfg.n_buffers} batches."
|
|
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
|
3
3
|
import json
|
|
4
4
|
import os
|
|
5
5
|
import warnings
|
|
6
|
-
from collections.abc import Generator, Iterator
|
|
6
|
+
from collections.abc import Generator, Iterator
|
|
7
7
|
from pathlib import Path
|
|
8
8
|
from typing import Any, Literal, cast
|
|
9
9
|
|
|
@@ -254,7 +254,6 @@ class ActivationsStore:
|
|
|
254
254
|
self.context_size = context_size
|
|
255
255
|
self.d_in = d_in
|
|
256
256
|
self.n_batches_in_buffer = n_batches_in_buffer
|
|
257
|
-
self.half_buffer_size = n_batches_in_buffer // 2
|
|
258
257
|
self.total_training_tokens = total_training_tokens
|
|
259
258
|
self.store_batch_size_prompts = store_batch_size_prompts
|
|
260
259
|
self.train_batch_size_tokens = train_batch_size_tokens
|
|
@@ -538,18 +537,15 @@ class ActivationsStore:
|
|
|
538
537
|
|
|
539
538
|
return stacked_activations
|
|
540
539
|
|
|
541
|
-
def
|
|
540
|
+
def _load_raw_llm_batch_from_cached(
|
|
542
541
|
self,
|
|
543
|
-
total_size: int,
|
|
544
|
-
context_size: int,
|
|
545
|
-
d_in: int,
|
|
546
542
|
raise_on_epoch_end: bool,
|
|
547
543
|
) -> tuple[
|
|
548
544
|
torch.Tensor,
|
|
549
545
|
torch.Tensor | None,
|
|
550
546
|
]:
|
|
551
547
|
"""
|
|
552
|
-
Loads
|
|
548
|
+
Loads a batch of activations from `cached_activation_dataset`
|
|
553
549
|
|
|
554
550
|
The dataset has columns for each hook_name,
|
|
555
551
|
each containing activations of shape (context_size, d_in).
|
|
@@ -557,6 +553,10 @@ class ActivationsStore:
|
|
|
557
553
|
raises StopIteration
|
|
558
554
|
"""
|
|
559
555
|
assert self.cached_activation_dataset is not None
|
|
556
|
+
context_size = self.context_size
|
|
557
|
+
batch_size = self.store_batch_size_prompts
|
|
558
|
+
d_in = self.d_in
|
|
559
|
+
|
|
560
560
|
# In future, could be a list of multiple hook names
|
|
561
561
|
if self.hook_name not in self.cached_activation_dataset.column_names:
|
|
562
562
|
raise ValueError(
|
|
@@ -564,138 +564,100 @@ class ActivationsStore:
|
|
|
564
564
|
f"got {self.cached_activation_dataset.column_names}."
|
|
565
565
|
)
|
|
566
566
|
|
|
567
|
-
if self.current_row_idx > len(self.cached_activation_dataset) -
|
|
567
|
+
if self.current_row_idx > len(self.cached_activation_dataset) - batch_size:
|
|
568
568
|
self.current_row_idx = 0
|
|
569
569
|
if raise_on_epoch_end:
|
|
570
570
|
raise StopIteration
|
|
571
571
|
|
|
572
|
-
new_buffer = []
|
|
573
572
|
ds_slice = self.cached_activation_dataset[
|
|
574
|
-
self.current_row_idx : self.current_row_idx +
|
|
573
|
+
self.current_row_idx : self.current_row_idx + batch_size
|
|
575
574
|
]
|
|
576
575
|
# Load activations for each hook.
|
|
577
576
|
# Usually faster to first slice dataset then pick column
|
|
578
|
-
|
|
579
|
-
if
|
|
577
|
+
acts_buffer = ds_slice[self.hook_name]
|
|
578
|
+
if acts_buffer.shape != (batch_size, context_size, d_in):
|
|
580
579
|
raise ValueError(
|
|
581
|
-
f"
|
|
582
|
-
f"but expected ({
|
|
580
|
+
f"acts_buffer has shape {acts_buffer.shape}, "
|
|
581
|
+
f"but expected ({batch_size}, {context_size}, {d_in})."
|
|
583
582
|
)
|
|
584
583
|
|
|
585
|
-
self.current_row_idx +=
|
|
586
|
-
acts_buffer =
|
|
584
|
+
self.current_row_idx += batch_size
|
|
585
|
+
acts_buffer = acts_buffer.reshape(batch_size * context_size, d_in)
|
|
587
586
|
|
|
588
587
|
if "token_ids" not in self.cached_activation_dataset.column_names:
|
|
589
588
|
return acts_buffer, None
|
|
590
589
|
|
|
591
590
|
token_ids_buffer = ds_slice["token_ids"]
|
|
592
|
-
if token_ids_buffer.shape != (
|
|
591
|
+
if token_ids_buffer.shape != (batch_size, context_size):
|
|
593
592
|
raise ValueError(
|
|
594
593
|
f"token_ids_buffer has shape {token_ids_buffer.shape}, "
|
|
595
|
-
f"but expected ({
|
|
594
|
+
f"but expected ({batch_size}, {context_size})."
|
|
596
595
|
)
|
|
597
|
-
token_ids_buffer = token_ids_buffer.reshape(
|
|
596
|
+
token_ids_buffer = token_ids_buffer.reshape(batch_size * context_size)
|
|
598
597
|
return acts_buffer, token_ids_buffer
|
|
599
598
|
|
|
600
599
|
@torch.no_grad()
|
|
601
|
-
def
|
|
600
|
+
def get_raw_llm_batch(
|
|
602
601
|
self,
|
|
603
|
-
n_batches_in_buffer: int,
|
|
604
602
|
raise_on_epoch_end: bool = False,
|
|
605
|
-
shuffle: bool = True,
|
|
606
603
|
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
607
604
|
"""
|
|
608
|
-
Loads the next
|
|
605
|
+
Loads the next batch of activations from the LLM and returns it.
|
|
609
606
|
|
|
610
|
-
|
|
607
|
+
If raise_on_epoch_end is True, when the dataset is exhausted it will
|
|
608
|
+
automatically refill the dataset and then raise a StopIteration so that
|
|
609
|
+
the caller has a chance to react.
|
|
611
610
|
|
|
612
|
-
|
|
611
|
+
Returns:
|
|
612
|
+
Tuple of (activations, token_ids) where activations has shape
|
|
613
|
+
(batch_size * context_size, d_in) and token_ids has shape
|
|
614
|
+
(batch_size * context_size,).
|
|
613
615
|
"""
|
|
614
|
-
context_size = self.context_size
|
|
615
|
-
batch_size = self.store_batch_size_prompts
|
|
616
616
|
d_in = self.d_in
|
|
617
|
-
total_size = batch_size * n_batches_in_buffer
|
|
618
617
|
|
|
619
618
|
if self.cached_activation_dataset is not None:
|
|
620
|
-
return self.
|
|
621
|
-
total_size, context_size, d_in, raise_on_epoch_end
|
|
622
|
-
)
|
|
619
|
+
return self._load_raw_llm_batch_from_cached(raise_on_epoch_end)
|
|
623
620
|
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
(total_size, self.training_context_size, d_in),
|
|
628
|
-
dtype=self.dtype, # type: ignore
|
|
629
|
-
device=self.device,
|
|
630
|
-
)
|
|
631
|
-
new_buffer_token_ids = torch.zeros(
|
|
632
|
-
(total_size, self.training_context_size),
|
|
633
|
-
dtype=torch.long,
|
|
634
|
-
device=self.device,
|
|
621
|
+
# move batch toks to gpu for model
|
|
622
|
+
batch_tokens = self.get_batch_tokens(raise_at_epoch_end=raise_on_epoch_end).to(
|
|
623
|
+
_get_model_device(self.model)
|
|
635
624
|
)
|
|
625
|
+
activations = self.get_activations(batch_tokens).to(self.device)
|
|
636
626
|
|
|
637
|
-
for
|
|
638
|
-
|
|
639
|
-
):
|
|
640
|
-
# move batch toks to gpu for model
|
|
641
|
-
refill_batch_tokens = self.get_batch_tokens(
|
|
642
|
-
raise_at_epoch_end=raise_on_epoch_end
|
|
643
|
-
).to(_get_model_device(self.model))
|
|
644
|
-
refill_activations = self.get_activations(refill_batch_tokens)
|
|
645
|
-
# move acts back to cpu
|
|
646
|
-
refill_activations.to(self.device)
|
|
647
|
-
new_buffer_activations[
|
|
648
|
-
refill_batch_idx_start : refill_batch_idx_start + batch_size, ...
|
|
649
|
-
] = refill_activations
|
|
650
|
-
|
|
651
|
-
# handle seqpos_slice, this is done for activations in get_activations
|
|
652
|
-
refill_batch_tokens = refill_batch_tokens[:, slice(*self.seqpos_slice)]
|
|
653
|
-
new_buffer_token_ids[
|
|
654
|
-
refill_batch_idx_start : refill_batch_idx_start + batch_size, ...
|
|
655
|
-
] = refill_batch_tokens
|
|
656
|
-
|
|
657
|
-
new_buffer_activations = new_buffer_activations.reshape(-1, d_in)
|
|
658
|
-
new_buffer_token_ids = new_buffer_token_ids.reshape(-1)
|
|
659
|
-
if shuffle:
|
|
660
|
-
new_buffer_activations, new_buffer_token_ids = permute_together(
|
|
661
|
-
[new_buffer_activations, new_buffer_token_ids]
|
|
662
|
-
)
|
|
627
|
+
# handle seqpos_slice, this is done for activations in get_activations
|
|
628
|
+
batch_tokens = batch_tokens[:, slice(*self.seqpos_slice)]
|
|
663
629
|
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
)
|
|
630
|
+
# reshape from (batch, context, d_in) to (batch * context, d_in)
|
|
631
|
+
activations = activations.reshape(-1, d_in)
|
|
632
|
+
token_ids = batch_tokens.reshape(-1)
|
|
668
633
|
|
|
669
|
-
|
|
634
|
+
return activations, token_ids
|
|
635
|
+
|
|
636
|
+
def get_filtered_llm_batch(
|
|
670
637
|
self,
|
|
671
|
-
n_batches_in_buffer: int,
|
|
672
638
|
raise_on_epoch_end: bool = False,
|
|
673
|
-
shuffle: bool = True,
|
|
674
639
|
) -> torch.Tensor:
|
|
640
|
+
"""
|
|
641
|
+
Get a batch of LLM activations with special tokens filtered out.
|
|
642
|
+
"""
|
|
675
643
|
return _filter_buffer_acts(
|
|
676
|
-
self.
|
|
677
|
-
n_batches_in_buffer=n_batches_in_buffer,
|
|
678
|
-
raise_on_epoch_end=raise_on_epoch_end,
|
|
679
|
-
shuffle=shuffle,
|
|
680
|
-
),
|
|
644
|
+
self.get_raw_llm_batch(raise_on_epoch_end=raise_on_epoch_end),
|
|
681
645
|
self.exclude_special_tokens,
|
|
682
646
|
)
|
|
683
647
|
|
|
684
648
|
def _iterate_filtered_activations(self) -> Generator[torch.Tensor, None, None]:
|
|
685
649
|
"""
|
|
686
|
-
Iterate over
|
|
650
|
+
Iterate over filtered LLM activation batches.
|
|
687
651
|
"""
|
|
688
652
|
while True:
|
|
689
653
|
try:
|
|
690
|
-
yield self.
|
|
691
|
-
self.half_buffer_size, raise_on_epoch_end=True
|
|
692
|
-
)
|
|
654
|
+
yield self.get_filtered_llm_batch(raise_on_epoch_end=True)
|
|
693
655
|
except StopIteration:
|
|
694
656
|
warnings.warn(
|
|
695
657
|
"All samples in the training dataset have been exhausted, beginning new epoch."
|
|
696
658
|
)
|
|
697
659
|
try:
|
|
698
|
-
yield self.
|
|
660
|
+
yield self.get_filtered_llm_batch()
|
|
699
661
|
except StopIteration:
|
|
700
662
|
raise ValueError(
|
|
701
663
|
"Unable to fill buffer after starting new epoch. Dataset may be too small."
|
|
@@ -827,9 +789,3 @@ def _filter_buffer_acts(
|
|
|
827
789
|
|
|
828
790
|
mask = torch.isin(tokens, exclude_tokens)
|
|
829
791
|
return activations[~mask]
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
def permute_together(tensors: Sequence[torch.Tensor]) -> tuple[torch.Tensor, ...]:
|
|
833
|
-
"""Permute tensors together."""
|
|
834
|
-
permutation = torch.randperm(tensors[0].shape[0])
|
|
835
|
-
return tuple(t[permutation] for t in tensors)
|
|
@@ -44,9 +44,10 @@ def mixing_buffer(
|
|
|
44
44
|
if mix_fraction > 0:
|
|
45
45
|
storage_buffer = storage_buffer[torch.randperm(storage_buffer.shape[0])]
|
|
46
46
|
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
47
|
+
# Keep a fixed amount for mixing, serve the rest
|
|
48
|
+
keep_for_mixing = int(buffer_size * mix_fraction)
|
|
49
|
+
num_to_serve = storage_buffer.shape[0] - keep_for_mixing
|
|
50
|
+
num_serving_batches = max(1, num_to_serve // batch_size)
|
|
50
51
|
serving_cutoff = num_serving_batches * batch_size
|
|
51
52
|
serving_buffer = storage_buffer[:serving_cutoff]
|
|
52
53
|
storage_buffer = storage_buffer[serving_cutoff:]
|
|
@@ -1,8 +1,8 @@
|
|
|
1
|
-
sae_lens/__init__.py,sha256=
|
|
1
|
+
sae_lens/__init__.py,sha256=ETLfd3PmdJ2aAaKyeTTHptBE2HaWY0OfzOKNk7dyhKE,4725
|
|
2
2
|
sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
3
|
sae_lens/analysis/hooked_sae_transformer.py,sha256=dQRgGVwce8XwylL2AzJE7l9elhtMRFCs2hdUj-Qyy4g,14038
|
|
4
4
|
sae_lens/analysis/neuronpedia_integration.py,sha256=Gx1W7hUBEuMoasNcnOnZ1wmqbXDd1pSZ1nqKEya1HQc,4962
|
|
5
|
-
sae_lens/cache_activations_runner.py,sha256=
|
|
5
|
+
sae_lens/cache_activations_runner.py,sha256=TjqNWIc46Nw09jHWFjzQzgzG5wdu_87Ahe-iFjI5_0Q,13117
|
|
6
6
|
sae_lens/config.py,sha256=sseYcRMsAyopj8FICup1RGTXjFxzAithZ2OH7OpQV3Y,30839
|
|
7
7
|
sae_lens/constants.py,sha256=CM-h9AjZNAl2aP7hVpKk7YsFHpu-_Lfhhmq2d5qPEVc,887
|
|
8
8
|
sae_lens/evals.py,sha256=P0NUsJeGzYxFBiVKhbPzd72IFKY4gH40HHlEZ3jEAmg,39598
|
|
@@ -28,15 +28,15 @@ sae_lens/saes/transcoder.py,sha256=CTpJs8ASOK06npih7gZHygZuxqTR7HICWlOYfTiKjI4,1
|
|
|
28
28
|
sae_lens/tokenization_and_batching.py,sha256=uoHtAs9z3XqG0Fh-iQVYVlrbyB_E3kFFhrKU30BosCo,5438
|
|
29
29
|
sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
30
30
|
sae_lens/training/activation_scaler.py,sha256=FzNfgBplLWmyiSlZ6TUvE-nur3lOiGTrlvC97ys8S24,1973
|
|
31
|
-
sae_lens/training/activations_store.py,sha256=
|
|
32
|
-
sae_lens/training/mixing_buffer.py,sha256=
|
|
31
|
+
sae_lens/training/activations_store.py,sha256=kp4-6R4rTJUSt-g-Ifg5B1h7iIe7jZj-XQSKDvDpQMI,32187
|
|
32
|
+
sae_lens/training/mixing_buffer.py,sha256=1Z-S2CcQXMWGxRZJFnXeZFxbZcALkO_fP6VO37XdJQQ,2519
|
|
33
33
|
sae_lens/training/optim.py,sha256=bJpqqcK4enkcPvQAJkeH4Ci1LUOlfjIMTv6-IlaAbRA,5588
|
|
34
34
|
sae_lens/training/sae_trainer.py,sha256=zhkabyIKxI_tZTV3_kwz6zMrHZ95Ecr97krmwc-9ffs,17600
|
|
35
35
|
sae_lens/training/types.py,sha256=1FpLx_Doda9vZpmfm-x1e8wGBYpyhe9Kpb_JuM5nIFM,90
|
|
36
36
|
sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
|
|
37
37
|
sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
|
|
38
38
|
sae_lens/util.py,sha256=spkcmQUsjVYFn5H2032nQYr1CKGVnv3tAdfIpY59-Mg,3919
|
|
39
|
-
sae_lens-6.27.
|
|
40
|
-
sae_lens-6.27.
|
|
41
|
-
sae_lens-6.27.
|
|
42
|
-
sae_lens-6.27.
|
|
39
|
+
sae_lens-6.27.3.dist-info/METADATA,sha256=c59mjyoausFHs1bd8n_4J6dA-2uDRPgY9Wwas52zydw,5361
|
|
40
|
+
sae_lens-6.27.3.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
|
|
41
|
+
sae_lens-6.27.3.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
|
|
42
|
+
sae_lens-6.27.3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|