sae-lens 5.10.3__py3-none-any.whl → 6.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -101,61 +101,85 @@ def _get_main_lr_scheduler(
101
101
  raise ValueError(f"Unsupported scheduler: {scheduler_name}")
102
102
 
103
103
 
104
- class L1Scheduler:
104
+ class CoefficientScheduler:
105
+ """Linearly warms up a scalar value from 0.0 to a final value."""
106
+
105
107
  def __init__(
106
108
  self,
107
- l1_warm_up_steps: float,
108
- total_steps: int,
109
- final_l1_coefficient: float,
109
+ warm_up_steps: float,
110
+ final_value: float,
110
111
  ):
111
- self.l1_warmup_steps = l1_warm_up_steps
112
- # assume using warm-up
113
- if self.l1_warmup_steps != 0:
114
- self.current_l1_coefficient = 0.0
115
- else:
116
- self.current_l1_coefficient = final_l1_coefficient
117
-
118
- self.final_l1_coefficient = final_l1_coefficient
119
-
112
+ self.warm_up_steps = warm_up_steps
113
+ self.final_value = final_value
120
114
  self.current_step = 0
121
- self.total_steps = total_steps
122
- if not isinstance(self.final_l1_coefficient, (float, int)):
115
+
116
+ if not isinstance(self.final_value, (float, int)):
123
117
  raise TypeError(
124
- f"final_l1_coefficient must be float or int, got {type(self.final_l1_coefficient)}."
118
+ f"final_value must be float or int, got {type(self.final_value)}."
125
119
  )
126
120
 
121
+ # Initialize current_value based on whether warm-up is used
122
+ if self.warm_up_steps > 0:
123
+ self.current_value = 0.0
124
+ else:
125
+ self.current_value = self.final_value
126
+
127
127
  def __repr__(self) -> str:
128
128
  return (
129
- f"L1Scheduler(final_l1_value={self.final_l1_coefficient}, "
130
- f"l1_warmup_steps={self.l1_warmup_steps}, "
131
- f"total_steps={self.total_steps})"
129
+ f"{self.__class__.__name__}(final_value={self.final_value}, "
130
+ f"warm_up_steps={self.warm_up_steps})"
132
131
  )
133
132
 
134
- def step(self):
133
+ def step(self) -> float:
135
134
  """
136
- Updates the l1 coefficient of the sparse autoencoder.
135
+ Updates the scalar value based on the current step.
136
+
137
+ Returns:
138
+ The current scalar value after the step.
137
139
  """
138
- step = self.current_step
139
- if step < self.l1_warmup_steps:
140
- self.current_l1_coefficient = self.final_l1_coefficient * (
141
- (1 + step) / self.l1_warmup_steps
142
- ) # type: ignore
140
+ if self.current_step < self.warm_up_steps:
141
+ self.current_value = self.final_value * (
142
+ (self.current_step + 1) / self.warm_up_steps
143
+ )
143
144
  else:
144
- self.current_l1_coefficient = self.final_l1_coefficient # type: ignore
145
+ # Ensure the value stays at final_value after warm-up
146
+ self.current_value = self.final_value
145
147
 
146
148
  self.current_step += 1
149
+ return self.current_value
147
150
 
148
- def state_dict(self):
149
- """State dict for serializing as part of an SAETrainContext."""
151
+ @property
152
+ def value(self) -> float:
153
+ """Returns the current scalar value."""
154
+ return self.current_value
155
+
156
+ def state_dict(self) -> dict[str, Any]:
157
+ """State dict for serialization."""
150
158
  return {
151
- "l1_warmup_steps": self.l1_warmup_steps,
152
- "total_steps": self.total_steps,
153
- "current_l1_coefficient": self.current_l1_coefficient,
154
- "final_l1_coefficient": self.final_l1_coefficient,
159
+ "warm_up_steps": self.warm_up_steps,
160
+ "final_value": self.final_value,
155
161
  "current_step": self.current_step,
162
+ "current_value": self.current_value,
156
163
  }
157
164
 
158
165
  def load_state_dict(self, state_dict: dict[str, Any]):
159
- """Loads all state apart from attached SAE."""
160
- for k in state_dict:
161
- setattr(self, k, state_dict[k])
166
+ """Loads the scheduler state."""
167
+ self.warm_up_steps = state_dict["warm_up_steps"]
168
+ self.final_value = state_dict["final_value"]
169
+ self.current_step = state_dict["current_step"]
170
+ # Maintain consistency: re-calculate current_value based on loaded step
171
+ # This handles resuming correctly if stopped mid-warmup.
172
+ if self.current_step <= self.warm_up_steps and self.warm_up_steps > 0:
173
+ # Use max(0, ...) to handle case where current_step might be loaded as -1 or similar before first step
174
+ step_for_calc = max(0, self.current_step)
175
+ # Recalculate based on the step *before* the one about to be taken
176
+ # Or simply use the saved current_value if available and consistent
177
+ if "current_value" in state_dict:
178
+ self.current_value = state_dict["current_value"]
179
+ else: # Legacy state dicts might not have current_value
180
+ self.current_value = self.final_value * (
181
+ step_for_calc / self.warm_up_steps
182
+ )
183
+
184
+ else:
185
+ self.current_value = self.final_value
@@ -1,6 +1,6 @@
1
1
  import contextlib
2
2
  from dataclasses import dataclass
3
- from typing import Any, Protocol, cast
3
+ from typing import Any, Generic, Protocol, cast
4
4
 
5
5
  import torch
6
6
  import wandb
@@ -11,16 +11,16 @@ from transformer_lens.hook_points import HookedRootModule
11
11
  from sae_lens import __version__
12
12
  from sae_lens.config import LanguageModelSAERunnerConfig
13
13
  from sae_lens.evals import EvalConfig, run_evals
14
+ from sae_lens.saes.sae import (
15
+ T_TRAINING_SAE,
16
+ T_TRAINING_SAE_CONFIG,
17
+ TrainCoefficientConfig,
18
+ TrainingSAE,
19
+ TrainStepInput,
20
+ TrainStepOutput,
21
+ )
14
22
  from sae_lens.training.activations_store import ActivationsStore
15
- from sae_lens.training.optim import L1Scheduler, get_lr_scheduler
16
- from sae_lens.training.training_sae import TrainingSAE, TrainStepOutput
17
-
18
- # used to map between parameters which are updated during finetuning and the config str.
19
- FINETUNING_PARAMETERS = {
20
- "scale": ["scaling_factor"],
21
- "decoder": ["scaling_factor", "W_dec", "b_dec"],
22
- "unrotated_decoder": ["scaling_factor", "b_dec"],
23
- }
23
+ from sae_lens.training.optim import CoefficientScheduler, get_lr_scheduler
24
24
 
25
25
 
26
26
  def _log_feature_sparsity(
@@ -29,7 +29,7 @@ def _log_feature_sparsity(
29
29
  return torch.log10(feature_sparsity + eps).detach().cpu()
30
30
 
31
31
 
32
- def _update_sae_lens_training_version(sae: TrainingSAE) -> None:
32
+ def _update_sae_lens_training_version(sae: TrainingSAE[Any]) -> None:
33
33
  """
34
34
  Make sure we record the version of SAELens used for the training run
35
35
  """
@@ -38,7 +38,7 @@ def _update_sae_lens_training_version(sae: TrainingSAE) -> None:
38
38
 
39
39
  @dataclass
40
40
  class TrainSAEOutput:
41
- sae: TrainingSAE
41
+ sae: TrainingSAE[Any]
42
42
  checkpoint_path: str
43
43
  log_feature_sparsities: torch.Tensor
44
44
 
@@ -46,13 +46,13 @@ class TrainSAEOutput:
46
46
  class SaveCheckpointFn(Protocol):
47
47
  def __call__(
48
48
  self,
49
- trainer: "SAETrainer",
49
+ trainer: "SAETrainer[Any, Any]",
50
50
  checkpoint_name: str,
51
51
  wandb_aliases: list[str] | None = None,
52
52
  ) -> None: ...
53
53
 
54
54
 
55
- class SAETrainer:
55
+ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
56
56
  """
57
57
  Core SAE class used for inference. For training, see TrainingSAE.
58
58
  """
@@ -60,10 +60,10 @@ class SAETrainer:
60
60
  def __init__(
61
61
  self,
62
62
  model: HookedRootModule,
63
- sae: TrainingSAE,
63
+ sae: T_TRAINING_SAE,
64
64
  activation_store: ActivationsStore,
65
65
  save_checkpoint_fn: SaveCheckpointFn,
66
- cfg: LanguageModelSAERunnerConfig,
66
+ cfg: LanguageModelSAERunnerConfig[T_TRAINING_SAE_CONFIG],
67
67
  ) -> None:
68
68
  self.model = model
69
69
  self.sae = sae
@@ -88,11 +88,11 @@ class SAETrainer:
88
88
  )[1:]
89
89
 
90
90
  self.act_freq_scores = torch.zeros(
91
- cast(int, cfg.d_sae),
91
+ cast(int, cfg.sae.d_sae),
92
92
  device=cfg.device,
93
93
  )
94
94
  self.n_forward_passes_since_fired = torch.zeros(
95
- cast(int, cfg.d_sae),
95
+ cast(int, cfg.sae.d_sae),
96
96
  device=cfg.device,
97
97
  )
98
98
  self.n_frac_active_tokens = 0
@@ -121,11 +121,14 @@ class SAETrainer:
121
121
  lr_end=cfg.lr_end,
122
122
  num_cycles=cfg.n_restart_cycles,
123
123
  )
124
- self.l1_scheduler = L1Scheduler(
125
- l1_warm_up_steps=cfg.l1_warm_up_steps,
126
- total_steps=cfg.total_training_steps,
127
- final_l1_coefficient=cfg.l1_coefficient,
128
- )
124
+ self.coefficient_schedulers = {}
125
+ for name, coeff_cfg in self.sae.get_coefficients().items():
126
+ if not isinstance(coeff_cfg, TrainCoefficientConfig):
127
+ coeff_cfg = TrainCoefficientConfig(value=coeff_cfg, warm_up_steps=0)
128
+ self.coefficient_schedulers[name] = CoefficientScheduler(
129
+ warm_up_steps=coeff_cfg.warm_up_steps,
130
+ final_value=coeff_cfg.value,
131
+ )
129
132
 
130
133
  # Setup autocast if using
131
134
  self.scaler = torch.amp.GradScaler(
@@ -163,15 +166,11 @@ class SAETrainer:
163
166
  def log_feature_sparsity(self) -> torch.Tensor:
164
167
  return _log_feature_sparsity(self.feature_sparsity)
165
168
 
166
- @property
167
- def current_l1_coefficient(self) -> float:
168
- return self.l1_scheduler.current_l1_coefficient
169
-
170
169
  @property
171
170
  def dead_neurons(self) -> torch.Tensor:
172
171
  return (self.n_forward_passes_since_fired > self.cfg.dead_feature_window).bool()
173
172
 
174
- def fit(self) -> TrainingSAE:
173
+ def fit(self) -> T_TRAINING_SAE:
175
174
  pbar = tqdm(total=self.cfg.total_training_tokens, desc="Training SAE")
176
175
 
177
176
  self.activations_store.set_norm_scaling_factor_if_needed()
@@ -186,7 +185,7 @@ class SAETrainer:
186
185
 
187
186
  step_output = self._train_step(sae=self.sae, sae_in=layer_acts)
188
187
 
189
- if self.cfg.log_to_wandb:
188
+ if self.cfg.logger.log_to_wandb:
190
189
  self._log_train_step(step_output)
191
190
  self._run_and_log_evals()
192
191
 
@@ -194,9 +193,6 @@ class SAETrainer:
194
193
  self.n_training_steps += 1
195
194
  self._update_pbar(step_output, pbar)
196
195
 
197
- ### If n_training_tokens > sae_group.cfg.training_tokens, then we should switch to fine-tuning (if we haven't already)
198
- self._begin_finetuning_if_needed()
199
-
200
196
  # fold the estimated norm scaling factor into the sae weights
201
197
  if self.activations_store.estimated_norm_scaling_factor is not None:
202
198
  self.sae.fold_activation_norm_scaling_factor(
@@ -216,17 +212,14 @@ class SAETrainer:
216
212
 
217
213
  def _train_step(
218
214
  self,
219
- sae: TrainingSAE,
215
+ sae: T_TRAINING_SAE,
220
216
  sae_in: torch.Tensor,
221
217
  ) -> TrainStepOutput:
222
218
  sae.train()
223
- # Make sure the W_dec is still zero-norm
224
- if self.cfg.normalize_sae_decoder:
225
- sae.set_decoder_norm_to_unit_norm()
226
219
 
227
220
  # log and then reset the feature sparsity every feature_sampling_window steps
228
221
  if (self.n_training_steps + 1) % self.cfg.feature_sampling_window == 0:
229
- if self.cfg.log_to_wandb:
222
+ if self.cfg.logger.log_to_wandb:
230
223
  sparsity_log_dict = self._build_sparsity_log_dict()
231
224
  wandb.log(sparsity_log_dict, step=self.n_training_steps)
232
225
  self._reset_running_sparsity_stats()
@@ -235,9 +228,11 @@ class SAETrainer:
235
228
  # https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html
236
229
  with self.autocast_if_enabled:
237
230
  train_step_output = self.sae.training_forward_pass(
238
- sae_in=sae_in,
239
- dead_neuron_mask=self.dead_neurons,
240
- current_l1_coefficient=self.current_l1_coefficient,
231
+ step_input=TrainStepInput(
232
+ sae_in=sae_in,
233
+ dead_neuron_mask=self.dead_neurons,
234
+ coefficients=self.get_coefficients(),
235
+ ),
241
236
  )
242
237
 
243
238
  with torch.no_grad():
@@ -259,18 +254,16 @@ class SAETrainer:
259
254
  self.scaler.step(self.optimizer) # just ctx.optimizer.step() if not autocasting
260
255
  self.scaler.update()
261
256
 
262
- if self.cfg.normalize_sae_decoder:
263
- sae.remove_gradient_parallel_to_decoder_directions()
264
-
265
257
  self.optimizer.zero_grad()
266
258
  self.lr_scheduler.step()
267
- self.l1_scheduler.step()
259
+ for scheduler in self.coefficient_schedulers.values():
260
+ scheduler.step()
268
261
 
269
262
  return train_step_output
270
263
 
271
264
  @torch.no_grad()
272
265
  def _log_train_step(self, step_output: TrainStepOutput):
273
- if (self.n_training_steps + 1) % self.cfg.wandb_log_frequency == 0:
266
+ if (self.n_training_steps + 1) % self.cfg.logger.wandb_log_frequency == 0:
274
267
  wandb.log(
275
268
  self._build_train_step_log_dict(
276
269
  output=step_output,
@@ -279,6 +272,13 @@ class SAETrainer:
279
272
  step=self.n_training_steps,
280
273
  )
281
274
 
275
+ @torch.no_grad()
276
+ def get_coefficients(self) -> dict[str, float]:
277
+ return {
278
+ name: scheduler.value
279
+ for name, scheduler in self.coefficient_schedulers.items()
280
+ }
281
+
282
282
  @torch.no_grad()
283
283
  def _build_train_step_log_dict(
284
284
  self,
@@ -311,19 +311,15 @@ class SAETrainer:
311
311
  "sparsity/mean_passes_since_fired": self.n_forward_passes_since_fired.mean().item(),
312
312
  "sparsity/dead_features": self.dead_neurons.sum().item(),
313
313
  "details/current_learning_rate": current_learning_rate,
314
- "details/current_l1_coefficient": self.current_l1_coefficient,
315
314
  "details/n_training_tokens": n_training_tokens,
315
+ **{
316
+ f"details/{name}_coefficient": scheduler.value
317
+ for name, scheduler in self.coefficient_schedulers.items()
318
+ },
316
319
  }
317
320
  for loss_name, loss_value in output.losses.items():
318
321
  loss_item = _unwrap_item(loss_value)
319
- # special case for l1 loss, which we normalize by the l1 coefficient
320
- if loss_name == "l1_loss":
321
- log_dict[f"losses/{loss_name}"] = (
322
- loss_item / self.current_l1_coefficient
323
- )
324
- log_dict[f"losses/raw_{loss_name}"] = loss_item
325
- else:
326
- log_dict[f"losses/{loss_name}"] = loss_item
322
+ log_dict[f"losses/{loss_name}"] = loss_item
327
323
 
328
324
  return log_dict
329
325
 
@@ -331,7 +327,8 @@ class SAETrainer:
331
327
  def _run_and_log_evals(self):
332
328
  # record loss frequently, but not all the time.
333
329
  if (self.n_training_steps + 1) % (
334
- self.cfg.wandb_log_frequency * self.cfg.eval_every_n_wandb_logs
330
+ self.cfg.logger.wandb_log_frequency
331
+ * self.cfg.logger.eval_every_n_wandb_logs
335
332
  ) == 0:
336
333
  self.sae.eval()
337
334
  ignore_tokens = set()
@@ -358,17 +355,8 @@ class SAETrainer:
358
355
  # Remove metrics that are not useful for wandb logging
359
356
  eval_metrics.pop("metrics/total_tokens_evaluated", None)
360
357
 
361
- W_dec_norm_dist = self.sae.W_dec.detach().float().norm(dim=1).cpu().numpy()
362
- eval_metrics["weights/W_dec_norms"] = wandb.Histogram(W_dec_norm_dist) # type: ignore
363
-
364
- if self.sae.cfg.architecture == "standard":
365
- b_e_dist = self.sae.b_enc.detach().float().cpu().numpy()
366
- eval_metrics["weights/b_e"] = wandb.Histogram(b_e_dist) # type: ignore
367
- elif self.sae.cfg.architecture == "gated":
368
- b_gate_dist = self.sae.b_gate.detach().float().cpu().numpy()
369
- eval_metrics["weights/b_gate"] = wandb.Histogram(b_gate_dist) # type: ignore
370
- b_mag_dist = self.sae.b_mag.detach().float().cpu().numpy()
371
- eval_metrics["weights/b_mag"] = wandb.Histogram(b_mag_dist) # type: ignore
358
+ for key, value in self.sae.log_histograms().items():
359
+ eval_metrics[key] = wandb.Histogram(value) # type: ignore
372
360
 
373
361
  wandb.log(
374
362
  eval_metrics,
@@ -390,7 +378,7 @@ class SAETrainer:
390
378
  @torch.no_grad()
391
379
  def _reset_running_sparsity_stats(self) -> None:
392
380
  self.act_freq_scores = torch.zeros(
393
- self.cfg.d_sae, # type: ignore
381
+ self.cfg.sae.d_sae, # type: ignore
394
382
  device=self.cfg.device,
395
383
  )
396
384
  self.n_frac_active_tokens = 0
@@ -422,25 +410,6 @@ class SAETrainer:
422
410
  pbar.set_description(f"{self.n_training_steps}| {loss_strs}")
423
411
  pbar.update(update_interval * self.cfg.train_batch_size_tokens)
424
412
 
425
- def _begin_finetuning_if_needed(self):
426
- if (not self.started_fine_tuning) and (
427
- self.n_training_tokens > self.cfg.training_tokens
428
- ):
429
- self.started_fine_tuning = True
430
-
431
- # finetuning method should be set in the config
432
- # if not, then we don't finetune
433
- if not isinstance(self.cfg.finetuning_method, str):
434
- return
435
-
436
- for name, param in self.sae.named_parameters():
437
- if name in FINETUNING_PARAMETERS[self.cfg.finetuning_method]:
438
- param.requires_grad = True
439
- else:
440
- param.requires_grad = False
441
-
442
- self.finetuning = True
443
-
444
413
 
445
414
  def _unwrap_item(item: float | torch.Tensor) -> float:
446
415
  return item.item() if isinstance(item, torch.Tensor) else item
@@ -2,23 +2,24 @@ import io
2
2
  from pathlib import Path
3
3
  from tempfile import TemporaryDirectory
4
4
  from textwrap import dedent
5
- from typing import Iterable
5
+ from typing import Any, Iterable
6
6
 
7
7
  from huggingface_hub import HfApi, create_repo, get_hf_file_metadata, hf_hub_url
8
8
  from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
9
9
  from tqdm.autonotebook import tqdm
10
10
 
11
11
  from sae_lens import logger
12
- from sae_lens.config import (
12
+ from sae_lens.constants import (
13
+ RUNNER_CFG_FILENAME,
13
14
  SAE_CFG_FILENAME,
14
15
  SAE_WEIGHTS_FILENAME,
15
16
  SPARSITY_FILENAME,
16
17
  )
17
- from sae_lens.sae import SAE
18
+ from sae_lens.saes.sae import SAE
18
19
 
19
20
 
20
21
  def upload_saes_to_huggingface(
21
- saes_dict: dict[str, SAE | Path | str],
22
+ saes_dict: dict[str, SAE[Any] | Path | str],
22
23
  hf_repo_id: str,
23
24
  hf_revision: str = "main",
24
25
  show_progress: bool = True,
@@ -119,11 +120,16 @@ def _upload_sae(api: HfApi, sae_path: Path, repo_id: str, sae_id: str, revision:
119
120
  revision=revision,
120
121
  repo_type="model",
121
122
  commit_message=f"Upload SAE {sae_id}",
122
- allow_patterns=[SAE_CFG_FILENAME, SAE_WEIGHTS_FILENAME, SPARSITY_FILENAME],
123
+ allow_patterns=[
124
+ SAE_CFG_FILENAME,
125
+ SAE_WEIGHTS_FILENAME,
126
+ SPARSITY_FILENAME,
127
+ RUNNER_CFG_FILENAME,
128
+ ],
123
129
  )
124
130
 
125
131
 
126
- def _build_sae_path(sae_ref: SAE | Path | str, tmp_dir: str) -> Path:
132
+ def _build_sae_path(sae_ref: SAE[Any] | Path | str, tmp_dir: str) -> Path:
127
133
  if isinstance(sae_ref, SAE):
128
134
  sae_ref.save_model(tmp_dir)
129
135
  return Path(tmp_dir)
sae_lens/util.py ADDED
@@ -0,0 +1,28 @@
1
+ from dataclasses import asdict, fields, is_dataclass
2
+ from typing import Sequence, TypeVar
3
+
4
+ K = TypeVar("K")
5
+ V = TypeVar("V")
6
+
7
+
8
+ def filter_valid_dataclass_fields(
9
+ source: dict[str, V] | object,
10
+ destination: object | type,
11
+ whitelist_fields: Sequence[str] | None = None,
12
+ ) -> dict[str, V]:
13
+ """Filter a source dict or dataclass instance to only include fields that are present in the destination dataclass."""
14
+
15
+ if not is_dataclass(destination):
16
+ raise ValueError(f"{destination} is not a dataclass")
17
+
18
+ if is_dataclass(source) and not isinstance(source, type):
19
+ source_dict = asdict(source)
20
+ elif isinstance(source, dict):
21
+ source_dict = source
22
+ else:
23
+ raise ValueError(f"{source} is not a dict or dataclass")
24
+
25
+ valid_field_names = {field.name for field in fields(destination)}
26
+ if whitelist_fields is not None:
27
+ valid_field_names = valid_field_names.union(whitelist_fields)
28
+ return {key: val for key, val in source_dict.items() if key in valid_field_names}
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: sae-lens
3
- Version: 5.10.3
3
+ Version: 6.0.0rc2
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  Keywords: deep-learning,sparse-autoencoders,mechanistic-interpretability,PyTorch
@@ -0,0 +1,35 @@
1
+ sae_lens/__init__.py,sha256=JZATcdlWGVOXYTHb41hn7dPp7pR2tWgpLAz2ztQOE-A,2747
2
+ sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
+ sae_lens/analysis/hooked_sae_transformer.py,sha256=Eyg1Y2hVIHNuiiLOCTgzstOuW6iA-7hPHqaGR8y_vMs,13809
4
+ sae_lens/analysis/neuronpedia_integration.py,sha256=DlI08ThI0zwMrBthICt1OFCMyqmaCUDeZxhOk7b7teY,18680
5
+ sae_lens/cache_activations_runner.py,sha256=27jp2hFxZj4foWCRCJJd2VCwYJtMgkvPx6MuIhQBofc,12591
6
+ sae_lens/config.py,sha256=Ff6MRzRlVk8xtgkvHdJEmuPh9Owc10XIWBaUwdypzkU,26062
7
+ sae_lens/constants.py,sha256=HSiSp0j2Umak2buT30seFhkmj7KNuPmB3u4yLXrgfOg,462
8
+ sae_lens/evals.py,sha256=aR0pJMBWBUdZElXPcxUyNnNYWbM2LC5UeaESKAwdOMY,39098
9
+ sae_lens/load_model.py,sha256=tE70sXsyyyGYW7o506O3eiw1MXyyW6DCQojLG49hWYI,6771
10
+ sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
+ sae_lens/loading/pretrained_sae_loaders.py,sha256=IgQ-XSJ5VTLCzmJavPmk1vExBVB-36wW7w-ZNo7tzPY,31214
12
+ sae_lens/loading/pretrained_saes_directory.py,sha256=4Vn-Jex6SveD7EbxcSOBv8cx1gkPfUMLU1QOP-ww1ZE,3752
13
+ sae_lens/pretokenize_runner.py,sha256=0nHQq3s_d80VS8iVK4-e6y_orAYVO8c4RrLGtIDfK_E,6885
14
+ sae_lens/pretrained_saes.yaml,sha256=C_z-7Lxz6ZIy2V-c-4Xw45eAQ926O9aGjocSNuki0xs,573557
15
+ sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
16
+ sae_lens/sae_training_runner.py,sha256=lI_d3ywS312dIz0wctm_Sgt3W9ffBOS7ahnDXBljX1s,8320
17
+ sae_lens/saes/__init__.py,sha256=v6mfeDzyGYtT6x5SszAQtkldTXwPE-V_iwOlrT_pDwQ,1008
18
+ sae_lens/saes/gated_sae.py,sha256=IgWvZxeJpdiu7VqeUnJLC-VWVhz6o8OXvmwCS-LJ-WQ,9426
19
+ sae_lens/saes/jumprelu_sae.py,sha256=lkhafpoYYn4-62tBlmmufmUomoo3CmFFQQ3NNylBNSM,12264
20
+ sae_lens/saes/sae.py,sha256=edJK3VFzOVBPXUX6QJ5fhhoY0wcfEisDmVXiqFRA7Xg,35089
21
+ sae_lens/saes/standard_sae.py,sha256=tMs6Z6Cv44PWa7pLo53xhXFnHMvO5BM6eVYHtRPLpos,6652
22
+ sae_lens/saes/topk_sae.py,sha256=CfF59K4J2XwUvztwg4fBbvFO3PyucLkg4Elkxdk0ozs,9786
23
+ sae_lens/tokenization_and_batching.py,sha256=oUAscjy_LPOrOb8_Ty6eLAcZ0B3HB_wiWjWktgolhG0,4314
24
+ sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
25
+ sae_lens/training/activations_store.py,sha256=5V5dExeXWoE0dw-ePOZVnQIbBJwrepRMdsQrRam9Lg8,36790
26
+ sae_lens/training/geometric_median.py,sha256=3kH8ZJAgKStlnZgs6s1uYGDYh004Bl0r4RLhuwT3lBY,3719
27
+ sae_lens/training/optim.py,sha256=KXdOym-Ly3f2aFbndRc0JEH0Wa7u1BE5ljxGN3YtouQ,6836
28
+ sae_lens/training/sae_trainer.py,sha256=zYAk_9QJ8AJi2TjDZ1qW_lyoovSBqrJvBHzyYgb89ZY,15251
29
+ sae_lens/training/upload_saes_to_huggingface.py,sha256=tXvR4j25IgMjJ8R9oczwSdy00Tg-P_jAtnPHRt8yF64,4489
30
+ sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
31
+ sae_lens/util.py,sha256=4lqtl7HT9OiyRK8fe8nXtkcn2lOR1uX7ANrAClf6Bv8,1026
32
+ sae_lens-6.0.0rc2.dist-info/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
33
+ sae_lens-6.0.0rc2.dist-info/METADATA,sha256=Z8Zwb6EknAPB5dOvfduYZewr4nldot-1dQoqz50Co3k,5326
34
+ sae_lens-6.0.0rc2.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
35
+ sae_lens-6.0.0rc2.dist-info/RECORD,,