sae-lens 6.27.2__py3-none-any.whl → 6.27.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sae_lens/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.27.2"
2
+ __version__ = "6.27.3"
3
3
 
4
4
  import logging
5
5
 
@@ -263,14 +263,21 @@ class CacheActivationsRunner:
263
263
 
264
264
  for i in tqdm(range(self.cfg.n_buffers), desc="Caching activations"):
265
265
  try:
266
- buffer = self.activations_store.get_raw_buffer(
267
- self.cfg.n_batches_in_buffer, shuffle=False
268
- )
269
- shard = self._create_shard(buffer)
266
+ # Accumulate n_batches_in_buffer batches into one shard
267
+ buffers: list[tuple[torch.Tensor, torch.Tensor | None]] = []
268
+ for _ in range(self.cfg.n_batches_in_buffer):
269
+ buffers.append(self.activations_store.get_raw_llm_batch())
270
+ # Concatenate all batches
271
+ acts = torch.cat([b[0] for b in buffers], dim=0)
272
+ token_ids: torch.Tensor | None = None
273
+ if buffers[0][1] is not None:
274
+ # All batches have token_ids if the first one does
275
+ token_ids = torch.cat([b[1] for b in buffers], dim=0) # type: ignore[arg-type]
276
+ shard = self._create_shard((acts, token_ids))
270
277
  shard.save_to_disk(
271
278
  f"{tmp_cached_activation_path}/shard_{i:05d}", num_shards=1
272
279
  )
273
- del buffer, shard
280
+ del buffers, acts, token_ids, shard
274
281
  except StopIteration:
275
282
  logger.warning(
276
283
  f"Warning: Ran out of samples while filling the buffer at batch {i} before reaching {self.cfg.n_buffers} batches."
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  import json
4
4
  import os
5
5
  import warnings
6
- from collections.abc import Generator, Iterator, Sequence
6
+ from collections.abc import Generator, Iterator
7
7
  from pathlib import Path
8
8
  from typing import Any, Literal, cast
9
9
 
@@ -254,7 +254,6 @@ class ActivationsStore:
254
254
  self.context_size = context_size
255
255
  self.d_in = d_in
256
256
  self.n_batches_in_buffer = n_batches_in_buffer
257
- self.half_buffer_size = n_batches_in_buffer // 2
258
257
  self.total_training_tokens = total_training_tokens
259
258
  self.store_batch_size_prompts = store_batch_size_prompts
260
259
  self.train_batch_size_tokens = train_batch_size_tokens
@@ -538,18 +537,15 @@ class ActivationsStore:
538
537
 
539
538
  return stacked_activations
540
539
 
541
- def _load_buffer_from_cached(
540
+ def _load_raw_llm_batch_from_cached(
542
541
  self,
543
- total_size: int,
544
- context_size: int,
545
- d_in: int,
546
542
  raise_on_epoch_end: bool,
547
543
  ) -> tuple[
548
544
  torch.Tensor,
549
545
  torch.Tensor | None,
550
546
  ]:
551
547
  """
552
- Loads `total_size` activations from `cached_activation_dataset`
548
+ Loads a batch of activations from `cached_activation_dataset`
553
549
 
554
550
  The dataset has columns for each hook_name,
555
551
  each containing activations of shape (context_size, d_in).
@@ -557,6 +553,10 @@ class ActivationsStore:
557
553
  raises StopIteration
558
554
  """
559
555
  assert self.cached_activation_dataset is not None
556
+ context_size = self.context_size
557
+ batch_size = self.store_batch_size_prompts
558
+ d_in = self.d_in
559
+
560
560
  # In future, could be a list of multiple hook names
561
561
  if self.hook_name not in self.cached_activation_dataset.column_names:
562
562
  raise ValueError(
@@ -564,138 +564,100 @@ class ActivationsStore:
564
564
  f"got {self.cached_activation_dataset.column_names}."
565
565
  )
566
566
 
567
- if self.current_row_idx > len(self.cached_activation_dataset) - total_size:
567
+ if self.current_row_idx > len(self.cached_activation_dataset) - batch_size:
568
568
  self.current_row_idx = 0
569
569
  if raise_on_epoch_end:
570
570
  raise StopIteration
571
571
 
572
- new_buffer = []
573
572
  ds_slice = self.cached_activation_dataset[
574
- self.current_row_idx : self.current_row_idx + total_size
573
+ self.current_row_idx : self.current_row_idx + batch_size
575
574
  ]
576
575
  # Load activations for each hook.
577
576
  # Usually faster to first slice dataset then pick column
578
- new_buffer = ds_slice[self.hook_name]
579
- if new_buffer.shape != (total_size, context_size, d_in):
577
+ acts_buffer = ds_slice[self.hook_name]
578
+ if acts_buffer.shape != (batch_size, context_size, d_in):
580
579
  raise ValueError(
581
- f"new_buffer has shape {new_buffer.shape}, "
582
- f"but expected ({total_size}, {context_size}, {d_in})."
580
+ f"acts_buffer has shape {acts_buffer.shape}, "
581
+ f"but expected ({batch_size}, {context_size}, {d_in})."
583
582
  )
584
583
 
585
- self.current_row_idx += total_size
586
- acts_buffer = new_buffer.reshape(total_size * context_size, d_in)
584
+ self.current_row_idx += batch_size
585
+ acts_buffer = acts_buffer.reshape(batch_size * context_size, d_in)
587
586
 
588
587
  if "token_ids" not in self.cached_activation_dataset.column_names:
589
588
  return acts_buffer, None
590
589
 
591
590
  token_ids_buffer = ds_slice["token_ids"]
592
- if token_ids_buffer.shape != (total_size, context_size):
591
+ if token_ids_buffer.shape != (batch_size, context_size):
593
592
  raise ValueError(
594
593
  f"token_ids_buffer has shape {token_ids_buffer.shape}, "
595
- f"but expected ({total_size}, {context_size})."
594
+ f"but expected ({batch_size}, {context_size})."
596
595
  )
597
- token_ids_buffer = token_ids_buffer.reshape(total_size * context_size)
596
+ token_ids_buffer = token_ids_buffer.reshape(batch_size * context_size)
598
597
  return acts_buffer, token_ids_buffer
599
598
 
600
599
  @torch.no_grad()
601
- def get_raw_buffer(
600
+ def get_raw_llm_batch(
602
601
  self,
603
- n_batches_in_buffer: int,
604
602
  raise_on_epoch_end: bool = False,
605
- shuffle: bool = True,
606
603
  ) -> tuple[torch.Tensor, torch.Tensor | None]:
607
604
  """
608
- Loads the next n_batches_in_buffer batches of activations into a tensor and returns it.
605
+ Loads the next batch of activations from the LLM and returns it.
609
606
 
610
- The primary purpose here is maintaining a shuffling buffer.
607
+ If raise_on_epoch_end is True, when the dataset is exhausted it will
608
+ automatically refill the dataset and then raise a StopIteration so that
609
+ the caller has a chance to react.
611
610
 
612
- If raise_on_epoch_end is True, when the dataset it exhausted it will automatically refill the dataset and then raise a StopIteration so that the caller has a chance to react.
611
+ Returns:
612
+ Tuple of (activations, token_ids) where activations has shape
613
+ (batch_size * context_size, d_in) and token_ids has shape
614
+ (batch_size * context_size,).
613
615
  """
614
- context_size = self.context_size
615
- batch_size = self.store_batch_size_prompts
616
616
  d_in = self.d_in
617
- total_size = batch_size * n_batches_in_buffer
618
617
 
619
618
  if self.cached_activation_dataset is not None:
620
- return self._load_buffer_from_cached(
621
- total_size, context_size, d_in, raise_on_epoch_end
622
- )
619
+ return self._load_raw_llm_batch_from_cached(raise_on_epoch_end)
623
620
 
624
- refill_iterator = range(0, total_size, batch_size)
625
- # Initialize empty tensor buffer of the maximum required size with an additional dimension for layers
626
- new_buffer_activations = torch.zeros(
627
- (total_size, self.training_context_size, d_in),
628
- dtype=self.dtype, # type: ignore
629
- device=self.device,
630
- )
631
- new_buffer_token_ids = torch.zeros(
632
- (total_size, self.training_context_size),
633
- dtype=torch.long,
634
- device=self.device,
621
+ # move batch toks to gpu for model
622
+ batch_tokens = self.get_batch_tokens(raise_at_epoch_end=raise_on_epoch_end).to(
623
+ _get_model_device(self.model)
635
624
  )
625
+ activations = self.get_activations(batch_tokens).to(self.device)
636
626
 
637
- for refill_batch_idx_start in tqdm(
638
- refill_iterator, leave=False, desc="Refilling buffer"
639
- ):
640
- # move batch toks to gpu for model
641
- refill_batch_tokens = self.get_batch_tokens(
642
- raise_at_epoch_end=raise_on_epoch_end
643
- ).to(_get_model_device(self.model))
644
- refill_activations = self.get_activations(refill_batch_tokens)
645
- # move acts back to cpu
646
- refill_activations.to(self.device)
647
- new_buffer_activations[
648
- refill_batch_idx_start : refill_batch_idx_start + batch_size, ...
649
- ] = refill_activations
650
-
651
- # handle seqpos_slice, this is done for activations in get_activations
652
- refill_batch_tokens = refill_batch_tokens[:, slice(*self.seqpos_slice)]
653
- new_buffer_token_ids[
654
- refill_batch_idx_start : refill_batch_idx_start + batch_size, ...
655
- ] = refill_batch_tokens
656
-
657
- new_buffer_activations = new_buffer_activations.reshape(-1, d_in)
658
- new_buffer_token_ids = new_buffer_token_ids.reshape(-1)
659
- if shuffle:
660
- new_buffer_activations, new_buffer_token_ids = permute_together(
661
- [new_buffer_activations, new_buffer_token_ids]
662
- )
627
+ # handle seqpos_slice, this is done for activations in get_activations
628
+ batch_tokens = batch_tokens[:, slice(*self.seqpos_slice)]
663
629
 
664
- return (
665
- new_buffer_activations,
666
- new_buffer_token_ids,
667
- )
630
+ # reshape from (batch, context, d_in) to (batch * context, d_in)
631
+ activations = activations.reshape(-1, d_in)
632
+ token_ids = batch_tokens.reshape(-1)
668
633
 
669
- def get_filtered_buffer(
634
+ return activations, token_ids
635
+
636
+ def get_filtered_llm_batch(
670
637
  self,
671
- n_batches_in_buffer: int,
672
638
  raise_on_epoch_end: bool = False,
673
- shuffle: bool = True,
674
639
  ) -> torch.Tensor:
640
+ """
641
+ Get a batch of LLM activations with special tokens filtered out.
642
+ """
675
643
  return _filter_buffer_acts(
676
- self.get_raw_buffer(
677
- n_batches_in_buffer=n_batches_in_buffer,
678
- raise_on_epoch_end=raise_on_epoch_end,
679
- shuffle=shuffle,
680
- ),
644
+ self.get_raw_llm_batch(raise_on_epoch_end=raise_on_epoch_end),
681
645
  self.exclude_special_tokens,
682
646
  )
683
647
 
684
648
  def _iterate_filtered_activations(self) -> Generator[torch.Tensor, None, None]:
685
649
  """
686
- Iterate over the filtered tokens in the buffer.
650
+ Iterate over filtered LLM activation batches.
687
651
  """
688
652
  while True:
689
653
  try:
690
- yield self.get_filtered_buffer(
691
- self.half_buffer_size, raise_on_epoch_end=True
692
- )
654
+ yield self.get_filtered_llm_batch(raise_on_epoch_end=True)
693
655
  except StopIteration:
694
656
  warnings.warn(
695
657
  "All samples in the training dataset have been exhausted, beginning new epoch."
696
658
  )
697
659
  try:
698
- yield self.get_filtered_buffer(self.half_buffer_size)
660
+ yield self.get_filtered_llm_batch()
699
661
  except StopIteration:
700
662
  raise ValueError(
701
663
  "Unable to fill buffer after starting new epoch. Dataset may be too small."
@@ -827,9 +789,3 @@ def _filter_buffer_acts(
827
789
 
828
790
  mask = torch.isin(tokens, exclude_tokens)
829
791
  return activations[~mask]
830
-
831
-
832
- def permute_together(tensors: Sequence[torch.Tensor]) -> tuple[torch.Tensor, ...]:
833
- """Permute tensors together."""
834
- permutation = torch.randperm(tensors[0].shape[0])
835
- return tuple(t[permutation] for t in tensors)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sae-lens
3
- Version: 6.27.2
3
+ Version: 6.27.3
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  License-File: LICENSE
@@ -1,8 +1,8 @@
1
- sae_lens/__init__.py,sha256=Y5z21KL2t1NAxsyEAKPnjrhXiDdK4M_gedq_7gwrmD0,4725
1
+ sae_lens/__init__.py,sha256=ETLfd3PmdJ2aAaKyeTTHptBE2HaWY0OfzOKNk7dyhKE,4725
2
2
  sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  sae_lens/analysis/hooked_sae_transformer.py,sha256=dQRgGVwce8XwylL2AzJE7l9elhtMRFCs2hdUj-Qyy4g,14038
4
4
  sae_lens/analysis/neuronpedia_integration.py,sha256=Gx1W7hUBEuMoasNcnOnZ1wmqbXDd1pSZ1nqKEya1HQc,4962
5
- sae_lens/cache_activations_runner.py,sha256=Lvlz-k5-3XxVRtUdC4b1CiKyx5s0ckLa8GDGv9_kcxs,12566
5
+ sae_lens/cache_activations_runner.py,sha256=TjqNWIc46Nw09jHWFjzQzgzG5wdu_87Ahe-iFjI5_0Q,13117
6
6
  sae_lens/config.py,sha256=sseYcRMsAyopj8FICup1RGTXjFxzAithZ2OH7OpQV3Y,30839
7
7
  sae_lens/constants.py,sha256=CM-h9AjZNAl2aP7hVpKk7YsFHpu-_Lfhhmq2d5qPEVc,887
8
8
  sae_lens/evals.py,sha256=P0NUsJeGzYxFBiVKhbPzd72IFKY4gH40HHlEZ3jEAmg,39598
@@ -28,7 +28,7 @@ sae_lens/saes/transcoder.py,sha256=CTpJs8ASOK06npih7gZHygZuxqTR7HICWlOYfTiKjI4,1
28
28
  sae_lens/tokenization_and_batching.py,sha256=uoHtAs9z3XqG0Fh-iQVYVlrbyB_E3kFFhrKU30BosCo,5438
29
29
  sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
30
30
  sae_lens/training/activation_scaler.py,sha256=FzNfgBplLWmyiSlZ6TUvE-nur3lOiGTrlvC97ys8S24,1973
31
- sae_lens/training/activations_store.py,sha256=2BVajHRcozKQFf1tkeraUCdFuut3spdk0hhgtdpizzI,34031
31
+ sae_lens/training/activations_store.py,sha256=kp4-6R4rTJUSt-g-Ifg5B1h7iIe7jZj-XQSKDvDpQMI,32187
32
32
  sae_lens/training/mixing_buffer.py,sha256=1Z-S2CcQXMWGxRZJFnXeZFxbZcALkO_fP6VO37XdJQQ,2519
33
33
  sae_lens/training/optim.py,sha256=bJpqqcK4enkcPvQAJkeH4Ci1LUOlfjIMTv6-IlaAbRA,5588
34
34
  sae_lens/training/sae_trainer.py,sha256=zhkabyIKxI_tZTV3_kwz6zMrHZ95Ecr97krmwc-9ffs,17600
@@ -36,7 +36,7 @@ sae_lens/training/types.py,sha256=1FpLx_Doda9vZpmfm-x1e8wGBYpyhe9Kpb_JuM5nIFM,90
36
36
  sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
37
37
  sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
38
38
  sae_lens/util.py,sha256=spkcmQUsjVYFn5H2032nQYr1CKGVnv3tAdfIpY59-Mg,3919
39
- sae_lens-6.27.2.dist-info/METADATA,sha256=O2gl2BGnUAEHUScAel03ovRp2TSf7rZiUG66x97sBBs,5361
40
- sae_lens-6.27.2.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
41
- sae_lens-6.27.2.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
42
- sae_lens-6.27.2.dist-info/RECORD,,
39
+ sae_lens-6.27.3.dist-info/METADATA,sha256=c59mjyoausFHs1bd8n_4J6dA-2uDRPgY9Wwas52zydw,5361
40
+ sae_lens-6.27.3.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
41
+ sae_lens-6.27.3.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
42
+ sae_lens-6.27.3.dist-info/RECORD,,