sae-lens 5.10.3__py3-none-any.whl → 6.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,185 @@
1
+ from dataclasses import dataclass
2
+ from typing import Any
3
+
4
+ import numpy as np
5
+ import torch
6
+ from jaxtyping import Float
7
+ from numpy.typing import NDArray
8
+ from torch import nn
9
+ from typing_extensions import override
10
+
11
+ from sae_lens.saes.sae import (
12
+ SAE,
13
+ SAEConfig,
14
+ TrainCoefficientConfig,
15
+ TrainingSAE,
16
+ TrainingSAEConfig,
17
+ TrainStepInput,
18
+ )
19
+ from sae_lens.util import filter_valid_dataclass_fields
20
+
21
+
22
+ @dataclass
23
+ class StandardSAEConfig(SAEConfig):
24
+ """
25
+ Configuration class for a StandardSAE.
26
+ """
27
+
28
+ @override
29
+ @classmethod
30
+ def architecture(cls) -> str:
31
+ return "standard"
32
+
33
+
34
+ class StandardSAE(SAE[StandardSAEConfig]):
35
+ """
36
+ StandardSAE is an inference-only implementation of a Sparse Autoencoder (SAE)
37
+ using a simple linear encoder and decoder.
38
+
39
+ It implements the required abstract methods from BaseSAE:
40
+ - initialize_weights: sets up simple parameter initializations for W_enc, b_enc, W_dec, and b_dec.
41
+ - encode: computes the feature activations from an input.
42
+ - decode: reconstructs the input from the feature activations.
43
+
44
+ The BaseSAE.forward() method automatically calls encode and decode,
45
+ including any error-term processing if configured.
46
+ """
47
+
48
+ b_enc: nn.Parameter
49
+
50
+ def __init__(self, cfg: StandardSAEConfig, use_error_term: bool = False):
51
+ super().__init__(cfg, use_error_term)
52
+
53
+ @override
54
+ def initialize_weights(self) -> None:
55
+ # Initialize encoder weights and bias.
56
+ super().initialize_weights()
57
+ _init_weights_standard(self)
58
+
59
+ def encode(
60
+ self, x: Float[torch.Tensor, "... d_in"]
61
+ ) -> Float[torch.Tensor, "... d_sae"]:
62
+ """
63
+ Encode the input tensor into the feature space.
64
+ For inference, no noise is added.
65
+ """
66
+ # Preprocess the SAE input (casting type, applying hooks, normalization)
67
+ sae_in = self.process_sae_in(x)
68
+ # Compute the pre-activation values
69
+ hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
70
+ # Apply the activation function (e.g., ReLU, tanh-relu, depending on config)
71
+ return self.hook_sae_acts_post(self.activation_fn(hidden_pre))
72
+
73
+ def decode(
74
+ self, feature_acts: Float[torch.Tensor, "... d_sae"]
75
+ ) -> Float[torch.Tensor, "... d_in"]:
76
+ """
77
+ Decode the feature activations back to the input space.
78
+ Now, if hook_z reshaping is turned on, we reverse the flattening.
79
+ """
80
+ # 1) linear transform
81
+ sae_out_pre = feature_acts @ self.W_dec + self.b_dec
82
+ # 2) hook reconstruction
83
+ sae_out_pre = self.hook_sae_recons(sae_out_pre)
84
+ # 4) optional out-normalization (e.g. constant_norm_rescale or layer_norm)
85
+ sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
86
+ # 5) if hook_z is enabled, rearrange back to (..., n_heads, d_head).
87
+ return self.reshape_fn_out(sae_out_pre, self.d_head)
88
+
89
+
90
+ @dataclass
91
+ class StandardTrainingSAEConfig(TrainingSAEConfig):
92
+ """
93
+ Configuration class for training a StandardTrainingSAE.
94
+ """
95
+
96
+ l1_coefficient: float = 1.0
97
+ lp_norm: float = 1.0
98
+ l1_warm_up_steps: int = 0
99
+
100
+ @override
101
+ @classmethod
102
+ def architecture(cls) -> str:
103
+ return "standard"
104
+
105
+
106
+ class StandardTrainingSAE(TrainingSAE[StandardTrainingSAEConfig]):
107
+ """
108
+ StandardTrainingSAE is a concrete implementation of BaseTrainingSAE using the "standard" SAE architecture.
109
+ It implements:
110
+ - initialize_weights: basic weight initialization for encoder/decoder.
111
+ - encode: inference encoding (invokes encode_with_hidden_pre).
112
+ - decode: a simple linear decoder.
113
+ - encode_with_hidden_pre: computes pre-activations, adds noise when training, and then activates.
114
+ - calculate_aux_loss: computes a sparsity penalty based on the (optionally scaled) p-norm of feature activations.
115
+ """
116
+
117
+ b_enc: nn.Parameter
118
+
119
+ def initialize_weights(self) -> None:
120
+ super().initialize_weights()
121
+ _init_weights_standard(self)
122
+
123
+ @override
124
+ def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]:
125
+ return {
126
+ "l1": TrainCoefficientConfig(
127
+ value=self.cfg.l1_coefficient,
128
+ warm_up_steps=self.cfg.l1_warm_up_steps,
129
+ ),
130
+ }
131
+
132
+ def encode_with_hidden_pre(
133
+ self, x: Float[torch.Tensor, "... d_in"]
134
+ ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
135
+ # Process the input (including dtype conversion, hook call, and any activation normalization)
136
+ sae_in = self.process_sae_in(x)
137
+ # Compute the pre-activation (and allow for a hook if desired)
138
+ hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc) # type: ignore
139
+ # Add noise during training for robustness (scaled by noise_scale from the configuration)
140
+ if self.training and self.cfg.noise_scale > 0:
141
+ hidden_pre_noised = (
142
+ hidden_pre + torch.randn_like(hidden_pre) * self.cfg.noise_scale
143
+ )
144
+ else:
145
+ hidden_pre_noised = hidden_pre
146
+ # Apply the activation function (and any post-activation hook)
147
+ feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre_noised))
148
+ return feature_acts, hidden_pre_noised
149
+
150
+ def calculate_aux_loss(
151
+ self,
152
+ step_input: TrainStepInput,
153
+ feature_acts: torch.Tensor,
154
+ hidden_pre: torch.Tensor,
155
+ sae_out: torch.Tensor,
156
+ ) -> dict[str, torch.Tensor]:
157
+ # The "standard" auxiliary loss is a sparsity penalty on the feature activations
158
+ weighted_feature_acts = feature_acts * self.W_dec.norm(dim=1)
159
+
160
+ # Compute the p-norm (set by cfg.lp_norm) over the feature dimension
161
+ sparsity = weighted_feature_acts.norm(p=self.cfg.lp_norm, dim=-1)
162
+ l1_loss = (step_input.coefficients["l1"] * sparsity).mean()
163
+
164
+ return {"l1_loss": l1_loss}
165
+
166
+ def log_histograms(self) -> dict[str, NDArray[np.generic]]:
167
+ """Log histograms of the weights and biases."""
168
+ b_e_dist = self.b_enc.detach().float().cpu().numpy()
169
+ return {
170
+ **super().log_histograms(),
171
+ "weights/b_e": b_e_dist,
172
+ }
173
+
174
+ def to_inference_config_dict(self) -> dict[str, Any]:
175
+ return filter_valid_dataclass_fields(
176
+ self.cfg.to_dict(), StandardSAEConfig, ["architecture"]
177
+ )
178
+
179
+
180
+ def _init_weights_standard(
181
+ sae: SAE[StandardSAEConfig] | TrainingSAE[StandardTrainingSAEConfig],
182
+ ) -> None:
183
+ sae.b_enc = nn.Parameter(
184
+ torch.zeros(sae.cfg.d_sae, dtype=sae.dtype, device=sae.device)
185
+ )
@@ -0,0 +1,294 @@
1
+ """Inference-only TopKSAE variant, similar in spirit to StandardSAE but using a TopK-based activation."""
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Callable
5
+
6
+ import torch
7
+ from jaxtyping import Float
8
+ from torch import nn
9
+ from typing_extensions import override
10
+
11
+ from sae_lens.saes.sae import (
12
+ SAE,
13
+ SAEConfig,
14
+ TrainCoefficientConfig,
15
+ TrainingSAE,
16
+ TrainingSAEConfig,
17
+ TrainStepInput,
18
+ )
19
+ from sae_lens.util import filter_valid_dataclass_fields
20
+
21
+
22
+ class TopK(nn.Module):
23
+ """
24
+ A simple TopK activation that zeroes out all but the top K elements along the last dimension,
25
+ then optionally applies a post-activation function (e.g., ReLU).
26
+ """
27
+
28
+ b_enc: nn.Parameter
29
+
30
+ def __init__(
31
+ self,
32
+ k: int,
33
+ postact_fn: Callable[[torch.Tensor], torch.Tensor] = nn.ReLU(),
34
+ ):
35
+ super().__init__()
36
+ self.k = k
37
+ self.postact_fn = postact_fn
38
+
39
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
40
+ """
41
+ 1) Select top K elements along the last dimension.
42
+ 2) Apply post-activation (often ReLU).
43
+ 3) Zero out all other entries.
44
+ """
45
+ topk = torch.topk(x, k=self.k, dim=-1)
46
+ values = self.postact_fn(topk.values)
47
+ result = torch.zeros_like(x)
48
+ result.scatter_(-1, topk.indices, values)
49
+ return result
50
+
51
+
52
+ @dataclass
53
+ class TopKSAEConfig(SAEConfig):
54
+ """
55
+ Configuration class for a TopKSAE.
56
+ """
57
+
58
+ k: int = 100
59
+
60
+ @override
61
+ @classmethod
62
+ def architecture(cls) -> str:
63
+ return "topk"
64
+
65
+
66
+ class TopKSAE(SAE[TopKSAEConfig]):
67
+ """
68
+ An inference-only sparse autoencoder using a "topk" activation function.
69
+ It uses linear encoder and decoder layers, applying the TopK activation
70
+ to the hidden pre-activation in its encode step.
71
+ """
72
+
73
+ b_enc: nn.Parameter
74
+
75
+ def __init__(self, cfg: TopKSAEConfig, use_error_term: bool = False):
76
+ """
77
+ Args:
78
+ cfg: SAEConfig defining model size and behavior.
79
+ use_error_term: Whether to apply the error-term approach in the forward pass.
80
+ """
81
+ super().__init__(cfg, use_error_term)
82
+
83
+ @override
84
+ def initialize_weights(self) -> None:
85
+ # Initialize encoder weights and bias.
86
+ super().initialize_weights()
87
+ _init_weights_topk(self)
88
+
89
+ def encode(
90
+ self, x: Float[torch.Tensor, "... d_in"]
91
+ ) -> Float[torch.Tensor, "... d_sae"]:
92
+ """
93
+ Converts input x into feature activations.
94
+ Uses topk activation from the config (cfg.activation_fn == "topk")
95
+ under the hood.
96
+ """
97
+ sae_in = self.process_sae_in(x)
98
+ hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
99
+ # The BaseSAE already sets self.activation_fn to TopK(...) if config requests topk.
100
+ return self.hook_sae_acts_post(self.activation_fn(hidden_pre))
101
+
102
+ def decode(
103
+ self, feature_acts: Float[torch.Tensor, "... d_sae"]
104
+ ) -> Float[torch.Tensor, "... d_in"]:
105
+ """
106
+ Reconstructs the input from topk feature activations.
107
+ Applies optional finetuning scaling, hooking to recons, out normalization,
108
+ and optional head reshaping.
109
+ """
110
+ sae_out_pre = feature_acts @ self.W_dec + self.b_dec
111
+ sae_out_pre = self.hook_sae_recons(sae_out_pre)
112
+ sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre)
113
+ return self.reshape_fn_out(sae_out_pre, self.d_head)
114
+
115
+ @override
116
+ def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
117
+ return TopK(self.cfg.k)
118
+
119
+
120
+ @dataclass
121
+ class TopKTrainingSAEConfig(TrainingSAEConfig):
122
+ """
123
+ Configuration class for training a TopKTrainingSAE.
124
+ """
125
+
126
+ k: int = 100
127
+
128
+ @override
129
+ @classmethod
130
+ def architecture(cls) -> str:
131
+ return "topk"
132
+
133
+
134
+ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]):
135
+ """
136
+ TopK variant with training functionality. Injects noise during training, optionally
137
+ calculates a topk-related auxiliary loss, etc.
138
+ """
139
+
140
+ b_enc: nn.Parameter
141
+
142
+ def __init__(self, cfg: TopKTrainingSAEConfig, use_error_term: bool = False):
143
+ super().__init__(cfg, use_error_term)
144
+
145
+ @override
146
+ def initialize_weights(self) -> None:
147
+ super().initialize_weights()
148
+ _init_weights_topk(self)
149
+
150
+ def encode_with_hidden_pre(
151
+ self, x: Float[torch.Tensor, "... d_in"]
152
+ ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
153
+ """
154
+ Similar to the base training method: cast input, optionally add noise, then apply TopK.
155
+ """
156
+ sae_in = self.process_sae_in(x)
157
+ hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc)
158
+
159
+ # Inject noise if training
160
+ if self.training and self.cfg.noise_scale > 0:
161
+ hidden_pre_noised = (
162
+ hidden_pre + torch.randn_like(hidden_pre) * self.cfg.noise_scale
163
+ )
164
+ else:
165
+ hidden_pre_noised = hidden_pre
166
+
167
+ # Apply the TopK activation function (already set in self.activation_fn if config is "topk")
168
+ feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre_noised))
169
+ return feature_acts, hidden_pre_noised
170
+
171
+ def calculate_aux_loss(
172
+ self,
173
+ step_input: TrainStepInput,
174
+ feature_acts: torch.Tensor,
175
+ hidden_pre: torch.Tensor,
176
+ sae_out: torch.Tensor,
177
+ ) -> dict[str, torch.Tensor]:
178
+ # Calculate the auxiliary loss for dead neurons
179
+ topk_loss = self.calculate_topk_aux_loss(
180
+ sae_in=step_input.sae_in,
181
+ sae_out=sae_out,
182
+ hidden_pre=hidden_pre,
183
+ dead_neuron_mask=step_input.dead_neuron_mask,
184
+ )
185
+ return {"auxiliary_reconstruction_loss": topk_loss}
186
+
187
+ @override
188
+ def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
189
+ return TopK(self.cfg.k)
190
+
191
+ @override
192
+ def get_coefficients(self) -> dict[str, TrainCoefficientConfig | float]:
193
+ return {}
194
+
195
+ def calculate_topk_aux_loss(
196
+ self,
197
+ sae_in: torch.Tensor,
198
+ sae_out: torch.Tensor,
199
+ hidden_pre: torch.Tensor,
200
+ dead_neuron_mask: torch.Tensor | None,
201
+ ) -> torch.Tensor:
202
+ """
203
+ Calculate TopK auxiliary loss.
204
+
205
+ This auxiliary loss encourages dead neurons to learn useful features by having
206
+ them reconstruct the residual error from the live neurons. It's a key part of
207
+ preventing neuron death in TopK SAEs.
208
+ """
209
+ # Mostly taken from https://github.com/EleutherAI/sae/blob/main/sae/sae.py, except without variance normalization
210
+ # NOTE: checking the number of dead neurons will force a GPU sync, so performance can likely be improved here
211
+ if dead_neuron_mask is None or (num_dead := int(dead_neuron_mask.sum())) == 0:
212
+ return sae_out.new_tensor(0.0)
213
+ residual = (sae_in - sae_out).detach()
214
+
215
+ # Heuristic from Appendix B.1 in the paper
216
+ k_aux = sae_in.shape[-1] // 2
217
+
218
+ # Reduce the scale of the loss if there are a small number of dead latents
219
+ scale = min(num_dead / k_aux, 1.0)
220
+ k_aux = min(k_aux, num_dead)
221
+
222
+ auxk_acts = _calculate_topk_aux_acts(
223
+ k_aux=k_aux,
224
+ hidden_pre=hidden_pre,
225
+ dead_neuron_mask=dead_neuron_mask,
226
+ )
227
+
228
+ # Encourage the top ~50% of dead latents to predict the residual of the
229
+ # top k living latents
230
+ recons = self.decode(auxk_acts)
231
+ auxk_loss = (recons - residual).pow(2).sum(dim=-1).mean()
232
+ return scale * auxk_loss
233
+
234
+ def _calculate_topk_aux_acts(
235
+ self,
236
+ k_aux: int,
237
+ hidden_pre: torch.Tensor,
238
+ dead_neuron_mask: torch.Tensor,
239
+ ) -> torch.Tensor:
240
+ """
241
+ Helper method to calculate activations for the auxiliary loss.
242
+
243
+ Args:
244
+ k_aux: Number of top dead neurons to select
245
+ hidden_pre: Pre-activation values from encoder
246
+ dead_neuron_mask: Boolean mask indicating which neurons are dead
247
+
248
+ Returns:
249
+ Tensor with activations for only the top-k dead neurons, zeros elsewhere
250
+ """
251
+ # Don't include living latents in this loss (set them to -inf so they won't be selected)
252
+ auxk_latents = torch.where(
253
+ dead_neuron_mask[None],
254
+ hidden_pre,
255
+ torch.tensor(-float("inf"), device=hidden_pre.device),
256
+ )
257
+
258
+ # Find topk values among dead neurons
259
+ auxk_topk = auxk_latents.topk(k_aux, dim=-1, sorted=False)
260
+
261
+ # Create a tensor of zeros, then place the topk values at their proper indices
262
+ auxk_acts = torch.zeros_like(hidden_pre)
263
+ auxk_acts.scatter_(-1, auxk_topk.indices, auxk_topk.values)
264
+
265
+ return auxk_acts
266
+
267
+ def to_inference_config_dict(self) -> dict[str, Any]:
268
+ return filter_valid_dataclass_fields(
269
+ self.cfg.to_dict(), TopKSAEConfig, ["architecture"]
270
+ )
271
+
272
+
273
+ def _calculate_topk_aux_acts(
274
+ k_aux: int,
275
+ hidden_pre: torch.Tensor,
276
+ dead_neuron_mask: torch.Tensor,
277
+ ) -> torch.Tensor:
278
+ # Don't include living latents in this loss
279
+ auxk_latents = torch.where(dead_neuron_mask[None], hidden_pre, -torch.inf)
280
+ # Top-k dead latents
281
+ auxk_topk = auxk_latents.topk(k_aux, sorted=False)
282
+ # Set the activations to zero for all but the top k_aux dead latents
283
+ auxk_acts = torch.zeros_like(hidden_pre)
284
+ auxk_acts.scatter_(-1, auxk_topk.indices, auxk_topk.values)
285
+ # Set activations to zero for all but top k_aux dead latents
286
+ return auxk_acts
287
+
288
+
289
+ def _init_weights_topk(
290
+ sae: SAE[TopKSAEConfig] | TrainingSAE[TopKTrainingSAEConfig],
291
+ ) -> None:
292
+ sae.b_enc = nn.Parameter(
293
+ torch.zeros(sae.cfg.d_sae, dtype=sae.dtype, device=sae.device)
294
+ )
@@ -23,12 +23,12 @@ from transformers import AutoTokenizer, PreTrainedTokenizerBase
23
23
 
24
24
  from sae_lens import logger
25
25
  from sae_lens.config import (
26
- DTYPE_MAP,
27
26
  CacheActivationsRunnerConfig,
28
27
  HfDataset,
29
28
  LanguageModelSAERunnerConfig,
30
29
  )
31
- from sae_lens.sae import SAE
30
+ from sae_lens.constants import DTYPE_MAP
31
+ from sae_lens.saes.sae import SAE, T_SAE_CONFIG, T_TRAINING_SAE_CONFIG
32
32
  from sae_lens.tokenization_and_batching import concat_and_batch_sequences
33
33
 
34
34
 
@@ -91,7 +91,8 @@ class ActivationsStore:
91
91
  def from_config(
92
92
  cls,
93
93
  model: HookedRootModule,
94
- cfg: LanguageModelSAERunnerConfig | CacheActivationsRunnerConfig,
94
+ cfg: LanguageModelSAERunnerConfig[T_TRAINING_SAE_CONFIG]
95
+ | CacheActivationsRunnerConfig,
95
96
  override_dataset: HfDataset | None = None,
96
97
  ) -> ActivationsStore:
97
98
  if isinstance(cfg, CacheActivationsRunnerConfig):
@@ -128,13 +129,15 @@ class ActivationsStore:
128
129
  hook_layer=cfg.hook_layer,
129
130
  hook_head_index=cfg.hook_head_index,
130
131
  context_size=cfg.context_size,
131
- d_in=cfg.d_in,
132
+ d_in=cfg.d_in
133
+ if isinstance(cfg, CacheActivationsRunnerConfig)
134
+ else cfg.sae.d_in,
132
135
  n_batches_in_buffer=cfg.n_batches_in_buffer,
133
136
  total_training_tokens=cfg.training_tokens,
134
137
  store_batch_size_prompts=cfg.store_batch_size_prompts,
135
138
  train_batch_size_tokens=cfg.train_batch_size_tokens,
136
139
  prepend_bos=cfg.prepend_bos,
137
- normalize_activations=cfg.normalize_activations,
140
+ normalize_activations=cfg.sae.normalize_activations,
138
141
  device=device,
139
142
  dtype=cfg.dtype,
140
143
  cached_activations_path=cached_activations_path,
@@ -149,9 +152,10 @@ class ActivationsStore:
149
152
  def from_sae(
150
153
  cls,
151
154
  model: HookedRootModule,
152
- sae: SAE,
155
+ sae: SAE[T_SAE_CONFIG],
156
+ dataset: HfDataset | str,
157
+ dataset_trust_remote_code: bool = False,
153
158
  context_size: int | None = None,
154
- dataset: HfDataset | str | None = None,
155
159
  streaming: bool = True,
156
160
  store_batch_size_prompts: int = 8,
157
161
  n_batches_in_buffer: int = 8,
@@ -159,25 +163,37 @@ class ActivationsStore:
159
163
  total_tokens: int = 10**9,
160
164
  device: str = "cpu",
161
165
  ) -> ActivationsStore:
166
+ if sae.cfg.metadata.hook_name is None:
167
+ raise ValueError("hook_name is required")
168
+ if sae.cfg.metadata.hook_layer is None:
169
+ raise ValueError("hook_layer is required")
170
+ if sae.cfg.metadata.hook_head_index is None:
171
+ raise ValueError("hook_head_index is required")
172
+ if sae.cfg.metadata.context_size is None:
173
+ raise ValueError("context_size is required")
174
+ if sae.cfg.metadata.prepend_bos is None:
175
+ raise ValueError("prepend_bos is required")
162
176
  return cls(
163
177
  model=model,
164
- dataset=sae.cfg.dataset_path if dataset is None else dataset,
178
+ dataset=dataset,
165
179
  d_in=sae.cfg.d_in,
166
- hook_name=sae.cfg.hook_name,
167
- hook_layer=sae.cfg.hook_layer,
168
- hook_head_index=sae.cfg.hook_head_index,
169
- context_size=sae.cfg.context_size if context_size is None else context_size,
170
- prepend_bos=sae.cfg.prepend_bos,
180
+ hook_name=sae.cfg.metadata.hook_name,
181
+ hook_layer=sae.cfg.metadata.hook_layer,
182
+ hook_head_index=sae.cfg.metadata.hook_head_index,
183
+ context_size=sae.cfg.metadata.context_size
184
+ if context_size is None
185
+ else context_size,
186
+ prepend_bos=sae.cfg.metadata.prepend_bos,
171
187
  streaming=streaming,
172
188
  store_batch_size_prompts=store_batch_size_prompts,
173
189
  train_batch_size_tokens=train_batch_size_tokens,
174
190
  n_batches_in_buffer=n_batches_in_buffer,
175
191
  total_training_tokens=total_tokens,
176
192
  normalize_activations=sae.cfg.normalize_activations,
177
- dataset_trust_remote_code=sae.cfg.dataset_trust_remote_code,
193
+ dataset_trust_remote_code=dataset_trust_remote_code,
178
194
  dtype=sae.cfg.dtype,
179
195
  device=torch.device(device),
180
- seqpos_slice=sae.cfg.seqpos_slice,
196
+ seqpos_slice=sae.cfg.metadata.seqpos_slice or (None,),
181
197
  )
182
198
 
183
199
  def __init__(
@@ -448,7 +464,7 @@ class ActivationsStore:
448
464
  ):
449
465
  # temporalily set estimated_norm_scaling_factor to 1.0 so the dataloader works
450
466
  self.estimated_norm_scaling_factor = 1.0
451
- acts = self.next_batch()[:, 0]
467
+ acts = self.next_batch()[0]
452
468
  self.estimated_norm_scaling_factor = None
453
469
  norms_per_batch.append(acts.norm(dim=-1).mean().item())
454
470
  mean_norm = np.mean(norms_per_batch)