sae-lens 6.26.0__py3-none-any.whl → 6.28.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,230 @@
1
+ """
2
+ Plotting utilities for visualizing SAE training on synthetic data.
3
+
4
+ This module provides functions for:
5
+
6
+ - Plotting cosine similarities between SAE features and true features
7
+ - Automatically reordering features for better visualization
8
+ - Creating comparison plots between encoder and decoder
9
+ """
10
+
11
+ from collections.abc import Iterable
12
+ from pathlib import Path
13
+ from typing import Any
14
+
15
+ import plotly.graph_objects as go
16
+ import torch
17
+ from plotly.subplots import make_subplots
18
+
19
+ from sae_lens.saes import SAE
20
+ from sae_lens.synthetic.feature_dictionary import FeatureDictionary
21
+ from sae_lens.util import cosine_similarities
22
+
23
+
24
+ def find_best_feature_ordering(
25
+ sae_features: torch.Tensor,
26
+ true_features: torch.Tensor,
27
+ ) -> torch.Tensor:
28
+ """
29
+ Find the best ordering of SAE features to match true features.
30
+
31
+ Reorders SAE features so that each SAE latent aligns with its best-matching
32
+ true feature in order. This makes cosine similarity plots more interpretable.
33
+
34
+ Args:
35
+ sae_features: SAE decoder weights of shape [d_sae, hidden_dim]
36
+ true_features: True feature vectors of shape [num_features, hidden_dim]
37
+
38
+ Returns:
39
+ Tensor of indices that reorders sae_features for best alignment
40
+ """
41
+ cos_sims = cosine_similarities(sae_features, true_features)
42
+ best_matches = torch.argmax(torch.abs(cos_sims), dim=1)
43
+ return torch.argsort(best_matches)
44
+
45
+
46
+ def find_best_feature_ordering_from_sae(
47
+ sae: torch.nn.Module,
48
+ feature_dict: FeatureDictionary,
49
+ ) -> torch.Tensor:
50
+ """
51
+ Find the best feature ordering for an SAE given a feature dictionary.
52
+
53
+ Args:
54
+ sae: SAE with W_dec attribute of shape [d_sae, hidden_dim]
55
+ feature_dict: The feature dictionary containing true features
56
+
57
+ Returns:
58
+ Tensor of indices that reorders SAE latents for best alignment
59
+ """
60
+ sae_features = sae.W_dec.detach() # type: ignore[attr-defined]
61
+ true_features = feature_dict.feature_vectors.detach()
62
+ return find_best_feature_ordering(sae_features, true_features)
63
+
64
+
65
+ def find_best_feature_ordering_across_saes(
66
+ saes: Iterable[torch.nn.Module],
67
+ feature_dict: FeatureDictionary,
68
+ ) -> torch.Tensor:
69
+ """
70
+ Find the best feature ordering that works across multiple SAEs.
71
+
72
+ Useful for creating consistent orderings across training snapshots.
73
+
74
+ Args:
75
+ saes: Iterable of SAEs to consider
76
+ feature_dict: The feature dictionary containing true features
77
+
78
+ Returns:
79
+ The best ordering tensor found across all SAEs
80
+ """
81
+ best_score = float("-inf")
82
+ best_ordering: torch.Tensor | None = None
83
+
84
+ true_features = feature_dict.feature_vectors.detach()
85
+
86
+ for sae in saes:
87
+ sae_features = sae.W_dec.detach() # type: ignore[attr-defined]
88
+ cos_sims = cosine_similarities(sae_features, true_features)
89
+ cos_sims = torch.round(cos_sims * 100) / 100 # Reduce numerical noise
90
+
91
+ ordering = find_best_feature_ordering(sae_features, true_features)
92
+ score = cos_sims[ordering, torch.arange(cos_sims.shape[1])].mean().item()
93
+
94
+ if score > best_score:
95
+ best_score = score
96
+ best_ordering = ordering
97
+
98
+ if best_ordering is None:
99
+ raise ValueError("No SAEs provided")
100
+
101
+ return best_ordering
102
+
103
+
104
+ def plot_sae_feature_similarity(
105
+ sae: SAE[Any],
106
+ feature_dict: FeatureDictionary,
107
+ title: str | None = None,
108
+ reorder_features: bool | torch.Tensor = False,
109
+ decoder_only: bool = False,
110
+ show_values: bool = False,
111
+ height: int = 400,
112
+ width: int = 800,
113
+ save_path: str | Path | None = None,
114
+ show_plot: bool = True,
115
+ dtick: int | None = 1,
116
+ scale: float = 1.0,
117
+ ):
118
+ """
119
+ Plot cosine similarities between SAE features and true features.
120
+
121
+ Creates a heatmap showing how well each SAE latent aligns with each
122
+ true feature. Useful for understanding what the SAE has learned.
123
+
124
+ Args:
125
+ sae: The SAE to visualize. Must have W_enc and W_dec attributes.
126
+ feature_dict: The feature dictionary containing true features
127
+ title: Plot title. If None, a default title is used.
128
+ reorder_features: If True, automatically reorders features for best alignment.
129
+ If a tensor, uses that as the ordering.
130
+ decoder_only: If True, only plots the decoder (not encoder and decoder side-by-side)
131
+ show_values: If True, shows numeric values on the heatmap
132
+ height: Height of the figure in pixels
133
+ width: Width of the figure in pixels
134
+ save_path: If provided, saves the figure to this path
135
+ show_plot: If True, displays the plot
136
+ dtick: Tick spacing for axes
137
+ scale: Scale factor for image resolution when saving
138
+ """
139
+ # Get cosine similarities
140
+ true_features = feature_dict.feature_vectors.detach()
141
+ dec_cos_sims = cosine_similarities(sae.W_dec.detach(), true_features) # type: ignore[attr-defined]
142
+ enc_cos_sims = cosine_similarities(sae.W_enc.T.detach(), true_features) # type: ignore[attr-defined]
143
+
144
+ # Round to reduce numerical noise
145
+ dec_cos_sims = torch.round(dec_cos_sims * 100) / 100
146
+ enc_cos_sims = torch.round(enc_cos_sims * 100) / 100
147
+
148
+ # Apply feature reordering if requested
149
+ if reorder_features is not False:
150
+ if isinstance(reorder_features, bool):
151
+ sorted_indices = find_best_feature_ordering(
152
+ sae.W_dec.detach(),
153
+ true_features, # type: ignore[attr-defined]
154
+ )
155
+ else:
156
+ sorted_indices = reorder_features
157
+ dec_cos_sims = dec_cos_sims[sorted_indices]
158
+ enc_cos_sims = enc_cos_sims[sorted_indices]
159
+
160
+ hovertemplate = "True feature: %{x}<br>SAE Latent: %{y}<br>Cosine Similarity: %{z:.3f}<extra></extra>"
161
+
162
+ if decoder_only:
163
+ fig = make_subplots(rows=1, cols=1)
164
+
165
+ decoder_args: dict[str, Any] = {
166
+ "z": dec_cos_sims.cpu().numpy(),
167
+ "zmin": -1,
168
+ "zmax": 1,
169
+ "colorscale": "RdBu",
170
+ "colorbar": dict(title="cos sim", x=1.0, dtick=1, tickvals=[-1, 0, 1]),
171
+ "hovertemplate": hovertemplate,
172
+ }
173
+ if show_values:
174
+ decoder_args["texttemplate"] = "%{z:.2f}"
175
+ decoder_args["textfont"] = {"size": 10}
176
+
177
+ fig.add_trace(go.Heatmap(**decoder_args), row=1, col=1)
178
+ fig.update_xaxes(title_text="True feature", row=1, col=1, dtick=dtick)
179
+ fig.update_yaxes(title_text="SAE Latent", row=1, col=1, dtick=dtick)
180
+ else:
181
+ fig = make_subplots(
182
+ rows=1, cols=2, subplot_titles=("SAE encoder", "SAE decoder")
183
+ )
184
+
185
+ # Encoder heatmap
186
+ encoder_args: dict[str, Any] = {
187
+ "z": enc_cos_sims.cpu().numpy(),
188
+ "zmin": -1,
189
+ "zmax": 1,
190
+ "colorscale": "RdBu",
191
+ "showscale": False,
192
+ "hovertemplate": hovertemplate,
193
+ }
194
+ if show_values:
195
+ encoder_args["texttemplate"] = "%{z:.2f}"
196
+ encoder_args["textfont"] = {"size": 10}
197
+
198
+ fig.add_trace(go.Heatmap(**encoder_args), row=1, col=1)
199
+
200
+ # Decoder heatmap
201
+ decoder_args = {
202
+ "z": dec_cos_sims.cpu().numpy(),
203
+ "zmin": -1,
204
+ "zmax": 1,
205
+ "colorscale": "RdBu",
206
+ "colorbar": dict(title="cos sim", x=1.0, dtick=1, tickvals=[-1, 0, 1]),
207
+ "hovertemplate": hovertemplate,
208
+ }
209
+ if show_values:
210
+ decoder_args["texttemplate"] = "%{z:.2f}"
211
+ decoder_args["textfont"] = {"size": 10}
212
+
213
+ fig.add_trace(go.Heatmap(**decoder_args), row=1, col=2)
214
+
215
+ fig.update_xaxes(title_text="True feature", row=1, col=1, dtick=dtick)
216
+ fig.update_xaxes(title_text="True feature", row=1, col=2, dtick=dtick)
217
+ fig.update_yaxes(title_text="SAE Latent", row=1, col=1, dtick=dtick)
218
+ fig.update_yaxes(title_text="SAE Latent", row=1, col=2, dtick=dtick)
219
+
220
+ # Set main title
221
+ if title is None:
222
+ title = "Cosine similarity with true features"
223
+ fig.update_layout(height=height, width=width, title_text=title)
224
+
225
+ if save_path:
226
+ Path(save_path).parent.mkdir(parents=True, exist_ok=True)
227
+ fig.write_image(save_path, scale=scale)
228
+
229
+ if show_plot:
230
+ fig.show()
@@ -0,0 +1,145 @@
1
+ from collections.abc import Iterator
2
+ from pathlib import Path
3
+ from typing import Any, Callable
4
+
5
+ import torch
6
+
7
+ from sae_lens.config import LoggingConfig, SAETrainerConfig
8
+ from sae_lens.saes.sae import TrainingSAE
9
+ from sae_lens.synthetic.activation_generator import ActivationGenerator
10
+ from sae_lens.synthetic.feature_dictionary import FeatureDictionary
11
+ from sae_lens.training.sae_trainer import SAETrainer, SaveCheckpointFn
12
+
13
+
14
+ def train_toy_sae(
15
+ sae: TrainingSAE[Any],
16
+ feature_dict: FeatureDictionary,
17
+ activations_generator: ActivationGenerator,
18
+ training_samples: int = 10_000_000,
19
+ batch_size: int = 1024,
20
+ lr: float = 3e-4,
21
+ lr_warm_up_steps: int = 0,
22
+ lr_decay_steps: int = 0,
23
+ device: str | torch.device = "cpu",
24
+ n_snapshots: int = 0,
25
+ snapshot_fn: Callable[[SAETrainer[Any, Any]], None] | None = None,
26
+ ) -> None:
27
+ """
28
+ Train an SAE on synthetic activations from a feature dictionary.
29
+
30
+ This is a convenience function that sets up the training loop with
31
+ sensible defaults for small-scale synthetic data experiments.
32
+
33
+ Args:
34
+ sae: The TrainingSAE to train
35
+ feature_dict: The feature dictionary that maps feature activations to
36
+ hidden activations
37
+ activations_generator: Generator that produces feature activations
38
+ training_samples: Total number of training samples
39
+ batch_size: Batch size for training
40
+ lr: Learning rate
41
+ lr_warm_up_steps: Number of warmup steps for learning rate
42
+ lr_decay_steps: Number of steps over which to decay learning rate
43
+ device: Device to train on
44
+ n_snapshots: Number of snapshots to take during training. Snapshots are
45
+ evenly spaced throughout training.
46
+ snapshot_fn: Callback function called at each snapshot point. Receives
47
+ the SAETrainer instance, allowing access to the SAE, training step,
48
+ and other training state. Required if n_snapshots > 0.
49
+ """
50
+
51
+ device_str = str(device) if isinstance(device, torch.device) else device
52
+
53
+ # Create data iterator
54
+ data_iterator = SyntheticActivationIterator(
55
+ feature_dict=feature_dict,
56
+ activations_generator=activations_generator,
57
+ batch_size=batch_size,
58
+ )
59
+
60
+ # Create trainer config
61
+ trainer_cfg = SAETrainerConfig(
62
+ n_checkpoints=n_snapshots,
63
+ checkpoint_path=None,
64
+ save_final_checkpoint=False,
65
+ total_training_samples=training_samples,
66
+ device=device_str,
67
+ autocast=False,
68
+ lr=lr,
69
+ lr_end=lr,
70
+ lr_scheduler_name="constant",
71
+ lr_warm_up_steps=lr_warm_up_steps,
72
+ adam_beta1=0.9,
73
+ adam_beta2=0.999,
74
+ lr_decay_steps=lr_decay_steps,
75
+ n_restart_cycles=1,
76
+ train_batch_size_samples=batch_size,
77
+ dead_feature_window=1000,
78
+ feature_sampling_window=2000,
79
+ logger=LoggingConfig(
80
+ log_to_wandb=False,
81
+ # hacky way to disable evals, but works for now
82
+ eval_every_n_wandb_logs=2**31 - 1,
83
+ ),
84
+ )
85
+
86
+ def snapshot_wrapper(
87
+ snapshot_fn: Callable[[SAETrainer[Any, Any]], None] | None,
88
+ ) -> SaveCheckpointFn:
89
+ def save_checkpoint(checkpoint_path: Path | None) -> None: # noqa: ARG001
90
+ if snapshot_fn is None:
91
+ raise ValueError("snapshot_fn must be provided to take snapshots")
92
+ snapshot_fn(trainer)
93
+
94
+ return save_checkpoint
95
+
96
+ # Create trainer and train
97
+ feature_dict.eval()
98
+ trainer = SAETrainer(
99
+ cfg=trainer_cfg,
100
+ sae=sae,
101
+ data_provider=data_iterator,
102
+ save_checkpoint_fn=snapshot_wrapper(snapshot_fn),
103
+ )
104
+
105
+ trainer.fit()
106
+
107
+
108
+ class SyntheticActivationIterator(Iterator[torch.Tensor]):
109
+ """
110
+ An iterator that generates synthetic activations for SAE training.
111
+
112
+ This iterator wraps a FeatureDictionary and a function that generates
113
+ feature activations, producing hidden activations that can be used
114
+ to train an SAE.
115
+ """
116
+
117
+ def __init__(
118
+ self,
119
+ feature_dict: FeatureDictionary,
120
+ activations_generator: ActivationGenerator,
121
+ batch_size: int,
122
+ ):
123
+ """
124
+ Create a new SyntheticActivationIterator.
125
+
126
+ Args:
127
+ feature_dict: The feature dictionary to use for generating hidden activations
128
+ activations_generator: Generator that produces feature activations
129
+ batch_size: Number of samples per batch
130
+ """
131
+ self.feature_dict = feature_dict
132
+ self.activations_generator = activations_generator
133
+ self.batch_size = batch_size
134
+
135
+ @torch.no_grad()
136
+ def next_batch(self) -> torch.Tensor:
137
+ """Generate the next batch of hidden activations."""
138
+ features = self.activations_generator(self.batch_size)
139
+ return self.feature_dict(features)
140
+
141
+ def __iter__(self) -> "SyntheticActivationIterator":
142
+ return self
143
+
144
+ def __next__(self) -> torch.Tensor:
145
+ return self.next_batch()
@@ -85,8 +85,8 @@ def concat_and_batch_sequences(
85
85
  for sequence in tokens_iterator:
86
86
  if (
87
87
  begin_sequence_token_id is not None
88
- and sequence[0] != begin_sequence_token_id
89
88
  and len(sequence) >= context_size - 1
89
+ and sequence[0] != begin_sequence_token_id
90
90
  ):
91
91
  begin_sequence_token_id_tensor = torch.tensor(
92
92
  [begin_sequence_token_id],
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  import json
4
4
  import os
5
5
  import warnings
6
- from collections.abc import Generator, Iterator, Sequence
6
+ from collections.abc import Generator, Iterator
7
7
  from pathlib import Path
8
8
  from typing import Any, Literal, cast
9
9
 
@@ -148,6 +148,7 @@ class ActivationsStore:
148
148
  exclude_special_tokens=exclude_special_tokens,
149
149
  disable_concat_sequences=cfg.disable_concat_sequences,
150
150
  sequence_separator_token=cfg.sequence_separator_token,
151
+ activations_mixing_fraction=cfg.activations_mixing_fraction,
151
152
  )
152
153
 
153
154
  @classmethod
@@ -222,6 +223,7 @@ class ActivationsStore:
222
223
  exclude_special_tokens: torch.Tensor | None = None,
223
224
  disable_concat_sequences: bool = False,
224
225
  sequence_separator_token: int | Literal["bos", "eos", "sep"] | None = "bos",
226
+ activations_mixing_fraction: float = 0.5,
225
227
  ):
226
228
  self.model = model
227
229
  if model_kwargs is None:
@@ -252,7 +254,6 @@ class ActivationsStore:
252
254
  self.context_size = context_size
253
255
  self.d_in = d_in
254
256
  self.n_batches_in_buffer = n_batches_in_buffer
255
- self.half_buffer_size = n_batches_in_buffer // 2
256
257
  self.total_training_tokens = total_training_tokens
257
258
  self.store_batch_size_prompts = store_batch_size_prompts
258
259
  self.train_batch_size_tokens = train_batch_size_tokens
@@ -269,6 +270,7 @@ class ActivationsStore:
269
270
  self.sequence_separator_token: int | Literal["bos", "eos", "sep"] | None = (
270
271
  sequence_separator_token
271
272
  )
273
+ self.activations_mixing_fraction = activations_mixing_fraction
272
274
 
273
275
  self.n_dataset_processed = 0
274
276
 
@@ -535,18 +537,15 @@ class ActivationsStore:
535
537
 
536
538
  return stacked_activations
537
539
 
538
- def _load_buffer_from_cached(
540
+ def _load_raw_llm_batch_from_cached(
539
541
  self,
540
- total_size: int,
541
- context_size: int,
542
- d_in: int,
543
542
  raise_on_epoch_end: bool,
544
543
  ) -> tuple[
545
544
  torch.Tensor,
546
545
  torch.Tensor | None,
547
546
  ]:
548
547
  """
549
- Loads `total_size` activations from `cached_activation_dataset`
548
+ Loads a batch of activations from `cached_activation_dataset`
550
549
 
551
550
  The dataset has columns for each hook_name,
552
551
  each containing activations of shape (context_size, d_in).
@@ -554,6 +553,10 @@ class ActivationsStore:
554
553
  raises StopIteration
555
554
  """
556
555
  assert self.cached_activation_dataset is not None
556
+ context_size = self.context_size
557
+ batch_size = self.store_batch_size_prompts
558
+ d_in = self.d_in
559
+
557
560
  # In future, could be a list of multiple hook names
558
561
  if self.hook_name not in self.cached_activation_dataset.column_names:
559
562
  raise ValueError(
@@ -561,138 +564,100 @@ class ActivationsStore:
561
564
  f"got {self.cached_activation_dataset.column_names}."
562
565
  )
563
566
 
564
- if self.current_row_idx > len(self.cached_activation_dataset) - total_size:
567
+ if self.current_row_idx > len(self.cached_activation_dataset) - batch_size:
565
568
  self.current_row_idx = 0
566
569
  if raise_on_epoch_end:
567
570
  raise StopIteration
568
571
 
569
- new_buffer = []
570
572
  ds_slice = self.cached_activation_dataset[
571
- self.current_row_idx : self.current_row_idx + total_size
573
+ self.current_row_idx : self.current_row_idx + batch_size
572
574
  ]
573
575
  # Load activations for each hook.
574
576
  # Usually faster to first slice dataset then pick column
575
- new_buffer = ds_slice[self.hook_name]
576
- if new_buffer.shape != (total_size, context_size, d_in):
577
+ acts_buffer = ds_slice[self.hook_name]
578
+ if acts_buffer.shape != (batch_size, context_size, d_in):
577
579
  raise ValueError(
578
- f"new_buffer has shape {new_buffer.shape}, "
579
- f"but expected ({total_size}, {context_size}, {d_in})."
580
+ f"acts_buffer has shape {acts_buffer.shape}, "
581
+ f"but expected ({batch_size}, {context_size}, {d_in})."
580
582
  )
581
583
 
582
- self.current_row_idx += total_size
583
- acts_buffer = new_buffer.reshape(total_size * context_size, d_in)
584
+ self.current_row_idx += batch_size
585
+ acts_buffer = acts_buffer.reshape(batch_size * context_size, d_in)
584
586
 
585
587
  if "token_ids" not in self.cached_activation_dataset.column_names:
586
588
  return acts_buffer, None
587
589
 
588
590
  token_ids_buffer = ds_slice["token_ids"]
589
- if token_ids_buffer.shape != (total_size, context_size):
591
+ if token_ids_buffer.shape != (batch_size, context_size):
590
592
  raise ValueError(
591
593
  f"token_ids_buffer has shape {token_ids_buffer.shape}, "
592
- f"but expected ({total_size}, {context_size})."
594
+ f"but expected ({batch_size}, {context_size})."
593
595
  )
594
- token_ids_buffer = token_ids_buffer.reshape(total_size * context_size)
596
+ token_ids_buffer = token_ids_buffer.reshape(batch_size * context_size)
595
597
  return acts_buffer, token_ids_buffer
596
598
 
597
599
  @torch.no_grad()
598
- def get_raw_buffer(
600
+ def get_raw_llm_batch(
599
601
  self,
600
- n_batches_in_buffer: int,
601
602
  raise_on_epoch_end: bool = False,
602
- shuffle: bool = True,
603
603
  ) -> tuple[torch.Tensor, torch.Tensor | None]:
604
604
  """
605
- Loads the next n_batches_in_buffer batches of activations into a tensor and returns it.
605
+ Loads the next batch of activations from the LLM and returns it.
606
606
 
607
- The primary purpose here is maintaining a shuffling buffer.
607
+ If raise_on_epoch_end is True, when the dataset is exhausted it will
608
+ automatically refill the dataset and then raise a StopIteration so that
609
+ the caller has a chance to react.
608
610
 
609
- If raise_on_epoch_end is True, when the dataset it exhausted it will automatically refill the dataset and then raise a StopIteration so that the caller has a chance to react.
611
+ Returns:
612
+ Tuple of (activations, token_ids) where activations has shape
613
+ (batch_size * context_size, d_in) and token_ids has shape
614
+ (batch_size * context_size,).
610
615
  """
611
- context_size = self.context_size
612
- batch_size = self.store_batch_size_prompts
613
616
  d_in = self.d_in
614
- total_size = batch_size * n_batches_in_buffer
615
617
 
616
618
  if self.cached_activation_dataset is not None:
617
- return self._load_buffer_from_cached(
618
- total_size, context_size, d_in, raise_on_epoch_end
619
- )
619
+ return self._load_raw_llm_batch_from_cached(raise_on_epoch_end)
620
620
 
621
- refill_iterator = range(0, total_size, batch_size)
622
- # Initialize empty tensor buffer of the maximum required size with an additional dimension for layers
623
- new_buffer_activations = torch.zeros(
624
- (total_size, self.training_context_size, d_in),
625
- dtype=self.dtype, # type: ignore
626
- device=self.device,
627
- )
628
- new_buffer_token_ids = torch.zeros(
629
- (total_size, self.training_context_size),
630
- dtype=torch.long,
631
- device=self.device,
621
+ # move batch toks to gpu for model
622
+ batch_tokens = self.get_batch_tokens(raise_at_epoch_end=raise_on_epoch_end).to(
623
+ _get_model_device(self.model)
632
624
  )
625
+ activations = self.get_activations(batch_tokens).to(self.device)
633
626
 
634
- for refill_batch_idx_start in tqdm(
635
- refill_iterator, leave=False, desc="Refilling buffer"
636
- ):
637
- # move batch toks to gpu for model
638
- refill_batch_tokens = self.get_batch_tokens(
639
- raise_at_epoch_end=raise_on_epoch_end
640
- ).to(_get_model_device(self.model))
641
- refill_activations = self.get_activations(refill_batch_tokens)
642
- # move acts back to cpu
643
- refill_activations.to(self.device)
644
- new_buffer_activations[
645
- refill_batch_idx_start : refill_batch_idx_start + batch_size, ...
646
- ] = refill_activations
647
-
648
- # handle seqpos_slice, this is done for activations in get_activations
649
- refill_batch_tokens = refill_batch_tokens[:, slice(*self.seqpos_slice)]
650
- new_buffer_token_ids[
651
- refill_batch_idx_start : refill_batch_idx_start + batch_size, ...
652
- ] = refill_batch_tokens
653
-
654
- new_buffer_activations = new_buffer_activations.reshape(-1, d_in)
655
- new_buffer_token_ids = new_buffer_token_ids.reshape(-1)
656
- if shuffle:
657
- new_buffer_activations, new_buffer_token_ids = permute_together(
658
- [new_buffer_activations, new_buffer_token_ids]
659
- )
627
+ # handle seqpos_slice, this is done for activations in get_activations
628
+ batch_tokens = batch_tokens[:, slice(*self.seqpos_slice)]
660
629
 
661
- return (
662
- new_buffer_activations,
663
- new_buffer_token_ids,
664
- )
630
+ # reshape from (batch, context, d_in) to (batch * context, d_in)
631
+ activations = activations.reshape(-1, d_in)
632
+ token_ids = batch_tokens.reshape(-1)
665
633
 
666
- def get_filtered_buffer(
634
+ return activations, token_ids
635
+
636
+ def get_filtered_llm_batch(
667
637
  self,
668
- n_batches_in_buffer: int,
669
638
  raise_on_epoch_end: bool = False,
670
- shuffle: bool = True,
671
639
  ) -> torch.Tensor:
640
+ """
641
+ Get a batch of LLM activations with special tokens filtered out.
642
+ """
672
643
  return _filter_buffer_acts(
673
- self.get_raw_buffer(
674
- n_batches_in_buffer=n_batches_in_buffer,
675
- raise_on_epoch_end=raise_on_epoch_end,
676
- shuffle=shuffle,
677
- ),
644
+ self.get_raw_llm_batch(raise_on_epoch_end=raise_on_epoch_end),
678
645
  self.exclude_special_tokens,
679
646
  )
680
647
 
681
648
  def _iterate_filtered_activations(self) -> Generator[torch.Tensor, None, None]:
682
649
  """
683
- Iterate over the filtered tokens in the buffer.
650
+ Iterate over filtered LLM activation batches.
684
651
  """
685
652
  while True:
686
653
  try:
687
- yield self.get_filtered_buffer(
688
- self.half_buffer_size, raise_on_epoch_end=True
689
- )
654
+ yield self.get_filtered_llm_batch(raise_on_epoch_end=True)
690
655
  except StopIteration:
691
656
  warnings.warn(
692
657
  "All samples in the training dataset have been exhausted, beginning new epoch."
693
658
  )
694
659
  try:
695
- yield self.get_filtered_buffer(self.half_buffer_size)
660
+ yield self.get_filtered_llm_batch()
696
661
  except StopIteration:
697
662
  raise ValueError(
698
663
  "Unable to fill buffer after starting new epoch. Dataset may be too small."
@@ -708,6 +673,7 @@ class ActivationsStore:
708
673
  buffer_size=self.n_batches_in_buffer * self.training_context_size,
709
674
  batch_size=self.train_batch_size_tokens,
710
675
  activations_loader=self._iterate_filtered_activations(),
676
+ mix_fraction=self.activations_mixing_fraction,
711
677
  )
712
678
 
713
679
  def next_batch(self) -> torch.Tensor:
@@ -823,9 +789,3 @@ def _filter_buffer_acts(
823
789
 
824
790
  mask = torch.isin(tokens, exclude_tokens)
825
791
  return activations[~mask]
826
-
827
-
828
- def permute_together(tensors: Sequence[torch.Tensor]) -> tuple[torch.Tensor, ...]:
829
- """Permute tensors together."""
830
- permutation = torch.randperm(tensors[0].shape[0])
831
- return tuple(t[permutation] for t in tensors)