sae-lens 6.16.3__tar.gz → 6.18.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sae-lens might be problematic. Click here for more details.

Files changed (40) hide show
  1. {sae_lens-6.16.3 → sae_lens-6.18.0}/PKG-INFO +1 -1
  2. {sae_lens-6.16.3 → sae_lens-6.18.0}/pyproject.toml +1 -1
  3. {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/__init__.py +1 -1
  4. {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/cache_activations_runner.py +1 -1
  5. {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/config.py +39 -2
  6. {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/constants.py +1 -0
  7. {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/llm_sae_training_runner.py +9 -4
  8. {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/saes/sae.py +7 -1
  9. {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/training/activation_scaler.py +7 -0
  10. {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/training/activations_store.py +46 -3
  11. {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/training/optim.py +11 -0
  12. {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/training/sae_trainer.py +49 -11
  13. {sae_lens-6.16.3 → sae_lens-6.18.0}/LICENSE +0 -0
  14. {sae_lens-6.16.3 → sae_lens-6.18.0}/README.md +0 -0
  15. {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/analysis/__init__.py +0 -0
  16. {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
  17. {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/analysis/neuronpedia_integration.py +0 -0
  18. {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/evals.py +0 -0
  19. {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/load_model.py +0 -0
  20. {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/loading/__init__.py +0 -0
  21. {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/loading/pretrained_sae_loaders.py +0 -0
  22. {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/loading/pretrained_saes_directory.py +0 -0
  23. {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/pretokenize_runner.py +0 -0
  24. {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/pretrained_saes.yaml +0 -0
  25. {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/registry.py +0 -0
  26. {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/saes/__init__.py +0 -0
  27. {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/saes/batchtopk_sae.py +0 -0
  28. {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/saes/gated_sae.py +0 -0
  29. {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/saes/jumprelu_sae.py +0 -0
  30. {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/saes/matryoshka_batchtopk_sae.py +0 -0
  31. {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/saes/standard_sae.py +0 -0
  32. {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/saes/topk_sae.py +0 -0
  33. {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/saes/transcoder.py +0 -0
  34. {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/tokenization_and_batching.py +0 -0
  35. {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/training/__init__.py +0 -0
  36. {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/training/mixing_buffer.py +0 -0
  37. {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/training/types.py +0 -0
  38. {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
  39. {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/tutorial/tsea.py +0 -0
  40. {sae_lens-6.16.3 → sae_lens-6.18.0}/sae_lens/util.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sae-lens
3
- Version: 6.16.3
3
+ Version: 6.18.0
4
4
  Summary: Training and Analyzing Sparse Autoencoders (SAEs)
5
5
  License: MIT
6
6
  License-File: LICENSE
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "sae-lens"
3
- version = "6.16.3"
3
+ version = "6.18.0"
4
4
  description = "Training and Analyzing Sparse Autoencoders (SAEs)"
5
5
  authors = ["Joseph Bloom"]
6
6
  readme = "README.md"
@@ -1,5 +1,5 @@
1
1
  # ruff: noqa: E402
2
- __version__ = "6.16.3"
2
+ __version__ = "6.18.0"
3
3
 
4
4
  import logging
5
5
 
@@ -10,7 +10,7 @@ from datasets import Array2D, Dataset, Features, Sequence, Value
10
10
  from datasets.fingerprint import generate_fingerprint
11
11
  from huggingface_hub import HfApi
12
12
  from jaxtyping import Float, Int
13
- from tqdm import tqdm
13
+ from tqdm.auto import tqdm
14
14
  from transformer_lens.HookedTransformer import HookedRootModule
15
15
 
16
16
  from sae_lens import logger
@@ -18,6 +18,7 @@ from datasets import (
18
18
 
19
19
  from sae_lens import __version__, logger
20
20
  from sae_lens.constants import DTYPE_MAP
21
+ from sae_lens.registry import get_sae_training_class
21
22
  from sae_lens.saes.sae import TrainingSAEConfig
22
23
 
23
24
  if TYPE_CHECKING:
@@ -171,6 +172,7 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
171
172
  n_checkpoints (int): The number of checkpoints to save during training. 0 means no checkpoints.
172
173
  checkpoint_path (str | None): The path to save checkpoints. A unique ID will be appended to this path. Set to None to disable checkpoint saving. (default is "checkpoints")
173
174
  save_final_checkpoint (bool): Whether to include an additional final checkpoint when training is finished. (default is False).
175
+ resume_from_checkpoint (str | None): The path to the checkpoint to resume training from. (default is None).
174
176
  output_path (str | None): The path to save outputs. Set to None to disable output saving. (default is "output")
175
177
  verbose (bool): Whether to print verbose output. (default is True)
176
178
  model_kwargs (dict[str, Any]): Keyword arguments for `model.run_with_cache`
@@ -261,6 +263,7 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
261
263
  checkpoint_path: str | None = "checkpoints"
262
264
  save_final_checkpoint: bool = False
263
265
  output_path: str | None = "output"
266
+ resume_from_checkpoint: str | None = None
264
267
 
265
268
  # Misc
266
269
  verbose: bool = True
@@ -385,8 +388,11 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
385
388
  return self.sae.to_dict()
386
389
 
387
390
  def to_dict(self) -> dict[str, Any]:
388
- # Make a shallow copy of config's dictionary
389
- d = dict(self.__dict__)
391
+ """
392
+ Convert the config to a dictionary.
393
+ """
394
+
395
+ d = asdict(self)
390
396
 
391
397
  d["logger"] = asdict(self.logger)
392
398
  d["sae"] = self.sae.to_dict()
@@ -396,6 +402,37 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
396
402
  d["act_store_device"] = str(self.act_store_device)
397
403
  return d
398
404
 
405
+ @classmethod
406
+ def from_dict(cls, cfg_dict: dict[str, Any]) -> "LanguageModelSAERunnerConfig[Any]":
407
+ """
408
+ Load a LanguageModelSAERunnerConfig from a dictionary given by `to_dict`.
409
+
410
+ Args:
411
+ cfg_dict (dict[str, Any]): The dictionary to load the config from.
412
+
413
+ Returns:
414
+ LanguageModelSAERunnerConfig: The loaded config.
415
+ """
416
+ if "sae" not in cfg_dict:
417
+ raise ValueError("sae field is required in the config dictionary")
418
+ if "architecture" not in cfg_dict["sae"]:
419
+ raise ValueError("architecture field is required in the sae dictionary")
420
+ if "logger" not in cfg_dict:
421
+ raise ValueError("logger field is required in the config dictionary")
422
+ sae_config_class = get_sae_training_class(cfg_dict["sae"]["architecture"])[1]
423
+ sae_cfg = sae_config_class.from_dict(cfg_dict["sae"])
424
+ logger_cfg = LoggingConfig(**cfg_dict["logger"])
425
+ updated_cfg_dict: dict[str, Any] = {
426
+ **cfg_dict,
427
+ "sae": sae_cfg,
428
+ "logger": logger_cfg,
429
+ }
430
+ output = cls(**updated_cfg_dict)
431
+ # the post_init always appends to checkpoint path, so we need to set it explicitly here.
432
+ if "checkpoint_path" in cfg_dict:
433
+ output.checkpoint_path = cfg_dict["checkpoint_path"]
434
+ return output
435
+
399
436
  def to_sae_trainer_config(self) -> "SAETrainerConfig":
400
437
  return SAETrainerConfig(
401
438
  n_checkpoints=self.n_checkpoints,
@@ -17,5 +17,6 @@ SAE_WEIGHTS_FILENAME = "sae_weights.safetensors"
17
17
  SAE_CFG_FILENAME = "cfg.json"
18
18
  RUNNER_CFG_FILENAME = "runner_cfg.json"
19
19
  SPARSIFY_WEIGHTS_FILENAME = "sae.safetensors"
20
+ TRAINER_STATE_FILENAME = "trainer_state.pt"
20
21
  ACTIVATIONS_STORE_STATE_FILENAME = "activations_store_state.safetensors"
21
22
  ACTIVATION_SCALER_CFG_FILENAME = "activation_scaler.json"
@@ -16,7 +16,6 @@ from typing_extensions import deprecated
16
16
  from sae_lens import logger
17
17
  from sae_lens.config import HfDataset, LanguageModelSAERunnerConfig
18
18
  from sae_lens.constants import (
19
- ACTIVATIONS_STORE_STATE_FILENAME,
20
19
  RUNNER_CFG_FILENAME,
21
20
  SPARSITY_FILENAME,
22
21
  )
@@ -112,6 +111,7 @@ class LanguageModelSAETrainingRunner:
112
111
  override_dataset: HfDataset | None = None,
113
112
  override_model: HookedRootModule | None = None,
114
113
  override_sae: TrainingSAE[Any] | None = None,
114
+ resume_from_checkpoint: Path | str | None = None,
115
115
  ):
116
116
  if override_dataset is not None:
117
117
  logger.warning(
@@ -153,6 +153,7 @@ class LanguageModelSAETrainingRunner:
153
153
  )
154
154
  else:
155
155
  self.sae = override_sae
156
+
156
157
  self.sae.to(self.cfg.device)
157
158
 
158
159
  def run(self):
@@ -185,6 +186,12 @@ class LanguageModelSAETrainingRunner:
185
186
  cfg=self.cfg.to_sae_trainer_config(),
186
187
  )
187
188
 
189
+ if self.cfg.resume_from_checkpoint is not None:
190
+ logger.info(f"Resuming from checkpoint: {self.cfg.resume_from_checkpoint}")
191
+ trainer.load_trainer_state(self.cfg.resume_from_checkpoint)
192
+ self.sae.load_weights_from_checkpoint(self.cfg.resume_from_checkpoint)
193
+ self.activations_store.load_from_checkpoint(self.cfg.resume_from_checkpoint)
194
+
188
195
  self._compile_if_needed()
189
196
  sae = self.run_trainer_with_interruption_handling(trainer)
190
197
 
@@ -304,9 +311,7 @@ class LanguageModelSAETrainingRunner:
304
311
  if checkpoint_path is None:
305
312
  return
306
313
 
307
- self.activations_store.save(
308
- str(checkpoint_path / ACTIVATIONS_STORE_STATE_FILENAME)
309
- )
314
+ self.activations_store.save_to_checkpoint(checkpoint_path)
310
315
 
311
316
  runner_config = self.cfg.to_dict()
312
317
  with open(checkpoint_path / RUNNER_CFG_FILENAME, "w") as f:
@@ -21,7 +21,7 @@ import einops
21
21
  import torch
22
22
  from jaxtyping import Float
23
23
  from numpy.typing import NDArray
24
- from safetensors.torch import save_file
24
+ from safetensors.torch import load_file, save_file
25
25
  from torch import nn
26
26
  from transformer_lens.hook_points import HookedRootModule, HookPoint
27
27
  from typing_extensions import deprecated, overload, override
@@ -1018,6 +1018,12 @@ class TrainingSAE(SAE[T_TRAINING_SAE_CONFIG], ABC):
1018
1018
  ) -> type[TrainingSAEConfig]:
1019
1019
  return get_sae_training_class(architecture)[1]
1020
1020
 
1021
+ def load_weights_from_checkpoint(self, checkpoint_path: Path | str) -> None:
1022
+ checkpoint_path = Path(checkpoint_path)
1023
+ state_dict = load_file(checkpoint_path / SAE_WEIGHTS_FILENAME)
1024
+ self.process_state_dict_for_loading(state_dict)
1025
+ self.load_state_dict(state_dict)
1026
+
1021
1027
 
1022
1028
  _blank_hook = nn.Identity()
1023
1029
 
@@ -1,5 +1,6 @@
1
1
  import json
2
2
  from dataclasses import dataclass
3
+ from pathlib import Path
3
4
  from statistics import mean
4
5
 
5
6
  import torch
@@ -51,3 +52,9 @@ class ActivationScaler:
51
52
 
52
53
  with open(file_path, "w") as f:
53
54
  json.dump({"scaling_factor": self.scaling_factor}, f)
55
+
56
+ def load(self, file_path: str | Path):
57
+ """load the state dict from a file in json format"""
58
+ with open(file_path) as f:
59
+ data = json.load(f)
60
+ self.scaling_factor = data["scaling_factor"]
@@ -4,6 +4,7 @@ import json
4
4
  import os
5
5
  import warnings
6
6
  from collections.abc import Generator, Iterator, Sequence
7
+ from pathlib import Path
7
8
  from typing import Any, Literal, cast
8
9
 
9
10
  import datasets
@@ -13,8 +14,8 @@ from huggingface_hub import hf_hub_download
13
14
  from huggingface_hub.utils import HfHubHTTPError
14
15
  from jaxtyping import Float, Int
15
16
  from requests import HTTPError
16
- from safetensors.torch import save_file
17
- from tqdm import tqdm
17
+ from safetensors.torch import load_file, save_file
18
+ from tqdm.auto import tqdm
18
19
  from transformer_lens.hook_points import HookedRootModule
19
20
  from transformers import AutoTokenizer, PreTrainedTokenizerBase
20
21
 
@@ -24,7 +25,7 @@ from sae_lens.config import (
24
25
  HfDataset,
25
26
  LanguageModelSAERunnerConfig,
26
27
  )
27
- from sae_lens.constants import DTYPE_MAP
28
+ from sae_lens.constants import ACTIVATIONS_STORE_STATE_FILENAME, DTYPE_MAP
28
29
  from sae_lens.pretokenize_runner import get_special_token_from_cfg
29
30
  from sae_lens.saes.sae import SAE, T_SAE_CONFIG, T_TRAINING_SAE_CONFIG
30
31
  from sae_lens.tokenization_and_batching import concat_and_batch_sequences
@@ -729,6 +730,48 @@ class ActivationsStore:
729
730
  """save the state dict to a file in safetensors format"""
730
731
  save_file(self.state_dict(), file_path)
731
732
 
733
+ def save_to_checkpoint(self, checkpoint_path: str | Path):
734
+ """Save the state dict to a checkpoint path"""
735
+ self.save(str(Path(checkpoint_path) / ACTIVATIONS_STORE_STATE_FILENAME))
736
+
737
+ def load_from_checkpoint(self, checkpoint_path: str | Path):
738
+ """Load the state dict from a checkpoint path"""
739
+ self.load(str(Path(checkpoint_path) / ACTIVATIONS_STORE_STATE_FILENAME))
740
+
741
+ def load(self, file_path: str):
742
+ """Load the state dict from a file in safetensors format"""
743
+
744
+ state_dict = load_file(file_path)
745
+
746
+ if "n_dataset_processed" in state_dict:
747
+ target_n_dataset_processed = state_dict["n_dataset_processed"].item()
748
+
749
+ # Only fast-forward if needed
750
+
751
+ if target_n_dataset_processed > self.n_dataset_processed:
752
+ logger.info(
753
+ "Fast-forwarding through dataset samples to match checkpoint position"
754
+ )
755
+ samples_to_skip = target_n_dataset_processed - self.n_dataset_processed
756
+
757
+ pbar = tqdm(
758
+ total=samples_to_skip,
759
+ desc="Fast-forwarding through dataset",
760
+ leave=False,
761
+ )
762
+ while target_n_dataset_processed > self.n_dataset_processed:
763
+ start = self.n_dataset_processed
764
+ try:
765
+ # Just consume and ignore the values to fast-forward
766
+ next(self.iterable_sequences)
767
+ except StopIteration:
768
+ logger.warning(
769
+ "Dataset exhausted during fast-forward. Resetting dataset."
770
+ )
771
+ self.iterable_sequences = self._iterate_tokenized_sequences()
772
+ pbar.update(self.n_dataset_processed - start)
773
+ pbar.close()
774
+
732
775
 
733
776
  def validate_pretokenized_dataset_tokenizer(
734
777
  dataset_path: str, model_tokenizer: PreTrainedTokenizerBase
@@ -2,6 +2,8 @@
2
2
  Took the LR scheduler from my previous work: https://github.com/jbloomAus/DecisionTransformerInterpretability/blob/ee55df35cdb92e81d689c72fb9dd5a7252893363/src/decision_transformer/utils.py#L425
3
3
  """
4
4
 
5
+ from typing import Any
6
+
5
7
  import torch.optim as optim
6
8
  import torch.optim.lr_scheduler as lr_scheduler
7
9
 
@@ -150,3 +152,12 @@ class CoefficientScheduler:
150
152
  def value(self) -> float:
151
153
  """Returns the current scalar value."""
152
154
  return self.current_value
155
+
156
+ def state_dict(self) -> dict[str, Any]:
157
+ return {
158
+ "current_step": self.current_step,
159
+ }
160
+
161
+ def load_state_dict(self, state_dict: dict[str, Any]):
162
+ for k in state_dict:
163
+ setattr(self, k, state_dict[k])
@@ -1,4 +1,5 @@
1
1
  import contextlib
2
+ import math
2
3
  from pathlib import Path
3
4
  from typing import Any, Callable, Generic, Protocol
4
5
 
@@ -10,7 +11,11 @@ from tqdm.auto import tqdm
10
11
 
11
12
  from sae_lens import __version__
12
13
  from sae_lens.config import SAETrainerConfig
13
- from sae_lens.constants import ACTIVATION_SCALER_CFG_FILENAME, SPARSITY_FILENAME
14
+ from sae_lens.constants import (
15
+ ACTIVATION_SCALER_CFG_FILENAME,
16
+ SPARSITY_FILENAME,
17
+ TRAINER_STATE_FILENAME,
18
+ )
14
19
  from sae_lens.saes.sae import (
15
20
  T_TRAINING_SAE,
16
21
  T_TRAINING_SAE_CONFIG,
@@ -56,6 +61,7 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
56
61
  data_provider: DataProvider
57
62
  activation_scaler: ActivationScaler
58
63
  evaluator: Evaluator[T_TRAINING_SAE] | None
64
+ coefficient_schedulers: dict[str, CoefficientScheduler]
59
65
 
60
66
  def __init__(
61
67
  self,
@@ -84,7 +90,9 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
84
90
  range(
85
91
  0,
86
92
  cfg.total_training_samples,
87
- cfg.total_training_samples // self.cfg.n_checkpoints,
93
+ math.ceil(
94
+ cfg.total_training_samples / (self.cfg.n_checkpoints + 1)
95
+ ),
88
96
  )
89
97
  )[1:]
90
98
 
@@ -93,11 +101,6 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
93
101
  sae.cfg.d_sae, device=cfg.device
94
102
  )
95
103
  self.n_frac_active_samples = 0
96
- # we don't train the scaling factor (initially)
97
- # set requires grad to false for the scaling factor
98
- for name, param in self.sae.named_parameters():
99
- if "scaling_factor" in name:
100
- param.requires_grad = False
101
104
 
102
105
  self.optimizer = Adam(
103
106
  sae.parameters(),
@@ -210,10 +213,7 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
210
213
  sparsity_path = checkpoint_path / SPARSITY_FILENAME
211
214
  save_file({"sparsity": self.log_feature_sparsity}, sparsity_path)
212
215
 
213
- activation_scaler_path = (
214
- checkpoint_path / ACTIVATION_SCALER_CFG_FILENAME
215
- )
216
- self.activation_scaler.save(str(activation_scaler_path))
216
+ self.save_trainer_state(checkpoint_path)
217
217
 
218
218
  if self.cfg.logger.log_to_wandb:
219
219
  self.cfg.logger.log(
@@ -227,6 +227,44 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
227
227
  if self.save_checkpoint_fn is not None:
228
228
  self.save_checkpoint_fn(checkpoint_path=checkpoint_path)
229
229
 
230
+ def save_trainer_state(self, checkpoint_path: Path) -> None:
231
+ checkpoint_path.mkdir(exist_ok=True, parents=True)
232
+ scheduler_state_dicts = {
233
+ name: scheduler.state_dict()
234
+ for name, scheduler in self.coefficient_schedulers.items()
235
+ }
236
+ torch.save(
237
+ {
238
+ "optimizer": self.optimizer.state_dict(),
239
+ "lr_scheduler": self.lr_scheduler.state_dict(),
240
+ "n_training_samples": self.n_training_samples,
241
+ "n_training_steps": self.n_training_steps,
242
+ "act_freq_scores": self.act_freq_scores,
243
+ "n_forward_passes_since_fired": self.n_forward_passes_since_fired,
244
+ "n_frac_active_samples": self.n_frac_active_samples,
245
+ "started_fine_tuning": self.started_fine_tuning,
246
+ "coefficient_schedulers": scheduler_state_dicts,
247
+ },
248
+ str(checkpoint_path / TRAINER_STATE_FILENAME),
249
+ )
250
+ activation_scaler_path = checkpoint_path / ACTIVATION_SCALER_CFG_FILENAME
251
+ self.activation_scaler.save(str(activation_scaler_path))
252
+
253
+ def load_trainer_state(self, checkpoint_path: Path | str) -> None:
254
+ checkpoint_path = Path(checkpoint_path)
255
+ self.activation_scaler.load(checkpoint_path / ACTIVATION_SCALER_CFG_FILENAME)
256
+ state_dict = torch.load(checkpoint_path / TRAINER_STATE_FILENAME)
257
+ self.optimizer.load_state_dict(state_dict["optimizer"])
258
+ self.lr_scheduler.load_state_dict(state_dict["lr_scheduler"])
259
+ self.n_training_samples = state_dict["n_training_samples"]
260
+ self.n_training_steps = state_dict["n_training_steps"]
261
+ self.act_freq_scores = state_dict["act_freq_scores"]
262
+ self.n_forward_passes_since_fired = state_dict["n_forward_passes_since_fired"]
263
+ self.n_frac_active_samples = state_dict["n_frac_active_samples"]
264
+ self.started_fine_tuning = state_dict["started_fine_tuning"]
265
+ for name, scheduler_state_dict in state_dict["coefficient_schedulers"].items():
266
+ self.coefficient_schedulers[name].load_state_dict(scheduler_state_dict)
267
+
230
268
  def _train_step(
231
269
  self,
232
270
  sae: T_TRAINING_SAE,
File without changes
File without changes
File without changes
File without changes