sae-lens 6.0.0rc1__py3-none-any.whl → 6.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  import contextlib
2
2
  from dataclasses import dataclass
3
- from typing import Any, Protocol, cast
3
+ from typing import Any, Generic, Protocol, cast
4
4
 
5
5
  import torch
6
6
  import wandb
@@ -11,16 +11,16 @@ from transformer_lens.hook_points import HookedRootModule
11
11
  from sae_lens import __version__
12
12
  from sae_lens.config import LanguageModelSAERunnerConfig
13
13
  from sae_lens.evals import EvalConfig, run_evals
14
- from sae_lens.saes.sae import TrainingSAE, TrainStepInput, TrainStepOutput
14
+ from sae_lens.saes.sae import (
15
+ T_TRAINING_SAE,
16
+ T_TRAINING_SAE_CONFIG,
17
+ TrainCoefficientConfig,
18
+ TrainingSAE,
19
+ TrainStepInput,
20
+ TrainStepOutput,
21
+ )
15
22
  from sae_lens.training.activations_store import ActivationsStore
16
- from sae_lens.training.optim import L1Scheduler, get_lr_scheduler
17
-
18
- # used to map between parameters which are updated during finetuning and the config str.
19
- FINETUNING_PARAMETERS = {
20
- "scale": ["scaling_factor"],
21
- "decoder": ["scaling_factor", "W_dec", "b_dec"],
22
- "unrotated_decoder": ["scaling_factor", "b_dec"],
23
- }
23
+ from sae_lens.training.optim import CoefficientScheduler, get_lr_scheduler
24
24
 
25
25
 
26
26
  def _log_feature_sparsity(
@@ -29,7 +29,7 @@ def _log_feature_sparsity(
29
29
  return torch.log10(feature_sparsity + eps).detach().cpu()
30
30
 
31
31
 
32
- def _update_sae_lens_training_version(sae: TrainingSAE) -> None:
32
+ def _update_sae_lens_training_version(sae: TrainingSAE[Any]) -> None:
33
33
  """
34
34
  Make sure we record the version of SAELens used for the training run
35
35
  """
@@ -38,7 +38,7 @@ def _update_sae_lens_training_version(sae: TrainingSAE) -> None:
38
38
 
39
39
  @dataclass
40
40
  class TrainSAEOutput:
41
- sae: TrainingSAE
41
+ sae: TrainingSAE[Any]
42
42
  checkpoint_path: str
43
43
  log_feature_sparsities: torch.Tensor
44
44
 
@@ -46,13 +46,13 @@ class TrainSAEOutput:
46
46
  class SaveCheckpointFn(Protocol):
47
47
  def __call__(
48
48
  self,
49
- trainer: "SAETrainer",
49
+ trainer: "SAETrainer[Any, Any]",
50
50
  checkpoint_name: str,
51
51
  wandb_aliases: list[str] | None = None,
52
52
  ) -> None: ...
53
53
 
54
54
 
55
- class SAETrainer:
55
+ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
56
56
  """
57
57
  Core SAE class used for inference. For training, see TrainingSAE.
58
58
  """
@@ -60,10 +60,10 @@ class SAETrainer:
60
60
  def __init__(
61
61
  self,
62
62
  model: HookedRootModule,
63
- sae: TrainingSAE,
63
+ sae: T_TRAINING_SAE,
64
64
  activation_store: ActivationsStore,
65
65
  save_checkpoint_fn: SaveCheckpointFn,
66
- cfg: LanguageModelSAERunnerConfig,
66
+ cfg: LanguageModelSAERunnerConfig[T_TRAINING_SAE_CONFIG],
67
67
  ) -> None:
68
68
  self.model = model
69
69
  self.sae = sae
@@ -88,11 +88,11 @@ class SAETrainer:
88
88
  )[1:]
89
89
 
90
90
  self.act_freq_scores = torch.zeros(
91
- cast(int, cfg.d_sae),
91
+ cast(int, cfg.sae.d_sae),
92
92
  device=cfg.device,
93
93
  )
94
94
  self.n_forward_passes_since_fired = torch.zeros(
95
- cast(int, cfg.d_sae),
95
+ cast(int, cfg.sae.d_sae),
96
96
  device=cfg.device,
97
97
  )
98
98
  self.n_frac_active_tokens = 0
@@ -121,11 +121,14 @@ class SAETrainer:
121
121
  lr_end=cfg.lr_end,
122
122
  num_cycles=cfg.n_restart_cycles,
123
123
  )
124
- self.l1_scheduler = L1Scheduler(
125
- l1_warm_up_steps=cfg.l1_warm_up_steps,
126
- total_steps=cfg.total_training_steps,
127
- final_l1_coefficient=cfg.l1_coefficient,
128
- )
124
+ self.coefficient_schedulers = {}
125
+ for name, coeff_cfg in self.sae.get_coefficients().items():
126
+ if not isinstance(coeff_cfg, TrainCoefficientConfig):
127
+ coeff_cfg = TrainCoefficientConfig(value=coeff_cfg, warm_up_steps=0)
128
+ self.coefficient_schedulers[name] = CoefficientScheduler(
129
+ warm_up_steps=coeff_cfg.warm_up_steps,
130
+ final_value=coeff_cfg.value,
131
+ )
129
132
 
130
133
  # Setup autocast if using
131
134
  self.scaler = torch.amp.GradScaler(
@@ -163,15 +166,11 @@ class SAETrainer:
163
166
  def log_feature_sparsity(self) -> torch.Tensor:
164
167
  return _log_feature_sparsity(self.feature_sparsity)
165
168
 
166
- @property
167
- def current_l1_coefficient(self) -> float:
168
- return self.l1_scheduler.current_l1_coefficient
169
-
170
169
  @property
171
170
  def dead_neurons(self) -> torch.Tensor:
172
171
  return (self.n_forward_passes_since_fired > self.cfg.dead_feature_window).bool()
173
172
 
174
- def fit(self) -> TrainingSAE:
173
+ def fit(self) -> T_TRAINING_SAE:
175
174
  pbar = tqdm(total=self.cfg.total_training_tokens, desc="Training SAE")
176
175
 
177
176
  self.activations_store.set_norm_scaling_factor_if_needed()
@@ -194,9 +193,6 @@ class SAETrainer:
194
193
  self.n_training_steps += 1
195
194
  self._update_pbar(step_output, pbar)
196
195
 
197
- ### If n_training_tokens > sae_group.cfg.training_tokens, then we should switch to fine-tuning (if we haven't already)
198
- self._begin_finetuning_if_needed()
199
-
200
196
  # fold the estimated norm scaling factor into the sae weights
201
197
  if self.activations_store.estimated_norm_scaling_factor is not None:
202
198
  self.sae.fold_activation_norm_scaling_factor(
@@ -216,13 +212,10 @@ class SAETrainer:
216
212
 
217
213
  def _train_step(
218
214
  self,
219
- sae: TrainingSAE,
215
+ sae: T_TRAINING_SAE,
220
216
  sae_in: torch.Tensor,
221
217
  ) -> TrainStepOutput:
222
218
  sae.train()
223
- # Make sure the W_dec is still zero-norm
224
- if self.cfg.normalize_sae_decoder:
225
- sae.set_decoder_norm_to_unit_norm()
226
219
 
227
220
  # log and then reset the feature sparsity every feature_sampling_window steps
228
221
  if (self.n_training_steps + 1) % self.cfg.feature_sampling_window == 0:
@@ -238,7 +231,7 @@ class SAETrainer:
238
231
  step_input=TrainStepInput(
239
232
  sae_in=sae_in,
240
233
  dead_neuron_mask=self.dead_neurons,
241
- current_l1_coefficient=self.current_l1_coefficient,
234
+ coefficients=self.get_coefficients(),
242
235
  ),
243
236
  )
244
237
 
@@ -261,12 +254,10 @@ class SAETrainer:
261
254
  self.scaler.step(self.optimizer) # just ctx.optimizer.step() if not autocasting
262
255
  self.scaler.update()
263
256
 
264
- if self.cfg.normalize_sae_decoder:
265
- sae.remove_gradient_parallel_to_decoder_directions()
266
-
267
257
  self.optimizer.zero_grad()
268
258
  self.lr_scheduler.step()
269
- self.l1_scheduler.step()
259
+ for scheduler in self.coefficient_schedulers.values():
260
+ scheduler.step()
270
261
 
271
262
  return train_step_output
272
263
 
@@ -281,6 +272,13 @@ class SAETrainer:
281
272
  step=self.n_training_steps,
282
273
  )
283
274
 
275
+ @torch.no_grad()
276
+ def get_coefficients(self) -> dict[str, float]:
277
+ return {
278
+ name: scheduler.value
279
+ for name, scheduler in self.coefficient_schedulers.items()
280
+ }
281
+
284
282
  @torch.no_grad()
285
283
  def _build_train_step_log_dict(
286
284
  self,
@@ -313,19 +311,15 @@ class SAETrainer:
313
311
  "sparsity/mean_passes_since_fired": self.n_forward_passes_since_fired.mean().item(),
314
312
  "sparsity/dead_features": self.dead_neurons.sum().item(),
315
313
  "details/current_learning_rate": current_learning_rate,
316
- "details/current_l1_coefficient": self.current_l1_coefficient,
317
314
  "details/n_training_tokens": n_training_tokens,
315
+ **{
316
+ f"details/{name}_coefficient": scheduler.value
317
+ for name, scheduler in self.coefficient_schedulers.items()
318
+ },
318
319
  }
319
320
  for loss_name, loss_value in output.losses.items():
320
321
  loss_item = _unwrap_item(loss_value)
321
- # special case for l1 loss, which we normalize by the l1 coefficient
322
- if loss_name == "l1_loss":
323
- log_dict[f"losses/{loss_name}"] = (
324
- loss_item / self.current_l1_coefficient
325
- )
326
- log_dict[f"losses/raw_{loss_name}"] = loss_item
327
- else:
328
- log_dict[f"losses/{loss_name}"] = loss_item
322
+ log_dict[f"losses/{loss_name}"] = loss_item
329
323
 
330
324
  return log_dict
331
325
 
@@ -384,7 +378,7 @@ class SAETrainer:
384
378
  @torch.no_grad()
385
379
  def _reset_running_sparsity_stats(self) -> None:
386
380
  self.act_freq_scores = torch.zeros(
387
- self.cfg.d_sae, # type: ignore
381
+ self.cfg.sae.d_sae, # type: ignore
388
382
  device=self.cfg.device,
389
383
  )
390
384
  self.n_frac_active_tokens = 0
@@ -416,25 +410,6 @@ class SAETrainer:
416
410
  pbar.set_description(f"{self.n_training_steps}| {loss_strs}")
417
411
  pbar.update(update_interval * self.cfg.train_batch_size_tokens)
418
412
 
419
- def _begin_finetuning_if_needed(self):
420
- if (not self.started_fine_tuning) and (
421
- self.n_training_tokens > self.cfg.training_tokens
422
- ):
423
- self.started_fine_tuning = True
424
-
425
- # finetuning method should be set in the config
426
- # if not, then we don't finetune
427
- if not isinstance(self.cfg.finetuning_method, str):
428
- return
429
-
430
- for name, param in self.sae.named_parameters():
431
- if name in FINETUNING_PARAMETERS[self.cfg.finetuning_method]:
432
- param.requires_grad = True
433
- else:
434
- param.requires_grad = False
435
-
436
- self.finetuning = True
437
-
438
413
 
439
414
  def _unwrap_item(item: float | torch.Tensor) -> float:
440
415
  return item.item() if isinstance(item, torch.Tensor) else item
@@ -2,14 +2,15 @@ import io
2
2
  from pathlib import Path
3
3
  from tempfile import TemporaryDirectory
4
4
  from textwrap import dedent
5
- from typing import Iterable
5
+ from typing import Any, Iterable
6
6
 
7
7
  from huggingface_hub import HfApi, create_repo, get_hf_file_metadata, hf_hub_url
8
8
  from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
9
9
  from tqdm.autonotebook import tqdm
10
10
 
11
11
  from sae_lens import logger
12
- from sae_lens.config import (
12
+ from sae_lens.constants import (
13
+ RUNNER_CFG_FILENAME,
13
14
  SAE_CFG_FILENAME,
14
15
  SAE_WEIGHTS_FILENAME,
15
16
  SPARSITY_FILENAME,
@@ -18,7 +19,7 @@ from sae_lens.saes.sae import SAE
18
19
 
19
20
 
20
21
  def upload_saes_to_huggingface(
21
- saes_dict: dict[str, SAE | Path | str],
22
+ saes_dict: dict[str, SAE[Any] | Path | str],
22
23
  hf_repo_id: str,
23
24
  hf_revision: str = "main",
24
25
  show_progress: bool = True,
@@ -119,11 +120,16 @@ def _upload_sae(api: HfApi, sae_path: Path, repo_id: str, sae_id: str, revision:
119
120
  revision=revision,
120
121
  repo_type="model",
121
122
  commit_message=f"Upload SAE {sae_id}",
122
- allow_patterns=[SAE_CFG_FILENAME, SAE_WEIGHTS_FILENAME, SPARSITY_FILENAME],
123
+ allow_patterns=[
124
+ SAE_CFG_FILENAME,
125
+ SAE_WEIGHTS_FILENAME,
126
+ SPARSITY_FILENAME,
127
+ RUNNER_CFG_FILENAME,
128
+ ],
123
129
  )
124
130
 
125
131
 
126
- def _build_sae_path(sae_ref: SAE | Path | str, tmp_dir: str) -> Path:
132
+ def _build_sae_path(sae_ref: SAE[Any] | Path | str, tmp_dir: str) -> Path:
127
133
  if isinstance(sae_ref, SAE):
128
134
  sae_ref.save_model(tmp_dir)
129
135
  return Path(tmp_dir)
sae_lens/util.py ADDED
@@ -0,0 +1,28 @@
1
+ from dataclasses import asdict, fields, is_dataclass
2
+ from typing import Sequence, TypeVar
3
+
4
+ K = TypeVar("K")
5
+ V = TypeVar("V")
6
+
7
+
8
+ def filter_valid_dataclass_fields(
9
+ source: dict[str, V] | object,
10
+ destination: object | type,
11
+ whitelist_fields: Sequence[str] | None = None,
12
+ ) -> dict[str, V]:
13
+ """Filter a source dict or dataclass instance to only include fields that are present in the destination dataclass."""
14
+
15
+ if not is_dataclass(destination):
16
+ raise ValueError(f"{destination} is not a dataclass")
17
+
18
+ if is_dataclass(source) and not isinstance(source, type):
19
+ source_dict = asdict(source)
20
+ elif isinstance(source, dict):
21
+ source_dict = source
22
+ else:
23
+ raise ValueError(f"{source} is not a dict or dataclass")
24
+
25
+ valid_field_names = {field.name for field in fields(destination)}
26
+ if whitelist_fields is not None:
27
+ valid_field_names = valid_field_names.union(whitelist_fields)
28
+ return {key: val for key, val in source_dict.items() if key in valid_field_names}
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: sae-lens
3
- Version: 6.0.0rc1
3
+ Version: 6.0.0rc2
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  Keywords: deep-learning,sparse-autoencoders,mechanistic-interpretability,PyTorch
@@ -0,0 +1,35 @@
1
+ sae_lens/__init__.py,sha256=JZATcdlWGVOXYTHb41hn7dPp7pR2tWgpLAz2ztQOE-A,2747
2
+ sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
+ sae_lens/analysis/hooked_sae_transformer.py,sha256=Eyg1Y2hVIHNuiiLOCTgzstOuW6iA-7hPHqaGR8y_vMs,13809
4
+ sae_lens/analysis/neuronpedia_integration.py,sha256=DlI08ThI0zwMrBthICt1OFCMyqmaCUDeZxhOk7b7teY,18680
5
+ sae_lens/cache_activations_runner.py,sha256=27jp2hFxZj4foWCRCJJd2VCwYJtMgkvPx6MuIhQBofc,12591
6
+ sae_lens/config.py,sha256=Ff6MRzRlVk8xtgkvHdJEmuPh9Owc10XIWBaUwdypzkU,26062
7
+ sae_lens/constants.py,sha256=HSiSp0j2Umak2buT30seFhkmj7KNuPmB3u4yLXrgfOg,462
8
+ sae_lens/evals.py,sha256=aR0pJMBWBUdZElXPcxUyNnNYWbM2LC5UeaESKAwdOMY,39098
9
+ sae_lens/load_model.py,sha256=tE70sXsyyyGYW7o506O3eiw1MXyyW6DCQojLG49hWYI,6771
10
+ sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
+ sae_lens/loading/pretrained_sae_loaders.py,sha256=IgQ-XSJ5VTLCzmJavPmk1vExBVB-36wW7w-ZNo7tzPY,31214
12
+ sae_lens/loading/pretrained_saes_directory.py,sha256=4Vn-Jex6SveD7EbxcSOBv8cx1gkPfUMLU1QOP-ww1ZE,3752
13
+ sae_lens/pretokenize_runner.py,sha256=0nHQq3s_d80VS8iVK4-e6y_orAYVO8c4RrLGtIDfK_E,6885
14
+ sae_lens/pretrained_saes.yaml,sha256=C_z-7Lxz6ZIy2V-c-4Xw45eAQ926O9aGjocSNuki0xs,573557
15
+ sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
16
+ sae_lens/sae_training_runner.py,sha256=lI_d3ywS312dIz0wctm_Sgt3W9ffBOS7ahnDXBljX1s,8320
17
+ sae_lens/saes/__init__.py,sha256=v6mfeDzyGYtT6x5SszAQtkldTXwPE-V_iwOlrT_pDwQ,1008
18
+ sae_lens/saes/gated_sae.py,sha256=IgWvZxeJpdiu7VqeUnJLC-VWVhz6o8OXvmwCS-LJ-WQ,9426
19
+ sae_lens/saes/jumprelu_sae.py,sha256=lkhafpoYYn4-62tBlmmufmUomoo3CmFFQQ3NNylBNSM,12264
20
+ sae_lens/saes/sae.py,sha256=edJK3VFzOVBPXUX6QJ5fhhoY0wcfEisDmVXiqFRA7Xg,35089
21
+ sae_lens/saes/standard_sae.py,sha256=tMs6Z6Cv44PWa7pLo53xhXFnHMvO5BM6eVYHtRPLpos,6652
22
+ sae_lens/saes/topk_sae.py,sha256=CfF59K4J2XwUvztwg4fBbvFO3PyucLkg4Elkxdk0ozs,9786
23
+ sae_lens/tokenization_and_batching.py,sha256=oUAscjy_LPOrOb8_Ty6eLAcZ0B3HB_wiWjWktgolhG0,4314
24
+ sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
25
+ sae_lens/training/activations_store.py,sha256=5V5dExeXWoE0dw-ePOZVnQIbBJwrepRMdsQrRam9Lg8,36790
26
+ sae_lens/training/geometric_median.py,sha256=3kH8ZJAgKStlnZgs6s1uYGDYh004Bl0r4RLhuwT3lBY,3719
27
+ sae_lens/training/optim.py,sha256=KXdOym-Ly3f2aFbndRc0JEH0Wa7u1BE5ljxGN3YtouQ,6836
28
+ sae_lens/training/sae_trainer.py,sha256=zYAk_9QJ8AJi2TjDZ1qW_lyoovSBqrJvBHzyYgb89ZY,15251
29
+ sae_lens/training/upload_saes_to_huggingface.py,sha256=tXvR4j25IgMjJ8R9oczwSdy00Tg-P_jAtnPHRt8yF64,4489
30
+ sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
31
+ sae_lens/util.py,sha256=4lqtl7HT9OiyRK8fe8nXtkcn2lOR1uX7ANrAClf6Bv8,1026
32
+ sae_lens-6.0.0rc2.dist-info/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
33
+ sae_lens-6.0.0rc2.dist-info/METADATA,sha256=Z8Zwb6EknAPB5dOvfduYZewr4nldot-1dQoqz50Co3k,5326
34
+ sae_lens-6.0.0rc2.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
35
+ sae_lens-6.0.0rc2.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: poetry-core 2.1.2
2
+ Generator: poetry-core 2.1.3
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
sae_lens/regsitry.py DELETED
@@ -1,34 +0,0 @@
1
- from typing import TYPE_CHECKING
2
-
3
- # avoid circular imports
4
- if TYPE_CHECKING:
5
- from sae_lens.saes.sae import SAE, TrainingSAE
6
-
7
- SAE_CLASS_REGISTRY: dict[str, "type[SAE]"] = {}
8
- SAE_TRAINING_CLASS_REGISTRY: dict[str, "type[TrainingSAE]"] = {}
9
-
10
-
11
- def register_sae_class(architecture: str, sae_class: "type[SAE]") -> None:
12
- if architecture in SAE_CLASS_REGISTRY:
13
- raise ValueError(
14
- f"SAE class for architecture {architecture} already registered."
15
- )
16
- SAE_CLASS_REGISTRY[architecture] = sae_class
17
-
18
-
19
- def register_sae_training_class(
20
- architecture: str, sae_training_class: "type[TrainingSAE]"
21
- ) -> None:
22
- if architecture in SAE_TRAINING_CLASS_REGISTRY:
23
- raise ValueError(
24
- f"SAE training class for architecture {architecture} already registered."
25
- )
26
- SAE_TRAINING_CLASS_REGISTRY[architecture] = sae_training_class
27
-
28
-
29
- def get_sae_class(architecture: str) -> "type[SAE]":
30
- return SAE_CLASS_REGISTRY[architecture]
31
-
32
-
33
- def get_sae_training_class(architecture: str) -> "type[TrainingSAE]":
34
- return SAE_TRAINING_CLASS_REGISTRY[architecture]
@@ -1,32 +0,0 @@
1
- sae_lens/__init__.py,sha256=ofQyurU7LtxIsg89QFCZe13QsdYpxErRI0x0tiCpB04,2074
2
- sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- sae_lens/analysis/hooked_sae_transformer.py,sha256=RK0mcLhymXdJInXHcagQggxW9Qf4ptePnH7sKXvGGaU,13727
4
- sae_lens/analysis/neuronpedia_integration.py,sha256=dFiKRWfuT5iUfTPBPmZydSaNG3VwqZ1asuNbbQv_NCM,18488
5
- sae_lens/cache_activations_runner.py,sha256=dGK5EHJMHAKDAFyr25fy1COSm-61q-q6kpWENHFMaKk,12561
6
- sae_lens/config.py,sha256=SPjziXrTyOBjObSi-3s0_mza3Z7WH8gd9NT9pVUfosg,34375
7
- sae_lens/evals.py,sha256=tjDKmkUM4fBbP9LHZuBLCx37ux8Px9CliTMme3Wjt1A,38898
8
- sae_lens/load_model.py,sha256=tE70sXsyyyGYW7o506O3eiw1MXyyW6DCQojLG49hWYI,6771
9
- sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
- sae_lens/loading/pretrained_sae_loaders.py,sha256=NcqyH2KDL8Dg66-hjXsBAq1-IwdLEpYfKwbkHxSQbrg,29961
11
- sae_lens/loading/pretrained_saes_directory.py,sha256=4Vn-Jex6SveD7EbxcSOBv8cx1gkPfUMLU1QOP-ww1ZE,3752
12
- sae_lens/pretokenize_runner.py,sha256=0nHQq3s_d80VS8iVK4-e6y_orAYVO8c4RrLGtIDfK_E,6885
13
- sae_lens/pretrained_saes.yaml,sha256=C_z-7Lxz6ZIy2V-c-4Xw45eAQ926O9aGjocSNuki0xs,573557
14
- sae_lens/regsitry.py,sha256=yCse5NmVH-ZaPET3jW8r7C_py2DL3yoox40GxGzJ0TI,1098
15
- sae_lens/sae_training_runner.py,sha256=VRNSAIsZLfcQMfZB8qdnK45PUXwoNvJ-rKt9BVYjMMY,8244
16
- sae_lens/saes/gated_sae.py,sha256=l5ucq7AZHya6ZClWNNE7CionGSf1ms5m1Ah3IoN6SH4,9916
17
- sae_lens/saes/jumprelu_sae.py,sha256=DRWgY58894cNh_sYAlefObI4rr0Eb6KHu1WuhTCcvB4,13468
18
- sae_lens/saes/sae.py,sha256=fd7OEsSXbmVii6QoYI_TRti6dwaxAQyrBcKyX7PxERw,36779
19
- sae_lens/saes/standard_sae.py,sha256=m2eNL_w6ave-_g7F1eQiwI4qbjMwwjzvxp96RN_WVAw,7110
20
- sae_lens/saes/topk_sae.py,sha256=aBET4F55A4xMIvZ8AazPtyl3oL-9S7krKx78li0uKGk,11370
21
- sae_lens/tokenization_and_batching.py,sha256=oUAscjy_LPOrOb8_Ty6eLAcZ0B3HB_wiWjWktgolhG0,4314
22
- sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
23
- sae_lens/training/activations_store.py,sha256=ilJdcnZWfTDus1bdoqIb1wF_7H8_HWLmf8OCGrybmlA,35998
24
- sae_lens/training/geometric_median.py,sha256=3kH8ZJAgKStlnZgs6s1uYGDYh004Bl0r4RLhuwT3lBY,3719
25
- sae_lens/training/optim.py,sha256=AImcc-MAaGDLOBP2hJ4alDFCtaqqgm4cc2eBxIxiQAo,5784
26
- sae_lens/training/sae_trainer.py,sha256=6TkqbzA0fYluRM8ouI_nU9sz-FaP63axxcnDrVfw37E,16279
27
- sae_lens/training/upload_saes_to_huggingface.py,sha256=tVC-2Txw7-9XttGlKzM0OSqU8CK7HDO9vIzDMqEwAYU,4366
28
- sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
29
- sae_lens-6.0.0rc1.dist-info/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
30
- sae_lens-6.0.0rc1.dist-info/METADATA,sha256=wHH-VRtquu-FjZEOHdPJi3zYW3ns7MCT1fVerbPEylc,5326
31
- sae_lens-6.0.0rc1.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
32
- sae_lens-6.0.0rc1.dist-info/RECORD,,