sae-lens 6.0.0rc2__py3-none-any.whl → 6.0.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,16 +1,17 @@
1
1
  import contextlib
2
2
  from dataclasses import dataclass
3
- from typing import Any, Generic, Protocol, cast
3
+ from pathlib import Path
4
+ from typing import Any, Callable, Generic, Protocol
4
5
 
5
6
  import torch
6
7
  import wandb
8
+ from safetensors.torch import save_file
7
9
  from torch.optim import Adam
8
10
  from tqdm import tqdm
9
- from transformer_lens.hook_points import HookedRootModule
10
11
 
11
12
  from sae_lens import __version__
12
- from sae_lens.config import LanguageModelSAERunnerConfig
13
- from sae_lens.evals import EvalConfig, run_evals
13
+ from sae_lens.config import SAETrainerConfig
14
+ from sae_lens.constants import ACTIVATION_SCALER_CFG_FILENAME, SPARSITY_FILENAME
14
15
  from sae_lens.saes.sae import (
15
16
  T_TRAINING_SAE,
16
17
  T_TRAINING_SAE_CONFIG,
@@ -19,8 +20,9 @@ from sae_lens.saes.sae import (
19
20
  TrainStepInput,
20
21
  TrainStepOutput,
21
22
  )
22
- from sae_lens.training.activations_store import ActivationsStore
23
+ from sae_lens.training.activation_scaler import ActivationScaler
23
24
  from sae_lens.training.optim import CoefficientScheduler, get_lr_scheduler
25
+ from sae_lens.training.types import DataProvider
24
26
 
25
27
 
26
28
  def _log_feature_sparsity(
@@ -46,33 +48,39 @@ class TrainSAEOutput:
46
48
  class SaveCheckpointFn(Protocol):
47
49
  def __call__(
48
50
  self,
49
- trainer: "SAETrainer[Any, Any]",
50
- checkpoint_name: str,
51
- wandb_aliases: list[str] | None = None,
51
+ checkpoint_path: Path,
52
52
  ) -> None: ...
53
53
 
54
54
 
55
+ Evaluator = Callable[[T_TRAINING_SAE, DataProvider, ActivationScaler], dict[str, Any]]
56
+
57
+
55
58
  class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
56
59
  """
57
60
  Core SAE class used for inference. For training, see TrainingSAE.
58
61
  """
59
62
 
63
+ data_provider: DataProvider
64
+ activation_scaler: ActivationScaler
65
+ evaluator: Evaluator[T_TRAINING_SAE] | None
66
+
60
67
  def __init__(
61
68
  self,
62
- model: HookedRootModule,
69
+ cfg: SAETrainerConfig,
63
70
  sae: T_TRAINING_SAE,
64
- activation_store: ActivationsStore,
65
- save_checkpoint_fn: SaveCheckpointFn,
66
- cfg: LanguageModelSAERunnerConfig[T_TRAINING_SAE_CONFIG],
71
+ data_provider: DataProvider,
72
+ evaluator: Evaluator[T_TRAINING_SAE] | None = None,
73
+ save_checkpoint_fn: SaveCheckpointFn | None = None,
67
74
  ) -> None:
68
- self.model = model
69
75
  self.sae = sae
70
- self.activations_store = activation_store
71
- self.save_checkpoint = save_checkpoint_fn
76
+ self.data_provider = data_provider
77
+ self.evaluator = evaluator
78
+ self.activation_scaler = ActivationScaler()
79
+ self.save_checkpoint_fn = save_checkpoint_fn
72
80
  self.cfg = cfg
73
81
 
74
82
  self.n_training_steps: int = 0
75
- self.n_training_tokens: int = 0
83
+ self.n_training_samples: int = 0
76
84
  self.started_fine_tuning: bool = False
77
85
 
78
86
  _update_sae_lens_training_version(self.sae)
@@ -82,20 +90,16 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
82
90
  self.checkpoint_thresholds = list(
83
91
  range(
84
92
  0,
85
- cfg.total_training_tokens,
86
- cfg.total_training_tokens // self.cfg.n_checkpoints,
93
+ cfg.total_training_samples,
94
+ cfg.total_training_samples // self.cfg.n_checkpoints,
87
95
  )
88
96
  )[1:]
89
97
 
90
- self.act_freq_scores = torch.zeros(
91
- cast(int, cfg.sae.d_sae),
92
- device=cfg.device,
93
- )
98
+ self.act_freq_scores = torch.zeros(sae.cfg.d_sae, device=cfg.device)
94
99
  self.n_forward_passes_since_fired = torch.zeros(
95
- cast(int, cfg.sae.d_sae),
96
- device=cfg.device,
100
+ sae.cfg.d_sae, device=cfg.device
97
101
  )
98
- self.n_frac_active_tokens = 0
102
+ self.n_frac_active_samples = 0
99
103
  # we don't train the scaling factor (initially)
100
104
  # set requires grad to false for the scaling factor
101
105
  for name, param in self.sae.named_parameters():
@@ -131,7 +135,7 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
131
135
  )
132
136
 
133
137
  # Setup autocast if using
134
- self.scaler = torch.amp.GradScaler(
138
+ self.grad_scaler = torch.amp.GradScaler(
135
139
  device=self.cfg.device, enabled=self.cfg.autocast
136
140
  )
137
141
 
@@ -144,23 +148,9 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
144
148
  else:
145
149
  self.autocast_if_enabled = contextlib.nullcontext()
146
150
 
147
- # Set up eval config
148
-
149
- self.trainer_eval_config = EvalConfig(
150
- batch_size_prompts=self.cfg.eval_batch_size_prompts,
151
- n_eval_reconstruction_batches=self.cfg.n_eval_batches,
152
- n_eval_sparsity_variance_batches=self.cfg.n_eval_batches,
153
- compute_ce_loss=True,
154
- compute_l2_norms=True,
155
- compute_sparsity_metrics=True,
156
- compute_variance_metrics=True,
157
- compute_kl=False,
158
- compute_featurewise_weight_based_metrics=False,
159
- )
160
-
161
151
  @property
162
152
  def feature_sparsity(self) -> torch.Tensor:
163
- return self.act_freq_scores / self.n_frac_active_tokens
153
+ return self.act_freq_scores / self.n_frac_active_samples
164
154
 
165
155
  @property
166
156
  def log_feature_sparsity(self) -> torch.Tensor:
@@ -171,19 +161,23 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
171
161
  return (self.n_forward_passes_since_fired > self.cfg.dead_feature_window).bool()
172
162
 
173
163
  def fit(self) -> T_TRAINING_SAE:
174
- pbar = tqdm(total=self.cfg.total_training_tokens, desc="Training SAE")
164
+ pbar = tqdm(total=self.cfg.total_training_samples, desc="Training SAE")
175
165
 
176
- self.activations_store.set_norm_scaling_factor_if_needed()
166
+ if self.sae.cfg.normalize_activations == "expected_average_only_in":
167
+ self.activation_scaler.estimate_scaling_factor(
168
+ d_in=self.sae.cfg.d_in,
169
+ data_provider=self.data_provider,
170
+ n_batches_for_norm_estimate=int(1e3),
171
+ )
177
172
 
178
173
  # Train loop
179
- while self.n_training_tokens < self.cfg.total_training_tokens:
174
+ while self.n_training_samples < self.cfg.total_training_samples:
180
175
  # Do a training step.
181
- layer_acts = self.activations_store.next_batch()[:, 0, :].to(
182
- self.sae.device
183
- )
184
- self.n_training_tokens += self.cfg.train_batch_size_tokens
176
+ batch = next(self.data_provider).to(self.sae.device)
177
+ self.n_training_samples += batch.shape[0]
178
+ scaled_batch = self.activation_scaler(batch)
185
179
 
186
- step_output = self._train_step(sae=self.sae, sae_in=layer_acts)
180
+ step_output = self._train_step(sae=self.sae, sae_in=scaled_batch)
187
181
 
188
182
  if self.cfg.logger.log_to_wandb:
189
183
  self._log_train_step(step_output)
@@ -194,22 +188,49 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
194
188
  self._update_pbar(step_output, pbar)
195
189
 
196
190
  # fold the estimated norm scaling factor into the sae weights
197
- if self.activations_store.estimated_norm_scaling_factor is not None:
191
+ if self.activation_scaler.scaling_factor is not None:
198
192
  self.sae.fold_activation_norm_scaling_factor(
199
- self.activations_store.estimated_norm_scaling_factor
193
+ self.activation_scaler.scaling_factor
200
194
  )
201
- self.activations_store.estimated_norm_scaling_factor = None
195
+ self.activation_scaler.scaling_factor = None
202
196
 
203
197
  # save final sae group to checkpoints folder
204
198
  self.save_checkpoint(
205
- trainer=self,
206
- checkpoint_name=f"final_{self.n_training_tokens}",
199
+ checkpoint_name=f"final_{self.n_training_samples}",
207
200
  wandb_aliases=["final_model"],
208
201
  )
209
202
 
210
203
  pbar.close()
211
204
  return self.sae
212
205
 
206
+ def save_checkpoint(
207
+ self,
208
+ checkpoint_name: str,
209
+ wandb_aliases: list[str] | None = None,
210
+ ) -> None:
211
+ checkpoint_path = Path(self.cfg.checkpoint_path) / checkpoint_name
212
+ checkpoint_path.mkdir(exist_ok=True, parents=True)
213
+
214
+ weights_path, cfg_path = self.sae.save_model(str(checkpoint_path))
215
+
216
+ sparsity_path = checkpoint_path / SPARSITY_FILENAME
217
+ save_file({"sparsity": self.log_feature_sparsity}, sparsity_path)
218
+
219
+ activation_scaler_path = checkpoint_path / ACTIVATION_SCALER_CFG_FILENAME
220
+ self.activation_scaler.save(str(activation_scaler_path))
221
+
222
+ if self.cfg.logger.log_to_wandb:
223
+ self.cfg.logger.log(
224
+ self,
225
+ weights_path,
226
+ cfg_path,
227
+ sparsity_path=sparsity_path,
228
+ wandb_aliases=wandb_aliases,
229
+ )
230
+
231
+ if self.save_checkpoint_fn is not None:
232
+ self.save_checkpoint_fn(checkpoint_path=checkpoint_path)
233
+
213
234
  def _train_step(
214
235
  self,
215
236
  sae: T_TRAINING_SAE,
@@ -242,17 +263,19 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
242
263
  self.act_freq_scores += (
243
264
  (train_step_output.feature_acts.abs() > 0).float().sum(0)
244
265
  )
245
- self.n_frac_active_tokens += self.cfg.train_batch_size_tokens
266
+ self.n_frac_active_samples += self.cfg.train_batch_size_samples
246
267
 
247
- # Scaler will rescale gradients if autocast is enabled
248
- self.scaler.scale(
268
+ # Grad scaler will rescale gradients if autocast is enabled
269
+ self.grad_scaler.scale(
249
270
  train_step_output.loss
250
271
  ).backward() # loss.backward() if not autocasting
251
- self.scaler.unscale_(self.optimizer) # needed to clip correctly
272
+ self.grad_scaler.unscale_(self.optimizer) # needed to clip correctly
252
273
  # TODO: Work out if grad norm clipping should be in config / how to test it.
253
274
  torch.nn.utils.clip_grad_norm_(sae.parameters(), 1.0)
254
- self.scaler.step(self.optimizer) # just ctx.optimizer.step() if not autocasting
255
- self.scaler.update()
275
+ self.grad_scaler.step(
276
+ self.optimizer
277
+ ) # just ctx.optimizer.step() if not autocasting
278
+ self.grad_scaler.update()
256
279
 
257
280
  self.optimizer.zero_grad()
258
281
  self.lr_scheduler.step()
@@ -267,7 +290,7 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
267
290
  wandb.log(
268
291
  self._build_train_step_log_dict(
269
292
  output=step_output,
270
- n_training_tokens=self.n_training_tokens,
293
+ n_training_samples=self.n_training_samples,
271
294
  ),
272
295
  step=self.n_training_steps,
273
296
  )
@@ -283,7 +306,7 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
283
306
  def _build_train_step_log_dict(
284
307
  self,
285
308
  output: TrainStepOutput,
286
- n_training_tokens: int,
309
+ n_training_samples: int,
287
310
  ) -> dict[str, Any]:
288
311
  sae_in = output.sae_in
289
312
  sae_out = output.sae_out
@@ -311,7 +334,7 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
311
334
  "sparsity/mean_passes_since_fired": self.n_forward_passes_since_fired.mean().item(),
312
335
  "sparsity/dead_features": self.dead_neurons.sum().item(),
313
336
  "details/current_learning_rate": current_learning_rate,
314
- "details/n_training_tokens": n_training_tokens,
337
+ "details/n_training_samples": n_training_samples,
315
338
  **{
316
339
  f"details/{name}_coefficient": scheduler.value
317
340
  for name, scheduler in self.coefficient_schedulers.items()
@@ -331,30 +354,11 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
331
354
  * self.cfg.logger.eval_every_n_wandb_logs
332
355
  ) == 0:
333
356
  self.sae.eval()
334
- ignore_tokens = set()
335
- if self.activations_store.exclude_special_tokens is not None:
336
- ignore_tokens = set(
337
- self.activations_store.exclude_special_tokens.tolist()
338
- )
339
- eval_metrics, _ = run_evals(
340
- sae=self.sae,
341
- activation_store=self.activations_store,
342
- model=self.model,
343
- eval_config=self.trainer_eval_config,
344
- ignore_tokens=ignore_tokens,
345
- model_kwargs=self.cfg.model_kwargs,
346
- ) # not calculating featurwise metrics here.
347
-
348
- # Remove eval metrics that are already logged during training
349
- eval_metrics.pop("metrics/explained_variance", None)
350
- eval_metrics.pop("metrics/explained_variance_std", None)
351
- eval_metrics.pop("metrics/l0", None)
352
- eval_metrics.pop("metrics/l1", None)
353
- eval_metrics.pop("metrics/mse", None)
354
-
355
- # Remove metrics that are not useful for wandb logging
356
- eval_metrics.pop("metrics/total_tokens_evaluated", None)
357
-
357
+ eval_metrics = (
358
+ self.evaluator(self.sae, self.data_provider, self.activation_scaler)
359
+ if self.evaluator is not None
360
+ else {}
361
+ )
358
362
  for key, value in self.sae.log_histograms().items():
359
363
  eval_metrics[key] = wandb.Histogram(value) # type: ignore
360
364
 
@@ -378,21 +382,18 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
378
382
  @torch.no_grad()
379
383
  def _reset_running_sparsity_stats(self) -> None:
380
384
  self.act_freq_scores = torch.zeros(
381
- self.cfg.sae.d_sae, # type: ignore
385
+ self.sae.cfg.d_sae, # type: ignore
382
386
  device=self.cfg.device,
383
387
  )
384
- self.n_frac_active_tokens = 0
388
+ self.n_frac_active_samples = 0
385
389
 
386
390
  @torch.no_grad()
387
391
  def _checkpoint_if_needed(self):
388
392
  if (
389
393
  self.checkpoint_thresholds
390
- and self.n_training_tokens > self.checkpoint_thresholds[0]
394
+ and self.n_training_samples > self.checkpoint_thresholds[0]
391
395
  ):
392
- self.save_checkpoint(
393
- trainer=self,
394
- checkpoint_name=str(self.n_training_tokens),
395
- )
396
+ self.save_checkpoint(checkpoint_name=str(self.n_training_samples))
396
397
  self.checkpoint_thresholds.pop(0)
397
398
 
398
399
  @torch.no_grad()
@@ -408,7 +409,7 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
408
409
  for loss_name, loss_value in step_output.losses.items()
409
410
  )
410
411
  pbar.set_description(f"{self.n_training_steps}| {loss_strs}")
411
- pbar.update(update_interval * self.cfg.train_batch_size_tokens)
412
+ pbar.update(update_interval * self.cfg.train_batch_size_samples)
412
413
 
413
414
 
414
415
  def _unwrap_item(item: float | torch.Tensor) -> float:
@@ -0,0 +1,5 @@
1
+ from typing import Iterator
2
+
3
+ import torch
4
+
5
+ DataProvider = Iterator[torch.Tensor]
sae_lens/util.py CHANGED
@@ -1,3 +1,4 @@
1
+ import re
1
2
  from dataclasses import asdict, fields, is_dataclass
2
3
  from typing import Sequence, TypeVar
3
4
 
@@ -26,3 +27,21 @@ def filter_valid_dataclass_fields(
26
27
  if whitelist_fields is not None:
27
28
  valid_field_names = valid_field_names.union(whitelist_fields)
28
29
  return {key: val for key, val in source_dict.items() if key in valid_field_names}
30
+
31
+
32
+ def extract_stop_at_layer_from_tlens_hook_name(hook_name: str) -> int | None:
33
+ """Extract the stop_at layer from a HookedTransformer hook name.
34
+
35
+ Returns None if the hook name is not a valid HookedTransformer hook name.
36
+ """
37
+ layer = extract_layer_from_tlens_hook_name(hook_name)
38
+ return None if layer is None else layer + 1
39
+
40
+
41
+ def extract_layer_from_tlens_hook_name(hook_name: str) -> int | None:
42
+ """Extract the layer from a HookedTransformer hook name.
43
+
44
+ Returns None if the hook name is not a valid HookedTransformer hook name.
45
+ """
46
+ hook_match = re.search(r"\.(\d+)\.", hook_name)
47
+ return None if hook_match is None else int(hook_match.group(1))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: sae-lens
3
- Version: 6.0.0rc2
3
+ Version: 6.0.0rc3
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  Keywords: deep-learning,sparse-autoencoders,mechanistic-interpretability,PyTorch
@@ -1,35 +1,38 @@
1
- sae_lens/__init__.py,sha256=JZATcdlWGVOXYTHb41hn7dPp7pR2tWgpLAz2ztQOE-A,2747
1
+ sae_lens/__init__.py,sha256=881mDkwEifeN32NsH78_CaeH11sKYK4YnqCW502qHE4,2861
2
2
  sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  sae_lens/analysis/hooked_sae_transformer.py,sha256=Eyg1Y2hVIHNuiiLOCTgzstOuW6iA-7hPHqaGR8y_vMs,13809
4
4
  sae_lens/analysis/neuronpedia_integration.py,sha256=DlI08ThI0zwMrBthICt1OFCMyqmaCUDeZxhOk7b7teY,18680
5
- sae_lens/cache_activations_runner.py,sha256=27jp2hFxZj4foWCRCJJd2VCwYJtMgkvPx6MuIhQBofc,12591
6
- sae_lens/config.py,sha256=Ff6MRzRlVk8xtgkvHdJEmuPh9Owc10XIWBaUwdypzkU,26062
7
- sae_lens/constants.py,sha256=HSiSp0j2Umak2buT30seFhkmj7KNuPmB3u4yLXrgfOg,462
8
- sae_lens/evals.py,sha256=aR0pJMBWBUdZElXPcxUyNnNYWbM2LC5UeaESKAwdOMY,39098
9
- sae_lens/load_model.py,sha256=tE70sXsyyyGYW7o506O3eiw1MXyyW6DCQojLG49hWYI,6771
5
+ sae_lens/cache_activations_runner.py,sha256=L5hhuU2-zPQr2S3L64GMKKLeMQfqXxwDl8NbuOtrybI,12567
6
+ sae_lens/config.py,sha256=5Wgr8SsUvYWU2Xmet1JyJ0upAZArMDpYfr3jaK8TvRY,27234
7
+ sae_lens/constants.py,sha256=RJlzWx7wLNMNmrdI63naF7-M3enb55vYRN4x1hXx6vI,593
8
+ sae_lens/evals.py,sha256=WRdHlVeZxXCi33gef7rQE90PSUBF6pjrHnPP6av_Urg,38747
9
+ sae_lens/llm_sae_training_runner.py,sha256=-FPXaHvDfSw5twSaDO8O80aGIzX6T0HywgdpEFFoO-8,9098
10
+ sae_lens/load_model.py,sha256=dBB_9gO6kWyQ4sXHq7qB8T3YUlXm3PGwYcpR4UVW4QY,8633
10
11
  sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
- sae_lens/loading/pretrained_sae_loaders.py,sha256=IgQ-XSJ5VTLCzmJavPmk1vExBVB-36wW7w-ZNo7tzPY,31214
12
+ sae_lens/loading/pretrained_sae_loaders.py,sha256=FSAz9Je-8Xl7ccdEyp8-WRn-KFtaJ74zgKMefnfaj3A,30877
12
13
  sae_lens/loading/pretrained_saes_directory.py,sha256=4Vn-Jex6SveD7EbxcSOBv8cx1gkPfUMLU1QOP-ww1ZE,3752
13
14
  sae_lens/pretokenize_runner.py,sha256=0nHQq3s_d80VS8iVK4-e6y_orAYVO8c4RrLGtIDfK_E,6885
14
15
  sae_lens/pretrained_saes.yaml,sha256=C_z-7Lxz6ZIy2V-c-4Xw45eAQ926O9aGjocSNuki0xs,573557
15
16
  sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
16
- sae_lens/sae_training_runner.py,sha256=lI_d3ywS312dIz0wctm_Sgt3W9ffBOS7ahnDXBljX1s,8320
17
17
  sae_lens/saes/__init__.py,sha256=v6mfeDzyGYtT6x5SszAQtkldTXwPE-V_iwOlrT_pDwQ,1008
18
18
  sae_lens/saes/gated_sae.py,sha256=IgWvZxeJpdiu7VqeUnJLC-VWVhz6o8OXvmwCS-LJ-WQ,9426
19
19
  sae_lens/saes/jumprelu_sae.py,sha256=lkhafpoYYn4-62tBlmmufmUomoo3CmFFQQ3NNylBNSM,12264
20
- sae_lens/saes/sae.py,sha256=edJK3VFzOVBPXUX6QJ5fhhoY0wcfEisDmVXiqFRA7Xg,35089
20
+ sae_lens/saes/sae.py,sha256=u4kmsUVxa2rnFt8A5jLfj7T6h6qqBK6CkecHslebQgE,34938
21
21
  sae_lens/saes/standard_sae.py,sha256=tMs6Z6Cv44PWa7pLo53xhXFnHMvO5BM6eVYHtRPLpos,6652
22
22
  sae_lens/saes/topk_sae.py,sha256=CfF59K4J2XwUvztwg4fBbvFO3PyucLkg4Elkxdk0ozs,9786
23
23
  sae_lens/tokenization_and_batching.py,sha256=oUAscjy_LPOrOb8_Ty6eLAcZ0B3HB_wiWjWktgolhG0,4314
24
24
  sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
25
- sae_lens/training/activations_store.py,sha256=5V5dExeXWoE0dw-ePOZVnQIbBJwrepRMdsQrRam9Lg8,36790
25
+ sae_lens/training/activation_scaler.py,sha256=1P-vva3wJhs2NH65YONli4Rw4auvgZkxe_KKwTNMCR0,1714
26
+ sae_lens/training/activations_store.py,sha256=Xvnz7l2aw3XWtOQsQDj4G4bt-XT6egbumGBwrAM1mtA,32722
26
27
  sae_lens/training/geometric_median.py,sha256=3kH8ZJAgKStlnZgs6s1uYGDYh004Bl0r4RLhuwT3lBY,3719
28
+ sae_lens/training/mixing_buffer.py,sha256=vDpYG5ZE70szDvBsRKcNHEES3h_WTKJ16qDYk5jPOVA,2015
27
29
  sae_lens/training/optim.py,sha256=KXdOym-Ly3f2aFbndRc0JEH0Wa7u1BE5ljxGN3YtouQ,6836
28
- sae_lens/training/sae_trainer.py,sha256=zYAk_9QJ8AJi2TjDZ1qW_lyoovSBqrJvBHzyYgb89ZY,15251
30
+ sae_lens/training/sae_trainer.py,sha256=rFuMdnBDe82nd7YV_QKVE18V5jCWmohbzkIGL0Z2kIM,15153
31
+ sae_lens/training/types.py,sha256=qSjmGzXf3MLalygG0psnVjmhX_mpLmL47MQtZfe7qxg,81
29
32
  sae_lens/training/upload_saes_to_huggingface.py,sha256=tXvR4j25IgMjJ8R9oczwSdy00Tg-P_jAtnPHRt8yF64,4489
30
33
  sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
31
- sae_lens/util.py,sha256=4lqtl7HT9OiyRK8fe8nXtkcn2lOR1uX7ANrAClf6Bv8,1026
32
- sae_lens-6.0.0rc2.dist-info/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
33
- sae_lens-6.0.0rc2.dist-info/METADATA,sha256=Z8Zwb6EknAPB5dOvfduYZewr4nldot-1dQoqz50Co3k,5326
34
- sae_lens-6.0.0rc2.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
35
- sae_lens-6.0.0rc2.dist-info/RECORD,,
34
+ sae_lens/util.py,sha256=mCwLAilGMVo8Scm7CIsCafU7GsfmBvCcjwmloI4Ly7Y,1718
35
+ sae_lens-6.0.0rc3.dist-info/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
36
+ sae_lens-6.0.0rc3.dist-info/METADATA,sha256=irWiVHtJUXiACNPxZ0fNIVwq1n7n0wxg87c0WSYUkMw,5326
37
+ sae_lens-6.0.0rc3.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
38
+ sae_lens-6.0.0rc3.dist-info/RECORD,,