sae-lens 6.26.1__py3-none-any.whl → 6.28.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,145 @@
1
+ from collections.abc import Iterator
2
+ from pathlib import Path
3
+ from typing import Any, Callable
4
+
5
+ import torch
6
+
7
+ from sae_lens.config import LoggingConfig, SAETrainerConfig
8
+ from sae_lens.saes.sae import TrainingSAE
9
+ from sae_lens.synthetic.activation_generator import ActivationGenerator
10
+ from sae_lens.synthetic.feature_dictionary import FeatureDictionary
11
+ from sae_lens.training.sae_trainer import SAETrainer, SaveCheckpointFn
12
+
13
+
14
+ def train_toy_sae(
15
+ sae: TrainingSAE[Any],
16
+ feature_dict: FeatureDictionary,
17
+ activations_generator: ActivationGenerator,
18
+ training_samples: int = 10_000_000,
19
+ batch_size: int = 1024,
20
+ lr: float = 3e-4,
21
+ lr_warm_up_steps: int = 0,
22
+ lr_decay_steps: int = 0,
23
+ device: str | torch.device = "cpu",
24
+ n_snapshots: int = 0,
25
+ snapshot_fn: Callable[[SAETrainer[Any, Any]], None] | None = None,
26
+ ) -> None:
27
+ """
28
+ Train an SAE on synthetic activations from a feature dictionary.
29
+
30
+ This is a convenience function that sets up the training loop with
31
+ sensible defaults for small-scale synthetic data experiments.
32
+
33
+ Args:
34
+ sae: The TrainingSAE to train
35
+ feature_dict: The feature dictionary that maps feature activations to
36
+ hidden activations
37
+ activations_generator: Generator that produces feature activations
38
+ training_samples: Total number of training samples
39
+ batch_size: Batch size for training
40
+ lr: Learning rate
41
+ lr_warm_up_steps: Number of warmup steps for learning rate
42
+ lr_decay_steps: Number of steps over which to decay learning rate
43
+ device: Device to train on
44
+ n_snapshots: Number of snapshots to take during training. Snapshots are
45
+ evenly spaced throughout training.
46
+ snapshot_fn: Callback function called at each snapshot point. Receives
47
+ the SAETrainer instance, allowing access to the SAE, training step,
48
+ and other training state. Required if n_snapshots > 0.
49
+ """
50
+
51
+ device_str = str(device) if isinstance(device, torch.device) else device
52
+
53
+ # Create data iterator
54
+ data_iterator = SyntheticActivationIterator(
55
+ feature_dict=feature_dict,
56
+ activations_generator=activations_generator,
57
+ batch_size=batch_size,
58
+ )
59
+
60
+ # Create trainer config
61
+ trainer_cfg = SAETrainerConfig(
62
+ n_checkpoints=n_snapshots,
63
+ checkpoint_path=None,
64
+ save_final_checkpoint=False,
65
+ total_training_samples=training_samples,
66
+ device=device_str,
67
+ autocast=False,
68
+ lr=lr,
69
+ lr_end=lr,
70
+ lr_scheduler_name="constant",
71
+ lr_warm_up_steps=lr_warm_up_steps,
72
+ adam_beta1=0.9,
73
+ adam_beta2=0.999,
74
+ lr_decay_steps=lr_decay_steps,
75
+ n_restart_cycles=1,
76
+ train_batch_size_samples=batch_size,
77
+ dead_feature_window=1000,
78
+ feature_sampling_window=2000,
79
+ logger=LoggingConfig(
80
+ log_to_wandb=False,
81
+ # hacky way to disable evals, but works for now
82
+ eval_every_n_wandb_logs=2**31 - 1,
83
+ ),
84
+ )
85
+
86
+ def snapshot_wrapper(
87
+ snapshot_fn: Callable[[SAETrainer[Any, Any]], None] | None,
88
+ ) -> SaveCheckpointFn:
89
+ def save_checkpoint(checkpoint_path: Path | None) -> None: # noqa: ARG001
90
+ if snapshot_fn is None:
91
+ raise ValueError("snapshot_fn must be provided to take snapshots")
92
+ snapshot_fn(trainer)
93
+
94
+ return save_checkpoint
95
+
96
+ # Create trainer and train
97
+ feature_dict.eval()
98
+ trainer = SAETrainer(
99
+ cfg=trainer_cfg,
100
+ sae=sae,
101
+ data_provider=data_iterator,
102
+ save_checkpoint_fn=snapshot_wrapper(snapshot_fn),
103
+ )
104
+
105
+ trainer.fit()
106
+
107
+
108
+ class SyntheticActivationIterator(Iterator[torch.Tensor]):
109
+ """
110
+ An iterator that generates synthetic activations for SAE training.
111
+
112
+ This iterator wraps a FeatureDictionary and a function that generates
113
+ feature activations, producing hidden activations that can be used
114
+ to train an SAE.
115
+ """
116
+
117
+ def __init__(
118
+ self,
119
+ feature_dict: FeatureDictionary,
120
+ activations_generator: ActivationGenerator,
121
+ batch_size: int,
122
+ ):
123
+ """
124
+ Create a new SyntheticActivationIterator.
125
+
126
+ Args:
127
+ feature_dict: The feature dictionary to use for generating hidden activations
128
+ activations_generator: Generator that produces feature activations
129
+ batch_size: Number of samples per batch
130
+ """
131
+ self.feature_dict = feature_dict
132
+ self.activations_generator = activations_generator
133
+ self.batch_size = batch_size
134
+
135
+ @torch.no_grad()
136
+ def next_batch(self) -> torch.Tensor:
137
+ """Generate the next batch of hidden activations."""
138
+ features = self.activations_generator(self.batch_size)
139
+ return self.feature_dict(features)
140
+
141
+ def __iter__(self) -> "SyntheticActivationIterator":
142
+ return self
143
+
144
+ def __next__(self) -> torch.Tensor:
145
+ return self.next_batch()
@@ -85,8 +85,8 @@ def concat_and_batch_sequences(
85
85
  for sequence in tokens_iterator:
86
86
  if (
87
87
  begin_sequence_token_id is not None
88
- and sequence[0] != begin_sequence_token_id
89
88
  and len(sequence) >= context_size - 1
89
+ and sequence[0] != begin_sequence_token_id
90
90
  ):
91
91
  begin_sequence_token_id_tensor = torch.tensor(
92
92
  [begin_sequence_token_id],
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  import json
4
4
  import os
5
5
  import warnings
6
- from collections.abc import Generator, Iterator, Sequence
6
+ from collections.abc import Generator, Iterator
7
7
  from pathlib import Path
8
8
  from typing import Any, Literal, cast
9
9
 
@@ -148,6 +148,7 @@ class ActivationsStore:
148
148
  exclude_special_tokens=exclude_special_tokens,
149
149
  disable_concat_sequences=cfg.disable_concat_sequences,
150
150
  sequence_separator_token=cfg.sequence_separator_token,
151
+ activations_mixing_fraction=cfg.activations_mixing_fraction,
151
152
  )
152
153
 
153
154
  @classmethod
@@ -222,6 +223,7 @@ class ActivationsStore:
222
223
  exclude_special_tokens: torch.Tensor | None = None,
223
224
  disable_concat_sequences: bool = False,
224
225
  sequence_separator_token: int | Literal["bos", "eos", "sep"] | None = "bos",
226
+ activations_mixing_fraction: float = 0.5,
225
227
  ):
226
228
  self.model = model
227
229
  if model_kwargs is None:
@@ -252,7 +254,6 @@ class ActivationsStore:
252
254
  self.context_size = context_size
253
255
  self.d_in = d_in
254
256
  self.n_batches_in_buffer = n_batches_in_buffer
255
- self.half_buffer_size = n_batches_in_buffer // 2
256
257
  self.total_training_tokens = total_training_tokens
257
258
  self.store_batch_size_prompts = store_batch_size_prompts
258
259
  self.train_batch_size_tokens = train_batch_size_tokens
@@ -269,6 +270,7 @@ class ActivationsStore:
269
270
  self.sequence_separator_token: int | Literal["bos", "eos", "sep"] | None = (
270
271
  sequence_separator_token
271
272
  )
273
+ self.activations_mixing_fraction = activations_mixing_fraction
272
274
 
273
275
  self.n_dataset_processed = 0
274
276
 
@@ -535,18 +537,15 @@ class ActivationsStore:
535
537
 
536
538
  return stacked_activations
537
539
 
538
- def _load_buffer_from_cached(
540
+ def _load_raw_llm_batch_from_cached(
539
541
  self,
540
- total_size: int,
541
- context_size: int,
542
- d_in: int,
543
542
  raise_on_epoch_end: bool,
544
543
  ) -> tuple[
545
544
  torch.Tensor,
546
545
  torch.Tensor | None,
547
546
  ]:
548
547
  """
549
- Loads `total_size` activations from `cached_activation_dataset`
548
+ Loads a batch of activations from `cached_activation_dataset`
550
549
 
551
550
  The dataset has columns for each hook_name,
552
551
  each containing activations of shape (context_size, d_in).
@@ -554,6 +553,10 @@ class ActivationsStore:
554
553
  raises StopIteration
555
554
  """
556
555
  assert self.cached_activation_dataset is not None
556
+ context_size = self.context_size
557
+ batch_size = self.store_batch_size_prompts
558
+ d_in = self.d_in
559
+
557
560
  # In future, could be a list of multiple hook names
558
561
  if self.hook_name not in self.cached_activation_dataset.column_names:
559
562
  raise ValueError(
@@ -561,138 +564,100 @@ class ActivationsStore:
561
564
  f"got {self.cached_activation_dataset.column_names}."
562
565
  )
563
566
 
564
- if self.current_row_idx > len(self.cached_activation_dataset) - total_size:
567
+ if self.current_row_idx > len(self.cached_activation_dataset) - batch_size:
565
568
  self.current_row_idx = 0
566
569
  if raise_on_epoch_end:
567
570
  raise StopIteration
568
571
 
569
- new_buffer = []
570
572
  ds_slice = self.cached_activation_dataset[
571
- self.current_row_idx : self.current_row_idx + total_size
573
+ self.current_row_idx : self.current_row_idx + batch_size
572
574
  ]
573
575
  # Load activations for each hook.
574
576
  # Usually faster to first slice dataset then pick column
575
- new_buffer = ds_slice[self.hook_name]
576
- if new_buffer.shape != (total_size, context_size, d_in):
577
+ acts_buffer = ds_slice[self.hook_name]
578
+ if acts_buffer.shape != (batch_size, context_size, d_in):
577
579
  raise ValueError(
578
- f"new_buffer has shape {new_buffer.shape}, "
579
- f"but expected ({total_size}, {context_size}, {d_in})."
580
+ f"acts_buffer has shape {acts_buffer.shape}, "
581
+ f"but expected ({batch_size}, {context_size}, {d_in})."
580
582
  )
581
583
 
582
- self.current_row_idx += total_size
583
- acts_buffer = new_buffer.reshape(total_size * context_size, d_in)
584
+ self.current_row_idx += batch_size
585
+ acts_buffer = acts_buffer.reshape(batch_size * context_size, d_in)
584
586
 
585
587
  if "token_ids" not in self.cached_activation_dataset.column_names:
586
588
  return acts_buffer, None
587
589
 
588
590
  token_ids_buffer = ds_slice["token_ids"]
589
- if token_ids_buffer.shape != (total_size, context_size):
591
+ if token_ids_buffer.shape != (batch_size, context_size):
590
592
  raise ValueError(
591
593
  f"token_ids_buffer has shape {token_ids_buffer.shape}, "
592
- f"but expected ({total_size}, {context_size})."
594
+ f"but expected ({batch_size}, {context_size})."
593
595
  )
594
- token_ids_buffer = token_ids_buffer.reshape(total_size * context_size)
596
+ token_ids_buffer = token_ids_buffer.reshape(batch_size * context_size)
595
597
  return acts_buffer, token_ids_buffer
596
598
 
597
599
  @torch.no_grad()
598
- def get_raw_buffer(
600
+ def get_raw_llm_batch(
599
601
  self,
600
- n_batches_in_buffer: int,
601
602
  raise_on_epoch_end: bool = False,
602
- shuffle: bool = True,
603
603
  ) -> tuple[torch.Tensor, torch.Tensor | None]:
604
604
  """
605
- Loads the next n_batches_in_buffer batches of activations into a tensor and returns it.
605
+ Loads the next batch of activations from the LLM and returns it.
606
606
 
607
- The primary purpose here is maintaining a shuffling buffer.
607
+ If raise_on_epoch_end is True, when the dataset is exhausted it will
608
+ automatically refill the dataset and then raise a StopIteration so that
609
+ the caller has a chance to react.
608
610
 
609
- If raise_on_epoch_end is True, when the dataset it exhausted it will automatically refill the dataset and then raise a StopIteration so that the caller has a chance to react.
611
+ Returns:
612
+ Tuple of (activations, token_ids) where activations has shape
613
+ (batch_size * context_size, d_in) and token_ids has shape
614
+ (batch_size * context_size,).
610
615
  """
611
- context_size = self.context_size
612
- batch_size = self.store_batch_size_prompts
613
616
  d_in = self.d_in
614
- total_size = batch_size * n_batches_in_buffer
615
617
 
616
618
  if self.cached_activation_dataset is not None:
617
- return self._load_buffer_from_cached(
618
- total_size, context_size, d_in, raise_on_epoch_end
619
- )
619
+ return self._load_raw_llm_batch_from_cached(raise_on_epoch_end)
620
620
 
621
- refill_iterator = range(0, total_size, batch_size)
622
- # Initialize empty tensor buffer of the maximum required size with an additional dimension for layers
623
- new_buffer_activations = torch.zeros(
624
- (total_size, self.training_context_size, d_in),
625
- dtype=self.dtype, # type: ignore
626
- device=self.device,
627
- )
628
- new_buffer_token_ids = torch.zeros(
629
- (total_size, self.training_context_size),
630
- dtype=torch.long,
631
- device=self.device,
621
+ # move batch toks to gpu for model
622
+ batch_tokens = self.get_batch_tokens(raise_at_epoch_end=raise_on_epoch_end).to(
623
+ _get_model_device(self.model)
632
624
  )
625
+ activations = self.get_activations(batch_tokens).to(self.device)
633
626
 
634
- for refill_batch_idx_start in tqdm(
635
- refill_iterator, leave=False, desc="Refilling buffer"
636
- ):
637
- # move batch toks to gpu for model
638
- refill_batch_tokens = self.get_batch_tokens(
639
- raise_at_epoch_end=raise_on_epoch_end
640
- ).to(_get_model_device(self.model))
641
- refill_activations = self.get_activations(refill_batch_tokens)
642
- # move acts back to cpu
643
- refill_activations.to(self.device)
644
- new_buffer_activations[
645
- refill_batch_idx_start : refill_batch_idx_start + batch_size, ...
646
- ] = refill_activations
647
-
648
- # handle seqpos_slice, this is done for activations in get_activations
649
- refill_batch_tokens = refill_batch_tokens[:, slice(*self.seqpos_slice)]
650
- new_buffer_token_ids[
651
- refill_batch_idx_start : refill_batch_idx_start + batch_size, ...
652
- ] = refill_batch_tokens
653
-
654
- new_buffer_activations = new_buffer_activations.reshape(-1, d_in)
655
- new_buffer_token_ids = new_buffer_token_ids.reshape(-1)
656
- if shuffle:
657
- new_buffer_activations, new_buffer_token_ids = permute_together(
658
- [new_buffer_activations, new_buffer_token_ids]
659
- )
627
+ # handle seqpos_slice, this is done for activations in get_activations
628
+ batch_tokens = batch_tokens[:, slice(*self.seqpos_slice)]
660
629
 
661
- return (
662
- new_buffer_activations,
663
- new_buffer_token_ids,
664
- )
630
+ # reshape from (batch, context, d_in) to (batch * context, d_in)
631
+ activations = activations.reshape(-1, d_in)
632
+ token_ids = batch_tokens.reshape(-1)
665
633
 
666
- def get_filtered_buffer(
634
+ return activations, token_ids
635
+
636
+ def get_filtered_llm_batch(
667
637
  self,
668
- n_batches_in_buffer: int,
669
638
  raise_on_epoch_end: bool = False,
670
- shuffle: bool = True,
671
639
  ) -> torch.Tensor:
640
+ """
641
+ Get a batch of LLM activations with special tokens filtered out.
642
+ """
672
643
  return _filter_buffer_acts(
673
- self.get_raw_buffer(
674
- n_batches_in_buffer=n_batches_in_buffer,
675
- raise_on_epoch_end=raise_on_epoch_end,
676
- shuffle=shuffle,
677
- ),
644
+ self.get_raw_llm_batch(raise_on_epoch_end=raise_on_epoch_end),
678
645
  self.exclude_special_tokens,
679
646
  )
680
647
 
681
648
  def _iterate_filtered_activations(self) -> Generator[torch.Tensor, None, None]:
682
649
  """
683
- Iterate over the filtered tokens in the buffer.
650
+ Iterate over filtered LLM activation batches.
684
651
  """
685
652
  while True:
686
653
  try:
687
- yield self.get_filtered_buffer(
688
- self.half_buffer_size, raise_on_epoch_end=True
689
- )
654
+ yield self.get_filtered_llm_batch(raise_on_epoch_end=True)
690
655
  except StopIteration:
691
656
  warnings.warn(
692
657
  "All samples in the training dataset have been exhausted, beginning new epoch."
693
658
  )
694
659
  try:
695
- yield self.get_filtered_buffer(self.half_buffer_size)
660
+ yield self.get_filtered_llm_batch()
696
661
  except StopIteration:
697
662
  raise ValueError(
698
663
  "Unable to fill buffer after starting new epoch. Dataset may be too small."
@@ -708,6 +673,7 @@ class ActivationsStore:
708
673
  buffer_size=self.n_batches_in_buffer * self.training_context_size,
709
674
  batch_size=self.train_batch_size_tokens,
710
675
  activations_loader=self._iterate_filtered_activations(),
676
+ mix_fraction=self.activations_mixing_fraction,
711
677
  )
712
678
 
713
679
  def next_batch(self) -> torch.Tensor:
@@ -823,9 +789,3 @@ def _filter_buffer_acts(
823
789
 
824
790
  mask = torch.isin(tokens, exclude_tokens)
825
791
  return activations[~mask]
826
-
827
-
828
- def permute_together(tensors: Sequence[torch.Tensor]) -> tuple[torch.Tensor, ...]:
829
- """Permute tensors together."""
830
- permutation = torch.randperm(tensors[0].shape[0])
831
- return tuple(t[permutation] for t in tensors)
@@ -8,15 +8,19 @@ def mixing_buffer(
8
8
  buffer_size: int,
9
9
  batch_size: int,
10
10
  activations_loader: Iterator[torch.Tensor],
11
+ mix_fraction: float = 0.5,
11
12
  ) -> Iterator[torch.Tensor]:
12
13
  """
13
14
  A generator that maintains a mix of old and new activations for better training.
14
- It stores half of the activations and mixes them with new ones to create batches.
15
+ It keeps a portion of activations and mixes them with new ones to create batches.
15
16
 
16
17
  Args:
17
- buffer_size: Total size of the buffer (will store buffer_size/2 activations)
18
+ buffer_size: Total size of the buffer
18
19
  batch_size: Size of batches to return
19
20
  activations_loader: Iterator providing new activations
21
+ mix_fraction: Fraction of buffer to keep for mixing (default 0.5).
22
+ Higher values mean more temporal mixing but slower throughput.
23
+ If 0, no shuffling occurs (passthrough mode).
20
24
 
21
25
  Yields:
22
26
  Batches of activations of shape (batch_size, *activation_dims)
@@ -24,6 +28,8 @@ def mixing_buffer(
24
28
 
25
29
  if buffer_size < batch_size:
26
30
  raise ValueError("Buffer size must be greater than or equal to batch size")
31
+ if not 0 <= mix_fraction <= 1:
32
+ raise ValueError("mix_fraction must be in [0, 1]")
27
33
 
28
34
  storage_buffer: torch.Tensor | None = None
29
35
 
@@ -35,10 +41,13 @@ def mixing_buffer(
35
41
  )
36
42
 
37
43
  if storage_buffer.shape[0] >= buffer_size:
38
- # Shuffle
39
- storage_buffer = storage_buffer[torch.randperm(storage_buffer.shape[0])]
44
+ if mix_fraction > 0:
45
+ storage_buffer = storage_buffer[torch.randperm(storage_buffer.shape[0])]
40
46
 
41
- num_serving_batches = max(1, storage_buffer.shape[0] // (2 * batch_size))
47
+ # Keep a fixed amount for mixing, serve the rest
48
+ keep_for_mixing = int(buffer_size * mix_fraction)
49
+ num_to_serve = storage_buffer.shape[0] - keep_for_mixing
50
+ num_serving_batches = max(1, num_to_serve // batch_size)
42
51
  serving_cutoff = num_serving_batches * batch_size
43
52
  serving_buffer = storage_buffer[:serving_cutoff]
44
53
  storage_buffer = storage_buffer[serving_cutoff:]
@@ -55,7 +55,7 @@ Evaluator = Callable[[T_TRAINING_SAE, DataProvider, ActivationScaler], dict[str,
55
55
 
56
56
  class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
57
57
  """
58
- Core SAE class used for inference. For training, see TrainingSAE.
58
+ Trainer for Sparse Autoencoder (SAE) models.
59
59
  """
60
60
 
61
61
  data_provider: DataProvider
sae_lens/util.py CHANGED
@@ -95,8 +95,10 @@ def get_special_token_ids(tokenizer: PreTrainedTokenizerBase) -> list[int]:
95
95
  return list(special_tokens)
96
96
 
97
97
 
98
- def str_to_dtype(dtype: str) -> torch.dtype:
98
+ def str_to_dtype(dtype: str | torch.dtype) -> torch.dtype:
99
99
  """Convert a string to a torch.dtype."""
100
+ if isinstance(dtype, torch.dtype):
101
+ return dtype
100
102
  if dtype not in DTYPE_MAP:
101
103
  raise ValueError(
102
104
  f"Invalid dtype: {dtype}. Must be one of {list(DTYPE_MAP.keys())}"
@@ -111,3 +113,26 @@ def dtype_to_str(dtype: torch.dtype) -> str:
111
113
  f"Invalid dtype: {dtype}. Must be one of {list(DTYPE_TO_STR.keys())}"
112
114
  )
113
115
  return DTYPE_TO_STR[dtype]
116
+
117
+
118
+ def cosine_similarities(
119
+ mat1: torch.Tensor, mat2: torch.Tensor | None = None
120
+ ) -> torch.Tensor:
121
+ """
122
+ Compute cosine similarities between each row of mat1 and each row of mat2.
123
+
124
+ Args:
125
+ mat1: Tensor of shape [n1, d]
126
+ mat2: Tensor of shape [n2, d]. If not provided, mat1 = mat2
127
+
128
+ Returns:
129
+ Tensor of shape [n1, n2] with cosine similarities
130
+ """
131
+ if mat2 is None:
132
+ mat2 = mat1
133
+ # Clamp norm to 1e-8 to prevent division by zero. This threshold is chosen
134
+ # to be small enough to not affect normal vectors but large enough to avoid
135
+ # numerical instability. Zero vectors will effectively map to zero similarity.
136
+ mat1_normed = mat1 / mat1.norm(dim=1, keepdim=True).clamp(min=1e-8)
137
+ mat2_normed = mat2 / mat2.norm(dim=1, keepdim=True).clamp(min=1e-8)
138
+ return mat1_normed @ mat2_normed.T
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sae-lens
3
- Version: 6.26.1
3
+ Version: 6.28.1
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  License-File: LICENSE
@@ -77,6 +77,8 @@ The new v6 update is a major refactor to SAELens and changes the way training co
77
77
  [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/logits_lens_with_features.ipynb)
78
78
  - [Training a Sparse Autoencoder](tutorials/training_a_sparse_autoencoder.ipynb)
79
79
  [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/training_a_sparse_autoencoder.ipynb)
80
+ - [Training SAEs on Synthetic Data](tutorials/training_saes_on_synthetic_data.ipynb)
81
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/decoderesearch/SAELens/blob/main/tutorials/training_saes_on_synthetic_data.ipynb)
80
82
 
81
83
  ## Join the Slack!
82
84
 
@@ -0,0 +1,52 @@
1
+ sae_lens/__init__.py,sha256=S-AS72IxkvKO-wItRQjuyczikDxmfDaUgXRSfu5PU-o,4788
2
+ sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
+ sae_lens/analysis/hooked_sae_transformer.py,sha256=dQRgGVwce8XwylL2AzJE7l9elhtMRFCs2hdUj-Qyy4g,14038
4
+ sae_lens/analysis/neuronpedia_integration.py,sha256=Gx1W7hUBEuMoasNcnOnZ1wmqbXDd1pSZ1nqKEya1HQc,4962
5
+ sae_lens/cache_activations_runner.py,sha256=TjqNWIc46Nw09jHWFjzQzgzG5wdu_87Ahe-iFjI5_0Q,13117
6
+ sae_lens/config.py,sha256=sseYcRMsAyopj8FICup1RGTXjFxzAithZ2OH7OpQV3Y,30839
7
+ sae_lens/constants.py,sha256=CM-h9AjZNAl2aP7hVpKk7YsFHpu-_Lfhhmq2d5qPEVc,887
8
+ sae_lens/evals.py,sha256=P0NUsJeGzYxFBiVKhbPzd72IFKY4gH40HHlEZ3jEAmg,39598
9
+ sae_lens/llm_sae_training_runner.py,sha256=M7BK55gSFYu2qFQKABHX3c8i46P1LfODCeyHFzGGuqU,15196
10
+ sae_lens/load_model.py,sha256=C8AMykctj6H7tz_xRwB06-EXj6TfW64PtSJZR5Jxn1Y,8649
11
+ sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
+ sae_lens/loading/pretrained_sae_loaders.py,sha256=hHMlew1u6zVlbzvS9S_SfUPnAG0_OAjjIcjoUTIUZrU,63657
13
+ sae_lens/loading/pretrained_saes_directory.py,sha256=1at_aQbD8WFywchQCKuwfP-yvCq_Z2aUYrpKDnSN5Nc,4283
14
+ sae_lens/pretokenize_runner.py,sha256=amJwIz3CKi2s2wNQn-10E7eAV7VFhNqtFDNTeTkwEI8,7133
15
+ sae_lens/pretrained_saes.yaml,sha256=Hn8jXwZ7V6QQxzgu41LFEP-LAzuDxwYL5vhoar-pPX8,1509922
16
+ sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
17
+ sae_lens/saes/__init__.py,sha256=SBqPaP6Gl5uPFwHlumAZATC4Wd26xKIYLAAAo4MSa5Q,2200
18
+ sae_lens/saes/batchtopk_sae.py,sha256=x4EbgZl0GUickRPcCmtKNGS2Ra3Uy1Z1OtF2FnrSabQ,5422
19
+ sae_lens/saes/gated_sae.py,sha256=V_2ZNlV4gRD-rX5JSx1xqY7idT8ChfdQ5yxWDdu_6hg,8826
20
+ sae_lens/saes/jumprelu_sae.py,sha256=miiF-xI_yXdV9EkKjwAbU9zSMsx9KtKCz5YdXEzkN8g,13313
21
+ sae_lens/saes/matching_pursuit_sae.py,sha256=08_G9p1YMLnE5qZVCPp6gll-iG6nHRbMMASf4_bkFt8,13207
22
+ sae_lens/saes/matryoshka_batchtopk_sae.py,sha256=Qr6htt1HHOuO9FXI9hyaPSnGFIiJG-v7y1t1CEmkFzM,5995
23
+ sae_lens/saes/sae.py,sha256=xRmgiLuaFlDCv8SyLbL-5TwdrWHpNLqSGe8mC1L6WcI,40942
24
+ sae_lens/saes/standard_sae.py,sha256=_hldNZkFPAf9VGrxouR1-tN8T2OEk8IkWBcXoatrC1o,5749
25
+ sae_lens/saes/temporal_sae.py,sha256=83Ap4mYGfdN3sKdPF8nKjhdXph3-7E2QuLobqJ_YuoM,13273
26
+ sae_lens/saes/topk_sae.py,sha256=vrMRPrCQR1o8G_kXqY_EAoGZARupkQNFB2dNZVLsusE,21073
27
+ sae_lens/saes/transcoder.py,sha256=CTpJs8ASOK06npih7gZHygZuxqTR7HICWlOYfTiKjI4,13501
28
+ sae_lens/synthetic/__init__.py,sha256=FGUasB6fLPXRFCcrtKfL7vCKDOWebZ5Rx5F9QNJZklI,2875
29
+ sae_lens/synthetic/activation_generator.py,sha256=thWGTwRmhu0K8m66WfJUajHmuIPHkwV4_HjmG0dL3G8,7638
30
+ sae_lens/synthetic/correlation.py,sha256=odr-S5h6c2U-bepwrAQeMfV1iBF_cnnQzqw7zapEXZ4,6056
31
+ sae_lens/synthetic/evals.py,sha256=Nhi314ZnRgLfhBj-3tm_zzI-pGyFTcwllDXbIpPFXeU,4584
32
+ sae_lens/synthetic/feature_dictionary.py,sha256=2A9wqdT1KejRLuIoFWdoiWdDtaHHgIluaKsHGizsVxI,4864
33
+ sae_lens/synthetic/firing_probabilities.py,sha256=yclz1pWl5gE1r8LAxFvzQS88Lxwk5-3r8BCX9HLVejA,3370
34
+ sae_lens/synthetic/hierarchy.py,sha256=dlQdPnnG3VzQDB3QOaqSXwoH8Ij2ioxmTlZg1lXHaRQ,11754
35
+ sae_lens/synthetic/initialization.py,sha256=orMGW-786wRDHIS2W7bEH0HmlVFQ4g2z4bnnwdv5w4s,1386
36
+ sae_lens/synthetic/plotting.py,sha256=5lFrej1QOkGAcImFNo5-o-8mI_rUVqvEI57KzUQPPtQ,8208
37
+ sae_lens/synthetic/training.py,sha256=Bg6NYxdzifq_8g-dJQSZ_z_TXDdGRtEi7tqNDb-gCVc,4986
38
+ sae_lens/tokenization_and_batching.py,sha256=uoHtAs9z3XqG0Fh-iQVYVlrbyB_E3kFFhrKU30BosCo,5438
39
+ sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
40
+ sae_lens/training/activation_scaler.py,sha256=FzNfgBplLWmyiSlZ6TUvE-nur3lOiGTrlvC97ys8S24,1973
41
+ sae_lens/training/activations_store.py,sha256=kp4-6R4rTJUSt-g-Ifg5B1h7iIe7jZj-XQSKDvDpQMI,32187
42
+ sae_lens/training/mixing_buffer.py,sha256=1Z-S2CcQXMWGxRZJFnXeZFxbZcALkO_fP6VO37XdJQQ,2519
43
+ sae_lens/training/optim.py,sha256=bJpqqcK4enkcPvQAJkeH4Ci1LUOlfjIMTv6-IlaAbRA,5588
44
+ sae_lens/training/sae_trainer.py,sha256=iiGrNwmiX0xSHnJit0lH66yQzB6q8Fww1WNJZbTSBGY,17579
45
+ sae_lens/training/types.py,sha256=1FpLx_Doda9vZpmfm-x1e8wGBYpyhe9Kpb_JuM5nIFM,90
46
+ sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
47
+ sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
48
+ sae_lens/util.py,sha256=oIMoeyEP2IzcPFmRbKUzOAycgEyMcOasGeO_BGVZbc4,4846
49
+ sae_lens-6.28.1.dist-info/METADATA,sha256=OdPVG1dwWoLGqiutKkAJGazfBLLbYQLBUbs_3h58BKg,5633
50
+ sae_lens-6.28.1.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
51
+ sae_lens-6.28.1.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
52
+ sae_lens-6.28.1.dist-info/RECORD,,
@@ -1,42 +0,0 @@
1
- sae_lens/__init__.py,sha256=zRp1nmb41W1Pt1rvlKvRWw73UxjGyz1iHAzH9_X6_WQ,4725
2
- sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- sae_lens/analysis/hooked_sae_transformer.py,sha256=dQRgGVwce8XwylL2AzJE7l9elhtMRFCs2hdUj-Qyy4g,14038
4
- sae_lens/analysis/neuronpedia_integration.py,sha256=Gx1W7hUBEuMoasNcnOnZ1wmqbXDd1pSZ1nqKEya1HQc,4962
5
- sae_lens/cache_activations_runner.py,sha256=Lvlz-k5-3XxVRtUdC4b1CiKyx5s0ckLa8GDGv9_kcxs,12566
6
- sae_lens/config.py,sha256=C982bUELhGHcfTwzeMTtXIf2hPtc946thYpUyctLiBo,30516
7
- sae_lens/constants.py,sha256=CM-h9AjZNAl2aP7hVpKk7YsFHpu-_Lfhhmq2d5qPEVc,887
8
- sae_lens/evals.py,sha256=P0NUsJeGzYxFBiVKhbPzd72IFKY4gH40HHlEZ3jEAmg,39598
9
- sae_lens/llm_sae_training_runner.py,sha256=M7BK55gSFYu2qFQKABHX3c8i46P1LfODCeyHFzGGuqU,15196
10
- sae_lens/load_model.py,sha256=C8AMykctj6H7tz_xRwB06-EXj6TfW64PtSJZR5Jxn1Y,8649
11
- sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- sae_lens/loading/pretrained_sae_loaders.py,sha256=hq-dhxsEdUmlAnZEiZBqX7lNyQQwZ6KXmXZWpzAc5FY,63638
13
- sae_lens/loading/pretrained_saes_directory.py,sha256=hejNfLUepYCSGPalRfQwxxCEUqMMUPsn1tufwvwct5k,3820
14
- sae_lens/pretokenize_runner.py,sha256=amJwIz3CKi2s2wNQn-10E7eAV7VFhNqtFDNTeTkwEI8,7133
15
- sae_lens/pretrained_saes.yaml,sha256=Hn8jXwZ7V6QQxzgu41LFEP-LAzuDxwYL5vhoar-pPX8,1509922
16
- sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
17
- sae_lens/saes/__init__.py,sha256=SBqPaP6Gl5uPFwHlumAZATC4Wd26xKIYLAAAo4MSa5Q,2200
18
- sae_lens/saes/batchtopk_sae.py,sha256=x4EbgZl0GUickRPcCmtKNGS2Ra3Uy1Z1OtF2FnrSabQ,5422
19
- sae_lens/saes/gated_sae.py,sha256=mHnmw-RD7hqIbP9_EBj3p2SK0OqQIkZivdOKRygeRgw,8825
20
- sae_lens/saes/jumprelu_sae.py,sha256=udjGHp3WTABQSL2Qq57j-bINWX61GCmo68EmdjMOXoo,13310
21
- sae_lens/saes/matching_pursuit_sae.py,sha256=08_G9p1YMLnE5qZVCPp6gll-iG6nHRbMMASf4_bkFt8,13207
22
- sae_lens/saes/matryoshka_batchtopk_sae.py,sha256=Qr6htt1HHOuO9FXI9hyaPSnGFIiJG-v7y1t1CEmkFzM,5995
23
- sae_lens/saes/sae.py,sha256=fzXv8lwHskSxsf8hm_wlKPkpq50iafmBjBNQzwZ6a00,40050
24
- sae_lens/saes/standard_sae.py,sha256=nEVETwAmRD2tyX7ESIic1fij48gAq1Dh7s_GQ2fqCZ4,5747
25
- sae_lens/saes/temporal_sae.py,sha256=DsecivcHWId-MTuJpQbz8OhqtmGhZACxJauYZGHo0Ok,13272
26
- sae_lens/saes/topk_sae.py,sha256=vrMRPrCQR1o8G_kXqY_EAoGZARupkQNFB2dNZVLsusE,21073
27
- sae_lens/saes/transcoder.py,sha256=CTpJs8ASOK06npih7gZHygZuxqTR7HICWlOYfTiKjI4,13501
28
- sae_lens/tokenization_and_batching.py,sha256=D_o7cXvRqhT89H3wNzoRymNALNE6eHojBWLdXOUwUGE,5438
29
- sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
30
- sae_lens/training/activation_scaler.py,sha256=FzNfgBplLWmyiSlZ6TUvE-nur3lOiGTrlvC97ys8S24,1973
31
- sae_lens/training/activations_store.py,sha256=rQadexm2BiwK7_MZIPlRkcKSqabi3iuOTC-R8aJchS8,33778
32
- sae_lens/training/mixing_buffer.py,sha256=vDpYG5ZE70szDvBsRKcNHEES3h_WTKJ16qDYk5jPOVA,2015
33
- sae_lens/training/optim.py,sha256=bJpqqcK4enkcPvQAJkeH4Ci1LUOlfjIMTv6-IlaAbRA,5588
34
- sae_lens/training/sae_trainer.py,sha256=zhkabyIKxI_tZTV3_kwz6zMrHZ95Ecr97krmwc-9ffs,17600
35
- sae_lens/training/types.py,sha256=1FpLx_Doda9vZpmfm-x1e8wGBYpyhe9Kpb_JuM5nIFM,90
36
- sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
37
- sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
38
- sae_lens/util.py,sha256=spkcmQUsjVYFn5H2032nQYr1CKGVnv3tAdfIpY59-Mg,3919
39
- sae_lens-6.26.1.dist-info/METADATA,sha256=yoE6CFgQ9L5SLzI3Zgr8H8CfUBgSimihGyEIvKd8TW8,5361
40
- sae_lens-6.26.1.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
41
- sae_lens-6.26.1.dist-info/licenses/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
42
- sae_lens-6.26.1.dist-info/RECORD,,