sae-lens 6.0.0rc1__py3-none-any.whl → 6.0.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +50 -16
- sae_lens/analysis/hooked_sae_transformer.py +10 -10
- sae_lens/analysis/neuronpedia_integration.py +13 -11
- sae_lens/cache_activations_runner.py +2 -1
- sae_lens/config.py +59 -231
- sae_lens/constants.py +18 -0
- sae_lens/evals.py +16 -13
- sae_lens/loading/pretrained_sae_loaders.py +36 -3
- sae_lens/registry.py +49 -0
- sae_lens/sae_training_runner.py +22 -21
- sae_lens/saes/__init__.py +48 -0
- sae_lens/saes/gated_sae.py +70 -59
- sae_lens/saes/jumprelu_sae.py +58 -72
- sae_lens/saes/sae.py +250 -272
- sae_lens/saes/standard_sae.py +75 -57
- sae_lens/saes/topk_sae.py +72 -83
- sae_lens/training/activations_store.py +31 -15
- sae_lens/training/optim.py +60 -36
- sae_lens/training/sae_trainer.py +44 -69
- sae_lens/training/upload_saes_to_huggingface.py +11 -5
- sae_lens/util.py +28 -0
- {sae_lens-6.0.0rc1.dist-info → sae_lens-6.0.0rc2.dist-info}/METADATA +1 -1
- sae_lens-6.0.0rc2.dist-info/RECORD +35 -0
- {sae_lens-6.0.0rc1.dist-info → sae_lens-6.0.0rc2.dist-info}/WHEEL +1 -1
- sae_lens/regsitry.py +0 -34
- sae_lens-6.0.0rc1.dist-info/RECORD +0 -32
- {sae_lens-6.0.0rc1.dist-info → sae_lens-6.0.0rc2.dist-info}/LICENSE +0 -0
sae_lens/training/sae_trainer.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import contextlib
|
|
2
2
|
from dataclasses import dataclass
|
|
3
|
-
from typing import Any, Protocol, cast
|
|
3
|
+
from typing import Any, Generic, Protocol, cast
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
import wandb
|
|
@@ -11,16 +11,16 @@ from transformer_lens.hook_points import HookedRootModule
|
|
|
11
11
|
from sae_lens import __version__
|
|
12
12
|
from sae_lens.config import LanguageModelSAERunnerConfig
|
|
13
13
|
from sae_lens.evals import EvalConfig, run_evals
|
|
14
|
-
from sae_lens.saes.sae import
|
|
14
|
+
from sae_lens.saes.sae import (
|
|
15
|
+
T_TRAINING_SAE,
|
|
16
|
+
T_TRAINING_SAE_CONFIG,
|
|
17
|
+
TrainCoefficientConfig,
|
|
18
|
+
TrainingSAE,
|
|
19
|
+
TrainStepInput,
|
|
20
|
+
TrainStepOutput,
|
|
21
|
+
)
|
|
15
22
|
from sae_lens.training.activations_store import ActivationsStore
|
|
16
|
-
from sae_lens.training.optim import
|
|
17
|
-
|
|
18
|
-
# used to map between parameters which are updated during finetuning and the config str.
|
|
19
|
-
FINETUNING_PARAMETERS = {
|
|
20
|
-
"scale": ["scaling_factor"],
|
|
21
|
-
"decoder": ["scaling_factor", "W_dec", "b_dec"],
|
|
22
|
-
"unrotated_decoder": ["scaling_factor", "b_dec"],
|
|
23
|
-
}
|
|
23
|
+
from sae_lens.training.optim import CoefficientScheduler, get_lr_scheduler
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
def _log_feature_sparsity(
|
|
@@ -29,7 +29,7 @@ def _log_feature_sparsity(
|
|
|
29
29
|
return torch.log10(feature_sparsity + eps).detach().cpu()
|
|
30
30
|
|
|
31
31
|
|
|
32
|
-
def _update_sae_lens_training_version(sae: TrainingSAE) -> None:
|
|
32
|
+
def _update_sae_lens_training_version(sae: TrainingSAE[Any]) -> None:
|
|
33
33
|
"""
|
|
34
34
|
Make sure we record the version of SAELens used for the training run
|
|
35
35
|
"""
|
|
@@ -38,7 +38,7 @@ def _update_sae_lens_training_version(sae: TrainingSAE) -> None:
|
|
|
38
38
|
|
|
39
39
|
@dataclass
|
|
40
40
|
class TrainSAEOutput:
|
|
41
|
-
sae: TrainingSAE
|
|
41
|
+
sae: TrainingSAE[Any]
|
|
42
42
|
checkpoint_path: str
|
|
43
43
|
log_feature_sparsities: torch.Tensor
|
|
44
44
|
|
|
@@ -46,13 +46,13 @@ class TrainSAEOutput:
|
|
|
46
46
|
class SaveCheckpointFn(Protocol):
|
|
47
47
|
def __call__(
|
|
48
48
|
self,
|
|
49
|
-
trainer: "SAETrainer",
|
|
49
|
+
trainer: "SAETrainer[Any, Any]",
|
|
50
50
|
checkpoint_name: str,
|
|
51
51
|
wandb_aliases: list[str] | None = None,
|
|
52
52
|
) -> None: ...
|
|
53
53
|
|
|
54
54
|
|
|
55
|
-
class SAETrainer:
|
|
55
|
+
class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
56
56
|
"""
|
|
57
57
|
Core SAE class used for inference. For training, see TrainingSAE.
|
|
58
58
|
"""
|
|
@@ -60,10 +60,10 @@ class SAETrainer:
|
|
|
60
60
|
def __init__(
|
|
61
61
|
self,
|
|
62
62
|
model: HookedRootModule,
|
|
63
|
-
sae:
|
|
63
|
+
sae: T_TRAINING_SAE,
|
|
64
64
|
activation_store: ActivationsStore,
|
|
65
65
|
save_checkpoint_fn: SaveCheckpointFn,
|
|
66
|
-
cfg: LanguageModelSAERunnerConfig,
|
|
66
|
+
cfg: LanguageModelSAERunnerConfig[T_TRAINING_SAE_CONFIG],
|
|
67
67
|
) -> None:
|
|
68
68
|
self.model = model
|
|
69
69
|
self.sae = sae
|
|
@@ -88,11 +88,11 @@ class SAETrainer:
|
|
|
88
88
|
)[1:]
|
|
89
89
|
|
|
90
90
|
self.act_freq_scores = torch.zeros(
|
|
91
|
-
cast(int, cfg.d_sae),
|
|
91
|
+
cast(int, cfg.sae.d_sae),
|
|
92
92
|
device=cfg.device,
|
|
93
93
|
)
|
|
94
94
|
self.n_forward_passes_since_fired = torch.zeros(
|
|
95
|
-
cast(int, cfg.d_sae),
|
|
95
|
+
cast(int, cfg.sae.d_sae),
|
|
96
96
|
device=cfg.device,
|
|
97
97
|
)
|
|
98
98
|
self.n_frac_active_tokens = 0
|
|
@@ -121,11 +121,14 @@ class SAETrainer:
|
|
|
121
121
|
lr_end=cfg.lr_end,
|
|
122
122
|
num_cycles=cfg.n_restart_cycles,
|
|
123
123
|
)
|
|
124
|
-
self.
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
124
|
+
self.coefficient_schedulers = {}
|
|
125
|
+
for name, coeff_cfg in self.sae.get_coefficients().items():
|
|
126
|
+
if not isinstance(coeff_cfg, TrainCoefficientConfig):
|
|
127
|
+
coeff_cfg = TrainCoefficientConfig(value=coeff_cfg, warm_up_steps=0)
|
|
128
|
+
self.coefficient_schedulers[name] = CoefficientScheduler(
|
|
129
|
+
warm_up_steps=coeff_cfg.warm_up_steps,
|
|
130
|
+
final_value=coeff_cfg.value,
|
|
131
|
+
)
|
|
129
132
|
|
|
130
133
|
# Setup autocast if using
|
|
131
134
|
self.scaler = torch.amp.GradScaler(
|
|
@@ -163,15 +166,11 @@ class SAETrainer:
|
|
|
163
166
|
def log_feature_sparsity(self) -> torch.Tensor:
|
|
164
167
|
return _log_feature_sparsity(self.feature_sparsity)
|
|
165
168
|
|
|
166
|
-
@property
|
|
167
|
-
def current_l1_coefficient(self) -> float:
|
|
168
|
-
return self.l1_scheduler.current_l1_coefficient
|
|
169
|
-
|
|
170
169
|
@property
|
|
171
170
|
def dead_neurons(self) -> torch.Tensor:
|
|
172
171
|
return (self.n_forward_passes_since_fired > self.cfg.dead_feature_window).bool()
|
|
173
172
|
|
|
174
|
-
def fit(self) ->
|
|
173
|
+
def fit(self) -> T_TRAINING_SAE:
|
|
175
174
|
pbar = tqdm(total=self.cfg.total_training_tokens, desc="Training SAE")
|
|
176
175
|
|
|
177
176
|
self.activations_store.set_norm_scaling_factor_if_needed()
|
|
@@ -194,9 +193,6 @@ class SAETrainer:
|
|
|
194
193
|
self.n_training_steps += 1
|
|
195
194
|
self._update_pbar(step_output, pbar)
|
|
196
195
|
|
|
197
|
-
### If n_training_tokens > sae_group.cfg.training_tokens, then we should switch to fine-tuning (if we haven't already)
|
|
198
|
-
self._begin_finetuning_if_needed()
|
|
199
|
-
|
|
200
196
|
# fold the estimated norm scaling factor into the sae weights
|
|
201
197
|
if self.activations_store.estimated_norm_scaling_factor is not None:
|
|
202
198
|
self.sae.fold_activation_norm_scaling_factor(
|
|
@@ -216,13 +212,10 @@ class SAETrainer:
|
|
|
216
212
|
|
|
217
213
|
def _train_step(
|
|
218
214
|
self,
|
|
219
|
-
sae:
|
|
215
|
+
sae: T_TRAINING_SAE,
|
|
220
216
|
sae_in: torch.Tensor,
|
|
221
217
|
) -> TrainStepOutput:
|
|
222
218
|
sae.train()
|
|
223
|
-
# Make sure the W_dec is still zero-norm
|
|
224
|
-
if self.cfg.normalize_sae_decoder:
|
|
225
|
-
sae.set_decoder_norm_to_unit_norm()
|
|
226
219
|
|
|
227
220
|
# log and then reset the feature sparsity every feature_sampling_window steps
|
|
228
221
|
if (self.n_training_steps + 1) % self.cfg.feature_sampling_window == 0:
|
|
@@ -238,7 +231,7 @@ class SAETrainer:
|
|
|
238
231
|
step_input=TrainStepInput(
|
|
239
232
|
sae_in=sae_in,
|
|
240
233
|
dead_neuron_mask=self.dead_neurons,
|
|
241
|
-
|
|
234
|
+
coefficients=self.get_coefficients(),
|
|
242
235
|
),
|
|
243
236
|
)
|
|
244
237
|
|
|
@@ -261,12 +254,10 @@ class SAETrainer:
|
|
|
261
254
|
self.scaler.step(self.optimizer) # just ctx.optimizer.step() if not autocasting
|
|
262
255
|
self.scaler.update()
|
|
263
256
|
|
|
264
|
-
if self.cfg.normalize_sae_decoder:
|
|
265
|
-
sae.remove_gradient_parallel_to_decoder_directions()
|
|
266
|
-
|
|
267
257
|
self.optimizer.zero_grad()
|
|
268
258
|
self.lr_scheduler.step()
|
|
269
|
-
self.
|
|
259
|
+
for scheduler in self.coefficient_schedulers.values():
|
|
260
|
+
scheduler.step()
|
|
270
261
|
|
|
271
262
|
return train_step_output
|
|
272
263
|
|
|
@@ -281,6 +272,13 @@ class SAETrainer:
|
|
|
281
272
|
step=self.n_training_steps,
|
|
282
273
|
)
|
|
283
274
|
|
|
275
|
+
@torch.no_grad()
|
|
276
|
+
def get_coefficients(self) -> dict[str, float]:
|
|
277
|
+
return {
|
|
278
|
+
name: scheduler.value
|
|
279
|
+
for name, scheduler in self.coefficient_schedulers.items()
|
|
280
|
+
}
|
|
281
|
+
|
|
284
282
|
@torch.no_grad()
|
|
285
283
|
def _build_train_step_log_dict(
|
|
286
284
|
self,
|
|
@@ -313,19 +311,15 @@ class SAETrainer:
|
|
|
313
311
|
"sparsity/mean_passes_since_fired": self.n_forward_passes_since_fired.mean().item(),
|
|
314
312
|
"sparsity/dead_features": self.dead_neurons.sum().item(),
|
|
315
313
|
"details/current_learning_rate": current_learning_rate,
|
|
316
|
-
"details/current_l1_coefficient": self.current_l1_coefficient,
|
|
317
314
|
"details/n_training_tokens": n_training_tokens,
|
|
315
|
+
**{
|
|
316
|
+
f"details/{name}_coefficient": scheduler.value
|
|
317
|
+
for name, scheduler in self.coefficient_schedulers.items()
|
|
318
|
+
},
|
|
318
319
|
}
|
|
319
320
|
for loss_name, loss_value in output.losses.items():
|
|
320
321
|
loss_item = _unwrap_item(loss_value)
|
|
321
|
-
|
|
322
|
-
if loss_name == "l1_loss":
|
|
323
|
-
log_dict[f"losses/{loss_name}"] = (
|
|
324
|
-
loss_item / self.current_l1_coefficient
|
|
325
|
-
)
|
|
326
|
-
log_dict[f"losses/raw_{loss_name}"] = loss_item
|
|
327
|
-
else:
|
|
328
|
-
log_dict[f"losses/{loss_name}"] = loss_item
|
|
322
|
+
log_dict[f"losses/{loss_name}"] = loss_item
|
|
329
323
|
|
|
330
324
|
return log_dict
|
|
331
325
|
|
|
@@ -384,7 +378,7 @@ class SAETrainer:
|
|
|
384
378
|
@torch.no_grad()
|
|
385
379
|
def _reset_running_sparsity_stats(self) -> None:
|
|
386
380
|
self.act_freq_scores = torch.zeros(
|
|
387
|
-
self.cfg.d_sae, # type: ignore
|
|
381
|
+
self.cfg.sae.d_sae, # type: ignore
|
|
388
382
|
device=self.cfg.device,
|
|
389
383
|
)
|
|
390
384
|
self.n_frac_active_tokens = 0
|
|
@@ -416,25 +410,6 @@ class SAETrainer:
|
|
|
416
410
|
pbar.set_description(f"{self.n_training_steps}| {loss_strs}")
|
|
417
411
|
pbar.update(update_interval * self.cfg.train_batch_size_tokens)
|
|
418
412
|
|
|
419
|
-
def _begin_finetuning_if_needed(self):
|
|
420
|
-
if (not self.started_fine_tuning) and (
|
|
421
|
-
self.n_training_tokens > self.cfg.training_tokens
|
|
422
|
-
):
|
|
423
|
-
self.started_fine_tuning = True
|
|
424
|
-
|
|
425
|
-
# finetuning method should be set in the config
|
|
426
|
-
# if not, then we don't finetune
|
|
427
|
-
if not isinstance(self.cfg.finetuning_method, str):
|
|
428
|
-
return
|
|
429
|
-
|
|
430
|
-
for name, param in self.sae.named_parameters():
|
|
431
|
-
if name in FINETUNING_PARAMETERS[self.cfg.finetuning_method]:
|
|
432
|
-
param.requires_grad = True
|
|
433
|
-
else:
|
|
434
|
-
param.requires_grad = False
|
|
435
|
-
|
|
436
|
-
self.finetuning = True
|
|
437
|
-
|
|
438
413
|
|
|
439
414
|
def _unwrap_item(item: float | torch.Tensor) -> float:
|
|
440
415
|
return item.item() if isinstance(item, torch.Tensor) else item
|
|
@@ -2,14 +2,15 @@ import io
|
|
|
2
2
|
from pathlib import Path
|
|
3
3
|
from tempfile import TemporaryDirectory
|
|
4
4
|
from textwrap import dedent
|
|
5
|
-
from typing import Iterable
|
|
5
|
+
from typing import Any, Iterable
|
|
6
6
|
|
|
7
7
|
from huggingface_hub import HfApi, create_repo, get_hf_file_metadata, hf_hub_url
|
|
8
8
|
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
|
|
9
9
|
from tqdm.autonotebook import tqdm
|
|
10
10
|
|
|
11
11
|
from sae_lens import logger
|
|
12
|
-
from sae_lens.
|
|
12
|
+
from sae_lens.constants import (
|
|
13
|
+
RUNNER_CFG_FILENAME,
|
|
13
14
|
SAE_CFG_FILENAME,
|
|
14
15
|
SAE_WEIGHTS_FILENAME,
|
|
15
16
|
SPARSITY_FILENAME,
|
|
@@ -18,7 +19,7 @@ from sae_lens.saes.sae import SAE
|
|
|
18
19
|
|
|
19
20
|
|
|
20
21
|
def upload_saes_to_huggingface(
|
|
21
|
-
saes_dict: dict[str, SAE | Path | str],
|
|
22
|
+
saes_dict: dict[str, SAE[Any] | Path | str],
|
|
22
23
|
hf_repo_id: str,
|
|
23
24
|
hf_revision: str = "main",
|
|
24
25
|
show_progress: bool = True,
|
|
@@ -119,11 +120,16 @@ def _upload_sae(api: HfApi, sae_path: Path, repo_id: str, sae_id: str, revision:
|
|
|
119
120
|
revision=revision,
|
|
120
121
|
repo_type="model",
|
|
121
122
|
commit_message=f"Upload SAE {sae_id}",
|
|
122
|
-
allow_patterns=[
|
|
123
|
+
allow_patterns=[
|
|
124
|
+
SAE_CFG_FILENAME,
|
|
125
|
+
SAE_WEIGHTS_FILENAME,
|
|
126
|
+
SPARSITY_FILENAME,
|
|
127
|
+
RUNNER_CFG_FILENAME,
|
|
128
|
+
],
|
|
123
129
|
)
|
|
124
130
|
|
|
125
131
|
|
|
126
|
-
def _build_sae_path(sae_ref: SAE | Path | str, tmp_dir: str) -> Path:
|
|
132
|
+
def _build_sae_path(sae_ref: SAE[Any] | Path | str, tmp_dir: str) -> Path:
|
|
127
133
|
if isinstance(sae_ref, SAE):
|
|
128
134
|
sae_ref.save_model(tmp_dir)
|
|
129
135
|
return Path(tmp_dir)
|
sae_lens/util.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from dataclasses import asdict, fields, is_dataclass
|
|
2
|
+
from typing import Sequence, TypeVar
|
|
3
|
+
|
|
4
|
+
K = TypeVar("K")
|
|
5
|
+
V = TypeVar("V")
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def filter_valid_dataclass_fields(
|
|
9
|
+
source: dict[str, V] | object,
|
|
10
|
+
destination: object | type,
|
|
11
|
+
whitelist_fields: Sequence[str] | None = None,
|
|
12
|
+
) -> dict[str, V]:
|
|
13
|
+
"""Filter a source dict or dataclass instance to only include fields that are present in the destination dataclass."""
|
|
14
|
+
|
|
15
|
+
if not is_dataclass(destination):
|
|
16
|
+
raise ValueError(f"{destination} is not a dataclass")
|
|
17
|
+
|
|
18
|
+
if is_dataclass(source) and not isinstance(source, type):
|
|
19
|
+
source_dict = asdict(source)
|
|
20
|
+
elif isinstance(source, dict):
|
|
21
|
+
source_dict = source
|
|
22
|
+
else:
|
|
23
|
+
raise ValueError(f"{source} is not a dict or dataclass")
|
|
24
|
+
|
|
25
|
+
valid_field_names = {field.name for field in fields(destination)}
|
|
26
|
+
if whitelist_fields is not None:
|
|
27
|
+
valid_field_names = valid_field_names.union(whitelist_fields)
|
|
28
|
+
return {key: val for key, val in source_dict.items() if key in valid_field_names}
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
sae_lens/__init__.py,sha256=JZATcdlWGVOXYTHb41hn7dPp7pR2tWgpLAz2ztQOE-A,2747
|
|
2
|
+
sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
|
+
sae_lens/analysis/hooked_sae_transformer.py,sha256=Eyg1Y2hVIHNuiiLOCTgzstOuW6iA-7hPHqaGR8y_vMs,13809
|
|
4
|
+
sae_lens/analysis/neuronpedia_integration.py,sha256=DlI08ThI0zwMrBthICt1OFCMyqmaCUDeZxhOk7b7teY,18680
|
|
5
|
+
sae_lens/cache_activations_runner.py,sha256=27jp2hFxZj4foWCRCJJd2VCwYJtMgkvPx6MuIhQBofc,12591
|
|
6
|
+
sae_lens/config.py,sha256=Ff6MRzRlVk8xtgkvHdJEmuPh9Owc10XIWBaUwdypzkU,26062
|
|
7
|
+
sae_lens/constants.py,sha256=HSiSp0j2Umak2buT30seFhkmj7KNuPmB3u4yLXrgfOg,462
|
|
8
|
+
sae_lens/evals.py,sha256=aR0pJMBWBUdZElXPcxUyNnNYWbM2LC5UeaESKAwdOMY,39098
|
|
9
|
+
sae_lens/load_model.py,sha256=tE70sXsyyyGYW7o506O3eiw1MXyyW6DCQojLG49hWYI,6771
|
|
10
|
+
sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
11
|
+
sae_lens/loading/pretrained_sae_loaders.py,sha256=IgQ-XSJ5VTLCzmJavPmk1vExBVB-36wW7w-ZNo7tzPY,31214
|
|
12
|
+
sae_lens/loading/pretrained_saes_directory.py,sha256=4Vn-Jex6SveD7EbxcSOBv8cx1gkPfUMLU1QOP-ww1ZE,3752
|
|
13
|
+
sae_lens/pretokenize_runner.py,sha256=0nHQq3s_d80VS8iVK4-e6y_orAYVO8c4RrLGtIDfK_E,6885
|
|
14
|
+
sae_lens/pretrained_saes.yaml,sha256=C_z-7Lxz6ZIy2V-c-4Xw45eAQ926O9aGjocSNuki0xs,573557
|
|
15
|
+
sae_lens/registry.py,sha256=nhy7BPSudSATqW4lo9H_k3Na7sfGHmAf9v-3wpnLL_o,1490
|
|
16
|
+
sae_lens/sae_training_runner.py,sha256=lI_d3ywS312dIz0wctm_Sgt3W9ffBOS7ahnDXBljX1s,8320
|
|
17
|
+
sae_lens/saes/__init__.py,sha256=v6mfeDzyGYtT6x5SszAQtkldTXwPE-V_iwOlrT_pDwQ,1008
|
|
18
|
+
sae_lens/saes/gated_sae.py,sha256=IgWvZxeJpdiu7VqeUnJLC-VWVhz6o8OXvmwCS-LJ-WQ,9426
|
|
19
|
+
sae_lens/saes/jumprelu_sae.py,sha256=lkhafpoYYn4-62tBlmmufmUomoo3CmFFQQ3NNylBNSM,12264
|
|
20
|
+
sae_lens/saes/sae.py,sha256=edJK3VFzOVBPXUX6QJ5fhhoY0wcfEisDmVXiqFRA7Xg,35089
|
|
21
|
+
sae_lens/saes/standard_sae.py,sha256=tMs6Z6Cv44PWa7pLo53xhXFnHMvO5BM6eVYHtRPLpos,6652
|
|
22
|
+
sae_lens/saes/topk_sae.py,sha256=CfF59K4J2XwUvztwg4fBbvFO3PyucLkg4Elkxdk0ozs,9786
|
|
23
|
+
sae_lens/tokenization_and_batching.py,sha256=oUAscjy_LPOrOb8_Ty6eLAcZ0B3HB_wiWjWktgolhG0,4314
|
|
24
|
+
sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
25
|
+
sae_lens/training/activations_store.py,sha256=5V5dExeXWoE0dw-ePOZVnQIbBJwrepRMdsQrRam9Lg8,36790
|
|
26
|
+
sae_lens/training/geometric_median.py,sha256=3kH8ZJAgKStlnZgs6s1uYGDYh004Bl0r4RLhuwT3lBY,3719
|
|
27
|
+
sae_lens/training/optim.py,sha256=KXdOym-Ly3f2aFbndRc0JEH0Wa7u1BE5ljxGN3YtouQ,6836
|
|
28
|
+
sae_lens/training/sae_trainer.py,sha256=zYAk_9QJ8AJi2TjDZ1qW_lyoovSBqrJvBHzyYgb89ZY,15251
|
|
29
|
+
sae_lens/training/upload_saes_to_huggingface.py,sha256=tXvR4j25IgMjJ8R9oczwSdy00Tg-P_jAtnPHRt8yF64,4489
|
|
30
|
+
sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
|
|
31
|
+
sae_lens/util.py,sha256=4lqtl7HT9OiyRK8fe8nXtkcn2lOR1uX7ANrAClf6Bv8,1026
|
|
32
|
+
sae_lens-6.0.0rc2.dist-info/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
|
|
33
|
+
sae_lens-6.0.0rc2.dist-info/METADATA,sha256=Z8Zwb6EknAPB5dOvfduYZewr4nldot-1dQoqz50Co3k,5326
|
|
34
|
+
sae_lens-6.0.0rc2.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
|
35
|
+
sae_lens-6.0.0rc2.dist-info/RECORD,,
|
sae_lens/regsitry.py
DELETED
|
@@ -1,34 +0,0 @@
|
|
|
1
|
-
from typing import TYPE_CHECKING
|
|
2
|
-
|
|
3
|
-
# avoid circular imports
|
|
4
|
-
if TYPE_CHECKING:
|
|
5
|
-
from sae_lens.saes.sae import SAE, TrainingSAE
|
|
6
|
-
|
|
7
|
-
SAE_CLASS_REGISTRY: dict[str, "type[SAE]"] = {}
|
|
8
|
-
SAE_TRAINING_CLASS_REGISTRY: dict[str, "type[TrainingSAE]"] = {}
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
def register_sae_class(architecture: str, sae_class: "type[SAE]") -> None:
|
|
12
|
-
if architecture in SAE_CLASS_REGISTRY:
|
|
13
|
-
raise ValueError(
|
|
14
|
-
f"SAE class for architecture {architecture} already registered."
|
|
15
|
-
)
|
|
16
|
-
SAE_CLASS_REGISTRY[architecture] = sae_class
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
def register_sae_training_class(
|
|
20
|
-
architecture: str, sae_training_class: "type[TrainingSAE]"
|
|
21
|
-
) -> None:
|
|
22
|
-
if architecture in SAE_TRAINING_CLASS_REGISTRY:
|
|
23
|
-
raise ValueError(
|
|
24
|
-
f"SAE training class for architecture {architecture} already registered."
|
|
25
|
-
)
|
|
26
|
-
SAE_TRAINING_CLASS_REGISTRY[architecture] = sae_training_class
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
def get_sae_class(architecture: str) -> "type[SAE]":
|
|
30
|
-
return SAE_CLASS_REGISTRY[architecture]
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
def get_sae_training_class(architecture: str) -> "type[TrainingSAE]":
|
|
34
|
-
return SAE_TRAINING_CLASS_REGISTRY[architecture]
|
|
@@ -1,32 +0,0 @@
|
|
|
1
|
-
sae_lens/__init__.py,sha256=ofQyurU7LtxIsg89QFCZe13QsdYpxErRI0x0tiCpB04,2074
|
|
2
|
-
sae_lens/analysis/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
|
-
sae_lens/analysis/hooked_sae_transformer.py,sha256=RK0mcLhymXdJInXHcagQggxW9Qf4ptePnH7sKXvGGaU,13727
|
|
4
|
-
sae_lens/analysis/neuronpedia_integration.py,sha256=dFiKRWfuT5iUfTPBPmZydSaNG3VwqZ1asuNbbQv_NCM,18488
|
|
5
|
-
sae_lens/cache_activations_runner.py,sha256=dGK5EHJMHAKDAFyr25fy1COSm-61q-q6kpWENHFMaKk,12561
|
|
6
|
-
sae_lens/config.py,sha256=SPjziXrTyOBjObSi-3s0_mza3Z7WH8gd9NT9pVUfosg,34375
|
|
7
|
-
sae_lens/evals.py,sha256=tjDKmkUM4fBbP9LHZuBLCx37ux8Px9CliTMme3Wjt1A,38898
|
|
8
|
-
sae_lens/load_model.py,sha256=tE70sXsyyyGYW7o506O3eiw1MXyyW6DCQojLG49hWYI,6771
|
|
9
|
-
sae_lens/loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
10
|
-
sae_lens/loading/pretrained_sae_loaders.py,sha256=NcqyH2KDL8Dg66-hjXsBAq1-IwdLEpYfKwbkHxSQbrg,29961
|
|
11
|
-
sae_lens/loading/pretrained_saes_directory.py,sha256=4Vn-Jex6SveD7EbxcSOBv8cx1gkPfUMLU1QOP-ww1ZE,3752
|
|
12
|
-
sae_lens/pretokenize_runner.py,sha256=0nHQq3s_d80VS8iVK4-e6y_orAYVO8c4RrLGtIDfK_E,6885
|
|
13
|
-
sae_lens/pretrained_saes.yaml,sha256=C_z-7Lxz6ZIy2V-c-4Xw45eAQ926O9aGjocSNuki0xs,573557
|
|
14
|
-
sae_lens/regsitry.py,sha256=yCse5NmVH-ZaPET3jW8r7C_py2DL3yoox40GxGzJ0TI,1098
|
|
15
|
-
sae_lens/sae_training_runner.py,sha256=VRNSAIsZLfcQMfZB8qdnK45PUXwoNvJ-rKt9BVYjMMY,8244
|
|
16
|
-
sae_lens/saes/gated_sae.py,sha256=l5ucq7AZHya6ZClWNNE7CionGSf1ms5m1Ah3IoN6SH4,9916
|
|
17
|
-
sae_lens/saes/jumprelu_sae.py,sha256=DRWgY58894cNh_sYAlefObI4rr0Eb6KHu1WuhTCcvB4,13468
|
|
18
|
-
sae_lens/saes/sae.py,sha256=fd7OEsSXbmVii6QoYI_TRti6dwaxAQyrBcKyX7PxERw,36779
|
|
19
|
-
sae_lens/saes/standard_sae.py,sha256=m2eNL_w6ave-_g7F1eQiwI4qbjMwwjzvxp96RN_WVAw,7110
|
|
20
|
-
sae_lens/saes/topk_sae.py,sha256=aBET4F55A4xMIvZ8AazPtyl3oL-9S7krKx78li0uKGk,11370
|
|
21
|
-
sae_lens/tokenization_and_batching.py,sha256=oUAscjy_LPOrOb8_Ty6eLAcZ0B3HB_wiWjWktgolhG0,4314
|
|
22
|
-
sae_lens/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
23
|
-
sae_lens/training/activations_store.py,sha256=ilJdcnZWfTDus1bdoqIb1wF_7H8_HWLmf8OCGrybmlA,35998
|
|
24
|
-
sae_lens/training/geometric_median.py,sha256=3kH8ZJAgKStlnZgs6s1uYGDYh004Bl0r4RLhuwT3lBY,3719
|
|
25
|
-
sae_lens/training/optim.py,sha256=AImcc-MAaGDLOBP2hJ4alDFCtaqqgm4cc2eBxIxiQAo,5784
|
|
26
|
-
sae_lens/training/sae_trainer.py,sha256=6TkqbzA0fYluRM8ouI_nU9sz-FaP63axxcnDrVfw37E,16279
|
|
27
|
-
sae_lens/training/upload_saes_to_huggingface.py,sha256=tVC-2Txw7-9XttGlKzM0OSqU8CK7HDO9vIzDMqEwAYU,4366
|
|
28
|
-
sae_lens/tutorial/tsea.py,sha256=fd1am_XXsf2KMbByDapJo-2qlxduKaa62Z2qcQZ3QKU,18145
|
|
29
|
-
sae_lens-6.0.0rc1.dist-info/LICENSE,sha256=DW6e-hDosiu4CfW0-imI57sV1I5f9UEslpviNQcOAKs,1069
|
|
30
|
-
sae_lens-6.0.0rc1.dist-info/METADATA,sha256=wHH-VRtquu-FjZEOHdPJi3zYW3ns7MCT1fVerbPEylc,5326
|
|
31
|
-
sae_lens-6.0.0rc1.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
|
32
|
-
sae_lens-6.0.0rc1.dist-info/RECORD,,
|
|
File without changes
|