sae-lens 6.9.1__tar.gz → 6.10.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sae_lens-6.9.1 → sae_lens-6.10.0}/PKG-INFO +1 -1
- {sae_lens-6.9.1 → sae_lens-6.10.0}/pyproject.toml +1 -1
- {sae_lens-6.9.1 → sae_lens-6.10.0}/sae_lens/__init__.py +1 -1
- {sae_lens-6.9.1 → sae_lens-6.10.0}/sae_lens/config.py +13 -5
- {sae_lens-6.9.1 → sae_lens-6.10.0}/sae_lens/llm_sae_training_runner.py +53 -8
- {sae_lens-6.9.1 → sae_lens-6.10.0}/sae_lens/training/sae_trainer.py +25 -29
- {sae_lens-6.9.1 → sae_lens-6.10.0}/sae_lens/util.py +18 -0
- {sae_lens-6.9.1 → sae_lens-6.10.0}/LICENSE +0 -0
- {sae_lens-6.9.1 → sae_lens-6.10.0}/README.md +0 -0
- {sae_lens-6.9.1 → sae_lens-6.10.0}/sae_lens/analysis/__init__.py +0 -0
- {sae_lens-6.9.1 → sae_lens-6.10.0}/sae_lens/analysis/hooked_sae_transformer.py +0 -0
- {sae_lens-6.9.1 → sae_lens-6.10.0}/sae_lens/analysis/neuronpedia_integration.py +0 -0
- {sae_lens-6.9.1 → sae_lens-6.10.0}/sae_lens/cache_activations_runner.py +0 -0
- {sae_lens-6.9.1 → sae_lens-6.10.0}/sae_lens/constants.py +0 -0
- {sae_lens-6.9.1 → sae_lens-6.10.0}/sae_lens/evals.py +0 -0
- {sae_lens-6.9.1 → sae_lens-6.10.0}/sae_lens/load_model.py +0 -0
- {sae_lens-6.9.1 → sae_lens-6.10.0}/sae_lens/loading/__init__.py +0 -0
- {sae_lens-6.9.1 → sae_lens-6.10.0}/sae_lens/loading/pretrained_sae_loaders.py +0 -0
- {sae_lens-6.9.1 → sae_lens-6.10.0}/sae_lens/loading/pretrained_saes_directory.py +0 -0
- {sae_lens-6.9.1 → sae_lens-6.10.0}/sae_lens/pretokenize_runner.py +0 -0
- {sae_lens-6.9.1 → sae_lens-6.10.0}/sae_lens/pretrained_saes.yaml +0 -0
- {sae_lens-6.9.1 → sae_lens-6.10.0}/sae_lens/registry.py +0 -0
- {sae_lens-6.9.1 → sae_lens-6.10.0}/sae_lens/saes/__init__.py +0 -0
- {sae_lens-6.9.1 → sae_lens-6.10.0}/sae_lens/saes/batchtopk_sae.py +0 -0
- {sae_lens-6.9.1 → sae_lens-6.10.0}/sae_lens/saes/gated_sae.py +0 -0
- {sae_lens-6.9.1 → sae_lens-6.10.0}/sae_lens/saes/jumprelu_sae.py +0 -0
- {sae_lens-6.9.1 → sae_lens-6.10.0}/sae_lens/saes/sae.py +0 -0
- {sae_lens-6.9.1 → sae_lens-6.10.0}/sae_lens/saes/standard_sae.py +0 -0
- {sae_lens-6.9.1 → sae_lens-6.10.0}/sae_lens/saes/topk_sae.py +0 -0
- {sae_lens-6.9.1 → sae_lens-6.10.0}/sae_lens/saes/transcoder.py +0 -0
- {sae_lens-6.9.1 → sae_lens-6.10.0}/sae_lens/tokenization_and_batching.py +0 -0
- {sae_lens-6.9.1 → sae_lens-6.10.0}/sae_lens/training/__init__.py +0 -0
- {sae_lens-6.9.1 → sae_lens-6.10.0}/sae_lens/training/activation_scaler.py +0 -0
- {sae_lens-6.9.1 → sae_lens-6.10.0}/sae_lens/training/activations_store.py +0 -0
- {sae_lens-6.9.1 → sae_lens-6.10.0}/sae_lens/training/mixing_buffer.py +0 -0
- {sae_lens-6.9.1 → sae_lens-6.10.0}/sae_lens/training/optim.py +0 -0
- {sae_lens-6.9.1 → sae_lens-6.10.0}/sae_lens/training/types.py +0 -0
- {sae_lens-6.9.1 → sae_lens-6.10.0}/sae_lens/training/upload_saes_to_huggingface.py +0 -0
- {sae_lens-6.9.1 → sae_lens-6.10.0}/sae_lens/tutorial/tsea.py +0 -0
|
@@ -169,8 +169,10 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
|
169
169
|
eval_batch_size_prompts (int, optional): The batch size for evaluation, in prompts. Useful if evals cause OOM.
|
|
170
170
|
logger (LoggingConfig): Configuration for logging (e.g. W&B).
|
|
171
171
|
n_checkpoints (int): The number of checkpoints to save during training. 0 means no checkpoints.
|
|
172
|
-
checkpoint_path (str): The path to save checkpoints. A unique ID will be appended to this path.
|
|
173
|
-
|
|
172
|
+
checkpoint_path (str | None): The path to save checkpoints. A unique ID will be appended to this path. Set to None to disable checkpoint saving. (default is "checkpoints")
|
|
173
|
+
save_final_checkpoint (bool): Whether to include an additional final checkpoint when training is finished. (default is False).
|
|
174
|
+
output_path (str | None): The path to save outputs. Set to None to disable output saving. (default is "output")
|
|
175
|
+
verbose (bool): Whether to print verbose output. (default is True)
|
|
174
176
|
model_kwargs (dict[str, Any]): Keyword arguments for `model.run_with_cache`
|
|
175
177
|
model_from_pretrained_kwargs (dict[str, Any], optional): Additional keyword arguments to pass to the model's `from_pretrained` method.
|
|
176
178
|
sae_lens_version (str): The version of the sae_lens library.
|
|
@@ -254,9 +256,13 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
|
254
256
|
|
|
255
257
|
logger: LoggingConfig = field(default_factory=LoggingConfig)
|
|
256
258
|
|
|
257
|
-
#
|
|
259
|
+
# Outputs/Checkpoints
|
|
258
260
|
n_checkpoints: int = 0
|
|
259
|
-
checkpoint_path: str = "checkpoints"
|
|
261
|
+
checkpoint_path: str | None = "checkpoints"
|
|
262
|
+
save_final_checkpoint: bool = False
|
|
263
|
+
output_path: str | None = "output"
|
|
264
|
+
|
|
265
|
+
# Misc
|
|
260
266
|
verbose: bool = True
|
|
261
267
|
model_kwargs: dict[str, Any] = dict_field(default={})
|
|
262
268
|
model_from_pretrained_kwargs: dict[str, Any] | None = dict_field(default=None)
|
|
@@ -394,6 +400,7 @@ class LanguageModelSAERunnerConfig(Generic[T_TRAINING_SAE_CONFIG]):
|
|
|
394
400
|
return SAETrainerConfig(
|
|
395
401
|
n_checkpoints=self.n_checkpoints,
|
|
396
402
|
checkpoint_path=self.checkpoint_path,
|
|
403
|
+
save_final_checkpoint=self.save_final_checkpoint,
|
|
397
404
|
total_training_samples=self.total_training_tokens,
|
|
398
405
|
device=self.device,
|
|
399
406
|
autocast=self.autocast,
|
|
@@ -618,7 +625,8 @@ class PretokenizeRunnerConfig:
|
|
|
618
625
|
@dataclass
|
|
619
626
|
class SAETrainerConfig:
|
|
620
627
|
n_checkpoints: int
|
|
621
|
-
checkpoint_path: str
|
|
628
|
+
checkpoint_path: str | None
|
|
629
|
+
save_final_checkpoint: bool
|
|
622
630
|
total_training_samples: int
|
|
623
631
|
device: str
|
|
624
632
|
autocast: bool
|
|
@@ -8,13 +8,18 @@ from typing import Any, Generic
|
|
|
8
8
|
|
|
9
9
|
import torch
|
|
10
10
|
import wandb
|
|
11
|
+
from safetensors.torch import save_file
|
|
11
12
|
from simple_parsing import ArgumentParser
|
|
12
13
|
from transformer_lens.hook_points import HookedRootModule
|
|
13
14
|
from typing_extensions import deprecated
|
|
14
15
|
|
|
15
16
|
from sae_lens import logger
|
|
16
17
|
from sae_lens.config import HfDataset, LanguageModelSAERunnerConfig
|
|
17
|
-
from sae_lens.constants import
|
|
18
|
+
from sae_lens.constants import (
|
|
19
|
+
ACTIVATIONS_STORE_STATE_FILENAME,
|
|
20
|
+
RUNNER_CFG_FILENAME,
|
|
21
|
+
SPARSITY_FILENAME,
|
|
22
|
+
)
|
|
18
23
|
from sae_lens.evals import EvalConfig, run_evals
|
|
19
24
|
from sae_lens.load_model import load_model
|
|
20
25
|
from sae_lens.saes.batchtopk_sae import BatchTopKTrainingSAEConfig
|
|
@@ -185,11 +190,47 @@ class LanguageModelSAETrainingRunner:
|
|
|
185
190
|
self._compile_if_needed()
|
|
186
191
|
sae = self.run_trainer_with_interruption_handling(trainer)
|
|
187
192
|
|
|
193
|
+
if self.cfg.output_path is not None:
|
|
194
|
+
self.save_final_sae(
|
|
195
|
+
sae=sae,
|
|
196
|
+
output_path=self.cfg.output_path,
|
|
197
|
+
log_feature_sparsity=trainer.log_feature_sparsity,
|
|
198
|
+
)
|
|
199
|
+
|
|
188
200
|
if self.cfg.logger.log_to_wandb:
|
|
189
201
|
wandb.finish()
|
|
190
202
|
|
|
191
203
|
return sae
|
|
192
204
|
|
|
205
|
+
def save_final_sae(
|
|
206
|
+
self,
|
|
207
|
+
sae: TrainingSAE[Any],
|
|
208
|
+
output_path: str,
|
|
209
|
+
log_feature_sparsity: torch.Tensor | None = None,
|
|
210
|
+
):
|
|
211
|
+
base_output_path = Path(output_path)
|
|
212
|
+
base_output_path.mkdir(exist_ok=True, parents=True)
|
|
213
|
+
|
|
214
|
+
weights_path, cfg_path = sae.save_inference_model(str(base_output_path))
|
|
215
|
+
|
|
216
|
+
sparsity_path = None
|
|
217
|
+
if log_feature_sparsity is not None:
|
|
218
|
+
sparsity_path = base_output_path / SPARSITY_FILENAME
|
|
219
|
+
save_file({"sparsity": log_feature_sparsity}, sparsity_path)
|
|
220
|
+
|
|
221
|
+
runner_config = self.cfg.to_dict()
|
|
222
|
+
with open(base_output_path / RUNNER_CFG_FILENAME, "w") as f:
|
|
223
|
+
json.dump(runner_config, f)
|
|
224
|
+
|
|
225
|
+
if self.cfg.logger.log_to_wandb:
|
|
226
|
+
self.cfg.logger.log(
|
|
227
|
+
self,
|
|
228
|
+
weights_path,
|
|
229
|
+
cfg_path,
|
|
230
|
+
sparsity_path=sparsity_path,
|
|
231
|
+
wandb_aliases=["final_model"],
|
|
232
|
+
)
|
|
233
|
+
|
|
193
234
|
def _set_sae_metadata(self):
|
|
194
235
|
self.sae.cfg.metadata.dataset_path = self.cfg.dataset_path
|
|
195
236
|
self.sae.cfg.metadata.hook_name = self.cfg.hook_name
|
|
@@ -247,20 +288,24 @@ class LanguageModelSAETrainingRunner:
|
|
|
247
288
|
sae = trainer.fit()
|
|
248
289
|
|
|
249
290
|
except (KeyboardInterrupt, InterruptedException):
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
291
|
+
if self.cfg.checkpoint_path is not None:
|
|
292
|
+
logger.warning("interrupted, saving progress")
|
|
293
|
+
checkpoint_path = Path(self.cfg.checkpoint_path) / str(
|
|
294
|
+
trainer.n_training_samples
|
|
295
|
+
)
|
|
296
|
+
self.save_checkpoint(checkpoint_path)
|
|
297
|
+
logger.info("done saving")
|
|
256
298
|
raise
|
|
257
299
|
|
|
258
300
|
return sae
|
|
259
301
|
|
|
260
302
|
def save_checkpoint(
|
|
261
303
|
self,
|
|
262
|
-
checkpoint_path: Path,
|
|
304
|
+
checkpoint_path: Path | None,
|
|
263
305
|
) -> None:
|
|
306
|
+
if checkpoint_path is None:
|
|
307
|
+
return
|
|
308
|
+
|
|
264
309
|
self.activations_store.save(
|
|
265
310
|
str(checkpoint_path / ACTIVATIONS_STORE_STATE_FILENAME)
|
|
266
311
|
)
|
|
@@ -22,6 +22,7 @@ from sae_lens.saes.sae import (
|
|
|
22
22
|
from sae_lens.training.activation_scaler import ActivationScaler
|
|
23
23
|
from sae_lens.training.optim import CoefficientScheduler, get_lr_scheduler
|
|
24
24
|
from sae_lens.training.types import DataProvider
|
|
25
|
+
from sae_lens.util import path_or_tmp_dir
|
|
25
26
|
|
|
26
27
|
|
|
27
28
|
def _log_feature_sparsity(
|
|
@@ -40,7 +41,7 @@ def _update_sae_lens_training_version(sae: TrainingSAE[Any]) -> None:
|
|
|
40
41
|
class SaveCheckpointFn(Protocol):
|
|
41
42
|
def __call__(
|
|
42
43
|
self,
|
|
43
|
-
checkpoint_path: Path,
|
|
44
|
+
checkpoint_path: Path | None,
|
|
44
45
|
) -> None: ...
|
|
45
46
|
|
|
46
47
|
|
|
@@ -187,12 +188,8 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
|
187
188
|
)
|
|
188
189
|
self.activation_scaler.scaling_factor = None
|
|
189
190
|
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
checkpoint_name=f"final_{self.n_training_samples}",
|
|
193
|
-
wandb_aliases=["final_model"],
|
|
194
|
-
save_inference_model=True,
|
|
195
|
-
)
|
|
191
|
+
if self.cfg.save_final_checkpoint:
|
|
192
|
+
self.save_checkpoint(checkpoint_name=f"final_{self.n_training_samples}")
|
|
196
193
|
|
|
197
194
|
pbar.close()
|
|
198
195
|
return self.sae
|
|
@@ -201,32 +198,31 @@ class SAETrainer(Generic[T_TRAINING_SAE, T_TRAINING_SAE_CONFIG]):
|
|
|
201
198
|
self,
|
|
202
199
|
checkpoint_name: str,
|
|
203
200
|
wandb_aliases: list[str] | None = None,
|
|
204
|
-
save_inference_model: bool = False,
|
|
205
201
|
) -> None:
|
|
206
|
-
checkpoint_path =
|
|
207
|
-
checkpoint_path.
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
if save_inference_model
|
|
212
|
-
else self.sae.save_model
|
|
213
|
-
)
|
|
214
|
-
weights_path, cfg_path = save_fn(str(checkpoint_path))
|
|
202
|
+
checkpoint_path = None
|
|
203
|
+
if self.cfg.checkpoint_path is not None or self.cfg.logger.log_to_wandb:
|
|
204
|
+
with path_or_tmp_dir(self.cfg.checkpoint_path) as base_checkpoint_path:
|
|
205
|
+
checkpoint_path = base_checkpoint_path / checkpoint_name
|
|
206
|
+
checkpoint_path.mkdir(exist_ok=True, parents=True)
|
|
215
207
|
|
|
216
|
-
|
|
217
|
-
save_file({"sparsity": self.log_feature_sparsity}, sparsity_path)
|
|
208
|
+
weights_path, cfg_path = self.sae.save_model(str(checkpoint_path))
|
|
218
209
|
|
|
219
|
-
|
|
220
|
-
|
|
210
|
+
sparsity_path = checkpoint_path / SPARSITY_FILENAME
|
|
211
|
+
save_file({"sparsity": self.log_feature_sparsity}, sparsity_path)
|
|
221
212
|
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
213
|
+
activation_scaler_path = (
|
|
214
|
+
checkpoint_path / ACTIVATION_SCALER_CFG_FILENAME
|
|
215
|
+
)
|
|
216
|
+
self.activation_scaler.save(str(activation_scaler_path))
|
|
217
|
+
|
|
218
|
+
if self.cfg.logger.log_to_wandb:
|
|
219
|
+
self.cfg.logger.log(
|
|
220
|
+
self,
|
|
221
|
+
weights_path,
|
|
222
|
+
cfg_path,
|
|
223
|
+
sparsity_path=sparsity_path,
|
|
224
|
+
wandb_aliases=wandb_aliases,
|
|
225
|
+
)
|
|
230
226
|
|
|
231
227
|
if self.save_checkpoint_fn is not None:
|
|
232
228
|
self.save_checkpoint_fn(checkpoint_path=checkpoint_path)
|
|
@@ -1,5 +1,8 @@
|
|
|
1
1
|
import re
|
|
2
|
+
import tempfile
|
|
3
|
+
from contextlib import contextmanager
|
|
2
4
|
from dataclasses import asdict, fields, is_dataclass
|
|
5
|
+
from pathlib import Path
|
|
3
6
|
from typing import Sequence, TypeVar
|
|
4
7
|
|
|
5
8
|
K = TypeVar("K")
|
|
@@ -45,3 +48,18 @@ def extract_layer_from_tlens_hook_name(hook_name: str) -> int | None:
|
|
|
45
48
|
"""
|
|
46
49
|
hook_match = re.search(r"\.(\d+)\.", hook_name)
|
|
47
50
|
return None if hook_match is None else int(hook_match.group(1))
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@contextmanager
|
|
54
|
+
def path_or_tmp_dir(path: str | Path | None):
|
|
55
|
+
"""Context manager that yields a concrete Path for path.
|
|
56
|
+
|
|
57
|
+
- If path is None, creates a TemporaryDirectory and yields its Path.
|
|
58
|
+
The directory is cleaned up on context exit.
|
|
59
|
+
- Otherwise, yields Path(path) without creating or cleaning.
|
|
60
|
+
"""
|
|
61
|
+
if path is None:
|
|
62
|
+
with tempfile.TemporaryDirectory() as td:
|
|
63
|
+
yield Path(td)
|
|
64
|
+
else:
|
|
65
|
+
yield Path(path)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|