sae-lens 6.16.3__tar.gz → 6.18.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sae-lens might be problematic. Click here for more details.
- {sae_lens-6.16.3 → sae_lens-6.18.0}/PKG-INFO +1 -1
- {sae_lens-6.16.3 → sae_lens-6.18.0}/pyproject.toml +1 -1
- {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/__init__.py +1 -1
- {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/cache_activations_runner.py +1 -1
- {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/config.py +39 -2
- {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/constants.py +1 -0
- {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/llm_sae_training_runner.py +9 -4
- {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/saes/sae.py +7 -1
- {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/training/activation_scaler.py +7 -0
- {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/training/activations_store.py +46 -3
- {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/training/optim.py +11 -0
- {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/training/sae_trainer.py +49 -11
- {sae_lens-6.16.3 → sae_lens-6.18.0}/LICENSE +0 -0
- {sae_lens-6.16.3 → sae_lens-6.18.0}/README.md +0 -0
- {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/analysis/__init__.py +0 -0
- {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
- {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/analysis/neuronpedia_integration.py +0 -0
- {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/evals.py +0 -0
- {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/load_model.py +0 -0
- {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/loading/__init__.py +0 -0
- {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/loading/pretrained_sae_loaders.py +0 -0
- {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/loading/pretrained_saes_directory.py +0 -0
- {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/pretokenize_runner.py +0 -0
- {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/pretrained_saes.yaml +0 -0
- {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/registry.py +0 -0
- {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/saes/__init__.py +0 -0
- {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/saes/batchtopk_sae.py +0 -0
- {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/saes/gated_sae.py +0 -0
- {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/saes/jumprelu_sae.py +0 -0
- {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/saes/matryoshka_batchtopk_sae.py +0 -0
- {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/saes/standard_sae.py +0 -0
- {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/saes/topk_sae.py +0 -0
- {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/saes/transcoder.py +0 -0
- {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/tokenization_and_batching.py +0 -0
- {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/training/__init__.py +0 -0
- {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/training/mixing_buffer.py +0 -0
- {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/training/types.py +0 -0
- {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
- {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/tutorial/tsea.py +0 -0
- {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/util.py +0 -0
|
@@ -10,7 +10,7 @@ from datasets import Array2D, Dataset, Features, Sequence, Value
|
|
|
10
10
|
from datasets.fingerprint import generate_fingerprint
|
|
11
11
|
from huggingface_hub import HfApi
|
|
12
12
|
from jaxtyping import Float, Int
|
|
13
|
-
from tqdm import tqdm
|
|
13
|
+
from tqdm.auto import tqdm
|
|
14
14
|
from transformer_lens.HookedTransformer import HookedRootModule
|
|
15
15
|
|
|
16
16
|
from sae_lens import logger
|
|
@@ -18,6 +18,7 @@ from datasets import (
|
|
|
18
18
|
|
|
19
19
|
from sae_lens import __version__, logger
|
|
20
20
|
from sae_lens.constants import DTYPE_MAP
|
|
21
|
+
from sae_lens.registry import get_sae_training_class
|
|
21
22
|
from sae_lens.saes.sae import TrainingSAEConfig
|
|
22
23
|
|
|
23
24
|
if TYPE_CHECKING:
|
|
@@ -171,6 +172,7 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
|
171
172
|
n_checkpoints (int): The number of checkpoints to save during training. 0 means no checkpoints.
|
|
172
173
|
checkpoint_path (str | None): The path to save checkpoints. A unique ID will be appended to this path. Set to None to disable checkpoint saving. (default is "checkpoints")
|
|
173
174
|
save_final_checkpoint (bool): Whether to include an additional final checkpoint when training is finished. (default is False).
|
|
175
|
+
resume_from_checkpoint (str | None): The path to the checkpoint to resume training from. (default is None).
|
|
174
176
|
output_path (str | None): The path to save outputs. Set to None to disable output saving. (default is "output")
|
|
175
177
|
verbose (bool): Whether to print verbose output. (default is True)
|
|
176
178
|
model_kwargs (dict[str, Any]): Keyword arguments for `model.run_with_cache`
|
|
@@ -261,6 +263,7 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
|
261
263
|
checkpoint_path: str | None = "checkpoints"
|
|
262
264
|
save_final_checkpoint: bool = False
|
|
263
265
|
output_path: str | None = "output"
|
|
266
|
+
resume_from_checkpoint: str | None = None
|
|
264
267
|
|
|
265
268
|
# Misc
|
|
266
269
|
verbose: bool = True
|
|
@@ -385,8 +388,11 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
|
385
388
|
return self.sae.to_dict()
|
|
386
389
|
|
|
387
390
|
def to_dict(self) -> dict[str, Any]:
|
|
388
|
-
|
|
389
|
-
|
|
391
|
+
"""
|
|
392
|
+
Convert the config to a dictionary.
|
|
393
|
+
"""
|
|
394
|
+
|
|
395
|
+
d = asdict(self)
|
|
390
396
|
|
|
391
397
|
d["logger"] = asdict(self.logger)
|
|
392
398
|
d["sae"] = self.sae.to_dict()
|
|
@@ -396,6 +402,37 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
|
396
402
|
d["act_store_device"] = str(self.act_store_device)
|
|
397
403
|
return d
|
|
398
404
|
|
|
405
|
+
@classmethod
|
|
406
|
+
def from_dict(cls, cfg_dict: dict[str, Any]) -> "LanguageModelSAERunnerConfig[Any]":
|
|
407
|
+
"""
|
|
408
|
+
Load a LanguageModelSAERunnerConfig from a dictionary given by `to_dict`.
|
|
409
|
+
|
|
410
|
+
Args:
|
|
411
|
+
cfg_dict (dict[str, Any]): The dictionary to load the config from.
|
|
412
|
+
|
|
413
|
+
Returns:
|
|
414
|
+
LanguageModelSAERunnerConfig: The loaded config.
|
|
415
|
+
"""
|
|
416
|
+
if "sae" not in cfg_dict:
|
|
417
|
+
raise ValueError("sae field is required in the config dictionary")
|
|
418
|
+
if "architecture" not in cfg_dict["sae"]:
|
|
419
|
+
raise ValueError("architecture field is required in the sae dictionary")
|
|
420
|
+
if "logger" not in cfg_dict:
|
|
421
|
+
raise ValueError("logger field is required in the config dictionary")
|
|
422
|
+
sae_config_class = get_sae_training_class(cfg_dict["sae"]["architecture"])[1]
|
|
423
|
+
sae_cfg = sae_config_class.from_dict(cfg_dict["sae"])
|
|
424
|
+
logger_cfg = LoggingConfig(**cfg_dict["logger"])
|
|
425
|
+
updated_cfg_dict: dict[str, Any] = {
|
|
426
|
+
**cfg_dict,
|
|
427
|
+
"sae": sae_cfg,
|
|
428
|
+
"logger": logger_cfg,
|
|
429
|
+
}
|
|
430
|
+
output = cls(**updated_cfg_dict)
|
|
431
|
+
# the post_init always appends to checkpoint path, so we need to set it explicitly here.
|
|
432
|
+
if "checkpoint_path" in cfg_dict:
|
|
433
|
+
output.checkpoint_path = cfg_dict["checkpoint_path"]
|
|
434
|
+
return output
|
|
435
|
+
|
|
399
436
|
def to_sae_trainer_config(self) -> "SAETrainerConfig":
|
|
400
437
|
return SAETrainerConfig(
|
|
401
438
|
n_checkpoints=self.n_checkpoints,
|
|
@@ -17,5 +17,6 @@ SAE_WEIGHTS_FILENAME = "sae_weights.safetensors"
|
|
|
17
17
|
SAE_CFG_FILENAME = "cfg.json"
|
|
18
18
|
RUNNER_CFG_FILENAME = "runner_cfg.json"
|
|
19
19
|
SPARSIFY_WEIGHTS_FILENAME = "sae.safetensors"
|
|
20
|
+
TRAINER_STATE_FILENAME = "trainer_state.pt"
|
|
20
21
|
ACTIVATIONS_STORE_STATE_FILENAME = "activations_store_state.safetensors"
|
|
21
22
|
ACTIVATION_SCALER_CFG_FILENAME = "activation_scaler.json"
|
|
@@ -16,7 +16,6 @@ from typing_extensions import deprecated
|
|
|
16
16
|
from sae_lens import logger
|
|
17
17
|
from sae_lens.config import HfDataset, LanguageModelSAERunnerConfig
|
|
18
18
|
from sae_lens.constants import (
|
|
19
|
-
ACTIVATIONS_STORE_STATE_FILENAME,
|
|
20
19
|
RUNNER_CFG_FILENAME,
|
|
21
20
|
SPARSITY_FILENAME,
|
|
22
21
|
)
|
|
@@ -112,6 +111,7 @@ class LanguageModelSAETrainingRunner:
|
|
|
112
111
|
override_dataset: HfDataset | None = None,
|
|
113
112
|
override_model: HookedRootModule | None = None,
|
|
114
113
|
override_sae: TrainingSAE[Any] | None = None,
|
|
114
|
+
resume_from_checkpoint: Path | str | None = None,
|
|
115
115
|
):
|
|
116
116
|
if override_dataset is not None:
|
|
117
117
|
logger.warning(
|
|
@@ -153,6 +153,7 @@ class LanguageModelSAETrainingRunner:
|
|
|
153
153
|
)
|
|
154
154
|
else:
|
|
155
155
|
self.sae = override_sae
|
|
156
|
+
|
|
156
157
|
self.sae.to(self.cfg.device)
|
|
157
158
|
|
|
158
159
|
def run(self):
|
|
@@ -185,6 +186,12 @@ class LanguageModelSAETrainingRunner:
|
|
|
185
186
|
cfg=self.cfg.to_sae_trainer_config(),
|
|
186
187
|
)
|
|
187
188
|
|
|
189
|
+
if self.cfg.resume_from_checkpoint is not None:
|
|
190
|
+
logger.info(f"Resuming from checkpoint: {self.cfg.resume_from_checkpoint}")
|
|
191
|
+
trainer.load_trainer_state(self.cfg.resume_from_checkpoint)
|
|
192
|
+
self.sae.load_weights_from_checkpoint(self.cfg.resume_from_checkpoint)
|
|
193
|
+
self.activations_store.load_from_checkpoint(self.cfg.resume_from_checkpoint)
|
|
194
|
+
|
|
188
195
|
self._compile_if_needed()
|
|
189
196
|
sae = self.run_trainer_with_interruption_handling(trainer)
|
|
190
197
|
|
|
@@ -304,9 +311,7 @@ class LanguageModelSAETrainingRunner:
|
|
|
304
311
|
if checkpoint_path is None:
|
|
305
312
|
return
|
|
306
313
|
|
|
307
|
-
self.activations_store.
|
|
308
|
-
str(checkpoint_path / ACTIVATIONS_STORE_STATE_FILENAME)
|
|
309
|
-
)
|
|
314
|
+
self.activations_store.save_to_checkpoint(checkpoint_path)
|
|
310
315
|
|
|
311
316
|
runner_config = self.cfg.to_dict()
|
|
312
317
|
with open(checkpoint_path / RUNNER_CFG_FILENAME, "w") as f:
|
|
@@ -21,7 +21,7 @@ import einops
|
|
|
21
21
|
import torch
|
|
22
22
|
from jaxtyping import Float
|
|
23
23
|
from numpy.typing import NDArray
|
|
24
|
-
from safetensors.torch import save_file
|
|
24
|
+
from safetensors.torch import load_file, save_file
|
|
25
25
|
from torch import nn
|
|
26
26
|
from transformer_lens.hook_points import HookedRootModule, HookPoint
|
|
27
27
|
from typing_extensions import deprecated, overload, override
|
|
@@ -1018,6 +1018,12 @@ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
|
|
|
1018
1018
|
) -> type[TrainingSAEConfig]:
|
|
1019
1019
|
return get_sae_training_class(architecture)[1]
|
|
1020
1020
|
|
|
1021
|
+
def load_weights_from_checkpoint(self, checkpoint_path: Path | str) -> None:
|
|
1022
|
+
checkpoint_path = Path(checkpoint_path)
|
|
1023
|
+
state_dict = load_file(checkpoint_path / SAE_WEIGHTS_FILENAME)
|
|
1024
|
+
self.process_state_dict_for_loading(state_dict)
|
|
1025
|
+
self.load_state_dict(state_dict)
|
|
1026
|
+
|
|
1021
1027
|
|
|
1022
1028
|
_blank_hook = nn.Identity()
|
|
1023
1029
|
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import json
|
|
2
2
|
from dataclasses import dataclass
|
|
3
|
+
from pathlib import Path
|
|
3
4
|
from statistics import mean
|
|
4
5
|
|
|
5
6
|
import torch
|
|
@@ -51,3 +52,9 @@ class ActivationScaler:
|
|
|
51
52
|
|
|
52
53
|
with open(file_path, "w") as f:
|
|
53
54
|
json.dump({"scaling_factor": self.scaling_factor}, f)
|
|
55
|
+
|
|
56
|
+
def load(self, file_path: str | Path):
|
|
57
|
+
"""load the state dict from a file in json format"""
|
|
58
|
+
with open(file_path) as f:
|
|
59
|
+
data = json.load(f)
|
|
60
|
+
self.scaling_factor = data["scaling_factor"]
|
|
@@ -4,6 +4,7 @@ import json
|
|
|
4
4
|
import os
|
|
5
5
|
import warnings
|
|
6
6
|
from collections.abc import Generator, Iterator, Sequence
|
|
7
|
+
from pathlib import Path
|
|
7
8
|
from typing import Any, Literal, cast
|
|
8
9
|
|
|
9
10
|
import datasets
|
|
@@ -13,8 +14,8 @@ from huggingface_hub import hf_hub_download
|
|
|
13
14
|
from huggingface_hub.utils import HfHubHTTPError
|
|
14
15
|
from jaxtyping import Float, Int
|
|
15
16
|
from requests import HTTPError
|
|
16
|
-
from safetensors.torch import save_file
|
|
17
|
-
from tqdm import tqdm
|
|
17
|
+
from safetensors.torch import load_file, save_file
|
|
18
|
+
from tqdm.auto import tqdm
|
|
18
19
|
from transformer_lens.hook_points import HookedRootModule
|
|
19
20
|
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
|
20
21
|
|
|
@@ -24,7 +25,7 @@ from sae_lens.config import (
|
|
|
24
25
|
HfDataset,
|
|
25
26
|
LanguageModelSAERunnerConfig,
|
|
26
27
|
)
|
|
27
|
-
from sae_lens.constants import DTYPE_MAP
|
|
28
|
+
from sae_lens.constants import ACTIVATIONS_STORE_STATE_FILENAME, DTYPE_MAP
|
|
28
29
|
from sae_lens.pretokenize_runner import get_special_token_from_cfg
|
|
29
30
|
from sae_lens.saes.sae import SAE, T_SAE_CONFIG, T_TRAINING_SAE_CONFIG
|
|
30
31
|
from sae_lens.tokenization_and_batching import concat_and_batch_sequences
|
|
@@ -729,6 +730,48 @@ class ActivationsStore:
|
|
|
729
730
|
"""save the state dict to a file in safetensors format"""
|
|
730
731
|
save_file(self.state_dict(), file_path)
|
|
731
732
|
|
|
733
|
+
def save_to_checkpoint(self, checkpoint_path: str | Path):
|
|
734
|
+
"""Save the state dict to a checkpoint path"""
|
|
735
|
+
self.save(str(Path(checkpoint_path) / ACTIVATIONS_STORE_STATE_FILENAME))
|
|
736
|
+
|
|
737
|
+
def load_from_checkpoint(self, checkpoint_path: str | Path):
|
|
738
|
+
"""Load the state dict from a checkpoint path"""
|
|
739
|
+
self.load(str(Path(checkpoint_path) / ACTIVATIONS_STORE_STATE_FILENAME))
|
|
740
|
+
|
|
741
|
+
def load(self, file_path: str):
|
|
742
|
+
"""Load the state dict from a file in safetensors format"""
|
|
743
|
+
|
|
744
|
+
state_dict = load_file(file_path)
|
|
745
|
+
|
|
746
|
+
if "n_dataset_processed" in state_dict:
|
|
747
|
+
target_n_dataset_processed = state_dict["n_dataset_processed"].item()
|
|
748
|
+
|
|
749
|
+
# Only fast-forward if needed
|
|
750
|
+
|
|
751
|
+
if target_n_dataset_processed > self.n_dataset_processed:
|
|
752
|
+
logger.info(
|
|
753
|
+
"Fast-forwarding through dataset samples to match checkpoint position"
|
|
754
|
+
)
|
|
755
|
+
samples_to_skip = target_n_dataset_processed - self.n_dataset_processed
|
|
756
|
+
|
|
757
|
+
pbar = tqdm(
|
|
758
|
+
total=samples_to_skip,
|
|
759
|
+
desc="Fast-forwarding through dataset",
|
|
760
|
+
leave=False,
|
|
761
|
+
)
|
|
762
|
+
while target_n_dataset_processed > self.n_dataset_processed:
|
|
763
|
+
start = self.n_dataset_processed
|
|
764
|
+
try:
|
|
765
|
+
# Just consume and ignore the values to fast-forward
|
|
766
|
+
next(self.iterable_sequences)
|
|
767
|
+
except StopIteration:
|
|
768
|
+
logger.warning(
|
|
769
|
+
"Dataset exhausted during fast-forward. Resetting dataset."
|
|
770
|
+
)
|
|
771
|
+
self.iterable_sequences = self._iterate_tokenized_sequences()
|
|
772
|
+
pbar.update(self.n_dataset_processed - start)
|
|
773
|
+
pbar.close()
|
|
774
|
+
|
|
732
775
|
|
|
733
776
|
def validate_pretokenized_dataset_tokenizer(
|
|
734
777
|
dataset_path: str, model_tokenizer: PreTrainedTokenizerBase
|
|
@@ -2,6 +2,8 @@
|
|
|
2
2
|
Took the LR scheduler from my previous work: https://github.com/jbloomAus/DecisionTransformerInterpretability/blob/ee55df35cdb92e81d689c72fb9dd5a7252893363/src/decision_transformer/utils.py#L425
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
5
7
|
import torch.optim as optim
|
|
6
8
|
import torch.optim.lr_scheduler as lr_scheduler
|
|
7
9
|
|
|
@@ -150,3 +152,12 @@ class CoefficientScheduler:
|
|
|
150
152
|
def value(self) -> float:
|
|
151
153
|
"""Returns the current scalar value."""
|
|
152
154
|
return self.current_value
|
|
155
|
+
|
|
156
|
+
def state_dict(self) -> dict[str, Any]:
|
|
157
|
+
return {
|
|
158
|
+
"current_step": self.current_step,
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
def load_state_dict(self, state_dict: dict[str, Any]):
|
|
162
|
+
for k in state_dict:
|
|
163
|
+
setattr(self, k, state_dict[k])
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import contextlib
|
|
2
|
+
import math
|
|
2
3
|
from pathlib import Path
|
|
3
4
|
from typing import Any, Callable, Generic, Protocol
|
|
4
5
|
|
|
@@ -10,7 +11,11 @@ from tqdm.auto import tqdm
|
|
|
10
11
|
|
|
11
12
|
from sae_lens import __version__
|
|
12
13
|
from sae_lens.config import SAETrainerConfig
|
|
13
|
-
from sae_lens.constants import
|
|
14
|
+
from sae_lens.constants import (
|
|
15
|
+
ACTIVATION_SCALER_CFG_FILENAME,
|
|
16
|
+
SPARSITY_FILENAME,
|
|
17
|
+
TRAINER_STATE_FILENAME,
|
|
18
|
+
)
|
|
14
19
|
from sae_lens.saes.sae import (
|
|
15
20
|
T_TRAINING_SAE,
|
|
16
21
|
T_TRAINING_SAE_CONFIG,
|
|
@@ -56,6 +61,7 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
|
56
61
|
data_provider: DataProvider
|
|
57
62
|
activation_scaler: ActivationScaler
|
|
58
63
|
evaluator: Evaluator[T_TRAINING_SAE] | None
|
|
64
|
+
coefficient_schedulers: dict[str, CoefficientScheduler]
|
|
59
65
|
|
|
60
66
|
def __init__(
|
|
61
67
|
self,
|
|
@@ -84,7 +90,9 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
|
84
90
|
range(
|
|
85
91
|
0,
|
|
86
92
|
cfg.total_training_samples,
|
|
87
|
-
|
|
93
|
+
math.ceil(
|
|
94
|
+
cfg.total_training_samples / (self.cfg.n_checkpoints + 1)
|
|
95
|
+
),
|
|
88
96
|
)
|
|
89
97
|
)[1:]
|
|
90
98
|
|
|
@@ -93,11 +101,6 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
|
93
101
|
sae.cfg.d_sae, device=cfg.device
|
|
94
102
|
)
|
|
95
103
|
self.n_frac_active_samples = 0
|
|
96
|
-
# we don't train the scaling factor (initially)
|
|
97
|
-
# set requires grad to false for the scaling factor
|
|
98
|
-
for name, param in self.sae.named_parameters():
|
|
99
|
-
if "scaling_factor" in name:
|
|
100
|
-
param.requires_grad = False
|
|
101
104
|
|
|
102
105
|
self.optimizer = Adam(
|
|
103
106
|
sae.parameters(),
|
|
@@ -210,10 +213,7 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
|
210
213
|
sparsity_path = checkpoint_path / SPARSITY_FILENAME
|
|
211
214
|
save_file({"sparsity": self.log_feature_sparsity}, sparsity_path)
|
|
212
215
|
|
|
213
|
-
|
|
214
|
-
checkpoint_path / ACTIVATION_SCALER_CFG_FILENAME
|
|
215
|
-
)
|
|
216
|
-
self.activation_scaler.save(str(activation_scaler_path))
|
|
216
|
+
self.save_trainer_state(checkpoint_path)
|
|
217
217
|
|
|
218
218
|
if self.cfg.logger.log_to_wandb:
|
|
219
219
|
self.cfg.logger.log(
|
|
@@ -227,6 +227,44 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
|
227
227
|
if self.save_checkpoint_fn is not None:
|
|
228
228
|
self.save_checkpoint_fn(checkpoint_path=checkpoint_path)
|
|
229
229
|
|
|
230
|
+
def save_trainer_state(self, checkpoint_path: Path) -> None:
|
|
231
|
+
checkpoint_path.mkdir(exist_ok=True, parents=True)
|
|
232
|
+
scheduler_state_dicts = {
|
|
233
|
+
name: scheduler.state_dict()
|
|
234
|
+
for name, scheduler in self.coefficient_schedulers.items()
|
|
235
|
+
}
|
|
236
|
+
torch.save(
|
|
237
|
+
{
|
|
238
|
+
"optimizer": self.optimizer.state_dict(),
|
|
239
|
+
"lr_scheduler": self.lr_scheduler.state_dict(),
|
|
240
|
+
"n_training_samples": self.n_training_samples,
|
|
241
|
+
"n_training_steps": self.n_training_steps,
|
|
242
|
+
"act_freq_scores": self.act_freq_scores,
|
|
243
|
+
"n_forward_passes_since_fired": self.n_forward_passes_since_fired,
|
|
244
|
+
"n_frac_active_samples": self.n_frac_active_samples,
|
|
245
|
+
"started_fine_tuning": self.started_fine_tuning,
|
|
246
|
+
"coefficient_schedulers": scheduler_state_dicts,
|
|
247
|
+
},
|
|
248
|
+
str(checkpoint_path / TRAINER_STATE_FILENAME),
|
|
249
|
+
)
|
|
250
|
+
activation_scaler_path = checkpoint_path / ACTIVATION_SCALER_CFG_FILENAME
|
|
251
|
+
self.activation_scaler.save(str(activation_scaler_path))
|
|
252
|
+
|
|
253
|
+
def load_trainer_state(self, checkpoint_path: Path | str) -> None:
|
|
254
|
+
checkpoint_path = Path(checkpoint_path)
|
|
255
|
+
self.activation_scaler.load(checkpoint_path / ACTIVATION_SCALER_CFG_FILENAME)
|
|
256
|
+
state_dict = torch.load(checkpoint_path / TRAINER_STATE_FILENAME)
|
|
257
|
+
self.optimizer.load_state_dict(state_dict["optimizer"])
|
|
258
|
+
self.lr_scheduler.load_state_dict(state_dict["lr_scheduler"])
|
|
259
|
+
self.n_training_samples = state_dict["n_training_samples"]
|
|
260
|
+
self.n_training_steps = state_dict["n_training_steps"]
|
|
261
|
+
self.act_freq_scores = state_dict["act_freq_scores"]
|
|
262
|
+
self.n_forward_passes_since_fired = state_dict["n_forward_passes_since_fired"]
|
|
263
|
+
self.n_frac_active_samples = state_dict["n_frac_active_samples"]
|
|
264
|
+
self.started_fine_tuning = state_dict["started_fine_tuning"]
|
|
265
|
+
for name, scheduler_state_dict in state_dict["coefficient_schedulers"].items():
|
|
266
|
+
self.coefficient_schedulers[name].load_state_dict(scheduler_state_dict)
|
|
267
|
+
|
|
230
268
|
def _train_step(
|
|
231
269
|
self,
|
|
232
270
|
sae: T_TRAINING_SAE,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|