sae-lens 5.9.0__py3-none-any.whl → 6.0.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sae_lens/__init__.py +22 -6
- sae_lens/analysis/hooked_sae_transformer.py +2 -2
- sae_lens/config.py +66 -23
- sae_lens/evals.py +6 -5
- sae_lens/{toolkit → loading}/pretrained_sae_loaders.py +33 -25
- sae_lens/regsitry.py +34 -0
- sae_lens/sae_training_runner.py +18 -33
- sae_lens/saes/gated_sae.py +247 -0
- sae_lens/saes/jumprelu_sae.py +368 -0
- sae_lens/saes/sae.py +970 -0
- sae_lens/saes/standard_sae.py +167 -0
- sae_lens/saes/topk_sae.py +305 -0
- sae_lens/training/activations_store.py +2 -2
- sae_lens/training/sae_trainer.py +13 -19
- sae_lens/training/upload_saes_to_huggingface.py +1 -1
- {sae_lens-5.9.0.dist-info → sae_lens-6.0.0rc1.dist-info}/METADATA +3 -3
- sae_lens-6.0.0rc1.dist-info/RECORD +32 -0
- sae_lens/sae.py +0 -747
- sae_lens/training/training_sae.py +0 -705
- sae_lens-5.9.0.dist-info/RECORD +0 -28
- /sae_lens/{toolkit → loading}/__init__.py +0 -0
- /sae_lens/{toolkit → loading}/pretrained_saes_directory.py +0 -0
- {sae_lens-5.9.0.dist-info → sae_lens-6.0.0rc1.dist-info}/LICENSE +0 -0
- {sae_lens-5.9.0.dist-info → sae_lens-6.0.0rc1.dist-info}/WHEEL +0 -0
sae_lens/__init__.py
CHANGED
|
@@ -1,10 +1,15 @@
|
|
|
1
1
|
# ruff: noqa: E402
|
|
2
|
-
__version__ = "
|
|
2
|
+
__version__ = "6.0.0-rc.1"
|
|
3
3
|
|
|
4
4
|
import logging
|
|
5
5
|
|
|
6
6
|
logger = logging.getLogger(__name__)
|
|
7
7
|
|
|
8
|
+
from sae_lens.saes.gated_sae import GatedSAE, GatedTrainingSAE
|
|
9
|
+
from sae_lens.saes.jumprelu_sae import JumpReLUSAE, JumpReLUTrainingSAE
|
|
10
|
+
from sae_lens.saes.standard_sae import StandardSAE, StandardTrainingSAE
|
|
11
|
+
from sae_lens.saes.topk_sae import TopKSAE, TopKTrainingSAE
|
|
12
|
+
|
|
8
13
|
from .analysis.hooked_sae_transformer import HookedSAETransformer
|
|
9
14
|
from .cache_activations_runner import CacheActivationsRunner
|
|
10
15
|
from .config import (
|
|
@@ -13,17 +18,26 @@ from .config import (
|
|
|
13
18
|
PretokenizeRunnerConfig,
|
|
14
19
|
)
|
|
15
20
|
from .evals import run_evals
|
|
16
|
-
from .
|
|
17
|
-
from .sae import SAE, SAEConfig
|
|
18
|
-
from .sae_training_runner import SAETrainingRunner
|
|
19
|
-
from .toolkit.pretrained_sae_loaders import (
|
|
21
|
+
from .loading.pretrained_sae_loaders import (
|
|
20
22
|
PretrainedSaeDiskLoader,
|
|
21
23
|
PretrainedSaeHuggingfaceLoader,
|
|
22
24
|
)
|
|
25
|
+
from .pretokenize_runner import PretokenizeRunner, pretokenize_runner
|
|
26
|
+
from .regsitry import register_sae_class, register_sae_training_class
|
|
27
|
+
from .sae_training_runner import SAETrainingRunner
|
|
28
|
+
from .saes.sae import SAE, SAEConfig, TrainingSAE, TrainingSAEConfig
|
|
23
29
|
from .training.activations_store import ActivationsStore
|
|
24
|
-
from .training.training_sae import TrainingSAE, TrainingSAEConfig
|
|
25
30
|
from .training.upload_saes_to_huggingface import upload_saes_to_huggingface
|
|
26
31
|
|
|
32
|
+
register_sae_class("standard", StandardSAE)
|
|
33
|
+
register_sae_training_class("standard", StandardTrainingSAE)
|
|
34
|
+
register_sae_class("gated", GatedSAE)
|
|
35
|
+
register_sae_training_class("gated", GatedTrainingSAE)
|
|
36
|
+
register_sae_class("topk", TopKSAE)
|
|
37
|
+
register_sae_training_class("topk", TopKTrainingSAE)
|
|
38
|
+
register_sae_class("jumprelu", JumpReLUSAE)
|
|
39
|
+
register_sae_training_class("jumprelu", JumpReLUTrainingSAE)
|
|
40
|
+
|
|
27
41
|
__all__ = [
|
|
28
42
|
"SAE",
|
|
29
43
|
"SAEConfig",
|
|
@@ -42,4 +56,6 @@ __all__ = [
|
|
|
42
56
|
"upload_saes_to_huggingface",
|
|
43
57
|
"PretrainedSaeHuggingfaceLoader",
|
|
44
58
|
"PretrainedSaeDiskLoader",
|
|
59
|
+
"register_sae_class",
|
|
60
|
+
"register_sae_training_class",
|
|
45
61
|
]
|
|
@@ -8,7 +8,7 @@ from transformer_lens.ActivationCache import ActivationCache
|
|
|
8
8
|
from transformer_lens.hook_points import HookPoint # Hooking utilities
|
|
9
9
|
from transformer_lens.HookedTransformer import HookedTransformer
|
|
10
10
|
|
|
11
|
-
from sae_lens.sae import SAE
|
|
11
|
+
from sae_lens.saes.sae import SAE
|
|
12
12
|
|
|
13
13
|
SingleLoss = Float[torch.Tensor, ""] # Type alias for a single element tensor
|
|
14
14
|
LossPerToken = Float[torch.Tensor, "batch pos-1"]
|
|
@@ -275,7 +275,7 @@ class HookedSAETransformer(HookedTransformer):
|
|
|
275
275
|
.. code-block:: python
|
|
276
276
|
|
|
277
277
|
from transformer_lens import HookedSAETransformer
|
|
278
|
-
from sae_lens.sae import SAE
|
|
278
|
+
from sae_lens.saes.sae import SAE
|
|
279
279
|
|
|
280
280
|
model = HookedSAETransformer.from_pretrained('gpt2-small')
|
|
281
281
|
sae_cfg = SAEConfig(...)
|
sae_lens/config.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import math
|
|
3
3
|
import os
|
|
4
|
-
from dataclasses import dataclass, field
|
|
4
|
+
from dataclasses import asdict, dataclass, field
|
|
5
|
+
from pathlib import Path
|
|
5
6
|
from typing import Any, Literal, cast
|
|
6
7
|
|
|
7
8
|
import simple_parsing
|
|
@@ -53,6 +54,52 @@ def dict_field(default: dict[str, Any] | None, **kwargs: Any) -> Any: # type: i
|
|
|
53
54
|
return simple_parsing.helpers.dict_field(default, type=json_dict, **kwargs)
|
|
54
55
|
|
|
55
56
|
|
|
57
|
+
@dataclass
|
|
58
|
+
class LoggingConfig:
|
|
59
|
+
# WANDB
|
|
60
|
+
log_to_wandb: bool = True
|
|
61
|
+
log_activations_store_to_wandb: bool = False
|
|
62
|
+
log_optimizer_state_to_wandb: bool = False
|
|
63
|
+
wandb_project: str = "sae_lens_training"
|
|
64
|
+
wandb_id: str | None = None
|
|
65
|
+
run_name: str | None = None
|
|
66
|
+
wandb_entity: str | None = None
|
|
67
|
+
wandb_log_frequency: int = 10
|
|
68
|
+
eval_every_n_wandb_logs: int = 100 # logs every 100 steps.
|
|
69
|
+
|
|
70
|
+
def log(
|
|
71
|
+
self,
|
|
72
|
+
trainer: Any, # avoid import cycle from importing SAETrainer
|
|
73
|
+
weights_path: Path | str,
|
|
74
|
+
cfg_path: Path | str,
|
|
75
|
+
sparsity_path: Path | str | None,
|
|
76
|
+
wandb_aliases: list[str] | None = None,
|
|
77
|
+
) -> None:
|
|
78
|
+
# Avoid wandb saving errors such as:
|
|
79
|
+
# ValueError: Artifact name may only contain alphanumeric characters, dashes, underscores, and dots. Invalid name: sae_google/gemma-2b_etc
|
|
80
|
+
sae_name = trainer.sae.get_name().replace("/", "__")
|
|
81
|
+
|
|
82
|
+
# save model weights and cfg
|
|
83
|
+
model_artifact = wandb.Artifact(
|
|
84
|
+
sae_name,
|
|
85
|
+
type="model",
|
|
86
|
+
metadata=dict(trainer.cfg.__dict__),
|
|
87
|
+
)
|
|
88
|
+
model_artifact.add_file(str(weights_path))
|
|
89
|
+
model_artifact.add_file(str(cfg_path))
|
|
90
|
+
wandb.log_artifact(model_artifact, aliases=wandb_aliases)
|
|
91
|
+
|
|
92
|
+
# save log feature sparsity
|
|
93
|
+
sparsity_artifact = wandb.Artifact(
|
|
94
|
+
f"{sae_name}_log_feature_sparsity",
|
|
95
|
+
type="log_feature_sparsity",
|
|
96
|
+
metadata=dict(trainer.cfg.__dict__),
|
|
97
|
+
)
|
|
98
|
+
if sparsity_path is not None:
|
|
99
|
+
sparsity_artifact.add_file(str(sparsity_path))
|
|
100
|
+
wandb.log_artifact(sparsity_artifact)
|
|
101
|
+
|
|
102
|
+
|
|
56
103
|
@dataclass
|
|
57
104
|
class LanguageModelSAERunnerConfig:
|
|
58
105
|
"""
|
|
@@ -245,16 +292,7 @@ class LanguageModelSAERunnerConfig:
|
|
|
245
292
|
n_eval_batches: int = 10
|
|
246
293
|
eval_batch_size_prompts: int | None = None # useful if evals cause OOM
|
|
247
294
|
|
|
248
|
-
|
|
249
|
-
log_to_wandb: bool = True
|
|
250
|
-
log_activations_store_to_wandb: bool = False
|
|
251
|
-
log_optimizer_state_to_wandb: bool = False
|
|
252
|
-
wandb_project: str = "mats_sae_training_language_model"
|
|
253
|
-
wandb_id: str | None = None
|
|
254
|
-
run_name: str | None = None
|
|
255
|
-
wandb_entity: str | None = None
|
|
256
|
-
wandb_log_frequency: int = 10
|
|
257
|
-
eval_every_n_wandb_logs: int = 100 # logs every 1000 steps.
|
|
295
|
+
logger: LoggingConfig = field(default_factory=LoggingConfig)
|
|
258
296
|
|
|
259
297
|
# Misc
|
|
260
298
|
resume: bool = False
|
|
@@ -310,8 +348,8 @@ class LanguageModelSAERunnerConfig:
|
|
|
310
348
|
self.train_batch_size_tokens * self.context_size * self.n_batches_in_buffer
|
|
311
349
|
)
|
|
312
350
|
|
|
313
|
-
if self.run_name is None:
|
|
314
|
-
self.run_name = f"{self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.training_tokens:3.3e}"
|
|
351
|
+
if self.logger.run_name is None:
|
|
352
|
+
self.logger.run_name = f"{self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.training_tokens:3.3e}"
|
|
315
353
|
|
|
316
354
|
if self.model_from_pretrained_kwargs is None:
|
|
317
355
|
if self.model_class_name == "HookedTransformer":
|
|
@@ -356,7 +394,7 @@ class LanguageModelSAERunnerConfig:
|
|
|
356
394
|
if self.lr_end is None:
|
|
357
395
|
self.lr_end = self.lr / 10
|
|
358
396
|
|
|
359
|
-
unique_id = self.wandb_id
|
|
397
|
+
unique_id = self.logger.wandb_id
|
|
360
398
|
if unique_id is None:
|
|
361
399
|
unique_id = cast(
|
|
362
400
|
Any, wandb
|
|
@@ -388,7 +426,9 @@ class LanguageModelSAERunnerConfig:
|
|
|
388
426
|
) // self.train_batch_size_tokens
|
|
389
427
|
logger.info(f"Total training steps: {total_training_steps}")
|
|
390
428
|
|
|
391
|
-
total_wandb_updates =
|
|
429
|
+
total_wandb_updates = (
|
|
430
|
+
total_training_steps // self.logger.wandb_log_frequency
|
|
431
|
+
)
|
|
392
432
|
logger.info(f"Total wandb updates: {total_wandb_updates}")
|
|
393
433
|
|
|
394
434
|
# how many times will we sample dead neurons?
|
|
@@ -445,7 +485,7 @@ class LanguageModelSAERunnerConfig:
|
|
|
445
485
|
"hook_name": self.hook_name,
|
|
446
486
|
"hook_layer": self.hook_layer,
|
|
447
487
|
"hook_head_index": self.hook_head_index,
|
|
448
|
-
"
|
|
488
|
+
"activation_fn": self.activation_fn,
|
|
449
489
|
"apply_b_dec_to_input": self.apply_b_dec_to_input,
|
|
450
490
|
"context_size": self.context_size,
|
|
451
491
|
"prepend_bos": self.prepend_bos,
|
|
@@ -478,13 +518,16 @@ class LanguageModelSAERunnerConfig:
|
|
|
478
518
|
}
|
|
479
519
|
|
|
480
520
|
def to_dict(self) -> dict[str, Any]:
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
521
|
+
# Make a shallow copy of config’s dictionary
|
|
522
|
+
d = dict(self.__dict__)
|
|
523
|
+
|
|
524
|
+
d["logger"] = asdict(self.logger)
|
|
525
|
+
|
|
526
|
+
# Overwrite fields that might not be JSON-serializable
|
|
527
|
+
d["dtype"] = str(self.dtype)
|
|
528
|
+
d["device"] = str(self.device)
|
|
529
|
+
d["act_store_device"] = str(self.act_store_device)
|
|
530
|
+
return d
|
|
488
531
|
|
|
489
532
|
def to_json(self, path: str) -> None:
|
|
490
533
|
if not os.path.exists(os.path.dirname(path)):
|
sae_lens/evals.py
CHANGED
|
@@ -19,8 +19,8 @@ from tqdm import tqdm
|
|
|
19
19
|
from transformer_lens import HookedTransformer
|
|
20
20
|
from transformer_lens.hook_points import HookedRootModule
|
|
21
21
|
|
|
22
|
-
from sae_lens.
|
|
23
|
-
from sae_lens.
|
|
22
|
+
from sae_lens.loading.pretrained_saes_directory import get_pretrained_saes_directory
|
|
23
|
+
from sae_lens.saes.sae import SAE
|
|
24
24
|
from sae_lens.training.activations_store import ActivationsStore
|
|
25
25
|
|
|
26
26
|
|
|
@@ -279,7 +279,6 @@ def get_featurewise_weight_based_metrics(sae: SAE) -> dict[str, Any]:
|
|
|
279
279
|
unit_norm_decoder = (sae.W_dec.T / sae.W_dec.T.norm(dim=0, keepdim=True)).cpu()
|
|
280
280
|
|
|
281
281
|
encoder_norms = sae.W_enc.norm(dim=-2).cpu().tolist()
|
|
282
|
-
encoder_bias = sae.b_enc.cpu().tolist()
|
|
283
282
|
encoder_decoder_cosine_sim = (
|
|
284
283
|
torch.nn.functional.cosine_similarity(
|
|
285
284
|
unit_norm_decoder.T,
|
|
@@ -289,11 +288,13 @@ def get_featurewise_weight_based_metrics(sae: SAE) -> dict[str, Any]:
|
|
|
289
288
|
.tolist()
|
|
290
289
|
)
|
|
291
290
|
|
|
292
|
-
|
|
293
|
-
"encoder_bias": encoder_bias,
|
|
291
|
+
metrics = {
|
|
294
292
|
"encoder_norm": encoder_norms,
|
|
295
293
|
"encoder_decoder_cosine_sim": encoder_decoder_cosine_sim,
|
|
296
294
|
}
|
|
295
|
+
if hasattr(sae, "b_enc") and sae.b_enc is not None:
|
|
296
|
+
metrics["encoder_bias"] = sae.b_enc.cpu().tolist() # type: ignore
|
|
297
|
+
return metrics
|
|
297
298
|
|
|
298
299
|
|
|
299
300
|
def get_downstream_reconstruction_metrics(
|
|
@@ -17,7 +17,7 @@ from sae_lens.config import (
|
|
|
17
17
|
SAE_WEIGHTS_FILENAME,
|
|
18
18
|
SPARSITY_FILENAME,
|
|
19
19
|
)
|
|
20
|
-
from sae_lens.
|
|
20
|
+
from sae_lens.loading.pretrained_saes_directory import (
|
|
21
21
|
get_config_overrides,
|
|
22
22
|
get_pretrained_saes_directory,
|
|
23
23
|
get_repo_id_and_folder_name,
|
|
@@ -174,30 +174,38 @@ def get_sae_lens_config_from_disk(
|
|
|
174
174
|
|
|
175
175
|
|
|
176
176
|
def handle_config_defaulting(cfg_dict: dict[str, Any]) -> dict[str, Any]:
|
|
177
|
+
rename_keys_map = {
|
|
178
|
+
"hook_point": "hook_name",
|
|
179
|
+
"hook_point_layer": "hook_layer",
|
|
180
|
+
"hook_point_head_index": "hook_head_index",
|
|
181
|
+
"activation_fn_str": "activation_fn",
|
|
182
|
+
}
|
|
183
|
+
new_cfg = {rename_keys_map.get(k, k): v for k, v in cfg_dict.items()}
|
|
184
|
+
|
|
177
185
|
# Set default values for backwards compatibility
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
if "normalize_activations" in
|
|
188
|
-
|
|
186
|
+
new_cfg.setdefault("prepend_bos", True)
|
|
187
|
+
new_cfg.setdefault("dataset_trust_remote_code", True)
|
|
188
|
+
new_cfg.setdefault("apply_b_dec_to_input", True)
|
|
189
|
+
new_cfg.setdefault("finetuning_scaling_factor", False)
|
|
190
|
+
new_cfg.setdefault("sae_lens_training_version", None)
|
|
191
|
+
new_cfg.setdefault("activation_fn", new_cfg.get("activation_fn", "relu"))
|
|
192
|
+
new_cfg.setdefault("architecture", "standard")
|
|
193
|
+
new_cfg.setdefault("neuronpedia_id", None)
|
|
194
|
+
|
|
195
|
+
if "normalize_activations" in new_cfg and isinstance(
|
|
196
|
+
new_cfg["normalize_activations"], bool
|
|
189
197
|
):
|
|
190
198
|
# backwards compatibility
|
|
191
|
-
|
|
199
|
+
new_cfg["normalize_activations"] = (
|
|
192
200
|
"none"
|
|
193
|
-
if not
|
|
201
|
+
if not new_cfg["normalize_activations"]
|
|
194
202
|
else "expected_average_only_in"
|
|
195
203
|
)
|
|
196
204
|
|
|
197
|
-
|
|
198
|
-
|
|
205
|
+
new_cfg.setdefault("normalize_activations", "none")
|
|
206
|
+
new_cfg.setdefault("device", "cpu")
|
|
199
207
|
|
|
200
|
-
return
|
|
208
|
+
return new_cfg
|
|
201
209
|
|
|
202
210
|
|
|
203
211
|
def get_connor_rob_hook_z_config_from_hf(
|
|
@@ -223,7 +231,7 @@ def get_connor_rob_hook_z_config_from_hf(
|
|
|
223
231
|
"hook_name": old_cfg_dict["act_name"],
|
|
224
232
|
"hook_layer": old_cfg_dict["layer"],
|
|
225
233
|
"hook_head_index": None,
|
|
226
|
-
"
|
|
234
|
+
"activation_fn": "relu",
|
|
227
235
|
"apply_b_dec_to_input": True,
|
|
228
236
|
"finetuning_scaling_factor": False,
|
|
229
237
|
"sae_lens_training_version": None,
|
|
@@ -372,7 +380,7 @@ def get_gemma_2_config_from_hf(
|
|
|
372
380
|
"hook_name": hook_name,
|
|
373
381
|
"hook_layer": layer,
|
|
374
382
|
"hook_head_index": None,
|
|
375
|
-
"
|
|
383
|
+
"activation_fn": "relu",
|
|
376
384
|
"finetuning_scaling_factor": False,
|
|
377
385
|
"sae_lens_training_version": None,
|
|
378
386
|
"prepend_bos": True,
|
|
@@ -485,7 +493,7 @@ def get_llama_scope_config_from_hf(
|
|
|
485
493
|
"hook_name": old_cfg_dict["hook_point_in"],
|
|
486
494
|
"hook_layer": int(old_cfg_dict["hook_point_in"].split(".")[1]),
|
|
487
495
|
"hook_head_index": None,
|
|
488
|
-
"
|
|
496
|
+
"activation_fn": "relu",
|
|
489
497
|
"finetuning_scaling_factor": False,
|
|
490
498
|
"sae_lens_training_version": None,
|
|
491
499
|
"prepend_bos": True,
|
|
@@ -597,8 +605,8 @@ def get_dictionary_learning_config_1_from_hf(
|
|
|
597
605
|
|
|
598
606
|
hook_point_name = f"blocks.{trainer['layer']}.hook_resid_post"
|
|
599
607
|
|
|
600
|
-
|
|
601
|
-
activation_fn_kwargs = {"k": trainer["k"]} if
|
|
608
|
+
activation_fn = "topk" if trainer["dict_class"] == "AutoEncoderTopK" else "relu"
|
|
609
|
+
activation_fn_kwargs = {"k": trainer["k"]} if activation_fn == "topk" else {}
|
|
602
610
|
|
|
603
611
|
return {
|
|
604
612
|
"architecture": (
|
|
@@ -612,7 +620,7 @@ def get_dictionary_learning_config_1_from_hf(
|
|
|
612
620
|
"hook_name": hook_point_name,
|
|
613
621
|
"hook_layer": trainer["layer"],
|
|
614
622
|
"hook_head_index": None,
|
|
615
|
-
"
|
|
623
|
+
"activation_fn": activation_fn,
|
|
616
624
|
"activation_fn_kwargs": activation_fn_kwargs,
|
|
617
625
|
"apply_b_dec_to_input": True,
|
|
618
626
|
"finetuning_scaling_factor": False,
|
|
@@ -655,7 +663,7 @@ def get_deepseek_r1_config_from_hf(
|
|
|
655
663
|
"dataset_path": "lmsys/lmsys-chat-1m",
|
|
656
664
|
"dataset_trust_remote_code": True,
|
|
657
665
|
"sae_lens_training_version": None,
|
|
658
|
-
"
|
|
666
|
+
"activation_fn": "relu",
|
|
659
667
|
"normalize_activations": "none",
|
|
660
668
|
"device": device,
|
|
661
669
|
"apply_b_dec_to_input": False,
|
|
@@ -810,7 +818,7 @@ def get_llama_scope_r1_distill_config_from_hf(
|
|
|
810
818
|
"hook_name": huggingface_cfg_dict["hook_point_in"],
|
|
811
819
|
"hook_layer": int(huggingface_cfg_dict["hook_point_in"].split(".")[1]),
|
|
812
820
|
"hook_head_index": None,
|
|
813
|
-
"
|
|
821
|
+
"activation_fn": "relu",
|
|
814
822
|
"finetuning_scaling_factor": False,
|
|
815
823
|
"sae_lens_training_version": None,
|
|
816
824
|
"prepend_bos": True,
|
sae_lens/regsitry.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING
|
|
2
|
+
|
|
3
|
+
# avoid circular imports
|
|
4
|
+
if TYPE_CHECKING:
|
|
5
|
+
from sae_lens.saes.sae import SAE, TrainingSAE
|
|
6
|
+
|
|
7
|
+
SAE_CLASS_REGISTRY: dict[str, "type[SAE]"] = {}
|
|
8
|
+
SAE_TRAINING_CLASS_REGISTRY: dict[str, "type[TrainingSAE]"] = {}
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def register_sae_class(architecture: str, sae_class: "type[SAE]") -> None:
|
|
12
|
+
if architecture in SAE_CLASS_REGISTRY:
|
|
13
|
+
raise ValueError(
|
|
14
|
+
f"SAE class for architecture {architecture} already registered."
|
|
15
|
+
)
|
|
16
|
+
SAE_CLASS_REGISTRY[architecture] = sae_class
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def register_sae_training_class(
|
|
20
|
+
architecture: str, sae_training_class: "type[TrainingSAE]"
|
|
21
|
+
) -> None:
|
|
22
|
+
if architecture in SAE_TRAINING_CLASS_REGISTRY:
|
|
23
|
+
raise ValueError(
|
|
24
|
+
f"SAE training class for architecture {architecture} already registered."
|
|
25
|
+
)
|
|
26
|
+
SAE_TRAINING_CLASS_REGISTRY[architecture] = sae_training_class
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def get_sae_class(architecture: str) -> "type[SAE]":
|
|
30
|
+
return SAE_CLASS_REGISTRY[architecture]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def get_sae_training_class(architecture: str) -> "type[TrainingSAE]":
|
|
34
|
+
return SAE_TRAINING_CLASS_REGISTRY[architecture]
|
sae_lens/sae_training_runner.py
CHANGED
|
@@ -13,10 +13,10 @@ from transformer_lens.hook_points import HookedRootModule
|
|
|
13
13
|
from sae_lens import logger
|
|
14
14
|
from sae_lens.config import HfDataset, LanguageModelSAERunnerConfig
|
|
15
15
|
from sae_lens.load_model import load_model
|
|
16
|
+
from sae_lens.saes.sae import TrainingSAE, TrainingSAEConfig
|
|
16
17
|
from sae_lens.training.activations_store import ActivationsStore
|
|
17
18
|
from sae_lens.training.geometric_median import compute_geometric_median
|
|
18
19
|
from sae_lens.training.sae_trainer import SAETrainer
|
|
19
|
-
from sae_lens.training.training_sae import TrainingSAE, TrainingSAEConfig
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
class InterruptedException(Exception):
|
|
@@ -73,14 +73,14 @@ class SAETrainingRunner:
|
|
|
73
73
|
|
|
74
74
|
if override_sae is None:
|
|
75
75
|
if self.cfg.from_pretrained_path is not None:
|
|
76
|
-
self.sae = TrainingSAE.
|
|
76
|
+
self.sae = TrainingSAE.load_from_disk(
|
|
77
77
|
self.cfg.from_pretrained_path, self.cfg.device
|
|
78
78
|
)
|
|
79
79
|
else:
|
|
80
|
-
self.sae = TrainingSAE(
|
|
80
|
+
self.sae = TrainingSAE.from_dict(
|
|
81
81
|
TrainingSAEConfig.from_dict(
|
|
82
82
|
self.cfg.get_training_sae_cfg_dict(),
|
|
83
|
-
)
|
|
83
|
+
).to_dict()
|
|
84
84
|
)
|
|
85
85
|
self._init_sae_group_b_decs()
|
|
86
86
|
else:
|
|
@@ -91,13 +91,13 @@ class SAETrainingRunner:
|
|
|
91
91
|
Run the training of the SAE.
|
|
92
92
|
"""
|
|
93
93
|
|
|
94
|
-
if self.cfg.log_to_wandb:
|
|
94
|
+
if self.cfg.logger.log_to_wandb:
|
|
95
95
|
wandb.init(
|
|
96
|
-
project=self.cfg.wandb_project,
|
|
97
|
-
entity=self.cfg.wandb_entity,
|
|
96
|
+
project=self.cfg.logger.wandb_project,
|
|
97
|
+
entity=self.cfg.logger.wandb_entity,
|
|
98
98
|
config=cast(Any, self.cfg),
|
|
99
|
-
name=self.cfg.run_name,
|
|
100
|
-
id=self.cfg.wandb_id,
|
|
99
|
+
name=self.cfg.logger.run_name,
|
|
100
|
+
id=self.cfg.logger.wandb_id,
|
|
101
101
|
)
|
|
102
102
|
|
|
103
103
|
trainer = SAETrainer(
|
|
@@ -111,7 +111,7 @@ class SAETrainingRunner:
|
|
|
111
111
|
self._compile_if_needed()
|
|
112
112
|
sae = self.run_trainer_with_interruption_handling(trainer)
|
|
113
113
|
|
|
114
|
-
if self.cfg.log_to_wandb:
|
|
114
|
+
if self.cfg.logger.log_to_wandb:
|
|
115
115
|
wandb.finish()
|
|
116
116
|
|
|
117
117
|
return sae
|
|
@@ -175,7 +175,7 @@ class SAETrainingRunner:
|
|
|
175
175
|
layer_acts,
|
|
176
176
|
maxiter=100,
|
|
177
177
|
).median
|
|
178
|
-
self.sae.initialize_b_dec_with_precalculated(median)
|
|
178
|
+
self.sae.initialize_b_dec_with_precalculated(median)
|
|
179
179
|
elif self.cfg.b_dec_init_method == "mean":
|
|
180
180
|
self.activations_store.set_norm_scaling_factor_if_needed()
|
|
181
181
|
layer_acts = self.activations_store.storage_buffer.detach().cpu()[:, 0, :]
|
|
@@ -208,29 +208,14 @@ class SAETrainingRunner:
|
|
|
208
208
|
with open(cfg_path, "w") as f:
|
|
209
209
|
json.dump(config, f)
|
|
210
210
|
|
|
211
|
-
if trainer.cfg.log_to_wandb:
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
sae_name,
|
|
219
|
-
type="model",
|
|
220
|
-
metadata=dict(trainer.cfg.__dict__),
|
|
221
|
-
)
|
|
222
|
-
model_artifact.add_file(str(weights_path))
|
|
223
|
-
model_artifact.add_file(str(cfg_path))
|
|
224
|
-
wandb.log_artifact(model_artifact, aliases=wandb_aliases)
|
|
225
|
-
|
|
226
|
-
# save log feature sparsity
|
|
227
|
-
sparsity_artifact = wandb.Artifact(
|
|
228
|
-
f"{sae_name}_log_feature_sparsity",
|
|
229
|
-
type="log_feature_sparsity",
|
|
230
|
-
metadata=dict(trainer.cfg.__dict__),
|
|
211
|
+
if trainer.cfg.logger.log_to_wandb:
|
|
212
|
+
trainer.cfg.logger.log(
|
|
213
|
+
trainer,
|
|
214
|
+
weights_path,
|
|
215
|
+
cfg_path,
|
|
216
|
+
sparsity_path=sparsity_path,
|
|
217
|
+
wandb_aliases=wandb_aliases,
|
|
231
218
|
)
|
|
232
|
-
sparsity_artifact.add_file(str(sparsity_path))
|
|
233
|
-
wandb.log_artifact(sparsity_artifact)
|
|
234
219
|
|
|
235
220
|
|
|
236
221
|
def _parse_cfg_args(args: Sequence[str]) -> LanguageModelSAERunnerConfig:
|