sae-lens 6.0.0rc2__py3-none-any.whl → 6.0.0rc4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,16 +1,17 @@
1
1
  import contextlib
2
2
  from dataclasses import dataclass
3
- from typing import Any, Generic, Protocol, cast
3
+ from pathlib import Path
4
+ from typing import Any, Callable, Generic, Protocol
4
5
 
5
6
  import torch
6
7
  import wandb
8
+ from safetensors.torch import save_file
7
9
  from torch.optim import Adam
8
- from tqdm import tqdm
9
- from transformer_lens.hook_points import HookedRootModule
10
+ from tqdm.auto import tqdm
10
11
 
11
12
  from sae_lens import __version__
12
- from sae_lens.config import LanguageModelSAERunnerConfig
13
- from sae_lens.evals import EvalConfig, run_evals
13
+ from sae_lens.config import SAETrainerConfig
14
+ from sae_lens.constants import ACTIVATION_SCALER_CFG_FILENAME, SPARSITY_FILENAME
14
15
  from sae_lens.saes.sae import (
15
16
  T_TRAINING_SAE,
16
17
  T_TRAINING_SAE_CONFIG,
@@ -19,8 +20,9 @@ from sae_lens.saes.sae import (
19
20
  TrainStepInput,
20
21
  TrainStepOutput,
21
22
  )
22
- from sae_lens.training.activations_store import ActivationsStore
23
+ from sae_lens.training.activation_scaler import ActivationScaler
23
24
  from sae_lens.training.optim import CoefficientScheduler, get_lr_scheduler
25
+ from sae_lens.training.types import DataProvider
24
26
 
25
27
 
26
28
  def _log_feature_sparsity(
@@ -46,33 +48,39 @@ class TrainSAEOutput:
46
48
  class SaveCheckpointFn(Protocol):
47
49
  def __call__(
48
50
  self,
49
- trainer: "SAETrainer[Any, Any]",
50
- checkpoint_name: str,
51
- wandb_aliases: list[str] | None = None,
51
+ checkpoint_path: Path,
52
52
  ) -> None: ...
53
53
 
54
54
 
55
+ Evaluator = Callable[[T_TRAINING_SAE, DataProvider, ActivationScaler], dict[str, Any]]
56
+
57
+
55
58
  class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
56
59
  """
57
60
  Core SAE class used for inference. For training, see TrainingSAE.
58
61
  """
59
62
 
63
+ data_provider: DataProvider
64
+ activation_scaler: ActivationScaler
65
+ evaluator: Evaluator[T_TRAINING_SAE] | None
66
+
60
67
  def __init__(
61
68
  self,
62
- model: HookedRootModule,
69
+ cfg: SAETrainerConfig,
63
70
  sae: T_TRAINING_SAE,
64
- activation_store: ActivationsStore,
65
- save_checkpoint_fn: SaveCheckpointFn,
66
- cfg: LanguageModelSAERunnerConfig[T_TRAINING_SAE_CONFIG],
71
+ data_provider: DataProvider,
72
+ evaluator: Evaluator[T_TRAINING_SAE] | None = None,
73
+ save_checkpoint_fn: SaveCheckpointFn | None = None,
67
74
  ) -> None:
68
- self.model = model
69
75
  self.sae = sae
70
- self.activations_store = activation_store
71
- self.save_checkpoint = save_checkpoint_fn
76
+ self.data_provider = data_provider
77
+ self.evaluator = evaluator
78
+ self.activation_scaler = ActivationScaler()
79
+ self.save_checkpoint_fn = save_checkpoint_fn
72
80
  self.cfg = cfg
73
81
 
74
82
  self.n_training_steps: int = 0
75
- self.n_training_tokens: int = 0
83
+ self.n_training_samples: int = 0
76
84
  self.started_fine_tuning: bool = False
77
85
 
78
86
  _update_sae_lens_training_version(self.sae)
@@ -82,20 +90,16 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
82
90
  self.checkpoint_thresholds = list(
83
91
  range(
84
92
  0,
85
- cfg.total_training_tokens,
86
- cfg.total_training_tokens // self.cfg.n_checkpoints,
93
+ cfg.total_training_samples,
94
+ cfg.total_training_samples // self.cfg.n_checkpoints,
87
95
  )
88
96
  )[1:]
89
97
 
90
- self.act_freq_scores = torch.zeros(
91
- cast(int, cfg.sae.d_sae),
92
- device=cfg.device,
93
- )
98
+ self.act_freq_scores = torch.zeros(sae.cfg.d_sae, device=cfg.device)
94
99
  self.n_forward_passes_since_fired = torch.zeros(
95
- cast(int, cfg.sae.d_sae),
96
- device=cfg.device,
100
+ sae.cfg.d_sae, device=cfg.device
97
101
  )
98
- self.n_frac_active_tokens = 0
102
+ self.n_frac_active_samples = 0
99
103
  # we don't train the scaling factor (initially)
100
104
  # set requires grad to false for the scaling factor
101
105
  for name, param in self.sae.named_parameters():
@@ -131,7 +135,7 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
131
135
  )
132
136
 
133
137
  # Setup autocast if using
134
- self.scaler = torch.amp.GradScaler(
138
+ self.grad_scaler = torch.amp.GradScaler(
135
139
  device=self.cfg.device, enabled=self.cfg.autocast
136
140
  )
137
141
 
@@ -144,23 +148,9 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
144
148
  else:
145
149
  self.autocast_if_enabled = contextlib.nullcontext()
146
150
 
147
- # Set up eval config
148
-
149
- self.trainer_eval_config = EvalConfig(
150
- batch_size_prompts=self.cfg.eval_batch_size_prompts,
151
- n_eval_reconstruction_batches=self.cfg.n_eval_batches,
152
- n_eval_sparsity_variance_batches=self.cfg.n_eval_batches,
153
- compute_ce_loss=True,
154
- compute_l2_norms=True,
155
- compute_sparsity_metrics=True,
156
- compute_variance_metrics=True,
157
- compute_kl=False,
158
- compute_featurewise_weight_based_metrics=False,
159
- )
160
-
161
151
  @property
162
152
  def feature_sparsity(self) -> torch.Tensor:
163
- return self.act_freq_scores / self.n_frac_active_tokens
153
+ return self.act_freq_scores / self.n_frac_active_samples
164
154
 
165
155
  @property
166
156
  def log_feature_sparsity(self) -> torch.Tensor:
@@ -171,19 +161,24 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
171
161
  return (self.n_forward_passes_since_fired > self.cfg.dead_feature_window).bool()
172
162
 
173
163
  def fit(self) -> T_TRAINING_SAE:
174
- pbar = tqdm(total=self.cfg.total_training_tokens, desc="Training SAE")
175
-
176
- self.activations_store.set_norm_scaling_factor_if_needed()
164
+ self.sae.to(self.cfg.device)
165
+ pbar = tqdm(total=self.cfg.total_training_samples, desc="Training SAE")
166
+
167
+ if self.sae.cfg.normalize_activations == "expected_average_only_in":
168
+ self.activation_scaler.estimate_scaling_factor(
169
+ d_in=self.sae.cfg.d_in,
170
+ data_provider=self.data_provider,
171
+ n_batches_for_norm_estimate=int(1e3),
172
+ )
177
173
 
178
174
  # Train loop
179
- while self.n_training_tokens < self.cfg.total_training_tokens:
175
+ while self.n_training_samples < self.cfg.total_training_samples:
180
176
  # Do a training step.
181
- layer_acts = self.activations_store.next_batch()[:, 0, :].to(
182
- self.sae.device
183
- )
184
- self.n_training_tokens += self.cfg.train_batch_size_tokens
177
+ batch = next(self.data_provider).to(self.sae.device)
178
+ self.n_training_samples += batch.shape[0]
179
+ scaled_batch = self.activation_scaler(batch)
185
180
 
186
- step_output = self._train_step(sae=self.sae, sae_in=layer_acts)
181
+ step_output = self._train_step(sae=self.sae, sae_in=scaled_batch)
187
182
 
188
183
  if self.cfg.logger.log_to_wandb:
189
184
  self._log_train_step(step_output)
@@ -194,22 +189,56 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
194
189
  self._update_pbar(step_output, pbar)
195
190
 
196
191
  # fold the estimated norm scaling factor into the sae weights
197
- if self.activations_store.estimated_norm_scaling_factor is not None:
192
+ if self.activation_scaler.scaling_factor is not None:
198
193
  self.sae.fold_activation_norm_scaling_factor(
199
- self.activations_store.estimated_norm_scaling_factor
194
+ self.activation_scaler.scaling_factor
200
195
  )
201
- self.activations_store.estimated_norm_scaling_factor = None
196
+ self.activation_scaler.scaling_factor = None
202
197
 
203
- # save final sae group to checkpoints folder
198
+ # save final inference sae group to checkpoints folder
204
199
  self.save_checkpoint(
205
- trainer=self,
206
- checkpoint_name=f"final_{self.n_training_tokens}",
200
+ checkpoint_name=f"final_{self.n_training_samples}",
207
201
  wandb_aliases=["final_model"],
202
+ save_inference_model=True,
208
203
  )
209
204
 
210
205
  pbar.close()
211
206
  return self.sae
212
207
 
208
+ def save_checkpoint(
209
+ self,
210
+ checkpoint_name: str,
211
+ wandb_aliases: list[str] | None = None,
212
+ save_inference_model: bool = False,
213
+ ) -> None:
214
+ checkpoint_path = Path(self.cfg.checkpoint_path) / checkpoint_name
215
+ checkpoint_path.mkdir(exist_ok=True, parents=True)
216
+
217
+ save_fn = (
218
+ self.sae.save_inference_model
219
+ if save_inference_model
220
+ else self.sae.save_model
221
+ )
222
+ weights_path, cfg_path = save_fn(str(checkpoint_path))
223
+
224
+ sparsity_path = checkpoint_path / SPARSITY_FILENAME
225
+ save_file({"sparsity": self.log_feature_sparsity}, sparsity_path)
226
+
227
+ activation_scaler_path = checkpoint_path / ACTIVATION_SCALER_CFG_FILENAME
228
+ self.activation_scaler.save(str(activation_scaler_path))
229
+
230
+ if self.cfg.logger.log_to_wandb:
231
+ self.cfg.logger.log(
232
+ self,
233
+ weights_path,
234
+ cfg_path,
235
+ sparsity_path=sparsity_path,
236
+ wandb_aliases=wandb_aliases,
237
+ )
238
+
239
+ if self.save_checkpoint_fn is not None:
240
+ self.save_checkpoint_fn(checkpoint_path=checkpoint_path)
241
+
213
242
  def _train_step(
214
243
  self,
215
244
  sae: T_TRAINING_SAE,
@@ -242,17 +271,19 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
242
271
  self.act_freq_scores += (
243
272
  (train_step_output.feature_acts.abs() > 0).float().sum(0)
244
273
  )
245
- self.n_frac_active_tokens += self.cfg.train_batch_size_tokens
274
+ self.n_frac_active_samples += self.cfg.train_batch_size_samples
246
275
 
247
- # Scaler will rescale gradients if autocast is enabled
248
- self.scaler.scale(
276
+ # Grad scaler will rescale gradients if autocast is enabled
277
+ self.grad_scaler.scale(
249
278
  train_step_output.loss
250
279
  ).backward() # loss.backward() if not autocasting
251
- self.scaler.unscale_(self.optimizer) # needed to clip correctly
280
+ self.grad_scaler.unscale_(self.optimizer) # needed to clip correctly
252
281
  # TODO: Work out if grad norm clipping should be in config / how to test it.
253
282
  torch.nn.utils.clip_grad_norm_(sae.parameters(), 1.0)
254
- self.scaler.step(self.optimizer) # just ctx.optimizer.step() if not autocasting
255
- self.scaler.update()
283
+ self.grad_scaler.step(
284
+ self.optimizer
285
+ ) # just ctx.optimizer.step() if not autocasting
286
+ self.grad_scaler.update()
256
287
 
257
288
  self.optimizer.zero_grad()
258
289
  self.lr_scheduler.step()
@@ -267,7 +298,7 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
267
298
  wandb.log(
268
299
  self._build_train_step_log_dict(
269
300
  output=step_output,
270
- n_training_tokens=self.n_training_tokens,
301
+ n_training_samples=self.n_training_samples,
271
302
  ),
272
303
  step=self.n_training_steps,
273
304
  )
@@ -283,7 +314,7 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
283
314
  def _build_train_step_log_dict(
284
315
  self,
285
316
  output: TrainStepOutput,
286
- n_training_tokens: int,
317
+ n_training_samples: int,
287
318
  ) -> dict[str, Any]:
288
319
  sae_in = output.sae_in
289
320
  sae_out = output.sae_out
@@ -311,7 +342,7 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
311
342
  "sparsity/mean_passes_since_fired": self.n_forward_passes_since_fired.mean().item(),
312
343
  "sparsity/dead_features": self.dead_neurons.sum().item(),
313
344
  "details/current_learning_rate": current_learning_rate,
314
- "details/n_training_tokens": n_training_tokens,
345
+ "details/n_training_samples": n_training_samples,
315
346
  **{
316
347
  f"details/{name}_coefficient": scheduler.value
317
348
  for name, scheduler in self.coefficient_schedulers.items()
@@ -331,30 +362,11 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
331
362
  * self.cfg.logger.eval_every_n_wandb_logs
332
363
  ) == 0:
333
364
  self.sae.eval()
334
- ignore_tokens = set()
335
- if self.activations_store.exclude_special_tokens is not None:
336
- ignore_tokens = set(
337
- self.activations_store.exclude_special_tokens.tolist()
338
- )
339
- eval_metrics, _ = run_evals(
340
- sae=self.sae,
341
- activation_store=self.activations_store,
342
- model=self.model,
343
- eval_config=self.trainer_eval_config,
344
- ignore_tokens=ignore_tokens,
345
- model_kwargs=self.cfg.model_kwargs,
346
- ) # not calculating featurwise metrics here.
347
-
348
- # Remove eval metrics that are already logged during training
349
- eval_metrics.pop("metrics/explained_variance", None)
350
- eval_metrics.pop("metrics/explained_variance_std", None)
351
- eval_metrics.pop("metrics/l0", None)
352
- eval_metrics.pop("metrics/l1", None)
353
- eval_metrics.pop("metrics/mse", None)
354
-
355
- # Remove metrics that are not useful for wandb logging
356
- eval_metrics.pop("metrics/total_tokens_evaluated", None)
357
-
365
+ eval_metrics = (
366
+ self.evaluator(self.sae, self.data_provider, self.activation_scaler)
367
+ if self.evaluator is not None
368
+ else {}
369
+ )
358
370
  for key, value in self.sae.log_histograms().items():
359
371
  eval_metrics[key] = wandb.Histogram(value) # type: ignore
360
372
 
@@ -378,21 +390,18 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
378
390
  @torch.no_grad()
379
391
  def _reset_running_sparsity_stats(self) -> None:
380
392
  self.act_freq_scores = torch.zeros(
381
- self.cfg.sae.d_sae, # type: ignore
393
+ self.sae.cfg.d_sae, # type: ignore
382
394
  device=self.cfg.device,
383
395
  )
384
- self.n_frac_active_tokens = 0
396
+ self.n_frac_active_samples = 0
385
397
 
386
398
  @torch.no_grad()
387
399
  def _checkpoint_if_needed(self):
388
400
  if (
389
401
  self.checkpoint_thresholds
390
- and self.n_training_tokens > self.checkpoint_thresholds[0]
402
+ and self.n_training_samples > self.checkpoint_thresholds[0]
391
403
  ):
392
- self.save_checkpoint(
393
- trainer=self,
394
- checkpoint_name=str(self.n_training_tokens),
395
- )
404
+ self.save_checkpoint(checkpoint_name=str(self.n_training_samples))
396
405
  self.checkpoint_thresholds.pop(0)
397
406
 
398
407
  @torch.no_grad()
@@ -408,7 +417,7 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
408
417
  for loss_name, loss_value in step_output.losses.items()
409
418
  )
410
419
  pbar.set_description(f"{self.n_training_steps}| {loss_strs}")
411
- pbar.update(update_interval * self.cfg.train_batch_size_tokens)
420
+ pbar.update(update_interval * self.cfg.train_batch_size_samples)
412
421
 
413
422
 
414
423
  def _unwrap_item(item: float | torch.Tensor) -> float:
@@ -0,0 +1,5 @@
1
+ from typing import Iterator
2
+
3
+ import torch
4
+
5
+ DataProvider = Iterator[torch.Tensor]
@@ -88,7 +88,7 @@ def _create_default_readme(repo_id: str, sae_ids: Iterable[str]) -> str:
88
88
  ```python
89
89
  from sae_lens import SAE
90
90
 
91
- sae, cfg_dict, sparsity = SAE.from_pretrained("{repo_id}", "<sae_id>")
91
+ sae = SAE.from_pretrained("{repo_id}", "<sae_id>")
92
92
  ```
93
93
  """
94
94
  )
sae_lens/util.py CHANGED
@@ -1,3 +1,4 @@
1
+ import re
1
2
  from dataclasses import asdict, fields, is_dataclass
2
3
  from typing import Sequence, TypeVar
3
4
 
@@ -26,3 +27,21 @@ def filter_valid_dataclass_fields(
26
27
  if whitelist_fields is not None:
27
28
  valid_field_names = valid_field_names.union(whitelist_fields)
28
29
  return {key: val for key, val in source_dict.items() if key in valid_field_names}
30
+
31
+
32
+ def extract_stop_at_layer_from_tlens_hook_name(hook_name: str) -> int | None:
33
+ """Extract the stop_at layer from a HookedTransformer hook name.
34
+
35
+ Returns None if the hook name is not a valid HookedTransformer hook name.
36
+ """
37
+ layer = extract_layer_from_tlens_hook_name(hook_name)
38
+ return None if layer is None else layer + 1
39
+
40
+
41
+ def extract_layer_from_tlens_hook_name(hook_name: str) -> int | None:
42
+ """Extract the layer from a HookedTransformer hook name.
43
+
44
+ Returns None if the hook name is not a valid HookedTransformer hook name.
45
+ """
46
+ hook_match = re.search(r"\.(\d+)\.", hook_name)
47
+ return None if hook_match is None else int(hook_match.group(1))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: sae-lens
3
- Version: 6.0.0rc2
3
+ Version: 6.0.0rc4
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  Keywords: deep-learning,sparse-autoencoders,mechanistic-interpretability,PyTorch
@@ -0,0 +1,37 @@
1
+ sae_lens/__init__.py,sha256=dGZU3Y6iwiuW5oQVTfNvUmfnHO3bHWWbpU-nvXvw9M8,2861
2
+ sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
+ sae_lens/analysis/hooked_sae_transformer.py,sha256=Eyg1Y2hVIHNuiiLOCTgzstOuW6iA-7hPHqaGR8y_vMs,13809
4
+ sae_lens/analysis/neuronpedia_integration.py,sha256=MrENqc81Mc2SMbxGjbwHzpkGUCAFKSf0i4EdaUF2Oj4,18707
5
+ sae_lens/cache_activations_runner.py,sha256=L5hhuU2-zPQr2S3L64GMKKLeMQfqXxwDl8NbuOtrybI,12567
6
+ sae_lens/config.py,sha256=9Lg4HkQvj1t9QZJdmC071lyJMc_iqNQknosT7zOYfwM,27278
7
+ sae_lens/constants.py,sha256=RJlzWx7wLNMNmrdI63naF7-M3enb55vYRN4x1hXx6vI,593
8
+ sae_lens/evals.py,sha256=PIMGQobE9o2bHksDAtQe5bnTMYyHoZKB_elFhDOjrmo,38991
9
+ sae_lens/llm_sae_training_runner.py,sha256=58XbDylw2fPOD7C-ZfSAjeNqJLXB05uHGTuiYVVbXXY,13354
10
+ sae_lens/load_model.py,sha256=dBB_9gO6kWyQ4sXHq7qB8T3YUlXm3PGwYcpR4UVW4QY,8633
11
+ sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
+ sae_lens/loading/pretrained_sae_loaders.py,sha256=kbirwfCg4Ks9Cg3rt78bYxIHMhz5h015n0UTRJQLJY0,31291
13
+ sae_lens/loading/pretrained_saes_directory.py,sha256=4Vn-Jex6SveD7EbxcSOBv8cx1gkPfUMLU1QOP-ww1ZE,3752
14
+ sae_lens/pretokenize_runner.py,sha256=0nHQq3s_d80VS8iVK4-e6y_orAYVO8c4RrLGtIDfK_E,6885
15
+ sae_lens/pretrained_saes.yaml,sha256=C_z-7Lxz6ZIy2V-c-4Xw45eAQ926O9aGjocSNuki0xs,573557
16
+ sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
17
+ sae_lens/saes/__init__.py,sha256=v6mfeDzyGYtT6x5SszAQtkldTXwPE-V_iwOlrT_pDwQ,1008
18
+ sae_lens/saes/gated_sae.py,sha256=0zd66bH04nsaGk3bxHk10hsZofa2GrFbMo15LOsuqgU,9233
19
+ sae_lens/saes/jumprelu_sae.py,sha256=iwmPQJ4XpIxzgosty680u8Zj7x1uVZhM75kPOT3obi0,12060
20
+ sae_lens/saes/sae.py,sha256=HAGkJAj_FIDzbSR1dsG8b2AyMq8UauUU_yx-LvdfjuE,37465
21
+ sae_lens/saes/standard_sae.py,sha256=PfkGLsw_6La3PXHOQL0u7qQsaZsXCJqYCeCcRDj5n64,6274
22
+ sae_lens/saes/topk_sae.py,sha256=kmry1FE1H06OvCfn84V-j2JfWGKcU5b2urwAq_Oq5j4,9893
23
+ sae_lens/tokenization_and_batching.py,sha256=oUAscjy_LPOrOb8_Ty6eLAcZ0B3HB_wiWjWktgolhG0,4314
24
+ sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
25
+ sae_lens/training/activation_scaler.py,sha256=seEE-2Qd2JMHxqgnsNWPt-DGtYGZxWPnOwCGuVNSOtI,1719
26
+ sae_lens/training/activations_store.py,sha256=s3Qvztv2siuuXSuXEUDZYSKq1QQCsqsGXY767kv6grc,32609
27
+ sae_lens/training/mixing_buffer.py,sha256=vDpYG5ZE70szDvBsRKcNHEES3h_WTKJ16qDYk5jPOVA,2015
28
+ sae_lens/training/optim.py,sha256=KXdOym-Ly3f2aFbndRc0JEH0Wa7u1BE5ljxGN3YtouQ,6836
29
+ sae_lens/training/sae_trainer.py,sha256=9K0VudwSTJp9OlCVzaU_ngZ0WlYNrN6-ozTCCAxR9_k,15421
30
+ sae_lens/training/types.py,sha256=qSjmGzXf3MLalygG0psnVjmhX_mpLmL47MQtZfe7qxg,81
31
+ sae_lens/training/upload_saes_to_huggingface.py,sha256=r_WzI1zLtGZ5TzAxuG3xa_8T09j3zXJrWd_vzPsPGkQ,4469
32
+ sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
33
+ sae_lens/util.py,sha256=mCwLAilGMVo8Scm7CIsCafU7GsfmBvCcjwmloI4Ly7Y,1718
34
+ sae_lens-6.0.0rc4.dist-info/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
35
+ sae_lens-6.0.0rc4.dist-info/METADATA,sha256=wOQMSV4yNlpgpGxuE4DI0-q4KzTRYOg1m9ZxpdCsNjk,5326
36
+ sae_lens-6.0.0rc4.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
37
+ sae_lens-6.0.0rc4.dist-info/RECORD,,